Merge STPT: Step 3

Summary:
1. Add joint pre-training scripts
2. Replace prepend_tgt_lang_tag_no_change with prepend_tgt_lang_tag_as_bos
3. Add readme for the joint pre-training
4. Add test case for the Librispeech model

Reviewed By: hygong-fb

Differential Revision: D36300953

fbshipit-source-id: cb749689787ed97c1250d122bdefb7f7a2252292
This commit is contained in:
Yun Tang 2022-05-10 19:44:00 -07:00 committed by Facebook GitHub Bot
parent 4368ede817
commit 993129dae4
18 changed files with 3831 additions and 23 deletions

View File

@ -5,17 +5,16 @@ An extension of Fairseq s2t project with the speech to text task enhanced by the
Examples of speech text joint training in fairseq
- [English-to-German MuST-C model](docs/ende-mustc.md)
- [IWSLT 2021 Multilingual Speech Translation](docs/iwslt2021.md)
- [Speech Text Joint Pre-training ](docs/pre-training.md)
## Citation
Please cite as:
```
@inproceedings{Tang2021AGM,
title={A General Multi-Task Learning Framework to Leverage Text Data for Speech to Text Tasks},
author={Yun Tang and J. Pino and Changhan Wang and Xutai Ma and Dmitriy Genzel},
booktitle={ICASSP},
year={2021}
@inproceedings{Tang2022UnifiedSP,
title={Unified Speech-Text Pre-training for Speech Translation and Recognition},
author={Yun Tang and Hongyu Gong and Ning Dong and Changhan Wang and Wei-Ning Hsu and Jiatao Gu and Alexei Baevski and Xian Li and Abdelrahman Mohamed and Michael Auli and Juan Miguel Pino},
booktitle={ACL},
year={2022}
}
@inproceedings{Tang2021IST,
title = {Improving Speech Translation by Understanding and Learning from the Auxiliary Text Translation Task},
author = {Yun Tang and Juan Pino and Xian Li and Changhan Wang and Dmitriy Genzel},
@ -29,6 +28,12 @@ Please cite as:
booktitle = {IWSLT},
year = {2021},
}
@inproceedings{Tang2021AGM,
title={A General Multi-Task Learning Framework to Leverage Text Data for Speech to Text Tasks},
author={Yun Tang and J. Pino and Changhan Wang and Xutai Ma and Dmitriy Genzel},
booktitle={ICASSP},
year={2021}
}
@inproceedings{wang2020fairseqs2t,
title = {fairseq S2T: Fast Speech-to-Text Modeling with fairseq},

View File

@ -0,0 +1,180 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
import math
from dataclasses import dataclass, field
from fairseq import metrics, utils
from fairseq.criterions import FairseqCriterion, register_criterion
from fairseq.criterions.ctc import CtcCriterion, CtcCriterionConfig
from fairseq.criterions.label_smoothed_cross_entropy import (
LabelSmoothedCrossEntropyCriterionConfig,
)
from fairseq.logging.meters import safe_round
from .multi_modality_cross_entropy import SpeechTextPreTrainCrossEntCriterion
logger = logging.getLogger(__name__)
@dataclass
class SpeechTextPreTrainCompoundCriterionConfig(
LabelSmoothedCrossEntropyCriterionConfig
):
zero_infinity: bool = field(
default=False,
metadata={"help": "zero inf loss when source length <= target length"},
)
post_process: str = field(
default="none",
metadata={
"help": "how to post process predictions into words. can be letter, "
"wordpiece, BPE symbols, etc. "
"See fairseq.data.data_utils.post_process() for full list of options"
},
)
@register_criterion(
"speech_text_pretrain_compound", dataclass=SpeechTextPreTrainCompoundCriterionConfig
)
class SpeechTextPreTrainCompoundCriterion(FairseqCriterion):
def __init__(
self,
task,
sentence_avg,
label_smoothing,
report_accuracy=False,
zero_infinity=False,
post_process=None,
):
super().__init__(task)
self.xent = SpeechTextPreTrainCrossEntCriterion(
task, sentence_avg, label_smoothing, report_accuracy
)
cfg_dict = {
"zero_infinity": zero_infinity,
"sentence_avg": sentence_avg,
"post_process": post_process,
}
cfg_ctc = CtcCriterionConfig(**cfg_dict)
self.ctc = CtcCriterion(cfg_ctc, task)
def forward(self, model, sample, reduce=True):
mode = sample["net_input"]["mode"]
if mode == "sup_speech_ctc": # CTC
sample["net_input"][
"src_lengths"
] = None # get downsampled src_lengths from padding_mask
loss, sample_size, logging_output = self.ctc(model, sample, reduce)
logging_output["mode"] = SpeechTextPreTrainCompoundCriterion.mode2value(
"CTC"
)
else:
loss, sample_size, logging_output = self.xent(model, sample, reduce)
logging_output["mode"] = SpeechTextPreTrainCompoundCriterion.mode2value(
"xent"
)
return loss, sample_size, logging_output
@staticmethod
def logging_outputs_can_be_summed() -> bool:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
return True
@staticmethod
def mode2value(mode): # make the logging_outputs_can_be_summed = True
if mode == "CTC":
return 907 # prime number
if mode == "xent":
return 887 # prime number
return 0
@staticmethod
def value2mode(value):
if value % 907 == 0:
return "CTC"
if value % 887 == 0:
return "xent"
raise ValueError("Unknow mode")
@staticmethod
def reduce_metrics(logging_outputs) -> None:
"""Aggregate logging outputs from data parallel training."""
def _get_mode(logging_outputs):
mds = [
SpeechTextPreTrainCompoundCriterion.value2mode(log["mode"])
for log in logging_outputs
]
if sum([1 if l != mds[0] else 0 for l in mds]) > 0:
raise ValueError("mode in one mini-batch is expected to be the same!")
return mds[0]
log_mode = _get_mode(logging_outputs)
if log_mode == "xent":
return SpeechTextPreTrainCrossEntCriterion.reduce_metrics(logging_outputs)
# ctc loss
loss_sum = utils.item(sum(log.get("loss", 0) for log in logging_outputs))
ntokens = utils.item(sum(log.get("ntokens", 0) for log in logging_outputs))
nsentences = utils.item(
sum(log.get("nsentences", 0) for log in logging_outputs)
)
sample_size = utils.item(
sum(log.get("sample_size", 0) for log in logging_outputs)
)
metrics.log_scalar(
"ctc_loss", loss_sum / sample_size / math.log(2), sample_size, round=3
)
metrics.log_scalar("ctc_ntokens", ntokens)
metrics.log_scalar("ctc_nsentences", nsentences)
if sample_size != ntokens:
metrics.log_scalar(
"ctc_nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3
)
c_errors = sum(log.get("c_errors", 0) for log in logging_outputs)
metrics.log_scalar("_c_errors", c_errors)
c_total = sum(log.get("c_total", 0) for log in logging_outputs)
metrics.log_scalar("_c_total", c_total)
w_errors = sum(log.get("w_errors", 0) for log in logging_outputs)
metrics.log_scalar("_w_errors", w_errors)
wv_errors = sum(log.get("wv_errors", 0) for log in logging_outputs)
metrics.log_scalar("_wv_errors", wv_errors)
w_total = sum(log.get("w_total", 0) for log in logging_outputs)
metrics.log_scalar("_w_total", w_total)
if c_total > 0:
metrics.log_derived(
"uer",
lambda meters: safe_round(
meters["_c_errors"].sum * 100.0 / meters["_c_total"].sum, 3
)
if meters["_c_total"].sum > 0
else float("nan"),
)
if w_total > 0:
metrics.log_derived(
"wer",
lambda meters: safe_round(
meters["_w_errors"].sum * 100.0 / meters["_w_total"].sum, 3
)
if meters["_w_total"].sum > 0
else float("nan"),
)
metrics.log_derived(
"raw_wer",
lambda meters: safe_round(
meters["_wv_errors"].sum * 100.0 / meters["_w_total"].sum, 3
)
if meters["_w_total"].sum > 0
else float("nan"),
)

View File

@ -0,0 +1,101 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
from fairseq import utils
from fairseq.criterions import register_criterion
from fairseq.criterions.label_smoothed_cross_entropy import (
LabelSmoothedCrossEntropyCriterion,
LabelSmoothedCrossEntropyCriterionConfig,
label_smoothed_nll_loss,
)
@register_criterion(
"speech_text_pretrain_cross_entropy",
dataclass=LabelSmoothedCrossEntropyCriterionConfig,
)
class SpeechTextPreTrainCrossEntCriterion(LabelSmoothedCrossEntropyCriterion):
def __init__(self, task, sentence_avg, label_smoothing, report_accuracy=False):
super().__init__(
task, sentence_avg, label_smoothing, report_accuracy=report_accuracy
)
def forward(self, model, sample, reduce=True):
net_output = model(**sample["net_input"])
loss, nll_loss, nsentences, ntokens, n_correct = self.compute_loss(
model, net_output, sample, reduce=reduce
)
sample_size = nsentences if self.sentence_avg else ntokens
logging_output = {
"loss": loss.data,
"nll_loss": nll_loss.data,
"ntokens": ntokens,
"nsentences": nsentences,
"sample_size": sample_size,
}
if self.report_accuracy:
logging_output["n_correct"] = utils.item(n_correct)
logging_output["total"] = utils.item(ntokens)
return loss, sample_size, logging_output
def get_lprobs_and_target(self, model, net_output, sample):
lprobs = model.get_normalized_probs(net_output, log_probs=True)
target = model.get_targets(sample, net_output)
assert self.ignore_prefix_size == 0
if self.ignore_prefix_size > 0:
if getattr(lprobs, "batch_first", False):
lprobs = lprobs[:, self.ignore_prefix_size :, :].contiguous()
target = target[:, self.ignore_prefix_size :].contiguous()
else:
lprobs = lprobs[self.ignore_prefix_size :, :, :].contiguous()
target = target[self.ignore_prefix_size :, :].contiguous()
return lprobs, target
def compute_loss(self, model, net_output, sample, reduce=True):
lprobs, target = self.get_lprobs_and_target(model, net_output, sample)
n_correct = 0
if isinstance(target, dict):
t_lprobs = target["target_logprobs"]
if not lprobs.batch_first:
lprobs = lprobs.transpose(0, 1)
t_lprobs = t_lprobs.transpose(0, 1)
nsentences, seq_len = lprobs.size()[:2]
ntokens = nsentences * seq_len
t_probs = t_lprobs.exp()
mask_indices = (
net_output[1]["mask_indices"][0]
if len(net_output[1]["mask_indices"]) > 0
else None
)
# mask_indices is True for those masking frames
if mask_indices is not None: # B X T
t_probs = t_probs.masked_fill(mask_indices.eq(False).unsqueeze(-1), 0)
ntokens = mask_indices.int().sum()
t_probs = t_probs.detach()
t_lprobs = t_lprobs.detach()
loss = (
-(t_probs * (lprobs - t_lprobs)).sum()
if reduce
else -(t_probs * (lprobs - t_lprobs)).sum(-1, keepdim=True)
)
nll_loss = loss
else:
nsentences = target.size(0)
mask = target.ne(self.padding_idx)
loss, nll_loss = label_smoothed_nll_loss(
lprobs.view(-1, lprobs.size(-1)),
target.view(-1),
self.eps,
ignore_index=self.padding_idx,
reduce=reduce,
)
n_correct = torch.sum(
lprobs.argmax(-1).masked_select(mask).eq(target.masked_select(mask))
)
ntokens = torch.sum(mask)
return loss, nll_loss, nsentences, ntokens, n_correct

View File

@ -0,0 +1,318 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import copy
import math
import re
import torch
from fairseq.data import data_utils
from fairseq.data.language_pair_dataset import LanguagePairDataset
# Part of the code is modified from DenoisingDataset
# compared with DenoisingDataset, no permute_sentences or documents (rotate_ratio, permute_sentence_ratio)
class LanguagePairDenoisingDataset(LanguagePairDataset):
def __init__(
self,
src,
src_sizes,
src_dict,
tgt,
tgt_sizes,
tgt_dict,
mask_idx,
mask_whole_words,
seed,
args,
left_pad_source=True,
left_pad_target=False,
shuffle=True,
input_feeding=True,
remove_eos_from_source=False,
append_eos_to_target=False,
align_dataset=None,
constraints=None,
append_bos=False,
eos=None,
num_buckets=0,
src_lang_id=None,
tgt_lang_id=None,
pad_to_multiple=1,
):
super().__init__(
src,
src_sizes,
src_dict,
tgt,
tgt_sizes,
tgt_dict,
left_pad_source,
left_pad_target,
shuffle,
input_feeding,
remove_eos_from_source,
append_eos_to_target,
align_dataset,
constraints,
append_bos,
eos,
num_buckets,
src_lang_id,
tgt_lang_id,
pad_to_multiple,
)
self.mask_idx = mask_idx
self.mask_whole_word = mask_whole_words
self.mask_ratio = args.mask
self.random_ratio = args.mask_random
self.insert_ratio = args.insert
self.replace_length = args.replace_length
if self.replace_length not in [-1, 0, 1]:
raise ValueError(f"invalid arg: replace_length={self.replace_length}")
if args.mask_length not in ["subword", "word", "span-poisson"]:
raise ValueError(f"invalid arg: mask-length={args.mask_length}")
if args.mask_length == "subword" and args.replace_length not in [0, 1]:
raise ValueError("if using subwords, use replace-length=1 or 0")
self.mask_span_distribution = None
if args.mask_length == "span-poisson":
# Text infilling: "A number of text spans are sampled, with span lengths drawn from a Poisson distribution (λ = 3). Each span is replaced with a single [MASK] token. 0-length spans correspond to the insertion of [MASK] tokens."
_lambda = args.poisson_lambda
lambda_to_the_k = 1
e_to_the_minus_lambda = math.exp(-_lambda)
k_factorial = 1
ps = []
for k in range(0, 128):
ps.append(e_to_the_minus_lambda * lambda_to_the_k / k_factorial)
lambda_to_the_k *= _lambda
k_factorial *= k + 1
if ps[-1] < 0.0000001:
break
ps = torch.FloatTensor(ps)
self.mask_span_distribution = torch.distributions.Categorical(ps)
self.epoch = 0
self.seed = seed
def _is_phoneme(x):
if re.search("<lang:", x) or x in (
"<mask>",
"<sil>",
"<pad>",
"<s>",
"</s>",
"<unk>",
):
return False
return True
self.voc_valid_ids = torch.LongTensor(
[i for i, x in enumerate(self.src_dict.symbols) if _is_phoneme(x)]
)
self.voc_valid_size = self.voc_valid_ids.size(0)
@property
def can_reuse_epoch_itr_across_epochs(self):
return False
def set_epoch(self, epoch, **unused):
self.epoch = epoch
def __getitem__(self, index):
tgt_item = self.tgt[index] if self.tgt is not None else None
src_item = copy.deepcopy(self.src[index])
with data_utils.numpy_seed(self.seed, self.epoch, index):
source = src_item
assert source[-1] == self.eos
if self.mask_ratio > 0:
source = self.add_whole_word_mask(source, self.mask_ratio)
if self.insert_ratio > 0:
source = self.add_insertion_noise(source, self.insert_ratio)
src_item = source
if self.append_eos_to_target:
eos = self.tgt_dict.eos() if self.tgt_dict else self.src_dict.eos()
if self.tgt and self.tgt[index][-1] != eos:
tgt_item = torch.cat([self.tgt[index], torch.LongTensor([eos])])
if self.append_bos:
bos = self.tgt_dict.bos() if self.tgt_dict else self.src_dict.bos()
if self.tgt and self.tgt[index][0] != bos:
tgt_item = torch.cat([torch.LongTensor([bos]), self.tgt[index]])
bos = self.src_dict.bos()
if src_item[0] != bos:
src_item = torch.cat([torch.LongTensor([bos]), src_item])
if self.remove_eos_from_source:
eos = self.src_dict.eos()
if src_item[-1] == eos:
src_item = src_item[:-1]
example = {
"id": index,
"source": src_item,
"target": tgt_item,
}
if self.align_dataset is not None:
example["alignment"] = self.align_dataset[index]
if self.constraints is not None:
example["constraints"] = self.constraints[index]
if self.src_lang_id is not None:
example["src_lang_id"] = self.src_lang_id
if self.tgt_lang_id is not None:
example["tgt_lang_id"] = self.tgt_lang_id
return example
# following functions are borrowed from denoising_dataset
def word_starts(self, source):
if self.mask_whole_word is not None:
is_word_start = self.mask_whole_word.gather(0, source)
else:
is_word_start = torch.ones(source.size())
is_word_start[0] = 0
is_word_start[-1] = 0
return is_word_start
def add_whole_word_mask(self, source, p):
is_word_start = self.word_starts(source)
num_to_mask = int(math.ceil(is_word_start.float().sum() * p))
num_inserts = 0
if num_to_mask == 0:
return source
if self.mask_span_distribution is not None:
lengths = self.mask_span_distribution.sample(sample_shape=(num_to_mask,))
# Make sure we have enough to mask
cum_length = torch.cumsum(lengths, 0)
while cum_length[-1] < num_to_mask:
lengths = torch.cat(
[
lengths,
self.mask_span_distribution.sample(sample_shape=(num_to_mask,)),
],
dim=0,
)
cum_length = torch.cumsum(lengths, 0)
# Trim to masking budget
i = 0
while cum_length[i] < num_to_mask:
i += 1
lengths[i] = num_to_mask - (0 if i == 0 else cum_length[i - 1])
num_to_mask = i + 1
lengths = lengths[:num_to_mask]
# Handle 0-length mask (inserts) separately
lengths = lengths[lengths > 0]
num_inserts = num_to_mask - lengths.size(0)
num_to_mask -= num_inserts
if num_to_mask == 0:
return self.add_insertion_noise(source, num_inserts / source.size(0))
assert (lengths > 0).all()
else:
lengths = torch.ones((num_to_mask,)).long()
assert is_word_start[-1] == 0
word_starts = is_word_start.nonzero(as_tuple=False)
indices = word_starts[
torch.randperm(word_starts.size(0))[:num_to_mask]
].squeeze(1)
mask_random = torch.FloatTensor(num_to_mask).uniform_() < self.random_ratio
source_length = source.size(0)
assert source_length - 1 not in indices
to_keep = torch.ones(source_length, dtype=torch.bool)
is_word_start[
-1
] = 255 # acts as a long length, so spans don't go over the end of doc
if self.replace_length == 0:
to_keep[indices] = 0
else:
# keep index, but replace it with [MASK]
source[indices] = self.mask_idx
source[indices[mask_random]] = self.voc_valid_ids[
torch.randint(0, self.voc_valid_size - 1, size=(mask_random.sum(),))
]
if self.mask_span_distribution is not None:
assert len(lengths.size()) == 1
assert lengths.size() == indices.size()
lengths -= 1
while indices.size(0) > 0:
assert lengths.size() == indices.size()
lengths -= is_word_start[indices + 1].long()
uncompleted = lengths >= 0
indices = indices[uncompleted] + 1
mask_random = mask_random[uncompleted]
lengths = lengths[uncompleted]
if self.replace_length != -1:
# delete token
to_keep[indices] = 0
else:
# keep index, but replace it with [MASK]
source[indices] = self.mask_idx
source[indices[mask_random]] = self.voc_valid_ids[
torch.randint(
0, self.voc_valid_size - 1, size=(mask_random.sum(),)
)
]
else:
# A bit faster when all lengths are 1
while indices.size(0) > 0:
uncompleted = is_word_start[indices + 1] == 0
indices = indices[uncompleted] + 1
mask_random = mask_random[uncompleted]
if self.replace_length != -1:
# delete token
to_keep[indices] = 0
else:
# keep index, but replace it with [MASK]
source[indices] = self.mask_idx
source[indices[mask_random]] = self.voc_valid_ids[
torch.randint(
0, self.voc_valid_size - 1, size=(mask_random.sum(),)
)
]
assert source_length - 1 not in indices
source = source[to_keep]
if num_inserts > 0:
source = self.add_insertion_noise(source, num_inserts / source.size(0))
return source
def add_insertion_noise(self, tokens, p):
if p == 0.0:
return tokens
num_tokens = len(tokens)
n = int(math.ceil(num_tokens * p))
noise_indices = torch.randperm(num_tokens + n - 2)[:n] + 1
noise_mask = torch.zeros(size=(num_tokens + n,), dtype=torch.bool)
noise_mask[noise_indices] = 1
result = torch.LongTensor(n + len(tokens)).fill_(-1)
num_random = int(math.ceil(n * self.random_ratio))
result[noise_indices[num_random:]] = self.mask_idx
result[noise_indices[:num_random]] = self.voc_valid_ids[
torch.randint(0, self.voc_valid_size - 1, size=(num_random,))
]
result[~noise_mask] = tokens
assert (result >= 0).all()
return result

View File

@ -22,11 +22,17 @@ Enhanced Joint Training: the joint training is enhanced with pre-trained models,
--out-path ${must_c_en_de_src_text_pho}
```
- Replace the source text under the "src_text" column in the tsv file with the corresponding phoneme reprentation generated in the step above.
Below is the snapshot for the MuST-C en-de dev tsv
```
id audio n_frames tgt_text src_text speaker
ted_767_0 en-de/flac.zip:10071514743:48445 56160 Heute spreche ich zu Ihnen über Energie und Klima. ▁AY1 M ▁G OW1 IH0 NG ▁T UW1 ▁T AO1 K ▁T AH0 D EY1 ▁AH0 B AW1 T ▁EH1 N ER0 JH IY0 ▁AH0 N D ▁K L AY1 M AH0 T spk.767_
ted_767_1 en-de/flac.zip:1214217978:205678 226080 Und das überrascht vielleicht etwas, weil sich meine Vollzeitbeschäftigung bei der Stiftung hauptsächlich um Impfstoffe und Saatgut dreht, um die Dinge, die wir erfinden und liefern müssen um den ärmsten 2 Milliarden ein besseres Leben zu ermöglichen. ▁AH0 N D ▁DH AE1 T ▁M AY1 T ▁S IY1 M ▁AH0 ▁B IH1 T ▁S ER0 P R AY1 Z IH0 NG ▁B IH0 K AO1 Z ▁M AY1 ▁F UH1 L ▁T AY1 M ▁W ER1 K ▁AE1 T ▁DH AH0 ▁F AW0 N D EY1 SH AH0 N ▁IH1 Z ▁M OW1 S T L IY0 ▁AH0 B AW1 T ▁V AE2 K S IY1 N Z ▁AH0 N D ▁S IY1 D Z ▁AH0 B AW1 T ▁DH AH0 ▁TH IH1 NG Z ▁DH AE1 T ▁W IY1 ▁N IY1 D ▁T UW1 ▁IH0 N V EH1 N T ▁AH0 N D ▁D IH0 L IH1 V ER0 ▁T UW1 ▁HH EH1 L P ▁DH AH0 ▁P UH1 R IH0 S T ▁T UW1 ▁B IH1 L Y AH0 N ▁L AY1 V ▁B EH1 T ER0 ▁L IH1 V Z spk.767_
```
- Prepare phoneme dictionary and save to $MANIFEST_ROOT as [src_dict.txt](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/must_c/en_de/src_dict.txt)
#### Prepare WMT text data
- [Download wmt data](https://github.com/pytorch/fairseq/blob/main/examples/translation/prepare-wmt14en2de.sh)
- Convert source text (English) into phoneme representation as above
- Generate binary parallel file for training (as translation example) and save data in $parallel_text_data
- Generate binary parallel files with "fairseq-preprocess" from fairseq for training and validation. The source input is English phoneme representation and the target input is German sentencepiece token . The output is saved under $parallel_text_data
## Training
The model is trained with 8 v100 GPUs.

View File

@ -0,0 +1,188 @@
[[Back]](..)
# Unified Speech-Text Pre-training for Speech Translation and Recognition
This directory contains the pre-training recipes from paper ["Unified Speech-Text Pre-training for Speech Translation and Recognition"](https://arxiv.org/abs/2204.05409).
## Librispeech ASR Pre-training
### Prepare Data
#### Download files
#### Prepare pre-training data
- Text to text task (T2T): prepare the binary data following the similar steps in [EN_DE Joint training](./ende-mustc.md). The source data is presented as phomeme token sequence and the target data is coded as subword tokens via SentencePiece. The text data is downloaded from [openslr](https://www.openslr.org/12)
- Self-supervised speech learning task (SSL): The data is prepared as [wav2vec 2.0](https://github.com/pytorch/fairseq/tree/main/examples/wav2vec/README.md)
- Speech to phoneme classification task (S2P): The tsv file contains 5 fields: "id", "audio", "n_frames", "tgt_text", and "align". The tgt_text field is corresponding to the phoneme based representation of the speech data. "align" field contains the alignment information. The phoneme level forced alignment for the labelled speech data (i.e. Librispeech) can be obtained via [kaldi](http://kaldi-asr.org) or [MFA](https://montrealcorpustools.github.io/Montreal-Forced-Aligner/). The segmentation information is normalized to 0$\sim$1 for the whole utterance. The snapshot of the tsv file is below:
```
id audio n_frames tgt_text align
116-288045-0000 /librispeech/dev-other/116/288045/116-288045-0000.flac 170400 <sil> ▁AE1 Z AY1 ▁AH0 P R OW1 CH T ▁DH AH1 ▁S IH1 T IY0 <sil> AY1 ▁HH ER1 D ▁B EH1 L Z ▁R IH1 NG IH0 NG <sil> ▁AE1 N D AH0 ▁L IH1 T AH0 L ▁L EY1 T ER0 AY1 ▁F AW1 N D ▁DH AH0 ▁S T R IY1 T S ▁AH0 S T IH1 R ▁W IH0 TH ▁TH R AO1 NG Z ▁AH0 V ▁W EH1 L ▁D R EH1 S T ▁P IY1 P AH0 L ▁IH1 N ▁F AE1 M L IY0 ▁G R UW1 P S <sil> ▁W EH1 N D IH0 NG ▁DH EH1 R ▁W EY1 <sil> ▁HH IH1 DH ER0 ▁AH0 N D ▁TH IH1 DH ER0 <sil> 0.047977 0.056444 0.064911 0.075259 0.081844 0.089370 0.095014 0.104421 0.109125 0.111947 0.115710 0.120414 0.134525 0.141110 0.143932 0.174036 0.176858 0.190028 0.199436 0.207902 0.218250 0.224835 0.231421 0.242709 0.251176 0.257761 0.263405 0.268109 0.270931 0.290687 0.342427 0.349953 0.353716 0.356538 0.360301 0.363123 0.365945 0.368768 0.371590 0.376294 0.384760 0.394167 0.401693 0.409219 0.419567 0.430856 0.441204 0.444026 0.446849 0.449671 0.456256 0.463782 0.471308 0.477893 0.486359 0.491063 0.494826 0.501411 0.512700 0.517404 0.520226 0.534337 0.540922 0.545626 0.550329 0.559737 0.568203 0.583255 0.592662 0.600188 0.603951 0.611477 0.619003 0.624647 0.634055 0.639699 0.646284 0.653810 0.659454 0.664158 0.670743 0.682032 0.687676 0.692380 0.708373 0.713076 0.719661 0.729069 0.740357 0.744120 0.748824 0.752587 0.761994 0.770461 0.781750 0.790216 0.805268 0.808090 0.823142 0.832549 0.836312 0.840075 0.843838 0.851364 0.854186 0.857008 0.862653 0.878645 0.898401 0.901223 0.906867 0.913452 0.920038 0.926623 0.934149 0.939793 0.942615 0.945437 0.952023 0.957667 0.977422 1.000000
```
- Speech to text task (S2T): The data preparation follow the steps in [EN_DE Joint training](./ende-mustc.md).
#### Prepare fine-tuning data:
We re-use the data from T2T and S2T tasks in the fine-tuning stage.
### Model Build
#### Pre-training
```
python train.py $T2T_DATA \
--save-dir $SAVE_PRE_PATH --user-dir examples/speech_text_joint_to_text --task speech_text_joint_denoising \
--criterion speech_text_pretrain_cross_entropy --optimizer adam --weight-decay 0.01 --config-yaml config_s2p.yaml --config-s2s-yaml config.yaml --ddp-backend no_c10d \
--lang-pairs pho-wrd --num-workers 4 --log-interval 500 --save-interval-updates 5000 --keep-interval-updates 1 --no-emb-update-unsup --report-accuracy --lr 0.001 --end-learning-rate 1e-06 \
--lr-scheduler polynomial_decay --warmup-updates 10000 --total-num-update 800000 --update-freq 6 --validate-interval-updates 10000 --train-subset train \
--valid-subset valid,valid_sup_speech,valid_sup_speech_s2s,valid_unsup_speech --dataset-impl mmap \
--sup-speech-data $S2P_DATA_PATH --sup-speech-train-subset train_960.ali --sup-speech-valid-subset dev-clean.ali --sup-speech-s2s-data $S2T_DATA_PATH \
--sup-speech-s2s-train-subset train --sup-speech-s2s-valid-subset dev-clean --unsup-speech-train-data $SSL_DATA_PATH/train.tsv --unsup-speech-valid-data $SSL_DATA_PATH/valid.tsv \
--batch-size 200 --batch-size-valid 150 --max-source-positions 1024 --max-target-positions 1024 --max-text-tokens 3072 --max-speech-positions 600000 \
--max-sample-size 750000 --min-sample-size 64000 --max-speech-tokens 750000 --max-tokens-valid 750000 --skip-invalid-size-inputs-valid-test \
--unsupervised-speech-sample-ratio 3.0 --supervised-speech-sample-ratio 5 --supervised-speech-s2s-sample-ratio 5 --text-sample-ratio 1.0 --mask 0.3 --mask-random 0.1 \
--mask-length span-poisson --speech-sup-mask-prob 0.3 --speech-unsup-mask-prob 0.7 --use-mask-whole-words --arch speech_text_pretrain_bart_base_stack \
--no-scale-feature --activation-fn gelu --speech-extractor-mode default --stacked-encoder all --encoder-normalize-before --decoder-normalize-before \
--encoder-learned-pos --decoder-learned-pos --dropout 0.1 --load-pretrained-mbart-encoder-from $BART --load-pretrained-mbart-decoder-from $BART
```
The current implementation also supports model pre-training without the forced alignment supervised data. In this case, CTC is used to optimize the S2P task. We need to do following changes for the setting:
1. options to be added
```
--use-sup-speech-ctc --criterion speech_text_pretrain_compound
```
2. options to be deleted
```
--same-data-update --criterion speech_text_pretrain_cross_entropy
```
However, we find the CTC based pre-training is still worse than the forced alignment based setting. It could be partially due to the inferior pre-training setting that we re-use the forced alignment based pre-training setting for the CTC based pre-training.
#### Fine-tuning
```
python train.py $S2T_DATA_PATH \
--save-dir $SAVE_FT_PATH --num-workers 8 --task speech_text_joint_to_text --arch dualinputs2twavtransformer_base_stack \
--user-dir examples/speech_text_joint_to_text --max-update 100000 --optimizer adam --lr-scheduler inverse_sqrt --lr 0.0003 --update-freq 3 --clip-norm 10.0 \
--criterion guided_label_smoothed_cross_entropy_with_accuracy --guide-alpha 0.8 --label-smoothing 0.1 --warmup-updates 20000 --attentive-cost-regularization 0.02 \
--enc-grad-mult 2.0 --max-tokens 800000 --max-source-positions 800000 --max-tokens-text 10000 --max-positions-text 1024 --max-target-positions 1024 --no-scale-feature \
--activation-fn gelu --load-pretrained-speech-text-encoder $SAVE_PRE_PATH/checkpoint_last.pt --load-pretrained-speech-text-decoder $SAVE_PRE_PATH/checkpoint_last.pt \
--encoder-normalize-before --decoder-normalize-before --speech-extractor-mode default --speech-mask-channel-length 64 --speech-mask-channel-prob 0.5 \
--speech-mask-length 10 --speech-mask-prob 0.65 --text-sample-ratio 0.25 --mask-text-ratio 0.3 --mask-text-type random --parallel-text-data text_bin \
--text-input-cost-ratio 0.5 --langpairs pho-wrd --update-mix-data --log-format json --max-tokens-valid 800000 --ddp-backend no_c10d --log-interval 500 \
--config-yaml config.yaml --skip-invalid-size-inputs-valid-test --keep-last-epochs 50 --layernorm-embedding --encoder-learned-pos --decoder-learned-pos
```
### Evaluation
The last 10 epoch models from fine-tuning is conducted model average to get $FINAL_MODEL
```
python ./fairseq_cli/generate.py \
$S2T_DATA_PATH \
--task speech_text_joint_to_text \
--max-tokens 800000 \
--max-source-positions 800000 \
--nbest 1 \
--results-path $RESULTS_LOG \
--batch-size 512 \
--path $FINAL_MODEL \
--gen-subset $SUBSET \
--config-yaml config.yaml \
--scoring wer \
--beam 10 --lenpen 1.0 examples/speech_text_joint_to_text
```
### Results and models
| | dev-clean | dev-other | test-clean | test-other |
|---|---|---|---|---|
| WER| 2.0 | 4.4 | 2.1 |4.6 |
**Model Links**:
- [config_s2p.yaml](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/librispeech/pretrain/config_s2p.yaml): Config for S2P
- [spm.model](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/librispeech/finetuned/spm.model): Sentence Piece model
- [src_dict.txt](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/librispeech/finetuned/src_dict.txt): Source Phoneme Dictionary
- [tgt_dict.txt](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/librispeech/finetuned/tgt_dict.txt): Target Sentence Piece Dictionary
- [config.yaml](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/librispeech/finetuned/config.yaml): Config for S2T
- [BART](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/librispeech/pretrain/bart.pt): trained from Librispeech text data
- [Joint Pre-trained model](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/librispeech/pretrain/checkpoint6.pt): model pre-trained with 960 hours Librispeech data (S2P, S2T) Librispeech text training data (T2T) and Librilight data (SSL)
- [Fine-tuned model](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/librispeech/finetuned/checkpoint_ave_10.pt): the pre-trained model is fined one 960 hours Librispeech speech and text data. (S2T + T2T)
## MuST-C
### Prepare Data
Compared with the ASR Librispeech ASR recipe, the differences are below:
- Replace the speech data with corresponding MuST-C data
- Parallel text data from WMT is replaced the Librispeech text data
### Model Build
#### Pre-training
EN-DE is used as an example
```
python train.py $TXT_DATA \
--save-dir $SAVE_PRE_PATH --user-dir examples/speech_text_joint_to_text --task speech_text_joint_denoising --criterion speech_text_pretrain_cross_entropy --optimizer adam --weight-decay 0.01 \
--config-yaml config_s2p.yaml --config-s2s-yaml config.yaml --ddp-backend no_c10d --lang-pairs-bitext en-fr --num-workers 4 --log-interval 500 --save-interval-updates 5000 --keep-interval-updates 1 \
--no-emb-update-unsup --use-decoder-output-proj --report-accuracy --lr 0.001 --end-learning-rate 1e-06 --lr-scheduler polynomial_decay --warmup-updates 10000 --total-num-update 800000 \
--update-freq 8 --validate-interval-updates 10000 --train-subset train --valid-subset valid_sup_speech,valid_sup_speech_s2s,valid_unsup_speech --dataset-impl mmap \
--sup-speech-data $S2P_DATA_PATH --sup-speech-train-subset train --sup-speech-valid-subset dev --sup-speech-s2s-data $S2T_DATA_PATH --sup-speech-s2s-train-subset train \
--sup-speech-s2s-valid-subset dev --unsup-speech-train-data $SSL_DATA_PATH/train.tsv --unsup-speech-valid-data $SSL_DATA_PATH/valid.tsv --batch-size 200 --batch-size-valid 100 \
--max-source-positions 1024 --max-target-positions 1024 --max-text-tokens 2048 --max-speech-positions 600000 --max-sample-size 600000 --min-sample-size 64000 \
--max-speech-tokens 600000 --max-tokens-valid 600000 --skip-invalid-size-inputs-valid-test --unsupervised-speech-sample-ratio 1.2 --supervised-speech-sample-ratio 10 \
--supervised-speech-s2s-sample-ratio 10 --bitext-sample-ratio 0.5 --mask 0.3 --mask-random 0.1 --mask-length span-poisson --speech-sup-mask-prob 0.3 \
--speech-unsup-mask-prob 0.7 --use-mask-whole-words --arch speech_text_pretrain_bart_base_stack --no-scale-feature --activation-fn gelu --speech-extractor-mode default \
--stacked-encoder s2s --encoder-normalize-before --decoder-normalize-before --encoder-learned-pos --decoder-learned-pos --dropout 0.1 \
--load-pretrained-mbart-encoder-from $EN_FR_NMT --load-pretrained-mbart-decoder-from $EN_FR_NMT
```
#### Fine-tuning
```
python train.py $S2T_DATA_PATH \
--save-dir $SAVE_FT_PATH --num-workers 8 --task speech_text_joint_to_text --arch dualinputs2twavtransformer_base_stack --user-dir examples/speech_text_joint_to_text \
--max-epoch 25 --update-mix-data --optimizer adam --lr-scheduler inverse_sqrt --lr 0.0003 --update-freq 4 --clip-norm 10.0 --warmup-updates 20000 \
--criterion guided_label_smoothed_cross_entropy_with_accuracy --guide-alpha 0.8 --attentive-cost-regularization 0.02 --enc-grad-mult 2.0 --label-smoothing 0.1 \
--max-tokens 800000 --max-source-positions 800000 --max-tokens-text 10000 --max-positions-text 1024 --load-pretrained-speech-text-encoder $SAVE_PRE_PATH/checkpoint_last.pt \
--load-pretrained-speech-text-decoder $SAVE_PRE_PATH/checkpoint_last.pt --speech-mask-channel-length 64 --speech-mask-channel-prob 0.5 --speech-mask-length 10 \
--speech-mask-prob 0.65 --text-sample-ratio 0.05 --mask-text-ratio 0.3 --mask-text-type random --parallel-text-data data-bin-wt --text-input-cost-ratio 0.5 \
--langpairs en-fr --log-format json --max-tokens-valid 800000 --ddp-backend no_c10d --log-interval 100 --config-yaml config.yaml --skip-invalid-size-inputs-valid-test \
--noise-token '▁NOISE' --keep-last-epochs 40 --layernorm-embedding --encoder-learned-pos --decoder-learned-pos --activation-fn gelu \
--speech-extractor-mode default --max-target-positions 1024 --encoder-normalize-before --decoder-normalize-before
```
### Evaluation
The last 10 epoch models from fine-tuning is conducted model average to get $FINAL_MODEL
```
python fairseq_cli/generate.py \
$S2T_DATA_PATH \
--task speech_text_joint_to_text \
--nbest 1 \
--max-tokens 800000 \
--max-source-positions 800000 \
--results-path $RESULTS_LOG \
--batch-size 512 \
--path $FINAL_MODEL \
--gen-subset $SUBSET \
--config-yaml config.yaml \
--scoring sacrebleu \
--beam 10 --lenpen 1.0 examples/speech_text_joint_to_text
```
### Results and models
| | en-fr | en-es | en-de |
|---|---|---|---|
| BLEU| 39.7 | 33.2 |29.2 |
**Model Links**:
1. DE
- [de config.yaml](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/must_c/de/config.yaml)
- [de src_dict.txt](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/must_c/de/src_dict.txt)
- [de tgt_dict.txt](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/must_c/de/tgt_dict.txt)
- [de spm.model](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/must_c/de/spm.model)
- [de pre-trained nmt model](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/must_c/de/nmt.pt)
- [de pre-trained model](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/must_c/de/checkpoint_pretraing.pt)
- [de fine-tuned model](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/must_c/de/checkpoint_finetune_ave10.pt)
2. ES
- [es config.yaml](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/must_c/es/config.yaml)
- [es src_dict.txt](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/must_c/es/src_dict.txt)
- [es tgt_dict.txt](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/must_c/es/tgt_dict.txt)
- [es spm.model](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/must_c/es/spm.model)
- [es pre-trained nmt model](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/must_c/es/nmt.pt)
- [es pre-trained model](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/must_c/es/checkpoint_pretraing.pt)
- [es fine-tuned model](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/must_c/es/checkpoint_finetune_ave10.pt)
3. FR
- [fr config.yaml](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/must_c/fr/config.yaml)
- [fr src_dict.txt](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/must_c/fr/src_dict.txt)
- [fr tgt_dict.txt](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/must_c/fr/tgt_dict.txt)
- [fr spm.model](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/must_c/fr/spm.model)
- [fr pre-trained nmt model](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/must_c/fr/nmt.pt)
- [fr pre-trained model](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/must_c/fr/checkpoint_pretraing.pt)
- [fr fine-tuned model](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/must_c/fr/checkpoint_finetune_ave10.pt)
4. [config_s2p.yaml](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/must_c/config_s2p.yaml)

View File

@ -0,0 +1,698 @@
#!/usr/bin/env python3
import logging
from collections import OrderedDict, namedtuple
from typing import Dict, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from fairseq import checkpoint_utils, utils
from fairseq.file_io import PathManager
from fairseq.models import (
FairseqDecoder,
FairseqEncoderDecoderModel,
register_model,
register_model_architecture,
)
from fairseq.models.speech_to_text import (
MultiInputDecoder,
MultiModalityEncoder,
SpeechWavTransformerEncoder,
StackedSpeechWavTransformerEncoder,
)
from fairseq.models.transformer import (
TransformerDecoder,
TransformerEncoder,
TransformerModel,
)
logger = logging.getLogger(__name__)
class SpeechTextPreTrainEncoder(MultiModalityEncoder):
def __init__(
self,
dictionary,
sup_speech_encoder,
sup_s2s_speech_encoder,
unsup_speech_encoder,
text_encoder,
):
super().__init__(dictionary)
self.sup_speech_encoder = sup_speech_encoder
self.sup_s2s_speech_encoder = sup_s2s_speech_encoder
self.unsup_speech_encoder = unsup_speech_encoder
self.text_encoder = text_encoder
@classmethod
def update_transformer_encoder_cfg(cls, args, update_dict):
cfg = dict(args._get_kwargs())
for fkey in update_dict.keys():
cfg[fkey] = update_dict[fkey]
cfg.pop("_name", None) # remove keys start with _
model_args = namedtuple("args", cfg.keys())(*cfg.values())
return model_args
@classmethod
def build_text_encoder(cls, args, src_dictionary):
enc_emb = nn.Embedding(
len(src_dictionary), args.encoder_embed_dim, src_dictionary.pad()
)
model_args = cls.update_transformer_encoder_cfg(
args, {"encoder_layers": args.text_encoder_layers}
)
text_encoder = TransformerEncoder(model_args, src_dictionary, enc_emb)
return text_encoder
@classmethod
def build_speech_encoder(cls, args):
model_args = cls.update_transformer_encoder_cfg(
args,
{
"encoder_layers": args.speech_encoder_layers,
"speech_mask_prob": args.speech_sup_mask_prob,
},
)
speech_encoder = SpeechWavTransformerEncoder(model_args)
return speech_encoder
@classmethod
def share_layers(cls, src_layers, tgt_layers): # share layer but not dropout
# share parameters in src_layers with tgt_layers
assert len(src_layers) == len(tgt_layers)
for i, ly in enumerate(src_layers):
tly = tgt_layers[i]
tly.self_attn = ly.self_attn
tly.self_attn_layer_norm = ly.self_attn_layer_norm
tly.activation_fn = ly.activation_fn
tly.normalize_before = ly.normalize_before
tly.fc1 = ly.fc1
tly.fc2 = ly.fc2
tly.final_layer_norm = ly.final_layer_norm
if hasattr(tly, "encoder_attn"):
tly.encoder_attn = ly.encoder_attn
tly.encoder_attn_layer_norm = ly.encoder_attn_layer_norm
return tgt_layers
@classmethod
def build_unsup_speech_encoder(cls, args, sup_speech_encoder):
model_args = cls.update_transformer_encoder_cfg(
args,
{
"encoder_layers": args.speech_encoder_layers,
"speech_mask_prob": args.speech_unsup_mask_prob,
"encoder_layerdrop": 0.0,
"decoder_layerdrop": 0.0,
"dropout": args.speech_unsup_dropout,
"activation_dropout": args.speech_unsup_dropout,
"attention_dropout": 0.0,
"dropout_features": args.speech_unsup_feature_dropout,
"dropout_input": args.speech_unsup_feature_dropout,
},
)
unsup_speech_encoder = SpeechWavTransformerEncoder(model_args, alway_mask=True)
unsup_speech_encoder.layer_norm = sup_speech_encoder.layer_norm
unsup_speech_encoder.layers = cls.share_layers(
sup_speech_encoder.layers, unsup_speech_encoder.layers
)
unsup_speech_encoder.mask_emb = sup_speech_encoder.mask_emb
unsup_speech_encoder.embed_positions = sup_speech_encoder.embed_positions
unsup_speech_encoder.feat_layer_norm = sup_speech_encoder.feat_layer_norm
unsup_speech_encoder.feat_proj = sup_speech_encoder.feat_proj
unsup_speech_encoder.subsample = sup_speech_encoder.subsample
return unsup_speech_encoder
@classmethod
def build_encoder(cls, args, dictionary):
text_encoder = cls.build_text_encoder(args, dictionary)
if getattr(args, "load_pretrained_mbart_encoder_from", None):
text_encoder = checkpoint_utils.load_pretrained_component_from_model(
component=text_encoder,
checkpoint=args.load_pretrained_mbart_encoder_from,
)
speech_encoder = cls.build_speech_encoder(args)
if getattr(args, "load_pretrained_feature_extractor_from", None):
def load_feature_extractor(component, checkpoint):
if not PathManager.exists(checkpoint):
raise IOError("Model file not found: {}".format(checkpoint))
state = checkpoint_utils.load_checkpoint_to_cpu(checkpoint)
component_state_dict = OrderedDict()
component_prefix = "feature_extractor"
for key in state["model"].keys():
if key.startswith(component_prefix):
component_subkey = key[len(component_prefix) + 1 :]
component_state_dict[component_subkey] = state["model"][key]
component.load_state_dict(component_state_dict, strict=True)
return component
speech_encoder.subsample = load_feature_extractor(
speech_encoder.subsample, args.load_pretrained_feature_extractor_from
)
speech_s2s_encoder = speech_encoder
unsup_speech_encoder = cls.build_unsup_speech_encoder(args, speech_encoder)
if getattr(args, "stacked_encoder", "none") != "none":
if args.encoder_shared_text_layers_from_begin > 0:
raise ValueError(
"We can not stack encoders and share encoders at the same time!"
)
speech_s2s_encoder = StackedSpeechWavTransformerEncoder(
speech_encoder, text_encoder.layers, text_encoder.layer_norm
)
if args.stacked_encoder == "all":
speech_encoder = speech_s2s_encoder
unsup_speech_encoder = StackedSpeechWavTransformerEncoder(
unsup_speech_encoder, text_encoder.layers, text_encoder.layer_norm
)
else:
cls.share_speech_text_encoder(
speech_encoder, text_encoder, args.encoder_shared_text_layers_from_begin
)
return SpeechTextPreTrainEncoder(
dictionary,
speech_encoder,
speech_s2s_encoder,
unsup_speech_encoder,
text_encoder,
)
@classmethod
def share_speech_text_encoder(
cls, speech_encoder, text_encoder, shared_layers_from_begin
):
if shared_layers_from_begin > 0:
num_text_encoder_layers = len(text_encoder.layers)
assert len(speech_encoder.layers) >= shared_layers_from_begin
assert num_text_encoder_layers >= shared_layers_from_begin
assert len(speech_encoder.layers) >= num_text_encoder_layers
for i, ly in enumerate(
speech_encoder.layers[
-num_text_encoder_layers : -num_text_encoder_layers
+ shared_layers_from_begin
]
):
assert isinstance(text_encoder.layers[i], type(ly))
text_encoder.layers[i] = ly
def select_encoder(self, mode, **kwargs):
if mode in ("speech", "sup_speech_ctc", "sup_speech_ali", "sup_speech_s2s"):
kwargs["features_only"] = True
if mode == "sup_speech_s2s":
return self.sup_s2s_speech_encoder, kwargs
return self.sup_speech_encoder, kwargs
elif mode == "unsup_speech":
kwargs["features_only"] = False
return self.unsup_speech_encoder, kwargs
elif mode in ("text", "bitext"):
return self.text_encoder, kwargs
else:
raise NotImplementedError(f"{mode} is not supported")
return None, kwargs
def forward(self, src_tokens, src_lengths=None, mode="", alignment=None, **kwargs):
return super().forward(src_tokens, src_lengths, mode, **kwargs)
# SpeechDummyDecoder works as an extension of encoder, so we could fit encoder only training into seq2seq training
class SpeechDummyDecoder(FairseqDecoder):
def __init__(
self,
dictionary,
output_embedding,
no_emb_update_unsup=False,
use_output_proj=False,
):
super().__init__(dictionary)
self.output_embedding = output_embedding
num_embedding, num_dim = self.output_embedding.weight.size()
self.out_proj = (
None if use_output_proj is False else nn.Linear(num_dim, num_dim)
)
self.no_emb_update_unsup = no_emb_update_unsup
def extend_alignment(self, alignment, src_lengths, prev_output_tokens):
# alignment: B X N
# src_lengths: B X T
# prev_output_tokens: B X (N + 1)
tgt_tokens = prev_output_tokens[
:, 1:
] # remove the leading start of sentence token
ext_alignment = (
torch.ones(len(src_lengths), src_lengths.max(), device=src_lengths.device)
.long()
.fill_(self.dictionary.pad())
)
for bs in range(src_lengths.size(0)):
tgt_length = tgt_tokens[bs].ne(self.dictionary.pad()).sum().item()
assert tgt_length == sum(alignment[bs].ne(1)) + 1
src_st = 0
for i in range(tgt_length):
tok = tgt_tokens[bs][i]
src_ed = (alignment[bs][i] * src_lengths[bs]).int().item()
ext_alignment[bs][src_st:src_ed].fill_(tok)
src_st = src_ed
return ext_alignment
def forward(
self,
prev_output_tokens,
encoder_out,
incremental_state=None,
mode="speech",
alignment=None,
**kwargs,
):
"""
Args:
prev_output_tokens (LongTensor): previous decoder outputs of shape
`(batch, tgt_len)`, for teacher forcing
encoder_out (optional): output from the encoder, used for
encoder-side attention
incremental_state (dict): dictionary used for storing state during
:ref:`Incremental decoding`
features_only (bool, optional): only return features without
applying output layer (default: False).
full_context_alignment (bool, optional): don't apply
auto-regressive mask to self-attention (default: False).
Returns:
sup_speech_ctc:
dictionary{"logits": logits, "padding_mask": padding_mask}
sup_speech_ali and unsup_speech:
tuple:
- the decoder's output of shape `(batch, tgt_len, vocab)`
- a dictionary with any model-specific outputs
"""
emb_weight = self.output_embedding.weight
if (
mode == "unsup_speech" and self.no_emb_update_unsup
): # no gradient for embedding here
emb_weight = emb_weight.detach()
enc_out = (
encoder_out["encoder_out"][0]
if self.out_proj is None
else self.out_proj(encoder_out["encoder_out"][0])
)
logits = F.linear(enc_out, emb_weight, None).transpose(0, 1) # B X T X C
others = None
if mode in (
"speech",
"sup_speech_ctc",
): # speech data with label, do forcealignment
if len(encoder_out["encoder_padding_mask"]) > 0:
padding_mask = encoder_out["encoder_padding_mask"][0]
logits = logits.masked_fill(padding_mask, float("-inf"))
else:
seq_len, bsz = encoder_out["encoder_out"][0].size()[:2]
padding_mask = torch.zeros(
bsz, seq_len, device=encoder_out["encoder_out"][0].device
).bool()
return {"x": logits, "padding_mask": padding_mask}
elif mode == "sup_speech_ali":
src_lengths = None
if len(encoder_out["encoder_padding_mask"]) > 0:
src_lengths = (1 - encoder_out["encoder_padding_mask"][0].long()).sum(
-1
)
else:
seq_len, bsz = encoder_out["encoder_out"][0].size()[:2]
src_lengths = (
torch.ones(bsz, device=encoder_out["encoder_out"][0].device).long()
* seq_len
)
assert alignment is not None
alignment = self.extend_alignment(
alignment, src_lengths, prev_output_tokens
)
others = {"pseudo_target_tokens": alignment}
elif mode == "unsup_speech":
enc_out_ori = (
encoder_out["encoder_unmasked_out"][0]
if self.out_proj is None
else self.out_proj(encoder_out["encoder_unmasked_out"][0])
)
logits_ori = F.linear(enc_out_ori, emb_weight, None).transpose(0, 1)
if len(encoder_out["encoder_padding_mask"]) > 0:
encoder_padding_mask = encoder_out["encoder_padding_mask"][0]
logits_ori = logits_ori.masked_fill(encoder_padding_mask, float("-inf"))
pseudo_labels = utils.log_softmax(logits_ori, dim=-1)
others = {
"pseudo_target_logprobs": pseudo_labels,
"padding_mask": encoder_out["encoder_padding_mask"], # B X T
"mask_indices": encoder_out[
"mask_indices"
], # True for masked frames B X T
}
return logits, others
def get_normalized_probs(
self,
net_output: Dict[str, Tensor],
log_probs: bool,
sample: Optional[Dict[str, Tensor]] = None,
):
return self.get_normalized_probs_scriptable(
(net_output["x"], None), log_probs, sample
)
class SpeechTextPreTrainDecoder(MultiInputDecoder):
def __init__(self, dictionary, speech_decoder, text_decoder):
super().__init__(dictionary)
self.speech_decoder = speech_decoder
self.text_decoder = text_decoder
def select_decoder(self, mode, **kwargs):
if mode == "unsup_speech":
kwargs["mode"] = mode
return self.speech_decoder, kwargs
if mode in ("text", "bitext"):
return self.text_decoder, kwargs
if mode in ("speech", "sup_speech_ctc", "sup_speech_ali"):
kwargs["mode"] = mode
return self.speech_decoder, kwargs
if mode in ("speech", "sup_speech_s2s"):
if "alignment" in kwargs:
del kwargs["alignment"]
return self.text_decoder, kwargs
raise NotImplementedError(f"{mode} is not supported")
return None, kwargs
def get_normalized_probs(
self,
net_output,
log_probs,
sample=None,
):
"""Get normalized probabilities (or log probs) from a net's output."""
if isinstance(net_output, dict):
return self.speech_decoder.get_normalized_probs(
net_output, log_probs, sample
)
return self.text_decoder.get_normalized_probs(net_output, log_probs, sample)
@classmethod
def build_text_decoder(cls, args, tgt_dictionary, dec_emb_share=None):
dec_emb = (
nn.Embedding(
len(tgt_dictionary), args.decoder_embed_dim, tgt_dictionary.pad()
)
if dec_emb_share is None
else dec_emb_share
)
text_decoder = TransformerDecoder(args, tgt_dictionary, dec_emb)
return text_decoder
@classmethod
def build_dummy_speech_decoder(cls, args, dictionary, dec_emb_share=None):
dec_emb = (
nn.Embedding(len(dictionary), args.decoder_embed_dim, dictionary.pad())
if dec_emb_share is None
else dec_emb_share
)
speech_decoder = SpeechDummyDecoder(
dictionary,
dec_emb,
no_emb_update_unsup=getattr(args, "no_emb_update_unsup", False),
use_output_proj=getattr(args, "use_decoder_output_proj", False),
)
return speech_decoder
@classmethod
def build_decoder(
cls, args, text_dictionary, speech_dictionary, speech_output_embedding
):
text_decoder = cls.build_text_decoder(args, text_dictionary)
speech_decoder = cls.build_dummy_speech_decoder(
args, speech_dictionary, speech_output_embedding
)
if getattr(args, "load_pretrained_mbart_decoder_from", None):
text_decoder = checkpoint_utils.load_pretrained_component_from_model(
component=text_decoder,
checkpoint=args.load_pretrained_mbart_decoder_from,
)
return SpeechTextPreTrainDecoder(text_dictionary, speech_decoder, text_decoder)
@register_model("speech_text_pretrain_bart")
class SpeechTextPreTrainModel(FairseqEncoderDecoderModel):
def __init__(self, encoder, decoder):
super().__init__(encoder, decoder)
self.num_updates = 0
def forward(
self, src_tokens, src_lengths, prev_output_tokens, src_lang_ids=None, **kwargs
):
if src_lang_ids is not None:
encoder_out = self.encoder(
src_tokens, src_lengths=src_lengths, src_lang_ids=src_lang_ids, **kwargs
)
else:
encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs)
decoder_out = self.decoder(
prev_output_tokens, encoder_out=encoder_out, **kwargs
)
return decoder_out
def max_positions(self):
return None # it is provided in task
def get_targets(self, sample, net_output):
mode = sample["net_input"]["mode"]
if mode == "unsup_speech":
return {"target_logprobs": net_output[1]["pseudo_target_logprobs"]}
if mode == "sup_speech_ali":
return net_output[1]["pseudo_target_tokens"]
return sample["target"]
def get_normalized_probs(
self,
net_output,
log_probs,
sample=None,
):
# net_output['encoder_out'] is a (B, T, D) tensor
lprobs = self.get_normalized_probs_scriptable(net_output, log_probs, sample)
lprobs.batch_first = True
return lprobs
@staticmethod
def add_args(parser):
TransformerModel.add_args(parser)
SpeechWavTransformerEncoder.add_args(parser)
parser.add_argument(
"--speech-sup-mask-prob",
type=float,
help="probability of replacing a token with mask (sup-speech)",
)
parser.add_argument(
"--speech-unsup-mask-prob",
type=float,
help="probability of replacing a token with mask (unsup-speech)",
)
parser.add_argument(
"--load-pretrained-mbart-encoder-from",
type=str,
metavar="STR",
help="model to take text encoder weights from (for initialization)",
)
parser.add_argument(
"--load-pretrained-mbart-decoder-from",
type=str,
metavar="STR",
help="model to take text decoder weights from (for initialization)",
)
parser.add_argument(
"--load-pretrained-feature-extractor-from",
type=str,
metavar="STR",
help="model to take feature extractor weights from (for initialization)",
)
parser.add_argument(
"--speech-unsup-dropout",
type=float,
default=0,
help="dropout for unsupervised speech encoder",
)
parser.add_argument(
"--speech-unsup-feature-dropout",
type=float,
default=0,
help="dropout for unsupervised speech feature encoder",
)
parser.add_argument(
"--encoder-shared-text-layers-from-begin",
type=int,
help="number of text encoder layers shared with speech encoder (from first layer)",
)
parser.add_argument(
"--stacked-encoder",
default="none",
choices=["none", "s2s", "all"],
help="stack speech and text encoders",
)
parser.add_argument("--use-decoder-output-proj", action="store_true")
@classmethod
def build_model(cls, args, task):
encoder = SpeechTextPreTrainEncoder.build_encoder(args, task.src_dict)
decoder = SpeechTextPreTrainDecoder.build_decoder(
args, task.tgt_dict, task.src_dict, encoder.text_encoder.embed_tokens
)
model = SpeechTextPreTrainModel(encoder, decoder)
return model
def upgrade_state_dict(self, state_dict):
"""Upgrade old state dicts to work with newer code."""
if "decoder.speech_decoder.output_projection.weight" in state_dict:
del state_dict["decoder.speech_decoder.output_projection.weight"]
self.upgrade_state_dict_named(state_dict, "")
@register_model_architecture(
"speech_text_pretrain_bart", "speech_text_pretrain_bart_base"
)
def speech_text_pretrain_bart_base(args):
# speech masking
args.dropout_input = getattr(args, "dropout_input", 0)
args.dropout_features = getattr(args, "dropout_features", 0)
args.speech_mask_length = getattr(args, "speech_mask_length", 10)
args.speech_mask_prob = getattr(args, "speech_mask_prob", 0.65)
args.speech_sup_mask_prob = getattr(args, "speech_sup_mask_prob", 0.3)
args.speech_unsup_mask_prob = getattr(
args, "speech_unsup_mask_prob", args.speech_mask_prob
)
args.speech_mask_selection = getattr(args, "speech_mask_selection", "static")
args.speech_mask_other = getattr(args, "speech_mask_other", 0)
args.speech_mask_min_space = getattr(args, "speech_mask_min_space", 1)
args.speech_no_mask_overlap = getattr(args, "speech_no_mask_overlap", False)
args.speech_mask_channel_length = getattr(args, "speech_mask_channel_length", 10)
args.speech_mask_channel_prob = getattr(args, "speech_mask_channel_prob", 0.0)
args.speech_mask_channel_selection = getattr(
args, "speech_mask_channel_selection", "static"
)
args.speech_mask_channel_other = getattr(args, "speech_mask_channel_other", 0)
args.speech_mask_channel_min_space = getattr(
args, "speech_mask_channel_min_space", 1
)
args.speech_no_mask_channel_overlap = getattr(
args, "speech_no_mask_channel_overlap", False
)
args.no_scale_feature = getattr(args, "", False)
args.feature_grad_mult = getattr(args, "feature_grad_mult", 1.0) # 0.1
# Transformer
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 768)
args.encoder_ffn_embed_dim = getattr(
args, "encoder_ffn_embed_dim", args.encoder_embed_dim * 4
)
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 12)
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0)
args.encoder_learned_pos = getattr(args, "encoder_learned_pos", False)
args.speech_conv_bias = getattr(args, "speech_conv_bias", False)
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim)
args.decoder_ffn_embed_dim = getattr(
args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim
)
args.decoder_attention_heads = getattr(
args, "decoder_attention_heads", args.encoder_attention_heads
)
args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False)
args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False)
args.dropout = getattr(args, "dropout", 0.1)
args.attention_dropout = getattr(args, "attention_dropout", args.dropout)
args.activation_dropout = getattr(args, "activation_dropout", 0.0)
args.activation_fn = getattr(args, "activation_fn", "relu") # gelu?
args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None)
args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
args.speech_unsup_dropout = getattr(args, "speech_unsup_dropout", 0)
args.speech_unsup_feature_dropout = getattr(args, "speech_unsup_feature_dropout", 0)
args.tie_adaptive_weights = getattr(args, "tie_adaptive_weights", False)
args.share_decoder_input_output_embed = getattr(
args, "share_decoder_input_output_embed", False
)
args.no_token_positional_embeddings = getattr(
args, "no_token_positional_embeddings", False
)
args.adaptive_input = getattr(args, "adaptive_input", False)
args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0.0)
args.decoder_output_dim = getattr(
args, "decoder_output_dim", args.decoder_embed_dim
)
args.layernorm_embedding = getattr(args, "layernorm_embedding", False)
args.no_scale_embedding = getattr(args, "no_scale_embedding", False)
args.quant_noise_pq = getattr(args, "quant_noise_pq", 0)
args.speech_encoder_layers = getattr(args, "speech_encoder_layers", 12)
args.text_encoder_layers = getattr(args, "text_encoder_layers", 6)
args.encoder_shared_text_layers_from_begin = getattr(
args, "encoder_shared_text_layers_from_begin", 6
)
args.decoder_layers = getattr(args, "decoder_layers", 6)
args.no_emb_update_unsup = getattr(args, "no_emb_update_unsup", False)
@register_model_architecture(
"speech_text_pretrain_bart", "speech_text_pretrain_bart_base_stack"
)
def speech_text_pretrain_bart_base_stack(args):
args.speech_encoder_layers = getattr(args, "speech_encoder_layers", 6)
args.text_encoder_layers = getattr(args, "text_encoder_layers", 6)
args.encoder_shared_text_layers_from_begin = getattr(
args, "encoder_shared_text_layers_from_begin", 0
)
args.stacked_encoder = getattr(args, "stacked_encoder", "all")
args.layernorm_embedding = getattr(args, "layernorm_embedding", True)
speech_text_pretrain_bart_base(args)
@register_model_architecture(
"speech_text_pretrain_bart", "speech_text_pretrain_bart_large"
)
def speech_text_pretrain_bart_large(args):
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024)
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16)
args.speech_encoder_layers = getattr(args, "speech_encoder_layers", 24)
args.text_encoder_layers = getattr(args, "text_encoder_layers", 12)
args.encoder_shared_text_layers_from_begin = getattr(
args, "encoder_shared_text_layers_from_begin", 12
)
args.decoder_layers = getattr(args, "decoder_layers", 12)
args.dropout = getattr(args, "dropout", 0.3)
speech_text_pretrain_bart_base(args)
@register_model_architecture(
"speech_text_pretrain_bart", "speech_text_pretrain_bart_large_stack"
)
def speech_text_pretrain_bart_large_stack(args):
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024)
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16)
args.speech_encoder_layers = getattr(args, "speech_encoder_layers", 6)
args.text_encoder_layers = getattr(args, "text_encoder_layers", 12)
args.encoder_shared_text_layers_from_begin = getattr(
args, "encoder_shared_text_layers_from_begin", 0
)
args.decoder_layers = getattr(args, "decoder_layers", 12)
args.stacked_encoder = getattr(args, "stacked_encoder", "s2s")
args.layernorm_embedding = getattr(args, "layernorm_embedding", True)
speech_text_pretrain_bart_base(args)

