hydra fairseq 3 - inherit from legacy for fairseq classes

Summary: hydra fairseq 3 - inherit from legacy for fairseq classes

Reviewed By: alexeib

Differential Revision: D23375457

fbshipit-source-id: ef9d19f2d02f2326eea44a70f1f6e1668b420840
This commit is contained in:
Mu Tian 2020-09-09 17:00:56 -07:00 committed by Facebook GitHub Bot
parent df45f42efd
commit 42c5dcbd18
40 changed files with 257 additions and 73 deletions

113
docs/hydra_integration.md Normal file
View File

@ -0,0 +1,113 @@
## Hydra
Hydra is an open-source Python framework that simplifies the development of research and other complex applications. The key feature is the ability to dynamically create a hierarchical configuration by composition and override it through config files and the command line. The name Hydra comes from its ability to run multiple similar jobs - much like a Hydra with multiple heads.
## Train models with hydra interface
#### Provide parameters in `.yaml` files
For example, if we'd like to train a language model with transformer, we could provide parameters in yaml files. Note that the modules used (task, model, criterion, optimizer, lr scheduler) in training must be migrated with hydra interface already (See session below).
- Provide top level choices on which generic parameter file, and which modules to use: `config/config.yaml`, this will look like for example:
```
defaults:
- params: training_params
- task: language_modeling
- model: transformer_lm
- criterion: cross_entropy
- optimizer: adam
- lr_scheduler: inverse_sqrt
```
- Provide generic parameters common across different training jobs: `config/params/training_params.yaml`
- Provide task parameters: `config/task/language_modeling.yaml`
- Provide model parameters: `config/model/transformer_lm.yaml`
- Provide criterion parameters: `config/criterion/cross_entropy.yaml`
- Provide optimizer parameters: `config/optimizer/adam.yaml`
- Provide lr_scheduler parameters `config/lr_scheduler/inverse_sqrt.yaml`
#### Command line overriding
`train_hydra.py` is the main entry point for training with hydra interface. If we specify all parameters we want in `.yaml` files, then we could simply use command:
```
# task.data is requested field marked by `???` in yaml
python fairseq_cli/train_hydra.py \
task.data=/private/home/abaevski/data/wiki103 \
```
Alternatively, if we need to override certain params from the command line, we could do so as below (note the structure of where each parameter sits)
```
python fairseq_cli/train_hydra.py
params=training_params \
task=language_modeling \
task.data=/private/home/abaevski/data/wiki103 \
task.tokens_per_sample=512 \
task.sample_break_mode=none \
model=transformer_lm \
model.share_decoder_input_output_embed=true \
model.dropout=0.1 \
optimizer=adam \
optimizer.adam_betas="'(0.9, 0.98)'" \
optimizer.weight_decay=0.01 \
lr_scheduler=inverse_sqrt \
lr_scheduler.warmup_updates=4000 \
lr_scheduler.warmup_init_lr=1e-07 \
criterion=cross_entropy \
params.common.fp16=true \
params.common.log_format=json \
params.common.log_interval=1 \
params.dataset.max_tokens=1024 \
params.dataset.num_workers=4 \
params.optimization.update_freq=[16] \
params.optimization.max_update=50000 \
params.optimization.clip_norm=0.0 \
params.optimization.lr=[0.0005] \
params.checkpoint.save_dir=/checkpoint/mtian/transformer_wikitext-103-hydra-args-cli \
params.checkpoint.save_interval_updates=10
```
## Migrate existing/Creating new modules to hydra interface
In each of the modules we want to migrated/create with hydra interface, fundamentally we need to
- Provide a dataclass that layouts the parameters used in the module.
- Modify the builder and/or constructor that previously takes `argparse.Namespace` argument `args`, into taking `omegaconf.DictConfig` config objects. At this moment we allow `Union[omegaconf.DictConfig, argparse.Namespace]` to support compatibility.
- For `add_args()`, we need to extract argument from the dataclass defined in the same file, and append them into `parser`. This is also to support compatibility. This is simply supported with `gen_parser_from_dataclass` API, see examples files below.
#### Migrated examples:
- Task: `fairseq/tasks/language_modeling.py`
- Model: `fairseq/models/transformer_lm.py`
- Criterion: `fairseq/criterions/adaptive_loss.py` and `fairseq/criterions/cross_entropy.py`
- Optimizer: `fairseq/optim/adam.py` and `fairseq/optim/nag.py`
- LR scheduler: `fairseq/optim/lr_scheduler/cosine_lr_scheduler.py` and `fairseq/optim/lr_scheduler/inverse_square_root_schedule.py`
## Interpolate parameters across different places
## Support of legacy interface
If you still like to pass legacy style arguments in command line, `fairseq_cli/train.py` can support this. Internally it coverted `args` into hydra config objects whenever there are migrated modules aligned.
```
python fairseq_cli/train.py --task language_modeling \
/private/home/abaevski/data/wiki103 \
--save-dir /checkpoint/mtian/transformer_wikitext-103-hydra-args-cli \
--arch transformer_lm --share-decoder-input-output-embed \
--dropout 0.1 \
--optimizer adam --adam-betas '(0.9, 0.98)' --weight-decay 0.01 --clip-norm 0.0 \
--lr 0.0005 --lr-scheduler inverse_sqrt --warmup-updates 4000 --warmup-init-lr 1e-07 \
--tokens-per-sample 512 --sample-break-mode none \
--max-tokens 1024 --update-freq 16 \
--fp16 \
--max-update 50000 --log-format json --log-interval 1 --num-workers 4 \
--save-interval-updates 10
```

