Add linting with black (#2678)

Summary:
# Before submitting

- [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
- [ ] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/main/CONTRIBUTING.md)?
- [ ] Did you make sure to update the docs?
- [ ] Did you write any new necessary tests?

## What does this PR do?
Fixes # (issue).

## PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

## Did you have fun?
Make sure you had fun coding �

Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/2678

Reviewed By: Mortimerp9

Differential Revision: D32653381

Pulled By: dianaml0

fbshipit-source-id: 2810d14867cd7d64f4d340740e2b590b82de47fe
This commit is contained in:
dianaml0 2021-11-29 12:30:10 -08:00 committed by Facebook GitHub Bot
parent 3dc1691df1
commit 0dfd6b6240
137 changed files with 2139 additions and 1353 deletions

View File

@ -53,3 +53,8 @@ jobs:
- name: Run tests - name: Run tests
run: | run: |
python setup.py test python setup.py test
- name: Lint with black
run: |
pip install black
black --check . --extend-exclude 'examples|fairseq\/model_parallel\/megatron'

View File

@ -27,6 +27,7 @@ sys.modules["fairseq.progress_bar"] = progress_bar
# initialize hydra # initialize hydra
from fairseq.dataclass.initialize import hydra_init from fairseq.dataclass.initialize import hydra_init
hydra_init() hydra_init()
import fairseq.criterions # noqa import fairseq.criterions # noqa

View File

@ -7,10 +7,10 @@ import logging
import numpy as np import numpy as np
import torch import torch
from fairseq.data import Dictionary, FairseqDataset from fairseq.data import Dictionary, FairseqDataset
from fairseq.tasks import LegacyFairseqTask, register_task from fairseq.tasks import LegacyFairseqTask, register_task
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -36,7 +36,7 @@ class DummyMTTask(LegacyFairseqTask):
@classmethod @classmethod
def setup_task(cls, args, **kwargs): def setup_task(cls, args, **kwargs):
"""Setup the task. """ """Setup the task."""
dictionary = Dictionary() dictionary = Dictionary()
for i in range(args.dict_size): for i in range(args.dict_size):
dictionary.add_symbol("word{}".format(i)) dictionary.add_symbol("word{}".format(i))

View File

@ -96,10 +96,7 @@ def save_checkpoint(cfg: CheckpointConfig, trainer, epoch_itr, val_loss):
checkpoint_conds[ checkpoint_conds[
"checkpoint.best_{}_{:.3f}{}{}.pt".format( "checkpoint.best_{}_{:.3f}{}{}.pt".format(
cfg.best_checkpoint_metric, cfg.best_checkpoint_metric, val_loss, rand_sfx, suffix
val_loss,
rand_sfx,
suffix
) )
] = worst_best is None or is_better(val_loss, worst_best) ] = worst_best is None or is_better(val_loss, worst_best)
checkpoint_conds[ checkpoint_conds[
@ -468,9 +465,7 @@ def load_model_ensemble_and_task(
and len(state["optimizer_history"]) > 0 and len(state["optimizer_history"]) > 0
and "num_updates" in state["optimizer_history"][-1] and "num_updates" in state["optimizer_history"][-1]
): ):
model.set_num_updates( model.set_num_updates(state["optimizer_history"][-1]["num_updates"])
state["optimizer_history"][-1]["num_updates"]
)
model.load_state_dict( model.load_state_dict(
state["model"], strict=strict, model_cfg=cfg.model state["model"], strict=strict, model_cfg=cfg.model
) )
@ -588,9 +583,8 @@ def _upgrade_state_dict(state):
# backward compatibility, cfg updates # backward compatibility, cfg updates
if "args" in state and state["args"] is not None: if "args" in state and state["args"] is not None:
# old model checkpoints may not have separate source/target positions # old model checkpoints may not have separate source/target positions
if ( if hasattr(state["args"], "max_positions") and not hasattr(
hasattr(state["args"], "max_positions") state["args"], "max_source_positions"
and not hasattr(state["args"], "max_source_positions")
): ):
state["args"].max_source_positions = state["args"].max_positions state["args"].max_source_positions = state["args"].max_positions
state["args"].max_target_positions = state["args"].max_positions state["args"].max_target_positions = state["args"].max_positions
@ -615,13 +609,10 @@ def _upgrade_state_dict(state):
state["args"].stop_min_lr = state["args"].min_lr state["args"].stop_min_lr = state["args"].min_lr
del state["args"].min_lr del state["args"].min_lr
# binary_cross_entropy / kd_binary_cross_entropy => wav2vec criterion # binary_cross_entropy / kd_binary_cross_entropy => wav2vec criterion
if ( if hasattr(state["args"], "criterion") and state["args"].criterion in [
hasattr(state["args"], "criterion") "binary_cross_entropy",
and state["args"].criterion in [ "kd_binary_cross_entropy",
"binary_cross_entropy", ]:
"kd_binary_cross_entropy",
]
):
state["args"].criterion = "wav2vec" state["args"].criterion = "wav2vec"
# remove log_keys if it's None (criteria will supply a default value of []) # remove log_keys if it's None (criteria will supply a default value of [])
if hasattr(state["args"], "log_keys") and state["args"].log_keys is None: if hasattr(state["args"], "log_keys") and state["args"].log_keys is None:
@ -659,7 +650,9 @@ def _upgrade_state_dict(state):
): ):
cfg.task.eval_wer_config.print_alignment = "hard" cfg.task.eval_wer_config.print_alignment = "hard"
if "generation" in cfg and isinstance(cfg.generation.print_alignment, bool): if "generation" in cfg and isinstance(cfg.generation.print_alignment, bool):
cfg.generation.print_alignment = "hard" if cfg.generation.print_alignment else None cfg.generation.print_alignment = (
"hard" if cfg.generation.print_alignment else None
)
if ( if (
"model" in cfg "model" in cfg
and "w2v_args" in cfg.model and "w2v_args" in cfg.model
@ -833,16 +826,16 @@ def load_ema_from_checkpoint(fpath):
params_dict = collections.OrderedDict() params_dict = collections.OrderedDict()
new_state = None new_state = None
with PathManager.open(fpath, 'rb') as f: with PathManager.open(fpath, "rb") as f:
new_state = torch.load( new_state = torch.load(
f, f,
map_location=( map_location=(
lambda s, _: torch.serialization.default_restore_location(s, 'cpu') lambda s, _: torch.serialization.default_restore_location(s, "cpu")
), ),
) )
# EMA model is stored in a separate "extra state" # EMA model is stored in a separate "extra state"
model_params = new_state['extra_state']['ema'] model_params = new_state["extra_state"]["ema"]
for key in list(model_params.keys()): for key in list(model_params.keys()):
p = model_params[key] p = model_params[key]
@ -860,5 +853,5 @@ def load_ema_from_checkpoint(fpath):
"ema model weights, is this model trained with EMA?" "ema model weights, is this model trained with EMA?"
) )
new_state['model'] = params_dict new_state["model"] = params_dict
return new_state return new_state

View File

@ -20,9 +20,7 @@ from fairseq.models.fairseq_model import FairseqEncoderModel
@dataclass @dataclass
class FastSpeech2CriterionConfig(FairseqDataclass): class FastSpeech2CriterionConfig(FairseqDataclass):
ctc_weight: float = field( ctc_weight: float = field(default=0.0, metadata={"help": "weight for CTC loss"})
default=0.0, metadata={"help": "weight for CTC loss"}
)
@register_criterion("fastspeech2", dataclass=FastSpeech2CriterionConfig) @register_criterion("fastspeech2", dataclass=FastSpeech2CriterionConfig)
@ -44,7 +42,7 @@ class FastSpeech2Loss(FairseqCriterion):
speaker=sample["speaker"], speaker=sample["speaker"],
durations=sample["durations"], durations=sample["durations"],
pitches=sample["pitches"], pitches=sample["pitches"],
energies=sample["energies"] energies=sample["energies"],
) )
src_mask = lengths_to_mask(sample["net_input"]["src_lengths"]) src_mask = lengths_to_mask(sample["net_input"]["src_lengths"])
@ -57,8 +55,7 @@ class FastSpeech2Loss(FairseqCriterion):
feat_out, feat = _feat_out[tgt_mask], sample["target"][tgt_mask] feat_out, feat = _feat_out[tgt_mask], sample["target"][tgt_mask]
l1_loss = F.l1_loss(feat_out, feat, reduction=reduction) l1_loss = F.l1_loss(feat_out, feat, reduction=reduction)
if _feat_out_post is not None: if _feat_out_post is not None:
l1_loss += F.l1_loss(_feat_out_post[tgt_mask], feat, l1_loss += F.l1_loss(_feat_out_post[tgt_mask], feat, reduction=reduction)
reduction=reduction)
pitch_loss = F.mse_loss(pitch_out, pitches, reduction=reduction) pitch_loss = F.mse_loss(pitch_out, pitches, reduction=reduction)
energy_loss = F.mse_loss(energy_out, energies, reduction=reduction) energy_loss = F.mse_loss(energy_out, energies, reduction=reduction)
@ -69,16 +66,23 @@ class FastSpeech2Loss(FairseqCriterion):
log_dur = torch.log(dur + 1)[src_mask] log_dur = torch.log(dur + 1)[src_mask]
dur_loss = F.mse_loss(log_dur_out, log_dur, reduction=reduction) dur_loss = F.mse_loss(log_dur_out, log_dur, reduction=reduction)
ctc_loss = torch.tensor(0.).type_as(l1_loss) ctc_loss = torch.tensor(0.0).type_as(l1_loss)
if self.ctc_weight > 0.: if self.ctc_weight > 0.0:
lprobs = model.get_normalized_probs((_feat_out,), log_probs=True) lprobs = model.get_normalized_probs((_feat_out,), log_probs=True)
lprobs = lprobs.transpose(0, 1) # T x B x C lprobs = lprobs.transpose(0, 1) # T x B x C
src_mask = lengths_to_mask(src_lens) src_mask = lengths_to_mask(src_lens)
src_tokens_flat = src_tokens.masked_select(src_mask) src_tokens_flat = src_tokens.masked_select(src_mask)
ctc_loss = F.ctc_loss( ctc_loss = (
lprobs, src_tokens_flat, tgt_lens, src_lens, F.ctc_loss(
reduction=reduction, zero_infinity=True lprobs,
) * self.ctc_weight src_tokens_flat,
tgt_lens,
src_lens,
reduction=reduction,
zero_infinity=True,
)
* self.ctc_weight
)
loss = l1_loss + dur_loss + pitch_loss + energy_loss + ctc_loss loss = l1_loss + dur_loss + pitch_loss + energy_loss + ctc_loss
@ -102,8 +106,12 @@ class FastSpeech2Loss(FairseqCriterion):
ntot = sum(ns) ntot = sum(ns)
ws = [n / (ntot + 1e-8) for n in ns] ws = [n / (ntot + 1e-8) for n in ns]
for key in [ for key in [
"loss", "l1_loss", "dur_loss", "pitch_loss", "energy_loss", "loss",
"ctc_loss" "l1_loss",
"dur_loss",
"pitch_loss",
"energy_loss",
"ctc_loss",
]: ]:
vals = [log.get(key, 0) for log in logging_outputs] vals = [log.get(key, 0) for log in logging_outputs]
val = sum(val * w for val, w in zip(vals, ws)) val = sum(val * w for val, w in zip(vals, ws))
@ -115,10 +123,10 @@ class FastSpeech2Loss(FairseqCriterion):
return return
n = sum(log.get("targ_frames", 0) for log in logging_outputs) n = sum(log.get("targ_frames", 0) for log in logging_outputs)
for key, new_key in [ for key, new_key in [
("mcd_loss", "mcd_loss"), ("mcd_loss", "mcd_loss"),
("pred_frames", "pred_ratio"), ("pred_frames", "pred_ratio"),
("nins", "ins_rate"), ("nins", "ins_rate"),
("ndel", "del_rate"), ("ndel", "del_rate"),
]: ]:
val = sum(log.get(key, 0) for log in logging_outputs) val = sum(log.get(key, 0) for log in logging_outputs)
metrics.log_scalar(new_key, val / n, n, round=3) metrics.log_scalar(new_key, val / n, n, round=3)

View File

@ -37,7 +37,14 @@ class HubertCriterionConfig(FairseqDataclass):
@register_criterion("hubert", dataclass=HubertCriterionConfig) @register_criterion("hubert", dataclass=HubertCriterionConfig)
class HubertCriterion(FairseqCriterion): class HubertCriterion(FairseqCriterion):
def __init__(self, task, pred_masked_weight, pred_nomask_weight, loss_weights=None, log_keys=None): def __init__(
self,
task,
pred_masked_weight,
pred_nomask_weight,
loss_weights=None,
log_keys=None,
):
super().__init__(task) super().__init__(task)
self.pred_masked_weight = pred_masked_weight self.pred_masked_weight = pred_masked_weight
self.pred_nomask_weight = pred_nomask_weight self.pred_nomask_weight = pred_nomask_weight
@ -52,7 +59,7 @@ class HubertCriterion(FairseqCriterion):
3) logging outputs to display while training 3) logging outputs to display while training
""" """
net_output = model(target_list=sample["target_list"], **sample["net_input"]) net_output = model(target_list=sample["target_list"], **sample["net_input"])
loss = 0. loss = 0.0
sample_size = 0 sample_size = 0
logging_output = {} logging_output = {}
reduction = "sum" if reduce else "none" reduction = "sum" if reduce else "none"
@ -89,7 +96,9 @@ class HubertCriterion(FairseqCriterion):
names = [names] names = [names]
if len(self.loss_weights) == 1 and len(extra_losses) != 1: if len(self.loss_weights) == 1 and len(extra_losses) != 1:
self.loss_weights = [self.loss_weights[0]] * len(extra_losses) self.loss_weights = [self.loss_weights[0]] * len(extra_losses)
assert len(extra_losses) == len(self.loss_weights), f"{len(extra_losses)}, {len(self.loss_weights)}" assert len(extra_losses) == len(
self.loss_weights
), f"{len(extra_losses)}, {len(self.loss_weights)}"
for p, n, coef in zip(extra_losses, names, self.loss_weights): for p, n, coef in zip(extra_losses, names, self.loss_weights):
if coef != 0 and p is not None: if coef != 0 and p is not None:
p = coef * p.float() * sample_size p = coef * p.float() * sample_size
@ -140,12 +149,20 @@ class HubertCriterion(FairseqCriterion):
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs) ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
metrics.log_scalar("loss", loss_sum / sample_size / math.log(2), sample_size, round=3) metrics.log_scalar(
"loss", loss_sum / sample_size / math.log(2), sample_size, round=3
)
if sample_size != ntokens: if sample_size != ntokens:
metrics.log_scalar("nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3) metrics.log_scalar(
metrics.log_derived("ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)) "nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3
)
metrics.log_derived(
"ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)
)
else: else:
metrics.log_derived("ppl", lambda meters: utils.get_perplexity(meters["loss"].avg)) metrics.log_derived(
"ppl", lambda meters: utils.get_perplexity(meters["loss"].avg)
)
counts = {} counts = {}
for lk in logging_outputs[0].keys(): for lk in logging_outputs[0].keys():

View File

@ -9,19 +9,20 @@ from fairseq import metrics, utils
from fairseq.criterions import register_criterion from fairseq.criterions import register_criterion
from fairseq.criterions.label_smoothed_cross_entropy import ( from fairseq.criterions.label_smoothed_cross_entropy import (
LabelSmoothedCrossEntropyCriterion, LabelSmoothedCrossEntropyCriterion,
LabelSmoothedCrossEntropyCriterionConfig LabelSmoothedCrossEntropyCriterionConfig,
) )
try: try:
from simuleval.metrics.latency import ( from simuleval.metrics.latency import (
AverageLagging, AverageLagging,
AverageProportion, AverageProportion,
DifferentiableAverageLagging DifferentiableAverageLagging,
) )
LATENCY_METRICS = { LATENCY_METRICS = {
"average_lagging": AverageLagging, "average_lagging": AverageLagging,
"average_proportion": AverageProportion, "average_proportion": AverageProportion,
"differentiable_average_lagging": DifferentiableAverageLagging, "differentiable_average_lagging": DifferentiableAverageLagging,
} }
except ImportError: except ImportError:
LATENCY_METRICS = None LATENCY_METRICS = None
@ -56,9 +57,10 @@ class LabelSmoothedCrossEntropyCriterionLatencyAugmentConfig(
metadata={"help": "Add latency loss after certain steps"}, metadata={"help": "Add latency loss after certain steps"},
) )
@register_criterion( @register_criterion(
"latency_augmented_label_smoothed_cross_entropy", "latency_augmented_label_smoothed_cross_entropy",
dataclass=LabelSmoothedCrossEntropyCriterionLatencyAugmentConfig dataclass=LabelSmoothedCrossEntropyCriterionLatencyAugmentConfig,
) )
class LatencyAugmentedLabelSmoothedCrossEntropyCriterion( class LatencyAugmentedLabelSmoothedCrossEntropyCriterion(
LabelSmoothedCrossEntropyCriterion LabelSmoothedCrossEntropyCriterion
@ -101,9 +103,9 @@ class LatencyAugmentedLabelSmoothedCrossEntropyCriterion(
if self.latency_update_after > 0: if self.latency_update_after > 0:
num_updates = getattr(model.decoder, "num_updates", None) num_updates = getattr(model.decoder, "num_updates", None)
assert num_updates is not None, ( assert (
"model.decoder doesn't have attribute 'num_updates'" num_updates is not None
) ), "model.decoder doesn't have attribute 'num_updates'"
if num_updates <= self.latency_update_after: if num_updates <= self.latency_update_after:
latency_loss = 0 latency_loss = 0
@ -134,9 +136,7 @@ class LatencyAugmentedLabelSmoothedCrossEntropyCriterion(
assert ( assert (
net_output[-1].encoder_padding_mask is None net_output[-1].encoder_padding_mask is None
or not net_output[-1].encoder_padding_mask[:, 0].any() or not net_output[-1].encoder_padding_mask[:, 0].any()
), ( ), "Only right padding on source is supported."
"Only right padding on source is supported."
)
# 1. Obtain the expected alignment # 1. Obtain the expected alignment
alpha_list = [item["alpha"] for item in net_output[1].attn_list] alpha_list = [item["alpha"] for item in net_output[1].attn_list]
num_layers = len(alpha_list) num_layers = len(alpha_list)
@ -174,8 +174,7 @@ class LatencyAugmentedLabelSmoothedCrossEntropyCriterion(
.view(-1) .view(-1)
) )
expected_latency = LATENCY_METRICS[self.latency_avg_type]( expected_latency = LATENCY_METRICS[self.latency_avg_type](
expected_delays, src_lengths, None, expected_delays, src_lengths, None, target_padding_mask=target_padding_mask
target_padding_mask=target_padding_mask
) )
# 2.1 average expected latency of heads # 2.1 average expected latency of heads
@ -210,24 +209,12 @@ class LatencyAugmentedLabelSmoothedCrossEntropyCriterion(
@classmethod @classmethod
def reduce_metrics(cls, logging_outputs) -> None: def reduce_metrics(cls, logging_outputs) -> None:
super().reduce_metrics(logging_outputs) super().reduce_metrics(logging_outputs)
latency = sum( latency = sum(log.get("latency", 0) for log in logging_outputs)
log.get("latency", 0) for log in logging_outputs delays_var = sum(log.get("delays_var", 0) for log in logging_outputs)
) latency_loss = sum(log.get("latency_loss", 0) for log in logging_outputs)
delays_var = sum(
log.get("delays_var", 0) for log in logging_outputs
)
latency_loss = sum(
log.get("latency_loss", 0) for log in logging_outputs
)
nsentences = sum(log.get("nsentences", 0) for log in logging_outputs) nsentences = sum(log.get("nsentences", 0) for log in logging_outputs)
metrics.log_scalar("latency", latency.float() / nsentences, nsentences, round=3)
metrics.log_scalar("delays_var", delays_var / nsentences, nsentences, round=3)
metrics.log_scalar( metrics.log_scalar(
"latency", latency.float() / nsentences, nsentences, round=3 "latency_loss", latency_loss / nsentences, nsentences, round=3
)
metrics.log_scalar(
"delays_var", delays_var / nsentences,
nsentences, round=3
)
metrics.log_scalar(
"latency_loss", latency_loss / nsentences,
nsentences, round=3
) )

View File

@ -41,9 +41,7 @@ class Tacotron2CriterionConfig(FairseqDataclass):
default=0.4, default=0.4,
metadata={"help": "weight of positive examples for BCE loss"}, metadata={"help": "weight of positive examples for BCE loss"},
) )
ctc_weight: float = field( ctc_weight: float = field(default=0.0, metadata={"help": "weight for CTC loss"})
default=0.0, metadata={"help": "weight for CTC loss"}
)
sentence_avg: bool = II("optimization.sentence_avg") sentence_avg: bool = II("optimization.sentence_avg")
@ -70,8 +68,7 @@ class GuidedAttentionLoss(torch.nn.Module):
bsz, max_s_len, max_t_len = len(src_lens), max(src_lens), max(tgt_lens) bsz, max_s_len, max_t_len = len(src_lens), max(src_lens), max(tgt_lens)
weights = torch.zeros((bsz, max_t_len, max_s_len)) weights = torch.zeros((bsz, max_t_len, max_s_len))
for i, (s_len, t_len) in enumerate(zip(src_lens, tgt_lens)): for i, (s_len, t_len) in enumerate(zip(src_lens, tgt_lens)):
weights[i, :t_len, :s_len] = self._get_weight(s_len, t_len, weights[i, :t_len, :s_len] = self._get_weight(s_len, t_len, self.sigma)
self.sigma)
return weights return weights
@staticmethod @staticmethod
@ -90,9 +87,16 @@ class GuidedAttentionLoss(torch.nn.Module):
@register_criterion("tacotron2", dataclass=Tacotron2CriterionConfig) @register_criterion("tacotron2", dataclass=Tacotron2CriterionConfig)
class Tacotron2Criterion(FairseqCriterion): class Tacotron2Criterion(FairseqCriterion):
def __init__(self, task, sentence_avg, n_frames_per_step, def __init__(
use_guided_attention_loss, guided_attention_loss_sigma, self,
bce_pos_weight, ctc_weight): task,
sentence_avg,
n_frames_per_step,
use_guided_attention_loss,
guided_attention_loss_sigma,
bce_pos_weight,
ctc_weight,
):
super().__init__(task) super().__init__(task)
self.sentence_avg = sentence_avg self.sentence_avg = sentence_avg
self.n_frames_per_step = n_frames_per_step self.n_frames_per_step = n_frames_per_step
@ -120,31 +124,42 @@ class Tacotron2Criterion(FairseqCriterion):
prev_output_tokens=sample["net_input"]["prev_output_tokens"], prev_output_tokens=sample["net_input"]["prev_output_tokens"],
incremental_state=None, incremental_state=None,
target_lengths=tgt_lens, target_lengths=tgt_lens,
speaker=sample["speaker"] speaker=sample["speaker"],
) )
l1_loss, mse_loss, eos_loss = self.compute_loss( l1_loss, mse_loss, eos_loss = self.compute_loss(
extra["feature_out"], feat_out, eos_out, feat_tgt, eos_tgt, extra["feature_out"],
tgt_lens, reduction, feat_out,
eos_out,
feat_tgt,
eos_tgt,
tgt_lens,
reduction,
) )
attn_loss = torch.tensor(0.).type_as(l1_loss) attn_loss = torch.tensor(0.0).type_as(l1_loss)
if self.guided_attn is not None: if self.guided_attn is not None:
attn_loss = self.guided_attn(extra['attn'], src_lens, tgt_lens, reduction) attn_loss = self.guided_attn(extra["attn"], src_lens, tgt_lens, reduction)
ctc_loss = torch.tensor(0.).type_as(l1_loss) ctc_loss = torch.tensor(0.0).type_as(l1_loss)
if self.ctc_weight > 0.: if self.ctc_weight > 0.0:
net_output = (feat_out, eos_out, extra) net_output = (feat_out, eos_out, extra)
lprobs = model.get_normalized_probs(net_output, log_probs=True) lprobs = model.get_normalized_probs(net_output, log_probs=True)
lprobs = lprobs.transpose(0, 1) # T x B x C lprobs = lprobs.transpose(0, 1) # T x B x C
src_mask = lengths_to_mask(src_lens) src_mask = lengths_to_mask(src_lens)
src_tokens_flat = src_tokens.masked_select(src_mask) src_tokens_flat = src_tokens.masked_select(src_mask)
ctc_loss = F.ctc_loss( ctc_loss = (
lprobs, src_tokens_flat, tgt_lens, src_lens, F.ctc_loss(
reduction=reduction, zero_infinity=True lprobs,
) * self.ctc_weight src_tokens_flat,
tgt_lens,
src_lens,
reduction=reduction,
zero_infinity=True,
)
* self.ctc_weight
)
loss = l1_loss + mse_loss + eos_loss + attn_loss + ctc_loss loss = l1_loss + mse_loss + eos_loss + attn_loss + ctc_loss
sample_size = sample["nsentences"] if self.sentence_avg \ sample_size = sample["nsentences"] if self.sentence_avg else sample["ntokens"]
else sample["ntokens"]
logging_output = { logging_output = {
"loss": utils.item(loss.data), "loss": utils.item(loss.data),
"ntokens": sample["ntokens"], "ntokens": sample["ntokens"],
@ -158,8 +173,16 @@ class Tacotron2Criterion(FairseqCriterion):
} }
return loss, sample_size, logging_output return loss, sample_size, logging_output
def compute_loss(self, feat_out, feat_out_post, eos_out, feat_tgt, def compute_loss(
eos_tgt, tgt_lens, reduction="mean"): self,
feat_out,
feat_out_post,
eos_out,
feat_tgt,
eos_tgt,
tgt_lens,
reduction="mean",
):
mask = lengths_to_mask(tgt_lens) mask = lengths_to_mask(tgt_lens)
_eos_out = eos_out[mask].squeeze() _eos_out = eos_out[mask].squeeze()
_eos_tgt = eos_tgt[mask] _eos_tgt = eos_tgt[mask]
@ -167,17 +190,17 @@ class Tacotron2Criterion(FairseqCriterion):
_feat_out = feat_out[mask] _feat_out = feat_out[mask]
_feat_out_post = feat_out_post[mask] _feat_out_post = feat_out_post[mask]
l1_loss = ( l1_loss = F.l1_loss(_feat_out, _feat_tgt, reduction=reduction) + F.l1_loss(
F.l1_loss(_feat_out, _feat_tgt, reduction=reduction) + _feat_out_post, _feat_tgt, reduction=reduction
F.l1_loss(_feat_out_post, _feat_tgt, reduction=reduction)
) )
mse_loss = ( mse_loss = F.mse_loss(_feat_out, _feat_tgt, reduction=reduction) + F.mse_loss(
F.mse_loss(_feat_out, _feat_tgt, reduction=reduction) + _feat_out_post, _feat_tgt, reduction=reduction
F.mse_loss(_feat_out_post, _feat_tgt, reduction=reduction)
) )
eos_loss = F.binary_cross_entropy_with_logits( eos_loss = F.binary_cross_entropy_with_logits(
_eos_out, _eos_tgt, pos_weight=torch.tensor(self.bce_pos_weight), _eos_out,
reduction=reduction _eos_tgt,
pos_weight=torch.tensor(self.bce_pos_weight),
reduction=reduction,
) )
return l1_loss, mse_loss, eos_loss return l1_loss, mse_loss, eos_loss
@ -197,10 +220,10 @@ class Tacotron2Criterion(FairseqCriterion):
return return
n = sum(log.get("targ_frames", 0) for log in logging_outputs) n = sum(log.get("targ_frames", 0) for log in logging_outputs)
for key, new_key in [ for key, new_key in [
("mcd_loss", "mcd_loss"), ("mcd_loss", "mcd_loss"),
("pred_frames", "pred_ratio"), ("pred_frames", "pred_ratio"),
("nins", "ins_rate"), ("nins", "ins_rate"),
("ndel", "del_rate"), ("ndel", "del_rate"),
]: ]:
val = sum(log.get(key, 0) for log in logging_outputs) val = sum(log.get(key, 0) for log in logging_outputs)
metrics.log_scalar(new_key, val / n, n, round=3) metrics.log_scalar(new_key, val / n, n, round=3)

View File

@ -33,6 +33,7 @@ class Wav2VecCriterionConfig(FairseqDataclass):
metadata={"help": "output keys to log"}, metadata={"help": "output keys to log"},
) )
@register_criterion("wav2vec", dataclass=Wav2VecCriterionConfig) @register_criterion("wav2vec", dataclass=Wav2VecCriterionConfig)
class Wav2vecCriterion(FairseqCriterion): class Wav2vecCriterion(FairseqCriterion):
def __init__(self, task, infonce=False, loss_weights=None, log_keys=None): def __init__(self, task, infonce=False, loss_weights=None, log_keys=None):
@ -76,16 +77,16 @@ class Wav2vecCriterion(FairseqCriterion):
# we don't shrink tensors using mask_indices. # we don't shrink tensors using mask_indices.
# Instead, we use mask indices to adjust loss. # Instead, we use mask indices to adjust loss.
mi = ( mi = (
sample['net_input']['mask_indices'] sample["net_input"]["mask_indices"]
.transpose(0, 1) # logits are transposed in `model.get_logits` .transpose(0, 1) # logits are transposed in `model.get_logits`
.reshape(logits.size(0)) .reshape(logits.size(0))
) )
loss = (loss * mi).sum() if reduce else (loss * mi) loss = (loss * mi).sum() if reduce else (loss * mi)
if 'sample_size' in sample: if "sample_size" in sample:
sample_size = sample['sample_size'] sample_size = sample["sample_size"]
elif 'mask_indices' in sample['net_input']: elif "mask_indices" in sample["net_input"]:
sample_size = sample['net_input']['mask_indices'].sum() sample_size = sample["net_input"]["mask_indices"].sum()
else: else:
sample_size = target.numel() if self.infonce else target.long().sum().item() sample_size = target.numel() if self.infonce else target.long().sum().item()
losses.append(loss.detach().clone()) losses.append(loss.detach().clone())
@ -216,8 +217,8 @@ class Wav2vecCriterion(FairseqCriterion):
metrics.log_scalar(k, val / len(logging_outputs), round=3) metrics.log_scalar(k, val / len(logging_outputs), round=3)
# FIXME: revert when gather based xla reduction is implemented # FIXME: revert when gather based xla reduction is implemented
#@staticmethod # @staticmethod
#def logging_outputs_can_be_summed() -> bool: # def logging_outputs_can_be_summed() -> bool:
def logging_outputs_can_be_summed(self) -> bool: def logging_outputs_can_be_summed(self) -> bool:
""" """
Whether the logging outputs returned by `forward` can be summed Whether the logging outputs returned by `forward` can be summed

View File

@ -20,7 +20,7 @@ class AddTargetDataset(BaseWrapperDataset):
process_label=None, process_label=None,
label_len_fn=None, label_len_fn=None,
add_to_input=False, add_to_input=False,
text_compression_level=TextCompressionLevel.none text_compression_level=TextCompressionLevel.none,
): ):
super().__init__(dataset) super().__init__(dataset)
self.labels = labels self.labels = labels

View File

