UnitY implementation (#4670)

* Add UnitY implementation

* Rename for consistency

* Refactor conformer encoder construction

* Change the order of arguments for rdrop_alpha

* Add compute_loss_with_rdrop

* Move build_multitask_decoder to xm_transformer_unity.py

* Fix generator selection

* Fix check in build_criterion

* Modularize Rdrop

* Minor fix

* Refine class names

* Refactor submodules

* Fix CE

* Fix import

* Fix argments for datasets

* Add description to AugTransformerDecoderBase

* Fix SpeechToTextDatasetCreator

* Fix metavar in arguments

* Uncomment override_decoder_args

* Fix comment in warning

* Add is_fisrt_pass_decoder flag

* Change Translatotron2SpeechGenerator to MultiDecoderSpeechGenerator

* Move inference code to examples/speech_to_speech/unity

* Fix rdrop default value in aux tasks

* Add language tag mapping option to multitask-config-yaml

* Rename encoder_out2 and encoder_outs2

* Rename UnitYXMTransformerModel to XMTransformerModelUnitY

* Support num_best_checkpoints in average_checkpoints

* Fix has_multitask

* Inherit SequenceGenerator

* Reflect recent updates

* Minor fix in logging

* Fix typo

* Refactor SpeechToSpectrogram2passMultitaskTaskCriterion

* Minor update for multitask
This commit is contained in:
Hirofumi Inaguma 2022-10-06 19:38:32 -07:00 committed by GitHub
parent 6d90f79883
commit b4001184f4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
26 changed files with 3480 additions and 45 deletions

View File

@ -2,3 +2,5 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from . import unity # noqa

View File

@ -0,0 +1,7 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from . import sequence_generator # noqa
from . import sequence_generator_multi_decoder # noqa

View File

@ -0,0 +1,626 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
import sys
from typing import Dict, List, Optional
import torch
from torch import Tensor
from fairseq.sequence_generator import EnsembleModel as EnsembleModelBase
from fairseq.sequence_generator import SequenceGenerator as SequenceGeneratorBase
class SequenceGenerator(SequenceGeneratorBase):
def __init__(
self,
models,
tgt_dict,
beam_size=1,
max_len_a=0,
max_len_b=200,
max_len=0,
min_len=1,
normalize_scores=True,
len_penalty=1.0,
unk_penalty=0.0,
temperature=1.0,
match_source_len=False,
no_repeat_ngram_size=0,
search_strategy=None,
eos=None,
symbols_to_strip_from_output=None,
lm_model=None,
lm_weight=1.0,
tokens_to_suppress=(),
):
"""Generates translations of a given source sentence.
Args:
models (List[~fairseq.models.FairseqModel]): ensemble of models,
currently support fairseq.models.TransformerModel for scripting
beam_size (int, optional): beam width (default: 1)
max_len_a/b (int, optional): generate sequences of maximum length
ax + b, where x is the source length
max_len (int, optional): the maximum length of the generated output
(not including end-of-sentence)
min_len (int, optional): the minimum length of the generated output
(not including end-of-sentence)
normalize_scores (bool, optional): normalize scores by the length
of the output (default: True)
len_penalty (float, optional): length penalty, where <1.0 favors
shorter, >1.0 favors longer sentences (default: 1.0)
unk_penalty (float, optional): unknown word penalty, where <0
produces more unks, >0 produces fewer (default: 0.0)
temperature (float, optional): temperature, where values
>1.0 produce more uniform samples and values <1.0 produce
sharper samples (default: 1.0)
match_source_len (bool, optional): outputs should match the source
length (default: False)
"""
super().__init__(
models=models,
tgt_dict=tgt_dict,
beam_size=beam_size,
max_len_a=max_len_a,
max_len_b=max_len_b,
max_len=max_len,
min_len=min_len,
normalize_scores=normalize_scores,
len_penalty=len_penalty,
unk_penalty=unk_penalty,
temperature=temperature,
match_source_len=match_source_len,
no_repeat_ngram_size=no_repeat_ngram_size,
search_strategy=search_strategy,
eos=eos,
symbols_to_strip_from_output=symbols_to_strip_from_output,
lm_model=lm_model,
lm_weight=lm_weight,
tokens_to_suppress=tokens_to_suppress,
)
if isinstance(models, EnsembleModel):
self.model = models
else:
self.model = EnsembleModel(models)
self.model.set_decoder_beam_size(self.beam_size)
self.model.eval()
def _generate(
self,
sample: Dict[str, Dict[str, Tensor]],
prefix_tokens: Optional[Tensor] = None,
constraints: Optional[Tensor] = None,
bos_token: Optional[int] = None,
):
net_input = sample["net_input"]
if "src_tokens" in net_input:
src_tokens = net_input["src_tokens"]
# length of the source text being the character length except EndOfSentence and pad
# if src_lengths exists in net_input (speech_to_text dataset case), then use it
if "src_lengths" in net_input:
src_lengths = net_input["src_lengths"]
else:
src_lengths = (
(src_tokens.ne(self.eos) & src_tokens.ne(self.pad))
.long()
.sum(dim=1)
)
elif "source" in net_input:
src_tokens = net_input["source"]
src_lengths = (
net_input["padding_mask"].size(-1) - net_input["padding_mask"].sum(-1)
if net_input["padding_mask"] is not None
else torch.tensor(src_tokens.size(-1)).to(src_tokens)
)
elif "features" in net_input:
src_tokens = net_input["features"]
src_lengths = (
net_input["padding_mask"].size(-1) - net_input["padding_mask"].sum(-1)
if net_input["padding_mask"] is not None
else torch.tensor(src_tokens.size(-1)).to(src_tokens)
)
else:
raise Exception(
"expected src_tokens or source in net input. input keys: "
+ str(net_input.keys())
)
if constraints is not None and not self.search.supports_constraints:
raise NotImplementedError(
"Target-side constraints were provided, but search method doesn't support them"
)
# Initialize constraints, when active
self.search.init_constraints(constraints, self.beam_size)
# compute the encoder output for each beam
with torch.autograd.profiler.record_function("EnsembleModel: forward_encoder"):
encoder_outs = self.model.forward_encoder(net_input)
finalized = self.generate_decoder(
encoder_outs,
src_tokens,
src_lengths,
sample,
prefix_tokens,
constraints,
bos_token,
)
return finalized
def generate_decoder(
self,
encoder_outs,
src_tokens,
src_lengths,
sample: Dict[str, Dict[str, Tensor]],
prefix_tokens: Optional[Tensor] = None,
constraints: Optional[Tensor] = None,
bos_token: Optional[int] = None,
aux_task_name="",
encoder_outs_aug: Optional[
Tensor
] = None, # an additional/augmented encoder_outs
):
incremental_states = torch.jit.annotate(
List[Dict[str, Dict[str, Optional[Tensor]]]],
[
torch.jit.annotate(Dict[str, Dict[str, Optional[Tensor]]], {})
for i in range(self.model.models_size)
],
)
# bsz: total number of sentences in beam
# Note that src_tokens may have more than 2 dimensions (i.e. audio features)
bsz, src_len = src_tokens.size()[:2]
beam_size = self.beam_size
decoder_name = f"{aux_task_name}_decoder" if aux_task_name else "decoder"
max_len: int = -1
if self.match_source_len:
max_len = src_lengths.max().item()
else:
max_len = min(
int(self.max_len_a * src_len + self.max_len_b),
self.max_len - 1,
)
assert (
self.min_len <= max_len
), "min_len cannot be larger than max_len, please adjust these!"
# placeholder of indices for bsz * beam_size to hold tokens and accumulative scores
new_order = torch.arange(bsz).view(-1, 1).repeat(1, beam_size).view(-1)
new_order = new_order.to(src_tokens.device).long()
encoder_outs = self.model.reorder_encoder_out(encoder_outs, new_order)
# ensure encoder_outs is a List.
assert encoder_outs is not None
if encoder_outs_aug is not None:
encoder_outs_aug = self.model.reorder_encoder_out(
encoder_outs_aug, new_order
)
# initialize buffers
scores = (
torch.zeros(bsz * beam_size, max_len + 1).to(src_tokens).float()
) # +1 for eos; pad is never chosen for scoring
tokens = (
torch.zeros(bsz * beam_size, max_len + 2)
.to(src_tokens)
.long()
.fill_(self.pad)
) # +2 for eos and pad
tokens[:, 0] = self.eos if bos_token is None else bos_token
attn: Optional[Tensor] = None
# A list that indicates candidates that should be ignored.
# For example, suppose we're sampling and have already finalized 2/5
# samples. Then cands_to_ignore would mark 2 positions as being ignored,
# so that we only finalize the remaining 3 samples.
cands_to_ignore = (
torch.zeros(bsz, beam_size).to(src_tokens).eq(-1)
) # forward and backward-compatible False mask
# list of completed sentences
finalized = torch.jit.annotate(
List[List[Dict[str, Tensor]]],
[torch.jit.annotate(List[Dict[str, Tensor]], []) for i in range(bsz)],
) # contains lists of dictionaries of infomation about the hypothesis being finalized at each step
# a boolean array indicating if the sentence at the index is finished or not
finished = [False for i in range(bsz)]
num_remaining_sent = bsz # number of sentences remaining
# number of candidate hypos per step
cand_size = 2 * beam_size # 2 x beam size in case half are EOS
# offset arrays for converting between different indexing schemes
bbsz_offsets = (
(torch.arange(0, bsz) * beam_size)
.unsqueeze(1)
.type_as(tokens)
.to(src_tokens.device)
)
cand_offsets = torch.arange(0, cand_size).type_as(tokens).to(src_tokens.device)
reorder_state: Optional[Tensor] = None
batch_idxs: Optional[Tensor] = None
original_batch_idxs: Optional[Tensor] = None
if "id" in sample and isinstance(sample["id"], Tensor):
original_batch_idxs = sample["id"]
else:
original_batch_idxs = torch.arange(0, bsz).type_as(tokens)
for step in range(max_len + 1): # one extra step for EOS marker
# reorder decoder internal states based on the prev choice of beams
if reorder_state is not None:
if batch_idxs is not None:
# update beam indices to take into account removed sentences
corr = batch_idxs - torch.arange(batch_idxs.numel()).type_as(
batch_idxs
)
reorder_state.view(-1, beam_size).add_(
corr.unsqueeze(-1) * beam_size
)
original_batch_idxs = original_batch_idxs[batch_idxs]
self.model.reorder_incremental_state(
incremental_states, reorder_state, decoder_name
)
encoder_outs = self.model.reorder_encoder_out(
encoder_outs, reorder_state
)
if encoder_outs_aug is not None:
encoder_outs_aug = self.model.reorder_encoder_out(
encoder_outs_aug, reorder_state
)
with torch.autograd.profiler.record_function(
"EnsembleModel: forward_decoder"
):
lprobs, avg_attn_scores = self.model.forward_decoder(
tokens[:, : step + 1],
encoder_outs,
incremental_states,
self.temperature,
decoder_name=decoder_name,
encoder_outs_aug=encoder_outs_aug,
)
if self.lm_model is not None and not aux_task_name:
lm_out = self.lm_model(tokens[:, : step + 1])
probs = self.lm_model.get_normalized_probs(
lm_out, log_probs=True, sample=None
)
probs = probs[:, -1, :] * self.lm_weight
lprobs += probs
lprobs[lprobs != lprobs] = torch.tensor(-math.inf).to(lprobs)
lprobs[:, self.pad] = -math.inf # never select pad
lprobs[:, self.unk] -= self.unk_penalty # apply unk penalty
# handle max length constraint
if step >= max_len:
lprobs[:, : self.eos] = -math.inf
lprobs[:, self.eos + 1 :] = -math.inf
# handle prefix tokens (possibly with different lengths)
if (
prefix_tokens is not None
and step < prefix_tokens.size(1)
and step < max_len
):
lprobs, tokens, scores = self._prefix_tokens(
step, lprobs, scores, tokens, prefix_tokens, beam_size
)
else:
if step < self.min_len:
# minimum length constraint (does not apply if using prefix_tokens)
lprobs[:, self.eos] = -math.inf
if self.token_indices_to_suppress is not None:
lprobs[:, self.token_indices_to_suppress] = -math.inf
# Record attention scores, only support avg_attn_scores is a Tensor
if avg_attn_scores is not None:
if attn is None:
attn = torch.empty(
bsz * beam_size, avg_attn_scores.size(1), max_len + 2
).to(scores)
attn[:, :, step + 1].copy_(avg_attn_scores)
scores = scores.type_as(lprobs)
eos_bbsz_idx = torch.empty(0).to(
tokens
) # indices of hypothesis ending with eos (finished sentences)
eos_scores = torch.empty(0).to(
scores
) # scores of hypothesis ending with eos (finished sentences)
if self.should_set_src_lengths:
self.search.set_src_lengths(src_lengths)
if self.repeat_ngram_blocker is not None:
lprobs = self.repeat_ngram_blocker(tokens, lprobs, bsz, beam_size, step)
# Shape: (batch, cand_size)
cand_scores, cand_indices, cand_beams = self.search.step(
step,
lprobs.view(bsz, -1, self.vocab_size),
scores.view(bsz, beam_size, -1)[:, :, :step],
tokens[:, : step + 1],
original_batch_idxs,
)
# cand_bbsz_idx contains beam indices for the top candidate
# hypotheses, with a range of values: [0, bsz*beam_size),
# and dimensions: [bsz, cand_size]
cand_bbsz_idx = cand_beams.add(bbsz_offsets)
# finalize hypotheses that end in eos
# Shape of eos_mask: (batch size, beam size)
eos_mask = cand_indices.eq(self.eos) & cand_scores.ne(-math.inf)
eos_mask[:, :beam_size][cands_to_ignore] = torch.tensor(0).to(eos_mask)
# only consider eos when it's among the top beam_size indices
# Now we know what beam item(s) to finish
# Shape: 1d list of absolute-numbered
eos_bbsz_idx = torch.masked_select(
cand_bbsz_idx[:, :beam_size], mask=eos_mask[:, :beam_size]
)
finalized_sents: List[int] = []
if eos_bbsz_idx.numel() > 0:
eos_scores = torch.masked_select(
cand_scores[:, :beam_size], mask=eos_mask[:, :beam_size]
)
finalized_sents = self.finalize_hypos(
step,
eos_bbsz_idx,
eos_scores,
tokens,
scores,
finalized,
finished,
beam_size,
attn,
src_lengths,
max_len,
)
num_remaining_sent -= len(finalized_sents)
assert num_remaining_sent >= 0
if num_remaining_sent == 0:
break
if self.search.stop_on_max_len and step >= max_len:
break
assert step < max_len, f"{step} < {max_len}"
# Remove finalized sentences (ones for which {beam_size}
# finished hypotheses have been generated) from the batch.
if len(finalized_sents) > 0:
new_bsz = bsz - len(finalized_sents)
# construct batch_idxs which holds indices of batches to keep for the next pass
batch_mask = torch.ones(
bsz, dtype=torch.bool, device=cand_indices.device
)
batch_mask[finalized_sents] = False
# TODO replace `nonzero(as_tuple=False)` after TorchScript supports it
batch_idxs = torch.arange(
bsz, device=cand_indices.device
).masked_select(batch_mask)
# Choose the subset of the hypothesized constraints that will continue
self.search.prune_sentences(batch_idxs)
eos_mask = eos_mask[batch_idxs]
cand_beams = cand_beams[batch_idxs]
bbsz_offsets.resize_(new_bsz, 1)
cand_bbsz_idx = cand_beams.add(bbsz_offsets)
cand_scores = cand_scores[batch_idxs]
cand_indices = cand_indices[batch_idxs]
if prefix_tokens is not None:
prefix_tokens = prefix_tokens[batch_idxs]
src_lengths = src_lengths[batch_idxs]
cands_to_ignore = cands_to_ignore[batch_idxs]
scores = scores.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1)
tokens = tokens.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1)
if attn is not None:
attn = attn.view(bsz, -1)[batch_idxs].view(
new_bsz * beam_size, attn.size(1), -1
)
bsz = new_bsz
else:
batch_idxs = None
# Set active_mask so that values > cand_size indicate eos hypos
# and values < cand_size indicate candidate active hypos.
# After, the min values per row are the top candidate active hypos
# Rewrite the operator since the element wise or is not supported in torchscript.
eos_mask[:, :beam_size] = ~((~cands_to_ignore) & (~eos_mask[:, :beam_size]))
active_mask = torch.add(
eos_mask.type_as(cand_offsets) * cand_size,
cand_offsets[: eos_mask.size(1)],
)
# get the top beam_size active hypotheses, which are just
# the hypos with the smallest values in active_mask.
# {active_hypos} indicates which {beam_size} hypotheses
# from the list of {2 * beam_size} candidates were
# selected. Shapes: (batch size, beam size)
new_cands_to_ignore, active_hypos = torch.topk(
active_mask, k=beam_size, dim=1, largest=False
)
# update cands_to_ignore to ignore any finalized hypos.
cands_to_ignore = new_cands_to_ignore.ge(cand_size)[:, :beam_size]
# Make sure there is at least one active item for each sentence in the batch.
assert (~cands_to_ignore).any(dim=1).all()
# update cands_to_ignore to ignore any finalized hypos
# {active_bbsz_idx} denotes which beam number is continued for each new hypothesis (a beam
# can be selected more than once).
active_bbsz_idx = torch.gather(cand_bbsz_idx, dim=1, index=active_hypos)
active_scores = torch.gather(cand_scores, dim=1, index=active_hypos)
active_bbsz_idx = active_bbsz_idx.view(-1)
active_scores = active_scores.view(-1)
# copy tokens and scores for active hypotheses
# Set the tokens for each beam (can select the same row more than once)
tokens[:, : step + 1] = torch.index_select(
tokens[:, : step + 1], dim=0, index=active_bbsz_idx
)
# Select the next token for each of them
tokens.view(bsz, beam_size, -1)[:, :, step + 1] = torch.gather(
cand_indices, dim=1, index=active_hypos
)
if step > 0:
scores[:, :step] = torch.index_select(
scores[:, :step], dim=0, index=active_bbsz_idx
)
scores.view(bsz, beam_size, -1)[:, :, step] = torch.gather(
cand_scores, dim=1, index=active_hypos
)
# Update constraints based on which candidates were selected for the next beam
self.search.update_constraints(active_hypos)
# copy attention for active hypotheses
if attn is not None:
attn[:, :, : step + 2] = torch.index_select(
attn[:, :, : step + 2], dim=0, index=active_bbsz_idx
)
# reorder incremental state in decoder
reorder_state = active_bbsz_idx
# sort by score descending
for sent in range(len(finalized)):
scores = torch.tensor(
[float(elem["score"].item()) for elem in finalized[sent]]
)
_, sorted_scores_indices = torch.sort(scores, descending=True)
finalized[sent] = [finalized[sent][ssi] for ssi in sorted_scores_indices]
finalized[sent] = torch.jit.annotate(
List[Dict[str, Tensor]], finalized[sent]
)
return finalized
class EnsembleModel(EnsembleModelBase):
"""A wrapper around an ensemble of models."""
def __init__(self, models):
super().__init__(models)
@torch.jit.export
def forward_decoder(
self,
tokens,
encoder_outs: List[Dict[str, List[Tensor]]],
incremental_states: List[Dict[str, Dict[str, Optional[Tensor]]]],
temperature: float = 1.0,
decoder_name="decoder",
encoder_outs_aug: List[Dict[str, List[Tensor]]] = None,
):
log_probs = []
avg_attn: Optional[Tensor] = None
encoder_out: Optional[Dict[str, List[Tensor]]] = None
encoder_out_aug: Optional[Dict[str, List[Tensor]]] = None
for i, model in enumerate(self.models):
if self.has_encoder():
encoder_out = encoder_outs[i]
if encoder_outs_aug is not None:
encoder_out_aug = encoder_outs_aug[i]
# decode each model
if self.has_incremental_states():
if encoder_out_aug is not None:
decoder_out = getattr(model, decoder_name).forward(
tokens,
encoder_out=encoder_out,
encoder_out_aug=encoder_out_aug,
incremental_state=incremental_states[i],
)
else:
decoder_out = getattr(model, decoder_name).forward(
tokens,
encoder_out=encoder_out,
incremental_state=incremental_states[i],
)
else:
if hasattr(model, decoder_name):
decoder_out = getattr(model, decoder_name).forward(
tokens, encoder_out=encoder_out
)
else:
decoder_out = model.forward(tokens)
attn: Optional[Tensor] = None
decoder_len = len(decoder_out)
if decoder_len > 1 and decoder_out[1] is not None:
if isinstance(decoder_out[1], Tensor):
attn = decoder_out[1]
else:
attn_holder = decoder_out[1]["attn"]
if isinstance(attn_holder, Tensor):
attn = attn_holder
elif attn_holder is not None:
attn = attn_holder[0]
if attn is not None:
attn = attn[:, -1, :]
decoder_out_tuple = (
decoder_out[0][:, -1:, :].div_(temperature),
None if decoder_len <= 1 else decoder_out[1],
)
probs = getattr(model, decoder_name).get_normalized_probs(
decoder_out_tuple, log_probs=True, sample=None
)
probs = probs[:, -1, :]
if self.models_size == 1:
return probs, attn
log_probs.append(probs)
if attn is not None:
if avg_attn is None:
avg_attn = attn
else:
avg_attn.add_(attn)
avg_probs = torch.logsumexp(torch.stack(log_probs, dim=0), dim=0) - math.log(
self.models_size
)
if avg_attn is not None:
avg_attn.div_(self.models_size)
return avg_probs, avg_attn
@torch.jit.export
def reorder_incremental_state(
self,
incremental_states: List[Dict[str, Dict[str, Optional[Tensor]]]],
new_order,
decoder_name="decoder",
):
if not self.has_incremental_states():
return
for i, model in enumerate(self.models):
getattr(model, decoder_name).reorder_incremental_state_scripting(
incremental_states[i], new_order
)

