Add linting with black (#2678)

Summary:
# Before submitting

- [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
- [ ] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/main/CONTRIBUTING.md)?
- [ ] Did you make sure to update the docs?
- [ ] Did you write any new necessary tests?

## What does this PR do?
Fixes # (issue).

## PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

## Did you have fun?
Make sure you had fun coding �

Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/2678

Reviewed By: Mortimerp9

Differential Revision: D32653381

Pulled By: dianaml0

fbshipit-source-id: 2810d14867cd7d64f4d340740e2b590b82de47fe
This commit is contained in:
dianaml0 2021-11-29 12:30:10 -08:00 committed by Facebook GitHub Bot
parent 3dc1691df1
commit 0dfd6b6240
137 changed files with 2139 additions and 1353 deletions

View File

@ -53,3 +53,8 @@ jobs:
- name: Run tests
run: |
python setup.py test
- name: Lint with black
run: |
pip install black
black --check . --extend-exclude 'examples|fairseq\/model_parallel\/megatron'

View File

@ -27,6 +27,7 @@ sys.modules["fairseq.progress_bar"] = progress_bar
# initialize hydra
from fairseq.dataclass.initialize import hydra_init
hydra_init()
import fairseq.criterions # noqa

View File

@ -7,10 +7,10 @@ import logging
import numpy as np
import torch
from fairseq.data import Dictionary, FairseqDataset
from fairseq.tasks import LegacyFairseqTask, register_task
logger = logging.getLogger(__name__)
@ -36,7 +36,7 @@ class DummyMTTask(LegacyFairseqTask):
@classmethod
def setup_task(cls, args, **kwargs):
"""Setup the task. """
"""Setup the task."""
dictionary = Dictionary()
for i in range(args.dict_size):
dictionary.add_symbol("word{}".format(i))

View File

@ -96,10 +96,7 @@ def save_checkpoint(cfg: CheckpointConfig, trainer, epoch_itr, val_loss):
checkpoint_conds[
"checkpoint.best_{}_{:.3f}{}{}.pt".format(
cfg.best_checkpoint_metric,
val_loss,
rand_sfx,
suffix
cfg.best_checkpoint_metric, val_loss, rand_sfx, suffix
)
] = worst_best is None or is_better(val_loss, worst_best)
checkpoint_conds[
@ -468,9 +465,7 @@ def load_model_ensemble_and_task(
and len(state["optimizer_history"]) > 0
and "num_updates" in state["optimizer_history"][-1]
):
model.set_num_updates(
state["optimizer_history"][-1]["num_updates"]
)
model.set_num_updates(state["optimizer_history"][-1]["num_updates"])
model.load_state_dict(
state["model"], strict=strict, model_cfg=cfg.model
)
@ -588,9 +583,8 @@ def _upgrade_state_dict(state):
# backward compatibility, cfg updates
if "args" in state and state["args"] is not None:
# old model checkpoints may not have separate source/target positions
if (
hasattr(state["args"], "max_positions")
and not hasattr(state["args"], "max_source_positions")
if hasattr(state["args"], "max_positions") and not hasattr(
state["args"], "max_source_positions"
):
state["args"].max_source_positions = state["args"].max_positions
state["args"].max_target_positions = state["args"].max_positions
@ -615,13 +609,10 @@ def _upgrade_state_dict(state):
state["args"].stop_min_lr = state["args"].min_lr
del state["args"].min_lr
# binary_cross_entropy / kd_binary_cross_entropy => wav2vec criterion
if (
hasattr(state["args"], "criterion")
and state["args"].criterion in [
"binary_cross_entropy",
"kd_binary_cross_entropy",
]
):
if hasattr(state["args"], "criterion") and state["args"].criterion in [
"binary_cross_entropy",
"kd_binary_cross_entropy",
]:
state["args"].criterion = "wav2vec"
# remove log_keys if it's None (criteria will supply a default value of [])
if hasattr(state["args"], "log_keys") and state["args"].log_keys is None:
@ -659,7 +650,9 @@ def _upgrade_state_dict(state):
):
cfg.task.eval_wer_config.print_alignment = "hard"
if "generation" in cfg and isinstance(cfg.generation.print_alignment, bool):
cfg.generation.print_alignment = "hard" if cfg.generation.print_alignment else None
cfg.generation.print_alignment = (
"hard" if cfg.generation.print_alignment else None
)
if (
"model" in cfg
and "w2v_args" in cfg.model
@ -833,16 +826,16 @@ def load_ema_from_checkpoint(fpath):
params_dict = collections.OrderedDict()
new_state = None
with PathManager.open(fpath, 'rb') as f:
with PathManager.open(fpath, "rb") as f:
new_state = torch.load(
f,
map_location=(
lambda s, _: torch.serialization.default_restore_location(s, 'cpu')
lambda s, _: torch.serialization.default_restore_location(s, "cpu")
),
)
# EMA model is stored in a separate "extra state"
model_params = new_state['extra_state']['ema']
model_params = new_state["extra_state"]["ema"]
for key in list(model_params.keys()):
p = model_params[key]
@ -860,5 +853,5 @@ def load_ema_from_checkpoint(fpath):
"ema model weights, is this model trained with EMA?"
)
new_state['model'] = params_dict
new_state["model"] = params_dict
return new_state

View File

@ -20,9 +20,7 @@ from fairseq.models.fairseq_model import FairseqEncoderModel
@dataclass
class FastSpeech2CriterionConfig(FairseqDataclass):
ctc_weight: float = field(
default=0.0, metadata={"help": "weight for CTC loss"}
)
ctc_weight: float = field(default=0.0, metadata={"help": "weight for CTC loss"})
@register_criterion("fastspeech2", dataclass=FastSpeech2CriterionConfig)
@ -44,7 +42,7 @@ class FastSpeech2Loss(FairseqCriterion):
speaker=sample["speaker"],
durations=sample["durations"],
pitches=sample["pitches"],
energies=sample["energies"]
energies=sample["energies"],
)
src_mask = lengths_to_mask(sample["net_input"]["src_lengths"])
@ -57,8 +55,7 @@ class FastSpeech2Loss(FairseqCriterion):
feat_out, feat = _feat_out[tgt_mask], sample["target"][tgt_mask]
l1_loss = F.l1_loss(feat_out, feat, reduction=reduction)
if _feat_out_post is not None:
l1_loss += F.l1_loss(_feat_out_post[tgt_mask], feat,
reduction=reduction)
l1_loss += F.l1_loss(_feat_out_post[tgt_mask], feat, reduction=reduction)
pitch_loss = F.mse_loss(pitch_out, pitches, reduction=reduction)
energy_loss = F.mse_loss(energy_out, energies, reduction=reduction)
@ -69,16 +66,23 @@ class FastSpeech2Loss(FairseqCriterion):
log_dur = torch.log(dur + 1)[src_mask]
dur_loss = F.mse_loss(log_dur_out, log_dur, reduction=reduction)
ctc_loss = torch.tensor(0.).type_as(l1_loss)
if self.ctc_weight > 0.:
ctc_loss = torch.tensor(0.0).type_as(l1_loss)
if self.ctc_weight > 0.0:
lprobs = model.get_normalized_probs((_feat_out,), log_probs=True)
lprobs = lprobs.transpose(0, 1) # T x B x C
src_mask = lengths_to_mask(src_lens)
src_tokens_flat = src_tokens.masked_select(src_mask)
ctc_loss = F.ctc_loss(
lprobs, src_tokens_flat, tgt_lens, src_lens,
reduction=reduction, zero_infinity=True
) * self.ctc_weight
ctc_loss = (
F.ctc_loss(
lprobs,
src_tokens_flat,
tgt_lens,
src_lens,
reduction=reduction,
zero_infinity=True,
)
* self.ctc_weight
)
loss = l1_loss + dur_loss + pitch_loss + energy_loss + ctc_loss
@ -102,8 +106,12 @@ class FastSpeech2Loss(FairseqCriterion):
ntot = sum(ns)
ws = [n / (ntot + 1e-8) for n in ns]
for key in [
"loss", "l1_loss", "dur_loss", "pitch_loss", "energy_loss",
"ctc_loss"
"loss",
"l1_loss",
"dur_loss",
"pitch_loss",
"energy_loss",
"ctc_loss",
]:
vals = [log.get(key, 0) for log in logging_outputs]
val = sum(val * w for val, w in zip(vals, ws))
@ -115,10 +123,10 @@ class FastSpeech2Loss(FairseqCriterion):
return
n = sum(log.get("targ_frames", 0) for log in logging_outputs)
for key, new_key in [
("mcd_loss", "mcd_loss"),
("pred_frames", "pred_ratio"),
("nins", "ins_rate"),
("ndel", "del_rate"),
("mcd_loss", "mcd_loss"),
("pred_frames", "pred_ratio"),
("nins", "ins_rate"),
("ndel", "del_rate"),
]:
val = sum(log.get(key, 0) for log in logging_outputs)
metrics.log_scalar(new_key, val / n, n, round=3)

View File

@ -37,7 +37,14 @@ class HubertCriterionConfig(FairseqDataclass):
@register_criterion("hubert", dataclass=HubertCriterionConfig)
class HubertCriterion(FairseqCriterion):
def __init__(self, task, pred_masked_weight, pred_nomask_weight, loss_weights=None, log_keys=None):
def __init__(
self,
task,
pred_masked_weight,
pred_nomask_weight,
loss_weights=None,
log_keys=None,
):
super().__init__(task)
self.pred_masked_weight = pred_masked_weight
self.pred_nomask_weight = pred_nomask_weight
@ -52,7 +59,7 @@ class HubertCriterion(FairseqCriterion):
3) logging outputs to display while training
"""
net_output = model(target_list=sample["target_list"], **sample["net_input"])
loss = 0.
loss = 0.0
sample_size = 0
logging_output = {}
reduction = "sum" if reduce else "none"
@ -89,7 +96,9 @@ class HubertCriterion(FairseqCriterion):
names = [names]
if len(self.loss_weights) == 1 and len(extra_losses) != 1:
self.loss_weights = [self.loss_weights[0]] * len(extra_losses)
assert len(extra_losses) == len(self.loss_weights), f"{len(extra_losses)}, {len(self.loss_weights)}"
assert len(extra_losses) == len(
self.loss_weights
), f"{len(extra_losses)}, {len(self.loss_weights)}"
for p, n, coef in zip(extra_losses, names, self.loss_weights):
if coef != 0 and p is not None:
p = coef * p.float() * sample_size
@ -140,12 +149,20 @@ class HubertCriterion(FairseqCriterion):
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
metrics.log_scalar("loss", loss_sum / sample_size / math.log(2), sample_size, round=3)
metrics.log_scalar(
"loss", loss_sum / sample_size / math.log(2), sample_size, round=3
)
if sample_size != ntokens:
metrics.log_scalar("nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3)
metrics.log_derived("ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg))
metrics.log_scalar(
"nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3
)
metrics.log_derived(
"ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)
)
else:
metrics.log_derived("ppl", lambda meters: utils.get_perplexity(meters["loss"].avg))
metrics.log_derived(
"ppl", lambda meters: utils.get_perplexity(meters["loss"].avg)
)
counts = {}
for lk in logging_outputs[0].keys():

View File

@ -9,19 +9,20 @@ from fairseq import metrics, utils
from fairseq.criterions import register_criterion
from fairseq.criterions.label_smoothed_cross_entropy import (
LabelSmoothedCrossEntropyCriterion,
LabelSmoothedCrossEntropyCriterionConfig
LabelSmoothedCrossEntropyCriterionConfig,
)
try:
from simuleval.metrics.latency import (
AverageLagging,
AverageProportion,
DifferentiableAverageLagging
DifferentiableAverageLagging,
)
LATENCY_METRICS = {
"average_lagging": AverageLagging,
"average_proportion": AverageProportion,
"differentiable_average_lagging": DifferentiableAverageLagging,
"differentiable_average_lagging": DifferentiableAverageLagging,
}
except ImportError:
LATENCY_METRICS = None
@ -56,9 +57,10 @@ class LabelSmoothedCrossEntropyCriterionLatencyAugmentConfig(
metadata={"help": "Add latency loss after certain steps"},
)
@register_criterion(
"latency_augmented_label_smoothed_cross_entropy",
dataclass=LabelSmoothedCrossEntropyCriterionLatencyAugmentConfig
dataclass=LabelSmoothedCrossEntropyCriterionLatencyAugmentConfig,
)
class LatencyAugmentedLabelSmoothedCrossEntropyCriterion(
LabelSmoothedCrossEntropyCriterion
@ -101,9 +103,9 @@ class LatencyAugmentedLabelSmoothedCrossEntropyCriterion(
if self.latency_update_after > 0:
num_updates = getattr(model.decoder, "num_updates", None)
assert num_updates is not None, (
"model.decoder doesn't have attribute 'num_updates'"
)
assert (
num_updates is not None
), "model.decoder doesn't have attribute 'num_updates'"
if num_updates <= self.latency_update_after:
latency_loss = 0
@ -134,9 +136,7 @@ class LatencyAugmentedLabelSmoothedCrossEntropyCriterion(
assert (
net_output[-1].encoder_padding_mask is None
or not net_output[-1].encoder_padding_mask[:, 0].any()
), (
"Only right padding on source is supported."
)
), "Only right padding on source is supported."
# 1. Obtain the expected alignment
alpha_list = [item["alpha"] for item in net_output[1].attn_list]
num_layers = len(alpha_list)
@ -174,8 +174,7 @@ class LatencyAugmentedLabelSmoothedCrossEntropyCriterion(
.view(-1)
)
expected_latency = LATENCY_METRICS[self.latency_avg_type](
expected_delays, src_lengths, None,
target_padding_mask=target_padding_mask
expected_delays, src_lengths, None, target_padding_mask=target_padding_mask
)
# 2.1 average expected latency of heads
@ -210,24 +209,12 @@ class LatencyAugmentedLabelSmoothedCrossEntropyCriterion(
@classmethod
def reduce_metrics(cls, logging_outputs) -> None:
super().reduce_metrics(logging_outputs)
latency = sum(
log.get("latency", 0) for log in logging_outputs
)
delays_var = sum(
log.get("delays_var", 0) for log in logging_outputs
)
latency_loss = sum(
log.get("latency_loss", 0) for log in logging_outputs
)
latency = sum(log.get("latency", 0) for log in logging_outputs)
delays_var = sum(log.get("delays_var", 0) for log in logging_outputs)
latency_loss = sum(log.get("latency_loss", 0) for log in logging_outputs)
nsentences = sum(log.get("nsentences", 0) for log in logging_outputs)
metrics.log_scalar("latency", latency.float() / nsentences, nsentences, round=3)
metrics.log_scalar("delays_var", delays_var / nsentences, nsentences, round=3)
metrics.log_scalar(
"latency", latency.float() / nsentences, nsentences, round=3
)
metrics.log_scalar(
"delays_var", delays_var / nsentences,
nsentences, round=3
)
metrics.log_scalar(
"latency_loss", latency_loss / nsentences,
nsentences, round=3
"latency_loss", latency_loss / nsentences, nsentences, round=3
)

View File

@ -41,9 +41,7 @@ class Tacotron2CriterionConfig(FairseqDataclass):
default=0.4,
metadata={"help": "weight of positive examples for BCE loss"},
)
ctc_weight: float = field(
default=0.0, metadata={"help": "weight for CTC loss"}
)
ctc_weight: float = field(default=0.0, metadata={"help": "weight for CTC loss"})
sentence_avg: bool = II("optimization.sentence_avg")
@ -70,8 +68,7 @@ class GuidedAttentionLoss(torch.nn.Module):
bsz, max_s_len, max_t_len = len(src_lens), max(src_lens), max(tgt_lens)
weights = torch.zeros((bsz, max_t_len, max_s_len))
for i, (s_len, t_len) in enumerate(zip(src_lens, tgt_lens)):
weights[i, :t_len, :s_len] = self._get_weight(s_len, t_len,
self.sigma)
weights[i, :t_len, :s_len] = self._get_weight(s_len, t_len, self.sigma)
return weights
@staticmethod
@ -90,9 +87,16 @@ class GuidedAttentionLoss(torch.nn.Module):
@register_criterion("tacotron2", dataclass=Tacotron2CriterionConfig)
class Tacotron2Criterion(FairseqCriterion):
def __init__(self, task, sentence_avg, n_frames_per_step,
use_guided_attention_loss, guided_attention_loss_sigma,
bce_pos_weight, ctc_weight):
def __init__(
self,
task,
sentence_avg,
n_frames_per_step,
use_guided_attention_loss,
guided_attention_loss_sigma,
bce_pos_weight,
ctc_weight,
):
super().__init__(task)
self.sentence_avg = sentence_avg
self.n_frames_per_step = n_frames_per_step
@ -120,31 +124,42 @@ class Tacotron2Criterion(FairseqCriterion):
prev_output_tokens=sample["net_input"]["prev_output_tokens"],
incremental_state=None,
target_lengths=tgt_lens,
speaker=sample["speaker"]
speaker=sample["speaker"],
)
l1_loss, mse_loss, eos_loss = self.compute_loss(
extra["feature_out"], feat_out, eos_out, feat_tgt, eos_tgt,
tgt_lens, reduction,
extra["feature_out"],
feat_out,
eos_out,
feat_tgt,
eos_tgt,
tgt_lens,
reduction,
)
attn_loss = torch.tensor(0.).type_as(l1_loss)
attn_loss = torch.tensor(0.0).type_as(l1_loss)
if self.guided_attn is not None:
attn_loss = self.guided_attn(extra['attn'], src_lens, tgt_lens, reduction)
ctc_loss = torch.tensor(0.).type_as(l1_loss)
if self.ctc_weight > 0.:
attn_loss = self.guided_attn(extra["attn"], src_lens, tgt_lens, reduction)
ctc_loss = torch.tensor(0.0).type_as(l1_loss)
if self.ctc_weight > 0.0:
net_output = (feat_out, eos_out, extra)
lprobs = model.get_normalized_probs(net_output, log_probs=True)
lprobs = lprobs.transpose(0, 1) # T x B x C
src_mask = lengths_to_mask(src_lens)
src_tokens_flat = src_tokens.masked_select(src_mask)
ctc_loss = F.ctc_loss(
lprobs, src_tokens_flat, tgt_lens, src_lens,
reduction=reduction, zero_infinity=True
) * self.ctc_weight
ctc_loss = (
F.ctc_loss(
lprobs,
src_tokens_flat,
tgt_lens,
src_lens,
reduction=reduction,
zero_infinity=True,
)
* self.ctc_weight
)
loss = l1_loss + mse_loss + eos_loss + attn_loss + ctc_loss
sample_size = sample["nsentences"] if self.sentence_avg \
else sample["ntokens"]
sample_size = sample["nsentences"] if self.sentence_avg else sample["ntokens"]
logging_output = {
"loss": utils.item(loss.data),
"ntokens": sample["ntokens"],
@ -158,8 +173,16 @@ class Tacotron2Criterion(FairseqCriterion):
}
return loss, sample_size, logging_output
def compute_loss(self, feat_out, feat_out_post, eos_out, feat_tgt,
eos_tgt, tgt_lens, reduction="mean"):
def compute_loss(
self,
feat_out,
feat_out_post,
eos_out,
feat_tgt,
eos_tgt,
tgt_lens,
reduction="mean",
):
mask = lengths_to_mask(tgt_lens)
_eos_out = eos_out[mask].squeeze()
_eos_tgt = eos_tgt[mask]
@ -167,17 +190,17 @@ class Tacotron2Criterion(FairseqCriterion):
_feat_out = feat_out[mask]
_feat_out_post = feat_out_post[mask]
l1_loss = (
F.l1_loss(_feat_out, _feat_tgt, reduction=reduction) +
F.l1_loss(_feat_out_post, _feat_tgt, reduction=reduction)
l1_loss = F.l1_loss(_feat_out, _feat_tgt, reduction=reduction) + F.l1_loss(
_feat_out_post, _feat_tgt, reduction=reduction
)
mse_loss = (
F.mse_loss(_feat_out, _feat_tgt, reduction=reduction) +
F.mse_loss(_feat_out_post, _feat_tgt, reduction=reduction)
mse_loss = F.mse_loss(_feat_out, _feat_tgt, reduction=reduction) + F.mse_loss(
_feat_out_post, _feat_tgt, reduction=reduction
)
eos_loss = F.binary_cross_entropy_with_logits(
_eos_out, _eos_tgt, pos_weight=torch.tensor(self.bce_pos_weight),
reduction=reduction
_eos_out,
_eos_tgt,
pos_weight=torch.tensor(self.bce_pos_weight),
reduction=reduction,
)
return l1_loss, mse_loss, eos_loss
@ -197,10 +220,10 @@ class Tacotron2Criterion(FairseqCriterion):
return
n = sum(log.get("targ_frames", 0) for log in logging_outputs)
for key, new_key in [
("mcd_loss", "mcd_loss"),
("pred_frames", "pred_ratio"),
("nins", "ins_rate"),
("ndel", "del_rate"),
("mcd_loss", "mcd_loss"),
("pred_frames", "pred_ratio"),
("nins", "ins_rate"),
("ndel", "del_rate"),
]:
val = sum(log.get(key, 0) for log in logging_outputs)
metrics.log_scalar(new_key, val / n, n, round=3)

View File

@ -33,6 +33,7 @@ class Wav2VecCriterionConfig(FairseqDataclass):
metadata={"help": "output keys to log"},
)
@register_criterion("wav2vec", dataclass=Wav2VecCriterionConfig)
class Wav2vecCriterion(FairseqCriterion):
def __init__(self, task, infonce=False, loss_weights=None, log_keys=None):
@ -76,16 +77,16 @@ class Wav2vecCriterion(FairseqCriterion):
# we don't shrink tensors using mask_indices.
# Instead, we use mask indices to adjust loss.
mi = (
sample['net_input']['mask_indices']
sample["net_input"]["mask_indices"]
.transpose(0, 1) # logits are transposed in `model.get_logits`
.reshape(logits.size(0))
)
loss = (loss * mi).sum() if reduce else (loss * mi)
if 'sample_size' in sample:
sample_size = sample['sample_size']
elif 'mask_indices' in sample['net_input']:
sample_size = sample['net_input']['mask_indices'].sum()
if "sample_size" in sample:
sample_size = sample["sample_size"]
elif "mask_indices" in sample["net_input"]:
sample_size = sample["net_input"]["mask_indices"].sum()
else:
sample_size = target.numel() if self.infonce else target.long().sum().item()
losses.append(loss.detach().clone())
@ -216,8 +217,8 @@ class Wav2vecCriterion(FairseqCriterion):
metrics.log_scalar(k, val / len(logging_outputs), round=3)
# FIXME: revert when gather based xla reduction is implemented
#@staticmethod
#def logging_outputs_can_be_summed() -> bool:
# @staticmethod
# def logging_outputs_can_be_summed() -> bool:
def logging_outputs_can_be_summed(self) -> bool:
"""
Whether the logging outputs returned by `forward` can be summed

View File

@ -20,7 +20,7 @@ class AddTargetDataset(BaseWrapperDataset):
process_label=None,
label_len_fn=None,
add_to_input=False,
text_compression_level=TextCompressionLevel.none
text_compression_level=TextCompressionLevel.none,
):
super().__init__(dataset)
self.labels = labels

View File

