diff --git a/docs/hydra_integration.md b/docs/hydra_integration.md new file mode 100644 index 000000000..9b77dd835 --- /dev/null +++ b/docs/hydra_integration.md @@ -0,0 +1,113 @@ + + +## Hydra + +Hydra is an open-source Python framework that simplifies the development of research and other complex applications. The key feature is the ability to dynamically create a hierarchical configuration by composition and override it through config files and the command line. The name Hydra comes from its ability to run multiple similar jobs - much like a Hydra with multiple heads. + +## Train models with hydra interface + +#### Provide parameters in `.yaml` files +For example, if we'd like to train a language model with transformer, we could provide parameters in yaml files. Note that the modules used (task, model, criterion, optimizer, lr scheduler) in training must be migrated with hydra interface already (See session below). + +- Provide top level choices on which generic parameter file, and which modules to use: `config/config.yaml`, this will look like for example: + +``` +defaults: + - params: training_params + - task: language_modeling + - model: transformer_lm + - criterion: cross_entropy + - optimizer: adam + - lr_scheduler: inverse_sqrt +``` + +- Provide generic parameters common across different training jobs: `config/params/training_params.yaml` +- Provide task parameters: `config/task/language_modeling.yaml` +- Provide model parameters: `config/model/transformer_lm.yaml` +- Provide criterion parameters: `config/criterion/cross_entropy.yaml` +- Provide optimizer parameters: `config/optimizer/adam.yaml` +- Provide lr_scheduler parameters `config/lr_scheduler/inverse_sqrt.yaml` + +#### Command line overriding +`train_hydra.py` is the main entry point for training with hydra interface. If we specify all parameters we want in `.yaml` files, then we could simply use command: + +``` +# task.data is requested field marked by `???` in yaml +python fairseq_cli/train_hydra.py \ +task.data=/private/home/abaevski/data/wiki103 \ +``` + +Alternatively, if we need to override certain params from the command line, we could do so as below (note the structure of where each parameter sits) + +``` +python fairseq_cli/train_hydra.py +params=training_params \ +task=language_modeling \ +task.data=/private/home/abaevski/data/wiki103 \ +task.tokens_per_sample=512 \ +task.sample_break_mode=none \ +model=transformer_lm \ +model.share_decoder_input_output_embed=true \ +model.dropout=0.1 \ +optimizer=adam \ +optimizer.adam_betas="'(0.9, 0.98)'" \ +optimizer.weight_decay=0.01 \ +lr_scheduler=inverse_sqrt \ +lr_scheduler.warmup_updates=4000 \ +lr_scheduler.warmup_init_lr=1e-07 \ +criterion=cross_entropy \ +params.common.fp16=true \ +params.common.log_format=json \ +params.common.log_interval=1 \ +params.dataset.max_tokens=1024 \ +params.dataset.num_workers=4 \ +params.optimization.update_freq=[16] \ +params.optimization.max_update=50000 \ +params.optimization.clip_norm=0.0 \ +params.optimization.lr=[0.0005] \ +params.checkpoint.save_dir=/checkpoint/mtian/transformer_wikitext-103-hydra-args-cli \ +params.checkpoint.save_interval_updates=10 +``` + +## Migrate existing/Creating new modules to hydra interface + +In each of the modules we want to migrated/create with hydra interface, fundamentally we need to + +- Provide a dataclass that layouts the parameters used in the module. + +- Modify the builder and/or constructor that previously takes `argparse.Namespace` argument `args`, into taking `omegaconf.DictConfig` config objects. At this moment we allow `Union[omegaconf.DictConfig, argparse.Namespace]` to support compatibility. + +- For `add_args()`, we need to extract argument from the dataclass defined in the same file, and append them into `parser`. This is also to support compatibility. This is simply supported with `gen_parser_from_dataclass` API, see examples files below. + +#### Migrated examples: + +- Task: `fairseq/tasks/language_modeling.py` + +- Model: `fairseq/models/transformer_lm.py` + +- Criterion: `fairseq/criterions/adaptive_loss.py` and `fairseq/criterions/cross_entropy.py` + +- Optimizer: `fairseq/optim/adam.py` and `fairseq/optim/nag.py` + +- LR scheduler: `fairseq/optim/lr_scheduler/cosine_lr_scheduler.py` and `fairseq/optim/lr_scheduler/inverse_square_root_schedule.py` + + +## Interpolate parameters across different places + +## Support of legacy interface +If you still like to pass legacy style arguments in command line, `fairseq_cli/train.py` can support this. Internally it coverted `args` into hydra config objects whenever there are migrated modules aligned. + +``` +python fairseq_cli/train.py --task language_modeling \ +/private/home/abaevski/data/wiki103 \ +--save-dir /checkpoint/mtian/transformer_wikitext-103-hydra-args-cli \ +--arch transformer_lm --share-decoder-input-output-embed \ +--dropout 0.1 \ +--optimizer adam --adam-betas '(0.9, 0.98)' --weight-decay 0.01 --clip-norm 0.0 \ +--lr 0.0005 --lr-scheduler inverse_sqrt --warmup-updates 4000 --warmup-init-lr 1e-07 \ +--tokens-per-sample 512 --sample-break-mode none \ +--max-tokens 1024 --update-freq 16 \ +--fp16 \ +--max-update 50000 --log-format json --log-interval 1 --num-workers 4 \ +--save-interval-updates 10 +``` diff --git a/examples/roberta/commonsense_qa/commonsense_qa_task.py b/examples/roberta/commonsense_qa/commonsense_qa_task.py index 274e8d39a..7ed2bc36a 100644 --- a/examples/roberta/commonsense_qa/commonsense_qa_task.py +++ b/examples/roberta/commonsense_qa/commonsense_qa_task.py @@ -22,11 +22,11 @@ from fairseq.data import ( RightPadDataset, SortDataset, ) -from fairseq.tasks import FairseqTask, register_task +from fairseq.tasks import register_task, LegacyFairseqTask @register_task('commonsense_qa') -class CommonsenseQATask(FairseqTask): +class CommonsenseQATask(LegacyFairseqTask): """Task to finetune RoBERTa for Commonsense QA.""" @staticmethod diff --git a/examples/roberta/wsc/wsc_task.py b/examples/roberta/wsc/wsc_task.py index fbba0d896..058e3eea2 100644 --- a/examples/roberta/wsc/wsc_task.py +++ b/examples/roberta/wsc/wsc_task.py @@ -24,13 +24,13 @@ from fairseq.data import ( PadDataset, SortDataset, ) -from fairseq.tasks import FairseqTask, register_task +from fairseq.tasks import register_task, LegacyFairseqTask from . import wsc_utils @register_task('wsc') -class WSCTask(FairseqTask): +class WSCTask(LegacyFairseqTask): """Task to finetune RoBERTa for Winograd Schemas.""" @staticmethod diff --git a/examples/speech_recognition/tasks/speech_recognition.py b/examples/speech_recognition/tasks/speech_recognition.py index e5717c0ef..dde0b1257 100644 --- a/examples/speech_recognition/tasks/speech_recognition.py +++ b/examples/speech_recognition/tasks/speech_recognition.py @@ -10,7 +10,7 @@ import sys import torch from fairseq.data import Dictionary -from fairseq.tasks import FairseqTask, register_task +from fairseq.tasks import register_task, LegacyFairseqTask from examples.speech_recognition.data import AsrDataset from examples.speech_recognition.data.replabels import replabel_symbol @@ -66,7 +66,7 @@ def get_asr_dataset_from_json(data_json_path, tgt_dict): @register_task("speech_recognition") -class SpeechRecognitionTask(FairseqTask): +class SpeechRecognitionTask(LegacyFairseqTask): """ Task for training speech recognition model. """ diff --git a/fairseq/benchmark/dummy_lm.py b/fairseq/benchmark/dummy_lm.py index 92e9dc8df..f33a1adcf 100644 --- a/fairseq/benchmark/dummy_lm.py +++ b/fairseq/benchmark/dummy_lm.py @@ -9,14 +9,14 @@ import numpy as np import torch from fairseq.data import Dictionary, FairseqDataset -from fairseq.tasks import FairseqTask, register_task +from fairseq.tasks import register_task, LegacyFairseqTask logger = logging.getLogger(__name__) @register_task('dummy_lm') -class DummyLMTask(FairseqTask): +class DummyLMTask(LegacyFairseqTask): @staticmethod def add_args(parser): diff --git a/fairseq/benchmark/dummy_masked_lm.py b/fairseq/benchmark/dummy_masked_lm.py index f2e459caa..3b0bdc51f 100644 --- a/fairseq/benchmark/dummy_masked_lm.py +++ b/fairseq/benchmark/dummy_masked_lm.py @@ -9,14 +9,14 @@ import numpy as np import torch from fairseq.data import Dictionary, FairseqDataset -from fairseq.tasks import FairseqTask, register_task +from fairseq.tasks import register_task, LegacyFairseqTask logger = logging.getLogger(__name__) @register_task('dummy_masked_lm') -class DummyMaskedLMTask(FairseqTask): +class DummyMaskedLMTask(LegacyFairseqTask): @staticmethod def add_args(parser): diff --git a/fairseq/benchmark/dummy_mt.py b/fairseq/benchmark/dummy_mt.py index 9fba9bb52..0371b3e75 100644 --- a/fairseq/benchmark/dummy_mt.py +++ b/fairseq/benchmark/dummy_mt.py @@ -9,14 +9,14 @@ import numpy as np import torch from fairseq.data import Dictionary, FairseqDataset -from fairseq.tasks import FairseqTask, register_task +from fairseq.tasks import register_task, LegacyFairseqTask logger = logging.getLogger(__name__) @register_task('dummy_mt') -class DummyMTTask(FairseqTask): +class DummyMTTask(LegacyFairseqTask): @staticmethod def add_args(parser): diff --git a/fairseq/optim/__init__.py b/fairseq/optim/__init__.py index dff140d58..b172b270a 100644 --- a/fairseq/optim/__init__.py +++ b/fairseq/optim/__init__.py @@ -7,7 +7,7 @@ import importlib import os from fairseq import registry -from fairseq.optim.fairseq_optimizer import FairseqOptimizer +from fairseq.optim.fairseq_optimizer import FairseqOptimizer, LegacyFairseqOptimizer # noqa from fairseq.optim.fp16_optimizer import FP16Optimizer, MemoryEfficientFP16Optimizer from fairseq.optim.bmuf import FairseqBMUF # noqa from fairseq.optim.shard import shard_ diff --git a/fairseq/optim/adadelta.py b/fairseq/optim/adadelta.py index 0a76e27fe..9b311ae38 100644 --- a/fairseq/optim/adadelta.py +++ b/fairseq/optim/adadelta.py @@ -5,11 +5,11 @@ import torch.optim -from . import FairseqOptimizer, register_optimizer +from . import register_optimizer, LegacyFairseqOptimizer @register_optimizer('adadelta') -class Adadelta(FairseqOptimizer): +class Adadelta(LegacyFairseqOptimizer): def __init__(self, args, params): super().__init__(args) self._optimizer = torch.optim.Adadelta(params, **self.optimizer_config) diff --git a/fairseq/optim/adafactor.py b/fairseq/optim/adafactor.py index b0fb3a9f5..ab69e0e58 100644 --- a/fairseq/optim/adafactor.py +++ b/fairseq/optim/adafactor.py @@ -7,11 +7,11 @@ import math import torch import torch.optim -from . import FairseqOptimizer, register_optimizer +from . import register_optimizer, LegacyFairseqOptimizer @register_optimizer('adafactor') -class FairseqAdafactor(FairseqOptimizer): +class FairseqAdafactor(LegacyFairseqOptimizer): def __init__(self, args, params): super().__init__(args) self._optimizer = Adafactor(params, **self.optimizer_config) diff --git a/fairseq/optim/adagrad.py b/fairseq/optim/adagrad.py index 57f83258c..505675277 100644 --- a/fairseq/optim/adagrad.py +++ b/fairseq/optim/adagrad.py @@ -5,11 +5,11 @@ import torch.optim -from . import FairseqOptimizer, register_optimizer +from . import register_optimizer, LegacyFairseqOptimizer @register_optimizer('adagrad') -class Adagrad(FairseqOptimizer): +class Adagrad(LegacyFairseqOptimizer): def __init__(self, args, params): super().__init__(args) self._optimizer = torch.optim.Adagrad(params, **self.optimizer_config) diff --git a/fairseq/optim/adamax.py b/fairseq/optim/adamax.py index 856215a3b..195e7a90d 100644 --- a/fairseq/optim/adamax.py +++ b/fairseq/optim/adamax.py @@ -6,11 +6,11 @@ import torch import torch.optim -from . import FairseqOptimizer, register_optimizer +from . import register_optimizer, LegacyFairseqOptimizer @register_optimizer('adamax') -class FairseqAdamax(FairseqOptimizer): +class FairseqAdamax(LegacyFairseqOptimizer): def __init__(self, args, params): super().__init__(args) self._optimizer = Adamax(params, **self.optimizer_config) diff --git a/fairseq/optim/fairseq_optimizer.py b/fairseq/optim/fairseq_optimizer.py index e00a04dd1..18c26a3a3 100644 --- a/fairseq/optim/fairseq_optimizer.py +++ b/fairseq/optim/fairseq_optimizer.py @@ -140,3 +140,9 @@ class FairseqOptimizer(object): def average_params(self): pass + + +class LegacyFairseqOptimizer(FairseqOptimizer): + + def __init__(self, args): + self.args = args diff --git a/fairseq/optim/fused_lamb.py b/fairseq/optim/fused_lamb.py index f9b0409c5..d48ecbc8e 100644 --- a/fairseq/optim/fused_lamb.py +++ b/fairseq/optim/fused_lamb.py @@ -3,11 +3,11 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from fairseq.optim import FairseqOptimizer, register_optimizer +from fairseq.optim import register_optimizer, LegacyFairseqOptimizer @register_optimizer('lamb') -class FairseqLAMB(FairseqOptimizer): +class FairseqLAMB(LegacyFairseqOptimizer): """LAMB optimizer.""" def __init__(self, args, params): diff --git a/fairseq/optim/lr_scheduler/__init__.py b/fairseq/optim/lr_scheduler/__init__.py index edd0a6a13..76c535718 100644 --- a/fairseq/optim/lr_scheduler/__init__.py +++ b/fairseq/optim/lr_scheduler/__init__.py @@ -7,7 +7,7 @@ import importlib import os from fairseq import registry -from fairseq.optim.lr_scheduler.fairseq_lr_scheduler import FairseqLRScheduler +from fairseq.optim.lr_scheduler.fairseq_lr_scheduler import FairseqLRScheduler, LegacyFairseqLRScheduler # noqa build_lr_scheduler, register_lr_scheduler, LR_SCHEDULER_REGISTRY = registry.setup_registry( diff --git a/fairseq/optim/lr_scheduler/fairseq_lr_scheduler.py b/fairseq/optim/lr_scheduler/fairseq_lr_scheduler.py index 8b7884829..5569de3db 100644 --- a/fairseq/optim/lr_scheduler/fairseq_lr_scheduler.py +++ b/fairseq/optim/lr_scheduler/fairseq_lr_scheduler.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. from .. import FairseqOptimizer +from argparse import Namespace class FairseqLRScheduler(object): @@ -40,3 +41,13 @@ class FairseqLRScheduler(object): def step_update(self, num_updates): """Update the learning rate after each update.""" return self.optimizer.get_lr() + + +class LegacyFairseqLRScheduler(FairseqLRScheduler): + + def __init__(self, args: Namespace, optimizer): + if not isinstance(optimizer, FairseqOptimizer): + raise ValueError('optimizer must be an instance of FairseqOptimizer') + self.args = args + self.optimizer = optimizer + self.best = None diff --git a/fairseq/optim/lr_scheduler/fixed_schedule.py b/fairseq/optim/lr_scheduler/fixed_schedule.py index 1c3edd004..9a30195fa 100644 --- a/fairseq/optim/lr_scheduler/fixed_schedule.py +++ b/fairseq/optim/lr_scheduler/fixed_schedule.py @@ -3,11 +3,11 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from . import FairseqLRScheduler, register_lr_scheduler +from . import register_lr_scheduler, LegacyFairseqLRScheduler @register_lr_scheduler('fixed') -class FixedSchedule(FairseqLRScheduler): +class FixedSchedule(LegacyFairseqLRScheduler): """Decay the LR on a fixed schedule.""" def __init__(self, args, optimizer): diff --git a/fairseq/optim/lr_scheduler/polynomial_decay_schedule.py b/fairseq/optim/lr_scheduler/polynomial_decay_schedule.py index aff57f9b9..73e8b170b 100644 --- a/fairseq/optim/lr_scheduler/polynomial_decay_schedule.py +++ b/fairseq/optim/lr_scheduler/polynomial_decay_schedule.py @@ -3,11 +3,11 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from . import FairseqLRScheduler, register_lr_scheduler +from . import register_lr_scheduler, LegacyFairseqLRScheduler @register_lr_scheduler('polynomial_decay') -class PolynomialDecaySchedule(FairseqLRScheduler): +class PolynomialDecaySchedule(LegacyFairseqLRScheduler): """Decay the LR on a fixed schedule.""" def __init__(self, args, optimizer): diff --git a/fairseq/optim/lr_scheduler/reduce_lr_on_plateau.py b/fairseq/optim/lr_scheduler/reduce_lr_on_plateau.py index 65ac2e307..5199b09a3 100644 --- a/fairseq/optim/lr_scheduler/reduce_lr_on_plateau.py +++ b/fairseq/optim/lr_scheduler/reduce_lr_on_plateau.py @@ -5,11 +5,11 @@ import torch.optim.lr_scheduler -from . import FairseqLRScheduler, register_lr_scheduler +from . import register_lr_scheduler, LegacyFairseqLRScheduler @register_lr_scheduler('reduce_lr_on_plateau') -class ReduceLROnPlateau(FairseqLRScheduler): +class ReduceLROnPlateau(LegacyFairseqLRScheduler): """ Decay the LR by a factor every time the validation loss plateaus. Also comes with optional warmup phase, where we linearly increase diff --git a/fairseq/optim/lr_scheduler/tri_stage_lr_scheduler.py b/fairseq/optim/lr_scheduler/tri_stage_lr_scheduler.py index 3460fa122..95c5576f2 100644 --- a/fairseq/optim/lr_scheduler/tri_stage_lr_scheduler.py +++ b/fairseq/optim/lr_scheduler/tri_stage_lr_scheduler.py @@ -3,12 +3,12 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from . import FairseqLRScheduler, register_lr_scheduler +from . import register_lr_scheduler, LegacyFairseqLRScheduler import math @register_lr_scheduler('tri_stage') -class TriStageLRSchedule(FairseqLRScheduler): +class TriStageLRSchedule(LegacyFairseqLRScheduler): """Tristage learning rate schedulr Implement the learning rate scheduler in https://arxiv.org/pdf/1904.08779.pdf diff --git a/fairseq/optim/lr_scheduler/triangular_lr_scheduler.py b/fairseq/optim/lr_scheduler/triangular_lr_scheduler.py index fed0cf7ef..67e1df65e 100644 --- a/fairseq/optim/lr_scheduler/triangular_lr_scheduler.py +++ b/fairseq/optim/lr_scheduler/triangular_lr_scheduler.py @@ -5,11 +5,11 @@ import math -from . import FairseqLRScheduler, register_lr_scheduler +from . import register_lr_scheduler, LegacyFairseqLRScheduler @register_lr_scheduler('triangular') -class TriangularSchedule(FairseqLRScheduler): +class TriangularSchedule(LegacyFairseqLRScheduler): """Assign LR based on a triangular cyclical schedule. See https://arxiv.org/pdf/1506.01186.pdf for details. diff --git a/fairseq/optim/sgd.py b/fairseq/optim/sgd.py index 8c4e3e0a8..b558f41ab 100644 --- a/fairseq/optim/sgd.py +++ b/fairseq/optim/sgd.py @@ -5,11 +5,11 @@ import torch.optim -from . import FairseqOptimizer, register_optimizer +from . import register_optimizer, LegacyFairseqOptimizer @register_optimizer('sgd') -class SGD(FairseqOptimizer): +class SGD(LegacyFairseqOptimizer): def __init__(self, args, params): super().__init__(args) self._optimizer = torch.optim.SGD(params, **self.optimizer_config) diff --git a/fairseq/tasks/__init__.py b/fairseq/tasks/__init__.py index b1bb404f1..69231a852 100644 --- a/fairseq/tasks/__init__.py +++ b/fairseq/tasks/__init__.py @@ -7,7 +7,7 @@ import argparse import importlib import os -from .fairseq_task import FairseqTask +from .fairseq_task import FairseqTask, LegacyFairseqTask # noqa TASK_REGISTRY = {} TASK_CLASS_NAMES = set() diff --git a/fairseq/tasks/audio_pretraining.py b/fairseq/tasks/audio_pretraining.py index 2a51279eb..75bcfaa8d 100644 --- a/fairseq/tasks/audio_pretraining.py +++ b/fairseq/tasks/audio_pretraining.py @@ -9,7 +9,7 @@ import os import sys from fairseq.data import FileAudioDataset, Dictionary, AddTargetDataset -from . import FairseqTask, register_task +from . import LegacyFairseqTask, register_task class LabelEncoder(object): @@ -23,7 +23,7 @@ class LabelEncoder(object): @register_task("audio_pretraining") -class AudioPretrainingTask(FairseqTask): +class AudioPretrainingTask(LegacyFairseqTask): """ """ diff --git a/fairseq/tasks/cross_lingual_lm.py b/fairseq/tasks/cross_lingual_lm.py index 3589492f1..a7ce1f1ad 100644 --- a/fairseq/tasks/cross_lingual_lm.py +++ b/fairseq/tasks/cross_lingual_lm.py @@ -21,14 +21,14 @@ from fairseq.data import ( ) from fairseq.data.legacy.masked_lm_dataset import MaskedLMDataset from fairseq.data.multi_corpus_sampled_dataset import MultiCorpusSampledDataset -from fairseq.tasks import FairseqTask, register_task +from fairseq.tasks import register_task, LegacyFairseqTask from fairseq import utils logger = logging.getLogger(__name__) @register_task('cross_lingual_lm') -class CrossLingualLMTask(FairseqTask): +class CrossLingualLMTask(LegacyFairseqTask): """ Task for training cross-lingual language models. diff --git a/fairseq/tasks/denoising.py b/fairseq/tasks/denoising.py index 28beb517f..ea6db45c7 100644 --- a/fairseq/tasks/denoising.py +++ b/fairseq/tasks/denoising.py @@ -16,7 +16,7 @@ from fairseq.data import ( TokenBlockDataset, ) from fairseq.data.encoders.utils import get_whole_word_mask -from fairseq.tasks import FairseqTask, register_task +from fairseq.tasks import register_task, LegacyFairseqTask from fairseq import utils @@ -24,7 +24,7 @@ logger = logging.getLogger(__name__) @register_task('denoising') -class DenoisingTask(FairseqTask): +class DenoisingTask(LegacyFairseqTask): """ Denoising task for applying sequence to sequence denoising. (ie. BART) """ diff --git a/fairseq/tasks/fairseq_task.py b/fairseq/tasks/fairseq_task.py index c7b39f5b6..8da07bf8b 100644 --- a/fairseq/tasks/fairseq_task.py +++ b/fairseq/tasks/fairseq_task.py @@ -12,6 +12,7 @@ import torch from fairseq import metrics, search, tokenizer, utils from fairseq.data import data_utils, FairseqDataset, iterators, Dictionary +from argparse import Namespace logger = logging.getLogger(__name__) @@ -486,3 +487,56 @@ class FairseqTask(object): """Return the target :class:`~fairseq.data.Dictionary` (if applicable for this task).""" raise NotImplementedError + + +class LegacyFairseqTask(FairseqTask): + + def __init__(self, args: Namespace): + self.args = args + self.datasets = {} + self.dataset_to_epoch_iter = {} + + @classmethod + def setup_task(cls, args: Namespace, **kwargs): + """Setup the task (e.g., load dictionaries). + + Args: + args (argparse.Namespace): parsed command-line arguments + """ + return cls(args, **kwargs) + + def has_sharded_data(self, split): + return (os.pathsep in getattr(self.args, 'data', '')) + + def build_model(self, args: Namespace): + """ + Build the :class:`~fairseq.models.BaseFairseqModel` instance for this + task. + + Args: + args (argparse.Namespace): parsed command-line arguments + + Returns: + a :class:`~fairseq.models.BaseFairseqModel` instance + """ + from fairseq import models, quantization_utils + model = models.build_model(args, self) + if getattr(args, 'tpu', False): + model.prepare_for_tpu_() + model = quantization_utils.quantize_model_scalar(model, args) + return model + + def build_criterion(self, args: Namespace): + """ + Build the :class:`~fairseq.criterions.FairseqCriterion` instance for + this task. + + Args: + args (argparse.Namespace): parsed command-line arguments + + Returns: + a :class:`~fairseq.criterions.FairseqCriterion` instance + """ + from fairseq import criterions + + return criterions.build_criterion(args, self) diff --git a/fairseq/tasks/legacy_masked_lm.py b/fairseq/tasks/legacy_masked_lm.py index 40e272495..4e0390cdc 100644 --- a/fairseq/tasks/legacy_masked_lm.py +++ b/fairseq/tasks/legacy_masked_lm.py @@ -20,7 +20,7 @@ from fairseq.data import Dictionary from fairseq.data.legacy.block_pair_dataset import BlockPairDataset from fairseq.data.legacy.masked_lm_dataset import MaskedLMDataset from fairseq.data.legacy.masked_lm_dictionary import BertDictionary -from fairseq.tasks import FairseqTask, register_task +from fairseq.tasks import register_task, LegacyFairseqTask from fairseq import utils @@ -28,7 +28,7 @@ logger = logging.getLogger(__name__) @register_task('legacy_masked_lm') -class LegacyMaskedLMTask(FairseqTask): +class LegacyMaskedLMTask(LegacyFairseqTask): """ Task for training Masked LM (BERT) model. Args: diff --git a/fairseq/tasks/masked_lm.py b/fairseq/tasks/masked_lm.py index 4a6e6a2d3..10b234a96 100644 --- a/fairseq/tasks/masked_lm.py +++ b/fairseq/tasks/masked_lm.py @@ -21,8 +21,8 @@ from fairseq.data import ( SortDataset, TokenBlockDataset, ) +from fairseq.tasks import register_task, LegacyFairseqTask from fairseq.data.shorten_dataset import maybe_shorten_dataset -from fairseq.tasks import FairseqTask, register_task from fairseq.data.encoders.utils import get_whole_word_mask from fairseq import utils @@ -31,7 +31,7 @@ logger = logging.getLogger(__name__) @register_task('masked_lm') -class MaskedLMTask(FairseqTask): +class MaskedLMTask(LegacyFairseqTask): """Task for training masked language models (e.g., BERT, RoBERTa).""" @staticmethod diff --git a/fairseq/tasks/multilingual_masked_lm.py b/fairseq/tasks/multilingual_masked_lm.py index 5d96a608b..110e580a7 100644 --- a/fairseq/tasks/multilingual_masked_lm.py +++ b/fairseq/tasks/multilingual_masked_lm.py @@ -26,7 +26,7 @@ from fairseq.data import ( SortDataset, TokenBlockDataset, ) -from fairseq.tasks import FairseqTask, register_task +from fairseq.tasks import register_task, LegacyFairseqTask from fairseq import utils @@ -34,7 +34,7 @@ logger = logging.getLogger(__name__) @register_task('multilingual_masked_lm') -class MultiLingualMaskedLMTask(FairseqTask): +class MultiLingualMaskedLMTask(LegacyFairseqTask): """Task for training masked language models (e.g., BERT, RoBERTa).""" @staticmethod diff --git a/fairseq/tasks/multilingual_translation.py b/fairseq/tasks/multilingual_translation.py index 272bcf1ae..784b438ca 100644 --- a/fairseq/tasks/multilingual_translation.py +++ b/fairseq/tasks/multilingual_translation.py @@ -6,11 +6,11 @@ from collections import OrderedDict import logging import os - +from fairseq import options import contextlib import torch -from fairseq import metrics, options +from fairseq import metrics, utils from fairseq.data import ( Dictionary, LanguagePairDataset, @@ -20,8 +20,7 @@ from fairseq.data import ( from fairseq.models import FairseqMultiModel from fairseq.tasks.translation import load_langpair_dataset -from . import FairseqTask, register_task -from fairseq import utils +from . import register_task, LegacyFairseqTask logger = logging.getLogger(__name__) @@ -39,7 +38,7 @@ def _lang_token_index(dic: Dictionary, lang: str): @register_task('multilingual_translation') -class MultilingualTranslationTask(FairseqTask): +class MultilingualTranslationTask(LegacyFairseqTask): """A task for training multiple translation models simultaneously. We iterate round-robin over batches from multiple language pairs, ordered diff --git a/fairseq/tasks/sentence_prediction.py b/fairseq/tasks/sentence_prediction.py index cf5eae38b..fec19e0a7 100644 --- a/fairseq/tasks/sentence_prediction.py +++ b/fairseq/tasks/sentence_prediction.py @@ -25,15 +25,15 @@ from fairseq.data import ( SortDataset, StripTokenDataset, ) +from fairseq.tasks import register_task, LegacyFairseqTask from fairseq.data.shorten_dataset import maybe_shorten_dataset -from fairseq.tasks import FairseqTask, register_task logger = logging.getLogger(__name__) @register_task('sentence_prediction') -class SentencePredictionTask(FairseqTask): +class SentencePredictionTask(LegacyFairseqTask): """ Sentence (or sentence pair) prediction (classification or regression) task. diff --git a/fairseq/tasks/sentence_ranking.py b/fairseq/tasks/sentence_ranking.py index ea4b50a29..a1d332a3c 100644 --- a/fairseq/tasks/sentence_ranking.py +++ b/fairseq/tasks/sentence_ranking.py @@ -23,15 +23,15 @@ from fairseq.data import ( SortDataset, TruncateDataset ) +from fairseq.tasks import register_task, LegacyFairseqTask from fairseq.data.shorten_dataset import maybe_shorten_dataset -from fairseq.tasks import FairseqTask, register_task logger = logging.getLogger(__name__) @register_task('sentence_ranking') -class SentenceRankingTask(FairseqTask): +class SentenceRankingTask(LegacyFairseqTask): """ Ranking task on multiple sentences. diff --git a/fairseq/tasks/translation.py b/fairseq/tasks/translation.py index a01768ecb..6eac29365 100644 --- a/fairseq/tasks/translation.py +++ b/fairseq/tasks/translation.py @@ -8,10 +8,10 @@ import json import itertools import logging import os - +from fairseq import options import numpy as np -from fairseq import metrics, options, utils +from fairseq import metrics, utils from fairseq.data import ( AppendTokenDataset, ConcatDataset, @@ -24,7 +24,7 @@ from fairseq.data import ( TruncateDataset, ) -from fairseq.tasks import FairseqTask, register_task +from fairseq.tasks import register_task, LegacyFairseqTask EVAL_BLEU_ORDER = 4 @@ -133,7 +133,7 @@ def load_langpair_dataset( @register_task('translation') -class TranslationTask(FairseqTask): +class TranslationTask(LegacyFairseqTask): """ Translate from one (source) language to another (target) language. diff --git a/fairseq/tasks/translation_multi_simple_epoch.py b/fairseq/tasks/translation_multi_simple_epoch.py index e13c9fd88..94f1fd32a 100644 --- a/fairseq/tasks/translation_multi_simple_epoch.py +++ b/fairseq/tasks/translation_multi_simple_epoch.py @@ -16,7 +16,7 @@ from fairseq.data import ( ListDataset, ) -from fairseq.tasks import FairseqTask, register_task +from fairseq.tasks import register_task, LegacyFairseqTask from fairseq.data.multilingual.sampling_method import SamplingMethod from fairseq.data.multilingual.multilingual_data_manager import MultilingualDatasetManager @@ -31,7 +31,7 @@ logger = logging.getLogger(__name__) @register_task('translation_multi_simple_epoch') -class TranslationMultiSimpleEpochTask(FairseqTask): +class TranslationMultiSimpleEpochTask(LegacyFairseqTask): """ Translate from one (source) language to another (target) language. diff --git a/tests/speech_recognition/asr_test_base.py b/tests/speech_recognition/asr_test_base.py index 7482858ff..4f3d3fceb 100644 --- a/tests/speech_recognition/asr_test_base.py +++ b/tests/speech_recognition/asr_test_base.py @@ -17,7 +17,7 @@ from fairseq.models import ( FairseqEncoderModel, FairseqModel, ) -from fairseq.tasks.fairseq_task import FairseqTask +from fairseq.tasks.fairseq_task import LegacyFairseqTask from examples.speech_recognition.data.data_utils import lengths_to_encoder_padding_mask @@ -37,7 +37,7 @@ def get_dummy_dictionary(vocab_size=DEFAULT_TEST_VOCAB_SIZE): return dummy_dict -class DummyTask(FairseqTask): +class DummyTask(LegacyFairseqTask): def __init__(self, args): super().__init__(args) self.dictionary = get_dummy_dictionary() diff --git a/tests/test_export.py b/tests/test_export.py index 7b0e7fcf1..87e52bd7c 100644 --- a/tests/test_export.py +++ b/tests/test_export.py @@ -12,13 +12,13 @@ import torch from fairseq.data.dictionary import Dictionary from fairseq.models.transformer import TransformerModel from fairseq.modules import multihead_attention, sinusoidal_positional_embedding -from fairseq.tasks.fairseq_task import FairseqTask +from fairseq.tasks.fairseq_task import LegacyFairseqTask DEFAULT_TEST_VOCAB_SIZE = 100 -class DummyTask(FairseqTask): +class DummyTask(LegacyFairseqTask): def __init__(self, args): super().__init__(args) self.dictionary = get_dummy_dictionary() diff --git a/tests/test_lstm_jitable.py b/tests/test_lstm_jitable.py index d0d812cea..d97652fb7 100644 --- a/tests/test_lstm_jitable.py +++ b/tests/test_lstm_jitable.py @@ -10,13 +10,13 @@ import unittest import torch from fairseq.data.dictionary import Dictionary from fairseq.models.lstm import LSTMModel -from fairseq.tasks.fairseq_task import FairseqTask +from fairseq.tasks.fairseq_task import LegacyFairseqTask DEFAULT_TEST_VOCAB_SIZE = 100 -class DummyTask(FairseqTask): +class DummyTask(LegacyFairseqTask): def __init__(self, args): super().__init__(args) self.dictionary = get_dummy_dictionary() diff --git a/tests/test_sequence_generator.py b/tests/test_sequence_generator.py index 36560bcca..517aa77d5 100644 --- a/tests/test_sequence_generator.py +++ b/tests/test_sequence_generator.py @@ -14,13 +14,13 @@ from fairseq.data.dictionary import Dictionary from fairseq.models.transformer import TransformerModel from fairseq.sequence_generator import SequenceGenerator, EnsembleModel -from fairseq.tasks.fairseq_task import FairseqTask +from fairseq.tasks.fairseq_task import LegacyFairseqTask DEFAULT_TEST_VOCAB_SIZE = 100 -class DummyTask(FairseqTask): +class DummyTask(LegacyFairseqTask): def __init__(self, args): super().__init__(args) self.dictionary = get_dummy_dictionary() diff --git a/tests/utils.py b/tests/utils.py index 869a70c5e..ef546fa58 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -20,6 +20,7 @@ from fairseq.models import ( FairseqIncrementalDecoder, ) from fairseq.models.fairseq_encoder import EncoderOut +from fairseq.tasks import LegacyFairseqTask from fairseq.tasks import FairseqTask from fairseq_cli import ( generate, @@ -284,7 +285,7 @@ class TestDataset(torch.utils.data.Dataset): return len(self.data) -class TestTranslationTask(FairseqTask): +class TestTranslationTask(LegacyFairseqTask): def __init__(self, args, src_dict, tgt_dict, model): super().__init__(args)