View File

@ -0,0 +1,260 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Dict, List, Optional
import torch
import torch.nn as nn
from torch import Tensor
from fairseq import search
class MultiDecoderSequenceGenerator(nn.Module):
def __init__(
self,
models,
tgt_dict,
tgt_dict_mt,
beam_size=1,
beam_size_mt=1,
max_len_a=0,
max_len_b=200,
max_len_a_mt=0,
max_len_b_mt=200,
max_len=0,
min_len=1,
normalize_scores=True,
len_penalty=1.0,
len_penalty_mt=1.0,
unk_penalty=0.0,
temperature=1.0,
match_source_len=False,
no_repeat_ngram_size=0,
eos=None,
eos_mt=None,
symbols_to_strip_from_output=None,
lm_model=None,
lm_weight=1.0,
):
"""Generates translations of a given source sentence.
Args:
models (List[~fairseq.models.FairseqModel]): ensemble of models,
currently support fairseq.models.TransformerModel for scripting
beam_size (int, optional): beam width (default: 1)
max_len_a/b (int, optional): generate sequences of maximum length
ax + b, where x is the source length for the second pass
max_len_a_mt/b_mt (int, optional): generate sequences of maximum length
ax + b, where x is the source length for the first pass
max_len (int, optional): the maximum length of the generated output
(not including end-of-sentence)
min_len (int, optional): the minimum length of the generated output
(not including end-of-sentence)
normalize_scores (bool, optional): normalize scores by the length
of the output (default: True)
len_penalty (float, optional): length penalty in the second pass, where <1.0 favors
shorter, >1.0 favors longer sentences (default: 1.0)
len_penalty (float, optional): length penalty in the first pass, where <1.0 favors
shorter, >1.0 favors longer sentences (default: 1.0)
unk_penalty (float, optional): unknown word penalty, where <0
produces more unks, >0 produces fewer (default: 0.0)
temperature (float, optional): temperature, where values
>1.0 produce more uniform samples and values <1.0 produce
sharper samples (default: 1.0)
match_source_len (bool, optional): outputs should match the source
length (default: False)
"""
super().__init__()
from examples.speech_to_speech.unity.sequence_generator import SequenceGenerator
self.generator = SequenceGenerator(
models,
tgt_dict,
beam_size=beam_size,
max_len_a=max_len_a,
max_len_b=max_len_b,
max_len=max_len,
min_len=min_len,
normalize_scores=normalize_scores,
len_penalty=len_penalty,
unk_penalty=unk_penalty,
temperature=temperature,
match_source_len=match_source_len,
no_repeat_ngram_size=no_repeat_ngram_size,
search_strategy=search.BeamSearch(tgt_dict),
eos=eos,
symbols_to_strip_from_output=symbols_to_strip_from_output,
lm_model=lm_model,
lm_weight=lm_weight,
)
self.eos = self.generator.eos
self.generator_mt = SequenceGenerator(
models,
tgt_dict_mt,
beam_size=beam_size_mt,
max_len_a=max_len_a_mt,
max_len_b=max_len_b_mt,
max_len=max_len,
min_len=min_len,
normalize_scores=normalize_scores,
len_penalty=len_penalty_mt,
unk_penalty=unk_penalty,
temperature=temperature,
match_source_len=match_source_len,
no_repeat_ngram_size=no_repeat_ngram_size,
search_strategy=search.BeamSearch(tgt_dict_mt),
eos=eos_mt,
symbols_to_strip_from_output=symbols_to_strip_from_output,
)
@torch.no_grad()
def generate(
self, models, sample: Dict[str, Dict[str, Tensor]], **kwargs
) -> List[List[Dict[str, Tensor]]]:
"""Generate translations. Match the api of other fairseq generators.
Args:
models (List[~fairseq.models.FairseqModel]): ensemble of models
sample (dict): batch
prefix_tokens (torch.LongTensor, optional): force decoder to begin
with these tokens
constraints (torch.LongTensor, optional): force decoder to include
the list of constraints
bos_token (int, optional): beginning of sentence token
(default: self.eos)
"""
return self._generate(sample, **kwargs)
def _generate(
self,
sample: Dict[str, Dict[str, Tensor]],
prefix_tokens: Optional[Tensor] = None,
constraints: Optional[Tensor] = None,
bos_token: Optional[int] = None,
):
net_input = sample["net_input"]
if "src_tokens" in net_input:
src_tokens = net_input["src_tokens"]
# length of the source text being the character length except EndOfSentence and pad
src_lengths = (
(src_tokens.ne(self.generator.eos) & src_tokens.ne(self.generator.pad))
.long()
.sum(dim=1)
)
else:
raise Exception(
"expected src_tokens or source in net input. input keys: "
+ str(net_input.keys())
)
if constraints is not None and not self.generator.search.supports_constraints:
raise NotImplementedError(
"Target-side constraints were provided, but search method doesn't support them"
)
# Initialize constraints, when active
self.generator.search.init_constraints(constraints, self.generator.beam_size)
self.generator_mt.search.init_constraints(
constraints, self.generator_mt.beam_size
)
# compute the encoder output for each beam
with torch.autograd.profiler.record_function("EnsembleModel: forward_encoder"):
encoder_outs = self.generator.model.forward_encoder(net_input)
single_model = self.generator.model.single_model
mt_decoder = getattr(single_model, f"{single_model.mt_task_name}_decoder")
# 1. MT decoder
finalized_mt = self.generator_mt.generate_decoder(
encoder_outs,
src_tokens,
src_lengths,
sample,
prefix_tokens,
constraints,
bos_token,
aux_task_name=single_model.mt_task_name,
)
# extract decoder output corresponding to the best hypothesis
max_tgt_len = max([len(hypo[0]["tokens"]) for hypo in finalized_mt])
prev_output_tokens_mt = (
src_tokens.new_zeros(src_tokens.shape[0], max_tgt_len)
.fill_(mt_decoder.padding_idx)
.int()
) # B x T
for i, hypo in enumerate(finalized_mt):
i_beam = 0
tmp = hypo[i_beam]["tokens"].int() # hyp + eos
prev_output_tokens_mt[i, 0] = self.generator_mt.eos
if tmp[-1] == self.generator_mt.eos:
tmp = tmp[:-1]
prev_output_tokens_mt[i, 1 : len(tmp) + 1] = tmp
text = "".join([self.generator_mt.tgt_dict[c] for c in tmp])
text = text.replace("_", " ")
text = text.replace("", " ")
text = text.replace("<unk>", " ")
text = text.replace("<s>", "")
text = text.replace("</s>", "")
if len(text) > 0 and text[0] == " ":
text = text[1:]
sample_id = sample["id"].tolist()[i]
print("{} (None-{})".format(text, sample_id))
x = mt_decoder(
prev_output_tokens_mt,
encoder_out=encoder_outs[0],
features_only=True,
)[0].transpose(0, 1)
if getattr(single_model, "proj", None) is not None:
x = single_model.proj(x)
mt_decoder_padding_mask = None
if prev_output_tokens_mt.eq(mt_decoder.padding_idx).any():
mt_decoder_padding_mask = prev_output_tokens_mt.eq(mt_decoder.padding_idx)
# 2. T2U encoder
if getattr(single_model, "synthesizer_encoder", None) is not None:
t2u_encoder_out = single_model.synthesizer_encoder(
x,
mt_decoder_padding_mask,
)
else:
t2u_encoder_out = {
"encoder_out": [x], # T x B x C
"encoder_padding_mask": [mt_decoder_padding_mask]
if mt_decoder_padding_mask is not None
else [], # B x T
"encoder_embedding": [],
"encoder_states": [],
"src_tokens": [],
"src_lengths": [],
}
if getattr(single_model, "t2u_augmented_cross_attn", False):
encoder_outs_aug = [t2u_encoder_out]
else:
encoder_outs = [t2u_encoder_out]
encoder_outs_aug = None
# 3. T2U decoder
finalized = self.generator.generate_decoder(
encoder_outs,
src_tokens,
src_lengths,
sample,
prefix_tokens,
constraints,
bos_token,
encoder_outs_aug=encoder_outs_aug,
)
return finalized

View File