View File

@ -0,0 +1,526 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
from collections import OrderedDict, namedtuple
import torch.nn as nn
from fairseq import checkpoint_utils, utils
from fairseq.checkpoint_utils import load_checkpoint_to_cpu
from fairseq.file_io import PathManager
from fairseq.models import register_model, register_model_architecture
from fairseq.models.speech_to_text import (
SpeechWavTransformerEncoder,
StackedSpeechWavTransformerEncoder,
TransformerDecoder,
)
from fairseq.models.transformer import TransformerEncoder
from .s2t_dualinputtransformer import (
DualInputEncoder,
DualInputS2TTransformerModel,
TransformerMultiInputDecoder,
)
logger = logging.getLogger(__name__)
@register_model("dual_input_wav_transformer")
class DualInputWavTransformerModel(DualInputS2TTransformerModel):
def __init__(self, encoder, decoder):
super().__init__(encoder, decoder)
@staticmethod
def add_args(parser):
def add_transformer_args(parser):
# We can't use TransformerModel.add_args(parser), since it defines max-source-positions which is duplicated with tasks/speech_to_text.py
# Transformer
parser.add_argument(
"--activation-fn",
type=str,
default="relu",
choices=utils.get_available_activation_fns(),
help="activation function to use",
)
parser.add_argument(
"--dropout", type=float, metavar="D", help="dropout probability"
)
parser.add_argument(
"--attention-dropout",
type=float,
metavar="D",
help="dropout probability for attention weights",
)
parser.add_argument(
"--activation-dropout",
"--relu-dropout",
type=float,
metavar="D",
help="dropout probability after activation in FFN.",
)
parser.add_argument(
"--encoder-embed-dim",
type=int,
metavar="N",
help="encoder embedding dimension",
)
parser.add_argument(
"--encoder-ffn-embed-dim",
type=int,
metavar="N",
help="encoder embedding dimension for FFN",
)
parser.add_argument(
"--encoder-layers", type=int, metavar="N", help="num encoder layers"
)
parser.add_argument(
"--encoder-attention-heads",
type=int,
metavar="N",
help="num encoder attention heads",
)
parser.add_argument(
"--encoder-normalize-before",
action="store_true",
help="apply layernorm before each encoder block",
)
parser.add_argument(
"--decoder-embed-dim",
type=int,
metavar="N",
help="decoder embedding dimension",
)
parser.add_argument(
"--decoder-ffn-embed-dim",
type=int,
metavar="N",
help="decoder embedding dimension for FFN",
)
parser.add_argument(
"--decoder-layers", type=int, metavar="N", help="num decoder layers"
)
parser.add_argument(
"--decoder-attention-heads",
type=int,
metavar="N",
help="num decoder attention heads",
)
parser.add_argument(
"--decoder-normalize-before",
action="store_true",
help="apply layernorm before each decoder block",
)
parser.add_argument(
"--share-decoder-input-output-embed",
action="store_true",
help="share decoder input and output embeddings",
)
parser.add_argument(
"--layernorm-embedding",
action="store_true",
help="add layernorm to embedding",
)
parser.add_argument(
"--no-scale-embedding",
action="store_true",
help="if True, dont scale embeddings",
)
parser.add_argument(
"--encoder-learned-pos",
action="store_true",
help="use learned positional embeddings",
)
parser.add_argument(
"--decoder-learned-pos",
action="store_true",
help="use learned positional embeddings",
)
add_transformer_args(parser)
SpeechWavTransformerEncoder.add_args(parser)
parser.add_argument(
"--load-pretrained-speech-text-encoder",
type=str,
default="",
metavar="EXPR",
help=""" path to the pretrained speech text encoder from SpeechTextPreTrainModel """,
)
parser.add_argument(
"--load-pretrained-wav2vec-encoder",
type=str,
default="",
metavar="EXPR",
help=""" path to the pretrained speech text encoder from wav2vec """,
)
parser.add_argument(
"--load-pretrained-speech-text-decoder",
type=str,
default="",
metavar="EXPR",
help=""" path to the pretrained speech text decoder from SpeechTextPreTrainModel """,
)
parser.add_argument(
"--load-pretrained-text-decoder",
type=str,
default="",
metavar="EXPR",
help=""" path to the pretrained text decoder """,
)
parser.add_argument(
"--load-init-encoder",
type=str,
default="",
metavar="EXPR",
help=""" path to load seed encoder model """,
)
parser.add_argument(
"--load-init-decoder",
type=str,
default="",
metavar="EXPR",
help=""" path to load seed decoder model """,
)
parser.add_argument(
"--text-input-cost-ratio",
type=float,
default=1.0,
metavar="V",
help="text input cost ratio relative to speech input cost",
)
parser.add_argument(
"--enc-grad-mult",
type=float,
metavar="V",
default=1.0,
help="multiply enc1 and enc2 gradient by V",
)
parser.add_argument(
"--enc2-along-grad-mult",
type=float,
metavar="V",
default=1.0,
help="multiply enc2 gradient by V if only enc2 is used",
)
parser.add_argument(
"--no-strict-check-pretrain-model",
action="store_true",
help="Don't apply strict model check for the pretrained model",
)
parser.add_argument(
"--stacked-encoder",
action="store_true",
help="stack speech and text encoders",
)
@classmethod
def update_transformer_encoder_cfg(cls, args, update_dict):
cfg = dict(args._get_kwargs())
for fkey in update_dict.keys():
cfg[fkey] = update_dict[fkey]
cfg.pop("_name", None) # remove keys start with _
model_args = namedtuple("args", cfg.keys())(*cfg.values())
return model_args
@classmethod
def build_text_encoder(cls, args, src_dictionary):
enc_emb = nn.Embedding(
len(src_dictionary), args.encoder_embed_dim, src_dictionary.pad()
)
model_args = cls.update_transformer_encoder_cfg(
args,
{
"encoder_layers": args.text_encoder_layers,
"max_source_positions": args.max_positions_text,
},
)
text_encoder = TransformerEncoder(model_args, src_dictionary, enc_emb)
return text_encoder
@classmethod
def build_speech_encoder(cls, args):
model_args = cls.update_transformer_encoder_cfg(
args, {"encoder_layers": args.speech_encoder_layers}
)
speech_encoder = SpeechWavTransformerEncoder(model_args)
return speech_encoder
@classmethod
def check_args(cls, condition, is_strict, msg):
if condition:
return
if is_strict:
raise ValueError(msg)
logger.warn(msg)
@classmethod
def build_encoder(cls, args, task):
# text_encoder = cls.build_text_encoder(args, task.source_dictionary )
text_encoder = cls.build_text_encoder(args, task.src_dict)
speech_encoder = cls.build_speech_encoder(args)
if args.load_pretrained_wav2vec_encoder:
component_pairs = (
("feature_extractor", speech_encoder.subsample),
("post_extract_proj", speech_encoder.feat_proj),
("layer_norm", speech_encoder.feat_layer_norm),
("encoder.pos_conv", speech_encoder.embed_positions),
("encoder.layers", speech_encoder.layers),
("encoder.layer_norm", speech_encoder.layer_norm),
("mask_emb", speech_encoder.mask_emb),
)
state = cls.load_pretrained_speech_text_components(
args.load_pretrained_wav2vec_encoder, component_pairs
)
cls.check_args(
args.encoder_normalize_before
== state["cfg"]["model"]["layer_norm_first"],
not args.no_strict_check_pretrain_model,
f"encoder_normalize_before {args.encoder_normalize_before} doesn't match with the pretrained model",
)
cls.check_args(
args.activation_fn == state["cfg"]["model"]["activation_fn"],
not args.no_strict_check_pretrain_model,
f"activation_fn {args.activation_fn} doesn't match with the pretrained model",
)
if getattr(args, "stacked_encoder", False):
if args.encoder_shared_text_layers_from_begin > 0:
raise ValueError(
"We can not stack encoders and share encoders at the same time!"
)
speech_encoder = StackedSpeechWavTransformerEncoder(
speech_encoder, text_encoder.layers, text_encoder.layer_norm
)
else:
cls.share_speech_text_encoder(
speech_encoder, text_encoder, args.encoder_shared_text_layers_from_begin
)
cross_attentive_loss_before_last_layer = (
0 if getattr(args, "attentive_cost_regularization", 0.0) > 0.0 else -1
)
encoder = DualInputEncoder(
args,
speech_encoder,
text_encoder,
task.src_dict,
cross_attentive_loss_before_last_layer,
)
if args.load_pretrained_speech_text_encoder:
component_pairs = (
("encoder.sup_s2s_speech_encoder", encoder.spch_encoder),
("encoder.text_encoder", encoder.text_encoder),
)
cls.load_pretrained_speech_text_components(
args.load_pretrained_speech_text_encoder, component_pairs
)
if getattr(args, "load_init_encoder", "") != "":
checkpoint_utils.load_pretrained_component_from_model(
encoder, args.load_init_encoder
)
return encoder
@classmethod
def build_text_decoder(cls, args, tgt_dictionary, dec_emb_share=None):
dec_emb = (
nn.Embedding(
len(tgt_dictionary), args.decoder_embed_dim, tgt_dictionary.pad()
)
if dec_emb_share is None
else dec_emb_share
)
text_decoder = TransformerDecoder(args, tgt_dictionary, dec_emb)
return text_decoder
@classmethod
def build_decoder(cls, args, task):
text_decoder = cls.build_text_decoder(args, task.target_dictionary)
compute_cross_attentive_loss = (
True if getattr(args, "attentive_cost_regularization", 0.0) > 0.0 else False
)
cross_attentive_loss_without_norm = getattr(
args, "attentive_cost_without_normalize", False
)
cross_attentive_loss_reverse = (
False # getattr(args, "attentive_cost_reverse", False)
)
if getattr(args, "load_pretrained_text_decoder", "") != "":
checkpoint_utils.load_pretrained_component_from_model(
text_decoder, args.load_pretrained_text_decoder
)
if args.load_pretrained_speech_text_decoder:
component_pairs = (("decoder.text_decoder", text_decoder),)
cls.load_pretrained_speech_text_components(
args.load_pretrained_speech_text_decoder, component_pairs
)
decoder = TransformerMultiInputDecoder(
dictionary=task.target_dictionary,
spch_decoder=text_decoder,
text_decoder=text_decoder,
compute_cross_attentive_loss=compute_cross_attentive_loss,
cross_attentive_loss_with_norm=True
if not cross_attentive_loss_without_norm
else False,
cross_attentive_loss_reverse=cross_attentive_loss_reverse,
)
if getattr(args, "load_init_decoder", "") != "":
checkpoint_utils.load_pretrained_component_from_model(
decoder, args.load_init_decoder
)
return decoder
@classmethod
def load_pretrained_speech_text_components(cls, checkpoint, component_pairs):
if not PathManager.exists(checkpoint):
raise IOError("Model file not found: {}".format(checkpoint))
state = load_checkpoint_to_cpu(checkpoint)
for component_type, component in component_pairs:
if isinstance(component, nn.parameter.Parameter):
component.data.copy_(state["model"][component_type])
else:
component_state_dict = OrderedDict()
for key in state["model"].keys():
if key.startswith(component_type):
component_subkey = key[len(component_type) + 1 :]
component_state_dict[component_subkey] = state["model"][key]
component.load_state_dict(component_state_dict, strict=True)
return state
@classmethod
def share_speech_text_encoder(
cls, speech_encoder, text_encoder, shared_layers_from_begin
):
if shared_layers_from_begin > 0:
num_text_encoder_layers = len(text_encoder.layers)
assert len(speech_encoder.layers) >= shared_layers_from_begin
assert num_text_encoder_layers >= shared_layers_from_begin
assert len(speech_encoder.layers) >= num_text_encoder_layers
for i, ly in enumerate(
speech_encoder.layers[
-num_text_encoder_layers : -num_text_encoder_layers
+ shared_layers_from_begin
]
):
assert isinstance(text_encoder.layers[i], type(ly))
text_encoder.layers[i] = ly
@register_model_architecture(
"dual_input_wav_transformer", "dualinputs2twavtransformer_base"
)
def dualinputs2twavtransformer_base(args):
# speech masking
args.dropout_input = getattr(args, "dropout_input", 0)
args.dropout_features = getattr(args, "dropout_features", 0)
args.speech_mask_length = getattr(args, "speech_mask_length", 10)
args.speech_mask_prob = getattr(args, "speech_mask_prob", 0.65)
args.speech_mask_selection = getattr(args, "speech_mask_selection", "static")
args.speech_mask_other = getattr(args, "speech_mask_other", 0)
args.speech_mask_min_space = getattr(args, "speech_mask_min_space", 1)
args.speech_no_mask_overlap = getattr(args, "speech_no_mask_overlap", False)
args.speech_conv_bias = getattr(args, "speech_conv_bias", False)
args.speech_extractor_mode = getattr(args, "speech_extractor_mode", "default")
args.no_strict_check_pretrain_model = getattr(
args, "no_strict_check_pretrain_model", False
)
args.speech_mask_channel_length = getattr(args, "speech_mask_channel_length", 10)
args.speech_mask_channel_prob = getattr(args, "speech_mask_channel_prob", 0.0)
args.speech_mask_channel_selection = getattr(
args, "speech_mask_channel_selection", "static"
)
args.speech_mask_channel_other = getattr(args, "speech_mask_channel_other", 0)
args.speech_mask_channel_min_space = getattr(
args, "speech_mask_channel_min_space", 1
)
args.speech_no_mask_channel_overlap = getattr(
args, "speech_no_mask_channel_overlap", False
)
args.no_scale_feature = getattr(args, "", False)
args.feature_grad_mult = getattr(args, "feature_grad_mult", 0.0) # 0.1
# Transformer
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 768)
args.encoder_ffn_embed_dim = getattr(
args, "encoder_ffn_embed_dim", args.encoder_embed_dim * 4
)
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 12)
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0.1)
args.encoder_learned_pos = getattr(args, "encoder_learned_pos", False)
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim)
args.decoder_ffn_embed_dim = getattr(
args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim
)
args.decoder_attention_heads = getattr(
args, "decoder_attention_heads", args.encoder_attention_heads
)
args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False)
args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False)
args.dropout = getattr(args, "dropout", 0.1)
args.attention_dropout = getattr(args, "attention_dropout", 0)
args.activation_dropout = getattr(args, "activation_dropout", args.dropout)
args.activation_fn = getattr(args, "activation_fn", "relu") # gelu?
args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None)
args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
args.tie_adaptive_weights = getattr(args, "tie_adaptive_weights", False)
args.share_decoder_input_output_embed = getattr(
args, "share_decoder_input_output_embed", False
)
args.no_token_positional_embeddings = getattr(
args, "no_token_positional_embeddings", False
)
args.adaptive_input = getattr(args, "adaptive_input", False)
args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0.0)
args.decoder_output_dim = getattr(
args, "decoder_output_dim", args.decoder_embed_dim
)
args.layernorm_embedding = getattr(args, "layernorm_embedding", False)
args.no_scale_embedding = getattr(args, "no_scale_embedding", False)
args.quant_noise_pq = getattr(args, "quant_noise_pq", 0)
args.speech_encoder_layers = getattr(args, "speech_encoder_layers", 12)
args.text_encoder_layers = getattr(args, "text_encoder_layers", 6)
args.encoder_shared_text_layers_from_begin = getattr(
args, "encoder_shared_text_layers_from_begin", 6
)
args.decoder_layers = getattr(args, "decoder_layers", 6)
@register_model_architecture(
"dual_input_wav_transformer", "dualinputs2twavtransformer_base_stack"
)
def dualinputs2twavtransformer_base_stack(args):
args.speech_encoder_layers = getattr(args, "speech_encoder_layers", 6)
args.text_encoder_layers = getattr(args, "text_encoder_layers", 6)
args.encoder_shared_text_layers_from_begin = getattr(
args, "encoder_shared_text_layers_from_begin", 0
)
args.decoder_layers = getattr(args, "decoder_layers", 6)
args.stacked_encoder = getattr(args, "stacked_encoder", True)
args.layernorm_embedding = getattr(args, "layernorm_embedding", True)
dualinputs2twavtransformer_base(args)
@register_model_architecture(
"dual_input_wav_transformer", "dualinputs2twavtransformer_large"
)
def dualinputs2twavtransformer_large(args):
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024)
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16)
args.speech_encoder_layers = getattr(args, "speech_encoder_layers", 24)
args.text_encoder_layers = getattr(args, "text_encoder_layers", 12)
args.encoder_shared_text_layers_from_begin = getattr(
args, "encoder_shared_text_layers_from_begin", 12
)
args.decoder_layers = getattr(args, "decoder_layers", 12)
dualinputs2twavtransformer_base(args)

