mirror of
https://github.com/facebookresearch/fairseq.git
synced 2024-10-26 17:32:57 +03:00
Asr initial push (#810)
Summary: Initial code for speech recognition task. Right now only one ASR model added - https://arxiv.org/abs/1904.11660 unit test testing: python -m unittest discover tests also run model training with this code and obtained 5.0 test_clean | 13.4 test_other on librispeech with pytorch/audio features Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/810 Reviewed By: cpuhrsch Differential Revision: D16706659 Pulled By: okhonko fbshipit-source-id: 89a5f9883e50bc0e548234287aa0ea73f7402514
This commit is contained in:
parent
9a1038f68a
commit
72f9364cc6
32
examples/speech_recognition/README.md
Normal file
32
examples/speech_recognition/README.md
Normal file
@ -0,0 +1,32 @@
|
||||
# Speech Recognition
|
||||
`examples/speech_recognition` is implementing ASR task in Fairseq, along with needed features, datasets, models and loss functions to train and infer model described in [Transformers with convolutional context for ASR (Abdelrahman Mohamed et al., 2019)](https://arxiv.org/abs/1904.11660).
|
||||
|
||||
|
||||
## Additional dependencies
|
||||
On top of main fairseq dependencies there are couple more additional requirements.
|
||||
|
||||
1) Please follow the instructions to install [torchaudio](https://github.com/pytorch/audio). This is required to compute audio fbank features.
|
||||
2) [Sclite](http://www1.icsi.berkeley.edu/Speech/docs/sctk-1.2/sclite.htm#sclite_name_0) is used to measure WER. Sclite can be downloaded and installed from source from sctk package [here](http://www.openslr.org/4/). Training and inference doesn't require Sclite dependency.
|
||||
|
||||
## Preparing librispeech data
|
||||
```
|
||||
./examples/speech_recognition/datasets/prepare-librispeech.sh $DIR_TO_SAVE_RAW_DATA $DIR_FOR_PREPROCESSED_DATA
|
||||
```
|
||||
|
||||
## Training librispeech data
|
||||
```
|
||||
python train.py $DIR_FOR_PREPROCESSED_DATA --save-dir $MODEL_PATH --max-epoch 80 --task speech_recognition --arch vggtransformer_2 --optimizer adadelta --lr 1.0 --adadelta-eps 1e-8 --adadelta-rho 0.95 --clip-norm 10.0 --max-tokens 5000 --log-format json --log-interval 1 --criterion cross_entropy_acc --user-dir examples/speech_recognition/
|
||||
```
|
||||
|
||||
## Inference for librispeech
|
||||
`$SET` can be `test_clean` or `test_other`
|
||||
Any checkpoint in `$MODEL_PATH` can be selected. In this example we are working with `checkpoint_last.pt`
|
||||
```
|
||||
python examples/speech_recognition/infer.py $DIR_FOR_PREPROCESSED_DATA --task speech_recognition --max-tokens 25000 --nbest 1 --path $MODEL_PATH/checkpoint_last.pt --beam 20 --results-path $RES_DIR --batch-size 40 --gen-subset $SET --user-dir examples/speech_recognition/
|
||||
```
|
||||
|
||||
## Inference for librispeech
|
||||
```
|
||||
sclite -r ${RES_DIR}/ref.word-checkpoint_last.pt-${SET}.txt -h ${RES_DIR}/hypo.word-checkpoint_last.pt-${SET}.txt -i rm -o all stdout > $RES_REPORT
|
||||
```
|
||||
`Sum/Avg` row from first table of the report has WER
|
1
examples/speech_recognition/__init__.py
Normal file
1
examples/speech_recognition/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from . import tasks, criterions, models # noqa
|
7
examples/speech_recognition/criterions/__init__.py
Normal file
7
examples/speech_recognition/criterions/__init__.py
Normal file
@ -0,0 +1,7 @@
|
||||
import importlib
|
||||
import os
|
||||
|
||||
for file in os.listdir(os.path.dirname(__file__)):
|
||||
if file.endswith('.py') and not file.startswith('_'):
|
||||
criterion_name = file[:file.find('.py')]
|
||||
importlib.import_module('examples.speech_recognition.criterions.' + criterion_name)
|
129
examples/speech_recognition/criterions/cross_entropy_acc.py
Normal file
129
examples/speech_recognition/criterions/cross_entropy_acc.py
Normal file
@ -0,0 +1,129 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||
|
||||
import logging
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from fairseq import utils
|
||||
from fairseq.criterions import FairseqCriterion, register_criterion
|
||||
|
||||
|
||||
@register_criterion("cross_entropy_acc")
|
||||
class CrossEntropyWithAccCriterion(FairseqCriterion):
|
||||
def __init__(self, args, task):
|
||||
super().__init__(args, task)
|
||||
|
||||
def compute_loss(self, model, net_output, target, reduction, log_probs):
|
||||
# N, T -> N * T
|
||||
target = target.view(-1)
|
||||
lprobs = model.get_normalized_probs(net_output, log_probs=log_probs)
|
||||
if not hasattr(lprobs, "batch_first"):
|
||||
logging.warning(
|
||||
"ERROR: we need to know whether "
|
||||
"batch first for the net output; "
|
||||
"you need to set batch_first attribute for the return value of "
|
||||
"model.get_normalized_probs. Now, we assume this is true, but "
|
||||
"in the future, we will raise exception instead. "
|
||||
)
|
||||
batch_first = getattr(lprobs, "batch_first", True)
|
||||
if not batch_first:
|
||||
lprobs = lprobs.transpose(0, 1)
|
||||
|
||||
# N, T, D -> N * T, D
|
||||
lprobs = lprobs.view(-1, lprobs.size(-1))
|
||||
loss = F.nll_loss(
|
||||
lprobs, target, ignore_index=self.padding_idx, reduction=reduction
|
||||
)
|
||||
return lprobs, loss
|
||||
|
||||
def get_logging_output(self, sample, target, lprobs, loss):
|
||||
target = target.view(-1)
|
||||
mask = target != self.padding_idx
|
||||
correct = torch.sum(
|
||||
lprobs.argmax(1).masked_select(mask) == target.masked_select(mask)
|
||||
)
|
||||
total = torch.sum(mask)
|
||||
sample_size = (
|
||||
sample["target"].size(0) if self.args.sentence_avg else sample["ntokens"]
|
||||
)
|
||||
|
||||
logging_output = {
|
||||
"loss": utils.item(loss.data), # * sample['ntokens'],
|
||||
"ntokens": sample["ntokens"],
|
||||
"nsentences": sample["target"].size(0),
|
||||
"sample_size": sample_size,
|
||||
"correct": utils.item(correct.data),
|
||||
"total": utils.item(total.data),
|
||||
"nframes": torch.sum(sample["net_input"]["src_lengths"]).item(),
|
||||
}
|
||||
|
||||
return sample_size, logging_output
|
||||
|
||||
def forward(self, model, sample, reduction="sum", log_probs=True):
|
||||
"""Computes the cross entropy with accuracy metric for the given sample.
|
||||
|
||||
This is similar to CrossEntropyCriterion in fairseq, but also
|
||||
computes accuracy metrics as part of logging
|
||||
|
||||
Args:
|
||||
logprobs (Torch.tensor) of shape N, T, D i.e.
|
||||
batchsize, timesteps, dimensions
|
||||
targets (Torch.tensor) of shape N, T i.e batchsize, timesteps
|
||||
|
||||
Returns:
|
||||
tuple: With three elements:
|
||||
1) the loss
|
||||
2) the sample size, which is used as the denominator for the gradient
|
||||
3) logging outputs to display while training
|
||||
|
||||
TODO:
|
||||
* Currently this Criterion will only work with LSTMEncoderModels or
|
||||
FairseqModels which have decoder, or Models which return TorchTensor
|
||||
as net_output.
|
||||
We need to make a change to support all FairseqEncoder models.
|
||||
"""
|
||||
net_output = model(**sample["net_input"])
|
||||
target = model.get_targets(sample, net_output)
|
||||
lprobs, loss = self.compute_loss(
|
||||
model, net_output, target, reduction, log_probs
|
||||
)
|
||||
sample_size, logging_output = self.get_logging_output(
|
||||
sample, target, lprobs, loss
|
||||
)
|
||||
return loss, sample_size, logging_output
|
||||
|
||||
@staticmethod
|
||||
def aggregate_logging_outputs(logging_outputs):
|
||||
"""Aggregate logging outputs from data parallel training."""
|
||||
correct_sum = sum(log.get("correct", 0) for log in logging_outputs)
|
||||
total_sum = sum(log.get("total", 0) for log in logging_outputs)
|
||||
loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
|
||||
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
|
||||
nsentences = sum(log.get("nsentences", 0) for log in logging_outputs)
|
||||
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
|
||||
nframes = sum(log.get("nframes", 0) for log in logging_outputs)
|
||||
agg_output = {
|
||||
"loss": loss_sum / sample_size / math.log(2) if sample_size > 0 else 0.0,
|
||||
# if args.sentence_avg, then sample_size is nsentences, then loss
|
||||
# is per-sentence loss; else sample_size is ntokens, the loss
|
||||
# becomes per-output token loss
|
||||
"ntokens": ntokens,
|
||||
"nsentences": nsentences,
|
||||
"nframes": nframes,
|
||||
"sample_size": sample_size,
|
||||
"acc": correct_sum * 100.0 / total_sum if total_sum > 0 else 0.0,
|
||||
"correct": correct_sum,
|
||||
"total": total_sum,
|
||||
# total is the number of validate tokens
|
||||
}
|
||||
if sample_size != ntokens:
|
||||
agg_output["nll_loss"] = loss_sum / ntokens / math.log(2)
|
||||
# loss: per output token loss
|
||||
# nll_loss: per sentence loss
|
||||
return agg_output
|
10
examples/speech_recognition/data/__init__.py
Normal file
10
examples/speech_recognition/data/__init__.py
Normal file
@ -0,0 +1,10 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from .asr_dataset import AsrDataset
|
||||
|
||||
__all__ = [
|
||||
'AsrDataset',
|
||||
]
|
110
examples/speech_recognition/data/asr_dataset.py
Normal file
110
examples/speech_recognition/data/asr_dataset.py
Normal file
@ -0,0 +1,110 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
from fairseq.data import FairseqDataset
|
||||
|
||||
from . import data_utils
|
||||
from .collaters import Seq2SeqCollater
|
||||
|
||||
|
||||
class AsrDataset(FairseqDataset):
|
||||
"""
|
||||
A dataset representing speech and corresponding transcription.
|
||||
|
||||
Args:
|
||||
aud_paths: (List[str]): A list of str with paths to audio files.
|
||||
aud_durations_ms (List[int]): A list of int containing the durations of
|
||||
audio files.
|
||||
tgt (List[torch.LongTensor]): A list of LongTensors containing the indices
|
||||
of target transcriptions.
|
||||
tgt_dict (~fairseq.data.Dictionary): target vocabulary.
|
||||
ids (List[str]): A list of utterance IDs.
|
||||
speakers (List[str]): A list of speakers corresponding to utterances.
|
||||
num_mel_bins (int): Number of triangular mel-frequency bins (default: 80)
|
||||
frame_length (float): Frame length in milliseconds (default: 25.0)
|
||||
frame_shift (float): Frame shift in milliseconds (default: 10.0)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, aud_paths, aud_durations_ms, tgt,
|
||||
tgt_dict, ids, speakers,
|
||||
num_mel_bins=80, frame_length=25.0, frame_shift=10.0
|
||||
):
|
||||
assert frame_length > 0
|
||||
assert frame_shift > 0
|
||||
assert all(x > frame_length for x in aud_durations_ms)
|
||||
self.frame_sizes = [
|
||||
int(1 + (d - frame_length) / frame_shift)
|
||||
for d in aud_durations_ms
|
||||
]
|
||||
|
||||
assert len(aud_paths) > 0
|
||||
assert len(aud_paths) == len(aud_durations_ms)
|
||||
assert len(aud_paths) == len(tgt)
|
||||
assert len(aud_paths) == len(ids)
|
||||
assert len(aud_paths) == len(speakers)
|
||||
self.aud_paths = aud_paths
|
||||
self.tgt_dict = tgt_dict
|
||||
self.tgt = tgt
|
||||
self.ids = ids
|
||||
self.speakers = speakers
|
||||
self.num_mel_bins = num_mel_bins
|
||||
self.frame_length = frame_length
|
||||
self.frame_shift = frame_shift
|
||||
|
||||
def __getitem__(self, index):
|
||||
import torchaudio
|
||||
import torchaudio.compliance.kaldi as kaldi
|
||||
tgt_item = self.tgt[index] if self.tgt is not None else None
|
||||
|
||||
path = self.aud_paths[index]
|
||||
if not os.path.exists(path):
|
||||
raise FileNotFoundError("Audio file not found: {}".format(path))
|
||||
sound, sample_rate = torchaudio.load_wav(path)
|
||||
output = kaldi.fbank(
|
||||
sound,
|
||||
num_mel_bins=self.num_mel_bins,
|
||||
frame_length=self.frame_length,
|
||||
frame_shift=self.frame_shift
|
||||
)
|
||||
output_cmvn = data_utils.apply_mv_norm(output)
|
||||
self.collater = Seq2SeqCollater(
|
||||
0, 1, pad_index=self.tgt_dict.pad(),
|
||||
eos_index=self.tgt_dict.eos(), move_eos_to_beginning=True
|
||||
)
|
||||
|
||||
return {"id": index, "data": [output_cmvn.detach(), tgt_item]}
|
||||
|
||||
def __len__(self):
|
||||
return len(self.aud_paths)
|
||||
|
||||
def collater(self, samples):
|
||||
"""Merge a list of samples to form a mini-batch.
|
||||
|
||||
Args:
|
||||
samples (List[int]): sample indices to collate
|
||||
|
||||
Returns:
|
||||
dict: a mini-batch suitable for forwarding with a Model
|
||||
"""
|
||||
return self.collater.collate(samples)
|
||||
|
||||
def num_tokens(self, index):
|
||||
return self.frame_sizes[index]
|
||||
|
||||
def size(self, index):
|
||||
"""Return an example's size as a float or tuple. This value is used when
|
||||
filtering a dataset with ``--max-positions``."""
|
||||
return (
|
||||
self.frame_sizes[index],
|
||||
len(self.tgt[index]) if self.tgt is not None else 0,
|
||||
)
|
||||
|
||||
def ordered_indices(self):
|
||||
"""Return an ordered list of indices. Batches will be constructed based
|
||||
on this order."""
|
||||
return np.arange(len(self))
|
129
examples/speech_recognition/data/collaters.py
Normal file
129
examples/speech_recognition/data/collaters.py
Normal file
@ -0,0 +1,129 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
"""
|
||||
This module contains collection of classes which implement
|
||||
collate functionalities for various tasks.
|
||||
|
||||
Collaters should know what data to expect for each sample
|
||||
and they should pack / collate them into batches
|
||||
"""
|
||||
|
||||
|
||||
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
from fairseq.data import data_utils as fairseq_data_utils
|
||||
|
||||
|
||||
class Seq2SeqCollater(object):
|
||||
"""
|
||||
Implements collate function mainly for seq2seq tasks
|
||||
This expects each sample to contain feature (src_tokens) and
|
||||
targets.
|
||||
This collator is also used for aligned training task.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
feature_index=0,
|
||||
label_index=1,
|
||||
pad_index=1,
|
||||
eos_index=2,
|
||||
move_eos_to_beginning=True,
|
||||
):
|
||||
self.feature_index = feature_index
|
||||
self.label_index = label_index
|
||||
self.pad_index = pad_index
|
||||
self.eos_index = eos_index
|
||||
self.move_eos_to_beginning = move_eos_to_beginning
|
||||
|
||||
def _collate_frames(self, frames):
|
||||
"""Convert a list of 2d frames into a padded 3d tensor
|
||||
Args:
|
||||
frames (list): list of 2d frames of size L[i]*f_dim. Where L[i] is
|
||||
length of i-th frame and f_dim is static dimension of features
|
||||
Returns:
|
||||
3d tensor of size len(frames)*len_max*f_dim where len_max is max of L[i]
|
||||
"""
|
||||
len_max = max(frame.size(0) for frame in frames)
|
||||
f_dim = frames[0].size(1)
|
||||
res = frames[0].new(len(frames), len_max, f_dim).fill_(0.0)
|
||||
|
||||
for i, v in enumerate(frames):
|
||||
res[i, : v.size(0)] = v
|
||||
|
||||
return res
|
||||
|
||||
def collate(self, samples):
|
||||
"""
|
||||
utility function to collate samples into batch for speech recognition.
|
||||
"""
|
||||
if len(samples) == 0:
|
||||
return {}
|
||||
|
||||
# parse samples into torch tensors
|
||||
parsed_samples = []
|
||||
for s in samples:
|
||||
# skip invalid samples
|
||||
if s["data"][self.feature_index] is None:
|
||||
continue
|
||||
source = s["data"][self.feature_index]
|
||||
if isinstance(source, (np.ndarray, np.generic)):
|
||||
source = torch.from_numpy(source)
|
||||
target = s["data"][self.label_index]
|
||||
if isinstance(target, (np.ndarray, np.generic)):
|
||||
target = torch.from_numpy(target).long()
|
||||
|
||||
parsed_sample = {"id": s["id"], "source": source, "target": target}
|
||||
parsed_samples.append(parsed_sample)
|
||||
samples = parsed_samples
|
||||
|
||||
id = torch.LongTensor([s["id"] for s in samples])
|
||||
frames = self._collate_frames([s["source"] for s in samples])
|
||||
# sort samples by descending number of frames
|
||||
frames_lengths = torch.LongTensor([s["source"].size(0) for s in samples])
|
||||
frames_lengths, sort_order = frames_lengths.sort(descending=True)
|
||||
id = id.index_select(0, sort_order)
|
||||
frames = frames.index_select(0, sort_order)
|
||||
|
||||
target = None
|
||||
target_lengths = None
|
||||
prev_output_tokens = None
|
||||
if samples[0].get("target", None) is not None:
|
||||
ntokens = sum(len(s["target"]) for s in samples)
|
||||
target = fairseq_data_utils.collate_tokens(
|
||||
[s["target"] for s in samples],
|
||||
self.pad_index,
|
||||
self.eos_index,
|
||||
left_pad=False,
|
||||
move_eos_to_beginning=False,
|
||||
)
|
||||
target = target.index_select(0, sort_order)
|
||||
target_lengths = torch.LongTensor(
|
||||
[s["target"].size(0) for s in samples]
|
||||
).index_select(0, sort_order)
|
||||
prev_output_tokens = fairseq_data_utils.collate_tokens(
|
||||
[s["target"] for s in samples],
|
||||
self.pad_index,
|
||||
self.eos_index,
|
||||
left_pad=False,
|
||||
move_eos_to_beginning=self.move_eos_to_beginning,
|
||||
)
|
||||
prev_output_tokens = prev_output_tokens.index_select(0, sort_order)
|
||||
else:
|
||||
ntokens = sum(len(s["source"]) for s in samples)
|
||||
|
||||
batch = {
|
||||
"id": id,
|
||||
"ntokens": ntokens,
|
||||
"net_input": {"src_tokens": frames, "src_lengths": frames_lengths},
|
||||
"target": target,
|
||||
"target_lengths": target_lengths,
|
||||
"nsentences": len(samples),
|
||||
}
|
||||
if prev_output_tokens is not None:
|
||||
batch["net_input"]["prev_output_tokens"] = prev_output_tokens
|
||||
return batch
|
60
examples/speech_recognition/data/data_utils.py
Normal file
60
examples/speech_recognition/data/data_utils.py
Normal file
@ -0,0 +1,60 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def calc_mean_invstddev(feature):
|
||||
if len(feature.size()) != 2:
|
||||
raise ValueError("We expect the input feature to be 2-D tensor")
|
||||
mean = feature.mean(0)
|
||||
var = feature.var(0)
|
||||
# avoid division by ~zero
|
||||
eps = 1e-8
|
||||
if (var < eps).any():
|
||||
return mean, 1.0 / (torch.sqrt(var) + eps)
|
||||
return mean, 1.0 / torch.sqrt(var)
|
||||
|
||||
|
||||
def apply_mv_norm(features):
|
||||
mean, invstddev = calc_mean_invstddev(features)
|
||||
res = (features - mean) * invstddev
|
||||
return res
|
||||
|
||||
|
||||
def lengths_to_encoder_padding_mask(lengths, batch_first=False):
|
||||
"""
|
||||
convert lengths (a 1-D Long/Int tensor) to 2-D binary tensor
|
||||
|
||||
Args:
|
||||
lengths: a (B, )-shaped tensor
|
||||
|
||||
Return:
|
||||
max_length: maximum length of B sequences
|
||||
encoder_padding_mask: a (max_length, B) binary mask, where
|
||||
[t, b] = 0 for t < lengths[b] and 1 otherwise
|
||||
|
||||
TODO:
|
||||
kernelize this function if benchmarking shows this function is slow
|
||||
"""
|
||||
max_lengths = torch.max(lengths).item()
|
||||
bsz = lengths.size(0)
|
||||
encoder_padding_mask = torch.arange(
|
||||
max_lengths
|
||||
).to( # a (T, ) tensor with [0, ..., T-1]
|
||||
lengths.device
|
||||
).view( # move to the right device
|
||||
1, max_lengths
|
||||
).expand( # reshape to (1, T)-shaped tensor
|
||||
bsz, -1
|
||||
) >= lengths.view( # expand to (B, T)-shaped tensor
|
||||
bsz, 1
|
||||
).expand(
|
||||
-1, max_lengths
|
||||
)
|
||||
if not batch_first:
|
||||
return encoder_padding_mask.t(), max_lengths
|
||||
else:
|
||||
return encoder_padding_mask, max_lengths
|
96
examples/speech_recognition/datasets/asr_prep_json.py
Normal file
96
examples/speech_recognition/datasets/asr_prep_json.py
Normal file
@ -0,0 +1,96 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||
|
||||
from collections import namedtuple
|
||||
import concurrent.futures
|
||||
from itertools import chain
|
||||
import argparse
|
||||
import os
|
||||
import json
|
||||
import sentencepiece as spm
|
||||
import multiprocessing
|
||||
import torchaudio
|
||||
|
||||
from fairseq.data import Dictionary
|
||||
|
||||
MILLISECONDS_TO_SECONDS = 0.001
|
||||
|
||||
|
||||
def process_sample(aud_path, lable, utt_id, sp, tgt_dict):
|
||||
input = {}
|
||||
output = {}
|
||||
si, ei = torchaudio.info(aud_path)
|
||||
input["length_ms"] = int(si.length / si.channels / si.rate / MILLISECONDS_TO_SECONDS)
|
||||
input["path"] = aud_path
|
||||
|
||||
token = " ".join(sp.EncodeAsPieces(lable))
|
||||
ids = tgt_dict.encode_line(token, append_eos=False)
|
||||
output["text"] = lable
|
||||
output["token"] = token
|
||||
output["tokenid"] = ', '.join(map(str, [t.tolist() for t in ids]))
|
||||
return {utt_id: {"input": input, "output": output}}
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--audio-dirs", nargs="+", default=['-'], required=True,
|
||||
help="input directories with audio files")
|
||||
parser.add_argument("--labels", required=True,
|
||||
help="aggregated input labels with format <ID LABEL> per line",
|
||||
type=argparse.FileType('r', encoding='UTF-8'))
|
||||
parser.add_argument("--spm-model", required=True,
|
||||
help="sentencepiece model to use for encoding",
|
||||
type=argparse.FileType('r', encoding='UTF-8'))
|
||||
parser.add_argument("--dictionary", required=True,
|
||||
help="file to load fairseq dictionary from",
|
||||
type=argparse.FileType('r', encoding='UTF-8'))
|
||||
parser.add_argument("--audio-format", choices=["flac", "wav"], default="wav")
|
||||
parser.add_argument("--output", required=True, type=argparse.FileType('w'),
|
||||
help="path to save json output")
|
||||
args = parser.parse_args()
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.Load(args.spm_model.name)
|
||||
|
||||
tgt_dict = Dictionary.load(args.dictionary)
|
||||
|
||||
labels = {}
|
||||
for line in args.labels:
|
||||
(utt_id, label) = line.split(" ", 1)
|
||||
labels[utt_id] = label
|
||||
if len(labels) == 0:
|
||||
raise Exception('No labels found in ', args.labels_path)
|
||||
|
||||
Sample = namedtuple('Sample', 'aud_path utt_id')
|
||||
samples = []
|
||||
for path, _, files in chain.from_iterable(os.walk(path) for path in args.audio_dirs):
|
||||
for f in files:
|
||||
if f.endswith(args.audio_format):
|
||||
if len(os.path.splitext(f)) != 2:
|
||||
raise Exception('Expect <utt_id.extension> file name. Got: ', f)
|
||||
utt_id = os.path.splitext(f)[0]
|
||||
if utt_id not in labels:
|
||||
continue
|
||||
samples.append(Sample(os.path.join(path, f), utt_id))
|
||||
|
||||
utts = {}
|
||||
num_cpu = multiprocessing.cpu_count()
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=num_cpu) as executor:
|
||||
future_to_sample = {executor.submit(process_sample, s.aud_path, labels[s.utt_id], s.utt_id, sp, tgt_dict): s for s in samples}
|
||||
for future in concurrent.futures.as_completed(future_to_sample):
|
||||
try:
|
||||
data = future.result()
|
||||
except Exception as exc:
|
||||
print('generated an exception: ', exc)
|
||||
else:
|
||||
utts.update(data)
|
||||
json.dump({"utts": utts}, args.output, indent=4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
88
examples/speech_recognition/datasets/prepare-librispeech.sh
Executable file
88
examples/speech_recognition/datasets/prepare-librispeech.sh
Executable file
@ -0,0 +1,88 @@
|
||||
#!/usr/bin/env bash
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
# Prepare librispeech dataset
|
||||
|
||||
base_url=www.openslr.org/resources/12
|
||||
train_dir=train_960
|
||||
|
||||
if [ "$#" -ne 2 ]; then
|
||||
echo "Usage: $0 <download_dir> <out_dir>"
|
||||
echo "e.g.: $0 /tmp/librispeech_raw/ ~/data/librispeech_final"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
download_dir=${1%/}
|
||||
out_dir=${2%/}
|
||||
|
||||
fairseq_root=~/fairseq-py/
|
||||
mkdir -p ${out_dir}
|
||||
cd ${out_dir} || exit
|
||||
|
||||
nbpe=5000
|
||||
bpemode=unigram
|
||||
|
||||
if [ ! -d "$fairseq_root" ]; then
|
||||
echo "$0: Please set correct fairseq_root"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "Data Download"
|
||||
for part in dev-clean test-clean dev-other test-other train-clean-100 train-clean-360 train-other-500; do
|
||||
url=$base_url/$part.tar.gz
|
||||
if ! wget -P $download_dir $url; then
|
||||
echo "$0: wget failed for $url"
|
||||
exit 1
|
||||
fi
|
||||
if ! tar -C $download_dir -xvzf $download_dir/$part.tar.gz; then
|
||||
echo "$0: error un-tarring archive $download_dir/$part.tar.gz"
|
||||
exit 1
|
||||
fi
|
||||
done
|
||||
|
||||
echo "Merge all train packs into one"
|
||||
mkdir -p ${download_dir}/LibriSpeech/${train_dir}/
|
||||
for part in train-clean-100 train-clean-360 train-other-500; do
|
||||
mv ${download_dir}/LibriSpeech/${part}/* $download_dir/LibriSpeech/${train_dir}/
|
||||
done
|
||||
echo "Merge train text"
|
||||
find ${download_dir}/LibriSpeech/${train_dir}/ -name '*.txt' -exec cat {} \; >> ${download_dir}/LibriSpeech/${train_dir}/text
|
||||
|
||||
# Use combined dev-clean and dev-other as validation set
|
||||
find ${download_dir}/LibriSpeech/dev-clean/ ${download_dir}/LibriSpeech/dev-other/ -name '*.txt' -exec cat {} \; >> ${download_dir}/LibriSpeech/valid_text
|
||||
find ${download_dir}/LibriSpeech/test-clean/ -name '*.txt' -exec cat {} \; >> ${download_dir}/LibriSpeech/test-clean/text
|
||||
find ${download_dir}/LibriSpeech/test-other/ -name '*.txt' -exec cat {} \; >> ${download_dir}/LibriSpeech/test-other/text
|
||||
|
||||
|
||||
dict=data/lang_char/${train_dir}_${bpemode}${nbpe}_units.txt
|
||||
encoded=data/lang_char/${train_dir}_${bpemode}${nbpe}_encoded.txt
|
||||
fairseq_dict=data/lang_char/${train_dir}_${bpemode}${nbpe}_fairseq_dict.txt
|
||||
bpemodel=data/lang_char/${train_dir}_${bpemode}${nbpe}
|
||||
echo "dictionary: ${dict}"
|
||||
echo "Dictionary preparation"
|
||||
mkdir -p data/lang_char/
|
||||
echo "<unk> 3" > ${dict}
|
||||
echo "</s> 2" >> ${dict}
|
||||
echo "<pad> 1" >> ${dict}
|
||||
cut -f 2- -d" " ${download_dir}/LibriSpeech/${train_dir}/text > data/lang_char/input.txt
|
||||
spm_train --input=data/lang_char/input.txt --vocab_size=${nbpe} --model_type=${bpemode} --model_prefix=${bpemodel} --input_sentence_size=100000000 --unk_id=3 --eos_id=2 --pad_id=1 --bos_id=-1 --character_coverage=1
|
||||
spm_encode --model=${bpemodel}.model --output_format=piece < data/lang_char/input.txt > ${encoded}
|
||||
cat ${encoded} | tr ' ' '\n' | sort | uniq | awk '{print $0 " " NR+3}' >> ${dict}
|
||||
cat ${encoded} | tr ' ' '\n' | sort | uniq -c | awk '{print $2 " " $1}' > ${fairseq_dict}
|
||||
wc -l ${dict}
|
||||
|
||||
echo "Prepare train and test jsons"
|
||||
for part in train_960 test-other test-clean; do
|
||||
python ${fairseq_root}/examples/speech_recognition/datasets/asr_prep_json.py --audio-dirs ${download_dir}/LibriSpeech/${part} --labels ${download_dir}/LibriSpeech/${part}/text --spm-model ${bpemodel}.model --audio-format flac --dictionary ${fairseq_dict} --output ${part}.json
|
||||
done
|
||||
# fairseq expects to find train.json and valid.json during training
|
||||
mv train_960.json train.json
|
||||
|
||||
echo "Prepare valid json"
|
||||
python ${fairseq_root}/examples/speech_recognition/datasets/asr_prep_json.py --audio-dirs ${download_dir}/LibriSpeech/dev-clean ${download_dir}/LibriSpeech/dev-other --labels ${download_dir}/LibriSpeech/valid_text --spm-model ${bpemodel}.model --audio-format flac --dictionary ${fairseq_dict} --output valid.json
|
||||
|
||||
cp ${fairseq_dict} ./dict.txt
|
||||
cp ${bpemodel}.model ./spm.model
|
243
examples/speech_recognition/infer.py
Normal file
243
examples/speech_recognition/infer.py
Normal file
@ -0,0 +1,243 @@
|
||||
#!/usr/bin/env python3 -u
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""
|
||||
Run inference for pre-processed data with a trained model.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
from fairseq import options, progress_bar, utils, tasks
|
||||
from fairseq.meters import StopwatchMeter, TimeMeter
|
||||
from fairseq.utils import import_user_module
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
|
||||
def add_asr_eval_argument(parser):
|
||||
parser.add_argument("--ctc", action="store_true", help="decode a ctc model")
|
||||
parser.add_argument("--rnnt", default=False, help="decode a rnnt model")
|
||||
parser.add_argument("--kspmodel", default=None, help="sentence piece model")
|
||||
parser.add_argument(
|
||||
"--wfstlm", default=None, help="wfstlm on dictonary output units"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--rnnt_decoding_type",
|
||||
default="greedy",
|
||||
help="wfstlm on dictonary\
|
||||
output units",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lm_weight",
|
||||
default=0.2,
|
||||
help="weight for wfstlm while interpolating\
|
||||
with neural score",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--rnnt_len_penalty", default=-0.5, help="rnnt length penalty on word level"
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def check_args(args):
|
||||
assert args.path is not None, "--path required for generation!"
|
||||
assert args.results_path is not None, "--results_path required for generation!"
|
||||
assert (
|
||||
not args.sampling or args.nbest == args.beam
|
||||
), "--sampling requires --nbest to be equal to --beam"
|
||||
assert (
|
||||
args.replace_unk is None or args.raw_text
|
||||
), "--replace-unk requires a raw text dataset (--raw-text)"
|
||||
|
||||
|
||||
def get_dataset_itr(args, task):
|
||||
return task.get_batch_iterator(
|
||||
dataset=task.dataset(args.gen_subset),
|
||||
max_tokens=args.max_tokens,
|
||||
max_sentences=args.max_sentences,
|
||||
max_positions=(1000000.0, 1000000.0),
|
||||
ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
|
||||
required_batch_size_multiple=args.required_batch_size_multiple,
|
||||
num_shards=args.num_shards,
|
||||
shard_id=args.shard_id,
|
||||
num_workers=args.num_workers,
|
||||
).next_epoch_itr(shuffle=False)
|
||||
|
||||
|
||||
def process_predictions(args, hypos, sp, tgt_dict, target_tokens, res_files, speaker, id):
|
||||
for hypo in hypos[: min(len(hypos), args.nbest)]:
|
||||
hyp_pieces = tgt_dict.string(hypo["tokens"].int().cpu())
|
||||
hyp_words = sp.DecodePieces(hyp_pieces.split())
|
||||
print(
|
||||
"{} ({}-{})".format(hyp_pieces, speaker, id),
|
||||
file=res_files["hypo.units"],
|
||||
)
|
||||
print(
|
||||
"{} ({}-{})".format(hyp_words, speaker, id),
|
||||
file=res_files["hypo.words"],
|
||||
)
|
||||
|
||||
tgt_pieces = tgt_dict.string(target_tokens)
|
||||
tgt_words = sp.DecodePieces(tgt_pieces.split())
|
||||
print(
|
||||
"{} ({}-{})".format(tgt_pieces, speaker, id),
|
||||
file=res_files["ref.units"],
|
||||
)
|
||||
print(
|
||||
"{} ({}-{})".format(tgt_words, speaker, id),
|
||||
file=res_files["ref.words"],
|
||||
)
|
||||
# only score top hypothesis
|
||||
if not args.quiet:
|
||||
logger.debug("HYPO:" + hyp_words)
|
||||
logger.debug("TARGET:" + tgt_words)
|
||||
logger.debug("___________________")
|
||||
|
||||
|
||||
def prepare_result_files(args):
|
||||
def get_res_file(file_prefix):
|
||||
path = os.path.join(
|
||||
args.results_path,
|
||||
"{}-{}-{}.txt".format(
|
||||
file_prefix, os.path.basename(args.path), args.gen_subset
|
||||
),
|
||||
)
|
||||
return open(path, "w", buffering=1)
|
||||
|
||||
return {
|
||||
"hypo.words": get_res_file("hypo.word"),
|
||||
"hypo.units": get_res_file("hypo.units"),
|
||||
"ref.words": get_res_file("ref.word"),
|
||||
"ref.units": get_res_file("ref.units"),
|
||||
}
|
||||
|
||||
|
||||
def optimize_models(args, use_cuda, models):
|
||||
"""Optimize ensemble for generation
|
||||
"""
|
||||
for model in models:
|
||||
model.make_generation_fast_(
|
||||
beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
|
||||
need_attn=args.print_alignment,
|
||||
)
|
||||
if args.fp16:
|
||||
model.half()
|
||||
if use_cuda:
|
||||
model.cuda()
|
||||
|
||||
|
||||
def main(args):
|
||||
check_args(args)
|
||||
import_user_module(args)
|
||||
|
||||
if args.max_tokens is None and args.max_sentences is None:
|
||||
args.max_tokens = 30000
|
||||
logger.info(args)
|
||||
|
||||
use_cuda = torch.cuda.is_available() and not args.cpu
|
||||
|
||||
# Load dataset splits
|
||||
task = tasks.setup_task(args)
|
||||
task.load_dataset(args.gen_subset)
|
||||
logger.info(
|
||||
"| {} {} {} examples".format(
|
||||
args.data, args.gen_subset, len(task.dataset(args.gen_subset))
|
||||
)
|
||||
)
|
||||
|
||||
# Set dictionary
|
||||
tgt_dict = task.target_dictionary
|
||||
|
||||
if args.ctc or args.rnnt:
|
||||
tgt_dict.add_symbol("<ctc_blank>")
|
||||
if args.ctc:
|
||||
logger.info("| decoding a ctc model")
|
||||
if args.rnnt:
|
||||
logger.info("| decoding a rnnt model")
|
||||
|
||||
# Load ensemble
|
||||
logger.info("| loading model(s) from {}".format(args.path))
|
||||
models, _model_args = utils.load_ensemble_for_inference(
|
||||
args.path.split(":"),
|
||||
task,
|
||||
model_arg_overrides=eval(args.model_overrides), # noqa
|
||||
)
|
||||
optimize_models(args, use_cuda, models)
|
||||
|
||||
# Load dataset (possibly sharded)
|
||||
itr = get_dataset_itr(args, task)
|
||||
|
||||
# Initialize generator
|
||||
gen_timer = StopwatchMeter()
|
||||
generator = task.build_generator(args)
|
||||
|
||||
num_sentences = 0
|
||||
|
||||
if not os.path.exists(args.results_path):
|
||||
os.makedirs(args.results_path)
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.Load(os.path.join(args.data, 'spm.model'))
|
||||
|
||||
res_files = prepare_result_files(args)
|
||||
with progress_bar.build_progress_bar(args, itr) as t:
|
||||
wps_meter = TimeMeter()
|
||||
for sample in t:
|
||||
sample = utils.move_to_cuda(sample) if use_cuda else sample
|
||||
if "net_input" not in sample:
|
||||
continue
|
||||
|
||||
prefix_tokens = None
|
||||
if args.prefix_size > 0:
|
||||
prefix_tokens = sample["target"][:, : args.prefix_size]
|
||||
|
||||
gen_timer.start()
|
||||
hypos = task.inference_step(generator, models, sample, prefix_tokens)
|
||||
num_generated_tokens = sum(len(h[0]["tokens"]) for h in hypos)
|
||||
gen_timer.stop(num_generated_tokens)
|
||||
|
||||
for i, sample_id in enumerate(sample['id'].tolist()):
|
||||
speaker = task.dataset(args.gen_subset).speakers[int(sample_id)]
|
||||
id = task.dataset(args.gen_subset).ids[int(sample_id)]
|
||||
target_tokens = (
|
||||
utils.strip_pad(sample["target"][i, :], tgt_dict.pad()).int().cpu()
|
||||
)
|
||||
# Process top predictions
|
||||
process_predictions(
|
||||
args, hypos[i], sp, tgt_dict, target_tokens, res_files, speaker, id
|
||||
)
|
||||
|
||||
wps_meter.update(num_generated_tokens)
|
||||
t.log({"wps": round(wps_meter.avg)})
|
||||
num_sentences += sample["nsentences"]
|
||||
|
||||
logger.info(
|
||||
"| Processed {} sentences ({} tokens) in {:.1f}s ({:.2f}"
|
||||
"sentences/s, {:.2f} tokens/s)".format(
|
||||
num_sentences,
|
||||
gen_timer.n,
|
||||
gen_timer.sum,
|
||||
num_sentences / gen_timer.sum,
|
||||
1.0 / gen_timer.avg,
|
||||
)
|
||||
)
|
||||
logger.info("| Generate {} with beam={}".format(args.gen_subset, args.beam))
|
||||
|
||||
|
||||
def cli_main():
|
||||
parser = options.get_generation_parser()
|
||||
parser = add_asr_eval_argument(parser)
|
||||
args = options.parse_args_and_arch(parser)
|
||||
main(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli_main()
|
7
examples/speech_recognition/models/__init__.py
Normal file
7
examples/speech_recognition/models/__init__.py
Normal file
@ -0,0 +1,7 @@
|
||||
import importlib
|
||||
import os
|
||||
|
||||
for file in os.listdir(os.path.dirname(__file__)):
|
||||
if file.endswith('.py') and not file.startswith('_'):
|
||||
model_name = file[:file.find('.py')]
|
||||
importlib.import_module('examples.speech_recognition.models.' + model_name)
|
838
examples/speech_recognition/models/vggtransformer.py
Normal file
838
examples/speech_recognition/models/vggtransformer.py
Normal file
@ -0,0 +1,838 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
import math
|
||||
from collections.abc import Iterable
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from fairseq import utils
|
||||
from fairseq.models import (
|
||||
FairseqEncoder,
|
||||
FairseqIncrementalDecoder,
|
||||
FairseqModel,
|
||||
register_model,
|
||||
register_model_architecture,
|
||||
)
|
||||
from fairseq.modules import LinearizedConvolution
|
||||
from examples.speech_recognition.data.data_utils import lengths_to_encoder_padding_mask
|
||||
from fairseq.modules import TransformerDecoderLayer, TransformerEncoderLayer, VGGBlock
|
||||
|
||||
|
||||
@register_model("asr_vggtransformer")
|
||||
class VGGTransformerModel(FairseqModel):
|
||||
"""
|
||||
Transformers with convolutional context for ASR
|
||||
https://arxiv.org/abs/1904.11660
|
||||
"""
|
||||
def __init__(self, encoder, decoder):
|
||||
super().__init__(encoder, decoder)
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
"""Add model-specific arguments to the parser."""
|
||||
parser.add_argument(
|
||||
"--input-feat-per-channel",
|
||||
type=int,
|
||||
metavar="N",
|
||||
help="encoder input dimension per input channel",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vggblock-enc-config",
|
||||
type=str,
|
||||
metavar="EXPR",
|
||||
help="""
|
||||
an array of tuples each containing the configuration of one vggblock:
|
||||
[(out_channels,
|
||||
conv_kernel_size,
|
||||
pooling_kernel_size,
|
||||
num_conv_layers,
|
||||
use_layer_norm), ...])
|
||||
""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--transformer-enc-config",
|
||||
type=str,
|
||||
metavar="EXPR",
|
||||
help=""""
|
||||
a tuple containing the configuration of the encoder transformer layers
|
||||
configurations:
|
||||
[(input_dim,
|
||||
num_heads,
|
||||
ffn_dim,
|
||||
normalize_before,
|
||||
dropout,
|
||||
attention_dropout,
|
||||
relu_dropout), ...]')
|
||||
""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enc-output-dim",
|
||||
type=int,
|
||||
metavar="N",
|
||||
help="""
|
||||
encoder output dimension, can be None. If specified, projecting the
|
||||
transformer output to the specified dimension""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--in-channels",
|
||||
type=int,
|
||||
metavar="N",
|
||||
help="number of encoder input channels",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tgt-embed-dim",
|
||||
type=int,
|
||||
metavar="N",
|
||||
help="embedding dimension of the decoder target tokens",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--transformer-dec-config",
|
||||
type=str,
|
||||
metavar="EXPR",
|
||||
help="""
|
||||
a tuple containing the configuration of the decoder transformer layers
|
||||
configurations:
|
||||
[(input_dim,
|
||||
num_heads,
|
||||
ffn_dim,
|
||||
normalize_before,
|
||||
dropout,
|
||||
attention_dropout,
|
||||
relu_dropout), ...]
|
||||
""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--conv-dec-config",
|
||||
type=str,
|
||||
metavar="EXPR",
|
||||
help="""
|
||||
an array of tuples for the decoder 1-D convolution config
|
||||
[(out_channels, conv_kernel_size, use_layer_norm), ...]""",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def build_encoder(cls, args, task):
|
||||
return VGGTransformerEncoder(
|
||||
input_feat_per_channel=args.input_feat_per_channel,
|
||||
vggblock_config=eval(args.vggblock_enc_config),
|
||||
transformer_config=eval(args.transformer_enc_config),
|
||||
encoder_output_dim=args.enc_output_dim,
|
||||
in_channels=args.in_channels,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def build_decoder(cls, args, task):
|
||||
return TransformerDecoder(
|
||||
dictionary=task.target_dictionary,
|
||||
embed_dim=args.tgt_embed_dim,
|
||||
transformer_config=eval(args.transformer_dec_config),
|
||||
conv_config=eval(args.conv_dec_config),
|
||||
encoder_output_dim=args.enc_output_dim,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def build_model(cls, args, task):
|
||||
"""Build a new model instance."""
|
||||
# make sure that all args are properly defaulted
|
||||
# (in case there are any new ones)
|
||||
base_architecture(args)
|
||||
|
||||
encoder = cls.build_encoder(args, task)
|
||||
decoder = cls.build_decoder(args, task)
|
||||
return cls(encoder, decoder)
|
||||
|
||||
def get_normalized_probs(self, net_output, log_probs, sample=None):
|
||||
# net_output['encoder_out'] is a (B, T, D) tensor
|
||||
lprobs = super().get_normalized_probs(net_output, log_probs, sample)
|
||||
lprobs.batch_first = True
|
||||
return lprobs
|
||||
|
||||
|
||||
DEFAULT_ENC_VGGBLOCK_CONFIG = ((32, 3, 2, 2, False),) * 2
|
||||
DEFAULT_ENC_TRANSFORMER_CONFIG = ((256, 4, 1024, True, 0.2, 0.2, 0.2),) * 2
|
||||
# 256: embedding dimension
|
||||
# 4: number of heads
|
||||
# 1024: FFN
|
||||
# True: apply layerNorm before (dropout + resiaul) instead of after
|
||||
# 0.2 (dropout): dropout after MultiheadAttention and second FC
|
||||
# 0.2 (attention_dropout): dropout in MultiheadAttention
|
||||
# 0.2 (relu_dropout): dropout after ReLu
|
||||
DEFAULT_DEC_TRANSFORMER_CONFIG = ((256, 2, 1024, True, 0.2, 0.2, 0.2),) * 2
|
||||
DEFAULT_DEC_CONV_CONFIG = ((256, 3, True),) * 2
|
||||
|
||||
|
||||
# TODO: repace transformer encoder config from one liner
|
||||
# to explicit args to get rid of this transformation
|
||||
def prepare_transformer_encoder_params(
|
||||
input_dim,
|
||||
num_heads,
|
||||
ffn_dim,
|
||||
normalize_before,
|
||||
dropout,
|
||||
attention_dropout,
|
||||
relu_dropout,
|
||||
):
|
||||
args = argparse.Namespace()
|
||||
args.encoder_embed_dim = input_dim
|
||||
args.encoder_attention_heads = num_heads
|
||||
args.attention_dropout = attention_dropout
|
||||
args.dropout = dropout
|
||||
args.activation_dropout = relu_dropout
|
||||
args.encoder_normalize_before = normalize_before
|
||||
args.encoder_ffn_embed_dim = ffn_dim
|
||||
return args
|
||||
|
||||
|
||||
def prepare_transformer_decoder_params(
|
||||
input_dim,
|
||||
num_heads,
|
||||
ffn_dim,
|
||||
normalize_before,
|
||||
dropout,
|
||||
attention_dropout,
|
||||
relu_dropout,
|
||||
):
|
||||
args = argparse.Namespace()
|
||||
args.decoder_embed_dim = input_dim
|
||||
args.decoder_attention_heads = num_heads
|
||||
args.attention_dropout = attention_dropout
|
||||
args.dropout = dropout
|
||||
args.activation_dropout = relu_dropout
|
||||
args.decoder_normalize_before = normalize_before
|
||||
args.decoder_ffn_embed_dim = ffn_dim
|
||||
return args
|
||||
|
||||
|
||||
class VGGTransformerEncoder(FairseqEncoder):
|
||||
"""VGG + Transformer encoder"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_feat_per_channel,
|
||||
vggblock_config=DEFAULT_ENC_VGGBLOCK_CONFIG,
|
||||
transformer_config=DEFAULT_ENC_TRANSFORMER_CONFIG,
|
||||
encoder_output_dim=512,
|
||||
in_channels=1,
|
||||
transformer_context=None,
|
||||
transformer_sampling=None,
|
||||
):
|
||||
"""constructor for VGGTransformerEncoder
|
||||
|
||||
Args:
|
||||
- input_feat_per_channel: feature dim (not including stacked,
|
||||
just base feature)
|
||||
- in_channel: # input channels (e.g., if stack 8 feature vector
|
||||
together, this is 8)
|
||||
- vggblock_config: configuration of vggblock, see comments on
|
||||
DEFAULT_ENC_VGGBLOCK_CONFIG
|
||||
- transformer_config: configuration of transformer layer, see comments
|
||||
on DEFAULT_ENC_TRANSFORMER_CONFIG
|
||||
- encoder_output_dim: final transformer output embedding dimension
|
||||
- transformer_context: (left, right) if set, self-attention will be focused
|
||||
on (t-left, t+right)
|
||||
- transformer_sampling: an iterable of int, must match with
|
||||
len(transformer_config), transformer_sampling[i] indicates sampling
|
||||
factor for i-th transformer layer, after multihead att and feedfoward
|
||||
part
|
||||
"""
|
||||
super().__init__(None)
|
||||
|
||||
self.num_vggblocks = 0
|
||||
if vggblock_config is not None:
|
||||
if not isinstance(vggblock_config, Iterable):
|
||||
raise ValueError("vggblock_config is not iterable")
|
||||
self.num_vggblocks = len(vggblock_config)
|
||||
|
||||
self.conv_layers = nn.ModuleList()
|
||||
self.in_channels = in_channels
|
||||
self.input_dim = input_feat_per_channel
|
||||
|
||||
if vggblock_config is not None:
|
||||
for _, config in enumerate(vggblock_config):
|
||||
(
|
||||
out_channels,
|
||||
conv_kernel_size,
|
||||
pooling_kernel_size,
|
||||
num_conv_layers,
|
||||
layer_norm,
|
||||
) = config
|
||||
self.conv_layers.append(
|
||||
VGGBlock(
|
||||
in_channels,
|
||||
out_channels,
|
||||
conv_kernel_size,
|
||||
pooling_kernel_size,
|
||||
num_conv_layers,
|
||||
input_dim=input_feat_per_channel,
|
||||
layer_norm=layer_norm,
|
||||
)
|
||||
)
|
||||
in_channels = out_channels
|
||||
input_feat_per_channel = self.conv_layers[-1].output_dim
|
||||
|
||||
transformer_input_dim = self.infer_conv_output_dim(
|
||||
self.in_channels, self.input_dim
|
||||
)
|
||||
# transformer_input_dim is the output dimension of VGG part
|
||||
|
||||
self.validate_transformer_config(transformer_config)
|
||||
self.transformer_context = self.parse_transformer_context(transformer_context)
|
||||
self.transformer_sampling = self.parse_transformer_sampling(
|
||||
transformer_sampling, len(transformer_config)
|
||||
)
|
||||
|
||||
self.transformer_layers = nn.ModuleList()
|
||||
|
||||
if transformer_input_dim != transformer_config[0][0]:
|
||||
self.transformer_layers.append(
|
||||
Linear(transformer_input_dim, transformer_config[0][0])
|
||||
)
|
||||
self.transformer_layers.append(
|
||||
TransformerEncoderLayer(
|
||||
prepare_transformer_encoder_params(*transformer_config[0])
|
||||
)
|
||||
)
|
||||
|
||||
for i in range(1, len(transformer_config)):
|
||||
if transformer_config[i - 1][0] != transformer_config[i][0]:
|
||||
self.transformer_layers.append(
|
||||
Linear(transformer_config[i - 1][0], transformer_config[i][0])
|
||||
)
|
||||
self.transformer_layers.append(
|
||||
TransformerEncoderLayer(
|
||||
prepare_transformer_encoder_params(*transformer_config[i])
|
||||
)
|
||||
)
|
||||
|
||||
self.encoder_output_dim = encoder_output_dim
|
||||
self.transformer_layers.extend(
|
||||
[
|
||||
Linear(transformer_config[-1][0], encoder_output_dim),
|
||||
LayerNorm(encoder_output_dim),
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, src_tokens, src_lengths, **kwargs):
|
||||
"""
|
||||
src_tokens: padded tensor (B, T, C * feat)
|
||||
src_lengths: tensor of original lengths of input utterances (B,)
|
||||
"""
|
||||
bsz, max_seq_len, _ = src_tokens.size()
|
||||
x = src_tokens.view(bsz, max_seq_len, self.in_channels, self.input_dim)
|
||||
x = x.transpose(1, 2).contiguous()
|
||||
# (B, C, T, feat)
|
||||
|
||||
for layer_idx in range(len(self.conv_layers)):
|
||||
x = self.conv_layers[layer_idx](x)
|
||||
|
||||
bsz, _, output_seq_len, _ = x.size()
|
||||
|
||||
# (B, C, T, feat) -> (B, T, C, feat) -> (T, B, C, feat) -> (T, B, C * feat)
|
||||
x = x.transpose(1, 2).transpose(0, 1)
|
||||
x = x.contiguous().view(output_seq_len, bsz, -1)
|
||||
|
||||
subsampling_factor = int(max_seq_len * 1.0 / output_seq_len + 0.5)
|
||||
# TODO: shouldn't subsampling_factor determined in advance ?
|
||||
input_lengths = (src_lengths.float() / subsampling_factor).ceil().long()
|
||||
|
||||
encoder_padding_mask, _ = lengths_to_encoder_padding_mask(
|
||||
input_lengths, batch_first=True
|
||||
)
|
||||
if not encoder_padding_mask.any():
|
||||
encoder_padding_mask = None
|
||||
|
||||
attn_mask = self.lengths_to_attn_mask(input_lengths, subsampling_factor)
|
||||
|
||||
transformer_layer_idx = 0
|
||||
|
||||
for layer_idx in range(len(self.transformer_layers)):
|
||||
|
||||
if isinstance(self.transformer_layers[layer_idx], TransformerEncoderLayer):
|
||||
x = self.transformer_layers[layer_idx](
|
||||
x, encoder_padding_mask, attn_mask
|
||||
)
|
||||
|
||||
if self.transformer_sampling[transformer_layer_idx] != 1:
|
||||
sampling_factor = self.transformer_sampling[transformer_layer_idx]
|
||||
x, encoder_padding_mask, attn_mask = self.slice(
|
||||
x, encoder_padding_mask, attn_mask, sampling_factor
|
||||
)
|
||||
|
||||
transformer_layer_idx += 1
|
||||
|
||||
else:
|
||||
x = self.transformer_layers[layer_idx](x)
|
||||
|
||||
# encoder_padding_maks is a (T x B) tensor, its [t, b] elements indicate
|
||||
# whether encoder_output[t, b] is valid or not (valid=0, invalid=1)
|
||||
|
||||
return {
|
||||
"encoder_out": x, # (T, B, C)
|
||||
"encoder_padding_mask": encoder_padding_mask.t()
|
||||
if encoder_padding_mask is not None
|
||||
else None,
|
||||
# (B, T) --> (T, B)
|
||||
}
|
||||
|
||||
def infer_conv_output_dim(self, in_channels, input_dim):
|
||||
sample_seq_len = 200
|
||||
sample_bsz = 10
|
||||
x = torch.randn(sample_bsz, in_channels, sample_seq_len, input_dim)
|
||||
for i, _ in enumerate(self.conv_layers):
|
||||
x = self.conv_layers[i](x)
|
||||
x = x.transpose(1, 2)
|
||||
mb, seq = x.size()[:2]
|
||||
return x.contiguous().view(mb, seq, -1).size(-1)
|
||||
|
||||
def validate_transformer_config(self, transformer_config):
|
||||
for config in transformer_config:
|
||||
input_dim, num_heads = config[:2]
|
||||
if input_dim % num_heads != 0:
|
||||
msg = (
|
||||
"ERROR in transformer config {}:".format(config)
|
||||
+ "input dimension {} ".format(input_dim)
|
||||
+ "not dividable by number of heads".format(num_heads)
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
def parse_transformer_context(self, transformer_context):
|
||||
"""
|
||||
transformer_context can be the following:
|
||||
- None; indicates no context is used, i.e.,
|
||||
transformer can access full context
|
||||
- a tuple/list of two int; indicates left and right context,
|
||||
any number <0 indicates infinite context
|
||||
* e.g., (5, 6) indicates that for query at x_t, transformer can
|
||||
access [t-5, t+6] (inclusive)
|
||||
* e.g., (-1, 6) indicates that for query at x_t, transformer can
|
||||
access [0, t+6] (inclusive)
|
||||
"""
|
||||
if transformer_context is None:
|
||||
return None
|
||||
|
||||
if not isinstance(transformer_context, Iterable):
|
||||
raise ValueError("transformer context must be Iterable if it is not None")
|
||||
|
||||
if len(transformer_context) != 2:
|
||||
raise ValueError("transformer context must have length 2")
|
||||
|
||||
left_context = transformer_context[0]
|
||||
if left_context < 0:
|
||||
left_context = None
|
||||
|
||||
right_context = transformer_context[1]
|
||||
if right_context < 0:
|
||||
right_context = None
|
||||
|
||||
if left_context is None and right_context is None:
|
||||
return None
|
||||
|
||||
return (left_context, right_context)
|
||||
|
||||
def parse_transformer_sampling(self, transformer_sampling, num_layers):
|
||||
"""
|
||||
parsing transformer sampling configuration
|
||||
|
||||
Args:
|
||||
- transformer_sampling, accepted input:
|
||||
* None, indicating no sampling
|
||||
* an Iterable with int (>0) as element
|
||||
- num_layers, expected number of transformer layers, must match with
|
||||
the length of transformer_sampling if it is not None
|
||||
|
||||
Returns:
|
||||
- A tuple with length num_layers
|
||||
"""
|
||||
if transformer_sampling is None:
|
||||
return (1,) * num_layers
|
||||
|
||||
if not isinstance(transformer_sampling, Iterable):
|
||||
raise ValueError(
|
||||
"transformer_sampling must be an iterable if it is not None"
|
||||
)
|
||||
|
||||
if len(transformer_sampling) != num_layers:
|
||||
raise ValueError(
|
||||
"transformer_sampling {} does not match with the number "
|
||||
+ "of layers {}".format(transformer_sampling, num_layers)
|
||||
)
|
||||
|
||||
for layer, value in enumerate(transformer_sampling):
|
||||
if not isinstance(value, int):
|
||||
raise ValueError("Invalid value in transformer_sampling: ")
|
||||
if value < 1:
|
||||
raise ValueError(
|
||||
"{} layer's subsampling is {}.".format(layer, value)
|
||||
+ " This is not allowed! "
|
||||
)
|
||||
return transformer_sampling
|
||||
|
||||
def slice(self, embedding, padding_mask, attn_mask, sampling_factor):
|
||||
"""
|
||||
embedding is a (T, B, D) tensor
|
||||
padding_mask is a (B, T) tensor or None
|
||||
attn_mask is a (T, T) tensor or None
|
||||
"""
|
||||
embedding = embedding[::sampling_factor, :, :]
|
||||
if padding_mask is not None:
|
||||
padding_mask = padding_mask[:, ::sampling_factor]
|
||||
if attn_mask is not None:
|
||||
attn_mask = attn_mask[::sampling_factor, ::sampling_factor]
|
||||
|
||||
return embedding, padding_mask, attn_mask
|
||||
|
||||
def lengths_to_attn_mask(self, input_lengths, subsampling_factor=1):
|
||||
"""
|
||||
create attention mask according to sequence lengths and transformer
|
||||
context
|
||||
|
||||
Args:
|
||||
- input_lengths: (B, )-shape Int/Long tensor; input_lengths[b] is
|
||||
the length of b-th sequence
|
||||
- subsampling_factor: int
|
||||
* Note that the left_context and right_context is specified in
|
||||
the input frame-level while input to transformer may already
|
||||
go through subsampling (e.g., the use of striding in vggblock)
|
||||
we use subsampling_factor to scale the left/right context
|
||||
|
||||
Return:
|
||||
- a (T, T) binary tensor or None, where T is max(input_lengths)
|
||||
* if self.transformer_context is None, None
|
||||
* if left_context is None,
|
||||
* attn_mask[t, t + right_context + 1:] = 1
|
||||
* others = 0
|
||||
* if right_context is None,
|
||||
* attn_mask[t, 0:t - left_context] = 1
|
||||
* others = 0
|
||||
* elsif
|
||||
* attn_mask[t, t - left_context: t + right_context + 1] = 0
|
||||
* others = 1
|
||||
"""
|
||||
if self.transformer_context is None:
|
||||
return None
|
||||
|
||||
maxT = torch.max(input_lengths).item()
|
||||
attn_mask = torch.zeros(maxT, maxT)
|
||||
|
||||
left_context = self.transformer_context[0]
|
||||
right_context = self.transformer_context[1]
|
||||
if left_context is not None:
|
||||
left_context = math.ceil(self.transformer_context[0] / subsampling_factor)
|
||||
if right_context is not None:
|
||||
right_context = math.ceil(self.transformer_context[1] / subsampling_factor)
|
||||
|
||||
for t in range(maxT):
|
||||
if left_context is not None:
|
||||
st = 0
|
||||
en = max(st, t - left_context)
|
||||
attn_mask[t, st:en] = 1
|
||||
if right_context is not None:
|
||||
st = t + right_context + 1
|
||||
st = min(st, maxT - 1)
|
||||
attn_mask[t, st:] = 1
|
||||
|
||||
return attn_mask.to(input_lengths.device)
|
||||
|
||||
def reorder_encoder_out(self, encoder_out, new_order):
|
||||
encoder_out["encoder_out"] = encoder_out["encoder_out"].index_select(
|
||||
1, new_order
|
||||
)
|
||||
if encoder_out["encoder_padding_mask"] is not None:
|
||||
encoder_out["encoder_padding_mask"] = encoder_out[
|
||||
"encoder_padding_mask"
|
||||
].index_select(1, new_order)
|
||||
return encoder_out
|
||||
|
||||
|
||||
class TransformerDecoder(FairseqIncrementalDecoder):
|
||||
"""
|
||||
Transformer decoder consisting of *args.decoder_layers* layers. Each layer
|
||||
is a :class:`TransformerDecoderLayer`.
|
||||
Args:
|
||||
args (argparse.Namespace): parsed command-line arguments
|
||||
dictionary (~fairseq.data.Dictionary): decoding dictionary
|
||||
embed_tokens (torch.nn.Embedding): output embedding
|
||||
no_encoder_attn (bool, optional): whether to attend to encoder outputs.
|
||||
Default: ``False``
|
||||
left_pad (bool, optional): whether the input is left-padded. Default:
|
||||
``False``
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dictionary,
|
||||
embed_dim=512,
|
||||
transformer_config=DEFAULT_ENC_TRANSFORMER_CONFIG,
|
||||
conv_config=DEFAULT_DEC_CONV_CONFIG,
|
||||
encoder_output_dim=512,
|
||||
):
|
||||
|
||||
super().__init__(dictionary)
|
||||
vocab_size = len(dictionary)
|
||||
self.padding_idx = dictionary.pad()
|
||||
self.embed_tokens = Embedding(vocab_size, embed_dim, self.padding_idx)
|
||||
|
||||
self.conv_layers = nn.ModuleList()
|
||||
for i in range(len(conv_config)):
|
||||
out_channels, kernel_size, layer_norm = conv_config[i]
|
||||
if i == 0:
|
||||
conv_layer = LinearizedConv1d(
|
||||
embed_dim, out_channels, kernel_size, padding=kernel_size - 1
|
||||
)
|
||||
else:
|
||||
conv_layer = LinearizedConv1d(
|
||||
conv_config[i - 1][0],
|
||||
out_channels,
|
||||
kernel_size,
|
||||
padding=kernel_size - 1,
|
||||
)
|
||||
self.conv_layers.append(conv_layer)
|
||||
if layer_norm:
|
||||
self.conv_layers.append(nn.LayerNorm(out_channels))
|
||||
self.conv_layers.append(nn.ReLU())
|
||||
|
||||
self.layers = nn.ModuleList()
|
||||
if conv_config[-1][0] != transformer_config[0][0]:
|
||||
self.layers.append(Linear(conv_config[-1][0], transformer_config[0][0]))
|
||||
self.layers.append(TransformerDecoderLayer(
|
||||
prepare_transformer_decoder_params(*transformer_config[0])
|
||||
))
|
||||
|
||||
for i in range(1, len(transformer_config)):
|
||||
if transformer_config[i - 1][0] != transformer_config[i][0]:
|
||||
self.layers.append(
|
||||
Linear(transformer_config[i - 1][0], transformer_config[i][0])
|
||||
)
|
||||
self.layers.append(TransformerDecoderLayer(
|
||||
prepare_transformer_decoder_params(*transformer_config[i])
|
||||
))
|
||||
self.fc_out = Linear(transformer_config[-1][0], vocab_size)
|
||||
|
||||
def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None):
|
||||
"""
|
||||
Args:
|
||||
prev_output_tokens (LongTensor): previous decoder outputs of shape
|
||||
`(batch, tgt_len)`, for input feeding/teacher forcing
|
||||
encoder_out (Tensor, optional): output from the encoder, used for
|
||||
encoder-side attention
|
||||
incremental_state (dict): dictionary used for storing state during
|
||||
:ref:`Incremental decoding`
|
||||
Returns:
|
||||
tuple:
|
||||
- the last decoder layer's output of shape `(batch, tgt_len,
|
||||
vocab)`
|
||||
- the last decoder layer's attention weights of shape `(batch,
|
||||
tgt_len, src_len)`
|
||||
"""
|
||||
target_padding_mask = (
|
||||
(prev_output_tokens == self.padding_idx).to(prev_output_tokens.device)
|
||||
if incremental_state is None
|
||||
else None
|
||||
)
|
||||
|
||||
if incremental_state is not None:
|
||||
prev_output_tokens = prev_output_tokens[:, -1:]
|
||||
|
||||
# embed tokens
|
||||
x = self.embed_tokens(prev_output_tokens)
|
||||
|
||||
# B x T x C -> T x B x C
|
||||
x = self._transpose_if_training(x, incremental_state)
|
||||
|
||||
for layer in self.conv_layers:
|
||||
if isinstance(layer, LinearizedConvolution):
|
||||
x = layer(x, incremental_state)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
# B x T x C -> T x B x C
|
||||
x = self._transpose_if_inference(x, incremental_state)
|
||||
|
||||
# decoder layers
|
||||
for layer in self.layers:
|
||||
if isinstance(layer, TransformerDecoderLayer):
|
||||
x, _ = layer(
|
||||
x,
|
||||
(encoder_out["encoder_out"] if encoder_out is not None else None),
|
||||
(
|
||||
encoder_out["encoder_padding_mask"].t()
|
||||
if encoder_out["encoder_padding_mask"] is not None
|
||||
else None
|
||||
),
|
||||
incremental_state,
|
||||
self_attn_mask=(
|
||||
self.buffered_future_mask(x)
|
||||
if incremental_state is None
|
||||
else None
|
||||
),
|
||||
self_attn_padding_mask=(
|
||||
target_padding_mask if incremental_state is None else None
|
||||
),
|
||||
)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
# T x B x C -> B x T x C
|
||||
x = x.transpose(0, 1)
|
||||
|
||||
x = self.fc_out(x)
|
||||
|
||||
return x, None
|
||||
|
||||
def buffered_future_mask(self, tensor):
|
||||
dim = tensor.size(0)
|
||||
if (
|
||||
not hasattr(self, "_future_mask")
|
||||
or self._future_mask is None
|
||||
or self._future_mask.device != tensor.device
|
||||
):
|
||||
self._future_mask = torch.triu(
|
||||
utils.fill_with_neg_inf(tensor.new(dim, dim)), 1
|
||||
)
|
||||
if self._future_mask.size(0) < dim:
|
||||
self._future_mask = torch.triu(
|
||||
utils.fill_with_neg_inf(self._future_mask.resize_(dim, dim)), 1
|
||||
)
|
||||
return self._future_mask[:dim, :dim]
|
||||
|
||||
def _transpose_if_training(self, x, incremental_state):
|
||||
if incremental_state is None:
|
||||
x = x.transpose(0, 1)
|
||||
return x
|
||||
|
||||
def _transpose_if_inference(self, x, incremental_state):
|
||||
if incremental_state:
|
||||
x = x.transpose(0, 1)
|
||||
return x
|
||||
|
||||
|
||||
def Embedding(num_embeddings, embedding_dim, padding_idx):
|
||||
m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
|
||||
# nn.init.uniform_(m.weight, -0.1, 0.1)
|
||||
# nn.init.constant_(m.weight[padding_idx], 0)
|
||||
return m
|
||||
|
||||
|
||||
def Linear(in_features, out_features, bias=True, dropout=0):
|
||||
"""Linear layer (input: N x T x C)"""
|
||||
m = nn.Linear(in_features, out_features, bias=bias)
|
||||
# m.weight.data.uniform_(-0.1, 0.1)
|
||||
# if bias:
|
||||
# m.bias.data.uniform_(-0.1, 0.1)
|
||||
return m
|
||||
|
||||
|
||||
def LinearizedConv1d(in_channels, out_channels, kernel_size, dropout=0, **kwargs):
|
||||
"""Weight-normalized Conv1d layer optimized for decoding"""
|
||||
m = LinearizedConvolution(in_channels, out_channels, kernel_size, **kwargs)
|
||||
std = math.sqrt((4 * (1.0 - dropout)) / (m.kernel_size[0] * in_channels))
|
||||
nn.init.normal_(m.weight, mean=0, std=std)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
return nn.utils.weight_norm(m, dim=2)
|
||||
|
||||
|
||||
def LayerNorm(embedding_dim):
|
||||
m = nn.LayerNorm(embedding_dim)
|
||||
return m
|
||||
|
||||
|
||||
# seq2seq models
|
||||
def base_architecture(args):
|
||||
args.input_feat_per_channel = getattr(args, "input_feat_per_channel", 40)
|
||||
args.vggblock_enc_config = getattr(
|
||||
args, "vggblock_enc_config", DEFAULT_ENC_VGGBLOCK_CONFIG
|
||||
)
|
||||
args.transformer_enc_config = getattr(
|
||||
args, "transformer_enc_config", DEFAULT_ENC_TRANSFORMER_CONFIG
|
||||
)
|
||||
args.enc_output_dim = getattr(args, "enc_output_dim", 512)
|
||||
args.in_channels = getattr(args, "in_channels", 1)
|
||||
args.tgt_embed_dim = getattr(args, "tgt_embed_dim", 128)
|
||||
args.transformer_dec_config = getattr(
|
||||
args, "transformer_dec_config", DEFAULT_ENC_TRANSFORMER_CONFIG
|
||||
)
|
||||
args.conv_dec_config = getattr(args, "conv_dec_config", DEFAULT_DEC_CONV_CONFIG)
|
||||
args.transformer_context = getattr(args, "transformer_context", "None")
|
||||
|
||||
|
||||
@register_model_architecture("asr_vggtransformer", "vggtransformer_1")
|
||||
def vggtransformer_1(args):
|
||||
args.input_feat_per_channel = getattr(args, "input_feat_per_channel", 80)
|
||||
args.vggblock_enc_config = getattr(
|
||||
args, "vggblock_enc_config", "[(64, 3, 2, 2, True), (128, 3, 2, 2, True)]"
|
||||
)
|
||||
args.transformer_enc_config = getattr(
|
||||
args,
|
||||
"transformer_enc_config",
|
||||
"((1024, 16, 4096, True, 0.15, 0.15, 0.15),) * 14",
|
||||
)
|
||||
args.enc_output_dim = getattr(args, "enc_output_dim", 1024)
|
||||
args.tgt_embed_dim = getattr(args, "tgt_embed_dim", 128)
|
||||
args.conv_dec_config = getattr(args, "conv_dec_config", "((256, 3, True),) * 4")
|
||||
args.transformer_dec_config = getattr(
|
||||
args,
|
||||
"transformer_dec_config",
|
||||
"((1024, 16, 4096, True, 0.15, 0.15, 0.15),) * 4",
|
||||
)
|
||||
|
||||
|
||||
@register_model_architecture("asr_vggtransformer", "vggtransformer_2")
|
||||
def vggtransformer_2(args):
|
||||
args.input_feat_per_channel = getattr(args, "input_feat_per_channel", 80)
|
||||
args.vggblock_enc_config = getattr(
|
||||
args, "vggblock_enc_config", "[(64, 3, 2, 2, True), (128, 3, 2, 2, True)]"
|
||||
)
|
||||
args.transformer_enc_config = getattr(
|
||||
args,
|
||||
"transformer_enc_config",
|
||||
"((1024, 16, 4096, True, 0.15, 0.15, 0.15),) * 16",
|
||||
)
|
||||
args.enc_output_dim = getattr(args, "enc_output_dim", 1024)
|
||||
args.tgt_embed_dim = getattr(args, "tgt_embed_dim", 512)
|
||||
args.conv_dec_config = getattr(args, "conv_dec_config", "((256, 3, True),) * 4")
|
||||
args.transformer_dec_config = getattr(
|
||||
args,
|
||||
"transformer_dec_config",
|
||||
"((1024, 16, 4096, True, 0.15, 0.15, 0.15),) * 6",
|
||||
)
|
||||
|
||||
|
||||
@register_model_architecture("asr_vggtransformer", "vggtransformer_base")
|
||||
def vggtransformer_base(args):
|
||||
args.input_feat_per_channel = getattr(args, "input_feat_per_channel", 80)
|
||||
args.vggblock_enc_config = getattr(
|
||||
args, "vggblock_enc_config", "[(64, 3, 2, 2, True), (128, 3, 2, 2, True)]"
|
||||
)
|
||||
args.transformer_enc_config = getattr(
|
||||
args, "transformer_enc_config", "((512, 8, 2048, True, 0.15, 0.15, 0.15),) * 12"
|
||||
)
|
||||
|
||||
args.enc_output_dim = getattr(args, "enc_output_dim", 512)
|
||||
args.tgt_embed_dim = getattr(args, "tgt_embed_dim", 512)
|
||||
args.conv_dec_config = getattr(args, "conv_dec_config", "((256, 3, True),) * 4")
|
||||
args.transformer_dec_config = getattr(
|
||||
args, "transformer_dec_config", "((512, 8, 2048, True, 0.15, 0.15, 0.15),) * 6"
|
||||
)
|
||||
# Size estimations:
|
||||
# Encoder:
|
||||
# - vggblock param: 64*1*3*3 + 64*64*3*3 + 128*64*3*3 + 128*128*3 = 258K
|
||||
# Transformer:
|
||||
# - input dimension adapter: 2560 x 512 -> 1.31M
|
||||
# - transformer_layers (x12) --> 37.74M
|
||||
# * MultiheadAttention: 512*512*3 (in_proj) + 512*512 (out_proj) = 1.048M
|
||||
# * FFN weight: 512*2048*2 = 2.097M
|
||||
# - output dimension adapter: 512 x 512 -> 0.26 M
|
||||
# Decoder:
|
||||
# - LinearizedConv1d: 512 * 256 * 3 + 256 * 256 * 3 * 3
|
||||
# - transformer_layer: (x6) --> 25.16M
|
||||
# * MultiheadAttention (self-attention): 512*512*3 + 512*512 = 1.048M
|
||||
# * MultiheadAttention (encoder-attention): 512*512*3 + 512*512 = 1.048M
|
||||
# * FFN: 512*2048*2 = 2.097M
|
||||
# Final FC:
|
||||
# - FC: 512*5000 = 256K (assuming vocab size 5K)
|
||||
# In total:
|
||||
# ~65 M
|
7
examples/speech_recognition/tasks/__init__.py
Normal file
7
examples/speech_recognition/tasks/__init__.py
Normal file
@ -0,0 +1,7 @@
|
||||
import importlib
|
||||
import os
|
||||
|
||||
for file in os.listdir(os.path.dirname(__file__)):
|
||||
if file.endswith('.py') and not file.startswith('_'):
|
||||
task_name = file[:file.find('.py')]
|
||||
importlib.import_module('examples.speech_recognition.tasks.' + task_name)
|
116
examples/speech_recognition/tasks/speech_recognition.py
Normal file
116
examples/speech_recognition/tasks/speech_recognition.py
Normal file
@ -0,0 +1,116 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
|
||||
import torch
|
||||
from fairseq.data import Dictionary
|
||||
from fairseq.tasks import FairseqTask, register_task
|
||||
from examples.speech_recognition.data import AsrDataset
|
||||
|
||||
|
||||
def get_asr_dataset_from_json(data_json_path, tgt_dict):
|
||||
"""
|
||||
Parse data json and create dataset.
|
||||
See scripts/asr_prep_json.py which pack json from raw files
|
||||
|
||||
Json example:
|
||||
{
|
||||
"utts": {
|
||||
"4771-29403-0025": {
|
||||
"input": {
|
||||
"length_ms": 170,
|
||||
"path": "/tmp/file1.flac"
|
||||
},
|
||||
"output": {
|
||||
"text": "HELLO \n",
|
||||
"token": "HE LLO",
|
||||
"tokenid": "4815, 861"
|
||||
}
|
||||
},
|
||||
"1564-142299-0096": {
|
||||
...
|
||||
}
|
||||
}
|
||||
"""
|
||||
if not os.path.isfile(data_json_path):
|
||||
raise FileNotFoundError("Dataset not found: {}".format(data_json_path))
|
||||
with open(data_json_path, "rb") as f:
|
||||
data_samples = json.load(f)["utts"]
|
||||
assert len(data_samples) != 0
|
||||
sorted_samples = sorted(
|
||||
data_samples.items(),
|
||||
key=lambda sample: int(sample[1]["input"]["length_ms"]),
|
||||
reverse=True,
|
||||
)
|
||||
aud_paths = [s[1]["input"]["path"] for s in sorted_samples]
|
||||
ids = [s[0] for s in sorted_samples]
|
||||
speakers = []
|
||||
for s in sorted_samples:
|
||||
m = re.search("(.+?)-(.+?)-(.+?)", s[0])
|
||||
speakers.append(m.group(1) + "_" + m.group(2))
|
||||
frame_sizes = [s[1]["input"]["length_ms"] for s in sorted_samples]
|
||||
tgt = [
|
||||
torch.LongTensor(
|
||||
[int(i) for i in s[1]["output"]["tokenid"].split(", ")]
|
||||
)
|
||||
for s in sorted_samples
|
||||
]
|
||||
# append eos
|
||||
tgt = [torch.cat([t, torch.LongTensor([tgt_dict.eos()])]) for t in tgt]
|
||||
return AsrDataset(
|
||||
aud_paths, frame_sizes, tgt, tgt_dict, ids, speakers
|
||||
)
|
||||
|
||||
|
||||
@register_task("speech_recognition")
|
||||
class SpeechRecognitionTask(FairseqTask):
|
||||
"""
|
||||
Task for training speech recognition model.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
"""Add task-specific arguments to the parser."""
|
||||
parser.add_argument("data", help="path to data directory")
|
||||
|
||||
def __init__(self, args, tgt_dict):
|
||||
super().__init__(args)
|
||||
self.tgt_dict = tgt_dict
|
||||
|
||||
@classmethod
|
||||
def setup_task(cls, args, **kwargs):
|
||||
"""Setup the task (e.g., load dictionaries)."""
|
||||
dict_path = os.path.join(args.data, "dict.txt")
|
||||
if not os.path.isfile(dict_path):
|
||||
raise FileNotFoundError("Dict not found: {}".format(dict_path))
|
||||
tgt_dict = Dictionary.load(dict_path)
|
||||
|
||||
print("| dictionary: {} types".format(len(tgt_dict)))
|
||||
return cls(args, tgt_dict)
|
||||
|
||||
def load_dataset(self, split, combine=False, **kwargs):
|
||||
"""Load a given dataset split.
|
||||
|
||||
Args:
|
||||
split (str): name of the split (e.g., train, valid, test)
|
||||
"""
|
||||
data_json_path = os.path.join(self.args.data, "{}.json".format(split))
|
||||
self.datasets[split] = get_asr_dataset_from_json(
|
||||
data_json_path, self.tgt_dict)
|
||||
|
||||
@property
|
||||
def target_dictionary(self):
|
||||
"""Return the :class:`~fairseq.data.Dictionary` for the language
|
||||
model."""
|
||||
return self.tgt_dict
|
||||
|
||||
@property
|
||||
def source_dictionary(self):
|
||||
"""Return the source :class:`~fairseq.data.Dictionary` (if applicable
|
||||
for this task)."""
|
||||
return None
|
@ -23,6 +23,8 @@ from fairseq.modules import (
|
||||
MultiheadAttention,
|
||||
PositionalEmbedding,
|
||||
SinusoidalPositionalEmbedding,
|
||||
TransformerDecoderLayer,
|
||||
TransformerEncoderLayer,
|
||||
)
|
||||
|
||||
DEFAULT_MAX_SOURCE_POSITIONS = 1024
|
||||
@ -504,253 +506,6 @@ class TransformerDecoder(FairseqIncrementalDecoder):
|
||||
return state_dict
|
||||
|
||||
|
||||
class TransformerEncoderLayer(nn.Module):
|
||||
"""Encoder layer block.
|
||||
|
||||
In the original paper each operation (multi-head attention or FFN) is
|
||||
postprocessed with: `dropout -> add residual -> layernorm`. In the
|
||||
tensor2tensor code they suggest that learning is more robust when
|
||||
preprocessing each layer with layernorm and postprocessing with:
|
||||
`dropout -> add residual`. We default to the approach in the paper, but the
|
||||
tensor2tensor approach can be enabled by setting
|
||||
*args.encoder_normalize_before* to ``True``.
|
||||
|
||||
Args:
|
||||
args (argparse.Namespace): parsed command-line arguments
|
||||
"""
|
||||
|
||||
def __init__(self, args):
|
||||
super().__init__()
|
||||
self.embed_dim = args.encoder_embed_dim
|
||||
self.self_attn = MultiheadAttention(
|
||||
self.embed_dim, args.encoder_attention_heads,
|
||||
dropout=args.attention_dropout, self_attention=True
|
||||
)
|
||||
self.self_attn_layer_norm = LayerNorm(self.embed_dim)
|
||||
self.dropout = args.dropout
|
||||
self.activation_fn = utils.get_activation_fn(
|
||||
activation=getattr(args, 'activation_fn', 'relu')
|
||||
)
|
||||
self.activation_dropout = getattr(args, 'activation_dropout', 0)
|
||||
if self.activation_dropout == 0:
|
||||
# for backwards compatibility with models that use args.relu_dropout
|
||||
self.activation_dropout = getattr(args, 'relu_dropout', 0)
|
||||
self.normalize_before = args.encoder_normalize_before
|
||||
self.fc1 = Linear(self.embed_dim, args.encoder_ffn_embed_dim)
|
||||
self.fc2 = Linear(args.encoder_ffn_embed_dim, self.embed_dim)
|
||||
self.final_layer_norm = LayerNorm(self.embed_dim)
|
||||
|
||||
def upgrade_state_dict_named(self, state_dict, name):
|
||||
"""
|
||||
Rename layer norm states from `...layer_norms.0.weight` to
|
||||
`...self_attn_layer_norm.weight` and `...layer_norms.1.weight` to
|
||||
`...final_layer_norm.weight`
|
||||
"""
|
||||
layer_norm_map = {
|
||||
'0': 'self_attn_layer_norm',
|
||||
'1': 'final_layer_norm'
|
||||
}
|
||||
for old, new in layer_norm_map.items():
|
||||
for m in ('weight', 'bias'):
|
||||
k = '{}.layer_norms.{}.{}'.format(name, old, m)
|
||||
if k in state_dict:
|
||||
state_dict[
|
||||
'{}.{}.{}'.format(name, new, m)
|
||||
] = state_dict[k]
|
||||
del state_dict[k]
|
||||
|
||||
def forward(self, x, encoder_padding_mask):
|
||||
"""
|
||||
Args:
|
||||
x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
|
||||
encoder_padding_mask (ByteTensor): binary ByteTensor of shape
|
||||
`(batch, src_len)` where padding elements are indicated by ``1``.
|
||||
|
||||
Returns:
|
||||
encoded output of shape `(seq_len, batch, embed_dim)`
|
||||
"""
|
||||
residual = x
|
||||
x = self.maybe_layer_norm(self.self_attn_layer_norm, x, before=True)
|
||||
x, _ = self.self_attn(query=x, key=x, value=x, key_padding_mask=encoder_padding_mask)
|
||||
x = F.dropout(x, p=self.dropout, training=self.training)
|
||||
x = residual + x
|
||||
x = self.maybe_layer_norm(self.self_attn_layer_norm, x, after=True)
|
||||
|
||||
residual = x
|
||||
x = self.maybe_layer_norm(self.final_layer_norm, x, before=True)
|
||||
x = self.activation_fn(self.fc1(x))
|
||||
x = F.dropout(x, p=self.activation_dropout, training=self.training)
|
||||
x = self.fc2(x)
|
||||
x = F.dropout(x, p=self.dropout, training=self.training)
|
||||
x = residual + x
|
||||
x = self.maybe_layer_norm(self.final_layer_norm, x, after=True)
|
||||
return x
|
||||
|
||||
def maybe_layer_norm(self, layer_norm, x, before=False, after=False):
|
||||
assert before ^ after
|
||||
if after ^ self.normalize_before:
|
||||
return layer_norm(x)
|
||||
else:
|
||||
return x
|
||||
|
||||
|
||||
class TransformerDecoderLayer(nn.Module):
|
||||
"""Decoder layer block.
|
||||
|
||||
In the original paper each operation (multi-head attention, encoder
|
||||
attention or FFN) is postprocessed with: `dropout -> add residual ->
|
||||
layernorm`. In the tensor2tensor code they suggest that learning is more
|
||||
robust when preprocessing each layer with layernorm and postprocessing with:
|
||||
`dropout -> add residual`. We default to the approach in the paper, but the
|
||||
tensor2tensor approach can be enabled by setting
|
||||
*args.decoder_normalize_before* to ``True``.
|
||||
|
||||
Args:
|
||||
args (argparse.Namespace): parsed command-line arguments
|
||||
no_encoder_attn (bool, optional): whether to attend to encoder outputs
|
||||
(default: False).
|
||||
"""
|
||||
|
||||
def __init__(self, args, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False):
|
||||
super().__init__()
|
||||
self.embed_dim = args.decoder_embed_dim
|
||||
self.self_attn = MultiheadAttention(
|
||||
embed_dim=self.embed_dim,
|
||||
num_heads=args.decoder_attention_heads,
|
||||
dropout=args.attention_dropout,
|
||||
add_bias_kv=add_bias_kv,
|
||||
add_zero_attn=add_zero_attn,
|
||||
self_attention=True
|
||||
)
|
||||
self.dropout = args.dropout
|
||||
self.activation_fn = utils.get_activation_fn(
|
||||
activation=getattr(args, 'activation_fn', 'relu')
|
||||
)
|
||||
self.activation_dropout = getattr(args, 'activation_dropout', 0)
|
||||
if self.activation_dropout == 0:
|
||||
# for backwards compatibility with models that use args.relu_dropout
|
||||
self.activation_dropout = getattr(args, 'relu_dropout', 0)
|
||||
self.normalize_before = args.decoder_normalize_before
|
||||
|
||||
# use layerNorm rather than FusedLayerNorm for exporting.
|
||||
# char_inputs can be used to determint this.
|
||||
# TODO remove this once we update apex with the fix
|
||||
export = getattr(args, 'char_inputs', False)
|
||||
self.self_attn_layer_norm = LayerNorm(self.embed_dim, export=export)
|
||||
|
||||
if no_encoder_attn:
|
||||
self.encoder_attn = None
|
||||
self.encoder_attn_layer_norm = None
|
||||
else:
|
||||
self.encoder_attn = MultiheadAttention(
|
||||
self.embed_dim,
|
||||
args.decoder_attention_heads,
|
||||
kdim=getattr(args, 'encoder_embed_dim', None),
|
||||
vdim=getattr(args, 'encoder_embed_dim', None),
|
||||
dropout=args.attention_dropout,
|
||||
encoder_decoder_attention=True,
|
||||
)
|
||||
self.encoder_attn_layer_norm = LayerNorm(self.embed_dim, export=export)
|
||||
|
||||
self.fc1 = Linear(self.embed_dim, args.decoder_ffn_embed_dim)
|
||||
self.fc2 = Linear(args.decoder_ffn_embed_dim, self.embed_dim)
|
||||
|
||||
self.final_layer_norm = LayerNorm(self.embed_dim, export=export)
|
||||
self.need_attn = True
|
||||
|
||||
self.onnx_trace = False
|
||||
|
||||
def prepare_for_onnx_export_(self):
|
||||
self.onnx_trace = True
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
encoder_out=None,
|
||||
encoder_padding_mask=None,
|
||||
incremental_state=None,
|
||||
prev_self_attn_state=None,
|
||||
prev_attn_state=None,
|
||||
self_attn_mask=None,
|
||||
self_attn_padding_mask=None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
|
||||
encoder_padding_mask (ByteTensor): binary ByteTensor of shape
|
||||
`(batch, src_len)` where padding elements are indicated by ``1``.
|
||||
|
||||
Returns:
|
||||
encoded output of shape `(seq_len, batch, embed_dim)`
|
||||
"""
|
||||
residual = x
|
||||
x = self.maybe_layer_norm(self.self_attn_layer_norm, x, before=True)
|
||||
if prev_self_attn_state is not None:
|
||||
if incremental_state is None:
|
||||
incremental_state = {}
|
||||
prev_key, prev_value = prev_self_attn_state
|
||||
saved_state = {"prev_key": prev_key, "prev_value": prev_value}
|
||||
self.self_attn._set_input_buffer(incremental_state, saved_state)
|
||||
x, attn = self.self_attn(
|
||||
query=x,
|
||||
key=x,
|
||||
value=x,
|
||||
key_padding_mask=self_attn_padding_mask,
|
||||
incremental_state=incremental_state,
|
||||
need_weights=False,
|
||||
attn_mask=self_attn_mask,
|
||||
)
|
||||
x = F.dropout(x, p=self.dropout, training=self.training)
|
||||
x = residual + x
|
||||
x = self.maybe_layer_norm(self.self_attn_layer_norm, x, after=True)
|
||||
|
||||
if self.encoder_attn is not None:
|
||||
residual = x
|
||||
x = self.maybe_layer_norm(self.encoder_attn_layer_norm, x, before=True)
|
||||
if prev_attn_state is not None:
|
||||
if incremental_state is None:
|
||||
incremental_state = {}
|
||||
prev_key, prev_value = prev_attn_state
|
||||
saved_state = {"prev_key": prev_key, "prev_value": prev_value}
|
||||
self.encoder_attn._set_input_buffer(incremental_state, saved_state)
|
||||
x, attn = self.encoder_attn(
|
||||
query=x,
|
||||
key=encoder_out,
|
||||
value=encoder_out,
|
||||
key_padding_mask=encoder_padding_mask,
|
||||
incremental_state=incremental_state,
|
||||
static_kv=True,
|
||||
need_weights=(not self.training and self.need_attn),
|
||||
)
|
||||
x = F.dropout(x, p=self.dropout, training=self.training)
|
||||
x = residual + x
|
||||
x = self.maybe_layer_norm(self.encoder_attn_layer_norm, x, after=True)
|
||||
|
||||
residual = x
|
||||
x = self.maybe_layer_norm(self.final_layer_norm, x, before=True)
|
||||
x = self.activation_fn(self.fc1(x))
|
||||
x = F.dropout(x, p=self.activation_dropout, training=self.training)
|
||||
x = self.fc2(x)
|
||||
x = F.dropout(x, p=self.dropout, training=self.training)
|
||||
x = residual + x
|
||||
x = self.maybe_layer_norm(self.final_layer_norm, x, after=True)
|
||||
if self.onnx_trace and incremental_state is not None:
|
||||
saved_state = self.self_attn._get_input_buffer(incremental_state)
|
||||
self_attn_state = saved_state["prev_key"], saved_state["prev_value"]
|
||||
return x, attn, self_attn_state
|
||||
return x, attn
|
||||
|
||||
def maybe_layer_norm(self, layer_norm, x, before=False, after=False):
|
||||
assert before ^ after
|
||||
if after ^ self.normalize_before:
|
||||
return layer_norm(x)
|
||||
else:
|
||||
return x
|
||||
|
||||
def make_generation_fast_(self, need_attn=False, **kwargs):
|
||||
self.need_attn = need_attn
|
||||
|
||||
|
||||
def Embedding(num_embeddings, embedding_dim, padding_idx):
|
||||
m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
|
||||
nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
|
||||
|
@ -26,6 +26,8 @@ from .sinusoidal_positional_embedding import SinusoidalPositionalEmbedding
|
||||
from .transformer_sentence_encoder_layer import TransformerSentenceEncoderLayer
|
||||
from .transformer_sentence_encoder import TransformerSentenceEncoder
|
||||
from .unfold import unfold1d
|
||||
from .transformer_layer import TransformerDecoderLayer, TransformerEncoderLayer
|
||||
from .vggblock import VGGBlock
|
||||
|
||||
__all__ = [
|
||||
'AdaptiveInput',
|
||||
@ -51,5 +53,8 @@ __all__ = [
|
||||
'SinusoidalPositionalEmbedding',
|
||||
'TransformerSentenceEncoderLayer',
|
||||
'TransformerSentenceEncoder',
|
||||
'TransformerDecoderLayer',
|
||||
'TransformerEncoderLayer',
|
||||
'VGGBlock',
|
||||
'unfold1d',
|
||||
]
|
||||
|
279
fairseq/modules/transformer_layer.py
Normal file
279
fairseq/modules/transformer_layer.py
Normal file
@ -0,0 +1,279 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from fairseq import utils
|
||||
from fairseq.modules import LayerNorm, MultiheadAttention
|
||||
|
||||
|
||||
class TransformerEncoderLayer(nn.Module):
|
||||
"""Encoder layer block.
|
||||
|
||||
In the original paper each operation (multi-head attention or FFN) is
|
||||
postprocessed with: `dropout -> add residual -> layernorm`. In the
|
||||
tensor2tensor code they suggest that learning is more robust when
|
||||
preprocessing each layer with layernorm and postprocessing with:
|
||||
`dropout -> add residual`. We default to the approach in the paper, but the
|
||||
tensor2tensor approach can be enabled by setting
|
||||
*args.encoder_normalize_before* to ``True``.
|
||||
|
||||
Args:
|
||||
args (argparse.Namespace): parsed command-line arguments
|
||||
"""
|
||||
|
||||
def __init__(self, args):
|
||||
super().__init__()
|
||||
self.embed_dim = args.encoder_embed_dim
|
||||
self.self_attn = MultiheadAttention(
|
||||
self.embed_dim, args.encoder_attention_heads,
|
||||
dropout=args.attention_dropout, self_attention=True
|
||||
)
|
||||
self.self_attn_layer_norm = LayerNorm(self.embed_dim)
|
||||
self.dropout = args.dropout
|
||||
self.activation_fn = utils.get_activation_fn(
|
||||
activation=getattr(args, 'activation_fn', 'relu')
|
||||
)
|
||||
self.activation_dropout = getattr(args, 'activation_dropout', 0)
|
||||
if self.activation_dropout == 0:
|
||||
# for backwards compatibility with models that use args.relu_dropout
|
||||
self.activation_dropout = getattr(args, 'relu_dropout', 0)
|
||||
self.normalize_before = args.encoder_normalize_before
|
||||
self.fc1 = Linear(self.embed_dim, args.encoder_ffn_embed_dim)
|
||||
self.fc2 = Linear(args.encoder_ffn_embed_dim, self.embed_dim)
|
||||
self.final_layer_norm = LayerNorm(self.embed_dim)
|
||||
|
||||
def upgrade_state_dict_named(self, state_dict, name):
|
||||
"""
|
||||
Rename layer norm states from `...layer_norms.0.weight` to
|
||||
`...self_attn_layer_norm.weight` and `...layer_norms.1.weight` to
|
||||
`...final_layer_norm.weight`
|
||||
"""
|
||||
layer_norm_map = {
|
||||
'0': 'self_attn_layer_norm',
|
||||
'1': 'final_layer_norm'
|
||||
}
|
||||
for old, new in layer_norm_map.items():
|
||||
for m in ('weight', 'bias'):
|
||||
k = '{}.layer_norms.{}.{}'.format(name, old, m)
|
||||
if k in state_dict:
|
||||
state_dict[
|
||||
'{}.{}.{}'.format(name, new, m)
|
||||
] = state_dict[k]
|
||||
del state_dict[k]
|
||||
|
||||
def forward(self, x, encoder_padding_mask, attn_mask=None):
|
||||
"""
|
||||
Args:
|
||||
x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
|
||||
encoder_padding_mask (ByteTensor): binary ByteTensor of shape
|
||||
`(batch, src_len)` where padding elements are indicated by ``1``.
|
||||
attn_mask (ByteTensor): binary tensor of shape (T_tgt, T_src), where
|
||||
T_tgt is the length of query, while T_src is the length of key,
|
||||
though here both query and key is x here,
|
||||
attn_mask[t_tgt, t_src] = 1 means when calculating embedding
|
||||
for t_tgt, t_src is excluded (or masked out), =0 means it is
|
||||
included in attention
|
||||
|
||||
Returns:
|
||||
encoded output of shape `(seq_len, batch, embed_dim)`
|
||||
"""
|
||||
residual = x
|
||||
x = self.maybe_layer_norm(self.self_attn_layer_norm, x, before=True)
|
||||
if attn_mask is not None:
|
||||
attn_mask = attn_mask.masked_fill(attn_mask.byte(), -1e8)
|
||||
# anything in original attn_mask = 1, becomes -1e8
|
||||
# anything in original attn_mask = 0, becomes 0
|
||||
# Note that we cannot use -inf here, because at some edge cases,
|
||||
# the attention weight (before softmax) for some padded element in query
|
||||
# will become -inf, which results in NaN in model parameters
|
||||
# TODO: to formally solve this problem, we need to change fairseq's
|
||||
# MultiheadAttention. We will do this later on.
|
||||
x, _ = self.self_attn(query=x, key=x, value=x, key_padding_mask=encoder_padding_mask)
|
||||
x = F.dropout(x, p=self.dropout, training=self.training)
|
||||
x = residual + x
|
||||
x = self.maybe_layer_norm(self.self_attn_layer_norm, x, after=True)
|
||||
|
||||
residual = x
|
||||
x = self.maybe_layer_norm(self.final_layer_norm, x, before=True)
|
||||
x = self.activation_fn(self.fc1(x))
|
||||
x = F.dropout(x, p=self.activation_dropout, training=self.training)
|
||||
x = self.fc2(x)
|
||||
x = F.dropout(x, p=self.dropout, training=self.training)
|
||||
x = residual + x
|
||||
x = self.maybe_layer_norm(self.final_layer_norm, x, after=True)
|
||||
return x
|
||||
|
||||
def maybe_layer_norm(self, layer_norm, x, before=False, after=False):
|
||||
assert before ^ after
|
||||
if after ^ self.normalize_before:
|
||||
return layer_norm(x)
|
||||
else:
|
||||
return x
|
||||
|
||||
|
||||
class TransformerDecoderLayer(nn.Module):
|
||||
"""Decoder layer block.
|
||||
|
||||
In the original paper each operation (multi-head attention, encoder
|
||||
attention or FFN) is postprocessed with: `dropout -> add residual ->
|
||||
layernorm`. In the tensor2tensor code they suggest that learning is more
|
||||
robust when preprocessing each layer with layernorm and postprocessing with:
|
||||
`dropout -> add residual`. We default to the approach in the paper, but the
|
||||
tensor2tensor approach can be enabled by setting
|
||||
*args.decoder_normalize_before* to ``True``.
|
||||
|
||||
Args:
|
||||
args (argparse.Namespace): parsed command-line arguments
|
||||
no_encoder_attn (bool, optional): whether to attend to encoder outputs
|
||||
(default: False).
|
||||
"""
|
||||
|
||||
def __init__(self, args, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False):
|
||||
super().__init__()
|
||||
self.embed_dim = args.decoder_embed_dim
|
||||
self.self_attn = MultiheadAttention(
|
||||
embed_dim=self.embed_dim,
|
||||
num_heads=args.decoder_attention_heads,
|
||||
dropout=args.attention_dropout,
|
||||
add_bias_kv=add_bias_kv,
|
||||
add_zero_attn=add_zero_attn,
|
||||
self_attention=True
|
||||
)
|
||||
self.dropout = args.dropout
|
||||
self.activation_fn = utils.get_activation_fn(
|
||||
activation=getattr(args, 'activation_fn', 'relu')
|
||||
)
|
||||
self.activation_dropout = getattr(args, 'activation_dropout', 0)
|
||||
if self.activation_dropout == 0:
|
||||
# for backwards compatibility with models that use args.relu_dropout
|
||||
self.activation_dropout = getattr(args, 'relu_dropout', 0)
|
||||
self.normalize_before = args.decoder_normalize_before
|
||||
|
||||
# use layerNorm rather than FusedLayerNorm for exporting.
|
||||
# char_inputs can be used to determint this.
|
||||
# TODO remove this once we update apex with the fix
|
||||
export = getattr(args, 'char_inputs', False)
|
||||
self.self_attn_layer_norm = LayerNorm(self.embed_dim, export=export)
|
||||
|
||||
if no_encoder_attn:
|
||||
self.encoder_attn = None
|
||||
self.encoder_attn_layer_norm = None
|
||||
else:
|
||||
self.encoder_attn = MultiheadAttention(
|
||||
self.embed_dim,
|
||||
args.decoder_attention_heads,
|
||||
kdim=getattr(args, 'encoder_embed_dim', None),
|
||||
vdim=getattr(args, 'encoder_embed_dim', None),
|
||||
dropout=args.attention_dropout,
|
||||
encoder_decoder_attention=True,
|
||||
)
|
||||
self.encoder_attn_layer_norm = LayerNorm(self.embed_dim, export=export)
|
||||
|
||||
self.fc1 = Linear(self.embed_dim, args.decoder_ffn_embed_dim)
|
||||
self.fc2 = Linear(args.decoder_ffn_embed_dim, self.embed_dim)
|
||||
|
||||
self.final_layer_norm = LayerNorm(self.embed_dim, export=export)
|
||||
self.need_attn = True
|
||||
|
||||
self.onnx_trace = False
|
||||
|
||||
def prepare_for_onnx_export_(self):
|
||||
self.onnx_trace = True
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
encoder_out=None,
|
||||
encoder_padding_mask=None,
|
||||
incremental_state=None,
|
||||
prev_self_attn_state=None,
|
||||
prev_attn_state=None,
|
||||
self_attn_mask=None,
|
||||
self_attn_padding_mask=None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
|
||||
encoder_padding_mask (ByteTensor): binary ByteTensor of shape
|
||||
`(batch, src_len)` where padding elements are indicated by ``1``.
|
||||
|
||||
Returns:
|
||||
encoded output of shape `(seq_len, batch, embed_dim)`
|
||||
"""
|
||||
residual = x
|
||||
x = self.maybe_layer_norm(self.self_attn_layer_norm, x, before=True)
|
||||
if prev_self_attn_state is not None:
|
||||
if incremental_state is None:
|
||||
incremental_state = {}
|
||||
prev_key, prev_value = prev_self_attn_state
|
||||
saved_state = {"prev_key": prev_key, "prev_value": prev_value}
|
||||
self.self_attn._set_input_buffer(incremental_state, saved_state)
|
||||
x, attn = self.self_attn(
|
||||
query=x,
|
||||
key=x,
|
||||
value=x,
|
||||
key_padding_mask=self_attn_padding_mask,
|
||||
incremental_state=incremental_state,
|
||||
need_weights=False,
|
||||
attn_mask=self_attn_mask,
|
||||
)
|
||||
x = F.dropout(x, p=self.dropout, training=self.training)
|
||||
x = residual + x
|
||||
x = self.maybe_layer_norm(self.self_attn_layer_norm, x, after=True)
|
||||
|
||||
if self.encoder_attn is not None:
|
||||
residual = x
|
||||
x = self.maybe_layer_norm(self.encoder_attn_layer_norm, x, before=True)
|
||||
if prev_attn_state is not None:
|
||||
if incremental_state is None:
|
||||
incremental_state = {}
|
||||
prev_key, prev_value = prev_attn_state
|
||||
saved_state = {"prev_key": prev_key, "prev_value": prev_value}
|
||||
self.encoder_attn._set_input_buffer(incremental_state, saved_state)
|
||||
x, attn = self.encoder_attn(
|
||||
query=x,
|
||||
key=encoder_out,
|
||||
value=encoder_out,
|
||||
key_padding_mask=encoder_padding_mask,
|
||||
incremental_state=incremental_state,
|
||||
static_kv=True,
|
||||
need_weights=(not self.training and self.need_attn),
|
||||
)
|
||||
x = F.dropout(x, p=self.dropout, training=self.training)
|
||||
x = residual + x
|
||||
x = self.maybe_layer_norm(self.encoder_attn_layer_norm, x, after=True)
|
||||
|
||||
residual = x
|
||||
x = self.maybe_layer_norm(self.final_layer_norm, x, before=True)
|
||||
x = self.activation_fn(self.fc1(x))
|
||||
x = F.dropout(x, p=self.activation_dropout, training=self.training)
|
||||
x = self.fc2(x)
|
||||
x = F.dropout(x, p=self.dropout, training=self.training)
|
||||
x = residual + x
|
||||
x = self.maybe_layer_norm(self.final_layer_norm, x, after=True)
|
||||
if self.onnx_trace and incremental_state is not None:
|
||||
saved_state = self.self_attn._get_input_buffer(incremental_state)
|
||||
self_attn_state = saved_state["prev_key"], saved_state["prev_value"]
|
||||
return x, attn, self_attn_state
|
||||
return x, attn
|
||||
|
||||
def maybe_layer_norm(self, layer_norm, x, before=False, after=False):
|
||||
assert before ^ after
|
||||
if after ^ self.normalize_before:
|
||||
return layer_norm(x)
|
||||
else:
|
||||
return x
|
||||
|
||||
def make_generation_fast_(self, need_attn=False, **kwargs):
|
||||
self.need_attn = need_attn
|
||||
|
||||
|
||||
def Linear(in_features, out_features, bias=True):
|
||||
m = nn.Linear(in_features, out_features, bias)
|
||||
nn.init.xavier_uniform_(m.weight)
|
||||
if bias:
|
||||
nn.init.constant_(m.bias, 0.)
|
||||
return m
|
115
fairseq/modules/vggblock.py
Normal file
115
fairseq/modules/vggblock.py
Normal file
@ -0,0 +1,115 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||
|
||||
from collections.abc import Iterable
|
||||
from itertools import repeat
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def _pair(v):
|
||||
if isinstance(v, Iterable):
|
||||
assert len(v) == 2, "len(v) != 2"
|
||||
return v
|
||||
return tuple(repeat(v, 2))
|
||||
|
||||
|
||||
def infer_conv_output_dim(conv_op, input_dim, sample_inchannel):
|
||||
sample_seq_len = 200
|
||||
sample_bsz = 10
|
||||
x = torch.randn(sample_bsz, sample_inchannel, sample_seq_len, input_dim)
|
||||
# N x C x H x W
|
||||
# N: sample_bsz, C: sample_inchannel, H: sample_seq_len, W: input_dim
|
||||
x = conv_op(x)
|
||||
# N x C x H x W
|
||||
x = x.transpose(1, 2)
|
||||
# N x H x C x W
|
||||
bsz, seq = x.size()[:2]
|
||||
per_channel_dim = x.size()[3]
|
||||
# bsz: N, seq: H, CxW the rest
|
||||
return x.contiguous().view(bsz, seq, -1).size(-1), per_channel_dim
|
||||
|
||||
|
||||
class VGGBlock(torch.nn.Module):
|
||||
"""
|
||||
VGG motibated cnn module https://arxiv.org/pdf/1409.1556.pdf
|
||||
|
||||
Args:
|
||||
in_channels: (int) number of input channels (typically 1)
|
||||
out_channels: (int) number of output channels
|
||||
conv_kernel_size: convolution channels
|
||||
pooling_kernel_size: the size of the pooling window to take a max over
|
||||
num_conv_layers: (int) number of convolution layers
|
||||
input_dim: (int) input dimension
|
||||
conv_stride: the stride of the convolving kernel.
|
||||
Can be a single number or a tuple (sH, sW) Default: 1
|
||||
padding: implicit paddings on both sides of the input.
|
||||
Can be a single number or a tuple (padH, padW). Default: None
|
||||
layer_norm: (bool) if layer norm is going to be applied. Default: False
|
||||
|
||||
Shape:
|
||||
Input: BxCxTxfeat, i.e. (batch_size, input_size, timesteps, features)
|
||||
Output: BxCxTxfeat, i.e. (batch_size, input_size, timesteps, features)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
conv_kernel_size,
|
||||
pooling_kernel_size,
|
||||
num_conv_layers,
|
||||
input_dim,
|
||||
conv_stride=1,
|
||||
padding=None,
|
||||
layer_norm=False,
|
||||
):
|
||||
assert (
|
||||
input_dim is not None
|
||||
), "Need input_dim for LayerNorm and infer_conv_output_dim"
|
||||
super(VGGBlock, self).__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.conv_kernel_size = _pair(conv_kernel_size)
|
||||
self.pooling_kernel_size = _pair(pooling_kernel_size)
|
||||
self.num_conv_layers = num_conv_layers
|
||||
self.padding = (
|
||||
tuple(e // 2 for e in self.conv_kernel_size)
|
||||
if padding is None
|
||||
else _pair(padding)
|
||||
)
|
||||
self.conv_stride = _pair(conv_stride)
|
||||
|
||||
self.layers = nn.ModuleList()
|
||||
for layer in range(num_conv_layers):
|
||||
conv_op = nn.Conv2d(
|
||||
in_channels if layer == 0 else out_channels,
|
||||
out_channels,
|
||||
self.conv_kernel_size,
|
||||
stride=self.conv_stride,
|
||||
padding=self.padding,
|
||||
)
|
||||
self.layers.append(conv_op)
|
||||
if layer_norm:
|
||||
conv_output_dim, per_channel_dim = infer_conv_output_dim(
|
||||
conv_op, input_dim, in_channels if layer == 0 else out_channels
|
||||
)
|
||||
self.layers.append(nn.LayerNorm(per_channel_dim))
|
||||
input_dim = per_channel_dim
|
||||
self.layers.append(nn.ReLU())
|
||||
|
||||
pool_op = nn.MaxPool2d(kernel_size=pooling_kernel_size, ceil_mode=True)
|
||||
self.layers.append(pool_op)
|
||||
self.total_output_dim, self.output_dim = infer_conv_output_dim(
|
||||
pool_op, input_dim, out_channels
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
for i, _ in enumerate(self.layers):
|
||||
x = self.layers[i](x)
|
||||
return x
|
0
tests/speech_recognition/__init__.py
Normal file
0
tests/speech_recognition/__init__.py
Normal file
549
tests/speech_recognition/asr_test_base.py
Normal file
549
tests/speech_recognition/asr_test_base.py
Normal file
@ -0,0 +1,549 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import unittest
|
||||
from inspect import currentframe, getframeinfo
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from fairseq.data import data_utils as fairseq_data_utils
|
||||
from fairseq.data.dictionary import Dictionary
|
||||
from fairseq.models import (
|
||||
BaseFairseqModel,
|
||||
FairseqDecoder,
|
||||
FairseqEncoder,
|
||||
FairseqEncoderDecoderModel,
|
||||
FairseqEncoderModel,
|
||||
FairseqModel,
|
||||
)
|
||||
from fairseq.tasks.fairseq_task import FairseqTask
|
||||
from examples.speech_recognition.data.data_utils import lengths_to_encoder_padding_mask
|
||||
|
||||
|
||||
DEFAULT_TEST_VOCAB_SIZE = 100
|
||||
|
||||
|
||||
# ///////////////////////////////////////////////////////////////////////////
|
||||
# utility function to setup dummy dict/task/input
|
||||
# ///////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
def get_dummy_dictionary(vocab_size=DEFAULT_TEST_VOCAB_SIZE):
|
||||
dummy_dict = Dictionary()
|
||||
# add dummy symbol to satisfy vocab size
|
||||
for id, _ in enumerate(range(vocab_size)):
|
||||
dummy_dict.add_symbol("{}".format(id), 1000)
|
||||
return dummy_dict
|
||||
|
||||
|
||||
class DummyTask(FairseqTask):
|
||||
def __init__(self, args):
|
||||
super().__init__(args)
|
||||
self.dictionary = get_dummy_dictionary()
|
||||
if getattr(self.args, "ctc", False):
|
||||
self.dictionary.add_symbol("<ctc_blank>")
|
||||
self.tgt_dict = self.dictionary
|
||||
|
||||
@property
|
||||
def target_dictionary(self):
|
||||
return self.dictionary
|
||||
|
||||
|
||||
def get_dummy_task_and_parser():
|
||||
"""
|
||||
to build a fariseq model, we need some dummy parse and task. This function
|
||||
is used to create dummy task and parser to faciliate model/criterion test
|
||||
|
||||
Note: we use FbSpeechRecognitionTask as the dummy task. You may want
|
||||
to use other task by providing another function
|
||||
"""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="test_dummy_s2s_task", argument_default=argparse.SUPPRESS
|
||||
)
|
||||
DummyTask.add_args(parser)
|
||||
args = parser.parse_args([])
|
||||
task = DummyTask.setup_task(args)
|
||||
return task, parser
|
||||
|
||||
|
||||
def get_dummy_input(T=100, D=80, B=5, K=100):
|
||||
forward_input = {}
|
||||
# T max sequence length
|
||||
# D feature vector dimension
|
||||
# B batch size
|
||||
# K target dimension size
|
||||
feature = torch.randn(B, T, D)
|
||||
# this (B, T, D) layout is just a convention, you can override it by
|
||||
# write your own _prepare_forward_input function
|
||||
src_lengths = torch.from_numpy(
|
||||
np.random.randint(low=1, high=T, size=B).astype(np.int64)
|
||||
)
|
||||
src_lengths[0] = T # make sure the maximum length matches
|
||||
prev_output_tokens = []
|
||||
for b in range(B):
|
||||
token_length = np.random.randint(low=1, high=src_lengths[b].item() + 1)
|
||||
tokens = np.random.randint(low=0, high=K, size=token_length)
|
||||
prev_output_tokens.append(torch.from_numpy(tokens))
|
||||
|
||||
prev_output_tokens = fairseq_data_utils.collate_tokens(
|
||||
prev_output_tokens,
|
||||
pad_idx=1,
|
||||
eos_idx=2,
|
||||
left_pad=False,
|
||||
move_eos_to_beginning=False,
|
||||
)
|
||||
src_lengths, sorted_order = src_lengths.sort(descending=True)
|
||||
forward_input["src_tokens"] = feature.index_select(0, sorted_order)
|
||||
forward_input["src_lengths"] = src_lengths
|
||||
forward_input["prev_output_tokens"] = prev_output_tokens
|
||||
|
||||
return forward_input
|
||||
|
||||
|
||||
def get_dummy_encoder_output(encoder_out_shape=(100, 80, 5)):
|
||||
"""
|
||||
This only provides an example to generate dummy encoder output
|
||||
"""
|
||||
(T, B, D) = encoder_out_shape
|
||||
encoder_out = {}
|
||||
|
||||
encoder_out["encoder_out"] = torch.from_numpy(
|
||||
np.random.randn(*encoder_out_shape).astype(np.float32)
|
||||
)
|
||||
seq_lengths = torch.from_numpy(np.random.randint(low=1, high=T, size=B))
|
||||
# some dummy mask
|
||||
encoder_out["encoder_padding_mask"] = torch.arange(T).view(1, T).expand(
|
||||
B, -1
|
||||
) >= seq_lengths.view(B, 1).expand(-1, T)
|
||||
encoder_out["encoder_padding_mask"].t_()
|
||||
|
||||
# encoer_padding_mask is (T, B) tensor, with (t, b)-th element indicate
|
||||
# whether encoder_out[t, b] is valid (=0) or not (=1)
|
||||
return encoder_out
|
||||
|
||||
|
||||
def _current_postion_info():
|
||||
cf = currentframe()
|
||||
frameinfo = " (at {}:{})".format(
|
||||
os.path.basename(getframeinfo(cf).filename), cf.f_back.f_lineno
|
||||
)
|
||||
return frameinfo
|
||||
|
||||
|
||||
def check_encoder_output(encoder_output, batch_size=None):
|
||||
"""we expect encoder_output to be a dict with the following
|
||||
key/value pairs:
|
||||
- encoder_out: a Torch.Tensor
|
||||
- encoder_padding_mask: a binary Torch.Tensor
|
||||
"""
|
||||
if not isinstance(encoder_output, dict):
|
||||
msg = (
|
||||
"FairseqEncoderModel.forward(...) must be a dict" + _current_postion_info()
|
||||
)
|
||||
return False, msg
|
||||
|
||||
if "encoder_out" not in encoder_output:
|
||||
msg = (
|
||||
"FairseqEncoderModel.forward(...) must contain encoder_out"
|
||||
+ _current_postion_info()
|
||||
)
|
||||
return False, msg
|
||||
|
||||
if "encoder_padding_mask" not in encoder_output:
|
||||
msg = (
|
||||
"FairseqEncoderModel.forward(...) must contain encoder_padding_mask"
|
||||
+ _current_postion_info()
|
||||
)
|
||||
return False, msg
|
||||
|
||||
if not isinstance(encoder_output["encoder_out"], torch.Tensor):
|
||||
msg = "encoder_out must be a torch.Tensor" + _current_postion_info()
|
||||
return False, msg
|
||||
|
||||
if encoder_output["encoder_out"].dtype != torch.float32:
|
||||
msg = "encoder_out must have float32 dtype" + _current_postion_info()
|
||||
return False, msg
|
||||
|
||||
mask = encoder_output["encoder_padding_mask"]
|
||||
if mask is not None:
|
||||
if not isinstance(mask, torch.Tensor):
|
||||
msg = (
|
||||
"encoder_padding_mask must be a torch.Tensor" + _current_postion_info()
|
||||
)
|
||||
return False, msg
|
||||
if mask.dtype != torch.uint8:
|
||||
msg = (
|
||||
"encoder_padding_mask must have dtype of uint8"
|
||||
+ _current_postion_info()
|
||||
)
|
||||
return False, msg
|
||||
|
||||
if mask.dim() != 2:
|
||||
msg = (
|
||||
"we expect encoder_padding_mask to be a 2-d tensor, in shape (T, B)"
|
||||
+ _current_postion_info()
|
||||
)
|
||||
return False, msg
|
||||
|
||||
if batch_size is not None and mask.size(1) != batch_size:
|
||||
msg = (
|
||||
"we expect encoder_padding_mask to be a 2-d tensor, with size(1)"
|
||||
+ " being the batch size"
|
||||
+ _current_postion_info()
|
||||
)
|
||||
return False, msg
|
||||
return True, None
|
||||
|
||||
|
||||
def check_decoder_output(decoder_output):
|
||||
"""we expect output from a decoder is a tuple with the following constraint:
|
||||
- the first element is a torch.Tensor
|
||||
- the second element can be anything (reserved for future use)
|
||||
"""
|
||||
if not isinstance(decoder_output, tuple):
|
||||
msg = "FariseqDecoder output must be a tuple" + _current_postion_info()
|
||||
return False, msg
|
||||
|
||||
if len(decoder_output) != 2:
|
||||
msg = "FairseqDecoder output must be 2-elem tuple" + _current_postion_info()
|
||||
return False, msg
|
||||
|
||||
if not isinstance(decoder_output[0], torch.Tensor):
|
||||
msg = (
|
||||
"FariseqDecoder output[0] must be a torch.Tensor" + _current_postion_info()
|
||||
)
|
||||
return False, msg
|
||||
|
||||
return True, None
|
||||
|
||||
|
||||
# ///////////////////////////////////////////////////////////////////////////
|
||||
# Base Test class
|
||||
# ///////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
class TestBaseFairseqModelBase(unittest.TestCase):
|
||||
"""
|
||||
This class is used to facilitate writing unittest for any class derived from
|
||||
`BaseFairseqModel`.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
if cls is TestBaseFairseqModelBase:
|
||||
raise unittest.SkipTest("Skipping test case in base")
|
||||
super().setUpClass()
|
||||
|
||||
def setUpModel(self, model):
|
||||
self.assertTrue(isinstance(model, BaseFairseqModel))
|
||||
self.model = model
|
||||
|
||||
def setupInput(self):
|
||||
pass
|
||||
|
||||
def setUp(self):
|
||||
self.model = None
|
||||
self.forward_input = None
|
||||
pass
|
||||
|
||||
|
||||
class TestFairseqEncoderDecoderModelBase(TestBaseFairseqModelBase):
|
||||
"""
|
||||
base code to test FairseqEncoderDecoderModel (formally known as
|
||||
`FairseqModel`) must be derived from this base class
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
if cls is TestFairseqEncoderDecoderModelBase:
|
||||
raise unittest.SkipTest("Skipping test case in base")
|
||||
super().setUpClass()
|
||||
|
||||
def setUpModel(self, model_cls, extra_args_setters=None):
|
||||
self.assertTrue(
|
||||
issubclass(model_cls, (FairseqEncoderDecoderModel, FairseqModel)),
|
||||
msg="This class only tests for FairseqModel subclasses",
|
||||
)
|
||||
|
||||
task, parser = get_dummy_task_and_parser()
|
||||
model_cls.add_args(parser)
|
||||
|
||||
args = parser.parse_args([])
|
||||
if extra_args_setters is not None:
|
||||
for args_setter in extra_args_setters:
|
||||
args_setter(args)
|
||||
model = model_cls.build_model(args, task)
|
||||
self.model = model
|
||||
|
||||
def setUpInput(self, input=None):
|
||||
self.forward_input = get_dummy_input() if input is None else input
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
|
||||
def test_forward(self):
|
||||
if self.model and self.forward_input:
|
||||
forward_output = self.model.forward(**self.forward_input)
|
||||
# for FairseqEncoderDecoderModel, forward returns a tuple of two
|
||||
# elements, the first one is a Torch.Tensor
|
||||
succ, msg = check_decoder_output(forward_output)
|
||||
if not succ:
|
||||
self.assertTrue(succ, msg=msg)
|
||||
self.forward_output = forward_output
|
||||
|
||||
def test_get_normalized_probs(self):
|
||||
if self.model and self.forward_input:
|
||||
forward_output = self.model.forward(**self.forward_input)
|
||||
logprob = self.model.get_normalized_probs(forward_output, log_probs=True)
|
||||
prob = self.model.get_normalized_probs(forward_output, log_probs=False)
|
||||
|
||||
# in order for different models/criterion to play with each other
|
||||
# we need to know whether the logprob or prob output is batch_first
|
||||
# or not. We assume an additional attribute will be attached to logprob
|
||||
# or prob. If you find your code failed here, simply override
|
||||
# FairseqModel.get_normalized_probs, see example at
|
||||
# https://fburl.com/batch_first_example
|
||||
self.assertTrue(hasattr(logprob, "batch_first"))
|
||||
self.assertTrue(hasattr(prob, "batch_first"))
|
||||
|
||||
self.assertTrue(torch.is_tensor(logprob))
|
||||
self.assertTrue(torch.is_tensor(prob))
|
||||
|
||||
|
||||
class TestFairseqEncoderModelBase(TestBaseFairseqModelBase):
|
||||
"""
|
||||
base class to test FairseqEncoderModel
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
if cls is TestFairseqEncoderModelBase:
|
||||
raise unittest.SkipTest("Skipping test case in base")
|
||||
super().setUpClass()
|
||||
|
||||
def setUpModel(self, model_cls, extra_args_setters=None):
|
||||
self.assertTrue(
|
||||
issubclass(model_cls, FairseqEncoderModel),
|
||||
msg="This class is only used for testing FairseqEncoderModel",
|
||||
)
|
||||
task, parser = get_dummy_task_and_parser()
|
||||
model_cls.add_args(parser)
|
||||
args = parser.parse_args([])
|
||||
if extra_args_setters is not None:
|
||||
for args_setter in extra_args_setters:
|
||||
args_setter(args)
|
||||
|
||||
model = model_cls.build_model(args, task)
|
||||
self.model = model
|
||||
|
||||
def setUpInput(self, input=None):
|
||||
self.forward_input = get_dummy_input() if input is None else input
|
||||
# get_dummy_input() is originally for s2s, here we delete extra dict
|
||||
# items, so it can be used for EncoderModel / Encoder as well
|
||||
self.forward_input.pop("prev_output_tokens", None)
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
|
||||
def test_forward(self):
|
||||
if self.forward_input and self.model:
|
||||
bsz = self.forward_input["src_tokens"].size(0)
|
||||
forward_output = self.model.forward(**self.forward_input)
|
||||
|
||||
# we expect forward_output to be a dict with the following
|
||||
# key/value pairs:
|
||||
# - encoder_out: a Torch.Tensor
|
||||
# - encoder_padding_mask: a binary Torch.Tensor
|
||||
succ, msg = check_encoder_output(forward_output, batch_size=bsz)
|
||||
if not succ:
|
||||
self.assertTrue(succ, msg=msg)
|
||||
self.forward_output = forward_output
|
||||
|
||||
def test_get_normalized_probs(self):
|
||||
if self.model and self.forward_input:
|
||||
forward_output = self.model.forward(**self.forward_input)
|
||||
logprob = self.model.get_normalized_probs(forward_output, log_probs=True)
|
||||
prob = self.model.get_normalized_probs(forward_output, log_probs=False)
|
||||
|
||||
# in order for different models/criterion to play with each other
|
||||
# we need to know whether the logprob or prob output is batch_first
|
||||
# or not. We assume an additional attribute will be attached to logprob
|
||||
# or prob. If you find your code failed here, simply override
|
||||
# FairseqModel.get_normalized_probs, see example at
|
||||
# https://fburl.com/batch_first_example
|
||||
self.assertTrue(hasattr(logprob, "batch_first"))
|
||||
self.assertTrue(hasattr(prob, "batch_first"))
|
||||
|
||||
self.assertTrue(torch.is_tensor(logprob))
|
||||
self.assertTrue(torch.is_tensor(prob))
|
||||
|
||||
|
||||
class TestFairseqEncoderBase(unittest.TestCase):
|
||||
"""
|
||||
base class to test FairseqEncoder
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
if cls is TestFairseqEncoderBase:
|
||||
raise unittest.SkipTest("Skipping test case in base")
|
||||
super().setUpClass()
|
||||
|
||||
def setUpEncoder(self, encoder):
|
||||
self.assertTrue(
|
||||
isinstance(encoder, FairseqEncoder),
|
||||
msg="This class is only used for test FairseqEncoder",
|
||||
)
|
||||
self.encoder = encoder
|
||||
|
||||
def setUpInput(self, input=None):
|
||||
self.forward_input = get_dummy_input() if input is None else input
|
||||
# get_dummy_input() is originally for s2s, here we delete extra dict
|
||||
# items, so it can be used for EncoderModel / Encoder as well
|
||||
self.forward_input.pop("prev_output_tokens", None)
|
||||
|
||||
def setUp(self):
|
||||
self.encoder = None
|
||||
self.forward_input = None
|
||||
|
||||
def test_forward(self):
|
||||
if self.encoder and self.forward_input:
|
||||
bsz = self.forward_input["src_tokens"].size(0)
|
||||
|
||||
forward_output = self.encoder.forward(**self.forward_input)
|
||||
succ, msg = check_encoder_output(forward_output, batch_size=bsz)
|
||||
if not succ:
|
||||
self.assertTrue(succ, msg=msg)
|
||||
self.forward_output = forward_output
|
||||
|
||||
|
||||
class TestFairseqDecoderBase(unittest.TestCase):
|
||||
"""
|
||||
base class to test FairseqDecoder
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
if cls is TestFairseqDecoderBase:
|
||||
raise unittest.SkipTest("Skipping test case in base")
|
||||
super().setUpClass()
|
||||
|
||||
def setUpDecoder(self, decoder):
|
||||
self.assertTrue(
|
||||
isinstance(decoder, FairseqDecoder),
|
||||
msg="This class is only used for test FairseqDecoder",
|
||||
)
|
||||
self.decoder = decoder
|
||||
|
||||
def setUpInput(self, input=None):
|
||||
self.forward_input = get_dummy_encoder_output() if input is None else input
|
||||
|
||||
def setUpPrevOutputTokens(self, tokens=None):
|
||||
if tokens is None:
|
||||
self.encoder_input = get_dummy_input()
|
||||
self.prev_output_tokens = self.encoder_input["prev_output_tokens"]
|
||||
else:
|
||||
self.prev_output_tokens = tokens
|
||||
|
||||
def setUp(self):
|
||||
self.decoder = None
|
||||
self.forward_input = None
|
||||
self.prev_output_tokens = None
|
||||
|
||||
def test_forward(self):
|
||||
if (
|
||||
self.decoder is not None
|
||||
and self.forward_input is not None
|
||||
and self.prev_output_tokens is not None
|
||||
):
|
||||
forward_output = self.decoder.forward(
|
||||
prev_output_tokens=self.prev_output_tokens,
|
||||
encoder_out=self.forward_input,
|
||||
)
|
||||
succ, msg = check_decoder_output(forward_output)
|
||||
if not succ:
|
||||
self.assertTrue(succ, msg=msg)
|
||||
self.forward_input = forward_output
|
||||
|
||||
|
||||
class DummyEncoderModel(FairseqEncoderModel):
|
||||
def __init__(self, encoder):
|
||||
super().__init__(encoder)
|
||||
|
||||
@classmethod
|
||||
def build_model(cls, args, task):
|
||||
return cls(DummyEncoder())
|
||||
|
||||
def get_logits(self, net_output):
|
||||
# Inverse of sigmoid to use with BinaryCrossEntropyWithLogitsCriterion as
|
||||
# F.binary_cross_entropy_with_logits combines sigmoid and CE
|
||||
return torch.log(
|
||||
torch.div(net_output["encoder_out"], 1 - net_output["encoder_out"])
|
||||
)
|
||||
|
||||
|
||||
class DummyEncoder(FairseqEncoder):
|
||||
def __init__(self):
|
||||
super().__init__(None)
|
||||
|
||||
def forward(self, src_tokens, src_lengths):
|
||||
mask, max_len = lengths_to_encoder_padding_mask(src_lengths)
|
||||
return {"encoder_out": src_tokens, "encoder_padding_mask": mask}
|
||||
|
||||
|
||||
class CrossEntropyCriterionTestBase(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
if cls is CrossEntropyCriterionTestBase:
|
||||
raise unittest.SkipTest("Skipping base class test case")
|
||||
super().setUpClass()
|
||||
|
||||
def setUpArgs(self):
|
||||
args = argparse.Namespace()
|
||||
args.sentence_avg = False
|
||||
args.threshold = 0.1 # to use with BinaryCrossEntropyWithLogitsCriterion
|
||||
return args
|
||||
|
||||
def setUp(self):
|
||||
args = self.setUpArgs()
|
||||
self.model = DummyEncoderModel(encoder=DummyEncoder())
|
||||
self.criterion = self.criterion_cls(args=args, task=DummyTask(args))
|
||||
|
||||
def get_src_tokens(self, correct_prediction, aggregate):
|
||||
"""
|
||||
correct_prediction: True if the net_output (src_tokens) should
|
||||
predict the correct target
|
||||
aggregate: True if the criterion expects net_output (src_tokens)
|
||||
aggregated across time axis
|
||||
"""
|
||||
predicted_idx = 0 if correct_prediction else 1
|
||||
if aggregate:
|
||||
src_tokens = torch.zeros((2, 2), dtype=torch.float)
|
||||
for b in range(2):
|
||||
src_tokens[b][predicted_idx] = 1.0
|
||||
else:
|
||||
src_tokens = torch.zeros((2, 10, 2), dtype=torch.float)
|
||||
for b in range(2):
|
||||
for t in range(10):
|
||||
src_tokens[b][t][predicted_idx] = 1.0
|
||||
return src_tokens
|
||||
|
||||
def get_target(self, soft_target):
|
||||
if soft_target:
|
||||
target = torch.zeros((2, 2), dtype=torch.float)
|
||||
for b in range(2):
|
||||
target[b][0] = 1.0
|
||||
else:
|
||||
target = torch.zeros((2, 10), dtype=torch.long)
|
||||
return target
|
||||
|
||||
def get_test_sample(self, correct, soft_target, aggregate):
|
||||
src_tokens = self.get_src_tokens(correct, aggregate)
|
||||
target = self.get_target(soft_target)
|
||||
L = src_tokens.size(1)
|
||||
return {
|
||||
"net_input": {"src_tokens": src_tokens, "src_lengths": torch.tensor([L])},
|
||||
"target": target,
|
||||
"ntokens": src_tokens.size(0) * src_tokens.size(1),
|
||||
}
|
60
tests/speech_recognition/test_collaters.py
Normal file
60
tests/speech_recognition/test_collaters.py
Normal file
@ -0,0 +1,60 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) 2017-present, Facebook, Inc.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the LICENSE file in
|
||||
# the root directory of this source tree. An additional grant of patent rights
|
||||
# can be found in the PATENTS file in the same directory.
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from examples.speech_recognition.data.collaters import Seq2SeqCollater
|
||||
|
||||
|
||||
class TestSeq2SeqCollator(unittest.TestCase):
|
||||
def test_collate(self):
|
||||
|
||||
eos_idx = 1
|
||||
pad_idx = 0
|
||||
collater = Seq2SeqCollater(
|
||||
feature_index=0, label_index=1, pad_index=pad_idx, eos_index=eos_idx
|
||||
)
|
||||
|
||||
# 2 frames in the first sample and 3 frames in the second one
|
||||
frames1 = np.array([[7, 8], [9, 10]])
|
||||
frames2 = np.array([[1, 2], [3, 4], [5, 6]])
|
||||
target1 = np.array([4, 2, 3, eos_idx])
|
||||
target2 = np.array([3, 2, eos_idx])
|
||||
sample1 = {"id": 0, "data": [frames1, target1]}
|
||||
sample2 = {"id": 1, "data": [frames2, target2]}
|
||||
batch = collater.collate([sample1, sample2])
|
||||
|
||||
# collate sort inputs by frame's length before creating the batch
|
||||
self.assertTensorEqual(batch["id"], torch.tensor([1, 0]))
|
||||
self.assertEqual(batch["ntokens"], 7)
|
||||
self.assertTensorEqual(
|
||||
batch["net_input"]["src_tokens"],
|
||||
torch.tensor(
|
||||
[[[1, 2], [3, 4], [5, 6]], [[7, 8], [9, 10], [pad_idx, pad_idx]]]
|
||||
),
|
||||
)
|
||||
self.assertTensorEqual(
|
||||
batch["net_input"]["prev_output_tokens"],
|
||||
torch.tensor([[eos_idx, 3, 2, pad_idx], [eos_idx, 4, 2, 3]]),
|
||||
)
|
||||
self.assertTensorEqual(batch["net_input"]["src_lengths"], torch.tensor([3, 2]))
|
||||
self.assertTensorEqual(
|
||||
batch["target"],
|
||||
torch.tensor([[3, 2, eos_idx, pad_idx], [4, 2, 3, eos_idx]]),
|
||||
)
|
||||
self.assertEqual(batch["nsentences"], 2)
|
||||
|
||||
def assertTensorEqual(self, t1, t2):
|
||||
self.assertEqual(t1.size(), t2.size(), "size mismatch")
|
||||
self.assertEqual(t1.ne(t2).long().sum(), 0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
36
tests/speech_recognition/test_cross_entropy.py
Normal file
36
tests/speech_recognition/test_cross_entropy.py
Normal file
@ -0,0 +1,36 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) 2017-present, Facebook, Inc.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the LICENSE file in
|
||||
# the root directory of this source tree. An additional grant of patent rights
|
||||
# can be found in the PATENTS file in the same directory.
|
||||
|
||||
from examples.speech_recognition.criterions.cross_entropy_acc import CrossEntropyWithAccCriterion
|
||||
from .asr_test_base import CrossEntropyCriterionTestBase
|
||||
|
||||
|
||||
class CrossEntropyWithAccCriterionTest(CrossEntropyCriterionTestBase):
|
||||
def setUp(self):
|
||||
self.criterion_cls = CrossEntropyWithAccCriterion
|
||||
super().setUp()
|
||||
|
||||
def test_cross_entropy_all_correct(self):
|
||||
sample = self.get_test_sample(correct=True, soft_target=False, aggregate=False)
|
||||
loss, sample_size, logging_output = self.criterion(
|
||||
self.model, sample, "sum", log_probs=True
|
||||
)
|
||||
assert logging_output["correct"] == 20
|
||||
assert logging_output["total"] == 20
|
||||
assert logging_output["sample_size"] == 20
|
||||
assert logging_output["ntokens"] == 20
|
||||
|
||||
def test_cross_entropy_all_wrong(self):
|
||||
sample = self.get_test_sample(correct=False, soft_target=False, aggregate=False)
|
||||
loss, sample_size, logging_output = self.criterion(
|
||||
self.model, sample, "sum", log_probs=True
|
||||
)
|
||||
assert logging_output["correct"] == 0
|
||||
assert logging_output["total"] == 20
|
||||
assert logging_output["sample_size"] == 20
|
||||
assert logging_output["ntokens"] == 20
|
135
tests/speech_recognition/test_vggtransformer.py
Normal file
135
tests/speech_recognition/test_vggtransformer.py
Normal file
@ -0,0 +1,135 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# import models/encoder/decoder to be tested
|
||||
from examples.speech_recognition.models.vggtransformer import (
|
||||
TransformerDecoder,
|
||||
VGGTransformerEncoder,
|
||||
VGGTransformerModel,
|
||||
vggtransformer_1,
|
||||
vggtransformer_2,
|
||||
vggtransformer_base,
|
||||
)
|
||||
|
||||
# import base test class
|
||||
from .asr_test_base import (
|
||||
DEFAULT_TEST_VOCAB_SIZE,
|
||||
TestFairseqDecoderBase,
|
||||
TestFairseqEncoderBase,
|
||||
TestFairseqEncoderDecoderModelBase,
|
||||
get_dummy_dictionary,
|
||||
get_dummy_encoder_output,
|
||||
get_dummy_input,
|
||||
)
|
||||
|
||||
|
||||
class VGGTransformerModelTest_mid(TestFairseqEncoderDecoderModelBase):
|
||||
def setUp(self):
|
||||
def override_config(args):
|
||||
"""
|
||||
vggtrasformer_1 use 14 layers of transformer,
|
||||
for testing purpose, it is too expensive. For fast turn-around
|
||||
test, reduce the number of layers to 3.
|
||||
"""
|
||||
args.transformer_enc_config = (
|
||||
"((1024, 16, 4096, True, 0.15, 0.15, 0.15),) * 3"
|
||||
)
|
||||
|
||||
super().setUp()
|
||||
extra_args_setter = [vggtransformer_1, override_config]
|
||||
|
||||
self.setUpModel(VGGTransformerModel, extra_args_setter)
|
||||
self.setUpInput(get_dummy_input(T=50, D=80, B=5, K=DEFAULT_TEST_VOCAB_SIZE))
|
||||
|
||||
|
||||
class VGGTransformerModelTest_big(TestFairseqEncoderDecoderModelBase):
|
||||
def setUp(self):
|
||||
def override_config(args):
|
||||
"""
|
||||
vggtrasformer_2 use 16 layers of transformer,
|
||||
for testing purpose, it is too expensive. For fast turn-around
|
||||
test, reduce the number of layers to 3.
|
||||
"""
|
||||
args.transformer_enc_config = (
|
||||
"((1024, 16, 4096, True, 0.15, 0.15, 0.15),) * 3"
|
||||
)
|
||||
|
||||
super().setUp()
|
||||
extra_args_setter = [vggtransformer_2, override_config]
|
||||
|
||||
self.setUpModel(VGGTransformerModel, extra_args_setter)
|
||||
self.setUpInput(get_dummy_input(T=50, D=80, B=5, K=DEFAULT_TEST_VOCAB_SIZE))
|
||||
|
||||
|
||||
class VGGTransformerModelTest_base(TestFairseqEncoderDecoderModelBase):
|
||||
def setUp(self):
|
||||
def override_config(args):
|
||||
"""
|
||||
vggtrasformer_base use 12 layers of transformer,
|
||||
for testing purpose, it is too expensive. For fast turn-around
|
||||
test, reduce the number of layers to 3.
|
||||
"""
|
||||
args.transformer_enc_config = (
|
||||
"((512, 8, 2048, True, 0.15, 0.15, 0.15),) * 3"
|
||||
)
|
||||
|
||||
super().setUp()
|
||||
extra_args_setter = [vggtransformer_base, override_config]
|
||||
|
||||
self.setUpModel(VGGTransformerModel, extra_args_setter)
|
||||
self.setUpInput(get_dummy_input(T=50, D=80, B=5, K=DEFAULT_TEST_VOCAB_SIZE))
|
||||
|
||||
|
||||
class VGGTransformerEncoderTest(TestFairseqEncoderBase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
|
||||
self.setUpInput(get_dummy_input(T=50, D=80, B=5))
|
||||
|
||||
def test_forward(self):
|
||||
print("1. test standard vggtransformer")
|
||||
self.setUpEncoder(VGGTransformerEncoder(input_feat_per_channel=80))
|
||||
super().test_forward()
|
||||
print("2. test vggtransformer with limited right context")
|
||||
self.setUpEncoder(
|
||||
VGGTransformerEncoder(
|
||||
input_feat_per_channel=80, transformer_context=(-1, 5)
|
||||
)
|
||||
)
|
||||
super().test_forward()
|
||||
print("3. test vggtransformer with limited left context")
|
||||
self.setUpEncoder(
|
||||
VGGTransformerEncoder(
|
||||
input_feat_per_channel=80, transformer_context=(5, -1)
|
||||
)
|
||||
)
|
||||
super().test_forward()
|
||||
print("4. test vggtransformer with limited right context and sampling")
|
||||
self.setUpEncoder(
|
||||
VGGTransformerEncoder(
|
||||
input_feat_per_channel=80,
|
||||
transformer_context=(-1, 12),
|
||||
transformer_sampling=(2, 2),
|
||||
)
|
||||
)
|
||||
super().test_forward()
|
||||
print("5. test vggtransformer with windowed context and sampling")
|
||||
self.setUpEncoder(
|
||||
VGGTransformerEncoder(
|
||||
input_feat_per_channel=80,
|
||||
transformer_context=(12, 12),
|
||||
transformer_sampling=(2, 2),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class TransformerDecoderTest(TestFairseqDecoderBase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
|
||||
dict = get_dummy_dictionary(vocab_size=DEFAULT_TEST_VOCAB_SIZE)
|
||||
decoder = TransformerDecoder(dict)
|
||||
dummy_encoder_output = get_dummy_encoder_output(encoder_out_shape=(50, 5, 256))
|
||||
|
||||
self.setUpDecoder(decoder)
|
||||
self.setUpInput(dummy_encoder_output)
|
||||
self.setUpPrevOutputTokens()
|
Loading…
Reference in New Issue
Block a user