@ -111,8 +111,12 @@ class MultitaskCriterion:
for key in ["target", "target_lengths", "ntokens"]:
task_sample[key] = sample["multitask"][task_name][key]
if task_name == getattr(model, "mt_task_name", None):
decoder_out = model_out["mt_decoder_out"]
else:
decoder_out = None
task_loss, task_sample_size, task_logging_output = task_criterion(
model.multitask_decoders[task_name], task_sample
model.multitask_decoders[task_name], task_sample, net_output=decoder_out
)
loss = loss + self.multitask_loss_weight[task_name] * task_loss
@ -251,6 +255,80 @@ class SpeechToUnitMultitaskTaskCriterion(
return False
@register_criterion(
"speech_to_unit_2pass", dataclass=RdropLabelSmoothedCrossEntropyCriterionConfig
)
class SpeechToUnit2passMultitaskTaskCriterion(SpeechToUnitMultitaskTaskCriterion):
def __init__(
self,
task,
sentence_avg,
label_smoothing,
ignore_prefix_size=0,
report_accuracy=False,
rdrop_alpha=0.0,
):
super().__init__(
task,
sentence_avg,
label_smoothing,
ignore_prefix_size,
report_accuracy,
rdrop_alpha,
)
def forward(self, model, sample, reduce=True):
net_input_concat = {
"src_tokens": sample["net_input"]["src_tokens"],
"src_lengths": sample["net_input"]["src_lengths"],
"prev_output_tokens": sample["net_input"]["prev_output_tokens"],
"prev_output_tokens_mt": sample["multitask"][model.mt_task_name][
"net_input"
]["prev_output_tokens"],
"tgt_speaker": sample["net_input"].get("tgt_speaker", None),
"return_all_hiddens": True,
}
if getattr(model, "asr_task_name", None) is not None:
net_input_concat["prev_output_tokens_asr"] = sample["multitask"][
model.asr_task_name
]["net_input"]["prev_output_tokens"]
if self.rdrop_alpha > 0 or self.rdrop_alpha_mtl > 0:
net_input_concat = duplicate_input(net_input_concat)
net_output, extra = model(**net_input_concat)
loss, nll_loss, rdrop_kl_loss = self.compute_loss(
model, [net_output], sample, reduce=reduce
)
sample_size = (
sample["target"].size(0) if self.sentence_avg else sample["ntokens"]
)
logging_output = {
"loss": loss.data,
"nll_loss": nll_loss.data,
"ntokens": sample["ntokens"],
"nsentences": sample["target"].size(0),
"sample_size": sample_size,
}
if self.report_accuracy:
n_correct, total = self.compute_accuracy(model, [net_output], sample)
logging_output["n_correct"] = utils.item(n_correct.data)
logging_output["total"] = utils.item(total.data)
if self.rdrop_alpha > 0:
logging_output["rdrop_kl_loss"] = utils.item(rdrop_kl_loss.data)
if len(self.multitask_criterion) == 0:
return loss, sample_size, logging_output
# multitask
multitask_loss, multitask_log = self.get_multitask_loss(model, sample, extra)
loss += multitask_loss
logging_output["multitask"] = multitask_log
return loss, sample_size, logging_output
@register_criterion("speech_to_spectrogram", dataclass=Tacotron2CriterionConfig)
class SpeechToSpectrogramMultitaskTaskCriterion(Tacotron2Criterion, MultitaskCriterion):
def __init__(
@ -351,3 +429,88 @@ class SpeechToSpectrogramMultitaskTaskCriterion(Tacotron2Criterion, MultitaskCri
return
MultitaskCriterion.reduce_metrics(logging_outputs)
@register_criterion("speech_to_spectrogram_2pass", dataclass=Tacotron2CriterionConfig)
class SpeechToSpectrogram2passMultitaskTaskCriterion(
SpeechToSpectrogramMultitaskTaskCriterion
):
def __init__(
self,
task,
sentence_avg,
use_guided_attention_loss,
guided_attention_loss_sigma,
bce_pos_weight,
ctc_weight,
):
super().__init__(
task,
sentence_avg,
use_guided_attention_loss,
guided_attention_loss_sigma,
bce_pos_weight,
ctc_weight,
)
def forward(self, model, sample, reduction="mean"):
bsz, max_len, _ = sample["target"].size()
feat_tgt = sample["target"]
feat_len = sample["target_lengths"].view(bsz, 1).expand(-1, max_len)
eos_tgt = torch.arange(max_len).to(sample["target"].device)
eos_tgt = eos_tgt.view(1, max_len).expand(bsz, -1)
eos_tgt = (eos_tgt == (feat_len - 1)).float()
feat_out, eos_out, extra = model(
src_tokens=sample["net_input"]["src_tokens"],
src_lengths=sample["net_input"]["src_lengths"],
prev_output_tokens=sample["net_input"]["prev_output_tokens"],
prev_output_tokens_mt=sample["multitask"][model.mt_task_name]["net_input"][
"prev_output_tokens"
],
tgt_speaker=sample["net_input"]["tgt_speaker"],
target_lengths=sample["target_lengths"],
return_all_hiddens=True,
)
l1_loss, mse_loss, eos_loss = self.compute_loss(
extra["feature_out"],
feat_out,
eos_out,
feat_tgt,
eos_tgt,
sample["target_lengths"],
reduction,
)
attn_loss = torch.tensor(0.0).type_as(l1_loss)
if self.guided_attn is not None:
attn_loss = self.guided_attn(
extra["attn"],
sample["net_input"]["src_lengths"],
sample["target_lengths"],
reduction,
)
loss = (
l1_loss + mse_loss + eos_loss + attn_loss
) # do not include ctc loss as there's no text target
sample_size = sample["nsentences"] if self.sentence_avg else sample["ntokens"]
logging_output = {
"loss": utils.item(loss.data),
"ntokens": sample["ntokens"],
"nsentences": sample["nsentences"],
"sample_size": sample_size,
"l1_loss": utils.item(l1_loss.data),
"mse_loss": utils.item(mse_loss.data),
"eos_loss": utils.item(eos_loss.data),
"attn_loss": utils.item(attn_loss.data),
}
if len(self.multitask_criterion) == 0:
return loss, sample_size, logging_output
# multitask
multitask_loss, multitask_log = self.get_multitask_loss(model, sample, extra)
loss += multitask_loss
logging_output["multitask"] = multitask_log
return loss, sample_size, logging_output

View File

@ -257,6 +257,24 @@ class MultitaskConfig(object):
assert name in self.config, f"multitask '{name}' does not exist!"
return self.config[name]
@property
def first_pass_decoder_task_index(self):
"""Return the task index of the first-pass text decoder.
If there are multiple 'is_first_pass_decoder: True' in the config file,
the last task is used for the first-pass decoder.
If there is no 'is_first_pass_decoder: True' in the config file,
the last task whose task_name includes 'target' and decoder_type is not ctc.
"""
idx = -1
for i, (k, v) in enumerate(self.config.items()):
if v.is_first_pass_decoder:
idx = i
if idx < 0:
for i, (k, v) in enumerate(self.config.items()):
if k.startswith("target") and v.decoder_type == "transformer":
idx = i
return idx
class SingleTaskConfig(object):
def __init__(self, name, config):
@ -336,6 +354,34 @@ class SingleTaskConfig(object):
)
return weight
@property
def prepend_bos_and_append_tgt_lang_tag(self) -> bool:
"""Prepend BOS and append target lang ID token to the target (e.g. mBART with language token pretraining)."""
return self.config.get("prepend_bos_and_append_tgt_lang_tag", False)
@property
def eos_token(self):
"""EOS token during generation"""
return self.config.get("eos_token", "<eos>")
@property
def rdrop_alpha(self):
return self.config.get("rdrop_alpha", None)
return self.config.get("rdrop_alpha", 0.0)
@property
def is_first_pass_decoder(self):
flag = self.config.get("is_first_pass_decoder", False)
if flag:
if self.decoder_type == "ctc":
raise ValueError(
"First-pass decoder in the multi-decoder model must not be CTC."
)
if "target" not in self.task_name:
raise Warning(
'The name of the first-pass decoder does not include "target".'
)
return flag
@property
def get_lang_tag_mapping(self):
return self.config.get("lang_tag_mapping", {})

View File

@ -247,8 +247,9 @@ class SpeechToSpeechMultitaskDataset(SpeechToSpeechDataset):
multitask_target = {}
sample_id = self.ids[index]
tgt_lang = self.tgt_langs[index]
for task_name, task_dataset in self.multitask_data.items():
multitask_target[task_name] = task_dataset.get(sample_id)
multitask_target[task_name] = task_dataset.get(sample_id, tgt_lang)
return s2s_data, multitask_target
@ -318,7 +319,7 @@ class SpeechToSpeechDatasetCreator(object):
src_langs = [s.get(cls.KEY_SRC_LANG, cls.DEFAULT_LANG) for s in samples]
tgt_langs = [s.get(cls.KEY_TGT_LANG, cls.DEFAULT_LANG) for s in samples]
has_multitask = len(multitask) > 0
has_multitask = multitask is not None and len(multitask.keys()) > 0
dataset_cls = (
SpeechToSpeechMultitaskDataset if has_multitask else SpeechToSpeechDataset
)

View File

@ -391,6 +391,7 @@ class SpeechToTextDataset(FairseqDataset):
class TextTargetMultitaskData(object):
# mandatory columns
KEY_ID, KEY_TEXT = "id", "tgt_text"
LANG_TAG_TEMPLATE = "<lang:{}>"
def __init__(self, args, split, tgt_dict):
samples = SpeechToTextDatasetCreator._load_samples_from_tsv(args.data, split)
@ -399,6 +400,16 @@ class TextTargetMultitaskData(object):
self.append_eos = args.decoder_type != "ctc"
self.pre_tokenizer = self.build_tokenizer(args)
self.bpe_tokenizer = self.build_bpe(args)
self.prepend_bos_and_append_tgt_lang_tag = (
args.prepend_bos_and_append_tgt_lang_tag
)
self.eos_token = args.eos_token
self.lang_tag_mapping = args.get_lang_tag_mapping
@classmethod
def is_lang_tag(cls, token):
pattern = cls.LANG_TAG_TEMPLATE.replace("{}", "(.*)")
return re.match(pattern, token)
@classmethod
def tokenize(cls, tokenizer, text: str):
@ -409,6 +420,13 @@ class TextTargetMultitaskData(object):
text = self.tokenize(self.bpe_tokenizer, text)
return text
def get_lang_tag_idx(self, lang: str, dictionary: Dictionary):
lang_tag = self.LANG_TAG_TEMPLATE.format(lang)
lang_tag = self.lang_tag_mapping.get(lang_tag, lang_tag)
lang_tag_idx = dictionary.index(lang_tag)
assert lang_tag_idx != dictionary.unk(), (lang, lang_tag)
return lang_tag_idx
def build_tokenizer(self, args):
pre_tokenizer = args.config.get("pre_tokenizer")
if pre_tokenizer is not None:
@ -425,14 +443,21 @@ class TextTargetMultitaskData(object):
else:
return None
def get(self, sample_id):
def get(self, sample_id, tgt_lang=None):
if sample_id in self.data:
tokenized = self.get_tokenized_tgt_text(sample_id)
return self.dict.encode_line(
target = self.dict.encode_line(
tokenized,
add_if_not_exist=False,
append_eos=self.append_eos,
)
if self.prepend_bos_and_append_tgt_lang_tag:
bos = torch.LongTensor([self.dict.bos()])
lang_tag_idx = self.get_lang_tag_idx(tgt_lang, self.dict)
assert lang_tag_idx != self.dict.unk()
lang_tag_idx = torch.LongTensor([lang_tag_idx])
target = torch.cat((bos, target, lang_tag_idx), 0)
return target
else:
logger.warning(f"no target for {sample_id}")
return torch.IntTensor([])
@ -441,7 +466,7 @@ class TextTargetMultitaskData(object):
out = fairseq_data_utils.collate_tokens(
samples,
self.dict.pad(),
eos_idx=self.dict.eos(),
eos_idx=None,
left_pad=False,
move_eos_to_beginning=False,
).long()
@ -449,7 +474,7 @@ class TextTargetMultitaskData(object):
prev_out = fairseq_data_utils.collate_tokens(
samples,
self.dict.pad(),
eos_idx=self.dict.eos(),
eos_idx=None,
left_pad=False,
move_eos_to_beginning=True,
).long()
@ -551,7 +576,7 @@ class SpeechToTextDatasetCreator(object):
src_langs = [s.get(cls.KEY_SRC_LANG, cls.DEFAULT_LANG) for s in samples]
tgt_langs = [s.get(cls.KEY_TGT_LANG, cls.DEFAULT_LANG) for s in samples]
has_multitask = len(multitask) > 0
has_multitask = multitask is not None and len(multitask.keys()) > 0
dataset_cls = (
SpeechToTextMultitaskDataset if has_multitask else SpeechToTextDataset
)

View File

@ -811,6 +811,10 @@ class GenerationConfig(FairseqDataclass):
default=5,
metadata={"help": "beam size"},
)
beam_mt: int = field(
default=0,
metadata={"help": "beam size for the first-pass decoder"},
)
nbest: int = field(
default=1,
metadata={"help": "number of hypotheses to output"},
@ -827,6 +831,18 @@ class GenerationConfig(FairseqDataclass):
"help": "generate sequences of maximum length ax + b, where x is the source length"
},
)
max_len_a_mt: float = field(
default=0,
metadata={
"help": "generate sequences of maximum length ax + b, where x is the source length for the first-pass decoder"
},
)
max_len_b_mt: int = field(
default=200,
metadata={
"help": "generate sequences of maximum length ax + b, where x is the source length for the first-pass decoder"
},
)
min_len: int = field(
default=1,
metadata={"help": "minimum generation length"},
@ -853,6 +869,12 @@ class GenerationConfig(FairseqDataclass):
"help": "length penalty: <1.0 favors shorter, >1.0 favors longer sentences"
},
)
lenpen_mt: float = field(
default=1,
metadata={
"help": "length penalty for the first-pass decoder: <1.0 favors shorter, >1.0 favors longer sentences"
},
)
unkpen: float = field(
default=0,
metadata={

View File

@ -4,4 +4,6 @@
# LICENSE file in the root directory of this source tree.
from .s2s_conformer import * # noqa
from .s2s_conformer_translatotron2 import * # noqa
from .s2s_conformer_unity import * # noqa
from .s2s_transformer import * # noqa

View File

@ -0,0 +1,108 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Any, Dict, List, Optional
from torch import Tensor
from fairseq.models.transformer import Linear
from fairseq.models.transformer.transformer_decoder_aug import AugTransformerDecoder
class AugTransformerUnitDecoder(AugTransformerDecoder):
"""Based on Transformer decoder, with support to decoding stacked units"""
def __init__(
self,
args,
dictionary,
embed_tokens,
no_encoder_attn=False,
output_projection=None,
):
super().__init__(
args, dictionary, embed_tokens, no_encoder_attn, output_projection
)
self.n_frames_per_step = args.n_frames_per_step
self.out_proj_n_frames = (
Linear(
self.output_embed_dim,
self.output_embed_dim * self.n_frames_per_step,
bias=False,
)
if self.n_frames_per_step > 1
else None
)
def forward(
self,
prev_output_tokens,
encoder_out: Optional[Dict[str, List[Tensor]]] = None,
encoder_out_aug: Optional[Dict[str, List[Tensor]]] = None,
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
features_only: bool = False,
full_context_alignment: bool = False,
alignment_layer: Optional[int] = None,
alignment_heads: Optional[int] = None,
src_lengths: Optional[Any] = None,
return_all_hiddens: bool = False,
):
"""
Args:
prev_output_tokens (LongTensor): previous decoder outputs of shape
`(batch, tgt_len)`, for teacher forcing
encoder_out (optional): output from the encoder, used for
encoder-side attention, should be of size T x B x C
incremental_state (dict): dictionary used for storing state during
:ref:`Incremental decoding`
features_only (bool, optional): only return features without
applying output layer (default: False).
full_context_alignment (bool, optional): don't apply
auto-regressive mask to self-attention (default: False).
Returns:
tuple:
- the decoder's output of shape `(batch, tgt_len, vocab)`
- a dictionary with any model-specific outputs
"""
x, extra = self.extract_features(
prev_output_tokens,
encoder_out=encoder_out,
encoder_out_aug=encoder_out_aug,
incremental_state=incremental_state,
full_context_alignment=full_context_alignment,
alignment_layer=alignment_layer,
alignment_heads=alignment_heads,
)
if not features_only:
bsz, seq_len, d = x.size()
if self.out_proj_n_frames:
x = self.out_proj_n_frames(x)
x = self.output_layer(x.view(bsz, seq_len, self.n_frames_per_step, d))
x = x.view(bsz, seq_len * self.n_frames_per_step, -1)
if (
incremental_state is None and self.n_frames_per_step > 1
): # teacher-forcing mode in training
x = x[
:, : -(self.n_frames_per_step - 1), :
] # remove extra frames after <eos>
return x, extra
def upgrade_state_dict_named(self, state_dict, name):
if self.n_frames_per_step > 1:
move_keys = [
(
f"{name}.project_in_dim.weight",
f"{name}.embed_tokens.project_in_dim.weight",
)
]
for from_k, to_k in move_keys:
if from_k in state_dict and to_k not in state_dict:
state_dict[to_k] = state_dict[from_k]
del state_dict[from_k]

View File

@ -0,0 +1,85 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch.nn as nn
from fairseq.models import FairseqEncoder
from fairseq.modules import LayerNorm, TransformerEncoderLayer
class TransformerEncoderNoEmb(FairseqEncoder):
"""Transformer encoder without token embeddings."""
def __init__(self, args):
super().__init__(None)
self.layers = nn.ModuleList(
[TransformerEncoderLayer(args) for _ in range(args.encoder_layers)]
)
if args.encoder_normalize_before:
self.layer_norm = LayerNorm(args.encoder_embed_dim)
else:
self.layer_norm = None
def forward(self, x, encoder_padding_mask, return_all_hiddens=False):
encoder_states = []
for layer in self.layers:
x = layer(x, encoder_padding_mask)
if return_all_hiddens:
encoder_states.append(x)
if self.layer_norm is not None:
x = self.layer_norm(x)
return {
"encoder_out": [x], # T x B x C
"encoder_padding_mask": [encoder_padding_mask]
if encoder_padding_mask is not None and encoder_padding_mask.any()
else [], # B x T
"encoder_embedding": [], # B x T x C
"encoder_states": encoder_states, # List[T x B x C]
"src_tokens": [],
"src_lengths": [],
}
def reorder_encoder_out(self, encoder_out, new_order):
new_encoder_out = (
[]
if len(encoder_out["encoder_out"]) == 0
else [x.index_select(1, new_order) for x in encoder_out["encoder_out"]]
)
new_encoder_padding_mask = (
[]
if len(encoder_out["encoder_padding_mask"]) == 0
else [
x.index_select(0, new_order)
for x in encoder_out["encoder_padding_mask"]
]
)
new_encoder_embedding = (
[]
if len(encoder_out["encoder_embedding"]) == 0
else [
x.index_select(0, new_order) for x in encoder_out["encoder_embedding"]
]
)
encoder_states = encoder_out["encoder_states"]
if len(encoder_states) > 0:
for idx, state in enumerate(encoder_states):
encoder_states[idx] = state.index_select(1, new_order)
return {
"encoder_out": new_encoder_out, # T x B x C
"encoder_padding_mask": new_encoder_padding_mask, # B x T
"encoder_embedding": new_encoder_embedding, # B x T x C
"encoder_states": encoder_states, # List[T x B x C]
"src_tokens": [], # B x T
"src_lengths": [], # B x 1
}

View File

@ -11,7 +11,9 @@ import torch
from fairseq import checkpoint_utils
from fairseq.models import register_model, register_model_architecture
from fairseq.models.speech_to_speech.s2s_transformer import (
S2SpecTTransformerModel,
S2UTTransformerModel,
s2spect_architecture_base,
s2ut_architecture_base,
)
from fairseq.models.speech_to_text import S2TConformerEncoder
@ -97,6 +99,34 @@ class S2UTConformerModel(S2UTTransformerModel):
return build_s2s_conformer_encoder(args)
@register_model("s2spect_conformer")
class S2SpecTConformerModel(S2SpecTTransformerModel):
"""
Direct speech-to-speech translation model with Conformer encoder + TTS Transformer decoder
"""
@staticmethod
def add_args(parser):
S2SpecTTransformerModel.add_args(parser)
parser.add_argument("--depthwise-conv-kernel-size", type=int, default=31)
parser.add_argument(
"--attn-type",
type=str,
default=None,
help="If not specified uses fairseq MHA. Other valid option is espnet for using conformer",
)
parser.add_argument(
"--pos-enc-type",
type=str,
default="abs",
help="Must be specified in addition to attn-type=espnet for rel_pos and rope",
)
@classmethod
def build_encoder(cls, args):
return build_s2s_conformer_encoder(args)
@register_model_architecture("s2ut_conformer", "s2ut_conformer")
def s2ut_conformer_architecture_base(args):
args.attn_type = getattr(args, "attn_type", None)
@ -111,3 +141,32 @@ def s2ut_conformer_architecture_base(args):
args.encoder_layers = getattr(args, "encoder_layers", 16)
args.depthwise_conv_kernel_size = getattr(args, "depthwise_conv_kernel_size", 31)
s2ut_architecture_base(args)
@register_model_architecture("s2spect_conformer", "s2spect_conformer")
def s2spect_conformer_architecture_base(args):
args.attn_type = getattr(args, "attn_type", None)
args.pos_enc_type = getattr(args, "pos_enc_type", "abs")
args.input_feat_per_channel = getattr(args, "input_feat_per_channel", 80)
args.input_channels = getattr(args, "input_channels", 1)
args.max_source_positions = getattr(args, "max_source_positions", 6000)
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 256)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048)
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4)
args.dropout = getattr(args, "dropout", 0.1)
args.encoder_layers = getattr(args, "encoder_layers", 16)
args.depthwise_conv_kernel_size = getattr(args, "depthwise_conv_kernel_size", 31)
s2spect_architecture_base(args)
@register_model_architecture("s2spect_conformer", "s2spect_conformer_fisher")
def s2spect_architecture_fisher(args):
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 256)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 256 * 8)
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4)
args.dropout = getattr(args, "dropout", 0.1)
# decoder
args.prenet_dim = getattr(args, "prenet_dim", 32)
s2spect_conformer_architecture_base(args)