View File

@ -0,0 +1,71 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import re
from collections import OrderedDict
import torch
from fairseq.file_io import PathManager
def is_update(param_name, module_name):
if module_name in param_name:
return True
return False
def load_checkpoint(src_cpt):
with PathManager.open(src_cpt, "rb") as f:
state_src = torch.load(
f,
map_location=(
lambda s, _: torch.serialization.default_restore_location(s, "cpu")
),
)
return state_src
def save_checkpoint(tgt_cpt, states):
with PathManager.open(tgt_cpt, "wb") as f:
torch.save(
states,
f,
)
# convert the pre-trained model into bart model
def main():
parser = argparse.ArgumentParser()
# fmt: off
parser.add_argument('--input-model', required=True,
help='Input checkpoint file path.')
parser.add_argument('--output-model', required=True,
help='output checkpoint file path.')
# fmt: on
args = parser.parse_args()
print(args)
states = load_checkpoint(args.input_model)
model = states["model"]
new_model = OrderedDict()
for key in model.keys():
if re.search("^encoder.text_encoder", key):
new_key = re.sub("encoder.text_encoder", "encoder", key)
new_model[new_key] = model[key]
elif re.search("^decoder.text_decoder", key):
new_key = re.sub("decoder.text_decoder", "decoder", key)
new_model[new_key] = model[key]
states["model"] = new_model
save_checkpoint(args.output_model, states)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,447 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import itertools
import logging
import os
import re
import numpy as np
import torch
from examples.speech_text_joint_to_text.data.pair_denoising_dataset import (
LanguagePairDenoisingDataset,
)
from fairseq import utils
from fairseq.data import (
ConcatDataset,
Dictionary,
LanguagePairDataset,
ResamplingDataset,
TransformEosConcatLangPairDataset,
TransformEosLangPairDataset,
data_utils,
indexed_dataset,
)
from fairseq.data.encoders.utils import get_whole_word_mask
from fairseq.tasks import register_task
from fairseq.tasks.translation import TranslationTask
logger = logging.getLogger(__name__)
def gen_whole_word_mask(args, dictionary):
def is_beginning_of_word(i):
if i < dictionary.nspecial:
# special elements are always considered beginnings
return True
tok = dictionary[i]
if tok.startswith("madeupword"):
return True
if tok in ["<unk>", "<s>", "</s>", "<pad>"]:
return True
return tok.startswith("\u2581")
if args.use_mask_whole_words:
mask_whole_words = torch.ByteTensor(
list(map(is_beginning_of_word, range(len(dictionary))))
)
else:
# it will mask every token as word leading token, since no bpe model is loaded for phoneme tokens
return get_whole_word_mask(args, dictionary)
return mask_whole_words
@register_task("paired_denoising")
class PairedDenoisingTask(TranslationTask):
LANG_TAG_TEMPLATE = "<lang:{}>" # Tag for language (target)
@staticmethod
def add_args(parser):
TranslationTask.add_args(parser)
# bart setting
parser.add_argument(
"--mask",
default=0.0,
type=float,
help="fraction of words/subwords that will be masked",
)
parser.add_argument(
"--mask-random",
default=0.0,
type=float,
help="instead of using [MASK], use random token this often",
)
parser.add_argument(
"--insert",
default=0.0,
type=float,
help="insert this percentage of additional random tokens",
)
parser.add_argument(
"--poisson-lambda",
default=3.0,
type=float,
help="randomly shuffle sentences for this proportion of inputs",
)
parser.add_argument(
"--mask-length",
default="span-poisson",
type=str,
choices=["subword", "word", "span-poisson"],
help="mask length to choose",
)
parser.add_argument(
"--replace-length",
default=1,
type=int,
help="when masking N tokens, replace with 0, 1, or N tokens (use -1 for N)",
)
# multi-lingual
parser.add_argument(
"--multilang-sampling-alpha",
type=float,
default=1.0,
help="smoothing alpha for sample ratios across multiple datasets",
)
parser.add_argument(
"--lang-pairs",
default="",
metavar="PAIRS",
help="comma-separated list of language pairs (in training order): phnen-en,phnfr-fr,phnit-it. Do masking",
)
parser.add_argument(
"--lang-pairs-bitext",
default="",
metavar="PAIRS",
help="comma-separated list of language pairs (in training order): en-de,en-fr,de-fr. No masking",
)
parser.add_argument("--add-src-lang-token", default=False, action="store_true")
parser.add_argument("--add-tgt-lang-token", default=False, action="store_true")
parser.add_argument(
"--no-whole-word-mask-langs",
type=str,
default="",
metavar="N",
help="languages without spacing between words dont support whole word masking",
)
parser.add_argument(
"--use-mask-whole-words", default=False, action="store_true"
)
@classmethod
def setup_task(cls, args, **kwargs):
"""Setup the task."""
paths = args.data.split(":")
assert len(paths) > 0
src_dict = Dictionary.load(
os.path.join(paths[0], "src_dict.txt")
) # assume all languages share a source dictionary
tgt_dict = Dictionary.load(
os.path.join(paths[0], "tgt_dict.txt")
) # assume all languages share a target dictionary
lang_pairs = args.lang_pairs + "," + args.lang_pairs_bitext
lang_pairs = re.sub(",$", "", re.sub("^,", "", lang_pairs))
src_langs = [lp.split("-")[0] for lp in lang_pairs.split(",")]
tgt_langs = [lp.split("-")[1] for lp in lang_pairs.split(",")]
if args.add_src_lang_token:
for lang in src_langs:
assert (
src_dict.index(PairedDenoisingTask.LANG_TAG_TEMPLATE.format(lang))
!= src_dict.unk()
)
if args.add_tgt_lang_token:
for lang in tgt_langs:
assert (
tgt_dict.index(PairedDenoisingTask.LANG_TAG_TEMPLATE.format(lang))
!= tgt_dict.unk()
)
logger.info("source dictionary: {} types".format(len(src_dict)))
logger.info("target dictionary: {} types".format(len(tgt_dict)))
if not hasattr(args, "shuffle_instance"):
args.shuffle_instance = False
return cls(args, src_dict, tgt_dict)
def __init__(self, args, src_dict, tgt_dict):
super().__init__(args, src_dict, tgt_dict)
# check mask token
self.mask_idx = self.src_dict.index("<mask>")
assert self.mask_idx != self.src_dict.unk()
self.lang_pairs = args.lang_pairs
self.lang_pairs_bitext = args.lang_pairs_bitext
self.args = args
@classmethod
def language_pair_denoising_dataset(
cls,
data_path,
do_mask,
split,
src,
src_dict,
tgt,
tgt_dict,
mask_idx,
mask_whole_words,
seed,
args,
dataset_impl,
combine=False,
left_pad_source=True,
left_pad_target=False,
max_source_positions=1024,
max_target_positions=1024,
shuffle=True,
src_lang_id=None,
tgt_lang_id=None,
):
def split_exists(split, src, tgt, lang, data_path):
filename = os.path.join(
data_path, "{}.{}-{}.{}".format(split, src, tgt, lang)
)
return indexed_dataset.dataset_exists(filename, impl=dataset_impl)
src_datasets = []
tgt_datasets = []
for k in itertools.count():
split_k = split + (str(k) if k > 0 else "")
# infer langcode
if split_exists(split_k, src, tgt, src, data_path):
prefix = os.path.join(data_path, "{}.{}-{}.".format(split_k, src, tgt))
elif split_exists(split_k, tgt, src, src, data_path):
prefix = os.path.join(data_path, "{}.{}-{}.".format(split_k, tgt, src))
else:
if k > 0:
break
else:
raise FileNotFoundError(
"Dataset not found: {} ({})".format(split, data_path)
)
src_dataset = data_utils.load_indexed_dataset(
prefix + src, src_dict, dataset_impl
)
src_datasets.append(src_dataset)
tgt_dataset = data_utils.load_indexed_dataset(
prefix + tgt, tgt_dict, dataset_impl
)
if tgt_dataset is not None:
tgt_datasets.append(tgt_dataset)
logger.info(
"{} {} {}-{} {} examples".format(
data_path, split_k, src, tgt, len(src_datasets[-1])
)
)
if not combine:
break
assert len(src_datasets) == len(tgt_datasets) or len(tgt_datasets) == 0
if len(src_datasets) == 1:
src_dataset = src_datasets[0]
tgt_dataset = tgt_datasets[0] if len(tgt_datasets) > 0 else None
else:
sample_ratios = [1] * len(src_datasets)
src_dataset = ConcatDataset(src_datasets, sample_ratios)
if len(tgt_datasets) > 0:
tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios)
else:
tgt_dataset = None
eos = None
tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None
if not do_mask:
return LanguagePairDataset(
src_dataset,
src_dataset.sizes,
src_dict,
tgt_dataset,
tgt_dataset_sizes,
tgt_dict,
left_pad_source=left_pad_source,
left_pad_target=left_pad_target,
eos=eos,
shuffle=shuffle,
src_lang_id=src_lang_id,
tgt_lang_id=tgt_lang_id,
)
return LanguagePairDenoisingDataset(
src_dataset,
src_dataset.sizes,
src_dict,
tgt_dataset,
tgt_dataset_sizes,
tgt_dict,
mask_idx,
mask_whole_words,
seed,
args,
left_pad_source=left_pad_source,
left_pad_target=left_pad_target,
eos=eos,
shuffle=shuffle,
src_lang_id=src_lang_id,
tgt_lang_id=tgt_lang_id,
)
def _get_sample_prob(self, dataset_lens):
"""
Get smoothed sampling porbability by languages. This helps low resource
languages by upsampling them.
"""
prob = dataset_lens / dataset_lens.sum()
smoothed_prob = prob ** self.args.multilang_sampling_alpha
smoothed_prob = smoothed_prob / smoothed_prob.sum()
return smoothed_prob
def resample_datasets(self, lang_datasets, lang_pairs_all, epoch):
# For train subset, additionally up or down sample languages.
if self.args.multilang_sampling_alpha == 1.0:
return lang_datasets
dataset_lengths = np.array(
[len(d) for d in lang_datasets],
dtype=float,
)
sample_probs = self._get_sample_prob(dataset_lengths)
logger.info(
"Sample probability by language pair: {}".format(
{
lp: "{0:.4f}".format(sample_probs[id])
for id, lp in enumerate(lang_pairs_all)
}
)
)
size_ratio = (sample_probs * dataset_lengths.sum()) / dataset_lengths
logger.info(
"Up/Down Sampling ratio by language: {}".format(
{
lp: "{0:.2f}".format(size_ratio[id])
for id, lp in enumerate(lang_pairs_all)
}
)
)
resampled_lang_datasets = [
ResamplingDataset(
lang_datasets[i],
size_ratio=size_ratio[i],
seed=self.args.seed,
epoch=epoch,
replace=size_ratio[i] >= 1.0,
)
for i, d in enumerate(lang_datasets)
]
return resampled_lang_datasets
def load_dataset_only(
self, split, lang_pairs, do_mask=True, epoch=1, combine=False
):
paths = utils.split_paths(self.args.data)
assert len(paths) > 0
data_path = paths[(epoch - 1) % len(paths)]
# TODO unk token will be considered as first word too, though it might be an unknown phoneme within a word
# get_whole_word_mask returns a tensor (size V by 1 ) to indicate if a token is a word start token
mask_whole_src_words = gen_whole_word_mask(self.args, self.src_dict)
language_without_segmentations = self.args.no_whole_word_mask_langs.split(",")
lang_datasets = []
eos_bos = []
lang_pairs = lang_pairs.split(",") if lang_pairs != "" else []
assert len(lang_pairs) > 0
for lp in lang_pairs:
src, tgt = lp.split("-")
lang_mask_whole_src_words = (
mask_whole_src_words
if src not in language_without_segmentations
else None
)
end_token = (
self.source_dictionary.index(
PairedDenoisingTask.LANG_TAG_TEMPLATE.format(src)
)
if self.args.add_src_lang_token
else None
)
bos_token = (
self.target_dictionary.index(
PairedDenoisingTask.LANG_TAG_TEMPLATE.format(tgt)
)
if self.args.add_tgt_lang_token
else None
)
src_lang_id = None
if self.args.add_src_lang_token or self.args.add_tgt_lang_token:
eos_bos.append((end_token, bos_token))
dataset = PairedDenoisingTask.language_pair_denoising_dataset(
data_path,
do_mask,
split,
src,
self.source_dictionary,
tgt,
self.target_dictionary,
self.mask_idx,
lang_mask_whole_src_words,
self.args.seed,
self.args,
self.args.dataset_impl,
combine=combine,
left_pad_source=utils.eval_bool(self.args.left_pad_source),
left_pad_target=utils.eval_bool(self.args.left_pad_target),
max_source_positions=self.args.max_source_positions,
max_target_positions=self.args.max_target_positions,
src_lang_id=src_lang_id,
)
lang_datasets.append(dataset)
if len(lang_datasets) == 0:
return
elif len(lang_datasets) == 1:
dataset = lang_datasets[0]
if self.args.add_src_lang_token or self.args.add_tgt_lang_token:
end_token, bos_token = eos_bos[0]
dataset = TransformEosLangPairDataset(
dataset,
src_eos=self.source_dictionary.eos(),
new_src_eos=end_token,
tgt_bos=self.target_dictionary.eos(),
new_tgt_bos=bos_token,
)
else:
end_tokens = [item[0] for item in eos_bos if item[0] is not None]
bos_tokens = [item[1] for item in eos_bos if item[1] is not None]
lang_datasets = self.resample_datasets(lang_datasets, lang_pairs, epoch)
dataset = TransformEosConcatLangPairDataset(
lang_datasets,
self.source_dictionary.eos(),
self.target_dictionary.eos(),
new_src_eos=end_tokens,
new_tgt_bos=bos_tokens,
)
return dataset
# split in (train, valid, test, ...)
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
self.datasets[split] = self.load_dataset_only(
split, self.lang_pairs, epoch=epoch, combine=combine
)

