From 3b27ed7996b0315f471c795cf9b7dfcc18467cbe Mon Sep 17 00:00:00 2001 From: alexeib Date: Tue, 20 Oct 2020 00:31:00 -0700 Subject: [PATCH] Enable Hydra configs in fairseq (#1343) (#1510) Summary: Pull Request resolved: https://github.com/facebookresearch/pytext/pull/1510 this is the main pr that switches on hydra functionality in fairseq we migrate "args" object into omegaconf "DictConfig" at all legacy entry points in addition this migrates various components from secondary registries (like bpe encoders and tokenizers) to make the migration smoother i am going through code that references migrated fairseq components and changing it to inherit from "Legacy*" components instead. hopefully tests will catch most of this Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1343 Reviewed By: myleott Differential Revision: D23973928 Pulled By: alexeib fbshipit-source-id: dd9554981fff51ea75c1ff343874d1d6e61793c9 --- config/config.yaml | 116 ++++- config/config_eval_lm.yaml | 7 - config/criterion/adaptive_loss.yaml | 4 +- config/criterion/cross_entropy.yaml | 3 +- config/params/eval_lm_params.yaml | 105 ---- config/params/training_params.yaml | 95 ---- docs/hydra_integration.md | 26 +- docs/tutorial_classifying_names.rst | 2 +- examples/noisychannel/rerank.py | 8 +- examples/roberta/wsc/wsc_criterion.py | 4 +- .../unsupervised_quality_estimation/README.md | 2 +- fairseq/checkpoint_utils.py | 247 +++++---- fairseq/criterions/__init__.py | 6 +- fairseq/criterions/adaptive_loss.py | 12 +- fairseq/criterions/cross_entropy.py | 2 +- fairseq/criterions/ctc.py | 18 +- fairseq/criterions/fairseq_criterion.py | 11 +- fairseq/data/encoders/byte_bpe.py | 23 +- fairseq/data/encoders/bytes.py | 2 +- fairseq/data/encoders/characters.py | 2 +- fairseq/data/encoders/fastbpe.py | 23 +- fairseq/data/encoders/gpt2_bpe.py | 36 +- fairseq/data/encoders/hf_bert_bpe.py | 32 +- fairseq/data/encoders/hf_byte_bpe.py | 31 +- fairseq/data/encoders/moses_tokenizer.py | 48 +- fairseq/data/encoders/nltk_tokenizer.py | 2 +- fairseq/data/encoders/sentencepiece_bpe.py | 23 +- fairseq/data/encoders/space_tokenizer.py | 2 +- fairseq/data/encoders/subword_nmt_bpe.py | 28 +- fairseq/dataclass/constants.py | 2 + fairseq/dataclass/data_class.py | 487 +++++++++++------- fairseq/dataclass/utils.py | 174 +++++-- fairseq/distributed_utils.py | 228 ++++---- fairseq/hub_utils.py | 28 +- fairseq/model_parallel/megatron_trainer.py | 5 +- .../pipeline_parallel_transformer/model.py | 8 +- .../model_parallel/models/transformer_lm.py | 4 + fairseq/models/__init__.py | 24 +- fairseq/models/bart/hub_interface.py | 16 +- fairseq/models/bart/model.py | 4 +- fairseq/models/fairseq_model.py | 53 +- fairseq/models/multilingual_transformer.py | 4 +- fairseq/models/roberta/hub_interface.py | 6 +- fairseq/models/roberta/model.py | 2 +- fairseq/models/transformer.py | 11 +- fairseq/models/transformer_lm.py | 2 +- fairseq/modules/transformer_layer.py | 14 +- fairseq/optim/__init__.py | 8 +- fairseq/optim/adam.py | 25 +- fairseq/optim/bmuf.py | 23 +- fairseq/optim/fairseq_optimizer.py | 4 +- fairseq/optim/fp16_optimizer.py | 84 +-- fairseq/optim/lr_scheduler/__init__.py | 6 +- .../optim/lr_scheduler/cosine_lr_scheduler.py | 57 +- .../lr_scheduler/fairseq_lr_scheduler.py | 4 +- .../inverse_square_root_schedule.py | 35 +- fairseq/optim/nag.py | 17 +- fairseq/optim/shard.py | 2 +- fairseq/options.py | 123 +---- fairseq/quantization_utils.py | 5 +- fairseq/registry.py | 58 +-- fairseq/scoring/__init__.py | 23 +- fairseq/scoring/bleu.py | 56 +- fairseq/scoring/tokenizer.py | 6 +- fairseq/scoring/wer.py | 45 +- fairseq/tasks/__init__.py | 13 +- fairseq/tasks/audio_pretraining.py | 2 +- fairseq/tasks/fairseq_task.py | 29 +- fairseq/tasks/language_modeling.py | 12 +- fairseq/tasks/multilingual_translation.py | 10 +- fairseq/tasks/speech_to_text.py | 4 +- fairseq/trainer.py | 170 +++--- fairseq_cli/eval_lm.py | 129 +++-- fairseq_cli/generate.py | 158 +++--- fairseq_cli/interactive.py | 105 ++-- fairseq_cli/score.py | 8 +- fairseq_cli/train.py | 211 ++++---- fairseq_cli/validate.py | 64 ++- tests/speech_recognition/asr_test_base.py | 5 +- tests/test_bmuf.py | 72 ++- tests/test_fp16_optimizer.py | 35 +- tests/test_inference_dropout.py | 10 +- tests/test_memory_efficient_fp16.py | 40 +- tests/test_train.py | 63 ++- tests/utils.py | 2 +- 85 files changed, 2034 insertions(+), 1681 deletions(-) delete mode 100644 config/config_eval_lm.yaml delete mode 100644 config/params/eval_lm_params.yaml delete mode 100644 config/params/training_params.yaml diff --git a/config/config.yaml b/config/config.yaml index 66723e70..b9ee6c74 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -1,7 +1,111 @@ +# @package _group_ +common: + no_progress_bar: false + log_interval: 100 + log_format: null + tensorboard_logdir: null + seed: 1 + cpu: false + tpu: false + bf16: false + fp16: false + memory_efficient_fp16: false + memory_efficient_bf16: false + fp16_no_flatten_grads: false + fp16_init_scale: 128 + fp16_scale_window: null + fp16_scale_tolerance: 0.0 + min_loss_scale: 1.0e-4 + threshold_loss_scale: null + user_dir: null + empty_cache_freq: 0 + all_gather_list_size: 16384 + model_parallel_size: 1 + quantization_config_path: null + profile: false +distributed_training: + distributed_rank: 0 + distributed_backend: "nccl" + distributed_init_method: null + distributed_port: -1 + device_id: 0 + local_rank: 0 + distributed_no_spawn: false + ddp_backend: "c10d" + bucket_cap_mb: 25 + fix_batches_to_gpus: false + find_unused_parameters: false + fast_stat_sync: false + broadcast_buffers: false + distributed_wrapper: "DDP" + slowmo_momentum: null + slowmo_algorithm: "LocalSGD" + localsgd_frequency: 3 +dataset: + num_workers: 1 + skip_invalid_size_inputs_valid_test: false + max_tokens: null + batch_size: null + required_batch_size_multiple: 8 + dataset_impl: null + data_buffer_size: 10 + train_subset: "train" + valid_subset: "valid" + validate_interval: 1 + fixed_validation_seed: null + disable_validation: false + curriculum: 0 + gen_subset: "test" + num_shards: 1 + shard_id: 0 + max_tokens_valid: ${dataset.max_tokens} + batch_size_valid: ${dataset.batch_size} +optimization: + max_epoch: 0 + max_update: 0 + clip_norm: 25.0 + sentence_avg: false + update_freq: [ 1 ] + lr: [ 0.25 ] + min_lr: -1.0 + use_bmuf: false +checkpoint: + save_dir: "checkpoints" + restore_file: "checkpoint_last.pt" + reset_dataloader: false + reset_lr_scheduler: false + reset_meters: false + reset_optimizer: false + optimizer_overrides: "{}" + save_interval: 1 + save_interval_updates: 0 + keep_interval_updates: -1 + keep_last_epochs: -1 + keep_best_checkpoints: -1 + no_save: false + no_epoch_checkpoints: false + no_last_checkpoints: false + no_save_optimizer_state: false + best_checkpoint_metric: "loss" + maximize_best_checkpoint_metric: false + patience: -1 + checkpoint_suffix: "" +bmuf: + block_lr: 1 + block_momentum: 0.875 + global_sync_iter: 50 + warmup_iterations: 500 + use_nbm: false + average_sync: false defaults: - - params: training_params - - task: language_modeling - - model: transformer_lm - - criterion: cross_entropy - - optimizer: adam - - lr_scheduler: inverse_sqrt + - task: language_modeling + - model: null + - criterion: null + - optimizer: null + - lr_scheduler: null + - bpe: null + - tokenizer: null + - scoring: null + - generation: null + - common_eval: null + - eval_lm: null diff --git a/config/config_eval_lm.yaml b/config/config_eval_lm.yaml deleted file mode 100644 index 5a93cb5d..00000000 --- a/config/config_eval_lm.yaml +++ /dev/null @@ -1,7 +0,0 @@ -defaults: - - params: eval_lm_params - - task: language_modeling - - model: transformer_lm - - criterion: cross_entropy - - optimizer: adam - - lr_scheduler: inverse_sqrt diff --git a/config/criterion/adaptive_loss.yaml b/config/criterion/adaptive_loss.yaml index a85a7eed..7997b076 100644 --- a/config/criterion/adaptive_loss.yaml +++ b/config/criterion/adaptive_loss.yaml @@ -1,3 +1,3 @@ # @package _group_ -sentence_avg: ${params.optimization.sentence_avg} -ddp_backend: ${params.distributed_training.ddp_backend} +sentence_avg: ${optimization.sentence_avg} +ddp_backend: ${distributed_training.ddp_backend} diff --git a/config/criterion/cross_entropy.yaml b/config/criterion/cross_entropy.yaml index a85a7eed..ad3d4148 100644 --- a/config/criterion/cross_entropy.yaml +++ b/config/criterion/cross_entropy.yaml @@ -1,3 +1,2 @@ # @package _group_ -sentence_avg: ${params.optimization.sentence_avg} -ddp_backend: ${params.distributed_training.ddp_backend} +sentence_avg: ${optimization.sentence_avg} diff --git a/config/params/eval_lm_params.yaml b/config/params/eval_lm_params.yaml deleted file mode 100644 index 6f27055d..00000000 --- a/config/params/eval_lm_params.yaml +++ /dev/null @@ -1,105 +0,0 @@ -# @package _group_ -common: - no_progress_bar: false - log_interval: 100 - log_format: null - tensorboard_logdir: null - seed: 1 - cpu: false - fp16: false - memory_efficient_fp16: false - fp16_no_flatten_grads: false - fp16_init_scale: 128 - fp16_scale_window: null - fp16_scale_tolerance: 0.0 - min_loss_scale: 1.0e-4 - threshold_loss_scale: null - user_dir: null - empty_cache_freq: 0 - all_gather_list_size: 16384 - model_parallel_size: 1 - checkpoint_suffix: "" - quantization_config_path: null -distributed_training: - distributed_rank: 0 - distributed_backend: "nccl" - distributed_init_method: null - distributed_port: -1 - device_id: 0 - local_rank: 0 - distributed_no_spawn: false - ddp_backend: "c10d" - bucket_cap_mb: 25 - fix_batches_to_gpus: false - find_unused_parameters: false - fast_stat_sync: false - broadcast_buffers: false - distributed_wrapper: "DDP" - slowmo_momentum: null - slowmo_algorithm: "LocalSGD" - localsgd_frequency: 3 -dataset: - num_workers: 1 - skip_invalid_size_inputs_valid_test: false - max_tokens: null - batch_size: ${params.dataset.batch_size} - required_batch_size_multiple: 8 - dataset_impl: null - data_buffer_size: 10 - train_subset: "train" - valid_subset: "valid" - validate_interval: 1 - fixed_validation_seed: null - disable_validation: false - curriculum: 0 - gen_subset: "test" - num_shards: 1 - shard_id: 0 - max_tokens_valid: ${params.dataset.max_tokens} - batch_size_valid: ${params.dataset.batch_size} -optimization: - max_epoch: 0 - max_update: 0 - clip_norm: 25.0 - sentence_avg: false - update_freq: [1] - lr: [0.25] - min_lr: -1.0 - use_bmuf: false -checkpoint: - save_dir: "checkpoints" - restore_file: "checkpoint_last.pt" - reset_dataloader: false - reset_lr_scheduler: false - reset_meters: false - reset_optimizer: false - optimizer_overrides: "{}" - save_interval: 1 - save_interval_updates: 0 - keep_interval_updates: -1 - keep_last_epochs: -1 - keep_best_checkpoints: -1 - no_save: false - no_epoch_checkpoints: false - no_last_checkpoints: false - no_save_optimizer_state: false - best_checkpoint_metric: "loss" - maximize_best_checkpoint_metric: false - patience: -1 -common_eval: - path: null - remove_bpe: null - quiet: false - model_overrides: "{}" - results_path: null -eval_lm: - output_word_probs: false - output_word_stats: false - context_window: 0 -bmuf: - block_lr: 1 - block_momentum: 0.875 - global_sync_iter: 50 - warmup_iterations: 500 - use_nbm: false - average_sync: false diff --git a/config/params/training_params.yaml b/config/params/training_params.yaml deleted file mode 100644 index 2ce94f92..00000000 --- a/config/params/training_params.yaml +++ /dev/null @@ -1,95 +0,0 @@ -# @package _group_ -common: - no_progress_bar: false - log_interval: 100 - log_format: null - tensorboard_logdir: null - seed: 1 - cpu: false - fp16: false - memory_efficient_fp16: false - fp16_no_flatten_grads: false - fp16_init_scale: 128 - fp16_scale_window: null - fp16_scale_tolerance: 0.0 - min_loss_scale: 1.0e-4 - threshold_loss_scale: null - user_dir: null - empty_cache_freq: 0 - all_gather_list_size: 16384 - model_parallel_size: 1 - checkpoint_suffix: "" - quantization_config_path: null -distributed_training: - distributed_rank: 0 - distributed_backend: "nccl" - distributed_init_method: null - distributed_port: -1 - device_id: 0 - local_rank: 0 - distributed_no_spawn: false - ddp_backend: "c10d" - bucket_cap_mb: 25 - fix_batches_to_gpus: false - find_unused_parameters: false - fast_stat_sync: false - broadcast_buffers: false - distributed_wrapper: "DDP" - slowmo_momentum: null - slowmo_algorithm: "LocalSGD" - localsgd_frequency: 3 -dataset: - num_workers: 1 - skip_invalid_size_inputs_valid_test: false - max_tokens: null - batch_size: ${params.dataset.batch_size} - required_batch_size_multiple: 8 - dataset_impl: null - data_buffer_size: 10 - train_subset: "train" - valid_subset: "valid" - validate_interval: 1 - fixed_validation_seed: null - disable_validation: false - curriculum: 0 - gen_subset: "test" - num_shards: 1 - shard_id: 0 - max_tokens_valid: ${params.dataset.max_tokens} - batch_size_valid: ${params.dataset.batch_size} -optimization: - max_epoch: 0 - max_update: 0 - clip_norm: 25.0 - sentence_avg: false - update_freq: [1] - lr: [0.25] - min_lr: -1.0 - use_bmuf: false -checkpoint: - save_dir: "checkpoints" - restore_file: "checkpoint_last.pt" - reset_dataloader: false - reset_lr_scheduler: false - reset_meters: false - reset_optimizer: false - optimizer_overrides: "{}" - save_interval: 1 - save_interval_updates: 0 - keep_interval_updates: -1 - keep_last_epochs: -1 - keep_best_checkpoints: -1 - no_save: false - no_epoch_checkpoints: false - no_last_checkpoints: false - no_save_optimizer_state: false - best_checkpoint_metric: "loss" - maximize_best_checkpoint_metric: false - patience: -1 -bmuf: - block_lr: 1 - block_momentum: 0.875 - global_sync_iter: 50 - warmup_iterations: 500 - use_nbm: false - average_sync: false diff --git a/docs/hydra_integration.md b/docs/hydra_integration.md index 9b77dd83..0973cd27 100644 --- a/docs/hydra_integration.md +++ b/docs/hydra_integration.md @@ -13,7 +13,6 @@ For example, if we'd like to train a language model with transformer, we could p ``` defaults: - - params: training_params - task: language_modeling - model: transformer_lm - criterion: cross_entropy @@ -21,7 +20,7 @@ defaults: - lr_scheduler: inverse_sqrt ``` -- Provide generic parameters common across different training jobs: `config/params/training_params.yaml` +- Provide generic parameters common across different jobs: `config.yaml` - Provide task parameters: `config/task/language_modeling.yaml` - Provide model parameters: `config/model/transformer_lm.yaml` - Provide criterion parameters: `config/criterion/cross_entropy.yaml` @@ -41,7 +40,6 @@ Alternatively, if we need to override certain params from the command line, we c ``` python fairseq_cli/train_hydra.py -params=training_params \ task=language_modeling \ task.data=/private/home/abaevski/data/wiki103 \ task.tokens_per_sample=512 \ @@ -56,17 +54,17 @@ lr_scheduler=inverse_sqrt \ lr_scheduler.warmup_updates=4000 \ lr_scheduler.warmup_init_lr=1e-07 \ criterion=cross_entropy \ -params.common.fp16=true \ -params.common.log_format=json \ -params.common.log_interval=1 \ -params.dataset.max_tokens=1024 \ -params.dataset.num_workers=4 \ -params.optimization.update_freq=[16] \ -params.optimization.max_update=50000 \ -params.optimization.clip_norm=0.0 \ -params.optimization.lr=[0.0005] \ -params.checkpoint.save_dir=/checkpoint/mtian/transformer_wikitext-103-hydra-args-cli \ -params.checkpoint.save_interval_updates=10 +common.fp16=true \ +common.log_format=json \ +common.log_interval=1 \ +dataset.max_tokens=1024 \ +dataset.num_workers=4 \ +optimization.update_freq=[16] \ +optimization.max_update=50000 \ +optimization.clip_norm=0.0 \ +optimization.lr=[0.0005] \ +checkpoint.save_dir=/checkpoint/mtian/transformer_wikitext-103-hydra-args-cli \ +checkpoint.save_interval_updates=10 ``` ## Migrate existing/Creating new modules to hydra interface diff --git a/docs/tutorial_classifying_names.rst b/docs/tutorial_classifying_names.rst index 40a3cb6f..b02fec04 100644 --- a/docs/tutorial_classifying_names.rst +++ b/docs/tutorial_classifying_names.rst @@ -212,7 +212,7 @@ following contents:: @register_task('simple_classification') - class SimpleClassificationTask(FairseqTask): + class SimpleClassificationTask(LegacyFairseqTask): @staticmethod def add_args(parser): diff --git a/examples/noisychannel/rerank.py b/examples/noisychannel/rerank.py index 4df424e6..13036926 100644 --- a/examples/noisychannel/rerank.py +++ b/examples/noisychannel/rerank.py @@ -27,7 +27,13 @@ def score_target_hypo( print("lenpen", lenpen, "weight1", a, "weight2", b, "weight3", c) gen_output_lst, bitext1_lst, bitext2_lst, lm_res_lst = load_score_files(args) dict = dictionary.Dictionary() - scorer = bleu.Scorer(dict.pad(), dict.eos(), dict.unk()) + scorer = scorer = bleu.Scorer( + bleu.BleuConfig( + pad=dict.pad(), + eos=dict.eos(), + unk=dict.unk(), + ) + ) ordered_hypos = {} ordered_targets = {} diff --git a/examples/roberta/wsc/wsc_criterion.py b/examples/roberta/wsc/wsc_criterion.py index 1a590123..ed0251fd 100644 --- a/examples/roberta/wsc/wsc_criterion.py +++ b/examples/roberta/wsc/wsc_criterion.py @@ -20,8 +20,8 @@ class WSCCriterion(LegacyFairseqCriterion): self.prediction_h = open(self.args.save_predictions, "w") else: self.prediction_h = None - self.bpe = encoders.build_bpe(args) - self.tokenizer = encoders.build_tokenizer(args) + self.bpe = encoders.build_bpe(args.bpe) + self.tokenizer = encoders.build_tokenizer(args.tokenizer) def __del__(self): if self.prediction_h is not None: diff --git a/examples/unsupervised_quality_estimation/README.md b/examples/unsupervised_quality_estimation/README.md index 809a58e4..aeb96a14 100644 --- a/examples/unsupervised_quality_estimation/README.md +++ b/examples/unsupervised_quality_estimation/README.md @@ -85,7 +85,7 @@ Produce model scores for the generated translations using `--retain-dropout` opt ``` CUDA_VISIBLE_DEVICES=${GPU} fairseq-generate ${TMP}/bin-repeated --path ${MODEL_DIR}/${LP}.pt --beam 5 --source-lang $SRC_LANG --target-lang $TGT_LANG --no-progress-bar --unkpen 5 --score-reference --retain-dropout - --retain-dropout-modules TransformerModel TransformerEncoder TransformerDecoder TransformerEncoderLayer + --retain-dropout-modules '["TransformerModel","TransformerEncoder","TransformerDecoder","TransformerEncoderLayer"]' TransformerDecoderLayer --seed 46 > $TMP/dropout.scoring.out grep ^H $TMP/dropout.scoring.out | cut -f2- > $TMP/dropout.scores diff --git a/fairseq/checkpoint_utils.py b/fairseq/checkpoint_utils.py index 75e2c68c..c036e129 100644 --- a/fairseq/checkpoint_utils.py +++ b/fairseq/checkpoint_utils.py @@ -3,36 +3,42 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import ast import collections import logging import os import re import traceback from collections import OrderedDict -from typing import Union +from typing import Optional, Union import torch +from fairseq.dataclass.utils import ( + convert_namespace_to_omegaconf, + overwrite_args_by_name, +) from fairseq.file_io import PathManager from fairseq.models import FairseqDecoder, FairseqEncoder +from omegaconf import DictConfig, open_dict from torch.serialization import default_restore_location logger = logging.getLogger(__name__) -def save_checkpoint(args, trainer, epoch_itr, val_loss): - from fairseq import distributed_utils, meters +def save_checkpoint(cfg: DictConfig, trainer, epoch_itr, val_loss): + from fairseq import meters # only one worker should attempt to create the required dir - if args.distributed_rank == 0: - os.makedirs(args.save_dir, exist_ok=True) + if cfg.distributed_rank == 0: + os.makedirs(cfg.save_dir, exist_ok=True) prev_best = getattr(save_checkpoint, "best", val_loss) if val_loss is not None: - best_function = max if args.maximize_best_checkpoint_metric else min + best_function = max if cfg.maximize_best_checkpoint_metric else min save_checkpoint.best = best_function(val_loss, prev_best) - if args.no_save: + if cfg.no_save: return trainer.consolidate_optimizer() @@ -41,7 +47,7 @@ def save_checkpoint(args, trainer, epoch_itr, val_loss): return def is_better(a, b): - return a >= b if args.maximize_best_checkpoint_metric else a <= b + return a >= b if cfg.maximize_best_checkpoint_metric else a <= b write_timer = meters.StopwatchMeter() write_timer.start() @@ -50,38 +56,36 @@ def save_checkpoint(args, trainer, epoch_itr, val_loss): end_of_epoch = epoch_itr.end_of_epoch() updates = trainer.get_num_updates() - suffix = getattr(args, "checkpoint_suffix", "") + suffix = cfg.checkpoint_suffix or "" checkpoint_conds = collections.OrderedDict() checkpoint_conds["checkpoint{}{}.pt".format(epoch, suffix)] = ( - end_of_epoch - and not args.no_epoch_checkpoints - and epoch % args.save_interval == 0 + end_of_epoch and not cfg.no_epoch_checkpoints and epoch % cfg.save_interval == 0 ) checkpoint_conds["checkpoint_{}_{}{}.pt".format(epoch, updates, suffix)] = ( not end_of_epoch - and args.save_interval_updates > 0 - and updates % args.save_interval_updates == 0 + and cfg.save_interval_updates > 0 + and updates % cfg.save_interval_updates == 0 ) checkpoint_conds["checkpoint_best{}.pt".format(suffix)] = val_loss is not None and ( not hasattr(save_checkpoint, "best") or is_better(val_loss, save_checkpoint.best) ) - if val_loss is not None and args.keep_best_checkpoints > 0: + if val_loss is not None and cfg.keep_best_checkpoints > 0: checkpoint_conds[ - "checkpoint.best_{}_{:.2f}.pt".format(args.best_checkpoint_metric, val_loss) + "checkpoint.best_{}_{:.2f}.pt".format(cfg.best_checkpoint_metric, val_loss) ] = not hasattr(save_checkpoint, "best") or is_better( val_loss, save_checkpoint.best ) checkpoint_conds[ "checkpoint_last{}.pt".format(suffix) - ] = not args.no_last_checkpoints + ] = not cfg.no_last_checkpoints extra_state = {"train_iterator": epoch_itr.state_dict(), "val_loss": val_loss} if hasattr(save_checkpoint, "best"): extra_state.update({"best": save_checkpoint.best}) checkpoints = [ - os.path.join(args.save_dir, fn) for fn, cond in checkpoint_conds.items() if cond + os.path.join(cfg.save_dir, fn) for fn, cond in checkpoint_conds.items() if cond ] if len(checkpoints) > 0: trainer.save_checkpoint(checkpoints[0], extra_state) @@ -95,51 +99,52 @@ def save_checkpoint(args, trainer, epoch_itr, val_loss): ) ) - if not end_of_epoch and args.keep_interval_updates > 0: + if not end_of_epoch and cfg.keep_interval_updates > 0: # remove old checkpoints; checkpoints are sorted in descending order checkpoints = checkpoint_paths( - args.save_dir, pattern=r"checkpoint_\d+_(\d+)\.pt" + cfg.save_dir, pattern=r"checkpoint_\d+_(\d+)\.pt" ) - for old_chk in checkpoints[args.keep_interval_updates :]: + for old_chk in checkpoints[cfg.keep_interval_updates :]: if os.path.lexists(old_chk): os.remove(old_chk) - if args.keep_last_epochs > 0: + if cfg.keep_last_epochs > 0: # remove old epoch checkpoints; checkpoints are sorted in descending order - checkpoints = checkpoint_paths(args.save_dir, pattern=r"checkpoint(\d+)\.pt") - for old_chk in checkpoints[args.keep_last_epochs :]: + checkpoints = checkpoint_paths(cfg.save_dir, pattern=r"checkpoint(\d+)\.pt") + for old_chk in checkpoints[cfg.keep_last_epochs :]: if os.path.lexists(old_chk): os.remove(old_chk) - if args.keep_best_checkpoints > 0: + if cfg.keep_best_checkpoints > 0: # only keep the best N checkpoints according to validation metric checkpoints = checkpoint_paths( - args.save_dir, + cfg.save_dir, pattern=r"checkpoint\.best_{}_(\d+\.?\d*)\.pt".format( - args.best_checkpoint_metric + cfg.best_checkpoint_metric ), ) - if not args.maximize_best_checkpoint_metric: + if not cfg.maximize_best_checkpoint_metric: checkpoints = checkpoints[::-1] - for old_chk in checkpoints[args.keep_best_checkpoints :]: + for old_chk in checkpoints[cfg.keep_best_checkpoints :]: if os.path.lexists(old_chk): os.remove(old_chk) -def load_checkpoint(args, trainer, **passthrough_args): +def load_checkpoint(cfg: DictConfig, trainer, **passthrough_args): """ Load a checkpoint and restore the training iterator. *passthrough_args* will be passed through to ``trainer.get_train_iterator``. """ - reset_optimizer = args.reset_optimizer - reset_lr_scheduler = args.reset_lr_scheduler - optimizer_overrides = eval(args.optimizer_overrides) - reset_meters = args.reset_meters - reset_dataloader = args.reset_dataloader - if getattr(args, "finetune_from_model", None) is not None and ( + reset_optimizer = cfg.reset_optimizer + reset_lr_scheduler = cfg.reset_lr_scheduler + optimizer_overrides = ast.literal_eval(cfg.optimizer_overrides) + reset_meters = cfg.reset_meters + reset_dataloader = cfg.reset_dataloader + + if cfg.finetune_from_model is not None and ( reset_optimizer or reset_lr_scheduler or reset_meters or reset_dataloader ): raise ValueError( @@ -147,19 +152,19 @@ def load_checkpoint(args, trainer, **passthrough_args): " or reset_lr_scheduler or reset_meters or reset_dataloader" ) - suffix = getattr(args, "checkpoint_suffix", "") + suffix = cfg.checkpoint_suffix if ( - args.restore_file == "checkpoint_last.pt" + cfg.restore_file == "checkpoint_last.pt" ): # default value of restore_file is 'checkpoint_last.pt' checkpoint_path = os.path.join( - args.save_dir, "checkpoint_last{}.pt".format(suffix) + cfg.save_dir, "checkpoint_last{}.pt".format(suffix) ) first_launch = not PathManager.exists(checkpoint_path) - if getattr(args, "finetune_from_model", None) is not None and first_launch: + if cfg.finetune_from_model is not None and first_launch: # if there is no last checkpoint to restore, start the finetune from pretrained model # else just use usual logic to load checkpoint, e.g. restart from last checkpoint and etc. - if PathManager.exists(args.finetune_from_model): - checkpoint_path = args.finetune_from_model + if PathManager.exists(cfg.finetune_from_model): + checkpoint_path = cfg.finetune_from_model reset_optimizer = True reset_lr_scheduler = True reset_meters = True @@ -170,19 +175,17 @@ def load_checkpoint(args, trainer, **passthrough_args): ) else: raise ValueError( - f"--funetune-from-model {args.finetune_from_model} does not exist" + f"--funetune-from-model {cfg.finetune_from_model} does not exist" ) - elif getattr(args, "model_parallel_size", 1) > 1: - checkpoint_path = args.restore_file.replace(".pt", suffix + ".pt") + elif cfg.model_parallel_size > 1: + checkpoint_path = cfg.restore_file.replace(".pt", suffix + ".pt") else: - checkpoint_path = args.restore_file + checkpoint_path = cfg.restore_file - if args.restore_file != "checkpoint_last.pt" and getattr( - args, "finetune_from_model", None - ): + if cfg.restore_file != "checkpoint_last.pt" and cfg.finetune_from_model: raise ValueError( "--finetune-from-model and --restore-file (non-default value) " - "can not be specified together: " + str(args) + "can not be specified together: " + str(cfg) ) extra_state = trainer.load_checkpoint( @@ -225,10 +228,14 @@ def load_checkpoint_to_cpu(path, arg_overrides=None): f, map_location=lambda s, l: default_restore_location(s, "cpu") ) - args = state["args"] - if arg_overrides is not None: + if "args" in state and state["args"] is not None and arg_overrides is not None: + args = state["args"] for arg_name, arg_val in arg_overrides.items(): setattr(args, arg_name, arg_val) + + if "cfg" in state and state["cfg"] is not None and arg_overrides is not None: + overwrite_args_by_name(state["cfg"], arg_overrides) + state = _upgrade_state_dict(state) return state @@ -274,19 +281,28 @@ def load_model_ensemble_and_task( filename = filename.replace(".pt", suffix + ".pt") else: filename = orig_filename[:-3] + f"_part{shard_idx}.pt" + if not PathManager.exists(filename): raise IOError("Model file not found: {}".format(filename)) state = load_checkpoint_to_cpu(filename, arg_overrides) - if shard_idx == 0: - args = state["args"] - if task is None: - task = tasks.setup_task(args) + if "args" in state and state["args"] is not None: + cfg = convert_namespace_to_omegaconf(state["args"]) + elif "cfg" in state and state["cfg"] is not None: + cfg = state["cfg"] + else: + raise RuntimeError( + f"Neither args nor cfg exist in state keys = {state.keys()}" + ) - # build model for ensemble - model = task.build_model(args) - model.load_state_dict(state["model"], strict=strict, args=args) + if task is None: + task = tasks.setup_task(cfg.task) + + # build model for ensemble + model = task.build_model(cfg.model) + + model.load_state_dict(state["model"], strict=strict, model_cfg=cfg.model) ensemble.append(model) - return ensemble, args, task + return ensemble, cfg, task def checkpoint_paths(path, pattern=r"checkpoint(\d+)\.pt"): @@ -323,7 +339,7 @@ def torch_persistent_save(obj, f): def save_state( filename, - args, + cfg: DictConfig, model_state_dict, criterion, optimizer, @@ -331,6 +347,7 @@ def save_state( num_updates, optim_history=None, extra_state=None, + **kwargs, ): from fairseq import utils @@ -339,7 +356,8 @@ def save_state( if extra_state is None: extra_state = {} state_dict = { - "args": args, + "cfg": cfg, + "args": kwargs.get("args", None), "model": model_state_dict or {}, "optimizer_history": optim_history + [ @@ -354,11 +372,17 @@ def save_state( } if utils.has_parameters(criterion): state_dict["criterion"] = criterion.state_dict() - if not args.no_save_optimizer_state: - state_dict["last_optimizer_state"] = optimizer.state_dict() - # convert all state to CPU - state_dict = utils.move_to_cpu(state_dict) + if cfg is None: + cfg = state_dict["args"] + assert cfg is not None, "must provide cfg or args" + + if isinstance(cfg, DictConfig): + no_save_optimizer_state = cfg.checkpoint.no_save_optimizer_state + else: + no_save_optimizer_state = cfg.no_save_optimizer_state + if not no_save_optimizer_state: + state_dict["last_optimizer_state"] = optimizer.state_dict() with PathManager.open(filename, "wb") as f: torch_persistent_save(state_dict, f) @@ -403,46 +427,49 @@ def _upgrade_state_dict(state): # keep track of number of updates if "num_updates" not in state["optimizer_history"][-1]: state["optimizer_history"][-1]["num_updates"] = 0 - # old model checkpoints may not have separate source/target positions - if hasattr(state["args"], "max_positions") and not hasattr( - state["args"], "max_source_positions" - ): - state["args"].max_source_positions = state["args"].max_positions - state["args"].max_target_positions = state["args"].max_positions # use stateful training data iterator if "train_iterator" not in state["extra_state"]: state["extra_state"]["train_iterator"] = { "epoch": state["extra_state"]["epoch"], "iterations_in_epoch": state["extra_state"].get("batch_offset", 0), } - # default to translation task - if not hasattr(state["args"], "task"): - state["args"].task = "translation" - # --raw-text and --lazy-load are deprecated - if getattr(state["args"], "raw_text", False): - state["args"].dataset_impl = "raw" - elif getattr(state["args"], "lazy_load", False): - state["args"].dataset_impl = "lazy" - # epochs start at 1 - if state["extra_state"]["train_iterator"] is not None: - state["extra_state"]["train_iterator"]["epoch"] = max( - state["extra_state"]["train_iterator"].get("epoch", 1), - 1, - ) - # set any missing default values in the task, model or other registries - registry.set_defaults(state["args"], tasks.TASK_REGISTRY[state["args"].task]) - registry.set_defaults(state["args"], models.ARCH_MODEL_REGISTRY[state["args"].arch]) - for registry_name, REGISTRY in registry.REGISTRIES.items(): - choice = getattr(state["args"], registry_name, None) - if choice is not None: - cls = REGISTRY["registry"][choice] - registry.set_defaults(state["args"], cls) + # old model checkpoints may not have separate source/target positions + # backward compatibility, cfg updates + if "args" in state and state["args"] is not None: + # default to translation task + if not hasattr(state["args"], "task"): + state["args"].task = "translation" + # --raw-text and --lazy-load are deprecated + if getattr(state["args"], "raw_text", False): + state["args"].dataset_impl = "raw" + elif getattr(state["args"], "lazy_load", False): + state["args"].dataset_impl = "lazy" + # epochs start at 1 + if state["extra_state"]["train_iterator"] is not None: + state["extra_state"]["train_iterator"]["epoch"] = max( + state["extra_state"]["train_iterator"].get("epoch", 1), 1 + ) + + state["cfg"] = convert_namespace_to_omegaconf(state["args"]) + + if "cfg" in state and state["cfg"] is not None: + with open_dict(state["cfg"]): + if state["cfg"].task is not None: + if hasattr(state["cfg"].task, "max_positions") and not hasattr( + state["cfg"].task, "max_source_positions" + ): + state["cfg"].task.max_source_positions = state[ + "cfg" + ].task.max_positions + state["cfg"].task.max_target_positions = state[ + "cfg" + ].task.max_positions return state -def prune_state_dict(state_dict, args): +def prune_state_dict(state_dict, model_cfg: Optional[DictConfig]): """Prune the given state_dict if desired for LayerDrop (https://arxiv.org/abs/1909.11556). @@ -453,16 +480,20 @@ def prune_state_dict(state_dict, args): It's called by functions that load models from checkpoints and does not need to be called directly. """ - if not args or args.arch == "ptt_transformer": + arch = None + if model_cfg is not None: + arch = ( + model_cfg._name + if isinstance(model_cfg, DictConfig) + else getattr(model_cfg, "arch", None) + ) + + if not model_cfg or arch is None or arch == "ptt_transformer": # args should not be none, but don't crash if it is. return state_dict - encoder_layers_to_keep = ( - args.encoder_layers_to_keep if "encoder_layers_to_keep" in vars(args) else None - ) - decoder_layers_to_keep = ( - args.decoder_layers_to_keep if "decoder_layers_to_keep" in vars(args) else None - ) + encoder_layers_to_keep = getattr(model_cfg, "encoder_layers_to_keep", None) + decoder_layers_to_keep = getattr(model_cfg, "decoder_layers_to_keep", None) if not encoder_layers_to_keep and not decoder_layers_to_keep: return state_dict @@ -474,7 +505,7 @@ def prune_state_dict(state_dict, args): def create_pruning_pass(layers_to_keep, layer_name): keep_layers = sorted( - [int(layer_string) for layer_string in layers_to_keep.split(",")] + int(layer_string) for layer_string in layers_to_keep.split(",") ) mapping_dict = {} for i in range(len(keep_layers)): @@ -518,10 +549,12 @@ def prune_state_dict(state_dict, args): # Since layers are now pruned, *_layers_to_keep are no longer needed. # This is more of "It would make it work fix" rather than a proper fix. - if "encoder_layers_to_keep" in vars(args): - args.encoder_layers_to_keep = None - if "decoder_layers_to_keep" in vars(args): - args.decoder_layers_to_keep = None + + with open_dict(model_cfg): + if hasattr(model_cfg, "encoder_layers_to_keep"): + model_cfg.encoder_layers_to_keep = None + if hasattr(model_cfg, "decoder_layers_to_keep"): + model_cfg.decoder_layers_to_keep = None return new_state_dict diff --git a/fairseq/criterions/__init__.py b/fairseq/criterions/__init__.py index a7eb5f6f..8cc6c0f0 100644 --- a/fairseq/criterions/__init__.py +++ b/fairseq/criterions/__init__.py @@ -6,8 +6,6 @@ import importlib import os -from argparse import Namespace -from typing import Union from fairseq import registry from fairseq.criterions.fairseq_criterion import ( # noqa @@ -27,8 +25,8 @@ from omegaconf import DictConfig ) -def build_criterion(criterion_cfg: Union[DictConfig, Namespace], task): - return build_criterion_(criterion_cfg, task) +def build_criterion(cfg: DictConfig, task): + return build_criterion_(cfg, task) # automatically import any Python files in the criterions/ directory diff --git a/fairseq/criterions/adaptive_loss.py b/fairseq/criterions/adaptive_loss.py index 74ba37c3..04832295 100644 --- a/fairseq/criterions/adaptive_loss.py +++ b/fairseq/criterions/adaptive_loss.py @@ -11,13 +11,13 @@ from fairseq import metrics, utils from fairseq.criterions import FairseqCriterion, register_criterion from fairseq.dataclass import FairseqDataclass from fairseq.dataclass.constants import DDP_BACKEND_CHOICES -from omegaconf import II +from omegaconf import II, DictConfig @dataclass class AdaptiveLossConfig(FairseqDataclass): - sentence_avg: bool = II("params.optimization.sentence_avg") - ddp_backend: DDP_BACKEND_CHOICES = II("params.distributed_training.ddp_backend") + sentence_avg: bool = II("optimization.sentence_avg") + ddp_backend: DDP_BACKEND_CHOICES = II("distributed_training.ddp_backend") @register_criterion("adaptive_loss", dataclass=AdaptiveLossConfig) @@ -31,14 +31,14 @@ class AdaptiveLoss(FairseqCriterion): self.sentence_avg = sentence_avg @classmethod - def build_criterion(cls, args, task): - if getattr(args, "ddp_backend", None) == "c10d": + def build_criterion(cls, cfg: DictConfig, task): + if cfg.ddp_backend == "c10d": raise Exception( "AdaptiveLoss is not compatible with the c10d " "version of DistributedDataParallel. Please use " "`--ddp-backend=no_c10d` instead." ) - return cls(task, args.sentence_avg) + return cls(task, cfg.sentence_avg) def forward(self, model, sample, reduce=True): """Compute the loss for the given sample. diff --git a/fairseq/criterions/cross_entropy.py b/fairseq/criterions/cross_entropy.py index 91b58545..758e7276 100644 --- a/fairseq/criterions/cross_entropy.py +++ b/fairseq/criterions/cross_entropy.py @@ -15,7 +15,7 @@ from omegaconf import II @dataclass class CrossEntropyCriterionConfig(FairseqDataclass): - sentence_avg: bool = II("params.optimization.sentence_avg") + sentence_avg: bool = II("optimization.sentence_avg") @register_criterion("cross_entropy", dataclass=CrossEntropyCriterionConfig) diff --git a/fairseq/criterions/ctc.py b/fairseq/criterions/ctc.py index 4f93b3cb..9310024f 100644 --- a/fairseq/criterions/ctc.py +++ b/fairseq/criterions/ctc.py @@ -10,24 +10,24 @@ from argparse import Namespace import torch import torch.nn.functional as F from fairseq import metrics, utils -from fairseq.criterions import FairseqCriterion, register_criterion +from fairseq.criterions import LegacyFairseqCriterion, register_criterion from fairseq.data.data_utils import post_process from fairseq.logging.meters import safe_round @register_criterion("ctc") -class CtcCriterion(FairseqCriterion): - def __init__(self, task, wer_args, zero_infinity, sentence_avg, remove_bpe): - super().__init__(task) +class CtcCriterion(LegacyFairseqCriterion): + def __init__(self, args, task): + super().__init__(args, task) self.blank_idx = task.target_dictionary.bos() self.pad_idx = task.target_dictionary.pad() self.eos_idx = task.target_dictionary.eos() - self.post_process = remove_bpe if remove_bpe else "letter" + self.post_process = args.remove_bpe if args.remove_bpe else "letter" - if wer_args is not None: + if args.wer_args is not None: from examples.speech_recognition.w2l_decoder import W2lKenLMDecoder - wer_compute_kenlm, wer_lexicon, lm_w, ws_w = eval(wer_args) + wer_compute_kenlm, wer_lexicon, lm_w, ws_w = eval(args.wer_args) dec_args = Namespace() dec_args.nbest = 1 @@ -46,8 +46,8 @@ class CtcCriterion(FairseqCriterion): else: self.w2l_decoder = None - self.zero_infinity = zero_infinity - self.sentence_avg = sentence_avg + self.zero_infinity = args.zero_infinity + self.sentence_avg = args.sentence_avg @staticmethod def add_args(parser): diff --git a/fairseq/criterions/fairseq_criterion.py b/fairseq/criterions/fairseq_criterion.py index ef94a863..b2eda1a7 100644 --- a/fairseq/criterions/fairseq_criterion.py +++ b/fairseq/criterions/fairseq_criterion.py @@ -8,6 +8,7 @@ from typing import Any, Dict, List from fairseq import metrics, utils from fairseq.dataclass.utils import gen_parser_from_dataclass +from omegaconf import DictConfig from torch.nn.modules.loss import _Loss @@ -27,10 +28,8 @@ class FairseqCriterion(_Loss): gen_parser_from_dataclass(parser, dc()) @classmethod - def build_criterion(cls, args, task): + def build_criterion(cls, cfg: DictConfig, task): """Construct a criterion from command-line args.""" - # Criterions can override this, but for convenience we also try - # to automatically map argparse.Namespace keys to corresponding # arguments in the __init__. init_args = {} for p in inspect.signature(cls).parameters.values(): @@ -47,8 +46,8 @@ class FairseqCriterion(_Loss): if p.name == "task": init_args["task"] = task - elif hasattr(args, p.name): - init_args[p.name] = getattr(args, p.name) + elif hasattr(cfg, p.name): + init_args[p.name] = getattr(cfg, p.name) elif p.default != p.empty: pass # we'll use the default value else: @@ -70,7 +69,7 @@ class FairseqCriterion(_Loss): @staticmethod def aggregate_logging_outputs( - logging_outputs: List[Dict[str, Any]], + logging_outputs: List[Dict[str, Any]] ) -> Dict[str, Any]: """Aggregate logging outputs from data parallel training.""" utils.deprecation_warning( diff --git a/fairseq/data/encoders/byte_bpe.py b/fairseq/data/encoders/byte_bpe.py index 0d2da3ea..31e3a062 100644 --- a/fairseq/data/encoders/byte_bpe.py +++ b/fairseq/data/encoders/byte_bpe.py @@ -4,6 +4,8 @@ # LICENSE file in the root directory of this source tree. +from dataclasses import dataclass, field + from fairseq import file_utils from fairseq.data.encoders import register_bpe from fairseq.data.encoders.byte_utils import ( @@ -12,19 +14,20 @@ from fairseq.data.encoders.byte_utils import ( byte_encode, smart_byte_decode, ) +from fairseq.dataclass import FairseqDataclass -@register_bpe("byte_bpe") +@dataclass +class ByteBpeConfig(FairseqDataclass): + sentencepiece_model_path: str = field( + default="???", metadata={"help": "path to sentencepiece model"} + ) + + +@register_bpe("byte_bpe", dataclass=ByteBpeConfig) class ByteBPE(object): - @staticmethod - def add_args(parser): - # fmt: off - parser.add_argument('--sentencepiece-model-path', type=str, - help='path to sentencepiece model') - # fmt: on - - def __init__(self, args): - vocab = file_utils.cached_path(args.sentencepiece_model_path) + def __init__(self, cfg): + vocab = file_utils.cached_path(cfg.sentencepiece_model_path) try: import sentencepiece as spm diff --git a/fairseq/data/encoders/bytes.py b/fairseq/data/encoders/bytes.py index bb9554ed..f88f8f69 100644 --- a/fairseq/data/encoders/bytes.py +++ b/fairseq/data/encoders/bytes.py @@ -15,7 +15,7 @@ from fairseq.data.encoders.byte_utils import ( @register_bpe("bytes") class Bytes(object): - def __init__(self, args): + def __init__(self, *unused): pass @staticmethod diff --git a/fairseq/data/encoders/characters.py b/fairseq/data/encoders/characters.py index cffc5751..494ea219 100644 --- a/fairseq/data/encoders/characters.py +++ b/fairseq/data/encoders/characters.py @@ -13,7 +13,7 @@ SPACE_ESCAPE = chr(9601) @register_bpe("characters") class Characters(object): - def __init__(self, args): + def __init__(self, *unused): pass @staticmethod diff --git a/fairseq/data/encoders/fastbpe.py b/fairseq/data/encoders/fastbpe.py index 74d4ad85..f7c21039 100644 --- a/fairseq/data/encoders/fastbpe.py +++ b/fairseq/data/encoders/fastbpe.py @@ -3,23 +3,24 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from dataclasses import dataclass, field + from fairseq import file_utils from fairseq.data.encoders import register_bpe +from fairseq.dataclass import FairseqDataclass -@register_bpe("fastbpe") +@dataclass +class fastBPEConfig(FairseqDataclass): + bpe_codes: str = field(default="???", metadata={"help": "path to fastBPE BPE"}) + + +@register_bpe("fastbpe", dataclass=fastBPEConfig) class fastBPE(object): - @staticmethod - def add_args(parser): - # fmt: off - parser.add_argument('--bpe-codes', type=str, - help='path to fastBPE BPE') - # fmt: on - - def __init__(self, args): - if args.bpe_codes is None: + def __init__(self, cfg): + if cfg.bpe_codes is None: raise ValueError("--bpe-codes is required for --bpe=fastbpe") - codes = file_utils.cached_path(args.bpe_codes) + codes = file_utils.cached_path(cfg.bpe_codes) try: import fastBPE diff --git a/fairseq/data/encoders/gpt2_bpe.py b/fairseq/data/encoders/gpt2_bpe.py index 8ac099a6..e661426a 100644 --- a/fairseq/data/encoders/gpt2_bpe.py +++ b/fairseq/data/encoders/gpt2_bpe.py @@ -3,8 +3,11 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from dataclasses import dataclass, field + from fairseq import file_utils from fairseq.data.encoders import register_bpe +from fairseq.dataclass import FairseqDataclass from .gpt2_bpe_utils import get_encoder @@ -13,26 +16,21 @@ DEFAULT_ENCODER_JSON = "https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder. DEFAULT_VOCAB_BPE = "https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe" -@register_bpe("gpt2") -class GPT2BPE(object): - @staticmethod - def add_args(parser): - # fmt: off - parser.add_argument('--gpt2-encoder-json', type=str, - default=DEFAULT_ENCODER_JSON, - help='path to encoder.json') - parser.add_argument('--gpt2-vocab-bpe', type=str, - default=DEFAULT_VOCAB_BPE, - help='path to vocab.bpe') - # fmt: on +@dataclass +class GPT2BPEConfig(FairseqDataclass): + gpt2_encoder_json: str = field( + default=DEFAULT_ENCODER_JSON, metadata={"help": "path to encoder.json"} + ) + gpt2_vocab_bpe: str = field( + default=DEFAULT_VOCAB_BPE, metadata={"help": "path to vocab.bpe"} + ) - def __init__(self, args): - encoder_json = file_utils.cached_path( - getattr(args, "gpt2_encoder_json", DEFAULT_ENCODER_JSON) - ) - vocab_bpe = file_utils.cached_path( - getattr(args, "gpt2_vocab_bpe", DEFAULT_VOCAB_BPE) - ) + +@register_bpe("gpt2", dataclass=GPT2BPEConfig) +class GPT2BPE(object): + def __init__(self, cfg): + encoder_json = file_utils.cached_path(cfg.gpt2_encoder_json) + vocab_bpe = file_utils.cached_path(cfg.gpt2_vocab_bpe) self.bpe = get_encoder(encoder_json, vocab_bpe) def encode(self, x: str) -> str: diff --git a/fairseq/data/encoders/hf_bert_bpe.py b/fairseq/data/encoders/hf_bert_bpe.py index a968fe88..a41c0593 100644 --- a/fairseq/data/encoders/hf_bert_bpe.py +++ b/fairseq/data/encoders/hf_bert_bpe.py @@ -3,22 +3,24 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from dataclasses import dataclass, field +from typing import Optional + from fairseq.data.encoders import register_bpe +from fairseq.dataclass import FairseqDataclass -@register_bpe("bert") +@dataclass +class BertBPEConfig(FairseqDataclass): + bpe_cased: bool = field(default=False, metadata={"help": "set for cased BPE"}) + bpe_vocab_file: Optional[str] = field( + default=None, metadata={"help": "bpe vocab file"} + ) + + +@register_bpe("bert", dataclass=BertBPEConfig) class BertBPE(object): - @staticmethod - def add_args(parser): - # fmt: off - parser.add_argument('--bpe-cased', action='store_true', - help='set for cased BPE', - default=False) - parser.add_argument('--bpe-vocab-file', type=str, - help='bpe vocab file.') - # fmt: on - - def __init__(self, args): + def __init__(self, cfg): try: from transformers import BertTokenizer except ImportError: @@ -26,13 +28,13 @@ class BertBPE(object): "Please install transformers with: pip install transformers" ) - if "bpe_vocab_file" in args: + if cfg.bpe_vocab_file: self.bert_tokenizer = BertTokenizer( - args.bpe_vocab_file, do_lower_case=not args.bpe_cased + cfg.bpe_vocab_file, do_lower_case=not cfg.bpe_cased ) else: vocab_file_name = ( - "bert-base-cased" if args.bpe_cased else "bert-base-uncased" + "bert-base-cased" if cfg.bpe_cased else "bert-base-uncased" ) self.bert_tokenizer = BertTokenizer.from_pretrained(vocab_file_name) diff --git a/fairseq/data/encoders/hf_byte_bpe.py b/fairseq/data/encoders/hf_byte_bpe.py index 544d4082..92d2c392 100644 --- a/fairseq/data/encoders/hf_byte_bpe.py +++ b/fairseq/data/encoders/hf_byte_bpe.py @@ -3,21 +3,24 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from dataclasses import dataclass, field + from fairseq.data.encoders import register_bpe +from fairseq.dataclass import FairseqDataclass -@register_bpe("hf_byte_bpe") +@dataclass +class HuggingFaceByteLevelBPEConfig(FairseqDataclass): + bpe_merges: str = field(default="???", metadata={"help": "path to merges.txt"}) + bpe_vocab: str = field(default="???", metadata={"help": "path to vocab.json"}) + bpe_add_prefix_space: bool = field( + default=False, metadata={"help": "add prefix space before encoding"} + ) + + +@register_bpe("hf_byte_bpe", dataclass=HuggingFaceByteLevelBPEConfig) class HuggingFaceByteLevelBPE(object): - @staticmethod - def add_args(parser): - # fmt: off - parser.add_argument('--bpe-merges', help='path to merges.txt') - parser.add_argument('--bpe-vocab', help='path to vocab.json') - parser.add_argument('--bpe-add-prefix-space', action='store_true', - help='add prefix space before encoding') - # fmt: on - - def __init__(self, args): + def __init__(self, cfg): try: from tokenizers import ByteLevelBPETokenizer except ImportError: @@ -26,9 +29,9 @@ class HuggingFaceByteLevelBPE(object): ) self.bpe = ByteLevelBPETokenizer( - args.bpe_vocab, - args.bpe_merges, - add_prefix_space=getattr(args, "bpe_add_prefix_space", False), + cfg.bpe_vocab, + cfg.bpe_merges, + add_prefix_space=cfg.bpe_add_prefix_space, ) def encode(self, x: str) -> str: diff --git a/fairseq/data/encoders/moses_tokenizer.py b/fairseq/data/encoders/moses_tokenizer.py index 8c248442..fa004dd4 100644 --- a/fairseq/data/encoders/moses_tokenizer.py +++ b/fairseq/data/encoders/moses_tokenizer.py @@ -3,37 +3,35 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from dataclasses import dataclass, field + from fairseq.data.encoders import register_tokenizer +from fairseq.dataclass import FairseqDataclass -@register_tokenizer("moses") +@dataclass +class MosesTokenizerConfig(FairseqDataclass): + source_lang: str = field(default="en", metadata={"help": "source language"}) + target_lang: str = field(default="en", metadata={"help": "target language"}) + moses_no_dash_splits: bool = field( + default=False, metadata={"help": "don't apply dash split rules"} + ) + moses_no_escape: bool = field( + default=False, + metadata={"help": "don't perform HTML escaping on apostrophe, quotes, etc."}, + ) + + +@register_tokenizer("moses", dataclass=MosesTokenizerConfig) class MosesTokenizer(object): - @staticmethod - def add_args(parser): - # fmt: off - parser.add_argument('--moses-source-lang', metavar='SRC', - help='source language') - parser.add_argument('--moses-target-lang', metavar='TARGET', - help='target language') - parser.add_argument('--moses-no-dash-splits', action='store_true', default=False, - help='don\'t apply dash split rules') - parser.add_argument('--moses-no-escape', action='store_true', default=False, - help='don\'t perform HTML escaping on apostrophy, quotes, etc.') - # fmt: on - - def __init__(self, args): - self.args = args - - if getattr(args, "moses_source_lang", None) is None: - args.moses_source_lang = getattr(args, "source_lang", "en") - if getattr(args, "moses_target_lang", None) is None: - args.moses_target_lang = getattr(args, "target_lang", "en") + def __init__(self, cfg): + self.cfg = cfg try: from sacremoses import MosesTokenizer, MosesDetokenizer - self.tok = MosesTokenizer(args.moses_source_lang) - self.detok = MosesDetokenizer(args.moses_target_lang) + self.tok = MosesTokenizer(cfg.source_lang) + self.detok = MosesDetokenizer(cfg.target_lang) except ImportError: raise ImportError( "Please install Moses tokenizer with: pip install sacremoses" @@ -42,9 +40,9 @@ class MosesTokenizer(object): def encode(self, x: str) -> str: return self.tok.tokenize( x, - aggressive_dash_splits=(not self.args.moses_no_dash_splits), + aggressive_dash_splits=(not self.cfg.moses_no_dash_splits), return_str=True, - escape=(not self.args.moses_no_escape), + escape=(not self.cfg.moses_no_escape), ) def decode(self, x: str) -> str: diff --git a/fairseq/data/encoders/nltk_tokenizer.py b/fairseq/data/encoders/nltk_tokenizer.py index 3b617e73..ee164710 100644 --- a/fairseq/data/encoders/nltk_tokenizer.py +++ b/fairseq/data/encoders/nltk_tokenizer.py @@ -8,7 +8,7 @@ from fairseq.data.encoders import register_tokenizer @register_tokenizer("nltk") class NLTKTokenizer(object): - def __init__(self, source_lang=None, target_lang=None): + def __init__(self, *unused): try: from nltk.tokenize import word_tokenize diff --git a/fairseq/data/encoders/sentencepiece_bpe.py b/fairseq/data/encoders/sentencepiece_bpe.py index b25c6cae..a76d46a2 100644 --- a/fairseq/data/encoders/sentencepiece_bpe.py +++ b/fairseq/data/encoders/sentencepiece_bpe.py @@ -3,21 +3,24 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from dataclasses import dataclass, field + from fairseq import file_utils from fairseq.data.encoders import register_bpe +from fairseq.dataclass import FairseqDataclass -@register_bpe("sentencepiece") +@dataclass +class SentencepieceConfig(FairseqDataclass): + sentencepiece_model: str = field( + default="???", metadata={"help": "path to sentencepiece model"} + ) + + +@register_bpe("sentencepiece", dataclass=SentencepieceConfig) class SentencepieceBPE(object): - @staticmethod - def add_args(parser): - # fmt: off - parser.add_argument('--sentencepiece-model', type=str, - help='path to sentencepiece model') - # fmt: on - - def __init__(self, args): - sentencepiece_model = file_utils.cached_path(args.sentencepiece_model) + def __init__(self, cfg): + sentencepiece_model = file_utils.cached_path(cfg.sentencepiece_model) try: import sentencepiece as spm diff --git a/fairseq/data/encoders/space_tokenizer.py b/fairseq/data/encoders/space_tokenizer.py index 3bc7ce49..7c7f644d 100644 --- a/fairseq/data/encoders/space_tokenizer.py +++ b/fairseq/data/encoders/space_tokenizer.py @@ -10,7 +10,7 @@ from fairseq.data.encoders import register_tokenizer @register_tokenizer("space") class SpaceTokenizer(object): - def __init__(self, source_lang=None, target_lang=None): + def __init__(self, *unused): self.space_tok = re.compile(r"\s+") def encode(self, x: str) -> str: diff --git a/fairseq/data/encoders/subword_nmt_bpe.py b/fairseq/data/encoders/subword_nmt_bpe.py index e85f99af..5d724d27 100644 --- a/fairseq/data/encoders/subword_nmt_bpe.py +++ b/fairseq/data/encoders/subword_nmt_bpe.py @@ -3,25 +3,25 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from dataclasses import dataclass, field + from fairseq import file_utils from fairseq.data.encoders import register_bpe +from fairseq.dataclass import FairseqDataclass -@register_bpe("subword_nmt") +@dataclass +class SubwordNMTBPEConfig(FairseqDataclass): + bpe_codes: str = field(default="???", metadata={"help": "path to subword NMT BPE"}) + bpe_separator: str = field(default="@@", metadata={"help": "BPE separator"}) + + +@register_bpe("subword_nmt", dataclass=SubwordNMTBPEConfig) class SubwordNMTBPE(object): - @staticmethod - def add_args(parser): - # fmt: off - parser.add_argument('--bpe-codes', type=str, - help='path to subword NMT BPE') - parser.add_argument('--bpe-separator', default='@@', - help='BPE separator') - # fmt: on - - def __init__(self, args): - if args.bpe_codes is None: + def __init__(self, cfg): + if cfg.bpe_codes is None: raise ValueError("--bpe-codes is required for --bpe=subword_nmt") - codes = file_utils.cached_path(args.bpe_codes) + codes = file_utils.cached_path(cfg.bpe_codes) try: from subword_nmt import apply_bpe @@ -31,7 +31,7 @@ class SubwordNMTBPE(object): "--codes", codes, "--separator", - args.bpe_separator, + cfg.bpe_separator, ] ) self.bpe = apply_bpe.BPE( diff --git a/fairseq/dataclass/constants.py b/fairseq/dataclass/constants.py index 21b36450..2fd87f5f 100644 --- a/fairseq/dataclass/constants.py +++ b/fairseq/dataclass/constants.py @@ -9,5 +9,7 @@ from fairseq.dataclass.utils import ChoiceEnum LOG_FORMAT_CHOICES = ChoiceEnum(["json", "none", "simple", "tqdm"]) DDP_BACKEND_CHOICES = ChoiceEnum(["c10d", "no_c10d"]) DISTRIBUTED_WRAPPER_CHOICES = ChoiceEnum(["DDP", "SlowMo"]) +GENERATION_CONSTRAINTS_CHOICES = ChoiceEnum(["ordered", "unordered"]) +GENERATION_DECODING_FORMAT_CHOICES = ChoiceEnum(["unigram", "ensemble", "vote", "dp", "bs"]) ZERO_SHARDING_CHOICES = ChoiceEnum(["none", "os"]) PIPELINE_CHECKPOINT_CHOICES = ChoiceEnum(["always", "never", "except_last"]) diff --git a/fairseq/dataclass/data_class.py b/fairseq/dataclass/data_class.py index ed1d12d8..b0c17ba0 100644 --- a/fairseq/dataclass/data_class.py +++ b/fairseq/dataclass/data_class.py @@ -3,32 +3,37 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import logging import sys from argparse import Namespace -from dataclasses import dataclass, field +from dataclasses import _MISSING_TYPE, dataclass, field from typing import Any, Dict, List, Optional, Tuple, Type import torch -from fairseq.criterions import CRITERION_DATACLASS_REGISTRY from fairseq.data.indexed_dataset import get_available_dataset_impl from fairseq.dataclass.constants import ( DDP_BACKEND_CHOICES, DISTRIBUTED_WRAPPER_CHOICES, + GENERATION_CONSTRAINTS_CHOICES, + GENERATION_DECODING_FORMAT_CHOICES, LOG_FORMAT_CHOICES, PIPELINE_CHECKPOINT_CHOICES, ZERO_SHARDING_CHOICES, ) from fairseq.dataclass.utils import ChoiceEnum, FairseqDataclass from fairseq.models import ARCH_MODEL_REGISTRY, MODEL_DATACLASS_REGISTRY -from fairseq.optim import OPTIMIZER_DATACLASS_REGISTRY from fairseq.optim.bmuf import FairseqBMUFConfig -from fairseq.optim.lr_scheduler import LR_SCHEDULER_DATACLASS_REGISTRY +from fairseq.registry import REGISTRIES from fairseq.tasks import TASK_DATACLASS_REGISTRY from hydra.core.config_store import ConfigStore +from omegaconf import II + + +logger = logging.getLogger(__name__) @dataclass -class CommonParams(FairseqDataclass): +class CommonConfig(FairseqDataclass): # This is the core dataclass including common parameters shared by all different jobs. Please append your params to other dataclasses if they were # used for a particular purpose or task, such as those dedicated for `distributed training`, `optimization`, etc. no_progress_bar: bool = field( @@ -109,18 +114,6 @@ class CommonParams(FairseqDataclass): model_parallel_size: int = field( default=1, metadata={"help": "total number of GPUs to parallelize model over"} ) - checkpoint_suffix: str = field( - default="", metadata={"help": "suffix to add to the checkpoint file name"} - ) - checkpoint_shard_count: int = field( - default=1, - metadata={ - "help": "Number of shards containing the checkpoint - " - "if the checkpoint is over 300GB, it is preferable " - "to split it into shards to prevent OOM on CPU while loading " - "the checkpoint" - }, - ) quantization_config_path: Optional[str] = field( default=None, metadata={"help": "path to quantization config file"} ) @@ -130,7 +123,7 @@ class CommonParams(FairseqDataclass): @dataclass -class DistributedTrainingParams(FairseqDataclass): +class DistributedTrainingConfig(FairseqDataclass): distributed_world_size: int = field( default=max(1, torch.cuda.device_count()), metadata={ @@ -229,7 +222,7 @@ class DistributedTrainingParams(FairseqDataclass): default=False, metadata={"help": "if set, use pipeline model parallelism across GPUs"}, ) - pipeline_balance: str = field( + pipeline_balance: Optional[str] = field( default=None, metadata={ "help": "partition the model into N_K pieces, where each piece " @@ -237,7 +230,7 @@ class DistributedTrainingParams(FairseqDataclass): "should equal the total number of layers in the model" }, ) - pipeline_devices: str = field( + pipeline_devices: Optional[str] = field( default=None, metadata={ "help": "a list of device indices indicating which device to place " @@ -245,10 +238,10 @@ class DistributedTrainingParams(FairseqDataclass): "equal the length of the --pipeline-balance argument" }, ) - pipeline_chunks: int = field( + pipeline_chunks: Optional[int] = field( default=0, metadata={"help": "microbatch count for pipeline model parallelism"} ) - pipeline_encoder_balance: str = field( + pipeline_encoder_balance: Optional[str] = field( default=None, metadata={ "help": "partition the pipeline parallel encoder into N_K pieces, where each piece " @@ -256,7 +249,7 @@ class DistributedTrainingParams(FairseqDataclass): "should equal the total number of encoder layers in the model" }, ) - pipeline_encoder_devices: str = field( + pipeline_encoder_devices: Optional[str] = field( default=None, metadata={ "help": "a list of device indices indicating which device to place " @@ -264,7 +257,7 @@ class DistributedTrainingParams(FairseqDataclass): "equal the length of the --pipeline-encoder-balance argument" }, ) - pipeline_decoder_balance: str = field( + pipeline_decoder_balance: Optional[str] = field( default=None, metadata={ "help": "partition the pipeline parallel decoder into N_K pieces, where each piece " @@ -272,7 +265,7 @@ class DistributedTrainingParams(FairseqDataclass): "should equal the total number of decoder layers in the model" }, ) - pipeline_decoder_devices: str = field( + pipeline_decoder_devices: Optional[str] = field( default=None, metadata={ "help": "a list of device indices indicating which device to place " @@ -287,10 +280,11 @@ class DistributedTrainingParams(FairseqDataclass): zero_sharding: ZERO_SHARDING_CHOICES = field( default="none", metadata={"help": "ZeRO sharding"} ) + tpu: bool = II("common.tpu") @dataclass -class DatasetParams(FairseqDataclass): +class DatasetConfig(FairseqDataclass): num_workers: int = field( default=1, metadata={"help": "how many subprocesses to use for data loading"} ) @@ -374,7 +368,7 @@ class DatasetParams(FairseqDataclass): @dataclass -class OptimizationParams(FairseqDataclass): +class OptimizationConfig(FairseqDataclass): max_epoch: int = field( default=0, metadata={"help": "force stop training at specified epoch"} ) @@ -421,7 +415,7 @@ class OptimizationParams(FairseqDataclass): @dataclass -class CheckpointParams(FairseqDataclass): +class CheckpointConfig(FairseqDataclass): save_dir: str = field( default="checkpoints", metadata={"help": "path to save checkpoints"} ) @@ -514,12 +508,217 @@ class CheckpointParams(FairseqDataclass): ) }, ) + checkpoint_suffix: str = field( + default="", metadata={"help": "suffix to add to the checkpoint file name"} + ) + checkpoint_shard_count: int = field( + default=1, + metadata={ + "help": "Number of shards containing the checkpoint - " + "if the checkpoint is over 300GB, it is preferable " + "to split it into shards to prevent OOM on CPU while loading " + "the checkpoint" + }, + ) + model_parallel_size: int = II("common.model_parallel_size") + distributed_rank: int = II("distributed_training.distributed_rank") @dataclass -class CommonEvalParams(FairseqDataclass): +class GenerationConfig(FairseqDataclass): + beam: int = field( + default=5, + metadata={"help": "beam size"}, + ) + nbest: int = field( + default=1, + metadata={"help": "number of hypotheses to output"}, + ) + max_len_a: float = field( + default=0, + metadata={ + "help": "generate sequences of maximum length ax + b, where x is the source length" + }, + ) + max_len_b: int = field( + default=200, + metadata={ + "help": "generate sequences of maximum length ax + b, where x is the source length" + }, + ) + min_len: int = field( + default=1, + metadata={"help": "minimum generation length"}, + ) + match_source_len: bool = field( + default=False, + metadata={"help": "generations should match the source length"}, + ) + unnormalized: bool = field( + default=False, + metadata={"help": "compare unnormalized hypothesis scores"}, + ) + no_early_stop: bool = field( + default=False, + metadata={"help": "deprecated"}, + ) + no_beamable_mm: bool = field( + default=False, + metadata={"help": "don't use BeamableMM in attention layers"}, + ) + lenpen: float = field( + default=1, + metadata={ + "help": "length penalty: <1.0 favors shorter, >1.0 favors longer sentences" + }, + ) + unkpen: float = field( + default=0, + metadata={ + "help": "unknown word penalty: <0 produces more unks, >0 produces fewer" + }, + ) + replace_unk: Optional[str] = field( + default=None, + metadata={ + "help": "perform unknown replacement (optionally with alignment dictionary)", + "argparse_const": "@@ ", + }, + ) + sacrebleu: bool = field( + default=False, + metadata={"help": "score with sacrebleu"}, + ) + score_reference: bool = field( + default=False, + metadata={"help": "just score the reference translation"}, + ) + prefix_size: int = field( + default=0, + metadata={"help": "initialize generation by target prefix of given length"}, + ) + no_repeat_ngram_size: int = field( + default=0, + metadata={ + "help": "ngram blocking such that this size ngram cannot be repeated in the generation" + }, + ) + sampling: bool = field( + default=False, + metadata={"help": "sample hypotheses instead of using beam search"}, + ) + sampling_topk: int = field( + default=-1, + metadata={"help": "sample from top K likely next words instead of all words"}, + ) + sampling_topp: float = field( + default=-1.0, + metadata={ + "help": "sample from the smallest set whose cumulative probability mass exceeds p for next words" + }, + ) + constraints: Optional[GENERATION_CONSTRAINTS_CHOICES] = field( + default=None, + metadata={ + "help": "enables lexically constrained decoding", + "argparse_const": "ordered", + }, + ) + temperature: float = field( + default=1.0, + metadata={"help": "temperature for generation"}, + ) + diverse_beam_groups: int = field( + default=-1, + metadata={"help": "number of groups for Diverse Beam Search"}, + ) + diverse_beam_strength: float = field( + default=0.5, + metadata={"help": "strength of diversity penalty for Diverse Beam Search"}, + ) + diversity_rate: float = field( + default=-1.0, + metadata={"help": "strength of diversity penalty for Diverse Siblings Search"}, + ) + print_alignment: bool = field( + default=False, + metadata={ + "help": "if set, uses attention feedback to compute and print alignment to source tokens" + }, + ) + print_step: bool = field( + default=False, + metadata={"help": "print steps"}, + ) + lm_path: Optional[str] = field( + default=None, + metadata={"help": "path to lm checkpoint for lm fusion"}, + ) + lm_weight: float = field( + default=0.0, + metadata={"help": "weight for lm probs for lm fusion"}, + ) + + # arguments for iterative refinement generator + iter_decode_eos_penalty: float = field( + default=0.0, + metadata={"help": "if > 0.0, it penalized early-stopping in decoding."}, + ) + iter_decode_max_iter: int = field( + default=10, + metadata={"help": "maximum iterations for iterative refinement."}, + ) + iter_decode_force_max_iter: bool = field( + default=False, + metadata={ + "help": "if set, run exact the maximum number of iterations without early stop" + }, + ) + iter_decode_with_beam: int = field( + default=1, + metadata={ + "help": "if > 1, model will generate translations varying by the lengths." + }, + ) + iter_decode_with_external_reranker: bool = field( + default=False, + metadata={ + "help": "if set, the last checkpoint are assumed to be a reranker to rescore the translations" + }, + ) + retain_iter_history: bool = field( + default=False, + metadata={ + "help": "if set, decoding returns the whole history of iterative refinement" + }, + ) + retain_dropout: bool = field( + default=False, + metadata={"help": "Use dropout at inference time"}, + ) + retain_dropout_modules: Optional[List[str]] = field( + default=None, + metadata={ + "help": "if set, only retain dropout for the specified modules; " + "if not set, then dropout will be retained for all modules" + }, + ) + # special decoding format for advanced decoding. + decoding_format: Optional[GENERATION_DECODING_FORMAT_CHOICES] = field( + default=None, + metadata={"help": "special decoding format for advanced decoding."}, + ) + no_seed_provided: bool = field( + default=False, + metadata={"help": "if set, dont use seed for initializing random generators"}, + ) + + +@dataclass +class CommonEvalConfig(FairseqDataclass): path: Optional[str] = field( - default=None, metadata={"help": "path(s) to model file(s), colon separated"} + default=None, + metadata={"help": "path(s) to model file(s), colon separated"}, ) remove_bpe: Optional[str] = field( default=None, @@ -541,7 +740,7 @@ class CommonEvalParams(FairseqDataclass): @dataclass -class EvalLMParams(FairseqDataclass): +class EvalLMConfig(FairseqDataclass): output_word_probs: bool = field( default=False, metadata={ @@ -569,37 +768,31 @@ class EvalLMParams(FairseqDataclass): @dataclass -class TrainingConfig(FairseqDataclass): - """Config for training, a composition of training params""" - - common: CommonParams = CommonParams() - distributed_training: DistributedTrainingParams = DistributedTrainingParams() - dataset: DatasetParams = DatasetParams() - optimization: OptimizationParams = OptimizationParams() - checkpoint: CheckpointParams = CheckpointParams() - bmuf: FairseqBMUFConfig = FairseqBMUFConfig() +class InteractiveConfig(FairseqDataclass): + buffer_size: int = field( + default=0, + metadata={ + "help": "read this many sentences into a buffer before processing them" + }, + ) + input: str = field( + default="-", + metadata={"help": "file to read from; use - for stdin"}, + ) -@dataclass -class EvalLMConfig(FairseqDataclass): - """Config for eval lm, a composition of eval_lm params""" - - common: CommonParams = CommonParams() - distributed_training: DistributedTrainingParams = DistributedTrainingParams() - dataset: DatasetParams = DatasetParams() - optimization: OptimizationParams = OptimizationParams() - checkpoint: CheckpointParams = CheckpointParams() - bmuf: FairseqBMUFConfig = FairseqBMUFConfig() - common_eval: CommonEvalParams = CommonEvalParams() - eval_lm: EvalLMParams = EvalLMParams() - - -def register_params_dataclass( - cs: ConfigStore, name: str, group: str, data_class: Type[FairseqDataclass] -) -> None: - """register params dataclass in config store""" - node_ = data_class(_name=data_class.name()) - cs.store(name=name, group=group, node=node_) +CONFIGS = { + "common": CommonConfig, + "common_eval": CommonEvalConfig, + "distributed_training": DistributedTrainingConfig, + "dataset": DatasetConfig, + "optimization": OptimizationConfig, + "checkpoint": CheckpointConfig, + "bmuf": FairseqBMUFConfig, + "generation": GenerationConfig, + "eval_lm": EvalLMConfig, + "interactive": InteractiveConfig, +} def register_module_dataclass( @@ -608,100 +801,67 @@ def register_module_dataclass( """register dataclasses defined in modules in config store, for example, in migrated tasks, models, etc.""" # note that if `group == model`, we register all model archs, not the model name. for k, v in registry.items(): - if v is not None: - node_ = v(_name=k) - cs.store(name=k, group=group, node=node_) + node_ = v() + node_._name = k + cs.store(name=k, group=group, node=node_, provider="fairseq") -def register_training_hydra_cfg(cs: ConfigStore, name: str = "default") -> None: +def register_hydra_cfg(cs: ConfigStore, name: str = "default") -> None: """cs: config store instance, register common training configs""" - register_params_dataclass( - cs, name="training_params", group="params", data_class=TrainingConfig - ) + for k, v in CONFIGS.items(): + try: + cs.store(name=k, node=v()) + except BaseException: + logger.error(f"{k} - {v()}") + raise register_module_dataclass(cs, TASK_DATACLASS_REGISTRY, "task") register_module_dataclass(cs, MODEL_DATACLASS_REGISTRY, "model") - register_module_dataclass(cs, CRITERION_DATACLASS_REGISTRY, "criterion") - register_module_dataclass(cs, OPTIMIZER_DATACLASS_REGISTRY, "optimizer") - register_module_dataclass(cs, LR_SCHEDULER_DATACLASS_REGISTRY, "lr_scheduler") - -def register_eval_lm_hydra_cfg(cs: ConfigStore, name: str = "default") -> None: - """cs: config store instance, register common training configs""" - - register_params_dataclass( - cs, name="eval_lm_params", group="params", data_class=EvalLMConfig - ) - - register_module_dataclass(cs, TASK_DATACLASS_REGISTRY, "task") - register_module_dataclass(cs, CRITERION_DATACLASS_REGISTRY, "criterion") - register_module_dataclass(cs, OPTIMIZER_DATACLASS_REGISTRY, "optimizer") - register_module_dataclass(cs, LR_SCHEDULER_DATACLASS_REGISTRY, "lr_scheduler") + for k, v in REGISTRIES.items(): + register_module_dataclass(cs, v["dataclass_registry"], k) def _override_attr( sub_node: str, data_class: Type[FairseqDataclass], args: Namespace ) -> List[str]: overrides = [] - for k in data_class.__dataclass_fields__.keys(): - if k == "_name": + + def get_default(f): + if not isinstance(f.default_factory, _MISSING_TYPE): + return f.default_factory() + return f.default + + for k, v in data_class.__dataclass_fields__.items(): + if k.startswith("_"): # private member, skip continue - if not hasattr(args, k): - # print(f"cannot override {sub_node}.{k} since args does not have attribute {k}") - continue - if getattr(args, k) is None: + + val = get_default(v) if not hasattr(args, k) else getattr(args, k) + + if val is None: overrides.append("{}.{}=null".format(sub_node, k)) - elif getattr(args, k) == "": + elif val == "": overrides.append("{}.{}=''".format(sub_node, k)) - elif isinstance(getattr(args, k), str): - if ( - getattr(args, k).startswith("[") - or getattr(args, k).startswith("(") - or getattr(args, k).startswith("{") - or ("," in getattr(args, k)) - ): - overrides.append("{}.{}='{}'".format(sub_node, k, getattr(args, k))) - else: - overrides.append("{}.{}={}".format(sub_node, k, getattr(args, k))) + elif isinstance(val, str): + overrides.append("{}.{}='{}'".format(sub_node, k, val)) else: - overrides.append("{}.{}={}".format(sub_node, k, getattr(args, k))) + overrides.append("{}.{}={}".format(sub_node, k, val)) return overrides -def override_training_args(args: Namespace) -> Tuple[List[str], List[str]]: - overrides = [] - - overrides.extend(_override_attr("params.common", CommonParams, args)) - overrides.extend(_override_attr("params.dataset", DatasetParams, args)) - overrides.extend( - _override_attr("params.distributed_training", DistributedTrainingParams, args) - ) - overrides.extend(_override_attr("params.optimization", OptimizationParams, args)) - overrides.extend(_override_attr("params.checkpoint", CheckpointParams, args)) - overrides.extend(_override_attr("params.bmuf", FairseqBMUFConfig, args)) - module_overrides, module_deletes = override_module_args(args) - overrides.extend(module_overrides) - - return overrides, module_deletes - - -def override_eval_lm_args(args: Namespace) -> Tuple[List[str], List[str]]: - overrides = [] - - overrides.extend(_override_attr("params.common", CommonParams, args)) - overrides.extend(_override_attr("params.dataset", DatasetParams, args)) - overrides.extend( - _override_attr("params.distributed_training", DistributedTrainingParams, args) - ) - overrides.extend(_override_attr("params.common_eval", CommonEvalParams, args)) - overrides.extend(_override_attr("params.eval_lm", EvalLMParams, args)) - overrides.extend(_override_attr("params.bmuf", FairseqBMUFConfig, args)) - module_overrides, module_deletes = override_module_args(args) - overrides.extend(module_overrides) - - return overrides, module_deletes +def migrate_registry( + name, value, registry, args, overrides, deletes, use_name_as_val=False +): + if value in registry: + overrides.append("{}={}".format(name, value)) + overrides.append("{}._name={}".format(name, value)) + overrides.extend(_override_attr(name, registry[value], args)) + elif use_name_as_val and value is not None: + overrides.append("{}={}".format(name, value)) + else: + deletes.append(name) def override_module_args(args: Namespace) -> Tuple[List[str], List[str]]: @@ -709,53 +869,34 @@ def override_module_args(args: Namespace) -> Tuple[List[str], List[str]]: overrides = [] deletes = [] + for k, v in CONFIGS.items(): + overrides.extend(_override_attr(k, v, args)) + if args is not None: - assert ( - hasattr(args, "task") - and hasattr(args, "criterion") - and hasattr(args, "optimizer") - and hasattr(args, "lr_scheduler") - ) - if args.task in TASK_DATACLASS_REGISTRY: - overrides.append("task={}".format(args.task)) - overrides.append("task._name={}".format(args.task)) - overrides.extend( - _override_attr("task", TASK_DATACLASS_REGISTRY[args.task], args) + if hasattr(args, "task"): + migrate_registry( + "task", args.task, TASK_DATACLASS_REGISTRY, args, overrides, deletes ) else: deletes.append("task") - if args.criterion in CRITERION_DATACLASS_REGISTRY: - overrides.append("criterion={}".format(args.criterion)) - overrides.append("criterion._name={}".format(args.criterion)) - overrides.extend( - _override_attr( - "criterion", CRITERION_DATACLASS_REGISTRY[args.criterion], args - ) - ) - else: - deletes.append("criterion") - if args.optimizer in OPTIMIZER_DATACLASS_REGISTRY: - overrides.append("optimizer={}".format(args.optimizer)) - overrides.append("optimizer._name={}".format(args.optimizer)) - overrides.extend( - _override_attr( - "optimizer", OPTIMIZER_DATACLASS_REGISTRY[args.optimizer], args - ) - ) - else: - deletes.append("optimizer") - if args.lr_scheduler in LR_SCHEDULER_DATACLASS_REGISTRY: - overrides.append("lr_scheduler={}".format(args.lr_scheduler)) - overrides.append("lr_scheduler._name={}".format(args.lr_scheduler)) - overrides.extend( - _override_attr( - "lr_scheduler", - LR_SCHEDULER_DATACLASS_REGISTRY[args.lr_scheduler], + + # these options will be set to "None" if they have not yet been migrated + # so we can populate them with the entire flat args + CORE_REGISTRIES = {"criterion", "optimizer", "lr_scheduler"} + + for k, v in REGISTRIES.items(): + if hasattr(args, k): + migrate_registry( + k, + getattr(args, k), + v["dataclass_registry"], args, + overrides, + deletes, + use_name_as_val=k not in CORE_REGISTRIES, ) - ) - else: - deletes.append("lr_scheduler") + else: + deletes.append(k) no_dc = True if hasattr(args, "arch"): diff --git a/fairseq/dataclass/utils.py b/fairseq/dataclass/utils.py index 599cc2b4..bcfe2329 100644 --- a/fairseq/dataclass/utils.py +++ b/fairseq/dataclass/utils.py @@ -3,17 +3,24 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from argparse import ArgumentParser -from dataclasses import MISSING, dataclass +import ast +from argparse import ArgumentParser, Namespace +from dataclasses import _MISSING_TYPE, MISSING, dataclass from enum import Enum from typing import Any, Dict, List, Optional +from hydra.core.global_hydra import GlobalHydra +from hydra.experimental import compose, initialize +from omegaconf import DictConfig, OmegaConf, open_dict + def eval_str_list(x, x_type=float): if x is None: return None if isinstance(x, str): - x = eval(x) + if len(x) == 0: + return [] + x = ast.literal_eval(x) try: return list(map(x_type, x)) except TypeError: @@ -70,22 +77,11 @@ class FairseqDataclass: != self.__dataclass_fields__[attribute_name].default ): return getattr(self, attribute_name) - return self.__dataclass_fields__[attribute_name].default - def _get_default_factory(self, attribute_name: str) -> Any: - if hasattr(self, attribute_name): - if str(getattr(self, attribute_name)).startswith("${"): - return str(getattr(self, attribute_name)) - elif str(self.__dataclass_fields__[attribute_name].default).startswith( - "${" - ): - return str(self.__dataclass_fields__[attribute_name].default) - elif ( - getattr(self, attribute_name) - != self.__dataclass_fields__[attribute_name].default_factory() - ): - return getattr(self, attribute_name) - return self.__dataclass_fields__[attribute_name].default_factory() + f = self.__dataclass_fields__[attribute_name] + if not isinstance(f.default_factory, _MISSING_TYPE): + return f.default_factory() + return f.default def _get_type(self, attribute_name: str) -> Any: return self.__dataclass_fields__[attribute_name].type @@ -119,7 +115,7 @@ def gen_parser_from_dataclass( def interpret_dc_type(field_type): if isinstance(field_type, str): - raise RuntimeError() + raise RuntimeError("field should be a type") typestring = str(field_type) if re.match(r"(typing.|^)Union\[(.*), NoneType\]$", typestring): return field_type.__args__[0] @@ -129,12 +125,13 @@ def gen_parser_from_dataclass( dataclass_instance: FairseqDataclass, k: str ) -> Dict[str, Any]: """k: dataclass attributes""" + + kwargs = {} + field_type = dataclass_instance._get_type(k) inter_type = interpret_dc_type(field_type) - if isinstance(inter_type, type) and issubclass(inter_type, List): - field_default = dataclass_instance._get_default_factory(k) - else: - field_default = dataclass_instance._get_default(k) + + field_default = dataclass_instance._get_default(k) if isinstance(inter_type, type) and issubclass(inter_type, Enum): field_choices = [t.value for t in list(inter_type)] @@ -143,7 +140,7 @@ def gen_parser_from_dataclass( field_help = dataclass_instance._get_help(k) field_const = dataclass_instance._get_argparse_const(k) - kwargs = {} + if isinstance(field_default, str) and field_default.startswith("${"): kwargs["default"] = field_default else: @@ -163,7 +160,11 @@ def gen_parser_from_dataclass( else: raise NotImplementedError() if field_default is not MISSING: - kwargs["default"] = ",".join(map(str, field_default)) + kwargs["default"] = ( + ",".join(map(str, field_default)) + if field_default is not None + else None + ) elif ( isinstance(inter_type, type) and issubclass(inter_type, Enum) ) or "Enum" in str(inter_type): @@ -187,6 +188,7 @@ def gen_parser_from_dataclass( if field_const is not None: kwargs["const"] = field_const kwargs["nargs"] = "?" + return kwargs for k in dataclass_instance._get_all_attributes(): @@ -194,8 +196,122 @@ def gen_parser_from_dataclass( if field_name is None: continue kwargs = get_kwargs_from_dc(dataclass_instance, k) - if isinstance(kwargs["default"], str) and kwargs["default"].startswith("${"): - continue - if delete_default: - del kwargs["default"] + + if "default" in kwargs: + if isinstance(kwargs["default"], str) and kwargs["default"].startswith( + "${" + ): + continue + if delete_default: + del kwargs["default"] parser.add_argument(field_name, **kwargs) + + +def _set_legacy_defaults(args, cls): + """Helper to set default arguments based on *add_args*.""" + if not hasattr(cls, "add_args"): + return + + import argparse + + parser = argparse.ArgumentParser( + argument_default=argparse.SUPPRESS, allow_abbrev=False + ) + cls.add_args(parser) + # copied from argparse.py: + defaults = argparse.Namespace() + for action in parser._actions: + if action.dest is not argparse.SUPPRESS: + if not hasattr(defaults, action.dest): + if action.default is not argparse.SUPPRESS: + setattr(defaults, action.dest, action.default) + for key, default_value in vars(defaults).items(): + if not hasattr(args, key): + setattr(args, key, default_value) + + +def convert_namespace_to_omegaconf(args: Namespace) -> DictConfig: + from fairseq.dataclass.data_class import override_module_args + + # Here we are using field values provided in args to override counterparts inside config object + overrides, deletes = override_module_args(args) + + cfg_name = "config" + cfg_path = f"../../{cfg_name}" + + if not GlobalHydra().is_initialized(): + initialize(config_path=cfg_path) + + composed_cfg = compose(cfg_name, overrides=overrides, strict=False) + for k in deletes: + composed_cfg[k] = None + + cfg = OmegaConf.create( + OmegaConf.to_container(composed_cfg, resolve=True, enum_to_str=True) + ) + + # hack to be able to set Namespace in dict config. this should be removed when we update to newer + # omegaconf version that supports object flags, or when we migrate all existing models + from omegaconf import _utils + + old_primitive = _utils.is_primitive_type + _utils.is_primitive_type = lambda _: True + + if cfg.task is None and getattr(args, "task", None): + cfg.task = Namespace(**vars(args)) + from fairseq.tasks import TASK_REGISTRY + + _set_legacy_defaults(cfg.task, TASK_REGISTRY[args.task]) + cfg.task._name = args.task + if cfg.model is None and getattr(args, "arch", None): + cfg.model = Namespace(**vars(args)) + from fairseq.models import ARCH_MODEL_REGISTRY + + _set_legacy_defaults(cfg.model, ARCH_MODEL_REGISTRY[args.arch]) + cfg.model._name = args.arch + if cfg.optimizer is None and getattr(args, "optimizer", None): + cfg.optimizer = Namespace(**vars(args)) + from fairseq.optim import OPTIMIZER_REGISTRY + + _set_legacy_defaults(cfg.optimizer, OPTIMIZER_REGISTRY[args.optimizer]) + cfg.optimizer._name = args.optimizer + if cfg.lr_scheduler is None and getattr(args, "lr_scheduler", None): + cfg.lr_scheduler = Namespace(**vars(args)) + from fairseq.optim.lr_scheduler import LR_SCHEDULER_REGISTRY + + _set_legacy_defaults(cfg.lr_scheduler, LR_SCHEDULER_REGISTRY[args.lr_scheduler]) + cfg.lr_scheduler._name = args.lr_scheduler + if cfg.criterion is None and getattr(args, "criterion", None): + cfg.criterion = Namespace(**vars(args)) + from fairseq.criterions import CRITERION_REGISTRY + + _set_legacy_defaults(cfg.criterion, CRITERION_REGISTRY[args.criterion]) + cfg.criterion._name = args.criterion + + _utils.is_primitive_type = old_primitive + OmegaConf.set_struct(cfg, True) + return cfg + + +def populate_dataclass( + args: Namespace, dataclass: FairseqDataclass +) -> FairseqDataclass: + for k in dataclass.__dataclass_fields__.keys(): + if k.startswith("_"): + # private member, skip + continue + if hasattr(args, k): + setattr(dataclass, k, getattr(args, k)) + + return dataclass + + +def overwrite_args_by_name(cfg: DictConfig, overrides: Dict[str, any]): + # this will be deprecated when we get rid of argparse and model_overrides logic + + with open_dict(cfg): + for k in cfg.keys(): + if isinstance(cfg[k], DictConfig): + overwrite_args_by_name(cfg[k], overrides) + elif k in overrides: + cfg[k] = overrides[k] diff --git a/fairseq/distributed_utils.py b/fairseq/distributed_utils.py index bcb0595e..23cdfc69 100644 --- a/fairseq/distributed_utils.py +++ b/fairseq/distributed_utils.py @@ -11,35 +11,38 @@ import socket import struct import subprocess import warnings +from argparse import Namespace from collections import OrderedDict from typing import Any, Dict, Mapping import torch import torch.distributed as dist from fairseq import utils +from fairseq.dataclass.utils import convert_namespace_to_omegaconf +from omegaconf import DictConfig, open_dict logger = logging.getLogger(__name__) -def is_master(args): - return args.distributed_rank == 0 +def is_master(cfg: DictConfig): + return cfg.distributed_rank == 0 -def infer_init_method(args, force_distributed=False): - if args.distributed_init_method is not None or getattr(args, "tpu", False): +def infer_init_method(cfg: DictConfig, force_distributed=False): + if cfg.distributed_init_method is not None or cfg.tpu: return - if args.pipeline_model_parallel: + if cfg.pipeline_model_parallel: balance_exists = ( - args.pipeline_balance is not None - or args.pipeline_encoder_balance is not None - or args.pipeline_decoder_balance is not None + cfg.pipeline_balance is not None + or cfg.pipeline_encoder_balance is not None + or cfg.pipeline_decoder_balance is not None ) devices_exist = ( - args.pipeline_devices is not None - or args.pipeline_encoder_devices is not None - or args.pipeline_decoder_devices is not None + cfg.pipeline_devices is not None + or cfg.pipeline_encoder_devices is not None + or cfg.pipeline_decoder_devices is not None ) if not balance_exists: raise ValueError( @@ -50,19 +53,19 @@ def infer_init_method(args, force_distributed=False): "--pipeline-devices is currently required for pipeline model parallelism" ) - args.pipeline_balance = utils.eval_str_list(args.pipeline_balance, type=int) - if args.pipeline_devices is not None: - args.pipeline_devices = utils.eval_str_list(args.pipeline_devices, type=int) - num_pipeline_devices = len(set(args.pipeline_devices)) + cfg.pipeline_balance = utils.eval_str_list(cfg.pipeline_balance, type=int) + if cfg.pipeline_devices is not None: + cfg.pipeline_devices = utils.eval_str_list(cfg.pipeline_devices, type=int) + num_pipeline_devices = len(set(cfg.pipeline_devices)) else: - args.pipeline_encoder_devices = utils.eval_str_list( - args.pipeline_encoder_devices, type=int + cfg.pipeline_encoder_devices = utils.eval_str_list( + cfg.pipeline_encoder_devices, type=int ) - args.pipeline_decoder_devices = utils.eval_str_list( - args.pipeline_decoder_devices, type=int + cfg.pipeline_decoder_devices = utils.eval_str_list( + cfg.pipeline_decoder_devices, type=int ) num_pipeline_devices = len( - set(args.pipeline_encoder_devices + args.pipeline_decoder_devices) + set(cfg.pipeline_encoder_devices + cfg.pipeline_decoder_devices) ) gpus_per_node = torch.cuda.device_count() assert ( @@ -79,14 +82,14 @@ def infer_init_method(args, force_distributed=False): key in os.environ for key in ["MASTER_ADDR", "MASTER_PORT", "WORLD_SIZE", "RANK"] ): - args.distributed_init_method = "env://" - args.distributed_world_size = int(os.environ["WORLD_SIZE"]) - args.distributed_rank = int(os.environ["RANK"]) + cfg.distributed_init_method = "env://" + cfg.distributed_world_size = int(os.environ["WORLD_SIZE"]) + cfg.distributed_rank = int(os.environ["RANK"]) # processes are created by torch.distributed.launch - args.distributed_no_spawn = True + cfg.distributed_no_spawn = True # we can determine the init method automatically for Slurm - elif args.distributed_port > 0: + elif cfg.distributed_port > 0: node_list = os.environ.get("SLURM_STEP_NODELIST") if node_list is None: node_list = os.environ.get("SLURM_JOB_NODELIST") @@ -95,9 +98,9 @@ def infer_init_method(args, force_distributed=False): hostnames = subprocess.check_output( ["scontrol", "show", "hostnames", node_list] ) - args.distributed_init_method = "tcp://{host}:{port}".format( + cfg.distributed_init_method = "tcp://{host}:{port}".format( host=hostnames.split()[0].decode("utf-8"), - port=args.distributed_port, + port=cfg.distributed_port, ) nnodes = int(os.environ.get("SLURM_NNODES")) ntasks_per_node = os.environ.get("SLURM_NTASKS_PER_NODE") @@ -111,88 +114,94 @@ def infer_init_method(args, force_distributed=False): if ntasks_per_node == 1: gpus_per_node = torch.cuda.device_count() node_id = int(os.environ.get("SLURM_NODEID")) - args.distributed_rank = node_id * gpus_per_node - args.distributed_world_size = nnodes * gpus_per_node - elif args.pipeline_model_parallel: + cfg.distributed_rank = node_id * gpus_per_node + cfg.distributed_world_size = nnodes * gpus_per_node + elif cfg.pipeline_model_parallel: assert ntasks_per_node == num_pipelines_per_node, ( "SLURM --ntasks-per-node must match number of pipelines per " "node (={})".format(num_pipelines_per_node) ) - args.distributed_no_spawn = True + cfg.distributed_no_spawn = True # For 4-way MP on nodes with 8 GPUs, ranks will be [0, 1] on # the first node, [1, 2] on the second node, etc. This # matches torch.distributed.launch. node_id = int(os.environ.get("SLURM_NODEID")) local_id = int(os.environ.get("SLURM_LOCALID")) - args.distributed_rank = node_id * num_pipelines_per_node + local_id + cfg.distributed_rank = node_id * num_pipelines_per_node + local_id # In the above example, device_id will always be in [0, 1], # which also matches torch.distributed.launch. - args.device_id = local_id + cfg.device_id = local_id # We also want to set distributed_world_size to be the total # number of pipelines across all nodes. - args.distributed_world_size = nnodes * num_pipelines_per_node + cfg.distributed_world_size = nnodes * num_pipelines_per_node else: - assert ntasks_per_node == args.distributed_world_size // nnodes - args.distributed_no_spawn = True - args.distributed_rank = int(os.environ.get("SLURM_PROCID")) - args.device_id = int(os.environ.get("SLURM_LOCALID")) + assert ntasks_per_node == cfg.distributed_world_size // nnodes + cfg.distributed_no_spawn = True + cfg.distributed_rank = int(os.environ.get("SLURM_PROCID")) + cfg.device_id = int(os.environ.get("SLURM_LOCALID")) except subprocess.CalledProcessError as e: # scontrol failed raise e except FileNotFoundError: # Slurm is not installed pass - elif args.distributed_world_size > 1 or force_distributed: + elif cfg.distributed_world_size > 1 or force_distributed: # fallback for single node with multiple GPUs - assert args.distributed_world_size <= torch.cuda.device_count() + assert cfg.distributed_world_size <= torch.cuda.device_count() port = random.randint(10000, 20000) - args.distributed_init_method = "tcp://localhost:{port}".format(port=port) + cfg.distributed_init_method = "tcp://localhost:{port}".format(port=port) - if args.pipeline_model_parallel: - if not args.distributed_no_spawn: + if cfg.pipeline_model_parallel: + if not cfg.distributed_no_spawn: # When distributed_no_spawn is False, we expect distributed_rank and # distributed_world_size to be based on the total number of GPUs, so # we need to correct them to be based on the number of pipelines. - assert args.distributed_world_size % num_pipeline_devices == 0 - args.distributed_world_size = ( - args.distributed_world_size // num_pipeline_devices + assert cfg.distributed_world_size % num_pipeline_devices == 0 + cfg.distributed_world_size = ( + cfg.distributed_world_size // num_pipeline_devices ) # In the case of 4-way MP on nodes with 8 GPUs, we want # distributed_rank to be the starting GPU index for each pipeline # i.e., 0, 2, ... - assert args.distributed_rank % gpus_per_node == 0 - assert args.distributed_rank % num_pipeline_devices == 0 - args.distributed_rank = args.distributed_rank // num_pipeline_devices - # launch one process per pipeline - args.distributed_num_procs = num_pipelines_per_node + assert cfg.distributed_rank % gpus_per_node == 0 + assert cfg.distributed_rank % num_pipeline_devices == 0 + + with open_dict(cfg): + cfg.distributed_rank = cfg.distributed_rank // num_pipeline_devices + # launch one process per pipeline + cfg.distributed_num_procs = num_pipelines_per_node # if we have 4-way MP on a node with 8 GPUs, we want device_ids to be 0 # and 4, indicating the starting device IDs for each pipeline - args.device_id *= num_pipeline_devices + cfg.device_id *= num_pipeline_devices - if args.device_id > 0: + if cfg.device_id > 0: # if there's multiple pipelines on a node (e.g., 4-way MP on an 8 # GPU node), we need to adjust pipeline_devices accordingly logger.debug( "setting CUDA device={} on rank {}".format( - args.device_id, args.distributed_rank + cfg.device_id, cfg.distributed_rank ) ) - torch.cuda.set_device(args.device_id) - args.pipeline_devices = [args.device_id + d for d in args.pipeline_devices] + torch.cuda.set_device(cfg.device_id) + with open_dict(cfg): + cfg.pipeline_devices = [cfg.device_id + d for d in cfg.pipeline_devices] logger.info( "setting pipeline_devices={} on rank {}".format( - args.pipeline_devices, args.distributed_rank - ), + cfg.pipeline_devices, cfg.distributed_rank + ) + ) + elif not cfg.distributed_no_spawn: + with open_dict(cfg): + cfg.distributed_num_procs = min( + torch.cuda.device_count(), cfg.distributed_world_size ) - elif not args.distributed_no_spawn: - args.distributed_num_procs = min( - torch.cuda.device_count(), - args.distributed_world_size, - ) -def distributed_init(args): - if not getattr(args, "tpu", False): +def distributed_init(cfg: DictConfig): + if isinstance(cfg, Namespace): + cfg = convert_namespace_to_omegaconf(cfg) + + if not cfg.common.tpu: if torch.distributed.is_initialized(): warnings.warn( "Distributed is already initialized, cannot initialize twice!" @@ -200,20 +209,20 @@ def distributed_init(args): else: logger.info( "distributed init (rank {}): {}".format( - args.distributed_rank, - args.distributed_init_method, + cfg.distributed_training.distributed_rank, + cfg.distributed_training.distributed_init_method, ) ) dist.init_process_group( - backend=args.distributed_backend, - init_method=args.distributed_init_method, - world_size=args.distributed_world_size, - rank=args.distributed_rank, + backend=cfg.distributed_training.distributed_backend, + init_method=cfg.distributed_training.distributed_init_method, + world_size=cfg.distributed_training.distributed_world_size, + rank=cfg.distributed_training.distributed_rank, ) logger.info( "initialized host {} as rank {}".format( socket.gethostname(), - args.distributed_rank, + cfg.distributed_training.distributed_rank, ) ) @@ -221,20 +230,22 @@ def distributed_init(args): if torch.cuda.is_available(): dist.all_reduce(torch.zeros(1).cuda()) - args.distributed_rank = torch.distributed.get_rank() + cfg.distributed_training.distributed_rank = torch.distributed.get_rank() else: import torch_xla.core.xla_model as xm - assert xm.xrt_world_size() == args.distributed_world_size - args.device_id = xm.get_local_ordinal() - args.distributed_rank = xm.get_ordinal() + assert xm.xrt_world_size() == cfg.distributed_training.distributed_world_size + cfg.distributed_training.device_id = xm.get_local_ordinal() + cfg.distributed_training.distributed_rank = xm.get_ordinal() xm.rendezvous("distributed_init") # wait for all workers xm.mark_step() - if not is_master(args): + if is_master(cfg.distributed_training): + logging.getLogger().setLevel(logging.INFO) + else: logging.getLogger().setLevel(logging.WARNING) - if args.model_parallel_size > 1: + if cfg.common.model_parallel_size > 1: try: from fairseq.model_parallel.megatron.mpu import ( get_model_parallel_rank, @@ -247,58 +258,61 @@ def distributed_init(args): "\n\n git submodule update --init " "fairseq/model_parallel/megatron" ) - initialize_model_parallel(args.model_parallel_size) - model_parallel_cuda_manual_seed(args.seed) + initialize_model_parallel(cfg.common.model_parallel_size) + model_parallel_cuda_manual_seed(cfg.common.seed) model_part_number = get_model_parallel_rank() - args.checkpoint_suffix += "-model_part-{0}".format(model_part_number) - return args.distributed_rank + cfg.checkpoint.checkpoint_suffix += "-model_part-{0}".format(model_part_number) + return cfg.distributed_training.distributed_rank -def distributed_main(i, main, args, kwargs): - args.device_id = i - if torch.cuda.is_available() and not args.cpu and not getattr(args, "tpu", False): - torch.cuda.set_device(args.device_id) - if args.distributed_rank is None: # torch.multiprocessing.spawn - args.distributed_rank = kwargs.pop("start_rank", 0) + i +def distributed_main(i, main, cfg: DictConfig, kwargs): + cfg.distributed_training.device_id = i + if torch.cuda.is_available() and not cfg.common.cpu and not cfg.common.tpu: + torch.cuda.set_device(cfg.distributed_training.device_id) + if cfg.distributed_training.distributed_rank is None: # torch.multiprocessing.spawn + cfg.distributed_training.distributed_rank = kwargs.pop("start_rank", 0) + i - args.distributed_rank = distributed_init(args) + cfg.distributed_training.distributed_rank = distributed_init(cfg) after_distributed_init_fn = kwargs.pop("after_distributed_init_fn", None) if after_distributed_init_fn: - args = after_distributed_init_fn(args) + cfg = after_distributed_init_fn(cfg) - main(args, **kwargs) + main(cfg, **kwargs) -def call_main(args, main, **kwargs): - if args.distributed_init_method is None: - infer_init_method(args) +def call_main(cfg: DictConfig, main, **kwargs): + if cfg.distributed_training.distributed_init_method is None: + infer_init_method(cfg.distributed_training) - if args.distributed_init_method is not None: + if cfg.distributed_training.distributed_init_method is not None: # distributed training - if not args.distributed_no_spawn: - start_rank = args.distributed_rank - args.distributed_rank = None # assign automatically + if not cfg.distributed_training.distributed_no_spawn: + start_rank = cfg.distributed_training.distributed_rank + cfg.distributed_training.distributed_rank = None # assign automatically kwargs["start_rank"] = start_rank torch.multiprocessing.spawn( fn=distributed_main, - args=(main, args, kwargs), - nprocs=args.distributed_num_procs, + args=(main, cfg, kwargs), + nprocs=min( + torch.cuda.device_count(), + cfg.distributed_training.distributed_world_size, + ), ) else: - distributed_main(args.device_id, main, args, kwargs) - elif getattr(args, "tpu", False) and args.distributed_world_size > 1: + distributed_main(cfg.distributed_training.device_id, main, cfg, kwargs) + elif cfg.common.tpu and cfg.distributed_training.distributed_world_size > 1: import torch_xla.distributed.xla_multiprocessing as xmp torch.multiprocessing.set_sharing_strategy("file_system") xmp.spawn( fn=distributed_main, - args=(main, args, kwargs), + args=(main, cfg, kwargs), nprocs=8, # use all 8 TPU cores ) else: # single GPU main - main(args, **kwargs) + main(cfg, **kwargs) def get_rank(): @@ -392,11 +406,7 @@ def all_gather_list(data, group=None, max_size=16384): ) -def all_reduce_dict( - data: Mapping[str, Any], - device, - group=None, -) -> Dict[str, Any]: +def all_reduce_dict(data: Mapping[str, Any], device, group=None) -> Dict[str, Any]: """ AllReduce a dictionary of values across workers. We separately reduce items that are already on the device and items on CPU for diff --git a/fairseq/hub_utils.py b/fairseq/hub_utils.py index b293e54e..3be7078b 100644 --- a/fairseq/hub_utils.py +++ b/fairseq/hub_utils.py @@ -8,11 +8,12 @@ import argparse import copy import logging import os -from typing import Any, Dict, Iterator, List, Tuple +from typing import Any, Dict, Iterator, List import torch from fairseq import utils from fairseq.data import encoders +from omegaconf import open_dict from torch import nn @@ -85,9 +86,9 @@ class GeneratorHubInterface(nn.Module): translation or language model. """ - def __init__(self, args, task, models): + def __init__(self, cfg, task, models): super().__init__() - self.args = args + self.cfg = cfg self.task = task self.models = nn.ModuleList(models) self.src_dict = task.source_dictionary @@ -95,14 +96,14 @@ class GeneratorHubInterface(nn.Module): # optimize model for generation for model in self.models: - model.prepare_for_inference_(args) + model.prepare_for_inference_(cfg) # Load alignment dictionary for unknown word replacement # (None if no unknown word replacement, empty if no path to align dictionary) - self.align_dict = utils.load_align_dict(getattr(args, "replace_unk", None)) + self.align_dict = utils.load_align_dict(cfg.generation.replace_unk) - self.tokenizer = encoders.build_tokenizer(args) - self.bpe = encoders.build_bpe(args) + self.tokenizer = encoders.build_tokenizer(cfg.tokenizer) + self.bpe = encoders.build_bpe(cfg.bpe) self.max_positions = utils.resolve_max_positions( self.task.max_positions(), *[model.max_positions() for model in models] @@ -156,10 +157,11 @@ class GeneratorHubInterface(nn.Module): )[0] # build generator using current args as well as any kwargs - gen_args = copy.copy(self.args) - gen_args.beam = beam - for k, v in kwargs.items(): - setattr(gen_args, k, v) + gen_args = copy.copy(self.cfg) + with open_dict(gen_args): + gen_args.beam = beam + for k, v in kwargs.items(): + setattr(gen_args, k, v) generator = self.task.build_generator(self.models, gen_args) inference_step_args = inference_step_args or {} @@ -253,8 +255,8 @@ class GeneratorHubInterface(nn.Module): lengths = torch.LongTensor([t.numel() for t in tokens]) batch_iterator = self.task.get_batch_iterator( dataset=self.task.build_dataset_for_inference(tokens, lengths), - max_tokens=self.args.max_tokens, - max_sentences=self.args.batch_size, + max_tokens=self.cfg.dataset.max_tokens, + max_sentences=self.cfg.dataset.batch_size, max_positions=self.max_positions, ignore_invalid_inputs=skip_invalid_size_inputs, disable_iterator_cache=True, diff --git a/fairseq/model_parallel/megatron_trainer.py b/fairseq/model_parallel/megatron_trainer.py index 761ffc8e..258551c9 100644 --- a/fairseq/model_parallel/megatron_trainer.py +++ b/fairseq/model_parallel/megatron_trainer.py @@ -9,6 +9,7 @@ Train a network across multiple GPUs. from fairseq import distributed_utils from fairseq.trainer import Trainer +from omegaconf import DictConfig try: @@ -28,14 +29,14 @@ except (ImportError, ModuleNotFoundError): class MegatronTrainer(Trainer): """Main class for model parallel with data parallel training.""" - def __init__(self, args, task, model, criterion): + def __init__(self, cfg: DictConfig, task, model, criterion, **kwargs): if not has_megatron_submodule: raise ImportError( "\n\nPlease install the megatron submodule:" "\n\n git submodule update --init " "fairseq/model_parallel/megatron" ) - super().__init__(args, task, model, criterion) + super().__init__(cfg, task, model, criterion, **kwargs) @property def data_parallel_world_size(self): diff --git a/fairseq/model_parallel/models/pipeline_parallel_transformer/model.py b/fairseq/model_parallel/models/pipeline_parallel_transformer/model.py index cbfc6ae4..76cfe3b0 100644 --- a/fairseq/model_parallel/models/pipeline_parallel_transformer/model.py +++ b/fairseq/model_parallel/models/pipeline_parallel_transformer/model.py @@ -96,7 +96,7 @@ class PipelineParallelTransformerModel(BaseFairseqModel): encoder_output_tuple = self.encoder(input) return self.decoder(encoder_output_tuple) - def prepare_for_inference_(self, args): + def prepare_for_inference_(self, cfg): if self.encoder is not None and self.decoder is not None: logger.info("Encoder and Decoder already initialized") return @@ -111,9 +111,9 @@ class PipelineParallelTransformerModel(BaseFairseqModel): decoder_module_list.append(module) module_count += 1 self.model = None - self.encoder = TransformerEncoder(args, None, None, encoder_module_list) + self.encoder = TransformerEncoder(cfg.model, None, None, encoder_module_list) self.decoder = TransformerDecoder( - args, None, None, decoder_module_list=decoder_module_list + cfg.model, None, None, decoder_module_list=decoder_module_list ) @staticmethod @@ -320,7 +320,7 @@ class PipelineParallelTransformerModel(BaseFairseqModel): """Maximum length supported by the decoder.""" return self.decoder_max_positions - def load_state_dict(self, state_dict, strict=True, args=None): + def load_state_dict(self, state_dict, strict=True, cfg=None): """Copies parameters and buffers from *state_dict* into this module and its descendants. diff --git a/fairseq/model_parallel/models/transformer_lm.py b/fairseq/model_parallel/models/transformer_lm.py index 5db6efb7..dc52f6e8 100644 --- a/fairseq/model_parallel/models/transformer_lm.py +++ b/fairseq/model_parallel/models/transformer_lm.py @@ -72,6 +72,10 @@ class ModelParallelTransformerLanguageModel(TransformerLanguageModel): ) return cls(decoder) + @staticmethod + def add_args(parser): + TransformerLanguageModel.add_args(parser) + @classmethod def build_embedding(cls, args, dictionary, embed_dim, path=None): def _vocab_init(tensor, **kwargs): diff --git a/fairseq/models/__init__.py b/fairseq/models/__init__.py index 7ff94427..3b4fd51d 100644 --- a/fairseq/models/__init__.py +++ b/fairseq/models/__init__.py @@ -7,8 +7,6 @@ import argparse import importlib import os -from argparse import Namespace -from typing import Union import fairseq from fairseq.dataclass import FairseqDataclass @@ -52,10 +50,10 @@ __all__ = [ ] -def build_model(model_cfg: Union[DictConfig, Namespace], task): - if isinstance(model_cfg, DictConfig): - return ARCH_MODEL_REGISTRY[model_cfg._name].build_model(model_cfg, task) - return ARCH_MODEL_REGISTRY[model_cfg.arch].build_model(model_cfg, task) +def build_model(cfg: DictConfig, task): + if isinstance(cfg, DictConfig): + return ARCH_MODEL_REGISTRY[cfg._name].build_model(cfg, task) + return ARCH_MODEL_REGISTRY[cfg.arch].build_model(cfg, task) def register_model(name, dataclass=None): @@ -92,7 +90,8 @@ def register_model(name, dataclass=None): ) cls.__dataclass = dataclass - MODEL_DATACLASS_REGISTRY[name] = dataclass + if dataclass is not None: + MODEL_DATACLASS_REGISTRY[name] = dataclass return cls return register_model_cls @@ -108,14 +107,13 @@ def register_model_architecture(model_name, arch_name): For example:: @register_model_architecture('lstm', 'lstm_luong_wmt_en_de') - def lstm_luong_wmt_en_de(args): - args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 1000) + def lstm_luong_wmt_en_de(cfg): + args.encoder_embed_dim = getattr(cfg.model, 'encoder_embed_dim', 1000) (...) - The decorated function should take a single argument *args*, which is a - :class:`argparse.Namespace` of arguments parsed from the command-line. The - decorated function should modify these arguments in-place to match the - desired architecture. + The decorated function should take a single argument *cfg*, which is a + :class:`omegaconf.DictConfig`. The decorated function should modify these + arguments in-place to match the desired architecture. Args: model_name (str): the name of the Model (Model must already be diff --git a/fairseq/models/bart/hub_interface.py b/fairseq/models/bart/hub_interface.py index cdabe360..6a520cb9 100644 --- a/fairseq/models/bart/hub_interface.py +++ b/fairseq/models/bart/hub_interface.py @@ -13,6 +13,7 @@ import torch.nn as nn import torch.nn.functional as F from fairseq import utils from fairseq.data import encoders +from omegaconf import open_dict logger = logging.getLogger(__name__) @@ -24,13 +25,13 @@ class BARTHubInterface(nn.Module): Usage: https://github.com/pytorch/fairseq/tree/master/examples/bart """ - def __init__(self, args, task, model): + def __init__(self, cfg, task, model): super().__init__() - self.args = args + self.cfg = cfg self.task = task self.model = model - self.bpe = encoders.build_bpe(args) + self.bpe = encoders.build_bpe(cfg.bpe) self.max_positions = min( utils.resolve_max_positions( @@ -120,10 +121,11 @@ class BARTHubInterface(nn.Module): sample = self._build_sample(tokens) # build generator using current args as well as any kwargs - gen_args = copy.copy(self.args) - gen_args.beam = beam - for k, v in kwargs.items(): - setattr(gen_args, k, v) + gen_args = copy.copy(self.cfg) + with open_dict(gen_args): + gen_args.beam = beam + for k, v in kwargs.items(): + setattr(gen_args, k, v) generator = self.task.build_generator([self.model], gen_args) translations = self.task.inference_step( generator, diff --git a/fairseq/models/bart/model.py b/fairseq/models/bart/model.py index 0f22352b..7263a78d 100644 --- a/fairseq/models/bart/model.py +++ b/fairseq/models/bart/model.py @@ -144,7 +144,9 @@ class BARTModel(TransformerModel): num_classes=num_classes, activation_fn=self.args.pooler_activation_fn, pooler_dropout=self.args.pooler_dropout, - do_spectral_norm=self.args.spectral_norm_classification_head, + do_spectral_norm=getattr( + self.args, "spectral_norm_classification_head", False + ), ) def upgrade_state_dict_named(self, state_dict, name): diff --git a/fairseq/models/fairseq_model.py b/fairseq/models/fairseq_model.py index bfd41777..3ebb30e3 100644 --- a/fairseq/models/fairseq_model.py +++ b/fairseq/models/fairseq_model.py @@ -7,6 +7,7 @@ Base classes for various fairseq models. """ import logging +from argparse import Namespace from typing import Dict, List, Optional, Tuple import torch @@ -15,8 +16,12 @@ import torch.nn.functional as F from fairseq import utils from fairseq.checkpoint_utils import prune_state_dict from fairseq.data import Dictionary -from fairseq.dataclass.utils import gen_parser_from_dataclass +from fairseq.dataclass.utils import ( + convert_namespace_to_omegaconf, + gen_parser_from_dataclass, +) from fairseq.models import FairseqDecoder, FairseqEncoder +from omegaconf import DictConfig from torch import Tensor @@ -86,15 +91,26 @@ class BaseFairseqModel(nn.Module): """Maximum length supported by the model.""" return None - def load_state_dict(self, state_dict, strict=True, args=None): + def load_state_dict( + self, + state_dict, + strict=True, + model_cfg: Optional[DictConfig] = None, + args: Optional[Namespace] = None, + ): """Copies parameters and buffers from *state_dict* into this module and its descendants. Overrides the method in :class:`nn.Module`. Compared with that method this additionally "upgrades" *state_dicts* from old checkpoints. """ + + if model_cfg is None and args is not None: + logger.warn("using 'args' is deprecated, please update your code to use dataclass config") + model_cfg = convert_namespace_to_omegaconf(args).model + self.upgrade_state_dict(state_dict) - new_state_dict = prune_state_dict(state_dict, args) + new_state_dict = prune_state_dict(state_dict, model_cfg) return super().load_state_dict(new_state_dict, strict) def upgrade_state_dict(self, state_dict): @@ -133,18 +149,18 @@ class BaseFairseqModel(nn.Module): self.apply(_apply) - def prepare_for_inference_(self, args): + def prepare_for_inference_(self, cfg: DictConfig): """Prepare model for inference.""" kwargs = {} kwargs["beamable_mm_beam_size"] = ( - None if getattr(args, "no_beamable_mm", False) else getattr(args, "beam", 5) + None + if getattr(cfg.generation, "no_beamable_mm", False) + else getattr(cfg.generation, "beam", 5) ) - kwargs["need_attn"] = getattr(args, "print_alignment", False) - if hasattr(args, "retain_dropout"): - kwargs["retain_dropout"] = args.retain_dropout - kwargs["retain_dropout_modules"] = getattr( - args, "retain_dropout_modules", None - ) + kwargs["need_attn"] = getattr(cfg.generation, "print_alignment", False) + if getattr(cfg.generation, "retain_dropout", False): + kwargs["retain_dropout"] = cfg.generation.retain_dropout + kwargs["retain_dropout_modules"] = cfg.generation.retain_dropout_modules self.make_generation_fast_(**kwargs) def make_generation_fast_(self, **kwargs): @@ -437,15 +453,26 @@ class FairseqMultiModel(BaseFairseqModel): def forward_decoder(self, prev_output_tokens, **kwargs): return self.decoder(prev_output_tokens, **kwargs) - def load_state_dict(self, state_dict, strict=True, args=None): + def load_state_dict( + self, + state_dict, + strict=True, + model_cfg=None, + args: Optional[Namespace] = None, + ): """Copies parameters and buffers from *state_dict* into this module and its descendants. Overrides the method in :class:`nn.Module`. Compared with that method this additionally "upgrades" *state_dicts* from old checkpoints. """ + + if model_cfg is None and args is not None: + logger.warn("using 'args' is deprecated, please update your code to use dataclass config") + model_cfg = convert_namespace_to_omegaconf(args).model + self.upgrade_state_dict(state_dict) - new_state_dict = prune_state_dict(state_dict, args) + new_state_dict = prune_state_dict(state_dict, model_cfg) return super().load_state_dict(new_state_dict, strict) diff --git a/fairseq/models/multilingual_transformer.py b/fairseq/models/multilingual_transformer.py index e3fbbd57..2e1f86f3 100644 --- a/fairseq/models/multilingual_transformer.py +++ b/fairseq/models/multilingual_transformer.py @@ -194,14 +194,14 @@ class MultilingualTransformerModel(FairseqMultiModel): module_class = TransformerEncoder if is_encoder else TransformerDecoder return module_class(args, lang_dict, embed_tokens) - def load_state_dict(self, state_dict, strict=True, args=None): + def load_state_dict(self, state_dict, strict=True, model_cfg=None): state_dict_subset = state_dict.copy() for k, _ in state_dict.items(): assert k.startswith("models.") lang_pair = k.split(".")[1] if lang_pair not in self.models: del state_dict_subset[k] - super().load_state_dict(state_dict_subset, strict=strict, args=args) + super().load_state_dict(state_dict_subset, strict=strict, model_cfg=model_cfg) @register_model_architecture("multilingual_transformer", "multilingual_transformer") diff --git a/fairseq/models/roberta/hub_interface.py b/fairseq/models/roberta/hub_interface.py index 526823bd..0c723f06 100644 --- a/fairseq/models/roberta/hub_interface.py +++ b/fairseq/models/roberta/hub_interface.py @@ -17,13 +17,13 @@ class RobertaHubInterface(nn.Module): Usage: https://github.com/pytorch/fairseq/tree/master/examples/roberta """ - def __init__(self, args, task, model): + def __init__(self, cfg, task, model): super().__init__() - self.args = args + self.cfg = cfg self.task = task self.model = model - self.bpe = encoders.build_bpe(args) + self.bpe = encoders.build_bpe(cfg.bpe) # this is useful for determining the device self.register_buffer("_float_tensor", torch.tensor([0], dtype=torch.float)) diff --git a/fairseq/models/roberta/model.py b/fairseq/models/roberta/model.py index 6ce216a6..d1a63196 100644 --- a/fairseq/models/roberta/model.py +++ b/fairseq/models/roberta/model.py @@ -494,7 +494,7 @@ def base_architecture(args): args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0.0) args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0.0) args.spectral_norm_classification_head = getattr( - args, "spectral_nrom_classification_head", False + args, "spectral_norm_classification_head", False ) diff --git a/fairseq/models/transformer.py b/fairseq/models/transformer.py index fbb7ce23..f87fa50d 100644 --- a/fairseq/models/transformer.py +++ b/fairseq/models/transformer.py @@ -578,10 +578,9 @@ class TransformerDecoder(FairseqIncrementalDecoder): if embed_dim != input_embed_dim else None ) - self.embed_positions = ( PositionalEmbedding( - args.max_target_positions, + self.max_target_positions, embed_dim, self.padding_idx, learned=args.decoder_learned_pos, @@ -963,6 +962,14 @@ def base_architecture(args): args.layernorm_embedding = getattr(args, "layernorm_embedding", False) args.tie_adaptive_weights = getattr(args, "tie_adaptive_weights", False) + args.encoder_layers_to_keep = getattr(args, "encoder_layers_to_keep", None) + args.decoder_layers_to_keep = getattr(args, "decoder_layers_to_keep", None) + args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0) + args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0) + args.quant_noise_pq = getattr(args, "quant_noise_pq", 0) + args.quant_noise_pq_block_size = getattr(args, "quant_noise_pq_block_size", 8) + args.quant_noise_scalar = getattr(args, "quant_noise_scalar", 0) + @register_model_architecture("transformer", "transformer_iwslt_de_en") def transformer_iwslt_de_en(args): diff --git a/fairseq/models/transformer_lm.py b/fairseq/models/transformer_lm.py index 22b17f06..df809bdb 100644 --- a/fairseq/models/transformer_lm.py +++ b/fairseq/models/transformer_lm.py @@ -159,7 +159,7 @@ class TransformerLanguageModelConfig(FairseqDataclass): add_bos_token: bool = II("task.add_bos_token") tokens_per_sample: int = II("task.tokens_per_sample") max_target_positions: Optional[int] = II("task.max_target_positions") - tpu: bool = II("params.common.tpu") + tpu: bool = II("common.tpu") @register_model("transformer_lm", dataclass=TransformerLanguageModelConfig) diff --git a/fairseq/modules/transformer_layer.py b/fairseq/modules/transformer_layer.py index 48cd4c73..8775aa77 100644 --- a/fairseq/modules/transformer_layer.py +++ b/fairseq/modules/transformer_layer.py @@ -32,20 +32,20 @@ class TransformerEncoderLayer(nn.Module): def __init__(self, args): super().__init__() self.embed_dim = args.encoder_embed_dim - self.quant_noise = getattr(args, "quant_noise_pq", 0) - self.quant_noise_block_size = getattr(args, "quant_noise_pq_block_size", 8) + self.quant_noise = getattr(args, 'quant_noise_pq', 0) + self.quant_noise_block_size = getattr(args, 'quant_noise_pq_block_size', 8) or 8 self.self_attn = self.build_self_attention(self.embed_dim, args) self.self_attn_layer_norm = LayerNorm(self.embed_dim) self.dropout_module = FairseqDropout( args.dropout, module_name=self.__class__.__name__ ) self.activation_fn = utils.get_activation_fn( - activation=getattr(args, "activation_fn", "relu") + activation=getattr(args, 'activation_fn', 'relu') or "relu" ) - activation_dropout_p = getattr(args, "activation_dropout", 0) + activation_dropout_p = getattr(args, "activation_dropout", 0) or 0 if activation_dropout_p == 0: # for backwards compatibility with models that use args.relu_dropout - activation_dropout_p = getattr(args, "relu_dropout", 0) + activation_dropout_p = getattr(args, "relu_dropout", 0) or 0 self.activation_dropout_module = FairseqDropout( float(activation_dropout_p), module_name=self.__class__.__name__ ) @@ -197,10 +197,10 @@ class TransformerDecoderLayer(nn.Module): if getattr(args, "activation_fn", None) is not None else "relu" ) - activation_dropout_p = getattr(args, "activation_dropout", 0) + activation_dropout_p = getattr(args, "activation_dropout", 0) or 0 if activation_dropout_p == 0: # for backwards compatibility with models that use args.relu_dropout - activation_dropout_p = getattr(args, "relu_dropout", 0) + activation_dropout_p = getattr(args, "relu_dropout", 0) or 0 self.activation_dropout_module = FairseqDropout( float(activation_dropout_p), module_name=self.__class__.__name__ ) diff --git a/fairseq/optim/__init__.py b/fairseq/optim/__init__.py index 94eb2c7e..d8e58172 100644 --- a/fairseq/optim/__init__.py +++ b/fairseq/optim/__init__.py @@ -6,8 +6,6 @@ import importlib import os -from argparse import Namespace -from typing import Union from fairseq import registry from fairseq.optim.bmuf import FairseqBMUF # noqa @@ -19,7 +17,6 @@ from fairseq.optim.fp16_optimizer import FP16Optimizer, MemoryEfficientFP16Optim from fairseq.optim.shard import shard_ from omegaconf import DictConfig - __all__ = [ "FairseqOptimizer", "FP16Optimizer", @@ -27,7 +24,6 @@ __all__ = [ "shard_", ] - ( _build_optimizer, register_optimizer, @@ -37,12 +33,12 @@ __all__ = [ def build_optimizer( - optimizer_cfg: Union[DictConfig, Namespace], params, *extra_args, **extra_kwargs + cfg: DictConfig, params, *extra_args, **extra_kwargs ): if all(isinstance(p, dict) for p in params): params = [t for p in params for t in p.values()] params = list(filter(lambda p: p.requires_grad, params)) - return _build_optimizer(optimizer_cfg, params, *extra_args, **extra_kwargs) + return _build_optimizer(cfg, params, *extra_args, **extra_kwargs) # automatically import any Python files in the optim/ directory diff --git a/fairseq/optim/adam.py b/fairseq/optim/adam.py index f678a9f5..9b8ddffd 100644 --- a/fairseq/optim/adam.py +++ b/fairseq/optim/adam.py @@ -5,6 +5,7 @@ import logging import math +from collections import Collection from dataclasses import dataclass, field from typing import List @@ -14,7 +15,7 @@ import torch.optim from fairseq.dataclass import FairseqDataclass from fairseq.optim import FairseqOptimizer, register_optimizer from fairseq.optim.fused_adam import get_fused_adam_class -from omegaconf import II +from omegaconf import II, DictConfig logger = logging.getLogger(__name__) @@ -33,8 +34,8 @@ class FairseqAdamConfig(FairseqDataclass): default=False, metadata={"help": "Use fairseq.optim.adam.Adam"} ) # TODO common vars below in parent - tpu: bool = II("params.common.tpu") - lr: List[float] = II("params.optimization.lr") + tpu: bool = II("common.tpu") + lr: List[float] = II("optimization.lr") @register_optimizer("adam", dataclass=FairseqAdamConfig) @@ -46,15 +47,15 @@ class FairseqAdam(FairseqOptimizer): analogous to torch.optim.AdamW from PyTorch. """ - def __init__(self, args, params): - super().__init__(args) + def __init__(self, cfg: DictConfig, params): + super().__init__(cfg) fused_adam_cls = get_fused_adam_class() use_fused_adam = ( - not getattr(args, "use_old_adam", False) + not getattr(cfg, "use_old_adam", False) and fused_adam_cls is not None and torch.cuda.is_available() ) - if getattr(args, "tpu", False): + if getattr(cfg, "tpu", False): # on TPUs we use the Adam defined here, since it # automatically casts gradients to FP32 self._optimizer = Adam(params, **self.optimizer_config) @@ -73,10 +74,12 @@ class FairseqAdam(FairseqOptimizer): different learning rate. """ return { - "lr": self.args.lr[0], - "betas": eval(self.args.adam_betas), - "eps": self.args.adam_eps, - "weight_decay": self.args.weight_decay, + "lr": self.cfg.lr[0] + if isinstance(self.cfg.lr, Collection) + else self.cfg.lr, + "betas": eval(self.cfg.adam_betas), + "eps": self.cfg.adam_eps, + "weight_decay": self.cfg.weight_decay, } def average_params(self): diff --git a/fairseq/optim/bmuf.py b/fairseq/optim/bmuf.py index 3312f811..55f225ba 100644 --- a/fairseq/optim/bmuf.py +++ b/fairseq/optim/bmuf.py @@ -10,7 +10,7 @@ import torch.distributed as dist from fairseq.dataclass import FairseqDataclass from fairseq.dataclass.utils import gen_parser_from_dataclass from fairseq.optim.fairseq_optimizer import FairseqOptimizer -from omegaconf import II +from omegaconf import II, DictConfig @dataclass @@ -38,7 +38,7 @@ class FairseqBMUFConfig(FairseqDataclass): }, ) distributed_world_size: int = II( - "params.distributed_training.distributed_world_size" + "distributed_training.distributed_world_size" ) @@ -52,20 +52,19 @@ class FairseqBMUF(FairseqOptimizer): model-update filtering """ - def __init__(self, args, optimizer): - - super().__init__(args) + def __init__(self, cfg: DictConfig, optimizer): + super().__init__(cfg) self._optimizer = optimizer self._num_updates = 0 - self.sync_iter = self.args.global_sync_iter - self.block_momentum = self.args.block_momentum - self.block_lr = self.args.block_lr + self.sync_iter = cfg.global_sync_iter + self.block_momentum = cfg.block_momentum + self.block_lr = cfg.block_lr self._reset_local_data() - self.warmup_iteration = self.args.warmup_iterations - self.use_nbm = self.args.use_nbm + self.warmup_iteration = cfg.warmup_iterations + self.use_nbm = cfg.use_nbm self.initial_state = self._optimizer.state_dict() - self.average_sync = self.args.average_sync - self.world_size = self.args.distributed_world_size + self.average_sync = self.cfg.average_sync + self.world_size = self.cfg.distributed_world_size @staticmethod def add_args(parser): diff --git a/fairseq/optim/fairseq_optimizer.py b/fairseq/optim/fairseq_optimizer.py index 8a10399a..9c093833 100644 --- a/fairseq/optim/fairseq_optimizer.py +++ b/fairseq/optim/fairseq_optimizer.py @@ -9,9 +9,9 @@ from fairseq.dataclass.utils import gen_parser_from_dataclass class FairseqOptimizer(object): - def __init__(self, args): + def __init__(self, cfg): super().__init__() - self.args = args + self.cfg = cfg @classmethod def add_args(cls, parser): diff --git a/fairseq/optim/fp16_optimizer.py b/fairseq/optim/fp16_optimizer.py index b622fbde..b08a7237 100644 --- a/fairseq/optim/fp16_optimizer.py +++ b/fairseq/optim/fp16_optimizer.py @@ -7,7 +7,8 @@ from collections import defaultdict from itertools import chain import torch -from fairseq import optim, utils +from fairseq import optim +from omegaconf import DictConfig from .dynamic_loss_scaler import DynamicLossScaler @@ -211,7 +212,7 @@ class _FP16OptimizerMixin(object): for fp32_params in self.fp32_params.values(): fp32_params.grad.zero_() else: - raise ("self.fp32_params must be a tensor or dict") + raise RuntimeError("self.fp32_params must be a tensor or dict") else: for p32 in self.fp32_params: p32.grad.zero_() @@ -226,58 +227,60 @@ class FP16Optimizer(_FP16OptimizerMixin, optim.FairseqOptimizer): Wrap an *optimizer* to support FP16 (mixed precision) training. """ - def __init__(self, args, params, fp32_optimizer, fp32_params): - super().__init__(args) + def __init__(self, cfg: DictConfig, params, fp32_optimizer, fp32_params, **kwargs): + super().__init__(cfg.optimizer) self.fp16_params = params self.fp32_optimizer = fp32_optimizer self.fp32_params = fp32_params - if getattr(args, "fp16_scale_window", None) is None: - if len(args.update_freq) > 1: + if getattr(cfg.common, "fp16_scale_window", None) is None: + if len(cfg.optimization.update_freq) > 1: raise ValueError( "--fp16-scale-window must be given explicitly when using a " "custom --update-freq schedule" ) data_parallel_size = int( - args.distributed_world_size / args.model_parallel_size + cfg.distributed_training.distributed_world_size + / cfg.common.model_parallel_size + ) + scale_window = int( + 2 ** 14 / data_parallel_size / cfg.optimization.update_freq[0] ) - scale_window = int(2 ** 14 / data_parallel_size / args.update_freq[0]) else: - scale_window = args.fp16_scale_window + scale_window = cfg.common.fp16_scale_window - if not getattr(args, "bf16", False): + if not getattr(cfg.common, "bf16", False): self.scaler = DynamicLossScaler( - init_scale=args.fp16_init_scale, + init_scale=cfg.common.fp16_init_scale, scale_window=scale_window, - tolerance=args.fp16_scale_tolerance, - threshold=args.threshold_loss_scale, - min_loss_scale=args.min_loss_scale, + tolerance=cfg.common.fp16_scale_tolerance, + threshold=cfg.common.threshold_loss_scale, + min_loss_scale=cfg.common.min_loss_scale, ) else: # disable loss scaling for bfloat16 self.scaler = None @classmethod - def build_optimizer(cls, args, params): + def build_optimizer(cls, cfg: DictConfig, params, **kwargs): """ Args: - args (argparse.Namespace): fairseq args + cfg (omegaconf.DictConfig): fairseq args params (iterable): iterable of parameters to optimize """ - flatten = not getattr(args, "fp16_no_flatten_grads", False) - if getattr(args, "bf16", False): + flatten = not getattr(cfg.common, "fp16_no_flatten_grads", False) + if getattr(cfg.common, "bf16", False): flatten = False # mixed precision is faster on TPUs without flat grads - fp32_params = cls.build_fp32_params(args, params, flatten=flatten) + fp32_params = cls.build_fp32_params(cfg.optimizer, params, flatten=flatten) if flatten: - fp32_optimizer = optim.build_optimizer(args, [fp32_params]) + fp32_optimizer = optim.build_optimizer(cfg.optimizer, [fp32_params]) else: - fp32_optimizer = optim.build_optimizer(args, fp32_params) + fp32_optimizer = optim.build_optimizer(cfg.optimizer, fp32_params) if flatten and not fp32_optimizer.supports_flat_params: raise RuntimeError( - "chosen optimizer does not support flat params, " - "please set --fp16-no-flatten-grads" + f"chosen optimizer {fp32_optimizer.__class__.__name__} does not support flat params, please set --fp16-no-flatten-grads" ) - return cls(args, params, fp32_optimizer, fp32_params) + return cls(cfg, params, fp32_optimizer, fp32_params, **kwargs) @property def optimizer(self): @@ -427,49 +430,52 @@ class MemoryEfficientFP16Optimizer( *supports_memory_efficient_fp16* property. """ - def __init__(self, args, params, optimizer): + def __init__(self, cfg: DictConfig, params, optimizer, **kwargs): if not optimizer.supports_memory_efficient_fp16: raise ValueError( "Unsupported optimizer: {}".format(optimizer.__class__.__name__) ) - super().__init__(args) + super().__init__(cfg.optimizer) self.wrapped_optimizer = optimizer - if getattr(args, "fp16_scale_window", None) is None: - if len(args.update_freq) > 1: + if getattr(cfg.common, "fp16_scale_window", None) is None: + if len(cfg.optimization.update_freq) > 1: raise ValueError( "--fp16-scale-window must be given explicitly when using a " "custom --update-freq schedule" ) data_parallel_size = int( - args.distributed_world_size / args.model_parallel_size + cfg.distributed_training.distributed_world_size + / cfg.common.model_parallel_size + ) + scale_window = ( + 2 ** 14 / data_parallel_size / cfg.optimization.update_freq[0] ) - scale_window = 2 ** 14 / data_parallel_size / args.update_freq[0] else: - scale_window = args.fp16_scale_window + scale_window = cfg.common.fp16_scale_window - if not getattr(args, "bf16", False): + if not getattr(cfg.common, "bf16", False): self.scaler = DynamicLossScaler( - init_scale=args.fp16_init_scale, + init_scale=cfg.common.fp16_init_scale, scale_window=scale_window, - tolerance=args.fp16_scale_tolerance, - threshold=args.threshold_loss_scale, - min_loss_scale=args.min_loss_scale, + tolerance=cfg.common.fp16_scale_tolerance, + threshold=cfg.common.threshold_loss_scale, + min_loss_scale=cfg.common.min_loss_scale, ) else: # disable loss scaling for bfloat16 self.scaler = None @classmethod - def build_optimizer(cls, args, params): + def build_optimizer(cls, cfg: DictConfig, params, **kwargs): """ Args: args (argparse.Namespace): fairseq args params (iterable): iterable of parameters to optimize """ - fp16_optimizer = optim.build_optimizer(args, params) - return cls(args, params, fp16_optimizer) + fp16_optimizer = optim.build_optimizer(cfg.optimizer, params) + return cls(cfg, params, fp16_optimizer, **kwargs) @property def optimizer(self): diff --git a/fairseq/optim/lr_scheduler/__init__.py b/fairseq/optim/lr_scheduler/__init__.py index 7b72c257..f07d43c7 100644 --- a/fairseq/optim/lr_scheduler/__init__.py +++ b/fairseq/optim/lr_scheduler/__init__.py @@ -6,8 +6,6 @@ import importlib import os -from argparse import Namespace -from typing import Union from fairseq import registry from fairseq.optim.lr_scheduler.fairseq_lr_scheduler import ( # noqa @@ -27,8 +25,8 @@ from omegaconf import DictConfig ) -def build_lr_scheduler(lr_scheduler_cfg: Union[DictConfig, Namespace], optimizer): - return build_lr_scheduler_(lr_scheduler_cfg, optimizer) +def build_lr_scheduler(cfg: DictConfig, optimizer): + return build_lr_scheduler_(cfg, optimizer) # automatically import any Python files in the optim/lr_scheduler/ directory diff --git a/fairseq/optim/lr_scheduler/cosine_lr_scheduler.py b/fairseq/optim/lr_scheduler/cosine_lr_scheduler.py index 98d55750..c3c6663e 100644 --- a/fairseq/optim/lr_scheduler/cosine_lr_scheduler.py +++ b/fairseq/optim/lr_scheduler/cosine_lr_scheduler.py @@ -4,11 +4,12 @@ # LICENSE file in the root directory of this source tree. import math +from collections import Collection from dataclasses import dataclass, field from typing import List from fairseq.dataclass import FairseqDataclass -from omegaconf import II +from omegaconf import II, DictConfig from . import FairseqLRScheduler, register_lr_scheduler @@ -38,8 +39,8 @@ class CosineConfig(FairseqDataclass): default=0.1, metadata={"help": "shrink factor for annealing"} ) # TODO common var for parent class - lr: List[float] = II("params.optimization.lr") - max_update: int = II("params.optimization.max_update") + lr: List[float] = II("optimization.lr") + max_update: int = II("optimization.max_update") @register_lr_scheduler("cosine", dataclass=CosineConfig) @@ -66,43 +67,51 @@ class CosineSchedule(FairseqLRScheduler): after every iteration. """ - def __init__(self, args, optimizer): - super().__init__(args, optimizer) - if len(args.lr) > 1: + def __init__( + self, cfg: DictConfig, fairseq_optimizer + ): + super().__init__(cfg, fairseq_optimizer) + if isinstance(cfg.lr, Collection) and len(cfg.lr) > 1: raise ValueError( "Cannot use a fixed learning rate schedule with cosine." " Consider --lr-scheduler=fixed instead." ) - warmup_end_lr = args.max_lr - if args.warmup_init_lr < 0: - args.warmup_init_lr = args.lr[0] - - self.min_lr = args.lr[0] - self.max_lr = args.max_lr + warmup_end_lr = cfg.max_lr + lr = ( + cfg.lr[0] + if isinstance(cfg.lr, Collection) + else cfg.lr + ) + if cfg.warmup_init_lr < 0: + cfg.warmup_init_lr = lr + self.min_lr = lr + self.max_lr = cfg.max_lr assert self.max_lr > self.min_lr, "max_lr must be more than lr" - self.t_mult = args.t_mult - self.period = args.lr_period_updates + self.t_mult = cfg.t_mult + self.period = cfg.lr_period_updates if self.period <= 0: assert ( - args.max_update >= 0 + cfg.max_update >= 0 ), "Either --max_update or --lr-period-updates must be set" - self.period = args.max_update - args.warmup_updates + self.period = cfg.max_update - cfg.warmup_updates - if args.warmup_updates > 0: + if cfg.warmup_updates > 0: # linearly warmup for the first args.warmup_updates - self.lr_step = (warmup_end_lr - args.warmup_init_lr) / args.warmup_updates + self.lr_step = ( + warmup_end_lr - cfg.warmup_init_lr + ) / cfg.warmup_updates else: self.lr_step = 1 - self.warmup_updates = args.warmup_updates - self.lr_shrink = args.lr_shrink + self.warmup_updates = cfg.warmup_updates + self.lr_shrink = cfg.lr_shrink # initial learning rate - self.lr = args.warmup_init_lr + self.lr = cfg.warmup_init_lr self.optimizer.set_lr(self.lr) def step(self, epoch, val_loss=None): @@ -113,10 +122,10 @@ class CosineSchedule(FairseqLRScheduler): def step_update(self, num_updates): """Update the learning rate after each update.""" - if num_updates < self.args.warmup_updates: - self.lr = self.args.warmup_init_lr + num_updates * self.lr_step + if num_updates < self.cfg.warmup_updates: + self.lr = self.cfg.warmup_init_lr + num_updates * self.lr_step else: - curr_updates = num_updates - self.args.warmup_updates + curr_updates = num_updates - self.cfg.warmup_updates if self.t_mult != 1: i = math.floor( math.log( diff --git a/fairseq/optim/lr_scheduler/fairseq_lr_scheduler.py b/fairseq/optim/lr_scheduler/fairseq_lr_scheduler.py index 8fde0713..569e4482 100644 --- a/fairseq/optim/lr_scheduler/fairseq_lr_scheduler.py +++ b/fairseq/optim/lr_scheduler/fairseq_lr_scheduler.py @@ -11,11 +11,11 @@ from .. import FairseqOptimizer class FairseqLRScheduler(object): - def __init__(self, args, optimizer): + def __init__(self, cfg, optimizer): super().__init__() if not isinstance(optimizer, FairseqOptimizer): raise ValueError("optimizer must be an instance of FairseqOptimizer") - self.args = args + self.cfg = cfg self.optimizer = optimizer self.best = None diff --git a/fairseq/optim/lr_scheduler/inverse_square_root_schedule.py b/fairseq/optim/lr_scheduler/inverse_square_root_schedule.py index d27261ad..c42e0906 100644 --- a/fairseq/optim/lr_scheduler/inverse_square_root_schedule.py +++ b/fairseq/optim/lr_scheduler/inverse_square_root_schedule.py @@ -3,11 +3,12 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from collections import Collection from dataclasses import dataclass, field from typing import List from fairseq.dataclass import FairseqDataclass -from omegaconf import II +from omegaconf import II, DictConfig from . import FairseqLRScheduler, register_lr_scheduler @@ -25,7 +26,7 @@ class InverseSquareRootScheduleConfig(FairseqDataclass): }, ) # TODO common vars at parent class - lr: List[float] = II("params.optimization.lr") + lr: List[float] = II("optimization.lr") @register_lr_scheduler("inverse_sqrt", dataclass=InverseSquareRootScheduleConfig) @@ -48,25 +49,33 @@ class InverseSquareRootSchedule(FairseqLRScheduler): lr = decay_factor / sqrt(update_num) """ - def __init__(self, args, optimizer): - super().__init__(args, optimizer) - if len(args.lr) > 1: + def __init__(self, cfg: DictConfig, optimizer): + super().__init__(cfg, optimizer) + if isinstance(cfg.lr, Collection) and len(cfg.lr) > 1: raise ValueError( "Cannot use a fixed learning rate schedule with inverse_sqrt." " Consider --lr-scheduler=fixed instead." ) - warmup_end_lr = args.lr[0] - if args.warmup_init_lr < 0: - args.warmup_init_lr = 0 if args.warmup_updates > 0 else warmup_end_lr + warmup_end_lr = ( + cfg.lr[0] + if isinstance(cfg.lr, Collection) + else cfg.lr + ) + if cfg.warmup_init_lr < 0: + cfg.warmup_init_lr = ( + 0 if cfg.warmup_updates > 0 else warmup_end_lr + ) # linearly warmup for the first args.warmup_updates - self.lr_step = (warmup_end_lr - args.warmup_init_lr) / args.warmup_updates + self.lr_step = ( + warmup_end_lr - cfg.warmup_init_lr + ) / cfg.warmup_updates # then, decay prop. to the inverse square root of the update number - self.decay_factor = warmup_end_lr * args.warmup_updates ** 0.5 + self.decay_factor = warmup_end_lr * cfg.warmup_updates ** 0.5 # initial learning rate - self.lr = args.warmup_init_lr + self.lr = cfg.warmup_init_lr self.optimizer.set_lr(self.lr) def step(self, epoch, val_loss=None): @@ -77,8 +86,8 @@ class InverseSquareRootSchedule(FairseqLRScheduler): def step_update(self, num_updates): """Update the learning rate after each update.""" - if num_updates < self.args.warmup_updates: - self.lr = self.args.warmup_init_lr + num_updates * self.lr_step + if num_updates < self.cfg.warmup_updates: + self.lr = self.cfg.warmup_init_lr + num_updates * self.lr_step else: self.lr = self.decay_factor * num_updates ** -0.5 self.optimizer.set_lr(self.lr) diff --git a/fairseq/optim/nag.py b/fairseq/optim/nag.py index 58d2f356..3982a827 100644 --- a/fairseq/optim/nag.py +++ b/fairseq/optim/nag.py @@ -3,12 +3,13 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from collections import Collection from dataclasses import dataclass, field from typing import List import torch from fairseq.dataclass import FairseqDataclass -from omegaconf import II +from omegaconf import II, DictConfig from torch.optim.optimizer import Optimizer, required from . import FairseqOptimizer, register_optimizer @@ -19,13 +20,13 @@ class FairseqNAGConfig(FairseqDataclass): momentum: float = field(default=0.99, metadata={"help": "momentum factor"}) weight_decay: float = field(default=0.0, metadata={"help": "weight decay"}) # TODO common vars in parent class - lr: List[float] = II("params.optimization.lr") + lr: List[float] = II("optimization.lr") @register_optimizer("nag", dataclass=FairseqNAGConfig) class FairseqNAG(FairseqOptimizer): - def __init__(self, args, params): - super().__init__(args) + def __init__(self, cfg: DictConfig, params): + super().__init__(cfg) self._optimizer = NAG(params, **self.optimizer_config) @property @@ -37,9 +38,11 @@ class FairseqNAG(FairseqOptimizer): different learning rate. """ return { - "lr": self.args.lr[0], - "momentum": self.args.momentum, - "weight_decay": self.args.weight_decay, + "lr": self.cfg.lr[0] + if isinstance(self.cfg.lr, Collection) + else self.cfg.lr, + "momentum": self.cfg.momentum, + "weight_decay": self.cfg.weight_decay, } diff --git a/fairseq/optim/shard.py b/fairseq/optim/shard.py index a035a1c1..ecef05b4 100644 --- a/fairseq/optim/shard.py +++ b/fairseq/optim/shard.py @@ -12,7 +12,7 @@ except ImportError: _has_fairscale = False -def shard_(args, optimizer, group): +def shard_(optimizer, group): if not _has_fairscale: raise ImportError( "\n\nPlease install the fairscale package:" "\n\n pip install fairscale" diff --git a/fairseq/options.py b/fairseq/options.py index 1a24fcca..6bc526ce 100644 --- a/fairseq/options.py +++ b/fairseq/options.py @@ -10,13 +10,15 @@ import torch from fairseq import utils from fairseq.data.indexed_dataset import get_available_dataset_impl from fairseq.dataclass.data_class import ( - CheckpointParams, - CommonEvalParams, - CommonParams, - DatasetParams, - DistributedTrainingParams, - EvalLMParams, - OptimizationParams, + CheckpointConfig, + CommonConfig, + CommonEvalConfig, + DatasetConfig, + DistributedTrainingConfig, + EvalLMConfig, + GenerationConfig, + InteractiveConfig, + OptimizationConfig, ) from fairseq.dataclass.utils import gen_parser_from_dataclass @@ -45,6 +47,7 @@ def get_generation_parser(interactive=False, default_task="translation"): add_dataset_args(parser, gen=True) add_distributed_training_args(parser, default_world_size=1) add_generation_args(parser) + add_checkpoint_args(parser) if interactive: add_interactive_args(parser) return parser @@ -67,7 +70,7 @@ def get_validation_parser(default_task=None): add_dataset_args(parser, train=True) add_distributed_training_args(parser, default_world_size=1) group = parser.add_argument_group("Evaluation") - gen_parser_from_dataclass(group, CommonEvalParams()) + gen_parser_from_dataclass(group, CommonEvalConfig()) return parser @@ -210,7 +213,7 @@ def get_parser(desc, default_task="translation"): utils.import_user_module(usr_args) parser = argparse.ArgumentParser(allow_abbrev=False) - gen_parser_from_dataclass(parser, CommonParams()) + gen_parser_from_dataclass(parser, CommonConfig()) from fairseq.registry import REGISTRIES @@ -283,7 +286,7 @@ def add_preprocess_args(parser): def add_dataset_args(parser, train=False, gen=False): group = parser.add_argument_group("dataset_data_loading") - gen_parser_from_dataclass(group, DatasetParams()) + gen_parser_from_dataclass(group, DatasetConfig()) # fmt: on return group @@ -293,7 +296,7 @@ def add_distributed_training_args(parser, default_world_size=None): if default_world_size is None: default_world_size = max(1, torch.cuda.device_count()) gen_parser_from_dataclass( - group, DistributedTrainingParams(distributed_world_size=default_world_size) + group, DistributedTrainingConfig(distributed_world_size=default_world_size) ) return group @@ -301,7 +304,7 @@ def add_distributed_training_args(parser, default_world_size=None): def add_optimization_args(parser): group = parser.add_argument_group("optimization") # fmt: off - gen_parser_from_dataclass(group, OptimizationParams()) + gen_parser_from_dataclass(group, OptimizationConfig()) # fmt: on return group @@ -309,117 +312,31 @@ def add_optimization_args(parser): def add_checkpoint_args(parser): group = parser.add_argument_group("checkpoint") # fmt: off - gen_parser_from_dataclass(group, CheckpointParams()) + gen_parser_from_dataclass(group, CheckpointConfig()) # fmt: on return group def add_common_eval_args(group): - gen_parser_from_dataclass(group, CommonEvalParams()) + gen_parser_from_dataclass(group, CommonEvalConfig()) def add_eval_lm_args(parser): group = parser.add_argument_group("LM Evaluation") add_common_eval_args(group) - gen_parser_from_dataclass(group, EvalLMParams()) + gen_parser_from_dataclass(group, EvalLMConfig()) def add_generation_args(parser): group = parser.add_argument_group("Generation") add_common_eval_args(group) - # fmt: off - group.add_argument('--beam', default=5, type=int, metavar='N', - help='beam size') - group.add_argument('--nbest', default=1, type=int, metavar='N', - help='number of hypotheses to output') - group.add_argument('--max-len-a', default=0, type=float, metavar='N', - help=('generate sequences of maximum length ax + b, ' - 'where x is the source length')) - group.add_argument('--max-len-b', default=200, type=int, metavar='N', - help=('generate sequences of maximum length ax + b, ' - 'where x is the source length')) - group.add_argument('--min-len', default=1, type=float, metavar='N', - help=('minimum generation length')) - group.add_argument('--match-source-len', default=False, action='store_true', - help=('generations should match the source length')) - group.add_argument('--no-early-stop', action='store_true', - help='deprecated') - group.add_argument('--unnormalized', action='store_true', - help='compare unnormalized hypothesis scores') - group.add_argument('--no-beamable-mm', action='store_true', - help='don\'t use BeamableMM in attention layers') - group.add_argument('--lenpen', default=1, type=float, - help='length penalty: <1.0 favors shorter, >1.0 favors longer sentences') - group.add_argument('--unkpen', default=0, type=float, - help='unknown word penalty: <0 produces more unks, >0 produces fewer') - group.add_argument('--replace-unk', nargs='?', const=True, default=None, - help='perform unknown replacement (optionally with alignment dictionary)') - group.add_argument('--sacrebleu', action='store_true', - help='score with sacrebleu') - group.add_argument('--score-reference', action='store_true', - help='just score the reference translation') - group.add_argument('--prefix-size', default=0, type=int, metavar='PS', - help='initialize generation by target prefix of given length') - group.add_argument('--no-repeat-ngram-size', default=0, type=int, metavar='N', - help='ngram blocking such that this size ngram cannot be repeated in the generation') - group.add_argument('--sampling', action='store_true', - help='sample hypotheses instead of using beam search') - group.add_argument('--sampling-topk', default=-1, type=int, metavar='PS', - help='sample from top K likely next words instead of all words') - group.add_argument('--sampling-topp', default=-1.0, type=float, metavar='PS', - help='sample from the smallest set whose cumulative probability mass exceeds p for next words') - group.add_argument('--constraints', const="ordered", nargs="?", choices=["ordered", "unordered"], - help='enables lexically constrained decoding') - group.add_argument('--temperature', default=1., type=float, metavar='N', - help='temperature for generation') - group.add_argument('--diverse-beam-groups', default=-1, type=int, metavar='N', - help='number of groups for Diverse Beam Search') - group.add_argument('--diverse-beam-strength', default=0.5, type=float, metavar='N', - help='strength of diversity penalty for Diverse Beam Search') - group.add_argument('--diversity-rate', default=-1.0, type=float, metavar='N', - help='strength of diversity penalty for Diverse Siblings Search') - group.add_argument('--print-alignment', action='store_true', - help='if set, uses attention feedback to compute and print alignment to source tokens') - group.add_argument('--print-step', action='store_true') - - group.add_argument('--lm-path', default=None, type=str, metavar='PATH', - help='path to lm checkpoint for lm fusion') - group.add_argument('--lm-weight', default=0.0, type=float, metavar='N', - help='weight for lm probs for lm fusion') - - # arguments for iterative refinement generator - group.add_argument('--iter-decode-eos-penalty', default=0.0, type=float, metavar='N', - help='if > 0.0, it penalized early-stopping in decoding.') - group.add_argument('--iter-decode-max-iter', default=10, type=int, metavar='N', - help='maximum iterations for iterative refinement.') - group.add_argument('--iter-decode-force-max-iter', action='store_true', - help='if set, run exact the maximum number of iterations without early stop') - group.add_argument('--iter-decode-with-beam', default=1, type=int, metavar='N', - help='if > 1, model will generate translations varying by the lengths.') - group.add_argument('--iter-decode-with-external-reranker', action='store_true', - help='if set, the last checkpoint are assumed to be a reranker to rescore the translations'), - group.add_argument('--retain-iter-history', action='store_true', - help='if set, decoding returns the whole history of iterative refinement') - group.add_argument('--retain-dropout', action='store_true', - help='Use dropout at inference time') - group.add_argument('--retain-dropout-modules', default=None, nargs='+', type=str, - help='if set, only retain dropout for the specified modules; ' - 'if not set, then dropout will be retained for all modules') - - # special decoding format for advanced decoding. - group.add_argument('--decoding-format', default=None, type=str, choices=['unigram', 'ensemble', 'vote', 'dp', 'bs']) - # fmt: on + gen_parser_from_dataclass(group, GenerationConfig()) return group def add_interactive_args(parser): group = parser.add_argument_group("Interactive") - # fmt: off - group.add_argument('--buffer-size', default=0, type=int, metavar='N', - help='read this many sentences into a buffer before processing them') - group.add_argument('--input', default='-', type=str, metavar='FILE', - help='file to read from; use - for stdin') - # fmt: on + gen_parser_from_dataclass(group, InteractiveConfig()) def add_model_args(parser): diff --git a/fairseq/quantization_utils.py b/fairseq/quantization_utils.py index 69dd61d7..11fc414c 100644 --- a/fairseq/quantization_utils.py +++ b/fairseq/quantization_utils.py @@ -6,13 +6,14 @@ import logging from fairseq.modules.quantization import pq, quantization_options, scalar +from omegaconf import DictConfig logger = logging.getLogger(__name__) -def quantize_model_scalar(model, args): - quant_noise_scalar = getattr(args, "quant_noise_scalar", 0) +def quantize_model_scalar(model, model_cfg: DictConfig): + quant_noise_scalar = getattr(model_cfg, "quant_noise_scalar", 0) or 0 if quant_noise_scalar > 0: # quantize_model edits the model in place scalar.quantize_model_(model, p=quant_noise_scalar, bits=8, update_step=1000) diff --git a/fairseq/registry.py b/fairseq/registry.py index 382dec22..4446084d 100644 --- a/fairseq/registry.py +++ b/fairseq/registry.py @@ -3,14 +3,13 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import argparse from argparse import Namespace + from typing import Union - from fairseq.dataclass import FairseqDataclass +from fairseq.dataclass.utils import populate_dataclass from omegaconf import DictConfig - REGISTRIES = {} @@ -25,33 +24,30 @@ def setup_registry(registry_name: str, base_class=None, default=None, required=F # maintain a registry of all registries if registry_name in REGISTRIES: return # registry already exists - REGISTRIES[registry_name] = {"registry": REGISTRY, "default": default} + REGISTRIES[registry_name] = {"registry": REGISTRY, "default": default, "dataclass_registry": DATACLASS_REGISTRY} - def build_x(args: Union[DictConfig, Namespace], *extra_args, **extra_kwargs): - if isinstance(args, DictConfig): - if getattr(args, "_name", None) is not None: - choice = args._name - elif hasattr(args, registry_name): - choice = args.registry_name - else: - raise RuntimeError( - f"Neither _name nor {registry_name} in args, args = {args}" - ) + def build_x(cfg: Union[DictConfig, str, Namespace], *extra_args, **extra_kwargs): + if isinstance(cfg, DictConfig): + choice = cfg._name + elif isinstance(cfg, str): + choice = cfg else: - choice = getattr(args, registry_name, None) + choice = getattr(cfg, registry_name, None) + if choice in DATACLASS_REGISTRY: + cfg = populate_dataclass(cfg, DATACLASS_REGISTRY[choice]()) if choice is None: if required: - raise ValueError("--{} is required!".format(registry_name)) + raise ValueError('{} is required!'.format(registry_name)) return None + cls = REGISTRY[choice] if hasattr(cls, "build_" + registry_name): builder = getattr(cls, "build_" + registry_name) else: builder = cls - if isinstance(args, Namespace): - set_defaults(args, cls) - return builder(args, *extra_args, **extra_kwargs) + + return builder(cfg, *extra_args, **extra_kwargs) def register_x(name, dataclass=None): def register_x_cls(cls): @@ -77,30 +73,10 @@ def setup_registry(registry_name: str, base_class=None, default=None, required=F cls.__dataclass = dataclass REGISTRY[name] = cls - DATACLASS_REGISTRY[name] = cls.__dataclass - REGISTRY_CLASS_NAMES.add(cls.__name__) + if cls.__dataclass is not None: + DATACLASS_REGISTRY[name] = cls.__dataclass return cls return register_x_cls return build_x, register_x, REGISTRY, DATACLASS_REGISTRY - - -def set_defaults(args: Namespace, cls): - """Helper to set default arguments based on *add_args*.""" - if not hasattr(cls, "add_args"): - return - parser = argparse.ArgumentParser( - argument_default=argparse.SUPPRESS, allow_abbrev=False - ) - cls.add_args(parser) - # copied from argparse.py: - defaults = argparse.Namespace() - for action in parser._actions: - if action.dest is not argparse.SUPPRESS: - if not hasattr(defaults, action.dest): - if action.default is not argparse.SUPPRESS: - setattr(defaults, action.dest, action.default) - for key, default_value in vars(defaults).items(): - if not hasattr(args, key): - setattr(args, key, default_value) diff --git a/fairseq/scoring/__init__.py b/fairseq/scoring/__init__.py index 4be0cb51..8c706cb5 100644 --- a/fairseq/scoring/__init__.py +++ b/fairseq/scoring/__init__.py @@ -9,11 +9,12 @@ import os from abc import ABC, abstractmethod from fairseq import registry +from omegaconf import DictConfig class BaseScorer(ABC): - def __init__(self, args): - self.args = args + def __init__(self, cfg): + self.cfg = cfg self.ref = [] self.pred = [] @@ -39,19 +40,17 @@ _build_scorer, register_scorer, SCORER_REGISTRY, _ = registry.setup_registry( ) -def build_scorer(args, tgt_dict): - from fairseq import utils +def build_scorer(choice, tgt_dict): + if isinstance(choice, DictConfig): + choice = choice._name - if args.sacrebleu: - utils.deprecation_warning( - "--sacrebleu is deprecated. Please use --scoring sacrebleu instead." - ) - args.scoring = "sacrebleu" - if args.scoring == "bleu": + if choice == "bleu": from fairseq.scoring import bleu - return bleu.Scorer(tgt_dict.pad(), tgt_dict.eos(), tgt_dict.unk()) - return _build_scorer(args) + return bleu.Scorer( + bleu.BleuConfig(pad=tgt_dict.pad(), eos=tgt_dict.eos(), unk=tgt_dict.unk()) + ) + return _build_scorer(choice) # automatically import any Python files in the current directory diff --git a/fairseq/scoring/bleu.py b/fairseq/scoring/bleu.py index 7f8bd73b..97de5f96 100644 --- a/fairseq/scoring/bleu.py +++ b/fairseq/scoring/bleu.py @@ -6,8 +6,10 @@ import ctypes import math import sys +from dataclasses import dataclass, field import torch +from fairseq.dataclass import FairseqDataclass from fairseq.scoring import BaseScorer, register_scorer from fairseq.scoring.tokenizer import EvaluationTokenizer @@ -27,31 +29,32 @@ class BleuStat(ctypes.Structure): ] -@register_scorer("sacrebleu") +@dataclass +class SacrebleuConfig(FairseqDataclass): + sacrebleu_tokenizer: EvaluationTokenizer.ALL_TOKENIZER_TYPES = field( + default="13a", metadata={"help": "tokenizer"} + ) + sacrebleu_lowercase: bool = field( + default=False, metadata={"help": "apply lowercasing"} + ) + sacrebleu_char_level: bool = field( + default=False, metadata={"help": "evaluate at character level"} + ) + + +@register_scorer("sacrebleu", dataclass=SacrebleuConfig) class SacrebleuScorer(BaseScorer): - def __init__(self, args): - super(SacrebleuScorer, self).__init__(args) + def __init__(self, cfg): + super(SacrebleuScorer, self).__init__(cfg) import sacrebleu self.sacrebleu = sacrebleu self.tokenizer = EvaluationTokenizer( - tokenizer_type=self.args.sacrebleu_tokenizer, - lowercase=self.args.sacrebleu_lowercase, - character_tokenization=self.args.sacrebleu_char_level, + tokenizer_type=cfg.sacrebleu_tokenizer, + lowercase=cfg.sacrebleu_lowercase, + character_tokenization=cfg.sacrebleu_char_level, ) - @staticmethod - def add_args(parser): - # fmt: off - parser.add_argument('--sacrebleu-tokenizer', type=str, default='13a', - choices=EvaluationTokenizer.ALL_TOKENIZER_TYPES, - help='tokenizer') - parser.add_argument('--sacrebleu-lowercase', type=str, default=False, - help='apply lowercasing') - parser.add_argument('--sacrebleu-char-level', action='store_true', - help='evaluate at character level') - # fmt: on - def add_string(self, ref, pred): self.ref.append(self.tokenizer.tokenize(ref)) self.pred.append(self.tokenizer.tokenize(pred)) @@ -68,13 +71,20 @@ class SacrebleuScorer(BaseScorer): ).format() -@register_scorer("bleu") +@dataclass +class BleuConfig(FairseqDataclass): + pad: int = field(default=1, metadata={"help": "padding index"}) + eos: int = field(default=2, metadata={"help": "eos index"}) + unk: int = field(default=3, metadata={"help": "unk index"}) + + +@register_scorer("bleu", dataclass=BleuConfig) class Scorer(object): - def __init__(self, pad, eos, unk): + def __init__(self, cfg): self.stat = BleuStat() - self.pad = pad - self.eos = eos - self.unk = unk + self.pad = cfg.pad + self.eos = cfg.eos + self.unk = cfg.unk try: from fairseq import libbleu diff --git a/fairseq/scoring/tokenizer.py b/fairseq/scoring/tokenizer.py index dbcc6e4d..0d0702bf 100644 --- a/fairseq/scoring/tokenizer.py +++ b/fairseq/scoring/tokenizer.py @@ -5,6 +5,8 @@ import unicodedata +from fairseq.dataclass.utils import ChoiceEnum + class EvaluationTokenizer(object): """A generic evaluation-time tokenizer, which leverages built-in tokenizers @@ -22,7 +24,7 @@ class EvaluationTokenizer(object): SPACE = chr(32) SPACE_ESCAPE = chr(9601) - ALL_TOKENIZER_TYPES = ["none", "13a", "intl", "zh", "ja-mecab"] + ALL_TOKENIZER_TYPES = ChoiceEnum(["none", "13a", "intl", "zh", "ja-mecab"]) def __init__( self, @@ -33,7 +35,7 @@ class EvaluationTokenizer(object): ): from sacrebleu.tokenizers import TOKENIZERS - assert tokenizer_type in self.ALL_TOKENIZER_TYPES + assert tokenizer_type in TOKENIZERS, f"{tokenizer_type}, {TOKENIZERS}" self.lowercase = lowercase self.punctuation_removal = punctuation_removal self.character_tokenization = character_tokenization diff --git a/fairseq/scoring/wer.py b/fairseq/scoring/wer.py index 21efefd9..633dc47c 100644 --- a/fairseq/scoring/wer.py +++ b/fairseq/scoring/wer.py @@ -3,14 +3,31 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from dataclasses import dataclass, field + +from fairseq.dataclass import FairseqDataclass from fairseq.scoring import BaseScorer, register_scorer from fairseq.scoring.tokenizer import EvaluationTokenizer -@register_scorer("wer") +@dataclass +class WerScorerConfig(FairseqDataclass): + wer_tokenizer: EvaluationTokenizer.ALL_TOKENIZER_TYPES = field( + default="none", metadata={"help": "sacreBLEU tokenizer to use for evaluation"} + ) + wer_remove_punct: bool = field( + default=False, metadata={"help": "remove punctuation"} + ) + wer_char_level: bool = field( + default=False, metadata={"help": "evaluate at character level"} + ) + wer_lowercase: bool = field(default=False, metadata={"help": "lowercasing"}) + + +@register_scorer("wer", dataclass=WerScorerConfig) class WerScorer(BaseScorer): - def __init__(self, args): - super().__init__(args) + def __init__(self, cfg): + super().__init__(cfg) self.reset() try: import editdistance as ed @@ -18,26 +35,12 @@ class WerScorer(BaseScorer): raise ImportError("Please install editdistance to use WER scorer") self.ed = ed self.tokenizer = EvaluationTokenizer( - tokenizer_type=self.args.wer_tokenizer, - lowercase=self.args.wer_lowercase, - punctuation_removal=self.args.wer_remove_punct, - character_tokenization=self.args.wer_char_level, + tokenizer_type=self.cfg.wer_tokenizer, + lowercase=self.cfg.wer_lowercase, + punctuation_removal=self.cfg.wer_remove_punct, + character_tokenization=self.cfg.wer_char_level, ) - @staticmethod - def add_args(parser): - # fmt: off - parser.add_argument('--wer-tokenizer', type=str, default='none', - choices=EvaluationTokenizer.ALL_TOKENIZER_TYPES, - help='sacreBLEU tokenizer to use for evaluation') - parser.add_argument('--wer-remove-punct', action='store_true', - help='remove punctuation') - parser.add_argument('--wer-char-level', action='store_true', - help='evaluate at character level') - parser.add_argument('--wer-lowercase', action='store_true', - help='lowercasing') - # fmt: on - def reset(self): self.distance = 0 self.ref_length = 0 diff --git a/fairseq/tasks/__init__.py b/fairseq/tasks/__init__.py index e0abce25..41f461f8 100644 --- a/fairseq/tasks/__init__.py +++ b/fairseq/tasks/__init__.py @@ -7,8 +7,6 @@ import argparse import importlib import os -from argparse import Namespace -from typing import Union from fairseq.dataclass import FairseqDataclass from omegaconf import DictConfig @@ -22,10 +20,10 @@ TASK_REGISTRY = {} TASK_CLASS_NAMES = set() -def setup_task(task_cfg: Union[DictConfig, Namespace], **kwargs): - if isinstance(task_cfg, DictConfig): - return TASK_REGISTRY[task_cfg._name].setup_task(task_cfg, **kwargs) - return TASK_REGISTRY[task_cfg.task].setup_task(task_cfg, **kwargs) +def setup_task(cfg: DictConfig, **kwargs): + if isinstance(cfg, DictConfig): + return TASK_REGISTRY[cfg._name].setup_task(cfg, **kwargs) + return TASK_REGISTRY[cfg.task].setup_task(cfg, **kwargs) def register_task(name, dataclass=None): @@ -70,7 +68,8 @@ def register_task(name, dataclass=None): ) cls.__dataclass = dataclass - TASK_DATACLASS_REGISTRY[name] = dataclass + if dataclass is not None: + TASK_DATACLASS_REGISTRY[name] = dataclass return cls diff --git a/fairseq/tasks/audio_pretraining.py b/fairseq/tasks/audio_pretraining.py index ff2342af..a831ad6e 100644 --- a/fairseq/tasks/audio_pretraining.py +++ b/fairseq/tasks/audio_pretraining.py @@ -79,7 +79,7 @@ class AudioPretrainingTask(LegacyFairseqTask): """Setup the task (e.g., load dictionaries). Args: - args (argparse.Namespace): parsed command-line arguments + args (omegaconf.DictConfig): parsed command-line arguments """ return cls(args) diff --git a/fairseq/tasks/fairseq_task.py b/fairseq/tasks/fairseq_task.py index 0a96aeb1..3cdb64cf 100644 --- a/fairseq/tasks/fairseq_task.py +++ b/fairseq/tasks/fairseq_task.py @@ -12,6 +12,7 @@ import torch from fairseq import metrics, search, tokenizer, utils from fairseq.data import Dictionary, FairseqDataset, data_utils, encoders, iterators from fairseq.dataclass.utils import gen_parser_from_dataclass +from omegaconf import DictConfig logger = logging.getLogger(__name__) @@ -39,8 +40,8 @@ class FairseqTask(object): """ return criterion.logging_outputs_can_be_summed() - def __init__(self, args): - self.args = args + def __init__(self, cfg: DictConfig, **kwargs): + self.cfg = cfg self.datasets = {} self.dataset_to_epoch_iter = {} @@ -78,16 +79,16 @@ class FairseqTask(object): return d @classmethod - def setup_task(cls, args, **kwargs): + def setup_task(cls, cfg: DictConfig, **kwargs): """Setup the task (e.g., load dictionaries). Args: - args (argparse.Namespace): parsed command-line arguments + cfg (omegaconf.DictConfig): parsed command-line arguments """ - return cls(args, **kwargs) + return cls(cfg, **kwargs) def has_sharded_data(self, split): - return os.pathsep in getattr(self.args, "data", "") + return os.pathsep in getattr(self.cfg, "data", "") def load_dataset(self, split, combine=False, **kwargs): """Load a given dataset split. @@ -254,39 +255,39 @@ class FairseqTask(object): return epoch_iter - def build_model(self, args): + def build_model(self, cfg: DictConfig): """ Build the :class:`~fairseq.models.BaseFairseqModel` instance for this task. Args: - args (argparse.Namespace): parsed command-line arguments + cfg (omegaconf.DictConfig): configuration object Returns: a :class:`~fairseq.models.BaseFairseqModel` instance """ from fairseq import models, quantization_utils - model = models.build_model(args, self) - if getattr(args, "tpu", False): + model = models.build_model(cfg, self) + if getattr(cfg, "tpu", False): model.prepare_for_tpu_() - model = quantization_utils.quantize_model_scalar(model, args) + model = quantization_utils.quantize_model_scalar(model, cfg) return model - def build_criterion(self, args): + def build_criterion(self, cfg: DictConfig): """ Build the :class:`~fairseq.criterions.FairseqCriterion` instance for this task. Args: - args (argparse.Namespace): parsed command-line arguments + cfg (omegaconf.DictConfig): configration object Returns: a :class:`~fairseq.criterions.FairseqCriterion` instance """ from fairseq import criterions - return criterions.build_criterion(args, self) + return criterions.build_criterion(cfg, self) def build_generator( self, models, args, seq_gen_cls=None, extra_gen_cls_kwargs=None diff --git a/fairseq/tasks/language_modeling.py b/fairseq/tasks/language_modeling.py index 8792c648..6e85417f 100644 --- a/fairseq/tasks/language_modeling.py +++ b/fairseq/tasks/language_modeling.py @@ -28,7 +28,7 @@ from fairseq.data import ( from fairseq.data.indexed_dataset import get_available_dataset_impl from fairseq.data.shorten_dataset import maybe_shorten_dataset from fairseq.dataclass import ChoiceEnum, FairseqDataclass -from fairseq.tasks import FairseqTask, register_task +from fairseq.tasks import LegacyFairseqTask, register_task from omegaconf import II @@ -85,16 +85,16 @@ class LanguageModelingConfig(FairseqDataclass): }, ) # TODO common vars below add to parent - seed: int = II("params.common.seed") + seed: int = II("common.seed") dataset_impl: Optional[ChoiceEnum(get_available_dataset_impl())] = II( - "params.dataset.dataset_impl" + "dataset.dataset_impl" ) - data_buffer_size: int = II("params.dataset.data_buffer_size") - tpu: bool = II("params.common.tpu") + data_buffer_size: int = II("dataset.data_buffer_size") + tpu: bool = II("common.tpu") @register_task("language_modeling", dataclass=LanguageModelingConfig) -class LanguageModelingTask(FairseqTask): +class LanguageModelingTask(LegacyFairseqTask): """ Train a language model. diff --git a/fairseq/tasks/multilingual_translation.py b/fairseq/tasks/multilingual_translation.py index f6cb17f1..26e0b529 100644 --- a/fairseq/tasks/multilingual_translation.py +++ b/fairseq/tasks/multilingual_translation.py @@ -117,7 +117,7 @@ class MultilingualTranslationTask(LegacyFairseqTask): return cls(args, dicts, training) @classmethod - def prepare(cls, args, **kargs): + def update_args(cls, args): args.left_pad_source = utils.eval_bool(args.left_pad_source) args.left_pad_target = utils.eval_bool(args.left_pad_target) @@ -127,6 +127,10 @@ class MultilingualTranslationTask(LegacyFairseqTask): ) if isinstance(args.lang_pairs, str): args.lang_pairs = args.lang_pairs.split(",") + + @classmethod + def prepare(cls, args, **kargs): + cls.update_args(args) sorted_langs = sorted( list({x for lang_pair in args.lang_pairs for x in lang_pair.split("-")}) ) @@ -298,6 +302,10 @@ class MultilingualTranslationTask(LegacyFairseqTask): if len(messages) > 0: raise ValueError(" ".join(messages)) + # Update args -> the fact that the constructor here + # changes the args object doesn't mean you get the same one here + self.update_args(args) + # Check if task args are consistant with model args check_args() diff --git a/fairseq/tasks/speech_to_text.py b/fairseq/tasks/speech_to_text.py index 6d222f0d..c200bb14 100644 --- a/fairseq/tasks/speech_to_text.py +++ b/fairseq/tasks/speech_to_text.py @@ -13,7 +13,7 @@ from fairseq.data.audio.speech_to_text_dataset import ( SpeechToTextDataset, SpeechToTextDatasetCreator, ) -from fairseq.tasks import FairseqTask, register_task +from fairseq.tasks import LegacyFairseqTask, register_task logging.basicConfig( @@ -25,7 +25,7 @@ logger = logging.getLogger(__name__) @register_task("speech_to_text") -class SpeechToTextTask(FairseqTask): +class SpeechToTextTask(LegacyFairseqTask): @staticmethod def add_args(parser): parser.add_argument("data", help="manifest root path") diff --git a/fairseq/trainer.py b/fairseq/trainer.py index 0069b794..8b00e8b4 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -11,15 +11,18 @@ import contextlib import logging import sys import time +from argparse import Namespace from itertools import chain from typing import Any, Dict, List import torch from fairseq import checkpoint_utils, distributed_utils, models, optim, utils +from fairseq.dataclass.utils import convert_namespace_to_omegaconf from fairseq.file_io import PathManager from fairseq.logging import meters, metrics from fairseq.nan_detector import NanDetector from fairseq.optim import lr_scheduler +from omegaconf import DictConfig logger = logging.getLogger(__name__) @@ -35,19 +38,25 @@ class Trainer(object): communication of the gradients across workers. """ - def __init__(self, args, task, model, criterion, quantizer=None): - self.args = args + def __init__(self, cfg: DictConfig, task, model, criterion, quantizer=None): + + if isinstance(cfg, Namespace): + logger.warning( + "argparse.Namespace configuration is deprecated! Automatically converting to OmegaConf" + ) + cfg = convert_namespace_to_omegaconf(cfg) + + self.cfg = cfg self.task = task # catalog shared parameters shared_params = _catalog_shared_params(model) - - self.tpu = getattr(args, "tpu", False) - self.cuda = torch.cuda.is_available() and not args.cpu and not self.tpu + self.tpu = cfg.common.tpu + self.cuda = torch.cuda.is_available() and not cfg.common.cpu and not self.tpu if self.cuda: self.device = torch.device("cuda") elif self.tpu: - self.device = utils.get_tpu_device(args) + self.device = utils.get_tpu_device() else: self.device = torch.device("cpu") @@ -58,19 +67,21 @@ class Trainer(object): import torch_xla.core.xla_model as xm self._model = xm.send_cpu_data_to_device(self._model, self.device) - if args.fp16: + if cfg.common.fp16: self._criterion = self._criterion.half() self._model = self._model.half() - elif args.bf16: + elif cfg.common.bf16: self._criterion = self._criterion.to(dtype=torch.bfloat16) self._model = self._model.to(dtype=torch.bfloat16) - if not args.pipeline_model_parallel: + if not cfg.distributed_training.pipeline_model_parallel: self._criterion = self._criterion.to(device=self.device) self._model = self._model.to(device=self.device) - self.pipeline_model_parallel = args.pipeline_model_parallel + self.pipeline_model_parallel = cfg.distributed_training.pipeline_model_parallel self.last_device = None if self.cuda and self.pipeline_model_parallel: - self.last_device = torch.device(args.pipeline_devices[-1]) + self.last_device = torch.device( + cfg.distributed_training.pipeline_devices[-1] + ) # check that shared parameters are preserved after device transfer for shared_param in shared_params: @@ -129,7 +140,7 @@ class Trainer(object): @property def data_parallel_world_size(self): - return self.args.distributed_world_size + return self.cfg.distributed_training.distributed_world_size @property def data_parallel_process_group(self): @@ -140,11 +151,11 @@ class Trainer(object): @property def data_parallel_rank(self): - return self.args.distributed_rank + return self.cfg.distributed_training.distributed_rank @property def is_data_parallel_master(self): - return distributed_utils.is_master(self.args) + return distributed_utils.is_master(self.cfg.distributed_training) @property def criterion(self): @@ -152,11 +163,11 @@ class Trainer(object): if ( utils.has_parameters(self._criterion) and self.data_parallel_world_size > 1 - and not self.args.use_bmuf + and not self.cfg.optimization.use_bmuf and not self.tpu ): self._wrapped_criterion = models.DistributedFairseqModel( - self.args, + self.cfg.distributed_training, self._criterion, process_group=self.data_parallel_process_group, ) @@ -169,11 +180,11 @@ class Trainer(object): if self._wrapped_model is None: if ( self.data_parallel_world_size > 1 - and not self.args.use_bmuf + and not self.cfg.optimization.use_bmuf and not self.tpu ): self._wrapped_model = models.DistributedFairseqModel( - self.args, + self.cfg.distributed_training, self._model, process_group=self.data_parallel_process_group, ) @@ -201,44 +212,51 @@ class Trainer(object): ) ) - if self.args.fp16 or self.args.bf16: + if self.cfg.common.fp16 or self.cfg.common.bf16: if self.cuda and torch.cuda.get_device_capability(0)[0] < 7: logger.info( "NOTE: your device does NOT support faster training with --fp16, " "please switch to FP32 which is likely to be faster" ) - if self.args.memory_efficient_fp16 or self.args.memory_efficient_bf16: + if ( + self.cfg.common.memory_efficient_fp16 + or self.cfg.common.memory_efficient_bf16 + ): self._optimizer = optim.MemoryEfficientFP16Optimizer.build_optimizer( - self.args, params + self.cfg, params ) else: - self._optimizer = optim.FP16Optimizer.build_optimizer(self.args, params) + self._optimizer = optim.FP16Optimizer.build_optimizer(self.cfg, params) else: if self.cuda and torch.cuda.get_device_capability(0)[0] >= 7: logger.info("NOTE: your device may support faster training with --fp16") - self._optimizer = optim.build_optimizer(self.args, params) + self._optimizer = optim.build_optimizer(self.cfg.optimizer, params) - if self.args.use_bmuf: - self._optimizer = optim.FairseqBMUF(self.args, self._optimizer) + if self.cfg.optimization.use_bmuf: + self._optimizer = optim.FairseqBMUF( + self.cfg.bmuf, + self._optimizer, + ) - if self.args.zero_sharding == "os": + if self.cfg.distributed_training.zero_sharding == "os": if ( - self.args.fp16 - and not self.args.memory_efficient_fp16 - and not self.args.memory_efficient_bf16 - ) and not self.args.fp16_no_flatten_grads: + self.cfg.common.fp16 + and not self.cfg.common.memory_efficient_fp16 + and not self.cfg.common.memory_efficient_bf16 + ) and not self.cfg.common.fp16_no_flatten_grads: raise ValueError( "ZeRO is incomptabile with fp16 and flattened grads. " "Please use --fp16-no-flatten-grads" ) else: - optim.shard_( - self.args, self._optimizer, self.data_parallel_process_group - ) + optim.shard_(self._optimizer, self.data_parallel_process_group) # We should initialize the learning rate scheduler immediately after # building the optimizer, so that the initial learning rate is set. - self._lr_scheduler = lr_scheduler.build_lr_scheduler(self.args, self.optimizer) + self._lr_scheduler = lr_scheduler.build_lr_scheduler( + self.cfg.lr_scheduler, + self.optimizer, + ) self._lr_scheduler.step_update(0) def consolidate_optimizer(self): @@ -253,7 +271,7 @@ class Trainer(object): extra_state["previous_training_time"] = self.cumulative_training_time() checkpoint_utils.save_state( filename, - self.args, + self.cfg, self.get_model().state_dict(), self.get_criterion(), self.optimizer, @@ -277,11 +295,10 @@ class Trainer(object): bexists = PathManager.isfile(filename) if bexists: state = checkpoint_utils.load_checkpoint_to_cpu(filename) - # load model parameters try: self.get_model().load_state_dict( - state["model"], strict=True, args=self.args + state["model"], strict=True, model_cfg=self.cfg.model ) if utils.has_parameters(self.get_criterion()): self.get_criterion().load_state_dict( @@ -355,28 +372,28 @@ class Trainer(object): if load_dataset: logger.info("loading train data for epoch {}".format(epoch)) self.task.load_dataset( - self.args.train_subset, + self.cfg.dataset.train_subset, epoch=epoch, combine=combine, data_selector=data_selector, ) batch_iterator = self.task.get_batch_iterator( - dataset=self.task.dataset(self.args.train_subset), - max_tokens=self.args.max_tokens, - max_sentences=self.args.batch_size, + dataset=self.task.dataset(self.cfg.dataset.train_subset), + max_tokens=self.cfg.dataset.max_tokens, + max_sentences=self.cfg.dataset.batch_size, max_positions=utils.resolve_max_positions( self.task.max_positions(), self.model.max_positions(), - self.args.max_tokens, + self.cfg.dataset.max_tokens, ), ignore_invalid_inputs=True, - required_batch_size_multiple=self.args.required_batch_size_multiple, - seed=self.args.seed, + required_batch_size_multiple=self.cfg.dataset.required_batch_size_multiple, + seed=self.cfg.common.seed, num_shards=self.data_parallel_world_size if shard_batch_itr else 1, shard_id=self.data_parallel_rank if shard_batch_itr else 0, - num_workers=self.args.num_workers, + num_workers=self.cfg.dataset.num_workers, epoch=epoch, - data_buffer_size=self.args.data_buffer_size, + data_buffer_size=self.cfg.dataset.data_buffer_size, disable_iterator_cache=disable_iterator_cache, ) self.reset_dummy_batch(batch_iterator.first_batch) @@ -390,19 +407,19 @@ class Trainer(object): """Return an EpochBatchIterator over given validation subset for a given epoch.""" batch_iterator = self.task.get_batch_iterator( dataset=self.task.dataset(subset), - max_tokens=self.args.max_tokens_valid, - max_sentences=self.args.batch_size_valid, + max_tokens=self.cfg.dataset.max_tokens_valid, + max_sentences=self.cfg.dataset.batch_size_valid, max_positions=utils.resolve_max_positions( self.task.max_positions(), self.model.max_positions(), ), - ignore_invalid_inputs=self.args.skip_invalid_size_inputs_valid_test, - required_batch_size_multiple=self.args.required_batch_size_multiple, - seed=self.args.seed, + ignore_invalid_inputs=self.cfg.dataset.skip_invalid_size_inputs_valid_test, + required_batch_size_multiple=self.cfg.dataset.required_batch_size_multiple, + seed=self.cfg.common.seed, num_shards=self.data_parallel_world_size, shard_id=self.data_parallel_rank, - num_workers=self.args.num_workers, - data_buffer_size=self.args.data_buffer_size, + num_workers=self.cfg.dataset.num_workers, + data_buffer_size=self.cfg.dataset.data_buffer_size, disable_iterator_cache=disable_iterator_cache, ) self.reset_dummy_batch(batch_iterator.first_batch) @@ -504,7 +521,7 @@ class Trainer(object): self.zero_grad() if self.cuda: torch.cuda.empty_cache() - if self.args.distributed_world_size == 1: + if self.cfg.distributed_training.distributed_world_size == 1: return None else: raise e @@ -565,7 +582,7 @@ class Trainer(object): # multiply gradients by (# GPUs / sample_size) since DDP # already normalizes by the number of GPUs. Thus we get # (sum_of_gradients / sample_size). - if not self.args.use_bmuf: + if not self.cfg.optimization.use_bmuf: self.optimizer.multiply_grads( self.data_parallel_world_size / sample_size ) @@ -575,12 +592,12 @@ class Trainer(object): with torch.autograd.profiler.record_function("clip-grads"): # clip grads - grad_norm = self.clip_grad_norm(self.args.clip_norm) + grad_norm = self.clip_grad_norm(self.cfg.optimization.clip_norm) # check that grad norms are consistent across workers if ( - not self.args.use_bmuf - and self.args.distributed_wrapper != "SlowMo" + not self.cfg.optimization.use_bmuf + and self.cfg.distributed_training.distributed_wrapper != "SlowMo" and not self.tpu ): self._check_grad_norms(grad_norm) @@ -624,7 +641,10 @@ class Trainer(object): self.optimizer.optimizer ) - if not overflow or self.args.distributed_wrapper == "SlowMo": + if ( + not overflow + or self.cfg.distributed_training.distributed_wrapper == "SlowMo" + ): self.set_num_updates(self.get_num_updates() + 1) if self.tpu: @@ -636,7 +656,7 @@ class Trainer(object): # only log stats every log_interval steps # this causes wps to be misreported when log_interval > 1 logging_output = {} - if self.get_num_updates() % self.args.log_interval == 0: + if self.get_num_updates() % self.cfg.common.log_interval == 0: # log memory usage mem_info = xm.get_memory_info(self.device) gb_free = mem_info["kb_free"] / 1024 / 1024 @@ -677,16 +697,16 @@ class Trainer(object): # clear CUDA cache to reduce memory fragmentation if ( self.cuda - and self.args.empty_cache_freq > 0 + and self.cfg.common.empty_cache_freq > 0 and ( - (self.get_num_updates() + self.args.empty_cache_freq - 1) - % self.args.empty_cache_freq + (self.get_num_updates() + self.cfg.common.empty_cache_freq - 1) + % self.cfg.common.empty_cache_freq ) == 0 ): torch.cuda.empty_cache() - if self.args.fp16: + if self.cfg.common.fp16: metrics.log_scalar( "loss_scale", self.optimizer.scaler.loss_scale, @@ -883,10 +903,10 @@ class Trainer(object): return t.to(dtype=torch.bfloat16) return t - if self.args.fp16: + if self.cfg.common.fp16: sample = utils.apply_to_sample(apply_half, sample) - if self.args.bf16: + if self.cfg.common.bf16: sample = utils.apply_to_sample(apply_bfloat16, sample) return sample @@ -894,7 +914,7 @@ class Trainer(object): def _set_seed(self): # Set seed based on args.seed and the update number so that we get # reproducible results when resuming from checkpoints - seed = self.args.seed + self.get_num_updates() + seed = self.cfg.common.seed + self.get_num_updates() utils.set_torch_seed(seed) def _sync_stats(self): @@ -902,10 +922,12 @@ class Trainer(object): # BMUF and it's a bmuf sync with warmup iterations completed before. if self.data_parallel_world_size == 1: return False - elif self.args.use_bmuf: - return (self.get_num_updates() + 1) % self.args.global_sync_iter == 0 and ( + elif self.cfg.optimization.use_bmuf: + return ( self.get_num_updates() + 1 - ) > self.args.warmup_iterations + ) % self.cfg.bmuf.global_sync_iter == 0 and ( + self.get_num_updates() + 1 + ) > self.cfg.bmuf.warmup_iterations else: return True @@ -950,7 +972,7 @@ class Trainer(object): zip( *distributed_utils.all_gather_list( [logging_outputs] + list(extra_stats_to_sum), - max_size=getattr(self.args, "all_gather_list_size", 16384), + max_size=getattr(self.cfg.common, "all_gather_list_size", 16384), group=self.data_parallel_process_group, ) ) @@ -1038,11 +1060,11 @@ class Trainer(object): if grad_norm is not None: metrics.log_speed("ups", 1.0, priority=100, round=2) metrics.log_scalar("gnorm", grad_norm, priority=400, round=3) - if self.args.clip_norm > 0: + if self.cfg.optimization.clip_norm > 0: metrics.log_scalar( "clip", torch.where( - grad_norm > self.args.clip_norm, + grad_norm > self.cfg.optimization.clip_norm, grad_norm.new_tensor(100), grad_norm.new_tensor(0), ), @@ -1087,7 +1109,7 @@ class Trainer(object): logger.warning( "XLA compilation detected on device #{}; too many of these can lead " "to slow training, but we expect a few in the beginning".format( - self.args.distributed_rank + self.cfg.distributed_training.distributed_rank ) ) self._num_xla_compiles = num_xla_compiles diff --git a/fairseq_cli/eval_lm.py b/fairseq_cli/eval_lm.py index 9a4ff8ee..4621a66a 100644 --- a/fairseq_cli/eval_lm.py +++ b/fairseq_cli/eval_lm.py @@ -11,13 +11,19 @@ Evaluate the perplexity of a trained language model. import logging import math import os +from argparse import Namespace import torch from fairseq import checkpoint_utils, distributed_utils, options, tasks, utils from fairseq.data import LMContextWindowDataset +from fairseq.dataclass.data_class import register_hydra_cfg +from fairseq.dataclass.utils import convert_namespace_to_omegaconf from fairseq.logging import progress_bar from fairseq.logging.meters import StopwatchMeter, TimeMeter from fairseq.sequence_scorer import SequenceScorer +from hydra.core.config_store import ConfigStore +from hydra.experimental import initialize +from omegaconf import DictConfig logging.basicConfig( @@ -60,65 +66,60 @@ class WordStat(object): ) -def main(parsed_args, **unused_kwargs): - assert parsed_args.path is not None, "--path required for evaluation!" +def main(cfg: DictConfig, override_args=None, **unused_kwargs): + if isinstance(cfg, Namespace): + cfg = convert_namespace_to_omegaconf(cfg) - if torch.cuda.is_available() and not parsed_args.cpu: - torch.cuda.set_device(parsed_args.device_id) + utils.import_user_module(cfg.common) - utils.import_user_module(parsed_args) + use_fp16 = cfg.common.fp16 + use_cuda = torch.cuda.is_available() and not cfg.common.cpu - logger.info(parsed_args) + if use_cuda: + torch.cuda.set_device(cfg.distributed_training.device_id) - use_cuda = torch.cuda.is_available() and not parsed_args.cpu + if override_args is not None: + overrides = vars(override_args) + overrides.update(eval(getattr(override_args, "model_overrides", "{}"))) + else: + overrides = None - task = tasks.setup_task(parsed_args) + logger.info(cfg) # Load ensemble - logger.info("loading model(s) from {}".format(parsed_args.path)) - models, args = checkpoint_utils.load_model_ensemble( - parsed_args.path.split(os.pathsep), - arg_overrides=eval(parsed_args.model_overrides), - task=task, - suffix=getattr(parsed_args, "checkpoint_suffix", ""), - strict=(parsed_args.checkpoint_shard_count == 1), - num_shards=parsed_args.checkpoint_shard_count, - ) - - for arg in vars(parsed_args).keys(): - if arg not in { - "self_target", - "future_target", - "past_target", - "tokens_per_sample", - "output_size_dictionary", - "add_bos_token", - }: - setattr(args, arg, getattr(parsed_args, arg)) + logger.info("loading model(s) from {}".format(cfg.common_eval.path)) # reduce tokens per sample by the required context window size - args.tokens_per_sample -= args.context_window - task = tasks.setup_task(args) + cfg.task.tokens_per_sample -= cfg.eval_lm.context_window + + models, model_args, task = checkpoint_utils.load_model_ensemble_and_task( + [cfg.common_eval.path], + arg_overrides=overrides, + suffix=cfg.checkpoint.checkpoint_suffix, + strict=(cfg.checkpoint.checkpoint_shard_count == 1), + num_shards=cfg.checkpoint.checkpoint_shard_count, + ) # Load dataset splits - task.load_dataset(args.gen_subset) - dataset = task.dataset(args.gen_subset) - if args.context_window > 0: + gen_subset = cfg.dataset.gen_subset + task.load_dataset(gen_subset) + dataset = task.dataset(gen_subset) + if cfg.eval_lm.context_window > 0: dataset = LMContextWindowDataset( dataset=dataset, - tokens_per_sample=args.tokens_per_sample, - context_window=args.context_window, + tokens_per_sample=cfg.task.tokens_per_sample, + context_window=cfg.eval_lm.context_window, pad_idx=task.source_dictionary.pad(), ) - logger.info("{} {} {} examples".format(args.data, args.gen_subset, len(dataset))) + logger.info("{} {} {} examples".format(cfg.task.data, gen_subset, len(dataset))) # Optimize ensemble for generation and set the source and dest dicts on the model (required by scorer) for model in models: - if args.fp16: + if use_fp16: model.half() - if use_cuda and not args.pipeline_model_parallel: + if use_cuda and not cfg.distributed_training.pipeline_model_parallel: model.cuda() - model.prepare_for_inference_(args) + model.prepare_for_inference_(cfg) assert len(models) > 0 @@ -128,35 +129,41 @@ def main(parsed_args, **unused_kwargs): itr = task.get_batch_iterator( dataset=dataset, - max_tokens=args.max_tokens or 36000, - max_sentences=args.batch_size, + max_tokens=cfg.dataset.max_tokens or 36000, + max_sentences=cfg.dataset.batch_size, max_positions=utils.resolve_max_positions( *[model.max_positions() for model in models] ), ignore_invalid_inputs=True, - num_shards=args.num_shards, - shard_id=args.shard_id, - num_workers=args.num_workers, - data_buffer_size=args.data_buffer_size, + num_shards=max( + cfg.dataset.num_shards, + cfg.distributed_training.distributed_world_size, + ), + shard_id=max( + cfg.dataset.shard_id, + cfg.distributed_training.distributed_rank, + ), + num_workers=cfg.dataset.num_workers, + data_buffer_size=cfg.dataset.data_buffer_size, ).next_epoch_itr(shuffle=False) progress = progress_bar.progress_bar( itr, - log_format=args.log_format, - log_interval=args.log_interval, - default_log_format=("tqdm" if not args.no_progress_bar else "none"), + log_format=cfg.common.log_format, + log_interval=cfg.common.log_interval, + default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"), ) gen_timer = StopwatchMeter() - scorer = SequenceScorer(task.target_dictionary, args.softmax_batch) + scorer = SequenceScorer(task.target_dictionary, cfg.eval_lm.softmax_batch) score_sum = 0.0 count = 0 - if args.remove_bpe is not None: - if args.remove_bpe == "sentencepiece": + if cfg.common_eval.remove_bpe is not None: + if cfg.common_eval.remove_bpe == "sentencepiece": raise NotImplementedError else: - bpe_cont = args.remove_bpe.rstrip() + bpe_cont = cfg.common_eval.remove_bpe.rstrip() bpe_toks = { i for i in range(len(task.source_dictionary)) @@ -189,7 +196,7 @@ def main(parsed_args, **unused_kwargs): tgt_len = tokens.numel() pos_scores = hypo["positional_scores"].float() - if getattr(args, "add_bos_token", False): + if cfg.task.add_bos_token: assert hypo["tokens"][0].item() == task.target_dictionary.bos() tokens = tokens[1:] pos_scores = pos_scores[1:] @@ -212,7 +219,7 @@ def main(parsed_args, **unused_kwargs): score_sum += pos_scores.sum().cpu() count += pos_scores.numel() - skipped_toks - if args.output_word_probs or args.output_word_stats: + if cfg.eval_lm.output_word_probs or cfg.eval_lm.output_word_stats: w = "" word_prob = [] is_bpe = False @@ -238,7 +245,7 @@ def main(parsed_args, **unused_kwargs): ) is_bpe = False w = "" - if args.output_word_probs: + if cfg.eval_lm.output_word_probs: logger.info( str(int(sample_id)) + " " @@ -264,7 +271,7 @@ def main(parsed_args, **unused_kwargs): ) ) - if args.output_word_stats: + if cfg.eval_lm.output_word_stats: for ws in sorted(word_stats.values(), key=lambda x: x.count, reverse=True): logger.info(ws) @@ -272,8 +279,16 @@ def main(parsed_args, **unused_kwargs): def cli_main(): parser = options.get_eval_lm_parser() args = options.parse_args_and_arch(parser) - distributed_utils.call_main(args, main) + + # only override args that are explicitly given on the command line + override_parser = options.get_validation_parser() + override_args = options.parse_args_and_arch(override_parser, suppress_defaults=True) + + distributed_utils.call_main(args, main, override_args=override_args) if __name__ == "__main__": + cs = ConfigStore.instance() + register_hydra_cfg(cs) + initialize(config_path="../config", strict=True) cli_main() diff --git a/fairseq_cli/generate.py b/fairseq_cli/generate.py index 8ddf981c..6a6f7465 100644 --- a/fairseq_cli/generate.py +++ b/fairseq_cli/generate.py @@ -12,33 +12,45 @@ import logging import math import os import sys +from argparse import Namespace from itertools import chain import numpy as np import torch from fairseq import checkpoint_utils, options, scoring, tasks, utils +from fairseq.data import encoders +from fairseq.dataclass.data_class import register_hydra_cfg +from fairseq.dataclass.utils import convert_namespace_to_omegaconf from fairseq.logging import progress_bar from fairseq.logging.meters import StopwatchMeter, TimeMeter +from hydra.core.config_store import ConfigStore +from hydra.experimental import initialize +from omegaconf import DictConfig -def main(args): - assert args.path is not None, "--path required for generation!" +def main(cfg: DictConfig): + + if isinstance(cfg, Namespace): + cfg = convert_namespace_to_omegaconf(cfg) + + assert cfg.common_eval.path is not None, "--path required for generation!" assert ( - not args.sampling or args.nbest == args.beam + not cfg.generation.sampling or cfg.generation.nbest == cfg.generation.beam ), "--sampling requires --nbest to be equal to --beam" assert ( - args.replace_unk is None or args.dataset_impl == "raw" + cfg.generation.replace_unk is None or cfg.dataset.dataset_impl == "raw" ), "--replace-unk requires a raw text dataset (--dataset-impl=raw)" - if args.results_path is not None: - os.makedirs(args.results_path, exist_ok=True) + if cfg.common_eval.results_path is not None: + os.makedirs(cfg.common_eval.results_path, exist_ok=True) output_path = os.path.join( - args.results_path, "generate-{}.txt".format(args.gen_subset) + cfg.common_eval.results_path, + "generate-{}.txt".format(cfg.dataset.gen_subset), ) with open(output_path, "w", buffering=1, encoding="utf-8") as h: - return _main(args, h) + return _main(cfg, h) else: - return _main(args, sys.stdout) + return _main(cfg, sys.stdout) def get_symbols_to_strip_from_output(generator): @@ -48,7 +60,7 @@ def get_symbols_to_strip_from_output(generator): return {generator.eos} -def _main(args, output_file): +def _main(cfg: DictConfig, output_file): logging.basicConfig( format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S", @@ -57,22 +69,22 @@ def _main(args, output_file): ) logger = logging.getLogger("fairseq_cli.generate") - utils.import_user_module(args) + utils.import_user_module(cfg.common) - if args.max_tokens is None and args.batch_size is None: - args.max_tokens = 12000 - logger.info(args) + if cfg.dataset.max_tokens is None and cfg.dataset.batch_size is None: + cfg.dataset.max_tokens = 12000 + logger.info(cfg) # Fix seed for stochastic decoding - if args.seed is not None and not args.no_seed_provided: - np.random.seed(args.seed) - utils.set_torch_seed(args.seed) + if cfg.common.seed is not None and not cfg.generation.no_seed_provided: + np.random.seed(cfg.common.seed) + utils.set_torch_seed(cfg.common.seed) - use_cuda = torch.cuda.is_available() and not args.cpu + use_cuda = torch.cuda.is_available() and not cfg.common.cpu # Load dataset splits - task = tasks.setup_task(args) - task.load_dataset(args.gen_subset) + task = tasks.setup_task(cfg.task) + task.load_dataset(cfg.dataset.gen_subset) # Set dictionaries try: @@ -81,32 +93,30 @@ def _main(args, output_file): src_dict = None tgt_dict = task.target_dictionary - overrides = ast.literal_eval(args.model_overrides) + overrides = ast.literal_eval(cfg.common_eval.model_overrides) # Load ensemble - logger.info("loading model(s) from {}".format(args.path)) + logger.info("loading model(s) from {}".format(cfg.common_eval.path)) models, _model_args = checkpoint_utils.load_model_ensemble( - utils.split_paths(args.path), + utils.split_paths(cfg.common_eval.path), arg_overrides=overrides, task=task, - suffix=getattr(args, "checkpoint_suffix", ""), - strict=(args.checkpoint_shard_count == 1), - num_shards=args.checkpoint_shard_count, + suffix=cfg.checkpoint.checkpoint_suffix, + strict=(cfg.checkpoint.checkpoint_shard_count == 1), + num_shards=cfg.checkpoint.checkpoint_shard_count, ) - if args.lm_path is not None: - overrides["data"] = args.data + if cfg.generation.lm_path is not None: + overrides["data"] = cfg.task.data try: lms, _ = checkpoint_utils.load_model_ensemble( - [args.lm_path], - arg_overrides=overrides, - task=None, + [cfg.generation.lm_path], arg_overrides=overrides, task=None ) except: logger.warning( f"Failed to load language model! Please make sure that the language model dict is the same " - f"as target dict and is located in the data dir ({args.data})" + f"as target dict and is located in the data dir ({cfg.task.data})" ) raise @@ -118,49 +128,50 @@ def _main(args, output_file): for model in chain(models, lms): if model is None: continue - if args.fp16: + if cfg.common.fp16: model.half() - if use_cuda and not args.pipeline_model_parallel: + if use_cuda and not cfg.distributed_training.pipeline_model_parallel: model.cuda() - model.prepare_for_inference_(args) + model.prepare_for_inference_(cfg) # Load alignment dictionary for unknown word replacement # (None if no unknown word replacement, empty if no path to align dictionary) - align_dict = utils.load_align_dict(args.replace_unk) + align_dict = utils.load_align_dict(cfg.generation.replace_unk) # Load dataset (possibly sharded) itr = task.get_batch_iterator( - dataset=task.dataset(args.gen_subset), - max_tokens=args.max_tokens, - max_sentences=args.batch_size, + dataset=task.dataset(cfg.dataset.gen_subset), + max_tokens=cfg.dataset.max_tokens, + max_sentences=cfg.dataset.batch_size, max_positions=utils.resolve_max_positions( - task.max_positions(), *[model.max_positions() for model in models] + task.max_positions(), *[m.max_positions() for m in models] ), - ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, - required_batch_size_multiple=args.required_batch_size_multiple, - num_shards=args.num_shards, - shard_id=args.shard_id, - num_workers=args.num_workers, - data_buffer_size=args.data_buffer_size, + ignore_invalid_inputs=cfg.dataset.skip_invalid_size_inputs_valid_test, + required_batch_size_multiple=cfg.dataset.required_batch_size_multiple, + seed=cfg.common.seed, + num_shards=cfg.distributed_training.distributed_world_size, + shard_id=cfg.distributed_training.distributed_rank, + num_workers=cfg.dataset.num_workers, + data_buffer_size=cfg.dataset.data_buffer_size, ).next_epoch_itr(shuffle=False) progress = progress_bar.progress_bar( itr, - log_format=args.log_format, - log_interval=args.log_interval, - default_log_format=("tqdm" if not args.no_progress_bar else "none"), + log_format=cfg.common.log_format, + log_interval=cfg.common.log_interval, + default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"), ) # Initialize generator gen_timer = StopwatchMeter() - extra_gen_cls_kwargs = {"lm_model": lms[0], "lm_weight": args.lm_weight} + extra_gen_cls_kwargs = {"lm_model": lms[0], "lm_weight": cfg.generation.lm_weight} generator = task.build_generator( - models, args, extra_gen_cls_kwargs=extra_gen_cls_kwargs + models, cfg.task, extra_gen_cls_kwargs=extra_gen_cls_kwargs ) # Handle tokenization and BPE - tokenizer = task.build_tokenizer(args) - bpe = task.build_bpe(args) + tokenizer = encoders.build_tokenizer(cfg.tokenizer) + bpe = encoders.build_bpe(cfg.bpe) def decode_fn(x): if bpe is not None: @@ -169,7 +180,7 @@ def _main(args, output_file): x = tokenizer.decode(x) return x - scorer = scoring.build_scorer(args, tgt_dict) + scorer = scoring.build_scorer(cfg.scoring, tgt_dict) num_sentences = 0 has_target = True @@ -180,8 +191,8 @@ def _main(args, output_file): continue prefix_tokens = None - if args.prefix_size > 0: - prefix_tokens = sample["target"][:, : args.prefix_size] + if cfg.generation.prefix_size > 0: + prefix_tokens = sample["target"][:, : cfg.generation.prefix_size] constraints = None if "constraints" in sample: @@ -217,19 +228,21 @@ def _main(args, output_file): # Either retrieve the original sentences or regenerate them from tokens. if align_dict is not None: - src_str = task.dataset(args.gen_subset).src.get_original_text(sample_id) - target_str = task.dataset(args.gen_subset).tgt.get_original_text( + src_str = task.dataset(cfg.dataset.gen_subset).src.get_original_text( + sample_id + ) + target_str = task.dataset(cfg.dataset.gen_subset).tgt.get_original_text( sample_id ) else: if src_dict is not None: - src_str = src_dict.string(src_tokens, args.remove_bpe) + src_str = src_dict.string(src_tokens, cfg.common_eval.remove_bpe) else: src_str = "" if has_target: target_str = tgt_dict.string( target_tokens, - args.remove_bpe, + cfg.common_eval.remove_bpe, escape_unk=True, extra_symbols_to_ignore=get_symbols_to_strip_from_output( generator @@ -240,25 +253,25 @@ def _main(args, output_file): if has_target: target_str = decode_fn(target_str) - if not args.quiet: + if not cfg.common_eval.quiet: if src_dict is not None: print("S-{}\t{}".format(sample_id, src_str), file=output_file) if has_target: print("T-{}\t{}".format(sample_id, target_str), file=output_file) # Process top predictions - for j, hypo in enumerate(hypos[i][: args.nbest]): + for j, hypo in enumerate(hypos[i][: cfg.generation.nbest]): hypo_tokens, hypo_str, alignment = utils.post_process_prediction( hypo_tokens=hypo["tokens"].int().cpu(), src_str=src_str, alignment=hypo["alignment"], align_dict=align_dict, tgt_dict=tgt_dict, - remove_bpe=args.remove_bpe, + remove_bpe=cfg.common_eval.remove_bpe, extra_symbols_to_ignore=get_symbols_to_strip_from_output(generator), ) detok_hypo_str = decode_fn(hypo_str) - if not args.quiet: + if not cfg.common_eval.quiet: score = hypo["score"] / math.log(2) # convert to base 2 # original hypothesis (after tokenization and BPE) print( @@ -286,7 +299,7 @@ def _main(args, output_file): file=output_file, ) - if args.print_alignment: + if cfg.generation.print_alignment: print( "A-{}\t{}".format( sample_id, @@ -300,13 +313,13 @@ def _main(args, output_file): file=output_file, ) - if args.print_step: + if cfg.generation.print_step: print( "I-{}\t{}".format(sample_id, hypo["steps"]), file=output_file, ) - if getattr(args, "retain_iter_history", False): + if cfg.generation.retain_iter_history: for step, h in enumerate(hypo["history"]): _, h_str, _ = utils.post_process_prediction( hypo_tokens=h["tokens"].int().cpu(), @@ -323,7 +336,7 @@ def _main(args, output_file): # Score only the top hypothesis if has_target and j == 0: - if align_dict is not None or args.remove_bpe is not None: + if align_dict is not None or cfg.common_eval.remove_bpe is not None: # Convert back to tokens for evaluation with unk replacement and/or without BPE target_tokens = tgt_dict.encode_line( target_str, add_if_not_exist=True @@ -353,8 +366,8 @@ def _main(args, output_file): ) ) if has_target: - if args.bpe and not args.sacrebleu: - if args.remove_bpe: + if cfg.bpe and not cfg.generation.sacrebleu: + if cfg.common_eval.remove_bpe: logger.warning( "BLEU score is being computed by splitting detokenized string on spaces, this is probably not what you want. Use --sacrebleu for standard 13a BLEU tokenization" ) @@ -365,7 +378,7 @@ def _main(args, output_file): # use print to be consistent with other main outputs: S-, H-, T-, D- and so on print( "Generate {} with beam={}: {}".format( - args.gen_subset, args.beam, scorer.result_string() + cfg.dataset.gen_subset, cfg.generation.beam, scorer.result_string() ), file=output_file, ) @@ -380,4 +393,7 @@ def cli_main(): if __name__ == "__main__": + cs = ConfigStore.instance() + register_hydra_cfg(cs) + initialize(config_path="../config", strict=True) cli_main() diff --git a/fairseq_cli/interactive.py b/fairseq_cli/interactive.py index de3893a3..ddd2617c 100644 --- a/fairseq_cli/interactive.py +++ b/fairseq_cli/interactive.py @@ -7,20 +7,27 @@ Translate raw text with a trained model. Batches data on-the-fly. """ +import ast import fileinput import logging import math import os import sys import time +from argparse import Namespace from collections import namedtuple import numpy as np import torch from fairseq import checkpoint_utils, distributed_utils, options, tasks, utils from fairseq.data import encoders +from fairseq.dataclass.data_class import register_hydra_cfg +from fairseq.dataclass.utils import convert_namespace_to_omegaconf from fairseq.token_generation_constraints import pack_constraints, unpack_constraints from fairseq_cli.generate import get_symbols_to_strip_from_output +from hydra.core.config_store import ConfigStore +from hydra.experimental import initialize +from omegaconf import DictConfig logging.basicConfig( @@ -49,11 +56,11 @@ def buffered_read(input, buffer_size): yield buffer -def make_batches(lines, args, task, max_positions, encode_fn): +def make_batches(lines, cfg, task, max_positions, encode_fn): def encode_fn_target(x): return encode_fn(x) - if args.constraints: + if cfg.generation.constraints: # Strip (tab-delimited) contraints, if present, from input lines, # store them in batch_constraints batch_constraints = [list() for _ in lines] @@ -79,7 +86,7 @@ def make_batches(lines, args, task, max_positions, encode_fn): for src_str in lines ] - if args.constraints: + if cfg.generation.constraints: constraints_tensor = pack_constraints(batch_constraints) else: constraints_tensor = None @@ -89,10 +96,10 @@ def make_batches(lines, args, task, max_positions, encode_fn): dataset=task.build_dataset_for_inference( tokens, lengths, constraints=constraints_tensor ), - max_tokens=args.max_tokens, - max_sentences=args.batch_size, + max_tokens=cfg.dataset.max_tokens, + max_sentences=cfg.dataset.batch_size, max_positions=max_positions, - ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, + ignore_invalid_inputs=cfg.dataset.skip_invalid_size_inputs_valid_test, ).next_epoch_itr(shuffle=False) for batch in itr: ids = batch["id"] @@ -108,45 +115,50 @@ def make_batches(lines, args, task, max_positions, encode_fn): ) -def main(args): +def main(cfg: DictConfig): + if isinstance(cfg, Namespace): + cfg = convert_namespace_to_omegaconf(cfg) + start_time = time.time() total_translate_time = 0 - utils.import_user_module(args) + utils.import_user_module(cfg.common) - if args.buffer_size < 1: - args.buffer_size = 1 - if args.max_tokens is None and args.batch_size is None: - args.batch_size = 1 + if cfg.interactive.buffer_size < 1: + cfg.interactive.buffer_size = 1 + if cfg.dataset.max_tokens is None and cfg.dataset.batch_size is None: + cfg.dataset.batch_size = 1 assert ( - not args.sampling or args.nbest == args.beam + not cfg.generation.sampling or cfg.generation.nbest == cfg.generation.beam ), "--sampling requires --nbest to be equal to --beam" assert ( - not args.batch_size or args.batch_size <= args.buffer_size + not cfg.dataset.batch_size + or cfg.dataset.batch_size <= cfg.interactive.buffer_size ), "--batch-size cannot be larger than --buffer-size" - logger.info(args) + logger.info(cfg) # Fix seed for stochastic decoding - if args.seed is not None and not args.no_seed_provided: - np.random.seed(args.seed) - utils.set_torch_seed(args.seed) + if cfg.common.seed is not None and not cfg.generation.no_seed_provided: + np.random.seed(cfg.common.seed) + utils.set_torch_seed(cfg.common.seed) - use_cuda = torch.cuda.is_available() and not args.cpu + use_cuda = torch.cuda.is_available() and not cfg.common.cpu # Setup task, e.g., translation - task = tasks.setup_task(args) + task = tasks.setup_task(cfg.task) # Load ensemble - logger.info("loading model(s) from {}".format(args.path)) + overrides = ast.literal_eval(cfg.common_eval.model_overrides) + logger.info("loading model(s) from {}".format(cfg.common_eval.path)) models, _model_args = checkpoint_utils.load_model_ensemble( - args.path.split(os.pathsep), - arg_overrides=eval(args.model_overrides), + utils.split_paths(cfg.common_eval.path), + arg_overrides=overrides, task=task, - suffix=getattr(args, "checkpoint_suffix", ""), - strict=(args.checkpoint_shard_count == 1), - num_shards=args.checkpoint_shard_count, + suffix=cfg.checkpoint.checkpoint_suffix, + strict=(cfg.checkpoint.checkpoint_shard_count == 1), + num_shards=cfg.checkpoint.checkpoint_shard_count, ) # Set dictionaries @@ -155,18 +167,20 @@ def main(args): # Optimize ensemble for generation for model in models: - if args.fp16: + if model is None: + continue + if cfg.common.fp16: model.half() - if use_cuda and not args.pipeline_model_parallel: + if use_cuda and not cfg.distributed_training.pipeline_model_parallel: model.cuda() - model.prepare_for_inference_(args) + model.prepare_for_inference_(cfg) # Initialize generator - generator = task.build_generator(models, args) + generator = task.build_generator(models, cfg.task) # Handle tokenization and BPE - tokenizer = encoders.build_tokenizer(args) - bpe = encoders.build_bpe(args) + tokenizer = encoders.build_tokenizer(cfg.tokenizer) + bpe = encoders.build_bpe(cfg.bpe) def encode_fn(x): if tokenizer is not None: @@ -184,25 +198,25 @@ def main(args): # Load alignment dictionary for unknown word replacement # (None if no unknown word replacement, empty if no path to align dictionary) - align_dict = utils.load_align_dict(args.replace_unk) + align_dict = utils.load_align_dict(cfg.generation.replace_unk) max_positions = utils.resolve_max_positions( task.max_positions(), *[model.max_positions() for model in models] ) - if args.constraints: + if cfg.generation.constraints: logger.warning( "NOTE: Constrained decoding currently assumes a shared subword vocabulary." ) - if args.buffer_size > 1: - logger.info("Sentence buffer size: %s", args.buffer_size) + if cfg.interactive.buffer_size > 1: + logger.info("Sentence buffer size: %s", cfg.interactive.buffer_size) logger.info("NOTE: hypothesis and token scores are output in base 2") logger.info("Type the input sentence and press return:") start_id = 0 - for inputs in buffered_read(args.input, args.buffer_size): + for inputs in buffered_read(cfg.interactive.input, cfg.interactive.buffer_size): results = [] - for batch in make_batches(inputs, args, task, max_positions, encode_fn): + for batch in make_batches(inputs, cfg, task, max_positions, encode_fn): bsz = batch.src_tokens.size(0) src_tokens = batch.src_tokens src_lengths = batch.src_lengths @@ -226,7 +240,7 @@ def main(args): translate_time = time.time() - translate_start_time total_translate_time += translate_time list_constraints = [[] for _ in range(bsz)] - if args.constraints: + if cfg.generation.constraints: list_constraints = [unpack_constraints(c) for c in constraints] for i, (id, hypos) in enumerate(zip(batch.ids.tolist(), translations)): src_tokens_i = utils.strip_pad(src_tokens[i], tgt_dict.pad()) @@ -246,25 +260,25 @@ def main(args): # sort output to match input order for id_, src_tokens, hypos, info in sorted(results, key=lambda x: x[0]): if src_dict is not None: - src_str = src_dict.string(src_tokens, args.remove_bpe) + src_str = src_dict.string(src_tokens, cfg.common_eval.remove_bpe) print("S-{}\t{}".format(id_, src_str)) print("W-{}\t{:.3f}\tseconds".format(id_, info["time"])) for constraint in info["constraints"]: print( "C-{}\t{}".format( - id_, tgt_dict.string(constraint, args.remove_bpe) + id_, tgt_dict.string(constraint, cfg.common_eval.remove_bpe) ) ) # Process top predictions - for hypo in hypos[: min(len(hypos), args.nbest)]: + for hypo in hypos[: min(len(hypos), cfg.generation.nbest)]: hypo_tokens, hypo_str, alignment = utils.post_process_prediction( hypo_tokens=hypo["tokens"].int().cpu(), src_str=src_str, alignment=hypo["alignment"], align_dict=align_dict, tgt_dict=tgt_dict, - remove_bpe=args.remove_bpe, + remove_bpe=cfg.common_eval.remove_bpe, extra_symbols_to_ignore=get_symbols_to_strip_from_output(generator), ) detok_hypo_str = decode_fn(hypo_str) @@ -285,7 +299,7 @@ def main(args): ), ) ) - if args.print_alignment: + if cfg.generation.print_alignment: alignment_str = " ".join( ["{}-{}".format(src, tgt) for src, tgt in alignment] ) @@ -308,4 +322,7 @@ def cli_main(): if __name__ == "__main__": + cs = ConfigStore.instance() + register_hydra_cfg(cs) + initialize(config_path="../config", strict=True) cli_main() diff --git a/fairseq_cli/score.py b/fairseq_cli/score.py index b8354eb9..e06d6725 100644 --- a/fairseq_cli/score.py +++ b/fairseq_cli/score.py @@ -78,7 +78,13 @@ def cli_main(): def score(fdsys): with open(args.ref) as fdref: - scorer = bleu.Scorer(dict.pad(), dict.eos(), dict.unk()) + scorer = bleu.Scorer( + bleu.BleuConfig( + pad=dict.pad(), + eos=dict.eos(), + unk=dict.unk(), + ) + ) for sys_tok, ref_tok in zip(readlines(fdsys), readlines(fdref)): sys_tok = dict.encode_line(sys_tok) ref_tok = dict.encode_line(ref_tok) diff --git a/fairseq_cli/train.py b/fairseq_cli/train.py index cd3a93b1..4c007610 100644 --- a/fairseq_cli/train.py +++ b/fairseq_cli/train.py @@ -11,11 +11,13 @@ import argparse import logging import math import os -import random import sys +from typing import Dict, Optional, Any, List, Tuple, Callable import numpy as np import torch +from hydra.core.config_store import ConfigStore + from fairseq import ( checkpoint_utils, distributed_utils, @@ -25,8 +27,12 @@ from fairseq import ( utils, ) from fairseq.data import iterators +from fairseq.dataclass.utils import convert_namespace_to_omegaconf from fairseq.logging import meters, metrics, progress_bar from fairseq.model_parallel.megatron_trainer import MegatronTrainer +from omegaconf import DictConfig +from hydra.experimental import initialize +from fairseq.dataclass.data_class import register_hydra_cfg from fairseq.trainer import Trainer @@ -39,90 +45,86 @@ logging.basicConfig( logger = logging.getLogger("fairseq_cli.train") -def main(args): - utils.import_user_module(args) +def main(cfg: DictConfig) -> None: + if isinstance(cfg, argparse.Namespace): + cfg = convert_namespace_to_omegaconf(cfg) - assert ( - args.max_tokens is not None or args.batch_size is not None - ), "Must specify batch size either with --max-tokens or --batch-size" + utils.import_user_module(cfg.common) + assert cfg.dataset.max_tokens is not None or cfg.dataset.batch_size is not None, \ + 'Must specify batch size either with --max-tokens or --batch-size' metrics.reset() - np.random.seed(args.seed) - utils.set_torch_seed(args.seed) + np.random.seed(cfg.common.seed) + utils.set_torch_seed(cfg.common.seed) - if distributed_utils.is_master(args): - checkpoint_utils.verify_checkpoint_directory(args.save_dir) + if distributed_utils.is_master(cfg.distributed_training): + checkpoint_utils.verify_checkpoint_directory(cfg.checkpoint.save_dir) # Print args - logger.info(args) + logger.info(cfg) # Setup task, e.g., translation, language modeling, etc. - task = tasks.setup_task(args) - + task = tasks.setup_task(cfg.task) # Load valid dataset (we load training data below, based on the latest checkpoint) - for valid_sub_split in args.valid_subset.split(","): + for valid_sub_split in cfg.dataset.valid_subset.split(','): task.load_dataset(valid_sub_split, combine=False, epoch=1) # Build model and criterion - model = task.build_model(args) - criterion = task.build_criterion(args) + model = task.build_model(cfg.model) + criterion = task.build_criterion(cfg.criterion) logger.info(model) - logger.info("task: {} ({})".format(args.task, task.__class__.__name__)) - logger.info("model: {} ({})".format(args.arch, model.__class__.__name__)) + logger.info("task: {} ({})".format(cfg.task._name, task.__class__.__name__)) + logger.info("model: {} ({})".format(cfg.model._name, model.__class__.__name__)) logger.info( - "criterion: {} ({})".format(args.criterion, criterion.__class__.__name__) - ) - logger.info( - "num. model params: {} (num. trained: {})".format( - sum(p.numel() for p in model.parameters()), - sum(p.numel() for p in model.parameters() if p.requires_grad), - ) + "criterion: {} ({})".format(cfg.criterion._name, criterion.__class__.__name__) ) + logger.info("num. model params: {} (num. trained: {})".format( + sum(p.numel() for p in model.parameters()), + sum(p.numel() for p in model.parameters() if p.requires_grad), + )) # (optionally) Configure quantization - if args.quantization_config_path is not None: + if cfg.common.quantization_config_path is not None: quantizer = quantization_utils.Quantizer( - config_path=args.quantization_config_path, - max_epoch=args.max_epoch, - max_update=args.max_update, + config_path=cfg.common.quantization_config_path, + max_epoch=cfg.optimization.max_epoch, + max_update=cfg.optimization.max_update, ) else: quantizer = None # Build trainer - if args.model_parallel_size == 1: - trainer = Trainer(args, task, model, criterion, quantizer) + if cfg.common.model_parallel_size == 1: + trainer = Trainer(cfg, task, model, criterion, quantizer) else: - trainer = MegatronTrainer(args, task, model, criterion) + trainer = MegatronTrainer(cfg, task, model, criterion) - logger.info( - "training on {} devices (GPUs/TPUs)".format(args.distributed_world_size) - ) - logger.info( - "max tokens per GPU = {} and max sentences per GPU = {}".format( - args.max_tokens, args.batch_size - ) - ) + logger.info('training on {} devices (GPUs/TPUs)'.format(cfg.distributed_training.distributed_world_size)) + logger.info('max tokens per GPU = {} and batch size per GPU = {}'.format( + cfg.dataset.max_tokens, + cfg.dataset.batch_size, + )) # Load the latest checkpoint if one is available and restore the # corresponding train iterator extra_state, epoch_itr = checkpoint_utils.load_checkpoint( - args, + cfg.checkpoint, trainer, # don't cache epoch iterators for sharded datasets disable_iterator_cache=task.has_sharded_data("train"), ) - # Train until the learning rate gets too small - max_epoch = args.max_epoch or math.inf + max_epoch = cfg.optimization.max_epoch or math.inf lr = trainer.get_lr() train_meter = meters.StopwatchMeter() train_meter.start() - - while lr > args.min_lr and epoch_itr.next_epoch_idx <= max_epoch: + while ( + lr > cfg.optimization.min_lr + and epoch_itr.next_epoch_idx <= max_epoch + ): # train for one epoch - valid_losses, should_stop = train(args, trainer, task, epoch_itr) + valid_losses, should_stop = train(cfg, trainer, task, epoch_itr) if should_stop: break @@ -140,15 +142,15 @@ def main(args): logger.info("done training in {:.1f} seconds".format(train_meter.sum)) -def should_stop_early(args, valid_loss): +def should_stop_early(cfg: DictConfig, valid_loss: float) -> bool: # skip check if no validation was done in the current epoch if valid_loss is None: return False - if args.patience <= 0: + if cfg.checkpoint.patience <= 0: return False def is_better(a, b): - return a > b if args.maximize_best_checkpoint_metric else a < b + return a > b if cfg.checkpoint.maximize_best_checkpoint_metric else a < b prev_best = getattr(should_stop_early, "best", None) if prev_best is None or is_better(valid_loss, prev_best): @@ -157,48 +159,43 @@ def should_stop_early(args, valid_loss): return False else: should_stop_early.num_runs += 1 - if should_stop_early.num_runs >= args.patience: - logger.info( - "early stop since valid performance hasn't improved for last {} runs".format( - args.patience - ) - ) + if should_stop_early.num_runs >= cfg.checkpoint.patience: + logger.info('early stop since valid performance hasn\'t improved for last {} runs'.format(cfg.checkpoint.patience)) return True else: return False @metrics.aggregate("train") -def train(args, trainer, task, epoch_itr): +def train(cfg: DictConfig, trainer: Trainer, task: tasks.FairseqTask, epoch_itr) -> Tuple[List[Optional[float]], bool]: """Train the model for one epoch and return validation losses.""" # Initialize data iterator itr = epoch_itr.next_epoch_itr( - fix_batches_to_gpus=args.fix_batches_to_gpus, - shuffle=(epoch_itr.next_epoch_idx > args.curriculum), + fix_batches_to_gpus=cfg.distributed_training.fix_batches_to_gpus, + shuffle=(epoch_itr.next_epoch_idx > cfg.dataset.curriculum), ) update_freq = ( - args.update_freq[epoch_itr.epoch - 1] - if epoch_itr.epoch <= len(args.update_freq) - else args.update_freq[-1] + cfg.optimization.update_freq[epoch_itr.epoch - 1] + if epoch_itr.epoch <= len(cfg.optimization.update_freq) + else cfg.optimization.update_freq[-1] ) itr = iterators.GroupedIterator(itr, update_freq) - if getattr(args, "tpu", False): + if getattr(cfg.common, "tpu", False): itr = utils.tpu_data_loader(itr) progress = progress_bar.progress_bar( itr, - log_format=args.log_format, - log_interval=args.log_interval, + log_format=cfg.common.log_format, + log_interval=cfg.common.log_interval, epoch=epoch_itr.epoch, tensorboard_logdir=( - args.tensorboard_logdir if distributed_utils.is_master(args) else None + cfg.common.tensorboard_logdir if distributed_utils.is_master(cfg.distributed_training) else None ), - default_log_format=("tqdm" if not args.no_progress_bar else "simple"), + default_log_format=('tqdm' if not cfg.common.no_progress_bar else 'simple'), ) trainer.begin_epoch(epoch_itr.epoch) - valid_losses = [None] - valid_subsets = args.valid_subset.split(",") + valid_subsets = cfg.dataset.valid_subset.split(',') should_stop = False num_updates = trainer.get_num_updates() for i, samples in enumerate(progress): @@ -210,7 +207,7 @@ def train(args, trainer, task, epoch_itr): if log_output is not None: # not OOM, overflow, ... # log mid-epoch stats num_updates = trainer.get_num_updates() - if num_updates % args.log_interval == 0: + if num_updates % cfg.common.log_interval == 0: stats = get_training_stats(metrics.get_smoothed_values("train_inner")) progress.log(stats, tag="train_inner", step=num_updates) @@ -220,7 +217,7 @@ def train(args, trainer, task, epoch_itr): end_of_epoch = not itr.has_next() valid_losses, should_stop = validate_and_save( - args, trainer, task, epoch_itr, valid_subsets, end_of_epoch + cfg, trainer, task, epoch_itr, valid_subsets, end_of_epoch ) if should_stop: @@ -236,64 +233,64 @@ def train(args, trainer, task, epoch_itr): return valid_losses, should_stop -def validate_and_save(args, trainer, task, epoch_itr, valid_subsets, end_of_epoch): +def validate_and_save(cfg: DictConfig, trainer: Trainer, task: tasks.FairseqTask, epoch_itr, valid_subsets: List[str], end_of_epoch: bool) -> Tuple[List[Optional[float]], bool]: num_updates = trainer.get_num_updates() - max_update = args.max_update or math.inf + max_update = cfg.optimization.max_update or math.inf do_save = ( - (end_of_epoch and epoch_itr.epoch % args.save_interval == 0) + (end_of_epoch and epoch_itr.epoch % cfg.checkpoint.save_interval == 0) or num_updates >= max_update or ( - args.save_interval_updates > 0 + cfg.checkpoint.save_interval_updates > 0 and num_updates > 0 - and num_updates % args.save_interval_updates == 0 - and num_updates >= args.validate_after_updates + and num_updates % cfg.checkpoint.save_interval_updates == 0 + and num_updates >= cfg.dataset.validate_after_updates ) ) do_validate = ( (not end_of_epoch and do_save) # validate during mid-epoch saves - or (end_of_epoch and epoch_itr.epoch % args.validate_interval == 0) + or (end_of_epoch and epoch_itr.epoch % cfg.dataset.validate_interval == 0) or num_updates >= max_update or ( - args.validate_interval_updates > 0 + cfg.dataset.validate_interval_updates > 0 and num_updates > 0 - and num_updates % args.validate_interval_updates == 0 + and num_updates % cfg.dataset.validate_interval_updates == 0 ) - ) and not args.disable_validation + ) and not cfg.dataset.disable_validation # Validate valid_losses = [None] if do_validate: - valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets) + valid_losses = validate(cfg, trainer, task, epoch_itr, valid_subsets) # Stopping conditions should_stop = ( - should_stop_early(args, valid_losses[0]) + should_stop_early(cfg, valid_losses[0]) or num_updates >= max_update or ( - args.stop_time_hours > 0 - and trainer.cumulative_training_time() / (60 * 60) > args.stop_time_hours + cfg.optimization.stop_time_hours > 0 + and trainer.cumulative_training_time() / (60 * 60) > cfg.optimization.stop_time_hours ) ) # Save checkpoint if do_save or should_stop: logger.info("begin save checkpoint") - checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, valid_losses[0]) + checkpoint_utils.save_checkpoint(cfg.checkpoint, trainer, epoch_itr, valid_losses[0]) return valid_losses, should_stop -def get_training_stats(stats): +def get_training_stats(stats: Dict[str, Any]) -> Dict[str, Any]: stats["wall"] = round(metrics.get_meter("default", "wall").elapsed_time, 0) return stats -def validate(args, trainer, task, epoch_itr, subsets): +def validate(cfg: DictConfig, trainer: Trainer, task: tasks.FairseqTask, epoch_itr, subsets: List[str]) -> List[Optional[float]]: """Evaluate the model on the validation set(s) and return the losses.""" - if args.fixed_validation_seed is not None: + if cfg.dataset.fixed_validation_seed is not None: # set fixed seed for every validation - utils.set_torch_seed(args.fixed_validation_seed) + utils.set_torch_seed(cfg.dataset.fixed_validation_seed) trainer.begin_valid_epoch(epoch_itr.epoch) valid_losses = [] @@ -302,18 +299,18 @@ def validate(args, trainer, task, epoch_itr, subsets): # Initialize data iterator itr = trainer.get_valid_iterator(subset).next_epoch_itr(shuffle=False) - if getattr(args, "tpu", False): + if cfg.common.tpu: itr = utils.tpu_data_loader(itr) progress = progress_bar.progress_bar( itr, - log_format=args.log_format, - log_interval=args.log_interval, + log_format=cfg.common.log_format, + log_interval=cfg.common.log_interval, epoch=epoch_itr.epoch, prefix=f"valid on '{subset}' subset", tensorboard_logdir=( - args.tensorboard_logdir if distributed_utils.is_master(args) else None + cfg.common.tensorboard_logdir if distributed_utils.is_master(cfg.distributed_training) else None ), - default_log_format=("tqdm" if not args.no_progress_bar else "simple"), + default_log_format=('tqdm' if not cfg.common.no_progress_bar else 'simple'), ) # create a new root metrics aggregator so validation metrics @@ -323,34 +320,40 @@ def validate(args, trainer, task, epoch_itr, subsets): trainer.valid_step(sample) # log validation stats - stats = get_valid_stats(args, trainer, agg.get_smoothed_values()) + stats = get_valid_stats(cfg, trainer, agg.get_smoothed_values()) progress.print(stats, tag=subset, step=trainer.get_num_updates()) - valid_losses.append(stats[args.best_checkpoint_metric]) + valid_losses.append(stats[cfg.checkpoint.best_checkpoint_metric]) return valid_losses -def get_valid_stats(args, trainer, stats): +def get_valid_stats(cfg: DictConfig, trainer: Trainer, stats: Dict[str, Any]) -> Dict[str, Any]: stats["num_updates"] = trainer.get_num_updates() if hasattr(checkpoint_utils.save_checkpoint, "best"): - key = "best_{0}".format(args.best_checkpoint_metric) - best_function = max if args.maximize_best_checkpoint_metric else min + key = "best_{0}".format(cfg.checkpoint.best_checkpoint_metric) + best_function = max if cfg.checkpoint.maximize_best_checkpoint_metric else min stats[key] = best_function( - checkpoint_utils.save_checkpoint.best, stats[args.best_checkpoint_metric] + checkpoint_utils.save_checkpoint.best, stats[cfg.checkpoint.best_checkpoint_metric] ) return stats -def cli_main(modify_parser=None): +def cli_main(modify_parser: Optional[Callable[[argparse.ArgumentParser], None]] = None) -> None: parser = options.get_training_parser() args = options.parse_args_and_arch(parser, modify_parser=modify_parser) + + cfg = convert_namespace_to_omegaconf(args) + if args.profile: with torch.cuda.profiler.profile(): with torch.autograd.profiler.emit_nvtx(): - distributed_utils.call_main(args, main) + distributed_utils.call_main(cfg, main) else: - distributed_utils.call_main(args, main) + distributed_utils.call_main(cfg, main) -if __name__ == "__main__": +if __name__ == '__main__': + cs = ConfigStore.instance() + register_hydra_cfg(cs) + initialize(config_path="../config", strict=True) cli_main() diff --git a/fairseq_cli/validate.py b/fairseq_cli/validate.py index df857550..368c9cb5 100644 --- a/fairseq_cli/validate.py +++ b/fairseq_cli/validate.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -u -#!/usr/bin/env python3 -u +# !/usr/bin/env python3 -u # Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the @@ -8,11 +8,17 @@ import logging import os import sys +from argparse import Namespace from itertools import chain import torch from fairseq import checkpoint_utils, distributed_utils, options, utils +from fairseq.dataclass.data_class import register_hydra_cfg +from fairseq.dataclass.utils import convert_namespace_to_omegaconf from fairseq.logging import metrics, progress_bar +from hydra.core.config_store import ConfigStore +from hydra.experimental import initialize +from omegaconf import DictConfig logging.basicConfig( @@ -24,18 +30,21 @@ logging.basicConfig( logger = logging.getLogger("fairseq_cli.validate") -def main(args, override_args=None): - utils.import_user_module(args) +def main(cfg: DictConfig, override_args=None): + if isinstance(cfg, Namespace): + cfg = convert_namespace_to_omegaconf(cfg) + + utils.import_user_module(cfg.common) assert ( - args.max_tokens is not None or args.batch_size is not None + cfg.dataset.max_tokens is not None or cfg.dataset.batch_size is not None ), "Must specify batch size either with --max-tokens or --batch-size" - use_fp16 = args.fp16 - use_cuda = torch.cuda.is_available() and not args.cpu + use_fp16 = cfg.common.fp16 + use_cuda = torch.cuda.is_available() and not cfg.common.cpu if use_cuda: - torch.cuda.set_device(args.device_id) + torch.cuda.set_device(cfg.distributed_training.device_id) if override_args is not None: overrides = vars(override_args) @@ -44,11 +53,11 @@ def main(args, override_args=None): overrides = None # Load ensemble - logger.info("loading model(s) from {}".format(args.path)) + logger.info("loading model(s) from {}".format(cfg.common_eval.path)) models, model_args, task = checkpoint_utils.load_model_ensemble_and_task( - [args.path], + [cfg.common_eval.path], arg_overrides=overrides, - suffix=getattr(args, "checkpoint_suffix", ""), + suffix=cfg.checkpoint.checkpoint_suffix, ) model = models[0] @@ -63,10 +72,10 @@ def main(args, override_args=None): logger.info(model_args) # Build criterion - criterion = task.build_criterion(model_args) + criterion = task.build_criterion(model_args.criterion) criterion.eval() - for subset in args.valid_subset.split(","): + for subset in cfg.dataset.valid_subset.split(","): try: task.load_dataset(subset, combine=False, epoch=1) dataset = task.dataset(subset) @@ -76,26 +85,26 @@ def main(args, override_args=None): # Initialize data iterator itr = task.get_batch_iterator( dataset=dataset, - max_tokens=args.max_tokens, - max_sentences=args.batch_size, + max_tokens=cfg.dataset.max_tokens, + max_sentences=cfg.dataset.batch_size, max_positions=utils.resolve_max_positions( task.max_positions(), *[m.max_positions() for m in models], ), - ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, - required_batch_size_multiple=args.required_batch_size_multiple, - seed=args.seed, - num_shards=args.distributed_world_size, - shard_id=args.distributed_rank, - num_workers=args.num_workers, - data_buffer_size=args.data_buffer_size, + ignore_invalid_inputs=cfg.dataset.skip_invalid_size_inputs_valid_test, + required_batch_size_multiple=cfg.dataset.required_batch_size_multiple, + seed=cfg.common.seed, + num_shards=cfg.distributed_training.distributed_world_size, + shard_id=cfg.distributed_training.distributed_rank, + num_workers=cfg.dataset.num_workers, + data_buffer_size=cfg.dataset.data_buffer_size, ).next_epoch_itr(shuffle=False) progress = progress_bar.progress_bar( itr, - log_format=args.log_format, - log_interval=args.log_interval, + log_format=cfg.common.log_format, + log_interval=cfg.common.log_interval, prefix=f"valid on '{subset}' subset", - default_log_format=("tqdm" if not args.no_progress_bar else "simple"), + default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"), ) log_outputs = [] @@ -105,10 +114,10 @@ def main(args, override_args=None): progress.log(log_output, step=i) log_outputs.append(log_output) - if args.distributed_world_size > 1: + if cfg.distributed_training.distributed_world_size > 1: log_outputs = distributed_utils.all_gather_list( log_outputs, - max_size=getattr(args, "all_gather_list_size", 16384), + max_size=cfg.common.all_gather_list_size, ) log_outputs = list(chain.from_iterable(log_outputs)) @@ -131,4 +140,7 @@ def cli_main(): if __name__ == "__main__": + cs = ConfigStore.instance() + register_hydra_cfg(cs) + initialize(config_path="../config", strict=True) cli_main() diff --git a/tests/speech_recognition/asr_test_base.py b/tests/speech_recognition/asr_test_base.py index 03410313..8c5d414e 100644 --- a/tests/speech_recognition/asr_test_base.py +++ b/tests/speech_recognition/asr_test_base.py @@ -272,6 +272,7 @@ class TestFairseqEncoderDecoderModelBase(TestBaseFairseqModelBase): model_cls.add_args(parser) args = parser.parse_args([]) + if extra_args_setters is not None: for args_setter in extra_args_setters: args_setter(args) @@ -515,9 +516,7 @@ class CrossEntropyCriterionTestBase(unittest.TestCase): def setUp(self): args = self.setUpArgs() self.model = DummyEncoderModel(encoder=DummyEncoder()) - self.criterion = self.criterion_cls.build_criterion( - args=args, task=DummyTask(args) - ) + self.criterion = self.criterion_cls.build_criterion(args, task=DummyTask(args)) def get_src_tokens(self, correct_prediction, aggregate): """ diff --git a/tests/test_bmuf.py b/tests/test_bmuf.py index 0165b295..e7aa6da1 100644 --- a/tests/test_bmuf.py +++ b/tests/test_bmuf.py @@ -11,7 +11,7 @@ from multiprocessing import Manager import torch import torch.nn as nn from fairseq import distributed_utils, optim - +from omegaconf import OmegaConf class Model(nn.Module): def __init__(self, input_size, output_size): @@ -23,13 +23,14 @@ class Model(nn.Module): return output -def setup_model_loss_criterion(args, rank, is_cuda): +def setup_model_loss_criterion(cfg, args, rank, is_cuda): """ setup model, criterion and optimizer based on input args """ args.distributed_rank = rank - if args.distributed_world_size > 1: - distributed_utils.distributed_init(args) + cfg.distributed_training.distributed_rank = args.distributed_rank + if cfg.distributed_training.distributed_world_size > 1: + distributed_utils.distributed_init(cfg) torch.manual_seed(1) model = Model(args.input_size, args.nb_classes) loss_fn = nn.CrossEntropyLoss() @@ -38,7 +39,10 @@ def setup_model_loss_criterion(args, rank, is_cuda): loss_fn = loss_fn.cuda() optimizer = optim.sgd.SGD(args, model.parameters()) - optimizer = optim.FairseqBMUF(args, optimizer) + optimizer = optim.FairseqBMUF( + cfg=cfg.bmuf, + optimizer=optimizer + ) return model, loss_fn, optimizer @@ -52,13 +56,13 @@ def train_step(input, target, model, loss_fn, optimizer, **unused): optimizer.step() -def single_gpu_training(args, rank, iterations, shared_results): +def single_gpu_training(cfg, args, rank, iterations, shared_results): is_cuda = torch.cuda.is_available() if is_cuda: torch.cuda.set_device(rank) - model, loss_fn, optimizer = setup_model_loss_criterion(args, rank, is_cuda) + model, loss_fn, optimizer = setup_model_loss_criterion(cfg, args, rank, is_cuda) for _ in range(iterations): input = torch.randn(1, args.input_size) @@ -103,18 +107,44 @@ def setup_args(): args.distributed_init_host = "localhost" args.distributed_port = port + 1 args.local_world_size = args.distributed_world_size - return args + + cfg = OmegaConf.create() + cfg.optimization = OmegaConf.create() + cfg.common = OmegaConf.create() + cfg.distributed_training = OmegaConf.create() + cfg.dataset = OmegaConf.create() + cfg.bmuf = OmegaConf.create() + cfg.optimizer = OmegaConf.create() + + cfg.bmuf.global_sync_iter = args.global_sync_iter + cfg.bmuf.block_momentum = args.block_momentum + cfg.bmuf.block_lr = args.block_lr + cfg.dataset.batch_size = args.batch_size + cfg.optimization.lr = args.lr + cfg.optimizer.momentum = args.momentum + cfg.optimizer.weight_decay = args.weight_decay + cfg.bmuf.warmup_iterations = args.warmup_iterations + cfg.bmuf.use_nbm = args.use_nbm + cfg.bmuf.average_sync = args.average_sync + cfg.common.model_parallel_size = args.model_parallel_size + cfg.distributed_training.distributed_backend = args.distributed_backend + cfg.distributed_training.distributed_world_size = args.distributed_world_size + cfg.bmuf.distributed_world_size = args.distributed_world_size + cfg.distributed_training.distributed_init_method = args.distributed_init_method + cfg.distributed_training.distributed_port = args.distributed_port + + return cfg, args @unittest.skipIf(torch.cuda.device_count() < 2, "test requires 2 GPUs") class TestBMUF(unittest.TestCase): - def bmuf_process(self, args, iterations): + def bmuf_process(self, cfg, args, iterations): processes = [] results = Manager().dict() ctx = torch.multiprocessing.get_context("spawn") for rank in range(args.distributed_world_size): p = ctx.Process( - target=single_gpu_training, args=(args, rank, iterations, results) + target=single_gpu_training, args=(cfg, args, rank, iterations, results) ) p.start() processes.append(p) @@ -125,19 +155,20 @@ class TestBMUF(unittest.TestCase): def test_bmuf_sync(self): # Train model for 1 iteration and do bmuf sync without doing warmup - args = setup_args() + cfg, args = setup_args() iterations = 1 - results = self.bmuf_process(args, iterations) + results = self.bmuf_process(cfg, args, iterations) # Make sure params in both machines are same assert len(results) == 2 self.assertAlmostEqual(results[0], results[1]) def test_warmup_sync(self): # Train model for 20 iteration and do warmup sync without doing bmuf sync - args = setup_args() + cfg, args = setup_args() args.warmup_iterations = 20 + cfg.bmuf.warmup_iterations = args.warmup_iterations iterations = 20 - results = self.bmuf_process(args, iterations) + results = self.bmuf_process(cfg, args, iterations) # Make sure params in both machines are same assert len(results) == 2 self.assertAlmostEqual(results[0], results[1]) @@ -145,22 +176,27 @@ class TestBMUF(unittest.TestCase): def test_warmup_sync_bmuf_sync(self): # Train model for 25 iteration and do warmup sync after 20 iteration # and bmuf sync after 25 iteration - args = setup_args() + cfg, args = setup_args() args.warmup_iterations = 20 args.global_sync_iter = 5 + cfg.bmuf.warmup_iterations = args.warmup_iterations + cfg.bmuf.global_sync_iter = args.global_sync_iter iterations = 25 - results = self.bmuf_process(args, iterations) + results = self.bmuf_process(cfg, args, iterations) # Make sure params in both machines are same assert len(results) == 2 self.assertAlmostEqual(results[0], results[1]) def test_single_gpu_bmuf(self): # Train model for 5 iterations and use GPU 1 - args = setup_args() + cfg, args = setup_args() args.distributed_world_size = 1 args.warmup_iterations = 5 + cfg.distributed_training.distributed_world_size = args.distributed_world_size + cfg.bmuf.distributed_world_size = args.distributed_world_size + cfg.bmuf.warmup_iterations = args.warmup_iterations iterations = 20 - results = self.bmuf_process(args, iterations) + results = self.bmuf_process(cfg, args, iterations) assert len(results) == 1 def assertAlmostEqual(self, t1, t2): diff --git a/tests/test_fp16_optimizer.py b/tests/test_fp16_optimizer.py index c4195273..aa6a863d 100644 --- a/tests/test_fp16_optimizer.py +++ b/tests/test_fp16_optimizer.py @@ -9,6 +9,7 @@ import unittest import torch from fairseq.optim.fp16_optimizer import FP16Optimizer, MemoryEfficientFP16Optimizer +from omegaconf import OmegaConf @unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU") @@ -27,17 +28,23 @@ class TestGradientScaling(unittest.TestCase): self.model.cuda().half() self.params = list(self.model.parameters()) - self.namespace_dls = argparse.Namespace( - optimizer="adam", - lr=[0.1], - adam_betas="(0.9, 0.999)", - adam_eps=1e-8, - weight_decay=0.0, - fp16_init_scale=1, - fp16_scale_window=1, - fp16_scale_tolerance=1, - threshold_loss_scale=1, - min_loss_scale=1e-4, + self.cfg_dls = OmegaConf.create( + { + "optimizer": { + "_name": "adam", + "lr": [0.1], + "adam_betas": "(0.9, 0.999)", + "adam_eps": 1e-8, + "weight_decay": 0.0, + }, + "common": { + "fp16_init_scale": 1, + "fp16_scale_window": 1, + "fp16_scale_tolerance": 1, + "threshold_loss_scale": 1, + "min_loss_scale": 1e-4, + }, + } ) def run_iter(self, model, params, optimizer): @@ -68,7 +75,7 @@ class TestGradientScaling(unittest.TestCase): def test_mixed_precision(self): model = copy.deepcopy(self.model) params = list(model.parameters()) - optimizer = FP16Optimizer.build_optimizer(self.namespace_dls, params) + optimizer = FP16Optimizer.build_optimizer(self.cfg_dls, params) self.run_iter(model, params, optimizer) self.assertTrue( @@ -87,9 +94,7 @@ class TestGradientScaling(unittest.TestCase): def test_memory_efficient(self): model = copy.deepcopy(self.model) params = list(model.parameters()) - optimizer = MemoryEfficientFP16Optimizer.build_optimizer( - self.namespace_dls, params - ) + optimizer = MemoryEfficientFP16Optimizer.build_optimizer(self.cfg_dls, params) self.run_iter(model, params, optimizer) diff --git a/tests/test_inference_dropout.py b/tests/test_inference_dropout.py index fd5edd43..353ac674 100644 --- a/tests/test_inference_dropout.py +++ b/tests/test_inference_dropout.py @@ -6,6 +6,7 @@ import logging import unittest +from fairseq.dataclass.utils import convert_namespace_to_omegaconf from fairseq.models.transformer import TransformerModel from tests.test_sequence_generator import get_dummy_task_and_parser @@ -25,7 +26,8 @@ class TestInferenceDropout(unittest.TestCase): def test_sets_inference_dropout_to_true(self): self.args.retain_dropout = True self.transformer_model = TransformerModel.build_model(self.args, self.task) - self.transformer_model.prepare_for_inference_(self.args) + cfg = convert_namespace_to_omegaconf(self.args) + self.transformer_model.prepare_for_inference_(cfg) assert self.transformer_model.encoder.dropout_module.apply_during_inference assert self.transformer_model.decoder.dropout_module.apply_during_inference for layer in self.transformer_model.encoder.layers: @@ -33,7 +35,8 @@ class TestInferenceDropout(unittest.TestCase): def test_inference_dropout_false_by_default(self): self.transformer_model = TransformerModel.build_model(self.args, self.task) - self.transformer_model.prepare_for_inference_(self.args) + cfg = convert_namespace_to_omegaconf(self.args) + self.transformer_model.prepare_for_inference_(cfg) assert not self.transformer_model.encoder.dropout_module.apply_during_inference assert not self.transformer_model.decoder.dropout_module.apply_during_inference for layer in self.transformer_model.encoder.layers: @@ -59,7 +62,8 @@ class TestInferenceDropout(unittest.TestCase): "TransformerEncoderLayer", ] self.transformer_model = TransformerModel.build_model(self.args, self.task) - self.transformer_model.prepare_for_inference_(self.args) + cfg = convert_namespace_to_omegaconf(self.args) + self.transformer_model.prepare_for_inference_(cfg) assert self.transformer_model.encoder.dropout_module.apply_during_inference assert not self.transformer_model.decoder.dropout_module.apply_during_inference for layer in self.transformer_model.decoder.layers: diff --git a/tests/test_memory_efficient_fp16.py b/tests/test_memory_efficient_fp16.py index e10636d9..2bf2f298 100644 --- a/tests/test_memory_efficient_fp16.py +++ b/tests/test_memory_efficient_fp16.py @@ -10,6 +10,7 @@ import unittest import torch from fairseq.optim.adam import FairseqAdam from fairseq.optim.fp16_optimizer import MemoryEfficientFP16Optimizer +from omegaconf import OmegaConf @unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU") @@ -26,25 +27,36 @@ class TestMemoryEfficientFP16(unittest.TestCase): params = list(model.parameters()) # initialize memory efficient FP16 optimizer + # with pseudo DictConfigs optimizer = FairseqAdam( - argparse.Namespace( - lr=[0.00001], - adam_betas="(0.9, 0.999)", - adam_eps=1e-8, - weight_decay=0.0, + cfg=OmegaConf.create( + vars( + argparse.Namespace( + adam_betas="(0.9, 0.999)", + adam_eps=1e-8, + weight_decay=0.0, + lr=[0.00001], + ) + ) ), - params, + params=params, ) me_optimizer = MemoryEfficientFP16Optimizer( - argparse.Namespace( - fp16_init_scale=1, - fp16_scale_window=1, - fp16_scale_tolerance=1, - threshold_loss_scale=1, - min_loss_scale=1e-4, + cfg=OmegaConf.create( + { + "common": vars( + argparse.Namespace( + fp16_init_scale=1, + fp16_scale_window=1, + fp16_scale_tolerance=1, + threshold_loss_scale=1, + min_loss_scale=1e-4, + ) + ) + } ), - params, - optimizer, + params=params, + optimizer=optimizer, ) # optimizer state is created in the first step diff --git a/tests/test_train.py b/tests/test_train.py index 1b7e027c..57daa194 100644 --- a/tests/test_train.py +++ b/tests/test_train.py @@ -11,6 +11,7 @@ from unittest.mock import MagicMock, patch import torch from fairseq import checkpoint_utils, data +from omegaconf import OmegaConf def mock_trainer(epoch, num_updates, iterations_in_epoch): @@ -56,21 +57,29 @@ def get_trainer_and_epoch_itr(epoch, epoch_size, num_updates, iterations_in_epoc return trainer, epoch_itr -def get_mock_args(finetune_from_model=None): - args_mock = MagicMock() - args_mock.optimizer_overrides = "{}" - args_mock.reset_dataloader = False - args_mock.reset_meters = False - args_mock.reset_optimizer = False - args_mock.reset_lr_scheduler = False - args_mock.finetune_from_model = finetune_from_model - args_mock.model_parallel_size = 1 - return args_mock +def get_mock_cfg(finetune_from_model): + cfg_mock = OmegaConf.create( + { + "checkpoint": { + "optimizer_overrides": "{}", + "reset_dataloader": False, + "reset_meters": False, + "reset_optimizer": False, + "reset_lr_scheduler": False, + "finetune_from_model": finetune_from_model, + "model_parallel_size": 1, + }, + "common": { + "model_parallel_size": 1, + }, + } + ) + return cfg_mock class TestLoadCheckpoint(unittest.TestCase): def setUp(self): - self.args_mock = get_mock_args() + self.cfg_mock = get_mock_cfg(None) self.patches = { "os.makedirs": MagicMock(), "os.path.join": MagicMock(), @@ -91,7 +100,9 @@ class TestLoadCheckpoint(unittest.TestCase): trainer, epoch_itr = get_trainer_and_epoch_itr(2, 150, 200, 50) trainer.get_train_iterator = MagicMock(return_value=epoch_itr) - _, epoch_itr = checkpoint_utils.load_checkpoint(self.args_mock, trainer) + _, epoch_itr = checkpoint_utils.load_checkpoint( + self.cfg_mock.checkpoint, trainer + ) self.assertEqual(epoch_itr.epoch, 2) self.assertEqual(epoch_itr.iterations_in_epoch, 50) @@ -120,7 +131,9 @@ class TestLoadCheckpoint(unittest.TestCase): trainer, epoch_itr = get_trainer_and_epoch_itr(2, 150, 300, 150) trainer.get_train_iterator = MagicMock(return_value=epoch_itr) - _, epoch_itr = checkpoint_utils.load_checkpoint(self.args_mock, trainer) + _, epoch_itr = checkpoint_utils.load_checkpoint( + self.cfg_mock.checkpoint, trainer + ) itr = epoch_itr.next_epoch_itr(shuffle=False) self.assertEqual(epoch_itr.epoch, 3) @@ -133,7 +146,9 @@ class TestLoadCheckpoint(unittest.TestCase): trainer.get_train_iterator = MagicMock(return_value=epoch_itr) self.patches["os.path.isfile"].return_value = False - _, epoch_itr = checkpoint_utils.load_checkpoint(self.args_mock, trainer) + _, epoch_itr = checkpoint_utils.load_checkpoint( + self.cfg_mock.checkpoint, trainer + ) itr = epoch_itr.next_epoch_itr(shuffle=False) self.assertEqual(epoch_itr.epoch, 1) @@ -152,10 +167,12 @@ class TestLoadCheckpoint(unittest.TestCase): "reset_dataloader", ]: with self.subTest(arg=arg): - args_mock = get_mock_args("/temp/checkpoint_pretrained.pt") - setattr(args_mock, arg, True) + cfg_mock = get_mock_cfg("/temp/checkpoint_pretrained.pt") + cfg_mock["checkpoint"][arg] = True with self.assertRaises(Exception) as context: - _, _ = checkpoint_utils.load_checkpoint(args_mock, trainer) + _, _ = checkpoint_utils.load_checkpoint( + cfg_mock.checkpoint, trainer + ) self.assertTrue( "--finetune-from-model can not be set together with either --reset-optimizer" @@ -168,8 +185,6 @@ class TestLoadCheckpoint(unittest.TestCase): trainer, epoch_itr = get_trainer_and_epoch_itr(1, 150, 0, 0) trainer.get_train_iterator = MagicMock(return_value=epoch_itr) from_model_path = "/temp/checkpoint_pretrained.pt" - args_mock = get_mock_args(from_model_path) - args_mock.restore_file = "checkpoint_last.pt" def mock_finetune_exist(path): if path == from_model_path: @@ -180,7 +195,9 @@ class TestLoadCheckpoint(unittest.TestCase): self.patches[ "fairseq.file_io.PathManager.exists" ].side_effect = mock_finetune_exist - _, _ = checkpoint_utils.load_checkpoint(args_mock, trainer) + cfg_mock = get_mock_cfg(from_model_path) + cfg_mock.checkpoint.restore_file = "checkpoint_last.pt" + _, _ = checkpoint_utils.load_checkpoint(cfg_mock.checkpoint, trainer) ( checkpoint_path, reset_optimizer, @@ -197,8 +214,6 @@ class TestLoadCheckpoint(unittest.TestCase): trainer, epoch_itr = get_trainer_and_epoch_itr(1, 150, 0, 0) trainer.get_train_iterator = MagicMock(return_value=epoch_itr) from_model_path = "/temp/checkpoint_pretrained.pt" - args_mock = get_mock_args(from_model_path) - args_mock.restore_file = "checkpoint_last.pt" # launch second time # both restore_file=checkpoint_last.pt and finetune_from_model are set @@ -211,7 +226,9 @@ class TestLoadCheckpoint(unittest.TestCase): self.patches[ "fairseq.file_io.PathManager.exists" ].side_effect = mock_finetune_exist - _, _ = checkpoint_utils.load_checkpoint(args_mock, trainer) + cfg_mock = get_mock_cfg(from_model_path) + cfg_mock.checkpoint.restore_file = "checkpoint_last.pt" + _, _ = checkpoint_utils.load_checkpoint(cfg_mock.checkpoint, trainer) ( checkpoint_path, reset_optimizer, diff --git a/tests/utils.py b/tests/utils.py index 91feca6b..a145aa58 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -20,7 +20,7 @@ from fairseq.models import ( FairseqIncrementalDecoder, ) from fairseq.models.fairseq_encoder import EncoderOut -from fairseq.tasks import FairseqTask, LegacyFairseqTask +from fairseq.tasks import LegacyFairseqTask from fairseq_cli import generate, interactive, preprocess, train, validate