View File

@ -0,0 +1,262 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import copy
import logging
from fairseq.models import (
FairseqEncoderModel,
FairseqLanguageModel,
register_model,
register_model_architecture,
)
from fairseq.models.speech_to_speech.modules.ctc_decoder import CTCDecoder
from fairseq.models.speech_to_speech.modules.transformer_encoder import (
TransformerEncoderNoEmb,
)
from fairseq.models.speech_to_speech.s2s_conformer import S2SpecTConformerModel
from fairseq.models.speech_to_speech.s2s_conformer_unity import (
multitask_text_transformer_decoder_arch,
)
from fairseq.models.speech_to_speech.s2s_transformer import (
base_multitask_text_transformer_decoder_arch,
s2spect_architecture_base,
)
from fairseq.models.text_to_speech import TTSTransformerDecoder
from fairseq.models.transformer import TransformerDecoder, TransformerModelBase
logger = logging.getLogger(__name__)
@register_model("s2spect2_conformer")
class S2SpecT2ConformerModel(S2SpecTConformerModel):
"""
Direct speech-to-speech translation model with Conformer encoder + MT Transformer decoder + TTS Transformer decoder
"""
@staticmethod
def add_args(parser):
S2SpecTConformerModel.add_args(parser)
parser.add_argument(
"--translation-decoder-layers",
type=int,
default=4,
metavar="N",
help="num decoder layers in the first-pass translation module",
)
parser.add_argument(
"--synthesizer",
default="transformer",
choices=["transformer"],
help="",
)
parser.add_argument(
"--synthesizer-encoder-layers",
type=int,
default=0,
metavar="N",
help="num encoder layers in the second-pass synthesizer module",
)
@classmethod
def build_multitask_decoder(
cls,
args,
tgt_dict,
in_dim,
is_mt_decoder,
decoder_layers,
decoder_embed_dim,
decoder_attention_heads,
):
decoder_args = args.decoder_args
decoder_args.encoder_embed_dim = in_dim
if args.decoder_type == "transformer":
if is_mt_decoder:
multitask_text_transformer_decoder_arch(
decoder_args,
decoder_layers,
decoder_embed_dim,
decoder_attention_heads,
) # 4L
else:
base_multitask_text_transformer_decoder_arch(decoder_args) # 2L
task_decoder = TransformerDecoder(
decoder_args,
tgt_dict,
embed_tokens=TransformerModelBase.build_embedding(
decoder_args,
tgt_dict,
decoder_args.decoder_embed_dim,
),
)
elif args.decoder_type == "ctc":
task_decoder = CTCDecoder(
dictionary=tgt_dict,
in_dim=in_dim,
)
else:
raise NotImplementedError(
"currently only support multitask decoder_type 'transformer', 'ctc'"
)
return task_decoder
@classmethod
def build_decoder(cls, args):
_args = copy.deepcopy(args)
_args.encoder_embed_dim = args.decoder_embed_dim
if args.synthesizer == "transformer":
return TTSTransformerDecoder(_args, None, padding_idx=1)
else:
raise NotImplementedError(args.synthesizer)
@classmethod
def build_model(cls, args, task):
encoder = cls.build_encoder(args)
decoder = cls.build_decoder(args)
base_model = cls(encoder, decoder)
# set up multitask decoders
base_model.mt_task_name = None
base_model.multitask_decoders = {}
has_first_pass_decoder = False
for task_name, task_obj in task.multitask_tasks.items():
if task_obj.is_first_pass_decoder:
has_first_pass_decoder = True
base_model.mt_task_name = task_name
in_dim = (
args.encoder_embed_dim
if task_obj.args.input_from == "encoder"
else args.decoder_embed_dim
)
task_decoder = cls.build_multitask_decoder(
task_obj.args,
task_obj.target_dictionary,
in_dim,
task_obj.is_first_pass_decoder,
getattr(args, "translation_decoder_layers", 4),
getattr(args, "decoder_embed_dim", 256),
getattr(args, "decoder_attention_heads", 4),
)
setattr(base_model, f"{task_name}_decoder", task_decoder)
decoder_model_cls = (
FairseqEncoderModel
if task_obj.args.decoder_type == "ctc"
else FairseqLanguageModel
)
base_model.multitask_decoders[task_name] = decoder_model_cls(
getattr(base_model, f"{task_name}_decoder")
)
assert has_first_pass_decoder, "set at least one intermediate non-CTC decoder"
# set up encoder on top of the auxiliary MT decoder
if getattr(args, "synthesizer_encoder_layers", 0) > 0:
base_model.synthesizer_encoder = cls.build_text_encoder(args)
else:
base_model.synthesizer_encoder = None
return base_model
@classmethod
def build_text_encoder(cls, args):
_args = copy.deepcopy(args)
_args.encoder_layers = args.synthesizer_encoder_layers
_args.encoder_embed_dim = args.decoder_embed_dim
_args.encoder_ffn_embed_dim = args.decoder_ffn_embed_dim
_args.encoder_attention_heads = args.decoder_attention_heads
_args.encoder_normalize_before = True
return TransformerEncoderNoEmb(_args)
def forward(
self,
src_tokens,
src_lengths,
prev_output_tokens,
prev_output_tokens_mt,
tgt_speaker=None,
incremental_state=None,
target_lengths=None,
speaker=None,
return_all_hiddens=False,
):
encoder_out = self.encoder(
src_tokens,
src_lengths=src_lengths,
tgt_speaker=tgt_speaker,
return_all_hiddens=return_all_hiddens,
)
# 1. MT decoder
mt_decoder = getattr(self, f"{self.mt_task_name}_decoder")
mt_decoder_out = mt_decoder(
prev_output_tokens_mt,
encoder_out=encoder_out,
)
x = mt_decoder_out[1]["inner_states"][-1]
if mt_decoder.layer_norm is not None:
x = mt_decoder.layer_norm(x)
mt_decoder_padding_mask = None
if prev_output_tokens_mt.eq(mt_decoder.padding_idx).any():
mt_decoder_padding_mask = prev_output_tokens_mt.eq(mt_decoder.padding_idx)
# 2. TTS encoder
if self.synthesizer_encoder is not None:
tts_encoder_out = self.synthesizer_encoder(
x,
mt_decoder_padding_mask,
return_all_hiddens=return_all_hiddens,
)
else:
tts_encoder_out = {
"encoder_out": [x], # T x B x C
"encoder_padding_mask": [mt_decoder_padding_mask], # B x T
}
# 3. TTS decoder
decoder_out = self.decoder(
prev_output_tokens,
encoder_out=tts_encoder_out,
incremental_state=incremental_state,
target_lengths=target_lengths,
speaker=speaker,
)
if return_all_hiddens:
decoder_out[-1]["encoder_states"] = encoder_out["encoder_states"]
decoder_out[-1]["encoder_padding_mask"] = encoder_out[
"encoder_padding_mask"
]
decoder_out[-1]["mt_decoder_out"] = mt_decoder_out
return decoder_out
@register_model_architecture(
model_name="s2spect2_conformer", arch_name="s2spect2_conformer"
)
def s2spect2_conformer_architecture_base(args):
args.conv_version = getattr(args, "conv_version", "convtransformer")
args.attn_type = getattr(args, "attn_type", None)
args.pos_enc_type = getattr(args, "pos_enc_type", "abs")
args.max_source_positions = getattr(args, "max_source_positions", 6000)
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 256)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048)
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4)
args.dropout = getattr(args, "dropout", 0.1)
args.encoder_layers = getattr(args, "encoder_layers", 16)
args.depthwise_conv_kernel_size = getattr(args, "depthwise_conv_kernel_size", 31)
s2spect_architecture_base(args)
# for old naming
@register_model_architecture(
model_name="s2spect2_conformer", arch_name="s2spect_conformer_translatotron2"
)
def s2spect2_conformer_architecture_base_legacy(args):
s2spect2_conformer_architecture_base(args)

View File