@ -18,26 +18,28 @@ FEATURE_OR_SF_AUDIO_FILE_EXTENSIONS = {".npy", ".wav", ".flac", ".ogg"}
def convert_waveform( def convert_waveform(
waveform: Union[np.ndarray, torch.Tensor], sample_rate: int, waveform: Union[np.ndarray, torch.Tensor],
normalize_volume: bool = False, to_mono: bool = False, sample_rate: int,
to_sample_rate: Optional[int] = None normalize_volume: bool = False,
to_mono: bool = False,
to_sample_rate: Optional[int] = None,
) -> Tuple[Union[np.ndarray, torch.Tensor], int]: ) -> Tuple[Union[np.ndarray, torch.Tensor], int]:
"""convert a waveform: """convert a waveform:
- to a target sample rate - to a target sample rate
- from multi-channel to mono channel - from multi-channel to mono channel
- volume normalization - volume normalization
Args: Args:
waveform (numpy.ndarray or torch.Tensor): 2D original waveform waveform (numpy.ndarray or torch.Tensor): 2D original waveform
(channels x length) (channels x length)
sample_rate (int): original sample rate sample_rate (int): original sample rate
normalize_volume (bool): perform volume normalization normalize_volume (bool): perform volume normalization
to_mono (bool): convert to mono channel if having multiple channels to_mono (bool): convert to mono channel if having multiple channels
to_sample_rate (Optional[int]): target sample rate to_sample_rate (Optional[int]): target sample rate
Returns: Returns:
waveform (numpy.ndarray): converted 2D waveform (channels x length) waveform (numpy.ndarray): converted 2D waveform (channels x length)
sample_rate (float): target sample rate sample_rate (float): target sample rate
""" """
try: try:
import torchaudio.sox_effects as ta_sox import torchaudio.sox_effects as ta_sox
except ImportError: except ImportError:
@ -63,10 +65,14 @@ def convert_waveform(
def get_waveform( def get_waveform(
path_or_fp: Union[str, BinaryIO], normalization: bool = True, path_or_fp: Union[str, BinaryIO],
mono: bool = True, frames: int = -1, start: int = 0, normalization: bool = True,
always_2d: bool = True, output_sample_rate: Optional[int] = None, mono: bool = True,
normalize_volume: bool = False frames: int = -1,
start: int = 0,
always_2d: bool = True,
output_sample_rate: Optional[int] = None,
normalize_volume: bool = False,
) -> Tuple[np.ndarray, int]: ) -> Tuple[np.ndarray, int]:
"""Get the waveform and sample rate of a 16-bit WAV/FLAC/OGG Vorbis audio. """Get the waveform and sample rate of a 16-bit WAV/FLAC/OGG Vorbis audio.
@ -98,8 +104,11 @@ def get_waveform(
) )
waveform = waveform.T # T x C -> C x T waveform = waveform.T # T x C -> C x T
waveform, sample_rate = convert_waveform( waveform, sample_rate = convert_waveform(
waveform, sample_rate, normalize_volume=normalize_volume, to_mono=mono, waveform,
to_sample_rate=output_sample_rate sample_rate,
normalize_volume=normalize_volume,
to_mono=mono,
to_sample_rate=output_sample_rate,
) )
if not normalization: if not normalization:
@ -182,7 +191,7 @@ def is_sf_audio_data(data: bytes) -> bool:
def mmap_read(path: str, offset: int, length: int) -> bytes: def mmap_read(path: str, offset: int, length: int) -> bytes:
with open(path, "rb") as f: with open(path, "rb") as f:
with mmap.mmap(f.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_o: with mmap.mmap(f.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_o:
data = mmap_o[offset: offset + length] data = mmap_o[offset : offset + length]
return data return data
@ -215,9 +224,7 @@ def parse_path(path: str) -> Tuple[str, List[int]]:
return _path, slice_ptr return _path, slice_ptr
def get_window( def get_window(window_fn: callable, n_fft: int, win_length: int) -> torch.Tensor:
window_fn: callable, n_fft: int, win_length: int
) -> torch.Tensor:
padding = n_fft - win_length padding = n_fft - win_length
assert padding >= 0 assert padding >= 0
return F.pad(window_fn(win_length), (padding // 2, padding - padding // 2)) return F.pad(window_fn(win_length), (padding // 2, padding - padding // 2))
@ -226,13 +233,13 @@ def get_window(
def get_fourier_basis(n_fft: int) -> torch.Tensor: def get_fourier_basis(n_fft: int) -> torch.Tensor:
basis = np.fft.fft(np.eye(n_fft)) basis = np.fft.fft(np.eye(n_fft))
basis = np.vstack( basis = np.vstack(
[np.real(basis[:n_fft // 2 + 1, :]), np.imag(basis[:n_fft // 2 + 1, :])] [np.real(basis[: n_fft // 2 + 1, :]), np.imag(basis[: n_fft // 2 + 1, :])]
) )
return torch.from_numpy(basis).float() return torch.from_numpy(basis).float()
def get_mel_filters( def get_mel_filters(
sample_rate: int, n_fft: int, n_mels: int, f_min: float, f_max: float sample_rate: int, n_fft: int, n_mels: int, f_min: float, f_max: float
) -> torch.Tensor: ) -> torch.Tensor:
try: try:
import librosa import librosa
@ -244,8 +251,12 @@ def get_mel_filters(
class TTSSpectrogram(torch.nn.Module): class TTSSpectrogram(torch.nn.Module):
def __init__( def __init__(
self, n_fft: int, win_length: int, hop_length: int, self,
window_fn: callable = torch.hann_window, return_phase: bool = False n_fft: int,
win_length: int,
hop_length: int,
window_fn: callable = torch.hann_window,
return_phase: bool = False,
) -> None: ) -> None:
super(TTSSpectrogram, self).__init__() super(TTSSpectrogram, self).__init__()
self.n_fft = n_fft self.n_fft = n_fft
@ -254,16 +265,16 @@ class TTSSpectrogram(torch.nn.Module):
basis = get_fourier_basis(n_fft).unsqueeze(1) basis = get_fourier_basis(n_fft).unsqueeze(1)
basis *= get_window(window_fn, n_fft, win_length) basis *= get_window(window_fn, n_fft, win_length)
self.register_buffer('basis', basis) self.register_buffer("basis", basis)
def forward( def forward(
self, waveform: torch.Tensor self, waveform: torch.Tensor
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
padding = (self.n_fft // 2, self.n_fft // 2) padding = (self.n_fft // 2, self.n_fft // 2)
x = F.pad(waveform.unsqueeze(1), padding, mode='reflect') x = F.pad(waveform.unsqueeze(1), padding, mode="reflect")
x = F.conv1d(x, self.basis, stride=self.hop_length) x = F.conv1d(x, self.basis, stride=self.hop_length)
real_part = x[:, :self.n_fft // 2 + 1, :] real_part = x[:, : self.n_fft // 2 + 1, :]
imag_part = x[:, self.n_fft // 2 + 1:, :] imag_part = x[:, self.n_fft // 2 + 1 :, :]
magnitude = torch.sqrt(real_part ** 2 + imag_part ** 2) magnitude = torch.sqrt(real_part ** 2 + imag_part ** 2)
if self.return_phase: if self.return_phase:
phase = torch.atan2(imag_part, real_part) phase = torch.atan2(imag_part, real_part)
@ -273,13 +284,11 @@ class TTSSpectrogram(torch.nn.Module):
class TTSMelScale(torch.nn.Module): class TTSMelScale(torch.nn.Module):
def __init__( def __init__(
self, n_mels: int, sample_rate: int, f_min: float, f_max: float, self, n_mels: int, sample_rate: int, f_min: float, f_max: float, n_stft: int
n_stft: int
) -> None: ) -> None:
super(TTSMelScale, self).__init__() super(TTSMelScale, self).__init__()
basis = get_mel_filters(sample_rate, (n_stft - 1) * 2, n_mels, f_min, basis = get_mel_filters(sample_rate, (n_stft - 1) * 2, n_mels, f_min, f_max)
f_max) self.register_buffer("basis", basis)
self.register_buffer('basis', basis)
def forward(self, specgram: torch.Tensor) -> torch.Tensor: def forward(self, specgram: torch.Tensor) -> torch.Tensor:
return torch.matmul(self.basis, specgram) return torch.matmul(self.basis, specgram)

View File

@ -13,11 +13,10 @@ from typing import List, Optional
import numpy as np import numpy as np
import torch import torch
from fairseq.data import Dictionary from fairseq.data import Dictionary
from fairseq.data.audio.speech_to_text_dataset import ( from fairseq.data.audio.speech_to_text_dataset import S2TDataConfig
S2TDataConfig
)
from fairseq.data.audio.text_to_speech_dataset import ( from fairseq.data.audio.text_to_speech_dataset import (
TextToSpeechDataset, TextToSpeechDatasetCreator TextToSpeechDataset,
TextToSpeechDatasetCreator,
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -48,7 +47,7 @@ class FrmTextToSpeechDataset(TextToSpeechDataset):
chunk_incr=5, chunk_incr=5,
add_eos=True, add_eos=True,
dedup=True, dedup=True,
ref_fpu=-1 ref_fpu=-1,
): ):
# It assumes texts are encoded at a fixed frame-rate # It assumes texts are encoded at a fixed frame-rate
super().__init__( super().__init__(
@ -67,7 +66,7 @@ class FrmTextToSpeechDataset(TextToSpeechDataset):
pre_tokenizer=pre_tokenizer, pre_tokenizer=pre_tokenizer,
bpe_tokenizer=bpe_tokenizer, bpe_tokenizer=bpe_tokenizer,
n_frames_per_step=n_frames_per_step, n_frames_per_step=n_frames_per_step,
speaker_to_id=speaker_to_id speaker_to_id=speaker_to_id,
) )
self.do_chunk = do_chunk self.do_chunk = do_chunk
@ -92,24 +91,23 @@ class FrmTextToSpeechDataset(TextToSpeechDataset):
fpu = source.size(0) / target.size(0) # frame-per-unit fpu = source.size(0) / target.size(0) # frame-per-unit
fps = self.n_frames_per_step fps = self.n_frames_per_step
assert ( assert (
self.ref_fpu == -1 or self.ref_fpu == -1 or abs((fpu * fps - self.ref_fpu) / self.ref_fpu) < 0.1
abs((fpu * fps - self.ref_fpu) / self.ref_fpu) < 0.1
), f"{fpu*fps} != {self.ref_fpu}" ), f"{fpu*fps} != {self.ref_fpu}"
# only chunk training split # only chunk training split
if self.is_train_split and self.do_chunk and self.chunk_size > 0: if self.is_train_split and self.do_chunk and self.chunk_size > 0:
lang = target[:int(self.data_cfg.prepend_tgt_lang_tag)] lang = target[: int(self.data_cfg.prepend_tgt_lang_tag)]
text = target[int(self.data_cfg.prepend_tgt_lang_tag):] text = target[int(self.data_cfg.prepend_tgt_lang_tag) :]
size = len(text) size = len(text)
chunk_size = min(self.chunk_size, size) chunk_size = min(self.chunk_size, size)
chunk_start = np.random.randint(size - chunk_size + 1) chunk_start = np.random.randint(size - chunk_size + 1)
text = text[chunk_start:chunk_start+chunk_size] text = text[chunk_start : chunk_start + chunk_size]
target = torch.cat((lang, text), 0) target = torch.cat((lang, text), 0)
f_size = int(np.floor(chunk_size * fpu)) f_size = int(np.floor(chunk_size * fpu))
f_start = int(np.floor(chunk_start * fpu)) f_start = int(np.floor(chunk_start * fpu))
assert(f_size > 0) assert f_size > 0
source = source[f_start:f_start+f_size, :] source = source[f_start : f_start + f_size, :]
if self.dedup: if self.dedup:
target = torch.unique_consecutive(target) target = torch.unique_consecutive(target)
@ -126,10 +124,12 @@ class FrmTextToSpeechDataset(TextToSpeechDataset):
self.chunk_size = self.chunk_init + epoch * self.chunk_incr self.chunk_size = self.chunk_init + epoch * self.chunk_incr
if self.chunk_bound > 0: if self.chunk_bound > 0:
self.chunk_size = min(self.chunk_size, self.chunk_bound) self.chunk_size = min(self.chunk_size, self.chunk_bound)
logger.info(( logger.info(
f"{self.split}: setting chunk size " (
f"from {old} to {self.chunk_size}" f"{self.split}: setting chunk size "
)) f"from {old} to {self.chunk_size}"
)
)
class FrmTextToSpeechDatasetCreator(TextToSpeechDatasetCreator): class FrmTextToSpeechDatasetCreator(TextToSpeechDatasetCreator):
@ -152,7 +152,7 @@ class FrmTextToSpeechDatasetCreator(TextToSpeechDatasetCreator):
chunk_incr: int = 5, chunk_incr: int = 5,
add_eos: bool = True, add_eos: bool = True,
dedup: bool = True, dedup: bool = True,
ref_fpu: float = -1 ref_fpu: float = -1,
) -> FrmTextToSpeechDataset: ) -> FrmTextToSpeechDataset:
tsv_path = op.join(root, f"{split}.tsv") tsv_path = op.join(root, f"{split}.tsv")
if not op.isfile(tsv_path): if not op.isfile(tsv_path):
@ -170,9 +170,7 @@ class FrmTextToSpeechDatasetCreator(TextToSpeechDatasetCreator):
assert len(s) > 0 assert len(s) > 0
ids = [ss[cls.KEY_ID] for ss in s] ids = [ss[cls.KEY_ID] for ss in s]
audio_paths = [ audio_paths = [op.join(data_cfg.audio_root, ss[cls.KEY_AUDIO]) for ss in s]
op.join(data_cfg.audio_root, ss[cls.KEY_AUDIO]) for ss in s
]
n_frames = [int(ss[cls.KEY_N_FRAMES]) for ss in s] n_frames = [int(ss[cls.KEY_N_FRAMES]) for ss in s]
tgt_texts = [ss[cls.KEY_TGT_TEXT] for ss in s] tgt_texts = [ss[cls.KEY_TGT_TEXT] for ss in s]
src_texts = [ss.get(cls.KEY_SRC_TEXT, cls.DEFAULT_SRC_TEXT) for ss in s] src_texts = [ss.get(cls.KEY_SRC_TEXT, cls.DEFAULT_SRC_TEXT) for ss in s]
@ -203,5 +201,5 @@ class FrmTextToSpeechDatasetCreator(TextToSpeechDatasetCreator):
chunk_incr=chunk_incr, chunk_incr=chunk_incr,
add_eos=add_eos, add_eos=add_eos,
dedup=dedup, dedup=dedup,
ref_fpu=ref_fpu ref_fpu=ref_fpu,
) )

View File

@ -152,10 +152,7 @@ class HubertDataset(FairseqDataset):
self.label_offsets_list = [ self.label_offsets_list = [
load_label_offset(p, inds, tot) for p in label_paths load_label_offset(p, inds, tot) for p in label_paths
] ]
assert ( assert label_processors is None or len(label_processors) == self.num_labels
label_processors is None
or len(label_processors) == self.num_labels
)
for label_path, label_rate in zip(label_paths, self.label_rates): for label_path, label_rate in zip(label_paths, self.label_rates):
verify_label_lengths( verify_label_lengths(
self.sizes, sample_rate, label_path, label_rate, inds, tot self.sizes, sample_rate, label_path, label_rate, inds, tot
@ -234,8 +231,7 @@ class HubertDataset(FairseqDataset):
) )
targets_by_label = [ targets_by_label = [
[s["label_list"][i] for s in samples] [s["label_list"][i] for s in samples] for i in range(self.num_labels)
for i in range(self.num_labels)
] ]
targets_list, lengths_list, ntokens_list = self.collater_label( targets_list, lengths_list, ntokens_list = self.collater_label(
targets_by_label, audio_size, audio_starts targets_by_label, audio_size, audio_starts
@ -270,9 +266,7 @@ class HubertDataset(FairseqDataset):
collated_audios[i] = audio collated_audios[i] = audio
elif diff < 0: elif diff < 0:
assert self.pad_audio assert self.pad_audio
collated_audios[i] = torch.cat( collated_audios[i] = torch.cat([audio, audio.new_full((-diff,), 0.0)])
[audio, audio.new_full((-diff,), 0.0)]
)
padding_mask[i, diff:] = True padding_mask[i, diff:] = True
else: else:
collated_audios[i], audio_starts[i] = self.crop_to_max_size( collated_audios[i], audio_starts[i] = self.crop_to_max_size(
@ -280,9 +274,7 @@ class HubertDataset(FairseqDataset):
) )
return collated_audios, padding_mask, audio_starts return collated_audios, padding_mask, audio_starts
def collater_frm_label( def collater_frm_label(self, targets, audio_size, audio_starts, label_rate, pad):
self, targets, audio_size, audio_starts, label_rate, pad
):
assert label_rate > 0 assert label_rate > 0
s2f = label_rate / self.sample_rate s2f = label_rate / self.sample_rate
frm_starts = [int(round(s * s2f)) for s in audio_starts] frm_starts = [int(round(s * s2f)) for s in audio_starts]
@ -290,24 +282,20 @@ class HubertDataset(FairseqDataset):
if not self.pad_audio: if not self.pad_audio:
rem_size = [len(t) - s for t, s in zip(targets, frm_starts)] rem_size = [len(t) - s for t, s in zip(targets, frm_starts)]
frm_size = min(frm_size, *rem_size) frm_size = min(frm_size, *rem_size)
targets = [t[s: s + frm_size] for t, s in zip(targets, frm_starts)] targets = [t[s : s + frm_size] for t, s in zip(targets, frm_starts)]
logger.debug(f"audio_starts={audio_starts}") logger.debug(f"audio_starts={audio_starts}")
logger.debug(f"frame_starts={frm_starts}") logger.debug(f"frame_starts={frm_starts}")
logger.debug(f"frame_size={frm_size}") logger.debug(f"frame_size={frm_size}")
lengths = torch.LongTensor([len(t) for t in targets]) lengths = torch.LongTensor([len(t) for t in targets])
ntokens = lengths.sum().item() ntokens = lengths.sum().item()
targets = data_utils.collate_tokens( targets = data_utils.collate_tokens(targets, pad_idx=pad, left_pad=False)
targets, pad_idx=pad, left_pad=False
)
return targets, lengths, ntokens return targets, lengths, ntokens
def collater_seq_label(self, targets, pad): def collater_seq_label(self, targets, pad):
lengths = torch.LongTensor([len(t) for t in targets]) lengths = torch.LongTensor([len(t) for t in targets])
ntokens = lengths.sum().item() ntokens = lengths.sum().item()
targets = data_utils.collate_tokens( targets = data_utils.collate_tokens(targets, pad_idx=pad, left_pad=False)
targets, pad_idx=pad, left_pad=False
)
return targets, lengths, ntokens return targets, lengths, ntokens
def collater_label(self, targets_by_label, audio_size, audio_starts): def collater_label(self, targets_by_label, audio_size, audio_starts):
@ -315,9 +303,7 @@ class HubertDataset(FairseqDataset):
itr = zip(targets_by_label, self.label_rates, self.pad_list) itr = zip(targets_by_label, self.label_rates, self.pad_list)
for targets, label_rate, pad in itr: for targets, label_rate, pad in itr:
if label_rate == -1: if label_rate == -1:
targets, lengths, ntokens = self.collater_seq_label( targets, lengths, ntokens = self.collater_seq_label(targets, pad)
targets, pad
)
else: else:
targets, lengths, ntokens = self.collater_frm_label( targets, lengths, ntokens = self.collater_frm_label(
targets, audio_size, audio_starts, label_rate, pad targets, audio_size, audio_starts, label_rate, pad

View File

@ -29,6 +29,7 @@ class ModalityDatasetItem(NamedTuple):
max_tokens: Optional[int] = None max_tokens: Optional[int] = None
max_sentences: Optional[int] = None max_sentences: Optional[int] = None
# MultiModalityDataset: it concate multiple datasets with different modalities. # MultiModalityDataset: it concate multiple datasets with different modalities.
# Compared with ConcatDataset it can 1) sample data given the ratios for different datasets # Compared with ConcatDataset it can 1) sample data given the ratios for different datasets
# 2) it adds mode to indicate what type of the data samples come from. # 2) it adds mode to indicate what type of the data samples come from.

View File

@ -308,6 +308,7 @@ class FileAudioDataset(RawAudioDataset):
def __getitem__(self, index): def __getitem__(self, index):
import soundfile as sf import soundfile as sf
fn = self.fnames[index] fn = self.fnames[index]
fn = fn if isinstance(self.fnames, list) else fn.as_py() fn = fn if isinstance(self.fnames, list) else fn.as_py()
fn = self.text_compressor.decompress(fn) fn = self.text_compressor.decompress(fn)

View File

@ -45,7 +45,11 @@ def get_features_from_npy_or_audio(path):
def get_features_or_waveform_from_stored_zip( def get_features_or_waveform_from_stored_zip(
path, byte_offset, byte_size, need_waveform=False, use_sample_rate=None, path,
byte_offset,
byte_size,
need_waveform=False,
use_sample_rate=None,
): ):
assert path.endswith(".zip") assert path.endswith(".zip")
data = read_from_stored_zip(path, byte_offset, byte_size) data = read_from_stored_zip(path, byte_offset, byte_size)
@ -53,18 +57,17 @@ def get_features_or_waveform_from_stored_zip(
if is_npy_data(data): if is_npy_data(data):
features_or_waveform = np.load(f) features_or_waveform = np.load(f)
elif is_sf_audio_data(data): elif is_sf_audio_data(data):
features_or_waveform = \ features_or_waveform = (
get_waveform( get_waveform(f, always_2d=False, output_sample_rate=use_sample_rate)[0]
f, always_2d=False, output_sample_rate=use_sample_rate if need_waveform
)[0] if need_waveform else get_fbank(f) else get_fbank(f)
)
else: else:
raise ValueError(f'Unknown file format for "{path}"') raise ValueError(f'Unknown file format for "{path}"')
return features_or_waveform return features_or_waveform
def get_features_or_waveform( def get_features_or_waveform(path: str, need_waveform=False, use_sample_rate=None):
path: str, need_waveform=False, use_sample_rate=None
):
"""Get speech features from .npy file or waveform from .wav/.flac file. """Get speech features from .npy file or waveform from .wav/.flac file.
The file may be inside an uncompressed ZIP file and is accessed via byte The file may be inside an uncompressed ZIP file and is accessed via byte
offset and length. offset and length.
@ -87,8 +90,11 @@ def get_features_or_waveform(
return get_features_from_npy_or_audio(_path) return get_features_from_npy_or_audio(_path)
elif len(slice_ptr) == 2: elif len(slice_ptr) == 2:
features_or_waveform = get_features_or_waveform_from_stored_zip( features_or_waveform = get_features_or_waveform_from_stored_zip(
_path, slice_ptr[0], slice_ptr[1], need_waveform=need_waveform, _path,
use_sample_rate=use_sample_rate slice_ptr[0],
slice_ptr[1],
need_waveform=need_waveform,
use_sample_rate=use_sample_rate,
) )
else: else:
raise ValueError(f"Invalid path: {path}") raise ValueError(f"Invalid path: {path}")
@ -145,7 +151,7 @@ class SpeechToTextDataset(FairseqDataset):
pre_tokenizer=None, pre_tokenizer=None,
bpe_tokenizer=None, bpe_tokenizer=None,
n_frames_per_step=1, n_frames_per_step=1,
speaker_to_id=None speaker_to_id=None,
): ):
self.split, self.is_train_split = split, is_train_split self.split, self.is_train_split = split, is_train_split
self.cfg = cfg self.cfg = cfg
@ -235,7 +241,7 @@ class SpeechToTextDataset(FairseqDataset):
if self.n_frames_per_step == 1: if self.n_frames_per_step == 1:
return feature return feature
n_packed_frames = feature.shape[0] // self.n_frames_per_step n_packed_frames = feature.shape[0] // self.n_frames_per_step
feature = feature[:self.n_frames_per_step * n_packed_frames] feature = feature[: self.n_frames_per_step * n_packed_frames]
return feature.reshape(n_packed_frames, -1) return feature.reshape(n_packed_frames, -1)
@classmethod @classmethod
@ -318,9 +324,11 @@ class SpeechToTextDataset(FairseqDataset):
speaker = None speaker = None
if self.speaker_to_id is not None: if self.speaker_to_id is not None:
speaker = torch.tensor( speaker = (
[s.speaker_id for s in samples], dtype=torch.long torch.tensor([s.speaker_id for s in samples], dtype=torch.long)
).index_select(0, order).view(-1, 1) .index_select(0, order)
.view(-1, 1)
)
net_input = { net_input = {
"src_tokens": frames, "src_tokens": frames,
@ -388,7 +396,7 @@ class SpeechToTextDatasetCreator(object):
pre_tokenizer, pre_tokenizer,
bpe_tokenizer, bpe_tokenizer,
n_frames_per_step, n_frames_per_step,
speaker_to_id speaker_to_id,
) -> SpeechToTextDataset: ) -> SpeechToTextDataset:
audio_root = Path(cfg.audio_root) audio_root = Path(cfg.audio_root)
ids = [s[cls.KEY_ID] for s in samples] ids = [s[cls.KEY_ID] for s in samples]
@ -415,7 +423,7 @@ class SpeechToTextDatasetCreator(object):
pre_tokenizer=pre_tokenizer, pre_tokenizer=pre_tokenizer,
bpe_tokenizer=bpe_tokenizer, bpe_tokenizer=bpe_tokenizer,
n_frames_per_step=n_frames_per_step, n_frames_per_step=n_frames_per_step,
speaker_to_id=speaker_to_id speaker_to_id=speaker_to_id,
) )
@classmethod @classmethod
@ -481,12 +489,19 @@ class SpeechToTextDatasetCreator(object):
pre_tokenizer, pre_tokenizer,
bpe_tokenizer, bpe_tokenizer,
n_frames_per_step, n_frames_per_step,
speaker_to_id speaker_to_id,
) -> SpeechToTextDataset: ) -> SpeechToTextDataset:
samples = cls._load_samples_from_tsv(root, split) samples = cls._load_samples_from_tsv(root, split)
return cls._from_list( return cls._from_list(
split, is_train_split, samples, cfg, tgt_dict, pre_tokenizer, split,
bpe_tokenizer, n_frames_per_step, speaker_to_id is_train_split,
samples,
cfg,
tgt_dict,
pre_tokenizer,
bpe_tokenizer,
n_frames_per_step,
speaker_to_id,
) )
@classmethod @classmethod
@ -502,12 +517,19 @@ class SpeechToTextDatasetCreator(object):
epoch: int, epoch: int,
seed: int, seed: int,
n_frames_per_step: int = 1, n_frames_per_step: int = 1,
speaker_to_id=None speaker_to_id=None,
) -> SpeechToTextDataset: ) -> SpeechToTextDataset:
datasets = [ datasets = [
cls._from_tsv( cls._from_tsv(
root, cfg, split, tgt_dict, is_train_split, pre_tokenizer, root,
bpe_tokenizer, n_frames_per_step, speaker_to_id cfg,
split,
tgt_dict,
is_train_split,
pre_tokenizer,
bpe_tokenizer,
n_frames_per_step,
speaker_to_id,
) )
for split in splits.split(",") for split in splits.split(",")
] ]

View File

@ -13,8 +13,11 @@ import numpy as np
import torch import torch
from fairseq.data.audio.speech_to_text_dataset import ( from fairseq.data.audio.speech_to_text_dataset import (
SpeechToTextDataset, SpeechToTextDatasetCreator, S2TDataConfig, SpeechToTextDataset,
_collate_frames, get_features_or_waveform SpeechToTextDatasetCreator,
S2TDataConfig,
_collate_frames,
get_features_or_waveform,
) )
from fairseq.data import Dictionary, data_utils as fairseq_data_utils from fairseq.data import Dictionary, data_utils as fairseq_data_utils
@ -32,34 +35,44 @@ class TextToSpeechDatasetItem(object):
class TextToSpeechDataset(SpeechToTextDataset): class TextToSpeechDataset(SpeechToTextDataset):
def __init__( def __init__(
self, self,
split: str, split: str,
is_train_split: bool, is_train_split: bool,
cfg: S2TDataConfig, cfg: S2TDataConfig,
audio_paths: List[str], audio_paths: List[str],
n_frames: List[int], n_frames: List[int],
src_texts: Optional[List[str]] = None, src_texts: Optional[List[str]] = None,
tgt_texts: Optional[List[str]] = None, tgt_texts: Optional[List[str]] = None,
speakers: Optional[List[str]] = None, speakers: Optional[List[str]] = None,
src_langs: Optional[List[str]] = None, src_langs: Optional[List[str]] = None,
tgt_langs: Optional[List[str]] = None, tgt_langs: Optional[List[str]] = None,
ids: Optional[List[str]] = None, ids: Optional[List[str]] = None,
tgt_dict: Optional[Dictionary] = None, tgt_dict: Optional[Dictionary] = None,
pre_tokenizer=None, pre_tokenizer=None,
bpe_tokenizer=None, bpe_tokenizer=None,
n_frames_per_step=1, n_frames_per_step=1,
speaker_to_id=None, speaker_to_id=None,
durations: Optional[List[List[int]]] = None, durations: Optional[List[List[int]]] = None,
pitches: Optional[List[str]] = None, pitches: Optional[List[str]] = None,
energies: Optional[List[str]] = None energies: Optional[List[str]] = None,
): ):
super(TextToSpeechDataset, self).__init__( super(TextToSpeechDataset, self).__init__(
split, is_train_split, cfg, audio_paths, n_frames, split,
src_texts=src_texts, tgt_texts=tgt_texts, speakers=speakers, is_train_split,
src_langs=src_langs, tgt_langs=tgt_langs, ids=ids, cfg,
tgt_dict=tgt_dict, pre_tokenizer=pre_tokenizer, audio_paths,
bpe_tokenizer=bpe_tokenizer, n_frames_per_step=n_frames_per_step, n_frames,
speaker_to_id=speaker_to_id src_texts=src_texts,
tgt_texts=tgt_texts,
speakers=speakers,
src_langs=src_langs,
tgt_langs=tgt_langs,
ids=ids,
tgt_dict=tgt_dict,
pre_tokenizer=pre_tokenizer,
bpe_tokenizer=bpe_tokenizer,
n_frames_per_step=n_frames_per_step,
speaker_to_id=speaker_to_id,
) )
self.durations = durations self.durations = durations
self.pitches = pitches self.pitches = pitches
@ -84,9 +97,13 @@ class TextToSpeechDataset(SpeechToTextDataset):
np.concatenate((energy, [0])) # pad 0 for EOS np.concatenate((energy, [0])) # pad 0 for EOS
).float() ).float()
return TextToSpeechDatasetItem( return TextToSpeechDatasetItem(
index=index, source=s2t_item.source, target=s2t_item.target, index=index,
speaker_id=s2t_item.speaker_id, duration=duration, pitch=pitch, source=s2t_item.source,
energy=energy target=s2t_item.target,
speaker_id=s2t_item.speaker_id,
duration=duration,
pitch=pitch,
energy=energy,
) )
def collater(self, samples: List[TextToSpeechDatasetItem]) -> Dict[str, Any]: def collater(self, samples: List[TextToSpeechDatasetItem]) -> Dict[str, Any]:
@ -96,8 +113,9 @@ class TextToSpeechDataset(SpeechToTextDataset):
src_lengths, order = torch.tensor( src_lengths, order = torch.tensor(
[s.target.shape[0] for s in samples], dtype=torch.long [s.target.shape[0] for s in samples], dtype=torch.long
).sort(descending=True) ).sort(descending=True)
id_ = torch.tensor([s.index for s in samples], id_ = torch.tensor([s.index for s in samples], dtype=torch.long).index_select(
dtype=torch.long).index_select(0, order) 0, order
)
feat = _collate_frames( feat = _collate_frames(
[s.source for s in samples], self.cfg.use_audio_input [s.source for s in samples], self.cfg.use_audio_input
).index_select(0, order) ).index_select(0, order)
@ -115,9 +133,11 @@ class TextToSpeechDataset(SpeechToTextDataset):
speaker = None speaker = None
if self.speaker_to_id is not None: if self.speaker_to_id is not None:
speaker = torch.tensor( speaker = (
[s.speaker_id for s in samples], dtype=torch.long torch.tensor([s.speaker_id for s in samples], dtype=torch.long)
).index_select(0, order).view(-1, 1) .index_select(0, order)
.view(-1, 1)
)
bsz, _, d = feat.size() bsz, _, d = feat.size()
prev_output_tokens = torch.cat( prev_output_tokens = torch.cat(
@ -175,7 +195,7 @@ class TextToSpeechDatasetCreator(SpeechToTextDatasetCreator):
pre_tokenizer, pre_tokenizer,
bpe_tokenizer, bpe_tokenizer,
n_frames_per_step, n_frames_per_step,
speaker_to_id speaker_to_id,
) -> TextToSpeechDataset: ) -> TextToSpeechDataset:
audio_root = Path(cfg.audio_root) audio_root = Path(cfg.audio_root)
ids = [s[cls.KEY_ID] for s in samples] ids = [s[cls.KEY_ID] for s in samples]
@ -189,27 +209,40 @@ class TextToSpeechDatasetCreator(SpeechToTextDatasetCreator):
durations = [s.get(cls.KEY_DURATION, None) for s in samples] durations = [s.get(cls.KEY_DURATION, None) for s in samples]
durations = [ durations = [
None if dd is None else [int(d) for d in dd.split(" ")] None if dd is None else [int(d) for d in dd.split(" ")] for dd in durations
for dd in durations
] ]
durations = None if any(dd is None for dd in durations) else durations durations = None if any(dd is None for dd in durations) else durations
pitches = [s.get(cls.KEY_PITCH, None) for s in samples] pitches = [s.get(cls.KEY_PITCH, None) for s in samples]
pitches = [ pitches = [
None if pp is None else (audio_root / pp).as_posix() None if pp is None else (audio_root / pp).as_posix() for pp in pitches
for pp in pitches
] ]
pitches = None if any(pp is None for pp in pitches) else pitches pitches = None if any(pp is None for pp in pitches) else pitches
energies = [s.get(cls.KEY_ENERGY, None) for s in samples] energies = [s.get(cls.KEY_ENERGY, None) for s in samples]
energies = [ energies = [
None if ee is None else (audio_root / ee).as_posix() None if ee is None else (audio_root / ee).as_posix() for ee in energies
for ee in energies] ]
energies = None if any(ee is None for ee in energies) else energies energies = None if any(ee is None for ee in energies) else energies
return TextToSpeechDataset( return TextToSpeechDataset(
split_name, is_train_split, cfg, audio_paths, n_frames, split_name,
src_texts, tgt_texts, speakers, src_langs, tgt_langs, ids, tgt_dict, is_train_split,
pre_tokenizer, bpe_tokenizer, n_frames_per_step, speaker_to_id, cfg,
durations, pitches, energies audio_paths,
n_frames,
src_texts,
tgt_texts,
speakers,
src_langs,
tgt_langs,
ids,
tgt_dict,
pre_tokenizer,
bpe_tokenizer,
n_frames_per_step,
speaker_to_id,
durations,
pitches,
energies,
) )

View File

@ -9,7 +9,7 @@ from . import BaseWrapperDataset
class ColorizeDataset(BaseWrapperDataset): class ColorizeDataset(BaseWrapperDataset):
""" Adds 'colors' property to net input that is obtained from the provided color getter for use by models """ """Adds 'colors' property to net input that is obtained from the provided color getter for use by models"""
def __init__(self, dataset, color_getter): def __init__(self, dataset, color_getter):
super().__init__(dataset) super().__init__(dataset)

View File

@ -69,6 +69,7 @@ def collate_tokens(
copy_tensor(v, res[i][size - len(v) :] if left_pad else res[i][: len(v)]) copy_tensor(v, res[i][size - len(v) :] if left_pad else res[i][: len(v)])
return res return res
def load_indexed_dataset( def load_indexed_dataset(
path, dictionary=None, dataset_impl=None, combine=False, default="cached" path, dictionary=None, dataset_impl=None, combine=False, default="cached"
): ):
@ -324,9 +325,7 @@ def batch_by_size(
) )
# added int() to avoid TypeError: an integer is required # added int() to avoid TypeError: an integer is required
max_tokens = ( max_tokens = int(max_tokens) if max_tokens is not None else -1
int(max_tokens) if max_tokens is not None else -1
)
max_sentences = max_sentences if max_sentences is not None else -1 max_sentences = max_sentences if max_sentences is not None else -1
bsz_mult = required_batch_size_multiple bsz_mult = required_batch_size_multiple
@ -375,8 +374,9 @@ def post_process(sentence: str, symbol: str):
sentence = sentence.replace(" ", "").replace("|", " ").strip() sentence = sentence.replace(" ", "").replace("|", " ").strip()
elif symbol == "silence": elif symbol == "silence":
import re import re
sentence = sentence.replace("<SIL>", "") sentence = sentence.replace("<SIL>", "")
sentence = re.sub(' +', ' ', sentence).strip() sentence = re.sub(" +", " ", sentence).strip()
elif symbol == "_EOW": elif symbol == "_EOW":
sentence = sentence.replace(" ", "").replace("_EOW", " ").strip() sentence = sentence.replace(" ", "").replace("_EOW", " ").strip()
elif symbol in {"subword_nmt", "@@ ", "@@"}: elif symbol in {"subword_nmt", "@@ ", "@@"}:
@ -547,7 +547,7 @@ def get_buckets(sizes, num_buckets):
np.percentile( np.percentile(
sizes, sizes,
np.linspace(0, 100, num_buckets + 1), np.linspace(0, 100, num_buckets + 1),
interpolation='lower', interpolation="lower",
)[1:] )[1:]
) )
return buckets return buckets
@ -564,7 +564,6 @@ def get_bucketed_sizes(orig_sizes, buckets):
return sizes return sizes
def _find_extra_valid_paths(dataset_path: str) -> set: def _find_extra_valid_paths(dataset_path: str) -> set:
paths = utils.split_paths(dataset_path) paths = utils.split_paths(dataset_path)
all_valid_paths = set() all_valid_paths = set()

View File

@ -21,8 +21,10 @@ class SentencepieceConfig(FairseqDataclass):
) )
sentencepiece_alpha: Optional[float] = field( sentencepiece_alpha: Optional[float] = field(
default=None, default=None,
metadata={"help": "soothing parameter for unigram sampling, " metadata={
"and merge probability for BPE-dropout"} "help": "soothing parameter for unigram sampling, "
"and merge probability for BPE-dropout"
},
) )
@ -45,8 +47,7 @@ class SentencepieceBPE(object):
def encode(self, x: str) -> str: def encode(self, x: str) -> str:
return " ".join( return " ".join(
self.sp.Encode( self.sp.Encode(
x, out_type=str, enable_sampling=self.enable_sampling, x, out_type=str, enable_sampling=self.enable_sampling, alpha=self.alpha
alpha=self.alpha
) )
) )

View File

@ -138,7 +138,7 @@ class FairseqDataset(torch.utils.data.Dataset, EpochListening):
) )
try: try:
num_tokens_vec = self.num_tokens_vec(indices).astype('int64') num_tokens_vec = self.num_tokens_vec(indices).astype("int64")
except NotImplementedError: except NotImplementedError:
num_tokens_vec = None num_tokens_vec = None

View File

@ -140,7 +140,9 @@ class HuffmanNode:
def is_leaf(self) -> bool: def is_leaf(self) -> bool:
return self.left is None and self.right is None return self.left is None and self.right is None
def code_table(self, prefix: tp.Optional[bitarray] = None) -> tp.Dict[str, "HuffmanNode"]: def code_table(
self, prefix: tp.Optional[bitarray] = None
) -> tp.Dict[str, "HuffmanNode"]:
defaulted_prefix = prefix if prefix is not None else bitarray() defaulted_prefix = prefix if prefix is not None else bitarray()
if self.is_leaf(): if self.is_leaf():
self.code = ( self.code = (

View File

@ -67,7 +67,9 @@ def make_builder(out_file, impl, vocab_size=None):
elif impl == "fasta": elif impl == "fasta":
raise NotImplementedError raise NotImplementedError
elif impl == "huffman": elif impl == "huffman":
raise ValueError("Use HuffmanCodeBuilder directly as it has a different interface.") raise ValueError(
"Use HuffmanCodeBuilder directly as it has a different interface."
)
else: else:
return IndexedDatasetBuilder(out_file) return IndexedDatasetBuilder(out_file)

View File

@ -380,7 +380,9 @@ class EpochBatchIterator(EpochBatchIterating):
# reset _frozen_batches to refresh the next epoch # reset _frozen_batches to refresh the next epoch
self._frozen_batches = None self._frozen_batches = None
self._cur_epoch_itr = self._get_iterator_for_epoch( self._cur_epoch_itr = self._get_iterator_for_epoch(
self.epoch, shuffle, fix_batches_to_gpus=fix_batches_to_gpus, self.epoch,
shuffle,
fix_batches_to_gpus=fix_batches_to_gpus,
) )
self.shuffle = shuffle self.shuffle = shuffle
return self._cur_epoch_itr return self._cur_epoch_itr
@ -421,7 +423,9 @@ class EpochBatchIterator(EpochBatchIterating):
if itr_pos > 0: if itr_pos > 0:
# fast-forward epoch iterator # fast-forward epoch iterator
self._next_epoch_itr = self._get_iterator_for_epoch( self._next_epoch_itr = self._get_iterator_for_epoch(
self.epoch, shuffle=state_dict.get("shuffle", True), offset=itr_pos, self.epoch,
shuffle=state_dict.get("shuffle", True),
offset=itr_pos,
) )
if self._next_epoch_itr is None: if self._next_epoch_itr is None:
if version == 1: if version == 1:

View File

@ -114,7 +114,10 @@ def collate(
"id": id, "id": id,
"nsentences": len(samples), "nsentences": len(samples),
"ntokens": ntokens, "ntokens": ntokens,
"net_input": {"src_tokens": src_tokens, "src_lengths": src_lengths,}, "net_input": {
"src_tokens": src_tokens,
"src_lengths": src_lengths,
},
"target": target, "target": target,
} }
if prev_output_tokens is not None: if prev_output_tokens is not None:
@ -467,5 +470,8 @@ class LanguagePairDataset(FairseqDataset):
list: list of removed indices list: list of removed indices
""" """
return data_utils.filter_paired_dataset_indices_by_size( return data_utils.filter_paired_dataset_indices_by_size(
self.src_sizes, self.tgt_sizes, indices, max_sizes, self.src_sizes,
self.tgt_sizes,
indices,
max_sizes,
) )

View File

@ -80,7 +80,9 @@ class MultiCorpusDataset(FairseqDataset):
def ordered_indices(self): def ordered_indices(self):
start = time.time() start = time.time()
with data_utils.numpy_seed(self.seed, self.epoch): with data_utils.numpy_seed(self.seed, self.epoch):
logger.info(f"sampling new dataset with seed {self.seed} epoch {self.epoch}") logger.info(
f"sampling new dataset with seed {self.seed} epoch {self.epoch}"
)
sampled_indices = [] sampled_indices = []
num_selected_instances = 0 num_selected_instances = 0

View File

@ -40,8 +40,8 @@ from fairseq.utils import FileContentsAction, csv_str_list, eval_str_dict
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
SRC_DICT_NAME = 'src' SRC_DICT_NAME = "src"
TGT_DICT_NAME = 'tgt' TGT_DICT_NAME = "tgt"
def _lang_id(dic: Dictionary, lang: str): def _lang_id(dic: Dictionary, lang: str):
@ -64,14 +64,16 @@ class MultilingualDatasetManager(object):
self.seed = args.seed self.seed = args.seed
self.lang_pairs = lang_pairs self.lang_pairs = lang_pairs
self.extra_lang_pairs = ( self.extra_lang_pairs = (
list( list({p for _, v in args.extra_lang_pairs.items() for p in v.split(",")})
{p for _, v in args.extra_lang_pairs.items() for p in v.split(",")} if args.extra_lang_pairs
) else []
if args.extra_lang_pairs )
else [] self.src_langs = {
) p.split("-")[0] for p in args.lang_pairs + self.extra_lang_pairs
self.src_langs = {p.split("-")[0] for p in args.lang_pairs + self.extra_lang_pairs} }
self.tgt_langs = {p.split("-")[1] for p in args.lang_pairs + self.extra_lang_pairs} self.tgt_langs = {
p.split("-")[1] for p in args.lang_pairs + self.extra_lang_pairs
}
self.langs = langs self.langs = langs
self.dicts = dicts self.dicts = dicts
self.lang_dict = self.create_lang_dictionary(self.langs) self.lang_dict = self.create_lang_dictionary(self.langs)
@ -111,10 +113,18 @@ class MultilingualDatasetManager(object):
"note that the ordering determines language token IDs; " "note that the ordering determines language token IDs; "
"--langs and --lang-dict are two exclusive options", "--langs and --lang-dict are two exclusive options",
) )
parser.add_argument('--source-dict', default=None, type=str, parser.add_argument(
help='path to source dictionary; if specified it will override per language dictionary loading') "--source-dict",
parser.add_argument('--target-dict', default=None, type=str, default=None,
help='path to target dictionary; if specified it will override per language dictionary loading') type=str,
help="path to source dictionary; if specified it will override per language dictionary loading",
)
parser.add_argument(
"--target-dict",
default=None,
type=str,
help="path to target dictionary; if specified it will override per language dictionary loading",
)
parser.add_argument( parser.add_argument(
"--lang-tok-style", "--lang-tok-style",
default=LangTokStyle.multilingual.value, default=LangTokStyle.multilingual.value,
@ -378,7 +388,9 @@ class MultilingualDatasetManager(object):
) )
return d return d
dicts = cls.load_all_dictionaries(args, language_list, load_dictionary_and_postproc, training) dicts = cls.load_all_dictionaries(
args, language_list, load_dictionary_and_postproc, training
)
return language_list, dicts, training return language_list, dicts, training
@classmethod @classmethod
@ -424,7 +436,10 @@ class MultilingualDatasetManager(object):
if args.fixed_dictionary is not None: if args.fixed_dictionary is not None:
fixed_dict = load_dictionary(args.fixed_dictionary) fixed_dict = load_dictionary(args.fixed_dictionary)
dicts = {lang: fixed_dict for lang in src_langs_to_load_dicts + tgt_langs_to_load_dicts} dicts = {
lang: fixed_dict
for lang in src_langs_to_load_dicts + tgt_langs_to_load_dicts
}
else: else:
if args.source_dict is None: if args.source_dict is None:
load_dicts(src_langs_to_load_dicts) load_dicts(src_langs_to_load_dicts)
@ -477,7 +492,10 @@ class MultilingualDatasetManager(object):
lang=tgt_lang, lang_tok_style=self.args.lang_tok_style, spec=spec lang=tgt_lang, lang_tok_style=self.args.lang_tok_style, spec=spec
) )
return self.get_langtok_index( return self.get_langtok_index(
langtok, self.get_source_dictionary(src_lang) if src_lang else self.get_target_dictionary(tgt_lang) langtok,
self.get_source_dictionary(src_lang)
if src_lang
else self.get_target_dictionary(tgt_lang),
) )
def get_decoder_langtok(self, tgt_lang, spec=None): def get_decoder_langtok(self, tgt_lang, spec=None):
@ -819,7 +837,9 @@ class MultilingualDatasetManager(object):
if self.args.lang_tok_replacing_bos_eos: if self.args.lang_tok_replacing_bos_eos:
ds = self.alter_dataset_langtok( ds = self.alter_dataset_langtok(
langpair_ds, langpair_ds,
src_eos=self.get_source_dictionary(src).eos() if src else self.get_target_dictionary(tgt).eos(), src_eos=self.get_source_dictionary(src).eos()
if src
else self.get_target_dictionary(tgt).eos(),
src_lang=src, src_lang=src,
tgt_eos=self.get_target_dictionary(tgt).eos(), tgt_eos=self.get_target_dictionary(tgt).eos(),
tgt_lang=tgt, tgt_lang=tgt,

View File

@ -298,7 +298,6 @@ class NoisingDataset(torch.utils.data.Dataset):
) )
self.sizes = src_dataset.sizes self.sizes = src_dataset.sizes
def __getitem__(self, index): def __getitem__(self, index):
""" """
Returns a single noisy sample. Multiple samples are fed to the collater Returns a single noisy sample. Multiple samples are fed to the collater

