Asr initial push (#810)

Summary:
Initial code for speech recognition task.
Right now only one ASR model added - https://arxiv.org/abs/1904.11660

unit test testing:
python -m unittest discover tests

also run model training with this code and obtained
5.0 test_clean | 13.4 test_other
on librispeech with pytorch/audio features
Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/810

Reviewed By: cpuhrsch

Differential Revision: D16706659

Pulled By: okhonko

fbshipit-source-id: 89a5f9883e50bc0e548234287aa0ea73f7402514
This commit is contained in:
Dmytro Okhonko 2019-08-08 02:42:38 -07:00 committed by Facebook Github Bot
parent 9a1038f68a
commit 72f9364cc6
24 changed files with 3054 additions and 247 deletions

View File

@ -0,0 +1,32 @@
# Speech Recognition
`examples/speech_recognition` is implementing ASR task in Fairseq, along with needed features, datasets, models and loss functions to train and infer model described in [Transformers with convolutional context for ASR (Abdelrahman Mohamed et al., 2019)](https://arxiv.org/abs/1904.11660).
## Additional dependencies
On top of main fairseq dependencies there are couple more additional requirements.
1) Please follow the instructions to install [torchaudio](https://github.com/pytorch/audio). This is required to compute audio fbank features.
2) [Sclite](http://www1.icsi.berkeley.edu/Speech/docs/sctk-1.2/sclite.htm#sclite_name_0) is used to measure WER. Sclite can be downloaded and installed from source from sctk package [here](http://www.openslr.org/4/). Training and inference doesn't require Sclite dependency.
## Preparing librispeech data
```
./examples/speech_recognition/datasets/prepare-librispeech.sh $DIR_TO_SAVE_RAW_DATA $DIR_FOR_PREPROCESSED_DATA
```
## Training librispeech data
```
python train.py $DIR_FOR_PREPROCESSED_DATA --save-dir $MODEL_PATH --max-epoch 80 --task speech_recognition --arch vggtransformer_2 --optimizer adadelta --lr 1.0 --adadelta-eps 1e-8 --adadelta-rho 0.95 --clip-norm 10.0 --max-tokens 5000 --log-format json --log-interval 1 --criterion cross_entropy_acc --user-dir examples/speech_recognition/
```
## Inference for librispeech
`$SET` can be `test_clean` or `test_other`
Any checkpoint in `$MODEL_PATH` can be selected. In this example we are working with `checkpoint_last.pt`
```
python examples/speech_recognition/infer.py $DIR_FOR_PREPROCESSED_DATA --task speech_recognition --max-tokens 25000 --nbest 1 --path $MODEL_PATH/checkpoint_last.pt --beam 20 --results-path $RES_DIR --batch-size 40 --gen-subset $SET --user-dir examples/speech_recognition/
```
## Inference for librispeech
```
sclite -r ${RES_DIR}/ref.word-checkpoint_last.pt-${SET}.txt -h ${RES_DIR}/hypo.word-checkpoint_last.pt-${SET}.txt -i rm -o all stdout > $RES_REPORT
```
`Sum/Avg` row from first table of the report has WER

View File

@ -0,0 +1 @@
from . import tasks, criterions, models # noqa

View File

@ -0,0 +1,7 @@
import importlib
import os
for file in os.listdir(os.path.dirname(__file__)):
if file.endswith('.py') and not file.startswith('_'):
criterion_name = file[:file.find('.py')]
importlib.import_module('examples.speech_recognition.criterions.' + criterion_name)

View File

@ -0,0 +1,129 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import absolute_import, division, print_function, unicode_literals
import logging
import math
import torch
import torch.nn.functional as F
from fairseq import utils
from fairseq.criterions import FairseqCriterion, register_criterion
@register_criterion("cross_entropy_acc")
class CrossEntropyWithAccCriterion(FairseqCriterion):
def __init__(self, args, task):
super().__init__(args, task)
def compute_loss(self, model, net_output, target, reduction, log_probs):
# N, T -> N * T
target = target.view(-1)
lprobs = model.get_normalized_probs(net_output, log_probs=log_probs)
if not hasattr(lprobs, "batch_first"):
logging.warning(
"ERROR: we need to know whether "
"batch first for the net output; "
"you need to set batch_first attribute for the return value of "
"model.get_normalized_probs. Now, we assume this is true, but "
"in the future, we will raise exception instead. "
)
batch_first = getattr(lprobs, "batch_first", True)
if not batch_first:
lprobs = lprobs.transpose(0, 1)
# N, T, D -> N * T, D
lprobs = lprobs.view(-1, lprobs.size(-1))
loss = F.nll_loss(
lprobs, target, ignore_index=self.padding_idx, reduction=reduction
)
return lprobs, loss
def get_logging_output(self, sample, target, lprobs, loss):
target = target.view(-1)
mask = target != self.padding_idx
correct = torch.sum(
lprobs.argmax(1).masked_select(mask) == target.masked_select(mask)
)
total = torch.sum(mask)
sample_size = (
sample["target"].size(0) if self.args.sentence_avg else sample["ntokens"]
)
logging_output = {
"loss": utils.item(loss.data), # * sample['ntokens'],
"ntokens": sample["ntokens"],
"nsentences": sample["target"].size(0),
"sample_size": sample_size,
"correct": utils.item(correct.data),
"total": utils.item(total.data),
"nframes": torch.sum(sample["net_input"]["src_lengths"]).item(),
}
return sample_size, logging_output
def forward(self, model, sample, reduction="sum", log_probs=True):
"""Computes the cross entropy with accuracy metric for the given sample.
This is similar to CrossEntropyCriterion in fairseq, but also
computes accuracy metrics as part of logging
Args:
logprobs (Torch.tensor) of shape N, T, D i.e.
batchsize, timesteps, dimensions
targets (Torch.tensor) of shape N, T i.e batchsize, timesteps
Returns:
tuple: With three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
TODO:
* Currently this Criterion will only work with LSTMEncoderModels or
FairseqModels which have decoder, or Models which return TorchTensor
as net_output.
We need to make a change to support all FairseqEncoder models.
"""
net_output = model(**sample["net_input"])
target = model.get_targets(sample, net_output)
lprobs, loss = self.compute_loss(
model, net_output, target, reduction, log_probs
)
sample_size, logging_output = self.get_logging_output(
sample, target, lprobs, loss
)
return loss, sample_size, logging_output
@staticmethod
def aggregate_logging_outputs(logging_outputs):
"""Aggregate logging outputs from data parallel training."""
correct_sum = sum(log.get("correct", 0) for log in logging_outputs)
total_sum = sum(log.get("total", 0) for log in logging_outputs)
loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
nsentences = sum(log.get("nsentences", 0) for log in logging_outputs)
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
nframes = sum(log.get("nframes", 0) for log in logging_outputs)
agg_output = {
"loss": loss_sum / sample_size / math.log(2) if sample_size > 0 else 0.0,
# if args.sentence_avg, then sample_size is nsentences, then loss
# is per-sentence loss; else sample_size is ntokens, the loss
# becomes per-output token loss
"ntokens": ntokens,
"nsentences": nsentences,
"nframes": nframes,
"sample_size": sample_size,
"acc": correct_sum * 100.0 / total_sum if total_sum > 0 else 0.0,
"correct": correct_sum,
"total": total_sum,
# total is the number of validate tokens
}
if sample_size != ntokens:
agg_output["nll_loss"] = loss_sum / ntokens / math.log(2)
# loss: per output token loss
# nll_loss: per sentence loss
return agg_output

View File

@ -0,0 +1,10 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from .asr_dataset import AsrDataset
__all__ = [
'AsrDataset',
]

View File

@ -0,0 +1,110 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import os
import numpy as np
from fairseq.data import FairseqDataset
from . import data_utils
from .collaters import Seq2SeqCollater
class AsrDataset(FairseqDataset):
"""
A dataset representing speech and corresponding transcription.
Args:
aud_paths: (List[str]): A list of str with paths to audio files.
aud_durations_ms (List[int]): A list of int containing the durations of
audio files.
tgt (List[torch.LongTensor]): A list of LongTensors containing the indices
of target transcriptions.
tgt_dict (~fairseq.data.Dictionary): target vocabulary.
ids (List[str]): A list of utterance IDs.
speakers (List[str]): A list of speakers corresponding to utterances.
num_mel_bins (int): Number of triangular mel-frequency bins (default: 80)
frame_length (float): Frame length in milliseconds (default: 25.0)
frame_shift (float): Frame shift in milliseconds (default: 10.0)
"""
def __init__(
self, aud_paths, aud_durations_ms, tgt,
tgt_dict, ids, speakers,
num_mel_bins=80, frame_length=25.0, frame_shift=10.0
):
assert frame_length > 0
assert frame_shift > 0
assert all(x > frame_length for x in aud_durations_ms)
self.frame_sizes = [
int(1 + (d - frame_length) / frame_shift)
for d in aud_durations_ms
]
assert len(aud_paths) > 0
assert len(aud_paths) == len(aud_durations_ms)
assert len(aud_paths) == len(tgt)
assert len(aud_paths) == len(ids)
assert len(aud_paths) == len(speakers)
self.aud_paths = aud_paths
self.tgt_dict = tgt_dict
self.tgt = tgt
self.ids = ids
self.speakers = speakers
self.num_mel_bins = num_mel_bins
self.frame_length = frame_length
self.frame_shift = frame_shift
def __getitem__(self, index):
import torchaudio
import torchaudio.compliance.kaldi as kaldi
tgt_item = self.tgt[index] if self.tgt is not None else None
path = self.aud_paths[index]
if not os.path.exists(path):
raise FileNotFoundError("Audio file not found: {}".format(path))
sound, sample_rate = torchaudio.load_wav(path)
output = kaldi.fbank(
sound,
num_mel_bins=self.num_mel_bins,
frame_length=self.frame_length,
frame_shift=self.frame_shift
)
output_cmvn = data_utils.apply_mv_norm(output)
self.collater = Seq2SeqCollater(
0, 1, pad_index=self.tgt_dict.pad(),
eos_index=self.tgt_dict.eos(), move_eos_to_beginning=True
)
return {"id": index, "data": [output_cmvn.detach(), tgt_item]}
def __len__(self):
return len(self.aud_paths)
def collater(self, samples):
"""Merge a list of samples to form a mini-batch.
Args:
samples (List[int]): sample indices to collate
Returns:
dict: a mini-batch suitable for forwarding with a Model
"""
return self.collater.collate(samples)
def num_tokens(self, index):
return self.frame_sizes[index]
def size(self, index):
"""Return an example's size as a float or tuple. This value is used when
filtering a dataset with ``--max-positions``."""
return (
self.frame_sizes[index],
len(self.tgt[index]) if self.tgt is not None else 0,
)
def ordered_indices(self):
"""Return an ordered list of indices. Batches will be constructed based
on this order."""
return np.arange(len(self))

View File

@ -0,0 +1,129 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
This module contains collection of classes which implement
collate functionalities for various tasks.
Collaters should know what data to expect for each sample
and they should pack / collate them into batches
"""
from __future__ import absolute_import, division, print_function, unicode_literals
import numpy as np
import torch
from fairseq.data import data_utils as fairseq_data_utils
class Seq2SeqCollater(object):
"""
Implements collate function mainly for seq2seq tasks
This expects each sample to contain feature (src_tokens) and
targets.
This collator is also used for aligned training task.
"""
def __init__(
self,
feature_index=0,
label_index=1,
pad_index=1,
eos_index=2,
move_eos_to_beginning=True,
):
self.feature_index = feature_index
self.label_index = label_index
self.pad_index = pad_index
self.eos_index = eos_index
self.move_eos_to_beginning = move_eos_to_beginning
def _collate_frames(self, frames):
"""Convert a list of 2d frames into a padded 3d tensor
Args:
frames (list): list of 2d frames of size L[i]*f_dim. Where L[i] is
length of i-th frame and f_dim is static dimension of features
Returns:
3d tensor of size len(frames)*len_max*f_dim where len_max is max of L[i]
"""
len_max = max(frame.size(0) for frame in frames)
f_dim = frames[0].size(1)
res = frames[0].new(len(frames), len_max, f_dim).fill_(0.0)
for i, v in enumerate(frames):
res[i, : v.size(0)] = v
return res
def collate(self, samples):
"""
utility function to collate samples into batch for speech recognition.
"""
if len(samples) == 0:
return {}
# parse samples into torch tensors
parsed_samples = []
for s in samples:
# skip invalid samples
if s["data"][self.feature_index] is None:
continue
source = s["data"][self.feature_index]
if isinstance(source, (np.ndarray, np.generic)):
source = torch.from_numpy(source)
target = s["data"][self.label_index]
if isinstance(target, (np.ndarray, np.generic)):
target = torch.from_numpy(target).long()
parsed_sample = {"id": s["id"], "source": source, "target": target}
parsed_samples.append(parsed_sample)
samples = parsed_samples
id = torch.LongTensor([s["id"] for s in samples])
frames = self._collate_frames([s["source"] for s in samples])
# sort samples by descending number of frames
frames_lengths = torch.LongTensor([s["source"].size(0) for s in samples])
frames_lengths, sort_order = frames_lengths.sort(descending=True)
id = id.index_select(0, sort_order)
frames = frames.index_select(0, sort_order)
target = None
target_lengths = None
prev_output_tokens = None
if samples[0].get("target", None) is not None:
ntokens = sum(len(s["target"]) for s in samples)
target = fairseq_data_utils.collate_tokens(
[s["target"] for s in samples],
self.pad_index,
self.eos_index,
left_pad=False,
move_eos_to_beginning=False,
)
target = target.index_select(0, sort_order)
target_lengths = torch.LongTensor(
[s["target"].size(0) for s in samples]
).index_select(0, sort_order)
prev_output_tokens = fairseq_data_utils.collate_tokens(
[s["target"] for s in samples],
self.pad_index,
self.eos_index,
left_pad=False,
move_eos_to_beginning=self.move_eos_to_beginning,
)
prev_output_tokens = prev_output_tokens.index_select(0, sort_order)
else:
ntokens = sum(len(s["source"]) for s in samples)
batch = {
"id": id,
"ntokens": ntokens,
"net_input": {"src_tokens": frames, "src_lengths": frames_lengths},
"target": target,
"target_lengths": target_lengths,
"nsentences": len(samples),
}
if prev_output_tokens is not None:
batch["net_input"]["prev_output_tokens"] = prev_output_tokens
return batch

View File

@ -0,0 +1,60 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
def calc_mean_invstddev(feature):
if len(feature.size()) != 2:
raise ValueError("We expect the input feature to be 2-D tensor")
mean = feature.mean(0)
var = feature.var(0)
# avoid division by ~zero
eps = 1e-8
if (var < eps).any():
return mean, 1.0 / (torch.sqrt(var) + eps)
return mean, 1.0 / torch.sqrt(var)
def apply_mv_norm(features):
mean, invstddev = calc_mean_invstddev(features)
res = (features - mean) * invstddev
return res
def lengths_to_encoder_padding_mask(lengths, batch_first=False):
"""
convert lengths (a 1-D Long/Int tensor) to 2-D binary tensor
Args:
lengths: a (B, )-shaped tensor
Return:
max_length: maximum length of B sequences
encoder_padding_mask: a (max_length, B) binary mask, where
[t, b] = 0 for t < lengths[b] and 1 otherwise
TODO:
kernelize this function if benchmarking shows this function is slow
"""
max_lengths = torch.max(lengths).item()
bsz = lengths.size(0)
encoder_padding_mask = torch.arange(
max_lengths
).to( # a (T, ) tensor with [0, ..., T-1]
lengths.device
).view( # move to the right device
1, max_lengths
).expand( # reshape to (1, T)-shaped tensor
bsz, -1
) >= lengths.view( # expand to (B, T)-shaped tensor
bsz, 1
).expand(
-1, max_lengths
)
if not batch_first:
return encoder_padding_mask.t(), max_lengths
else:
return encoder_padding_mask, max_lengths

View File

@ -0,0 +1,96 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import absolute_import, division, print_function, unicode_literals
from collections import namedtuple
import concurrent.futures
from itertools import chain
import argparse
import os
import json
import sentencepiece as spm
import multiprocessing
import torchaudio
from fairseq.data import Dictionary
MILLISECONDS_TO_SECONDS = 0.001
def process_sample(aud_path, lable, utt_id, sp, tgt_dict):
input = {}
output = {}
si, ei = torchaudio.info(aud_path)
input["length_ms"] = int(si.length / si.channels / si.rate / MILLISECONDS_TO_SECONDS)
input["path"] = aud_path
token = " ".join(sp.EncodeAsPieces(lable))
ids = tgt_dict.encode_line(token, append_eos=False)
output["text"] = lable
output["token"] = token
output["tokenid"] = ', '.join(map(str, [t.tolist() for t in ids]))
return {utt_id: {"input": input, "output": output}}
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--audio-dirs", nargs="+", default=['-'], required=True,
help="input directories with audio files")
parser.add_argument("--labels", required=True,
help="aggregated input labels with format <ID LABEL> per line",
type=argparse.FileType('r', encoding='UTF-8'))
parser.add_argument("--spm-model", required=True,
help="sentencepiece model to use for encoding",
type=argparse.FileType('r', encoding='UTF-8'))
parser.add_argument("--dictionary", required=True,
help="file to load fairseq dictionary from",
type=argparse.FileType('r', encoding='UTF-8'))
parser.add_argument("--audio-format", choices=["flac", "wav"], default="wav")
parser.add_argument("--output", required=True, type=argparse.FileType('w'),
help="path to save json output")
args = parser.parse_args()
sp = spm.SentencePieceProcessor()
sp.Load(args.spm_model.name)
tgt_dict = Dictionary.load(args.dictionary)
labels = {}
for line in args.labels:
(utt_id, label) = line.split(" ", 1)
labels[utt_id] = label
if len(labels) == 0:
raise Exception('No labels found in ', args.labels_path)
Sample = namedtuple('Sample', 'aud_path utt_id')
samples = []
for path, _, files in chain.from_iterable(os.walk(path) for path in args.audio_dirs):
for f in files:
if f.endswith(args.audio_format):
if len(os.path.splitext(f)) != 2:
raise Exception('Expect <utt_id.extension> file name. Got: ', f)
utt_id = os.path.splitext(f)[0]
if utt_id not in labels:
continue
samples.append(Sample(os.path.join(path, f), utt_id))
utts = {}
num_cpu = multiprocessing.cpu_count()
with concurrent.futures.ThreadPoolExecutor(max_workers=num_cpu) as executor:
future_to_sample = {executor.submit(process_sample, s.aud_path, labels[s.utt_id], s.utt_id, sp, tgt_dict): s for s in samples}
for future in concurrent.futures.as_completed(future_to_sample):
try:
data = future.result()
except Exception as exc:
print('generated an exception: ', exc)
else:
utts.update(data)
json.dump({"utts": utts}, args.output, indent=4)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,88 @@
#!/usr/bin/env bash
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# Prepare librispeech dataset
base_url=www.openslr.org/resources/12
train_dir=train_960
if [ "$#" -ne 2 ]; then
echo "Usage: $0 <download_dir> <out_dir>"
echo "e.g.: $0 /tmp/librispeech_raw/ ~/data/librispeech_final"
exit 1
fi
download_dir=${1%/}
out_dir=${2%/}
fairseq_root=~/fairseq-py/
mkdir -p ${out_dir}
cd ${out_dir} || exit
nbpe=5000
bpemode=unigram
if [ ! -d "$fairseq_root" ]; then
echo "$0: Please set correct fairseq_root"
exit 1
fi
echo "Data Download"
for part in dev-clean test-clean dev-other test-other train-clean-100 train-clean-360 train-other-500; do
url=$base_url/$part.tar.gz
if ! wget -P $download_dir $url; then
echo "$0: wget failed for $url"
exit 1
fi
if ! tar -C $download_dir -xvzf $download_dir/$part.tar.gz; then
echo "$0: error un-tarring archive $download_dir/$part.tar.gz"
exit 1
fi
done
echo "Merge all train packs into one"
mkdir -p ${download_dir}/LibriSpeech/${train_dir}/
for part in train-clean-100 train-clean-360 train-other-500; do
mv ${download_dir}/LibriSpeech/${part}/* $download_dir/LibriSpeech/${train_dir}/
done
echo "Merge train text"
find ${download_dir}/LibriSpeech/${train_dir}/ -name '*.txt' -exec cat {} \; >> ${download_dir}/LibriSpeech/${train_dir}/text
# Use combined dev-clean and dev-other as validation set
find ${download_dir}/LibriSpeech/dev-clean/ ${download_dir}/LibriSpeech/dev-other/ -name '*.txt' -exec cat {} \; >> ${download_dir}/LibriSpeech/valid_text
find ${download_dir}/LibriSpeech/test-clean/ -name '*.txt' -exec cat {} \; >> ${download_dir}/LibriSpeech/test-clean/text
find ${download_dir}/LibriSpeech/test-other/ -name '*.txt' -exec cat {} \; >> ${download_dir}/LibriSpeech/test-other/text
dict=data/lang_char/${train_dir}_${bpemode}${nbpe}_units.txt
encoded=data/lang_char/${train_dir}_${bpemode}${nbpe}_encoded.txt
fairseq_dict=data/lang_char/${train_dir}_${bpemode}${nbpe}_fairseq_dict.txt
bpemodel=data/lang_char/${train_dir}_${bpemode}${nbpe}
echo "dictionary: ${dict}"
echo "Dictionary preparation"
mkdir -p data/lang_char/
echo "<unk> 3" > ${dict}
echo "</s> 2" >> ${dict}
echo "<pad> 1" >> ${dict}
cut -f 2- -d" " ${download_dir}/LibriSpeech/${train_dir}/text > data/lang_char/input.txt
spm_train --input=data/lang_char/input.txt --vocab_size=${nbpe} --model_type=${bpemode} --model_prefix=${bpemodel} --input_sentence_size=100000000 --unk_id=3 --eos_id=2 --pad_id=1 --bos_id=-1 --character_coverage=1
spm_encode --model=${bpemodel}.model --output_format=piece < data/lang_char/input.txt > ${encoded}
cat ${encoded} | tr ' ' '\n' | sort | uniq | awk '{print $0 " " NR+3}' >> ${dict}
cat ${encoded} | tr ' ' '\n' | sort | uniq -c | awk '{print $2 " " $1}' > ${fairseq_dict}
wc -l ${dict}
echo "Prepare train and test jsons"
for part in train_960 test-other test-clean; do
python ${fairseq_root}/examples/speech_recognition/datasets/asr_prep_json.py --audio-dirs ${download_dir}/LibriSpeech/${part} --labels ${download_dir}/LibriSpeech/${part}/text --spm-model ${bpemodel}.model --audio-format flac --dictionary ${fairseq_dict} --output ${part}.json
done
# fairseq expects to find train.json and valid.json during training
mv train_960.json train.json
echo "Prepare valid json"
python ${fairseq_root}/examples/speech_recognition/datasets/asr_prep_json.py --audio-dirs ${download_dir}/LibriSpeech/dev-clean ${download_dir}/LibriSpeech/dev-other --labels ${download_dir}/LibriSpeech/valid_text --spm-model ${bpemodel}.model --audio-format flac --dictionary ${fairseq_dict} --output valid.json
cp ${fairseq_dict} ./dict.txt
cp ${bpemodel}.model ./spm.model

View File

@ -0,0 +1,243 @@
#!/usr/bin/env python3 -u
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
Run inference for pre-processed data with a trained model.
"""
import logging
import os
import sentencepiece as spm
import torch
from fairseq import options, progress_bar, utils, tasks
from fairseq.meters import StopwatchMeter, TimeMeter
from fairseq.utils import import_user_module
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
def add_asr_eval_argument(parser):
parser.add_argument("--ctc", action="store_true", help="decode a ctc model")
parser.add_argument("--rnnt", default=False, help="decode a rnnt model")
parser.add_argument("--kspmodel", default=None, help="sentence piece model")
parser.add_argument(
"--wfstlm", default=None, help="wfstlm on dictonary output units"
)
parser.add_argument(
"--rnnt_decoding_type",
default="greedy",
help="wfstlm on dictonary\
output units",
)
parser.add_argument(
"--lm_weight",
default=0.2,
help="weight for wfstlm while interpolating\
with neural score",
)
parser.add_argument(
"--rnnt_len_penalty", default=-0.5, help="rnnt length penalty on word level"
)
return parser
def check_args(args):
assert args.path is not None, "--path required for generation!"
assert args.results_path is not None, "--results_path required for generation!"
assert (
not args.sampling or args.nbest == args.beam
), "--sampling requires --nbest to be equal to --beam"
assert (
args.replace_unk is None or args.raw_text
), "--replace-unk requires a raw text dataset (--raw-text)"
def get_dataset_itr(args, task):
return task.get_batch_iterator(
dataset=task.dataset(args.gen_subset),
max_tokens=args.max_tokens,
max_sentences=args.max_sentences,
max_positions=(1000000.0, 1000000.0),
ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
required_batch_size_multiple=args.required_batch_size_multiple,
num_shards=args.num_shards,
shard_id=args.shard_id,
num_workers=args.num_workers,
).next_epoch_itr(shuffle=False)
def process_predictions(args, hypos, sp, tgt_dict, target_tokens, res_files, speaker, id):
for hypo in hypos[: min(len(hypos), args.nbest)]:
hyp_pieces = tgt_dict.string(hypo["tokens"].int().cpu())
hyp_words = sp.DecodePieces(hyp_pieces.split())
print(
"{} ({}-{})".format(hyp_pieces, speaker, id),
file=res_files["hypo.units"],
)
print(
"{} ({}-{})".format(hyp_words, speaker, id),
file=res_files["hypo.words"],
)
tgt_pieces = tgt_dict.string(target_tokens)
tgt_words = sp.DecodePieces(tgt_pieces.split())
print(
"{} ({}-{})".format(tgt_pieces, speaker, id),
file=res_files["ref.units"],
)
print(
"{} ({}-{})".format(tgt_words, speaker, id),
file=res_files["ref.words"],
)
# only score top hypothesis
if not args.quiet:
logger.debug("HYPO:" + hyp_words)
logger.debug("TARGET:" + tgt_words)
logger.debug("___________________")
def prepare_result_files(args):
def get_res_file(file_prefix):
path = os.path.join(
args.results_path,
"{}-{}-{}.txt".format(
file_prefix, os.path.basename(args.path), args.gen_subset
),
)
return open(path, "w", buffering=1)
return {
"hypo.words": get_res_file("hypo.word"),
"hypo.units": get_res_file("hypo.units"),
"ref.words": get_res_file("ref.word"),
"ref.units": get_res_file("ref.units"),
}
def optimize_models(args, use_cuda, models):
"""Optimize ensemble for generation
"""
for model in models:
model.make_generation_fast_(
beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
need_attn=args.print_alignment,
)
if args.fp16:
model.half()
if use_cuda:
model.cuda()
def main(args):
check_args(args)
import_user_module(args)
if args.max_tokens is None and args.max_sentences is None:
args.max_tokens = 30000
logger.info(args)
use_cuda = torch.cuda.is_available() and not args.cpu
# Load dataset splits
task = tasks.setup_task(args)
task.load_dataset(args.gen_subset)
logger.info(
"| {} {} {} examples".format(
args.data, args.gen_subset, len(task.dataset(args.gen_subset))
)
)
# Set dictionary
tgt_dict = task.target_dictionary
if args.ctc or args.rnnt:
tgt_dict.add_symbol("<ctc_blank>")
if args.ctc:
logger.info("| decoding a ctc model")
if args.rnnt:
logger.info("| decoding a rnnt model")
# Load ensemble
logger.info("| loading model(s) from {}".format(args.path))
models, _model_args = utils.load_ensemble_for_inference(
args.path.split(":"),
task,
model_arg_overrides=eval(args.model_overrides), # noqa
)
optimize_models(args, use_cuda, models)
# Load dataset (possibly sharded)
itr = get_dataset_itr(args, task)
# Initialize generator
gen_timer = StopwatchMeter()
generator = task.build_generator(args)
num_sentences = 0
if not os.path.exists(args.results_path):
os.makedirs(args.results_path)
sp = spm.SentencePieceProcessor()
sp.Load(os.path.join(args.data, 'spm.model'))
res_files = prepare_result_files(args)
with progress_bar.build_progress_bar(args, itr) as t:
wps_meter = TimeMeter()
for sample in t:
sample = utils.move_to_cuda(sample) if use_cuda else sample
if "net_input" not in sample:
continue
prefix_tokens = None
if args.prefix_size > 0:
prefix_tokens = sample["target"][:, : args.prefix_size]
gen_timer.start()
hypos = task.inference_step(generator, models, sample, prefix_tokens)
num_generated_tokens = sum(len(h[0]["tokens"]) for h in hypos)
gen_timer.stop(num_generated_tokens)
for i, sample_id in enumerate(sample['id'].tolist()):
speaker = task.dataset(args.gen_subset).speakers[int(sample_id)]
id = task.dataset(args.gen_subset).ids[int(sample_id)]
target_tokens = (
utils.strip_pad(sample["target"][i, :], tgt_dict.pad()).int().cpu()
)
# Process top predictions
process_predictions(
args, hypos[i], sp, tgt_dict, target_tokens, res_files, speaker, id
)
wps_meter.update(num_generated_tokens)
t.log({"wps": round(wps_meter.avg)})
num_sentences += sample["nsentences"]
logger.info(
"| Processed {} sentences ({} tokens) in {:.1f}s ({:.2f}"
"sentences/s, {:.2f} tokens/s)".format(
num_sentences,
gen_timer.n,
gen_timer.sum,
num_sentences / gen_timer.sum,
1.0 / gen_timer.avg,
)
)
logger.info("| Generate {} with beam={}".format(args.gen_subset, args.beam))
def cli_main():
parser = options.get_generation_parser()
parser = add_asr_eval_argument(parser)
args = options.parse_args_and_arch(parser)
main(args)
if __name__ == "__main__":
cli_main()

View File

@ -0,0 +1,7 @@
import importlib
import os
for file in os.listdir(os.path.dirname(__file__)):
if file.endswith('.py') and not file.startswith('_'):
model_name = file[:file.find('.py')]
importlib.import_module('examples.speech_recognition.models.' + model_name)

View File

@ -0,0 +1,838 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import math
from collections.abc import Iterable
import torch
import torch.nn as nn
from fairseq import utils
from fairseq.models import (
FairseqEncoder,
FairseqIncrementalDecoder,
FairseqModel,
register_model,
register_model_architecture,
)
from fairseq.modules import LinearizedConvolution
from examples.speech_recognition.data.data_utils import lengths_to_encoder_padding_mask
from fairseq.modules import TransformerDecoderLayer, TransformerEncoderLayer, VGGBlock
@register_model("asr_vggtransformer")
class VGGTransformerModel(FairseqModel):
"""
Transformers with convolutional context for ASR
https://arxiv.org/abs/1904.11660
"""
def __init__(self, encoder, decoder):
super().__init__(encoder, decoder)
@staticmethod
def add_args(parser):
"""Add model-specific arguments to the parser."""
parser.add_argument(
"--input-feat-per-channel",
type=int,
metavar="N",
help="encoder input dimension per input channel",
)
parser.add_argument(
"--vggblock-enc-config",
type=str,
metavar="EXPR",
help="""
an array of tuples each containing the configuration of one vggblock:
[(out_channels,
conv_kernel_size,
pooling_kernel_size,
num_conv_layers,
use_layer_norm), ...])
""",
)
parser.add_argument(
"--transformer-enc-config",
type=str,
metavar="EXPR",
help=""""
a tuple containing the configuration of the encoder transformer layers
configurations:
[(input_dim,
num_heads,
ffn_dim,
normalize_before,
dropout,
attention_dropout,
relu_dropout), ...]')
""",
)
parser.add_argument(
"--enc-output-dim",
type=int,
metavar="N",
help="""
encoder output dimension, can be None. If specified, projecting the
transformer output to the specified dimension""",
)
parser.add_argument(
"--in-channels",
type=int,
metavar="N",
help="number of encoder input channels",
)
parser.add_argument(
"--tgt-embed-dim",
type=int,
metavar="N",
help="embedding dimension of the decoder target tokens",
)
parser.add_argument(
"--transformer-dec-config",
type=str,
metavar="EXPR",
help="""
a tuple containing the configuration of the decoder transformer layers
configurations:
[(input_dim,
num_heads,
ffn_dim,
normalize_before,
dropout,
attention_dropout,
relu_dropout), ...]
""",
)
parser.add_argument(
"--conv-dec-config",
type=str,
metavar="EXPR",
help="""
an array of tuples for the decoder 1-D convolution config
[(out_channels, conv_kernel_size, use_layer_norm), ...]""",
)
@classmethod
def build_encoder(cls, args, task):
return VGGTransformerEncoder(
input_feat_per_channel=args.input_feat_per_channel,
vggblock_config=eval(args.vggblock_enc_config),
transformer_config=eval(args.transformer_enc_config),
encoder_output_dim=args.enc_output_dim,
in_channels=args.in_channels,
)
@classmethod
def build_decoder(cls, args, task):
return TransformerDecoder(
dictionary=task.target_dictionary,
embed_dim=args.tgt_embed_dim,
transformer_config=eval(args.transformer_dec_config),
conv_config=eval(args.conv_dec_config),
encoder_output_dim=args.enc_output_dim,
)
@classmethod
def build_model(cls, args, task):
"""Build a new model instance."""
# make sure that all args are properly defaulted
# (in case there are any new ones)
base_architecture(args)
encoder = cls.build_encoder(args, task)
decoder = cls.build_decoder(args, task)
return cls(encoder, decoder)
def get_normalized_probs(self, net_output, log_probs, sample=None):
# net_output['encoder_out'] is a (B, T, D) tensor
lprobs = super().get_normalized_probs(net_output, log_probs, sample)
lprobs.batch_first = True
return lprobs
DEFAULT_ENC_VGGBLOCK_CONFIG = ((32, 3, 2, 2, False),) * 2
DEFAULT_ENC_TRANSFORMER_CONFIG = ((256, 4, 1024, True, 0.2, 0.2, 0.2),) * 2
# 256: embedding dimension
# 4: number of heads
# 1024: FFN
# True: apply layerNorm before (dropout + resiaul) instead of after
# 0.2 (dropout): dropout after MultiheadAttention and second FC
# 0.2 (attention_dropout): dropout in MultiheadAttention
# 0.2 (relu_dropout): dropout after ReLu
DEFAULT_DEC_TRANSFORMER_CONFIG = ((256, 2, 1024, True, 0.2, 0.2, 0.2),) * 2
DEFAULT_DEC_CONV_CONFIG = ((256, 3, True),) * 2
# TODO: repace transformer encoder config from one liner
# to explicit args to get rid of this transformation
def prepare_transformer_encoder_params(
input_dim,
num_heads,
ffn_dim,
normalize_before,
dropout,
attention_dropout,
relu_dropout,
):
args = argparse.Namespace()
args.encoder_embed_dim = input_dim
args.encoder_attention_heads = num_heads
args.attention_dropout = attention_dropout
args.dropout = dropout
args.activation_dropout = relu_dropout
args.encoder_normalize_before = normalize_before
args.encoder_ffn_embed_dim = ffn_dim
return args
def prepare_transformer_decoder_params(
input_dim,
num_heads,
ffn_dim,
normalize_before,
dropout,
attention_dropout,
relu_dropout,
):
args = argparse.Namespace()
args.decoder_embed_dim = input_dim
args.decoder_attention_heads = num_heads
args.attention_dropout = attention_dropout
args.dropout = dropout
args.activation_dropout = relu_dropout
args.decoder_normalize_before = normalize_before
args.decoder_ffn_embed_dim = ffn_dim
return args
class VGGTransformerEncoder(FairseqEncoder):
"""VGG + Transformer encoder"""
def __init__(
self,
input_feat_per_channel,
vggblock_config=DEFAULT_ENC_VGGBLOCK_CONFIG,
transformer_config=DEFAULT_ENC_TRANSFORMER_CONFIG,
encoder_output_dim=512,
in_channels=1,
transformer_context=None,
transformer_sampling=None,
):
"""constructor for VGGTransformerEncoder
Args:
- input_feat_per_channel: feature dim (not including stacked,
just base feature)
- in_channel: # input channels (e.g., if stack 8 feature vector
together, this is 8)
- vggblock_config: configuration of vggblock, see comments on
DEFAULT_ENC_VGGBLOCK_CONFIG
- transformer_config: configuration of transformer layer, see comments
on DEFAULT_ENC_TRANSFORMER_CONFIG
- encoder_output_dim: final transformer output embedding dimension
- transformer_context: (left, right) if set, self-attention will be focused
on (t-left, t+right)
- transformer_sampling: an iterable of int, must match with
len(transformer_config), transformer_sampling[i] indicates sampling
factor for i-th transformer layer, after multihead att and feedfoward
part
"""
super().__init__(None)
self.num_vggblocks = 0
if vggblock_config is not None:
if not isinstance(vggblock_config, Iterable):
raise ValueError("vggblock_config is not iterable")
self.num_vggblocks = len(vggblock_config)
self.conv_layers = nn.ModuleList()
self.in_channels = in_channels
self.input_dim = input_feat_per_channel
if vggblock_config is not None:
for _, config in enumerate(vggblock_config):
(
out_channels,
conv_kernel_size,
pooling_kernel_size,
num_conv_layers,
layer_norm,
) = config
self.conv_layers.append(
VGGBlock(
in_channels,
out_channels,
conv_kernel_size,
pooling_kernel_size,
num_conv_layers,
input_dim=input_feat_per_channel,
layer_norm=layer_norm,
)
)
in_channels = out_channels
input_feat_per_channel = self.conv_layers[-1].output_dim
transformer_input_dim = self.infer_conv_output_dim(
self.in_channels, self.input_dim
)
# transformer_input_dim is the output dimension of VGG part
self.validate_transformer_config(transformer_config)
self.transformer_context = self.parse_transformer_context(transformer_context)
self.transformer_sampling = self.parse_transformer_sampling(
transformer_sampling, len(transformer_config)
)
self.transformer_layers = nn.ModuleList()
if transformer_input_dim != transformer_config[0][0]:
self.transformer_layers.append(
Linear(transformer_input_dim, transformer_config[0][0])
)
self.transformer_layers.append(
TransformerEncoderLayer(
prepare_transformer_encoder_params(*transformer_config[0])
)
)
for i in range(1, len(transformer_config)):
if transformer_config[i - 1][0] != transformer_config[i][0]:
self.transformer_layers.append(
Linear(transformer_config[i - 1][0], transformer_config[i][0])
)
self.transformer_layers.append(
TransformerEncoderLayer(
prepare_transformer_encoder_params(*transformer_config[i])
)
)
self.encoder_output_dim = encoder_output_dim
self.transformer_layers.extend(
[
Linear(transformer_config[-1][0], encoder_output_dim),
LayerNorm(encoder_output_dim),
]
)
def forward(self, src_tokens, src_lengths, **kwargs):
"""
src_tokens: padded tensor (B, T, C * feat)
src_lengths: tensor of original lengths of input utterances (B,)
"""
bsz, max_seq_len, _ = src_tokens.size()
x = src_tokens.view(bsz, max_seq_len, self.in_channels, self.input_dim)
x = x.transpose(1, 2).contiguous()
# (B, C, T, feat)
for layer_idx in range(len(self.conv_layers)):
x = self.conv_layers[layer_idx](x)
bsz, _, output_seq_len, _ = x.size()
# (B, C, T, feat) -> (B, T, C, feat) -> (T, B, C, feat) -> (T, B, C * feat)
x = x.transpose(1, 2).transpose(0, 1)
x = x.contiguous().view(output_seq_len, bsz, -1)
subsampling_factor = int(max_seq_len * 1.0 / output_seq_len + 0.5)
# TODO: shouldn't subsampling_factor determined in advance ?
input_lengths = (src_lengths.float() / subsampling_factor).ceil().long()
encoder_padding_mask, _ = lengths_to_encoder_padding_mask(
input_lengths, batch_first=True
)
if not encoder_padding_mask.any():
encoder_padding_mask = None
attn_mask = self.lengths_to_attn_mask(input_lengths, subsampling_factor)
transformer_layer_idx = 0
for layer_idx in range(len(self.transformer_layers)):
if isinstance(self.transformer_layers[layer_idx], TransformerEncoderLayer):
x = self.transformer_layers[layer_idx](
x, encoder_padding_mask, attn_mask
)
if self.transformer_sampling[transformer_layer_idx] != 1:
sampling_factor = self.transformer_sampling[transformer_layer_idx]
x, encoder_padding_mask, attn_mask = self.slice(
x, encoder_padding_mask, attn_mask, sampling_factor
)
transformer_layer_idx += 1
else:
x = self.transformer_layers[layer_idx](x)
# encoder_padding_maks is a (T x B) tensor, its [t, b] elements indicate
# whether encoder_output[t, b] is valid or not (valid=0, invalid=1)
return {
"encoder_out": x, # (T, B, C)
"encoder_padding_mask": encoder_padding_mask.t()
if encoder_padding_mask is not None
else None,
# (B, T) --> (T, B)
}
def infer_conv_output_dim(self, in_channels, input_dim):
sample_seq_len = 200
sample_bsz = 10
x = torch.randn(sample_bsz, in_channels, sample_seq_len, input_dim)
for i, _ in enumerate(self.conv_layers):
x = self.conv_layers[i](x)
x = x.transpose(1, 2)
mb, seq = x.size()[:2]
return x.contiguous().view(mb, seq, -1).size(-1)
def validate_transformer_config(self, transformer_config):
for config in transformer_config:
input_dim, num_heads = config[:2]
if input_dim % num_heads != 0:
msg = (
"ERROR in transformer config {}:".format(config)
+ "input dimension {} ".format(input_dim)
+ "not dividable by number of heads".format(num_heads)
)
raise ValueError(msg)
def parse_transformer_context(self, transformer_context):
"""
transformer_context can be the following:
- None; indicates no context is used, i.e.,
transformer can access full context
- a tuple/list of two int; indicates left and right context,
any number <0 indicates infinite context
* e.g., (5, 6) indicates that for query at x_t, transformer can
access [t-5, t+6] (inclusive)
* e.g., (-1, 6) indicates that for query at x_t, transformer can
access [0, t+6] (inclusive)
"""
if transformer_context is None:
return None
if not isinstance(transformer_context, Iterable):
raise ValueError("transformer context must be Iterable if it is not None")
if len(transformer_context) != 2:
raise ValueError("transformer context must have length 2")
left_context = transformer_context[0]
if left_context < 0:
left_context = None
right_context = transformer_context[1]
if right_context < 0:
right_context = None
if left_context is None and right_context is None:
return None
return (left_context, right_context)
def parse_transformer_sampling(self, transformer_sampling, num_layers):
"""
parsing transformer sampling configuration
Args:
- transformer_sampling, accepted input:
* None, indicating no sampling
* an Iterable with int (>0) as element
- num_layers, expected number of transformer layers, must match with
the length of transformer_sampling if it is not None
Returns:
- A tuple with length num_layers
"""
if transformer_sampling is None:
return (1,) * num_layers
if not isinstance(transformer_sampling, Iterable):
raise ValueError(
"transformer_sampling must be an iterable if it is not None"
)
if len(transformer_sampling) != num_layers:
raise ValueError(
"transformer_sampling {} does not match with the number "
+ "of layers {}".format(transformer_sampling, num_layers)
)
for layer, value in enumerate(transformer_sampling):
if not isinstance(value, int):
raise ValueError("Invalid value in transformer_sampling: ")
if value < 1:
raise ValueError(
"{} layer's subsampling is {}.".format(layer, value)
+ " This is not allowed! "
)
return transformer_sampling
def slice(self, embedding, padding_mask, attn_mask, sampling_factor):
"""
embedding is a (T, B, D) tensor
padding_mask is a (B, T) tensor or None
attn_mask is a (T, T) tensor or None
"""
embedding = embedding[::sampling_factor, :, :]
if padding_mask is not None:
padding_mask = padding_mask[:, ::sampling_factor]
if attn_mask is not None:
attn_mask = attn_mask[::sampling_factor, ::sampling_factor]
return embedding, padding_mask, attn_mask
def lengths_to_attn_mask(self, input_lengths, subsampling_factor=1):
"""
create attention mask according to sequence lengths and transformer
context
Args:
- input_lengths: (B, )-shape Int/Long tensor; input_lengths[b] is
the length of b-th sequence
- subsampling_factor: int
* Note that the left_context and right_context is specified in
the input frame-level while input to transformer may already
go through subsampling (e.g., the use of striding in vggblock)
we use subsampling_factor to scale the left/right context
Return:
- a (T, T) binary tensor or None, where T is max(input_lengths)
* if self.transformer_context is None, None
* if left_context is None,
* attn_mask[t, t + right_context + 1:] = 1
* others = 0
* if right_context is None,
* attn_mask[t, 0:t - left_context] = 1
* others = 0
* elsif
* attn_mask[t, t - left_context: t + right_context + 1] = 0
* others = 1
"""
if self.transformer_context is None:
return None
maxT = torch.max(input_lengths).item()
attn_mask = torch.zeros(maxT, maxT)
left_context = self.transformer_context[0]
right_context = self.transformer_context[1]
if left_context is not None:
left_context = math.ceil(self.transformer_context[0] / subsampling_factor)
if right_context is not None:
right_context = math.ceil(self.transformer_context[1] / subsampling_factor)
for t in range(maxT):
if left_context is not None:
st = 0
en = max(st, t - left_context)
attn_mask[t, st:en] = 1
if right_context is not None:
st = t + right_context + 1
st = min(st, maxT - 1)
attn_mask[t, st:] = 1
return attn_mask.to(input_lengths.device)
def reorder_encoder_out(self, encoder_out, new_order):
encoder_out["encoder_out"] = encoder_out["encoder_out"].index_select(
1, new_order
)
if encoder_out["encoder_padding_mask"] is not None:
encoder_out["encoder_padding_mask"] = encoder_out[
"encoder_padding_mask"
].index_select(1, new_order)
return encoder_out
class TransformerDecoder(FairseqIncrementalDecoder):
"""
Transformer decoder consisting of *args.decoder_layers* layers. Each layer
is a :class:`TransformerDecoderLayer`.
Args:
args (argparse.Namespace): parsed command-line arguments
dictionary (~fairseq.data.Dictionary): decoding dictionary
embed_tokens (torch.nn.Embedding): output embedding
no_encoder_attn (bool, optional): whether to attend to encoder outputs.
Default: ``False``
left_pad (bool, optional): whether the input is left-padded. Default:
``False``
"""
def __init__(
self,
dictionary,
embed_dim=512,
transformer_config=DEFAULT_ENC_TRANSFORMER_CONFIG,
conv_config=DEFAULT_DEC_CONV_CONFIG,
encoder_output_dim=512,
):
super().__init__(dictionary)
vocab_size = len(dictionary)
self.padding_idx = dictionary.pad()
self.embed_tokens = Embedding(vocab_size, embed_dim, self.padding_idx)
self.conv_layers = nn.ModuleList()
for i in range(len(conv_config)):
out_channels, kernel_size, layer_norm = conv_config[i]
if i == 0:
conv_layer = LinearizedConv1d(
embed_dim, out_channels, kernel_size, padding=kernel_size - 1
)
else:
conv_layer = LinearizedConv1d(
conv_config[i - 1][0],
out_channels,
kernel_size,
padding=kernel_size - 1,
)
self.conv_layers.append(conv_layer)
if layer_norm:
self.conv_layers.append(nn.LayerNorm(out_channels))
self.conv_layers.append(nn.ReLU())
self.layers = nn.ModuleList()
if conv_config[-1][0] != transformer_config[0][0]:
self.layers.append(Linear(conv_config[-1][0], transformer_config[0][0]))
self.layers.append(TransformerDecoderLayer(
prepare_transformer_decoder_params(*transformer_config[0])
))
for i in range(1, len(transformer_config)):
if transformer_config[i - 1][0] != transformer_config[i][0]:
self.layers.append(
Linear(transformer_config[i - 1][0], transformer_config[i][0])
)
self.layers.append(TransformerDecoderLayer(
prepare_transformer_decoder_params(*transformer_config[i])
))
self.fc_out = Linear(transformer_config[-1][0], vocab_size)
def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None):
"""
Args:
prev_output_tokens (LongTensor): previous decoder outputs of shape
`(batch, tgt_len)`, for input feeding/teacher forcing
encoder_out (Tensor, optional): output from the encoder, used for
encoder-side attention
incremental_state (dict): dictionary used for storing state during
:ref:`Incremental decoding`
Returns:
tuple:
- the last decoder layer's output of shape `(batch, tgt_len,
vocab)`
- the last decoder layer's attention weights of shape `(batch,
tgt_len, src_len)`
"""
target_padding_mask = (
(prev_output_tokens == self.padding_idx).to(prev_output_tokens.device)
if incremental_state is None
else None
)
if incremental_state is not None:
prev_output_tokens = prev_output_tokens[:, -1:]
# embed tokens
x = self.embed_tokens(prev_output_tokens)
# B x T x C -> T x B x C
x = self._transpose_if_training(x, incremental_state)
for layer in self.conv_layers:
if isinstance(layer, LinearizedConvolution):
x = layer(x, incremental_state)
else:
x = layer(x)
# B x T x C -> T x B x C
x = self._transpose_if_inference(x, incremental_state)
# decoder layers
for layer in self.layers:
if isinstance(layer, TransformerDecoderLayer):
x, _ = layer(
x,
(encoder_out["encoder_out"] if encoder_out is not None else None),
(
encoder_out["encoder_padding_mask"].t()
if encoder_out["encoder_padding_mask"] is not None
else None
),
incremental_state,
self_attn_mask=(
self.buffered_future_mask(x)
if incremental_state is None
else None
),
self_attn_padding_mask=(
target_padding_mask if incremental_state is None else None
),
)
else:
x = layer(x)
# T x B x C -> B x T x C
x = x.transpose(0, 1)
x = self.fc_out(x)
return x, None
def buffered_future_mask(self, tensor):
dim = tensor.size(0)
if (
not hasattr(self, "_future_mask")
or self._future_mask is None
or self._future_mask.device != tensor.device
):
self._future_mask = torch.triu(
utils.fill_with_neg_inf(tensor.new(dim, dim)), 1
)
if self._future_mask.size(0) < dim:
self._future_mask = torch.triu(
utils.fill_with_neg_inf(self._future_mask.resize_(dim, dim)), 1
)
return self._future_mask[:dim, :dim]
def _transpose_if_training(self, x, incremental_state):
if incremental_state is None:
x = x.transpose(0, 1)
return x
def _transpose_if_inference(self, x, incremental_state):
if incremental_state:
x = x.transpose(0, 1)
return x
def Embedding(num_embeddings, embedding_dim, padding_idx):
m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
# nn.init.uniform_(m.weight, -0.1, 0.1)
# nn.init.constant_(m.weight[padding_idx], 0)
return m
def Linear(in_features, out_features, bias=True, dropout=0):
"""Linear layer (input: N x T x C)"""
m = nn.Linear(in_features, out_features, bias=bias)
# m.weight.data.uniform_(-0.1, 0.1)
# if bias:
# m.bias.data.uniform_(-0.1, 0.1)
return m
def LinearizedConv1d(in_channels, out_channels, kernel_size, dropout=0, **kwargs):
"""Weight-normalized Conv1d layer optimized for decoding"""
m = LinearizedConvolution(in_channels, out_channels, kernel_size, **kwargs)
std = math.sqrt((4 * (1.0 - dropout)) / (m.kernel_size[0] * in_channels))
nn.init.normal_(m.weight, mean=0, std=std)
nn.init.constant_(m.bias, 0)
return nn.utils.weight_norm(m, dim=2)
def LayerNorm(embedding_dim):
m = nn.LayerNorm(embedding_dim)
return m
# seq2seq models
def base_architecture(args):
args.input_feat_per_channel = getattr(args, "input_feat_per_channel", 40)
args.vggblock_enc_config = getattr(
args, "vggblock_enc_config", DEFAULT_ENC_VGGBLOCK_CONFIG
)
args.transformer_enc_config = getattr(
args, "transformer_enc_config", DEFAULT_ENC_TRANSFORMER_CONFIG
)
args.enc_output_dim = getattr(args, "enc_output_dim", 512)
args.in_channels = getattr(args, "in_channels", 1)
args.tgt_embed_dim = getattr(args, "tgt_embed_dim", 128)
args.transformer_dec_config = getattr(
args, "transformer_dec_config", DEFAULT_ENC_TRANSFORMER_CONFIG
)
args.conv_dec_config = getattr(args, "conv_dec_config", DEFAULT_DEC_CONV_CONFIG)
args.transformer_context = getattr(args, "transformer_context", "None")
@register_model_architecture("asr_vggtransformer", "vggtransformer_1")
def vggtransformer_1(args):
args.input_feat_per_channel = getattr(args, "input_feat_per_channel", 80)
args.vggblock_enc_config = getattr(
args, "vggblock_enc_config", "[(64, 3, 2, 2, True), (128, 3, 2, 2, True)]"
)
args.transformer_enc_config = getattr(
args,
"transformer_enc_config",
"((1024, 16, 4096, True, 0.15, 0.15, 0.15),) * 14",
)
args.enc_output_dim = getattr(args, "enc_output_dim", 1024)
args.tgt_embed_dim = getattr(args, "tgt_embed_dim", 128)
args.conv_dec_config = getattr(args, "conv_dec_config", "((256, 3, True),) * 4")
args.transformer_dec_config = getattr(
args,
"transformer_dec_config",
"((1024, 16, 4096, True, 0.15, 0.15, 0.15),) * 4",
)
@register_model_architecture("asr_vggtransformer", "vggtransformer_2")
def vggtransformer_2(args):
args.input_feat_per_channel = getattr(args, "input_feat_per_channel", 80)
args.vggblock_enc_config = getattr(
args, "vggblock_enc_config", "[(64, 3, 2, 2, True), (128, 3, 2, 2, True)]"
)
args.transformer_enc_config = getattr(
args,
"transformer_enc_config",
"((1024, 16, 4096, True, 0.15, 0.15, 0.15),) * 16",
)
args.enc_output_dim = getattr(args, "enc_output_dim", 1024)
args.tgt_embed_dim = getattr(args, "tgt_embed_dim", 512)
args.conv_dec_config = getattr(args, "conv_dec_config", "((256, 3, True),) * 4")
args.transformer_dec_config = getattr(
args,
"transformer_dec_config",
"((1024, 16, 4096, True, 0.15, 0.15, 0.15),) * 6",
)
@register_model_architecture("asr_vggtransformer", "vggtransformer_base")
def vggtransformer_base(args):
args.input_feat_per_channel = getattr(args, "input_feat_per_channel", 80)
args.vggblock_enc_config = getattr(
args, "vggblock_enc_config", "[(64, 3, 2, 2, True), (128, 3, 2, 2, True)]"
)
args.transformer_enc_config = getattr(
args, "transformer_enc_config", "((512, 8, 2048, True, 0.15, 0.15, 0.15),) * 12"
)
args.enc_output_dim = getattr(args, "enc_output_dim", 512)
args.tgt_embed_dim = getattr(args, "tgt_embed_dim", 512)
args.conv_dec_config = getattr(args, "conv_dec_config", "((256, 3, True),) * 4")
args.transformer_dec_config = getattr(
args, "transformer_dec_config", "((512, 8, 2048, True, 0.15, 0.15, 0.15),) * 6"
)
# Size estimations:
# Encoder:
# - vggblock param: 64*1*3*3 + 64*64*3*3 + 128*64*3*3 + 128*128*3 = 258K
# Transformer:
# - input dimension adapter: 2560 x 512 -> 1.31M
# - transformer_layers (x12) --> 37.74M
# * MultiheadAttention: 512*512*3 (in_proj) + 512*512 (out_proj) = 1.048M
# * FFN weight: 512*2048*2 = 2.097M
# - output dimension adapter: 512 x 512 -> 0.26 M
# Decoder:
# - LinearizedConv1d: 512 * 256 * 3 + 256 * 256 * 3 * 3
# - transformer_layer: (x6) --> 25.16M
# * MultiheadAttention (self-attention): 512*512*3 + 512*512 = 1.048M
# * MultiheadAttention (encoder-attention): 512*512*3 + 512*512 = 1.048M
# * FFN: 512*2048*2 = 2.097M
# Final FC:
# - FC: 512*5000 = 256K (assuming vocab size 5K)
# In total:
# ~65 M

View File

@ -0,0 +1,7 @@
import importlib
import os
for file in os.listdir(os.path.dirname(__file__)):
if file.endswith('.py') and not file.startswith('_'):
task_name = file[:file.find('.py')]
importlib.import_module('examples.speech_recognition.tasks.' + task_name)

View File

@ -0,0 +1,116 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import json
import os
import re
import torch
from fairseq.data import Dictionary
from fairseq.tasks import FairseqTask, register_task
from examples.speech_recognition.data import AsrDataset
def get_asr_dataset_from_json(data_json_path, tgt_dict):
"""
Parse data json and create dataset.
See scripts/asr_prep_json.py which pack json from raw files
Json example:
{
"utts": {
"4771-29403-0025": {
"input": {
"length_ms": 170,
"path": "/tmp/file1.flac"
},
"output": {
"text": "HELLO \n",
"token": "HE LLO",
"tokenid": "4815, 861"
}
},
"1564-142299-0096": {
...
}
}
"""
if not os.path.isfile(data_json_path):
raise FileNotFoundError("Dataset not found: {}".format(data_json_path))
with open(data_json_path, "rb") as f:
data_samples = json.load(f)["utts"]
assert len(data_samples) != 0
sorted_samples = sorted(
data_samples.items(),
key=lambda sample: int(sample[1]["input"]["length_ms"]),
reverse=True,
)
aud_paths = [s[1]["input"]["path"] for s in sorted_samples]
ids = [s[0] for s in sorted_samples]
speakers = []
for s in sorted_samples:
m = re.search("(.+?)-(.+?)-(.+?)", s[0])
speakers.append(m.group(1) + "_" + m.group(2))
frame_sizes = [s[1]["input"]["length_ms"] for s in sorted_samples]
tgt = [
torch.LongTensor(
[int(i) for i in s[1]["output"]["tokenid"].split(", ")]
)
for s in sorted_samples
]
# append eos
tgt = [torch.cat([t, torch.LongTensor([tgt_dict.eos()])]) for t in tgt]
return AsrDataset(
aud_paths, frame_sizes, tgt, tgt_dict, ids, speakers
)
@register_task("speech_recognition")
class SpeechRecognitionTask(FairseqTask):
"""
Task for training speech recognition model.
"""
@staticmethod
def add_args(parser):
"""Add task-specific arguments to the parser."""
parser.add_argument("data", help="path to data directory")
def __init__(self, args, tgt_dict):
super().__init__(args)
self.tgt_dict = tgt_dict
@classmethod
def setup_task(cls, args, **kwargs):
"""Setup the task (e.g., load dictionaries)."""
dict_path = os.path.join(args.data, "dict.txt")
if not os.path.isfile(dict_path):
raise FileNotFoundError("Dict not found: {}".format(dict_path))
tgt_dict = Dictionary.load(dict_path)
print("| dictionary: {} types".format(len(tgt_dict)))
return cls(args, tgt_dict)
def load_dataset(self, split, combine=False, **kwargs):
"""Load a given dataset split.
Args:
split (str): name of the split (e.g., train, valid, test)
"""
data_json_path = os.path.join(self.args.data, "{}.json".format(split))
self.datasets[split] = get_asr_dataset_from_json(
data_json_path, self.tgt_dict)
@property
def target_dictionary(self):
"""Return the :class:`~fairseq.data.Dictionary` for the language
model."""
return self.tgt_dict
@property
def source_dictionary(self):
"""Return the source :class:`~fairseq.data.Dictionary` (if applicable
for this task)."""
return None

View File

@ -23,6 +23,8 @@ from fairseq.modules import (
MultiheadAttention,
PositionalEmbedding,
SinusoidalPositionalEmbedding,
TransformerDecoderLayer,
TransformerEncoderLayer,
)
DEFAULT_MAX_SOURCE_POSITIONS = 1024
@ -504,253 +506,6 @@ class TransformerDecoder(FairseqIncrementalDecoder):
return state_dict
class TransformerEncoderLayer(nn.Module):
"""Encoder layer block.
In the original paper each operation (multi-head attention or FFN) is
postprocessed with: `dropout -> add residual -> layernorm`. In the
tensor2tensor code they suggest that learning is more robust when
preprocessing each layer with layernorm and postprocessing with:
`dropout -> add residual`. We default to the approach in the paper, but the
tensor2tensor approach can be enabled by setting
*args.encoder_normalize_before* to ``True``.
Args:
args (argparse.Namespace): parsed command-line arguments
"""
def __init__(self, args):
super().__init__()
self.embed_dim = args.encoder_embed_dim
self.self_attn = MultiheadAttention(
self.embed_dim, args.encoder_attention_heads,
dropout=args.attention_dropout, self_attention=True
)
self.self_attn_layer_norm = LayerNorm(self.embed_dim)
self.dropout = args.dropout
self.activation_fn = utils.get_activation_fn(
activation=getattr(args, 'activation_fn', 'relu')
)
self.activation_dropout = getattr(args, 'activation_dropout', 0)
if self.activation_dropout == 0:
# for backwards compatibility with models that use args.relu_dropout
self.activation_dropout = getattr(args, 'relu_dropout', 0)
self.normalize_before = args.encoder_normalize_before
self.fc1 = Linear(self.embed_dim, args.encoder_ffn_embed_dim)
self.fc2 = Linear(args.encoder_ffn_embed_dim, self.embed_dim)
self.final_layer_norm = LayerNorm(self.embed_dim)
def upgrade_state_dict_named(self, state_dict, name):
"""
Rename layer norm states from `...layer_norms.0.weight` to
`...self_attn_layer_norm.weight` and `...layer_norms.1.weight` to
`...final_layer_norm.weight`
"""
layer_norm_map = {
'0': 'self_attn_layer_norm',
'1': 'final_layer_norm'
}
for old, new in layer_norm_map.items():
for m in ('weight', 'bias'):
k = '{}.layer_norms.{}.{}'.format(name, old, m)
if k in state_dict:
state_dict[
'{}.{}.{}'.format(name, new, m)
] = state_dict[k]
del state_dict[k]
def forward(self, x, encoder_padding_mask):
"""
Args:
x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
encoder_padding_mask (ByteTensor): binary ByteTensor of shape
`(batch, src_len)` where padding elements are indicated by ``1``.
Returns:
encoded output of shape `(seq_len, batch, embed_dim)`
"""
residual = x
x = self.maybe_layer_norm(self.self_attn_layer_norm, x, before=True)
x, _ = self.self_attn(query=x, key=x, value=x, key_padding_mask=encoder_padding_mask)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
x = self.maybe_layer_norm(self.self_attn_layer_norm, x, after=True)
residual = x
x = self.maybe_layer_norm(self.final_layer_norm, x, before=True)
x = self.activation_fn(self.fc1(x))
x = F.dropout(x, p=self.activation_dropout, training=self.training)
x = self.fc2(x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
x = self.maybe_layer_norm(self.final_layer_norm, x, after=True)
return x
def maybe_layer_norm(self, layer_norm, x, before=False, after=False):
assert before ^ after
if after ^ self.normalize_before:
return layer_norm(x)
else:
return x
class TransformerDecoderLayer(nn.Module):
"""Decoder layer block.
In the original paper each operation (multi-head attention, encoder
attention or FFN) is postprocessed with: `dropout -> add residual ->
layernorm`. In the tensor2tensor code they suggest that learning is more
robust when preprocessing each layer with layernorm and postprocessing with:
`dropout -> add residual`. We default to the approach in the paper, but the
tensor2tensor approach can be enabled by setting
*args.decoder_normalize_before* to ``True``.
Args:
args (argparse.Namespace): parsed command-line arguments
no_encoder_attn (bool, optional): whether to attend to encoder outputs
(default: False).
"""
def __init__(self, args, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False):
super().__init__()
self.embed_dim = args.decoder_embed_dim
self.self_attn = MultiheadAttention(
embed_dim=self.embed_dim,
num_heads=args.decoder_attention_heads,
dropout=args.attention_dropout,
add_bias_kv=add_bias_kv,
add_zero_attn=add_zero_attn,
self_attention=True
)
self.dropout = args.dropout
self.activation_fn = utils.get_activation_fn(
activation=getattr(args, 'activation_fn', 'relu')
)
self.activation_dropout = getattr(args, 'activation_dropout', 0)
if self.activation_dropout == 0:
# for backwards compatibility with models that use args.relu_dropout
self.activation_dropout = getattr(args, 'relu_dropout', 0)
self.normalize_before = args.decoder_normalize_before
# use layerNorm rather than FusedLayerNorm for exporting.
# char_inputs can be used to determint this.
# TODO remove this once we update apex with the fix
export = getattr(args, 'char_inputs', False)
self.self_attn_layer_norm = LayerNorm(self.embed_dim, export=export)
if no_encoder_attn:
self.encoder_attn = None
self.encoder_attn_layer_norm = None
else:
self.encoder_attn = MultiheadAttention(
self.embed_dim,
args.decoder_attention_heads,
kdim=getattr(args, 'encoder_embed_dim', None),
vdim=getattr(args, 'encoder_embed_dim', None),
dropout=args.attention_dropout,
encoder_decoder_attention=True,
)
self.encoder_attn_layer_norm = LayerNorm(self.embed_dim, export=export)
self.fc1 = Linear(self.embed_dim, args.decoder_ffn_embed_dim)
self.fc2 = Linear(args.decoder_ffn_embed_dim, self.embed_dim)
self.final_layer_norm = LayerNorm(self.embed_dim, export=export)
self.need_attn = True
self.onnx_trace = False
def prepare_for_onnx_export_(self):
self.onnx_trace = True
def forward(
self,
x,
encoder_out=None,
encoder_padding_mask=None,
incremental_state=None,
prev_self_attn_state=None,
prev_attn_state=None,
self_attn_mask=None,
self_attn_padding_mask=None,
):
"""
Args:
x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
encoder_padding_mask (ByteTensor): binary ByteTensor of shape
`(batch, src_len)` where padding elements are indicated by ``1``.
Returns:
encoded output of shape `(seq_len, batch, embed_dim)`
"""
residual = x
x = self.maybe_layer_norm(self.self_attn_layer_norm, x, before=True)
if prev_self_attn_state is not None:
if incremental_state is None:
incremental_state = {}
prev_key, prev_value = prev_self_attn_state
saved_state = {"prev_key": prev_key, "prev_value": prev_value}
self.self_attn._set_input_buffer(incremental_state, saved_state)
x, attn = self.self_attn(
query=x,
key=x,
value=x,
key_padding_mask=self_attn_padding_mask,
incremental_state=incremental_state,
need_weights=False,
attn_mask=self_attn_mask,
)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
x = self.maybe_layer_norm(self.self_attn_layer_norm, x, after=True)
if self.encoder_attn is not None:
residual = x
x = self.maybe_layer_norm(self.encoder_attn_layer_norm, x, before=True)
if prev_attn_state is not None:
if incremental_state is None:
incremental_state = {}
prev_key, prev_value = prev_attn_state
saved_state = {"prev_key": prev_key, "prev_value": prev_value}
self.encoder_attn._set_input_buffer(incremental_state, saved_state)
x, attn = self.encoder_attn(
query=x,
key=encoder_out,
value=encoder_out,
key_padding_mask=encoder_padding_mask,
incremental_state=incremental_state,
static_kv=True,
need_weights=(not self.training and self.need_attn),
)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
x = self.maybe_layer_norm(self.encoder_attn_layer_norm, x, after=True)
residual = x
x = self.maybe_layer_norm(self.final_layer_norm, x, before=True)
x = self.activation_fn(self.fc1(x))
x = F.dropout(x, p=self.activation_dropout, training=self.training)
x = self.fc2(x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
x = self.maybe_layer_norm(self.final_layer_norm, x, after=True)
if self.onnx_trace and incremental_state is not None:
saved_state = self.self_attn._get_input_buffer(incremental_state)
self_attn_state = saved_state["prev_key"], saved_state["prev_value"]
return x, attn, self_attn_state
return x, attn
def maybe_layer_norm(self, layer_norm, x, before=False, after=False):
assert before ^ after
if after ^ self.normalize_before:
return layer_norm(x)
else:
return x
def make_generation_fast_(self, need_attn=False, **kwargs):
self.need_attn = need_attn
def Embedding(num_embeddings, embedding_dim, padding_idx):
m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)

View File

@ -26,6 +26,8 @@ from .sinusoidal_positional_embedding import SinusoidalPositionalEmbedding
from .transformer_sentence_encoder_layer import TransformerSentenceEncoderLayer
from .transformer_sentence_encoder import TransformerSentenceEncoder
from .unfold import unfold1d
from .transformer_layer import TransformerDecoderLayer, TransformerEncoderLayer
from .vggblock import VGGBlock
__all__ = [
'AdaptiveInput',
@ -51,5 +53,8 @@ __all__ = [
'SinusoidalPositionalEmbedding',
'TransformerSentenceEncoderLayer',
'TransformerSentenceEncoder',
'TransformerDecoderLayer',
'TransformerEncoderLayer',
'VGGBlock',
'unfold1d',
]

View File

@ -0,0 +1,279 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch.nn as nn
import torch.nn.functional as F
from fairseq import utils
from fairseq.modules import LayerNorm, MultiheadAttention
class TransformerEncoderLayer(nn.Module):
"""Encoder layer block.
In the original paper each operation (multi-head attention or FFN) is
postprocessed with: `dropout -> add residual -> layernorm`. In the
tensor2tensor code they suggest that learning is more robust when
preprocessing each layer with layernorm and postprocessing with:
`dropout -> add residual`. We default to the approach in the paper, but the
tensor2tensor approach can be enabled by setting
*args.encoder_normalize_before* to ``True``.
Args:
args (argparse.Namespace): parsed command-line arguments
"""
def __init__(self, args):
super().__init__()
self.embed_dim = args.encoder_embed_dim
self.self_attn = MultiheadAttention(
self.embed_dim, args.encoder_attention_heads,
dropout=args.attention_dropout, self_attention=True
)
self.self_attn_layer_norm = LayerNorm(self.embed_dim)
self.dropout = args.dropout
self.activation_fn = utils.get_activation_fn(
activation=getattr(args, 'activation_fn', 'relu')
)
self.activation_dropout = getattr(args, 'activation_dropout', 0)
if self.activation_dropout == 0:
# for backwards compatibility with models that use args.relu_dropout
self.activation_dropout = getattr(args, 'relu_dropout', 0)
self.normalize_before = args.encoder_normalize_before
self.fc1 = Linear(self.embed_dim, args.encoder_ffn_embed_dim)
self.fc2 = Linear(args.encoder_ffn_embed_dim, self.embed_dim)
self.final_layer_norm = LayerNorm(self.embed_dim)
def upgrade_state_dict_named(self, state_dict, name):
"""
Rename layer norm states from `...layer_norms.0.weight` to
`...self_attn_layer_norm.weight` and `...layer_norms.1.weight` to
`...final_layer_norm.weight`
"""
layer_norm_map = {
'0': 'self_attn_layer_norm',
'1': 'final_layer_norm'
}
for old, new in layer_norm_map.items():
for m in ('weight', 'bias'):
k = '{}.layer_norms.{}.{}'.format(name, old, m)
if k in state_dict:
state_dict[
'{}.{}.{}'.format(name, new, m)
] = state_dict[k]
del state_dict[k]
def forward(self, x, encoder_padding_mask, attn_mask=None):
"""
Args:
x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
encoder_padding_mask (ByteTensor): binary ByteTensor of shape
`(batch, src_len)` where padding elements are indicated by ``1``.
attn_mask (ByteTensor): binary tensor of shape (T_tgt, T_src), where
T_tgt is the length of query, while T_src is the length of key,
though here both query and key is x here,
attn_mask[t_tgt, t_src] = 1 means when calculating embedding
for t_tgt, t_src is excluded (or masked out), =0 means it is
included in attention
Returns:
encoded output of shape `(seq_len, batch, embed_dim)`
"""
residual = x
x = self.maybe_layer_norm(self.self_attn_layer_norm, x, before=True)
if attn_mask is not None:
attn_mask = attn_mask.masked_fill(attn_mask.byte(), -1e8)
# anything in original attn_mask = 1, becomes -1e8
# anything in original attn_mask = 0, becomes 0
# Note that we cannot use -inf here, because at some edge cases,
# the attention weight (before softmax) for some padded element in query
# will become -inf, which results in NaN in model parameters
# TODO: to formally solve this problem, we need to change fairseq's
# MultiheadAttention. We will do this later on.
x, _ = self.self_attn(query=x, key=x, value=x, key_padding_mask=encoder_padding_mask)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
x = self.maybe_layer_norm(self.self_attn_layer_norm, x, after=True)
residual = x
x = self.maybe_layer_norm(self.final_layer_norm, x, before=True)
x = self.activation_fn(self.fc1(x))
x = F.dropout(x, p=self.activation_dropout, training=self.training)
x = self.fc2(x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
x = self.maybe_layer_norm(self.final_layer_norm, x, after=True)
return x
def maybe_layer_norm(self, layer_norm, x, before=False, after=False):
assert before ^ after
if after ^ self.normalize_before:
return layer_norm(x)
else:
return x
class TransformerDecoderLayer(nn.Module):
"""Decoder layer block.
In the original paper each operation (multi-head attention, encoder
attention or FFN) is postprocessed with: `dropout -> add residual ->
layernorm`. In the tensor2tensor code they suggest that learning is more
robust when preprocessing each layer with layernorm and postprocessing with:
`dropout -> add residual`. We default to the approach in the paper, but the
tensor2tensor approach can be enabled by setting
*args.decoder_normalize_before* to ``True``.
Args:
args (argparse.Namespace): parsed command-line arguments
no_encoder_attn (bool, optional): whether to attend to encoder outputs
(default: False).
"""
def __init__(self, args, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False):
super().__init__()
self.embed_dim = args.decoder_embed_dim
self.self_attn = MultiheadAttention(
embed_dim=self.embed_dim,
num_heads=args.decoder_attention_heads,
dropout=args.attention_dropout,
add_bias_kv=add_bias_kv,
add_zero_attn=add_zero_attn,
self_attention=True
)
self.dropout = args.dropout
self.activation_fn = utils.get_activation_fn(
activation=getattr(args, 'activation_fn', 'relu')
)
self.activation_dropout = getattr(args, 'activation_dropout', 0)
if self.activation_dropout == 0:
# for backwards compatibility with models that use args.relu_dropout
self.activation_dropout = getattr(args, 'relu_dropout', 0)
self.normalize_before = args.decoder_normalize_before
# use layerNorm rather than FusedLayerNorm for exporting.
# char_inputs can be used to determint this.
# TODO remove this once we update apex with the fix
export = getattr(args, 'char_inputs', False)
self.self_attn_layer_norm = LayerNorm(self.embed_dim, export=export)
if no_encoder_attn:
self.encoder_attn = None
self.encoder_attn_layer_norm = None
else:
self.encoder_attn = MultiheadAttention(
self.embed_dim,
args.decoder_attention_heads,
kdim=getattr(args, 'encoder_embed_dim', None),
vdim=getattr(args, 'encoder_embed_dim', None),
dropout=args.attention_dropout,
encoder_decoder_attention=True,
)
self.encoder_attn_layer_norm = LayerNorm(self.embed_dim, export=export)
self.fc1 = Linear(self.embed_dim, args.decoder_ffn_embed_dim)
self.fc2 = Linear(args.decoder_ffn_embed_dim, self.embed_dim)
self.final_layer_norm = LayerNorm(self.embed_dim, export=export)
self.need_attn = True
self.onnx_trace = False
def prepare_for_onnx_export_(self):
self.onnx_trace = True
def forward(
self,
x,
encoder_out=None,
encoder_padding_mask=None,
incremental_state=None,
prev_self_attn_state=None,
prev_attn_state=None,
self_attn_mask=None,
self_attn_padding_mask=None,
):
"""
Args:
x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
encoder_padding_mask (ByteTensor): binary ByteTensor of shape
`(batch, src_len)` where padding elements are indicated by ``1``.
Returns:
encoded output of shape `(seq_len, batch, embed_dim)`
"""
residual = x
x = self.maybe_layer_norm(self.self_attn_layer_norm, x, before=True)
if prev_self_attn_state is not None:
if incremental_state is None:
incremental_state = {}
prev_key, prev_value = prev_self_attn_state
saved_state = {"prev_key": prev_key, "prev_value": prev_value}
self.self_attn._set_input_buffer(incremental_state, saved_state)
x, attn = self.self_attn(
query=x,
key=x,
value=x,
key_padding_mask=self_attn_padding_mask,
incremental_state=incremental_state,
need_weights=False,
attn_mask=self_attn_mask,
)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
x = self.maybe_layer_norm(self.self_attn_layer_norm, x, after=True)
if self.encoder_attn is not None:
residual = x
x = self.maybe_layer_norm(self.encoder_attn_layer_norm, x, before=True)
if prev_attn_state is not None:
if incremental_state is None:
incremental_state = {}
prev_key, prev_value = prev_attn_state
saved_state = {"prev_key": prev_key, "prev_value": prev_value}
self.encoder_attn._set_input_buffer(incremental_state, saved_state)
x, attn = self.encoder_attn(
query=x,
key=encoder_out,
value=encoder_out,
key_padding_mask=encoder_padding_mask,
incremental_state=incremental_state,
static_kv=True,
need_weights=(not self.training and self.need_attn),
)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
x = self.maybe_layer_norm(self.encoder_attn_layer_norm, x, after=True)
residual = x
x = self.maybe_layer_norm(self.final_layer_norm, x, before=True)
x = self.activation_fn(self.fc1(x))
x = F.dropout(x, p=self.activation_dropout, training=self.training)
x = self.fc2(x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
x = self.maybe_layer_norm(self.final_layer_norm, x, after=True)
if self.onnx_trace and incremental_state is not None:
saved_state = self.self_attn._get_input_buffer(incremental_state)
self_attn_state = saved_state["prev_key"], saved_state["prev_value"]
return x, attn, self_attn_state
return x, attn
def maybe_layer_norm(self, layer_norm, x, before=False, after=False):
assert before ^ after
if after ^ self.normalize_before:
return layer_norm(x)
else:
return x
def make_generation_fast_(self, need_attn=False, **kwargs):
self.need_attn = need_attn
def Linear(in_features, out_features, bias=True):
m = nn.Linear(in_features, out_features, bias)
nn.init.xavier_uniform_(m.weight)
if bias:
nn.init.constant_(m.bias, 0.)
return m

115
fairseq/modules/vggblock.py Normal file
View File

@ -0,0 +1,115 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import absolute_import, division, print_function, unicode_literals
from collections.abc import Iterable
from itertools import repeat
import torch
import torch.nn as nn
def _pair(v):
if isinstance(v, Iterable):
assert len(v) == 2, "len(v) != 2"
return v
return tuple(repeat(v, 2))
def infer_conv_output_dim(conv_op, input_dim, sample_inchannel):
sample_seq_len = 200
sample_bsz = 10
x = torch.randn(sample_bsz, sample_inchannel, sample_seq_len, input_dim)
# N x C x H x W
# N: sample_bsz, C: sample_inchannel, H: sample_seq_len, W: input_dim
x = conv_op(x)
# N x C x H x W
x = x.transpose(1, 2)
# N x H x C x W
bsz, seq = x.size()[:2]
per_channel_dim = x.size()[3]
# bsz: N, seq: H, CxW the rest
return x.contiguous().view(bsz, seq, -1).size(-1), per_channel_dim
class VGGBlock(torch.nn.Module):
"""
VGG motibated cnn module https://arxiv.org/pdf/1409.1556.pdf
Args:
in_channels: (int) number of input channels (typically 1)
out_channels: (int) number of output channels
conv_kernel_size: convolution channels
pooling_kernel_size: the size of the pooling window to take a max over
num_conv_layers: (int) number of convolution layers
input_dim: (int) input dimension
conv_stride: the stride of the convolving kernel.
Can be a single number or a tuple (sH, sW) Default: 1
padding: implicit paddings on both sides of the input.
Can be a single number or a tuple (padH, padW). Default: None
layer_norm: (bool) if layer norm is going to be applied. Default: False
Shape:
Input: BxCxTxfeat, i.e. (batch_size, input_size, timesteps, features)
Output: BxCxTxfeat, i.e. (batch_size, input_size, timesteps, features)
"""
def __init__(
self,
in_channels,
out_channels,
conv_kernel_size,
pooling_kernel_size,
num_conv_layers,
input_dim,
conv_stride=1,
padding=None,
layer_norm=False,
):
assert (
input_dim is not None
), "Need input_dim for LayerNorm and infer_conv_output_dim"
super(VGGBlock, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.conv_kernel_size = _pair(conv_kernel_size)
self.pooling_kernel_size = _pair(pooling_kernel_size)
self.num_conv_layers = num_conv_layers
self.padding = (
tuple(e // 2 for e in self.conv_kernel_size)
if padding is None
else _pair(padding)
)
self.conv_stride = _pair(conv_stride)
self.layers = nn.ModuleList()
for layer in range(num_conv_layers):
conv_op = nn.Conv2d(
in_channels if layer == 0 else out_channels,
out_channels,
self.conv_kernel_size,
stride=self.conv_stride,
padding=self.padding,
)
self.layers.append(conv_op)
if layer_norm:
conv_output_dim, per_channel_dim = infer_conv_output_dim(
conv_op, input_dim, in_channels if layer == 0 else out_channels
)
self.layers.append(nn.LayerNorm(per_channel_dim))
input_dim = per_channel_dim
self.layers.append(nn.ReLU())
pool_op = nn.MaxPool2d(kernel_size=pooling_kernel_size, ceil_mode=True)
self.layers.append(pool_op)
self.total_output_dim, self.output_dim = infer_conv_output_dim(
pool_op, input_dim, out_channels
)
def forward(self, x):
for i, _ in enumerate(self.layers):
x = self.layers[i](x)
return x

View File

View File

@ -0,0 +1,549 @@
#!/usr/bin/env python3
import argparse
import os
import unittest
from inspect import currentframe, getframeinfo
import numpy as np
import torch
from fairseq.data import data_utils as fairseq_data_utils
from fairseq.data.dictionary import Dictionary
from fairseq.models import (
BaseFairseqModel,
FairseqDecoder,
FairseqEncoder,
FairseqEncoderDecoderModel,
FairseqEncoderModel,
FairseqModel,
)
from fairseq.tasks.fairseq_task import FairseqTask
from examples.speech_recognition.data.data_utils import lengths_to_encoder_padding_mask
DEFAULT_TEST_VOCAB_SIZE = 100
# ///////////////////////////////////////////////////////////////////////////
# utility function to setup dummy dict/task/input
# ///////////////////////////////////////////////////////////////////////////
def get_dummy_dictionary(vocab_size=DEFAULT_TEST_VOCAB_SIZE):
dummy_dict = Dictionary()
# add dummy symbol to satisfy vocab size
for id, _ in enumerate(range(vocab_size)):
dummy_dict.add_symbol("{}".format(id), 1000)
return dummy_dict
class DummyTask(FairseqTask):
def __init__(self, args):
super().__init__(args)
self.dictionary = get_dummy_dictionary()
if getattr(self.args, "ctc", False):
self.dictionary.add_symbol("<ctc_blank>")
self.tgt_dict = self.dictionary
@property
def target_dictionary(self):
return self.dictionary
def get_dummy_task_and_parser():
"""
to build a fariseq model, we need some dummy parse and task. This function
is used to create dummy task and parser to faciliate model/criterion test
Note: we use FbSpeechRecognitionTask as the dummy task. You may want
to use other task by providing another function
"""
parser = argparse.ArgumentParser(
description="test_dummy_s2s_task", argument_default=argparse.SUPPRESS
)
DummyTask.add_args(parser)
args = parser.parse_args([])
task = DummyTask.setup_task(args)
return task, parser
def get_dummy_input(T=100, D=80, B=5, K=100):
forward_input = {}
# T max sequence length
# D feature vector dimension
# B batch size
# K target dimension size
feature = torch.randn(B, T, D)
# this (B, T, D) layout is just a convention, you can override it by
# write your own _prepare_forward_input function
src_lengths = torch.from_numpy(
np.random.randint(low=1, high=T, size=B).astype(np.int64)
)
src_lengths[0] = T # make sure the maximum length matches
prev_output_tokens = []
for b in range(B):
token_length = np.random.randint(low=1, high=src_lengths[b].item() + 1)
tokens = np.random.randint(low=0, high=K, size=token_length)
prev_output_tokens.append(torch.from_numpy(tokens))
prev_output_tokens = fairseq_data_utils.collate_tokens(
prev_output_tokens,
pad_idx=1,
eos_idx=2,
left_pad=False,
move_eos_to_beginning=False,
)
src_lengths, sorted_order = src_lengths.sort(descending=True)
forward_input["src_tokens"] = feature.index_select(0, sorted_order)
forward_input["src_lengths"] = src_lengths
forward_input["prev_output_tokens"] = prev_output_tokens
return forward_input
def get_dummy_encoder_output(encoder_out_shape=(100, 80, 5)):
"""
This only provides an example to generate dummy encoder output
"""
(T, B, D) = encoder_out_shape
encoder_out = {}
encoder_out["encoder_out"] = torch.from_numpy(
np.random.randn(*encoder_out_shape).astype(np.float32)
)
seq_lengths = torch.from_numpy(np.random.randint(low=1, high=T, size=B))
# some dummy mask
encoder_out["encoder_padding_mask"] = torch.arange(T).view(1, T).expand(
B, -1
) >= seq_lengths.view(B, 1).expand(-1, T)
encoder_out["encoder_padding_mask"].t_()
# encoer_padding_mask is (T, B) tensor, with (t, b)-th element indicate
# whether encoder_out[t, b] is valid (=0) or not (=1)
return encoder_out
def _current_postion_info():
cf = currentframe()
frameinfo = " (at {}:{})".format(
os.path.basename(getframeinfo(cf).filename), cf.f_back.f_lineno
)
return frameinfo
def check_encoder_output(encoder_output, batch_size=None):
"""we expect encoder_output to be a dict with the following
key/value pairs:
- encoder_out: a Torch.Tensor
- encoder_padding_mask: a binary Torch.Tensor
"""
if not isinstance(encoder_output, dict):
msg = (
"FairseqEncoderModel.forward(...) must be a dict" + _current_postion_info()
)
return False, msg
if "encoder_out" not in encoder_output:
msg = (
"FairseqEncoderModel.forward(...) must contain encoder_out"
+ _current_postion_info()
)
return False, msg
if "encoder_padding_mask" not in encoder_output:
msg = (
"FairseqEncoderModel.forward(...) must contain encoder_padding_mask"
+ _current_postion_info()
)
return False, msg
if not isinstance(encoder_output["encoder_out"], torch.Tensor):
msg = "encoder_out must be a torch.Tensor" + _current_postion_info()
return False, msg
if encoder_output["encoder_out"].dtype != torch.float32:
msg = "encoder_out must have float32 dtype" + _current_postion_info()
return False, msg
mask = encoder_output["encoder_padding_mask"]
if mask is not None:
if not isinstance(mask, torch.Tensor):
msg = (
"encoder_padding_mask must be a torch.Tensor" + _current_postion_info()
)
return False, msg
if mask.dtype != torch.uint8:
msg = (
"encoder_padding_mask must have dtype of uint8"
+ _current_postion_info()
)
return False, msg
if mask.dim() != 2:
msg = (
"we expect encoder_padding_mask to be a 2-d tensor, in shape (T, B)"
+ _current_postion_info()
)
return False, msg
if batch_size is not None and mask.size(1) != batch_size:
msg = (
"we expect encoder_padding_mask to be a 2-d tensor, with size(1)"
+ " being the batch size"
+ _current_postion_info()
)
return False, msg
return True, None
def check_decoder_output(decoder_output):
"""we expect output from a decoder is a tuple with the following constraint:
- the first element is a torch.Tensor
- the second element can be anything (reserved for future use)
"""
if not isinstance(decoder_output, tuple):
msg = "FariseqDecoder output must be a tuple" + _current_postion_info()
return False, msg
if len(decoder_output) != 2:
msg = "FairseqDecoder output must be 2-elem tuple" + _current_postion_info()
return False, msg
if not isinstance(decoder_output[0], torch.Tensor):
msg = (
"FariseqDecoder output[0] must be a torch.Tensor" + _current_postion_info()
)
return False, msg
return True, None
# ///////////////////////////////////////////////////////////////////////////
# Base Test class
# ///////////////////////////////////////////////////////////////////////////
class TestBaseFairseqModelBase(unittest.TestCase):
"""
This class is used to facilitate writing unittest for any class derived from
`BaseFairseqModel`.
"""
@classmethod
def setUpClass(cls):
if cls is TestBaseFairseqModelBase:
raise unittest.SkipTest("Skipping test case in base")
super().setUpClass()
def setUpModel(self, model):
self.assertTrue(isinstance(model, BaseFairseqModel))
self.model = model
def setupInput(self):
pass
def setUp(self):
self.model = None
self.forward_input = None
pass
class TestFairseqEncoderDecoderModelBase(TestBaseFairseqModelBase):
"""
base code to test FairseqEncoderDecoderModel (formally known as
`FairseqModel`) must be derived from this base class
"""
@classmethod
def setUpClass(cls):
if cls is TestFairseqEncoderDecoderModelBase:
raise unittest.SkipTest("Skipping test case in base")
super().setUpClass()
def setUpModel(self, model_cls, extra_args_setters=None):
self.assertTrue(
issubclass(model_cls, (FairseqEncoderDecoderModel, FairseqModel)),
msg="This class only tests for FairseqModel subclasses",
)
task, parser = get_dummy_task_and_parser()
model_cls.add_args(parser)
args = parser.parse_args([])
if extra_args_setters is not None:
for args_setter in extra_args_setters:
args_setter(args)
model = model_cls.build_model(args, task)
self.model = model
def setUpInput(self, input=None):
self.forward_input = get_dummy_input() if input is None else input
def setUp(self):
super().setUp()
def test_forward(self):
if self.model and self.forward_input:
forward_output = self.model.forward(**self.forward_input)
# for FairseqEncoderDecoderModel, forward returns a tuple of two
# elements, the first one is a Torch.Tensor
succ, msg = check_decoder_output(forward_output)
if not succ:
self.assertTrue(succ, msg=msg)
self.forward_output = forward_output
def test_get_normalized_probs(self):
if self.model and self.forward_input:
forward_output = self.model.forward(**self.forward_input)
logprob = self.model.get_normalized_probs(forward_output, log_probs=True)
prob = self.model.get_normalized_probs(forward_output, log_probs=False)
# in order for different models/criterion to play with each other
# we need to know whether the logprob or prob output is batch_first
# or not. We assume an additional attribute will be attached to logprob
# or prob. If you find your code failed here, simply override
# FairseqModel.get_normalized_probs, see example at
# https://fburl.com/batch_first_example
self.assertTrue(hasattr(logprob, "batch_first"))
self.assertTrue(hasattr(prob, "batch_first"))
self.assertTrue(torch.is_tensor(logprob))
self.assertTrue(torch.is_tensor(prob))
class TestFairseqEncoderModelBase(TestBaseFairseqModelBase):
"""
base class to test FairseqEncoderModel
"""
@classmethod
def setUpClass(cls):
if cls is TestFairseqEncoderModelBase:
raise unittest.SkipTest("Skipping test case in base")
super().setUpClass()
def setUpModel(self, model_cls, extra_args_setters=None):
self.assertTrue(
issubclass(model_cls, FairseqEncoderModel),
msg="This class is only used for testing FairseqEncoderModel",
)
task, parser = get_dummy_task_and_parser()
model_cls.add_args(parser)
args = parser.parse_args([])
if extra_args_setters is not None:
for args_setter in extra_args_setters:
args_setter(args)
model = model_cls.build_model(args, task)
self.model = model
def setUpInput(self, input=None):
self.forward_input = get_dummy_input() if input is None else input
# get_dummy_input() is originally for s2s, here we delete extra dict
# items, so it can be used for EncoderModel / Encoder as well
self.forward_input.pop("prev_output_tokens", None)
def setUp(self):
super().setUp()
def test_forward(self):
if self.forward_input and self.model:
bsz = self.forward_input["src_tokens"].size(0)
forward_output = self.model.forward(**self.forward_input)
# we expect forward_output to be a dict with the following
# key/value pairs:
# - encoder_out: a Torch.Tensor
# - encoder_padding_mask: a binary Torch.Tensor
succ, msg = check_encoder_output(forward_output, batch_size=bsz)
if not succ:
self.assertTrue(succ, msg=msg)
self.forward_output = forward_output
def test_get_normalized_probs(self):
if self.model and self.forward_input:
forward_output = self.model.forward(**self.forward_input)
logprob = self.model.get_normalized_probs(forward_output, log_probs=True)
prob = self.model.get_normalized_probs(forward_output, log_probs=False)
# in order for different models/criterion to play with each other
# we need to know whether the logprob or prob output is batch_first
# or not. We assume an additional attribute will be attached to logprob
# or prob. If you find your code failed here, simply override
# FairseqModel.get_normalized_probs, see example at
# https://fburl.com/batch_first_example
self.assertTrue(hasattr(logprob, "batch_first"))
self.assertTrue(hasattr(prob, "batch_first"))
self.assertTrue(torch.is_tensor(logprob))
self.assertTrue(torch.is_tensor(prob))
class TestFairseqEncoderBase(unittest.TestCase):
"""
base class to test FairseqEncoder
"""
@classmethod
def setUpClass(cls):
if cls is TestFairseqEncoderBase:
raise unittest.SkipTest("Skipping test case in base")
super().setUpClass()
def setUpEncoder(self, encoder):
self.assertTrue(
isinstance(encoder, FairseqEncoder),
msg="This class is only used for test FairseqEncoder",
)
self.encoder = encoder
def setUpInput(self, input=None):
self.forward_input = get_dummy_input() if input is None else input
# get_dummy_input() is originally for s2s, here we delete extra dict
# items, so it can be used for EncoderModel / Encoder as well
self.forward_input.pop("prev_output_tokens", None)
def setUp(self):
self.encoder = None
self.forward_input = None
def test_forward(self):
if self.encoder and self.forward_input:
bsz = self.forward_input["src_tokens"].size(0)
forward_output = self.encoder.forward(**self.forward_input)
succ, msg = check_encoder_output(forward_output, batch_size=bsz)
if not succ:
self.assertTrue(succ, msg=msg)
self.forward_output = forward_output
class TestFairseqDecoderBase(unittest.TestCase):
"""
base class to test FairseqDecoder
"""
@classmethod
def setUpClass(cls):
if cls is TestFairseqDecoderBase:
raise unittest.SkipTest("Skipping test case in base")
super().setUpClass()
def setUpDecoder(self, decoder):
self.assertTrue(
isinstance(decoder, FairseqDecoder),
msg="This class is only used for test FairseqDecoder",
)
self.decoder = decoder
def setUpInput(self, input=None):
self.forward_input = get_dummy_encoder_output() if input is None else input
def setUpPrevOutputTokens(self, tokens=None):
if tokens is None:
self.encoder_input = get_dummy_input()
self.prev_output_tokens = self.encoder_input["prev_output_tokens"]
else:
self.prev_output_tokens = tokens
def setUp(self):
self.decoder = None
self.forward_input = None
self.prev_output_tokens = None
def test_forward(self):
if (
self.decoder is not None
and self.forward_input is not None
and self.prev_output_tokens is not None
):
forward_output = self.decoder.forward(
prev_output_tokens=self.prev_output_tokens,
encoder_out=self.forward_input,
)
succ, msg = check_decoder_output(forward_output)
if not succ:
self.assertTrue(succ, msg=msg)
self.forward_input = forward_output
class DummyEncoderModel(FairseqEncoderModel):
def __init__(self, encoder):
super().__init__(encoder)
@classmethod
def build_model(cls, args, task):
return cls(DummyEncoder())
def get_logits(self, net_output):
# Inverse of sigmoid to use with BinaryCrossEntropyWithLogitsCriterion as
# F.binary_cross_entropy_with_logits combines sigmoid and CE
return torch.log(
torch.div(net_output["encoder_out"], 1 - net_output["encoder_out"])
)
class DummyEncoder(FairseqEncoder):
def __init__(self):
super().__init__(None)
def forward(self, src_tokens, src_lengths):
mask, max_len = lengths_to_encoder_padding_mask(src_lengths)
return {"encoder_out": src_tokens, "encoder_padding_mask": mask}
class CrossEntropyCriterionTestBase(unittest.TestCase):
@classmethod
def setUpClass(cls):
if cls is CrossEntropyCriterionTestBase:
raise unittest.SkipTest("Skipping base class test case")
super().setUpClass()
def setUpArgs(self):
args = argparse.Namespace()
args.sentence_avg = False
args.threshold = 0.1 # to use with BinaryCrossEntropyWithLogitsCriterion
return args
def setUp(self):
args = self.setUpArgs()
self.model = DummyEncoderModel(encoder=DummyEncoder())
self.criterion = self.criterion_cls(args=args, task=DummyTask(args))
def get_src_tokens(self, correct_prediction, aggregate):
"""
correct_prediction: True if the net_output (src_tokens) should
predict the correct target
aggregate: True if the criterion expects net_output (src_tokens)
aggregated across time axis
"""
predicted_idx = 0 if correct_prediction else 1
if aggregate:
src_tokens = torch.zeros((2, 2), dtype=torch.float)
for b in range(2):
src_tokens[b][predicted_idx] = 1.0
else:
src_tokens = torch.zeros((2, 10, 2), dtype=torch.float)
for b in range(2):
for t in range(10):
src_tokens[b][t][predicted_idx] = 1.0
return src_tokens
def get_target(self, soft_target):
if soft_target:
target = torch.zeros((2, 2), dtype=torch.float)
for b in range(2):
target[b][0] = 1.0
else:
target = torch.zeros((2, 10), dtype=torch.long)
return target
def get_test_sample(self, correct, soft_target, aggregate):
src_tokens = self.get_src_tokens(correct, aggregate)
target = self.get_target(soft_target)
L = src_tokens.size(1)
return {
"net_input": {"src_tokens": src_tokens, "src_lengths": torch.tensor([L])},
"target": target,
"ntokens": src_tokens.size(0) * src_tokens.size(1),
}

View File

@ -0,0 +1,60 @@
#!/usr/bin/env python3
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import unittest
import numpy as np
import torch
from examples.speech_recognition.data.collaters import Seq2SeqCollater
class TestSeq2SeqCollator(unittest.TestCase):
def test_collate(self):
eos_idx = 1
pad_idx = 0
collater = Seq2SeqCollater(
feature_index=0, label_index=1, pad_index=pad_idx, eos_index=eos_idx
)
# 2 frames in the first sample and 3 frames in the second one
frames1 = np.array([[7, 8], [9, 10]])
frames2 = np.array([[1, 2], [3, 4], [5, 6]])
target1 = np.array([4, 2, 3, eos_idx])
target2 = np.array([3, 2, eos_idx])
sample1 = {"id": 0, "data": [frames1, target1]}
sample2 = {"id": 1, "data": [frames2, target2]}
batch = collater.collate([sample1, sample2])
# collate sort inputs by frame's length before creating the batch
self.assertTensorEqual(batch["id"], torch.tensor([1, 0]))
self.assertEqual(batch["ntokens"], 7)
self.assertTensorEqual(
batch["net_input"]["src_tokens"],
torch.tensor(
[[[1, 2], [3, 4], [5, 6]], [[7, 8], [9, 10], [pad_idx, pad_idx]]]
),
)
self.assertTensorEqual(
batch["net_input"]["prev_output_tokens"],
torch.tensor([[eos_idx, 3, 2, pad_idx], [eos_idx, 4, 2, 3]]),
)
self.assertTensorEqual(batch["net_input"]["src_lengths"], torch.tensor([3, 2]))
self.assertTensorEqual(
batch["target"],
torch.tensor([[3, 2, eos_idx, pad_idx], [4, 2, 3, eos_idx]]),
)
self.assertEqual(batch["nsentences"], 2)
def assertTensorEqual(self, t1, t2):
self.assertEqual(t1.size(), t2.size(), "size mismatch")
self.assertEqual(t1.ne(t2).long().sum(), 0)
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,36 @@
#!/usr/bin/env python3
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from examples.speech_recognition.criterions.cross_entropy_acc import CrossEntropyWithAccCriterion
from .asr_test_base import CrossEntropyCriterionTestBase
class CrossEntropyWithAccCriterionTest(CrossEntropyCriterionTestBase):
def setUp(self):
self.criterion_cls = CrossEntropyWithAccCriterion
super().setUp()
def test_cross_entropy_all_correct(self):
sample = self.get_test_sample(correct=True, soft_target=False, aggregate=False)
loss, sample_size, logging_output = self.criterion(
self.model, sample, "sum", log_probs=True
)
assert logging_output["correct"] == 20
assert logging_output["total"] == 20
assert logging_output["sample_size"] == 20
assert logging_output["ntokens"] == 20
def test_cross_entropy_all_wrong(self):
sample = self.get_test_sample(correct=False, soft_target=False, aggregate=False)
loss, sample_size, logging_output = self.criterion(
self.model, sample, "sum", log_probs=True
)
assert logging_output["correct"] == 0
assert logging_output["total"] == 20
assert logging_output["sample_size"] == 20
assert logging_output["ntokens"] == 20

View File

@ -0,0 +1,135 @@
#!/usr/bin/env python3
# import models/encoder/decoder to be tested
from examples.speech_recognition.models.vggtransformer import (
TransformerDecoder,
VGGTransformerEncoder,
VGGTransformerModel,
vggtransformer_1,
vggtransformer_2,
vggtransformer_base,
)
# import base test class
from .asr_test_base import (
DEFAULT_TEST_VOCAB_SIZE,
TestFairseqDecoderBase,
TestFairseqEncoderBase,
TestFairseqEncoderDecoderModelBase,
get_dummy_dictionary,
get_dummy_encoder_output,
get_dummy_input,
)
class VGGTransformerModelTest_mid(TestFairseqEncoderDecoderModelBase):
def setUp(self):
def override_config(args):
"""
vggtrasformer_1 use 14 layers of transformer,
for testing purpose, it is too expensive. For fast turn-around
test, reduce the number of layers to 3.
"""
args.transformer_enc_config = (
"((1024, 16, 4096, True, 0.15, 0.15, 0.15),) * 3"
)
super().setUp()
extra_args_setter = [vggtransformer_1, override_config]
self.setUpModel(VGGTransformerModel, extra_args_setter)
self.setUpInput(get_dummy_input(T=50, D=80, B=5, K=DEFAULT_TEST_VOCAB_SIZE))
class VGGTransformerModelTest_big(TestFairseqEncoderDecoderModelBase):
def setUp(self):
def override_config(args):
"""
vggtrasformer_2 use 16 layers of transformer,
for testing purpose, it is too expensive. For fast turn-around
test, reduce the number of layers to 3.
"""
args.transformer_enc_config = (
"((1024, 16, 4096, True, 0.15, 0.15, 0.15),) * 3"
)
super().setUp()
extra_args_setter = [vggtransformer_2, override_config]
self.setUpModel(VGGTransformerModel, extra_args_setter)
self.setUpInput(get_dummy_input(T=50, D=80, B=5, K=DEFAULT_TEST_VOCAB_SIZE))
class VGGTransformerModelTest_base(TestFairseqEncoderDecoderModelBase):
def setUp(self):
def override_config(args):
"""
vggtrasformer_base use 12 layers of transformer,
for testing purpose, it is too expensive. For fast turn-around
test, reduce the number of layers to 3.
"""
args.transformer_enc_config = (
"((512, 8, 2048, True, 0.15, 0.15, 0.15),) * 3"
)
super().setUp()
extra_args_setter = [vggtransformer_base, override_config]
self.setUpModel(VGGTransformerModel, extra_args_setter)
self.setUpInput(get_dummy_input(T=50, D=80, B=5, K=DEFAULT_TEST_VOCAB_SIZE))
class VGGTransformerEncoderTest(TestFairseqEncoderBase):
def setUp(self):
super().setUp()
self.setUpInput(get_dummy_input(T=50, D=80, B=5))
def test_forward(self):
print("1. test standard vggtransformer")
self.setUpEncoder(VGGTransformerEncoder(input_feat_per_channel=80))
super().test_forward()
print("2. test vggtransformer with limited right context")
self.setUpEncoder(
VGGTransformerEncoder(
input_feat_per_channel=80, transformer_context=(-1, 5)
)
)
super().test_forward()
print("3. test vggtransformer with limited left context")
self.setUpEncoder(
VGGTransformerEncoder(
input_feat_per_channel=80, transformer_context=(5, -1)
)
)
super().test_forward()
print("4. test vggtransformer with limited right context and sampling")
self.setUpEncoder(
VGGTransformerEncoder(
input_feat_per_channel=80,
transformer_context=(-1, 12),
transformer_sampling=(2, 2),
)
)
super().test_forward()
print("5. test vggtransformer with windowed context and sampling")
self.setUpEncoder(
VGGTransformerEncoder(
input_feat_per_channel=80,
transformer_context=(12, 12),
transformer_sampling=(2, 2),
)
)
class TransformerDecoderTest(TestFairseqDecoderBase):
def setUp(self):
super().setUp()
dict = get_dummy_dictionary(vocab_size=DEFAULT_TEST_VOCAB_SIZE)
decoder = TransformerDecoder(dict)
dummy_encoder_output = get_dummy_encoder_output(encoder_out_shape=(50, 5, 256))
self.setUpDecoder(decoder)
self.setUpInput(dummy_encoder_output)
self.setUpPrevOutputTokens()