View File

@ -0,0 +1,654 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
import os
import re
from argparse import Namespace
from pathlib import Path
from fairseq.data import ConcatDataset, Dictionary, encoders
from fairseq.data.audio.multi_modality_dataset import (
FileAudioDatasetWrapper,
ModalityDatasetItem,
MultiModalityDataset,
)
from fairseq.data.audio.speech_to_text_joint_dataset import (
S2TJointDataConfig,
SpeechToTextJointDatasetCreator,
)
from fairseq.data.iterators import GroupedEpochBatchIterator
from fairseq.tasks import register_task
from .pair_denoising import PairedDenoisingTask
logger = logging.getLogger(__name__)
@register_task("speech_text_joint_denoising")
class SpeechTextJointDenoisingPreTask(PairedDenoisingTask):
"""
Joint denoising training task for speech and text.
"""
SIL_TOKEN = "sil"
@classmethod
def add_args(cls, parser):
PairedDenoisingTask.add_args(parser)
# set max tokens and position
parser.add_argument(
"--max-text-tokens",
type=int,
metavar="N",
default=1024,
help="maximum samples for encoder text input ",
)
parser.add_argument(
"--max-speech-tokens",
type=int,
metavar="N",
default=50000,
help="maximum samples for encoder speech input ",
)
parser.add_argument(
"--max-speech-positions",
type=int,
metavar="N",
default=400,
help="maximum tokens for per encoder text input ",
)
parser.add_argument(
"--max-sample-size",
type=int,
metavar="N",
default=32000,
help="max sample size to crop to for batching (unsupervised speech) ",
)
parser.add_argument(
"--min-sample-size",
type=int,
metavar="N",
default=4000,
help="min sample size to crop to for batching (unsupervised speech) ",
)
# set mini-batch ratio for different modalities/subtasks
# s2p
parser.add_argument(
"--supervised-speech-sample-ratio",
default="1",
type=str,
metavar="N",
help="Multiple Ratio for speech dataset with transcripts ",
)
# s2t
parser.add_argument(
"--supervised-speech-s2s-sample-ratio",
default="1",
type=str,
metavar="N",
help="Multiple Ratio for speech dataset with transcripts ",
)
# ssl
parser.add_argument(
"--unsupervised-speech-sample-ratio",
default="1",
type=str,
metavar="N",
help="Multiple Ratio for speech dataset without transcripts ",
)
# t2t with monolingual data (masking)
parser.add_argument(
"--text-sample-ratio",
default="1",
type=str,
metavar="N",
help="Multiple Ratio for text set ",
)
# t2t with parallel data (no masking)
parser.add_argument(
"--bitext-sample-ratio",
default="1",
type=str,
metavar="N",
help="Multiple Ratio for text set (bitext) ",
)
# train_subset = "train", 'valid' or so
# parallel data is loaded according to string lang_pairs and lang_pairs_no_mask from args.data
# (un)supervised speech is loaded from args.(un)sup_speech_{train,valid}_subset
parser.add_argument(
"--sup-speech-data", default="", help="path to supervised speech data"
)
parser.add_argument(
"--sup-speech-train-subset",
default="",
help="supervised speech training subsets",
)
parser.add_argument(
"--sup-speech-valid-subset",
default="",
help="supervised speech validation subsets",
)
parser.add_argument(
"--config-yaml",
default="config.yaml",
help="supervised speech configuration yaml file",
)
parser.add_argument(
"--sup-speech-s2s-data", default="", help="path to supervised speech data"
)
parser.add_argument(
"--sup-speech-s2s-train-subset",
default="",
help="supervised speech training subsets",
)
parser.add_argument(
"--sup-speech-s2s-valid-subset",
default="",
help="supervised speech validation subsets",
)
parser.add_argument(
"--config-s2s-yaml",
default="config.yaml",
help="supervised speech configuration yaml file",
)
parser.add_argument(
"--unsup-speech-train-data",
default="",
help="path to unsupervised speech training data (tsv)",
)
parser.add_argument(
"--unsup-speech-valid-data",
default="",
help="path to unsupervised speech valid data (tsv)",
)
parser.add_argument(
"--sample-rate",
type=int,
metavar="N",
default=16000,
help="input audio sampling rate",
)
parser.add_argument(
"--no-emb-update-unsup",
default=False,
action="store_true",
help="no update for output embedding during unsupervised_speech mode",
)
parser.add_argument("--same-data-update", default=False, action="store_true")
# used for sup_speech_ali
parser.add_argument(
"--use-sup-speech-ctc",
default=False,
action="store_true",
help="use speech_sup_ctc instead of speech_sup_ali",
)
@classmethod
def setup_task(cls, args, **kwargs):
"""Setup the task."""
paths = args.data.split(":")
assert len(paths) > 0
src_dict = Dictionary.load(
os.path.join(paths[0], "src_dict.txt")
) # assume all languages share a source dictionary
tgt_dict = Dictionary.load(
os.path.join(paths[0], "tgt_dict.txt")
) # assume all languages share a target dictionary
lang_pairs = args.lang_pairs + "," + args.lang_pairs_bitext
lang_pairs = re.sub(",$", "", re.sub("^,", "", lang_pairs))
if lang_pairs != "":
src_langs = [lp.split("-")[0] for lp in lang_pairs.split(",")]
tgt_langs = [lp.split("-")[1] for lp in lang_pairs.split(",")]
else:
src_langs = []
tgt_langs = []
if args.add_src_lang_token:
for lang in src_langs:
assert (
src_dict.index(PairedDenoisingTask.LANG_TAG_TEMPLATE.format(lang))
!= src_dict.unk()
)
if args.add_tgt_lang_token:
for lang in tgt_langs:
assert (
tgt_dict.index(PairedDenoisingTask.LANG_TAG_TEMPLATE.format(lang))
!= tgt_dict.unk()
)
logger.info("source dictionary: {} types".format(len(src_dict)))
logger.info("target dictionary: {} types".format(len(tgt_dict)))
if not hasattr(args, "shuffle_instance"):
args.shuffle_instance = False
return cls(args, src_dict, tgt_dict)
def __init__(self, args, src_dict, tgt_dict):
super().__init__(args, src_dict, tgt_dict)
self.data_cfg = S2TJointDataConfig(
Path(args.sup_speech_data) / args.config_yaml
)
logger.info(
f"load supervised speech data configure from {Path(args.sup_speech_data) / args.config_yaml}"
)
self.data_s2s_cfg = (
S2TJointDataConfig(Path(args.sup_speech_s2s_data) / args.config_s2s_yaml)
if args.sup_speech_s2s_train_subset != ""
else None
)
if self.data_s2s_cfg is not None:
logger.info(
f"load supervised sequece to sequence speech data configure from {Path(args.sup_speech_s2s_data) / args.config_yaml}"
)
def parse_data_ratio(sample_ratio):
ratios = sample_ratio.split(",")
if len(ratios) == 1:
return [float(ratios[0])]
epoch_ratios = []
for item in ratios:
ep, r = item.split(":")
ep = int(ep)
r = float(r)
assert ep > 0 # epoch is 1 based
assert ep >= len(epoch_ratios)
if len(epoch_ratios) == 0:
epoch_ratios.append(
r
) # epoch_ratios[0] is not used, but we still set it to the first value to make thing simple.
while len(epoch_ratios) < ep:
epoch_ratios.append(epoch_ratios[-1])
epoch_ratios.append(r)
return epoch_ratios
self.sup_ratio = parse_data_ratio(args.supervised_speech_sample_ratio)
self.sup_s2s_ratio = parse_data_ratio(args.supervised_speech_s2s_sample_ratio)
self.text_ratio = parse_data_ratio(args.text_sample_ratio)
self.bitext_ratio = parse_data_ratio(args.bitext_sample_ratio)
self.unsup_ratio = parse_data_ratio(args.unsupervised_speech_sample_ratio)
self.sample_mode = None
def build_model(self, args):
args.input_feat_per_channel = self.data_cfg.input_feat_per_channel
args.input_channels = self.data_cfg.input_channels
return super().build_model(args)
def build_tokenizer(self, data_cfg, msg=""):
logger.info(f"pre-tokenizer {msg}: {data_cfg.pre_tokenizer}")
return encoders.build_tokenizer(Namespace(**data_cfg.pre_tokenizer))
def build_bpe(self, data_cfg, msg=""):
logger.info(f"tokenizer {msg}: {data_cfg.bpe_tokenizer}")
return encoders.build_bpe(Namespace(**data_cfg.bpe_tokenizer))
@classmethod
def resolve_data_type(cls, split, use_sup_speech_ctc):
if len(split.split("_")) == 1:
# default case, train or valid
is_train = split
dtype = "text"
else:
is_train, dtype = split.split("_", 1)
is_train = True if is_train == "train" else False
if dtype == "sup_speech":
dtype = "sup_speech_ctc" if use_sup_speech_ctc else "sup_speech_ali"
assert dtype in (
"text",
"bitext",
"sup_speech_ali",
"sup_speech_s2s",
"unsup_speech",
"sup_speech_ctc",
)
return is_train, dtype
def create_modalitydatasetitem(self, dtype, dataset):
dsitem = None
if dtype in ("text", "bitext"):
dsitem = ModalityDatasetItem(
dtype,
dataset,
(self.args.max_source_positions, self.args.max_target_positions),
self.args.max_text_tokens,
self.args.batch_size,
)
elif dtype in ("sup_speech_ctc", "sup_speech_ali", "sup_speech_s2s"):
dsitem = ModalityDatasetItem(
dtype,
dataset,
(self.args.max_speech_positions, self.args.max_target_positions),
self.args.max_speech_tokens,
self.args.batch_size,
)
elif dtype == "unsup_speech":
dsitem = ModalityDatasetItem(
dtype, dataset, 1e8, self.args.max_speech_tokens, self.args.batch_size
)
else:
raise ValueError(f"{dtype} is not supported")
return dsitem
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
def _get_sup_src_tgt_dict(src_dict, tgt_dict, use_s2s_sup_decoder):
if use_s2s_sup_decoder:
return None, tgt_dict
# use src_dict as tgt_dict here, since we use source dictionary as target for forcealignment
return None, src_dict
is_train, dtype = self.resolve_data_type(split, self.args.use_sup_speech_ctc)
# Note we use --add-tgt-lang-token instead of data_cfg.prepend_tgt_lang_tag_no_change to set target language tag in the text dataset
# Verify add_tgt_lang_token and prepend_tgt_lang_tag_no_change are same
# Note we use --multilang-sampling-alpha instead of data_cfg.sampling_text_alpha to set text data sampling
if is_train:
msets = []
# train split, load everything into one
if self.lang_pairs != "":
text_dataset = self.load_dataset_only(
"train", self.lang_pairs, epoch=epoch, combine=combine
)
dsitem = self.create_modalitydatasetitem("text", text_dataset)
msets.append(dsitem)
if self.lang_pairs_bitext != "": # load bitext
bitext_dataset = self.load_dataset_only(
"train_bitext",
self.lang_pairs_bitext,
do_mask=False,
epoch=epoch,
combine=combine,
)
dsitem = self.create_modalitydatasetitem("bitext", bitext_dataset)
msets.append(dsitem)
if self.args.sup_speech_train_subset != "":
pre_tokenizer = self.build_tokenizer(self.data_cfg)
bpe_tokenizer = self.build_bpe(self.data_cfg)
append_eos = True
sup_speech_type = "sup_speech_ali"
if self.args.use_sup_speech_ctc:
# CTC mode
sup_speech_type = "sup_speech_ctc"
append_eos = False # CTC doesn't need eos in the target
src_dict, tgt_dict = _get_sup_src_tgt_dict(
self.src_dict, self.tgt_dict, False
)
sup_speech_dataset = SpeechToTextJointDatasetCreator.from_tsv(
self.args.sup_speech_data,
self.data_cfg,
self.args.sup_speech_train_subset,
tgt_dict=tgt_dict,
src_dict=src_dict,
pre_tokenizer=pre_tokenizer,
bpe_tokenizer=bpe_tokenizer,
src_pre_tokenizer=None,
src_bpe_tokenizer=None,
is_train_split=is_train,
epoch=epoch,
seed=self.args.seed,
append_eos=append_eos,
)
dsitem = self.create_modalitydatasetitem(
sup_speech_type, sup_speech_dataset
)
msets.append(dsitem)
if self.args.sup_speech_s2s_train_subset != "":
pre_tokenizer = self.build_tokenizer(self.data_s2s_cfg, msg="(s2s)")
bpe_tokenizer = self.build_bpe(self.data_s2s_cfg, msg="(s2s)")
# make sure self.data_cfg.prepend_tgt_lang_tag_no_change == self.args.add_tgt_lang_token
src_dict, tgt_dict = _get_sup_src_tgt_dict(
self.src_dict, self.tgt_dict, True
)
sup_speech_s2s_dataset = SpeechToTextJointDatasetCreator.from_tsv(
self.args.sup_speech_s2s_data,
self.data_s2s_cfg,
self.args.sup_speech_s2s_train_subset,
tgt_dict=tgt_dict,
src_dict=src_dict,
pre_tokenizer=pre_tokenizer,
bpe_tokenizer=bpe_tokenizer,
src_pre_tokenizer=None,
src_bpe_tokenizer=None,
is_train_split=is_train,
epoch=epoch,
seed=self.args.seed,
)
dsitem = self.create_modalitydatasetitem(
"sup_speech_s2s", sup_speech_s2s_dataset
)
msets.append(dsitem)
if self.args.unsup_speech_train_data != "":
unsup_speech_dataset = FileAudioDatasetWrapper(
self.args.unsup_speech_train_data,
self.args.sample_rate,
max_sample_size=self.args.max_sample_size,
min_sample_size=self.args.min_sample_size,
normalize=False,
)
dsitem = self.create_modalitydatasetitem(
"unsup_speech", unsup_speech_dataset
)
msets.append(dsitem)
pre_train_dataset = MultiModalityDataset(msets)
self.datasets[split] = pre_train_dataset
else: # validation split, load them for each type of data
if dtype == "text":
text_dataset = self.load_dataset_only(
split, self.lang_pairs, epoch=epoch, combine=combine
)
dsitem = self.create_modalitydatasetitem("text", text_dataset)
self.datasets[split] = MultiModalityDataset([dsitem])
elif dtype == "bitext":
bitext_dataset = self.load_dataset_only(
split,
self.lang_pairs_bitext,
do_mask=False,
epoch=epoch,
combine=combine,
)
dsitem = self.create_modalitydatasetitem("bitext", bitext_dataset)
self.datasets[split] = MultiModalityDataset([dsitem])
elif dtype in ("sup_speech_ctc", "sup_speech_ali"):
assert self.args.sup_speech_valid_subset != ""
pre_tokenizer = self.build_tokenizer(self.data_cfg)
bpe_tokenizer = self.build_bpe(self.data_cfg)
append_eos = True
if dtype == "sup_speech_ctc":
# CTC mode
append_eos = False # CTC doesn't need eos
assert self.args.use_sup_speech_ctc
datasets = []
for split_name in self.args.sup_speech_valid_subset.split(","):
src_dict, tgt_dict = _get_sup_src_tgt_dict(
self.src_dict, self.tgt_dict, False
)
datasets.append(
SpeechToTextJointDatasetCreator.from_tsv(
self.args.sup_speech_data,
self.data_cfg,
split_name,
tgt_dict=tgt_dict,
src_dict=src_dict,
pre_tokenizer=pre_tokenizer,
bpe_tokenizer=bpe_tokenizer,
src_pre_tokenizer=None,
src_bpe_tokenizer=None,
is_train_split=is_train,
epoch=epoch,
seed=self.args.seed,
append_eos=append_eos,
)
)
dset = datasets[0] if len(datasets) == 1 else ConcatDataset(datasets)
dsitem = self.create_modalitydatasetitem(dtype, dset)
self.datasets[split] = MultiModalityDataset([dsitem])
elif dtype == "sup_speech_s2s":
assert self.args.sup_speech_s2s_valid_subset != ""
pre_tokenizer = self.build_tokenizer(self.data_s2s_cfg)
bpe_tokenizer = self.build_bpe(self.data_s2s_cfg)
datasets = []
for split_name in self.args.sup_speech_s2s_valid_subset.split(","):
src_dict, tgt_dict = _get_sup_src_tgt_dict(
self.src_dict, self.tgt_dict, True
)
datasets.append(
SpeechToTextJointDatasetCreator.from_tsv(
self.args.sup_speech_s2s_data,
self.data_s2s_cfg,
split_name,
tgt_dict=tgt_dict,
src_dict=src_dict,
pre_tokenizer=pre_tokenizer,
bpe_tokenizer=bpe_tokenizer,
src_pre_tokenizer=None,
src_bpe_tokenizer=None,
is_train_split=is_train,
epoch=epoch,
seed=self.args.seed,
)
)
dset = datasets[0] if len(datasets) == 1 else ConcatDataset(datasets)
dsitem = self.create_modalitydatasetitem("sup_speech_s2s", dset)
self.datasets[split] = MultiModalityDataset([dsitem])
elif dtype == "unsup_speech":
assert self.args.unsup_speech_valid_data != ""
unsup_speech_dataset = FileAudioDatasetWrapper(
self.args.unsup_speech_valid_data,
self.args.sample_rate,
max_sample_size=self.args.max_sample_size,
min_sample_size=self.args.min_sample_size,
normalize=False,
)
dsitem = self.create_modalitydatasetitem(
"unsup_speech", unsup_speech_dataset
)
self.datasets[split] = MultiModalityDataset([dsitem])
else:
raise ValueError(f"Unsupported type {dtype}")
def get_sample_ratio(self, epoch):
sup_ratio = (
self.sup_ratio[epoch] if len(self.sup_ratio) > epoch else self.sup_ratio[-1]
)
sup_s2s_ratio = (
self.sup_s2s_ratio[epoch]
if len(self.sup_s2s_ratio) > epoch
else self.sup_s2s_ratio[-1]
)
unsup_ratio = (
self.unsup_ratio[epoch]
if len(self.unsup_ratio) > epoch
else self.unsup_ratio[-1]
)
text_ratio = (
self.text_ratio[epoch]
if len(self.text_ratio) > epoch
else self.text_ratio[-1]
)
bitext_ratio = (
self.bitext_ratio[epoch]
if len(self.bitext_ratio) > epoch
else self.bitext_ratio[-1]
)
return text_ratio, bitext_ratio, sup_ratio, sup_s2s_ratio, unsup_ratio
def get_batch_iterator(
self,
dataset,
max_tokens=None,
max_sentences=None,
max_positions=None,
ignore_invalid_inputs=False,
required_batch_size_multiple=1,
seed=1,
num_shards=1,
shard_id=0,
num_workers=0,
epoch=0,
data_buffer_size=0,
disable_iterator_cache=False,
skip_remainder_batch=False,
grouped_shuffling=False,
update_epoch_batch_itr=False,
):
assert isinstance(dataset, MultiModalityDataset)
if len(dataset.id_to_mode) == 1:
max_positions = dataset.max_positions[0]
max_tokens = dataset.max_tokens[0]
max_sentences = dataset.max_sentences[0]
return super().get_batch_iterator(
dataset,
max_tokens,
max_sentences,
max_positions,
ignore_invalid_inputs,
required_batch_size_multiple,
seed,
num_shards,
shard_id,
num_workers,
epoch,
data_buffer_size,
disable_iterator_cache,
skip_remainder_batch=skip_remainder_batch,
)
mult_ratio = []
(
text_ratio,
bitext_ratio,
sup_ratio,
sup_s2s_ratio,
unsup_ratio,
) = self.get_sample_ratio(epoch)
for mode in dataset.id_to_mode:
if mode in ("sup_speech_ctc", "sup_speech_ali"):
mult_ratio.append(sup_ratio)
elif mode == "sup_speech_s2s":
mult_ratio.append(sup_s2s_ratio)
elif mode == "text":
mult_ratio.append(text_ratio)
elif mode == "bitext":
mult_ratio.append(bitext_ratio)
elif mode == "unsup_speech":
mult_ratio.append(unsup_ratio)
# initialize the dataset with the correct starting epoch
dataset.set_epoch(epoch)
batch_samplers = dataset.get_batch_samplers(
mult_ratio, required_batch_size_multiple, seed
)
# return a reusable, sharded iterator
epoch_iter = GroupedEpochBatchIterator(
dataset=dataset,
collate_fn=dataset.collater,
batch_samplers=batch_samplers,
seed=seed,
num_shards=num_shards,
shard_id=shard_id,
num_workers=num_workers,
epoch=epoch,
mult_rate=max(self.args.update_freq) if self.args.same_data_update else 1,
buffer_size=data_buffer_size,
skip_remainder_batch=skip_remainder_batch,
)
self.dataset_to_epoch_iter[dataset] = {} # refresh it every epoch
return epoch_iter