View File

@ -14,8 +14,7 @@ class TextCompressionLevel(Enum):
class TextCompressor(object): class TextCompressor(object):
def __init__( def __init__(
self, level: TextCompressionLevel, self, level: TextCompressionLevel, max_input_byte_length: int = 2 ** 16
max_input_byte_length: int = 2 ** 16
): ):
self.level = level self.level = level
self.max_input_length = max_input_byte_length self.max_input_length = max_input_byte_length
@ -23,11 +22,13 @@ class TextCompressor(object):
def compress(self, text: str) -> bytes: def compress(self, text: str) -> bytes:
if self.level == TextCompressionLevel.low: if self.level == TextCompressionLevel.low:
import zlib import zlib
# zlib: built-in, fast # zlib: built-in, fast
return zlib.compress(text.encode(), level=0) return zlib.compress(text.encode(), level=0)
elif self.level == TextCompressionLevel.high: elif self.level == TextCompressionLevel.high:
try: try:
import unishox2 import unishox2
# unishox2: optimized for short text but slower # unishox2: optimized for short text but slower
except ImportError: except ImportError:
raise ImportError( raise ImportError(
@ -42,6 +43,7 @@ class TextCompressor(object):
def decompress(self, compressed: bytes) -> str: def decompress(self, compressed: bytes) -> str:
if self.level == TextCompressionLevel.low: if self.level == TextCompressionLevel.low:
import zlib import zlib
return zlib.decompress(compressed).decode() return zlib.decompress(compressed).decode()
elif self.level == TextCompressionLevel.high: elif self.level == TextCompressionLevel.high:
try: try:

View File

@ -69,7 +69,10 @@ class TokenBlockDataset(FairseqDataset):
_sizes, split_path, (plasma_id, 1), plasma_path=plasma_path _sizes, split_path, (plasma_id, 1), plasma_path=plasma_path
) )
self._block_to_dataset_index = plasma_utils.PlasmaView( self._block_to_dataset_index = plasma_utils.PlasmaView(
block_to_dataset_index, split_path, (plasma_id, 2), plasma_path=plasma_path, block_to_dataset_index,
split_path,
(plasma_id, 2),
plasma_path=plasma_path,
) )
else: else:
self._slice_indices = plasma_utils.PlasmaArray(slice_indices) self._slice_indices = plasma_utils.PlasmaArray(slice_indices)
@ -127,7 +130,8 @@ class TokenBlockDataset(FairseqDataset):
) )
else: else:
block_to_dataset_index = _get_block_to_dataset_index_fast( block_to_dataset_index = _get_block_to_dataset_index_fast(
sizes, slice_indices, sizes,
slice_indices,
) )
size_dtype = np.uint16 if block_size < 65535 else np.uint32 size_dtype = np.uint16 if block_size < 65535 else np.uint32
num_tokens = slice_indices[-1].max() num_tokens = slice_indices[-1].max()

View File

@ -52,7 +52,7 @@ class TransformEosLangPairDataset(FairseqDataset):
if len(samples) == 0: if len(samples) == 0:
return samples return samples
if 'net_input' not in samples: if "net_input" not in samples:
return samples return samples
if self.new_src_eos is not None: if self.new_src_eos is not None:

View File

@ -126,7 +126,8 @@ class CommonConfig(FairseqDataclass):
metadata={"help": "Weights and Biases project name to use for logging"}, metadata={"help": "Weights and Biases project name to use for logging"},
) )
azureml_logging: Optional[bool] = field( azureml_logging: Optional[bool] = field(
default=False, metadata={"help": "Log scalars to AzureML context"}, default=False,
metadata={"help": "Log scalars to AzureML context"},
) )
seed: int = field( seed: int = field(
default=1, metadata={"help": "pseudo random number generator seed"} default=1, metadata={"help": "pseudo random number generator seed"}
@ -428,19 +429,23 @@ class DistributedTrainingConfig(FairseqDataclass):
tpu: bool = II("common.tpu") tpu: bool = II("common.tpu")
# configuration for --ddp-backend=fully_sharded # configuration for --ddp-backend=fully_sharded
no_reshard_after_forward: bool = field( no_reshard_after_forward: bool = field(
default=False, metadata={"help": "don't reshard parameters after forward pass"}, default=False,
metadata={"help": "don't reshard parameters after forward pass"},
) )
fp32_reduce_scatter: bool = field( fp32_reduce_scatter: bool = field(
default=False, metadata={"help": "reduce-scatter grads in FP32"}, default=False,
metadata={"help": "reduce-scatter grads in FP32"},
) )
cpu_offload: bool = field( cpu_offload: bool = field(
default=False, metadata={"help": "offload FP32 params to CPU"} default=False, metadata={"help": "offload FP32 params to CPU"}
) )
use_sharded_state: bool = field( use_sharded_state: bool = field(
default=False, metadata={"help": "use sharded checkpoint files"}, default=False,
metadata={"help": "use sharded checkpoint files"},
) )
not_fsdp_flatten_parameters: bool = field( not_fsdp_flatten_parameters: bool = field(
default=False, metadata={"help": "not flatten parameter param for fsdp"}, default=False,
metadata={"help": "not flatten parameter param for fsdp"},
) )
@ -786,10 +791,12 @@ class FairseqBMUFConfig(FairseqDataclass):
@dataclass @dataclass
class GenerationConfig(FairseqDataclass): class GenerationConfig(FairseqDataclass):
beam: int = field( beam: int = field(
default=5, metadata={"help": "beam size"}, default=5,
metadata={"help": "beam size"},
) )
nbest: int = field( nbest: int = field(
default=1, metadata={"help": "number of hypotheses to output"}, default=1,
metadata={"help": "number of hypotheses to output"},
) )
max_len_a: float = field( max_len_a: float = field(
default=0, default=0,
@ -804,19 +811,24 @@ class GenerationConfig(FairseqDataclass):
}, },
) )
min_len: int = field( min_len: int = field(
default=1, metadata={"help": "minimum generation length"}, default=1,
metadata={"help": "minimum generation length"},
) )
match_source_len: bool = field( match_source_len: bool = field(
default=False, metadata={"help": "generations should match the source length"}, default=False,
metadata={"help": "generations should match the source length"},
) )
unnormalized: bool = field( unnormalized: bool = field(
default=False, metadata={"help": "compare unnormalized hypothesis scores"}, default=False,
metadata={"help": "compare unnormalized hypothesis scores"},
) )
no_early_stop: bool = field( no_early_stop: bool = field(
default=False, metadata={"help": "deprecated"}, default=False,
metadata={"help": "deprecated"},
) )
no_beamable_mm: bool = field( no_beamable_mm: bool = field(
default=False, metadata={"help": "don't use BeamableMM in attention layers"}, default=False,
metadata={"help": "don't use BeamableMM in attention layers"},
) )
lenpen: float = field( lenpen: float = field(
default=1, default=1,
@ -838,10 +850,12 @@ class GenerationConfig(FairseqDataclass):
}, },
) )
sacrebleu: bool = field( sacrebleu: bool = field(
default=False, metadata={"help": "score with sacrebleu"}, default=False,
metadata={"help": "score with sacrebleu"},
) )
score_reference: bool = field( score_reference: bool = field(
default=False, metadata={"help": "just score the reference translation"}, default=False,
metadata={"help": "just score the reference translation"},
) )
prefix_size: int = field( prefix_size: int = field(
default=0, default=0,
@ -875,10 +889,12 @@ class GenerationConfig(FairseqDataclass):
}, },
) )
temperature: float = field( temperature: float = field(
default=1.0, metadata={"help": "temperature for generation"}, default=1.0,
metadata={"help": "temperature for generation"},
) )
diverse_beam_groups: int = field( diverse_beam_groups: int = field(
default=-1, metadata={"help": "number of groups for Diverse Beam Search"}, default=-1,
metadata={"help": "number of groups for Diverse Beam Search"},
) )
diverse_beam_strength: float = field( diverse_beam_strength: float = field(
default=0.5, default=0.5,
@ -897,13 +913,16 @@ class GenerationConfig(FairseqDataclass):
}, },
) )
print_step: bool = field( print_step: bool = field(
default=False, metadata={"help": "print steps"}, default=False,
metadata={"help": "print steps"},
) )
lm_path: Optional[str] = field( lm_path: Optional[str] = field(
default=None, metadata={"help": "path to lm checkpoint for lm fusion"}, default=None,
metadata={"help": "path to lm checkpoint for lm fusion"},
) )
lm_weight: float = field( lm_weight: float = field(
default=0.0, metadata={"help": "weight for lm probs for lm fusion"}, default=0.0,
metadata={"help": "weight for lm probs for lm fusion"},
) )
# arguments for iterative refinement generator # arguments for iterative refinement generator
@ -912,7 +931,8 @@ class GenerationConfig(FairseqDataclass):
metadata={"help": "if > 0.0, it penalized early-stopping in decoding."}, metadata={"help": "if > 0.0, it penalized early-stopping in decoding."},
) )
iter_decode_max_iter: int = field( iter_decode_max_iter: int = field(
default=10, metadata={"help": "maximum iterations for iterative refinement."}, default=10,
metadata={"help": "maximum iterations for iterative refinement."},
) )
iter_decode_force_max_iter: bool = field( iter_decode_force_max_iter: bool = field(
default=False, default=False,
@ -939,7 +959,8 @@ class GenerationConfig(FairseqDataclass):
}, },
) )
retain_dropout: bool = field( retain_dropout: bool = field(
default=False, metadata={"help": "Use dropout at inference time"}, default=False,
metadata={"help": "Use dropout at inference time"},
) )
# temporarily set to Any until https://github.com/facebookresearch/hydra/issues/1117 is fixed # temporarily set to Any until https://github.com/facebookresearch/hydra/issues/1117 is fixed
# retain_dropout_modules: Optional[List[str]] = field( # retain_dropout_modules: Optional[List[str]] = field(
@ -964,7 +985,8 @@ class GenerationConfig(FairseqDataclass):
@dataclass @dataclass
class CommonEvalConfig(FairseqDataclass): class CommonEvalConfig(FairseqDataclass):
path: Optional[str] = field( path: Optional[str] = field(
default=None, metadata={"help": "path(s) to model file(s), colon separated"}, default=None,
metadata={"help": "path(s) to model file(s), colon separated"},
) )
post_process: Optional[str] = field( post_process: Optional[str] = field(
default=None, default=None,
@ -1026,7 +1048,8 @@ class InteractiveConfig(FairseqDataclass):
}, },
) )
input: str = field( input: str = field(
default="-", metadata={"help": "file to read from; use - for stdin"}, default="-",
metadata={"help": "file to read from; use - for stdin"},
) )

View File

@ -35,14 +35,16 @@ def ChoiceEnum(choices: List[str]):
LOG_FORMAT_CHOICES = ChoiceEnum(["json", "none", "simple", "tqdm"]) LOG_FORMAT_CHOICES = ChoiceEnum(["json", "none", "simple", "tqdm"])
DDP_BACKEND_CHOICES = ChoiceEnum([ DDP_BACKEND_CHOICES = ChoiceEnum(
"c10d", # alias for pytorch_ddp [
"fully_sharded", # FullyShardedDataParallel from fairscale "c10d", # alias for pytorch_ddp
"legacy_ddp", "fully_sharded", # FullyShardedDataParallel from fairscale
"no_c10d", # alias for legacy_ddp "legacy_ddp",
"pytorch_ddp", "no_c10d", # alias for legacy_ddp
"slowmo", "pytorch_ddp",
]) "slowmo",
]
)
DDP_COMM_HOOK_CHOICES = ChoiceEnum(["none", "fp16"]) DDP_COMM_HOOK_CHOICES = ChoiceEnum(["none", "fp16"])
DATASET_IMPL_CHOICES = ChoiceEnum(["raw", "lazy", "cached", "mmap", "fasta", "huffman"]) DATASET_IMPL_CHOICES = ChoiceEnum(["raw", "lazy", "cached", "mmap", "fasta", "huffman"])
GENERATION_CONSTRAINTS_CHOICES = ChoiceEnum(["ordered", "unordered"]) GENERATION_CONSTRAINTS_CHOICES = ChoiceEnum(["ordered", "unordered"])

View File

@ -28,7 +28,7 @@ def hydra_init(cfg_name="config") -> None:
def add_defaults(cfg: DictConfig) -> None: def add_defaults(cfg: DictConfig) -> None:
"""This function adds default values that are stored in dataclasses that hydra doesn't know about """ """This function adds default values that are stored in dataclasses that hydra doesn't know about"""
from fairseq.registry import REGISTRIES from fairseq.registry import REGISTRIES
from fairseq.tasks import TASK_DATACLASS_REGISTRY from fairseq.tasks import TASK_DATACLASS_REGISTRY

View File

