speech-to-text OSS

Summary:
Imported from https://github.com/fairinternal/fairseq-py/pull/1284. Updated according to PR comments.

Main changes:
* New task: `fairseq.tasks.speech_to_text`
  * Multilingual support: multiple train sub-splits, temperature-based sampling, language ID tokens
* New dataset: `fairseq.data.audio.speech_to_text_dataset`
* Added accuracy metrics and BOS prefix removal to label smoothed cross entropy
* New models: Transformer (`fairseq.models.speech_to_text.s2t_transformer`) and BLSTM (`fairseq.models.speech_to_text.berard`)
* Extended scorers:
  * Added a base scorer class: `fairseq.scorers.BaseScorer` (the parent class for all scorers except the BLEU scorer in CPP)
  * Added an evaluation tokenizer: `fairseq.scorers.eval_tokenizer` which leverages sacreBLEU's built-in tokenizers and allows character-level tokenization as well as punctuation removal (for WER scoring).
  * Added chrF scorer: `fairseq.scorers.chrf`
* Online Mel-filter bank speech feature extraction (via CPP-based pyKaldi or Python-based TorchAudio): `fairseq.data.audio.audio_utils`
* Online speech feature transforms: `fairseq.data.audio.feature_transforms.*`
* Fixed the subsampled sequence lengths in VGGTransformer (`examples.speech_recognition.models.vggtransformer`)
* Examples under `examples/speech_to_text`:
  * LibriSpeech (ASR): better results than VGGTransformer with smaller Transformer-based models
  * MuST-C (ST): comparable to [SOTA results](https://arxiv.org/pdf/2004.10234.pdf) but with less tricks

Reviewed By: jmp84

Differential Revision: D24065273

fbshipit-source-id: 5f842ca9c826f92d4af660705611885fe440a9ab
This commit is contained in:
Changhan Wang 2020-10-14 12:27:45 -07:00 committed by Facebook GitHub Bot
parent a2d0be4989
commit 1d1c145387
22 changed files with 2941 additions and 16 deletions

View File

@ -251,6 +251,7 @@ class VGGTransformerEncoder(FairseqEncoder):
self.conv_layers = nn.ModuleList()
self.in_channels = in_channels
self.input_dim = input_feat_per_channel
self.pooling_kernel_sizes = []
if vggblock_config is not None:
for _, config in enumerate(vggblock_config):
@ -272,6 +273,7 @@ class VGGTransformerEncoder(FairseqEncoder):
layer_norm=layer_norm,
)
)
self.pooling_kernel_sizes.append(pooling_kernel_size)
in_channels = out_channels
input_feat_per_channel = self.conv_layers[-1].output_dim
@ -336,9 +338,9 @@ class VGGTransformerEncoder(FairseqEncoder):
x = x.transpose(1, 2).transpose(0, 1)
x = x.contiguous().view(output_seq_len, bsz, -1)
subsampling_factor = int(max_seq_len * 1.0 / output_seq_len + 0.5)
# TODO: shouldn't subsampling_factor determined in advance ?
input_lengths = (src_lengths.float() / subsampling_factor).ceil().long()
input_lengths = src_lengths.clone()
for s in self.pooling_kernel_sizes:
input_lengths = (input_lengths.float() / s).ceil().long()
encoder_padding_mask, _ = lengths_to_encoder_padding_mask(
input_lengths, batch_first=True
@ -346,6 +348,7 @@ class VGGTransformerEncoder(FairseqEncoder):
if not encoder_padding_mask.any():
encoder_padding_mask = None
subsampling_factor = int(max_seq_len * 1.0 / output_seq_len + 0.5)
attn_mask = self.lengths_to_attn_mask(input_lengths, subsampling_factor)
transformer_layer_idx = 0

View File

@ -0,0 +1,216 @@
# Speech-to-Text (S2T) Modeling
## Data Preparation
S2T modeling data consists of source speech features, target text and other optional information
(source text, speaker id, etc.). Fairseq S2T uses per-dataset-split TSV manifest files
to store these information. Each data field is represented by a column in the TSV file.
Unlike text token embeddings, speech features (e.g. log mel-filter banks) are usually fixed
during model training and can be pre-computed. The manifest file contains the path to
either the feature file in NumPy format or the WAV/FLAC audio file. For the latter,
features will be extracted on-the-fly by fairseq S2T. Optionally, feature/audio files can be packed
into uncompressed ZIP files (then accessed via byte offset and length) to improve I/O performance.
Fairseq S2T also employs a YAML file for data related configurations: tokenizer type and dictionary path
for the target text, feature transforms such as CMVN (cepstral mean and variance normalization) and SpecAugment,
temperature-based resampling, etc.
## Model Training & Evaluation
Fairseq S2T uses the unified `fairseq-train`/`fairseq-generate` interface for model training and evaluation.
It requires arguments `--task speech_to_text` and `--arch <arch in fairseq.models.speech_to_text.*>`.
## Example 1: Speech Recognition (ASR) on LibriSpeech
#### Data preparation
Download and preprocess LibriSpeech data with
```bash
python examples/speech_to_text/prep_librispeech_data.py \
--output-root ${LS_ROOT} --vocab-type unigram --vocab-size 10000
```
where `LS_ROOT` is the root path for downloaded data as well as generated manifest and feature files.
#### Training
```bash
fairseq-train ${LS_ROOT} --train-subset train --valid-subset dev --save-dir ${SAVE_DIR} --num-workers 4 \
--max-tokens 40000 --task speech_to_text --criterion label_smoothed_cross_entropy --max-update 300000 \
--arch s2t_transformer_s --optimizer adam --lr 2e-3 --lr-scheduler inverse_sqrt --warmup-updates 10000 \
--clip-norm 10.0 --seed 1 --update-freq 8
```
where `SAVE_DIR` is the checkpoint root path. Here we use `--arch s2t_transformer_s` (31M parameters) as example.
You may switch to `s2t_transformer_m` (71M) or `s2t_transformer_l` (268M) for better performance. We set
`--update-freq 8` to simulate 8 GPUs with 1 GPU. You may want to update it accordingly when using more than 1 GPU.
#### Inference & Evaluation
Average the last 10 checkpoints and evaluate on the 4 splits
(`dev-clean`, `dev-other`, `test-clean` and `test-other`):
```bash
CHECKPOINT_FILENAME=avg_last_10_checkpoint.pt
python scripts/average_checkpoints.py \
--inputs ${SAVE_DIR} --num-epoch-checkpoints 10 --output "${SAVE_DIR}/${CHECKPOINT_FILENAME}"
for SUBSET in dev-clean dev-other test-clean test-other; do
fairseq-generate ${LS_ROOT} --gen-subset ${SUBSET} --task speech_to_text \
--path ${SAVE_DIR}/${CHECKPOINT_FILENAME} --max-tokens 50000 --beam 5 --scoring wer
done
```
#### Result
| --arch | Params | dev-clean | dev-other | test-clean | test-other |
|---|---|---|---|---|---|
| s2t_transformer_s | 30M | 4.1 | 9.3 | 4.4 | 9.2 |
| s2t_transformer_sp | 35M | 3.9 | 9.3 | 4.3 | 8.8 |
| s2t_transformer_m | 71M | 3.5 | 8.1 | 3.7 | 8.1 |
| s2t_transformer_mp | 84M | 3.3 | 7.8 | 3.7 | 8.2 |
| s2t_transformer_l | 268M | 3.3 | 7.7 | 3.5 | 7.8 |
| s2t_transformer_lp | 318M | 3.1 | 7.5 | 3.4 | 7.6 |
## Example 2: Speech Translation (ST) on MuST-C
#### Data Preparation
[Download](https://ict.fbk.eu/must-c) and unpack MuST-C data to a path `MUSTC_ROOT`, then preprocess it with
```bash
python examples/speech_to_text/prep_mustc_data.py --data-root ${MUSTC_ROOT} \
--asr-vocab-type unigram --asr-vocab-size 5000 \
--st-vocab-type unigram --st-vocab-size 8000
```
The generated manifest and feature files will be available under `MUSTC_ROOT`.
#### ASR
###### Training
```bash
fairseq-train ${MUSTC_ROOT} --train-subset train_asr --valid-subset dev_asr --save-dir ${ASR_SAVE_DIR} \
--num-workers 4 --max-tokens 40000 --task speech_to_text --criterion label_smoothed_cross_entropy \
--report-accuracy --max-update 100000 --arch s2t_transformer_s --optimizer adam --lr 1e-3 \
--lr-scheduler inverse_sqrt --warmup-updates 10000 --clip-norm 10.0 --seed 1 --update-freq 8
```
where `ASR_SAVE_DIR` is the checkpoint root path. We set `--update-freq 8` to simulate 8 GPUs with 1 GPU.
You may want to update it accordingly when using more than 1 GPU.
###### Inference & Evaluation
```bash
CHECKPOINT_FILENAME=avg_last_10_checkpoint.pt
python scripts/average_checkpoints.py \
--inputs ${ASR_SAVE_DIR} --num-epoch-checkpoints 10 --output "${ASR_SAVE_DIR}/${CHECKPOINT_FILENAME}"
fairseq-generate ${MUSTC_ROOT} --gen-subset tst-COMMON_asr --task speech_to_text \
--path ${ASR_SAVE_DIR}/${CHECKPOINT_FILENAME} --max-tokens 50000 --beam 5 \
--scoring wer --wer-tokenizer 13a --wer-lowercase --wer-remove-punct
```
###### Result
| --arch | Params | En-De | En-Nl | En-Es | En-Fr | En-It | En-Pt | En-Ro | En-Ru |
|---|---|---|---|---|---|---|---|---|---|
| s2t_transformer_s | 31M | 18.2 | 17.6 | 17.7 | 17.2 | 17.9 | 19.1 | 18.1 | 17.7 |
#### ST
###### Training
```bash
fairseq-train ${MUSTC_ROOT} --train-subset train_st --valid-subset dev_st --save-dir ${ST_SAVE_DIR} \
--num-workers 4 --max-tokens 40000 --task speech_to_text --criterion label_smoothed_cross_entropy \
--report-accuracy --max-update 100000 --arch s2t_transformer_s --optimizer adam --lr 2e-3 \
--lr-scheduler inverse_sqrt --warmup-updates 10000 --clip-norm 10.0 --seed 1 --update-freq 8 \
--load-pretrained-encoder-from ${ASR_SAVE_DIR}/${CHECKPOINT_FILENAME}
```
where `ST_SAVE_DIR` is the checkpoint root path. The ST encoder is pre-trained by ASR for faster training and better
performance: `--load-pretrained-encoder-from <ASR checkpoint path>`. We set `--update-freq 8` to simulate 8 GPUs with 1 GPU.
You may want to update it accordingly when using more than 1 GPU.
###### Inference & Evaluation
Average the last 10 checkpoints and evaluate on the `tst-COMMON` split:
```bash
CHECKPOINT_FILENAME=avg_last_10_checkpoint.pt
python scripts/average_checkpoints.py \
--inputs ${ST_SAVE_DIR} --num-epoch-checkpoints 10 --output "${ST_SAVE_DIR}/${CHECKPOINT_FILENAME}"
fairseq-generate ${MUSTC_ROOT} --gen-subset tst-COMMON_st --task speech_to_text \
--path ${ST_SAVE_DIR}/${CHECKPOINT_FILENAME} --max-tokens 50000 --beam 5 --scoring sacrebleu
```
###### Result
| --arch | Params | En-De | En-Nl | En-Es | En-Fr | En-It | En-Pt | En-Ro | En-Ru |
|---|---|---|---|---|---|---|---|---|---|
| s2t_transformer_s | 31M | 22.7 | 27.3 | 27.2 | 32.9 | 22.7 | 28.1 | 21.9 | 15.3 |
## Example 3: ST on CoVoST
#### Data Preparation
Download and preprocess CoVoST data with
```bash
# En ASR
python examples/speech_to_text/prep_covost_data.py --data-root ${COVOST_ROOT} \
--vocab-type char --src-lang en
# ST
python examples/speech_to_text/prep_covost_data.py --data-root ${COVOST_ROOT} \
--vocab-type char --src-lang fr --tgt-lang en
```
where `COVOST_ROOT` is the root path for downloaded data as well as generated manifest and feature files.
#### ASR
###### Training
```bash
fairseq-train ${COVOST_ROOT} --train-subset train_asr --valid-subset dev_asr --save-dir ${ASR_SAVE_DIR} \
--num-workers 4 --max-tokens 40000 --task speech_to_text --criterion label_smoothed_cross_entropy \
--report-accuracy --max-update 100000 --arch s2t_transformer_s --optimizer adam --lr 1e-3 \
--lr-scheduler inverse_sqrt --warmup-updates 10000 --clip-norm 10.0 --seed 1 --update-freq 8
```
where `ASR_SAVE_DIR` is the checkpoint root path. We set `--update-freq 8` to simulate 8 GPUs with 1 GPU.
You may want to update it accordingly when using more than 1 GPU.
###### Inference & Evaluation
```bash
CHECKPOINT_FILENAME=avg_last_10_checkpoint.pt
python scripts/average_checkpoints.py \
--inputs ${ASR_SAVE_DIR} --num-epoch-checkpoints 10 --output "${ASR_SAVE_DIR}/${CHECKPOINT_FILENAME}"
fairseq-generate ${COVOST_ROOT} --gen-subset test_asr_en --task speech_to_text \
--path ${ASR_SAVE_DIR}/${CHECKPOINT_FILENAME} --max-tokens 50000 --beam 5 \
--scoring wer --wer-tokenizer 13a --wer-lowercase --wer-remove-punct
```
###### Result
| --arch | Params | En |
|---|---|---|
| s2t_transformer_s | 31M | 25.6 |
#### ST
###### Training
```bash
fairseq-train ${COVOST_ROOT} --train-subset train_st_fr_en --valid-subset dev_st_fr_en --save-dir ${ST_SAVE_DIR} \
--num-workers 4 --max-tokens 40000 --task speech_to_text --criterion label_smoothed_cross_entropy \
--report-accuracy --max-update 100000 --arch s2t_transformer_s --optimizer adam --lr 2e-3 \
--lr-scheduler inverse_sqrt --warmup-updates 10000 --clip-norm 10.0 --seed 1 --update-freq 8 \
--load-pretrained-encoder-from ${ASR_SAVE_DIR}/${CHECKPOINT_FILENAME}
```
where `ST_SAVE_DIR` is the checkpoint root path. The ST encoder is pre-trained by En ASR for faster training and better
performance: `--load-pretrained-encoder-from <ASR checkpoint path>`. We set `--update-freq 8` to simulate 8 GPUs with 1 GPU.
You may want to update it accordingly when using more than 1 GPU.
###### Inference & Evaluation
Average the last 10 checkpoints and evaluate on test split:
```bash
CHECKPOINT_FILENAME=avg_last_10_checkpoint.pt
python scripts/average_checkpoints.py \
--inputs ${ST_SAVE_DIR} --num-epoch-checkpoints 10 --output "${ST_SAVE_DIR}/${CHECKPOINT_FILENAME}"
fairseq-generate ${COVOST_ROOT} --gen-subset test_st_fr_en --task speech_to_text \
--path ${ST_SAVE_DIR}/${CHECKPOINT_FILENAME} --max-tokens 50000 --beam 5 --scoring sacrebleu
```
###### Result
| --arch | Params | Fr-En | De-En | Es-En | Ca-En | En-De | En-Ca | En-Fa | En-Et |
|---|---|---|---|---|---|---|---|---|---|
| s2t_transformer_s | 31M | 26.3 | 17.1 | 23.0 | 18.8 | 16.3 | 21.8 | 13.1 | 13.2 |
## Citation
Please cite as:
```
@inproceedings{wang2020fairseqs2t,
title = {fairseq S2T: Fast Speech-to-Text Modeling with fairseq},
author = {Changhan Wang and Yun Tang and Xutai Ma and Anne Wu and Dmytro Okhonko and Juan Pino},
booktitle = {Proceedings of the 2020 Conference of the Asian Chapter of the Association for Computational Linguistics (AACL): System Demonstrations},
year = {2020},
}
@inproceedings{ott2019fairseq,
title = {fairseq: A Fast, Extensible Toolkit for Sequence Modeling},
author = {Myle Ott and Sergey Edunov and Alexei Baevski and Angela Fan and Sam Gross and Nathan Ng and David Grangier and Michael Auli},
booktitle = {Proceedings of NAACL-HLT 2019: Demonstrations},
year = {2019},
}
```

View File

@ -0,0 +1,218 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from multiprocessing import cpu_count
import os
import os.path as op
from glob import glob
import zipfile
import csv
from functools import reduce
from typing import Dict, Any, List
from fairseq.data.audio.audio_utils import _get_kaldi_fbank, _get_torchaudio_fbank
import sentencepiece as sp
from tqdm import tqdm
import numpy as np
from fairseq.data.audio.feature_transforms.utterance_cmvn import UtteranceCMVN
UNK_TOKEN, UNK_TOKEN_ID = '<unk>', 3
BOS_TOKEN, BOS_TOKEN_ID = '<s>', 0
EOS_TOKEN, EOS_TOKEN_ID = '</s>', 2
PAD_TOKEN, PAD_TOKEN_ID = '<pad>', 1
def gen_vocab(
input_path: str, output_path_prefix: str, model_type='bpe',
vocab_size=1000,
):
# Train SentencePiece Model
arguments = [
f'--input={input_path}',
f'--model_prefix={output_path_prefix}',
f'--model_type={model_type}',
f'--vocab_size={vocab_size}',
'--character_coverage=1.0',
f'--num_threads={cpu_count()}',
f'--unk_id={UNK_TOKEN_ID}',
f'--bos_id={BOS_TOKEN_ID}',
f'--eos_id={EOS_TOKEN_ID}',
f'--pad_id={PAD_TOKEN_ID}'
]
sp.SentencePieceTrainer.Train(' '.join(arguments))
# Export fairseq dictionary
spm = sp.SentencePieceProcessor()
spm.Load(output_path_prefix + '.model')
vocab = {i: spm.IdToPiece(i) for i in range(spm.GetPieceSize())}
assert vocab.get(UNK_TOKEN_ID) == UNK_TOKEN and \
vocab.get(PAD_TOKEN_ID) == PAD_TOKEN and \
vocab.get(BOS_TOKEN_ID) == BOS_TOKEN and \
vocab.get(EOS_TOKEN_ID) == EOS_TOKEN
vocab = {
i: s for i, s in vocab.items()
if s not in {UNK_TOKEN, BOS_TOKEN, EOS_TOKEN, PAD_TOKEN}
}
with open(output_path_prefix + '.txt', 'w') as f_out:
for _, s in sorted(vocab.items(), key=lambda x: x[0]):
f_out.write(f'{s} 1\n')
def extract_fbank_features(waveform, sample_rate, output_path=None,
n_mel_bins=80, apply_utterance_cmvn=True,
overwrite=False):
if output_path is not None and op.exists(output_path) and not overwrite:
return
_waveform = waveform * (2 ** 15) # Kaldi compliance: 16-bit signed integers
_waveform = _waveform.squeeze().numpy()
features = _get_kaldi_fbank(_waveform, sample_rate, n_mel_bins)
if features is None:
features = _get_torchaudio_fbank(_waveform, sample_rate, n_mel_bins)
if features is None:
raise ImportError('Please install pyKaldi or torchaudio to enable '
'online filterbank feature extraction')
if apply_utterance_cmvn:
cmvn = UtteranceCMVN(norm_means=True, norm_vars=True)
features = cmvn(features)
if output_path is not None:
np.save(output_path, features)
else:
return features
def create_zip(data_root, zip_path):
cwd = os.path.abspath(os.curdir)
os.chdir(data_root)
with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_STORED) as f:
for filename in tqdm(glob('*.npy')):
f.write(filename)
os.chdir(cwd)
def is_npy_data(data: bytes) -> bool:
return data[0] == 147 and data[1] == 78
def get_zip_manifest(zip_root, zip_filename):
zip_path = op.join(zip_root, zip_filename)
with zipfile.ZipFile(zip_path, mode='r') as f:
info = f.infolist()
manifest = {}
for i in tqdm(info):
utt_id = op.splitext(i.filename)[0]
offset, file_size = i.header_offset + 30 + len(i.filename), i.file_size
manifest[utt_id] = f'{zip_filename}:{offset}:{file_size}'
with open(zip_path, 'rb') as f:
f.seek(offset)
data = f.read(file_size)
assert len(data) > 1 and is_npy_data(data)
return manifest
def gen_config_yaml(data_root, spm_filename, yaml_filename='config.yaml',
specaugment_policy='lb'):
assert specaugment_policy in {'lb', 'ld'}
data_root = op.abspath(data_root)
writer = S2TDataConfigWriter(op.join(data_root, yaml_filename))
writer.set_audio_root(op.abspath(data_root))
writer.set_vocab_filename(spm_filename.replace(".model", ".txt"))
writer.set_input_channels(1)
writer.set_input_feat_per_channel(80)
if specaugment_policy == 'lb':
writer.set_specaugment_lb_policy()
else:
writer.set_specaugment_ld_policy()
writer.set_bpe_tokenizer(
{'bpe': 'sentencepiece',
'sentencepiece_model': op.join(data_root, spm_filename)}
)
writer.set_feature_transforms('_train', ['specaugment'])
writer.flush()
def save_df_to_tsv(dataframe, path):
dataframe.to_csv(path, sep="\t", header=True, index=False, encoding="utf-8",
escapechar='\\', quoting=csv.QUOTE_NONE)
def filter_manifest_df(df, is_train_split=False, extra_filters=None,
min_n_frames=5, max_n_frames=3000):
filters = {
'no speech': df['audio'] == '',
f'short speech (<{min_n_frames} frames)': df['n_frames'] < min_n_frames,
'empty sentence': df['tgt_text'] == '',
}
if is_train_split:
filters[f'long speech (>{max_n_frames} frames)'] = \
df['n_frames'] > max_n_frames
if extra_filters is not None:
filters.update(extra_filters)
invalid = reduce(lambda x, y: x | y, filters.values())
valid = ~invalid
print(
'| ' + ', '.join(f'{n}: {f.sum()}' for n, f in filters.items()) +
f', total {invalid.sum()} filtered, {valid.sum()} remained.'
)
return df[valid]
class S2TDataConfigWriter(object):
DEFAULT_VOCAB_FILENAME = 'dict.txt'
DEFAULT_INPUT_FEAT_PER_CHANNEL = 80
DEFAULT_INPUT_CHANNELS = 1
def __init__(self, yaml_path):
try:
import yaml
except ImportError:
print('Please install PyYAML to load YAML files for S2T data config')
self.yaml = yaml
self.yaml_path = yaml_path
self.config = {}
def flush(self):
with open(self.yaml_path, 'w') as f:
self.yaml.dump(self.config, f)
def set_audio_root(self, audio_root=''):
self.config['audio_root'] = audio_root
def set_vocab_filename(self, vocab_filename='dict.txt'):
self.config['vocab_filename'] = vocab_filename
def set_specaugment(self, time_wrap_w: int, freq_mask_n: int,
freq_mask_f: int, time_mask_n: int, time_mask_t: int,
time_mask_p: float):
self.config['specaugment'] = {
'time_wrap_W': time_wrap_w, 'freq_mask_N': freq_mask_n,
'freq_mask_F': freq_mask_f, 'time_mask_N': time_mask_n,
'time_mask_T': time_mask_t, 'time_mask_p': time_mask_p,
}
def set_specaugment_lb_policy(self):
self.set_specaugment(time_wrap_w=0, freq_mask_n=1, freq_mask_f=27,
time_mask_n=1, time_mask_t=100, time_mask_p=1.0)
def set_specaugment_ld_policy(self):
self.set_specaugment(time_wrap_w=0, freq_mask_n=2, freq_mask_f=27,
time_mask_n=2, time_mask_t=100, time_mask_p=1.0)
def set_input_channels(self, input_channels=1):
self.config['input_channels'] = input_channels
def set_input_feat_per_channel(self, input_feat_per_channel=80):
self.config['input_feat_per_channel'] = input_feat_per_channel
def set_bpe_tokenizer(self, bpe_tokenizer: Dict[str, Any]):
self.config['bpe_tokenizer'] = bpe_tokenizer
def set_feature_transforms(self, split, transforms: List[str]):
if 'transforms' not in self.config:
self.config['transforms'] = {}
self.config['transforms'][split] = transforms

View File

@ -0,0 +1,232 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import logging
from tempfile import NamedTemporaryFile
import os
import os.path as op
import shutil
from typing import Tuple, Optional
import csv
from torchaudio.datasets.utils import download_url, extract_archive
from tqdm import tqdm
import pandas as pd
from torch.utils.data import Dataset
import torchaudio
from torch import Tensor
from examples.speech_to_text.data_utils import (
gen_vocab, create_zip, get_zip_manifest, save_df_to_tsv,
extract_fbank_features, gen_config_yaml, filter_manifest_df
)
log = logging.getLogger(__name__)
MANIFEST_COLUMNS = ['id', 'audio', 'n_frames', 'tgt_text', 'speaker']
class CoVoST(Dataset):
"""Create a Dataset for CoVoST (https://github.com/facebookresearch/covost).
Args:
root (str): root path to the dataset and generated manifests/features
source_language (str): source (audio) language
target_language (str, optional): target (text) language,
None for no translation (default: None)
version (int, optional): CoVoST version. (default: 2)
download (bool, optional): Whether to download the dataset if it is not
found at root path. (default: ``False``).
"""
CV_URL_TEMPLATE = "https://voice-prod-bundler-ee1969a6ce8178826482b88" \
"e843c335139bd3fb4.s3.amazonaws.com/{ver}/{lang}.tar.gz"
COVOST_URL_TEMPLATE = "https://dl.fbaipublicfiles.com/covost/" \
"covost_v2.{src_lang}_{tgt_lang}.tsv.tar.gz"
VERSIONS = {2}
SPLITS = ['train', 'dev', 'test']
CV_VERSION_ID = {1: "cv-corpus-3", 2: "cv-corpus-4-2019-12-10"}
XX_EN_LANGUAGES = {
1: ['fr', 'de', 'nl', 'ru', 'es', 'it', 'tr', 'fa', 'sv-SE', 'mn',
'zh-CN'],
2: ['fr', 'de', 'es', 'ca', 'it', 'ru', 'zh-CN', 'pt', 'fa', 'et', 'mn',
'nl', 'tr', 'ar', 'sv-SE', 'lv', 'sl', 'ta', 'ja', 'id', 'cy']
}
EN_XX_LANGUAGES = {
1: [],
2: ['de', 'tr', 'fa', 'sv-SE', 'mn', 'zh-CN', 'cy', 'ca', 'sl', 'et',
'id',
'ar', 'ta', 'lv', 'ja']
}
def __init__(
self, root: str, split: str, source_language: str,
target_language: Optional[str] = None, version: int = 2,
download: bool = False
) -> None:
assert version in self.VERSIONS and split in self.SPLITS
assert source_language is not None
self.no_translation = (target_language is None)
if not self.no_translation:
assert 'en' in {source_language, target_language}
if source_language == 'en':
assert target_language in self.EN_XX_LANGUAGES[version]
else:
assert source_language in self.XX_EN_LANGUAGES[version]
else:
# Hack here so that we can get "split" column from CoVoST TSV.
# Note that we use CoVoST train split for ASR which is an extension
# to Common Voice train split.
target_language = 'de' if source_language == 'en' else 'en'
self.root = os.path.join(root, 'raw')
os.makedirs(self.root, exist_ok=True)
cv_url = self.CV_URL_TEMPLATE.format(ver=self.CV_VERSION_ID[version],
lang=source_language)
cv_archive = os.path.join(self.root, os.path.basename(cv_url))
if download:
if not os.path.isfile(cv_archive):
download_url(cv_url, self.root, hash_value=None)
extract_archive(cv_archive)
covost_url = self.COVOST_URL_TEMPLATE.format(src_lang=source_language,
tgt_lang=target_language)
covost_archive = os.path.join(self.root, os.path.basename(covost_url))
if download:
if not os.path.isfile(covost_archive):
download_url(covost_url, self.root, hash_value=None)
extract_archive(covost_archive)
cv_tsv = self.load_from_tsv(os.path.join(self.root, 'validated.tsv'))
covost_tsv = self.load_from_tsv(
os.path.join(self.root,
os.path.basename(covost_url).replace('.tar.gz', ''))
)
df = pd.merge(left=cv_tsv[['path', 'sentence', 'client_id']],
right=covost_tsv[['path', 'translation', 'split']],
how='inner', on='path')
if split == 'train':
df = df[(df['split'] == split) | (df['split'] == f'{split}_covost')]
else:
df = df[df['split'] == split]
self.data = df.to_dict(orient='index').items()
self.data = [v for k, v in sorted(self.data, key=lambda x: x[0])]
@classmethod
def load_from_tsv(cls, path: str):
return pd.read_csv(
path, sep='\t', header=0, encoding='utf-8', escapechar='\\',
quoting=csv.QUOTE_NONE, na_filter=False
)
def __getitem__(
self, n: int
) -> Tuple[Tensor, int, str, str, Optional[str], str, str]:
"""Load the n-th sample from the dataset.
Args:
n (int): The index of the sample to be loaded
Returns:
tuple: ``(waveform, sample_rate, sentence, translation, speaker_id,
sample_id)``
"""
data = self.data[n]
path = os.path.join(self.root, 'clips', data['path'])
waveform, sample_rate = torchaudio.load(path)
sentence = data['sentence']
translation = None if self.no_translation else data['translation']
speaker_id = data['client_id']
_id = data['path'].replace('.mp3', '')
return waveform, sample_rate, sentence, translation, speaker_id, _id
def __len__(self) -> int:
return len(self.data)
def process(args):
root = op.join(args.data_root, args.src_lang)
os.makedirs(root, exist_ok=True)
# Extract features
feature_root = op.join(root, 'fbank80')
os.makedirs(feature_root, exist_ok=True)
for split in CoVoST.SPLITS:
print(f'Fetching split {split}...')
dataset = CoVoST(root, split, args.src_lang, args.tgt_lang,
download=True)
print('Extracting log mel filter bank features...')
for waveform, sample_rate, _, _, _, utt_id in tqdm(dataset):
extract_fbank_features(waveform, sample_rate,
op.join(feature_root, f'{utt_id}.npy'))
# Pack features into ZIP
zip_filename = 'fbank80.zip'
zip_path = op.join(root, zip_filename)
print('ZIPing features...')
create_zip(feature_root, zip_path)
print('Fetching ZIP manifest...')
zip_manifest = get_zip_manifest(args.data_root,
f'{args.src_lang}/{zip_filename}')
# Generate TSV manifest
print('Generating manifest...')
train_text = []
task = f'asr_{args.src_lang}'
if args.tgt_lang is not None:
task = f'st_{args.src_lang}_{args.tgt_lang}'
for split in CoVoST.SPLITS:
manifest = {c: [] for c in MANIFEST_COLUMNS}
dataset = CoVoST(root, split, args.src_lang, args.tgt_lang)
for wav, sr, src_utt, tgt_utt, speaker_id, utt_id in tqdm(dataset):
manifest['id'].append(utt_id)
manifest['audio'].append(zip_manifest[utt_id])
duration_ms = int(wav.size(1) / sr * 1000)
manifest['n_frames'].append(int(1 + (duration_ms - 25) / 10))
manifest['tgt_text'].append(
src_utt if args.tgt_lang is None else tgt_utt
)
manifest['speaker'].append(speaker_id)
is_train_split = split.startswith('train')
if is_train_split:
train_text.extend(manifest['tgt_text'])
df = pd.DataFrame.from_dict(manifest)
df = filter_manifest_df(df, is_train_split=is_train_split)
save_df_to_tsv(df, op.join(root, f'{split}_{task}.tsv'))
# Generate vocab
vocab_size_str = '' if args.vocab_type == 'char' else str(args.vocab_size)
spm_filename_prefix = f'spm_{args.vocab_type}{vocab_size_str}_{task}'
with NamedTemporaryFile(mode='w') as f:
for t in train_text:
f.write(t + '\n')
gen_vocab(f.name, op.join(root, spm_filename_prefix),
args.vocab_type, args.vocab_size)
# Generate config YAML
gen_config_yaml(root, spm_filename_prefix + '.model',
yaml_filename=f'config_{task}.yaml',
specaugment_policy='lb')
# Clean up
shutil.rmtree(feature_root)
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--data-root', '-d', required=True, type=str)
parser.add_argument('--vocab-type', default='unigram', required=True,
type=str, choices=['bpe', 'unigram', 'char']),
parser.add_argument('--vocab-size', default=1000, type=int)
parser.add_argument('--src-lang', '-s', required=True, type=str)
parser.add_argument('--tgt-lang', '-t', type=str)
args = parser.parse_args()
process(args)
if __name__ == '__main__':
main()

View File

@ -0,0 +1,96 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import logging
from tempfile import NamedTemporaryFile
import os
import shutil
import os.path as op
from tqdm import tqdm
from torchaudio.datasets import LIBRISPEECH
import pandas as pd
from examples.speech_to_text.data_utils import (
gen_vocab, create_zip, get_zip_manifest, save_df_to_tsv,
extract_fbank_features, gen_config_yaml
)
log = logging.getLogger(__name__)
SPLITS = ['train-clean-100', 'train-clean-360', 'train-other-500', 'dev-clean',
'dev-other', 'test-clean', 'test-other']
MANIFEST_COLUMNS = ['id', 'audio', 'n_frames', 'tgt_text', 'speaker']
def process(args):
os.makedirs(args.output_root, exist_ok=True)
# Extract features
feature_root = op.join(args.output_root, 'fbank80')
os.makedirs(feature_root, exist_ok=True)
for split in SPLITS:
print(f'Fetching split {split}...')
dataset = LIBRISPEECH(args.output_root, url=split, download=True)
print('Extracting log mel filter bank features...')
for wav, sample_rate, _, spk_id, chapter_id, utt_id in tqdm(dataset):
sample_id = f'{spk_id}-{chapter_id}-{utt_id}'
extract_fbank_features(wav, sample_rate,
op.join(feature_root, f'{sample_id}.npy'))
# Pack features into ZIP
zip_filename = 'fbank80.zip'
zip_path = op.join(args.output_root, zip_filename)
print('ZIPing features...')
create_zip(feature_root, zip_path)
print('Fetching ZIP manifest...')
zip_manifest = get_zip_manifest(args.output_root, zip_filename)
# Generate TSV manifest
print('Generating manifest...')
train_text = []
for split in SPLITS:
manifest = {c: [] for c in MANIFEST_COLUMNS}
dataset = LIBRISPEECH(args.output_root, url=split)
for wav, sample_rate, utt, spk_id, chapter_id, utt_id in tqdm(dataset):
sample_id = f'{spk_id}-{chapter_id}-{utt_id}'
manifest['id'].append(sample_id)
manifest['audio'].append(zip_manifest[sample_id])
duration_ms = int(wav.size(1) / sample_rate * 1000)
manifest['n_frames'].append(int(1 + (duration_ms - 25) / 10))
manifest['tgt_text'].append(utt)
manifest['speaker'].append(spk_id)
save_df_to_tsv(pd.DataFrame.from_dict(manifest),
op.join(args.output_root, f'{split}.tsv'))
if split.startswith('train'):
train_text.extend(manifest['tgt_text'])
# Generate vocab
vocab_size = '' if args.vocab_type == 'char' else str(args.vocab_size)
spm_filename_prefix = f'spm_{args.vocab_type}{vocab_size}'
with NamedTemporaryFile(mode='w') as f:
for t in train_text:
f.write(t + '\n')
gen_vocab(f.name, op.join(args.output_root, spm_filename_prefix),
args.vocab_type, args.vocab_size)
# Generate config YAML
gen_config_yaml(args.output_root, spm_filename_prefix + '.model',
specaugment_policy='ld')
# Clean up
shutil.rmtree(feature_root)
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--output-root', '-o', required=True, type=str)
parser.add_argument('--vocab-type', default='unigram', required=True,
type=str, choices=['bpe', 'unigram', 'char']),
parser.add_argument('--vocab-size', default=10000, type=int)
args = parser.parse_args()
process(args)
if __name__ == '__main__':
main()

View File

@ -0,0 +1,172 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import logging
from tempfile import NamedTemporaryFile
import os
import os.path as op
import shutil
from typing import Tuple
from itertools import groupby
from tqdm import tqdm
import pandas as pd
from torch.utils.data import Dataset
import torchaudio
from torch import Tensor
from examples.speech_to_text.data_utils import (
gen_vocab, create_zip, get_zip_manifest, save_df_to_tsv,
extract_fbank_features, gen_config_yaml, filter_manifest_df
)
log = logging.getLogger(__name__)
MANIFEST_COLUMNS = ['id', 'audio', 'n_frames', 'tgt_text', 'speaker']
TASKS = ['asr', 'st']
class MUSTC(Dataset):
"""
Create a Dataset for MuST-C. Each item is a tuple of the form:
waveform, sample_rate, source utterance, target utterance, speaker_id,
utterance_id
"""
SPLITS = ['train', 'dev', 'tst-COMMON', 'tst-HE']
LANGUAGES = ['de', 'es', 'fr', 'it', 'nl', 'pt', 'ro', 'ru']
def __init__(self, root: str, lang: str, split: str) -> None:
assert split in self.SPLITS and lang in self.LANGUAGES
_root = op.join(root, f'en-{lang}', 'data', split)
wav_root, txt_root = op.join(_root, 'wav'), op.join(_root, 'txt')
assert op.isdir(_root) and op.isdir(wav_root) and op.isdir(txt_root)
# Load audio segments
try:
import yaml
except ImportError:
print('Please install PyYAML to load YAML files for '
'the MuST-C dataset')
with open(op.join(txt_root, f'{split}.yaml')) as f:
segments = yaml.load(f, Loader=yaml.BaseLoader)
# Load source and target utterances
for _lang in ['en', lang]:
with open(op.join(txt_root, f'{split}.{_lang}')) as f:
utterances = [r.strip() for r in f]
assert len(segments) == len(utterances)
for i, u in enumerate(utterances):
segments[i][_lang] = u
# Gather info
self.data = []
for wav_filename, _seg_group in groupby(segments, lambda x: x['wav']):
wav_path = op.join(wav_root, wav_filename)
sample_rate = torchaudio.info(wav_path)[0].rate
seg_group = sorted(_seg_group, key=lambda x: x['offset'])
for i, segment in enumerate(seg_group):
offset = int(float(segment['offset']) * sample_rate)
n_frames = int(float(segment['duration']) * sample_rate)
_id = f'{op.splitext(wav_filename)[0]}_{i}'
self.data.append(
(wav_path, offset, n_frames, sample_rate, segment['en'],
segment[lang], segment['speaker_id'], _id)
)
def __getitem__(self, n: int) -> Tuple[Tensor, int, str, str, str, str]:
wav_path, offset, n_frames, sr, src_utt, tgt_utt, spk_id, utt_id = \
self.data[n]
waveform, _ = torchaudio.load(wav_path, offset=offset,
num_frames=n_frames)
return waveform, sr, src_utt, tgt_utt, spk_id, utt_id
def __len__(self) -> int:
return len(self.data)
def process(args):
for lang in MUSTC.LANGUAGES:
cur_root = op.join(args.data_root, f'en-{lang}')
if not op.isdir(cur_root):
print(f'{cur_root} does not exist. Skipped.')
continue
# Extract features
feature_root = op.join(cur_root, 'fbank80')
os.makedirs(feature_root, exist_ok=True)
for split in MUSTC.SPLITS:
print(f'Fetching split {split}...')
dataset = MUSTC(args.data_root, lang, split)
print('Extracting log mel filter bank features...')
for waveform, sample_rate, _, _, _, utt_id in tqdm(dataset):
extract_fbank_features(waveform, sample_rate,
op.join(feature_root, f'{utt_id}.npy'))
# Pack features into ZIP
zip_filename = 'fbank80.zip'
zip_path = op.join(cur_root, zip_filename)
print('ZIPing features...')
create_zip(feature_root, zip_path)
print('Fetching ZIP manifest...')
zip_manifest = get_zip_manifest(args.data_root,
f'en-{lang}/{zip_filename}')
# Generate TSV manifest
print('Generating manifest...')
train_text = {task: [] for task in TASKS}
for split in MUSTC.SPLITS:
is_train_split = split.startswith('train')
manifest = {c: [] for c in MANIFEST_COLUMNS}
text = {task: [] for task in TASKS}
dataset = MUSTC(args.data_root, lang, split)
for wav, sr, src_utt, tgt_utt, speaker_id, utt_id in tqdm(dataset):
manifest['id'].append(utt_id)
manifest['audio'].append(zip_manifest[utt_id])
duration_ms = int(wav.size(1) / sr * 1000)
manifest['n_frames'].append(int(1 + (duration_ms - 25) / 10))
text['asr'].append(src_utt)
text['st'].append(tgt_utt)
manifest['speaker'].append(speaker_id)
if is_train_split:
for task in TASKS:
train_text[task].extend(text[task])
for task in TASKS:
manifest['tgt_text'] = text[task]
df = pd.DataFrame.from_dict(manifest)
df = filter_manifest_df(df, is_train_split=is_train_split)
save_df_to_tsv(df, op.join(cur_root, f'{split}_{task}.tsv'))
# Generate vocab
for task in TASKS:
vocab_type, vocab_size = args.asr_vocab_type, args.asr_vocab_size
if task == 'st':
vocab_type, vocab_size = args.st_vocab_type, args.st_vocab_size
vocab_size_str = '' if vocab_type == 'char' else str(vocab_size)
spm_filename_prefix = f'spm_{vocab_type}{vocab_size_str}_{task}'
with NamedTemporaryFile(mode='w') as f:
for t in train_text[task]:
f.write(t + '\n')
gen_vocab(f.name, op.join(cur_root, spm_filename_prefix),
vocab_type, vocab_size)
# Generate config YAML
gen_config_yaml(cur_root, spm_filename_prefix + '.model',
yaml_filename=f'config_{task}.yaml',
specaugment_policy='lb')
# Clean up
shutil.rmtree(feature_root)
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--data-root', '-d', required=True, type=str)
parser.add_argument('--asr-vocab-type', default='unigram', required=True,
type=str, choices=['bpe', 'unigram', 'char']),
parser.add_argument('--st-vocab-type', default='unigram', required=True,
type=str, choices=['bpe', 'unigram', 'char']),
parser.add_argument('--asr-vocab-size', default=5000, type=int)
parser.add_argument('--st-vocab-size', default=8000, type=int)
args = parser.parse_args()
process(args)
if __name__ == '__main__':
main()

View File

@ -5,6 +5,8 @@
import math
import torch
from fairseq import metrics, utils
from fairseq.criterions import FairseqCriterion, register_criterion
@ -31,11 +33,13 @@ def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=None, reduce=T
@register_criterion('label_smoothed_cross_entropy')
class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
def __init__(self, task, sentence_avg, label_smoothing):
def __init__(self, task, sentence_avg, label_smoothing,
ignore_prefix_size=0, report_accuracy=False):
super().__init__(task)
self.sentence_avg = sentence_avg
self.eps = label_smoothing
self.ignore_prefix_size = ignore_prefix_size
self.report_accuracy = report_accuracy
@staticmethod
def add_args(parser):
@ -43,6 +47,10 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
# fmt: off
parser.add_argument('--label-smoothing', default=0., type=float, metavar='D',
help='epsilon for label smoothing, 0 means no label smoothing')
parser.add_argument('--report-accuracy', action='store_true',
help='report accuracy metric')
parser.add_argument('--ignore-prefix-size', default=0, type=int,
help='Ignore first N tokens')
# fmt: on
def forward(self, model, sample, reduce=True):
@ -63,19 +71,41 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
'nsentences': sample['target'].size(0),
'sample_size': sample_size,
}
if self.report_accuracy:
n_correct, total = self.compute_accuracy(model, net_output, sample)
logging_output['n_correct'] = utils.item(n_correct.data)
logging_output['total'] = utils.item(total.data)
return loss, sample_size, logging_output
def compute_loss(self, model, net_output, sample, reduce=True):
def get_lprobs_and_target(self, model, net_output, sample):
lprobs = model.get_normalized_probs(net_output, log_probs=True)
lprobs = lprobs.view(-1, lprobs.size(-1))
target = model.get_targets(sample, net_output).view(-1, 1)
target = model.get_targets(sample, net_output)
if self.ignore_prefix_size > 0:
if getattr(lprobs, "batch_first", False):
lprobs = lprobs[:, self.ignore_prefix_size:, :].contiguous()
target = target[:, self.ignore_prefix_size:].contiguous()
else:
lprobs = lprobs[self.ignore_prefix_size:, :, :].contiguous()
target = target[self.ignore_prefix_size:, :].contiguous()
return lprobs.view(-1, lprobs.size(-1)), target.view(-1)
def compute_loss(self, model, net_output, sample, reduce=True):
lprobs, target = self.get_lprobs_and_target(model, net_output, sample)
loss, nll_loss = label_smoothed_nll_loss(
lprobs, target, self.eps, ignore_index=self.padding_idx, reduce=reduce,
)
return loss, nll_loss
@staticmethod
def reduce_metrics(logging_outputs) -> None:
def compute_accuracy(self, model, net_output, sample):
lprobs, target = self.get_lprobs_and_target(model, net_output, sample)
mask = target.ne(self.padding_idx)
n_correct = torch.sum(
lprobs.argmax(1).masked_select(mask).eq(target.masked_select(mask)))
total = torch.sum(mask)
return n_correct, total
@classmethod
def reduce_metrics(cls, logging_outputs) -> None:
"""Aggregate logging outputs from data parallel training."""
loss_sum = sum(log.get('loss', 0) for log in logging_outputs)
nll_loss_sum = sum(log.get('nll_loss', 0) for log in logging_outputs)
@ -86,6 +116,20 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
metrics.log_scalar('nll_loss', nll_loss_sum / ntokens / math.log(2), ntokens, round=3)
metrics.log_derived('ppl', lambda meters: utils.get_perplexity(meters['nll_loss'].avg))
total = utils.item(sum(log.get('total', 0) for log in logging_outputs))
if total > 0:
metrics.log_scalar('total', total)
n_correct = utils.item(
sum(log.get('n_correct', 0) for log in logging_outputs)
)
metrics.log_scalar('n_correct', n_correct)
metrics.log_derived(
'accuracy',
lambda meters: round(
meters['n_correct'].sum * 100.0 / meters['total'].sum, 3
) if meters['total'].sum > 0 else float('nan'),
)
@staticmethod
def logging_outputs_can_be_summed() -> bool:
"""

View File

@ -0,0 +1,81 @@
import os.path as op
from typing import Union, BinaryIO, Optional, Tuple
import numpy as np
def get_waveform(
path_or_fp: Union[str, BinaryIO], normalization=True
) -> Tuple[np.ndarray, int]:
"""Get the waveform and sample rate of a 16-bit mono-channel WAV or FLAC.
Args:
path_or_fp (str or BinaryIO): the path or file-like object
normalization (bool): Normalize values to [-1, 1] (Default: True)
"""
if isinstance(path_or_fp, str):
ext = op.splitext(op.basename(path_or_fp))[1]
if ext not in {'.flac', '.wav'}:
raise ValueError(f'Unsupported audio format: {ext}')
try:
import soundfile as sf
except ImportError:
raise ImportError('Please install soundfile to load WAV/FLAC file')
waveform, sample_rate = sf.read(path_or_fp, dtype='float32')
if not normalization:
waveform *= 2 ** 15 # denormalized to 16-bit signed integers
return waveform, sample_rate
def _get_kaldi_fbank(waveform, sample_rate, n_bins=80) -> Optional[np.ndarray]:
"""Get mel-filter bank features via PyKaldi."""
try:
from kaldi.feat.mel import MelBanksOptions
from kaldi.feat.fbank import FbankOptions, Fbank
from kaldi.feat.window import FrameExtractionOptions
from kaldi.matrix import Vector
mel_opts = MelBanksOptions()
mel_opts.num_bins = n_bins
frame_opts = FrameExtractionOptions()
frame_opts.samp_freq = sample_rate
opts = FbankOptions()
opts.mel_opts = mel_opts
opts.frame_opts = frame_opts
fbank = Fbank(opts=opts)
features = fbank.compute(Vector(waveform), 1.0).numpy()
return features
except ImportError:
return None
def _get_torchaudio_fbank(waveform, sample_rate, n_bins=80) -> Optional[np.ndarray]:
"""Get mel-filter bank features via TorchAudio."""
try:
import torch
import torchaudio.compliance.kaldi as ta_kaldi
waveform = torch.from_numpy(waveform).unsqueeze(0)
features = ta_kaldi.fbank(waveform, num_mel_bins=n_bins,
sample_frequency=sample_rate)
return features.numpy()
except ImportError:
return None
def get_fbank(path_or_fp: Union[str, BinaryIO], n_bins=80) -> np.ndarray:
"""Get mel-filter bank features via PyKaldi or TorchAudio. Prefer PyKaldi
(faster CPP implementation) to TorchAudio (Python implementation). Note that
Kaldi/TorchAudio requires 16-bit signed integers as inputs and hence the
waveform should not be normalized."""
sound, sample_rate = get_waveform(path_or_fp, normalization=False)
features = _get_kaldi_fbank(sound, sample_rate, n_bins)
if features is None:
features = _get_torchaudio_fbank(sound, sample_rate, n_bins)
if features is None:
raise ImportError('Please install pyKaldi or torchaudio to enable '
'online filterbank feature extraction')
return features

View File

@ -0,0 +1,77 @@
import importlib
import os
from typing import Optional, Dict
from abc import ABC, abstractmethod
class AudioFeatureTransform(ABC):
@classmethod
@abstractmethod
def from_config_dict(cls, config: Optional[Dict] = None):
pass
AUDIO_FEATURE_TRANSFORM_REGISTRY = {}
AUDIO_FEATURE_TRANSFORM_CLASS_NAMES = set()
def register_audio_feature_transform(name):
def register_audio_feature_transform_cls(cls):
if name in AUDIO_FEATURE_TRANSFORM_REGISTRY:
raise ValueError(f'Cannot register duplicate transform ({name})')
if not issubclass(cls, AudioFeatureTransform):
raise ValueError(f'Transform ({name}: {cls.__name__}) must extend '
'AudioFeatureTransform')
if cls.__name__ in AUDIO_FEATURE_TRANSFORM_CLASS_NAMES:
raise ValueError(
f'Cannot register audio feature transform with duplicate '
f'class name ({cls.__name__})'
)
AUDIO_FEATURE_TRANSFORM_REGISTRY[name] = cls
AUDIO_FEATURE_TRANSFORM_CLASS_NAMES.add(cls.__name__)
return cls
return register_audio_feature_transform_cls
def get_audio_feature_transform(name):
return AUDIO_FEATURE_TRANSFORM_REGISTRY[name]
transforms_dir = os.path.dirname(__file__)
for file in os.listdir(transforms_dir):
path = os.path.join(transforms_dir, file)
if (
not file.startswith('_')
and not file.startswith('.')
and (file.endswith('.py') or os.path.isdir(path))
):
name = file[:file.find('.py')] if file.endswith('.py') else file
importlib.import_module('fairseq.data.audio.feature_transforms.' + name)
class CompositeAudioFeatureTransform(AudioFeatureTransform):
@classmethod
def from_config_dict(cls, config=None):
_config = {} if config is None else config
_transforms = _config.get('transforms')
if _transforms is None:
return None
transforms = [
get_audio_feature_transform(_t).from_config_dict(_config.get(_t))
for _t in _transforms
]
return CompositeAudioFeatureTransform(transforms)
def __init__(self, transforms):
self.transforms = [t for t in transforms if t is not None]
def __call__(self, x):
for t in self.transforms:
x = t(x)
return x
def __repr__(self):
format_string = [self.__class__.__name__ + '('] + \
[f" {t.__repr__()}" for t in self.transforms] + [')']
return '\n'.join(format_string)

View File

@ -0,0 +1,24 @@
import numpy as np
from fairseq.data.audio.feature_transforms import (
AudioFeatureTransform, register_audio_feature_transform
)
@register_audio_feature_transform('global_cmvn')
class GlobalCMVN(AudioFeatureTransform):
"""Global CMVN (cepstral mean and variance normalization). The global mean
and variance need to be pre-computed and stored in NumPy format (.npz)."""
@classmethod
def from_config_dict(cls, config=None):
_config = {} if config is None else config
return GlobalCMVN(_config.get('stats_npz_path'))
def __init__(self, stats_npz_path):
stats = np.load(stats_npz_path)
self.mean, self.std = stats['mean'], stats['std']
def __call__(self, x):
x = np.subtract(x, self.mean)
x = np.divide(x, self.std)
return x

View File

@ -0,0 +1,126 @@
import math
import numbers
from typing import Optional
import numpy as np
from fairseq.data.audio.feature_transforms import (
AudioFeatureTransform, register_audio_feature_transform
)
@register_audio_feature_transform('specaugment')
class SpecAugmentTransform(AudioFeatureTransform):
"""SpecAugment (https://arxiv.org/abs/1904.08779)"""
@classmethod
def from_config_dict(cls, config=None):
_config = {} if config is None else config
return SpecAugmentTransform(
_config.get('time_warp_W', 0),
_config.get('freq_mask_N', 0),
_config.get('freq_mask_F', 0),
_config.get('time_mask_N', 0),
_config.get('time_mask_T', 0),
_config.get('time_mask_p', 0.0),
_config.get('mask_value', None),
)
def __init__(
self,
time_warp_w: int = 0,
freq_mask_n: int = 0,
freq_mask_f: int = 0,
time_mask_n: int = 0,
time_mask_t: int = 0,
time_mask_p: float = 0.0,
mask_value: Optional[float] = 0.0,
):
# Sanity checks
assert mask_value is None or isinstance(
mask_value, numbers.Number
), f"mask_value (type: {type(mask_value)}) must be None or a number"
if freq_mask_n > 0:
assert (
freq_mask_f > 0
), f"freq_mask_F ({freq_mask_f}) " \
f"must be larger than 0 when doing freq masking."
if time_mask_n > 0:
assert (
time_mask_t > 0
), f"time_mask_T ({time_mask_t}) must be larger than 0 when " \
f"doing time masking."
self.time_warp_w = time_warp_w
self.freq_mask_n = freq_mask_n
self.freq_mask_f = freq_mask_f
self.time_mask_n = time_mask_n
self.time_mask_t = time_mask_t
self.time_mask_p = time_mask_p
self.mask_value = mask_value
def __repr__(self):
return self.__class__.__name__ + '(' + ', '.join(
[f'time_warp_w={self.time_warp_w}',
f'freq_mask_n={self.freq_mask_n}',
f'freq_mask_f={self.freq_mask_f}',
f'time_mask_n={self.time_mask_n}',
f'time_mask_t={self.time_mask_t}',
f'time_mask_p={self.time_mask_p}']
) + ')'
def __call__(self, spectrogram):
assert len(spectrogram.shape) == 2, "spectrogram must be a 2-D tensor."
distorted = spectrogram.copy() # make a copy of input spectrogram.
num_frames = spectrogram.shape[0] # or 'tau' in the paper.
num_freqs = spectrogram.shape[1] # or 'miu' in the paper.
mask_value = self.mask_value
if mask_value is None: # if no value was specified, use local mean.
mask_value = spectrogram.mean()
if num_frames == 0:
return spectrogram
if num_freqs < self.freq_mask_f:
return spectrogram
if self.time_warp_w > 0:
if 2 * self.time_warp_w < num_frames:
import cv2
w0 = np.random.randint(
self.time_warp_w, num_frames - self.time_warp_w
)
w = np.random.randint(0, self.time_warp_w)
upper, lower = distorted[:w0, :], distorted[w0:, :]
upper = cv2.resize(
upper, dsize=(num_freqs, w0 + w),
interpolation=cv2.INTER_LINEAR
)
lower = cv2.resize(
lower,
dsize=(num_freqs, num_frames - w0 - w),
interpolation=cv2.INTER_LINEAR,
)
distorted = np.concatenate((upper, lower), axis=0)
for _i in range(self.freq_mask_n):
f = np.random.randint(0, self.freq_mask_f)
f0 = np.random.randint(0, num_freqs - f)
if f != 0:
distorted[:, f0: f0 + f] = mask_value
max_time_mask_t = min(
self.time_mask_t, math.floor(num_frames * self.time_mask_p)
)
if max_time_mask_t < 1:
return distorted
for _i in range(self.time_mask_n):
t = np.random.randint(0, max_time_mask_t)
t0 = np.random.randint(0, num_frames - t)
if t != 0:
distorted[t0: t0 + t, :] = mask_value
return distorted

View File

@ -0,0 +1,38 @@
import numpy as np
from fairseq.data.audio.feature_transforms import (
AudioFeatureTransform, register_audio_feature_transform
)
@register_audio_feature_transform('utterance_cmvn')
class UtteranceCMVN(AudioFeatureTransform):
"""Utterance-level CMVN (cepstral mean and variance normalization)"""
@classmethod
def from_config_dict(cls, config=None):
_config = {} if config is None else config
return UtteranceCMVN(
_config.get('norm_means', True),
_config.get('norm_vars', True),
)
def __init__(self, norm_means=True, norm_vars=True):
self.norm_means, self.norm_vars = norm_means, norm_vars
def __repr__(self):
return self.__class__.__name__ + \
f'(norm_means={self.norm_means}, norm_vars={self.norm_vars})'
def __call__(self, x):
mean = x.mean(axis=0)
square_sums = (x ** 2).sum(axis=0)
if self.norm_means:
x = np.subtract(x, mean)
if self.norm_vars:
var = square_sums / x.shape[0] - mean ** 2
std = np.sqrt(np.maximum(var, 1e-10))
x = np.divide(x, std)
return x

View File

@ -0,0 +1,478 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
import re
from typing import List, Tuple, Optional, Dict
import os.path as op
import csv
import io
import numpy as np
import torch
from fairseq.data import (FairseqDataset, Dictionary, ResamplingDataset,
ConcatDataset, data_utils as fairseq_data_utils)
from fairseq.data.audio.audio_utils import get_fbank, get_waveform
from fairseq.data.audio.feature_transforms import CompositeAudioFeatureTransform
logging.basicConfig(
format='%(asctime)s | %(levelname)s | %(name)s | %(message)s',
datefmt='%Y-%m-%d %H:%M:%S', level=logging.INFO
)
logger = logging.getLogger(__name__)
class S2TDataConfig(object):
"""Wrapper class for data config YAML"""
def __init__(self, yaml_path):
try:
import yaml
except ImportError:
print('Please install PyYAML to load YAML files for '
'S2T data config')
self.config = {}
if op.isfile(yaml_path):
try:
with open(yaml_path) as f:
self.config = yaml.load(f, Loader=yaml.FullLoader)
except Exception as e:
logger.info(f'Failed to load config from {yaml_path}: {e}')
else:
logger.info(f'Cannot find {yaml_path}')
@property
def vocab_filename(self):
"""fairseq vocabulary file under data root"""
return self.config.get('vocab_filename', 'dict.txt')
@property
def shuffle(self) -> bool:
"""Shuffle dataset samples before batching"""
return self.config.get('shuffle', False)
@property
def pre_tokenizer(self) -> Dict:
"""Pre-tokenizer to apply before subword tokenization. Returning
a dictionary with `tokenizer` providing the tokenizer name and
the other items providing the tokenizer-specific arguments.
Tokenizers are defined in `fairseq.data.encoders.*`"""
return self.config.get('pre_tokenizer', {'tokenizer': None})
@property
def bpe_tokenizer(self) -> Dict:
"""Subword tokenizer to apply after pre-tokenization. Returning
a dictionary with `bpe` providing the tokenizer name and
the other items providing the tokenizer-specific arguments.
Tokenizers are defined in `fairseq.data.encoders.*`"""
return self.config.get('bpe_tokenizer', None)
@property
def prepend_tgt_lang_tag(self) -> bool:
"""Prepend target lang ID token as the target BOS (e.g. for to-many
multilingual setting). During inference, this requires `--prefix-size 1`
to force BOS to be lang ID token."""
return self.config.get('prepend_tgt_lang_tag', False)
@property
def input_feat_per_channel(self):
"""The dimension of input features (per audio channel)"""
return self.config.get('input_feat_per_channel', 80)
@property
def input_channels(self):
"""The number of channels in the input audio"""
return self.config.get('input_channels', 1)
@property
def sampling_alpha(self):
"""Hyper-parameter alpha = 1/T for temperature-based resampling.
(alpha = 1 for no resampling)"""
return self.config.get('sampling_alpha', 1.)
@property
def use_audio_input(self):
"""Needed by the dataset loader to see if the model requires
raw audio as inputs."""
return self.config.get('use_audio_input', False)
@property
def audio_root(self):
"""Audio paths in the manifest TSV can be relative and this provides
the root path. Set this to empty string when using absolute paths."""
return self.config.get('audio_root', '')
def get_feature_transforms(self, split, is_train):
"""Split-specific feature transforms. Allowing train set wildcard `_train`,
evaluation set wildcard `_eval` and general wildcard `*` for matching."""
from copy import deepcopy
cfg = deepcopy(self.config)
_cur = cfg.get('transforms', {})
cur = _cur.get(split)
cur = _cur.get('_train') if cur is None and is_train else cur
cur = _cur.get('_eval') if cur is None and not is_train else cur
cur = _cur.get('*') if cur is None else cur
cfg['transforms'] = cur
return cfg
def is_npy_data(data: bytes) -> bool:
return data[0] == 147 and data[1] == 78
def is_flac_or_wav_data(data: bytes) -> bool:
is_flac = (data[0] == 102 and data[1] == 76)
is_wav = (data[0] == 82 and data[1] == 73)
return is_flac or is_wav
def read_from_uncompressed_zip(file_path, offset, file_size) -> bytes:
with open(file_path, 'rb') as f:
f.seek(offset)
data = f.read(file_size)
return data
def get_features_from_npy_or_audio(path):
ext = op.splitext(op.basename(path))[1]
if ext not in {'.npy', '.flac', '.wav'}:
raise ValueError(f'Unsupported file format for "{path}"')
return np.load(path) if ext == '.npy' else get_fbank(path)
def get_features_or_waveform_from_uncompressed_zip(
path, byte_offset, byte_size, need_waveform=False
):
assert path.endswith('.zip')
data = read_from_uncompressed_zip(path, byte_offset, byte_size)
f = io.BytesIO(data)
if is_npy_data(data):
features_or_waveform = np.load(f)
elif is_flac_or_wav_data(data):
features_or_waveform = get_waveform(f)[0] if need_waveform else get_fbank(f)
else:
raise ValueError(f'Unknown file format for "{path}"')
return features_or_waveform
def get_features_or_waveform(path: str, need_waveform=False):
"""Get speech features from .npy file or waveform from .wav/.flac file.
The file may be inside an uncompressed ZIP file and is accessed via byte
offset and length.
Args:
path (str): File path in the format of "<.npy/.wav/.flac path>" or
"<zip path>:<byte offset>:<byte length>".
need_waveform (bool): return waveform instead of features.
Returns:
features_or_waveform (numpy.ndarray): speech features or waveform.
"""
_path, *extra = path.split(':')
if not op.exists(_path):
raise FileNotFoundError(f'File not found: {_path}')
if len(extra) == 0:
if need_waveform:
return get_waveform(_path)
return get_features_from_npy_or_audio(_path)
elif len(extra) == 2:
extra = [int(i) for i in extra]
features_or_waveform = get_features_or_waveform_from_uncompressed_zip(
_path, extra[0], extra[1], need_waveform=need_waveform
)
else:
raise ValueError(f'Invalid path: {path}')
return features_or_waveform
def _collate_frames(frames: List[torch.Tensor],
is_audio_input: bool = False) -> torch.Tensor:
"""
Convert a list of 2D frames into a padded 3D tensor
Args:
frames (list): list of 2D frames of size L[i]*f_dim. Where L[i] is
length of i-th frame and f_dim is static dimension of features
Returns:
3D tensor of size len(frames)*len_max*f_dim where len_max is max of L[i]
"""
max_len = max(frame.size(0) for frame in frames)
if is_audio_input:
out = frames[0].new_zeros((len(frames), max_len))
else:
out = frames[0].new_zeros((len(frames), max_len, frames[0].size(1)))
for i, v in enumerate(frames):
out[i, : v.size(0)] = v
return out
class SpeechToTextDataset(FairseqDataset):
LANG_TAG_TEMPLATE = '<lang:{}>'
def __init__(
self,
split: str,
is_train_split: bool,
data_cfg: S2TDataConfig,
audio_paths: List[str],
n_frames: List[int],
src_texts: Optional[List[str]] = None,
tgt_texts: Optional[List[str]] = None,
speakers: Optional[List[str]] = None,
src_langs: Optional[List[str]] = None,
tgt_langs: Optional[List[str]] = None,
ids: Optional[List[str]] = None,
tgt_dict: Optional[Dictionary] = None,
pre_tokenizer=None,
bpe_tokenizer=None,
):
self.split, self.is_train_split = split, is_train_split
self.data_cfg = data_cfg
self.audio_paths, self.n_frames = audio_paths, n_frames
self.n_samples = len(audio_paths)
assert len(n_frames) == self.n_samples > 0
assert src_texts is None or len(src_texts) == self.n_samples
assert tgt_texts is None or len(tgt_texts) == self.n_samples
assert speakers is None or len(speakers) == self.n_samples
assert src_langs is None or len(src_langs) == self.n_samples
assert tgt_langs is None or len(tgt_langs) == self.n_samples
assert ids is None or len(ids) == self.n_samples
assert (tgt_dict is None and tgt_texts is None) or \
(tgt_dict is not None and tgt_texts is not None)
self.tgt_dict = tgt_dict
self.check_tgt_lang_tag()
self.src_texts, self.tgt_texts = src_texts, tgt_texts
self.src_langs, self.tgt_langs = src_langs, tgt_langs
self.ids = ids
self.shuffle = data_cfg.shuffle if is_train_split else False
self.feature_transforms = CompositeAudioFeatureTransform.from_config_dict(
self.data_cfg.get_feature_transforms(split, is_train_split)
)
self.pre_tokenizer = pre_tokenizer
self.bpe_tokenizer = bpe_tokenizer
logger.info(self.__repr__())
def __repr__(self):
return self.__class__.__name__ + \
f'(split="{self.split}", n_samples={self.n_samples}, ' \
f'prepend_tgt_lang_tag={self.data_cfg.prepend_tgt_lang_tag}, ' \
f'shuffle={self.shuffle}, transforms={self.feature_transforms})'
@classmethod
def is_lang_tag(cls, token):
pattern = cls.LANG_TAG_TEMPLATE.replace('{}', '(.*)')
return re.match(pattern, token)
def check_tgt_lang_tag(self):
if self.data_cfg.prepend_tgt_lang_tag:
assert self.tgt_langs is not None and self.tgt_dict is not None
tgt_lang_tags = [self.LANG_TAG_TEMPLATE.format(t)
for t in set(self.tgt_langs)]
assert all(t in self.tgt_dict for t in tgt_lang_tags)
def tokenize_text(self, text: str):
if self.pre_tokenizer is not None:
text = self.pre_tokenizer.encode(text)
if self.bpe_tokenizer is not None:
text = self.bpe_tokenizer.encode(text)
return text
def __getitem__(
self, index: int
) -> Tuple[int, torch.Tensor, Optional[torch.Tensor]]:
source = get_features_or_waveform(
self.audio_paths[index], need_waveform=self.data_cfg.use_audio_input
)
if self.feature_transforms is not None:
assert not self.data_cfg.use_audio_input
source = self.feature_transforms(source)
source = torch.from_numpy(source).float()
target = None
if self.tgt_texts is not None:
tokenized = self.tokenize_text(self.tgt_texts[index])
target = self.tgt_dict.encode_line(
tokenized, add_if_not_exist=False, append_eos=True
).long()
if self.data_cfg.prepend_tgt_lang_tag:
lang_tag = self.LANG_TAG_TEMPLATE.format(self.tgt_langs[index])
lang_tag_idx = self.tgt_dict.index(lang_tag)
target = torch.cat((torch.LongTensor([lang_tag_idx]), target), 0)
return index, source, target
def __len__(self):
return self.n_samples
def collater(
self, samples: List[Tuple[int, torch.Tensor, torch.Tensor]]
) -> Dict:
if len(samples) == 0:
return {}
indices = torch.tensor([i for i, _, _ in samples], dtype=torch.long)
frames = _collate_frames([s for _, s, _ in samples],
self.data_cfg.use_audio_input)
# sort samples by descending number of frames
n_frames = torch.tensor(
[s.size(0) for _, s, _ in samples], dtype=torch.long
)
n_frames, order = n_frames.sort(descending=True)
indices = indices.index_select(0, order)
frames = frames.index_select(0, order)
target, target_lengths = None, None
prev_output_tokens = None
ntokens = None
if self.tgt_texts is not None:
target = fairseq_data_utils.collate_tokens(
[t for _, _, t in samples], self.tgt_dict.pad(),
self.tgt_dict.eos(), left_pad=False, move_eos_to_beginning=False
)
target = target.index_select(0, order)
target_lengths = torch.tensor(
[t.size(0) for _, _, t in samples], dtype=torch.long
).index_select(0, order)
prev_output_tokens = fairseq_data_utils.collate_tokens(
[t for _, _, t in samples], self.tgt_dict.pad(),
self.tgt_dict.eos(), left_pad=False, move_eos_to_beginning=True
)
prev_output_tokens = prev_output_tokens.index_select(0, order)
ntokens = sum(t.size(0) for _, _, t in samples)
out = {
"id": indices,
"net_input": {
"src_tokens": frames,
"src_lengths": n_frames,
"prev_output_tokens": prev_output_tokens,
},
"target": target,
"target_lengths": target_lengths,
"ntokens": ntokens,
"nsentences": len(samples),
}
return out
def num_tokens(self, index):
return self.n_frames[index]
def size(self, index):
t_len = 0
if self.tgt_texts is not None:
tokenized = self.tokenize_text(self.tgt_texts[index])
t_len = len(tokenized.split(' '))
return self.n_frames[index], t_len
@property
def sizes(self):
return np.array(self.n_frames)
@property
def can_reuse_epoch_itr_across_epochs(self):
return True
def ordered_indices(self):
if self.shuffle:
order = [np.random.permutation(len(self))]
else:
order = [np.arange(len(self))]
# first by descending order of # of frames then by original/random order
order.append([-n for n in self.n_frames])
return np.lexsort(order)
def prefetch(self, indices):
raise False
class SpeechToTextDatasetCreator(object):
# mandatory columns
KEY_ID, KEY_AUDIO, KEY_N_FRAMES = 'id', 'audio', 'n_frames'
KEY_TGT_TEXT = 'tgt_text'
# optional columns
KEY_SPEAKER, KEY_SRC_TEXT = 'speaker', 'src_text'
KEY_SRC_LANG, KEY_TGT_LANG = 'src_lang', 'tgt_lang'
# default values
DEFAULT_SPEAKER = DEFAULT_SRC_TEXT = DEFAULT_LANG = ''
@classmethod
def _from_list(cls, split_name: str, is_train_split,
samples: List[List[Dict]], data_cfg: S2TDataConfig, tgt_dict,
pre_tokenizer, bpe_tokenizer) -> SpeechToTextDataset:
audio_paths, n_frames, src_texts, tgt_texts, ids = [], [], [], [], []
speakers, src_langs, tgt_langs = [], [], []
for s in samples:
ids.extend([ss[cls.KEY_ID] for ss in s])
audio_paths.extend([op.join(data_cfg.audio_root, ss[cls.KEY_AUDIO])
for ss in s])
n_frames.extend([int(ss[cls.KEY_N_FRAMES]) for ss in s])
tgt_texts.extend([ss[cls.KEY_TGT_TEXT] for ss in s])
src_texts.extend([ss.get(cls.KEY_SRC_TEXT, cls.DEFAULT_SRC_TEXT)
for ss in s])
speakers.extend([ss.get(cls.KEY_SPEAKER, cls.DEFAULT_SPEAKER)
for ss in s])
src_langs.extend([ss.get(cls.KEY_SRC_LANG, cls.DEFAULT_LANG)
for ss in s])
tgt_langs.extend([ss.get(cls.KEY_TGT_LANG, cls.DEFAULT_LANG)
for ss in s])
return SpeechToTextDataset(
split_name, is_train_split, data_cfg, audio_paths, n_frames,
src_texts, tgt_texts, speakers, src_langs, tgt_langs, ids, tgt_dict,
pre_tokenizer, bpe_tokenizer
)
@classmethod
def _get_size_ratios(cls, ids: List[str], sizes: List[int],
alpha: float = 1.):
"""Size ratios for temperature-based sampling
(https://arxiv.org/abs/1907.05019)"""
_sizes = np.array(sizes)
prob = _sizes / _sizes.sum()
smoothed_prob = prob ** alpha
smoothed_prob = smoothed_prob / smoothed_prob.sum()
size_ratio = (smoothed_prob * _sizes.sum()) / _sizes
o_str = str({_i: f"{prob[i]:.3f}" for i, _i in enumerate(ids)})
logger.info(f"original sampling probability: {o_str}")
p_str = str({_i: f"{smoothed_prob[i]:.3f}" for i, _i in enumerate(ids)})
logger.info(f"balanced sampling probability: {p_str}")
sr_str = str({_id: f"{size_ratio[i]:.3f}" for i, _id in enumerate(ids)})
logger.info(f"balanced sampling size ratio: {sr_str}")
return size_ratio.tolist()
@classmethod
def from_tsv(cls, root: str, data_cfg: S2TDataConfig, splits: str, tgt_dict,
pre_tokenizer, bpe_tokenizer, is_train_split: bool, epoch: int,
seed: int) -> SpeechToTextDataset:
samples = []
_splits = splits.split(',')
for split in _splits:
tsv_path = op.join(root, f'{split}.tsv')
if not op.isfile(tsv_path):
raise FileNotFoundError(f"Dataset not found: {tsv_path}")
with open(tsv_path) as f:
reader = csv.DictReader(
f, delimiter='\t', quotechar=None, doublequote=False,
lineterminator='\n', quoting=csv.QUOTE_NONE
)
samples.append([dict(e) for e in reader])
assert len(samples) > 0
datasets = [cls._from_list(name, is_train_split, [s], data_cfg, tgt_dict,
pre_tokenizer, bpe_tokenizer)
for name, s in zip(_splits, samples)]
if is_train_split and len(_splits) > 1 and data_cfg.sampling_alpha != 1.:
# temperature-based sampling
size_ratios = cls._get_size_ratios(
_splits, [len(s) for s in samples], alpha=data_cfg.sampling_alpha
)
datasets = [
ResamplingDataset(d, size_ratio=r, seed=seed, epoch=epoch,
replace=(r >= 1.))
for d, r in zip(datasets, size_ratios)
]
return ConcatDataset(datasets)

View File

@ -446,3 +446,14 @@ def get_mem_usage():
return f'used={psutil.virtual_memory().used / mb}Mb; avail={psutil.virtual_memory().available / mb}Mb'
except ImportError:
return 'N/A'
def lengths_to_padding_mask(lens: torch.LongTensor) -> torch.BoolTensor:
bsz, max_lens = lens.size(0), torch.max(lens).item()
mask = torch.arange(max_lens).to(lens.device).view(1, max_lens)
mask = mask.expand(bsz, -1) >= lens.view(bsz, 1).expand(-1, max_lens)
return mask
def lengths_to_mask(lens: torch.LongTensor) -> torch.BoolTensor:
return ~lengths_to_padding_mask(lens)

View File

@ -0,0 +1,7 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from .berard import * # noqa
from .s2t_transformer import * # noqa

View File

@ -0,0 +1,581 @@
#!/usr/bin/env python3
from ast import literal_eval
from typing import List, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from fairseq import checkpoint_utils, utils
from fairseq.models import (
FairseqEncoder,
FairseqIncrementalDecoder,
FairseqEncoderDecoderModel,
register_model,
register_model_architecture,
)
from fairseq.data.data_utils import lengths_to_padding_mask
@register_model("s2t_berard")
class BerardModel(FairseqEncoderDecoderModel):
"""Implementation of a model similar to https://arxiv.org/abs/1802.04200
Paper title: End-to-End Automatic Speech Translation of Audiobooks
An implementation is available in tensorflow at
https://github.com/eske/seq2seq
Relevant files in this implementation are the config
(https://github.com/eske/seq2seq/blob/master/config/LibriSpeech/AST.yaml)
and the model code
(https://github.com/eske/seq2seq/blob/master/translate/models.py).
The encoder and decoder try to be close to the original implementation.
The attention is an MLP as in Bahdanau et al.
(https://arxiv.org/abs/1409.0473).
There is no state initialization by averaging the encoder outputs.
"""
def __init__(self, encoder, decoder):
super().__init__(encoder, decoder)
@staticmethod
def add_args(parser):
parser.add_argument("--input-layers", type=str, metavar="EXPR",
help="List of linear layer dimensions. These "
"layers are applied to the input features and "
"are followed by tanh and possibly dropout.")
parser.add_argument(
"--dropout", type=float, metavar="D",
help="Dropout probability to use in the encoder/decoder. "
"Note that this parameters control dropout in various places, "
"there is no fine-grained control for dropout for embeddings "
"vs LSTM layers for example."
)
parser.add_argument("--in-channels", type=int, metavar="N",
help="Number of encoder input channels. "
"Typically value is 1.")
parser.add_argument("--conv-layers", type=str, metavar="EXPR",
help="List of conv layers "
"(format: (channels, kernel, stride)).")
parser.add_argument("--num-blstm-layers", type=int, metavar="N",
help="Number of encoder bi-LSTM layers.")
parser.add_argument("--lstm-size", type=int, metavar="N",
help="LSTM hidden size.")
parser.add_argument(
"--decoder-embed-dim", type=int, metavar="N",
help="Embedding dimension of the decoder target tokens."
)
parser.add_argument("--decoder-hidden-dim", type=int, metavar="N",
help="Decoder LSTM hidden dimension.")
parser.add_argument("--decoder-num-layers", type=int, metavar="N",
help="Number of decoder LSTM layers.")
parser.add_argument("--attention-dim", type=int, metavar="N",
help="Hidden layer dimension in MLP attention.")
parser.add_argument(
"--output-layer-dim", type=int, metavar="N",
help="Hidden layer dim for linear layer prior to output projection."
)
parser.add_argument(
"--load-pretrained-encoder-from", type=str, metavar="STR",
help="model to take encoder weights from (for initialization)"
)
parser.add_argument(
"--load-pretrained-decoder-from", type=str, metavar="STR",
help="model to take decoder weights from (for initialization)"
)
@classmethod
def build_encoder(cls, args, task):
encoder = BerardEncoder(
input_layers=literal_eval(args.input_layers),
conv_layers=literal_eval(args.conv_layers),
in_channels=args.input_channels,
input_feat_per_channel=args.input_feat_per_channel,
num_blstm_layers=args.num_blstm_layers,
lstm_size=args.lstm_size,
dropout=args.dropout,
)
if getattr(args, "load_pretrained_encoder_from", None):
encoder = checkpoint_utils.load_pretrained_component_from_model(
component=encoder, checkpoint=args.load_pretrained_encoder_from
)
return encoder
@classmethod
def build_decoder(cls, args, task):
decoder = LSTMDecoder(
dictionary=task.target_dictionary,
embed_dim=args.decoder_embed_dim,
num_layers=args.decoder_num_layers,
hidden_size=args.decoder_hidden_dim,
dropout=args.dropout,
encoder_output_dim=2 * args.lstm_size, # bidirectional
attention_dim=args.attention_dim,
output_layer_dim=args.output_layer_dim,
)
if getattr(args, "load_pretrained_decoder_from", None):
decoder = checkpoint_utils.load_pretrained_component_from_model(
component=decoder, checkpoint=args.load_pretrained_decoder_from
)
return decoder
@classmethod
def build_model(cls, args, task):
"""Build a new model instance."""
encoder = cls.build_encoder(args, task)
decoder = cls.build_decoder(args, task)
return cls(encoder, decoder)
def get_normalized_probs(self, net_output, log_probs, sample=None):
# net_output['encoder_out'] is a (B, T, D) tensor
lprobs = super().get_normalized_probs(net_output, log_probs, sample)
# lprobs is a (B, T, D) tensor
lprobs.batch_first = True
return lprobs
class BerardEncoder(FairseqEncoder):
def __init__(
self,
input_layers: List[int],
conv_layers: List[Tuple[int]],
in_channels: int,
input_feat_per_channel: int,
num_blstm_layers: int,
lstm_size: int,
dropout: float,
):
"""
Args:
input_layers: list of linear layer dimensions. These layers are
applied to the input features and are followed by tanh and
possibly dropout.
conv_layers: list of conv2d layer configurations. A configuration is
a tuple (out_channels, conv_kernel_size, stride).
in_channels: number of input channels.
input_feat_per_channel: number of input features per channel. These
are speech features, typically 40 or 80.
num_blstm_layers: number of bidirectional LSTM layers.
lstm_size: size of the LSTM hidden (and cell) size.
dropout: dropout probability. Dropout can be applied after the
linear layers and LSTM layers but not to the convolutional
layers.
"""
super().__init__(None)
self.input_layers = nn.ModuleList()
in_features = input_feat_per_channel
for out_features in input_layers:
if dropout > 0:
self.input_layers.append(
nn.Sequential(
nn.Linear(in_features, out_features),
nn.Dropout(p=dropout)
)
)
else:
self.input_layers.append(nn.Linear(in_features, out_features))
in_features = out_features
self.in_channels = in_channels
self.input_dim = input_feat_per_channel
self.conv_kernel_sizes_and_strides = []
self.conv_layers = nn.ModuleList()
lstm_input_dim = input_layers[-1]
for conv_layer in conv_layers:
out_channels, conv_kernel_size, conv_stride = conv_layer
self.conv_layers.append(
nn.Conv2d(
in_channels,
out_channels,
conv_kernel_size,
stride=conv_stride,
padding=conv_kernel_size // 2,
)
)
self.conv_kernel_sizes_and_strides.append(
(conv_kernel_size, conv_stride)
)
in_channels = out_channels
lstm_input_dim //= conv_stride
lstm_input_dim *= conv_layers[-1][0]
self.lstm_size = lstm_size
self.num_blstm_layers = num_blstm_layers
self.lstm = nn.LSTM(
input_size=lstm_input_dim,
hidden_size=lstm_size,
num_layers=num_blstm_layers,
dropout=dropout,
bidirectional=True,
)
self.output_dim = 2 * lstm_size # bidirectional
if dropout > 0:
self.dropout = nn.Dropout(p=dropout)
else:
self.dropout = None
def forward(self, src_tokens, src_lengths=None, **kwargs):
"""
Args
src_tokens: padded tensor (B, T, C * feat)
src_lengths: tensor of original lengths of input utterances (B,)
"""
bsz, max_seq_len, _ = src_tokens.size()
# (B, C, T, feat)
x = (
src_tokens.view(bsz, max_seq_len, self.in_channels, self.input_dim)
.transpose(1, 2)
.contiguous()
)
for input_layer in self.input_layers:
x = input_layer(x)
x = torch.tanh(x)
for conv_layer in self.conv_layers:
x = conv_layer(x)
bsz, _, output_seq_len, _ = x.size()
# (B, C, T, feat) -> (B, T, C, feat) -> (T, B, C, feat) ->
# (T, B, C * feat)
x = x.transpose(1, 2).transpose(0, 1).contiguous().view(output_seq_len,
bsz, -1)
input_lengths = src_lengths.clone()
for k, s in self.conv_kernel_sizes_and_strides:
p = k // 2
input_lengths = (input_lengths.float() + 2 * p - k) / s + 1
input_lengths = input_lengths.floor().long()
packed_x = nn.utils.rnn.pack_padded_sequence(x, input_lengths)
h0 = x.new(2 * self.num_blstm_layers, bsz, self.lstm_size).zero_()
c0 = x.new(2 * self.num_blstm_layers, bsz, self.lstm_size).zero_()
packed_outs, _ = self.lstm(packed_x, (h0, c0))
# unpack outputs and apply dropout
x, output_lengths = nn.utils.rnn.pad_packed_sequence(packed_outs)
if self.dropout is not None:
x = self.dropout(x)
encoder_padding_mask = lengths_to_padding_mask(output_lengths).to(
src_tokens.device).t()
return {
"encoder_out": x, # (T, B, C)
"encoder_padding_mask": encoder_padding_mask, # (T, B)
}
def reorder_encoder_out(self, encoder_out, new_order):
encoder_out["encoder_out"] = encoder_out["encoder_out"].index_select(
1, new_order
)
encoder_out["encoder_padding_mask"] = encoder_out[
"encoder_padding_mask"
].index_select(1, new_order)
return encoder_out
class MLPAttention(nn.Module):
"""The original attention from Badhanau et al. (2014)
https://arxiv.org/abs/1409.0473, based on a Multi-Layer Perceptron.
The attention score between position i in the encoder and position j in the
decoder is: alpha_ij = V_a * tanh(W_ae * enc_i + W_ad * dec_j + b_a)
"""
def __init__(self, decoder_hidden_state_dim, context_dim, attention_dim):
super().__init__()
self.context_dim = context_dim
self.attention_dim = attention_dim
# W_ae and b_a
self.encoder_proj = nn.Linear(context_dim, self.attention_dim,
bias=True)
# W_ad
self.decoder_proj = nn.Linear(
decoder_hidden_state_dim, self.attention_dim, bias=False
)
# V_a
self.to_scores = nn.Linear(self.attention_dim, 1, bias=False)
def forward(self, decoder_state, source_hids, encoder_padding_mask):
"""The expected input dimensions are:
decoder_state: bsz x decoder_hidden_state_dim
source_hids: src_len x bsz x context_dim
encoder_padding_mask: src_len x bsz
"""
src_len, bsz, _ = source_hids.size()
# (src_len*bsz) x context_dim (to feed through linear)
flat_source_hids = source_hids.view(-1, self.context_dim)
# (src_len*bsz) x attention_dim
encoder_component = self.encoder_proj(flat_source_hids)
# src_len x bsz x attention_dim
encoder_component = encoder_component.view(src_len, bsz,
self.attention_dim)
# 1 x bsz x attention_dim
decoder_component = self.decoder_proj(decoder_state).unsqueeze(0)
# Sum with broadcasting and apply the non linearity
# src_len x bsz x attention_dim
hidden_att = torch.tanh(
(decoder_component + encoder_component).view(-1, self.attention_dim)
)
# Project onto the reals to get attentions scores (src_len x bsz)
attn_scores = self.to_scores(hidden_att).view(src_len, bsz)
# Mask + softmax (src_len x bsz)
if encoder_padding_mask is not None:
attn_scores = (
attn_scores.float()
.masked_fill_(encoder_padding_mask, float("-inf"))
.type_as(attn_scores)
) # FP16 support: cast to float and back
# srclen x bsz
normalized_masked_attn_scores = F.softmax(attn_scores, dim=0)
# Sum weighted sources (bsz x context_dim)
attn_weighted_context = (
source_hids * normalized_masked_attn_scores.unsqueeze(2)
).sum(dim=0)
return attn_weighted_context, normalized_masked_attn_scores
class LSTMDecoder(FairseqIncrementalDecoder):
def __init__(
self,
dictionary,
embed_dim,
num_layers,
hidden_size,
dropout,
encoder_output_dim,
attention_dim,
output_layer_dim,
):
"""
Args:
dictionary: target text dictionary.
embed_dim: embedding dimension for target tokens.
num_layers: number of LSTM layers.
hidden_size: hidden size for LSTM layers.
dropout: dropout probability. Dropout can be applied to the
embeddings, the LSTM layers, and the context vector.
encoder_output_dim: encoder output dimension (hidden size of
encoder LSTM).
attention_dim: attention dimension for MLP attention.
output_layer_dim: size of the linear layer prior to output
projection.
"""
super().__init__(dictionary)
self.num_layers = num_layers
self.hidden_size = hidden_size
num_embeddings = len(dictionary)
padding_idx = dictionary.pad()
self.embed_tokens = nn.Embedding(num_embeddings, embed_dim, padding_idx)
if dropout > 0:
self.dropout = nn.Dropout(p=dropout)
else:
self.dropout = None
self.layers = nn.ModuleList()
for layer_id in range(num_layers):
input_size = embed_dim if layer_id == 0 else encoder_output_dim
self.layers.append(
nn.LSTMCell(input_size=input_size, hidden_size=hidden_size)
)
self.context_dim = encoder_output_dim
self.attention = MLPAttention(
decoder_hidden_state_dim=hidden_size,
context_dim=encoder_output_dim,
attention_dim=attention_dim,
)
self.deep_output_layer = nn.Linear(
hidden_size + encoder_output_dim + embed_dim, output_layer_dim
)
self.output_projection = nn.Linear(output_layer_dim, num_embeddings)
def forward(self, prev_output_tokens, encoder_out=None,
incremental_state=None, **kwargs):
encoder_padding_mask = encoder_out["encoder_padding_mask"]
encoder_outs = encoder_out["encoder_out"]
if incremental_state is not None:
prev_output_tokens = prev_output_tokens[:, -1:]
bsz, seqlen = prev_output_tokens.size()
srclen = encoder_outs.size(0)
# embed tokens
embeddings = self.embed_tokens(prev_output_tokens)
x = embeddings
if self.dropout is not None:
x = self.dropout(x)
# B x T x C -> T x B x C
x = x.transpose(0, 1)
# initialize previous states (or get from cache during incremental
# generation)
cached_state = utils.get_incremental_state(
self, incremental_state, "cached_state"
)
if cached_state is not None:
prev_hiddens, prev_cells = cached_state
else:
prev_hiddens = [
encoder_out["encoder_out"].mean(dim=0)
] * self.num_layers
prev_cells = [x.new_zeros(bsz, self.hidden_size)] * self.num_layers
attn_scores = x.new_zeros(bsz, srclen)
attention_outs = []
outs = []
for j in range(seqlen):
input = x[j, :, :]
attention_out = None
for i, layer in enumerate(self.layers):
# the previous state is one layer below except for the bottom
# layer where the previous state is the state emitted by the
# top layer
hidden, cell = layer(
input,
(
prev_hiddens[(i - 1) % self.num_layers],
prev_cells[(i - 1) % self.num_layers],
),
)
if self.dropout is not None:
hidden = self.dropout(hidden)
prev_hiddens[i] = hidden
prev_cells[i] = cell
if attention_out is None:
attention_out, attn_scores = self.attention(
hidden, encoder_outs, encoder_padding_mask
)
if self.dropout is not None:
attention_out = self.dropout(attention_out)
attention_outs.append(attention_out)
input = attention_out
# collect the output of the top layer
outs.append(hidden)
# cache previous states (no-op except during incremental generation)
utils.set_incremental_state(
self, incremental_state, "cached_state", (prev_hiddens, prev_cells)
)
# collect outputs across time steps
x = torch.cat(outs, dim=0).view(seqlen, bsz, self.hidden_size)
attention_outs_concat = torch.cat(attention_outs, dim=0).view(
seqlen, bsz, self.context_dim
)
# T x B x C -> B x T x C
x = x.transpose(0, 1)
attention_outs_concat = attention_outs_concat.transpose(0, 1)
# concat LSTM output, attention output and embedding
# before output projection
x = torch.cat((x, attention_outs_concat, embeddings), dim=2)
x = self.deep_output_layer(x)
x = torch.tanh(x)
if self.dropout is not None:
x = self.dropout(x)
# project back to size of vocabulary
x = self.output_projection(x)
# to return the full attn_scores tensor, we need to fix the decoder
# to account for subsampling input frames
# return x, attn_scores
return x, None
def reorder_incremental_state(self, incremental_state, new_order):
super().reorder_incremental_state(incremental_state, new_order)
cached_state = utils.get_incremental_state(
self, incremental_state, "cached_state"
)
if cached_state is None:
return
def reorder_state(state):
if isinstance(state, list):
return [reorder_state(state_i) for state_i in state]
return state.index_select(0, new_order)
new_state = tuple(map(reorder_state, cached_state))
utils.set_incremental_state(
self, incremental_state, "cached_state", new_state
)
@register_model_architecture(model_name="s2t_berard", arch_name="s2t_berard")
def berard(args):
"""The original version: "End-to-End Automatic Speech Translation of
Audiobooks" (https://arxiv.org/abs/1802.04200)
"""
args.input_layers = getattr(args, "input_layers", "[256, 128]")
args.conv_layers = getattr(args, "conv_layers", "[(16, 3, 2), (16, 3, 2)]")
args.num_blstm_layers = getattr(args, "num_blstm_layers", 3)
args.lstm_size = getattr(args, "lstm_size", 256)
args.dropout = getattr(args, "dropout", 0.2)
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 128)
args.decoder_num_layers = getattr(args, "decoder_num_layers", 2)
args.decoder_hidden_dim = getattr(args, "decoder_hidden_dim", 512)
args.attention_dim = getattr(args, "attention_dim", 512)
args.output_layer_dim = getattr(args, "output_layer_dim", 128)
args.load_pretrained_encoder_from = getattr(
args, "load_pretrained_encoder_from", None
)
args.load_pretrained_decoder_from = getattr(
args, "load_pretrained_decoder_from", None
)
@register_model_architecture(model_name="s2t_berard",
arch_name="s2t_berard_256_3_3")
def berard_256_3_3(args):
"""Used in
* "Harnessing Indirect Training Data for End-to-End Automatic Speech
Translation: Tricks of the Trade" (https://arxiv.org/abs/1909.06515)
* "CoVoST: A Diverse Multilingual Speech-To-Text Translation Corpus"
(https://arxiv.org/pdf/2002.01320.pdf)
* "Self-Supervised Representations Improve End-to-End Speech Translation"
(https://arxiv.org/abs/2006.12124)
"""
args.decoder_num_layers = getattr(args, "decoder_num_layers", 3)
berard(args)
@register_model_architecture(model_name="s2t_berard",
arch_name="s2t_berard_512_3_2")
def berard_512_3_2(args):
args.num_blstm_layers = getattr(args, "num_blstm_layers", 3)
args.lstm_size = getattr(args, "lstm_size", 512)
args.dropout = getattr(args, "dropout", 0.3)
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 256)
args.decoder_num_layers = getattr(args, "decoder_num_layers", 2)
args.decoder_hidden_dim = getattr(args, "decoder_hidden_dim", 1024)
args.attention_dim = getattr(args, "attention_dim", 512)
args.output_layer_dim = getattr(args, "output_layer_dim", 256)
berard(args)
@register_model_architecture(model_name="s2t_berard",
arch_name="s2t_berard_512_5_3")
def berard_512_5_3(args):
args.num_blstm_layers = getattr(args, "num_blstm_layers", 5)
args.lstm_size = getattr(args, "lstm_size", 512)
args.dropout = getattr(args, "dropout", 0.3)
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 256)
args.decoder_num_layers = getattr(args, "decoder_num_layers", 3)
args.decoder_hidden_dim = getattr(args, "decoder_hidden_dim", 1024)
args.attention_dim = getattr(args, "attention_dim", 512)
args.output_layer_dim = getattr(args, "output_layer_dim", 256)
berard(args)