View File

@ -200,6 +200,8 @@ class LangPairMaskDataset(FairseqDataset):
return self.dataset.sizes
def get_batch_shapes(self):
if hasattr(self.dataset, "get_batch_shapes"):
return self.dataset.get_batch_shapes()
return self.dataset.buckets
def num_tokens_vec(self, indices):

View File

@ -5,22 +5,18 @@
import logging
from pathlib import Path
from typing import Dict, List, Optional, NamedTuple
from typing import Dict, List, NamedTuple, Optional
import torch
from fairseq.data import (
ConcatDataset,
Dictionary,
ResamplingDataset,
data_utils as fairseq_data_utils,
)
from fairseq.data import ConcatDataset, Dictionary, ResamplingDataset
from fairseq.data import data_utils as fairseq_data_utils
from fairseq.data.audio.speech_to_text_dataset import (
SpeechToTextDataset,
S2TDataConfig,
SpeechToTextDataset,
SpeechToTextDatasetCreator,
)
logger = logging.getLogger(__name__)
@ -52,8 +48,12 @@ class S2TJointDataConfig(S2TDataConfig):
def prepend_tgt_lang_tag_no_change(self) -> bool:
"""Prepend target lang ID token as the prev_output_tokens BOS (e.g. for
to-many multilingual setting). No change needed during inference.
This option is deprecated and replaced by prepend_tgt_lang_tag_as_bos.
"""
return self.config.get("prepend_tgt_lang_tag_no_change", False)
value = self.config.get("prepend_tgt_lang_tag_no_change", None)
if value is None:
return self.config.get("prepend_tgt_lang_tag_as_bos", False)
return value
@property
def sampling_text_alpha(self):