@ -57,21 +57,21 @@ def gen_parser_from_dataclass(
with_prefix: Optional[str] = None, with_prefix: Optional[str] = None,
) -> None: ) -> None:
""" """
convert a dataclass instance to tailing parser arguments. convert a dataclass instance to tailing parser arguments.
If `with_prefix` is provided, prefix all the keys in the resulting parser with it. It means that we are If `with_prefix` is provided, prefix all the keys in the resulting parser with it. It means that we are
building a flat namespace from a structured dataclass (see transformer_config.py for example). building a flat namespace from a structured dataclass (see transformer_config.py for example).
""" """
def argparse_name(name: str): def argparse_name(name: str):
if name == "data" and (with_prefix is None or with_prefix == ''): if name == "data" and (with_prefix is None or with_prefix == ""):
# normally data is positional args, so we don't add the -- nor the prefix # normally data is positional args, so we don't add the -- nor the prefix
return name return name
if name == "_name": if name == "_name":
# private member, skip # private member, skip
return None return None
full_name = "--" + name.replace("_", "-") full_name = "--" + name.replace("_", "-")
if with_prefix is not None and with_prefix != '': if with_prefix is not None and with_prefix != "":
# if a prefix is specified, construct the prefixed arg name # if a prefix is specified, construct the prefixed arg name
full_name = with_prefix + "-" + full_name[2:] # strip -- when composing full_name = with_prefix + "-" + full_name[2:] # strip -- when composing
return full_name return full_name
@ -143,8 +143,8 @@ def gen_parser_from_dataclass(
kwargs["default"] = field_default kwargs["default"] = field_default
# build the help with the hierarchical prefix # build the help with the hierarchical prefix
if with_prefix is not None and with_prefix != '' and field_help is not None: if with_prefix is not None and with_prefix != "" and field_help is not None:
field_help = with_prefix[2:] + ': ' + field_help field_help = with_prefix[2:] + ": " + field_help
kwargs["help"] = field_help kwargs["help"] = field_help
if field_const is not None: if field_const is not None:

View File

@ -4,7 +4,11 @@
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from .distributed_timeout_wrapper import DistributedTimeoutWrapper from .distributed_timeout_wrapper import DistributedTimeoutWrapper
from .fully_sharded_data_parallel import fsdp_enable_wrap, fsdp_wrap, FullyShardedDataParallel from .fully_sharded_data_parallel import (
fsdp_enable_wrap,
fsdp_wrap,
FullyShardedDataParallel,
)
from .legacy_distributed_data_parallel import LegacyDistributedDataParallel from .legacy_distributed_data_parallel import LegacyDistributedDataParallel
from .module_proxy_wrapper import ModuleProxyWrapper from .module_proxy_wrapper import ModuleProxyWrapper
from .tpu_distributed_data_parallel import TPUDistributedDataParallel from .tpu_distributed_data_parallel import TPUDistributedDataParallel

View File

@ -33,6 +33,7 @@ class DistributedTimeoutWrapper(nn.Module):
(set to a value <= 0 to disable the timeout) (set to a value <= 0 to disable the timeout)
signal (Optional): signal to send once timeout is triggered signal (Optional): signal to send once timeout is triggered
""" """
def __init__(self, module: nn.Module, timeout: int, signal=signal.SIGINT): def __init__(self, module: nn.Module, timeout: int, signal=signal.SIGINT):
super().__init__() super().__init__()
self.module = module self.module = module
@ -86,9 +87,11 @@ class DistributedTimeoutWrapper(nn.Module):
if self._terminated: if self._terminated:
break break
elif not success: elif not success:
logger.error(( logger.error(
"Killing job for not making progress in {} seconds. " (
"Set --heartbeat-timeout=-1 to disable this timeout." "Killing job for not making progress in {} seconds. "
).format(int(self.timeout))) "Set --heartbeat-timeout=-1 to disable this timeout."
).format(int(self.timeout))
)
os.kill(parent_pid, self.signal) os.kill(parent_pid, self.signal)
return return

View File

@ -137,7 +137,7 @@ class LegacyDistributedDataParallel(nn.Module):
if param.grad is None: if param.grad is None:
param.grad = torch.zeros_like(param) param.grad = torch.zeros_like(param)
if hasattr(param, 'expert'): if hasattr(param, "expert"):
# Skip gradient sync for unshared parameters # Skip gradient sync for unshared parameters
continue continue

View File

@ -26,8 +26,9 @@ class ModuleProxyWrapper(nn.Module):
def __init__(self, module: nn.Module): def __init__(self, module: nn.Module):
super().__init__() super().__init__()
assert hasattr(module, "module"), \ assert hasattr(
"ModuleProxyWrapper expects input to wrap another module" module, "module"
), "ModuleProxyWrapper expects input to wrap another module"
self.module = module self.module = module
def __getattr__(self, name): def __getattr__(self, name):

View File

@ -10,7 +10,6 @@ from fairseq.distributed import utils
class TPUDistributedDataParallel(nn.Module): class TPUDistributedDataParallel(nn.Module):
def __init__(self, module, process_group): def __init__(self, module, process_group):
super().__init__() super().__init__()
self.module = module self.module = module
@ -35,9 +34,10 @@ class TPUDistributedDataParallel(nn.Module):
gradients.append(p.grad) gradients.append(p.grad)
import torch_xla.core.xla_model as xm import torch_xla.core.xla_model as xm
xm.all_reduce( xm.all_reduce(
'sum', "sum",
gradients, gradients,
scale=1. / self.world_size, scale=1.0 / self.world_size,
groups=self.process_group[1], groups=self.process_group[1],
) )

View File

@ -201,9 +201,7 @@ def _pipeline_parallel_post_init(
# distributed_world_size to be based on the total number of GPUs, so # distributed_world_size to be based on the total number of GPUs, so
# we need to correct them to be based on the number of pipelines. # we need to correct them to be based on the number of pipelines.
assert cfg.distributed_world_size % num_pipeline_devices == 0 assert cfg.distributed_world_size % num_pipeline_devices == 0
cfg.distributed_world_size = ( cfg.distributed_world_size = cfg.distributed_world_size // num_pipeline_devices
cfg.distributed_world_size // num_pipeline_devices
)
# In the case of 4-way MP on nodes with 8 GPUs, we want # In the case of 4-way MP on nodes with 8 GPUs, we want
# distributed_rank to be the starting GPU index for each pipeline # distributed_rank to be the starting GPU index for each pipeline
# i.e., 0, 2, ... # i.e., 0, 2, ...
@ -306,8 +304,10 @@ def distributed_init(cfg: FairseqConfig):
model_part_number = get_model_parallel_rank() model_part_number = get_model_parallel_rank()
cfg.checkpoint.checkpoint_suffix += "-model_part-{0}".format(model_part_number) cfg.checkpoint.checkpoint_suffix += "-model_part-{0}".format(model_part_number)
if hasattr(cfg, "model") and getattr(cfg.model, "base_layers", 0) > 0: if hasattr(cfg, "model") and getattr(cfg.model, "base_layers", 0) > 0:
cfg.checkpoint.checkpoint_suffix = f"-rank-{cfg.distributed_training.distributed_rank}" cfg.checkpoint.checkpoint_suffix = (
f"-rank-{cfg.distributed_training.distributed_rank}"
)
return cfg.distributed_training.distributed_rank return cfg.distributed_training.distributed_rank
@ -696,7 +696,7 @@ def broadcast_tensors(
dist_device = torch.device("cpu") dist_device = torch.device("cpu")
# share metadata first to simplify transfer # share metadata first to simplify transfer
is_src_rank = (get_rank(group) == src_rank) is_src_rank = get_rank(group) == src_rank
if is_src_rank: if is_src_rank:
metadata = [ metadata = [
{"size": t.size(), "dtype": t.dtype, "device": t.device} for t in tensors {"size": t.size(), "dtype": t.dtype, "device": t.device} for t in tensors
@ -747,7 +747,10 @@ def broadcast_object(
def _broadcast_object_slow( def _broadcast_object_slow(
obj: Any, src_rank: int, group: object, dist_device: torch.device, obj: Any,
src_rank: int,
group: object,
dist_device: torch.device,
) -> Any: ) -> Any:
if get_rank(group) == src_rank: if get_rank(group) == src_rank:
# Emit data # Emit data

View File

@ -152,6 +152,7 @@ class PathManager:
""" """
ioPath async PathManager methods: ioPath async PathManager methods:
""" """
@staticmethod @staticmethod
def opena( def opena(
path: str, path: str,
@ -169,6 +170,7 @@ class PathManager:
logging.info("ioPath is initializing PathManager.") logging.info("ioPath is initializing PathManager.")
try: try:
from iopath.common.file_io import PathManager from iopath.common.file_io import PathManager
IOPathManager = PathManager() IOPathManager = PathManager()
except Exception: except Exception:
logging.exception("Failed to initialize ioPath PathManager object.") logging.exception("Failed to initialize ioPath PathManager object.")

View File

@ -146,6 +146,7 @@ def cached_path_from_pm(url_or_filename):
""" """
try: try:
from fairseq.file_io import PathManager from fairseq.file_io import PathManager
local_path = PathManager.get_local_path(url_or_filename) local_path = PathManager.get_local_path(url_or_filename)
return local_path return local_path
except Exception: except Exception:

View File

@ -130,6 +130,7 @@ def log_scalar(
agg.add_meter(key, AverageMeter(round=round), priority) agg.add_meter(key, AverageMeter(round=round), priority)
agg[key].update(value, weight) agg[key].update(value, weight)
def log_scalar_sum( def log_scalar_sum(
key: str, key: str,
value: float, value: float,
@ -309,6 +310,7 @@ def load_state_dict(state_dict):
def xla_metrics_report(): def xla_metrics_report():
try: try:
import torch_xla.debug.metrics as met import torch_xla.debug.metrics as met
print(met.metrics_report()) print(met.metrics_report())
except ImportError: except ImportError:
return return

View File

@ -52,8 +52,7 @@ class MegatronTrainer(Trainer):
def save_checkpoint(self, filename, extra_state): def save_checkpoint(self, filename, extra_state):
"""Save all training state in a checkpoint file.""" """Save all training state in a checkpoint file."""
extra_state['rng_tracker_states'] \ extra_state["rng_tracker_states"] = get_cuda_rng_tracker().get_states()
= get_cuda_rng_tracker().get_states()
super().save_checkpoint(filename, extra_state) super().save_checkpoint(filename, extra_state)
def load_checkpoint( def load_checkpoint(
@ -64,8 +63,13 @@ class MegatronTrainer(Trainer):
optimizer_overrides=None, optimizer_overrides=None,
reset_meters=False, reset_meters=False,
): ):
extra_state = super().load_checkpoint(filename, reset_optimizer=reset_optimizer, reset_lr_scheduler=reset_lr_scheduler, optimizer_overrides=optimizer_overrides, reset_meters=reset_meters) extra_state = super().load_checkpoint(
if extra_state is not None and 'rng_tracker_states' in extra_state: filename,
get_cuda_rng_tracker().set_states( reset_optimizer=reset_optimizer,
extra_state['rng_tracker_states']) reset_lr_scheduler=reset_lr_scheduler,
optimizer_overrides=optimizer_overrides,
reset_meters=reset_meters,
)
if extra_state is not None and "rng_tracker_states" in extra_state:
get_cuda_rng_tracker().set_states(extra_state["rng_tracker_states"])
return extra_state return extra_state

View File

@ -9,6 +9,7 @@ from collections import namedtuple
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from fairseq import options, utils from fairseq import options, utils
from fairseq.modules import ( from fairseq.modules import (
AdaptiveSoftmax, AdaptiveSoftmax,
@ -17,7 +18,6 @@ from fairseq.modules import (
PositionalEmbedding, PositionalEmbedding,
) )
EncoderOut = namedtuple( EncoderOut = namedtuple(
"TransformerEncoderOut", "TransformerEncoderOut",
[ [
@ -30,7 +30,7 @@ EncoderOut = namedtuple(
class TransformerEncoderEmbedding(nn.Module): class TransformerEncoderEmbedding(nn.Module):
""" Encoder Embedding + Positional Embedding """ """Encoder Embedding + Positional Embedding"""
def __init__(self, args, embed_tokens): def __init__(self, args, embed_tokens):
super().__init__() super().__init__()
@ -109,7 +109,7 @@ class TransformerEncoderLayerNorm(nn.Module):
class TransformerDecoderEmbedding(nn.Module): class TransformerDecoderEmbedding(nn.Module):
""" Decoder Embedding + Positional Embedding """ """Decoder Embedding + Positional Embedding"""
def __init__(self, args, embed_tokens): def __init__(self, args, embed_tokens):
super().__init__() super().__init__()

View File

@ -42,16 +42,20 @@ DEFAULT_MAX_TARGET_POSITIONS = 1024
TORCH_PIPE = False TORCH_PIPE = False
RPC_INIT = False RPC_INIT = False
def import_pipe(): def import_pipe():
global TORCH_PIPE global TORCH_PIPE
global RPC_INIT global RPC_INIT
try: try:
from torch.distributed.pipeline.sync import Pipe # noqa from torch.distributed.pipeline.sync import Pipe # noqa
global Pipe global Pipe
from torch.distributed.pipeline.sync.utils import partition_model from torch.distributed.pipeline.sync.utils import partition_model
global partition_model global partition_model
from torch.distributed import rpc from torch.distributed import rpc
import tempfile import tempfile
TORCH_PIPE = True TORCH_PIPE = True
# Initialize single process RPC agent since TORCH_PIPE requires # Initialize single process RPC agent since TORCH_PIPE requires
# RRef. RRef depends on RPC being initialized and as a result we initialize # RRef. RRef depends on RPC being initialized and as a result we initialize
@ -64,14 +68,15 @@ def import_pipe():
world_size=1, world_size=1,
rpc_backend_options=rpc.TensorPipeRpcBackendOptions( rpc_backend_options=rpc.TensorPipeRpcBackendOptions(
init_method="file://{}".format(tmpfile.name), init_method="file://{}".format(tmpfile.name),
) ),
) )
RPC_INIT = True RPC_INIT = True
logger.info('Using torch pipe') logger.info("Using torch pipe")
except ImportError: except ImportError:
try: try:
from fairscale.nn import Pipe # noqa from fairscale.nn import Pipe # noqa
logger.info('Using fairscale pipe')
logger.info("Using fairscale pipe")
except ImportError: except ImportError:
raise ImportError("Please install fairscale with: pip install fairscale") raise ImportError("Please install fairscale with: pip install fairscale")
@ -153,9 +158,14 @@ class PipelineParallelTransformerModel(BaseFairseqModel):
decoder_module_list.append(module) decoder_module_list.append(module)
module_count += 1 module_count += 1
self.model = None self.model = None
self.encoder = TransformerEncoder(cfg.distributed_training, None, None, encoder_module_list) self.encoder = TransformerEncoder(
cfg.distributed_training, None, None, encoder_module_list
)
self.decoder = TransformerDecoder( self.decoder = TransformerDecoder(
cfg.distributed_training, None, None, decoder_module_list=decoder_module_list cfg.distributed_training,
None,
None,
decoder_module_list=decoder_module_list,
) )
@staticmethod @staticmethod
@ -471,7 +481,9 @@ class TransformerEncoder(FairseqEncoder):
self.use_pipeline = encoder_module_list is not None self.use_pipeline = encoder_module_list is not None
if not self.use_pipeline: if not self.use_pipeline:
self.embedding_layer = TransformerEncoderEmbedding(args, embed_tokens) self.embedding_layer = TransformerEncoderEmbedding(args, embed_tokens)
self.encoder_layers = nn.Sequential(*[TransformerEncoderLayer(args) for i in range(args.encoder_layers)]) self.encoder_layers = nn.Sequential(
*[TransformerEncoderLayer(args) for i in range(args.encoder_layers)]
)
if isinstance(embed_tokens, nn.ModuleList): if isinstance(embed_tokens, nn.ModuleList):
emb_dim = sum(e.embedding_dim for e in embed_tokens) emb_dim = sum(e.embedding_dim for e in embed_tokens)
else: else:
@ -490,7 +502,11 @@ class TransformerEncoder(FairseqEncoder):
) )
if TORCH_PIPE: if TORCH_PIPE:
self.model = Pipe( self.model = Pipe(
module=partition_model(nn.Sequential(*encoder_module_list), encoder_balance, encoder_devices), module=partition_model(
nn.Sequential(*encoder_module_list),
encoder_balance,
encoder_devices,
),
chunks=args.pipeline_chunks, chunks=args.pipeline_chunks,
checkpoint=args.pipeline_checkpoint, checkpoint=args.pipeline_checkpoint,
) )
@ -614,10 +630,12 @@ class TransformerDecoder(FairseqDecoder):
self.use_pipeline = decoder_module_list is not None self.use_pipeline = decoder_module_list is not None
if not self.use_pipeline: if not self.use_pipeline:
self.embedding_layer = TransformerDecoderEmbedding(args, embed_tokens) self.embedding_layer = TransformerDecoderEmbedding(args, embed_tokens)
self.decoder_layers = nn.Sequential(*[ self.decoder_layers = nn.Sequential(
TransformerDecoderLayer(args, no_encoder_attn) *[
for _ in range(args.decoder_layers) TransformerDecoderLayer(args, no_encoder_attn)
]) for _ in range(args.decoder_layers)
]
)
self.decoder_output_layer = TransformerDecoderOutputLayer( self.decoder_output_layer = TransformerDecoderOutputLayer(
args, embed_tokens, dictionary args, embed_tokens, dictionary
) )
@ -634,7 +652,11 @@ class TransformerDecoder(FairseqDecoder):
) )
if TORCH_PIPE: if TORCH_PIPE:
self.model = Pipe( self.model = Pipe(
module=partition_model(nn.Sequential(*decoder_module_list), decoder_balance, decoder_devices), module=partition_model(
nn.Sequential(*decoder_module_list),
decoder_balance,
decoder_devices,
),
chunks=args.pipeline_chunks, chunks=args.pipeline_chunks,
checkpoint=args.pipeline_checkpoint, checkpoint=args.pipeline_checkpoint,
) )

View File

@ -4,11 +4,11 @@
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import torch.nn as nn import torch.nn as nn
from fairseq.model_parallel.models.transformer import ModelParallelTransformerDecoder from fairseq.model_parallel.models.transformer import ModelParallelTransformerDecoder
from fairseq.models import register_model, register_model_architecture from fairseq.models import register_model, register_model_architecture
from fairseq.models.transformer_lm import TransformerLanguageModel from fairseq.models.transformer_lm import TransformerLanguageModel
try: try:
from fairseq.model_parallel.megatron.mpu import VocabParallelEmbedding from fairseq.model_parallel.megatron.mpu import VocabParallelEmbedding
@ -22,7 +22,6 @@ DEFAULT_MAX_TARGET_POSITIONS = 1024
@register_model("model_parallel_transformer_lm") @register_model("model_parallel_transformer_lm")
class ModelParallelTransformerLanguageModel(TransformerLanguageModel): class ModelParallelTransformerLanguageModel(TransformerLanguageModel):
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):
TransformerLanguageModel.add_args(parser) TransformerLanguageModel.add_args(parser)
@ -72,10 +71,6 @@ class ModelParallelTransformerLanguageModel(TransformerLanguageModel):
) )
return cls(decoder) return cls(decoder)
@staticmethod
def add_args(parser):
TransformerLanguageModel.add_args(parser)
@classmethod @classmethod
def build_embedding(cls, args, dictionary, embed_dim, path=None): def build_embedding(cls, args, dictionary, embed_dim, path=None):
def _vocab_init(tensor, **kwargs): def _vocab_init(tensor, **kwargs):

View File

@ -98,9 +98,7 @@ def build_model(cfg: FairseqDataclass, task):
assert model is not None, ( assert model is not None, (
f"Could not infer model type from {cfg}. " f"Could not infer model type from {cfg}. "
"Available models: {}".format( "Available models: {}".format(MODEL_DATACLASS_REGISTRY.keys())
MODEL_DATACLASS_REGISTRY.keys()
)
+ f" Requested model type: {model_type}" + f" Requested model type: {model_type}"
) )

View File

@ -100,8 +100,8 @@ class BARTHubInterface(GeneratorHubInterface):
raise NotImplementedError("prefix generation not implemented for BART") raise NotImplementedError("prefix generation not implemented for BART")
res = [] res = []
for batch in self._build_batches(tokenized_sentences, skip_invalid_size_inputs): for batch in self._build_batches(tokenized_sentences, skip_invalid_size_inputs):
src_tokens = batch['net_input']['src_tokens'] src_tokens = batch["net_input"]["src_tokens"]
inference_step_args["prefix_tokens"] =src_tokens.new_full( inference_step_args["prefix_tokens"] = src_tokens.new_full(
(src_tokens.size(0), 1), fill_value=self.task.source_dictionary.bos() (src_tokens.size(0), 1), fill_value=self.task.source_dictionary.bos()
).to(device=self.device) ).to(device=self.device)
results = super().generate( results = super().generate(
@ -111,7 +111,7 @@ class BARTHubInterface(GeneratorHubInterface):
skip_invalid_size_inputs=skip_invalid_size_inputs, skip_invalid_size_inputs=skip_invalid_size_inputs,
**kwargs **kwargs
) )
for id, hypos in zip(batch['id'].tolist(), results): for id, hypos in zip(batch["id"].tolist(), results):
res.append((id, hypos)) res.append((id, hypos))
res = [hypos for _, hypos in sorted(res, key=lambda x: x[0])] res = [hypos for _, hypos in sorted(res, key=lambda x: x[0])]
return res return res
@ -177,32 +177,35 @@ class BARTHubInterface(GeneratorHubInterface):
match_source_len: bool = True, match_source_len: bool = True,
**generate_kwargs **generate_kwargs
): ):
masked_token = '<mask>' masked_token = "<mask>"
batch_tokens = [] batch_tokens = []
for masked_input in masked_inputs: for masked_input in masked_inputs:
assert masked_token in masked_input, \ assert (
"please add one {} token for the input".format(masked_token) masked_token in masked_input
), "please add one {} token for the input".format(masked_token)
text_spans = masked_input.split(masked_token) text_spans = masked_input.split(masked_token)
text_spans_bpe = (' {0} '.format(masked_token)).join( text_spans_bpe = (
[self.bpe.encode(text_span.rstrip()) for text_span in text_spans] (" {0} ".format(masked_token))
).strip() .join([self.bpe.encode(text_span.rstrip()) for text_span in text_spans])
.strip()
)
tokens = self.task.source_dictionary.encode_line( tokens = self.task.source_dictionary.encode_line(
'<s> ' + text_spans_bpe + ' </s>', "<s> " + text_spans_bpe + " </s>",
append_eos=False, append_eos=False,
add_if_not_exist=False, add_if_not_exist=False,
).long() ).long()
batch_tokens.append(tokens) batch_tokens.append(tokens)
# ensure beam size is at least as big as topk # ensure beam size is at least as big as topk
generate_kwargs['beam'] = max( generate_kwargs["beam"] = max(
topk, topk,
generate_kwargs.get('beam', -1), generate_kwargs.get("beam", -1),
) )
generate_kwargs['match_source_len'] = match_source_len generate_kwargs["match_source_len"] = match_source_len
batch_hypos = self.generate(batch_tokens, **generate_kwargs) batch_hypos = self.generate(batch_tokens, **generate_kwargs)
return [ return [
[(self.decode(hypo['tokens']), hypo['score']) for hypo in hypos[:topk]] [(self.decode(hypo["tokens"]), hypo["score"]) for hypo in hypos[:topk]]
for hypos in batch_hypos for hypos in batch_hypos
] ]

View File

@ -90,7 +90,7 @@ class BARTModel(TransformerModel):
src_tokens, src_tokens,
src_lengths=src_lengths, src_lengths=src_lengths,
token_embeddings=token_embeddings, token_embeddings=token_embeddings,
return_all_hiddens=return_all_hiddens return_all_hiddens=return_all_hiddens,
) )
x, extra = self.decoder( x, extra = self.decoder(
prev_output_tokens, prev_output_tokens,
@ -103,9 +103,9 @@ class BARTModel(TransformerModel):
) )
eos: int = self.eos eos: int = self.eos
if classification_head_name is not None: if classification_head_name is not None:
sentence_representation = x[ sentence_representation = x[src_tokens.eq(eos), :].view(
src_tokens.eq(eos), : x.size(0), -1, x.size(-1)
].view(x.size(0), -1, x.size(-1))[:, -1, :] )[:, -1, :]
for k, head in self.classification_heads.items(): for k, head in self.classification_heads.items():
# for torch script only supports iteration # for torch script only supports iteration
if k == classification_head_name: if k == classification_head_name:

View File

@ -25,7 +25,10 @@ logger = logging.getLogger(__name__)
_SLOWMO_DDP_DISABLED = False _SLOWMO_DDP_DISABLED = False
try: try:
from fairscale.experimental.nn.data_parallel import SlowMoBaseAlgorithm, SlowMoDistributedDataParallel from fairscale.experimental.nn.data_parallel import (
SlowMoBaseAlgorithm,
SlowMoDistributedDataParallel,
)
except ImportError: except ImportError:
_SLOWMO_DDP_DISABLED = True _SLOWMO_DDP_DISABLED = True

View File

@ -22,6 +22,7 @@ import copy
import logging import logging
import torch import torch
from fairseq import checkpoint_utils from fairseq import checkpoint_utils
@ -78,7 +79,9 @@ class EMA(object):
self.fp32_params = {} self.fp32_params = {}
if self.config.ema_seed_model is not None: if self.config.ema_seed_model is not None:
state = checkpoint_utils.load_ema_from_checkpoint(self.config.ema_seed_model) state = checkpoint_utils.load_ema_from_checkpoint(
self.config.ema_seed_model
)
self.model.load_state_dict(state["model"], strict=True) self.model.load_state_dict(state["model"], strict=True)
if device is not None: if device is not None:
@ -119,7 +122,7 @@ class EMA(object):
self.fp32_params[param_key] = _to_float(state_dict[param_key]) self.fp32_params[param_key] = _to_float(state_dict[param_key])
def restore(self, state_dict, build_fp32_params=False): def restore(self, state_dict, build_fp32_params=False):
""" Load data from a model spec into EMA model """ """Load data from a model spec into EMA model"""
self.model.load_state_dict(state_dict, strict=False) self.model.load_state_dict(state_dict, strict=False)
if build_fp32_params: if build_fp32_params:
self.build_fp32_params(state_dict) self.build_fp32_params(state_dict)
@ -131,16 +134,20 @@ class EMA(object):
return self.decay return self.decay
def _step_internal(self, new_model, updates=None): def _step_internal(self, new_model, updates=None):
""" One update of the EMA model based on new model weights """ """One update of the EMA model based on new model weights"""
decay = self.decay decay = self.decay
ema_state_dict = {} ema_state_dict = {}
ema_params = self.fp32_params if self.config.ema_fp32 else self.model.state_dict() ema_params = (
self.fp32_params if self.config.ema_fp32 else self.model.state_dict()
)
for key, param in new_model.state_dict().items(): for key, param in new_model.state_dict().items():
try: try:
ema_param = ema_params[key] ema_param = ema_params[key]
except KeyError: except KeyError:
ema_param = param.float().clone() if param.ndim == 1 else copy.deepcopy(param) ema_param = (
param.float().clone() if param.ndim == 1 else copy.deepcopy(param)
)
if param.shape != ema_param.shape: if param.shape != ema_param.shape:
raise ValueError( raise ValueError(
@ -151,7 +158,7 @@ class EMA(object):
# Do not decay a model.version pytorch param # Do not decay a model.version pytorch param
continue continue
ema_param.mul_(decay) ema_param.mul_(decay)
ema_param.add_(param.to(dtype=ema_param.dtype), alpha=1-decay) ema_param.add_(param.to(dtype=ema_param.dtype), alpha=1 - decay)
ema_state_dict[key] = ema_param ema_state_dict[key] = ema_param
self.restore(ema_state_dict, build_fp32_params=False) self.restore(ema_state_dict, build_fp32_params=False)
@ -168,8 +175,7 @@ class EMA(object):
""" """
self._set_decay( self._set_decay(
0 0
if updates is not None if updates is not None and updates < self.config.ema_start_update
and updates < self.config.ema_start_update
else self.config.ema_decay else self.config.ema_decay
) )
if updates is not None and self.config.ema_update_freq > 1: if updates is not None and self.config.ema_update_freq > 1:

View File

@ -19,7 +19,6 @@ class FairseqDecoder(nn.Module):
self.onnx_trace = False self.onnx_trace = False
self.adaptive_softmax = None self.adaptive_softmax = None
def forward(self, prev_output_tokens, encoder_out=None, **kwargs): def forward(self, prev_output_tokens, encoder_out=None, **kwargs):
""" """
Args: Args:

View File

@ -29,8 +29,9 @@ logger = logging.getLogger(__name__)
def check_type(module, expected_type): def check_type(module, expected_type):
if hasattr(module, "unwrapped_module"): if hasattr(module, "unwrapped_module"):
assert isinstance(module.unwrapped_module, expected_type), \ assert isinstance(
f"{type(module.unwrapped_module)} != {expected_type}" module.unwrapped_module, expected_type
), f"{type(module.unwrapped_module)} != {expected_type}"
else: else:
assert isinstance(module, expected_type), f"{type(module)} != {expected_type}" assert isinstance(module, expected_type), f"{type(module)} != {expected_type}"
@ -114,7 +115,9 @@ class BaseFairseqModel(nn.Module):
""" """
if model_cfg is None and args is not None: if model_cfg is None and args is not None:
logger.warn("using 'args' is deprecated, please update your code to use dataclass config") logger.warn(
"using 'args' is deprecated, please update your code to use dataclass config"
)
model_cfg = convert_namespace_to_omegaconf(args).model model_cfg = convert_namespace_to_omegaconf(args).model
self.upgrade_state_dict(state_dict) self.upgrade_state_dict(state_dict)
@ -454,7 +457,9 @@ class FairseqMultiModel(BaseFairseqModel):
""" """
if model_cfg is None and args is not None: if model_cfg is None and args is not None:
logger.warn("using 'args' is deprecated, please update your code to use dataclass config") logger.warn(
"using 'args' is deprecated, please update your code to use dataclass config"
)
model_cfg = convert_namespace_to_omegaconf(args).model model_cfg = convert_namespace_to_omegaconf(args).model
self.upgrade_state_dict(state_dict) self.upgrade_state_dict(state_dict)

View File

@ -30,9 +30,7 @@ from omegaconf import II
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
EXTRACTOR_MODE_CHOICES = ChoiceEnum(["default", "layer_norm"]) EXTRACTOR_MODE_CHOICES = ChoiceEnum(["default", "layer_norm"])
MASKING_DISTRIBUTION_CHOICES = ChoiceEnum( MASKING_DISTRIBUTION_CHOICES = ChoiceEnum(["static", "uniform", "normal", "poisson"])
["static", "uniform", "normal", "poisson"]
)
@dataclass @dataclass
@ -86,9 +84,7 @@ class HubertConfig(FairseqDataclass):
) )
dropout_features: float = field( dropout_features: float = field(
default=0.0, default=0.0,
metadata={ metadata={"help": "dropout to apply to the features (after feat extr)"},
"help": "dropout to apply to the features (after feat extr)"
},
) )
final_dim: int = field( final_dim: int = field(
@ -150,9 +146,7 @@ class HubertConfig(FairseqDataclass):
) )
mask_min_space: int = field( mask_min_space: int = field(
default=1, default=1,
metadata={ metadata={"help": "min space between spans (if no overlap is enabled)"},
"help": "min space between spans (if no overlap is enabled)"
},
) )
# channel masking # channel masking
@ -182,23 +176,17 @@ class HubertConfig(FairseqDataclass):
) )
mask_channel_min_space: int = field( mask_channel_min_space: int = field(
default=1, default=1,
metadata={ metadata={"help": "min space between spans (if no overlap is enabled)"},
"help": "min space between spans (if no overlap is enabled)"
},
) )
# positional embeddings # positional embeddings
conv_pos: int = field( conv_pos: int = field(
default=128, default=128,
metadata={ metadata={"help": "number of filters for convolutional positional embeddings"},
"help": "number of filters for convolutional positional embeddings"
},
) )
conv_pos_groups: int = field( conv_pos_groups: int = field(
default=16, default=16,
metadata={ metadata={"help": "number of groups for convolutional positional embedding"},
"help": "number of groups for convolutional positional embedding"
},
) )
latent_temp: Tuple[float, float, float] = field( latent_temp: Tuple[float, float, float] = field(
@ -238,9 +226,7 @@ class HubertModel(BaseFairseqModel):
conv_bias=cfg.conv_bias, conv_bias=cfg.conv_bias,
) )
feature_ds_rate = np.prod([s for _, _, s in feature_enc_layers]) feature_ds_rate = np.prod([s for _, _, s in feature_enc_layers])
self.feat2tar_ratio = ( self.feat2tar_ratio = cfg.label_rate * feature_ds_rate / task_cfg.sample_rate
cfg.label_rate * feature_ds_rate / task_cfg.sample_rate
)
self.post_extract_proj = ( self.post_extract_proj = (
nn.Linear(self.embed, cfg.encoder_embed_dim) nn.Linear(self.embed, cfg.encoder_embed_dim)
@ -270,9 +256,7 @@ class HubertModel(BaseFairseqModel):
self.skip_masked = cfg.skip_masked self.skip_masked = cfg.skip_masked
self.skip_nomask = cfg.skip_nomask self.skip_nomask = cfg.skip_nomask
final_dim = ( final_dim = cfg.final_dim if cfg.final_dim > 0 else cfg.encoder_embed_dim
cfg.final_dim if cfg.final_dim > 0 else cfg.encoder_embed_dim
)
self.mask_emb = nn.Parameter( self.mask_emb = nn.Parameter(
torch.FloatTensor(cfg.encoder_embed_dim).uniform_() torch.FloatTensor(cfg.encoder_embed_dim).uniform_()
@ -297,9 +281,7 @@ class HubertModel(BaseFairseqModel):
# modules below are not needed during fine-tuning # modules below are not needed during fine-tuning
if any([d is None for d in dictionaries]): if any([d is None for d in dictionaries]):
logger.info( logger.info("cannot find dictionary. assume will be used for fine-tuning")
"cannot find dictionary. assume will be used for fine-tuning"
)
else: else:
self.num_classes = [len(d) for d in dictionaries] self.num_classes = [len(d) for d in dictionaries]
self.label_embs_concat = nn.Parameter( self.label_embs_concat = nn.Parameter(
@ -365,9 +347,7 @@ class HubertModel(BaseFairseqModel):
pos = pos.unsqueeze(0) pos = pos.unsqueeze(0)
targets = torch.cat([pos, negs], dim=0) targets = torch.cat([pos, negs], dim=0)
logits = torch.cosine_similarity( logits = torch.cosine_similarity(x.float(), targets.float(), dim=-1).type_as(x)
x.float(), targets.float(), dim=-1
).type_as(x)
logits /= self.logit_temp logits /= self.logit_temp
if neg_is_pos.any(): if neg_is_pos.any():
logits[1:][neg_is_pos] = float("-inf") logits[1:][neg_is_pos] = float("-inf")
@ -385,7 +365,9 @@ class HubertModel(BaseFairseqModel):
return features return features
def forward_targets( def forward_targets(
self, features: torch.Tensor, target_list: List[torch.Tensor], self,
features: torch.Tensor,
target_list: List[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
# Trim features to ensure labels exist and then get aligned labels # Trim features to ensure labels exist and then get aligned labels
feat_tsz = features.size(2) feat_tsz = features.size(2)
@ -398,14 +380,14 @@ class HubertModel(BaseFairseqModel):
return features, target_list return features, target_list
def forward_padding_mask( def forward_padding_mask(
self, features: torch.Tensor, padding_mask: torch.Tensor, self,
features: torch.Tensor,
padding_mask: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
extra = padding_mask.size(1) % features.size(1) extra = padding_mask.size(1) % features.size(1)
if extra > 0: if extra > 0:
padding_mask = padding_mask[:, :-extra] padding_mask = padding_mask[:, :-extra]
padding_mask = padding_mask.view( padding_mask = padding_mask.view(padding_mask.size(0), features.size(1), -1)
padding_mask.size(0), features.size(1), -1
)
padding_mask = padding_mask.all(-1) padding_mask = padding_mask.all(-1)
return padding_mask return padding_mask
@ -439,9 +421,7 @@ class HubertModel(BaseFairseqModel):
unmasked_features = self.dropout_features(unmasked_features) unmasked_features = self.dropout_features(unmasked_features)
if mask: if mask:
x, mask_indices = self.apply_mask( x, mask_indices = self.apply_mask(features, padding_mask, target_list)
features, padding_mask, target_list
)
else: else:
x = features x = features
mask_indices = None mask_indices = None
@ -454,7 +434,7 @@ class HubertModel(BaseFairseqModel):
x, _ = self.encoder( x, _ = self.encoder(
x, x,
padding_mask=padding_mask, padding_mask=padding_mask,
layer=None if output_layer is None else output_layer - 1 layer=None if output_layer is None else output_layer - 1,
) )
if features_only: if features_only:
@ -483,9 +463,7 @@ class HubertModel(BaseFairseqModel):
proj_x_m_list = [proj_x_m for _ in range(len(target_list))] proj_x_m_list = [proj_x_m for _ in range(len(target_list))]
logit_m_list = [ logit_m_list = [
compute_pred(proj_x_m, t[masked_indices], label_embs_list[i]) compute_pred(proj_x_m, t[masked_indices], label_embs_list[i])
for i, (proj_x_m, t) in enumerate( for i, (proj_x_m, t) in enumerate(zip(proj_x_m_list, target_list))
zip(proj_x_m_list, target_list)
)
] ]
else: else:
logit_m_list = [None for _ in target_list] logit_m_list = [None for _ in target_list]
@ -500,9 +478,7 @@ class HubertModel(BaseFairseqModel):
logit_u_list = [ logit_u_list = [
compute_pred(proj_x_u, t[nomask_indices], label_embs_list[i]) compute_pred(proj_x_u, t[nomask_indices], label_embs_list[i])
for i, (proj_x_u, t) in enumerate( for i, (proj_x_u, t) in enumerate(zip(proj_x_u_list, target_list))
zip(proj_x_u_list, target_list)
)
] ]
else: else:
logit_u_list = [None for _ in target_list] logit_u_list = [None for _ in target_list]
@ -543,9 +519,7 @@ class HubertModel(BaseFairseqModel):
def get_targets(self, net_output, is_masked=True): def get_targets(self, net_output, is_masked=True):
logits_list = self.get_logits(net_output, is_masked) logits_list = self.get_logits(net_output, is_masked)
targets_list = [ targets_list = [x.new_zeros(x.size(0), dtype=torch.long) for x in logits_list]
x.new_zeros(x.size(0), dtype=torch.long) for x in logits_list
]
return targets_list return targets_list
def get_extra_losses(self, net_output): def get_extra_losses(self, net_output):

View File

@ -21,9 +21,7 @@ from omegaconf import II, MISSING
@dataclass @dataclass
class HubertAsrConfig(FairseqDataclass): class HubertAsrConfig(FairseqDataclass):
w2v_path: str = field( w2v_path: str = field(default=MISSING, metadata={"help": "path to hubert model"})
default=MISSING, metadata={"help": "path to hubert model"}
)
no_pretrained_weights: bool = field( no_pretrained_weights: bool = field(
default=False, default=False,
metadata={"help": "if true, does not load pretrained weights"}, metadata={"help": "if true, does not load pretrained weights"},
@ -34,9 +32,7 @@ class HubertAsrConfig(FairseqDataclass):
) )
final_dropout: float = field( final_dropout: float = field(
default=0.0, default=0.0,
metadata={ metadata={"help": "dropout after transformer and before final projection"},
"help": "dropout after transformer and before final projection"
},
) )
dropout: float = field( dropout: float = field(
default=0.0, default=0.0,
@ -45,15 +41,13 @@ class HubertAsrConfig(FairseqDataclass):
attention_dropout: float = field( attention_dropout: float = field(
default=0.0, default=0.0,
metadata={ metadata={
"help": "dropout probability for attention weights " "help": "dropout probability for attention weights " "inside hubert model"
"inside hubert model"
}, },
) )
activation_dropout: float = field( activation_dropout: float = field(
default=0.0, default=0.0,
metadata={ metadata={
"help": "dropout probability after activation in FFN " "help": "dropout probability after activation in FFN " "inside hubert model"
"inside hubert model"
}, },
) )
@ -184,9 +178,7 @@ class HubertSeq2SeqConfig(HubertAsrConfig):
decoder_ffn_embed_dim: int = field( decoder_ffn_embed_dim: int = field(
default=3072, metadata={"help": "decoder embedding dimension for FFN"} default=3072, metadata={"help": "decoder embedding dimension for FFN"}
) )
decoder_layers: int = field( decoder_layers: int = field(default=6, metadata={"help": "num of decoder layers"})
default=6, metadata={"help": "num of decoder layers"}
)
decoder_layerdrop: float = field( decoder_layerdrop: float = field(
default=0.0, metadata={"help": "decoder layerdrop chance"} default=0.0, metadata={"help": "decoder layerdrop chance"}
) )
@ -204,8 +196,7 @@ class HubertSeq2SeqConfig(HubertAsrConfig):
no_token_positional_embeddings: bool = field( no_token_positional_embeddings: bool = field(
default=False, default=False,
metadata={ metadata={
"help": "if set, disables positional embeddings " "help": "if set, disables positional embeddings " "(outside self attention)"
"(outside self attention)"
}, },
) )
decoder_dropout: float = field( decoder_dropout: float = field(
@ -214,15 +205,13 @@ class HubertSeq2SeqConfig(HubertAsrConfig):
decoder_attention_dropout: float = field( decoder_attention_dropout: float = field(
default=0.0, default=0.0,
metadata={ metadata={
"help": "dropout probability for attention weights " "help": "dropout probability for attention weights " "inside the decoder"
"inside the decoder"
}, },
) )
decoder_activation_dropout: float = field( decoder_activation_dropout: float = field(
default=0.0, default=0.0,
metadata={ metadata={
"help": "dropout probability after activation in FFN " "help": "dropout probability after activation in FFN " "inside the decoder"
"inside the decoder"
}, },
) )
max_target_positions: int = field( max_target_positions: int = field(
@ -258,9 +247,7 @@ class HubertEncoder(FairseqEncoder):
} }
if cfg.w2v_args is None: if cfg.w2v_args is None:
state = checkpoint_utils.load_checkpoint_to_cpu( state = checkpoint_utils.load_checkpoint_to_cpu(cfg.w2v_path, arg_overrides)
cfg.w2v_path, arg_overrides
)
w2v_args = state.get("cfg", None) w2v_args = state.get("cfg", None)
if w2v_args is None: if w2v_args is None:
w2v_args = convert_namespace_to_omegaconf(state["args"]) w2v_args = convert_namespace_to_omegaconf(state["args"])
@ -269,9 +256,7 @@ class HubertEncoder(FairseqEncoder):
state = None state = None
w2v_args = cfg.w2v_args w2v_args = cfg.w2v_args
if isinstance(w2v_args, Namespace): if isinstance(w2v_args, Namespace):
cfg.w2v_args = w2v_args = convert_namespace_to_omegaconf( cfg.w2v_args = w2v_args = convert_namespace_to_omegaconf(w2v_args)
w2v_args
)
assert cfg.normalize == w2v_args.task.normalize, ( assert cfg.normalize == w2v_args.task.normalize, (
"Fine-tuning works best when data normalization is the same. " "Fine-tuning works best when data normalization is the same. "
@ -344,9 +329,9 @@ class HubertEncoder(FairseqEncoder):
def reorder_encoder_out(self, encoder_out, new_order): def reorder_encoder_out(self, encoder_out, new_order):
if encoder_out["encoder_out"] is not None: if encoder_out["encoder_out"] is not None:
encoder_out["encoder_out"] = encoder_out[ encoder_out["encoder_out"] = encoder_out["encoder_out"].index_select(
"encoder_out" 1, new_order
].index_select(1, new_order) )
if encoder_out["encoder_padding_mask"] is not None: if encoder_out["encoder_padding_mask"] is not None:
encoder_out["encoder_padding_mask"] = encoder_out[ encoder_out["encoder_padding_mask"] = encoder_out[
"encoder_padding_mask" "encoder_padding_mask"

View File

@ -225,10 +225,10 @@ class LSTMEncoder(FairseqEncoder):
super().__init__(dictionary) super().__init__(dictionary)
self.num_layers = num_layers self.num_layers = num_layers
self.dropout_in_module = FairseqDropout( self.dropout_in_module = FairseqDropout(
dropout_in*1.0, module_name=self.__class__.__name__ dropout_in * 1.0, module_name=self.__class__.__name__
) )
self.dropout_out_module = FairseqDropout( self.dropout_out_module = FairseqDropout(
dropout_out*1.0, module_name=self.__class__.__name__ dropout_out * 1.0, module_name=self.__class__.__name__
) )
self.bidirectional = bidirectional self.bidirectional = bidirectional
self.hidden_size = hidden_size self.hidden_size = hidden_size
@ -329,7 +329,9 @@ class LSTMEncoder(FairseqEncoder):
out = outs.view(self.num_layers, 2, bsz, -1).transpose(1, 2).contiguous() out = outs.view(self.num_layers, 2, bsz, -1).transpose(1, 2).contiguous()
return out.view(self.num_layers, bsz, -1) return out.view(self.num_layers, bsz, -1)
def reorder_encoder_out(self, encoder_out: Tuple[Tensor, Tensor, Tensor, Tensor], new_order): def reorder_encoder_out(
self, encoder_out: Tuple[Tensor, Tensor, Tensor, Tensor], new_order
):
return tuple( return tuple(
( (
encoder_out[0].index_select(1, new_order), encoder_out[0].index_select(1, new_order),
@ -402,10 +404,10 @@ class LSTMDecoder(FairseqIncrementalDecoder):
): ):
super().__init__(dictionary) super().__init__(dictionary)
self.dropout_in_module = FairseqDropout( self.dropout_in_module = FairseqDropout(
dropout_in*1.0, module_name=self.__class__.__name__ dropout_in * 1.0, module_name=self.__class__.__name__
) )
self.dropout_out_module = FairseqDropout( self.dropout_out_module = FairseqDropout(
dropout_out*1.0, module_name=self.__class__.__name__ dropout_out * 1.0, module_name=self.__class__.__name__
) )
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.share_input_output_embed = share_input_output_embed self.share_input_output_embed = share_input_output_embed

View File

@ -18,7 +18,10 @@ def ensemble_encoder(func):
def wrapper(self, *args, **kwargs): def wrapper(self, *args, **kwargs):
if self.ensemble_models is None or len(self.ensemble_models) == 1: if self.ensemble_models is None or len(self.ensemble_models) == 1:
return func(self, *args, **kwargs) return func(self, *args, **kwargs)
encoder_outs = [func(model, *args, **kwargs, return_all_hiddens=True) for model in self.ensemble_models] encoder_outs = [
func(model, *args, **kwargs, return_all_hiddens=True)
for model in self.ensemble_models
]
_encoder_out = encoder_outs[0].copy() _encoder_out = encoder_outs[0].copy()
def stack(key): def stack(key):
@ -56,8 +59,7 @@ def ensemble_decoder(func):
model, model,
normalize=normalize, normalize=normalize,
encoder_out=_replace( encoder_out=_replace(
encoder_out, encoder_out, encoder_out["encoder_out"][0][:, :, :, i]
encoder_out["encoder_out"][0][:, :, :, i]
), ),
*args, *args,
**kwargs **kwargs

View File

@ -85,7 +85,8 @@ class EnsembleLevT(BasicEnsembleModel):
else: else:
if not encoder_outs[0]["encoder_padding_mask"]: if not encoder_outs[0]["encoder_padding_mask"]:
src_lens = ( src_lens = (
encoder_outs[0]["encoder_out"][0].new(bsz) encoder_outs[0]["encoder_out"][0]
.new(bsz)
.fill_(encoder_outs[0]["encoder_out"][0].size(1)) .fill_(encoder_outs[0]["encoder_out"][0].size(1))
) )
else: else:

View File

@ -183,7 +183,7 @@ class RobertaModel(FairseqEncoderModel):
"communication less efficient due to smaller input sizes. This option " "communication less efficient due to smaller input sizes. This option "
"is set to 0 (i.e., always wrap) when --checkpoint-activations or " "is set to 0 (i.e., always wrap) when --checkpoint-activations or "
"--offload-activations are passed." "--offload-activations are passed."
) ),
) )
@classmethod @classmethod
@ -542,7 +542,9 @@ def base_architecture(args):
args.layernorm_embedding = safe_getattr(args, "layernorm_embedding", True) args.layernorm_embedding = safe_getattr(args, "layernorm_embedding", True)
args.no_scale_embedding = safe_getattr(args, "no_scale_embedding", True) args.no_scale_embedding = safe_getattr(args, "no_scale_embedding", True)
args.activation_fn = safe_getattr(args, "activation_fn", "gelu") args.activation_fn = safe_getattr(args, "activation_fn", "gelu")
args.encoder_normalize_before = safe_getattr(args, "encoder_normalize_before", False) args.encoder_normalize_before = safe_getattr(
args, "encoder_normalize_before", False
)
args.pooler_activation_fn = safe_getattr(args, "pooler_activation_fn", "tanh") args.pooler_activation_fn = safe_getattr(args, "pooler_activation_fn", "tanh")
args.untie_weights_roberta = safe_getattr(args, "untie_weights_roberta", False) args.untie_weights_roberta = safe_getattr(args, "untie_weights_roberta", False)

View File

@ -12,26 +12,26 @@ from .hub_interface import RobertaHubInterface
from .model import RobertaModel from .model import RobertaModel
@register_model('gottbert') @register_model("gottbert")
class GottbertModel(RobertaModel): class GottbertModel(RobertaModel):
@classmethod @classmethod
def hub_models(cls): def hub_models(cls):
return { return {
'gottbert-base': 'https://dl.gottbert.de/fairseq/models/gottbert-base.tar.gz', "gottbert-base": "https://dl.gottbert.de/fairseq/models/gottbert-base.tar.gz",
} }
@classmethod @classmethod
def from_pretrained(cls, def from_pretrained(
model_name_or_path, cls,
checkpoint_file='model.pt', model_name_or_path,
data_name_or_path='.', checkpoint_file="model.pt",
bpe='hf_byte_bpe', data_name_or_path=".",
bpe_vocab='vocab.json', bpe="hf_byte_bpe",
bpe_merges='merges.txt', bpe_vocab="vocab.json",
bpe_add_prefix_space=False, bpe_merges="merges.txt",
**kwargs bpe_add_prefix_space=False,
): **kwargs
):
from fairseq import hub_utils from fairseq import hub_utils
x = hub_utils.from_pretrained( x = hub_utils.from_pretrained(
@ -46,4 +46,4 @@ class GottbertModel(RobertaModel):
bpe_add_prefix_space=bpe_add_prefix_space, bpe_add_prefix_space=bpe_add_prefix_space,
**kwargs, **kwargs,
) )
return RobertaHubInterface(x['args'], x['task'], x['models'][0]) return RobertaHubInterface(x["args"], x["task"], x["models"][0])

View File

@ -202,10 +202,10 @@ class S2TTransformerModel(FairseqEncoderDecoderModel):
help="model to take encoder weights from (for initialization)", help="model to take encoder weights from (for initialization)",
) )
parser.add_argument( parser.add_argument(
'--encoder-freezing-updates', "--encoder-freezing-updates",
type=int, type=int,
metavar='N', metavar="N",
help='freeze encoder for first N updates' help="freeze encoder for first N updates",
) )
@classmethod @classmethod
@ -329,7 +329,9 @@ class S2TTransformerEncoder(FairseqEncoder):
return { return {
"encoder_out": [x], # T x B x C "encoder_out": [x], # T x B x C
"encoder_padding_mask": [encoder_padding_mask] if encoder_padding_mask.any() else [], # B x T "encoder_padding_mask": [encoder_padding_mask]
if encoder_padding_mask.any()
else [], # B x T
"encoder_embedding": [], # B x T x C "encoder_embedding": [], # B x T x C
"encoder_states": encoder_states, # List[T x B x C] "encoder_states": encoder_states, # List[T x B x C]
"src_tokens": [], "src_tokens": [],
@ -339,27 +341,37 @@ class S2TTransformerEncoder(FairseqEncoder):
def forward(self, src_tokens, src_lengths, return_all_hiddens=False): def forward(self, src_tokens, src_lengths, return_all_hiddens=False):
if self.num_updates < self.encoder_freezing_updates: if self.num_updates < self.encoder_freezing_updates:
with torch.no_grad(): with torch.no_grad():
x = self._forward(src_tokens, src_lengths, x = self._forward(
return_all_hiddens=return_all_hiddens) src_tokens, src_lengths, return_all_hiddens=return_all_hiddens
)
else: else:
x = self._forward(src_tokens, src_lengths, x = self._forward(
return_all_hiddens=return_all_hiddens) src_tokens, src_lengths, return_all_hiddens=return_all_hiddens
)
return x return x
def reorder_encoder_out(self, encoder_out, new_order): def reorder_encoder_out(self, encoder_out, new_order):
new_encoder_out = ( new_encoder_out = (
[] if len(encoder_out["encoder_out"]) == 0 []
if len(encoder_out["encoder_out"]) == 0
else [x.index_select(1, new_order) for x in encoder_out["encoder_out"]] else [x.index_select(1, new_order) for x in encoder_out["encoder_out"]]
) )
new_encoder_padding_mask = ( new_encoder_padding_mask = (
[] if len(encoder_out["encoder_padding_mask"]) == 0 []
else [x.index_select(0, new_order) for x in encoder_out["encoder_padding_mask"]] if len(encoder_out["encoder_padding_mask"]) == 0
else [
x.index_select(0, new_order)
for x in encoder_out["encoder_padding_mask"]
]
) )
new_encoder_embedding = ( new_encoder_embedding = (
[] if len(encoder_out["encoder_embedding"]) == 0 []
else [x.index_select(0, new_order) for x in encoder_out["encoder_embedding"]] if len(encoder_out["encoder_embedding"]) == 0
else [
x.index_select(0, new_order) for x in encoder_out["encoder_embedding"]
]
) )
encoder_states = encoder_out["encoder_states"] encoder_states = encoder_out["encoder_states"]

View File

@ -9,8 +9,12 @@ import copy
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
from fairseq import utils, checkpoint_utils from fairseq import utils, checkpoint_utils
from fairseq.models import (FairseqEncoderDecoderModel, FairseqEncoder, from fairseq.models import (
register_model, register_model_architecture) FairseqEncoderDecoderModel,
FairseqEncoder,
register_model,
register_model_architecture,
)
from fairseq.models.transformer import Embedding, TransformerDecoder from fairseq.models.transformer import Embedding, TransformerDecoder
from fairseq.models.wav2vec import Wav2VecEncoder from fairseq.models.wav2vec import Wav2VecEncoder
from fairseq.modules.layer_norm import LayerNorm from fairseq.modules.layer_norm import LayerNorm
@ -24,18 +28,23 @@ logger = logging.getLogger(__name__)
class Conv1dAdaptor(nn.Module): class Conv1dAdaptor(nn.Module):
def __init__(self, in_dim, out_dim, n_layers=3, kernel_size=3, stride=2, def __init__(
add_layernorm=False): self, in_dim, out_dim, n_layers=3, kernel_size=3, stride=2, add_layernorm=False
):
super().__init__() super().__init__()
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
nn.Conv1d(in_dim if i == 0 else out_dim, out_dim * 2, kernel_size, nn.Conv1d(
stride=stride, padding=kernel_size // 2) in_dim if i == 0 else out_dim,
out_dim * 2,
kernel_size,
stride=stride,
padding=kernel_size // 2,
)
for i in range(n_layers) for i in range(n_layers)
) )
self.layernorms = None self.layernorms = None
if add_layernorm: if add_layernorm:
self.layernorms = nn.ModuleList(LayerNorm(out_dim) self.layernorms = nn.ModuleList(LayerNorm(out_dim) for _ in range(n_layers))
for _ in range(n_layers))
self.stride = stride self.stride = stride
@classmethod @classmethod
@ -43,7 +52,7 @@ class Conv1dAdaptor(nn.Module):
parser.add_argument("--adaptor-n-layers", type=int) parser.add_argument("--adaptor-n-layers", type=int)
parser.add_argument("--adaptor-kernel-size", type=int) parser.add_argument("--adaptor-kernel-size", type=int)
parser.add_argument("--adaptor-stride", type=int) parser.add_argument("--adaptor-stride", type=int)
parser.add_argument("--adaptor-layernorm", action='store_true') parser.add_argument("--adaptor-layernorm", action="store_true")
def get_out_seq_lens_tensor(self, in_seq_lens_tensor): def get_out_seq_lens_tensor(self, in_seq_lens_tensor):
out = in_seq_lens_tensor.clone() out = in_seq_lens_tensor.clone()
@ -197,15 +206,18 @@ class Wav2VecEncoderWithAdaptor(FairseqEncoder):
encoder_out_dim = self.w2v_encoder.w2v_model.encoder.embedding_dim encoder_out_dim = self.w2v_encoder.w2v_model.encoder.embedding_dim
# Projection + 8x shrinking # Projection + 8x shrinking
self.adaptor = Conv1dAdaptor( self.adaptor = Conv1dAdaptor(
encoder_out_dim, args.decoder_embed_dim, encoder_out_dim,
args.decoder_embed_dim,
n_layers=args.adaptor_n_layers, n_layers=args.adaptor_n_layers,
kernel_size=args.adaptor_kernel_size, stride=args.adaptor_stride, kernel_size=args.adaptor_kernel_size,
add_layernorm=args.adaptor_layernorm stride=args.adaptor_stride,
add_layernorm=args.adaptor_layernorm,
) )
for k, p in self.w2v_encoder.w2v_model.named_parameters(): for k, p in self.w2v_encoder.w2v_model.named_parameters():
# Freeze pretrained models by default # Freeze pretrained models by default
if safe_hasattr(args, 'finetune_w2v_params') and XMTransformerModel.finetune_params( if safe_hasattr(
args.finetune_w2v_params, k): args, "finetune_w2v_params"
) and XMTransformerModel.finetune_params(args.finetune_w2v_params, k):
p.requires_grad = True p.requires_grad = True
else: else:
p.requires_grad = False p.requires_grad = False
@ -214,11 +226,16 @@ class Wav2VecEncoderWithAdaptor(FairseqEncoder):
def add_args(cls, parser): def add_args(cls, parser):
add_wav2vec_asr_args(parser) add_wav2vec_asr_args(parser)
parser.add_argument( parser.add_argument(
"--normalize", action="store_true", "--normalize",
action="store_true",
help="if set, normalizes input to have 0 mean and unit variance", help="if set, normalizes input to have 0 mean and unit variance",
) )
parser.add_argument("--finetune-w2v-params", type=str, metavar="STR", parser.add_argument(
help="comma-separated param strings to finetune.") "--finetune-w2v-params",
type=str,
metavar="STR",
help="comma-separated param strings to finetune.",
)
Conv1dAdaptor.add_args(parser) Conv1dAdaptor.add_args(parser)
def forward(self, src_tokens, src_lengths=None, **kwargs): def forward(self, src_tokens, src_lengths=None, **kwargs):
@ -227,13 +244,17 @@ class Wav2VecEncoderWithAdaptor(FairseqEncoder):
x = out["encoder_out"] x = out["encoder_out"]
enc_padding_mask = None enc_padding_mask = None
if out["encoder_padding_mask"] is not None: if out["encoder_padding_mask"] is not None:
enc_padding_mask = out["encoder_padding_mask"].transpose(0, 1) # T X B --> B X T enc_padding_mask = out["encoder_padding_mask"].transpose(
0, 1
) # T X B --> B X T
x, enc_padding_mask = self.adaptor(x, enc_padding_mask) x, enc_padding_mask = self.adaptor(x, enc_padding_mask)
return { return {
"encoder_out": [x], # T x B x C "encoder_out": [x], # T x B x C
"encoder_padding_mask": [enc_padding_mask] if enc_padding_mask.any() else [], # B x T "encoder_padding_mask": [enc_padding_mask]
if enc_padding_mask.any()
else [], # B x T
"encoder_embedding": [], # B x T x C "encoder_embedding": [], # B x T x C
"encoder_states": [], # List[T x B x C] "encoder_states": [], # List[T x B x C]
"src_tokens": [], "src_tokens": [],
@ -242,20 +263,26 @@ class Wav2VecEncoderWithAdaptor(FairseqEncoder):
def reorder_encoder_out(self, encoder_out, new_order): def reorder_encoder_out(self, encoder_out, new_order):
new_encoder_out = ( new_encoder_out = (
[] if len(encoder_out["encoder_out"]) == 0 []
if len(encoder_out["encoder_out"]) == 0
else [x.index_select(1, new_order) for x in encoder_out["encoder_out"]] else [x.index_select(1, new_order) for x in encoder_out["encoder_out"]]
) )
new_encoder_padding_mask = ( new_encoder_padding_mask = (
[] if len(encoder_out["encoder_padding_mask"]) == 0 []
else [x.index_select(0, new_order) for x in if len(encoder_out["encoder_padding_mask"]) == 0
encoder_out["encoder_padding_mask"]] else [
x.index_select(0, new_order)
for x in encoder_out["encoder_padding_mask"]
]
) )
new_encoder_embedding = ( new_encoder_embedding = (
[] if len(encoder_out["encoder_embedding"]) == 0 []
else [x.index_select(0, new_order) for x in if len(encoder_out["encoder_embedding"]) == 0
encoder_out["encoder_embedding"]] else [
x.index_select(0, new_order) for x in encoder_out["encoder_embedding"]
]
) )
encoder_states = encoder_out["encoder_states"] encoder_states = encoder_out["encoder_states"]
@ -274,38 +301,71 @@ class Wav2VecEncoderWithAdaptor(FairseqEncoder):
def add_decoder_args(parser): def add_decoder_args(parser):
parser.add_argument("--activation-fn", type=str, default='relu',
choices=utils.get_available_activation_fns(),
help="activation function to use")
parser.add_argument("--decoder-dropout", type=float, metavar="D",
help="dropout probability")
parser.add_argument("--decoder-attention-dropout", type=float,
metavar="D",
help="dropout probability for attention weights")
parser.add_argument("--decoder-activation-dropout", type=float,
metavar="D",
help="dropout probability after activation in FFN.")
parser.add_argument("--decoder-embed-dim", type=int, metavar="N",
help="decoder embedding dimension")
parser.add_argument("--decoder-ffn-embed-dim", type=int, metavar="N",
help="decoder embedding dimension for FFN")
parser.add_argument("--decoder-layers", type=int, metavar="N",
help="num decoder layers")
parser.add_argument("--decoder-attention-heads", type=int, metavar="N",
help="num decoder attention heads")
parser.add_argument("--decoder-normalize-before", action="store_true",
help="apply layernorm before each decoder block")
parser.add_argument("--layernorm-embedding", action="store_true",
help="add layernorm to embedding")
parser.add_argument("--no-scale-embedding", action="store_true",
help="if True, dont scale embeddings")
parser.add_argument( parser.add_argument(
"--load-pretrained-decoder-from", type=str, metavar="STR", "--activation-fn",
help="model to take decoder weights from (for initialization)" type=str,
default="relu",
choices=utils.get_available_activation_fns(),
help="activation function to use",
)
parser.add_argument(
"--decoder-dropout", type=float, metavar="D", help="dropout probability"
)
parser.add_argument(
"--decoder-attention-dropout",
type=float,
metavar="D",
help="dropout probability for attention weights",
)
parser.add_argument(
"--decoder-activation-dropout",
type=float,
metavar="D",
help="dropout probability after activation in FFN.",
)
parser.add_argument(
"--decoder-embed-dim", type=int, metavar="N", help="decoder embedding dimension"
)
parser.add_argument(
"--decoder-ffn-embed-dim",
type=int,
metavar="N",
help="decoder embedding dimension for FFN",
)
parser.add_argument(
"--decoder-layers", type=int, metavar="N", help="num decoder layers"
)
parser.add_argument(
"--decoder-attention-heads",
type=int,
metavar="N",
help="num decoder attention heads",
)
parser.add_argument(
"--decoder-normalize-before",
action="store_true",
help="apply layernorm before each decoder block",
)
parser.add_argument(
"--layernorm-embedding", action="store_true", help="add layernorm to embedding"
)
parser.add_argument(
"--no-scale-embedding",
action="store_true",
help="if True, dont scale embeddings",
)
parser.add_argument(
"--load-pretrained-decoder-from",
type=str,
metavar="STR",
help="model to take decoder weights from (for initialization)",
)
parser.add_argument(
"--finetune-decoder-params",
type=str,
metavar="STR",
help="comma-separated param strings to finetune.",
) )
parser.add_argument("--finetune-decoder-params", type=str,
metavar="STR",
help="comma-separated param strings to finetune.")
parser.add_argument("--checkpoint-activations", action="store_true") parser.add_argument("--checkpoint-activations", action="store_true")
@ -342,16 +402,16 @@ class XMTransformerModel(FairseqEncoderDecoderModel):
_args.activation_dropout = args.decoder_activation_dropout _args.activation_dropout = args.decoder_activation_dropout
_args.max_target_positions = 1024 _args.max_target_positions = 1024
decoder = TransformerDecoder(_args, task.target_dictionary, decoder = TransformerDecoder(_args, task.target_dictionary, embed_tokens)
embed_tokens)
if getattr(args, "load_pretrained_decoder_from", None): if getattr(args, "load_pretrained_decoder_from", None):
decoder = checkpoint_utils.load_pretrained_component_from_model( decoder = checkpoint_utils.load_pretrained_component_from_model(
component=decoder, checkpoint=args.load_pretrained_decoder_from component=decoder, checkpoint=args.load_pretrained_decoder_from
) )
for k, p in decoder.named_parameters(): for k, p in decoder.named_parameters():
# Freeze pretrained models by default # Freeze pretrained models by default
if safe_hasattr(args, 'finetune_decoder_params') and XMTransformerModel.finetune_params( if safe_hasattr(
args.finetune_decoder_params, k): args, "finetune_decoder_params"
) and XMTransformerModel.finetune_params(args.finetune_decoder_params, k):
p.requires_grad = True p.requires_grad = True
else: else:
p.requires_grad = False p.requires_grad = False
@ -369,8 +429,9 @@ class XMTransformerModel(FairseqEncoderDecoderModel):
padding_idx = dictionary.pad() padding_idx = dictionary.pad()
return Embedding(num_embeddings, embed_dim, padding_idx) return Embedding(num_embeddings, embed_dim, padding_idx)
decoder_embed_tokens = build_embedding(task.target_dictionary, decoder_embed_tokens = build_embedding(
args.decoder_embed_dim) task.target_dictionary, args.decoder_embed_dim
)
encoder = cls.build_encoder(args) encoder = cls.build_encoder(args)
decoder = cls.build_decoder(args, task, decoder_embed_tokens) decoder = cls.build_decoder(args, task, decoder_embed_tokens)
return cls(encoder, decoder) return cls(encoder, decoder)
@ -382,8 +443,7 @@ class XMTransformerModel(FairseqEncoderDecoderModel):
sample: Optional[Dict[str, Tensor]] = None, sample: Optional[Dict[str, Tensor]] = None,
): ):
# net_output['encoder_out'] is a (B, T, D) tensor # net_output['encoder_out'] is a (B, T, D) tensor
lprobs = self.get_normalized_probs_scriptable(net_output, log_probs, lprobs = self.get_normalized_probs_scriptable(net_output, log_probs, sample)
sample)
lprobs.batch_first = True lprobs.batch_first = True
return lprobs return lprobs
@ -393,17 +453,19 @@ class XMTransformerModel(FairseqEncoderDecoderModel):
argument in its input, which is not supported in torchscript. This argument in its input, which is not supported in torchscript. This
method overrites the forward method definition without **kwargs. method overrites the forward method definition without **kwargs.
""" """
encoder_out = self.encoder(src_tokens=src_tokens, encoder_out = self.encoder(
src_lengths=src_lengths, **kwargs) src_tokens=src_tokens, src_lengths=src_lengths, **kwargs
decoder_out = self.decoder(prev_output_tokens=prev_output_tokens, )
encoder_out=encoder_out) decoder_out = self.decoder(
prev_output_tokens=prev_output_tokens, encoder_out=encoder_out
)
return decoder_out return decoder_out
def upgrade_state_dict(self, state_dict): def upgrade_state_dict(self, state_dict):
for k, _ in state_dict.items(): for k, _ in state_dict.items():
if 'adaptor.layers' in state_dict: if "adaptor.layers" in state_dict:
print(k) print(k)
new = k.replace('adaptor.layers', 'adaptor_layers') new = k.replace("adaptor.layers", "adaptor_layers")
state_dict[new] = state_dict[k] state_dict[new] = state_dict[k]
del state_dict[k] del state_dict[k]
@ -435,11 +497,9 @@ def set_default_w2v_encoder_args(args):
args.mask_channel_length = getattr(args, "mask_channel_length", 10) args.mask_channel_length = getattr(args, "mask_channel_length", 10)
args.mask_channel_prob = getattr(args, "mask_channel_prob", 0.5) args.mask_channel_prob = getattr(args, "mask_channel_prob", 0.5)
args.mask_channel_before = getattr(args, "mask_channel_before", False) args.mask_channel_before = getattr(args, "mask_channel_before", False)
args.mask_channel_selection = getattr(args, "mask_channel_selection", args.mask_channel_selection = getattr(args, "mask_channel_selection", "static")
"static")
args.mask_channel_other = getattr(args, "mask_channel_other", 0) args.mask_channel_other = getattr(args, "mask_channel_other", 0)
args.no_mask_channel_overlap = getattr(args, "no_mask_channel_overlap", args.no_mask_channel_overlap = getattr(args, "no_mask_channel_overlap", False)
False)
args.freeze_finetune_updates = getattr(args, "freeze_finetune_updates", 0) args.freeze_finetune_updates = getattr(args, "freeze_finetune_updates", 0)
args.feature_grad_mult = 0.1 args.feature_grad_mult = 0.1
@ -456,49 +516,43 @@ def set_default_adaptor_args(args):
def set_default_mbart_decoder_args(args): def set_default_mbart_decoder_args(args):
args.decoder_embed_path = getattr(args, 'decoder_embed_path', None) args.decoder_embed_path = getattr(args, "decoder_embed_path", None)
args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 1024) args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1024)
args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 4 * 1024)
4 * 1024) args.decoder_layers = getattr(args, "decoder_layers", 12)
args.decoder_layers = getattr(args, 'decoder_layers', 12) args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16)
args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 16) args.decoder_normalize_before = getattr(args, "decoder_normalize_before", True)
args.decoder_normalize_before = getattr(args, 'decoder_normalize_before', args.decoder_learned_pos = getattr(args, "decoder_learned_pos", True)
True)
args.decoder_learned_pos = getattr(args, 'decoder_learned_pos', True)
args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0.0) args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0.0)
args.adaptive_input = getattr(args, "adaptive_input", False) args.adaptive_input = getattr(args, "adaptive_input", False)
args.decoder_attention_dropout = getattr(args, 'decoder_attention_dropout', args.decoder_attention_dropout = getattr(args, "decoder_attention_dropout", 0.0)
0.) args.decoder_activation_dropout = getattr(args, "decoder_activation_dropout", 0.0)
args.decoder_activation_dropout = getattr(args, args.decoder_dropout = getattr(args, "decoder_dropout", 0.1)
'decoder_activation_dropout', 0.) args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None)
args.decoder_dropout = getattr(args, 'decoder_dropout', 0.1) args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
args.adaptive_softmax_cutoff = getattr(args, 'adaptive_softmax_cutoff',
None)
args.adaptive_softmax_dropout = getattr(args, 'adaptive_softmax_dropout', 0)
args.share_decoder_input_output_embed = getattr( args.share_decoder_input_output_embed = getattr(
args, 'share_decoder_input_output_embed', True args, "share_decoder_input_output_embed", True
) )
args.no_token_positional_embeddings = getattr( args.no_token_positional_embeddings = getattr(
args, "no_token_positional_embeddings", False args, "no_token_positional_embeddings", False
) )
args.decoder_output_dim = getattr(args, 'decoder_output_dim', args.decoder_output_dim = getattr(
args.decoder_embed_dim) args, "decoder_output_dim", args.decoder_embed_dim
args.decoder_input_dim = getattr(args, 'decoder_input_dim', )
args.decoder_embed_dim) args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim)
args.no_scale_embedding = getattr(args, 'no_scale_embedding', False) args.no_scale_embedding = getattr(args, "no_scale_embedding", False)
args.quant_noise_pq = getattr(args, "quant_noise_pq", 0) args.quant_noise_pq = getattr(args, "quant_noise_pq", 0)
args.layernorm_embedding = getattr(args, 'layernorm_embedding', True) args.layernorm_embedding = getattr(args, "layernorm_embedding", True)
args.activation_fn = getattr(args, 'activation_fn', 'gelu') args.activation_fn = getattr(args, "activation_fn", "gelu")
args.pooler_activation_fn = getattr(args, 'pooler_activation_fn', 'tanh') args.pooler_activation_fn = getattr(args, "pooler_activation_fn", "tanh")
args.pooler_dropout = getattr(args, 'pooler_dropout', 0.0) args.pooler_dropout = getattr(args, "pooler_dropout", 0.0)
args.checkpoint_activations = getattr(args, "checkpoint_activations", False) args.checkpoint_activations = getattr(args, "checkpoint_activations", False)
@register_model_architecture(model_name="xm_transformer", @register_model_architecture(model_name="xm_transformer", arch_name="xm_transformer")
arch_name="xm_transformer")
def base_architecture(args): def base_architecture(args):
set_default_w2v_encoder_args(args) set_default_w2v_encoder_args(args)
set_default_adaptor_args(args) set_default_adaptor_args(args)

View File

@ -8,10 +8,17 @@ import logging
import torch import torch
from torch import nn from torch import nn
from fairseq.models import (FairseqEncoder, FairseqEncoderModel, register_model, from fairseq.models import (
register_model_architecture) FairseqEncoder,
FairseqEncoderModel,
register_model,
register_model_architecture,
)
from fairseq.modules import ( from fairseq.modules import (
LayerNorm, PositionalEmbedding, FairseqDropout, MultiheadAttention LayerNorm,
PositionalEmbedding,
FairseqDropout,
MultiheadAttention,
) )
from fairseq import utils from fairseq import utils
from fairseq.data.data_utils import lengths_to_padding_mask from fairseq.data.data_utils import lengths_to_padding_mask
@ -36,11 +43,19 @@ class PositionwiseFeedForward(nn.Module):
def __init__(self, in_dim, hidden_dim, kernel_size, dropout): def __init__(self, in_dim, hidden_dim, kernel_size, dropout):
super().__init__() super().__init__()
self.ffn = nn.Sequential( self.ffn = nn.Sequential(
nn.Conv1d(in_dim, hidden_dim, kernel_size=kernel_size, nn.Conv1d(
padding=(kernel_size - 1) // 2), in_dim,
hidden_dim,
kernel_size=kernel_size,
padding=(kernel_size - 1) // 2,
),
nn.ReLU(), nn.ReLU(),
nn.Conv1d(hidden_dim, in_dim, kernel_size=kernel_size, nn.Conv1d(
padding=(kernel_size - 1) // 2) hidden_dim,
in_dim,
kernel_size=kernel_size,
padding=(kernel_size - 1) // 2,
),
) )
self.layer_norm = LayerNorm(in_dim) self.layer_norm = LayerNorm(in_dim)
self.dropout = self.dropout_module = FairseqDropout( self.dropout = self.dropout_module = FairseqDropout(
@ -57,8 +72,7 @@ class PositionwiseFeedForward(nn.Module):
class FFTLayer(torch.nn.Module): class FFTLayer(torch.nn.Module):
def __init__( def __init__(
self, embed_dim, n_heads, hidden_dim, kernel_size, dropout, self, embed_dim, n_heads, hidden_dim, kernel_size, dropout, attention_dropout
attention_dropout
): ):
super().__init__() super().__init__()
self.self_attn = MultiheadAttention( self.self_attn = MultiheadAttention(
@ -74,8 +88,7 @@ class FFTLayer(torch.nn.Module):
residual = x residual = x
x = x.transpose(0, 1) x = x.transpose(0, 1)
x, _ = self.self_attn( x, _ = self.self_attn(
query=x, key=x, value=x, key_padding_mask=padding_mask, query=x, key=x, value=x, key_padding_mask=padding_mask, need_weights=False
need_weights=False
) )
x = x.transpose(0, 1) x = x.transpose(0, 1)
x = self.layer_norm(x + residual) x = self.layer_norm(x + residual)
@ -106,11 +119,12 @@ class VariancePredictor(nn.Module):
super().__init__() super().__init__()
self.conv1 = nn.Sequential( self.conv1 = nn.Sequential(
nn.Conv1d( nn.Conv1d(
args.encoder_embed_dim, args.var_pred_hidden_dim, args.encoder_embed_dim,
args.var_pred_hidden_dim,
kernel_size=args.var_pred_kernel_size, kernel_size=args.var_pred_kernel_size,
padding=(args.var_pred_kernel_size - 1) // 2 padding=(args.var_pred_kernel_size - 1) // 2,
), ),
nn.ReLU() nn.ReLU(),
) )
self.ln1 = nn.LayerNorm(args.var_pred_hidden_dim) self.ln1 = nn.LayerNorm(args.var_pred_hidden_dim)
self.dropout_module = FairseqDropout( self.dropout_module = FairseqDropout(
@ -118,10 +132,12 @@ class VariancePredictor(nn.Module):
) )
self.conv2 = nn.Sequential( self.conv2 = nn.Sequential(
nn.Conv1d( nn.Conv1d(
args.var_pred_hidden_dim, args.var_pred_hidden_dim, args.var_pred_hidden_dim,
kernel_size=args.var_pred_kernel_size, padding=1 args.var_pred_hidden_dim,
kernel_size=args.var_pred_kernel_size,
padding=1,
), ),
nn.ReLU() nn.ReLU(),
) )
self.ln2 = nn.LayerNorm(args.var_pred_hidden_dim) self.ln2 = nn.LayerNorm(args.var_pred_hidden_dim)
self.proj = nn.Linear(args.var_pred_hidden_dim, 1) self.proj = nn.Linear(args.var_pred_hidden_dim, 1)
@ -171,8 +187,15 @@ class VarianceAdaptor(nn.Module):
return out, emb return out, emb
def forward( def forward(
self, x, padding_mask, durations=None, pitches=None, energies=None, self,
d_factor=1.0, p_factor=1.0, e_factor=1.0 x,
padding_mask,
durations=None,
pitches=None,
energies=None,
d_factor=1.0,
p_factor=1.0,
e_factor=1.0,
): ):
# x: B x T x C # x: B x T x C
log_dur_out = self.duration_predictor(x) log_dur_out = self.duration_predictor(x)
@ -205,8 +228,7 @@ class FastSpeech2Encoder(FairseqEncoder):
self.spk_emb_proj = None self.spk_emb_proj = None
if embed_speaker is not None: if embed_speaker is not None:
self.spk_emb_proj = nn.Linear( self.spk_emb_proj = nn.Linear(
args.encoder_embed_dim + args.speaker_embed_dim, args.encoder_embed_dim + args.speaker_embed_dim, args.encoder_embed_dim
args.encoder_embed_dim
) )
self.dropout_module = FairseqDropout( self.dropout_module = FairseqDropout(
@ -224,9 +246,12 @@ class FastSpeech2Encoder(FairseqEncoder):
self.encoder_fft_layers = nn.ModuleList( self.encoder_fft_layers = nn.ModuleList(
FFTLayer( FFTLayer(
args.encoder_embed_dim, args.encoder_attention_heads, args.encoder_embed_dim,
args.fft_hidden_dim, args.fft_kernel_size, args.encoder_attention_heads,
dropout=args.dropout, attention_dropout=args.attention_dropout args.fft_hidden_dim,
args.fft_kernel_size,
dropout=args.dropout,
attention_dropout=args.attention_dropout,
) )
for _ in range(args.encoder_layers) for _ in range(args.encoder_layers)
) )
@ -235,9 +260,12 @@ class FastSpeech2Encoder(FairseqEncoder):
self.decoder_fft_layers = nn.ModuleList( self.decoder_fft_layers = nn.ModuleList(
FFTLayer( FFTLayer(
args.decoder_embed_dim, args.decoder_attention_heads, args.decoder_embed_dim,
args.fft_hidden_dim, args.fft_kernel_size, args.decoder_attention_heads,
dropout=args.dropout, attention_dropout=args.attention_dropout args.fft_hidden_dim,
args.fft_kernel_size,
dropout=args.dropout,
attention_dropout=args.attention_dropout,
) )
for _ in range(args.decoder_layers) for _ in range(args.decoder_layers)
) )
@ -247,15 +275,25 @@ class FastSpeech2Encoder(FairseqEncoder):
self.postnet = None self.postnet = None
if args.add_postnet: if args.add_postnet:
self.postnet = Postnet( self.postnet = Postnet(
self.out_dim, args.postnet_conv_dim, self.out_dim,
args.postnet_conv_dim,
args.postnet_conv_kernel_size, args.postnet_conv_kernel_size,
args.postnet_layers, args.postnet_dropout args.postnet_layers,
args.postnet_dropout,
) )
self.apply(model_init) self.apply(model_init)
def forward(self, src_tokens, src_lengths=None, speaker=None, def forward(
durations=None, pitches=None, energies=None, **kwargs): self,
src_tokens,
src_lengths=None,
speaker=None,
durations=None,
pitches=None,
energies=None,
**kwargs
):
x = self.embed_tokens(src_tokens) x = self.embed_tokens(src_tokens)
enc_padding_mask = src_tokens.eq(self.padding_idx) enc_padding_mask = src_tokens.eq(self.padding_idx)
@ -270,8 +308,9 @@ class FastSpeech2Encoder(FairseqEncoder):
emb = self.embed_speaker(speaker).expand(bsz, seq_len, -1) emb = self.embed_speaker(speaker).expand(bsz, seq_len, -1)
x = self.spk_emb_proj(torch.cat([x, emb], dim=2)) x = self.spk_emb_proj(torch.cat([x, emb], dim=2))
x, out_lens, log_dur_out, pitch_out, energy_out = \ x, out_lens, log_dur_out, pitch_out, energy_out = self.var_adaptor(
self.var_adaptor(x, enc_padding_mask, durations, pitches, energies) x, enc_padding_mask, durations, pitches, energies
)
dec_padding_mask = lengths_to_padding_mask(out_lens) dec_padding_mask = lengths_to_padding_mask(out_lens)
x += self.dec_pos_emb_alpha * self.embed_positions(dec_padding_mask) x += self.dec_pos_emb_alpha * self.embed_positions(dec_padding_mask)
@ -326,7 +365,7 @@ class FastSpeech2Model(FairseqEncoderModel):
out_dim = args.output_frame_dim * args.n_frames_per_step out_dim = args.output_frame_dim * args.n_frames_per_step
self.ctc_proj = None self.ctc_proj = None
if getattr(args, "ctc_weight", 0.) > 0.: if getattr(args, "ctc_weight", 0.0) > 0.0:
self.ctc_proj = nn.Linear(out_dim, len(src_dict)) self.ctc_proj = nn.Linear(out_dim, len(src_dict))
@classmethod @classmethod

View File

@ -119,7 +119,7 @@ class Generator(torch.nn.Module):
self.ups = nn.ModuleList() self.ups = nn.ModuleList()
for i, (u, k) in enumerate( for i, (u, k) in enumerate(
zip(cfg["upsample_rates"], cfg["upsample_kernel_sizes"]) zip(cfg["upsample_rates"], cfg["upsample_kernel_sizes"])
): ):
self.ups.append( self.ups.append(
weight_norm( weight_norm(
@ -137,7 +137,7 @@ class Generator(torch.nn.Module):
for i in range(len(self.ups)): for i in range(len(self.ups)):
ch = cfg["upsample_initial_channel"] // (2 ** (i + 1)) ch = cfg["upsample_initial_channel"] // (2 ** (i + 1))
for k, d in zip( for k, d in zip(
cfg["resblock_kernel_sizes"], cfg["resblock_dilation_sizes"] cfg["resblock_kernel_sizes"], cfg["resblock_dilation_sizes"]
): ):
self.resblocks.append(ResBlock(ch, k, d)) self.resblocks.append(ResBlock(ch, k, d))

View File

@ -9,9 +9,13 @@ import torch
from torch import nn from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
from fairseq.models import (FairseqEncoder, FairseqEncoderDecoderModel, from fairseq.models import (
FairseqIncrementalDecoder, register_model, FairseqEncoder,
register_model_architecture) FairseqEncoderDecoderModel,
FairseqIncrementalDecoder,
register_model,
register_model_architecture,
)
from fairseq.modules import LSTMCellWithZoneOut, LocationAttention from fairseq.modules import LSTMCellWithZoneOut, LocationAttention
@ -31,29 +35,36 @@ class Tacotron2Encoder(FairseqEncoder):
self.spk_emb_proj = None self.spk_emb_proj = None
if embed_speaker is not None: if embed_speaker is not None:
self.spk_emb_proj = nn.Linear( self.spk_emb_proj = nn.Linear(
args.encoder_embed_dim + args.speaker_embed_dim, args.encoder_embed_dim + args.speaker_embed_dim, args.encoder_embed_dim
args.encoder_embed_dim
) )
self.embed_tokens = nn.Embedding(len(src_dict), args.encoder_embed_dim, self.embed_tokens = nn.Embedding(
padding_idx=self.padding_idx) len(src_dict), args.encoder_embed_dim, padding_idx=self.padding_idx
)
assert(args.encoder_conv_kernel_size % 2 == 1) assert args.encoder_conv_kernel_size % 2 == 1
self.convolutions = nn.ModuleList( self.convolutions = nn.ModuleList(
nn.Sequential( nn.Sequential(
nn.Conv1d(args.encoder_embed_dim, args.encoder_embed_dim, nn.Conv1d(
kernel_size=args.encoder_conv_kernel_size, args.encoder_embed_dim,
padding=((args.encoder_conv_kernel_size - 1) // 2)), args.encoder_embed_dim,
kernel_size=args.encoder_conv_kernel_size,
padding=((args.encoder_conv_kernel_size - 1) // 2),
),
nn.BatchNorm1d(args.encoder_embed_dim), nn.BatchNorm1d(args.encoder_embed_dim),
nn.ReLU(), nn.ReLU(),
nn.Dropout(args.encoder_dropout) nn.Dropout(args.encoder_dropout),
) )
for _ in range(args.encoder_conv_layers) for _ in range(args.encoder_conv_layers)
) )
self.lstm = nn.LSTM(args.encoder_embed_dim, args.encoder_embed_dim // 2, self.lstm = nn.LSTM(
num_layers=args.encoder_lstm_layers, args.encoder_embed_dim,
batch_first=True, bidirectional=True) args.encoder_embed_dim // 2,
num_layers=args.encoder_lstm_layers,
batch_first=True,
bidirectional=True,
)
self.apply(encoder_init) self.apply(encoder_init)
@ -78,7 +89,7 @@ class Tacotron2Encoder(FairseqEncoder):
return { return {
"encoder_out": [x], # B x T x C "encoder_out": [x], # B x T x C
"encoder_padding_mask": encoder_padding_mask, # B x T "encoder_padding_mask": encoder_padding_mask, # B x T
} }
@ -86,8 +97,7 @@ class Prenet(nn.Module):
def __init__(self, in_dim, n_layers, n_units, dropout): def __init__(self, in_dim, n_layers, n_units, dropout):
super().__init__() super().__init__()
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
nn.Sequential(nn.Linear(in_dim if i == 0 else n_units, n_units), nn.Sequential(nn.Linear(in_dim if i == 0 else n_units, n_units), nn.ReLU())
nn.ReLU())
for i in range(n_layers) for i in range(n_layers)
) )
self.dropout = dropout self.dropout = dropout
@ -102,20 +112,24 @@ class Postnet(nn.Module):
def __init__(self, in_dim, n_channels, kernel_size, n_layers, dropout): def __init__(self, in_dim, n_channels, kernel_size, n_layers, dropout):
super(Postnet, self).__init__() super(Postnet, self).__init__()
self.convolutions = nn.ModuleList() self.convolutions = nn.ModuleList()
assert(kernel_size % 2 == 1) assert kernel_size % 2 == 1
for i in range(n_layers): for i in range(n_layers):
cur_layers = [ cur_layers = (
nn.Conv1d(in_dim if i == 0 else n_channels, [
n_channels if i < n_layers - 1 else in_dim, nn.Conv1d(
kernel_size=kernel_size, in_dim if i == 0 else n_channels,
padding=((kernel_size - 1) // 2)), n_channels if i < n_layers - 1 else in_dim,
nn.BatchNorm1d(n_channels if i < n_layers - 1 else in_dim) kernel_size=kernel_size,
] + ([nn.Tanh()] if i < n_layers - 1 else []) + [nn.Dropout(dropout)] padding=((kernel_size - 1) // 2),
),
nn.BatchNorm1d(n_channels if i < n_layers - 1 else in_dim),
]
+ ([nn.Tanh()] if i < n_layers - 1 else [])
+ [nn.Dropout(dropout)]
)
nn.init.xavier_uniform_( nn.init.xavier_uniform_(
cur_layers[0].weight, cur_layers[0].weight,
torch.nn.init.calculate_gain( torch.nn.init.calculate_gain("tanh" if i < n_layers - 1 else "linear"),
"tanh" if i < n_layers - 1 else "linear"
)
) )
self.convolutions.append(nn.Sequential(*cur_layers)) self.convolutions.append(nn.Sequential(*cur_layers))
@ -138,21 +152,25 @@ class Tacotron2Decoder(FairseqIncrementalDecoder):
self.n_frames_per_step = args.n_frames_per_step self.n_frames_per_step = args.n_frames_per_step
self.out_dim = args.output_frame_dim * args.n_frames_per_step self.out_dim = args.output_frame_dim * args.n_frames_per_step
self.prenet = Prenet(self.out_dim, args.prenet_layers, args.prenet_dim, self.prenet = Prenet(
args.prenet_dropout) self.out_dim, args.prenet_layers, args.prenet_dim, args.prenet_dropout
)
# take prev_context, prev_frame, (speaker embedding) as input # take prev_context, prev_frame, (speaker embedding) as input
self.attention_lstm = LSTMCellWithZoneOut( self.attention_lstm = LSTMCellWithZoneOut(
args.zoneout, args.zoneout,
args.prenet_dim + args.encoder_embed_dim, args.prenet_dim + args.encoder_embed_dim,
args.decoder_lstm_dim args.decoder_lstm_dim,
) )
# take attention_lstm output, attention_state, encoder_out as input # take attention_lstm output, attention_state, encoder_out as input
self.attention = LocationAttention( self.attention = LocationAttention(
args.attention_dim, args.encoder_embed_dim, args.decoder_lstm_dim, args.attention_dim,
args.encoder_embed_dim,
args.decoder_lstm_dim,
(1 + int(args.attention_use_cumprob)), (1 + int(args.attention_use_cumprob)),
args.attention_conv_dim, args.attention_conv_kernel_size args.attention_conv_dim,
args.attention_conv_kernel_size,
) )
# take attention_lstm output, context, (gated_latent) as input # take attention_lstm output, context, (gated_latent) as input
@ -160,7 +178,7 @@ class Tacotron2Decoder(FairseqIncrementalDecoder):
LSTMCellWithZoneOut( LSTMCellWithZoneOut(
args.zoneout, args.zoneout,
args.encoder_embed_dim + args.decoder_lstm_dim, args.encoder_embed_dim + args.decoder_lstm_dim,
args.decoder_lstm_dim args.decoder_lstm_dim,
) )
for i in range(args.decoder_lstm_layers) for i in range(args.decoder_lstm_layers)
) )
@ -169,12 +187,16 @@ class Tacotron2Decoder(FairseqIncrementalDecoder):
self.feat_proj = nn.Linear(proj_in_dim, self.out_dim) self.feat_proj = nn.Linear(proj_in_dim, self.out_dim)
self.eos_proj = nn.Linear(proj_in_dim, 1) self.eos_proj = nn.Linear(proj_in_dim, 1)
self.postnet = Postnet(self.out_dim, args.postnet_conv_dim, self.postnet = Postnet(
args.postnet_conv_kernel_size, self.out_dim,
args.postnet_layers, args.postnet_dropout) args.postnet_conv_dim,
args.postnet_conv_kernel_size,
args.postnet_layers,
args.postnet_dropout,
)
self.ctc_proj = None self.ctc_proj = None
if getattr(args, "ctc_weight", 0.) > 0.: if getattr(args, "ctc_weight", 0.0) > 0.0:
self.ctc_proj = nn.Linear(self.out_dim, len(src_dict)) self.ctc_proj = nn.Linear(self.out_dim, len(src_dict))
self.apply(decoder_init) self.apply(decoder_init)
@ -190,12 +212,16 @@ class Tacotron2Decoder(FairseqIncrementalDecoder):
lstm_h = self.get_incremental_state(incremental_state, "lstm_h") lstm_h = self.get_incremental_state(incremental_state, "lstm_h")
if lstm_h is None: if lstm_h is None:
lstm_h = [enc_out.new_zeros(bsz, self.args.decoder_lstm_dim) lstm_h = [
for _ in range(self.args.decoder_lstm_layers)] enc_out.new_zeros(bsz, self.args.decoder_lstm_dim)
for _ in range(self.args.decoder_lstm_layers)
]
lstm_c = self.get_incremental_state(incremental_state, "lstm_c") lstm_c = self.get_incremental_state(incremental_state, "lstm_c")
if lstm_c is None: if lstm_c is None:
lstm_c = [enc_out.new_zeros(bsz, self.args.decoder_lstm_dim) lstm_c = [
for _ in range(self.args.decoder_lstm_layers)] enc_out.new_zeros(bsz, self.args.decoder_lstm_dim)
for _ in range(self.args.decoder_lstm_layers)
]
attn_w = self.get_incremental_state(incremental_state, "attn_w") attn_w = self.get_incremental_state(incremental_state, "attn_w")
if attn_w is None: if attn_w is None:
@ -216,8 +242,14 @@ class Tacotron2Decoder(FairseqIncrementalDecoder):
else: else:
raise ValueError(f"{self.args.init_attn_c} not supported") raise ValueError(f"{self.args.init_attn_c} not supported")
def forward(self, prev_output_tokens, encoder_out=None, def forward(
incremental_state=None, target_lengths=None, **kwargs): self,
prev_output_tokens,
encoder_out=None,
incremental_state=None,
target_lengths=None,
**kwargs,
):
enc_mask = encoder_out["encoder_padding_mask"] enc_mask = encoder_out["encoder_padding_mask"]
enc_out = encoder_out["encoder_out"][0] enc_out = encoder_out["encoder_out"][0]
in_len = enc_out.size(1) in_len = enc_out.size(1)
@ -227,8 +259,9 @@ class Tacotron2Decoder(FairseqIncrementalDecoder):
bsz, out_len, _ = prev_output_tokens.size() bsz, out_len, _ = prev_output_tokens.size()
prenet_out = self.prenet(prev_output_tokens) prenet_out = self.prenet(prev_output_tokens)
(alstm_h, alstm_c, lstm_h, lstm_c, (alstm_h, alstm_c, lstm_h, lstm_c, attn_w, attn_w_cum) = self._get_states(
attn_w, attn_w_cum) = self._get_states(incremental_state, enc_out) incremental_state, enc_out
)
attn_ctx = self._get_init_attn_c(enc_out, enc_mask) attn_ctx = self._get_init_attn_c(enc_out, enc_mask)
attn_out = enc_out.new_zeros(bsz, in_len, out_len) attn_out = enc_out.new_zeros(bsz, in_len, out_len)
@ -241,9 +274,7 @@ class Tacotron2Decoder(FairseqIncrementalDecoder):
attn_state = attn_w.unsqueeze(1) attn_state = attn_w.unsqueeze(1)
if self.args.attention_use_cumprob: if self.args.attention_use_cumprob:
attn_state = torch.stack((attn_w, attn_w_cum), dim=1) attn_state = torch.stack((attn_w, attn_w_cum), dim=1)
attn_ctx, attn_w = self.attention( attn_ctx, attn_w = self.attention(enc_out, enc_mask, alstm_h, attn_state)
enc_out, enc_mask, alstm_h, attn_state
)
attn_w_cum = attn_w_cum + attn_w attn_w_cum = attn_w_cum + attn_w
attn_out[:, :, t] = attn_w attn_out[:, :, t] = attn_w
@ -297,7 +328,7 @@ class Tacotron2Model(FairseqEncoderDecoderModel):
parser.add_argument("--postnet-conv-dim", type=int) parser.add_argument("--postnet-conv-dim", type=int)
parser.add_argument("--postnet-conv-kernel-size", type=int) parser.add_argument("--postnet-conv-kernel-size", type=int)
parser.add_argument("--init-attn-c", type=str) parser.add_argument("--init-attn-c", type=str)
parser.add_argument("--attention-use-cumprob", action='store_true') parser.add_argument("--attention-use-cumprob", action="store_true")
parser.add_argument("--zoneout", type=float) parser.add_argument("--zoneout", type=float)
parser.add_argument("--decoder-lstm-layers", type=int) parser.add_argument("--decoder-lstm-layers", type=int)
parser.add_argument("--decoder-lstm-dim", type=int) parser.add_argument("--decoder-lstm-dim", type=int)
@ -333,8 +364,7 @@ def base_architecture(args):
# decoder # decoder
args.attention_dim = getattr(args, "attention_dim", 128) args.attention_dim = getattr(args, "attention_dim", 128)
args.attention_conv_dim = getattr(args, "attention_conv_dim", 32) args.attention_conv_dim = getattr(args, "attention_conv_dim", 32)
args.attention_conv_kernel_size = getattr(args, args.attention_conv_kernel_size = getattr(args, "attention_conv_kernel_size", 15)
"attention_conv_kernel_size", 15)
args.prenet_dropout = getattr(args, "prenet_dropout", 0.5) args.prenet_dropout = getattr(args, "prenet_dropout", 0.5)
args.prenet_layers = getattr(args, "prenet_layers", 2) args.prenet_layers = getattr(args, "prenet_layers", 2)
args.prenet_dim = getattr(args, "prenet_dim", 256) args.prenet_dim = getattr(args, "prenet_dim", 256)

View File

@ -9,12 +9,14 @@ from typing import List, Optional
import torch import torch
from torch import nn from torch import nn
from fairseq.models import (FairseqEncoder, FairseqEncoderDecoderModel, from fairseq.models import (
FairseqIncrementalDecoder, register_model, FairseqEncoder,
register_model_architecture) FairseqEncoderDecoderModel,
from fairseq.modules import ( FairseqIncrementalDecoder,
TransformerEncoderLayer, TransformerDecoderLayer register_model,
register_model_architecture,
) )
from fairseq.modules import TransformerEncoderLayer, TransformerDecoderLayer
from fairseq.models.text_to_speech.tacotron2 import Prenet, Postnet from fairseq.models.text_to_speech.tacotron2 import Prenet, Postnet
from fairseq.modules import LayerNorm, PositionalEmbedding, FairseqDropout from fairseq.modules import LayerNorm, PositionalEmbedding, FairseqDropout
from fairseq.data.data_utils import lengths_to_padding_mask from fairseq.data.data_utils import lengths_to_padding_mask
@ -42,30 +44,31 @@ class TTSTransformerEncoder(FairseqEncoder):
self.spk_emb_proj = None self.spk_emb_proj = None
if embed_speaker is not None: if embed_speaker is not None:
self.spk_emb_proj = nn.Linear( self.spk_emb_proj = nn.Linear(
args.encoder_embed_dim + args.speaker_embed_dim, args.encoder_embed_dim + args.speaker_embed_dim, args.encoder_embed_dim
args.encoder_embed_dim
) )
self.dropout_module = FairseqDropout( self.dropout_module = FairseqDropout(
p=args.dropout, module_name=self.__class__.__name__ p=args.dropout, module_name=self.__class__.__name__
) )
self.embed_tokens = nn.Embedding(len(src_dict), args.encoder_embed_dim, self.embed_tokens = nn.Embedding(
padding_idx=self.padding_idx) len(src_dict), args.encoder_embed_dim, padding_idx=self.padding_idx
assert(args.encoder_conv_kernel_size % 2 == 1) )
assert args.encoder_conv_kernel_size % 2 == 1
self.prenet = nn.ModuleList( self.prenet = nn.ModuleList(
nn.Sequential( nn.Sequential(
nn.Conv1d(args.encoder_embed_dim, args.encoder_embed_dim, nn.Conv1d(
kernel_size=args.encoder_conv_kernel_size, args.encoder_embed_dim,
padding=((args.encoder_conv_kernel_size - 1) // 2)), args.encoder_embed_dim,
kernel_size=args.encoder_conv_kernel_size,
padding=((args.encoder_conv_kernel_size - 1) // 2),
),
nn.BatchNorm1d(args.encoder_embed_dim), nn.BatchNorm1d(args.encoder_embed_dim),
nn.ReLU(), nn.ReLU(),
nn.Dropout(args.encoder_dropout), nn.Dropout(args.encoder_dropout),
) )
for _ in range(args.encoder_conv_layers) for _ in range(args.encoder_conv_layers)
) )
self.prenet_proj = nn.Linear( self.prenet_proj = nn.Linear(args.encoder_embed_dim, args.encoder_embed_dim)
args.encoder_embed_dim, args.encoder_embed_dim
)
self.embed_positions = PositionalEmbedding( self.embed_positions = PositionalEmbedding(
args.max_source_positions, args.encoder_embed_dim, self.padding_idx args.max_source_positions, args.encoder_embed_dim, self.padding_idx
) )
@ -112,7 +115,9 @@ class TTSTransformerEncoder(FairseqEncoder):
return { return {
"encoder_out": [x], # T x B x C "encoder_out": [x], # T x B x C
"encoder_padding_mask": [padding_mask] if padding_mask.any() else [], # B x T "encoder_padding_mask": [padding_mask]
if padding_mask.any()
else [], # B x T
"encoder_embedding": [], # B x T x C "encoder_embedding": [], # B x T x C
"encoder_states": [], # List[T x B x C] "encoder_states": [], # List[T x B x C]
"src_tokens": [], "src_tokens": [],
@ -143,15 +148,15 @@ class TTSTransformerDecoder(FairseqIncrementalDecoder):
) )
self.pos_emb_alpha = nn.Parameter(torch.ones(1)) self.pos_emb_alpha = nn.Parameter(torch.ones(1))
self.prenet = nn.Sequential( self.prenet = nn.Sequential(
Prenet(self.out_dim, args.prenet_layers, args.prenet_dim, Prenet(
args.prenet_dropout), self.out_dim, args.prenet_layers, args.prenet_dim, args.prenet_dropout
),
nn.Linear(args.prenet_dim, args.decoder_embed_dim), nn.Linear(args.prenet_dim, args.decoder_embed_dim),
) )
self.n_transformer_layers = args.decoder_transformer_layers self.n_transformer_layers = args.decoder_transformer_layers
self.transformer_layers = nn.ModuleList( self.transformer_layers = nn.ModuleList(
TransformerDecoderLayer(args) TransformerDecoderLayer(args) for _ in range(self.n_transformer_layers)
for _ in range(self.n_transformer_layers)
) )
if args.decoder_normalize_before: if args.decoder_normalize_before:
self.layer_norm = LayerNorm(args.decoder_embed_dim) self.layer_norm = LayerNorm(args.decoder_embed_dim)
@ -161,19 +166,28 @@ class TTSTransformerDecoder(FairseqIncrementalDecoder):
self.feat_proj = nn.Linear(args.decoder_embed_dim, self.out_dim) self.feat_proj = nn.Linear(args.decoder_embed_dim, self.out_dim)
self.eos_proj = nn.Linear(args.decoder_embed_dim, 1) self.eos_proj = nn.Linear(args.decoder_embed_dim, 1)
self.postnet = Postnet(self.out_dim, args.postnet_conv_dim, self.postnet = Postnet(
args.postnet_conv_kernel_size, self.out_dim,
args.postnet_layers, args.postnet_dropout) args.postnet_conv_dim,
args.postnet_conv_kernel_size,
args.postnet_layers,
args.postnet_dropout,
)
self.ctc_proj = None self.ctc_proj = None
if getattr(args, "ctc_weight", 0.) > 0.: if getattr(args, "ctc_weight", 0.0) > 0.0:
self.ctc_proj = nn.Linear(self.out_dim, len(src_dict)) self.ctc_proj = nn.Linear(self.out_dim, len(src_dict))
self.apply(decoder_init) self.apply(decoder_init)
def extract_features( def extract_features(
self, prev_outputs, encoder_out=None, incremental_state=None, self,
target_lengths=None, speaker=None, **kwargs prev_outputs,
encoder_out=None,
incremental_state=None,
target_lengths=None,
speaker=None,
**kwargs
): ):
alignment_layer = self.n_transformer_layers - 1 alignment_layer = self.n_transformer_layers - 1
self_attn_padding_mask = lengths_to_padding_mask(target_lengths) self_attn_padding_mask = lengths_to_padding_mask(target_lengths)
@ -212,8 +226,8 @@ class TTSTransformerDecoder(FairseqIncrementalDecoder):
else None, else None,
encoder_out["encoder_padding_mask"][0] encoder_out["encoder_padding_mask"][0]
if ( if (
encoder_out is not None encoder_out is not None
and len(encoder_out["encoder_padding_mask"]) > 0 and len(encoder_out["encoder_padding_mask"]) > 0
) )
else None, else None,
incremental_state, incremental_state,
@ -239,13 +253,22 @@ class TTSTransformerDecoder(FairseqIncrementalDecoder):
return x, {"attn": attn, "inner_states": inner_states} return x, {"attn": attn, "inner_states": inner_states}
def forward(self, prev_output_tokens, encoder_out=None, def forward(
incremental_state=None, target_lengths=None, speaker=None, self,
**kwargs): prev_output_tokens,
encoder_out=None,
incremental_state=None,
target_lengths=None,
speaker=None,
**kwargs
):
x, extra = self.extract_features( x, extra = self.extract_features(
prev_output_tokens, encoder_out=encoder_out, prev_output_tokens,
incremental_state=incremental_state, target_lengths=target_lengths, encoder_out=encoder_out,
speaker=speaker, **kwargs incremental_state=incremental_state,
target_lengths=target_lengths,
speaker=speaker,
**kwargs
) )
attn = extra["attn"] attn = extra["attn"]
feat_out = self.feat_proj(x) feat_out = self.feat_proj(x)
@ -328,8 +351,9 @@ class TTSTransformerModel(FairseqEncoderDecoderModel):
return cls(encoder, decoder) return cls(encoder, decoder)
def forward_encoder(self, src_tokens, src_lengths, speaker=None, **kwargs): def forward_encoder(self, src_tokens, src_lengths, speaker=None, **kwargs):
return self.encoder(src_tokens, src_lengths=src_lengths, return self.encoder(
speaker=speaker, **kwargs) src_tokens, src_lengths=src_lengths, speaker=speaker, **kwargs
)
def set_num_updates(self, num_updates): def set_num_updates(self, num_updates):
super().set_num_updates(num_updates) super().set_num_updates(num_updates)
@ -348,7 +372,9 @@ def base_architecture(args):
# encoder transformer layers # encoder transformer layers
args.encoder_transformer_layers = getattr(args, "encoder_transformer_layers", 6) args.encoder_transformer_layers = getattr(args, "encoder_transformer_layers", 6)
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512) args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4 * args.encoder_embed_dim) args.encoder_ffn_embed_dim = getattr(
args, "encoder_ffn_embed_dim", 4 * args.encoder_embed_dim
)
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False) args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4) args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4)
args.attention_dropout = getattr(args, "attention_dropout", 0.0) args.attention_dropout = getattr(args, "attention_dropout", 0.0)
@ -366,6 +392,8 @@ def base_architecture(args):
# decoder transformer layers # decoder transformer layers
args.decoder_transformer_layers = getattr(args, "decoder_transformer_layers", 6) args.decoder_transformer_layers = getattr(args, "decoder_transformer_layers", 6)
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512) args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512)
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 4 * args.decoder_embed_dim) args.decoder_ffn_embed_dim = getattr(
args, "decoder_ffn_embed_dim", 4 * args.decoder_embed_dim
)
args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False) args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 4) args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 4)

View File

@ -13,7 +13,10 @@ from torch import nn
import torch.nn.functional as F import torch.nn.functional as F
from fairseq.data.audio.audio_utils import ( from fairseq.data.audio.audio_utils import (
get_window, get_fourier_basis, get_mel_filters, TTSSpectrogram get_window,
get_fourier_basis,
get_mel_filters,
TTSSpectrogram,
) )
from fairseq.data.audio.speech_to_text_dataset import S2TDataConfig from fairseq.data.audio.speech_to_text_dataset import S2TDataConfig
from fairseq.models.text_to_speech.hifigan import Generator as HiFiGANModel from fairseq.models.text_to_speech.hifigan import Generator as HiFiGANModel
@ -25,11 +28,9 @@ class PseudoInverseMelScale(torch.nn.Module):
def __init__(self, n_stft, n_mels, sample_rate, f_min, f_max) -> None: def __init__(self, n_stft, n_mels, sample_rate, f_min, f_max) -> None:
super(PseudoInverseMelScale, self).__init__() super(PseudoInverseMelScale, self).__init__()
self.n_mels = n_mels self.n_mels = n_mels
basis = get_mel_filters( basis = get_mel_filters(sample_rate, (n_stft - 1) * 2, n_mels, f_min, f_max)
sample_rate, (n_stft - 1) * 2, n_mels, f_min, f_max
)
basis = torch.pinverse(basis) # F x F_mel basis = torch.pinverse(basis) # F x F_mel
self.register_buffer('basis', basis) self.register_buffer("basis", basis)
def forward(self, melspec: torch.Tensor) -> torch.Tensor: def forward(self, melspec: torch.Tensor) -> torch.Tensor:
# pack batch # pack batch
@ -48,8 +49,12 @@ class PseudoInverseMelScale(torch.nn.Module):
class GriffinLim(torch.nn.Module): class GriffinLim(torch.nn.Module):
def __init__( def __init__(
self, n_fft: int, win_length: int, hop_length: int, n_iter: int, self,
window_fn=torch.hann_window n_fft: int,
win_length: int,
hop_length: int,
n_iter: int,
window_fn=torch.hann_window,
): ):
super(GriffinLim, self).__init__() super(GriffinLim, self).__init__()
self.transform = TTSSpectrogram( self.transform = TTSSpectrogram(
@ -59,7 +64,7 @@ class GriffinLim(torch.nn.Module):
basis = get_fourier_basis(n_fft) basis = get_fourier_basis(n_fft)
basis = torch.pinverse(n_fft / hop_length * basis).T[:, None, :] basis = torch.pinverse(n_fft / hop_length * basis).T[:, None, :]
basis *= get_window(window_fn, n_fft, win_length) basis *= get_window(window_fn, n_fft, win_length)
self.register_buffer('basis', basis) self.register_buffer("basis", basis)
self.n_fft = n_fft self.n_fft = n_fft
self.win_length = win_length self.win_length = win_length
@ -70,33 +75,33 @@ class GriffinLim(torch.nn.Module):
@classmethod @classmethod
def get_window_sum_square( def get_window_sum_square(
cls, n_frames, hop_length, win_length, n_fft, cls, n_frames, hop_length, win_length, n_fft, window_fn=torch.hann_window
window_fn=torch.hann_window
) -> torch.Tensor: ) -> torch.Tensor:
w_sq = get_window(window_fn, n_fft, win_length) ** 2 w_sq = get_window(window_fn, n_fft, win_length) ** 2
n = n_fft + hop_length * (n_frames - 1) n = n_fft + hop_length * (n_frames - 1)
x = torch.zeros(n, dtype=torch.float32) x = torch.zeros(n, dtype=torch.float32)
for i in range(n_frames): for i in range(n_frames):
ofst = i * hop_length ofst = i * hop_length
x[ofst: min(n, ofst + n_fft)] += w_sq[:max(0, min(n_fft, n - ofst))] x[ofst : min(n, ofst + n_fft)] += w_sq[: max(0, min(n_fft, n - ofst))]
return x return x
def inverse(self, magnitude: torch.Tensor, phase) -> torch.Tensor: def inverse(self, magnitude: torch.Tensor, phase) -> torch.Tensor:
x = torch.cat( x = torch.cat(
[magnitude * torch.cos(phase), magnitude * torch.sin(phase)], [magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1
dim=1
) )
x = F.conv_transpose1d(x, self.basis, stride=self.hop_length) x = F.conv_transpose1d(x, self.basis, stride=self.hop_length)
win_sum_sq = self.get_window_sum_square( win_sum_sq = self.get_window_sum_square(
magnitude.shape[-1], hop_length=self.hop_length, magnitude.shape[-1],
win_length=self.win_length, n_fft=self.n_fft hop_length=self.hop_length,
win_length=self.win_length,
n_fft=self.n_fft,
).to(magnitude.device) ).to(magnitude.device)
# remove modulation effects # remove modulation effects
approx_nonzero_indices = win_sum_sq > self.tiny approx_nonzero_indices = win_sum_sq > self.tiny
x[:, :, approx_nonzero_indices] /= win_sum_sq[approx_nonzero_indices] x[:, :, approx_nonzero_indices] /= win_sum_sq[approx_nonzero_indices]
x *= self.n_fft / self.hop_length x *= self.n_fft / self.hop_length
x = x[:, :, self.n_fft // 2:] x = x[:, :, self.n_fft // 2 :]
x = x[:, :, :-self.n_fft // 2:] x = x[:, :, : -self.n_fft // 2 :]
return x return x
def forward(self, specgram: torch.Tensor) -> torch.Tensor: def forward(self, specgram: torch.Tensor) -> torch.Tensor:
@ -111,18 +116,33 @@ class GriffinLim(torch.nn.Module):
class GriffinLimVocoder(nn.Module): class GriffinLimVocoder(nn.Module):
def __init__(self, sample_rate, win_size, hop_size, n_fft, def __init__(
n_mels, f_min, f_max, window_fn, self,
spec_bwd_max_iter=32, sample_rate,
fp16=False): win_size,
hop_size,
n_fft,
n_mels,
f_min,
f_max,
window_fn,
spec_bwd_max_iter=32,
fp16=False,
):
super().__init__() super().__init__()
self.inv_mel_transform = PseudoInverseMelScale( self.inv_mel_transform = PseudoInverseMelScale(
n_stft=n_fft // 2 + 1, n_mels=n_mels, sample_rate=sample_rate, n_stft=n_fft // 2 + 1,
f_min=f_min, f_max=f_max n_mels=n_mels,
sample_rate=sample_rate,
f_min=f_min,
f_max=f_max,
) )
self.gl_transform = GriffinLim( self.gl_transform = GriffinLim(
n_fft=n_fft, win_length=win_size, hop_length=hop_size, n_fft=n_fft,
window_fn=window_fn, n_iter=spec_bwd_max_iter win_length=win_size,
hop_length=hop_size,
window_fn=window_fn,
n_iter=spec_bwd_max_iter,
) )
if fp16: if fp16:
self.half() self.half()
@ -151,17 +171,19 @@ class GriffinLimVocoder(nn.Module):
sample_rate=feat_cfg["sample_rate"], sample_rate=feat_cfg["sample_rate"],
win_size=int(feat_cfg["win_len_t"] * feat_cfg["sample_rate"]), win_size=int(feat_cfg["win_len_t"] * feat_cfg["sample_rate"]),
hop_size=int(feat_cfg["hop_len_t"] * feat_cfg["sample_rate"]), hop_size=int(feat_cfg["hop_len_t"] * feat_cfg["sample_rate"]),
n_fft=feat_cfg["n_fft"], n_mels=feat_cfg["n_mels"], n_fft=feat_cfg["n_fft"],
f_min=feat_cfg["f_min"], f_max=feat_cfg["f_max"], n_mels=feat_cfg["n_mels"],
window_fn=window_fn, spec_bwd_max_iter=args.spec_bwd_max_iter, f_min=feat_cfg["f_min"],
fp16=args.fp16 f_max=feat_cfg["f_max"],
window_fn=window_fn,
spec_bwd_max_iter=args.spec_bwd_max_iter,
fp16=args.fp16,
) )
class HiFiGANVocoder(nn.Module): class HiFiGANVocoder(nn.Module):
def __init__( def __init__(
self, checkpoint_path: str, model_cfg: Dict[str, str], self, checkpoint_path: str, model_cfg: Dict[str, str], fp16: bool = False
fp16: bool = False
) -> None: ) -> None:
super().__init__() super().__init__()
self.model = HiFiGANModel(model_cfg) self.model = HiFiGANModel(model_cfg)

View File

@ -29,8 +29,8 @@ from torch import Tensor
# rewrite name for backward compatibility in `make_generation_fast_` # rewrite name for backward compatibility in `make_generation_fast_`
def module_name_fordropout(module_name: str) -> str: def module_name_fordropout(module_name: str) -> str:
if module_name == 'TransformerDecoderBase': if module_name == "TransformerDecoderBase":
return 'TransformerDecoder' return "TransformerDecoder"
else: else:
return module_name return module_name

View File

@ -29,8 +29,8 @@ from fairseq.models.transformer import (
# rewrite name for backward compatibility in `make_generation_fast_` # rewrite name for backward compatibility in `make_generation_fast_`
def module_name_fordropout(module_name: str) -> str: def module_name_fordropout(module_name: str) -> str:
if module_name == 'TransformerEncoderBase': if module_name == "TransformerEncoderBase":
return 'TransformerEncoder' return "TransformerEncoder"
else: else:
return module_name return module_name
@ -232,7 +232,12 @@ class TransformerEncoderBase(FairseqEncoder):
# `forward` so we use a dictionary instead. # `forward` so we use a dictionary instead.
# TorchScript does not support mixed values so the values are all lists. # TorchScript does not support mixed values so the values are all lists.
# The empty list is equivalent to None. # The empty list is equivalent to None.
src_lengths = src_tokens.ne(self.padding_idx).sum(dim=1, dtype=torch.int32).reshape(-1, 1).contiguous() src_lengths = (
src_tokens.ne(self.padding_idx)
.sum(dim=1, dtype=torch.int32)
.reshape(-1, 1)
.contiguous()
)
return { return {
"encoder_out": [x], # T x B x C "encoder_out": [x], # T x B x C
"encoder_padding_mask": [encoder_padding_mask], # B x T "encoder_padding_mask": [encoder_padding_mask], # B x T

View File

@ -15,7 +15,9 @@ from fairseq.models import (
register_model_architecture, register_model_architecture,
) )
from fairseq.models.transformer import ( from fairseq.models.transformer import (
DEFAULT_MIN_PARAMS_TO_WRAP, Embedding, TransformerDecoder DEFAULT_MIN_PARAMS_TO_WRAP,
Embedding,
TransformerDecoder,
) )
from fairseq.modules import AdaptiveInput, CharacterTokenEmbedder from fairseq.modules import AdaptiveInput, CharacterTokenEmbedder
from fairseq.utils import safe_getattr, safe_hasattr from fairseq.utils import safe_getattr, safe_hasattr
@ -179,7 +181,7 @@ class TransformerLanguageModelConfig(FairseqDataclass):
"is set to 0 (i.e., always wrap) when --checkpoint-activations or " "is set to 0 (i.e., always wrap) when --checkpoint-activations or "
"--offload-activations are passed." "--offload-activations are passed."
) )
} },
) )
# config for "BASE Layers: Simplifying Training of Large, Sparse Models" # config for "BASE Layers: Simplifying Training of Large, Sparse Models"
base_layers: Optional[int] = field( base_layers: Optional[int] = field(
@ -189,13 +191,25 @@ class TransformerLanguageModelConfig(FairseqDataclass):
default=1, metadata={"help": "number of sublayers in each BASE layer"} default=1, metadata={"help": "number of sublayers in each BASE layer"}
) )
base_shuffle: Optional[int] = field( base_shuffle: Optional[int] = field(
default=1, metadata={"help": "shuffle tokens between workers before computing assignment"} default=1,
metadata={"help": "shuffle tokens between workers before computing assignment"},
) )
# NormFormer # NormFormer
scale_fc: Optional[bool] = field(default=False, metadata={"help": 'Insert LayerNorm between fully connected layers'}) scale_fc: Optional[bool] = field(
scale_attn: Optional[bool] = field(default=False, metadata={"help": 'Insert LayerNorm after attention'}) default=False,
scale_heads: Optional[bool] = field(default=False, metadata={"help": 'Learn a scale coefficient for each attention head'}) metadata={"help": "Insert LayerNorm between fully connected layers"},
scale_resids: Optional[bool] = field(default=False, metadata={"help": 'Learn a scale coefficient for each residual connection'}) )
scale_attn: Optional[bool] = field(
default=False, metadata={"help": "Insert LayerNorm after attention"}
)
scale_heads: Optional[bool] = field(
default=False,
metadata={"help": "Learn a scale coefficient for each attention head"},
)
scale_resids: Optional[bool] = field(
default=False,
metadata={"help": "Learn a scale coefficient for each residual connection"},
)
# options from other parts of the config # options from other parts of the config
add_bos_token: bool = II("task.add_bos_token") add_bos_token: bool = II("task.add_bos_token")
tokens_per_sample: int = II("task.tokens_per_sample") tokens_per_sample: int = II("task.tokens_per_sample")
@ -345,7 +359,9 @@ def base_lm_architecture(args):
args.decoder_output_dim = safe_getattr( args.decoder_output_dim = safe_getattr(
args, "decoder_output_dim", args.decoder_embed_dim args, "decoder_output_dim", args.decoder_embed_dim
) )
args.decoder_input_dim = safe_getattr(args, "decoder_input_dim", args.decoder_embed_dim) args.decoder_input_dim = safe_getattr(
args, "decoder_input_dim", args.decoder_embed_dim
)
# Model training is not stable without this # Model training is not stable without this
args.decoder_normalize_before = True args.decoder_normalize_before = True
@ -362,10 +378,10 @@ def base_lm_architecture(args):
args.layernorm_embedding = safe_getattr(args, "layernorm_embedding", False) args.layernorm_embedding = safe_getattr(args, "layernorm_embedding", False)
args.checkpoint_activations = safe_getattr(args, "checkpoint_activations", False) args.checkpoint_activations = safe_getattr(args, "checkpoint_activations", False)
args.offload_activations = safe_getattr(args, "offload_activations", False) args.offload_activations = safe_getattr(args, "offload_activations", False)
args.scale_fc = safe_getattr(args, 'scale_fc', False) args.scale_fc = safe_getattr(args, "scale_fc", False)
args.scale_attn = safe_getattr(args, 'scale_attn', False) args.scale_attn = safe_getattr(args, "scale_attn", False)
args.scale_heads = safe_getattr(args, 'scale_heads', False) args.scale_heads = safe_getattr(args, "scale_heads", False)
args.scale_resids = safe_getattr(args, 'scale_resids', False) args.scale_resids = safe_getattr(args, "scale_resids", False)
if args.offload_activations: if args.offload_activations:
args.checkpoint_activations = True args.checkpoint_activations = True
@ -387,7 +403,9 @@ def transformer_lm_baevski_wiki103(args):
args.dropout = safe_getattr(args, "dropout", 0.3) args.dropout = safe_getattr(args, "dropout", 0.3)
args.adaptive_input = safe_getattr(args, "adaptive_input", True) args.adaptive_input = safe_getattr(args, "adaptive_input", True)
args.tie_adaptive_weights = safe_getattr(args, "tie_adaptive_weights", True) args.tie_adaptive_weights = safe_getattr(args, "tie_adaptive_weights", True)
args.adaptive_input_cutoff = safe_getattr(args, "adaptive_input_cutoff", "20000,60000") args.adaptive_input_cutoff = safe_getattr(
args, "adaptive_input_cutoff", "20000,60000"
)
args.adaptive_softmax_cutoff = safe_getattr( args.adaptive_softmax_cutoff = safe_getattr(
args, "adaptive_softmax_cutoff", "20000,60000" args, "adaptive_softmax_cutoff", "20000,60000"
) )
@ -472,7 +490,9 @@ def transformer_lm_gpt2_big(args):
def base_gpt3_architecture(args): def base_gpt3_architecture(args):
args.decoder_input_dim = args.decoder_embed_dim args.decoder_input_dim = args.decoder_embed_dim
args.decoder_output_dim = args.decoder_embed_dim args.decoder_output_dim = args.decoder_embed_dim
args.decoder_ffn_embed_dim = safe_getattr(args, "decoder_ffn_embed_dim", args.decoder_embed_dim * 4) args.decoder_ffn_embed_dim = safe_getattr(
args, "decoder_ffn_embed_dim", args.decoder_embed_dim * 4
)
# GPT-3 used learned positional embeddings, rather than sinusoidal # GPT-3 used learned positional embeddings, rather than sinusoidal
args.decoder_learned_pos = safe_getattr(args, "decoder_learned_pos", True) args.decoder_learned_pos = safe_getattr(args, "decoder_learned_pos", True)
args.dropout = safe_getattr(args, "dropout", 0.0) args.dropout = safe_getattr(args, "dropout", 0.0)

View File

@ -232,9 +232,11 @@ class Wav2Vec2Config(FairseqDataclass):
) )
checkpoint_activations: bool = field( checkpoint_activations: bool = field(
default=False, metadata={"help": "recompute activations and save memory for extra compute"} default=False,
metadata={"help": "recompute activations and save memory for extra compute"},
) )
@register_model("wav2vec2", dataclass=Wav2Vec2Config) @register_model("wav2vec2", dataclass=Wav2Vec2Config)
class Wav2Vec2Model(BaseFairseqModel): class Wav2Vec2Model(BaseFairseqModel):
def __init__(self, cfg: Wav2Vec2Config): def __init__(self, cfg: Wav2Vec2Config):
@ -844,14 +846,14 @@ class TransformerEncoder(nn.Module):
layers = [] layers = []
for _ in range(args.encoder_layers): for _ in range(args.encoder_layers):
layer = TransformerSentenceEncoderLayer( layer = TransformerSentenceEncoderLayer(
embedding_dim=self.embedding_dim, embedding_dim=self.embedding_dim,
ffn_embedding_dim=args.encoder_ffn_embed_dim, ffn_embedding_dim=args.encoder_ffn_embed_dim,
num_attention_heads=args.encoder_attention_heads, num_attention_heads=args.encoder_attention_heads,
dropout=self.dropout, dropout=self.dropout,
attention_dropout=args.attention_dropout, attention_dropout=args.attention_dropout,
activation_dropout=args.activation_dropout, activation_dropout=args.activation_dropout,
activation_fn=args.activation_fn, activation_fn=args.activation_fn,
layer_norm_first=args.layer_norm_first, layer_norm_first=args.layer_norm_first,
) )
if args.checkpoint_activations: if args.checkpoint_activations:
layer = fsdp_wrap(layer) layer = fsdp_wrap(layer)

View File

@ -152,10 +152,12 @@ class Wav2Vec2AsrConfig(FairseqDataclass):
w2v_args: Any = None w2v_args: Any = None
checkpoint_activations: bool = field( checkpoint_activations: bool = field(
default=False, metadata={"help": "recompute activations and save memory for extra compute"} default=False,
metadata={"help": "recompute activations and save memory for extra compute"},
) )
ddp_backend: str = II("distributed_training.ddp_backend") ddp_backend: str = II("distributed_training.ddp_backend")
@dataclass @dataclass
class Wav2Vec2CtcConfig(Wav2Vec2AsrConfig): class Wav2Vec2CtcConfig(Wav2Vec2AsrConfig):
blank_weight: float = 0 blank_weight: float = 0
@ -268,6 +270,7 @@ class Wav2Vec2Seq2SeqConfig(Wav2Vec2AsrConfig):
) )
autoregressive: bool = II("task.autoregressive") autoregressive: bool = II("task.autoregressive")
@register_model("wav2vec_seq2seq", dataclass=Wav2Vec2Seq2SeqConfig) @register_model("wav2vec_seq2seq", dataclass=Wav2Vec2Seq2SeqConfig)
class Wav2Vec2Seq2SeqModel(FairseqEncoderDecoderModel): class Wav2Vec2Seq2SeqModel(FairseqEncoderDecoderModel):
def __init__(self, encoder, decoder): def __init__(self, encoder, decoder):
@ -394,12 +397,17 @@ class Wav2VecEncoder(FairseqEncoder):
def load_model_weights(self, state, model, cfg): def load_model_weights(self, state, model, cfg):
if cfg.ddp_backend == "fully_sharded": if cfg.ddp_backend == "fully_sharded":
from fairseq.distributed import FullyShardedDataParallel from fairseq.distributed import FullyShardedDataParallel
for name, module in model.named_modules(): for name, module in model.named_modules():
if "encoder.layers" in name and len(name.split(".")) == 3: if "encoder.layers" in name and len(name.split(".")) == 3:
# Only for layers, we do a special handling and load the weights one by one # Only for layers, we do a special handling and load the weights one by one
# We dont load all weights together as that wont be memory efficient and may # We dont load all weights together as that wont be memory efficient and may
# cause oom # cause oom
new_dict = {k.replace(name+".", "") : v for (k, v) in state["model"].items() if name+"." in k} new_dict = {
k.replace(name + ".", ""): v
for (k, v) in state["model"].items()
if name + "." in k
}
assert isinstance(module, FullyShardedDataParallel) assert isinstance(module, FullyShardedDataParallel)
with module.summon_full_params(): with module.summon_full_params():
module.load_state_dict(new_dict, strict=True) module.load_state_dict(new_dict, strict=True)
@ -409,7 +417,9 @@ class Wav2VecEncoder(FairseqEncoder):
r = re.compile("encoder.layers.\d.") r = re.compile("encoder.layers.\d.")
filtered_list = list(filter(r.match, state["model"].keys())) filtered_list = list(filter(r.match, state["model"].keys()))
new_big_dict = {k: v for (k, v) in state["model"].items() if k not in filtered_list} new_big_dict = {
k: v for (k, v) in state["model"].items() if k not in filtered_list
}
model.load_state_dict(new_big_dict, strict=False) model.load_state_dict(new_big_dict, strict=False)
else: else:
@ -462,9 +472,9 @@ class Wav2VecEncoder(FairseqEncoder):
1, new_order 1, new_order
) )
if encoder_out["padding_mask"] is not None: if encoder_out["padding_mask"] is not None:
encoder_out["padding_mask"] = encoder_out[ encoder_out["padding_mask"] = encoder_out["padding_mask"].index_select(
"padding_mask" 0, new_order
].index_select(0, new_order) )
return encoder_out return encoder_out
def max_positions(self): def max_positions(self):
@ -640,7 +650,7 @@ class TransformerDecoder(FairseqIncrementalDecoder):
self_attn_mask=self.buffered_future_mask(x) self_attn_mask=self.buffered_future_mask(x)
if incremental_state is None if incremental_state is None
else None, else None,
self_attn_padding_mask=self_attn_padding_mask self_attn_padding_mask=self_attn_padding_mask,
) )
inner_states.append(x) inner_states.append(x)

View File

@ -12,14 +12,17 @@ from fairseq.modules.layer_norm import LayerNorm
class BaseLayer(nn.Module): class BaseLayer(nn.Module):
def __init__(self, args): def __init__(self, args):
super().__init__() super().__init__()
self.num_workers = distributed_utils.get_data_parallel_world_size() self.num_workers = distributed_utils.get_data_parallel_world_size()
expert_centroids = torch.empty(self.num_workers, args.decoder_embed_dim) expert_centroids = torch.empty(self.num_workers, args.decoder_embed_dim)
torch.nn.init.orthogonal_(expert_centroids, gain=0.1) torch.nn.init.orthogonal_(expert_centroids, gain=0.1)
self.register_parameter("expert_centroids", torch.nn.Parameter(expert_centroids)) self.register_parameter(
self.expert_network = nn.Sequential(*([BaseSublayer(args) for _ in range(args.base_sublayers)])) "expert_centroids", torch.nn.Parameter(expert_centroids)
)
self.expert_network = nn.Sequential(
*([BaseSublayer(args) for _ in range(args.base_sublayers)])
)
self.expert_id = distributed_utils.get_data_parallel_rank() self.expert_id = distributed_utils.get_data_parallel_rank()
self.shuffle = args.base_shuffle self.shuffle = args.base_shuffle
self.cpp = self.load_assignment() self.cpp = self.load_assignment()
@ -39,20 +42,34 @@ class BaseLayer(nn.Module):
with torch.no_grad(): with torch.no_grad():
# Compute similarity of each token to each expert, for routing # Compute similarity of each token to each expert, for routing
token_expert_affinities = features.matmul(self.expert_centroids.transpose(0, 1)) token_expert_affinities = features.matmul(
self.expert_centroids.transpose(0, 1)
)
# Compute which token goes to which expert # Compute which token goes to which expert
sort_by_expert, input_splits, output_splits = self.balanced_assignment(token_expert_affinities) \ sort_by_expert, input_splits, output_splits = (
if is_training else self.greedy_assignment(token_expert_affinities) self.balanced_assignment(token_expert_affinities)
if is_training
else self.greedy_assignment(token_expert_affinities)
)
# Swap these tokens for the right ones for our expert # Swap these tokens for the right ones for our expert
routed_features = All2All.apply(features[sort_by_expert], output_splits, input_splits) routed_features = All2All.apply(
features[sort_by_expert], output_splits, input_splits
)
if routed_features.size(0) > 0: if routed_features.size(0) > 0:
# Mix in the expert network based on how appropriate it is for these tokens # Mix in the expert network based on how appropriate it is for these tokens
alpha = torch.sigmoid(routed_features.mv(self.expert_centroids[self.expert_id])).unsqueeze(1) alpha = torch.sigmoid(
routed_features = alpha * self.expert_network(routed_features) + (1 - alpha) * routed_features routed_features.mv(self.expert_centroids[self.expert_id])
).unsqueeze(1)
routed_features = (
alpha * self.expert_network(routed_features)
+ (1 - alpha) * routed_features
)
# Return to original worker and ordering # Return to original worker and ordering
result = All2All.apply(routed_features, input_splits, output_splits)[self.inverse_sort(sort_by_expert)] result = All2All.apply(routed_features, input_splits, output_splits)[
self.inverse_sort(sort_by_expert)
]
if self.shuffle and is_training: if self.shuffle and is_training:
# Undo shuffling # Undo shuffling
@ -63,7 +80,9 @@ class BaseLayer(nn.Module):
def inverse_sort(self, order): def inverse_sort(self, order):
# Creates an index that undoes a sort: xs==xs[order][inverse_sort(order)] # Creates an index that undoes a sort: xs==xs[order][inverse_sort(order)]
return torch.empty_like(order).scatter_(0, order, torch.arange(0, order.size(0), device=order.device)) return torch.empty_like(order).scatter_(
0, order, torch.arange(0, order.size(0), device=order.device)
)
def balanced_assignment(self, scores): def balanced_assignment(self, scores):
ok = scores.isfinite() ok = scores.isfinite()
@ -79,7 +98,9 @@ class BaseLayer(nn.Module):
worker2token = sort_ordering // k worker2token = sort_ordering // k
# Find how many tokens we're sending to each other worker (being careful for sending 0 tokens to some workers) # Find how many tokens we're sending to each other worker (being careful for sending 0 tokens to some workers)
output_splits = torch.zeros((self.num_workers,), dtype=torch.long, device=scores.device) output_splits = torch.zeros(
(self.num_workers,), dtype=torch.long, device=scores.device
)
workers, counts = torch.unique_consecutive(token_to_workers, return_counts=True) workers, counts = torch.unique_consecutive(token_to_workers, return_counts=True)
output_splits[workers] = counts output_splits[workers] = counts
# Tell other workers how many tokens to expect from us # Tell other workers how many tokens to expect from us
@ -103,7 +124,7 @@ class BaseSublayer(nn.Module):
def __init__(self, args): def __init__(self, args):
super().__init__() super().__init__()
self.activation_fn = utils.get_activation_fn( self.activation_fn = utils.get_activation_fn(
activation=getattr(args, 'activation_fn', 'relu') or "relu" activation=getattr(args, "activation_fn", "relu") or "relu"
) )
self.norm = LayerNorm(args.decoder_embed_dim, export=False) self.norm = LayerNorm(args.decoder_embed_dim, export=False)
self.ff1 = torch.nn.Linear(args.decoder_embed_dim, args.decoder_ffn_embed_dim) self.ff1 = torch.nn.Linear(args.decoder_embed_dim, args.decoder_ffn_embed_dim)
@ -121,15 +142,29 @@ class All2All(torch.autograd.Function):
ctx.input_splits = input_splits ctx.input_splits = input_splits
ctx.output_splits = output_splits ctx.output_splits = output_splits
ys = torch.empty_like(xs) if output_splits is None else \ ys = (
xs.new_empty(size=[sum(output_splits)] + list(xs.size()[1:])) torch.empty_like(xs)
torch.distributed.all_to_all_single(ys, xs, output_split_sizes=output_splits, input_split_sizes=input_splits) if output_splits is None
else xs.new_empty(size=[sum(output_splits)] + list(xs.size()[1:]))
)
torch.distributed.all_to_all_single(
ys, xs, output_split_sizes=output_splits, input_split_sizes=input_splits
)
return ys return ys
@staticmethod @staticmethod
def backward(ctx, grad_output): def backward(ctx, grad_output):
result = torch.empty_like(grad_output) if ctx.input_splits is None else \ result = (
grad_output.new_empty(size=[sum(ctx.input_splits)] + list(grad_output.size()[1:])) torch.empty_like(grad_output)
torch.distributed.all_to_all_single(result, grad_output, if ctx.input_splits is None
output_split_sizes=ctx.input_splits, input_split_sizes=ctx.output_splits) else grad_output.new_empty(
size=[sum(ctx.input_splits)] + list(grad_output.size()[1:])
)
)
torch.distributed.all_to_all_single(
result,
grad_output,
output_split_sizes=ctx.input_splits,
input_split_sizes=ctx.output_splits,
)
return result, None, None return result, None, None

View File

@ -166,7 +166,9 @@ class CheckpointFunction(torch.autograd.Function):
if parent_ctx_dict["offload"]: if parent_ctx_dict["offload"]:
ctx.fwd_device = tuple(x.device for x in tensor_inputs) ctx.fwd_device = tuple(x.device for x in tensor_inputs)
ctx.grad_requirements = tuple(x.requires_grad for x in tensor_inputs) ctx.grad_requirements = tuple(x.requires_grad for x in tensor_inputs)
tensor_inputs = tuple(x.to(torch.device("cpu"), non_blocking=True) for x in tensor_inputs) tensor_inputs = tuple(
x.to(torch.device("cpu"), non_blocking=True) for x in tensor_inputs
)
else: else:
ctx.fwd_device, ctx.grad_requirements = None, None ctx.fwd_device, ctx.grad_requirements = None, None
@ -199,7 +201,8 @@ class CheckpointFunction(torch.autograd.Function):
tensor_inputs = checkpoint.detach_variable(tensor_inputs) tensor_inputs = checkpoint.detach_variable(tensor_inputs)
if ctx.fwd_device is not None: if ctx.fwd_device is not None:
tensor_inputs = [ tensor_inputs = [
t.to(ctx.fwd_device[i], non_blocking=True) for i, t in enumerate(tensor_inputs) t.to(ctx.fwd_device[i], non_blocking=True)
for i, t in enumerate(tensor_inputs)
] ]
for i, need_grad in enumerate(ctx.grad_requirements): for i, need_grad in enumerate(ctx.grad_requirements):
tensor_inputs[i].requires_grad = need_grad tensor_inputs[i].requires_grad = need_grad

View File

@ -75,6 +75,7 @@ class GumbelVectorQuantizer(nn.Module):
if isinstance(temp, str): if isinstance(temp, str):
import ast import ast
temp = ast.literal_eval(temp) temp = ast.literal_eval(temp)
assert len(temp) == 3, f"{temp}, {len(temp)}" assert len(temp) == 3, f"{temp}, {len(temp)}"

View File

@ -47,11 +47,12 @@ def cache_fn(f):
return cache return cache
cache = f(*args, **kwargs) cache = f(*args, **kwargs)
return cache return cache
return cached_fn return cached_fn
def to(t): def to(t):
return {'device': t.device, 'dtype': t.dtype} return {"device": t.device, "dtype": t.dtype}
def find_modules(nn_module, type): def find_modules(nn_module, type):
@ -102,7 +103,7 @@ def reshape_dim(t, dim, split_dims):
shape = list(t.shape) shape = list(t.shape)
num_dims = len(shape) num_dims = len(shape)
dim = (dim + num_dims) % num_dims dim = (dim + num_dims) % num_dims
shape[dim:dim+1] = split_dims shape[dim : dim + 1] = split_dims
return t.reshape(shape) return t.reshape(shape)
@ -118,6 +119,7 @@ def ema_inplace(moving_avg, new, decay):
return return
moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay)) moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
# helper classes # helper classes
@ -173,6 +175,7 @@ class ScaleNorm(nn.Module):
def norm(t): def norm(t):
n = torch.norm(t, dim=-1, keepdim=True).clamp(min=self.eps) n = torch.norm(t, dim=-1, keepdim=True).clamp(min=self.eps)
return t / n * self.g return t / n * self.g
return map_first_tuple_or_el(x, norm) return map_first_tuple_or_el(x, norm)
@ -202,51 +205,62 @@ class MatrixMultiply(nn.Module):
tensor = tensor.t() tensor = tensor.t()
return x @ tensor return x @ tensor
# positional embeddings # positional embeddings
class DepthWiseConv1d(nn.Module): class DepthWiseConv1d(nn.Module):
def __init__(self, dim_in, dim_out, kernel_size, stride=1, bias=True, causal=False): def __init__(self, dim_in, dim_out, kernel_size, stride=1, bias=True, causal=False):
super().__init__() super().__init__()
self.padding = ((kernel_size - 1), 0) if causal else (kernel_size // 2, kernel_size // 2) self.padding = (
((kernel_size - 1), 0) if causal else (kernel_size // 2, kernel_size // 2)
)
self.net = nn.Sequential( self.net = nn.Sequential(
nn.Conv1d(dim_in, dim_in, kernel_size=kernel_size, groups=dim_in, stride=stride, bias=bias), nn.Conv1d(
nn.Conv1d(dim_in, dim_out, 1, bias=bias) dim_in,
dim_in,
kernel_size=kernel_size,
groups=dim_in,
stride=stride,
bias=bias,
),
nn.Conv1d(dim_in, dim_out, 1, bias=bias),
) )
def forward(self, x): def forward(self, x):
x = F.pad(x, self.padding, value=0.) x = F.pad(x, self.padding, value=0.0)
return self.net(x) return self.net(x)
class FixedPositionalEmbedding(nn.Module): class FixedPositionalEmbedding(nn.Module):
def __init__(self, dim, max_seq_len): def __init__(self, dim, max_seq_len):
super().__init__() super().__init__()
inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim)) inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
position = torch.arange(0, max_seq_len, dtype=torch.float) position = torch.arange(0, max_seq_len, dtype=torch.float)
sinusoid_inp = torch.einsum("i,j->ij", position, inv_freq) sinusoid_inp = torch.einsum("i,j->ij", position, inv_freq)
emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1) emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1)
self.register_buffer('emb', emb) self.register_buffer("emb", emb)
def forward(self, x): def forward(self, x):
return self.emb[None, :x.shape[1], :].to(x) return self.emb[None, : x.shape[1], :].to(x)
def rotate_every_two(x): def rotate_every_two(x):
x = rearrange(x, '... (d j) -> ... d j', j=2) x = rearrange(x, "... (d j) -> ... d j", j=2)
x1, x2 = x.unbind(dim=-1) x1, x2 = x.unbind(dim=-1)
x = torch.stack((-x2, x1), dim=-1) x = torch.stack((-x2, x1), dim=-1)
return rearrange(x, '... d j -> ... (d j)') return rearrange(x, "... d j -> ... (d j)")
def apply_rotary_pos_emb(q, k, sinu_pos): def apply_rotary_pos_emb(q, k, sinu_pos):
sinu_pos = rearrange(sinu_pos, '() n (j d) -> n j d', j=2) sinu_pos = rearrange(sinu_pos, "() n (j d) -> n j d", j=2)
sin, cos = sinu_pos.unbind(dim=-2) sin, cos = sinu_pos.unbind(dim=-2)
sin, cos = map(lambda t: repeat(t, 'b n -> b (n j)', j=2), (sin, cos)) sin, cos = map(lambda t: repeat(t, "b n -> b (n j)", j=2), (sin, cos))
q, k = map(lambda t: (t * cos) + (rotate_every_two(t) * sin), (q, k)) q, k = map(lambda t: (t * cos) + (rotate_every_two(t) * sin), (q, k))
return q, k return q, k
# kmeans related function and class # kmeans related function and class
@ -261,7 +275,7 @@ def update_kmeans_on_backwards(module):
def similarity(x, means): def similarity(x, means):
return torch.einsum('bhld,hcd->bhlc', x, means) return torch.einsum("bhld,hcd->bhlc", x, means)
def dists_and_buckets(x, means): def dists_and_buckets(x, means):
@ -303,13 +317,15 @@ def distribution(dists, window_size):
class Kmeans(nn.Module): class Kmeans(nn.Module):
def __init__(self, num_heads, head_dim, num_clusters, ema_decay=0.999, commitment=1e-4): def __init__(
self, num_heads, head_dim, num_clusters, ema_decay=0.999, commitment=1e-4
):
super().__init__() super().__init__()
self.commitment = commitment self.commitment = commitment
self.ema_decay = ema_decay self.ema_decay = ema_decay
self.register_buffer('means', torch.randn(num_heads, num_clusters, head_dim)) self.register_buffer("means", torch.randn(num_heads, num_clusters, head_dim))
self.register_buffer('initted', torch.tensor(False)) self.register_buffer("initted", torch.tensor(False))
self.num_new_means = 0 self.num_new_means = 0
self.new_means = None self.new_means = None
@ -341,7 +357,7 @@ class Kmeans(nn.Module):
@torch.no_grad() @torch.no_grad()
def update(self, new_means=None): def update(self, new_means=None):
new_means = default(new_means, self.new_means) new_means = default(new_means, self.new_means)
assert exists(new_means), 'new kmeans has not been supplied' assert exists(new_means), "new kmeans has not been supplied"
ema_inplace(self.means, new_means, self.ema_decay) ema_inplace(self.means, new_means, self.ema_decay)
del self.new_means del self.new_means
@ -364,16 +380,33 @@ class Kmeans(nn.Module):
if update_means: if update_means:
with torch.no_grad(): with torch.no_grad():
means = kmeans_iter(x, means, buckets) means = kmeans_iter(x, means, buckets)
self.new_means = ema(self.new_means, means, self.num_new_means / (self.num_new_means + 1)) self.new_means = ema(
self.new_means, means, self.num_new_means / (self.num_new_means + 1)
)
self.num_new_means += 1 self.num_new_means += 1
return dists, loss return dists, loss
# kmeans attention class # kmeans attention class
class KmeansAttention(nn.Module): class KmeansAttention(nn.Module):
def __init__(self, num_clusters, window_size, num_heads, head_dim, causal=False, dropout=0., ema_decay=0.999, commitment=1e-4, context_window_size=None, receives_context=False, num_mem_kv=0, shared_qk=False): def __init__(
self,
num_clusters,
window_size,
num_heads,
head_dim,
causal=False,
dropout=0.0,
ema_decay=0.999,
commitment=1e-4,
context_window_size=None,
receives_context=False,
num_mem_kv=0,
shared_qk=False,
):
super().__init__() super().__init__()
self.num_heads = num_heads self.num_heads = num_heads
self.num_clusters = num_clusters self.num_clusters = num_clusters
@ -389,18 +422,32 @@ class KmeansAttention(nn.Module):
self.dropout = nn.Dropout(dropout) self.dropout = nn.Dropout(dropout)
self.num_mem_kv = max(num_mem_kv, 1 if causal and not shared_qk else 0) self.num_mem_kv = max(num_mem_kv, 1 if causal and not shared_qk else 0)
self.mem_key = nn.Parameter(torch.randn(num_heads, num_clusters, self.num_mem_kv, head_dim)) self.mem_key = nn.Parameter(
self.mem_value = nn.Parameter(torch.randn(num_heads, num_clusters, self.num_mem_kv, head_dim)) torch.randn(num_heads, num_clusters, self.num_mem_kv, head_dim)
)
self.mem_value = nn.Parameter(
torch.randn(num_heads, num_clusters, self.num_mem_kv, head_dim)
)
def forward(self, q, k, v, query_mask=None, key_mask=None, **kwargs): def forward(self, q, k, v, query_mask=None, key_mask=None, **kwargs):
b, h, t, d, kv_t, wsz, c_wsz, nc, device, dtype = *q.shape, k.shape[2], self.window_size, self.context_window_size, self.num_clusters, q.device, q.dtype b, h, t, d, kv_t, wsz, c_wsz, nc, device, dtype = (
is_reverse = kwargs.pop('_reverse', False) *q.shape,
k.shape[2],
self.window_size,
self.context_window_size,
self.num_clusters,
q.device,
q.dtype,
)
is_reverse = kwargs.pop("_reverse", False)
out = torch.zeros_like(q, dtype=dtype) out = torch.zeros_like(q, dtype=dtype)
update_kmeans = self.training and not is_reverse update_kmeans = self.training and not is_reverse
key_mask = default(key_mask, query_mask) if not self.receives_context else key_mask key_mask = (
default(key_mask, query_mask) if not self.receives_context else key_mask
)
kv_wsz = wsz if not self.receives_context else c_wsz kv_wsz = wsz if not self.receives_context else c_wsz
wsz = min(wsz, t) wsz = min(wsz, t)
@ -424,16 +471,22 @@ class KmeansAttention(nn.Module):
reshape_with_window = lambda x: x.reshape(b, h, nc, -1, d) reshape_with_window = lambda x: x.reshape(b, h, nc, -1, d)
q, k, v = map(reshape_with_window, (q, k, v)) q, k, v = map(reshape_with_window, (q, k, v))
m_k, m_v = map(lambda x: expand_dim(x, 0, b).to(q), (self.mem_key, self.mem_value)) m_k, m_v = map(
lambda x: expand_dim(x, 0, b).to(q), (self.mem_key, self.mem_value)
)
k, v = map(lambda x: torch.cat(x, dim=3), ((m_k, k), (m_v, v))) k, v = map(lambda x: torch.cat(x, dim=3), ((m_k, k), (m_v, v)))
dots = torch.einsum('bhnid,bhnjd->bhnij', q, k) * (d ** -0.5) dots = torch.einsum("bhnid,bhnjd->bhnij", q, k) * (d ** -0.5)
mask_value = max_neg_value(dots) mask_value = max_neg_value(dots)
if exists(query_mask) or exists(key_mask): if exists(query_mask) or exists(key_mask):
query_mask = default(query_mask, lambda: torch.ones((b, t), device=device).bool()) query_mask = default(
key_mask = default(key_mask, lambda: torch.ones((b, kv_t), device=device).bool()) query_mask, lambda: torch.ones((b, t), device=device).bool()
)
key_mask = default(
key_mask, lambda: torch.ones((b, kv_t), device=device).bool()
)
q_mask = expand_dim(query_mask, 1, h).gather(2, indices) q_mask = expand_dim(query_mask, 1, h).gather(2, indices)
kv_mask = expand_dim(key_mask, 1, h).gather(2, kv_indices) kv_mask = expand_dim(key_mask, 1, h).gather(2, kv_indices)
@ -444,14 +497,18 @@ class KmeansAttention(nn.Module):
del mask del mask
if self.causal: if self.causal:
q_mask, kv_mask = map(lambda t: t.reshape(b, h, nc, -1), (indices, kv_indices)) q_mask, kv_mask = map(
lambda t: t.reshape(b, h, nc, -1), (indices, kv_indices)
)
mask = q_mask[:, :, :, :, None] >= kv_mask[:, :, :, None, :] mask = q_mask[:, :, :, :, None] >= kv_mask[:, :, :, None, :]
mask = F.pad(mask, (self.num_mem_kv, 0), value=1) mask = F.pad(mask, (self.num_mem_kv, 0), value=1)
dots.masked_fill_(~mask, mask_value) dots.masked_fill_(~mask, mask_value)
del mask del mask
if self.shared_qk: if self.shared_qk:
q_mask, kv_mask = map(lambda t: t.reshape(b, h, nc, -1), (indices, kv_indices)) q_mask, kv_mask = map(
lambda t: t.reshape(b, h, nc, -1), (indices, kv_indices)
)
mask = q_mask[:, :, :, :, None] == kv_mask[:, :, :, None, :] mask = q_mask[:, :, :, :, None] == kv_mask[:, :, :, None, :]
mask = F.pad(mask, (self.num_mem_kv, 0), value=0) mask = F.pad(mask, (self.num_mem_kv, 0), value=0)
dots.masked_fill_(mask, TOKEN_SELF_ATTN_VALUE) dots.masked_fill_(mask, TOKEN_SELF_ATTN_VALUE)
@ -460,24 +517,32 @@ class KmeansAttention(nn.Module):
dots = dots.softmax(dim=-1) dots = dots.softmax(dim=-1)
dots = self.dropout(dots) dots = self.dropout(dots)
bo = torch.einsum('bhcij,bhcjd->bhcid', dots, v) bo = torch.einsum("bhcij,bhcjd->bhcid", dots, v)
so = torch.reshape(bo, (b, h, -1, bo.shape[-1])).type(dtype) so = torch.reshape(bo, (b, h, -1, bo.shape[-1])).type(dtype)
out = scatter_mean(out, so, indices.unsqueeze(-1).expand_as(so), -2) out = scatter_mean(out, so, indices.unsqueeze(-1).expand_as(so), -2)
return out, aux_loss return out, aux_loss
# feedforward # feedforward
class GELU_(nn.Module): class GELU_(nn.Module):
def forward(self, x): def forward(self, x):
return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) return (
0.5
* x
* (
1
+ torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))
)
)
GELU = nn.GELU if hasattr(nn, 'GELU') else GELU_ GELU = nn.GELU if hasattr(nn, "GELU") else GELU_
class FeedForward(nn.Module): class FeedForward(nn.Module):
def __init__(self, dim, mult=4, dropout=0., activation=None, glu=False): def __init__(self, dim, mult=4, dropout=0.0, activation=None, glu=False):
super().__init__() super().__init__()
activation = default(activation, GELU) activation = default(activation, GELU)
@ -499,17 +564,49 @@ class FeedForward(nn.Module):
x = self.w2(x) x = self.w2(x)
return x return x
# self attention # self attention
class SelfAttention(nn.Module): class SelfAttention(nn.Module):
def __init__(self, dim, max_seq_len, heads, local_attn_heads, window_size, dim_head=None, local_attn_window_size=None, local_attn_radius_blocks=1, causal=False, attn_dropout=0., dropout=0., kmeans_ema_decay=0.999, commitment_factor=1e-4, receives_context=False, context_window_size=None, rel_pos_emb=True, num_mem_kv=0, shared_qk=False, conv_query_kernel=9): def __init__(
self,
dim,
max_seq_len,
heads,
local_attn_heads,
window_size,
dim_head=None,
local_attn_window_size=None,
local_attn_radius_blocks=1,
causal=False,
attn_dropout=0.0,
dropout=0.0,
kmeans_ema_decay=0.999,
commitment_factor=1e-4,
receives_context=False,
context_window_size=None,
rel_pos_emb=True,
num_mem_kv=0,
shared_qk=False,
conv_query_kernel=9,
):
super().__init__() super().__init__()
assert dim_head or (dim % heads) == 0, 'hidden dimension must be divisible by number of heads' assert (
assert (max_seq_len % window_size) == 0, 'maximum sequence length must be divisible by the target window size' dim_head or (dim % heads) == 0
assert local_attn_heads <= heads, 'number of local attention heads must be less than total heads' ), "hidden dimension must be divisible by number of heads"
assert not (receives_context and local_attn_heads > 0), 'local attention cannot be used for self attention with context' assert (
assert not (receives_context and causal), 'contextual attention layer cannot be causal' max_seq_len % window_size
) == 0, "maximum sequence length must be divisible by the target window size"
assert (
local_attn_heads <= heads
), "number of local attention heads must be less than total heads"
assert not (
receives_context and local_attn_heads > 0
), "local attention cannot be used for self attention with context"
assert not (
receives_context and causal
), "contextual attention layer cannot be causal"
local_attn_window_size = default(local_attn_window_size, window_size) local_attn_window_size = default(local_attn_window_size, window_size)
context_window_size = default(context_window_size, window_size) context_window_size = default(context_window_size, window_size)
@ -535,7 +632,15 @@ class SelfAttention(nn.Module):
if self.local_attn_heads > 0: if self.local_attn_heads > 0:
rel_pos_emb_config = (dim_head, local_attn_heads) if rel_pos_emb else None rel_pos_emb_config = (dim_head, local_attn_heads) if rel_pos_emb else None
self.local_attn = LocalAttention(dim=dim_head, window_size=local_attn_window_size, causal=causal, dropout=attn_dropout, rel_pos_emb_config=rel_pos_emb_config, look_backward=local_attn_radius_blocks, look_forward=0 if causal else local_attn_radius_blocks) self.local_attn = LocalAttention(
dim=dim_head,
window_size=local_attn_window_size,
causal=causal,
dropout=attn_dropout,
rel_pos_emb_config=rel_pos_emb_config,
look_backward=local_attn_radius_blocks,
look_forward=0 if causal else local_attn_radius_blocks,
)
self.local_to_qkv = nn.Linear(dim, 3 * local_dim_heads) self.local_to_qkv = nn.Linear(dim, 3 * local_dim_heads)
# global # global
@ -543,12 +648,24 @@ class SelfAttention(nn.Module):
global_dim_heads = dim_head * self.global_attn_heads global_dim_heads = dim_head * self.global_attn_heads
if self.global_attn_heads > 0: if self.global_attn_heads > 0:
self.global_attn = KmeansAttention(num_clusters, window_size, self.global_attn_heads, dim_head, causal=causal, dropout=attn_dropout, ema_decay=kmeans_ema_decay, commitment=commitment_factor, receives_context=receives_context, num_mem_kv=num_mem_kv, shared_qk=shared_qk) self.global_attn = KmeansAttention(
num_clusters,
window_size,
self.global_attn_heads,
dim_head,
causal=causal,
dropout=attn_dropout,
ema_decay=kmeans_ema_decay,
commitment=commitment_factor,
receives_context=receives_context,
num_mem_kv=num_mem_kv,
shared_qk=shared_qk,
)
self.to_q = nn.Sequential( self.to_q = nn.Sequential(
Rearrange('b n c -> b c n'), Rearrange("b n c -> b c n"),
DepthWiseConv1d(dim, global_dim_heads, conv_query_kernel, causal=causal), DepthWiseConv1d(dim, global_dim_heads, conv_query_kernel, causal=causal),
Rearrange('b c n -> b n c') Rearrange("b c n -> b n c"),
) )
self.to_v = nn.Linear(dim, global_dim_heads, bias=False) self.to_v = nn.Linear(dim, global_dim_heads, bias=False)
@ -561,14 +678,30 @@ class SelfAttention(nn.Module):
self.to_out = nn.Linear(dim_heads, dim, bias=False) self.to_out = nn.Linear(dim_heads, dim, bias=False)
self.dropout = nn.Dropout(dropout) self.dropout = nn.Dropout(dropout)
def forward(self, query, key, value, context=None, key_padding_mask=None, context_mask=None, pos_emb=None, **kwargs): def forward(
assert not (self.receives_context and not exists(context)), 'context must be passed if self attention is set to receive context' self,
query,
key,
value,
context=None,
key_padding_mask=None,
context_mask=None,
pos_emb=None,
**kwargs
):
assert not (
self.receives_context and not exists(context)
), "context must be passed if self attention is set to receive context"
input_mask = key_padding_mask input_mask = key_padding_mask
x = query.transpose(0, 1) x = query.transpose(0, 1)
b, t, _, h, dh = *x.shape, self.heads, self.dim_head b, t, _, h, dh = *x.shape, self.heads, self.dim_head
has_local, has_global = map(lambda x: x > 0, (self.local_attn_heads, self.global_attn_heads)) has_local, has_global = map(
lambda x: x > 0, (self.local_attn_heads, self.global_attn_heads)
)
split_heads = lambda v: reshape_dim(v, -1, (-1, dh)).transpose(1, 2).contiguous() split_heads = (
lambda v: reshape_dim(v, -1, (-1, dh)).transpose(1, 2).contiguous()
)
if has_local: if has_local:
local_qkv = self.local_to_qkv(x).chunk(3, dim=-1) local_qkv = self.local_to_qkv(x).chunk(3, dim=-1)
@ -587,7 +720,7 @@ class SelfAttention(nn.Module):
q, k, v = map(split_heads, (q, k, v)) q, k, v = map(split_heads, (q, k, v))
out = [] out = []
total_loss = torch.tensor(0., requires_grad=True, **to(x)) total_loss = torch.tensor(0.0, requires_grad=True, **to(x))
if has_local: if has_local:
local_out = self.local_attn(lq, lk, lv, input_mask=input_mask) local_out = self.local_attn(lq, lk, lv, input_mask=input_mask)
@ -597,7 +730,9 @@ class SelfAttention(nn.Module):
if not self.receives_context and exists(pos_emb): if not self.receives_context and exists(pos_emb):
q, k = apply_rotary_pos_emb(q, k, pos_emb) q, k = apply_rotary_pos_emb(q, k, pos_emb)
global_out, loss = self.global_attn(q, k, v, query_mask=input_mask, key_mask=context_mask) global_out, loss = self.global_attn(
q, k, v, query_mask=input_mask, key_mask=context_mask
)
total_loss = total_loss + loss total_loss = total_loss + loss
out.append(global_out) out.append(global_out)