View File

@ -0,0 +1,394 @@
#!/usr/bin/env python3
import logging
import math
from typing import Dict, List, Optional, Tuple
import torch
import torch.nn as nn
from fairseq import utils, checkpoint_utils
from fairseq.models import (FairseqEncoder, FairseqEncoderDecoderModel,
register_model, register_model_architecture)
from fairseq.models.fairseq_encoder import EncoderOut
from fairseq.data.data_utils import lengths_to_padding_mask
from fairseq.models.transformer import Embedding, TransformerDecoder
from fairseq.modules import (PositionalEmbedding, TransformerEncoderLayer,
FairseqDropout, LayerNorm)
from torch import Tensor
logger = logging.getLogger(__name__)
class Conv1dSubsampler(nn.Module):
"""Convolutional subsampler: a stack of 1D convolution (along temporal
dimension) followed by non-linear activation via gated linear units
(https://arxiv.org/abs/1911.08460)
Args:
in_channels (int): the number of input channels
mid_channels (int): the number of intermediate channels
out_channels (int): the number of output channels
kernel_sizes (List[int]): the kernel size for each convolutional layer
"""
def __init__(self, in_channels: int, mid_channels: int, out_channels: int,
kernel_sizes: List[int] = (3, 3)):
super(Conv1dSubsampler, self).__init__()
self.n_layers = len(kernel_sizes)
self.conv_layers = nn.ModuleList(
nn.Conv1d(
in_channels if i == 0 else mid_channels // 2,
mid_channels if i < self.n_layers - 1 else out_channels * 2,
k, stride=2, padding=k // 2
)
for i, k in enumerate(kernel_sizes)
)
def get_out_seq_lens_tensor(self, in_seq_lens_tensor):
out = in_seq_lens_tensor.clone()
for _ in range(self.n_layers):
out = ((out.float() - 1) / 2 + 1).floor().long()
return out
def forward(self, src_tokens, src_lengths):
bsz, in_seq_len, _ = src_tokens.size() # B x T x (C x D)
x = src_tokens.transpose(1, 2).contiguous() # -> B x (C x D) x T
for conv in self.conv_layers:
x = conv(x)
x = nn.functional.glu(x, dim=1)
_, _, out_seq_len = x.size()
x = x.transpose(1, 2).transpose(0, 1).contiguous() # -> T x B x (C x D)
return x, self.get_out_seq_lens_tensor(src_lengths)
@register_model("s2t_transformer")
class S2TTransformerModel(FairseqEncoderDecoderModel):
"""Adapted Transformer model (https://arxiv.org/abs/1706.03762) for
speech-to-text tasks. The Transformer encoder/decoder remains the same.
A trainable input subsampler is prepended to the Transformer encoder to
project inputs into the encoder dimension as well as downsample input
sequence for computational efficiency."""
def __init__(self, encoder, decoder):
super().__init__(encoder, decoder)
@staticmethod
def add_args(parser):
"""Add model-specific arguments to the parser."""
# input
parser.add_argument("--conv-kernel-sizes", type=str, metavar="N",
help="kernel sizes of Conv1d subsampling layers")
parser.add_argument("--conv-channels", type=int, metavar="N",
help="# of channels in Conv1d subsampling layers")
# Transformer
parser.add_argument("--activation-fn", type=str, default='relu',
choices=utils.get_available_activation_fns(),
help="activation function to use")
parser.add_argument("--dropout", type=float, metavar="D",
help="dropout probability")
parser.add_argument("--attention-dropout", type=float, metavar="D",
help="dropout probability for attention weights")
parser.add_argument("--activation-dropout", "--relu-dropout",
type=float, metavar="D",
help="dropout probability after activation in FFN.")
parser.add_argument("--encoder-embed-dim", type=int, metavar="N",
help="encoder embedding dimension")
parser.add_argument("--encoder-ffn-embed-dim", type=int, metavar="N",
help="encoder embedding dimension for FFN")
parser.add_argument("--encoder-layers", type=int, metavar="N",
help="num encoder layers")
parser.add_argument("--encoder-attention-heads", type=int, metavar="N",
help="num encoder attention heads")
parser.add_argument("--encoder-normalize-before", action="store_true",
help="apply layernorm before each encoder block")
parser.add_argument("--decoder-embed-dim", type=int, metavar="N",
help="decoder embedding dimension")
parser.add_argument("--decoder-ffn-embed-dim", type=int, metavar="N",
help="decoder embedding dimension for FFN")
parser.add_argument("--decoder-layers", type=int, metavar="N",
help="num decoder layers")
parser.add_argument("--decoder-attention-heads", type=int, metavar="N",
help="num decoder attention heads")
parser.add_argument("--decoder-normalize-before", action="store_true",
help="apply layernorm before each decoder block")
parser.add_argument("--layernorm-embedding", action="store_true",
help="add layernorm to embedding")
parser.add_argument("--no-scale-embedding", action="store_true",
help="if True, dont scale embeddings")
parser.add_argument(
"--load-pretrained-encoder-from", type=str, metavar="STR",
help="model to take encoder weights from (for initialization)"
)
@classmethod
def build_encoder(cls, args):
encoder = S2TTransformerEncoder(args)
if getattr(args, "load_pretrained_encoder_from", None):
encoder = checkpoint_utils.load_pretrained_component_from_model(
component=encoder, checkpoint=args.load_pretrained_encoder_from
)
logger.info(f'loaded pretrained encoder from: '
f'{args.load_pretrained_encoder_from}')
return encoder
@classmethod
def build_decoder(cls, args, task, embed_tokens):
return TransformerDecoderScriptable(args, task.target_dictionary,
embed_tokens)
@classmethod
def build_model(cls, args, task):
"""Build a new model instance."""
# make sure all arguments are present in older models
base_architecture(args)
def build_embedding(dictionary, embed_dim):
num_embeddings = len(dictionary)
padding_idx = dictionary.pad()
return Embedding(num_embeddings, embed_dim, padding_idx)
decoder_embed_tokens = build_embedding(task.target_dictionary,
args.decoder_embed_dim)
encoder = cls.build_encoder(args)
decoder = cls.build_decoder(args, task, decoder_embed_tokens)
return cls(encoder, decoder)
def get_normalized_probs(
self,
net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]],
log_probs: bool,
sample: Optional[Dict[str, Tensor]] = None,
):
# net_output['encoder_out'] is a (B, T, D) tensor
lprobs = self.get_normalized_probs_scriptable(net_output, log_probs,
sample)
lprobs.batch_first = True
return lprobs
def forward(self, src_tokens, src_lengths, prev_output_tokens):
"""
The forward method inherited from the base class has a **kwargs
argument in its input, which is not supported in torchscript. This
method overrites the forward method definition without **kwargs.
"""
encoder_out = self.encoder(src_tokens=src_tokens,
src_lengths=src_lengths)
decoder_out = self.decoder(prev_output_tokens=prev_output_tokens,
encoder_out=encoder_out)
return decoder_out
class S2TTransformerEncoder(FairseqEncoder):
"""Speech-to-text Transformer encoder that consists of input subsampler and
Transformer encoder."""
def __init__(self, args):
super().__init__(None)
self.dropout_module = FairseqDropout(
p=args.dropout, module_name=self.__class__.__name__
)
self.embed_scale = math.sqrt(args.encoder_embed_dim)
if args.no_scale_embedding:
self.embed_scale = 1.0
self.padding_idx = 1
self.subsample = Conv1dSubsampler(
args.input_feat_per_channel * args.input_channels,
args.conv_channels, args.encoder_embed_dim,
[int(k) for k in args.conv_kernel_sizes.split(',')]
)
self.embed_positions = PositionalEmbedding(
args.max_source_positions, args.encoder_embed_dim,
self.padding_idx
)
self.transformer_layers = nn.ModuleList(
[TransformerEncoderLayer(args) for _ in range(args.encoder_layers)]
)
if args.encoder_normalize_before:
self.layer_norm = LayerNorm(args.encoder_embed_dim)
else:
self.layer_norm = None
def forward(self, src_tokens, src_lengths):
x, input_lengths = self.subsample(src_tokens, src_lengths)
x = self.embed_scale * x
encoder_padding_mask = lengths_to_padding_mask(input_lengths)
positions = self.embed_positions(encoder_padding_mask).transpose(0, 1)
x += positions
x = self.dropout_module(x)
for layer in self.transformer_layers:
x = layer(x, encoder_padding_mask)
if not encoder_padding_mask.any():
encoder_padding_mask = None
if self.layer_norm is not None:
x = self.layer_norm(x)
return EncoderOut(
encoder_out=x, encoder_padding_mask=encoder_padding_mask,
encoder_embedding=None, encoder_states=None, src_tokens=None,
src_lengths=None
)
@torch.jit.export
def reorder_encoder_out(self, encoder_out: EncoderOut, new_order):
"""
Since encoder_padding_mask and encoder_embedding are both of type
Optional[Tensor] in EncoderOut, they need to be copied as local
variables for Torchscript Optional refinement
"""
encoder_padding_mask: Optional[Tensor] = \
encoder_out.encoder_padding_mask
encoder_embedding: Optional[Tensor] = encoder_out.encoder_embedding
new_encoder_out = (
encoder_out.encoder_out
if encoder_out.encoder_out is None
else encoder_out.encoder_out.index_select(1, new_order)
)
new_encoder_padding_mask = (
encoder_padding_mask
if encoder_padding_mask is None
else encoder_padding_mask.index_select(0, new_order)
)
new_encoder_embedding = (
encoder_embedding
if encoder_embedding is None
else encoder_embedding.index_select(0, new_order)
)
encoder_states = encoder_out.encoder_states
if encoder_states is not None:
for idx, state in enumerate(encoder_states):
encoder_states[idx] = state.index_select(1, new_order)
return EncoderOut(
encoder_out=new_encoder_out, # T x B x C
encoder_padding_mask=new_encoder_padding_mask, # B x T
encoder_embedding=new_encoder_embedding, # B x T x C
encoder_states=encoder_states, # List[T x B x C]
src_tokens=None,
src_lengths=None,
)
class TransformerDecoderScriptable(TransformerDecoder):
def extract_features(
self,
prev_output_tokens,
encoder_out: Optional[EncoderOut] = None,
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
full_context_alignment: bool = False,
alignment_layer: Optional[int] = None,
alignment_heads: Optional[int] = None,
):
# call scriptable method from parent class
x, _ = self.extract_features_scriptable(
prev_output_tokens, encoder_out, incremental_state,
full_context_alignment, alignment_layer, alignment_heads,
)
return x, None
@register_model_architecture(model_name="s2t_transformer",
arch_name="s2t_transformer")
def base_architecture(args):
# Convolutional subsampler
args.conv_kernel_sizes = getattr(args, "conv_kernel_sizes", '5,5')
args.conv_channels = getattr(args, "conv_channels", 1024)
# Transformer
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048)
args.encoder_layers = getattr(args, "encoder_layers", 12)
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8)
args.encoder_normalize_before = getattr(args, "encoder_normalize_before",
True)
args.decoder_embed_dim = getattr(args, "decoder_embed_dim",
args.encoder_embed_dim)
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim",
args.encoder_ffn_embed_dim)
args.decoder_layers = getattr(args, "decoder_layers", 6)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8)
args.decoder_normalize_before = getattr(args, "decoder_normalize_before",
True)
args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False)
args.dropout = getattr(args, "dropout", 0.1)
args.attention_dropout = getattr(args, "attention_dropout", args.dropout)
args.activation_dropout = getattr(args, "activation_dropout", args.dropout)
args.activation_fn = getattr(args, "activation_fn", "relu")
args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff",
None)
args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
args.share_decoder_input_output_embed = getattr(
args, "share_decoder_input_output_embed", False
)
args.no_token_positional_embeddings = getattr(
args, "no_token_positional_embeddings", False
)
args.adaptive_input = getattr(args, "adaptive_input", False)
args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0.0)
args.decoder_output_dim = getattr(args, "decoder_output_dim",
args.decoder_embed_dim)
args.decoder_input_dim = getattr(args, "decoder_input_dim",
args.decoder_embed_dim)
args.no_scale_embedding = getattr(args, "no_scale_embedding", False)
args.quant_noise_pq = getattr(args, "quant_noise_pq", 0)
@register_model_architecture("s2t_transformer", "s2t_transformer_s")
def s2t_transformer_s(args):
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 256)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 256 * 8)
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 4)
args.dropout = getattr(args, "dropout", 0.1)
base_architecture(args)
@register_model_architecture("s2t_transformer", "s2t_transformer_sp")
def s2t_transformer_sp(args):
args.encoder_layers = getattr(args, "encoder_layers", 16)
s2t_transformer_s(args)
@register_model_architecture("s2t_transformer", "s2t_transformer_m")
def s2t_transformer_m(args):
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 512 * 4)
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8)
args.dropout = getattr(args, "dropout", 0.15)
base_architecture(args)
@register_model_architecture("s2t_transformer", "s2t_transformer_mp")
def s2t_transformer_mp(args):
args.encoder_layers = getattr(args, "encoder_layers", 16)
s2t_transformer_m(args)
@register_model_architecture("s2t_transformer", "s2t_transformer_l")
def s2t_transformer_l(args):
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim",
1024 * 4)
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16)
args.dropout = getattr(args, "dropout", 0.2)
base_architecture(args)
@register_model_architecture("s2t_transformer", "s2t_transformer_lp")
def s2t_transformer_lp(args):
args.encoder_layers = getattr(args, "encoder_layers", 16)
s2t_transformer_l(args)