View File

@ -5,6 +5,8 @@
from .berard import * # noqa
from .convtransformer import * # noqa
from .multi_modality_model import * # noqa
from .s2t_transformer import * # noqa
from .s2t_wav_transformer import * # noqa
from .xm_transformer import * # noqa
from .s2t_conformer import * # noqa

View File

@ -0,0 +1,49 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from fairseq.models import FairseqDecoder, FairseqEncoder
# a container for different encoders with training samples from different modality
# each time, only one encoder is selected
class MultiModalityEncoder(FairseqEncoder):
def __init__(self, dictionary):
super().__init__(dictionary)
def select_encoder(self, mode, **kwargs):
raise NotImplementedError("Model must implement the select_encoder method")
return None, kwargs
# def post_encoder(self, encoder_out, src_tokens, src_lengths, mode, **kwargs):
# # Default do nothing
# return encoder_out
# get sample data from JointSpeechTextDataset
def forward(self, src_tokens, src_lengths=None, mode="", **kwargs):
encoder, kwargs = self.select_encoder(mode, **kwargs)
# return self.post_encoder(encoder(src_tokens, src_lengths, **kwargs), src_tokens, src_lengths, mode, **kwargs)
return encoder(src_tokens, src_lengths, **kwargs)
# a container for different decoders with training samples from different modality
# each time, only one decoder is selected
class MultiInputDecoder(FairseqDecoder):
def __init__(self, dictionary):
super().__init__(dictionary)
def select_decoder(self, mode, **kwargs):
raise NotImplementedError("Model must implement the select_decoder method")
return None, kwargs
def forward(
self, prev_output_tokens, encoder_out, incremental_state=None, mode="", **kwargs
):
decoder, kwargs = self.select_decoder(mode, **kwargs)
return decoder(
prev_output_tokens,
encoder_out,
incremental_state=incremental_state,
**kwargs
)