View File

@ -13,6 +13,7 @@ from .conv_tbc import ConvTBC
from typing import Dict, Optional from typing import Dict, Optional
from torch import Tensor from torch import Tensor
@with_incremental_state @with_incremental_state
class LinearizedConvolution(ConvTBC): class LinearizedConvolution(ConvTBC):
"""An optimized version of nn.Conv1d. """An optimized version of nn.Conv1d.
@ -41,7 +42,11 @@ class LinearizedConvolution(ConvTBC):
del state_dict[prefix + "_linearized_weight"] del state_dict[prefix + "_linearized_weight"]
@torch.jit.export @torch.jit.export
def forward(self, input, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None): def forward(
self,
input,
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
):
""" """
Args: Args:
incremental_state: Used to buffer signal; if not None, then input is incremental_state: Used to buffer signal; if not None, then input is
@ -80,18 +85,28 @@ class LinearizedConvolution(ConvTBC):
return output.view(bsz, 1, -1) return output.view(bsz, 1, -1)
@torch.jit.unused @torch.jit.unused
def reorder_incremental_state(self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], new_order): def reorder_incremental_state(
self,
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]],
new_order,
):
input_buffer = self._get_input_buffer(incremental_state) input_buffer = self._get_input_buffer(incremental_state)
if input_buffer is not None: if input_buffer is not None:
input_buffer = input_buffer.index_select(0, new_order) input_buffer = input_buffer.index_select(0, new_order)
self._set_input_buffer(incremental_state, input_buffer) self._set_input_buffer(incremental_state, input_buffer)
@torch.jit.unused @torch.jit.unused
def _get_input_buffer(self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]): def _get_input_buffer(
self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]
):
return utils.get_incremental_state(self, incremental_state, "input_buffer") return utils.get_incremental_state(self, incremental_state, "input_buffer")
@torch.jit.unused @torch.jit.unused
def _set_input_buffer(self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], new_buffer): def _set_input_buffer(
self,
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]],
new_buffer,
):
return utils.set_incremental_state( return utils.set_incremental_state(
self, incremental_state, "input_buffer", new_buffer self, incremental_state, "input_buffer", new_buffer
) )