@ -18,26 +18,28 @@ FEATURE_OR_SF_AUDIO_FILE_EXTENSIONS = {".npy", ".wav", ".flac", ".ogg"}
def convert_waveform(
waveform: Union[np.ndarray, torch.Tensor], sample_rate: int,
normalize_volume: bool = False, to_mono: bool = False,
to_sample_rate: Optional[int] = None
waveform: Union[np.ndarray, torch.Tensor],
sample_rate: int,
normalize_volume: bool = False,
to_mono: bool = False,
to_sample_rate: Optional[int] = None,
) -> Tuple[Union[np.ndarray, torch.Tensor], int]:
"""convert a waveform:
- to a target sample rate
- from multi-channel to mono channel
- volume normalization
- to a target sample rate
- from multi-channel to mono channel
- volume normalization
Args:
waveform (numpy.ndarray or torch.Tensor): 2D original waveform
(channels x length)
sample_rate (int): original sample rate
normalize_volume (bool): perform volume normalization
to_mono (bool): convert to mono channel if having multiple channels
to_sample_rate (Optional[int]): target sample rate
Returns:
waveform (numpy.ndarray): converted 2D waveform (channels x length)
sample_rate (float): target sample rate
"""
Args:
waveform (numpy.ndarray or torch.Tensor): 2D original waveform
(channels x length)
sample_rate (int): original sample rate
normalize_volume (bool): perform volume normalization
to_mono (bool): convert to mono channel if having multiple channels
to_sample_rate (Optional[int]): target sample rate
Returns:
waveform (numpy.ndarray): converted 2D waveform (channels x length)
sample_rate (float): target sample rate
"""
try:
import torchaudio.sox_effects as ta_sox
except ImportError:
@ -63,10 +65,14 @@ def convert_waveform(
def get_waveform(
path_or_fp: Union[str, BinaryIO], normalization: bool = True,
mono: bool = True, frames: int = -1, start: int = 0,
always_2d: bool = True, output_sample_rate: Optional[int] = None,
normalize_volume: bool = False
path_or_fp: Union[str, BinaryIO],
normalization: bool = True,
mono: bool = True,
frames: int = -1,
start: int = 0,
always_2d: bool = True,
output_sample_rate: Optional[int] = None,
normalize_volume: bool = False,
) -> Tuple[np.ndarray, int]:
"""Get the waveform and sample rate of a 16-bit WAV/FLAC/OGG Vorbis audio.
@ -98,8 +104,11 @@ def get_waveform(
)
waveform = waveform.T # T x C -> C x T
waveform, sample_rate = convert_waveform(
waveform, sample_rate, normalize_volume=normalize_volume, to_mono=mono,
to_sample_rate=output_sample_rate
waveform,
sample_rate,
normalize_volume=normalize_volume,
to_mono=mono,
to_sample_rate=output_sample_rate,
)
if not normalization:
@ -182,7 +191,7 @@ def is_sf_audio_data(data: bytes) -> bool:
def mmap_read(path: str, offset: int, length: int) -> bytes:
with open(path, "rb") as f:
with mmap.mmap(f.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_o:
data = mmap_o[offset: offset + length]
data = mmap_o[offset : offset + length]
return data
@ -215,9 +224,7 @@ def parse_path(path: str) -> Tuple[str, List[int]]:
return _path, slice_ptr
def get_window(
window_fn: callable, n_fft: int, win_length: int
) -> torch.Tensor:
def get_window(window_fn: callable, n_fft: int, win_length: int) -> torch.Tensor:
padding = n_fft - win_length
assert padding >= 0
return F.pad(window_fn(win_length), (padding // 2, padding - padding // 2))
@ -226,13 +233,13 @@ def get_window(
def get_fourier_basis(n_fft: int) -> torch.Tensor:
basis = np.fft.fft(np.eye(n_fft))
basis = np.vstack(
[np.real(basis[:n_fft // 2 + 1, :]), np.imag(basis[:n_fft // 2 + 1, :])]
[np.real(basis[: n_fft // 2 + 1, :]), np.imag(basis[: n_fft // 2 + 1, :])]
)
return torch.from_numpy(basis).float()
def get_mel_filters(
sample_rate: int, n_fft: int, n_mels: int, f_min: float, f_max: float
sample_rate: int, n_fft: int, n_mels: int, f_min: float, f_max: float
) -> torch.Tensor:
try:
import librosa
@ -244,8 +251,12 @@ def get_mel_filters(
class TTSSpectrogram(torch.nn.Module):
def __init__(
self, n_fft: int, win_length: int, hop_length: int,
window_fn: callable = torch.hann_window, return_phase: bool = False
self,
n_fft: int,
win_length: int,
hop_length: int,
window_fn: callable = torch.hann_window,
return_phase: bool = False,
) -> None:
super(TTSSpectrogram, self).__init__()
self.n_fft = n_fft
@ -254,16 +265,16 @@ class TTSSpectrogram(torch.nn.Module):
basis = get_fourier_basis(n_fft).unsqueeze(1)
basis *= get_window(window_fn, n_fft, win_length)
self.register_buffer('basis', basis)
self.register_buffer("basis", basis)
def forward(
self, waveform: torch.Tensor
self, waveform: torch.Tensor
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
padding = (self.n_fft // 2, self.n_fft // 2)
x = F.pad(waveform.unsqueeze(1), padding, mode='reflect')
x = F.pad(waveform.unsqueeze(1), padding, mode="reflect")
x = F.conv1d(x, self.basis, stride=self.hop_length)
real_part = x[:, :self.n_fft // 2 + 1, :]
imag_part = x[:, self.n_fft // 2 + 1:, :]
real_part = x[:, : self.n_fft // 2 + 1, :]
imag_part = x[:, self.n_fft // 2 + 1 :, :]
magnitude = torch.sqrt(real_part ** 2 + imag_part ** 2)
if self.return_phase:
phase = torch.atan2(imag_part, real_part)
@ -273,13 +284,11 @@ class TTSSpectrogram(torch.nn.Module):
class TTSMelScale(torch.nn.Module):
def __init__(
self, n_mels: int, sample_rate: int, f_min: float, f_max: float,
n_stft: int
self, n_mels: int, sample_rate: int, f_min: float, f_max: float, n_stft: int
) -> None:
super(TTSMelScale, self).__init__()
basis = get_mel_filters(sample_rate, (n_stft - 1) * 2, n_mels, f_min,
f_max)
self.register_buffer('basis', basis)
basis = get_mel_filters(sample_rate, (n_stft - 1) * 2, n_mels, f_min, f_max)
self.register_buffer("basis", basis)
def forward(self, specgram: torch.Tensor) -> torch.Tensor:
return torch.matmul(self.basis, specgram)

View File

@ -13,11 +13,10 @@ from typing import List, Optional
import numpy as np
import torch
from fairseq.data import Dictionary
from fairseq.data.audio.speech_to_text_dataset import (
S2TDataConfig
)
from fairseq.data.audio.speech_to_text_dataset import S2TDataConfig
from fairseq.data.audio.text_to_speech_dataset import (
TextToSpeechDataset, TextToSpeechDatasetCreator
TextToSpeechDataset,
TextToSpeechDatasetCreator,
)
logger = logging.getLogger(__name__)
@ -48,7 +47,7 @@ class FrmTextToSpeechDataset(TextToSpeechDataset):
chunk_incr=5,
add_eos=True,
dedup=True,
ref_fpu=-1
ref_fpu=-1,
):
# It assumes texts are encoded at a fixed frame-rate
super().__init__(
@ -67,7 +66,7 @@ class FrmTextToSpeechDataset(TextToSpeechDataset):
pre_tokenizer=pre_tokenizer,
bpe_tokenizer=bpe_tokenizer,
n_frames_per_step=n_frames_per_step,
speaker_to_id=speaker_to_id
speaker_to_id=speaker_to_id,
)
self.do_chunk = do_chunk
@ -92,24 +91,23 @@ class FrmTextToSpeechDataset(TextToSpeechDataset):
fpu = source.size(0) / target.size(0) # frame-per-unit
fps = self.n_frames_per_step
assert (
self.ref_fpu == -1 or
abs((fpu * fps - self.ref_fpu) / self.ref_fpu) < 0.1
self.ref_fpu == -1 or abs((fpu * fps - self.ref_fpu) / self.ref_fpu) < 0.1
), f"{fpu*fps} != {self.ref_fpu}"
# only chunk training split
if self.is_train_split and self.do_chunk and self.chunk_size > 0:
lang = target[:int(self.data_cfg.prepend_tgt_lang_tag)]
text = target[int(self.data_cfg.prepend_tgt_lang_tag):]
lang = target[: int(self.data_cfg.prepend_tgt_lang_tag)]
text = target[int(self.data_cfg.prepend_tgt_lang_tag) :]
size = len(text)
chunk_size = min(self.chunk_size, size)
chunk_start = np.random.randint(size - chunk_size + 1)
text = text[chunk_start:chunk_start+chunk_size]
text = text[chunk_start : chunk_start + chunk_size]
target = torch.cat((lang, text), 0)
f_size = int(np.floor(chunk_size * fpu))
f_start = int(np.floor(chunk_start * fpu))
assert(f_size > 0)
source = source[f_start:f_start+f_size, :]
assert f_size > 0
source = source[f_start : f_start + f_size, :]
if self.dedup:
target = torch.unique_consecutive(target)
@ -126,10 +124,12 @@ class FrmTextToSpeechDataset(TextToSpeechDataset):
self.chunk_size = self.chunk_init + epoch * self.chunk_incr
if self.chunk_bound > 0:
self.chunk_size = min(self.chunk_size, self.chunk_bound)
logger.info((
f"{self.split}: setting chunk size "
f"from {old} to {self.chunk_size}"
))
logger.info(
(
f"{self.split}: setting chunk size "
f"from {old} to {self.chunk_size}"
)
)
class FrmTextToSpeechDatasetCreator(TextToSpeechDatasetCreator):
@ -152,7 +152,7 @@ class FrmTextToSpeechDatasetCreator(TextToSpeechDatasetCreator):
chunk_incr: int = 5,
add_eos: bool = True,
dedup: bool = True,
ref_fpu: float = -1
ref_fpu: float = -1,
) -> FrmTextToSpeechDataset:
tsv_path = op.join(root, f"{split}.tsv")
if not op.isfile(tsv_path):
@ -170,9 +170,7 @@ class FrmTextToSpeechDatasetCreator(TextToSpeechDatasetCreator):
assert len(s) > 0
ids = [ss[cls.KEY_ID] for ss in s]
audio_paths = [
op.join(data_cfg.audio_root, ss[cls.KEY_AUDIO]) for ss in s
]
audio_paths = [op.join(data_cfg.audio_root, ss[cls.KEY_AUDIO]) for ss in s]
n_frames = [int(ss[cls.KEY_N_FRAMES]) for ss in s]
tgt_texts = [ss[cls.KEY_TGT_TEXT] for ss in s]
src_texts = [ss.get(cls.KEY_SRC_TEXT, cls.DEFAULT_SRC_TEXT) for ss in s]
@ -203,5 +201,5 @@ class FrmTextToSpeechDatasetCreator(TextToSpeechDatasetCreator):
chunk_incr=chunk_incr,
add_eos=add_eos,
dedup=dedup,
ref_fpu=ref_fpu
ref_fpu=ref_fpu,
)

View File

@ -152,10 +152,7 @@ class HubertDataset(FairseqDataset):
self.label_offsets_list = [
load_label_offset(p, inds, tot) for p in label_paths
]
assert (
label_processors is None
or len(label_processors) == self.num_labels
)
assert label_processors is None or len(label_processors) == self.num_labels
for label_path, label_rate in zip(label_paths, self.label_rates):
verify_label_lengths(
self.sizes, sample_rate, label_path, label_rate, inds, tot
@ -234,8 +231,7 @@ class HubertDataset(FairseqDataset):
)
targets_by_label = [
[s["label_list"][i] for s in samples]
for i in range(self.num_labels)
[s["label_list"][i] for s in samples] for i in range(self.num_labels)
]
targets_list, lengths_list, ntokens_list = self.collater_label(
targets_by_label, audio_size, audio_starts
@ -270,9 +266,7 @@ class HubertDataset(FairseqDataset):
collated_audios[i] = audio
elif diff < 0:
assert self.pad_audio
collated_audios[i] = torch.cat(
[audio, audio.new_full((-diff,), 0.0)]
)
collated_audios[i] = torch.cat([audio, audio.new_full((-diff,), 0.0)])
padding_mask[i, diff:] = True
else:
collated_audios[i], audio_starts[i] = self.crop_to_max_size(
@ -280,9 +274,7 @@ class HubertDataset(FairseqDataset):
)
return collated_audios, padding_mask, audio_starts
def collater_frm_label(
self, targets, audio_size, audio_starts, label_rate, pad
):
def collater_frm_label(self, targets, audio_size, audio_starts, label_rate, pad):
assert label_rate > 0
s2f = label_rate / self.sample_rate
frm_starts = [int(round(s * s2f)) for s in audio_starts]
@ -290,24 +282,20 @@ class HubertDataset(FairseqDataset):
if not self.pad_audio:
rem_size = [len(t) - s for t, s in zip(targets, frm_starts)]
frm_size = min(frm_size, *rem_size)
targets = [t[s: s + frm_size] for t, s in zip(targets, frm_starts)]
targets = [t[s : s + frm_size] for t, s in zip(targets, frm_starts)]
logger.debug(f"audio_starts={audio_starts}")
logger.debug(f"frame_starts={frm_starts}")
logger.debug(f"frame_size={frm_size}")
lengths = torch.LongTensor([len(t) for t in targets])
ntokens = lengths.sum().item()
targets = data_utils.collate_tokens(
targets, pad_idx=pad, left_pad=False
)
targets = data_utils.collate_tokens(targets, pad_idx=pad, left_pad=False)
return targets, lengths, ntokens
def collater_seq_label(self, targets, pad):
lengths = torch.LongTensor([len(t) for t in targets])
ntokens = lengths.sum().item()
targets = data_utils.collate_tokens(
targets, pad_idx=pad, left_pad=False
)
targets = data_utils.collate_tokens(targets, pad_idx=pad, left_pad=False)
return targets, lengths, ntokens
def collater_label(self, targets_by_label, audio_size, audio_starts):
@ -315,9 +303,7 @@ class HubertDataset(FairseqDataset):
itr = zip(targets_by_label, self.label_rates, self.pad_list)
for targets, label_rate, pad in itr:
if label_rate == -1:
targets, lengths, ntokens = self.collater_seq_label(
targets, pad
)
targets, lengths, ntokens = self.collater_seq_label(targets, pad)
else:
targets, lengths, ntokens = self.collater_frm_label(
targets, audio_size, audio_starts, label_rate, pad

View File

@ -29,6 +29,7 @@ class ModalityDatasetItem(NamedTuple):
max_tokens: Optional[int] = None
max_sentences: Optional[int] = None
# MultiModalityDataset: it concate multiple datasets with different modalities.
# Compared with ConcatDataset it can 1) sample data given the ratios for different datasets
# 2) it adds mode to indicate what type of the data samples come from.

View File

@ -308,6 +308,7 @@ class FileAudioDataset(RawAudioDataset):
def __getitem__(self, index):
import soundfile as sf
fn = self.fnames[index]
fn = fn if isinstance(self.fnames, list) else fn.as_py()
fn = self.text_compressor.decompress(fn)

View File

@ -45,7 +45,11 @@ def get_features_from_npy_or_audio(path):
def get_features_or_waveform_from_stored_zip(
path, byte_offset, byte_size, need_waveform=False, use_sample_rate=None,
path,
byte_offset,
byte_size,
need_waveform=False,
use_sample_rate=None,
):
assert path.endswith(".zip")
data = read_from_stored_zip(path, byte_offset, byte_size)
@ -53,18 +57,17 @@ def get_features_or_waveform_from_stored_zip(
if is_npy_data(data):
features_or_waveform = np.load(f)
elif is_sf_audio_data(data):
features_or_waveform = \
get_waveform(
f, always_2d=False, output_sample_rate=use_sample_rate
)[0] if need_waveform else get_fbank(f)
features_or_waveform = (
get_waveform(f, always_2d=False, output_sample_rate=use_sample_rate)[0]
if need_waveform
else get_fbank(f)
)
else:
raise ValueError(f'Unknown file format for "{path}"')
return features_or_waveform
def get_features_or_waveform(
path: str, need_waveform=False, use_sample_rate=None
):
def get_features_or_waveform(path: str, need_waveform=False, use_sample_rate=None):
"""Get speech features from .npy file or waveform from .wav/.flac file.
The file may be inside an uncompressed ZIP file and is accessed via byte
offset and length.
@ -87,8 +90,11 @@ def get_features_or_waveform(
return get_features_from_npy_or_audio(_path)
elif len(slice_ptr) == 2:
features_or_waveform = get_features_or_waveform_from_stored_zip(
_path, slice_ptr[0], slice_ptr[1], need_waveform=need_waveform,
use_sample_rate=use_sample_rate
_path,
slice_ptr[0],
slice_ptr[1],
need_waveform=need_waveform,
use_sample_rate=use_sample_rate,
)
else:
raise ValueError(f"Invalid path: {path}")
@ -145,7 +151,7 @@ class SpeechToTextDataset(FairseqDataset):
pre_tokenizer=None,
bpe_tokenizer=None,
n_frames_per_step=1,
speaker_to_id=None
speaker_to_id=None,
):
self.split, self.is_train_split = split, is_train_split
self.cfg = cfg
@ -235,7 +241,7 @@ class SpeechToTextDataset(FairseqDataset):
if self.n_frames_per_step == 1:
return feature
n_packed_frames = feature.shape[0] // self.n_frames_per_step
feature = feature[:self.n_frames_per_step * n_packed_frames]
feature = feature[: self.n_frames_per_step * n_packed_frames]
return feature.reshape(n_packed_frames, -1)
@classmethod
@ -318,9 +324,11 @@ class SpeechToTextDataset(FairseqDataset):
speaker = None
if self.speaker_to_id is not None:
speaker = torch.tensor(
[s.speaker_id for s in samples], dtype=torch.long
).index_select(0, order).view(-1, 1)
speaker = (
torch.tensor([s.speaker_id for s in samples], dtype=torch.long)
.index_select(0, order)
.view(-1, 1)
)
net_input = {
"src_tokens": frames,
@ -388,7 +396,7 @@ class SpeechToTextDatasetCreator(object):
pre_tokenizer,
bpe_tokenizer,
n_frames_per_step,
speaker_to_id
speaker_to_id,
) -> SpeechToTextDataset:
audio_root = Path(cfg.audio_root)
ids = [s[cls.KEY_ID] for s in samples]
@ -415,7 +423,7 @@ class SpeechToTextDatasetCreator(object):
pre_tokenizer=pre_tokenizer,
bpe_tokenizer=bpe_tokenizer,
n_frames_per_step=n_frames_per_step,
speaker_to_id=speaker_to_id
speaker_to_id=speaker_to_id,
)
@classmethod
@ -481,12 +489,19 @@ class SpeechToTextDatasetCreator(object):
pre_tokenizer,
bpe_tokenizer,
n_frames_per_step,
speaker_to_id
speaker_to_id,
) -> SpeechToTextDataset:
samples = cls._load_samples_from_tsv(root, split)
return cls._from_list(
split, is_train_split, samples, cfg, tgt_dict, pre_tokenizer,
bpe_tokenizer, n_frames_per_step, speaker_to_id
split,
is_train_split,
samples,
cfg,
tgt_dict,
pre_tokenizer,
bpe_tokenizer,
n_frames_per_step,
speaker_to_id,
)
@classmethod
@ -502,12 +517,19 @@ class SpeechToTextDatasetCreator(object):
epoch: int,
seed: int,
n_frames_per_step: int = 1,
speaker_to_id=None
speaker_to_id=None,
) -> SpeechToTextDataset:
datasets = [
cls._from_tsv(
root, cfg, split, tgt_dict, is_train_split, pre_tokenizer,
bpe_tokenizer, n_frames_per_step, speaker_to_id
root,
cfg,
split,
tgt_dict,
is_train_split,
pre_tokenizer,
bpe_tokenizer,
n_frames_per_step,
speaker_to_id,
)
for split in splits.split(",")
]

View File

@ -13,8 +13,11 @@ import numpy as np
import torch
from fairseq.data.audio.speech_to_text_dataset import (
SpeechToTextDataset, SpeechToTextDatasetCreator, S2TDataConfig,
_collate_frames, get_features_or_waveform
SpeechToTextDataset,
SpeechToTextDatasetCreator,
S2TDataConfig,
_collate_frames,
get_features_or_waveform,
)
from fairseq.data import Dictionary, data_utils as fairseq_data_utils
@ -32,34 +35,44 @@ class TextToSpeechDatasetItem(object):
class TextToSpeechDataset(SpeechToTextDataset):
def __init__(
self,
split: str,
is_train_split: bool,
cfg: S2TDataConfig,
audio_paths: List[str],
n_frames: List[int],
src_texts: Optional[List[str]] = None,
tgt_texts: Optional[List[str]] = None,
speakers: Optional[List[str]] = None,
src_langs: Optional[List[str]] = None,
tgt_langs: Optional[List[str]] = None,
ids: Optional[List[str]] = None,
tgt_dict: Optional[Dictionary] = None,
pre_tokenizer=None,
bpe_tokenizer=None,
n_frames_per_step=1,
speaker_to_id=None,
durations: Optional[List[List[int]]] = None,
pitches: Optional[List[str]] = None,
energies: Optional[List[str]] = None
self,
split: str,
is_train_split: bool,
cfg: S2TDataConfig,
audio_paths: List[str],
n_frames: List[int],
src_texts: Optional[List[str]] = None,
tgt_texts: Optional[List[str]] = None,
speakers: Optional[List[str]] = None,
src_langs: Optional[List[str]] = None,
tgt_langs: Optional[List[str]] = None,
ids: Optional[List[str]] = None,
tgt_dict: Optional[Dictionary] = None,
pre_tokenizer=None,
bpe_tokenizer=None,
n_frames_per_step=1,
speaker_to_id=None,
durations: Optional[List[List[int]]] = None,
pitches: Optional[List[str]] = None,
energies: Optional[List[str]] = None,
):
super(TextToSpeechDataset, self).__init__(
split, is_train_split, cfg, audio_paths, n_frames,
src_texts=src_texts, tgt_texts=tgt_texts, speakers=speakers,
src_langs=src_langs, tgt_langs=tgt_langs, ids=ids,
tgt_dict=tgt_dict, pre_tokenizer=pre_tokenizer,
bpe_tokenizer=bpe_tokenizer, n_frames_per_step=n_frames_per_step,
speaker_to_id=speaker_to_id
split,
is_train_split,
cfg,
audio_paths,
n_frames,
src_texts=src_texts,
tgt_texts=tgt_texts,
speakers=speakers,
src_langs=src_langs,
tgt_langs=tgt_langs,
ids=ids,
tgt_dict=tgt_dict,
pre_tokenizer=pre_tokenizer,
bpe_tokenizer=bpe_tokenizer,
n_frames_per_step=n_frames_per_step,
speaker_to_id=speaker_to_id,
)
self.durations = durations
self.pitches = pitches
@ -84,9 +97,13 @@ class TextToSpeechDataset(SpeechToTextDataset):
np.concatenate((energy, [0])) # pad 0 for EOS
).float()
return TextToSpeechDatasetItem(
index=index, source=s2t_item.source, target=s2t_item.target,
speaker_id=s2t_item.speaker_id, duration=duration, pitch=pitch,
energy=energy
index=index,
source=s2t_item.source,
target=s2t_item.target,
speaker_id=s2t_item.speaker_id,
duration=duration,
pitch=pitch,
energy=energy,
)
def collater(self, samples: List[TextToSpeechDatasetItem]) -> Dict[str, Any]:
@ -96,8 +113,9 @@ class TextToSpeechDataset(SpeechToTextDataset):
src_lengths, order = torch.tensor(
[s.target.shape[0] for s in samples], dtype=torch.long
).sort(descending=True)
id_ = torch.tensor([s.index for s in samples],
dtype=torch.long).index_select(0, order)
id_ = torch.tensor([s.index for s in samples], dtype=torch.long).index_select(
0, order
)
feat = _collate_frames(
[s.source for s in samples], self.cfg.use_audio_input
).index_select(0, order)
@ -115,9 +133,11 @@ class TextToSpeechDataset(SpeechToTextDataset):
speaker = None
if self.speaker_to_id is not None:
speaker = torch.tensor(
[s.speaker_id for s in samples], dtype=torch.long
).index_select(0, order).view(-1, 1)
speaker = (
torch.tensor([s.speaker_id for s in samples], dtype=torch.long)
.index_select(0, order)
.view(-1, 1)
)
bsz, _, d = feat.size()
prev_output_tokens = torch.cat(
@ -175,7 +195,7 @@ class TextToSpeechDatasetCreator(SpeechToTextDatasetCreator):
pre_tokenizer,
bpe_tokenizer,
n_frames_per_step,
speaker_to_id
speaker_to_id,
) -> TextToSpeechDataset:
audio_root = Path(cfg.audio_root)
ids = [s[cls.KEY_ID] for s in samples]
@ -189,27 +209,40 @@ class TextToSpeechDatasetCreator(SpeechToTextDatasetCreator):
durations = [s.get(cls.KEY_DURATION, None) for s in samples]
durations = [
None if dd is None else [int(d) for d in dd.split(" ")]
for dd in durations
None if dd is None else [int(d) for d in dd.split(" ")] for dd in durations
]
durations = None if any(dd is None for dd in durations) else durations
pitches = [s.get(cls.KEY_PITCH, None) for s in samples]
pitches = [
None if pp is None else (audio_root / pp).as_posix()
for pp in pitches
None if pp is None else (audio_root / pp).as_posix() for pp in pitches
]
pitches = None if any(pp is None for pp in pitches) else pitches
energies = [s.get(cls.KEY_ENERGY, None) for s in samples]
energies = [
None if ee is None else (audio_root / ee).as_posix()
for ee in energies]
None if ee is None else (audio_root / ee).as_posix() for ee in energies
]
energies = None if any(ee is None for ee in energies) else energies
return TextToSpeechDataset(
split_name, is_train_split, cfg, audio_paths, n_frames,
src_texts, tgt_texts, speakers, src_langs, tgt_langs, ids, tgt_dict,
pre_tokenizer, bpe_tokenizer, n_frames_per_step, speaker_to_id,
durations, pitches, energies
split_name,
is_train_split,
cfg,
audio_paths,
n_frames,
src_texts,
tgt_texts,
speakers,
src_langs,
tgt_langs,
ids,
tgt_dict,
pre_tokenizer,
bpe_tokenizer,
n_frames_per_step,
speaker_to_id,
durations,
pitches,
energies,
)

View File

@ -9,7 +9,7 @@ from . import BaseWrapperDataset
class ColorizeDataset(BaseWrapperDataset):
""" Adds 'colors' property to net input that is obtained from the provided color getter for use by models """
"""Adds 'colors' property to net input that is obtained from the provided color getter for use by models"""
def __init__(self, dataset, color_getter):
super().__init__(dataset)

View File

@ -69,6 +69,7 @@ def collate_tokens(
copy_tensor(v, res[i][size - len(v) :] if left_pad else res[i][: len(v)])
return res
def load_indexed_dataset(
path, dictionary=None, dataset_impl=None, combine=False, default="cached"
):
@ -324,9 +325,7 @@ def batch_by_size(
)
# added int() to avoid TypeError: an integer is required
max_tokens = (
int(max_tokens) if max_tokens is not None else -1
)
max_tokens = int(max_tokens) if max_tokens is not None else -1
max_sentences = max_sentences if max_sentences is not None else -1
bsz_mult = required_batch_size_multiple
@ -375,8 +374,9 @@ def post_process(sentence: str, symbol: str):
sentence = sentence.replace(" ", "").replace("|", " ").strip()
elif symbol == "silence":
import re
sentence = sentence.replace("<SIL>", "")
sentence = re.sub(' +', ' ', sentence).strip()
sentence = re.sub(" +", " ", sentence).strip()
elif symbol == "_EOW":
sentence = sentence.replace(" ", "").replace("_EOW", " ").strip()
elif symbol in {"subword_nmt", "@@ ", "@@"}:
@ -547,7 +547,7 @@ def get_buckets(sizes, num_buckets):
np.percentile(
sizes,
np.linspace(0, 100, num_buckets + 1),
interpolation='lower',
interpolation="lower",
)[1:]
)
return buckets
@ -564,7 +564,6 @@ def get_bucketed_sizes(orig_sizes, buckets):
return sizes
def _find_extra_valid_paths(dataset_path: str) -> set:
paths = utils.split_paths(dataset_path)
all_valid_paths = set()

View File

@ -21,8 +21,10 @@ class SentencepieceConfig(FairseqDataclass):
)
sentencepiece_alpha: Optional[float] = field(
default=None,
metadata={"help": "soothing parameter for unigram sampling, "
"and merge probability for BPE-dropout"}
metadata={
"help": "soothing parameter for unigram sampling, "
"and merge probability for BPE-dropout"
},
)
@ -45,8 +47,7 @@ class SentencepieceBPE(object):
def encode(self, x: str) -> str:
return " ".join(
self.sp.Encode(
x, out_type=str, enable_sampling=self.enable_sampling,
alpha=self.alpha
x, out_type=str, enable_sampling=self.enable_sampling, alpha=self.alpha
)
)

View File

@ -138,7 +138,7 @@ class FairseqDataset(torch.utils.data.Dataset, EpochListening):
)
try:
num_tokens_vec = self.num_tokens_vec(indices).astype('int64')
num_tokens_vec = self.num_tokens_vec(indices).astype("int64")
except NotImplementedError:
num_tokens_vec = None

View File

@ -140,7 +140,9 @@ class HuffmanNode:
def is_leaf(self) -> bool:
return self.left is None and self.right is None
def code_table(self, prefix: tp.Optional[bitarray] = None) -> tp.Dict[str, "HuffmanNode"]:
def code_table(
self, prefix: tp.Optional[bitarray] = None
) -> tp.Dict[str, "HuffmanNode"]:
defaulted_prefix = prefix if prefix is not None else bitarray()
if self.is_leaf():
self.code = (

View File

@ -67,7 +67,9 @@ def make_builder(out_file, impl, vocab_size=None):
elif impl == "fasta":
raise NotImplementedError
elif impl == "huffman":
raise ValueError("Use HuffmanCodeBuilder directly as it has a different interface.")
raise ValueError(
"Use HuffmanCodeBuilder directly as it has a different interface."
)
else:
return IndexedDatasetBuilder(out_file)

View File

@ -380,7 +380,9 @@ class EpochBatchIterator(EpochBatchIterating):
# reset _frozen_batches to refresh the next epoch
self._frozen_batches = None
self._cur_epoch_itr = self._get_iterator_for_epoch(
self.epoch, shuffle, fix_batches_to_gpus=fix_batches_to_gpus,
self.epoch,
shuffle,
fix_batches_to_gpus=fix_batches_to_gpus,
)
self.shuffle = shuffle
return self._cur_epoch_itr
@ -421,7 +423,9 @@ class EpochBatchIterator(EpochBatchIterating):
if itr_pos > 0:
# fast-forward epoch iterator
self._next_epoch_itr = self._get_iterator_for_epoch(
self.epoch, shuffle=state_dict.get("shuffle", True), offset=itr_pos,
self.epoch,
shuffle=state_dict.get("shuffle", True),
offset=itr_pos,
)
if self._next_epoch_itr is None:
if version == 1:

View File

@ -114,7 +114,10 @@ def collate(
"id": id,
"nsentences": len(samples),
"ntokens": ntokens,
"net_input": {"src_tokens": src_tokens, "src_lengths": src_lengths,},
"net_input": {
"src_tokens": src_tokens,
"src_lengths": src_lengths,
},
"target": target,
}
if prev_output_tokens is not None:
@ -467,5 +470,8 @@ class LanguagePairDataset(FairseqDataset):
list: list of removed indices
"""
return data_utils.filter_paired_dataset_indices_by_size(
self.src_sizes, self.tgt_sizes, indices, max_sizes,
self.src_sizes,
self.tgt_sizes,
indices,
max_sizes,
)

View File

@ -80,7 +80,9 @@ class MultiCorpusDataset(FairseqDataset):
def ordered_indices(self):
start = time.time()
with data_utils.numpy_seed(self.seed, self.epoch):
logger.info(f"sampling new dataset with seed {self.seed} epoch {self.epoch}")
logger.info(
f"sampling new dataset with seed {self.seed} epoch {self.epoch}"
)
sampled_indices = []
num_selected_instances = 0

View File

@ -40,8 +40,8 @@ from fairseq.utils import FileContentsAction, csv_str_list, eval_str_dict
logger = logging.getLogger(__name__)
SRC_DICT_NAME = 'src'
TGT_DICT_NAME = 'tgt'
SRC_DICT_NAME = "src"
TGT_DICT_NAME = "tgt"
def _lang_id(dic: Dictionary, lang: str):
@ -64,14 +64,16 @@ class MultilingualDatasetManager(object):
self.seed = args.seed
self.lang_pairs = lang_pairs
self.extra_lang_pairs = (
list(
{p for _, v in args.extra_lang_pairs.items() for p in v.split(",")}
)
if args.extra_lang_pairs
else []
)
self.src_langs = {p.split("-")[0] for p in args.lang_pairs + self.extra_lang_pairs}
self.tgt_langs = {p.split("-")[1] for p in args.lang_pairs + self.extra_lang_pairs}
list({p for _, v in args.extra_lang_pairs.items() for p in v.split(",")})
if args.extra_lang_pairs
else []
)
self.src_langs = {
p.split("-")[0] for p in args.lang_pairs + self.extra_lang_pairs
}
self.tgt_langs = {
p.split("-")[1] for p in args.lang_pairs + self.extra_lang_pairs
}
self.langs = langs
self.dicts = dicts
self.lang_dict = self.create_lang_dictionary(self.langs)
@ -111,10 +113,18 @@ class MultilingualDatasetManager(object):
"note that the ordering determines language token IDs; "
"--langs and --lang-dict are two exclusive options",
)
parser.add_argument('--source-dict', default=None, type=str,
help='path to source dictionary; if specified it will override per language dictionary loading')
parser.add_argument('--target-dict', default=None, type=str,
help='path to target dictionary; if specified it will override per language dictionary loading')
parser.add_argument(
"--source-dict",
default=None,
type=str,
help="path to source dictionary; if specified it will override per language dictionary loading",
)
parser.add_argument(
"--target-dict",
default=None,
type=str,
help="path to target dictionary; if specified it will override per language dictionary loading",
)
parser.add_argument(
"--lang-tok-style",
default=LangTokStyle.multilingual.value,
@ -378,7 +388,9 @@ class MultilingualDatasetManager(object):
)
return d
dicts = cls.load_all_dictionaries(args, language_list, load_dictionary_and_postproc, training)
dicts = cls.load_all_dictionaries(
args, language_list, load_dictionary_and_postproc, training
)
return language_list, dicts, training
@classmethod
@ -424,7 +436,10 @@ class MultilingualDatasetManager(object):
if args.fixed_dictionary is not None:
fixed_dict = load_dictionary(args.fixed_dictionary)
dicts = {lang: fixed_dict for lang in src_langs_to_load_dicts + tgt_langs_to_load_dicts}
dicts = {
lang: fixed_dict
for lang in src_langs_to_load_dicts + tgt_langs_to_load_dicts
}
else:
if args.source_dict is None:
load_dicts(src_langs_to_load_dicts)
@ -477,7 +492,10 @@ class MultilingualDatasetManager(object):
lang=tgt_lang, lang_tok_style=self.args.lang_tok_style, spec=spec
)
return self.get_langtok_index(
langtok, self.get_source_dictionary(src_lang) if src_lang else self.get_target_dictionary(tgt_lang)
langtok,
self.get_source_dictionary(src_lang)
if src_lang
else self.get_target_dictionary(tgt_lang),
)
def get_decoder_langtok(self, tgt_lang, spec=None):
@ -819,7 +837,9 @@ class MultilingualDatasetManager(object):
if self.args.lang_tok_replacing_bos_eos:
ds = self.alter_dataset_langtok(
langpair_ds,
src_eos=self.get_source_dictionary(src).eos() if src else self.get_target_dictionary(tgt).eos(),
src_eos=self.get_source_dictionary(src).eos()
if src
else self.get_target_dictionary(tgt).eos(),
src_lang=src,
tgt_eos=self.get_target_dictionary(tgt).eos(),
tgt_lang=tgt,

View File

@ -298,7 +298,6 @@ class NoisingDataset(torch.utils.data.Dataset):
)
self.sizes = src_dataset.sizes
def __getitem__(self, index):
"""
Returns a single noisy sample. Multiple samples are fed to the collater

View File

@ -14,8 +14,7 @@ class TextCompressionLevel(Enum):
class TextCompressor(object):
def __init__(
self, level: TextCompressionLevel,
max_input_byte_length: int = 2 ** 16
self, level: TextCompressionLevel, max_input_byte_length: int = 2 ** 16
):
self.level = level
self.max_input_length = max_input_byte_length
@ -23,11 +22,13 @@ class TextCompressor(object):
def compress(self, text: str) -> bytes:
if self.level == TextCompressionLevel.low:
import zlib
# zlib: built-in, fast
return zlib.compress(text.encode(), level=0)
elif self.level == TextCompressionLevel.high:
try:
import unishox2
# unishox2: optimized for short text but slower
except ImportError:
raise ImportError(
@ -42,6 +43,7 @@ class TextCompressor(object):
def decompress(self, compressed: bytes) -> str:
if self.level == TextCompressionLevel.low:
import zlib
return zlib.decompress(compressed).decode()
elif self.level == TextCompressionLevel.high:
try:

View File

@ -69,7 +69,10 @@ class TokenBlockDataset(FairseqDataset):
_sizes, split_path, (plasma_id, 1), plasma_path=plasma_path
)
self._block_to_dataset_index = plasma_utils.PlasmaView(
block_to_dataset_index, split_path, (plasma_id, 2), plasma_path=plasma_path,
block_to_dataset_index,
split_path,
(plasma_id, 2),
plasma_path=plasma_path,
)
else:
self._slice_indices = plasma_utils.PlasmaArray(slice_indices)
@ -127,7 +130,8 @@ class TokenBlockDataset(FairseqDataset):
)
else:
block_to_dataset_index = _get_block_to_dataset_index_fast(
sizes, slice_indices,
sizes,
slice_indices,
)
size_dtype = np.uint16 if block_size < 65535 else np.uint32
num_tokens = slice_indices[-1].max()

View File

@ -52,7 +52,7 @@ class TransformEosLangPairDataset(FairseqDataset):
if len(samples) == 0:
return samples
if 'net_input' not in samples:
if "net_input" not in samples:
return samples
if self.new_src_eos is not None:

View File

@ -126,7 +126,8 @@ class CommonConfig(FairseqDataclass):
metadata={"help": "Weights and Biases project name to use for logging"},
)
azureml_logging: Optional[bool] = field(
default=False, metadata={"help": "Log scalars to AzureML context"},
default=False,
metadata={"help": "Log scalars to AzureML context"},
)
seed: int = field(
default=1, metadata={"help": "pseudo random number generator seed"}
@ -428,19 +429,23 @@ class DistributedTrainingConfig(FairseqDataclass):
tpu: bool = II("common.tpu")
# configuration for --ddp-backend=fully_sharded
no_reshard_after_forward: bool = field(
default=False, metadata={"help": "don't reshard parameters after forward pass"},
default=False,
metadata={"help": "don't reshard parameters after forward pass"},
)
fp32_reduce_scatter: bool = field(
default=False, metadata={"help": "reduce-scatter grads in FP32"},
default=False,
metadata={"help": "reduce-scatter grads in FP32"},
)
cpu_offload: bool = field(
default=False, metadata={"help": "offload FP32 params to CPU"}
)
use_sharded_state: bool = field(
default=False, metadata={"help": "use sharded checkpoint files"},
default=False,
metadata={"help": "use sharded checkpoint files"},
)
not_fsdp_flatten_parameters: bool = field(
default=False, metadata={"help": "not flatten parameter param for fsdp"},
default=False,
metadata={"help": "not flatten parameter param for fsdp"},
)
@ -786,10 +791,12 @@ class FairseqBMUFConfig(FairseqDataclass):
@dataclass
class GenerationConfig(FairseqDataclass):
beam: int = field(
default=5, metadata={"help": "beam size"},
default=5,
metadata={"help": "beam size"},
)
nbest: int = field(
default=1, metadata={"help": "number of hypotheses to output"},
default=1,
metadata={"help": "number of hypotheses to output"},
)
max_len_a: float = field(
default=0,
@ -804,19 +811,24 @@ class GenerationConfig(FairseqDataclass):
},
)
min_len: int = field(
default=1, metadata={"help": "minimum generation length"},
default=1,
metadata={"help": "minimum generation length"},
)
match_source_len: bool = field(
default=False, metadata={"help": "generations should match the source length"},
default=False,
metadata={"help": "generations should match the source length"},
)
unnormalized: bool = field(
default=False, metadata={"help": "compare unnormalized hypothesis scores"},
default=False,
metadata={"help": "compare unnormalized hypothesis scores"},
)
no_early_stop: bool = field(
default=False, metadata={"help": "deprecated"},
default=False,
metadata={"help": "deprecated"},
)
no_beamable_mm: bool = field(
default=False, metadata={"help": "don't use BeamableMM in attention layers"},
default=False,
metadata={"help": "don't use BeamableMM in attention layers"},
)
lenpen: float = field(
default=1,
@ -838,10 +850,12 @@ class GenerationConfig(FairseqDataclass):
},
)
sacrebleu: bool = field(
default=False, metadata={"help": "score with sacrebleu"},
default=False,
metadata={"help": "score with sacrebleu"},
)
score_reference: bool = field(
default=False, metadata={"help": "just score the reference translation"},
default=False,
metadata={"help": "just score the reference translation"},
)
prefix_size: int = field(
default=0,
@ -875,10 +889,12 @@ class GenerationConfig(FairseqDataclass):
},
)
temperature: float = field(
default=1.0, metadata={"help": "temperature for generation"},
default=1.0,
metadata={"help": "temperature for generation"},
)
diverse_beam_groups: int = field(
default=-1, metadata={"help": "number of groups for Diverse Beam Search"},
default=-1,
metadata={"help": "number of groups for Diverse Beam Search"},
)
diverse_beam_strength: float = field(
default=0.5,
@ -897,13 +913,16 @@ class GenerationConfig(FairseqDataclass):
},
)
print_step: bool = field(
default=False, metadata={"help": "print steps"},
default=False,
metadata={"help": "print steps"},
)
lm_path: Optional[str] = field(
default=None, metadata={"help": "path to lm checkpoint for lm fusion"},
default=None,
metadata={"help": "path to lm checkpoint for lm fusion"},
)
lm_weight: float = field(
default=0.0, metadata={"help": "weight for lm probs for lm fusion"},
default=0.0,
metadata={"help": "weight for lm probs for lm fusion"},
)
# arguments for iterative refinement generator
@ -912,7 +931,8 @@ class GenerationConfig(FairseqDataclass):
metadata={"help": "if > 0.0, it penalized early-stopping in decoding."},
)
iter_decode_max_iter: int = field(
default=10, metadata={"help": "maximum iterations for iterative refinement."},
default=10,
metadata={"help": "maximum iterations for iterative refinement."},
)
iter_decode_force_max_iter: bool = field(
default=False,
@ -939,7 +959,8 @@ class GenerationConfig(FairseqDataclass):
},
)
retain_dropout: bool = field(
default=False, metadata={"help": "Use dropout at inference time"},
default=False,
metadata={"help": "Use dropout at inference time"},
)
# temporarily set to Any until https://github.com/facebookresearch/hydra/issues/1117 is fixed
# retain_dropout_modules: Optional[List[str]] = field(
@ -964,7 +985,8 @@ class GenerationConfig(FairseqDataclass):
@dataclass
class CommonEvalConfig(FairseqDataclass):
path: Optional[str] = field(
default=None, metadata={"help": "path(s) to model file(s), colon separated"},
default=None,
metadata={"help": "path(s) to model file(s), colon separated"},
)
post_process: Optional[str] = field(
default=None,
@ -1026,7 +1048,8 @@ class InteractiveConfig(FairseqDataclass):
},
)
input: str = field(
default="-", metadata={"help": "file to read from; use - for stdin"},
default="-",
metadata={"help": "file to read from; use - for stdin"},
)

View File

@ -35,14 +35,16 @@ def ChoiceEnum(choices: List[str]):
LOG_FORMAT_CHOICES = ChoiceEnum(["json", "none", "simple", "tqdm"])
DDP_BACKEND_CHOICES = ChoiceEnum([
"c10d", # alias for pytorch_ddp
"fully_sharded", # FullyShardedDataParallel from fairscale
"legacy_ddp",
"no_c10d", # alias for legacy_ddp
"pytorch_ddp",
"slowmo",
])
DDP_BACKEND_CHOICES = ChoiceEnum(
[
"c10d", # alias for pytorch_ddp
"fully_sharded", # FullyShardedDataParallel from fairscale
"legacy_ddp",
"no_c10d", # alias for legacy_ddp
"pytorch_ddp",
"slowmo",
]
)
DDP_COMM_HOOK_CHOICES = ChoiceEnum(["none", "fp16"])
DATASET_IMPL_CHOICES = ChoiceEnum(["raw", "lazy", "cached", "mmap", "fasta", "huffman"])
GENERATION_CONSTRAINTS_CHOICES = ChoiceEnum(["ordered", "unordered"])

View File

@ -28,7 +28,7 @@ def hydra_init(cfg_name="config") -> None:
def add_defaults(cfg: DictConfig) -> None:
"""This function adds default values that are stored in dataclasses that hydra doesn't know about """
"""This function adds default values that are stored in dataclasses that hydra doesn't know about"""
from fairseq.registry import REGISTRIES
from fairseq.tasks import TASK_DATACLASS_REGISTRY

View File

@ -57,21 +57,21 @@ def gen_parser_from_dataclass(
with_prefix: Optional[str] = None,
) -> None:
"""
convert a dataclass instance to tailing parser arguments.
convert a dataclass instance to tailing parser arguments.
If `with_prefix` is provided, prefix all the keys in the resulting parser with it. It means that we are
building a flat namespace from a structured dataclass (see transformer_config.py for example).
If `with_prefix` is provided, prefix all the keys in the resulting parser with it. It means that we are
building a flat namespace from a structured dataclass (see transformer_config.py for example).
"""
def argparse_name(name: str):
if name == "data" and (with_prefix is None or with_prefix == ''):
if name == "data" and (with_prefix is None or with_prefix == ""):
# normally data is positional args, so we don't add the -- nor the prefix
return name
if name == "_name":
# private member, skip
return None
full_name = "--" + name.replace("_", "-")
if with_prefix is not None and with_prefix != '':
if with_prefix is not None and with_prefix != "":
# if a prefix is specified, construct the prefixed arg name
full_name = with_prefix + "-" + full_name[2:] # strip -- when composing
return full_name
@ -143,8 +143,8 @@ def gen_parser_from_dataclass(
kwargs["default"] = field_default
# build the help with the hierarchical prefix
if with_prefix is not None and with_prefix != '' and field_help is not None:
field_help = with_prefix[2:] + ': ' + field_help
if with_prefix is not None and with_prefix != "" and field_help is not None:
field_help = with_prefix[2:] + ": " + field_help
kwargs["help"] = field_help
if field_const is not None:

View File

@ -4,7 +4,11 @@
# LICENSE file in the root directory of this source tree.
from .distributed_timeout_wrapper import DistributedTimeoutWrapper
from .fully_sharded_data_parallel import fsdp_enable_wrap, fsdp_wrap, FullyShardedDataParallel
from .fully_sharded_data_parallel import (
fsdp_enable_wrap,
fsdp_wrap,
FullyShardedDataParallel,
)
from .legacy_distributed_data_parallel import LegacyDistributedDataParallel
from .module_proxy_wrapper import ModuleProxyWrapper
from .tpu_distributed_data_parallel import TPUDistributedDataParallel

View File

@ -33,6 +33,7 @@ class DistributedTimeoutWrapper(nn.Module):
(set to a value <= 0 to disable the timeout)
signal (Optional): signal to send once timeout is triggered
"""
def __init__(self, module: nn.Module, timeout: int, signal=signal.SIGINT):
super().__init__()
self.module = module
@ -86,9 +87,11 @@ class DistributedTimeoutWrapper(nn.Module):
if self._terminated:
break
elif not success:
logger.error((
"Killing job for not making progress in {} seconds. "
"Set --heartbeat-timeout=-1 to disable this timeout."
).format(int(self.timeout)))
logger.error(
(
"Killing job for not making progress in {} seconds. "
"Set --heartbeat-timeout=-1 to disable this timeout."
).format(int(self.timeout))
)
os.kill(parent_pid, self.signal)
return

View File

@ -137,7 +137,7 @@ class LegacyDistributedDataParallel(nn.Module):
if param.grad is None:
param.grad = torch.zeros_like(param)
if hasattr(param, 'expert'):
if hasattr(param, "expert"):
# Skip gradient sync for unshared parameters
continue

View File

@ -26,8 +26,9 @@ class ModuleProxyWrapper(nn.Module):
def __init__(self, module: nn.Module):
super().__init__()
assert hasattr(module, "module"), \
"ModuleProxyWrapper expects input to wrap another module"
assert hasattr(
module, "module"
), "ModuleProxyWrapper expects input to wrap another module"
self.module = module
def __getattr__(self, name):

View File

@ -10,7 +10,6 @@ from fairseq.distributed import utils
class TPUDistributedDataParallel(nn.Module):
def __init__(self, module, process_group):
super().__init__()
self.module = module
@ -35,9 +34,10 @@ class TPUDistributedDataParallel(nn.Module):
gradients.append(p.grad)
import torch_xla.core.xla_model as xm
xm.all_reduce(
'sum',
"sum",
gradients,
scale=1. / self.world_size,
scale=1.0 / self.world_size,
groups=self.process_group[1],
)

View File

@ -201,9 +201,7 @@ def _pipeline_parallel_post_init(
# distributed_world_size to be based on the total number of GPUs, so
# we need to correct them to be based on the number of pipelines.
assert cfg.distributed_world_size % num_pipeline_devices == 0
cfg.distributed_world_size = (
cfg.distributed_world_size // num_pipeline_devices
)
cfg.distributed_world_size = cfg.distributed_world_size // num_pipeline_devices
# In the case of 4-way MP on nodes with 8 GPUs, we want
# distributed_rank to be the starting GPU index for each pipeline
# i.e., 0, 2, ...
@ -306,8 +304,10 @@ def distributed_init(cfg: FairseqConfig):
model_part_number = get_model_parallel_rank()
cfg.checkpoint.checkpoint_suffix += "-model_part-{0}".format(model_part_number)
if hasattr(cfg, "model") and getattr(cfg.model, "base_layers", 0) > 0:
cfg.checkpoint.checkpoint_suffix = f"-rank-{cfg.distributed_training.distributed_rank}"
if hasattr(cfg, "model") and getattr(cfg.model, "base_layers", 0) > 0:
cfg.checkpoint.checkpoint_suffix = (
f"-rank-{cfg.distributed_training.distributed_rank}"
)
return cfg.distributed_training.distributed_rank
@ -696,7 +696,7 @@ def broadcast_tensors(
dist_device = torch.device("cpu")
# share metadata first to simplify transfer
is_src_rank = (get_rank(group) == src_rank)
is_src_rank = get_rank(group) == src_rank
if is_src_rank:
metadata = [
{"size": t.size(), "dtype": t.dtype, "device": t.device} for t in tensors
@ -747,7 +747,10 @@ def broadcast_object(
def _broadcast_object_slow(
obj: Any, src_rank: int, group: object, dist_device: torch.device,
obj: Any,
src_rank: int,
group: object,
dist_device: torch.device,
) -> Any:
if get_rank(group) == src_rank:
# Emit data

View File

@ -152,6 +152,7 @@ class PathManager:
"""
ioPath async PathManager methods:
"""
@staticmethod
def opena(
path: str,
@ -169,6 +170,7 @@ class PathManager:
logging.info("ioPath is initializing PathManager.")
try:
from iopath.common.file_io import PathManager
IOPathManager = PathManager()
except Exception:
logging.exception("Failed to initialize ioPath PathManager object.")

View File

@ -146,6 +146,7 @@ def cached_path_from_pm(url_or_filename):
"""
try:
from fairseq.file_io import PathManager
local_path = PathManager.get_local_path(url_or_filename)
return local_path
except Exception:

View File

@ -130,6 +130,7 @@ def log_scalar(
agg.add_meter(key, AverageMeter(round=round), priority)
agg[key].update(value, weight)
def log_scalar_sum(
key: str,
value: float,
@ -309,6 +310,7 @@ def load_state_dict(state_dict):
def xla_metrics_report():
try:
import torch_xla.debug.metrics as met
print(met.metrics_report())
except ImportError:
return

View File

@ -52,8 +52,7 @@ class MegatronTrainer(Trainer):
def save_checkpoint(self, filename, extra_state):
"""Save all training state in a checkpoint file."""
extra_state['rng_tracker_states'] \
= get_cuda_rng_tracker().get_states()
extra_state["rng_tracker_states"] = get_cuda_rng_tracker().get_states()
super().save_checkpoint(filename, extra_state)
def load_checkpoint(
@ -64,8 +63,13 @@ class MegatronTrainer(Trainer):
optimizer_overrides=None,
reset_meters=False,
):
extra_state = super().load_checkpoint(filename, reset_optimizer=reset_optimizer, reset_lr_scheduler=reset_lr_scheduler, optimizer_overrides=optimizer_overrides, reset_meters=reset_meters)
if extra_state is not None and 'rng_tracker_states' in extra_state:
get_cuda_rng_tracker().set_states(
extra_state['rng_tracker_states'])
extra_state = super().load_checkpoint(
filename,
reset_optimizer=reset_optimizer,
reset_lr_scheduler=reset_lr_scheduler,
optimizer_overrides=optimizer_overrides,
reset_meters=reset_meters,
)
if extra_state is not None and "rng_tracker_states" in extra_state:
get_cuda_rng_tracker().set_states(extra_state["rng_tracker_states"])
return extra_state

View File

@ -9,6 +9,7 @@ from collections import namedtuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from fairseq import options, utils
from fairseq.modules import (
AdaptiveSoftmax,
@ -17,7 +18,6 @@ from fairseq.modules import (
PositionalEmbedding,
)
EncoderOut = namedtuple(
"TransformerEncoderOut",
[
@ -30,7 +30,7 @@ EncoderOut = namedtuple(
class TransformerEncoderEmbedding(nn.Module):
""" Encoder Embedding + Positional Embedding """
"""Encoder Embedding + Positional Embedding"""
def __init__(self, args, embed_tokens):
super().__init__()
@ -109,7 +109,7 @@ class TransformerEncoderLayerNorm(nn.Module):
class TransformerDecoderEmbedding(nn.Module):
""" Decoder Embedding + Positional Embedding """
"""Decoder Embedding + Positional Embedding"""
def __init__(self, args, embed_tokens):
super().__init__()

View File

@ -42,16 +42,20 @@ DEFAULT_MAX_TARGET_POSITIONS = 1024
TORCH_PIPE = False
RPC_INIT = False
def import_pipe():
global TORCH_PIPE
global RPC_INIT
try:
from torch.distributed.pipeline.sync import Pipe # noqa
from torch.distributed.pipeline.sync import Pipe # noqa
global Pipe
from torch.distributed.pipeline.sync.utils import partition_model
global partition_model
from torch.distributed import rpc
import tempfile
TORCH_PIPE = True
# Initialize single process RPC agent since TORCH_PIPE requires
# RRef. RRef depends on RPC being initialized and as a result we initialize
@ -64,14 +68,15 @@ def import_pipe():
world_size=1,
rpc_backend_options=rpc.TensorPipeRpcBackendOptions(
init_method="file://{}".format(tmpfile.name),
)
),
)
RPC_INIT = True
logger.info('Using torch pipe')
logger.info("Using torch pipe")
except ImportError:
try:
from fairscale.nn import Pipe # noqa
logger.info('Using fairscale pipe')
from fairscale.nn import Pipe # noqa
logger.info("Using fairscale pipe")
except ImportError:
raise ImportError("Please install fairscale with: pip install fairscale")
@ -153,9 +158,14 @@ class PipelineParallelTransformerModel(BaseFairseqModel):
decoder_module_list.append(module)
module_count += 1
self.model = None
self.encoder = TransformerEncoder(cfg.distributed_training, None, None, encoder_module_list)
self.encoder = TransformerEncoder(
cfg.distributed_training, None, None, encoder_module_list
)
self.decoder = TransformerDecoder(
cfg.distributed_training, None, None, decoder_module_list=decoder_module_list
cfg.distributed_training,
None,
None,
decoder_module_list=decoder_module_list,
)
@staticmethod
@ -471,7 +481,9 @@ class TransformerEncoder(FairseqEncoder):
self.use_pipeline = encoder_module_list is not None
if not self.use_pipeline:
self.embedding_layer = TransformerEncoderEmbedding(args, embed_tokens)
self.encoder_layers = nn.Sequential(*[TransformerEncoderLayer(args) for i in range(args.encoder_layers)])
self.encoder_layers = nn.Sequential(
*[TransformerEncoderLayer(args) for i in range(args.encoder_layers)]
)
if isinstance(embed_tokens, nn.ModuleList):
emb_dim = sum(e.embedding_dim for e in embed_tokens)
else:
@ -490,7 +502,11 @@ class TransformerEncoder(FairseqEncoder):
)
if TORCH_PIPE:
self.model = Pipe(
module=partition_model(nn.Sequential(*encoder_module_list), encoder_balance, encoder_devices),
module=partition_model(
nn.Sequential(*encoder_module_list),
encoder_balance,
encoder_devices,
),
chunks=args.pipeline_chunks,
checkpoint=args.pipeline_checkpoint,
)
@ -614,10 +630,12 @@ class TransformerDecoder(FairseqDecoder):
self.use_pipeline = decoder_module_list is not None
if not self.use_pipeline:
self.embedding_layer = TransformerDecoderEmbedding(args, embed_tokens)
self.decoder_layers = nn.Sequential(*[
TransformerDecoderLayer(args, no_encoder_attn)
for _ in range(args.decoder_layers)
])
self.decoder_layers = nn.Sequential(
*[
TransformerDecoderLayer(args, no_encoder_attn)
for _ in range(args.decoder_layers)
]
)
self.decoder_output_layer = TransformerDecoderOutputLayer(
args, embed_tokens, dictionary
)
@ -634,7 +652,11 @@ class TransformerDecoder(FairseqDecoder):
)
if TORCH_PIPE:
self.model = Pipe(
module=partition_model(nn.Sequential(*decoder_module_list), decoder_balance, decoder_devices),
module=partition_model(
nn.Sequential(*decoder_module_list),
decoder_balance,
decoder_devices,
),
chunks=args.pipeline_chunks,
checkpoint=args.pipeline_checkpoint,
)

View File

@ -4,11 +4,11 @@
# LICENSE file in the root directory of this source tree.
import torch.nn as nn
from fairseq.model_parallel.models.transformer import ModelParallelTransformerDecoder
from fairseq.models import register_model, register_model_architecture
from fairseq.models.transformer_lm import TransformerLanguageModel
try:
from fairseq.model_parallel.megatron.mpu import VocabParallelEmbedding
@ -22,7 +22,6 @@ DEFAULT_MAX_TARGET_POSITIONS = 1024
@register_model("model_parallel_transformer_lm")
class ModelParallelTransformerLanguageModel(TransformerLanguageModel):
@staticmethod
def add_args(parser):
TransformerLanguageModel.add_args(parser)
@ -72,10 +71,6 @@ class ModelParallelTransformerLanguageModel(TransformerLanguageModel):
)
return cls(decoder)
@staticmethod
def add_args(parser):
TransformerLanguageModel.add_args(parser)
@classmethod
def build_embedding(cls, args, dictionary, embed_dim, path=None):
def _vocab_init(tensor, **kwargs):

View File

@ -98,9 +98,7 @@ def build_model(cfg: FairseqDataclass, task):
assert model is not None, (
f"Could not infer model type from {cfg}. "
"Available models: {}".format(
MODEL_DATACLASS_REGISTRY.keys()
)
"Available models: {}".format(MODEL_DATACLASS_REGISTRY.keys())
+ f" Requested model type: {model_type}"
)

View File

@ -100,8 +100,8 @@ class BARTHubInterface(GeneratorHubInterface):
raise NotImplementedError("prefix generation not implemented for BART")
res = []
for batch in self._build_batches(tokenized_sentences, skip_invalid_size_inputs):
src_tokens = batch['net_input']['src_tokens']
inference_step_args["prefix_tokens"] =src_tokens.new_full(
src_tokens = batch["net_input"]["src_tokens"]
inference_step_args["prefix_tokens"] = src_tokens.new_full(
(src_tokens.size(0), 1), fill_value=self.task.source_dictionary.bos()
).to(device=self.device)
results = super().generate(
@ -111,7 +111,7 @@ class BARTHubInterface(GeneratorHubInterface):
skip_invalid_size_inputs=skip_invalid_size_inputs,
**kwargs
)
for id, hypos in zip(batch['id'].tolist(), results):
for id, hypos in zip(batch["id"].tolist(), results):
res.append((id, hypos))
res = [hypos for _, hypos in sorted(res, key=lambda x: x[0])]
return res
@ -177,32 +177,35 @@ class BARTHubInterface(GeneratorHubInterface):
match_source_len: bool = True,
**generate_kwargs
):
masked_token = '<mask>'
masked_token = "<mask>"
batch_tokens = []
for masked_input in masked_inputs:
assert masked_token in masked_input, \
"please add one {} token for the input".format(masked_token)
assert (
masked_token in masked_input
), "please add one {} token for the input".format(masked_token)
text_spans = masked_input.split(masked_token)
text_spans_bpe = (' {0} '.format(masked_token)).join(
[self.bpe.encode(text_span.rstrip()) for text_span in text_spans]
).strip()
text_spans_bpe = (
(" {0} ".format(masked_token))
.join([self.bpe.encode(text_span.rstrip()) for text_span in text_spans])
.strip()
)
tokens = self.task.source_dictionary.encode_line(
'<s> ' + text_spans_bpe + ' </s>',
"<s> " + text_spans_bpe + " </s>",
append_eos=False,
add_if_not_exist=False,
).long()
batch_tokens.append(tokens)
# ensure beam size is at least as big as topk
generate_kwargs['beam'] = max(
generate_kwargs["beam"] = max(
topk,
generate_kwargs.get('beam', -1),
generate_kwargs.get("beam", -1),
)
generate_kwargs['match_source_len'] = match_source_len
generate_kwargs["match_source_len"] = match_source_len
batch_hypos = self.generate(batch_tokens, **generate_kwargs)
return [
[(self.decode(hypo['tokens']), hypo['score']) for hypo in hypos[:topk]]
[(self.decode(hypo["tokens"]), hypo["score"]) for hypo in hypos[:topk]]
for hypos in batch_hypos
]

View File

@ -90,7 +90,7 @@ class BARTModel(TransformerModel):
src_tokens,
src_lengths=src_lengths,
token_embeddings=token_embeddings,
return_all_hiddens=return_all_hiddens
return_all_hiddens=return_all_hiddens,
)
x, extra = self.decoder(
prev_output_tokens,
@ -103,9 +103,9 @@ class BARTModel(TransformerModel):
)
eos: int = self.eos
if classification_head_name is not None:
sentence_representation = x[
src_tokens.eq(eos), :
].view(x.size(0), -1, x.size(-1))[:, -1, :]
sentence_representation = x[src_tokens.eq(eos), :].view(
x.size(0), -1, x.size(-1)
)[:, -1, :]
for k, head in self.classification_heads.items():
# for torch script only supports iteration
if k == classification_head_name:

View File

@ -25,7 +25,10 @@ logger = logging.getLogger(__name__)
_SLOWMO_DDP_DISABLED = False
try:
from fairscale.experimental.nn.data_parallel import SlowMoBaseAlgorithm, SlowMoDistributedDataParallel
from fairscale.experimental.nn.data_parallel import (
SlowMoBaseAlgorithm,
SlowMoDistributedDataParallel,
)
except ImportError:
_SLOWMO_DDP_DISABLED = True

View File

@ -22,6 +22,7 @@ import copy
import logging
import torch
from fairseq import checkpoint_utils
@ -78,7 +79,9 @@ class EMA(object):
self.fp32_params = {}
if self.config.ema_seed_model is not None:
state = checkpoint_utils.load_ema_from_checkpoint(self.config.ema_seed_model)
state = checkpoint_utils.load_ema_from_checkpoint(
self.config.ema_seed_model
)
self.model.load_state_dict(state["model"], strict=True)
if device is not None:
@ -119,7 +122,7 @@ class EMA(object):
self.fp32_params[param_key] = _to_float(state_dict[param_key])
def restore(self, state_dict, build_fp32_params=False):
""" Load data from a model spec into EMA model """
"""Load data from a model spec into EMA model"""
self.model.load_state_dict(state_dict, strict=False)
if build_fp32_params:
self.build_fp32_params(state_dict)
@ -131,16 +134,20 @@ class EMA(object):
return self.decay
def _step_internal(self, new_model, updates=None):
""" One update of the EMA model based on new model weights """
"""One update of the EMA model based on new model weights"""
decay = self.decay
ema_state_dict = {}
ema_params = self.fp32_params if self.config.ema_fp32 else self.model.state_dict()
ema_params = (
self.fp32_params if self.config.ema_fp32 else self.model.state_dict()
)
for key, param in new_model.state_dict().items():
try:
ema_param = ema_params[key]
except KeyError:
ema_param = param.float().clone() if param.ndim == 1 else copy.deepcopy(param)
ema_param = (
param.float().clone() if param.ndim == 1 else copy.deepcopy(param)
)
if param.shape != ema_param.shape:
raise ValueError(
@ -151,7 +158,7 @@ class EMA(object):
# Do not decay a model.version pytorch param
continue
ema_param.mul_(decay)
ema_param.add_(param.to(dtype=ema_param.dtype), alpha=1-decay)
ema_param.add_(param.to(dtype=ema_param.dtype), alpha=1 - decay)
ema_state_dict[key] = ema_param
self.restore(ema_state_dict, build_fp32_params=False)
@ -168,8 +175,7 @@ class EMA(object):
"""
self._set_decay(
0
if updates is not None
and updates < self.config.ema_start_update
if updates is not None and updates < self.config.ema_start_update
else self.config.ema_decay
)
if updates is not None and self.config.ema_update_freq > 1:

View File

@ -19,7 +19,6 @@ class FairseqDecoder(nn.Module):
self.onnx_trace = False
self.adaptive_softmax = None
def forward(self, prev_output_tokens, encoder_out=None, **kwargs):
"""
Args:

View File

@ -29,8 +29,9 @@ logger = logging.getLogger(__name__)
def check_type(module, expected_type):
if hasattr(module, "unwrapped_module"):
assert isinstance(module.unwrapped_module, expected_type), \
f"{type(module.unwrapped_module)} != {expected_type}"
assert isinstance(
module.unwrapped_module, expected_type
), f"{type(module.unwrapped_module)} != {expected_type}"
else:
assert isinstance(module, expected_type), f"{type(module)} != {expected_type}"
@ -114,7 +115,9 @@ class BaseFairseqModel(nn.Module):
"""
if model_cfg is None and args is not None:
logger.warn("using 'args' is deprecated, please update your code to use dataclass config")
logger.warn(
"using 'args' is deprecated, please update your code to use dataclass config"
)
model_cfg = convert_namespace_to_omegaconf(args).model
self.upgrade_state_dict(state_dict)
@ -454,7 +457,9 @@ class FairseqMultiModel(BaseFairseqModel):
"""
if model_cfg is None and args is not None:
logger.warn("using 'args' is deprecated, please update your code to use dataclass config")
logger.warn(
"using 'args' is deprecated, please update your code to use dataclass config"
)
model_cfg = convert_namespace_to_omegaconf(args).model
self.upgrade_state_dict(state_dict)

View File

@ -30,9 +30,7 @@ from omegaconf import II
logger = logging.getLogger(__name__)
EXTRACTOR_MODE_CHOICES = ChoiceEnum(["default", "layer_norm"])
MASKING_DISTRIBUTION_CHOICES = ChoiceEnum(
["static", "uniform", "normal", "poisson"]
)
MASKING_DISTRIBUTION_CHOICES = ChoiceEnum(["static", "uniform", "normal", "poisson"])
@dataclass
@ -86,9 +84,7 @@ class HubertConfig(FairseqDataclass):
)
dropout_features: float = field(
default=0.0,
metadata={
"help": "dropout to apply to the features (after feat extr)"
},
metadata={"help": "dropout to apply to the features (after feat extr)"},
)
final_dim: int = field(
@ -150,9 +146,7 @@ class HubertConfig(FairseqDataclass):
)
mask_min_space: int = field(
default=1,
metadata={
"help": "min space between spans (if no overlap is enabled)"
},
metadata={"help": "min space between spans (if no overlap is enabled)"},
)
# channel masking
@ -182,23 +176,17 @@ class HubertConfig(FairseqDataclass):
)
mask_channel_min_space: int = field(
default=1,
metadata={
"help": "min space between spans (if no overlap is enabled)"
},
metadata={"help": "min space between spans (if no overlap is enabled)"},
)
# positional embeddings
conv_pos: int = field(
default=128,
metadata={
"help": "number of filters for convolutional positional embeddings"
},
metadata={"help": "number of filters for convolutional positional embeddings"},
)
conv_pos_groups: int = field(
default=16,
metadata={
"help": "number of groups for convolutional positional embedding"
},
metadata={"help": "number of groups for convolutional positional embedding"},
)
latent_temp: Tuple[float, float, float] = field(
@ -238,9 +226,7 @@ class HubertModel(BaseFairseqModel):
conv_bias=cfg.conv_bias,
)
feature_ds_rate = np.prod([s for _, _, s in feature_enc_layers])
self.feat2tar_ratio = (
cfg.label_rate * feature_ds_rate / task_cfg.sample_rate
)
self.feat2tar_ratio = cfg.label_rate * feature_ds_rate / task_cfg.sample_rate
self.post_extract_proj = (
nn.Linear(self.embed, cfg.encoder_embed_dim)
@ -270,9 +256,7 @@ class HubertModel(BaseFairseqModel):
self.skip_masked = cfg.skip_masked
self.skip_nomask = cfg.skip_nomask
final_dim = (
cfg.final_dim if cfg.final_dim > 0 else cfg.encoder_embed_dim
)
final_dim = cfg.final_dim if cfg.final_dim > 0 else cfg.encoder_embed_dim
self.mask_emb = nn.Parameter(
torch.FloatTensor(cfg.encoder_embed_dim).uniform_()
@ -297,9 +281,7 @@ class HubertModel(BaseFairseqModel):
# modules below are not needed during fine-tuning
if any([d is None for d in dictionaries]):
logger.info(
"cannot find dictionary. assume will be used for fine-tuning"
)
logger.info("cannot find dictionary. assume will be used for fine-tuning")
else:
self.num_classes = [len(d) for d in dictionaries]
self.label_embs_concat = nn.Parameter(
@ -365,9 +347,7 @@ class HubertModel(BaseFairseqModel):
pos = pos.unsqueeze(0)
targets = torch.cat([pos, negs], dim=0)
logits = torch.cosine_similarity(
x.float(), targets.float(), dim=-1
).type_as(x)
logits = torch.cosine_similarity(x.float(), targets.float(), dim=-1).type_as(x)
logits /= self.logit_temp
if neg_is_pos.any():
logits[1:][neg_is_pos] = float("-inf")
@ -385,7 +365,9 @@ class HubertModel(BaseFairseqModel):
return features
def forward_targets(
self, features: torch.Tensor, target_list: List[torch.Tensor],
self,
features: torch.Tensor,
target_list: List[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
# Trim features to ensure labels exist and then get aligned labels
feat_tsz = features.size(2)
@ -398,14 +380,14 @@ class HubertModel(BaseFairseqModel):
return features, target_list
def forward_padding_mask(
self, features: torch.Tensor, padding_mask: torch.Tensor,
self,
features: torch.Tensor,
padding_mask: torch.Tensor,
) -> torch.Tensor:
extra = padding_mask.size(1) % features.size(1)
if extra > 0:
padding_mask = padding_mask[:, :-extra]
padding_mask = padding_mask.view(
padding_mask.size(0), features.size(1), -1
)
padding_mask = padding_mask.view(padding_mask.size(0), features.size(1), -1)
padding_mask = padding_mask.all(-1)
return padding_mask
@ -439,9 +421,7 @@ class HubertModel(BaseFairseqModel):
unmasked_features = self.dropout_features(unmasked_features)
if mask:
x, mask_indices = self.apply_mask(
features, padding_mask, target_list
)
x, mask_indices = self.apply_mask(features, padding_mask, target_list)
else:
x = features
mask_indices = None
@ -454,7 +434,7 @@ class HubertModel(BaseFairseqModel):
x, _ = self.encoder(
x,
padding_mask=padding_mask,
layer=None if output_layer is None else output_layer - 1
layer=None if output_layer is None else output_layer - 1,
)
if features_only:
@ -483,9 +463,7 @@ class HubertModel(BaseFairseqModel):
proj_x_m_list = [proj_x_m for _ in range(len(target_list))]
logit_m_list = [
compute_pred(proj_x_m, t[masked_indices], label_embs_list[i])
for i, (proj_x_m, t) in enumerate(
zip(proj_x_m_list, target_list)
)
for i, (proj_x_m, t) in enumerate(zip(proj_x_m_list, target_list))
]
else:
logit_m_list = [None for _ in target_list]
@ -500,9 +478,7 @@ class HubertModel(BaseFairseqModel):
logit_u_list = [
compute_pred(proj_x_u, t[nomask_indices], label_embs_list[i])
for i, (proj_x_u, t) in enumerate(
zip(proj_x_u_list, target_list)
)
for i, (proj_x_u, t) in enumerate(zip(proj_x_u_list, target_list))
]
else:
logit_u_list = [None for _ in target_list]
@ -543,9 +519,7 @@ class HubertModel(BaseFairseqModel):
def get_targets(self, net_output, is_masked=True):
logits_list = self.get_logits(net_output, is_masked)
targets_list = [
x.new_zeros(x.size(0), dtype=torch.long) for x in logits_list
]
targets_list = [x.new_zeros(x.size(0), dtype=torch.long) for x in logits_list]
return targets_list
def get_extra_losses(self, net_output):

View File

@ -21,9 +21,7 @@ from omegaconf import II, MISSING
@dataclass
class HubertAsrConfig(FairseqDataclass):
w2v_path: str = field(
default=MISSING, metadata={"help": "path to hubert model"}
)
w2v_path: str = field(default=MISSING, metadata={"help": "path to hubert model"})
no_pretrained_weights: bool = field(
default=False,
metadata={"help": "if true, does not load pretrained weights"},
@ -34,9 +32,7 @@ class HubertAsrConfig(FairseqDataclass):
)
final_dropout: float = field(
default=0.0,
metadata={
"help": "dropout after transformer and before final projection"
},
metadata={"help": "dropout after transformer and before final projection"},
)
dropout: float = field(
default=0.0,
@ -45,15 +41,13 @@ class HubertAsrConfig(FairseqDataclass):
attention_dropout: float = field(
default=0.0,
metadata={
"help": "dropout probability for attention weights "
"inside hubert model"
"help": "dropout probability for attention weights " "inside hubert model"
},
)
activation_dropout: float = field(
default=0.0,
metadata={
"help": "dropout probability after activation in FFN "
"inside hubert model"
"help": "dropout probability after activation in FFN " "inside hubert model"
},
)
@ -184,9 +178,7 @@ class HubertSeq2SeqConfig(HubertAsrConfig):
decoder_ffn_embed_dim: int = field(
default=3072, metadata={"help": "decoder embedding dimension for FFN"}
)
decoder_layers: int = field(
default=6, metadata={"help": "num of decoder layers"}
)
decoder_layers: int = field(default=6, metadata={"help": "num of decoder layers"})
decoder_layerdrop: float = field(
default=0.0, metadata={"help": "decoder layerdrop chance"}
)
@ -204,8 +196,7 @@ class HubertSeq2SeqConfig(HubertAsrConfig):
no_token_positional_embeddings: bool = field(
default=False,
metadata={
"help": "if set, disables positional embeddings "
"(outside self attention)"
"help": "if set, disables positional embeddings " "(outside self attention)"
},
)
decoder_dropout: float = field(
@ -214,15 +205,13 @@ class HubertSeq2SeqConfig(HubertAsrConfig):
decoder_attention_dropout: float = field(
default=0.0,
metadata={
"help": "dropout probability for attention weights "
"inside the decoder"
"help": "dropout probability for attention weights " "inside the decoder"
},
)
decoder_activation_dropout: float = field(
default=0.0,
metadata={
"help": "dropout probability after activation in FFN "
"inside the decoder"
"help": "dropout probability after activation in FFN " "inside the decoder"
},
)
max_target_positions: int = field(
@ -258,9 +247,7 @@ class HubertEncoder(FairseqEncoder):
}
if cfg.w2v_args is None:
state = checkpoint_utils.load_checkpoint_to_cpu(
cfg.w2v_path, arg_overrides
)
state = checkpoint_utils.load_checkpoint_to_cpu(cfg.w2v_path, arg_overrides)
w2v_args = state.get("cfg", None)
if w2v_args is None:
w2v_args = convert_namespace_to_omegaconf(state["args"])
@ -269,9 +256,7 @@ class HubertEncoder(FairseqEncoder):
state = None
w2v_args = cfg.w2v_args
if isinstance(w2v_args, Namespace):
cfg.w2v_args = w2v_args = convert_namespace_to_omegaconf(
w2v_args
)
cfg.w2v_args = w2v_args = convert_namespace_to_omegaconf(w2v_args)
assert cfg.normalize == w2v_args.task.normalize, (
"Fine-tuning works best when data normalization is the same. "
@ -344,9 +329,9 @@ class HubertEncoder(FairseqEncoder):
def reorder_encoder_out(self, encoder_out, new_order):
if encoder_out["encoder_out"] is not None:
encoder_out["encoder_out"] = encoder_out[
"encoder_out"
].index_select(1, new_order)
encoder_out["encoder_out"] = encoder_out["encoder_out"].index_select(
1, new_order
)
if encoder_out["encoder_padding_mask"] is not None:
encoder_out["encoder_padding_mask"] = encoder_out[
"encoder_padding_mask"

View File

@ -225,10 +225,10 @@ class LSTMEncoder(FairseqEncoder):
super().__init__(dictionary)
self.num_layers = num_layers
self.dropout_in_module = FairseqDropout(
dropout_in*1.0, module_name=self.__class__.__name__
dropout_in * 1.0, module_name=self.__class__.__name__
)
self.dropout_out_module = FairseqDropout(
dropout_out*1.0, module_name=self.__class__.__name__
dropout_out * 1.0, module_name=self.__class__.__name__
)
self.bidirectional = bidirectional
self.hidden_size = hidden_size
@ -329,7 +329,9 @@ class LSTMEncoder(FairseqEncoder):
out = outs.view(self.num_layers, 2, bsz, -1).transpose(1, 2).contiguous()
return out.view(self.num_layers, bsz, -1)
def reorder_encoder_out(self, encoder_out: Tuple[Tensor, Tensor, Tensor, Tensor], new_order):
def reorder_encoder_out(
self, encoder_out: Tuple[Tensor, Tensor, Tensor, Tensor], new_order
):
return tuple(
(
encoder_out[0].index_select(1, new_order),
@ -402,10 +404,10 @@ class LSTMDecoder(FairseqIncrementalDecoder):
):
super().__init__(dictionary)
self.dropout_in_module = FairseqDropout(
dropout_in*1.0, module_name=self.__class__.__name__
dropout_in * 1.0, module_name=self.__class__.__name__
)
self.dropout_out_module = FairseqDropout(
dropout_out*1.0, module_name=self.__class__.__name__
dropout_out * 1.0, module_name=self.__class__.__name__
)
self.hidden_size = hidden_size
self.share_input_output_embed = share_input_output_embed

View File

@ -18,7 +18,10 @@ def ensemble_encoder(func):
def wrapper(self, *args, **kwargs):
if self.ensemble_models is None or len(self.ensemble_models) == 1:
return func(self, *args, **kwargs)
encoder_outs = [func(model, *args, **kwargs, return_all_hiddens=True) for model in self.ensemble_models]
encoder_outs = [
func(model, *args, **kwargs, return_all_hiddens=True)
for model in self.ensemble_models
]
_encoder_out = encoder_outs[0].copy()
def stack(key):
@ -56,8 +59,7 @@ def ensemble_decoder(func):
model,
normalize=normalize,
encoder_out=_replace(
encoder_out,
encoder_out["encoder_out"][0][:, :, :, i]
encoder_out, encoder_out["encoder_out"][0][:, :, :, i]
),
*args,
**kwargs

View File

@ -85,7 +85,8 @@ class EnsembleLevT(BasicEnsembleModel):
else:
if not encoder_outs[0]["encoder_padding_mask"]:
src_lens = (
encoder_outs[0]["encoder_out"][0].new(bsz)
encoder_outs[0]["encoder_out"][0]
.new(bsz)
.fill_(encoder_outs[0]["encoder_out"][0].size(1))
)
else:

View File

@ -183,7 +183,7 @@ class RobertaModel(FairseqEncoderModel):
"communication less efficient due to smaller input sizes. This option "
"is set to 0 (i.e., always wrap) when --checkpoint-activations or "
"--offload-activations are passed."
)
),
)
@classmethod
@ -542,7 +542,9 @@ def base_architecture(args):
args.layernorm_embedding = safe_getattr(args, "layernorm_embedding", True)
args.no_scale_embedding = safe_getattr(args, "no_scale_embedding", True)
args.activation_fn = safe_getattr(args, "activation_fn", "gelu")
args.encoder_normalize_before = safe_getattr(args, "encoder_normalize_before", False)
args.encoder_normalize_before = safe_getattr(
args, "encoder_normalize_before", False
)
args.pooler_activation_fn = safe_getattr(args, "pooler_activation_fn", "tanh")
args.untie_weights_roberta = safe_getattr(args, "untie_weights_roberta", False)

View File

@ -12,26 +12,26 @@ from .hub_interface import RobertaHubInterface
from .model import RobertaModel
@register_model('gottbert')
@register_model("gottbert")
class GottbertModel(RobertaModel):
@classmethod
def hub_models(cls):
return {
'gottbert-base': 'https://dl.gottbert.de/fairseq/models/gottbert-base.tar.gz',
"gottbert-base": "https://dl.gottbert.de/fairseq/models/gottbert-base.tar.gz",
}
@classmethod
def from_pretrained(cls,
model_name_or_path,
checkpoint_file='model.pt',
data_name_or_path='.',
bpe='hf_byte_bpe',
bpe_vocab='vocab.json',
bpe_merges='merges.txt',
bpe_add_prefix_space=False,
**kwargs
):
def from_pretrained(
cls,
model_name_or_path,
checkpoint_file="model.pt",
data_name_or_path=".",
bpe="hf_byte_bpe",
bpe_vocab="vocab.json",
bpe_merges="merges.txt",
bpe_add_prefix_space=False,
**kwargs
):
from fairseq import hub_utils
x = hub_utils.from_pretrained(
@ -46,4 +46,4 @@ class GottbertModel(RobertaModel):
bpe_add_prefix_space=bpe_add_prefix_space,
**kwargs,
)
return RobertaHubInterface(x['args'], x['task'], x['models'][0])
return RobertaHubInterface(x["args"], x["task"], x["models"][0])

View File

@ -202,10 +202,10 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
help="model to take encoder weights from (for initialization)",
)
parser.add_argument(
'--encoder-freezing-updates',
"--encoder-freezing-updates",
type=int,
metavar='N',
help='freeze encoder for first N updates'
metavar="N",
help="freeze encoder for first N updates",
)
@classmethod
@ -329,7 +329,9 @@ class S2TTransformerEncoder(FairseqEncoder):
return {
"encoder_out": [x], # T x B x C
"encoder_padding_mask": [encoder_padding_mask] if encoder_padding_mask.any() else [], # B x T
"encoder_padding_mask": [encoder_padding_mask]
if encoder_padding_mask.any()
else [], # B x T
"encoder_embedding": [], # B x T x C
"encoder_states": encoder_states, # List[T x B x C]
"src_tokens": [],
@ -339,27 +341,37 @@ class S2TTransformerEncoder(FairseqEncoder):
def forward(self, src_tokens, src_lengths, return_all_hiddens=False):
if self.num_updates < self.encoder_freezing_updates:
with torch.no_grad():
x = self._forward(src_tokens, src_lengths,
return_all_hiddens=return_all_hiddens)
x = self._forward(
src_tokens, src_lengths, return_all_hiddens=return_all_hiddens
)
else:
x = self._forward(src_tokens, src_lengths,
return_all_hiddens=return_all_hiddens)
x = self._forward(
src_tokens, src_lengths, return_all_hiddens=return_all_hiddens
)
return x
def reorder_encoder_out(self, encoder_out, new_order):
new_encoder_out = (
[] if len(encoder_out["encoder_out"]) == 0
[]
if len(encoder_out["encoder_out"]) == 0
else [x.index_select(1, new_order) for x in encoder_out["encoder_out"]]
)
new_encoder_padding_mask = (
[] if len(encoder_out["encoder_padding_mask"]) == 0
else [x.index_select(0, new_order) for x in encoder_out["encoder_padding_mask"]]
[]
if len(encoder_out["encoder_padding_mask"]) == 0
else [
x.index_select(0, new_order)
for x in encoder_out["encoder_padding_mask"]
]
)
new_encoder_embedding = (
[] if len(encoder_out["encoder_embedding"]) == 0
else [x.index_select(0, new_order) for x in encoder_out["encoder_embedding"]]
[]
if len(encoder_out["encoder_embedding"]) == 0
else [
x.index_select(0, new_order) for x in encoder_out["encoder_embedding"]
]
)
encoder_states = encoder_out["encoder_states"]

View File

@ -9,8 +9,12 @@ import copy
from typing import Dict, List, Optional, Tuple
from fairseq import utils, checkpoint_utils
from fairseq.models import (FairseqEncoderDecoderModel, FairseqEncoder,
register_model, register_model_architecture)
from fairseq.models import (
FairseqEncoderDecoderModel,
FairseqEncoder,
register_model,
register_model_architecture,
)
from fairseq.models.transformer import Embedding, TransformerDecoder
from fairseq.models.wav2vec import Wav2VecEncoder
from fairseq.modules.layer_norm import LayerNorm
@ -24,18 +28,23 @@ logger = logging.getLogger(__name__)
class Conv1dAdaptor(nn.Module):
def __init__(self, in_dim, out_dim, n_layers=3, kernel_size=3, stride=2,
add_layernorm=False):
def __init__(
self, in_dim, out_dim, n_layers=3, kernel_size=3, stride=2, add_layernorm=False
):
super().__init__()
self.layers = nn.ModuleList(
nn.Conv1d(in_dim if i == 0 else out_dim, out_dim * 2, kernel_size,
stride=stride, padding=kernel_size // 2)
nn.Conv1d(
in_dim if i == 0 else out_dim,
out_dim * 2,
kernel_size,
stride=stride,
padding=kernel_size // 2,
)
for i in range(n_layers)
)
self.layernorms = None
if add_layernorm:
self.layernorms = nn.ModuleList(LayerNorm(out_dim)
for _ in range(n_layers))
self.layernorms = nn.ModuleList(LayerNorm(out_dim) for _ in range(n_layers))
self.stride = stride
@classmethod
@ -43,7 +52,7 @@ class Conv1dAdaptor(nn.Module):
parser.add_argument("--adaptor-n-layers", type=int)
parser.add_argument("--adaptor-kernel-size", type=int)
parser.add_argument("--adaptor-stride", type=int)
parser.add_argument("--adaptor-layernorm", action='store_true')
parser.add_argument("--adaptor-layernorm", action="store_true")
def get_out_seq_lens_tensor(self, in_seq_lens_tensor):
out = in_seq_lens_tensor.clone()
@ -197,15 +206,18 @@ class Wav2VecEncoderWithAdaptor(FairseqEncoder):
encoder_out_dim = self.w2v_encoder.w2v_model.encoder.embedding_dim
# Projection + 8x shrinking
self.adaptor = Conv1dAdaptor(
encoder_out_dim, args.decoder_embed_dim,
encoder_out_dim,
args.decoder_embed_dim,
n_layers=args.adaptor_n_layers,
kernel_size=args.adaptor_kernel_size, stride=args.adaptor_stride,
add_layernorm=args.adaptor_layernorm
kernel_size=args.adaptor_kernel_size,
stride=args.adaptor_stride,
add_layernorm=args.adaptor_layernorm,
)
for k, p in self.w2v_encoder.w2v_model.named_parameters():
# Freeze pretrained models by default
if safe_hasattr(args, 'finetune_w2v_params') and XMTransformerModel.finetune_params(
args.finetune_w2v_params, k):
if safe_hasattr(
args, "finetune_w2v_params"
) and XMTransformerModel.finetune_params(args.finetune_w2v_params, k):
p.requires_grad = True
else:
p.requires_grad = False
@ -214,11 +226,16 @@ class Wav2VecEncoderWithAdaptor(FairseqEncoder):
def add_args(cls, parser):
add_wav2vec_asr_args(parser)
parser.add_argument(
"--normalize", action="store_true",
"--normalize",
action="store_true",
help="if set, normalizes input to have 0 mean and unit variance",
)
parser.add_argument("--finetune-w2v-params", type=str, metavar="STR",
help="comma-separated param strings to finetune.")
parser.add_argument(
"--finetune-w2v-params",
type=str,
metavar="STR",
help="comma-separated param strings to finetune.",
)
Conv1dAdaptor.add_args(parser)
def forward(self, src_tokens, src_lengths=None, **kwargs):
@ -227,13 +244,17 @@ class Wav2VecEncoderWithAdaptor(FairseqEncoder):
x = out["encoder_out"]
enc_padding_mask = None
if out["encoder_padding_mask"] is not None:
enc_padding_mask = out["encoder_padding_mask"].transpose(0, 1) # T X B --> B X T
enc_padding_mask = out["encoder_padding_mask"].transpose(
0, 1
) # T X B --> B X T
x, enc_padding_mask = self.adaptor(x, enc_padding_mask)
return {
"encoder_out": [x], # T x B x C
"encoder_padding_mask": [enc_padding_mask] if enc_padding_mask.any() else [], # B x T
"encoder_padding_mask": [enc_padding_mask]
if enc_padding_mask.any()
else [], # B x T
"encoder_embedding": [], # B x T x C
"encoder_states": [], # List[T x B x C]
"src_tokens": [],
@ -242,20 +263,26 @@ class Wav2VecEncoderWithAdaptor(FairseqEncoder):
def reorder_encoder_out(self, encoder_out, new_order):
new_encoder_out = (
[] if len(encoder_out["encoder_out"]) == 0
[]
if len(encoder_out["encoder_out"]) == 0
else [x.index_select(1, new_order) for x in encoder_out["encoder_out"]]
)
new_encoder_padding_mask = (
[] if len(encoder_out["encoder_padding_mask"]) == 0
else [x.index_select(0, new_order) for x in
encoder_out["encoder_padding_mask"]]
[]
if len(encoder_out["encoder_padding_mask"]) == 0
else [
x.index_select(0, new_order)
for x in encoder_out["encoder_padding_mask"]
]
)
new_encoder_embedding = (
[] if len(encoder_out["encoder_embedding"]) == 0
else [x.index_select(0, new_order) for x in
encoder_out["encoder_embedding"]]
[]
if len(encoder_out["encoder_embedding"]) == 0
else [
x.index_select(0, new_order) for x in encoder_out["encoder_embedding"]
]
)
encoder_states = encoder_out["encoder_states"]
@ -274,38 +301,71 @@ class Wav2VecEncoderWithAdaptor(FairseqEncoder):
def add_decoder_args(parser):
parser.add_argument("--activation-fn", type=str, default='relu',
choices=utils.get_available_activation_fns(),
help="activation function to use")
parser.add_argument("--decoder-dropout", type=float, metavar="D",
help="dropout probability")
parser.add_argument("--decoder-attention-dropout", type=float,
metavar="D",
help="dropout probability for attention weights")
parser.add_argument("--decoder-activation-dropout", type=float,
metavar="D",
help="dropout probability after activation in FFN.")
parser.add_argument("--decoder-embed-dim", type=int, metavar="N",
help="decoder embedding dimension")
parser.add_argument("--decoder-ffn-embed-dim", type=int, metavar="N",
help="decoder embedding dimension for FFN")
parser.add_argument("--decoder-layers", type=int, metavar="N",
help="num decoder layers")
parser.add_argument("--decoder-attention-heads", type=int, metavar="N",
help="num decoder attention heads")
parser.add_argument("--decoder-normalize-before", action="store_true",
help="apply layernorm before each decoder block")
parser.add_argument("--layernorm-embedding", action="store_true",
help="add layernorm to embedding")
parser.add_argument("--no-scale-embedding", action="store_true",
help="if True, dont scale embeddings")
parser.add_argument(
"--load-pretrained-decoder-from", type=str, metavar="STR",
help="model to take decoder weights from (for initialization)"
"--activation-fn",
type=str,
default="relu",
choices=utils.get_available_activation_fns(),
help="activation function to use",
)
parser.add_argument(
"--decoder-dropout", type=float, metavar="D", help="dropout probability"
)
parser.add_argument(
"--decoder-attention-dropout",
type=float,
metavar="D",
help="dropout probability for attention weights",
)
parser.add_argument(
"--decoder-activation-dropout",
type=float,
metavar="D",
help="dropout probability after activation in FFN.",
)
parser.add_argument(
"--decoder-embed-dim", type=int, metavar="N", help="decoder embedding dimension"
)
parser.add_argument(
"--decoder-ffn-embed-dim",
type=int,
metavar="N",
help="decoder embedding dimension for FFN",
)
parser.add_argument(
"--decoder-layers", type=int, metavar="N", help="num decoder layers"
)
parser.add_argument(
"--decoder-attention-heads",
type=int,
metavar="N",
help="num decoder attention heads",
)
parser.add_argument(
"--decoder-normalize-before",
action="store_true",
help="apply layernorm before each decoder block",
)
parser.add_argument(
"--layernorm-embedding", action="store_true", help="add layernorm to embedding"
)
parser.add_argument(
"--no-scale-embedding",
action="store_true",
help="if True, dont scale embeddings",
)
parser.add_argument(
"--load-pretrained-decoder-from",
type=str,
metavar="STR",
help="model to take decoder weights from (for initialization)",
)
parser.add_argument(
"--finetune-decoder-params",
type=str,
metavar="STR",
help="comma-separated param strings to finetune.",
)
parser.add_argument("--finetune-decoder-params", type=str,
metavar="STR",
help="comma-separated param strings to finetune.")
parser.add_argument("--checkpoint-activations", action="store_true")
@ -342,16 +402,16 @@ class XMTransformerModel(FairseqEncoderDecoderModel):
_args.activation_dropout = args.decoder_activation_dropout
_args.max_target_positions = 1024
decoder = TransformerDecoder(_args, task.target_dictionary,
embed_tokens)
decoder = TransformerDecoder(_args, task.target_dictionary, embed_tokens)
if getattr(args, "load_pretrained_decoder_from", None):
decoder = checkpoint_utils.load_pretrained_component_from_model(
component=decoder, checkpoint=args.load_pretrained_decoder_from
)
for k, p in decoder.named_parameters():
# Freeze pretrained models by default
if safe_hasattr(args, 'finetune_decoder_params') and XMTransformerModel.finetune_params(
args.finetune_decoder_params, k):
if safe_hasattr(
args, "finetune_decoder_params"
) and XMTransformerModel.finetune_params(args.finetune_decoder_params, k):
p.requires_grad = True
else:
p.requires_grad = False
@ -369,8 +429,9 @@ class XMTransformerModel(FairseqEncoderDecoderModel):
padding_idx = dictionary.pad()
return Embedding(num_embeddings, embed_dim, padding_idx)
decoder_embed_tokens = build_embedding(task.target_dictionary,
args.decoder_embed_dim)
decoder_embed_tokens = build_embedding(
task.target_dictionary, args.decoder_embed_dim
)
encoder = cls.build_encoder(args)
decoder = cls.build_decoder(args, task, decoder_embed_tokens)
return cls(encoder, decoder)
@ -382,8 +443,7 @@ class XMTransformerModel(FairseqEncoderDecoderModel):
sample: Optional[Dict[str, Tensor]] = None,
):
# net_output['encoder_out'] is a (B, T, D) tensor
lprobs = self.get_normalized_probs_scriptable(net_output, log_probs,
sample)
lprobs = self.get_normalized_probs_scriptable(net_output, log_probs, sample)
lprobs.batch_first = True
return lprobs
@ -393,17 +453,19 @@ class XMTransformerModel(FairseqEncoderDecoderModel):
argument in its input, which is not supported in torchscript. This
method overrites the forward method definition without **kwargs.
"""
encoder_out = self.encoder(src_tokens=src_tokens,
src_lengths=src_lengths, **kwargs)
decoder_out = self.decoder(prev_output_tokens=prev_output_tokens,
encoder_out=encoder_out)
encoder_out = self.encoder(
src_tokens=src_tokens, src_lengths=src_lengths, **kwargs
)
decoder_out = self.decoder(
prev_output_tokens=prev_output_tokens, encoder_out=encoder_out
)
return decoder_out
def upgrade_state_dict(self, state_dict):
for k, _ in state_dict.items():
if 'adaptor.layers' in state_dict:
if "adaptor.layers" in state_dict:
print(k)
new = k.replace('adaptor.layers', 'adaptor_layers')
new = k.replace("adaptor.layers", "adaptor_layers")
state_dict[new] = state_dict[k]
del state_dict[k]
@ -435,11 +497,9 @@ def set_default_w2v_encoder_args(args):
args.mask_channel_length = getattr(args, "mask_channel_length", 10)
args.mask_channel_prob = getattr(args, "mask_channel_prob", 0.5)
args.mask_channel_before = getattr(args, "mask_channel_before", False)
args.mask_channel_selection = getattr(args, "mask_channel_selection",
"static")
args.mask_channel_selection = getattr(args, "mask_channel_selection", "static")
args.mask_channel_other = getattr(args, "mask_channel_other", 0)
args.no_mask_channel_overlap = getattr(args, "no_mask_channel_overlap",
False)
args.no_mask_channel_overlap = getattr(args, "no_mask_channel_overlap", False)
args.freeze_finetune_updates = getattr(args, "freeze_finetune_updates", 0)
args.feature_grad_mult = 0.1
@ -456,49 +516,43 @@ def set_default_adaptor_args(args):
def set_default_mbart_decoder_args(args):
args.decoder_embed_path = getattr(args, 'decoder_embed_path', None)
args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 1024)
args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim',
4 * 1024)
args.decoder_layers = getattr(args, 'decoder_layers', 12)
args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 16)
args.decoder_normalize_before = getattr(args, 'decoder_normalize_before',
True)
args.decoder_learned_pos = getattr(args, 'decoder_learned_pos', True)
args.decoder_embed_path = getattr(args, "decoder_embed_path", None)
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1024)
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 4 * 1024)
args.decoder_layers = getattr(args, "decoder_layers", 12)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16)
args.decoder_normalize_before = getattr(args, "decoder_normalize_before", True)
args.decoder_learned_pos = getattr(args, "decoder_learned_pos", True)
args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0.0)
args.adaptive_input = getattr(args, "adaptive_input", False)
args.decoder_attention_dropout = getattr(args, 'decoder_attention_dropout',
0.)
args.decoder_activation_dropout = getattr(args,
'decoder_activation_dropout', 0.)
args.decoder_dropout = getattr(args, 'decoder_dropout', 0.1)
args.adaptive_softmax_cutoff = getattr(args, 'adaptive_softmax_cutoff',
None)
args.adaptive_softmax_dropout = getattr(args, 'adaptive_softmax_dropout', 0)
args.decoder_attention_dropout = getattr(args, "decoder_attention_dropout", 0.0)
args.decoder_activation_dropout = getattr(args, "decoder_activation_dropout", 0.0)
args.decoder_dropout = getattr(args, "decoder_dropout", 0.1)
args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None)
args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
args.share_decoder_input_output_embed = getattr(
args, 'share_decoder_input_output_embed', True
args, "share_decoder_input_output_embed", True
)
args.no_token_positional_embeddings = getattr(
args, "no_token_positional_embeddings", False
)
args.decoder_output_dim = getattr(args, 'decoder_output_dim',
args.decoder_embed_dim)
args.decoder_input_dim = getattr(args, 'decoder_input_dim',
args.decoder_embed_dim)
args.decoder_output_dim = getattr(
args, "decoder_output_dim", args.decoder_embed_dim
)
args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim)
args.no_scale_embedding = getattr(args, 'no_scale_embedding', False)
args.no_scale_embedding = getattr(args, "no_scale_embedding", False)
args.quant_noise_pq = getattr(args, "quant_noise_pq", 0)
args.layernorm_embedding = getattr(args, 'layernorm_embedding', True)
args.layernorm_embedding = getattr(args, "layernorm_embedding", True)
args.activation_fn = getattr(args, 'activation_fn', 'gelu')
args.pooler_activation_fn = getattr(args, 'pooler_activation_fn', 'tanh')
args.pooler_dropout = getattr(args, 'pooler_dropout', 0.0)
args.activation_fn = getattr(args, "activation_fn", "gelu")
args.pooler_activation_fn = getattr(args, "pooler_activation_fn", "tanh")
args.pooler_dropout = getattr(args, "pooler_dropout", 0.0)
args.checkpoint_activations = getattr(args, "checkpoint_activations", False)
@register_model_architecture(model_name="xm_transformer",
arch_name="xm_transformer")
@register_model_architecture(model_name="xm_transformer", arch_name="xm_transformer")
def base_architecture(args):
set_default_w2v_encoder_args(args)
set_default_adaptor_args(args)

View File

@ -8,10 +8,17 @@ import logging
import torch
from torch import nn
from fairseq.models import (FairseqEncoder, FairseqEncoderModel, register_model,
register_model_architecture)
from fairseq.models import (
FairseqEncoder,
FairseqEncoderModel,
register_model,
register_model_architecture,
)
from fairseq.modules import (
LayerNorm, PositionalEmbedding, FairseqDropout, MultiheadAttention
LayerNorm,
PositionalEmbedding,
FairseqDropout,
MultiheadAttention,
)
from fairseq import utils
from fairseq.data.data_utils import lengths_to_padding_mask
@ -36,11 +43,19 @@ class PositionwiseFeedForward(nn.Module):
def __init__(self, in_dim, hidden_dim, kernel_size, dropout):
super().__init__()
self.ffn = nn.Sequential(
nn.Conv1d(in_dim, hidden_dim, kernel_size=kernel_size,
padding=(kernel_size - 1) // 2),
nn.Conv1d(
in_dim,
hidden_dim,
kernel_size=kernel_size,
padding=(kernel_size - 1) // 2,
),
nn.ReLU(),
nn.Conv1d(hidden_dim, in_dim, kernel_size=kernel_size,
padding=(kernel_size - 1) // 2)
nn.Conv1d(
hidden_dim,
in_dim,
kernel_size=kernel_size,
padding=(kernel_size - 1) // 2,
),
)
self.layer_norm = LayerNorm(in_dim)
self.dropout = self.dropout_module = FairseqDropout(
@ -57,8 +72,7 @@ class PositionwiseFeedForward(nn.Module):
class FFTLayer(torch.nn.Module):
def __init__(
self, embed_dim, n_heads, hidden_dim, kernel_size, dropout,
attention_dropout
self, embed_dim, n_heads, hidden_dim, kernel_size, dropout, attention_dropout
):
super().__init__()
self.self_attn = MultiheadAttention(
@ -74,8 +88,7 @@ class FFTLayer(torch.nn.Module):
residual = x
x = x.transpose(0, 1)
x, _ = self.self_attn(
query=x, key=x, value=x, key_padding_mask=padding_mask,
need_weights=False
query=x, key=x, value=x, key_padding_mask=padding_mask, need_weights=False
)
x = x.transpose(0, 1)
x = self.layer_norm(x + residual)
@ -106,11 +119,12 @@ class VariancePredictor(nn.Module):
super().__init__()
self.conv1 = nn.Sequential(
nn.Conv1d(
args.encoder_embed_dim, args.var_pred_hidden_dim,
args.encoder_embed_dim,
args.var_pred_hidden_dim,
kernel_size=args.var_pred_kernel_size,
padding=(args.var_pred_kernel_size - 1) // 2
padding=(args.var_pred_kernel_size - 1) // 2,
),
nn.ReLU()
nn.ReLU(),
)
self.ln1 = nn.LayerNorm(args.var_pred_hidden_dim)
self.dropout_module = FairseqDropout(
@ -118,10 +132,12 @@ class VariancePredictor(nn.Module):
)
self.conv2 = nn.Sequential(
nn.Conv1d(
args.var_pred_hidden_dim, args.var_pred_hidden_dim,
kernel_size=args.var_pred_kernel_size, padding=1
args.var_pred_hidden_dim,
args.var_pred_hidden_dim,
kernel_size=args.var_pred_kernel_size,
padding=1,
),
nn.ReLU()
nn.ReLU(),
)
self.ln2 = nn.LayerNorm(args.var_pred_hidden_dim)
self.proj = nn.Linear(args.var_pred_hidden_dim, 1)
@ -171,8 +187,15 @@ class VarianceAdaptor(nn.Module):
return out, emb
def forward(
self, x, padding_mask, durations=None, pitches=None, energies=None,
d_factor=1.0, p_factor=1.0, e_factor=1.0
self,
x,
padding_mask,
durations=None,
pitches=None,
energies=None,
d_factor=1.0,
p_factor=1.0,
e_factor=1.0,
):
# x: B x T x C
log_dur_out = self.duration_predictor(x)
@ -205,8 +228,7 @@ class FastSpeech2Encoder(FairseqEncoder):
self.spk_emb_proj = None
if embed_speaker is not None:
self.spk_emb_proj = nn.Linear(
args.encoder_embed_dim + args.speaker_embed_dim,
args.encoder_embed_dim
args.encoder_embed_dim + args.speaker_embed_dim, args.encoder_embed_dim
)
self.dropout_module = FairseqDropout(
@ -224,9 +246,12 @@ class FastSpeech2Encoder(FairseqEncoder):
self.encoder_fft_layers = nn.ModuleList(
FFTLayer(
args.encoder_embed_dim, args.encoder_attention_heads,
args.fft_hidden_dim, args.fft_kernel_size,
dropout=args.dropout, attention_dropout=args.attention_dropout
args.encoder_embed_dim,
args.encoder_attention_heads,
args.fft_hidden_dim,
args.fft_kernel_size,
dropout=args.dropout,
attention_dropout=args.attention_dropout,
)
for _ in range(args.encoder_layers)
)
@ -235,9 +260,12 @@ class FastSpeech2Encoder(FairseqEncoder):
self.decoder_fft_layers = nn.ModuleList(
FFTLayer(
args.decoder_embed_dim, args.decoder_attention_heads,
args.fft_hidden_dim, args.fft_kernel_size,
dropout=args.dropout, attention_dropout=args.attention_dropout
args.decoder_embed_dim,
args.decoder_attention_heads,
args.fft_hidden_dim,
args.fft_kernel_size,
dropout=args.dropout,
attention_dropout=args.attention_dropout,
)
for _ in range(args.decoder_layers)
)
@ -247,15 +275,25 @@ class FastSpeech2Encoder(FairseqEncoder):
self.postnet = None
if args.add_postnet:
self.postnet = Postnet(
self.out_dim, args.postnet_conv_dim,
self.out_dim,
args.postnet_conv_dim,
args.postnet_conv_kernel_size,
args.postnet_layers, args.postnet_dropout
args.postnet_layers,
args.postnet_dropout,
)
self.apply(model_init)
def forward(self, src_tokens, src_lengths=None, speaker=None,
durations=None, pitches=None, energies=None, **kwargs):
def forward(
self,
src_tokens,
src_lengths=None,
speaker=None,
durations=None,
pitches=None,
energies=None,
**kwargs
):
x = self.embed_tokens(src_tokens)
enc_padding_mask = src_tokens.eq(self.padding_idx)
@ -270,8 +308,9 @@ class FastSpeech2Encoder(FairseqEncoder):
emb = self.embed_speaker(speaker).expand(bsz, seq_len, -1)
x = self.spk_emb_proj(torch.cat([x, emb], dim=2))
x, out_lens, log_dur_out, pitch_out, energy_out = \
self.var_adaptor(x, enc_padding_mask, durations, pitches, energies)
x, out_lens, log_dur_out, pitch_out, energy_out = self.var_adaptor(
x, enc_padding_mask, durations, pitches, energies
)
dec_padding_mask = lengths_to_padding_mask(out_lens)
x += self.dec_pos_emb_alpha * self.embed_positions(dec_padding_mask)
@ -326,7 +365,7 @@ class FastSpeech2Model(FairseqEncoderModel):
out_dim = args.output_frame_dim * args.n_frames_per_step
self.ctc_proj = None
if getattr(args, "ctc_weight", 0.) > 0.:
if getattr(args, "ctc_weight", 0.0) > 0.0:
self.ctc_proj = nn.Linear(out_dim, len(src_dict))
@classmethod

View File

@ -119,7 +119,7 @@ class Generator(torch.nn.Module):
self.ups = nn.ModuleList()
for i, (u, k) in enumerate(
zip(cfg["upsample_rates"], cfg["upsample_kernel_sizes"])
zip(cfg["upsample_rates"], cfg["upsample_kernel_sizes"])
):
self.ups.append(
weight_norm(
@ -137,7 +137,7 @@ class Generator(torch.nn.Module):
for i in range(len(self.ups)):
ch = cfg["upsample_initial_channel"] // (2 ** (i + 1))
for k, d in zip(
cfg["resblock_kernel_sizes"], cfg["resblock_dilation_sizes"]
cfg["resblock_kernel_sizes"], cfg["resblock_dilation_sizes"]
):
self.resblocks.append(ResBlock(ch, k, d))

View File

@ -9,9 +9,13 @@ import torch
from torch import nn
from torch.nn import functional as F
from fairseq.models import (FairseqEncoder, FairseqEncoderDecoderModel,
FairseqIncrementalDecoder, register_model,
register_model_architecture)
from fairseq.models import (
FairseqEncoder,
FairseqEncoderDecoderModel,
FairseqIncrementalDecoder,
register_model,
register_model_architecture,
)
from fairseq.modules import LSTMCellWithZoneOut, LocationAttention
@ -31,29 +35,36 @@ class Tacotron2Encoder(FairseqEncoder):
self.spk_emb_proj = None
if embed_speaker is not None:
self.spk_emb_proj = nn.Linear(
args.encoder_embed_dim + args.speaker_embed_dim,
args.encoder_embed_dim
args.encoder_embed_dim + args.speaker_embed_dim, args.encoder_embed_dim
)
self.embed_tokens = nn.Embedding(len(src_dict), args.encoder_embed_dim,
padding_idx=self.padding_idx)
self.embed_tokens = nn.Embedding(
len(src_dict), args.encoder_embed_dim, padding_idx=self.padding_idx
)
assert(args.encoder_conv_kernel_size % 2 == 1)
assert args.encoder_conv_kernel_size % 2 == 1
self.convolutions = nn.ModuleList(
nn.Sequential(
nn.Conv1d(args.encoder_embed_dim, args.encoder_embed_dim,
kernel_size=args.encoder_conv_kernel_size,
padding=((args.encoder_conv_kernel_size - 1) // 2)),
nn.Conv1d(
args.encoder_embed_dim,
args.encoder_embed_dim,
kernel_size=args.encoder_conv_kernel_size,
padding=((args.encoder_conv_kernel_size - 1) // 2),
),
nn.BatchNorm1d(args.encoder_embed_dim),
nn.ReLU(),
nn.Dropout(args.encoder_dropout)
nn.Dropout(args.encoder_dropout),
)
for _ in range(args.encoder_conv_layers)
)
self.lstm = nn.LSTM(args.encoder_embed_dim, args.encoder_embed_dim // 2,
num_layers=args.encoder_lstm_layers,
batch_first=True, bidirectional=True)
self.lstm = nn.LSTM(
args.encoder_embed_dim,
args.encoder_embed_dim // 2,
num_layers=args.encoder_lstm_layers,
batch_first=True,
bidirectional=True,
)
self.apply(encoder_init)
@ -78,7 +89,7 @@ class Tacotron2Encoder(FairseqEncoder):
return {
"encoder_out": [x], # B x T x C
"encoder_padding_mask": encoder_padding_mask, # B x T
"encoder_padding_mask": encoder_padding_mask, # B x T
}
@ -86,8 +97,7 @@ class Prenet(nn.Module):
def __init__(self, in_dim, n_layers, n_units, dropout):
super().__init__()
self.layers = nn.ModuleList(
nn.Sequential(nn.Linear(in_dim if i == 0 else n_units, n_units),
nn.ReLU())
nn.Sequential(nn.Linear(in_dim if i == 0 else n_units, n_units), nn.ReLU())
for i in range(n_layers)
)
self.dropout = dropout
@ -102,20 +112,24 @@ class Postnet(nn.Module):
def __init__(self, in_dim, n_channels, kernel_size, n_layers, dropout):
super(Postnet, self).__init__()
self.convolutions = nn.ModuleList()
assert(kernel_size % 2 == 1)
assert kernel_size % 2 == 1
for i in range(n_layers):
cur_layers = [
nn.Conv1d(in_dim if i == 0 else n_channels,
n_channels if i < n_layers - 1 else in_dim,
kernel_size=kernel_size,
padding=((kernel_size - 1) // 2)),
nn.BatchNorm1d(n_channels if i < n_layers - 1 else in_dim)
] + ([nn.Tanh()] if i < n_layers - 1 else []) + [nn.Dropout(dropout)]
cur_layers = (
[
nn.Conv1d(
in_dim if i == 0 else n_channels,
n_channels if i < n_layers - 1 else in_dim,
kernel_size=kernel_size,
padding=((kernel_size - 1) // 2),
),
nn.BatchNorm1d(n_channels if i < n_layers - 1 else in_dim),
]
+ ([nn.Tanh()] if i < n_layers - 1 else [])
+ [nn.Dropout(dropout)]
)
nn.init.xavier_uniform_(
cur_layers[0].weight,
torch.nn.init.calculate_gain(
"tanh" if i < n_layers - 1 else "linear"
)
torch.nn.init.calculate_gain("tanh" if i < n_layers - 1 else "linear"),
)
self.convolutions.append(nn.Sequential(*cur_layers))
@ -138,21 +152,25 @@ class Tacotron2Decoder(FairseqIncrementalDecoder):
self.n_frames_per_step = args.n_frames_per_step
self.out_dim = args.output_frame_dim * args.n_frames_per_step
self.prenet = Prenet(self.out_dim, args.prenet_layers, args.prenet_dim,
args.prenet_dropout)
self.prenet = Prenet(
self.out_dim, args.prenet_layers, args.prenet_dim, args.prenet_dropout
)
# take prev_context, prev_frame, (speaker embedding) as input
self.attention_lstm = LSTMCellWithZoneOut(
args.zoneout,
args.prenet_dim + args.encoder_embed_dim,
args.decoder_lstm_dim
args.decoder_lstm_dim,
)
# take attention_lstm output, attention_state, encoder_out as input
self.attention = LocationAttention(
args.attention_dim, args.encoder_embed_dim, args.decoder_lstm_dim,
args.attention_dim,
args.encoder_embed_dim,
args.decoder_lstm_dim,
(1 + int(args.attention_use_cumprob)),
args.attention_conv_dim, args.attention_conv_kernel_size
args.attention_conv_dim,
args.attention_conv_kernel_size,
)
# take attention_lstm output, context, (gated_latent) as input
@ -160,7 +178,7 @@ class Tacotron2Decoder(FairseqIncrementalDecoder):
LSTMCellWithZoneOut(
args.zoneout,
args.encoder_embed_dim + args.decoder_lstm_dim,
args.decoder_lstm_dim
args.decoder_lstm_dim,
)
for i in range(args.decoder_lstm_layers)
)
@ -169,12 +187,16 @@ class Tacotron2Decoder(FairseqIncrementalDecoder):
self.feat_proj = nn.Linear(proj_in_dim, self.out_dim)
self.eos_proj = nn.Linear(proj_in_dim, 1)
self.postnet = Postnet(self.out_dim, args.postnet_conv_dim,
args.postnet_conv_kernel_size,
args.postnet_layers, args.postnet_dropout)
self.postnet = Postnet(
self.out_dim,
args.postnet_conv_dim,
args.postnet_conv_kernel_size,
args.postnet_layers,
args.postnet_dropout,
)
self.ctc_proj = None
if getattr(args, "ctc_weight", 0.) > 0.:
if getattr(args, "ctc_weight", 0.0) > 0.0:
self.ctc_proj = nn.Linear(self.out_dim, len(src_dict))
self.apply(decoder_init)
@ -190,12 +212,16 @@ class Tacotron2Decoder(FairseqIncrementalDecoder):
lstm_h = self.get_incremental_state(incremental_state, "lstm_h")
if lstm_h is None:
lstm_h = [enc_out.new_zeros(bsz, self.args.decoder_lstm_dim)
for _ in range(self.args.decoder_lstm_layers)]
lstm_h = [
enc_out.new_zeros(bsz, self.args.decoder_lstm_dim)
for _ in range(self.args.decoder_lstm_layers)
]
lstm_c = self.get_incremental_state(incremental_state, "lstm_c")
if lstm_c is None:
lstm_c = [enc_out.new_zeros(bsz, self.args.decoder_lstm_dim)
for _ in range(self.args.decoder_lstm_layers)]
lstm_c = [
enc_out.new_zeros(bsz, self.args.decoder_lstm_dim)
for _ in range(self.args.decoder_lstm_layers)
]
attn_w = self.get_incremental_state(incremental_state, "attn_w")
if attn_w is None:
@ -216,8 +242,14 @@ class Tacotron2Decoder(FairseqIncrementalDecoder):
else:
raise ValueError(f"{self.args.init_attn_c} not supported")
def forward(self, prev_output_tokens, encoder_out=None,
incremental_state=None, target_lengths=None, **kwargs):
def forward(
self,
prev_output_tokens,
encoder_out=None,
incremental_state=None,
target_lengths=None,
**kwargs,
):
enc_mask = encoder_out["encoder_padding_mask"]
enc_out = encoder_out["encoder_out"][0]
in_len = enc_out.size(1)
@ -227,8 +259,9 @@ class Tacotron2Decoder(FairseqIncrementalDecoder):
bsz, out_len, _ = prev_output_tokens.size()
prenet_out = self.prenet(prev_output_tokens)
(alstm_h, alstm_c, lstm_h, lstm_c,
attn_w, attn_w_cum) = self._get_states(incremental_state, enc_out)
(alstm_h, alstm_c, lstm_h, lstm_c, attn_w, attn_w_cum) = self._get_states(
incremental_state, enc_out
)
attn_ctx = self._get_init_attn_c(enc_out, enc_mask)
attn_out = enc_out.new_zeros(bsz, in_len, out_len)
@ -241,9 +274,7 @@ class Tacotron2Decoder(FairseqIncrementalDecoder):
attn_state = attn_w.unsqueeze(1)
if self.args.attention_use_cumprob:
attn_state = torch.stack((attn_w, attn_w_cum), dim=1)
attn_ctx, attn_w = self.attention(
enc_out, enc_mask, alstm_h, attn_state
)
attn_ctx, attn_w = self.attention(enc_out, enc_mask, alstm_h, attn_state)
attn_w_cum = attn_w_cum + attn_w
attn_out[:, :, t] = attn_w
@ -297,7 +328,7 @@ class Tacotron2Model(FairseqEncoderDecoderModel):
parser.add_argument("--postnet-conv-dim", type=int)
parser.add_argument("--postnet-conv-kernel-size", type=int)
parser.add_argument("--init-attn-c", type=str)
parser.add_argument("--attention-use-cumprob", action='store_true')
parser.add_argument("--attention-use-cumprob", action="store_true")
parser.add_argument("--zoneout", type=float)
parser.add_argument("--decoder-lstm-layers", type=int)
parser.add_argument("--decoder-lstm-dim", type=int)
@ -333,8 +364,7 @@ def base_architecture(args):
# decoder
args.attention_dim = getattr(args, "attention_dim", 128)
args.attention_conv_dim = getattr(args, "attention_conv_dim", 32)
args.attention_conv_kernel_size = getattr(args,
"attention_conv_kernel_size", 15)
args.attention_conv_kernel_size = getattr(args, "attention_conv_kernel_size", 15)
args.prenet_dropout = getattr(args, "prenet_dropout", 0.5)
args.prenet_layers = getattr(args, "prenet_layers", 2)
args.prenet_dim = getattr(args, "prenet_dim", 256)

View File

@ -9,12 +9,14 @@ from typing import List, Optional
import torch
from torch import nn
from fairseq.models import (FairseqEncoder, FairseqEncoderDecoderModel,
FairseqIncrementalDecoder, register_model,
register_model_architecture)
from fairseq.modules import (
TransformerEncoderLayer, TransformerDecoderLayer
from fairseq.models import (
FairseqEncoder,
FairseqEncoderDecoderModel,
FairseqIncrementalDecoder,
register_model,
register_model_architecture,
)
from fairseq.modules import TransformerEncoderLayer, TransformerDecoderLayer
from fairseq.models.text_to_speech.tacotron2 import Prenet, Postnet
from fairseq.modules import LayerNorm, PositionalEmbedding, FairseqDropout
from fairseq.data.data_utils import lengths_to_padding_mask
@ -42,30 +44,31 @@ class TTSTransformerEncoder(FairseqEncoder):
self.spk_emb_proj = None
if embed_speaker is not None:
self.spk_emb_proj = nn.Linear(
args.encoder_embed_dim + args.speaker_embed_dim,
args.encoder_embed_dim
args.encoder_embed_dim + args.speaker_embed_dim, args.encoder_embed_dim
)
self.dropout_module = FairseqDropout(
p=args.dropout, module_name=self.__class__.__name__
)
self.embed_tokens = nn.Embedding(len(src_dict), args.encoder_embed_dim,
padding_idx=self.padding_idx)
assert(args.encoder_conv_kernel_size % 2 == 1)
self.embed_tokens = nn.Embedding(
len(src_dict), args.encoder_embed_dim, padding_idx=self.padding_idx
)
assert args.encoder_conv_kernel_size % 2 == 1
self.prenet = nn.ModuleList(
nn.Sequential(
nn.Conv1d(args.encoder_embed_dim, args.encoder_embed_dim,
kernel_size=args.encoder_conv_kernel_size,
padding=((args.encoder_conv_kernel_size - 1) // 2)),
nn.Conv1d(
args.encoder_embed_dim,
args.encoder_embed_dim,
kernel_size=args.encoder_conv_kernel_size,
padding=((args.encoder_conv_kernel_size - 1) // 2),
),
nn.BatchNorm1d(args.encoder_embed_dim),
nn.ReLU(),
nn.Dropout(args.encoder_dropout),
)
for _ in range(args.encoder_conv_layers)
)
self.prenet_proj = nn.Linear(
args.encoder_embed_dim, args.encoder_embed_dim
)
self.prenet_proj = nn.Linear(args.encoder_embed_dim, args.encoder_embed_dim)
self.embed_positions = PositionalEmbedding(
args.max_source_positions, args.encoder_embed_dim, self.padding_idx
)
@ -112,7 +115,9 @@ class TTSTransformerEncoder(FairseqEncoder):
return {
"encoder_out": [x], # T x B x C
"encoder_padding_mask": [padding_mask] if padding_mask.any() else [], # B x T
"encoder_padding_mask": [padding_mask]
if padding_mask.any()
else [], # B x T
"encoder_embedding": [], # B x T x C
"encoder_states": [], # List[T x B x C]
"src_tokens": [],
@ -143,15 +148,15 @@ class TTSTransformerDecoder(FairseqIncrementalDecoder):
)
self.pos_emb_alpha = nn.Parameter(torch.ones(1))
self.prenet = nn.Sequential(
Prenet(self.out_dim, args.prenet_layers, args.prenet_dim,
args.prenet_dropout),
Prenet(
self.out_dim, args.prenet_layers, args.prenet_dim, args.prenet_dropout
),
nn.Linear(args.prenet_dim, args.decoder_embed_dim),
)
self.n_transformer_layers = args.decoder_transformer_layers
self.transformer_layers = nn.ModuleList(
TransformerDecoderLayer(args)
for _ in range(self.n_transformer_layers)
TransformerDecoderLayer(args) for _ in range(self.n_transformer_layers)
)
if args.decoder_normalize_before:
self.layer_norm = LayerNorm(args.decoder_embed_dim)
@ -161,19 +166,28 @@ class TTSTransformerDecoder(FairseqIncrementalDecoder):
self.feat_proj = nn.Linear(args.decoder_embed_dim, self.out_dim)
self.eos_proj = nn.Linear(args.decoder_embed_dim, 1)
self.postnet = Postnet(self.out_dim, args.postnet_conv_dim,
args.postnet_conv_kernel_size,
args.postnet_layers, args.postnet_dropout)
self.postnet = Postnet(
self.out_dim,
args.postnet_conv_dim,
args.postnet_conv_kernel_size,
args.postnet_layers,
args.postnet_dropout,
)
self.ctc_proj = None
if getattr(args, "ctc_weight", 0.) > 0.:
if getattr(args, "ctc_weight", 0.0) > 0.0:
self.ctc_proj = nn.Linear(self.out_dim, len(src_dict))
self.apply(decoder_init)
def extract_features(
self, prev_outputs, encoder_out=None, incremental_state=None,
target_lengths=None, speaker=None, **kwargs
self,
prev_outputs,
encoder_out=None,
incremental_state=None,
target_lengths=None,
speaker=None,
**kwargs
):
alignment_layer = self.n_transformer_layers - 1
self_attn_padding_mask = lengths_to_padding_mask(target_lengths)
@ -212,8 +226,8 @@ class TTSTransformerDecoder(FairseqIncrementalDecoder):
else None,
encoder_out["encoder_padding_mask"][0]
if (
encoder_out is not None
and len(encoder_out["encoder_padding_mask"]) > 0
encoder_out is not None
and len(encoder_out["encoder_padding_mask"]) > 0
)
else None,
incremental_state,
@ -239,13 +253,22 @@ class TTSTransformerDecoder(FairseqIncrementalDecoder):
return x, {"attn": attn, "inner_states": inner_states}
def forward(self, prev_output_tokens, encoder_out=None,
incremental_state=None, target_lengths=None, speaker=None,
**kwargs):
def forward(
self,
prev_output_tokens,
encoder_out=None,
incremental_state=None,
target_lengths=None,
speaker=None,
**kwargs
):
x, extra = self.extract_features(
prev_output_tokens, encoder_out=encoder_out,
incremental_state=incremental_state, target_lengths=target_lengths,
speaker=speaker, **kwargs
prev_output_tokens,
encoder_out=encoder_out,
incremental_state=incremental_state,
target_lengths=target_lengths,
speaker=speaker,
**kwargs
)
attn = extra["attn"]
feat_out = self.feat_proj(x)
@ -328,8 +351,9 @@ class TTSTransformerModel(FairseqEncoderDecoderModel):
return cls(encoder, decoder)
def forward_encoder(self, src_tokens, src_lengths, speaker=None, **kwargs):
return self.encoder(src_tokens, src_lengths=src_lengths,
speaker=speaker, **kwargs)
return self.encoder(
src_tokens, src_lengths=src_lengths, speaker=speaker, **kwargs
)
def set_num_updates(self, num_updates):
super().set_num_updates(num_updates)
@ -348,7 +372,9 @@ def base_architecture(args):
# encoder transformer layers
args.encoder_transformer_layers = getattr(args, "encoder_transformer_layers", 6)
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4 * args.encoder_embed_dim)
args.encoder_ffn_embed_dim = getattr(
args, "encoder_ffn_embed_dim", 4 * args.encoder_embed_dim
)
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4)
args.attention_dropout = getattr(args, "attention_dropout", 0.0)
@ -366,6 +392,8 @@ def base_architecture(args):
# decoder transformer layers
args.decoder_transformer_layers = getattr(args, "decoder_transformer_layers", 6)
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512)
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 4 * args.decoder_embed_dim)
args.decoder_ffn_embed_dim = getattr(
args, "decoder_ffn_embed_dim", 4 * args.decoder_embed_dim
)
args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 4)

View File

@ -13,7 +13,10 @@ from torch import nn
import torch.nn.functional as F
from fairseq.data.audio.audio_utils import (
get_window, get_fourier_basis, get_mel_filters, TTSSpectrogram
get_window,
get_fourier_basis,
get_mel_filters,
TTSSpectrogram,
)
from fairseq.data.audio.speech_to_text_dataset import S2TDataConfig
from fairseq.models.text_to_speech.hifigan import Generator as HiFiGANModel
@ -25,11 +28,9 @@ class PseudoInverseMelScale(torch.nn.Module):
def __init__(self, n_stft, n_mels, sample_rate, f_min, f_max) -> None:
super(PseudoInverseMelScale, self).__init__()
self.n_mels = n_mels
basis = get_mel_filters(
sample_rate, (n_stft - 1) * 2, n_mels, f_min, f_max
)
basis = get_mel_filters(sample_rate, (n_stft - 1) * 2, n_mels, f_min, f_max)
basis = torch.pinverse(basis) # F x F_mel
self.register_buffer('basis', basis)
self.register_buffer("basis", basis)
def forward(self, melspec: torch.Tensor) -> torch.Tensor:
# pack batch
@ -48,8 +49,12 @@ class PseudoInverseMelScale(torch.nn.Module):
class GriffinLim(torch.nn.Module):
def __init__(
self, n_fft: int, win_length: int, hop_length: int, n_iter: int,
window_fn=torch.hann_window
self,
n_fft: int,
win_length: int,
hop_length: int,
n_iter: int,
window_fn=torch.hann_window,
):
super(GriffinLim, self).__init__()
self.transform = TTSSpectrogram(
@ -59,7 +64,7 @@ class GriffinLim(torch.nn.Module):
basis = get_fourier_basis(n_fft)
basis = torch.pinverse(n_fft / hop_length * basis).T[:, None, :]
basis *= get_window(window_fn, n_fft, win_length)
self.register_buffer('basis', basis)
self.register_buffer("basis", basis)
self.n_fft = n_fft
self.win_length = win_length
@ -70,33 +75,33 @@ class GriffinLim(torch.nn.Module):
@classmethod
def get_window_sum_square(
cls, n_frames, hop_length, win_length, n_fft,
window_fn=torch.hann_window
cls, n_frames, hop_length, win_length, n_fft, window_fn=torch.hann_window
) -> torch.Tensor:
w_sq = get_window(window_fn, n_fft, win_length) ** 2
n = n_fft + hop_length * (n_frames - 1)
x = torch.zeros(n, dtype=torch.float32)
for i in range(n_frames):
ofst = i * hop_length
x[ofst: min(n, ofst + n_fft)] += w_sq[:max(0, min(n_fft, n - ofst))]
x[ofst : min(n, ofst + n_fft)] += w_sq[: max(0, min(n_fft, n - ofst))]
return x
def inverse(self, magnitude: torch.Tensor, phase) -> torch.Tensor:
x = torch.cat(
[magnitude * torch.cos(phase), magnitude * torch.sin(phase)],
dim=1
[magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1
)
x = F.conv_transpose1d(x, self.basis, stride=self.hop_length)
win_sum_sq = self.get_window_sum_square(
magnitude.shape[-1], hop_length=self.hop_length,
win_length=self.win_length, n_fft=self.n_fft
magnitude.shape[-1],
hop_length=self.hop_length,
win_length=self.win_length,
n_fft=self.n_fft,
).to(magnitude.device)
# remove modulation effects
approx_nonzero_indices = win_sum_sq > self.tiny
x[:, :, approx_nonzero_indices] /= win_sum_sq[approx_nonzero_indices]
x *= self.n_fft / self.hop_length
x = x[:, :, self.n_fft // 2:]
x = x[:, :, :-self.n_fft // 2:]
x = x[:, :, self.n_fft // 2 :]
x = x[:, :, : -self.n_fft // 2 :]
return x
def forward(self, specgram: torch.Tensor) -> torch.Tensor:
@ -111,18 +116,33 @@ class GriffinLim(torch.nn.Module):
class GriffinLimVocoder(nn.Module):
def __init__(self, sample_rate, win_size, hop_size, n_fft,
n_mels, f_min, f_max, window_fn,
spec_bwd_max_iter=32,
fp16=False):
def __init__(
self,
sample_rate,
win_size,
hop_size,
n_fft,
n_mels,
f_min,
f_max,
window_fn,
spec_bwd_max_iter=32,
fp16=False,
):
super().__init__()
self.inv_mel_transform = PseudoInverseMelScale(
n_stft=n_fft // 2 + 1, n_mels=n_mels, sample_rate=sample_rate,
f_min=f_min, f_max=f_max
n_stft=n_fft // 2 + 1,
n_mels=n_mels,
sample_rate=sample_rate,
f_min=f_min,
f_max=f_max,
)
self.gl_transform = GriffinLim(
n_fft=n_fft, win_length=win_size, hop_length=hop_size,
window_fn=window_fn, n_iter=spec_bwd_max_iter
n_fft=n_fft,
win_length=win_size,
hop_length=hop_size,
window_fn=window_fn,
n_iter=spec_bwd_max_iter,
)
if fp16:
self.half()
@ -151,17 +171,19 @@ class GriffinLimVocoder(nn.Module):
sample_rate=feat_cfg["sample_rate"],
win_size=int(feat_cfg["win_len_t"] * feat_cfg["sample_rate"]),
hop_size=int(feat_cfg["hop_len_t"] * feat_cfg["sample_rate"]),
n_fft=feat_cfg["n_fft"], n_mels=feat_cfg["n_mels"],
f_min=feat_cfg["f_min"], f_max=feat_cfg["f_max"],
window_fn=window_fn, spec_bwd_max_iter=args.spec_bwd_max_iter,
fp16=args.fp16
n_fft=feat_cfg["n_fft"],
n_mels=feat_cfg["n_mels"],
f_min=feat_cfg["f_min"],
f_max=feat_cfg["f_max"],
window_fn=window_fn,
spec_bwd_max_iter=args.spec_bwd_max_iter,
fp16=args.fp16,
)
class HiFiGANVocoder(nn.Module):
def __init__(
self, checkpoint_path: str, model_cfg: Dict[str, str],
fp16: bool = False
self, checkpoint_path: str, model_cfg: Dict[str, str], fp16: bool = False
) -> None:
super().__init__()
self.model = HiFiGANModel(model_cfg)

View File

@ -29,8 +29,8 @@ from torch import Tensor
# rewrite name for backward compatibility in `make_generation_fast_`
def module_name_fordropout(module_name: str) -> str:
if module_name == 'TransformerDecoderBase':
return 'TransformerDecoder'
if module_name == "TransformerDecoderBase":
return "TransformerDecoder"
else:
return module_name

View File

@ -29,8 +29,8 @@ from fairseq.models.transformer import (
# rewrite name for backward compatibility in `make_generation_fast_`
def module_name_fordropout(module_name: str) -> str:
if module_name == 'TransformerEncoderBase':
return 'TransformerEncoder'
if module_name == "TransformerEncoderBase":
return "TransformerEncoder"
else:
return module_name
@ -232,7 +232,12 @@ class TransformerEncoderBase(FairseqEncoder):
# `forward` so we use a dictionary instead.
# TorchScript does not support mixed values so the values are all lists.
# The empty list is equivalent to None.
src_lengths = src_tokens.ne(self.padding_idx).sum(dim=1, dtype=torch.int32).reshape(-1, 1).contiguous()
src_lengths = (
src_tokens.ne(self.padding_idx)
.sum(dim=1, dtype=torch.int32)
.reshape(-1, 1)
.contiguous()
)
return {
"encoder_out": [x], # T x B x C
"encoder_padding_mask": [encoder_padding_mask], # B x T

View File

@ -15,7 +15,9 @@ from fairseq.models import (
register_model_architecture,
)
from fairseq.models.transformer import (
DEFAULT_MIN_PARAMS_TO_WRAP, Embedding, TransformerDecoder
DEFAULT_MIN_PARAMS_TO_WRAP,
Embedding,
TransformerDecoder,
)
from fairseq.modules import AdaptiveInput, CharacterTokenEmbedder
from fairseq.utils import safe_getattr, safe_hasattr
@ -179,7 +181,7 @@ class TransformerLanguageModelConfig(FairseqDataclass):
"is set to 0 (i.e., always wrap) when --checkpoint-activations or "
"--offload-activations are passed."
)
}
},
)
# config for "BASE Layers: Simplifying Training of Large, Sparse Models"
base_layers: Optional[int] = field(
@ -189,13 +191,25 @@ class TransformerLanguageModelConfig(FairseqDataclass):
default=1, metadata={"help": "number of sublayers in each BASE layer"}
)
base_shuffle: Optional[int] = field(
default=1, metadata={"help": "shuffle tokens between workers before computing assignment"}
default=1,
metadata={"help": "shuffle tokens between workers before computing assignment"},
)
# NormFormer
scale_fc: Optional[bool] = field(default=False, metadata={"help": 'Insert LayerNorm between fully connected layers'})
scale_attn: Optional[bool] = field(default=False, metadata={"help": 'Insert LayerNorm after attention'})
scale_heads: Optional[bool] = field(default=False, metadata={"help": 'Learn a scale coefficient for each attention head'})
scale_resids: Optional[bool] = field(default=False, metadata={"help": 'Learn a scale coefficient for each residual connection'})
scale_fc: Optional[bool] = field(
default=False,
metadata={"help": "Insert LayerNorm between fully connected layers"},
)
scale_attn: Optional[bool] = field(
default=False, metadata={"help": "Insert LayerNorm after attention"}
)
scale_heads: Optional[bool] = field(
default=False,
metadata={"help": "Learn a scale coefficient for each attention head"},
)
scale_resids: Optional[bool] = field(
default=False,
metadata={"help": "Learn a scale coefficient for each residual connection"},
)
# options from other parts of the config
add_bos_token: bool = II("task.add_bos_token")
tokens_per_sample: int = II("task.tokens_per_sample")
@ -345,7 +359,9 @@ def base_lm_architecture(args):
args.decoder_output_dim = safe_getattr(
args, "decoder_output_dim", args.decoder_embed_dim
)
args.decoder_input_dim = safe_getattr(args, "decoder_input_dim", args.decoder_embed_dim)
args.decoder_input_dim = safe_getattr(
args, "decoder_input_dim", args.decoder_embed_dim
)
# Model training is not stable without this
args.decoder_normalize_before = True
@ -362,10 +378,10 @@ def base_lm_architecture(args):
args.layernorm_embedding = safe_getattr(args, "layernorm_embedding", False)
args.checkpoint_activations = safe_getattr(args, "checkpoint_activations", False)
args.offload_activations = safe_getattr(args, "offload_activations", False)
args.scale_fc = safe_getattr(args, 'scale_fc', False)
args.scale_attn = safe_getattr(args, 'scale_attn', False)
args.scale_heads = safe_getattr(args, 'scale_heads', False)
args.scale_resids = safe_getattr(args, 'scale_resids', False)
args.scale_fc = safe_getattr(args, "scale_fc", False)
args.scale_attn = safe_getattr(args, "scale_attn", False)
args.scale_heads = safe_getattr(args, "scale_heads", False)
args.scale_resids = safe_getattr(args, "scale_resids", False)
if args.offload_activations:
args.checkpoint_activations = True
@ -387,7 +403,9 @@ def transformer_lm_baevski_wiki103(args):
args.dropout = safe_getattr(args, "dropout", 0.3)
args.adaptive_input = safe_getattr(args, "adaptive_input", True)
args.tie_adaptive_weights = safe_getattr(args, "tie_adaptive_weights", True)
args.adaptive_input_cutoff = safe_getattr(args, "adaptive_input_cutoff", "20000,60000")
args.adaptive_input_cutoff = safe_getattr(
args, "adaptive_input_cutoff", "20000,60000"
)
args.adaptive_softmax_cutoff = safe_getattr(
args, "adaptive_softmax_cutoff", "20000,60000"
)
@ -472,7 +490,9 @@ def transformer_lm_gpt2_big(args):
def base_gpt3_architecture(args):
args.decoder_input_dim = args.decoder_embed_dim
args.decoder_output_dim = args.decoder_embed_dim
args.decoder_ffn_embed_dim = safe_getattr(args, "decoder_ffn_embed_dim", args.decoder_embed_dim * 4)
args.decoder_ffn_embed_dim = safe_getattr(
args, "decoder_ffn_embed_dim", args.decoder_embed_dim * 4
)
# GPT-3 used learned positional embeddings, rather than sinusoidal
args.decoder_learned_pos = safe_getattr(args, "decoder_learned_pos", True)
args.dropout = safe_getattr(args, "dropout", 0.0)

View File

@ -232,9 +232,11 @@ class Wav2Vec2Config(FairseqDataclass):
)
checkpoint_activations: bool = field(
default=False, metadata={"help": "recompute activations and save memory for extra compute"}
default=False,
metadata={"help": "recompute activations and save memory for extra compute"},
)
@register_model("wav2vec2", dataclass=Wav2Vec2Config)
class Wav2Vec2Model(BaseFairseqModel):
def __init__(self, cfg: Wav2Vec2Config):
@ -844,14 +846,14 @@ class TransformerEncoder(nn.Module):
layers = []
for _ in range(args.encoder_layers):
layer = TransformerSentenceEncoderLayer(
embedding_dim=self.embedding_dim,
ffn_embedding_dim=args.encoder_ffn_embed_dim,
num_attention_heads=args.encoder_attention_heads,
dropout=self.dropout,
attention_dropout=args.attention_dropout,
activation_dropout=args.activation_dropout,
activation_fn=args.activation_fn,
layer_norm_first=args.layer_norm_first,
embedding_dim=self.embedding_dim,
ffn_embedding_dim=args.encoder_ffn_embed_dim,
num_attention_heads=args.encoder_attention_heads,
dropout=self.dropout,
attention_dropout=args.attention_dropout,
activation_dropout=args.activation_dropout,
activation_fn=args.activation_fn,
layer_norm_first=args.layer_norm_first,
)
if args.checkpoint_activations:
layer = fsdp_wrap(layer)

View File

@ -152,10 +152,12 @@ class Wav2Vec2AsrConfig(FairseqDataclass):
w2v_args: Any = None
checkpoint_activations: bool = field(
default=False, metadata={"help": "recompute activations and save memory for extra compute"}
default=False,
metadata={"help": "recompute activations and save memory for extra compute"},
)
ddp_backend: str = II("distributed_training.ddp_backend")
@dataclass
class Wav2Vec2CtcConfig(Wav2Vec2AsrConfig):
blank_weight: float = 0
@ -268,6 +270,7 @@ class Wav2Vec2Seq2SeqConfig(Wav2Vec2AsrConfig):
)
autoregressive: bool = II("task.autoregressive")
@register_model("wav2vec_seq2seq", dataclass=Wav2Vec2Seq2SeqConfig)
class Wav2Vec2Seq2SeqModel(FairseqEncoderDecoderModel):
def __init__(self, encoder, decoder):
@ -394,12 +397,17 @@ class Wav2VecEncoder(FairseqEncoder):
def load_model_weights(self, state, model, cfg):
if cfg.ddp_backend == "fully_sharded":
from fairseq.distributed import FullyShardedDataParallel
for name, module in model.named_modules():
if "encoder.layers" in name and len(name.split(".")) == 3:
# Only for layers, we do a special handling and load the weights one by one
# We dont load all weights together as that wont be memory efficient and may
# cause oom
new_dict = {k.replace(name+".", "") : v for (k, v) in state["model"].items() if name+"." in k}
new_dict = {
k.replace(name + ".", ""): v
for (k, v) in state["model"].items()
if name + "." in k
}
assert isinstance(module, FullyShardedDataParallel)
with module.summon_full_params():
module.load_state_dict(new_dict, strict=True)
@ -409,7 +417,9 @@ class Wav2VecEncoder(FairseqEncoder):
r = re.compile("encoder.layers.\d.")
filtered_list = list(filter(r.match, state["model"].keys()))
new_big_dict = {k: v for (k, v) in state["model"].items() if k not in filtered_list}
new_big_dict = {
k: v for (k, v) in state["model"].items() if k not in filtered_list
}
model.load_state_dict(new_big_dict, strict=False)
else:
@ -462,9 +472,9 @@ class Wav2VecEncoder(FairseqEncoder):
1, new_order
)
if encoder_out["padding_mask"] is not None:
encoder_out["padding_mask"] = encoder_out[
"padding_mask"
].index_select(0, new_order)
encoder_out["padding_mask"] = encoder_out["padding_mask"].index_select(
0, new_order
)
return encoder_out
def max_positions(self):
@ -640,7 +650,7 @@ class TransformerDecoder(FairseqIncrementalDecoder):
self_attn_mask=self.buffered_future_mask(x)
if incremental_state is None
else None,
self_attn_padding_mask=self_attn_padding_mask
self_attn_padding_mask=self_attn_padding_mask,
)
inner_states.append(x)

View File

@ -12,14 +12,17 @@ from fairseq.modules.layer_norm import LayerNorm
class BaseLayer(nn.Module):
def __init__(self, args):
super().__init__()
self.num_workers = distributed_utils.get_data_parallel_world_size()
expert_centroids = torch.empty(self.num_workers, args.decoder_embed_dim)
torch.nn.init.orthogonal_(expert_centroids, gain=0.1)
self.register_parameter("expert_centroids", torch.nn.Parameter(expert_centroids))
self.expert_network = nn.Sequential(*([BaseSublayer(args) for _ in range(args.base_sublayers)]))
self.register_parameter(
"expert_centroids", torch.nn.Parameter(expert_centroids)
)
self.expert_network = nn.Sequential(
*([BaseSublayer(args) for _ in range(args.base_sublayers)])
)
self.expert_id = distributed_utils.get_data_parallel_rank()
self.shuffle = args.base_shuffle
self.cpp = self.load_assignment()
@ -39,20 +42,34 @@ class BaseLayer(nn.Module):
with torch.no_grad():
# Compute similarity of each token to each expert, for routing
token_expert_affinities = features.matmul(self.expert_centroids.transpose(0, 1))
token_expert_affinities = features.matmul(
self.expert_centroids.transpose(0, 1)
)
# Compute which token goes to which expert
sort_by_expert, input_splits, output_splits = self.balanced_assignment(token_expert_affinities) \
if is_training else self.greedy_assignment(token_expert_affinities)
sort_by_expert, input_splits, output_splits = (
self.balanced_assignment(token_expert_affinities)
if is_training
else self.greedy_assignment(token_expert_affinities)
)
# Swap these tokens for the right ones for our expert
routed_features = All2All.apply(features[sort_by_expert], output_splits, input_splits)
routed_features = All2All.apply(
features[sort_by_expert], output_splits, input_splits
)
if routed_features.size(0) > 0:
# Mix in the expert network based on how appropriate it is for these tokens
alpha = torch.sigmoid(routed_features.mv(self.expert_centroids[self.expert_id])).unsqueeze(1)
routed_features = alpha * self.expert_network(routed_features) + (1 - alpha) * routed_features
alpha = torch.sigmoid(
routed_features.mv(self.expert_centroids[self.expert_id])
).unsqueeze(1)
routed_features = (
alpha * self.expert_network(routed_features)
+ (1 - alpha) * routed_features
)
# Return to original worker and ordering
result = All2All.apply(routed_features, input_splits, output_splits)[self.inverse_sort(sort_by_expert)]
result = All2All.apply(routed_features, input_splits, output_splits)[
self.inverse_sort(sort_by_expert)
]
if self.shuffle and is_training:
# Undo shuffling
@ -63,7 +80,9 @@ class BaseLayer(nn.Module):
def inverse_sort(self, order):
# Creates an index that undoes a sort: xs==xs[order][inverse_sort(order)]
return torch.empty_like(order).scatter_(0, order, torch.arange(0, order.size(0), device=order.device))
return torch.empty_like(order).scatter_(
0, order, torch.arange(0, order.size(0), device=order.device)
)
def balanced_assignment(self, scores):
ok = scores.isfinite()
@ -79,7 +98,9 @@ class BaseLayer(nn.Module):
worker2token = sort_ordering // k
# Find how many tokens we're sending to each other worker (being careful for sending 0 tokens to some workers)
output_splits = torch.zeros((self.num_workers,), dtype=torch.long, device=scores.device)
output_splits = torch.zeros(
(self.num_workers,), dtype=torch.long, device=scores.device
)
workers, counts = torch.unique_consecutive(token_to_workers, return_counts=True)
output_splits[workers] = counts
# Tell other workers how many tokens to expect from us
@ -103,7 +124,7 @@ class BaseSublayer(nn.Module):
def __init__(self, args):
super().__init__()
self.activation_fn = utils.get_activation_fn(
activation=getattr(args, 'activation_fn', 'relu') or "relu"
activation=getattr(args, "activation_fn", "relu") or "relu"
)
self.norm = LayerNorm(args.decoder_embed_dim, export=False)
self.ff1 = torch.nn.Linear(args.decoder_embed_dim, args.decoder_ffn_embed_dim)
@ -121,15 +142,29 @@ class All2All(torch.autograd.Function):
ctx.input_splits = input_splits
ctx.output_splits = output_splits
ys = torch.empty_like(xs) if output_splits is None else \
xs.new_empty(size=[sum(output_splits)] + list(xs.size()[1:]))
torch.distributed.all_to_all_single(ys, xs, output_split_sizes=output_splits, input_split_sizes=input_splits)
ys = (
torch.empty_like(xs)
if output_splits is None
else xs.new_empty(size=[sum(output_splits)] + list(xs.size()[1:]))
)
torch.distributed.all_to_all_single(
ys, xs, output_split_sizes=output_splits, input_split_sizes=input_splits
)
return ys
@staticmethod
def backward(ctx, grad_output):
result = torch.empty_like(grad_output) if ctx.input_splits is None else \
grad_output.new_empty(size=[sum(ctx.input_splits)] + list(grad_output.size()[1:]))
torch.distributed.all_to_all_single(result, grad_output,
output_split_sizes=ctx.input_splits, input_split_sizes=ctx.output_splits)
result = (
torch.empty_like(grad_output)
if ctx.input_splits is None
else grad_output.new_empty(
size=[sum(ctx.input_splits)] + list(grad_output.size()[1:])
)
)
torch.distributed.all_to_all_single(
result,
grad_output,
output_split_sizes=ctx.input_splits,
input_split_sizes=ctx.output_splits,
)
return result, None, None

View File

@ -166,7 +166,9 @@ class CheckpointFunction(torch.autograd.Function):
if parent_ctx_dict["offload"]:
ctx.fwd_device = tuple(x.device for x in tensor_inputs)
ctx.grad_requirements = tuple(x.requires_grad for x in tensor_inputs)
tensor_inputs = tuple(x.to(torch.device("cpu"), non_blocking=True) for x in tensor_inputs)
tensor_inputs = tuple(
x.to(torch.device("cpu"), non_blocking=True) for x in tensor_inputs
)
else:
ctx.fwd_device, ctx.grad_requirements = None, None
@ -199,7 +201,8 @@ class CheckpointFunction(torch.autograd.Function):
tensor_inputs = checkpoint.detach_variable(tensor_inputs)
if ctx.fwd_device is not None:
tensor_inputs = [
t.to(ctx.fwd_device[i], non_blocking=True) for i, t in enumerate(tensor_inputs)
t.to(ctx.fwd_device[i], non_blocking=True)
for i, t in enumerate(tensor_inputs)
]
for i, need_grad in enumerate(ctx.grad_requirements):
tensor_inputs[i].requires_grad = need_grad

View File

@ -75,6 +75,7 @@ class GumbelVectorQuantizer(nn.Module):
if isinstance(temp, str):
import ast
temp = ast.literal_eval(temp)
assert len(temp) == 3, f"{temp}, {len(temp)}"

View File

@ -47,11 +47,12 @@ def cache_fn(f):
return cache
cache = f(*args, **kwargs)
return cache
return cached_fn
def to(t):
return {'device': t.device, 'dtype': t.dtype}
return {"device": t.device, "dtype": t.dtype}
def find_modules(nn_module, type):
@ -102,7 +103,7 @@ def reshape_dim(t, dim, split_dims):
shape = list(t.shape)
num_dims = len(shape)
dim = (dim + num_dims) % num_dims
shape[dim:dim+1] = split_dims
shape[dim : dim + 1] = split_dims
return t.reshape(shape)
@ -118,6 +119,7 @@ def ema_inplace(moving_avg, new, decay):
return
moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
# helper classes
@ -173,6 +175,7 @@ class ScaleNorm(nn.Module):
def norm(t):
n = torch.norm(t, dim=-1, keepdim=True).clamp(min=self.eps)
return t / n * self.g
return map_first_tuple_or_el(x, norm)
@ -202,51 +205,62 @@ class MatrixMultiply(nn.Module):
tensor = tensor.t()
return x @ tensor
# positional embeddings
class DepthWiseConv1d(nn.Module):
def __init__(self, dim_in, dim_out, kernel_size, stride=1, bias=True, causal=False):
super().__init__()
self.padding = ((kernel_size - 1), 0) if causal else (kernel_size // 2, kernel_size // 2)
self.padding = (
((kernel_size - 1), 0) if causal else (kernel_size // 2, kernel_size // 2)
)
self.net = nn.Sequential(
nn.Conv1d(dim_in, dim_in, kernel_size=kernel_size, groups=dim_in, stride=stride, bias=bias),
nn.Conv1d(dim_in, dim_out, 1, bias=bias)
nn.Conv1d(
dim_in,
dim_in,
kernel_size=kernel_size,
groups=dim_in,
stride=stride,
bias=bias,
),
nn.Conv1d(dim_in, dim_out, 1, bias=bias),
)
def forward(self, x):
x = F.pad(x, self.padding, value=0.)
x = F.pad(x, self.padding, value=0.0)
return self.net(x)
class FixedPositionalEmbedding(nn.Module):
def __init__(self, dim, max_seq_len):
super().__init__()
inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
position = torch.arange(0, max_seq_len, dtype=torch.float)
sinusoid_inp = torch.einsum("i,j->ij", position, inv_freq)
emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1)
self.register_buffer('emb', emb)
self.register_buffer("emb", emb)
def forward(self, x):
return self.emb[None, :x.shape[1], :].to(x)
return self.emb[None, : x.shape[1], :].to(x)
def rotate_every_two(x):
x = rearrange(x, '... (d j) -> ... d j', j=2)
x = rearrange(x, "... (d j) -> ... d j", j=2)
x1, x2 = x.unbind(dim=-1)
x = torch.stack((-x2, x1), dim=-1)
return rearrange(x, '... d j -> ... (d j)')
return rearrange(x, "... d j -> ... (d j)")
def apply_rotary_pos_emb(q, k, sinu_pos):
sinu_pos = rearrange(sinu_pos, '() n (j d) -> n j d', j=2)
sinu_pos = rearrange(sinu_pos, "() n (j d) -> n j d", j=2)
sin, cos = sinu_pos.unbind(dim=-2)
sin, cos = map(lambda t: repeat(t, 'b n -> b (n j)', j=2), (sin, cos))
sin, cos = map(lambda t: repeat(t, "b n -> b (n j)", j=2), (sin, cos))
q, k = map(lambda t: (t * cos) + (rotate_every_two(t) * sin), (q, k))
return q, k
# kmeans related function and class
@ -261,7 +275,7 @@ def update_kmeans_on_backwards(module):
def similarity(x, means):
return torch.einsum('bhld,hcd->bhlc', x, means)
return torch.einsum("bhld,hcd->bhlc", x, means)
def dists_and_buckets(x, means):
@ -303,13 +317,15 @@ def distribution(dists, window_size):
class Kmeans(nn.Module):
def __init__(self, num_heads, head_dim, num_clusters, ema_decay=0.999, commitment=1e-4):
def __init__(
self, num_heads, head_dim, num_clusters, ema_decay=0.999, commitment=1e-4
):
super().__init__()
self.commitment = commitment
self.ema_decay = ema_decay
self.register_buffer('means', torch.randn(num_heads, num_clusters, head_dim))
self.register_buffer('initted', torch.tensor(False))
self.register_buffer("means", torch.randn(num_heads, num_clusters, head_dim))
self.register_buffer("initted", torch.tensor(False))
self.num_new_means = 0
self.new_means = None
@ -341,7 +357,7 @@ class Kmeans(nn.Module):
@torch.no_grad()
def update(self, new_means=None):
new_means = default(new_means, self.new_means)
assert exists(new_means), 'new kmeans has not been supplied'
assert exists(new_means), "new kmeans has not been supplied"
ema_inplace(self.means, new_means, self.ema_decay)
del self.new_means
@ -364,16 +380,33 @@ class Kmeans(nn.Module):
if update_means:
with torch.no_grad():
means = kmeans_iter(x, means, buckets)
self.new_means = ema(self.new_means, means, self.num_new_means / (self.num_new_means + 1))
self.new_means = ema(
self.new_means, means, self.num_new_means / (self.num_new_means + 1)
)
self.num_new_means += 1
return dists, loss
# kmeans attention class
class KmeansAttention(nn.Module):
def __init__(self, num_clusters, window_size, num_heads, head_dim, causal=False, dropout=0., ema_decay=0.999, commitment=1e-4, context_window_size=None, receives_context=False, num_mem_kv=0, shared_qk=False):
def __init__(
self,
num_clusters,
window_size,
num_heads,
head_dim,
causal=False,
dropout=0.0,
ema_decay=0.999,
commitment=1e-4,
context_window_size=None,
receives_context=False,
num_mem_kv=0,
shared_qk=False,
):
super().__init__()
self.num_heads = num_heads
self.num_clusters = num_clusters
@ -389,18 +422,32 @@ class KmeansAttention(nn.Module):
self.dropout = nn.Dropout(dropout)
self.num_mem_kv = max(num_mem_kv, 1 if causal and not shared_qk else 0)
self.mem_key = nn.Parameter(torch.randn(num_heads, num_clusters, self.num_mem_kv, head_dim))
self.mem_value = nn.Parameter(torch.randn(num_heads, num_clusters, self.num_mem_kv, head_dim))
self.mem_key = nn.Parameter(
torch.randn(num_heads, num_clusters, self.num_mem_kv, head_dim)
)
self.mem_value = nn.Parameter(
torch.randn(num_heads, num_clusters, self.num_mem_kv, head_dim)
)
def forward(self, q, k, v, query_mask=None, key_mask=None, **kwargs):
b, h, t, d, kv_t, wsz, c_wsz, nc, device, dtype = *q.shape, k.shape[2], self.window_size, self.context_window_size, self.num_clusters, q.device, q.dtype
is_reverse = kwargs.pop('_reverse', False)
b, h, t, d, kv_t, wsz, c_wsz, nc, device, dtype = (
*q.shape,
k.shape[2],
self.window_size,
self.context_window_size,
self.num_clusters,
q.device,
q.dtype,
)
is_reverse = kwargs.pop("_reverse", False)
out = torch.zeros_like(q, dtype=dtype)
update_kmeans = self.training and not is_reverse
key_mask = default(key_mask, query_mask) if not self.receives_context else key_mask
key_mask = (
default(key_mask, query_mask) if not self.receives_context else key_mask
)
kv_wsz = wsz if not self.receives_context else c_wsz
wsz = min(wsz, t)
@ -424,16 +471,22 @@ class KmeansAttention(nn.Module):
reshape_with_window = lambda x: x.reshape(b, h, nc, -1, d)
q, k, v = map(reshape_with_window, (q, k, v))
m_k, m_v = map(lambda x: expand_dim(x, 0, b).to(q), (self.mem_key, self.mem_value))
m_k, m_v = map(
lambda x: expand_dim(x, 0, b).to(q), (self.mem_key, self.mem_value)
)
k, v = map(lambda x: torch.cat(x, dim=3), ((m_k, k), (m_v, v)))
dots = torch.einsum('bhnid,bhnjd->bhnij', q, k) * (d ** -0.5)
dots = torch.einsum("bhnid,bhnjd->bhnij", q, k) * (d ** -0.5)
mask_value = max_neg_value(dots)
if exists(query_mask) or exists(key_mask):
query_mask = default(query_mask, lambda: torch.ones((b, t), device=device).bool())
key_mask = default(key_mask, lambda: torch.ones((b, kv_t), device=device).bool())
query_mask = default(
query_mask, lambda: torch.ones((b, t), device=device).bool()
)
key_mask = default(
key_mask, lambda: torch.ones((b, kv_t), device=device).bool()
)
q_mask = expand_dim(query_mask, 1, h).gather(2, indices)
kv_mask = expand_dim(key_mask, 1, h).gather(2, kv_indices)
@ -444,14 +497,18 @@ class KmeansAttention(nn.Module):
del mask
if self.causal:
q_mask, kv_mask = map(lambda t: t.reshape(b, h, nc, -1), (indices, kv_indices))
q_mask, kv_mask = map(
lambda t: t.reshape(b, h, nc, -1), (indices, kv_indices)
)
mask = q_mask[:, :, :, :, None] >= kv_mask[:, :, :, None, :]
mask = F.pad(mask, (self.num_mem_kv, 0), value=1)
dots.masked_fill_(~mask, mask_value)
del mask
if self.shared_qk:
q_mask, kv_mask = map(lambda t: t.reshape(b, h, nc, -1), (indices, kv_indices))
q_mask, kv_mask = map(
lambda t: t.reshape(b, h, nc, -1), (indices, kv_indices)
)
mask = q_mask[:, :, :, :, None] == kv_mask[:, :, :, None, :]
mask = F.pad(mask, (self.num_mem_kv, 0), value=0)
dots.masked_fill_(mask, TOKEN_SELF_ATTN_VALUE)
@ -460,24 +517,32 @@ class KmeansAttention(nn.Module):
dots = dots.softmax(dim=-1)
dots = self.dropout(dots)
bo = torch.einsum('bhcij,bhcjd->bhcid', dots, v)
bo = torch.einsum("bhcij,bhcjd->bhcid", dots, v)
so = torch.reshape(bo, (b, h, -1, bo.shape[-1])).type(dtype)
out = scatter_mean(out, so, indices.unsqueeze(-1).expand_as(so), -2)
return out, aux_loss
# feedforward
class GELU_(nn.Module):
def forward(self, x):
return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
return (
0.5
* x
* (
1
+ torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))
)
)
GELU = nn.GELU if hasattr(nn, 'GELU') else GELU_
GELU = nn.GELU if hasattr(nn, "GELU") else GELU_
class FeedForward(nn.Module):
def __init__(self, dim, mult=4, dropout=0., activation=None, glu=False):
def __init__(self, dim, mult=4, dropout=0.0, activation=None, glu=False):
super().__init__()
activation = default(activation, GELU)
@ -499,17 +564,49 @@ class FeedForward(nn.Module):
x = self.w2(x)
return x
# self attention
class SelfAttention(nn.Module):
def __init__(self, dim, max_seq_len, heads, local_attn_heads, window_size, dim_head=None, local_attn_window_size=None, local_attn_radius_blocks=1, causal=False, attn_dropout=0., dropout=0., kmeans_ema_decay=0.999, commitment_factor=1e-4, receives_context=False, context_window_size=None, rel_pos_emb=True, num_mem_kv=0, shared_qk=False, conv_query_kernel=9):
def __init__(
self,
dim,
max_seq_len,
heads,
local_attn_heads,
window_size,
dim_head=None,
local_attn_window_size=None,
local_attn_radius_blocks=1,
causal=False,
attn_dropout=0.0,
dropout=0.0,
kmeans_ema_decay=0.999,
commitment_factor=1e-4,
receives_context=False,
context_window_size=None,
rel_pos_emb=True,
num_mem_kv=0,
shared_qk=False,
conv_query_kernel=9,
):
super().__init__()
assert dim_head or (dim % heads) == 0, 'hidden dimension must be divisible by number of heads'
assert (max_seq_len % window_size) == 0, 'maximum sequence length must be divisible by the target window size'
assert local_attn_heads <= heads, 'number of local attention heads must be less than total heads'
assert not (receives_context and local_attn_heads > 0), 'local attention cannot be used for self attention with context'
assert not (receives_context and causal), 'contextual attention layer cannot be causal'
assert (
dim_head or (dim % heads) == 0
), "hidden dimension must be divisible by number of heads"
assert (
max_seq_len % window_size
) == 0, "maximum sequence length must be divisible by the target window size"
assert (
local_attn_heads <= heads
), "number of local attention heads must be less than total heads"
assert not (
receives_context and local_attn_heads > 0
), "local attention cannot be used for self attention with context"
assert not (
receives_context and causal
), "contextual attention layer cannot be causal"
local_attn_window_size = default(local_attn_window_size, window_size)
context_window_size = default(context_window_size, window_size)
@ -535,7 +632,15 @@ class SelfAttention(nn.Module):
if self.local_attn_heads > 0:
rel_pos_emb_config = (dim_head, local_attn_heads) if rel_pos_emb else None
self.local_attn = LocalAttention(dim=dim_head, window_size=local_attn_window_size, causal=causal, dropout=attn_dropout, rel_pos_emb_config=rel_pos_emb_config, look_backward=local_attn_radius_blocks, look_forward=0 if causal else local_attn_radius_blocks)
self.local_attn = LocalAttention(
dim=dim_head,
window_size=local_attn_window_size,
causal=causal,
dropout=attn_dropout,
rel_pos_emb_config=rel_pos_emb_config,
look_backward=local_attn_radius_blocks,
look_forward=0 if causal else local_attn_radius_blocks,
)
self.local_to_qkv = nn.Linear(dim, 3 * local_dim_heads)
# global
@ -543,12 +648,24 @@ class SelfAttention(nn.Module):
global_dim_heads = dim_head * self.global_attn_heads
if self.global_attn_heads > 0:
self.global_attn = KmeansAttention(num_clusters, window_size, self.global_attn_heads, dim_head, causal=causal, dropout=attn_dropout, ema_decay=kmeans_ema_decay, commitment=commitment_factor, receives_context=receives_context, num_mem_kv=num_mem_kv, shared_qk=shared_qk)
self.global_attn = KmeansAttention(
num_clusters,
window_size,
self.global_attn_heads,
dim_head,
causal=causal,
dropout=attn_dropout,
ema_decay=kmeans_ema_decay,
commitment=commitment_factor,
receives_context=receives_context,
num_mem_kv=num_mem_kv,
shared_qk=shared_qk,
)
self.to_q = nn.Sequential(
Rearrange('b n c -> b c n'),
Rearrange("b n c -> b c n"),
DepthWiseConv1d(dim, global_dim_heads, conv_query_kernel, causal=causal),
Rearrange('b c n -> b n c')
Rearrange("b c n -> b n c"),
)
self.to_v = nn.Linear(dim, global_dim_heads, bias=False)
@ -561,14 +678,30 @@ class SelfAttention(nn.Module):
self.to_out = nn.Linear(dim_heads, dim, bias=False)
self.dropout = nn.Dropout(dropout)
def forward(self, query, key, value, context=None, key_padding_mask=None, context_mask=None, pos_emb=None, **kwargs):
assert not (self.receives_context and not exists(context)), 'context must be passed if self attention is set to receive context'
def forward(
self,
query,
key,
value,
context=None,
key_padding_mask=None,
context_mask=None,
pos_emb=None,
**kwargs
):
assert not (
self.receives_context and not exists(context)
), "context must be passed if self attention is set to receive context"
input_mask = key_padding_mask
x = query.transpose(0, 1)
b, t, _, h, dh = *x.shape, self.heads, self.dim_head
has_local, has_global = map(lambda x: x > 0, (self.local_attn_heads, self.global_attn_heads))
has_local, has_global = map(
lambda x: x > 0, (self.local_attn_heads, self.global_attn_heads)
)
split_heads = lambda v: reshape_dim(v, -1, (-1, dh)).transpose(1, 2).contiguous()
split_heads = (
lambda v: reshape_dim(v, -1, (-1, dh)).transpose(1, 2).contiguous()
)
if has_local:
local_qkv = self.local_to_qkv(x).chunk(3, dim=-1)
@ -587,7 +720,7 @@ class SelfAttention(nn.Module):
q, k, v = map(split_heads, (q, k, v))
out = []
total_loss = torch.tensor(0., requires_grad=True, **to(x))
total_loss = torch.tensor(0.0, requires_grad=True, **to(x))
if has_local:
local_out = self.local_attn(lq, lk, lv, input_mask=input_mask)
@ -597,7 +730,9 @@ class SelfAttention(nn.Module):
if not self.receives_context and exists(pos_emb):
q, k = apply_rotary_pos_emb(q, k, pos_emb)
global_out, loss = self.global_attn(q, k, v, query_mask=input_mask, key_mask=context_mask)
global_out, loss = self.global_attn(
q, k, v, query_mask=input_mask, key_mask=context_mask
)
total_loss = total_loss + loss
out.append(global_out)

View File

@ -13,6 +13,7 @@ from .conv_tbc import ConvTBC
from typing import Dict, Optional
from torch import Tensor
@with_incremental_state
class LinearizedConvolution(ConvTBC):
"""An optimized version of nn.Conv1d.
@ -41,7 +42,11 @@ class LinearizedConvolution(ConvTBC):
del state_dict[prefix + "_linearized_weight"]
@torch.jit.export
def forward(self, input, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None):
def forward(
self,
input,
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
):
"""
Args:
incremental_state: Used to buffer signal; if not None, then input is
@ -80,18 +85,28 @@ class LinearizedConvolution(ConvTBC):
return output.view(bsz, 1, -1)
@torch.jit.unused
def reorder_incremental_state(self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], new_order):
def reorder_incremental_state(
self,
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]],
new_order,
):
input_buffer = self._get_input_buffer(incremental_state)
if input_buffer is not None:
input_buffer = input_buffer.index_select(0, new_order)
self._set_input_buffer(incremental_state, input_buffer)
@torch.jit.unused
def _get_input_buffer(self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]):
def _get_input_buffer(
self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]
):
return utils.get_incremental_state(self, incremental_state, "input_buffer")
@torch.jit.unused
def _set_input_buffer(self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], new_buffer):
def _set_input_buffer(
self,
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]],
new_buffer,
):
return utils.set_incremental_state(
self, incremental_state, "input_buffer", new_buffer
)

View File

@ -20,9 +20,16 @@ class LocationAttention(nn.Module):
:param int conv_kernel_size: filter size of attention convolution
"""
def __init__(self, attn_dim, encoder_dim, decoder_dim,
attn_state_kernel_size, conv_dim, conv_kernel_size,
scaling=2.0):
def __init__(
self,
attn_dim,
encoder_dim,
decoder_dim,
attn_state_kernel_size,
conv_dim,
conv_kernel_size,
scaling=2.0,
):
super(LocationAttention, self).__init__()
self.attn_dim = attn_dim
self.decoder_dim = decoder_dim
@ -30,9 +37,13 @@ class LocationAttention(nn.Module):
self.proj_enc = nn.Linear(encoder_dim, attn_dim)
self.proj_dec = nn.Linear(decoder_dim, attn_dim, bias=False)
self.proj_attn = nn.Linear(conv_dim, attn_dim, bias=False)
self.conv = nn.Conv1d(attn_state_kernel_size, conv_dim,
2 * conv_kernel_size + 1,
padding=conv_kernel_size, bias=False)
self.conv = nn.Conv1d(
attn_state_kernel_size,
conv_dim,
2 * conv_kernel_size + 1,
padding=conv_kernel_size,
bias=False,
)
self.proj_out = nn.Sequential(nn.Tanh(), nn.Linear(attn_dim, 1))
self.proj_enc_out = None # cache

View File

@ -12,20 +12,20 @@ class LSTMCellWithZoneOut(nn.Module):
https://arxiv.org/abs/1606.01305
"""
def __init__(self, prob: float, input_size: int, hidden_size: int,
bias: bool = True):
def __init__(
self, prob: float, input_size: int, hidden_size: int, bias: bool = True
):
super(LSTMCellWithZoneOut, self).__init__()
self.lstm_cell = nn.LSTMCell(input_size, hidden_size, bias=bias)
self.prob = prob
if prob > 1.0 or prob < 0.0:
raise ValueError("zoneout probability must be in the range from "
"0.0 to 1.0.")
raise ValueError(
"zoneout probability must be in the range from " "0.0 to 1.0."
)
def zoneout(self, h, next_h, prob):
if isinstance(h, tuple):
return tuple(
[self.zoneout(h[i], next_h[i], prob) for i in range(len(h))]
)
return tuple([self.zoneout(h[i], next_h[i], prob) for i in range(len(h))])
if self.training:
mask = h.new_zeros(*h.size()).bernoulli_(prob)

View File

@ -60,7 +60,9 @@ def quantize_model_(
to layers_to_quantize[step]
"""
quantized_layers = get_layers(model, layers_to_quantize[step], remove_weights=remove_weights)
quantized_layers = get_layers(
model, layers_to_quantize[step], remove_weights=remove_weights
)
for layer in quantized_layers:
@ -108,8 +110,8 @@ def quantize_model_(
centroids = torch.rand(centroids.size())
centroids.cuda()
# Get counts and assignment keys from layer in loaded checkpoint.
counts_key = layer+"."+"counts"
assignment_key = layer+"."+"assignments"
counts_key = layer + "." + "counts"
assignment_key = layer + "." + "assignments"
# Get number of different bins to include.
counts = list(state_dict[counts_key].shape)[0]
print(layer)
@ -122,7 +124,7 @@ def quantize_model_(
print(num_assignments)
print(num_extra)
assignments_bins = torch.arange(counts)
assignments_rand = torch.randint(0, counts-1, (num_extra, ))
assignments_rand = torch.randint(0, counts - 1, (num_extra,))
assignments = torch.cat((assignments_bins, assignments_rand), 0)
# assignments = assignments.type(torch.IntTensor)
assignments.cuda()

View File

@ -16,7 +16,9 @@ from .modules import ActivationQuantizer, IntConv2d, IntEmbedding, IntLinear
MAPPING = {nn.Linear: IntLinear, nn.Embedding: IntEmbedding, nn.Conv2d: IntConv2d}
def quantize_model_(model, p=0.2, bits=8, update_step=3000, method="histogram", remove_weights=False):
def quantize_model_(
model, p=0.2, bits=8, update_step=3000, method="histogram", remove_weights=False
):
"""
Replaces all modules with their scalar quantized counterpart and
registers hooks to quantize the post-ativations of those modules.

View File

@ -132,8 +132,7 @@ class TransformerEncoderLayerBase(nn.Module):
# will become -inf, which results in NaN in model parameters
if attn_mask is not None:
attn_mask = attn_mask.masked_fill(
attn_mask.to(torch.bool),
-1e8 if x.dtype == torch.float32 else -1e4
attn_mask.to(torch.bool), -1e8 if x.dtype == torch.float32 else -1e4
)
residual = x
@ -213,11 +212,19 @@ class TransformerDecoderLayerBase(nn.Module):
add_bias_kv=add_bias_kv,
add_zero_attn=add_zero_attn,
)
self.attn_ln = LayerNorm(self.embed_dim) if utils.safe_getattr(cfg, 'scale_attn', False) else None
self.attn_ln = (
LayerNorm(self.embed_dim)
if utils.safe_getattr(cfg, "scale_attn", False)
else None
)
self.nh = self.self_attn.num_heads
self.head_dim = self.self_attn.head_dim
scale_heads = utils.safe_getattr(cfg, 'scale_heads', False)
self.c_attn = nn.Parameter(torch.ones((self.nh,)), requires_grad=True) if scale_heads else None
scale_heads = utils.safe_getattr(cfg, "scale_heads", False)
self.c_attn = (
nn.Parameter(torch.ones((self.nh,)), requires_grad=True)
if scale_heads
else None
)
self.activation_fn = utils.get_activation_fn(activation=cfg.activation_fn)
activation_dropout_p = cfg.activation_dropout
@ -238,8 +245,21 @@ class TransformerDecoderLayerBase(nn.Module):
self.encoder_attn = self.build_encoder_attention(self.embed_dim, cfg)
self.encoder_attn_layer_norm = LayerNorm(self.embed_dim, export=cfg.export)
self.ffn_layernorm = LayerNorm(cfg.decoder.ffn_embed_dim) if utils.safe_getattr(cfg, 'scale_fc', False) else None
self.w_resid = nn.Parameter(torch.ones(self.embed_dim, ), requires_grad=True) if utils.safe_getattr(cfg, 'scale_resids', False) else None
self.ffn_layernorm = (
LayerNorm(cfg.decoder.ffn_embed_dim)
if utils.safe_getattr(cfg, "scale_fc", False)
else None
)
self.w_resid = (
nn.Parameter(
torch.ones(
self.embed_dim,
),
requires_grad=True,
)
if utils.safe_getattr(cfg, "scale_resids", False)
else None
)
self.fc1 = self.build_fc1(
self.embed_dim,
@ -297,7 +317,6 @@ class TransformerDecoderLayerBase(nn.Module):
def residual_connection(self, x, residual):
return residual + x
def forward(
self,
x,
@ -377,7 +396,7 @@ class TransformerDecoderLayerBase(nn.Module):
if self.c_attn is not None:
tgt_len, bsz = x.size(0), x.size(1)
x = x.view(tgt_len, bsz, self.nh, self.head_dim)
x = torch.einsum('tbhd,h->tbhd', x, self.c_attn)
x = torch.einsum("tbhd,h->tbhd", x, self.c_attn)
x = x.reshape(tgt_len, bsz, self.embed_dim)
if self.attn_ln is not None:
x = self.attn_ln(x)

View File

@ -35,9 +35,7 @@ def init_bert_params(module):
def normal_(data):
# with FSDP, module params will be on CUDA, so we cast them back to CPU
# so that the RNG is consistent with and without FSDP
data.copy_(
data.cpu().normal_(mean=0.0, std=0.02).to(data.device)
)
data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device))
if isinstance(module, nn.Linear):
normal_(module.weight.data)
@ -276,7 +274,9 @@ class TransformerSentenceEncoder(nn.Module):
inner_states.append(x)
for layer in self.layers:
x, _ = layer(x, self_attn_padding_mask=padding_mask, self_attn_mask=attn_mask)
x, _ = layer(
x, self_attn_padding_mask=padding_mask, self_attn_mask=attn_mask
)
if not last_state_only:
inner_states.append(x)

View File

@ -2,13 +2,13 @@
# Licensed under the MIT License.
""" Wrapper for ngram_repeat_block cuda extension """
import math
import warnings
from typing import Dict, List, Optional
import torch
from torch import nn
import math
from typing import Dict, List, Optional
import warnings
try:
from fairseq import ngram_repeat_block_cuda
@ -37,7 +37,7 @@ def is_cuda_extension_usable() -> bool:
class NGramRepeatBlock(nn.Module):
""" Wrapper class for calling ngram_repeat_block cuda extension """
"""Wrapper class for calling ngram_repeat_block cuda extension"""
def __init__(self, no_repeat_ngram_size: int, use_extension: bool = True):
super().__init__()

View File

@ -67,13 +67,13 @@ class FairseqAdam(FairseqOptimizer):
elif use_fused_adam:
logger.info("using FusedAdam")
self._optimizer = fused_adam_cls(
params,
use_fp16_stats=self.cfg.fp16_adam_stats,
**self.optimizer_config
params, use_fp16_stats=self.cfg.fp16_adam_stats, **self.optimizer_config
)
else:
if self.cfg.fp16_adam_stats:
raise NotImplementedError("--fp16-adam-stats is only supported with FusedAdamV1")
raise NotImplementedError(
"--fp16-adam-stats is only supported with FusedAdamV1"
)
self._optimizer = Adam(params, **self.optimizer_config)
@property

View File

@ -63,8 +63,9 @@ class AMPOptimizer(optim.FairseqOptimizer):
).format(self.min_loss_scale, new_loss_scale)
)
else:
logger.info("AMP: overflow detected, setting scale to "
f"to {new_loss_scale}")
logger.info(
"AMP: overflow detected, setting scale to " f"to {new_loss_scale}"
)
return grad_norm
@property

View File

@ -23,7 +23,9 @@ class OptimizerAndSchedulerConfig(FairseqDataclass):
optimizer: Any = None
lr_scheduler: Optional[Any] = None
lr: List = II("optimization.lr")
lr_float: Optional[float] = None # this makes it easier to sweep on learning rate with auto sweepers
lr_float: Optional[
float
] = None # this makes it easier to sweep on learning rate with auto sweepers
@dataclass

View File

@ -16,6 +16,7 @@ from omegaconf import II, DictConfig
try:
import deepspeed
has_deepspeed = True
except ImportError as e:
has_deepspeed = False
@ -24,12 +25,15 @@ except ImportError as e:
def _get_cpu_adam():
try:
from deepspeed.ops.op_builder import CPUAdamBuilder
return CPUAdamBuilder().load()
except ImportError:
# fbcode
from deepspeed.ops.adam import DeepSpeedCPUAdam as ds_opt_adam
return ds_opt_adam
@dataclass
class FairseqCPUAdamConfig(FairseqDataclass):
adam_betas: str = field(

View File

@ -64,9 +64,9 @@ class _FP16OptimizerMixin(object):
fp32_params = []
for p in params:
p32 = torch.nn.Parameter(p.data.float())
if hasattr(p, 'expert'):
if hasattr(p, "expert"):
p32.expert = True
elif hasattr(p, 'base_expert'):
elif hasattr(p, "base_expert"):
p32.base_expert = True
p32.grad = torch.zeros_like(p32.data)
if hasattr(p, "param_group"):
@ -209,7 +209,9 @@ class _FP16OptimizerMixin(object):
self._sync_fp16_grads_to_fp32()
if getattr(self, "supports_step_with_scale", False):
self.fp32_optimizer.step(closure, scale=(1.0 / self._multiply_factor), groups=groups)
self.fp32_optimizer.step(
closure, scale=(1.0 / self._multiply_factor), groups=groups
)
else:
self._unscale_grads()
self.fp32_optimizer.step(closure, groups=groups)
@ -434,7 +436,9 @@ class _MemoryEfficientFP16OptimizerMixin(object):
"""Performs a single optimization step."""
if getattr(self, "supports_step_with_scale", False):
# NOTE(msb) optimizer divides by scale factor
self.wrapped_optimizer.step(closure, scale=(1.0 / self._multiply_factor), groups=groups)
self.wrapped_optimizer.step(
closure, scale=(1.0 / self._multiply_factor), groups=groups
)
else:
self._unscale_grads()
self.wrapped_optimizer.step(closure, groups=groups)

View File

@ -179,7 +179,7 @@ class FusedAdamV1(torch.optim.Optimizer):
if p.device.type == "cpu":
p_data_fp32 = p.data.cuda(non_blocking=True).float()
out_p = torch.tensor([], dtype = torch.float)
out_p = torch.tensor([], dtype=torch.float)
else:
p_data_fp32 = p.data.float()
out_p = p.data
@ -234,6 +234,7 @@ class FusedAdamV1(torch.optim.Optimizer):
p.data.copy_(p_data_fp32, non_blocking=True)
if self.use_fp16_stats:
def inf_norm(t):
return torch.norm(t, float("inf"))
@ -262,7 +263,9 @@ try:
def __init__(self, *args, use_fp16_stats=False, **kwargs):
if use_fp16_stats:
raise NotImplementedError("--fp16-adam-stats is only supported with FusedAdamV1")
raise NotImplementedError(
"--fp16-adam-stats is only supported with FusedAdamV1"
)
super().__init__(*args, **kwargs)
if not hasattr(self, "multi_tensor_adam"):
raise Exception(

View File

@ -32,7 +32,7 @@ class ManualSchedule(LegacyFairseqLRScheduler):
self.optimizer.set_lr(self.lr) # Set the beginning of the epoch.
def parse_manuallr_args(self, lr_args_str):
lr_dict = ast.literal_eval(lr_args_str.replace(' ', ''))
lr_dict = ast.literal_eval(lr_args_str.replace(" ", ""))
if not isinstance(lr_dict, dict):
raise ValueError("epoch2lr/update2lr must be abel to evaluated to a dict")
@ -84,9 +84,14 @@ class ManualSchedule(LegacyFairseqLRScheduler):
if manual_keys:
manual_lr = self.epoch2lr[max(manual_keys)]
else:
logger.warning("@@@ epoch={} does not exist in manual lr input. epoch2lr={}...".format(
epoch, list(self.epoch2lr.items())[:min(10, len(self.epoch2lr.keys())-1)]
))
logger.warning(
"@@@ epoch={} does not exist in manual lr input. epoch2lr={}...".format(
epoch,
list(self.epoch2lr.items())[
: min(10, len(self.epoch2lr.keys()) - 1)
],
)
)
manual_lr = self.optimizer.get_lr()
return manual_lr
@ -102,8 +107,14 @@ class ManualSchedule(LegacyFairseqLRScheduler):
if manual_keys:
manual_lr = self.update2lr[max(manual_keys)]
else:
logger.warning("epoch={} does not exist in manual lr input update2lr={}...".format(
num_updates, list(self.update2lr.items())[:min(10, len(self.update2lr.keys())-1)]))
logger.warning(
"epoch={} does not exist in manual lr input update2lr={}...".format(
num_updates,
list(self.update2lr.items())[
: min(10, len(self.update2lr.keys()) - 1)
],
)
)
manual_lr = self.optimizer.get_lr()
self.optimizer.set_lr(manual_lr)

View File

@ -36,8 +36,7 @@ class StepLRScheduleConfig(FairseqDataclass):
@register_lr_scheduler("step", dataclass=StepLRScheduleConfig)
class StepLRSchedule(FairseqLRScheduler):
"""Decay learning rate every k updates by a fixed factor
"""
"""Decay learning rate every k updates by a fixed factor"""
def __init__(self, cfg: StepLRScheduleConfig, fairseq_optimizer):
super().__init__(cfg, fairseq_optimizer)
@ -50,16 +49,16 @@ class StepLRSchedule(FairseqLRScheduler):
cfg.warmup_init_lr if cfg.warmup_init_lr >= 0 else self.min_lr
)
assert(self.lr_deacy_period > 0)
assert(self.lr_decay <= 1)
assert(self.min_lr >= 0)
assert(self.max_lr > self.min_lr)
assert self.lr_deacy_period > 0
assert self.lr_decay <= 1
assert self.min_lr >= 0
assert self.max_lr > self.min_lr
if cfg.warmup_updates > 0:
# linearly warmup for the first cfg.warmup_updates
self.warmup_lr_step = (
(self.max_lr - self.warmup_init_lr) / self.warmup_updates
)
self.max_lr - self.warmup_init_lr
) / self.warmup_updates
else:
self.warmup_lr_step = 1

View File

@ -171,7 +171,9 @@ class SequenceGenerator(nn.Module):
yield id, src, ref, hypos[i]
@torch.no_grad()
def generate(self, models, sample: Dict[str, Dict[str, Tensor]], **kwargs) -> List[List[Dict[str, Tensor]]]:
def generate(
self, models, sample: Dict[str, Dict[str, Tensor]], **kwargs
) -> List[List[Dict[str, Tensor]]]:
"""Generate translations. Match the api of other fairseq generators.
Args:
@ -223,7 +225,10 @@ class SequenceGenerator(nn.Module):
else torch.tensor(src_tokens.size(-1)).to(src_tokens)
)
else:
raise Exception("expected src_tokens or source in net input. input keys: " + str(net_input.keys()))
raise Exception(
"expected src_tokens or source in net input. input keys: "
+ str(net_input.keys())
)
# bsz: total number of sentences in beam
# Note that src_tokens may have more than 2 dimensions (i.e. audio features)
@ -328,7 +333,9 @@ class SequenceGenerator(nn.Module):
encoder_outs = self.model.reorder_encoder_out(
encoder_outs, reorder_state
)
with torch.autograd.profiler.record_function("EnsembleModel: forward_decoder"):
with torch.autograd.profiler.record_function(
"EnsembleModel: forward_decoder"
):
lprobs, avg_attn_scores = self.model.forward_decoder(
tokens[:, : step + 1],
encoder_outs,
@ -751,7 +758,14 @@ class EnsembleModel(nn.Module):
return self.has_incremental
def max_decoder_positions(self):
return min([m.max_decoder_positions() for m in self.models if hasattr(m, "max_decoder_positions")] + [sys.maxsize])
return min(
[
m.max_decoder_positions()
for m in self.models
if hasattr(m, "max_decoder_positions")
]
+ [sys.maxsize]
)
@torch.jit.export
def forward_encoder(self, net_input: Dict[str, Tensor]):

View File

@ -35,8 +35,12 @@ class SpeechGenerator(object):
class AutoRegressiveSpeechGenerator(SpeechGenerator):
def __init__(
self, model, vocoder, data_cfg, max_iter: int = 6000,
eos_prob_threshold: float = 0.5,
self,
model,
vocoder,
data_cfg,
max_iter: int = 6000,
eos_prob_threshold: float = 0.5,
):
super().__init__(model, vocoder, data_cfg)
self.max_iter = max_iter
@ -54,8 +58,9 @@ class AutoRegressiveSpeechGenerator(SpeechGenerator):
raw_dim = out_dim // n_frames_per_step
# initialize
encoder_out = model.forward_encoder(src_tokens, src_lengths,
speaker=sample["speaker"])
encoder_out = model.forward_encoder(
src_tokens, src_lengths, speaker=sample["speaker"]
)
incremental_state = {}
feat, attn, eos_prob = [], [], []
finished = src_tokens.new_zeros((bsz,)).bool()
@ -66,21 +71,24 @@ class AutoRegressiveSpeechGenerator(SpeechGenerator):
cur_out_lens = out_lens.clone()
cur_out_lens.masked_fill_(cur_out_lens.eq(self.max_iter), step + 1)
_, cur_eos_out, cur_extra = model.forward_decoder(
prev_feat_out, encoder_out=encoder_out,
prev_feat_out,
encoder_out=encoder_out,
incremental_state=incremental_state,
target_lengths=cur_out_lens, speaker=sample["speaker"], **kwargs
target_lengths=cur_out_lens,
speaker=sample["speaker"],
**kwargs
)
cur_eos_prob = torch.sigmoid(cur_eos_out).squeeze(2)
feat.append(cur_extra['feature_out'])
attn.append(cur_extra['attn'])
feat.append(cur_extra["feature_out"])
attn.append(cur_extra["attn"])
eos_prob.append(cur_eos_prob)
cur_finished = (cur_eos_prob.squeeze(1) > self.eos_prob_threshold)
cur_finished = cur_eos_prob.squeeze(1) > self.eos_prob_threshold
out_lens.masked_fill_((~finished) & cur_finished, step + 1)
finished = finished | cur_finished
if finished.sum().item() == bsz:
break
prev_feat_out = cur_extra['feature_out']
prev_feat_out = cur_extra["feature_out"]
feat = torch.cat(feat, dim=1)
feat = model.decoder.postnet(feat) + feat
@ -98,11 +106,11 @@ class AutoRegressiveSpeechGenerator(SpeechGenerator):
finalized = [
{
'feature': feat[b, :out_len],
'eos_prob': eos_prob[b, :out_len],
'attn': attn[b, :, :out_len],
'alignment': alignment[b, :out_len],
'waveform': self.get_waveform(feat[b, :out_len]),
"feature": feat[b, :out_len],
"eos_prob": eos_prob[b, :out_len],
"attn": attn[b, :, :out_len],
"alignment": alignment[b, :out_len],
"waveform": self.get_waveform(feat[b, :out_len]),
}
for b, out_len in zip(range(bsz), out_lens)
]
@ -134,7 +142,7 @@ class NonAutoregressiveSpeechGenerator(SpeechGenerator):
prev_output_tokens=sample["net_input"]["prev_output_tokens"],
incremental_state=None,
target_lengths=sample["target_lengths"],
speaker=sample["speaker"]
speaker=sample["speaker"],
)
if feat_post is not None:
feat = feat_post
@ -142,9 +150,7 @@ class NonAutoregressiveSpeechGenerator(SpeechGenerator):
feat = feat.view(bsz, -1, raw_dim)
feat = self.gcmvn_denormalize(feat)
dur_out = torch.clamp(
torch.round(torch.exp(log_dur_out) - 1).long(), min=0
)
dur_out = torch.clamp(torch.round(torch.exp(log_dur_out) - 1).long(), min=0)
def get_dur_plot_data(d):
r = []
@ -155,11 +161,11 @@ class NonAutoregressiveSpeechGenerator(SpeechGenerator):
out_lens = out_lens * n_frames_per_step
finalized = [
{
'feature': feat[b, :l] if l > 0 else feat.new_zeros([1, raw_dim]),
'waveform': self.get_waveform(
"feature": feat[b, :l] if l > 0 else feat.new_zeros([1, raw_dim]),
"waveform": self.get_waveform(
feat[b, :l] if l > 0 else feat.new_zeros([1, raw_dim])
),
'attn': feat.new_tensor(get_dur_plot_data(dur_out[b])),
"attn": feat.new_tensor(get_dur_plot_data(dur_out[b])),
}
for b, l in zip(range(bsz), out_lens)
]
@ -188,8 +194,12 @@ class TeacherForcingAutoRegressiveSpeechGenerator(AutoRegressiveSpeechGenerator)
bsz = src_tokens.shape[0]
feat, eos_prob, extra = model(
src_tokens, src_lens, prev_out_tokens, incremental_state=None,
target_lengths=tgt_lens, speaker=sample["speaker"]
src_tokens,
src_lens,
prev_out_tokens,
incremental_state=None,
target_lengths=tgt_lens,
speaker=sample["speaker"],
)
attn = extra["attn"] # B x T_s x T_t
@ -203,11 +213,11 @@ class TeacherForcingAutoRegressiveSpeechGenerator(AutoRegressiveSpeechGenerator)
finalized = [
{
'feature': feat[b, :tgt_len],
'eos_prob': eos_prob[b, :tgt_len],
'attn': attn[b, :, :tgt_len],
'alignment': alignment[b, :tgt_len],
'waveform': self.get_waveform(feat[b, :tgt_len]),
"feature": feat[b, :tgt_len],
"eos_prob": eos_prob[b, :tgt_len],
"attn": attn[b, :, :tgt_len],
"alignment": alignment[b, :tgt_len],
"waveform": self.get_waveform(feat[b, :tgt_len]),
}
for b, tgt_len in zip(range(bsz), tgt_lens)
]

View File

@ -67,31 +67,31 @@ class AudioFinetuningConfig(AudioPretrainingConfig):
default=False, metadata={"help": "evaluation with BLEU scores"}
)
eval_bleu_detok: Optional[str] = field(
default=None, metadata={
default=None,
metadata={
"help": "detokenize before computing BLEU (e.g., 'moses'); "
"required if using --eval-bleu; use 'space' to disable "
"detokenization; see fairseq.data.encoders for other options"
}
"required if using --eval-bleu; use 'space' to disable "
"detokenization; see fairseq.data.encoders for other options"
},
)
eval_bleu_detok_args: str = field(
default="{}",
metadata={"help": "args for building the tokenizer, if needed"}
default="{}", metadata={"help": "args for building the tokenizer, if needed"}
)
eval_tokenized_bleu: bool = field(
default=False,
metadata={"help": "compute tokenized BLEU instead of sacrebleu"}
default=False, metadata={"help": "compute tokenized BLEU instead of sacrebleu"}
)
eval_bleu_remove_bpe: Optional[str] = field(
default=None, metadata={"help": "remove BPE before computing BLEU"}
)
eval_bleu_args: str = field(
default="{}",
metadata={"help": "generation args for BLUE scoring, e.g., "
"'{\"beam\": 4, \"lenpen\": 0.6}'"}
metadata={
"help": "generation args for BLUE scoring, e.g., "
'\'{"beam": 4, "lenpen": 0.6}\''
},
)
eval_bleu_print_samples: bool = field(
default=False,
metadata={"help": "print sample generations during validation"}
default=False, metadata={"help": "print sample generations during validation"}
)
autoregressive: bool = field(
default=False,
@ -123,7 +123,9 @@ class AudioFinetuningTask(AudioPretrainingTask):
return Dictionary.load(dict_path)
return None
def load_dataset(self, split: str, task_cfg: AudioFinetuningConfig = None, **kwargs):
def load_dataset(
self, split: str, task_cfg: AudioFinetuningConfig = None, **kwargs
):
super().load_dataset(split, task_cfg, **kwargs)
task_cfg = task_cfg or self.cfg
@ -138,7 +140,8 @@ class AudioFinetuningTask(AudioPretrainingTask):
with open(label_path, "r") as f:
labels = [
text_compressor.compress(l)
for i, l in enumerate(f) if i not in skipped_indices
for i, l in enumerate(f)
if i not in skipped_indices
]
assert len(labels) == len(self.datasets[split]), (
@ -157,7 +160,7 @@ class AudioFinetuningTask(AudioPretrainingTask):
process_label=process_label,
label_len_fn=label_len_fn,
add_to_input=task_cfg.get("autoregressive", False),
text_compression_level=text_compression_level
text_compression_level=text_compression_level,
)
@property
@ -176,8 +179,8 @@ class AudioFinetuningTask(AudioPretrainingTask):
logging_output["_num_words"] = metrics["num_words"]
if self.cfg.eval_bleu and self.cfg.autoregressive:
metrics = self._inference_with_bleu(self.sequence_generator, sample, model)
logging_output['_bleu_sys_len'] = metrics.sys_len
logging_output['_bleu_ref_len'] = metrics.ref_len
logging_output["_bleu_sys_len"] = metrics.sys_len
logging_output["_bleu_ref_len"] = metrics.ref_len
# we split counts into separate entries so that they can be
# summed efficiently across workers using fast-stat-sync
assert len(metrics.counts) == 4
@ -200,9 +203,9 @@ class AudioFinetuningTask(AudioPretrainingTask):
self.tokenizer = None
if self.cfg.eval_bleu and self.cfg.autoregressive:
assert self.cfg.eval_bleu_detok is not None, (
'--eval-bleu-detok is required if using --eval-bleu; '
'try --eval-bleu-detok=moses (or --eval-bleu-detok=space '
'to disable detokenization, e.g., when using sentencepiece)'
"--eval-bleu-detok is required if using --eval-bleu; "
"try --eval-bleu-detok=moses (or --eval-bleu-detok=space "
"to disable detokenization, e.g., when using sentencepiece)"
)
detok_args = json.loads(self.cfg.eval_bleu_detok_args)
self.tokenizer = encoders.build_tokenizer(
@ -261,9 +264,7 @@ class AudioFinetuningTask(AudioPretrainingTask):
# BLEU scores. Instead, we use a somewhat more verbose
# alternative that is unlikely to appear in the real
# reference, but doesn't get split into multiple tokens.
unk_string=(
"UNKNOWNTOKENINREF" if is_ref else "UNKNOWNTOKENINHYP"
),
unk_string=("UNKNOWNTOKENINREF" if is_ref else "UNKNOWNTOKENINHYP"),
)
if self.tokenizer:
s = self.tokenizer.decode(s)
@ -272,21 +273,18 @@ class AudioFinetuningTask(AudioPretrainingTask):
gen_out = self.inference_step(generator, [model], sample)
hyps, refs = [], []
for i in range(len(gen_out)):
hyps.append(decode(gen_out[i][0]['tokens'], is_ref=False))
hyps.append(decode(gen_out[i][0]["tokens"], is_ref=False))
refs.append(
decode(
utils.strip_pad(
sample['target'][i],
self.target_dictionary.pad()
),
utils.strip_pad(sample["target"][i], self.target_dictionary.pad()),
is_ref=True, # don't count <unk> as matches to the hypo
)
)
if self.cfg.eval_bleu_print_samples:
logger.info('H-{} {}'.format(sample["id"][0], hyps[0]))
logger.info('T-{} {}'.format(sample["id"][0], refs[0]))
logger.info("H-{} {}".format(sample["id"][0], hyps[0]))
logger.info("T-{} {}".format(sample["id"][0], refs[0]))
eval_tokenization = 'none' if self.cfg.eval_tokenized_bleu else '13a'
eval_tokenization = "none" if self.cfg.eval_tokenized_bleu else "13a"
return sacrebleu.corpus_bleu(hyps, [refs], tokenize=eval_tokenization)
def reduce_metrics(self, logging_outputs, criterion):
@ -329,18 +327,17 @@ class AudioFinetuningTask(AudioPretrainingTask):
count_keys = [f"_bleu_counts_{i}" for i in range(4)]
total_keys = [f"_bleu_totals_{i}" for i in range(4)]
for k in len_keys + count_keys + total_keys:
metrics.log_scalar(
k, sum(log.get(k, 0) for log in logging_outputs)
)
metrics.log_scalar(k, sum(log.get(k, 0) for log in logging_outputs))
import sacrebleu
metrics.log_derived(
'bleu',
"bleu",
lambda meters: sacrebleu.compute_bleu(
correct=[meters[k].sum for k in count_keys],
total=[meters[k].sum for k in total_keys],
sys_len=meters['_bleu_sys_len'].sum,
ref_len=meters['_bleu_ref_len'].sum,
smooth_method="exp"
).score
sys_len=meters["_bleu_sys_len"].sum,
ref_len=meters["_bleu_ref_len"].sum,
smooth_method="exp",
).score,
)

View File

@ -50,8 +50,7 @@ class AudioPretrainingConfig(FairseqDataclass):
data: str = field(default=MISSING, metadata={"help": "path to data directory"})
labels: Optional[str] = field(
default=None,
metadata={
"help": "extension of the label file to load, used for fine-tuning"},
metadata={"help": "extension of the label file to load, used for fine-tuning"},
)
binarized_dataset: bool = field(
default=False,
@ -102,8 +101,8 @@ class AudioPretrainingConfig(FairseqDataclass):
default="none",
metadata={
"help": "compression level for texts (e.g. audio filenames, "
"target texts): none/low/high (default: none). "
}
"target texts): none/low/high (default: none). "
},
)

View File

@ -135,7 +135,6 @@ class DenoisingTask(LegacyFairseqTask):
'e.g., "train,valid" (default: all dataset splits)',
)
def __init__(self, args, dictionary):
super().__init__(args)
self.dictionary = dictionary

View File

@ -11,20 +11,19 @@ from fairseq.tasks.text_to_speech import TextToSpeechTask
logging.basicConfig(
format='%(asctime)s | %(levelname)s | %(name)s | %(message)s',
datefmt='%Y-%m-%d %H:%M:%S', level=logging.INFO
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
level=logging.INFO,
)
logger = logging.getLogger(__name__)
@register_task('frm_text_to_speech')
@register_task("frm_text_to_speech")
class FrmTextToSpeechTask(TextToSpeechTask):
@staticmethod
def add_args(parser):
TextToSpeechTask.add_args(parser)
parser.add_argument(
"--do_chunk", action="store_true", help="train on chunks"
)
parser.add_argument("--do_chunk", action="store_true", help="train on chunks")
parser.add_argument("--chunk_bound", default=-1, type=int)
parser.add_argument("--chunk_init", default=50, type=int)
parser.add_argument("--chunk_incr", default=5, type=int)
@ -52,5 +51,5 @@ class FrmTextToSpeechTask(TextToSpeechTask):
chunk_incr=self.args.chunk_incr,
add_eos=self.args.add_eos,
dedup=self.args.dedup,
ref_fpu=self.args.ref_fpu
ref_fpu=self.args.ref_fpu,
)

Some files were not shown because too many files have changed in this diff Show More