mirror of
https://github.com/facebookresearch/fairseq.git
synced 2024-07-14 18:50:22 +03:00
Add linting with black (#2678)
Summary: # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [ ] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/main/CONTRIBUTING.md)? - [ ] Did you make sure to update the docs? - [ ] Did you write any new necessary tests? ## What does this PR do? Fixes # (issue). ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/2678 Reviewed By: Mortimerp9 Differential Revision: D32653381 Pulled By: dianaml0 fbshipit-source-id: 2810d14867cd7d64f4d340740e2b590b82de47fe
This commit is contained in:
parent
3dc1691df1
commit
0dfd6b6240
5
.github/workflows/build.yml
vendored
5
.github/workflows/build.yml
vendored
@ -53,3 +53,8 @@ jobs:
|
||||
- name: Run tests
|
||||
run: |
|
||||
python setup.py test
|
||||
|
||||
- name: Lint with black
|
||||
run: |
|
||||
pip install black
|
||||
black --check . --extend-exclude 'examples|fairseq\/model_parallel\/megatron'
|
||||
|
@ -27,6 +27,7 @@ sys.modules["fairseq.progress_bar"] = progress_bar
|
||||
|
||||
# initialize hydra
|
||||
from fairseq.dataclass.initialize import hydra_init
|
||||
|
||||
hydra_init()
|
||||
|
||||
import fairseq.criterions # noqa
|
||||
|
@ -7,10 +7,10 @@ import logging
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from fairseq.data import Dictionary, FairseqDataset
|
||||
from fairseq.tasks import LegacyFairseqTask, register_task
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -36,7 +36,7 @@ class DummyMTTask(LegacyFairseqTask):
|
||||
|
||||
@classmethod
|
||||
def setup_task(cls, args, **kwargs):
|
||||
"""Setup the task. """
|
||||
"""Setup the task."""
|
||||
dictionary = Dictionary()
|
||||
for i in range(args.dict_size):
|
||||
dictionary.add_symbol("word{}".format(i))
|
||||
|
@ -96,10 +96,7 @@ def save_checkpoint(cfg: CheckpointConfig, trainer, epoch_itr, val_loss):
|
||||
|
||||
checkpoint_conds[
|
||||
"checkpoint.best_{}_{:.3f}{}{}.pt".format(
|
||||
cfg.best_checkpoint_metric,
|
||||
val_loss,
|
||||
rand_sfx,
|
||||
suffix
|
||||
cfg.best_checkpoint_metric, val_loss, rand_sfx, suffix
|
||||
)
|
||||
] = worst_best is None or is_better(val_loss, worst_best)
|
||||
checkpoint_conds[
|
||||
@ -468,9 +465,7 @@ def load_model_ensemble_and_task(
|
||||
and len(state["optimizer_history"]) > 0
|
||||
and "num_updates" in state["optimizer_history"][-1]
|
||||
):
|
||||
model.set_num_updates(
|
||||
state["optimizer_history"][-1]["num_updates"]
|
||||
)
|
||||
model.set_num_updates(state["optimizer_history"][-1]["num_updates"])
|
||||
model.load_state_dict(
|
||||
state["model"], strict=strict, model_cfg=cfg.model
|
||||
)
|
||||
@ -588,9 +583,8 @@ def _upgrade_state_dict(state):
|
||||
# backward compatibility, cfg updates
|
||||
if "args" in state and state["args"] is not None:
|
||||
# old model checkpoints may not have separate source/target positions
|
||||
if (
|
||||
hasattr(state["args"], "max_positions")
|
||||
and not hasattr(state["args"], "max_source_positions")
|
||||
if hasattr(state["args"], "max_positions") and not hasattr(
|
||||
state["args"], "max_source_positions"
|
||||
):
|
||||
state["args"].max_source_positions = state["args"].max_positions
|
||||
state["args"].max_target_positions = state["args"].max_positions
|
||||
@ -615,13 +609,10 @@ def _upgrade_state_dict(state):
|
||||
state["args"].stop_min_lr = state["args"].min_lr
|
||||
del state["args"].min_lr
|
||||
# binary_cross_entropy / kd_binary_cross_entropy => wav2vec criterion
|
||||
if (
|
||||
hasattr(state["args"], "criterion")
|
||||
and state["args"].criterion in [
|
||||
"binary_cross_entropy",
|
||||
"kd_binary_cross_entropy",
|
||||
]
|
||||
):
|
||||
if hasattr(state["args"], "criterion") and state["args"].criterion in [
|
||||
"binary_cross_entropy",
|
||||
"kd_binary_cross_entropy",
|
||||
]:
|
||||
state["args"].criterion = "wav2vec"
|
||||
# remove log_keys if it's None (criteria will supply a default value of [])
|
||||
if hasattr(state["args"], "log_keys") and state["args"].log_keys is None:
|
||||
@ -659,7 +650,9 @@ def _upgrade_state_dict(state):
|
||||
):
|
||||
cfg.task.eval_wer_config.print_alignment = "hard"
|
||||
if "generation" in cfg and isinstance(cfg.generation.print_alignment, bool):
|
||||
cfg.generation.print_alignment = "hard" if cfg.generation.print_alignment else None
|
||||
cfg.generation.print_alignment = (
|
||||
"hard" if cfg.generation.print_alignment else None
|
||||
)
|
||||
if (
|
||||
"model" in cfg
|
||||
and "w2v_args" in cfg.model
|
||||
@ -833,16 +826,16 @@ def load_ema_from_checkpoint(fpath):
|
||||
params_dict = collections.OrderedDict()
|
||||
new_state = None
|
||||
|
||||
with PathManager.open(fpath, 'rb') as f:
|
||||
with PathManager.open(fpath, "rb") as f:
|
||||
new_state = torch.load(
|
||||
f,
|
||||
map_location=(
|
||||
lambda s, _: torch.serialization.default_restore_location(s, 'cpu')
|
||||
lambda s, _: torch.serialization.default_restore_location(s, "cpu")
|
||||
),
|
||||
)
|
||||
|
||||
# EMA model is stored in a separate "extra state"
|
||||
model_params = new_state['extra_state']['ema']
|
||||
model_params = new_state["extra_state"]["ema"]
|
||||
|
||||
for key in list(model_params.keys()):
|
||||
p = model_params[key]
|
||||
@ -860,5 +853,5 @@ def load_ema_from_checkpoint(fpath):
|
||||
"ema model weights, is this model trained with EMA?"
|
||||
)
|
||||
|
||||
new_state['model'] = params_dict
|
||||
new_state["model"] = params_dict
|
||||
return new_state
|
||||
|
@ -20,9 +20,7 @@ from fairseq.models.fairseq_model import FairseqEncoderModel
|
||||
|
||||
@dataclass
|
||||
class FastSpeech2CriterionConfig(FairseqDataclass):
|
||||
ctc_weight: float = field(
|
||||
default=0.0, metadata={"help": "weight for CTC loss"}
|
||||
)
|
||||
ctc_weight: float = field(default=0.0, metadata={"help": "weight for CTC loss"})
|
||||
|
||||
|
||||
@register_criterion("fastspeech2", dataclass=FastSpeech2CriterionConfig)
|
||||
@ -44,7 +42,7 @@ class FastSpeech2Loss(FairseqCriterion):
|
||||
speaker=sample["speaker"],
|
||||
durations=sample["durations"],
|
||||
pitches=sample["pitches"],
|
||||
energies=sample["energies"]
|
||||
energies=sample["energies"],
|
||||
)
|
||||
|
||||
src_mask = lengths_to_mask(sample["net_input"]["src_lengths"])
|
||||
@ -57,8 +55,7 @@ class FastSpeech2Loss(FairseqCriterion):
|
||||
feat_out, feat = _feat_out[tgt_mask], sample["target"][tgt_mask]
|
||||
l1_loss = F.l1_loss(feat_out, feat, reduction=reduction)
|
||||
if _feat_out_post is not None:
|
||||
l1_loss += F.l1_loss(_feat_out_post[tgt_mask], feat,
|
||||
reduction=reduction)
|
||||
l1_loss += F.l1_loss(_feat_out_post[tgt_mask], feat, reduction=reduction)
|
||||
|
||||
pitch_loss = F.mse_loss(pitch_out, pitches, reduction=reduction)
|
||||
energy_loss = F.mse_loss(energy_out, energies, reduction=reduction)
|
||||
@ -69,16 +66,23 @@ class FastSpeech2Loss(FairseqCriterion):
|
||||
log_dur = torch.log(dur + 1)[src_mask]
|
||||
dur_loss = F.mse_loss(log_dur_out, log_dur, reduction=reduction)
|
||||
|
||||
ctc_loss = torch.tensor(0.).type_as(l1_loss)
|
||||
if self.ctc_weight > 0.:
|
||||
ctc_loss = torch.tensor(0.0).type_as(l1_loss)
|
||||
if self.ctc_weight > 0.0:
|
||||
lprobs = model.get_normalized_probs((_feat_out,), log_probs=True)
|
||||
lprobs = lprobs.transpose(0, 1) # T x B x C
|
||||
src_mask = lengths_to_mask(src_lens)
|
||||
src_tokens_flat = src_tokens.masked_select(src_mask)
|
||||
ctc_loss = F.ctc_loss(
|
||||
lprobs, src_tokens_flat, tgt_lens, src_lens,
|
||||
reduction=reduction, zero_infinity=True
|
||||
) * self.ctc_weight
|
||||
ctc_loss = (
|
||||
F.ctc_loss(
|
||||
lprobs,
|
||||
src_tokens_flat,
|
||||
tgt_lens,
|
||||
src_lens,
|
||||
reduction=reduction,
|
||||
zero_infinity=True,
|
||||
)
|
||||
* self.ctc_weight
|
||||
)
|
||||
|
||||
loss = l1_loss + dur_loss + pitch_loss + energy_loss + ctc_loss
|
||||
|
||||
@ -102,8 +106,12 @@ class FastSpeech2Loss(FairseqCriterion):
|
||||
ntot = sum(ns)
|
||||
ws = [n / (ntot + 1e-8) for n in ns]
|
||||
for key in [
|
||||
"loss", "l1_loss", "dur_loss", "pitch_loss", "energy_loss",
|
||||
"ctc_loss"
|
||||
"loss",
|
||||
"l1_loss",
|
||||
"dur_loss",
|
||||
"pitch_loss",
|
||||
"energy_loss",
|
||||
"ctc_loss",
|
||||
]:
|
||||
vals = [log.get(key, 0) for log in logging_outputs]
|
||||
val = sum(val * w for val, w in zip(vals, ws))
|
||||
@ -115,10 +123,10 @@ class FastSpeech2Loss(FairseqCriterion):
|
||||
return
|
||||
n = sum(log.get("targ_frames", 0) for log in logging_outputs)
|
||||
for key, new_key in [
|
||||
("mcd_loss", "mcd_loss"),
|
||||
("pred_frames", "pred_ratio"),
|
||||
("nins", "ins_rate"),
|
||||
("ndel", "del_rate"),
|
||||
("mcd_loss", "mcd_loss"),
|
||||
("pred_frames", "pred_ratio"),
|
||||
("nins", "ins_rate"),
|
||||
("ndel", "del_rate"),
|
||||
]:
|
||||
val = sum(log.get(key, 0) for log in logging_outputs)
|
||||
metrics.log_scalar(new_key, val / n, n, round=3)
|
||||
|
@ -37,7 +37,14 @@ class HubertCriterionConfig(FairseqDataclass):
|
||||
|
||||
@register_criterion("hubert", dataclass=HubertCriterionConfig)
|
||||
class HubertCriterion(FairseqCriterion):
|
||||
def __init__(self, task, pred_masked_weight, pred_nomask_weight, loss_weights=None, log_keys=None):
|
||||
def __init__(
|
||||
self,
|
||||
task,
|
||||
pred_masked_weight,
|
||||
pred_nomask_weight,
|
||||
loss_weights=None,
|
||||
log_keys=None,
|
||||
):
|
||||
super().__init__(task)
|
||||
self.pred_masked_weight = pred_masked_weight
|
||||
self.pred_nomask_weight = pred_nomask_weight
|
||||
@ -52,7 +59,7 @@ class HubertCriterion(FairseqCriterion):
|
||||
3) logging outputs to display while training
|
||||
"""
|
||||
net_output = model(target_list=sample["target_list"], **sample["net_input"])
|
||||
loss = 0.
|
||||
loss = 0.0
|
||||
sample_size = 0
|
||||
logging_output = {}
|
||||
reduction = "sum" if reduce else "none"
|
||||
@ -89,7 +96,9 @@ class HubertCriterion(FairseqCriterion):
|
||||
names = [names]
|
||||
if len(self.loss_weights) == 1 and len(extra_losses) != 1:
|
||||
self.loss_weights = [self.loss_weights[0]] * len(extra_losses)
|
||||
assert len(extra_losses) == len(self.loss_weights), f"{len(extra_losses)}, {len(self.loss_weights)}"
|
||||
assert len(extra_losses) == len(
|
||||
self.loss_weights
|
||||
), f"{len(extra_losses)}, {len(self.loss_weights)}"
|
||||
for p, n, coef in zip(extra_losses, names, self.loss_weights):
|
||||
if coef != 0 and p is not None:
|
||||
p = coef * p.float() * sample_size
|
||||
@ -140,12 +149,20 @@ class HubertCriterion(FairseqCriterion):
|
||||
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
|
||||
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
|
||||
|
||||
metrics.log_scalar("loss", loss_sum / sample_size / math.log(2), sample_size, round=3)
|
||||
metrics.log_scalar(
|
||||
"loss", loss_sum / sample_size / math.log(2), sample_size, round=3
|
||||
)
|
||||
if sample_size != ntokens:
|
||||
metrics.log_scalar("nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3)
|
||||
metrics.log_derived("ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg))
|
||||
metrics.log_scalar(
|
||||
"nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3
|
||||
)
|
||||
metrics.log_derived(
|
||||
"ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)
|
||||
)
|
||||
else:
|
||||
metrics.log_derived("ppl", lambda meters: utils.get_perplexity(meters["loss"].avg))
|
||||
metrics.log_derived(
|
||||
"ppl", lambda meters: utils.get_perplexity(meters["loss"].avg)
|
||||
)
|
||||
|
||||
counts = {}
|
||||
for lk in logging_outputs[0].keys():
|
||||
|
@ -9,19 +9,20 @@ from fairseq import metrics, utils
|
||||
from fairseq.criterions import register_criterion
|
||||
from fairseq.criterions.label_smoothed_cross_entropy import (
|
||||
LabelSmoothedCrossEntropyCriterion,
|
||||
LabelSmoothedCrossEntropyCriterionConfig
|
||||
LabelSmoothedCrossEntropyCriterionConfig,
|
||||
)
|
||||
|
||||
try:
|
||||
from simuleval.metrics.latency import (
|
||||
AverageLagging,
|
||||
AverageProportion,
|
||||
DifferentiableAverageLagging
|
||||
DifferentiableAverageLagging,
|
||||
)
|
||||
|
||||
LATENCY_METRICS = {
|
||||
"average_lagging": AverageLagging,
|
||||
"average_proportion": AverageProportion,
|
||||
"differentiable_average_lagging": DifferentiableAverageLagging,
|
||||
"differentiable_average_lagging": DifferentiableAverageLagging,
|
||||
}
|
||||
except ImportError:
|
||||
LATENCY_METRICS = None
|
||||
@ -56,9 +57,10 @@ class LabelSmoothedCrossEntropyCriterionLatencyAugmentConfig(
|
||||
metadata={"help": "Add latency loss after certain steps"},
|
||||
)
|
||||
|
||||
|
||||
@register_criterion(
|
||||
"latency_augmented_label_smoothed_cross_entropy",
|
||||
dataclass=LabelSmoothedCrossEntropyCriterionLatencyAugmentConfig
|
||||
dataclass=LabelSmoothedCrossEntropyCriterionLatencyAugmentConfig,
|
||||
)
|
||||
class LatencyAugmentedLabelSmoothedCrossEntropyCriterion(
|
||||
LabelSmoothedCrossEntropyCriterion
|
||||
@ -101,9 +103,9 @@ class LatencyAugmentedLabelSmoothedCrossEntropyCriterion(
|
||||
|
||||
if self.latency_update_after > 0:
|
||||
num_updates = getattr(model.decoder, "num_updates", None)
|
||||
assert num_updates is not None, (
|
||||
"model.decoder doesn't have attribute 'num_updates'"
|
||||
)
|
||||
assert (
|
||||
num_updates is not None
|
||||
), "model.decoder doesn't have attribute 'num_updates'"
|
||||
if num_updates <= self.latency_update_after:
|
||||
latency_loss = 0
|
||||
|
||||
@ -134,9 +136,7 @@ class LatencyAugmentedLabelSmoothedCrossEntropyCriterion(
|
||||
assert (
|
||||
net_output[-1].encoder_padding_mask is None
|
||||
or not net_output[-1].encoder_padding_mask[:, 0].any()
|
||||
), (
|
||||
"Only right padding on source is supported."
|
||||
)
|
||||
), "Only right padding on source is supported."
|
||||
# 1. Obtain the expected alignment
|
||||
alpha_list = [item["alpha"] for item in net_output[1].attn_list]
|
||||
num_layers = len(alpha_list)
|
||||
@ -174,8 +174,7 @@ class LatencyAugmentedLabelSmoothedCrossEntropyCriterion(
|
||||
.view(-1)
|
||||
)
|
||||
expected_latency = LATENCY_METRICS[self.latency_avg_type](
|
||||
expected_delays, src_lengths, None,
|
||||
target_padding_mask=target_padding_mask
|
||||
expected_delays, src_lengths, None, target_padding_mask=target_padding_mask
|
||||
)
|
||||
|
||||
# 2.1 average expected latency of heads
|
||||
@ -210,24 +209,12 @@ class LatencyAugmentedLabelSmoothedCrossEntropyCriterion(
|
||||
@classmethod
|
||||
def reduce_metrics(cls, logging_outputs) -> None:
|
||||
super().reduce_metrics(logging_outputs)
|
||||
latency = sum(
|
||||
log.get("latency", 0) for log in logging_outputs
|
||||
)
|
||||
delays_var = sum(
|
||||
log.get("delays_var", 0) for log in logging_outputs
|
||||
)
|
||||
latency_loss = sum(
|
||||
log.get("latency_loss", 0) for log in logging_outputs
|
||||
)
|
||||
latency = sum(log.get("latency", 0) for log in logging_outputs)
|
||||
delays_var = sum(log.get("delays_var", 0) for log in logging_outputs)
|
||||
latency_loss = sum(log.get("latency_loss", 0) for log in logging_outputs)
|
||||
nsentences = sum(log.get("nsentences", 0) for log in logging_outputs)
|
||||
metrics.log_scalar("latency", latency.float() / nsentences, nsentences, round=3)
|
||||
metrics.log_scalar("delays_var", delays_var / nsentences, nsentences, round=3)
|
||||
metrics.log_scalar(
|
||||
"latency", latency.float() / nsentences, nsentences, round=3
|
||||
)
|
||||
metrics.log_scalar(
|
||||
"delays_var", delays_var / nsentences,
|
||||
nsentences, round=3
|
||||
)
|
||||
metrics.log_scalar(
|
||||
"latency_loss", latency_loss / nsentences,
|
||||
nsentences, round=3
|
||||
"latency_loss", latency_loss / nsentences, nsentences, round=3
|
||||
)
|
||||
|
@ -41,9 +41,7 @@ class Tacotron2CriterionConfig(FairseqDataclass):
|
||||
default=0.4,
|
||||
metadata={"help": "weight of positive examples for BCE loss"},
|
||||
)
|
||||
ctc_weight: float = field(
|
||||
default=0.0, metadata={"help": "weight for CTC loss"}
|
||||
)
|
||||
ctc_weight: float = field(default=0.0, metadata={"help": "weight for CTC loss"})
|
||||
sentence_avg: bool = II("optimization.sentence_avg")
|
||||
|
||||
|
||||
@ -70,8 +68,7 @@ class GuidedAttentionLoss(torch.nn.Module):
|
||||
bsz, max_s_len, max_t_len = len(src_lens), max(src_lens), max(tgt_lens)
|
||||
weights = torch.zeros((bsz, max_t_len, max_s_len))
|
||||
for i, (s_len, t_len) in enumerate(zip(src_lens, tgt_lens)):
|
||||
weights[i, :t_len, :s_len] = self._get_weight(s_len, t_len,
|
||||
self.sigma)
|
||||
weights[i, :t_len, :s_len] = self._get_weight(s_len, t_len, self.sigma)
|
||||
return weights
|
||||
|
||||
@staticmethod
|
||||
@ -90,9 +87,16 @@ class GuidedAttentionLoss(torch.nn.Module):
|
||||
|
||||
@register_criterion("tacotron2", dataclass=Tacotron2CriterionConfig)
|
||||
class Tacotron2Criterion(FairseqCriterion):
|
||||
def __init__(self, task, sentence_avg, n_frames_per_step,
|
||||
use_guided_attention_loss, guided_attention_loss_sigma,
|
||||
bce_pos_weight, ctc_weight):
|
||||
def __init__(
|
||||
self,
|
||||
task,
|
||||
sentence_avg,
|
||||
n_frames_per_step,
|
||||
use_guided_attention_loss,
|
||||
guided_attention_loss_sigma,
|
||||
bce_pos_weight,
|
||||
ctc_weight,
|
||||
):
|
||||
super().__init__(task)
|
||||
self.sentence_avg = sentence_avg
|
||||
self.n_frames_per_step = n_frames_per_step
|
||||
@ -120,31 +124,42 @@ class Tacotron2Criterion(FairseqCriterion):
|
||||
prev_output_tokens=sample["net_input"]["prev_output_tokens"],
|
||||
incremental_state=None,
|
||||
target_lengths=tgt_lens,
|
||||
speaker=sample["speaker"]
|
||||
speaker=sample["speaker"],
|
||||
)
|
||||
|
||||
l1_loss, mse_loss, eos_loss = self.compute_loss(
|
||||
extra["feature_out"], feat_out, eos_out, feat_tgt, eos_tgt,
|
||||
tgt_lens, reduction,
|
||||
extra["feature_out"],
|
||||
feat_out,
|
||||
eos_out,
|
||||
feat_tgt,
|
||||
eos_tgt,
|
||||
tgt_lens,
|
||||
reduction,
|
||||
)
|
||||
attn_loss = torch.tensor(0.).type_as(l1_loss)
|
||||
attn_loss = torch.tensor(0.0).type_as(l1_loss)
|
||||
if self.guided_attn is not None:
|
||||
attn_loss = self.guided_attn(extra['attn'], src_lens, tgt_lens, reduction)
|
||||
ctc_loss = torch.tensor(0.).type_as(l1_loss)
|
||||
if self.ctc_weight > 0.:
|
||||
attn_loss = self.guided_attn(extra["attn"], src_lens, tgt_lens, reduction)
|
||||
ctc_loss = torch.tensor(0.0).type_as(l1_loss)
|
||||
if self.ctc_weight > 0.0:
|
||||
net_output = (feat_out, eos_out, extra)
|
||||
lprobs = model.get_normalized_probs(net_output, log_probs=True)
|
||||
lprobs = lprobs.transpose(0, 1) # T x B x C
|
||||
src_mask = lengths_to_mask(src_lens)
|
||||
src_tokens_flat = src_tokens.masked_select(src_mask)
|
||||
ctc_loss = F.ctc_loss(
|
||||
lprobs, src_tokens_flat, tgt_lens, src_lens,
|
||||
reduction=reduction, zero_infinity=True
|
||||
) * self.ctc_weight
|
||||
ctc_loss = (
|
||||
F.ctc_loss(
|
||||
lprobs,
|
||||
src_tokens_flat,
|
||||
tgt_lens,
|
||||
src_lens,
|
||||
reduction=reduction,
|
||||
zero_infinity=True,
|
||||
)
|
||||
* self.ctc_weight
|
||||
)
|
||||
loss = l1_loss + mse_loss + eos_loss + attn_loss + ctc_loss
|
||||
|
||||
sample_size = sample["nsentences"] if self.sentence_avg \
|
||||
else sample["ntokens"]
|
||||
sample_size = sample["nsentences"] if self.sentence_avg else sample["ntokens"]
|
||||
logging_output = {
|
||||
"loss": utils.item(loss.data),
|
||||
"ntokens": sample["ntokens"],
|
||||
@ -158,8 +173,16 @@ class Tacotron2Criterion(FairseqCriterion):
|
||||
}
|
||||
return loss, sample_size, logging_output
|
||||
|
||||
def compute_loss(self, feat_out, feat_out_post, eos_out, feat_tgt,
|
||||
eos_tgt, tgt_lens, reduction="mean"):
|
||||
def compute_loss(
|
||||
self,
|
||||
feat_out,
|
||||
feat_out_post,
|
||||
eos_out,
|
||||
feat_tgt,
|
||||
eos_tgt,
|
||||
tgt_lens,
|
||||
reduction="mean",
|
||||
):
|
||||
mask = lengths_to_mask(tgt_lens)
|
||||
_eos_out = eos_out[mask].squeeze()
|
||||
_eos_tgt = eos_tgt[mask]
|
||||
@ -167,17 +190,17 @@ class Tacotron2Criterion(FairseqCriterion):
|
||||
_feat_out = feat_out[mask]
|
||||
_feat_out_post = feat_out_post[mask]
|
||||
|
||||
l1_loss = (
|
||||
F.l1_loss(_feat_out, _feat_tgt, reduction=reduction) +
|
||||
F.l1_loss(_feat_out_post, _feat_tgt, reduction=reduction)
|
||||
l1_loss = F.l1_loss(_feat_out, _feat_tgt, reduction=reduction) + F.l1_loss(
|
||||
_feat_out_post, _feat_tgt, reduction=reduction
|
||||
)
|
||||
mse_loss = (
|
||||
F.mse_loss(_feat_out, _feat_tgt, reduction=reduction) +
|
||||
F.mse_loss(_feat_out_post, _feat_tgt, reduction=reduction)
|
||||
mse_loss = F.mse_loss(_feat_out, _feat_tgt, reduction=reduction) + F.mse_loss(
|
||||
_feat_out_post, _feat_tgt, reduction=reduction
|
||||
)
|
||||
eos_loss = F.binary_cross_entropy_with_logits(
|
||||
_eos_out, _eos_tgt, pos_weight=torch.tensor(self.bce_pos_weight),
|
||||
reduction=reduction
|
||||
_eos_out,
|
||||
_eos_tgt,
|
||||
pos_weight=torch.tensor(self.bce_pos_weight),
|
||||
reduction=reduction,
|
||||
)
|
||||
return l1_loss, mse_loss, eos_loss
|
||||
|
||||
@ -197,10 +220,10 @@ class Tacotron2Criterion(FairseqCriterion):
|
||||
return
|
||||
n = sum(log.get("targ_frames", 0) for log in logging_outputs)
|
||||
for key, new_key in [
|
||||
("mcd_loss", "mcd_loss"),
|
||||
("pred_frames", "pred_ratio"),
|
||||
("nins", "ins_rate"),
|
||||
("ndel", "del_rate"),
|
||||
("mcd_loss", "mcd_loss"),
|
||||
("pred_frames", "pred_ratio"),
|
||||
("nins", "ins_rate"),
|
||||
("ndel", "del_rate"),
|
||||
]:
|
||||
val = sum(log.get(key, 0) for log in logging_outputs)
|
||||
metrics.log_scalar(new_key, val / n, n, round=3)
|
||||
|
@ -33,6 +33,7 @@ class Wav2VecCriterionConfig(FairseqDataclass):
|
||||
metadata={"help": "output keys to log"},
|
||||
)
|
||||
|
||||
|
||||
@register_criterion("wav2vec", dataclass=Wav2VecCriterionConfig)
|
||||
class Wav2vecCriterion(FairseqCriterion):
|
||||
def __init__(self, task, infonce=False, loss_weights=None, log_keys=None):
|
||||
@ -76,16 +77,16 @@ class Wav2vecCriterion(FairseqCriterion):
|
||||
# we don't shrink tensors using mask_indices.
|
||||
# Instead, we use mask indices to adjust loss.
|
||||
mi = (
|
||||
sample['net_input']['mask_indices']
|
||||
sample["net_input"]["mask_indices"]
|
||||
.transpose(0, 1) # logits are transposed in `model.get_logits`
|
||||
.reshape(logits.size(0))
|
||||
)
|
||||
loss = (loss * mi).sum() if reduce else (loss * mi)
|
||||
|
||||
if 'sample_size' in sample:
|
||||
sample_size = sample['sample_size']
|
||||
elif 'mask_indices' in sample['net_input']:
|
||||
sample_size = sample['net_input']['mask_indices'].sum()
|
||||
if "sample_size" in sample:
|
||||
sample_size = sample["sample_size"]
|
||||
elif "mask_indices" in sample["net_input"]:
|
||||
sample_size = sample["net_input"]["mask_indices"].sum()
|
||||
else:
|
||||
sample_size = target.numel() if self.infonce else target.long().sum().item()
|
||||
losses.append(loss.detach().clone())
|
||||
@ -216,8 +217,8 @@ class Wav2vecCriterion(FairseqCriterion):
|
||||
metrics.log_scalar(k, val / len(logging_outputs), round=3)
|
||||
|
||||
# FIXME: revert when gather based xla reduction is implemented
|
||||
#@staticmethod
|
||||
#def logging_outputs_can_be_summed() -> bool:
|
||||
# @staticmethod
|
||||
# def logging_outputs_can_be_summed() -> bool:
|
||||
def logging_outputs_can_be_summed(self) -> bool:
|
||||
"""
|
||||
Whether the logging outputs returned by `forward` can be summed
|
||||
|
@ -20,7 +20,7 @@ class AddTargetDataset(BaseWrapperDataset):
|
||||
process_label=None,
|
||||
label_len_fn=None,
|
||||
add_to_input=False,
|
||||
text_compression_level=TextCompressionLevel.none
|
||||
text_compression_level=TextCompressionLevel.none,
|
||||
):
|
||||
super().__init__(dataset)
|
||||
self.labels = labels
|
||||
|
@ -18,26 +18,28 @@ FEATURE_OR_SF_AUDIO_FILE_EXTENSIONS = {".npy", ".wav", ".flac", ".ogg"}
|
||||
|
||||
|
||||
def convert_waveform(
|
||||
waveform: Union[np.ndarray, torch.Tensor], sample_rate: int,
|
||||
normalize_volume: bool = False, to_mono: bool = False,
|
||||
to_sample_rate: Optional[int] = None
|
||||
waveform: Union[np.ndarray, torch.Tensor],
|
||||
sample_rate: int,
|
||||
normalize_volume: bool = False,
|
||||
to_mono: bool = False,
|
||||
to_sample_rate: Optional[int] = None,
|
||||
) -> Tuple[Union[np.ndarray, torch.Tensor], int]:
|
||||
"""convert a waveform:
|
||||
- to a target sample rate
|
||||
- from multi-channel to mono channel
|
||||
- volume normalization
|
||||
- to a target sample rate
|
||||
- from multi-channel to mono channel
|
||||
- volume normalization
|
||||
|
||||
Args:
|
||||
waveform (numpy.ndarray or torch.Tensor): 2D original waveform
|
||||
(channels x length)
|
||||
sample_rate (int): original sample rate
|
||||
normalize_volume (bool): perform volume normalization
|
||||
to_mono (bool): convert to mono channel if having multiple channels
|
||||
to_sample_rate (Optional[int]): target sample rate
|
||||
Returns:
|
||||
waveform (numpy.ndarray): converted 2D waveform (channels x length)
|
||||
sample_rate (float): target sample rate
|
||||
"""
|
||||
Args:
|
||||
waveform (numpy.ndarray or torch.Tensor): 2D original waveform
|
||||
(channels x length)
|
||||
sample_rate (int): original sample rate
|
||||
normalize_volume (bool): perform volume normalization
|
||||
to_mono (bool): convert to mono channel if having multiple channels
|
||||
to_sample_rate (Optional[int]): target sample rate
|
||||
Returns:
|
||||
waveform (numpy.ndarray): converted 2D waveform (channels x length)
|
||||
sample_rate (float): target sample rate
|
||||
"""
|
||||
try:
|
||||
import torchaudio.sox_effects as ta_sox
|
||||
except ImportError:
|
||||
@ -63,10 +65,14 @@ def convert_waveform(
|
||||
|
||||
|
||||
def get_waveform(
|
||||
path_or_fp: Union[str, BinaryIO], normalization: bool = True,
|
||||
mono: bool = True, frames: int = -1, start: int = 0,
|
||||
always_2d: bool = True, output_sample_rate: Optional[int] = None,
|
||||
normalize_volume: bool = False
|
||||
path_or_fp: Union[str, BinaryIO],
|
||||
normalization: bool = True,
|
||||
mono: bool = True,
|
||||
frames: int = -1,
|
||||
start: int = 0,
|
||||
always_2d: bool = True,
|
||||
output_sample_rate: Optional[int] = None,
|
||||
normalize_volume: bool = False,
|
||||
) -> Tuple[np.ndarray, int]:
|
||||
"""Get the waveform and sample rate of a 16-bit WAV/FLAC/OGG Vorbis audio.
|
||||
|
||||
@ -98,8 +104,11 @@ def get_waveform(
|
||||
)
|
||||
waveform = waveform.T # T x C -> C x T
|
||||
waveform, sample_rate = convert_waveform(
|
||||
waveform, sample_rate, normalize_volume=normalize_volume, to_mono=mono,
|
||||
to_sample_rate=output_sample_rate
|
||||
waveform,
|
||||
sample_rate,
|
||||
normalize_volume=normalize_volume,
|
||||
to_mono=mono,
|
||||
to_sample_rate=output_sample_rate,
|
||||
)
|
||||
|
||||
if not normalization:
|
||||
@ -182,7 +191,7 @@ def is_sf_audio_data(data: bytes) -> bool:
|
||||
def mmap_read(path: str, offset: int, length: int) -> bytes:
|
||||
with open(path, "rb") as f:
|
||||
with mmap.mmap(f.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_o:
|
||||
data = mmap_o[offset: offset + length]
|
||||
data = mmap_o[offset : offset + length]
|
||||
return data
|
||||
|
||||
|
||||
@ -215,9 +224,7 @@ def parse_path(path: str) -> Tuple[str, List[int]]:
|
||||
return _path, slice_ptr
|
||||
|
||||
|
||||
def get_window(
|
||||
window_fn: callable, n_fft: int, win_length: int
|
||||
) -> torch.Tensor:
|
||||
def get_window(window_fn: callable, n_fft: int, win_length: int) -> torch.Tensor:
|
||||
padding = n_fft - win_length
|
||||
assert padding >= 0
|
||||
return F.pad(window_fn(win_length), (padding // 2, padding - padding // 2))
|
||||
@ -226,13 +233,13 @@ def get_window(
|
||||
def get_fourier_basis(n_fft: int) -> torch.Tensor:
|
||||
basis = np.fft.fft(np.eye(n_fft))
|
||||
basis = np.vstack(
|
||||
[np.real(basis[:n_fft // 2 + 1, :]), np.imag(basis[:n_fft // 2 + 1, :])]
|
||||
[np.real(basis[: n_fft // 2 + 1, :]), np.imag(basis[: n_fft // 2 + 1, :])]
|
||||
)
|
||||
return torch.from_numpy(basis).float()
|
||||
|
||||
|
||||
def get_mel_filters(
|
||||
sample_rate: int, n_fft: int, n_mels: int, f_min: float, f_max: float
|
||||
sample_rate: int, n_fft: int, n_mels: int, f_min: float, f_max: float
|
||||
) -> torch.Tensor:
|
||||
try:
|
||||
import librosa
|
||||
@ -244,8 +251,12 @@ def get_mel_filters(
|
||||
|
||||
class TTSSpectrogram(torch.nn.Module):
|
||||
def __init__(
|
||||
self, n_fft: int, win_length: int, hop_length: int,
|
||||
window_fn: callable = torch.hann_window, return_phase: bool = False
|
||||
self,
|
||||
n_fft: int,
|
||||
win_length: int,
|
||||
hop_length: int,
|
||||
window_fn: callable = torch.hann_window,
|
||||
return_phase: bool = False,
|
||||
) -> None:
|
||||
super(TTSSpectrogram, self).__init__()
|
||||
self.n_fft = n_fft
|
||||
@ -254,16 +265,16 @@ class TTSSpectrogram(torch.nn.Module):
|
||||
|
||||
basis = get_fourier_basis(n_fft).unsqueeze(1)
|
||||
basis *= get_window(window_fn, n_fft, win_length)
|
||||
self.register_buffer('basis', basis)
|
||||
self.register_buffer("basis", basis)
|
||||
|
||||
def forward(
|
||||
self, waveform: torch.Tensor
|
||||
self, waveform: torch.Tensor
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
padding = (self.n_fft // 2, self.n_fft // 2)
|
||||
x = F.pad(waveform.unsqueeze(1), padding, mode='reflect')
|
||||
x = F.pad(waveform.unsqueeze(1), padding, mode="reflect")
|
||||
x = F.conv1d(x, self.basis, stride=self.hop_length)
|
||||
real_part = x[:, :self.n_fft // 2 + 1, :]
|
||||
imag_part = x[:, self.n_fft // 2 + 1:, :]
|
||||
real_part = x[:, : self.n_fft // 2 + 1, :]
|
||||
imag_part = x[:, self.n_fft // 2 + 1 :, :]
|
||||
magnitude = torch.sqrt(real_part ** 2 + imag_part ** 2)
|
||||
if self.return_phase:
|
||||
phase = torch.atan2(imag_part, real_part)
|
||||
@ -273,13 +284,11 @@ class TTSSpectrogram(torch.nn.Module):
|
||||
|
||||
class TTSMelScale(torch.nn.Module):
|
||||
def __init__(
|
||||
self, n_mels: int, sample_rate: int, f_min: float, f_max: float,
|
||||
n_stft: int
|
||||
self, n_mels: int, sample_rate: int, f_min: float, f_max: float, n_stft: int
|
||||
) -> None:
|
||||
super(TTSMelScale, self).__init__()
|
||||
basis = get_mel_filters(sample_rate, (n_stft - 1) * 2, n_mels, f_min,
|
||||
f_max)
|
||||
self.register_buffer('basis', basis)
|
||||
basis = get_mel_filters(sample_rate, (n_stft - 1) * 2, n_mels, f_min, f_max)
|
||||
self.register_buffer("basis", basis)
|
||||
|
||||
def forward(self, specgram: torch.Tensor) -> torch.Tensor:
|
||||
return torch.matmul(self.basis, specgram)
|
||||
|
@ -13,11 +13,10 @@ from typing import List, Optional
|
||||
import numpy as np
|
||||
import torch
|
||||
from fairseq.data import Dictionary
|
||||
from fairseq.data.audio.speech_to_text_dataset import (
|
||||
S2TDataConfig
|
||||
)
|
||||
from fairseq.data.audio.speech_to_text_dataset import S2TDataConfig
|
||||
from fairseq.data.audio.text_to_speech_dataset import (
|
||||
TextToSpeechDataset, TextToSpeechDatasetCreator
|
||||
TextToSpeechDataset,
|
||||
TextToSpeechDatasetCreator,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -48,7 +47,7 @@ class FrmTextToSpeechDataset(TextToSpeechDataset):
|
||||
chunk_incr=5,
|
||||
add_eos=True,
|
||||
dedup=True,
|
||||
ref_fpu=-1
|
||||
ref_fpu=-1,
|
||||
):
|
||||
# It assumes texts are encoded at a fixed frame-rate
|
||||
super().__init__(
|
||||
@ -67,7 +66,7 @@ class FrmTextToSpeechDataset(TextToSpeechDataset):
|
||||
pre_tokenizer=pre_tokenizer,
|
||||
bpe_tokenizer=bpe_tokenizer,
|
||||
n_frames_per_step=n_frames_per_step,
|
||||
speaker_to_id=speaker_to_id
|
||||
speaker_to_id=speaker_to_id,
|
||||
)
|
||||
|
||||
self.do_chunk = do_chunk
|
||||
@ -92,24 +91,23 @@ class FrmTextToSpeechDataset(TextToSpeechDataset):
|
||||
fpu = source.size(0) / target.size(0) # frame-per-unit
|
||||
fps = self.n_frames_per_step
|
||||
assert (
|
||||
self.ref_fpu == -1 or
|
||||
abs((fpu * fps - self.ref_fpu) / self.ref_fpu) < 0.1
|
||||
self.ref_fpu == -1 or abs((fpu * fps - self.ref_fpu) / self.ref_fpu) < 0.1
|
||||
), f"{fpu*fps} != {self.ref_fpu}"
|
||||
|
||||
# only chunk training split
|
||||
if self.is_train_split and self.do_chunk and self.chunk_size > 0:
|
||||
lang = target[:int(self.data_cfg.prepend_tgt_lang_tag)]
|
||||
text = target[int(self.data_cfg.prepend_tgt_lang_tag):]
|
||||
lang = target[: int(self.data_cfg.prepend_tgt_lang_tag)]
|
||||
text = target[int(self.data_cfg.prepend_tgt_lang_tag) :]
|
||||
size = len(text)
|
||||
chunk_size = min(self.chunk_size, size)
|
||||
chunk_start = np.random.randint(size - chunk_size + 1)
|
||||
text = text[chunk_start:chunk_start+chunk_size]
|
||||
text = text[chunk_start : chunk_start + chunk_size]
|
||||
target = torch.cat((lang, text), 0)
|
||||
|
||||
f_size = int(np.floor(chunk_size * fpu))
|
||||
f_start = int(np.floor(chunk_start * fpu))
|
||||
assert(f_size > 0)
|
||||
source = source[f_start:f_start+f_size, :]
|
||||
assert f_size > 0
|
||||
source = source[f_start : f_start + f_size, :]
|
||||
|
||||
if self.dedup:
|
||||
target = torch.unique_consecutive(target)
|
||||
@ -126,10 +124,12 @@ class FrmTextToSpeechDataset(TextToSpeechDataset):
|
||||
self.chunk_size = self.chunk_init + epoch * self.chunk_incr
|
||||
if self.chunk_bound > 0:
|
||||
self.chunk_size = min(self.chunk_size, self.chunk_bound)
|
||||
logger.info((
|
||||
f"{self.split}: setting chunk size "
|
||||
f"from {old} to {self.chunk_size}"
|
||||
))
|
||||
logger.info(
|
||||
(
|
||||
f"{self.split}: setting chunk size "
|
||||
f"from {old} to {self.chunk_size}"
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class FrmTextToSpeechDatasetCreator(TextToSpeechDatasetCreator):
|
||||
@ -152,7 +152,7 @@ class FrmTextToSpeechDatasetCreator(TextToSpeechDatasetCreator):
|
||||
chunk_incr: int = 5,
|
||||
add_eos: bool = True,
|
||||
dedup: bool = True,
|
||||
ref_fpu: float = -1
|
||||
ref_fpu: float = -1,
|
||||
) -> FrmTextToSpeechDataset:
|
||||
tsv_path = op.join(root, f"{split}.tsv")
|
||||
if not op.isfile(tsv_path):
|
||||
@ -170,9 +170,7 @@ class FrmTextToSpeechDatasetCreator(TextToSpeechDatasetCreator):
|
||||
assert len(s) > 0
|
||||
|
||||
ids = [ss[cls.KEY_ID] for ss in s]
|
||||
audio_paths = [
|
||||
op.join(data_cfg.audio_root, ss[cls.KEY_AUDIO]) for ss in s
|
||||
]
|
||||
audio_paths = [op.join(data_cfg.audio_root, ss[cls.KEY_AUDIO]) for ss in s]
|
||||
n_frames = [int(ss[cls.KEY_N_FRAMES]) for ss in s]
|
||||
tgt_texts = [ss[cls.KEY_TGT_TEXT] for ss in s]
|
||||
src_texts = [ss.get(cls.KEY_SRC_TEXT, cls.DEFAULT_SRC_TEXT) for ss in s]
|
||||
@ -203,5 +201,5 @@ class FrmTextToSpeechDatasetCreator(TextToSpeechDatasetCreator):
|
||||
chunk_incr=chunk_incr,
|
||||
add_eos=add_eos,
|
||||
dedup=dedup,
|
||||
ref_fpu=ref_fpu
|
||||
ref_fpu=ref_fpu,
|
||||
)
|
||||
|
@ -152,10 +152,7 @@ class HubertDataset(FairseqDataset):
|
||||
self.label_offsets_list = [
|
||||
load_label_offset(p, inds, tot) for p in label_paths
|
||||
]
|
||||
assert (
|
||||
label_processors is None
|
||||
or len(label_processors) == self.num_labels
|
||||
)
|
||||
assert label_processors is None or len(label_processors) == self.num_labels
|
||||
for label_path, label_rate in zip(label_paths, self.label_rates):
|
||||
verify_label_lengths(
|
||||
self.sizes, sample_rate, label_path, label_rate, inds, tot
|
||||
@ -234,8 +231,7 @@ class HubertDataset(FairseqDataset):
|
||||
)
|
||||
|
||||
targets_by_label = [
|
||||
[s["label_list"][i] for s in samples]
|
||||
for i in range(self.num_labels)
|
||||
[s["label_list"][i] for s in samples] for i in range(self.num_labels)
|
||||
]
|
||||
targets_list, lengths_list, ntokens_list = self.collater_label(
|
||||
targets_by_label, audio_size, audio_starts
|
||||
@ -270,9 +266,7 @@ class HubertDataset(FairseqDataset):
|
||||
collated_audios[i] = audio
|
||||
elif diff < 0:
|
||||
assert self.pad_audio
|
||||
collated_audios[i] = torch.cat(
|
||||
[audio, audio.new_full((-diff,), 0.0)]
|
||||
)
|
||||
collated_audios[i] = torch.cat([audio, audio.new_full((-diff,), 0.0)])
|
||||
padding_mask[i, diff:] = True
|
||||
else:
|
||||
collated_audios[i], audio_starts[i] = self.crop_to_max_size(
|
||||
@ -280,9 +274,7 @@ class HubertDataset(FairseqDataset):
|
||||
)
|
||||
return collated_audios, padding_mask, audio_starts
|
||||
|
||||
def collater_frm_label(
|
||||
self, targets, audio_size, audio_starts, label_rate, pad
|
||||
):
|
||||
def collater_frm_label(self, targets, audio_size, audio_starts, label_rate, pad):
|
||||
assert label_rate > 0
|
||||
s2f = label_rate / self.sample_rate
|
||||
frm_starts = [int(round(s * s2f)) for s in audio_starts]
|
||||
@ -290,24 +282,20 @@ class HubertDataset(FairseqDataset):
|
||||
if not self.pad_audio:
|
||||
rem_size = [len(t) - s for t, s in zip(targets, frm_starts)]
|
||||
frm_size = min(frm_size, *rem_size)
|
||||
targets = [t[s: s + frm_size] for t, s in zip(targets, frm_starts)]
|
||||
targets = [t[s : s + frm_size] for t, s in zip(targets, frm_starts)]
|
||||
logger.debug(f"audio_starts={audio_starts}")
|
||||
logger.debug(f"frame_starts={frm_starts}")
|
||||
logger.debug(f"frame_size={frm_size}")
|
||||
|
||||
lengths = torch.LongTensor([len(t) for t in targets])
|
||||
ntokens = lengths.sum().item()
|
||||
targets = data_utils.collate_tokens(
|
||||
targets, pad_idx=pad, left_pad=False
|
||||
)
|
||||
targets = data_utils.collate_tokens(targets, pad_idx=pad, left_pad=False)
|
||||
return targets, lengths, ntokens
|
||||
|
||||
def collater_seq_label(self, targets, pad):
|
||||
lengths = torch.LongTensor([len(t) for t in targets])
|
||||
ntokens = lengths.sum().item()
|
||||
targets = data_utils.collate_tokens(
|
||||
targets, pad_idx=pad, left_pad=False
|
||||
)
|
||||
targets = data_utils.collate_tokens(targets, pad_idx=pad, left_pad=False)
|
||||
return targets, lengths, ntokens
|
||||
|
||||
def collater_label(self, targets_by_label, audio_size, audio_starts):
|
||||
@ -315,9 +303,7 @@ class HubertDataset(FairseqDataset):
|
||||
itr = zip(targets_by_label, self.label_rates, self.pad_list)
|
||||
for targets, label_rate, pad in itr:
|
||||
if label_rate == -1:
|
||||
targets, lengths, ntokens = self.collater_seq_label(
|
||||
targets, pad
|
||||
)
|
||||
targets, lengths, ntokens = self.collater_seq_label(targets, pad)
|
||||
else:
|
||||
targets, lengths, ntokens = self.collater_frm_label(
|
||||
targets, audio_size, audio_starts, label_rate, pad
|
||||
|
@ -29,6 +29,7 @@ class ModalityDatasetItem(NamedTuple):
|
||||
max_tokens: Optional[int] = None
|
||||
max_sentences: Optional[int] = None
|
||||
|
||||
|
||||
# MultiModalityDataset: it concate multiple datasets with different modalities.
|
||||
# Compared with ConcatDataset it can 1) sample data given the ratios for different datasets
|
||||
# 2) it adds mode to indicate what type of the data samples come from.
|
||||
|
@ -308,6 +308,7 @@ class FileAudioDataset(RawAudioDataset):
|
||||
|
||||
def __getitem__(self, index):
|
||||
import soundfile as sf
|
||||
|
||||
fn = self.fnames[index]
|
||||
fn = fn if isinstance(self.fnames, list) else fn.as_py()
|
||||
fn = self.text_compressor.decompress(fn)
|
||||
|
@ -45,7 +45,11 @@ def get_features_from_npy_or_audio(path):
|
||||
|
||||
|
||||
def get_features_or_waveform_from_stored_zip(
|
||||
path, byte_offset, byte_size, need_waveform=False, use_sample_rate=None,
|
||||
path,
|
||||
byte_offset,
|
||||
byte_size,
|
||||
need_waveform=False,
|
||||
use_sample_rate=None,
|
||||
):
|
||||
assert path.endswith(".zip")
|
||||
data = read_from_stored_zip(path, byte_offset, byte_size)
|
||||
@ -53,18 +57,17 @@ def get_features_or_waveform_from_stored_zip(
|
||||
if is_npy_data(data):
|
||||
features_or_waveform = np.load(f)
|
||||
elif is_sf_audio_data(data):
|
||||
features_or_waveform = \
|
||||
get_waveform(
|
||||
f, always_2d=False, output_sample_rate=use_sample_rate
|
||||
)[0] if need_waveform else get_fbank(f)
|
||||
features_or_waveform = (
|
||||
get_waveform(f, always_2d=False, output_sample_rate=use_sample_rate)[0]
|
||||
if need_waveform
|
||||
else get_fbank(f)
|
||||
)
|
||||
else:
|
||||
raise ValueError(f'Unknown file format for "{path}"')
|
||||
return features_or_waveform
|
||||
|
||||
|
||||
def get_features_or_waveform(
|
||||
path: str, need_waveform=False, use_sample_rate=None
|
||||
):
|
||||
def get_features_or_waveform(path: str, need_waveform=False, use_sample_rate=None):
|
||||
"""Get speech features from .npy file or waveform from .wav/.flac file.
|
||||
The file may be inside an uncompressed ZIP file and is accessed via byte
|
||||
offset and length.
|
||||
@ -87,8 +90,11 @@ def get_features_or_waveform(
|
||||
return get_features_from_npy_or_audio(_path)
|
||||
elif len(slice_ptr) == 2:
|
||||
features_or_waveform = get_features_or_waveform_from_stored_zip(
|
||||
_path, slice_ptr[0], slice_ptr[1], need_waveform=need_waveform,
|
||||
use_sample_rate=use_sample_rate
|
||||
_path,
|
||||
slice_ptr[0],
|
||||
slice_ptr[1],
|
||||
need_waveform=need_waveform,
|
||||
use_sample_rate=use_sample_rate,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid path: {path}")
|
||||
@ -145,7 +151,7 @@ class SpeechToTextDataset(FairseqDataset):
|
||||
pre_tokenizer=None,
|
||||
bpe_tokenizer=None,
|
||||
n_frames_per_step=1,
|
||||
speaker_to_id=None
|
||||
speaker_to_id=None,
|
||||
):
|
||||
self.split, self.is_train_split = split, is_train_split
|
||||
self.cfg = cfg
|
||||
@ -235,7 +241,7 @@ class SpeechToTextDataset(FairseqDataset):
|
||||
if self.n_frames_per_step == 1:
|
||||
return feature
|
||||
n_packed_frames = feature.shape[0] // self.n_frames_per_step
|
||||
feature = feature[:self.n_frames_per_step * n_packed_frames]
|
||||
feature = feature[: self.n_frames_per_step * n_packed_frames]
|
||||
return feature.reshape(n_packed_frames, -1)
|
||||
|
||||
@classmethod
|
||||
@ -318,9 +324,11 @@ class SpeechToTextDataset(FairseqDataset):
|
||||
|
||||
speaker = None
|
||||
if self.speaker_to_id is not None:
|
||||
speaker = torch.tensor(
|
||||
[s.speaker_id for s in samples], dtype=torch.long
|
||||
).index_select(0, order).view(-1, 1)
|
||||
speaker = (
|
||||
torch.tensor([s.speaker_id for s in samples], dtype=torch.long)
|
||||
.index_select(0, order)
|
||||
.view(-1, 1)
|
||||
)
|
||||
|
||||
net_input = {
|
||||
"src_tokens": frames,
|
||||
@ -388,7 +396,7 @@ class SpeechToTextDatasetCreator(object):
|
||||
pre_tokenizer,
|
||||
bpe_tokenizer,
|
||||
n_frames_per_step,
|
||||
speaker_to_id
|
||||
speaker_to_id,
|
||||
) -> SpeechToTextDataset:
|
||||
audio_root = Path(cfg.audio_root)
|
||||
ids = [s[cls.KEY_ID] for s in samples]
|
||||
@ -415,7 +423,7 @@ class SpeechToTextDatasetCreator(object):
|
||||
pre_tokenizer=pre_tokenizer,
|
||||
bpe_tokenizer=bpe_tokenizer,
|
||||
n_frames_per_step=n_frames_per_step,
|
||||
speaker_to_id=speaker_to_id
|
||||
speaker_to_id=speaker_to_id,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -481,12 +489,19 @@ class SpeechToTextDatasetCreator(object):
|
||||
pre_tokenizer,
|
||||
bpe_tokenizer,
|
||||
n_frames_per_step,
|
||||
speaker_to_id
|
||||
speaker_to_id,
|
||||
) -> SpeechToTextDataset:
|
||||
samples = cls._load_samples_from_tsv(root, split)
|
||||
return cls._from_list(
|
||||
split, is_train_split, samples, cfg, tgt_dict, pre_tokenizer,
|
||||
bpe_tokenizer, n_frames_per_step, speaker_to_id
|
||||
split,
|
||||
is_train_split,
|
||||
samples,
|
||||
cfg,
|
||||
tgt_dict,
|
||||
pre_tokenizer,
|
||||
bpe_tokenizer,
|
||||
n_frames_per_step,
|
||||
speaker_to_id,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -502,12 +517,19 @@ class SpeechToTextDatasetCreator(object):
|
||||
epoch: int,
|
||||
seed: int,
|
||||
n_frames_per_step: int = 1,
|
||||
speaker_to_id=None
|
||||
speaker_to_id=None,
|
||||
) -> SpeechToTextDataset:
|
||||
datasets = [
|
||||
cls._from_tsv(
|
||||
root, cfg, split, tgt_dict, is_train_split, pre_tokenizer,
|
||||
bpe_tokenizer, n_frames_per_step, speaker_to_id
|
||||
root,
|
||||
cfg,
|
||||
split,
|
||||
tgt_dict,
|
||||
is_train_split,
|
||||
pre_tokenizer,
|
||||
bpe_tokenizer,
|
||||
n_frames_per_step,
|
||||
speaker_to_id,
|
||||
)
|
||||
for split in splits.split(",")
|
||||
]
|
||||
|
@ -13,8 +13,11 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
from fairseq.data.audio.speech_to_text_dataset import (
|
||||
SpeechToTextDataset, SpeechToTextDatasetCreator, S2TDataConfig,
|
||||
_collate_frames, get_features_or_waveform
|
||||
SpeechToTextDataset,
|
||||
SpeechToTextDatasetCreator,
|
||||
S2TDataConfig,
|
||||
_collate_frames,
|
||||
get_features_or_waveform,
|
||||
)
|
||||
from fairseq.data import Dictionary, data_utils as fairseq_data_utils
|
||||
|
||||
@ -32,34 +35,44 @@ class TextToSpeechDatasetItem(object):
|
||||
|
||||
class TextToSpeechDataset(SpeechToTextDataset):
|
||||
def __init__(
|
||||
self,
|
||||
split: str,
|
||||
is_train_split: bool,
|
||||
cfg: S2TDataConfig,
|
||||
audio_paths: List[str],
|
||||
n_frames: List[int],
|
||||
src_texts: Optional[List[str]] = None,
|
||||
tgt_texts: Optional[List[str]] = None,
|
||||
speakers: Optional[List[str]] = None,
|
||||
src_langs: Optional[List[str]] = None,
|
||||
tgt_langs: Optional[List[str]] = None,
|
||||
ids: Optional[List[str]] = None,
|
||||
tgt_dict: Optional[Dictionary] = None,
|
||||
pre_tokenizer=None,
|
||||
bpe_tokenizer=None,
|
||||
n_frames_per_step=1,
|
||||
speaker_to_id=None,
|
||||
durations: Optional[List[List[int]]] = None,
|
||||
pitches: Optional[List[str]] = None,
|
||||
energies: Optional[List[str]] = None
|
||||
self,
|
||||
split: str,
|
||||
is_train_split: bool,
|
||||
cfg: S2TDataConfig,
|
||||
audio_paths: List[str],
|
||||
n_frames: List[int],
|
||||
src_texts: Optional[List[str]] = None,
|
||||
tgt_texts: Optional[List[str]] = None,
|
||||
speakers: Optional[List[str]] = None,
|
||||
src_langs: Optional[List[str]] = None,
|
||||
tgt_langs: Optional[List[str]] = None,
|
||||
ids: Optional[List[str]] = None,
|
||||
tgt_dict: Optional[Dictionary] = None,
|
||||
pre_tokenizer=None,
|
||||
bpe_tokenizer=None,
|
||||
n_frames_per_step=1,
|
||||
speaker_to_id=None,
|
||||
durations: Optional[List[List[int]]] = None,
|
||||
pitches: Optional[List[str]] = None,
|
||||
energies: Optional[List[str]] = None,
|
||||
):
|
||||
super(TextToSpeechDataset, self).__init__(
|
||||
split, is_train_split, cfg, audio_paths, n_frames,
|
||||
src_texts=src_texts, tgt_texts=tgt_texts, speakers=speakers,
|
||||
src_langs=src_langs, tgt_langs=tgt_langs, ids=ids,
|
||||
tgt_dict=tgt_dict, pre_tokenizer=pre_tokenizer,
|
||||
bpe_tokenizer=bpe_tokenizer, n_frames_per_step=n_frames_per_step,
|
||||
speaker_to_id=speaker_to_id
|
||||
split,
|
||||
is_train_split,
|
||||
cfg,
|
||||
audio_paths,
|
||||
n_frames,
|
||||
src_texts=src_texts,
|
||||
tgt_texts=tgt_texts,
|
||||
speakers=speakers,
|
||||
src_langs=src_langs,
|
||||
tgt_langs=tgt_langs,
|
||||
ids=ids,
|
||||
tgt_dict=tgt_dict,
|
||||
pre_tokenizer=pre_tokenizer,
|
||||
bpe_tokenizer=bpe_tokenizer,
|
||||
n_frames_per_step=n_frames_per_step,
|
||||
speaker_to_id=speaker_to_id,
|
||||
)
|
||||
self.durations = durations
|
||||
self.pitches = pitches
|
||||
@ -84,9 +97,13 @@ class TextToSpeechDataset(SpeechToTextDataset):
|
||||
np.concatenate((energy, [0])) # pad 0 for EOS
|
||||
).float()
|
||||
return TextToSpeechDatasetItem(
|
||||
index=index, source=s2t_item.source, target=s2t_item.target,
|
||||
speaker_id=s2t_item.speaker_id, duration=duration, pitch=pitch,
|
||||
energy=energy
|
||||
index=index,
|
||||
source=s2t_item.source,
|
||||
target=s2t_item.target,
|
||||
speaker_id=s2t_item.speaker_id,
|
||||
duration=duration,
|
||||
pitch=pitch,
|
||||
energy=energy,
|
||||
)
|
||||
|
||||
def collater(self, samples: List[TextToSpeechDatasetItem]) -> Dict[str, Any]:
|
||||
@ -96,8 +113,9 @@ class TextToSpeechDataset(SpeechToTextDataset):
|
||||
src_lengths, order = torch.tensor(
|
||||
[s.target.shape[0] for s in samples], dtype=torch.long
|
||||
).sort(descending=True)
|
||||
id_ = torch.tensor([s.index for s in samples],
|
||||
dtype=torch.long).index_select(0, order)
|
||||
id_ = torch.tensor([s.index for s in samples], dtype=torch.long).index_select(
|
||||
0, order
|
||||
)
|
||||
feat = _collate_frames(
|
||||
[s.source for s in samples], self.cfg.use_audio_input
|
||||
).index_select(0, order)
|
||||
@ -115,9 +133,11 @@ class TextToSpeechDataset(SpeechToTextDataset):
|
||||
|
||||
speaker = None
|
||||
if self.speaker_to_id is not None:
|
||||
speaker = torch.tensor(
|
||||
[s.speaker_id for s in samples], dtype=torch.long
|
||||
).index_select(0, order).view(-1, 1)
|
||||
speaker = (
|
||||
torch.tensor([s.speaker_id for s in samples], dtype=torch.long)
|
||||
.index_select(0, order)
|
||||
.view(-1, 1)
|
||||
)
|
||||
|
||||
bsz, _, d = feat.size()
|
||||
prev_output_tokens = torch.cat(
|
||||
@ -175,7 +195,7 @@ class TextToSpeechDatasetCreator(SpeechToTextDatasetCreator):
|
||||
pre_tokenizer,
|
||||
bpe_tokenizer,
|
||||
n_frames_per_step,
|
||||
speaker_to_id
|
||||
speaker_to_id,
|
||||
) -> TextToSpeechDataset:
|
||||
audio_root = Path(cfg.audio_root)
|
||||
ids = [s[cls.KEY_ID] for s in samples]
|
||||
@ -189,27 +209,40 @@ class TextToSpeechDatasetCreator(SpeechToTextDatasetCreator):
|
||||
|
||||
durations = [s.get(cls.KEY_DURATION, None) for s in samples]
|
||||
durations = [
|
||||
None if dd is None else [int(d) for d in dd.split(" ")]
|
||||
for dd in durations
|
||||
None if dd is None else [int(d) for d in dd.split(" ")] for dd in durations
|
||||
]
|
||||
durations = None if any(dd is None for dd in durations) else durations
|
||||
|
||||
pitches = [s.get(cls.KEY_PITCH, None) for s in samples]
|
||||
pitches = [
|
||||
None if pp is None else (audio_root / pp).as_posix()
|
||||
for pp in pitches
|
||||
None if pp is None else (audio_root / pp).as_posix() for pp in pitches
|
||||
]
|
||||
pitches = None if any(pp is None for pp in pitches) else pitches
|
||||
|
||||
energies = [s.get(cls.KEY_ENERGY, None) for s in samples]
|
||||
energies = [
|
||||
None if ee is None else (audio_root / ee).as_posix()
|
||||
for ee in energies]
|
||||
None if ee is None else (audio_root / ee).as_posix() for ee in energies
|
||||
]
|
||||
energies = None if any(ee is None for ee in energies) else energies
|
||||
|
||||
return TextToSpeechDataset(
|
||||
split_name, is_train_split, cfg, audio_paths, n_frames,
|
||||
src_texts, tgt_texts, speakers, src_langs, tgt_langs, ids, tgt_dict,
|
||||
pre_tokenizer, bpe_tokenizer, n_frames_per_step, speaker_to_id,
|
||||
durations, pitches, energies
|
||||
split_name,
|
||||
is_train_split,
|
||||
cfg,
|
||||
audio_paths,
|
||||
n_frames,
|
||||
src_texts,
|
||||
tgt_texts,
|
||||
speakers,
|
||||
src_langs,
|
||||
tgt_langs,
|
||||
ids,
|
||||
tgt_dict,
|
||||
pre_tokenizer,
|
||||
bpe_tokenizer,
|
||||
n_frames_per_step,
|
||||
speaker_to_id,
|
||||
durations,
|
||||
pitches,
|
||||
energies,
|
||||
)
|
||||
|
@ -9,7 +9,7 @@ from . import BaseWrapperDataset
|
||||
|
||||
|
||||
class ColorizeDataset(BaseWrapperDataset):
|
||||
""" Adds 'colors' property to net input that is obtained from the provided color getter for use by models """
|
||||
"""Adds 'colors' property to net input that is obtained from the provided color getter for use by models"""
|
||||
|
||||
def __init__(self, dataset, color_getter):
|
||||
super().__init__(dataset)
|
||||
|
@ -69,6 +69,7 @@ def collate_tokens(
|
||||
copy_tensor(v, res[i][size - len(v) :] if left_pad else res[i][: len(v)])
|
||||
return res
|
||||
|
||||
|
||||
def load_indexed_dataset(
|
||||
path, dictionary=None, dataset_impl=None, combine=False, default="cached"
|
||||
):
|
||||
@ -324,9 +325,7 @@ def batch_by_size(
|
||||
)
|
||||
|
||||
# added int() to avoid TypeError: an integer is required
|
||||
max_tokens = (
|
||||
int(max_tokens) if max_tokens is not None else -1
|
||||
)
|
||||
max_tokens = int(max_tokens) if max_tokens is not None else -1
|
||||
max_sentences = max_sentences if max_sentences is not None else -1
|
||||
bsz_mult = required_batch_size_multiple
|
||||
|
||||
@ -375,8 +374,9 @@ def post_process(sentence: str, symbol: str):
|
||||
sentence = sentence.replace(" ", "").replace("|", " ").strip()
|
||||
elif symbol == "silence":
|
||||
import re
|
||||
|
||||
sentence = sentence.replace("<SIL>", "")
|
||||
sentence = re.sub(' +', ' ', sentence).strip()
|
||||
sentence = re.sub(" +", " ", sentence).strip()
|
||||
elif symbol == "_EOW":
|
||||
sentence = sentence.replace(" ", "").replace("_EOW", " ").strip()
|
||||
elif symbol in {"subword_nmt", "@@ ", "@@"}:
|
||||
@ -547,7 +547,7 @@ def get_buckets(sizes, num_buckets):
|
||||
np.percentile(
|
||||
sizes,
|
||||
np.linspace(0, 100, num_buckets + 1),
|
||||
interpolation='lower',
|
||||
interpolation="lower",
|
||||
)[1:]
|
||||
)
|
||||
return buckets
|
||||
@ -564,7 +564,6 @@ def get_bucketed_sizes(orig_sizes, buckets):
|
||||
return sizes
|
||||
|
||||
|
||||
|
||||
def _find_extra_valid_paths(dataset_path: str) -> set:
|
||||
paths = utils.split_paths(dataset_path)
|
||||
all_valid_paths = set()
|
||||
|
@ -21,8 +21,10 @@ class SentencepieceConfig(FairseqDataclass):
|
||||
)
|
||||
sentencepiece_alpha: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={"help": "soothing parameter for unigram sampling, "
|
||||
"and merge probability for BPE-dropout"}
|
||||
metadata={
|
||||
"help": "soothing parameter for unigram sampling, "
|
||||
"and merge probability for BPE-dropout"
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@ -45,8 +47,7 @@ class SentencepieceBPE(object):
|
||||
def encode(self, x: str) -> str:
|
||||
return " ".join(
|
||||
self.sp.Encode(
|
||||
x, out_type=str, enable_sampling=self.enable_sampling,
|
||||
alpha=self.alpha
|
||||
x, out_type=str, enable_sampling=self.enable_sampling, alpha=self.alpha
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -138,7 +138,7 @@ class FairseqDataset(torch.utils.data.Dataset, EpochListening):
|
||||
)
|
||||
|
||||
try:
|
||||
num_tokens_vec = self.num_tokens_vec(indices).astype('int64')
|
||||
num_tokens_vec = self.num_tokens_vec(indices).astype("int64")
|
||||
except NotImplementedError:
|
||||
num_tokens_vec = None
|
||||
|
||||
|
@ -140,7 +140,9 @@ class HuffmanNode:
|
||||
def is_leaf(self) -> bool:
|
||||
return self.left is None and self.right is None
|
||||
|
||||
def code_table(self, prefix: tp.Optional[bitarray] = None) -> tp.Dict[str, "HuffmanNode"]:
|
||||
def code_table(
|
||||
self, prefix: tp.Optional[bitarray] = None
|
||||
) -> tp.Dict[str, "HuffmanNode"]:
|
||||
defaulted_prefix = prefix if prefix is not None else bitarray()
|
||||
if self.is_leaf():
|
||||
self.code = (
|
||||
|
@ -67,7 +67,9 @@ def make_builder(out_file, impl, vocab_size=None):
|
||||
elif impl == "fasta":
|
||||
raise NotImplementedError
|
||||
elif impl == "huffman":
|
||||
raise ValueError("Use HuffmanCodeBuilder directly as it has a different interface.")
|
||||
raise ValueError(
|
||||
"Use HuffmanCodeBuilder directly as it has a different interface."
|
||||
)
|
||||
else:
|
||||
return IndexedDatasetBuilder(out_file)
|
||||
|
||||
|
@ -380,7 +380,9 @@ class EpochBatchIterator(EpochBatchIterating):
|
||||
# reset _frozen_batches to refresh the next epoch
|
||||
self._frozen_batches = None
|
||||
self._cur_epoch_itr = self._get_iterator_for_epoch(
|
||||
self.epoch, shuffle, fix_batches_to_gpus=fix_batches_to_gpus,
|
||||
self.epoch,
|
||||
shuffle,
|
||||
fix_batches_to_gpus=fix_batches_to_gpus,
|
||||
)
|
||||
self.shuffle = shuffle
|
||||
return self._cur_epoch_itr
|
||||
@ -421,7 +423,9 @@ class EpochBatchIterator(EpochBatchIterating):
|
||||
if itr_pos > 0:
|
||||
# fast-forward epoch iterator
|
||||
self._next_epoch_itr = self._get_iterator_for_epoch(
|
||||
self.epoch, shuffle=state_dict.get("shuffle", True), offset=itr_pos,
|
||||
self.epoch,
|
||||
shuffle=state_dict.get("shuffle", True),
|
||||
offset=itr_pos,
|
||||
)
|
||||
if self._next_epoch_itr is None:
|
||||
if version == 1:
|
||||
|
@ -114,7 +114,10 @@ def collate(
|
||||
"id": id,
|
||||
"nsentences": len(samples),
|
||||
"ntokens": ntokens,
|
||||
"net_input": {"src_tokens": src_tokens, "src_lengths": src_lengths,},
|
||||
"net_input": {
|
||||
"src_tokens": src_tokens,
|
||||
"src_lengths": src_lengths,
|
||||
},
|
||||
"target": target,
|
||||
}
|
||||
if prev_output_tokens is not None:
|
||||
@ -467,5 +470,8 @@ class LanguagePairDataset(FairseqDataset):
|
||||
list: list of removed indices
|
||||
"""
|
||||
return data_utils.filter_paired_dataset_indices_by_size(
|
||||
self.src_sizes, self.tgt_sizes, indices, max_sizes,
|
||||
self.src_sizes,
|
||||
self.tgt_sizes,
|
||||
indices,
|
||||
max_sizes,
|
||||
)
|
||||
|
@ -80,7 +80,9 @@ class MultiCorpusDataset(FairseqDataset):
|
||||
def ordered_indices(self):
|
||||
start = time.time()
|
||||
with data_utils.numpy_seed(self.seed, self.epoch):
|
||||
logger.info(f"sampling new dataset with seed {self.seed} epoch {self.epoch}")
|
||||
logger.info(
|
||||
f"sampling new dataset with seed {self.seed} epoch {self.epoch}"
|
||||
)
|
||||
sampled_indices = []
|
||||
num_selected_instances = 0
|
||||
|
||||
|
@ -40,8 +40,8 @@ from fairseq.utils import FileContentsAction, csv_str_list, eval_str_dict
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SRC_DICT_NAME = 'src'
|
||||
TGT_DICT_NAME = 'tgt'
|
||||
SRC_DICT_NAME = "src"
|
||||
TGT_DICT_NAME = "tgt"
|
||||
|
||||
|
||||
def _lang_id(dic: Dictionary, lang: str):
|
||||
@ -64,14 +64,16 @@ class MultilingualDatasetManager(object):
|
||||
self.seed = args.seed
|
||||
self.lang_pairs = lang_pairs
|
||||
self.extra_lang_pairs = (
|
||||
list(
|
||||
{p for _, v in args.extra_lang_pairs.items() for p in v.split(",")}
|
||||
)
|
||||
if args.extra_lang_pairs
|
||||
else []
|
||||
)
|
||||
self.src_langs = {p.split("-")[0] for p in args.lang_pairs + self.extra_lang_pairs}
|
||||
self.tgt_langs = {p.split("-")[1] for p in args.lang_pairs + self.extra_lang_pairs}
|
||||
list({p for _, v in args.extra_lang_pairs.items() for p in v.split(",")})
|
||||
if args.extra_lang_pairs
|
||||
else []
|
||||
)
|
||||
self.src_langs = {
|
||||
p.split("-")[0] for p in args.lang_pairs + self.extra_lang_pairs
|
||||
}
|
||||
self.tgt_langs = {
|
||||
p.split("-")[1] for p in args.lang_pairs + self.extra_lang_pairs
|
||||
}
|
||||
self.langs = langs
|
||||
self.dicts = dicts
|
||||
self.lang_dict = self.create_lang_dictionary(self.langs)
|
||||
@ -111,10 +113,18 @@ class MultilingualDatasetManager(object):
|
||||
"note that the ordering determines language token IDs; "
|
||||
"--langs and --lang-dict are two exclusive options",
|
||||
)
|
||||
parser.add_argument('--source-dict', default=None, type=str,
|
||||
help='path to source dictionary; if specified it will override per language dictionary loading')
|
||||
parser.add_argument('--target-dict', default=None, type=str,
|
||||
help='path to target dictionary; if specified it will override per language dictionary loading')
|
||||
parser.add_argument(
|
||||
"--source-dict",
|
||||
default=None,
|
||||
type=str,
|
||||
help="path to source dictionary; if specified it will override per language dictionary loading",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--target-dict",
|
||||
default=None,
|
||||
type=str,
|
||||
help="path to target dictionary; if specified it will override per language dictionary loading",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lang-tok-style",
|
||||
default=LangTokStyle.multilingual.value,
|
||||
@ -378,7 +388,9 @@ class MultilingualDatasetManager(object):
|
||||
)
|
||||
return d
|
||||
|
||||
dicts = cls.load_all_dictionaries(args, language_list, load_dictionary_and_postproc, training)
|
||||
dicts = cls.load_all_dictionaries(
|
||||
args, language_list, load_dictionary_and_postproc, training
|
||||
)
|
||||
return language_list, dicts, training
|
||||
|
||||
@classmethod
|
||||
@ -424,7 +436,10 @@ class MultilingualDatasetManager(object):
|
||||
|
||||
if args.fixed_dictionary is not None:
|
||||
fixed_dict = load_dictionary(args.fixed_dictionary)
|
||||
dicts = {lang: fixed_dict for lang in src_langs_to_load_dicts + tgt_langs_to_load_dicts}
|
||||
dicts = {
|
||||
lang: fixed_dict
|
||||
for lang in src_langs_to_load_dicts + tgt_langs_to_load_dicts
|
||||
}
|
||||
else:
|
||||
if args.source_dict is None:
|
||||
load_dicts(src_langs_to_load_dicts)
|
||||
@ -477,7 +492,10 @@ class MultilingualDatasetManager(object):
|
||||
lang=tgt_lang, lang_tok_style=self.args.lang_tok_style, spec=spec
|
||||
)
|
||||
return self.get_langtok_index(
|
||||
langtok, self.get_source_dictionary(src_lang) if src_lang else self.get_target_dictionary(tgt_lang)
|
||||
langtok,
|
||||
self.get_source_dictionary(src_lang)
|
||||
if src_lang
|
||||
else self.get_target_dictionary(tgt_lang),
|
||||
)
|
||||
|
||||
def get_decoder_langtok(self, tgt_lang, spec=None):
|
||||
@ -819,7 +837,9 @@ class MultilingualDatasetManager(object):
|
||||
if self.args.lang_tok_replacing_bos_eos:
|
||||
ds = self.alter_dataset_langtok(
|
||||
langpair_ds,
|
||||
src_eos=self.get_source_dictionary(src).eos() if src else self.get_target_dictionary(tgt).eos(),
|
||||
src_eos=self.get_source_dictionary(src).eos()
|
||||
if src
|
||||
else self.get_target_dictionary(tgt).eos(),
|
||||
src_lang=src,
|
||||
tgt_eos=self.get_target_dictionary(tgt).eos(),
|
||||
tgt_lang=tgt,
|
||||
|
@ -298,7 +298,6 @@ class NoisingDataset(torch.utils.data.Dataset):
|
||||
)
|
||||
self.sizes = src_dataset.sizes
|
||||
|
||||
|
||||
def __getitem__(self, index):
|
||||
"""
|
||||
Returns a single noisy sample. Multiple samples are fed to the collater
|
||||
|
@ -14,8 +14,7 @@ class TextCompressionLevel(Enum):
|
||||
|
||||
class TextCompressor(object):
|
||||
def __init__(
|
||||
self, level: TextCompressionLevel,
|
||||
max_input_byte_length: int = 2 ** 16
|
||||
self, level: TextCompressionLevel, max_input_byte_length: int = 2 ** 16
|
||||
):
|
||||
self.level = level
|
||||
self.max_input_length = max_input_byte_length
|
||||
@ -23,11 +22,13 @@ class TextCompressor(object):
|
||||
def compress(self, text: str) -> bytes:
|
||||
if self.level == TextCompressionLevel.low:
|
||||
import zlib
|
||||
|
||||
# zlib: built-in, fast
|
||||
return zlib.compress(text.encode(), level=0)
|
||||
elif self.level == TextCompressionLevel.high:
|
||||
try:
|
||||
import unishox2
|
||||
|
||||
# unishox2: optimized for short text but slower
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
@ -42,6 +43,7 @@ class TextCompressor(object):
|
||||
def decompress(self, compressed: bytes) -> str:
|
||||
if self.level == TextCompressionLevel.low:
|
||||
import zlib
|
||||
|
||||
return zlib.decompress(compressed).decode()
|
||||
elif self.level == TextCompressionLevel.high:
|
||||
try:
|
||||
|
@ -69,7 +69,10 @@ class TokenBlockDataset(FairseqDataset):
|
||||
_sizes, split_path, (plasma_id, 1), plasma_path=plasma_path
|
||||
)
|
||||
self._block_to_dataset_index = plasma_utils.PlasmaView(
|
||||
block_to_dataset_index, split_path, (plasma_id, 2), plasma_path=plasma_path,
|
||||
block_to_dataset_index,
|
||||
split_path,
|
||||
(plasma_id, 2),
|
||||
plasma_path=plasma_path,
|
||||
)
|
||||
else:
|
||||
self._slice_indices = plasma_utils.PlasmaArray(slice_indices)
|
||||
@ -127,7 +130,8 @@ class TokenBlockDataset(FairseqDataset):
|
||||
)
|
||||
else:
|
||||
block_to_dataset_index = _get_block_to_dataset_index_fast(
|
||||
sizes, slice_indices,
|
||||
sizes,
|
||||
slice_indices,
|
||||
)
|
||||
size_dtype = np.uint16 if block_size < 65535 else np.uint32
|
||||
num_tokens = slice_indices[-1].max()
|
||||
|
@ -52,7 +52,7 @@ class TransformEosLangPairDataset(FairseqDataset):
|
||||
if len(samples) == 0:
|
||||
return samples
|
||||
|
||||
if 'net_input' not in samples:
|
||||
if "net_input" not in samples:
|
||||
return samples
|
||||
|
||||
if self.new_src_eos is not None:
|
||||
|
@ -126,7 +126,8 @@ class CommonConfig(FairseqDataclass):
|
||||
metadata={"help": "Weights and Biases project name to use for logging"},
|
||||
)
|
||||
azureml_logging: Optional[bool] = field(
|
||||
default=False, metadata={"help": "Log scalars to AzureML context"},
|
||||
default=False,
|
||||
metadata={"help": "Log scalars to AzureML context"},
|
||||
)
|
||||
seed: int = field(
|
||||
default=1, metadata={"help": "pseudo random number generator seed"}
|
||||
@ -428,19 +429,23 @@ class DistributedTrainingConfig(FairseqDataclass):
|
||||
tpu: bool = II("common.tpu")
|
||||
# configuration for --ddp-backend=fully_sharded
|
||||
no_reshard_after_forward: bool = field(
|
||||
default=False, metadata={"help": "don't reshard parameters after forward pass"},
|
||||
default=False,
|
||||
metadata={"help": "don't reshard parameters after forward pass"},
|
||||
)
|
||||
fp32_reduce_scatter: bool = field(
|
||||
default=False, metadata={"help": "reduce-scatter grads in FP32"},
|
||||
default=False,
|
||||
metadata={"help": "reduce-scatter grads in FP32"},
|
||||
)
|
||||
cpu_offload: bool = field(
|
||||
default=False, metadata={"help": "offload FP32 params to CPU"}
|
||||
)
|
||||
use_sharded_state: bool = field(
|
||||
default=False, metadata={"help": "use sharded checkpoint files"},
|
||||
default=False,
|
||||
metadata={"help": "use sharded checkpoint files"},
|
||||
)
|
||||
not_fsdp_flatten_parameters: bool = field(
|
||||
default=False, metadata={"help": "not flatten parameter param for fsdp"},
|
||||
default=False,
|
||||
metadata={"help": "not flatten parameter param for fsdp"},
|
||||
)
|
||||
|
||||
|
||||
@ -786,10 +791,12 @@ class FairseqBMUFConfig(FairseqDataclass):
|
||||
@dataclass
|
||||
class GenerationConfig(FairseqDataclass):
|
||||
beam: int = field(
|
||||
default=5, metadata={"help": "beam size"},
|
||||
default=5,
|
||||
metadata={"help": "beam size"},
|
||||
)
|
||||
nbest: int = field(
|
||||
default=1, metadata={"help": "number of hypotheses to output"},
|
||||
default=1,
|
||||
metadata={"help": "number of hypotheses to output"},
|
||||
)
|
||||
max_len_a: float = field(
|
||||
default=0,
|
||||
@ -804,19 +811,24 @@ class GenerationConfig(FairseqDataclass):
|
||||
},
|
||||
)
|
||||
min_len: int = field(
|
||||
default=1, metadata={"help": "minimum generation length"},
|
||||
default=1,
|
||||
metadata={"help": "minimum generation length"},
|
||||
)
|
||||
match_source_len: bool = field(
|
||||
default=False, metadata={"help": "generations should match the source length"},
|
||||
default=False,
|
||||
metadata={"help": "generations should match the source length"},
|
||||
)
|
||||
unnormalized: bool = field(
|
||||
default=False, metadata={"help": "compare unnormalized hypothesis scores"},
|
||||
default=False,
|
||||
metadata={"help": "compare unnormalized hypothesis scores"},
|
||||
)
|
||||
no_early_stop: bool = field(
|
||||
default=False, metadata={"help": "deprecated"},
|
||||
default=False,
|
||||
metadata={"help": "deprecated"},
|
||||
)
|
||||
no_beamable_mm: bool = field(
|
||||
default=False, metadata={"help": "don't use BeamableMM in attention layers"},
|
||||
default=False,
|
||||
metadata={"help": "don't use BeamableMM in attention layers"},
|
||||
)
|
||||
lenpen: float = field(
|
||||
default=1,
|
||||
@ -838,10 +850,12 @@ class GenerationConfig(FairseqDataclass):
|
||||
},
|
||||
)
|
||||
sacrebleu: bool = field(
|
||||
default=False, metadata={"help": "score with sacrebleu"},
|
||||
default=False,
|
||||
metadata={"help": "score with sacrebleu"},
|
||||
)
|
||||
score_reference: bool = field(
|
||||
default=False, metadata={"help": "just score the reference translation"},
|
||||
default=False,
|
||||
metadata={"help": "just score the reference translation"},
|
||||
)
|
||||
prefix_size: int = field(
|
||||
default=0,
|
||||
@ -875,10 +889,12 @@ class GenerationConfig(FairseqDataclass):
|
||||
},
|
||||
)
|
||||
temperature: float = field(
|
||||
default=1.0, metadata={"help": "temperature for generation"},
|
||||
default=1.0,
|
||||
metadata={"help": "temperature for generation"},
|
||||
)
|
||||
diverse_beam_groups: int = field(
|
||||
default=-1, metadata={"help": "number of groups for Diverse Beam Search"},
|
||||
default=-1,
|
||||
metadata={"help": "number of groups for Diverse Beam Search"},
|
||||
)
|
||||
diverse_beam_strength: float = field(
|
||||
default=0.5,
|
||||
@ -897,13 +913,16 @@ class GenerationConfig(FairseqDataclass):
|
||||
},
|
||||
)
|
||||
print_step: bool = field(
|
||||
default=False, metadata={"help": "print steps"},
|
||||
default=False,
|
||||
metadata={"help": "print steps"},
|
||||
)
|
||||
lm_path: Optional[str] = field(
|
||||
default=None, metadata={"help": "path to lm checkpoint for lm fusion"},
|
||||
default=None,
|
||||
metadata={"help": "path to lm checkpoint for lm fusion"},
|
||||
)
|
||||
lm_weight: float = field(
|
||||
default=0.0, metadata={"help": "weight for lm probs for lm fusion"},
|
||||
default=0.0,
|
||||
metadata={"help": "weight for lm probs for lm fusion"},
|
||||
)
|
||||
|
||||
# arguments for iterative refinement generator
|
||||
@ -912,7 +931,8 @@ class GenerationConfig(FairseqDataclass):
|
||||
metadata={"help": "if > 0.0, it penalized early-stopping in decoding."},
|
||||
)
|
||||
iter_decode_max_iter: int = field(
|
||||
default=10, metadata={"help": "maximum iterations for iterative refinement."},
|
||||
default=10,
|
||||
metadata={"help": "maximum iterations for iterative refinement."},
|
||||
)
|
||||
iter_decode_force_max_iter: bool = field(
|
||||
default=False,
|
||||
@ -939,7 +959,8 @@ class GenerationConfig(FairseqDataclass):
|
||||
},
|
||||
)
|
||||
retain_dropout: bool = field(
|
||||
default=False, metadata={"help": "Use dropout at inference time"},
|
||||
default=False,
|
||||
metadata={"help": "Use dropout at inference time"},
|
||||
)
|
||||
# temporarily set to Any until https://github.com/facebookresearch/hydra/issues/1117 is fixed
|
||||
# retain_dropout_modules: Optional[List[str]] = field(
|
||||
@ -964,7 +985,8 @@ class GenerationConfig(FairseqDataclass):
|
||||
@dataclass
|
||||
class CommonEvalConfig(FairseqDataclass):
|
||||
path: Optional[str] = field(
|
||||
default=None, metadata={"help": "path(s) to model file(s), colon separated"},
|
||||
default=None,
|
||||
metadata={"help": "path(s) to model file(s), colon separated"},
|
||||
)
|
||||
post_process: Optional[str] = field(
|
||||
default=None,
|
||||
@ -1026,7 +1048,8 @@ class InteractiveConfig(FairseqDataclass):
|
||||
},
|
||||
)
|
||||
input: str = field(
|
||||
default="-", metadata={"help": "file to read from; use - for stdin"},
|
||||
default="-",
|
||||
metadata={"help": "file to read from; use - for stdin"},
|
||||
)
|
||||
|
||||
|
||||
|
@ -35,14 +35,16 @@ def ChoiceEnum(choices: List[str]):
|
||||
|
||||
|
||||
LOG_FORMAT_CHOICES = ChoiceEnum(["json", "none", "simple", "tqdm"])
|
||||
DDP_BACKEND_CHOICES = ChoiceEnum([
|
||||
"c10d", # alias for pytorch_ddp
|
||||
"fully_sharded", # FullyShardedDataParallel from fairscale
|
||||
"legacy_ddp",
|
||||
"no_c10d", # alias for legacy_ddp
|
||||
"pytorch_ddp",
|
||||
"slowmo",
|
||||
])
|
||||
DDP_BACKEND_CHOICES = ChoiceEnum(
|
||||
[
|
||||
"c10d", # alias for pytorch_ddp
|
||||
"fully_sharded", # FullyShardedDataParallel from fairscale
|
||||
"legacy_ddp",
|
||||
"no_c10d", # alias for legacy_ddp
|
||||
"pytorch_ddp",
|
||||
"slowmo",
|
||||
]
|
||||
)
|
||||
DDP_COMM_HOOK_CHOICES = ChoiceEnum(["none", "fp16"])
|
||||
DATASET_IMPL_CHOICES = ChoiceEnum(["raw", "lazy", "cached", "mmap", "fasta", "huffman"])
|
||||
GENERATION_CONSTRAINTS_CHOICES = ChoiceEnum(["ordered", "unordered"])
|
||||
|
@ -28,7 +28,7 @@ def hydra_init(cfg_name="config") -> None:
|
||||
|
||||
|
||||
def add_defaults(cfg: DictConfig) -> None:
|
||||
"""This function adds default values that are stored in dataclasses that hydra doesn't know about """
|
||||
"""This function adds default values that are stored in dataclasses that hydra doesn't know about"""
|
||||
|
||||
from fairseq.registry import REGISTRIES
|
||||
from fairseq.tasks import TASK_DATACLASS_REGISTRY
|
||||
|
@ -57,21 +57,21 @@ def gen_parser_from_dataclass(
|
||||
with_prefix: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
convert a dataclass instance to tailing parser arguments.
|
||||
convert a dataclass instance to tailing parser arguments.
|
||||
|
||||
If `with_prefix` is provided, prefix all the keys in the resulting parser with it. It means that we are
|
||||
building a flat namespace from a structured dataclass (see transformer_config.py for example).
|
||||
If `with_prefix` is provided, prefix all the keys in the resulting parser with it. It means that we are
|
||||
building a flat namespace from a structured dataclass (see transformer_config.py for example).
|
||||
"""
|
||||
|
||||
def argparse_name(name: str):
|
||||
if name == "data" and (with_prefix is None or with_prefix == ''):
|
||||
if name == "data" and (with_prefix is None or with_prefix == ""):
|
||||
# normally data is positional args, so we don't add the -- nor the prefix
|
||||
return name
|
||||
if name == "_name":
|
||||
# private member, skip
|
||||
return None
|
||||
full_name = "--" + name.replace("_", "-")
|
||||
if with_prefix is not None and with_prefix != '':
|
||||
if with_prefix is not None and with_prefix != "":
|
||||
# if a prefix is specified, construct the prefixed arg name
|
||||
full_name = with_prefix + "-" + full_name[2:] # strip -- when composing
|
||||
return full_name
|
||||
@ -143,8 +143,8 @@ def gen_parser_from_dataclass(
|
||||
kwargs["default"] = field_default
|
||||
|
||||
# build the help with the hierarchical prefix
|
||||
if with_prefix is not None and with_prefix != '' and field_help is not None:
|
||||
field_help = with_prefix[2:] + ': ' + field_help
|
||||
if with_prefix is not None and with_prefix != "" and field_help is not None:
|
||||
field_help = with_prefix[2:] + ": " + field_help
|
||||
|
||||
kwargs["help"] = field_help
|
||||
if field_const is not None:
|
||||
|
@ -4,7 +4,11 @@
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from .distributed_timeout_wrapper import DistributedTimeoutWrapper
|
||||
from .fully_sharded_data_parallel import fsdp_enable_wrap, fsdp_wrap, FullyShardedDataParallel
|
||||
from .fully_sharded_data_parallel import (
|
||||
fsdp_enable_wrap,
|
||||
fsdp_wrap,
|
||||
FullyShardedDataParallel,
|
||||
)
|
||||
from .legacy_distributed_data_parallel import LegacyDistributedDataParallel
|
||||
from .module_proxy_wrapper import ModuleProxyWrapper
|
||||
from .tpu_distributed_data_parallel import TPUDistributedDataParallel
|
||||
|
@ -33,6 +33,7 @@ class DistributedTimeoutWrapper(nn.Module):
|
||||
(set to a value <= 0 to disable the timeout)
|
||||
signal (Optional): signal to send once timeout is triggered
|
||||
"""
|
||||
|
||||
def __init__(self, module: nn.Module, timeout: int, signal=signal.SIGINT):
|
||||
super().__init__()
|
||||
self.module = module
|
||||
@ -86,9 +87,11 @@ class DistributedTimeoutWrapper(nn.Module):
|
||||
if self._terminated:
|
||||
break
|
||||
elif not success:
|
||||
logger.error((
|
||||
"Killing job for not making progress in {} seconds. "
|
||||
"Set --heartbeat-timeout=-1 to disable this timeout."
|
||||
).format(int(self.timeout)))
|
||||
logger.error(
|
||||
(
|
||||
"Killing job for not making progress in {} seconds. "
|
||||
"Set --heartbeat-timeout=-1 to disable this timeout."
|
||||
).format(int(self.timeout))
|
||||
)
|
||||
os.kill(parent_pid, self.signal)
|
||||
return
|
||||
|
@ -137,7 +137,7 @@ class LegacyDistributedDataParallel(nn.Module):
|
||||
if param.grad is None:
|
||||
param.grad = torch.zeros_like(param)
|
||||
|
||||
if hasattr(param, 'expert'):
|
||||
if hasattr(param, "expert"):
|
||||
# Skip gradient sync for unshared parameters
|
||||
continue
|
||||
|
||||
|
@ -26,8 +26,9 @@ class ModuleProxyWrapper(nn.Module):
|
||||
|
||||
def __init__(self, module: nn.Module):
|
||||
super().__init__()
|
||||
assert hasattr(module, "module"), \
|
||||
"ModuleProxyWrapper expects input to wrap another module"
|
||||
assert hasattr(
|
||||
module, "module"
|
||||
), "ModuleProxyWrapper expects input to wrap another module"
|
||||
self.module = module
|
||||
|
||||
def __getattr__(self, name):
|
||||
|
@ -10,7 +10,6 @@ from fairseq.distributed import utils
|
||||
|
||||
|
||||
class TPUDistributedDataParallel(nn.Module):
|
||||
|
||||
def __init__(self, module, process_group):
|
||||
super().__init__()
|
||||
self.module = module
|
||||
@ -35,9 +34,10 @@ class TPUDistributedDataParallel(nn.Module):
|
||||
gradients.append(p.grad)
|
||||
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
xm.all_reduce(
|
||||
'sum',
|
||||
"sum",
|
||||
gradients,
|
||||
scale=1. / self.world_size,
|
||||
scale=1.0 / self.world_size,
|
||||
groups=self.process_group[1],
|
||||
)
|
||||
|
@ -201,9 +201,7 @@ def _pipeline_parallel_post_init(
|
||||
# distributed_world_size to be based on the total number of GPUs, so
|
||||
# we need to correct them to be based on the number of pipelines.
|
||||
assert cfg.distributed_world_size % num_pipeline_devices == 0
|
||||
cfg.distributed_world_size = (
|
||||
cfg.distributed_world_size // num_pipeline_devices
|
||||
)
|
||||
cfg.distributed_world_size = cfg.distributed_world_size // num_pipeline_devices
|
||||
# In the case of 4-way MP on nodes with 8 GPUs, we want
|
||||
# distributed_rank to be the starting GPU index for each pipeline
|
||||
# i.e., 0, 2, ...
|
||||
@ -306,8 +304,10 @@ def distributed_init(cfg: FairseqConfig):
|
||||
model_part_number = get_model_parallel_rank()
|
||||
cfg.checkpoint.checkpoint_suffix += "-model_part-{0}".format(model_part_number)
|
||||
|
||||
if hasattr(cfg, "model") and getattr(cfg.model, "base_layers", 0) > 0:
|
||||
cfg.checkpoint.checkpoint_suffix = f"-rank-{cfg.distributed_training.distributed_rank}"
|
||||
if hasattr(cfg, "model") and getattr(cfg.model, "base_layers", 0) > 0:
|
||||
cfg.checkpoint.checkpoint_suffix = (
|
||||
f"-rank-{cfg.distributed_training.distributed_rank}"
|
||||
)
|
||||
|
||||
return cfg.distributed_training.distributed_rank
|
||||
|
||||
@ -696,7 +696,7 @@ def broadcast_tensors(
|
||||
dist_device = torch.device("cpu")
|
||||
|
||||
# share metadata first to simplify transfer
|
||||
is_src_rank = (get_rank(group) == src_rank)
|
||||
is_src_rank = get_rank(group) == src_rank
|
||||
if is_src_rank:
|
||||
metadata = [
|
||||
{"size": t.size(), "dtype": t.dtype, "device": t.device} for t in tensors
|
||||
@ -747,7 +747,10 @@ def broadcast_object(
|
||||
|
||||
|
||||
def _broadcast_object_slow(
|
||||
obj: Any, src_rank: int, group: object, dist_device: torch.device,
|
||||
obj: Any,
|
||||
src_rank: int,
|
||||
group: object,
|
||||
dist_device: torch.device,
|
||||
) -> Any:
|
||||
if get_rank(group) == src_rank:
|
||||
# Emit data
|
||||
|
@ -152,6 +152,7 @@ class PathManager:
|
||||
"""
|
||||
ioPath async PathManager methods:
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def opena(
|
||||
path: str,
|
||||
@ -169,6 +170,7 @@ class PathManager:
|
||||
logging.info("ioPath is initializing PathManager.")
|
||||
try:
|
||||
from iopath.common.file_io import PathManager
|
||||
|
||||
IOPathManager = PathManager()
|
||||
except Exception:
|
||||
logging.exception("Failed to initialize ioPath PathManager object.")
|
||||
|
@ -146,6 +146,7 @@ def cached_path_from_pm(url_or_filename):
|
||||
"""
|
||||
try:
|
||||
from fairseq.file_io import PathManager
|
||||
|
||||
local_path = PathManager.get_local_path(url_or_filename)
|
||||
return local_path
|
||||
except Exception:
|
||||
|
@ -130,6 +130,7 @@ def log_scalar(
|
||||
agg.add_meter(key, AverageMeter(round=round), priority)
|
||||
agg[key].update(value, weight)
|
||||
|
||||
|
||||
def log_scalar_sum(
|
||||
key: str,
|
||||
value: float,
|
||||
@ -309,6 +310,7 @@ def load_state_dict(state_dict):
|
||||
def xla_metrics_report():
|
||||
try:
|
||||
import torch_xla.debug.metrics as met
|
||||
|
||||
print(met.metrics_report())
|
||||
except ImportError:
|
||||
return
|
||||
|
@ -52,8 +52,7 @@ class MegatronTrainer(Trainer):
|
||||
|
||||
def save_checkpoint(self, filename, extra_state):
|
||||
"""Save all training state in a checkpoint file."""
|
||||
extra_state['rng_tracker_states'] \
|
||||
= get_cuda_rng_tracker().get_states()
|
||||
extra_state["rng_tracker_states"] = get_cuda_rng_tracker().get_states()
|
||||
super().save_checkpoint(filename, extra_state)
|
||||
|
||||
def load_checkpoint(
|
||||
@ -64,8 +63,13 @@ class MegatronTrainer(Trainer):
|
||||
optimizer_overrides=None,
|
||||
reset_meters=False,
|
||||
):
|
||||
extra_state = super().load_checkpoint(filename, reset_optimizer=reset_optimizer, reset_lr_scheduler=reset_lr_scheduler, optimizer_overrides=optimizer_overrides, reset_meters=reset_meters)
|
||||
if extra_state is not None and 'rng_tracker_states' in extra_state:
|
||||
get_cuda_rng_tracker().set_states(
|
||||
extra_state['rng_tracker_states'])
|
||||
extra_state = super().load_checkpoint(
|
||||
filename,
|
||||
reset_optimizer=reset_optimizer,
|
||||
reset_lr_scheduler=reset_lr_scheduler,
|
||||
optimizer_overrides=optimizer_overrides,
|
||||
reset_meters=reset_meters,
|
||||
)
|
||||
if extra_state is not None and "rng_tracker_states" in extra_state:
|
||||
get_cuda_rng_tracker().set_states(extra_state["rng_tracker_states"])
|
||||
return extra_state
|
||||
|
@ -9,6 +9,7 @@ from collections import namedtuple
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from fairseq import options, utils
|
||||
from fairseq.modules import (
|
||||
AdaptiveSoftmax,
|
||||
@ -17,7 +18,6 @@ from fairseq.modules import (
|
||||
PositionalEmbedding,
|
||||
)
|
||||
|
||||
|
||||
EncoderOut = namedtuple(
|
||||
"TransformerEncoderOut",
|
||||
[
|
||||
@ -30,7 +30,7 @@ EncoderOut = namedtuple(
|
||||
|
||||
|
||||
class TransformerEncoderEmbedding(nn.Module):
|
||||
""" Encoder Embedding + Positional Embedding """
|
||||
"""Encoder Embedding + Positional Embedding"""
|
||||
|
||||
def __init__(self, args, embed_tokens):
|
||||
super().__init__()
|
||||
@ -109,7 +109,7 @@ class TransformerEncoderLayerNorm(nn.Module):
|
||||
|
||||
|
||||
class TransformerDecoderEmbedding(nn.Module):
|
||||
""" Decoder Embedding + Positional Embedding """
|
||||
"""Decoder Embedding + Positional Embedding"""
|
||||
|
||||
def __init__(self, args, embed_tokens):
|
||||
super().__init__()
|
||||
|
@ -42,16 +42,20 @@ DEFAULT_MAX_TARGET_POSITIONS = 1024
|
||||
TORCH_PIPE = False
|
||||
RPC_INIT = False
|
||||
|
||||
|
||||
def import_pipe():
|
||||
global TORCH_PIPE
|
||||
global RPC_INIT
|
||||
try:
|
||||
from torch.distributed.pipeline.sync import Pipe # noqa
|
||||
from torch.distributed.pipeline.sync import Pipe # noqa
|
||||
|
||||
global Pipe
|
||||
from torch.distributed.pipeline.sync.utils import partition_model
|
||||
|
||||
global partition_model
|
||||
from torch.distributed import rpc
|
||||
import tempfile
|
||||
|
||||
TORCH_PIPE = True
|
||||
# Initialize single process RPC agent since TORCH_PIPE requires
|
||||
# RRef. RRef depends on RPC being initialized and as a result we initialize
|
||||
@ -64,14 +68,15 @@ def import_pipe():
|
||||
world_size=1,
|
||||
rpc_backend_options=rpc.TensorPipeRpcBackendOptions(
|
||||
init_method="file://{}".format(tmpfile.name),
|
||||
)
|
||||
),
|
||||
)
|
||||
RPC_INIT = True
|
||||
logger.info('Using torch pipe')
|
||||
logger.info("Using torch pipe")
|
||||
except ImportError:
|
||||
try:
|
||||
from fairscale.nn import Pipe # noqa
|
||||
logger.info('Using fairscale pipe')
|
||||
from fairscale.nn import Pipe # noqa
|
||||
|
||||
logger.info("Using fairscale pipe")
|
||||
except ImportError:
|
||||
raise ImportError("Please install fairscale with: pip install fairscale")
|
||||
|
||||
@ -153,9 +158,14 @@ class PipelineParallelTransformerModel(BaseFairseqModel):
|
||||
decoder_module_list.append(module)
|
||||
module_count += 1
|
||||
self.model = None
|
||||
self.encoder = TransformerEncoder(cfg.distributed_training, None, None, encoder_module_list)
|
||||
self.encoder = TransformerEncoder(
|
||||
cfg.distributed_training, None, None, encoder_module_list
|
||||
)
|
||||
self.decoder = TransformerDecoder(
|
||||
cfg.distributed_training, None, None, decoder_module_list=decoder_module_list
|
||||
cfg.distributed_training,
|
||||
None,
|
||||
None,
|
||||
decoder_module_list=decoder_module_list,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@ -471,7 +481,9 @@ class TransformerEncoder(FairseqEncoder):
|
||||
self.use_pipeline = encoder_module_list is not None
|
||||
if not self.use_pipeline:
|
||||
self.embedding_layer = TransformerEncoderEmbedding(args, embed_tokens)
|
||||
self.encoder_layers = nn.Sequential(*[TransformerEncoderLayer(args) for i in range(args.encoder_layers)])
|
||||
self.encoder_layers = nn.Sequential(
|
||||
*[TransformerEncoderLayer(args) for i in range(args.encoder_layers)]
|
||||
)
|
||||
if isinstance(embed_tokens, nn.ModuleList):
|
||||
emb_dim = sum(e.embedding_dim for e in embed_tokens)
|
||||
else:
|
||||
@ -490,7 +502,11 @@ class TransformerEncoder(FairseqEncoder):
|
||||
)
|
||||
if TORCH_PIPE:
|
||||
self.model = Pipe(
|
||||
module=partition_model(nn.Sequential(*encoder_module_list), encoder_balance, encoder_devices),
|
||||
module=partition_model(
|
||||
nn.Sequential(*encoder_module_list),
|
||||
encoder_balance,
|
||||
encoder_devices,
|
||||
),
|
||||
chunks=args.pipeline_chunks,
|
||||
checkpoint=args.pipeline_checkpoint,
|
||||
)
|
||||
@ -614,10 +630,12 @@ class TransformerDecoder(FairseqDecoder):
|
||||
self.use_pipeline = decoder_module_list is not None
|
||||
if not self.use_pipeline:
|
||||
self.embedding_layer = TransformerDecoderEmbedding(args, embed_tokens)
|
||||
self.decoder_layers = nn.Sequential(*[
|
||||
TransformerDecoderLayer(args, no_encoder_attn)
|
||||
for _ in range(args.decoder_layers)
|
||||
])
|
||||
self.decoder_layers = nn.Sequential(
|
||||
*[
|
||||
TransformerDecoderLayer(args, no_encoder_attn)
|
||||
for _ in range(args.decoder_layers)
|
||||
]
|
||||
)
|
||||
self.decoder_output_layer = TransformerDecoderOutputLayer(
|
||||
args, embed_tokens, dictionary
|
||||
)
|
||||
@ -634,7 +652,11 @@ class TransformerDecoder(FairseqDecoder):
|
||||
)
|
||||
if TORCH_PIPE:
|
||||
self.model = Pipe(
|
||||
module=partition_model(nn.Sequential(*decoder_module_list), decoder_balance, decoder_devices),
|
||||
module=partition_model(
|
||||
nn.Sequential(*decoder_module_list),
|
||||
decoder_balance,
|
||||
decoder_devices,
|
||||
),
|
||||
chunks=args.pipeline_chunks,
|
||||
checkpoint=args.pipeline_checkpoint,
|
||||
)
|
||||
|
@ -4,11 +4,11 @@
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
from fairseq.model_parallel.models.transformer import ModelParallelTransformerDecoder
|
||||
from fairseq.models import register_model, register_model_architecture
|
||||
from fairseq.models.transformer_lm import TransformerLanguageModel
|
||||
|
||||
|
||||
try:
|
||||
from fairseq.model_parallel.megatron.mpu import VocabParallelEmbedding
|
||||
|
||||
@ -22,7 +22,6 @@ DEFAULT_MAX_TARGET_POSITIONS = 1024
|
||||
|
||||
@register_model("model_parallel_transformer_lm")
|
||||
class ModelParallelTransformerLanguageModel(TransformerLanguageModel):
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
TransformerLanguageModel.add_args(parser)
|
||||
@ -72,10 +71,6 @@ class ModelParallelTransformerLanguageModel(TransformerLanguageModel):
|
||||
)
|
||||
return cls(decoder)
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
TransformerLanguageModel.add_args(parser)
|
||||
|
||||
@classmethod
|
||||
def build_embedding(cls, args, dictionary, embed_dim, path=None):
|
||||
def _vocab_init(tensor, **kwargs):
|
||||
|
@ -98,9 +98,7 @@ def build_model(cfg: FairseqDataclass, task):
|
||||
|
||||
assert model is not None, (
|
||||
f"Could not infer model type from {cfg}. "
|
||||
"Available models: {}".format(
|
||||
MODEL_DATACLASS_REGISTRY.keys()
|
||||
)
|
||||
"Available models: {}".format(MODEL_DATACLASS_REGISTRY.keys())
|
||||
+ f" Requested model type: {model_type}"
|
||||
)
|
||||
|
||||
|
@ -100,8 +100,8 @@ class BARTHubInterface(GeneratorHubInterface):
|
||||
raise NotImplementedError("prefix generation not implemented for BART")
|
||||
res = []
|
||||
for batch in self._build_batches(tokenized_sentences, skip_invalid_size_inputs):
|
||||
src_tokens = batch['net_input']['src_tokens']
|
||||
inference_step_args["prefix_tokens"] =src_tokens.new_full(
|
||||
src_tokens = batch["net_input"]["src_tokens"]
|
||||
inference_step_args["prefix_tokens"] = src_tokens.new_full(
|
||||
(src_tokens.size(0), 1), fill_value=self.task.source_dictionary.bos()
|
||||
).to(device=self.device)
|
||||
results = super().generate(
|
||||
@ -111,7 +111,7 @@ class BARTHubInterface(GeneratorHubInterface):
|
||||
skip_invalid_size_inputs=skip_invalid_size_inputs,
|
||||
**kwargs
|
||||
)
|
||||
for id, hypos in zip(batch['id'].tolist(), results):
|
||||
for id, hypos in zip(batch["id"].tolist(), results):
|
||||
res.append((id, hypos))
|
||||
res = [hypos for _, hypos in sorted(res, key=lambda x: x[0])]
|
||||
return res
|
||||
@ -177,32 +177,35 @@ class BARTHubInterface(GeneratorHubInterface):
|
||||
match_source_len: bool = True,
|
||||
**generate_kwargs
|
||||
):
|
||||
masked_token = '<mask>'
|
||||
masked_token = "<mask>"
|
||||
batch_tokens = []
|
||||
for masked_input in masked_inputs:
|
||||
assert masked_token in masked_input, \
|
||||
"please add one {} token for the input".format(masked_token)
|
||||
assert (
|
||||
masked_token in masked_input
|
||||
), "please add one {} token for the input".format(masked_token)
|
||||
|
||||
text_spans = masked_input.split(masked_token)
|
||||
text_spans_bpe = (' {0} '.format(masked_token)).join(
|
||||
[self.bpe.encode(text_span.rstrip()) for text_span in text_spans]
|
||||
).strip()
|
||||
text_spans_bpe = (
|
||||
(" {0} ".format(masked_token))
|
||||
.join([self.bpe.encode(text_span.rstrip()) for text_span in text_spans])
|
||||
.strip()
|
||||
)
|
||||
tokens = self.task.source_dictionary.encode_line(
|
||||
'<s> ' + text_spans_bpe + ' </s>',
|
||||
"<s> " + text_spans_bpe + " </s>",
|
||||
append_eos=False,
|
||||
add_if_not_exist=False,
|
||||
).long()
|
||||
batch_tokens.append(tokens)
|
||||
|
||||
# ensure beam size is at least as big as topk
|
||||
generate_kwargs['beam'] = max(
|
||||
generate_kwargs["beam"] = max(
|
||||
topk,
|
||||
generate_kwargs.get('beam', -1),
|
||||
generate_kwargs.get("beam", -1),
|
||||
)
|
||||
generate_kwargs['match_source_len'] = match_source_len
|
||||
generate_kwargs["match_source_len"] = match_source_len
|
||||
batch_hypos = self.generate(batch_tokens, **generate_kwargs)
|
||||
|
||||
return [
|
||||
[(self.decode(hypo['tokens']), hypo['score']) for hypo in hypos[:topk]]
|
||||
[(self.decode(hypo["tokens"]), hypo["score"]) for hypo in hypos[:topk]]
|
||||
for hypos in batch_hypos
|
||||
]
|
||||
|
@ -90,7 +90,7 @@ class BARTModel(TransformerModel):
|
||||
src_tokens,
|
||||
src_lengths=src_lengths,
|
||||
token_embeddings=token_embeddings,
|
||||
return_all_hiddens=return_all_hiddens
|
||||
return_all_hiddens=return_all_hiddens,
|
||||
)
|
||||
x, extra = self.decoder(
|
||||
prev_output_tokens,
|
||||
@ -103,9 +103,9 @@ class BARTModel(TransformerModel):
|
||||
)
|
||||
eos: int = self.eos
|
||||
if classification_head_name is not None:
|
||||
sentence_representation = x[
|
||||
src_tokens.eq(eos), :
|
||||
].view(x.size(0), -1, x.size(-1))[:, -1, :]
|
||||
sentence_representation = x[src_tokens.eq(eos), :].view(
|
||||
x.size(0), -1, x.size(-1)
|
||||
)[:, -1, :]
|
||||
for k, head in self.classification_heads.items():
|
||||
# for torch script only supports iteration
|
||||
if k == classification_head_name:
|
||||
|
@ -25,7 +25,10 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
_SLOWMO_DDP_DISABLED = False
|
||||
try:
|
||||
from fairscale.experimental.nn.data_parallel import SlowMoBaseAlgorithm, SlowMoDistributedDataParallel
|
||||
from fairscale.experimental.nn.data_parallel import (
|
||||
SlowMoBaseAlgorithm,
|
||||
SlowMoDistributedDataParallel,
|
||||
)
|
||||
except ImportError:
|
||||
_SLOWMO_DDP_DISABLED = True
|
||||
|
||||
|
@ -22,6 +22,7 @@ import copy
|
||||
import logging
|
||||
|
||||
import torch
|
||||
|
||||
from fairseq import checkpoint_utils
|
||||
|
||||
|
||||
@ -78,7 +79,9 @@ class EMA(object):
|
||||
self.fp32_params = {}
|
||||
|
||||
if self.config.ema_seed_model is not None:
|
||||
state = checkpoint_utils.load_ema_from_checkpoint(self.config.ema_seed_model)
|
||||
state = checkpoint_utils.load_ema_from_checkpoint(
|
||||
self.config.ema_seed_model
|
||||
)
|
||||
self.model.load_state_dict(state["model"], strict=True)
|
||||
|
||||
if device is not None:
|
||||
@ -119,7 +122,7 @@ class EMA(object):
|
||||
self.fp32_params[param_key] = _to_float(state_dict[param_key])
|
||||
|
||||
def restore(self, state_dict, build_fp32_params=False):
|
||||
""" Load data from a model spec into EMA model """
|
||||
"""Load data from a model spec into EMA model"""
|
||||
self.model.load_state_dict(state_dict, strict=False)
|
||||
if build_fp32_params:
|
||||
self.build_fp32_params(state_dict)
|
||||
@ -131,16 +134,20 @@ class EMA(object):
|
||||
return self.decay
|
||||
|
||||
def _step_internal(self, new_model, updates=None):
|
||||
""" One update of the EMA model based on new model weights """
|
||||
"""One update of the EMA model based on new model weights"""
|
||||
decay = self.decay
|
||||
|
||||
ema_state_dict = {}
|
||||
ema_params = self.fp32_params if self.config.ema_fp32 else self.model.state_dict()
|
||||
ema_params = (
|
||||
self.fp32_params if self.config.ema_fp32 else self.model.state_dict()
|
||||
)
|
||||
for key, param in new_model.state_dict().items():
|
||||
try:
|
||||
ema_param = ema_params[key]
|
||||
except KeyError:
|
||||
ema_param = param.float().clone() if param.ndim == 1 else copy.deepcopy(param)
|
||||
ema_param = (
|
||||
param.float().clone() if param.ndim == 1 else copy.deepcopy(param)
|
||||
)
|
||||
|
||||
if param.shape != ema_param.shape:
|
||||
raise ValueError(
|
||||
@ -151,7 +158,7 @@ class EMA(object):
|
||||
# Do not decay a model.version pytorch param
|
||||
continue
|
||||
ema_param.mul_(decay)
|
||||
ema_param.add_(param.to(dtype=ema_param.dtype), alpha=1-decay)
|
||||
ema_param.add_(param.to(dtype=ema_param.dtype), alpha=1 - decay)
|
||||
ema_state_dict[key] = ema_param
|
||||
self.restore(ema_state_dict, build_fp32_params=False)
|
||||
|
||||
@ -168,8 +175,7 @@ class EMA(object):
|
||||
"""
|
||||
self._set_decay(
|
||||
0
|
||||
if updates is not None
|
||||
and updates < self.config.ema_start_update
|
||||
if updates is not None and updates < self.config.ema_start_update
|
||||
else self.config.ema_decay
|
||||
)
|
||||
if updates is not None and self.config.ema_update_freq > 1:
|
||||
|
@ -19,7 +19,6 @@ class FairseqDecoder(nn.Module):
|
||||
self.onnx_trace = False
|
||||
self.adaptive_softmax = None
|
||||
|
||||
|
||||
def forward(self, prev_output_tokens, encoder_out=None, **kwargs):
|
||||
"""
|
||||
Args:
|
||||
|
@ -29,8 +29,9 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
def check_type(module, expected_type):
|
||||
if hasattr(module, "unwrapped_module"):
|
||||
assert isinstance(module.unwrapped_module, expected_type), \
|
||||
f"{type(module.unwrapped_module)} != {expected_type}"
|
||||
assert isinstance(
|
||||
module.unwrapped_module, expected_type
|
||||
), f"{type(module.unwrapped_module)} != {expected_type}"
|
||||
else:
|
||||
assert isinstance(module, expected_type), f"{type(module)} != {expected_type}"
|
||||
|
||||
@ -114,7 +115,9 @@ class BaseFairseqModel(nn.Module):
|
||||
"""
|
||||
|
||||
if model_cfg is None and args is not None:
|
||||
logger.warn("using 'args' is deprecated, please update your code to use dataclass config")
|
||||
logger.warn(
|
||||
"using 'args' is deprecated, please update your code to use dataclass config"
|
||||
)
|
||||
model_cfg = convert_namespace_to_omegaconf(args).model
|
||||
|
||||
self.upgrade_state_dict(state_dict)
|
||||
@ -454,7 +457,9 @@ class FairseqMultiModel(BaseFairseqModel):
|
||||
"""
|
||||
|
||||
if model_cfg is None and args is not None:
|
||||
logger.warn("using 'args' is deprecated, please update your code to use dataclass config")
|
||||
logger.warn(
|
||||
"using 'args' is deprecated, please update your code to use dataclass config"
|
||||
)
|
||||
model_cfg = convert_namespace_to_omegaconf(args).model
|
||||
|
||||
self.upgrade_state_dict(state_dict)
|
||||
|
@ -30,9 +30,7 @@ from omegaconf import II
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
EXTRACTOR_MODE_CHOICES = ChoiceEnum(["default", "layer_norm"])
|
||||
MASKING_DISTRIBUTION_CHOICES = ChoiceEnum(
|
||||
["static", "uniform", "normal", "poisson"]
|
||||
)
|
||||
MASKING_DISTRIBUTION_CHOICES = ChoiceEnum(["static", "uniform", "normal", "poisson"])
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -86,9 +84,7 @@ class HubertConfig(FairseqDataclass):
|
||||
)
|
||||
dropout_features: float = field(
|
||||
default=0.0,
|
||||
metadata={
|
||||
"help": "dropout to apply to the features (after feat extr)"
|
||||
},
|
||||
metadata={"help": "dropout to apply to the features (after feat extr)"},
|
||||
)
|
||||
|
||||
final_dim: int = field(
|
||||
@ -150,9 +146,7 @@ class HubertConfig(FairseqDataclass):
|
||||
)
|
||||
mask_min_space: int = field(
|
||||
default=1,
|
||||
metadata={
|
||||
"help": "min space between spans (if no overlap is enabled)"
|
||||
},
|
||||
metadata={"help": "min space between spans (if no overlap is enabled)"},
|
||||
)
|
||||
|
||||
# channel masking
|
||||
@ -182,23 +176,17 @@ class HubertConfig(FairseqDataclass):
|
||||
)
|
||||
mask_channel_min_space: int = field(
|
||||
default=1,
|
||||
metadata={
|
||||
"help": "min space between spans (if no overlap is enabled)"
|
||||
},
|
||||
metadata={"help": "min space between spans (if no overlap is enabled)"},
|
||||
)
|
||||
|
||||
# positional embeddings
|
||||
conv_pos: int = field(
|
||||
default=128,
|
||||
metadata={
|
||||
"help": "number of filters for convolutional positional embeddings"
|
||||
},
|
||||
metadata={"help": "number of filters for convolutional positional embeddings"},
|
||||
)
|
||||
conv_pos_groups: int = field(
|
||||
default=16,
|
||||
metadata={
|
||||
"help": "number of groups for convolutional positional embedding"
|
||||
},
|
||||
metadata={"help": "number of groups for convolutional positional embedding"},
|
||||
)
|
||||
|
||||
latent_temp: Tuple[float, float, float] = field(
|
||||
@ -238,9 +226,7 @@ class HubertModel(BaseFairseqModel):
|
||||
conv_bias=cfg.conv_bias,
|
||||
)
|
||||
feature_ds_rate = np.prod([s for _, _, s in feature_enc_layers])
|
||||
self.feat2tar_ratio = (
|
||||
cfg.label_rate * feature_ds_rate / task_cfg.sample_rate
|
||||
)
|
||||
self.feat2tar_ratio = cfg.label_rate * feature_ds_rate / task_cfg.sample_rate
|
||||
|
||||
self.post_extract_proj = (
|
||||
nn.Linear(self.embed, cfg.encoder_embed_dim)
|
||||
@ -270,9 +256,7 @@ class HubertModel(BaseFairseqModel):
|
||||
self.skip_masked = cfg.skip_masked
|
||||
self.skip_nomask = cfg.skip_nomask
|
||||
|
||||
final_dim = (
|
||||
cfg.final_dim if cfg.final_dim > 0 else cfg.encoder_embed_dim
|
||||
)
|
||||
final_dim = cfg.final_dim if cfg.final_dim > 0 else cfg.encoder_embed_dim
|
||||
|
||||
self.mask_emb = nn.Parameter(
|
||||
torch.FloatTensor(cfg.encoder_embed_dim).uniform_()
|
||||
@ -297,9 +281,7 @@ class HubertModel(BaseFairseqModel):
|
||||
|
||||
# modules below are not needed during fine-tuning
|
||||
if any([d is None for d in dictionaries]):
|
||||
logger.info(
|
||||
"cannot find dictionary. assume will be used for fine-tuning"
|
||||
)
|
||||
logger.info("cannot find dictionary. assume will be used for fine-tuning")
|
||||
else:
|
||||
self.num_classes = [len(d) for d in dictionaries]
|
||||
self.label_embs_concat = nn.Parameter(
|
||||
@ -365,9 +347,7 @@ class HubertModel(BaseFairseqModel):
|
||||
pos = pos.unsqueeze(0)
|
||||
targets = torch.cat([pos, negs], dim=0)
|
||||
|
||||
logits = torch.cosine_similarity(
|
||||
x.float(), targets.float(), dim=-1
|
||||
).type_as(x)
|
||||
logits = torch.cosine_similarity(x.float(), targets.float(), dim=-1).type_as(x)
|
||||
logits /= self.logit_temp
|
||||
if neg_is_pos.any():
|
||||
logits[1:][neg_is_pos] = float("-inf")
|
||||
@ -385,7 +365,9 @@ class HubertModel(BaseFairseqModel):
|
||||
return features
|
||||
|
||||
def forward_targets(
|
||||
self, features: torch.Tensor, target_list: List[torch.Tensor],
|
||||
self,
|
||||
features: torch.Tensor,
|
||||
target_list: List[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Trim features to ensure labels exist and then get aligned labels
|
||||
feat_tsz = features.size(2)
|
||||
@ -398,14 +380,14 @@ class HubertModel(BaseFairseqModel):
|
||||
return features, target_list
|
||||
|
||||
def forward_padding_mask(
|
||||
self, features: torch.Tensor, padding_mask: torch.Tensor,
|
||||
self,
|
||||
features: torch.Tensor,
|
||||
padding_mask: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
extra = padding_mask.size(1) % features.size(1)
|
||||
if extra > 0:
|
||||
padding_mask = padding_mask[:, :-extra]
|
||||
padding_mask = padding_mask.view(
|
||||
padding_mask.size(0), features.size(1), -1
|
||||
)
|
||||
padding_mask = padding_mask.view(padding_mask.size(0), features.size(1), -1)
|
||||
padding_mask = padding_mask.all(-1)
|
||||
return padding_mask
|
||||
|
||||
@ -439,9 +421,7 @@ class HubertModel(BaseFairseqModel):
|
||||
unmasked_features = self.dropout_features(unmasked_features)
|
||||
|
||||
if mask:
|
||||
x, mask_indices = self.apply_mask(
|
||||
features, padding_mask, target_list
|
||||
)
|
||||
x, mask_indices = self.apply_mask(features, padding_mask, target_list)
|
||||
else:
|
||||
x = features
|
||||
mask_indices = None
|
||||
@ -454,7 +434,7 @@ class HubertModel(BaseFairseqModel):
|
||||
x, _ = self.encoder(
|
||||
x,
|
||||
padding_mask=padding_mask,
|
||||
layer=None if output_layer is None else output_layer - 1
|
||||
layer=None if output_layer is None else output_layer - 1,
|
||||
)
|
||||
|
||||
if features_only:
|
||||
@ -483,9 +463,7 @@ class HubertModel(BaseFairseqModel):
|
||||
proj_x_m_list = [proj_x_m for _ in range(len(target_list))]
|
||||
logit_m_list = [
|
||||
compute_pred(proj_x_m, t[masked_indices], label_embs_list[i])
|
||||
for i, (proj_x_m, t) in enumerate(
|
||||
zip(proj_x_m_list, target_list)
|
||||
)
|
||||
for i, (proj_x_m, t) in enumerate(zip(proj_x_m_list, target_list))
|
||||
]
|
||||
else:
|
||||
logit_m_list = [None for _ in target_list]
|
||||
@ -500,9 +478,7 @@ class HubertModel(BaseFairseqModel):
|
||||
|
||||
logit_u_list = [
|
||||
compute_pred(proj_x_u, t[nomask_indices], label_embs_list[i])
|
||||
for i, (proj_x_u, t) in enumerate(
|
||||
zip(proj_x_u_list, target_list)
|
||||
)
|
||||
for i, (proj_x_u, t) in enumerate(zip(proj_x_u_list, target_list))
|
||||
]
|
||||
else:
|
||||
logit_u_list = [None for _ in target_list]
|
||||
@ -543,9 +519,7 @@ class HubertModel(BaseFairseqModel):
|
||||
|
||||
def get_targets(self, net_output, is_masked=True):
|
||||
logits_list = self.get_logits(net_output, is_masked)
|
||||
targets_list = [
|
||||
x.new_zeros(x.size(0), dtype=torch.long) for x in logits_list
|
||||
]
|
||||
targets_list = [x.new_zeros(x.size(0), dtype=torch.long) for x in logits_list]
|
||||
return targets_list
|
||||
|
||||
def get_extra_losses(self, net_output):
|
||||
|
@ -21,9 +21,7 @@ from omegaconf import II, MISSING
|
||||
|
||||
@dataclass
|
||||
class HubertAsrConfig(FairseqDataclass):
|
||||
w2v_path: str = field(
|
||||
default=MISSING, metadata={"help": "path to hubert model"}
|
||||
)
|
||||
w2v_path: str = field(default=MISSING, metadata={"help": "path to hubert model"})
|
||||
no_pretrained_weights: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "if true, does not load pretrained weights"},
|
||||
@ -34,9 +32,7 @@ class HubertAsrConfig(FairseqDataclass):
|
||||
)
|
||||
final_dropout: float = field(
|
||||
default=0.0,
|
||||
metadata={
|
||||
"help": "dropout after transformer and before final projection"
|
||||
},
|
||||
metadata={"help": "dropout after transformer and before final projection"},
|
||||
)
|
||||
dropout: float = field(
|
||||
default=0.0,
|
||||
@ -45,15 +41,13 @@ class HubertAsrConfig(FairseqDataclass):
|
||||
attention_dropout: float = field(
|
||||
default=0.0,
|
||||
metadata={
|
||||
"help": "dropout probability for attention weights "
|
||||
"inside hubert model"
|
||||
"help": "dropout probability for attention weights " "inside hubert model"
|
||||
},
|
||||
)
|
||||
activation_dropout: float = field(
|
||||
default=0.0,
|
||||
metadata={
|
||||
"help": "dropout probability after activation in FFN "
|
||||
"inside hubert model"
|
||||
"help": "dropout probability after activation in FFN " "inside hubert model"
|
||||
},
|
||||
)
|
||||
|
||||
@ -184,9 +178,7 @@ class HubertSeq2SeqConfig(HubertAsrConfig):
|
||||
decoder_ffn_embed_dim: int = field(
|
||||
default=3072, metadata={"help": "decoder embedding dimension for FFN"}
|
||||
)
|
||||
decoder_layers: int = field(
|
||||
default=6, metadata={"help": "num of decoder layers"}
|
||||
)
|
||||
decoder_layers: int = field(default=6, metadata={"help": "num of decoder layers"})
|
||||
decoder_layerdrop: float = field(
|
||||
default=0.0, metadata={"help": "decoder layerdrop chance"}
|
||||
)
|
||||
@ -204,8 +196,7 @@ class HubertSeq2SeqConfig(HubertAsrConfig):
|
||||
no_token_positional_embeddings: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "if set, disables positional embeddings "
|
||||
"(outside self attention)"
|
||||
"help": "if set, disables positional embeddings " "(outside self attention)"
|
||||
},
|
||||
)
|
||||
decoder_dropout: float = field(
|
||||
@ -214,15 +205,13 @@ class HubertSeq2SeqConfig(HubertAsrConfig):
|
||||
decoder_attention_dropout: float = field(
|
||||
default=0.0,
|
||||
metadata={
|
||||
"help": "dropout probability for attention weights "
|
||||
"inside the decoder"
|
||||
"help": "dropout probability for attention weights " "inside the decoder"
|
||||
},
|
||||
)
|
||||
decoder_activation_dropout: float = field(
|
||||
default=0.0,
|
||||
metadata={
|
||||
"help": "dropout probability after activation in FFN "
|
||||
"inside the decoder"
|
||||
"help": "dropout probability after activation in FFN " "inside the decoder"
|
||||
},
|
||||
)
|
||||
max_target_positions: int = field(
|
||||
@ -258,9 +247,7 @@ class HubertEncoder(FairseqEncoder):
|
||||
}
|
||||
|
||||
if cfg.w2v_args is None:
|
||||
state = checkpoint_utils.load_checkpoint_to_cpu(
|
||||
cfg.w2v_path, arg_overrides
|
||||
)
|
||||
state = checkpoint_utils.load_checkpoint_to_cpu(cfg.w2v_path, arg_overrides)
|
||||
w2v_args = state.get("cfg", None)
|
||||
if w2v_args is None:
|
||||
w2v_args = convert_namespace_to_omegaconf(state["args"])
|
||||
@ -269,9 +256,7 @@ class HubertEncoder(FairseqEncoder):
|
||||
state = None
|
||||
w2v_args = cfg.w2v_args
|
||||
if isinstance(w2v_args, Namespace):
|
||||
cfg.w2v_args = w2v_args = convert_namespace_to_omegaconf(
|
||||
w2v_args
|
||||
)
|
||||
cfg.w2v_args = w2v_args = convert_namespace_to_omegaconf(w2v_args)
|
||||
|
||||
assert cfg.normalize == w2v_args.task.normalize, (
|
||||
"Fine-tuning works best when data normalization is the same. "
|
||||
@ -344,9 +329,9 @@ class HubertEncoder(FairseqEncoder):
|
||||
|
||||
def reorder_encoder_out(self, encoder_out, new_order):
|
||||
if encoder_out["encoder_out"] is not None:
|
||||
encoder_out["encoder_out"] = encoder_out[
|
||||
"encoder_out"
|
||||
].index_select(1, new_order)
|
||||
encoder_out["encoder_out"] = encoder_out["encoder_out"].index_select(
|
||||
1, new_order
|
||||
)
|
||||
if encoder_out["encoder_padding_mask"] is not None:
|
||||
encoder_out["encoder_padding_mask"] = encoder_out[
|
||||
"encoder_padding_mask"
|
||||
|
@ -225,10 +225,10 @@ class LSTMEncoder(FairseqEncoder):
|
||||
super().__init__(dictionary)
|
||||
self.num_layers = num_layers
|
||||
self.dropout_in_module = FairseqDropout(
|
||||
dropout_in*1.0, module_name=self.__class__.__name__
|
||||
dropout_in * 1.0, module_name=self.__class__.__name__
|
||||
)
|
||||
self.dropout_out_module = FairseqDropout(
|
||||
dropout_out*1.0, module_name=self.__class__.__name__
|
||||
dropout_out * 1.0, module_name=self.__class__.__name__
|
||||
)
|
||||
self.bidirectional = bidirectional
|
||||
self.hidden_size = hidden_size
|
||||
@ -329,7 +329,9 @@ class LSTMEncoder(FairseqEncoder):
|
||||
out = outs.view(self.num_layers, 2, bsz, -1).transpose(1, 2).contiguous()
|
||||
return out.view(self.num_layers, bsz, -1)
|
||||
|
||||
def reorder_encoder_out(self, encoder_out: Tuple[Tensor, Tensor, Tensor, Tensor], new_order):
|
||||
def reorder_encoder_out(
|
||||
self, encoder_out: Tuple[Tensor, Tensor, Tensor, Tensor], new_order
|
||||
):
|
||||
return tuple(
|
||||
(
|
||||
encoder_out[0].index_select(1, new_order),
|
||||
@ -402,10 +404,10 @@ class LSTMDecoder(FairseqIncrementalDecoder):
|
||||
):
|
||||
super().__init__(dictionary)
|
||||
self.dropout_in_module = FairseqDropout(
|
||||
dropout_in*1.0, module_name=self.__class__.__name__
|
||||
dropout_in * 1.0, module_name=self.__class__.__name__
|
||||
)
|
||||
self.dropout_out_module = FairseqDropout(
|
||||
dropout_out*1.0, module_name=self.__class__.__name__
|
||||
dropout_out * 1.0, module_name=self.__class__.__name__
|
||||
)
|
||||
self.hidden_size = hidden_size
|
||||
self.share_input_output_embed = share_input_output_embed
|
||||
|
@ -18,7 +18,10 @@ def ensemble_encoder(func):
|
||||
def wrapper(self, *args, **kwargs):
|
||||
if self.ensemble_models is None or len(self.ensemble_models) == 1:
|
||||
return func(self, *args, **kwargs)
|
||||
encoder_outs = [func(model, *args, **kwargs, return_all_hiddens=True) for model in self.ensemble_models]
|
||||
encoder_outs = [
|
||||
func(model, *args, **kwargs, return_all_hiddens=True)
|
||||
for model in self.ensemble_models
|
||||
]
|
||||
_encoder_out = encoder_outs[0].copy()
|
||||
|
||||
def stack(key):
|
||||
@ -56,8 +59,7 @@ def ensemble_decoder(func):
|
||||
model,
|
||||
normalize=normalize,
|
||||
encoder_out=_replace(
|
||||
encoder_out,
|
||||
encoder_out["encoder_out"][0][:, :, :, i]
|
||||
encoder_out, encoder_out["encoder_out"][0][:, :, :, i]
|
||||
),
|
||||
*args,
|
||||
**kwargs
|
||||
|
@ -85,7 +85,8 @@ class EnsembleLevT(BasicEnsembleModel):
|
||||
else:
|
||||
if not encoder_outs[0]["encoder_padding_mask"]:
|
||||
src_lens = (
|
||||
encoder_outs[0]["encoder_out"][0].new(bsz)
|
||||
encoder_outs[0]["encoder_out"][0]
|
||||
.new(bsz)
|
||||
.fill_(encoder_outs[0]["encoder_out"][0].size(1))
|
||||
)
|
||||
else:
|
||||
|
@ -183,7 +183,7 @@ class RobertaModel(FairseqEncoderModel):
|
||||
"communication less efficient due to smaller input sizes. This option "
|
||||
"is set to 0 (i.e., always wrap) when --checkpoint-activations or "
|
||||
"--offload-activations are passed."
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -542,7 +542,9 @@ def base_architecture(args):
|
||||
args.layernorm_embedding = safe_getattr(args, "layernorm_embedding", True)
|
||||
args.no_scale_embedding = safe_getattr(args, "no_scale_embedding", True)
|
||||
args.activation_fn = safe_getattr(args, "activation_fn", "gelu")
|
||||
args.encoder_normalize_before = safe_getattr(args, "encoder_normalize_before", False)
|
||||
args.encoder_normalize_before = safe_getattr(
|
||||
args, "encoder_normalize_before", False
|
||||
)
|
||||
args.pooler_activation_fn = safe_getattr(args, "pooler_activation_fn", "tanh")
|
||||
args.untie_weights_roberta = safe_getattr(args, "untie_weights_roberta", False)
|
||||
|
||||
|
@ -12,26 +12,26 @@ from .hub_interface import RobertaHubInterface
|
||||
from .model import RobertaModel
|
||||
|
||||
|
||||
@register_model('gottbert')
|
||||
@register_model("gottbert")
|
||||
class GottbertModel(RobertaModel):
|
||||
|
||||
@classmethod
|
||||
def hub_models(cls):
|
||||
return {
|
||||
'gottbert-base': 'https://dl.gottbert.de/fairseq/models/gottbert-base.tar.gz',
|
||||
"gottbert-base": "https://dl.gottbert.de/fairseq/models/gottbert-base.tar.gz",
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls,
|
||||
model_name_or_path,
|
||||
checkpoint_file='model.pt',
|
||||
data_name_or_path='.',
|
||||
bpe='hf_byte_bpe',
|
||||
bpe_vocab='vocab.json',
|
||||
bpe_merges='merges.txt',
|
||||
bpe_add_prefix_space=False,
|
||||
**kwargs
|
||||
):
|
||||
def from_pretrained(
|
||||
cls,
|
||||
model_name_or_path,
|
||||
checkpoint_file="model.pt",
|
||||
data_name_or_path=".",
|
||||
bpe="hf_byte_bpe",
|
||||
bpe_vocab="vocab.json",
|
||||
bpe_merges="merges.txt",
|
||||
bpe_add_prefix_space=False,
|
||||
**kwargs
|
||||
):
|
||||
from fairseq import hub_utils
|
||||
|
||||
x = hub_utils.from_pretrained(
|
||||
@ -46,4 +46,4 @@ class GottbertModel(RobertaModel):
|
||||
bpe_add_prefix_space=bpe_add_prefix_space,
|
||||
**kwargs,
|
||||
)
|
||||
return RobertaHubInterface(x['args'], x['task'], x['models'][0])
|
||||
return RobertaHubInterface(x["args"], x["task"], x["models"][0])
|
||||
|
@ -202,10 +202,10 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
|
||||
help="model to take encoder weights from (for initialization)",
|
||||
)
|
||||
parser.add_argument(
|
||||
'--encoder-freezing-updates',
|
||||
"--encoder-freezing-updates",
|
||||
type=int,
|
||||
metavar='N',
|
||||
help='freeze encoder for first N updates'
|
||||
metavar="N",
|
||||
help="freeze encoder for first N updates",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -329,7 +329,9 @@ class S2TTransformerEncoder(FairseqEncoder):
|
||||
|
||||
return {
|
||||
"encoder_out": [x], # T x B x C
|
||||
"encoder_padding_mask": [encoder_padding_mask] if encoder_padding_mask.any() else [], # B x T
|
||||
"encoder_padding_mask": [encoder_padding_mask]
|
||||
if encoder_padding_mask.any()
|
||||
else [], # B x T
|
||||
"encoder_embedding": [], # B x T x C
|
||||
"encoder_states": encoder_states, # List[T x B x C]
|
||||
"src_tokens": [],
|
||||
@ -339,27 +341,37 @@ class S2TTransformerEncoder(FairseqEncoder):
|
||||
def forward(self, src_tokens, src_lengths, return_all_hiddens=False):
|
||||
if self.num_updates < self.encoder_freezing_updates:
|
||||
with torch.no_grad():
|
||||
x = self._forward(src_tokens, src_lengths,
|
||||
return_all_hiddens=return_all_hiddens)
|
||||
x = self._forward(
|
||||
src_tokens, src_lengths, return_all_hiddens=return_all_hiddens
|
||||
)
|
||||
else:
|
||||
x = self._forward(src_tokens, src_lengths,
|
||||
return_all_hiddens=return_all_hiddens)
|
||||
x = self._forward(
|
||||
src_tokens, src_lengths, return_all_hiddens=return_all_hiddens
|
||||
)
|
||||
return x
|
||||
|
||||
def reorder_encoder_out(self, encoder_out, new_order):
|
||||
new_encoder_out = (
|
||||
[] if len(encoder_out["encoder_out"]) == 0
|
||||
[]
|
||||
if len(encoder_out["encoder_out"]) == 0
|
||||
else [x.index_select(1, new_order) for x in encoder_out["encoder_out"]]
|
||||
)
|
||||
|
||||
new_encoder_padding_mask = (
|
||||
[] if len(encoder_out["encoder_padding_mask"]) == 0
|
||||
else [x.index_select(0, new_order) for x in encoder_out["encoder_padding_mask"]]
|
||||
[]
|
||||
if len(encoder_out["encoder_padding_mask"]) == 0
|
||||
else [
|
||||
x.index_select(0, new_order)
|
||||
for x in encoder_out["encoder_padding_mask"]
|
||||
]
|
||||
)
|
||||
|
||||
new_encoder_embedding = (
|
||||
[] if len(encoder_out["encoder_embedding"]) == 0
|
||||
else [x.index_select(0, new_order) for x in encoder_out["encoder_embedding"]]
|
||||
[]
|
||||
if len(encoder_out["encoder_embedding"]) == 0
|
||||
else [
|
||||
x.index_select(0, new_order) for x in encoder_out["encoder_embedding"]
|
||||
]
|
||||
)
|
||||
|
||||
encoder_states = encoder_out["encoder_states"]
|
||||
|
@ -9,8 +9,12 @@ import copy
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from fairseq import utils, checkpoint_utils
|
||||
from fairseq.models import (FairseqEncoderDecoderModel, FairseqEncoder,
|
||||
register_model, register_model_architecture)
|
||||
from fairseq.models import (
|
||||
FairseqEncoderDecoderModel,
|
||||
FairseqEncoder,
|
||||
register_model,
|
||||
register_model_architecture,
|
||||
)
|
||||
from fairseq.models.transformer import Embedding, TransformerDecoder
|
||||
from fairseq.models.wav2vec import Wav2VecEncoder
|
||||
from fairseq.modules.layer_norm import LayerNorm
|
||||
@ -24,18 +28,23 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Conv1dAdaptor(nn.Module):
|
||||
def __init__(self, in_dim, out_dim, n_layers=3, kernel_size=3, stride=2,
|
||||
add_layernorm=False):
|
||||
def __init__(
|
||||
self, in_dim, out_dim, n_layers=3, kernel_size=3, stride=2, add_layernorm=False
|
||||
):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList(
|
||||
nn.Conv1d(in_dim if i == 0 else out_dim, out_dim * 2, kernel_size,
|
||||
stride=stride, padding=kernel_size // 2)
|
||||
nn.Conv1d(
|
||||
in_dim if i == 0 else out_dim,
|
||||
out_dim * 2,
|
||||
kernel_size,
|
||||
stride=stride,
|
||||
padding=kernel_size // 2,
|
||||
)
|
||||
for i in range(n_layers)
|
||||
)
|
||||
self.layernorms = None
|
||||
if add_layernorm:
|
||||
self.layernorms = nn.ModuleList(LayerNorm(out_dim)
|
||||
for _ in range(n_layers))
|
||||
self.layernorms = nn.ModuleList(LayerNorm(out_dim) for _ in range(n_layers))
|
||||
self.stride = stride
|
||||
|
||||
@classmethod
|
||||
@ -43,7 +52,7 @@ class Conv1dAdaptor(nn.Module):
|
||||
parser.add_argument("--adaptor-n-layers", type=int)
|
||||
parser.add_argument("--adaptor-kernel-size", type=int)
|
||||
parser.add_argument("--adaptor-stride", type=int)
|
||||
parser.add_argument("--adaptor-layernorm", action='store_true')
|
||||
parser.add_argument("--adaptor-layernorm", action="store_true")
|
||||
|
||||
def get_out_seq_lens_tensor(self, in_seq_lens_tensor):
|
||||
out = in_seq_lens_tensor.clone()
|
||||
@ -197,15 +206,18 @@ class Wav2VecEncoderWithAdaptor(FairseqEncoder):
|
||||
encoder_out_dim = self.w2v_encoder.w2v_model.encoder.embedding_dim
|
||||
# Projection + 8x shrinking
|
||||
self.adaptor = Conv1dAdaptor(
|
||||
encoder_out_dim, args.decoder_embed_dim,
|
||||
encoder_out_dim,
|
||||
args.decoder_embed_dim,
|
||||
n_layers=args.adaptor_n_layers,
|
||||
kernel_size=args.adaptor_kernel_size, stride=args.adaptor_stride,
|
||||
add_layernorm=args.adaptor_layernorm
|
||||
kernel_size=args.adaptor_kernel_size,
|
||||
stride=args.adaptor_stride,
|
||||
add_layernorm=args.adaptor_layernorm,
|
||||
)
|
||||
for k, p in self.w2v_encoder.w2v_model.named_parameters():
|
||||
# Freeze pretrained models by default
|
||||
if safe_hasattr(args, 'finetune_w2v_params') and XMTransformerModel.finetune_params(
|
||||
args.finetune_w2v_params, k):
|
||||
if safe_hasattr(
|
||||
args, "finetune_w2v_params"
|
||||
) and XMTransformerModel.finetune_params(args.finetune_w2v_params, k):
|
||||
p.requires_grad = True
|
||||
else:
|
||||
p.requires_grad = False
|
||||
@ -214,11 +226,16 @@ class Wav2VecEncoderWithAdaptor(FairseqEncoder):
|
||||
def add_args(cls, parser):
|
||||
add_wav2vec_asr_args(parser)
|
||||
parser.add_argument(
|
||||
"--normalize", action="store_true",
|
||||
"--normalize",
|
||||
action="store_true",
|
||||
help="if set, normalizes input to have 0 mean and unit variance",
|
||||
)
|
||||
parser.add_argument("--finetune-w2v-params", type=str, metavar="STR",
|
||||
help="comma-separated param strings to finetune.")
|
||||
parser.add_argument(
|
||||
"--finetune-w2v-params",
|
||||
type=str,
|
||||
metavar="STR",
|
||||
help="comma-separated param strings to finetune.",
|
||||
)
|
||||
Conv1dAdaptor.add_args(parser)
|
||||
|
||||
def forward(self, src_tokens, src_lengths=None, **kwargs):
|
||||
@ -227,13 +244,17 @@ class Wav2VecEncoderWithAdaptor(FairseqEncoder):
|
||||
x = out["encoder_out"]
|
||||
enc_padding_mask = None
|
||||
if out["encoder_padding_mask"] is not None:
|
||||
enc_padding_mask = out["encoder_padding_mask"].transpose(0, 1) # T X B --> B X T
|
||||
enc_padding_mask = out["encoder_padding_mask"].transpose(
|
||||
0, 1
|
||||
) # T X B --> B X T
|
||||
|
||||
x, enc_padding_mask = self.adaptor(x, enc_padding_mask)
|
||||
|
||||
return {
|
||||
"encoder_out": [x], # T x B x C
|
||||
"encoder_padding_mask": [enc_padding_mask] if enc_padding_mask.any() else [], # B x T
|
||||
"encoder_padding_mask": [enc_padding_mask]
|
||||
if enc_padding_mask.any()
|
||||
else [], # B x T
|
||||
"encoder_embedding": [], # B x T x C
|
||||
"encoder_states": [], # List[T x B x C]
|
||||
"src_tokens": [],
|
||||
@ -242,20 +263,26 @@ class Wav2VecEncoderWithAdaptor(FairseqEncoder):
|
||||
|
||||
def reorder_encoder_out(self, encoder_out, new_order):
|
||||
new_encoder_out = (
|
||||
[] if len(encoder_out["encoder_out"]) == 0
|
||||
[]
|
||||
if len(encoder_out["encoder_out"]) == 0
|
||||
else [x.index_select(1, new_order) for x in encoder_out["encoder_out"]]
|
||||
)
|
||||
|
||||
new_encoder_padding_mask = (
|
||||
[] if len(encoder_out["encoder_padding_mask"]) == 0
|
||||
else [x.index_select(0, new_order) for x in
|
||||
encoder_out["encoder_padding_mask"]]
|
||||
[]
|
||||
if len(encoder_out["encoder_padding_mask"]) == 0
|
||||
else [
|
||||
x.index_select(0, new_order)
|
||||
for x in encoder_out["encoder_padding_mask"]
|
||||
]
|
||||
)
|
||||
|
||||
new_encoder_embedding = (
|
||||
[] if len(encoder_out["encoder_embedding"]) == 0
|
||||
else [x.index_select(0, new_order) for x in
|
||||
encoder_out["encoder_embedding"]]
|
||||
[]
|
||||
if len(encoder_out["encoder_embedding"]) == 0
|
||||
else [
|
||||
x.index_select(0, new_order) for x in encoder_out["encoder_embedding"]
|
||||
]
|
||||
)
|
||||
|
||||
encoder_states = encoder_out["encoder_states"]
|
||||
@ -274,38 +301,71 @@ class Wav2VecEncoderWithAdaptor(FairseqEncoder):
|
||||
|
||||
|
||||
def add_decoder_args(parser):
|
||||
parser.add_argument("--activation-fn", type=str, default='relu',
|
||||
choices=utils.get_available_activation_fns(),
|
||||
help="activation function to use")
|
||||
parser.add_argument("--decoder-dropout", type=float, metavar="D",
|
||||
help="dropout probability")
|
||||
parser.add_argument("--decoder-attention-dropout", type=float,
|
||||
metavar="D",
|
||||
help="dropout probability for attention weights")
|
||||
parser.add_argument("--decoder-activation-dropout", type=float,
|
||||
metavar="D",
|
||||
help="dropout probability after activation in FFN.")
|
||||
parser.add_argument("--decoder-embed-dim", type=int, metavar="N",
|
||||
help="decoder embedding dimension")
|
||||
parser.add_argument("--decoder-ffn-embed-dim", type=int, metavar="N",
|
||||
help="decoder embedding dimension for FFN")
|
||||
parser.add_argument("--decoder-layers", type=int, metavar="N",
|
||||
help="num decoder layers")
|
||||
parser.add_argument("--decoder-attention-heads", type=int, metavar="N",
|
||||
help="num decoder attention heads")
|
||||
parser.add_argument("--decoder-normalize-before", action="store_true",
|
||||
help="apply layernorm before each decoder block")
|
||||
parser.add_argument("--layernorm-embedding", action="store_true",
|
||||
help="add layernorm to embedding")
|
||||
parser.add_argument("--no-scale-embedding", action="store_true",
|
||||
help="if True, dont scale embeddings")
|
||||
parser.add_argument(
|
||||
"--load-pretrained-decoder-from", type=str, metavar="STR",
|
||||
help="model to take decoder weights from (for initialization)"
|
||||
"--activation-fn",
|
||||
type=str,
|
||||
default="relu",
|
||||
choices=utils.get_available_activation_fns(),
|
||||
help="activation function to use",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--decoder-dropout", type=float, metavar="D", help="dropout probability"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--decoder-attention-dropout",
|
||||
type=float,
|
||||
metavar="D",
|
||||
help="dropout probability for attention weights",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--decoder-activation-dropout",
|
||||
type=float,
|
||||
metavar="D",
|
||||
help="dropout probability after activation in FFN.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--decoder-embed-dim", type=int, metavar="N", help="decoder embedding dimension"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--decoder-ffn-embed-dim",
|
||||
type=int,
|
||||
metavar="N",
|
||||
help="decoder embedding dimension for FFN",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--decoder-layers", type=int, metavar="N", help="num decoder layers"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--decoder-attention-heads",
|
||||
type=int,
|
||||
metavar="N",
|
||||
help="num decoder attention heads",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--decoder-normalize-before",
|
||||
action="store_true",
|
||||
help="apply layernorm before each decoder block",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--layernorm-embedding", action="store_true", help="add layernorm to embedding"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-scale-embedding",
|
||||
action="store_true",
|
||||
help="if True, dont scale embeddings",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--load-pretrained-decoder-from",
|
||||
type=str,
|
||||
metavar="STR",
|
||||
help="model to take decoder weights from (for initialization)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--finetune-decoder-params",
|
||||
type=str,
|
||||
metavar="STR",
|
||||
help="comma-separated param strings to finetune.",
|
||||
)
|
||||
parser.add_argument("--finetune-decoder-params", type=str,
|
||||
metavar="STR",
|
||||
help="comma-separated param strings to finetune.")
|
||||
parser.add_argument("--checkpoint-activations", action="store_true")
|
||||
|
||||
|
||||
@ -342,16 +402,16 @@ class XMTransformerModel(FairseqEncoderDecoderModel):
|
||||
_args.activation_dropout = args.decoder_activation_dropout
|
||||
_args.max_target_positions = 1024
|
||||
|
||||
decoder = TransformerDecoder(_args, task.target_dictionary,
|
||||
embed_tokens)
|
||||
decoder = TransformerDecoder(_args, task.target_dictionary, embed_tokens)
|
||||
if getattr(args, "load_pretrained_decoder_from", None):
|
||||
decoder = checkpoint_utils.load_pretrained_component_from_model(
|
||||
component=decoder, checkpoint=args.load_pretrained_decoder_from
|
||||
)
|
||||
for k, p in decoder.named_parameters():
|
||||
# Freeze pretrained models by default
|
||||
if safe_hasattr(args, 'finetune_decoder_params') and XMTransformerModel.finetune_params(
|
||||
args.finetune_decoder_params, k):
|
||||
if safe_hasattr(
|
||||
args, "finetune_decoder_params"
|
||||
) and XMTransformerModel.finetune_params(args.finetune_decoder_params, k):
|
||||
p.requires_grad = True
|
||||
else:
|
||||
p.requires_grad = False
|
||||
@ -369,8 +429,9 @@ class XMTransformerModel(FairseqEncoderDecoderModel):
|
||||
padding_idx = dictionary.pad()
|
||||
return Embedding(num_embeddings, embed_dim, padding_idx)
|
||||
|
||||
decoder_embed_tokens = build_embedding(task.target_dictionary,
|
||||
args.decoder_embed_dim)
|
||||
decoder_embed_tokens = build_embedding(
|
||||
task.target_dictionary, args.decoder_embed_dim
|
||||
)
|
||||
encoder = cls.build_encoder(args)
|
||||
decoder = cls.build_decoder(args, task, decoder_embed_tokens)
|
||||
return cls(encoder, decoder)
|
||||
@ -382,8 +443,7 @@ class XMTransformerModel(FairseqEncoderDecoderModel):
|
||||
sample: Optional[Dict[str, Tensor]] = None,
|
||||
):
|
||||
# net_output['encoder_out'] is a (B, T, D) tensor
|
||||
lprobs = self.get_normalized_probs_scriptable(net_output, log_probs,
|
||||
sample)
|
||||
lprobs = self.get_normalized_probs_scriptable(net_output, log_probs, sample)
|
||||
lprobs.batch_first = True
|
||||
return lprobs
|
||||
|
||||
@ -393,17 +453,19 @@ class XMTransformerModel(FairseqEncoderDecoderModel):
|
||||
argument in its input, which is not supported in torchscript. This
|
||||
method overrites the forward method definition without **kwargs.
|
||||
"""
|
||||
encoder_out = self.encoder(src_tokens=src_tokens,
|
||||
src_lengths=src_lengths, **kwargs)
|
||||
decoder_out = self.decoder(prev_output_tokens=prev_output_tokens,
|
||||
encoder_out=encoder_out)
|
||||
encoder_out = self.encoder(
|
||||
src_tokens=src_tokens, src_lengths=src_lengths, **kwargs
|
||||
)
|
||||
decoder_out = self.decoder(
|
||||
prev_output_tokens=prev_output_tokens, encoder_out=encoder_out
|
||||
)
|
||||
return decoder_out
|
||||
|
||||
def upgrade_state_dict(self, state_dict):
|
||||
for k, _ in state_dict.items():
|
||||
if 'adaptor.layers' in state_dict:
|
||||
if "adaptor.layers" in state_dict:
|
||||
print(k)
|
||||
new = k.replace('adaptor.layers', 'adaptor_layers')
|
||||
new = k.replace("adaptor.layers", "adaptor_layers")
|
||||
state_dict[new] = state_dict[k]
|
||||
del state_dict[k]
|
||||
|
||||
@ -435,11 +497,9 @@ def set_default_w2v_encoder_args(args):
|
||||
args.mask_channel_length = getattr(args, "mask_channel_length", 10)
|
||||
args.mask_channel_prob = getattr(args, "mask_channel_prob", 0.5)
|
||||
args.mask_channel_before = getattr(args, "mask_channel_before", False)
|
||||
args.mask_channel_selection = getattr(args, "mask_channel_selection",
|
||||
"static")
|
||||
args.mask_channel_selection = getattr(args, "mask_channel_selection", "static")
|
||||
args.mask_channel_other = getattr(args, "mask_channel_other", 0)
|
||||
args.no_mask_channel_overlap = getattr(args, "no_mask_channel_overlap",
|
||||
False)
|
||||
args.no_mask_channel_overlap = getattr(args, "no_mask_channel_overlap", False)
|
||||
|
||||
args.freeze_finetune_updates = getattr(args, "freeze_finetune_updates", 0)
|
||||
args.feature_grad_mult = 0.1
|
||||
@ -456,49 +516,43 @@ def set_default_adaptor_args(args):
|
||||
|
||||
|
||||
def set_default_mbart_decoder_args(args):
|
||||
args.decoder_embed_path = getattr(args, 'decoder_embed_path', None)
|
||||
args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 1024)
|
||||
args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim',
|
||||
4 * 1024)
|
||||
args.decoder_layers = getattr(args, 'decoder_layers', 12)
|
||||
args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 16)
|
||||
args.decoder_normalize_before = getattr(args, 'decoder_normalize_before',
|
||||
True)
|
||||
args.decoder_learned_pos = getattr(args, 'decoder_learned_pos', True)
|
||||
args.decoder_embed_path = getattr(args, "decoder_embed_path", None)
|
||||
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1024)
|
||||
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 4 * 1024)
|
||||
args.decoder_layers = getattr(args, "decoder_layers", 12)
|
||||
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16)
|
||||
args.decoder_normalize_before = getattr(args, "decoder_normalize_before", True)
|
||||
args.decoder_learned_pos = getattr(args, "decoder_learned_pos", True)
|
||||
args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0.0)
|
||||
args.adaptive_input = getattr(args, "adaptive_input", False)
|
||||
args.decoder_attention_dropout = getattr(args, 'decoder_attention_dropout',
|
||||
0.)
|
||||
args.decoder_activation_dropout = getattr(args,
|
||||
'decoder_activation_dropout', 0.)
|
||||
args.decoder_dropout = getattr(args, 'decoder_dropout', 0.1)
|
||||
args.adaptive_softmax_cutoff = getattr(args, 'adaptive_softmax_cutoff',
|
||||
None)
|
||||
args.adaptive_softmax_dropout = getattr(args, 'adaptive_softmax_dropout', 0)
|
||||
args.decoder_attention_dropout = getattr(args, "decoder_attention_dropout", 0.0)
|
||||
args.decoder_activation_dropout = getattr(args, "decoder_activation_dropout", 0.0)
|
||||
args.decoder_dropout = getattr(args, "decoder_dropout", 0.1)
|
||||
args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None)
|
||||
args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
|
||||
args.share_decoder_input_output_embed = getattr(
|
||||
args, 'share_decoder_input_output_embed', True
|
||||
args, "share_decoder_input_output_embed", True
|
||||
)
|
||||
args.no_token_positional_embeddings = getattr(
|
||||
args, "no_token_positional_embeddings", False
|
||||
)
|
||||
|
||||
args.decoder_output_dim = getattr(args, 'decoder_output_dim',
|
||||
args.decoder_embed_dim)
|
||||
args.decoder_input_dim = getattr(args, 'decoder_input_dim',
|
||||
args.decoder_embed_dim)
|
||||
args.decoder_output_dim = getattr(
|
||||
args, "decoder_output_dim", args.decoder_embed_dim
|
||||
)
|
||||
args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim)
|
||||
|
||||
args.no_scale_embedding = getattr(args, 'no_scale_embedding', False)
|
||||
args.no_scale_embedding = getattr(args, "no_scale_embedding", False)
|
||||
args.quant_noise_pq = getattr(args, "quant_noise_pq", 0)
|
||||
args.layernorm_embedding = getattr(args, 'layernorm_embedding', True)
|
||||
args.layernorm_embedding = getattr(args, "layernorm_embedding", True)
|
||||
|
||||
args.activation_fn = getattr(args, 'activation_fn', 'gelu')
|
||||
args.pooler_activation_fn = getattr(args, 'pooler_activation_fn', 'tanh')
|
||||
args.pooler_dropout = getattr(args, 'pooler_dropout', 0.0)
|
||||
args.activation_fn = getattr(args, "activation_fn", "gelu")
|
||||
args.pooler_activation_fn = getattr(args, "pooler_activation_fn", "tanh")
|
||||
args.pooler_dropout = getattr(args, "pooler_dropout", 0.0)
|
||||
args.checkpoint_activations = getattr(args, "checkpoint_activations", False)
|
||||
|
||||
|
||||
@register_model_architecture(model_name="xm_transformer",
|
||||
arch_name="xm_transformer")
|
||||
@register_model_architecture(model_name="xm_transformer", arch_name="xm_transformer")
|
||||
def base_architecture(args):
|
||||
set_default_w2v_encoder_args(args)
|
||||
set_default_adaptor_args(args)
|
||||
|
@ -8,10 +8,17 @@ import logging
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from fairseq.models import (FairseqEncoder, FairseqEncoderModel, register_model,
|
||||
register_model_architecture)
|
||||
from fairseq.models import (
|
||||
FairseqEncoder,
|
||||
FairseqEncoderModel,
|
||||
register_model,
|
||||
register_model_architecture,
|
||||
)
|
||||
from fairseq.modules import (
|
||||
LayerNorm, PositionalEmbedding, FairseqDropout, MultiheadAttention
|
||||
LayerNorm,
|
||||
PositionalEmbedding,
|
||||
FairseqDropout,
|
||||
MultiheadAttention,
|
||||
)
|
||||
from fairseq import utils
|
||||
from fairseq.data.data_utils import lengths_to_padding_mask
|
||||
@ -36,11 +43,19 @@ class PositionwiseFeedForward(nn.Module):
|
||||
def __init__(self, in_dim, hidden_dim, kernel_size, dropout):
|
||||
super().__init__()
|
||||
self.ffn = nn.Sequential(
|
||||
nn.Conv1d(in_dim, hidden_dim, kernel_size=kernel_size,
|
||||
padding=(kernel_size - 1) // 2),
|
||||
nn.Conv1d(
|
||||
in_dim,
|
||||
hidden_dim,
|
||||
kernel_size=kernel_size,
|
||||
padding=(kernel_size - 1) // 2,
|
||||
),
|
||||
nn.ReLU(),
|
||||
nn.Conv1d(hidden_dim, in_dim, kernel_size=kernel_size,
|
||||
padding=(kernel_size - 1) // 2)
|
||||
nn.Conv1d(
|
||||
hidden_dim,
|
||||
in_dim,
|
||||
kernel_size=kernel_size,
|
||||
padding=(kernel_size - 1) // 2,
|
||||
),
|
||||
)
|
||||
self.layer_norm = LayerNorm(in_dim)
|
||||
self.dropout = self.dropout_module = FairseqDropout(
|
||||
@ -57,8 +72,7 @@ class PositionwiseFeedForward(nn.Module):
|
||||
|
||||
class FFTLayer(torch.nn.Module):
|
||||
def __init__(
|
||||
self, embed_dim, n_heads, hidden_dim, kernel_size, dropout,
|
||||
attention_dropout
|
||||
self, embed_dim, n_heads, hidden_dim, kernel_size, dropout, attention_dropout
|
||||
):
|
||||
super().__init__()
|
||||
self.self_attn = MultiheadAttention(
|
||||
@ -74,8 +88,7 @@ class FFTLayer(torch.nn.Module):
|
||||
residual = x
|
||||
x = x.transpose(0, 1)
|
||||
x, _ = self.self_attn(
|
||||
query=x, key=x, value=x, key_padding_mask=padding_mask,
|
||||
need_weights=False
|
||||
query=x, key=x, value=x, key_padding_mask=padding_mask, need_weights=False
|
||||
)
|
||||
x = x.transpose(0, 1)
|
||||
x = self.layer_norm(x + residual)
|
||||
@ -106,11 +119,12 @@ class VariancePredictor(nn.Module):
|
||||
super().__init__()
|
||||
self.conv1 = nn.Sequential(
|
||||
nn.Conv1d(
|
||||
args.encoder_embed_dim, args.var_pred_hidden_dim,
|
||||
args.encoder_embed_dim,
|
||||
args.var_pred_hidden_dim,
|
||||
kernel_size=args.var_pred_kernel_size,
|
||||
padding=(args.var_pred_kernel_size - 1) // 2
|
||||
padding=(args.var_pred_kernel_size - 1) // 2,
|
||||
),
|
||||
nn.ReLU()
|
||||
nn.ReLU(),
|
||||
)
|
||||
self.ln1 = nn.LayerNorm(args.var_pred_hidden_dim)
|
||||
self.dropout_module = FairseqDropout(
|
||||
@ -118,10 +132,12 @@ class VariancePredictor(nn.Module):
|
||||
)
|
||||
self.conv2 = nn.Sequential(
|
||||
nn.Conv1d(
|
||||
args.var_pred_hidden_dim, args.var_pred_hidden_dim,
|
||||
kernel_size=args.var_pred_kernel_size, padding=1
|
||||
args.var_pred_hidden_dim,
|
||||
args.var_pred_hidden_dim,
|
||||
kernel_size=args.var_pred_kernel_size,
|
||||
padding=1,
|
||||
),
|
||||
nn.ReLU()
|
||||
nn.ReLU(),
|
||||
)
|
||||
self.ln2 = nn.LayerNorm(args.var_pred_hidden_dim)
|
||||
self.proj = nn.Linear(args.var_pred_hidden_dim, 1)
|
||||
@ -171,8 +187,15 @@ class VarianceAdaptor(nn.Module):
|
||||
return out, emb
|
||||
|
||||
def forward(
|
||||
self, x, padding_mask, durations=None, pitches=None, energies=None,
|
||||
d_factor=1.0, p_factor=1.0, e_factor=1.0
|
||||
self,
|
||||
x,
|
||||
padding_mask,
|
||||
durations=None,
|
||||
pitches=None,
|
||||
energies=None,
|
||||
d_factor=1.0,
|
||||
p_factor=1.0,
|
||||
e_factor=1.0,
|
||||
):
|
||||
# x: B x T x C
|
||||
log_dur_out = self.duration_predictor(x)
|
||||
@ -205,8 +228,7 @@ class FastSpeech2Encoder(FairseqEncoder):
|
||||
self.spk_emb_proj = None
|
||||
if embed_speaker is not None:
|
||||
self.spk_emb_proj = nn.Linear(
|
||||
args.encoder_embed_dim + args.speaker_embed_dim,
|
||||
args.encoder_embed_dim
|
||||
args.encoder_embed_dim + args.speaker_embed_dim, args.encoder_embed_dim
|
||||
)
|
||||
|
||||
self.dropout_module = FairseqDropout(
|
||||
@ -224,9 +246,12 @@ class FastSpeech2Encoder(FairseqEncoder):
|
||||
|
||||
self.encoder_fft_layers = nn.ModuleList(
|
||||
FFTLayer(
|
||||
args.encoder_embed_dim, args.encoder_attention_heads,
|
||||
args.fft_hidden_dim, args.fft_kernel_size,
|
||||
dropout=args.dropout, attention_dropout=args.attention_dropout
|
||||
args.encoder_embed_dim,
|
||||
args.encoder_attention_heads,
|
||||
args.fft_hidden_dim,
|
||||
args.fft_kernel_size,
|
||||
dropout=args.dropout,
|
||||
attention_dropout=args.attention_dropout,
|
||||
)
|
||||
for _ in range(args.encoder_layers)
|
||||
)
|
||||
@ -235,9 +260,12 @@ class FastSpeech2Encoder(FairseqEncoder):
|
||||
|
||||
self.decoder_fft_layers = nn.ModuleList(
|
||||
FFTLayer(
|
||||
args.decoder_embed_dim, args.decoder_attention_heads,
|
||||
args.fft_hidden_dim, args.fft_kernel_size,
|
||||
dropout=args.dropout, attention_dropout=args.attention_dropout
|
||||
args.decoder_embed_dim,
|
||||
args.decoder_attention_heads,
|
||||
args.fft_hidden_dim,
|
||||
args.fft_kernel_size,
|
||||
dropout=args.dropout,
|
||||
attention_dropout=args.attention_dropout,
|
||||
)
|
||||
for _ in range(args.decoder_layers)
|
||||
)
|
||||
@ -247,15 +275,25 @@ class FastSpeech2Encoder(FairseqEncoder):
|
||||
self.postnet = None
|
||||
if args.add_postnet:
|
||||
self.postnet = Postnet(
|
||||
self.out_dim, args.postnet_conv_dim,
|
||||
self.out_dim,
|
||||
args.postnet_conv_dim,
|
||||
args.postnet_conv_kernel_size,
|
||||
args.postnet_layers, args.postnet_dropout
|
||||
args.postnet_layers,
|
||||
args.postnet_dropout,
|
||||
)
|
||||
|
||||
self.apply(model_init)
|
||||
|
||||
def forward(self, src_tokens, src_lengths=None, speaker=None,
|
||||
durations=None, pitches=None, energies=None, **kwargs):
|
||||
def forward(
|
||||
self,
|
||||
src_tokens,
|
||||
src_lengths=None,
|
||||
speaker=None,
|
||||
durations=None,
|
||||
pitches=None,
|
||||
energies=None,
|
||||
**kwargs
|
||||
):
|
||||
x = self.embed_tokens(src_tokens)
|
||||
|
||||
enc_padding_mask = src_tokens.eq(self.padding_idx)
|
||||
@ -270,8 +308,9 @@ class FastSpeech2Encoder(FairseqEncoder):
|
||||
emb = self.embed_speaker(speaker).expand(bsz, seq_len, -1)
|
||||
x = self.spk_emb_proj(torch.cat([x, emb], dim=2))
|
||||
|
||||
x, out_lens, log_dur_out, pitch_out, energy_out = \
|
||||
self.var_adaptor(x, enc_padding_mask, durations, pitches, energies)
|
||||
x, out_lens, log_dur_out, pitch_out, energy_out = self.var_adaptor(
|
||||
x, enc_padding_mask, durations, pitches, energies
|
||||
)
|
||||
|
||||
dec_padding_mask = lengths_to_padding_mask(out_lens)
|
||||
x += self.dec_pos_emb_alpha * self.embed_positions(dec_padding_mask)
|
||||
@ -326,7 +365,7 @@ class FastSpeech2Model(FairseqEncoderModel):
|
||||
|
||||
out_dim = args.output_frame_dim * args.n_frames_per_step
|
||||
self.ctc_proj = None
|
||||
if getattr(args, "ctc_weight", 0.) > 0.:
|
||||
if getattr(args, "ctc_weight", 0.0) > 0.0:
|
||||
self.ctc_proj = nn.Linear(out_dim, len(src_dict))
|
||||
|
||||
@classmethod
|
||||
|
@ -119,7 +119,7 @@ class Generator(torch.nn.Module):
|
||||
|
||||
self.ups = nn.ModuleList()
|
||||
for i, (u, k) in enumerate(
|
||||
zip(cfg["upsample_rates"], cfg["upsample_kernel_sizes"])
|
||||
zip(cfg["upsample_rates"], cfg["upsample_kernel_sizes"])
|
||||
):
|
||||
self.ups.append(
|
||||
weight_norm(
|
||||
@ -137,7 +137,7 @@ class Generator(torch.nn.Module):
|
||||
for i in range(len(self.ups)):
|
||||
ch = cfg["upsample_initial_channel"] // (2 ** (i + 1))
|
||||
for k, d in zip(
|
||||
cfg["resblock_kernel_sizes"], cfg["resblock_dilation_sizes"]
|
||||
cfg["resblock_kernel_sizes"], cfg["resblock_dilation_sizes"]
|
||||
):
|
||||
self.resblocks.append(ResBlock(ch, k, d))
|
||||
|
||||
|
@ -9,9 +9,13 @@ import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from fairseq.models import (FairseqEncoder, FairseqEncoderDecoderModel,
|
||||
FairseqIncrementalDecoder, register_model,
|
||||
register_model_architecture)
|
||||
from fairseq.models import (
|
||||
FairseqEncoder,
|
||||
FairseqEncoderDecoderModel,
|
||||
FairseqIncrementalDecoder,
|
||||
register_model,
|
||||
register_model_architecture,
|
||||
)
|
||||
from fairseq.modules import LSTMCellWithZoneOut, LocationAttention
|
||||
|
||||
|
||||
@ -31,29 +35,36 @@ class Tacotron2Encoder(FairseqEncoder):
|
||||
self.spk_emb_proj = None
|
||||
if embed_speaker is not None:
|
||||
self.spk_emb_proj = nn.Linear(
|
||||
args.encoder_embed_dim + args.speaker_embed_dim,
|
||||
args.encoder_embed_dim
|
||||
args.encoder_embed_dim + args.speaker_embed_dim, args.encoder_embed_dim
|
||||
)
|
||||
|
||||
self.embed_tokens = nn.Embedding(len(src_dict), args.encoder_embed_dim,
|
||||
padding_idx=self.padding_idx)
|
||||
self.embed_tokens = nn.Embedding(
|
||||
len(src_dict), args.encoder_embed_dim, padding_idx=self.padding_idx
|
||||
)
|
||||
|
||||
assert(args.encoder_conv_kernel_size % 2 == 1)
|
||||
assert args.encoder_conv_kernel_size % 2 == 1
|
||||
self.convolutions = nn.ModuleList(
|
||||
nn.Sequential(
|
||||
nn.Conv1d(args.encoder_embed_dim, args.encoder_embed_dim,
|
||||
kernel_size=args.encoder_conv_kernel_size,
|
||||
padding=((args.encoder_conv_kernel_size - 1) // 2)),
|
||||
nn.Conv1d(
|
||||
args.encoder_embed_dim,
|
||||
args.encoder_embed_dim,
|
||||
kernel_size=args.encoder_conv_kernel_size,
|
||||
padding=((args.encoder_conv_kernel_size - 1) // 2),
|
||||
),
|
||||
nn.BatchNorm1d(args.encoder_embed_dim),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(args.encoder_dropout)
|
||||
nn.Dropout(args.encoder_dropout),
|
||||
)
|
||||
for _ in range(args.encoder_conv_layers)
|
||||
)
|
||||
|
||||
self.lstm = nn.LSTM(args.encoder_embed_dim, args.encoder_embed_dim // 2,
|
||||
num_layers=args.encoder_lstm_layers,
|
||||
batch_first=True, bidirectional=True)
|
||||
self.lstm = nn.LSTM(
|
||||
args.encoder_embed_dim,
|
||||
args.encoder_embed_dim // 2,
|
||||
num_layers=args.encoder_lstm_layers,
|
||||
batch_first=True,
|
||||
bidirectional=True,
|
||||
)
|
||||
|
||||
self.apply(encoder_init)
|
||||
|
||||
@ -78,7 +89,7 @@ class Tacotron2Encoder(FairseqEncoder):
|
||||
|
||||
return {
|
||||
"encoder_out": [x], # B x T x C
|
||||
"encoder_padding_mask": encoder_padding_mask, # B x T
|
||||
"encoder_padding_mask": encoder_padding_mask, # B x T
|
||||
}
|
||||
|
||||
|
||||
@ -86,8 +97,7 @@ class Prenet(nn.Module):
|
||||
def __init__(self, in_dim, n_layers, n_units, dropout):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList(
|
||||
nn.Sequential(nn.Linear(in_dim if i == 0 else n_units, n_units),
|
||||
nn.ReLU())
|
||||
nn.Sequential(nn.Linear(in_dim if i == 0 else n_units, n_units), nn.ReLU())
|
||||
for i in range(n_layers)
|
||||
)
|
||||
self.dropout = dropout
|
||||
@ -102,20 +112,24 @@ class Postnet(nn.Module):
|
||||
def __init__(self, in_dim, n_channels, kernel_size, n_layers, dropout):
|
||||
super(Postnet, self).__init__()
|
||||
self.convolutions = nn.ModuleList()
|
||||
assert(kernel_size % 2 == 1)
|
||||
assert kernel_size % 2 == 1
|
||||
for i in range(n_layers):
|
||||
cur_layers = [
|
||||
nn.Conv1d(in_dim if i == 0 else n_channels,
|
||||
n_channels if i < n_layers - 1 else in_dim,
|
||||
kernel_size=kernel_size,
|
||||
padding=((kernel_size - 1) // 2)),
|
||||
nn.BatchNorm1d(n_channels if i < n_layers - 1 else in_dim)
|
||||
] + ([nn.Tanh()] if i < n_layers - 1 else []) + [nn.Dropout(dropout)]
|
||||
cur_layers = (
|
||||
[
|
||||
nn.Conv1d(
|
||||
in_dim if i == 0 else n_channels,
|
||||
n_channels if i < n_layers - 1 else in_dim,
|
||||
kernel_size=kernel_size,
|
||||
padding=((kernel_size - 1) // 2),
|
||||
),
|
||||
nn.BatchNorm1d(n_channels if i < n_layers - 1 else in_dim),
|
||||
]
|
||||
+ ([nn.Tanh()] if i < n_layers - 1 else [])
|
||||
+ [nn.Dropout(dropout)]
|
||||
)
|
||||
nn.init.xavier_uniform_(
|
||||
cur_layers[0].weight,
|
||||
torch.nn.init.calculate_gain(
|
||||
"tanh" if i < n_layers - 1 else "linear"
|
||||
)
|
||||
torch.nn.init.calculate_gain("tanh" if i < n_layers - 1 else "linear"),
|
||||
)
|
||||
self.convolutions.append(nn.Sequential(*cur_layers))
|
||||
|
||||
@ -138,21 +152,25 @@ class Tacotron2Decoder(FairseqIncrementalDecoder):
|
||||
self.n_frames_per_step = args.n_frames_per_step
|
||||
self.out_dim = args.output_frame_dim * args.n_frames_per_step
|
||||
|
||||
self.prenet = Prenet(self.out_dim, args.prenet_layers, args.prenet_dim,
|
||||
args.prenet_dropout)
|
||||
self.prenet = Prenet(
|
||||
self.out_dim, args.prenet_layers, args.prenet_dim, args.prenet_dropout
|
||||
)
|
||||
|
||||
# take prev_context, prev_frame, (speaker embedding) as input
|
||||
self.attention_lstm = LSTMCellWithZoneOut(
|
||||
args.zoneout,
|
||||
args.prenet_dim + args.encoder_embed_dim,
|
||||
args.decoder_lstm_dim
|
||||
args.decoder_lstm_dim,
|
||||
)
|
||||
|
||||
# take attention_lstm output, attention_state, encoder_out as input
|
||||
self.attention = LocationAttention(
|
||||
args.attention_dim, args.encoder_embed_dim, args.decoder_lstm_dim,
|
||||
args.attention_dim,
|
||||
args.encoder_embed_dim,
|
||||
args.decoder_lstm_dim,
|
||||
(1 + int(args.attention_use_cumprob)),
|
||||
args.attention_conv_dim, args.attention_conv_kernel_size
|
||||
args.attention_conv_dim,
|
||||
args.attention_conv_kernel_size,
|
||||
)
|
||||
|
||||
# take attention_lstm output, context, (gated_latent) as input
|
||||
@ -160,7 +178,7 @@ class Tacotron2Decoder(FairseqIncrementalDecoder):
|
||||
LSTMCellWithZoneOut(
|
||||
args.zoneout,
|
||||
args.encoder_embed_dim + args.decoder_lstm_dim,
|
||||
args.decoder_lstm_dim
|
||||
args.decoder_lstm_dim,
|
||||
)
|
||||
for i in range(args.decoder_lstm_layers)
|
||||
)
|
||||
@ -169,12 +187,16 @@ class Tacotron2Decoder(FairseqIncrementalDecoder):
|
||||
self.feat_proj = nn.Linear(proj_in_dim, self.out_dim)
|
||||
self.eos_proj = nn.Linear(proj_in_dim, 1)
|
||||
|
||||
self.postnet = Postnet(self.out_dim, args.postnet_conv_dim,
|
||||
args.postnet_conv_kernel_size,
|
||||
args.postnet_layers, args.postnet_dropout)
|
||||
self.postnet = Postnet(
|
||||
self.out_dim,
|
||||
args.postnet_conv_dim,
|
||||
args.postnet_conv_kernel_size,
|
||||
args.postnet_layers,
|
||||
args.postnet_dropout,
|
||||
)
|
||||
|
||||
self.ctc_proj = None
|
||||
if getattr(args, "ctc_weight", 0.) > 0.:
|
||||
if getattr(args, "ctc_weight", 0.0) > 0.0:
|
||||
self.ctc_proj = nn.Linear(self.out_dim, len(src_dict))
|
||||
|
||||
self.apply(decoder_init)
|
||||
@ -190,12 +212,16 @@ class Tacotron2Decoder(FairseqIncrementalDecoder):
|
||||
|
||||
lstm_h = self.get_incremental_state(incremental_state, "lstm_h")
|
||||
if lstm_h is None:
|
||||
lstm_h = [enc_out.new_zeros(bsz, self.args.decoder_lstm_dim)
|
||||
for _ in range(self.args.decoder_lstm_layers)]
|
||||
lstm_h = [
|
||||
enc_out.new_zeros(bsz, self.args.decoder_lstm_dim)
|
||||
for _ in range(self.args.decoder_lstm_layers)
|
||||
]
|
||||
lstm_c = self.get_incremental_state(incremental_state, "lstm_c")
|
||||
if lstm_c is None:
|
||||
lstm_c = [enc_out.new_zeros(bsz, self.args.decoder_lstm_dim)
|
||||
for _ in range(self.args.decoder_lstm_layers)]
|
||||
lstm_c = [
|
||||
enc_out.new_zeros(bsz, self.args.decoder_lstm_dim)
|
||||
for _ in range(self.args.decoder_lstm_layers)
|
||||
]
|
||||
|
||||
attn_w = self.get_incremental_state(incremental_state, "attn_w")
|
||||
if attn_w is None:
|
||||
@ -216,8 +242,14 @@ class Tacotron2Decoder(FairseqIncrementalDecoder):
|
||||
else:
|
||||
raise ValueError(f"{self.args.init_attn_c} not supported")
|
||||
|
||||
def forward(self, prev_output_tokens, encoder_out=None,
|
||||
incremental_state=None, target_lengths=None, **kwargs):
|
||||
def forward(
|
||||
self,
|
||||
prev_output_tokens,
|
||||
encoder_out=None,
|
||||
incremental_state=None,
|
||||
target_lengths=None,
|
||||
**kwargs,
|
||||
):
|
||||
enc_mask = encoder_out["encoder_padding_mask"]
|
||||
enc_out = encoder_out["encoder_out"][0]
|
||||
in_len = enc_out.size(1)
|
||||
@ -227,8 +259,9 @@ class Tacotron2Decoder(FairseqIncrementalDecoder):
|
||||
bsz, out_len, _ = prev_output_tokens.size()
|
||||
|
||||
prenet_out = self.prenet(prev_output_tokens)
|
||||
(alstm_h, alstm_c, lstm_h, lstm_c,
|
||||
attn_w, attn_w_cum) = self._get_states(incremental_state, enc_out)
|
||||
(alstm_h, alstm_c, lstm_h, lstm_c, attn_w, attn_w_cum) = self._get_states(
|
||||
incremental_state, enc_out
|
||||
)
|
||||
attn_ctx = self._get_init_attn_c(enc_out, enc_mask)
|
||||
|
||||
attn_out = enc_out.new_zeros(bsz, in_len, out_len)
|
||||
@ -241,9 +274,7 @@ class Tacotron2Decoder(FairseqIncrementalDecoder):
|
||||
attn_state = attn_w.unsqueeze(1)
|
||||
if self.args.attention_use_cumprob:
|
||||
attn_state = torch.stack((attn_w, attn_w_cum), dim=1)
|
||||
attn_ctx, attn_w = self.attention(
|
||||
enc_out, enc_mask, alstm_h, attn_state
|
||||
)
|
||||
attn_ctx, attn_w = self.attention(enc_out, enc_mask, alstm_h, attn_state)
|
||||
attn_w_cum = attn_w_cum + attn_w
|
||||
attn_out[:, :, t] = attn_w
|
||||
|
||||
@ -297,7 +328,7 @@ class Tacotron2Model(FairseqEncoderDecoderModel):
|
||||
parser.add_argument("--postnet-conv-dim", type=int)
|
||||
parser.add_argument("--postnet-conv-kernel-size", type=int)
|
||||
parser.add_argument("--init-attn-c", type=str)
|
||||
parser.add_argument("--attention-use-cumprob", action='store_true')
|
||||
parser.add_argument("--attention-use-cumprob", action="store_true")
|
||||
parser.add_argument("--zoneout", type=float)
|
||||
parser.add_argument("--decoder-lstm-layers", type=int)
|
||||
parser.add_argument("--decoder-lstm-dim", type=int)
|
||||
@ -333,8 +364,7 @@ def base_architecture(args):
|
||||
# decoder
|
||||
args.attention_dim = getattr(args, "attention_dim", 128)
|
||||
args.attention_conv_dim = getattr(args, "attention_conv_dim", 32)
|
||||
args.attention_conv_kernel_size = getattr(args,
|
||||
"attention_conv_kernel_size", 15)
|
||||
args.attention_conv_kernel_size = getattr(args, "attention_conv_kernel_size", 15)
|
||||
args.prenet_dropout = getattr(args, "prenet_dropout", 0.5)
|
||||
args.prenet_layers = getattr(args, "prenet_layers", 2)
|
||||
args.prenet_dim = getattr(args, "prenet_dim", 256)
|
||||
|
@ -9,12 +9,14 @@ from typing import List, Optional
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from fairseq.models import (FairseqEncoder, FairseqEncoderDecoderModel,
|
||||
FairseqIncrementalDecoder, register_model,
|
||||
register_model_architecture)
|
||||
from fairseq.modules import (
|
||||
TransformerEncoderLayer, TransformerDecoderLayer
|
||||
from fairseq.models import (
|
||||
FairseqEncoder,
|
||||
FairseqEncoderDecoderModel,
|
||||
FairseqIncrementalDecoder,
|
||||
register_model,
|
||||
register_model_architecture,
|
||||
)
|
||||
from fairseq.modules import TransformerEncoderLayer, TransformerDecoderLayer
|
||||
from fairseq.models.text_to_speech.tacotron2 import Prenet, Postnet
|
||||
from fairseq.modules import LayerNorm, PositionalEmbedding, FairseqDropout
|
||||
from fairseq.data.data_utils import lengths_to_padding_mask
|
||||
@ -42,30 +44,31 @@ class TTSTransformerEncoder(FairseqEncoder):
|
||||
self.spk_emb_proj = None
|
||||
if embed_speaker is not None:
|
||||
self.spk_emb_proj = nn.Linear(
|
||||
args.encoder_embed_dim + args.speaker_embed_dim,
|
||||
args.encoder_embed_dim
|
||||
args.encoder_embed_dim + args.speaker_embed_dim, args.encoder_embed_dim
|
||||
)
|
||||
|
||||
self.dropout_module = FairseqDropout(
|
||||
p=args.dropout, module_name=self.__class__.__name__
|
||||
)
|
||||
self.embed_tokens = nn.Embedding(len(src_dict), args.encoder_embed_dim,
|
||||
padding_idx=self.padding_idx)
|
||||
assert(args.encoder_conv_kernel_size % 2 == 1)
|
||||
self.embed_tokens = nn.Embedding(
|
||||
len(src_dict), args.encoder_embed_dim, padding_idx=self.padding_idx
|
||||
)
|
||||
assert args.encoder_conv_kernel_size % 2 == 1
|
||||
self.prenet = nn.ModuleList(
|
||||
nn.Sequential(
|
||||
nn.Conv1d(args.encoder_embed_dim, args.encoder_embed_dim,
|
||||
kernel_size=args.encoder_conv_kernel_size,
|
||||
padding=((args.encoder_conv_kernel_size - 1) // 2)),
|
||||
nn.Conv1d(
|
||||
args.encoder_embed_dim,
|
||||
args.encoder_embed_dim,
|
||||
kernel_size=args.encoder_conv_kernel_size,
|
||||
padding=((args.encoder_conv_kernel_size - 1) // 2),
|
||||
),
|
||||
nn.BatchNorm1d(args.encoder_embed_dim),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(args.encoder_dropout),
|
||||
)
|
||||
for _ in range(args.encoder_conv_layers)
|
||||
)
|
||||
self.prenet_proj = nn.Linear(
|
||||
args.encoder_embed_dim, args.encoder_embed_dim
|
||||
)
|
||||
self.prenet_proj = nn.Linear(args.encoder_embed_dim, args.encoder_embed_dim)
|
||||
self.embed_positions = PositionalEmbedding(
|
||||
args.max_source_positions, args.encoder_embed_dim, self.padding_idx
|
||||
)
|
||||
@ -112,7 +115,9 @@ class TTSTransformerEncoder(FairseqEncoder):
|
||||
|
||||
return {
|
||||
"encoder_out": [x], # T x B x C
|
||||
"encoder_padding_mask": [padding_mask] if padding_mask.any() else [], # B x T
|
||||
"encoder_padding_mask": [padding_mask]
|
||||
if padding_mask.any()
|
||||
else [], # B x T
|
||||
"encoder_embedding": [], # B x T x C
|
||||
"encoder_states": [], # List[T x B x C]
|
||||
"src_tokens": [],
|
||||
@ -143,15 +148,15 @@ class TTSTransformerDecoder(FairseqIncrementalDecoder):
|
||||
)
|
||||
self.pos_emb_alpha = nn.Parameter(torch.ones(1))
|
||||
self.prenet = nn.Sequential(
|
||||
Prenet(self.out_dim, args.prenet_layers, args.prenet_dim,
|
||||
args.prenet_dropout),
|
||||
Prenet(
|
||||
self.out_dim, args.prenet_layers, args.prenet_dim, args.prenet_dropout
|
||||
),
|
||||
nn.Linear(args.prenet_dim, args.decoder_embed_dim),
|
||||
)
|
||||
|
||||
self.n_transformer_layers = args.decoder_transformer_layers
|
||||
self.transformer_layers = nn.ModuleList(
|
||||
TransformerDecoderLayer(args)
|
||||
for _ in range(self.n_transformer_layers)
|
||||
TransformerDecoderLayer(args) for _ in range(self.n_transformer_layers)
|
||||
)
|
||||
if args.decoder_normalize_before:
|
||||
self.layer_norm = LayerNorm(args.decoder_embed_dim)
|
||||
@ -161,19 +166,28 @@ class TTSTransformerDecoder(FairseqIncrementalDecoder):
|
||||
self.feat_proj = nn.Linear(args.decoder_embed_dim, self.out_dim)
|
||||
self.eos_proj = nn.Linear(args.decoder_embed_dim, 1)
|
||||
|
||||
self.postnet = Postnet(self.out_dim, args.postnet_conv_dim,
|
||||
args.postnet_conv_kernel_size,
|
||||
args.postnet_layers, args.postnet_dropout)
|
||||
self.postnet = Postnet(
|
||||
self.out_dim,
|
||||
args.postnet_conv_dim,
|
||||
args.postnet_conv_kernel_size,
|
||||
args.postnet_layers,
|
||||
args.postnet_dropout,
|
||||
)
|
||||
|
||||
self.ctc_proj = None
|
||||
if getattr(args, "ctc_weight", 0.) > 0.:
|
||||
if getattr(args, "ctc_weight", 0.0) > 0.0:
|
||||
self.ctc_proj = nn.Linear(self.out_dim, len(src_dict))
|
||||
|
||||
self.apply(decoder_init)
|
||||
|
||||
def extract_features(
|
||||
self, prev_outputs, encoder_out=None, incremental_state=None,
|
||||
target_lengths=None, speaker=None, **kwargs
|
||||
self,
|
||||
prev_outputs,
|
||||
encoder_out=None,
|
||||
incremental_state=None,
|
||||
target_lengths=None,
|
||||
speaker=None,
|
||||
**kwargs
|
||||
):
|
||||
alignment_layer = self.n_transformer_layers - 1
|
||||
self_attn_padding_mask = lengths_to_padding_mask(target_lengths)
|
||||
@ -212,8 +226,8 @@ class TTSTransformerDecoder(FairseqIncrementalDecoder):
|
||||
else None,
|
||||
encoder_out["encoder_padding_mask"][0]
|
||||
if (
|
||||
encoder_out is not None
|
||||
and len(encoder_out["encoder_padding_mask"]) > 0
|
||||
encoder_out is not None
|
||||
and len(encoder_out["encoder_padding_mask"]) > 0
|
||||
)
|
||||
else None,
|
||||
incremental_state,
|
||||
@ -239,13 +253,22 @@ class TTSTransformerDecoder(FairseqIncrementalDecoder):
|
||||
|
||||
return x, {"attn": attn, "inner_states": inner_states}
|
||||
|
||||
def forward(self, prev_output_tokens, encoder_out=None,
|
||||
incremental_state=None, target_lengths=None, speaker=None,
|
||||
**kwargs):
|
||||
def forward(
|
||||
self,
|
||||
prev_output_tokens,
|
||||
encoder_out=None,
|
||||
incremental_state=None,
|
||||
target_lengths=None,
|
||||
speaker=None,
|
||||
**kwargs
|
||||
):
|
||||
x, extra = self.extract_features(
|
||||
prev_output_tokens, encoder_out=encoder_out,
|
||||
incremental_state=incremental_state, target_lengths=target_lengths,
|
||||
speaker=speaker, **kwargs
|
||||
prev_output_tokens,
|
||||
encoder_out=encoder_out,
|
||||
incremental_state=incremental_state,
|
||||
target_lengths=target_lengths,
|
||||
speaker=speaker,
|
||||
**kwargs
|
||||
)
|
||||
attn = extra["attn"]
|
||||
feat_out = self.feat_proj(x)
|
||||
@ -328,8 +351,9 @@ class TTSTransformerModel(FairseqEncoderDecoderModel):
|
||||
return cls(encoder, decoder)
|
||||
|
||||
def forward_encoder(self, src_tokens, src_lengths, speaker=None, **kwargs):
|
||||
return self.encoder(src_tokens, src_lengths=src_lengths,
|
||||
speaker=speaker, **kwargs)
|
||||
return self.encoder(
|
||||
src_tokens, src_lengths=src_lengths, speaker=speaker, **kwargs
|
||||
)
|
||||
|
||||
def set_num_updates(self, num_updates):
|
||||
super().set_num_updates(num_updates)
|
||||
@ -348,7 +372,9 @@ def base_architecture(args):
|
||||
# encoder transformer layers
|
||||
args.encoder_transformer_layers = getattr(args, "encoder_transformer_layers", 6)
|
||||
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
|
||||
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4 * args.encoder_embed_dim)
|
||||
args.encoder_ffn_embed_dim = getattr(
|
||||
args, "encoder_ffn_embed_dim", 4 * args.encoder_embed_dim
|
||||
)
|
||||
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
|
||||
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4)
|
||||
args.attention_dropout = getattr(args, "attention_dropout", 0.0)
|
||||
@ -366,6 +392,8 @@ def base_architecture(args):
|
||||
# decoder transformer layers
|
||||
args.decoder_transformer_layers = getattr(args, "decoder_transformer_layers", 6)
|
||||
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512)
|
||||
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 4 * args.decoder_embed_dim)
|
||||
args.decoder_ffn_embed_dim = getattr(
|
||||
args, "decoder_ffn_embed_dim", 4 * args.decoder_embed_dim
|
||||
)
|
||||
args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False)
|
||||
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 4)
|
||||
|
@ -13,7 +13,10 @@ from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from fairseq.data.audio.audio_utils import (
|
||||
get_window, get_fourier_basis, get_mel_filters, TTSSpectrogram
|
||||
get_window,
|
||||
get_fourier_basis,
|
||||
get_mel_filters,
|
||||
TTSSpectrogram,
|
||||
)
|
||||
from fairseq.data.audio.speech_to_text_dataset import S2TDataConfig
|
||||
from fairseq.models.text_to_speech.hifigan import Generator as HiFiGANModel
|
||||
@ -25,11 +28,9 @@ class PseudoInverseMelScale(torch.nn.Module):
|
||||
def __init__(self, n_stft, n_mels, sample_rate, f_min, f_max) -> None:
|
||||
super(PseudoInverseMelScale, self).__init__()
|
||||
self.n_mels = n_mels
|
||||
basis = get_mel_filters(
|
||||
sample_rate, (n_stft - 1) * 2, n_mels, f_min, f_max
|
||||
)
|
||||
basis = get_mel_filters(sample_rate, (n_stft - 1) * 2, n_mels, f_min, f_max)
|
||||
basis = torch.pinverse(basis) # F x F_mel
|
||||
self.register_buffer('basis', basis)
|
||||
self.register_buffer("basis", basis)
|
||||
|
||||
def forward(self, melspec: torch.Tensor) -> torch.Tensor:
|
||||
# pack batch
|
||||
@ -48,8 +49,12 @@ class PseudoInverseMelScale(torch.nn.Module):
|
||||
|
||||
class GriffinLim(torch.nn.Module):
|
||||
def __init__(
|
||||
self, n_fft: int, win_length: int, hop_length: int, n_iter: int,
|
||||
window_fn=torch.hann_window
|
||||
self,
|
||||
n_fft: int,
|
||||
win_length: int,
|
||||
hop_length: int,
|
||||
n_iter: int,
|
||||
window_fn=torch.hann_window,
|
||||
):
|
||||
super(GriffinLim, self).__init__()
|
||||
self.transform = TTSSpectrogram(
|
||||
@ -59,7 +64,7 @@ class GriffinLim(torch.nn.Module):
|
||||
basis = get_fourier_basis(n_fft)
|
||||
basis = torch.pinverse(n_fft / hop_length * basis).T[:, None, :]
|
||||
basis *= get_window(window_fn, n_fft, win_length)
|
||||
self.register_buffer('basis', basis)
|
||||
self.register_buffer("basis", basis)
|
||||
|
||||
self.n_fft = n_fft
|
||||
self.win_length = win_length
|
||||
@ -70,33 +75,33 @@ class GriffinLim(torch.nn.Module):
|
||||
|
||||
@classmethod
|
||||
def get_window_sum_square(
|
||||
cls, n_frames, hop_length, win_length, n_fft,
|
||||
window_fn=torch.hann_window
|
||||
cls, n_frames, hop_length, win_length, n_fft, window_fn=torch.hann_window
|
||||
) -> torch.Tensor:
|
||||
w_sq = get_window(window_fn, n_fft, win_length) ** 2
|
||||
n = n_fft + hop_length * (n_frames - 1)
|
||||
x = torch.zeros(n, dtype=torch.float32)
|
||||
for i in range(n_frames):
|
||||
ofst = i * hop_length
|
||||
x[ofst: min(n, ofst + n_fft)] += w_sq[:max(0, min(n_fft, n - ofst))]
|
||||
x[ofst : min(n, ofst + n_fft)] += w_sq[: max(0, min(n_fft, n - ofst))]
|
||||
return x
|
||||
|
||||
def inverse(self, magnitude: torch.Tensor, phase) -> torch.Tensor:
|
||||
x = torch.cat(
|
||||
[magnitude * torch.cos(phase), magnitude * torch.sin(phase)],
|
||||
dim=1
|
||||
[magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1
|
||||
)
|
||||
x = F.conv_transpose1d(x, self.basis, stride=self.hop_length)
|
||||
win_sum_sq = self.get_window_sum_square(
|
||||
magnitude.shape[-1], hop_length=self.hop_length,
|
||||
win_length=self.win_length, n_fft=self.n_fft
|
||||
magnitude.shape[-1],
|
||||
hop_length=self.hop_length,
|
||||
win_length=self.win_length,
|
||||
n_fft=self.n_fft,
|
||||
).to(magnitude.device)
|
||||
# remove modulation effects
|
||||
approx_nonzero_indices = win_sum_sq > self.tiny
|
||||
x[:, :, approx_nonzero_indices] /= win_sum_sq[approx_nonzero_indices]
|
||||
x *= self.n_fft / self.hop_length
|
||||
x = x[:, :, self.n_fft // 2:]
|
||||
x = x[:, :, :-self.n_fft // 2:]
|
||||
x = x[:, :, self.n_fft // 2 :]
|
||||
x = x[:, :, : -self.n_fft // 2 :]
|
||||
return x
|
||||
|
||||
def forward(self, specgram: torch.Tensor) -> torch.Tensor:
|
||||
@ -111,18 +116,33 @@ class GriffinLim(torch.nn.Module):
|
||||
|
||||
|
||||
class GriffinLimVocoder(nn.Module):
|
||||
def __init__(self, sample_rate, win_size, hop_size, n_fft,
|
||||
n_mels, f_min, f_max, window_fn,
|
||||
spec_bwd_max_iter=32,
|
||||
fp16=False):
|
||||
def __init__(
|
||||
self,
|
||||
sample_rate,
|
||||
win_size,
|
||||
hop_size,
|
||||
n_fft,
|
||||
n_mels,
|
||||
f_min,
|
||||
f_max,
|
||||
window_fn,
|
||||
spec_bwd_max_iter=32,
|
||||
fp16=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.inv_mel_transform = PseudoInverseMelScale(
|
||||
n_stft=n_fft // 2 + 1, n_mels=n_mels, sample_rate=sample_rate,
|
||||
f_min=f_min, f_max=f_max
|
||||
n_stft=n_fft // 2 + 1,
|
||||
n_mels=n_mels,
|
||||
sample_rate=sample_rate,
|
||||
f_min=f_min,
|
||||
f_max=f_max,
|
||||
)
|
||||
self.gl_transform = GriffinLim(
|
||||
n_fft=n_fft, win_length=win_size, hop_length=hop_size,
|
||||
window_fn=window_fn, n_iter=spec_bwd_max_iter
|
||||
n_fft=n_fft,
|
||||
win_length=win_size,
|
||||
hop_length=hop_size,
|
||||
window_fn=window_fn,
|
||||
n_iter=spec_bwd_max_iter,
|
||||
)
|
||||
if fp16:
|
||||
self.half()
|
||||
@ -151,17 +171,19 @@ class GriffinLimVocoder(nn.Module):
|
||||
sample_rate=feat_cfg["sample_rate"],
|
||||
win_size=int(feat_cfg["win_len_t"] * feat_cfg["sample_rate"]),
|
||||
hop_size=int(feat_cfg["hop_len_t"] * feat_cfg["sample_rate"]),
|
||||
n_fft=feat_cfg["n_fft"], n_mels=feat_cfg["n_mels"],
|
||||
f_min=feat_cfg["f_min"], f_max=feat_cfg["f_max"],
|
||||
window_fn=window_fn, spec_bwd_max_iter=args.spec_bwd_max_iter,
|
||||
fp16=args.fp16
|
||||
n_fft=feat_cfg["n_fft"],
|
||||
n_mels=feat_cfg["n_mels"],
|
||||
f_min=feat_cfg["f_min"],
|
||||
f_max=feat_cfg["f_max"],
|
||||
window_fn=window_fn,
|
||||
spec_bwd_max_iter=args.spec_bwd_max_iter,
|
||||
fp16=args.fp16,
|
||||
)
|
||||
|
||||
|
||||
class HiFiGANVocoder(nn.Module):
|
||||
def __init__(
|
||||
self, checkpoint_path: str, model_cfg: Dict[str, str],
|
||||
fp16: bool = False
|
||||
self, checkpoint_path: str, model_cfg: Dict[str, str], fp16: bool = False
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.model = HiFiGANModel(model_cfg)
|
||||
|
@ -29,8 +29,8 @@ from torch import Tensor
|
||||
|
||||
# rewrite name for backward compatibility in `make_generation_fast_`
|
||||
def module_name_fordropout(module_name: str) -> str:
|
||||
if module_name == 'TransformerDecoderBase':
|
||||
return 'TransformerDecoder'
|
||||
if module_name == "TransformerDecoderBase":
|
||||
return "TransformerDecoder"
|
||||
else:
|
||||
return module_name
|
||||
|
||||
|
@ -29,8 +29,8 @@ from fairseq.models.transformer import (
|
||||
|
||||
# rewrite name for backward compatibility in `make_generation_fast_`
|
||||
def module_name_fordropout(module_name: str) -> str:
|
||||
if module_name == 'TransformerEncoderBase':
|
||||
return 'TransformerEncoder'
|
||||
if module_name == "TransformerEncoderBase":
|
||||
return "TransformerEncoder"
|
||||
else:
|
||||
return module_name
|
||||
|
||||
@ -232,7 +232,12 @@ class TransformerEncoderBase(FairseqEncoder):
|
||||
# `forward` so we use a dictionary instead.
|
||||
# TorchScript does not support mixed values so the values are all lists.
|
||||
# The empty list is equivalent to None.
|
||||
src_lengths = src_tokens.ne(self.padding_idx).sum(dim=1, dtype=torch.int32).reshape(-1, 1).contiguous()
|
||||
src_lengths = (
|
||||
src_tokens.ne(self.padding_idx)
|
||||
.sum(dim=1, dtype=torch.int32)
|
||||
.reshape(-1, 1)
|
||||
.contiguous()
|
||||
)
|
||||
return {
|
||||
"encoder_out": [x], # T x B x C
|
||||
"encoder_padding_mask": [encoder_padding_mask], # B x T
|
||||
|
@ -15,7 +15,9 @@ from fairseq.models import (
|
||||
register_model_architecture,
|
||||
)
|
||||
from fairseq.models.transformer import (
|
||||
DEFAULT_MIN_PARAMS_TO_WRAP, Embedding, TransformerDecoder
|
||||
DEFAULT_MIN_PARAMS_TO_WRAP,
|
||||
Embedding,
|
||||
TransformerDecoder,
|
||||
)
|
||||
from fairseq.modules import AdaptiveInput, CharacterTokenEmbedder
|
||||
from fairseq.utils import safe_getattr, safe_hasattr
|
||||
@ -179,7 +181,7 @@ class TransformerLanguageModelConfig(FairseqDataclass):
|
||||
"is set to 0 (i.e., always wrap) when --checkpoint-activations or "
|
||||
"--offload-activations are passed."
|
||||
)
|
||||
}
|
||||
},
|
||||
)
|
||||
# config for "BASE Layers: Simplifying Training of Large, Sparse Models"
|
||||
base_layers: Optional[int] = field(
|
||||
@ -189,13 +191,25 @@ class TransformerLanguageModelConfig(FairseqDataclass):
|
||||
default=1, metadata={"help": "number of sublayers in each BASE layer"}
|
||||
)
|
||||
base_shuffle: Optional[int] = field(
|
||||
default=1, metadata={"help": "shuffle tokens between workers before computing assignment"}
|
||||
default=1,
|
||||
metadata={"help": "shuffle tokens between workers before computing assignment"},
|
||||
)
|
||||
# NormFormer
|
||||
scale_fc: Optional[bool] = field(default=False, metadata={"help": 'Insert LayerNorm between fully connected layers'})
|
||||
scale_attn: Optional[bool] = field(default=False, metadata={"help": 'Insert LayerNorm after attention'})
|
||||
scale_heads: Optional[bool] = field(default=False, metadata={"help": 'Learn a scale coefficient for each attention head'})
|
||||
scale_resids: Optional[bool] = field(default=False, metadata={"help": 'Learn a scale coefficient for each residual connection'})
|
||||
scale_fc: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "Insert LayerNorm between fully connected layers"},
|
||||
)
|
||||
scale_attn: Optional[bool] = field(
|
||||
default=False, metadata={"help": "Insert LayerNorm after attention"}
|
||||
)
|
||||
scale_heads: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "Learn a scale coefficient for each attention head"},
|
||||
)
|
||||
scale_resids: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "Learn a scale coefficient for each residual connection"},
|
||||
)
|
||||
# options from other parts of the config
|
||||
add_bos_token: bool = II("task.add_bos_token")
|
||||
tokens_per_sample: int = II("task.tokens_per_sample")
|
||||
@ -345,7 +359,9 @@ def base_lm_architecture(args):
|
||||
args.decoder_output_dim = safe_getattr(
|
||||
args, "decoder_output_dim", args.decoder_embed_dim
|
||||
)
|
||||
args.decoder_input_dim = safe_getattr(args, "decoder_input_dim", args.decoder_embed_dim)
|
||||
args.decoder_input_dim = safe_getattr(
|
||||
args, "decoder_input_dim", args.decoder_embed_dim
|
||||
)
|
||||
|
||||
# Model training is not stable without this
|
||||
args.decoder_normalize_before = True
|
||||
@ -362,10 +378,10 @@ def base_lm_architecture(args):
|
||||
args.layernorm_embedding = safe_getattr(args, "layernorm_embedding", False)
|
||||
args.checkpoint_activations = safe_getattr(args, "checkpoint_activations", False)
|
||||
args.offload_activations = safe_getattr(args, "offload_activations", False)
|
||||
args.scale_fc = safe_getattr(args, 'scale_fc', False)
|
||||
args.scale_attn = safe_getattr(args, 'scale_attn', False)
|
||||
args.scale_heads = safe_getattr(args, 'scale_heads', False)
|
||||
args.scale_resids = safe_getattr(args, 'scale_resids', False)
|
||||
args.scale_fc = safe_getattr(args, "scale_fc", False)
|
||||
args.scale_attn = safe_getattr(args, "scale_attn", False)
|
||||
args.scale_heads = safe_getattr(args, "scale_heads", False)
|
||||
args.scale_resids = safe_getattr(args, "scale_resids", False)
|
||||
if args.offload_activations:
|
||||
args.checkpoint_activations = True
|
||||
|
||||
@ -387,7 +403,9 @@ def transformer_lm_baevski_wiki103(args):
|
||||
args.dropout = safe_getattr(args, "dropout", 0.3)
|
||||
args.adaptive_input = safe_getattr(args, "adaptive_input", True)
|
||||
args.tie_adaptive_weights = safe_getattr(args, "tie_adaptive_weights", True)
|
||||
args.adaptive_input_cutoff = safe_getattr(args, "adaptive_input_cutoff", "20000,60000")
|
||||
args.adaptive_input_cutoff = safe_getattr(
|
||||
args, "adaptive_input_cutoff", "20000,60000"
|
||||
)
|
||||
args.adaptive_softmax_cutoff = safe_getattr(
|
||||
args, "adaptive_softmax_cutoff", "20000,60000"
|
||||
)
|
||||
@ -472,7 +490,9 @@ def transformer_lm_gpt2_big(args):
|
||||
def base_gpt3_architecture(args):
|
||||
args.decoder_input_dim = args.decoder_embed_dim
|
||||
args.decoder_output_dim = args.decoder_embed_dim
|
||||
args.decoder_ffn_embed_dim = safe_getattr(args, "decoder_ffn_embed_dim", args.decoder_embed_dim * 4)
|
||||
args.decoder_ffn_embed_dim = safe_getattr(
|
||||
args, "decoder_ffn_embed_dim", args.decoder_embed_dim * 4
|
||||
)
|
||||
# GPT-3 used learned positional embeddings, rather than sinusoidal
|
||||
args.decoder_learned_pos = safe_getattr(args, "decoder_learned_pos", True)
|
||||
args.dropout = safe_getattr(args, "dropout", 0.0)
|
||||
|
@ -232,9 +232,11 @@ class Wav2Vec2Config(FairseqDataclass):
|
||||
)
|
||||
|
||||
checkpoint_activations: bool = field(
|
||||
default=False, metadata={"help": "recompute activations and save memory for extra compute"}
|
||||
default=False,
|
||||
metadata={"help": "recompute activations and save memory for extra compute"},
|
||||
)
|
||||
|
||||
|
||||
@register_model("wav2vec2", dataclass=Wav2Vec2Config)
|
||||
class Wav2Vec2Model(BaseFairseqModel):
|
||||
def __init__(self, cfg: Wav2Vec2Config):
|
||||
@ -844,14 +846,14 @@ class TransformerEncoder(nn.Module):
|
||||
layers = []
|
||||
for _ in range(args.encoder_layers):
|
||||
layer = TransformerSentenceEncoderLayer(
|
||||
embedding_dim=self.embedding_dim,
|
||||
ffn_embedding_dim=args.encoder_ffn_embed_dim,
|
||||
num_attention_heads=args.encoder_attention_heads,
|
||||
dropout=self.dropout,
|
||||
attention_dropout=args.attention_dropout,
|
||||
activation_dropout=args.activation_dropout,
|
||||
activation_fn=args.activation_fn,
|
||||
layer_norm_first=args.layer_norm_first,
|
||||
embedding_dim=self.embedding_dim,
|
||||
ffn_embedding_dim=args.encoder_ffn_embed_dim,
|
||||
num_attention_heads=args.encoder_attention_heads,
|
||||
dropout=self.dropout,
|
||||
attention_dropout=args.attention_dropout,
|
||||
activation_dropout=args.activation_dropout,
|
||||
activation_fn=args.activation_fn,
|
||||
layer_norm_first=args.layer_norm_first,
|
||||
)
|
||||
if args.checkpoint_activations:
|
||||
layer = fsdp_wrap(layer)
|
||||
|
@ -152,10 +152,12 @@ class Wav2Vec2AsrConfig(FairseqDataclass):
|
||||
w2v_args: Any = None
|
||||
|
||||
checkpoint_activations: bool = field(
|
||||
default=False, metadata={"help": "recompute activations and save memory for extra compute"}
|
||||
default=False,
|
||||
metadata={"help": "recompute activations and save memory for extra compute"},
|
||||
)
|
||||
ddp_backend: str = II("distributed_training.ddp_backend")
|
||||
|
||||
|
||||
@dataclass
|
||||
class Wav2Vec2CtcConfig(Wav2Vec2AsrConfig):
|
||||
blank_weight: float = 0
|
||||
@ -268,6 +270,7 @@ class Wav2Vec2Seq2SeqConfig(Wav2Vec2AsrConfig):
|
||||
)
|
||||
autoregressive: bool = II("task.autoregressive")
|
||||
|
||||
|
||||
@register_model("wav2vec_seq2seq", dataclass=Wav2Vec2Seq2SeqConfig)
|
||||
class Wav2Vec2Seq2SeqModel(FairseqEncoderDecoderModel):
|
||||
def __init__(self, encoder, decoder):
|
||||
@ -394,12 +397,17 @@ class Wav2VecEncoder(FairseqEncoder):
|
||||
def load_model_weights(self, state, model, cfg):
|
||||
if cfg.ddp_backend == "fully_sharded":
|
||||
from fairseq.distributed import FullyShardedDataParallel
|
||||
|
||||
for name, module in model.named_modules():
|
||||
if "encoder.layers" in name and len(name.split(".")) == 3:
|
||||
# Only for layers, we do a special handling and load the weights one by one
|
||||
# We dont load all weights together as that wont be memory efficient and may
|
||||
# cause oom
|
||||
new_dict = {k.replace(name+".", "") : v for (k, v) in state["model"].items() if name+"." in k}
|
||||
new_dict = {
|
||||
k.replace(name + ".", ""): v
|
||||
for (k, v) in state["model"].items()
|
||||
if name + "." in k
|
||||
}
|
||||
assert isinstance(module, FullyShardedDataParallel)
|
||||
with module.summon_full_params():
|
||||
module.load_state_dict(new_dict, strict=True)
|
||||
@ -409,7 +417,9 @@ class Wav2VecEncoder(FairseqEncoder):
|
||||
r = re.compile("encoder.layers.\d.")
|
||||
filtered_list = list(filter(r.match, state["model"].keys()))
|
||||
|
||||
new_big_dict = {k: v for (k, v) in state["model"].items() if k not in filtered_list}
|
||||
new_big_dict = {
|
||||
k: v for (k, v) in state["model"].items() if k not in filtered_list
|
||||
}
|
||||
|
||||
model.load_state_dict(new_big_dict, strict=False)
|
||||
else:
|
||||
@ -462,9 +472,9 @@ class Wav2VecEncoder(FairseqEncoder):
|
||||
1, new_order
|
||||
)
|
||||
if encoder_out["padding_mask"] is not None:
|
||||
encoder_out["padding_mask"] = encoder_out[
|
||||
"padding_mask"
|
||||
].index_select(0, new_order)
|
||||
encoder_out["padding_mask"] = encoder_out["padding_mask"].index_select(
|
||||
0, new_order
|
||||
)
|
||||
return encoder_out
|
||||
|
||||
def max_positions(self):
|
||||
@ -640,7 +650,7 @@ class TransformerDecoder(FairseqIncrementalDecoder):
|
||||
self_attn_mask=self.buffered_future_mask(x)
|
||||
if incremental_state is None
|
||||
else None,
|
||||
self_attn_padding_mask=self_attn_padding_mask
|
||||
self_attn_padding_mask=self_attn_padding_mask,
|
||||
)
|
||||
inner_states.append(x)
|
||||
|
||||
|
@ -12,14 +12,17 @@ from fairseq.modules.layer_norm import LayerNorm
|
||||
|
||||
|
||||
class BaseLayer(nn.Module):
|
||||
|
||||
def __init__(self, args):
|
||||
super().__init__()
|
||||
self.num_workers = distributed_utils.get_data_parallel_world_size()
|
||||
expert_centroids = torch.empty(self.num_workers, args.decoder_embed_dim)
|
||||
torch.nn.init.orthogonal_(expert_centroids, gain=0.1)
|
||||
self.register_parameter("expert_centroids", torch.nn.Parameter(expert_centroids))
|
||||
self.expert_network = nn.Sequential(*([BaseSublayer(args) for _ in range(args.base_sublayers)]))
|
||||
self.register_parameter(
|
||||
"expert_centroids", torch.nn.Parameter(expert_centroids)
|
||||
)
|
||||
self.expert_network = nn.Sequential(
|
||||
*([BaseSublayer(args) for _ in range(args.base_sublayers)])
|
||||
)
|
||||
self.expert_id = distributed_utils.get_data_parallel_rank()
|
||||
self.shuffle = args.base_shuffle
|
||||
self.cpp = self.load_assignment()
|
||||
@ -39,20 +42,34 @@ class BaseLayer(nn.Module):
|
||||
|
||||
with torch.no_grad():
|
||||
# Compute similarity of each token to each expert, for routing
|
||||
token_expert_affinities = features.matmul(self.expert_centroids.transpose(0, 1))
|
||||
token_expert_affinities = features.matmul(
|
||||
self.expert_centroids.transpose(0, 1)
|
||||
)
|
||||
|
||||
# Compute which token goes to which expert
|
||||
sort_by_expert, input_splits, output_splits = self.balanced_assignment(token_expert_affinities) \
|
||||
if is_training else self.greedy_assignment(token_expert_affinities)
|
||||
sort_by_expert, input_splits, output_splits = (
|
||||
self.balanced_assignment(token_expert_affinities)
|
||||
if is_training
|
||||
else self.greedy_assignment(token_expert_affinities)
|
||||
)
|
||||
# Swap these tokens for the right ones for our expert
|
||||
routed_features = All2All.apply(features[sort_by_expert], output_splits, input_splits)
|
||||
routed_features = All2All.apply(
|
||||
features[sort_by_expert], output_splits, input_splits
|
||||
)
|
||||
|
||||
if routed_features.size(0) > 0:
|
||||
# Mix in the expert network based on how appropriate it is for these tokens
|
||||
alpha = torch.sigmoid(routed_features.mv(self.expert_centroids[self.expert_id])).unsqueeze(1)
|
||||
routed_features = alpha * self.expert_network(routed_features) + (1 - alpha) * routed_features
|
||||
alpha = torch.sigmoid(
|
||||
routed_features.mv(self.expert_centroids[self.expert_id])
|
||||
).unsqueeze(1)
|
||||
routed_features = (
|
||||
alpha * self.expert_network(routed_features)
|
||||
+ (1 - alpha) * routed_features
|
||||
)
|
||||
# Return to original worker and ordering
|
||||
result = All2All.apply(routed_features, input_splits, output_splits)[self.inverse_sort(sort_by_expert)]
|
||||
result = All2All.apply(routed_features, input_splits, output_splits)[
|
||||
self.inverse_sort(sort_by_expert)
|
||||
]
|
||||
|
||||
if self.shuffle and is_training:
|
||||
# Undo shuffling
|
||||
@ -63,7 +80,9 @@ class BaseLayer(nn.Module):
|
||||
|
||||
def inverse_sort(self, order):
|
||||
# Creates an index that undoes a sort: xs==xs[order][inverse_sort(order)]
|
||||
return torch.empty_like(order).scatter_(0, order, torch.arange(0, order.size(0), device=order.device))
|
||||
return torch.empty_like(order).scatter_(
|
||||
0, order, torch.arange(0, order.size(0), device=order.device)
|
||||
)
|
||||
|
||||
def balanced_assignment(self, scores):
|
||||
ok = scores.isfinite()
|
||||
@ -79,7 +98,9 @@ class BaseLayer(nn.Module):
|
||||
worker2token = sort_ordering // k
|
||||
|
||||
# Find how many tokens we're sending to each other worker (being careful for sending 0 tokens to some workers)
|
||||
output_splits = torch.zeros((self.num_workers,), dtype=torch.long, device=scores.device)
|
||||
output_splits = torch.zeros(
|
||||
(self.num_workers,), dtype=torch.long, device=scores.device
|
||||
)
|
||||
workers, counts = torch.unique_consecutive(token_to_workers, return_counts=True)
|
||||
output_splits[workers] = counts
|
||||
# Tell other workers how many tokens to expect from us
|
||||
@ -103,7 +124,7 @@ class BaseSublayer(nn.Module):
|
||||
def __init__(self, args):
|
||||
super().__init__()
|
||||
self.activation_fn = utils.get_activation_fn(
|
||||
activation=getattr(args, 'activation_fn', 'relu') or "relu"
|
||||
activation=getattr(args, "activation_fn", "relu") or "relu"
|
||||
)
|
||||
self.norm = LayerNorm(args.decoder_embed_dim, export=False)
|
||||
self.ff1 = torch.nn.Linear(args.decoder_embed_dim, args.decoder_ffn_embed_dim)
|
||||
@ -121,15 +142,29 @@ class All2All(torch.autograd.Function):
|
||||
ctx.input_splits = input_splits
|
||||
ctx.output_splits = output_splits
|
||||
|
||||
ys = torch.empty_like(xs) if output_splits is None else \
|
||||
xs.new_empty(size=[sum(output_splits)] + list(xs.size()[1:]))
|
||||
torch.distributed.all_to_all_single(ys, xs, output_split_sizes=output_splits, input_split_sizes=input_splits)
|
||||
ys = (
|
||||
torch.empty_like(xs)
|
||||
if output_splits is None
|
||||
else xs.new_empty(size=[sum(output_splits)] + list(xs.size()[1:]))
|
||||
)
|
||||
torch.distributed.all_to_all_single(
|
||||
ys, xs, output_split_sizes=output_splits, input_split_sizes=input_splits
|
||||
)
|
||||
return ys
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
result = torch.empty_like(grad_output) if ctx.input_splits is None else \
|
||||
grad_output.new_empty(size=[sum(ctx.input_splits)] + list(grad_output.size()[1:]))
|
||||
torch.distributed.all_to_all_single(result, grad_output,
|
||||
output_split_sizes=ctx.input_splits, input_split_sizes=ctx.output_splits)
|
||||
result = (
|
||||
torch.empty_like(grad_output)
|
||||
if ctx.input_splits is None
|
||||
else grad_output.new_empty(
|
||||
size=[sum(ctx.input_splits)] + list(grad_output.size()[1:])
|
||||
)
|
||||
)
|
||||
torch.distributed.all_to_all_single(
|
||||
result,
|
||||
grad_output,
|
||||
output_split_sizes=ctx.input_splits,
|
||||
input_split_sizes=ctx.output_splits,
|
||||
)
|
||||
return result, None, None
|
||||
|
@ -166,7 +166,9 @@ class CheckpointFunction(torch.autograd.Function):
|
||||
if parent_ctx_dict["offload"]:
|
||||
ctx.fwd_device = tuple(x.device for x in tensor_inputs)
|
||||
ctx.grad_requirements = tuple(x.requires_grad for x in tensor_inputs)
|
||||
tensor_inputs = tuple(x.to(torch.device("cpu"), non_blocking=True) for x in tensor_inputs)
|
||||
tensor_inputs = tuple(
|
||||
x.to(torch.device("cpu"), non_blocking=True) for x in tensor_inputs
|
||||
)
|
||||
|
||||
else:
|
||||
ctx.fwd_device, ctx.grad_requirements = None, None
|
||||
@ -199,7 +201,8 @@ class CheckpointFunction(torch.autograd.Function):
|
||||
tensor_inputs = checkpoint.detach_variable(tensor_inputs)
|
||||
if ctx.fwd_device is not None:
|
||||
tensor_inputs = [
|
||||
t.to(ctx.fwd_device[i], non_blocking=True) for i, t in enumerate(tensor_inputs)
|
||||
t.to(ctx.fwd_device[i], non_blocking=True)
|
||||
for i, t in enumerate(tensor_inputs)
|
||||
]
|
||||
for i, need_grad in enumerate(ctx.grad_requirements):
|
||||
tensor_inputs[i].requires_grad = need_grad
|
||||
|
@ -75,6 +75,7 @@ class GumbelVectorQuantizer(nn.Module):
|
||||
|
||||
if isinstance(temp, str):
|
||||
import ast
|
||||
|
||||
temp = ast.literal_eval(temp)
|
||||
assert len(temp) == 3, f"{temp}, {len(temp)}"
|
||||
|
||||
|
@ -47,11 +47,12 @@ def cache_fn(f):
|
||||
return cache
|
||||
cache = f(*args, **kwargs)
|
||||
return cache
|
||||
|
||||
return cached_fn
|
||||
|
||||
|
||||
def to(t):
|
||||
return {'device': t.device, 'dtype': t.dtype}
|
||||
return {"device": t.device, "dtype": t.dtype}
|
||||
|
||||
|
||||
def find_modules(nn_module, type):
|
||||
@ -102,7 +103,7 @@ def reshape_dim(t, dim, split_dims):
|
||||
shape = list(t.shape)
|
||||
num_dims = len(shape)
|
||||
dim = (dim + num_dims) % num_dims
|
||||
shape[dim:dim+1] = split_dims
|
||||
shape[dim : dim + 1] = split_dims
|
||||
return t.reshape(shape)
|
||||
|
||||
|
||||
@ -118,6 +119,7 @@ def ema_inplace(moving_avg, new, decay):
|
||||
return
|
||||
moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
|
||||
|
||||
|
||||
# helper classes
|
||||
|
||||
|
||||
@ -173,6 +175,7 @@ class ScaleNorm(nn.Module):
|
||||
def norm(t):
|
||||
n = torch.norm(t, dim=-1, keepdim=True).clamp(min=self.eps)
|
||||
return t / n * self.g
|
||||
|
||||
return map_first_tuple_or_el(x, norm)
|
||||
|
||||
|
||||
@ -202,51 +205,62 @@ class MatrixMultiply(nn.Module):
|
||||
tensor = tensor.t()
|
||||
return x @ tensor
|
||||
|
||||
|
||||
# positional embeddings
|
||||
|
||||
|
||||
class DepthWiseConv1d(nn.Module):
|
||||
def __init__(self, dim_in, dim_out, kernel_size, stride=1, bias=True, causal=False):
|
||||
super().__init__()
|
||||
self.padding = ((kernel_size - 1), 0) if causal else (kernel_size // 2, kernel_size // 2)
|
||||
self.padding = (
|
||||
((kernel_size - 1), 0) if causal else (kernel_size // 2, kernel_size // 2)
|
||||
)
|
||||
|
||||
self.net = nn.Sequential(
|
||||
nn.Conv1d(dim_in, dim_in, kernel_size=kernel_size, groups=dim_in, stride=stride, bias=bias),
|
||||
nn.Conv1d(dim_in, dim_out, 1, bias=bias)
|
||||
nn.Conv1d(
|
||||
dim_in,
|
||||
dim_in,
|
||||
kernel_size=kernel_size,
|
||||
groups=dim_in,
|
||||
stride=stride,
|
||||
bias=bias,
|
||||
),
|
||||
nn.Conv1d(dim_in, dim_out, 1, bias=bias),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = F.pad(x, self.padding, value=0.)
|
||||
x = F.pad(x, self.padding, value=0.0)
|
||||
return self.net(x)
|
||||
|
||||
|
||||
class FixedPositionalEmbedding(nn.Module):
|
||||
def __init__(self, dim, max_seq_len):
|
||||
super().__init__()
|
||||
inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
|
||||
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
|
||||
position = torch.arange(0, max_seq_len, dtype=torch.float)
|
||||
sinusoid_inp = torch.einsum("i,j->ij", position, inv_freq)
|
||||
emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1)
|
||||
self.register_buffer('emb', emb)
|
||||
self.register_buffer("emb", emb)
|
||||
|
||||
def forward(self, x):
|
||||
return self.emb[None, :x.shape[1], :].to(x)
|
||||
return self.emb[None, : x.shape[1], :].to(x)
|
||||
|
||||
|
||||
def rotate_every_two(x):
|
||||
x = rearrange(x, '... (d j) -> ... d j', j=2)
|
||||
x = rearrange(x, "... (d j) -> ... d j", j=2)
|
||||
x1, x2 = x.unbind(dim=-1)
|
||||
x = torch.stack((-x2, x1), dim=-1)
|
||||
return rearrange(x, '... d j -> ... (d j)')
|
||||
return rearrange(x, "... d j -> ... (d j)")
|
||||
|
||||
|
||||
def apply_rotary_pos_emb(q, k, sinu_pos):
|
||||
sinu_pos = rearrange(sinu_pos, '() n (j d) -> n j d', j=2)
|
||||
sinu_pos = rearrange(sinu_pos, "() n (j d) -> n j d", j=2)
|
||||
sin, cos = sinu_pos.unbind(dim=-2)
|
||||
sin, cos = map(lambda t: repeat(t, 'b n -> b (n j)', j=2), (sin, cos))
|
||||
sin, cos = map(lambda t: repeat(t, "b n -> b (n j)", j=2), (sin, cos))
|
||||
q, k = map(lambda t: (t * cos) + (rotate_every_two(t) * sin), (q, k))
|
||||
return q, k
|
||||
|
||||
|
||||
# kmeans related function and class
|
||||
|
||||
|
||||
@ -261,7 +275,7 @@ def update_kmeans_on_backwards(module):
|
||||
|
||||
|
||||
def similarity(x, means):
|
||||
return torch.einsum('bhld,hcd->bhlc', x, means)
|
||||
return torch.einsum("bhld,hcd->bhlc", x, means)
|
||||
|
||||
|
||||
def dists_and_buckets(x, means):
|
||||
@ -303,13 +317,15 @@ def distribution(dists, window_size):
|
||||
|
||||
|
||||
class Kmeans(nn.Module):
|
||||
def __init__(self, num_heads, head_dim, num_clusters, ema_decay=0.999, commitment=1e-4):
|
||||
def __init__(
|
||||
self, num_heads, head_dim, num_clusters, ema_decay=0.999, commitment=1e-4
|
||||
):
|
||||
super().__init__()
|
||||
self.commitment = commitment
|
||||
self.ema_decay = ema_decay
|
||||
|
||||
self.register_buffer('means', torch.randn(num_heads, num_clusters, head_dim))
|
||||
self.register_buffer('initted', torch.tensor(False))
|
||||
self.register_buffer("means", torch.randn(num_heads, num_clusters, head_dim))
|
||||
self.register_buffer("initted", torch.tensor(False))
|
||||
self.num_new_means = 0
|
||||
self.new_means = None
|
||||
|
||||
@ -341,7 +357,7 @@ class Kmeans(nn.Module):
|
||||
@torch.no_grad()
|
||||
def update(self, new_means=None):
|
||||
new_means = default(new_means, self.new_means)
|
||||
assert exists(new_means), 'new kmeans has not been supplied'
|
||||
assert exists(new_means), "new kmeans has not been supplied"
|
||||
ema_inplace(self.means, new_means, self.ema_decay)
|
||||
|
||||
del self.new_means
|
||||
@ -364,16 +380,33 @@ class Kmeans(nn.Module):
|
||||
if update_means:
|
||||
with torch.no_grad():
|
||||
means = kmeans_iter(x, means, buckets)
|
||||
self.new_means = ema(self.new_means, means, self.num_new_means / (self.num_new_means + 1))
|
||||
self.new_means = ema(
|
||||
self.new_means, means, self.num_new_means / (self.num_new_means + 1)
|
||||
)
|
||||
self.num_new_means += 1
|
||||
|
||||
return dists, loss
|
||||
|
||||
|
||||
# kmeans attention class
|
||||
|
||||
|
||||
class KmeansAttention(nn.Module):
|
||||
def __init__(self, num_clusters, window_size, num_heads, head_dim, causal=False, dropout=0., ema_decay=0.999, commitment=1e-4, context_window_size=None, receives_context=False, num_mem_kv=0, shared_qk=False):
|
||||
def __init__(
|
||||
self,
|
||||
num_clusters,
|
||||
window_size,
|
||||
num_heads,
|
||||
head_dim,
|
||||
causal=False,
|
||||
dropout=0.0,
|
||||
ema_decay=0.999,
|
||||
commitment=1e-4,
|
||||
context_window_size=None,
|
||||
receives_context=False,
|
||||
num_mem_kv=0,
|
||||
shared_qk=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.num_clusters = num_clusters
|
||||
@ -389,18 +422,32 @@ class KmeansAttention(nn.Module):
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
self.num_mem_kv = max(num_mem_kv, 1 if causal and not shared_qk else 0)
|
||||
self.mem_key = nn.Parameter(torch.randn(num_heads, num_clusters, self.num_mem_kv, head_dim))
|
||||
self.mem_value = nn.Parameter(torch.randn(num_heads, num_clusters, self.num_mem_kv, head_dim))
|
||||
self.mem_key = nn.Parameter(
|
||||
torch.randn(num_heads, num_clusters, self.num_mem_kv, head_dim)
|
||||
)
|
||||
self.mem_value = nn.Parameter(
|
||||
torch.randn(num_heads, num_clusters, self.num_mem_kv, head_dim)
|
||||
)
|
||||
|
||||
def forward(self, q, k, v, query_mask=None, key_mask=None, **kwargs):
|
||||
b, h, t, d, kv_t, wsz, c_wsz, nc, device, dtype = *q.shape, k.shape[2], self.window_size, self.context_window_size, self.num_clusters, q.device, q.dtype
|
||||
is_reverse = kwargs.pop('_reverse', False)
|
||||
b, h, t, d, kv_t, wsz, c_wsz, nc, device, dtype = (
|
||||
*q.shape,
|
||||
k.shape[2],
|
||||
self.window_size,
|
||||
self.context_window_size,
|
||||
self.num_clusters,
|
||||
q.device,
|
||||
q.dtype,
|
||||
)
|
||||
is_reverse = kwargs.pop("_reverse", False)
|
||||
|
||||
out = torch.zeros_like(q, dtype=dtype)
|
||||
|
||||
update_kmeans = self.training and not is_reverse
|
||||
|
||||
key_mask = default(key_mask, query_mask) if not self.receives_context else key_mask
|
||||
key_mask = (
|
||||
default(key_mask, query_mask) if not self.receives_context else key_mask
|
||||
)
|
||||
kv_wsz = wsz if not self.receives_context else c_wsz
|
||||
|
||||
wsz = min(wsz, t)
|
||||
@ -424,16 +471,22 @@ class KmeansAttention(nn.Module):
|
||||
reshape_with_window = lambda x: x.reshape(b, h, nc, -1, d)
|
||||
q, k, v = map(reshape_with_window, (q, k, v))
|
||||
|
||||
m_k, m_v = map(lambda x: expand_dim(x, 0, b).to(q), (self.mem_key, self.mem_value))
|
||||
m_k, m_v = map(
|
||||
lambda x: expand_dim(x, 0, b).to(q), (self.mem_key, self.mem_value)
|
||||
)
|
||||
k, v = map(lambda x: torch.cat(x, dim=3), ((m_k, k), (m_v, v)))
|
||||
|
||||
dots = torch.einsum('bhnid,bhnjd->bhnij', q, k) * (d ** -0.5)
|
||||
dots = torch.einsum("bhnid,bhnjd->bhnij", q, k) * (d ** -0.5)
|
||||
|
||||
mask_value = max_neg_value(dots)
|
||||
|
||||
if exists(query_mask) or exists(key_mask):
|
||||
query_mask = default(query_mask, lambda: torch.ones((b, t), device=device).bool())
|
||||
key_mask = default(key_mask, lambda: torch.ones((b, kv_t), device=device).bool())
|
||||
query_mask = default(
|
||||
query_mask, lambda: torch.ones((b, t), device=device).bool()
|
||||
)
|
||||
key_mask = default(
|
||||
key_mask, lambda: torch.ones((b, kv_t), device=device).bool()
|
||||
)
|
||||
|
||||
q_mask = expand_dim(query_mask, 1, h).gather(2, indices)
|
||||
kv_mask = expand_dim(key_mask, 1, h).gather(2, kv_indices)
|
||||
@ -444,14 +497,18 @@ class KmeansAttention(nn.Module):
|
||||
del mask
|
||||
|
||||
if self.causal:
|
||||
q_mask, kv_mask = map(lambda t: t.reshape(b, h, nc, -1), (indices, kv_indices))
|
||||
q_mask, kv_mask = map(
|
||||
lambda t: t.reshape(b, h, nc, -1), (indices, kv_indices)
|
||||
)
|
||||
mask = q_mask[:, :, :, :, None] >= kv_mask[:, :, :, None, :]
|
||||
mask = F.pad(mask, (self.num_mem_kv, 0), value=1)
|
||||
dots.masked_fill_(~mask, mask_value)
|
||||
del mask
|
||||
|
||||
if self.shared_qk:
|
||||
q_mask, kv_mask = map(lambda t: t.reshape(b, h, nc, -1), (indices, kv_indices))
|
||||
q_mask, kv_mask = map(
|
||||
lambda t: t.reshape(b, h, nc, -1), (indices, kv_indices)
|
||||
)
|
||||
mask = q_mask[:, :, :, :, None] == kv_mask[:, :, :, None, :]
|
||||
mask = F.pad(mask, (self.num_mem_kv, 0), value=0)
|
||||
dots.masked_fill_(mask, TOKEN_SELF_ATTN_VALUE)
|
||||
@ -460,24 +517,32 @@ class KmeansAttention(nn.Module):
|
||||
dots = dots.softmax(dim=-1)
|
||||
dots = self.dropout(dots)
|
||||
|
||||
bo = torch.einsum('bhcij,bhcjd->bhcid', dots, v)
|
||||
bo = torch.einsum("bhcij,bhcjd->bhcid", dots, v)
|
||||
so = torch.reshape(bo, (b, h, -1, bo.shape[-1])).type(dtype)
|
||||
out = scatter_mean(out, so, indices.unsqueeze(-1).expand_as(so), -2)
|
||||
return out, aux_loss
|
||||
|
||||
|
||||
# feedforward
|
||||
|
||||
|
||||
class GELU_(nn.Module):
|
||||
def forward(self, x):
|
||||
return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
|
||||
return (
|
||||
0.5
|
||||
* x
|
||||
* (
|
||||
1
|
||||
+ torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
GELU = nn.GELU if hasattr(nn, 'GELU') else GELU_
|
||||
GELU = nn.GELU if hasattr(nn, "GELU") else GELU_
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, mult=4, dropout=0., activation=None, glu=False):
|
||||
def __init__(self, dim, mult=4, dropout=0.0, activation=None, glu=False):
|
||||
super().__init__()
|
||||
activation = default(activation, GELU)
|
||||
|
||||
@ -499,17 +564,49 @@ class FeedForward(nn.Module):
|
||||
x = self.w2(x)
|
||||
return x
|
||||
|
||||
|
||||
# self attention
|
||||
|
||||
|
||||
class SelfAttention(nn.Module):
|
||||
def __init__(self, dim, max_seq_len, heads, local_attn_heads, window_size, dim_head=None, local_attn_window_size=None, local_attn_radius_blocks=1, causal=False, attn_dropout=0., dropout=0., kmeans_ema_decay=0.999, commitment_factor=1e-4, receives_context=False, context_window_size=None, rel_pos_emb=True, num_mem_kv=0, shared_qk=False, conv_query_kernel=9):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
max_seq_len,
|
||||
heads,
|
||||
local_attn_heads,
|
||||
window_size,
|
||||
dim_head=None,
|
||||
local_attn_window_size=None,
|
||||
local_attn_radius_blocks=1,
|
||||
causal=False,
|
||||
attn_dropout=0.0,
|
||||
dropout=0.0,
|
||||
kmeans_ema_decay=0.999,
|
||||
commitment_factor=1e-4,
|
||||
receives_context=False,
|
||||
context_window_size=None,
|
||||
rel_pos_emb=True,
|
||||
num_mem_kv=0,
|
||||
shared_qk=False,
|
||||
conv_query_kernel=9,
|
||||
):
|
||||
super().__init__()
|
||||
assert dim_head or (dim % heads) == 0, 'hidden dimension must be divisible by number of heads'
|
||||
assert (max_seq_len % window_size) == 0, 'maximum sequence length must be divisible by the target window size'
|
||||
assert local_attn_heads <= heads, 'number of local attention heads must be less than total heads'
|
||||
assert not (receives_context and local_attn_heads > 0), 'local attention cannot be used for self attention with context'
|
||||
assert not (receives_context and causal), 'contextual attention layer cannot be causal'
|
||||
assert (
|
||||
dim_head or (dim % heads) == 0
|
||||
), "hidden dimension must be divisible by number of heads"
|
||||
assert (
|
||||
max_seq_len % window_size
|
||||
) == 0, "maximum sequence length must be divisible by the target window size"
|
||||
assert (
|
||||
local_attn_heads <= heads
|
||||
), "number of local attention heads must be less than total heads"
|
||||
assert not (
|
||||
receives_context and local_attn_heads > 0
|
||||
), "local attention cannot be used for self attention with context"
|
||||
assert not (
|
||||
receives_context and causal
|
||||
), "contextual attention layer cannot be causal"
|
||||
|
||||
local_attn_window_size = default(local_attn_window_size, window_size)
|
||||
context_window_size = default(context_window_size, window_size)
|
||||
@ -535,7 +632,15 @@ class SelfAttention(nn.Module):
|
||||
|
||||
if self.local_attn_heads > 0:
|
||||
rel_pos_emb_config = (dim_head, local_attn_heads) if rel_pos_emb else None
|
||||
self.local_attn = LocalAttention(dim=dim_head, window_size=local_attn_window_size, causal=causal, dropout=attn_dropout, rel_pos_emb_config=rel_pos_emb_config, look_backward=local_attn_radius_blocks, look_forward=0 if causal else local_attn_radius_blocks)
|
||||
self.local_attn = LocalAttention(
|
||||
dim=dim_head,
|
||||
window_size=local_attn_window_size,
|
||||
causal=causal,
|
||||
dropout=attn_dropout,
|
||||
rel_pos_emb_config=rel_pos_emb_config,
|
||||
look_backward=local_attn_radius_blocks,
|
||||
look_forward=0 if causal else local_attn_radius_blocks,
|
||||
)
|
||||
self.local_to_qkv = nn.Linear(dim, 3 * local_dim_heads)
|
||||
|
||||
# global
|
||||
@ -543,12 +648,24 @@ class SelfAttention(nn.Module):
|
||||
global_dim_heads = dim_head * self.global_attn_heads
|
||||
|
||||
if self.global_attn_heads > 0:
|
||||
self.global_attn = KmeansAttention(num_clusters, window_size, self.global_attn_heads, dim_head, causal=causal, dropout=attn_dropout, ema_decay=kmeans_ema_decay, commitment=commitment_factor, receives_context=receives_context, num_mem_kv=num_mem_kv, shared_qk=shared_qk)
|
||||
self.global_attn = KmeansAttention(
|
||||
num_clusters,
|
||||
window_size,
|
||||
self.global_attn_heads,
|
||||
dim_head,
|
||||
causal=causal,
|
||||
dropout=attn_dropout,
|
||||
ema_decay=kmeans_ema_decay,
|
||||
commitment=commitment_factor,
|
||||
receives_context=receives_context,
|
||||
num_mem_kv=num_mem_kv,
|
||||
shared_qk=shared_qk,
|
||||
)
|
||||
|
||||
self.to_q = nn.Sequential(
|
||||
Rearrange('b n c -> b c n'),
|
||||
Rearrange("b n c -> b c n"),
|
||||
DepthWiseConv1d(dim, global_dim_heads, conv_query_kernel, causal=causal),
|
||||
Rearrange('b c n -> b n c')
|
||||
Rearrange("b c n -> b n c"),
|
||||
)
|
||||
|
||||
self.to_v = nn.Linear(dim, global_dim_heads, bias=False)
|
||||
@ -561,14 +678,30 @@ class SelfAttention(nn.Module):
|
||||
self.to_out = nn.Linear(dim_heads, dim, bias=False)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(self, query, key, value, context=None, key_padding_mask=None, context_mask=None, pos_emb=None, **kwargs):
|
||||
assert not (self.receives_context and not exists(context)), 'context must be passed if self attention is set to receive context'
|
||||
def forward(
|
||||
self,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
context=None,
|
||||
key_padding_mask=None,
|
||||
context_mask=None,
|
||||
pos_emb=None,
|
||||
**kwargs
|
||||
):
|
||||
assert not (
|
||||
self.receives_context and not exists(context)
|
||||
), "context must be passed if self attention is set to receive context"
|
||||
input_mask = key_padding_mask
|
||||
x = query.transpose(0, 1)
|
||||
b, t, _, h, dh = *x.shape, self.heads, self.dim_head
|
||||
has_local, has_global = map(lambda x: x > 0, (self.local_attn_heads, self.global_attn_heads))
|
||||
has_local, has_global = map(
|
||||
lambda x: x > 0, (self.local_attn_heads, self.global_attn_heads)
|
||||
)
|
||||
|
||||
split_heads = lambda v: reshape_dim(v, -1, (-1, dh)).transpose(1, 2).contiguous()
|
||||
split_heads = (
|
||||
lambda v: reshape_dim(v, -1, (-1, dh)).transpose(1, 2).contiguous()
|
||||
)
|
||||
|
||||
if has_local:
|
||||
local_qkv = self.local_to_qkv(x).chunk(3, dim=-1)
|
||||
@ -587,7 +720,7 @@ class SelfAttention(nn.Module):
|
||||
q, k, v = map(split_heads, (q, k, v))
|
||||
|
||||
out = []
|
||||
total_loss = torch.tensor(0., requires_grad=True, **to(x))
|
||||
total_loss = torch.tensor(0.0, requires_grad=True, **to(x))
|
||||
|
||||
if has_local:
|
||||
local_out = self.local_attn(lq, lk, lv, input_mask=input_mask)
|
||||
@ -597,7 +730,9 @@ class SelfAttention(nn.Module):
|
||||
if not self.receives_context and exists(pos_emb):
|
||||
q, k = apply_rotary_pos_emb(q, k, pos_emb)
|
||||
|
||||
global_out, loss = self.global_attn(q, k, v, query_mask=input_mask, key_mask=context_mask)
|
||||
global_out, loss = self.global_attn(
|
||||
q, k, v, query_mask=input_mask, key_mask=context_mask
|
||||
)
|
||||
total_loss = total_loss + loss
|
||||
|
||||
out.append(global_out)
|
||||
|
@ -13,6 +13,7 @@ from .conv_tbc import ConvTBC
|
||||
from typing import Dict, Optional
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
@with_incremental_state
|
||||
class LinearizedConvolution(ConvTBC):
|
||||
"""An optimized version of nn.Conv1d.
|
||||
@ -41,7 +42,11 @@ class LinearizedConvolution(ConvTBC):
|
||||
del state_dict[prefix + "_linearized_weight"]
|
||||
|
||||
@torch.jit.export
|
||||
def forward(self, input, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None):
|
||||
def forward(
|
||||
self,
|
||||
input,
|
||||
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
incremental_state: Used to buffer signal; if not None, then input is
|
||||
@ -80,18 +85,28 @@ class LinearizedConvolution(ConvTBC):
|
||||
return output.view(bsz, 1, -1)
|
||||
|
||||
@torch.jit.unused
|
||||
def reorder_incremental_state(self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], new_order):
|
||||
def reorder_incremental_state(
|
||||
self,
|
||||
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]],
|
||||
new_order,
|
||||
):
|
||||
input_buffer = self._get_input_buffer(incremental_state)
|
||||
if input_buffer is not None:
|
||||
input_buffer = input_buffer.index_select(0, new_order)
|
||||
self._set_input_buffer(incremental_state, input_buffer)
|
||||
|
||||
@torch.jit.unused
|
||||
def _get_input_buffer(self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]):
|
||||
def _get_input_buffer(
|
||||
self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]
|
||||
):
|
||||
return utils.get_incremental_state(self, incremental_state, "input_buffer")
|
||||
|
||||
@torch.jit.unused
|
||||
def _set_input_buffer(self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], new_buffer):
|
||||
def _set_input_buffer(
|
||||
self,
|
||||
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]],
|
||||
new_buffer,
|
||||
):
|
||||
return utils.set_incremental_state(
|
||||
self, incremental_state, "input_buffer", new_buffer
|
||||
)
|
||||
|
@ -20,9 +20,16 @@ class LocationAttention(nn.Module):
|
||||
:param int conv_kernel_size: filter size of attention convolution
|
||||
"""
|
||||
|
||||
def __init__(self, attn_dim, encoder_dim, decoder_dim,
|
||||
attn_state_kernel_size, conv_dim, conv_kernel_size,
|
||||
scaling=2.0):
|
||||
def __init__(
|
||||
self,
|
||||
attn_dim,
|
||||
encoder_dim,
|
||||
decoder_dim,
|
||||
attn_state_kernel_size,
|
||||
conv_dim,
|
||||
conv_kernel_size,
|
||||
scaling=2.0,
|
||||
):
|
||||
super(LocationAttention, self).__init__()
|
||||
self.attn_dim = attn_dim
|
||||
self.decoder_dim = decoder_dim
|
||||
@ -30,9 +37,13 @@ class LocationAttention(nn.Module):
|
||||
self.proj_enc = nn.Linear(encoder_dim, attn_dim)
|
||||
self.proj_dec = nn.Linear(decoder_dim, attn_dim, bias=False)
|
||||
self.proj_attn = nn.Linear(conv_dim, attn_dim, bias=False)
|
||||
self.conv = nn.Conv1d(attn_state_kernel_size, conv_dim,
|
||||
2 * conv_kernel_size + 1,
|
||||
padding=conv_kernel_size, bias=False)
|
||||
self.conv = nn.Conv1d(
|
||||
attn_state_kernel_size,
|
||||
conv_dim,
|
||||
2 * conv_kernel_size + 1,
|
||||
padding=conv_kernel_size,
|
||||
bias=False,
|
||||
)
|
||||
self.proj_out = nn.Sequential(nn.Tanh(), nn.Linear(attn_dim, 1))
|
||||
|
||||
self.proj_enc_out = None # cache
|
||||
|
@ -12,20 +12,20 @@ class LSTMCellWithZoneOut(nn.Module):
|
||||
https://arxiv.org/abs/1606.01305
|
||||
"""
|
||||
|
||||
def __init__(self, prob: float, input_size: int, hidden_size: int,
|
||||
bias: bool = True):
|
||||
def __init__(
|
||||
self, prob: float, input_size: int, hidden_size: int, bias: bool = True
|
||||
):
|
||||
super(LSTMCellWithZoneOut, self).__init__()
|
||||
self.lstm_cell = nn.LSTMCell(input_size, hidden_size, bias=bias)
|
||||
self.prob = prob
|
||||
if prob > 1.0 or prob < 0.0:
|
||||
raise ValueError("zoneout probability must be in the range from "
|
||||
"0.0 to 1.0.")
|
||||
raise ValueError(
|
||||
"zoneout probability must be in the range from " "0.0 to 1.0."
|
||||
)
|
||||
|
||||
def zoneout(self, h, next_h, prob):
|
||||
if isinstance(h, tuple):
|
||||
return tuple(
|
||||
[self.zoneout(h[i], next_h[i], prob) for i in range(len(h))]
|
||||
)
|
||||
return tuple([self.zoneout(h[i], next_h[i], prob) for i in range(len(h))])
|
||||
|
||||
if self.training:
|
||||
mask = h.new_zeros(*h.size()).bernoulli_(prob)
|
||||
|
@ -60,7 +60,9 @@ def quantize_model_(
|
||||
to layers_to_quantize[step]
|
||||
"""
|
||||
|
||||
quantized_layers = get_layers(model, layers_to_quantize[step], remove_weights=remove_weights)
|
||||
quantized_layers = get_layers(
|
||||
model, layers_to_quantize[step], remove_weights=remove_weights
|
||||
)
|
||||
|
||||
for layer in quantized_layers:
|
||||
|
||||
@ -108,8 +110,8 @@ def quantize_model_(
|
||||
centroids = torch.rand(centroids.size())
|
||||
centroids.cuda()
|
||||
# Get counts and assignment keys from layer in loaded checkpoint.
|
||||
counts_key = layer+"."+"counts"
|
||||
assignment_key = layer+"."+"assignments"
|
||||
counts_key = layer + "." + "counts"
|
||||
assignment_key = layer + "." + "assignments"
|
||||
# Get number of different bins to include.
|
||||
counts = list(state_dict[counts_key].shape)[0]
|
||||
print(layer)
|
||||
@ -122,7 +124,7 @@ def quantize_model_(
|
||||
print(num_assignments)
|
||||
print(num_extra)
|
||||
assignments_bins = torch.arange(counts)
|
||||
assignments_rand = torch.randint(0, counts-1, (num_extra, ))
|
||||
assignments_rand = torch.randint(0, counts - 1, (num_extra,))
|
||||
assignments = torch.cat((assignments_bins, assignments_rand), 0)
|
||||
# assignments = assignments.type(torch.IntTensor)
|
||||
assignments.cuda()
|
||||
|
@ -16,7 +16,9 @@ from .modules import ActivationQuantizer, IntConv2d, IntEmbedding, IntLinear
|
||||
MAPPING = {nn.Linear: IntLinear, nn.Embedding: IntEmbedding, nn.Conv2d: IntConv2d}
|
||||
|
||||
|
||||
def quantize_model_(model, p=0.2, bits=8, update_step=3000, method="histogram", remove_weights=False):
|
||||
def quantize_model_(
|
||||
model, p=0.2, bits=8, update_step=3000, method="histogram", remove_weights=False
|
||||
):
|
||||
"""
|
||||
Replaces all modules with their scalar quantized counterpart and
|
||||
registers hooks to quantize the post-ativations of those modules.
|
||||
|
@ -132,8 +132,7 @@ class TransformerEncoderLayerBase(nn.Module):
|
||||
# will become -inf, which results in NaN in model parameters
|
||||
if attn_mask is not None:
|
||||
attn_mask = attn_mask.masked_fill(
|
||||
attn_mask.to(torch.bool),
|
||||
-1e8 if x.dtype == torch.float32 else -1e4
|
||||
attn_mask.to(torch.bool), -1e8 if x.dtype == torch.float32 else -1e4
|
||||
)
|
||||
|
||||
residual = x
|
||||
@ -213,11 +212,19 @@ class TransformerDecoderLayerBase(nn.Module):
|
||||
add_bias_kv=add_bias_kv,
|
||||
add_zero_attn=add_zero_attn,
|
||||
)
|
||||
self.attn_ln = LayerNorm(self.embed_dim) if utils.safe_getattr(cfg, 'scale_attn', False) else None
|
||||
self.attn_ln = (
|
||||
LayerNorm(self.embed_dim)
|
||||
if utils.safe_getattr(cfg, "scale_attn", False)
|
||||
else None
|
||||
)
|
||||
self.nh = self.self_attn.num_heads
|
||||
self.head_dim = self.self_attn.head_dim
|
||||
scale_heads = utils.safe_getattr(cfg, 'scale_heads', False)
|
||||
self.c_attn = nn.Parameter(torch.ones((self.nh,)), requires_grad=True) if scale_heads else None
|
||||
scale_heads = utils.safe_getattr(cfg, "scale_heads", False)
|
||||
self.c_attn = (
|
||||
nn.Parameter(torch.ones((self.nh,)), requires_grad=True)
|
||||
if scale_heads
|
||||
else None
|
||||
)
|
||||
|
||||
self.activation_fn = utils.get_activation_fn(activation=cfg.activation_fn)
|
||||
activation_dropout_p = cfg.activation_dropout
|
||||
@ -238,8 +245,21 @@ class TransformerDecoderLayerBase(nn.Module):
|
||||
self.encoder_attn = self.build_encoder_attention(self.embed_dim, cfg)
|
||||
self.encoder_attn_layer_norm = LayerNorm(self.embed_dim, export=cfg.export)
|
||||
|
||||
self.ffn_layernorm = LayerNorm(cfg.decoder.ffn_embed_dim) if utils.safe_getattr(cfg, 'scale_fc', False) else None
|
||||
self.w_resid = nn.Parameter(torch.ones(self.embed_dim, ), requires_grad=True) if utils.safe_getattr(cfg, 'scale_resids', False) else None
|
||||
self.ffn_layernorm = (
|
||||
LayerNorm(cfg.decoder.ffn_embed_dim)
|
||||
if utils.safe_getattr(cfg, "scale_fc", False)
|
||||
else None
|
||||
)
|
||||
self.w_resid = (
|
||||
nn.Parameter(
|
||||
torch.ones(
|
||||
self.embed_dim,
|
||||
),
|
||||
requires_grad=True,
|
||||
)
|
||||
if utils.safe_getattr(cfg, "scale_resids", False)
|
||||
else None
|
||||
)
|
||||
|
||||
self.fc1 = self.build_fc1(
|
||||
self.embed_dim,
|
||||
@ -297,7 +317,6 @@ class TransformerDecoderLayerBase(nn.Module):
|
||||
def residual_connection(self, x, residual):
|
||||
return residual + x
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
@ -377,7 +396,7 @@ class TransformerDecoderLayerBase(nn.Module):
|
||||
if self.c_attn is not None:
|
||||
tgt_len, bsz = x.size(0), x.size(1)
|
||||
x = x.view(tgt_len, bsz, self.nh, self.head_dim)
|
||||
x = torch.einsum('tbhd,h->tbhd', x, self.c_attn)
|
||||
x = torch.einsum("tbhd,h->tbhd", x, self.c_attn)
|
||||
x = x.reshape(tgt_len, bsz, self.embed_dim)
|
||||
if self.attn_ln is not None:
|
||||
x = self.attn_ln(x)
|
||||
|
@ -35,9 +35,7 @@ def init_bert_params(module):
|
||||
def normal_(data):
|
||||
# with FSDP, module params will be on CUDA, so we cast them back to CPU
|
||||
# so that the RNG is consistent with and without FSDP
|
||||
data.copy_(
|
||||
data.cpu().normal_(mean=0.0, std=0.02).to(data.device)
|
||||
)
|
||||
data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device))
|
||||
|
||||
if isinstance(module, nn.Linear):
|
||||
normal_(module.weight.data)
|
||||
@ -276,7 +274,9 @@ class TransformerSentenceEncoder(nn.Module):
|
||||
inner_states.append(x)
|
||||
|
||||
for layer in self.layers:
|
||||
x, _ = layer(x, self_attn_padding_mask=padding_mask, self_attn_mask=attn_mask)
|
||||
x, _ = layer(
|
||||
x, self_attn_padding_mask=padding_mask, self_attn_mask=attn_mask
|
||||
)
|
||||
if not last_state_only:
|
||||
inner_states.append(x)
|
||||
|
||||
|
@ -2,13 +2,13 @@
|
||||
# Licensed under the MIT License.
|
||||
|
||||
""" Wrapper for ngram_repeat_block cuda extension """
|
||||
import math
|
||||
import warnings
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
import math
|
||||
from typing import Dict, List, Optional
|
||||
import warnings
|
||||
|
||||
try:
|
||||
from fairseq import ngram_repeat_block_cuda
|
||||
|
||||
@ -37,7 +37,7 @@ def is_cuda_extension_usable() -> bool:
|
||||
|
||||
|
||||
class NGramRepeatBlock(nn.Module):
|
||||
""" Wrapper class for calling ngram_repeat_block cuda extension """
|
||||
"""Wrapper class for calling ngram_repeat_block cuda extension"""
|
||||
|
||||
def __init__(self, no_repeat_ngram_size: int, use_extension: bool = True):
|
||||
super().__init__()
|
||||
|
@ -67,13 +67,13 @@ class FairseqAdam(FairseqOptimizer):
|
||||
elif use_fused_adam:
|
||||
logger.info("using FusedAdam")
|
||||
self._optimizer = fused_adam_cls(
|
||||
params,
|
||||
use_fp16_stats=self.cfg.fp16_adam_stats,
|
||||
**self.optimizer_config
|
||||
params, use_fp16_stats=self.cfg.fp16_adam_stats, **self.optimizer_config
|
||||
)
|
||||
else:
|
||||
if self.cfg.fp16_adam_stats:
|
||||
raise NotImplementedError("--fp16-adam-stats is only supported with FusedAdamV1")
|
||||
raise NotImplementedError(
|
||||
"--fp16-adam-stats is only supported with FusedAdamV1"
|
||||
)
|
||||
self._optimizer = Adam(params, **self.optimizer_config)
|
||||
|
||||
@property
|
||||
|
@ -63,8 +63,9 @@ class AMPOptimizer(optim.FairseqOptimizer):
|
||||
).format(self.min_loss_scale, new_loss_scale)
|
||||
)
|
||||
else:
|
||||
logger.info("AMP: overflow detected, setting scale to "
|
||||
f"to {new_loss_scale}")
|
||||
logger.info(
|
||||
"AMP: overflow detected, setting scale to " f"to {new_loss_scale}"
|
||||
)
|
||||
return grad_norm
|
||||
|
||||
@property
|
||||
|
@ -23,7 +23,9 @@ class OptimizerAndSchedulerConfig(FairseqDataclass):
|
||||
optimizer: Any = None
|
||||
lr_scheduler: Optional[Any] = None
|
||||
lr: List = II("optimization.lr")
|
||||
lr_float: Optional[float] = None # this makes it easier to sweep on learning rate with auto sweepers
|
||||
lr_float: Optional[
|
||||
float
|
||||
] = None # this makes it easier to sweep on learning rate with auto sweepers
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -16,6 +16,7 @@ from omegaconf import II, DictConfig
|
||||
|
||||
try:
|
||||
import deepspeed
|
||||
|
||||
has_deepspeed = True
|
||||
except ImportError as e:
|
||||
has_deepspeed = False
|
||||
@ -24,12 +25,15 @@ except ImportError as e:
|
||||
def _get_cpu_adam():
|
||||
try:
|
||||
from deepspeed.ops.op_builder import CPUAdamBuilder
|
||||
|
||||
return CPUAdamBuilder().load()
|
||||
except ImportError:
|
||||
# fbcode
|
||||
from deepspeed.ops.adam import DeepSpeedCPUAdam as ds_opt_adam
|
||||
|
||||
return ds_opt_adam
|
||||
|
||||
|
||||
@dataclass
|
||||
class FairseqCPUAdamConfig(FairseqDataclass):
|
||||
adam_betas: str = field(
|
||||
|
@ -64,9 +64,9 @@ class _FP16OptimizerMixin(object):
|
||||
fp32_params = []
|
||||
for p in params:
|
||||
p32 = torch.nn.Parameter(p.data.float())
|
||||
if hasattr(p, 'expert'):
|
||||
if hasattr(p, "expert"):
|
||||
p32.expert = True
|
||||
elif hasattr(p, 'base_expert'):
|
||||
elif hasattr(p, "base_expert"):
|
||||
p32.base_expert = True
|
||||
p32.grad = torch.zeros_like(p32.data)
|
||||
if hasattr(p, "param_group"):
|
||||
@ -209,7 +209,9 @@ class _FP16OptimizerMixin(object):
|
||||
self._sync_fp16_grads_to_fp32()
|
||||
|
||||
if getattr(self, "supports_step_with_scale", False):
|
||||
self.fp32_optimizer.step(closure, scale=(1.0 / self._multiply_factor), groups=groups)
|
||||
self.fp32_optimizer.step(
|
||||
closure, scale=(1.0 / self._multiply_factor), groups=groups
|
||||
)
|
||||
else:
|
||||
self._unscale_grads()
|
||||
self.fp32_optimizer.step(closure, groups=groups)
|
||||
@ -434,7 +436,9 @@ class _MemoryEfficientFP16OptimizerMixin(object):
|
||||
"""Performs a single optimization step."""
|
||||
if getattr(self, "supports_step_with_scale", False):
|
||||
# NOTE(msb) optimizer divides by scale factor
|
||||
self.wrapped_optimizer.step(closure, scale=(1.0 / self._multiply_factor), groups=groups)
|
||||
self.wrapped_optimizer.step(
|
||||
closure, scale=(1.0 / self._multiply_factor), groups=groups
|
||||
)
|
||||
else:
|
||||
self._unscale_grads()
|
||||
self.wrapped_optimizer.step(closure, groups=groups)
|
||||
|
@ -179,7 +179,7 @@ class FusedAdamV1(torch.optim.Optimizer):
|
||||
|
||||
if p.device.type == "cpu":
|
||||
p_data_fp32 = p.data.cuda(non_blocking=True).float()
|
||||
out_p = torch.tensor([], dtype = torch.float)
|
||||
out_p = torch.tensor([], dtype=torch.float)
|
||||
else:
|
||||
p_data_fp32 = p.data.float()
|
||||
out_p = p.data
|
||||
@ -234,6 +234,7 @@ class FusedAdamV1(torch.optim.Optimizer):
|
||||
p.data.copy_(p_data_fp32, non_blocking=True)
|
||||
|
||||
if self.use_fp16_stats:
|
||||
|
||||
def inf_norm(t):
|
||||
return torch.norm(t, float("inf"))
|
||||
|
||||
@ -262,7 +263,9 @@ try:
|
||||
|
||||
def __init__(self, *args, use_fp16_stats=False, **kwargs):
|
||||
if use_fp16_stats:
|
||||
raise NotImplementedError("--fp16-adam-stats is only supported with FusedAdamV1")
|
||||
raise NotImplementedError(
|
||||
"--fp16-adam-stats is only supported with FusedAdamV1"
|
||||
)
|
||||
super().__init__(*args, **kwargs)
|
||||
if not hasattr(self, "multi_tensor_adam"):
|
||||
raise Exception(
|
||||
|
@ -32,7 +32,7 @@ class ManualSchedule(LegacyFairseqLRScheduler):
|
||||
self.optimizer.set_lr(self.lr) # Set the beginning of the epoch.
|
||||
|
||||
def parse_manuallr_args(self, lr_args_str):
|
||||
lr_dict = ast.literal_eval(lr_args_str.replace(' ', ''))
|
||||
lr_dict = ast.literal_eval(lr_args_str.replace(" ", ""))
|
||||
if not isinstance(lr_dict, dict):
|
||||
raise ValueError("epoch2lr/update2lr must be abel to evaluated to a dict")
|
||||
|
||||
@ -84,9 +84,14 @@ class ManualSchedule(LegacyFairseqLRScheduler):
|
||||
if manual_keys:
|
||||
manual_lr = self.epoch2lr[max(manual_keys)]
|
||||
else:
|
||||
logger.warning("@@@ epoch={} does not exist in manual lr input. epoch2lr={}...".format(
|
||||
epoch, list(self.epoch2lr.items())[:min(10, len(self.epoch2lr.keys())-1)]
|
||||
))
|
||||
logger.warning(
|
||||
"@@@ epoch={} does not exist in manual lr input. epoch2lr={}...".format(
|
||||
epoch,
|
||||
list(self.epoch2lr.items())[
|
||||
: min(10, len(self.epoch2lr.keys()) - 1)
|
||||
],
|
||||
)
|
||||
)
|
||||
manual_lr = self.optimizer.get_lr()
|
||||
return manual_lr
|
||||
|
||||
@ -102,8 +107,14 @@ class ManualSchedule(LegacyFairseqLRScheduler):
|
||||
if manual_keys:
|
||||
manual_lr = self.update2lr[max(manual_keys)]
|
||||
else:
|
||||
logger.warning("epoch={} does not exist in manual lr input update2lr={}...".format(
|
||||
num_updates, list(self.update2lr.items())[:min(10, len(self.update2lr.keys())-1)]))
|
||||
logger.warning(
|
||||
"epoch={} does not exist in manual lr input update2lr={}...".format(
|
||||
num_updates,
|
||||
list(self.update2lr.items())[
|
||||
: min(10, len(self.update2lr.keys()) - 1)
|
||||
],
|
||||
)
|
||||
)
|
||||
manual_lr = self.optimizer.get_lr()
|
||||
|
||||
self.optimizer.set_lr(manual_lr)
|
||||
|
@ -36,8 +36,7 @@ class StepLRScheduleConfig(FairseqDataclass):
|
||||
|
||||
@register_lr_scheduler("step", dataclass=StepLRScheduleConfig)
|
||||
class StepLRSchedule(FairseqLRScheduler):
|
||||
"""Decay learning rate every k updates by a fixed factor
|
||||
"""
|
||||
"""Decay learning rate every k updates by a fixed factor"""
|
||||
|
||||
def __init__(self, cfg: StepLRScheduleConfig, fairseq_optimizer):
|
||||
super().__init__(cfg, fairseq_optimizer)
|
||||
@ -50,16 +49,16 @@ class StepLRSchedule(FairseqLRScheduler):
|
||||
cfg.warmup_init_lr if cfg.warmup_init_lr >= 0 else self.min_lr
|
||||
)
|
||||
|
||||
assert(self.lr_deacy_period > 0)
|
||||
assert(self.lr_decay <= 1)
|
||||
assert(self.min_lr >= 0)
|
||||
assert(self.max_lr > self.min_lr)
|
||||
assert self.lr_deacy_period > 0
|
||||
assert self.lr_decay <= 1
|
||||
assert self.min_lr >= 0
|
||||
assert self.max_lr > self.min_lr
|
||||
|
||||
if cfg.warmup_updates > 0:
|
||||
# linearly warmup for the first cfg.warmup_updates
|
||||
self.warmup_lr_step = (
|
||||
(self.max_lr - self.warmup_init_lr) / self.warmup_updates
|
||||
)
|
||||
self.max_lr - self.warmup_init_lr
|
||||
) / self.warmup_updates
|
||||
else:
|
||||
self.warmup_lr_step = 1
|
||||
|
||||
|
@ -171,7 +171,9 @@ class SequenceGenerator(nn.Module):
|
||||
yield id, src, ref, hypos[i]
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(self, models, sample: Dict[str, Dict[str, Tensor]], **kwargs) -> List[List[Dict[str, Tensor]]]:
|
||||
def generate(
|
||||
self, models, sample: Dict[str, Dict[str, Tensor]], **kwargs
|
||||
) -> List[List[Dict[str, Tensor]]]:
|
||||
"""Generate translations. Match the api of other fairseq generators.
|
||||
|
||||
Args:
|
||||
@ -223,7 +225,10 @@ class SequenceGenerator(nn.Module):
|
||||
else torch.tensor(src_tokens.size(-1)).to(src_tokens)
|
||||
)
|
||||
else:
|
||||
raise Exception("expected src_tokens or source in net input. input keys: " + str(net_input.keys()))
|
||||
raise Exception(
|
||||
"expected src_tokens or source in net input. input keys: "
|
||||
+ str(net_input.keys())
|
||||
)
|
||||
|
||||
# bsz: total number of sentences in beam
|
||||
# Note that src_tokens may have more than 2 dimensions (i.e. audio features)
|
||||
@ -328,7 +333,9 @@ class SequenceGenerator(nn.Module):
|
||||
encoder_outs = self.model.reorder_encoder_out(
|
||||
encoder_outs, reorder_state
|
||||
)
|
||||
with torch.autograd.profiler.record_function("EnsembleModel: forward_decoder"):
|
||||
with torch.autograd.profiler.record_function(
|
||||
"EnsembleModel: forward_decoder"
|
||||
):
|
||||
lprobs, avg_attn_scores = self.model.forward_decoder(
|
||||
tokens[:, : step + 1],
|
||||
encoder_outs,
|
||||
@ -751,7 +758,14 @@ class EnsembleModel(nn.Module):
|
||||
return self.has_incremental
|
||||
|
||||
def max_decoder_positions(self):
|
||||
return min([m.max_decoder_positions() for m in self.models if hasattr(m, "max_decoder_positions")] + [sys.maxsize])
|
||||
return min(
|
||||
[
|
||||
m.max_decoder_positions()
|
||||
for m in self.models
|
||||
if hasattr(m, "max_decoder_positions")
|
||||
]
|
||||
+ [sys.maxsize]
|
||||
)
|
||||
|
||||
@torch.jit.export
|
||||
def forward_encoder(self, net_input: Dict[str, Tensor]):
|
||||
|
@ -35,8 +35,12 @@ class SpeechGenerator(object):
|
||||
|
||||
class AutoRegressiveSpeechGenerator(SpeechGenerator):
|
||||
def __init__(
|
||||
self, model, vocoder, data_cfg, max_iter: int = 6000,
|
||||
eos_prob_threshold: float = 0.5,
|
||||
self,
|
||||
model,
|
||||
vocoder,
|
||||
data_cfg,
|
||||
max_iter: int = 6000,
|
||||
eos_prob_threshold: float = 0.5,
|
||||
):
|
||||
super().__init__(model, vocoder, data_cfg)
|
||||
self.max_iter = max_iter
|
||||
@ -54,8 +58,9 @@ class AutoRegressiveSpeechGenerator(SpeechGenerator):
|
||||
raw_dim = out_dim // n_frames_per_step
|
||||
|
||||
# initialize
|
||||
encoder_out = model.forward_encoder(src_tokens, src_lengths,
|
||||
speaker=sample["speaker"])
|
||||
encoder_out = model.forward_encoder(
|
||||
src_tokens, src_lengths, speaker=sample["speaker"]
|
||||
)
|
||||
incremental_state = {}
|
||||
feat, attn, eos_prob = [], [], []
|
||||
finished = src_tokens.new_zeros((bsz,)).bool()
|
||||
@ -66,21 +71,24 @@ class AutoRegressiveSpeechGenerator(SpeechGenerator):
|
||||
cur_out_lens = out_lens.clone()
|
||||
cur_out_lens.masked_fill_(cur_out_lens.eq(self.max_iter), step + 1)
|
||||
_, cur_eos_out, cur_extra = model.forward_decoder(
|
||||
prev_feat_out, encoder_out=encoder_out,
|
||||
prev_feat_out,
|
||||
encoder_out=encoder_out,
|
||||
incremental_state=incremental_state,
|
||||
target_lengths=cur_out_lens, speaker=sample["speaker"], **kwargs
|
||||
target_lengths=cur_out_lens,
|
||||
speaker=sample["speaker"],
|
||||
**kwargs
|
||||
)
|
||||
cur_eos_prob = torch.sigmoid(cur_eos_out).squeeze(2)
|
||||
feat.append(cur_extra['feature_out'])
|
||||
attn.append(cur_extra['attn'])
|
||||
feat.append(cur_extra["feature_out"])
|
||||
attn.append(cur_extra["attn"])
|
||||
eos_prob.append(cur_eos_prob)
|
||||
|
||||
cur_finished = (cur_eos_prob.squeeze(1) > self.eos_prob_threshold)
|
||||
cur_finished = cur_eos_prob.squeeze(1) > self.eos_prob_threshold
|
||||
out_lens.masked_fill_((~finished) & cur_finished, step + 1)
|
||||
finished = finished | cur_finished
|
||||
if finished.sum().item() == bsz:
|
||||
break
|
||||
prev_feat_out = cur_extra['feature_out']
|
||||
prev_feat_out = cur_extra["feature_out"]
|
||||
|
||||
feat = torch.cat(feat, dim=1)
|
||||
feat = model.decoder.postnet(feat) + feat
|
||||
@ -98,11 +106,11 @@ class AutoRegressiveSpeechGenerator(SpeechGenerator):
|
||||
|
||||
finalized = [
|
||||
{
|
||||
'feature': feat[b, :out_len],
|
||||
'eos_prob': eos_prob[b, :out_len],
|
||||
'attn': attn[b, :, :out_len],
|
||||
'alignment': alignment[b, :out_len],
|
||||
'waveform': self.get_waveform(feat[b, :out_len]),
|
||||
"feature": feat[b, :out_len],
|
||||
"eos_prob": eos_prob[b, :out_len],
|
||||
"attn": attn[b, :, :out_len],
|
||||
"alignment": alignment[b, :out_len],
|
||||
"waveform": self.get_waveform(feat[b, :out_len]),
|
||||
}
|
||||
for b, out_len in zip(range(bsz), out_lens)
|
||||
]
|
||||
@ -134,7 +142,7 @@ class NonAutoregressiveSpeechGenerator(SpeechGenerator):
|
||||
prev_output_tokens=sample["net_input"]["prev_output_tokens"],
|
||||
incremental_state=None,
|
||||
target_lengths=sample["target_lengths"],
|
||||
speaker=sample["speaker"]
|
||||
speaker=sample["speaker"],
|
||||
)
|
||||
if feat_post is not None:
|
||||
feat = feat_post
|
||||
@ -142,9 +150,7 @@ class NonAutoregressiveSpeechGenerator(SpeechGenerator):
|
||||
feat = feat.view(bsz, -1, raw_dim)
|
||||
feat = self.gcmvn_denormalize(feat)
|
||||
|
||||
dur_out = torch.clamp(
|
||||
torch.round(torch.exp(log_dur_out) - 1).long(), min=0
|
||||
)
|
||||
dur_out = torch.clamp(torch.round(torch.exp(log_dur_out) - 1).long(), min=0)
|
||||
|
||||
def get_dur_plot_data(d):
|
||||
r = []
|
||||
@ -155,11 +161,11 @@ class NonAutoregressiveSpeechGenerator(SpeechGenerator):
|
||||
out_lens = out_lens * n_frames_per_step
|
||||
finalized = [
|
||||
{
|
||||
'feature': feat[b, :l] if l > 0 else feat.new_zeros([1, raw_dim]),
|
||||
'waveform': self.get_waveform(
|
||||
"feature": feat[b, :l] if l > 0 else feat.new_zeros([1, raw_dim]),
|
||||
"waveform": self.get_waveform(
|
||||
feat[b, :l] if l > 0 else feat.new_zeros([1, raw_dim])
|
||||
),
|
||||
'attn': feat.new_tensor(get_dur_plot_data(dur_out[b])),
|
||||
"attn": feat.new_tensor(get_dur_plot_data(dur_out[b])),
|
||||
}
|
||||
for b, l in zip(range(bsz), out_lens)
|
||||
]
|
||||
@ -188,8 +194,12 @@ class TeacherForcingAutoRegressiveSpeechGenerator(AutoRegressiveSpeechGenerator)
|
||||
bsz = src_tokens.shape[0]
|
||||
|
||||
feat, eos_prob, extra = model(
|
||||
src_tokens, src_lens, prev_out_tokens, incremental_state=None,
|
||||
target_lengths=tgt_lens, speaker=sample["speaker"]
|
||||
src_tokens,
|
||||
src_lens,
|
||||
prev_out_tokens,
|
||||
incremental_state=None,
|
||||
target_lengths=tgt_lens,
|
||||
speaker=sample["speaker"],
|
||||
)
|
||||
|
||||
attn = extra["attn"] # B x T_s x T_t
|
||||
@ -203,11 +213,11 @@ class TeacherForcingAutoRegressiveSpeechGenerator(AutoRegressiveSpeechGenerator)
|
||||
|
||||
finalized = [
|
||||
{
|
||||
'feature': feat[b, :tgt_len],
|
||||
'eos_prob': eos_prob[b, :tgt_len],
|
||||
'attn': attn[b, :, :tgt_len],
|
||||
'alignment': alignment[b, :tgt_len],
|
||||
'waveform': self.get_waveform(feat[b, :tgt_len]),
|
||||
"feature": feat[b, :tgt_len],
|
||||
"eos_prob": eos_prob[b, :tgt_len],
|
||||
"attn": attn[b, :, :tgt_len],
|
||||
"alignment": alignment[b, :tgt_len],
|
||||
"waveform": self.get_waveform(feat[b, :tgt_len]),
|
||||
}
|
||||
for b, tgt_len in zip(range(bsz), tgt_lens)
|
||||
]
|
||||
|
@ -67,31 +67,31 @@ class AudioFinetuningConfig(AudioPretrainingConfig):
|
||||
default=False, metadata={"help": "evaluation with BLEU scores"}
|
||||
)
|
||||
eval_bleu_detok: Optional[str] = field(
|
||||
default=None, metadata={
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "detokenize before computing BLEU (e.g., 'moses'); "
|
||||
"required if using --eval-bleu; use 'space' to disable "
|
||||
"detokenization; see fairseq.data.encoders for other options"
|
||||
}
|
||||
"required if using --eval-bleu; use 'space' to disable "
|
||||
"detokenization; see fairseq.data.encoders for other options"
|
||||
},
|
||||
)
|
||||
eval_bleu_detok_args: str = field(
|
||||
default="{}",
|
||||
metadata={"help": "args for building the tokenizer, if needed"}
|
||||
default="{}", metadata={"help": "args for building the tokenizer, if needed"}
|
||||
)
|
||||
eval_tokenized_bleu: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "compute tokenized BLEU instead of sacrebleu"}
|
||||
default=False, metadata={"help": "compute tokenized BLEU instead of sacrebleu"}
|
||||
)
|
||||
eval_bleu_remove_bpe: Optional[str] = field(
|
||||
default=None, metadata={"help": "remove BPE before computing BLEU"}
|
||||
)
|
||||
eval_bleu_args: str = field(
|
||||
default="{}",
|
||||
metadata={"help": "generation args for BLUE scoring, e.g., "
|
||||
"'{\"beam\": 4, \"lenpen\": 0.6}'"}
|
||||
metadata={
|
||||
"help": "generation args for BLUE scoring, e.g., "
|
||||
'\'{"beam": 4, "lenpen": 0.6}\''
|
||||
},
|
||||
)
|
||||
eval_bleu_print_samples: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "print sample generations during validation"}
|
||||
default=False, metadata={"help": "print sample generations during validation"}
|
||||
)
|
||||
autoregressive: bool = field(
|
||||
default=False,
|
||||
@ -123,7 +123,9 @@ class AudioFinetuningTask(AudioPretrainingTask):
|
||||
return Dictionary.load(dict_path)
|
||||
return None
|
||||
|
||||
def load_dataset(self, split: str, task_cfg: AudioFinetuningConfig = None, **kwargs):
|
||||
def load_dataset(
|
||||
self, split: str, task_cfg: AudioFinetuningConfig = None, **kwargs
|
||||
):
|
||||
super().load_dataset(split, task_cfg, **kwargs)
|
||||
|
||||
task_cfg = task_cfg or self.cfg
|
||||
@ -138,7 +140,8 @@ class AudioFinetuningTask(AudioPretrainingTask):
|
||||
with open(label_path, "r") as f:
|
||||
labels = [
|
||||
text_compressor.compress(l)
|
||||
for i, l in enumerate(f) if i not in skipped_indices
|
||||
for i, l in enumerate(f)
|
||||
if i not in skipped_indices
|
||||
]
|
||||
|
||||
assert len(labels) == len(self.datasets[split]), (
|
||||
@ -157,7 +160,7 @@ class AudioFinetuningTask(AudioPretrainingTask):
|
||||
process_label=process_label,
|
||||
label_len_fn=label_len_fn,
|
||||
add_to_input=task_cfg.get("autoregressive", False),
|
||||
text_compression_level=text_compression_level
|
||||
text_compression_level=text_compression_level,
|
||||
)
|
||||
|
||||
@property
|
||||
@ -176,8 +179,8 @@ class AudioFinetuningTask(AudioPretrainingTask):
|
||||
logging_output["_num_words"] = metrics["num_words"]
|
||||
if self.cfg.eval_bleu and self.cfg.autoregressive:
|
||||
metrics = self._inference_with_bleu(self.sequence_generator, sample, model)
|
||||
logging_output['_bleu_sys_len'] = metrics.sys_len
|
||||
logging_output['_bleu_ref_len'] = metrics.ref_len
|
||||
logging_output["_bleu_sys_len"] = metrics.sys_len
|
||||
logging_output["_bleu_ref_len"] = metrics.ref_len
|
||||
# we split counts into separate entries so that they can be
|
||||
# summed efficiently across workers using fast-stat-sync
|
||||
assert len(metrics.counts) == 4
|
||||
@ -200,9 +203,9 @@ class AudioFinetuningTask(AudioPretrainingTask):
|
||||
self.tokenizer = None
|
||||
if self.cfg.eval_bleu and self.cfg.autoregressive:
|
||||
assert self.cfg.eval_bleu_detok is not None, (
|
||||
'--eval-bleu-detok is required if using --eval-bleu; '
|
||||
'try --eval-bleu-detok=moses (or --eval-bleu-detok=space '
|
||||
'to disable detokenization, e.g., when using sentencepiece)'
|
||||
"--eval-bleu-detok is required if using --eval-bleu; "
|
||||
"try --eval-bleu-detok=moses (or --eval-bleu-detok=space "
|
||||
"to disable detokenization, e.g., when using sentencepiece)"
|
||||
)
|
||||
detok_args = json.loads(self.cfg.eval_bleu_detok_args)
|
||||
self.tokenizer = encoders.build_tokenizer(
|
||||
@ -261,9 +264,7 @@ class AudioFinetuningTask(AudioPretrainingTask):
|
||||
# BLEU scores. Instead, we use a somewhat more verbose
|
||||
# alternative that is unlikely to appear in the real
|
||||
# reference, but doesn't get split into multiple tokens.
|
||||
unk_string=(
|
||||
"UNKNOWNTOKENINREF" if is_ref else "UNKNOWNTOKENINHYP"
|
||||
),
|
||||
unk_string=("UNKNOWNTOKENINREF" if is_ref else "UNKNOWNTOKENINHYP"),
|
||||
)
|
||||
if self.tokenizer:
|
||||
s = self.tokenizer.decode(s)
|
||||
@ -272,21 +273,18 @@ class AudioFinetuningTask(AudioPretrainingTask):
|
||||
gen_out = self.inference_step(generator, [model], sample)
|
||||
hyps, refs = [], []
|
||||
for i in range(len(gen_out)):
|
||||
hyps.append(decode(gen_out[i][0]['tokens'], is_ref=False))
|
||||
hyps.append(decode(gen_out[i][0]["tokens"], is_ref=False))
|
||||
refs.append(
|
||||
decode(
|
||||
utils.strip_pad(
|
||||
sample['target'][i],
|
||||
self.target_dictionary.pad()
|
||||
),
|
||||
utils.strip_pad(sample["target"][i], self.target_dictionary.pad()),
|
||||
is_ref=True, # don't count <unk> as matches to the hypo
|
||||
)
|
||||
)
|
||||
if self.cfg.eval_bleu_print_samples:
|
||||
logger.info('H-{} {}'.format(sample["id"][0], hyps[0]))
|
||||
logger.info('T-{} {}'.format(sample["id"][0], refs[0]))
|
||||
logger.info("H-{} {}".format(sample["id"][0], hyps[0]))
|
||||
logger.info("T-{} {}".format(sample["id"][0], refs[0]))
|
||||
|
||||
eval_tokenization = 'none' if self.cfg.eval_tokenized_bleu else '13a'
|
||||
eval_tokenization = "none" if self.cfg.eval_tokenized_bleu else "13a"
|
||||
return sacrebleu.corpus_bleu(hyps, [refs], tokenize=eval_tokenization)
|
||||
|
||||
def reduce_metrics(self, logging_outputs, criterion):
|
||||
@ -329,18 +327,17 @@ class AudioFinetuningTask(AudioPretrainingTask):
|
||||
count_keys = [f"_bleu_counts_{i}" for i in range(4)]
|
||||
total_keys = [f"_bleu_totals_{i}" for i in range(4)]
|
||||
for k in len_keys + count_keys + total_keys:
|
||||
metrics.log_scalar(
|
||||
k, sum(log.get(k, 0) for log in logging_outputs)
|
||||
)
|
||||
metrics.log_scalar(k, sum(log.get(k, 0) for log in logging_outputs))
|
||||
|
||||
import sacrebleu
|
||||
|
||||
metrics.log_derived(
|
||||
'bleu',
|
||||
"bleu",
|
||||
lambda meters: sacrebleu.compute_bleu(
|
||||
correct=[meters[k].sum for k in count_keys],
|
||||
total=[meters[k].sum for k in total_keys],
|
||||
sys_len=meters['_bleu_sys_len'].sum,
|
||||
ref_len=meters['_bleu_ref_len'].sum,
|
||||
smooth_method="exp"
|
||||
).score
|
||||
sys_len=meters["_bleu_sys_len"].sum,
|
||||
ref_len=meters["_bleu_ref_len"].sum,
|
||||
smooth_method="exp",
|
||||
).score,
|
||||
)
|
||||
|
@ -50,8 +50,7 @@ class AudioPretrainingConfig(FairseqDataclass):
|
||||
data: str = field(default=MISSING, metadata={"help": "path to data directory"})
|
||||
labels: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "extension of the label file to load, used for fine-tuning"},
|
||||
metadata={"help": "extension of the label file to load, used for fine-tuning"},
|
||||
)
|
||||
binarized_dataset: bool = field(
|
||||
default=False,
|
||||
@ -102,8 +101,8 @@ class AudioPretrainingConfig(FairseqDataclass):
|
||||
default="none",
|
||||
metadata={
|
||||
"help": "compression level for texts (e.g. audio filenames, "
|
||||
"target texts): none/low/high (default: none). "
|
||||
}
|
||||
"target texts): none/low/high (default: none). "
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
|
@ -135,7 +135,6 @@ class DenoisingTask(LegacyFairseqTask):
|
||||
'e.g., "train,valid" (default: all dataset splits)',
|
||||
)
|
||||
|
||||
|
||||
def __init__(self, args, dictionary):
|
||||
super().__init__(args)
|
||||
self.dictionary = dictionary
|
||||
|
@ -11,20 +11,19 @@ from fairseq.tasks.text_to_speech import TextToSpeechTask
|
||||
|
||||
|
||||
logging.basicConfig(
|
||||
format='%(asctime)s | %(levelname)s | %(name)s | %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S', level=logging.INFO
|
||||
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
level=logging.INFO,
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@register_task('frm_text_to_speech')
|
||||
@register_task("frm_text_to_speech")
|
||||
class FrmTextToSpeechTask(TextToSpeechTask):
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
TextToSpeechTask.add_args(parser)
|
||||
parser.add_argument(
|
||||
"--do_chunk", action="store_true", help="train on chunks"
|
||||
)
|
||||
parser.add_argument("--do_chunk", action="store_true", help="train on chunks")
|
||||
parser.add_argument("--chunk_bound", default=-1, type=int)
|
||||
parser.add_argument("--chunk_init", default=50, type=int)
|
||||
parser.add_argument("--chunk_incr", default=5, type=int)
|
||||
@ -52,5 +51,5 @@ class FrmTextToSpeechTask(TextToSpeechTask):
|
||||
chunk_incr=self.args.chunk_incr,
|
||||
add_eos=self.args.add_eos,
|
||||
dedup=self.args.dedup,
|
||||
ref_fpu=self.args.ref_fpu
|
||||
ref_fpu=self.args.ref_fpu,
|
||||
)
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user