View File

@ -20,9 +20,16 @@ class LocationAttention(nn.Module):
:param int conv_kernel_size: filter size of attention convolution :param int conv_kernel_size: filter size of attention convolution
""" """
def __init__(self, attn_dim, encoder_dim, decoder_dim, def __init__(
attn_state_kernel_size, conv_dim, conv_kernel_size, self,
scaling=2.0): attn_dim,
encoder_dim,
decoder_dim,
attn_state_kernel_size,
conv_dim,
conv_kernel_size,
scaling=2.0,
):
super(LocationAttention, self).__init__() super(LocationAttention, self).__init__()
self.attn_dim = attn_dim self.attn_dim = attn_dim
self.decoder_dim = decoder_dim self.decoder_dim = decoder_dim
@ -30,9 +37,13 @@ class LocationAttention(nn.Module):
self.proj_enc = nn.Linear(encoder_dim, attn_dim) self.proj_enc = nn.Linear(encoder_dim, attn_dim)
self.proj_dec = nn.Linear(decoder_dim, attn_dim, bias=False) self.proj_dec = nn.Linear(decoder_dim, attn_dim, bias=False)
self.proj_attn = nn.Linear(conv_dim, attn_dim, bias=False) self.proj_attn = nn.Linear(conv_dim, attn_dim, bias=False)
self.conv = nn.Conv1d(attn_state_kernel_size, conv_dim, self.conv = nn.Conv1d(
2 * conv_kernel_size + 1, attn_state_kernel_size,
padding=conv_kernel_size, bias=False) conv_dim,
2 * conv_kernel_size + 1,
padding=conv_kernel_size,
bias=False,
)
self.proj_out = nn.Sequential(nn.Tanh(), nn.Linear(attn_dim, 1)) self.proj_out = nn.Sequential(nn.Tanh(), nn.Linear(attn_dim, 1))
self.proj_enc_out = None # cache self.proj_enc_out = None # cache