View File

@ -10,10 +10,9 @@ from argparse import Namespace
import torch
from fairseq import metrics, search, tokenizer, utils
from fairseq.data import Dictionary, FairseqDataset, data_utils, iterators
from fairseq.data import Dictionary, FairseqDataset, data_utils, iterators, encoders
from fairseq.dataclass.utils import gen_parser_from_dataclass
logger = logging.getLogger(__name__)
@ -504,6 +503,14 @@ class FairseqTask(object):
for this task)."""
raise NotImplementedError
def build_tokenizer(self, args):
"""Build the pre-tokenizer for this task."""
return encoders.build_tokenizer(args)
def build_bpe(self, args):
"""Build the tokenizer for this task."""
return encoders.build_bpe(args)
class LegacyFairseqTask(FairseqTask):
def __init__(self, args: Namespace):

View File

@ -0,0 +1,120 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
from argparse import Namespace
import os.path as op
from fairseq.data import encoders, Dictionary
from fairseq.data.audio.speech_to_text_dataset import (
SpeechToTextDataset, SpeechToTextDatasetCreator, S2TDataConfig
)
from fairseq.tasks import FairseqTask, register_task
logging.basicConfig(
format='%(asctime)s | %(levelname)s | %(name)s | %(message)s',
datefmt='%Y-%m-%d %H:%M:%S',
level=logging.INFO,
)
logger = logging.getLogger(__name__)
@register_task('speech_to_text')
class SpeechToTextTask(FairseqTask):
@staticmethod
def add_args(parser):
parser.add_argument('data', help='manifest root path')
parser.add_argument(
'--config-yaml', type=str, default='config.yaml',
help='Configuration YAML filename (under manifest root)'
)
parser.add_argument('--max-source-positions', default=6000, type=int,
metavar='N',
help='max number of tokens in the source sequence')
parser.add_argument('--max-target-positions', default=1024, type=int,
metavar='N',
help='max number of tokens in the target sequence')
def __init__(self, args, tgt_dict):
super().__init__(args)
self.tgt_dict = tgt_dict
self.data_cfg = S2TDataConfig(op.join(args.data, args.config_yaml))
@classmethod
def setup_task(cls, args, **kwargs):
data_cfg = S2TDataConfig(op.join(args.data, args.config_yaml))
dict_path = op.join(args.data, data_cfg.vocab_filename)
if not op.isfile(dict_path):
raise FileNotFoundError(f'Dict not found: {dict_path}')
tgt_dict = Dictionary.load(dict_path)
logger.info(f'dictionary size ({data_cfg.vocab_filename}): '
f'{len(tgt_dict):,}')
if getattr(args, 'train_subset', None) is not None:
if not all(s.startswith('train') for s in args.train_subset.split(',')):
raise ValueError('Train splits should be named like "train*".')
return cls(args, tgt_dict)
def build_criterion(self, args):
from fairseq import criterions
if self.data_cfg.prepend_tgt_lang_tag and args.ignore_prefix_size != 1:
raise ValueError('Please set "--ignore-prefix-size 1" since '
'target language ID token is prepended as BOS.')
return criterions.build_criterion(args, self)
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
is_train_split = split.startswith('train')
pre_tokenizer = self.build_tokenizer(self.args)
bpe_tokenizer = self.build_bpe(self.args)
self.datasets[split] = SpeechToTextDatasetCreator.from_tsv(
self.args.data, self.data_cfg, split, self.tgt_dict,
pre_tokenizer, bpe_tokenizer, is_train_split=is_train_split,
epoch=epoch, seed=self.args.seed
)
@property
def target_dictionary(self):
return self.tgt_dict
@property
def source_dictionary(self):
return None
def max_positions(self):
return self.args.max_source_positions, self.args.max_target_positions
def build_model(self, args):
args.input_feat_per_channel = self.data_cfg.input_feat_per_channel
args.input_channels = self.data_cfg.input_channels
return super(SpeechToTextTask, self).build_model(args)
def build_generator(
self, models, args, seq_gen_cls=None, extra_gen_cls_kwargs=None,
):
if self.data_cfg.prepend_tgt_lang_tag and args.prefix_size != 1:
raise ValueError('Please set "--prefix-size 1" since '
'target language ID token is prepended as BOS.')
lang_token_ids = {
i for s, i in self.tgt_dict.indices.items()
if SpeechToTextDataset.is_lang_tag(s)
}
extra_gen_cls_kwargs = {'symbols_to_strip_from_output': lang_token_ids}
return super().build_generator(
models, args, seq_gen_cls=None,
extra_gen_cls_kwargs=extra_gen_cls_kwargs
)
def build_tokenizer(self, args):
logger.info(f'pre-tokenizer: {self.data_cfg.pre_tokenizer}')
return encoders.build_tokenizer(Namespace(**self.data_cfg.pre_tokenizer))
def build_bpe(self, args):
logger.info(f'tokenizer: {self.data_cfg.bpe_tokenizer}')
return encoders.build_bpe(Namespace(**self.data_cfg.bpe_tokenizer))
@classmethod
def build_dataset_for_inference(cls, audio_paths, n_frames):
return SpeechToTextDataset('interactive', False, {}, audio_paths,
n_frames)

View File

@ -21,7 +21,6 @@ import torch
from fairseq import checkpoint_utils, options, scoring, tasks, utils
from fairseq.logging import progress_bar
from fairseq.logging.meters import StopwatchMeter, TimeMeter
from fairseq.data import encoders
def main(args):
@ -158,8 +157,8 @@ def _main(args, output_file):
generator = task.build_generator(models, args, extra_gen_cls_kwargs=extra_gen_cls_kwargs)
# Handle tokenization and BPE
tokenizer = encoders.build_tokenizer(args)
bpe = encoders.build_bpe(args)
tokenizer = task.build_tokenizer(args)
bpe = task.build_bpe(args)
def decode_fn(x):
if bpe is not None:

View File

@ -141,7 +141,7 @@ setup(
'hydra-core',
'numpy',
'regex',
'sacrebleu',
'sacrebleu>=1.4.12',
'torch',
'tqdm',
],

View File

@ -38,6 +38,7 @@ class TestLabelSmoothing(unittest.TestCase):
# build model
self.args = argparse.Namespace()
self.args.sentence_avg = False
self.args.report_accuracy = False
self.args.probs = torch.FloatTensor([
# pad eos unk w1 w2 w3
[0.05, 0.05, 0.1, 0.05, 0.3, 0.4, 0.05],