@ -0,0 +1,298 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import copy
import logging
from fairseq.models import (
FairseqEncoder,
FairseqEncoderModel,
FairseqLanguageModel,
register_model,
register_model_architecture,
)
from fairseq.models.speech_to_speech.modules.ctc_decoder import CTCDecoder
from fairseq.models.speech_to_speech.modules.stacked_embedding import StackedEmbedding
from fairseq.models.speech_to_speech.modules.transformer_decoder_aug import (
AugTransformerUnitDecoder,
)
from fairseq.models.speech_to_speech.modules.transformer_encoder import (
TransformerEncoderNoEmb,
)
from fairseq.models.speech_to_speech.s2s_conformer import S2UTConformerModel
from fairseq.models.speech_to_speech.s2s_transformer import (
TransformerUnitDecoder,
base_multitask_text_transformer_decoder_arch,
s2ut_architecture_base,
)
from fairseq.models.transformer import TransformerDecoder, TransformerModelBase
logger = logging.getLogger(__name__)
def multitask_text_transformer_decoder_arch(
args, decoder_layers, decoder_embed_dim=256, decoder_attention_heads=4
):
args.decoder_layers = decoder_layers
args.decoder_embed_dim = decoder_embed_dim
args.decoder_attention_heads = decoder_attention_heads
base_multitask_text_transformer_decoder_arch(args)
@register_model("unity_conformer")
class UnityConformerModel(S2UTConformerModel):
"""
Direct speech-to-speech translation model with Conformer encoder + MT Transformer decoder + Transformer discrete unit decoder
"""
@staticmethod
def add_args(parser):
S2UTConformerModel.add_args(parser)
parser.add_argument(
"--translation-decoder-layers",
type=int,
default=4,
metavar="N",
help="num decoder layers in the first-pass translation module",
)
parser.add_argument(
"--synthesizer",
default="transformer",
choices=["transformer"],
help="",
)
parser.add_argument(
"--synthesizer-encoder-layers",
type=int,
default=0,
metavar="N",
help="num encoder layers in the second-pass synthesizer module",
)
parser.add_argument(
"--synthesizer-augmented-cross-attention",
action="store_true",
default=False,
help="augmented cross-attention over speech encoder output",
)
@classmethod
def build_multitask_decoder(
cls,
args,
tgt_dict,
in_dim,
is_first_pass_decoder,
decoder_layers,
decoder_embed_dim,
decoder_attention_heads,
):
decoder_args = args.decoder_args
decoder_args.encoder_embed_dim = in_dim
if args.decoder_type == "transformer":
if is_first_pass_decoder:
multitask_text_transformer_decoder_arch(
decoder_args,
decoder_layers,
decoder_embed_dim,
decoder_attention_heads,
) # 4L
else:
base_multitask_text_transformer_decoder_arch(decoder_args) # 2L
task_decoder = TransformerDecoder(
decoder_args,
tgt_dict,
embed_tokens=TransformerModelBase.build_embedding(
decoder_args,
tgt_dict,
decoder_args.decoder_embed_dim,
),
)
elif args.decoder_type == "ctc":
task_decoder = CTCDecoder(
dictionary=tgt_dict,
in_dim=in_dim,
)
else:
raise NotImplementedError(
"currently only support multitask decoder_type 'transformer', 'ctc'"
)
return task_decoder
@classmethod
def build_decoder(cls, args, tgt_dict, aug_attn=False):
num_embeddings = len(tgt_dict)
padding_idx = tgt_dict.pad()
embed_tokens = StackedEmbedding(
num_embeddings,
args.decoder_embed_dim,
padding_idx,
num_stacked=args.n_frames_per_step,
)
_args = copy.deepcopy(args)
_args.encoder_embed_dim = args.decoder_embed_dim
decoder_cls = AugTransformerUnitDecoder if aug_attn else TransformerUnitDecoder
return decoder_cls(
_args,
tgt_dict,
embed_tokens,
)
@classmethod
def build_model(cls, args, task):
encoder = cls.build_encoder(args)
decoder = cls.build_decoder(
args,
task.target_dictionary,
aug_attn=getattr(args, "synthesizer_augmented_cross_attention", False),
)
base_model = cls(encoder, decoder)
base_model.t2u_augmented_cross_attn = getattr(
args, "synthesizer_augmented_cross_attention", False
)
# set up multitask decoders
base_model.mt_task_name = None
base_model.multitask_decoders = {}
has_first_pass_decoder = False
for task_name, task_obj in task.multitask_tasks.items():
if task_obj.is_first_pass_decoder:
has_first_pass_decoder = True
base_model.mt_task_name = task_name
in_dim = (
args.encoder_embed_dim
if task_obj.args.input_from == "encoder"
else args.decoder_embed_dim
)
task_decoder = cls.build_multitask_decoder(
task_obj.args,
task_obj.target_dictionary,
in_dim,
task_obj.is_first_pass_decoder,
getattr(args, "translation_decoder_layers", 4),
getattr(args, "decoder_embed_dim", 256),
getattr(args, "decoder_attention_heads", 4),
)
setattr(base_model, f"{task_name}_decoder", task_decoder)
decoder_model_cls = (
FairseqEncoderModel
if task_obj.args.decoder_type == "ctc"
else FairseqLanguageModel
)
base_model.multitask_decoders[task_name] = decoder_model_cls(
getattr(base_model, f"{task_name}_decoder")
)
assert has_first_pass_decoder, "set at least one intermediate non-CTC decoder"
# set up encoder on top of the auxiliary MT decoder
if getattr(args, "synthesizer_encoder_layers", 0) > 0:
base_model.synthesizer_encoder = cls.build_text_encoder(args)
else:
base_model.synthesizer_encoder = None
return base_model
@classmethod
def build_text_encoder(cls, args):
_args = copy.deepcopy(args)
_args.encoder_layers = args.synthesizer_encoder_layers
_args.encoder_embed_dim = args.decoder_embed_dim
_args.encoder_ffn_embed_dim = args.decoder_ffn_embed_dim
_args.encoder_attention_heads = args.decoder_attention_heads
_args.encoder_normalize_before = True
return TransformerEncoderNoEmb(_args)
def forward(
self,
src_tokens,
src_lengths,
prev_output_tokens,
prev_output_tokens_mt,
tgt_speaker=None,
return_all_hiddens=False,
):
mt_decoder = getattr(self, f"{self.mt_task_name}_decoder")
encoder_out = self.encoder(
src_tokens,
src_lengths=src_lengths,
tgt_speaker=tgt_speaker,
return_all_hiddens=return_all_hiddens,
)
# 1. MT decoder
mt_decoder_out = mt_decoder(
prev_output_tokens_mt,
encoder_out=encoder_out,
)
x = mt_decoder_out[1]["inner_states"][-1]
if mt_decoder.layer_norm is not None:
x = mt_decoder.layer_norm(x)
mt_decoder_padding_mask = None
if prev_output_tokens_mt.eq(mt_decoder.padding_idx).any():
mt_decoder_padding_mask = prev_output_tokens_mt.eq(mt_decoder.padding_idx)
# 2. T2U encoder
if self.synthesizer_encoder is not None:
t2u_encoder_out = self.synthesizer_encoder(
x,
mt_decoder_padding_mask,
return_all_hiddens=return_all_hiddens,
)
else:
t2u_encoder_out = {
"encoder_out": [x], # T x B x C
"encoder_padding_mask": [mt_decoder_padding_mask], # B x T
}
# 3. T2U decoder
if self.t2u_augmented_cross_attn:
decoder_out = self.decoder(
prev_output_tokens,
encoder_out=encoder_out,
encoder_out_aug=t2u_encoder_out,
)
else:
decoder_out = self.decoder(
prev_output_tokens,
encoder_out=t2u_encoder_out,
)
if return_all_hiddens:
decoder_out[-1]["encoder_states"] = encoder_out["encoder_states"]
decoder_out[-1]["encoder_padding_mask"] = encoder_out[
"encoder_padding_mask"
]
decoder_out[-1]["mt_decoder_out"] = mt_decoder_out
return decoder_out
@register_model_architecture(model_name="unity_conformer", arch_name="unity_conformer")
def unity_conformer_architecture_base(args):
args.conv_version = getattr(args, "conv_version", "convtransformer")
args.attn_type = getattr(args, "attn_type", None)
args.pos_enc_type = getattr(args, "pos_enc_type", "abs")
args.max_source_positions = getattr(args, "max_source_positions", 6000)
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 256)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048)
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4)
args.dropout = getattr(args, "dropout", 0.1)
args.encoder_layers = getattr(args, "encoder_layers", 16)
args.depthwise_conv_kernel_size = getattr(args, "depthwise_conv_kernel_size", 31)
s2ut_architecture_base(args)
# for old naming
@register_model_architecture(
model_name="unity_conformer", arch_name="s2ut_conformer_translatotron2"
)
def unity_conformer_architecture_base_legacy(args):
unity_conformer_architecture_base(args)

View File

@ -10,3 +10,4 @@ from .s2t_conformer import * # noqa
from .s2t_transformer import * # noqa
from .s2t_wav_transformer import * # noqa
from .xm_transformer import * # noqa
from .xm_transformer_unity import * # noqa

View File

