mirror of
https://github.com/facebookresearch/fairseq.git
synced 2024-10-03 20:28:26 +03:00
UnitY implementation (#4670)
* Add UnitY implementation * Rename for consistency * Refactor conformer encoder construction * Change the order of arguments for rdrop_alpha * Add compute_loss_with_rdrop * Move build_multitask_decoder to xm_transformer_unity.py * Fix generator selection * Fix check in build_criterion * Modularize Rdrop * Minor fix * Refine class names * Refactor submodules * Fix CE * Fix import * Fix argments for datasets * Add description to AugTransformerDecoderBase * Fix SpeechToTextDatasetCreator * Fix metavar in arguments * Uncomment override_decoder_args * Fix comment in warning * Add is_fisrt_pass_decoder flag * Change Translatotron2SpeechGenerator to MultiDecoderSpeechGenerator * Move inference code to examples/speech_to_speech/unity * Fix rdrop default value in aux tasks * Add language tag mapping option to multitask-config-yaml * Rename encoder_out2 and encoder_outs2 * Rename UnitYXMTransformerModel to XMTransformerModelUnitY * Support num_best_checkpoints in average_checkpoints * Fix has_multitask * Inherit SequenceGenerator * Reflect recent updates * Minor fix in logging * Fix typo * Refactor SpeechToSpectrogram2passMultitaskTaskCriterion * Minor update for multitask
This commit is contained in:
parent
6d90f79883
commit
b4001184f4
@ -2,3 +2,5 @@
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from . import unity # noqa
|
||||
|
7
examples/speech_to_speech/unity/__init__.py
Normal file
7
examples/speech_to_speech/unity/__init__.py
Normal file
@ -0,0 +1,7 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from . import sequence_generator # noqa
|
||||
from . import sequence_generator_multi_decoder # noqa
|
626
examples/speech_to_speech/unity/sequence_generator.py
Normal file
626
examples/speech_to_speech/unity/sequence_generator.py
Normal file
@ -0,0 +1,626 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import math
|
||||
import sys
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from fairseq.sequence_generator import EnsembleModel as EnsembleModelBase
|
||||
from fairseq.sequence_generator import SequenceGenerator as SequenceGeneratorBase
|
||||
|
||||
|
||||
class SequenceGenerator(SequenceGeneratorBase):
|
||||
def __init__(
|
||||
self,
|
||||
models,
|
||||
tgt_dict,
|
||||
beam_size=1,
|
||||
max_len_a=0,
|
||||
max_len_b=200,
|
||||
max_len=0,
|
||||
min_len=1,
|
||||
normalize_scores=True,
|
||||
len_penalty=1.0,
|
||||
unk_penalty=0.0,
|
||||
temperature=1.0,
|
||||
match_source_len=False,
|
||||
no_repeat_ngram_size=0,
|
||||
search_strategy=None,
|
||||
eos=None,
|
||||
symbols_to_strip_from_output=None,
|
||||
lm_model=None,
|
||||
lm_weight=1.0,
|
||||
tokens_to_suppress=(),
|
||||
):
|
||||
"""Generates translations of a given source sentence.
|
||||
|
||||
Args:
|
||||
models (List[~fairseq.models.FairseqModel]): ensemble of models,
|
||||
currently support fairseq.models.TransformerModel for scripting
|
||||
beam_size (int, optional): beam width (default: 1)
|
||||
max_len_a/b (int, optional): generate sequences of maximum length
|
||||
ax + b, where x is the source length
|
||||
max_len (int, optional): the maximum length of the generated output
|
||||
(not including end-of-sentence)
|
||||
min_len (int, optional): the minimum length of the generated output
|
||||
(not including end-of-sentence)
|
||||
normalize_scores (bool, optional): normalize scores by the length
|
||||
of the output (default: True)
|
||||
len_penalty (float, optional): length penalty, where <1.0 favors
|
||||
shorter, >1.0 favors longer sentences (default: 1.0)
|
||||
unk_penalty (float, optional): unknown word penalty, where <0
|
||||
produces more unks, >0 produces fewer (default: 0.0)
|
||||
temperature (float, optional): temperature, where values
|
||||
>1.0 produce more uniform samples and values <1.0 produce
|
||||
sharper samples (default: 1.0)
|
||||
match_source_len (bool, optional): outputs should match the source
|
||||
length (default: False)
|
||||
"""
|
||||
super().__init__(
|
||||
models=models,
|
||||
tgt_dict=tgt_dict,
|
||||
beam_size=beam_size,
|
||||
max_len_a=max_len_a,
|
||||
max_len_b=max_len_b,
|
||||
max_len=max_len,
|
||||
min_len=min_len,
|
||||
normalize_scores=normalize_scores,
|
||||
len_penalty=len_penalty,
|
||||
unk_penalty=unk_penalty,
|
||||
temperature=temperature,
|
||||
match_source_len=match_source_len,
|
||||
no_repeat_ngram_size=no_repeat_ngram_size,
|
||||
search_strategy=search_strategy,
|
||||
eos=eos,
|
||||
symbols_to_strip_from_output=symbols_to_strip_from_output,
|
||||
lm_model=lm_model,
|
||||
lm_weight=lm_weight,
|
||||
tokens_to_suppress=tokens_to_suppress,
|
||||
)
|
||||
|
||||
if isinstance(models, EnsembleModel):
|
||||
self.model = models
|
||||
else:
|
||||
self.model = EnsembleModel(models)
|
||||
|
||||
self.model.set_decoder_beam_size(self.beam_size)
|
||||
self.model.eval()
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
sample: Dict[str, Dict[str, Tensor]],
|
||||
prefix_tokens: Optional[Tensor] = None,
|
||||
constraints: Optional[Tensor] = None,
|
||||
bos_token: Optional[int] = None,
|
||||
):
|
||||
net_input = sample["net_input"]
|
||||
|
||||
if "src_tokens" in net_input:
|
||||
src_tokens = net_input["src_tokens"]
|
||||
# length of the source text being the character length except EndOfSentence and pad
|
||||
# if src_lengths exists in net_input (speech_to_text dataset case), then use it
|
||||
if "src_lengths" in net_input:
|
||||
src_lengths = net_input["src_lengths"]
|
||||
else:
|
||||
src_lengths = (
|
||||
(src_tokens.ne(self.eos) & src_tokens.ne(self.pad))
|
||||
.long()
|
||||
.sum(dim=1)
|
||||
)
|
||||
elif "source" in net_input:
|
||||
src_tokens = net_input["source"]
|
||||
src_lengths = (
|
||||
net_input["padding_mask"].size(-1) - net_input["padding_mask"].sum(-1)
|
||||
if net_input["padding_mask"] is not None
|
||||
else torch.tensor(src_tokens.size(-1)).to(src_tokens)
|
||||
)
|
||||
elif "features" in net_input:
|
||||
src_tokens = net_input["features"]
|
||||
src_lengths = (
|
||||
net_input["padding_mask"].size(-1) - net_input["padding_mask"].sum(-1)
|
||||
if net_input["padding_mask"] is not None
|
||||
else torch.tensor(src_tokens.size(-1)).to(src_tokens)
|
||||
)
|
||||
else:
|
||||
raise Exception(
|
||||
"expected src_tokens or source in net input. input keys: "
|
||||
+ str(net_input.keys())
|
||||
)
|
||||
|
||||
if constraints is not None and not self.search.supports_constraints:
|
||||
raise NotImplementedError(
|
||||
"Target-side constraints were provided, but search method doesn't support them"
|
||||
)
|
||||
|
||||
# Initialize constraints, when active
|
||||
self.search.init_constraints(constraints, self.beam_size)
|
||||
|
||||
# compute the encoder output for each beam
|
||||
with torch.autograd.profiler.record_function("EnsembleModel: forward_encoder"):
|
||||
encoder_outs = self.model.forward_encoder(net_input)
|
||||
|
||||
finalized = self.generate_decoder(
|
||||
encoder_outs,
|
||||
src_tokens,
|
||||
src_lengths,
|
||||
sample,
|
||||
prefix_tokens,
|
||||
constraints,
|
||||
bos_token,
|
||||
)
|
||||
return finalized
|
||||
|
||||
def generate_decoder(
|
||||
self,
|
||||
encoder_outs,
|
||||
src_tokens,
|
||||
src_lengths,
|
||||
sample: Dict[str, Dict[str, Tensor]],
|
||||
prefix_tokens: Optional[Tensor] = None,
|
||||
constraints: Optional[Tensor] = None,
|
||||
bos_token: Optional[int] = None,
|
||||
aux_task_name="",
|
||||
encoder_outs_aug: Optional[
|
||||
Tensor
|
||||
] = None, # an additional/augmented encoder_outs
|
||||
):
|
||||
incremental_states = torch.jit.annotate(
|
||||
List[Dict[str, Dict[str, Optional[Tensor]]]],
|
||||
[
|
||||
torch.jit.annotate(Dict[str, Dict[str, Optional[Tensor]]], {})
|
||||
for i in range(self.model.models_size)
|
||||
],
|
||||
)
|
||||
|
||||
# bsz: total number of sentences in beam
|
||||
# Note that src_tokens may have more than 2 dimensions (i.e. audio features)
|
||||
bsz, src_len = src_tokens.size()[:2]
|
||||
beam_size = self.beam_size
|
||||
|
||||
decoder_name = f"{aux_task_name}_decoder" if aux_task_name else "decoder"
|
||||
|
||||
max_len: int = -1
|
||||
if self.match_source_len:
|
||||
max_len = src_lengths.max().item()
|
||||
else:
|
||||
max_len = min(
|
||||
int(self.max_len_a * src_len + self.max_len_b),
|
||||
self.max_len - 1,
|
||||
)
|
||||
assert (
|
||||
self.min_len <= max_len
|
||||
), "min_len cannot be larger than max_len, please adjust these!"
|
||||
|
||||
# placeholder of indices for bsz * beam_size to hold tokens and accumulative scores
|
||||
new_order = torch.arange(bsz).view(-1, 1).repeat(1, beam_size).view(-1)
|
||||
new_order = new_order.to(src_tokens.device).long()
|
||||
encoder_outs = self.model.reorder_encoder_out(encoder_outs, new_order)
|
||||
# ensure encoder_outs is a List.
|
||||
assert encoder_outs is not None
|
||||
if encoder_outs_aug is not None:
|
||||
encoder_outs_aug = self.model.reorder_encoder_out(
|
||||
encoder_outs_aug, new_order
|
||||
)
|
||||
|
||||
# initialize buffers
|
||||
scores = (
|
||||
torch.zeros(bsz * beam_size, max_len + 1).to(src_tokens).float()
|
||||
) # +1 for eos; pad is never chosen for scoring
|
||||
tokens = (
|
||||
torch.zeros(bsz * beam_size, max_len + 2)
|
||||
.to(src_tokens)
|
||||
.long()
|
||||
.fill_(self.pad)
|
||||
) # +2 for eos and pad
|
||||
tokens[:, 0] = self.eos if bos_token is None else bos_token
|
||||
attn: Optional[Tensor] = None
|
||||
|
||||
# A list that indicates candidates that should be ignored.
|
||||
# For example, suppose we're sampling and have already finalized 2/5
|
||||
# samples. Then cands_to_ignore would mark 2 positions as being ignored,
|
||||
# so that we only finalize the remaining 3 samples.
|
||||
cands_to_ignore = (
|
||||
torch.zeros(bsz, beam_size).to(src_tokens).eq(-1)
|
||||
) # forward and backward-compatible False mask
|
||||
|
||||
# list of completed sentences
|
||||
finalized = torch.jit.annotate(
|
||||
List[List[Dict[str, Tensor]]],
|
||||
[torch.jit.annotate(List[Dict[str, Tensor]], []) for i in range(bsz)],
|
||||
) # contains lists of dictionaries of infomation about the hypothesis being finalized at each step
|
||||
|
||||
# a boolean array indicating if the sentence at the index is finished or not
|
||||
finished = [False for i in range(bsz)]
|
||||
num_remaining_sent = bsz # number of sentences remaining
|
||||
|
||||
# number of candidate hypos per step
|
||||
cand_size = 2 * beam_size # 2 x beam size in case half are EOS
|
||||
|
||||
# offset arrays for converting between different indexing schemes
|
||||
bbsz_offsets = (
|
||||
(torch.arange(0, bsz) * beam_size)
|
||||
.unsqueeze(1)
|
||||
.type_as(tokens)
|
||||
.to(src_tokens.device)
|
||||
)
|
||||
cand_offsets = torch.arange(0, cand_size).type_as(tokens).to(src_tokens.device)
|
||||
|
||||
reorder_state: Optional[Tensor] = None
|
||||
batch_idxs: Optional[Tensor] = None
|
||||
|
||||
original_batch_idxs: Optional[Tensor] = None
|
||||
if "id" in sample and isinstance(sample["id"], Tensor):
|
||||
original_batch_idxs = sample["id"]
|
||||
else:
|
||||
original_batch_idxs = torch.arange(0, bsz).type_as(tokens)
|
||||
|
||||
for step in range(max_len + 1): # one extra step for EOS marker
|
||||
# reorder decoder internal states based on the prev choice of beams
|
||||
if reorder_state is not None:
|
||||
if batch_idxs is not None:
|
||||
# update beam indices to take into account removed sentences
|
||||
corr = batch_idxs - torch.arange(batch_idxs.numel()).type_as(
|
||||
batch_idxs
|
||||
)
|
||||
reorder_state.view(-1, beam_size).add_(
|
||||
corr.unsqueeze(-1) * beam_size
|
||||
)
|
||||
original_batch_idxs = original_batch_idxs[batch_idxs]
|
||||
self.model.reorder_incremental_state(
|
||||
incremental_states, reorder_state, decoder_name
|
||||
)
|
||||
encoder_outs = self.model.reorder_encoder_out(
|
||||
encoder_outs, reorder_state
|
||||
)
|
||||
if encoder_outs_aug is not None:
|
||||
encoder_outs_aug = self.model.reorder_encoder_out(
|
||||
encoder_outs_aug, reorder_state
|
||||
)
|
||||
with torch.autograd.profiler.record_function(
|
||||
"EnsembleModel: forward_decoder"
|
||||
):
|
||||
lprobs, avg_attn_scores = self.model.forward_decoder(
|
||||
tokens[:, : step + 1],
|
||||
encoder_outs,
|
||||
incremental_states,
|
||||
self.temperature,
|
||||
decoder_name=decoder_name,
|
||||
encoder_outs_aug=encoder_outs_aug,
|
||||
)
|
||||
|
||||
if self.lm_model is not None and not aux_task_name:
|
||||
lm_out = self.lm_model(tokens[:, : step + 1])
|
||||
probs = self.lm_model.get_normalized_probs(
|
||||
lm_out, log_probs=True, sample=None
|
||||
)
|
||||
probs = probs[:, -1, :] * self.lm_weight
|
||||
lprobs += probs
|
||||
|
||||
lprobs[lprobs != lprobs] = torch.tensor(-math.inf).to(lprobs)
|
||||
|
||||
lprobs[:, self.pad] = -math.inf # never select pad
|
||||
lprobs[:, self.unk] -= self.unk_penalty # apply unk penalty
|
||||
|
||||
# handle max length constraint
|
||||
if step >= max_len:
|
||||
lprobs[:, : self.eos] = -math.inf
|
||||
lprobs[:, self.eos + 1 :] = -math.inf
|
||||
|
||||
# handle prefix tokens (possibly with different lengths)
|
||||
if (
|
||||
prefix_tokens is not None
|
||||
and step < prefix_tokens.size(1)
|
||||
and step < max_len
|
||||
):
|
||||
lprobs, tokens, scores = self._prefix_tokens(
|
||||
step, lprobs, scores, tokens, prefix_tokens, beam_size
|
||||
)
|
||||
else:
|
||||
if step < self.min_len:
|
||||
# minimum length constraint (does not apply if using prefix_tokens)
|
||||
lprobs[:, self.eos] = -math.inf
|
||||
|
||||
if self.token_indices_to_suppress is not None:
|
||||
lprobs[:, self.token_indices_to_suppress] = -math.inf
|
||||
|
||||
# Record attention scores, only support avg_attn_scores is a Tensor
|
||||
if avg_attn_scores is not None:
|
||||
if attn is None:
|
||||
attn = torch.empty(
|
||||
bsz * beam_size, avg_attn_scores.size(1), max_len + 2
|
||||
).to(scores)
|
||||
attn[:, :, step + 1].copy_(avg_attn_scores)
|
||||
|
||||
scores = scores.type_as(lprobs)
|
||||
eos_bbsz_idx = torch.empty(0).to(
|
||||
tokens
|
||||
) # indices of hypothesis ending with eos (finished sentences)
|
||||
eos_scores = torch.empty(0).to(
|
||||
scores
|
||||
) # scores of hypothesis ending with eos (finished sentences)
|
||||
|
||||
if self.should_set_src_lengths:
|
||||
self.search.set_src_lengths(src_lengths)
|
||||
|
||||
if self.repeat_ngram_blocker is not None:
|
||||
lprobs = self.repeat_ngram_blocker(tokens, lprobs, bsz, beam_size, step)
|
||||
|
||||
# Shape: (batch, cand_size)
|
||||
cand_scores, cand_indices, cand_beams = self.search.step(
|
||||
step,
|
||||
lprobs.view(bsz, -1, self.vocab_size),
|
||||
scores.view(bsz, beam_size, -1)[:, :, :step],
|
||||
tokens[:, : step + 1],
|
||||
original_batch_idxs,
|
||||
)
|
||||
|
||||
# cand_bbsz_idx contains beam indices for the top candidate
|
||||
# hypotheses, with a range of values: [0, bsz*beam_size),
|
||||
# and dimensions: [bsz, cand_size]
|
||||
cand_bbsz_idx = cand_beams.add(bbsz_offsets)
|
||||
|
||||
# finalize hypotheses that end in eos
|
||||
# Shape of eos_mask: (batch size, beam size)
|
||||
eos_mask = cand_indices.eq(self.eos) & cand_scores.ne(-math.inf)
|
||||
eos_mask[:, :beam_size][cands_to_ignore] = torch.tensor(0).to(eos_mask)
|
||||
|
||||
# only consider eos when it's among the top beam_size indices
|
||||
# Now we know what beam item(s) to finish
|
||||
# Shape: 1d list of absolute-numbered
|
||||
eos_bbsz_idx = torch.masked_select(
|
||||
cand_bbsz_idx[:, :beam_size], mask=eos_mask[:, :beam_size]
|
||||
)
|
||||
|
||||
finalized_sents: List[int] = []
|
||||
if eos_bbsz_idx.numel() > 0:
|
||||
eos_scores = torch.masked_select(
|
||||
cand_scores[:, :beam_size], mask=eos_mask[:, :beam_size]
|
||||
)
|
||||
|
||||
finalized_sents = self.finalize_hypos(
|
||||
step,
|
||||
eos_bbsz_idx,
|
||||
eos_scores,
|
||||
tokens,
|
||||
scores,
|
||||
finalized,
|
||||
finished,
|
||||
beam_size,
|
||||
attn,
|
||||
src_lengths,
|
||||
max_len,
|
||||
)
|
||||
num_remaining_sent -= len(finalized_sents)
|
||||
|
||||
assert num_remaining_sent >= 0
|
||||
if num_remaining_sent == 0:
|
||||
break
|
||||
if self.search.stop_on_max_len and step >= max_len:
|
||||
break
|
||||
assert step < max_len, f"{step} < {max_len}"
|
||||
|
||||
# Remove finalized sentences (ones for which {beam_size}
|
||||
# finished hypotheses have been generated) from the batch.
|
||||
if len(finalized_sents) > 0:
|
||||
new_bsz = bsz - len(finalized_sents)
|
||||
|
||||
# construct batch_idxs which holds indices of batches to keep for the next pass
|
||||
batch_mask = torch.ones(
|
||||
bsz, dtype=torch.bool, device=cand_indices.device
|
||||
)
|
||||
batch_mask[finalized_sents] = False
|
||||
# TODO replace `nonzero(as_tuple=False)` after TorchScript supports it
|
||||
batch_idxs = torch.arange(
|
||||
bsz, device=cand_indices.device
|
||||
).masked_select(batch_mask)
|
||||
|
||||
# Choose the subset of the hypothesized constraints that will continue
|
||||
self.search.prune_sentences(batch_idxs)
|
||||
|
||||
eos_mask = eos_mask[batch_idxs]
|
||||
cand_beams = cand_beams[batch_idxs]
|
||||
bbsz_offsets.resize_(new_bsz, 1)
|
||||
cand_bbsz_idx = cand_beams.add(bbsz_offsets)
|
||||
cand_scores = cand_scores[batch_idxs]
|
||||
cand_indices = cand_indices[batch_idxs]
|
||||
|
||||
if prefix_tokens is not None:
|
||||
prefix_tokens = prefix_tokens[batch_idxs]
|
||||
src_lengths = src_lengths[batch_idxs]
|
||||
cands_to_ignore = cands_to_ignore[batch_idxs]
|
||||
|
||||
scores = scores.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1)
|
||||
tokens = tokens.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1)
|
||||
if attn is not None:
|
||||
attn = attn.view(bsz, -1)[batch_idxs].view(
|
||||
new_bsz * beam_size, attn.size(1), -1
|
||||
)
|
||||
bsz = new_bsz
|
||||
else:
|
||||
batch_idxs = None
|
||||
|
||||
# Set active_mask so that values > cand_size indicate eos hypos
|
||||
# and values < cand_size indicate candidate active hypos.
|
||||
# After, the min values per row are the top candidate active hypos
|
||||
|
||||
# Rewrite the operator since the element wise or is not supported in torchscript.
|
||||
|
||||
eos_mask[:, :beam_size] = ~((~cands_to_ignore) & (~eos_mask[:, :beam_size]))
|
||||
active_mask = torch.add(
|
||||
eos_mask.type_as(cand_offsets) * cand_size,
|
||||
cand_offsets[: eos_mask.size(1)],
|
||||
)
|
||||
|
||||
# get the top beam_size active hypotheses, which are just
|
||||
# the hypos with the smallest values in active_mask.
|
||||
# {active_hypos} indicates which {beam_size} hypotheses
|
||||
# from the list of {2 * beam_size} candidates were
|
||||
# selected. Shapes: (batch size, beam size)
|
||||
new_cands_to_ignore, active_hypos = torch.topk(
|
||||
active_mask, k=beam_size, dim=1, largest=False
|
||||
)
|
||||
|
||||
# update cands_to_ignore to ignore any finalized hypos.
|
||||
cands_to_ignore = new_cands_to_ignore.ge(cand_size)[:, :beam_size]
|
||||
# Make sure there is at least one active item for each sentence in the batch.
|
||||
assert (~cands_to_ignore).any(dim=1).all()
|
||||
|
||||
# update cands_to_ignore to ignore any finalized hypos
|
||||
|
||||
# {active_bbsz_idx} denotes which beam number is continued for each new hypothesis (a beam
|
||||
# can be selected more than once).
|
||||
active_bbsz_idx = torch.gather(cand_bbsz_idx, dim=1, index=active_hypos)
|
||||
active_scores = torch.gather(cand_scores, dim=1, index=active_hypos)
|
||||
|
||||
active_bbsz_idx = active_bbsz_idx.view(-1)
|
||||
active_scores = active_scores.view(-1)
|
||||
|
||||
# copy tokens and scores for active hypotheses
|
||||
|
||||
# Set the tokens for each beam (can select the same row more than once)
|
||||
tokens[:, : step + 1] = torch.index_select(
|
||||
tokens[:, : step + 1], dim=0, index=active_bbsz_idx
|
||||
)
|
||||
# Select the next token for each of them
|
||||
tokens.view(bsz, beam_size, -1)[:, :, step + 1] = torch.gather(
|
||||
cand_indices, dim=1, index=active_hypos
|
||||
)
|
||||
if step > 0:
|
||||
scores[:, :step] = torch.index_select(
|
||||
scores[:, :step], dim=0, index=active_bbsz_idx
|
||||
)
|
||||
scores.view(bsz, beam_size, -1)[:, :, step] = torch.gather(
|
||||
cand_scores, dim=1, index=active_hypos
|
||||
)
|
||||
|
||||
# Update constraints based on which candidates were selected for the next beam
|
||||
self.search.update_constraints(active_hypos)
|
||||
|
||||
# copy attention for active hypotheses
|
||||
if attn is not None:
|
||||
attn[:, :, : step + 2] = torch.index_select(
|
||||
attn[:, :, : step + 2], dim=0, index=active_bbsz_idx
|
||||
)
|
||||
|
||||
# reorder incremental state in decoder
|
||||
reorder_state = active_bbsz_idx
|
||||
|
||||
# sort by score descending
|
||||
for sent in range(len(finalized)):
|
||||
scores = torch.tensor(
|
||||
[float(elem["score"].item()) for elem in finalized[sent]]
|
||||
)
|
||||
_, sorted_scores_indices = torch.sort(scores, descending=True)
|
||||
finalized[sent] = [finalized[sent][ssi] for ssi in sorted_scores_indices]
|
||||
finalized[sent] = torch.jit.annotate(
|
||||
List[Dict[str, Tensor]], finalized[sent]
|
||||
)
|
||||
return finalized
|
||||
|
||||
|
||||
class EnsembleModel(EnsembleModelBase):
|
||||
"""A wrapper around an ensemble of models."""
|
||||
|
||||
def __init__(self, models):
|
||||
super().__init__(models)
|
||||
|
||||
@torch.jit.export
|
||||
def forward_decoder(
|
||||
self,
|
||||
tokens,
|
||||
encoder_outs: List[Dict[str, List[Tensor]]],
|
||||
incremental_states: List[Dict[str, Dict[str, Optional[Tensor]]]],
|
||||
temperature: float = 1.0,
|
||||
decoder_name="decoder",
|
||||
encoder_outs_aug: List[Dict[str, List[Tensor]]] = None,
|
||||
):
|
||||
log_probs = []
|
||||
avg_attn: Optional[Tensor] = None
|
||||
encoder_out: Optional[Dict[str, List[Tensor]]] = None
|
||||
encoder_out_aug: Optional[Dict[str, List[Tensor]]] = None
|
||||
for i, model in enumerate(self.models):
|
||||
if self.has_encoder():
|
||||
encoder_out = encoder_outs[i]
|
||||
if encoder_outs_aug is not None:
|
||||
encoder_out_aug = encoder_outs_aug[i]
|
||||
# decode each model
|
||||
if self.has_incremental_states():
|
||||
if encoder_out_aug is not None:
|
||||
decoder_out = getattr(model, decoder_name).forward(
|
||||
tokens,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_aug=encoder_out_aug,
|
||||
incremental_state=incremental_states[i],
|
||||
)
|
||||
else:
|
||||
decoder_out = getattr(model, decoder_name).forward(
|
||||
tokens,
|
||||
encoder_out=encoder_out,
|
||||
incremental_state=incremental_states[i],
|
||||
)
|
||||
else:
|
||||
if hasattr(model, decoder_name):
|
||||
decoder_out = getattr(model, decoder_name).forward(
|
||||
tokens, encoder_out=encoder_out
|
||||
)
|
||||
else:
|
||||
decoder_out = model.forward(tokens)
|
||||
|
||||
attn: Optional[Tensor] = None
|
||||
decoder_len = len(decoder_out)
|
||||
if decoder_len > 1 and decoder_out[1] is not None:
|
||||
if isinstance(decoder_out[1], Tensor):
|
||||
attn = decoder_out[1]
|
||||
else:
|
||||
attn_holder = decoder_out[1]["attn"]
|
||||
if isinstance(attn_holder, Tensor):
|
||||
attn = attn_holder
|
||||
elif attn_holder is not None:
|
||||
attn = attn_holder[0]
|
||||
if attn is not None:
|
||||
attn = attn[:, -1, :]
|
||||
|
||||
decoder_out_tuple = (
|
||||
decoder_out[0][:, -1:, :].div_(temperature),
|
||||
None if decoder_len <= 1 else decoder_out[1],
|
||||
)
|
||||
probs = getattr(model, decoder_name).get_normalized_probs(
|
||||
decoder_out_tuple, log_probs=True, sample=None
|
||||
)
|
||||
probs = probs[:, -1, :]
|
||||
if self.models_size == 1:
|
||||
return probs, attn
|
||||
|
||||
log_probs.append(probs)
|
||||
if attn is not None:
|
||||
if avg_attn is None:
|
||||
avg_attn = attn
|
||||
else:
|
||||
avg_attn.add_(attn)
|
||||
|
||||
avg_probs = torch.logsumexp(torch.stack(log_probs, dim=0), dim=0) - math.log(
|
||||
self.models_size
|
||||
)
|
||||
|
||||
if avg_attn is not None:
|
||||
avg_attn.div_(self.models_size)
|
||||
return avg_probs, avg_attn
|
||||
|
||||
@torch.jit.export
|
||||
def reorder_incremental_state(
|
||||
self,
|
||||
incremental_states: List[Dict[str, Dict[str, Optional[Tensor]]]],
|
||||
new_order,
|
||||
decoder_name="decoder",
|
||||
):
|
||||
if not self.has_incremental_states():
|
||||
return
|
||||
for i, model in enumerate(self.models):
|
||||
getattr(model, decoder_name).reorder_incremental_state_scripting(
|
||||
incremental_states[i], new_order
|
||||
)
|
@ -0,0 +1,260 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
|
||||
from fairseq import search
|
||||
|
||||
|
||||
class MultiDecoderSequenceGenerator(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
models,
|
||||
tgt_dict,
|
||||
tgt_dict_mt,
|
||||
beam_size=1,
|
||||
beam_size_mt=1,
|
||||
max_len_a=0,
|
||||
max_len_b=200,
|
||||
max_len_a_mt=0,
|
||||
max_len_b_mt=200,
|
||||
max_len=0,
|
||||
min_len=1,
|
||||
normalize_scores=True,
|
||||
len_penalty=1.0,
|
||||
len_penalty_mt=1.0,
|
||||
unk_penalty=0.0,
|
||||
temperature=1.0,
|
||||
match_source_len=False,
|
||||
no_repeat_ngram_size=0,
|
||||
eos=None,
|
||||
eos_mt=None,
|
||||
symbols_to_strip_from_output=None,
|
||||
lm_model=None,
|
||||
lm_weight=1.0,
|
||||
):
|
||||
"""Generates translations of a given source sentence.
|
||||
|
||||
Args:
|
||||
models (List[~fairseq.models.FairseqModel]): ensemble of models,
|
||||
currently support fairseq.models.TransformerModel for scripting
|
||||
beam_size (int, optional): beam width (default: 1)
|
||||
max_len_a/b (int, optional): generate sequences of maximum length
|
||||
ax + b, where x is the source length for the second pass
|
||||
max_len_a_mt/b_mt (int, optional): generate sequences of maximum length
|
||||
ax + b, where x is the source length for the first pass
|
||||
max_len (int, optional): the maximum length of the generated output
|
||||
(not including end-of-sentence)
|
||||
min_len (int, optional): the minimum length of the generated output
|
||||
(not including end-of-sentence)
|
||||
normalize_scores (bool, optional): normalize scores by the length
|
||||
of the output (default: True)
|
||||
len_penalty (float, optional): length penalty in the second pass, where <1.0 favors
|
||||
shorter, >1.0 favors longer sentences (default: 1.0)
|
||||
len_penalty (float, optional): length penalty in the first pass, where <1.0 favors
|
||||
shorter, >1.0 favors longer sentences (default: 1.0)
|
||||
unk_penalty (float, optional): unknown word penalty, where <0
|
||||
produces more unks, >0 produces fewer (default: 0.0)
|
||||
temperature (float, optional): temperature, where values
|
||||
>1.0 produce more uniform samples and values <1.0 produce
|
||||
sharper samples (default: 1.0)
|
||||
match_source_len (bool, optional): outputs should match the source
|
||||
length (default: False)
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
from examples.speech_to_speech.unity.sequence_generator import SequenceGenerator
|
||||
|
||||
self.generator = SequenceGenerator(
|
||||
models,
|
||||
tgt_dict,
|
||||
beam_size=beam_size,
|
||||
max_len_a=max_len_a,
|
||||
max_len_b=max_len_b,
|
||||
max_len=max_len,
|
||||
min_len=min_len,
|
||||
normalize_scores=normalize_scores,
|
||||
len_penalty=len_penalty,
|
||||
unk_penalty=unk_penalty,
|
||||
temperature=temperature,
|
||||
match_source_len=match_source_len,
|
||||
no_repeat_ngram_size=no_repeat_ngram_size,
|
||||
search_strategy=search.BeamSearch(tgt_dict),
|
||||
eos=eos,
|
||||
symbols_to_strip_from_output=symbols_to_strip_from_output,
|
||||
lm_model=lm_model,
|
||||
lm_weight=lm_weight,
|
||||
)
|
||||
self.eos = self.generator.eos
|
||||
|
||||
self.generator_mt = SequenceGenerator(
|
||||
models,
|
||||
tgt_dict_mt,
|
||||
beam_size=beam_size_mt,
|
||||
max_len_a=max_len_a_mt,
|
||||
max_len_b=max_len_b_mt,
|
||||
max_len=max_len,
|
||||
min_len=min_len,
|
||||
normalize_scores=normalize_scores,
|
||||
len_penalty=len_penalty_mt,
|
||||
unk_penalty=unk_penalty,
|
||||
temperature=temperature,
|
||||
match_source_len=match_source_len,
|
||||
no_repeat_ngram_size=no_repeat_ngram_size,
|
||||
search_strategy=search.BeamSearch(tgt_dict_mt),
|
||||
eos=eos_mt,
|
||||
symbols_to_strip_from_output=symbols_to_strip_from_output,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(
|
||||
self, models, sample: Dict[str, Dict[str, Tensor]], **kwargs
|
||||
) -> List[List[Dict[str, Tensor]]]:
|
||||
"""Generate translations. Match the api of other fairseq generators.
|
||||
|
||||
Args:
|
||||
models (List[~fairseq.models.FairseqModel]): ensemble of models
|
||||
sample (dict): batch
|
||||
prefix_tokens (torch.LongTensor, optional): force decoder to begin
|
||||
with these tokens
|
||||
constraints (torch.LongTensor, optional): force decoder to include
|
||||
the list of constraints
|
||||
bos_token (int, optional): beginning of sentence token
|
||||
(default: self.eos)
|
||||
"""
|
||||
return self._generate(sample, **kwargs)
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
sample: Dict[str, Dict[str, Tensor]],
|
||||
prefix_tokens: Optional[Tensor] = None,
|
||||
constraints: Optional[Tensor] = None,
|
||||
bos_token: Optional[int] = None,
|
||||
):
|
||||
net_input = sample["net_input"]
|
||||
|
||||
if "src_tokens" in net_input:
|
||||
src_tokens = net_input["src_tokens"]
|
||||
# length of the source text being the character length except EndOfSentence and pad
|
||||
src_lengths = (
|
||||
(src_tokens.ne(self.generator.eos) & src_tokens.ne(self.generator.pad))
|
||||
.long()
|
||||
.sum(dim=1)
|
||||
)
|
||||
else:
|
||||
raise Exception(
|
||||
"expected src_tokens or source in net input. input keys: "
|
||||
+ str(net_input.keys())
|
||||
)
|
||||
|
||||
if constraints is not None and not self.generator.search.supports_constraints:
|
||||
raise NotImplementedError(
|
||||
"Target-side constraints were provided, but search method doesn't support them"
|
||||
)
|
||||
|
||||
# Initialize constraints, when active
|
||||
self.generator.search.init_constraints(constraints, self.generator.beam_size)
|
||||
self.generator_mt.search.init_constraints(
|
||||
constraints, self.generator_mt.beam_size
|
||||
)
|
||||
|
||||
# compute the encoder output for each beam
|
||||
with torch.autograd.profiler.record_function("EnsembleModel: forward_encoder"):
|
||||
encoder_outs = self.generator.model.forward_encoder(net_input)
|
||||
|
||||
single_model = self.generator.model.single_model
|
||||
mt_decoder = getattr(single_model, f"{single_model.mt_task_name}_decoder")
|
||||
|
||||
# 1. MT decoder
|
||||
finalized_mt = self.generator_mt.generate_decoder(
|
||||
encoder_outs,
|
||||
src_tokens,
|
||||
src_lengths,
|
||||
sample,
|
||||
prefix_tokens,
|
||||
constraints,
|
||||
bos_token,
|
||||
aux_task_name=single_model.mt_task_name,
|
||||
)
|
||||
|
||||
# extract decoder output corresponding to the best hypothesis
|
||||
max_tgt_len = max([len(hypo[0]["tokens"]) for hypo in finalized_mt])
|
||||
prev_output_tokens_mt = (
|
||||
src_tokens.new_zeros(src_tokens.shape[0], max_tgt_len)
|
||||
.fill_(mt_decoder.padding_idx)
|
||||
.int()
|
||||
) # B x T
|
||||
for i, hypo in enumerate(finalized_mt):
|
||||
i_beam = 0
|
||||
tmp = hypo[i_beam]["tokens"].int() # hyp + eos
|
||||
prev_output_tokens_mt[i, 0] = self.generator_mt.eos
|
||||
if tmp[-1] == self.generator_mt.eos:
|
||||
tmp = tmp[:-1]
|
||||
prev_output_tokens_mt[i, 1 : len(tmp) + 1] = tmp
|
||||
|
||||
text = "".join([self.generator_mt.tgt_dict[c] for c in tmp])
|
||||
text = text.replace("_", " ")
|
||||
text = text.replace("▁", " ")
|
||||
text = text.replace("<unk>", " ")
|
||||
text = text.replace("<s>", "")
|
||||
text = text.replace("</s>", "")
|
||||
if len(text) > 0 and text[0] == " ":
|
||||
text = text[1:]
|
||||
sample_id = sample["id"].tolist()[i]
|
||||
print("{} (None-{})".format(text, sample_id))
|
||||
|
||||
x = mt_decoder(
|
||||
prev_output_tokens_mt,
|
||||
encoder_out=encoder_outs[0],
|
||||
features_only=True,
|
||||
)[0].transpose(0, 1)
|
||||
|
||||
if getattr(single_model, "proj", None) is not None:
|
||||
x = single_model.proj(x)
|
||||
|
||||
mt_decoder_padding_mask = None
|
||||
if prev_output_tokens_mt.eq(mt_decoder.padding_idx).any():
|
||||
mt_decoder_padding_mask = prev_output_tokens_mt.eq(mt_decoder.padding_idx)
|
||||
|
||||
# 2. T2U encoder
|
||||
if getattr(single_model, "synthesizer_encoder", None) is not None:
|
||||
t2u_encoder_out = single_model.synthesizer_encoder(
|
||||
x,
|
||||
mt_decoder_padding_mask,
|
||||
)
|
||||
else:
|
||||
t2u_encoder_out = {
|
||||
"encoder_out": [x], # T x B x C
|
||||
"encoder_padding_mask": [mt_decoder_padding_mask]
|
||||
if mt_decoder_padding_mask is not None
|
||||
else [], # B x T
|
||||
"encoder_embedding": [],
|
||||
"encoder_states": [],
|
||||
"src_tokens": [],
|
||||
"src_lengths": [],
|
||||
}
|
||||
|
||||
if getattr(single_model, "t2u_augmented_cross_attn", False):
|
||||
encoder_outs_aug = [t2u_encoder_out]
|
||||
else:
|
||||
encoder_outs = [t2u_encoder_out]
|
||||
encoder_outs_aug = None
|
||||
|
||||
# 3. T2U decoder
|
||||
finalized = self.generator.generate_decoder(
|
||||
encoder_outs,
|
||||
src_tokens,
|
||||
src_lengths,
|
||||
sample,
|
||||
prefix_tokens,
|
||||
constraints,
|
||||
bos_token,
|
||||
encoder_outs_aug=encoder_outs_aug,
|
||||
)
|
||||
return finalized
|
@ -111,8 +111,12 @@ class MultitaskCriterion:
|
||||
for key in ["target", "target_lengths", "ntokens"]:
|
||||
task_sample[key] = sample["multitask"][task_name][key]
|
||||
|
||||
if task_name == getattr(model, "mt_task_name", None):
|
||||
decoder_out = model_out["mt_decoder_out"]
|
||||
else:
|
||||
decoder_out = None
|
||||
task_loss, task_sample_size, task_logging_output = task_criterion(
|
||||
model.multitask_decoders[task_name], task_sample
|
||||
model.multitask_decoders[task_name], task_sample, net_output=decoder_out
|
||||
)
|
||||
|
||||
loss = loss + self.multitask_loss_weight[task_name] * task_loss
|
||||
@ -251,6 +255,80 @@ class SpeechToUnitMultitaskTaskCriterion(
|
||||
return False
|
||||
|
||||
|
||||
@register_criterion(
|
||||
"speech_to_unit_2pass", dataclass=RdropLabelSmoothedCrossEntropyCriterionConfig
|
||||
)
|
||||
class SpeechToUnit2passMultitaskTaskCriterion(SpeechToUnitMultitaskTaskCriterion):
|
||||
def __init__(
|
||||
self,
|
||||
task,
|
||||
sentence_avg,
|
||||
label_smoothing,
|
||||
ignore_prefix_size=0,
|
||||
report_accuracy=False,
|
||||
rdrop_alpha=0.0,
|
||||
):
|
||||
super().__init__(
|
||||
task,
|
||||
sentence_avg,
|
||||
label_smoothing,
|
||||
ignore_prefix_size,
|
||||
report_accuracy,
|
||||
rdrop_alpha,
|
||||
)
|
||||
|
||||
def forward(self, model, sample, reduce=True):
|
||||
net_input_concat = {
|
||||
"src_tokens": sample["net_input"]["src_tokens"],
|
||||
"src_lengths": sample["net_input"]["src_lengths"],
|
||||
"prev_output_tokens": sample["net_input"]["prev_output_tokens"],
|
||||
"prev_output_tokens_mt": sample["multitask"][model.mt_task_name][
|
||||
"net_input"
|
||||
]["prev_output_tokens"],
|
||||
"tgt_speaker": sample["net_input"].get("tgt_speaker", None),
|
||||
"return_all_hiddens": True,
|
||||
}
|
||||
if getattr(model, "asr_task_name", None) is not None:
|
||||
net_input_concat["prev_output_tokens_asr"] = sample["multitask"][
|
||||
model.asr_task_name
|
||||
]["net_input"]["prev_output_tokens"]
|
||||
|
||||
if self.rdrop_alpha > 0 or self.rdrop_alpha_mtl > 0:
|
||||
net_input_concat = duplicate_input(net_input_concat)
|
||||
|
||||
net_output, extra = model(**net_input_concat)
|
||||
loss, nll_loss, rdrop_kl_loss = self.compute_loss(
|
||||
model, [net_output], sample, reduce=reduce
|
||||
)
|
||||
|
||||
sample_size = (
|
||||
sample["target"].size(0) if self.sentence_avg else sample["ntokens"]
|
||||
)
|
||||
logging_output = {
|
||||
"loss": loss.data,
|
||||
"nll_loss": nll_loss.data,
|
||||
"ntokens": sample["ntokens"],
|
||||
"nsentences": sample["target"].size(0),
|
||||
"sample_size": sample_size,
|
||||
}
|
||||
if self.report_accuracy:
|
||||
n_correct, total = self.compute_accuracy(model, [net_output], sample)
|
||||
logging_output["n_correct"] = utils.item(n_correct.data)
|
||||
logging_output["total"] = utils.item(total.data)
|
||||
if self.rdrop_alpha > 0:
|
||||
logging_output["rdrop_kl_loss"] = utils.item(rdrop_kl_loss.data)
|
||||
|
||||
if len(self.multitask_criterion) == 0:
|
||||
return loss, sample_size, logging_output
|
||||
|
||||
# multitask
|
||||
multitask_loss, multitask_log = self.get_multitask_loss(model, sample, extra)
|
||||
loss += multitask_loss
|
||||
logging_output["multitask"] = multitask_log
|
||||
|
||||
return loss, sample_size, logging_output
|
||||
|
||||
|
||||
@register_criterion("speech_to_spectrogram", dataclass=Tacotron2CriterionConfig)
|
||||
class SpeechToSpectrogramMultitaskTaskCriterion(Tacotron2Criterion, MultitaskCriterion):
|
||||
def __init__(
|
||||
@ -351,3 +429,88 @@ class SpeechToSpectrogramMultitaskTaskCriterion(Tacotron2Criterion, MultitaskCri
|
||||
return
|
||||
|
||||
MultitaskCriterion.reduce_metrics(logging_outputs)
|
||||
|
||||
|
||||
@register_criterion("speech_to_spectrogram_2pass", dataclass=Tacotron2CriterionConfig)
|
||||
class SpeechToSpectrogram2passMultitaskTaskCriterion(
|
||||
SpeechToSpectrogramMultitaskTaskCriterion
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
task,
|
||||
sentence_avg,
|
||||
use_guided_attention_loss,
|
||||
guided_attention_loss_sigma,
|
||||
bce_pos_weight,
|
||||
ctc_weight,
|
||||
):
|
||||
super().__init__(
|
||||
task,
|
||||
sentence_avg,
|
||||
use_guided_attention_loss,
|
||||
guided_attention_loss_sigma,
|
||||
bce_pos_weight,
|
||||
ctc_weight,
|
||||
)
|
||||
|
||||
def forward(self, model, sample, reduction="mean"):
|
||||
bsz, max_len, _ = sample["target"].size()
|
||||
feat_tgt = sample["target"]
|
||||
feat_len = sample["target_lengths"].view(bsz, 1).expand(-1, max_len)
|
||||
eos_tgt = torch.arange(max_len).to(sample["target"].device)
|
||||
eos_tgt = eos_tgt.view(1, max_len).expand(bsz, -1)
|
||||
eos_tgt = (eos_tgt == (feat_len - 1)).float()
|
||||
|
||||
feat_out, eos_out, extra = model(
|
||||
src_tokens=sample["net_input"]["src_tokens"],
|
||||
src_lengths=sample["net_input"]["src_lengths"],
|
||||
prev_output_tokens=sample["net_input"]["prev_output_tokens"],
|
||||
prev_output_tokens_mt=sample["multitask"][model.mt_task_name]["net_input"][
|
||||
"prev_output_tokens"
|
||||
],
|
||||
tgt_speaker=sample["net_input"]["tgt_speaker"],
|
||||
target_lengths=sample["target_lengths"],
|
||||
return_all_hiddens=True,
|
||||
)
|
||||
|
||||
l1_loss, mse_loss, eos_loss = self.compute_loss(
|
||||
extra["feature_out"],
|
||||
feat_out,
|
||||
eos_out,
|
||||
feat_tgt,
|
||||
eos_tgt,
|
||||
sample["target_lengths"],
|
||||
reduction,
|
||||
)
|
||||
attn_loss = torch.tensor(0.0).type_as(l1_loss)
|
||||
if self.guided_attn is not None:
|
||||
attn_loss = self.guided_attn(
|
||||
extra["attn"],
|
||||
sample["net_input"]["src_lengths"],
|
||||
sample["target_lengths"],
|
||||
reduction,
|
||||
)
|
||||
loss = (
|
||||
l1_loss + mse_loss + eos_loss + attn_loss
|
||||
) # do not include ctc loss as there's no text target
|
||||
|
||||
sample_size = sample["nsentences"] if self.sentence_avg else sample["ntokens"]
|
||||
logging_output = {
|
||||
"loss": utils.item(loss.data),
|
||||
"ntokens": sample["ntokens"],
|
||||
"nsentences": sample["nsentences"],
|
||||
"sample_size": sample_size,
|
||||
"l1_loss": utils.item(l1_loss.data),
|
||||
"mse_loss": utils.item(mse_loss.data),
|
||||
"eos_loss": utils.item(eos_loss.data),
|
||||
"attn_loss": utils.item(attn_loss.data),
|
||||
}
|
||||
|
||||
if len(self.multitask_criterion) == 0:
|
||||
return loss, sample_size, logging_output
|
||||
|
||||
# multitask
|
||||
multitask_loss, multitask_log = self.get_multitask_loss(model, sample, extra)
|
||||
loss += multitask_loss
|
||||
logging_output["multitask"] = multitask_log
|
||||
return loss, sample_size, logging_output
|
||||
|
@ -257,6 +257,24 @@ class MultitaskConfig(object):
|
||||
assert name in self.config, f"multitask '{name}' does not exist!"
|
||||
return self.config[name]
|
||||
|
||||
@property
|
||||
def first_pass_decoder_task_index(self):
|
||||
"""Return the task index of the first-pass text decoder.
|
||||
If there are multiple 'is_first_pass_decoder: True' in the config file,
|
||||
the last task is used for the first-pass decoder.
|
||||
If there is no 'is_first_pass_decoder: True' in the config file,
|
||||
the last task whose task_name includes 'target' and decoder_type is not ctc.
|
||||
"""
|
||||
idx = -1
|
||||
for i, (k, v) in enumerate(self.config.items()):
|
||||
if v.is_first_pass_decoder:
|
||||
idx = i
|
||||
if idx < 0:
|
||||
for i, (k, v) in enumerate(self.config.items()):
|
||||
if k.startswith("target") and v.decoder_type == "transformer":
|
||||
idx = i
|
||||
return idx
|
||||
|
||||
|
||||
class SingleTaskConfig(object):
|
||||
def __init__(self, name, config):
|
||||
@ -336,6 +354,34 @@ class SingleTaskConfig(object):
|
||||
)
|
||||
return weight
|
||||
|
||||
@property
|
||||
def prepend_bos_and_append_tgt_lang_tag(self) -> bool:
|
||||
"""Prepend BOS and append target lang ID token to the target (e.g. mBART with language token pretraining)."""
|
||||
return self.config.get("prepend_bos_and_append_tgt_lang_tag", False)
|
||||
|
||||
@property
|
||||
def eos_token(self):
|
||||
"""EOS token during generation"""
|
||||
return self.config.get("eos_token", "<eos>")
|
||||
|
||||
@property
|
||||
def rdrop_alpha(self):
|
||||
return self.config.get("rdrop_alpha", None)
|
||||
return self.config.get("rdrop_alpha", 0.0)
|
||||
|
||||
@property
|
||||
def is_first_pass_decoder(self):
|
||||
flag = self.config.get("is_first_pass_decoder", False)
|
||||
if flag:
|
||||
if self.decoder_type == "ctc":
|
||||
raise ValueError(
|
||||
"First-pass decoder in the multi-decoder model must not be CTC."
|
||||
)
|
||||
if "target" not in self.task_name:
|
||||
raise Warning(
|
||||
'The name of the first-pass decoder does not include "target".'
|
||||
)
|
||||
return flag
|
||||
|
||||
@property
|
||||
def get_lang_tag_mapping(self):
|
||||
return self.config.get("lang_tag_mapping", {})
|
||||
|
@ -247,8 +247,9 @@ class SpeechToSpeechMultitaskDataset(SpeechToSpeechDataset):
|
||||
|
||||
multitask_target = {}
|
||||
sample_id = self.ids[index]
|
||||
tgt_lang = self.tgt_langs[index]
|
||||
for task_name, task_dataset in self.multitask_data.items():
|
||||
multitask_target[task_name] = task_dataset.get(sample_id)
|
||||
multitask_target[task_name] = task_dataset.get(sample_id, tgt_lang)
|
||||
|
||||
return s2s_data, multitask_target
|
||||
|
||||
@ -318,7 +319,7 @@ class SpeechToSpeechDatasetCreator(object):
|
||||
src_langs = [s.get(cls.KEY_SRC_LANG, cls.DEFAULT_LANG) for s in samples]
|
||||
tgt_langs = [s.get(cls.KEY_TGT_LANG, cls.DEFAULT_LANG) for s in samples]
|
||||
|
||||
has_multitask = len(multitask) > 0
|
||||
has_multitask = multitask is not None and len(multitask.keys()) > 0
|
||||
dataset_cls = (
|
||||
SpeechToSpeechMultitaskDataset if has_multitask else SpeechToSpeechDataset
|
||||
)
|
||||
|
@ -391,6 +391,7 @@ class SpeechToTextDataset(FairseqDataset):
|
||||
class TextTargetMultitaskData(object):
|
||||
# mandatory columns
|
||||
KEY_ID, KEY_TEXT = "id", "tgt_text"
|
||||
LANG_TAG_TEMPLATE = "<lang:{}>"
|
||||
|
||||
def __init__(self, args, split, tgt_dict):
|
||||
samples = SpeechToTextDatasetCreator._load_samples_from_tsv(args.data, split)
|
||||
@ -399,6 +400,16 @@ class TextTargetMultitaskData(object):
|
||||
self.append_eos = args.decoder_type != "ctc"
|
||||
self.pre_tokenizer = self.build_tokenizer(args)
|
||||
self.bpe_tokenizer = self.build_bpe(args)
|
||||
self.prepend_bos_and_append_tgt_lang_tag = (
|
||||
args.prepend_bos_and_append_tgt_lang_tag
|
||||
)
|
||||
self.eos_token = args.eos_token
|
||||
self.lang_tag_mapping = args.get_lang_tag_mapping
|
||||
|
||||
@classmethod
|
||||
def is_lang_tag(cls, token):
|
||||
pattern = cls.LANG_TAG_TEMPLATE.replace("{}", "(.*)")
|
||||
return re.match(pattern, token)
|
||||
|
||||
@classmethod
|
||||
def tokenize(cls, tokenizer, text: str):
|
||||
@ -409,6 +420,13 @@ class TextTargetMultitaskData(object):
|
||||
text = self.tokenize(self.bpe_tokenizer, text)
|
||||
return text
|
||||
|
||||
def get_lang_tag_idx(self, lang: str, dictionary: Dictionary):
|
||||
lang_tag = self.LANG_TAG_TEMPLATE.format(lang)
|
||||
lang_tag = self.lang_tag_mapping.get(lang_tag, lang_tag)
|
||||
lang_tag_idx = dictionary.index(lang_tag)
|
||||
assert lang_tag_idx != dictionary.unk(), (lang, lang_tag)
|
||||
return lang_tag_idx
|
||||
|
||||
def build_tokenizer(self, args):
|
||||
pre_tokenizer = args.config.get("pre_tokenizer")
|
||||
if pre_tokenizer is not None:
|
||||
@ -425,14 +443,21 @@ class TextTargetMultitaskData(object):
|
||||
else:
|
||||
return None
|
||||
|
||||
def get(self, sample_id):
|
||||
def get(self, sample_id, tgt_lang=None):
|
||||
if sample_id in self.data:
|
||||
tokenized = self.get_tokenized_tgt_text(sample_id)
|
||||
return self.dict.encode_line(
|
||||
target = self.dict.encode_line(
|
||||
tokenized,
|
||||
add_if_not_exist=False,
|
||||
append_eos=self.append_eos,
|
||||
)
|
||||
if self.prepend_bos_and_append_tgt_lang_tag:
|
||||
bos = torch.LongTensor([self.dict.bos()])
|
||||
lang_tag_idx = self.get_lang_tag_idx(tgt_lang, self.dict)
|
||||
assert lang_tag_idx != self.dict.unk()
|
||||
lang_tag_idx = torch.LongTensor([lang_tag_idx])
|
||||
target = torch.cat((bos, target, lang_tag_idx), 0)
|
||||
return target
|
||||
else:
|
||||
logger.warning(f"no target for {sample_id}")
|
||||
return torch.IntTensor([])
|
||||
@ -441,7 +466,7 @@ class TextTargetMultitaskData(object):
|
||||
out = fairseq_data_utils.collate_tokens(
|
||||
samples,
|
||||
self.dict.pad(),
|
||||
eos_idx=self.dict.eos(),
|
||||
eos_idx=None,
|
||||
left_pad=False,
|
||||
move_eos_to_beginning=False,
|
||||
).long()
|
||||
@ -449,7 +474,7 @@ class TextTargetMultitaskData(object):
|
||||
prev_out = fairseq_data_utils.collate_tokens(
|
||||
samples,
|
||||
self.dict.pad(),
|
||||
eos_idx=self.dict.eos(),
|
||||
eos_idx=None,
|
||||
left_pad=False,
|
||||
move_eos_to_beginning=True,
|
||||
).long()
|
||||
@ -551,7 +576,7 @@ class SpeechToTextDatasetCreator(object):
|
||||
src_langs = [s.get(cls.KEY_SRC_LANG, cls.DEFAULT_LANG) for s in samples]
|
||||
tgt_langs = [s.get(cls.KEY_TGT_LANG, cls.DEFAULT_LANG) for s in samples]
|
||||
|
||||
has_multitask = len(multitask) > 0
|
||||
has_multitask = multitask is not None and len(multitask.keys()) > 0
|
||||
dataset_cls = (
|
||||
SpeechToTextMultitaskDataset if has_multitask else SpeechToTextDataset
|
||||
)
|
||||
|
@ -811,6 +811,10 @@ class GenerationConfig(FairseqDataclass):
|
||||
default=5,
|
||||
metadata={"help": "beam size"},
|
||||
)
|
||||
beam_mt: int = field(
|
||||
default=0,
|
||||
metadata={"help": "beam size for the first-pass decoder"},
|
||||
)
|
||||
nbest: int = field(
|
||||
default=1,
|
||||
metadata={"help": "number of hypotheses to output"},
|
||||
@ -827,6 +831,18 @@ class GenerationConfig(FairseqDataclass):
|
||||
"help": "generate sequences of maximum length ax + b, where x is the source length"
|
||||
},
|
||||
)
|
||||
max_len_a_mt: float = field(
|
||||
default=0,
|
||||
metadata={
|
||||
"help": "generate sequences of maximum length ax + b, where x is the source length for the first-pass decoder"
|
||||
},
|
||||
)
|
||||
max_len_b_mt: int = field(
|
||||
default=200,
|
||||
metadata={
|
||||
"help": "generate sequences of maximum length ax + b, where x is the source length for the first-pass decoder"
|
||||
},
|
||||
)
|
||||
min_len: int = field(
|
||||
default=1,
|
||||
metadata={"help": "minimum generation length"},
|
||||
@ -853,6 +869,12 @@ class GenerationConfig(FairseqDataclass):
|
||||
"help": "length penalty: <1.0 favors shorter, >1.0 favors longer sentences"
|
||||
},
|
||||
)
|
||||
lenpen_mt: float = field(
|
||||
default=1,
|
||||
metadata={
|
||||
"help": "length penalty for the first-pass decoder: <1.0 favors shorter, >1.0 favors longer sentences"
|
||||
},
|
||||
)
|
||||
unkpen: float = field(
|
||||
default=0,
|
||||
metadata={
|
||||
|
@ -4,4 +4,6 @@
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from .s2s_conformer import * # noqa
|
||||
from .s2s_conformer_translatotron2 import * # noqa
|
||||
from .s2s_conformer_unity import * # noqa
|
||||
from .s2s_transformer import * # noqa
|
||||
|
@ -0,0 +1,108 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from torch import Tensor
|
||||
|
||||
from fairseq.models.transformer import Linear
|
||||
from fairseq.models.transformer.transformer_decoder_aug import AugTransformerDecoder
|
||||
|
||||
|
||||
class AugTransformerUnitDecoder(AugTransformerDecoder):
|
||||
"""Based on Transformer decoder, with support to decoding stacked units"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
args,
|
||||
dictionary,
|
||||
embed_tokens,
|
||||
no_encoder_attn=False,
|
||||
output_projection=None,
|
||||
):
|
||||
super().__init__(
|
||||
args, dictionary, embed_tokens, no_encoder_attn, output_projection
|
||||
)
|
||||
self.n_frames_per_step = args.n_frames_per_step
|
||||
|
||||
self.out_proj_n_frames = (
|
||||
Linear(
|
||||
self.output_embed_dim,
|
||||
self.output_embed_dim * self.n_frames_per_step,
|
||||
bias=False,
|
||||
)
|
||||
if self.n_frames_per_step > 1
|
||||
else None
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
prev_output_tokens,
|
||||
encoder_out: Optional[Dict[str, List[Tensor]]] = None,
|
||||
encoder_out_aug: Optional[Dict[str, List[Tensor]]] = None,
|
||||
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
|
||||
features_only: bool = False,
|
||||
full_context_alignment: bool = False,
|
||||
alignment_layer: Optional[int] = None,
|
||||
alignment_heads: Optional[int] = None,
|
||||
src_lengths: Optional[Any] = None,
|
||||
return_all_hiddens: bool = False,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
prev_output_tokens (LongTensor): previous decoder outputs of shape
|
||||
`(batch, tgt_len)`, for teacher forcing
|
||||
encoder_out (optional): output from the encoder, used for
|
||||
encoder-side attention, should be of size T x B x C
|
||||
incremental_state (dict): dictionary used for storing state during
|
||||
:ref:`Incremental decoding`
|
||||
features_only (bool, optional): only return features without
|
||||
applying output layer (default: False).
|
||||
full_context_alignment (bool, optional): don't apply
|
||||
auto-regressive mask to self-attention (default: False).
|
||||
|
||||
Returns:
|
||||
tuple:
|
||||
- the decoder's output of shape `(batch, tgt_len, vocab)`
|
||||
- a dictionary with any model-specific outputs
|
||||
"""
|
||||
|
||||
x, extra = self.extract_features(
|
||||
prev_output_tokens,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_aug=encoder_out_aug,
|
||||
incremental_state=incremental_state,
|
||||
full_context_alignment=full_context_alignment,
|
||||
alignment_layer=alignment_layer,
|
||||
alignment_heads=alignment_heads,
|
||||
)
|
||||
|
||||
if not features_only:
|
||||
bsz, seq_len, d = x.size()
|
||||
if self.out_proj_n_frames:
|
||||
x = self.out_proj_n_frames(x)
|
||||
x = self.output_layer(x.view(bsz, seq_len, self.n_frames_per_step, d))
|
||||
x = x.view(bsz, seq_len * self.n_frames_per_step, -1)
|
||||
if (
|
||||
incremental_state is None and self.n_frames_per_step > 1
|
||||
): # teacher-forcing mode in training
|
||||
x = x[
|
||||
:, : -(self.n_frames_per_step - 1), :
|
||||
] # remove extra frames after <eos>
|
||||
|
||||
return x, extra
|
||||
|
||||
def upgrade_state_dict_named(self, state_dict, name):
|
||||
if self.n_frames_per_step > 1:
|
||||
move_keys = [
|
||||
(
|
||||
f"{name}.project_in_dim.weight",
|
||||
f"{name}.embed_tokens.project_in_dim.weight",
|
||||
)
|
||||
]
|
||||
for from_k, to_k in move_keys:
|
||||
if from_k in state_dict and to_k not in state_dict:
|
||||
state_dict[to_k] = state_dict[from_k]
|
||||
del state_dict[from_k]
|
@ -0,0 +1,85 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
from fairseq.models import FairseqEncoder
|
||||
from fairseq.modules import LayerNorm, TransformerEncoderLayer
|
||||
|
||||
|
||||
class TransformerEncoderNoEmb(FairseqEncoder):
|
||||
"""Transformer encoder without token embeddings."""
|
||||
|
||||
def __init__(self, args):
|
||||
super().__init__(None)
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[TransformerEncoderLayer(args) for _ in range(args.encoder_layers)]
|
||||
)
|
||||
if args.encoder_normalize_before:
|
||||
self.layer_norm = LayerNorm(args.encoder_embed_dim)
|
||||
else:
|
||||
self.layer_norm = None
|
||||
|
||||
def forward(self, x, encoder_padding_mask, return_all_hiddens=False):
|
||||
|
||||
encoder_states = []
|
||||
|
||||
for layer in self.layers:
|
||||
x = layer(x, encoder_padding_mask)
|
||||
if return_all_hiddens:
|
||||
encoder_states.append(x)
|
||||
|
||||
if self.layer_norm is not None:
|
||||
x = self.layer_norm(x)
|
||||
|
||||
return {
|
||||
"encoder_out": [x], # T x B x C
|
||||
"encoder_padding_mask": [encoder_padding_mask]
|
||||
if encoder_padding_mask is not None and encoder_padding_mask.any()
|
||||
else [], # B x T
|
||||
"encoder_embedding": [], # B x T x C
|
||||
"encoder_states": encoder_states, # List[T x B x C]
|
||||
"src_tokens": [],
|
||||
"src_lengths": [],
|
||||
}
|
||||
|
||||
def reorder_encoder_out(self, encoder_out, new_order):
|
||||
new_encoder_out = (
|
||||
[]
|
||||
if len(encoder_out["encoder_out"]) == 0
|
||||
else [x.index_select(1, new_order) for x in encoder_out["encoder_out"]]
|
||||
)
|
||||
|
||||
new_encoder_padding_mask = (
|
||||
[]
|
||||
if len(encoder_out["encoder_padding_mask"]) == 0
|
||||
else [
|
||||
x.index_select(0, new_order)
|
||||
for x in encoder_out["encoder_padding_mask"]
|
||||
]
|
||||
)
|
||||
|
||||
new_encoder_embedding = (
|
||||
[]
|
||||
if len(encoder_out["encoder_embedding"]) == 0
|
||||
else [
|
||||
x.index_select(0, new_order) for x in encoder_out["encoder_embedding"]
|
||||
]
|
||||
)
|
||||
|
||||
encoder_states = encoder_out["encoder_states"]
|
||||
if len(encoder_states) > 0:
|
||||
for idx, state in enumerate(encoder_states):
|
||||
encoder_states[idx] = state.index_select(1, new_order)
|
||||
|
||||
return {
|
||||
"encoder_out": new_encoder_out, # T x B x C
|
||||
"encoder_padding_mask": new_encoder_padding_mask, # B x T
|
||||
"encoder_embedding": new_encoder_embedding, # B x T x C
|
||||
"encoder_states": encoder_states, # List[T x B x C]
|
||||
"src_tokens": [], # B x T
|
||||
"src_lengths": [], # B x 1
|
||||
}
|
@ -11,7 +11,9 @@ import torch
|
||||
from fairseq import checkpoint_utils
|
||||
from fairseq.models import register_model, register_model_architecture
|
||||
from fairseq.models.speech_to_speech.s2s_transformer import (
|
||||
S2SpecTTransformerModel,
|
||||
S2UTTransformerModel,
|
||||
s2spect_architecture_base,
|
||||
s2ut_architecture_base,
|
||||
)
|
||||
from fairseq.models.speech_to_text import S2TConformerEncoder
|
||||
@ -97,6 +99,34 @@ class S2UTConformerModel(S2UTTransformerModel):
|
||||
return build_s2s_conformer_encoder(args)
|
||||
|
||||
|
||||
@register_model("s2spect_conformer")
|
||||
class S2SpecTConformerModel(S2SpecTTransformerModel):
|
||||
"""
|
||||
Direct speech-to-speech translation model with Conformer encoder + TTS Transformer decoder
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
S2SpecTTransformerModel.add_args(parser)
|
||||
parser.add_argument("--depthwise-conv-kernel-size", type=int, default=31)
|
||||
parser.add_argument(
|
||||
"--attn-type",
|
||||
type=str,
|
||||
default=None,
|
||||
help="If not specified uses fairseq MHA. Other valid option is espnet for using conformer",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pos-enc-type",
|
||||
type=str,
|
||||
default="abs",
|
||||
help="Must be specified in addition to attn-type=espnet for rel_pos and rope",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def build_encoder(cls, args):
|
||||
return build_s2s_conformer_encoder(args)
|
||||
|
||||
|
||||
@register_model_architecture("s2ut_conformer", "s2ut_conformer")
|
||||
def s2ut_conformer_architecture_base(args):
|
||||
args.attn_type = getattr(args, "attn_type", None)
|
||||
@ -111,3 +141,32 @@ def s2ut_conformer_architecture_base(args):
|
||||
args.encoder_layers = getattr(args, "encoder_layers", 16)
|
||||
args.depthwise_conv_kernel_size = getattr(args, "depthwise_conv_kernel_size", 31)
|
||||
s2ut_architecture_base(args)
|
||||
|
||||
|
||||
@register_model_architecture("s2spect_conformer", "s2spect_conformer")
|
||||
def s2spect_conformer_architecture_base(args):
|
||||
args.attn_type = getattr(args, "attn_type", None)
|
||||
args.pos_enc_type = getattr(args, "pos_enc_type", "abs")
|
||||
args.input_feat_per_channel = getattr(args, "input_feat_per_channel", 80)
|
||||
args.input_channels = getattr(args, "input_channels", 1)
|
||||
args.max_source_positions = getattr(args, "max_source_positions", 6000)
|
||||
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 256)
|
||||
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048)
|
||||
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4)
|
||||
args.dropout = getattr(args, "dropout", 0.1)
|
||||
args.encoder_layers = getattr(args, "encoder_layers", 16)
|
||||
args.depthwise_conv_kernel_size = getattr(args, "depthwise_conv_kernel_size", 31)
|
||||
s2spect_architecture_base(args)
|
||||
|
||||
|
||||
@register_model_architecture("s2spect_conformer", "s2spect_conformer_fisher")
|
||||
def s2spect_architecture_fisher(args):
|
||||
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 256)
|
||||
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 256 * 8)
|
||||
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4)
|
||||
args.dropout = getattr(args, "dropout", 0.1)
|
||||
|
||||
# decoder
|
||||
args.prenet_dim = getattr(args, "prenet_dim", 32)
|
||||
|
||||
s2spect_conformer_architecture_base(args)
|
||||
|
262
fairseq/models/speech_to_speech/s2s_conformer_translatotron2.py
Normal file
262
fairseq/models/speech_to_speech/s2s_conformer_translatotron2.py
Normal file
@ -0,0 +1,262 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import copy
|
||||
import logging
|
||||
|
||||
from fairseq.models import (
|
||||
FairseqEncoderModel,
|
||||
FairseqLanguageModel,
|
||||
register_model,
|
||||
register_model_architecture,
|
||||
)
|
||||
from fairseq.models.speech_to_speech.modules.ctc_decoder import CTCDecoder
|
||||
from fairseq.models.speech_to_speech.modules.transformer_encoder import (
|
||||
TransformerEncoderNoEmb,
|
||||
)
|
||||
from fairseq.models.speech_to_speech.s2s_conformer import S2SpecTConformerModel
|
||||
from fairseq.models.speech_to_speech.s2s_conformer_unity import (
|
||||
multitask_text_transformer_decoder_arch,
|
||||
)
|
||||
from fairseq.models.speech_to_speech.s2s_transformer import (
|
||||
base_multitask_text_transformer_decoder_arch,
|
||||
s2spect_architecture_base,
|
||||
)
|
||||
from fairseq.models.text_to_speech import TTSTransformerDecoder
|
||||
from fairseq.models.transformer import TransformerDecoder, TransformerModelBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@register_model("s2spect2_conformer")
|
||||
class S2SpecT2ConformerModel(S2SpecTConformerModel):
|
||||
"""
|
||||
Direct speech-to-speech translation model with Conformer encoder + MT Transformer decoder + TTS Transformer decoder
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
S2SpecTConformerModel.add_args(parser)
|
||||
parser.add_argument(
|
||||
"--translation-decoder-layers",
|
||||
type=int,
|
||||
default=4,
|
||||
metavar="N",
|
||||
help="num decoder layers in the first-pass translation module",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--synthesizer",
|
||||
default="transformer",
|
||||
choices=["transformer"],
|
||||
help="",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--synthesizer-encoder-layers",
|
||||
type=int,
|
||||
default=0,
|
||||
metavar="N",
|
||||
help="num encoder layers in the second-pass synthesizer module",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def build_multitask_decoder(
|
||||
cls,
|
||||
args,
|
||||
tgt_dict,
|
||||
in_dim,
|
||||
is_mt_decoder,
|
||||
decoder_layers,
|
||||
decoder_embed_dim,
|
||||
decoder_attention_heads,
|
||||
):
|
||||
decoder_args = args.decoder_args
|
||||
decoder_args.encoder_embed_dim = in_dim
|
||||
if args.decoder_type == "transformer":
|
||||
if is_mt_decoder:
|
||||
multitask_text_transformer_decoder_arch(
|
||||
decoder_args,
|
||||
decoder_layers,
|
||||
decoder_embed_dim,
|
||||
decoder_attention_heads,
|
||||
) # 4L
|
||||
else:
|
||||
base_multitask_text_transformer_decoder_arch(decoder_args) # 2L
|
||||
task_decoder = TransformerDecoder(
|
||||
decoder_args,
|
||||
tgt_dict,
|
||||
embed_tokens=TransformerModelBase.build_embedding(
|
||||
decoder_args,
|
||||
tgt_dict,
|
||||
decoder_args.decoder_embed_dim,
|
||||
),
|
||||
)
|
||||
elif args.decoder_type == "ctc":
|
||||
task_decoder = CTCDecoder(
|
||||
dictionary=tgt_dict,
|
||||
in_dim=in_dim,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"currently only support multitask decoder_type 'transformer', 'ctc'"
|
||||
)
|
||||
|
||||
return task_decoder
|
||||
|
||||
@classmethod
|
||||
def build_decoder(cls, args):
|
||||
_args = copy.deepcopy(args)
|
||||
_args.encoder_embed_dim = args.decoder_embed_dim
|
||||
|
||||
if args.synthesizer == "transformer":
|
||||
return TTSTransformerDecoder(_args, None, padding_idx=1)
|
||||
else:
|
||||
raise NotImplementedError(args.synthesizer)
|
||||
|
||||
@classmethod
|
||||
def build_model(cls, args, task):
|
||||
encoder = cls.build_encoder(args)
|
||||
decoder = cls.build_decoder(args)
|
||||
base_model = cls(encoder, decoder)
|
||||
|
||||
# set up multitask decoders
|
||||
base_model.mt_task_name = None
|
||||
base_model.multitask_decoders = {}
|
||||
has_first_pass_decoder = False
|
||||
for task_name, task_obj in task.multitask_tasks.items():
|
||||
if task_obj.is_first_pass_decoder:
|
||||
has_first_pass_decoder = True
|
||||
base_model.mt_task_name = task_name
|
||||
|
||||
in_dim = (
|
||||
args.encoder_embed_dim
|
||||
if task_obj.args.input_from == "encoder"
|
||||
else args.decoder_embed_dim
|
||||
)
|
||||
task_decoder = cls.build_multitask_decoder(
|
||||
task_obj.args,
|
||||
task_obj.target_dictionary,
|
||||
in_dim,
|
||||
task_obj.is_first_pass_decoder,
|
||||
getattr(args, "translation_decoder_layers", 4),
|
||||
getattr(args, "decoder_embed_dim", 256),
|
||||
getattr(args, "decoder_attention_heads", 4),
|
||||
)
|
||||
|
||||
setattr(base_model, f"{task_name}_decoder", task_decoder)
|
||||
decoder_model_cls = (
|
||||
FairseqEncoderModel
|
||||
if task_obj.args.decoder_type == "ctc"
|
||||
else FairseqLanguageModel
|
||||
)
|
||||
base_model.multitask_decoders[task_name] = decoder_model_cls(
|
||||
getattr(base_model, f"{task_name}_decoder")
|
||||
)
|
||||
|
||||
assert has_first_pass_decoder, "set at least one intermediate non-CTC decoder"
|
||||
|
||||
# set up encoder on top of the auxiliary MT decoder
|
||||
if getattr(args, "synthesizer_encoder_layers", 0) > 0:
|
||||
base_model.synthesizer_encoder = cls.build_text_encoder(args)
|
||||
else:
|
||||
base_model.synthesizer_encoder = None
|
||||
|
||||
return base_model
|
||||
|
||||
@classmethod
|
||||
def build_text_encoder(cls, args):
|
||||
_args = copy.deepcopy(args)
|
||||
_args.encoder_layers = args.synthesizer_encoder_layers
|
||||
_args.encoder_embed_dim = args.decoder_embed_dim
|
||||
_args.encoder_ffn_embed_dim = args.decoder_ffn_embed_dim
|
||||
_args.encoder_attention_heads = args.decoder_attention_heads
|
||||
_args.encoder_normalize_before = True
|
||||
return TransformerEncoderNoEmb(_args)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
src_tokens,
|
||||
src_lengths,
|
||||
prev_output_tokens,
|
||||
prev_output_tokens_mt,
|
||||
tgt_speaker=None,
|
||||
incremental_state=None,
|
||||
target_lengths=None,
|
||||
speaker=None,
|
||||
return_all_hiddens=False,
|
||||
):
|
||||
encoder_out = self.encoder(
|
||||
src_tokens,
|
||||
src_lengths=src_lengths,
|
||||
tgt_speaker=tgt_speaker,
|
||||
return_all_hiddens=return_all_hiddens,
|
||||
)
|
||||
|
||||
# 1. MT decoder
|
||||
mt_decoder = getattr(self, f"{self.mt_task_name}_decoder")
|
||||
mt_decoder_out = mt_decoder(
|
||||
prev_output_tokens_mt,
|
||||
encoder_out=encoder_out,
|
||||
)
|
||||
x = mt_decoder_out[1]["inner_states"][-1]
|
||||
if mt_decoder.layer_norm is not None:
|
||||
x = mt_decoder.layer_norm(x)
|
||||
|
||||
mt_decoder_padding_mask = None
|
||||
if prev_output_tokens_mt.eq(mt_decoder.padding_idx).any():
|
||||
mt_decoder_padding_mask = prev_output_tokens_mt.eq(mt_decoder.padding_idx)
|
||||
|
||||
# 2. TTS encoder
|
||||
if self.synthesizer_encoder is not None:
|
||||
tts_encoder_out = self.synthesizer_encoder(
|
||||
x,
|
||||
mt_decoder_padding_mask,
|
||||
return_all_hiddens=return_all_hiddens,
|
||||
)
|
||||
else:
|
||||
tts_encoder_out = {
|
||||
"encoder_out": [x], # T x B x C
|
||||
"encoder_padding_mask": [mt_decoder_padding_mask], # B x T
|
||||
}
|
||||
|
||||
# 3. TTS decoder
|
||||
decoder_out = self.decoder(
|
||||
prev_output_tokens,
|
||||
encoder_out=tts_encoder_out,
|
||||
incremental_state=incremental_state,
|
||||
target_lengths=target_lengths,
|
||||
speaker=speaker,
|
||||
)
|
||||
if return_all_hiddens:
|
||||
decoder_out[-1]["encoder_states"] = encoder_out["encoder_states"]
|
||||
decoder_out[-1]["encoder_padding_mask"] = encoder_out[
|
||||
"encoder_padding_mask"
|
||||
]
|
||||
decoder_out[-1]["mt_decoder_out"] = mt_decoder_out
|
||||
return decoder_out
|
||||
|
||||
|
||||
@register_model_architecture(
|
||||
model_name="s2spect2_conformer", arch_name="s2spect2_conformer"
|
||||
)
|
||||
def s2spect2_conformer_architecture_base(args):
|
||||
args.conv_version = getattr(args, "conv_version", "convtransformer")
|
||||
args.attn_type = getattr(args, "attn_type", None)
|
||||
args.pos_enc_type = getattr(args, "pos_enc_type", "abs")
|
||||
args.max_source_positions = getattr(args, "max_source_positions", 6000)
|
||||
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 256)
|
||||
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048)
|
||||
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4)
|
||||
args.dropout = getattr(args, "dropout", 0.1)
|
||||
args.encoder_layers = getattr(args, "encoder_layers", 16)
|
||||
args.depthwise_conv_kernel_size = getattr(args, "depthwise_conv_kernel_size", 31)
|
||||
s2spect_architecture_base(args)
|
||||
|
||||
|
||||
# for old naming
|
||||
@register_model_architecture(
|
||||
model_name="s2spect2_conformer", arch_name="s2spect_conformer_translatotron2"
|
||||
)
|
||||
def s2spect2_conformer_architecture_base_legacy(args):
|
||||
s2spect2_conformer_architecture_base(args)
|
298
fairseq/models/speech_to_speech/s2s_conformer_unity.py
Normal file
298
fairseq/models/speech_to_speech/s2s_conformer_unity.py
Normal file
@ -0,0 +1,298 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import copy
|
||||
import logging
|
||||
|
||||
from fairseq.models import (
|
||||
FairseqEncoder,
|
||||
FairseqEncoderModel,
|
||||
FairseqLanguageModel,
|
||||
register_model,
|
||||
register_model_architecture,
|
||||
)
|
||||
from fairseq.models.speech_to_speech.modules.ctc_decoder import CTCDecoder
|
||||
from fairseq.models.speech_to_speech.modules.stacked_embedding import StackedEmbedding
|
||||
from fairseq.models.speech_to_speech.modules.transformer_decoder_aug import (
|
||||
AugTransformerUnitDecoder,
|
||||
)
|
||||
from fairseq.models.speech_to_speech.modules.transformer_encoder import (
|
||||
TransformerEncoderNoEmb,
|
||||
)
|
||||
from fairseq.models.speech_to_speech.s2s_conformer import S2UTConformerModel
|
||||
from fairseq.models.speech_to_speech.s2s_transformer import (
|
||||
TransformerUnitDecoder,
|
||||
base_multitask_text_transformer_decoder_arch,
|
||||
s2ut_architecture_base,
|
||||
)
|
||||
from fairseq.models.transformer import TransformerDecoder, TransformerModelBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def multitask_text_transformer_decoder_arch(
|
||||
args, decoder_layers, decoder_embed_dim=256, decoder_attention_heads=4
|
||||
):
|
||||
args.decoder_layers = decoder_layers
|
||||
args.decoder_embed_dim = decoder_embed_dim
|
||||
args.decoder_attention_heads = decoder_attention_heads
|
||||
base_multitask_text_transformer_decoder_arch(args)
|
||||
|
||||
|
||||
@register_model("unity_conformer")
|
||||
class UnityConformerModel(S2UTConformerModel):
|
||||
"""
|
||||
Direct speech-to-speech translation model with Conformer encoder + MT Transformer decoder + Transformer discrete unit decoder
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
S2UTConformerModel.add_args(parser)
|
||||
parser.add_argument(
|
||||
"--translation-decoder-layers",
|
||||
type=int,
|
||||
default=4,
|
||||
metavar="N",
|
||||
help="num decoder layers in the first-pass translation module",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--synthesizer",
|
||||
default="transformer",
|
||||
choices=["transformer"],
|
||||
help="",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--synthesizer-encoder-layers",
|
||||
type=int,
|
||||
default=0,
|
||||
metavar="N",
|
||||
help="num encoder layers in the second-pass synthesizer module",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--synthesizer-augmented-cross-attention",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="augmented cross-attention over speech encoder output",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def build_multitask_decoder(
|
||||
cls,
|
||||
args,
|
||||
tgt_dict,
|
||||
in_dim,
|
||||
is_first_pass_decoder,
|
||||
decoder_layers,
|
||||
decoder_embed_dim,
|
||||
decoder_attention_heads,
|
||||
):
|
||||
decoder_args = args.decoder_args
|
||||
decoder_args.encoder_embed_dim = in_dim
|
||||
if args.decoder_type == "transformer":
|
||||
if is_first_pass_decoder:
|
||||
multitask_text_transformer_decoder_arch(
|
||||
decoder_args,
|
||||
decoder_layers,
|
||||
decoder_embed_dim,
|
||||
decoder_attention_heads,
|
||||
) # 4L
|
||||
else:
|
||||
base_multitask_text_transformer_decoder_arch(decoder_args) # 2L
|
||||
task_decoder = TransformerDecoder(
|
||||
decoder_args,
|
||||
tgt_dict,
|
||||
embed_tokens=TransformerModelBase.build_embedding(
|
||||
decoder_args,
|
||||
tgt_dict,
|
||||
decoder_args.decoder_embed_dim,
|
||||
),
|
||||
)
|
||||
elif args.decoder_type == "ctc":
|
||||
task_decoder = CTCDecoder(
|
||||
dictionary=tgt_dict,
|
||||
in_dim=in_dim,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"currently only support multitask decoder_type 'transformer', 'ctc'"
|
||||
)
|
||||
|
||||
return task_decoder
|
||||
|
||||
@classmethod
|
||||
def build_decoder(cls, args, tgt_dict, aug_attn=False):
|
||||
num_embeddings = len(tgt_dict)
|
||||
padding_idx = tgt_dict.pad()
|
||||
embed_tokens = StackedEmbedding(
|
||||
num_embeddings,
|
||||
args.decoder_embed_dim,
|
||||
padding_idx,
|
||||
num_stacked=args.n_frames_per_step,
|
||||
)
|
||||
|
||||
_args = copy.deepcopy(args)
|
||||
_args.encoder_embed_dim = args.decoder_embed_dim
|
||||
|
||||
decoder_cls = AugTransformerUnitDecoder if aug_attn else TransformerUnitDecoder
|
||||
return decoder_cls(
|
||||
_args,
|
||||
tgt_dict,
|
||||
embed_tokens,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def build_model(cls, args, task):
|
||||
encoder = cls.build_encoder(args)
|
||||
decoder = cls.build_decoder(
|
||||
args,
|
||||
task.target_dictionary,
|
||||
aug_attn=getattr(args, "synthesizer_augmented_cross_attention", False),
|
||||
)
|
||||
base_model = cls(encoder, decoder)
|
||||
|
||||
base_model.t2u_augmented_cross_attn = getattr(
|
||||
args, "synthesizer_augmented_cross_attention", False
|
||||
)
|
||||
|
||||
# set up multitask decoders
|
||||
base_model.mt_task_name = None
|
||||
base_model.multitask_decoders = {}
|
||||
has_first_pass_decoder = False
|
||||
for task_name, task_obj in task.multitask_tasks.items():
|
||||
if task_obj.is_first_pass_decoder:
|
||||
has_first_pass_decoder = True
|
||||
base_model.mt_task_name = task_name
|
||||
|
||||
in_dim = (
|
||||
args.encoder_embed_dim
|
||||
if task_obj.args.input_from == "encoder"
|
||||
else args.decoder_embed_dim
|
||||
)
|
||||
task_decoder = cls.build_multitask_decoder(
|
||||
task_obj.args,
|
||||
task_obj.target_dictionary,
|
||||
in_dim,
|
||||
task_obj.is_first_pass_decoder,
|
||||
getattr(args, "translation_decoder_layers", 4),
|
||||
getattr(args, "decoder_embed_dim", 256),
|
||||
getattr(args, "decoder_attention_heads", 4),
|
||||
)
|
||||
|
||||
setattr(base_model, f"{task_name}_decoder", task_decoder)
|
||||
decoder_model_cls = (
|
||||
FairseqEncoderModel
|
||||
if task_obj.args.decoder_type == "ctc"
|
||||
else FairseqLanguageModel
|
||||
)
|
||||
base_model.multitask_decoders[task_name] = decoder_model_cls(
|
||||
getattr(base_model, f"{task_name}_decoder")
|
||||
)
|
||||
|
||||
assert has_first_pass_decoder, "set at least one intermediate non-CTC decoder"
|
||||
|
||||
# set up encoder on top of the auxiliary MT decoder
|
||||
if getattr(args, "synthesizer_encoder_layers", 0) > 0:
|
||||
base_model.synthesizer_encoder = cls.build_text_encoder(args)
|
||||
else:
|
||||
base_model.synthesizer_encoder = None
|
||||
|
||||
return base_model
|
||||
|
||||
@classmethod
|
||||
def build_text_encoder(cls, args):
|
||||
_args = copy.deepcopy(args)
|
||||
_args.encoder_layers = args.synthesizer_encoder_layers
|
||||
_args.encoder_embed_dim = args.decoder_embed_dim
|
||||
_args.encoder_ffn_embed_dim = args.decoder_ffn_embed_dim
|
||||
_args.encoder_attention_heads = args.decoder_attention_heads
|
||||
_args.encoder_normalize_before = True
|
||||
return TransformerEncoderNoEmb(_args)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
src_tokens,
|
||||
src_lengths,
|
||||
prev_output_tokens,
|
||||
prev_output_tokens_mt,
|
||||
tgt_speaker=None,
|
||||
return_all_hiddens=False,
|
||||
):
|
||||
mt_decoder = getattr(self, f"{self.mt_task_name}_decoder")
|
||||
|
||||
encoder_out = self.encoder(
|
||||
src_tokens,
|
||||
src_lengths=src_lengths,
|
||||
tgt_speaker=tgt_speaker,
|
||||
return_all_hiddens=return_all_hiddens,
|
||||
)
|
||||
|
||||
# 1. MT decoder
|
||||
mt_decoder_out = mt_decoder(
|
||||
prev_output_tokens_mt,
|
||||
encoder_out=encoder_out,
|
||||
)
|
||||
x = mt_decoder_out[1]["inner_states"][-1]
|
||||
if mt_decoder.layer_norm is not None:
|
||||
x = mt_decoder.layer_norm(x)
|
||||
|
||||
mt_decoder_padding_mask = None
|
||||
if prev_output_tokens_mt.eq(mt_decoder.padding_idx).any():
|
||||
mt_decoder_padding_mask = prev_output_tokens_mt.eq(mt_decoder.padding_idx)
|
||||
|
||||
# 2. T2U encoder
|
||||
if self.synthesizer_encoder is not None:
|
||||
t2u_encoder_out = self.synthesizer_encoder(
|
||||
x,
|
||||
mt_decoder_padding_mask,
|
||||
return_all_hiddens=return_all_hiddens,
|
||||
)
|
||||
else:
|
||||
t2u_encoder_out = {
|
||||
"encoder_out": [x], # T x B x C
|
||||
"encoder_padding_mask": [mt_decoder_padding_mask], # B x T
|
||||
}
|
||||
|
||||
# 3. T2U decoder
|
||||
if self.t2u_augmented_cross_attn:
|
||||
decoder_out = self.decoder(
|
||||
prev_output_tokens,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_aug=t2u_encoder_out,
|
||||
)
|
||||
else:
|
||||
decoder_out = self.decoder(
|
||||
prev_output_tokens,
|
||||
encoder_out=t2u_encoder_out,
|
||||
)
|
||||
if return_all_hiddens:
|
||||
decoder_out[-1]["encoder_states"] = encoder_out["encoder_states"]
|
||||
decoder_out[-1]["encoder_padding_mask"] = encoder_out[
|
||||
"encoder_padding_mask"
|
||||
]
|
||||
decoder_out[-1]["mt_decoder_out"] = mt_decoder_out
|
||||
return decoder_out
|
||||
|
||||
|
||||
@register_model_architecture(model_name="unity_conformer", arch_name="unity_conformer")
|
||||
def unity_conformer_architecture_base(args):
|
||||
args.conv_version = getattr(args, "conv_version", "convtransformer")
|
||||
args.attn_type = getattr(args, "attn_type", None)
|
||||
args.pos_enc_type = getattr(args, "pos_enc_type", "abs")
|
||||
args.max_source_positions = getattr(args, "max_source_positions", 6000)
|
||||
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 256)
|
||||
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048)
|
||||
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4)
|
||||
args.dropout = getattr(args, "dropout", 0.1)
|
||||
args.encoder_layers = getattr(args, "encoder_layers", 16)
|
||||
args.depthwise_conv_kernel_size = getattr(args, "depthwise_conv_kernel_size", 31)
|
||||
s2ut_architecture_base(args)
|
||||
|
||||
|
||||
# for old naming
|
||||
@register_model_architecture(
|
||||
model_name="unity_conformer", arch_name="s2ut_conformer_translatotron2"
|
||||
)
|
||||
def unity_conformer_architecture_base_legacy(args):
|
||||
unity_conformer_architecture_base(args)
|
@ -10,3 +10,4 @@ from .s2t_conformer import * # noqa
|
||||
from .s2t_transformer import * # noqa
|
||||
from .s2t_wav_transformer import * # noqa
|
||||
from .xm_transformer import * # noqa
|
||||
from .xm_transformer_unity import * # noqa
|
||||
|
@ -623,6 +623,7 @@ class XMTransformerModel(FairseqEncoderDecoderModel):
|
||||
_args.dropout = args.decoder_dropout
|
||||
_args.attention_dropout = args.decoder_attention_dropout
|
||||
_args.activation_dropout = args.decoder_activation_dropout
|
||||
_args.layerdrop = _args.decoder_layerdrop
|
||||
|
||||
decoder = TransformerDecoder(_args, task.target_dictionary, embed_tokens)
|
||||
decoder = cls.maybe_load_pretrained(
|
||||
|
351
fairseq/models/speech_to_text/xm_transformer_unity.py
Normal file
351
fairseq/models/speech_to_text/xm_transformer_unity.py
Normal file
@ -0,0 +1,351 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import copy
|
||||
import logging
|
||||
|
||||
from fairseq.models import (
|
||||
FairseqEncoderModel,
|
||||
FairseqLanguageModel,
|
||||
register_model,
|
||||
register_model_architecture,
|
||||
)
|
||||
from fairseq.models.speech_to_speech.modules.ctc_decoder import CTCDecoder
|
||||
from fairseq.models.speech_to_speech.modules.transformer_encoder import (
|
||||
TransformerEncoderNoEmb,
|
||||
)
|
||||
from fairseq.models.speech_to_text.xm_transformer import XMTransformerModel
|
||||
from fairseq.models.speech_to_text.xm_transformer import (
|
||||
base_architecture as xm_t_base_architecture,
|
||||
)
|
||||
from fairseq.models.speech_to_text.xm_transformer import (
|
||||
build_embedding,
|
||||
need_finetuning,
|
||||
set_default_adaptor_args,
|
||||
set_default_general_args,
|
||||
set_default_transformer_decoder_args,
|
||||
set_default_w2v_encoder_args,
|
||||
)
|
||||
from fairseq.models.transformer import Linear, TransformerDecoder, TransformerModelBase
|
||||
from fairseq.models.transformer.transformer_decoder_aug import AugTransformerDecoder
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def unit_transformer_decoder_arch_base(
|
||||
args, decoder_layers=6, decoder_embed_dim=768, decoder_attention_heads=12
|
||||
):
|
||||
args.encoder_layers = decoder_layers
|
||||
args.decoder_layers = decoder_layers
|
||||
args.decoder_embed_dim = decoder_embed_dim
|
||||
args.decoder_ffn_embed_dim = decoder_embed_dim * 4
|
||||
args.decoder_attention_heads = decoder_attention_heads
|
||||
args.encoder_embed_dim = args.decoder_embed_dim
|
||||
args.decoder_output_dim = decoder_embed_dim
|
||||
args.decoder_input_dim = decoder_embed_dim
|
||||
|
||||
|
||||
def unit_transformer_decoder_arch_large(
|
||||
args, decoder_layers=12, decoder_embed_dim=1024, decoder_attention_heads=16
|
||||
):
|
||||
args.encoder_layers = decoder_layers
|
||||
args.decoder_layers = decoder_layers
|
||||
args.decoder_embed_dim = decoder_embed_dim
|
||||
args.decoder_ffn_embed_dim = decoder_embed_dim * 4
|
||||
args.decoder_attention_heads = decoder_attention_heads
|
||||
args.encoder_embed_dim = args.decoder_embed_dim
|
||||
args.decoder_output_dim = decoder_embed_dim
|
||||
args.decoder_input_dim = decoder_embed_dim
|
||||
|
||||
|
||||
@register_model("unity_xm_transformer")
|
||||
class XMTransformerModelUnitY(XMTransformerModel):
|
||||
@classmethod
|
||||
def hub_models(cls):
|
||||
base_url = "http://dl.fbaipublicfiles.com/fairseq/s2t"
|
||||
model_ids = []
|
||||
return {i: f"{base_url}/{i}.tar.gz" for i in model_ids}
|
||||
|
||||
def __init__(self, encoder, decoder):
|
||||
super().__init__(encoder, decoder)
|
||||
|
||||
@classmethod
|
||||
def add_args(cls, parser):
|
||||
"""Add model-specific arguments to the parser."""
|
||||
XMTransformerModel.add_args(parser)
|
||||
parser.add_argument(
|
||||
"--translation-decoder-layers",
|
||||
type=int,
|
||||
default=4,
|
||||
metavar="N",
|
||||
help="num decoder layers in the first-pass translation module",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--synthesizer-encoder-layers",
|
||||
type=int,
|
||||
default=0,
|
||||
metavar="N",
|
||||
help="num encoder layers in the second-pass synthesizer module",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--synthesizer-augmented-cross-attention",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="augmented cross-attention over speech encoder output",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--load-pretrained-aux-decoder-from",
|
||||
type=str,
|
||||
metavar="STR",
|
||||
help="model to take decoder weights from (for initialization)",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def build_text_decoder(cls, args, tgt_dict):
|
||||
_args = copy.deepcopy(args)
|
||||
|
||||
if args.adaptor_proj or args.encoder_proj: # not V0 arch
|
||||
_args.encoder_embed_dim = _args.decoder_embed_dim
|
||||
_args.dropout = args.decoder_dropout
|
||||
_args.attention_dropout = args.decoder_attention_dropout
|
||||
_args.activation_dropout = args.decoder_activation_dropout
|
||||
_args.layerdrop = _args.decoder_layerdrop
|
||||
_args.decoder_layers = _args.translation_decoder_layers
|
||||
|
||||
embed_tokens = build_embedding(tgt_dict, _args.decoder_embed_dim)
|
||||
decoder = TransformerDecoder(_args, tgt_dict, embed_tokens)
|
||||
|
||||
if getattr(args, "load_pretrained_aux_decoder_from", None) is not None:
|
||||
decoder = cls.maybe_load_pretrained(
|
||||
decoder, getattr(args, "load_pretrained_aux_decoder_from", None)
|
||||
)
|
||||
|
||||
for k, p in decoder.named_parameters():
|
||||
p.requires_grad = need_finetuning(args.finetune_decoder_params, k)
|
||||
return decoder
|
||||
|
||||
@classmethod
|
||||
def build_decoder(cls, args, task, aug_attn=False):
|
||||
_args = copy.deepcopy(args)
|
||||
_args.layerdrop = 0.0 # turn off layerdrop for shallow layers
|
||||
|
||||
_args.encoder_embed_dim = args.decoder_embed_dim
|
||||
|
||||
proj = None
|
||||
if args.decoder_embed_dim != _args.decoder_embed_dim:
|
||||
proj = Linear(args.decoder_embed_dim, _args.decoder_embed_dim)
|
||||
|
||||
embed_tokens = build_embedding(task.target_dictionary, _args.decoder_embed_dim)
|
||||
decoder_cls = AugTransformerDecoder if aug_attn else TransformerDecoder
|
||||
decoder = decoder_cls(_args, task.target_dictionary, embed_tokens)
|
||||
|
||||
if getattr(args, "load_pretrained_decoder_from", None) is not None:
|
||||
# load all layers first and then discard the bottom layers
|
||||
embed_tokens = build_embedding(
|
||||
task.target_dictionary, _args.decoder_embed_dim
|
||||
)
|
||||
decoder_tmp = decoder_cls(_args, task.target_dictionary, embed_tokens)
|
||||
decoder_tmp = cls.maybe_load_pretrained(
|
||||
decoder_tmp, getattr(_args, "load_pretrained_decoder_from", None)
|
||||
)
|
||||
state_dict = decoder_tmp.state_dict()
|
||||
for k, p in decoder.named_parameters():
|
||||
p.data = state_dict[k].data
|
||||
p.requires_grad = need_finetuning(_args.finetune_decoder_params, k)
|
||||
decoder.layers = decoder.layers[-_args.decoder_layers :]
|
||||
|
||||
return decoder, proj, _args
|
||||
|
||||
@classmethod
|
||||
def build_model(cls, args, task):
|
||||
"""Build a new model instance."""
|
||||
|
||||
# make sure all arguments are present in older models
|
||||
xm_t_base_architecture(args)
|
||||
|
||||
encoder = cls.build_encoder(args)
|
||||
decoder, proj, unit_args = cls.build_decoder(
|
||||
args,
|
||||
task,
|
||||
aug_attn=getattr(args, "synthesizer_augmented_cross_attention", False),
|
||||
)
|
||||
base_model = cls(encoder, decoder)
|
||||
setattr(base_model, "proj", proj)
|
||||
|
||||
base_model.t2u_augmented_cross_attn = getattr(
|
||||
args, "synthesizer_augmented_cross_attention", False
|
||||
)
|
||||
|
||||
# set up multitask decoders
|
||||
base_model.mt_task_name = None
|
||||
base_model.multitask_decoders = {}
|
||||
has_first_pass_decoder = False
|
||||
for task_name, task_obj in task.multitask_tasks.items():
|
||||
if task_obj.is_first_pass_decoder:
|
||||
has_first_pass_decoder = True
|
||||
base_model.mt_task_name = task_name
|
||||
|
||||
task_decoder = cls.build_multitask_decoder(
|
||||
args,
|
||||
task_obj.args,
|
||||
task_obj.target_dictionary,
|
||||
args.decoder_embed_dim,
|
||||
task_obj.is_first_pass_decoder,
|
||||
)
|
||||
|
||||
setattr(base_model, f"{task_name}_decoder", task_decoder)
|
||||
decoder_model_cls = (
|
||||
FairseqEncoderModel
|
||||
if task_obj.args.decoder_type == "ctc"
|
||||
else FairseqLanguageModel
|
||||
)
|
||||
base_model.multitask_decoders[task_name] = decoder_model_cls(
|
||||
getattr(base_model, f"{task_name}_decoder")
|
||||
)
|
||||
|
||||
assert has_first_pass_decoder, "set at least one intermediate non-CTC decoder"
|
||||
|
||||
# set up encoder on top of the auxiliary MT decoder
|
||||
if getattr(args, "synthesizer_encoder_layers", 0) > 0:
|
||||
base_model.synthesizer_encoder = cls.build_t2u_encoder(unit_args)
|
||||
else:
|
||||
base_model.synthesizer_encoder = None
|
||||
|
||||
return base_model
|
||||
|
||||
@classmethod
|
||||
def build_multitask_decoder(
|
||||
cls, args, mtl_args, tgt_dict, in_dim, is_first_pass_decoder
|
||||
):
|
||||
decoder_args = mtl_args.decoder_args
|
||||
decoder_args.encoder_embed_dim = in_dim
|
||||
if mtl_args.decoder_type == "transformer":
|
||||
if is_first_pass_decoder:
|
||||
task_decoder = cls.build_text_decoder(args, tgt_dict)
|
||||
else:
|
||||
from fairseq.models.speech_to_speech import (
|
||||
base_multitask_text_transformer_decoder_arch,
|
||||
)
|
||||
|
||||
base_multitask_text_transformer_decoder_arch(decoder_args) # 2L
|
||||
task_decoder = TransformerDecoder(
|
||||
decoder_args,
|
||||
tgt_dict,
|
||||
embed_tokens=TransformerModelBase.build_embedding(
|
||||
decoder_args,
|
||||
tgt_dict,
|
||||
decoder_args.decoder_embed_dim,
|
||||
),
|
||||
)
|
||||
elif args.decoder_type == "ctc":
|
||||
task_decoder = CTCDecoder(
|
||||
dictionary=tgt_dict,
|
||||
in_dim=in_dim,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"currently only support multitask decoder_type 'transformer', 'ctc'"
|
||||
)
|
||||
|
||||
return task_decoder
|
||||
|
||||
@classmethod
|
||||
def build_t2u_encoder(cls, args):
|
||||
_args = copy.deepcopy(args)
|
||||
_args.encoder_layers = _args.synthesizer_encoder_layers
|
||||
_args.encoder_embed_dim = args.decoder_embed_dim
|
||||
_args.encoder_ffn_embed_dim = args.decoder_ffn_embed_dim
|
||||
_args.encoder_attention_heads = args.decoder_attention_heads
|
||||
_args.encoder_normalize_before = True
|
||||
return TransformerEncoderNoEmb(_args)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
src_tokens,
|
||||
src_lengths,
|
||||
prev_output_tokens,
|
||||
prev_output_tokens_mt,
|
||||
return_all_hiddens=False,
|
||||
tgt_speaker=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
The forward method inherited from the base class has a **kwargs
|
||||
argument in its input, which is not supported in torchscript. This
|
||||
method overwrites the forward method definition without **kwargs.
|
||||
"""
|
||||
encoder_out = self.encoder(
|
||||
src_tokens=src_tokens, src_lengths=src_lengths, **kwargs
|
||||
)
|
||||
|
||||
# 1. MT decoder
|
||||
mt_decoder = getattr(self, f"{self.mt_task_name}_decoder")
|
||||
mt_decoder_out = mt_decoder(
|
||||
prev_output_tokens_mt,
|
||||
encoder_out=encoder_out,
|
||||
)
|
||||
x = mt_decoder_out[1]["inner_states"][-1]
|
||||
if mt_decoder.layer_norm is not None:
|
||||
x = mt_decoder.layer_norm(x)
|
||||
if self.proj is not None:
|
||||
x = self.proj(x)
|
||||
|
||||
mt_decoder_padding_mask = None
|
||||
if prev_output_tokens_mt.eq(mt_decoder.padding_idx).any():
|
||||
mt_decoder_padding_mask = prev_output_tokens_mt.eq(mt_decoder.padding_idx)
|
||||
|
||||
# 2. T2U encoder
|
||||
if self.synthesizer_encoder is not None:
|
||||
t2u_encoder_out = self.synthesizer_encoder(
|
||||
x,
|
||||
mt_decoder_padding_mask,
|
||||
)
|
||||
else:
|
||||
t2u_encoder_out = {
|
||||
"encoder_out": [x], # T x B x C
|
||||
"encoder_padding_mask": [mt_decoder_padding_mask], # B x T
|
||||
}
|
||||
|
||||
# 3. T2U decoder
|
||||
if self.t2u_augmented_cross_attn:
|
||||
decoder_out = self.decoder(
|
||||
prev_output_tokens,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_aug=t2u_encoder_out,
|
||||
)
|
||||
else:
|
||||
decoder_out = self.decoder(
|
||||
prev_output_tokens,
|
||||
encoder_out=t2u_encoder_out,
|
||||
)
|
||||
if return_all_hiddens:
|
||||
decoder_out[-1]["encoder_states"] = encoder_out["encoder_out"]
|
||||
# NOTE: from the top layer
|
||||
decoder_out[-1]["encoder_padding_mask"] = encoder_out[
|
||||
"encoder_padding_mask"
|
||||
]
|
||||
decoder_out[-1]["mt_decoder_out"] = mt_decoder_out
|
||||
return decoder_out
|
||||
|
||||
|
||||
@register_model_architecture(
|
||||
model_name="unity_xm_transformer", arch_name="unity_xm_transformer"
|
||||
)
|
||||
def base_architecture_unity(args):
|
||||
set_default_general_args(args)
|
||||
set_default_w2v_encoder_args(args)
|
||||
set_default_adaptor_args(args)
|
||||
set_default_transformer_decoder_args(args)
|
||||
|
||||
args.layernorm_embedding = False
|
||||
args.decoder_learned_pos = False
|
||||
|
||||
|
||||
# for old models
|
||||
@register_model_architecture(
|
||||
model_name="unity_xm_transformer", arch_name="xm_transformer_t2"
|
||||
)
|
||||
def base_architecture_unity_legacy(args):
|
||||
base_architecture_unity(args)
|
@ -8,6 +8,7 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
|
||||
from fairseq import utils
|
||||
from fairseq.distributed import fsdp_wrap
|
||||
@ -25,7 +26,6 @@ from fairseq.modules import (
|
||||
)
|
||||
from fairseq.modules.checkpoint_activations import checkpoint_wrapper
|
||||
from fairseq.modules.quant_noise import quant_noise as apply_quant_noise_
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
# rewrite name for backward compatibility in `make_generation_fast_`
|
||||
@ -42,7 +42,7 @@ class TransformerDecoderBase(FairseqIncrementalDecoder):
|
||||
is a :class:`TransformerDecoderLayer`.
|
||||
|
||||
Args:
|
||||
args (argparse.Namespace): parsed command-line arguments
|
||||
cfg (argparse.Namespace): parsed command-line arguments
|
||||
dictionary (~fairseq.data.Dictionary): decoding dictionary
|
||||
embed_tokens (torch.nn.Embedding): output embedding
|
||||
no_encoder_attn (bool, optional): whether to attend to encoder outputs
|
||||
|
392
fairseq/models/transformer/transformer_decoder_aug.py
Normal file
392
fairseq/models/transformer/transformer_decoder_aug.py
Normal file
@ -0,0 +1,392 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
|
||||
from fairseq import utils
|
||||
from fairseq.distributed import fsdp_wrap
|
||||
from fairseq.models.transformer import TransformerConfig
|
||||
from fairseq.models.transformer.transformer_decoder import TransformerDecoderBase
|
||||
from fairseq.modules import (
|
||||
LayerDropModuleList,
|
||||
SinusoidalPositionalEmbedding,
|
||||
transformer_layer_aug,
|
||||
)
|
||||
from fairseq.modules.checkpoint_activations import checkpoint_wrapper
|
||||
|
||||
|
||||
class AugTransformerDecoderBase(TransformerDecoderBase):
|
||||
"""
|
||||
Transformer decoder augmented with an additional cross-attention. Each layer
|
||||
is a :class:`AugTransformerDecoderLayerBase`.
|
||||
|
||||
Args:
|
||||
cfg (argparse.Namespace): parsed command-line arguments
|
||||
dictionary (~fairseq.data.Dictionary): decoding dictionary
|
||||
embed_tokens (torch.nn.Embedding): output embedding
|
||||
encoder_attn_merge_type (str, optional): the way to combine outputs from
|
||||
two cross-attention modules. If "sequential" is set, two cross-attention
|
||||
modules are stacked sequentially. If "parallel" is set, they are processed
|
||||
in parallel and combined before feeding it to FFN (default: sequential).
|
||||
dropnet_ratio (float, optional): a probability to drop each cross-attention
|
||||
module during training (default: 0.0).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cfg,
|
||||
dictionary,
|
||||
embed_tokens,
|
||||
output_projection=None,
|
||||
encoder_attn_merge_type="sequential",
|
||||
dropnet_ratio=0.0,
|
||||
):
|
||||
super().__init__(
|
||||
cfg,
|
||||
dictionary,
|
||||
embed_tokens,
|
||||
no_encoder_attn=False,
|
||||
output_projection=output_projection,
|
||||
)
|
||||
# assert cfg.cross_self_attention
|
||||
self.cross_self_attention = cfg.cross_self_attention
|
||||
|
||||
if self.decoder_layerdrop > 0.0:
|
||||
self.layers = LayerDropModuleList(p=self.decoder_layerdrop)
|
||||
else:
|
||||
self.layers = nn.ModuleList([])
|
||||
self.layers.extend(
|
||||
[
|
||||
self.build_decoder_layer(cfg, encoder_attn_merge_type, dropnet_ratio)
|
||||
for _ in range(cfg.decoder.layers)
|
||||
]
|
||||
)
|
||||
|
||||
def build_decoder_layer(
|
||||
self,
|
||||
cfg,
|
||||
encoder_attn_merge_type="sequential",
|
||||
dropnet_ratio=0,
|
||||
):
|
||||
layer = transformer_layer_aug.AugTransformerDecoderLayerBase(
|
||||
cfg,
|
||||
no_encoder_attn=False,
|
||||
encoder_attn_merge_type=encoder_attn_merge_type,
|
||||
dropnet_ratio=dropnet_ratio,
|
||||
)
|
||||
checkpoint = cfg.checkpoint_activations
|
||||
if checkpoint:
|
||||
offload_to_cpu = cfg.offload_activations
|
||||
layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu)
|
||||
# if we are checkpointing, enforce that FSDP always wraps the
|
||||
# checkpointed layer, regardless of layer size
|
||||
min_params_to_wrap = cfg.min_params_to_wrap if not checkpoint else 0
|
||||
layer = fsdp_wrap(layer, min_num_params=min_params_to_wrap)
|
||||
return layer
|
||||
|
||||
def forward(
|
||||
self,
|
||||
prev_output_tokens,
|
||||
encoder_out: Optional[Dict[str, List[Tensor]]] = None,
|
||||
encoder_out_aug: Optional[Dict[str, List[Tensor]]] = None,
|
||||
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
|
||||
features_only: bool = False,
|
||||
full_context_alignment: bool = False,
|
||||
alignment_layer: Optional[int] = None,
|
||||
alignment_heads: Optional[int] = None,
|
||||
src_lengths: Optional[Any] = None,
|
||||
return_all_hiddens: bool = False,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
prev_output_tokens (LongTensor): previous decoder outputs of shape
|
||||
`(batch, tgt_len)`, for teacher forcing
|
||||
encoder_out (optional): output from the encoder, used for
|
||||
encoder-side attention, should be of size T x B x C
|
||||
incremental_state (dict): dictionary used for storing state during
|
||||
:ref:`Incremental decoding`
|
||||
features_only (bool, optional): only return features without
|
||||
applying output layer (default: False).
|
||||
full_context_alignment (bool, optional): don't apply
|
||||
auto-regressive mask to self-attention (default: False).
|
||||
|
||||
Returns:
|
||||
tuple:
|
||||
- the decoder's output of shape `(batch, tgt_len, vocab)`
|
||||
- a dictionary with any model-specific outputs
|
||||
"""
|
||||
|
||||
x, extra = self.extract_features(
|
||||
prev_output_tokens,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_aug=encoder_out_aug,
|
||||
incremental_state=incremental_state,
|
||||
full_context_alignment=full_context_alignment,
|
||||
alignment_layer=alignment_layer,
|
||||
alignment_heads=alignment_heads,
|
||||
)
|
||||
|
||||
if not features_only:
|
||||
x = self.output_layer(x)
|
||||
return x, extra
|
||||
|
||||
def extract_features(
|
||||
self,
|
||||
prev_output_tokens,
|
||||
encoder_out: Optional[Dict[str, List[Tensor]]],
|
||||
encoder_out_aug: Optional[Dict[str, List[Tensor]]],
|
||||
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
|
||||
full_context_alignment: bool = False,
|
||||
alignment_layer: Optional[int] = None,
|
||||
alignment_heads: Optional[int] = None,
|
||||
):
|
||||
return self.extract_features_scriptable(
|
||||
prev_output_tokens,
|
||||
encoder_out,
|
||||
encoder_out_aug,
|
||||
incremental_state,
|
||||
full_context_alignment,
|
||||
alignment_layer,
|
||||
alignment_heads,
|
||||
)
|
||||
|
||||
"""
|
||||
A scriptable subclass of this class has an extract_features method and calls
|
||||
super().extract_features, but super() is not supported in torchscript. A copy of
|
||||
this function is made to be used in the subclass instead.
|
||||
"""
|
||||
|
||||
def extract_features_scriptable(
|
||||
self,
|
||||
prev_output_tokens,
|
||||
encoder_out: Optional[Dict[str, List[Tensor]]],
|
||||
encoder_out_aug: Optional[Dict[str, List[Tensor]]],
|
||||
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
|
||||
full_context_alignment: bool = False,
|
||||
alignment_layer: Optional[int] = None,
|
||||
alignment_heads: Optional[int] = None,
|
||||
):
|
||||
"""
|
||||
Similar to *forward* but only return features.
|
||||
|
||||
Includes several features from "Jointly Learning to Align and
|
||||
Translate with Transformer Models" (Garg et al., EMNLP 2019).
|
||||
|
||||
Args:
|
||||
full_context_alignment (bool, optional): don't apply
|
||||
auto-regressive mask to self-attention (default: False).
|
||||
alignment_layer (int, optional): return mean alignment over
|
||||
heads at this layer (default: last layer).
|
||||
alignment_heads (int, optional): only average alignment over
|
||||
this many heads (default: all heads).
|
||||
|
||||
Returns:
|
||||
tuple:
|
||||
- the decoder's features of shape `(batch, tgt_len, embed_dim)`
|
||||
- a dictionary with any model-specific outputs
|
||||
"""
|
||||
bs, slen = prev_output_tokens.size()
|
||||
if alignment_layer is None:
|
||||
alignment_layer = self.num_layers - 1
|
||||
|
||||
enc: Optional[Tensor] = None
|
||||
padding_mask: Optional[Tensor] = None
|
||||
if encoder_out is not None and len(encoder_out["encoder_out"]) > 0:
|
||||
enc = encoder_out["encoder_out"][0]
|
||||
if encoder_out is not None and len(encoder_out["encoder_padding_mask"]) > 0:
|
||||
padding_mask = encoder_out["encoder_padding_mask"][0]
|
||||
|
||||
enc_aug: Optional[Tensor] = None
|
||||
padding_mask_aug: Optional[Tensor] = None
|
||||
if encoder_out_aug is not None and len(encoder_out_aug["encoder_out"]) > 0:
|
||||
enc_aug = encoder_out_aug["encoder_out"][0]
|
||||
if (
|
||||
encoder_out_aug is not None
|
||||
and len(encoder_out_aug["encoder_padding_mask"]) > 0
|
||||
):
|
||||
padding_mask_aug = encoder_out_aug["encoder_padding_mask"][0]
|
||||
|
||||
# embed positions
|
||||
positions = None
|
||||
if self.embed_positions is not None:
|
||||
positions = self.embed_positions(
|
||||
prev_output_tokens, incremental_state=incremental_state
|
||||
)
|
||||
|
||||
if incremental_state is not None:
|
||||
prev_output_tokens = prev_output_tokens[:, -1:]
|
||||
if positions is not None:
|
||||
positions = positions[:, -1:]
|
||||
|
||||
# Prevent torchscript exporting issue for dynamic quant embedding
|
||||
prev_output_tokens = prev_output_tokens.contiguous()
|
||||
# embed tokens and positions
|
||||
x = self.embed_scale * self.embed_tokens(prev_output_tokens)
|
||||
|
||||
if self.quant_noise is not None:
|
||||
x = self.quant_noise(x)
|
||||
|
||||
if self.project_in_dim is not None:
|
||||
x = self.project_in_dim(x)
|
||||
|
||||
if positions is not None:
|
||||
x += positions
|
||||
|
||||
if self.layernorm_embedding is not None:
|
||||
x = self.layernorm_embedding(x)
|
||||
|
||||
x = self.dropout_module(x)
|
||||
|
||||
# B x T x C -> T x B x C
|
||||
x = x.transpose(0, 1)
|
||||
|
||||
self_attn_padding_mask: Optional[Tensor] = None
|
||||
if self.cross_self_attention or prev_output_tokens.eq(self.padding_idx).any():
|
||||
self_attn_padding_mask = prev_output_tokens.eq(self.padding_idx)
|
||||
|
||||
# decoder layers
|
||||
attn: Optional[Tensor] = None
|
||||
attn_aug: Optional[Tensor] = None
|
||||
inner_states: List[Optional[Tensor]] = [x]
|
||||
for idx, layer in enumerate(self.layers):
|
||||
if incremental_state is None and not full_context_alignment:
|
||||
self_attn_mask = self.buffered_future_mask(x)
|
||||
else:
|
||||
self_attn_mask = None
|
||||
|
||||
x, layer_attn, layer_attn_aug, _ = layer(
|
||||
x,
|
||||
enc,
|
||||
padding_mask,
|
||||
enc_aug,
|
||||
padding_mask_aug,
|
||||
incremental_state,
|
||||
self_attn_mask=self_attn_mask,
|
||||
self_attn_padding_mask=self_attn_padding_mask,
|
||||
need_attn=bool((idx == alignment_layer)),
|
||||
need_head_weights=bool((idx == alignment_layer)),
|
||||
)
|
||||
inner_states.append(x)
|
||||
if layer_attn is not None and idx == alignment_layer:
|
||||
attn = layer_attn.float().to(x)
|
||||
if layer_attn_aug is not None and idx == alignment_layer:
|
||||
attn_aug = layer_attn_aug.float().to(x)
|
||||
|
||||
if attn is not None:
|
||||
if alignment_heads is not None:
|
||||
attn = attn[:alignment_heads]
|
||||
|
||||
# average probabilities over heads
|
||||
attn = attn.mean(dim=0)
|
||||
|
||||
if attn_aug is not None:
|
||||
if alignment_heads is not None:
|
||||
attn_aug = attn_aug[:alignment_heads]
|
||||
|
||||
# average probabilities over heads
|
||||
attn_aug = attn_aug.mean(dim=0)
|
||||
|
||||
if self.layer_norm is not None:
|
||||
x = self.layer_norm(x)
|
||||
|
||||
# T x B x C -> B x T x C
|
||||
x = x.transpose(0, 1)
|
||||
|
||||
if self.project_out_dim is not None:
|
||||
x = self.project_out_dim(x)
|
||||
|
||||
return x, {"attn": [attn], "attn_aug": [attn_aug], "inner_states": inner_states}
|
||||
|
||||
def upgrade_state_dict_named(self, state_dict, name):
|
||||
"""Upgrade a (possibly old) state dict for new versions of fairseq."""
|
||||
if isinstance(self.embed_positions, SinusoidalPositionalEmbedding):
|
||||
weights_key = "{}.embed_positions.weights".format(name)
|
||||
if weights_key in state_dict:
|
||||
del state_dict[weights_key]
|
||||
state_dict[
|
||||
"{}.embed_positions._float_tensor".format(name)
|
||||
] = torch.FloatTensor(1)
|
||||
|
||||
if f"{name}.output_projection.weight" not in state_dict:
|
||||
if self.share_input_output_embed:
|
||||
embed_out_key = f"{name}.embed_tokens.weight"
|
||||
else:
|
||||
embed_out_key = f"{name}.embed_out"
|
||||
if embed_out_key in state_dict:
|
||||
state_dict[f"{name}.output_projection.weight"] = state_dict[
|
||||
embed_out_key
|
||||
]
|
||||
if not self.share_input_output_embed:
|
||||
del state_dict[embed_out_key]
|
||||
|
||||
for i in range(self.num_layers):
|
||||
# update layer norms
|
||||
layer_norm_map = {
|
||||
"0": "self_attn_layer_norm",
|
||||
"1": "encoder_attn_layer_norm",
|
||||
"2": "encoder_attn_layer_norm2",
|
||||
"3": "final_layer_norm",
|
||||
}
|
||||
for old, new in layer_norm_map.items():
|
||||
for m in ("weight", "bias"):
|
||||
k = "{}.layers.{}.layer_norms.{}.{}".format(name, i, old, m)
|
||||
if k in state_dict:
|
||||
state_dict[
|
||||
"{}.layers.{}.{}.{}".format(name, i, new, m)
|
||||
] = state_dict[k]
|
||||
del state_dict[k]
|
||||
|
||||
version_key = "{}.version".format(name)
|
||||
if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) <= 2:
|
||||
# earlier checkpoints did not normalize after the stack of layers
|
||||
self.layer_norm = None
|
||||
self.normalize = False
|
||||
state_dict[version_key] = torch.Tensor([1])
|
||||
|
||||
return state_dict
|
||||
|
||||
|
||||
class AugTransformerDecoder(AugTransformerDecoderBase):
|
||||
def __init__(
|
||||
self,
|
||||
args,
|
||||
dictionary,
|
||||
embed_tokens,
|
||||
output_projection=None,
|
||||
):
|
||||
self.args = args
|
||||
super().__init__(
|
||||
TransformerConfig.from_namespace(args),
|
||||
dictionary,
|
||||
embed_tokens,
|
||||
no_encoder_attn=False,
|
||||
output_projection=output_projection,
|
||||
encoder_attn_merge_type=getattr(
|
||||
args, "synthesizer_augmented_cross_attention_merge_type", "sequential"
|
||||
),
|
||||
dropnet_ratio=getattr(args, "dropnet_ratio", 0),
|
||||
)
|
||||
|
||||
def build_output_projection(self, args, dictionary, embed_tokens):
|
||||
super().build_output_projection(
|
||||
TransformerConfig.from_namespace(args), dictionary, embed_tokens
|
||||
)
|
||||
|
||||
def build_decoder_layer(
|
||||
self,
|
||||
args,
|
||||
encoder_attn_merge_type="sequential",
|
||||
dropnet_ratio=0,
|
||||
):
|
||||
return super().build_decoder_layer(
|
||||
TransformerConfig.from_namespace(args),
|
||||
no_encoder_attn=False,
|
||||
encoder_attn_merge_type=encoder_attn_merge_type,
|
||||
dropnet_ratio=dropnet_ratio,
|
||||
)
|
@ -28,7 +28,7 @@ class TransformerEncoderLayerBase(nn.Module):
|
||||
*cfg.encoder.normalize_before* to ``True``.
|
||||
|
||||
Args:
|
||||
args (argparse.Namespace): parsed command-line arguments
|
||||
cfg (argparse.Namespace): parsed command-line arguments
|
||||
"""
|
||||
|
||||
def __init__(self, cfg, return_fc=False):
|
||||
|
315
fairseq/modules/transformer_layer_aug.py
Normal file
315
fairseq/modules/transformer_layer_aug.py
Normal file
@ -0,0 +1,315 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from numpy.random import uniform
|
||||
from torch import Tensor
|
||||
|
||||
from fairseq.modules import LayerNorm
|
||||
from fairseq.modules.transformer_layer import TransformerDecoderLayerBase
|
||||
|
||||
|
||||
class AugTransformerDecoderLayerBase(TransformerDecoderLayerBase):
|
||||
"""Decoder layer block augmented with an additional cross-attention.
|
||||
|
||||
This decoder block is processed with the sequence of the following sub-modules.
|
||||
self-attention -> cross-attention (first) -> cross-attention (second) -> FFN
|
||||
|
||||
Args:
|
||||
cfg (argparse.Namespace): parsed command-line arguments
|
||||
encoder_attn_merge_type (str, optional): the way to combine outputs from
|
||||
two cross-attention modules. If "sequential" is set, two cross-attention
|
||||
modules are stacked sequentially. If "parallel" is set, they are processed
|
||||
in parallel and combined before feeding it to FFN (default: sequential).
|
||||
dropnet_ratio (float, optional): a probability to drop each cross-attention
|
||||
module during training (default: 0.0).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cfg,
|
||||
add_bias_kv=False,
|
||||
add_zero_attn=False,
|
||||
encoder_attn_merge_type="sequential",
|
||||
dropnet_ratio=0.0,
|
||||
):
|
||||
super().__init__(
|
||||
cfg,
|
||||
no_encoder_attn=False,
|
||||
add_bias_kv=add_bias_kv,
|
||||
add_zero_attn=False,
|
||||
)
|
||||
self.encoder_attn = self.build_encoder_attention(self.embed_dim, cfg)
|
||||
self.encoder_attn_layer_norm = LayerNorm(self.embed_dim, export=cfg.export)
|
||||
self.encoder_attn2 = self.build_encoder_attention(self.embed_dim, cfg)
|
||||
if encoder_attn_merge_type == "sequential":
|
||||
self.encoder_attn_layer_norm2 = LayerNorm(self.embed_dim, export=cfg.export)
|
||||
else:
|
||||
self.encoder_attn_layer_norm2 = None
|
||||
|
||||
self.encoder_attn_merge_type = encoder_attn_merge_type
|
||||
self.dropnet_ratio = dropnet_ratio
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
encoder_out: Optional[torch.Tensor] = None,
|
||||
encoder_padding_mask: Optional[torch.Tensor] = None,
|
||||
encoder_out_aug: Optional[torch.Tensor] = None,
|
||||
encoder_padding_mask2: Optional[torch.Tensor] = None,
|
||||
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
|
||||
prev_self_attn_state: Optional[List[torch.Tensor]] = None,
|
||||
prev_attn_state: Optional[List[torch.Tensor]] = None,
|
||||
self_attn_mask: Optional[torch.Tensor] = None,
|
||||
self_attn_padding_mask: Optional[torch.Tensor] = None,
|
||||
need_attn: bool = False,
|
||||
need_head_weights: bool = False,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
|
||||
encoder_padding_mask (ByteTensor, optional): binary
|
||||
ByteTensor of shape `(batch, src_len)` where padding
|
||||
elements are indicated by ``1``.
|
||||
need_attn (bool, optional): return attention weights
|
||||
need_head_weights (bool, optional): return attention weights
|
||||
for each head (default: return average over heads).
|
||||
|
||||
Returns:
|
||||
encoded output of shape `(seq_len, batch, embed_dim)`
|
||||
"""
|
||||
if need_head_weights:
|
||||
need_attn = True
|
||||
|
||||
residual = x
|
||||
if self.normalize_before:
|
||||
x = self.self_attn_layer_norm(x)
|
||||
if prev_self_attn_state is not None:
|
||||
prev_key, prev_value = prev_self_attn_state[:2]
|
||||
saved_state: Dict[str, Optional[Tensor]] = {
|
||||
"prev_key": prev_key,
|
||||
"prev_value": prev_value,
|
||||
}
|
||||
if len(prev_self_attn_state) >= 3:
|
||||
saved_state["prev_key_padding_mask"] = prev_self_attn_state[2]
|
||||
assert incremental_state is not None
|
||||
self.self_attn._set_input_buffer(incremental_state, saved_state)
|
||||
_self_attn_input_buffer = self.self_attn._get_input_buffer(incremental_state)
|
||||
if self.cross_self_attention and not (
|
||||
incremental_state is not None
|
||||
and _self_attn_input_buffer is not None
|
||||
and "prev_key" in _self_attn_input_buffer
|
||||
):
|
||||
if self_attn_mask is not None:
|
||||
assert encoder_out is not None
|
||||
self_attn_mask = torch.cat(
|
||||
(x.new_zeros(x.size(0), encoder_out.size(0)), self_attn_mask), dim=1
|
||||
)
|
||||
if self_attn_padding_mask is not None:
|
||||
if encoder_padding_mask is None:
|
||||
assert encoder_out is not None
|
||||
encoder_padding_mask = self_attn_padding_mask.new_zeros(
|
||||
encoder_out.size(1), encoder_out.size(0)
|
||||
)
|
||||
self_attn_padding_mask = torch.cat(
|
||||
(encoder_padding_mask, self_attn_padding_mask), dim=1
|
||||
)
|
||||
assert encoder_out is not None
|
||||
y = torch.cat((encoder_out, x), dim=0)
|
||||
else:
|
||||
y = x
|
||||
|
||||
x, attn = self.self_attn(
|
||||
query=x,
|
||||
key=y,
|
||||
value=y,
|
||||
key_padding_mask=self_attn_padding_mask,
|
||||
incremental_state=incremental_state,
|
||||
need_weights=False,
|
||||
attn_mask=self_attn_mask,
|
||||
)
|
||||
if self.c_attn is not None:
|
||||
tgt_len, bsz = x.size(0), x.size(1)
|
||||
x = x.view(tgt_len, bsz, self.nh, self.head_dim)
|
||||
x = torch.einsum("tbhd,h->tbhd", x, self.c_attn)
|
||||
x = x.reshape(tgt_len, bsz, self.embed_dim)
|
||||
if self.attn_ln is not None:
|
||||
x = self.attn_ln(x)
|
||||
x = self.dropout_module(x)
|
||||
x = self.residual_connection(x, residual)
|
||||
if not self.normalize_before:
|
||||
x = self.self_attn_layer_norm(x)
|
||||
|
||||
assert encoder_out is not None
|
||||
assert encoder_out_aug is not None
|
||||
|
||||
if self.encoder_attn_merge_type == "sequential":
|
||||
ratios = self.get_dropnet_ratio()
|
||||
|
||||
# first encoder attention
|
||||
if ratios[0] > 0:
|
||||
residual = x
|
||||
if self.normalize_before:
|
||||
x = self.encoder_attn_layer_norm(x)
|
||||
if prev_attn_state is not None:
|
||||
prev_key, prev_value = prev_attn_state[:2]
|
||||
saved_state: Dict[str, Optional[Tensor]] = {
|
||||
"prev_key": prev_key,
|
||||
"prev_value": prev_value,
|
||||
}
|
||||
if len(prev_attn_state) >= 3:
|
||||
saved_state["prev_key_padding_mask"] = prev_attn_state[2]
|
||||
assert incremental_state is not None
|
||||
self.encoder_attn._set_input_buffer(incremental_state, saved_state)
|
||||
|
||||
x, attn = self.encoder_attn(
|
||||
query=x,
|
||||
key=encoder_out,
|
||||
value=encoder_out,
|
||||
key_padding_mask=encoder_padding_mask,
|
||||
incremental_state=incremental_state,
|
||||
static_kv=True,
|
||||
need_weights=need_attn or (not self.training and self.need_attn),
|
||||
need_head_weights=need_head_weights,
|
||||
)
|
||||
x = self.dropout_module(x)
|
||||
x = self.residual_connection(x, residual)
|
||||
if not self.normalize_before:
|
||||
x = self.encoder_attn_layer_norm(x)
|
||||
x = ratios[0] * x
|
||||
|
||||
# second encoder attention
|
||||
if ratios[1] > 0:
|
||||
residual = x
|
||||
if self.normalize_before:
|
||||
x = self.encoder_attn_layer_norm2(x)
|
||||
if prev_attn_state is not None:
|
||||
prev_key, prev_value = prev_attn_state[:2]
|
||||
saved_state: Dict[str, Optional[Tensor]] = {
|
||||
"prev_key": prev_key,
|
||||
"prev_value": prev_value,
|
||||
}
|
||||
if len(prev_attn_state) >= 3:
|
||||
saved_state["prev_key_padding_mask"] = prev_attn_state[2]
|
||||
assert incremental_state is not None
|
||||
self.encoder_attn2._set_input_buffer(incremental_state, saved_state)
|
||||
|
||||
x, attn2 = self.encoder_attn2(
|
||||
query=x,
|
||||
key=encoder_out_aug,
|
||||
value=encoder_out_aug,
|
||||
key_padding_mask=encoder_padding_mask2,
|
||||
incremental_state=incremental_state,
|
||||
static_kv=True,
|
||||
need_weights=need_attn or (not self.training and self.need_attn),
|
||||
need_head_weights=need_head_weights,
|
||||
)
|
||||
x = self.dropout_module(x)
|
||||
x = self.residual_connection(x, residual)
|
||||
if not self.normalize_before:
|
||||
x = self.encoder_attn_layer_norm2(x)
|
||||
x = ratios[1] * x
|
||||
|
||||
elif self.encoder_attn_merge_type == "parallel":
|
||||
residual = x
|
||||
if self.normalize_before:
|
||||
x = self.encoder_attn_layer_norm(x)
|
||||
if prev_attn_state is not None:
|
||||
prev_key, prev_value = prev_attn_state[:2]
|
||||
saved_state: Dict[str, Optional[Tensor]] = {
|
||||
"prev_key": prev_key,
|
||||
"prev_value": prev_value,
|
||||
}
|
||||
if len(prev_attn_state) >= 3:
|
||||
saved_state["prev_key_padding_mask"] = prev_attn_state[2]
|
||||
assert incremental_state is not None
|
||||
self.encoder_attn._set_input_buffer(incremental_state, saved_state)
|
||||
|
||||
x1, attn = self.encoder_attn(
|
||||
query=x,
|
||||
key=encoder_out,
|
||||
value=encoder_out,
|
||||
key_padding_mask=encoder_padding_mask,
|
||||
incremental_state=incremental_state,
|
||||
static_kv=True,
|
||||
need_weights=need_attn or (not self.training and self.need_attn),
|
||||
need_head_weights=need_head_weights,
|
||||
)
|
||||
x2, attn2 = self.encoder_attn2(
|
||||
query=x,
|
||||
key=encoder_out_aug,
|
||||
value=encoder_out_aug,
|
||||
key_padding_mask=encoder_padding_mask2,
|
||||
incremental_state=incremental_state,
|
||||
static_kv=True,
|
||||
need_weights=need_attn or (not self.training and self.need_attn),
|
||||
need_head_weights=need_head_weights,
|
||||
)
|
||||
x1 = self.dropout_module(x1)
|
||||
x2 = self.dropout_module(x2)
|
||||
ratios = self.get_dropnet_ratio()
|
||||
x = ratios[0] * x1 + ratios[1] * x2
|
||||
x = self.residual_connection(x, residual)
|
||||
if not self.normalize_before:
|
||||
x = self.encoder_attn_layer_norm(x)
|
||||
|
||||
else:
|
||||
raise NotImplementedError(self.encoder_attn_merge_type)
|
||||
|
||||
residual = x
|
||||
if self.normalize_before:
|
||||
x = self.final_layer_norm(x)
|
||||
|
||||
x = self.activation_fn(self.fc1(x))
|
||||
x = self.activation_dropout_module(x)
|
||||
if self.ffn_layernorm is not None:
|
||||
x = self.ffn_layernorm(x)
|
||||
x = self.fc2(x)
|
||||
x = self.dropout_module(x)
|
||||
if self.w_resid is not None:
|
||||
residual = torch.mul(self.w_resid, residual)
|
||||
x = self.residual_connection(x, residual)
|
||||
if not self.normalize_before:
|
||||
x = self.final_layer_norm(x)
|
||||
if self.onnx_trace and incremental_state is not None:
|
||||
saved_state = self.self_attn._get_input_buffer(incremental_state)
|
||||
assert saved_state is not None
|
||||
if self_attn_padding_mask is not None:
|
||||
self_attn_state = [
|
||||
saved_state["prev_key"],
|
||||
saved_state["prev_value"],
|
||||
saved_state["prev_key_padding_mask"],
|
||||
]
|
||||
else:
|
||||
self_attn_state = [saved_state["prev_key"], saved_state["prev_value"]]
|
||||
return x, attn, attn2, self_attn_state
|
||||
return x, attn, attn2, None
|
||||
|
||||
def get_dropnet_ratio(self):
|
||||
if self.encoder_attn_merge_type == "sequential":
|
||||
if self.dropnet_ratio > 0:
|
||||
frand = float(uniform(0, 1))
|
||||
if frand < self.dropnet_ratio and self.training:
|
||||
return [2, 0]
|
||||
elif frand > 1 - self.dropnet_ratio and self.training:
|
||||
return [0, 2]
|
||||
else:
|
||||
return [1, 1]
|
||||
else:
|
||||
return [1, 1]
|
||||
|
||||
elif self.encoder_attn_merge_type == "parallel":
|
||||
if self.dropnet_ratio > 0:
|
||||
frand = float(uniform(0, 1))
|
||||
if frand < self.dropnet_ratio and self.training:
|
||||
return [1, 0]
|
||||
elif frand > 1 - self.dropnet_ratio and self.training:
|
||||
return [0, 1]
|
||||
else:
|
||||
return [0.5, 0.5]
|
||||
else:
|
||||
return [0.5, 0.5]
|
@ -76,7 +76,203 @@ class AutoRegressiveSpeechGenerator(SpeechGenerator):
|
||||
incremental_state=incremental_state,
|
||||
target_lengths=cur_out_lens,
|
||||
speaker=sample["speaker"],
|
||||
**kwargs
|
||||
**kwargs,
|
||||
)
|
||||
cur_eos_prob = torch.sigmoid(cur_eos_out).squeeze(2)
|
||||
feat.append(cur_extra["feature_out"])
|
||||
attn.append(cur_extra["attn"])
|
||||
eos_prob.append(cur_eos_prob)
|
||||
|
||||
cur_finished = cur_eos_prob.squeeze(1) > self.eos_prob_threshold
|
||||
out_lens.masked_fill_((~finished) & cur_finished, step + 1)
|
||||
finished = finished | cur_finished
|
||||
if finished.sum().item() == bsz:
|
||||
break
|
||||
prev_feat_out = cur_extra["feature_out"]
|
||||
|
||||
feat = torch.cat(feat, dim=1)
|
||||
feat = model.decoder.postnet(feat) + feat
|
||||
eos_prob = torch.cat(eos_prob, dim=1)
|
||||
attn = torch.cat(attn, dim=2)
|
||||
alignment = attn.max(dim=1)[1]
|
||||
|
||||
feat = feat.reshape(bsz, -1, raw_dim)
|
||||
feat = self.gcmvn_denormalize(feat)
|
||||
|
||||
eos_prob = eos_prob.repeat_interleave(n_frames_per_step, dim=1)
|
||||
attn = attn.repeat_interleave(n_frames_per_step, dim=2)
|
||||
alignment = alignment.repeat_interleave(n_frames_per_step, dim=1)
|
||||
out_lens = out_lens * n_frames_per_step
|
||||
|
||||
finalized = [
|
||||
{
|
||||
"feature": feat[b, :out_len],
|
||||
"eos_prob": eos_prob[b, :out_len],
|
||||
"attn": attn[b, :, :out_len],
|
||||
"alignment": alignment[b, :out_len],
|
||||
"waveform": self.get_waveform(feat[b, :out_len]),
|
||||
}
|
||||
for b, out_len in zip(range(bsz), out_lens)
|
||||
]
|
||||
|
||||
if has_targ:
|
||||
assert sample["target"].size(-1) == out_dim
|
||||
tgt_feats = sample["target"].view(bsz, -1, raw_dim)
|
||||
tgt_feats = self.gcmvn_denormalize(tgt_feats)
|
||||
tgt_lens = sample["target_lengths"] * n_frames_per_step
|
||||
for b, (f, l) in enumerate(zip(tgt_feats, tgt_lens)):
|
||||
finalized[b]["targ_feature"] = f[:l]
|
||||
finalized[b]["targ_waveform"] = self.get_waveform(f[:l])
|
||||
return finalized
|
||||
|
||||
|
||||
class MultiDecoderSpeechGenerator(SpeechGenerator):
|
||||
def __init__(
|
||||
self,
|
||||
models,
|
||||
args,
|
||||
vocoder,
|
||||
data_cfg,
|
||||
tgt_dict_mt,
|
||||
max_iter: int = 6000,
|
||||
eos_prob_threshold: float = 0.5,
|
||||
eos_mt=None,
|
||||
symbols_to_strip_from_output=None,
|
||||
):
|
||||
super().__init__(models[0], vocoder, data_cfg)
|
||||
self.max_iter = max_iter
|
||||
self.eos_prob_threshold = eos_prob_threshold
|
||||
|
||||
self.tgt_dict_mt = tgt_dict_mt
|
||||
self.eos_mt = eos_mt
|
||||
|
||||
from examples.speech_to_speech.unity.sequence_generator import SequenceGenerator
|
||||
from fairseq import search
|
||||
|
||||
self.text_generator = SequenceGenerator(
|
||||
models,
|
||||
tgt_dict_mt,
|
||||
beam_size=max(1, getattr(args, "beam", 5)),
|
||||
max_len_a=getattr(args, "max_len_a", 0),
|
||||
max_len_b=getattr(args, "max_len_b", 200),
|
||||
min_len=getattr(args, "min_len", 1),
|
||||
normalize_scores=(not getattr(args, "unnormalized", False)),
|
||||
len_penalty=getattr(args, "lenpen", 1),
|
||||
unk_penalty=getattr(args, "unkpen", 0),
|
||||
temperature=getattr(args, "temperature", 1.0),
|
||||
match_source_len=getattr(args, "match_source_len", False),
|
||||
no_repeat_ngram_size=getattr(args, "no_repeat_ngram_size", 0),
|
||||
search_strategy=search.BeamSearch(tgt_dict_mt),
|
||||
eos=eos_mt,
|
||||
symbols_to_strip_from_output=symbols_to_strip_from_output,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(self, model, sample, has_targ=False, **kwargs):
|
||||
model.eval()
|
||||
|
||||
src_tokens = sample["net_input"]["src_tokens"]
|
||||
src_lengths = sample["net_input"]["src_lengths"]
|
||||
bsz, src_len = src_tokens.size()[:2]
|
||||
n_frames_per_step = model.decoder.n_frames_per_step
|
||||
out_dim = model.decoder.out_dim
|
||||
raw_dim = out_dim // n_frames_per_step
|
||||
|
||||
# initialize
|
||||
encoder_out = model.forward_encoder(
|
||||
src_tokens, src_lengths, speaker=sample["speaker"]
|
||||
)
|
||||
|
||||
prefix_tokens = None
|
||||
constraints = None
|
||||
bos_token = None
|
||||
|
||||
mt_decoder = getattr(model, f"{model.mt_task_name}_decoder")
|
||||
|
||||
# 1. MT decoder
|
||||
finalized_mt = self.text_generator.generate_decoder(
|
||||
[encoder_out],
|
||||
src_tokens,
|
||||
src_lengths,
|
||||
sample,
|
||||
prefix_tokens,
|
||||
constraints,
|
||||
bos_token,
|
||||
aux_task_name=model.mt_task_name,
|
||||
)
|
||||
|
||||
# extract decoder output corresponding to the best hypothesis
|
||||
max_tgt_len = max([len(hypo[0]["tokens"]) for hypo in finalized_mt])
|
||||
prev_output_tokens_mt = (
|
||||
src_tokens.new_zeros(src_tokens.shape[0], max_tgt_len)
|
||||
.fill_(mt_decoder.padding_idx)
|
||||
.int()
|
||||
) # B x T
|
||||
for i, hypo in enumerate(finalized_mt):
|
||||
i_beam = 0
|
||||
tmp = hypo[i_beam]["tokens"].int() # hyp + eos
|
||||
prev_output_tokens_mt[i, 0] = self.text_generator.eos
|
||||
if tmp[-1] == self.text_generator.eos:
|
||||
tmp = tmp[:-1]
|
||||
prev_output_tokens_mt[i, 1 : len(tmp) + 1] = tmp
|
||||
|
||||
text = "".join([self.tgt_dict_mt[c] for c in tmp])
|
||||
text = text.replace("_", " ")
|
||||
text = text.replace("▁", " ")
|
||||
text = text.replace("<unk>", " ")
|
||||
text = text.replace("<s>", "")
|
||||
text = text.replace("</s>", "")
|
||||
if len(text) > 0 and text[0] == " ":
|
||||
text = text[1:]
|
||||
sample_id = sample["id"].tolist()[i]
|
||||
print("{} (None-{})".format(text, sample_id))
|
||||
|
||||
mt_decoder_out = mt_decoder(
|
||||
prev_output_tokens_mt,
|
||||
encoder_out=encoder_out,
|
||||
features_only=True,
|
||||
)
|
||||
x = mt_decoder_out[0].transpose(0, 1)
|
||||
|
||||
mt_decoder_padding_mask = None
|
||||
if prev_output_tokens_mt.eq(mt_decoder.padding_idx).any():
|
||||
mt_decoder_padding_mask = prev_output_tokens_mt.eq(mt_decoder.padding_idx)
|
||||
|
||||
# 2. TTS encoder
|
||||
if getattr(model, "synthesizer_encoder", None) is not None:
|
||||
synthesizer_encoder_out = model.synthesizer_encoder(
|
||||
x,
|
||||
mt_decoder_padding_mask,
|
||||
)
|
||||
else:
|
||||
synthesizer_encoder_out = {
|
||||
"encoder_out": [x], # T x B x C
|
||||
"encoder_padding_mask": [mt_decoder_padding_mask]
|
||||
if mt_decoder_padding_mask is not None
|
||||
else [], # B x T
|
||||
"encoder_embedding": [],
|
||||
"encoder_states": [],
|
||||
"src_tokens": [],
|
||||
"src_lengths": [],
|
||||
}
|
||||
|
||||
# 3. TTS decoder
|
||||
incremental_state = {}
|
||||
feat, attn, eos_prob = [], [], []
|
||||
finished = src_tokens.new_zeros((bsz,)).bool()
|
||||
out_lens = src_lengths.new_zeros((bsz,)).long().fill_(self.max_iter)
|
||||
|
||||
prev_feat_out = encoder_out["encoder_out"][0].new_zeros(bsz, 1, out_dim)
|
||||
for step in range(self.max_iter):
|
||||
cur_out_lens = out_lens.clone()
|
||||
cur_out_lens.masked_fill_(cur_out_lens.eq(self.max_iter), step + 1)
|
||||
_, cur_eos_out, cur_extra = model.forward_decoder(
|
||||
prev_feat_out,
|
||||
encoder_out=synthesizer_encoder_out,
|
||||
incremental_state=incremental_state,
|
||||
target_lengths=cur_out_lens,
|
||||
speaker=sample["speaker"],
|
||||
**kwargs,
|
||||
)
|
||||
cur_eos_prob = torch.sigmoid(cur_eos_out).squeeze(2)
|
||||
feat.append(cur_extra["feature_out"])
|
||||
|
@ -8,6 +8,7 @@ import logging
|
||||
import math
|
||||
from argparse import Namespace
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -16,7 +17,10 @@ from fairseq import utils
|
||||
from fairseq.data import Dictionary
|
||||
from fairseq.data.audio.data_cfg import MultitaskConfig, S2SDataConfig
|
||||
from fairseq.data.audio.speech_to_speech_dataset import SpeechToSpeechDatasetCreator
|
||||
from fairseq.data.audio.speech_to_text_dataset import SpeechToTextDataset
|
||||
from fairseq.data.audio.speech_to_text_dataset import (
|
||||
SpeechToTextDataset,
|
||||
TextTargetMultitaskData,
|
||||
)
|
||||
from fairseq.tasks import LegacyFairseqTask, register_task
|
||||
from fairseq.tasks.speech_to_text import DummyMultiTask
|
||||
from fairseq.tasks.text_to_speech import batch_mel_cepstral_distortion
|
||||
@ -209,15 +213,35 @@ class SpeechToSpeechTask(LegacyFairseqTask):
|
||||
super().__init__(args)
|
||||
self.tgt_dict = tgt_dict
|
||||
self.data_cfg = S2SDataConfig(Path(args.data) / args.config_yaml)
|
||||
|
||||
self.multitask_tasks = {}
|
||||
self.tgt_dict_mt = None
|
||||
self.eos_token_mt = None
|
||||
if getattr(args, "multitask_config_yaml", None) is not None:
|
||||
multitask_cfg = MultitaskConfig(
|
||||
Path(args.data) / args.multitask_config_yaml
|
||||
)
|
||||
for task_name, task_config in multitask_cfg.get_all_tasks().items():
|
||||
self.multitask_tasks[task_name] = DummyMultiTask(
|
||||
task_config, task_config.tgt_dict
|
||||
first_pass_task_idx = multitask_cfg.first_pass_decoder_task_index
|
||||
for i, (task_name, task_config) in enumerate(
|
||||
multitask_cfg.get_all_tasks().items()
|
||||
):
|
||||
task_obj = DummyMultiTask(
|
||||
task_config,
|
||||
task_config.tgt_dict,
|
||||
first_pass=i == first_pass_task_idx,
|
||||
)
|
||||
self.multitask_tasks[task_name] = task_obj
|
||||
if task_obj.is_first_pass_decoder:
|
||||
self.tgt_dict_mt = task_obj.target_dictionary
|
||||
if task_config.prepend_bos_and_append_tgt_lang_tag:
|
||||
self.eos_token_mt = task_config.eos_token
|
||||
assert not isinstance(self.eos_token_mt, List)
|
||||
|
||||
if not self.eos_token_mt:
|
||||
raise Warning(
|
||||
"Please provide eos_token in --multitask-config-yaml to replace eos in sequence generator"
|
||||
)
|
||||
|
||||
self._infer_tgt_lang_id = infer_tgt_lang_id
|
||||
|
||||
@classmethod
|
||||
@ -267,11 +291,13 @@ class SpeechToSpeechTask(LegacyFairseqTask):
|
||||
from fairseq import criterions
|
||||
|
||||
if len(self.multitask_tasks) > 0:
|
||||
if self.args.target_is_code and args._name != "speech_to_unit":
|
||||
if self.args.target_is_code and not args._name.startswith("speech_to_unit"):
|
||||
raise ValueError(
|
||||
"set --criterion speech_to_unit for speech-to-unit loss with multitask"
|
||||
)
|
||||
elif not self.args.target_is_code and args._name != "speech_to_spectrogram":
|
||||
elif not self.args.target_is_code and not args._name.startswith(
|
||||
"speech_to_spectrogram"
|
||||
):
|
||||
raise ValueError(
|
||||
"set --criterion speech_to_spectrogram for speech-to-spectrogram loss with multitask"
|
||||
)
|
||||
@ -296,6 +322,10 @@ class SpeechToSpeechTask(LegacyFairseqTask):
|
||||
def target_dictionary(self):
|
||||
return self.tgt_dict
|
||||
|
||||
@property
|
||||
def target_dictionary_mt(self):
|
||||
return self.tgt_dict_mt
|
||||
|
||||
@property
|
||||
def source_dictionary(self):
|
||||
return None
|
||||
@ -326,6 +356,36 @@ class SpeechToSpeechTask(LegacyFairseqTask):
|
||||
|
||||
return model
|
||||
|
||||
def build_generator_dual_decoder(
|
||||
self,
|
||||
models,
|
||||
args,
|
||||
extra_gen_cls_kwargs=None,
|
||||
):
|
||||
from examples.speech_to_speech.unity.sequence_generator_multi_decoder import (
|
||||
MultiDecoderSequenceGenerator,
|
||||
)
|
||||
|
||||
return MultiDecoderSequenceGenerator(
|
||||
models,
|
||||
self.target_dictionary,
|
||||
self.target_dictionary_mt,
|
||||
beam_size=max(1, getattr(args, "beam", 1)),
|
||||
beam_size_mt=max(1, getattr(args, "beam_mt", 1)),
|
||||
max_len_a=getattr(args, "max_len_a", 0),
|
||||
max_len_b=getattr(args, "max_len_b", 200),
|
||||
max_len_a_mt=getattr(args, "max_len_a_mt", 0),
|
||||
max_len_b_mt=getattr(args, "max_len_b_mt", 200),
|
||||
min_len=getattr(args, "min_len", 1),
|
||||
normalize_scores=(not getattr(args, "unnormalized", False)),
|
||||
len_penalty=getattr(args, "lenpen", 1),
|
||||
unk_penalty=getattr(args, "unkpen", 0),
|
||||
temperature=getattr(args, "temperature", 1.0),
|
||||
match_source_len=getattr(args, "match_source_len", False),
|
||||
no_repeat_ngram_size=getattr(args, "no_repeat_ngram_size", 0),
|
||||
**extra_gen_cls_kwargs,
|
||||
)
|
||||
|
||||
def build_generator(
|
||||
self,
|
||||
models,
|
||||
@ -344,14 +404,23 @@ class SpeechToSpeechTask(LegacyFairseqTask):
|
||||
else self.vocoder.cpu()
|
||||
)
|
||||
|
||||
has_dual_decoder = getattr(models[0], "mt_task_name", None) is not None
|
||||
|
||||
if self.args.target_is_code:
|
||||
if self.args.n_frames_per_step == 1:
|
||||
seq_generator = super().build_generator(
|
||||
models,
|
||||
args,
|
||||
seq_gen_cls=None,
|
||||
extra_gen_cls_kwargs=extra_gen_cls_kwargs,
|
||||
)
|
||||
if has_dual_decoder:
|
||||
seq_generator = self.build_generator_dual_decoder(
|
||||
models,
|
||||
args,
|
||||
extra_gen_cls_kwargs=extra_gen_cls_kwargs,
|
||||
)
|
||||
else:
|
||||
seq_generator = super().build_generator(
|
||||
models,
|
||||
args,
|
||||
seq_gen_cls=None,
|
||||
extra_gen_cls_kwargs=extra_gen_cls_kwargs,
|
||||
)
|
||||
else:
|
||||
assert (
|
||||
getattr(args, "beam", 1) == 1 and getattr(args, "nbest", 1) == 1
|
||||
@ -361,24 +430,64 @@ class SpeechToSpeechTask(LegacyFairseqTask):
|
||||
self.args.target_code_size,
|
||||
)
|
||||
else:
|
||||
if getattr(args, "teacher_forcing", False):
|
||||
from fairseq.speech_generator import (
|
||||
TeacherForcingAutoRegressiveSpeechGenerator,
|
||||
if has_dual_decoder:
|
||||
if getattr(args, "teacher_forcing", False):
|
||||
raise NotImplementedError
|
||||
else:
|
||||
from fairseq.speech_generator import MultiDecoderSpeechGenerator
|
||||
|
||||
generator = MultiDecoderSpeechGenerator
|
||||
|
||||
lang_token_ids_aux = {
|
||||
i
|
||||
for s, i in self.tgt_dict_mt.indices.items()
|
||||
if TextTargetMultitaskData.is_lang_tag(s)
|
||||
}
|
||||
|
||||
if extra_gen_cls_kwargs is None:
|
||||
extra_gen_cls_kwargs = {}
|
||||
extra_gen_cls_kwargs[
|
||||
"symbols_to_strip_from_output"
|
||||
] = lang_token_ids_aux
|
||||
|
||||
eos_id_mt = (
|
||||
self.tgt_dict_mt.index(self.eos_token_mt)
|
||||
if self.eos_token_mt
|
||||
else None
|
||||
)
|
||||
assert eos_id_mt != self.tgt_dict_mt.unk()
|
||||
extra_gen_cls_kwargs["eos_mt"] = eos_id_mt
|
||||
|
||||
generator = TeacherForcingAutoRegressiveSpeechGenerator
|
||||
logger.info("Teacher forcing mode for generation")
|
||||
seq_generator = generator(
|
||||
models,
|
||||
args,
|
||||
self.vocoder,
|
||||
self.data_cfg,
|
||||
self.target_dictionary_mt,
|
||||
max_iter=self.args.max_target_positions,
|
||||
eos_prob_threshold=self.args.eos_prob_threshold,
|
||||
**extra_gen_cls_kwargs,
|
||||
)
|
||||
else:
|
||||
from fairseq.speech_generator import AutoRegressiveSpeechGenerator
|
||||
if getattr(args, "teacher_forcing", False):
|
||||
from fairseq.speech_generator import (
|
||||
TeacherForcingAutoRegressiveSpeechGenerator,
|
||||
)
|
||||
|
||||
generator = AutoRegressiveSpeechGenerator
|
||||
seq_generator = generator(
|
||||
models[0],
|
||||
self.vocoder,
|
||||
self.data_cfg,
|
||||
max_iter=self.args.max_target_positions,
|
||||
eos_prob_threshold=self.args.eos_prob_threshold,
|
||||
)
|
||||
generator = TeacherForcingAutoRegressiveSpeechGenerator
|
||||
logger.info("Teacher forcing mode for generation")
|
||||
else:
|
||||
from fairseq.speech_generator import AutoRegressiveSpeechGenerator
|
||||
|
||||
generator = AutoRegressiveSpeechGenerator
|
||||
|
||||
seq_generator = generator(
|
||||
models[0],
|
||||
self.vocoder,
|
||||
self.data_cfg,
|
||||
max_iter=self.args.max_target_positions,
|
||||
eos_prob_threshold=self.args.eos_prob_threshold,
|
||||
)
|
||||
|
||||
return seq_generator
|
||||
|
||||
|
@ -6,6 +6,7 @@
|
||||
import logging
|
||||
from argparse import Namespace
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
from fairseq.data import Dictionary, encoders
|
||||
from fairseq.data.audio.audio_utils import get_features_or_waveform
|
||||
@ -14,6 +15,7 @@ from fairseq.data.audio.speech_to_text_dataset import (
|
||||
S2TDataConfig,
|
||||
SpeechToTextDataset,
|
||||
SpeechToTextDatasetCreator,
|
||||
TextTargetMultitaskData,
|
||||
)
|
||||
from fairseq.tasks import LegacyFairseqTask, register_task
|
||||
|
||||
@ -66,13 +68,32 @@ class SpeechToTextTask(LegacyFairseqTask):
|
||||
)
|
||||
|
||||
self.multitask_tasks = {}
|
||||
self.tgt_dict_mt = None
|
||||
self.eos_token_mt = None
|
||||
if getattr(args, "multitask_config_yaml", None) is not None:
|
||||
multitask_cfg = MultitaskConfig(
|
||||
Path(args.data) / args.multitask_config_yaml
|
||||
)
|
||||
for task_name, task_config in multitask_cfg.get_all_tasks().items():
|
||||
task_obj = DummyMultiTask(task_config, task_config.tgt_dict)
|
||||
first_pass_task_idx = multitask_cfg.first_pass_decoder_task_index
|
||||
for i, (task_name, task_config) in enumerate(
|
||||
multitask_cfg.get_all_tasks().items()
|
||||
):
|
||||
task_obj = DummyMultiTask(
|
||||
task_config,
|
||||
task_config.tgt_dict,
|
||||
first_pass=i == first_pass_task_idx,
|
||||
)
|
||||
self.multitask_tasks[task_name] = task_obj
|
||||
if task_obj.is_first_pass_decoder:
|
||||
self.tgt_dict_mt = task_obj.target_dictionary
|
||||
if task_config.prepend_bos_and_append_tgt_lang_tag:
|
||||
self.eos_token_mt = task_config.eos_token
|
||||
assert not isinstance(self.eos_token_mt, List)
|
||||
|
||||
if not self.eos_token_mt:
|
||||
raise Warning(
|
||||
"Please provide eos_token in --multitask-config-yaml to replace eos in sequence generator"
|
||||
)
|
||||
|
||||
def _get_speaker_to_id(self):
|
||||
speaker_to_id = None
|
||||
@ -124,12 +145,17 @@ class SpeechToTextTask(LegacyFairseqTask):
|
||||
epoch=epoch,
|
||||
seed=self.args.seed,
|
||||
speaker_to_id=self.speaker_to_id,
|
||||
multitask=self.multitask_tasks,
|
||||
)
|
||||
|
||||
@property
|
||||
def target_dictionary(self):
|
||||
return self.tgt_dict
|
||||
|
||||
@property
|
||||
def target_dictionary_mt(self):
|
||||
return self.tgt_dict_mt
|
||||
|
||||
@property
|
||||
def source_dictionary(self):
|
||||
return None
|
||||
@ -143,6 +169,51 @@ class SpeechToTextTask(LegacyFairseqTask):
|
||||
args.speaker_to_id = self.speaker_to_id
|
||||
return super(SpeechToTextTask, self).build_model(args, from_checkpoint)
|
||||
|
||||
def build_generator_dual_decoder(
|
||||
self,
|
||||
models,
|
||||
args,
|
||||
extra_gen_cls_kwargs,
|
||||
):
|
||||
from examples.speech_to_speech.unity.sequence_generator_multi_decoder import (
|
||||
MultiDecoderSequenceGenerator,
|
||||
)
|
||||
|
||||
lang_token_ids_aux = {
|
||||
i
|
||||
for s, i in self.tgt_dict_mt.indices.items()
|
||||
if TextTargetMultitaskData.is_lang_tag(s)
|
||||
}
|
||||
|
||||
extra_gen_cls_kwargs["symbols_to_strip_from_output"].update(lang_token_ids_aux)
|
||||
|
||||
eos_id_mt = (
|
||||
self.tgt_dict_mt.index(self.eos_token_mt) if self.eos_token_mt else None
|
||||
)
|
||||
assert eos_id_mt != self.tgt_dict_mt.unk()
|
||||
extra_gen_cls_kwargs["eos_mt"] = eos_id_mt
|
||||
|
||||
return MultiDecoderSequenceGenerator(
|
||||
models,
|
||||
self.target_dictionary,
|
||||
self.target_dictionary_mt,
|
||||
beam_size=max(1, getattr(args, "beam", 1)),
|
||||
beam_size_mt=max(1, getattr(args, "beam_mt", 1)),
|
||||
max_len_a=getattr(args, "max_len_a", 0),
|
||||
max_len_b=getattr(args, "max_len_b", 200),
|
||||
max_len_a_mt=getattr(args, "max_len_a_mt", 0),
|
||||
max_len_b_mt=getattr(args, "max_len_b_mt", 0),
|
||||
min_len=getattr(args, "min_len", 1),
|
||||
normalize_scores=(not getattr(args, "unnormalized", False)),
|
||||
len_penalty=getattr(args, "lenpen", 1),
|
||||
len_penalty_mt=getattr(args, "lenpen_mt", 1),
|
||||
unk_penalty=getattr(args, "unkpen", 0),
|
||||
temperature=getattr(args, "temperature", 1.0),
|
||||
match_source_len=getattr(args, "match_source_len", False),
|
||||
no_repeat_ngram_size=getattr(args, "no_repeat_ngram_size", 0),
|
||||
**extra_gen_cls_kwargs,
|
||||
)
|
||||
|
||||
def build_generator(
|
||||
self,
|
||||
models,
|
||||
@ -179,9 +250,21 @@ class SpeechToTextTask(LegacyFairseqTask):
|
||||
eos_id = self.tgt_dict.index(eos_token) if eos_token else None
|
||||
extra_gen_cls_kwargs["eos"] = eos_id
|
||||
|
||||
return super().build_generator(
|
||||
models, args, seq_gen_cls=None, extra_gen_cls_kwargs=extra_gen_cls_kwargs
|
||||
)
|
||||
has_dual_decoder = getattr(models[0], "mt_task_name", None) is not None
|
||||
|
||||
if has_dual_decoder:
|
||||
return self.build_generator_dual_decoder(
|
||||
models,
|
||||
args,
|
||||
extra_gen_cls_kwargs=extra_gen_cls_kwargs,
|
||||
)
|
||||
else:
|
||||
return super().build_generator(
|
||||
models,
|
||||
args,
|
||||
seq_gen_cls=None,
|
||||
extra_gen_cls_kwargs=extra_gen_cls_kwargs,
|
||||
)
|
||||
|
||||
def train_step(
|
||||
self, sample, model, criterion, optimizer, update_num, ignore_grad=False
|
||||
@ -225,14 +308,19 @@ class SpeechToTextTask(LegacyFairseqTask):
|
||||
|
||||
|
||||
class DummyMultiTask(LegacyFairseqTask):
|
||||
def __init__(self, args, tgt_dict):
|
||||
def __init__(self, args, tgt_dict, first_pass=False):
|
||||
super().__init__(args)
|
||||
self.tgt_dict = tgt_dict
|
||||
self.first_pass = first_pass
|
||||
|
||||
@property
|
||||
def target_dictionary(self):
|
||||
return self.tgt_dict
|
||||
|
||||
@property
|
||||
def is_first_pass_decoder(self):
|
||||
return self.first_pass
|
||||
|
||||
def inference_step(
|
||||
self, generator, models, sample, prefix_tokens=None, constraints=None
|
||||
):
|
||||
|
@ -10,6 +10,7 @@ import os
|
||||
import re
|
||||
|
||||
import torch
|
||||
|
||||
from fairseq.file_io import PathManager
|
||||
|
||||
|
||||
@ -113,6 +114,9 @@ def main():
|
||||
num_group.add_argument('--num-update-checkpoints', type=int,
|
||||
help='if set, will try to find checkpoints with names checkpoint_ee_xx.pt in the path specified by'
|
||||
' input, and average last this many of them.')
|
||||
num_group.add_argument('--num-best-checkpoints', type=int, default=0,
|
||||
help='if set, will try to find checkpoints with names checkpoint_best_ee_xx.pt in the path specified by'
|
||||
' input, and average last this many of them.')
|
||||
parser.add_argument('--checkpoint-upper-bound', type=int,
|
||||
help='when using --num-epoch-checkpoints, this will set an upper bound on which epoch to use, '
|
||||
'when using --num-update-checkpoints, this will set an upper bound on which update to use'
|
||||
@ -150,6 +154,18 @@ def main():
|
||||
)
|
||||
print("averaging checkpoints: ", args.inputs)
|
||||
|
||||
if args.num_best_checkpoints > 0:
|
||||
args.inputs = list(
|
||||
sorted(
|
||||
args.inputs,
|
||||
key=lambda x: float(
|
||||
os.path.basename(x).split("_")[-1].replace(".pt", "")
|
||||
),
|
||||
)
|
||||
)
|
||||
args.inputs = args.inputs[: args.num_best_checkpoints]
|
||||
for path in args.inputs:
|
||||
print(os.path.basename(path))
|
||||
new_state = average_checkpoints(args.inputs)
|
||||
with PathManager.open(args.output, "wb") as f:
|
||||
torch.save(new_state, f)
|
||||
|
Loading…
Reference in New Issue
Block a user