View File

@ -22,11 +22,11 @@ from fairseq.data import (
RightPadDataset,
SortDataset,
)
from fairseq.tasks import FairseqTask, register_task
from fairseq.tasks import register_task, LegacyFairseqTask
@register_task('commonsense_qa')
class CommonsenseQATask(FairseqTask):
class CommonsenseQATask(LegacyFairseqTask):
"""Task to finetune RoBERTa for Commonsense QA."""
@staticmethod

View File

@ -24,13 +24,13 @@ from fairseq.data import (
PadDataset,
SortDataset,
)
from fairseq.tasks import FairseqTask, register_task
from fairseq.tasks import register_task, LegacyFairseqTask
from . import wsc_utils
@register_task('wsc')
class WSCTask(FairseqTask):
class WSCTask(LegacyFairseqTask):
"""Task to finetune RoBERTa for Winograd Schemas."""
@staticmethod

View File

@ -10,7 +10,7 @@ import sys
import torch
from fairseq.data import Dictionary
from fairseq.tasks import FairseqTask, register_task
from fairseq.tasks import register_task, LegacyFairseqTask
from examples.speech_recognition.data import AsrDataset
from examples.speech_recognition.data.replabels import replabel_symbol
@ -66,7 +66,7 @@ def get_asr_dataset_from_json(data_json_path, tgt_dict):
@register_task("speech_recognition")
class SpeechRecognitionTask(FairseqTask):
class SpeechRecognitionTask(LegacyFairseqTask):
"""
Task for training speech recognition model.
"""

View File

@ -9,14 +9,14 @@ import numpy as np
import torch
from fairseq.data import Dictionary, FairseqDataset
from fairseq.tasks import FairseqTask, register_task
from fairseq.tasks import register_task, LegacyFairseqTask
logger = logging.getLogger(__name__)
@register_task('dummy_lm')
class DummyLMTask(FairseqTask):
class DummyLMTask(LegacyFairseqTask):
@staticmethod
def add_args(parser):

View File

@ -9,14 +9,14 @@ import numpy as np
import torch
from fairseq.data import Dictionary, FairseqDataset
from fairseq.tasks import FairseqTask, register_task
from fairseq.tasks import register_task, LegacyFairseqTask
logger = logging.getLogger(__name__)
@register_task('dummy_masked_lm')
class DummyMaskedLMTask(FairseqTask):
class DummyMaskedLMTask(LegacyFairseqTask):
@staticmethod
def add_args(parser):