@ -623,6 +623,7 @@ class XMTransformerModel(FairseqEncoderDecoderModel):
_args.dropout = args.decoder_dropout
_args.attention_dropout = args.decoder_attention_dropout
_args.activation_dropout = args.decoder_activation_dropout
_args.layerdrop = _args.decoder_layerdrop
decoder = TransformerDecoder(_args, task.target_dictionary, embed_tokens)
decoder = cls.maybe_load_pretrained(

View File

@ -0,0 +1,351 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import copy
import logging
from fairseq.models import (
FairseqEncoderModel,
FairseqLanguageModel,
register_model,
register_model_architecture,
)
from fairseq.models.speech_to_speech.modules.ctc_decoder import CTCDecoder
from fairseq.models.speech_to_speech.modules.transformer_encoder import (
TransformerEncoderNoEmb,
)
from fairseq.models.speech_to_text.xm_transformer import XMTransformerModel
from fairseq.models.speech_to_text.xm_transformer import (
base_architecture as xm_t_base_architecture,
)
from fairseq.models.speech_to_text.xm_transformer import (
build_embedding,
need_finetuning,
set_default_adaptor_args,
set_default_general_args,
set_default_transformer_decoder_args,
set_default_w2v_encoder_args,
)
from fairseq.models.transformer import Linear, TransformerDecoder, TransformerModelBase
from fairseq.models.transformer.transformer_decoder_aug import AugTransformerDecoder
logger = logging.getLogger(__name__)
def unit_transformer_decoder_arch_base(
args, decoder_layers=6, decoder_embed_dim=768, decoder_attention_heads=12
):
args.encoder_layers = decoder_layers
args.decoder_layers = decoder_layers
args.decoder_embed_dim = decoder_embed_dim
args.decoder_ffn_embed_dim = decoder_embed_dim * 4
args.decoder_attention_heads = decoder_attention_heads
args.encoder_embed_dim = args.decoder_embed_dim
args.decoder_output_dim = decoder_embed_dim
args.decoder_input_dim = decoder_embed_dim
def unit_transformer_decoder_arch_large(
args, decoder_layers=12, decoder_embed_dim=1024, decoder_attention_heads=16
):
args.encoder_layers = decoder_layers
args.decoder_layers = decoder_layers
args.decoder_embed_dim = decoder_embed_dim
args.decoder_ffn_embed_dim = decoder_embed_dim * 4
args.decoder_attention_heads = decoder_attention_heads
args.encoder_embed_dim = args.decoder_embed_dim
args.decoder_output_dim = decoder_embed_dim
args.decoder_input_dim = decoder_embed_dim
@register_model("unity_xm_transformer")
class XMTransformerModelUnitY(XMTransformerModel):
@classmethod
def hub_models(cls):
base_url = "http://dl.fbaipublicfiles.com/fairseq/s2t"
model_ids = []
return {i: f"{base_url}/{i}.tar.gz" for i in model_ids}
def __init__(self, encoder, decoder):
super().__init__(encoder, decoder)
@classmethod
def add_args(cls, parser):
"""Add model-specific arguments to the parser."""
XMTransformerModel.add_args(parser)
parser.add_argument(
"--translation-decoder-layers",
type=int,
default=4,
metavar="N",
help="num decoder layers in the first-pass translation module",
)
parser.add_argument(
"--synthesizer-encoder-layers",
type=int,
default=0,
metavar="N",
help="num encoder layers in the second-pass synthesizer module",
)
parser.add_argument(
"--synthesizer-augmented-cross-attention",
action="store_true",
default=False,
help="augmented cross-attention over speech encoder output",
)
parser.add_argument(
"--load-pretrained-aux-decoder-from",
type=str,
metavar="STR",
help="model to take decoder weights from (for initialization)",
)
@classmethod
def build_text_decoder(cls, args, tgt_dict):
_args = copy.deepcopy(args)
if args.adaptor_proj or args.encoder_proj: # not V0 arch
_args.encoder_embed_dim = _args.decoder_embed_dim
_args.dropout = args.decoder_dropout
_args.attention_dropout = args.decoder_attention_dropout
_args.activation_dropout = args.decoder_activation_dropout
_args.layerdrop = _args.decoder_layerdrop
_args.decoder_layers = _args.translation_decoder_layers
embed_tokens = build_embedding(tgt_dict, _args.decoder_embed_dim)
decoder = TransformerDecoder(_args, tgt_dict, embed_tokens)
if getattr(args, "load_pretrained_aux_decoder_from", None) is not None:
decoder = cls.maybe_load_pretrained(
decoder, getattr(args, "load_pretrained_aux_decoder_from", None)
)
for k, p in decoder.named_parameters():
p.requires_grad = need_finetuning(args.finetune_decoder_params, k)
return decoder
@classmethod
def build_decoder(cls, args, task, aug_attn=False):
_args = copy.deepcopy(args)
_args.layerdrop = 0.0 # turn off layerdrop for shallow layers
_args.encoder_embed_dim = args.decoder_embed_dim
proj = None
if args.decoder_embed_dim != _args.decoder_embed_dim:
proj = Linear(args.decoder_embed_dim, _args.decoder_embed_dim)
embed_tokens = build_embedding(task.target_dictionary, _args.decoder_embed_dim)
decoder_cls = AugTransformerDecoder if aug_attn else TransformerDecoder
decoder = decoder_cls(_args, task.target_dictionary, embed_tokens)
if getattr(args, "load_pretrained_decoder_from", None) is not None:
# load all layers first and then discard the bottom layers
embed_tokens = build_embedding(
task.target_dictionary, _args.decoder_embed_dim
)
decoder_tmp = decoder_cls(_args, task.target_dictionary, embed_tokens)
decoder_tmp = cls.maybe_load_pretrained(
decoder_tmp, getattr(_args, "load_pretrained_decoder_from", None)
)
state_dict = decoder_tmp.state_dict()
for k, p in decoder.named_parameters():
p.data = state_dict[k].data
p.requires_grad = need_finetuning(_args.finetune_decoder_params, k)
decoder.layers = decoder.layers[-_args.decoder_layers :]
return decoder, proj, _args
@classmethod
def build_model(cls, args, task):
"""Build a new model instance."""
# make sure all arguments are present in older models
xm_t_base_architecture(args)
encoder = cls.build_encoder(args)
decoder, proj, unit_args = cls.build_decoder(
args,
task,
aug_attn=getattr(args, "synthesizer_augmented_cross_attention", False),
)
base_model = cls(encoder, decoder)
setattr(base_model, "proj", proj)
base_model.t2u_augmented_cross_attn = getattr(
args, "synthesizer_augmented_cross_attention", False
)
# set up multitask decoders
base_model.mt_task_name = None
base_model.multitask_decoders = {}
has_first_pass_decoder = False
for task_name, task_obj in task.multitask_tasks.items():
if task_obj.is_first_pass_decoder:
has_first_pass_decoder = True
base_model.mt_task_name = task_name
task_decoder = cls.build_multitask_decoder(
args,
task_obj.args,
task_obj.target_dictionary,
args.decoder_embed_dim,
task_obj.is_first_pass_decoder,
)
setattr(base_model, f"{task_name}_decoder", task_decoder)
decoder_model_cls = (
FairseqEncoderModel
if task_obj.args.decoder_type == "ctc"
else FairseqLanguageModel
)
base_model.multitask_decoders[task_name] = decoder_model_cls(
getattr(base_model, f"{task_name}_decoder")
)
assert has_first_pass_decoder, "set at least one intermediate non-CTC decoder"
# set up encoder on top of the auxiliary MT decoder
if getattr(args, "synthesizer_encoder_layers", 0) > 0:
base_model.synthesizer_encoder = cls.build_t2u_encoder(unit_args)
else:
base_model.synthesizer_encoder = None
return base_model
@classmethod
def build_multitask_decoder(
cls, args, mtl_args, tgt_dict, in_dim, is_first_pass_decoder
):
decoder_args = mtl_args.decoder_args
decoder_args.encoder_embed_dim = in_dim
if mtl_args.decoder_type == "transformer":
if is_first_pass_decoder:
task_decoder = cls.build_text_decoder(args, tgt_dict)
else:
from fairseq.models.speech_to_speech import (
base_multitask_text_transformer_decoder_arch,
)
base_multitask_text_transformer_decoder_arch(decoder_args) # 2L
task_decoder = TransformerDecoder(
decoder_args,
tgt_dict,
embed_tokens=TransformerModelBase.build_embedding(
decoder_args,
tgt_dict,
decoder_args.decoder_embed_dim,
),
)
elif args.decoder_type == "ctc":
task_decoder = CTCDecoder(
dictionary=tgt_dict,
in_dim=in_dim,
)
else:
raise NotImplementedError(
"currently only support multitask decoder_type 'transformer', 'ctc'"
)
return task_decoder
@classmethod
def build_t2u_encoder(cls, args):
_args = copy.deepcopy(args)
_args.encoder_layers = _args.synthesizer_encoder_layers
_args.encoder_embed_dim = args.decoder_embed_dim
_args.encoder_ffn_embed_dim = args.decoder_ffn_embed_dim
_args.encoder_attention_heads = args.decoder_attention_heads
_args.encoder_normalize_before = True
return TransformerEncoderNoEmb(_args)
def forward(
self,
src_tokens,
src_lengths,
prev_output_tokens,
prev_output_tokens_mt,
return_all_hiddens=False,
tgt_speaker=None,
**kwargs,
):
"""
The forward method inherited from the base class has a **kwargs
argument in its input, which is not supported in torchscript. This
method overwrites the forward method definition without **kwargs.
"""
encoder_out = self.encoder(
src_tokens=src_tokens, src_lengths=src_lengths, **kwargs
)
# 1. MT decoder
mt_decoder = getattr(self, f"{self.mt_task_name}_decoder")
mt_decoder_out = mt_decoder(
prev_output_tokens_mt,
encoder_out=encoder_out,
)
x = mt_decoder_out[1]["inner_states"][-1]
if mt_decoder.layer_norm is not None:
x = mt_decoder.layer_norm(x)
if self.proj is not None:
x = self.proj(x)
mt_decoder_padding_mask = None
if prev_output_tokens_mt.eq(mt_decoder.padding_idx).any():
mt_decoder_padding_mask = prev_output_tokens_mt.eq(mt_decoder.padding_idx)
# 2. T2U encoder
if self.synthesizer_encoder is not None:
t2u_encoder_out = self.synthesizer_encoder(
x,
mt_decoder_padding_mask,
)
else:
t2u_encoder_out = {
"encoder_out": [x], # T x B x C
"encoder_padding_mask": [mt_decoder_padding_mask], # B x T
}
# 3. T2U decoder
if self.t2u_augmented_cross_attn:
decoder_out = self.decoder(
prev_output_tokens,
encoder_out=encoder_out,
encoder_out_aug=t2u_encoder_out,
)
else:
decoder_out = self.decoder(
prev_output_tokens,
encoder_out=t2u_encoder_out,
)
if return_all_hiddens:
decoder_out[-1]["encoder_states"] = encoder_out["encoder_out"]
# NOTE: from the top layer
decoder_out[-1]["encoder_padding_mask"] = encoder_out[
"encoder_padding_mask"
]
decoder_out[-1]["mt_decoder_out"] = mt_decoder_out
return decoder_out
@register_model_architecture(
model_name="unity_xm_transformer", arch_name="unity_xm_transformer"
)
def base_architecture_unity(args):
set_default_general_args(args)
set_default_w2v_encoder_args(args)
set_default_adaptor_args(args)
set_default_transformer_decoder_args(args)
args.layernorm_embedding = False
args.decoder_learned_pos = False
# for old models
@register_model_architecture(
model_name="unity_xm_transformer", arch_name="xm_transformer_t2"
)
def base_architecture_unity_legacy(args):
base_architecture_unity(args)

View File

@ -8,6 +8,7 @@ from typing import Any, Dict, List, Optional
import torch
import torch.nn as nn
from torch import Tensor
from fairseq import utils
from fairseq.distributed import fsdp_wrap
@ -25,7 +26,6 @@ from fairseq.modules import (
)
from fairseq.modules.checkpoint_activations import checkpoint_wrapper
from fairseq.modules.quant_noise import quant_noise as apply_quant_noise_
from torch import Tensor
# rewrite name for backward compatibility in `make_generation_fast_`
@ -42,7 +42,7 @@ class TransformerDecoderBase(FairseqIncrementalDecoder):
is a :class:`TransformerDecoderLayer`.
Args:
args (argparse.Namespace): parsed command-line arguments
cfg (argparse.Namespace): parsed command-line arguments
dictionary (~fairseq.data.Dictionary): decoding dictionary
embed_tokens (torch.nn.Embedding): output embedding
no_encoder_attn (bool, optional): whether to attend to encoder outputs

View File

@ -0,0 +1,392 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Any, Dict, List, Optional
import torch
import torch.nn as nn
from torch import Tensor
from fairseq import utils
from fairseq.distributed import fsdp_wrap
from fairseq.models.transformer import TransformerConfig
from fairseq.models.transformer.transformer_decoder import TransformerDecoderBase
from fairseq.modules import (
LayerDropModuleList,
SinusoidalPositionalEmbedding,
transformer_layer_aug,
)
from fairseq.modules.checkpoint_activations import checkpoint_wrapper
class AugTransformerDecoderBase(TransformerDecoderBase):
"""
Transformer decoder augmented with an additional cross-attention. Each layer
is a :class:`AugTransformerDecoderLayerBase`.
Args:
cfg (argparse.Namespace): parsed command-line arguments
dictionary (~fairseq.data.Dictionary): decoding dictionary
embed_tokens (torch.nn.Embedding): output embedding
encoder_attn_merge_type (str, optional): the way to combine outputs from
two cross-attention modules. If "sequential" is set, two cross-attention
modules are stacked sequentially. If "parallel" is set, they are processed
in parallel and combined before feeding it to FFN (default: sequential).
dropnet_ratio (float, optional): a probability to drop each cross-attention
module during training (default: 0.0).
"""
def __init__(
self,
cfg,
dictionary,
embed_tokens,
output_projection=None,
encoder_attn_merge_type="sequential",
dropnet_ratio=0.0,
):
super().__init__(
cfg,
dictionary,
embed_tokens,
no_encoder_attn=False,
output_projection=output_projection,
)
# assert cfg.cross_self_attention
self.cross_self_attention = cfg.cross_self_attention
if self.decoder_layerdrop > 0.0:
self.layers = LayerDropModuleList(p=self.decoder_layerdrop)
else:
self.layers = nn.ModuleList([])
self.layers.extend(
[
self.build_decoder_layer(cfg, encoder_attn_merge_type, dropnet_ratio)
for _ in range(cfg.decoder.layers)
]
)
def build_decoder_layer(
self,
cfg,
encoder_attn_merge_type="sequential",
dropnet_ratio=0,
):
layer = transformer_layer_aug.AugTransformerDecoderLayerBase(
cfg,
no_encoder_attn=False,
encoder_attn_merge_type=encoder_attn_merge_type,
dropnet_ratio=dropnet_ratio,
)
checkpoint = cfg.checkpoint_activations
if checkpoint:
offload_to_cpu = cfg.offload_activations
layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu)
# if we are checkpointing, enforce that FSDP always wraps the
# checkpointed layer, regardless of layer size
min_params_to_wrap = cfg.min_params_to_wrap if not checkpoint else 0
layer = fsdp_wrap(layer, min_num_params=min_params_to_wrap)
return layer
def forward(
self,
prev_output_tokens,
encoder_out: Optional[Dict[str, List[Tensor]]] = None,
encoder_out_aug: Optional[Dict[str, List[Tensor]]] = None,
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
features_only: bool = False,
full_context_alignment: bool = False,
alignment_layer: Optional[int] = None,
alignment_heads: Optional[int] = None,
src_lengths: Optional[Any] = None,
return_all_hiddens: bool = False,
):
"""
Args:
prev_output_tokens (LongTensor): previous decoder outputs of shape
`(batch, tgt_len)`, for teacher forcing
encoder_out (optional): output from the encoder, used for
encoder-side attention, should be of size T x B x C
incremental_state (dict): dictionary used for storing state during
:ref:`Incremental decoding`
features_only (bool, optional): only return features without
applying output layer (default: False).
full_context_alignment (bool, optional): don't apply
auto-regressive mask to self-attention (default: False).
Returns:
tuple:
- the decoder's output of shape `(batch, tgt_len, vocab)`
- a dictionary with any model-specific outputs
"""
x, extra = self.extract_features(
prev_output_tokens,
encoder_out=encoder_out,
encoder_out_aug=encoder_out_aug,
incremental_state=incremental_state,
full_context_alignment=full_context_alignment,
alignment_layer=alignment_layer,
alignment_heads=alignment_heads,
)
if not features_only:
x = self.output_layer(x)
return x, extra
def extract_features(
self,
prev_output_tokens,
encoder_out: Optional[Dict[str, List[Tensor]]],
encoder_out_aug: Optional[Dict[str, List[Tensor]]],
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
full_context_alignment: bool = False,
alignment_layer: Optional[int] = None,
alignment_heads: Optional[int] = None,
):
return self.extract_features_scriptable(
prev_output_tokens,
encoder_out,
encoder_out_aug,
incremental_state,
full_context_alignment,
alignment_layer,
alignment_heads,
)
"""
A scriptable subclass of this class has an extract_features method and calls
super().extract_features, but super() is not supported in torchscript. A copy of
this function is made to be used in the subclass instead.
"""
def extract_features_scriptable(
self,
prev_output_tokens,
encoder_out: Optional[Dict[str, List[Tensor]]],
encoder_out_aug: Optional[Dict[str, List[Tensor]]],
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
full_context_alignment: bool = False,
alignment_layer: Optional[int] = None,
alignment_heads: Optional[int] = None,
):
"""
Similar to *forward* but only return features.
Includes several features from "Jointly Learning to Align and
Translate with Transformer Models" (Garg et al., EMNLP 2019).
Args:
full_context_alignment (bool, optional): don't apply
auto-regressive mask to self-attention (default: False).
alignment_layer (int, optional): return mean alignment over
heads at this layer (default: last layer).
alignment_heads (int, optional): only average alignment over
this many heads (default: all heads).
Returns:
tuple:
- the decoder's features of shape `(batch, tgt_len, embed_dim)`
- a dictionary with any model-specific outputs
"""
bs, slen = prev_output_tokens.size()
if alignment_layer is None:
alignment_layer = self.num_layers - 1
enc: Optional[Tensor] = None
padding_mask: Optional[Tensor] = None
if encoder_out is not None and len(encoder_out["encoder_out"]) > 0:
enc = encoder_out["encoder_out"][0]
if encoder_out is not None and len(encoder_out["encoder_padding_mask"]) > 0:
padding_mask = encoder_out["encoder_padding_mask"][0]
enc_aug: Optional[Tensor] = None
padding_mask_aug: Optional[Tensor] = None
if encoder_out_aug is not None and len(encoder_out_aug["encoder_out"]) > 0:
enc_aug = encoder_out_aug["encoder_out"][0]
if (
encoder_out_aug is not None
and len(encoder_out_aug["encoder_padding_mask"]) > 0
):
padding_mask_aug = encoder_out_aug["encoder_padding_mask"][0]
# embed positions
positions = None
if self.embed_positions is not None:
positions = self.embed_positions(
prev_output_tokens, incremental_state=incremental_state
)
if incremental_state is not None:
prev_output_tokens = prev_output_tokens[:, -1:]
if positions is not None:
positions = positions[:, -1:]
# Prevent torchscript exporting issue for dynamic quant embedding
prev_output_tokens = prev_output_tokens.contiguous()
# embed tokens and positions
x = self.embed_scale * self.embed_tokens(prev_output_tokens)
if self.quant_noise is not None:
x = self.quant_noise(x)
if self.project_in_dim is not None:
x = self.project_in_dim(x)
if positions is not None:
x += positions
if self.layernorm_embedding is not None:
x = self.layernorm_embedding(x)
x = self.dropout_module(x)
# B x T x C -> T x B x C
x = x.transpose(0, 1)
self_attn_padding_mask: Optional[Tensor] = None
if self.cross_self_attention or prev_output_tokens.eq(self.padding_idx).any():
self_attn_padding_mask = prev_output_tokens.eq(self.padding_idx)
# decoder layers
attn: Optional[Tensor] = None
attn_aug: Optional[Tensor] = None
inner_states: List[Optional[Tensor]] = [x]
for idx, layer in enumerate(self.layers):
if incremental_state is None and not full_context_alignment:
self_attn_mask = self.buffered_future_mask(x)
else:
self_attn_mask = None
x, layer_attn, layer_attn_aug, _ = layer(
x,
enc,
padding_mask,
enc_aug,
padding_mask_aug,
incremental_state,
self_attn_mask=self_attn_mask,
self_attn_padding_mask=self_attn_padding_mask,
need_attn=bool((idx == alignment_layer)),
need_head_weights=bool((idx == alignment_layer)),
)
inner_states.append(x)
if layer_attn is not None and idx == alignment_layer:
attn = layer_attn.float().to(x)
if layer_attn_aug is not None and idx == alignment_layer:
attn_aug = layer_attn_aug.float().to(x)
if attn is not None:
if alignment_heads is not None:
attn = attn[:alignment_heads]
# average probabilities over heads
attn = attn.mean(dim=0)
if attn_aug is not None:
if alignment_heads is not None:
attn_aug = attn_aug[:alignment_heads]
# average probabilities over heads
attn_aug = attn_aug.mean(dim=0)
if self.layer_norm is not None:
x = self.layer_norm(x)
# T x B x C -> B x T x C
x = x.transpose(0, 1)
if self.project_out_dim is not None:
x = self.project_out_dim(x)
return x, {"attn": [attn], "attn_aug": [attn_aug], "inner_states": inner_states}
def upgrade_state_dict_named(self, state_dict, name):
"""Upgrade a (possibly old) state dict for new versions of fairseq."""
if isinstance(self.embed_positions, SinusoidalPositionalEmbedding):
weights_key = "{}.embed_positions.weights".format(name)
if weights_key in state_dict:
del state_dict[weights_key]
state_dict[
"{}.embed_positions._float_tensor".format(name)
] = torch.FloatTensor(1)
if f"{name}.output_projection.weight" not in state_dict:
if self.share_input_output_embed:
embed_out_key = f"{name}.embed_tokens.weight"
else:
embed_out_key = f"{name}.embed_out"
if embed_out_key in state_dict:
state_dict[f"{name}.output_projection.weight"] = state_dict[
embed_out_key
]
if not self.share_input_output_embed:
del state_dict[embed_out_key]
for i in range(self.num_layers):
# update layer norms
layer_norm_map = {
"0": "self_attn_layer_norm",
"1": "encoder_attn_layer_norm",
"2": "encoder_attn_layer_norm2",
"3": "final_layer_norm",
}
for old, new in layer_norm_map.items():
for m in ("weight", "bias"):
k = "{}.layers.{}.layer_norms.{}.{}".format(name, i, old, m)
if k in state_dict:
state_dict[
"{}.layers.{}.{}.{}".format(name, i, new, m)
] = state_dict[k]
del state_dict[k]
version_key = "{}.version".format(name)
if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) <= 2:
# earlier checkpoints did not normalize after the stack of layers
self.layer_norm = None
self.normalize = False
state_dict[version_key] = torch.Tensor([1])
return state_dict
class AugTransformerDecoder(AugTransformerDecoderBase):
def __init__(
self,
args,
dictionary,
embed_tokens,
output_projection=None,
):
self.args = args
super().__init__(
TransformerConfig.from_namespace(args),
dictionary,
embed_tokens,
no_encoder_attn=False,
output_projection=output_projection,
encoder_attn_merge_type=getattr(
args, "synthesizer_augmented_cross_attention_merge_type", "sequential"
),
dropnet_ratio=getattr(args, "dropnet_ratio", 0),
)
def build_output_projection(self, args, dictionary, embed_tokens):
super().build_output_projection(
TransformerConfig.from_namespace(args), dictionary, embed_tokens
)
def build_decoder_layer(
self,
args,
encoder_attn_merge_type="sequential",
dropnet_ratio=0,
):
return super().build_decoder_layer(
TransformerConfig.from_namespace(args),
no_encoder_attn=False,
encoder_attn_merge_type=encoder_attn_merge_type,
dropnet_ratio=dropnet_ratio,
)

View File