View File

@ -12,20 +12,20 @@ class LSTMCellWithZoneOut(nn.Module):
https://arxiv.org/abs/1606.01305 https://arxiv.org/abs/1606.01305
""" """
def __init__(self, prob: float, input_size: int, hidden_size: int, def __init__(
bias: bool = True): self, prob: float, input_size: int, hidden_size: int, bias: bool = True
):
super(LSTMCellWithZoneOut, self).__init__() super(LSTMCellWithZoneOut, self).__init__()
self.lstm_cell = nn.LSTMCell(input_size, hidden_size, bias=bias) self.lstm_cell = nn.LSTMCell(input_size, hidden_size, bias=bias)
self.prob = prob self.prob = prob
if prob > 1.0 or prob < 0.0: if prob > 1.0 or prob < 0.0:
raise ValueError("zoneout probability must be in the range from " raise ValueError(
"0.0 to 1.0.") "zoneout probability must be in the range from " "0.0 to 1.0."
)
def zoneout(self, h, next_h, prob): def zoneout(self, h, next_h, prob):
if isinstance(h, tuple): if isinstance(h, tuple):
return tuple( return tuple([self.zoneout(h[i], next_h[i], prob) for i in range(len(h))])
[self.zoneout(h[i], next_h[i], prob) for i in range(len(h))]
)
if self.training: if self.training:
mask = h.new_zeros(*h.size()).bernoulli_(prob) mask = h.new_zeros(*h.size()).bernoulli_(prob)

View File