View File

@ -9,14 +9,14 @@ import numpy as np
import torch
from fairseq.data import Dictionary, FairseqDataset
from fairseq.tasks import FairseqTask, register_task
from fairseq.tasks import register_task, LegacyFairseqTask
logger = logging.getLogger(__name__)
@register_task('dummy_mt')
class DummyMTTask(FairseqTask):
class DummyMTTask(LegacyFairseqTask):
@staticmethod
def add_args(parser):

View File

@ -7,7 +7,7 @@ import importlib
import os
from fairseq import registry
from fairseq.optim.fairseq_optimizer import FairseqOptimizer
from fairseq.optim.fairseq_optimizer import FairseqOptimizer, LegacyFairseqOptimizer # noqa
from fairseq.optim.fp16_optimizer import FP16Optimizer, MemoryEfficientFP16Optimizer
from fairseq.optim.bmuf import FairseqBMUF # noqa
from fairseq.optim.shard import shard_

View File

@ -5,11 +5,11 @@
import torch.optim
from . import FairseqOptimizer, register_optimizer
from . import register_optimizer, LegacyFairseqOptimizer
@register_optimizer('adadelta')
class Adadelta(FairseqOptimizer):
class Adadelta(LegacyFairseqOptimizer):
def __init__(self, args, params):
super().__init__(args)
self._optimizer = torch.optim.Adadelta(params, **self.optimizer_config)

View File

@ -7,11 +7,11 @@ import math
import torch
import torch.optim
from . import FairseqOptimizer, register_optimizer
from . import register_optimizer, LegacyFairseqOptimizer
@register_optimizer('adafactor')
class FairseqAdafactor(FairseqOptimizer):
class FairseqAdafactor(LegacyFairseqOptimizer):
def __init__(self, args, params):
super().__init__(args)
self._optimizer = Adafactor(params, **self.optimizer_config)

View File

@ -5,11 +5,11 @@
import torch.optim
from . import FairseqOptimizer, register_optimizer
from . import register_optimizer, LegacyFairseqOptimizer
@register_optimizer('adagrad')
class Adagrad(FairseqOptimizer):
class Adagrad(LegacyFairseqOptimizer):
def __init__(self, args, params):
super().__init__(args)
self._optimizer = torch.optim.Adagrad(params, **self.optimizer_config)

View File

@ -6,11 +6,11 @@
import torch
import torch.optim
from . import FairseqOptimizer, register_optimizer
from . import register_optimizer, LegacyFairseqOptimizer
@register_optimizer('adamax')
class FairseqAdamax(FairseqOptimizer):
class FairseqAdamax(LegacyFairseqOptimizer):
def __init__(self, args, params):
super().__init__(args)
self._optimizer = Adamax(params, **self.optimizer_config)

View File

@ -140,3 +140,9 @@ class FairseqOptimizer(object):
def average_params(self):
pass
class LegacyFairseqOptimizer(FairseqOptimizer):
def __init__(self, args):
self.args = args

View File

@ -3,11 +3,11 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from fairseq.optim import FairseqOptimizer, register_optimizer
from fairseq.optim import register_optimizer, LegacyFairseqOptimizer
@register_optimizer('lamb')
class FairseqLAMB(FairseqOptimizer):
class FairseqLAMB(LegacyFairseqOptimizer):
"""LAMB optimizer."""
def __init__(self, args, params):

View File