@ -28,7 +28,7 @@ class TransformerEncoderLayerBase(nn.Module):
*cfg.encoder.normalize_before* to ``True``.
Args:
args (argparse.Namespace): parsed command-line arguments
cfg (argparse.Namespace): parsed command-line arguments
"""
def __init__(self, cfg, return_fc=False):

View File

@ -0,0 +1,315 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Dict, List, Optional
import torch
from numpy.random import uniform
from torch import Tensor
from fairseq.modules import LayerNorm
from fairseq.modules.transformer_layer import TransformerDecoderLayerBase
class AugTransformerDecoderLayerBase(TransformerDecoderLayerBase):
"""Decoder layer block augmented with an additional cross-attention.
This decoder block is processed with the sequence of the following sub-modules.
self-attention -> cross-attention (first) -> cross-attention (second) -> FFN
Args:
cfg (argparse.Namespace): parsed command-line arguments
encoder_attn_merge_type (str, optional): the way to combine outputs from
two cross-attention modules. If "sequential" is set, two cross-attention
modules are stacked sequentially. If "parallel" is set, they are processed
in parallel and combined before feeding it to FFN (default: sequential).
dropnet_ratio (float, optional): a probability to drop each cross-attention
module during training (default: 0.0).
"""
def __init__(
self,
cfg,
add_bias_kv=False,
add_zero_attn=False,
encoder_attn_merge_type="sequential",
dropnet_ratio=0.0,
):
super().__init__(
cfg,
no_encoder_attn=False,
add_bias_kv=add_bias_kv,
add_zero_attn=False,
)
self.encoder_attn = self.build_encoder_attention(self.embed_dim, cfg)
self.encoder_attn_layer_norm = LayerNorm(self.embed_dim, export=cfg.export)
self.encoder_attn2 = self.build_encoder_attention(self.embed_dim, cfg)
if encoder_attn_merge_type == "sequential":
self.encoder_attn_layer_norm2 = LayerNorm(self.embed_dim, export=cfg.export)
else:
self.encoder_attn_layer_norm2 = None
self.encoder_attn_merge_type = encoder_attn_merge_type
self.dropnet_ratio = dropnet_ratio
def forward(
self,
x,
encoder_out: Optional[torch.Tensor] = None,
encoder_padding_mask: Optional[torch.Tensor] = None,
encoder_out_aug: Optional[torch.Tensor] = None,
encoder_padding_mask2: Optional[torch.Tensor] = None,
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
prev_self_attn_state: Optional[List[torch.Tensor]] = None,
prev_attn_state: Optional[List[torch.Tensor]] = None,
self_attn_mask: Optional[torch.Tensor] = None,
self_attn_padding_mask: Optional[torch.Tensor] = None,
need_attn: bool = False,
need_head_weights: bool = False,
):
"""
Args:
x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
encoder_padding_mask (ByteTensor, optional): binary
ByteTensor of shape `(batch, src_len)` where padding
elements are indicated by ``1``.
need_attn (bool, optional): return attention weights
need_head_weights (bool, optional): return attention weights
for each head (default: return average over heads).
Returns:
encoded output of shape `(seq_len, batch, embed_dim)`
"""
if need_head_weights:
need_attn = True
residual = x
if self.normalize_before:
x = self.self_attn_layer_norm(x)
if prev_self_attn_state is not None:
prev_key, prev_value = prev_self_attn_state[:2]
saved_state: Dict[str, Optional[Tensor]] = {
"prev_key": prev_key,
"prev_value": prev_value,
}
if len(prev_self_attn_state) >= 3:
saved_state["prev_key_padding_mask"] = prev_self_attn_state[2]
assert incremental_state is not None
self.self_attn._set_input_buffer(incremental_state, saved_state)
_self_attn_input_buffer = self.self_attn._get_input_buffer(incremental_state)
if self.cross_self_attention and not (
incremental_state is not None
and _self_attn_input_buffer is not None
and "prev_key" in _self_attn_input_buffer
):
if self_attn_mask is not None:
assert encoder_out is not None
self_attn_mask = torch.cat(
(x.new_zeros(x.size(0), encoder_out.size(0)), self_attn_mask), dim=1
)
if self_attn_padding_mask is not None:
if encoder_padding_mask is None:
assert encoder_out is not None
encoder_padding_mask = self_attn_padding_mask.new_zeros(
encoder_out.size(1), encoder_out.size(0)
)
self_attn_padding_mask = torch.cat(
(encoder_padding_mask, self_attn_padding_mask), dim=1
)
assert encoder_out is not None
y = torch.cat((encoder_out, x), dim=0)
else:
y = x
x, attn = self.self_attn(
query=x,
key=y,
value=y,
key_padding_mask=self_attn_padding_mask,
incremental_state=incremental_state,
need_weights=False,
attn_mask=self_attn_mask,
)
if self.c_attn is not None:
tgt_len, bsz = x.size(0), x.size(1)
x = x.view(tgt_len, bsz, self.nh, self.head_dim)
x = torch.einsum("tbhd,h->tbhd", x, self.c_attn)
x = x.reshape(tgt_len, bsz, self.embed_dim)
if self.attn_ln is not None:
x = self.attn_ln(x)
x = self.dropout_module(x)
x = self.residual_connection(x, residual)
if not self.normalize_before:
x = self.self_attn_layer_norm(x)
assert encoder_out is not None
assert encoder_out_aug is not None
if self.encoder_attn_merge_type == "sequential":
ratios = self.get_dropnet_ratio()
# first encoder attention
if ratios[0] > 0:
residual = x
if self.normalize_before:
x = self.encoder_attn_layer_norm(x)
if prev_attn_state is not None:
prev_key, prev_value = prev_attn_state[:2]
saved_state: Dict[str, Optional[Tensor]] = {
"prev_key": prev_key,
"prev_value": prev_value,
}
if len(prev_attn_state) >= 3:
saved_state["prev_key_padding_mask"] = prev_attn_state[2]
assert incremental_state is not None
self.encoder_attn._set_input_buffer(incremental_state, saved_state)
x, attn = self.encoder_attn(
query=x,
key=encoder_out,
value=encoder_out,
key_padding_mask=encoder_padding_mask,
incremental_state=incremental_state,
static_kv=True,
need_weights=need_attn or (not self.training and self.need_attn),
need_head_weights=need_head_weights,
)
x = self.dropout_module(x)
x = self.residual_connection(x, residual)
if not self.normalize_before:
x = self.encoder_attn_layer_norm(x)
x = ratios[0] * x
# second encoder attention
if ratios[1] > 0:
residual = x
if self.normalize_before:
x = self.encoder_attn_layer_norm2(x)
if prev_attn_state is not None:
prev_key, prev_value = prev_attn_state[:2]
saved_state: Dict[str, Optional[Tensor]] = {
"prev_key": prev_key,
"prev_value": prev_value,
}
if len(prev_attn_state) >= 3:
saved_state["prev_key_padding_mask"] = prev_attn_state[2]
assert incremental_state is not None
self.encoder_attn2._set_input_buffer(incremental_state, saved_state)
x, attn2 = self.encoder_attn2(
query=x,
key=encoder_out_aug,
value=encoder_out_aug,
key_padding_mask=encoder_padding_mask2,
incremental_state=incremental_state,
static_kv=True,
need_weights=need_attn or (not self.training and self.need_attn),
need_head_weights=need_head_weights,
)
x = self.dropout_module(x)
x = self.residual_connection(x, residual)
if not self.normalize_before:
x = self.encoder_attn_layer_norm2(x)
x = ratios[1] * x
elif self.encoder_attn_merge_type == "parallel":
residual = x
if self.normalize_before:
x = self.encoder_attn_layer_norm(x)
if prev_attn_state is not None:
prev_key, prev_value = prev_attn_state[:2]
saved_state: Dict[str, Optional[Tensor]] = {
"prev_key": prev_key,
"prev_value": prev_value,
}
if len(prev_attn_state) >= 3:
saved_state["prev_key_padding_mask"] = prev_attn_state[2]
assert incremental_state is not None
self.encoder_attn._set_input_buffer(incremental_state, saved_state)
x1, attn = self.encoder_attn(
query=x,
key=encoder_out,
value=encoder_out,
key_padding_mask=encoder_padding_mask,
incremental_state=incremental_state,
static_kv=True,
need_weights=need_attn or (not self.training and self.need_attn),
need_head_weights=need_head_weights,
)
x2, attn2 = self.encoder_attn2(
query=x,
key=encoder_out_aug,
value=encoder_out_aug,
key_padding_mask=encoder_padding_mask2,
incremental_state=incremental_state,
static_kv=True,
need_weights=need_attn or (not self.training and self.need_attn),
need_head_weights=need_head_weights,
)
x1 = self.dropout_module(x1)
x2 = self.dropout_module(x2)
ratios = self.get_dropnet_ratio()
x = ratios[0] * x1 + ratios[1] * x2
x = self.residual_connection(x, residual)
if not self.normalize_before:
x = self.encoder_attn_layer_norm(x)
else:
raise NotImplementedError(self.encoder_attn_merge_type)
residual = x
if self.normalize_before:
x = self.final_layer_norm(x)
x = self.activation_fn(self.fc1(x))
x = self.activation_dropout_module(x)
if self.ffn_layernorm is not None:
x = self.ffn_layernorm(x)
x = self.fc2(x)
x = self.dropout_module(x)
if self.w_resid is not None:
residual = torch.mul(self.w_resid, residual)
x = self.residual_connection(x, residual)
if not self.normalize_before:
x = self.final_layer_norm(x)
if self.onnx_trace and incremental_state is not None:
saved_state = self.self_attn._get_input_buffer(incremental_state)
assert saved_state is not None
if self_attn_padding_mask is not None:
self_attn_state = [
saved_state["prev_key"],
saved_state["prev_value"],
saved_state["prev_key_padding_mask"],
]
else:
self_attn_state = [saved_state["prev_key"], saved_state["prev_value"]]
return x, attn, attn2, self_attn_state
return x, attn, attn2, None
def get_dropnet_ratio(self):
if self.encoder_attn_merge_type == "sequential":
if self.dropnet_ratio > 0:
frand = float(uniform(0, 1))
if frand < self.dropnet_ratio and self.training:
return [2, 0]
elif frand > 1 - self.dropnet_ratio and self.training:
return [0, 2]
else:
return [1, 1]
else:
return [1, 1]
elif self.encoder_attn_merge_type == "parallel":
if self.dropnet_ratio > 0:
frand = float(uniform(0, 1))
if frand < self.dropnet_ratio and self.training:
return [1, 0]
elif frand > 1 - self.dropnet_ratio and self.training:
return [0, 1]
else:
return [0.5, 0.5]
else:
return [0.5, 0.5]

View File

@ -76,7 +76,203 @@ class AutoRegressiveSpeechGenerator(SpeechGenerator):
incremental_state=incremental_state,
target_lengths=cur_out_lens,
speaker=sample["speaker"],
**kwargs
**kwargs,
)
cur_eos_prob = torch.sigmoid(cur_eos_out).squeeze(2)
feat.append(cur_extra["feature_out"])
attn.append(cur_extra["attn"])
eos_prob.append(cur_eos_prob)
cur_finished = cur_eos_prob.squeeze(1) > self.eos_prob_threshold
out_lens.masked_fill_((~finished) & cur_finished, step + 1)
finished = finished | cur_finished
if finished.sum().item() == bsz:
break
prev_feat_out = cur_extra["feature_out"]
feat = torch.cat(feat, dim=1)
feat = model.decoder.postnet(feat) + feat
eos_prob = torch.cat(eos_prob, dim=1)
attn = torch.cat(attn, dim=2)
alignment = attn.max(dim=1)[1]
feat = feat.reshape(bsz, -1, raw_dim)
feat = self.gcmvn_denormalize(feat)
eos_prob = eos_prob.repeat_interleave(n_frames_per_step, dim=1)
attn = attn.repeat_interleave(n_frames_per_step, dim=2)
alignment = alignment.repeat_interleave(n_frames_per_step, dim=1)
out_lens = out_lens * n_frames_per_step
finalized = [
{
"feature": feat[b, :out_len],
"eos_prob": eos_prob[b, :out_len],
"attn": attn[b, :, :out_len],
"alignment": alignment[b, :out_len],
"waveform": self.get_waveform(feat[b, :out_len]),
}
for b, out_len in zip(range(bsz), out_lens)
]
if has_targ:
assert sample["target"].size(-1) == out_dim
tgt_feats = sample["target"].view(bsz, -1, raw_dim)
tgt_feats = self.gcmvn_denormalize(tgt_feats)
tgt_lens = sample["target_lengths"] * n_frames_per_step
for b, (f, l) in enumerate(zip(tgt_feats, tgt_lens)):
finalized[b]["targ_feature"] = f[:l]
finalized[b]["targ_waveform"] = self.get_waveform(f[:l])
return finalized
class MultiDecoderSpeechGenerator(SpeechGenerator):
def __init__(
self,
models,
args,
vocoder,
data_cfg,
tgt_dict_mt,
max_iter: int = 6000,
eos_prob_threshold: float = 0.5,
eos_mt=None,
symbols_to_strip_from_output=None,
):
super().__init__(models[0], vocoder, data_cfg)
self.max_iter = max_iter
self.eos_prob_threshold = eos_prob_threshold
self.tgt_dict_mt = tgt_dict_mt
self.eos_mt = eos_mt
from examples.speech_to_speech.unity.sequence_generator import SequenceGenerator
from fairseq import search
self.text_generator = SequenceGenerator(
models,
tgt_dict_mt,
beam_size=max(1, getattr(args, "beam", 5)),
max_len_a=getattr(args, "max_len_a", 0),
max_len_b=getattr(args, "max_len_b", 200),
min_len=getattr(args, "min_len", 1),
normalize_scores=(not getattr(args, "unnormalized", False)),
len_penalty=getattr(args, "lenpen", 1),
unk_penalty=getattr(args, "unkpen", 0),
temperature=getattr(args, "temperature", 1.0),
match_source_len=getattr(args, "match_source_len", False),
no_repeat_ngram_size=getattr(args, "no_repeat_ngram_size", 0),
search_strategy=search.BeamSearch(tgt_dict_mt),
eos=eos_mt,
symbols_to_strip_from_output=symbols_to_strip_from_output,
)
@torch.no_grad()
def generate(self, model, sample, has_targ=False, **kwargs):
model.eval()
src_tokens = sample["net_input"]["src_tokens"]
src_lengths = sample["net_input"]["src_lengths"]
bsz, src_len = src_tokens.size()[:2]
n_frames_per_step = model.decoder.n_frames_per_step
out_dim = model.decoder.out_dim
raw_dim = out_dim // n_frames_per_step
# initialize
encoder_out = model.forward_encoder(
src_tokens, src_lengths, speaker=sample["speaker"]
)
prefix_tokens = None
constraints = None
bos_token = None
mt_decoder = getattr(model, f"{model.mt_task_name}_decoder")
# 1. MT decoder
finalized_mt = self.text_generator.generate_decoder(
[encoder_out],
src_tokens,
src_lengths,
sample,
prefix_tokens,
constraints,
bos_token,
aux_task_name=model.mt_task_name,
)
# extract decoder output corresponding to the best hypothesis
max_tgt_len = max([len(hypo[0]["tokens"]) for hypo in finalized_mt])
prev_output_tokens_mt = (
src_tokens.new_zeros(src_tokens.shape[0], max_tgt_len)
.fill_(mt_decoder.padding_idx)
.int()
) # B x T
for i, hypo in enumerate(finalized_mt):
i_beam = 0
tmp = hypo[i_beam]["tokens"].int() # hyp + eos
prev_output_tokens_mt[i, 0] = self.text_generator.eos
if tmp[-1] == self.text_generator.eos:
tmp = tmp[:-1]
prev_output_tokens_mt[i, 1 : len(tmp) + 1] = tmp
text = "".join([self.tgt_dict_mt[c] for c in tmp])
text = text.replace("_", " ")
text = text.replace("", " ")
text = text.replace("<unk>", " ")
text = text.replace("<s>", "")
text = text.replace("</s>", "")
if len(text) > 0 and text[0] == " ":
text = text[1:]
sample_id = sample["id"].tolist()[i]
print("{} (None-{})".format(text, sample_id))
mt_decoder_out = mt_decoder(
prev_output_tokens_mt,
encoder_out=encoder_out,
features_only=True,
)
x = mt_decoder_out[0].transpose(0, 1)
mt_decoder_padding_mask = None
if prev_output_tokens_mt.eq(mt_decoder.padding_idx).any():
mt_decoder_padding_mask = prev_output_tokens_mt.eq(mt_decoder.padding_idx)
# 2. TTS encoder
if getattr(model, "synthesizer_encoder", None) is not None:
synthesizer_encoder_out = model.synthesizer_encoder(
x,
mt_decoder_padding_mask,
)
else:
synthesizer_encoder_out = {
"encoder_out": [x], # T x B x C
"encoder_padding_mask": [mt_decoder_padding_mask]
if mt_decoder_padding_mask is not None
else [], # B x T
"encoder_embedding": [],
"encoder_states": [],
"src_tokens": [],
"src_lengths": [],
}
# 3. TTS decoder
incremental_state = {}
feat, attn, eos_prob = [], [], []
finished = src_tokens.new_zeros((bsz,)).bool()
out_lens = src_lengths.new_zeros((bsz,)).long().fill_(self.max_iter)
prev_feat_out = encoder_out["encoder_out"][0].new_zeros(bsz, 1, out_dim)
for step in range(self.max_iter):
cur_out_lens = out_lens.clone()
cur_out_lens.masked_fill_(cur_out_lens.eq(self.max_iter), step + 1)
_, cur_eos_out, cur_extra = model.forward_decoder(
prev_feat_out,
encoder_out=synthesizer_encoder_out,
incremental_state=incremental_state,
target_lengths=cur_out_lens,
speaker=sample["speaker"],
**kwargs,
)
cur_eos_prob = torch.sigmoid(cur_eos_out).squeeze(2)
feat.append(cur_extra["feature_out"])

View File

@ -8,6 +8,7 @@ import logging
import math
from argparse import Namespace
from pathlib import Path
from typing import List
import torch
import torch.nn as nn
@ -16,7 +17,10 @@ from fairseq import utils
from fairseq.data import Dictionary
from fairseq.data.audio.data_cfg import MultitaskConfig, S2SDataConfig
from fairseq.data.audio.speech_to_speech_dataset import SpeechToSpeechDatasetCreator
from fairseq.data.audio.speech_to_text_dataset import SpeechToTextDataset
from fairseq.data.audio.speech_to_text_dataset import (
SpeechToTextDataset,
TextTargetMultitaskData,
)
from fairseq.tasks import LegacyFairseqTask, register_task
from fairseq.tasks.speech_to_text import DummyMultiTask
from fairseq.tasks.text_to_speech import batch_mel_cepstral_distortion
@ -209,15 +213,35 @@ class SpeechToSpeechTask(LegacyFairseqTask):
super().__init__(args)
self.tgt_dict = tgt_dict
self.data_cfg = S2SDataConfig(Path(args.data) / args.config_yaml)
self.multitask_tasks = {}
self.tgt_dict_mt = None
self.eos_token_mt = None
if getattr(args, "multitask_config_yaml", None) is not None:
multitask_cfg = MultitaskConfig(
Path(args.data) / args.multitask_config_yaml
)
for task_name, task_config in multitask_cfg.get_all_tasks().items():
self.multitask_tasks[task_name] = DummyMultiTask(
task_config, task_config.tgt_dict
first_pass_task_idx = multitask_cfg.first_pass_decoder_task_index
for i, (task_name, task_config) in enumerate(
multitask_cfg.get_all_tasks().items()
):
task_obj = DummyMultiTask(
task_config,
task_config.tgt_dict,
first_pass=i == first_pass_task_idx,
)
self.multitask_tasks[task_name] = task_obj
if task_obj.is_first_pass_decoder:
self.tgt_dict_mt = task_obj.target_dictionary
if task_config.prepend_bos_and_append_tgt_lang_tag:
self.eos_token_mt = task_config.eos_token
assert not isinstance(self.eos_token_mt, List)
if not self.eos_token_mt:
raise Warning(
"Please provide eos_token in --multitask-config-yaml to replace eos in sequence generator"
)
self._infer_tgt_lang_id = infer_tgt_lang_id
@classmethod
@ -267,11 +291,13 @@ class SpeechToSpeechTask(LegacyFairseqTask):
from fairseq import criterions
if len(self.multitask_tasks) > 0:
if self.args.target_is_code and args._name != "speech_to_unit":
if self.args.target_is_code and not args._name.startswith("speech_to_unit"):
raise ValueError(
"set --criterion speech_to_unit for speech-to-unit loss with multitask"
)
elif not self.args.target_is_code and args._name != "speech_to_spectrogram":
elif not self.args.target_is_code and not args._name.startswith(
"speech_to_spectrogram"
):
raise ValueError(
"set --criterion speech_to_spectrogram for speech-to-spectrogram loss with multitask"
)
@ -296,6 +322,10 @@ class SpeechToSpeechTask(LegacyFairseqTask):
def target_dictionary(self):
return self.tgt_dict
@property
def target_dictionary_mt(self):
return self.tgt_dict_mt
@property
def source_dictionary(self):
return None
@ -326,6 +356,36 @@ class SpeechToSpeechTask(LegacyFairseqTask):
return model
def build_generator_dual_decoder(
self,
models,
args,
extra_gen_cls_kwargs=None,
):
from examples.speech_to_speech.unity.sequence_generator_multi_decoder import (
MultiDecoderSequenceGenerator,
)
return MultiDecoderSequenceGenerator(
models,
self.target_dictionary,
self.target_dictionary_mt,
beam_size=max(1, getattr(args, "beam", 1)),
beam_size_mt=max(1, getattr(args, "beam_mt", 1)),
max_len_a=getattr(args, "max_len_a", 0),
max_len_b=getattr(args, "max_len_b", 200),
max_len_a_mt=getattr(args, "max_len_a_mt", 0),
max_len_b_mt=getattr(args, "max_len_b_mt", 200),
min_len=getattr(args, "min_len", 1),
normalize_scores=(not getattr(args, "unnormalized", False)),
len_penalty=getattr(args, "lenpen", 1),
unk_penalty=getattr(args, "unkpen", 0),
temperature=getattr(args, "temperature", 1.0),
match_source_len=getattr(args, "match_source_len", False),
no_repeat_ngram_size=getattr(args, "no_repeat_ngram_size", 0),
**extra_gen_cls_kwargs,
)
def build_generator(
self,
models,
@ -344,14 +404,23 @@ class SpeechToSpeechTask(LegacyFairseqTask):
else self.vocoder.cpu()
)
has_dual_decoder = getattr(models[0], "mt_task_name", None) is not None
if self.args.target_is_code:
if self.args.n_frames_per_step == 1:
seq_generator = super().build_generator(
models,
args,
seq_gen_cls=None,
extra_gen_cls_kwargs=extra_gen_cls_kwargs,
)
if has_dual_decoder:
seq_generator = self.build_generator_dual_decoder(
models,
args,
extra_gen_cls_kwargs=extra_gen_cls_kwargs,
)
else:
seq_generator = super().build_generator(
models,
args,
seq_gen_cls=None,
extra_gen_cls_kwargs=extra_gen_cls_kwargs,
)
else:
assert (
getattr(args, "beam", 1) == 1 and getattr(args, "nbest", 1) == 1
@ -361,24 +430,64 @@ class SpeechToSpeechTask(LegacyFairseqTask):
self.args.target_code_size,
)
else:
if getattr(args, "teacher_forcing", False):
from fairseq.speech_generator import (
TeacherForcingAutoRegressiveSpeechGenerator,
if has_dual_decoder:
if getattr(args, "teacher_forcing", False):
raise NotImplementedError
else:
from fairseq.speech_generator import MultiDecoderSpeechGenerator
generator = MultiDecoderSpeechGenerator
lang_token_ids_aux = {
i
for s, i in self.tgt_dict_mt.indices.items()
if TextTargetMultitaskData.is_lang_tag(s)
}
if extra_gen_cls_kwargs is None:
extra_gen_cls_kwargs = {}
extra_gen_cls_kwargs[
"symbols_to_strip_from_output"
] = lang_token_ids_aux
eos_id_mt = (
self.tgt_dict_mt.index(self.eos_token_mt)
if self.eos_token_mt
else None
)
assert eos_id_mt != self.tgt_dict_mt.unk()
extra_gen_cls_kwargs["eos_mt"] = eos_id_mt
generator = TeacherForcingAutoRegressiveSpeechGenerator
logger.info("Teacher forcing mode for generation")
seq_generator = generator(
models,
args,
self.vocoder,
self.data_cfg,
self.target_dictionary_mt,
max_iter=self.args.max_target_positions,
eos_prob_threshold=self.args.eos_prob_threshold,
**extra_gen_cls_kwargs,
)
else:
from fairseq.speech_generator import AutoRegressiveSpeechGenerator
if getattr(args, "teacher_forcing", False):
from fairseq.speech_generator import (
TeacherForcingAutoRegressiveSpeechGenerator,
)
generator = AutoRegressiveSpeechGenerator
seq_generator = generator(
models[0],
self.vocoder,
self.data_cfg,
max_iter=self.args.max_target_positions,
eos_prob_threshold=self.args.eos_prob_threshold,
)
generator = TeacherForcingAutoRegressiveSpeechGenerator
logger.info("Teacher forcing mode for generation")
else:
from fairseq.speech_generator import AutoRegressiveSpeechGenerator
generator = AutoRegressiveSpeechGenerator
seq_generator = generator(
models[0],
self.vocoder,
self.data_cfg,
max_iter=self.args.max_target_positions,
eos_prob_threshold=self.args.eos_prob_threshold,
)
return seq_generator

View File

@ -6,6 +6,7 @@
import logging
from argparse import Namespace
from pathlib import Path
from typing import List
from fairseq.data import Dictionary, encoders
from fairseq.data.audio.audio_utils import get_features_or_waveform
@ -14,6 +15,7 @@ from fairseq.data.audio.speech_to_text_dataset import (
S2TDataConfig,
SpeechToTextDataset,
SpeechToTextDatasetCreator,
TextTargetMultitaskData,
)
from fairseq.tasks import LegacyFairseqTask, register_task
@ -66,13 +68,32 @@ class SpeechToTextTask(LegacyFairseqTask):
)
self.multitask_tasks = {}
self.tgt_dict_mt = None
self.eos_token_mt = None
if getattr(args, "multitask_config_yaml", None) is not None:
multitask_cfg = MultitaskConfig(
Path(args.data) / args.multitask_config_yaml
)
for task_name, task_config in multitask_cfg.get_all_tasks().items():
task_obj = DummyMultiTask(task_config, task_config.tgt_dict)
first_pass_task_idx = multitask_cfg.first_pass_decoder_task_index
for i, (task_name, task_config) in enumerate(
multitask_cfg.get_all_tasks().items()
):
task_obj = DummyMultiTask(
task_config,
task_config.tgt_dict,
first_pass=i == first_pass_task_idx,
)
self.multitask_tasks[task_name] = task_obj
if task_obj.is_first_pass_decoder:
self.tgt_dict_mt = task_obj.target_dictionary
if task_config.prepend_bos_and_append_tgt_lang_tag:
self.eos_token_mt = task_config.eos_token
assert not isinstance(self.eos_token_mt, List)
if not self.eos_token_mt:
raise Warning(
"Please provide eos_token in --multitask-config-yaml to replace eos in sequence generator"
)
def _get_speaker_to_id(self):
speaker_to_id = None
@ -124,12 +145,17 @@ class SpeechToTextTask(LegacyFairseqTask):
epoch=epoch,
seed=self.args.seed,
speaker_to_id=self.speaker_to_id,
multitask=self.multitask_tasks,
)
@property
def target_dictionary(self):
return self.tgt_dict
@property
def target_dictionary_mt(self):
return self.tgt_dict_mt
@property
def source_dictionary(self):
return None
@ -143,6 +169,51 @@ class SpeechToTextTask(LegacyFairseqTask):
args.speaker_to_id = self.speaker_to_id
return super(SpeechToTextTask, self).build_model(args, from_checkpoint)
def build_generator_dual_decoder(
self,
models,
args,
extra_gen_cls_kwargs,
):
from examples.speech_to_speech.unity.sequence_generator_multi_decoder import (
MultiDecoderSequenceGenerator,
)
lang_token_ids_aux = {
i
for s, i in self.tgt_dict_mt.indices.items()
if TextTargetMultitaskData.is_lang_tag(s)
}
extra_gen_cls_kwargs["symbols_to_strip_from_output"].update(lang_token_ids_aux)
eos_id_mt = (
self.tgt_dict_mt.index(self.eos_token_mt) if self.eos_token_mt else None
)
assert eos_id_mt != self.tgt_dict_mt.unk()
extra_gen_cls_kwargs["eos_mt"] = eos_id_mt
return MultiDecoderSequenceGenerator(
models,
self.target_dictionary,
self.target_dictionary_mt,
beam_size=max(1, getattr(args, "beam", 1)),
beam_size_mt=max(1, getattr(args, "beam_mt", 1)),
max_len_a=getattr(args, "max_len_a", 0),
max_len_b=getattr(args, "max_len_b", 200),
max_len_a_mt=getattr(args, "max_len_a_mt", 0),
max_len_b_mt=getattr(args, "max_len_b_mt", 0),
min_len=getattr(args, "min_len", 1),
normalize_scores=(not getattr(args, "unnormalized", False)),
len_penalty=getattr(args, "lenpen", 1),
len_penalty_mt=getattr(args, "lenpen_mt", 1),
unk_penalty=getattr(args, "unkpen", 0),
temperature=getattr(args, "temperature", 1.0),
match_source_len=getattr(args, "match_source_len", False),
no_repeat_ngram_size=getattr(args, "no_repeat_ngram_size", 0),
**extra_gen_cls_kwargs,
)
def build_generator(
self,
models,
@ -179,9 +250,21 @@ class SpeechToTextTask(LegacyFairseqTask):
eos_id = self.tgt_dict.index(eos_token) if eos_token else None
extra_gen_cls_kwargs["eos"] = eos_id
return super().build_generator(
models, args, seq_gen_cls=None, extra_gen_cls_kwargs=extra_gen_cls_kwargs
)
has_dual_decoder = getattr(models[0], "mt_task_name", None) is not None
if has_dual_decoder:
return self.build_generator_dual_decoder(
models,
args,
extra_gen_cls_kwargs=extra_gen_cls_kwargs,
)
else:
return super().build_generator(
models,
args,
seq_gen_cls=None,
extra_gen_cls_kwargs=extra_gen_cls_kwargs,
)
def train_step(
self, sample, model, criterion, optimizer, update_num, ignore_grad=False
@ -225,14 +308,19 @@ class SpeechToTextTask(LegacyFairseqTask):
class DummyMultiTask(LegacyFairseqTask):
def __init__(self, args, tgt_dict):
def __init__(self, args, tgt_dict, first_pass=False):
super().__init__(args)
self.tgt_dict = tgt_dict
self.first_pass = first_pass
@property
def target_dictionary(self):
return self.tgt_dict
@property
def is_first_pass_decoder(self):
return self.first_pass
def inference_step(
self, generator, models, sample, prefix_tokens=None, constraints=None
):

View File

@ -10,6 +10,7 @@ import os
import re
import torch
from fairseq.file_io import PathManager
@ -113,6 +114,9 @@ def main():
num_group.add_argument('--num-update-checkpoints', type=int,
help='if set, will try to find checkpoints with names checkpoint_ee_xx.pt in the path specified by'
' input, and average last this many of them.')
num_group.add_argument('--num-best-checkpoints', type=int, default=0,
help='if set, will try to find checkpoints with names checkpoint_best_ee_xx.pt in the path specified by'
' input, and average last this many of them.')
parser.add_argument('--checkpoint-upper-bound', type=int,
help='when using --num-epoch-checkpoints, this will set an upper bound on which epoch to use, '
'when using --num-update-checkpoints, this will set an upper bound on which update to use'
@ -150,6 +154,18 @@ def main():
)
print("averaging checkpoints: ", args.inputs)
if args.num_best_checkpoints > 0:
args.inputs = list(
sorted(
args.inputs,
key=lambda x: float(
os.path.basename(x).split("_")[-1].replace(".pt", "")
),
)
)
args.inputs = args.inputs[: args.num_best_checkpoints]
for path in args.inputs:
print(os.path.basename(path))
new_state = average_checkpoints(args.inputs)
with PathManager.open(args.output, "wb") as f:
torch.save(new_state, f)