mirror of
https://github.com/facebookresearch/fairseq.git
synced 2024-10-26 17:32:57 +03:00
speech-to-text OSS
Summary: Imported from https://github.com/fairinternal/fairseq-py/pull/1284. Updated according to PR comments. Main changes: * New task: `fairseq.tasks.speech_to_text` * Multilingual support: multiple train sub-splits, temperature-based sampling, language ID tokens * New dataset: `fairseq.data.audio.speech_to_text_dataset` * Added accuracy metrics and BOS prefix removal to label smoothed cross entropy * New models: Transformer (`fairseq.models.speech_to_text.s2t_transformer`) and BLSTM (`fairseq.models.speech_to_text.berard`) * Extended scorers: * Added a base scorer class: `fairseq.scorers.BaseScorer` (the parent class for all scorers except the BLEU scorer in CPP) * Added an evaluation tokenizer: `fairseq.scorers.eval_tokenizer` which leverages sacreBLEU's built-in tokenizers and allows character-level tokenization as well as punctuation removal (for WER scoring). * Added chrF scorer: `fairseq.scorers.chrf` * Online Mel-filter bank speech feature extraction (via CPP-based pyKaldi or Python-based TorchAudio): `fairseq.data.audio.audio_utils` * Online speech feature transforms: `fairseq.data.audio.feature_transforms.*` * Fixed the subsampled sequence lengths in VGGTransformer (`examples.speech_recognition.models.vggtransformer`) * Examples under `examples/speech_to_text`: * LibriSpeech (ASR): better results than VGGTransformer with smaller Transformer-based models * MuST-C (ST): comparable to [SOTA results](https://arxiv.org/pdf/2004.10234.pdf) but with less tricks Reviewed By: jmp84 Differential Revision: D24065273 fbshipit-source-id: 5f842ca9c826f92d4af660705611885fe440a9ab
This commit is contained in:
parent
a2d0be4989
commit
1d1c145387
@ -251,6 +251,7 @@ class VGGTransformerEncoder(FairseqEncoder):
|
||||
self.conv_layers = nn.ModuleList()
|
||||
self.in_channels = in_channels
|
||||
self.input_dim = input_feat_per_channel
|
||||
self.pooling_kernel_sizes = []
|
||||
|
||||
if vggblock_config is not None:
|
||||
for _, config in enumerate(vggblock_config):
|
||||
@ -272,6 +273,7 @@ class VGGTransformerEncoder(FairseqEncoder):
|
||||
layer_norm=layer_norm,
|
||||
)
|
||||
)
|
||||
self.pooling_kernel_sizes.append(pooling_kernel_size)
|
||||
in_channels = out_channels
|
||||
input_feat_per_channel = self.conv_layers[-1].output_dim
|
||||
|
||||
@ -336,9 +338,9 @@ class VGGTransformerEncoder(FairseqEncoder):
|
||||
x = x.transpose(1, 2).transpose(0, 1)
|
||||
x = x.contiguous().view(output_seq_len, bsz, -1)
|
||||
|
||||
subsampling_factor = int(max_seq_len * 1.0 / output_seq_len + 0.5)
|
||||
# TODO: shouldn't subsampling_factor determined in advance ?
|
||||
input_lengths = (src_lengths.float() / subsampling_factor).ceil().long()
|
||||
input_lengths = src_lengths.clone()
|
||||
for s in self.pooling_kernel_sizes:
|
||||
input_lengths = (input_lengths.float() / s).ceil().long()
|
||||
|
||||
encoder_padding_mask, _ = lengths_to_encoder_padding_mask(
|
||||
input_lengths, batch_first=True
|
||||
@ -346,6 +348,7 @@ class VGGTransformerEncoder(FairseqEncoder):
|
||||
if not encoder_padding_mask.any():
|
||||
encoder_padding_mask = None
|
||||
|
||||
subsampling_factor = int(max_seq_len * 1.0 / output_seq_len + 0.5)
|
||||
attn_mask = self.lengths_to_attn_mask(input_lengths, subsampling_factor)
|
||||
|
||||
transformer_layer_idx = 0
|
||||
|
216
examples/speech_to_text/README.md
Normal file
216
examples/speech_to_text/README.md
Normal file
@ -0,0 +1,216 @@
|
||||
# Speech-to-Text (S2T) Modeling
|
||||
|
||||
## Data Preparation
|
||||
S2T modeling data consists of source speech features, target text and other optional information
|
||||
(source text, speaker id, etc.). Fairseq S2T uses per-dataset-split TSV manifest files
|
||||
to store these information. Each data field is represented by a column in the TSV file.
|
||||
|
||||
Unlike text token embeddings, speech features (e.g. log mel-filter banks) are usually fixed
|
||||
during model training and can be pre-computed. The manifest file contains the path to
|
||||
either the feature file in NumPy format or the WAV/FLAC audio file. For the latter,
|
||||
features will be extracted on-the-fly by fairseq S2T. Optionally, feature/audio files can be packed
|
||||
into uncompressed ZIP files (then accessed via byte offset and length) to improve I/O performance.
|
||||
|
||||
Fairseq S2T also employs a YAML file for data related configurations: tokenizer type and dictionary path
|
||||
for the target text, feature transforms such as CMVN (cepstral mean and variance normalization) and SpecAugment,
|
||||
temperature-based resampling, etc.
|
||||
|
||||
## Model Training & Evaluation
|
||||
Fairseq S2T uses the unified `fairseq-train`/`fairseq-generate` interface for model training and evaluation.
|
||||
It requires arguments `--task speech_to_text` and `--arch <arch in fairseq.models.speech_to_text.*>`.
|
||||
|
||||
|
||||
## Example 1: Speech Recognition (ASR) on LibriSpeech
|
||||
|
||||
#### Data preparation
|
||||
Download and preprocess LibriSpeech data with
|
||||
```bash
|
||||
python examples/speech_to_text/prep_librispeech_data.py \
|
||||
--output-root ${LS_ROOT} --vocab-type unigram --vocab-size 10000
|
||||
```
|
||||
where `LS_ROOT` is the root path for downloaded data as well as generated manifest and feature files.
|
||||
|
||||
#### Training
|
||||
```bash
|
||||
fairseq-train ${LS_ROOT} --train-subset train --valid-subset dev --save-dir ${SAVE_DIR} --num-workers 4 \
|
||||
--max-tokens 40000 --task speech_to_text --criterion label_smoothed_cross_entropy --max-update 300000 \
|
||||
--arch s2t_transformer_s --optimizer adam --lr 2e-3 --lr-scheduler inverse_sqrt --warmup-updates 10000 \
|
||||
--clip-norm 10.0 --seed 1 --update-freq 8
|
||||
```
|
||||
where `SAVE_DIR` is the checkpoint root path. Here we use `--arch s2t_transformer_s` (31M parameters) as example.
|
||||
You may switch to `s2t_transformer_m` (71M) or `s2t_transformer_l` (268M) for better performance. We set
|
||||
`--update-freq 8` to simulate 8 GPUs with 1 GPU. You may want to update it accordingly when using more than 1 GPU.
|
||||
|
||||
#### Inference & Evaluation
|
||||
Average the last 10 checkpoints and evaluate on the 4 splits
|
||||
(`dev-clean`, `dev-other`, `test-clean` and `test-other`):
|
||||
```bash
|
||||
CHECKPOINT_FILENAME=avg_last_10_checkpoint.pt
|
||||
python scripts/average_checkpoints.py \
|
||||
--inputs ${SAVE_DIR} --num-epoch-checkpoints 10 --output "${SAVE_DIR}/${CHECKPOINT_FILENAME}"
|
||||
for SUBSET in dev-clean dev-other test-clean test-other; do
|
||||
fairseq-generate ${LS_ROOT} --gen-subset ${SUBSET} --task speech_to_text \
|
||||
--path ${SAVE_DIR}/${CHECKPOINT_FILENAME} --max-tokens 50000 --beam 5 --scoring wer
|
||||
done
|
||||
```
|
||||
|
||||
#### Result
|
||||
|
||||
| --arch | Params | dev-clean | dev-other | test-clean | test-other |
|
||||
|---|---|---|---|---|---|
|
||||
| s2t_transformer_s | 30M | 4.1 | 9.3 | 4.4 | 9.2 |
|
||||
| s2t_transformer_sp | 35M | 3.9 | 9.3 | 4.3 | 8.8 |
|
||||
| s2t_transformer_m | 71M | 3.5 | 8.1 | 3.7 | 8.1 |
|
||||
| s2t_transformer_mp | 84M | 3.3 | 7.8 | 3.7 | 8.2 |
|
||||
| s2t_transformer_l | 268M | 3.3 | 7.7 | 3.5 | 7.8 |
|
||||
| s2t_transformer_lp | 318M | 3.1 | 7.5 | 3.4 | 7.6 |
|
||||
|
||||
|
||||
## Example 2: Speech Translation (ST) on MuST-C
|
||||
|
||||
#### Data Preparation
|
||||
[Download](https://ict.fbk.eu/must-c) and unpack MuST-C data to a path `MUSTC_ROOT`, then preprocess it with
|
||||
```bash
|
||||
python examples/speech_to_text/prep_mustc_data.py --data-root ${MUSTC_ROOT} \
|
||||
--asr-vocab-type unigram --asr-vocab-size 5000 \
|
||||
--st-vocab-type unigram --st-vocab-size 8000
|
||||
```
|
||||
The generated manifest and feature files will be available under `MUSTC_ROOT`.
|
||||
|
||||
#### ASR
|
||||
###### Training
|
||||
```bash
|
||||
fairseq-train ${MUSTC_ROOT} --train-subset train_asr --valid-subset dev_asr --save-dir ${ASR_SAVE_DIR} \
|
||||
--num-workers 4 --max-tokens 40000 --task speech_to_text --criterion label_smoothed_cross_entropy \
|
||||
--report-accuracy --max-update 100000 --arch s2t_transformer_s --optimizer adam --lr 1e-3 \
|
||||
--lr-scheduler inverse_sqrt --warmup-updates 10000 --clip-norm 10.0 --seed 1 --update-freq 8
|
||||
```
|
||||
where `ASR_SAVE_DIR` is the checkpoint root path. We set `--update-freq 8` to simulate 8 GPUs with 1 GPU.
|
||||
You may want to update it accordingly when using more than 1 GPU.
|
||||
|
||||
###### Inference & Evaluation
|
||||
```bash
|
||||
CHECKPOINT_FILENAME=avg_last_10_checkpoint.pt
|
||||
python scripts/average_checkpoints.py \
|
||||
--inputs ${ASR_SAVE_DIR} --num-epoch-checkpoints 10 --output "${ASR_SAVE_DIR}/${CHECKPOINT_FILENAME}"
|
||||
fairseq-generate ${MUSTC_ROOT} --gen-subset tst-COMMON_asr --task speech_to_text \
|
||||
--path ${ASR_SAVE_DIR}/${CHECKPOINT_FILENAME} --max-tokens 50000 --beam 5 \
|
||||
--scoring wer --wer-tokenizer 13a --wer-lowercase --wer-remove-punct
|
||||
```
|
||||
###### Result
|
||||
| --arch | Params | En-De | En-Nl | En-Es | En-Fr | En-It | En-Pt | En-Ro | En-Ru |
|
||||
|---|---|---|---|---|---|---|---|---|---|
|
||||
| s2t_transformer_s | 31M | 18.2 | 17.6 | 17.7 | 17.2 | 17.9 | 19.1 | 18.1 | 17.7 |
|
||||
|
||||
#### ST
|
||||
###### Training
|
||||
```bash
|
||||
fairseq-train ${MUSTC_ROOT} --train-subset train_st --valid-subset dev_st --save-dir ${ST_SAVE_DIR} \
|
||||
--num-workers 4 --max-tokens 40000 --task speech_to_text --criterion label_smoothed_cross_entropy \
|
||||
--report-accuracy --max-update 100000 --arch s2t_transformer_s --optimizer adam --lr 2e-3 \
|
||||
--lr-scheduler inverse_sqrt --warmup-updates 10000 --clip-norm 10.0 --seed 1 --update-freq 8 \
|
||||
--load-pretrained-encoder-from ${ASR_SAVE_DIR}/${CHECKPOINT_FILENAME}
|
||||
```
|
||||
where `ST_SAVE_DIR` is the checkpoint root path. The ST encoder is pre-trained by ASR for faster training and better
|
||||
performance: `--load-pretrained-encoder-from <ASR checkpoint path>`. We set `--update-freq 8` to simulate 8 GPUs with 1 GPU.
|
||||
You may want to update it accordingly when using more than 1 GPU.
|
||||
|
||||
###### Inference & Evaluation
|
||||
Average the last 10 checkpoints and evaluate on the `tst-COMMON` split:
|
||||
```bash
|
||||
CHECKPOINT_FILENAME=avg_last_10_checkpoint.pt
|
||||
python scripts/average_checkpoints.py \
|
||||
--inputs ${ST_SAVE_DIR} --num-epoch-checkpoints 10 --output "${ST_SAVE_DIR}/${CHECKPOINT_FILENAME}"
|
||||
fairseq-generate ${MUSTC_ROOT} --gen-subset tst-COMMON_st --task speech_to_text \
|
||||
--path ${ST_SAVE_DIR}/${CHECKPOINT_FILENAME} --max-tokens 50000 --beam 5 --scoring sacrebleu
|
||||
```
|
||||
|
||||
###### Result
|
||||
| --arch | Params | En-De | En-Nl | En-Es | En-Fr | En-It | En-Pt | En-Ro | En-Ru |
|
||||
|---|---|---|---|---|---|---|---|---|---|
|
||||
| s2t_transformer_s | 31M | 22.7 | 27.3 | 27.2 | 32.9 | 22.7 | 28.1 | 21.9 | 15.3 |
|
||||
|
||||
|
||||
## Example 3: ST on CoVoST
|
||||
#### Data Preparation
|
||||
Download and preprocess CoVoST data with
|
||||
```bash
|
||||
# En ASR
|
||||
python examples/speech_to_text/prep_covost_data.py --data-root ${COVOST_ROOT} \
|
||||
--vocab-type char --src-lang en
|
||||
# ST
|
||||
python examples/speech_to_text/prep_covost_data.py --data-root ${COVOST_ROOT} \
|
||||
--vocab-type char --src-lang fr --tgt-lang en
|
||||
```
|
||||
where `COVOST_ROOT` is the root path for downloaded data as well as generated manifest and feature files.
|
||||
|
||||
#### ASR
|
||||
###### Training
|
||||
```bash
|
||||
fairseq-train ${COVOST_ROOT} --train-subset train_asr --valid-subset dev_asr --save-dir ${ASR_SAVE_DIR} \
|
||||
--num-workers 4 --max-tokens 40000 --task speech_to_text --criterion label_smoothed_cross_entropy \
|
||||
--report-accuracy --max-update 100000 --arch s2t_transformer_s --optimizer adam --lr 1e-3 \
|
||||
--lr-scheduler inverse_sqrt --warmup-updates 10000 --clip-norm 10.0 --seed 1 --update-freq 8
|
||||
```
|
||||
where `ASR_SAVE_DIR` is the checkpoint root path. We set `--update-freq 8` to simulate 8 GPUs with 1 GPU.
|
||||
You may want to update it accordingly when using more than 1 GPU.
|
||||
|
||||
###### Inference & Evaluation
|
||||
```bash
|
||||
CHECKPOINT_FILENAME=avg_last_10_checkpoint.pt
|
||||
python scripts/average_checkpoints.py \
|
||||
--inputs ${ASR_SAVE_DIR} --num-epoch-checkpoints 10 --output "${ASR_SAVE_DIR}/${CHECKPOINT_FILENAME}"
|
||||
fairseq-generate ${COVOST_ROOT} --gen-subset test_asr_en --task speech_to_text \
|
||||
--path ${ASR_SAVE_DIR}/${CHECKPOINT_FILENAME} --max-tokens 50000 --beam 5 \
|
||||
--scoring wer --wer-tokenizer 13a --wer-lowercase --wer-remove-punct
|
||||
```
|
||||
###### Result
|
||||
| --arch | Params | En |
|
||||
|---|---|---|
|
||||
| s2t_transformer_s | 31M | 25.6 |
|
||||
|
||||
#### ST
|
||||
###### Training
|
||||
```bash
|
||||
fairseq-train ${COVOST_ROOT} --train-subset train_st_fr_en --valid-subset dev_st_fr_en --save-dir ${ST_SAVE_DIR} \
|
||||
--num-workers 4 --max-tokens 40000 --task speech_to_text --criterion label_smoothed_cross_entropy \
|
||||
--report-accuracy --max-update 100000 --arch s2t_transformer_s --optimizer adam --lr 2e-3 \
|
||||
--lr-scheduler inverse_sqrt --warmup-updates 10000 --clip-norm 10.0 --seed 1 --update-freq 8 \
|
||||
--load-pretrained-encoder-from ${ASR_SAVE_DIR}/${CHECKPOINT_FILENAME}
|
||||
```
|
||||
where `ST_SAVE_DIR` is the checkpoint root path. The ST encoder is pre-trained by En ASR for faster training and better
|
||||
performance: `--load-pretrained-encoder-from <ASR checkpoint path>`. We set `--update-freq 8` to simulate 8 GPUs with 1 GPU.
|
||||
You may want to update it accordingly when using more than 1 GPU.
|
||||
|
||||
###### Inference & Evaluation
|
||||
Average the last 10 checkpoints and evaluate on test split:
|
||||
```bash
|
||||
CHECKPOINT_FILENAME=avg_last_10_checkpoint.pt
|
||||
python scripts/average_checkpoints.py \
|
||||
--inputs ${ST_SAVE_DIR} --num-epoch-checkpoints 10 --output "${ST_SAVE_DIR}/${CHECKPOINT_FILENAME}"
|
||||
fairseq-generate ${COVOST_ROOT} --gen-subset test_st_fr_en --task speech_to_text \
|
||||
--path ${ST_SAVE_DIR}/${CHECKPOINT_FILENAME} --max-tokens 50000 --beam 5 --scoring sacrebleu
|
||||
```
|
||||
|
||||
###### Result
|
||||
| --arch | Params | Fr-En | De-En | Es-En | Ca-En | En-De | En-Ca | En-Fa | En-Et |
|
||||
|---|---|---|---|---|---|---|---|---|---|
|
||||
| s2t_transformer_s | 31M | 26.3 | 17.1 | 23.0 | 18.8 | 16.3 | 21.8 | 13.1 | 13.2 |
|
||||
|
||||
## Citation
|
||||
Please cite as:
|
||||
```
|
||||
@inproceedings{wang2020fairseqs2t,
|
||||
title = {fairseq S2T: Fast Speech-to-Text Modeling with fairseq},
|
||||
author = {Changhan Wang and Yun Tang and Xutai Ma and Anne Wu and Dmytro Okhonko and Juan Pino},
|
||||
booktitle = {Proceedings of the 2020 Conference of the Asian Chapter of the Association for Computational Linguistics (AACL): System Demonstrations},
|
||||
year = {2020},
|
||||
}
|
||||
|
||||
@inproceedings{ott2019fairseq,
|
||||
title = {fairseq: A Fast, Extensible Toolkit for Sequence Modeling},
|
||||
author = {Myle Ott and Sergey Edunov and Alexei Baevski and Angela Fan and Sam Gross and Nathan Ng and David Grangier and Michael Auli},
|
||||
booktitle = {Proceedings of NAACL-HLT 2019: Demonstrations},
|
||||
year = {2019},
|
||||
}
|
||||
```
|
218
examples/speech_to_text/data_utils.py
Normal file
218
examples/speech_to_text/data_utils.py
Normal file
@ -0,0 +1,218 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from multiprocessing import cpu_count
|
||||
import os
|
||||
import os.path as op
|
||||
from glob import glob
|
||||
import zipfile
|
||||
import csv
|
||||
from functools import reduce
|
||||
from typing import Dict, Any, List
|
||||
from fairseq.data.audio.audio_utils import _get_kaldi_fbank, _get_torchaudio_fbank
|
||||
|
||||
import sentencepiece as sp
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
|
||||
from fairseq.data.audio.feature_transforms.utterance_cmvn import UtteranceCMVN
|
||||
|
||||
UNK_TOKEN, UNK_TOKEN_ID = '<unk>', 3
|
||||
BOS_TOKEN, BOS_TOKEN_ID = '<s>', 0
|
||||
EOS_TOKEN, EOS_TOKEN_ID = '</s>', 2
|
||||
PAD_TOKEN, PAD_TOKEN_ID = '<pad>', 1
|
||||
|
||||
|
||||
def gen_vocab(
|
||||
input_path: str, output_path_prefix: str, model_type='bpe',
|
||||
vocab_size=1000,
|
||||
):
|
||||
# Train SentencePiece Model
|
||||
arguments = [
|
||||
f'--input={input_path}',
|
||||
f'--model_prefix={output_path_prefix}',
|
||||
f'--model_type={model_type}',
|
||||
f'--vocab_size={vocab_size}',
|
||||
'--character_coverage=1.0',
|
||||
f'--num_threads={cpu_count()}',
|
||||
f'--unk_id={UNK_TOKEN_ID}',
|
||||
f'--bos_id={BOS_TOKEN_ID}',
|
||||
f'--eos_id={EOS_TOKEN_ID}',
|
||||
f'--pad_id={PAD_TOKEN_ID}'
|
||||
]
|
||||
sp.SentencePieceTrainer.Train(' '.join(arguments))
|
||||
# Export fairseq dictionary
|
||||
spm = sp.SentencePieceProcessor()
|
||||
spm.Load(output_path_prefix + '.model')
|
||||
vocab = {i: spm.IdToPiece(i) for i in range(spm.GetPieceSize())}
|
||||
assert vocab.get(UNK_TOKEN_ID) == UNK_TOKEN and \
|
||||
vocab.get(PAD_TOKEN_ID) == PAD_TOKEN and \
|
||||
vocab.get(BOS_TOKEN_ID) == BOS_TOKEN and \
|
||||
vocab.get(EOS_TOKEN_ID) == EOS_TOKEN
|
||||
vocab = {
|
||||
i: s for i, s in vocab.items()
|
||||
if s not in {UNK_TOKEN, BOS_TOKEN, EOS_TOKEN, PAD_TOKEN}
|
||||
}
|
||||
with open(output_path_prefix + '.txt', 'w') as f_out:
|
||||
for _, s in sorted(vocab.items(), key=lambda x: x[0]):
|
||||
f_out.write(f'{s} 1\n')
|
||||
|
||||
|
||||
def extract_fbank_features(waveform, sample_rate, output_path=None,
|
||||
n_mel_bins=80, apply_utterance_cmvn=True,
|
||||
overwrite=False):
|
||||
if output_path is not None and op.exists(output_path) and not overwrite:
|
||||
return
|
||||
|
||||
_waveform = waveform * (2 ** 15) # Kaldi compliance: 16-bit signed integers
|
||||
_waveform = _waveform.squeeze().numpy()
|
||||
|
||||
features = _get_kaldi_fbank(_waveform, sample_rate, n_mel_bins)
|
||||
if features is None:
|
||||
features = _get_torchaudio_fbank(_waveform, sample_rate, n_mel_bins)
|
||||
if features is None:
|
||||
raise ImportError('Please install pyKaldi or torchaudio to enable '
|
||||
'online filterbank feature extraction')
|
||||
|
||||
if apply_utterance_cmvn:
|
||||
cmvn = UtteranceCMVN(norm_means=True, norm_vars=True)
|
||||
features = cmvn(features)
|
||||
if output_path is not None:
|
||||
np.save(output_path, features)
|
||||
else:
|
||||
return features
|
||||
|
||||
|
||||
def create_zip(data_root, zip_path):
|
||||
cwd = os.path.abspath(os.curdir)
|
||||
os.chdir(data_root)
|
||||
with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_STORED) as f:
|
||||
for filename in tqdm(glob('*.npy')):
|
||||
f.write(filename)
|
||||
os.chdir(cwd)
|
||||
|
||||
|
||||
def is_npy_data(data: bytes) -> bool:
|
||||
return data[0] == 147 and data[1] == 78
|
||||
|
||||
|
||||
def get_zip_manifest(zip_root, zip_filename):
|
||||
zip_path = op.join(zip_root, zip_filename)
|
||||
with zipfile.ZipFile(zip_path, mode='r') as f:
|
||||
info = f.infolist()
|
||||
manifest = {}
|
||||
for i in tqdm(info):
|
||||
utt_id = op.splitext(i.filename)[0]
|
||||
offset, file_size = i.header_offset + 30 + len(i.filename), i.file_size
|
||||
manifest[utt_id] = f'{zip_filename}:{offset}:{file_size}'
|
||||
with open(zip_path, 'rb') as f:
|
||||
f.seek(offset)
|
||||
data = f.read(file_size)
|
||||
assert len(data) > 1 and is_npy_data(data)
|
||||
return manifest
|
||||
|
||||
|
||||
def gen_config_yaml(data_root, spm_filename, yaml_filename='config.yaml',
|
||||
specaugment_policy='lb'):
|
||||
assert specaugment_policy in {'lb', 'ld'}
|
||||
data_root = op.abspath(data_root)
|
||||
writer = S2TDataConfigWriter(op.join(data_root, yaml_filename))
|
||||
writer.set_audio_root(op.abspath(data_root))
|
||||
writer.set_vocab_filename(spm_filename.replace(".model", ".txt"))
|
||||
writer.set_input_channels(1)
|
||||
writer.set_input_feat_per_channel(80)
|
||||
if specaugment_policy == 'lb':
|
||||
writer.set_specaugment_lb_policy()
|
||||
else:
|
||||
writer.set_specaugment_ld_policy()
|
||||
writer.set_bpe_tokenizer(
|
||||
{'bpe': 'sentencepiece',
|
||||
'sentencepiece_model': op.join(data_root, spm_filename)}
|
||||
)
|
||||
writer.set_feature_transforms('_train', ['specaugment'])
|
||||
writer.flush()
|
||||
|
||||
|
||||
def save_df_to_tsv(dataframe, path):
|
||||
dataframe.to_csv(path, sep="\t", header=True, index=False, encoding="utf-8",
|
||||
escapechar='\\', quoting=csv.QUOTE_NONE)
|
||||
|
||||
|
||||
def filter_manifest_df(df, is_train_split=False, extra_filters=None,
|
||||
min_n_frames=5, max_n_frames=3000):
|
||||
filters = {
|
||||
'no speech': df['audio'] == '',
|
||||
f'short speech (<{min_n_frames} frames)': df['n_frames'] < min_n_frames,
|
||||
'empty sentence': df['tgt_text'] == '',
|
||||
}
|
||||
if is_train_split:
|
||||
filters[f'long speech (>{max_n_frames} frames)'] = \
|
||||
df['n_frames'] > max_n_frames
|
||||
if extra_filters is not None:
|
||||
filters.update(extra_filters)
|
||||
invalid = reduce(lambda x, y: x | y, filters.values())
|
||||
valid = ~invalid
|
||||
print(
|
||||
'| ' + ', '.join(f'{n}: {f.sum()}' for n, f in filters.items()) +
|
||||
f', total {invalid.sum()} filtered, {valid.sum()} remained.'
|
||||
)
|
||||
return df[valid]
|
||||
|
||||
|
||||
class S2TDataConfigWriter(object):
|
||||
DEFAULT_VOCAB_FILENAME = 'dict.txt'
|
||||
DEFAULT_INPUT_FEAT_PER_CHANNEL = 80
|
||||
DEFAULT_INPUT_CHANNELS = 1
|
||||
|
||||
def __init__(self, yaml_path):
|
||||
try:
|
||||
import yaml
|
||||
except ImportError:
|
||||
print('Please install PyYAML to load YAML files for S2T data config')
|
||||
self.yaml = yaml
|
||||
self.yaml_path = yaml_path
|
||||
self.config = {}
|
||||
|
||||
def flush(self):
|
||||
with open(self.yaml_path, 'w') as f:
|
||||
self.yaml.dump(self.config, f)
|
||||
|
||||
def set_audio_root(self, audio_root=''):
|
||||
self.config['audio_root'] = audio_root
|
||||
|
||||
def set_vocab_filename(self, vocab_filename='dict.txt'):
|
||||
self.config['vocab_filename'] = vocab_filename
|
||||
|
||||
def set_specaugment(self, time_wrap_w: int, freq_mask_n: int,
|
||||
freq_mask_f: int, time_mask_n: int, time_mask_t: int,
|
||||
time_mask_p: float):
|
||||
self.config['specaugment'] = {
|
||||
'time_wrap_W': time_wrap_w, 'freq_mask_N': freq_mask_n,
|
||||
'freq_mask_F': freq_mask_f, 'time_mask_N': time_mask_n,
|
||||
'time_mask_T': time_mask_t, 'time_mask_p': time_mask_p,
|
||||
}
|
||||
|
||||
def set_specaugment_lb_policy(self):
|
||||
self.set_specaugment(time_wrap_w=0, freq_mask_n=1, freq_mask_f=27,
|
||||
time_mask_n=1, time_mask_t=100, time_mask_p=1.0)
|
||||
|
||||
def set_specaugment_ld_policy(self):
|
||||
self.set_specaugment(time_wrap_w=0, freq_mask_n=2, freq_mask_f=27,
|
||||
time_mask_n=2, time_mask_t=100, time_mask_p=1.0)
|
||||
|
||||
def set_input_channels(self, input_channels=1):
|
||||
self.config['input_channels'] = input_channels
|
||||
|
||||
def set_input_feat_per_channel(self, input_feat_per_channel=80):
|
||||
self.config['input_feat_per_channel'] = input_feat_per_channel
|
||||
|
||||
def set_bpe_tokenizer(self, bpe_tokenizer: Dict[str, Any]):
|
||||
self.config['bpe_tokenizer'] = bpe_tokenizer
|
||||
|
||||
def set_feature_transforms(self, split, transforms: List[str]):
|
||||
if 'transforms' not in self.config:
|
||||
self.config['transforms'] = {}
|
||||
self.config['transforms'][split] = transforms
|
232
examples/speech_to_text/prep_covost_data.py
Normal file
232
examples/speech_to_text/prep_covost_data.py
Normal file
@ -0,0 +1,232 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from tempfile import NamedTemporaryFile
|
||||
import os
|
||||
import os.path as op
|
||||
import shutil
|
||||
from typing import Tuple, Optional
|
||||
import csv
|
||||
|
||||
from torchaudio.datasets.utils import download_url, extract_archive
|
||||
from tqdm import tqdm
|
||||
import pandas as pd
|
||||
from torch.utils.data import Dataset
|
||||
import torchaudio
|
||||
from torch import Tensor
|
||||
|
||||
from examples.speech_to_text.data_utils import (
|
||||
gen_vocab, create_zip, get_zip_manifest, save_df_to_tsv,
|
||||
extract_fbank_features, gen_config_yaml, filter_manifest_df
|
||||
)
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
MANIFEST_COLUMNS = ['id', 'audio', 'n_frames', 'tgt_text', 'speaker']
|
||||
|
||||
|
||||
class CoVoST(Dataset):
|
||||
"""Create a Dataset for CoVoST (https://github.com/facebookresearch/covost).
|
||||
|
||||
Args:
|
||||
root (str): root path to the dataset and generated manifests/features
|
||||
source_language (str): source (audio) language
|
||||
target_language (str, optional): target (text) language,
|
||||
None for no translation (default: None)
|
||||
version (int, optional): CoVoST version. (default: 2)
|
||||
download (bool, optional): Whether to download the dataset if it is not
|
||||
found at root path. (default: ``False``).
|
||||
"""
|
||||
|
||||
CV_URL_TEMPLATE = "https://voice-prod-bundler-ee1969a6ce8178826482b88" \
|
||||
"e843c335139bd3fb4.s3.amazonaws.com/{ver}/{lang}.tar.gz"
|
||||
COVOST_URL_TEMPLATE = "https://dl.fbaipublicfiles.com/covost/" \
|
||||
"covost_v2.{src_lang}_{tgt_lang}.tsv.tar.gz"
|
||||
|
||||
VERSIONS = {2}
|
||||
SPLITS = ['train', 'dev', 'test']
|
||||
|
||||
CV_VERSION_ID = {1: "cv-corpus-3", 2: "cv-corpus-4-2019-12-10"}
|
||||
|
||||
XX_EN_LANGUAGES = {
|
||||
1: ['fr', 'de', 'nl', 'ru', 'es', 'it', 'tr', 'fa', 'sv-SE', 'mn',
|
||||
'zh-CN'],
|
||||
2: ['fr', 'de', 'es', 'ca', 'it', 'ru', 'zh-CN', 'pt', 'fa', 'et', 'mn',
|
||||
'nl', 'tr', 'ar', 'sv-SE', 'lv', 'sl', 'ta', 'ja', 'id', 'cy']
|
||||
}
|
||||
EN_XX_LANGUAGES = {
|
||||
1: [],
|
||||
2: ['de', 'tr', 'fa', 'sv-SE', 'mn', 'zh-CN', 'cy', 'ca', 'sl', 'et',
|
||||
'id',
|
||||
'ar', 'ta', 'lv', 'ja']
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self, root: str, split: str, source_language: str,
|
||||
target_language: Optional[str] = None, version: int = 2,
|
||||
download: bool = False
|
||||
) -> None:
|
||||
assert version in self.VERSIONS and split in self.SPLITS
|
||||
assert source_language is not None
|
||||
self.no_translation = (target_language is None)
|
||||
if not self.no_translation:
|
||||
assert 'en' in {source_language, target_language}
|
||||
if source_language == 'en':
|
||||
assert target_language in self.EN_XX_LANGUAGES[version]
|
||||
else:
|
||||
assert source_language in self.XX_EN_LANGUAGES[version]
|
||||
else:
|
||||
# Hack here so that we can get "split" column from CoVoST TSV.
|
||||
# Note that we use CoVoST train split for ASR which is an extension
|
||||
# to Common Voice train split.
|
||||
target_language = 'de' if source_language == 'en' else 'en'
|
||||
|
||||
self.root = os.path.join(root, 'raw')
|
||||
os.makedirs(self.root, exist_ok=True)
|
||||
|
||||
cv_url = self.CV_URL_TEMPLATE.format(ver=self.CV_VERSION_ID[version],
|
||||
lang=source_language)
|
||||
cv_archive = os.path.join(self.root, os.path.basename(cv_url))
|
||||
if download:
|
||||
if not os.path.isfile(cv_archive):
|
||||
download_url(cv_url, self.root, hash_value=None)
|
||||
extract_archive(cv_archive)
|
||||
|
||||
covost_url = self.COVOST_URL_TEMPLATE.format(src_lang=source_language,
|
||||
tgt_lang=target_language)
|
||||
covost_archive = os.path.join(self.root, os.path.basename(covost_url))
|
||||
if download:
|
||||
if not os.path.isfile(covost_archive):
|
||||
download_url(covost_url, self.root, hash_value=None)
|
||||
extract_archive(covost_archive)
|
||||
|
||||
cv_tsv = self.load_from_tsv(os.path.join(self.root, 'validated.tsv'))
|
||||
covost_tsv = self.load_from_tsv(
|
||||
os.path.join(self.root,
|
||||
os.path.basename(covost_url).replace('.tar.gz', ''))
|
||||
)
|
||||
df = pd.merge(left=cv_tsv[['path', 'sentence', 'client_id']],
|
||||
right=covost_tsv[['path', 'translation', 'split']],
|
||||
how='inner', on='path')
|
||||
if split == 'train':
|
||||
df = df[(df['split'] == split) | (df['split'] == f'{split}_covost')]
|
||||
else:
|
||||
df = df[df['split'] == split]
|
||||
self.data = df.to_dict(orient='index').items()
|
||||
self.data = [v for k, v in sorted(self.data, key=lambda x: x[0])]
|
||||
|
||||
@classmethod
|
||||
def load_from_tsv(cls, path: str):
|
||||
return pd.read_csv(
|
||||
path, sep='\t', header=0, encoding='utf-8', escapechar='\\',
|
||||
quoting=csv.QUOTE_NONE, na_filter=False
|
||||
)
|
||||
|
||||
def __getitem__(
|
||||
self, n: int
|
||||
) -> Tuple[Tensor, int, str, str, Optional[str], str, str]:
|
||||
"""Load the n-th sample from the dataset.
|
||||
|
||||
Args:
|
||||
n (int): The index of the sample to be loaded
|
||||
|
||||
Returns:
|
||||
tuple: ``(waveform, sample_rate, sentence, translation, speaker_id,
|
||||
sample_id)``
|
||||
"""
|
||||
data = self.data[n]
|
||||
path = os.path.join(self.root, 'clips', data['path'])
|
||||
waveform, sample_rate = torchaudio.load(path)
|
||||
sentence = data['sentence']
|
||||
translation = None if self.no_translation else data['translation']
|
||||
speaker_id = data['client_id']
|
||||
_id = data['path'].replace('.mp3', '')
|
||||
return waveform, sample_rate, sentence, translation, speaker_id, _id
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.data)
|
||||
|
||||
|
||||
def process(args):
|
||||
root = op.join(args.data_root, args.src_lang)
|
||||
os.makedirs(root, exist_ok=True)
|
||||
# Extract features
|
||||
feature_root = op.join(root, 'fbank80')
|
||||
os.makedirs(feature_root, exist_ok=True)
|
||||
for split in CoVoST.SPLITS:
|
||||
print(f'Fetching split {split}...')
|
||||
dataset = CoVoST(root, split, args.src_lang, args.tgt_lang,
|
||||
download=True)
|
||||
print('Extracting log mel filter bank features...')
|
||||
for waveform, sample_rate, _, _, _, utt_id in tqdm(dataset):
|
||||
extract_fbank_features(waveform, sample_rate,
|
||||
op.join(feature_root, f'{utt_id}.npy'))
|
||||
# Pack features into ZIP
|
||||
zip_filename = 'fbank80.zip'
|
||||
zip_path = op.join(root, zip_filename)
|
||||
print('ZIPing features...')
|
||||
create_zip(feature_root, zip_path)
|
||||
print('Fetching ZIP manifest...')
|
||||
zip_manifest = get_zip_manifest(args.data_root,
|
||||
f'{args.src_lang}/{zip_filename}')
|
||||
# Generate TSV manifest
|
||||
print('Generating manifest...')
|
||||
train_text = []
|
||||
task = f'asr_{args.src_lang}'
|
||||
if args.tgt_lang is not None:
|
||||
task = f'st_{args.src_lang}_{args.tgt_lang}'
|
||||
for split in CoVoST.SPLITS:
|
||||
manifest = {c: [] for c in MANIFEST_COLUMNS}
|
||||
dataset = CoVoST(root, split, args.src_lang, args.tgt_lang)
|
||||
for wav, sr, src_utt, tgt_utt, speaker_id, utt_id in tqdm(dataset):
|
||||
manifest['id'].append(utt_id)
|
||||
manifest['audio'].append(zip_manifest[utt_id])
|
||||
duration_ms = int(wav.size(1) / sr * 1000)
|
||||
manifest['n_frames'].append(int(1 + (duration_ms - 25) / 10))
|
||||
manifest['tgt_text'].append(
|
||||
src_utt if args.tgt_lang is None else tgt_utt
|
||||
)
|
||||
manifest['speaker'].append(speaker_id)
|
||||
is_train_split = split.startswith('train')
|
||||
if is_train_split:
|
||||
train_text.extend(manifest['tgt_text'])
|
||||
df = pd.DataFrame.from_dict(manifest)
|
||||
df = filter_manifest_df(df, is_train_split=is_train_split)
|
||||
save_df_to_tsv(df, op.join(root, f'{split}_{task}.tsv'))
|
||||
# Generate vocab
|
||||
vocab_size_str = '' if args.vocab_type == 'char' else str(args.vocab_size)
|
||||
spm_filename_prefix = f'spm_{args.vocab_type}{vocab_size_str}_{task}'
|
||||
with NamedTemporaryFile(mode='w') as f:
|
||||
for t in train_text:
|
||||
f.write(t + '\n')
|
||||
gen_vocab(f.name, op.join(root, spm_filename_prefix),
|
||||
args.vocab_type, args.vocab_size)
|
||||
# Generate config YAML
|
||||
gen_config_yaml(root, spm_filename_prefix + '.model',
|
||||
yaml_filename=f'config_{task}.yaml',
|
||||
specaugment_policy='lb')
|
||||
# Clean up
|
||||
shutil.rmtree(feature_root)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--data-root', '-d', required=True, type=str)
|
||||
parser.add_argument('--vocab-type', default='unigram', required=True,
|
||||
type=str, choices=['bpe', 'unigram', 'char']),
|
||||
parser.add_argument('--vocab-size', default=1000, type=int)
|
||||
parser.add_argument('--src-lang', '-s', required=True, type=str)
|
||||
parser.add_argument('--tgt-lang', '-t', type=str)
|
||||
args = parser.parse_args()
|
||||
|
||||
process(args)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
96
examples/speech_to_text/prep_librispeech_data.py
Normal file
96
examples/speech_to_text/prep_librispeech_data.py
Normal file
@ -0,0 +1,96 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from tempfile import NamedTemporaryFile
|
||||
import os
|
||||
import shutil
|
||||
import os.path as op
|
||||
|
||||
from tqdm import tqdm
|
||||
from torchaudio.datasets import LIBRISPEECH
|
||||
import pandas as pd
|
||||
|
||||
from examples.speech_to_text.data_utils import (
|
||||
gen_vocab, create_zip, get_zip_manifest, save_df_to_tsv,
|
||||
extract_fbank_features, gen_config_yaml
|
||||
)
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
SPLITS = ['train-clean-100', 'train-clean-360', 'train-other-500', 'dev-clean',
|
||||
'dev-other', 'test-clean', 'test-other']
|
||||
|
||||
MANIFEST_COLUMNS = ['id', 'audio', 'n_frames', 'tgt_text', 'speaker']
|
||||
|
||||
|
||||
def process(args):
|
||||
os.makedirs(args.output_root, exist_ok=True)
|
||||
# Extract features
|
||||
feature_root = op.join(args.output_root, 'fbank80')
|
||||
os.makedirs(feature_root, exist_ok=True)
|
||||
for split in SPLITS:
|
||||
print(f'Fetching split {split}...')
|
||||
dataset = LIBRISPEECH(args.output_root, url=split, download=True)
|
||||
print('Extracting log mel filter bank features...')
|
||||
for wav, sample_rate, _, spk_id, chapter_id, utt_id in tqdm(dataset):
|
||||
sample_id = f'{spk_id}-{chapter_id}-{utt_id}'
|
||||
extract_fbank_features(wav, sample_rate,
|
||||
op.join(feature_root, f'{sample_id}.npy'))
|
||||
# Pack features into ZIP
|
||||
zip_filename = 'fbank80.zip'
|
||||
zip_path = op.join(args.output_root, zip_filename)
|
||||
print('ZIPing features...')
|
||||
create_zip(feature_root, zip_path)
|
||||
print('Fetching ZIP manifest...')
|
||||
zip_manifest = get_zip_manifest(args.output_root, zip_filename)
|
||||
# Generate TSV manifest
|
||||
print('Generating manifest...')
|
||||
train_text = []
|
||||
for split in SPLITS:
|
||||
manifest = {c: [] for c in MANIFEST_COLUMNS}
|
||||
dataset = LIBRISPEECH(args.output_root, url=split)
|
||||
for wav, sample_rate, utt, spk_id, chapter_id, utt_id in tqdm(dataset):
|
||||
sample_id = f'{spk_id}-{chapter_id}-{utt_id}'
|
||||
manifest['id'].append(sample_id)
|
||||
manifest['audio'].append(zip_manifest[sample_id])
|
||||
duration_ms = int(wav.size(1) / sample_rate * 1000)
|
||||
manifest['n_frames'].append(int(1 + (duration_ms - 25) / 10))
|
||||
manifest['tgt_text'].append(utt)
|
||||
manifest['speaker'].append(spk_id)
|
||||
save_df_to_tsv(pd.DataFrame.from_dict(manifest),
|
||||
op.join(args.output_root, f'{split}.tsv'))
|
||||
if split.startswith('train'):
|
||||
train_text.extend(manifest['tgt_text'])
|
||||
# Generate vocab
|
||||
vocab_size = '' if args.vocab_type == 'char' else str(args.vocab_size)
|
||||
spm_filename_prefix = f'spm_{args.vocab_type}{vocab_size}'
|
||||
with NamedTemporaryFile(mode='w') as f:
|
||||
for t in train_text:
|
||||
f.write(t + '\n')
|
||||
gen_vocab(f.name, op.join(args.output_root, spm_filename_prefix),
|
||||
args.vocab_type, args.vocab_size)
|
||||
# Generate config YAML
|
||||
gen_config_yaml(args.output_root, spm_filename_prefix + '.model',
|
||||
specaugment_policy='ld')
|
||||
# Clean up
|
||||
shutil.rmtree(feature_root)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--output-root', '-o', required=True, type=str)
|
||||
parser.add_argument('--vocab-type', default='unigram', required=True,
|
||||
type=str, choices=['bpe', 'unigram', 'char']),
|
||||
parser.add_argument('--vocab-size', default=10000, type=int)
|
||||
args = parser.parse_args()
|
||||
|
||||
process(args)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
172
examples/speech_to_text/prep_mustc_data.py
Normal file
172
examples/speech_to_text/prep_mustc_data.py
Normal file
@ -0,0 +1,172 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from tempfile import NamedTemporaryFile
|
||||
import os
|
||||
import os.path as op
|
||||
import shutil
|
||||
from typing import Tuple
|
||||
from itertools import groupby
|
||||
|
||||
from tqdm import tqdm
|
||||
import pandas as pd
|
||||
from torch.utils.data import Dataset
|
||||
import torchaudio
|
||||
from torch import Tensor
|
||||
|
||||
from examples.speech_to_text.data_utils import (
|
||||
gen_vocab, create_zip, get_zip_manifest, save_df_to_tsv,
|
||||
extract_fbank_features, gen_config_yaml, filter_manifest_df
|
||||
)
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
MANIFEST_COLUMNS = ['id', 'audio', 'n_frames', 'tgt_text', 'speaker']
|
||||
TASKS = ['asr', 'st']
|
||||
|
||||
|
||||
class MUSTC(Dataset):
|
||||
"""
|
||||
Create a Dataset for MuST-C. Each item is a tuple of the form:
|
||||
waveform, sample_rate, source utterance, target utterance, speaker_id,
|
||||
utterance_id
|
||||
"""
|
||||
SPLITS = ['train', 'dev', 'tst-COMMON', 'tst-HE']
|
||||
LANGUAGES = ['de', 'es', 'fr', 'it', 'nl', 'pt', 'ro', 'ru']
|
||||
|
||||
def __init__(self, root: str, lang: str, split: str) -> None:
|
||||
assert split in self.SPLITS and lang in self.LANGUAGES
|
||||
_root = op.join(root, f'en-{lang}', 'data', split)
|
||||
wav_root, txt_root = op.join(_root, 'wav'), op.join(_root, 'txt')
|
||||
assert op.isdir(_root) and op.isdir(wav_root) and op.isdir(txt_root)
|
||||
# Load audio segments
|
||||
try:
|
||||
import yaml
|
||||
except ImportError:
|
||||
print('Please install PyYAML to load YAML files for '
|
||||
'the MuST-C dataset')
|
||||
with open(op.join(txt_root, f'{split}.yaml')) as f:
|
||||
segments = yaml.load(f, Loader=yaml.BaseLoader)
|
||||
# Load source and target utterances
|
||||
for _lang in ['en', lang]:
|
||||
with open(op.join(txt_root, f'{split}.{_lang}')) as f:
|
||||
utterances = [r.strip() for r in f]
|
||||
assert len(segments) == len(utterances)
|
||||
for i, u in enumerate(utterances):
|
||||
segments[i][_lang] = u
|
||||
# Gather info
|
||||
self.data = []
|
||||
for wav_filename, _seg_group in groupby(segments, lambda x: x['wav']):
|
||||
wav_path = op.join(wav_root, wav_filename)
|
||||
sample_rate = torchaudio.info(wav_path)[0].rate
|
||||
seg_group = sorted(_seg_group, key=lambda x: x['offset'])
|
||||
for i, segment in enumerate(seg_group):
|
||||
offset = int(float(segment['offset']) * sample_rate)
|
||||
n_frames = int(float(segment['duration']) * sample_rate)
|
||||
_id = f'{op.splitext(wav_filename)[0]}_{i}'
|
||||
self.data.append(
|
||||
(wav_path, offset, n_frames, sample_rate, segment['en'],
|
||||
segment[lang], segment['speaker_id'], _id)
|
||||
)
|
||||
|
||||
def __getitem__(self, n: int) -> Tuple[Tensor, int, str, str, str, str]:
|
||||
wav_path, offset, n_frames, sr, src_utt, tgt_utt, spk_id, utt_id = \
|
||||
self.data[n]
|
||||
waveform, _ = torchaudio.load(wav_path, offset=offset,
|
||||
num_frames=n_frames)
|
||||
return waveform, sr, src_utt, tgt_utt, spk_id, utt_id
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.data)
|
||||
|
||||
|
||||
def process(args):
|
||||
for lang in MUSTC.LANGUAGES:
|
||||
cur_root = op.join(args.data_root, f'en-{lang}')
|
||||
if not op.isdir(cur_root):
|
||||
print(f'{cur_root} does not exist. Skipped.')
|
||||
continue
|
||||
# Extract features
|
||||
feature_root = op.join(cur_root, 'fbank80')
|
||||
os.makedirs(feature_root, exist_ok=True)
|
||||
for split in MUSTC.SPLITS:
|
||||
print(f'Fetching split {split}...')
|
||||
dataset = MUSTC(args.data_root, lang, split)
|
||||
print('Extracting log mel filter bank features...')
|
||||
for waveform, sample_rate, _, _, _, utt_id in tqdm(dataset):
|
||||
extract_fbank_features(waveform, sample_rate,
|
||||
op.join(feature_root, f'{utt_id}.npy'))
|
||||
# Pack features into ZIP
|
||||
zip_filename = 'fbank80.zip'
|
||||
zip_path = op.join(cur_root, zip_filename)
|
||||
print('ZIPing features...')
|
||||
create_zip(feature_root, zip_path)
|
||||
print('Fetching ZIP manifest...')
|
||||
zip_manifest = get_zip_manifest(args.data_root,
|
||||
f'en-{lang}/{zip_filename}')
|
||||
# Generate TSV manifest
|
||||
print('Generating manifest...')
|
||||
train_text = {task: [] for task in TASKS}
|
||||
for split in MUSTC.SPLITS:
|
||||
is_train_split = split.startswith('train')
|
||||
manifest = {c: [] for c in MANIFEST_COLUMNS}
|
||||
text = {task: [] for task in TASKS}
|
||||
dataset = MUSTC(args.data_root, lang, split)
|
||||
for wav, sr, src_utt, tgt_utt, speaker_id, utt_id in tqdm(dataset):
|
||||
manifest['id'].append(utt_id)
|
||||
manifest['audio'].append(zip_manifest[utt_id])
|
||||
duration_ms = int(wav.size(1) / sr * 1000)
|
||||
manifest['n_frames'].append(int(1 + (duration_ms - 25) / 10))
|
||||
text['asr'].append(src_utt)
|
||||
text['st'].append(tgt_utt)
|
||||
manifest['speaker'].append(speaker_id)
|
||||
if is_train_split:
|
||||
for task in TASKS:
|
||||
train_text[task].extend(text[task])
|
||||
for task in TASKS:
|
||||
manifest['tgt_text'] = text[task]
|
||||
df = pd.DataFrame.from_dict(manifest)
|
||||
df = filter_manifest_df(df, is_train_split=is_train_split)
|
||||
save_df_to_tsv(df, op.join(cur_root, f'{split}_{task}.tsv'))
|
||||
# Generate vocab
|
||||
for task in TASKS:
|
||||
vocab_type, vocab_size = args.asr_vocab_type, args.asr_vocab_size
|
||||
if task == 'st':
|
||||
vocab_type, vocab_size = args.st_vocab_type, args.st_vocab_size
|
||||
vocab_size_str = '' if vocab_type == 'char' else str(vocab_size)
|
||||
spm_filename_prefix = f'spm_{vocab_type}{vocab_size_str}_{task}'
|
||||
with NamedTemporaryFile(mode='w') as f:
|
||||
for t in train_text[task]:
|
||||
f.write(t + '\n')
|
||||
gen_vocab(f.name, op.join(cur_root, spm_filename_prefix),
|
||||
vocab_type, vocab_size)
|
||||
# Generate config YAML
|
||||
gen_config_yaml(cur_root, spm_filename_prefix + '.model',
|
||||
yaml_filename=f'config_{task}.yaml',
|
||||
specaugment_policy='lb')
|
||||
# Clean up
|
||||
shutil.rmtree(feature_root)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--data-root', '-d', required=True, type=str)
|
||||
parser.add_argument('--asr-vocab-type', default='unigram', required=True,
|
||||
type=str, choices=['bpe', 'unigram', 'char']),
|
||||
parser.add_argument('--st-vocab-type', default='unigram', required=True,
|
||||
type=str, choices=['bpe', 'unigram', 'char']),
|
||||
parser.add_argument('--asr-vocab-size', default=5000, type=int)
|
||||
parser.add_argument('--st-vocab-size', default=8000, type=int)
|
||||
args = parser.parse_args()
|
||||
|
||||
process(args)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
@ -5,6 +5,8 @@
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
|
||||
from fairseq import metrics, utils
|
||||
from fairseq.criterions import FairseqCriterion, register_criterion
|
||||
|
||||
@ -31,11 +33,13 @@ def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=None, reduce=T
|
||||
|
||||
@register_criterion('label_smoothed_cross_entropy')
|
||||
class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
|
||||
|
||||
def __init__(self, task, sentence_avg, label_smoothing):
|
||||
def __init__(self, task, sentence_avg, label_smoothing,
|
||||
ignore_prefix_size=0, report_accuracy=False):
|
||||
super().__init__(task)
|
||||
self.sentence_avg = sentence_avg
|
||||
self.eps = label_smoothing
|
||||
self.ignore_prefix_size = ignore_prefix_size
|
||||
self.report_accuracy = report_accuracy
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
@ -43,6 +47,10 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
|
||||
# fmt: off
|
||||
parser.add_argument('--label-smoothing', default=0., type=float, metavar='D',
|
||||
help='epsilon for label smoothing, 0 means no label smoothing')
|
||||
parser.add_argument('--report-accuracy', action='store_true',
|
||||
help='report accuracy metric')
|
||||
parser.add_argument('--ignore-prefix-size', default=0, type=int,
|
||||
help='Ignore first N tokens')
|
||||
# fmt: on
|
||||
|
||||
def forward(self, model, sample, reduce=True):
|
||||
@ -63,19 +71,41 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
|
||||
'nsentences': sample['target'].size(0),
|
||||
'sample_size': sample_size,
|
||||
}
|
||||
if self.report_accuracy:
|
||||
n_correct, total = self.compute_accuracy(model, net_output, sample)
|
||||
logging_output['n_correct'] = utils.item(n_correct.data)
|
||||
logging_output['total'] = utils.item(total.data)
|
||||
return loss, sample_size, logging_output
|
||||
|
||||
def compute_loss(self, model, net_output, sample, reduce=True):
|
||||
def get_lprobs_and_target(self, model, net_output, sample):
|
||||
lprobs = model.get_normalized_probs(net_output, log_probs=True)
|
||||
lprobs = lprobs.view(-1, lprobs.size(-1))
|
||||
target = model.get_targets(sample, net_output).view(-1, 1)
|
||||
target = model.get_targets(sample, net_output)
|
||||
if self.ignore_prefix_size > 0:
|
||||
if getattr(lprobs, "batch_first", False):
|
||||
lprobs = lprobs[:, self.ignore_prefix_size:, :].contiguous()
|
||||
target = target[:, self.ignore_prefix_size:].contiguous()
|
||||
else:
|
||||
lprobs = lprobs[self.ignore_prefix_size:, :, :].contiguous()
|
||||
target = target[self.ignore_prefix_size:, :].contiguous()
|
||||
return lprobs.view(-1, lprobs.size(-1)), target.view(-1)
|
||||
|
||||
def compute_loss(self, model, net_output, sample, reduce=True):
|
||||
lprobs, target = self.get_lprobs_and_target(model, net_output, sample)
|
||||
loss, nll_loss = label_smoothed_nll_loss(
|
||||
lprobs, target, self.eps, ignore_index=self.padding_idx, reduce=reduce,
|
||||
)
|
||||
return loss, nll_loss
|
||||
|
||||
@staticmethod
|
||||
def reduce_metrics(logging_outputs) -> None:
|
||||
def compute_accuracy(self, model, net_output, sample):
|
||||
lprobs, target = self.get_lprobs_and_target(model, net_output, sample)
|
||||
mask = target.ne(self.padding_idx)
|
||||
n_correct = torch.sum(
|
||||
lprobs.argmax(1).masked_select(mask).eq(target.masked_select(mask)))
|
||||
total = torch.sum(mask)
|
||||
return n_correct, total
|
||||
|
||||
@classmethod
|
||||
def reduce_metrics(cls, logging_outputs) -> None:
|
||||
"""Aggregate logging outputs from data parallel training."""
|
||||
loss_sum = sum(log.get('loss', 0) for log in logging_outputs)
|
||||
nll_loss_sum = sum(log.get('nll_loss', 0) for log in logging_outputs)
|
||||
@ -86,6 +116,20 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
|
||||
metrics.log_scalar('nll_loss', nll_loss_sum / ntokens / math.log(2), ntokens, round=3)
|
||||
metrics.log_derived('ppl', lambda meters: utils.get_perplexity(meters['nll_loss'].avg))
|
||||
|
||||
total = utils.item(sum(log.get('total', 0) for log in logging_outputs))
|
||||
if total > 0:
|
||||
metrics.log_scalar('total', total)
|
||||
n_correct = utils.item(
|
||||
sum(log.get('n_correct', 0) for log in logging_outputs)
|
||||
)
|
||||
metrics.log_scalar('n_correct', n_correct)
|
||||
metrics.log_derived(
|
||||
'accuracy',
|
||||
lambda meters: round(
|
||||
meters['n_correct'].sum * 100.0 / meters['total'].sum, 3
|
||||
) if meters['total'].sum > 0 else float('nan'),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def logging_outputs_can_be_summed() -> bool:
|
||||
"""
|
||||
|
81
fairseq/data/audio/audio_utils.py
Normal file
81
fairseq/data/audio/audio_utils.py
Normal file
@ -0,0 +1,81 @@
|
||||
import os.path as op
|
||||
from typing import Union, BinaryIO, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def get_waveform(
|
||||
path_or_fp: Union[str, BinaryIO], normalization=True
|
||||
) -> Tuple[np.ndarray, int]:
|
||||
"""Get the waveform and sample rate of a 16-bit mono-channel WAV or FLAC.
|
||||
|
||||
Args:
|
||||
path_or_fp (str or BinaryIO): the path or file-like object
|
||||
normalization (bool): Normalize values to [-1, 1] (Default: True)
|
||||
"""
|
||||
if isinstance(path_or_fp, str):
|
||||
ext = op.splitext(op.basename(path_or_fp))[1]
|
||||
if ext not in {'.flac', '.wav'}:
|
||||
raise ValueError(f'Unsupported audio format: {ext}')
|
||||
|
||||
try:
|
||||
import soundfile as sf
|
||||
except ImportError:
|
||||
raise ImportError('Please install soundfile to load WAV/FLAC file')
|
||||
|
||||
waveform, sample_rate = sf.read(path_or_fp, dtype='float32')
|
||||
if not normalization:
|
||||
waveform *= 2 ** 15 # denormalized to 16-bit signed integers
|
||||
return waveform, sample_rate
|
||||
|
||||
|
||||
def _get_kaldi_fbank(waveform, sample_rate, n_bins=80) -> Optional[np.ndarray]:
|
||||
"""Get mel-filter bank features via PyKaldi."""
|
||||
try:
|
||||
from kaldi.feat.mel import MelBanksOptions
|
||||
from kaldi.feat.fbank import FbankOptions, Fbank
|
||||
from kaldi.feat.window import FrameExtractionOptions
|
||||
from kaldi.matrix import Vector
|
||||
|
||||
mel_opts = MelBanksOptions()
|
||||
mel_opts.num_bins = n_bins
|
||||
frame_opts = FrameExtractionOptions()
|
||||
frame_opts.samp_freq = sample_rate
|
||||
opts = FbankOptions()
|
||||
opts.mel_opts = mel_opts
|
||||
opts.frame_opts = frame_opts
|
||||
fbank = Fbank(opts=opts)
|
||||
features = fbank.compute(Vector(waveform), 1.0).numpy()
|
||||
return features
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
|
||||
def _get_torchaudio_fbank(waveform, sample_rate, n_bins=80) -> Optional[np.ndarray]:
|
||||
"""Get mel-filter bank features via TorchAudio."""
|
||||
try:
|
||||
import torch
|
||||
import torchaudio.compliance.kaldi as ta_kaldi
|
||||
waveform = torch.from_numpy(waveform).unsqueeze(0)
|
||||
features = ta_kaldi.fbank(waveform, num_mel_bins=n_bins,
|
||||
sample_frequency=sample_rate)
|
||||
return features.numpy()
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
|
||||
def get_fbank(path_or_fp: Union[str, BinaryIO], n_bins=80) -> np.ndarray:
|
||||
"""Get mel-filter bank features via PyKaldi or TorchAudio. Prefer PyKaldi
|
||||
(faster CPP implementation) to TorchAudio (Python implementation). Note that
|
||||
Kaldi/TorchAudio requires 16-bit signed integers as inputs and hence the
|
||||
waveform should not be normalized."""
|
||||
sound, sample_rate = get_waveform(path_or_fp, normalization=False)
|
||||
|
||||
features = _get_kaldi_fbank(sound, sample_rate, n_bins)
|
||||
if features is None:
|
||||
features = _get_torchaudio_fbank(sound, sample_rate, n_bins)
|
||||
if features is None:
|
||||
raise ImportError('Please install pyKaldi or torchaudio to enable '
|
||||
'online filterbank feature extraction')
|
||||
|
||||
return features
|
77
fairseq/data/audio/feature_transforms/__init__.py
Normal file
77
fairseq/data/audio/feature_transforms/__init__.py
Normal file
@ -0,0 +1,77 @@
|
||||
import importlib
|
||||
import os
|
||||
from typing import Optional, Dict
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class AudioFeatureTransform(ABC):
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def from_config_dict(cls, config: Optional[Dict] = None):
|
||||
pass
|
||||
|
||||
|
||||
AUDIO_FEATURE_TRANSFORM_REGISTRY = {}
|
||||
AUDIO_FEATURE_TRANSFORM_CLASS_NAMES = set()
|
||||
|
||||
|
||||
def register_audio_feature_transform(name):
|
||||
def register_audio_feature_transform_cls(cls):
|
||||
if name in AUDIO_FEATURE_TRANSFORM_REGISTRY:
|
||||
raise ValueError(f'Cannot register duplicate transform ({name})')
|
||||
if not issubclass(cls, AudioFeatureTransform):
|
||||
raise ValueError(f'Transform ({name}: {cls.__name__}) must extend '
|
||||
'AudioFeatureTransform')
|
||||
if cls.__name__ in AUDIO_FEATURE_TRANSFORM_CLASS_NAMES:
|
||||
raise ValueError(
|
||||
f'Cannot register audio feature transform with duplicate '
|
||||
f'class name ({cls.__name__})'
|
||||
)
|
||||
AUDIO_FEATURE_TRANSFORM_REGISTRY[name] = cls
|
||||
AUDIO_FEATURE_TRANSFORM_CLASS_NAMES.add(cls.__name__)
|
||||
return cls
|
||||
|
||||
return register_audio_feature_transform_cls
|
||||
|
||||
|
||||
def get_audio_feature_transform(name):
|
||||
return AUDIO_FEATURE_TRANSFORM_REGISTRY[name]
|
||||
|
||||
|
||||
transforms_dir = os.path.dirname(__file__)
|
||||
for file in os.listdir(transforms_dir):
|
||||
path = os.path.join(transforms_dir, file)
|
||||
if (
|
||||
not file.startswith('_')
|
||||
and not file.startswith('.')
|
||||
and (file.endswith('.py') or os.path.isdir(path))
|
||||
):
|
||||
name = file[:file.find('.py')] if file.endswith('.py') else file
|
||||
importlib.import_module('fairseq.data.audio.feature_transforms.' + name)
|
||||
|
||||
|
||||
class CompositeAudioFeatureTransform(AudioFeatureTransform):
|
||||
@classmethod
|
||||
def from_config_dict(cls, config=None):
|
||||
_config = {} if config is None else config
|
||||
_transforms = _config.get('transforms')
|
||||
if _transforms is None:
|
||||
return None
|
||||
transforms = [
|
||||
get_audio_feature_transform(_t).from_config_dict(_config.get(_t))
|
||||
for _t in _transforms
|
||||
]
|
||||
return CompositeAudioFeatureTransform(transforms)
|
||||
|
||||
def __init__(self, transforms):
|
||||
self.transforms = [t for t in transforms if t is not None]
|
||||
|
||||
def __call__(self, x):
|
||||
for t in self.transforms:
|
||||
x = t(x)
|
||||
return x
|
||||
|
||||
def __repr__(self):
|
||||
format_string = [self.__class__.__name__ + '('] + \
|
||||
[f" {t.__repr__()}" for t in self.transforms] + [')']
|
||||
return '\n'.join(format_string)
|
24
fairseq/data/audio/feature_transforms/global_cmvn.py
Normal file
24
fairseq/data/audio/feature_transforms/global_cmvn.py
Normal file
@ -0,0 +1,24 @@
|
||||
import numpy as np
|
||||
from fairseq.data.audio.feature_transforms import (
|
||||
AudioFeatureTransform, register_audio_feature_transform
|
||||
)
|
||||
|
||||
|
||||
@register_audio_feature_transform('global_cmvn')
|
||||
class GlobalCMVN(AudioFeatureTransform):
|
||||
"""Global CMVN (cepstral mean and variance normalization). The global mean
|
||||
and variance need to be pre-computed and stored in NumPy format (.npz)."""
|
||||
|
||||
@classmethod
|
||||
def from_config_dict(cls, config=None):
|
||||
_config = {} if config is None else config
|
||||
return GlobalCMVN(_config.get('stats_npz_path'))
|
||||
|
||||
def __init__(self, stats_npz_path):
|
||||
stats = np.load(stats_npz_path)
|
||||
self.mean, self.std = stats['mean'], stats['std']
|
||||
|
||||
def __call__(self, x):
|
||||
x = np.subtract(x, self.mean)
|
||||
x = np.divide(x, self.std)
|
||||
return x
|
126
fairseq/data/audio/feature_transforms/specaugment.py
Normal file
126
fairseq/data/audio/feature_transforms/specaugment.py
Normal file
@ -0,0 +1,126 @@
|
||||
import math
|
||||
import numbers
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
from fairseq.data.audio.feature_transforms import (
|
||||
AudioFeatureTransform, register_audio_feature_transform
|
||||
)
|
||||
|
||||
|
||||
@register_audio_feature_transform('specaugment')
|
||||
class SpecAugmentTransform(AudioFeatureTransform):
|
||||
"""SpecAugment (https://arxiv.org/abs/1904.08779)"""
|
||||
|
||||
@classmethod
|
||||
def from_config_dict(cls, config=None):
|
||||
_config = {} if config is None else config
|
||||
return SpecAugmentTransform(
|
||||
_config.get('time_warp_W', 0),
|
||||
_config.get('freq_mask_N', 0),
|
||||
_config.get('freq_mask_F', 0),
|
||||
_config.get('time_mask_N', 0),
|
||||
_config.get('time_mask_T', 0),
|
||||
_config.get('time_mask_p', 0.0),
|
||||
_config.get('mask_value', None),
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
time_warp_w: int = 0,
|
||||
freq_mask_n: int = 0,
|
||||
freq_mask_f: int = 0,
|
||||
time_mask_n: int = 0,
|
||||
time_mask_t: int = 0,
|
||||
time_mask_p: float = 0.0,
|
||||
mask_value: Optional[float] = 0.0,
|
||||
):
|
||||
# Sanity checks
|
||||
assert mask_value is None or isinstance(
|
||||
mask_value, numbers.Number
|
||||
), f"mask_value (type: {type(mask_value)}) must be None or a number"
|
||||
if freq_mask_n > 0:
|
||||
assert (
|
||||
freq_mask_f > 0
|
||||
), f"freq_mask_F ({freq_mask_f}) " \
|
||||
f"must be larger than 0 when doing freq masking."
|
||||
if time_mask_n > 0:
|
||||
assert (
|
||||
time_mask_t > 0
|
||||
), f"time_mask_T ({time_mask_t}) must be larger than 0 when " \
|
||||
f"doing time masking."
|
||||
|
||||
self.time_warp_w = time_warp_w
|
||||
self.freq_mask_n = freq_mask_n
|
||||
self.freq_mask_f = freq_mask_f
|
||||
self.time_mask_n = time_mask_n
|
||||
self.time_mask_t = time_mask_t
|
||||
self.time_mask_p = time_mask_p
|
||||
self.mask_value = mask_value
|
||||
|
||||
def __repr__(self):
|
||||
return self.__class__.__name__ + '(' + ', '.join(
|
||||
[f'time_warp_w={self.time_warp_w}',
|
||||
f'freq_mask_n={self.freq_mask_n}',
|
||||
f'freq_mask_f={self.freq_mask_f}',
|
||||
f'time_mask_n={self.time_mask_n}',
|
||||
f'time_mask_t={self.time_mask_t}',
|
||||
f'time_mask_p={self.time_mask_p}']
|
||||
) + ')'
|
||||
|
||||
def __call__(self, spectrogram):
|
||||
assert len(spectrogram.shape) == 2, "spectrogram must be a 2-D tensor."
|
||||
|
||||
distorted = spectrogram.copy() # make a copy of input spectrogram.
|
||||
num_frames = spectrogram.shape[0] # or 'tau' in the paper.
|
||||
num_freqs = spectrogram.shape[1] # or 'miu' in the paper.
|
||||
mask_value = self.mask_value
|
||||
|
||||
if mask_value is None: # if no value was specified, use local mean.
|
||||
mask_value = spectrogram.mean()
|
||||
|
||||
if num_frames == 0:
|
||||
return spectrogram
|
||||
|
||||
if num_freqs < self.freq_mask_f:
|
||||
return spectrogram
|
||||
|
||||
if self.time_warp_w > 0:
|
||||
if 2 * self.time_warp_w < num_frames:
|
||||
import cv2
|
||||
w0 = np.random.randint(
|
||||
self.time_warp_w, num_frames - self.time_warp_w
|
||||
)
|
||||
w = np.random.randint(0, self.time_warp_w)
|
||||
upper, lower = distorted[:w0, :], distorted[w0:, :]
|
||||
upper = cv2.resize(
|
||||
upper, dsize=(num_freqs, w0 + w),
|
||||
interpolation=cv2.INTER_LINEAR
|
||||
)
|
||||
lower = cv2.resize(
|
||||
lower,
|
||||
dsize=(num_freqs, num_frames - w0 - w),
|
||||
interpolation=cv2.INTER_LINEAR,
|
||||
)
|
||||
distorted = np.concatenate((upper, lower), axis=0)
|
||||
|
||||
for _i in range(self.freq_mask_n):
|
||||
f = np.random.randint(0, self.freq_mask_f)
|
||||
f0 = np.random.randint(0, num_freqs - f)
|
||||
if f != 0:
|
||||
distorted[:, f0: f0 + f] = mask_value
|
||||
|
||||
max_time_mask_t = min(
|
||||
self.time_mask_t, math.floor(num_frames * self.time_mask_p)
|
||||
)
|
||||
if max_time_mask_t < 1:
|
||||
return distorted
|
||||
|
||||
for _i in range(self.time_mask_n):
|
||||
t = np.random.randint(0, max_time_mask_t)
|
||||
t0 = np.random.randint(0, num_frames - t)
|
||||
if t != 0:
|
||||
distorted[t0: t0 + t, :] = mask_value
|
||||
|
||||
return distorted
|
38
fairseq/data/audio/feature_transforms/utterance_cmvn.py
Normal file
38
fairseq/data/audio/feature_transforms/utterance_cmvn.py
Normal file
@ -0,0 +1,38 @@
|
||||
import numpy as np
|
||||
|
||||
from fairseq.data.audio.feature_transforms import (
|
||||
AudioFeatureTransform, register_audio_feature_transform
|
||||
)
|
||||
|
||||
|
||||
@register_audio_feature_transform('utterance_cmvn')
|
||||
class UtteranceCMVN(AudioFeatureTransform):
|
||||
"""Utterance-level CMVN (cepstral mean and variance normalization)"""
|
||||
|
||||
@classmethod
|
||||
def from_config_dict(cls, config=None):
|
||||
_config = {} if config is None else config
|
||||
return UtteranceCMVN(
|
||||
_config.get('norm_means', True),
|
||||
_config.get('norm_vars', True),
|
||||
)
|
||||
|
||||
def __init__(self, norm_means=True, norm_vars=True):
|
||||
self.norm_means, self.norm_vars = norm_means, norm_vars
|
||||
|
||||
def __repr__(self):
|
||||
return self.__class__.__name__ + \
|
||||
f'(norm_means={self.norm_means}, norm_vars={self.norm_vars})'
|
||||
|
||||
def __call__(self, x):
|
||||
mean = x.mean(axis=0)
|
||||
square_sums = (x ** 2).sum(axis=0)
|
||||
|
||||
if self.norm_means:
|
||||
x = np.subtract(x, mean)
|
||||
if self.norm_vars:
|
||||
var = square_sums / x.shape[0] - mean ** 2
|
||||
std = np.sqrt(np.maximum(var, 1e-10))
|
||||
x = np.divide(x, std)
|
||||
|
||||
return x
|
478
fairseq/data/audio/speech_to_text_dataset.py
Normal file
478
fairseq/data/audio/speech_to_text_dataset.py
Normal file
@ -0,0 +1,478 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import List, Tuple, Optional, Dict
|
||||
import os.path as op
|
||||
import csv
|
||||
import io
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from fairseq.data import (FairseqDataset, Dictionary, ResamplingDataset,
|
||||
ConcatDataset, data_utils as fairseq_data_utils)
|
||||
from fairseq.data.audio.audio_utils import get_fbank, get_waveform
|
||||
from fairseq.data.audio.feature_transforms import CompositeAudioFeatureTransform
|
||||
|
||||
logging.basicConfig(
|
||||
format='%(asctime)s | %(levelname)s | %(name)s | %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S', level=logging.INFO
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class S2TDataConfig(object):
|
||||
"""Wrapper class for data config YAML"""
|
||||
def __init__(self, yaml_path):
|
||||
try:
|
||||
import yaml
|
||||
except ImportError:
|
||||
print('Please install PyYAML to load YAML files for '
|
||||
'S2T data config')
|
||||
self.config = {}
|
||||
if op.isfile(yaml_path):
|
||||
try:
|
||||
with open(yaml_path) as f:
|
||||
self.config = yaml.load(f, Loader=yaml.FullLoader)
|
||||
except Exception as e:
|
||||
logger.info(f'Failed to load config from {yaml_path}: {e}')
|
||||
else:
|
||||
logger.info(f'Cannot find {yaml_path}')
|
||||
|
||||
@property
|
||||
def vocab_filename(self):
|
||||
"""fairseq vocabulary file under data root"""
|
||||
return self.config.get('vocab_filename', 'dict.txt')
|
||||
|
||||
@property
|
||||
def shuffle(self) -> bool:
|
||||
"""Shuffle dataset samples before batching"""
|
||||
return self.config.get('shuffle', False)
|
||||
|
||||
@property
|
||||
def pre_tokenizer(self) -> Dict:
|
||||
"""Pre-tokenizer to apply before subword tokenization. Returning
|
||||
a dictionary with `tokenizer` providing the tokenizer name and
|
||||
the other items providing the tokenizer-specific arguments.
|
||||
Tokenizers are defined in `fairseq.data.encoders.*`"""
|
||||
return self.config.get('pre_tokenizer', {'tokenizer': None})
|
||||
|
||||
@property
|
||||
def bpe_tokenizer(self) -> Dict:
|
||||
"""Subword tokenizer to apply after pre-tokenization. Returning
|
||||
a dictionary with `bpe` providing the tokenizer name and
|
||||
the other items providing the tokenizer-specific arguments.
|
||||
Tokenizers are defined in `fairseq.data.encoders.*`"""
|
||||
return self.config.get('bpe_tokenizer', None)
|
||||
|
||||
@property
|
||||
def prepend_tgt_lang_tag(self) -> bool:
|
||||
"""Prepend target lang ID token as the target BOS (e.g. for to-many
|
||||
multilingual setting). During inference, this requires `--prefix-size 1`
|
||||
to force BOS to be lang ID token."""
|
||||
return self.config.get('prepend_tgt_lang_tag', False)
|
||||
|
||||
@property
|
||||
def input_feat_per_channel(self):
|
||||
"""The dimension of input features (per audio channel)"""
|
||||
return self.config.get('input_feat_per_channel', 80)
|
||||
|
||||
@property
|
||||
def input_channels(self):
|
||||
"""The number of channels in the input audio"""
|
||||
return self.config.get('input_channels', 1)
|
||||
|
||||
@property
|
||||
def sampling_alpha(self):
|
||||
"""Hyper-parameter alpha = 1/T for temperature-based resampling.
|
||||
(alpha = 1 for no resampling)"""
|
||||
return self.config.get('sampling_alpha', 1.)
|
||||
|
||||
@property
|
||||
def use_audio_input(self):
|
||||
"""Needed by the dataset loader to see if the model requires
|
||||
raw audio as inputs."""
|
||||
return self.config.get('use_audio_input', False)
|
||||
|
||||
@property
|
||||
def audio_root(self):
|
||||
"""Audio paths in the manifest TSV can be relative and this provides
|
||||
the root path. Set this to empty string when using absolute paths."""
|
||||
return self.config.get('audio_root', '')
|
||||
|
||||
def get_feature_transforms(self, split, is_train):
|
||||
"""Split-specific feature transforms. Allowing train set wildcard `_train`,
|
||||
evaluation set wildcard `_eval` and general wildcard `*` for matching."""
|
||||
from copy import deepcopy
|
||||
cfg = deepcopy(self.config)
|
||||
_cur = cfg.get('transforms', {})
|
||||
cur = _cur.get(split)
|
||||
cur = _cur.get('_train') if cur is None and is_train else cur
|
||||
cur = _cur.get('_eval') if cur is None and not is_train else cur
|
||||
cur = _cur.get('*') if cur is None else cur
|
||||
cfg['transforms'] = cur
|
||||
return cfg
|
||||
|
||||
|
||||
def is_npy_data(data: bytes) -> bool:
|
||||
return data[0] == 147 and data[1] == 78
|
||||
|
||||
|
||||
def is_flac_or_wav_data(data: bytes) -> bool:
|
||||
is_flac = (data[0] == 102 and data[1] == 76)
|
||||
is_wav = (data[0] == 82 and data[1] == 73)
|
||||
return is_flac or is_wav
|
||||
|
||||
|
||||
def read_from_uncompressed_zip(file_path, offset, file_size) -> bytes:
|
||||
with open(file_path, 'rb') as f:
|
||||
f.seek(offset)
|
||||
data = f.read(file_size)
|
||||
return data
|
||||
|
||||
|
||||
def get_features_from_npy_or_audio(path):
|
||||
ext = op.splitext(op.basename(path))[1]
|
||||
if ext not in {'.npy', '.flac', '.wav'}:
|
||||
raise ValueError(f'Unsupported file format for "{path}"')
|
||||
return np.load(path) if ext == '.npy' else get_fbank(path)
|
||||
|
||||
|
||||
def get_features_or_waveform_from_uncompressed_zip(
|
||||
path, byte_offset, byte_size, need_waveform=False
|
||||
):
|
||||
assert path.endswith('.zip')
|
||||
data = read_from_uncompressed_zip(path, byte_offset, byte_size)
|
||||
f = io.BytesIO(data)
|
||||
if is_npy_data(data):
|
||||
features_or_waveform = np.load(f)
|
||||
elif is_flac_or_wav_data(data):
|
||||
features_or_waveform = get_waveform(f)[0] if need_waveform else get_fbank(f)
|
||||
else:
|
||||
raise ValueError(f'Unknown file format for "{path}"')
|
||||
return features_or_waveform
|
||||
|
||||
|
||||
def get_features_or_waveform(path: str, need_waveform=False):
|
||||
"""Get speech features from .npy file or waveform from .wav/.flac file.
|
||||
The file may be inside an uncompressed ZIP file and is accessed via byte
|
||||
offset and length.
|
||||
|
||||
Args:
|
||||
path (str): File path in the format of "<.npy/.wav/.flac path>" or
|
||||
"<zip path>:<byte offset>:<byte length>".
|
||||
need_waveform (bool): return waveform instead of features.
|
||||
|
||||
Returns:
|
||||
features_or_waveform (numpy.ndarray): speech features or waveform.
|
||||
"""
|
||||
_path, *extra = path.split(':')
|
||||
if not op.exists(_path):
|
||||
raise FileNotFoundError(f'File not found: {_path}')
|
||||
|
||||
if len(extra) == 0:
|
||||
if need_waveform:
|
||||
return get_waveform(_path)
|
||||
return get_features_from_npy_or_audio(_path)
|
||||
elif len(extra) == 2:
|
||||
extra = [int(i) for i in extra]
|
||||
features_or_waveform = get_features_or_waveform_from_uncompressed_zip(
|
||||
_path, extra[0], extra[1], need_waveform=need_waveform
|
||||
)
|
||||
else:
|
||||
raise ValueError(f'Invalid path: {path}')
|
||||
|
||||
return features_or_waveform
|
||||
|
||||
|
||||
def _collate_frames(frames: List[torch.Tensor],
|
||||
is_audio_input: bool = False) -> torch.Tensor:
|
||||
"""
|
||||
Convert a list of 2D frames into a padded 3D tensor
|
||||
Args:
|
||||
frames (list): list of 2D frames of size L[i]*f_dim. Where L[i] is
|
||||
length of i-th frame and f_dim is static dimension of features
|
||||
Returns:
|
||||
3D tensor of size len(frames)*len_max*f_dim where len_max is max of L[i]
|
||||
"""
|
||||
max_len = max(frame.size(0) for frame in frames)
|
||||
if is_audio_input:
|
||||
out = frames[0].new_zeros((len(frames), max_len))
|
||||
else:
|
||||
out = frames[0].new_zeros((len(frames), max_len, frames[0].size(1)))
|
||||
for i, v in enumerate(frames):
|
||||
out[i, : v.size(0)] = v
|
||||
return out
|
||||
|
||||
|
||||
class SpeechToTextDataset(FairseqDataset):
|
||||
LANG_TAG_TEMPLATE = '<lang:{}>'
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
split: str,
|
||||
is_train_split: bool,
|
||||
data_cfg: S2TDataConfig,
|
||||
audio_paths: List[str],
|
||||
n_frames: List[int],
|
||||
src_texts: Optional[List[str]] = None,
|
||||
tgt_texts: Optional[List[str]] = None,
|
||||
speakers: Optional[List[str]] = None,
|
||||
src_langs: Optional[List[str]] = None,
|
||||
tgt_langs: Optional[List[str]] = None,
|
||||
ids: Optional[List[str]] = None,
|
||||
tgt_dict: Optional[Dictionary] = None,
|
||||
pre_tokenizer=None,
|
||||
bpe_tokenizer=None,
|
||||
):
|
||||
self.split, self.is_train_split = split, is_train_split
|
||||
self.data_cfg = data_cfg
|
||||
self.audio_paths, self.n_frames = audio_paths, n_frames
|
||||
self.n_samples = len(audio_paths)
|
||||
assert len(n_frames) == self.n_samples > 0
|
||||
assert src_texts is None or len(src_texts) == self.n_samples
|
||||
assert tgt_texts is None or len(tgt_texts) == self.n_samples
|
||||
assert speakers is None or len(speakers) == self.n_samples
|
||||
assert src_langs is None or len(src_langs) == self.n_samples
|
||||
assert tgt_langs is None or len(tgt_langs) == self.n_samples
|
||||
assert ids is None or len(ids) == self.n_samples
|
||||
assert (tgt_dict is None and tgt_texts is None) or \
|
||||
(tgt_dict is not None and tgt_texts is not None)
|
||||
self.tgt_dict = tgt_dict
|
||||
self.check_tgt_lang_tag()
|
||||
self.src_texts, self.tgt_texts = src_texts, tgt_texts
|
||||
self.src_langs, self.tgt_langs = src_langs, tgt_langs
|
||||
self.ids = ids
|
||||
self.shuffle = data_cfg.shuffle if is_train_split else False
|
||||
|
||||
self.feature_transforms = CompositeAudioFeatureTransform.from_config_dict(
|
||||
self.data_cfg.get_feature_transforms(split, is_train_split)
|
||||
)
|
||||
|
||||
self.pre_tokenizer = pre_tokenizer
|
||||
self.bpe_tokenizer = bpe_tokenizer
|
||||
|
||||
logger.info(self.__repr__())
|
||||
|
||||
def __repr__(self):
|
||||
return self.__class__.__name__ + \
|
||||
f'(split="{self.split}", n_samples={self.n_samples}, ' \
|
||||
f'prepend_tgt_lang_tag={self.data_cfg.prepend_tgt_lang_tag}, ' \
|
||||
f'shuffle={self.shuffle}, transforms={self.feature_transforms})'
|
||||
|
||||
@classmethod
|
||||
def is_lang_tag(cls, token):
|
||||
pattern = cls.LANG_TAG_TEMPLATE.replace('{}', '(.*)')
|
||||
return re.match(pattern, token)
|
||||
|
||||
def check_tgt_lang_tag(self):
|
||||
if self.data_cfg.prepend_tgt_lang_tag:
|
||||
assert self.tgt_langs is not None and self.tgt_dict is not None
|
||||
tgt_lang_tags = [self.LANG_TAG_TEMPLATE.format(t)
|
||||
for t in set(self.tgt_langs)]
|
||||
assert all(t in self.tgt_dict for t in tgt_lang_tags)
|
||||
|
||||
def tokenize_text(self, text: str):
|
||||
if self.pre_tokenizer is not None:
|
||||
text = self.pre_tokenizer.encode(text)
|
||||
if self.bpe_tokenizer is not None:
|
||||
text = self.bpe_tokenizer.encode(text)
|
||||
return text
|
||||
|
||||
def __getitem__(
|
||||
self, index: int
|
||||
) -> Tuple[int, torch.Tensor, Optional[torch.Tensor]]:
|
||||
source = get_features_or_waveform(
|
||||
self.audio_paths[index], need_waveform=self.data_cfg.use_audio_input
|
||||
)
|
||||
if self.feature_transforms is not None:
|
||||
assert not self.data_cfg.use_audio_input
|
||||
source = self.feature_transforms(source)
|
||||
source = torch.from_numpy(source).float()
|
||||
|
||||
target = None
|
||||
if self.tgt_texts is not None:
|
||||
tokenized = self.tokenize_text(self.tgt_texts[index])
|
||||
target = self.tgt_dict.encode_line(
|
||||
tokenized, add_if_not_exist=False, append_eos=True
|
||||
).long()
|
||||
if self.data_cfg.prepend_tgt_lang_tag:
|
||||
lang_tag = self.LANG_TAG_TEMPLATE.format(self.tgt_langs[index])
|
||||
lang_tag_idx = self.tgt_dict.index(lang_tag)
|
||||
target = torch.cat((torch.LongTensor([lang_tag_idx]), target), 0)
|
||||
return index, source, target
|
||||
|
||||
def __len__(self):
|
||||
return self.n_samples
|
||||
|
||||
def collater(
|
||||
self, samples: List[Tuple[int, torch.Tensor, torch.Tensor]]
|
||||
) -> Dict:
|
||||
if len(samples) == 0:
|
||||
return {}
|
||||
indices = torch.tensor([i for i, _, _ in samples], dtype=torch.long)
|
||||
frames = _collate_frames([s for _, s, _ in samples],
|
||||
self.data_cfg.use_audio_input)
|
||||
# sort samples by descending number of frames
|
||||
n_frames = torch.tensor(
|
||||
[s.size(0) for _, s, _ in samples], dtype=torch.long
|
||||
)
|
||||
n_frames, order = n_frames.sort(descending=True)
|
||||
indices = indices.index_select(0, order)
|
||||
frames = frames.index_select(0, order)
|
||||
|
||||
target, target_lengths = None, None
|
||||
prev_output_tokens = None
|
||||
ntokens = None
|
||||
if self.tgt_texts is not None:
|
||||
target = fairseq_data_utils.collate_tokens(
|
||||
[t for _, _, t in samples], self.tgt_dict.pad(),
|
||||
self.tgt_dict.eos(), left_pad=False, move_eos_to_beginning=False
|
||||
)
|
||||
target = target.index_select(0, order)
|
||||
target_lengths = torch.tensor(
|
||||
[t.size(0) for _, _, t in samples], dtype=torch.long
|
||||
).index_select(0, order)
|
||||
prev_output_tokens = fairseq_data_utils.collate_tokens(
|
||||
[t for _, _, t in samples], self.tgt_dict.pad(),
|
||||
self.tgt_dict.eos(), left_pad=False, move_eos_to_beginning=True
|
||||
)
|
||||
prev_output_tokens = prev_output_tokens.index_select(0, order)
|
||||
ntokens = sum(t.size(0) for _, _, t in samples)
|
||||
|
||||
out = {
|
||||
"id": indices,
|
||||
"net_input": {
|
||||
"src_tokens": frames,
|
||||
"src_lengths": n_frames,
|
||||
"prev_output_tokens": prev_output_tokens,
|
||||
},
|
||||
"target": target,
|
||||
"target_lengths": target_lengths,
|
||||
"ntokens": ntokens,
|
||||
"nsentences": len(samples),
|
||||
}
|
||||
return out
|
||||
|
||||
def num_tokens(self, index):
|
||||
return self.n_frames[index]
|
||||
|
||||
def size(self, index):
|
||||
t_len = 0
|
||||
if self.tgt_texts is not None:
|
||||
tokenized = self.tokenize_text(self.tgt_texts[index])
|
||||
t_len = len(tokenized.split(' '))
|
||||
return self.n_frames[index], t_len
|
||||
|
||||
@property
|
||||
def sizes(self):
|
||||
return np.array(self.n_frames)
|
||||
|
||||
@property
|
||||
def can_reuse_epoch_itr_across_epochs(self):
|
||||
return True
|
||||
|
||||
def ordered_indices(self):
|
||||
if self.shuffle:
|
||||
order = [np.random.permutation(len(self))]
|
||||
else:
|
||||
order = [np.arange(len(self))]
|
||||
# first by descending order of # of frames then by original/random order
|
||||
order.append([-n for n in self.n_frames])
|
||||
return np.lexsort(order)
|
||||
|
||||
def prefetch(self, indices):
|
||||
raise False
|
||||
|
||||
|
||||
class SpeechToTextDatasetCreator(object):
|
||||
# mandatory columns
|
||||
KEY_ID, KEY_AUDIO, KEY_N_FRAMES = 'id', 'audio', 'n_frames'
|
||||
KEY_TGT_TEXT = 'tgt_text'
|
||||
# optional columns
|
||||
KEY_SPEAKER, KEY_SRC_TEXT = 'speaker', 'src_text'
|
||||
KEY_SRC_LANG, KEY_TGT_LANG = 'src_lang', 'tgt_lang'
|
||||
# default values
|
||||
DEFAULT_SPEAKER = DEFAULT_SRC_TEXT = DEFAULT_LANG = ''
|
||||
|
||||
@classmethod
|
||||
def _from_list(cls, split_name: str, is_train_split,
|
||||
samples: List[List[Dict]], data_cfg: S2TDataConfig, tgt_dict,
|
||||
pre_tokenizer, bpe_tokenizer) -> SpeechToTextDataset:
|
||||
audio_paths, n_frames, src_texts, tgt_texts, ids = [], [], [], [], []
|
||||
speakers, src_langs, tgt_langs = [], [], []
|
||||
for s in samples:
|
||||
ids.extend([ss[cls.KEY_ID] for ss in s])
|
||||
audio_paths.extend([op.join(data_cfg.audio_root, ss[cls.KEY_AUDIO])
|
||||
for ss in s])
|
||||
n_frames.extend([int(ss[cls.KEY_N_FRAMES]) for ss in s])
|
||||
tgt_texts.extend([ss[cls.KEY_TGT_TEXT] for ss in s])
|
||||
src_texts.extend([ss.get(cls.KEY_SRC_TEXT, cls.DEFAULT_SRC_TEXT)
|
||||
for ss in s])
|
||||
speakers.extend([ss.get(cls.KEY_SPEAKER, cls.DEFAULT_SPEAKER)
|
||||
for ss in s])
|
||||
src_langs.extend([ss.get(cls.KEY_SRC_LANG, cls.DEFAULT_LANG)
|
||||
for ss in s])
|
||||
tgt_langs.extend([ss.get(cls.KEY_TGT_LANG, cls.DEFAULT_LANG)
|
||||
for ss in s])
|
||||
return SpeechToTextDataset(
|
||||
split_name, is_train_split, data_cfg, audio_paths, n_frames,
|
||||
src_texts, tgt_texts, speakers, src_langs, tgt_langs, ids, tgt_dict,
|
||||
pre_tokenizer, bpe_tokenizer
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _get_size_ratios(cls, ids: List[str], sizes: List[int],
|
||||
alpha: float = 1.):
|
||||
"""Size ratios for temperature-based sampling
|
||||
(https://arxiv.org/abs/1907.05019)"""
|
||||
_sizes = np.array(sizes)
|
||||
prob = _sizes / _sizes.sum()
|
||||
smoothed_prob = prob ** alpha
|
||||
smoothed_prob = smoothed_prob / smoothed_prob.sum()
|
||||
size_ratio = (smoothed_prob * _sizes.sum()) / _sizes
|
||||
|
||||
o_str = str({_i: f"{prob[i]:.3f}" for i, _i in enumerate(ids)})
|
||||
logger.info(f"original sampling probability: {o_str}")
|
||||
p_str = str({_i: f"{smoothed_prob[i]:.3f}" for i, _i in enumerate(ids)})
|
||||
logger.info(f"balanced sampling probability: {p_str}")
|
||||
sr_str = str({_id: f"{size_ratio[i]:.3f}" for i, _id in enumerate(ids)})
|
||||
logger.info(f"balanced sampling size ratio: {sr_str}")
|
||||
return size_ratio.tolist()
|
||||
|
||||
@classmethod
|
||||
def from_tsv(cls, root: str, data_cfg: S2TDataConfig, splits: str, tgt_dict,
|
||||
pre_tokenizer, bpe_tokenizer, is_train_split: bool, epoch: int,
|
||||
seed: int) -> SpeechToTextDataset:
|
||||
samples = []
|
||||
_splits = splits.split(',')
|
||||
for split in _splits:
|
||||
tsv_path = op.join(root, f'{split}.tsv')
|
||||
if not op.isfile(tsv_path):
|
||||
raise FileNotFoundError(f"Dataset not found: {tsv_path}")
|
||||
with open(tsv_path) as f:
|
||||
reader = csv.DictReader(
|
||||
f, delimiter='\t', quotechar=None, doublequote=False,
|
||||
lineterminator='\n', quoting=csv.QUOTE_NONE
|
||||
)
|
||||
samples.append([dict(e) for e in reader])
|
||||
assert len(samples) > 0
|
||||
|
||||
datasets = [cls._from_list(name, is_train_split, [s], data_cfg, tgt_dict,
|
||||
pre_tokenizer, bpe_tokenizer)
|
||||
for name, s in zip(_splits, samples)]
|
||||
|
||||
if is_train_split and len(_splits) > 1 and data_cfg.sampling_alpha != 1.:
|
||||
# temperature-based sampling
|
||||
size_ratios = cls._get_size_ratios(
|
||||
_splits, [len(s) for s in samples], alpha=data_cfg.sampling_alpha
|
||||
)
|
||||
datasets = [
|
||||
ResamplingDataset(d, size_ratio=r, seed=seed, epoch=epoch,
|
||||
replace=(r >= 1.))
|
||||
for d, r in zip(datasets, size_ratios)
|
||||
]
|
||||
return ConcatDataset(datasets)
|
@ -446,3 +446,14 @@ def get_mem_usage():
|
||||
return f'used={psutil.virtual_memory().used / mb}Mb; avail={psutil.virtual_memory().available / mb}Mb'
|
||||
except ImportError:
|
||||
return 'N/A'
|
||||
|
||||
|
||||
def lengths_to_padding_mask(lens: torch.LongTensor) -> torch.BoolTensor:
|
||||
bsz, max_lens = lens.size(0), torch.max(lens).item()
|
||||
mask = torch.arange(max_lens).to(lens.device).view(1, max_lens)
|
||||
mask = mask.expand(bsz, -1) >= lens.view(bsz, 1).expand(-1, max_lens)
|
||||
return mask
|
||||
|
||||
|
||||
def lengths_to_mask(lens: torch.LongTensor) -> torch.BoolTensor:
|
||||
return ~lengths_to_padding_mask(lens)
|
||||
|
7
fairseq/models/speech_to_text/__init__.py
Normal file
7
fairseq/models/speech_to_text/__init__.py
Normal file
@ -0,0 +1,7 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from .berard import * # noqa
|
||||
from .s2t_transformer import * # noqa
|
581
fairseq/models/speech_to_text/berard.py
Normal file
581
fairseq/models/speech_to_text/berard.py
Normal file
@ -0,0 +1,581 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
from ast import literal_eval
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from fairseq import checkpoint_utils, utils
|
||||
from fairseq.models import (
|
||||
FairseqEncoder,
|
||||
FairseqIncrementalDecoder,
|
||||
FairseqEncoderDecoderModel,
|
||||
register_model,
|
||||
register_model_architecture,
|
||||
)
|
||||
from fairseq.data.data_utils import lengths_to_padding_mask
|
||||
|
||||
|
||||
@register_model("s2t_berard")
|
||||
class BerardModel(FairseqEncoderDecoderModel):
|
||||
"""Implementation of a model similar to https://arxiv.org/abs/1802.04200
|
||||
|
||||
Paper title: End-to-End Automatic Speech Translation of Audiobooks
|
||||
An implementation is available in tensorflow at
|
||||
https://github.com/eske/seq2seq
|
||||
Relevant files in this implementation are the config
|
||||
(https://github.com/eske/seq2seq/blob/master/config/LibriSpeech/AST.yaml)
|
||||
and the model code
|
||||
(https://github.com/eske/seq2seq/blob/master/translate/models.py).
|
||||
The encoder and decoder try to be close to the original implementation.
|
||||
The attention is an MLP as in Bahdanau et al.
|
||||
(https://arxiv.org/abs/1409.0473).
|
||||
There is no state initialization by averaging the encoder outputs.
|
||||
"""
|
||||
|
||||
def __init__(self, encoder, decoder):
|
||||
super().__init__(encoder, decoder)
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
parser.add_argument("--input-layers", type=str, metavar="EXPR",
|
||||
help="List of linear layer dimensions. These "
|
||||
"layers are applied to the input features and "
|
||||
"are followed by tanh and possibly dropout.")
|
||||
parser.add_argument(
|
||||
"--dropout", type=float, metavar="D",
|
||||
help="Dropout probability to use in the encoder/decoder. "
|
||||
"Note that this parameters control dropout in various places, "
|
||||
"there is no fine-grained control for dropout for embeddings "
|
||||
"vs LSTM layers for example."
|
||||
)
|
||||
parser.add_argument("--in-channels", type=int, metavar="N",
|
||||
help="Number of encoder input channels. "
|
||||
"Typically value is 1.")
|
||||
parser.add_argument("--conv-layers", type=str, metavar="EXPR",
|
||||
help="List of conv layers "
|
||||
"(format: (channels, kernel, stride)).")
|
||||
parser.add_argument("--num-blstm-layers", type=int, metavar="N",
|
||||
help="Number of encoder bi-LSTM layers.")
|
||||
parser.add_argument("--lstm-size", type=int, metavar="N",
|
||||
help="LSTM hidden size.")
|
||||
parser.add_argument(
|
||||
"--decoder-embed-dim", type=int, metavar="N",
|
||||
help="Embedding dimension of the decoder target tokens."
|
||||
)
|
||||
parser.add_argument("--decoder-hidden-dim", type=int, metavar="N",
|
||||
help="Decoder LSTM hidden dimension.")
|
||||
parser.add_argument("--decoder-num-layers", type=int, metavar="N",
|
||||
help="Number of decoder LSTM layers.")
|
||||
parser.add_argument("--attention-dim", type=int, metavar="N",
|
||||
help="Hidden layer dimension in MLP attention.")
|
||||
parser.add_argument(
|
||||
"--output-layer-dim", type=int, metavar="N",
|
||||
help="Hidden layer dim for linear layer prior to output projection."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--load-pretrained-encoder-from", type=str, metavar="STR",
|
||||
help="model to take encoder weights from (for initialization)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--load-pretrained-decoder-from", type=str, metavar="STR",
|
||||
help="model to take decoder weights from (for initialization)"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def build_encoder(cls, args, task):
|
||||
encoder = BerardEncoder(
|
||||
input_layers=literal_eval(args.input_layers),
|
||||
conv_layers=literal_eval(args.conv_layers),
|
||||
in_channels=args.input_channels,
|
||||
input_feat_per_channel=args.input_feat_per_channel,
|
||||
num_blstm_layers=args.num_blstm_layers,
|
||||
lstm_size=args.lstm_size,
|
||||
dropout=args.dropout,
|
||||
)
|
||||
if getattr(args, "load_pretrained_encoder_from", None):
|
||||
encoder = checkpoint_utils.load_pretrained_component_from_model(
|
||||
component=encoder, checkpoint=args.load_pretrained_encoder_from
|
||||
)
|
||||
return encoder
|
||||
|
||||
@classmethod
|
||||
def build_decoder(cls, args, task):
|
||||
decoder = LSTMDecoder(
|
||||
dictionary=task.target_dictionary,
|
||||
embed_dim=args.decoder_embed_dim,
|
||||
num_layers=args.decoder_num_layers,
|
||||
hidden_size=args.decoder_hidden_dim,
|
||||
dropout=args.dropout,
|
||||
encoder_output_dim=2 * args.lstm_size, # bidirectional
|
||||
attention_dim=args.attention_dim,
|
||||
output_layer_dim=args.output_layer_dim,
|
||||
)
|
||||
if getattr(args, "load_pretrained_decoder_from", None):
|
||||
decoder = checkpoint_utils.load_pretrained_component_from_model(
|
||||
component=decoder, checkpoint=args.load_pretrained_decoder_from
|
||||
)
|
||||
return decoder
|
||||
|
||||
@classmethod
|
||||
def build_model(cls, args, task):
|
||||
"""Build a new model instance."""
|
||||
encoder = cls.build_encoder(args, task)
|
||||
decoder = cls.build_decoder(args, task)
|
||||
|
||||
return cls(encoder, decoder)
|
||||
|
||||
def get_normalized_probs(self, net_output, log_probs, sample=None):
|
||||
# net_output['encoder_out'] is a (B, T, D) tensor
|
||||
lprobs = super().get_normalized_probs(net_output, log_probs, sample)
|
||||
# lprobs is a (B, T, D) tensor
|
||||
lprobs.batch_first = True
|
||||
return lprobs
|
||||
|
||||
|
||||
class BerardEncoder(FairseqEncoder):
|
||||
def __init__(
|
||||
self,
|
||||
input_layers: List[int],
|
||||
conv_layers: List[Tuple[int]],
|
||||
in_channels: int,
|
||||
input_feat_per_channel: int,
|
||||
num_blstm_layers: int,
|
||||
lstm_size: int,
|
||||
dropout: float,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
input_layers: list of linear layer dimensions. These layers are
|
||||
applied to the input features and are followed by tanh and
|
||||
possibly dropout.
|
||||
conv_layers: list of conv2d layer configurations. A configuration is
|
||||
a tuple (out_channels, conv_kernel_size, stride).
|
||||
in_channels: number of input channels.
|
||||
input_feat_per_channel: number of input features per channel. These
|
||||
are speech features, typically 40 or 80.
|
||||
num_blstm_layers: number of bidirectional LSTM layers.
|
||||
lstm_size: size of the LSTM hidden (and cell) size.
|
||||
dropout: dropout probability. Dropout can be applied after the
|
||||
linear layers and LSTM layers but not to the convolutional
|
||||
layers.
|
||||
"""
|
||||
super().__init__(None)
|
||||
|
||||
self.input_layers = nn.ModuleList()
|
||||
in_features = input_feat_per_channel
|
||||
for out_features in input_layers:
|
||||
if dropout > 0:
|
||||
self.input_layers.append(
|
||||
nn.Sequential(
|
||||
nn.Linear(in_features, out_features),
|
||||
nn.Dropout(p=dropout)
|
||||
)
|
||||
)
|
||||
else:
|
||||
self.input_layers.append(nn.Linear(in_features, out_features))
|
||||
in_features = out_features
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.input_dim = input_feat_per_channel
|
||||
self.conv_kernel_sizes_and_strides = []
|
||||
self.conv_layers = nn.ModuleList()
|
||||
lstm_input_dim = input_layers[-1]
|
||||
for conv_layer in conv_layers:
|
||||
out_channels, conv_kernel_size, conv_stride = conv_layer
|
||||
self.conv_layers.append(
|
||||
nn.Conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
conv_kernel_size,
|
||||
stride=conv_stride,
|
||||
padding=conv_kernel_size // 2,
|
||||
)
|
||||
)
|
||||
self.conv_kernel_sizes_and_strides.append(
|
||||
(conv_kernel_size, conv_stride)
|
||||
)
|
||||
in_channels = out_channels
|
||||
lstm_input_dim //= conv_stride
|
||||
|
||||
lstm_input_dim *= conv_layers[-1][0]
|
||||
self.lstm_size = lstm_size
|
||||
self.num_blstm_layers = num_blstm_layers
|
||||
self.lstm = nn.LSTM(
|
||||
input_size=lstm_input_dim,
|
||||
hidden_size=lstm_size,
|
||||
num_layers=num_blstm_layers,
|
||||
dropout=dropout,
|
||||
bidirectional=True,
|
||||
)
|
||||
self.output_dim = 2 * lstm_size # bidirectional
|
||||
if dropout > 0:
|
||||
self.dropout = nn.Dropout(p=dropout)
|
||||
else:
|
||||
self.dropout = None
|
||||
|
||||
def forward(self, src_tokens, src_lengths=None, **kwargs):
|
||||
"""
|
||||
Args
|
||||
src_tokens: padded tensor (B, T, C * feat)
|
||||
src_lengths: tensor of original lengths of input utterances (B,)
|
||||
"""
|
||||
bsz, max_seq_len, _ = src_tokens.size()
|
||||
# (B, C, T, feat)
|
||||
x = (
|
||||
src_tokens.view(bsz, max_seq_len, self.in_channels, self.input_dim)
|
||||
.transpose(1, 2)
|
||||
.contiguous()
|
||||
)
|
||||
|
||||
for input_layer in self.input_layers:
|
||||
x = input_layer(x)
|
||||
x = torch.tanh(x)
|
||||
|
||||
for conv_layer in self.conv_layers:
|
||||
x = conv_layer(x)
|
||||
|
||||
bsz, _, output_seq_len, _ = x.size()
|
||||
|
||||
# (B, C, T, feat) -> (B, T, C, feat) -> (T, B, C, feat) ->
|
||||
# (T, B, C * feat)
|
||||
x = x.transpose(1, 2).transpose(0, 1).contiguous().view(output_seq_len,
|
||||
bsz, -1)
|
||||
|
||||
input_lengths = src_lengths.clone()
|
||||
for k, s in self.conv_kernel_sizes_and_strides:
|
||||
p = k // 2
|
||||
input_lengths = (input_lengths.float() + 2 * p - k) / s + 1
|
||||
input_lengths = input_lengths.floor().long()
|
||||
|
||||
packed_x = nn.utils.rnn.pack_padded_sequence(x, input_lengths)
|
||||
|
||||
h0 = x.new(2 * self.num_blstm_layers, bsz, self.lstm_size).zero_()
|
||||
c0 = x.new(2 * self.num_blstm_layers, bsz, self.lstm_size).zero_()
|
||||
packed_outs, _ = self.lstm(packed_x, (h0, c0))
|
||||
|
||||
# unpack outputs and apply dropout
|
||||
x, output_lengths = nn.utils.rnn.pad_packed_sequence(packed_outs)
|
||||
if self.dropout is not None:
|
||||
x = self.dropout(x)
|
||||
|
||||
encoder_padding_mask = lengths_to_padding_mask(output_lengths).to(
|
||||
src_tokens.device).t()
|
||||
|
||||
return {
|
||||
"encoder_out": x, # (T, B, C)
|
||||
"encoder_padding_mask": encoder_padding_mask, # (T, B)
|
||||
}
|
||||
|
||||
def reorder_encoder_out(self, encoder_out, new_order):
|
||||
encoder_out["encoder_out"] = encoder_out["encoder_out"].index_select(
|
||||
1, new_order
|
||||
)
|
||||
encoder_out["encoder_padding_mask"] = encoder_out[
|
||||
"encoder_padding_mask"
|
||||
].index_select(1, new_order)
|
||||
return encoder_out
|
||||
|
||||
|
||||
class MLPAttention(nn.Module):
|
||||
"""The original attention from Badhanau et al. (2014)
|
||||
|
||||
https://arxiv.org/abs/1409.0473, based on a Multi-Layer Perceptron.
|
||||
The attention score between position i in the encoder and position j in the
|
||||
decoder is: alpha_ij = V_a * tanh(W_ae * enc_i + W_ad * dec_j + b_a)
|
||||
"""
|
||||
|
||||
def __init__(self, decoder_hidden_state_dim, context_dim, attention_dim):
|
||||
super().__init__()
|
||||
|
||||
self.context_dim = context_dim
|
||||
self.attention_dim = attention_dim
|
||||
# W_ae and b_a
|
||||
self.encoder_proj = nn.Linear(context_dim, self.attention_dim,
|
||||
bias=True)
|
||||
# W_ad
|
||||
self.decoder_proj = nn.Linear(
|
||||
decoder_hidden_state_dim, self.attention_dim, bias=False
|
||||
)
|
||||
# V_a
|
||||
self.to_scores = nn.Linear(self.attention_dim, 1, bias=False)
|
||||
|
||||
def forward(self, decoder_state, source_hids, encoder_padding_mask):
|
||||
"""The expected input dimensions are:
|
||||
decoder_state: bsz x decoder_hidden_state_dim
|
||||
source_hids: src_len x bsz x context_dim
|
||||
encoder_padding_mask: src_len x bsz
|
||||
"""
|
||||
src_len, bsz, _ = source_hids.size()
|
||||
# (src_len*bsz) x context_dim (to feed through linear)
|
||||
flat_source_hids = source_hids.view(-1, self.context_dim)
|
||||
# (src_len*bsz) x attention_dim
|
||||
encoder_component = self.encoder_proj(flat_source_hids)
|
||||
# src_len x bsz x attention_dim
|
||||
encoder_component = encoder_component.view(src_len, bsz,
|
||||
self.attention_dim)
|
||||
# 1 x bsz x attention_dim
|
||||
decoder_component = self.decoder_proj(decoder_state).unsqueeze(0)
|
||||
# Sum with broadcasting and apply the non linearity
|
||||
# src_len x bsz x attention_dim
|
||||
hidden_att = torch.tanh(
|
||||
(decoder_component + encoder_component).view(-1, self.attention_dim)
|
||||
)
|
||||
# Project onto the reals to get attentions scores (src_len x bsz)
|
||||
attn_scores = self.to_scores(hidden_att).view(src_len, bsz)
|
||||
|
||||
# Mask + softmax (src_len x bsz)
|
||||
if encoder_padding_mask is not None:
|
||||
attn_scores = (
|
||||
attn_scores.float()
|
||||
.masked_fill_(encoder_padding_mask, float("-inf"))
|
||||
.type_as(attn_scores)
|
||||
) # FP16 support: cast to float and back
|
||||
# srclen x bsz
|
||||
normalized_masked_attn_scores = F.softmax(attn_scores, dim=0)
|
||||
|
||||
# Sum weighted sources (bsz x context_dim)
|
||||
attn_weighted_context = (
|
||||
source_hids * normalized_masked_attn_scores.unsqueeze(2)
|
||||
).sum(dim=0)
|
||||
|
||||
return attn_weighted_context, normalized_masked_attn_scores
|
||||
|
||||
|
||||
class LSTMDecoder(FairseqIncrementalDecoder):
|
||||
def __init__(
|
||||
self,
|
||||
dictionary,
|
||||
embed_dim,
|
||||
num_layers,
|
||||
hidden_size,
|
||||
dropout,
|
||||
encoder_output_dim,
|
||||
attention_dim,
|
||||
output_layer_dim,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
dictionary: target text dictionary.
|
||||
embed_dim: embedding dimension for target tokens.
|
||||
num_layers: number of LSTM layers.
|
||||
hidden_size: hidden size for LSTM layers.
|
||||
dropout: dropout probability. Dropout can be applied to the
|
||||
embeddings, the LSTM layers, and the context vector.
|
||||
encoder_output_dim: encoder output dimension (hidden size of
|
||||
encoder LSTM).
|
||||
attention_dim: attention dimension for MLP attention.
|
||||
output_layer_dim: size of the linear layer prior to output
|
||||
projection.
|
||||
"""
|
||||
super().__init__(dictionary)
|
||||
self.num_layers = num_layers
|
||||
self.hidden_size = hidden_size
|
||||
num_embeddings = len(dictionary)
|
||||
padding_idx = dictionary.pad()
|
||||
self.embed_tokens = nn.Embedding(num_embeddings, embed_dim, padding_idx)
|
||||
if dropout > 0:
|
||||
self.dropout = nn.Dropout(p=dropout)
|
||||
else:
|
||||
self.dropout = None
|
||||
|
||||
self.layers = nn.ModuleList()
|
||||
for layer_id in range(num_layers):
|
||||
input_size = embed_dim if layer_id == 0 else encoder_output_dim
|
||||
self.layers.append(
|
||||
nn.LSTMCell(input_size=input_size, hidden_size=hidden_size)
|
||||
)
|
||||
|
||||
self.context_dim = encoder_output_dim
|
||||
self.attention = MLPAttention(
|
||||
decoder_hidden_state_dim=hidden_size,
|
||||
context_dim=encoder_output_dim,
|
||||
attention_dim=attention_dim,
|
||||
)
|
||||
|
||||
self.deep_output_layer = nn.Linear(
|
||||
hidden_size + encoder_output_dim + embed_dim, output_layer_dim
|
||||
)
|
||||
self.output_projection = nn.Linear(output_layer_dim, num_embeddings)
|
||||
|
||||
def forward(self, prev_output_tokens, encoder_out=None,
|
||||
incremental_state=None, **kwargs):
|
||||
encoder_padding_mask = encoder_out["encoder_padding_mask"]
|
||||
encoder_outs = encoder_out["encoder_out"]
|
||||
|
||||
if incremental_state is not None:
|
||||
prev_output_tokens = prev_output_tokens[:, -1:]
|
||||
bsz, seqlen = prev_output_tokens.size()
|
||||
|
||||
srclen = encoder_outs.size(0)
|
||||
|
||||
# embed tokens
|
||||
embeddings = self.embed_tokens(prev_output_tokens)
|
||||
x = embeddings
|
||||
if self.dropout is not None:
|
||||
x = self.dropout(x)
|
||||
|
||||
# B x T x C -> T x B x C
|
||||
x = x.transpose(0, 1)
|
||||
|
||||
# initialize previous states (or get from cache during incremental
|
||||
# generation)
|
||||
cached_state = utils.get_incremental_state(
|
||||
self, incremental_state, "cached_state"
|
||||
)
|
||||
if cached_state is not None:
|
||||
prev_hiddens, prev_cells = cached_state
|
||||
else:
|
||||
prev_hiddens = [
|
||||
encoder_out["encoder_out"].mean(dim=0)
|
||||
] * self.num_layers
|
||||
prev_cells = [x.new_zeros(bsz, self.hidden_size)] * self.num_layers
|
||||
|
||||
attn_scores = x.new_zeros(bsz, srclen)
|
||||
attention_outs = []
|
||||
outs = []
|
||||
for j in range(seqlen):
|
||||
input = x[j, :, :]
|
||||
attention_out = None
|
||||
for i, layer in enumerate(self.layers):
|
||||
# the previous state is one layer below except for the bottom
|
||||
# layer where the previous state is the state emitted by the
|
||||
# top layer
|
||||
hidden, cell = layer(
|
||||
input,
|
||||
(
|
||||
prev_hiddens[(i - 1) % self.num_layers],
|
||||
prev_cells[(i - 1) % self.num_layers],
|
||||
),
|
||||
)
|
||||
if self.dropout is not None:
|
||||
hidden = self.dropout(hidden)
|
||||
prev_hiddens[i] = hidden
|
||||
prev_cells[i] = cell
|
||||
if attention_out is None:
|
||||
attention_out, attn_scores = self.attention(
|
||||
hidden, encoder_outs, encoder_padding_mask
|
||||
)
|
||||
if self.dropout is not None:
|
||||
attention_out = self.dropout(attention_out)
|
||||
attention_outs.append(attention_out)
|
||||
input = attention_out
|
||||
|
||||
# collect the output of the top layer
|
||||
outs.append(hidden)
|
||||
|
||||
# cache previous states (no-op except during incremental generation)
|
||||
utils.set_incremental_state(
|
||||
self, incremental_state, "cached_state", (prev_hiddens, prev_cells)
|
||||
)
|
||||
|
||||
# collect outputs across time steps
|
||||
x = torch.cat(outs, dim=0).view(seqlen, bsz, self.hidden_size)
|
||||
attention_outs_concat = torch.cat(attention_outs, dim=0).view(
|
||||
seqlen, bsz, self.context_dim
|
||||
)
|
||||
|
||||
# T x B x C -> B x T x C
|
||||
x = x.transpose(0, 1)
|
||||
attention_outs_concat = attention_outs_concat.transpose(0, 1)
|
||||
|
||||
# concat LSTM output, attention output and embedding
|
||||
# before output projection
|
||||
x = torch.cat((x, attention_outs_concat, embeddings), dim=2)
|
||||
x = self.deep_output_layer(x)
|
||||
x = torch.tanh(x)
|
||||
if self.dropout is not None:
|
||||
x = self.dropout(x)
|
||||
# project back to size of vocabulary
|
||||
x = self.output_projection(x)
|
||||
|
||||
# to return the full attn_scores tensor, we need to fix the decoder
|
||||
# to account for subsampling input frames
|
||||
# return x, attn_scores
|
||||
return x, None
|
||||
|
||||
def reorder_incremental_state(self, incremental_state, new_order):
|
||||
super().reorder_incremental_state(incremental_state, new_order)
|
||||
cached_state = utils.get_incremental_state(
|
||||
self, incremental_state, "cached_state"
|
||||
)
|
||||
if cached_state is None:
|
||||
return
|
||||
|
||||
def reorder_state(state):
|
||||
if isinstance(state, list):
|
||||
return [reorder_state(state_i) for state_i in state]
|
||||
return state.index_select(0, new_order)
|
||||
|
||||
new_state = tuple(map(reorder_state, cached_state))
|
||||
utils.set_incremental_state(
|
||||
self, incremental_state, "cached_state", new_state
|
||||
)
|
||||
|
||||
|
||||
@register_model_architecture(model_name="s2t_berard", arch_name="s2t_berard")
|
||||
def berard(args):
|
||||
"""The original version: "End-to-End Automatic Speech Translation of
|
||||
Audiobooks" (https://arxiv.org/abs/1802.04200)
|
||||
"""
|
||||
args.input_layers = getattr(args, "input_layers", "[256, 128]")
|
||||
args.conv_layers = getattr(args, "conv_layers", "[(16, 3, 2), (16, 3, 2)]")
|
||||
args.num_blstm_layers = getattr(args, "num_blstm_layers", 3)
|
||||
args.lstm_size = getattr(args, "lstm_size", 256)
|
||||
args.dropout = getattr(args, "dropout", 0.2)
|
||||
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 128)
|
||||
args.decoder_num_layers = getattr(args, "decoder_num_layers", 2)
|
||||
args.decoder_hidden_dim = getattr(args, "decoder_hidden_dim", 512)
|
||||
args.attention_dim = getattr(args, "attention_dim", 512)
|
||||
args.output_layer_dim = getattr(args, "output_layer_dim", 128)
|
||||
args.load_pretrained_encoder_from = getattr(
|
||||
args, "load_pretrained_encoder_from", None
|
||||
)
|
||||
args.load_pretrained_decoder_from = getattr(
|
||||
args, "load_pretrained_decoder_from", None
|
||||
)
|
||||
|
||||
|
||||
@register_model_architecture(model_name="s2t_berard",
|
||||
arch_name="s2t_berard_256_3_3")
|
||||
def berard_256_3_3(args):
|
||||
"""Used in
|
||||
* "Harnessing Indirect Training Data for End-to-End Automatic Speech
|
||||
Translation: Tricks of the Trade" (https://arxiv.org/abs/1909.06515)
|
||||
* "CoVoST: A Diverse Multilingual Speech-To-Text Translation Corpus"
|
||||
(https://arxiv.org/pdf/2002.01320.pdf)
|
||||
* "Self-Supervised Representations Improve End-to-End Speech Translation"
|
||||
(https://arxiv.org/abs/2006.12124)
|
||||
"""
|
||||
args.decoder_num_layers = getattr(args, "decoder_num_layers", 3)
|
||||
berard(args)
|
||||
|
||||
|
||||
@register_model_architecture(model_name="s2t_berard",
|
||||
arch_name="s2t_berard_512_3_2")
|
||||
def berard_512_3_2(args):
|
||||
args.num_blstm_layers = getattr(args, "num_blstm_layers", 3)
|
||||
args.lstm_size = getattr(args, "lstm_size", 512)
|
||||
args.dropout = getattr(args, "dropout", 0.3)
|
||||
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 256)
|
||||
args.decoder_num_layers = getattr(args, "decoder_num_layers", 2)
|
||||
args.decoder_hidden_dim = getattr(args, "decoder_hidden_dim", 1024)
|
||||
args.attention_dim = getattr(args, "attention_dim", 512)
|
||||
args.output_layer_dim = getattr(args, "output_layer_dim", 256)
|
||||
berard(args)
|
||||
|
||||
|
||||
@register_model_architecture(model_name="s2t_berard",
|
||||
arch_name="s2t_berard_512_5_3")
|
||||
def berard_512_5_3(args):
|
||||
args.num_blstm_layers = getattr(args, "num_blstm_layers", 5)
|
||||
args.lstm_size = getattr(args, "lstm_size", 512)
|
||||
args.dropout = getattr(args, "dropout", 0.3)
|
||||
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 256)
|
||||
args.decoder_num_layers = getattr(args, "decoder_num_layers", 3)
|
||||
args.decoder_hidden_dim = getattr(args, "decoder_hidden_dim", 1024)
|
||||
args.attention_dim = getattr(args, "attention_dim", 512)
|
||||
args.output_layer_dim = getattr(args, "output_layer_dim", 256)
|
||||
berard(args)
|
394
fairseq/models/speech_to_text/s2t_transformer.py
Normal file
394
fairseq/models/speech_to_text/s2t_transformer.py
Normal file
@ -0,0 +1,394 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import logging
|
||||
import math
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from fairseq import utils, checkpoint_utils
|
||||
from fairseq.models import (FairseqEncoder, FairseqEncoderDecoderModel,
|
||||
register_model, register_model_architecture)
|
||||
from fairseq.models.fairseq_encoder import EncoderOut
|
||||
from fairseq.data.data_utils import lengths_to_padding_mask
|
||||
from fairseq.models.transformer import Embedding, TransformerDecoder
|
||||
from fairseq.modules import (PositionalEmbedding, TransformerEncoderLayer,
|
||||
FairseqDropout, LayerNorm)
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Conv1dSubsampler(nn.Module):
|
||||
"""Convolutional subsampler: a stack of 1D convolution (along temporal
|
||||
dimension) followed by non-linear activation via gated linear units
|
||||
(https://arxiv.org/abs/1911.08460)
|
||||
|
||||
Args:
|
||||
in_channels (int): the number of input channels
|
||||
mid_channels (int): the number of intermediate channels
|
||||
out_channels (int): the number of output channels
|
||||
kernel_sizes (List[int]): the kernel size for each convolutional layer
|
||||
"""
|
||||
def __init__(self, in_channels: int, mid_channels: int, out_channels: int,
|
||||
kernel_sizes: List[int] = (3, 3)):
|
||||
super(Conv1dSubsampler, self).__init__()
|
||||
self.n_layers = len(kernel_sizes)
|
||||
self.conv_layers = nn.ModuleList(
|
||||
nn.Conv1d(
|
||||
in_channels if i == 0 else mid_channels // 2,
|
||||
mid_channels if i < self.n_layers - 1 else out_channels * 2,
|
||||
k, stride=2, padding=k // 2
|
||||
)
|
||||
for i, k in enumerate(kernel_sizes)
|
||||
)
|
||||
|
||||
def get_out_seq_lens_tensor(self, in_seq_lens_tensor):
|
||||
out = in_seq_lens_tensor.clone()
|
||||
for _ in range(self.n_layers):
|
||||
out = ((out.float() - 1) / 2 + 1).floor().long()
|
||||
return out
|
||||
|
||||
def forward(self, src_tokens, src_lengths):
|
||||
bsz, in_seq_len, _ = src_tokens.size() # B x T x (C x D)
|
||||
x = src_tokens.transpose(1, 2).contiguous() # -> B x (C x D) x T
|
||||
for conv in self.conv_layers:
|
||||
x = conv(x)
|
||||
x = nn.functional.glu(x, dim=1)
|
||||
_, _, out_seq_len = x.size()
|
||||
x = x.transpose(1, 2).transpose(0, 1).contiguous() # -> T x B x (C x D)
|
||||
return x, self.get_out_seq_lens_tensor(src_lengths)
|
||||
|
||||
|
||||
@register_model("s2t_transformer")
|
||||
class S2TTransformerModel(FairseqEncoderDecoderModel):
|
||||
"""Adapted Transformer model (https://arxiv.org/abs/1706.03762) for
|
||||
speech-to-text tasks. The Transformer encoder/decoder remains the same.
|
||||
A trainable input subsampler is prepended to the Transformer encoder to
|
||||
project inputs into the encoder dimension as well as downsample input
|
||||
sequence for computational efficiency."""
|
||||
|
||||
def __init__(self, encoder, decoder):
|
||||
super().__init__(encoder, decoder)
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
"""Add model-specific arguments to the parser."""
|
||||
# input
|
||||
parser.add_argument("--conv-kernel-sizes", type=str, metavar="N",
|
||||
help="kernel sizes of Conv1d subsampling layers")
|
||||
parser.add_argument("--conv-channels", type=int, metavar="N",
|
||||
help="# of channels in Conv1d subsampling layers")
|
||||
# Transformer
|
||||
parser.add_argument("--activation-fn", type=str, default='relu',
|
||||
choices=utils.get_available_activation_fns(),
|
||||
help="activation function to use")
|
||||
parser.add_argument("--dropout", type=float, metavar="D",
|
||||
help="dropout probability")
|
||||
parser.add_argument("--attention-dropout", type=float, metavar="D",
|
||||
help="dropout probability for attention weights")
|
||||
parser.add_argument("--activation-dropout", "--relu-dropout",
|
||||
type=float, metavar="D",
|
||||
help="dropout probability after activation in FFN.")
|
||||
parser.add_argument("--encoder-embed-dim", type=int, metavar="N",
|
||||
help="encoder embedding dimension")
|
||||
parser.add_argument("--encoder-ffn-embed-dim", type=int, metavar="N",
|
||||
help="encoder embedding dimension for FFN")
|
||||
parser.add_argument("--encoder-layers", type=int, metavar="N",
|
||||
help="num encoder layers")
|
||||
parser.add_argument("--encoder-attention-heads", type=int, metavar="N",
|
||||
help="num encoder attention heads")
|
||||
parser.add_argument("--encoder-normalize-before", action="store_true",
|
||||
help="apply layernorm before each encoder block")
|
||||
parser.add_argument("--decoder-embed-dim", type=int, metavar="N",
|
||||
help="decoder embedding dimension")
|
||||
parser.add_argument("--decoder-ffn-embed-dim", type=int, metavar="N",
|
||||
help="decoder embedding dimension for FFN")
|
||||
parser.add_argument("--decoder-layers", type=int, metavar="N",
|
||||
help="num decoder layers")
|
||||
parser.add_argument("--decoder-attention-heads", type=int, metavar="N",
|
||||
help="num decoder attention heads")
|
||||
parser.add_argument("--decoder-normalize-before", action="store_true",
|
||||
help="apply layernorm before each decoder block")
|
||||
parser.add_argument("--layernorm-embedding", action="store_true",
|
||||
help="add layernorm to embedding")
|
||||
parser.add_argument("--no-scale-embedding", action="store_true",
|
||||
help="if True, dont scale embeddings")
|
||||
parser.add_argument(
|
||||
"--load-pretrained-encoder-from", type=str, metavar="STR",
|
||||
help="model to take encoder weights from (for initialization)"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def build_encoder(cls, args):
|
||||
encoder = S2TTransformerEncoder(args)
|
||||
if getattr(args, "load_pretrained_encoder_from", None):
|
||||
encoder = checkpoint_utils.load_pretrained_component_from_model(
|
||||
component=encoder, checkpoint=args.load_pretrained_encoder_from
|
||||
)
|
||||
logger.info(f'loaded pretrained encoder from: '
|
||||
f'{args.load_pretrained_encoder_from}')
|
||||
return encoder
|
||||
|
||||
@classmethod
|
||||
def build_decoder(cls, args, task, embed_tokens):
|
||||
return TransformerDecoderScriptable(args, task.target_dictionary,
|
||||
embed_tokens)
|
||||
|
||||
@classmethod
|
||||
def build_model(cls, args, task):
|
||||
"""Build a new model instance."""
|
||||
|
||||
# make sure all arguments are present in older models
|
||||
base_architecture(args)
|
||||
|
||||
def build_embedding(dictionary, embed_dim):
|
||||
num_embeddings = len(dictionary)
|
||||
padding_idx = dictionary.pad()
|
||||
return Embedding(num_embeddings, embed_dim, padding_idx)
|
||||
|
||||
decoder_embed_tokens = build_embedding(task.target_dictionary,
|
||||
args.decoder_embed_dim)
|
||||
encoder = cls.build_encoder(args)
|
||||
decoder = cls.build_decoder(args, task, decoder_embed_tokens)
|
||||
return cls(encoder, decoder)
|
||||
|
||||
def get_normalized_probs(
|
||||
self,
|
||||
net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]],
|
||||
log_probs: bool,
|
||||
sample: Optional[Dict[str, Tensor]] = None,
|
||||
):
|
||||
# net_output['encoder_out'] is a (B, T, D) tensor
|
||||
lprobs = self.get_normalized_probs_scriptable(net_output, log_probs,
|
||||
sample)
|
||||
lprobs.batch_first = True
|
||||
return lprobs
|
||||
|
||||
def forward(self, src_tokens, src_lengths, prev_output_tokens):
|
||||
"""
|
||||
The forward method inherited from the base class has a **kwargs
|
||||
argument in its input, which is not supported in torchscript. This
|
||||
method overrites the forward method definition without **kwargs.
|
||||
"""
|
||||
encoder_out = self.encoder(src_tokens=src_tokens,
|
||||
src_lengths=src_lengths)
|
||||
decoder_out = self.decoder(prev_output_tokens=prev_output_tokens,
|
||||
encoder_out=encoder_out)
|
||||
return decoder_out
|
||||
|
||||
|
||||
class S2TTransformerEncoder(FairseqEncoder):
|
||||
"""Speech-to-text Transformer encoder that consists of input subsampler and
|
||||
Transformer encoder."""
|
||||
|
||||
def __init__(self, args):
|
||||
super().__init__(None)
|
||||
|
||||
self.dropout_module = FairseqDropout(
|
||||
p=args.dropout, module_name=self.__class__.__name__
|
||||
)
|
||||
self.embed_scale = math.sqrt(args.encoder_embed_dim)
|
||||
if args.no_scale_embedding:
|
||||
self.embed_scale = 1.0
|
||||
self.padding_idx = 1
|
||||
|
||||
self.subsample = Conv1dSubsampler(
|
||||
args.input_feat_per_channel * args.input_channels,
|
||||
args.conv_channels, args.encoder_embed_dim,
|
||||
[int(k) for k in args.conv_kernel_sizes.split(',')]
|
||||
)
|
||||
|
||||
self.embed_positions = PositionalEmbedding(
|
||||
args.max_source_positions, args.encoder_embed_dim,
|
||||
self.padding_idx
|
||||
)
|
||||
|
||||
self.transformer_layers = nn.ModuleList(
|
||||
[TransformerEncoderLayer(args) for _ in range(args.encoder_layers)]
|
||||
)
|
||||
if args.encoder_normalize_before:
|
||||
self.layer_norm = LayerNorm(args.encoder_embed_dim)
|
||||
else:
|
||||
self.layer_norm = None
|
||||
|
||||
def forward(self, src_tokens, src_lengths):
|
||||
x, input_lengths = self.subsample(src_tokens, src_lengths)
|
||||
x = self.embed_scale * x
|
||||
|
||||
encoder_padding_mask = lengths_to_padding_mask(input_lengths)
|
||||
positions = self.embed_positions(encoder_padding_mask).transpose(0, 1)
|
||||
x += positions
|
||||
x = self.dropout_module(x)
|
||||
|
||||
for layer in self.transformer_layers:
|
||||
x = layer(x, encoder_padding_mask)
|
||||
|
||||
if not encoder_padding_mask.any():
|
||||
encoder_padding_mask = None
|
||||
|
||||
if self.layer_norm is not None:
|
||||
x = self.layer_norm(x)
|
||||
|
||||
return EncoderOut(
|
||||
encoder_out=x, encoder_padding_mask=encoder_padding_mask,
|
||||
encoder_embedding=None, encoder_states=None, src_tokens=None,
|
||||
src_lengths=None
|
||||
)
|
||||
|
||||
@torch.jit.export
|
||||
def reorder_encoder_out(self, encoder_out: EncoderOut, new_order):
|
||||
"""
|
||||
Since encoder_padding_mask and encoder_embedding are both of type
|
||||
Optional[Tensor] in EncoderOut, they need to be copied as local
|
||||
variables for Torchscript Optional refinement
|
||||
"""
|
||||
|
||||
encoder_padding_mask: Optional[Tensor] = \
|
||||
encoder_out.encoder_padding_mask
|
||||
encoder_embedding: Optional[Tensor] = encoder_out.encoder_embedding
|
||||
|
||||
new_encoder_out = (
|
||||
encoder_out.encoder_out
|
||||
if encoder_out.encoder_out is None
|
||||
else encoder_out.encoder_out.index_select(1, new_order)
|
||||
)
|
||||
|
||||
new_encoder_padding_mask = (
|
||||
encoder_padding_mask
|
||||
if encoder_padding_mask is None
|
||||
else encoder_padding_mask.index_select(0, new_order)
|
||||
)
|
||||
|
||||
new_encoder_embedding = (
|
||||
encoder_embedding
|
||||
if encoder_embedding is None
|
||||
else encoder_embedding.index_select(0, new_order)
|
||||
)
|
||||
|
||||
encoder_states = encoder_out.encoder_states
|
||||
if encoder_states is not None:
|
||||
for idx, state in enumerate(encoder_states):
|
||||
encoder_states[idx] = state.index_select(1, new_order)
|
||||
|
||||
return EncoderOut(
|
||||
encoder_out=new_encoder_out, # T x B x C
|
||||
encoder_padding_mask=new_encoder_padding_mask, # B x T
|
||||
encoder_embedding=new_encoder_embedding, # B x T x C
|
||||
encoder_states=encoder_states, # List[T x B x C]
|
||||
src_tokens=None,
|
||||
src_lengths=None,
|
||||
)
|
||||
|
||||
|
||||
class TransformerDecoderScriptable(TransformerDecoder):
|
||||
def extract_features(
|
||||
self,
|
||||
prev_output_tokens,
|
||||
encoder_out: Optional[EncoderOut] = None,
|
||||
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
|
||||
full_context_alignment: bool = False,
|
||||
alignment_layer: Optional[int] = None,
|
||||
alignment_heads: Optional[int] = None,
|
||||
):
|
||||
# call scriptable method from parent class
|
||||
x, _ = self.extract_features_scriptable(
|
||||
prev_output_tokens, encoder_out, incremental_state,
|
||||
full_context_alignment, alignment_layer, alignment_heads,
|
||||
)
|
||||
return x, None
|
||||
|
||||
|
||||
@register_model_architecture(model_name="s2t_transformer",
|
||||
arch_name="s2t_transformer")
|
||||
def base_architecture(args):
|
||||
# Convolutional subsampler
|
||||
args.conv_kernel_sizes = getattr(args, "conv_kernel_sizes", '5,5')
|
||||
args.conv_channels = getattr(args, "conv_channels", 1024)
|
||||
# Transformer
|
||||
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
|
||||
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048)
|
||||
args.encoder_layers = getattr(args, "encoder_layers", 12)
|
||||
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8)
|
||||
args.encoder_normalize_before = getattr(args, "encoder_normalize_before",
|
||||
True)
|
||||
args.decoder_embed_dim = getattr(args, "decoder_embed_dim",
|
||||
args.encoder_embed_dim)
|
||||
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim",
|
||||
args.encoder_ffn_embed_dim)
|
||||
args.decoder_layers = getattr(args, "decoder_layers", 6)
|
||||
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8)
|
||||
args.decoder_normalize_before = getattr(args, "decoder_normalize_before",
|
||||
True)
|
||||
args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False)
|
||||
args.dropout = getattr(args, "dropout", 0.1)
|
||||
args.attention_dropout = getattr(args, "attention_dropout", args.dropout)
|
||||
args.activation_dropout = getattr(args, "activation_dropout", args.dropout)
|
||||
args.activation_fn = getattr(args, "activation_fn", "relu")
|
||||
args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff",
|
||||
None)
|
||||
args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
|
||||
args.share_decoder_input_output_embed = getattr(
|
||||
args, "share_decoder_input_output_embed", False
|
||||
)
|
||||
args.no_token_positional_embeddings = getattr(
|
||||
args, "no_token_positional_embeddings", False
|
||||
)
|
||||
args.adaptive_input = getattr(args, "adaptive_input", False)
|
||||
args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0.0)
|
||||
args.decoder_output_dim = getattr(args, "decoder_output_dim",
|
||||
args.decoder_embed_dim)
|
||||
args.decoder_input_dim = getattr(args, "decoder_input_dim",
|
||||
args.decoder_embed_dim)
|
||||
args.no_scale_embedding = getattr(args, "no_scale_embedding", False)
|
||||
args.quant_noise_pq = getattr(args, "quant_noise_pq", 0)
|
||||
|
||||
|
||||
@register_model_architecture("s2t_transformer", "s2t_transformer_s")
|
||||
def s2t_transformer_s(args):
|
||||
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 256)
|
||||
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 256 * 8)
|
||||
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4)
|
||||
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 4)
|
||||
args.dropout = getattr(args, "dropout", 0.1)
|
||||
base_architecture(args)
|
||||
|
||||
|
||||
@register_model_architecture("s2t_transformer", "s2t_transformer_sp")
|
||||
def s2t_transformer_sp(args):
|
||||
args.encoder_layers = getattr(args, "encoder_layers", 16)
|
||||
s2t_transformer_s(args)
|
||||
|
||||
|
||||
@register_model_architecture("s2t_transformer", "s2t_transformer_m")
|
||||
def s2t_transformer_m(args):
|
||||
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
|
||||
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 512 * 4)
|
||||
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8)
|
||||
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8)
|
||||
args.dropout = getattr(args, "dropout", 0.15)
|
||||
base_architecture(args)
|
||||
|
||||
|
||||
@register_model_architecture("s2t_transformer", "s2t_transformer_mp")
|
||||
def s2t_transformer_mp(args):
|
||||
args.encoder_layers = getattr(args, "encoder_layers", 16)
|
||||
s2t_transformer_m(args)
|
||||
|
||||
|
||||
@register_model_architecture("s2t_transformer", "s2t_transformer_l")
|
||||
def s2t_transformer_l(args):
|
||||
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024)
|
||||
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim",
|
||||
1024 * 4)
|
||||
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16)
|
||||
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16)
|
||||
args.dropout = getattr(args, "dropout", 0.2)
|
||||
base_architecture(args)
|
||||
|
||||
|
||||
@register_model_architecture("s2t_transformer", "s2t_transformer_lp")
|
||||
def s2t_transformer_lp(args):
|
||||
args.encoder_layers = getattr(args, "encoder_layers", 16)
|
||||
s2t_transformer_l(args)
|
@ -10,10 +10,9 @@ from argparse import Namespace
|
||||
|
||||
import torch
|
||||
from fairseq import metrics, search, tokenizer, utils
|
||||
from fairseq.data import Dictionary, FairseqDataset, data_utils, iterators
|
||||
from fairseq.data import Dictionary, FairseqDataset, data_utils, iterators, encoders
|
||||
from fairseq.dataclass.utils import gen_parser_from_dataclass
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -504,6 +503,14 @@ class FairseqTask(object):
|
||||
for this task)."""
|
||||
raise NotImplementedError
|
||||
|
||||
def build_tokenizer(self, args):
|
||||
"""Build the pre-tokenizer for this task."""
|
||||
return encoders.build_tokenizer(args)
|
||||
|
||||
def build_bpe(self, args):
|
||||
"""Build the tokenizer for this task."""
|
||||
return encoders.build_bpe(args)
|
||||
|
||||
|
||||
class LegacyFairseqTask(FairseqTask):
|
||||
def __init__(self, args: Namespace):
|
||||
|
120
fairseq/tasks/speech_to_text.py
Normal file
120
fairseq/tasks/speech_to_text.py
Normal file
@ -0,0 +1,120 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
from argparse import Namespace
|
||||
import os.path as op
|
||||
|
||||
from fairseq.data import encoders, Dictionary
|
||||
from fairseq.data.audio.speech_to_text_dataset import (
|
||||
SpeechToTextDataset, SpeechToTextDatasetCreator, S2TDataConfig
|
||||
)
|
||||
from fairseq.tasks import FairseqTask, register_task
|
||||
|
||||
logging.basicConfig(
|
||||
format='%(asctime)s | %(levelname)s | %(name)s | %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S',
|
||||
level=logging.INFO,
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@register_task('speech_to_text')
|
||||
class SpeechToTextTask(FairseqTask):
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
parser.add_argument('data', help='manifest root path')
|
||||
parser.add_argument(
|
||||
'--config-yaml', type=str, default='config.yaml',
|
||||
help='Configuration YAML filename (under manifest root)'
|
||||
)
|
||||
parser.add_argument('--max-source-positions', default=6000, type=int,
|
||||
metavar='N',
|
||||
help='max number of tokens in the source sequence')
|
||||
parser.add_argument('--max-target-positions', default=1024, type=int,
|
||||
metavar='N',
|
||||
help='max number of tokens in the target sequence')
|
||||
|
||||
def __init__(self, args, tgt_dict):
|
||||
super().__init__(args)
|
||||
self.tgt_dict = tgt_dict
|
||||
self.data_cfg = S2TDataConfig(op.join(args.data, args.config_yaml))
|
||||
|
||||
@classmethod
|
||||
def setup_task(cls, args, **kwargs):
|
||||
data_cfg = S2TDataConfig(op.join(args.data, args.config_yaml))
|
||||
dict_path = op.join(args.data, data_cfg.vocab_filename)
|
||||
if not op.isfile(dict_path):
|
||||
raise FileNotFoundError(f'Dict not found: {dict_path}')
|
||||
tgt_dict = Dictionary.load(dict_path)
|
||||
logger.info(f'dictionary size ({data_cfg.vocab_filename}): '
|
||||
f'{len(tgt_dict):,}')
|
||||
|
||||
if getattr(args, 'train_subset', None) is not None:
|
||||
if not all(s.startswith('train') for s in args.train_subset.split(',')):
|
||||
raise ValueError('Train splits should be named like "train*".')
|
||||
return cls(args, tgt_dict)
|
||||
|
||||
def build_criterion(self, args):
|
||||
from fairseq import criterions
|
||||
if self.data_cfg.prepend_tgt_lang_tag and args.ignore_prefix_size != 1:
|
||||
raise ValueError('Please set "--ignore-prefix-size 1" since '
|
||||
'target language ID token is prepended as BOS.')
|
||||
return criterions.build_criterion(args, self)
|
||||
|
||||
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
|
||||
is_train_split = split.startswith('train')
|
||||
pre_tokenizer = self.build_tokenizer(self.args)
|
||||
bpe_tokenizer = self.build_bpe(self.args)
|
||||
self.datasets[split] = SpeechToTextDatasetCreator.from_tsv(
|
||||
self.args.data, self.data_cfg, split, self.tgt_dict,
|
||||
pre_tokenizer, bpe_tokenizer, is_train_split=is_train_split,
|
||||
epoch=epoch, seed=self.args.seed
|
||||
)
|
||||
|
||||
@property
|
||||
def target_dictionary(self):
|
||||
return self.tgt_dict
|
||||
|
||||
@property
|
||||
def source_dictionary(self):
|
||||
return None
|
||||
|
||||
def max_positions(self):
|
||||
return self.args.max_source_positions, self.args.max_target_positions
|
||||
|
||||
def build_model(self, args):
|
||||
args.input_feat_per_channel = self.data_cfg.input_feat_per_channel
|
||||
args.input_channels = self.data_cfg.input_channels
|
||||
return super(SpeechToTextTask, self).build_model(args)
|
||||
|
||||
def build_generator(
|
||||
self, models, args, seq_gen_cls=None, extra_gen_cls_kwargs=None,
|
||||
):
|
||||
if self.data_cfg.prepend_tgt_lang_tag and args.prefix_size != 1:
|
||||
raise ValueError('Please set "--prefix-size 1" since '
|
||||
'target language ID token is prepended as BOS.')
|
||||
lang_token_ids = {
|
||||
i for s, i in self.tgt_dict.indices.items()
|
||||
if SpeechToTextDataset.is_lang_tag(s)
|
||||
}
|
||||
extra_gen_cls_kwargs = {'symbols_to_strip_from_output': lang_token_ids}
|
||||
return super().build_generator(
|
||||
models, args, seq_gen_cls=None,
|
||||
extra_gen_cls_kwargs=extra_gen_cls_kwargs
|
||||
)
|
||||
|
||||
def build_tokenizer(self, args):
|
||||
logger.info(f'pre-tokenizer: {self.data_cfg.pre_tokenizer}')
|
||||
return encoders.build_tokenizer(Namespace(**self.data_cfg.pre_tokenizer))
|
||||
|
||||
def build_bpe(self, args):
|
||||
logger.info(f'tokenizer: {self.data_cfg.bpe_tokenizer}')
|
||||
return encoders.build_bpe(Namespace(**self.data_cfg.bpe_tokenizer))
|
||||
|
||||
@classmethod
|
||||
def build_dataset_for_inference(cls, audio_paths, n_frames):
|
||||
return SpeechToTextDataset('interactive', False, {}, audio_paths,
|
||||
n_frames)
|
@ -21,7 +21,6 @@ import torch
|
||||
from fairseq import checkpoint_utils, options, scoring, tasks, utils
|
||||
from fairseq.logging import progress_bar
|
||||
from fairseq.logging.meters import StopwatchMeter, TimeMeter
|
||||
from fairseq.data import encoders
|
||||
|
||||
|
||||
def main(args):
|
||||
@ -158,8 +157,8 @@ def _main(args, output_file):
|
||||
generator = task.build_generator(models, args, extra_gen_cls_kwargs=extra_gen_cls_kwargs)
|
||||
|
||||
# Handle tokenization and BPE
|
||||
tokenizer = encoders.build_tokenizer(args)
|
||||
bpe = encoders.build_bpe(args)
|
||||
tokenizer = task.build_tokenizer(args)
|
||||
bpe = task.build_bpe(args)
|
||||
|
||||
def decode_fn(x):
|
||||
if bpe is not None:
|
||||
|
2
setup.py
2
setup.py
@ -141,7 +141,7 @@ setup(
|
||||
'hydra-core',
|
||||
'numpy',
|
||||
'regex',
|
||||
'sacrebleu',
|
||||
'sacrebleu>=1.4.12',
|
||||
'torch',
|
||||
'tqdm',
|
||||
],
|
||||
|
@ -38,6 +38,7 @@ class TestLabelSmoothing(unittest.TestCase):
|
||||
# build model
|
||||
self.args = argparse.Namespace()
|
||||
self.args.sentence_avg = False
|
||||
self.args.report_accuracy = False
|
||||
self.args.probs = torch.FloatTensor([
|
||||
# pad eos unk w1 w2 w3
|
||||
[0.05, 0.05, 0.1, 0.05, 0.3, 0.4, 0.05],
|
||||
|
Loading…
Reference in New Issue
Block a user