@ -7,7 +7,7 @@ import importlib
import os
from fairseq import registry
from fairseq.optim.lr_scheduler.fairseq_lr_scheduler import FairseqLRScheduler
from fairseq.optim.lr_scheduler.fairseq_lr_scheduler import FairseqLRScheduler, LegacyFairseqLRScheduler # noqa
build_lr_scheduler, register_lr_scheduler, LR_SCHEDULER_REGISTRY = registry.setup_registry(

View File

@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.
from .. import FairseqOptimizer
from argparse import Namespace
class FairseqLRScheduler(object):
@ -40,3 +41,13 @@ class FairseqLRScheduler(object):
def step_update(self, num_updates):
"""Update the learning rate after each update."""
return self.optimizer.get_lr()
class LegacyFairseqLRScheduler(FairseqLRScheduler):
def __init__(self, args: Namespace, optimizer):
if not isinstance(optimizer, FairseqOptimizer):
raise ValueError('optimizer must be an instance of FairseqOptimizer')
self.args = args
self.optimizer = optimizer
self.best = None

View File

@ -3,11 +3,11 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from . import FairseqLRScheduler, register_lr_scheduler
from . import register_lr_scheduler, LegacyFairseqLRScheduler
@register_lr_scheduler('fixed')
class FixedSchedule(FairseqLRScheduler):
class FixedSchedule(LegacyFairseqLRScheduler):
"""Decay the LR on a fixed schedule."""
def __init__(self, args, optimizer):

View File

@ -3,11 +3,11 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from . import FairseqLRScheduler, register_lr_scheduler
from . import register_lr_scheduler, LegacyFairseqLRScheduler
@register_lr_scheduler('polynomial_decay')
class PolynomialDecaySchedule(FairseqLRScheduler):
class PolynomialDecaySchedule(LegacyFairseqLRScheduler):
"""Decay the LR on a fixed schedule."""
def __init__(self, args, optimizer):

View File

@ -5,11 +5,11 @@
import torch.optim.lr_scheduler
from . import FairseqLRScheduler, register_lr_scheduler
from . import register_lr_scheduler, LegacyFairseqLRScheduler
@register_lr_scheduler('reduce_lr_on_plateau')
class ReduceLROnPlateau(FairseqLRScheduler):
class ReduceLROnPlateau(LegacyFairseqLRScheduler):
"""
Decay the LR by a factor every time the validation loss plateaus.
Also comes with optional warmup phase, where we linearly increase

View File

@ -3,12 +3,12 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from . import FairseqLRScheduler, register_lr_scheduler
from . import register_lr_scheduler, LegacyFairseqLRScheduler
import math
@register_lr_scheduler('tri_stage')
class TriStageLRSchedule(FairseqLRScheduler):
class TriStageLRSchedule(LegacyFairseqLRScheduler):
"""Tristage learning rate schedulr
Implement the learning rate scheduler in https://arxiv.org/pdf/1904.08779.pdf

View File

@ -5,11 +5,11 @@
import math
from . import FairseqLRScheduler, register_lr_scheduler
from . import register_lr_scheduler, LegacyFairseqLRScheduler
@register_lr_scheduler('triangular')
class TriangularSchedule(FairseqLRScheduler):
class TriangularSchedule(LegacyFairseqLRScheduler):
"""Assign LR based on a triangular cyclical schedule.
See https://arxiv.org/pdf/1506.01186.pdf for details.

View File

@ -5,11 +5,11 @@
import torch.optim
from . import FairseqOptimizer, register_optimizer
from . import register_optimizer, LegacyFairseqOptimizer
@register_optimizer('sgd')
class SGD(FairseqOptimizer):
class SGD(LegacyFairseqOptimizer):
def __init__(self, args, params):
super().__init__(args)
self._optimizer = torch.optim.SGD(params, **self.optimizer_config)

View File

@ -7,7 +7,7 @@ import argparse
import importlib
import os
from .fairseq_task import FairseqTask
from .fairseq_task import FairseqTask, LegacyFairseqTask # noqa
TASK_REGISTRY = {}
TASK_CLASS_NAMES = set()

View File

@ -9,7 +9,7 @@ import os
import sys
from fairseq.data import FileAudioDataset, Dictionary, AddTargetDataset
from . import FairseqTask, register_task
from . import LegacyFairseqTask, register_task
class LabelEncoder(object):
@ -23,7 +23,7 @@ class LabelEncoder(object):
@register_task("audio_pretraining")
class AudioPretrainingTask(FairseqTask):
class AudioPretrainingTask(LegacyFairseqTask):
"""
"""

View File

@ -21,14 +21,14 @@ from fairseq.data import (
)
from fairseq.data.legacy.masked_lm_dataset import MaskedLMDataset
from fairseq.data.multi_corpus_sampled_dataset import MultiCorpusSampledDataset
from fairseq.tasks import FairseqTask, register_task
from fairseq.tasks import register_task, LegacyFairseqTask
from fairseq import utils
logger = logging.getLogger(__name__)
@register_task('cross_lingual_lm')
class CrossLingualLMTask(FairseqTask):
class CrossLingualLMTask(LegacyFairseqTask):
"""
Task for training cross-lingual language models.

View File

@ -16,7 +16,7 @@ from fairseq.data import (
TokenBlockDataset,
)
from fairseq.data.encoders.utils import get_whole_word_mask
from fairseq.tasks import FairseqTask, register_task
from fairseq.tasks import register_task, LegacyFairseqTask
from fairseq import utils
@ -24,7 +24,7 @@ logger = logging.getLogger(__name__)
@register_task('denoising')
class DenoisingTask(FairseqTask):
class DenoisingTask(LegacyFairseqTask):
"""
Denoising task for applying sequence to sequence denoising. (ie. BART)
"""

View File

@ -12,6 +12,7 @@ import torch
from fairseq import metrics, search, tokenizer, utils
from fairseq.data import data_utils, FairseqDataset, iterators, Dictionary
from argparse import Namespace
logger = logging.getLogger(__name__)
@ -486,3 +487,56 @@ class FairseqTask(object):
"""Return the target :class:`~fairseq.data.Dictionary` (if applicable
for this task)."""
raise NotImplementedError
class LegacyFairseqTask(FairseqTask):
def __init__(self, args: Namespace):
self.args = args
self.datasets = {}
self.dataset_to_epoch_iter = {}
@classmethod
def setup_task(cls, args: Namespace, **kwargs):
"""Setup the task (e.g., load dictionaries).
Args:
args (argparse.Namespace): parsed command-line arguments
"""
return cls(args, **kwargs)
def has_sharded_data(self, split):
return (os.pathsep in getattr(self.args, 'data', ''))
def build_model(self, args: Namespace):
"""
Build the :class:`~fairseq.models.BaseFairseqModel` instance for this
task.
Args:
args (argparse.Namespace): parsed command-line arguments
Returns:
a :class:`~fairseq.models.BaseFairseqModel` instance
"""
from fairseq import models, quantization_utils
model = models.build_model(args, self)
if getattr(args, 'tpu', False):
model.prepare_for_tpu_()
model = quantization_utils.quantize_model_scalar(model, args)
return model
def build_criterion(self, args: Namespace):
"""
Build the :class:`~fairseq.criterions.FairseqCriterion` instance for
this task.
Args:
args (argparse.Namespace): parsed command-line arguments
Returns:
a :class:`~fairseq.criterions.FairseqCriterion` instance
"""
from fairseq import criterions
return criterions.build_criterion(args, self)

View File

@ -20,7 +20,7 @@ from fairseq.data import Dictionary
from fairseq.data.legacy.block_pair_dataset import BlockPairDataset
from fairseq.data.legacy.masked_lm_dataset import MaskedLMDataset
from fairseq.data.legacy.masked_lm_dictionary import BertDictionary
from fairseq.tasks import FairseqTask, register_task
from fairseq.tasks import register_task, LegacyFairseqTask
from fairseq import utils
@ -28,7 +28,7 @@ logger = logging.getLogger(__name__)
@register_task('legacy_masked_lm')
class LegacyMaskedLMTask(FairseqTask):
class LegacyMaskedLMTask(LegacyFairseqTask):
"""
Task for training Masked LM (BERT) model.
Args:

View File

@ -21,8 +21,8 @@ from fairseq.data import (
SortDataset,
TokenBlockDataset,
)
from fairseq.tasks import register_task, LegacyFairseqTask
from fairseq.data.shorten_dataset import maybe_shorten_dataset
from fairseq.tasks import FairseqTask, register_task
from fairseq.data.encoders.utils import get_whole_word_mask
from fairseq import utils
@ -31,7 +31,7 @@ logger = logging.getLogger(__name__)
@register_task('masked_lm')
class MaskedLMTask(FairseqTask):
class MaskedLMTask(LegacyFairseqTask):
"""Task for training masked language models (e.g., BERT, RoBERTa)."""
@staticmethod

View File

@ -26,7 +26,7 @@ from fairseq.data import (
SortDataset,
TokenBlockDataset,
)
from fairseq.tasks import FairseqTask, register_task
from fairseq.tasks import register_task, LegacyFairseqTask
from fairseq import utils
@ -34,7 +34,7 @@ logger = logging.getLogger(__name__)
@register_task('multilingual_masked_lm')
class MultiLingualMaskedLMTask(FairseqTask):
class MultiLingualMaskedLMTask(LegacyFairseqTask):
"""Task for training masked language models (e.g., BERT, RoBERTa)."""
@staticmethod

View File

@ -6,11 +6,11 @@
from collections import OrderedDict
import logging
import os
from fairseq import options
import contextlib
import torch
from fairseq import metrics, options
from fairseq import metrics, utils
from fairseq.data import (
Dictionary,
LanguagePairDataset,
@ -20,8 +20,7 @@ from fairseq.data import (
from fairseq.models import FairseqMultiModel
from fairseq.tasks.translation import load_langpair_dataset
from . import FairseqTask, register_task
from fairseq import utils
from . import register_task, LegacyFairseqTask
logger = logging.getLogger(__name__)
@ -39,7 +38,7 @@ def _lang_token_index(dic: Dictionary, lang: str):
@register_task('multilingual_translation')
class MultilingualTranslationTask(FairseqTask):
class MultilingualTranslationTask(LegacyFairseqTask):
"""A task for training multiple translation models simultaneously.
We iterate round-robin over batches from multiple language pairs, ordered

View File

@ -25,15 +25,15 @@ from fairseq.data import (
SortDataset,
StripTokenDataset,
)
from fairseq.tasks import register_task, LegacyFairseqTask
from fairseq.data.shorten_dataset import maybe_shorten_dataset
from fairseq.tasks import FairseqTask, register_task
logger = logging.getLogger(__name__)
@register_task('sentence_prediction')
class SentencePredictionTask(FairseqTask):
class SentencePredictionTask(LegacyFairseqTask):
"""
Sentence (or sentence pair) prediction (classification or regression) task.

View File

@ -23,15 +23,15 @@ from fairseq.data import (
SortDataset,
TruncateDataset
)
from fairseq.tasks import register_task, LegacyFairseqTask
from fairseq.data.shorten_dataset import maybe_shorten_dataset
from fairseq.tasks import FairseqTask, register_task
logger = logging.getLogger(__name__)
@register_task('sentence_ranking')
class SentenceRankingTask(FairseqTask):
class SentenceRankingTask(LegacyFairseqTask):
"""
Ranking task on multiple sentences.

View File

@ -8,10 +8,10 @@ import json
import itertools
import logging
import os
from fairseq import options
import numpy as np
from fairseq import metrics, options, utils
from fairseq import metrics, utils
from fairseq.data import (
AppendTokenDataset,
ConcatDataset,
@ -24,7 +24,7 @@ from fairseq.data import (
TruncateDataset,
)
from fairseq.tasks import FairseqTask, register_task
from fairseq.tasks import register_task, LegacyFairseqTask
EVAL_BLEU_ORDER = 4
@ -133,7 +133,7 @@ def load_langpair_dataset(
@register_task('translation')
class TranslationTask(FairseqTask):
class TranslationTask(LegacyFairseqTask):
"""
Translate from one (source) language to another (target) language.

View File

@ -16,7 +16,7 @@ from fairseq.data import (
ListDataset,
)
from fairseq.tasks import FairseqTask, register_task
from fairseq.tasks import register_task, LegacyFairseqTask
from fairseq.data.multilingual.sampling_method import SamplingMethod
from fairseq.data.multilingual.multilingual_data_manager import MultilingualDatasetManager
@ -31,7 +31,7 @@ logger = logging.getLogger(__name__)
@register_task('translation_multi_simple_epoch')
class TranslationMultiSimpleEpochTask(FairseqTask):
class TranslationMultiSimpleEpochTask(LegacyFairseqTask):
"""
Translate from one (source) language to another (target) language.

View File

@ -17,7 +17,7 @@ from fairseq.models import (
FairseqEncoderModel,
FairseqModel,
)
from fairseq.tasks.fairseq_task import FairseqTask
from fairseq.tasks.fairseq_task import LegacyFairseqTask
from examples.speech_recognition.data.data_utils import lengths_to_encoder_padding_mask
@ -37,7 +37,7 @@ def get_dummy_dictionary(vocab_size=DEFAULT_TEST_VOCAB_SIZE):
return dummy_dict
class DummyTask(FairseqTask):
class DummyTask(LegacyFairseqTask):
def __init__(self, args):
super().__init__(args)
self.dictionary = get_dummy_dictionary()

View File

@ -12,13 +12,13 @@ import torch
from fairseq.data.dictionary import Dictionary
from fairseq.models.transformer import TransformerModel
from fairseq.modules import multihead_attention, sinusoidal_positional_embedding
from fairseq.tasks.fairseq_task import FairseqTask
from fairseq.tasks.fairseq_task import LegacyFairseqTask
DEFAULT_TEST_VOCAB_SIZE = 100
class DummyTask(FairseqTask):
class DummyTask(LegacyFairseqTask):
def __init__(self, args):
super().__init__(args)
self.dictionary = get_dummy_dictionary()

View File

@ -10,13 +10,13 @@ import unittest
import torch
from fairseq.data.dictionary import Dictionary
from fairseq.models.lstm import LSTMModel
from fairseq.tasks.fairseq_task import FairseqTask
from fairseq.tasks.fairseq_task import LegacyFairseqTask
DEFAULT_TEST_VOCAB_SIZE = 100
class DummyTask(FairseqTask):
class DummyTask(LegacyFairseqTask):
def __init__(self, args):
super().__init__(args)
self.dictionary = get_dummy_dictionary()

View File

@ -14,13 +14,13 @@ from fairseq.data.dictionary import Dictionary
from fairseq.models.transformer import TransformerModel
from fairseq.sequence_generator import SequenceGenerator, EnsembleModel
from fairseq.tasks.fairseq_task import FairseqTask
from fairseq.tasks.fairseq_task import LegacyFairseqTask
DEFAULT_TEST_VOCAB_SIZE = 100
class DummyTask(FairseqTask):
class DummyTask(LegacyFairseqTask):
def __init__(self, args):
super().__init__(args)
self.dictionary = get_dummy_dictionary()

View File

@ -20,6 +20,7 @@ from fairseq.models import (
FairseqIncrementalDecoder,
)
from fairseq.models.fairseq_encoder import EncoderOut
from fairseq.tasks import LegacyFairseqTask
from fairseq.tasks import FairseqTask
from fairseq_cli import (
generate,
@ -284,7 +285,7 @@ class TestDataset(torch.utils.data.Dataset):
return len(self.data)
class TestTranslationTask(FairseqTask):
class TestTranslationTask(LegacyFairseqTask):
def __init__(self, args, src_dict, tgt_dict, model):
super().__init__(args)