View File

@ -295,15 +295,15 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
return sample["target"], sample["target_lengths"]
def get_ctc_output(
self,
net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]],
sample: Optional[Dict[str, Tensor]],
self,
net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]],
sample: Optional[Dict[str, Tensor]]
):
encoder_out = net_output[1]["encoder_out"]["encoder_out"][0]
logits = self.encoder.ctc_proj(encoder_out) # T x B x C
out = utils.log_softmax(logits.float(), dim=-1)
padding_mask = net_output[1]["encoder_out"]["encoder_padding_mask"]
lens = out.new_full((out.shape[1],), out.shape[0]).long()
lens = out.new_full((out.shape[1], ), out.shape[0]).long()
if len(padding_mask) > 0:
lens -= padding_mask[0].sum(dim=-1)
return out, lens
@ -359,7 +359,7 @@ class S2TTransformerEncoder(FairseqEncoder):
self.layer_norm = None
self.ctc_proj = None
if getattr(args, "ctc_weight", 0.0) > 0.0:
if getattr(args, "ctc_weight", 0.) > 0.:
self.ctc_proj = nn.Linear(args.encoder_embed_dim, args.tgt_dict_size)
def _forward(self, src_tokens, src_lengths, return_all_hiddens=False):

View File

@ -0,0 +1,485 @@
#!/usr/bin/env python3
import math
import torch
import torch.nn as nn
from fairseq.data.data_utils import compute_mask_indices
from fairseq.models import FairseqEncoder
from fairseq.models.wav2vec import ConvFeatureExtractionModel
from fairseq.modules import GradMultiply, LayerNorm, SamePad, TransformerEncoderLayer
# Transformer encoder with wave input, it is adopted from wav2vec 2.0 Encoder.
# use wav input
# use trained position embedding so it is easier to match with text input
class SpeechWavTransformerEncoder(FairseqEncoder):
# extra parameters for speech encoder besides those defined in transformermodel
@staticmethod
def add_args(parser):
parser.add_argument(
"--dropout-input",
type=float,
metavar="D",
help="dropout to apply to the input (after feat extr)",
)
parser.add_argument(
"--dropout-features",
type=float,
metavar="D",
help="dropout to apply to the unmasked features (after feat extr)",
)
parser.add_argument(
"--speech-extractor-mode",
type=str,
default="layer_norm",
choices=["default", "layer_norm"],
help="feature extractor norm",
)
parser.add_argument(
"--speech-conv-bias",
action="store_true",
help="include bias in speech conv encoder",
)
parser.add_argument(
"--conv-feature-layers",
default="[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] + [(512,2,2)]",
help="string describing convolutional feature extraction layers in form of a python list that contains [(dim, kernel_size, stride), ...]",
)
parser.add_argument(
"--speech-mask-length",
type=int,
help="repeat the mask indices multiple times",
)
parser.add_argument(
"--speech-mask-prob",
type=float,
help="probability of replacing a token with mask",
)
parser.add_argument(
"--speech-mask-selection",
type=str,
choices=["static", "uniform", "normal", "poisson"],
help="how to choose masks",
)
parser.add_argument(
"--speech-mask-other",
type=float,
help="stdev of the mask length in case of 'normal' selection strategy",
)
parser.add_argument(
"--speech-no-mask-overlap",
action="store_true",
help="whether to allow masks to overlap",
)
parser.add_argument(
"--speech-mask-min-space",
type=int,
help="min space between spans (if no overlap is enabled)",
)
parser.add_argument(
"--speech-mask-channel-length",
type=int,
help="repeat the mask indices multiple times",
)
parser.add_argument(
"--speech-mask-channel-prob",
type=float,
help="probability of replacing a token with mask",
)
parser.add_argument(
"--speech-mask-channel-selection",
type=str,
choices=["static", "uniform", "normal", "poisson"],
help="how to choose masks",
)
parser.add_argument(
"--speech-mask-channel-other",
type=float,
help="stdev of the mask length in case of 'normal' selection strategy",
)
parser.add_argument(
"--speech-no-mask-channel-overlap",
action="store_true",
help="whether to allow masks to overlap",
)
parser.add_argument(
"--no-scale-feature",
action="store_true",
help="no scale for the calculated features",
)
parser.add_argument(
"--speech-mask-channel-min-space",
type=int,
help="min space between spans (if no overlap is enabled)",
)
parser.add_argument(
"--feature-grad-mult",
type=float,
help="reset feature grad mult in wav2vec 2.0 to this",
)
# positional embeddings
parser.add_argument(
"--conv-pos",
type=int,
default=128,
help="number of filters for convolutional positional embeddings",
)
parser.add_argument(
"--conv-pos-groups",
type=int,
default=16,
help="number of groups for convolutional positional embedding",
)
# model configures
parser.add_argument(
"--speech-encoder-layers",
type=int,
help="number of speech encoder layers",
)
parser.add_argument(
"--text-encoder-layers",
type=int,
help="number of text encoder layers",
)
def __init__(self, args, alway_mask=False):
super().__init__(args)
self.args = args
self.dropout = args.dropout
self.embedding_dim = args.encoder_embed_dim
self.feat_scale = math.sqrt(args.encoder_embed_dim)
if args.no_scale_feature:
self.feat_scale = 1.0
subsample = ConvFeatureExtractionModel(
conv_layers=eval(args.conv_feature_layers),
dropout=0.0,
mode=args.speech_extractor_mode, # default, layer_norm
conv_bias=args.speech_conv_bias,
)
feature_enc_layers = eval(args.conv_feature_layers)
self.subsample = subsample
self.feat_proj = (
nn.Linear(feature_enc_layers[-1][0], self.embedding_dim)
if feature_enc_layers[-1][0] != self.embedding_dim
else None
)
self.feat_layer_norm = LayerNorm(feature_enc_layers[-1][0])
self.embed_positions = nn.Conv1d(
self.embedding_dim,
self.embedding_dim,
kernel_size=args.conv_pos,
padding=args.conv_pos // 2,
groups=args.conv_pos_groups,
)
std = math.sqrt(4 / (args.conv_pos * self.embedding_dim))
nn.init.normal_(self.embed_positions.weight, mean=0, std=std)
nn.init.constant_(self.embed_positions.bias, 0)
self.embed_positions = nn.utils.weight_norm(
self.embed_positions, name="weight", dim=2
)
self.embed_positions = nn.Sequential(
self.embed_positions, SamePad(args.conv_pos), nn.GELU()
)
self.mask_prob = args.speech_mask_prob
self.mask_selection = args.speech_mask_selection
self.mask_other = args.speech_mask_other
self.mask_length = args.speech_mask_length
self.no_mask_overlap = args.speech_no_mask_overlap
self.mask_min_space = args.speech_mask_min_space
self.mask_channel_prob = args.speech_mask_channel_prob
self.mask_channel_selection = args.speech_mask_channel_selection
self.mask_channel_other = args.speech_mask_channel_other
self.mask_channel_length = args.speech_mask_channel_length
self.no_mask_channel_overlap = args.speech_no_mask_channel_overlap
self.mask_channel_min_space = args.speech_mask_channel_min_space
self.dropout_input = nn.Dropout(args.dropout_input)
self.dropout_features = nn.Dropout(args.dropout_features)
self.feature_grad_mult = args.feature_grad_mult
self.mask_emb = nn.Parameter(
torch.FloatTensor(args.encoder_embed_dim).uniform_()
)
self.layers = nn.ModuleList(
[TransformerEncoderLayer(args) for _ in range(args.encoder_layers)]
)
self.layer_norm = LayerNorm(args.encoder_embed_dim)
self.normalize_before = args.encoder_normalize_before
self.alway_mask = alway_mask
def apply_mask(self, x, padding_mask):
B, T, C = x.shape
if self.mask_prob > 0:
mask_indices = compute_mask_indices(
(B, T),
padding_mask,
self.mask_prob,
self.mask_length,
self.mask_selection,
self.mask_other,
min_masks=2,
no_overlap=self.no_mask_overlap,
min_space=self.mask_min_space,
)
mask_indices = torch.from_numpy(mask_indices).to(x.device)
x[mask_indices] = self.mask_emb
else:
mask_indices = None
if self.mask_channel_prob > 0:
mask_channel_indices = compute_mask_indices(
(B, C),
None,
self.mask_channel_prob,
self.mask_channel_length,
self.mask_channel_selection,
self.mask_channel_other,
no_overlap=self.no_mask_channel_overlap,
min_space=self.mask_channel_min_space,
)
mask_channel_indices = (
torch.from_numpy(mask_channel_indices)
.to(x.device)
.unsqueeze(1)
.expand(-1, T, -1)
)
x[mask_channel_indices] = 0
return x, mask_indices
def forward(
self,
src_tokens,
src_lengths,
return_all_hiddens=False,
padding_mask=None,
features_only=True,
):
mask = self.training or self.alway_mask
if self.feature_grad_mult > 0 and self.training:
features = self.subsample(src_tokens)
if self.feature_grad_mult != 1.0:
features = GradMultiply.apply(features, self.feature_grad_mult)
else:
with torch.no_grad():
features = self.subsample(src_tokens)
features = features.transpose(1, 2)
features = self.feat_layer_norm(features)
if self.feat_proj is not None:
features = self.feat_proj(features)
if padding_mask is not None:
input_lengths = (1 - padding_mask.long()).sum(-1)
# apply conv formula to get real output_lengths
output_lengths = self._get_feat_extract_output_lengths(input_lengths)
padding_mask = torch.zeros(
features.shape[:2], dtype=features.dtype, device=features.device
)
# these two operations makes sure that all values
# before the output lengths indices are attended to
padding_mask[
(
torch.arange(padding_mask.shape[0], device=padding_mask.device),
output_lengths - 1,
)
] = 1
padding_mask = (1 - padding_mask.flip([-1]).cumsum(-1).flip([-1])).bool()
features = self.feat_scale * features if self.feat_scale != 1.0 else features
unmasked_features = features.clone()
features = self.dropout_input(features)
unmasked_features = self.dropout_features(unmasked_features)
if mask:
x, mask_indices = self.apply_mask(features, padding_mask)
else:
x = features
mask_indices = None
def cal_transformer_layers(x, encoder_padding_mask, return_all_hiddens=False):
# x: B x T x C
positions = self.embed_positions(x.transpose(1, 2)).transpose(1, 2)
x = x + positions
if not self.normalize_before:
x = self.layer_norm(x)
# B x T x C -> T x B x C
x = x.transpose(0, 1)
encoder_states = []
for layer in self.layers:
x = layer(x, encoder_padding_mask)
if return_all_hiddens:
encoder_states.append(x)
if self.normalize_before:
x = self.layer_norm(x)
return x, encoder_states
x, encoder_states = cal_transformer_layers(x, padding_mask, return_all_hiddens)
if features_only:
return {
"encoder_out": [x], # [T x B x C]
"encoder_padding_mask": [padding_mask]
if padding_mask is not None
else [], # B x T
"encoder_embedding": [], #
"encoder_states": encoder_states, # List[T x B x C]
"src_tokens": [],
"src_lengths": [],
"mask_indices": [mask_indices],
}
x_unmasked = x
if self.mask_prob > 0 or self.mask_channel_prob > 0:
x_unmasked, _ = cal_transformer_layers(unmasked_features, padding_mask)
return {
"encoder_out": [x], # [T x B x C]
"encoder_unmasked_out": [x_unmasked], # [T x B x C]
"encoder_padding_mask": [padding_mask]
if padding_mask is not None
else [], # B x T
"encoder_embedding": [], #
"encoder_states": encoder_states, # List[T x B x C]
"src_tokens": [],
"src_lengths": [],
"mask_indices": [mask_indices] if mask_indices is not None else [], # B X T
}
def reorder_encoder_out(self, encoder_out, new_order):
new_encoder_out = (
[]
if len(encoder_out["encoder_out"]) == 0
else [x.index_select(1, new_order) for x in encoder_out["encoder_out"]]
)
new_encoder_padding_mask = (
[]
if len(encoder_out["encoder_padding_mask"]) == 0
else [
x.index_select(0, new_order)
for x in encoder_out["encoder_padding_mask"]
]
)
new_encoder_embedding = (
[]
if len(encoder_out["encoder_embedding"]) == 0
else [
x.index_select(0, new_order) for x in encoder_out["encoder_embedding"]
]
)
encoder_states = encoder_out["encoder_states"]
if len(encoder_states) > 0:
for idx, state in enumerate(encoder_states):
encoder_states[idx] = state.index_select(1, new_order)
return {
"encoder_out": new_encoder_out, # T x B x C
"encoder_padding_mask": new_encoder_padding_mask, # B x T
"encoder_embedding": new_encoder_embedding, # B x T x C
"encoder_states": encoder_states, # List[T x B x C]
"src_tokens": [], # B x T
"src_lengths": [], # B x 1
}
class StackedSpeechWavTransformerEncoder(FairseqEncoder):
def __init__(self, speech_enc, text_enc_layers, text_layer_norm):
super().__init__(None)
self.speech_encoder = speech_enc
self.text_encoder_layers = text_enc_layers
self.final_layer_norm = text_layer_norm
def forward(
self,
src_tokens,
src_lengths=None,
return_all_hiddens=False,
padding_mask=None,
features_only=True,
):
out = self.speech_encoder.forward(
src_tokens,
src_lengths,
return_all_hiddens,
padding_mask=padding_mask,
features_only=features_only,
)
x = out["encoder_out"][0]
encoder_padding_mask = None
if len(out["encoder_padding_mask"]) > 0:
encoder_padding_mask = out["encoder_padding_mask"][0]
def cal_text_layers(x, padding_mask, return_all_hiddens=False):
encoder_states = []
for layer in self.text_encoder_layers:
x = layer(x, padding_mask)
if return_all_hiddens:
encoder_states.append(x)
if self.final_layer_norm is not None:
x = self.final_layer_norm(x)
return x, encoder_states
x, encoder_states = cal_text_layers(x, encoder_padding_mask, return_all_hiddens)
if features_only:
return {
"encoder_out": [x], # T x B x C
"encoder_padding_mask": [encoder_padding_mask]
if encoder_padding_mask is not None
else [], # B x T
"encoder_embedding": [], # B x T x C
"encoder_states": encoder_states, # List[T x B x C]
"src_tokens": [],
"src_lengths": [],
}
x_u = out["encoder_unmasked_out"][0]
x_u, _ = cal_text_layers(x_u, encoder_padding_mask)
return {
"encoder_out": [x], # [T x B x C]
"encoder_unmasked_out": [x_u], # [T x B x C]
"encoder_padding_mask": [encoder_padding_mask]
if encoder_padding_mask is not None
else [], # B x T
"encoder_embedding": [], #
"encoder_states": encoder_states, # List[T x B x C]
"src_tokens": [],
"src_lengths": [],
"mask_indices": out["mask_indices"], # B X T
}
def reorder_encoder_out(self, encoder_out, new_order):
return self.speech_encoder.reorder_encoder_out(encoder_out, new_order)

