From 0dfd6b624081fc4e1c72fc74ae0cd2de199c334c Mon Sep 17 00:00:00 2001 From: dianaml0 <82468439+dianaml0@users.noreply.github.com> Date: Mon, 29 Nov 2021 12:30:10 -0800 Subject: [PATCH] Add linting with black (#2678) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: # Before submitting - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [ ] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/main/CONTRIBUTING.md)? - [ ] Did you make sure to update the docs? - [ ] Did you write any new necessary tests? ## What does this PR do? Fixes # (issue). ## PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. ## Did you have fun? Make sure you had fun coding � Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/2678 Reviewed By: Mortimerp9 Differential Revision: D32653381 Pulled By: dianaml0 fbshipit-source-id: 2810d14867cd7d64f4d340740e2b590b82de47fe --- .github/workflows/build.yml | 5 + fairseq/__init__.py | 1 + fairseq/benchmark/dummy_mt.py | 4 +- fairseq/checkpoint_utils.py | 37 +-- fairseq/criterions/fastspeech2_loss.py | 44 +-- fairseq/criterions/hubert_criterion.py | 31 ++- ...moothed_cross_entropy_latency_augmented.py | 47 ++-- fairseq/criterions/tacotron2_loss.py | 93 ++++--- fairseq/criterions/wav2vec_criterion.py | 15 +- fairseq/data/add_target_dataset.py | 2 +- fairseq/data/audio/audio_utils.py | 91 +++--- .../data/audio/frm_text_to_speech_dataset.py | 42 ++- fairseq/data/audio/hubert_dataset.py | 30 +- fairseq/data/audio/multi_modality_dataset.py | 1 + fairseq/data/audio/raw_audio_dataset.py | 1 + fairseq/data/audio/speech_to_text_dataset.py | 68 +++-- fairseq/data/audio/text_to_speech_dataset.py | 127 +++++---- fairseq/data/colorize_dataset.py | 2 +- fairseq/data/data_utils.py | 11 +- fairseq/data/encoders/sentencepiece_bpe.py | 9 +- fairseq/data/fairseq_dataset.py | 2 +- fairseq/data/huffman/huffman_coder.py | 4 +- fairseq/data/indexed_dataset.py | 4 +- fairseq/data/iterators.py | 8 +- fairseq/data/language_pair_dataset.py | 10 +- fairseq/data/multi_corpus_dataset.py | 4 +- .../multilingual/multilingual_data_manager.py | 56 ++-- fairseq/data/noising.py | 1 - fairseq/data/text_compressor.py | 6 +- fairseq/data/token_block_dataset.py | 8 +- .../data/transform_eos_lang_pair_dataset.py | 2 +- fairseq/dataclass/configs.py | 69 +++-- fairseq/dataclass/constants.py | 18 +- fairseq/dataclass/initialize.py | 2 +- fairseq/dataclass/utils.py | 14 +- fairseq/distributed/__init__.py | 6 +- .../distributed_timeout_wrapper.py | 11 +- .../legacy_distributed_data_parallel.py | 2 +- fairseq/distributed/module_proxy_wrapper.py | 5 +- .../tpu_distributed_data_parallel.py | 6 +- fairseq/distributed/utils.py | 17 +- fairseq/file_io.py | 2 + fairseq/file_utils.py | 1 + fairseq/logging/metrics.py | 2 + fairseq/model_parallel/megatron_trainer.py | 16 +- .../pipeline_parallel_transformer/layers.py | 6 +- .../pipeline_parallel_transformer/model.py | 50 +++- .../model_parallel/models/transformer_lm.py | 7 +- fairseq/models/__init__.py | 4 +- fairseq/models/bart/hub_interface.py | 31 ++- fairseq/models/bart/model.py | 8 +- fairseq/models/distributed_fairseq_model.py | 5 +- fairseq/models/ema/ema.py | 22 +- fairseq/models/fairseq_decoder.py | 1 - fairseq/models/fairseq_model.py | 13 +- fairseq/models/hubert/hubert.py | 70 ++--- fairseq/models/hubert/hubert_asr.py | 41 +-- fairseq/models/lstm.py | 12 +- fairseq/models/nat/fairseq_nat_model.py | 8 +- .../models/nat/nonautoregressive_ensembles.py | 3 +- fairseq/models/roberta/model.py | 6 +- fairseq/models/roberta/model_gottbert.py | 28 +- .../models/speech_to_text/s2t_transformer.py | 38 ++- .../models/speech_to_text/xm_transformer.py | 260 +++++++++++------- fairseq/models/text_to_speech/fastspeech2.py | 107 ++++--- fairseq/models/text_to_speech/hifigan.py | 4 +- fairseq/models/text_to_speech/tacotron2.py | 136 +++++---- .../models/text_to_speech/tts_transformer.py | 106 ++++--- fairseq/models/text_to_speech/vocoder.py | 84 +++--- .../models/transformer/transformer_decoder.py | 4 +- .../models/transformer/transformer_encoder.py | 11 +- fairseq/models/transformer_lm.py | 48 +++- fairseq/models/wav2vec/wav2vec2.py | 20 +- fairseq/models/wav2vec/wav2vec2_asr.py | 24 +- fairseq/modules/base_layer.py | 75 +++-- fairseq/modules/checkpoint_activations.py | 7 +- fairseq/modules/gumbel_vector_quantizer.py | 1 + fairseq/modules/kmeans_attention.py | 237 ++++++++++++---- fairseq/modules/linearized_convolution.py | 23 +- fairseq/modules/location_attention.py | 23 +- fairseq/modules/lstm_cell_with_zoneout.py | 14 +- fairseq/modules/quantization/pq/utils.py | 10 +- fairseq/modules/quantization/scalar/utils.py | 4 +- fairseq/modules/transformer_layer.py | 37 ++- .../modules/transformer_sentence_encoder.py | 8 +- fairseq/ngram_repeat_block.py | 10 +- fairseq/optim/adam.py | 8 +- fairseq/optim/amp_optimizer.py | 5 +- fairseq/optim/composite.py | 4 +- fairseq/optim/cpu_adam.py | 4 + fairseq/optim/fp16_optimizer.py | 12 +- fairseq/optim/fused_adam.py | 7 +- .../optim/lr_scheduler/manual_lr_scheduler.py | 23 +- .../optim/lr_scheduler/step_lr_scheduler.py | 15 +- fairseq/sequence_generator.py | 22 +- fairseq/speech_generator.py | 68 +++-- fairseq/tasks/audio_finetuning.py | 75 +++-- fairseq/tasks/audio_pretraining.py | 7 +- fairseq/tasks/denoising.py | 1 - fairseq/tasks/frm_text_to_speech.py | 13 +- fairseq/tasks/hubert_pretraining.py | 28 +- fairseq/tasks/language_modeling.py | 10 +- fairseq/tasks/simultaneous_translation.py | 9 +- fairseq/tasks/speech_to_text.py | 7 +- fairseq/tasks/text_to_speech.py | 184 ++++++++----- fairseq/tasks/translation.py | 6 +- fairseq/tasks/translation_lev.py | 12 +- .../tasks/translation_multi_simple_epoch.py | 6 +- fairseq/trainer.py | 50 +++- fairseq/utils.py | 5 +- fairseq_cli/generate.py | 23 +- fairseq_cli/hydra_train.py | 23 +- fairseq_cli/interactive.py | 7 +- fairseq_cli/preprocess.py | 4 +- fairseq_cli/train.py | 25 +- fairseq_cli/validate.py | 8 +- setup.py | 5 +- tests/distributed/test_bmuf.py | 8 +- .../test_distributed_timeout_wrapper.py | 2 - .../distributed/test_module_proxy_wrapper.py | 1 - tests/distributed/utils.py | 5 +- tests/gpu/test_binaries_gpu.py | 29 +- tests/gpu/test_ema_gpu.py | 44 ++- tests/test_amp_optimizer.py | 11 +- tests/test_binaries.py | 24 +- tests/test_checkpoint_utils.py | 34 +-- tests/test_data_utils.py | 6 +- tests/test_dataclass_utils.py | 2 +- tests/test_ema.py | 44 ++- tests/test_export.py | 3 +- tests/test_file_io.py | 1 + tests/test_iopath.py | 1 - tests/test_lm_context_window.py | 16 +- tests/test_multi_corpus_dataset.py | 7 +- tests/test_noising.py | 7 +- tests/test_sequence_generator.py | 18 +- tests/utils.py | 10 +- 137 files changed, 2139 insertions(+), 1353 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index bd1d3b3c..a80e0f92 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -53,3 +53,8 @@ jobs: - name: Run tests run: | python setup.py test + + - name: Lint with black + run: | + pip install black + black --check . --extend-exclude 'examples|fairseq\/model_parallel\/megatron' diff --git a/fairseq/__init__.py b/fairseq/__init__.py index dc9fd188..080c988b 100644 --- a/fairseq/__init__.py +++ b/fairseq/__init__.py @@ -27,6 +27,7 @@ sys.modules["fairseq.progress_bar"] = progress_bar # initialize hydra from fairseq.dataclass.initialize import hydra_init + hydra_init() import fairseq.criterions # noqa diff --git a/fairseq/benchmark/dummy_mt.py b/fairseq/benchmark/dummy_mt.py index 4ca7be93..28d78cff 100644 --- a/fairseq/benchmark/dummy_mt.py +++ b/fairseq/benchmark/dummy_mt.py @@ -7,10 +7,10 @@ import logging import numpy as np import torch + from fairseq.data import Dictionary, FairseqDataset from fairseq.tasks import LegacyFairseqTask, register_task - logger = logging.getLogger(__name__) @@ -36,7 +36,7 @@ class DummyMTTask(LegacyFairseqTask): @classmethod def setup_task(cls, args, **kwargs): - """Setup the task. """ + """Setup the task.""" dictionary = Dictionary() for i in range(args.dict_size): dictionary.add_symbol("word{}".format(i)) diff --git a/fairseq/checkpoint_utils.py b/fairseq/checkpoint_utils.py index 289053e5..a5be928f 100644 --- a/fairseq/checkpoint_utils.py +++ b/fairseq/checkpoint_utils.py @@ -96,10 +96,7 @@ def save_checkpoint(cfg: CheckpointConfig, trainer, epoch_itr, val_loss): checkpoint_conds[ "checkpoint.best_{}_{:.3f}{}{}.pt".format( - cfg.best_checkpoint_metric, - val_loss, - rand_sfx, - suffix + cfg.best_checkpoint_metric, val_loss, rand_sfx, suffix ) ] = worst_best is None or is_better(val_loss, worst_best) checkpoint_conds[ @@ -468,9 +465,7 @@ def load_model_ensemble_and_task( and len(state["optimizer_history"]) > 0 and "num_updates" in state["optimizer_history"][-1] ): - model.set_num_updates( - state["optimizer_history"][-1]["num_updates"] - ) + model.set_num_updates(state["optimizer_history"][-1]["num_updates"]) model.load_state_dict( state["model"], strict=strict, model_cfg=cfg.model ) @@ -588,9 +583,8 @@ def _upgrade_state_dict(state): # backward compatibility, cfg updates if "args" in state and state["args"] is not None: # old model checkpoints may not have separate source/target positions - if ( - hasattr(state["args"], "max_positions") - and not hasattr(state["args"], "max_source_positions") + if hasattr(state["args"], "max_positions") and not hasattr( + state["args"], "max_source_positions" ): state["args"].max_source_positions = state["args"].max_positions state["args"].max_target_positions = state["args"].max_positions @@ -615,13 +609,10 @@ def _upgrade_state_dict(state): state["args"].stop_min_lr = state["args"].min_lr del state["args"].min_lr # binary_cross_entropy / kd_binary_cross_entropy => wav2vec criterion - if ( - hasattr(state["args"], "criterion") - and state["args"].criterion in [ - "binary_cross_entropy", - "kd_binary_cross_entropy", - ] - ): + if hasattr(state["args"], "criterion") and state["args"].criterion in [ + "binary_cross_entropy", + "kd_binary_cross_entropy", + ]: state["args"].criterion = "wav2vec" # remove log_keys if it's None (criteria will supply a default value of []) if hasattr(state["args"], "log_keys") and state["args"].log_keys is None: @@ -659,7 +650,9 @@ def _upgrade_state_dict(state): ): cfg.task.eval_wer_config.print_alignment = "hard" if "generation" in cfg and isinstance(cfg.generation.print_alignment, bool): - cfg.generation.print_alignment = "hard" if cfg.generation.print_alignment else None + cfg.generation.print_alignment = ( + "hard" if cfg.generation.print_alignment else None + ) if ( "model" in cfg and "w2v_args" in cfg.model @@ -833,16 +826,16 @@ def load_ema_from_checkpoint(fpath): params_dict = collections.OrderedDict() new_state = None - with PathManager.open(fpath, 'rb') as f: + with PathManager.open(fpath, "rb") as f: new_state = torch.load( f, map_location=( - lambda s, _: torch.serialization.default_restore_location(s, 'cpu') + lambda s, _: torch.serialization.default_restore_location(s, "cpu") ), ) # EMA model is stored in a separate "extra state" - model_params = new_state['extra_state']['ema'] + model_params = new_state["extra_state"]["ema"] for key in list(model_params.keys()): p = model_params[key] @@ -860,5 +853,5 @@ def load_ema_from_checkpoint(fpath): "ema model weights, is this model trained with EMA?" ) - new_state['model'] = params_dict + new_state["model"] = params_dict return new_state diff --git a/fairseq/criterions/fastspeech2_loss.py b/fairseq/criterions/fastspeech2_loss.py index b17b5070..b317409e 100644 --- a/fairseq/criterions/fastspeech2_loss.py +++ b/fairseq/criterions/fastspeech2_loss.py @@ -20,9 +20,7 @@ from fairseq.models.fairseq_model import FairseqEncoderModel @dataclass class FastSpeech2CriterionConfig(FairseqDataclass): - ctc_weight: float = field( - default=0.0, metadata={"help": "weight for CTC loss"} - ) + ctc_weight: float = field(default=0.0, metadata={"help": "weight for CTC loss"}) @register_criterion("fastspeech2", dataclass=FastSpeech2CriterionConfig) @@ -44,7 +42,7 @@ class FastSpeech2Loss(FairseqCriterion): speaker=sample["speaker"], durations=sample["durations"], pitches=sample["pitches"], - energies=sample["energies"] + energies=sample["energies"], ) src_mask = lengths_to_mask(sample["net_input"]["src_lengths"]) @@ -57,8 +55,7 @@ class FastSpeech2Loss(FairseqCriterion): feat_out, feat = _feat_out[tgt_mask], sample["target"][tgt_mask] l1_loss = F.l1_loss(feat_out, feat, reduction=reduction) if _feat_out_post is not None: - l1_loss += F.l1_loss(_feat_out_post[tgt_mask], feat, - reduction=reduction) + l1_loss += F.l1_loss(_feat_out_post[tgt_mask], feat, reduction=reduction) pitch_loss = F.mse_loss(pitch_out, pitches, reduction=reduction) energy_loss = F.mse_loss(energy_out, energies, reduction=reduction) @@ -69,16 +66,23 @@ class FastSpeech2Loss(FairseqCriterion): log_dur = torch.log(dur + 1)[src_mask] dur_loss = F.mse_loss(log_dur_out, log_dur, reduction=reduction) - ctc_loss = torch.tensor(0.).type_as(l1_loss) - if self.ctc_weight > 0.: + ctc_loss = torch.tensor(0.0).type_as(l1_loss) + if self.ctc_weight > 0.0: lprobs = model.get_normalized_probs((_feat_out,), log_probs=True) lprobs = lprobs.transpose(0, 1) # T x B x C src_mask = lengths_to_mask(src_lens) src_tokens_flat = src_tokens.masked_select(src_mask) - ctc_loss = F.ctc_loss( - lprobs, src_tokens_flat, tgt_lens, src_lens, - reduction=reduction, zero_infinity=True - ) * self.ctc_weight + ctc_loss = ( + F.ctc_loss( + lprobs, + src_tokens_flat, + tgt_lens, + src_lens, + reduction=reduction, + zero_infinity=True, + ) + * self.ctc_weight + ) loss = l1_loss + dur_loss + pitch_loss + energy_loss + ctc_loss @@ -102,8 +106,12 @@ class FastSpeech2Loss(FairseqCriterion): ntot = sum(ns) ws = [n / (ntot + 1e-8) for n in ns] for key in [ - "loss", "l1_loss", "dur_loss", "pitch_loss", "energy_loss", - "ctc_loss" + "loss", + "l1_loss", + "dur_loss", + "pitch_loss", + "energy_loss", + "ctc_loss", ]: vals = [log.get(key, 0) for log in logging_outputs] val = sum(val * w for val, w in zip(vals, ws)) @@ -115,10 +123,10 @@ class FastSpeech2Loss(FairseqCriterion): return n = sum(log.get("targ_frames", 0) for log in logging_outputs) for key, new_key in [ - ("mcd_loss", "mcd_loss"), - ("pred_frames", "pred_ratio"), - ("nins", "ins_rate"), - ("ndel", "del_rate"), + ("mcd_loss", "mcd_loss"), + ("pred_frames", "pred_ratio"), + ("nins", "ins_rate"), + ("ndel", "del_rate"), ]: val = sum(log.get(key, 0) for log in logging_outputs) metrics.log_scalar(new_key, val / n, n, round=3) diff --git a/fairseq/criterions/hubert_criterion.py b/fairseq/criterions/hubert_criterion.py index 68cb24e6..83b514ae 100644 --- a/fairseq/criterions/hubert_criterion.py +++ b/fairseq/criterions/hubert_criterion.py @@ -37,7 +37,14 @@ class HubertCriterionConfig(FairseqDataclass): @register_criterion("hubert", dataclass=HubertCriterionConfig) class HubertCriterion(FairseqCriterion): - def __init__(self, task, pred_masked_weight, pred_nomask_weight, loss_weights=None, log_keys=None): + def __init__( + self, + task, + pred_masked_weight, + pred_nomask_weight, + loss_weights=None, + log_keys=None, + ): super().__init__(task) self.pred_masked_weight = pred_masked_weight self.pred_nomask_weight = pred_nomask_weight @@ -52,7 +59,7 @@ class HubertCriterion(FairseqCriterion): 3) logging outputs to display while training """ net_output = model(target_list=sample["target_list"], **sample["net_input"]) - loss = 0. + loss = 0.0 sample_size = 0 logging_output = {} reduction = "sum" if reduce else "none" @@ -89,7 +96,9 @@ class HubertCriterion(FairseqCriterion): names = [names] if len(self.loss_weights) == 1 and len(extra_losses) != 1: self.loss_weights = [self.loss_weights[0]] * len(extra_losses) - assert len(extra_losses) == len(self.loss_weights), f"{len(extra_losses)}, {len(self.loss_weights)}" + assert len(extra_losses) == len( + self.loss_weights + ), f"{len(extra_losses)}, {len(self.loss_weights)}" for p, n, coef in zip(extra_losses, names, self.loss_weights): if coef != 0 and p is not None: p = coef * p.float() * sample_size @@ -140,12 +149,20 @@ class HubertCriterion(FairseqCriterion): ntokens = sum(log.get("ntokens", 0) for log in logging_outputs) sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) - metrics.log_scalar("loss", loss_sum / sample_size / math.log(2), sample_size, round=3) + metrics.log_scalar( + "loss", loss_sum / sample_size / math.log(2), sample_size, round=3 + ) if sample_size != ntokens: - metrics.log_scalar("nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3) - metrics.log_derived("ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)) + metrics.log_scalar( + "nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3 + ) + metrics.log_derived( + "ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg) + ) else: - metrics.log_derived("ppl", lambda meters: utils.get_perplexity(meters["loss"].avg)) + metrics.log_derived( + "ppl", lambda meters: utils.get_perplexity(meters["loss"].avg) + ) counts = {} for lk in logging_outputs[0].keys(): diff --git a/fairseq/criterions/label_smoothed_cross_entropy_latency_augmented.py b/fairseq/criterions/label_smoothed_cross_entropy_latency_augmented.py index 223a16f7..d5fb390f 100644 --- a/fairseq/criterions/label_smoothed_cross_entropy_latency_augmented.py +++ b/fairseq/criterions/label_smoothed_cross_entropy_latency_augmented.py @@ -9,19 +9,20 @@ from fairseq import metrics, utils from fairseq.criterions import register_criterion from fairseq.criterions.label_smoothed_cross_entropy import ( LabelSmoothedCrossEntropyCriterion, - LabelSmoothedCrossEntropyCriterionConfig + LabelSmoothedCrossEntropyCriterionConfig, ) try: from simuleval.metrics.latency import ( AverageLagging, AverageProportion, - DifferentiableAverageLagging + DifferentiableAverageLagging, ) + LATENCY_METRICS = { "average_lagging": AverageLagging, "average_proportion": AverageProportion, - "differentiable_average_lagging": DifferentiableAverageLagging, + "differentiable_average_lagging": DifferentiableAverageLagging, } except ImportError: LATENCY_METRICS = None @@ -56,9 +57,10 @@ class LabelSmoothedCrossEntropyCriterionLatencyAugmentConfig( metadata={"help": "Add latency loss after certain steps"}, ) + @register_criterion( "latency_augmented_label_smoothed_cross_entropy", - dataclass=LabelSmoothedCrossEntropyCriterionLatencyAugmentConfig + dataclass=LabelSmoothedCrossEntropyCriterionLatencyAugmentConfig, ) class LatencyAugmentedLabelSmoothedCrossEntropyCriterion( LabelSmoothedCrossEntropyCriterion @@ -101,9 +103,9 @@ class LatencyAugmentedLabelSmoothedCrossEntropyCriterion( if self.latency_update_after > 0: num_updates = getattr(model.decoder, "num_updates", None) - assert num_updates is not None, ( - "model.decoder doesn't have attribute 'num_updates'" - ) + assert ( + num_updates is not None + ), "model.decoder doesn't have attribute 'num_updates'" if num_updates <= self.latency_update_after: latency_loss = 0 @@ -134,9 +136,7 @@ class LatencyAugmentedLabelSmoothedCrossEntropyCriterion( assert ( net_output[-1].encoder_padding_mask is None or not net_output[-1].encoder_padding_mask[:, 0].any() - ), ( - "Only right padding on source is supported." - ) + ), "Only right padding on source is supported." # 1. Obtain the expected alignment alpha_list = [item["alpha"] for item in net_output[1].attn_list] num_layers = len(alpha_list) @@ -174,8 +174,7 @@ class LatencyAugmentedLabelSmoothedCrossEntropyCriterion( .view(-1) ) expected_latency = LATENCY_METRICS[self.latency_avg_type]( - expected_delays, src_lengths, None, - target_padding_mask=target_padding_mask + expected_delays, src_lengths, None, target_padding_mask=target_padding_mask ) # 2.1 average expected latency of heads @@ -210,24 +209,12 @@ class LatencyAugmentedLabelSmoothedCrossEntropyCriterion( @classmethod def reduce_metrics(cls, logging_outputs) -> None: super().reduce_metrics(logging_outputs) - latency = sum( - log.get("latency", 0) for log in logging_outputs - ) - delays_var = sum( - log.get("delays_var", 0) for log in logging_outputs - ) - latency_loss = sum( - log.get("latency_loss", 0) for log in logging_outputs - ) + latency = sum(log.get("latency", 0) for log in logging_outputs) + delays_var = sum(log.get("delays_var", 0) for log in logging_outputs) + latency_loss = sum(log.get("latency_loss", 0) for log in logging_outputs) nsentences = sum(log.get("nsentences", 0) for log in logging_outputs) + metrics.log_scalar("latency", latency.float() / nsentences, nsentences, round=3) + metrics.log_scalar("delays_var", delays_var / nsentences, nsentences, round=3) metrics.log_scalar( - "latency", latency.float() / nsentences, nsentences, round=3 - ) - metrics.log_scalar( - "delays_var", delays_var / nsentences, - nsentences, round=3 - ) - metrics.log_scalar( - "latency_loss", latency_loss / nsentences, - nsentences, round=3 + "latency_loss", latency_loss / nsentences, nsentences, round=3 ) diff --git a/fairseq/criterions/tacotron2_loss.py b/fairseq/criterions/tacotron2_loss.py index 8c7b655c..11ebf2d8 100644 --- a/fairseq/criterions/tacotron2_loss.py +++ b/fairseq/criterions/tacotron2_loss.py @@ -41,9 +41,7 @@ class Tacotron2CriterionConfig(FairseqDataclass): default=0.4, metadata={"help": "weight of positive examples for BCE loss"}, ) - ctc_weight: float = field( - default=0.0, metadata={"help": "weight for CTC loss"} - ) + ctc_weight: float = field(default=0.0, metadata={"help": "weight for CTC loss"}) sentence_avg: bool = II("optimization.sentence_avg") @@ -70,8 +68,7 @@ class GuidedAttentionLoss(torch.nn.Module): bsz, max_s_len, max_t_len = len(src_lens), max(src_lens), max(tgt_lens) weights = torch.zeros((bsz, max_t_len, max_s_len)) for i, (s_len, t_len) in enumerate(zip(src_lens, tgt_lens)): - weights[i, :t_len, :s_len] = self._get_weight(s_len, t_len, - self.sigma) + weights[i, :t_len, :s_len] = self._get_weight(s_len, t_len, self.sigma) return weights @staticmethod @@ -90,9 +87,16 @@ class GuidedAttentionLoss(torch.nn.Module): @register_criterion("tacotron2", dataclass=Tacotron2CriterionConfig) class Tacotron2Criterion(FairseqCriterion): - def __init__(self, task, sentence_avg, n_frames_per_step, - use_guided_attention_loss, guided_attention_loss_sigma, - bce_pos_weight, ctc_weight): + def __init__( + self, + task, + sentence_avg, + n_frames_per_step, + use_guided_attention_loss, + guided_attention_loss_sigma, + bce_pos_weight, + ctc_weight, + ): super().__init__(task) self.sentence_avg = sentence_avg self.n_frames_per_step = n_frames_per_step @@ -120,31 +124,42 @@ class Tacotron2Criterion(FairseqCriterion): prev_output_tokens=sample["net_input"]["prev_output_tokens"], incremental_state=None, target_lengths=tgt_lens, - speaker=sample["speaker"] + speaker=sample["speaker"], ) l1_loss, mse_loss, eos_loss = self.compute_loss( - extra["feature_out"], feat_out, eos_out, feat_tgt, eos_tgt, - tgt_lens, reduction, + extra["feature_out"], + feat_out, + eos_out, + feat_tgt, + eos_tgt, + tgt_lens, + reduction, ) - attn_loss = torch.tensor(0.).type_as(l1_loss) + attn_loss = torch.tensor(0.0).type_as(l1_loss) if self.guided_attn is not None: - attn_loss = self.guided_attn(extra['attn'], src_lens, tgt_lens, reduction) - ctc_loss = torch.tensor(0.).type_as(l1_loss) - if self.ctc_weight > 0.: + attn_loss = self.guided_attn(extra["attn"], src_lens, tgt_lens, reduction) + ctc_loss = torch.tensor(0.0).type_as(l1_loss) + if self.ctc_weight > 0.0: net_output = (feat_out, eos_out, extra) lprobs = model.get_normalized_probs(net_output, log_probs=True) lprobs = lprobs.transpose(0, 1) # T x B x C src_mask = lengths_to_mask(src_lens) src_tokens_flat = src_tokens.masked_select(src_mask) - ctc_loss = F.ctc_loss( - lprobs, src_tokens_flat, tgt_lens, src_lens, - reduction=reduction, zero_infinity=True - ) * self.ctc_weight + ctc_loss = ( + F.ctc_loss( + lprobs, + src_tokens_flat, + tgt_lens, + src_lens, + reduction=reduction, + zero_infinity=True, + ) + * self.ctc_weight + ) loss = l1_loss + mse_loss + eos_loss + attn_loss + ctc_loss - sample_size = sample["nsentences"] if self.sentence_avg \ - else sample["ntokens"] + sample_size = sample["nsentences"] if self.sentence_avg else sample["ntokens"] logging_output = { "loss": utils.item(loss.data), "ntokens": sample["ntokens"], @@ -158,8 +173,16 @@ class Tacotron2Criterion(FairseqCriterion): } return loss, sample_size, logging_output - def compute_loss(self, feat_out, feat_out_post, eos_out, feat_tgt, - eos_tgt, tgt_lens, reduction="mean"): + def compute_loss( + self, + feat_out, + feat_out_post, + eos_out, + feat_tgt, + eos_tgt, + tgt_lens, + reduction="mean", + ): mask = lengths_to_mask(tgt_lens) _eos_out = eos_out[mask].squeeze() _eos_tgt = eos_tgt[mask] @@ -167,17 +190,17 @@ class Tacotron2Criterion(FairseqCriterion): _feat_out = feat_out[mask] _feat_out_post = feat_out_post[mask] - l1_loss = ( - F.l1_loss(_feat_out, _feat_tgt, reduction=reduction) + - F.l1_loss(_feat_out_post, _feat_tgt, reduction=reduction) + l1_loss = F.l1_loss(_feat_out, _feat_tgt, reduction=reduction) + F.l1_loss( + _feat_out_post, _feat_tgt, reduction=reduction ) - mse_loss = ( - F.mse_loss(_feat_out, _feat_tgt, reduction=reduction) + - F.mse_loss(_feat_out_post, _feat_tgt, reduction=reduction) + mse_loss = F.mse_loss(_feat_out, _feat_tgt, reduction=reduction) + F.mse_loss( + _feat_out_post, _feat_tgt, reduction=reduction ) eos_loss = F.binary_cross_entropy_with_logits( - _eos_out, _eos_tgt, pos_weight=torch.tensor(self.bce_pos_weight), - reduction=reduction + _eos_out, + _eos_tgt, + pos_weight=torch.tensor(self.bce_pos_weight), + reduction=reduction, ) return l1_loss, mse_loss, eos_loss @@ -197,10 +220,10 @@ class Tacotron2Criterion(FairseqCriterion): return n = sum(log.get("targ_frames", 0) for log in logging_outputs) for key, new_key in [ - ("mcd_loss", "mcd_loss"), - ("pred_frames", "pred_ratio"), - ("nins", "ins_rate"), - ("ndel", "del_rate"), + ("mcd_loss", "mcd_loss"), + ("pred_frames", "pred_ratio"), + ("nins", "ins_rate"), + ("ndel", "del_rate"), ]: val = sum(log.get(key, 0) for log in logging_outputs) metrics.log_scalar(new_key, val / n, n, round=3) diff --git a/fairseq/criterions/wav2vec_criterion.py b/fairseq/criterions/wav2vec_criterion.py index e04786cc..e37274d5 100644 --- a/fairseq/criterions/wav2vec_criterion.py +++ b/fairseq/criterions/wav2vec_criterion.py @@ -33,6 +33,7 @@ class Wav2VecCriterionConfig(FairseqDataclass): metadata={"help": "output keys to log"}, ) + @register_criterion("wav2vec", dataclass=Wav2VecCriterionConfig) class Wav2vecCriterion(FairseqCriterion): def __init__(self, task, infonce=False, loss_weights=None, log_keys=None): @@ -76,16 +77,16 @@ class Wav2vecCriterion(FairseqCriterion): # we don't shrink tensors using mask_indices. # Instead, we use mask indices to adjust loss. mi = ( - sample['net_input']['mask_indices'] + sample["net_input"]["mask_indices"] .transpose(0, 1) # logits are transposed in `model.get_logits` .reshape(logits.size(0)) ) loss = (loss * mi).sum() if reduce else (loss * mi) - if 'sample_size' in sample: - sample_size = sample['sample_size'] - elif 'mask_indices' in sample['net_input']: - sample_size = sample['net_input']['mask_indices'].sum() + if "sample_size" in sample: + sample_size = sample["sample_size"] + elif "mask_indices" in sample["net_input"]: + sample_size = sample["net_input"]["mask_indices"].sum() else: sample_size = target.numel() if self.infonce else target.long().sum().item() losses.append(loss.detach().clone()) @@ -216,8 +217,8 @@ class Wav2vecCriterion(FairseqCriterion): metrics.log_scalar(k, val / len(logging_outputs), round=3) # FIXME: revert when gather based xla reduction is implemented - #@staticmethod - #def logging_outputs_can_be_summed() -> bool: + # @staticmethod + # def logging_outputs_can_be_summed() -> bool: def logging_outputs_can_be_summed(self) -> bool: """ Whether the logging outputs returned by `forward` can be summed diff --git a/fairseq/data/add_target_dataset.py b/fairseq/data/add_target_dataset.py index d8a08e74..bf89f256 100644 --- a/fairseq/data/add_target_dataset.py +++ b/fairseq/data/add_target_dataset.py @@ -20,7 +20,7 @@ class AddTargetDataset(BaseWrapperDataset): process_label=None, label_len_fn=None, add_to_input=False, - text_compression_level=TextCompressionLevel.none + text_compression_level=TextCompressionLevel.none, ): super().__init__(dataset) self.labels = labels diff --git a/fairseq/data/audio/audio_utils.py b/fairseq/data/audio/audio_utils.py index 35e7fb2a..ac6b13d8 100644 --- a/fairseq/data/audio/audio_utils.py +++ b/fairseq/data/audio/audio_utils.py @@ -18,26 +18,28 @@ FEATURE_OR_SF_AUDIO_FILE_EXTENSIONS = {".npy", ".wav", ".flac", ".ogg"} def convert_waveform( - waveform: Union[np.ndarray, torch.Tensor], sample_rate: int, - normalize_volume: bool = False, to_mono: bool = False, - to_sample_rate: Optional[int] = None + waveform: Union[np.ndarray, torch.Tensor], + sample_rate: int, + normalize_volume: bool = False, + to_mono: bool = False, + to_sample_rate: Optional[int] = None, ) -> Tuple[Union[np.ndarray, torch.Tensor], int]: """convert a waveform: - - to a target sample rate - - from multi-channel to mono channel - - volume normalization + - to a target sample rate + - from multi-channel to mono channel + - volume normalization - Args: - waveform (numpy.ndarray or torch.Tensor): 2D original waveform - (channels x length) - sample_rate (int): original sample rate - normalize_volume (bool): perform volume normalization - to_mono (bool): convert to mono channel if having multiple channels - to_sample_rate (Optional[int]): target sample rate - Returns: - waveform (numpy.ndarray): converted 2D waveform (channels x length) - sample_rate (float): target sample rate - """ + Args: + waveform (numpy.ndarray or torch.Tensor): 2D original waveform + (channels x length) + sample_rate (int): original sample rate + normalize_volume (bool): perform volume normalization + to_mono (bool): convert to mono channel if having multiple channels + to_sample_rate (Optional[int]): target sample rate + Returns: + waveform (numpy.ndarray): converted 2D waveform (channels x length) + sample_rate (float): target sample rate + """ try: import torchaudio.sox_effects as ta_sox except ImportError: @@ -63,10 +65,14 @@ def convert_waveform( def get_waveform( - path_or_fp: Union[str, BinaryIO], normalization: bool = True, - mono: bool = True, frames: int = -1, start: int = 0, - always_2d: bool = True, output_sample_rate: Optional[int] = None, - normalize_volume: bool = False + path_or_fp: Union[str, BinaryIO], + normalization: bool = True, + mono: bool = True, + frames: int = -1, + start: int = 0, + always_2d: bool = True, + output_sample_rate: Optional[int] = None, + normalize_volume: bool = False, ) -> Tuple[np.ndarray, int]: """Get the waveform and sample rate of a 16-bit WAV/FLAC/OGG Vorbis audio. @@ -98,8 +104,11 @@ def get_waveform( ) waveform = waveform.T # T x C -> C x T waveform, sample_rate = convert_waveform( - waveform, sample_rate, normalize_volume=normalize_volume, to_mono=mono, - to_sample_rate=output_sample_rate + waveform, + sample_rate, + normalize_volume=normalize_volume, + to_mono=mono, + to_sample_rate=output_sample_rate, ) if not normalization: @@ -182,7 +191,7 @@ def is_sf_audio_data(data: bytes) -> bool: def mmap_read(path: str, offset: int, length: int) -> bytes: with open(path, "rb") as f: with mmap.mmap(f.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_o: - data = mmap_o[offset: offset + length] + data = mmap_o[offset : offset + length] return data @@ -215,9 +224,7 @@ def parse_path(path: str) -> Tuple[str, List[int]]: return _path, slice_ptr -def get_window( - window_fn: callable, n_fft: int, win_length: int -) -> torch.Tensor: +def get_window(window_fn: callable, n_fft: int, win_length: int) -> torch.Tensor: padding = n_fft - win_length assert padding >= 0 return F.pad(window_fn(win_length), (padding // 2, padding - padding // 2)) @@ -226,13 +233,13 @@ def get_window( def get_fourier_basis(n_fft: int) -> torch.Tensor: basis = np.fft.fft(np.eye(n_fft)) basis = np.vstack( - [np.real(basis[:n_fft // 2 + 1, :]), np.imag(basis[:n_fft // 2 + 1, :])] + [np.real(basis[: n_fft // 2 + 1, :]), np.imag(basis[: n_fft // 2 + 1, :])] ) return torch.from_numpy(basis).float() def get_mel_filters( - sample_rate: int, n_fft: int, n_mels: int, f_min: float, f_max: float + sample_rate: int, n_fft: int, n_mels: int, f_min: float, f_max: float ) -> torch.Tensor: try: import librosa @@ -244,8 +251,12 @@ def get_mel_filters( class TTSSpectrogram(torch.nn.Module): def __init__( - self, n_fft: int, win_length: int, hop_length: int, - window_fn: callable = torch.hann_window, return_phase: bool = False + self, + n_fft: int, + win_length: int, + hop_length: int, + window_fn: callable = torch.hann_window, + return_phase: bool = False, ) -> None: super(TTSSpectrogram, self).__init__() self.n_fft = n_fft @@ -254,16 +265,16 @@ class TTSSpectrogram(torch.nn.Module): basis = get_fourier_basis(n_fft).unsqueeze(1) basis *= get_window(window_fn, n_fft, win_length) - self.register_buffer('basis', basis) + self.register_buffer("basis", basis) def forward( - self, waveform: torch.Tensor + self, waveform: torch.Tensor ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: padding = (self.n_fft // 2, self.n_fft // 2) - x = F.pad(waveform.unsqueeze(1), padding, mode='reflect') + x = F.pad(waveform.unsqueeze(1), padding, mode="reflect") x = F.conv1d(x, self.basis, stride=self.hop_length) - real_part = x[:, :self.n_fft // 2 + 1, :] - imag_part = x[:, self.n_fft // 2 + 1:, :] + real_part = x[:, : self.n_fft // 2 + 1, :] + imag_part = x[:, self.n_fft // 2 + 1 :, :] magnitude = torch.sqrt(real_part ** 2 + imag_part ** 2) if self.return_phase: phase = torch.atan2(imag_part, real_part) @@ -273,13 +284,11 @@ class TTSSpectrogram(torch.nn.Module): class TTSMelScale(torch.nn.Module): def __init__( - self, n_mels: int, sample_rate: int, f_min: float, f_max: float, - n_stft: int + self, n_mels: int, sample_rate: int, f_min: float, f_max: float, n_stft: int ) -> None: super(TTSMelScale, self).__init__() - basis = get_mel_filters(sample_rate, (n_stft - 1) * 2, n_mels, f_min, - f_max) - self.register_buffer('basis', basis) + basis = get_mel_filters(sample_rate, (n_stft - 1) * 2, n_mels, f_min, f_max) + self.register_buffer("basis", basis) def forward(self, specgram: torch.Tensor) -> torch.Tensor: return torch.matmul(self.basis, specgram) diff --git a/fairseq/data/audio/frm_text_to_speech_dataset.py b/fairseq/data/audio/frm_text_to_speech_dataset.py index 125b1fc0..b54654d4 100644 --- a/fairseq/data/audio/frm_text_to_speech_dataset.py +++ b/fairseq/data/audio/frm_text_to_speech_dataset.py @@ -13,11 +13,10 @@ from typing import List, Optional import numpy as np import torch from fairseq.data import Dictionary -from fairseq.data.audio.speech_to_text_dataset import ( - S2TDataConfig -) +from fairseq.data.audio.speech_to_text_dataset import S2TDataConfig from fairseq.data.audio.text_to_speech_dataset import ( - TextToSpeechDataset, TextToSpeechDatasetCreator + TextToSpeechDataset, + TextToSpeechDatasetCreator, ) logger = logging.getLogger(__name__) @@ -48,7 +47,7 @@ class FrmTextToSpeechDataset(TextToSpeechDataset): chunk_incr=5, add_eos=True, dedup=True, - ref_fpu=-1 + ref_fpu=-1, ): # It assumes texts are encoded at a fixed frame-rate super().__init__( @@ -67,7 +66,7 @@ class FrmTextToSpeechDataset(TextToSpeechDataset): pre_tokenizer=pre_tokenizer, bpe_tokenizer=bpe_tokenizer, n_frames_per_step=n_frames_per_step, - speaker_to_id=speaker_to_id + speaker_to_id=speaker_to_id, ) self.do_chunk = do_chunk @@ -92,24 +91,23 @@ class FrmTextToSpeechDataset(TextToSpeechDataset): fpu = source.size(0) / target.size(0) # frame-per-unit fps = self.n_frames_per_step assert ( - self.ref_fpu == -1 or - abs((fpu * fps - self.ref_fpu) / self.ref_fpu) < 0.1 + self.ref_fpu == -1 or abs((fpu * fps - self.ref_fpu) / self.ref_fpu) < 0.1 ), f"{fpu*fps} != {self.ref_fpu}" # only chunk training split if self.is_train_split and self.do_chunk and self.chunk_size > 0: - lang = target[:int(self.data_cfg.prepend_tgt_lang_tag)] - text = target[int(self.data_cfg.prepend_tgt_lang_tag):] + lang = target[: int(self.data_cfg.prepend_tgt_lang_tag)] + text = target[int(self.data_cfg.prepend_tgt_lang_tag) :] size = len(text) chunk_size = min(self.chunk_size, size) chunk_start = np.random.randint(size - chunk_size + 1) - text = text[chunk_start:chunk_start+chunk_size] + text = text[chunk_start : chunk_start + chunk_size] target = torch.cat((lang, text), 0) f_size = int(np.floor(chunk_size * fpu)) f_start = int(np.floor(chunk_start * fpu)) - assert(f_size > 0) - source = source[f_start:f_start+f_size, :] + assert f_size > 0 + source = source[f_start : f_start + f_size, :] if self.dedup: target = torch.unique_consecutive(target) @@ -126,10 +124,12 @@ class FrmTextToSpeechDataset(TextToSpeechDataset): self.chunk_size = self.chunk_init + epoch * self.chunk_incr if self.chunk_bound > 0: self.chunk_size = min(self.chunk_size, self.chunk_bound) - logger.info(( - f"{self.split}: setting chunk size " - f"from {old} to {self.chunk_size}" - )) + logger.info( + ( + f"{self.split}: setting chunk size " + f"from {old} to {self.chunk_size}" + ) + ) class FrmTextToSpeechDatasetCreator(TextToSpeechDatasetCreator): @@ -152,7 +152,7 @@ class FrmTextToSpeechDatasetCreator(TextToSpeechDatasetCreator): chunk_incr: int = 5, add_eos: bool = True, dedup: bool = True, - ref_fpu: float = -1 + ref_fpu: float = -1, ) -> FrmTextToSpeechDataset: tsv_path = op.join(root, f"{split}.tsv") if not op.isfile(tsv_path): @@ -170,9 +170,7 @@ class FrmTextToSpeechDatasetCreator(TextToSpeechDatasetCreator): assert len(s) > 0 ids = [ss[cls.KEY_ID] for ss in s] - audio_paths = [ - op.join(data_cfg.audio_root, ss[cls.KEY_AUDIO]) for ss in s - ] + audio_paths = [op.join(data_cfg.audio_root, ss[cls.KEY_AUDIO]) for ss in s] n_frames = [int(ss[cls.KEY_N_FRAMES]) for ss in s] tgt_texts = [ss[cls.KEY_TGT_TEXT] for ss in s] src_texts = [ss.get(cls.KEY_SRC_TEXT, cls.DEFAULT_SRC_TEXT) for ss in s] @@ -203,5 +201,5 @@ class FrmTextToSpeechDatasetCreator(TextToSpeechDatasetCreator): chunk_incr=chunk_incr, add_eos=add_eos, dedup=dedup, - ref_fpu=ref_fpu + ref_fpu=ref_fpu, ) diff --git a/fairseq/data/audio/hubert_dataset.py b/fairseq/data/audio/hubert_dataset.py index f00fe301..1c0267bb 100644 --- a/fairseq/data/audio/hubert_dataset.py +++ b/fairseq/data/audio/hubert_dataset.py @@ -152,10 +152,7 @@ class HubertDataset(FairseqDataset): self.label_offsets_list = [ load_label_offset(p, inds, tot) for p in label_paths ] - assert ( - label_processors is None - or len(label_processors) == self.num_labels - ) + assert label_processors is None or len(label_processors) == self.num_labels for label_path, label_rate in zip(label_paths, self.label_rates): verify_label_lengths( self.sizes, sample_rate, label_path, label_rate, inds, tot @@ -234,8 +231,7 @@ class HubertDataset(FairseqDataset): ) targets_by_label = [ - [s["label_list"][i] for s in samples] - for i in range(self.num_labels) + [s["label_list"][i] for s in samples] for i in range(self.num_labels) ] targets_list, lengths_list, ntokens_list = self.collater_label( targets_by_label, audio_size, audio_starts @@ -270,9 +266,7 @@ class HubertDataset(FairseqDataset): collated_audios[i] = audio elif diff < 0: assert self.pad_audio - collated_audios[i] = torch.cat( - [audio, audio.new_full((-diff,), 0.0)] - ) + collated_audios[i] = torch.cat([audio, audio.new_full((-diff,), 0.0)]) padding_mask[i, diff:] = True else: collated_audios[i], audio_starts[i] = self.crop_to_max_size( @@ -280,9 +274,7 @@ class HubertDataset(FairseqDataset): ) return collated_audios, padding_mask, audio_starts - def collater_frm_label( - self, targets, audio_size, audio_starts, label_rate, pad - ): + def collater_frm_label(self, targets, audio_size, audio_starts, label_rate, pad): assert label_rate > 0 s2f = label_rate / self.sample_rate frm_starts = [int(round(s * s2f)) for s in audio_starts] @@ -290,24 +282,20 @@ class HubertDataset(FairseqDataset): if not self.pad_audio: rem_size = [len(t) - s for t, s in zip(targets, frm_starts)] frm_size = min(frm_size, *rem_size) - targets = [t[s: s + frm_size] for t, s in zip(targets, frm_starts)] + targets = [t[s : s + frm_size] for t, s in zip(targets, frm_starts)] logger.debug(f"audio_starts={audio_starts}") logger.debug(f"frame_starts={frm_starts}") logger.debug(f"frame_size={frm_size}") lengths = torch.LongTensor([len(t) for t in targets]) ntokens = lengths.sum().item() - targets = data_utils.collate_tokens( - targets, pad_idx=pad, left_pad=False - ) + targets = data_utils.collate_tokens(targets, pad_idx=pad, left_pad=False) return targets, lengths, ntokens def collater_seq_label(self, targets, pad): lengths = torch.LongTensor([len(t) for t in targets]) ntokens = lengths.sum().item() - targets = data_utils.collate_tokens( - targets, pad_idx=pad, left_pad=False - ) + targets = data_utils.collate_tokens(targets, pad_idx=pad, left_pad=False) return targets, lengths, ntokens def collater_label(self, targets_by_label, audio_size, audio_starts): @@ -315,9 +303,7 @@ class HubertDataset(FairseqDataset): itr = zip(targets_by_label, self.label_rates, self.pad_list) for targets, label_rate, pad in itr: if label_rate == -1: - targets, lengths, ntokens = self.collater_seq_label( - targets, pad - ) + targets, lengths, ntokens = self.collater_seq_label(targets, pad) else: targets, lengths, ntokens = self.collater_frm_label( targets, audio_size, audio_starts, label_rate, pad diff --git a/fairseq/data/audio/multi_modality_dataset.py b/fairseq/data/audio/multi_modality_dataset.py index 69d23d31..625a16ec 100644 --- a/fairseq/data/audio/multi_modality_dataset.py +++ b/fairseq/data/audio/multi_modality_dataset.py @@ -29,6 +29,7 @@ class ModalityDatasetItem(NamedTuple): max_tokens: Optional[int] = None max_sentences: Optional[int] = None + # MultiModalityDataset: it concate multiple datasets with different modalities. # Compared with ConcatDataset it can 1) sample data given the ratios for different datasets # 2) it adds mode to indicate what type of the data samples come from. diff --git a/fairseq/data/audio/raw_audio_dataset.py b/fairseq/data/audio/raw_audio_dataset.py index f4e96549..181e2bbc 100644 --- a/fairseq/data/audio/raw_audio_dataset.py +++ b/fairseq/data/audio/raw_audio_dataset.py @@ -308,6 +308,7 @@ class FileAudioDataset(RawAudioDataset): def __getitem__(self, index): import soundfile as sf + fn = self.fnames[index] fn = fn if isinstance(self.fnames, list) else fn.as_py() fn = self.text_compressor.decompress(fn) diff --git a/fairseq/data/audio/speech_to_text_dataset.py b/fairseq/data/audio/speech_to_text_dataset.py index 164bf413..b6dfd9ae 100644 --- a/fairseq/data/audio/speech_to_text_dataset.py +++ b/fairseq/data/audio/speech_to_text_dataset.py @@ -45,7 +45,11 @@ def get_features_from_npy_or_audio(path): def get_features_or_waveform_from_stored_zip( - path, byte_offset, byte_size, need_waveform=False, use_sample_rate=None, + path, + byte_offset, + byte_size, + need_waveform=False, + use_sample_rate=None, ): assert path.endswith(".zip") data = read_from_stored_zip(path, byte_offset, byte_size) @@ -53,18 +57,17 @@ def get_features_or_waveform_from_stored_zip( if is_npy_data(data): features_or_waveform = np.load(f) elif is_sf_audio_data(data): - features_or_waveform = \ - get_waveform( - f, always_2d=False, output_sample_rate=use_sample_rate - )[0] if need_waveform else get_fbank(f) + features_or_waveform = ( + get_waveform(f, always_2d=False, output_sample_rate=use_sample_rate)[0] + if need_waveform + else get_fbank(f) + ) else: raise ValueError(f'Unknown file format for "{path}"') return features_or_waveform -def get_features_or_waveform( - path: str, need_waveform=False, use_sample_rate=None -): +def get_features_or_waveform(path: str, need_waveform=False, use_sample_rate=None): """Get speech features from .npy file or waveform from .wav/.flac file. The file may be inside an uncompressed ZIP file and is accessed via byte offset and length. @@ -87,8 +90,11 @@ def get_features_or_waveform( return get_features_from_npy_or_audio(_path) elif len(slice_ptr) == 2: features_or_waveform = get_features_or_waveform_from_stored_zip( - _path, slice_ptr[0], slice_ptr[1], need_waveform=need_waveform, - use_sample_rate=use_sample_rate + _path, + slice_ptr[0], + slice_ptr[1], + need_waveform=need_waveform, + use_sample_rate=use_sample_rate, ) else: raise ValueError(f"Invalid path: {path}") @@ -145,7 +151,7 @@ class SpeechToTextDataset(FairseqDataset): pre_tokenizer=None, bpe_tokenizer=None, n_frames_per_step=1, - speaker_to_id=None + speaker_to_id=None, ): self.split, self.is_train_split = split, is_train_split self.cfg = cfg @@ -235,7 +241,7 @@ class SpeechToTextDataset(FairseqDataset): if self.n_frames_per_step == 1: return feature n_packed_frames = feature.shape[0] // self.n_frames_per_step - feature = feature[:self.n_frames_per_step * n_packed_frames] + feature = feature[: self.n_frames_per_step * n_packed_frames] return feature.reshape(n_packed_frames, -1) @classmethod @@ -318,9 +324,11 @@ class SpeechToTextDataset(FairseqDataset): speaker = None if self.speaker_to_id is not None: - speaker = torch.tensor( - [s.speaker_id for s in samples], dtype=torch.long - ).index_select(0, order).view(-1, 1) + speaker = ( + torch.tensor([s.speaker_id for s in samples], dtype=torch.long) + .index_select(0, order) + .view(-1, 1) + ) net_input = { "src_tokens": frames, @@ -388,7 +396,7 @@ class SpeechToTextDatasetCreator(object): pre_tokenizer, bpe_tokenizer, n_frames_per_step, - speaker_to_id + speaker_to_id, ) -> SpeechToTextDataset: audio_root = Path(cfg.audio_root) ids = [s[cls.KEY_ID] for s in samples] @@ -415,7 +423,7 @@ class SpeechToTextDatasetCreator(object): pre_tokenizer=pre_tokenizer, bpe_tokenizer=bpe_tokenizer, n_frames_per_step=n_frames_per_step, - speaker_to_id=speaker_to_id + speaker_to_id=speaker_to_id, ) @classmethod @@ -481,12 +489,19 @@ class SpeechToTextDatasetCreator(object): pre_tokenizer, bpe_tokenizer, n_frames_per_step, - speaker_to_id + speaker_to_id, ) -> SpeechToTextDataset: samples = cls._load_samples_from_tsv(root, split) return cls._from_list( - split, is_train_split, samples, cfg, tgt_dict, pre_tokenizer, - bpe_tokenizer, n_frames_per_step, speaker_to_id + split, + is_train_split, + samples, + cfg, + tgt_dict, + pre_tokenizer, + bpe_tokenizer, + n_frames_per_step, + speaker_to_id, ) @classmethod @@ -502,12 +517,19 @@ class SpeechToTextDatasetCreator(object): epoch: int, seed: int, n_frames_per_step: int = 1, - speaker_to_id=None + speaker_to_id=None, ) -> SpeechToTextDataset: datasets = [ cls._from_tsv( - root, cfg, split, tgt_dict, is_train_split, pre_tokenizer, - bpe_tokenizer, n_frames_per_step, speaker_to_id + root, + cfg, + split, + tgt_dict, + is_train_split, + pre_tokenizer, + bpe_tokenizer, + n_frames_per_step, + speaker_to_id, ) for split in splits.split(",") ] diff --git a/fairseq/data/audio/text_to_speech_dataset.py b/fairseq/data/audio/text_to_speech_dataset.py index abfcb2be..0e1489ae 100644 --- a/fairseq/data/audio/text_to_speech_dataset.py +++ b/fairseq/data/audio/text_to_speech_dataset.py @@ -13,8 +13,11 @@ import numpy as np import torch from fairseq.data.audio.speech_to_text_dataset import ( - SpeechToTextDataset, SpeechToTextDatasetCreator, S2TDataConfig, - _collate_frames, get_features_or_waveform + SpeechToTextDataset, + SpeechToTextDatasetCreator, + S2TDataConfig, + _collate_frames, + get_features_or_waveform, ) from fairseq.data import Dictionary, data_utils as fairseq_data_utils @@ -32,34 +35,44 @@ class TextToSpeechDatasetItem(object): class TextToSpeechDataset(SpeechToTextDataset): def __init__( - self, - split: str, - is_train_split: bool, - cfg: S2TDataConfig, - audio_paths: List[str], - n_frames: List[int], - src_texts: Optional[List[str]] = None, - tgt_texts: Optional[List[str]] = None, - speakers: Optional[List[str]] = None, - src_langs: Optional[List[str]] = None, - tgt_langs: Optional[List[str]] = None, - ids: Optional[List[str]] = None, - tgt_dict: Optional[Dictionary] = None, - pre_tokenizer=None, - bpe_tokenizer=None, - n_frames_per_step=1, - speaker_to_id=None, - durations: Optional[List[List[int]]] = None, - pitches: Optional[List[str]] = None, - energies: Optional[List[str]] = None + self, + split: str, + is_train_split: bool, + cfg: S2TDataConfig, + audio_paths: List[str], + n_frames: List[int], + src_texts: Optional[List[str]] = None, + tgt_texts: Optional[List[str]] = None, + speakers: Optional[List[str]] = None, + src_langs: Optional[List[str]] = None, + tgt_langs: Optional[List[str]] = None, + ids: Optional[List[str]] = None, + tgt_dict: Optional[Dictionary] = None, + pre_tokenizer=None, + bpe_tokenizer=None, + n_frames_per_step=1, + speaker_to_id=None, + durations: Optional[List[List[int]]] = None, + pitches: Optional[List[str]] = None, + energies: Optional[List[str]] = None, ): super(TextToSpeechDataset, self).__init__( - split, is_train_split, cfg, audio_paths, n_frames, - src_texts=src_texts, tgt_texts=tgt_texts, speakers=speakers, - src_langs=src_langs, tgt_langs=tgt_langs, ids=ids, - tgt_dict=tgt_dict, pre_tokenizer=pre_tokenizer, - bpe_tokenizer=bpe_tokenizer, n_frames_per_step=n_frames_per_step, - speaker_to_id=speaker_to_id + split, + is_train_split, + cfg, + audio_paths, + n_frames, + src_texts=src_texts, + tgt_texts=tgt_texts, + speakers=speakers, + src_langs=src_langs, + tgt_langs=tgt_langs, + ids=ids, + tgt_dict=tgt_dict, + pre_tokenizer=pre_tokenizer, + bpe_tokenizer=bpe_tokenizer, + n_frames_per_step=n_frames_per_step, + speaker_to_id=speaker_to_id, ) self.durations = durations self.pitches = pitches @@ -84,9 +97,13 @@ class TextToSpeechDataset(SpeechToTextDataset): np.concatenate((energy, [0])) # pad 0 for EOS ).float() return TextToSpeechDatasetItem( - index=index, source=s2t_item.source, target=s2t_item.target, - speaker_id=s2t_item.speaker_id, duration=duration, pitch=pitch, - energy=energy + index=index, + source=s2t_item.source, + target=s2t_item.target, + speaker_id=s2t_item.speaker_id, + duration=duration, + pitch=pitch, + energy=energy, ) def collater(self, samples: List[TextToSpeechDatasetItem]) -> Dict[str, Any]: @@ -96,8 +113,9 @@ class TextToSpeechDataset(SpeechToTextDataset): src_lengths, order = torch.tensor( [s.target.shape[0] for s in samples], dtype=torch.long ).sort(descending=True) - id_ = torch.tensor([s.index for s in samples], - dtype=torch.long).index_select(0, order) + id_ = torch.tensor([s.index for s in samples], dtype=torch.long).index_select( + 0, order + ) feat = _collate_frames( [s.source for s in samples], self.cfg.use_audio_input ).index_select(0, order) @@ -115,9 +133,11 @@ class TextToSpeechDataset(SpeechToTextDataset): speaker = None if self.speaker_to_id is not None: - speaker = torch.tensor( - [s.speaker_id for s in samples], dtype=torch.long - ).index_select(0, order).view(-1, 1) + speaker = ( + torch.tensor([s.speaker_id for s in samples], dtype=torch.long) + .index_select(0, order) + .view(-1, 1) + ) bsz, _, d = feat.size() prev_output_tokens = torch.cat( @@ -175,7 +195,7 @@ class TextToSpeechDatasetCreator(SpeechToTextDatasetCreator): pre_tokenizer, bpe_tokenizer, n_frames_per_step, - speaker_to_id + speaker_to_id, ) -> TextToSpeechDataset: audio_root = Path(cfg.audio_root) ids = [s[cls.KEY_ID] for s in samples] @@ -189,27 +209,40 @@ class TextToSpeechDatasetCreator(SpeechToTextDatasetCreator): durations = [s.get(cls.KEY_DURATION, None) for s in samples] durations = [ - None if dd is None else [int(d) for d in dd.split(" ")] - for dd in durations + None if dd is None else [int(d) for d in dd.split(" ")] for dd in durations ] durations = None if any(dd is None for dd in durations) else durations pitches = [s.get(cls.KEY_PITCH, None) for s in samples] pitches = [ - None if pp is None else (audio_root / pp).as_posix() - for pp in pitches + None if pp is None else (audio_root / pp).as_posix() for pp in pitches ] pitches = None if any(pp is None for pp in pitches) else pitches energies = [s.get(cls.KEY_ENERGY, None) for s in samples] energies = [ - None if ee is None else (audio_root / ee).as_posix() - for ee in energies] + None if ee is None else (audio_root / ee).as_posix() for ee in energies + ] energies = None if any(ee is None for ee in energies) else energies return TextToSpeechDataset( - split_name, is_train_split, cfg, audio_paths, n_frames, - src_texts, tgt_texts, speakers, src_langs, tgt_langs, ids, tgt_dict, - pre_tokenizer, bpe_tokenizer, n_frames_per_step, speaker_to_id, - durations, pitches, energies + split_name, + is_train_split, + cfg, + audio_paths, + n_frames, + src_texts, + tgt_texts, + speakers, + src_langs, + tgt_langs, + ids, + tgt_dict, + pre_tokenizer, + bpe_tokenizer, + n_frames_per_step, + speaker_to_id, + durations, + pitches, + energies, ) diff --git a/fairseq/data/colorize_dataset.py b/fairseq/data/colorize_dataset.py index 6ef097bf..7a6d2713 100644 --- a/fairseq/data/colorize_dataset.py +++ b/fairseq/data/colorize_dataset.py @@ -9,7 +9,7 @@ from . import BaseWrapperDataset class ColorizeDataset(BaseWrapperDataset): - """ Adds 'colors' property to net input that is obtained from the provided color getter for use by models """ + """Adds 'colors' property to net input that is obtained from the provided color getter for use by models""" def __init__(self, dataset, color_getter): super().__init__(dataset) diff --git a/fairseq/data/data_utils.py b/fairseq/data/data_utils.py index b3de5768..7914e605 100644 --- a/fairseq/data/data_utils.py +++ b/fairseq/data/data_utils.py @@ -69,6 +69,7 @@ def collate_tokens( copy_tensor(v, res[i][size - len(v) :] if left_pad else res[i][: len(v)]) return res + def load_indexed_dataset( path, dictionary=None, dataset_impl=None, combine=False, default="cached" ): @@ -324,9 +325,7 @@ def batch_by_size( ) # added int() to avoid TypeError: an integer is required - max_tokens = ( - int(max_tokens) if max_tokens is not None else -1 - ) + max_tokens = int(max_tokens) if max_tokens is not None else -1 max_sentences = max_sentences if max_sentences is not None else -1 bsz_mult = required_batch_size_multiple @@ -375,8 +374,9 @@ def post_process(sentence: str, symbol: str): sentence = sentence.replace(" ", "").replace("|", " ").strip() elif symbol == "silence": import re + sentence = sentence.replace("", "") - sentence = re.sub(' +', ' ', sentence).strip() + sentence = re.sub(" +", " ", sentence).strip() elif symbol == "_EOW": sentence = sentence.replace(" ", "").replace("_EOW", " ").strip() elif symbol in {"subword_nmt", "@@ ", "@@"}: @@ -547,7 +547,7 @@ def get_buckets(sizes, num_buckets): np.percentile( sizes, np.linspace(0, 100, num_buckets + 1), - interpolation='lower', + interpolation="lower", )[1:] ) return buckets @@ -564,7 +564,6 @@ def get_bucketed_sizes(orig_sizes, buckets): return sizes - def _find_extra_valid_paths(dataset_path: str) -> set: paths = utils.split_paths(dataset_path) all_valid_paths = set() diff --git a/fairseq/data/encoders/sentencepiece_bpe.py b/fairseq/data/encoders/sentencepiece_bpe.py index fc830f6e..0aa6cd76 100644 --- a/fairseq/data/encoders/sentencepiece_bpe.py +++ b/fairseq/data/encoders/sentencepiece_bpe.py @@ -21,8 +21,10 @@ class SentencepieceConfig(FairseqDataclass): ) sentencepiece_alpha: Optional[float] = field( default=None, - metadata={"help": "soothing parameter for unigram sampling, " - "and merge probability for BPE-dropout"} + metadata={ + "help": "soothing parameter for unigram sampling, " + "and merge probability for BPE-dropout" + }, ) @@ -45,8 +47,7 @@ class SentencepieceBPE(object): def encode(self, x: str) -> str: return " ".join( self.sp.Encode( - x, out_type=str, enable_sampling=self.enable_sampling, - alpha=self.alpha + x, out_type=str, enable_sampling=self.enable_sampling, alpha=self.alpha ) ) diff --git a/fairseq/data/fairseq_dataset.py b/fairseq/data/fairseq_dataset.py index 23e6992d..2bde7fc5 100644 --- a/fairseq/data/fairseq_dataset.py +++ b/fairseq/data/fairseq_dataset.py @@ -138,7 +138,7 @@ class FairseqDataset(torch.utils.data.Dataset, EpochListening): ) try: - num_tokens_vec = self.num_tokens_vec(indices).astype('int64') + num_tokens_vec = self.num_tokens_vec(indices).astype("int64") except NotImplementedError: num_tokens_vec = None diff --git a/fairseq/data/huffman/huffman_coder.py b/fairseq/data/huffman/huffman_coder.py index 6531f154..c04f8456 100644 --- a/fairseq/data/huffman/huffman_coder.py +++ b/fairseq/data/huffman/huffman_coder.py @@ -140,7 +140,9 @@ class HuffmanNode: def is_leaf(self) -> bool: return self.left is None and self.right is None - def code_table(self, prefix: tp.Optional[bitarray] = None) -> tp.Dict[str, "HuffmanNode"]: + def code_table( + self, prefix: tp.Optional[bitarray] = None + ) -> tp.Dict[str, "HuffmanNode"]: defaulted_prefix = prefix if prefix is not None else bitarray() if self.is_leaf(): self.code = ( diff --git a/fairseq/data/indexed_dataset.py b/fairseq/data/indexed_dataset.py index 23afb433..d0843926 100644 --- a/fairseq/data/indexed_dataset.py +++ b/fairseq/data/indexed_dataset.py @@ -67,7 +67,9 @@ def make_builder(out_file, impl, vocab_size=None): elif impl == "fasta": raise NotImplementedError elif impl == "huffman": - raise ValueError("Use HuffmanCodeBuilder directly as it has a different interface.") + raise ValueError( + "Use HuffmanCodeBuilder directly as it has a different interface." + ) else: return IndexedDatasetBuilder(out_file) diff --git a/fairseq/data/iterators.py b/fairseq/data/iterators.py index 81b5f565..e16d91ef 100644 --- a/fairseq/data/iterators.py +++ b/fairseq/data/iterators.py @@ -380,7 +380,9 @@ class EpochBatchIterator(EpochBatchIterating): # reset _frozen_batches to refresh the next epoch self._frozen_batches = None self._cur_epoch_itr = self._get_iterator_for_epoch( - self.epoch, shuffle, fix_batches_to_gpus=fix_batches_to_gpus, + self.epoch, + shuffle, + fix_batches_to_gpus=fix_batches_to_gpus, ) self.shuffle = shuffle return self._cur_epoch_itr @@ -421,7 +423,9 @@ class EpochBatchIterator(EpochBatchIterating): if itr_pos > 0: # fast-forward epoch iterator self._next_epoch_itr = self._get_iterator_for_epoch( - self.epoch, shuffle=state_dict.get("shuffle", True), offset=itr_pos, + self.epoch, + shuffle=state_dict.get("shuffle", True), + offset=itr_pos, ) if self._next_epoch_itr is None: if version == 1: diff --git a/fairseq/data/language_pair_dataset.py b/fairseq/data/language_pair_dataset.py index ff3e14bf..fd356ddd 100644 --- a/fairseq/data/language_pair_dataset.py +++ b/fairseq/data/language_pair_dataset.py @@ -114,7 +114,10 @@ def collate( "id": id, "nsentences": len(samples), "ntokens": ntokens, - "net_input": {"src_tokens": src_tokens, "src_lengths": src_lengths,}, + "net_input": { + "src_tokens": src_tokens, + "src_lengths": src_lengths, + }, "target": target, } if prev_output_tokens is not None: @@ -467,5 +470,8 @@ class LanguagePairDataset(FairseqDataset): list: list of removed indices """ return data_utils.filter_paired_dataset_indices_by_size( - self.src_sizes, self.tgt_sizes, indices, max_sizes, + self.src_sizes, + self.tgt_sizes, + indices, + max_sizes, ) diff --git a/fairseq/data/multi_corpus_dataset.py b/fairseq/data/multi_corpus_dataset.py index 1566d8e0..a3f47c72 100644 --- a/fairseq/data/multi_corpus_dataset.py +++ b/fairseq/data/multi_corpus_dataset.py @@ -80,7 +80,9 @@ class MultiCorpusDataset(FairseqDataset): def ordered_indices(self): start = time.time() with data_utils.numpy_seed(self.seed, self.epoch): - logger.info(f"sampling new dataset with seed {self.seed} epoch {self.epoch}") + logger.info( + f"sampling new dataset with seed {self.seed} epoch {self.epoch}" + ) sampled_indices = [] num_selected_instances = 0 diff --git a/fairseq/data/multilingual/multilingual_data_manager.py b/fairseq/data/multilingual/multilingual_data_manager.py index 137481b4..8dae99d9 100644 --- a/fairseq/data/multilingual/multilingual_data_manager.py +++ b/fairseq/data/multilingual/multilingual_data_manager.py @@ -40,8 +40,8 @@ from fairseq.utils import FileContentsAction, csv_str_list, eval_str_dict logger = logging.getLogger(__name__) -SRC_DICT_NAME = 'src' -TGT_DICT_NAME = 'tgt' +SRC_DICT_NAME = "src" +TGT_DICT_NAME = "tgt" def _lang_id(dic: Dictionary, lang: str): @@ -64,14 +64,16 @@ class MultilingualDatasetManager(object): self.seed = args.seed self.lang_pairs = lang_pairs self.extra_lang_pairs = ( - list( - {p for _, v in args.extra_lang_pairs.items() for p in v.split(",")} - ) - if args.extra_lang_pairs - else [] - ) - self.src_langs = {p.split("-")[0] for p in args.lang_pairs + self.extra_lang_pairs} - self.tgt_langs = {p.split("-")[1] for p in args.lang_pairs + self.extra_lang_pairs} + list({p for _, v in args.extra_lang_pairs.items() for p in v.split(",")}) + if args.extra_lang_pairs + else [] + ) + self.src_langs = { + p.split("-")[0] for p in args.lang_pairs + self.extra_lang_pairs + } + self.tgt_langs = { + p.split("-")[1] for p in args.lang_pairs + self.extra_lang_pairs + } self.langs = langs self.dicts = dicts self.lang_dict = self.create_lang_dictionary(self.langs) @@ -111,10 +113,18 @@ class MultilingualDatasetManager(object): "note that the ordering determines language token IDs; " "--langs and --lang-dict are two exclusive options", ) - parser.add_argument('--source-dict', default=None, type=str, - help='path to source dictionary; if specified it will override per language dictionary loading') - parser.add_argument('--target-dict', default=None, type=str, - help='path to target dictionary; if specified it will override per language dictionary loading') + parser.add_argument( + "--source-dict", + default=None, + type=str, + help="path to source dictionary; if specified it will override per language dictionary loading", + ) + parser.add_argument( + "--target-dict", + default=None, + type=str, + help="path to target dictionary; if specified it will override per language dictionary loading", + ) parser.add_argument( "--lang-tok-style", default=LangTokStyle.multilingual.value, @@ -378,7 +388,9 @@ class MultilingualDatasetManager(object): ) return d - dicts = cls.load_all_dictionaries(args, language_list, load_dictionary_and_postproc, training) + dicts = cls.load_all_dictionaries( + args, language_list, load_dictionary_and_postproc, training + ) return language_list, dicts, training @classmethod @@ -424,7 +436,10 @@ class MultilingualDatasetManager(object): if args.fixed_dictionary is not None: fixed_dict = load_dictionary(args.fixed_dictionary) - dicts = {lang: fixed_dict for lang in src_langs_to_load_dicts + tgt_langs_to_load_dicts} + dicts = { + lang: fixed_dict + for lang in src_langs_to_load_dicts + tgt_langs_to_load_dicts + } else: if args.source_dict is None: load_dicts(src_langs_to_load_dicts) @@ -477,7 +492,10 @@ class MultilingualDatasetManager(object): lang=tgt_lang, lang_tok_style=self.args.lang_tok_style, spec=spec ) return self.get_langtok_index( - langtok, self.get_source_dictionary(src_lang) if src_lang else self.get_target_dictionary(tgt_lang) + langtok, + self.get_source_dictionary(src_lang) + if src_lang + else self.get_target_dictionary(tgt_lang), ) def get_decoder_langtok(self, tgt_lang, spec=None): @@ -819,7 +837,9 @@ class MultilingualDatasetManager(object): if self.args.lang_tok_replacing_bos_eos: ds = self.alter_dataset_langtok( langpair_ds, - src_eos=self.get_source_dictionary(src).eos() if src else self.get_target_dictionary(tgt).eos(), + src_eos=self.get_source_dictionary(src).eos() + if src + else self.get_target_dictionary(tgt).eos(), src_lang=src, tgt_eos=self.get_target_dictionary(tgt).eos(), tgt_lang=tgt, diff --git a/fairseq/data/noising.py b/fairseq/data/noising.py index 2b1cc347..e92e83c2 100644 --- a/fairseq/data/noising.py +++ b/fairseq/data/noising.py @@ -298,7 +298,6 @@ class NoisingDataset(torch.utils.data.Dataset): ) self.sizes = src_dataset.sizes - def __getitem__(self, index): """ Returns a single noisy sample. Multiple samples are fed to the collater diff --git a/fairseq/data/text_compressor.py b/fairseq/data/text_compressor.py index 561e9ac8..8a4e8daa 100644 --- a/fairseq/data/text_compressor.py +++ b/fairseq/data/text_compressor.py @@ -14,8 +14,7 @@ class TextCompressionLevel(Enum): class TextCompressor(object): def __init__( - self, level: TextCompressionLevel, - max_input_byte_length: int = 2 ** 16 + self, level: TextCompressionLevel, max_input_byte_length: int = 2 ** 16 ): self.level = level self.max_input_length = max_input_byte_length @@ -23,11 +22,13 @@ class TextCompressor(object): def compress(self, text: str) -> bytes: if self.level == TextCompressionLevel.low: import zlib + # zlib: built-in, fast return zlib.compress(text.encode(), level=0) elif self.level == TextCompressionLevel.high: try: import unishox2 + # unishox2: optimized for short text but slower except ImportError: raise ImportError( @@ -42,6 +43,7 @@ class TextCompressor(object): def decompress(self, compressed: bytes) -> str: if self.level == TextCompressionLevel.low: import zlib + return zlib.decompress(compressed).decode() elif self.level == TextCompressionLevel.high: try: diff --git a/fairseq/data/token_block_dataset.py b/fairseq/data/token_block_dataset.py index d2c65fd7..a414e7ef 100644 --- a/fairseq/data/token_block_dataset.py +++ b/fairseq/data/token_block_dataset.py @@ -69,7 +69,10 @@ class TokenBlockDataset(FairseqDataset): _sizes, split_path, (plasma_id, 1), plasma_path=plasma_path ) self._block_to_dataset_index = plasma_utils.PlasmaView( - block_to_dataset_index, split_path, (plasma_id, 2), plasma_path=plasma_path, + block_to_dataset_index, + split_path, + (plasma_id, 2), + plasma_path=plasma_path, ) else: self._slice_indices = plasma_utils.PlasmaArray(slice_indices) @@ -127,7 +130,8 @@ class TokenBlockDataset(FairseqDataset): ) else: block_to_dataset_index = _get_block_to_dataset_index_fast( - sizes, slice_indices, + sizes, + slice_indices, ) size_dtype = np.uint16 if block_size < 65535 else np.uint32 num_tokens = slice_indices[-1].max() diff --git a/fairseq/data/transform_eos_lang_pair_dataset.py b/fairseq/data/transform_eos_lang_pair_dataset.py index e21144a8..d8b21090 100644 --- a/fairseq/data/transform_eos_lang_pair_dataset.py +++ b/fairseq/data/transform_eos_lang_pair_dataset.py @@ -52,7 +52,7 @@ class TransformEosLangPairDataset(FairseqDataset): if len(samples) == 0: return samples - if 'net_input' not in samples: + if "net_input" not in samples: return samples if self.new_src_eos is not None: diff --git a/fairseq/dataclass/configs.py b/fairseq/dataclass/configs.py index b081e6ca..b6150ea3 100644 --- a/fairseq/dataclass/configs.py +++ b/fairseq/dataclass/configs.py @@ -126,7 +126,8 @@ class CommonConfig(FairseqDataclass): metadata={"help": "Weights and Biases project name to use for logging"}, ) azureml_logging: Optional[bool] = field( - default=False, metadata={"help": "Log scalars to AzureML context"}, + default=False, + metadata={"help": "Log scalars to AzureML context"}, ) seed: int = field( default=1, metadata={"help": "pseudo random number generator seed"} @@ -428,19 +429,23 @@ class DistributedTrainingConfig(FairseqDataclass): tpu: bool = II("common.tpu") # configuration for --ddp-backend=fully_sharded no_reshard_after_forward: bool = field( - default=False, metadata={"help": "don't reshard parameters after forward pass"}, + default=False, + metadata={"help": "don't reshard parameters after forward pass"}, ) fp32_reduce_scatter: bool = field( - default=False, metadata={"help": "reduce-scatter grads in FP32"}, + default=False, + metadata={"help": "reduce-scatter grads in FP32"}, ) cpu_offload: bool = field( default=False, metadata={"help": "offload FP32 params to CPU"} ) use_sharded_state: bool = field( - default=False, metadata={"help": "use sharded checkpoint files"}, + default=False, + metadata={"help": "use sharded checkpoint files"}, ) not_fsdp_flatten_parameters: bool = field( - default=False, metadata={"help": "not flatten parameter param for fsdp"}, + default=False, + metadata={"help": "not flatten parameter param for fsdp"}, ) @@ -786,10 +791,12 @@ class FairseqBMUFConfig(FairseqDataclass): @dataclass class GenerationConfig(FairseqDataclass): beam: int = field( - default=5, metadata={"help": "beam size"}, + default=5, + metadata={"help": "beam size"}, ) nbest: int = field( - default=1, metadata={"help": "number of hypotheses to output"}, + default=1, + metadata={"help": "number of hypotheses to output"}, ) max_len_a: float = field( default=0, @@ -804,19 +811,24 @@ class GenerationConfig(FairseqDataclass): }, ) min_len: int = field( - default=1, metadata={"help": "minimum generation length"}, + default=1, + metadata={"help": "minimum generation length"}, ) match_source_len: bool = field( - default=False, metadata={"help": "generations should match the source length"}, + default=False, + metadata={"help": "generations should match the source length"}, ) unnormalized: bool = field( - default=False, metadata={"help": "compare unnormalized hypothesis scores"}, + default=False, + metadata={"help": "compare unnormalized hypothesis scores"}, ) no_early_stop: bool = field( - default=False, metadata={"help": "deprecated"}, + default=False, + metadata={"help": "deprecated"}, ) no_beamable_mm: bool = field( - default=False, metadata={"help": "don't use BeamableMM in attention layers"}, + default=False, + metadata={"help": "don't use BeamableMM in attention layers"}, ) lenpen: float = field( default=1, @@ -838,10 +850,12 @@ class GenerationConfig(FairseqDataclass): }, ) sacrebleu: bool = field( - default=False, metadata={"help": "score with sacrebleu"}, + default=False, + metadata={"help": "score with sacrebleu"}, ) score_reference: bool = field( - default=False, metadata={"help": "just score the reference translation"}, + default=False, + metadata={"help": "just score the reference translation"}, ) prefix_size: int = field( default=0, @@ -875,10 +889,12 @@ class GenerationConfig(FairseqDataclass): }, ) temperature: float = field( - default=1.0, metadata={"help": "temperature for generation"}, + default=1.0, + metadata={"help": "temperature for generation"}, ) diverse_beam_groups: int = field( - default=-1, metadata={"help": "number of groups for Diverse Beam Search"}, + default=-1, + metadata={"help": "number of groups for Diverse Beam Search"}, ) diverse_beam_strength: float = field( default=0.5, @@ -897,13 +913,16 @@ class GenerationConfig(FairseqDataclass): }, ) print_step: bool = field( - default=False, metadata={"help": "print steps"}, + default=False, + metadata={"help": "print steps"}, ) lm_path: Optional[str] = field( - default=None, metadata={"help": "path to lm checkpoint for lm fusion"}, + default=None, + metadata={"help": "path to lm checkpoint for lm fusion"}, ) lm_weight: float = field( - default=0.0, metadata={"help": "weight for lm probs for lm fusion"}, + default=0.0, + metadata={"help": "weight for lm probs for lm fusion"}, ) # arguments for iterative refinement generator @@ -912,7 +931,8 @@ class GenerationConfig(FairseqDataclass): metadata={"help": "if > 0.0, it penalized early-stopping in decoding."}, ) iter_decode_max_iter: int = field( - default=10, metadata={"help": "maximum iterations for iterative refinement."}, + default=10, + metadata={"help": "maximum iterations for iterative refinement."}, ) iter_decode_force_max_iter: bool = field( default=False, @@ -939,7 +959,8 @@ class GenerationConfig(FairseqDataclass): }, ) retain_dropout: bool = field( - default=False, metadata={"help": "Use dropout at inference time"}, + default=False, + metadata={"help": "Use dropout at inference time"}, ) # temporarily set to Any until https://github.com/facebookresearch/hydra/issues/1117 is fixed # retain_dropout_modules: Optional[List[str]] = field( @@ -964,7 +985,8 @@ class GenerationConfig(FairseqDataclass): @dataclass class CommonEvalConfig(FairseqDataclass): path: Optional[str] = field( - default=None, metadata={"help": "path(s) to model file(s), colon separated"}, + default=None, + metadata={"help": "path(s) to model file(s), colon separated"}, ) post_process: Optional[str] = field( default=None, @@ -1026,7 +1048,8 @@ class InteractiveConfig(FairseqDataclass): }, ) input: str = field( - default="-", metadata={"help": "file to read from; use - for stdin"}, + default="-", + metadata={"help": "file to read from; use - for stdin"}, ) diff --git a/fairseq/dataclass/constants.py b/fairseq/dataclass/constants.py index 7e5aef70..5af92f2b 100644 --- a/fairseq/dataclass/constants.py +++ b/fairseq/dataclass/constants.py @@ -35,14 +35,16 @@ def ChoiceEnum(choices: List[str]): LOG_FORMAT_CHOICES = ChoiceEnum(["json", "none", "simple", "tqdm"]) -DDP_BACKEND_CHOICES = ChoiceEnum([ - "c10d", # alias for pytorch_ddp - "fully_sharded", # FullyShardedDataParallel from fairscale - "legacy_ddp", - "no_c10d", # alias for legacy_ddp - "pytorch_ddp", - "slowmo", -]) +DDP_BACKEND_CHOICES = ChoiceEnum( + [ + "c10d", # alias for pytorch_ddp + "fully_sharded", # FullyShardedDataParallel from fairscale + "legacy_ddp", + "no_c10d", # alias for legacy_ddp + "pytorch_ddp", + "slowmo", + ] +) DDP_COMM_HOOK_CHOICES = ChoiceEnum(["none", "fp16"]) DATASET_IMPL_CHOICES = ChoiceEnum(["raw", "lazy", "cached", "mmap", "fasta", "huffman"]) GENERATION_CONSTRAINTS_CHOICES = ChoiceEnum(["ordered", "unordered"]) diff --git a/fairseq/dataclass/initialize.py b/fairseq/dataclass/initialize.py index 8f6cbafb..5a7784ba 100644 --- a/fairseq/dataclass/initialize.py +++ b/fairseq/dataclass/initialize.py @@ -28,7 +28,7 @@ def hydra_init(cfg_name="config") -> None: def add_defaults(cfg: DictConfig) -> None: - """This function adds default values that are stored in dataclasses that hydra doesn't know about """ + """This function adds default values that are stored in dataclasses that hydra doesn't know about""" from fairseq.registry import REGISTRIES from fairseq.tasks import TASK_DATACLASS_REGISTRY diff --git a/fairseq/dataclass/utils.py b/fairseq/dataclass/utils.py index 1320ec47..b80315dd 100644 --- a/fairseq/dataclass/utils.py +++ b/fairseq/dataclass/utils.py @@ -57,21 +57,21 @@ def gen_parser_from_dataclass( with_prefix: Optional[str] = None, ) -> None: """ - convert a dataclass instance to tailing parser arguments. + convert a dataclass instance to tailing parser arguments. - If `with_prefix` is provided, prefix all the keys in the resulting parser with it. It means that we are - building a flat namespace from a structured dataclass (see transformer_config.py for example). + If `with_prefix` is provided, prefix all the keys in the resulting parser with it. It means that we are + building a flat namespace from a structured dataclass (see transformer_config.py for example). """ def argparse_name(name: str): - if name == "data" and (with_prefix is None or with_prefix == ''): + if name == "data" and (with_prefix is None or with_prefix == ""): # normally data is positional args, so we don't add the -- nor the prefix return name if name == "_name": # private member, skip return None full_name = "--" + name.replace("_", "-") - if with_prefix is not None and with_prefix != '': + if with_prefix is not None and with_prefix != "": # if a prefix is specified, construct the prefixed arg name full_name = with_prefix + "-" + full_name[2:] # strip -- when composing return full_name @@ -143,8 +143,8 @@ def gen_parser_from_dataclass( kwargs["default"] = field_default # build the help with the hierarchical prefix - if with_prefix is not None and with_prefix != '' and field_help is not None: - field_help = with_prefix[2:] + ': ' + field_help + if with_prefix is not None and with_prefix != "" and field_help is not None: + field_help = with_prefix[2:] + ": " + field_help kwargs["help"] = field_help if field_const is not None: diff --git a/fairseq/distributed/__init__.py b/fairseq/distributed/__init__.py index d0b96b73..9130db8f 100644 --- a/fairseq/distributed/__init__.py +++ b/fairseq/distributed/__init__.py @@ -4,7 +4,11 @@ # LICENSE file in the root directory of this source tree. from .distributed_timeout_wrapper import DistributedTimeoutWrapper -from .fully_sharded_data_parallel import fsdp_enable_wrap, fsdp_wrap, FullyShardedDataParallel +from .fully_sharded_data_parallel import ( + fsdp_enable_wrap, + fsdp_wrap, + FullyShardedDataParallel, +) from .legacy_distributed_data_parallel import LegacyDistributedDataParallel from .module_proxy_wrapper import ModuleProxyWrapper from .tpu_distributed_data_parallel import TPUDistributedDataParallel diff --git a/fairseq/distributed/distributed_timeout_wrapper.py b/fairseq/distributed/distributed_timeout_wrapper.py index 18107ef2..6e06b4b6 100644 --- a/fairseq/distributed/distributed_timeout_wrapper.py +++ b/fairseq/distributed/distributed_timeout_wrapper.py @@ -33,6 +33,7 @@ class DistributedTimeoutWrapper(nn.Module): (set to a value <= 0 to disable the timeout) signal (Optional): signal to send once timeout is triggered """ + def __init__(self, module: nn.Module, timeout: int, signal=signal.SIGINT): super().__init__() self.module = module @@ -86,9 +87,11 @@ class DistributedTimeoutWrapper(nn.Module): if self._terminated: break elif not success: - logger.error(( - "Killing job for not making progress in {} seconds. " - "Set --heartbeat-timeout=-1 to disable this timeout." - ).format(int(self.timeout))) + logger.error( + ( + "Killing job for not making progress in {} seconds. " + "Set --heartbeat-timeout=-1 to disable this timeout." + ).format(int(self.timeout)) + ) os.kill(parent_pid, self.signal) return diff --git a/fairseq/distributed/legacy_distributed_data_parallel.py b/fairseq/distributed/legacy_distributed_data_parallel.py index f2308f87..5f89e6c0 100644 --- a/fairseq/distributed/legacy_distributed_data_parallel.py +++ b/fairseq/distributed/legacy_distributed_data_parallel.py @@ -137,7 +137,7 @@ class LegacyDistributedDataParallel(nn.Module): if param.grad is None: param.grad = torch.zeros_like(param) - if hasattr(param, 'expert'): + if hasattr(param, "expert"): # Skip gradient sync for unshared parameters continue diff --git a/fairseq/distributed/module_proxy_wrapper.py b/fairseq/distributed/module_proxy_wrapper.py index fc2c6f8c..904dc0c2 100644 --- a/fairseq/distributed/module_proxy_wrapper.py +++ b/fairseq/distributed/module_proxy_wrapper.py @@ -26,8 +26,9 @@ class ModuleProxyWrapper(nn.Module): def __init__(self, module: nn.Module): super().__init__() - assert hasattr(module, "module"), \ - "ModuleProxyWrapper expects input to wrap another module" + assert hasattr( + module, "module" + ), "ModuleProxyWrapper expects input to wrap another module" self.module = module def __getattr__(self, name): diff --git a/fairseq/distributed/tpu_distributed_data_parallel.py b/fairseq/distributed/tpu_distributed_data_parallel.py index e971cf07..3b9e1033 100644 --- a/fairseq/distributed/tpu_distributed_data_parallel.py +++ b/fairseq/distributed/tpu_distributed_data_parallel.py @@ -10,7 +10,6 @@ from fairseq.distributed import utils class TPUDistributedDataParallel(nn.Module): - def __init__(self, module, process_group): super().__init__() self.module = module @@ -35,9 +34,10 @@ class TPUDistributedDataParallel(nn.Module): gradients.append(p.grad) import torch_xla.core.xla_model as xm + xm.all_reduce( - 'sum', + "sum", gradients, - scale=1. / self.world_size, + scale=1.0 / self.world_size, groups=self.process_group[1], ) diff --git a/fairseq/distributed/utils.py b/fairseq/distributed/utils.py index e6459447..2c52f76a 100644 --- a/fairseq/distributed/utils.py +++ b/fairseq/distributed/utils.py @@ -201,9 +201,7 @@ def _pipeline_parallel_post_init( # distributed_world_size to be based on the total number of GPUs, so # we need to correct them to be based on the number of pipelines. assert cfg.distributed_world_size % num_pipeline_devices == 0 - cfg.distributed_world_size = ( - cfg.distributed_world_size // num_pipeline_devices - ) + cfg.distributed_world_size = cfg.distributed_world_size // num_pipeline_devices # In the case of 4-way MP on nodes with 8 GPUs, we want # distributed_rank to be the starting GPU index for each pipeline # i.e., 0, 2, ... @@ -306,8 +304,10 @@ def distributed_init(cfg: FairseqConfig): model_part_number = get_model_parallel_rank() cfg.checkpoint.checkpoint_suffix += "-model_part-{0}".format(model_part_number) - if hasattr(cfg, "model") and getattr(cfg.model, "base_layers", 0) > 0: - cfg.checkpoint.checkpoint_suffix = f"-rank-{cfg.distributed_training.distributed_rank}" + if hasattr(cfg, "model") and getattr(cfg.model, "base_layers", 0) > 0: + cfg.checkpoint.checkpoint_suffix = ( + f"-rank-{cfg.distributed_training.distributed_rank}" + ) return cfg.distributed_training.distributed_rank @@ -696,7 +696,7 @@ def broadcast_tensors( dist_device = torch.device("cpu") # share metadata first to simplify transfer - is_src_rank = (get_rank(group) == src_rank) + is_src_rank = get_rank(group) == src_rank if is_src_rank: metadata = [ {"size": t.size(), "dtype": t.dtype, "device": t.device} for t in tensors @@ -747,7 +747,10 @@ def broadcast_object( def _broadcast_object_slow( - obj: Any, src_rank: int, group: object, dist_device: torch.device, + obj: Any, + src_rank: int, + group: object, + dist_device: torch.device, ) -> Any: if get_rank(group) == src_rank: # Emit data diff --git a/fairseq/file_io.py b/fairseq/file_io.py index dba663d4..8eca70a0 100644 --- a/fairseq/file_io.py +++ b/fairseq/file_io.py @@ -152,6 +152,7 @@ class PathManager: """ ioPath async PathManager methods: """ + @staticmethod def opena( path: str, @@ -169,6 +170,7 @@ class PathManager: logging.info("ioPath is initializing PathManager.") try: from iopath.common.file_io import PathManager + IOPathManager = PathManager() except Exception: logging.exception("Failed to initialize ioPath PathManager object.") diff --git a/fairseq/file_utils.py b/fairseq/file_utils.py index d1d5ea65..b99da2e8 100644 --- a/fairseq/file_utils.py +++ b/fairseq/file_utils.py @@ -146,6 +146,7 @@ def cached_path_from_pm(url_or_filename): """ try: from fairseq.file_io import PathManager + local_path = PathManager.get_local_path(url_or_filename) return local_path except Exception: diff --git a/fairseq/logging/metrics.py b/fairseq/logging/metrics.py index 58c2fb64..892b0ea4 100644 --- a/fairseq/logging/metrics.py +++ b/fairseq/logging/metrics.py @@ -130,6 +130,7 @@ def log_scalar( agg.add_meter(key, AverageMeter(round=round), priority) agg[key].update(value, weight) + def log_scalar_sum( key: str, value: float, @@ -309,6 +310,7 @@ def load_state_dict(state_dict): def xla_metrics_report(): try: import torch_xla.debug.metrics as met + print(met.metrics_report()) except ImportError: return diff --git a/fairseq/model_parallel/megatron_trainer.py b/fairseq/model_parallel/megatron_trainer.py index 8ab4657f..ca421186 100644 --- a/fairseq/model_parallel/megatron_trainer.py +++ b/fairseq/model_parallel/megatron_trainer.py @@ -52,8 +52,7 @@ class MegatronTrainer(Trainer): def save_checkpoint(self, filename, extra_state): """Save all training state in a checkpoint file.""" - extra_state['rng_tracker_states'] \ - = get_cuda_rng_tracker().get_states() + extra_state["rng_tracker_states"] = get_cuda_rng_tracker().get_states() super().save_checkpoint(filename, extra_state) def load_checkpoint( @@ -64,8 +63,13 @@ class MegatronTrainer(Trainer): optimizer_overrides=None, reset_meters=False, ): - extra_state = super().load_checkpoint(filename, reset_optimizer=reset_optimizer, reset_lr_scheduler=reset_lr_scheduler, optimizer_overrides=optimizer_overrides, reset_meters=reset_meters) - if extra_state is not None and 'rng_tracker_states' in extra_state: - get_cuda_rng_tracker().set_states( - extra_state['rng_tracker_states']) + extra_state = super().load_checkpoint( + filename, + reset_optimizer=reset_optimizer, + reset_lr_scheduler=reset_lr_scheduler, + optimizer_overrides=optimizer_overrides, + reset_meters=reset_meters, + ) + if extra_state is not None and "rng_tracker_states" in extra_state: + get_cuda_rng_tracker().set_states(extra_state["rng_tracker_states"]) return extra_state diff --git a/fairseq/model_parallel/models/pipeline_parallel_transformer/layers.py b/fairseq/model_parallel/models/pipeline_parallel_transformer/layers.py index eb81ded3..c07a027a 100644 --- a/fairseq/model_parallel/models/pipeline_parallel_transformer/layers.py +++ b/fairseq/model_parallel/models/pipeline_parallel_transformer/layers.py @@ -9,6 +9,7 @@ from collections import namedtuple import torch import torch.nn as nn import torch.nn.functional as F + from fairseq import options, utils from fairseq.modules import ( AdaptiveSoftmax, @@ -17,7 +18,6 @@ from fairseq.modules import ( PositionalEmbedding, ) - EncoderOut = namedtuple( "TransformerEncoderOut", [ @@ -30,7 +30,7 @@ EncoderOut = namedtuple( class TransformerEncoderEmbedding(nn.Module): - """ Encoder Embedding + Positional Embedding """ + """Encoder Embedding + Positional Embedding""" def __init__(self, args, embed_tokens): super().__init__() @@ -109,7 +109,7 @@ class TransformerEncoderLayerNorm(nn.Module): class TransformerDecoderEmbedding(nn.Module): - """ Decoder Embedding + Positional Embedding """ + """Decoder Embedding + Positional Embedding""" def __init__(self, args, embed_tokens): super().__init__() diff --git a/fairseq/model_parallel/models/pipeline_parallel_transformer/model.py b/fairseq/model_parallel/models/pipeline_parallel_transformer/model.py index 7f30dd98..7bb0c9ad 100644 --- a/fairseq/model_parallel/models/pipeline_parallel_transformer/model.py +++ b/fairseq/model_parallel/models/pipeline_parallel_transformer/model.py @@ -42,16 +42,20 @@ DEFAULT_MAX_TARGET_POSITIONS = 1024 TORCH_PIPE = False RPC_INIT = False + def import_pipe(): global TORCH_PIPE global RPC_INIT try: - from torch.distributed.pipeline.sync import Pipe # noqa + from torch.distributed.pipeline.sync import Pipe # noqa + global Pipe from torch.distributed.pipeline.sync.utils import partition_model + global partition_model from torch.distributed import rpc import tempfile + TORCH_PIPE = True # Initialize single process RPC agent since TORCH_PIPE requires # RRef. RRef depends on RPC being initialized and as a result we initialize @@ -64,14 +68,15 @@ def import_pipe(): world_size=1, rpc_backend_options=rpc.TensorPipeRpcBackendOptions( init_method="file://{}".format(tmpfile.name), - ) + ), ) RPC_INIT = True - logger.info('Using torch pipe') + logger.info("Using torch pipe") except ImportError: try: - from fairscale.nn import Pipe # noqa - logger.info('Using fairscale pipe') + from fairscale.nn import Pipe # noqa + + logger.info("Using fairscale pipe") except ImportError: raise ImportError("Please install fairscale with: pip install fairscale") @@ -153,9 +158,14 @@ class PipelineParallelTransformerModel(BaseFairseqModel): decoder_module_list.append(module) module_count += 1 self.model = None - self.encoder = TransformerEncoder(cfg.distributed_training, None, None, encoder_module_list) + self.encoder = TransformerEncoder( + cfg.distributed_training, None, None, encoder_module_list + ) self.decoder = TransformerDecoder( - cfg.distributed_training, None, None, decoder_module_list=decoder_module_list + cfg.distributed_training, + None, + None, + decoder_module_list=decoder_module_list, ) @staticmethod @@ -471,7 +481,9 @@ class TransformerEncoder(FairseqEncoder): self.use_pipeline = encoder_module_list is not None if not self.use_pipeline: self.embedding_layer = TransformerEncoderEmbedding(args, embed_tokens) - self.encoder_layers = nn.Sequential(*[TransformerEncoderLayer(args) for i in range(args.encoder_layers)]) + self.encoder_layers = nn.Sequential( + *[TransformerEncoderLayer(args) for i in range(args.encoder_layers)] + ) if isinstance(embed_tokens, nn.ModuleList): emb_dim = sum(e.embedding_dim for e in embed_tokens) else: @@ -490,7 +502,11 @@ class TransformerEncoder(FairseqEncoder): ) if TORCH_PIPE: self.model = Pipe( - module=partition_model(nn.Sequential(*encoder_module_list), encoder_balance, encoder_devices), + module=partition_model( + nn.Sequential(*encoder_module_list), + encoder_balance, + encoder_devices, + ), chunks=args.pipeline_chunks, checkpoint=args.pipeline_checkpoint, ) @@ -614,10 +630,12 @@ class TransformerDecoder(FairseqDecoder): self.use_pipeline = decoder_module_list is not None if not self.use_pipeline: self.embedding_layer = TransformerDecoderEmbedding(args, embed_tokens) - self.decoder_layers = nn.Sequential(*[ - TransformerDecoderLayer(args, no_encoder_attn) - for _ in range(args.decoder_layers) - ]) + self.decoder_layers = nn.Sequential( + *[ + TransformerDecoderLayer(args, no_encoder_attn) + for _ in range(args.decoder_layers) + ] + ) self.decoder_output_layer = TransformerDecoderOutputLayer( args, embed_tokens, dictionary ) @@ -634,7 +652,11 @@ class TransformerDecoder(FairseqDecoder): ) if TORCH_PIPE: self.model = Pipe( - module=partition_model(nn.Sequential(*decoder_module_list), decoder_balance, decoder_devices), + module=partition_model( + nn.Sequential(*decoder_module_list), + decoder_balance, + decoder_devices, + ), chunks=args.pipeline_chunks, checkpoint=args.pipeline_checkpoint, ) diff --git a/fairseq/model_parallel/models/transformer_lm.py b/fairseq/model_parallel/models/transformer_lm.py index dc52f6e8..a7ca5c9f 100644 --- a/fairseq/model_parallel/models/transformer_lm.py +++ b/fairseq/model_parallel/models/transformer_lm.py @@ -4,11 +4,11 @@ # LICENSE file in the root directory of this source tree. import torch.nn as nn + from fairseq.model_parallel.models.transformer import ModelParallelTransformerDecoder from fairseq.models import register_model, register_model_architecture from fairseq.models.transformer_lm import TransformerLanguageModel - try: from fairseq.model_parallel.megatron.mpu import VocabParallelEmbedding @@ -22,7 +22,6 @@ DEFAULT_MAX_TARGET_POSITIONS = 1024 @register_model("model_parallel_transformer_lm") class ModelParallelTransformerLanguageModel(TransformerLanguageModel): - @staticmethod def add_args(parser): TransformerLanguageModel.add_args(parser) @@ -72,10 +71,6 @@ class ModelParallelTransformerLanguageModel(TransformerLanguageModel): ) return cls(decoder) - @staticmethod - def add_args(parser): - TransformerLanguageModel.add_args(parser) - @classmethod def build_embedding(cls, args, dictionary, embed_dim, path=None): def _vocab_init(tensor, **kwargs): diff --git a/fairseq/models/__init__.py b/fairseq/models/__init__.py index 337c77ac..320f5e17 100644 --- a/fairseq/models/__init__.py +++ b/fairseq/models/__init__.py @@ -98,9 +98,7 @@ def build_model(cfg: FairseqDataclass, task): assert model is not None, ( f"Could not infer model type from {cfg}. " - "Available models: {}".format( - MODEL_DATACLASS_REGISTRY.keys() - ) + "Available models: {}".format(MODEL_DATACLASS_REGISTRY.keys()) + f" Requested model type: {model_type}" ) diff --git a/fairseq/models/bart/hub_interface.py b/fairseq/models/bart/hub_interface.py index 4d47d975..6b647c96 100644 --- a/fairseq/models/bart/hub_interface.py +++ b/fairseq/models/bart/hub_interface.py @@ -100,8 +100,8 @@ class BARTHubInterface(GeneratorHubInterface): raise NotImplementedError("prefix generation not implemented for BART") res = [] for batch in self._build_batches(tokenized_sentences, skip_invalid_size_inputs): - src_tokens = batch['net_input']['src_tokens'] - inference_step_args["prefix_tokens"] =src_tokens.new_full( + src_tokens = batch["net_input"]["src_tokens"] + inference_step_args["prefix_tokens"] = src_tokens.new_full( (src_tokens.size(0), 1), fill_value=self.task.source_dictionary.bos() ).to(device=self.device) results = super().generate( @@ -111,7 +111,7 @@ class BARTHubInterface(GeneratorHubInterface): skip_invalid_size_inputs=skip_invalid_size_inputs, **kwargs ) - for id, hypos in zip(batch['id'].tolist(), results): + for id, hypos in zip(batch["id"].tolist(), results): res.append((id, hypos)) res = [hypos for _, hypos in sorted(res, key=lambda x: x[0])] return res @@ -177,32 +177,35 @@ class BARTHubInterface(GeneratorHubInterface): match_source_len: bool = True, **generate_kwargs ): - masked_token = '' + masked_token = "" batch_tokens = [] for masked_input in masked_inputs: - assert masked_token in masked_input, \ - "please add one {} token for the input".format(masked_token) + assert ( + masked_token in masked_input + ), "please add one {} token for the input".format(masked_token) text_spans = masked_input.split(masked_token) - text_spans_bpe = (' {0} '.format(masked_token)).join( - [self.bpe.encode(text_span.rstrip()) for text_span in text_spans] - ).strip() + text_spans_bpe = ( + (" {0} ".format(masked_token)) + .join([self.bpe.encode(text_span.rstrip()) for text_span in text_spans]) + .strip() + ) tokens = self.task.source_dictionary.encode_line( - ' ' + text_spans_bpe + ' ', + " " + text_spans_bpe + " ", append_eos=False, add_if_not_exist=False, ).long() batch_tokens.append(tokens) # ensure beam size is at least as big as topk - generate_kwargs['beam'] = max( + generate_kwargs["beam"] = max( topk, - generate_kwargs.get('beam', -1), + generate_kwargs.get("beam", -1), ) - generate_kwargs['match_source_len'] = match_source_len + generate_kwargs["match_source_len"] = match_source_len batch_hypos = self.generate(batch_tokens, **generate_kwargs) return [ - [(self.decode(hypo['tokens']), hypo['score']) for hypo in hypos[:topk]] + [(self.decode(hypo["tokens"]), hypo["score"]) for hypo in hypos[:topk]] for hypos in batch_hypos ] diff --git a/fairseq/models/bart/model.py b/fairseq/models/bart/model.py index 71d0b27c..bdb12b02 100644 --- a/fairseq/models/bart/model.py +++ b/fairseq/models/bart/model.py @@ -90,7 +90,7 @@ class BARTModel(TransformerModel): src_tokens, src_lengths=src_lengths, token_embeddings=token_embeddings, - return_all_hiddens=return_all_hiddens + return_all_hiddens=return_all_hiddens, ) x, extra = self.decoder( prev_output_tokens, @@ -103,9 +103,9 @@ class BARTModel(TransformerModel): ) eos: int = self.eos if classification_head_name is not None: - sentence_representation = x[ - src_tokens.eq(eos), : - ].view(x.size(0), -1, x.size(-1))[:, -1, :] + sentence_representation = x[src_tokens.eq(eos), :].view( + x.size(0), -1, x.size(-1) + )[:, -1, :] for k, head in self.classification_heads.items(): # for torch script only supports iteration if k == classification_head_name: diff --git a/fairseq/models/distributed_fairseq_model.py b/fairseq/models/distributed_fairseq_model.py index de8d6ac1..7c4ab558 100644 --- a/fairseq/models/distributed_fairseq_model.py +++ b/fairseq/models/distributed_fairseq_model.py @@ -25,7 +25,10 @@ logger = logging.getLogger(__name__) _SLOWMO_DDP_DISABLED = False try: - from fairscale.experimental.nn.data_parallel import SlowMoBaseAlgorithm, SlowMoDistributedDataParallel + from fairscale.experimental.nn.data_parallel import ( + SlowMoBaseAlgorithm, + SlowMoDistributedDataParallel, + ) except ImportError: _SLOWMO_DDP_DISABLED = True diff --git a/fairseq/models/ema/ema.py b/fairseq/models/ema/ema.py index 6c0af693..c43d9693 100644 --- a/fairseq/models/ema/ema.py +++ b/fairseq/models/ema/ema.py @@ -22,6 +22,7 @@ import copy import logging import torch + from fairseq import checkpoint_utils @@ -78,7 +79,9 @@ class EMA(object): self.fp32_params = {} if self.config.ema_seed_model is not None: - state = checkpoint_utils.load_ema_from_checkpoint(self.config.ema_seed_model) + state = checkpoint_utils.load_ema_from_checkpoint( + self.config.ema_seed_model + ) self.model.load_state_dict(state["model"], strict=True) if device is not None: @@ -119,7 +122,7 @@ class EMA(object): self.fp32_params[param_key] = _to_float(state_dict[param_key]) def restore(self, state_dict, build_fp32_params=False): - """ Load data from a model spec into EMA model """ + """Load data from a model spec into EMA model""" self.model.load_state_dict(state_dict, strict=False) if build_fp32_params: self.build_fp32_params(state_dict) @@ -131,16 +134,20 @@ class EMA(object): return self.decay def _step_internal(self, new_model, updates=None): - """ One update of the EMA model based on new model weights """ + """One update of the EMA model based on new model weights""" decay = self.decay ema_state_dict = {} - ema_params = self.fp32_params if self.config.ema_fp32 else self.model.state_dict() + ema_params = ( + self.fp32_params if self.config.ema_fp32 else self.model.state_dict() + ) for key, param in new_model.state_dict().items(): try: ema_param = ema_params[key] except KeyError: - ema_param = param.float().clone() if param.ndim == 1 else copy.deepcopy(param) + ema_param = ( + param.float().clone() if param.ndim == 1 else copy.deepcopy(param) + ) if param.shape != ema_param.shape: raise ValueError( @@ -151,7 +158,7 @@ class EMA(object): # Do not decay a model.version pytorch param continue ema_param.mul_(decay) - ema_param.add_(param.to(dtype=ema_param.dtype), alpha=1-decay) + ema_param.add_(param.to(dtype=ema_param.dtype), alpha=1 - decay) ema_state_dict[key] = ema_param self.restore(ema_state_dict, build_fp32_params=False) @@ -168,8 +175,7 @@ class EMA(object): """ self._set_decay( 0 - if updates is not None - and updates < self.config.ema_start_update + if updates is not None and updates < self.config.ema_start_update else self.config.ema_decay ) if updates is not None and self.config.ema_update_freq > 1: diff --git a/fairseq/models/fairseq_decoder.py b/fairseq/models/fairseq_decoder.py index 4f1e8b52..13b73d63 100644 --- a/fairseq/models/fairseq_decoder.py +++ b/fairseq/models/fairseq_decoder.py @@ -19,7 +19,6 @@ class FairseqDecoder(nn.Module): self.onnx_trace = False self.adaptive_softmax = None - def forward(self, prev_output_tokens, encoder_out=None, **kwargs): """ Args: diff --git a/fairseq/models/fairseq_model.py b/fairseq/models/fairseq_model.py index e55c7ba1..42f9134a 100644 --- a/fairseq/models/fairseq_model.py +++ b/fairseq/models/fairseq_model.py @@ -29,8 +29,9 @@ logger = logging.getLogger(__name__) def check_type(module, expected_type): if hasattr(module, "unwrapped_module"): - assert isinstance(module.unwrapped_module, expected_type), \ - f"{type(module.unwrapped_module)} != {expected_type}" + assert isinstance( + module.unwrapped_module, expected_type + ), f"{type(module.unwrapped_module)} != {expected_type}" else: assert isinstance(module, expected_type), f"{type(module)} != {expected_type}" @@ -114,7 +115,9 @@ class BaseFairseqModel(nn.Module): """ if model_cfg is None and args is not None: - logger.warn("using 'args' is deprecated, please update your code to use dataclass config") + logger.warn( + "using 'args' is deprecated, please update your code to use dataclass config" + ) model_cfg = convert_namespace_to_omegaconf(args).model self.upgrade_state_dict(state_dict) @@ -454,7 +457,9 @@ class FairseqMultiModel(BaseFairseqModel): """ if model_cfg is None and args is not None: - logger.warn("using 'args' is deprecated, please update your code to use dataclass config") + logger.warn( + "using 'args' is deprecated, please update your code to use dataclass config" + ) model_cfg = convert_namespace_to_omegaconf(args).model self.upgrade_state_dict(state_dict) diff --git a/fairseq/models/hubert/hubert.py b/fairseq/models/hubert/hubert.py index 232a5e40..4306af03 100644 --- a/fairseq/models/hubert/hubert.py +++ b/fairseq/models/hubert/hubert.py @@ -30,9 +30,7 @@ from omegaconf import II logger = logging.getLogger(__name__) EXTRACTOR_MODE_CHOICES = ChoiceEnum(["default", "layer_norm"]) -MASKING_DISTRIBUTION_CHOICES = ChoiceEnum( - ["static", "uniform", "normal", "poisson"] -) +MASKING_DISTRIBUTION_CHOICES = ChoiceEnum(["static", "uniform", "normal", "poisson"]) @dataclass @@ -86,9 +84,7 @@ class HubertConfig(FairseqDataclass): ) dropout_features: float = field( default=0.0, - metadata={ - "help": "dropout to apply to the features (after feat extr)" - }, + metadata={"help": "dropout to apply to the features (after feat extr)"}, ) final_dim: int = field( @@ -150,9 +146,7 @@ class HubertConfig(FairseqDataclass): ) mask_min_space: int = field( default=1, - metadata={ - "help": "min space between spans (if no overlap is enabled)" - }, + metadata={"help": "min space between spans (if no overlap is enabled)"}, ) # channel masking @@ -182,23 +176,17 @@ class HubertConfig(FairseqDataclass): ) mask_channel_min_space: int = field( default=1, - metadata={ - "help": "min space between spans (if no overlap is enabled)" - }, + metadata={"help": "min space between spans (if no overlap is enabled)"}, ) # positional embeddings conv_pos: int = field( default=128, - metadata={ - "help": "number of filters for convolutional positional embeddings" - }, + metadata={"help": "number of filters for convolutional positional embeddings"}, ) conv_pos_groups: int = field( default=16, - metadata={ - "help": "number of groups for convolutional positional embedding" - }, + metadata={"help": "number of groups for convolutional positional embedding"}, ) latent_temp: Tuple[float, float, float] = field( @@ -238,9 +226,7 @@ class HubertModel(BaseFairseqModel): conv_bias=cfg.conv_bias, ) feature_ds_rate = np.prod([s for _, _, s in feature_enc_layers]) - self.feat2tar_ratio = ( - cfg.label_rate * feature_ds_rate / task_cfg.sample_rate - ) + self.feat2tar_ratio = cfg.label_rate * feature_ds_rate / task_cfg.sample_rate self.post_extract_proj = ( nn.Linear(self.embed, cfg.encoder_embed_dim) @@ -270,9 +256,7 @@ class HubertModel(BaseFairseqModel): self.skip_masked = cfg.skip_masked self.skip_nomask = cfg.skip_nomask - final_dim = ( - cfg.final_dim if cfg.final_dim > 0 else cfg.encoder_embed_dim - ) + final_dim = cfg.final_dim if cfg.final_dim > 0 else cfg.encoder_embed_dim self.mask_emb = nn.Parameter( torch.FloatTensor(cfg.encoder_embed_dim).uniform_() @@ -297,9 +281,7 @@ class HubertModel(BaseFairseqModel): # modules below are not needed during fine-tuning if any([d is None for d in dictionaries]): - logger.info( - "cannot find dictionary. assume will be used for fine-tuning" - ) + logger.info("cannot find dictionary. assume will be used for fine-tuning") else: self.num_classes = [len(d) for d in dictionaries] self.label_embs_concat = nn.Parameter( @@ -365,9 +347,7 @@ class HubertModel(BaseFairseqModel): pos = pos.unsqueeze(0) targets = torch.cat([pos, negs], dim=0) - logits = torch.cosine_similarity( - x.float(), targets.float(), dim=-1 - ).type_as(x) + logits = torch.cosine_similarity(x.float(), targets.float(), dim=-1).type_as(x) logits /= self.logit_temp if neg_is_pos.any(): logits[1:][neg_is_pos] = float("-inf") @@ -385,7 +365,9 @@ class HubertModel(BaseFairseqModel): return features def forward_targets( - self, features: torch.Tensor, target_list: List[torch.Tensor], + self, + features: torch.Tensor, + target_list: List[torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor]: # Trim features to ensure labels exist and then get aligned labels feat_tsz = features.size(2) @@ -398,14 +380,14 @@ class HubertModel(BaseFairseqModel): return features, target_list def forward_padding_mask( - self, features: torch.Tensor, padding_mask: torch.Tensor, + self, + features: torch.Tensor, + padding_mask: torch.Tensor, ) -> torch.Tensor: extra = padding_mask.size(1) % features.size(1) if extra > 0: padding_mask = padding_mask[:, :-extra] - padding_mask = padding_mask.view( - padding_mask.size(0), features.size(1), -1 - ) + padding_mask = padding_mask.view(padding_mask.size(0), features.size(1), -1) padding_mask = padding_mask.all(-1) return padding_mask @@ -439,9 +421,7 @@ class HubertModel(BaseFairseqModel): unmasked_features = self.dropout_features(unmasked_features) if mask: - x, mask_indices = self.apply_mask( - features, padding_mask, target_list - ) + x, mask_indices = self.apply_mask(features, padding_mask, target_list) else: x = features mask_indices = None @@ -454,7 +434,7 @@ class HubertModel(BaseFairseqModel): x, _ = self.encoder( x, padding_mask=padding_mask, - layer=None if output_layer is None else output_layer - 1 + layer=None if output_layer is None else output_layer - 1, ) if features_only: @@ -483,9 +463,7 @@ class HubertModel(BaseFairseqModel): proj_x_m_list = [proj_x_m for _ in range(len(target_list))] logit_m_list = [ compute_pred(proj_x_m, t[masked_indices], label_embs_list[i]) - for i, (proj_x_m, t) in enumerate( - zip(proj_x_m_list, target_list) - ) + for i, (proj_x_m, t) in enumerate(zip(proj_x_m_list, target_list)) ] else: logit_m_list = [None for _ in target_list] @@ -500,9 +478,7 @@ class HubertModel(BaseFairseqModel): logit_u_list = [ compute_pred(proj_x_u, t[nomask_indices], label_embs_list[i]) - for i, (proj_x_u, t) in enumerate( - zip(proj_x_u_list, target_list) - ) + for i, (proj_x_u, t) in enumerate(zip(proj_x_u_list, target_list)) ] else: logit_u_list = [None for _ in target_list] @@ -543,9 +519,7 @@ class HubertModel(BaseFairseqModel): def get_targets(self, net_output, is_masked=True): logits_list = self.get_logits(net_output, is_masked) - targets_list = [ - x.new_zeros(x.size(0), dtype=torch.long) for x in logits_list - ] + targets_list = [x.new_zeros(x.size(0), dtype=torch.long) for x in logits_list] return targets_list def get_extra_losses(self, net_output): diff --git a/fairseq/models/hubert/hubert_asr.py b/fairseq/models/hubert/hubert_asr.py index dce899c9..e336370b 100644 --- a/fairseq/models/hubert/hubert_asr.py +++ b/fairseq/models/hubert/hubert_asr.py @@ -21,9 +21,7 @@ from omegaconf import II, MISSING @dataclass class HubertAsrConfig(FairseqDataclass): - w2v_path: str = field( - default=MISSING, metadata={"help": "path to hubert model"} - ) + w2v_path: str = field(default=MISSING, metadata={"help": "path to hubert model"}) no_pretrained_weights: bool = field( default=False, metadata={"help": "if true, does not load pretrained weights"}, @@ -34,9 +32,7 @@ class HubertAsrConfig(FairseqDataclass): ) final_dropout: float = field( default=0.0, - metadata={ - "help": "dropout after transformer and before final projection" - }, + metadata={"help": "dropout after transformer and before final projection"}, ) dropout: float = field( default=0.0, @@ -45,15 +41,13 @@ class HubertAsrConfig(FairseqDataclass): attention_dropout: float = field( default=0.0, metadata={ - "help": "dropout probability for attention weights " - "inside hubert model" + "help": "dropout probability for attention weights " "inside hubert model" }, ) activation_dropout: float = field( default=0.0, metadata={ - "help": "dropout probability after activation in FFN " - "inside hubert model" + "help": "dropout probability after activation in FFN " "inside hubert model" }, ) @@ -184,9 +178,7 @@ class HubertSeq2SeqConfig(HubertAsrConfig): decoder_ffn_embed_dim: int = field( default=3072, metadata={"help": "decoder embedding dimension for FFN"} ) - decoder_layers: int = field( - default=6, metadata={"help": "num of decoder layers"} - ) + decoder_layers: int = field(default=6, metadata={"help": "num of decoder layers"}) decoder_layerdrop: float = field( default=0.0, metadata={"help": "decoder layerdrop chance"} ) @@ -204,8 +196,7 @@ class HubertSeq2SeqConfig(HubertAsrConfig): no_token_positional_embeddings: bool = field( default=False, metadata={ - "help": "if set, disables positional embeddings " - "(outside self attention)" + "help": "if set, disables positional embeddings " "(outside self attention)" }, ) decoder_dropout: float = field( @@ -214,15 +205,13 @@ class HubertSeq2SeqConfig(HubertAsrConfig): decoder_attention_dropout: float = field( default=0.0, metadata={ - "help": "dropout probability for attention weights " - "inside the decoder" + "help": "dropout probability for attention weights " "inside the decoder" }, ) decoder_activation_dropout: float = field( default=0.0, metadata={ - "help": "dropout probability after activation in FFN " - "inside the decoder" + "help": "dropout probability after activation in FFN " "inside the decoder" }, ) max_target_positions: int = field( @@ -258,9 +247,7 @@ class HubertEncoder(FairseqEncoder): } if cfg.w2v_args is None: - state = checkpoint_utils.load_checkpoint_to_cpu( - cfg.w2v_path, arg_overrides - ) + state = checkpoint_utils.load_checkpoint_to_cpu(cfg.w2v_path, arg_overrides) w2v_args = state.get("cfg", None) if w2v_args is None: w2v_args = convert_namespace_to_omegaconf(state["args"]) @@ -269,9 +256,7 @@ class HubertEncoder(FairseqEncoder): state = None w2v_args = cfg.w2v_args if isinstance(w2v_args, Namespace): - cfg.w2v_args = w2v_args = convert_namespace_to_omegaconf( - w2v_args - ) + cfg.w2v_args = w2v_args = convert_namespace_to_omegaconf(w2v_args) assert cfg.normalize == w2v_args.task.normalize, ( "Fine-tuning works best when data normalization is the same. " @@ -344,9 +329,9 @@ class HubertEncoder(FairseqEncoder): def reorder_encoder_out(self, encoder_out, new_order): if encoder_out["encoder_out"] is not None: - encoder_out["encoder_out"] = encoder_out[ - "encoder_out" - ].index_select(1, new_order) + encoder_out["encoder_out"] = encoder_out["encoder_out"].index_select( + 1, new_order + ) if encoder_out["encoder_padding_mask"] is not None: encoder_out["encoder_padding_mask"] = encoder_out[ "encoder_padding_mask" diff --git a/fairseq/models/lstm.py b/fairseq/models/lstm.py index e1e66a7d..8a291562 100644 --- a/fairseq/models/lstm.py +++ b/fairseq/models/lstm.py @@ -225,10 +225,10 @@ class LSTMEncoder(FairseqEncoder): super().__init__(dictionary) self.num_layers = num_layers self.dropout_in_module = FairseqDropout( - dropout_in*1.0, module_name=self.__class__.__name__ + dropout_in * 1.0, module_name=self.__class__.__name__ ) self.dropout_out_module = FairseqDropout( - dropout_out*1.0, module_name=self.__class__.__name__ + dropout_out * 1.0, module_name=self.__class__.__name__ ) self.bidirectional = bidirectional self.hidden_size = hidden_size @@ -329,7 +329,9 @@ class LSTMEncoder(FairseqEncoder): out = outs.view(self.num_layers, 2, bsz, -1).transpose(1, 2).contiguous() return out.view(self.num_layers, bsz, -1) - def reorder_encoder_out(self, encoder_out: Tuple[Tensor, Tensor, Tensor, Tensor], new_order): + def reorder_encoder_out( + self, encoder_out: Tuple[Tensor, Tensor, Tensor, Tensor], new_order + ): return tuple( ( encoder_out[0].index_select(1, new_order), @@ -402,10 +404,10 @@ class LSTMDecoder(FairseqIncrementalDecoder): ): super().__init__(dictionary) self.dropout_in_module = FairseqDropout( - dropout_in*1.0, module_name=self.__class__.__name__ + dropout_in * 1.0, module_name=self.__class__.__name__ ) self.dropout_out_module = FairseqDropout( - dropout_out*1.0, module_name=self.__class__.__name__ + dropout_out * 1.0, module_name=self.__class__.__name__ ) self.hidden_size = hidden_size self.share_input_output_embed = share_input_output_embed diff --git a/fairseq/models/nat/fairseq_nat_model.py b/fairseq/models/nat/fairseq_nat_model.py index b0939411..a5594a4e 100644 --- a/fairseq/models/nat/fairseq_nat_model.py +++ b/fairseq/models/nat/fairseq_nat_model.py @@ -18,7 +18,10 @@ def ensemble_encoder(func): def wrapper(self, *args, **kwargs): if self.ensemble_models is None or len(self.ensemble_models) == 1: return func(self, *args, **kwargs) - encoder_outs = [func(model, *args, **kwargs, return_all_hiddens=True) for model in self.ensemble_models] + encoder_outs = [ + func(model, *args, **kwargs, return_all_hiddens=True) + for model in self.ensemble_models + ] _encoder_out = encoder_outs[0].copy() def stack(key): @@ -56,8 +59,7 @@ def ensemble_decoder(func): model, normalize=normalize, encoder_out=_replace( - encoder_out, - encoder_out["encoder_out"][0][:, :, :, i] + encoder_out, encoder_out["encoder_out"][0][:, :, :, i] ), *args, **kwargs diff --git a/fairseq/models/nat/nonautoregressive_ensembles.py b/fairseq/models/nat/nonautoregressive_ensembles.py index 705a04fb..0a0221f9 100644 --- a/fairseq/models/nat/nonautoregressive_ensembles.py +++ b/fairseq/models/nat/nonautoregressive_ensembles.py @@ -85,7 +85,8 @@ class EnsembleLevT(BasicEnsembleModel): else: if not encoder_outs[0]["encoder_padding_mask"]: src_lens = ( - encoder_outs[0]["encoder_out"][0].new(bsz) + encoder_outs[0]["encoder_out"][0] + .new(bsz) .fill_(encoder_outs[0]["encoder_out"][0].size(1)) ) else: diff --git a/fairseq/models/roberta/model.py b/fairseq/models/roberta/model.py index bb205b91..6e59ece5 100644 --- a/fairseq/models/roberta/model.py +++ b/fairseq/models/roberta/model.py @@ -183,7 +183,7 @@ class RobertaModel(FairseqEncoderModel): "communication less efficient due to smaller input sizes. This option " "is set to 0 (i.e., always wrap) when --checkpoint-activations or " "--offload-activations are passed." - ) + ), ) @classmethod @@ -542,7 +542,9 @@ def base_architecture(args): args.layernorm_embedding = safe_getattr(args, "layernorm_embedding", True) args.no_scale_embedding = safe_getattr(args, "no_scale_embedding", True) args.activation_fn = safe_getattr(args, "activation_fn", "gelu") - args.encoder_normalize_before = safe_getattr(args, "encoder_normalize_before", False) + args.encoder_normalize_before = safe_getattr( + args, "encoder_normalize_before", False + ) args.pooler_activation_fn = safe_getattr(args, "pooler_activation_fn", "tanh") args.untie_weights_roberta = safe_getattr(args, "untie_weights_roberta", False) diff --git a/fairseq/models/roberta/model_gottbert.py b/fairseq/models/roberta/model_gottbert.py index 2e8c6635..dc7a019b 100644 --- a/fairseq/models/roberta/model_gottbert.py +++ b/fairseq/models/roberta/model_gottbert.py @@ -12,26 +12,26 @@ from .hub_interface import RobertaHubInterface from .model import RobertaModel -@register_model('gottbert') +@register_model("gottbert") class GottbertModel(RobertaModel): - @classmethod def hub_models(cls): return { - 'gottbert-base': 'https://dl.gottbert.de/fairseq/models/gottbert-base.tar.gz', + "gottbert-base": "https://dl.gottbert.de/fairseq/models/gottbert-base.tar.gz", } @classmethod - def from_pretrained(cls, - model_name_or_path, - checkpoint_file='model.pt', - data_name_or_path='.', - bpe='hf_byte_bpe', - bpe_vocab='vocab.json', - bpe_merges='merges.txt', - bpe_add_prefix_space=False, - **kwargs - ): + def from_pretrained( + cls, + model_name_or_path, + checkpoint_file="model.pt", + data_name_or_path=".", + bpe="hf_byte_bpe", + bpe_vocab="vocab.json", + bpe_merges="merges.txt", + bpe_add_prefix_space=False, + **kwargs + ): from fairseq import hub_utils x = hub_utils.from_pretrained( @@ -46,4 +46,4 @@ class GottbertModel(RobertaModel): bpe_add_prefix_space=bpe_add_prefix_space, **kwargs, ) - return RobertaHubInterface(x['args'], x['task'], x['models'][0]) + return RobertaHubInterface(x["args"], x["task"], x["models"][0]) diff --git a/fairseq/models/speech_to_text/s2t_transformer.py b/fairseq/models/speech_to_text/s2t_transformer.py index aff9d0ff..cc108485 100644 --- a/fairseq/models/speech_to_text/s2t_transformer.py +++ b/fairseq/models/speech_to_text/s2t_transformer.py @@ -202,10 +202,10 @@ class S2TTransformerModel(FairseqEncoderDecoderModel): help="model to take encoder weights from (for initialization)", ) parser.add_argument( - '--encoder-freezing-updates', + "--encoder-freezing-updates", type=int, - metavar='N', - help='freeze encoder for first N updates' + metavar="N", + help="freeze encoder for first N updates", ) @classmethod @@ -329,7 +329,9 @@ class S2TTransformerEncoder(FairseqEncoder): return { "encoder_out": [x], # T x B x C - "encoder_padding_mask": [encoder_padding_mask] if encoder_padding_mask.any() else [], # B x T + "encoder_padding_mask": [encoder_padding_mask] + if encoder_padding_mask.any() + else [], # B x T "encoder_embedding": [], # B x T x C "encoder_states": encoder_states, # List[T x B x C] "src_tokens": [], @@ -339,27 +341,37 @@ class S2TTransformerEncoder(FairseqEncoder): def forward(self, src_tokens, src_lengths, return_all_hiddens=False): if self.num_updates < self.encoder_freezing_updates: with torch.no_grad(): - x = self._forward(src_tokens, src_lengths, - return_all_hiddens=return_all_hiddens) + x = self._forward( + src_tokens, src_lengths, return_all_hiddens=return_all_hiddens + ) else: - x = self._forward(src_tokens, src_lengths, - return_all_hiddens=return_all_hiddens) + x = self._forward( + src_tokens, src_lengths, return_all_hiddens=return_all_hiddens + ) return x def reorder_encoder_out(self, encoder_out, new_order): new_encoder_out = ( - [] if len(encoder_out["encoder_out"]) == 0 + [] + if len(encoder_out["encoder_out"]) == 0 else [x.index_select(1, new_order) for x in encoder_out["encoder_out"]] ) new_encoder_padding_mask = ( - [] if len(encoder_out["encoder_padding_mask"]) == 0 - else [x.index_select(0, new_order) for x in encoder_out["encoder_padding_mask"]] + [] + if len(encoder_out["encoder_padding_mask"]) == 0 + else [ + x.index_select(0, new_order) + for x in encoder_out["encoder_padding_mask"] + ] ) new_encoder_embedding = ( - [] if len(encoder_out["encoder_embedding"]) == 0 - else [x.index_select(0, new_order) for x in encoder_out["encoder_embedding"]] + [] + if len(encoder_out["encoder_embedding"]) == 0 + else [ + x.index_select(0, new_order) for x in encoder_out["encoder_embedding"] + ] ) encoder_states = encoder_out["encoder_states"] diff --git a/fairseq/models/speech_to_text/xm_transformer.py b/fairseq/models/speech_to_text/xm_transformer.py index 5eecbfa2..c2cc86bb 100644 --- a/fairseq/models/speech_to_text/xm_transformer.py +++ b/fairseq/models/speech_to_text/xm_transformer.py @@ -9,8 +9,12 @@ import copy from typing import Dict, List, Optional, Tuple from fairseq import utils, checkpoint_utils -from fairseq.models import (FairseqEncoderDecoderModel, FairseqEncoder, - register_model, register_model_architecture) +from fairseq.models import ( + FairseqEncoderDecoderModel, + FairseqEncoder, + register_model, + register_model_architecture, +) from fairseq.models.transformer import Embedding, TransformerDecoder from fairseq.models.wav2vec import Wav2VecEncoder from fairseq.modules.layer_norm import LayerNorm @@ -24,18 +28,23 @@ logger = logging.getLogger(__name__) class Conv1dAdaptor(nn.Module): - def __init__(self, in_dim, out_dim, n_layers=3, kernel_size=3, stride=2, - add_layernorm=False): + def __init__( + self, in_dim, out_dim, n_layers=3, kernel_size=3, stride=2, add_layernorm=False + ): super().__init__() self.layers = nn.ModuleList( - nn.Conv1d(in_dim if i == 0 else out_dim, out_dim * 2, kernel_size, - stride=stride, padding=kernel_size // 2) + nn.Conv1d( + in_dim if i == 0 else out_dim, + out_dim * 2, + kernel_size, + stride=stride, + padding=kernel_size // 2, + ) for i in range(n_layers) ) self.layernorms = None if add_layernorm: - self.layernorms = nn.ModuleList(LayerNorm(out_dim) - for _ in range(n_layers)) + self.layernorms = nn.ModuleList(LayerNorm(out_dim) for _ in range(n_layers)) self.stride = stride @classmethod @@ -43,7 +52,7 @@ class Conv1dAdaptor(nn.Module): parser.add_argument("--adaptor-n-layers", type=int) parser.add_argument("--adaptor-kernel-size", type=int) parser.add_argument("--adaptor-stride", type=int) - parser.add_argument("--adaptor-layernorm", action='store_true') + parser.add_argument("--adaptor-layernorm", action="store_true") def get_out_seq_lens_tensor(self, in_seq_lens_tensor): out = in_seq_lens_tensor.clone() @@ -197,15 +206,18 @@ class Wav2VecEncoderWithAdaptor(FairseqEncoder): encoder_out_dim = self.w2v_encoder.w2v_model.encoder.embedding_dim # Projection + 8x shrinking self.adaptor = Conv1dAdaptor( - encoder_out_dim, args.decoder_embed_dim, + encoder_out_dim, + args.decoder_embed_dim, n_layers=args.adaptor_n_layers, - kernel_size=args.adaptor_kernel_size, stride=args.adaptor_stride, - add_layernorm=args.adaptor_layernorm + kernel_size=args.adaptor_kernel_size, + stride=args.adaptor_stride, + add_layernorm=args.adaptor_layernorm, ) for k, p in self.w2v_encoder.w2v_model.named_parameters(): # Freeze pretrained models by default - if safe_hasattr(args, 'finetune_w2v_params') and XMTransformerModel.finetune_params( - args.finetune_w2v_params, k): + if safe_hasattr( + args, "finetune_w2v_params" + ) and XMTransformerModel.finetune_params(args.finetune_w2v_params, k): p.requires_grad = True else: p.requires_grad = False @@ -214,11 +226,16 @@ class Wav2VecEncoderWithAdaptor(FairseqEncoder): def add_args(cls, parser): add_wav2vec_asr_args(parser) parser.add_argument( - "--normalize", action="store_true", + "--normalize", + action="store_true", help="if set, normalizes input to have 0 mean and unit variance", ) - parser.add_argument("--finetune-w2v-params", type=str, metavar="STR", - help="comma-separated param strings to finetune.") + parser.add_argument( + "--finetune-w2v-params", + type=str, + metavar="STR", + help="comma-separated param strings to finetune.", + ) Conv1dAdaptor.add_args(parser) def forward(self, src_tokens, src_lengths=None, **kwargs): @@ -227,13 +244,17 @@ class Wav2VecEncoderWithAdaptor(FairseqEncoder): x = out["encoder_out"] enc_padding_mask = None if out["encoder_padding_mask"] is not None: - enc_padding_mask = out["encoder_padding_mask"].transpose(0, 1) # T X B --> B X T + enc_padding_mask = out["encoder_padding_mask"].transpose( + 0, 1 + ) # T X B --> B X T x, enc_padding_mask = self.adaptor(x, enc_padding_mask) return { "encoder_out": [x], # T x B x C - "encoder_padding_mask": [enc_padding_mask] if enc_padding_mask.any() else [], # B x T + "encoder_padding_mask": [enc_padding_mask] + if enc_padding_mask.any() + else [], # B x T "encoder_embedding": [], # B x T x C "encoder_states": [], # List[T x B x C] "src_tokens": [], @@ -242,20 +263,26 @@ class Wav2VecEncoderWithAdaptor(FairseqEncoder): def reorder_encoder_out(self, encoder_out, new_order): new_encoder_out = ( - [] if len(encoder_out["encoder_out"]) == 0 + [] + if len(encoder_out["encoder_out"]) == 0 else [x.index_select(1, new_order) for x in encoder_out["encoder_out"]] ) new_encoder_padding_mask = ( - [] if len(encoder_out["encoder_padding_mask"]) == 0 - else [x.index_select(0, new_order) for x in - encoder_out["encoder_padding_mask"]] + [] + if len(encoder_out["encoder_padding_mask"]) == 0 + else [ + x.index_select(0, new_order) + for x in encoder_out["encoder_padding_mask"] + ] ) new_encoder_embedding = ( - [] if len(encoder_out["encoder_embedding"]) == 0 - else [x.index_select(0, new_order) for x in - encoder_out["encoder_embedding"]] + [] + if len(encoder_out["encoder_embedding"]) == 0 + else [ + x.index_select(0, new_order) for x in encoder_out["encoder_embedding"] + ] ) encoder_states = encoder_out["encoder_states"] @@ -274,38 +301,71 @@ class Wav2VecEncoderWithAdaptor(FairseqEncoder): def add_decoder_args(parser): - parser.add_argument("--activation-fn", type=str, default='relu', - choices=utils.get_available_activation_fns(), - help="activation function to use") - parser.add_argument("--decoder-dropout", type=float, metavar="D", - help="dropout probability") - parser.add_argument("--decoder-attention-dropout", type=float, - metavar="D", - help="dropout probability for attention weights") - parser.add_argument("--decoder-activation-dropout", type=float, - metavar="D", - help="dropout probability after activation in FFN.") - parser.add_argument("--decoder-embed-dim", type=int, metavar="N", - help="decoder embedding dimension") - parser.add_argument("--decoder-ffn-embed-dim", type=int, metavar="N", - help="decoder embedding dimension for FFN") - parser.add_argument("--decoder-layers", type=int, metavar="N", - help="num decoder layers") - parser.add_argument("--decoder-attention-heads", type=int, metavar="N", - help="num decoder attention heads") - parser.add_argument("--decoder-normalize-before", action="store_true", - help="apply layernorm before each decoder block") - parser.add_argument("--layernorm-embedding", action="store_true", - help="add layernorm to embedding") - parser.add_argument("--no-scale-embedding", action="store_true", - help="if True, dont scale embeddings") parser.add_argument( - "--load-pretrained-decoder-from", type=str, metavar="STR", - help="model to take decoder weights from (for initialization)" + "--activation-fn", + type=str, + default="relu", + choices=utils.get_available_activation_fns(), + help="activation function to use", + ) + parser.add_argument( + "--decoder-dropout", type=float, metavar="D", help="dropout probability" + ) + parser.add_argument( + "--decoder-attention-dropout", + type=float, + metavar="D", + help="dropout probability for attention weights", + ) + parser.add_argument( + "--decoder-activation-dropout", + type=float, + metavar="D", + help="dropout probability after activation in FFN.", + ) + parser.add_argument( + "--decoder-embed-dim", type=int, metavar="N", help="decoder embedding dimension" + ) + parser.add_argument( + "--decoder-ffn-embed-dim", + type=int, + metavar="N", + help="decoder embedding dimension for FFN", + ) + parser.add_argument( + "--decoder-layers", type=int, metavar="N", help="num decoder layers" + ) + parser.add_argument( + "--decoder-attention-heads", + type=int, + metavar="N", + help="num decoder attention heads", + ) + parser.add_argument( + "--decoder-normalize-before", + action="store_true", + help="apply layernorm before each decoder block", + ) + parser.add_argument( + "--layernorm-embedding", action="store_true", help="add layernorm to embedding" + ) + parser.add_argument( + "--no-scale-embedding", + action="store_true", + help="if True, dont scale embeddings", + ) + parser.add_argument( + "--load-pretrained-decoder-from", + type=str, + metavar="STR", + help="model to take decoder weights from (for initialization)", + ) + parser.add_argument( + "--finetune-decoder-params", + type=str, + metavar="STR", + help="comma-separated param strings to finetune.", ) - parser.add_argument("--finetune-decoder-params", type=str, - metavar="STR", - help="comma-separated param strings to finetune.") parser.add_argument("--checkpoint-activations", action="store_true") @@ -342,16 +402,16 @@ class XMTransformerModel(FairseqEncoderDecoderModel): _args.activation_dropout = args.decoder_activation_dropout _args.max_target_positions = 1024 - decoder = TransformerDecoder(_args, task.target_dictionary, - embed_tokens) + decoder = TransformerDecoder(_args, task.target_dictionary, embed_tokens) if getattr(args, "load_pretrained_decoder_from", None): decoder = checkpoint_utils.load_pretrained_component_from_model( component=decoder, checkpoint=args.load_pretrained_decoder_from ) for k, p in decoder.named_parameters(): # Freeze pretrained models by default - if safe_hasattr(args, 'finetune_decoder_params') and XMTransformerModel.finetune_params( - args.finetune_decoder_params, k): + if safe_hasattr( + args, "finetune_decoder_params" + ) and XMTransformerModel.finetune_params(args.finetune_decoder_params, k): p.requires_grad = True else: p.requires_grad = False @@ -369,8 +429,9 @@ class XMTransformerModel(FairseqEncoderDecoderModel): padding_idx = dictionary.pad() return Embedding(num_embeddings, embed_dim, padding_idx) - decoder_embed_tokens = build_embedding(task.target_dictionary, - args.decoder_embed_dim) + decoder_embed_tokens = build_embedding( + task.target_dictionary, args.decoder_embed_dim + ) encoder = cls.build_encoder(args) decoder = cls.build_decoder(args, task, decoder_embed_tokens) return cls(encoder, decoder) @@ -382,8 +443,7 @@ class XMTransformerModel(FairseqEncoderDecoderModel): sample: Optional[Dict[str, Tensor]] = None, ): # net_output['encoder_out'] is a (B, T, D) tensor - lprobs = self.get_normalized_probs_scriptable(net_output, log_probs, - sample) + lprobs = self.get_normalized_probs_scriptable(net_output, log_probs, sample) lprobs.batch_first = True return lprobs @@ -393,17 +453,19 @@ class XMTransformerModel(FairseqEncoderDecoderModel): argument in its input, which is not supported in torchscript. This method overrites the forward method definition without **kwargs. """ - encoder_out = self.encoder(src_tokens=src_tokens, - src_lengths=src_lengths, **kwargs) - decoder_out = self.decoder(prev_output_tokens=prev_output_tokens, - encoder_out=encoder_out) + encoder_out = self.encoder( + src_tokens=src_tokens, src_lengths=src_lengths, **kwargs + ) + decoder_out = self.decoder( + prev_output_tokens=prev_output_tokens, encoder_out=encoder_out + ) return decoder_out def upgrade_state_dict(self, state_dict): for k, _ in state_dict.items(): - if 'adaptor.layers' in state_dict: + if "adaptor.layers" in state_dict: print(k) - new = k.replace('adaptor.layers', 'adaptor_layers') + new = k.replace("adaptor.layers", "adaptor_layers") state_dict[new] = state_dict[k] del state_dict[k] @@ -435,11 +497,9 @@ def set_default_w2v_encoder_args(args): args.mask_channel_length = getattr(args, "mask_channel_length", 10) args.mask_channel_prob = getattr(args, "mask_channel_prob", 0.5) args.mask_channel_before = getattr(args, "mask_channel_before", False) - args.mask_channel_selection = getattr(args, "mask_channel_selection", - "static") + args.mask_channel_selection = getattr(args, "mask_channel_selection", "static") args.mask_channel_other = getattr(args, "mask_channel_other", 0) - args.no_mask_channel_overlap = getattr(args, "no_mask_channel_overlap", - False) + args.no_mask_channel_overlap = getattr(args, "no_mask_channel_overlap", False) args.freeze_finetune_updates = getattr(args, "freeze_finetune_updates", 0) args.feature_grad_mult = 0.1 @@ -456,49 +516,43 @@ def set_default_adaptor_args(args): def set_default_mbart_decoder_args(args): - args.decoder_embed_path = getattr(args, 'decoder_embed_path', None) - args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 1024) - args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', - 4 * 1024) - args.decoder_layers = getattr(args, 'decoder_layers', 12) - args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 16) - args.decoder_normalize_before = getattr(args, 'decoder_normalize_before', - True) - args.decoder_learned_pos = getattr(args, 'decoder_learned_pos', True) + args.decoder_embed_path = getattr(args, "decoder_embed_path", None) + args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1024) + args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 4 * 1024) + args.decoder_layers = getattr(args, "decoder_layers", 12) + args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16) + args.decoder_normalize_before = getattr(args, "decoder_normalize_before", True) + args.decoder_learned_pos = getattr(args, "decoder_learned_pos", True) args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0.0) args.adaptive_input = getattr(args, "adaptive_input", False) - args.decoder_attention_dropout = getattr(args, 'decoder_attention_dropout', - 0.) - args.decoder_activation_dropout = getattr(args, - 'decoder_activation_dropout', 0.) - args.decoder_dropout = getattr(args, 'decoder_dropout', 0.1) - args.adaptive_softmax_cutoff = getattr(args, 'adaptive_softmax_cutoff', - None) - args.adaptive_softmax_dropout = getattr(args, 'adaptive_softmax_dropout', 0) + args.decoder_attention_dropout = getattr(args, "decoder_attention_dropout", 0.0) + args.decoder_activation_dropout = getattr(args, "decoder_activation_dropout", 0.0) + args.decoder_dropout = getattr(args, "decoder_dropout", 0.1) + args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None) + args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0) args.share_decoder_input_output_embed = getattr( - args, 'share_decoder_input_output_embed', True + args, "share_decoder_input_output_embed", True ) args.no_token_positional_embeddings = getattr( args, "no_token_positional_embeddings", False ) - args.decoder_output_dim = getattr(args, 'decoder_output_dim', - args.decoder_embed_dim) - args.decoder_input_dim = getattr(args, 'decoder_input_dim', - args.decoder_embed_dim) + args.decoder_output_dim = getattr( + args, "decoder_output_dim", args.decoder_embed_dim + ) + args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim) - args.no_scale_embedding = getattr(args, 'no_scale_embedding', False) + args.no_scale_embedding = getattr(args, "no_scale_embedding", False) args.quant_noise_pq = getattr(args, "quant_noise_pq", 0) - args.layernorm_embedding = getattr(args, 'layernorm_embedding', True) + args.layernorm_embedding = getattr(args, "layernorm_embedding", True) - args.activation_fn = getattr(args, 'activation_fn', 'gelu') - args.pooler_activation_fn = getattr(args, 'pooler_activation_fn', 'tanh') - args.pooler_dropout = getattr(args, 'pooler_dropout', 0.0) + args.activation_fn = getattr(args, "activation_fn", "gelu") + args.pooler_activation_fn = getattr(args, "pooler_activation_fn", "tanh") + args.pooler_dropout = getattr(args, "pooler_dropout", 0.0) args.checkpoint_activations = getattr(args, "checkpoint_activations", False) -@register_model_architecture(model_name="xm_transformer", - arch_name="xm_transformer") +@register_model_architecture(model_name="xm_transformer", arch_name="xm_transformer") def base_architecture(args): set_default_w2v_encoder_args(args) set_default_adaptor_args(args) diff --git a/fairseq/models/text_to_speech/fastspeech2.py b/fairseq/models/text_to_speech/fastspeech2.py index 4fe9cc4d..f2a0792b 100644 --- a/fairseq/models/text_to_speech/fastspeech2.py +++ b/fairseq/models/text_to_speech/fastspeech2.py @@ -8,10 +8,17 @@ import logging import torch from torch import nn -from fairseq.models import (FairseqEncoder, FairseqEncoderModel, register_model, - register_model_architecture) +from fairseq.models import ( + FairseqEncoder, + FairseqEncoderModel, + register_model, + register_model_architecture, +) from fairseq.modules import ( - LayerNorm, PositionalEmbedding, FairseqDropout, MultiheadAttention + LayerNorm, + PositionalEmbedding, + FairseqDropout, + MultiheadAttention, ) from fairseq import utils from fairseq.data.data_utils import lengths_to_padding_mask @@ -36,11 +43,19 @@ class PositionwiseFeedForward(nn.Module): def __init__(self, in_dim, hidden_dim, kernel_size, dropout): super().__init__() self.ffn = nn.Sequential( - nn.Conv1d(in_dim, hidden_dim, kernel_size=kernel_size, - padding=(kernel_size - 1) // 2), + nn.Conv1d( + in_dim, + hidden_dim, + kernel_size=kernel_size, + padding=(kernel_size - 1) // 2, + ), nn.ReLU(), - nn.Conv1d(hidden_dim, in_dim, kernel_size=kernel_size, - padding=(kernel_size - 1) // 2) + nn.Conv1d( + hidden_dim, + in_dim, + kernel_size=kernel_size, + padding=(kernel_size - 1) // 2, + ), ) self.layer_norm = LayerNorm(in_dim) self.dropout = self.dropout_module = FairseqDropout( @@ -57,8 +72,7 @@ class PositionwiseFeedForward(nn.Module): class FFTLayer(torch.nn.Module): def __init__( - self, embed_dim, n_heads, hidden_dim, kernel_size, dropout, - attention_dropout + self, embed_dim, n_heads, hidden_dim, kernel_size, dropout, attention_dropout ): super().__init__() self.self_attn = MultiheadAttention( @@ -74,8 +88,7 @@ class FFTLayer(torch.nn.Module): residual = x x = x.transpose(0, 1) x, _ = self.self_attn( - query=x, key=x, value=x, key_padding_mask=padding_mask, - need_weights=False + query=x, key=x, value=x, key_padding_mask=padding_mask, need_weights=False ) x = x.transpose(0, 1) x = self.layer_norm(x + residual) @@ -106,11 +119,12 @@ class VariancePredictor(nn.Module): super().__init__() self.conv1 = nn.Sequential( nn.Conv1d( - args.encoder_embed_dim, args.var_pred_hidden_dim, + args.encoder_embed_dim, + args.var_pred_hidden_dim, kernel_size=args.var_pred_kernel_size, - padding=(args.var_pred_kernel_size - 1) // 2 + padding=(args.var_pred_kernel_size - 1) // 2, ), - nn.ReLU() + nn.ReLU(), ) self.ln1 = nn.LayerNorm(args.var_pred_hidden_dim) self.dropout_module = FairseqDropout( @@ -118,10 +132,12 @@ class VariancePredictor(nn.Module): ) self.conv2 = nn.Sequential( nn.Conv1d( - args.var_pred_hidden_dim, args.var_pred_hidden_dim, - kernel_size=args.var_pred_kernel_size, padding=1 + args.var_pred_hidden_dim, + args.var_pred_hidden_dim, + kernel_size=args.var_pred_kernel_size, + padding=1, ), - nn.ReLU() + nn.ReLU(), ) self.ln2 = nn.LayerNorm(args.var_pred_hidden_dim) self.proj = nn.Linear(args.var_pred_hidden_dim, 1) @@ -171,8 +187,15 @@ class VarianceAdaptor(nn.Module): return out, emb def forward( - self, x, padding_mask, durations=None, pitches=None, energies=None, - d_factor=1.0, p_factor=1.0, e_factor=1.0 + self, + x, + padding_mask, + durations=None, + pitches=None, + energies=None, + d_factor=1.0, + p_factor=1.0, + e_factor=1.0, ): # x: B x T x C log_dur_out = self.duration_predictor(x) @@ -205,8 +228,7 @@ class FastSpeech2Encoder(FairseqEncoder): self.spk_emb_proj = None if embed_speaker is not None: self.spk_emb_proj = nn.Linear( - args.encoder_embed_dim + args.speaker_embed_dim, - args.encoder_embed_dim + args.encoder_embed_dim + args.speaker_embed_dim, args.encoder_embed_dim ) self.dropout_module = FairseqDropout( @@ -224,9 +246,12 @@ class FastSpeech2Encoder(FairseqEncoder): self.encoder_fft_layers = nn.ModuleList( FFTLayer( - args.encoder_embed_dim, args.encoder_attention_heads, - args.fft_hidden_dim, args.fft_kernel_size, - dropout=args.dropout, attention_dropout=args.attention_dropout + args.encoder_embed_dim, + args.encoder_attention_heads, + args.fft_hidden_dim, + args.fft_kernel_size, + dropout=args.dropout, + attention_dropout=args.attention_dropout, ) for _ in range(args.encoder_layers) ) @@ -235,9 +260,12 @@ class FastSpeech2Encoder(FairseqEncoder): self.decoder_fft_layers = nn.ModuleList( FFTLayer( - args.decoder_embed_dim, args.decoder_attention_heads, - args.fft_hidden_dim, args.fft_kernel_size, - dropout=args.dropout, attention_dropout=args.attention_dropout + args.decoder_embed_dim, + args.decoder_attention_heads, + args.fft_hidden_dim, + args.fft_kernel_size, + dropout=args.dropout, + attention_dropout=args.attention_dropout, ) for _ in range(args.decoder_layers) ) @@ -247,15 +275,25 @@ class FastSpeech2Encoder(FairseqEncoder): self.postnet = None if args.add_postnet: self.postnet = Postnet( - self.out_dim, args.postnet_conv_dim, + self.out_dim, + args.postnet_conv_dim, args.postnet_conv_kernel_size, - args.postnet_layers, args.postnet_dropout + args.postnet_layers, + args.postnet_dropout, ) self.apply(model_init) - def forward(self, src_tokens, src_lengths=None, speaker=None, - durations=None, pitches=None, energies=None, **kwargs): + def forward( + self, + src_tokens, + src_lengths=None, + speaker=None, + durations=None, + pitches=None, + energies=None, + **kwargs + ): x = self.embed_tokens(src_tokens) enc_padding_mask = src_tokens.eq(self.padding_idx) @@ -270,8 +308,9 @@ class FastSpeech2Encoder(FairseqEncoder): emb = self.embed_speaker(speaker).expand(bsz, seq_len, -1) x = self.spk_emb_proj(torch.cat([x, emb], dim=2)) - x, out_lens, log_dur_out, pitch_out, energy_out = \ - self.var_adaptor(x, enc_padding_mask, durations, pitches, energies) + x, out_lens, log_dur_out, pitch_out, energy_out = self.var_adaptor( + x, enc_padding_mask, durations, pitches, energies + ) dec_padding_mask = lengths_to_padding_mask(out_lens) x += self.dec_pos_emb_alpha * self.embed_positions(dec_padding_mask) @@ -326,7 +365,7 @@ class FastSpeech2Model(FairseqEncoderModel): out_dim = args.output_frame_dim * args.n_frames_per_step self.ctc_proj = None - if getattr(args, "ctc_weight", 0.) > 0.: + if getattr(args, "ctc_weight", 0.0) > 0.0: self.ctc_proj = nn.Linear(out_dim, len(src_dict)) @classmethod diff --git a/fairseq/models/text_to_speech/hifigan.py b/fairseq/models/text_to_speech/hifigan.py index edc7db60..e30fe77f 100644 --- a/fairseq/models/text_to_speech/hifigan.py +++ b/fairseq/models/text_to_speech/hifigan.py @@ -119,7 +119,7 @@ class Generator(torch.nn.Module): self.ups = nn.ModuleList() for i, (u, k) in enumerate( - zip(cfg["upsample_rates"], cfg["upsample_kernel_sizes"]) + zip(cfg["upsample_rates"], cfg["upsample_kernel_sizes"]) ): self.ups.append( weight_norm( @@ -137,7 +137,7 @@ class Generator(torch.nn.Module): for i in range(len(self.ups)): ch = cfg["upsample_initial_channel"] // (2 ** (i + 1)) for k, d in zip( - cfg["resblock_kernel_sizes"], cfg["resblock_dilation_sizes"] + cfg["resblock_kernel_sizes"], cfg["resblock_dilation_sizes"] ): self.resblocks.append(ResBlock(ch, k, d)) diff --git a/fairseq/models/text_to_speech/tacotron2.py b/fairseq/models/text_to_speech/tacotron2.py index bb327e81..4df40756 100644 --- a/fairseq/models/text_to_speech/tacotron2.py +++ b/fairseq/models/text_to_speech/tacotron2.py @@ -9,9 +9,13 @@ import torch from torch import nn from torch.nn import functional as F -from fairseq.models import (FairseqEncoder, FairseqEncoderDecoderModel, - FairseqIncrementalDecoder, register_model, - register_model_architecture) +from fairseq.models import ( + FairseqEncoder, + FairseqEncoderDecoderModel, + FairseqIncrementalDecoder, + register_model, + register_model_architecture, +) from fairseq.modules import LSTMCellWithZoneOut, LocationAttention @@ -31,29 +35,36 @@ class Tacotron2Encoder(FairseqEncoder): self.spk_emb_proj = None if embed_speaker is not None: self.spk_emb_proj = nn.Linear( - args.encoder_embed_dim + args.speaker_embed_dim, - args.encoder_embed_dim + args.encoder_embed_dim + args.speaker_embed_dim, args.encoder_embed_dim ) - self.embed_tokens = nn.Embedding(len(src_dict), args.encoder_embed_dim, - padding_idx=self.padding_idx) + self.embed_tokens = nn.Embedding( + len(src_dict), args.encoder_embed_dim, padding_idx=self.padding_idx + ) - assert(args.encoder_conv_kernel_size % 2 == 1) + assert args.encoder_conv_kernel_size % 2 == 1 self.convolutions = nn.ModuleList( nn.Sequential( - nn.Conv1d(args.encoder_embed_dim, args.encoder_embed_dim, - kernel_size=args.encoder_conv_kernel_size, - padding=((args.encoder_conv_kernel_size - 1) // 2)), + nn.Conv1d( + args.encoder_embed_dim, + args.encoder_embed_dim, + kernel_size=args.encoder_conv_kernel_size, + padding=((args.encoder_conv_kernel_size - 1) // 2), + ), nn.BatchNorm1d(args.encoder_embed_dim), nn.ReLU(), - nn.Dropout(args.encoder_dropout) + nn.Dropout(args.encoder_dropout), ) for _ in range(args.encoder_conv_layers) ) - self.lstm = nn.LSTM(args.encoder_embed_dim, args.encoder_embed_dim // 2, - num_layers=args.encoder_lstm_layers, - batch_first=True, bidirectional=True) + self.lstm = nn.LSTM( + args.encoder_embed_dim, + args.encoder_embed_dim // 2, + num_layers=args.encoder_lstm_layers, + batch_first=True, + bidirectional=True, + ) self.apply(encoder_init) @@ -78,7 +89,7 @@ class Tacotron2Encoder(FairseqEncoder): return { "encoder_out": [x], # B x T x C - "encoder_padding_mask": encoder_padding_mask, # B x T + "encoder_padding_mask": encoder_padding_mask, # B x T } @@ -86,8 +97,7 @@ class Prenet(nn.Module): def __init__(self, in_dim, n_layers, n_units, dropout): super().__init__() self.layers = nn.ModuleList( - nn.Sequential(nn.Linear(in_dim if i == 0 else n_units, n_units), - nn.ReLU()) + nn.Sequential(nn.Linear(in_dim if i == 0 else n_units, n_units), nn.ReLU()) for i in range(n_layers) ) self.dropout = dropout @@ -102,20 +112,24 @@ class Postnet(nn.Module): def __init__(self, in_dim, n_channels, kernel_size, n_layers, dropout): super(Postnet, self).__init__() self.convolutions = nn.ModuleList() - assert(kernel_size % 2 == 1) + assert kernel_size % 2 == 1 for i in range(n_layers): - cur_layers = [ - nn.Conv1d(in_dim if i == 0 else n_channels, - n_channels if i < n_layers - 1 else in_dim, - kernel_size=kernel_size, - padding=((kernel_size - 1) // 2)), - nn.BatchNorm1d(n_channels if i < n_layers - 1 else in_dim) - ] + ([nn.Tanh()] if i < n_layers - 1 else []) + [nn.Dropout(dropout)] + cur_layers = ( + [ + nn.Conv1d( + in_dim if i == 0 else n_channels, + n_channels if i < n_layers - 1 else in_dim, + kernel_size=kernel_size, + padding=((kernel_size - 1) // 2), + ), + nn.BatchNorm1d(n_channels if i < n_layers - 1 else in_dim), + ] + + ([nn.Tanh()] if i < n_layers - 1 else []) + + [nn.Dropout(dropout)] + ) nn.init.xavier_uniform_( cur_layers[0].weight, - torch.nn.init.calculate_gain( - "tanh" if i < n_layers - 1 else "linear" - ) + torch.nn.init.calculate_gain("tanh" if i < n_layers - 1 else "linear"), ) self.convolutions.append(nn.Sequential(*cur_layers)) @@ -138,21 +152,25 @@ class Tacotron2Decoder(FairseqIncrementalDecoder): self.n_frames_per_step = args.n_frames_per_step self.out_dim = args.output_frame_dim * args.n_frames_per_step - self.prenet = Prenet(self.out_dim, args.prenet_layers, args.prenet_dim, - args.prenet_dropout) + self.prenet = Prenet( + self.out_dim, args.prenet_layers, args.prenet_dim, args.prenet_dropout + ) # take prev_context, prev_frame, (speaker embedding) as input self.attention_lstm = LSTMCellWithZoneOut( args.zoneout, args.prenet_dim + args.encoder_embed_dim, - args.decoder_lstm_dim + args.decoder_lstm_dim, ) # take attention_lstm output, attention_state, encoder_out as input self.attention = LocationAttention( - args.attention_dim, args.encoder_embed_dim, args.decoder_lstm_dim, + args.attention_dim, + args.encoder_embed_dim, + args.decoder_lstm_dim, (1 + int(args.attention_use_cumprob)), - args.attention_conv_dim, args.attention_conv_kernel_size + args.attention_conv_dim, + args.attention_conv_kernel_size, ) # take attention_lstm output, context, (gated_latent) as input @@ -160,7 +178,7 @@ class Tacotron2Decoder(FairseqIncrementalDecoder): LSTMCellWithZoneOut( args.zoneout, args.encoder_embed_dim + args.decoder_lstm_dim, - args.decoder_lstm_dim + args.decoder_lstm_dim, ) for i in range(args.decoder_lstm_layers) ) @@ -169,12 +187,16 @@ class Tacotron2Decoder(FairseqIncrementalDecoder): self.feat_proj = nn.Linear(proj_in_dim, self.out_dim) self.eos_proj = nn.Linear(proj_in_dim, 1) - self.postnet = Postnet(self.out_dim, args.postnet_conv_dim, - args.postnet_conv_kernel_size, - args.postnet_layers, args.postnet_dropout) + self.postnet = Postnet( + self.out_dim, + args.postnet_conv_dim, + args.postnet_conv_kernel_size, + args.postnet_layers, + args.postnet_dropout, + ) self.ctc_proj = None - if getattr(args, "ctc_weight", 0.) > 0.: + if getattr(args, "ctc_weight", 0.0) > 0.0: self.ctc_proj = nn.Linear(self.out_dim, len(src_dict)) self.apply(decoder_init) @@ -190,12 +212,16 @@ class Tacotron2Decoder(FairseqIncrementalDecoder): lstm_h = self.get_incremental_state(incremental_state, "lstm_h") if lstm_h is None: - lstm_h = [enc_out.new_zeros(bsz, self.args.decoder_lstm_dim) - for _ in range(self.args.decoder_lstm_layers)] + lstm_h = [ + enc_out.new_zeros(bsz, self.args.decoder_lstm_dim) + for _ in range(self.args.decoder_lstm_layers) + ] lstm_c = self.get_incremental_state(incremental_state, "lstm_c") if lstm_c is None: - lstm_c = [enc_out.new_zeros(bsz, self.args.decoder_lstm_dim) - for _ in range(self.args.decoder_lstm_layers)] + lstm_c = [ + enc_out.new_zeros(bsz, self.args.decoder_lstm_dim) + for _ in range(self.args.decoder_lstm_layers) + ] attn_w = self.get_incremental_state(incremental_state, "attn_w") if attn_w is None: @@ -216,8 +242,14 @@ class Tacotron2Decoder(FairseqIncrementalDecoder): else: raise ValueError(f"{self.args.init_attn_c} not supported") - def forward(self, prev_output_tokens, encoder_out=None, - incremental_state=None, target_lengths=None, **kwargs): + def forward( + self, + prev_output_tokens, + encoder_out=None, + incremental_state=None, + target_lengths=None, + **kwargs, + ): enc_mask = encoder_out["encoder_padding_mask"] enc_out = encoder_out["encoder_out"][0] in_len = enc_out.size(1) @@ -227,8 +259,9 @@ class Tacotron2Decoder(FairseqIncrementalDecoder): bsz, out_len, _ = prev_output_tokens.size() prenet_out = self.prenet(prev_output_tokens) - (alstm_h, alstm_c, lstm_h, lstm_c, - attn_w, attn_w_cum) = self._get_states(incremental_state, enc_out) + (alstm_h, alstm_c, lstm_h, lstm_c, attn_w, attn_w_cum) = self._get_states( + incremental_state, enc_out + ) attn_ctx = self._get_init_attn_c(enc_out, enc_mask) attn_out = enc_out.new_zeros(bsz, in_len, out_len) @@ -241,9 +274,7 @@ class Tacotron2Decoder(FairseqIncrementalDecoder): attn_state = attn_w.unsqueeze(1) if self.args.attention_use_cumprob: attn_state = torch.stack((attn_w, attn_w_cum), dim=1) - attn_ctx, attn_w = self.attention( - enc_out, enc_mask, alstm_h, attn_state - ) + attn_ctx, attn_w = self.attention(enc_out, enc_mask, alstm_h, attn_state) attn_w_cum = attn_w_cum + attn_w attn_out[:, :, t] = attn_w @@ -297,7 +328,7 @@ class Tacotron2Model(FairseqEncoderDecoderModel): parser.add_argument("--postnet-conv-dim", type=int) parser.add_argument("--postnet-conv-kernel-size", type=int) parser.add_argument("--init-attn-c", type=str) - parser.add_argument("--attention-use-cumprob", action='store_true') + parser.add_argument("--attention-use-cumprob", action="store_true") parser.add_argument("--zoneout", type=float) parser.add_argument("--decoder-lstm-layers", type=int) parser.add_argument("--decoder-lstm-dim", type=int) @@ -333,8 +364,7 @@ def base_architecture(args): # decoder args.attention_dim = getattr(args, "attention_dim", 128) args.attention_conv_dim = getattr(args, "attention_conv_dim", 32) - args.attention_conv_kernel_size = getattr(args, - "attention_conv_kernel_size", 15) + args.attention_conv_kernel_size = getattr(args, "attention_conv_kernel_size", 15) args.prenet_dropout = getattr(args, "prenet_dropout", 0.5) args.prenet_layers = getattr(args, "prenet_layers", 2) args.prenet_dim = getattr(args, "prenet_dim", 256) diff --git a/fairseq/models/text_to_speech/tts_transformer.py b/fairseq/models/text_to_speech/tts_transformer.py index ff7af78b..32bdd374 100644 --- a/fairseq/models/text_to_speech/tts_transformer.py +++ b/fairseq/models/text_to_speech/tts_transformer.py @@ -9,12 +9,14 @@ from typing import List, Optional import torch from torch import nn -from fairseq.models import (FairseqEncoder, FairseqEncoderDecoderModel, - FairseqIncrementalDecoder, register_model, - register_model_architecture) -from fairseq.modules import ( - TransformerEncoderLayer, TransformerDecoderLayer +from fairseq.models import ( + FairseqEncoder, + FairseqEncoderDecoderModel, + FairseqIncrementalDecoder, + register_model, + register_model_architecture, ) +from fairseq.modules import TransformerEncoderLayer, TransformerDecoderLayer from fairseq.models.text_to_speech.tacotron2 import Prenet, Postnet from fairseq.modules import LayerNorm, PositionalEmbedding, FairseqDropout from fairseq.data.data_utils import lengths_to_padding_mask @@ -42,30 +44,31 @@ class TTSTransformerEncoder(FairseqEncoder): self.spk_emb_proj = None if embed_speaker is not None: self.spk_emb_proj = nn.Linear( - args.encoder_embed_dim + args.speaker_embed_dim, - args.encoder_embed_dim + args.encoder_embed_dim + args.speaker_embed_dim, args.encoder_embed_dim ) self.dropout_module = FairseqDropout( p=args.dropout, module_name=self.__class__.__name__ ) - self.embed_tokens = nn.Embedding(len(src_dict), args.encoder_embed_dim, - padding_idx=self.padding_idx) - assert(args.encoder_conv_kernel_size % 2 == 1) + self.embed_tokens = nn.Embedding( + len(src_dict), args.encoder_embed_dim, padding_idx=self.padding_idx + ) + assert args.encoder_conv_kernel_size % 2 == 1 self.prenet = nn.ModuleList( nn.Sequential( - nn.Conv1d(args.encoder_embed_dim, args.encoder_embed_dim, - kernel_size=args.encoder_conv_kernel_size, - padding=((args.encoder_conv_kernel_size - 1) // 2)), + nn.Conv1d( + args.encoder_embed_dim, + args.encoder_embed_dim, + kernel_size=args.encoder_conv_kernel_size, + padding=((args.encoder_conv_kernel_size - 1) // 2), + ), nn.BatchNorm1d(args.encoder_embed_dim), nn.ReLU(), nn.Dropout(args.encoder_dropout), ) for _ in range(args.encoder_conv_layers) ) - self.prenet_proj = nn.Linear( - args.encoder_embed_dim, args.encoder_embed_dim - ) + self.prenet_proj = nn.Linear(args.encoder_embed_dim, args.encoder_embed_dim) self.embed_positions = PositionalEmbedding( args.max_source_positions, args.encoder_embed_dim, self.padding_idx ) @@ -112,7 +115,9 @@ class TTSTransformerEncoder(FairseqEncoder): return { "encoder_out": [x], # T x B x C - "encoder_padding_mask": [padding_mask] if padding_mask.any() else [], # B x T + "encoder_padding_mask": [padding_mask] + if padding_mask.any() + else [], # B x T "encoder_embedding": [], # B x T x C "encoder_states": [], # List[T x B x C] "src_tokens": [], @@ -143,15 +148,15 @@ class TTSTransformerDecoder(FairseqIncrementalDecoder): ) self.pos_emb_alpha = nn.Parameter(torch.ones(1)) self.prenet = nn.Sequential( - Prenet(self.out_dim, args.prenet_layers, args.prenet_dim, - args.prenet_dropout), + Prenet( + self.out_dim, args.prenet_layers, args.prenet_dim, args.prenet_dropout + ), nn.Linear(args.prenet_dim, args.decoder_embed_dim), ) self.n_transformer_layers = args.decoder_transformer_layers self.transformer_layers = nn.ModuleList( - TransformerDecoderLayer(args) - for _ in range(self.n_transformer_layers) + TransformerDecoderLayer(args) for _ in range(self.n_transformer_layers) ) if args.decoder_normalize_before: self.layer_norm = LayerNorm(args.decoder_embed_dim) @@ -161,19 +166,28 @@ class TTSTransformerDecoder(FairseqIncrementalDecoder): self.feat_proj = nn.Linear(args.decoder_embed_dim, self.out_dim) self.eos_proj = nn.Linear(args.decoder_embed_dim, 1) - self.postnet = Postnet(self.out_dim, args.postnet_conv_dim, - args.postnet_conv_kernel_size, - args.postnet_layers, args.postnet_dropout) + self.postnet = Postnet( + self.out_dim, + args.postnet_conv_dim, + args.postnet_conv_kernel_size, + args.postnet_layers, + args.postnet_dropout, + ) self.ctc_proj = None - if getattr(args, "ctc_weight", 0.) > 0.: + if getattr(args, "ctc_weight", 0.0) > 0.0: self.ctc_proj = nn.Linear(self.out_dim, len(src_dict)) self.apply(decoder_init) def extract_features( - self, prev_outputs, encoder_out=None, incremental_state=None, - target_lengths=None, speaker=None, **kwargs + self, + prev_outputs, + encoder_out=None, + incremental_state=None, + target_lengths=None, + speaker=None, + **kwargs ): alignment_layer = self.n_transformer_layers - 1 self_attn_padding_mask = lengths_to_padding_mask(target_lengths) @@ -212,8 +226,8 @@ class TTSTransformerDecoder(FairseqIncrementalDecoder): else None, encoder_out["encoder_padding_mask"][0] if ( - encoder_out is not None - and len(encoder_out["encoder_padding_mask"]) > 0 + encoder_out is not None + and len(encoder_out["encoder_padding_mask"]) > 0 ) else None, incremental_state, @@ -239,13 +253,22 @@ class TTSTransformerDecoder(FairseqIncrementalDecoder): return x, {"attn": attn, "inner_states": inner_states} - def forward(self, prev_output_tokens, encoder_out=None, - incremental_state=None, target_lengths=None, speaker=None, - **kwargs): + def forward( + self, + prev_output_tokens, + encoder_out=None, + incremental_state=None, + target_lengths=None, + speaker=None, + **kwargs + ): x, extra = self.extract_features( - prev_output_tokens, encoder_out=encoder_out, - incremental_state=incremental_state, target_lengths=target_lengths, - speaker=speaker, **kwargs + prev_output_tokens, + encoder_out=encoder_out, + incremental_state=incremental_state, + target_lengths=target_lengths, + speaker=speaker, + **kwargs ) attn = extra["attn"] feat_out = self.feat_proj(x) @@ -328,8 +351,9 @@ class TTSTransformerModel(FairseqEncoderDecoderModel): return cls(encoder, decoder) def forward_encoder(self, src_tokens, src_lengths, speaker=None, **kwargs): - return self.encoder(src_tokens, src_lengths=src_lengths, - speaker=speaker, **kwargs) + return self.encoder( + src_tokens, src_lengths=src_lengths, speaker=speaker, **kwargs + ) def set_num_updates(self, num_updates): super().set_num_updates(num_updates) @@ -348,7 +372,9 @@ def base_architecture(args): # encoder transformer layers args.encoder_transformer_layers = getattr(args, "encoder_transformer_layers", 6) args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512) - args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4 * args.encoder_embed_dim) + args.encoder_ffn_embed_dim = getattr( + args, "encoder_ffn_embed_dim", 4 * args.encoder_embed_dim + ) args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False) args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4) args.attention_dropout = getattr(args, "attention_dropout", 0.0) @@ -366,6 +392,8 @@ def base_architecture(args): # decoder transformer layers args.decoder_transformer_layers = getattr(args, "decoder_transformer_layers", 6) args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512) - args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 4 * args.decoder_embed_dim) + args.decoder_ffn_embed_dim = getattr( + args, "decoder_ffn_embed_dim", 4 * args.decoder_embed_dim + ) args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False) args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 4) diff --git a/fairseq/models/text_to_speech/vocoder.py b/fairseq/models/text_to_speech/vocoder.py index 65d9f9f0..af3042ee 100644 --- a/fairseq/models/text_to_speech/vocoder.py +++ b/fairseq/models/text_to_speech/vocoder.py @@ -13,7 +13,10 @@ from torch import nn import torch.nn.functional as F from fairseq.data.audio.audio_utils import ( - get_window, get_fourier_basis, get_mel_filters, TTSSpectrogram + get_window, + get_fourier_basis, + get_mel_filters, + TTSSpectrogram, ) from fairseq.data.audio.speech_to_text_dataset import S2TDataConfig from fairseq.models.text_to_speech.hifigan import Generator as HiFiGANModel @@ -25,11 +28,9 @@ class PseudoInverseMelScale(torch.nn.Module): def __init__(self, n_stft, n_mels, sample_rate, f_min, f_max) -> None: super(PseudoInverseMelScale, self).__init__() self.n_mels = n_mels - basis = get_mel_filters( - sample_rate, (n_stft - 1) * 2, n_mels, f_min, f_max - ) + basis = get_mel_filters(sample_rate, (n_stft - 1) * 2, n_mels, f_min, f_max) basis = torch.pinverse(basis) # F x F_mel - self.register_buffer('basis', basis) + self.register_buffer("basis", basis) def forward(self, melspec: torch.Tensor) -> torch.Tensor: # pack batch @@ -48,8 +49,12 @@ class PseudoInverseMelScale(torch.nn.Module): class GriffinLim(torch.nn.Module): def __init__( - self, n_fft: int, win_length: int, hop_length: int, n_iter: int, - window_fn=torch.hann_window + self, + n_fft: int, + win_length: int, + hop_length: int, + n_iter: int, + window_fn=torch.hann_window, ): super(GriffinLim, self).__init__() self.transform = TTSSpectrogram( @@ -59,7 +64,7 @@ class GriffinLim(torch.nn.Module): basis = get_fourier_basis(n_fft) basis = torch.pinverse(n_fft / hop_length * basis).T[:, None, :] basis *= get_window(window_fn, n_fft, win_length) - self.register_buffer('basis', basis) + self.register_buffer("basis", basis) self.n_fft = n_fft self.win_length = win_length @@ -70,33 +75,33 @@ class GriffinLim(torch.nn.Module): @classmethod def get_window_sum_square( - cls, n_frames, hop_length, win_length, n_fft, - window_fn=torch.hann_window + cls, n_frames, hop_length, win_length, n_fft, window_fn=torch.hann_window ) -> torch.Tensor: w_sq = get_window(window_fn, n_fft, win_length) ** 2 n = n_fft + hop_length * (n_frames - 1) x = torch.zeros(n, dtype=torch.float32) for i in range(n_frames): ofst = i * hop_length - x[ofst: min(n, ofst + n_fft)] += w_sq[:max(0, min(n_fft, n - ofst))] + x[ofst : min(n, ofst + n_fft)] += w_sq[: max(0, min(n_fft, n - ofst))] return x def inverse(self, magnitude: torch.Tensor, phase) -> torch.Tensor: x = torch.cat( - [magnitude * torch.cos(phase), magnitude * torch.sin(phase)], - dim=1 + [magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1 ) x = F.conv_transpose1d(x, self.basis, stride=self.hop_length) win_sum_sq = self.get_window_sum_square( - magnitude.shape[-1], hop_length=self.hop_length, - win_length=self.win_length, n_fft=self.n_fft + magnitude.shape[-1], + hop_length=self.hop_length, + win_length=self.win_length, + n_fft=self.n_fft, ).to(magnitude.device) # remove modulation effects approx_nonzero_indices = win_sum_sq > self.tiny x[:, :, approx_nonzero_indices] /= win_sum_sq[approx_nonzero_indices] x *= self.n_fft / self.hop_length - x = x[:, :, self.n_fft // 2:] - x = x[:, :, :-self.n_fft // 2:] + x = x[:, :, self.n_fft // 2 :] + x = x[:, :, : -self.n_fft // 2 :] return x def forward(self, specgram: torch.Tensor) -> torch.Tensor: @@ -111,18 +116,33 @@ class GriffinLim(torch.nn.Module): class GriffinLimVocoder(nn.Module): - def __init__(self, sample_rate, win_size, hop_size, n_fft, - n_mels, f_min, f_max, window_fn, - spec_bwd_max_iter=32, - fp16=False): + def __init__( + self, + sample_rate, + win_size, + hop_size, + n_fft, + n_mels, + f_min, + f_max, + window_fn, + spec_bwd_max_iter=32, + fp16=False, + ): super().__init__() self.inv_mel_transform = PseudoInverseMelScale( - n_stft=n_fft // 2 + 1, n_mels=n_mels, sample_rate=sample_rate, - f_min=f_min, f_max=f_max + n_stft=n_fft // 2 + 1, + n_mels=n_mels, + sample_rate=sample_rate, + f_min=f_min, + f_max=f_max, ) self.gl_transform = GriffinLim( - n_fft=n_fft, win_length=win_size, hop_length=hop_size, - window_fn=window_fn, n_iter=spec_bwd_max_iter + n_fft=n_fft, + win_length=win_size, + hop_length=hop_size, + window_fn=window_fn, + n_iter=spec_bwd_max_iter, ) if fp16: self.half() @@ -151,17 +171,19 @@ class GriffinLimVocoder(nn.Module): sample_rate=feat_cfg["sample_rate"], win_size=int(feat_cfg["win_len_t"] * feat_cfg["sample_rate"]), hop_size=int(feat_cfg["hop_len_t"] * feat_cfg["sample_rate"]), - n_fft=feat_cfg["n_fft"], n_mels=feat_cfg["n_mels"], - f_min=feat_cfg["f_min"], f_max=feat_cfg["f_max"], - window_fn=window_fn, spec_bwd_max_iter=args.spec_bwd_max_iter, - fp16=args.fp16 + n_fft=feat_cfg["n_fft"], + n_mels=feat_cfg["n_mels"], + f_min=feat_cfg["f_min"], + f_max=feat_cfg["f_max"], + window_fn=window_fn, + spec_bwd_max_iter=args.spec_bwd_max_iter, + fp16=args.fp16, ) class HiFiGANVocoder(nn.Module): def __init__( - self, checkpoint_path: str, model_cfg: Dict[str, str], - fp16: bool = False + self, checkpoint_path: str, model_cfg: Dict[str, str], fp16: bool = False ) -> None: super().__init__() self.model = HiFiGANModel(model_cfg) diff --git a/fairseq/models/transformer/transformer_decoder.py b/fairseq/models/transformer/transformer_decoder.py index 49e37917..5046fa24 100644 --- a/fairseq/models/transformer/transformer_decoder.py +++ b/fairseq/models/transformer/transformer_decoder.py @@ -29,8 +29,8 @@ from torch import Tensor # rewrite name for backward compatibility in `make_generation_fast_` def module_name_fordropout(module_name: str) -> str: - if module_name == 'TransformerDecoderBase': - return 'TransformerDecoder' + if module_name == "TransformerDecoderBase": + return "TransformerDecoder" else: return module_name diff --git a/fairseq/models/transformer/transformer_encoder.py b/fairseq/models/transformer/transformer_encoder.py index f007776a..578f03d9 100644 --- a/fairseq/models/transformer/transformer_encoder.py +++ b/fairseq/models/transformer/transformer_encoder.py @@ -29,8 +29,8 @@ from fairseq.models.transformer import ( # rewrite name for backward compatibility in `make_generation_fast_` def module_name_fordropout(module_name: str) -> str: - if module_name == 'TransformerEncoderBase': - return 'TransformerEncoder' + if module_name == "TransformerEncoderBase": + return "TransformerEncoder" else: return module_name @@ -232,7 +232,12 @@ class TransformerEncoderBase(FairseqEncoder): # `forward` so we use a dictionary instead. # TorchScript does not support mixed values so the values are all lists. # The empty list is equivalent to None. - src_lengths = src_tokens.ne(self.padding_idx).sum(dim=1, dtype=torch.int32).reshape(-1, 1).contiguous() + src_lengths = ( + src_tokens.ne(self.padding_idx) + .sum(dim=1, dtype=torch.int32) + .reshape(-1, 1) + .contiguous() + ) return { "encoder_out": [x], # T x B x C "encoder_padding_mask": [encoder_padding_mask], # B x T diff --git a/fairseq/models/transformer_lm.py b/fairseq/models/transformer_lm.py index 14cbb089..f029cf05 100644 --- a/fairseq/models/transformer_lm.py +++ b/fairseq/models/transformer_lm.py @@ -15,7 +15,9 @@ from fairseq.models import ( register_model_architecture, ) from fairseq.models.transformer import ( - DEFAULT_MIN_PARAMS_TO_WRAP, Embedding, TransformerDecoder + DEFAULT_MIN_PARAMS_TO_WRAP, + Embedding, + TransformerDecoder, ) from fairseq.modules import AdaptiveInput, CharacterTokenEmbedder from fairseq.utils import safe_getattr, safe_hasattr @@ -179,7 +181,7 @@ class TransformerLanguageModelConfig(FairseqDataclass): "is set to 0 (i.e., always wrap) when --checkpoint-activations or " "--offload-activations are passed." ) - } + }, ) # config for "BASE Layers: Simplifying Training of Large, Sparse Models" base_layers: Optional[int] = field( @@ -189,13 +191,25 @@ class TransformerLanguageModelConfig(FairseqDataclass): default=1, metadata={"help": "number of sublayers in each BASE layer"} ) base_shuffle: Optional[int] = field( - default=1, metadata={"help": "shuffle tokens between workers before computing assignment"} + default=1, + metadata={"help": "shuffle tokens between workers before computing assignment"}, ) # NormFormer - scale_fc: Optional[bool] = field(default=False, metadata={"help": 'Insert LayerNorm between fully connected layers'}) - scale_attn: Optional[bool] = field(default=False, metadata={"help": 'Insert LayerNorm after attention'}) - scale_heads: Optional[bool] = field(default=False, metadata={"help": 'Learn a scale coefficient for each attention head'}) - scale_resids: Optional[bool] = field(default=False, metadata={"help": 'Learn a scale coefficient for each residual connection'}) + scale_fc: Optional[bool] = field( + default=False, + metadata={"help": "Insert LayerNorm between fully connected layers"}, + ) + scale_attn: Optional[bool] = field( + default=False, metadata={"help": "Insert LayerNorm after attention"} + ) + scale_heads: Optional[bool] = field( + default=False, + metadata={"help": "Learn a scale coefficient for each attention head"}, + ) + scale_resids: Optional[bool] = field( + default=False, + metadata={"help": "Learn a scale coefficient for each residual connection"}, + ) # options from other parts of the config add_bos_token: bool = II("task.add_bos_token") tokens_per_sample: int = II("task.tokens_per_sample") @@ -345,7 +359,9 @@ def base_lm_architecture(args): args.decoder_output_dim = safe_getattr( args, "decoder_output_dim", args.decoder_embed_dim ) - args.decoder_input_dim = safe_getattr(args, "decoder_input_dim", args.decoder_embed_dim) + args.decoder_input_dim = safe_getattr( + args, "decoder_input_dim", args.decoder_embed_dim + ) # Model training is not stable without this args.decoder_normalize_before = True @@ -362,10 +378,10 @@ def base_lm_architecture(args): args.layernorm_embedding = safe_getattr(args, "layernorm_embedding", False) args.checkpoint_activations = safe_getattr(args, "checkpoint_activations", False) args.offload_activations = safe_getattr(args, "offload_activations", False) - args.scale_fc = safe_getattr(args, 'scale_fc', False) - args.scale_attn = safe_getattr(args, 'scale_attn', False) - args.scale_heads = safe_getattr(args, 'scale_heads', False) - args.scale_resids = safe_getattr(args, 'scale_resids', False) + args.scale_fc = safe_getattr(args, "scale_fc", False) + args.scale_attn = safe_getattr(args, "scale_attn", False) + args.scale_heads = safe_getattr(args, "scale_heads", False) + args.scale_resids = safe_getattr(args, "scale_resids", False) if args.offload_activations: args.checkpoint_activations = True @@ -387,7 +403,9 @@ def transformer_lm_baevski_wiki103(args): args.dropout = safe_getattr(args, "dropout", 0.3) args.adaptive_input = safe_getattr(args, "adaptive_input", True) args.tie_adaptive_weights = safe_getattr(args, "tie_adaptive_weights", True) - args.adaptive_input_cutoff = safe_getattr(args, "adaptive_input_cutoff", "20000,60000") + args.adaptive_input_cutoff = safe_getattr( + args, "adaptive_input_cutoff", "20000,60000" + ) args.adaptive_softmax_cutoff = safe_getattr( args, "adaptive_softmax_cutoff", "20000,60000" ) @@ -472,7 +490,9 @@ def transformer_lm_gpt2_big(args): def base_gpt3_architecture(args): args.decoder_input_dim = args.decoder_embed_dim args.decoder_output_dim = args.decoder_embed_dim - args.decoder_ffn_embed_dim = safe_getattr(args, "decoder_ffn_embed_dim", args.decoder_embed_dim * 4) + args.decoder_ffn_embed_dim = safe_getattr( + args, "decoder_ffn_embed_dim", args.decoder_embed_dim * 4 + ) # GPT-3 used learned positional embeddings, rather than sinusoidal args.decoder_learned_pos = safe_getattr(args, "decoder_learned_pos", True) args.dropout = safe_getattr(args, "dropout", 0.0) diff --git a/fairseq/models/wav2vec/wav2vec2.py b/fairseq/models/wav2vec/wav2vec2.py index af722a53..6431ccb9 100644 --- a/fairseq/models/wav2vec/wav2vec2.py +++ b/fairseq/models/wav2vec/wav2vec2.py @@ -232,9 +232,11 @@ class Wav2Vec2Config(FairseqDataclass): ) checkpoint_activations: bool = field( - default=False, metadata={"help": "recompute activations and save memory for extra compute"} + default=False, + metadata={"help": "recompute activations and save memory for extra compute"}, ) + @register_model("wav2vec2", dataclass=Wav2Vec2Config) class Wav2Vec2Model(BaseFairseqModel): def __init__(self, cfg: Wav2Vec2Config): @@ -844,14 +846,14 @@ class TransformerEncoder(nn.Module): layers = [] for _ in range(args.encoder_layers): layer = TransformerSentenceEncoderLayer( - embedding_dim=self.embedding_dim, - ffn_embedding_dim=args.encoder_ffn_embed_dim, - num_attention_heads=args.encoder_attention_heads, - dropout=self.dropout, - attention_dropout=args.attention_dropout, - activation_dropout=args.activation_dropout, - activation_fn=args.activation_fn, - layer_norm_first=args.layer_norm_first, + embedding_dim=self.embedding_dim, + ffn_embedding_dim=args.encoder_ffn_embed_dim, + num_attention_heads=args.encoder_attention_heads, + dropout=self.dropout, + attention_dropout=args.attention_dropout, + activation_dropout=args.activation_dropout, + activation_fn=args.activation_fn, + layer_norm_first=args.layer_norm_first, ) if args.checkpoint_activations: layer = fsdp_wrap(layer) diff --git a/fairseq/models/wav2vec/wav2vec2_asr.py b/fairseq/models/wav2vec/wav2vec2_asr.py index 612db5d9..ed40f6d8 100644 --- a/fairseq/models/wav2vec/wav2vec2_asr.py +++ b/fairseq/models/wav2vec/wav2vec2_asr.py @@ -152,10 +152,12 @@ class Wav2Vec2AsrConfig(FairseqDataclass): w2v_args: Any = None checkpoint_activations: bool = field( - default=False, metadata={"help": "recompute activations and save memory for extra compute"} + default=False, + metadata={"help": "recompute activations and save memory for extra compute"}, ) ddp_backend: str = II("distributed_training.ddp_backend") + @dataclass class Wav2Vec2CtcConfig(Wav2Vec2AsrConfig): blank_weight: float = 0 @@ -268,6 +270,7 @@ class Wav2Vec2Seq2SeqConfig(Wav2Vec2AsrConfig): ) autoregressive: bool = II("task.autoregressive") + @register_model("wav2vec_seq2seq", dataclass=Wav2Vec2Seq2SeqConfig) class Wav2Vec2Seq2SeqModel(FairseqEncoderDecoderModel): def __init__(self, encoder, decoder): @@ -394,12 +397,17 @@ class Wav2VecEncoder(FairseqEncoder): def load_model_weights(self, state, model, cfg): if cfg.ddp_backend == "fully_sharded": from fairseq.distributed import FullyShardedDataParallel + for name, module in model.named_modules(): if "encoder.layers" in name and len(name.split(".")) == 3: # Only for layers, we do a special handling and load the weights one by one # We dont load all weights together as that wont be memory efficient and may # cause oom - new_dict = {k.replace(name+".", "") : v for (k, v) in state["model"].items() if name+"." in k} + new_dict = { + k.replace(name + ".", ""): v + for (k, v) in state["model"].items() + if name + "." in k + } assert isinstance(module, FullyShardedDataParallel) with module.summon_full_params(): module.load_state_dict(new_dict, strict=True) @@ -409,7 +417,9 @@ class Wav2VecEncoder(FairseqEncoder): r = re.compile("encoder.layers.\d.") filtered_list = list(filter(r.match, state["model"].keys())) - new_big_dict = {k: v for (k, v) in state["model"].items() if k not in filtered_list} + new_big_dict = { + k: v for (k, v) in state["model"].items() if k not in filtered_list + } model.load_state_dict(new_big_dict, strict=False) else: @@ -462,9 +472,9 @@ class Wav2VecEncoder(FairseqEncoder): 1, new_order ) if encoder_out["padding_mask"] is not None: - encoder_out["padding_mask"] = encoder_out[ - "padding_mask" - ].index_select(0, new_order) + encoder_out["padding_mask"] = encoder_out["padding_mask"].index_select( + 0, new_order + ) return encoder_out def max_positions(self): @@ -640,7 +650,7 @@ class TransformerDecoder(FairseqIncrementalDecoder): self_attn_mask=self.buffered_future_mask(x) if incremental_state is None else None, - self_attn_padding_mask=self_attn_padding_mask + self_attn_padding_mask=self_attn_padding_mask, ) inner_states.append(x) diff --git a/fairseq/modules/base_layer.py b/fairseq/modules/base_layer.py index e7ef155b..e823f7ba 100644 --- a/fairseq/modules/base_layer.py +++ b/fairseq/modules/base_layer.py @@ -12,14 +12,17 @@ from fairseq.modules.layer_norm import LayerNorm class BaseLayer(nn.Module): - def __init__(self, args): super().__init__() self.num_workers = distributed_utils.get_data_parallel_world_size() expert_centroids = torch.empty(self.num_workers, args.decoder_embed_dim) torch.nn.init.orthogonal_(expert_centroids, gain=0.1) - self.register_parameter("expert_centroids", torch.nn.Parameter(expert_centroids)) - self.expert_network = nn.Sequential(*([BaseSublayer(args) for _ in range(args.base_sublayers)])) + self.register_parameter( + "expert_centroids", torch.nn.Parameter(expert_centroids) + ) + self.expert_network = nn.Sequential( + *([BaseSublayer(args) for _ in range(args.base_sublayers)]) + ) self.expert_id = distributed_utils.get_data_parallel_rank() self.shuffle = args.base_shuffle self.cpp = self.load_assignment() @@ -39,20 +42,34 @@ class BaseLayer(nn.Module): with torch.no_grad(): # Compute similarity of each token to each expert, for routing - token_expert_affinities = features.matmul(self.expert_centroids.transpose(0, 1)) + token_expert_affinities = features.matmul( + self.expert_centroids.transpose(0, 1) + ) # Compute which token goes to which expert - sort_by_expert, input_splits, output_splits = self.balanced_assignment(token_expert_affinities) \ - if is_training else self.greedy_assignment(token_expert_affinities) + sort_by_expert, input_splits, output_splits = ( + self.balanced_assignment(token_expert_affinities) + if is_training + else self.greedy_assignment(token_expert_affinities) + ) # Swap these tokens for the right ones for our expert - routed_features = All2All.apply(features[sort_by_expert], output_splits, input_splits) + routed_features = All2All.apply( + features[sort_by_expert], output_splits, input_splits + ) if routed_features.size(0) > 0: # Mix in the expert network based on how appropriate it is for these tokens - alpha = torch.sigmoid(routed_features.mv(self.expert_centroids[self.expert_id])).unsqueeze(1) - routed_features = alpha * self.expert_network(routed_features) + (1 - alpha) * routed_features + alpha = torch.sigmoid( + routed_features.mv(self.expert_centroids[self.expert_id]) + ).unsqueeze(1) + routed_features = ( + alpha * self.expert_network(routed_features) + + (1 - alpha) * routed_features + ) # Return to original worker and ordering - result = All2All.apply(routed_features, input_splits, output_splits)[self.inverse_sort(sort_by_expert)] + result = All2All.apply(routed_features, input_splits, output_splits)[ + self.inverse_sort(sort_by_expert) + ] if self.shuffle and is_training: # Undo shuffling @@ -63,7 +80,9 @@ class BaseLayer(nn.Module): def inverse_sort(self, order): # Creates an index that undoes a sort: xs==xs[order][inverse_sort(order)] - return torch.empty_like(order).scatter_(0, order, torch.arange(0, order.size(0), device=order.device)) + return torch.empty_like(order).scatter_( + 0, order, torch.arange(0, order.size(0), device=order.device) + ) def balanced_assignment(self, scores): ok = scores.isfinite() @@ -79,7 +98,9 @@ class BaseLayer(nn.Module): worker2token = sort_ordering // k # Find how many tokens we're sending to each other worker (being careful for sending 0 tokens to some workers) - output_splits = torch.zeros((self.num_workers,), dtype=torch.long, device=scores.device) + output_splits = torch.zeros( + (self.num_workers,), dtype=torch.long, device=scores.device + ) workers, counts = torch.unique_consecutive(token_to_workers, return_counts=True) output_splits[workers] = counts # Tell other workers how many tokens to expect from us @@ -103,7 +124,7 @@ class BaseSublayer(nn.Module): def __init__(self, args): super().__init__() self.activation_fn = utils.get_activation_fn( - activation=getattr(args, 'activation_fn', 'relu') or "relu" + activation=getattr(args, "activation_fn", "relu") or "relu" ) self.norm = LayerNorm(args.decoder_embed_dim, export=False) self.ff1 = torch.nn.Linear(args.decoder_embed_dim, args.decoder_ffn_embed_dim) @@ -121,15 +142,29 @@ class All2All(torch.autograd.Function): ctx.input_splits = input_splits ctx.output_splits = output_splits - ys = torch.empty_like(xs) if output_splits is None else \ - xs.new_empty(size=[sum(output_splits)] + list(xs.size()[1:])) - torch.distributed.all_to_all_single(ys, xs, output_split_sizes=output_splits, input_split_sizes=input_splits) + ys = ( + torch.empty_like(xs) + if output_splits is None + else xs.new_empty(size=[sum(output_splits)] + list(xs.size()[1:])) + ) + torch.distributed.all_to_all_single( + ys, xs, output_split_sizes=output_splits, input_split_sizes=input_splits + ) return ys @staticmethod def backward(ctx, grad_output): - result = torch.empty_like(grad_output) if ctx.input_splits is None else \ - grad_output.new_empty(size=[sum(ctx.input_splits)] + list(grad_output.size()[1:])) - torch.distributed.all_to_all_single(result, grad_output, - output_split_sizes=ctx.input_splits, input_split_sizes=ctx.output_splits) + result = ( + torch.empty_like(grad_output) + if ctx.input_splits is None + else grad_output.new_empty( + size=[sum(ctx.input_splits)] + list(grad_output.size()[1:]) + ) + ) + torch.distributed.all_to_all_single( + result, + grad_output, + output_split_sizes=ctx.input_splits, + input_split_sizes=ctx.output_splits, + ) return result, None, None diff --git a/fairseq/modules/checkpoint_activations.py b/fairseq/modules/checkpoint_activations.py index dc73c662..aa0b5929 100644 --- a/fairseq/modules/checkpoint_activations.py +++ b/fairseq/modules/checkpoint_activations.py @@ -166,7 +166,9 @@ class CheckpointFunction(torch.autograd.Function): if parent_ctx_dict["offload"]: ctx.fwd_device = tuple(x.device for x in tensor_inputs) ctx.grad_requirements = tuple(x.requires_grad for x in tensor_inputs) - tensor_inputs = tuple(x.to(torch.device("cpu"), non_blocking=True) for x in tensor_inputs) + tensor_inputs = tuple( + x.to(torch.device("cpu"), non_blocking=True) for x in tensor_inputs + ) else: ctx.fwd_device, ctx.grad_requirements = None, None @@ -199,7 +201,8 @@ class CheckpointFunction(torch.autograd.Function): tensor_inputs = checkpoint.detach_variable(tensor_inputs) if ctx.fwd_device is not None: tensor_inputs = [ - t.to(ctx.fwd_device[i], non_blocking=True) for i, t in enumerate(tensor_inputs) + t.to(ctx.fwd_device[i], non_blocking=True) + for i, t in enumerate(tensor_inputs) ] for i, need_grad in enumerate(ctx.grad_requirements): tensor_inputs[i].requires_grad = need_grad diff --git a/fairseq/modules/gumbel_vector_quantizer.py b/fairseq/modules/gumbel_vector_quantizer.py index 71134388..a0e40c36 100644 --- a/fairseq/modules/gumbel_vector_quantizer.py +++ b/fairseq/modules/gumbel_vector_quantizer.py @@ -75,6 +75,7 @@ class GumbelVectorQuantizer(nn.Module): if isinstance(temp, str): import ast + temp = ast.literal_eval(temp) assert len(temp) == 3, f"{temp}, {len(temp)}" diff --git a/fairseq/modules/kmeans_attention.py b/fairseq/modules/kmeans_attention.py index 11a7debc..ca506301 100644 --- a/fairseq/modules/kmeans_attention.py +++ b/fairseq/modules/kmeans_attention.py @@ -47,11 +47,12 @@ def cache_fn(f): return cache cache = f(*args, **kwargs) return cache + return cached_fn def to(t): - return {'device': t.device, 'dtype': t.dtype} + return {"device": t.device, "dtype": t.dtype} def find_modules(nn_module, type): @@ -102,7 +103,7 @@ def reshape_dim(t, dim, split_dims): shape = list(t.shape) num_dims = len(shape) dim = (dim + num_dims) % num_dims - shape[dim:dim+1] = split_dims + shape[dim : dim + 1] = split_dims return t.reshape(shape) @@ -118,6 +119,7 @@ def ema_inplace(moving_avg, new, decay): return moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay)) + # helper classes @@ -173,6 +175,7 @@ class ScaleNorm(nn.Module): def norm(t): n = torch.norm(t, dim=-1, keepdim=True).clamp(min=self.eps) return t / n * self.g + return map_first_tuple_or_el(x, norm) @@ -202,51 +205,62 @@ class MatrixMultiply(nn.Module): tensor = tensor.t() return x @ tensor + # positional embeddings class DepthWiseConv1d(nn.Module): def __init__(self, dim_in, dim_out, kernel_size, stride=1, bias=True, causal=False): super().__init__() - self.padding = ((kernel_size - 1), 0) if causal else (kernel_size // 2, kernel_size // 2) + self.padding = ( + ((kernel_size - 1), 0) if causal else (kernel_size // 2, kernel_size // 2) + ) self.net = nn.Sequential( - nn.Conv1d(dim_in, dim_in, kernel_size=kernel_size, groups=dim_in, stride=stride, bias=bias), - nn.Conv1d(dim_in, dim_out, 1, bias=bias) + nn.Conv1d( + dim_in, + dim_in, + kernel_size=kernel_size, + groups=dim_in, + stride=stride, + bias=bias, + ), + nn.Conv1d(dim_in, dim_out, 1, bias=bias), ) def forward(self, x): - x = F.pad(x, self.padding, value=0.) + x = F.pad(x, self.padding, value=0.0) return self.net(x) class FixedPositionalEmbedding(nn.Module): def __init__(self, dim, max_seq_len): super().__init__() - inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) position = torch.arange(0, max_seq_len, dtype=torch.float) sinusoid_inp = torch.einsum("i,j->ij", position, inv_freq) emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1) - self.register_buffer('emb', emb) + self.register_buffer("emb", emb) def forward(self, x): - return self.emb[None, :x.shape[1], :].to(x) + return self.emb[None, : x.shape[1], :].to(x) def rotate_every_two(x): - x = rearrange(x, '... (d j) -> ... d j', j=2) + x = rearrange(x, "... (d j) -> ... d j", j=2) x1, x2 = x.unbind(dim=-1) x = torch.stack((-x2, x1), dim=-1) - return rearrange(x, '... d j -> ... (d j)') + return rearrange(x, "... d j -> ... (d j)") def apply_rotary_pos_emb(q, k, sinu_pos): - sinu_pos = rearrange(sinu_pos, '() n (j d) -> n j d', j=2) + sinu_pos = rearrange(sinu_pos, "() n (j d) -> n j d", j=2) sin, cos = sinu_pos.unbind(dim=-2) - sin, cos = map(lambda t: repeat(t, 'b n -> b (n j)', j=2), (sin, cos)) + sin, cos = map(lambda t: repeat(t, "b n -> b (n j)", j=2), (sin, cos)) q, k = map(lambda t: (t * cos) + (rotate_every_two(t) * sin), (q, k)) return q, k + # kmeans related function and class @@ -261,7 +275,7 @@ def update_kmeans_on_backwards(module): def similarity(x, means): - return torch.einsum('bhld,hcd->bhlc', x, means) + return torch.einsum("bhld,hcd->bhlc", x, means) def dists_and_buckets(x, means): @@ -303,13 +317,15 @@ def distribution(dists, window_size): class Kmeans(nn.Module): - def __init__(self, num_heads, head_dim, num_clusters, ema_decay=0.999, commitment=1e-4): + def __init__( + self, num_heads, head_dim, num_clusters, ema_decay=0.999, commitment=1e-4 + ): super().__init__() self.commitment = commitment self.ema_decay = ema_decay - self.register_buffer('means', torch.randn(num_heads, num_clusters, head_dim)) - self.register_buffer('initted', torch.tensor(False)) + self.register_buffer("means", torch.randn(num_heads, num_clusters, head_dim)) + self.register_buffer("initted", torch.tensor(False)) self.num_new_means = 0 self.new_means = None @@ -341,7 +357,7 @@ class Kmeans(nn.Module): @torch.no_grad() def update(self, new_means=None): new_means = default(new_means, self.new_means) - assert exists(new_means), 'new kmeans has not been supplied' + assert exists(new_means), "new kmeans has not been supplied" ema_inplace(self.means, new_means, self.ema_decay) del self.new_means @@ -364,16 +380,33 @@ class Kmeans(nn.Module): if update_means: with torch.no_grad(): means = kmeans_iter(x, means, buckets) - self.new_means = ema(self.new_means, means, self.num_new_means / (self.num_new_means + 1)) + self.new_means = ema( + self.new_means, means, self.num_new_means / (self.num_new_means + 1) + ) self.num_new_means += 1 return dists, loss + # kmeans attention class class KmeansAttention(nn.Module): - def __init__(self, num_clusters, window_size, num_heads, head_dim, causal=False, dropout=0., ema_decay=0.999, commitment=1e-4, context_window_size=None, receives_context=False, num_mem_kv=0, shared_qk=False): + def __init__( + self, + num_clusters, + window_size, + num_heads, + head_dim, + causal=False, + dropout=0.0, + ema_decay=0.999, + commitment=1e-4, + context_window_size=None, + receives_context=False, + num_mem_kv=0, + shared_qk=False, + ): super().__init__() self.num_heads = num_heads self.num_clusters = num_clusters @@ -389,18 +422,32 @@ class KmeansAttention(nn.Module): self.dropout = nn.Dropout(dropout) self.num_mem_kv = max(num_mem_kv, 1 if causal and not shared_qk else 0) - self.mem_key = nn.Parameter(torch.randn(num_heads, num_clusters, self.num_mem_kv, head_dim)) - self.mem_value = nn.Parameter(torch.randn(num_heads, num_clusters, self.num_mem_kv, head_dim)) + self.mem_key = nn.Parameter( + torch.randn(num_heads, num_clusters, self.num_mem_kv, head_dim) + ) + self.mem_value = nn.Parameter( + torch.randn(num_heads, num_clusters, self.num_mem_kv, head_dim) + ) def forward(self, q, k, v, query_mask=None, key_mask=None, **kwargs): - b, h, t, d, kv_t, wsz, c_wsz, nc, device, dtype = *q.shape, k.shape[2], self.window_size, self.context_window_size, self.num_clusters, q.device, q.dtype - is_reverse = kwargs.pop('_reverse', False) + b, h, t, d, kv_t, wsz, c_wsz, nc, device, dtype = ( + *q.shape, + k.shape[2], + self.window_size, + self.context_window_size, + self.num_clusters, + q.device, + q.dtype, + ) + is_reverse = kwargs.pop("_reverse", False) out = torch.zeros_like(q, dtype=dtype) update_kmeans = self.training and not is_reverse - key_mask = default(key_mask, query_mask) if not self.receives_context else key_mask + key_mask = ( + default(key_mask, query_mask) if not self.receives_context else key_mask + ) kv_wsz = wsz if not self.receives_context else c_wsz wsz = min(wsz, t) @@ -424,16 +471,22 @@ class KmeansAttention(nn.Module): reshape_with_window = lambda x: x.reshape(b, h, nc, -1, d) q, k, v = map(reshape_with_window, (q, k, v)) - m_k, m_v = map(lambda x: expand_dim(x, 0, b).to(q), (self.mem_key, self.mem_value)) + m_k, m_v = map( + lambda x: expand_dim(x, 0, b).to(q), (self.mem_key, self.mem_value) + ) k, v = map(lambda x: torch.cat(x, dim=3), ((m_k, k), (m_v, v))) - dots = torch.einsum('bhnid,bhnjd->bhnij', q, k) * (d ** -0.5) + dots = torch.einsum("bhnid,bhnjd->bhnij", q, k) * (d ** -0.5) mask_value = max_neg_value(dots) if exists(query_mask) or exists(key_mask): - query_mask = default(query_mask, lambda: torch.ones((b, t), device=device).bool()) - key_mask = default(key_mask, lambda: torch.ones((b, kv_t), device=device).bool()) + query_mask = default( + query_mask, lambda: torch.ones((b, t), device=device).bool() + ) + key_mask = default( + key_mask, lambda: torch.ones((b, kv_t), device=device).bool() + ) q_mask = expand_dim(query_mask, 1, h).gather(2, indices) kv_mask = expand_dim(key_mask, 1, h).gather(2, kv_indices) @@ -444,14 +497,18 @@ class KmeansAttention(nn.Module): del mask if self.causal: - q_mask, kv_mask = map(lambda t: t.reshape(b, h, nc, -1), (indices, kv_indices)) + q_mask, kv_mask = map( + lambda t: t.reshape(b, h, nc, -1), (indices, kv_indices) + ) mask = q_mask[:, :, :, :, None] >= kv_mask[:, :, :, None, :] mask = F.pad(mask, (self.num_mem_kv, 0), value=1) dots.masked_fill_(~mask, mask_value) del mask if self.shared_qk: - q_mask, kv_mask = map(lambda t: t.reshape(b, h, nc, -1), (indices, kv_indices)) + q_mask, kv_mask = map( + lambda t: t.reshape(b, h, nc, -1), (indices, kv_indices) + ) mask = q_mask[:, :, :, :, None] == kv_mask[:, :, :, None, :] mask = F.pad(mask, (self.num_mem_kv, 0), value=0) dots.masked_fill_(mask, TOKEN_SELF_ATTN_VALUE) @@ -460,24 +517,32 @@ class KmeansAttention(nn.Module): dots = dots.softmax(dim=-1) dots = self.dropout(dots) - bo = torch.einsum('bhcij,bhcjd->bhcid', dots, v) + bo = torch.einsum("bhcij,bhcjd->bhcid", dots, v) so = torch.reshape(bo, (b, h, -1, bo.shape[-1])).type(dtype) out = scatter_mean(out, so, indices.unsqueeze(-1).expand_as(so), -2) return out, aux_loss + # feedforward class GELU_(nn.Module): def forward(self, x): - return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) + return ( + 0.5 + * x + * ( + 1 + + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))) + ) + ) -GELU = nn.GELU if hasattr(nn, 'GELU') else GELU_ +GELU = nn.GELU if hasattr(nn, "GELU") else GELU_ class FeedForward(nn.Module): - def __init__(self, dim, mult=4, dropout=0., activation=None, glu=False): + def __init__(self, dim, mult=4, dropout=0.0, activation=None, glu=False): super().__init__() activation = default(activation, GELU) @@ -499,17 +564,49 @@ class FeedForward(nn.Module): x = self.w2(x) return x + # self attention class SelfAttention(nn.Module): - def __init__(self, dim, max_seq_len, heads, local_attn_heads, window_size, dim_head=None, local_attn_window_size=None, local_attn_radius_blocks=1, causal=False, attn_dropout=0., dropout=0., kmeans_ema_decay=0.999, commitment_factor=1e-4, receives_context=False, context_window_size=None, rel_pos_emb=True, num_mem_kv=0, shared_qk=False, conv_query_kernel=9): + def __init__( + self, + dim, + max_seq_len, + heads, + local_attn_heads, + window_size, + dim_head=None, + local_attn_window_size=None, + local_attn_radius_blocks=1, + causal=False, + attn_dropout=0.0, + dropout=0.0, + kmeans_ema_decay=0.999, + commitment_factor=1e-4, + receives_context=False, + context_window_size=None, + rel_pos_emb=True, + num_mem_kv=0, + shared_qk=False, + conv_query_kernel=9, + ): super().__init__() - assert dim_head or (dim % heads) == 0, 'hidden dimension must be divisible by number of heads' - assert (max_seq_len % window_size) == 0, 'maximum sequence length must be divisible by the target window size' - assert local_attn_heads <= heads, 'number of local attention heads must be less than total heads' - assert not (receives_context and local_attn_heads > 0), 'local attention cannot be used for self attention with context' - assert not (receives_context and causal), 'contextual attention layer cannot be causal' + assert ( + dim_head or (dim % heads) == 0 + ), "hidden dimension must be divisible by number of heads" + assert ( + max_seq_len % window_size + ) == 0, "maximum sequence length must be divisible by the target window size" + assert ( + local_attn_heads <= heads + ), "number of local attention heads must be less than total heads" + assert not ( + receives_context and local_attn_heads > 0 + ), "local attention cannot be used for self attention with context" + assert not ( + receives_context and causal + ), "contextual attention layer cannot be causal" local_attn_window_size = default(local_attn_window_size, window_size) context_window_size = default(context_window_size, window_size) @@ -535,7 +632,15 @@ class SelfAttention(nn.Module): if self.local_attn_heads > 0: rel_pos_emb_config = (dim_head, local_attn_heads) if rel_pos_emb else None - self.local_attn = LocalAttention(dim=dim_head, window_size=local_attn_window_size, causal=causal, dropout=attn_dropout, rel_pos_emb_config=rel_pos_emb_config, look_backward=local_attn_radius_blocks, look_forward=0 if causal else local_attn_radius_blocks) + self.local_attn = LocalAttention( + dim=dim_head, + window_size=local_attn_window_size, + causal=causal, + dropout=attn_dropout, + rel_pos_emb_config=rel_pos_emb_config, + look_backward=local_attn_radius_blocks, + look_forward=0 if causal else local_attn_radius_blocks, + ) self.local_to_qkv = nn.Linear(dim, 3 * local_dim_heads) # global @@ -543,12 +648,24 @@ class SelfAttention(nn.Module): global_dim_heads = dim_head * self.global_attn_heads if self.global_attn_heads > 0: - self.global_attn = KmeansAttention(num_clusters, window_size, self.global_attn_heads, dim_head, causal=causal, dropout=attn_dropout, ema_decay=kmeans_ema_decay, commitment=commitment_factor, receives_context=receives_context, num_mem_kv=num_mem_kv, shared_qk=shared_qk) + self.global_attn = KmeansAttention( + num_clusters, + window_size, + self.global_attn_heads, + dim_head, + causal=causal, + dropout=attn_dropout, + ema_decay=kmeans_ema_decay, + commitment=commitment_factor, + receives_context=receives_context, + num_mem_kv=num_mem_kv, + shared_qk=shared_qk, + ) self.to_q = nn.Sequential( - Rearrange('b n c -> b c n'), + Rearrange("b n c -> b c n"), DepthWiseConv1d(dim, global_dim_heads, conv_query_kernel, causal=causal), - Rearrange('b c n -> b n c') + Rearrange("b c n -> b n c"), ) self.to_v = nn.Linear(dim, global_dim_heads, bias=False) @@ -561,14 +678,30 @@ class SelfAttention(nn.Module): self.to_out = nn.Linear(dim_heads, dim, bias=False) self.dropout = nn.Dropout(dropout) - def forward(self, query, key, value, context=None, key_padding_mask=None, context_mask=None, pos_emb=None, **kwargs): - assert not (self.receives_context and not exists(context)), 'context must be passed if self attention is set to receive context' + def forward( + self, + query, + key, + value, + context=None, + key_padding_mask=None, + context_mask=None, + pos_emb=None, + **kwargs + ): + assert not ( + self.receives_context and not exists(context) + ), "context must be passed if self attention is set to receive context" input_mask = key_padding_mask x = query.transpose(0, 1) b, t, _, h, dh = *x.shape, self.heads, self.dim_head - has_local, has_global = map(lambda x: x > 0, (self.local_attn_heads, self.global_attn_heads)) + has_local, has_global = map( + lambda x: x > 0, (self.local_attn_heads, self.global_attn_heads) + ) - split_heads = lambda v: reshape_dim(v, -1, (-1, dh)).transpose(1, 2).contiguous() + split_heads = ( + lambda v: reshape_dim(v, -1, (-1, dh)).transpose(1, 2).contiguous() + ) if has_local: local_qkv = self.local_to_qkv(x).chunk(3, dim=-1) @@ -587,7 +720,7 @@ class SelfAttention(nn.Module): q, k, v = map(split_heads, (q, k, v)) out = [] - total_loss = torch.tensor(0., requires_grad=True, **to(x)) + total_loss = torch.tensor(0.0, requires_grad=True, **to(x)) if has_local: local_out = self.local_attn(lq, lk, lv, input_mask=input_mask) @@ -597,7 +730,9 @@ class SelfAttention(nn.Module): if not self.receives_context and exists(pos_emb): q, k = apply_rotary_pos_emb(q, k, pos_emb) - global_out, loss = self.global_attn(q, k, v, query_mask=input_mask, key_mask=context_mask) + global_out, loss = self.global_attn( + q, k, v, query_mask=input_mask, key_mask=context_mask + ) total_loss = total_loss + loss out.append(global_out) diff --git a/fairseq/modules/linearized_convolution.py b/fairseq/modules/linearized_convolution.py index f7e156cb..1c7a9f09 100644 --- a/fairseq/modules/linearized_convolution.py +++ b/fairseq/modules/linearized_convolution.py @@ -13,6 +13,7 @@ from .conv_tbc import ConvTBC from typing import Dict, Optional from torch import Tensor + @with_incremental_state class LinearizedConvolution(ConvTBC): """An optimized version of nn.Conv1d. @@ -41,7 +42,11 @@ class LinearizedConvolution(ConvTBC): del state_dict[prefix + "_linearized_weight"] @torch.jit.export - def forward(self, input, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None): + def forward( + self, + input, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, + ): """ Args: incremental_state: Used to buffer signal; if not None, then input is @@ -80,18 +85,28 @@ class LinearizedConvolution(ConvTBC): return output.view(bsz, 1, -1) @torch.jit.unused - def reorder_incremental_state(self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], new_order): + def reorder_incremental_state( + self, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], + new_order, + ): input_buffer = self._get_input_buffer(incremental_state) if input_buffer is not None: input_buffer = input_buffer.index_select(0, new_order) self._set_input_buffer(incremental_state, input_buffer) @torch.jit.unused - def _get_input_buffer(self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]): + def _get_input_buffer( + self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] + ): return utils.get_incremental_state(self, incremental_state, "input_buffer") @torch.jit.unused - def _set_input_buffer(self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], new_buffer): + def _set_input_buffer( + self, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], + new_buffer, + ): return utils.set_incremental_state( self, incremental_state, "input_buffer", new_buffer ) diff --git a/fairseq/modules/location_attention.py b/fairseq/modules/location_attention.py index a970876b..dbbbfb9f 100644 --- a/fairseq/modules/location_attention.py +++ b/fairseq/modules/location_attention.py @@ -20,9 +20,16 @@ class LocationAttention(nn.Module): :param int conv_kernel_size: filter size of attention convolution """ - def __init__(self, attn_dim, encoder_dim, decoder_dim, - attn_state_kernel_size, conv_dim, conv_kernel_size, - scaling=2.0): + def __init__( + self, + attn_dim, + encoder_dim, + decoder_dim, + attn_state_kernel_size, + conv_dim, + conv_kernel_size, + scaling=2.0, + ): super(LocationAttention, self).__init__() self.attn_dim = attn_dim self.decoder_dim = decoder_dim @@ -30,9 +37,13 @@ class LocationAttention(nn.Module): self.proj_enc = nn.Linear(encoder_dim, attn_dim) self.proj_dec = nn.Linear(decoder_dim, attn_dim, bias=False) self.proj_attn = nn.Linear(conv_dim, attn_dim, bias=False) - self.conv = nn.Conv1d(attn_state_kernel_size, conv_dim, - 2 * conv_kernel_size + 1, - padding=conv_kernel_size, bias=False) + self.conv = nn.Conv1d( + attn_state_kernel_size, + conv_dim, + 2 * conv_kernel_size + 1, + padding=conv_kernel_size, + bias=False, + ) self.proj_out = nn.Sequential(nn.Tanh(), nn.Linear(attn_dim, 1)) self.proj_enc_out = None # cache diff --git a/fairseq/modules/lstm_cell_with_zoneout.py b/fairseq/modules/lstm_cell_with_zoneout.py index f04e5db2..27330895 100644 --- a/fairseq/modules/lstm_cell_with_zoneout.py +++ b/fairseq/modules/lstm_cell_with_zoneout.py @@ -12,20 +12,20 @@ class LSTMCellWithZoneOut(nn.Module): https://arxiv.org/abs/1606.01305 """ - def __init__(self, prob: float, input_size: int, hidden_size: int, - bias: bool = True): + def __init__( + self, prob: float, input_size: int, hidden_size: int, bias: bool = True + ): super(LSTMCellWithZoneOut, self).__init__() self.lstm_cell = nn.LSTMCell(input_size, hidden_size, bias=bias) self.prob = prob if prob > 1.0 or prob < 0.0: - raise ValueError("zoneout probability must be in the range from " - "0.0 to 1.0.") + raise ValueError( + "zoneout probability must be in the range from " "0.0 to 1.0." + ) def zoneout(self, h, next_h, prob): if isinstance(h, tuple): - return tuple( - [self.zoneout(h[i], next_h[i], prob) for i in range(len(h))] - ) + return tuple([self.zoneout(h[i], next_h[i], prob) for i in range(len(h))]) if self.training: mask = h.new_zeros(*h.size()).bernoulli_(prob) diff --git a/fairseq/modules/quantization/pq/utils.py b/fairseq/modules/quantization/pq/utils.py index 14c015b7..eceeef8b 100644 --- a/fairseq/modules/quantization/pq/utils.py +++ b/fairseq/modules/quantization/pq/utils.py @@ -60,7 +60,9 @@ def quantize_model_( to layers_to_quantize[step] """ - quantized_layers = get_layers(model, layers_to_quantize[step], remove_weights=remove_weights) + quantized_layers = get_layers( + model, layers_to_quantize[step], remove_weights=remove_weights + ) for layer in quantized_layers: @@ -108,8 +110,8 @@ def quantize_model_( centroids = torch.rand(centroids.size()) centroids.cuda() # Get counts and assignment keys from layer in loaded checkpoint. - counts_key = layer+"."+"counts" - assignment_key = layer+"."+"assignments" + counts_key = layer + "." + "counts" + assignment_key = layer + "." + "assignments" # Get number of different bins to include. counts = list(state_dict[counts_key].shape)[0] print(layer) @@ -122,7 +124,7 @@ def quantize_model_( print(num_assignments) print(num_extra) assignments_bins = torch.arange(counts) - assignments_rand = torch.randint(0, counts-1, (num_extra, )) + assignments_rand = torch.randint(0, counts - 1, (num_extra,)) assignments = torch.cat((assignments_bins, assignments_rand), 0) # assignments = assignments.type(torch.IntTensor) assignments.cuda() diff --git a/fairseq/modules/quantization/scalar/utils.py b/fairseq/modules/quantization/scalar/utils.py index 2ec6af3f..d4b1cc25 100644 --- a/fairseq/modules/quantization/scalar/utils.py +++ b/fairseq/modules/quantization/scalar/utils.py @@ -16,7 +16,9 @@ from .modules import ActivationQuantizer, IntConv2d, IntEmbedding, IntLinear MAPPING = {nn.Linear: IntLinear, nn.Embedding: IntEmbedding, nn.Conv2d: IntConv2d} -def quantize_model_(model, p=0.2, bits=8, update_step=3000, method="histogram", remove_weights=False): +def quantize_model_( + model, p=0.2, bits=8, update_step=3000, method="histogram", remove_weights=False +): """ Replaces all modules with their scalar quantized counterpart and registers hooks to quantize the post-ativations of those modules. diff --git a/fairseq/modules/transformer_layer.py b/fairseq/modules/transformer_layer.py index 3ad2be95..d2b57ac8 100644 --- a/fairseq/modules/transformer_layer.py +++ b/fairseq/modules/transformer_layer.py @@ -132,8 +132,7 @@ class TransformerEncoderLayerBase(nn.Module): # will become -inf, which results in NaN in model parameters if attn_mask is not None: attn_mask = attn_mask.masked_fill( - attn_mask.to(torch.bool), - -1e8 if x.dtype == torch.float32 else -1e4 + attn_mask.to(torch.bool), -1e8 if x.dtype == torch.float32 else -1e4 ) residual = x @@ -213,11 +212,19 @@ class TransformerDecoderLayerBase(nn.Module): add_bias_kv=add_bias_kv, add_zero_attn=add_zero_attn, ) - self.attn_ln = LayerNorm(self.embed_dim) if utils.safe_getattr(cfg, 'scale_attn', False) else None + self.attn_ln = ( + LayerNorm(self.embed_dim) + if utils.safe_getattr(cfg, "scale_attn", False) + else None + ) self.nh = self.self_attn.num_heads self.head_dim = self.self_attn.head_dim - scale_heads = utils.safe_getattr(cfg, 'scale_heads', False) - self.c_attn = nn.Parameter(torch.ones((self.nh,)), requires_grad=True) if scale_heads else None + scale_heads = utils.safe_getattr(cfg, "scale_heads", False) + self.c_attn = ( + nn.Parameter(torch.ones((self.nh,)), requires_grad=True) + if scale_heads + else None + ) self.activation_fn = utils.get_activation_fn(activation=cfg.activation_fn) activation_dropout_p = cfg.activation_dropout @@ -238,8 +245,21 @@ class TransformerDecoderLayerBase(nn.Module): self.encoder_attn = self.build_encoder_attention(self.embed_dim, cfg) self.encoder_attn_layer_norm = LayerNorm(self.embed_dim, export=cfg.export) - self.ffn_layernorm = LayerNorm(cfg.decoder.ffn_embed_dim) if utils.safe_getattr(cfg, 'scale_fc', False) else None - self.w_resid = nn.Parameter(torch.ones(self.embed_dim, ), requires_grad=True) if utils.safe_getattr(cfg, 'scale_resids', False) else None + self.ffn_layernorm = ( + LayerNorm(cfg.decoder.ffn_embed_dim) + if utils.safe_getattr(cfg, "scale_fc", False) + else None + ) + self.w_resid = ( + nn.Parameter( + torch.ones( + self.embed_dim, + ), + requires_grad=True, + ) + if utils.safe_getattr(cfg, "scale_resids", False) + else None + ) self.fc1 = self.build_fc1( self.embed_dim, @@ -297,7 +317,6 @@ class TransformerDecoderLayerBase(nn.Module): def residual_connection(self, x, residual): return residual + x - def forward( self, x, @@ -377,7 +396,7 @@ class TransformerDecoderLayerBase(nn.Module): if self.c_attn is not None: tgt_len, bsz = x.size(0), x.size(1) x = x.view(tgt_len, bsz, self.nh, self.head_dim) - x = torch.einsum('tbhd,h->tbhd', x, self.c_attn) + x = torch.einsum("tbhd,h->tbhd", x, self.c_attn) x = x.reshape(tgt_len, bsz, self.embed_dim) if self.attn_ln is not None: x = self.attn_ln(x) diff --git a/fairseq/modules/transformer_sentence_encoder.py b/fairseq/modules/transformer_sentence_encoder.py index d0540d69..5d2db91a 100644 --- a/fairseq/modules/transformer_sentence_encoder.py +++ b/fairseq/modules/transformer_sentence_encoder.py @@ -35,9 +35,7 @@ def init_bert_params(module): def normal_(data): # with FSDP, module params will be on CUDA, so we cast them back to CPU # so that the RNG is consistent with and without FSDP - data.copy_( - data.cpu().normal_(mean=0.0, std=0.02).to(data.device) - ) + data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device)) if isinstance(module, nn.Linear): normal_(module.weight.data) @@ -276,7 +274,9 @@ class TransformerSentenceEncoder(nn.Module): inner_states.append(x) for layer in self.layers: - x, _ = layer(x, self_attn_padding_mask=padding_mask, self_attn_mask=attn_mask) + x, _ = layer( + x, self_attn_padding_mask=padding_mask, self_attn_mask=attn_mask + ) if not last_state_only: inner_states.append(x) diff --git a/fairseq/ngram_repeat_block.py b/fairseq/ngram_repeat_block.py index 85412514..98e707d1 100644 --- a/fairseq/ngram_repeat_block.py +++ b/fairseq/ngram_repeat_block.py @@ -2,13 +2,13 @@ # Licensed under the MIT License. """ Wrapper for ngram_repeat_block cuda extension """ +import math +import warnings +from typing import Dict, List, Optional + import torch from torch import nn -import math -from typing import Dict, List, Optional -import warnings - try: from fairseq import ngram_repeat_block_cuda @@ -37,7 +37,7 @@ def is_cuda_extension_usable() -> bool: class NGramRepeatBlock(nn.Module): - """ Wrapper class for calling ngram_repeat_block cuda extension """ + """Wrapper class for calling ngram_repeat_block cuda extension""" def __init__(self, no_repeat_ngram_size: int, use_extension: bool = True): super().__init__() diff --git a/fairseq/optim/adam.py b/fairseq/optim/adam.py index d3ae9e64..678ec7c6 100644 --- a/fairseq/optim/adam.py +++ b/fairseq/optim/adam.py @@ -67,13 +67,13 @@ class FairseqAdam(FairseqOptimizer): elif use_fused_adam: logger.info("using FusedAdam") self._optimizer = fused_adam_cls( - params, - use_fp16_stats=self.cfg.fp16_adam_stats, - **self.optimizer_config + params, use_fp16_stats=self.cfg.fp16_adam_stats, **self.optimizer_config ) else: if self.cfg.fp16_adam_stats: - raise NotImplementedError("--fp16-adam-stats is only supported with FusedAdamV1") + raise NotImplementedError( + "--fp16-adam-stats is only supported with FusedAdamV1" + ) self._optimizer = Adam(params, **self.optimizer_config) @property diff --git a/fairseq/optim/amp_optimizer.py b/fairseq/optim/amp_optimizer.py index 3b7958e5..cfe57d07 100644 --- a/fairseq/optim/amp_optimizer.py +++ b/fairseq/optim/amp_optimizer.py @@ -63,8 +63,9 @@ class AMPOptimizer(optim.FairseqOptimizer): ).format(self.min_loss_scale, new_loss_scale) ) else: - logger.info("AMP: overflow detected, setting scale to " - f"to {new_loss_scale}") + logger.info( + "AMP: overflow detected, setting scale to " f"to {new_loss_scale}" + ) return grad_norm @property diff --git a/fairseq/optim/composite.py b/fairseq/optim/composite.py index a5366d62..63701ee8 100644 --- a/fairseq/optim/composite.py +++ b/fairseq/optim/composite.py @@ -23,7 +23,9 @@ class OptimizerAndSchedulerConfig(FairseqDataclass): optimizer: Any = None lr_scheduler: Optional[Any] = None lr: List = II("optimization.lr") - lr_float: Optional[float] = None # this makes it easier to sweep on learning rate with auto sweepers + lr_float: Optional[ + float + ] = None # this makes it easier to sweep on learning rate with auto sweepers @dataclass diff --git a/fairseq/optim/cpu_adam.py b/fairseq/optim/cpu_adam.py index b2f893ae..b218934e 100644 --- a/fairseq/optim/cpu_adam.py +++ b/fairseq/optim/cpu_adam.py @@ -16,6 +16,7 @@ from omegaconf import II, DictConfig try: import deepspeed + has_deepspeed = True except ImportError as e: has_deepspeed = False @@ -24,12 +25,15 @@ except ImportError as e: def _get_cpu_adam(): try: from deepspeed.ops.op_builder import CPUAdamBuilder + return CPUAdamBuilder().load() except ImportError: # fbcode from deepspeed.ops.adam import DeepSpeedCPUAdam as ds_opt_adam + return ds_opt_adam + @dataclass class FairseqCPUAdamConfig(FairseqDataclass): adam_betas: str = field( diff --git a/fairseq/optim/fp16_optimizer.py b/fairseq/optim/fp16_optimizer.py index c59b21cf..f8af2883 100644 --- a/fairseq/optim/fp16_optimizer.py +++ b/fairseq/optim/fp16_optimizer.py @@ -64,9 +64,9 @@ class _FP16OptimizerMixin(object): fp32_params = [] for p in params: p32 = torch.nn.Parameter(p.data.float()) - if hasattr(p, 'expert'): + if hasattr(p, "expert"): p32.expert = True - elif hasattr(p, 'base_expert'): + elif hasattr(p, "base_expert"): p32.base_expert = True p32.grad = torch.zeros_like(p32.data) if hasattr(p, "param_group"): @@ -209,7 +209,9 @@ class _FP16OptimizerMixin(object): self._sync_fp16_grads_to_fp32() if getattr(self, "supports_step_with_scale", False): - self.fp32_optimizer.step(closure, scale=(1.0 / self._multiply_factor), groups=groups) + self.fp32_optimizer.step( + closure, scale=(1.0 / self._multiply_factor), groups=groups + ) else: self._unscale_grads() self.fp32_optimizer.step(closure, groups=groups) @@ -434,7 +436,9 @@ class _MemoryEfficientFP16OptimizerMixin(object): """Performs a single optimization step.""" if getattr(self, "supports_step_with_scale", False): # NOTE(msb) optimizer divides by scale factor - self.wrapped_optimizer.step(closure, scale=(1.0 / self._multiply_factor), groups=groups) + self.wrapped_optimizer.step( + closure, scale=(1.0 / self._multiply_factor), groups=groups + ) else: self._unscale_grads() self.wrapped_optimizer.step(closure, groups=groups) diff --git a/fairseq/optim/fused_adam.py b/fairseq/optim/fused_adam.py index 7a6d1f73..da872033 100644 --- a/fairseq/optim/fused_adam.py +++ b/fairseq/optim/fused_adam.py @@ -179,7 +179,7 @@ class FusedAdamV1(torch.optim.Optimizer): if p.device.type == "cpu": p_data_fp32 = p.data.cuda(non_blocking=True).float() - out_p = torch.tensor([], dtype = torch.float) + out_p = torch.tensor([], dtype=torch.float) else: p_data_fp32 = p.data.float() out_p = p.data @@ -234,6 +234,7 @@ class FusedAdamV1(torch.optim.Optimizer): p.data.copy_(p_data_fp32, non_blocking=True) if self.use_fp16_stats: + def inf_norm(t): return torch.norm(t, float("inf")) @@ -262,7 +263,9 @@ try: def __init__(self, *args, use_fp16_stats=False, **kwargs): if use_fp16_stats: - raise NotImplementedError("--fp16-adam-stats is only supported with FusedAdamV1") + raise NotImplementedError( + "--fp16-adam-stats is only supported with FusedAdamV1" + ) super().__init__(*args, **kwargs) if not hasattr(self, "multi_tensor_adam"): raise Exception( diff --git a/fairseq/optim/lr_scheduler/manual_lr_scheduler.py b/fairseq/optim/lr_scheduler/manual_lr_scheduler.py index 0269a1e2..57edc256 100644 --- a/fairseq/optim/lr_scheduler/manual_lr_scheduler.py +++ b/fairseq/optim/lr_scheduler/manual_lr_scheduler.py @@ -32,7 +32,7 @@ class ManualSchedule(LegacyFairseqLRScheduler): self.optimizer.set_lr(self.lr) # Set the beginning of the epoch. def parse_manuallr_args(self, lr_args_str): - lr_dict = ast.literal_eval(lr_args_str.replace(' ', '')) + lr_dict = ast.literal_eval(lr_args_str.replace(" ", "")) if not isinstance(lr_dict, dict): raise ValueError("epoch2lr/update2lr must be abel to evaluated to a dict") @@ -84,9 +84,14 @@ class ManualSchedule(LegacyFairseqLRScheduler): if manual_keys: manual_lr = self.epoch2lr[max(manual_keys)] else: - logger.warning("@@@ epoch={} does not exist in manual lr input. epoch2lr={}...".format( - epoch, list(self.epoch2lr.items())[:min(10, len(self.epoch2lr.keys())-1)] - )) + logger.warning( + "@@@ epoch={} does not exist in manual lr input. epoch2lr={}...".format( + epoch, + list(self.epoch2lr.items())[ + : min(10, len(self.epoch2lr.keys()) - 1) + ], + ) + ) manual_lr = self.optimizer.get_lr() return manual_lr @@ -102,8 +107,14 @@ class ManualSchedule(LegacyFairseqLRScheduler): if manual_keys: manual_lr = self.update2lr[max(manual_keys)] else: - logger.warning("epoch={} does not exist in manual lr input update2lr={}...".format( - num_updates, list(self.update2lr.items())[:min(10, len(self.update2lr.keys())-1)])) + logger.warning( + "epoch={} does not exist in manual lr input update2lr={}...".format( + num_updates, + list(self.update2lr.items())[ + : min(10, len(self.update2lr.keys()) - 1) + ], + ) + ) manual_lr = self.optimizer.get_lr() self.optimizer.set_lr(manual_lr) diff --git a/fairseq/optim/lr_scheduler/step_lr_scheduler.py b/fairseq/optim/lr_scheduler/step_lr_scheduler.py index 8cb20068..db99d4ee 100644 --- a/fairseq/optim/lr_scheduler/step_lr_scheduler.py +++ b/fairseq/optim/lr_scheduler/step_lr_scheduler.py @@ -36,8 +36,7 @@ class StepLRScheduleConfig(FairseqDataclass): @register_lr_scheduler("step", dataclass=StepLRScheduleConfig) class StepLRSchedule(FairseqLRScheduler): - """Decay learning rate every k updates by a fixed factor - """ + """Decay learning rate every k updates by a fixed factor""" def __init__(self, cfg: StepLRScheduleConfig, fairseq_optimizer): super().__init__(cfg, fairseq_optimizer) @@ -50,16 +49,16 @@ class StepLRSchedule(FairseqLRScheduler): cfg.warmup_init_lr if cfg.warmup_init_lr >= 0 else self.min_lr ) - assert(self.lr_deacy_period > 0) - assert(self.lr_decay <= 1) - assert(self.min_lr >= 0) - assert(self.max_lr > self.min_lr) + assert self.lr_deacy_period > 0 + assert self.lr_decay <= 1 + assert self.min_lr >= 0 + assert self.max_lr > self.min_lr if cfg.warmup_updates > 0: # linearly warmup for the first cfg.warmup_updates self.warmup_lr_step = ( - (self.max_lr - self.warmup_init_lr) / self.warmup_updates - ) + self.max_lr - self.warmup_init_lr + ) / self.warmup_updates else: self.warmup_lr_step = 1 diff --git a/fairseq/sequence_generator.py b/fairseq/sequence_generator.py index 2e61140d..bfa791a0 100644 --- a/fairseq/sequence_generator.py +++ b/fairseq/sequence_generator.py @@ -171,7 +171,9 @@ class SequenceGenerator(nn.Module): yield id, src, ref, hypos[i] @torch.no_grad() - def generate(self, models, sample: Dict[str, Dict[str, Tensor]], **kwargs) -> List[List[Dict[str, Tensor]]]: + def generate( + self, models, sample: Dict[str, Dict[str, Tensor]], **kwargs + ) -> List[List[Dict[str, Tensor]]]: """Generate translations. Match the api of other fairseq generators. Args: @@ -223,7 +225,10 @@ class SequenceGenerator(nn.Module): else torch.tensor(src_tokens.size(-1)).to(src_tokens) ) else: - raise Exception("expected src_tokens or source in net input. input keys: " + str(net_input.keys())) + raise Exception( + "expected src_tokens or source in net input. input keys: " + + str(net_input.keys()) + ) # bsz: total number of sentences in beam # Note that src_tokens may have more than 2 dimensions (i.e. audio features) @@ -328,7 +333,9 @@ class SequenceGenerator(nn.Module): encoder_outs = self.model.reorder_encoder_out( encoder_outs, reorder_state ) - with torch.autograd.profiler.record_function("EnsembleModel: forward_decoder"): + with torch.autograd.profiler.record_function( + "EnsembleModel: forward_decoder" + ): lprobs, avg_attn_scores = self.model.forward_decoder( tokens[:, : step + 1], encoder_outs, @@ -751,7 +758,14 @@ class EnsembleModel(nn.Module): return self.has_incremental def max_decoder_positions(self): - return min([m.max_decoder_positions() for m in self.models if hasattr(m, "max_decoder_positions")] + [sys.maxsize]) + return min( + [ + m.max_decoder_positions() + for m in self.models + if hasattr(m, "max_decoder_positions") + ] + + [sys.maxsize] + ) @torch.jit.export def forward_encoder(self, net_input: Dict[str, Tensor]): diff --git a/fairseq/speech_generator.py b/fairseq/speech_generator.py index 8276335e..d75338ec 100644 --- a/fairseq/speech_generator.py +++ b/fairseq/speech_generator.py @@ -35,8 +35,12 @@ class SpeechGenerator(object): class AutoRegressiveSpeechGenerator(SpeechGenerator): def __init__( - self, model, vocoder, data_cfg, max_iter: int = 6000, - eos_prob_threshold: float = 0.5, + self, + model, + vocoder, + data_cfg, + max_iter: int = 6000, + eos_prob_threshold: float = 0.5, ): super().__init__(model, vocoder, data_cfg) self.max_iter = max_iter @@ -54,8 +58,9 @@ class AutoRegressiveSpeechGenerator(SpeechGenerator): raw_dim = out_dim // n_frames_per_step # initialize - encoder_out = model.forward_encoder(src_tokens, src_lengths, - speaker=sample["speaker"]) + encoder_out = model.forward_encoder( + src_tokens, src_lengths, speaker=sample["speaker"] + ) incremental_state = {} feat, attn, eos_prob = [], [], [] finished = src_tokens.new_zeros((bsz,)).bool() @@ -66,21 +71,24 @@ class AutoRegressiveSpeechGenerator(SpeechGenerator): cur_out_lens = out_lens.clone() cur_out_lens.masked_fill_(cur_out_lens.eq(self.max_iter), step + 1) _, cur_eos_out, cur_extra = model.forward_decoder( - prev_feat_out, encoder_out=encoder_out, + prev_feat_out, + encoder_out=encoder_out, incremental_state=incremental_state, - target_lengths=cur_out_lens, speaker=sample["speaker"], **kwargs + target_lengths=cur_out_lens, + speaker=sample["speaker"], + **kwargs ) cur_eos_prob = torch.sigmoid(cur_eos_out).squeeze(2) - feat.append(cur_extra['feature_out']) - attn.append(cur_extra['attn']) + feat.append(cur_extra["feature_out"]) + attn.append(cur_extra["attn"]) eos_prob.append(cur_eos_prob) - cur_finished = (cur_eos_prob.squeeze(1) > self.eos_prob_threshold) + cur_finished = cur_eos_prob.squeeze(1) > self.eos_prob_threshold out_lens.masked_fill_((~finished) & cur_finished, step + 1) finished = finished | cur_finished if finished.sum().item() == bsz: break - prev_feat_out = cur_extra['feature_out'] + prev_feat_out = cur_extra["feature_out"] feat = torch.cat(feat, dim=1) feat = model.decoder.postnet(feat) + feat @@ -98,11 +106,11 @@ class AutoRegressiveSpeechGenerator(SpeechGenerator): finalized = [ { - 'feature': feat[b, :out_len], - 'eos_prob': eos_prob[b, :out_len], - 'attn': attn[b, :, :out_len], - 'alignment': alignment[b, :out_len], - 'waveform': self.get_waveform(feat[b, :out_len]), + "feature": feat[b, :out_len], + "eos_prob": eos_prob[b, :out_len], + "attn": attn[b, :, :out_len], + "alignment": alignment[b, :out_len], + "waveform": self.get_waveform(feat[b, :out_len]), } for b, out_len in zip(range(bsz), out_lens) ] @@ -134,7 +142,7 @@ class NonAutoregressiveSpeechGenerator(SpeechGenerator): prev_output_tokens=sample["net_input"]["prev_output_tokens"], incremental_state=None, target_lengths=sample["target_lengths"], - speaker=sample["speaker"] + speaker=sample["speaker"], ) if feat_post is not None: feat = feat_post @@ -142,9 +150,7 @@ class NonAutoregressiveSpeechGenerator(SpeechGenerator): feat = feat.view(bsz, -1, raw_dim) feat = self.gcmvn_denormalize(feat) - dur_out = torch.clamp( - torch.round(torch.exp(log_dur_out) - 1).long(), min=0 - ) + dur_out = torch.clamp(torch.round(torch.exp(log_dur_out) - 1).long(), min=0) def get_dur_plot_data(d): r = [] @@ -155,11 +161,11 @@ class NonAutoregressiveSpeechGenerator(SpeechGenerator): out_lens = out_lens * n_frames_per_step finalized = [ { - 'feature': feat[b, :l] if l > 0 else feat.new_zeros([1, raw_dim]), - 'waveform': self.get_waveform( + "feature": feat[b, :l] if l > 0 else feat.new_zeros([1, raw_dim]), + "waveform": self.get_waveform( feat[b, :l] if l > 0 else feat.new_zeros([1, raw_dim]) ), - 'attn': feat.new_tensor(get_dur_plot_data(dur_out[b])), + "attn": feat.new_tensor(get_dur_plot_data(dur_out[b])), } for b, l in zip(range(bsz), out_lens) ] @@ -188,8 +194,12 @@ class TeacherForcingAutoRegressiveSpeechGenerator(AutoRegressiveSpeechGenerator) bsz = src_tokens.shape[0] feat, eos_prob, extra = model( - src_tokens, src_lens, prev_out_tokens, incremental_state=None, - target_lengths=tgt_lens, speaker=sample["speaker"] + src_tokens, + src_lens, + prev_out_tokens, + incremental_state=None, + target_lengths=tgt_lens, + speaker=sample["speaker"], ) attn = extra["attn"] # B x T_s x T_t @@ -203,11 +213,11 @@ class TeacherForcingAutoRegressiveSpeechGenerator(AutoRegressiveSpeechGenerator) finalized = [ { - 'feature': feat[b, :tgt_len], - 'eos_prob': eos_prob[b, :tgt_len], - 'attn': attn[b, :, :tgt_len], - 'alignment': alignment[b, :tgt_len], - 'waveform': self.get_waveform(feat[b, :tgt_len]), + "feature": feat[b, :tgt_len], + "eos_prob": eos_prob[b, :tgt_len], + "attn": attn[b, :, :tgt_len], + "alignment": alignment[b, :tgt_len], + "waveform": self.get_waveform(feat[b, :tgt_len]), } for b, tgt_len in zip(range(bsz), tgt_lens) ] diff --git a/fairseq/tasks/audio_finetuning.py b/fairseq/tasks/audio_finetuning.py index 4ef87c60..70aa6a8d 100644 --- a/fairseq/tasks/audio_finetuning.py +++ b/fairseq/tasks/audio_finetuning.py @@ -67,31 +67,31 @@ class AudioFinetuningConfig(AudioPretrainingConfig): default=False, metadata={"help": "evaluation with BLEU scores"} ) eval_bleu_detok: Optional[str] = field( - default=None, metadata={ + default=None, + metadata={ "help": "detokenize before computing BLEU (e.g., 'moses'); " - "required if using --eval-bleu; use 'space' to disable " - "detokenization; see fairseq.data.encoders for other options" - } + "required if using --eval-bleu; use 'space' to disable " + "detokenization; see fairseq.data.encoders for other options" + }, ) eval_bleu_detok_args: str = field( - default="{}", - metadata={"help": "args for building the tokenizer, if needed"} + default="{}", metadata={"help": "args for building the tokenizer, if needed"} ) eval_tokenized_bleu: bool = field( - default=False, - metadata={"help": "compute tokenized BLEU instead of sacrebleu"} + default=False, metadata={"help": "compute tokenized BLEU instead of sacrebleu"} ) eval_bleu_remove_bpe: Optional[str] = field( default=None, metadata={"help": "remove BPE before computing BLEU"} ) eval_bleu_args: str = field( default="{}", - metadata={"help": "generation args for BLUE scoring, e.g., " - "'{\"beam\": 4, \"lenpen\": 0.6}'"} + metadata={ + "help": "generation args for BLUE scoring, e.g., " + '\'{"beam": 4, "lenpen": 0.6}\'' + }, ) eval_bleu_print_samples: bool = field( - default=False, - metadata={"help": "print sample generations during validation"} + default=False, metadata={"help": "print sample generations during validation"} ) autoregressive: bool = field( default=False, @@ -123,7 +123,9 @@ class AudioFinetuningTask(AudioPretrainingTask): return Dictionary.load(dict_path) return None - def load_dataset(self, split: str, task_cfg: AudioFinetuningConfig = None, **kwargs): + def load_dataset( + self, split: str, task_cfg: AudioFinetuningConfig = None, **kwargs + ): super().load_dataset(split, task_cfg, **kwargs) task_cfg = task_cfg or self.cfg @@ -138,7 +140,8 @@ class AudioFinetuningTask(AudioPretrainingTask): with open(label_path, "r") as f: labels = [ text_compressor.compress(l) - for i, l in enumerate(f) if i not in skipped_indices + for i, l in enumerate(f) + if i not in skipped_indices ] assert len(labels) == len(self.datasets[split]), ( @@ -157,7 +160,7 @@ class AudioFinetuningTask(AudioPretrainingTask): process_label=process_label, label_len_fn=label_len_fn, add_to_input=task_cfg.get("autoregressive", False), - text_compression_level=text_compression_level + text_compression_level=text_compression_level, ) @property @@ -176,8 +179,8 @@ class AudioFinetuningTask(AudioPretrainingTask): logging_output["_num_words"] = metrics["num_words"] if self.cfg.eval_bleu and self.cfg.autoregressive: metrics = self._inference_with_bleu(self.sequence_generator, sample, model) - logging_output['_bleu_sys_len'] = metrics.sys_len - logging_output['_bleu_ref_len'] = metrics.ref_len + logging_output["_bleu_sys_len"] = metrics.sys_len + logging_output["_bleu_ref_len"] = metrics.ref_len # we split counts into separate entries so that they can be # summed efficiently across workers using fast-stat-sync assert len(metrics.counts) == 4 @@ -200,9 +203,9 @@ class AudioFinetuningTask(AudioPretrainingTask): self.tokenizer = None if self.cfg.eval_bleu and self.cfg.autoregressive: assert self.cfg.eval_bleu_detok is not None, ( - '--eval-bleu-detok is required if using --eval-bleu; ' - 'try --eval-bleu-detok=moses (or --eval-bleu-detok=space ' - 'to disable detokenization, e.g., when using sentencepiece)' + "--eval-bleu-detok is required if using --eval-bleu; " + "try --eval-bleu-detok=moses (or --eval-bleu-detok=space " + "to disable detokenization, e.g., when using sentencepiece)" ) detok_args = json.loads(self.cfg.eval_bleu_detok_args) self.tokenizer = encoders.build_tokenizer( @@ -261,9 +264,7 @@ class AudioFinetuningTask(AudioPretrainingTask): # BLEU scores. Instead, we use a somewhat more verbose # alternative that is unlikely to appear in the real # reference, but doesn't get split into multiple tokens. - unk_string=( - "UNKNOWNTOKENINREF" if is_ref else "UNKNOWNTOKENINHYP" - ), + unk_string=("UNKNOWNTOKENINREF" if is_ref else "UNKNOWNTOKENINHYP"), ) if self.tokenizer: s = self.tokenizer.decode(s) @@ -272,21 +273,18 @@ class AudioFinetuningTask(AudioPretrainingTask): gen_out = self.inference_step(generator, [model], sample) hyps, refs = [], [] for i in range(len(gen_out)): - hyps.append(decode(gen_out[i][0]['tokens'], is_ref=False)) + hyps.append(decode(gen_out[i][0]["tokens"], is_ref=False)) refs.append( decode( - utils.strip_pad( - sample['target'][i], - self.target_dictionary.pad() - ), + utils.strip_pad(sample["target"][i], self.target_dictionary.pad()), is_ref=True, # don't count as matches to the hypo ) ) if self.cfg.eval_bleu_print_samples: - logger.info('H-{} {}'.format(sample["id"][0], hyps[0])) - logger.info('T-{} {}'.format(sample["id"][0], refs[0])) + logger.info("H-{} {}".format(sample["id"][0], hyps[0])) + logger.info("T-{} {}".format(sample["id"][0], refs[0])) - eval_tokenization = 'none' if self.cfg.eval_tokenized_bleu else '13a' + eval_tokenization = "none" if self.cfg.eval_tokenized_bleu else "13a" return sacrebleu.corpus_bleu(hyps, [refs], tokenize=eval_tokenization) def reduce_metrics(self, logging_outputs, criterion): @@ -329,18 +327,17 @@ class AudioFinetuningTask(AudioPretrainingTask): count_keys = [f"_bleu_counts_{i}" for i in range(4)] total_keys = [f"_bleu_totals_{i}" for i in range(4)] for k in len_keys + count_keys + total_keys: - metrics.log_scalar( - k, sum(log.get(k, 0) for log in logging_outputs) - ) + metrics.log_scalar(k, sum(log.get(k, 0) for log in logging_outputs)) import sacrebleu + metrics.log_derived( - 'bleu', + "bleu", lambda meters: sacrebleu.compute_bleu( correct=[meters[k].sum for k in count_keys], total=[meters[k].sum for k in total_keys], - sys_len=meters['_bleu_sys_len'].sum, - ref_len=meters['_bleu_ref_len'].sum, - smooth_method="exp" - ).score + sys_len=meters["_bleu_sys_len"].sum, + ref_len=meters["_bleu_ref_len"].sum, + smooth_method="exp", + ).score, ) diff --git a/fairseq/tasks/audio_pretraining.py b/fairseq/tasks/audio_pretraining.py index cc310088..debb269d 100644 --- a/fairseq/tasks/audio_pretraining.py +++ b/fairseq/tasks/audio_pretraining.py @@ -50,8 +50,7 @@ class AudioPretrainingConfig(FairseqDataclass): data: str = field(default=MISSING, metadata={"help": "path to data directory"}) labels: Optional[str] = field( default=None, - metadata={ - "help": "extension of the label file to load, used for fine-tuning"}, + metadata={"help": "extension of the label file to load, used for fine-tuning"}, ) binarized_dataset: bool = field( default=False, @@ -102,8 +101,8 @@ class AudioPretrainingConfig(FairseqDataclass): default="none", metadata={ "help": "compression level for texts (e.g. audio filenames, " - "target texts): none/low/high (default: none). " - } + "target texts): none/low/high (default: none). " + }, ) diff --git a/fairseq/tasks/denoising.py b/fairseq/tasks/denoising.py index d1dff26c..1d4f84c0 100644 --- a/fairseq/tasks/denoising.py +++ b/fairseq/tasks/denoising.py @@ -135,7 +135,6 @@ class DenoisingTask(LegacyFairseqTask): 'e.g., "train,valid" (default: all dataset splits)', ) - def __init__(self, args, dictionary): super().__init__(args) self.dictionary = dictionary diff --git a/fairseq/tasks/frm_text_to_speech.py b/fairseq/tasks/frm_text_to_speech.py index 1fa9b0f8..667f5f8e 100644 --- a/fairseq/tasks/frm_text_to_speech.py +++ b/fairseq/tasks/frm_text_to_speech.py @@ -11,20 +11,19 @@ from fairseq.tasks.text_to_speech import TextToSpeechTask logging.basicConfig( - format='%(asctime)s | %(levelname)s | %(name)s | %(message)s', - datefmt='%Y-%m-%d %H:%M:%S', level=logging.INFO + format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=logging.INFO, ) logger = logging.getLogger(__name__) -@register_task('frm_text_to_speech') +@register_task("frm_text_to_speech") class FrmTextToSpeechTask(TextToSpeechTask): @staticmethod def add_args(parser): TextToSpeechTask.add_args(parser) - parser.add_argument( - "--do_chunk", action="store_true", help="train on chunks" - ) + parser.add_argument("--do_chunk", action="store_true", help="train on chunks") parser.add_argument("--chunk_bound", default=-1, type=int) parser.add_argument("--chunk_init", default=50, type=int) parser.add_argument("--chunk_incr", default=5, type=int) @@ -52,5 +51,5 @@ class FrmTextToSpeechTask(TextToSpeechTask): chunk_incr=self.args.chunk_incr, add_eos=self.args.add_eos, dedup=self.args.dedup, - ref_fpu=self.args.ref_fpu + ref_fpu=self.args.ref_fpu, ) diff --git a/fairseq/tasks/hubert_pretraining.py b/fairseq/tasks/hubert_pretraining.py index f756080d..b8667d42 100644 --- a/fairseq/tasks/hubert_pretraining.py +++ b/fairseq/tasks/hubert_pretraining.py @@ -28,15 +28,15 @@ class LabelEncoder(object): def __call__(self, label: str) -> List[str]: return self.dictionary.encode_line( - label, append_eos=False, add_if_not_exist=False, + label, + append_eos=False, + add_if_not_exist=False, ) @dataclass class HubertPretrainingConfig(FairseqDataclass): - data: str = field( - default=MISSING, metadata={"help": "path to data directory"} - ) + data: str = field(default=MISSING, metadata={"help": "path to data directory"}) fine_tuning: bool = field( default=False, metadata={"help": "set to true if fine-tuning Hubert"} ) @@ -68,9 +68,7 @@ class HubertPretrainingConfig(FairseqDataclass): ) normalize: bool = field( default=False, - metadata={ - "help": "if set, normalizes input to have 0 mean and unit variance" - }, + metadata={"help": "if set, normalizes input to have 0 mean and unit variance"}, ) enable_padding: bool = field( default=False, @@ -91,8 +89,7 @@ class HubertPretrainingConfig(FairseqDataclass): single_target: Optional[bool] = field( default=False, metadata={ - "help": "if set, AddTargetDatasets outputs same keys " - "as AddTargetDataset" + "help": "if set, AddTargetDatasets outputs same keys " "as AddTargetDataset" }, ) random_crop: Optional[bool] = field( @@ -149,7 +146,10 @@ class HubertPretrainingTask(FairseqTask): def load_dictionaries(self): label_dir = self.cfg.data if self.cfg.label_dir is None else self.cfg.label_dir - dictionaries = [Dictionary.load(f"{label_dir}/dict.{label}.txt") for label in self.cfg.labels] + dictionaries = [ + Dictionary.load(f"{label_dir}/dict.{label}.txt") + for label in self.cfg.labels + ] return dictionaries[0] if self.cfg.fine_tuning else dictionaries def get_label_dir(self) -> str: @@ -163,9 +163,7 @@ class HubertPretrainingTask(FairseqTask): pad_list = [dict.pad() for dict in dicts] eos_list = [dict.eos() for dict in dicts] procs = [LabelEncoder(dict) for dict in dicts] - paths = [ - f"{self.get_label_dir()}/{split}.{l}" for l in self.cfg.labels - ] + paths = [f"{self.get_label_dir()}/{split}.{l}" for l in self.cfg.labels] # hubert v1: pad_audio=True, random_crop=False; self.datasets[split] = HubertDataset( @@ -189,7 +187,5 @@ class HubertPretrainingTask(FairseqTask): def max_positions(self) -> Tuple[int, int]: return (sys.maxsize, sys.maxsize) - def filter_indices_by_size( - self, indices: np.array, *args, **kwargs - ) -> np.array: + def filter_indices_by_size(self, indices: np.array, *args, **kwargs) -> np.array: return indices diff --git a/fairseq/tasks/language_modeling.py b/fairseq/tasks/language_modeling.py index 4b76a51c..aa397de9 100644 --- a/fairseq/tasks/language_modeling.py +++ b/fairseq/tasks/language_modeling.py @@ -85,10 +85,12 @@ class LanguageModelingConfig(FairseqDataclass): }, ) pad_to_fixed_length: Optional[bool] = field( - default=False, metadata={"help": "pad to fixed length"}, + default=False, + metadata={"help": "pad to fixed length"}, ) pad_to_fixed_bsz: Optional[bool] = field( - default=False, metadata={"help": "boolean to pad to fixed batch size"}, + default=False, + metadata={"help": "boolean to pad to fixed batch size"}, ) # TODO common vars below add to parent @@ -247,7 +249,9 @@ class LanguageModelingTask(LegacyFairseqTask): pad_to_bsz = None if self.args.pad_to_fixed_bsz: - pad_to_bsz = self.args.batch_size_valid if 'valid' in split else self.args.batch_size + pad_to_bsz = ( + self.args.batch_size_valid if "valid" in split else self.args.batch_size + ) self.datasets[split] = MonolingualDataset( dataset=dataset, diff --git a/fairseq/tasks/simultaneous_translation.py b/fairseq/tasks/simultaneous_translation.py index 11c7dc1e..9576b268 100644 --- a/fairseq/tasks/simultaneous_translation.py +++ b/fairseq/tasks/simultaneous_translation.py @@ -6,12 +6,11 @@ import logging from fairseq.tasks import register_task from fairseq.tasks.speech_to_text import SpeechToTextTask -from fairseq.tasks.translation import ( - TranslationTask, TranslationConfig -) +from fairseq.tasks.translation import TranslationTask, TranslationConfig try: - import examples.simultaneous_translation # noqa + import examples.simultaneous_translation # noqa + import_successful = True except BaseException: import_successful = False @@ -35,7 +34,7 @@ class SimulSpeechToTextTask(SpeechToTextTask): super().__init__(args, tgt_dict) -@register_task("simul_text_to_text", dataclass=TranslationConfig) +@register_task("simul_text_to_text", dataclass=TranslationConfig) class SimulTextToTextTask(TranslationTask): def __init__(self, cfg, src_dict, tgt_dict): check_import(import_successful) diff --git a/fairseq/tasks/speech_to_text.py b/fairseq/tasks/speech_to_text.py index 06e29210..3e568052 100644 --- a/fairseq/tasks/speech_to_text.py +++ b/fairseq/tasks/speech_to_text.py @@ -12,7 +12,7 @@ from fairseq.data.audio.speech_to_text_dataset import ( S2TDataConfig, SpeechToTextDataset, SpeechToTextDatasetCreator, - get_features_or_waveform + get_features_or_waveform, ) from fairseq.tasks import LegacyFairseqTask, register_task @@ -101,7 +101,7 @@ class SpeechToTextTask(LegacyFairseqTask): is_train_split=is_train_split, epoch=epoch, seed=self.args.seed, - speaker_to_id=self.speaker_to_id + speaker_to_id=self.speaker_to_id, ) @property @@ -143,8 +143,7 @@ class SpeechToTextTask(LegacyFairseqTask): extra_gen_cls_kwargs = {} extra_gen_cls_kwargs["symbols_to_strip_from_output"] = lang_token_ids return super().build_generator( - models, args, seq_gen_cls=None, - extra_gen_cls_kwargs=extra_gen_cls_kwargs + models, args, seq_gen_cls=None, extra_gen_cls_kwargs=extra_gen_cls_kwargs ) def build_tokenizer(self, args): diff --git a/fairseq/tasks/text_to_speech.py b/fairseq/tasks/text_to_speech.py index 5646e41d..bdbf87f6 100644 --- a/fairseq/tasks/text_to_speech.py +++ b/fairseq/tasks/text_to_speech.py @@ -15,13 +15,15 @@ from fairseq.data.audio.text_to_speech_dataset import TextToSpeechDatasetCreator from fairseq.tasks import register_task from fairseq.tasks.speech_to_text import SpeechToTextTask from fairseq.speech_generator import ( - AutoRegressiveSpeechGenerator, NonAutoregressiveSpeechGenerator, - TeacherForcingAutoRegressiveSpeechGenerator + AutoRegressiveSpeechGenerator, + NonAutoregressiveSpeechGenerator, + TeacherForcingAutoRegressiveSpeechGenerator, ) logging.basicConfig( - format='%(asctime)s | %(levelname)s | %(name)s | %(message)s', - datefmt='%Y-%m-%d %H:%M:%S', level=logging.INFO + format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=logging.INFO, ) logger = logging.getLogger(__name__) @@ -33,21 +35,31 @@ except ImportError: SummaryWriter = None -@register_task('text_to_speech') +@register_task("text_to_speech") class TextToSpeechTask(SpeechToTextTask): @staticmethod def add_args(parser): - parser.add_argument('data', help='manifest root path') + parser.add_argument("data", help="manifest root path") parser.add_argument( - '--config-yaml', type=str, default='config.yaml', - help='Configuration YAML filename (under manifest root)' + "--config-yaml", + type=str, + default="config.yaml", + help="Configuration YAML filename (under manifest root)", + ) + parser.add_argument( + "--max-source-positions", + default=1024, + type=int, + metavar="N", + help="max number of tokens in the source sequence", + ) + parser.add_argument( + "--max-target-positions", + default=1200, + type=int, + metavar="N", + help="max number of tokens in the target sequence", ) - parser.add_argument('--max-source-positions', default=1024, type=int, - metavar='N', - help='max number of tokens in the source sequence') - parser.add_argument('--max-target-positions', default=1200, type=int, - metavar='N', - help='max number of tokens in the target sequence') parser.add_argument("--n-frames-per-step", type=int, default=1) parser.add_argument("--eos-prob-threshold", type=float, default=0.5) parser.add_argument("--eval-inference", action="store_true") @@ -63,19 +75,24 @@ class TextToSpeechTask(SpeechToTextTask): self.tensorboard_writer = None self.tensorboard_dir = "" if args.tensorboard_logdir and SummaryWriter is not None: - self.tensorboard_dir = os.path.join(args.tensorboard_logdir, - "valid_extra") + self.tensorboard_dir = os.path.join(args.tensorboard_logdir, "valid_extra") def load_dataset(self, split, epoch=1, combine=False, **kwargs): - is_train_split = split.startswith('train') + is_train_split = split.startswith("train") pre_tokenizer = self.build_tokenizer(self.args) bpe_tokenizer = self.build_bpe(self.args) self.datasets[split] = TextToSpeechDatasetCreator.from_tsv( - self.args.data, self.data_cfg, split, self.src_dict, - pre_tokenizer, bpe_tokenizer, is_train_split=is_train_split, - epoch=epoch, seed=self.args.seed, + self.args.data, + self.data_cfg, + split, + self.src_dict, + pre_tokenizer, + bpe_tokenizer, + is_train_split=is_train_split, + epoch=epoch, + seed=self.args.seed, n_frames_per_step=self.args.n_frames_per_step, - speaker_to_id=self.speaker_to_id + speaker_to_id=self.speaker_to_id, ) @property @@ -106,7 +123,8 @@ class TextToSpeechTask(SpeechToTextTask): speaker_emb_mat = np.load(args.speaker_emb_path) assert speaker_emb_mat.shape[1] == args.speaker_embed_dim embed_speaker = torch.nn.Embedding.from_pretrained( - torch.from_numpy(speaker_emb_mat), freeze=True, + torch.from_numpy(speaker_emb_mat), + freeze=True, ) logger.info( f"load speaker embeddings from {args.speaker_emb_path}. " @@ -132,22 +150,23 @@ class TextToSpeechTask(SpeechToTextTask): vocoder = self.build_default_vocoder() model = models[0] if getattr(model, "NON_AUTOREGRESSIVE", False): - return NonAutoregressiveSpeechGenerator( - model, vocoder, self.data_cfg - ) + return NonAutoregressiveSpeechGenerator(model, vocoder, self.data_cfg) else: generator = AutoRegressiveSpeechGenerator if getattr(cfg, "teacher_forcing", False): generator = TeacherForcingAutoRegressiveSpeechGenerator logger.info("Teacher forcing mode for generation") return generator( - model, vocoder, self.data_cfg, + model, + vocoder, + self.data_cfg, max_iter=self.args.max_target_positions, - eos_prob_threshold=self.args.eos_prob_threshold + eos_prob_threshold=self.args.eos_prob_threshold, ) def build_default_vocoder(self): from fairseq.models.text_to_speech.vocoder import get_vocoder + vocoder = get_vocoder(self.args, self.data_cfg) if torch.cuda.is_available() and not self.args.cpu: vocoder = vocoder.cuda() @@ -156,25 +175,23 @@ class TextToSpeechTask(SpeechToTextTask): return vocoder def valid_step(self, sample, model, criterion): - loss, sample_size, logging_output = super().valid_step( - sample, model, criterion - ) + loss, sample_size, logging_output = super().valid_step(sample, model, criterion) if getattr(self.args, "eval_inference", False): hypos, inference_losses = self.valid_step_with_inference( sample, model, self.generator ) for k, v in inference_losses.items(): - assert(k not in logging_output) + assert k not in logging_output logging_output[k] = v picked_id = 0 if self.tensorboard_dir and (sample["id"] == picked_id).any(): self.log_tensorboard( sample, - hypos[:self.args.eval_tb_nsample], + hypos[: self.args.eval_tb_nsample], model._num_updates, - is_na_model=getattr(model, "NON_AUTOREGRESSIVE", False) + is_na_model=getattr(model, "NON_AUTOREGRESSIVE", False), ) return loss, sample_size, logging_output @@ -182,17 +199,17 @@ class TextToSpeechTask(SpeechToTextTask): hypos = generator.generate(model, sample, has_targ=True) losses = { - "mcd_loss": 0., - "targ_frames": 0., - "pred_frames": 0., - "nins": 0., - "ndel": 0., + "mcd_loss": 0.0, + "targ_frames": 0.0, + "pred_frames": 0.0, + "nins": 0.0, + "ndel": 0.0, } rets = batch_mel_cepstral_distortion( [hypo["targ_waveform"] for hypo in hypos], [hypo["waveform"] for hypo in hypos], self.sr, - normalize_type=None + normalize_type=None, ) for d, extra in rets: pathmap = extra[-1] @@ -218,41 +235,40 @@ class TextToSpeechTask(SpeechToTextTask): if is_na_model: data = plot_tts_output( [targ.transpose(0, 1), pred.transpose(0, 1)], - [f"target (idx={idx})", "output"], attn, - "alignment", ret_np=True, suptitle=text, + [f"target (idx={idx})", "output"], + attn, + "alignment", + ret_np=True, + suptitle=text, ) else: eos_prob = hypos[b]["eos_prob"] data = plot_tts_output( [targ.transpose(0, 1), pred.transpose(0, 1), attn], - [f"target (idx={idx})", "output", "alignment"], eos_prob, - "eos prob", ret_np=True, suptitle=text, + [f"target (idx={idx})", "output", "alignment"], + eos_prob, + "eos prob", + ret_np=True, + suptitle=text, ) tb_writer.add_image( - f"inference_sample_{b}", data, num_updates, - dataformats="HWC" + f"inference_sample_{b}", data, num_updates, dataformats="HWC" ) if hypos[b]["waveform"] is not None: targ_wave = hypos[b]["targ_waveform"].detach().cpu().float() pred_wave = hypos[b]["waveform"].detach().cpu().float() tb_writer.add_audio( - f"inference_targ_{b}", - targ_wave, - num_updates, - sample_rate=self.sr + f"inference_targ_{b}", targ_wave, num_updates, sample_rate=self.sr ) tb_writer.add_audio( - f"inference_pred_{b}", - pred_wave, - num_updates, - sample_rate=self.sr + f"inference_pred_{b}", pred_wave, num_updates, sample_rate=self.sr ) def save_figure_to_numpy(fig): - data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') + data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) return data @@ -261,8 +277,15 @@ DEFAULT_V_MIN = np.log(1e-5) def plot_tts_output( - data_2d, title_2d, data_1d, title_1d, figsize=(24, 4), - v_min=DEFAULT_V_MIN, v_max=3, ret_np=False, suptitle="" + data_2d, + title_2d, + data_1d, + title_1d, + figsize=(24, 4), + v_min=DEFAULT_V_MIN, + v_max=3, + ret_np=False, + suptitle="", ): try: import matplotlib.pyplot as plt @@ -271,8 +294,8 @@ def plot_tts_output( raise ImportError("Please install Matplotlib: pip install matplotlib") data_2d = [ - x.detach().cpu().float().numpy() - if isinstance(x, torch.Tensor) else x for x in data_2d + x.detach().cpu().float().numpy() if isinstance(x, torch.Tensor) else x + for x in data_2d ] fig, axes = plt.subplots(1, len(data_2d) + 1, figsize=figsize) if suptitle: @@ -281,12 +304,15 @@ def plot_tts_output( for ax, x, name in zip(axes, data_2d, title_2d): ax.set_title(name) divider = make_axes_locatable(ax) - cax = divider.append_axes('right', size='5%', pad=0.05) + cax = divider.append_axes("right", size="5%", pad=0.05) im = ax.imshow( - x, origin="lower", aspect="auto", vmin=max(x.min(), v_min), - vmax=min(x.max(), v_max) + x, + origin="lower", + aspect="auto", + vmin=max(x.min(), v_min), + vmax=min(x.max(), v_max), ) - fig.colorbar(im, cax=cax, orientation='vertical') + fig.colorbar(im, cax=cax, orientation="vertical") if isinstance(data_1d, torch.Tensor): data_1d = data_1d.detach().cpu().numpy() @@ -349,9 +375,12 @@ def batch_dynamic_time_warping(distance, shapes=None): for offset in range(2, m + n - 1): ind = antidiag_indices(offset, 1, m, 1, n) c = torch.stack( - [cumdist[:, ind[0], ind[1] - 1], cumdist[:, ind[0] - 1, ind[1] - 1], - cumdist[:, ind[0] - 1, ind[1]], ], - dim=2 + [ + cumdist[:, ind[0], ind[1] - 1], + cumdist[:, ind[0] - 1, ind[1] - 1], + cumdist[:, ind[0] - 1, ind[1]], + ], + dim=2, ) v, b = c.min(axis=-1) backptr[:, ind[0], ind[1]] = b.int() @@ -364,7 +393,7 @@ def batch_dynamic_time_warping(distance, shapes=None): j = n - 1 if shapes is None else (shapes[b][1] - 1).item() dtwpath = [(i, j)] while (i != 0 or j != 0) and len(dtwpath) < 10000: - assert (i >= 0 and j >= 0) + assert i >= 0 and j >= 0 di, dj = ptr2dij[backptr[b, i, j].item()] i, j = i + di, j + dj dtwpath.append((i, j)) @@ -401,7 +430,7 @@ def get_divisor(pathmap, normalize_type): def batch_compute_distortion(y1, y2, sr, feat_fn, dist_fn, normalize_type): d, s, x1, x2 = [], [], [], [] for cur_y1, cur_y2 in zip(y1, y2): - assert (cur_y1.ndim == 1 and cur_y2.ndim == 1) + assert cur_y1.ndim == 1 and cur_y2.ndim == 1 cur_x1 = feat_fn(cur_y1) cur_x2 = feat_fn(cur_y2) x1.append(cur_x1) @@ -432,9 +461,7 @@ def batch_compute_distortion(y1, y2, sr, feat_fn, dist_fn, normalize_type): return rets -def batch_mel_cepstral_distortion( - y1, y2, sr, normalize_type="path", mfcc_fn=None -): +def batch_mel_cepstral_distortion(y1, y2, sr, normalize_type="path", mfcc_fn=None): """ https://arxiv.org/pdf/2011.03568.pdf @@ -454,14 +481,21 @@ def batch_mel_cepstral_distortion( if mfcc_fn is None or mfcc_fn.sample_rate != sr: melkwargs = { - "n_fft": int(0.05 * sr), "win_length": int(0.05 * sr), - "hop_length": int(0.0125 * sr), "f_min": 20, - "n_mels": 80, "window_fn": torch.hann_window + "n_fft": int(0.05 * sr), + "win_length": int(0.05 * sr), + "hop_length": int(0.0125 * sr), + "f_min": 20, + "n_mels": 80, + "window_fn": torch.hann_window, } mfcc_fn = torchaudio.transforms.MFCC( sr, n_mfcc=13, log_mels=True, melkwargs=melkwargs ).to(y1[0].device) return batch_compute_distortion( - y1, y2, sr, lambda y: mfcc_fn(y).transpose(-1, -2), compute_rms_dist, - normalize_type + y1, + y2, + sr, + lambda y: mfcc_fn(y).transpose(-1, -2), + compute_rms_dist, + normalize_type, ) diff --git a/fairseq/tasks/translation.py b/fairseq/tasks/translation.py index 86473608..f5a3cf66 100644 --- a/fairseq/tasks/translation.py +++ b/fairseq/tasks/translation.py @@ -399,6 +399,7 @@ class TranslationTask(FairseqTask): def sum_logs(key): import torch + result = sum(log.get(key, 0) for log in logging_outputs) if torch.is_tensor(result): result = result.cpu() @@ -418,12 +419,15 @@ class TranslationTask(FairseqTask): def compute_bleu(meters): import inspect + try: from sacrebleu.metrics import BLEU + comp_bleu = BLEU.compute_bleu except ImportError: # compatibility API for sacrebleu 1.x import sacrebleu + comp_bleu = sacrebleu.compute_bleu fn_sig = inspect.getfullargspec(comp_bleu)[0] @@ -436,7 +440,7 @@ class TranslationTask(FairseqTask): total=meters["_bleu_totals"].sum, sys_len=meters["_bleu_sys_len"].sum, ref_len=meters["_bleu_ref_len"].sum, - **smooth + **smooth, ) return round(bleu.score, 2) diff --git a/fairseq/tasks/translation_lev.py b/fairseq/tasks/translation_lev.py index 04127930..b45fecd1 100644 --- a/fairseq/tasks/translation_lev.py +++ b/fairseq/tasks/translation_lev.py @@ -9,21 +9,25 @@ from fairseq import utils from fairseq.data import LanguagePairDataset from fairseq.dataclass import ChoiceEnum from fairseq.tasks import register_task -from fairseq.tasks.translation import TranslationConfig, TranslationTask, load_langpair_dataset +from fairseq.tasks.translation import ( + TranslationConfig, + TranslationTask, + load_langpair_dataset, +) from fairseq.utils import new_arange NOISE_CHOICES = ChoiceEnum(["random_delete", "random_mask", "no_noise", "full_mask"]) + @dataclass class TranslationLevenshteinConfig(TranslationConfig): noise: NOISE_CHOICES = field( default="random_delete", - metadata={ - "help": "type of noise" - }, + metadata={"help": "type of noise"}, ) + @register_task("translation_lev", dataclass=TranslationLevenshteinConfig) class TranslationLevenshteinTask(TranslationTask): """ diff --git a/fairseq/tasks/translation_multi_simple_epoch.py b/fairseq/tasks/translation_multi_simple_epoch.py index e64ab9a6..2ba012e8 100644 --- a/fairseq/tasks/translation_multi_simple_epoch.py +++ b/fairseq/tasks/translation_multi_simple_epoch.py @@ -197,7 +197,11 @@ class TranslationMultiSimpleEpochTask(LegacyFairseqTask): return dataset def build_generator( - self, models, args, seq_gen_cls=None, extra_gen_cls_kwargs=None, + self, + models, + args, + seq_gen_cls=None, + extra_gen_cls_kwargs=None, ): if not getattr(args, "keep_inference_langtok", False): _, tgt_langtok_spec = self.args.langtoks["main"] diff --git a/fairseq/trainer.py b/fairseq/trainer.py index 30e12dcc..ac24edb2 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -337,7 +337,10 @@ class Trainer(object): ) if self.cfg.optimization.use_bmuf: - self._optimizer = optim.FairseqBMUF(self.cfg.bmuf, self._optimizer,) + self._optimizer = optim.FairseqBMUF( + self.cfg.bmuf, + self._optimizer, + ) if self.cfg.distributed_training.zero_sharding == "os": if ( @@ -355,7 +358,8 @@ class Trainer(object): # We should initialize the learning rate scheduler immediately after # building the optimizer, so that the initial learning rate is set. self._lr_scheduler = lr_scheduler.build_lr_scheduler( - self.cfg.lr_scheduler, self.optimizer, + self.cfg.lr_scheduler, + self.optimizer, ) self._lr_scheduler.step_update(0) @@ -652,7 +656,9 @@ class Trainer(object): return batch_iterator def get_valid_iterator( - self, subset, disable_iterator_cache=False, + self, + subset, + disable_iterator_cache=False, ): """Return an EpochBatchIterator over given validation subset for a given epoch.""" batch_iterator = self.task.get_batch_iterator( @@ -660,7 +666,8 @@ class Trainer(object): max_tokens=self.cfg.dataset.max_tokens_valid, max_sentences=self.cfg.dataset.batch_size_valid, max_positions=utils.resolve_max_positions( - self.task.max_positions(), self.model.max_positions(), + self.task.max_positions(), + self.model.max_positions(), ), ignore_invalid_inputs=self.cfg.dataset.skip_invalid_size_inputs_valid_test, required_batch_size_multiple=self.cfg.dataset.required_batch_size_multiple, @@ -809,7 +816,11 @@ class Trainer(object): train_time = self._local_cumulative_training_time() ( logging_outputs, - (sample_size, ooms, total_train_time,), + ( + sample_size, + ooms, + total_train_time, + ), ) = self._aggregate_logging_outputs( logging_outputs, sample_size, ooms, train_time, ignore=is_dummy_batch ) @@ -924,7 +935,8 @@ class Trainer(object): if self.cfg.ema.store_ema: # Step EMA forward with new model. self.ema.step( - self.get_model(), self.get_num_updates(), + self.get_model(), + self.get_num_updates(), ) metrics.log_scalar( "ema_decay", @@ -1058,7 +1070,9 @@ class Trainer(object): # gather logging outputs from all replicas if self.data_parallel_world_size > 1: logging_outputs, (sample_size,) = self._aggregate_logging_outputs( - logging_outputs, sample_size, ignore=is_dummy_batch, + logging_outputs, + sample_size, + ignore=is_dummy_batch, ) # log validation stats @@ -1260,9 +1274,10 @@ class Trainer(object): return False elif self.cfg.optimization.use_bmuf: return ( - (self.get_num_updates() + 1) % self.cfg.bmuf.global_sync_iter == 0 - and (self.get_num_updates() + 1) > self.cfg.bmuf.warmup_iterations - ) + self.get_num_updates() + 1 + ) % self.cfg.bmuf.global_sync_iter == 0 and ( + self.get_num_updates() + 1 + ) > self.cfg.bmuf.warmup_iterations else: return True @@ -1275,7 +1290,10 @@ class Trainer(object): sys.stderr.flush() def _aggregate_logging_outputs( - self, logging_outputs: List[Dict[str, Any]], *extra_stats_to_sum, ignore=False, + self, + logging_outputs: List[Dict[str, Any]], + *extra_stats_to_sum, + ignore=False, ): if self.task.__class__.logging_outputs_can_be_summed(self.get_criterion()): return self._fast_stat_sync_sum( @@ -1287,7 +1305,10 @@ class Trainer(object): ) def _all_gather_list_sync( - self, logging_outputs: List[Dict[str, Any]], *extra_stats_to_sum, ignore=False, + self, + logging_outputs: List[Dict[str, Any]], + *extra_stats_to_sum, + ignore=False, ): """ Sync logging outputs across workers. all_gather_list_sync is @@ -1312,7 +1333,10 @@ class Trainer(object): return logging_outputs, extra_stats_to_sum def _fast_stat_sync_sum( - self, logging_outputs: List[Dict[str, Any]], *extra_stats_to_sum, ignore=False, + self, + logging_outputs: List[Dict[str, Any]], + *extra_stats_to_sum, + ignore=False, ): """ Sync logging outputs across workers. fast_stat_sync_sum is diff --git a/fairseq/utils.py b/fairseq/utils.py index 94114ce1..0f848961 100644 --- a/fairseq/utils.py +++ b/fairseq/utils.py @@ -85,7 +85,9 @@ def apply_to_sample(f, sample): return f(x) elif isinstance(x, collections.OrderedDict): # OrderedDict has attributes that needs to be preserved - od = collections.OrderedDict((key, _apply(value)) for key, value in x.items()) + od = collections.OrderedDict( + (key, _apply(value)) for key, value in x.items() + ) od.__dict__ = x.__dict__ return od elif isinstance(x, dict): @@ -536,6 +538,7 @@ def deprecation_warning(message, stacklevel=3): # don't use DeprecationWarning, since it's ignored by default warnings.warn(message, stacklevel=stacklevel) + def relu_squared(x: torch.Tensor): return F.relu(x).pow(2) diff --git a/fairseq_cli/generate.py b/fairseq_cli/generate.py index 7e887e88..b8757835 100644 --- a/fairseq_cli/generate.py +++ b/fairseq_cli/generate.py @@ -17,11 +17,12 @@ from itertools import chain import numpy as np import torch +from omegaconf import DictConfig + from fairseq import checkpoint_utils, options, scoring, tasks, utils from fairseq.dataclass.utils import convert_namespace_to_omegaconf from fairseq.logging import progress_bar from fairseq.logging.meters import StopwatchMeter, TimeMeter -from omegaconf import DictConfig def main(cfg: DictConfig): @@ -81,7 +82,6 @@ def _main(cfg: DictConfig, output_file): # Load dataset splits task = tasks.setup_task(cfg.task) - # Set dictionaries try: src_dict = getattr(task, "source_dictionary", None) @@ -316,10 +316,7 @@ def _main(cfg: DictConfig, output_file): "A-{}\t{}".format( sample_id, " ".join( - [ - ",".join(src_probs) - for src_probs in alignment - ] + [",".join(src_probs) for src_probs in alignment] ), ), file=output_file, @@ -348,7 +345,10 @@ def _main(cfg: DictConfig, output_file): # Score only the top hypothesis if has_target and j == 0: - if align_dict is not None or cfg.common_eval.post_process is not None: + if ( + align_dict is not None + or cfg.common_eval.post_process is not None + ): # Convert back to tokens for evaluation with unk replacement and/or without BPE target_tokens = tgt_dict.encode_line( target_str, add_if_not_exist=True @@ -402,9 +402,12 @@ def cli_main(): parser = options.get_generation_parser() # TODO: replace this workaround with refactoring of `AudioPretraining` parser.add_argument( - '--arch', '-a', metavar='ARCH', default="wav2vec2", - help='Model architecture. For constructing tasks that rely on ' - 'model args (e.g. `AudioPretraining`)' + "--arch", + "-a", + metavar="ARCH", + default="wav2vec2", + help="Model architecture. For constructing tasks that rely on " + "model args (e.g. `AudioPretraining`)", ) args = options.parse_args_and_arch(parser) main(args) diff --git a/fairseq_cli/hydra_train.py b/fairseq_cli/hydra_train.py index 6555ab41..607340af 100644 --- a/fairseq_cli/hydra_train.py +++ b/fairseq_cli/hydra_train.py @@ -7,18 +7,17 @@ import logging import os -from fairseq.dataclass.initialize import add_defaults, hydra_init -from fairseq_cli.train import main as pre_main -from fairseq import distributed_utils, metrics -from fairseq.dataclass.configs import FairseqConfig -from fairseq.dataclass.utils import omegaconf_no_object_check -from fairseq.utils import reset_logging - import hydra -from hydra.core.hydra_config import HydraConfig import torch +from hydra.core.hydra_config import HydraConfig from omegaconf import OmegaConf, open_dict +from fairseq import distributed_utils, metrics +from fairseq.dataclass.configs import FairseqConfig +from fairseq.dataclass.initialize import add_defaults, hydra_init +from fairseq.dataclass.utils import omegaconf_no_object_check +from fairseq.utils import reset_logging +from fairseq_cli.train import main as pre_main logger = logging.getLogger("fairseq_cli.hydra_train") @@ -38,10 +37,14 @@ def _hydra_main(cfg: FairseqConfig, **kwargs) -> float: if HydraConfig.initialized(): with open_dict(cfg): # make hydra logging work with ddp (see # see https://github.com/facebookresearch/hydra/issues/1126) - cfg.job_logging_cfg = OmegaConf.to_container(HydraConfig.get().job_logging, resolve=True) + cfg.job_logging_cfg = OmegaConf.to_container( + HydraConfig.get().job_logging, resolve=True + ) with omegaconf_no_object_check(): - cfg = OmegaConf.create(OmegaConf.to_container(cfg, resolve=True, enum_to_str=True)) + cfg = OmegaConf.create( + OmegaConf.to_container(cfg, resolve=True, enum_to_str=True) + ) OmegaConf.set_struct(cfg, True) try: diff --git a/fairseq_cli/interactive.py b/fairseq_cli/interactive.py index cadef282..03265d00 100644 --- a/fairseq_cli/interactive.py +++ b/fairseq_cli/interactive.py @@ -19,13 +19,13 @@ from collections import namedtuple import numpy as np import torch + from fairseq import checkpoint_utils, distributed_utils, options, tasks, utils from fairseq.dataclass.configs import FairseqConfig from fairseq.dataclass.utils import convert_namespace_to_omegaconf from fairseq.token_generation_constraints import pack_constraints, unpack_constraints from fairseq_cli.generate import get_symbols_to_strip_from_output - logging.basicConfig( format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S", @@ -249,7 +249,7 @@ def main(cfg: FairseqConfig): # sort output to match input order for id_, src_tokens, hypos, info in sorted(results, key=lambda x: x[0]): - src_str = '' + src_str = "" if src_dict is not None: src_str = src_dict.string(src_tokens, cfg.common_eval.post_process) print("S-{}\t{}".format(id_, src_str)) @@ -257,7 +257,8 @@ def main(cfg: FairseqConfig): for constraint in info["constraints"]: print( "C-{}\t{}".format( - id_, tgt_dict.string(constraint, cfg.common_eval.post_process) + id_, + tgt_dict.string(constraint, cfg.common_eval.post_process), ) ) diff --git a/fairseq_cli/preprocess.py b/fairseq_cli/preprocess.py index 4ee9a1e3..6f24983d 100644 --- a/fairseq_cli/preprocess.py +++ b/fairseq_cli/preprocess.py @@ -41,7 +41,9 @@ def main(args): ) logger.info(args) - assert args.dataset_impl != "huffman", "preprocessing.py doesn't support Huffman yet, use HuffmanCodeBuilder directly." + assert ( + args.dataset_impl != "huffman" + ), "preprocessing.py doesn't support Huffman yet, use HuffmanCodeBuilder directly." task = tasks.get_task(args.task) diff --git a/fairseq_cli/train.py b/fairseq_cli/train.py index 369a8a82..a707add7 100644 --- a/fairseq_cli/train.py +++ b/fairseq_cli/train.py @@ -12,7 +12,7 @@ import logging import math import os import sys -from typing import Dict, Optional, Any, List, Tuple, Callable +from typing import Any, Callable, Dict, List, Optional, Tuple # We need to setup root logger before importing any fairseq libraries. logging.basicConfig( @@ -25,23 +25,19 @@ logger = logging.getLogger("fairseq_cli.train") import numpy as np import torch -from fairseq import ( - checkpoint_utils, - options, - quantization_utils, - tasks, - utils, -) -from fairseq.data import iterators, data_utils +from omegaconf import DictConfig, OmegaConf + +from fairseq import checkpoint_utils, options, quantization_utils, tasks, utils +from fairseq.data import data_utils, iterators from fairseq.data.plasma_utils import PlasmaStore from fairseq.dataclass.configs import FairseqConfig from fairseq.dataclass.utils import convert_namespace_to_omegaconf -from fairseq.distributed import fsdp_enable_wrap, fsdp_wrap, utils as distributed_utils +from fairseq.distributed import fsdp_enable_wrap, fsdp_wrap +from fairseq.distributed import utils as distributed_utils from fairseq.file_io import PathManager from fairseq.logging import meters, metrics, progress_bar from fairseq.model_parallel.megatron_trainer import MegatronTrainer from fairseq.trainer import Trainer -from omegaconf import DictConfig, OmegaConf def main(cfg: FairseqConfig) -> None: @@ -156,7 +152,8 @@ def main(cfg: FairseqConfig) -> None: ) logger.info( "max tokens per device = {} and max sentences per device = {}".format( - cfg.dataset.max_tokens, cfg.dataset.batch_size, + cfg.dataset.max_tokens, + cfg.dataset.batch_size, ) ) @@ -259,7 +256,9 @@ def train( else cfg.optimization.update_freq[-1] ) itr = iterators.GroupedIterator( - itr, update_freq, skip_remainder_batch=cfg.optimization.skip_remainder_batch, + itr, + update_freq, + skip_remainder_batch=cfg.optimization.skip_remainder_batch, ) if cfg.common.tpu: itr = utils.tpu_data_loader(itr) diff --git a/fairseq_cli/validate.py b/fairseq_cli/validate.py index 22b93e9a..4617b6d5 100644 --- a/fairseq_cli/validate.py +++ b/fairseq_cli/validate.py @@ -11,12 +11,12 @@ from argparse import Namespace from itertools import chain import torch +from omegaconf import DictConfig + from fairseq import checkpoint_utils, distributed_utils, options, utils from fairseq.dataclass.utils import convert_namespace_to_omegaconf from fairseq.logging import metrics, progress_bar from fairseq.utils import reset_logging -from omegaconf import DictConfig - logging.basicConfig( format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", @@ -142,9 +142,7 @@ def cli_main(): # only override args that are explicitly given on the command line override_parser = options.get_validation_parser() - override_args = options.parse_args_and_arch( - override_parser, suppress_defaults=True - ) + override_args = options.parse_args_and_arch(override_parser, suppress_defaults=True) distributed_utils.call_main( convert_namespace_to_omegaconf(args), main, override_args=override_args diff --git a/setup.py b/setup.py index 4bf1e60d..c5591915 100644 --- a/setup.py +++ b/setup.py @@ -7,11 +7,9 @@ import os import subprocess import sys -from setuptools import setup, find_packages, Extension from setuptools import Extension, find_packages, setup - if sys.version_info < (3, 6): sys.exit("Sorry, Python >= 3.6 is required for fairseq.") @@ -277,7 +275,8 @@ if __name__ == "__main__": package_data = { "fairseq": ( - get_files(fairseq_examples) + get_files(os.path.join("fairseq", "config")) + get_files(fairseq_examples) + + get_files(os.path.join("fairseq", "config")) ) } do_setup(package_data) diff --git a/tests/distributed/test_bmuf.py b/tests/distributed/test_bmuf.py index 8b7cadb0..2a0f20d0 100644 --- a/tests/distributed/test_bmuf.py +++ b/tests/distributed/test_bmuf.py @@ -11,9 +11,10 @@ from multiprocessing import Manager import torch import torch.nn as nn +from omegaconf import OmegaConf + from fairseq import optim from fairseq.distributed import utils as distributed_utils -from omegaconf import OmegaConf class Model(nn.Module): @@ -42,10 +43,7 @@ def setup_model_loss_criterion(cfg, args, rank, is_cuda): loss_fn = loss_fn.cuda() optimizer = optim.sgd.SGD(args, model.parameters()) - optimizer = optim.FairseqBMUF( - cfg=cfg.bmuf, - optimizer=optimizer - ) + optimizer = optim.FairseqBMUF(cfg=cfg.bmuf, optimizer=optimizer) return model, loss_fn, optimizer diff --git a/tests/distributed/test_distributed_timeout_wrapper.py b/tests/distributed/test_distributed_timeout_wrapper.py index 27908b9d..996093cb 100644 --- a/tests/distributed/test_distributed_timeout_wrapper.py +++ b/tests/distributed/test_distributed_timeout_wrapper.py @@ -15,7 +15,6 @@ from fairseq.distributed import DistributedTimeoutWrapper class ModuleWithDelay(nn.Module): - def __init__(self, delay): super().__init__() self.delay = delay @@ -26,7 +25,6 @@ class ModuleWithDelay(nn.Module): class TestDistributedTimeoutWrapper(unittest.TestCase): - def setUp(self): logging.disable(logging.CRITICAL) diff --git a/tests/distributed/test_module_proxy_wrapper.py b/tests/distributed/test_module_proxy_wrapper.py index 2803a044..2ac1a877 100644 --- a/tests/distributed/test_module_proxy_wrapper.py +++ b/tests/distributed/test_module_proxy_wrapper.py @@ -38,7 +38,6 @@ class Model(nn.Module): class TestModuleProxyWrapper(unittest.TestCase): - def _get_module(self): module = Model() wrapped_module = MockDDPWrapper(module) diff --git a/tests/distributed/utils.py b/tests/distributed/utils.py index c8040392..be4e19cd 100644 --- a/tests/distributed/utils.py +++ b/tests/distributed/utils.py @@ -15,7 +15,10 @@ def spawn_and_init(fn, world_size, args=None): with tempfile.NamedTemporaryFile(delete=False) as tmp_file: torch.multiprocessing.spawn( fn=functools.partial(init_and_run, fn, args), - args=(world_size, tmp_file.name,), + args=( + world_size, + tmp_file.name, + ), nprocs=world_size, join=True, ) diff --git a/tests/gpu/test_binaries_gpu.py b/tests/gpu/test_binaries_gpu.py index 99eb7f55..550e751b 100644 --- a/tests/gpu/test_binaries_gpu.py +++ b/tests/gpu/test_binaries_gpu.py @@ -4,14 +4,15 @@ # LICENSE file in the root directory of this source tree. import contextlib -import logging import json +import logging import os import tempfile import unittest from io import StringIO import torch + from fairseq import options from fairseq_cli import train from tests.utils import ( @@ -32,20 +33,17 @@ class TestTranslationGPU(unittest.TestCase): logging.disable(logging.NOTSET) def test_fp16_multigpu(self): - self._test_multigpu( - "test_fp16", ["--fp16"] - ) + self._test_multigpu("test_fp16", ["--fp16"]) def test_slowmo_multigpu(self): self._test_multigpu( - "test_slowmo", - ["--ddp-backend", "slowmo", "--nprocs-per-node", "1"] + "test_slowmo", ["--ddp-backend", "slowmo", "--nprocs-per-node", "1"] ) def test_slowmo_single_node_multigpu(self): self._test_multigpu( "test_slowmo_single_node", - ["--ddp-backend", "slowmo", "--nprocs-per-node", "2"] + ["--ddp-backend", "slowmo", "--nprocs-per-node", "2"], ) def _test_multigpu(self, test_name, test_args): @@ -77,7 +75,9 @@ class TestTranslationGPU(unittest.TestCase): self._test_resume_training(["--ddp-backend", "fully_sharded"]) def test_resume_training_fsdp_sharded_state(self): - self._test_resume_training(["--ddp-backend", "fully_sharded", "--use-sharded-state"]) + self._test_resume_training( + ["--ddp-backend", "fully_sharded", "--use-sharded-state"] + ) def test_resume_training_noc10d(self): self._test_resume_training([]) @@ -101,7 +101,10 @@ class TestTranslationGPU(unittest.TestCase): create_dummy_data(data_dir) preprocess_translation_data(data_dir) train_translation_model( - data_dir, arch, flags + ["--log-file", log], world_size=world_size, + data_dir, + arch, + flags + ["--log-file", log], + world_size=world_size, ) log2 = os.path.join(data_dir, "resume.log") restore_file = os.path.join(data_dir, "checkpoint_1_2.pt") @@ -261,7 +264,13 @@ class TestTranslationGPU(unittest.TestCase): train_translation_model( data_dir, "fconv_iwslt_de_en", - ["--log-file", log, "--ddp-backend", "fully_sharded", "--use-sharded-state"], + [ + "--log-file", + log, + "--ddp-backend", + "fully_sharded", + "--use-sharded-state", + ], world_size=world_size, ) generate_main(data_dir, ["--checkpoint-shard-count", str(world_size)]) diff --git a/tests/gpu/test_ema_gpu.py b/tests/gpu/test_ema_gpu.py index 337107d6..34d9ccb7 100644 --- a/tests/gpu/test_ema_gpu.py +++ b/tests/gpu/test_ema_gpu.py @@ -9,6 +9,7 @@ from dataclasses import dataclass from typing import Optional import torch + from fairseq.models.ema import EMA @@ -45,9 +46,7 @@ class TestEMAGPU(unittest.TestCase): other_norm = torch.norm(y.float()) if msg is None: - msg = "|input - other| > {} + {} * |other|".format( - atol, rtol - ) + msg = "|input - other| > {} + {} * |other|".format(atol, rtol) self.assertLessEqual( diff_norm, @@ -104,9 +103,7 @@ class TestEMAGPU(unittest.TestCase): for key, param in model2.state_dict().items(): ema_param = ema_state_dict[key] - self.assertTrue( - torch.allclose(ema_param, param) - ) + self.assertTrue(torch.allclose(ema_param, param)) def test_ema_fp32(self): model = DummyModule().cuda().half() @@ -136,17 +133,27 @@ class TestEMAGPU(unittest.TestCase): # closer to the EMA update done in fp32 than in fp16. self.assertLessEqual( torch.norm( - ema_param.float() - - (config.ema_decay * prev_param.float() + (1 - config.ema_decay) * param.float()).half().float() + ema_param.float() + - ( + config.ema_decay * prev_param.float() + + (1 - config.ema_decay) * param.float() + ) + .half() + .float() ), torch.norm( - ema_param.float() - - (config.ema_decay * prev_param + (1 - config.ema_decay) * param).float() + ema_param.float() + - ( + config.ema_decay * prev_param + (1 - config.ema_decay) * param + ).float() ), ) self.assertTorchAllClose( ema_param, - (config.ema_decay * prev_param.float() + (1 - config.ema_decay) * param.float()).half(), + ( + config.ema_decay * prev_param.float() + + (1 - config.ema_decay) * param.float() + ).half(), ) def test_ema_fp16(self): @@ -179,12 +186,19 @@ class TestEMAGPU(unittest.TestCase): # closer to the EMA update done in fp16 than in fp32. self.assertLessEqual( torch.norm( - ema_param.float() - - (config.ema_decay * prev_param + (1 - config.ema_decay) * param).float() + ema_param.float() + - ( + config.ema_decay * prev_param + (1 - config.ema_decay) * param + ).float() ), torch.norm( - ema_param.float() - - (config.ema_decay * prev_param.float() + (1 - config.ema_decay) * param.float()).half().float() + ema_param.float() + - ( + config.ema_decay * prev_param.float() + + (1 - config.ema_decay) * param.float() + ) + .half() + .float() ), ) self.assertTorchAllClose( diff --git a/tests/test_amp_optimizer.py b/tests/test_amp_optimizer.py index 3a785e18..4d6073a9 100644 --- a/tests/test_amp_optimizer.py +++ b/tests/test_amp_optimizer.py @@ -8,7 +8,8 @@ import copy import unittest import torch -from torch.cuda.amp import autocast, GradScaler +from torch.cuda.amp import GradScaler, autocast + from fairseq.optim import build_optimizer @@ -58,15 +59,11 @@ class TestGradientScalingAMP(unittest.TestCase): self.scaler.update() self.assertEqual( model.weight, - torch.tensor( - [[3.1]], device="cuda:0", requires_grad=True - ), + torch.tensor([[3.1]], device="cuda:0", requires_grad=True), ) self.assertEqual( model.bias, - torch.tensor( - [5.1], device="cuda:0", requires_grad=True - ), + torch.tensor([5.1], device="cuda:0", requires_grad=True), ) self.assertEqual(self.scaler.get_scale(), 2.0) diff --git a/tests/test_binaries.py b/tests/test_binaries.py index bc233192..1ab92f5f 100644 --- a/tests/test_binaries.py +++ b/tests/test_binaries.py @@ -4,30 +4,31 @@ # LICENSE file in the root directory of this source tree. import contextlib -import logging import json +import logging import os import random import sys import tempfile import unittest from io import StringIO -from typing import List, Dict +from typing import Dict, List + import torch + from fairseq import options from fairseq_cli import eval_lm, train from tests.utils import ( create_dummy_data, + create_laser_data_and_config_json, generate_main, preprocess_lm_data, preprocess_summarization_data, preprocess_translation_data, - create_laser_data_and_config_json, - train_translation_model, train_language_model, + train_translation_model, ) - try: import transformers # noqa @@ -1161,7 +1162,7 @@ class TestLanguageModeling(unittest.TestCase): train_language_model( data_dir, "transformer_lm", - ["--add-bos-token", '--nval', '1'], + ["--add-bos-token", "--nval", "1"], run_validation=True, ) eval_lm_main(data_dir) @@ -1186,7 +1187,15 @@ class TestLanguageModeling(unittest.TestCase): train_language_model( data_dir, "transformer_lm", - ["--add-bos-token", '--nval', '1', '--scale-fc', '--scale-heads', '--scale-attn', '--scale-fc'], + [ + "--add-bos-token", + "--nval", + "1", + "--scale-fc", + "--scale-heads", + "--scale-attn", + "--scale-fc", + ], run_validation=True, ) eval_lm_main(data_dir) @@ -1202,6 +1211,7 @@ class TestLanguageModeling(unittest.TestCase): "500", ], ) + def test_transformer_lm_with_adaptive_softmax(self): with contextlib.redirect_stdout(StringIO()): with tempfile.TemporaryDirectory( diff --git a/tests/test_checkpoint_utils.py b/tests/test_checkpoint_utils.py index 0f282226..23ba034f 100644 --- a/tests/test_checkpoint_utils.py +++ b/tests/test_checkpoint_utils.py @@ -11,9 +11,9 @@ import unittest from io import StringIO from unittest.mock import patch -from fairseq import checkpoint_utils from omegaconf import OmegaConf +from fairseq import checkpoint_utils from tests.utils import ( create_dummy_data, preprocess_translation_data, @@ -56,23 +56,23 @@ class TestCheckpointUtils(unittest.TestCase): def test_load_model_ensemble_and_task(self): # with contextlib.redirect_stdout(StringIO()): - with self._train_transformer(seed=123) as model1: - with self._train_transformer(seed=456) as model2: - ensemble, cfg, task = checkpoint_utils.load_model_ensemble_and_task( - filenames=[model1, model2] - ) - self.assertEqual(len(ensemble), 2) + with self._train_transformer(seed=123) as model1: + with self._train_transformer(seed=456) as model2: + ensemble, cfg, task = checkpoint_utils.load_model_ensemble_and_task( + filenames=[model1, model2] + ) + self.assertEqual(len(ensemble), 2) - # after Transformer has been migrated to Hydra, this will probably - # become cfg.common.seed - self.assertEqual(ensemble[0].args.seed, 123) - self.assertEqual(ensemble[1].args.seed, 456) + # after Transformer has been migrated to Hydra, this will probably + # become cfg.common.seed + self.assertEqual(ensemble[0].args.seed, 123) + self.assertEqual(ensemble[1].args.seed, 456) - # the task from the first model should be returned - self.assertTrue("seed123" in task.cfg.data) + # the task from the first model should be returned + self.assertTrue("seed123" in task.cfg.data) - # last cfg is saved - self.assertEqual(cfg.common.seed, 456) + # last cfg is saved + self.assertEqual(cfg.common.seed, 456) def test_prune_state_dict(self): with contextlib.redirect_stdout(StringIO()): @@ -94,7 +94,9 @@ class TestCheckpointUtils(unittest.TestCase): filename = "async_checkpoint.pt" with patch(f"{checkpoint_utils.__name__}.PathManager.opena") as mock_opena: - with patch(f"{checkpoint_utils.__name__}._torch_persistent_save") as mock_save: + with patch( + f"{checkpoint_utils.__name__}._torch_persistent_save" + ) as mock_save: checkpoint_utils.torch_persistent_save( state_dict, filename, async_write=True ) diff --git a/tests/test_data_utils.py b/tests/test_data_utils.py index 2acfc8dc..c48d02c5 100644 --- a/tests/test_data_utils.py +++ b/tests/test_data_utils.py @@ -6,8 +6,8 @@ import unittest import numpy as np -from fairseq.data.data_utils_fast import batch_by_size_fn -from fairseq.data.data_utils_fast import batch_by_size_vec + +from fairseq.data.data_utils_fast import batch_by_size_fn, batch_by_size_vec class TestBatchBySize(unittest.TestCase): @@ -20,7 +20,7 @@ class TestBatchBySize(unittest.TestCase): max_sentences, bsz_mult, ): - """Simple, reliable and slow implementation of batch by size """ + """Simple, reliable and slow implementation of batch by size""" batches = [] start = 0 while start < len(indices): diff --git a/tests/test_dataclass_utils.py b/tests/test_dataclass_utils.py index 45fc391a..231f86b6 100644 --- a/tests/test_dataclass_utils.py +++ b/tests/test_dataclass_utils.py @@ -41,7 +41,7 @@ class TestDataclassUtils(unittest.TestCase): def test_argparse_convert_basic(self): parser = ArgumentParser() gen_parser_from_dataclass(parser, A(), True) - args = parser.parse_args(["--num-layers", '10', "the/data/path"]) + args = parser.parse_args(["--num-layers", "10", "the/data/path"]) self.assertEqual(args.num_layers, 10) self.assertEqual(args.data, "the/data/path") diff --git a/tests/test_ema.py b/tests/test_ema.py index 88ea65a4..e6f10ce9 100644 --- a/tests/test_ema.py +++ b/tests/test_ema.py @@ -9,6 +9,7 @@ from dataclasses import dataclass from typing import Optional import torch + from fairseq.models.ema import EMA @@ -44,9 +45,7 @@ class TestEMAGPU(unittest.TestCase): other_norm = torch.norm(y.float()) if msg is None: - msg = "|input - other| > {} + {} * |other|".format( - atol, rtol - ) + msg = "|input - other| > {} + {} * |other|".format(atol, rtol) self.assertLessEqual( diff_norm, @@ -103,9 +102,7 @@ class TestEMAGPU(unittest.TestCase): for key, param in model2.state_dict().items(): ema_param = ema_state_dict[key] - self.assertTrue( - torch.allclose(ema_param, param) - ) + self.assertTrue(torch.allclose(ema_param, param)) def test_ema_fp32(self): model = DummyModule().half() @@ -135,17 +132,27 @@ class TestEMAGPU(unittest.TestCase): # closer to the EMA update done in fp32 than in fp16. self.assertLessEqual( torch.norm( - ema_param.float() - - (config.ema_decay * prev_param.float() + (1 - config.ema_decay) * param.float()).half().float() + ema_param.float() + - ( + config.ema_decay * prev_param.float() + + (1 - config.ema_decay) * param.float() + ) + .half() + .float() ), torch.norm( - ema_param.float() - - (config.ema_decay * prev_param + (1 - config.ema_decay) * param).float() + ema_param.float() + - ( + config.ema_decay * prev_param + (1 - config.ema_decay) * param + ).float() ), ) self.assertTorchAllClose( ema_param, - (config.ema_decay * prev_param.float() + (1 - config.ema_decay) * param.float()).half(), + ( + config.ema_decay * prev_param.float() + + (1 - config.ema_decay) * param.float() + ).half(), ) def test_ema_fp16(self): @@ -178,12 +185,19 @@ class TestEMAGPU(unittest.TestCase): # closer to the EMA update done in fp16 than in fp32. self.assertLessEqual( torch.norm( - ema_param.float() - - (config.ema_decay * prev_param + (1 - config.ema_decay) * param).float() + ema_param.float() + - ( + config.ema_decay * prev_param + (1 - config.ema_decay) * param + ).float() ), torch.norm( - ema_param.float() - - (config.ema_decay * prev_param.float() + (1 - config.ema_decay) * param.float()).half().float() + ema_param.float() + - ( + config.ema_decay * prev_param.float() + + (1 - config.ema_decay) * param.float() + ) + .half() + .float() ), ) self.assertTorchAllClose( diff --git a/tests/test_export.py b/tests/test_export.py index b380697b..3e9a48d1 100644 --- a/tests/test_export.py +++ b/tests/test_export.py @@ -9,12 +9,12 @@ import tempfile import unittest import torch + from fairseq.data.dictionary import Dictionary from fairseq.models.transformer import TransformerModel from fairseq.modules import multihead_attention, sinusoidal_positional_embedding from fairseq.tasks.fairseq_task import LegacyFairseqTask - DEFAULT_TEST_VOCAB_SIZE = 100 @@ -116,6 +116,5 @@ class TestExportModels(unittest.TestCase): _test_save_and_load(scripted) - if __name__ == "__main__": unittest.main() diff --git a/tests/test_file_io.py b/tests/test_file_io.py index 425812bf..d1f33dad 100644 --- a/tests/test_file_io.py +++ b/tests/test_file_io.py @@ -50,6 +50,7 @@ class TestFileIO(unittest.TestCase): # ioPath `PathManager` is initialized after the first `opena` call. try: from fairseq.file_io import IOPathManager, PathManager + _asyncfile = os.path.join(self._tmpdir, "async.txt") f = PathManager.opena(_asyncfile, "wb") f.close() diff --git a/tests/test_iopath.py b/tests/test_iopath.py index 908261a6..48230a63 100644 --- a/tests/test_iopath.py +++ b/tests/test_iopath.py @@ -8,7 +8,6 @@ from unittest import mock class TestIOPath(unittest.TestCase): - def test_no_iopath(self): from .test_reproducibility import TestReproducibility diff --git a/tests/test_lm_context_window.py b/tests/test_lm_context_window.py index 7415e86a..f8d7e720 100644 --- a/tests/test_lm_context_window.py +++ b/tests/test_lm_context_window.py @@ -6,23 +6,25 @@ import unittest import torch + from fairseq.data import MonolingualDataset -from fairseq.tasks.language_modeling import LanguageModelingTask, LanguageModelingConfig +from fairseq.tasks.language_modeling import LanguageModelingConfig, LanguageModelingTask from tests import utils as test_utils class TestLMContextWindow(unittest.TestCase): - def test_eval_dataloader(self): dictionary = test_utils.dummy_dictionary(10) assert len(dictionary) == 14 # 4 extra special symbols assert dictionary.pad() == 1 - dataset = test_utils.TestDataset([ - torch.tensor([4, 5, 6, 7], dtype=torch.long), - torch.tensor([8, 9, 10, 11], dtype=torch.long), - torch.tensor([12, 13], dtype=torch.long), - ]) + dataset = test_utils.TestDataset( + [ + torch.tensor([4, 5, 6, 7], dtype=torch.long), + torch.tensor([8, 9, 10, 11], dtype=torch.long), + torch.tensor([12, 13], dtype=torch.long), + ] + ) dataset = MonolingualDataset(dataset, sizes=[4, 4, 2], src_vocab=dictionary) config = LanguageModelingConfig(tokens_per_sample=4) diff --git a/tests/test_multi_corpus_dataset.py b/tests/test_multi_corpus_dataset.py index 278bdb73..79900abf 100644 --- a/tests/test_multi_corpus_dataset.py +++ b/tests/test_multi_corpus_dataset.py @@ -7,6 +7,7 @@ import unittest from collections import OrderedDict import torch + from fairseq.data import LanguagePairDataset, TokenBlockDataset from fairseq.data.multi_corpus_dataset import MultiCorpusDataset from tests.test_train import mock_dict @@ -69,8 +70,10 @@ class TestMultiCorpusDataset(unittest.TestCase): ) self.assertEqual( len(items), - int(min(len(self.dataset_1), len(indices) * distribution[0]) - + min(len(self.dataset_1), len(indices) * distribution[1])) + int( + min(len(self.dataset_1), len(indices) * distribution[0]) + + min(len(self.dataset_1), len(indices) * distribution[1]) + ), ) print(distribution) diff --git a/tests/test_noising.py b/tests/test_noising.py index b3d0d123..1956f6ad 100644 --- a/tests/test_noising.py +++ b/tests/test_noising.py @@ -6,8 +6,9 @@ import unittest from typing import Dict, List -import tests.utils as test_utils import torch + +import tests.utils as test_utils from fairseq import utils from fairseq.data import ( Dictionary, @@ -138,7 +139,7 @@ class TestDataNoising(unittest.TestCase): return x, torch.LongTensor(src_len) def assert_eos_at_end(self, x, x_len, eos): - """Asserts last token of every sentence in x is EOS """ + """Asserts last token of every sentence in x is EOS""" for i in range(len(x_len)): self.assertEqual( x[x_len[i] - 1][i], @@ -373,7 +374,7 @@ class TestDataNoising(unittest.TestCase): ) def assert_no_eos_at_end(self, x, x_len, eos): - """Asserts that the last token of each sentence in x is not EOS """ + """Asserts that the last token of each sentence in x is not EOS""" for i in range(len(x_len)): self.assertNotEqual( x[x_len[i] - 1][i], diff --git a/tests/test_sequence_generator.py b/tests/test_sequence_generator.py index 92731919..b9f91ffa 100644 --- a/tests/test_sequence_generator.py +++ b/tests/test_sequence_generator.py @@ -4,22 +4,21 @@ # LICENSE file in the root directory of this source tree. import argparse +import math import tempfile import unittest -import math -import numpy as np +import numpy as np +import torch import tests.utils as test_utils -import torch from fairseq import search from fairseq.data.dictionary import Dictionary from fairseq.models.transformer import TransformerModel -from fairseq.sequence_generator import EnsembleModel, SequenceGenerator from fairseq.ngram_repeat_block import NGramRepeatBlock +from fairseq.sequence_generator import EnsembleModel, SequenceGenerator from fairseq.tasks.fairseq_task import LegacyFairseqTask - DEFAULT_TEST_VOCAB_SIZE = 100 @@ -590,9 +589,11 @@ class TestPrefixBeamSearch(TestSequenceGeneratorBase): # prefix step 0: torch.FloatTensor( [ - # eos - [0.0, unk] + [1.0 / vocab_size] * vocab_size # beam 1 - ] * self.beam_size + # eos + [0.0, unk] + + [1.0 / vocab_size] * vocab_size # beam 1 + ] + * self.beam_size ), ] * vocab_size @@ -617,6 +618,7 @@ class TestPrefixBeamSearch(TestSequenceGeneratorBase): # make sure test sample doesn't break any assertion generator.forward(sample, prefix_tokens=self.tokens[:, :-1]) + class TestTopPSamplingSearch(TestSequenceGeneratorBase): def setUp(self): # construct dummy dictionary diff --git a/tests/utils.py b/tests/utils.py index 6e0c7095..ce2e361d 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -12,9 +12,12 @@ from io import StringIO import torch import torch.nn.functional as F + +import fairseq.distributed.utils as distributed_utils from fairseq import options, utils from fairseq.data import Dictionary from fairseq.data.language_pair_dataset import collate +from fairseq.dataclass.utils import convert_namespace_to_omegaconf from fairseq.models import ( FairseqEncoder, FairseqEncoderDecoderModel, @@ -23,8 +26,6 @@ from fairseq.models import ( from fairseq.models.fairseq_encoder import EncoderOut from fairseq.tasks import LegacyFairseqTask from fairseq_cli import generate, interactive, preprocess, train, validate -import fairseq.distributed.utils as distributed_utils -from fairseq.dataclass.utils import convert_namespace_to_omegaconf def dummy_dictionary(vocab_size, prefix="token_"): @@ -37,7 +38,10 @@ def dummy_dictionary(vocab_size, prefix="token_"): def dummy_dataloader( - samples, padding_idx=1, eos_idx=2, batch_size=None, + samples, + padding_idx=1, + eos_idx=2, + batch_size=None, ): if batch_size is None: batch_size = len(samples)