@ -60,7 +60,9 @@ def quantize_model_(
to layers_to_quantize[step] to layers_to_quantize[step]
""" """
quantized_layers = get_layers(model, layers_to_quantize[step], remove_weights=remove_weights) quantized_layers = get_layers(
model, layers_to_quantize[step], remove_weights=remove_weights
)
for layer in quantized_layers: for layer in quantized_layers:
@ -108,8 +110,8 @@ def quantize_model_(
centroids = torch.rand(centroids.size()) centroids = torch.rand(centroids.size())
centroids.cuda() centroids.cuda()
# Get counts and assignment keys from layer in loaded checkpoint. # Get counts and assignment keys from layer in loaded checkpoint.
counts_key = layer+"."+"counts" counts_key = layer + "." + "counts"
assignment_key = layer+"."+"assignments" assignment_key = layer + "." + "assignments"
# Get number of different bins to include. # Get number of different bins to include.
counts = list(state_dict[counts_key].shape)[0] counts = list(state_dict[counts_key].shape)[0]
print(layer) print(layer)
@ -122,7 +124,7 @@ def quantize_model_(
print(num_assignments) print(num_assignments)
print(num_extra) print(num_extra)
assignments_bins = torch.arange(counts) assignments_bins = torch.arange(counts)
assignments_rand = torch.randint(0, counts-1, (num_extra, )) assignments_rand = torch.randint(0, counts - 1, (num_extra,))
assignments = torch.cat((assignments_bins, assignments_rand), 0) assignments = torch.cat((assignments_bins, assignments_rand), 0)
# assignments = assignments.type(torch.IntTensor) # assignments = assignments.type(torch.IntTensor)
assignments.cuda() assignments.cuda()

View File

@ -16,7 +16,9 @@ from .modules import ActivationQuantizer, IntConv2d, IntEmbedding, IntLinear
MAPPING = {nn.Linear: IntLinear, nn.Embedding: IntEmbedding, nn.Conv2d: IntConv2d} MAPPING = {nn.Linear: IntLinear, nn.Embedding: IntEmbedding, nn.Conv2d: IntConv2d}
def quantize_model_(model, p=0.2, bits=8, update_step=3000, method="histogram", remove_weights=False): def quantize_model_(
model, p=0.2, bits=8, update_step=3000, method="histogram", remove_weights=False
):
""" """
Replaces all modules with their scalar quantized counterpart and Replaces all modules with their scalar quantized counterpart and
registers hooks to quantize the post-ativations of those modules. registers hooks to quantize the post-ativations of those modules.

View File

@ -132,8 +132,7 @@ class TransformerEncoderLayerBase(nn.Module):
# will become -inf, which results in NaN in model parameters # will become -inf, which results in NaN in model parameters
if attn_mask is not None: if attn_mask is not None:
attn_mask = attn_mask.masked_fill( attn_mask = attn_mask.masked_fill(
attn_mask.to(torch.bool), attn_mask.to(torch.bool), -1e8 if x.dtype == torch.float32 else -1e4
-1e8 if x.dtype == torch.float32 else -1e4
) )
residual = x residual = x
@ -213,11 +212,19 @@ class TransformerDecoderLayerBase(nn.Module):
add_bias_kv=add_bias_kv, add_bias_kv=add_bias_kv,
add_zero_attn=add_zero_attn, add_zero_attn=add_zero_attn,
) )
self.attn_ln = LayerNorm(self.embed_dim) if utils.safe_getattr(cfg, 'scale_attn', False) else None self.attn_ln = (
LayerNorm(self.embed_dim)
if utils.safe_getattr(cfg, "scale_attn", False)
else None
)
self.nh = self.self_attn.num_heads self.nh = self.self_attn.num_heads
self.head_dim = self.self_attn.head_dim self.head_dim = self.self_attn.head_dim
scale_heads = utils.safe_getattr(cfg, 'scale_heads', False) scale_heads = utils.safe_getattr(cfg, "scale_heads", False)
self.c_attn = nn.Parameter(torch.ones((self.nh,)), requires_grad=True) if scale_heads else None self.c_attn = (
nn.Parameter(torch.ones((self.nh,)), requires_grad=True)
if scale_heads
else None
)
self.activation_fn = utils.get_activation_fn(activation=cfg.activation_fn) self.activation_fn = utils.get_activation_fn(activation=cfg.activation_fn)
activation_dropout_p = cfg.activation_dropout activation_dropout_p = cfg.activation_dropout
@ -238,8 +245,21 @@ class TransformerDecoderLayerBase(nn.Module):
self.encoder_attn = self.build_encoder_attention(self.embed_dim, cfg) self.encoder_attn = self.build_encoder_attention(self.embed_dim, cfg)
self.encoder_attn_layer_norm = LayerNorm(self.embed_dim, export=cfg.export) self.encoder_attn_layer_norm = LayerNorm(self.embed_dim, export=cfg.export)
self.ffn_layernorm = LayerNorm(cfg.decoder.ffn_embed_dim) if utils.safe_getattr(cfg, 'scale_fc', False) else None self.ffn_layernorm = (
self.w_resid = nn.Parameter(torch.ones(self.embed_dim, ), requires_grad=True) if utils.safe_getattr(cfg, 'scale_resids', False) else None LayerNorm(cfg.decoder.ffn_embed_dim)
if utils.safe_getattr(cfg, "scale_fc", False)
else None
)
self.w_resid = (
nn.Parameter(
torch.ones(
self.embed_dim,
),
requires_grad=True,
)
if utils.safe_getattr(cfg, "scale_resids", False)
else None
)
self.fc1 = self.build_fc1( self.fc1 = self.build_fc1(
self.embed_dim, self.embed_dim,
@ -297,7 +317,6 @@ class TransformerDecoderLayerBase(nn.Module):
def residual_connection(self, x, residual): def residual_connection(self, x, residual):
return residual + x return residual + x
def forward( def forward(
self, self,
x, x,
@ -377,7 +396,7 @@ class TransformerDecoderLayerBase(nn.Module):
if self.c_attn is not None: if self.c_attn is not None:
tgt_len, bsz = x.size(0), x.size(1) tgt_len, bsz = x.size(0), x.size(1)
x = x.view(tgt_len, bsz, self.nh, self.head_dim) x = x.view(tgt_len, bsz, self.nh, self.head_dim)
x = torch.einsum('tbhd,h->tbhd', x, self.c_attn) x = torch.einsum("tbhd,h->tbhd", x, self.c_attn)
x = x.reshape(tgt_len, bsz, self.embed_dim) x = x.reshape(tgt_len, bsz, self.embed_dim)
if self.attn_ln is not None: if self.attn_ln is not None:
x = self.attn_ln(x) x = self.attn_ln(x)

View File

@ -35,9 +35,7 @@ def init_bert_params(module):
def normal_(data): def normal_(data):
# with FSDP, module params will be on CUDA, so we cast them back to CPU # with FSDP, module params will be on CUDA, so we cast them back to CPU
# so that the RNG is consistent with and without FSDP # so that the RNG is consistent with and without FSDP
data.copy_( data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device))
data.cpu().normal_(mean=0.0, std=0.02).to(data.device)
)
if isinstance(module, nn.Linear): if isinstance(module, nn.Linear):
normal_(module.weight.data) normal_(module.weight.data)
@ -276,7 +274,9 @@ class TransformerSentenceEncoder(nn.Module):
inner_states.append(x) inner_states.append(x)
for layer in self.layers: for layer in self.layers:
x, _ = layer(x, self_attn_padding_mask=padding_mask, self_attn_mask=attn_mask) x, _ = layer(
x, self_attn_padding_mask=padding_mask, self_attn_mask=attn_mask
)
if not last_state_only: if not last_state_only:
inner_states.append(x) inner_states.append(x)

View File

@ -2,13 +2,13 @@
# Licensed under the MIT License. # Licensed under the MIT License.
""" Wrapper for ngram_repeat_block cuda extension """ """ Wrapper for ngram_repeat_block cuda extension """
import math
import warnings
from typing import Dict, List, Optional
import torch import torch
from torch import nn from torch import nn
import math
from typing import Dict, List, Optional
import warnings
try: try:
from fairseq import ngram_repeat_block_cuda from fairseq import ngram_repeat_block_cuda
@ -37,7 +37,7 @@ def is_cuda_extension_usable() -> bool:
class NGramRepeatBlock(nn.Module): class NGramRepeatBlock(nn.Module):
""" Wrapper class for calling ngram_repeat_block cuda extension """ """Wrapper class for calling ngram_repeat_block cuda extension"""
def __init__(self, no_repeat_ngram_size: int, use_extension: bool = True): def __init__(self, no_repeat_ngram_size: int, use_extension: bool = True):
super().__init__() super().__init__()

View File

@ -67,13 +67,13 @@ class FairseqAdam(FairseqOptimizer):
elif use_fused_adam: elif use_fused_adam:
logger.info("using FusedAdam") logger.info("using FusedAdam")
self._optimizer = fused_adam_cls( self._optimizer = fused_adam_cls(
params, params, use_fp16_stats=self.cfg.fp16_adam_stats, **self.optimizer_config
use_fp16_stats=self.cfg.fp16_adam_stats,
**self.optimizer_config
) )
else: else:
if self.cfg.fp16_adam_stats: if self.cfg.fp16_adam_stats:
raise NotImplementedError("--fp16-adam-stats is only supported with FusedAdamV1") raise NotImplementedError(
"--fp16-adam-stats is only supported with FusedAdamV1"
)
self._optimizer = Adam(params, **self.optimizer_config) self._optimizer = Adam(params, **self.optimizer_config)
@property @property

View File

@ -63,8 +63,9 @@ class AMPOptimizer(optim.FairseqOptimizer):
).format(self.min_loss_scale, new_loss_scale) ).format(self.min_loss_scale, new_loss_scale)
) )
else: else:
logger.info("AMP: overflow detected, setting scale to " logger.info(
f"to {new_loss_scale}") "AMP: overflow detected, setting scale to " f"to {new_loss_scale}"
)
return grad_norm return grad_norm
@property @property

View File

@ -23,7 +23,9 @@ class OptimizerAndSchedulerConfig(FairseqDataclass):
optimizer: Any = None optimizer: Any = None
lr_scheduler: Optional[Any] = None lr_scheduler: Optional[Any] = None
lr: List = II("optimization.lr") lr: List = II("optimization.lr")
lr_float: Optional[float] = None # this makes it easier to sweep on learning rate with auto sweepers lr_float: Optional[
float
] = None # this makes it easier to sweep on learning rate with auto sweepers
@dataclass @dataclass

View File

@ -16,6 +16,7 @@ from omegaconf import II, DictConfig
try: try:
import deepspeed import deepspeed
has_deepspeed = True has_deepspeed = True
except ImportError as e: except ImportError as e:
has_deepspeed = False has_deepspeed = False
@ -24,12 +25,15 @@ except ImportError as e:
def _get_cpu_adam(): def _get_cpu_adam():
try: try:
from deepspeed.ops.op_builder import CPUAdamBuilder from deepspeed.ops.op_builder import CPUAdamBuilder
return CPUAdamBuilder().load() return CPUAdamBuilder().load()
except ImportError: except ImportError:
# fbcode # fbcode
from deepspeed.ops.adam import DeepSpeedCPUAdam as ds_opt_adam from deepspeed.ops.adam import DeepSpeedCPUAdam as ds_opt_adam
return ds_opt_adam return ds_opt_adam
@dataclass @dataclass
class FairseqCPUAdamConfig(FairseqDataclass): class FairseqCPUAdamConfig(FairseqDataclass):
adam_betas: str = field( adam_betas: str = field(

View File

@ -64,9 +64,9 @@ class _FP16OptimizerMixin(object):
fp32_params = [] fp32_params = []
for p in params: for p in params:
p32 = torch.nn.Parameter(p.data.float()) p32 = torch.nn.Parameter(p.data.float())
if hasattr(p, 'expert'): if hasattr(p, "expert"):
p32.expert = True p32.expert = True
elif hasattr(p, 'base_expert'): elif hasattr(p, "base_expert"):
p32.base_expert = True p32.base_expert = True
p32.grad = torch.zeros_like(p32.data) p32.grad = torch.zeros_like(p32.data)
if hasattr(p, "param_group"): if hasattr(p, "param_group"):
@ -209,7 +209,9 @@ class _FP16OptimizerMixin(object):
self._sync_fp16_grads_to_fp32() self._sync_fp16_grads_to_fp32()
if getattr(self, "supports_step_with_scale", False): if getattr(self, "supports_step_with_scale", False):
self.fp32_optimizer.step(closure, scale=(1.0 / self._multiply_factor), groups=groups) self.fp32_optimizer.step(
closure, scale=(1.0 / self._multiply_factor), groups=groups
)
else: else:
self._unscale_grads() self._unscale_grads()
self.fp32_optimizer.step(closure, groups=groups) self.fp32_optimizer.step(closure, groups=groups)
@ -434,7 +436,9 @@ class _MemoryEfficientFP16OptimizerMixin(object):
"""Performs a single optimization step.""" """Performs a single optimization step."""
if getattr(self, "supports_step_with_scale", False): if getattr(self, "supports_step_with_scale", False):
# NOTE(msb) optimizer divides by scale factor # NOTE(msb) optimizer divides by scale factor
self.wrapped_optimizer.step(closure, scale=(1.0 / self._multiply_factor), groups=groups) self.wrapped_optimizer.step(
closure, scale=(1.0 / self._multiply_factor), groups=groups
)
else: else:
self._unscale_grads() self._unscale_grads()
self.wrapped_optimizer.step(closure, groups=groups) self.wrapped_optimizer.step(closure, groups=groups)

View File

@ -179,7 +179,7 @@ class FusedAdamV1(torch.optim.Optimizer):
if p.device.type == "cpu": if p.device.type == "cpu":
p_data_fp32 = p.data.cuda(non_blocking=True).float() p_data_fp32 = p.data.cuda(non_blocking=True).float()
out_p = torch.tensor([], dtype = torch.float) out_p = torch.tensor([], dtype=torch.float)
else: else:
p_data_fp32 = p.data.float() p_data_fp32 = p.data.float()
out_p = p.data out_p = p.data
@ -234,6 +234,7 @@ class FusedAdamV1(torch.optim.Optimizer):
p.data.copy_(p_data_fp32, non_blocking=True) p.data.copy_(p_data_fp32, non_blocking=True)
if self.use_fp16_stats: if self.use_fp16_stats:
def inf_norm(t): def inf_norm(t):
return torch.norm(t, float("inf")) return torch.norm(t, float("inf"))
@ -262,7 +263,9 @@ try:
def __init__(self, *args, use_fp16_stats=False, **kwargs): def __init__(self, *args, use_fp16_stats=False, **kwargs):
if use_fp16_stats: if use_fp16_stats:
raise NotImplementedError("--fp16-adam-stats is only supported with FusedAdamV1") raise NotImplementedError(
"--fp16-adam-stats is only supported with FusedAdamV1"
)
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
if not hasattr(self, "multi_tensor_adam"): if not hasattr(self, "multi_tensor_adam"):
raise Exception( raise Exception(

View File

@ -32,7 +32,7 @@ class ManualSchedule(LegacyFairseqLRScheduler):
self.optimizer.set_lr(self.lr) # Set the beginning of the epoch. self.optimizer.set_lr(self.lr) # Set the beginning of the epoch.
def parse_manuallr_args(self, lr_args_str): def parse_manuallr_args(self, lr_args_str):
lr_dict = ast.literal_eval(lr_args_str.replace(' ', '')) lr_dict = ast.literal_eval(lr_args_str.replace(" ", ""))
if not isinstance(lr_dict, dict): if not isinstance(lr_dict, dict):
raise ValueError("epoch2lr/update2lr must be abel to evaluated to a dict") raise ValueError("epoch2lr/update2lr must be abel to evaluated to a dict")
@ -84,9 +84,14 @@ class ManualSchedule(LegacyFairseqLRScheduler):
if manual_keys: if manual_keys:
manual_lr = self.epoch2lr[max(manual_keys)] manual_lr = self.epoch2lr[max(manual_keys)]
else: else:
logger.warning("@@@ epoch={} does not exist in manual lr input. epoch2lr={}...".format( logger.warning(
epoch, list(self.epoch2lr.items())[:min(10, len(self.epoch2lr.keys())-1)] "@@@ epoch={} does not exist in manual lr input. epoch2lr={}...".format(
)) epoch,
list(self.epoch2lr.items())[
: min(10, len(self.epoch2lr.keys()) - 1)
],
)
)
manual_lr = self.optimizer.get_lr() manual_lr = self.optimizer.get_lr()
return manual_lr return manual_lr
@ -102,8 +107,14 @@ class ManualSchedule(LegacyFairseqLRScheduler):
if manual_keys: if manual_keys:
manual_lr = self.update2lr[max(manual_keys)] manual_lr = self.update2lr[max(manual_keys)]
else: else:
logger.warning("epoch={} does not exist in manual lr input update2lr={}...".format( logger.warning(
num_updates, list(self.update2lr.items())[:min(10, len(self.update2lr.keys())-1)])) "epoch={} does not exist in manual lr input update2lr={}...".format(
num_updates,
list(self.update2lr.items())[
: min(10, len(self.update2lr.keys()) - 1)
],
)
)
manual_lr = self.optimizer.get_lr() manual_lr = self.optimizer.get_lr()
self.optimizer.set_lr(manual_lr) self.optimizer.set_lr(manual_lr)

View File

@ -36,8 +36,7 @@ class StepLRScheduleConfig(FairseqDataclass):
@register_lr_scheduler("step", dataclass=StepLRScheduleConfig) @register_lr_scheduler("step", dataclass=StepLRScheduleConfig)
class StepLRSchedule(FairseqLRScheduler): class StepLRSchedule(FairseqLRScheduler):
"""Decay learning rate every k updates by a fixed factor """Decay learning rate every k updates by a fixed factor"""
"""
def __init__(self, cfg: StepLRScheduleConfig, fairseq_optimizer): def __init__(self, cfg: StepLRScheduleConfig, fairseq_optimizer):
super().__init__(cfg, fairseq_optimizer) super().__init__(cfg, fairseq_optimizer)
@ -50,16 +49,16 @@ class StepLRSchedule(FairseqLRScheduler):
cfg.warmup_init_lr if cfg.warmup_init_lr >= 0 else self.min_lr cfg.warmup_init_lr if cfg.warmup_init_lr >= 0 else self.min_lr
) )
assert(self.lr_deacy_period > 0) assert self.lr_deacy_period > 0
assert(self.lr_decay <= 1) assert self.lr_decay <= 1
assert(self.min_lr >= 0) assert self.min_lr >= 0
assert(self.max_lr > self.min_lr) assert self.max_lr > self.min_lr
if cfg.warmup_updates > 0: if cfg.warmup_updates > 0:
# linearly warmup for the first cfg.warmup_updates # linearly warmup for the first cfg.warmup_updates
self.warmup_lr_step = ( self.warmup_lr_step = (
(self.max_lr - self.warmup_init_lr) / self.warmup_updates self.max_lr - self.warmup_init_lr
) ) / self.warmup_updates
else: else:
self.warmup_lr_step = 1 self.warmup_lr_step = 1

View File

@ -171,7 +171,9 @@ class SequenceGenerator(nn.Module):
yield id, src, ref, hypos[i] yield id, src, ref, hypos[i]
@torch.no_grad() @torch.no_grad()
def generate(self, models, sample: Dict[str, Dict[str, Tensor]], **kwargs) -> List[List[Dict[str, Tensor]]]: def generate(
self, models, sample: Dict[str, Dict[str, Tensor]], **kwargs
) -> List[List[Dict[str, Tensor]]]:
"""Generate translations. Match the api of other fairseq generators. """Generate translations. Match the api of other fairseq generators.
Args: Args:
@ -223,7 +225,10 @@ class SequenceGenerator(nn.Module):
else torch.tensor(src_tokens.size(-1)).to(src_tokens) else torch.tensor(src_tokens.size(-1)).to(src_tokens)
) )
else: else:
raise Exception("expected src_tokens or source in net input. input keys: " + str(net_input.keys())) raise Exception(
"expected src_tokens or source in net input. input keys: "
+ str(net_input.keys())
)
# bsz: total number of sentences in beam # bsz: total number of sentences in beam
# Note that src_tokens may have more than 2 dimensions (i.e. audio features) # Note that src_tokens may have more than 2 dimensions (i.e. audio features)
@ -328,7 +333,9 @@ class SequenceGenerator(nn.Module):
encoder_outs = self.model.reorder_encoder_out( encoder_outs = self.model.reorder_encoder_out(
encoder_outs, reorder_state encoder_outs, reorder_state
) )
with torch.autograd.profiler.record_function("EnsembleModel: forward_decoder"): with torch.autograd.profiler.record_function(
"EnsembleModel: forward_decoder"
):
lprobs, avg_attn_scores = self.model.forward_decoder( lprobs, avg_attn_scores = self.model.forward_decoder(
tokens[:, : step + 1], tokens[:, : step + 1],
encoder_outs, encoder_outs,
@ -751,7 +758,14 @@ class EnsembleModel(nn.Module):
return self.has_incremental return self.has_incremental
def max_decoder_positions(self): def max_decoder_positions(self):
return min([m.max_decoder_positions() for m in self.models if hasattr(m, "max_decoder_positions")] + [sys.maxsize]) return min(
[
m.max_decoder_positions()
for m in self.models
if hasattr(m, "max_decoder_positions")
]
+ [sys.maxsize]
)
@torch.jit.export @torch.jit.export
def forward_encoder(self, net_input: Dict[str, Tensor]): def forward_encoder(self, net_input: Dict[str, Tensor]):

View File

@ -35,8 +35,12 @@ class SpeechGenerator(object):
class AutoRegressiveSpeechGenerator(SpeechGenerator): class AutoRegressiveSpeechGenerator(SpeechGenerator):
def __init__( def __init__(
self, model, vocoder, data_cfg, max_iter: int = 6000, self,
eos_prob_threshold: float = 0.5, model,
vocoder,
data_cfg,
max_iter: int = 6000,
eos_prob_threshold: float = 0.5,
): ):
super().__init__(model, vocoder, data_cfg) super().__init__(model, vocoder, data_cfg)
self.max_iter = max_iter self.max_iter = max_iter
@ -54,8 +58,9 @@ class AutoRegressiveSpeechGenerator(SpeechGenerator):
raw_dim = out_dim // n_frames_per_step raw_dim = out_dim // n_frames_per_step
# initialize # initialize
encoder_out = model.forward_encoder(src_tokens, src_lengths, encoder_out = model.forward_encoder(
speaker=sample["speaker"]) src_tokens, src_lengths, speaker=sample["speaker"]
)
incremental_state = {} incremental_state = {}
feat, attn, eos_prob = [], [], [] feat, attn, eos_prob = [], [], []
finished = src_tokens.new_zeros((bsz,)).bool() finished = src_tokens.new_zeros((bsz,)).bool()
@ -66,21 +71,24 @@ class AutoRegressiveSpeechGenerator(SpeechGenerator):
cur_out_lens = out_lens.clone() cur_out_lens = out_lens.clone()
cur_out_lens.masked_fill_(cur_out_lens.eq(self.max_iter), step + 1) cur_out_lens.masked_fill_(cur_out_lens.eq(self.max_iter), step + 1)
_, cur_eos_out, cur_extra = model.forward_decoder( _, cur_eos_out, cur_extra = model.forward_decoder(
prev_feat_out, encoder_out=encoder_out, prev_feat_out,
encoder_out=encoder_out,
incremental_state=incremental_state, incremental_state=incremental_state,
target_lengths=cur_out_lens, speaker=sample["speaker"], **kwargs target_lengths=cur_out_lens,
speaker=sample["speaker"],
**kwargs
) )
cur_eos_prob = torch.sigmoid(cur_eos_out).squeeze(2) cur_eos_prob = torch.sigmoid(cur_eos_out).squeeze(2)
feat.append(cur_extra['feature_out']) feat.append(cur_extra["feature_out"])
attn.append(cur_extra['attn']) attn.append(cur_extra["attn"])
eos_prob.append(cur_eos_prob) eos_prob.append(cur_eos_prob)
cur_finished = (cur_eos_prob.squeeze(1) > self.eos_prob_threshold) cur_finished = cur_eos_prob.squeeze(1) > self.eos_prob_threshold
out_lens.masked_fill_((~finished) & cur_finished, step + 1) out_lens.masked_fill_((~finished) & cur_finished, step + 1)
finished = finished | cur_finished finished = finished | cur_finished
if finished.sum().item() == bsz: if finished.sum().item() == bsz:
break break
prev_feat_out = cur_extra['feature_out'] prev_feat_out = cur_extra["feature_out"]
feat = torch.cat(feat, dim=1) feat = torch.cat(feat, dim=1)
feat = model.decoder.postnet(feat) + feat feat = model.decoder.postnet(feat) + feat
@ -98,11 +106,11 @@ class AutoRegressiveSpeechGenerator(SpeechGenerator):
finalized = [ finalized = [
{ {
'feature': feat[b, :out_len], "feature": feat[b, :out_len],
'eos_prob': eos_prob[b, :out_len], "eos_prob": eos_prob[b, :out_len],
'attn': attn[b, :, :out_len], "attn": attn[b, :, :out_len],
'alignment': alignment[b, :out_len], "alignment": alignment[b, :out_len],
'waveform': self.get_waveform(feat[b, :out_len]), "waveform": self.get_waveform(feat[b, :out_len]),
} }
for b, out_len in zip(range(bsz), out_lens) for b, out_len in zip(range(bsz), out_lens)
] ]
@ -134,7 +142,7 @@ class NonAutoregressiveSpeechGenerator(SpeechGenerator):
prev_output_tokens=sample["net_input"]["prev_output_tokens"], prev_output_tokens=sample["net_input"]["prev_output_tokens"],
incremental_state=None, incremental_state=None,
target_lengths=sample["target_lengths"], target_lengths=sample["target_lengths"],
speaker=sample["speaker"] speaker=sample["speaker"],
) )
if feat_post is not None: if feat_post is not None:
feat = feat_post feat = feat_post
@ -142,9 +150,7 @@ class NonAutoregressiveSpeechGenerator(SpeechGenerator):
feat = feat.view(bsz, -1, raw_dim) feat = feat.view(bsz, -1, raw_dim)
feat = self.gcmvn_denormalize(feat) feat = self.gcmvn_denormalize(feat)
dur_out = torch.clamp( dur_out = torch.clamp(torch.round(torch.exp(log_dur_out) - 1).long(), min=0)
torch.round(torch.exp(log_dur_out) - 1).long(), min=0
)
def get_dur_plot_data(d): def get_dur_plot_data(d):
r = [] r = []
@ -155,11 +161,11 @@ class NonAutoregressiveSpeechGenerator(SpeechGenerator):
out_lens = out_lens * n_frames_per_step out_lens = out_lens * n_frames_per_step
finalized = [ finalized = [
{ {
'feature': feat[b, :l] if l > 0 else feat.new_zeros([1, raw_dim]), "feature": feat[b, :l] if l > 0 else feat.new_zeros([1, raw_dim]),
'waveform': self.get_waveform( "waveform": self.get_waveform(
feat[b, :l] if l > 0 else feat.new_zeros([1, raw_dim]) feat[b, :l] if l > 0 else feat.new_zeros([1, raw_dim])
), ),
'attn': feat.new_tensor(get_dur_plot_data(dur_out[b])), "attn": feat.new_tensor(get_dur_plot_data(dur_out[b])),
} }
for b, l in zip(range(bsz), out_lens) for b, l in zip(range(bsz), out_lens)
] ]
@ -188,8 +194,12 @@ class TeacherForcingAutoRegressiveSpeechGenerator(AutoRegressiveSpeechGenerator)
bsz = src_tokens.shape[0] bsz = src_tokens.shape[0]
feat, eos_prob, extra = model( feat, eos_prob, extra = model(
src_tokens, src_lens, prev_out_tokens, incremental_state=None, src_tokens,
target_lengths=tgt_lens, speaker=sample["speaker"] src_lens,
prev_out_tokens,
incremental_state=None,
target_lengths=tgt_lens,
speaker=sample["speaker"],
) )
attn = extra["attn"] # B x T_s x T_t attn = extra["attn"] # B x T_s x T_t
@ -203,11 +213,11 @@ class TeacherForcingAutoRegressiveSpeechGenerator(AutoRegressiveSpeechGenerator)
finalized = [ finalized = [
{ {
'feature': feat[b, :tgt_len], "feature": feat[b, :tgt_len],
'eos_prob': eos_prob[b, :tgt_len], "eos_prob": eos_prob[b, :tgt_len],
'attn': attn[b, :, :tgt_len], "attn": attn[b, :, :tgt_len],
'alignment': alignment[b, :tgt_len], "alignment": alignment[b, :tgt_len],
'waveform': self.get_waveform(feat[b, :tgt_len]), "waveform": self.get_waveform(feat[b, :tgt_len]),
} }
for b, tgt_len in zip(range(bsz), tgt_lens) for b, tgt_len in zip(range(bsz), tgt_lens)
] ]

View File

@ -67,31 +67,31 @@ class AudioFinetuningConfig(AudioPretrainingConfig):
default=False, metadata={"help": "evaluation with BLEU scores"} default=False, metadata={"help": "evaluation with BLEU scores"}
) )
eval_bleu_detok: Optional[str] = field( eval_bleu_detok: Optional[str] = field(
default=None, metadata={ default=None,
metadata={
"help": "detokenize before computing BLEU (e.g., 'moses'); " "help": "detokenize before computing BLEU (e.g., 'moses'); "
"required if using --eval-bleu; use 'space' to disable " "required if using --eval-bleu; use 'space' to disable "
"detokenization; see fairseq.data.encoders for other options" "detokenization; see fairseq.data.encoders for other options"
} },
) )
eval_bleu_detok_args: str = field( eval_bleu_detok_args: str = field(
default="{}", default="{}", metadata={"help": "args for building the tokenizer, if needed"}
metadata={"help": "args for building the tokenizer, if needed"}
) )
eval_tokenized_bleu: bool = field( eval_tokenized_bleu: bool = field(
default=False, default=False, metadata={"help": "compute tokenized BLEU instead of sacrebleu"}
metadata={"help": "compute tokenized BLEU instead of sacrebleu"}
) )
eval_bleu_remove_bpe: Optional[str] = field( eval_bleu_remove_bpe: Optional[str] = field(
default=None, metadata={"help": "remove BPE before computing BLEU"} default=None, metadata={"help": "remove BPE before computing BLEU"}
) )
eval_bleu_args: str = field( eval_bleu_args: str = field(
default="{}", default="{}",
metadata={"help": "generation args for BLUE scoring, e.g., " metadata={
"'{\"beam\": 4, \"lenpen\": 0.6}'"} "help": "generation args for BLUE scoring, e.g., "
'\'{"beam": 4, "lenpen": 0.6}\''
},
) )
eval_bleu_print_samples: bool = field( eval_bleu_print_samples: bool = field(
default=False, default=False, metadata={"help": "print sample generations during validation"}
metadata={"help": "print sample generations during validation"}
) )
autoregressive: bool = field( autoregressive: bool = field(
default=False, default=False,
@ -123,7 +123,9 @@ class AudioFinetuningTask(AudioPretrainingTask):
return Dictionary.load(dict_path) return Dictionary.load(dict_path)
return None return None
def load_dataset(self, split: str, task_cfg: AudioFinetuningConfig = None, **kwargs): def load_dataset(
self, split: str, task_cfg: AudioFinetuningConfig = None, **kwargs
):
super().load_dataset(split, task_cfg, **kwargs) super().load_dataset(split, task_cfg, **kwargs)
task_cfg = task_cfg or self.cfg task_cfg = task_cfg or self.cfg
@ -138,7 +140,8 @@ class AudioFinetuningTask(AudioPretrainingTask):
with open(label_path, "r") as f: with open(label_path, "r") as f:
labels = [ labels = [
text_compressor.compress(l) text_compressor.compress(l)
for i, l in enumerate(f) if i not in skipped_indices for i, l in enumerate(f)
if i not in skipped_indices
] ]
assert len(labels) == len(self.datasets[split]), ( assert len(labels) == len(self.datasets[split]), (
@ -157,7 +160,7 @@ class AudioFinetuningTask(AudioPretrainingTask):
process_label=process_label, process_label=process_label,
label_len_fn=label_len_fn, label_len_fn=label_len_fn,
add_to_input=task_cfg.get("autoregressive", False), add_to_input=task_cfg.get("autoregressive", False),
text_compression_level=text_compression_level text_compression_level=text_compression_level,
) )
@property @property
@ -176,8 +179,8 @@ class AudioFinetuningTask(AudioPretrainingTask):
logging_output["_num_words"] = metrics["num_words"] logging_output["_num_words"] = metrics["num_words"]
if self.cfg.eval_bleu and self.cfg.autoregressive: if self.cfg.eval_bleu and self.cfg.autoregressive:
metrics = self._inference_with_bleu(self.sequence_generator, sample, model) metrics = self._inference_with_bleu(self.sequence_generator, sample, model)
logging_output['_bleu_sys_len'] = metrics.sys_len logging_output["_bleu_sys_len"] = metrics.sys_len
logging_output['_bleu_ref_len'] = metrics.ref_len logging_output["_bleu_ref_len"] = metrics.ref_len
# we split counts into separate entries so that they can be # we split counts into separate entries so that they can be
# summed efficiently across workers using fast-stat-sync # summed efficiently across workers using fast-stat-sync
assert len(metrics.counts) == 4 assert len(metrics.counts) == 4
@ -200,9 +203,9 @@ class AudioFinetuningTask(AudioPretrainingTask):
self.tokenizer = None self.tokenizer = None
if self.cfg.eval_bleu and self.cfg.autoregressive: if self.cfg.eval_bleu and self.cfg.autoregressive:
assert self.cfg.eval_bleu_detok is not None, ( assert self.cfg.eval_bleu_detok is not None, (
'--eval-bleu-detok is required if using --eval-bleu; ' "--eval-bleu-detok is required if using --eval-bleu; "
'try --eval-bleu-detok=moses (or --eval-bleu-detok=space ' "try --eval-bleu-detok=moses (or --eval-bleu-detok=space "
'to disable detokenization, e.g., when using sentencepiece)' "to disable detokenization, e.g., when using sentencepiece)"
) )
detok_args = json.loads(self.cfg.eval_bleu_detok_args) detok_args = json.loads(self.cfg.eval_bleu_detok_args)
self.tokenizer = encoders.build_tokenizer( self.tokenizer = encoders.build_tokenizer(
@ -261,9 +264,7 @@ class AudioFinetuningTask(AudioPretrainingTask):
# BLEU scores. Instead, we use a somewhat more verbose # BLEU scores. Instead, we use a somewhat more verbose
# alternative that is unlikely to appear in the real # alternative that is unlikely to appear in the real
# reference, but doesn't get split into multiple tokens. # reference, but doesn't get split into multiple tokens.
unk_string=( unk_string=("UNKNOWNTOKENINREF" if is_ref else "UNKNOWNTOKENINHYP"),
"UNKNOWNTOKENINREF" if is_ref else "UNKNOWNTOKENINHYP"
),
) )
if self.tokenizer: if self.tokenizer:
s = self.tokenizer.decode(s) s = self.tokenizer.decode(s)
@ -272,21 +273,18 @@ class AudioFinetuningTask(AudioPretrainingTask):
gen_out = self.inference_step(generator, [model], sample) gen_out = self.inference_step(generator, [model], sample)
hyps, refs = [], [] hyps, refs = [], []
for i in range(len(gen_out)): for i in range(len(gen_out)):
hyps.append(decode(gen_out[i][0]['tokens'], is_ref=False)) hyps.append(decode(gen_out[i][0]["tokens"], is_ref=False))
refs.append( refs.append(
decode( decode(
utils.strip_pad( utils.strip_pad(sample["target"][i], self.target_dictionary.pad()),
sample['target'][i],
self.target_dictionary.pad()
),
is_ref=True, # don't count <unk> as matches to the hypo is_ref=True, # don't count <unk> as matches to the hypo
) )
) )
if self.cfg.eval_bleu_print_samples: if self.cfg.eval_bleu_print_samples:
logger.info('H-{} {}'.format(sample["id"][0], hyps[0])) logger.info("H-{} {}".format(sample["id"][0], hyps[0]))
logger.info('T-{} {}'.format(sample["id"][0], refs[0])) logger.info("T-{} {}".format(sample["id"][0], refs[0]))
eval_tokenization = 'none' if self.cfg.eval_tokenized_bleu else '13a' eval_tokenization = "none" if self.cfg.eval_tokenized_bleu else "13a"
return sacrebleu.corpus_bleu(hyps, [refs], tokenize=eval_tokenization) return sacrebleu.corpus_bleu(hyps, [refs], tokenize=eval_tokenization)
def reduce_metrics(self, logging_outputs, criterion): def reduce_metrics(self, logging_outputs, criterion):
@ -329,18 +327,17 @@ class AudioFinetuningTask(AudioPretrainingTask):
count_keys = [f"_bleu_counts_{i}" for i in range(4)] count_keys = [f"_bleu_counts_{i}" for i in range(4)]
total_keys = [f"_bleu_totals_{i}" for i in range(4)] total_keys = [f"_bleu_totals_{i}" for i in range(4)]
for k in len_keys + count_keys + total_keys: for k in len_keys + count_keys + total_keys:
metrics.log_scalar( metrics.log_scalar(k, sum(log.get(k, 0) for log in logging_outputs))
k, sum(log.get(k, 0) for log in logging_outputs)
)
import sacrebleu import sacrebleu
metrics.log_derived( metrics.log_derived(
'bleu', "bleu",
lambda meters: sacrebleu.compute_bleu( lambda meters: sacrebleu.compute_bleu(
correct=[meters[k].sum for k in count_keys], correct=[meters[k].sum for k in count_keys],
total=[meters[k].sum for k in total_keys], total=[meters[k].sum for k in total_keys],
sys_len=meters['_bleu_sys_len'].sum, sys_len=meters["_bleu_sys_len"].sum,
ref_len=meters['_bleu_ref_len'].sum, ref_len=meters["_bleu_ref_len"].sum,
smooth_method="exp" smooth_method="exp",
).score ).score,
) )

View File

@ -50,8 +50,7 @@ class AudioPretrainingConfig(FairseqDataclass):
data: str = field(default=MISSING, metadata={"help": "path to data directory"}) data: str = field(default=MISSING, metadata={"help": "path to data directory"})
labels: Optional[str] = field( labels: Optional[str] = field(
default=None, default=None,
metadata={ metadata={"help": "extension of the label file to load, used for fine-tuning"},
"help": "extension of the label file to load, used for fine-tuning"},
) )
binarized_dataset: bool = field( binarized_dataset: bool = field(
default=False, default=False,
@ -102,8 +101,8 @@ class AudioPretrainingConfig(FairseqDataclass):
default="none", default="none",
metadata={ metadata={
"help": "compression level for texts (e.g. audio filenames, " "help": "compression level for texts (e.g. audio filenames, "
"target texts): none/low/high (default: none). " "target texts): none/low/high (default: none). "
} },
) )

View File

@ -135,7 +135,6 @@ class DenoisingTask(LegacyFairseqTask):
'e.g., "train,valid" (default: all dataset splits)', 'e.g., "train,valid" (default: all dataset splits)',
) )
def __init__(self, args, dictionary): def __init__(self, args, dictionary):
super().__init__(args) super().__init__(args)
self.dictionary = dictionary self.dictionary = dictionary

View File

@ -11,20 +11,19 @@ from fairseq.tasks.text_to_speech import TextToSpeechTask
logging.basicConfig( logging.basicConfig(
format='%(asctime)s | %(levelname)s | %(name)s | %(message)s', format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
datefmt='%Y-%m-%d %H:%M:%S', level=logging.INFO datefmt="%Y-%m-%d %H:%M:%S",
level=logging.INFO,
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@register_task('frm_text_to_speech') @register_task("frm_text_to_speech")
class FrmTextToSpeechTask(TextToSpeechTask): class FrmTextToSpeechTask(TextToSpeechTask):
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):
TextToSpeechTask.add_args(parser) TextToSpeechTask.add_args(parser)
parser.add_argument( parser.add_argument("--do_chunk", action="store_true", help="train on chunks")
"--do_chunk", action="store_true", help="train on chunks"
)
parser.add_argument("--chunk_bound", default=-1, type=int) parser.add_argument("--chunk_bound", default=-1, type=int)
parser.add_argument("--chunk_init", default=50, type=int) parser.add_argument("--chunk_init", default=50, type=int)
parser.add_argument("--chunk_incr", default=5, type=int) parser.add_argument("--chunk_incr", default=5, type=int)
@ -52,5 +51,5 @@ class FrmTextToSpeechTask(TextToSpeechTask):
chunk_incr=self.args.chunk_incr, chunk_incr=self.args.chunk_incr,
add_eos=self.args.add_eos, add_eos=self.args.add_eos,
dedup=self.args.dedup, dedup=self.args.dedup,
ref_fpu=self.args.ref_fpu ref_fpu=self.args.ref_fpu,
) )

Some files were not shown because too many files have changed in this diff Show More