View File

@ -0,0 +1,76 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import unittest
from collections import namedtuple
from pathlib import Path
import torch
from tqdm import tqdm
import fairseq
from fairseq import utils
from fairseq.checkpoint_utils import load_model_ensemble_and_task
from fairseq.scoring.bleu import SacrebleuScorer
from fairseq.tasks import import_tasks
from tests.speech import S3_BASE_URL, TestFairseqSpeech
@unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU")
class TestLibrispeechDualInputWavTransformer(TestFairseqSpeech):
def setUp(self):
dataset_id = "librispeech_wvtrasnformer"
base_url = "https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/librispeech/finetuned"
data_filenames = [
"checkpoint_ave_10.pt",
"spm.model",
"src_dict.txt",
"tgt_dict.txt",
"config.yaml",
]
self._set_up(
dataset_id,
"s2t",
[
"librispeech_flac_test-other.tsv",
"librispeech_flac_test-other.zip",
],
)
for filename in data_filenames:
self.download(base_url, self.root, filename)
def import_user_module(self):
user_dir = (
Path(fairseq.__file__).parent.parent / "examples/speech_text_joint_to_text"
)
Arg = namedtuple("Arg", ["user_dir"])
arg = Arg(user_dir.__str__())
utils.import_user_module(arg)
@torch.no_grad()
def test_librispeech_dualinput_wav_transformer_checkpoint(self):
self.import_user_module()
checkpoint_filename = "checkpoint_ave_10.pt"
arg_overrides = {
"config_yaml": "config.yaml",
"load_pretrained_speech_text_encoder": "",
"load_pretrained_speech_text_decoder": "",
"beam": 10,
"nbest": 1,
"lenpen": 1.0,
"load_speech_only": True,
}
self.base_test(
checkpoint_filename,
4.6,
dataset="librispeech_flac_test-other",
max_tokens=800000,
max_positions=(800000, 1024),
arg_overrides=arg_overrides,
)
if __name__ == "__main__":
unittest.main()