From 993129dae4fab651524a8976db9486448e752d21 Mon Sep 17 00:00:00 2001 From: Yun Tang Date: Tue, 10 May 2022 19:44:00 -0700 Subject: [PATCH] Merge STPT: Step 3 Summary: 1. Add joint pre-training scripts 2. Replace prepend_tgt_lang_tag_no_change with prepend_tgt_lang_tag_as_bos 3. Add readme for the joint pre-training 4. Add test case for the Librispeech model Reviewed By: hygong-fb Differential Revision: D36300953 fbshipit-source-id: cb749689787ed97c1250d122bdefb7f7a2252292 --- examples/speech_text_joint_to_text/README.md | 19 +- .../criterions/multi_modality_compound.py | 180 +++++ .../multi_modality_cross_entropy.py | 101 +++ .../data/pair_denoising_dataset.py | 318 ++++++++ .../docs/ende-mustc.md | 8 +- .../docs/pre-training.md | 188 +++++ .../joint_speech_text_pretrain_transformer.py | 698 ++++++++++++++++++ .../models/s2t_dualinputwavtransformer.py | 526 +++++++++++++ .../scripts/convert_model.py | 71 ++ .../tasks/pair_denoising.py | 447 +++++++++++ .../tasks/speech_text_denoise_pretrain.py | 654 ++++++++++++++++ fairseq/data/audio/multi_modality_dataset.py | 2 + .../audio/speech_to_text_joint_dataset.py | 20 +- fairseq/models/speech_to_text/__init__.py | 2 + .../speech_to_text/multi_modality_model.py | 49 ++ .../models/speech_to_text/s2t_transformer.py | 10 +- .../speech_to_text/s2t_wav_transformer.py | 485 ++++++++++++ .../speech/test_dual_input_wav_transformer.py | 76 ++ 18 files changed, 3831 insertions(+), 23 deletions(-) create mode 100644 examples/speech_text_joint_to_text/criterions/multi_modality_compound.py create mode 100644 examples/speech_text_joint_to_text/criterions/multi_modality_cross_entropy.py create mode 100644 examples/speech_text_joint_to_text/data/pair_denoising_dataset.py create mode 100644 examples/speech_text_joint_to_text/docs/pre-training.md create mode 100644 examples/speech_text_joint_to_text/models/joint_speech_text_pretrain_transformer.py create mode 100644 examples/speech_text_joint_to_text/models/s2t_dualinputwavtransformer.py create mode 100644 examples/speech_text_joint_to_text/scripts/convert_model.py create mode 100644 examples/speech_text_joint_to_text/tasks/pair_denoising.py create mode 100644 examples/speech_text_joint_to_text/tasks/speech_text_denoise_pretrain.py create mode 100644 fairseq/models/speech_to_text/multi_modality_model.py create mode 100644 fairseq/models/speech_to_text/s2t_wav_transformer.py create mode 100644 tests/speech/test_dual_input_wav_transformer.py diff --git a/examples/speech_text_joint_to_text/README.md b/examples/speech_text_joint_to_text/README.md index e071d241e..c1aa11929 100644 --- a/examples/speech_text_joint_to_text/README.md +++ b/examples/speech_text_joint_to_text/README.md @@ -5,17 +5,16 @@ An extension of Fairseq s2t project with the speech to text task enhanced by the Examples of speech text joint training in fairseq - [English-to-German MuST-C model](docs/ende-mustc.md) - [IWSLT 2021 Multilingual Speech Translation](docs/iwslt2021.md) - +- [Speech Text Joint Pre-training ](docs/pre-training.md) ## Citation Please cite as: ``` -@inproceedings{Tang2021AGM, - title={A General Multi-Task Learning Framework to Leverage Text Data for Speech to Text Tasks}, - author={Yun Tang and J. Pino and Changhan Wang and Xutai Ma and Dmitriy Genzel}, - booktitle={ICASSP}, - year={2021} +@inproceedings{Tang2022UnifiedSP, + title={Unified Speech-Text Pre-training for Speech Translation and Recognition}, + author={Yun Tang and Hongyu Gong and Ning Dong and Changhan Wang and Wei-Ning Hsu and Jiatao Gu and Alexei Baevski and Xian Li and Abdelrahman Mohamed and Michael Auli and Juan Miguel Pino}, + booktitle={ACL}, + year={2022} } - @inproceedings{Tang2021IST, title = {Improving Speech Translation by Understanding and Learning from the Auxiliary Text Translation Task}, author = {Yun Tang and Juan Pino and Xian Li and Changhan Wang and Dmitriy Genzel}, @@ -29,6 +28,12 @@ Please cite as: booktitle = {IWSLT}, year = {2021}, } +@inproceedings{Tang2021AGM, + title={A General Multi-Task Learning Framework to Leverage Text Data for Speech to Text Tasks}, + author={Yun Tang and J. Pino and Changhan Wang and Xutai Ma and Dmitriy Genzel}, + booktitle={ICASSP}, + year={2021} +} @inproceedings{wang2020fairseqs2t, title = {fairseq S2T: Fast Speech-to-Text Modeling with fairseq}, diff --git a/examples/speech_text_joint_to_text/criterions/multi_modality_compound.py b/examples/speech_text_joint_to_text/criterions/multi_modality_compound.py new file mode 100644 index 000000000..292f6fdf4 --- /dev/null +++ b/examples/speech_text_joint_to_text/criterions/multi_modality_compound.py @@ -0,0 +1,180 @@ +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import logging +import math +from dataclasses import dataclass, field + +from fairseq import metrics, utils +from fairseq.criterions import FairseqCriterion, register_criterion +from fairseq.criterions.ctc import CtcCriterion, CtcCriterionConfig +from fairseq.criterions.label_smoothed_cross_entropy import ( + LabelSmoothedCrossEntropyCriterionConfig, +) +from fairseq.logging.meters import safe_round + +from .multi_modality_cross_entropy import SpeechTextPreTrainCrossEntCriterion + +logger = logging.getLogger(__name__) + + +@dataclass +class SpeechTextPreTrainCompoundCriterionConfig( + LabelSmoothedCrossEntropyCriterionConfig +): + zero_infinity: bool = field( + default=False, + metadata={"help": "zero inf loss when source length <= target length"}, + ) + post_process: str = field( + default="none", + metadata={ + "help": "how to post process predictions into words. can be letter, " + "wordpiece, BPE symbols, etc. " + "See fairseq.data.data_utils.post_process() for full list of options" + }, + ) + + +@register_criterion( + "speech_text_pretrain_compound", dataclass=SpeechTextPreTrainCompoundCriterionConfig +) +class SpeechTextPreTrainCompoundCriterion(FairseqCriterion): + def __init__( + self, + task, + sentence_avg, + label_smoothing, + report_accuracy=False, + zero_infinity=False, + post_process=None, + ): + super().__init__(task) + self.xent = SpeechTextPreTrainCrossEntCriterion( + task, sentence_avg, label_smoothing, report_accuracy + ) + cfg_dict = { + "zero_infinity": zero_infinity, + "sentence_avg": sentence_avg, + "post_process": post_process, + } + cfg_ctc = CtcCriterionConfig(**cfg_dict) + self.ctc = CtcCriterion(cfg_ctc, task) + + def forward(self, model, sample, reduce=True): + mode = sample["net_input"]["mode"] + if mode == "sup_speech_ctc": # CTC + sample["net_input"][ + "src_lengths" + ] = None # get downsampled src_lengths from padding_mask + loss, sample_size, logging_output = self.ctc(model, sample, reduce) + logging_output["mode"] = SpeechTextPreTrainCompoundCriterion.mode2value( + "CTC" + ) + else: + loss, sample_size, logging_output = self.xent(model, sample, reduce) + logging_output["mode"] = SpeechTextPreTrainCompoundCriterion.mode2value( + "xent" + ) + + return loss, sample_size, logging_output + + @staticmethod + def logging_outputs_can_be_summed() -> bool: + """ + Whether the logging outputs returned by `forward` can be summed + across workers prior to calling `reduce_metrics`. Setting this + to True will improves distributed training speed. + """ + return True + + @staticmethod + def mode2value(mode): # make the logging_outputs_can_be_summed = True + if mode == "CTC": + return 907 # prime number + if mode == "xent": + return 887 # prime number + return 0 + + @staticmethod + def value2mode(value): + if value % 907 == 0: + return "CTC" + if value % 887 == 0: + return "xent" + raise ValueError("Unknow mode") + + @staticmethod + def reduce_metrics(logging_outputs) -> None: + """Aggregate logging outputs from data parallel training.""" + + def _get_mode(logging_outputs): + mds = [ + SpeechTextPreTrainCompoundCriterion.value2mode(log["mode"]) + for log in logging_outputs + ] + if sum([1 if l != mds[0] else 0 for l in mds]) > 0: + raise ValueError("mode in one mini-batch is expected to be the same!") + return mds[0] + + log_mode = _get_mode(logging_outputs) + if log_mode == "xent": + return SpeechTextPreTrainCrossEntCriterion.reduce_metrics(logging_outputs) + + # ctc loss + loss_sum = utils.item(sum(log.get("loss", 0) for log in logging_outputs)) + ntokens = utils.item(sum(log.get("ntokens", 0) for log in logging_outputs)) + nsentences = utils.item( + sum(log.get("nsentences", 0) for log in logging_outputs) + ) + sample_size = utils.item( + sum(log.get("sample_size", 0) for log in logging_outputs) + ) + + metrics.log_scalar( + "ctc_loss", loss_sum / sample_size / math.log(2), sample_size, round=3 + ) + metrics.log_scalar("ctc_ntokens", ntokens) + metrics.log_scalar("ctc_nsentences", nsentences) + if sample_size != ntokens: + metrics.log_scalar( + "ctc_nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3 + ) + + c_errors = sum(log.get("c_errors", 0) for log in logging_outputs) + metrics.log_scalar("_c_errors", c_errors) + c_total = sum(log.get("c_total", 0) for log in logging_outputs) + metrics.log_scalar("_c_total", c_total) + w_errors = sum(log.get("w_errors", 0) for log in logging_outputs) + metrics.log_scalar("_w_errors", w_errors) + wv_errors = sum(log.get("wv_errors", 0) for log in logging_outputs) + metrics.log_scalar("_wv_errors", wv_errors) + w_total = sum(log.get("w_total", 0) for log in logging_outputs) + metrics.log_scalar("_w_total", w_total) + + if c_total > 0: + metrics.log_derived( + "uer", + lambda meters: safe_round( + meters["_c_errors"].sum * 100.0 / meters["_c_total"].sum, 3 + ) + if meters["_c_total"].sum > 0 + else float("nan"), + ) + if w_total > 0: + metrics.log_derived( + "wer", + lambda meters: safe_round( + meters["_w_errors"].sum * 100.0 / meters["_w_total"].sum, 3 + ) + if meters["_w_total"].sum > 0 + else float("nan"), + ) + metrics.log_derived( + "raw_wer", + lambda meters: safe_round( + meters["_wv_errors"].sum * 100.0 / meters["_w_total"].sum, 3 + ) + if meters["_w_total"].sum > 0 + else float("nan"), + ) diff --git a/examples/speech_text_joint_to_text/criterions/multi_modality_cross_entropy.py b/examples/speech_text_joint_to_text/criterions/multi_modality_cross_entropy.py new file mode 100644 index 000000000..6c9cb0f20 --- /dev/null +++ b/examples/speech_text_joint_to_text/criterions/multi_modality_cross_entropy.py @@ -0,0 +1,101 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import torch + +from fairseq import utils +from fairseq.criterions import register_criterion +from fairseq.criterions.label_smoothed_cross_entropy import ( + LabelSmoothedCrossEntropyCriterion, + LabelSmoothedCrossEntropyCriterionConfig, + label_smoothed_nll_loss, +) + + +@register_criterion( + "speech_text_pretrain_cross_entropy", + dataclass=LabelSmoothedCrossEntropyCriterionConfig, +) +class SpeechTextPreTrainCrossEntCriterion(LabelSmoothedCrossEntropyCriterion): + def __init__(self, task, sentence_avg, label_smoothing, report_accuracy=False): + super().__init__( + task, sentence_avg, label_smoothing, report_accuracy=report_accuracy + ) + + def forward(self, model, sample, reduce=True): + net_output = model(**sample["net_input"]) + loss, nll_loss, nsentences, ntokens, n_correct = self.compute_loss( + model, net_output, sample, reduce=reduce + ) + sample_size = nsentences if self.sentence_avg else ntokens + logging_output = { + "loss": loss.data, + "nll_loss": nll_loss.data, + "ntokens": ntokens, + "nsentences": nsentences, + "sample_size": sample_size, + } + if self.report_accuracy: + logging_output["n_correct"] = utils.item(n_correct) + logging_output["total"] = utils.item(ntokens) + return loss, sample_size, logging_output + + def get_lprobs_and_target(self, model, net_output, sample): + lprobs = model.get_normalized_probs(net_output, log_probs=True) + target = model.get_targets(sample, net_output) + assert self.ignore_prefix_size == 0 + if self.ignore_prefix_size > 0: + if getattr(lprobs, "batch_first", False): + lprobs = lprobs[:, self.ignore_prefix_size :, :].contiguous() + target = target[:, self.ignore_prefix_size :].contiguous() + else: + lprobs = lprobs[self.ignore_prefix_size :, :, :].contiguous() + target = target[self.ignore_prefix_size :, :].contiguous() + return lprobs, target + + def compute_loss(self, model, net_output, sample, reduce=True): + lprobs, target = self.get_lprobs_and_target(model, net_output, sample) + n_correct = 0 + if isinstance(target, dict): + t_lprobs = target["target_logprobs"] + + if not lprobs.batch_first: + lprobs = lprobs.transpose(0, 1) + t_lprobs = t_lprobs.transpose(0, 1) + nsentences, seq_len = lprobs.size()[:2] + ntokens = nsentences * seq_len + t_probs = t_lprobs.exp() + mask_indices = ( + net_output[1]["mask_indices"][0] + if len(net_output[1]["mask_indices"]) > 0 + else None + ) + + # mask_indices is True for those masking frames + if mask_indices is not None: # B X T + t_probs = t_probs.masked_fill(mask_indices.eq(False).unsqueeze(-1), 0) + ntokens = mask_indices.int().sum() + t_probs = t_probs.detach() + t_lprobs = t_lprobs.detach() + loss = ( + -(t_probs * (lprobs - t_lprobs)).sum() + if reduce + else -(t_probs * (lprobs - t_lprobs)).sum(-1, keepdim=True) + ) + nll_loss = loss + else: + nsentences = target.size(0) + mask = target.ne(self.padding_idx) + loss, nll_loss = label_smoothed_nll_loss( + lprobs.view(-1, lprobs.size(-1)), + target.view(-1), + self.eps, + ignore_index=self.padding_idx, + reduce=reduce, + ) + n_correct = torch.sum( + lprobs.argmax(-1).masked_select(mask).eq(target.masked_select(mask)) + ) + ntokens = torch.sum(mask) + return loss, nll_loss, nsentences, ntokens, n_correct diff --git a/examples/speech_text_joint_to_text/data/pair_denoising_dataset.py b/examples/speech_text_joint_to_text/data/pair_denoising_dataset.py new file mode 100644 index 000000000..fc94fbaf1 --- /dev/null +++ b/examples/speech_text_joint_to_text/data/pair_denoising_dataset.py @@ -0,0 +1,318 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import copy +import math +import re + +import torch + +from fairseq.data import data_utils +from fairseq.data.language_pair_dataset import LanguagePairDataset + + +# Part of the code is modified from DenoisingDataset +# compared with DenoisingDataset, no permute_sentences or documents (rotate_ratio, permute_sentence_ratio) +class LanguagePairDenoisingDataset(LanguagePairDataset): + def __init__( + self, + src, + src_sizes, + src_dict, + tgt, + tgt_sizes, + tgt_dict, + mask_idx, + mask_whole_words, + seed, + args, + left_pad_source=True, + left_pad_target=False, + shuffle=True, + input_feeding=True, + remove_eos_from_source=False, + append_eos_to_target=False, + align_dataset=None, + constraints=None, + append_bos=False, + eos=None, + num_buckets=0, + src_lang_id=None, + tgt_lang_id=None, + pad_to_multiple=1, + ): + super().__init__( + src, + src_sizes, + src_dict, + tgt, + tgt_sizes, + tgt_dict, + left_pad_source, + left_pad_target, + shuffle, + input_feeding, + remove_eos_from_source, + append_eos_to_target, + align_dataset, + constraints, + append_bos, + eos, + num_buckets, + src_lang_id, + tgt_lang_id, + pad_to_multiple, + ) + + self.mask_idx = mask_idx + self.mask_whole_word = mask_whole_words + self.mask_ratio = args.mask + self.random_ratio = args.mask_random + self.insert_ratio = args.insert + + self.replace_length = args.replace_length + + if self.replace_length not in [-1, 0, 1]: + raise ValueError(f"invalid arg: replace_length={self.replace_length}") + if args.mask_length not in ["subword", "word", "span-poisson"]: + raise ValueError(f"invalid arg: mask-length={args.mask_length}") + if args.mask_length == "subword" and args.replace_length not in [0, 1]: + raise ValueError("if using subwords, use replace-length=1 or 0") + + self.mask_span_distribution = None + if args.mask_length == "span-poisson": + # Text infilling: "A number of text spans are sampled, with span lengths drawn from a Poisson distribution (λ = 3). Each span is replaced with a single [MASK] token. 0-length spans correspond to the insertion of [MASK] tokens." + _lambda = args.poisson_lambda + + lambda_to_the_k = 1 + e_to_the_minus_lambda = math.exp(-_lambda) + k_factorial = 1 + ps = [] + for k in range(0, 128): + ps.append(e_to_the_minus_lambda * lambda_to_the_k / k_factorial) + lambda_to_the_k *= _lambda + k_factorial *= k + 1 + if ps[-1] < 0.0000001: + break + ps = torch.FloatTensor(ps) + self.mask_span_distribution = torch.distributions.Categorical(ps) + + self.epoch = 0 + self.seed = seed + + def _is_phoneme(x): + if re.search("", + "", + "", + "", + "", + "", + ): + return False + return True + + self.voc_valid_ids = torch.LongTensor( + [i for i, x in enumerate(self.src_dict.symbols) if _is_phoneme(x)] + ) + self.voc_valid_size = self.voc_valid_ids.size(0) + + @property + def can_reuse_epoch_itr_across_epochs(self): + return False + + def set_epoch(self, epoch, **unused): + self.epoch = epoch + + def __getitem__(self, index): + tgt_item = self.tgt[index] if self.tgt is not None else None + src_item = copy.deepcopy(self.src[index]) + with data_utils.numpy_seed(self.seed, self.epoch, index): + source = src_item + assert source[-1] == self.eos + if self.mask_ratio > 0: + source = self.add_whole_word_mask(source, self.mask_ratio) + + if self.insert_ratio > 0: + source = self.add_insertion_noise(source, self.insert_ratio) + src_item = source + + if self.append_eos_to_target: + eos = self.tgt_dict.eos() if self.tgt_dict else self.src_dict.eos() + if self.tgt and self.tgt[index][-1] != eos: + tgt_item = torch.cat([self.tgt[index], torch.LongTensor([eos])]) + + if self.append_bos: + bos = self.tgt_dict.bos() if self.tgt_dict else self.src_dict.bos() + if self.tgt and self.tgt[index][0] != bos: + tgt_item = torch.cat([torch.LongTensor([bos]), self.tgt[index]]) + + bos = self.src_dict.bos() + if src_item[0] != bos: + src_item = torch.cat([torch.LongTensor([bos]), src_item]) + + if self.remove_eos_from_source: + eos = self.src_dict.eos() + if src_item[-1] == eos: + src_item = src_item[:-1] + + example = { + "id": index, + "source": src_item, + "target": tgt_item, + } + if self.align_dataset is not None: + example["alignment"] = self.align_dataset[index] + if self.constraints is not None: + example["constraints"] = self.constraints[index] + if self.src_lang_id is not None: + example["src_lang_id"] = self.src_lang_id + if self.tgt_lang_id is not None: + example["tgt_lang_id"] = self.tgt_lang_id + return example + + # following functions are borrowed from denoising_dataset + def word_starts(self, source): + if self.mask_whole_word is not None: + is_word_start = self.mask_whole_word.gather(0, source) + else: + is_word_start = torch.ones(source.size()) + is_word_start[0] = 0 + is_word_start[-1] = 0 + return is_word_start + + def add_whole_word_mask(self, source, p): + is_word_start = self.word_starts(source) + num_to_mask = int(math.ceil(is_word_start.float().sum() * p)) + num_inserts = 0 + if num_to_mask == 0: + return source + + if self.mask_span_distribution is not None: + lengths = self.mask_span_distribution.sample(sample_shape=(num_to_mask,)) + + # Make sure we have enough to mask + cum_length = torch.cumsum(lengths, 0) + while cum_length[-1] < num_to_mask: + lengths = torch.cat( + [ + lengths, + self.mask_span_distribution.sample(sample_shape=(num_to_mask,)), + ], + dim=0, + ) + cum_length = torch.cumsum(lengths, 0) + + # Trim to masking budget + i = 0 + while cum_length[i] < num_to_mask: + i += 1 + lengths[i] = num_to_mask - (0 if i == 0 else cum_length[i - 1]) + num_to_mask = i + 1 + lengths = lengths[:num_to_mask] + + # Handle 0-length mask (inserts) separately + lengths = lengths[lengths > 0] + num_inserts = num_to_mask - lengths.size(0) + num_to_mask -= num_inserts + if num_to_mask == 0: + return self.add_insertion_noise(source, num_inserts / source.size(0)) + + assert (lengths > 0).all() + else: + lengths = torch.ones((num_to_mask,)).long() + assert is_word_start[-1] == 0 + word_starts = is_word_start.nonzero(as_tuple=False) + indices = word_starts[ + torch.randperm(word_starts.size(0))[:num_to_mask] + ].squeeze(1) + mask_random = torch.FloatTensor(num_to_mask).uniform_() < self.random_ratio + + source_length = source.size(0) + assert source_length - 1 not in indices + to_keep = torch.ones(source_length, dtype=torch.bool) + is_word_start[ + -1 + ] = 255 # acts as a long length, so spans don't go over the end of doc + if self.replace_length == 0: + to_keep[indices] = 0 + else: + # keep index, but replace it with [MASK] + source[indices] = self.mask_idx + source[indices[mask_random]] = self.voc_valid_ids[ + torch.randint(0, self.voc_valid_size - 1, size=(mask_random.sum(),)) + ] + + if self.mask_span_distribution is not None: + assert len(lengths.size()) == 1 + assert lengths.size() == indices.size() + lengths -= 1 + while indices.size(0) > 0: + assert lengths.size() == indices.size() + lengths -= is_word_start[indices + 1].long() + uncompleted = lengths >= 0 + indices = indices[uncompleted] + 1 + mask_random = mask_random[uncompleted] + lengths = lengths[uncompleted] + if self.replace_length != -1: + # delete token + to_keep[indices] = 0 + else: + # keep index, but replace it with [MASK] + source[indices] = self.mask_idx + source[indices[mask_random]] = self.voc_valid_ids[ + torch.randint( + 0, self.voc_valid_size - 1, size=(mask_random.sum(),) + ) + ] + else: + # A bit faster when all lengths are 1 + while indices.size(0) > 0: + uncompleted = is_word_start[indices + 1] == 0 + indices = indices[uncompleted] + 1 + mask_random = mask_random[uncompleted] + if self.replace_length != -1: + # delete token + to_keep[indices] = 0 + else: + # keep index, but replace it with [MASK] + source[indices] = self.mask_idx + source[indices[mask_random]] = self.voc_valid_ids[ + torch.randint( + 0, self.voc_valid_size - 1, size=(mask_random.sum(),) + ) + ] + + assert source_length - 1 not in indices + + source = source[to_keep] + + if num_inserts > 0: + source = self.add_insertion_noise(source, num_inserts / source.size(0)) + + return source + + def add_insertion_noise(self, tokens, p): + if p == 0.0: + return tokens + + num_tokens = len(tokens) + n = int(math.ceil(num_tokens * p)) + + noise_indices = torch.randperm(num_tokens + n - 2)[:n] + 1 + noise_mask = torch.zeros(size=(num_tokens + n,), dtype=torch.bool) + noise_mask[noise_indices] = 1 + result = torch.LongTensor(n + len(tokens)).fill_(-1) + + num_random = int(math.ceil(n * self.random_ratio)) + result[noise_indices[num_random:]] = self.mask_idx + result[noise_indices[:num_random]] = self.voc_valid_ids[ + torch.randint(0, self.voc_valid_size - 1, size=(num_random,)) + ] + + result[~noise_mask] = tokens + + assert (result >= 0).all() + return result diff --git a/examples/speech_text_joint_to_text/docs/ende-mustc.md b/examples/speech_text_joint_to_text/docs/ende-mustc.md index ad9e222ce..1acf6e001 100644 --- a/examples/speech_text_joint_to_text/docs/ende-mustc.md +++ b/examples/speech_text_joint_to_text/docs/ende-mustc.md @@ -22,11 +22,17 @@ Enhanced Joint Training: the joint training is enhanced with pre-trained models, --out-path ${must_c_en_de_src_text_pho} ``` - Replace the source text under the "src_text" column in the tsv file with the corresponding phoneme reprentation generated in the step above. +Below is the snapshot for the MuST-C en-de dev tsv +``` +id audio n_frames tgt_text src_text speaker +ted_767_0 en-de/flac.zip:10071514743:48445 56160 Heute spreche ich zu Ihnen über Energie und Klima. ▁AY1 M ▁G OW1 IH0 NG ▁T UW1 ▁T AO1 K ▁T AH0 D EY1 ▁AH0 B AW1 T ▁EH1 N ER0 JH IY0 ▁AH0 N D ▁K L AY1 M AH0 T spk.767_ +ted_767_1 en-de/flac.zip:1214217978:205678 226080 Und das überrascht vielleicht etwas, weil sich meine Vollzeitbeschäftigung bei der Stiftung hauptsächlich um Impfstoffe und Saatgut dreht, um die Dinge, die wir erfinden und liefern müssen um den ärmsten 2 Milliarden ein besseres Leben zu ermöglichen. ▁AH0 N D ▁DH AE1 T ▁M AY1 T ▁S IY1 M ▁AH0 ▁B IH1 T ▁S ER0 P R AY1 Z IH0 NG ▁B IH0 K AO1 Z ▁M AY1 ▁F UH1 L ▁T AY1 M ▁W ER1 K ▁AE1 T ▁DH AH0 ▁F AW0 N D EY1 SH AH0 N ▁IH1 Z ▁M OW1 S T L IY0 ▁AH0 B AW1 T ▁V AE2 K S IY1 N Z ▁AH0 N D ▁S IY1 D Z ▁AH0 B AW1 T ▁DH AH0 ▁TH IH1 NG Z ▁DH AE1 T ▁W IY1 ▁N IY1 D ▁T UW1 ▁IH0 N V EH1 N T ▁AH0 N D ▁D IH0 L IH1 V ER0 ▁T UW1 ▁HH EH1 L P ▁DH AH0 ▁P UH1 R IH0 S T ▁T UW1 ▁B IH1 L Y AH0 N ▁L AY1 V ▁B EH1 T ER0 ▁L IH1 V Z spk.767_ +``` - Prepare phoneme dictionary and save to $MANIFEST_ROOT as [src_dict.txt](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/must_c/en_de/src_dict.txt) #### Prepare WMT text data - [Download wmt data](https://github.com/pytorch/fairseq/blob/main/examples/translation/prepare-wmt14en2de.sh) - Convert source text (English) into phoneme representation as above -- Generate binary parallel file for training (as translation example) and save data in $parallel_text_data +- Generate binary parallel files with "fairseq-preprocess" from fairseq for training and validation. The source input is English phoneme representation and the target input is German sentencepiece token . The output is saved under $parallel_text_data ## Training The model is trained with 8 v100 GPUs. diff --git a/examples/speech_text_joint_to_text/docs/pre-training.md b/examples/speech_text_joint_to_text/docs/pre-training.md new file mode 100644 index 000000000..20272e616 --- /dev/null +++ b/examples/speech_text_joint_to_text/docs/pre-training.md @@ -0,0 +1,188 @@ +[[Back]](..) + +# Unified Speech-Text Pre-training for Speech Translation and Recognition + +This directory contains the pre-training recipes from paper ["Unified Speech-Text Pre-training for Speech Translation and Recognition"](https://arxiv.org/abs/2204.05409). + +## Librispeech ASR Pre-training +### Prepare Data +#### Download files +#### Prepare pre-training data +- Text to text task (T2T): prepare the binary data following the similar steps in [EN_DE Joint training](./ende-mustc.md). The source data is presented as phomeme token sequence and the target data is coded as subword tokens via SentencePiece. The text data is downloaded from [openslr](https://www.openslr.org/12) +- Self-supervised speech learning task (SSL): The data is prepared as [wav2vec 2.0](https://github.com/pytorch/fairseq/tree/main/examples/wav2vec/README.md) +- Speech to phoneme classification task (S2P): The tsv file contains 5 fields: "id", "audio", "n_frames", "tgt_text", and "align". The tgt_text field is corresponding to the phoneme based representation of the speech data. "align" field contains the alignment information. The phoneme level forced alignment for the labelled speech data (i.e. Librispeech) can be obtained via [kaldi](http://kaldi-asr.org) or [MFA](https://montrealcorpustools.github.io/Montreal-Forced-Aligner/). The segmentation information is normalized to 0$\sim$1 for the whole utterance. The snapshot of the tsv file is below: +``` +id audio n_frames tgt_text align +116-288045-0000 /librispeech/dev-other/116/288045/116-288045-0000.flac 170400 ▁AE1 Z AY1 ▁AH0 P R OW1 CH T ▁DH AH1 ▁S IH1 T IY0 AY1 ▁HH ER1 D ▁B EH1 L Z ▁R IH1 NG IH0 NG ▁AE1 N D AH0 ▁L IH1 T AH0 L ▁L EY1 T ER0 AY1 ▁F AW1 N D ▁DH AH0 ▁S T R IY1 T S ▁AH0 S T IH1 R ▁W IH0 TH ▁TH R AO1 NG Z ▁AH0 V ▁W EH1 L ▁D R EH1 S T ▁P IY1 P AH0 L ▁IH1 N ▁F AE1 M L IY0 ▁G R UW1 P S ▁W EH1 N D IH0 NG ▁DH EH1 R ▁W EY1 ▁HH IH1 DH ER0 ▁AH0 N D ▁TH IH1 DH ER0 0.047977 0.056444 0.064911 0.075259 0.081844 0.089370 0.095014 0.104421 0.109125 0.111947 0.115710 0.120414 0.134525 0.141110 0.143932 0.174036 0.176858 0.190028 0.199436 0.207902 0.218250 0.224835 0.231421 0.242709 0.251176 0.257761 0.263405 0.268109 0.270931 0.290687 0.342427 0.349953 0.353716 0.356538 0.360301 0.363123 0.365945 0.368768 0.371590 0.376294 0.384760 0.394167 0.401693 0.409219 0.419567 0.430856 0.441204 0.444026 0.446849 0.449671 0.456256 0.463782 0.471308 0.477893 0.486359 0.491063 0.494826 0.501411 0.512700 0.517404 0.520226 0.534337 0.540922 0.545626 0.550329 0.559737 0.568203 0.583255 0.592662 0.600188 0.603951 0.611477 0.619003 0.624647 0.634055 0.639699 0.646284 0.653810 0.659454 0.664158 0.670743 0.682032 0.687676 0.692380 0.708373 0.713076 0.719661 0.729069 0.740357 0.744120 0.748824 0.752587 0.761994 0.770461 0.781750 0.790216 0.805268 0.808090 0.823142 0.832549 0.836312 0.840075 0.843838 0.851364 0.854186 0.857008 0.862653 0.878645 0.898401 0.901223 0.906867 0.913452 0.920038 0.926623 0.934149 0.939793 0.942615 0.945437 0.952023 0.957667 0.977422 1.000000 + +``` +- Speech to text task (S2T): The data preparation follow the steps in [EN_DE Joint training](./ende-mustc.md). + +#### Prepare fine-tuning data: +We re-use the data from T2T and S2T tasks in the fine-tuning stage. + +### Model Build +#### Pre-training +``` +python train.py $T2T_DATA \ + --save-dir $SAVE_PRE_PATH --user-dir examples/speech_text_joint_to_text --task speech_text_joint_denoising \ + --criterion speech_text_pretrain_cross_entropy --optimizer adam --weight-decay 0.01 --config-yaml config_s2p.yaml --config-s2s-yaml config.yaml --ddp-backend no_c10d \ + --lang-pairs pho-wrd --num-workers 4 --log-interval 500 --save-interval-updates 5000 --keep-interval-updates 1 --no-emb-update-unsup --report-accuracy --lr 0.001 --end-learning-rate 1e-06 \ + --lr-scheduler polynomial_decay --warmup-updates 10000 --total-num-update 800000 --update-freq 6 --validate-interval-updates 10000 --train-subset train \ + --valid-subset valid,valid_sup_speech,valid_sup_speech_s2s,valid_unsup_speech --dataset-impl mmap \ + --sup-speech-data $S2P_DATA_PATH --sup-speech-train-subset train_960.ali --sup-speech-valid-subset dev-clean.ali --sup-speech-s2s-data $S2T_DATA_PATH \ + --sup-speech-s2s-train-subset train --sup-speech-s2s-valid-subset dev-clean --unsup-speech-train-data $SSL_DATA_PATH/train.tsv --unsup-speech-valid-data $SSL_DATA_PATH/valid.tsv \ + --batch-size 200 --batch-size-valid 150 --max-source-positions 1024 --max-target-positions 1024 --max-text-tokens 3072 --max-speech-positions 600000 \ + --max-sample-size 750000 --min-sample-size 64000 --max-speech-tokens 750000 --max-tokens-valid 750000 --skip-invalid-size-inputs-valid-test \ + --unsupervised-speech-sample-ratio 3.0 --supervised-speech-sample-ratio 5 --supervised-speech-s2s-sample-ratio 5 --text-sample-ratio 1.0 --mask 0.3 --mask-random 0.1 \ + --mask-length span-poisson --speech-sup-mask-prob 0.3 --speech-unsup-mask-prob 0.7 --use-mask-whole-words --arch speech_text_pretrain_bart_base_stack \ + --no-scale-feature --activation-fn gelu --speech-extractor-mode default --stacked-encoder all --encoder-normalize-before --decoder-normalize-before \ + --encoder-learned-pos --decoder-learned-pos --dropout 0.1 --load-pretrained-mbart-encoder-from $BART --load-pretrained-mbart-decoder-from $BART +``` +The current implementation also supports model pre-training without the forced alignment supervised data. In this case, CTC is used to optimize the S2P task. We need to do following changes for the setting: +1. options to be added +``` +--use-sup-speech-ctc --criterion speech_text_pretrain_compound +``` +2. options to be deleted +``` +--same-data-update --criterion speech_text_pretrain_cross_entropy +``` +However, we find the CTC based pre-training is still worse than the forced alignment based setting. It could be partially due to the inferior pre-training setting that we re-use the forced alignment based pre-training setting for the CTC based pre-training. + +#### Fine-tuning +``` +python train.py $S2T_DATA_PATH \ + --save-dir $SAVE_FT_PATH --num-workers 8 --task speech_text_joint_to_text --arch dualinputs2twavtransformer_base_stack \ + --user-dir examples/speech_text_joint_to_text --max-update 100000 --optimizer adam --lr-scheduler inverse_sqrt --lr 0.0003 --update-freq 3 --clip-norm 10.0 \ + --criterion guided_label_smoothed_cross_entropy_with_accuracy --guide-alpha 0.8 --label-smoothing 0.1 --warmup-updates 20000 --attentive-cost-regularization 0.02 \ + --enc-grad-mult 2.0 --max-tokens 800000 --max-source-positions 800000 --max-tokens-text 10000 --max-positions-text 1024 --max-target-positions 1024 --no-scale-feature \ + --activation-fn gelu --load-pretrained-speech-text-encoder $SAVE_PRE_PATH/checkpoint_last.pt --load-pretrained-speech-text-decoder $SAVE_PRE_PATH/checkpoint_last.pt \ + --encoder-normalize-before --decoder-normalize-before --speech-extractor-mode default --speech-mask-channel-length 64 --speech-mask-channel-prob 0.5 \ + --speech-mask-length 10 --speech-mask-prob 0.65 --text-sample-ratio 0.25 --mask-text-ratio 0.3 --mask-text-type random --parallel-text-data text_bin \ + --text-input-cost-ratio 0.5 --langpairs pho-wrd --update-mix-data --log-format json --max-tokens-valid 800000 --ddp-backend no_c10d --log-interval 500 \ + --config-yaml config.yaml --skip-invalid-size-inputs-valid-test --keep-last-epochs 50 --layernorm-embedding --encoder-learned-pos --decoder-learned-pos +``` + +### Evaluation +The last 10 epoch models from fine-tuning is conducted model average to get $FINAL_MODEL +``` +python ./fairseq_cli/generate.py \ + $S2T_DATA_PATH \ + --task speech_text_joint_to_text \ + --max-tokens 800000 \ + --max-source-positions 800000 \ + --nbest 1 \ + --results-path $RESULTS_LOG \ + --batch-size 512 \ + --path $FINAL_MODEL \ + --gen-subset $SUBSET \ + --config-yaml config.yaml \ + --scoring wer \ + --beam 10 --lenpen 1.0 examples/speech_text_joint_to_text +``` + +### Results and models +| | dev-clean | dev-other | test-clean | test-other | +|---|---|---|---|---| +| WER| 2.0 | 4.4 | 2.1 |4.6 | + +**Model Links**: +- [config_s2p.yaml](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/librispeech/pretrain/config_s2p.yaml): Config for S2P +- [spm.model](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/librispeech/finetuned/spm.model): Sentence Piece model +- [src_dict.txt](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/librispeech/finetuned/src_dict.txt): Source Phoneme Dictionary +- [tgt_dict.txt](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/librispeech/finetuned/tgt_dict.txt): Target Sentence Piece Dictionary +- [config.yaml](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/librispeech/finetuned/config.yaml): Config for S2T +- [BART](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/librispeech/pretrain/bart.pt): trained from Librispeech text data +- [Joint Pre-trained model](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/librispeech/pretrain/checkpoint6.pt): model pre-trained with 960 hours Librispeech data (S2P, S2T) Librispeech text training data (T2T) and Librilight data (SSL) +- [Fine-tuned model](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/librispeech/finetuned/checkpoint_ave_10.pt): the pre-trained model is fined one 960 hours Librispeech speech and text data. (S2T + T2T) + +## MuST-C +### Prepare Data +Compared with the ASR Librispeech ASR recipe, the differences are below: +- Replace the speech data with corresponding MuST-C data +- Parallel text data from WMT is replaced the Librispeech text data + +### Model Build +#### Pre-training +EN-DE is used as an example +``` +python train.py $TXT_DATA \ + --save-dir $SAVE_PRE_PATH --user-dir examples/speech_text_joint_to_text --task speech_text_joint_denoising --criterion speech_text_pretrain_cross_entropy --optimizer adam --weight-decay 0.01 \ + --config-yaml config_s2p.yaml --config-s2s-yaml config.yaml --ddp-backend no_c10d --lang-pairs-bitext en-fr --num-workers 4 --log-interval 500 --save-interval-updates 5000 --keep-interval-updates 1 \ + --no-emb-update-unsup --use-decoder-output-proj --report-accuracy --lr 0.001 --end-learning-rate 1e-06 --lr-scheduler polynomial_decay --warmup-updates 10000 --total-num-update 800000 \ + --update-freq 8 --validate-interval-updates 10000 --train-subset train --valid-subset valid_sup_speech,valid_sup_speech_s2s,valid_unsup_speech --dataset-impl mmap \ + --sup-speech-data $S2P_DATA_PATH --sup-speech-train-subset train --sup-speech-valid-subset dev --sup-speech-s2s-data $S2T_DATA_PATH --sup-speech-s2s-train-subset train \ + --sup-speech-s2s-valid-subset dev --unsup-speech-train-data $SSL_DATA_PATH/train.tsv --unsup-speech-valid-data $SSL_DATA_PATH/valid.tsv --batch-size 200 --batch-size-valid 100 \ + --max-source-positions 1024 --max-target-positions 1024 --max-text-tokens 2048 --max-speech-positions 600000 --max-sample-size 600000 --min-sample-size 64000 \ + --max-speech-tokens 600000 --max-tokens-valid 600000 --skip-invalid-size-inputs-valid-test --unsupervised-speech-sample-ratio 1.2 --supervised-speech-sample-ratio 10 \ + --supervised-speech-s2s-sample-ratio 10 --bitext-sample-ratio 0.5 --mask 0.3 --mask-random 0.1 --mask-length span-poisson --speech-sup-mask-prob 0.3 \ + --speech-unsup-mask-prob 0.7 --use-mask-whole-words --arch speech_text_pretrain_bart_base_stack --no-scale-feature --activation-fn gelu --speech-extractor-mode default \ + --stacked-encoder s2s --encoder-normalize-before --decoder-normalize-before --encoder-learned-pos --decoder-learned-pos --dropout 0.1 \ + --load-pretrained-mbart-encoder-from $EN_FR_NMT --load-pretrained-mbart-decoder-from $EN_FR_NMT +``` +#### Fine-tuning +``` +python train.py $S2T_DATA_PATH \ + --save-dir $SAVE_FT_PATH --num-workers 8 --task speech_text_joint_to_text --arch dualinputs2twavtransformer_base_stack --user-dir examples/speech_text_joint_to_text \ + --max-epoch 25 --update-mix-data --optimizer adam --lr-scheduler inverse_sqrt --lr 0.0003 --update-freq 4 --clip-norm 10.0 --warmup-updates 20000 \ + --criterion guided_label_smoothed_cross_entropy_with_accuracy --guide-alpha 0.8 --attentive-cost-regularization 0.02 --enc-grad-mult 2.0 --label-smoothing 0.1 \ + --max-tokens 800000 --max-source-positions 800000 --max-tokens-text 10000 --max-positions-text 1024 --load-pretrained-speech-text-encoder $SAVE_PRE_PATH/checkpoint_last.pt \ + --load-pretrained-speech-text-decoder $SAVE_PRE_PATH/checkpoint_last.pt --speech-mask-channel-length 64 --speech-mask-channel-prob 0.5 --speech-mask-length 10 \ + --speech-mask-prob 0.65 --text-sample-ratio 0.05 --mask-text-ratio 0.3 --mask-text-type random --parallel-text-data data-bin-wt --text-input-cost-ratio 0.5 \ + --langpairs en-fr --log-format json --max-tokens-valid 800000 --ddp-backend no_c10d --log-interval 100 --config-yaml config.yaml --skip-invalid-size-inputs-valid-test \ + --noise-token '▁NOISE' --keep-last-epochs 40 --layernorm-embedding --encoder-learned-pos --decoder-learned-pos --activation-fn gelu \ + --speech-extractor-mode default --max-target-positions 1024 --encoder-normalize-before --decoder-normalize-before +``` + +### Evaluation +The last 10 epoch models from fine-tuning is conducted model average to get $FINAL_MODEL +``` +python fairseq_cli/generate.py \ + $S2T_DATA_PATH \ + --task speech_text_joint_to_text \ + --nbest 1 \ + --max-tokens 800000 \ + --max-source-positions 800000 \ + --results-path $RESULTS_LOG \ + --batch-size 512 \ + --path $FINAL_MODEL \ + --gen-subset $SUBSET \ + --config-yaml config.yaml \ + --scoring sacrebleu \ + --beam 10 --lenpen 1.0 examples/speech_text_joint_to_text +``` + + +### Results and models +| | en-fr | en-es | en-de | +|---|---|---|---| +| BLEU| 39.7 | 33.2 |29.2 | + + +**Model Links**: +1. DE + - [de config.yaml](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/must_c/de/config.yaml) + - [de src_dict.txt](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/must_c/de/src_dict.txt) + - [de tgt_dict.txt](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/must_c/de/tgt_dict.txt) + - [de spm.model](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/must_c/de/spm.model) + - [de pre-trained nmt model](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/must_c/de/nmt.pt) + - [de pre-trained model](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/must_c/de/checkpoint_pretraing.pt) + - [de fine-tuned model](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/must_c/de/checkpoint_finetune_ave10.pt) +2. ES + - [es config.yaml](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/must_c/es/config.yaml) + - [es src_dict.txt](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/must_c/es/src_dict.txt) + - [es tgt_dict.txt](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/must_c/es/tgt_dict.txt) + - [es spm.model](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/must_c/es/spm.model) + - [es pre-trained nmt model](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/must_c/es/nmt.pt) + - [es pre-trained model](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/must_c/es/checkpoint_pretraing.pt) + - [es fine-tuned model](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/must_c/es/checkpoint_finetune_ave10.pt) +3. FR + - [fr config.yaml](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/must_c/fr/config.yaml) + - [fr src_dict.txt](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/must_c/fr/src_dict.txt) + - [fr tgt_dict.txt](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/must_c/fr/tgt_dict.txt) + - [fr spm.model](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/must_c/fr/spm.model) + - [fr pre-trained nmt model](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/must_c/fr/nmt.pt) + - [fr pre-trained model](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/must_c/fr/checkpoint_pretraing.pt) + - [fr fine-tuned model](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/must_c/fr/checkpoint_finetune_ave10.pt) +4. [config_s2p.yaml](https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/must_c/config_s2p.yaml) diff --git a/examples/speech_text_joint_to_text/models/joint_speech_text_pretrain_transformer.py b/examples/speech_text_joint_to_text/models/joint_speech_text_pretrain_transformer.py new file mode 100644 index 000000000..6f917398a --- /dev/null +++ b/examples/speech_text_joint_to_text/models/joint_speech_text_pretrain_transformer.py @@ -0,0 +1,698 @@ +#!/usr/bin/env python3 + +import logging +from collections import OrderedDict, namedtuple +from typing import Dict, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor + +from fairseq import checkpoint_utils, utils +from fairseq.file_io import PathManager +from fairseq.models import ( + FairseqDecoder, + FairseqEncoderDecoderModel, + register_model, + register_model_architecture, +) +from fairseq.models.speech_to_text import ( + MultiInputDecoder, + MultiModalityEncoder, + SpeechWavTransformerEncoder, + StackedSpeechWavTransformerEncoder, +) +from fairseq.models.transformer import ( + TransformerDecoder, + TransformerEncoder, + TransformerModel, +) + +logger = logging.getLogger(__name__) + + +class SpeechTextPreTrainEncoder(MultiModalityEncoder): + def __init__( + self, + dictionary, + sup_speech_encoder, + sup_s2s_speech_encoder, + unsup_speech_encoder, + text_encoder, + ): + super().__init__(dictionary) + self.sup_speech_encoder = sup_speech_encoder + self.sup_s2s_speech_encoder = sup_s2s_speech_encoder + self.unsup_speech_encoder = unsup_speech_encoder + self.text_encoder = text_encoder + + @classmethod + def update_transformer_encoder_cfg(cls, args, update_dict): + cfg = dict(args._get_kwargs()) + for fkey in update_dict.keys(): + cfg[fkey] = update_dict[fkey] + cfg.pop("_name", None) # remove keys start with _ + model_args = namedtuple("args", cfg.keys())(*cfg.values()) + return model_args + + @classmethod + def build_text_encoder(cls, args, src_dictionary): + enc_emb = nn.Embedding( + len(src_dictionary), args.encoder_embed_dim, src_dictionary.pad() + ) + model_args = cls.update_transformer_encoder_cfg( + args, {"encoder_layers": args.text_encoder_layers} + ) + text_encoder = TransformerEncoder(model_args, src_dictionary, enc_emb) + return text_encoder + + @classmethod + def build_speech_encoder(cls, args): + model_args = cls.update_transformer_encoder_cfg( + args, + { + "encoder_layers": args.speech_encoder_layers, + "speech_mask_prob": args.speech_sup_mask_prob, + }, + ) + speech_encoder = SpeechWavTransformerEncoder(model_args) + return speech_encoder + + @classmethod + def share_layers(cls, src_layers, tgt_layers): # share layer but not dropout + # share parameters in src_layers with tgt_layers + assert len(src_layers) == len(tgt_layers) + for i, ly in enumerate(src_layers): + tly = tgt_layers[i] + tly.self_attn = ly.self_attn + tly.self_attn_layer_norm = ly.self_attn_layer_norm + tly.activation_fn = ly.activation_fn + tly.normalize_before = ly.normalize_before + tly.fc1 = ly.fc1 + tly.fc2 = ly.fc2 + tly.final_layer_norm = ly.final_layer_norm + if hasattr(tly, "encoder_attn"): + tly.encoder_attn = ly.encoder_attn + tly.encoder_attn_layer_norm = ly.encoder_attn_layer_norm + return tgt_layers + + @classmethod + def build_unsup_speech_encoder(cls, args, sup_speech_encoder): + model_args = cls.update_transformer_encoder_cfg( + args, + { + "encoder_layers": args.speech_encoder_layers, + "speech_mask_prob": args.speech_unsup_mask_prob, + "encoder_layerdrop": 0.0, + "decoder_layerdrop": 0.0, + "dropout": args.speech_unsup_dropout, + "activation_dropout": args.speech_unsup_dropout, + "attention_dropout": 0.0, + "dropout_features": args.speech_unsup_feature_dropout, + "dropout_input": args.speech_unsup_feature_dropout, + }, + ) + + unsup_speech_encoder = SpeechWavTransformerEncoder(model_args, alway_mask=True) + unsup_speech_encoder.layer_norm = sup_speech_encoder.layer_norm + unsup_speech_encoder.layers = cls.share_layers( + sup_speech_encoder.layers, unsup_speech_encoder.layers + ) + unsup_speech_encoder.mask_emb = sup_speech_encoder.mask_emb + unsup_speech_encoder.embed_positions = sup_speech_encoder.embed_positions + unsup_speech_encoder.feat_layer_norm = sup_speech_encoder.feat_layer_norm + unsup_speech_encoder.feat_proj = sup_speech_encoder.feat_proj + unsup_speech_encoder.subsample = sup_speech_encoder.subsample + return unsup_speech_encoder + + @classmethod + def build_encoder(cls, args, dictionary): + text_encoder = cls.build_text_encoder(args, dictionary) + if getattr(args, "load_pretrained_mbart_encoder_from", None): + text_encoder = checkpoint_utils.load_pretrained_component_from_model( + component=text_encoder, + checkpoint=args.load_pretrained_mbart_encoder_from, + ) + speech_encoder = cls.build_speech_encoder(args) + if getattr(args, "load_pretrained_feature_extractor_from", None): + + def load_feature_extractor(component, checkpoint): + if not PathManager.exists(checkpoint): + raise IOError("Model file not found: {}".format(checkpoint)) + state = checkpoint_utils.load_checkpoint_to_cpu(checkpoint) + component_state_dict = OrderedDict() + + component_prefix = "feature_extractor" + for key in state["model"].keys(): + if key.startswith(component_prefix): + component_subkey = key[len(component_prefix) + 1 :] + component_state_dict[component_subkey] = state["model"][key] + component.load_state_dict(component_state_dict, strict=True) + return component + + speech_encoder.subsample = load_feature_extractor( + speech_encoder.subsample, args.load_pretrained_feature_extractor_from + ) + speech_s2s_encoder = speech_encoder + unsup_speech_encoder = cls.build_unsup_speech_encoder(args, speech_encoder) + if getattr(args, "stacked_encoder", "none") != "none": + if args.encoder_shared_text_layers_from_begin > 0: + raise ValueError( + "We can not stack encoders and share encoders at the same time!" + ) + speech_s2s_encoder = StackedSpeechWavTransformerEncoder( + speech_encoder, text_encoder.layers, text_encoder.layer_norm + ) + if args.stacked_encoder == "all": + speech_encoder = speech_s2s_encoder + unsup_speech_encoder = StackedSpeechWavTransformerEncoder( + unsup_speech_encoder, text_encoder.layers, text_encoder.layer_norm + ) + else: + cls.share_speech_text_encoder( + speech_encoder, text_encoder, args.encoder_shared_text_layers_from_begin + ) + return SpeechTextPreTrainEncoder( + dictionary, + speech_encoder, + speech_s2s_encoder, + unsup_speech_encoder, + text_encoder, + ) + + @classmethod + def share_speech_text_encoder( + cls, speech_encoder, text_encoder, shared_layers_from_begin + ): + if shared_layers_from_begin > 0: + num_text_encoder_layers = len(text_encoder.layers) + assert len(speech_encoder.layers) >= shared_layers_from_begin + assert num_text_encoder_layers >= shared_layers_from_begin + assert len(speech_encoder.layers) >= num_text_encoder_layers + for i, ly in enumerate( + speech_encoder.layers[ + -num_text_encoder_layers : -num_text_encoder_layers + + shared_layers_from_begin + ] + ): + assert isinstance(text_encoder.layers[i], type(ly)) + text_encoder.layers[i] = ly + + def select_encoder(self, mode, **kwargs): + if mode in ("speech", "sup_speech_ctc", "sup_speech_ali", "sup_speech_s2s"): + kwargs["features_only"] = True + if mode == "sup_speech_s2s": + return self.sup_s2s_speech_encoder, kwargs + return self.sup_speech_encoder, kwargs + elif mode == "unsup_speech": + kwargs["features_only"] = False + return self.unsup_speech_encoder, kwargs + elif mode in ("text", "bitext"): + return self.text_encoder, kwargs + else: + raise NotImplementedError(f"{mode} is not supported") + return None, kwargs + + def forward(self, src_tokens, src_lengths=None, mode="", alignment=None, **kwargs): + return super().forward(src_tokens, src_lengths, mode, **kwargs) + + +# SpeechDummyDecoder works as an extension of encoder, so we could fit encoder only training into seq2seq training +class SpeechDummyDecoder(FairseqDecoder): + def __init__( + self, + dictionary, + output_embedding, + no_emb_update_unsup=False, + use_output_proj=False, + ): + super().__init__(dictionary) + self.output_embedding = output_embedding + num_embedding, num_dim = self.output_embedding.weight.size() + self.out_proj = ( + None if use_output_proj is False else nn.Linear(num_dim, num_dim) + ) + self.no_emb_update_unsup = no_emb_update_unsup + + def extend_alignment(self, alignment, src_lengths, prev_output_tokens): + # alignment: B X N + # src_lengths: B X T + # prev_output_tokens: B X (N + 1) + tgt_tokens = prev_output_tokens[ + :, 1: + ] # remove the leading start of sentence token + ext_alignment = ( + torch.ones(len(src_lengths), src_lengths.max(), device=src_lengths.device) + .long() + .fill_(self.dictionary.pad()) + ) + for bs in range(src_lengths.size(0)): + tgt_length = tgt_tokens[bs].ne(self.dictionary.pad()).sum().item() + assert tgt_length == sum(alignment[bs].ne(1)) + 1 + src_st = 0 + for i in range(tgt_length): + tok = tgt_tokens[bs][i] + src_ed = (alignment[bs][i] * src_lengths[bs]).int().item() + ext_alignment[bs][src_st:src_ed].fill_(tok) + src_st = src_ed + return ext_alignment + + def forward( + self, + prev_output_tokens, + encoder_out, + incremental_state=None, + mode="speech", + alignment=None, + **kwargs, + ): + """ + Args: + prev_output_tokens (LongTensor): previous decoder outputs of shape + `(batch, tgt_len)`, for teacher forcing + encoder_out (optional): output from the encoder, used for + encoder-side attention + incremental_state (dict): dictionary used for storing state during + :ref:`Incremental decoding` + features_only (bool, optional): only return features without + applying output layer (default: False). + full_context_alignment (bool, optional): don't apply + auto-regressive mask to self-attention (default: False). + + Returns: + sup_speech_ctc: + dictionary{"logits": logits, "padding_mask": padding_mask} + sup_speech_ali and unsup_speech: + tuple: + - the decoder's output of shape `(batch, tgt_len, vocab)` + - a dictionary with any model-specific outputs + """ + emb_weight = self.output_embedding.weight + if ( + mode == "unsup_speech" and self.no_emb_update_unsup + ): # no gradient for embedding here + emb_weight = emb_weight.detach() + enc_out = ( + encoder_out["encoder_out"][0] + if self.out_proj is None + else self.out_proj(encoder_out["encoder_out"][0]) + ) + logits = F.linear(enc_out, emb_weight, None).transpose(0, 1) # B X T X C + others = None + if mode in ( + "speech", + "sup_speech_ctc", + ): # speech data with label, do forcealignment + if len(encoder_out["encoder_padding_mask"]) > 0: + padding_mask = encoder_out["encoder_padding_mask"][0] + logits = logits.masked_fill(padding_mask, float("-inf")) + else: + seq_len, bsz = encoder_out["encoder_out"][0].size()[:2] + padding_mask = torch.zeros( + bsz, seq_len, device=encoder_out["encoder_out"][0].device + ).bool() + return {"x": logits, "padding_mask": padding_mask} + elif mode == "sup_speech_ali": + src_lengths = None + if len(encoder_out["encoder_padding_mask"]) > 0: + src_lengths = (1 - encoder_out["encoder_padding_mask"][0].long()).sum( + -1 + ) + else: + seq_len, bsz = encoder_out["encoder_out"][0].size()[:2] + src_lengths = ( + torch.ones(bsz, device=encoder_out["encoder_out"][0].device).long() + * seq_len + ) + assert alignment is not None + alignment = self.extend_alignment( + alignment, src_lengths, prev_output_tokens + ) + others = {"pseudo_target_tokens": alignment} + elif mode == "unsup_speech": + enc_out_ori = ( + encoder_out["encoder_unmasked_out"][0] + if self.out_proj is None + else self.out_proj(encoder_out["encoder_unmasked_out"][0]) + ) + logits_ori = F.linear(enc_out_ori, emb_weight, None).transpose(0, 1) + if len(encoder_out["encoder_padding_mask"]) > 0: + encoder_padding_mask = encoder_out["encoder_padding_mask"][0] + logits_ori = logits_ori.masked_fill(encoder_padding_mask, float("-inf")) + pseudo_labels = utils.log_softmax(logits_ori, dim=-1) + others = { + "pseudo_target_logprobs": pseudo_labels, + "padding_mask": encoder_out["encoder_padding_mask"], # B X T + "mask_indices": encoder_out[ + "mask_indices" + ], # True for masked frames B X T + } + return logits, others + + def get_normalized_probs( + self, + net_output: Dict[str, Tensor], + log_probs: bool, + sample: Optional[Dict[str, Tensor]] = None, + ): + return self.get_normalized_probs_scriptable( + (net_output["x"], None), log_probs, sample + ) + + +class SpeechTextPreTrainDecoder(MultiInputDecoder): + def __init__(self, dictionary, speech_decoder, text_decoder): + super().__init__(dictionary) + self.speech_decoder = speech_decoder + self.text_decoder = text_decoder + + def select_decoder(self, mode, **kwargs): + if mode == "unsup_speech": + kwargs["mode"] = mode + return self.speech_decoder, kwargs + if mode in ("text", "bitext"): + return self.text_decoder, kwargs + if mode in ("speech", "sup_speech_ctc", "sup_speech_ali"): + kwargs["mode"] = mode + return self.speech_decoder, kwargs + if mode in ("speech", "sup_speech_s2s"): + if "alignment" in kwargs: + del kwargs["alignment"] + return self.text_decoder, kwargs + + raise NotImplementedError(f"{mode} is not supported") + return None, kwargs + + def get_normalized_probs( + self, + net_output, + log_probs, + sample=None, + ): + """Get normalized probabilities (or log probs) from a net's output.""" + if isinstance(net_output, dict): + return self.speech_decoder.get_normalized_probs( + net_output, log_probs, sample + ) + return self.text_decoder.get_normalized_probs(net_output, log_probs, sample) + + @classmethod + def build_text_decoder(cls, args, tgt_dictionary, dec_emb_share=None): + dec_emb = ( + nn.Embedding( + len(tgt_dictionary), args.decoder_embed_dim, tgt_dictionary.pad() + ) + if dec_emb_share is None + else dec_emb_share + ) + text_decoder = TransformerDecoder(args, tgt_dictionary, dec_emb) + return text_decoder + + @classmethod + def build_dummy_speech_decoder(cls, args, dictionary, dec_emb_share=None): + dec_emb = ( + nn.Embedding(len(dictionary), args.decoder_embed_dim, dictionary.pad()) + if dec_emb_share is None + else dec_emb_share + ) + speech_decoder = SpeechDummyDecoder( + dictionary, + dec_emb, + no_emb_update_unsup=getattr(args, "no_emb_update_unsup", False), + use_output_proj=getattr(args, "use_decoder_output_proj", False), + ) + return speech_decoder + + @classmethod + def build_decoder( + cls, args, text_dictionary, speech_dictionary, speech_output_embedding + ): + text_decoder = cls.build_text_decoder(args, text_dictionary) + speech_decoder = cls.build_dummy_speech_decoder( + args, speech_dictionary, speech_output_embedding + ) + if getattr(args, "load_pretrained_mbart_decoder_from", None): + text_decoder = checkpoint_utils.load_pretrained_component_from_model( + component=text_decoder, + checkpoint=args.load_pretrained_mbart_decoder_from, + ) + return SpeechTextPreTrainDecoder(text_dictionary, speech_decoder, text_decoder) + + +@register_model("speech_text_pretrain_bart") +class SpeechTextPreTrainModel(FairseqEncoderDecoderModel): + def __init__(self, encoder, decoder): + super().__init__(encoder, decoder) + self.num_updates = 0 + + def forward( + self, src_tokens, src_lengths, prev_output_tokens, src_lang_ids=None, **kwargs + ): + if src_lang_ids is not None: + encoder_out = self.encoder( + src_tokens, src_lengths=src_lengths, src_lang_ids=src_lang_ids, **kwargs + ) + else: + encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs) + decoder_out = self.decoder( + prev_output_tokens, encoder_out=encoder_out, **kwargs + ) + return decoder_out + + def max_positions(self): + return None # it is provided in task + + def get_targets(self, sample, net_output): + mode = sample["net_input"]["mode"] + if mode == "unsup_speech": + return {"target_logprobs": net_output[1]["pseudo_target_logprobs"]} + if mode == "sup_speech_ali": + return net_output[1]["pseudo_target_tokens"] + return sample["target"] + + def get_normalized_probs( + self, + net_output, + log_probs, + sample=None, + ): + # net_output['encoder_out'] is a (B, T, D) tensor + lprobs = self.get_normalized_probs_scriptable(net_output, log_probs, sample) + lprobs.batch_first = True + return lprobs + + @staticmethod + def add_args(parser): + TransformerModel.add_args(parser) + SpeechWavTransformerEncoder.add_args(parser) + parser.add_argument( + "--speech-sup-mask-prob", + type=float, + help="probability of replacing a token with mask (sup-speech)", + ) + parser.add_argument( + "--speech-unsup-mask-prob", + type=float, + help="probability of replacing a token with mask (unsup-speech)", + ) + parser.add_argument( + "--load-pretrained-mbart-encoder-from", + type=str, + metavar="STR", + help="model to take text encoder weights from (for initialization)", + ) + + parser.add_argument( + "--load-pretrained-mbart-decoder-from", + type=str, + metavar="STR", + help="model to take text decoder weights from (for initialization)", + ) + + parser.add_argument( + "--load-pretrained-feature-extractor-from", + type=str, + metavar="STR", + help="model to take feature extractor weights from (for initialization)", + ) + + parser.add_argument( + "--speech-unsup-dropout", + type=float, + default=0, + help="dropout for unsupervised speech encoder", + ) + + parser.add_argument( + "--speech-unsup-feature-dropout", + type=float, + default=0, + help="dropout for unsupervised speech feature encoder", + ) + + parser.add_argument( + "--encoder-shared-text-layers-from-begin", + type=int, + help="number of text encoder layers shared with speech encoder (from first layer)", + ) + + parser.add_argument( + "--stacked-encoder", + default="none", + choices=["none", "s2s", "all"], + help="stack speech and text encoders", + ) + + parser.add_argument("--use-decoder-output-proj", action="store_true") + + @classmethod + def build_model(cls, args, task): + encoder = SpeechTextPreTrainEncoder.build_encoder(args, task.src_dict) + decoder = SpeechTextPreTrainDecoder.build_decoder( + args, task.tgt_dict, task.src_dict, encoder.text_encoder.embed_tokens + ) + model = SpeechTextPreTrainModel(encoder, decoder) + return model + + def upgrade_state_dict(self, state_dict): + """Upgrade old state dicts to work with newer code.""" + if "decoder.speech_decoder.output_projection.weight" in state_dict: + del state_dict["decoder.speech_decoder.output_projection.weight"] + self.upgrade_state_dict_named(state_dict, "") + + +@register_model_architecture( + "speech_text_pretrain_bart", "speech_text_pretrain_bart_base" +) +def speech_text_pretrain_bart_base(args): + # speech masking + args.dropout_input = getattr(args, "dropout_input", 0) + args.dropout_features = getattr(args, "dropout_features", 0) + args.speech_mask_length = getattr(args, "speech_mask_length", 10) + args.speech_mask_prob = getattr(args, "speech_mask_prob", 0.65) + args.speech_sup_mask_prob = getattr(args, "speech_sup_mask_prob", 0.3) + args.speech_unsup_mask_prob = getattr( + args, "speech_unsup_mask_prob", args.speech_mask_prob + ) + args.speech_mask_selection = getattr(args, "speech_mask_selection", "static") + args.speech_mask_other = getattr(args, "speech_mask_other", 0) + args.speech_mask_min_space = getattr(args, "speech_mask_min_space", 1) + args.speech_no_mask_overlap = getattr(args, "speech_no_mask_overlap", False) + + args.speech_mask_channel_length = getattr(args, "speech_mask_channel_length", 10) + args.speech_mask_channel_prob = getattr(args, "speech_mask_channel_prob", 0.0) + args.speech_mask_channel_selection = getattr( + args, "speech_mask_channel_selection", "static" + ) + args.speech_mask_channel_other = getattr(args, "speech_mask_channel_other", 0) + args.speech_mask_channel_min_space = getattr( + args, "speech_mask_channel_min_space", 1 + ) + args.speech_no_mask_channel_overlap = getattr( + args, "speech_no_mask_channel_overlap", False + ) + args.no_scale_feature = getattr(args, "", False) + args.feature_grad_mult = getattr(args, "feature_grad_mult", 1.0) # 0.1 + + # Transformer + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 768) + args.encoder_ffn_embed_dim = getattr( + args, "encoder_ffn_embed_dim", args.encoder_embed_dim * 4 + ) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 12) + args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False) + args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0) + args.encoder_learned_pos = getattr(args, "encoder_learned_pos", False) + args.speech_conv_bias = getattr(args, "speech_conv_bias", False) + + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim) + args.decoder_ffn_embed_dim = getattr( + args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim + ) + args.decoder_attention_heads = getattr( + args, "decoder_attention_heads", args.encoder_attention_heads + ) + args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False) + args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False) + args.dropout = getattr(args, "dropout", 0.1) + args.attention_dropout = getattr(args, "attention_dropout", args.dropout) + args.activation_dropout = getattr(args, "activation_dropout", 0.0) + args.activation_fn = getattr(args, "activation_fn", "relu") # gelu? + args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None) + args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0) + + args.speech_unsup_dropout = getattr(args, "speech_unsup_dropout", 0) + args.speech_unsup_feature_dropout = getattr(args, "speech_unsup_feature_dropout", 0) + + args.tie_adaptive_weights = getattr(args, "tie_adaptive_weights", False) + args.share_decoder_input_output_embed = getattr( + args, "share_decoder_input_output_embed", False + ) + args.no_token_positional_embeddings = getattr( + args, "no_token_positional_embeddings", False + ) + args.adaptive_input = getattr(args, "adaptive_input", False) + args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0.0) + args.decoder_output_dim = getattr( + args, "decoder_output_dim", args.decoder_embed_dim + ) + args.layernorm_embedding = getattr(args, "layernorm_embedding", False) + args.no_scale_embedding = getattr(args, "no_scale_embedding", False) + args.quant_noise_pq = getattr(args, "quant_noise_pq", 0) + + args.speech_encoder_layers = getattr(args, "speech_encoder_layers", 12) + args.text_encoder_layers = getattr(args, "text_encoder_layers", 6) + args.encoder_shared_text_layers_from_begin = getattr( + args, "encoder_shared_text_layers_from_begin", 6 + ) + args.decoder_layers = getattr(args, "decoder_layers", 6) + + args.no_emb_update_unsup = getattr(args, "no_emb_update_unsup", False) + + +@register_model_architecture( + "speech_text_pretrain_bart", "speech_text_pretrain_bart_base_stack" +) +def speech_text_pretrain_bart_base_stack(args): + args.speech_encoder_layers = getattr(args, "speech_encoder_layers", 6) + args.text_encoder_layers = getattr(args, "text_encoder_layers", 6) + args.encoder_shared_text_layers_from_begin = getattr( + args, "encoder_shared_text_layers_from_begin", 0 + ) + args.stacked_encoder = getattr(args, "stacked_encoder", "all") + args.layernorm_embedding = getattr(args, "layernorm_embedding", True) + speech_text_pretrain_bart_base(args) + + +@register_model_architecture( + "speech_text_pretrain_bart", "speech_text_pretrain_bart_large" +) +def speech_text_pretrain_bart_large(args): + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16) + args.speech_encoder_layers = getattr(args, "speech_encoder_layers", 24) + args.text_encoder_layers = getattr(args, "text_encoder_layers", 12) + args.encoder_shared_text_layers_from_begin = getattr( + args, "encoder_shared_text_layers_from_begin", 12 + ) + args.decoder_layers = getattr(args, "decoder_layers", 12) + args.dropout = getattr(args, "dropout", 0.3) + speech_text_pretrain_bart_base(args) + + +@register_model_architecture( + "speech_text_pretrain_bart", "speech_text_pretrain_bart_large_stack" +) +def speech_text_pretrain_bart_large_stack(args): + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16) + args.speech_encoder_layers = getattr(args, "speech_encoder_layers", 6) + args.text_encoder_layers = getattr(args, "text_encoder_layers", 12) + args.encoder_shared_text_layers_from_begin = getattr( + args, "encoder_shared_text_layers_from_begin", 0 + ) + args.decoder_layers = getattr(args, "decoder_layers", 12) + args.stacked_encoder = getattr(args, "stacked_encoder", "s2s") + args.layernorm_embedding = getattr(args, "layernorm_embedding", True) + speech_text_pretrain_bart_base(args) diff --git a/examples/speech_text_joint_to_text/models/s2t_dualinputwavtransformer.py b/examples/speech_text_joint_to_text/models/s2t_dualinputwavtransformer.py new file mode 100644 index 000000000..66e4b3f1e --- /dev/null +++ b/examples/speech_text_joint_to_text/models/s2t_dualinputwavtransformer.py @@ -0,0 +1,526 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from collections import OrderedDict, namedtuple + +import torch.nn as nn + +from fairseq import checkpoint_utils, utils +from fairseq.checkpoint_utils import load_checkpoint_to_cpu +from fairseq.file_io import PathManager +from fairseq.models import register_model, register_model_architecture +from fairseq.models.speech_to_text import ( + SpeechWavTransformerEncoder, + StackedSpeechWavTransformerEncoder, + TransformerDecoder, +) +from fairseq.models.transformer import TransformerEncoder + +from .s2t_dualinputtransformer import ( + DualInputEncoder, + DualInputS2TTransformerModel, + TransformerMultiInputDecoder, +) + +logger = logging.getLogger(__name__) + + +@register_model("dual_input_wav_transformer") +class DualInputWavTransformerModel(DualInputS2TTransformerModel): + def __init__(self, encoder, decoder): + super().__init__(encoder, decoder) + + @staticmethod + def add_args(parser): + def add_transformer_args(parser): + # We can't use TransformerModel.add_args(parser), since it defines max-source-positions which is duplicated with tasks/speech_to_text.py + # Transformer + parser.add_argument( + "--activation-fn", + type=str, + default="relu", + choices=utils.get_available_activation_fns(), + help="activation function to use", + ) + parser.add_argument( + "--dropout", type=float, metavar="D", help="dropout probability" + ) + parser.add_argument( + "--attention-dropout", + type=float, + metavar="D", + help="dropout probability for attention weights", + ) + parser.add_argument( + "--activation-dropout", + "--relu-dropout", + type=float, + metavar="D", + help="dropout probability after activation in FFN.", + ) + parser.add_argument( + "--encoder-embed-dim", + type=int, + metavar="N", + help="encoder embedding dimension", + ) + parser.add_argument( + "--encoder-ffn-embed-dim", + type=int, + metavar="N", + help="encoder embedding dimension for FFN", + ) + parser.add_argument( + "--encoder-layers", type=int, metavar="N", help="num encoder layers" + ) + parser.add_argument( + "--encoder-attention-heads", + type=int, + metavar="N", + help="num encoder attention heads", + ) + parser.add_argument( + "--encoder-normalize-before", + action="store_true", + help="apply layernorm before each encoder block", + ) + parser.add_argument( + "--decoder-embed-dim", + type=int, + metavar="N", + help="decoder embedding dimension", + ) + parser.add_argument( + "--decoder-ffn-embed-dim", + type=int, + metavar="N", + help="decoder embedding dimension for FFN", + ) + parser.add_argument( + "--decoder-layers", type=int, metavar="N", help="num decoder layers" + ) + parser.add_argument( + "--decoder-attention-heads", + type=int, + metavar="N", + help="num decoder attention heads", + ) + parser.add_argument( + "--decoder-normalize-before", + action="store_true", + help="apply layernorm before each decoder block", + ) + parser.add_argument( + "--share-decoder-input-output-embed", + action="store_true", + help="share decoder input and output embeddings", + ) + parser.add_argument( + "--layernorm-embedding", + action="store_true", + help="add layernorm to embedding", + ) + parser.add_argument( + "--no-scale-embedding", + action="store_true", + help="if True, dont scale embeddings", + ) + + parser.add_argument( + "--encoder-learned-pos", + action="store_true", + help="use learned positional embeddings", + ) + parser.add_argument( + "--decoder-learned-pos", + action="store_true", + help="use learned positional embeddings", + ) + + add_transformer_args(parser) + SpeechWavTransformerEncoder.add_args(parser) + parser.add_argument( + "--load-pretrained-speech-text-encoder", + type=str, + default="", + metavar="EXPR", + help=""" path to the pretrained speech text encoder from SpeechTextPreTrainModel """, + ) + parser.add_argument( + "--load-pretrained-wav2vec-encoder", + type=str, + default="", + metavar="EXPR", + help=""" path to the pretrained speech text encoder from wav2vec """, + ) + + parser.add_argument( + "--load-pretrained-speech-text-decoder", + type=str, + default="", + metavar="EXPR", + help=""" path to the pretrained speech text decoder from SpeechTextPreTrainModel """, + ) + parser.add_argument( + "--load-pretrained-text-decoder", + type=str, + default="", + metavar="EXPR", + help=""" path to the pretrained text decoder """, + ) + parser.add_argument( + "--load-init-encoder", + type=str, + default="", + metavar="EXPR", + help=""" path to load seed encoder model """, + ) + parser.add_argument( + "--load-init-decoder", + type=str, + default="", + metavar="EXPR", + help=""" path to load seed decoder model """, + ) + + parser.add_argument( + "--text-input-cost-ratio", + type=float, + default=1.0, + metavar="V", + help="text input cost ratio relative to speech input cost", + ) + parser.add_argument( + "--enc-grad-mult", + type=float, + metavar="V", + default=1.0, + help="multiply enc1 and enc2 gradient by V", + ) + parser.add_argument( + "--enc2-along-grad-mult", + type=float, + metavar="V", + default=1.0, + help="multiply enc2 gradient by V if only enc2 is used", + ) + parser.add_argument( + "--no-strict-check-pretrain-model", + action="store_true", + help="Don't apply strict model check for the pretrained model", + ) + + parser.add_argument( + "--stacked-encoder", + action="store_true", + help="stack speech and text encoders", + ) + + @classmethod + def update_transformer_encoder_cfg(cls, args, update_dict): + cfg = dict(args._get_kwargs()) + for fkey in update_dict.keys(): + cfg[fkey] = update_dict[fkey] + cfg.pop("_name", None) # remove keys start with _ + model_args = namedtuple("args", cfg.keys())(*cfg.values()) + return model_args + + @classmethod + def build_text_encoder(cls, args, src_dictionary): + enc_emb = nn.Embedding( + len(src_dictionary), args.encoder_embed_dim, src_dictionary.pad() + ) + model_args = cls.update_transformer_encoder_cfg( + args, + { + "encoder_layers": args.text_encoder_layers, + "max_source_positions": args.max_positions_text, + }, + ) + text_encoder = TransformerEncoder(model_args, src_dictionary, enc_emb) + return text_encoder + + @classmethod + def build_speech_encoder(cls, args): + model_args = cls.update_transformer_encoder_cfg( + args, {"encoder_layers": args.speech_encoder_layers} + ) + speech_encoder = SpeechWavTransformerEncoder(model_args) + return speech_encoder + + @classmethod + def check_args(cls, condition, is_strict, msg): + if condition: + return + if is_strict: + raise ValueError(msg) + logger.warn(msg) + + @classmethod + def build_encoder(cls, args, task): + # text_encoder = cls.build_text_encoder(args, task.source_dictionary ) + text_encoder = cls.build_text_encoder(args, task.src_dict) + speech_encoder = cls.build_speech_encoder(args) + if args.load_pretrained_wav2vec_encoder: + component_pairs = ( + ("feature_extractor", speech_encoder.subsample), + ("post_extract_proj", speech_encoder.feat_proj), + ("layer_norm", speech_encoder.feat_layer_norm), + ("encoder.pos_conv", speech_encoder.embed_positions), + ("encoder.layers", speech_encoder.layers), + ("encoder.layer_norm", speech_encoder.layer_norm), + ("mask_emb", speech_encoder.mask_emb), + ) + state = cls.load_pretrained_speech_text_components( + args.load_pretrained_wav2vec_encoder, component_pairs + ) + cls.check_args( + args.encoder_normalize_before + == state["cfg"]["model"]["layer_norm_first"], + not args.no_strict_check_pretrain_model, + f"encoder_normalize_before {args.encoder_normalize_before} doesn't match with the pretrained model", + ) + cls.check_args( + args.activation_fn == state["cfg"]["model"]["activation_fn"], + not args.no_strict_check_pretrain_model, + f"activation_fn {args.activation_fn} doesn't match with the pretrained model", + ) + + if getattr(args, "stacked_encoder", False): + if args.encoder_shared_text_layers_from_begin > 0: + raise ValueError( + "We can not stack encoders and share encoders at the same time!" + ) + speech_encoder = StackedSpeechWavTransformerEncoder( + speech_encoder, text_encoder.layers, text_encoder.layer_norm + ) + else: + cls.share_speech_text_encoder( + speech_encoder, text_encoder, args.encoder_shared_text_layers_from_begin + ) + + cross_attentive_loss_before_last_layer = ( + 0 if getattr(args, "attentive_cost_regularization", 0.0) > 0.0 else -1 + ) + encoder = DualInputEncoder( + args, + speech_encoder, + text_encoder, + task.src_dict, + cross_attentive_loss_before_last_layer, + ) + if args.load_pretrained_speech_text_encoder: + component_pairs = ( + ("encoder.sup_s2s_speech_encoder", encoder.spch_encoder), + ("encoder.text_encoder", encoder.text_encoder), + ) + cls.load_pretrained_speech_text_components( + args.load_pretrained_speech_text_encoder, component_pairs + ) + if getattr(args, "load_init_encoder", "") != "": + checkpoint_utils.load_pretrained_component_from_model( + encoder, args.load_init_encoder + ) + return encoder + + @classmethod + def build_text_decoder(cls, args, tgt_dictionary, dec_emb_share=None): + dec_emb = ( + nn.Embedding( + len(tgt_dictionary), args.decoder_embed_dim, tgt_dictionary.pad() + ) + if dec_emb_share is None + else dec_emb_share + ) + text_decoder = TransformerDecoder(args, tgt_dictionary, dec_emb) + return text_decoder + + @classmethod + def build_decoder(cls, args, task): + text_decoder = cls.build_text_decoder(args, task.target_dictionary) + compute_cross_attentive_loss = ( + True if getattr(args, "attentive_cost_regularization", 0.0) > 0.0 else False + ) + cross_attentive_loss_without_norm = getattr( + args, "attentive_cost_without_normalize", False + ) + cross_attentive_loss_reverse = ( + False # getattr(args, "attentive_cost_reverse", False) + ) + if getattr(args, "load_pretrained_text_decoder", "") != "": + checkpoint_utils.load_pretrained_component_from_model( + text_decoder, args.load_pretrained_text_decoder + ) + + if args.load_pretrained_speech_text_decoder: + component_pairs = (("decoder.text_decoder", text_decoder),) + cls.load_pretrained_speech_text_components( + args.load_pretrained_speech_text_decoder, component_pairs + ) + + decoder = TransformerMultiInputDecoder( + dictionary=task.target_dictionary, + spch_decoder=text_decoder, + text_decoder=text_decoder, + compute_cross_attentive_loss=compute_cross_attentive_loss, + cross_attentive_loss_with_norm=True + if not cross_attentive_loss_without_norm + else False, + cross_attentive_loss_reverse=cross_attentive_loss_reverse, + ) + if getattr(args, "load_init_decoder", "") != "": + checkpoint_utils.load_pretrained_component_from_model( + decoder, args.load_init_decoder + ) + return decoder + + @classmethod + def load_pretrained_speech_text_components(cls, checkpoint, component_pairs): + if not PathManager.exists(checkpoint): + raise IOError("Model file not found: {}".format(checkpoint)) + state = load_checkpoint_to_cpu(checkpoint) + for component_type, component in component_pairs: + if isinstance(component, nn.parameter.Parameter): + component.data.copy_(state["model"][component_type]) + else: + component_state_dict = OrderedDict() + for key in state["model"].keys(): + if key.startswith(component_type): + component_subkey = key[len(component_type) + 1 :] + component_state_dict[component_subkey] = state["model"][key] + component.load_state_dict(component_state_dict, strict=True) + return state + + @classmethod + def share_speech_text_encoder( + cls, speech_encoder, text_encoder, shared_layers_from_begin + ): + if shared_layers_from_begin > 0: + num_text_encoder_layers = len(text_encoder.layers) + assert len(speech_encoder.layers) >= shared_layers_from_begin + assert num_text_encoder_layers >= shared_layers_from_begin + assert len(speech_encoder.layers) >= num_text_encoder_layers + for i, ly in enumerate( + speech_encoder.layers[ + -num_text_encoder_layers : -num_text_encoder_layers + + shared_layers_from_begin + ] + ): + assert isinstance(text_encoder.layers[i], type(ly)) + text_encoder.layers[i] = ly + + +@register_model_architecture( + "dual_input_wav_transformer", "dualinputs2twavtransformer_base" +) +def dualinputs2twavtransformer_base(args): + # speech masking + args.dropout_input = getattr(args, "dropout_input", 0) + args.dropout_features = getattr(args, "dropout_features", 0) + args.speech_mask_length = getattr(args, "speech_mask_length", 10) + args.speech_mask_prob = getattr(args, "speech_mask_prob", 0.65) + args.speech_mask_selection = getattr(args, "speech_mask_selection", "static") + args.speech_mask_other = getattr(args, "speech_mask_other", 0) + args.speech_mask_min_space = getattr(args, "speech_mask_min_space", 1) + args.speech_no_mask_overlap = getattr(args, "speech_no_mask_overlap", False) + args.speech_conv_bias = getattr(args, "speech_conv_bias", False) + args.speech_extractor_mode = getattr(args, "speech_extractor_mode", "default") + args.no_strict_check_pretrain_model = getattr( + args, "no_strict_check_pretrain_model", False + ) + + args.speech_mask_channel_length = getattr(args, "speech_mask_channel_length", 10) + args.speech_mask_channel_prob = getattr(args, "speech_mask_channel_prob", 0.0) + args.speech_mask_channel_selection = getattr( + args, "speech_mask_channel_selection", "static" + ) + args.speech_mask_channel_other = getattr(args, "speech_mask_channel_other", 0) + args.speech_mask_channel_min_space = getattr( + args, "speech_mask_channel_min_space", 1 + ) + args.speech_no_mask_channel_overlap = getattr( + args, "speech_no_mask_channel_overlap", False + ) + args.no_scale_feature = getattr(args, "", False) + args.feature_grad_mult = getattr(args, "feature_grad_mult", 0.0) # 0.1 + + # Transformer + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 768) + args.encoder_ffn_embed_dim = getattr( + args, "encoder_ffn_embed_dim", args.encoder_embed_dim * 4 + ) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 12) + args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False) + args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0.1) + args.encoder_learned_pos = getattr(args, "encoder_learned_pos", False) + + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim) + args.decoder_ffn_embed_dim = getattr( + args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim + ) + args.decoder_attention_heads = getattr( + args, "decoder_attention_heads", args.encoder_attention_heads + ) + args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False) + args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False) + args.dropout = getattr(args, "dropout", 0.1) + args.attention_dropout = getattr(args, "attention_dropout", 0) + args.activation_dropout = getattr(args, "activation_dropout", args.dropout) + args.activation_fn = getattr(args, "activation_fn", "relu") # gelu? + args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None) + args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0) + args.tie_adaptive_weights = getattr(args, "tie_adaptive_weights", False) + args.share_decoder_input_output_embed = getattr( + args, "share_decoder_input_output_embed", False + ) + args.no_token_positional_embeddings = getattr( + args, "no_token_positional_embeddings", False + ) + args.adaptive_input = getattr(args, "adaptive_input", False) + args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0.0) + args.decoder_output_dim = getattr( + args, "decoder_output_dim", args.decoder_embed_dim + ) + args.layernorm_embedding = getattr(args, "layernorm_embedding", False) + args.no_scale_embedding = getattr(args, "no_scale_embedding", False) + args.quant_noise_pq = getattr(args, "quant_noise_pq", 0) + + args.speech_encoder_layers = getattr(args, "speech_encoder_layers", 12) + args.text_encoder_layers = getattr(args, "text_encoder_layers", 6) + args.encoder_shared_text_layers_from_begin = getattr( + args, "encoder_shared_text_layers_from_begin", 6 + ) + args.decoder_layers = getattr(args, "decoder_layers", 6) + + +@register_model_architecture( + "dual_input_wav_transformer", "dualinputs2twavtransformer_base_stack" +) +def dualinputs2twavtransformer_base_stack(args): + args.speech_encoder_layers = getattr(args, "speech_encoder_layers", 6) + args.text_encoder_layers = getattr(args, "text_encoder_layers", 6) + args.encoder_shared_text_layers_from_begin = getattr( + args, "encoder_shared_text_layers_from_begin", 0 + ) + args.decoder_layers = getattr(args, "decoder_layers", 6) + args.stacked_encoder = getattr(args, "stacked_encoder", True) + args.layernorm_embedding = getattr(args, "layernorm_embedding", True) + dualinputs2twavtransformer_base(args) + + +@register_model_architecture( + "dual_input_wav_transformer", "dualinputs2twavtransformer_large" +) +def dualinputs2twavtransformer_large(args): + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16) + args.speech_encoder_layers = getattr(args, "speech_encoder_layers", 24) + args.text_encoder_layers = getattr(args, "text_encoder_layers", 12) + args.encoder_shared_text_layers_from_begin = getattr( + args, "encoder_shared_text_layers_from_begin", 12 + ) + args.decoder_layers = getattr(args, "decoder_layers", 12) + dualinputs2twavtransformer_base(args) diff --git a/examples/speech_text_joint_to_text/scripts/convert_model.py b/examples/speech_text_joint_to_text/scripts/convert_model.py new file mode 100644 index 000000000..4923af131 --- /dev/null +++ b/examples/speech_text_joint_to_text/scripts/convert_model.py @@ -0,0 +1,71 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import re +from collections import OrderedDict + +import torch + +from fairseq.file_io import PathManager + + +def is_update(param_name, module_name): + if module_name in param_name: + return True + return False + + +def load_checkpoint(src_cpt): + + with PathManager.open(src_cpt, "rb") as f: + state_src = torch.load( + f, + map_location=( + lambda s, _: torch.serialization.default_restore_location(s, "cpu") + ), + ) + + return state_src + + +def save_checkpoint(tgt_cpt, states): + + with PathManager.open(tgt_cpt, "wb") as f: + torch.save( + states, + f, + ) + + +# convert the pre-trained model into bart model +def main(): + parser = argparse.ArgumentParser() + # fmt: off + parser.add_argument('--input-model', required=True, + help='Input checkpoint file path.') + parser.add_argument('--output-model', required=True, + help='output checkpoint file path.') + # fmt: on + args = parser.parse_args() + print(args) + + states = load_checkpoint(args.input_model) + model = states["model"] + new_model = OrderedDict() + for key in model.keys(): + if re.search("^encoder.text_encoder", key): + new_key = re.sub("encoder.text_encoder", "encoder", key) + new_model[new_key] = model[key] + elif re.search("^decoder.text_decoder", key): + new_key = re.sub("decoder.text_decoder", "decoder", key) + new_model[new_key] = model[key] + states["model"] = new_model + save_checkpoint(args.output_model, states) + + +if __name__ == "__main__": + main() diff --git a/examples/speech_text_joint_to_text/tasks/pair_denoising.py b/examples/speech_text_joint_to_text/tasks/pair_denoising.py new file mode 100644 index 000000000..b13b1e5ae --- /dev/null +++ b/examples/speech_text_joint_to_text/tasks/pair_denoising.py @@ -0,0 +1,447 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import itertools +import logging +import os +import re + +import numpy as np +import torch + +from examples.speech_text_joint_to_text.data.pair_denoising_dataset import ( + LanguagePairDenoisingDataset, +) +from fairseq import utils +from fairseq.data import ( + ConcatDataset, + Dictionary, + LanguagePairDataset, + ResamplingDataset, + TransformEosConcatLangPairDataset, + TransformEosLangPairDataset, + data_utils, + indexed_dataset, +) +from fairseq.data.encoders.utils import get_whole_word_mask +from fairseq.tasks import register_task +from fairseq.tasks.translation import TranslationTask + +logger = logging.getLogger(__name__) + + +def gen_whole_word_mask(args, dictionary): + def is_beginning_of_word(i): + if i < dictionary.nspecial: + # special elements are always considered beginnings + return True + tok = dictionary[i] + if tok.startswith("madeupword"): + return True + + if tok in ["", "", "", ""]: + return True + return tok.startswith("\u2581") + + if args.use_mask_whole_words: + mask_whole_words = torch.ByteTensor( + list(map(is_beginning_of_word, range(len(dictionary)))) + ) + else: + # it will mask every token as word leading token, since no bpe model is loaded for phoneme tokens + return get_whole_word_mask(args, dictionary) + return mask_whole_words + + +@register_task("paired_denoising") +class PairedDenoisingTask(TranslationTask): + + LANG_TAG_TEMPLATE = "" # Tag for language (target) + + @staticmethod + def add_args(parser): + TranslationTask.add_args(parser) + # bart setting + parser.add_argument( + "--mask", + default=0.0, + type=float, + help="fraction of words/subwords that will be masked", + ) + parser.add_argument( + "--mask-random", + default=0.0, + type=float, + help="instead of using [MASK], use random token this often", + ) + parser.add_argument( + "--insert", + default=0.0, + type=float, + help="insert this percentage of additional random tokens", + ) + parser.add_argument( + "--poisson-lambda", + default=3.0, + type=float, + help="randomly shuffle sentences for this proportion of inputs", + ) + parser.add_argument( + "--mask-length", + default="span-poisson", + type=str, + choices=["subword", "word", "span-poisson"], + help="mask length to choose", + ) + parser.add_argument( + "--replace-length", + default=1, + type=int, + help="when masking N tokens, replace with 0, 1, or N tokens (use -1 for N)", + ) + + # multi-lingual + parser.add_argument( + "--multilang-sampling-alpha", + type=float, + default=1.0, + help="smoothing alpha for sample ratios across multiple datasets", + ) + parser.add_argument( + "--lang-pairs", + default="", + metavar="PAIRS", + help="comma-separated list of language pairs (in training order): phnen-en,phnfr-fr,phnit-it. Do masking", + ) + parser.add_argument( + "--lang-pairs-bitext", + default="", + metavar="PAIRS", + help="comma-separated list of language pairs (in training order): en-de,en-fr,de-fr. No masking", + ) + parser.add_argument("--add-src-lang-token", default=False, action="store_true") + parser.add_argument("--add-tgt-lang-token", default=False, action="store_true") + parser.add_argument( + "--no-whole-word-mask-langs", + type=str, + default="", + metavar="N", + help="languages without spacing between words dont support whole word masking", + ) + parser.add_argument( + "--use-mask-whole-words", default=False, action="store_true" + ) + + @classmethod + def setup_task(cls, args, **kwargs): + """Setup the task.""" + paths = args.data.split(":") + assert len(paths) > 0 + src_dict = Dictionary.load( + os.path.join(paths[0], "src_dict.txt") + ) # assume all languages share a source dictionary + tgt_dict = Dictionary.load( + os.path.join(paths[0], "tgt_dict.txt") + ) # assume all languages share a target dictionary + + lang_pairs = args.lang_pairs + "," + args.lang_pairs_bitext + lang_pairs = re.sub(",$", "", re.sub("^,", "", lang_pairs)) + src_langs = [lp.split("-")[0] for lp in lang_pairs.split(",")] + tgt_langs = [lp.split("-")[1] for lp in lang_pairs.split(",")] + + if args.add_src_lang_token: + for lang in src_langs: + assert ( + src_dict.index(PairedDenoisingTask.LANG_TAG_TEMPLATE.format(lang)) + != src_dict.unk() + ) + if args.add_tgt_lang_token: + for lang in tgt_langs: + assert ( + tgt_dict.index(PairedDenoisingTask.LANG_TAG_TEMPLATE.format(lang)) + != tgt_dict.unk() + ) + + logger.info("source dictionary: {} types".format(len(src_dict))) + logger.info("target dictionary: {} types".format(len(tgt_dict))) + if not hasattr(args, "shuffle_instance"): + args.shuffle_instance = False + return cls(args, src_dict, tgt_dict) + + def __init__(self, args, src_dict, tgt_dict): + super().__init__(args, src_dict, tgt_dict) + # check mask token + self.mask_idx = self.src_dict.index("") + assert self.mask_idx != self.src_dict.unk() + self.lang_pairs = args.lang_pairs + self.lang_pairs_bitext = args.lang_pairs_bitext + self.args = args + + @classmethod + def language_pair_denoising_dataset( + cls, + data_path, + do_mask, + split, + src, + src_dict, + tgt, + tgt_dict, + mask_idx, + mask_whole_words, + seed, + args, + dataset_impl, + combine=False, + left_pad_source=True, + left_pad_target=False, + max_source_positions=1024, + max_target_positions=1024, + shuffle=True, + src_lang_id=None, + tgt_lang_id=None, + ): + def split_exists(split, src, tgt, lang, data_path): + filename = os.path.join( + data_path, "{}.{}-{}.{}".format(split, src, tgt, lang) + ) + return indexed_dataset.dataset_exists(filename, impl=dataset_impl) + + src_datasets = [] + tgt_datasets = [] + + for k in itertools.count(): + split_k = split + (str(k) if k > 0 else "") + + # infer langcode + if split_exists(split_k, src, tgt, src, data_path): + prefix = os.path.join(data_path, "{}.{}-{}.".format(split_k, src, tgt)) + elif split_exists(split_k, tgt, src, src, data_path): + prefix = os.path.join(data_path, "{}.{}-{}.".format(split_k, tgt, src)) + else: + if k > 0: + break + else: + raise FileNotFoundError( + "Dataset not found: {} ({})".format(split, data_path) + ) + + src_dataset = data_utils.load_indexed_dataset( + prefix + src, src_dict, dataset_impl + ) + src_datasets.append(src_dataset) + + tgt_dataset = data_utils.load_indexed_dataset( + prefix + tgt, tgt_dict, dataset_impl + ) + if tgt_dataset is not None: + tgt_datasets.append(tgt_dataset) + + logger.info( + "{} {} {}-{} {} examples".format( + data_path, split_k, src, tgt, len(src_datasets[-1]) + ) + ) + + if not combine: + break + + assert len(src_datasets) == len(tgt_datasets) or len(tgt_datasets) == 0 + + if len(src_datasets) == 1: + src_dataset = src_datasets[0] + tgt_dataset = tgt_datasets[0] if len(tgt_datasets) > 0 else None + else: + sample_ratios = [1] * len(src_datasets) + src_dataset = ConcatDataset(src_datasets, sample_ratios) + if len(tgt_datasets) > 0: + tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios) + else: + tgt_dataset = None + + eos = None + + tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None + if not do_mask: + return LanguagePairDataset( + src_dataset, + src_dataset.sizes, + src_dict, + tgt_dataset, + tgt_dataset_sizes, + tgt_dict, + left_pad_source=left_pad_source, + left_pad_target=left_pad_target, + eos=eos, + shuffle=shuffle, + src_lang_id=src_lang_id, + tgt_lang_id=tgt_lang_id, + ) + + return LanguagePairDenoisingDataset( + src_dataset, + src_dataset.sizes, + src_dict, + tgt_dataset, + tgt_dataset_sizes, + tgt_dict, + mask_idx, + mask_whole_words, + seed, + args, + left_pad_source=left_pad_source, + left_pad_target=left_pad_target, + eos=eos, + shuffle=shuffle, + src_lang_id=src_lang_id, + tgt_lang_id=tgt_lang_id, + ) + + def _get_sample_prob(self, dataset_lens): + """ + Get smoothed sampling porbability by languages. This helps low resource + languages by upsampling them. + """ + prob = dataset_lens / dataset_lens.sum() + smoothed_prob = prob ** self.args.multilang_sampling_alpha + smoothed_prob = smoothed_prob / smoothed_prob.sum() + return smoothed_prob + + def resample_datasets(self, lang_datasets, lang_pairs_all, epoch): + # For train subset, additionally up or down sample languages. + if self.args.multilang_sampling_alpha == 1.0: + return lang_datasets + + dataset_lengths = np.array( + [len(d) for d in lang_datasets], + dtype=float, + ) + sample_probs = self._get_sample_prob(dataset_lengths) + logger.info( + "Sample probability by language pair: {}".format( + { + lp: "{0:.4f}".format(sample_probs[id]) + for id, lp in enumerate(lang_pairs_all) + } + ) + ) + size_ratio = (sample_probs * dataset_lengths.sum()) / dataset_lengths + logger.info( + "Up/Down Sampling ratio by language: {}".format( + { + lp: "{0:.2f}".format(size_ratio[id]) + for id, lp in enumerate(lang_pairs_all) + } + ) + ) + + resampled_lang_datasets = [ + ResamplingDataset( + lang_datasets[i], + size_ratio=size_ratio[i], + seed=self.args.seed, + epoch=epoch, + replace=size_ratio[i] >= 1.0, + ) + for i, d in enumerate(lang_datasets) + ] + return resampled_lang_datasets + + def load_dataset_only( + self, split, lang_pairs, do_mask=True, epoch=1, combine=False + ): + paths = utils.split_paths(self.args.data) + assert len(paths) > 0 + data_path = paths[(epoch - 1) % len(paths)] + + # TODO unk token will be considered as first word too, though it might be an unknown phoneme within a word + # get_whole_word_mask returns a tensor (size V by 1 ) to indicate if a token is a word start token + mask_whole_src_words = gen_whole_word_mask(self.args, self.src_dict) + language_without_segmentations = self.args.no_whole_word_mask_langs.split(",") + lang_datasets = [] + eos_bos = [] + lang_pairs = lang_pairs.split(",") if lang_pairs != "" else [] + assert len(lang_pairs) > 0 + for lp in lang_pairs: + src, tgt = lp.split("-") + lang_mask_whole_src_words = ( + mask_whole_src_words + if src not in language_without_segmentations + else None + ) + + end_token = ( + self.source_dictionary.index( + PairedDenoisingTask.LANG_TAG_TEMPLATE.format(src) + ) + if self.args.add_src_lang_token + else None + ) + bos_token = ( + self.target_dictionary.index( + PairedDenoisingTask.LANG_TAG_TEMPLATE.format(tgt) + ) + if self.args.add_tgt_lang_token + else None + ) + src_lang_id = None + + if self.args.add_src_lang_token or self.args.add_tgt_lang_token: + eos_bos.append((end_token, bos_token)) + + dataset = PairedDenoisingTask.language_pair_denoising_dataset( + data_path, + do_mask, + split, + src, + self.source_dictionary, + tgt, + self.target_dictionary, + self.mask_idx, + lang_mask_whole_src_words, + self.args.seed, + self.args, + self.args.dataset_impl, + combine=combine, + left_pad_source=utils.eval_bool(self.args.left_pad_source), + left_pad_target=utils.eval_bool(self.args.left_pad_target), + max_source_positions=self.args.max_source_positions, + max_target_positions=self.args.max_target_positions, + src_lang_id=src_lang_id, + ) + + lang_datasets.append(dataset) + + if len(lang_datasets) == 0: + return + elif len(lang_datasets) == 1: + dataset = lang_datasets[0] + if self.args.add_src_lang_token or self.args.add_tgt_lang_token: + end_token, bos_token = eos_bos[0] + dataset = TransformEosLangPairDataset( + dataset, + src_eos=self.source_dictionary.eos(), + new_src_eos=end_token, + tgt_bos=self.target_dictionary.eos(), + new_tgt_bos=bos_token, + ) + else: + end_tokens = [item[0] for item in eos_bos if item[0] is not None] + bos_tokens = [item[1] for item in eos_bos if item[1] is not None] + lang_datasets = self.resample_datasets(lang_datasets, lang_pairs, epoch) + dataset = TransformEosConcatLangPairDataset( + lang_datasets, + self.source_dictionary.eos(), + self.target_dictionary.eos(), + new_src_eos=end_tokens, + new_tgt_bos=bos_tokens, + ) + return dataset + + # split in (train, valid, test, ...) + def load_dataset(self, split, epoch=1, combine=False, **kwargs): + self.datasets[split] = self.load_dataset_only( + split, self.lang_pairs, epoch=epoch, combine=combine + ) diff --git a/examples/speech_text_joint_to_text/tasks/speech_text_denoise_pretrain.py b/examples/speech_text_joint_to_text/tasks/speech_text_denoise_pretrain.py new file mode 100644 index 000000000..f592633c0 --- /dev/null +++ b/examples/speech_text_joint_to_text/tasks/speech_text_denoise_pretrain.py @@ -0,0 +1,654 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import logging +import os +import re +from argparse import Namespace +from pathlib import Path + +from fairseq.data import ConcatDataset, Dictionary, encoders +from fairseq.data.audio.multi_modality_dataset import ( + FileAudioDatasetWrapper, + ModalityDatasetItem, + MultiModalityDataset, +) +from fairseq.data.audio.speech_to_text_joint_dataset import ( + S2TJointDataConfig, + SpeechToTextJointDatasetCreator, +) +from fairseq.data.iterators import GroupedEpochBatchIterator +from fairseq.tasks import register_task + +from .pair_denoising import PairedDenoisingTask + +logger = logging.getLogger(__name__) + + +@register_task("speech_text_joint_denoising") +class SpeechTextJointDenoisingPreTask(PairedDenoisingTask): + """ + Joint denoising training task for speech and text. + """ + + SIL_TOKEN = "sil" + + @classmethod + def add_args(cls, parser): + PairedDenoisingTask.add_args(parser) + # set max tokens and position + parser.add_argument( + "--max-text-tokens", + type=int, + metavar="N", + default=1024, + help="maximum samples for encoder text input ", + ) + parser.add_argument( + "--max-speech-tokens", + type=int, + metavar="N", + default=50000, + help="maximum samples for encoder speech input ", + ) + parser.add_argument( + "--max-speech-positions", + type=int, + metavar="N", + default=400, + help="maximum tokens for per encoder text input ", + ) + + parser.add_argument( + "--max-sample-size", + type=int, + metavar="N", + default=32000, + help="max sample size to crop to for batching (unsupervised speech) ", + ) + parser.add_argument( + "--min-sample-size", + type=int, + metavar="N", + default=4000, + help="min sample size to crop to for batching (unsupervised speech) ", + ) + + # set mini-batch ratio for different modalities/subtasks + # s2p + parser.add_argument( + "--supervised-speech-sample-ratio", + default="1", + type=str, + metavar="N", + help="Multiple Ratio for speech dataset with transcripts ", + ) + # s2t + parser.add_argument( + "--supervised-speech-s2s-sample-ratio", + default="1", + type=str, + metavar="N", + help="Multiple Ratio for speech dataset with transcripts ", + ) + # ssl + parser.add_argument( + "--unsupervised-speech-sample-ratio", + default="1", + type=str, + metavar="N", + help="Multiple Ratio for speech dataset without transcripts ", + ) + # t2t with monolingual data (masking) + parser.add_argument( + "--text-sample-ratio", + default="1", + type=str, + metavar="N", + help="Multiple Ratio for text set ", + ) + # t2t with parallel data (no masking) + parser.add_argument( + "--bitext-sample-ratio", + default="1", + type=str, + metavar="N", + help="Multiple Ratio for text set (bitext) ", + ) + # train_subset = "train", 'valid' or so + # parallel data is loaded according to string lang_pairs and lang_pairs_no_mask from args.data + # (un)supervised speech is loaded from args.(un)sup_speech_{train,valid}_subset + parser.add_argument( + "--sup-speech-data", default="", help="path to supervised speech data" + ) + parser.add_argument( + "--sup-speech-train-subset", + default="", + help="supervised speech training subsets", + ) + parser.add_argument( + "--sup-speech-valid-subset", + default="", + help="supervised speech validation subsets", + ) + parser.add_argument( + "--config-yaml", + default="config.yaml", + help="supervised speech configuration yaml file", + ) + parser.add_argument( + "--sup-speech-s2s-data", default="", help="path to supervised speech data" + ) + parser.add_argument( + "--sup-speech-s2s-train-subset", + default="", + help="supervised speech training subsets", + ) + parser.add_argument( + "--sup-speech-s2s-valid-subset", + default="", + help="supervised speech validation subsets", + ) + parser.add_argument( + "--config-s2s-yaml", + default="config.yaml", + help="supervised speech configuration yaml file", + ) + parser.add_argument( + "--unsup-speech-train-data", + default="", + help="path to unsupervised speech training data (tsv)", + ) + parser.add_argument( + "--unsup-speech-valid-data", + default="", + help="path to unsupervised speech valid data (tsv)", + ) + parser.add_argument( + "--sample-rate", + type=int, + metavar="N", + default=16000, + help="input audio sampling rate", + ) + parser.add_argument( + "--no-emb-update-unsup", + default=False, + action="store_true", + help="no update for output embedding during unsupervised_speech mode", + ) + parser.add_argument("--same-data-update", default=False, action="store_true") + + # used for sup_speech_ali + parser.add_argument( + "--use-sup-speech-ctc", + default=False, + action="store_true", + help="use speech_sup_ctc instead of speech_sup_ali", + ) + + @classmethod + def setup_task(cls, args, **kwargs): + """Setup the task.""" + paths = args.data.split(":") + assert len(paths) > 0 + src_dict = Dictionary.load( + os.path.join(paths[0], "src_dict.txt") + ) # assume all languages share a source dictionary + tgt_dict = Dictionary.load( + os.path.join(paths[0], "tgt_dict.txt") + ) # assume all languages share a target dictionary + + lang_pairs = args.lang_pairs + "," + args.lang_pairs_bitext + lang_pairs = re.sub(",$", "", re.sub("^,", "", lang_pairs)) + if lang_pairs != "": + src_langs = [lp.split("-")[0] for lp in lang_pairs.split(",")] + tgt_langs = [lp.split("-")[1] for lp in lang_pairs.split(",")] + else: + src_langs = [] + tgt_langs = [] + + if args.add_src_lang_token: + for lang in src_langs: + assert ( + src_dict.index(PairedDenoisingTask.LANG_TAG_TEMPLATE.format(lang)) + != src_dict.unk() + ) + if args.add_tgt_lang_token: + for lang in tgt_langs: + assert ( + tgt_dict.index(PairedDenoisingTask.LANG_TAG_TEMPLATE.format(lang)) + != tgt_dict.unk() + ) + + logger.info("source dictionary: {} types".format(len(src_dict))) + logger.info("target dictionary: {} types".format(len(tgt_dict))) + if not hasattr(args, "shuffle_instance"): + args.shuffle_instance = False + return cls(args, src_dict, tgt_dict) + + def __init__(self, args, src_dict, tgt_dict): + super().__init__(args, src_dict, tgt_dict) + self.data_cfg = S2TJointDataConfig( + Path(args.sup_speech_data) / args.config_yaml + ) + logger.info( + f"load supervised speech data configure from {Path(args.sup_speech_data) / args.config_yaml}" + ) + self.data_s2s_cfg = ( + S2TJointDataConfig(Path(args.sup_speech_s2s_data) / args.config_s2s_yaml) + if args.sup_speech_s2s_train_subset != "" + else None + ) + if self.data_s2s_cfg is not None: + logger.info( + f"load supervised sequece to sequence speech data configure from {Path(args.sup_speech_s2s_data) / args.config_yaml}" + ) + + def parse_data_ratio(sample_ratio): + ratios = sample_ratio.split(",") + if len(ratios) == 1: + return [float(ratios[0])] + epoch_ratios = [] + for item in ratios: + ep, r = item.split(":") + ep = int(ep) + r = float(r) + assert ep > 0 # epoch is 1 based + assert ep >= len(epoch_ratios) + + if len(epoch_ratios) == 0: + epoch_ratios.append( + r + ) # epoch_ratios[0] is not used, but we still set it to the first value to make thing simple. + while len(epoch_ratios) < ep: + epoch_ratios.append(epoch_ratios[-1]) + epoch_ratios.append(r) + return epoch_ratios + + self.sup_ratio = parse_data_ratio(args.supervised_speech_sample_ratio) + self.sup_s2s_ratio = parse_data_ratio(args.supervised_speech_s2s_sample_ratio) + self.text_ratio = parse_data_ratio(args.text_sample_ratio) + self.bitext_ratio = parse_data_ratio(args.bitext_sample_ratio) + self.unsup_ratio = parse_data_ratio(args.unsupervised_speech_sample_ratio) + self.sample_mode = None + + def build_model(self, args): + args.input_feat_per_channel = self.data_cfg.input_feat_per_channel + args.input_channels = self.data_cfg.input_channels + return super().build_model(args) + + def build_tokenizer(self, data_cfg, msg=""): + logger.info(f"pre-tokenizer {msg}: {data_cfg.pre_tokenizer}") + return encoders.build_tokenizer(Namespace(**data_cfg.pre_tokenizer)) + + def build_bpe(self, data_cfg, msg=""): + logger.info(f"tokenizer {msg}: {data_cfg.bpe_tokenizer}") + return encoders.build_bpe(Namespace(**data_cfg.bpe_tokenizer)) + + @classmethod + def resolve_data_type(cls, split, use_sup_speech_ctc): + if len(split.split("_")) == 1: + # default case, train or valid + is_train = split + dtype = "text" + else: + is_train, dtype = split.split("_", 1) + is_train = True if is_train == "train" else False + if dtype == "sup_speech": + dtype = "sup_speech_ctc" if use_sup_speech_ctc else "sup_speech_ali" + assert dtype in ( + "text", + "bitext", + "sup_speech_ali", + "sup_speech_s2s", + "unsup_speech", + "sup_speech_ctc", + ) + return is_train, dtype + + def create_modalitydatasetitem(self, dtype, dataset): + dsitem = None + if dtype in ("text", "bitext"): + dsitem = ModalityDatasetItem( + dtype, + dataset, + (self.args.max_source_positions, self.args.max_target_positions), + self.args.max_text_tokens, + self.args.batch_size, + ) + elif dtype in ("sup_speech_ctc", "sup_speech_ali", "sup_speech_s2s"): + dsitem = ModalityDatasetItem( + dtype, + dataset, + (self.args.max_speech_positions, self.args.max_target_positions), + self.args.max_speech_tokens, + self.args.batch_size, + ) + elif dtype == "unsup_speech": + dsitem = ModalityDatasetItem( + dtype, dataset, 1e8, self.args.max_speech_tokens, self.args.batch_size + ) + else: + raise ValueError(f"{dtype} is not supported") + return dsitem + + def load_dataset(self, split, epoch=1, combine=False, **kwargs): + def _get_sup_src_tgt_dict(src_dict, tgt_dict, use_s2s_sup_decoder): + if use_s2s_sup_decoder: + return None, tgt_dict + # use src_dict as tgt_dict here, since we use source dictionary as target for forcealignment + return None, src_dict + + is_train, dtype = self.resolve_data_type(split, self.args.use_sup_speech_ctc) + + # Note we use --add-tgt-lang-token instead of data_cfg.prepend_tgt_lang_tag_no_change to set target language tag in the text dataset + # Verify add_tgt_lang_token and prepend_tgt_lang_tag_no_change are same + + # Note we use --multilang-sampling-alpha instead of data_cfg.sampling_text_alpha to set text data sampling + if is_train: + msets = [] + # train split, load everything into one + if self.lang_pairs != "": + text_dataset = self.load_dataset_only( + "train", self.lang_pairs, epoch=epoch, combine=combine + ) + dsitem = self.create_modalitydatasetitem("text", text_dataset) + msets.append(dsitem) + if self.lang_pairs_bitext != "": # load bitext + bitext_dataset = self.load_dataset_only( + "train_bitext", + self.lang_pairs_bitext, + do_mask=False, + epoch=epoch, + combine=combine, + ) + dsitem = self.create_modalitydatasetitem("bitext", bitext_dataset) + msets.append(dsitem) + if self.args.sup_speech_train_subset != "": + pre_tokenizer = self.build_tokenizer(self.data_cfg) + bpe_tokenizer = self.build_bpe(self.data_cfg) + + append_eos = True + sup_speech_type = "sup_speech_ali" + if self.args.use_sup_speech_ctc: + # CTC mode + sup_speech_type = "sup_speech_ctc" + append_eos = False # CTC doesn't need eos in the target + + src_dict, tgt_dict = _get_sup_src_tgt_dict( + self.src_dict, self.tgt_dict, False + ) + sup_speech_dataset = SpeechToTextJointDatasetCreator.from_tsv( + self.args.sup_speech_data, + self.data_cfg, + self.args.sup_speech_train_subset, + tgt_dict=tgt_dict, + src_dict=src_dict, + pre_tokenizer=pre_tokenizer, + bpe_tokenizer=bpe_tokenizer, + src_pre_tokenizer=None, + src_bpe_tokenizer=None, + is_train_split=is_train, + epoch=epoch, + seed=self.args.seed, + append_eos=append_eos, + ) + dsitem = self.create_modalitydatasetitem( + sup_speech_type, sup_speech_dataset + ) + msets.append(dsitem) + + if self.args.sup_speech_s2s_train_subset != "": + pre_tokenizer = self.build_tokenizer(self.data_s2s_cfg, msg="(s2s)") + bpe_tokenizer = self.build_bpe(self.data_s2s_cfg, msg="(s2s)") + + # make sure self.data_cfg.prepend_tgt_lang_tag_no_change == self.args.add_tgt_lang_token + src_dict, tgt_dict = _get_sup_src_tgt_dict( + self.src_dict, self.tgt_dict, True + ) + sup_speech_s2s_dataset = SpeechToTextJointDatasetCreator.from_tsv( + self.args.sup_speech_s2s_data, + self.data_s2s_cfg, + self.args.sup_speech_s2s_train_subset, + tgt_dict=tgt_dict, + src_dict=src_dict, + pre_tokenizer=pre_tokenizer, + bpe_tokenizer=bpe_tokenizer, + src_pre_tokenizer=None, + src_bpe_tokenizer=None, + is_train_split=is_train, + epoch=epoch, + seed=self.args.seed, + ) + dsitem = self.create_modalitydatasetitem( + "sup_speech_s2s", sup_speech_s2s_dataset + ) + msets.append(dsitem) + if self.args.unsup_speech_train_data != "": + unsup_speech_dataset = FileAudioDatasetWrapper( + self.args.unsup_speech_train_data, + self.args.sample_rate, + max_sample_size=self.args.max_sample_size, + min_sample_size=self.args.min_sample_size, + normalize=False, + ) + dsitem = self.create_modalitydatasetitem( + "unsup_speech", unsup_speech_dataset + ) + msets.append(dsitem) + + pre_train_dataset = MultiModalityDataset(msets) + self.datasets[split] = pre_train_dataset + else: # validation split, load them for each type of data + if dtype == "text": + text_dataset = self.load_dataset_only( + split, self.lang_pairs, epoch=epoch, combine=combine + ) + dsitem = self.create_modalitydatasetitem("text", text_dataset) + self.datasets[split] = MultiModalityDataset([dsitem]) + elif dtype == "bitext": + bitext_dataset = self.load_dataset_only( + split, + self.lang_pairs_bitext, + do_mask=False, + epoch=epoch, + combine=combine, + ) + dsitem = self.create_modalitydatasetitem("bitext", bitext_dataset) + self.datasets[split] = MultiModalityDataset([dsitem]) + + elif dtype in ("sup_speech_ctc", "sup_speech_ali"): + assert self.args.sup_speech_valid_subset != "" + pre_tokenizer = self.build_tokenizer(self.data_cfg) + bpe_tokenizer = self.build_bpe(self.data_cfg) + append_eos = True + if dtype == "sup_speech_ctc": + # CTC mode + append_eos = False # CTC doesn't need eos + assert self.args.use_sup_speech_ctc + + datasets = [] + for split_name in self.args.sup_speech_valid_subset.split(","): + src_dict, tgt_dict = _get_sup_src_tgt_dict( + self.src_dict, self.tgt_dict, False + ) + datasets.append( + SpeechToTextJointDatasetCreator.from_tsv( + self.args.sup_speech_data, + self.data_cfg, + split_name, + tgt_dict=tgt_dict, + src_dict=src_dict, + pre_tokenizer=pre_tokenizer, + bpe_tokenizer=bpe_tokenizer, + src_pre_tokenizer=None, + src_bpe_tokenizer=None, + is_train_split=is_train, + epoch=epoch, + seed=self.args.seed, + append_eos=append_eos, + ) + ) + + dset = datasets[0] if len(datasets) == 1 else ConcatDataset(datasets) + dsitem = self.create_modalitydatasetitem(dtype, dset) + self.datasets[split] = MultiModalityDataset([dsitem]) + + elif dtype == "sup_speech_s2s": + assert self.args.sup_speech_s2s_valid_subset != "" + pre_tokenizer = self.build_tokenizer(self.data_s2s_cfg) + bpe_tokenizer = self.build_bpe(self.data_s2s_cfg) + datasets = [] + for split_name in self.args.sup_speech_s2s_valid_subset.split(","): + src_dict, tgt_dict = _get_sup_src_tgt_dict( + self.src_dict, self.tgt_dict, True + ) + datasets.append( + SpeechToTextJointDatasetCreator.from_tsv( + self.args.sup_speech_s2s_data, + self.data_s2s_cfg, + split_name, + tgt_dict=tgt_dict, + src_dict=src_dict, + pre_tokenizer=pre_tokenizer, + bpe_tokenizer=bpe_tokenizer, + src_pre_tokenizer=None, + src_bpe_tokenizer=None, + is_train_split=is_train, + epoch=epoch, + seed=self.args.seed, + ) + ) + + dset = datasets[0] if len(datasets) == 1 else ConcatDataset(datasets) + dsitem = self.create_modalitydatasetitem("sup_speech_s2s", dset) + self.datasets[split] = MultiModalityDataset([dsitem]) + elif dtype == "unsup_speech": + assert self.args.unsup_speech_valid_data != "" + unsup_speech_dataset = FileAudioDatasetWrapper( + self.args.unsup_speech_valid_data, + self.args.sample_rate, + max_sample_size=self.args.max_sample_size, + min_sample_size=self.args.min_sample_size, + normalize=False, + ) + dsitem = self.create_modalitydatasetitem( + "unsup_speech", unsup_speech_dataset + ) + self.datasets[split] = MultiModalityDataset([dsitem]) + else: + raise ValueError(f"Unsupported type {dtype}") + + def get_sample_ratio(self, epoch): + sup_ratio = ( + self.sup_ratio[epoch] if len(self.sup_ratio) > epoch else self.sup_ratio[-1] + ) + sup_s2s_ratio = ( + self.sup_s2s_ratio[epoch] + if len(self.sup_s2s_ratio) > epoch + else self.sup_s2s_ratio[-1] + ) + unsup_ratio = ( + self.unsup_ratio[epoch] + if len(self.unsup_ratio) > epoch + else self.unsup_ratio[-1] + ) + text_ratio = ( + self.text_ratio[epoch] + if len(self.text_ratio) > epoch + else self.text_ratio[-1] + ) + bitext_ratio = ( + self.bitext_ratio[epoch] + if len(self.bitext_ratio) > epoch + else self.bitext_ratio[-1] + ) + return text_ratio, bitext_ratio, sup_ratio, sup_s2s_ratio, unsup_ratio + + def get_batch_iterator( + self, + dataset, + max_tokens=None, + max_sentences=None, + max_positions=None, + ignore_invalid_inputs=False, + required_batch_size_multiple=1, + seed=1, + num_shards=1, + shard_id=0, + num_workers=0, + epoch=0, + data_buffer_size=0, + disable_iterator_cache=False, + skip_remainder_batch=False, + grouped_shuffling=False, + update_epoch_batch_itr=False, + ): + + assert isinstance(dataset, MultiModalityDataset) + if len(dataset.id_to_mode) == 1: + max_positions = dataset.max_positions[0] + max_tokens = dataset.max_tokens[0] + max_sentences = dataset.max_sentences[0] + return super().get_batch_iterator( + dataset, + max_tokens, + max_sentences, + max_positions, + ignore_invalid_inputs, + required_batch_size_multiple, + seed, + num_shards, + shard_id, + num_workers, + epoch, + data_buffer_size, + disable_iterator_cache, + skip_remainder_batch=skip_remainder_batch, + ) + + mult_ratio = [] + ( + text_ratio, + bitext_ratio, + sup_ratio, + sup_s2s_ratio, + unsup_ratio, + ) = self.get_sample_ratio(epoch) + for mode in dataset.id_to_mode: + if mode in ("sup_speech_ctc", "sup_speech_ali"): + mult_ratio.append(sup_ratio) + elif mode == "sup_speech_s2s": + mult_ratio.append(sup_s2s_ratio) + elif mode == "text": + mult_ratio.append(text_ratio) + elif mode == "bitext": + mult_ratio.append(bitext_ratio) + elif mode == "unsup_speech": + mult_ratio.append(unsup_ratio) + + # initialize the dataset with the correct starting epoch + dataset.set_epoch(epoch) + + batch_samplers = dataset.get_batch_samplers( + mult_ratio, required_batch_size_multiple, seed + ) + + # return a reusable, sharded iterator + epoch_iter = GroupedEpochBatchIterator( + dataset=dataset, + collate_fn=dataset.collater, + batch_samplers=batch_samplers, + seed=seed, + num_shards=num_shards, + shard_id=shard_id, + num_workers=num_workers, + epoch=epoch, + mult_rate=max(self.args.update_freq) if self.args.same_data_update else 1, + buffer_size=data_buffer_size, + skip_remainder_batch=skip_remainder_batch, + ) + self.dataset_to_epoch_iter[dataset] = {} # refresh it every epoch + return epoch_iter diff --git a/fairseq/data/audio/multi_modality_dataset.py b/fairseq/data/audio/multi_modality_dataset.py index 625a16ec9..39551a613 100644 --- a/fairseq/data/audio/multi_modality_dataset.py +++ b/fairseq/data/audio/multi_modality_dataset.py @@ -200,6 +200,8 @@ class LangPairMaskDataset(FairseqDataset): return self.dataset.sizes def get_batch_shapes(self): + if hasattr(self.dataset, "get_batch_shapes"): + return self.dataset.get_batch_shapes() return self.dataset.buckets def num_tokens_vec(self, indices): diff --git a/fairseq/data/audio/speech_to_text_joint_dataset.py b/fairseq/data/audio/speech_to_text_joint_dataset.py index 505ee81f3..06922ea08 100644 --- a/fairseq/data/audio/speech_to_text_joint_dataset.py +++ b/fairseq/data/audio/speech_to_text_joint_dataset.py @@ -5,22 +5,18 @@ import logging from pathlib import Path -from typing import Dict, List, Optional, NamedTuple +from typing import Dict, List, NamedTuple, Optional import torch -from fairseq.data import ( - ConcatDataset, - Dictionary, - ResamplingDataset, - data_utils as fairseq_data_utils, -) + +from fairseq.data import ConcatDataset, Dictionary, ResamplingDataset +from fairseq.data import data_utils as fairseq_data_utils from fairseq.data.audio.speech_to_text_dataset import ( - SpeechToTextDataset, S2TDataConfig, + SpeechToTextDataset, SpeechToTextDatasetCreator, ) - logger = logging.getLogger(__name__) @@ -52,8 +48,12 @@ class S2TJointDataConfig(S2TDataConfig): def prepend_tgt_lang_tag_no_change(self) -> bool: """Prepend target lang ID token as the prev_output_tokens BOS (e.g. for to-many multilingual setting). No change needed during inference. + This option is deprecated and replaced by prepend_tgt_lang_tag_as_bos. """ - return self.config.get("prepend_tgt_lang_tag_no_change", False) + value = self.config.get("prepend_tgt_lang_tag_no_change", None) + if value is None: + return self.config.get("prepend_tgt_lang_tag_as_bos", False) + return value @property def sampling_text_alpha(self): diff --git a/fairseq/models/speech_to_text/__init__.py b/fairseq/models/speech_to_text/__init__.py index e5d2ede31..f49c88e56 100644 --- a/fairseq/models/speech_to_text/__init__.py +++ b/fairseq/models/speech_to_text/__init__.py @@ -5,6 +5,8 @@ from .berard import * # noqa from .convtransformer import * # noqa +from .multi_modality_model import * # noqa from .s2t_transformer import * # noqa +from .s2t_wav_transformer import * # noqa from .xm_transformer import * # noqa from .s2t_conformer import * # noqa diff --git a/fairseq/models/speech_to_text/multi_modality_model.py b/fairseq/models/speech_to_text/multi_modality_model.py new file mode 100644 index 000000000..046421620 --- /dev/null +++ b/fairseq/models/speech_to_text/multi_modality_model.py @@ -0,0 +1,49 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from fairseq.models import FairseqDecoder, FairseqEncoder + + +# a container for different encoders with training samples from different modality +# each time, only one encoder is selected +class MultiModalityEncoder(FairseqEncoder): + def __init__(self, dictionary): + super().__init__(dictionary) + + def select_encoder(self, mode, **kwargs): + raise NotImplementedError("Model must implement the select_encoder method") + return None, kwargs + + # def post_encoder(self, encoder_out, src_tokens, src_lengths, mode, **kwargs): + # # Default do nothing + # return encoder_out + + # get sample data from JointSpeechTextDataset + def forward(self, src_tokens, src_lengths=None, mode="", **kwargs): + encoder, kwargs = self.select_encoder(mode, **kwargs) + # return self.post_encoder(encoder(src_tokens, src_lengths, **kwargs), src_tokens, src_lengths, mode, **kwargs) + return encoder(src_tokens, src_lengths, **kwargs) + + +# a container for different decoders with training samples from different modality +# each time, only one decoder is selected +class MultiInputDecoder(FairseqDecoder): + def __init__(self, dictionary): + super().__init__(dictionary) + + def select_decoder(self, mode, **kwargs): + raise NotImplementedError("Model must implement the select_decoder method") + return None, kwargs + + def forward( + self, prev_output_tokens, encoder_out, incremental_state=None, mode="", **kwargs + ): + decoder, kwargs = self.select_decoder(mode, **kwargs) + return decoder( + prev_output_tokens, + encoder_out, + incremental_state=incremental_state, + **kwargs + ) diff --git a/fairseq/models/speech_to_text/s2t_transformer.py b/fairseq/models/speech_to_text/s2t_transformer.py index 4b43e1acb..4ffbeaeec 100644 --- a/fairseq/models/speech_to_text/s2t_transformer.py +++ b/fairseq/models/speech_to_text/s2t_transformer.py @@ -295,15 +295,15 @@ class S2TTransformerModel(FairseqEncoderDecoderModel): return sample["target"], sample["target_lengths"] def get_ctc_output( - self, - net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]], - sample: Optional[Dict[str, Tensor]], + self, + net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]], + sample: Optional[Dict[str, Tensor]] ): encoder_out = net_output[1]["encoder_out"]["encoder_out"][0] logits = self.encoder.ctc_proj(encoder_out) # T x B x C out = utils.log_softmax(logits.float(), dim=-1) padding_mask = net_output[1]["encoder_out"]["encoder_padding_mask"] - lens = out.new_full((out.shape[1],), out.shape[0]).long() + lens = out.new_full((out.shape[1], ), out.shape[0]).long() if len(padding_mask) > 0: lens -= padding_mask[0].sum(dim=-1) return out, lens @@ -359,7 +359,7 @@ class S2TTransformerEncoder(FairseqEncoder): self.layer_norm = None self.ctc_proj = None - if getattr(args, "ctc_weight", 0.0) > 0.0: + if getattr(args, "ctc_weight", 0.) > 0.: self.ctc_proj = nn.Linear(args.encoder_embed_dim, args.tgt_dict_size) def _forward(self, src_tokens, src_lengths, return_all_hiddens=False): diff --git a/fairseq/models/speech_to_text/s2t_wav_transformer.py b/fairseq/models/speech_to_text/s2t_wav_transformer.py new file mode 100644 index 000000000..f11034818 --- /dev/null +++ b/fairseq/models/speech_to_text/s2t_wav_transformer.py @@ -0,0 +1,485 @@ +#!/usr/bin/env python3 + +import math + +import torch +import torch.nn as nn + +from fairseq.data.data_utils import compute_mask_indices +from fairseq.models import FairseqEncoder +from fairseq.models.wav2vec import ConvFeatureExtractionModel +from fairseq.modules import GradMultiply, LayerNorm, SamePad, TransformerEncoderLayer + + +# Transformer encoder with wave input, it is adopted from wav2vec 2.0 Encoder. +# use wav input +# use trained position embedding so it is easier to match with text input +class SpeechWavTransformerEncoder(FairseqEncoder): + + # extra parameters for speech encoder besides those defined in transformermodel + @staticmethod + def add_args(parser): + parser.add_argument( + "--dropout-input", + type=float, + metavar="D", + help="dropout to apply to the input (after feat extr)", + ) + parser.add_argument( + "--dropout-features", + type=float, + metavar="D", + help="dropout to apply to the unmasked features (after feat extr)", + ) + parser.add_argument( + "--speech-extractor-mode", + type=str, + default="layer_norm", + choices=["default", "layer_norm"], + help="feature extractor norm", + ) + + parser.add_argument( + "--speech-conv-bias", + action="store_true", + help="include bias in speech conv encoder", + ) + + parser.add_argument( + "--conv-feature-layers", + default="[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] + [(512,2,2)]", + help="string describing convolutional feature extraction layers in form of a python list that contains [(dim, kernel_size, stride), ...]", + ) + + parser.add_argument( + "--speech-mask-length", + type=int, + help="repeat the mask indices multiple times", + ) + + parser.add_argument( + "--speech-mask-prob", + type=float, + help="probability of replacing a token with mask", + ) + + parser.add_argument( + "--speech-mask-selection", + type=str, + choices=["static", "uniform", "normal", "poisson"], + help="how to choose masks", + ) + + parser.add_argument( + "--speech-mask-other", + type=float, + help="stdev of the mask length in case of 'normal' selection strategy", + ) + + parser.add_argument( + "--speech-no-mask-overlap", + action="store_true", + help="whether to allow masks to overlap", + ) + + parser.add_argument( + "--speech-mask-min-space", + type=int, + help="min space between spans (if no overlap is enabled)", + ) + + parser.add_argument( + "--speech-mask-channel-length", + type=int, + help="repeat the mask indices multiple times", + ) + + parser.add_argument( + "--speech-mask-channel-prob", + type=float, + help="probability of replacing a token with mask", + ) + + parser.add_argument( + "--speech-mask-channel-selection", + type=str, + choices=["static", "uniform", "normal", "poisson"], + help="how to choose masks", + ) + + parser.add_argument( + "--speech-mask-channel-other", + type=float, + help="stdev of the mask length in case of 'normal' selection strategy", + ) + + parser.add_argument( + "--speech-no-mask-channel-overlap", + action="store_true", + help="whether to allow masks to overlap", + ) + + parser.add_argument( + "--no-scale-feature", + action="store_true", + help="no scale for the calculated features", + ) + + parser.add_argument( + "--speech-mask-channel-min-space", + type=int, + help="min space between spans (if no overlap is enabled)", + ) + + parser.add_argument( + "--feature-grad-mult", + type=float, + help="reset feature grad mult in wav2vec 2.0 to this", + ) + + # positional embeddings + parser.add_argument( + "--conv-pos", + type=int, + default=128, + help="number of filters for convolutional positional embeddings", + ) + + parser.add_argument( + "--conv-pos-groups", + type=int, + default=16, + help="number of groups for convolutional positional embedding", + ) + # model configures + parser.add_argument( + "--speech-encoder-layers", + type=int, + help="number of speech encoder layers", + ) + parser.add_argument( + "--text-encoder-layers", + type=int, + help="number of text encoder layers", + ) + + def __init__(self, args, alway_mask=False): + super().__init__(args) + self.args = args + self.dropout = args.dropout + self.embedding_dim = args.encoder_embed_dim + self.feat_scale = math.sqrt(args.encoder_embed_dim) + if args.no_scale_feature: + self.feat_scale = 1.0 + + subsample = ConvFeatureExtractionModel( + conv_layers=eval(args.conv_feature_layers), + dropout=0.0, + mode=args.speech_extractor_mode, # default, layer_norm + conv_bias=args.speech_conv_bias, + ) + feature_enc_layers = eval(args.conv_feature_layers) + self.subsample = subsample + self.feat_proj = ( + nn.Linear(feature_enc_layers[-1][0], self.embedding_dim) + if feature_enc_layers[-1][0] != self.embedding_dim + else None + ) + + self.feat_layer_norm = LayerNorm(feature_enc_layers[-1][0]) + + self.embed_positions = nn.Conv1d( + self.embedding_dim, + self.embedding_dim, + kernel_size=args.conv_pos, + padding=args.conv_pos // 2, + groups=args.conv_pos_groups, + ) + std = math.sqrt(4 / (args.conv_pos * self.embedding_dim)) + nn.init.normal_(self.embed_positions.weight, mean=0, std=std) + nn.init.constant_(self.embed_positions.bias, 0) + + self.embed_positions = nn.utils.weight_norm( + self.embed_positions, name="weight", dim=2 + ) + self.embed_positions = nn.Sequential( + self.embed_positions, SamePad(args.conv_pos), nn.GELU() + ) + + self.mask_prob = args.speech_mask_prob + self.mask_selection = args.speech_mask_selection + self.mask_other = args.speech_mask_other + self.mask_length = args.speech_mask_length + self.no_mask_overlap = args.speech_no_mask_overlap + self.mask_min_space = args.speech_mask_min_space + + self.mask_channel_prob = args.speech_mask_channel_prob + self.mask_channel_selection = args.speech_mask_channel_selection + self.mask_channel_other = args.speech_mask_channel_other + self.mask_channel_length = args.speech_mask_channel_length + self.no_mask_channel_overlap = args.speech_no_mask_channel_overlap + self.mask_channel_min_space = args.speech_mask_channel_min_space + + self.dropout_input = nn.Dropout(args.dropout_input) + self.dropout_features = nn.Dropout(args.dropout_features) + + self.feature_grad_mult = args.feature_grad_mult + + self.mask_emb = nn.Parameter( + torch.FloatTensor(args.encoder_embed_dim).uniform_() + ) + + self.layers = nn.ModuleList( + [TransformerEncoderLayer(args) for _ in range(args.encoder_layers)] + ) + self.layer_norm = LayerNorm(args.encoder_embed_dim) + self.normalize_before = args.encoder_normalize_before + self.alway_mask = alway_mask + + def apply_mask(self, x, padding_mask): + B, T, C = x.shape + if self.mask_prob > 0: + mask_indices = compute_mask_indices( + (B, T), + padding_mask, + self.mask_prob, + self.mask_length, + self.mask_selection, + self.mask_other, + min_masks=2, + no_overlap=self.no_mask_overlap, + min_space=self.mask_min_space, + ) + mask_indices = torch.from_numpy(mask_indices).to(x.device) + x[mask_indices] = self.mask_emb + else: + mask_indices = None + + if self.mask_channel_prob > 0: + mask_channel_indices = compute_mask_indices( + (B, C), + None, + self.mask_channel_prob, + self.mask_channel_length, + self.mask_channel_selection, + self.mask_channel_other, + no_overlap=self.no_mask_channel_overlap, + min_space=self.mask_channel_min_space, + ) + mask_channel_indices = ( + torch.from_numpy(mask_channel_indices) + .to(x.device) + .unsqueeze(1) + .expand(-1, T, -1) + ) + x[mask_channel_indices] = 0 + + return x, mask_indices + + def forward( + self, + src_tokens, + src_lengths, + return_all_hiddens=False, + padding_mask=None, + features_only=True, + ): + mask = self.training or self.alway_mask + if self.feature_grad_mult > 0 and self.training: + features = self.subsample(src_tokens) + if self.feature_grad_mult != 1.0: + features = GradMultiply.apply(features, self.feature_grad_mult) + else: + with torch.no_grad(): + features = self.subsample(src_tokens) + features = features.transpose(1, 2) + features = self.feat_layer_norm(features) + if self.feat_proj is not None: + features = self.feat_proj(features) + + if padding_mask is not None: + input_lengths = (1 - padding_mask.long()).sum(-1) + # apply conv formula to get real output_lengths + output_lengths = self._get_feat_extract_output_lengths(input_lengths) + + padding_mask = torch.zeros( + features.shape[:2], dtype=features.dtype, device=features.device + ) + + # these two operations makes sure that all values + # before the output lengths indices are attended to + padding_mask[ + ( + torch.arange(padding_mask.shape[0], device=padding_mask.device), + output_lengths - 1, + ) + ] = 1 + padding_mask = (1 - padding_mask.flip([-1]).cumsum(-1).flip([-1])).bool() + + features = self.feat_scale * features if self.feat_scale != 1.0 else features + unmasked_features = features.clone() + + features = self.dropout_input(features) + unmasked_features = self.dropout_features(unmasked_features) + if mask: + x, mask_indices = self.apply_mask(features, padding_mask) + else: + x = features + mask_indices = None + + def cal_transformer_layers(x, encoder_padding_mask, return_all_hiddens=False): + # x: B x T x C + positions = self.embed_positions(x.transpose(1, 2)).transpose(1, 2) + x = x + positions + if not self.normalize_before: + x = self.layer_norm(x) + + # B x T x C -> T x B x C + x = x.transpose(0, 1) + encoder_states = [] + for layer in self.layers: + x = layer(x, encoder_padding_mask) + if return_all_hiddens: + encoder_states.append(x) + if self.normalize_before: + x = self.layer_norm(x) + return x, encoder_states + + x, encoder_states = cal_transformer_layers(x, padding_mask, return_all_hiddens) + if features_only: + return { + "encoder_out": [x], # [T x B x C] + "encoder_padding_mask": [padding_mask] + if padding_mask is not None + else [], # B x T + "encoder_embedding": [], # + "encoder_states": encoder_states, # List[T x B x C] + "src_tokens": [], + "src_lengths": [], + "mask_indices": [mask_indices], + } + + x_unmasked = x + if self.mask_prob > 0 or self.mask_channel_prob > 0: + x_unmasked, _ = cal_transformer_layers(unmasked_features, padding_mask) + return { + "encoder_out": [x], # [T x B x C] + "encoder_unmasked_out": [x_unmasked], # [T x B x C] + "encoder_padding_mask": [padding_mask] + if padding_mask is not None + else [], # B x T + "encoder_embedding": [], # + "encoder_states": encoder_states, # List[T x B x C] + "src_tokens": [], + "src_lengths": [], + "mask_indices": [mask_indices] if mask_indices is not None else [], # B X T + } + + def reorder_encoder_out(self, encoder_out, new_order): + new_encoder_out = ( + [] + if len(encoder_out["encoder_out"]) == 0 + else [x.index_select(1, new_order) for x in encoder_out["encoder_out"]] + ) + + new_encoder_padding_mask = ( + [] + if len(encoder_out["encoder_padding_mask"]) == 0 + else [ + x.index_select(0, new_order) + for x in encoder_out["encoder_padding_mask"] + ] + ) + + new_encoder_embedding = ( + [] + if len(encoder_out["encoder_embedding"]) == 0 + else [ + x.index_select(0, new_order) for x in encoder_out["encoder_embedding"] + ] + ) + + encoder_states = encoder_out["encoder_states"] + if len(encoder_states) > 0: + for idx, state in enumerate(encoder_states): + encoder_states[idx] = state.index_select(1, new_order) + + return { + "encoder_out": new_encoder_out, # T x B x C + "encoder_padding_mask": new_encoder_padding_mask, # B x T + "encoder_embedding": new_encoder_embedding, # B x T x C + "encoder_states": encoder_states, # List[T x B x C] + "src_tokens": [], # B x T + "src_lengths": [], # B x 1 + } + + +class StackedSpeechWavTransformerEncoder(FairseqEncoder): + def __init__(self, speech_enc, text_enc_layers, text_layer_norm): + super().__init__(None) + self.speech_encoder = speech_enc + self.text_encoder_layers = text_enc_layers + self.final_layer_norm = text_layer_norm + + def forward( + self, + src_tokens, + src_lengths=None, + return_all_hiddens=False, + padding_mask=None, + features_only=True, + ): + + out = self.speech_encoder.forward( + src_tokens, + src_lengths, + return_all_hiddens, + padding_mask=padding_mask, + features_only=features_only, + ) + x = out["encoder_out"][0] + encoder_padding_mask = None + if len(out["encoder_padding_mask"]) > 0: + encoder_padding_mask = out["encoder_padding_mask"][0] + + def cal_text_layers(x, padding_mask, return_all_hiddens=False): + encoder_states = [] + for layer in self.text_encoder_layers: + x = layer(x, padding_mask) + if return_all_hiddens: + encoder_states.append(x) + if self.final_layer_norm is not None: + x = self.final_layer_norm(x) + return x, encoder_states + + x, encoder_states = cal_text_layers(x, encoder_padding_mask, return_all_hiddens) + if features_only: + return { + "encoder_out": [x], # T x B x C + "encoder_padding_mask": [encoder_padding_mask] + if encoder_padding_mask is not None + else [], # B x T + "encoder_embedding": [], # B x T x C + "encoder_states": encoder_states, # List[T x B x C] + "src_tokens": [], + "src_lengths": [], + } + + x_u = out["encoder_unmasked_out"][0] + x_u, _ = cal_text_layers(x_u, encoder_padding_mask) + + return { + "encoder_out": [x], # [T x B x C] + "encoder_unmasked_out": [x_u], # [T x B x C] + "encoder_padding_mask": [encoder_padding_mask] + if encoder_padding_mask is not None + else [], # B x T + "encoder_embedding": [], # + "encoder_states": encoder_states, # List[T x B x C] + "src_tokens": [], + "src_lengths": [], + "mask_indices": out["mask_indices"], # B X T + } + + def reorder_encoder_out(self, encoder_out, new_order): + return self.speech_encoder.reorder_encoder_out(encoder_out, new_order) diff --git a/tests/speech/test_dual_input_wav_transformer.py b/tests/speech/test_dual_input_wav_transformer.py new file mode 100644 index 000000000..3581bc199 --- /dev/null +++ b/tests/speech/test_dual_input_wav_transformer.py @@ -0,0 +1,76 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import unittest +from collections import namedtuple +from pathlib import Path + +import torch +from tqdm import tqdm + +import fairseq +from fairseq import utils +from fairseq.checkpoint_utils import load_model_ensemble_and_task +from fairseq.scoring.bleu import SacrebleuScorer +from fairseq.tasks import import_tasks +from tests.speech import S3_BASE_URL, TestFairseqSpeech + + +@unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU") +class TestLibrispeechDualInputWavTransformer(TestFairseqSpeech): + def setUp(self): + dataset_id = "librispeech_wvtrasnformer" + base_url = "https://dl.fbaipublicfiles.com/joint_speech_text_4_s2t/acl2022/librispeech/finetuned" + data_filenames = [ + "checkpoint_ave_10.pt", + "spm.model", + "src_dict.txt", + "tgt_dict.txt", + "config.yaml", + ] + self._set_up( + dataset_id, + "s2t", + [ + "librispeech_flac_test-other.tsv", + "librispeech_flac_test-other.zip", + ], + ) + for filename in data_filenames: + self.download(base_url, self.root, filename) + + def import_user_module(self): + user_dir = ( + Path(fairseq.__file__).parent.parent / "examples/speech_text_joint_to_text" + ) + Arg = namedtuple("Arg", ["user_dir"]) + arg = Arg(user_dir.__str__()) + utils.import_user_module(arg) + + @torch.no_grad() + def test_librispeech_dualinput_wav_transformer_checkpoint(self): + self.import_user_module() + checkpoint_filename = "checkpoint_ave_10.pt" + arg_overrides = { + "config_yaml": "config.yaml", + "load_pretrained_speech_text_encoder": "", + "load_pretrained_speech_text_decoder": "", + "beam": 10, + "nbest": 1, + "lenpen": 1.0, + "load_speech_only": True, + } + self.base_test( + checkpoint_filename, + 4.6, + dataset="librispeech_flac_test-other", + max_tokens=800000, + max_positions=(800000, 1024), + arg_overrides=arg_overrides, + ) + + +if __name__ == "__main__": + unittest.main()