Mms release (#3948) (#5110)

This commit is contained in:
Vineel Pratap 2023-05-21 21:15:50 -07:00 committed by GitHub
parent bfd9dc6d27
commit 728b947019
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 2654 additions and 67 deletions

View File

@ -0,0 +1,63 @@
# MMS Model Card
## Model details
**Organization developing the model** The FAIR team of Meta AI.
**Model version** This is version 1 of the model.
**Model type** MMS is speech model, based on the transformer architecture. The pre-trained model comes in two sizes: 300M and 1B parameters. We fine-tune the model for speech recognition and make it available in the 1B variant. We also fine-tune the 1B variant for language identification.
**License** CC BY-NC
**Where to send questions or comments about the model** Questions and comments about MMS can be sent via the [GitHub repository](https://github.com/pytorch/fairseq/tree/master/examples/mms) of the project , by opening an issue and tagging it as MMS.
## Uses
**Primary intended uses** The primary use of MMS is to perform speech processing research for many more languages and to perform tasks such as automatic speech recognition, language identification, and speech synthesis.
**Primary intended users** The primary intended users of the model are researchers in speech processing, machine learning and artificial intelligence.
**Out-of-scope use cases** Fine-tuning the pre-pretrained models on other labeled datasets or downstream tasks requires further risk evaluation and mitigation.
## Bias and Risks
The MMS models were pre-trained on a blend of data from different domains, including readings of the New Testament. In the paper, we describe two studies analyzing gender bias and the use of religious language which conclude that models perform equally well for both genders and that on average, there is little bias for religious language (section 8 of the paper).
# Training Details
## Training Data
MMS is pre-trained on VoxPopuli (parliamentary speech), MLS (read audiobooks), VoxLingua-107 (YouTube speech), CommonVoice (read Wikipedia text), BABEL (telephone conversations), and MMS-lab-U (New Testament readings), MMS-unlab (various read Christian texts).
Models are fine-tuned on FLEURS, VoxLingua-107, MLS, CommonVoice, and MMS-lab. We obtained the language information for MMS-lab, MMS-lab-U and MMS-unlab from our data soucrce and did not manually verify it for every language.
## Training Procedure
Please refer to the research paper for details on this.
# Evaluation
## Testing Data, Factors & Metrics
We evaluate the model on a different benchmarks for the downstream tasks. The evaluation details are presented in the paper. The models performance is measured using standard metrics such as character error rate, word error rate, and classification accuracy.
# Citation
**BibTeX:**
```
@article{pratap2023mms,
title={Scaling Speech Technology to 1,000+ Languages},
author={Vineel Pratap and Andros Tjandra and Bowen Shi and Paden Tomasello and Arun Babu and Sayani Kundu and Ali Elkahky and Zhaoheng Ni and Apoorv Vyas and Maryam Fazel-Zarandi and Alexei Baevski and Yossi Adi and Xiaohui Zhang and Wei-Ning Hsu and Alexis Conneau and Michael Auli},
journal={arXiv},
year={2023}
}
```
# Model Card Contact
Please reach out to the authors at: [vineelkpratap@meta.com](mailto:vineelkpratap@meta.com) [androstj@meta.com](mailto:androstj@meta.com) [bshi@meta.com](mailto:bshi@meta.com) [michaelauli@meta.com](mailto:michaelauli@gmail.com)

175
examples/mms/README.md Normal file
View File

@ -0,0 +1,175 @@
# MMS: Scaling Speech Technology to 1000+ languages
The Massively Multilingual Speech (MMS) project expands speech technology from about 100 languages to over 1,000 by building a single multilingual speech recognition model supporting over 1,100 languages (more than 10 times as many as before), language identification models able to identify over [4,000 languages](https://dl.fbaipublicfiles.com/mms/misc/language_coverage_mms.html) (40 times more than before), pretrained models supporting over 1,400 languages, and text-to-speech models for over 1,100 languages. Our goal is to make it easier for people to access information and to use devices in their preferred language.
You can find details in the paper [Scaling Speech Technology to 1000+ languages](https://research.facebook.com/publications/scaling-speech-technology-to-1000-languages/) and the [blog post](https://ai.facebook.com/blog/multilingual-speech-recognition-model/).
An overview of the languages covered by MMS can be found [here](https://dl.fbaipublicfiles.com/mms/misc/language_coverage_mms.html).
## Pretrained models
| Model | Link
|---|---
MMS-300M | [download](https://dl.fbaipublicfiles.com/mms/pretraining/base_300m.pt)
MMS-1B | [download](https://dl.fbaipublicfiles.com/mms/pretraining/base_1b.pt)
Example commands to finetune the pretrained models can be found [here](https://github.com/fairinternal/fairseq-py/tree/mms_release/examples/wav2vec#fine-tune-a-pre-trained-model-with-ctc).
## Finetuned models
### ASR
| Model | Languages | Dataset | Model | Supported languages |
|---|---|---|---|---
MMS-1B:FL102 | 102 | FLEURS | [download](https://dl.fbaipublicfiles.com/mms/asr/mms1b_fl102.pt) | [download](https://dl.fbaipublicfiles.com/mms/asr/mms1b_fl102_langs.html)
MMS-1B:L1107| 1107 | MMS-lab | [download](https://dl.fbaipublicfiles.com/mms/asr/mms1b_l1107.pt) | [download](https://dl.fbaipublicfiles.com/mms/asr/mms1b_l1107_langs.html)
MMS-1B-all| 1162 | MMS-lab + FLEURS <br>+ CV + VP + MLS | [download](https://dl.fbaipublicfiles.com/mms/asr/mms1b_all.pt) | [download](https://dl.fbaipublicfiles.com/mms/asr/mms1b_all_langs.html)
### TTS
1. Download the list of [iso codes](https://dl.fbaipublicfiles.com/mms/tts/all-tts-languages.html) of 1107 languages.
2. Find the iso code of the target language and download the checkpoint. Each folder contains 3 files: `G_100000.pth`, `config.json`, `vocab.txt`. The `G_100000.pth` is the generator trained for 100K updates, `config.json` is the training config, `vocab.txt` is the vocabulary for the TTS model.
```
# Examples:
wget https://dl.fbaipublicfiles.com/mms/tts/eng.tar.gz # English (eng)
wget https://dl.fbaipublicfiles.com/mms/tts/azj-script_latin.tar.gz # North Azerbaijani (azj-script_latin)
```
### LID
\# Languages | Dataset | Model | Dictionary | Supported languages |
|---|---|---|---|---
126 | FLEURS + VL + MMS-lab-U + MMS-unlab | [download](https://dl.fbaipublicfiles.com/mms/lid/mms1b_l126.pt) | [download](https://dl.fbaipublicfiles.com/mms/lid/dict/l126/dict.lang.txt) | [download](https://dl.fbaipublicfiles.com/mms/lid/mms1b_l126_langs.html)
256 | FLEURS + VL + MMS-lab-U + MMS-unlab | [download](https://dl.fbaipublicfiles.com/mms/lid/mms1b_l256.pt) | [download](https://dl.fbaipublicfiles.com/mms/lid/dict/l256/dict.lang.txt) | [download](https://dl.fbaipublicfiles.com/mms/lid/mms1b_l256_langs.html)
512 | FLEURS + VL + MMS-lab-U + MMS-unlab | [download](https://dl.fbaipublicfiles.com/mms/lid/mms1b_l512.pt) | [download](https://dl.fbaipublicfiles.com/mms/lid/dict/l512/dict.lang.txt) | [download](https://dl.fbaipublicfiles.com/mms/lid/mms1b_l512_langs.html)
1024 | FLEURS + VL + MMS-lab-U + MMS-unlab | [download](https://dl.fbaipublicfiles.com/mms/lid/mms1b_l1024.pt) | [download](https://dl.fbaipublicfiles.com/mms/lid/dict/l1024/dict.lang.txt) | [download](https://dl.fbaipublicfiles.com/mms/lid/mms1b_l1024_langs.html)
2048 | FLEURS + VL + MMS-lab-U + MMS-unlab | [download](https://dl.fbaipublicfiles.com/mms/lid/mms1b_l2048.pt) | [download](https://dl.fbaipublicfiles.com/mms/lid/dict/l2048/dict.lang.txt) | [download](https://dl.fbaipublicfiles.com/mms/lid/mms1b_l2048_langs.html)
4017 | FLEURS + VL + MMS-lab-U + MMS-unlab | [download](https://dl.fbaipublicfiles.com/mms/lid/mms1b_l4017.pt) | [download](https://dl.fbaipublicfiles.com/mms/lid/dict/l4017/dict.lang.txt) | [download](https://dl.fbaipublicfiles.com/mms/lid/mms1b_l4017_langs.html)
## Commands to run inference
### ASR
Run this command to transcribe one or more audio files:
```shell command
cd /path/to/fairseq-py/
python examples/mms/asr/infer/mms_infer.py --model "/path/to/asr/model" --lang lang_code --audio "/path/to/audio_1.wav" "/path/to/audio_1.wav"
```
For more advance configuration and calculate CER/WER, you could prepare manifest folder by creating a folder with this format:
```
$ ls /path/to/manifest
dev.tsv
dev.wrd
dev.ltr
dev.uid
# dev.tsv each line contains <audio> <number_of_sample>
$ cat dev.tsv
/
/path/to/audio_1 180000
/path/to/audio_2 200000
$ cat dev.ltr
t h i s | i s | o n e |
t h i s | i s | t w o |
$ cat dev.wrd
this is one
this is two
$ cat dev.uid
audio_1
audio_2
```
Followed by command below:
```
lang_code=<iso_code>
PYTHONPATH=. PREFIX=INFER HYDRA_FULL_ERROR=1 python examples/speech_recognition/new/infer.py -m --config-dir examples/mms/config/ --config-name infer_common decoding.type=viterbi dataset.max_tokens=4000000 distributed_training.distributed_world_size=1 "common_eval.path='/path/to/asr/model'" task.data='/path/to/manifest' dataset.gen_subset="${lang_code}:dev" common_eval.post_process=letter
```
Available options:
* To get the raw character-based output, user can change to `common_eval.post_process=none`
* To maximize GPU efficiency or avoid out-of-memory (OOM), user can tune `dataset.max_tokens=???` size
* To run language model decoding, install flashlight python bindings using
```
git clone --recursive git@github.com:flashlight/flashlight.git
cd flashlight;
git checkout 035ead6efefb82b47c8c2e643603e87d38850076
cd bindings/python
python3 setup.py install
```
Train a [KenLM language model](https://github.com/flashlight/wav2letter/tree/main/recipes/rasr#language-model) and prepare a lexicon file in [this](https://dl.fbaipublicfiles.com/wav2letter/rasr/tutorial/lexicon.txt) format.
```
LANG=<iso> # for example - 'eng', 'azj-script_latin'
PYTHONPATH=. PREFIX=INFER HYDRA_FULL_ERROR=1 python examples/speech_recognition/new/infer.py --config-dir=examples/mms/asr/config \
--config-name=infer_common decoding.type=kenlm distributed_training.distributed_world_size=1 \
decoding.unique_wer_file=true decoding.beam=500 decoding.beamsizetoken=50 \
task.data=<MANIFEST_FOLDER_PATH> common_eval.path='<MODEL_PATH.pt>' decoding.lexicon=<LEXICON_FILE> decoding.lmpath=<LM_FILE> \
decoding.results_path=<OUTPUT_DIR> dataset.gen_subset=${LANG}:dev decoding.lmweight=??? decoding.wordscore=???
```
We typically sweep `lmweight` in the range of 0 to 5 and `wordscore` in the range of -3 to 3. The output directory will contain the reference and hypothesis outputs from decoder.
For decoding with character-based language models, use empty lexicon file (`decoding.lexicon=`), `decoding.unitlm=True` and sweep over `decoding.silweight` instead of `wordscore`.
### TTS
Note: clone and install [VITS](https://github.com/jaywalnut310/vits) before running inference.
```shell script
## English TTS
$ PYTHONPATH=$PYTHONPATH:/path/to/vits python examples/mms/tts/infer.py --model-dir /path/to/model/eng \
--wav ./example.wav --txt "Expanding the language coverage of speech technology \
has the potential to improve access to information for many more people"
## Maithili TTS
$ PYTHONPATH=$PYTHONPATH:/path/to/vits python examples/mms/tts/infer.py --model-dir /path/to/model/mai \
--wav ./example.wav --txt "मुदा आइ धरि ई तकनीक सौ सं किछु बेसी भाषा तक सीमित छल जे सात हजार \
सं बेसी ज्ञात भाषाक एकटा अंश अछी"
```
`example.wav` contains synthesized audio for the language.
### LID
Prepare two files in this format
```
#/path/to/manifest.tsv
/
/path/to/audio1.wav
/path/to/audio2.wav
/path/to/audio3.wav
# /path/to/manifest.lang
eng 1
eng 1
eng 1
```
Download model and the corresponding dictionary file for the LID model. The following command assuming there is a file named `dict.lang.txt` in `/path/to/dict/l126/`.
Use the following command to run inference -
```shell script
$ PYTHONPATH='.' python3 examples/mms/lid/infer.py /path/to/dict/l126/ --path /path/to/models/mms1b_l126.pt \
--task audio_classification --infer-manifest /path/to/manifest.tsv --output-path <OUTDIR>
```
`<OUTDIR>/predictions.txt` will contain the predictions from the model for the audio files in `manifest.tsv`.
# License
The MMS code and model weights are released under the CC-BY-NC 4.0 license.
# Citation
**BibTeX:**
```
@article{pratap2023mms,
title={Scaling Speech Technology to 1,000+ Languages},
author={Vineel Pratap and Andros Tjandra and Bowen Shi and Paden Tomasello and Arun Babu and Sayani Kundu and Ali Elkahky and Zhaoheng Ni and Apoorv Vyas and Maryam Fazel-Zarandi and Alexei Baevski and Yossi Adi and Xiaohui Zhang and Wei-Ning Hsu and Alexis Conneau and Michael Auli},
journal={arXiv},
year={2023}
}
```

View File

@ -0,0 +1,32 @@
# @package _global_
# defaults:
# - hydra/launcher: submitit_slurm
# @package _group_
task:
_name: audio_finetuning
data: null
labels: ltr
common_eval:
path: null
post_process: letter
# model_overrides: "{'task':{'multi_corpus_keys':None}}"
decoding:
type: viterbi
lexicon: null
unique_wer_file: false
results_path: null
distributed_training:
ddp_backend: legacy_ddp
distributed_world_size: 1
hydra:
run:
dir: ${common_eval.results_path}/${dataset.gen_subset}
sweep:
dir: /checkpoint/${env:USER}/${env:PREFIX}/${common_eval.results_path}
subdir: ${dataset.gen_subset}
dataset:
max_tokens: 2_000_000
gen_subset: dev
required_batch_size_multiple: 1

View File

@ -0,0 +1,3 @@
#!/bin/bash
lang="$1"
PYTHONPATH=. PREFIX=INFER HYDRA_FULL_ERROR=1 python examples/speech_recognition/new/infer.py -m --config-dir examples/mms/asr/config/ --config-name infer_common decoding.type=viterbi dataset.max_tokens=4000000 distributed_training.distributed_world_size=1 "common_eval.path='/fsx-wav2vec/androstj/exps/wav2vec/mms/v4/finetune/xl1b_d5_dfls_0_0.3_u300k__ft_on_d5_127_dbeta1/ft_smax_adp_common.seed:1__dataset.max_tokens:2880000__optimization.lr:[0.001]__optimization.max_update:4000__merged_ckpt/checkpoints/checkpoint_last.pt'" task.data=/fsx-wav2vec/androstj/dataset/v4/fl/fseq dataset.gen_subset="${lang}:${lang}/dev" common_eval.post_process=none

View File

@ -0,0 +1,52 @@
#!/usr/bin/env python -u
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import soundfile as sf
import tempfile
from pathlib import Path
import os
import subprocess
import sys
import re
def parser():
parser = argparse.ArgumentParser(description="ASR inference script for MMS model")
parser.add_argument("--model", type=str, help="path to ASR model", required=True)
parser.add_argument("--audio", type=str, help="path to audio file", required=True, nargs='+')
parser.add_argument("--lang", type=str, help="audio language", required=True)
parser.add_argument("--format", type=str, choices=["none", "letter"], default="letter")
return parser.parse_args()
def process(args):
with tempfile.TemporaryDirectory() as tmpdir:
print(">>> preparing tmp manifest dir ...", file=sys.stderr)
tmpdir = Path(tmpdir)
with open(tmpdir / "dev.tsv", "w") as fw:
fw.write("/\n")
for audio in args.audio:
nsample = sf.SoundFile(audio).frames
fw.write(f"{audio}\t{nsample}\n")
with open(tmpdir / "dev.uid", "w") as fw:
fw.write(f"{audio}\n"*len(args.audio))
with open(tmpdir / "dev.ltr", "w") as fw:
fw.write("d u m m y | d u m m y\n"*len(args.audio))
with open(tmpdir / "dev.wrd", "w") as fw:
fw.write("dummy dummy\n"*len(args.audio))
cmd = f"""
PYTHONPATH=. PREFIX=INFER HYDRA_FULL_ERROR=1 python examples/speech_recognition/new/infer.py -m --config-dir examples/mms/asr/config/ --config-name infer_common decoding.type=viterbi dataset.max_tokens=4000000 distributed_training.distributed_world_size=1 "common_eval.path='{args.model}'" task.data={tmpdir} dataset.gen_subset="{args.lang}:dev" common_eval.post_process={args.format} decoding.results_path={tmpdir}
"""
print(">>> loading model & running inference ...", file=sys.stderr)
subprocess.run(cmd, shell=True, stdout=subprocess.DEVNULL,)
with open(tmpdir/"hypo.word") as fr:
for ii, hypo in enumerate(fr):
hypo = re.sub("\(\S+\)$", "", hypo).strip()
print(f'===============\nInput: {args.audio[ii]}\nOutput: {hypo}')
if __name__ == "__main__":
args = parser()
process(args)

View File

@ -0,0 +1,47 @@
# Data Preparation
We describe the process of aligning long audio files with their transcripts and generating shorter audio segments below.
- Step 1: Download and install torchaudio using the nightly version. We have open sourced the CTC forced alignment algorithm described in our paper via [torchaudio](https://github.com/pytorch/audio/pull/3348).
```
pip install --pre torchaudio --index-url https://download.pytorch.org/whl/nightly/cu118
```
- Step 2: Download [uroman](https://github.com/isi-nlp/uroman) from Github. It is a universal romanizer which converts text in any script to the Latin alphabet. Use [this link](https://www.isi.edu/~ulf/uroman.html) to try their web interface.
```
git clone git@github.com:isi-nlp/uroman.git
```
- Step 3: Install a few other dependencies
```
pip install sox
pip install dataclasses
```
- Step 4: Create a text file containing the transcript for a (long) audio file. Each line in the text file will correspond to a separate audio segment that will be generated upon alignment.
Example content of the input text file :
```
Text of the desired first segment
Text of the desired second segment
Text of the desired third segment
```
- Step 5: Run forced alignment and segment the audio file into shorter segments.
```
python align_and_segment.py --audio /path/to/audio.wav --textfile /path/to/textfile --lang <iso> --outdir /path/to/output --uroman /path/to/uroman/bin
```
The above code will generated the audio segments under output directory based on the content of each line in the input text file. The `manifest.json` file consisting of the of segmented audio filepaths and their corresponding transcripts.
```
> head /path/to/output/manifest.json
{"audio_start_sec": 0.0, "audio_filepath": "/path/to/output/segment1.flac", "duration": 6.8, "text": "she wondered afterwards how she could have spoken with that hard serenity how she could have", "normalized_text": "she wondered afterwards how she could have spoken with that hard serenity how she could have", "uroman_tokens": "s h e w o n d e r e d a f t e r w a r d s h o w s h e c o u l d h a v e s p o k e n w i t h t h a t h a r d s e r e n i t y h o w s h e c o u l d h a v e"}
{"audio_start_sec": 6.8, "audio_filepath": "/path/to/output/segment2.flac", "duration": 5.3, "text": "gone steadily on with story after story poem after poem till", "normalized_text": "gone steadily on with story after story poem after poem till", "uroman_tokens": "g o n e s t e a d i l y o n w i t h s t o r y a f t e r s t o r y p o e m a f t e r p o e m t i l l"}
{"audio_start_sec": 12.1, "audio_filepath": "/path/to/output/segment3.flac", "duration": 5.9, "text": "allan's grip on her hands relaxed and he fell into a heavy tired sleep", "normalized_text": "allan's grip on her hands relaxed and he fell into a heavy tired sleep", "uroman_tokens": "a l l a n ' s g r i p o n h e r h a n d s r e l a x e d a n d h e f e l l i n t o a h e a v y t i r e d s l e e p"}
```
To visualize the segmented audio files, [Speech Data Explorer](https://github.com/NVIDIA/NeMo/tree/main/tools/speech_data_explorer) tool from NeMo toolkit can be used.
As our alignment model outputs uroman tokens for input audio in any language, it also works with non-english audio and their corresponding transcripts.

View File

@ -0,0 +1,187 @@
import os
import torch
import torchaudio
import sox
import json
import argparse
from examples.mms.data_prep.text_normalization import text_normalize
from examples.mms.data_prep.align_utils import (
get_uroman_tokens,
time_to_frame,
load_model_dict,
merge_repeats,
get_spans,
)
import torchaudio.functional as F
SAMPLING_FREQ = 16000
EMISSION_INTERVAL = 30
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def generate_emissions(model, audio_file):
waveform, _ = torchaudio.load(audio_file) # waveform: channels X T
waveform = waveform.to(DEVICE)
total_duration = sox.file_info.duration(audio_file)
audio_sf = sox.file_info.sample_rate(audio_file)
assert audio_sf == SAMPLING_FREQ
emissions_arr = []
with torch.inference_mode():
i = 0
while i < total_duration:
segment_start_time, segment_end_time = (i, i + EMISSION_INTERVAL)
context = EMISSION_INTERVAL * 0.1
input_start_time = max(segment_start_time - context, 0)
input_end_time = min(segment_end_time + context, total_duration)
waveform_split = waveform[
:,
int(SAMPLING_FREQ * input_start_time) : int(
SAMPLING_FREQ * (input_end_time)
),
]
model_outs, _ = model(waveform_split)
emissions_ = model_outs[0]
emission_start_frame = time_to_frame(segment_start_time)
emission_end_frame = time_to_frame(segment_end_time)
offset = time_to_frame(input_start_time)
emissions_ = emissions_[
:, emission_start_frame - offset : emission_end_frame - offset
]
emissions_arr.append(emissions_)
i += EMISSION_INTERVAL
emissions = torch.cat(emissions_arr, dim=1).squeeze()
emissions = torch.log_softmax(emissions, dim=-1)
stride = float(waveform.size(1) * 1000 / emissions.size(0) / SAMPLING_FREQ)
return emissions, stride
def get_alignments(
audio_file,
tokens,
model,
dictionary,
use_star,
):
# Generate emissions
emissions, stride = generate_emissions(model, audio_file)
T, N = emissions.size()
if use_star:
emissions = torch.cat([emissions, torch.zeros(T, 1).to(DEVICE)], dim=1)
# Force Alignment
if tokens:
token_indices = [dictionary[c] for c in " ".join(tokens).split(" ") if c in dictionary]
else:
print(f"Empty transcript!!!!! for audio file {audio_file}")
token_indices = []
blank = dictionary["<blank>"]
path, _ = F.force_align(
emissions, torch.Tensor(token_indices, device=DEVICE).int(), blank=blank
)
path = path.to("cpu").tolist()
segments = merge_repeats(path, {v: k for k, v in dictionary.items()})
return segments, stride
def main(args):
assert not os.path.exists(
args.outdir
), f"Error: Output path exists already {args.outdir}"
transcripts = []
with open(args.text_filepath) as f:
transcripts = [line.strip() for line in f]
print("Read {} lines from {}".format(len(transcripts), args.text_filepath))
norm_transcripts = [text_normalize(line.strip(), args.lang) for line in transcripts]
tokens = get_uroman_tokens(norm_transcripts, args.uroman_path, args.lang)
model, dictionary = load_model_dict()
model = model.to(DEVICE)
if args.use_star:
dictionary["<star>"] = len(dictionary)
tokens = ["<star>"] + tokens
transcripts = ["<star>"] + transcripts
norm_transcripts = ["<star>"] + norm_transcripts
segments, stride = get_alignments(
args.audio_filepath,
tokens,
model,
dictionary,
args.use_star,
)
# Get spans of each line in input text file
spans = get_spans(tokens, segments)
os.makedirs(args.outdir)
with open( f"{args.outdir}/manifest.json", "w") as f:
for i, t in enumerate(transcripts):
span = spans[i]
seg_start_idx = span[0].start
seg_end_idx = span[-1].end
output_file = f"{args.outdir}/segment{i}.flac"
audio_start_sec = seg_start_idx * stride / 1000
audio_end_sec = seg_end_idx * stride / 1000
tfm = sox.Transformer()
tfm.trim(audio_start_sec , audio_end_sec)
tfm.build_file(args.audio_filepath, output_file)
sample = {
"audio_start_sec": audio_start_sec,
"audio_filepath": str(output_file),
"duration": audio_end_sec - audio_start_sec,
"text": t,
"normalized_text":norm_transcripts[i],
"uroman_tokens": tokens[i],
}
f.write(json.dumps(sample) + "\n")
return segments, stride
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Align and segment long audio files")
parser.add_argument(
"-a", "--audio_filepath", type=str, help="Path to input audio file"
)
parser.add_argument(
"-t", "--text_filepath", type=str, help="Path to input text file "
)
parser.add_argument(
"-l", "--lang", type=str, default="eng", help="ISO code of the language"
)
parser.add_argument(
"-u", "--uroman_path", type=str, default="eng", help="Location to uroman/bin"
)
parser.add_argument(
"-s",
"--use_star",
action="store_true",
help="Use star at the start of transcript",
)
parser.add_argument(
"-o",
"--outdir",
type=str,
help="Output directory to store segmented audio files",
)
print("Using torch version:", torch.__version__)
print("Using torchaudio version:", torchaudio.__version__)
print("Using device: ", DEVICE)
args = parser.parse_args()
main(args)

View File

@ -0,0 +1,176 @@
import re
import os
import torch
import tempfile
import math
from dataclasses import dataclass
from torchaudio.models import wav2vec2_model
# iso codes with specialized rules in uroman
special_isos_uroman = "ara, bel, bul, deu, ell, eng, fas, grc, ell, eng, heb, kaz, kir, lav, lit, mkd, mkd2, oss, pnt, pus, rus, srp, srp2, tur, uig, ukr, yid".split(",")
special_isos_uroman = [i.strip() for i in special_isos_uroman]
def normalize_uroman(text):
text = text.lower()
text = re.sub("([^a-z' ])", " ", text)
text = re.sub(' +', ' ', text)
return text.strip()
def get_uroman_tokens(norm_transcripts, uroman_root_dir, iso = None):
tf = tempfile.NamedTemporaryFile()
tf2 = tempfile.NamedTemporaryFile()
with open(tf.name, "w") as f:
for t in norm_transcripts:
f.write(t + "\n")
assert os.path.exists(f"{uroman_root_dir}/uroman.pl"), "uroman not found"
cmd = f"perl {uroman_root_dir}/uroman.pl"
if iso in special_isos_uroman:
cmd += f" -l {iso} "
cmd += f" < {tf.name} > {tf2.name}"
os.system(cmd)
outtexts = []
with open(tf2.name) as f:
for line in f:
line = " ".join(line.strip())
line = re.sub(r"\s+", " ", line).strip()
outtexts.append(line)
assert len(outtexts) == len(norm_transcripts)
uromans = []
for ot in outtexts:
uromans.append(normalize_uroman(ot))
return uromans
@dataclass
class Segment:
label: str
start: int
end: int
def __repr__(self):
return f"{self.label}: [{self.start:5d}, {self.end:5d})"
@property
def length(self):
return self.end - self.start
def merge_repeats(path, idx_to_token_map):
i1, i2 = 0, 0
segments = []
while i1 < len(path):
while i2 < len(path) and path[i1] == path[i2]:
i2 += 1
segments.append(Segment(idx_to_token_map[path[i1]], i1, i2 - 1))
i1 = i2
return segments
def time_to_frame(time):
stride_msec = 20
frames_per_sec = 1000 / stride_msec
return int(time * frames_per_sec)
def load_model_dict():
model_path_name = "/tmp/ctc_alignment_mling_uroman_model.pt"
print("Downloading model and dictionary...")
if os.path.exists(model_path_name):
print("Model path already exists. Skipping downloading....")
else:
torch.hub.download_url_to_file(
"https://dl.fbaipublicfiles.com/mms/torchaudio/ctc_alignment_mling_uroman/model.pt",
model_path_name,
)
assert os.path.exists(model_path_name)
state_dict = torch.load(model_path_name, map_location="cpu")
model = wav2vec2_model(
extractor_mode="layer_norm",
extractor_conv_layer_config=[
(512, 10, 5),
(512, 3, 2),
(512, 3, 2),
(512, 3, 2),
(512, 3, 2),
(512, 2, 2),
(512, 2, 2),
],
extractor_conv_bias=True,
encoder_embed_dim=1024,
encoder_projection_dropout=0.0,
encoder_pos_conv_kernel=128,
encoder_pos_conv_groups=16,
encoder_num_layers=24,
encoder_num_heads=16,
encoder_attention_dropout=0.0,
encoder_ff_interm_features=4096,
encoder_ff_interm_dropout=0.1,
encoder_dropout=0.0,
encoder_layer_norm_first=True,
encoder_layer_drop=0.1,
aux_num_out=31,
)
model.load_state_dict(state_dict)
model.eval()
dict_path_name = "/tmp/ctc_alignment_mling_uroman_model.dict"
if os.path.exists(dict_path_name):
print("Dictionary path already exists. Skipping downloading....")
else:
torch.hub.download_url_to_file(
"https://dl.fbaipublicfiles.com/mms/torchaudio/ctc_alignment_mling_uroman/dictionary.txt",
dict_path_name,
)
assert os.path.exists(dict_path_name)
dictionary = {}
with open(dict_path_name) as f:
dictionary = {l.strip(): i for i, l in enumerate(f.readlines())}
return model, dictionary
def get_spans(tokens, segments):
ltr_idx = 0
tokens_idx = 0
intervals = []
start, end = (0, 0)
sil = "<blank>"
for (seg_idx, seg) in enumerate(segments):
if(tokens_idx == len(tokens)):
assert(seg_idx == len(segments) - 1)
assert(seg.label == '<blank>')
continue
cur_token = tokens[tokens_idx].split(' ')
ltr = cur_token[ltr_idx]
if seg.label == "<blank>": continue
assert(seg.label == ltr)
if(ltr_idx) == 0: start = seg_idx
if ltr_idx == len(cur_token) - 1:
ltr_idx = 0
tokens_idx += 1
intervals.append((start, seg_idx))
while tokens_idx < len(tokens) and len(tokens[tokens_idx]) == 0:
intervals.append((seg_idx, seg_idx))
tokens_idx += 1
else:
ltr_idx += 1
spans = []
for (idx, (start, end)) in enumerate(intervals):
span = segments[start:end + 1]
if start > 0:
prev_seg = segments[start - 1]
if prev_seg.label == sil:
pad_start = prev_seg.start if (idx == 0) else int((prev_seg.start + prev_seg.end)/2)
span = [Segment(sil, pad_start, span[0].start)] + span
if end+1 < len(segments):
next_seg = segments[end+1]
if next_seg.label == sil:
pad_end = next_seg.end if (idx == len(intervals) - 1) else math.floor((next_seg.start + next_seg.end) / 2)
span = span + [Segment(sil, span[-1].end, pad_end)]
spans.append(span)
return spans

View File

@ -0,0 +1,277 @@
import os
import re
colon = ":"
comma = ","
exclamation_mark = "!"
period = re.escape(".")
question_mark = re.escape("?")
semicolon = ";"
left_curly_bracket = "{"
right_curly_bracket = "}"
quotation_mark = '"'
basic_punc = (
period
+ question_mark
+ comma
+ colon
+ exclamation_mark
+ left_curly_bracket
+ right_curly_bracket
)
# General punc unicode block (0x2000-0x206F)
zero_width_space = r"\u200B"
zero_width_nonjoiner = r"\u200C"
left_to_right_mark = r"\u200E"
right_to_left_mark = r"\u200F"
left_to_right_embedding = r"\u202A"
pop_directional_formatting = r"\u202C"
# Here are some commonly ill-typed versions of apostrophe
right_single_quotation_mark = r"\u2019"
left_single_quotation_mark = r"\u2018"
# Language specific definitions
# Spanish
inverted_exclamation_mark = r"\u00A1"
inverted_question_mark = r"\u00BF"
# Hindi
hindi_danda = u"\u0964"
# Egyptian Arabic
# arabic_percent = r"\u066A"
arabic_comma = r"\u060C"
arabic_question_mark = r"\u061F"
arabic_semicolon = r"\u061B"
arabic_diacritics = r"\u064B-\u0652"
arabic_subscript_alef_and_inverted_damma = r"\u0656-\u0657"
# Chinese
full_stop = r"\u3002"
full_comma = r"\uFF0C"
full_exclamation_mark = r"\uFF01"
full_question_mark = r"\uFF1F"
full_semicolon = r"\uFF1B"
full_colon = r"\uFF1A"
full_parentheses = r"\uFF08\uFF09"
quotation_mark_horizontal = r"\u300C-\u300F"
quotation_mark_vertical = r"\uFF41-\uFF44"
title_marks = r"\u3008-\u300B"
wavy_low_line = r"\uFE4F"
ellipsis = r"\u22EF"
enumeration_comma = r"\u3001"
hyphenation_point = r"\u2027"
forward_slash = r"\uFF0F"
wavy_dash = r"\uFF5E"
box_drawings_light_horizontal = r"\u2500"
fullwidth_low_line = r"\uFF3F"
chinese_punc = (
full_stop
+ full_comma
+ full_exclamation_mark
+ full_question_mark
+ full_semicolon
+ full_colon
+ full_parentheses
+ quotation_mark_horizontal
+ quotation_mark_vertical
+ title_marks
+ wavy_low_line
+ ellipsis
+ enumeration_comma
+ hyphenation_point
+ forward_slash
+ wavy_dash
+ box_drawings_light_horizontal
+ fullwidth_low_line
)
# Armenian
armenian_apostrophe = r"\u055A"
emphasis_mark = r"\u055B"
exclamation_mark = r"\u055C"
armenian_comma = r"\u055D"
armenian_question_mark = r"\u055E"
abbreviation_mark = r"\u055F"
armenian_full_stop = r"\u0589"
armenian_punc = (
armenian_apostrophe
+ emphasis_mark
+ exclamation_mark
+ armenian_comma
+ armenian_question_mark
+ abbreviation_mark
+ armenian_full_stop
)
lesser_than_symbol = r"&lt;"
greater_than_symbol = r"&gt;"
lesser_than_sign = r"\u003c"
greater_than_sign = r"\u003e"
nbsp_written_form = r"&nbsp"
# Quotation marks
left_double_quotes = r"\u201c"
right_double_quotes = r"\u201d"
left_double_angle = r"\u00ab"
right_double_angle = r"\u00bb"
left_single_angle = r"\u2039"
right_single_angle = r"\u203a"
low_double_quotes = r"\u201e"
low_single_quotes = r"\u201a"
high_double_quotes = r"\u201f"
high_single_quotes = r"\u201b"
all_punct_quotes = (
left_double_quotes
+ right_double_quotes
+ left_double_angle
+ right_double_angle
+ left_single_angle
+ right_single_angle
+ low_double_quotes
+ low_single_quotes
+ high_double_quotes
+ high_single_quotes
+ right_single_quotation_mark
+ left_single_quotation_mark
)
mapping_quotes = (
"["
+ high_single_quotes
+ right_single_quotation_mark
+ left_single_quotation_mark
+ "]"
)
# Digits
english_digits = r"\u0030-\u0039"
bengali_digits = r"\u09e6-\u09ef"
khmer_digits = r"\u17e0-\u17e9"
devanagari_digits = r"\u0966-\u096f"
oriya_digits = r"\u0b66-\u0b6f"
extended_arabic_indic_digits = r"\u06f0-\u06f9"
kayah_li_digits = r"\ua900-\ua909"
fullwidth_digits = r"\uff10-\uff19"
malayam_digits = r"\u0d66-\u0d6f"
myanmar_digits = r"\u1040-\u1049"
roman_numeral = r"\u2170-\u2179"
nominal_digit_shapes = r"\u206f"
# Load punctuations from MMS-lab data
with open(f"{os.path.dirname(__file__)}/punctuations.lst", "r") as punc_f:
punc_list = punc_f.readlines()
punct_pattern = r""
for punc in punc_list:
# the first character in the tab separated line is the punc to be removed
punct_pattern += re.escape(punc.split("\t")[0])
shared_digits = (
english_digits
+ bengali_digits
+ khmer_digits
+ devanagari_digits
+ oriya_digits
+ extended_arabic_indic_digits
+ kayah_li_digits
+ fullwidth_digits
+ malayam_digits
+ myanmar_digits
+ roman_numeral
+ nominal_digit_shapes
)
shared_punc_list = (
basic_punc
+ all_punct_quotes
+ greater_than_sign
+ lesser_than_sign
+ inverted_question_mark
+ full_stop
+ semicolon
+ armenian_punc
+ inverted_exclamation_mark
+ arabic_comma
+ enumeration_comma
+ hindi_danda
+ quotation_mark
+ arabic_semicolon
+ arabic_question_mark
+ chinese_punc
+ punct_pattern
)
shared_mappping = {
lesser_than_symbol: "",
greater_than_symbol: "",
nbsp_written_form: "",
r"(\S+)" + mapping_quotes + r"(\S+)": r"\1'\2",
}
shared_deletion_list = (
left_to_right_mark
+ zero_width_nonjoiner
+ arabic_subscript_alef_and_inverted_damma
+ zero_width_space
+ arabic_diacritics
+ pop_directional_formatting
+ right_to_left_mark
+ left_to_right_embedding
)
norm_config = {
"*": {
"lower_case": True,
"punc_set": shared_punc_list,
"del_set": shared_deletion_list,
"mapping": shared_mappping,
"digit_set": shared_digits,
"unicode_norm": "NFKC",
"rm_diacritics" : False,
}
}
#=============== Mongolian ===============#
norm_config["mon"] = norm_config["*"].copy()
# add soft hyphen to punc list to match with fleurs
norm_config["mon"]["del_set"] += r"\u00AD"
norm_config["khk"] = norm_config["mon"].copy()
#=============== Hebrew ===============#
norm_config["heb"] = norm_config["*"].copy()
# add "HEBREW POINT" symbols to match with fleurs
norm_config["heb"]["del_set"] += r"\u05B0-\u05BF\u05C0-\u05CF"
#=============== Thai ===============#
norm_config["tha"] = norm_config["*"].copy()
# add "Zero width joiner" symbols to match with fleurs
norm_config["tha"]["punc_set"] += r"\u200D"
#=============== Arabic ===============#
norm_config["ara"] = norm_config["*"].copy()
norm_config["ara"]["mapping"]["ٱ"] = "ا"
norm_config["arb"] = norm_config["ara"].copy()
#=============== Javanese ===============#
norm_config["jav"] = norm_config["*"].copy()
norm_config["jav"]["rm_diacritics"] = True

View File

@ -0,0 +1,188 @@
 7355 INVALID UNICODE 0x81
 5265 INVALID UNICODE 0x90
 75 INVALID UNICODE 0x8
 31 INVALID UNICODE 0x8d
” 3 INVALID UNICODE 0x94
 2 INVALID UNICODE 0x8f
 2 INVALID UNICODE 0x1a
 1 INVALID UNICODE 0x9d
“ 1 INVALID UNICODE 0x93
’ 1 INVALID UNICODE 0x92
 8647 INVALID UNICODE 0xe295
 6650 INVALID UNICODE 0xf21d
 6234 INVALID UNICODE 0xf62d
 4815 INVALID UNICODE 0xf173
 4789 INVALID UNICODE 0xe514
 4409 INVALID UNICODE 0xe293
 3881 INVALID UNICODE 0xf523
 3788 INVALID UNICODE 0xe233
 2448 INVALID UNICODE 0xf50f
 2177 INVALID UNICODE 0xe232
 1955 INVALID UNICODE 0xea7b
 1926 INVALID UNICODE 0xf172
 973 INVALID UNICODE 0xe290
 972 INVALID UNICODE 0xf519
 661 INVALID UNICODE 0xe292
 591 INVALID UNICODE 0xe328
 509 INVALID UNICODE 0xe2fa
 458 INVALID UNICODE 0xe234
 446 INVALID UNICODE 0xe043
 419 INVALID UNICODE 0xe040
 399 INVALID UNICODE 0xe2fb
 387 INVALID UNICODE 0xe32b
 381 INVALID UNICODE 0xe236
 374 INVALID UNICODE 0xf511
 314 INVALID UNICODE 0xe517
 296 INVALID UNICODE 0xe2fe
 293 INVALID UNICODE 0xe492
 291 INVALID UNICODE 0xf52d
 289 INVALID UNICODE 0xe2fc
 195 INVALID UNICODE 0xf521
 190 INVALID UNICODE 0xe516
 182 INVALID UNICODE 0xe041
 178 INVALID UNICODE 0xf529
 113 INVALID UNICODE 0xe2f9
 87 INVALID UNICODE 0xe2d9
 78 INVALID UNICODE 0xe32a
 76 INVALID UNICODE 0xe291
 74 INVALID UNICODE 0xe296
 66 INVALID UNICODE 0xe518
 52 INVALID UNICODE 0xe32c
 46 INVALID UNICODE 0xe2db
 41 INVALID UNICODE 0xe231
 34 INVALID UNICODE 0xf522
 33 INVALID UNICODE 0xf518
 32 INVALID UNICODE 0xf513
 27 INVALID UNICODE 0xe32d
 25 INVALID UNICODE 0xe32e
 23 INVALID UNICODE 0xe06b
 15 INVALID UNICODE 0xea01
 12 INVALID UNICODE 0xe294
 11 INVALID UNICODE 0xe203
 8 INVALID UNICODE 0xf218
 7 INVALID UNICODE 0xe070
 7 INVALID UNICODE 0xe013
 5 INVALID UNICODE 0xe2de
 4 INVALID UNICODE 0xe493
 3 INVALID UNICODE 0xf7e8
 3 INVALID UNICODE 0xf7d0
 3 INVALID UNICODE 0xe313
 2 INVALID UNICODE 0xe329
 2 INVALID UNICODE 0xe06d
 2 INVALID UNICODE 0xe003
 1 INVALID UNICODE 0xf50e
 1 INVALID UNICODE 0xf171
 1 INVALID UNICODE 0xe01d
71 NOMINAL DIGIT SHAPES 0x206f
3 WORD JOINER 0x2060
― 126545 HORIZONTAL BAR 0x2015
־ 1028 HEBREW PUNCTUATION MAQAF 0x5be
) 98429 RIGHT PARENTHESIS 0x29
] 27108 RIGHT SQUARE BRACKET 0x5d
⌋ 1567 RIGHT FLOOR 0x230b
97 RIGHT TORTOISE SHELL BRACKET 0x3015
】 36 RIGHT BLACK LENTICULAR BRACKET 0x3011
14 ORNATE LEFT PARENTHESIS 0xfd3e
& 170517 AMPERSAND 0x26
། 106330 TIBETAN MARK SHAD 0xf0d
። 90203 ETHIOPIC FULL STOP 0x1362
፥ 60484 ETHIOPIC COLON 0x1365
༌ 60464 TIBETAN MARK DELIMITER TSHEG BSTAR 0xf0c
။ 51567 MYANMAR SIGN SECTION 0x104b
/ 46929 SOLIDUS 0x2f
၊ 38042 MYANMAR SIGN LITTLE SECTION 0x104a
· 37985 MIDDLE DOT 0xb7
‸ 36310 CARET 0x2038
* 34793 ASTERISK 0x2a
۔ 32432 ARABIC FULL STOP 0x6d4
፤ 31906 ETHIOPIC SEMICOLON 0x1364
၏ 21519 MYANMAR SYMBOL GENITIVE 0x104f
។ 20834 KHMER SIGN KHAN 0x17d4
꓾ 15773 LISU PUNCTUATION COMMA 0xa4fe
13473 CANADIAN SYLLABICS FULL STOP 0x166e
꤯ 12892 KAYAH LI SIGN SHYA 0xa92f
⵰ 11478 TIFINAGH SEPARATOR MARK 0x2d70
11118 LISU PUNCTUATION FULL STOP 0xa4ff
॥ 10763 DEVANAGARI DOUBLE DANDA 0x965
؞ 10403 ARABIC TRIPLE DOT PUNCTUATION MARK 0x61e
၍ 8936 MYANMAR SYMBOL COMPLETED 0x104d
· 8431 GREEK ANO TELEIA 0x387
† 7477 DAGGER 0x2020
၌ 6632 MYANMAR SYMBOL LOCATIVE 0x104c
፣ 5719 ETHIOPIC COMMA 0x1363
៖ 5528 KHMER SIGN CAMNUC PII KUUH 0x17d6
꤮ 4791 KAYAH LI SIGN CWI 0xa92e
※ 3439 REFERENCE MARK 0x203b
፦ 2727 ETHIOPIC PREFACE COLON 0x1366
• 1749 BULLET 0x2022
¶ 1507 PILCROW SIGN 0xb6
၎ 1386 MYANMAR SYMBOL AFOREMENTIONED 0x104e
﹖ 1224 SMALL QUESTION MARK 0xfe56
; 975 GREEK QUESTION MARK 0x37e
… 827 HORIZONTAL ELLIPSIS 0x2026
% 617 PERCENT SIGN 0x25
・ 468 KATAKANA MIDDLE DOT 0x30fb
༎ 306 TIBETAN MARK NYIS SHAD 0xf0e
‡ 140 DOUBLE DAGGER 0x2021
# 137 NUMBER SIGN 0x23
@ 125 COMMERCIAL AT 0x40
፡ 121 ETHIOPIC WORDSPACE 0x1361
៚ 55 KHMER SIGN KOOMUUT 0x17da
៕ 49 KHMER SIGN BARIYOOSAN 0x17d5
﹐ 10 SMALL COMMA 0xfe50
༅ 6 TIBETAN MARK CLOSING YIG MGO SGAB MA 0xf05
༄ 6 TIBETAN MARK INITIAL YIG MGO MDUN MA 0xf04
2 FULLWIDTH FULL STOP 0xff0e
﹗ 2 SMALL EXCLAMATION MARK 0xfe57
﹕ 2 SMALL COLON 0xfe55
‰ 2 PER MILLE SIGN 0x2030
・ 1 HALFWIDTH KATAKANA MIDDLE DOT 0xff65
( 98504 LEFT PARENTHESIS 0x28
[ 27245 LEFT SQUARE BRACKET 0x5b
⌊ 1567 LEFT FLOOR 0x230a
95 LEFT TORTOISE SHELL BRACKET 0x3014
【 36 LEFT BLACK LENTICULAR BRACKET 0x3010
﴿ 14 ORNATE RIGHT PARENTHESIS 0xfd3f
_ 4851 LOW LINE 0x5f
$ 72 DOLLAR SIGN 0x24
€ 14 EURO SIGN 0x20ac
£ 2 POUND SIGN 0xa3
~ 27462 TILDE 0x7e
= 11450 EQUALS SIGN 0x3d
| 8430 VERTICAL LINE 0x7c
3971 MINUS SIGN 0x2212
≫ 1904 MUCH GREATER-THAN 0x226b
≪ 1903 MUCH LESS-THAN 0x226a
+ 1450 PLUS SIGN 0x2b
345 FULLWIDTH LESS-THAN SIGN 0xff1c
344 FULLWIDTH GREATER-THAN SIGN 0xff1e
¬ 5 NOT SIGN 0xac
× 4 MULTIPLICATION SIGN 0xd7
→ 2 RIGHTWARDS ARROW 0x2192
537 CANADIAN SYLLABICS CHI SIGN 0x166d
° 499 DEGREE SIGN 0xb0
႟ 421 MYANMAR SYMBOL SHAN EXCLAMATION 0x109f
<EFBFBD> 192 REPLACEMENT CHARACTER 0xfffd
⌟ 54 BOTTOM RIGHT CORNER 0x231f
⌞ 54 BOTTOM LEFT CORNER 0x231e
© 2 COPYRIGHT SIGN 0xa9
40 NARROW NO-BREAK SPACE 0x202f
1 SIX-PER-EM SPACE 0x2006
˜ 40261 SMALL TILDE 0x2dc
^ 6469 CIRCUMFLEX ACCENT 0x5e
¯ 20 MACRON 0xaf
ˇ 191442 CARON 0x2c7
ⁿ 38144 SUPERSCRIPT LATIN SMALL LETTER N 0x207f
ـ 9440 ARABIC TATWEEL 0x640
ๆ 6766 THAI CHARACTER MAIYAMOK 0xe46
ៗ 3310 KHMER SIGN LEK TOO 0x17d7
々 678 IDEOGRAPHIC ITERATION MARK 0x3005
ໆ 430 LAO KO LA 0xec6
ー 319 KATAKANA-HIRAGANA PROLONGED SOUND MARK 0x30fc
ⁱ 137 SUPERSCRIPT LATIN SMALL LETTER I 0x2071
৷ 11056 BENGALI CURRENCY NUMERATOR FOUR 0x9f7
⅓ 26 VULGAR FRACTION ONE THIRD 0x2153
½ 26 VULGAR FRACTION ONE HALF 0xbd
¼ 4 VULGAR FRACTION ONE QUARTER 0xbc
⅟ 1 FRACTION NUMERATOR ONE 0x215f
57 FRACTION SLASH 0x2044

View File

@ -0,0 +1,92 @@
import json
import re
import unicodedata
from examples.mms.data_prep.norm_config import norm_config
def text_normalize(text, iso_code, lower_case=True, remove_numbers=True, remove_brackets=False):
"""Given a text, normalize it by changing to lower case, removing punctuations, removing words that only contain digits and removing extra spaces
Args:
text : The string to be normalized
iso_code :
remove_numbers : Boolean flag to specify if words containing only digits should be removed
Returns:
normalized_text : the string after all normalization
"""
config = norm_config.get(iso_code, norm_config["*"])
for field in ["lower_case", "punc_set","del_set", "mapping", "digit_set", "unicode_norm"]:
if field not in config:
config[field] = norm_config["*"][field]
text = unicodedata.normalize(config["unicode_norm"], text)
# Convert to lower case
if config["lower_case"] and lower_case:
text = text.lower()
# brackets
# always text inside brackets with numbers in them. Usually corresponds to "(Sam 23:17)"
text = re.sub(r"\([^\)]*\d[^\)]*\)", " ", text)
if remove_brackets:
text = re.sub(r"\([^\)]*\)", " ", text)
# Apply mappings
for old, new in config["mapping"].items():
text = re.sub(old, new, text)
# Replace punctutations with space
punct_pattern = r"[" + config["punc_set"]
punct_pattern += "]"
normalized_text = re.sub(punct_pattern, " ", text)
# remove characters in delete list
delete_patten = r"[" + config["del_set"] + "]"
normalized_text = re.sub(delete_patten, "", normalized_text)
# Remove words containing only digits
# We check for 3 cases a)text starts with a number b) a number is present somewhere in the middle of the text c) the text ends with a number
# For each case we use lookaround regex pattern to see if the digit pattern in preceded and followed by whitespaces, only then we replace the numbers with space
# The lookaround enables overlapping pattern matches to be replaced
if remove_numbers:
digits_pattern = "[" + config["digit_set"]
digits_pattern += "]+"
complete_digit_pattern = (
r"^"
+ digits_pattern
+ "(?=\s)|(?<=\s)"
+ digits_pattern
+ "(?=\s)|(?<=\s)"
+ digits_pattern
+ "$"
)
normalized_text = re.sub(complete_digit_pattern, " ", normalized_text)
if config["rm_diacritics"]:
from unidecode import unidecode
normalized_text = unidecode(normalized_text)
# Remove extra spaces
normalized_text = re.sub(r"\s+", " ", normalized_text).strip()
return normalized_text

197
examples/mms/lid/infer.py Normal file
View File

@ -0,0 +1,197 @@
import torch
from fairseq.data.text_compressor import TextCompressionLevel, TextCompressor
from fairseq import checkpoint_utils, distributed_utils, options, utils
from fairseq import checkpoint_utils, data, options, tasks
from fairseq.data import FileAudioDataset, AddTargetDataset, Dictionary
from fairseq.tasks.audio_classification import LabelEncoder
import copy
from tqdm import tqdm
import tempfile
import numpy as np
import json
def subset_manifest(infer_manifest, veri_pair):
with open(infer_manifest) as ff, open(veri_pair) as gg, tempfile.NamedTemporaryFile(
"w", delete=False
) as ww:
fnames = ff.read().strip().split("\n")
basedir = fnames[0]
needed_fname = []
for gi in gg.read().strip().split("\n"):
_, x1, x2 = gi.split()
needed_fname.append(x1)
needed_fname.append(x2)
needed_fname = set(needed_fname)
ww.write(basedir + "\n")
for ii in range(1, len(fnames)):
x1, x2 = fnames[ii].split()
if x1 in needed_fname:
ww.write(fnames[ii] + "\n")
print(f"| subset manifest for verification: {ww.name}")
return ww.name
def wrap_target_dataset(infer_manifest, dataset, task):
label_path = infer_manifest.replace(".tsv", ".lang")
text_compressor = TextCompressor(level=TextCompressionLevel.none)
with open(label_path, "r") as f:
labels = [text_compressor.compress(l) for i,l in enumerate(f)]
assert len(labels) == len(dataset)
process_label = LabelEncoder(task.target_dictionary)
dataset = AddTargetDataset(
dataset,
labels,
pad=task.target_dictionary.pad(),
eos=task.target_dictionary.eos(),
batch_targets=True,
process_label=process_label,
add_to_input=False,
)
return dataset
def resample_data(source, padding_mask, n_sample, max_sample_len):
# source: BxT
# padding_mask: BxT
B = source.shape[0]
T = source.shape[1]
sources = []
padding_masks = []
if B == 1:
return [source], [None]
seq_len = (~padding_mask).sum(1)
for jj in range(n_sample):
new_source = source.new_zeros(B, max_sample_len)
new_padding_mask = padding_mask.new_zeros(B, max_sample_len)
for ii in range(B):
if seq_len[ii] > max_sample_len:
start = np.random.randint(0, seq_len[ii] - max_sample_len + 1)
end = start + max_sample_len
else:
start = 0
end = seq_len[ii]
new_source[ii, 0 : end - start] = source[ii, start:end]
new_padding_mask[ii, end - start + 1 :] = True
sources.append(new_source)
padding_masks.append(new_padding_mask)
return sources, padding_masks
def resample_sample(sample, n_sample, max_sample_len):
new_sources, new_padding_masks = resample_data(
sample["net_input"]["source"],
sample["net_input"]["padding_mask"],
n_sample,
max_sample_len,
)
new_samples = []
for ii in range(n_sample):
new_sample = copy.deepcopy(sample)
new_sample["net_input"]["source"] = new_sources[ii]
new_sample["net_input"]["padding_mask"] = new_padding_masks[ii]
new_samples.append(new_sample)
return new_samples
def dict_to_nparr(dd):
dict_class = []
dict_idx = []
for ii, jj in enumerate(dd.symbols):
dict_idx.append(ii)
dict_class.append(jj)
dict_idx = np.array(dict_idx)
dict_class = np.array(dict_class)
return dict_class, dict_idx
if __name__ == "__main__":
np.random.seed(123)
# Parse command-line arguments for generation
parser = options.get_generation_parser(default_task="audio_classification")
# parser.add_argument('--infer-merge', type=str, default='mean')
parser.add_argument("--infer-xtimes", type=int, default=1)
parser.add_argument("--infer-num-samples", type=int, default=None)
parser.add_argument("--top-k", type=int, default=3)
parser.add_argument(
"--infer-max-sample-size", type=int, default=5 * 16000
) # 5 secs
parser.add_argument("--infer-manifest", required=True, type=str)
parser.add_argument("--output-path", default="/tmp/", type=str)
args = options.parse_args_and_arch(parser)
# Setup task
# task = tasks.setup_task(args)
use_cuda = not args.cpu
# Load model & task
print("| loading model from {}".format(args.path))
arg_overrides = {
"task": {
"data": args.data
},
# 'mask_prob': 0
#'max_sample_size': sys.maxsize,
#'min_sample_size': 0,
}
state = checkpoint_utils.load_checkpoint_to_cpu(args.path, arg_overrides)
models, _model_args, task = checkpoint_utils.load_model_ensemble_and_task(
[args.path], arg_overrides=arg_overrides, task=None, state=state
)
model = models[0]
model.eval()
if use_cuda:
model.cuda()
# Load dataset
dict_class, dict_idx = dict_to_nparr(task.target_dictionary)
infer_manifest = args.infer_manifest
infer_dataset = FileAudioDataset(
infer_manifest,
sample_rate=task.cfg.sample_rate,
max_sample_size=10**10, # task.cfg.max_sample_size,
min_sample_size=1, # task.cfg.min_sample_size,
pad=True,
normalize=task.cfg.normalize,
)
# add target (if needed)
infer_dataset = wrap_target_dataset(infer_manifest, infer_dataset, task)
itr = task.get_batch_iterator(
dataset=infer_dataset,
max_sentences=1,
# max_tokens=args.max_tokens,
num_workers=4,
).next_epoch_itr(shuffle=False)
predictions = {}
with torch.no_grad():
for _, sample in tqdm(enumerate(itr)):
# resample if needed
samples = resample_sample(
sample, args.infer_xtimes, args.infer_max_sample_size
)
for sample in samples:
sample = utils.move_to_cuda(sample) if use_cuda else sample
try:
latent = model.forward_latent(**sample["net_input"])
except:
latent = None
logit = model.forward(**sample["net_input"])
logit_lsm = torch.log_softmax(logit.squeeze(), dim=-1)
scores, indices = torch.topk(logit_lsm, args.top_k, dim=-1)
scores = torch.exp(scores).to("cpu").tolist()
indices = indices.to("cpu").tolist()
assert sample["id"].numel() == 1
sample_idx = sample["id"].to("cpu").tolist()[0]
assert sample_idx not in predictions
predictions[sample_idx] = [(task.target_dictionary[int(i)], s) for s, i in zip(scores, indices)]
with open(f"{args.output_path}/predictions.txt", "w") as fo:
for idx in range(len(infer_dataset)):
fo.write(json.dumps(predictions[idx]) + "\n")
print(f"Outputs will be located at - {args.output_path}/predictions.txt")

102
examples/mms/tts/infer.py Normal file
View File

@ -0,0 +1,102 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import os
import glob
import json
import math
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
import numpy as np
import commons
import utils
import argparse
from data_utils import TextAudioLoader, TextAudioCollate, TextAudioSpeakerLoader, TextAudioSpeakerCollate
from models import SynthesizerTrn
from scipy.io.wavfile import write
class TextMapper(object):
def __init__(self, vocab_file):
self.symbols = [x.replace("\n", "") for x in open(vocab_file).readlines()]
self.SPACE_ID = self.symbols.index(" ")
self._symbol_to_id = {s: i for i, s in enumerate(self.symbols)}
self._id_to_symbol = {i: s for i, s in enumerate(self.symbols)}
def text_to_sequence(self, text, cleaner_names):
'''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
Args:
text: string to convert to a sequence
cleaner_names: names of the cleaner functions to run the text through
Returns:
List of integers corresponding to the symbols in the text
'''
sequence = []
clean_text = text.strip()
for symbol in clean_text:
symbol_id = self._symbol_to_id[symbol]
sequence += [symbol_id]
return sequence
def get_text(self, text, hps):
text_norm = self.text_to_sequence(text, hps.data.text_cleaners)
if hps.data.add_blank:
text_norm = commons.intersperse(text_norm, 0)
text_norm = torch.LongTensor(text_norm)
return text_norm
def filter_oov(self, text):
val_chars = self._symbol_to_id
txt_filt = "".join(list(filter(lambda x: x in val_chars, text)))
print(f"text after filtering OOV: {txt_filt}")
return txt_filt
def generate():
parser = argparse.ArgumentParser(description='TTS inference')
parser.add_argument('--model-dir', type=str, help='model checkpoint dir')
parser.add_argument('--wav', type=str, help='output wav path')
parser.add_argument('--txt', type=str, help='input text')
args = parser.parse_args()
ckpt_dir, wav_path, txt = args.model_dir, args.wav, args.txt
vocab_file = f"{ckpt_dir}/vocab.txt"
config_file = f"{ckpt_dir}/config.json"
assert os.path.isfile(config_file), f"{config_file} doesn't exist"
hps = utils.get_hparams_from_file(config_file)
text_mapper = TextMapper(vocab_file)
net_g = SynthesizerTrn(
len(text_mapper.symbols),
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
**hps.model)
net_g.cuda()
_ = net_g.eval()
g_pth = f"{ckpt_dir}/G_100000.pth"
print(f"load {g_pth}")
_ = utils.load_checkpoint(g_pth, net_g, None)
print(f"text: {txt}")
txt = txt.lower()
txt = text_mapper.filter_oov(txt)
stn_tst = text_mapper.get_text(txt, hps)
with torch.no_grad():
x_tst = stn_tst.unsqueeze(0).cuda()
x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).cuda()
hyp = net_g.infer(
x_tst, x_tst_lengths, noise_scale=.667,
noise_scale_w=0.8, length_scale=1.0
)[0][0,0].cpu().float().numpy()
os.makedirs(os.path.dirname(wav_path), exist_ok=True)
print(f"wav: {wav_path}")
write(wav_path, hps.data.sampling_rate, hyp)
return
if __name__ == '__main__':
generate()

View File

@ -10,6 +10,7 @@ import logging
import os
import shutil
import sys
import re
from dataclasses import dataclass, field, is_dataclass
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union
@ -101,6 +102,29 @@ class InferenceProcessor:
self.task = tasks.setup_task(cfg.task)
models, saved_cfg = self.load_model_ensemble()
### LOAD ADAPTER ####
ckpt_obj = checkpoint_utils.load_checkpoint_to_cpu(self.cfg.common_eval.path)
if "adapter" in ckpt_obj:
target_lang = self.cfg.dataset.gen_subset.split(":")[0]
assert target_lang in ckpt_obj["adapter"]
logger.info(f">>> LOADING ADAPTER: {target_lang}")
ft_obj = ckpt_obj["adapter"][target_lang]
ft_model = ft_obj["model"]
cdevice = models[0].w2v_encoder.proj.weight.device
cdtype = models[0].w2v_encoder.proj.weight.dtype
ft_proj_out, ft_proj_in = ft_model["w2v_encoder.proj.weight"].shape
ft_proj = torch.nn.Linear(ft_proj_in, ft_proj_out, bias=True)
ft_proj.to(device=cdevice, dtype=cdtype)
models[0].w2v_encoder.proj = ft_proj
with torch.no_grad():
for kk, vv in models[0].named_parameters():
if kk in ft_model:
vv.copy_(ft_model[kk])
self.task.load_state_dict(ft_obj["task_state"])
# overwrite gen_subset with master config
self.cfg.dataset.gen_subset = re.sub('^[\w-]+:', saved_cfg['task']['multi_corpus_keys']+":", self.cfg.dataset.gen_subset)
self.models = models
self.saved_cfg = saved_cfg
self.tgt_dict = self.task.target_dictionary

View File

@ -47,6 +47,7 @@ class RawAudioDataset(FairseqDataset):
expand_adjacent: bool = False,
mask_dropout: float = 0,
non_overlapping: bool = False,
corpus_key=None,
):
super().__init__()
@ -72,6 +73,7 @@ class RawAudioDataset(FairseqDataset):
self.expand_adjacent = expand_adjacent
self.mask_dropout = mask_dropout
self.non_overlapping = non_overlapping
self.corpus_key = corpus_key
def __getitem__(self, index):
raise NotImplementedError()
@ -144,6 +146,8 @@ class RawAudioDataset(FairseqDataset):
collated_sources[i] = self.crop_to_max_size(source, target_size)
input = {"source": collated_sources}
if self.corpus_key is not None:
input["corpus_key"] = [self.corpus_key] * len(sources)
out = {"id": torch.LongTensor([s["id"] for s in samples])}
if self.pad:
input["padding_mask"] = padding_mask

View File

@ -26,19 +26,21 @@ class Dictionary:
eos="</s>",
unk="<unk>",
extra_special_symbols=None,
add_special_symbols=True,
):
self.bos_word, self.unk_word, self.pad_word, self.eos_word = bos, unk, pad, eos
self.symbols = []
self.count = []
self.indices = {}
self.bos_index = self.add_symbol(bos)
self.pad_index = self.add_symbol(pad)
self.eos_index = self.add_symbol(eos)
self.unk_index = self.add_symbol(unk)
if extra_special_symbols:
for s in extra_special_symbols:
self.add_symbol(s)
self.nspecial = len(self.symbols)
if add_special_symbols:
self.bos_index = self.add_symbol(bos)
self.pad_index = self.add_symbol(pad)
self.eos_index = self.add_symbol(eos)
self.unk_index = self.add_symbol(unk)
if extra_special_symbols:
for s in extra_special_symbols:
self.add_symbol(s)
self.nspecial = len(self.symbols)
def __eq__(self, other):
return self.indices == other.indices
@ -213,7 +215,7 @@ class Dictionary:
return self.unk_index
@classmethod
def load(cls, f):
def load(cls, f, add_special_symbols=True):
"""Loads the dictionary from a text file with the format:
```
@ -222,7 +224,7 @@ class Dictionary:
...
```
"""
d = cls()
d = cls(add_special_symbols=add_special_symbols)
d.add_from_file(f)
return d

View File

@ -7,3 +7,4 @@ from .wav2vec import * # noqa
from .wav2vec2 import * # noqa
from .wav2vec2_asr import * # noqa
from .wav2vec2_laser import * # noqa
from .wav2vec2_classification import * # noqa

View File

@ -17,6 +17,7 @@ from fairseq.data.data_utils import compute_mask_indices
from fairseq.dataclass import ChoiceEnum, FairseqDataclass
from fairseq.distributed import fsdp_wrap
from fairseq.models import BaseFairseqModel, register_model
from fairseq.distributed.fully_sharded_data_parallel import FullyShardedDataParallel
from fairseq.modules import (
Fp32GroupNorm,
Fp32LayerNorm,
@ -37,7 +38,7 @@ from .utils import pad_to_multiple
EXTRACTOR_MODE_CHOICES = ChoiceEnum(["default", "layer_norm"])
MASKING_DISTRIBUTION_CHOICES = ChoiceEnum(["static", "uniform", "normal", "poisson"])
LAYER_TYPE_CHOICES = ChoiceEnum(["transformer", "conformer"])
LAYER_TYPE_CHOICES = ChoiceEnum(["transformer", "conformer", "trf_adp"])
@dataclass
@ -289,6 +290,20 @@ class Wav2Vec2Config(FairseqDataclass):
)
fp16: bool = field(default=False, metadata={"help": "If fp16 is being used"})
# Adapter num
adp_num: int = field(
default=-1
)
adp_dim: int = field(
default=64
)
adp_act_fn: str = field(
default="relu"
)
adp_trf_idx: str = field(
default="all",
)
@register_model("wav2vec2", dataclass=Wav2Vec2Config)
class Wav2Vec2Model(BaseFairseqModel):
@ -588,6 +603,7 @@ class Wav2Vec2Model(BaseFairseqModel):
mask_indices=None,
mask_channel_indices=None,
padding_count=None,
corpus_key=None,
):
if self.feature_grad_mult > 0:
@ -672,7 +688,9 @@ class Wav2Vec2Model(BaseFairseqModel):
y = unmasked_features
mask_indices = None
x, layer_results = self.encoder(x, padding_mask=padding_mask, layer=layer)
x, layer_results = self.encoder(
x, padding_mask=padding_mask, layer=layer, corpus_key=corpus_key
)
if features_only:
return {
@ -774,9 +792,16 @@ class Wav2Vec2Model(BaseFairseqModel):
x = self.layer_norm(x)
return self.quantizer.forward_idx(x)
def extract_features(self, source, padding_mask, mask=False, layer=None):
def extract_features(
self, source, padding_mask, mask=False, layer=None, corpus_key=None
):
res = self.forward(
source, padding_mask, mask=mask, features_only=True, layer=layer
source,
padding_mask,
mask=mask,
features_only=True,
layer=layer,
corpus_key=corpus_key,
)
return res
@ -917,7 +942,7 @@ def make_conv_pos(e, k, g):
class TransformerEncoder(nn.Module):
def build_encoder_layer(self, args: Wav2Vec2Config):
def build_encoder_layer(self, args: Wav2Vec2Config, layer_idx: int):
if args.layer_type == "transformer":
layer = TransformerSentenceEncoderLayer(
embedding_dim=self.embedding_dim,
@ -941,6 +966,40 @@ class TransformerEncoder(nn.Module):
use_fp16=args.fp16,
pos_enc_type="abs",
)
elif args.layer_type == "trf_adp":
use_adp = False
if args.adp_trf_idx == "all":
use_adp = True
else:
adp_trf_idx = list(range(*[int(g) for g in args.adp_trf_idx.split(":")]))
if layer_idx in adp_trf_idx:
use_adp = True
if use_adp:
layer = TransformerSentenceEncoderWithAdapterLayer(
embedding_dim=self.embedding_dim,
ffn_embedding_dim=args.encoder_ffn_embed_dim,
num_attention_heads=args.encoder_attention_heads,
dropout=self.dropout,
attention_dropout=args.attention_dropout,
activation_dropout=args.activation_dropout,
activation_fn=args.activation_fn,
layer_norm_first=args.layer_norm_first,
adapter_num=args.adp_num,
adapter_dim=args.adp_dim,
adapter_act_fn=args.adp_act_fn,
)
else:
layer = TransformerSentenceEncoderLayer(
embedding_dim=self.embedding_dim,
ffn_embedding_dim=args.encoder_ffn_embed_dim,
num_attention_heads=args.encoder_attention_heads,
dropout=self.dropout,
attention_dropout=args.attention_dropout,
activation_dropout=args.activation_dropout,
activation_fn=args.activation_fn,
layer_norm_first=args.layer_norm_first,
)
layer = fsdp_wrap(layer)
if args.checkpoint_activations:
layer = checkpoint_wrapper(layer)
@ -991,7 +1050,7 @@ class TransformerEncoder(nn.Module):
)
self.layers = nn.ModuleList(
[self.build_encoder_layer(args) for _ in range(args.encoder_layers)]
[self.build_encoder_layer(args, ii) for ii in range(args.encoder_layers)]
)
self.layer_norm_first = args.layer_norm_first
self.layer_norm = LayerNorm(self.embedding_dim)
@ -999,8 +1058,10 @@ class TransformerEncoder(nn.Module):
self.apply(init_bert_params)
def forward(self, x, padding_mask=None, layer=None):
x, layer_results = self.extract_features(x, padding_mask, layer)
def forward(self, x, padding_mask=None, layer=None, corpus_key=None):
x, layer_results = self.extract_features(
x, padding_mask, layer, corpus_key=corpus_key
)
if self.layer_norm_first and layer is None:
x = self.layer_norm(x)
@ -1013,6 +1074,7 @@ class TransformerEncoder(nn.Module):
padding_mask=None,
tgt_layer=None,
min_layer=0,
corpus_key=None,
):
if padding_mask is not None:
@ -1043,12 +1105,29 @@ class TransformerEncoder(nn.Module):
layer_results = []
r = None
for i, layer in enumerate(self.layers):
dropout_probability = np.random.random() if self.layerdrop > 0 else 1
if not self.training or (dropout_probability > self.layerdrop):
x, (z, lr) = layer(
x, self_attn_padding_mask=padding_mask, need_weights=False
)
layer_check = layer
if isinstance(layer, FullyShardedDataParallel):
layer_check = layer.unwrapped_module
if (corpus_key is None) or (
not isinstance(layer_check, (
TransformerSentenceEncoderWithAdapterLayer,
)
)
):
x, (z, lr) = layer(
x, self_attn_padding_mask=padding_mask, need_weights=False
)
else:
x, (z, lr) = layer(
x,
self_attn_padding_mask=padding_mask,
need_weights=False,
corpus_key=corpus_key,
)
if i >= min_layer:
layer_results.append((x, z, lr))
if i == tgt_layer:
@ -1282,3 +1361,125 @@ class TransformerSentenceEncoderLayer(nn.Module):
x = self.final_layer_norm(x)
return x, (attn, layer_result)
class AdapterFast(nn.Module):
def __init__(self, adapter_num, input_dim, hidden_dim, act_fn):
"""
Implements adapter modules directly with 3D tensor weight as parameters
and without using ModuleList orto speed up training throughput.
"""
super().__init__()
self.adapter_num = adapter_num
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.W_a = nn.Parameter(torch.empty(adapter_num, hidden_dim, input_dim))
self.W_b = nn.Parameter(torch.empty(adapter_num, input_dim, hidden_dim))
self.b_a = nn.Parameter(torch.empty(adapter_num, hidden_dim))
self.b_b = nn.Parameter(torch.empty(adapter_num, input_dim))
self.ln_W = nn.Parameter(torch.empty(adapter_num, input_dim))
self.ln_b = nn.Parameter(torch.empty(adapter_num, input_dim))
self.act_fn = nn.Identity()
if act_fn == "relu":
self.act_fn = nn.ReLU()
elif act_fn == "gelu":
self.act_fn = nn.GELU()
elif act_fn == "selu":
self.act_fn = nn.SELU()
else:
raise ValueError(f"unsupported {act_fn}")
self.input_dim = input_dim
self.reset_parameters()
def reset_parameters(self):
for ii in range(self.adapter_num):
nn.init.kaiming_uniform_(self.W_a[ii], a=math.sqrt(5))
nn.init.kaiming_uniform_(self.W_b[ii], a=math.sqrt(5))
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.W_a[ii])
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
nn.init.uniform_(self.b_a[ii], -bound, bound)
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.W_b[ii])
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
nn.init.uniform_(self.b_b[ii], -bound, bound)
nn.init.ones_(self.ln_W)
nn.init.zeros_(self.ln_b)
def forward(self, x, adapter_id):
ii = adapter_id
h = x
h = F.layer_norm(h, (self.input_dim, ), self.ln_W[ii], self.ln_b[ii])
h = F.linear(h, self.W_a[ii], self.b_a[ii])
h = self.act_fn(h)
h = F.linear(h, self.W_b[ii], self.b_b[ii])
outputs = h
return outputs
def extra_repr(self):
return ('adapter={}, input_dim={}, hidden_dim={}'.format(self.adapter_num, self.input_dim, self.hidden_dim))
class TransformerSentenceEncoderWithAdapterLayer(TransformerSentenceEncoderLayer):
"""
Implements a Transformer Encoder Layer with adapters used in BERT/XLM style pre-trained
models. An adapter module is added along with vanilla Transformer module.
"""
def __init__(
self,
embedding_dim: float = 768,
ffn_embedding_dim: float = 3072,
num_attention_heads: int = 8,
dropout: float = 0.1,
attention_dropout: float = 0.1,
activation_dropout: float = 0.1,
activation_fn: str = "relu",
layer_norm_first: bool = False,
adapter_num=201,
adapter_dim=64,
adapter_act_fn="relu",
) -> None:
super().__init__(
embedding_dim=embedding_dim,
ffn_embedding_dim=ffn_embedding_dim,
num_attention_heads=num_attention_heads,
dropout=dropout,
attention_dropout=attention_dropout,
activation_dropout=activation_dropout,
activation_fn=activation_fn,
layer_norm_first=layer_norm_first,
)
self.adapter_num = adapter_num
self.adapter_dim = adapter_dim
self.adapter_layer = AdapterFast(adapter_num, self.embedding_dim, self.adapter_dim, adapter_act_fn)
def forward(
self,
x: torch.Tensor,
self_attn_mask: torch.Tensor = None,
self_attn_padding_mask: torch.Tensor = None,
need_weights: bool = False,
att_args=None,
corpus_key=None,
):
x, (attn, layer_result) = super().forward(
x=x,
self_attn_mask=self_attn_mask,
self_attn_padding_mask=self_attn_padding_mask,
need_weights=need_weights,
att_args=att_args,
)
assert corpus_key is not None
assert len(set(corpus_key)) == 1, f"corpus_key items are not same {corpus_key}"
y = self.adapter_layer(x, corpus_key[0])
x = x + y
return x, (attn, layer_result)

View File

@ -28,7 +28,7 @@ from fairseq.models import (
FairseqIncrementalDecoder,
register_model,
)
from fairseq.models.wav2vec.wav2vec2 import MASKING_DISTRIBUTION_CHOICES
from fairseq.models.wav2vec.wav2vec2 import MASKING_DISTRIBUTION_CHOICES, LAYER_TYPE_CHOICES, AdapterFast
from fairseq.modules import LayerNorm, PositionalEmbedding, TransformerDecoderLayer
from fairseq.tasks import FairseqTask
@ -178,6 +178,27 @@ class Wav2Vec2AsrConfig(FairseqDataclass):
layer_decay: float = 1
layer_type: LAYER_TYPE_CHOICES = field(
default="transformer", metadata={"help": "layer type in encoder"}
)
# Adapter num
adp_num: int = field(
default=-1
)
adp_dim: int = field(
default=64
)
adp_act_fn: str = field(
default="relu"
)
adp_trf_idx: str = field(
default="all",
)
freeze_regex: Optional[str] = field(
default=None,
)
@dataclass
class Wav2Vec2CtcConfig(Wav2Vec2AsrConfig):
blank_weight: float = 0
@ -416,6 +437,14 @@ class Wav2VecEncoder(FairseqEncoder):
"Please check that --normalize is set or unset for both pre-training and here"
)
with open_dict(w2v_args):
args_replacement = ["checkpoint_activations", "layer_type",
"adp_num", "adp_dim",
"adp_act_fn", "adp_trf_idx"]
for _args in args_replacement:
if hasattr(cfg, _args) and getattr(cfg, _args, None) is not None:
w2v_args.model[_args] = getattr(cfg, _args, None)
if hasattr(cfg, "checkpoint_activations") and cfg.checkpoint_activations:
with open_dict(w2v_args):
w2v_args.model.checkpoint_activations = cfg.checkpoint_activations
@ -423,7 +452,6 @@ class Wav2VecEncoder(FairseqEncoder):
w2v_args.task.data = cfg.data
task = tasks.setup_task(w2v_args.task, from_checkpoint=True)
model = task.build_model(w2v_args.model, from_checkpoint=True)
model.remove_pretraining_modules()
d = w2v_args.model.encoder_embed_dim
else:
@ -468,6 +496,9 @@ class Wav2VecEncoder(FairseqEncoder):
if targ_d is not None:
self.proj = Linear(d, targ_d)
if cfg.freeze_regex is not None:
self.freeze_regex(cfg.freeze_regex)
layer_decay = getattr(cfg, "layer_decay", 1)
if layer_decay < 1:
mod_encs = list(model.modality_encoders.values())
@ -491,6 +522,14 @@ class Wav2VecEncoder(FairseqEncoder):
optim_override["optimizer"]["lr_scale"] = layer_scales[lid]
p.optim_overrides = optim_override
def freeze_regex(self, pattern):
unfrozen_names = []
for name, param in self.named_parameters():
if re.fullmatch(pattern, name) is not None:
param.requires_grad_(False)
else:
unfrozen_names.append(name)
def load_model_weights(self, state, model, cfg):
if cfg.ddp_backend == "fully_sharded":
from fairseq.distributed import FullyShardedDataParallel
@ -553,6 +592,8 @@ class Wav2VecEncoder(FairseqEncoder):
"padding_mask": padding_mask,
"mask": self.apply_mask and self.training,
}
if "corpus_key" in kwargs:
w2v_args["corpus_key"] = kwargs["corpus_key"]
if self.is_d2v_multi:
w2v_args["mode"] = "AUDIO"

View File

@ -0,0 +1,348 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import contextlib
import logging
from argparse import Namespace
from dataclasses import dataclass, field
from typing import Any, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from omegaconf import II, MISSING, open_dict
from fairseq import checkpoint_utils, tasks, utils
from fairseq.dataclass import ChoiceEnum, FairseqDataclass
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
from fairseq.models import BaseFairseqModel, FairseqEncoder, register_model
from fairseq.models.wav2vec.wav2vec2 import MASKING_DISTRIBUTION_CHOICES, Wav2Vec2Config
from fairseq.models.wav2vec.wav2vec2_asr import Embedding, Linear, Wav2VecEncoder, Wav2Vec2AsrConfig
from fairseq.tasks import FairseqTask
logging.basicConfig(level=logging.DEBUG)
@dataclass
class Wav2Vec2ClassificationConfig(Wav2Vec2AsrConfig):
latent_embed_dim: Optional[int] = field(
default=None, metadata={"help": "latent dim (encoder w2v -> latent -> class"}
)
pooling: str = field(
default="first_token",
metadata={"help": "pooling layer choices"},
)
activation_fn: ChoiceEnum(utils.get_available_activation_fns()) = field(
default="gelu", metadata={"help": "activation function to use"}
)
@register_model("wav2vec_classification", dataclass=Wav2Vec2ClassificationConfig)
class Wav2VecClassification(BaseFairseqModel):
# TODO: Can be shared/merged with ASR model class as w2v_encoder params are common.
def __init__(
self,
cfg: Wav2Vec2ClassificationConfig,
w2v_encoder: BaseFairseqModel,
pooling_layer,
):
super().__init__()
self.cfg = cfg
self.w2v_encoder = w2v_encoder
self.pooling_layer = pooling_layer
def upgrade_state_dict_named(self, state_dict, name):
super().upgrade_state_dict_named(state_dict, name)
return state_dict
@classmethod
def build_model(cls, cfg: Wav2Vec2ClassificationConfig, task: FairseqTask):
"""Build a new model instance."""
w2v_encoder = Wav2VecEncoder(cfg, None)
pooling_layer = get_pooling_layer(
cfg,
w2v_encoder.w2v_model.encoder.layers[-1].embedding_dim,
len(task.target_dictionary),
len(w2v_encoder.w2v_model.encoder.layers),
)
return cls(cfg, w2v_encoder, pooling_layer)
def get_normalized_probs(self, net_output, log_probs):
"""Get normalized probabilities (or log probs) from a net's output."""
logits = net_output
if log_probs:
return utils.log_softmax(logits.float(), dim=-1)
else:
return utils.softmax(logits.float(), dim=-1)
def get_logits(self, net_output):
return net_output
def forward(self, **kwargs):
encoder_out_dict = self.w2v_encoder(**kwargs)
w2v_encoder_out = encoder_out_dict["encoder_out"] # TxBxC
w2v_encoder_padding_mask = encoder_out_dict["padding_mask"] # BxT
# w2v_encoder_layer_results = encoder_out_dict["layer_results"]
return self.pooling_layer(
last_layer_feats=w2v_encoder_out,
padding_mask=w2v_encoder_padding_mask,
# all_layer_feats=w2v_encoder_layer_results,
)
# def forward_latent(self, **kwargs):
# encoder_out_dict = self.w2v_encoder(**kwargs)
# w2v_encoder_out = encoder_out_dict["encoder_out"]
# w2v_encoder_padding_mask = encoder_out_dict["encoder_padding_mask"]
# w2v_encoder_layer_results = encoder_out_dict["layer_results"]
# return self.pooling_layer.forward_latent(
# last_layer_feats=w2v_encoder_out,
# padding_mask=w2v_encoder_padding_mask,
# all_layer_feats=w2v_encoder_layer_results,
# )
def get_pooling_layer(
cfg: Wav2Vec2ClassificationConfig,
encoder_embed_dim: int,
num_targets: int,
encoder_layers: int,
):
assert cfg.pooling == 'mean'
if cfg.pooling == "first_token":
return FirstToken(cfg, encoder_embed_dim, num_targets)
# elif cfg.pooling == "mean":
# return MeanPooling(cfg, encoder_embed_dim, num_targets)
elif cfg.pooling == "mean":
return MeanPoolingFast(cfg, encoder_embed_dim, num_targets)
elif cfg.pooling == "mean_amsoftmax":
return MeanPoolingFastAMSoftmax(cfg, encoder_embed_dim, num_targets)
elif cfg.pooling == "max":
return MaxPoolingFast(cfg, encoder_embed_dim, num_targets)
elif cfg.pooling == "elmo":
return LayerWeightedMeanPooling(
cfg, encoder_embed_dim, num_targets, encoder_layers
)
else:
raise NotImplementedError(f"{cfg.pooling} has not been implemented yet.")
class Pooling(nn.Module):
def __init__(
self,
cfg: Wav2Vec2ClassificationConfig,
encoder_embed_dim: int,
num_targets: int,
):
super().__init__()
self.projection = Linear(encoder_embed_dim, num_targets)
def forward(self, last_layer_feats, **kwargs):
raise NotImplementedError()
class FirstToken(Pooling):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self, last_layer_feats, **kwargs):
return self.projection(last_layer_feats[:, 0])
# class MeanPooling(Pooling):
# def __init__(
# self,
# cfg: Wav2VecClassificationConfig,
# encoder_embed_dim: int,
# num_targets: int,
# **kwargs,
# ):
# super().__init__(cfg, encoder_embed_dim, num_targets)
# self.activation_fn = utils.get_activation_fn(cfg.activation_fn)
# self.linear = Linear(encoder_embed_dim, encoder_embed_dim)
# def forward(self, last_layer_feats, padding_mask, **kwargs):
# # last_layer_feats: [BxTxD]
# # padding_mask: [BxT]
# last_layer_feats = self.linear(self.activation_fn(last_layer_feats))
# input_lengths = (1 - padding_mask.long()).sum(-1)
# pooled_feature_list = []
# for i in range(len(last_layer_feats)):
# length = input_lengths[i]
# pooled_feature = torch.mean(last_layer_feats[i][:length], dim=0)
# pooled_feature_list.append(pooled_feature)
# return self.projection(torch.stack(pooled_feature_list))
def fn_mean(x, mask):
"""
Args:
x: TxBxD
mask: BxT
Return:
y: BxD
"""
if mask is not None:
mask = mask.t()[:, :, None]
return (x * mask).sum(0) / mask.sum(0)
else:
return x.sum(0) / x.shape[0]
class MeanPoolingFast(nn.Module):
def __init__(
self,
cfg: Wav2Vec2ClassificationConfig,
encoder_embed_dim: int,
num_targets: int,
**kwargs,
):
super().__init__()
self.activation_fn = utils.get_activation_fn(cfg.activation_fn)
self.latent_embed_dim = (
cfg.latent_embed_dim
if cfg.latent_embed_dim is not None
else encoder_embed_dim
)
logging.debug(f"| {self.latent_embed_dim=}")
self.linear = Linear(encoder_embed_dim, self.latent_embed_dim)
self.projection = Linear(self.latent_embed_dim, num_targets)
def forward(self, last_layer_feats, padding_mask, **kwargs):
"""
Arguments
features - [TxBxD] Acoustic feature with shape
padding_mask - [BxT] Padding Mask
"""
if padding_mask is not None:
feat_mask = (~padding_mask).to(last_layer_feats.dtype)
else:
feat_mask = None
feat = self.linear(last_layer_feats)
feat = fn_mean(feat, feat_mask)
feat = self.activation_fn(feat)
return self.projection(feat)
def forward_latent(self, last_layer_feats, padding_mask, **kwargs):
"""
Arguments
features - [TxBxD] Acoustic feature with shape
padding_mask - [BxT] Padding Mask
"""
if padding_mask is not None:
feat_mask = (~padding_mask).to(last_layer_feats.dtype)
else:
feat_mask = None
feat = self.linear(last_layer_feats)
feat = fn_mean(feat, feat_mask)
return feat
class MeanPoolingFastAMSoftmax(MeanPoolingFast):
def __init__(
self,
cfg: Wav2Vec2ClassificationConfig,
encoder_embed_dim: int,
num_targets: int,
**kwargs,
):
super().__init__(cfg, encoder_embed_dim, num_targets, **kwargs)
self.projection = Linear(self.latent_embed_dim, num_targets, bias=False)
nn.init.xavier_normal_(self.projection.weight, gain=1)
def forward(self, last_layer_feats, padding_mask, **kwargs):
"""
Arguments
features - [BxTxD] Acoustic feature with shape
padding_mask - [BxT] Padding Mask
"""
feat_mask = (~padding_mask).to(last_layer_feats.dtype) # T,B -> B,T
feat = self.linear(last_layer_feats) # B,T,D
feat = fn_mean(feat, feat_mask) # B,D
feat = self.activation_fn(feat)
# normalize feat
feat_norm = F.normalize(feat, p=2, dim=-1) # B,D
weight_norm = F.normalize(self.projection.weight.t(), p=2, dim=-1) # D,K
cos_fw = feat_norm @ weight_norm
return cos_fw
def fn_max(x, mask):
"""
Args:
x: TxBxD
mask: BxT
Return:
y: BxD
"""
mask = mask.t()[:, :, None].to(torch.bool)
return x.masked_fill(~mask, -1e-8).max(0)[0]
class MaxPoolingFast(Pooling):
def __init__(
self,
cfg: Wav2Vec2ClassificationConfig,
encoder_embed_dim: int,
num_targets: int,
**kwargs,
):
super().__init__(cfg, encoder_embed_dim, num_targets)
self.activation_fn = utils.get_activation_fn(cfg.activation_fn)
self.linear = Linear(encoder_embed_dim, encoder_embed_dim)
def forward(self, last_layer_feats, padding_mask, **kwargs):
"""
Arguments
features - [TxBxD] Acoustic feature with shape
padding_mask - [BxT] Padding Mask
"""
feat_mask = (~padding_mask).to(last_layer_feats.dtype)
feat = self.linear(last_layer_feats)
feat = fn_max(feat, feat_mask)
feat = self.activation_fn(feat)
return self.projection(feat)
class LayerWeightedMeanPooling(MeanPoolingFast):
"""Elmo-style weighted average representation."""
def __init__(
self,
cfg: Wav2Vec2ClassificationConfig,
encoder_embed_dim: int,
num_targets: int,
encoder_layers: int,
):
super().__init__(cfg, encoder_embed_dim, num_targets)
self.num_layers = encoder_layers
self.weights = nn.Parameter(torch.ones(encoder_layers))
def forward(self, last_layer_feats, padding_mask, all_layer_feats):
# last_layer_feats: [BxTxD]
# padding_mask: [BxT]
if not self.training:
msg = (
f"Number of layers in input features = {len(all_layer_feats)}."
f" Expected {self.num_layers} layers."
)
assert len(all_layer_feats) == self.num_layers, msg
# Stack up all layers and reshape to (num_layers, features)
all_layer_feats_stacked = torch.stack(all_layer_feats, dim=0)
num_layers, *original_feat_shape = all_layer_feats_stacked.shape
all_layer_feats_stacked_flat = all_layer_feats_stacked.view(num_layers, -1)
# Weighted average
normalized_weights = F.softmax(self.weights, dim=-1)
weighted_avg_features = (
normalized_weights.unsqueeze(-1) * all_layer_feats_stacked_flat
).sum(dim=0)
weighted_avg_features = weighted_avg_features.view(*original_feat_shape)
# Mean Pooling on weighted average features.
return super().forward(weighted_avg_features, padding_mask)

View File

@ -0,0 +1,269 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from collections import OrderedDict
import itertools
import logging
import os
import sys
from dataclasses import dataclass, field
from typing import Optional
import numpy as np
import torch
from omegaconf import II, MISSING
from sklearn import metrics as sklearn_metrics
from fairseq.data import AddTargetDataset, Dictionary, FileAudioDataset
from fairseq.data.multi_corpus_dataset import MultiCorpusDataset
from fairseq.data.text_compressor import TextCompressionLevel, TextCompressor
from fairseq.dataclass import FairseqDataclass
from fairseq.tasks.audio_pretraining import AudioPretrainingConfig, AudioPretrainingTask
from fairseq.tasks.audio_finetuning import label_len_fn, LabelEncoder
from .. import utils
from ..logging import metrics
from . import FairseqTask, register_task
logger = logging.getLogger(__name__)
@dataclass
class AudioClassificationConfig(AudioPretrainingConfig):
target_dictionary: Optional[str] = field(
default=None, metadata={"help": "override default dictionary location"}
)
@register_task("audio_classification", dataclass=AudioClassificationConfig)
class AudioClassificationTask(AudioPretrainingTask):
"""Task for audio classification tasks."""
cfg: AudioClassificationConfig
def __init__(
self,
cfg: AudioClassificationConfig,
):
super().__init__(cfg)
self.state.add_factory("target_dictionary", self.load_target_dictionary)
logging.info(f"=== Number of labels = {len(self.target_dictionary)}")
def load_target_dictionary(self):
if self.cfg.labels:
target_dictionary = self.cfg.data
if self.cfg.target_dictionary: # override dict
target_dictionary = self.cfg.target_dictionary
dict_path = os.path.join(target_dictionary, f"dict.{self.cfg.labels}.txt")
logger.info("Using dict_path : {}".format(dict_path))
return Dictionary.load(dict_path, add_special_symbols=False)
return None
def load_dataset(
self, split: str, task_cfg: AudioClassificationConfig = None, **kwargs
):
super().load_dataset(split, task_cfg, **kwargs)
task_cfg = task_cfg or self.cfg
assert task_cfg.labels is not None
text_compression_level = getattr(
TextCompressionLevel, str(self.cfg.text_compression_level)
)
data_path = self.cfg.data
if task_cfg.multi_corpus_keys is None:
label_path = os.path.join(data_path, f"{split}.{task_cfg.labels}")
skipped_indices = getattr(self.datasets[split], "skipped_indices", set())
text_compressor = TextCompressor(level=text_compression_level)
with open(label_path, "r") as f:
labels = [
text_compressor.compress(l)
for i, l in enumerate(f)
if i not in skipped_indices
]
assert len(labels) == len(self.datasets[split]), (
f"labels length ({len(labels)}) and dataset length "
f"({len(self.datasets[split])}) do not match"
)
process_label = LabelEncoder(self.target_dictionary)
self.datasets[split] = AddTargetDataset(
self.datasets[split],
labels,
pad=self.target_dictionary.pad(),
eos=self.target_dictionary.eos(),
batch_targets=True,
process_label=process_label,
label_len_fn=label_len_fn,
add_to_input=False,
# text_compression_level=text_compression_level,
)
else:
target_dataset_map = OrderedDict()
multi_corpus_keys = [
k.strip() for k in task_cfg.multi_corpus_keys.split(",")
]
corpus_idx_map = {k: idx for idx, k in enumerate(multi_corpus_keys)}
data_keys = [k.split(":") for k in split.split(",")]
multi_corpus_sampling_weights = [
float(val.strip())
for val in task_cfg.multi_corpus_sampling_weights.split(",")
]
data_weights = []
for key, file_name in data_keys:
k = key.strip()
label_path = os.path.join(
data_path, f"{file_name.strip()}.{task_cfg.labels}"
)
skipped_indices = getattr(
self.dataset_map[split][k], "skipped_indices", set()
)
text_compressor = TextCompressor(level=text_compression_level)
with open(label_path, "r") as f:
labels = [
text_compressor.compress(l)
for i, l in enumerate(f)
if i not in skipped_indices
]
assert len(labels) == len(self.dataset_map[split][k]), (
f"labels length ({len(labels)}) and dataset length "
f"({len(self.dataset_map[split][k])}) do not match"
)
process_label = LabelEncoder(self.target_dictionary)
# TODO: Remove duplication of code from the if block above
target_dataset_map[k] = AddTargetDataset(
self.dataset_map[split][k],
labels,
pad=self.target_dictionary.pad(),
eos=self.target_dictionary.eos(),
batch_targets=True,
process_label=process_label,
label_len_fn=label_len_fn,
add_to_input=False,
# text_compression_level=text_compression_level,
)
data_weights.append(multi_corpus_sampling_weights[corpus_idx_map[k]])
if len(target_dataset_map) == 1:
self.datasets[split] = list(target_dataset_map.values())[0]
else:
self.datasets[split] = MultiCorpusDataset(
target_dataset_map,
distribution=data_weights,
seed=0,
sort_indices=True,
)
@property
def source_dictionary(self):
return None
@property
def target_dictionary(self):
"""Return the :class:`~fairseq.data.Dictionary` for the language
model."""
return self.state.target_dictionary
def train_step(self, sample, model, *args, **kwargs):
sample["target"] = sample["target"].to(dtype=torch.long)
loss, sample_size, logging_output = super().train_step(
sample, model, *args, **kwargs
)
self._log_metrics(sample, model, logging_output)
return loss, sample_size, logging_output
def valid_step(self, sample, model, criterion):
sample["target"] = sample["target"].to(dtype=torch.long)
loss, sample_size, logging_output = super().valid_step(sample, model, criterion)
self._log_metrics(sample, model, logging_output)
return loss, sample_size, logging_output
def _log_metrics(self, sample, model, logging_output):
metrics = self._inference_with_metrics(
sample,
model,
)
"""
logging_output["_precision"] = metrics["precision"]
logging_output["_recall"] = metrics["recall"]
logging_output["_f1"] = metrics["f1"]
logging_output["_eer"] = metrics["eer"]
logging_output["_accuracy"] = metrics["accuracy"]
"""
logging_output["_correct"] = metrics["correct"]
logging_output["_total"] = metrics["total"]
def _inference_with_metrics(self, sample, model):
def _compute_eer(target_list, lprobs):
# from scipy.optimize import brentq
# from scipy.interpolate import interp1d
y_one_hot = np.eye(len(self.state.target_dictionary))[target_list]
fpr, tpr, thresholds = sklearn_metrics.roc_curve(
y_one_hot.ravel(), lprobs.ravel()
)
# Revisit the interpolation approach.
# eer = brentq(lambda x: 1.0 - x - interp1d(fpr, tpr)(x), 0.0, 1.0)
fnr = 1 - tpr
eer = fpr[np.nanargmin(np.absolute((fnr - fpr)))]
return eer
with torch.no_grad():
net_output = model(**sample["net_input"])
lprobs = (
model.get_normalized_probs(net_output, log_probs=True).cpu().detach()
)
target_list = sample["target"][:, 0].detach().cpu()
predicted_list = torch.argmax(lprobs, 1).detach().cpu() # B,C->B
metrics = {
"correct": torch.sum(target_list == predicted_list).item(),
"total": len(target_list),
}
return metrics
def reduce_metrics(self, logging_outputs, criterion):
super().reduce_metrics(logging_outputs, criterion)
zero = torch.scalar_tensor(0.0)
correct, total = 0, 0
for log in logging_outputs:
correct += log.get("_correct", zero)
total += log.get("_total", zero)
metrics.log_scalar("_correct", correct)
metrics.log_scalar("_total", total)
if total > 0:
def _fn_accuracy(meters):
if meters["_total"].sum > 0:
return utils.item(meters["_correct"].sum / meters["_total"].sum)
return float("nan")
metrics.log_derived("accuracy", _fn_accuracy)
"""
prec_sum, recall_sum, f1_sum, acc_sum, eer_sum = 0.0, 0.0, 0.0, 0.0, 0.0
for log in logging_outputs:
prec_sum += log.get("_precision", zero).item()
recall_sum += log.get("_recall", zero).item()
f1_sum += log.get("_f1", zero).item()
acc_sum += log.get("_accuracy", zero).item()
eer_sum += log.get("_eer", zero).item()
metrics.log_scalar("avg_precision", prec_sum / len(logging_outputs))
metrics.log_scalar("avg_recall", recall_sum / len(logging_outputs))
metrics.log_scalar("avg_f1", f1_sum / len(logging_outputs))
metrics.log_scalar("avg_accuracy", acc_sum / len(logging_outputs))
metrics.log_scalar("avg_eer", eer_sum / len(logging_outputs))
"""

View File

@ -7,12 +7,13 @@
import logging
import os
from fairseq.data.multi_corpus_dataset import MultiCorpusDataset
import torch
import json
from argparse import Namespace
from dataclasses import dataclass, field
from typing import Optional, Any
from typing import Optional, Any, OrderedDict
from fairseq.data import AddTargetDataset, Dictionary, encoders
from fairseq.tasks.audio_pretraining import AudioPretrainingTask, AudioPretrainingConfig
@ -101,7 +102,12 @@ class AudioFinetuningConfig(AudioPretrainingConfig):
},
)
rebuild_batches: bool = True
target_dictionary: Optional[str] = field(
default=None,
metadata={
"help": "override default dictionary location"
}
)
@register_task("audio_finetuning", dataclass=AudioFinetuningConfig)
class AudioFinetuningTask(AudioPretrainingTask):
@ -120,7 +126,11 @@ class AudioFinetuningTask(AudioPretrainingTask):
def load_target_dictionary(self):
if self.cfg.labels:
dict_path = os.path.join(self.cfg.data, f"dict.{self.cfg.labels}.txt")
target_dictionary = self.cfg.data
if self.cfg.target_dictionary: # override dict
target_dictionary = self.cfg.target_dictionary
dict_path = os.path.join(target_dictionary, f"dict.{self.cfg.labels}.txt")
logger.info('Using dict_path : {}'.format(dict_path))
return Dictionary.load(dict_path)
return None
@ -135,34 +145,84 @@ class AudioFinetuningTask(AudioPretrainingTask):
TextCompressionLevel, str(self.cfg.text_compression_level)
)
data_path = self.cfg.data
label_path = os.path.join(data_path, f"{split}.{task_cfg.labels}")
skipped_indices = getattr(self.datasets[split], "skipped_indices", set())
text_compressor = TextCompressor(level=text_compression_level)
with open(label_path, "r") as f:
labels = [
text_compressor.compress(l)
for i, l in enumerate(f)
if i not in skipped_indices
]
if task_cfg.multi_corpus_keys is None:
label_path = os.path.join(data_path, f"{split}.{task_cfg.labels}")
skipped_indices = getattr(self.datasets[split], "skipped_indices", set())
text_compressor = TextCompressor(level=text_compression_level)
with open(label_path, "r") as f:
labels = [
text_compressor.compress(l)
for i, l in enumerate(f)
if i not in skipped_indices
]
assert len(labels) == len(self.datasets[split]), (
f"labels length ({len(labels)}) and dataset length "
f"({len(self.datasets[split])}) do not match"
)
assert len(labels) == len(self.datasets[split]), (
f"labels length ({len(labels)}) and dataset length "
f"({len(self.datasets[split])}) do not match"
)
process_label = LabelEncoder(self.target_dictionary)
process_label = LabelEncoder(self.target_dictionary)
self.datasets[split] = AddTargetDataset(
self.datasets[split],
labels,
pad=self.target_dictionary.pad(),
eos=self.target_dictionary.eos(),
batch_targets=True,
process_label=process_label,
label_len_fn=label_len_fn,
add_to_input=task_cfg.get("autoregressive", False),
text_compression_level=text_compression_level,
)
self.datasets[split] = AddTargetDataset(
self.datasets[split],
labels,
pad=self.target_dictionary.pad(),
eos=self.target_dictionary.eos(),
batch_targets=True,
process_label=process_label,
label_len_fn=label_len_fn,
add_to_input=task_cfg.get("autoregressive", False),
text_compression_level=text_compression_level,
)
else:
target_dataset_map = OrderedDict()
multi_corpus_keys = [k.strip() for k in task_cfg.multi_corpus_keys.split(",")]
corpus_idx_map = {k: idx for idx, k in enumerate(multi_corpus_keys)}
data_keys = [k.split(":") for k in split.split(",")]
multi_corpus_sampling_weights = [float(val.strip()) for val in task_cfg.multi_corpus_sampling_weights.split(",")]
data_weights = []
for key, file_name in data_keys:
k = key.strip()
label_path = os.path.join(data_path, f"{file_name.strip()}.{task_cfg.labels}")
skipped_indices = getattr(self.dataset_map[split][k], "skipped_indices", set())
text_compressor = TextCompressor(level=text_compression_level)
with open(label_path, "r") as f:
labels = [
text_compressor.compress(l)
for i, l in enumerate(f)
if i not in skipped_indices
]
assert len(labels) == len(self.dataset_map[split][k]), (
f"labels length ({len(labels)}) and dataset length "
f"({len(self.dataset_map[split][k])}) do not match"
)
process_label = LabelEncoder(self.target_dictionary)
# TODO: Remove duplication of code from the if block above
target_dataset_map[k] = AddTargetDataset(
self.dataset_map[split][k],
labels,
pad=self.target_dictionary.pad(),
eos=self.target_dictionary.eos(),
batch_targets=True,
process_label=process_label,
label_len_fn=label_len_fn,
add_to_input=task_cfg.get("autoregressive", False),
text_compression_level=text_compression_level,
)
data_weights.append(multi_corpus_sampling_weights[corpus_idx_map[k]])
if len(target_dataset_map) == 1:
self.datasets[split] = list(target_dataset_map.values())[0]
else:
self.datasets[split] = MultiCorpusDataset(target_dataset_map, distribution=data_weights, seed=0, sort_indices=True)
@property
def target_dictionary(self):

View File

@ -11,8 +11,9 @@ import sys
from argparse import Namespace
from dataclasses import dataclass, field
from typing import Optional
from omegaconf import MISSING, II
from typing import Optional, OrderedDict
from fairseq.data.multi_corpus_dataset import MultiCorpusDataset
from omegaconf import MISSING, II, OmegaConf
from fairseq.data import BinarizedAudioDataset, FileAudioDataset, SubsampleDataset
from fairseq.dataclass import FairseqDataclass, ChoiceEnum
@ -44,6 +45,12 @@ class AudioPretrainingConfig(FairseqDataclass):
default=None,
metadata={"help": "extension of the label file to load, used for fine-tuning"},
)
multi_corpus_keys: Optional[str] = field(
default=None,
metadata={"help": "Comma separated names for loading multi corpus datasets"})
multi_corpus_sampling_weights: Optional[str] = field(
default=None,
metadata={"help": "Comma separated string of sampling weights corresponding to the multi_corpus_keys"})
binarized_dataset: bool = field(
default=False,
metadata={
@ -121,7 +128,7 @@ class AudioPretrainingTask(FairseqTask):
TextCompressionLevel, str(self.cfg.text_compression_level)
)
compute_mask = task_cfg.precompute_mask_config is not None
compute_mask = getattr(task_cfg, "precompute_mask_config", None) is not None
mask_args = {}
if compute_mask:
mask_args = task_cfg.precompute_mask_config
@ -140,20 +147,59 @@ class AudioPretrainingTask(FairseqTask):
**mask_args,
)
else:
manifest_path = os.path.join(data_path, "{}.tsv".format(split))
if task_cfg.multi_corpus_keys is None:
manifest_path = os.path.join(data_path, "{}.tsv".format(split))
self.datasets[split] = FileAudioDataset(
manifest_path=manifest_path,
sample_rate=task_cfg.get("sample_rate", self.cfg.sample_rate),
max_sample_size=self.cfg.max_sample_size,
min_sample_size=self.cfg.min_sample_size,
pad=task_cfg.labels is not None or task_cfg.enable_padding,
normalize=task_cfg.normalize,
num_buckets=self.cfg.num_batch_buckets or int(self.cfg.tpu),
text_compression_level=text_compression_level,
compute_mask=compute_mask,
**mask_args,
)
self.datasets[split] = FileAudioDataset(
manifest_path=manifest_path,
sample_rate=task_cfg.get("sample_rate", self.cfg.sample_rate),
max_sample_size=self.cfg.max_sample_size,
min_sample_size=self.cfg.min_sample_size,
pad=task_cfg.labels is not None or task_cfg.enable_padding,
normalize=task_cfg.normalize,
num_buckets=self.cfg.num_batch_buckets or int(self.cfg.tpu),
text_compression_level=text_compression_level,
compute_mask=compute_mask,
**mask_args,
)
else:
dataset_map = OrderedDict()
self.dataset_map = {}
multi_corpus_keys = [k.strip() for k in task_cfg.multi_corpus_keys.split(",")]
corpus_idx_map = {k: idx for idx, k in enumerate(multi_corpus_keys)}
data_keys = [k.split(":") for k in split.split(",")]
multi_corpus_sampling_weights = [float(val.strip()) for val in task_cfg.multi_corpus_sampling_weights.split(",")]
data_weights = []
for key, file_name in data_keys:
k = key.strip()
manifest_path = os.path.join(data_path, "{}.tsv".format(file_name.strip()))
# TODO: Remove duplication of code from the if block above
dataset_map[k] = FileAudioDataset(
manifest_path=manifest_path,
sample_rate=task_cfg.get("sample_rate", self.cfg.sample_rate),
max_sample_size=self.cfg.max_sample_size,
min_sample_size=self.cfg.min_sample_size,
pad=task_cfg.labels is not None or task_cfg.enable_padding,
normalize=task_cfg.normalize,
num_buckets=self.cfg.num_batch_buckets or int(self.cfg.tpu),
text_compression_level=text_compression_level,
compute_mask=compute_mask,
corpus_key=corpus_idx_map[k],
**mask_args,
)
data_weights.append(multi_corpus_sampling_weights[corpus_idx_map[k]])
self.dataset_map[split] = dataset_map
if len(dataset_map) == 1:
self.datasets[split] = list(dataset_map.values())[0]
else:
self.datasets[split] = MultiCorpusDataset(dataset_map, distribution=data_weights, seed=0, sort_indices=True)
if getattr(task_cfg, "subsample", 1) < 1:
self.datasets[split] = SubsampleDataset(