diff --git a/config/config.yaml b/config/config.yaml index 66723e70..b9ee6c74 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -1,7 +1,111 @@ +# @package _group_ +common: + no_progress_bar: false + log_interval: 100 + log_format: null + tensorboard_logdir: null + seed: 1 + cpu: false + tpu: false + bf16: false + fp16: false + memory_efficient_fp16: false + memory_efficient_bf16: false + fp16_no_flatten_grads: false + fp16_init_scale: 128 + fp16_scale_window: null + fp16_scale_tolerance: 0.0 + min_loss_scale: 1.0e-4 + threshold_loss_scale: null + user_dir: null + empty_cache_freq: 0 + all_gather_list_size: 16384 + model_parallel_size: 1 + quantization_config_path: null + profile: false +distributed_training: + distributed_rank: 0 + distributed_backend: "nccl" + distributed_init_method: null + distributed_port: -1 + device_id: 0 + local_rank: 0 + distributed_no_spawn: false + ddp_backend: "c10d" + bucket_cap_mb: 25 + fix_batches_to_gpus: false + find_unused_parameters: false + fast_stat_sync: false + broadcast_buffers: false + distributed_wrapper: "DDP" + slowmo_momentum: null + slowmo_algorithm: "LocalSGD" + localsgd_frequency: 3 +dataset: + num_workers: 1 + skip_invalid_size_inputs_valid_test: false + max_tokens: null + batch_size: null + required_batch_size_multiple: 8 + dataset_impl: null + data_buffer_size: 10 + train_subset: "train" + valid_subset: "valid" + validate_interval: 1 + fixed_validation_seed: null + disable_validation: false + curriculum: 0 + gen_subset: "test" + num_shards: 1 + shard_id: 0 + max_tokens_valid: ${dataset.max_tokens} + batch_size_valid: ${dataset.batch_size} +optimization: + max_epoch: 0 + max_update: 0 + clip_norm: 25.0 + sentence_avg: false + update_freq: [ 1 ] + lr: [ 0.25 ] + min_lr: -1.0 + use_bmuf: false +checkpoint: + save_dir: "checkpoints" + restore_file: "checkpoint_last.pt" + reset_dataloader: false + reset_lr_scheduler: false + reset_meters: false + reset_optimizer: false + optimizer_overrides: "{}" + save_interval: 1 + save_interval_updates: 0 + keep_interval_updates: -1 + keep_last_epochs: -1 + keep_best_checkpoints: -1 + no_save: false + no_epoch_checkpoints: false + no_last_checkpoints: false + no_save_optimizer_state: false + best_checkpoint_metric: "loss" + maximize_best_checkpoint_metric: false + patience: -1 + checkpoint_suffix: "" +bmuf: + block_lr: 1 + block_momentum: 0.875 + global_sync_iter: 50 + warmup_iterations: 500 + use_nbm: false + average_sync: false defaults: - - params: training_params - - task: language_modeling - - model: transformer_lm - - criterion: cross_entropy - - optimizer: adam - - lr_scheduler: inverse_sqrt + - task: language_modeling + - model: null + - criterion: null + - optimizer: null + - lr_scheduler: null + - bpe: null + - tokenizer: null + - scoring: null + - generation: null + - common_eval: null + - eval_lm: null diff --git a/config/config_eval_lm.yaml b/config/config_eval_lm.yaml deleted file mode 100644 index 5a93cb5d..00000000 --- a/config/config_eval_lm.yaml +++ /dev/null @@ -1,7 +0,0 @@ -defaults: - - params: eval_lm_params - - task: language_modeling - - model: transformer_lm - - criterion: cross_entropy - - optimizer: adam - - lr_scheduler: inverse_sqrt diff --git a/config/criterion/adaptive_loss.yaml b/config/criterion/adaptive_loss.yaml index a85a7eed..7997b076 100644 --- a/config/criterion/adaptive_loss.yaml +++ b/config/criterion/adaptive_loss.yaml @@ -1,3 +1,3 @@ # @package _group_ -sentence_avg: ${params.optimization.sentence_avg} -ddp_backend: ${params.distributed_training.ddp_backend} +sentence_avg: ${optimization.sentence_avg} +ddp_backend: ${distributed_training.ddp_backend} diff --git a/config/criterion/cross_entropy.yaml b/config/criterion/cross_entropy.yaml index a85a7eed..ad3d4148 100644 --- a/config/criterion/cross_entropy.yaml +++ b/config/criterion/cross_entropy.yaml @@ -1,3 +1,2 @@ # @package _group_ -sentence_avg: ${params.optimization.sentence_avg} -ddp_backend: ${params.distributed_training.ddp_backend} +sentence_avg: ${optimization.sentence_avg} diff --git a/config/params/eval_lm_params.yaml b/config/params/eval_lm_params.yaml deleted file mode 100644 index 6f27055d..00000000 --- a/config/params/eval_lm_params.yaml +++ /dev/null @@ -1,105 +0,0 @@ -# @package _group_ -common: - no_progress_bar: false - log_interval: 100 - log_format: null - tensorboard_logdir: null - seed: 1 - cpu: false - fp16: false - memory_efficient_fp16: false - fp16_no_flatten_grads: false - fp16_init_scale: 128 - fp16_scale_window: null - fp16_scale_tolerance: 0.0 - min_loss_scale: 1.0e-4 - threshold_loss_scale: null - user_dir: null - empty_cache_freq: 0 - all_gather_list_size: 16384 - model_parallel_size: 1 - checkpoint_suffix: "" - quantization_config_path: null -distributed_training: - distributed_rank: 0 - distributed_backend: "nccl" - distributed_init_method: null - distributed_port: -1 - device_id: 0 - local_rank: 0 - distributed_no_spawn: false - ddp_backend: "c10d" - bucket_cap_mb: 25 - fix_batches_to_gpus: false - find_unused_parameters: false - fast_stat_sync: false - broadcast_buffers: false - distributed_wrapper: "DDP" - slowmo_momentum: null - slowmo_algorithm: "LocalSGD" - localsgd_frequency: 3 -dataset: - num_workers: 1 - skip_invalid_size_inputs_valid_test: false - max_tokens: null - batch_size: ${params.dataset.batch_size} - required_batch_size_multiple: 8 - dataset_impl: null - data_buffer_size: 10 - train_subset: "train" - valid_subset: "valid" - validate_interval: 1 - fixed_validation_seed: null - disable_validation: false - curriculum: 0 - gen_subset: "test" - num_shards: 1 - shard_id: 0 - max_tokens_valid: ${params.dataset.max_tokens} - batch_size_valid: ${params.dataset.batch_size} -optimization: - max_epoch: 0 - max_update: 0 - clip_norm: 25.0 - sentence_avg: false - update_freq: [1] - lr: [0.25] - min_lr: -1.0 - use_bmuf: false -checkpoint: - save_dir: "checkpoints" - restore_file: "checkpoint_last.pt" - reset_dataloader: false - reset_lr_scheduler: false - reset_meters: false - reset_optimizer: false - optimizer_overrides: "{}" - save_interval: 1 - save_interval_updates: 0 - keep_interval_updates: -1 - keep_last_epochs: -1 - keep_best_checkpoints: -1 - no_save: false - no_epoch_checkpoints: false - no_last_checkpoints: false - no_save_optimizer_state: false - best_checkpoint_metric: "loss" - maximize_best_checkpoint_metric: false - patience: -1 -common_eval: - path: null - remove_bpe: null - quiet: false - model_overrides: "{}" - results_path: null -eval_lm: - output_word_probs: false - output_word_stats: false - context_window: 0 -bmuf: - block_lr: 1 - block_momentum: 0.875 - global_sync_iter: 50 - warmup_iterations: 500 - use_nbm: false - average_sync: false diff --git a/config/params/training_params.yaml b/config/params/training_params.yaml deleted file mode 100644 index 2ce94f92..00000000 --- a/config/params/training_params.yaml +++ /dev/null @@ -1,95 +0,0 @@ -# @package _group_ -common: - no_progress_bar: false - log_interval: 100 - log_format: null - tensorboard_logdir: null - seed: 1 - cpu: false - fp16: false - memory_efficient_fp16: false - fp16_no_flatten_grads: false - fp16_init_scale: 128 - fp16_scale_window: null - fp16_scale_tolerance: 0.0 - min_loss_scale: 1.0e-4 - threshold_loss_scale: null - user_dir: null - empty_cache_freq: 0 - all_gather_list_size: 16384 - model_parallel_size: 1 - checkpoint_suffix: "" - quantization_config_path: null -distributed_training: - distributed_rank: 0 - distributed_backend: "nccl" - distributed_init_method: null - distributed_port: -1 - device_id: 0 - local_rank: 0 - distributed_no_spawn: false - ddp_backend: "c10d" - bucket_cap_mb: 25 - fix_batches_to_gpus: false - find_unused_parameters: false - fast_stat_sync: false - broadcast_buffers: false - distributed_wrapper: "DDP" - slowmo_momentum: null - slowmo_algorithm: "LocalSGD" - localsgd_frequency: 3 -dataset: - num_workers: 1 - skip_invalid_size_inputs_valid_test: false - max_tokens: null - batch_size: ${params.dataset.batch_size} - required_batch_size_multiple: 8 - dataset_impl: null - data_buffer_size: 10 - train_subset: "train" - valid_subset: "valid" - validate_interval: 1 - fixed_validation_seed: null - disable_validation: false - curriculum: 0 - gen_subset: "test" - num_shards: 1 - shard_id: 0 - max_tokens_valid: ${params.dataset.max_tokens} - batch_size_valid: ${params.dataset.batch_size} -optimization: - max_epoch: 0 - max_update: 0 - clip_norm: 25.0 - sentence_avg: false - update_freq: [1] - lr: [0.25] - min_lr: -1.0 - use_bmuf: false -checkpoint: - save_dir: "checkpoints" - restore_file: "checkpoint_last.pt" - reset_dataloader: false - reset_lr_scheduler: false - reset_meters: false - reset_optimizer: false - optimizer_overrides: "{}" - save_interval: 1 - save_interval_updates: 0 - keep_interval_updates: -1 - keep_last_epochs: -1 - keep_best_checkpoints: -1 - no_save: false - no_epoch_checkpoints: false - no_last_checkpoints: false - no_save_optimizer_state: false - best_checkpoint_metric: "loss" - maximize_best_checkpoint_metric: false - patience: -1 -bmuf: - block_lr: 1 - block_momentum: 0.875 - global_sync_iter: 50 - warmup_iterations: 500 - use_nbm: false - average_sync: false diff --git a/docs/hydra_integration.md b/docs/hydra_integration.md index 9b77dd83..0973cd27 100644 --- a/docs/hydra_integration.md +++ b/docs/hydra_integration.md @@ -13,7 +13,6 @@ For example, if we'd like to train a language model with transformer, we could p ``` defaults: - - params: training_params - task: language_modeling - model: transformer_lm - criterion: cross_entropy @@ -21,7 +20,7 @@ defaults: - lr_scheduler: inverse_sqrt ``` -- Provide generic parameters common across different training jobs: `config/params/training_params.yaml` +- Provide generic parameters common across different jobs: `config.yaml` - Provide task parameters: `config/task/language_modeling.yaml` - Provide model parameters: `config/model/transformer_lm.yaml` - Provide criterion parameters: `config/criterion/cross_entropy.yaml` @@ -41,7 +40,6 @@ Alternatively, if we need to override certain params from the command line, we c ``` python fairseq_cli/train_hydra.py -params=training_params \ task=language_modeling \ task.data=/private/home/abaevski/data/wiki103 \ task.tokens_per_sample=512 \ @@ -56,17 +54,17 @@ lr_scheduler=inverse_sqrt \ lr_scheduler.warmup_updates=4000 \ lr_scheduler.warmup_init_lr=1e-07 \ criterion=cross_entropy \ -params.common.fp16=true \ -params.common.log_format=json \ -params.common.log_interval=1 \ -params.dataset.max_tokens=1024 \ -params.dataset.num_workers=4 \ -params.optimization.update_freq=[16] \ -params.optimization.max_update=50000 \ -params.optimization.clip_norm=0.0 \ -params.optimization.lr=[0.0005] \ -params.checkpoint.save_dir=/checkpoint/mtian/transformer_wikitext-103-hydra-args-cli \ -params.checkpoint.save_interval_updates=10 +common.fp16=true \ +common.log_format=json \ +common.log_interval=1 \ +dataset.max_tokens=1024 \ +dataset.num_workers=4 \ +optimization.update_freq=[16] \ +optimization.max_update=50000 \ +optimization.clip_norm=0.0 \ +optimization.lr=[0.0005] \ +checkpoint.save_dir=/checkpoint/mtian/transformer_wikitext-103-hydra-args-cli \ +checkpoint.save_interval_updates=10 ``` ## Migrate existing/Creating new modules to hydra interface diff --git a/docs/tutorial_classifying_names.rst b/docs/tutorial_classifying_names.rst index 40a3cb6f..b02fec04 100644 --- a/docs/tutorial_classifying_names.rst +++ b/docs/tutorial_classifying_names.rst @@ -212,7 +212,7 @@ following contents:: @register_task('simple_classification') - class SimpleClassificationTask(FairseqTask): + class SimpleClassificationTask(LegacyFairseqTask): @staticmethod def add_args(parser): diff --git a/examples/noisychannel/rerank.py b/examples/noisychannel/rerank.py index 4df424e6..13036926 100644 --- a/examples/noisychannel/rerank.py +++ b/examples/noisychannel/rerank.py @@ -27,7 +27,13 @@ def score_target_hypo( print("lenpen", lenpen, "weight1", a, "weight2", b, "weight3", c) gen_output_lst, bitext1_lst, bitext2_lst, lm_res_lst = load_score_files(args) dict = dictionary.Dictionary() - scorer = bleu.Scorer(dict.pad(), dict.eos(), dict.unk()) + scorer = scorer = bleu.Scorer( + bleu.BleuConfig( + pad=dict.pad(), + eos=dict.eos(), + unk=dict.unk(), + ) + ) ordered_hypos = {} ordered_targets = {} diff --git a/examples/roberta/wsc/wsc_criterion.py b/examples/roberta/wsc/wsc_criterion.py index 1a590123..ed0251fd 100644 --- a/examples/roberta/wsc/wsc_criterion.py +++ b/examples/roberta/wsc/wsc_criterion.py @@ -20,8 +20,8 @@ class WSCCriterion(LegacyFairseqCriterion): self.prediction_h = open(self.args.save_predictions, "w") else: self.prediction_h = None - self.bpe = encoders.build_bpe(args) - self.tokenizer = encoders.build_tokenizer(args) + self.bpe = encoders.build_bpe(args.bpe) + self.tokenizer = encoders.build_tokenizer(args.tokenizer) def __del__(self): if self.prediction_h is not None: diff --git a/examples/unsupervised_quality_estimation/README.md b/examples/unsupervised_quality_estimation/README.md index 809a58e4..aeb96a14 100644 --- a/examples/unsupervised_quality_estimation/README.md +++ b/examples/unsupervised_quality_estimation/README.md @@ -85,7 +85,7 @@ Produce model scores for the generated translations using `--retain-dropout` opt ``` CUDA_VISIBLE_DEVICES=${GPU} fairseq-generate ${TMP}/bin-repeated --path ${MODEL_DIR}/${LP}.pt --beam 5 --source-lang $SRC_LANG --target-lang $TGT_LANG --no-progress-bar --unkpen 5 --score-reference --retain-dropout - --retain-dropout-modules TransformerModel TransformerEncoder TransformerDecoder TransformerEncoderLayer + --retain-dropout-modules '["TransformerModel","TransformerEncoder","TransformerDecoder","TransformerEncoderLayer"]' TransformerDecoderLayer --seed 46 > $TMP/dropout.scoring.out grep ^H $TMP/dropout.scoring.out | cut -f2- > $TMP/dropout.scores diff --git a/fairseq/checkpoint_utils.py b/fairseq/checkpoint_utils.py index 75e2c68c..c036e129 100644 --- a/fairseq/checkpoint_utils.py +++ b/fairseq/checkpoint_utils.py @@ -3,36 +3,42 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import ast import collections import logging import os import re import traceback from collections import OrderedDict -from typing import Union +from typing import Optional, Union import torch +from fairseq.dataclass.utils import ( + convert_namespace_to_omegaconf, + overwrite_args_by_name, +) from fairseq.file_io import PathManager from fairseq.models import FairseqDecoder, FairseqEncoder +from omegaconf import DictConfig, open_dict from torch.serialization import default_restore_location logger = logging.getLogger(__name__) -def save_checkpoint(args, trainer, epoch_itr, val_loss): - from fairseq import distributed_utils, meters +def save_checkpoint(cfg: DictConfig, trainer, epoch_itr, val_loss): + from fairseq import meters # only one worker should attempt to create the required dir - if args.distributed_rank == 0: - os.makedirs(args.save_dir, exist_ok=True) + if cfg.distributed_rank == 0: + os.makedirs(cfg.save_dir, exist_ok=True) prev_best = getattr(save_checkpoint, "best", val_loss) if val_loss is not None: - best_function = max if args.maximize_best_checkpoint_metric else min + best_function = max if cfg.maximize_best_checkpoint_metric else min save_checkpoint.best = best_function(val_loss, prev_best) - if args.no_save: + if cfg.no_save: return trainer.consolidate_optimizer() @@ -41,7 +47,7 @@ def save_checkpoint(args, trainer, epoch_itr, val_loss): return def is_better(a, b): - return a >= b if args.maximize_best_checkpoint_metric else a <= b + return a >= b if cfg.maximize_best_checkpoint_metric else a <= b write_timer = meters.StopwatchMeter() write_timer.start() @@ -50,38 +56,36 @@ def save_checkpoint(args, trainer, epoch_itr, val_loss): end_of_epoch = epoch_itr.end_of_epoch() updates = trainer.get_num_updates() - suffix = getattr(args, "checkpoint_suffix", "") + suffix = cfg.checkpoint_suffix or "" checkpoint_conds = collections.OrderedDict() checkpoint_conds["checkpoint{}{}.pt".format(epoch, suffix)] = ( - end_of_epoch - and not args.no_epoch_checkpoints - and epoch % args.save_interval == 0 + end_of_epoch and not cfg.no_epoch_checkpoints and epoch % cfg.save_interval == 0 ) checkpoint_conds["checkpoint_{}_{}{}.pt".format(epoch, updates, suffix)] = ( not end_of_epoch - and args.save_interval_updates > 0 - and updates % args.save_interval_updates == 0 + and cfg.save_interval_updates > 0 + and updates % cfg.save_interval_updates == 0 ) checkpoint_conds["checkpoint_best{}.pt".format(suffix)] = val_loss is not None and ( not hasattr(save_checkpoint, "best") or is_better(val_loss, save_checkpoint.best) ) - if val_loss is not None and args.keep_best_checkpoints > 0: + if val_loss is not None and cfg.keep_best_checkpoints > 0: checkpoint_conds[ - "checkpoint.best_{}_{:.2f}.pt".format(args.best_checkpoint_metric, val_loss) + "checkpoint.best_{}_{:.2f}.pt".format(cfg.best_checkpoint_metric, val_loss) ] = not hasattr(save_checkpoint, "best") or is_better( val_loss, save_checkpoint.best ) checkpoint_conds[ "checkpoint_last{}.pt".format(suffix) - ] = not args.no_last_checkpoints + ] = not cfg.no_last_checkpoints extra_state = {"train_iterator": epoch_itr.state_dict(), "val_loss": val_loss} if hasattr(save_checkpoint, "best"): extra_state.update({"best": save_checkpoint.best}) checkpoints = [ - os.path.join(args.save_dir, fn) for fn, cond in checkpoint_conds.items() if cond + os.path.join(cfg.save_dir, fn) for fn, cond in checkpoint_conds.items() if cond ] if len(checkpoints) > 0: trainer.save_checkpoint(checkpoints[0], extra_state) @@ -95,51 +99,52 @@ def save_checkpoint(args, trainer, epoch_itr, val_loss): ) ) - if not end_of_epoch and args.keep_interval_updates > 0: + if not end_of_epoch and cfg.keep_interval_updates > 0: # remove old checkpoints; checkpoints are sorted in descending order checkpoints = checkpoint_paths( - args.save_dir, pattern=r"checkpoint_\d+_(\d+)\.pt" + cfg.save_dir, pattern=r"checkpoint_\d+_(\d+)\.pt" ) - for old_chk in checkpoints[args.keep_interval_updates :]: + for old_chk in checkpoints[cfg.keep_interval_updates :]: if os.path.lexists(old_chk): os.remove(old_chk) - if args.keep_last_epochs > 0: + if cfg.keep_last_epochs > 0: # remove old epoch checkpoints; checkpoints are sorted in descending order - checkpoints = checkpoint_paths(args.save_dir, pattern=r"checkpoint(\d+)\.pt") - for old_chk in checkpoints[args.keep_last_epochs :]: + checkpoints = checkpoint_paths(cfg.save_dir, pattern=r"checkpoint(\d+)\.pt") + for old_chk in checkpoints[cfg.keep_last_epochs :]: if os.path.lexists(old_chk): os.remove(old_chk) - if args.keep_best_checkpoints > 0: + if cfg.keep_best_checkpoints > 0: # only keep the best N checkpoints according to validation metric checkpoints = checkpoint_paths( - args.save_dir, + cfg.save_dir, pattern=r"checkpoint\.best_{}_(\d+\.?\d*)\.pt".format( - args.best_checkpoint_metric + cfg.best_checkpoint_metric ), ) - if not args.maximize_best_checkpoint_metric: + if not cfg.maximize_best_checkpoint_metric: checkpoints = checkpoints[::-1] - for old_chk in checkpoints[args.keep_best_checkpoints :]: + for old_chk in checkpoints[cfg.keep_best_checkpoints :]: if os.path.lexists(old_chk): os.remove(old_chk) -def load_checkpoint(args, trainer, **passthrough_args): +def load_checkpoint(cfg: DictConfig, trainer, **passthrough_args): """ Load a checkpoint and restore the training iterator. *passthrough_args* will be passed through to ``trainer.get_train_iterator``. """ - reset_optimizer = args.reset_optimizer - reset_lr_scheduler = args.reset_lr_scheduler - optimizer_overrides = eval(args.optimizer_overrides) - reset_meters = args.reset_meters - reset_dataloader = args.reset_dataloader - if getattr(args, "finetune_from_model", None) is not None and ( + reset_optimizer = cfg.reset_optimizer + reset_lr_scheduler = cfg.reset_lr_scheduler + optimizer_overrides = ast.literal_eval(cfg.optimizer_overrides) + reset_meters = cfg.reset_meters + reset_dataloader = cfg.reset_dataloader + + if cfg.finetune_from_model is not None and ( reset_optimizer or reset_lr_scheduler or reset_meters or reset_dataloader ): raise ValueError( @@ -147,19 +152,19 @@ def load_checkpoint(args, trainer, **passthrough_args): " or reset_lr_scheduler or reset_meters or reset_dataloader" ) - suffix = getattr(args, "checkpoint_suffix", "") + suffix = cfg.checkpoint_suffix if ( - args.restore_file == "checkpoint_last.pt" + cfg.restore_file == "checkpoint_last.pt" ): # default value of restore_file is 'checkpoint_last.pt' checkpoint_path = os.path.join( - args.save_dir, "checkpoint_last{}.pt".format(suffix) + cfg.save_dir, "checkpoint_last{}.pt".format(suffix) ) first_launch = not PathManager.exists(checkpoint_path) - if getattr(args, "finetune_from_model", None) is not None and first_launch: + if cfg.finetune_from_model is not None and first_launch: # if there is no last checkpoint to restore, start the finetune from pretrained model # else just use usual logic to load checkpoint, e.g. restart from last checkpoint and etc. - if PathManager.exists(args.finetune_from_model): - checkpoint_path = args.finetune_from_model + if PathManager.exists(cfg.finetune_from_model): + checkpoint_path = cfg.finetune_from_model reset_optimizer = True reset_lr_scheduler = True reset_meters = True @@ -170,19 +175,17 @@ def load_checkpoint(args, trainer, **passthrough_args): ) else: raise ValueError( - f"--funetune-from-model {args.finetune_from_model} does not exist" + f"--funetune-from-model {cfg.finetune_from_model} does not exist" ) - elif getattr(args, "model_parallel_size", 1) > 1: - checkpoint_path = args.restore_file.replace(".pt", suffix + ".pt") + elif cfg.model_parallel_size > 1: + checkpoint_path = cfg.restore_file.replace(".pt", suffix + ".pt") else: - checkpoint_path = args.restore_file + checkpoint_path = cfg.restore_file - if args.restore_file != "checkpoint_last.pt" and getattr( - args, "finetune_from_model", None - ): + if cfg.restore_file != "checkpoint_last.pt" and cfg.finetune_from_model: raise ValueError( "--finetune-from-model and --restore-file (non-default value) " - "can not be specified together: " + str(args) + "can not be specified together: " + str(cfg) ) extra_state = trainer.load_checkpoint( @@ -225,10 +228,14 @@ def load_checkpoint_to_cpu(path, arg_overrides=None): f, map_location=lambda s, l: default_restore_location(s, "cpu") ) - args = state["args"] - if arg_overrides is not None: + if "args" in state and state["args"] is not None and arg_overrides is not None: + args = state["args"] for arg_name, arg_val in arg_overrides.items(): setattr(args, arg_name, arg_val) + + if "cfg" in state and state["cfg"] is not None and arg_overrides is not None: + overwrite_args_by_name(state["cfg"], arg_overrides) + state = _upgrade_state_dict(state) return state @@ -274,19 +281,28 @@ def load_model_ensemble_and_task( filename = filename.replace(".pt", suffix + ".pt") else: filename = orig_filename[:-3] + f"_part{shard_idx}.pt" + if not PathManager.exists(filename): raise IOError("Model file not found: {}".format(filename)) state = load_checkpoint_to_cpu(filename, arg_overrides) - if shard_idx == 0: - args = state["args"] - if task is None: - task = tasks.setup_task(args) + if "args" in state and state["args"] is not None: + cfg = convert_namespace_to_omegaconf(state["args"]) + elif "cfg" in state and state["cfg"] is not None: + cfg = state["cfg"] + else: + raise RuntimeError( + f"Neither args nor cfg exist in state keys = {state.keys()}" + ) - # build model for ensemble - model = task.build_model(args) - model.load_state_dict(state["model"], strict=strict, args=args) + if task is None: + task = tasks.setup_task(cfg.task) + + # build model for ensemble + model = task.build_model(cfg.model) + + model.load_state_dict(state["model"], strict=strict, model_cfg=cfg.model) ensemble.append(model) - return ensemble, args, task + return ensemble, cfg, task def checkpoint_paths(path, pattern=r"checkpoint(\d+)\.pt"): @@ -323,7 +339,7 @@ def torch_persistent_save(obj, f): def save_state( filename, - args, + cfg: DictConfig, model_state_dict, criterion, optimizer, @@ -331,6 +347,7 @@ def save_state( num_updates, optim_history=None, extra_state=None, + **kwargs, ): from fairseq import utils @@ -339,7 +356,8 @@ def save_state( if extra_state is None: extra_state = {} state_dict = { - "args": args, + "cfg": cfg, + "args": kwargs.get("args", None), "model": model_state_dict or {}, "optimizer_history": optim_history + [ @@ -354,11 +372,17 @@ def save_state( } if utils.has_parameters(criterion): state_dict["criterion"] = criterion.state_dict() - if not args.no_save_optimizer_state: - state_dict["last_optimizer_state"] = optimizer.state_dict() - # convert all state to CPU - state_dict = utils.move_to_cpu(state_dict) + if cfg is None: + cfg = state_dict["args"] + assert cfg is not None, "must provide cfg or args" + + if isinstance(cfg, DictConfig): + no_save_optimizer_state = cfg.checkpoint.no_save_optimizer_state + else: + no_save_optimizer_state = cfg.no_save_optimizer_state + if not no_save_optimizer_state: + state_dict["last_optimizer_state"] = optimizer.state_dict() with PathManager.open(filename, "wb") as f: torch_persistent_save(state_dict, f) @@ -403,46 +427,49 @@ def _upgrade_state_dict(state): # keep track of number of updates if "num_updates" not in state["optimizer_history"][-1]: state["optimizer_history"][-1]["num_updates"] = 0 - # old model checkpoints may not have separate source/target positions - if hasattr(state["args"], "max_positions") and not hasattr( - state["args"], "max_source_positions" - ): - state["args"].max_source_positions = state["args"].max_positions - state["args"].max_target_positions = state["args"].max_positions # use stateful training data iterator if "train_iterator" not in state["extra_state"]: state["extra_state"]["train_iterator"] = { "epoch": state["extra_state"]["epoch"], "iterations_in_epoch": state["extra_state"].get("batch_offset", 0), } - # default to translation task - if not hasattr(state["args"], "task"): - state["args"].task = "translation" - # --raw-text and --lazy-load are deprecated - if getattr(state["args"], "raw_text", False): - state["args"].dataset_impl = "raw" - elif getattr(state["args"], "lazy_load", False): - state["args"].dataset_impl = "lazy" - # epochs start at 1 - if state["extra_state"]["train_iterator"] is not None: - state["extra_state"]["train_iterator"]["epoch"] = max( - state["extra_state"]["train_iterator"].get("epoch", 1), - 1, - ) - # set any missing default values in the task, model or other registries - registry.set_defaults(state["args"], tasks.TASK_REGISTRY[state["args"].task]) - registry.set_defaults(state["args"], models.ARCH_MODEL_REGISTRY[state["args"].arch]) - for registry_name, REGISTRY in registry.REGISTRIES.items(): - choice = getattr(state["args"], registry_name, None) - if choice is not None: - cls = REGISTRY["registry"][choice] - registry.set_defaults(state["args"], cls) + # old model checkpoints may not have separate source/target positions + # backward compatibility, cfg updates + if "args" in state and state["args"] is not None: + # default to translation task + if not hasattr(state["args"], "task"): + state["args"].task = "translation" + # --raw-text and --lazy-load are deprecated + if getattr(state["args"], "raw_text", False): + state["args"].dataset_impl = "raw" + elif getattr(state["args"], "lazy_load", False): + state["args"].dataset_impl = "lazy" + # epochs start at 1 + if state["extra_state"]["train_iterator"] is not None: + state["extra_state"]["train_iterator"]["epoch"] = max( + state["extra_state"]["train_iterator"].get("epoch", 1), 1 + ) + + state["cfg"] = convert_namespace_to_omegaconf(state["args"]) + + if "cfg" in state and state["cfg"] is not None: + with open_dict(state["cfg"]): + if state["cfg"].task is not None: + if hasattr(state["cfg"].task, "max_positions") and not hasattr( + state["cfg"].task, "max_source_positions" + ): + state["cfg"].task.max_source_positions = state[ + "cfg" + ].task.max_positions + state["cfg"].task.max_target_positions = state[ + "cfg" + ].task.max_positions return state -def prune_state_dict(state_dict, args): +def prune_state_dict(state_dict, model_cfg: Optional[DictConfig]): """Prune the given state_dict if desired for LayerDrop (https://arxiv.org/abs/1909.11556). @@ -453,16 +480,20 @@ def prune_state_dict(state_dict, args): It's called by functions that load models from checkpoints and does not need to be called directly. """ - if not args or args.arch == "ptt_transformer": + arch = None + if model_cfg is not None: + arch = ( + model_cfg._name + if isinstance(model_cfg, DictConfig) + else getattr(model_cfg, "arch", None) + ) + + if not model_cfg or arch is None or arch == "ptt_transformer": # args should not be none, but don't crash if it is. return state_dict - encoder_layers_to_keep = ( - args.encoder_layers_to_keep if "encoder_layers_to_keep" in vars(args) else None - ) - decoder_layers_to_keep = ( - args.decoder_layers_to_keep if "decoder_layers_to_keep" in vars(args) else None - ) + encoder_layers_to_keep = getattr(model_cfg, "encoder_layers_to_keep", None) + decoder_layers_to_keep = getattr(model_cfg, "decoder_layers_to_keep", None) if not encoder_layers_to_keep and not decoder_layers_to_keep: return state_dict @@ -474,7 +505,7 @@ def prune_state_dict(state_dict, args): def create_pruning_pass(layers_to_keep, layer_name): keep_layers = sorted( - [int(layer_string) for layer_string in layers_to_keep.split(",")] + int(layer_string) for layer_string in layers_to_keep.split(",") ) mapping_dict = {} for i in range(len(keep_layers)): @@ -518,10 +549,12 @@ def prune_state_dict(state_dict, args): # Since layers are now pruned, *_layers_to_keep are no longer needed. # This is more of "It would make it work fix" rather than a proper fix. - if "encoder_layers_to_keep" in vars(args): - args.encoder_layers_to_keep = None - if "decoder_layers_to_keep" in vars(args): - args.decoder_layers_to_keep = None + + with open_dict(model_cfg): + if hasattr(model_cfg, "encoder_layers_to_keep"): + model_cfg.encoder_layers_to_keep = None + if hasattr(model_cfg, "decoder_layers_to_keep"): + model_cfg.decoder_layers_to_keep = None return new_state_dict diff --git a/fairseq/criterions/__init__.py b/fairseq/criterions/__init__.py index a7eb5f6f..8cc6c0f0 100644 --- a/fairseq/criterions/__init__.py +++ b/fairseq/criterions/__init__.py @@ -6,8 +6,6 @@ import importlib import os -from argparse import Namespace -from typing import Union from fairseq import registry from fairseq.criterions.fairseq_criterion import ( # noqa @@ -27,8 +25,8 @@ from omegaconf import DictConfig ) -def build_criterion(criterion_cfg: Union[DictConfig, Namespace], task): - return build_criterion_(criterion_cfg, task) +def build_criterion(cfg: DictConfig, task): + return build_criterion_(cfg, task) # automatically import any Python files in the criterions/ directory diff --git a/fairseq/criterions/adaptive_loss.py b/fairseq/criterions/adaptive_loss.py index 74ba37c3..04832295 100644 --- a/fairseq/criterions/adaptive_loss.py +++ b/fairseq/criterions/adaptive_loss.py @@ -11,13 +11,13 @@ from fairseq import metrics, utils from fairseq.criterions import FairseqCriterion, register_criterion from fairseq.dataclass import FairseqDataclass from fairseq.dataclass.constants import DDP_BACKEND_CHOICES -from omegaconf import II +from omegaconf import II, DictConfig @dataclass class AdaptiveLossConfig(FairseqDataclass): - sentence_avg: bool = II("params.optimization.sentence_avg") - ddp_backend: DDP_BACKEND_CHOICES = II("params.distributed_training.ddp_backend") + sentence_avg: bool = II("optimization.sentence_avg") + ddp_backend: DDP_BACKEND_CHOICES = II("distributed_training.ddp_backend") @register_criterion("adaptive_loss", dataclass=AdaptiveLossConfig) @@ -31,14 +31,14 @@ class AdaptiveLoss(FairseqCriterion): self.sentence_avg = sentence_avg @classmethod - def build_criterion(cls, args, task): - if getattr(args, "ddp_backend", None) == "c10d": + def build_criterion(cls, cfg: DictConfig, task): + if cfg.ddp_backend == "c10d": raise Exception( "AdaptiveLoss is not compatible with the c10d " "version of DistributedDataParallel. Please use " "`--ddp-backend=no_c10d` instead." ) - return cls(task, args.sentence_avg) + return cls(task, cfg.sentence_avg) def forward(self, model, sample, reduce=True): """Compute the loss for the given sample. diff --git a/fairseq/criterions/cross_entropy.py b/fairseq/criterions/cross_entropy.py index 91b58545..758e7276 100644 --- a/fairseq/criterions/cross_entropy.py +++ b/fairseq/criterions/cross_entropy.py @@ -15,7 +15,7 @@ from omegaconf import II @dataclass class CrossEntropyCriterionConfig(FairseqDataclass): - sentence_avg: bool = II("params.optimization.sentence_avg") + sentence_avg: bool = II("optimization.sentence_avg") @register_criterion("cross_entropy", dataclass=CrossEntropyCriterionConfig) diff --git a/fairseq/criterions/ctc.py b/fairseq/criterions/ctc.py index 4f93b3cb..9310024f 100644 --- a/fairseq/criterions/ctc.py +++ b/fairseq/criterions/ctc.py @@ -10,24 +10,24 @@ from argparse import Namespace import torch import torch.nn.functional as F from fairseq import metrics, utils -from fairseq.criterions import FairseqCriterion, register_criterion +from fairseq.criterions import LegacyFairseqCriterion, register_criterion from fairseq.data.data_utils import post_process from fairseq.logging.meters import safe_round @register_criterion("ctc") -class CtcCriterion(FairseqCriterion): - def __init__(self, task, wer_args, zero_infinity, sentence_avg, remove_bpe): - super().__init__(task) +class CtcCriterion(LegacyFairseqCriterion): + def __init__(self, args, task): + super().__init__(args, task) self.blank_idx = task.target_dictionary.bos() self.pad_idx = task.target_dictionary.pad() self.eos_idx = task.target_dictionary.eos() - self.post_process = remove_bpe if remove_bpe else "letter" + self.post_process = args.remove_bpe if args.remove_bpe else "letter" - if wer_args is not None: + if args.wer_args is not None: from examples.speech_recognition.w2l_decoder import W2lKenLMDecoder - wer_compute_kenlm, wer_lexicon, lm_w, ws_w = eval(wer_args) + wer_compute_kenlm, wer_lexicon, lm_w, ws_w = eval(args.wer_args) dec_args = Namespace() dec_args.nbest = 1 @@ -46,8 +46,8 @@ class CtcCriterion(FairseqCriterion): else: self.w2l_decoder = None - self.zero_infinity = zero_infinity - self.sentence_avg = sentence_avg + self.zero_infinity = args.zero_infinity + self.sentence_avg = args.sentence_avg @staticmethod def add_args(parser): diff --git a/fairseq/criterions/fairseq_criterion.py b/fairseq/criterions/fairseq_criterion.py index ef94a863..b2eda1a7 100644 --- a/fairseq/criterions/fairseq_criterion.py +++ b/fairseq/criterions/fairseq_criterion.py @@ -8,6 +8,7 @@ from typing import Any, Dict, List from fairseq import metrics, utils from fairseq.dataclass.utils import gen_parser_from_dataclass +from omegaconf import DictConfig from torch.nn.modules.loss import _Loss @@ -27,10 +28,8 @@ class FairseqCriterion(_Loss): gen_parser_from_dataclass(parser, dc()) @classmethod - def build_criterion(cls, args, task): + def build_criterion(cls, cfg: DictConfig, task): """Construct a criterion from command-line args.""" - # Criterions can override this, but for convenience we also try - # to automatically map argparse.Namespace keys to corresponding # arguments in the __init__. init_args = {} for p in inspect.signature(cls).parameters.values(): @@ -47,8 +46,8 @@ class FairseqCriterion(_Loss): if p.name == "task": init_args["task"] = task - elif hasattr(args, p.name): - init_args[p.name] = getattr(args, p.name) + elif hasattr(cfg, p.name): + init_args[p.name] = getattr(cfg, p.name) elif p.default != p.empty: pass # we'll use the default value else: @@ -70,7 +69,7 @@ class FairseqCriterion(_Loss): @staticmethod def aggregate_logging_outputs( - logging_outputs: List[Dict[str, Any]], + logging_outputs: List[Dict[str, Any]] ) -> Dict[str, Any]: """Aggregate logging outputs from data parallel training.""" utils.deprecation_warning( diff --git a/fairseq/data/encoders/byte_bpe.py b/fairseq/data/encoders/byte_bpe.py index 0d2da3ea..31e3a062 100644 --- a/fairseq/data/encoders/byte_bpe.py +++ b/fairseq/data/encoders/byte_bpe.py @@ -4,6 +4,8 @@ # LICENSE file in the root directory of this source tree. +from dataclasses import dataclass, field + from fairseq import file_utils from fairseq.data.encoders import register_bpe from fairseq.data.encoders.byte_utils import ( @@ -12,19 +14,20 @@ from fairseq.data.encoders.byte_utils import ( byte_encode, smart_byte_decode, ) +from fairseq.dataclass import FairseqDataclass -@register_bpe("byte_bpe") +@dataclass +class ByteBpeConfig(FairseqDataclass): + sentencepiece_model_path: str = field( + default="???", metadata={"help": "path to sentencepiece model"} + ) + + +@register_bpe("byte_bpe", dataclass=ByteBpeConfig) class ByteBPE(object): - @staticmethod - def add_args(parser): - # fmt: off - parser.add_argument('--sentencepiece-model-path', type=str, - help='path to sentencepiece model') - # fmt: on - - def __init__(self, args): - vocab = file_utils.cached_path(args.sentencepiece_model_path) + def __init__(self, cfg): + vocab = file_utils.cached_path(cfg.sentencepiece_model_path) try: import sentencepiece as spm diff --git a/fairseq/data/encoders/bytes.py b/fairseq/data/encoders/bytes.py index bb9554ed..f88f8f69 100644 --- a/fairseq/data/encoders/bytes.py +++ b/fairseq/data/encoders/bytes.py @@ -15,7 +15,7 @@ from fairseq.data.encoders.byte_utils import ( @register_bpe("bytes") class Bytes(object): - def __init__(self, args): + def __init__(self, *unused): pass @staticmethod diff --git a/fairseq/data/encoders/characters.py b/fairseq/data/encoders/characters.py index cffc5751..494ea219 100644 --- a/fairseq/data/encoders/characters.py +++ b/fairseq/data/encoders/characters.py @@ -13,7 +13,7 @@ SPACE_ESCAPE = chr(9601) @register_bpe("characters") class Characters(object): - def __init__(self, args): + def __init__(self, *unused): pass @staticmethod diff --git a/fairseq/data/encoders/fastbpe.py b/fairseq/data/encoders/fastbpe.py index 74d4ad85..f7c21039 100644 --- a/fairseq/data/encoders/fastbpe.py +++ b/fairseq/data/encoders/fastbpe.py @@ -3,23 +3,24 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from dataclasses import dataclass, field + from fairseq import file_utils from fairseq.data.encoders import register_bpe +from fairseq.dataclass import FairseqDataclass -@register_bpe("fastbpe") +@dataclass +class fastBPEConfig(FairseqDataclass): + bpe_codes: str = field(default="???", metadata={"help": "path to fastBPE BPE"}) + + +@register_bpe("fastbpe", dataclass=fastBPEConfig) class fastBPE(object): - @staticmethod - def add_args(parser): - # fmt: off - parser.add_argument('--bpe-codes', type=str, - help='path to fastBPE BPE') - # fmt: on - - def __init__(self, args): - if args.bpe_codes is None: + def __init__(self, cfg): + if cfg.bpe_codes is None: raise ValueError("--bpe-codes is required for --bpe=fastbpe") - codes = file_utils.cached_path(args.bpe_codes) + codes = file_utils.cached_path(cfg.bpe_codes) try: import fastBPE diff --git a/fairseq/data/encoders/gpt2_bpe.py b/fairseq/data/encoders/gpt2_bpe.py index 8ac099a6..e661426a 100644 --- a/fairseq/data/encoders/gpt2_bpe.py +++ b/fairseq/data/encoders/gpt2_bpe.py @@ -3,8 +3,11 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from dataclasses import dataclass, field + from fairseq import file_utils from fairseq.data.encoders import register_bpe +from fairseq.dataclass import FairseqDataclass from .gpt2_bpe_utils import get_encoder @@ -13,26 +16,21 @@ DEFAULT_ENCODER_JSON = "https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder. DEFAULT_VOCAB_BPE = "https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe" -@register_bpe("gpt2") -class GPT2BPE(object): - @staticmethod - def add_args(parser): - # fmt: off - parser.add_argument('--gpt2-encoder-json', type=str, - default=DEFAULT_ENCODER_JSON, - help='path to encoder.json') - parser.add_argument('--gpt2-vocab-bpe', type=str, - default=DEFAULT_VOCAB_BPE, - help='path to vocab.bpe') - # fmt: on +@dataclass +class GPT2BPEConfig(FairseqDataclass): + gpt2_encoder_json: str = field( + default=DEFAULT_ENCODER_JSON, metadata={"help": "path to encoder.json"} + ) + gpt2_vocab_bpe: str = field( + default=DEFAULT_VOCAB_BPE, metadata={"help": "path to vocab.bpe"} + ) - def __init__(self, args): - encoder_json = file_utils.cached_path( - getattr(args, "gpt2_encoder_json", DEFAULT_ENCODER_JSON) - ) - vocab_bpe = file_utils.cached_path( - getattr(args, "gpt2_vocab_bpe", DEFAULT_VOCAB_BPE) - ) + +@register_bpe("gpt2", dataclass=GPT2BPEConfig) +class GPT2BPE(object): + def __init__(self, cfg): + encoder_json = file_utils.cached_path(cfg.gpt2_encoder_json) + vocab_bpe = file_utils.cached_path(cfg.gpt2_vocab_bpe) self.bpe = get_encoder(encoder_json, vocab_bpe) def encode(self, x: str) -> str: diff --git a/fairseq/data/encoders/hf_bert_bpe.py b/fairseq/data/encoders/hf_bert_bpe.py index a968fe88..a41c0593 100644 --- a/fairseq/data/encoders/hf_bert_bpe.py +++ b/fairseq/data/encoders/hf_bert_bpe.py @@ -3,22 +3,24 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from dataclasses import dataclass, field +from typing import Optional + from fairseq.data.encoders import register_bpe +from fairseq.dataclass import FairseqDataclass -@register_bpe("bert") +@dataclass +class BertBPEConfig(FairseqDataclass): + bpe_cased: bool = field(default=False, metadata={"help": "set for cased BPE"}) + bpe_vocab_file: Optional[str] = field( + default=None, metadata={"help": "bpe vocab file"} + ) + + +@register_bpe("bert", dataclass=BertBPEConfig) class BertBPE(object): - @staticmethod - def add_args(parser): - # fmt: off - parser.add_argument('--bpe-cased', action='store_true', - help='set for cased BPE', - default=False) - parser.add_argument('--bpe-vocab-file', type=str, - help='bpe vocab file.') - # fmt: on - - def __init__(self, args): + def __init__(self, cfg): try: from transformers import BertTokenizer except ImportError: @@ -26,13 +28,13 @@ class BertBPE(object): "Please install transformers with: pip install transformers" ) - if "bpe_vocab_file" in args: + if cfg.bpe_vocab_file: self.bert_tokenizer = BertTokenizer( - args.bpe_vocab_file, do_lower_case=not args.bpe_cased + cfg.bpe_vocab_file, do_lower_case=not cfg.bpe_cased ) else: vocab_file_name = ( - "bert-base-cased" if args.bpe_cased else "bert-base-uncased" + "bert-base-cased" if cfg.bpe_cased else "bert-base-uncased" ) self.bert_tokenizer = BertTokenizer.from_pretrained(vocab_file_name) diff --git a/fairseq/data/encoders/hf_byte_bpe.py b/fairseq/data/encoders/hf_byte_bpe.py index 544d4082..92d2c392 100644 --- a/fairseq/data/encoders/hf_byte_bpe.py +++ b/fairseq/data/encoders/hf_byte_bpe.py @@ -3,21 +3,24 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from dataclasses import dataclass, field + from fairseq.data.encoders import register_bpe +from fairseq.dataclass import FairseqDataclass -@register_bpe("hf_byte_bpe") +@dataclass +class HuggingFaceByteLevelBPEConfig(FairseqDataclass): + bpe_merges: str = field(default="???", metadata={"help": "path to merges.txt"}) + bpe_vocab: str = field(default="???", metadata={"help": "path to vocab.json"}) + bpe_add_prefix_space: bool = field( + default=False, metadata={"help": "add prefix space before encoding"} + ) + + +@register_bpe("hf_byte_bpe", dataclass=HuggingFaceByteLevelBPEConfig) class HuggingFaceByteLevelBPE(object): - @staticmethod - def add_args(parser): - # fmt: off - parser.add_argument('--bpe-merges', help='path to merges.txt') - parser.add_argument('--bpe-vocab', help='path to vocab.json') - parser.add_argument('--bpe-add-prefix-space', action='store_true', - help='add prefix space before encoding') - # fmt: on - - def __init__(self, args): + def __init__(self, cfg): try: from tokenizers import ByteLevelBPETokenizer except ImportError: @@ -26,9 +29,9 @@ class HuggingFaceByteLevelBPE(object): ) self.bpe = ByteLevelBPETokenizer( - args.bpe_vocab, - args.bpe_merges, - add_prefix_space=getattr(args, "bpe_add_prefix_space", False), + cfg.bpe_vocab, + cfg.bpe_merges, + add_prefix_space=cfg.bpe_add_prefix_space, ) def encode(self, x: str) -> str: diff --git a/fairseq/data/encoders/moses_tokenizer.py b/fairseq/data/encoders/moses_tokenizer.py index 8c248442..fa004dd4 100644 --- a/fairseq/data/encoders/moses_tokenizer.py +++ b/fairseq/data/encoders/moses_tokenizer.py @@ -3,37 +3,35 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from dataclasses import dataclass, field + from fairseq.data.encoders import register_tokenizer +from fairseq.dataclass import FairseqDataclass -@register_tokenizer("moses") +@dataclass +class MosesTokenizerConfig(FairseqDataclass): + source_lang: str = field(default="en", metadata={"help": "source language"}) + target_lang: str = field(default="en", metadata={"help": "target language"}) + moses_no_dash_splits: bool = field( + default=False, metadata={"help": "don't apply dash split rules"} + ) + moses_no_escape: bool = field( + default=False, + metadata={"help": "don't perform HTML escaping on apostrophe, quotes, etc."}, + ) + + +@register_tokenizer("moses", dataclass=MosesTokenizerConfig) class MosesTokenizer(object): - @staticmethod - def add_args(parser): - # fmt: off - parser.add_argument('--moses-source-lang', metavar='SRC', - help='source language') - parser.add_argument('--moses-target-lang', metavar='TARGET', - help='target language') - parser.add_argument('--moses-no-dash-splits', action='store_true', default=False, - help='don\'t apply dash split rules') - parser.add_argument('--moses-no-escape', action='store_true', default=False, - help='don\'t perform HTML escaping on apostrophy, quotes, etc.') - # fmt: on - - def __init__(self, args): - self.args = args - - if getattr(args, "moses_source_lang", None) is None: - args.moses_source_lang = getattr(args, "source_lang", "en") - if getattr(args, "moses_target_lang", None) is None: - args.moses_target_lang = getattr(args, "target_lang", "en") + def __init__(self, cfg): + self.cfg = cfg try: from sacremoses import MosesTokenizer, MosesDetokenizer - self.tok = MosesTokenizer(args.moses_source_lang) - self.detok = MosesDetokenizer(args.moses_target_lang) + self.tok = MosesTokenizer(cfg.source_lang) + self.detok = MosesDetokenizer(cfg.target_lang) except ImportError: raise ImportError( "Please install Moses tokenizer with: pip install sacremoses" @@ -42,9 +40,9 @@ class MosesTokenizer(object): def encode(self, x: str) -> str: return self.tok.tokenize( x, - aggressive_dash_splits=(not self.args.moses_no_dash_splits), + aggressive_dash_splits=(not self.cfg.moses_no_dash_splits), return_str=True, - escape=(not self.args.moses_no_escape), + escape=(not self.cfg.moses_no_escape), ) def decode(self, x: str) -> str: diff --git a/fairseq/data/encoders/nltk_tokenizer.py b/fairseq/data/encoders/nltk_tokenizer.py index 3b617e73..ee164710 100644 --- a/fairseq/data/encoders/nltk_tokenizer.py +++ b/fairseq/data/encoders/nltk_tokenizer.py @@ -8,7 +8,7 @@ from fairseq.data.encoders import register_tokenizer @register_tokenizer("nltk") class NLTKTokenizer(object): - def __init__(self, source_lang=None, target_lang=None): + def __init__(self, *unused): try: from nltk.tokenize import word_tokenize diff --git a/fairseq/data/encoders/sentencepiece_bpe.py b/fairseq/data/encoders/sentencepiece_bpe.py index b25c6cae..a76d46a2 100644 --- a/fairseq/data/encoders/sentencepiece_bpe.py +++ b/fairseq/data/encoders/sentencepiece_bpe.py @@ -3,21 +3,24 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from dataclasses import dataclass, field + from fairseq import file_utils from fairseq.data.encoders import register_bpe +from fairseq.dataclass import FairseqDataclass -@register_bpe("sentencepiece") +@dataclass +class SentencepieceConfig(FairseqDataclass): + sentencepiece_model: str = field( + default="???", metadata={"help": "path to sentencepiece model"} + ) + + +@register_bpe("sentencepiece", dataclass=SentencepieceConfig) class SentencepieceBPE(object): - @staticmethod - def add_args(parser): - # fmt: off - parser.add_argument('--sentencepiece-model', type=str, - help='path to sentencepiece model') - # fmt: on - - def __init__(self, args): - sentencepiece_model = file_utils.cached_path(args.sentencepiece_model) + def __init__(self, cfg): + sentencepiece_model = file_utils.cached_path(cfg.sentencepiece_model) try: import sentencepiece as spm diff --git a/fairseq/data/encoders/space_tokenizer.py b/fairseq/data/encoders/space_tokenizer.py index 3bc7ce49..7c7f644d 100644 --- a/fairseq/data/encoders/space_tokenizer.py +++ b/fairseq/data/encoders/space_tokenizer.py @@ -10,7 +10,7 @@ from fairseq.data.encoders import register_tokenizer @register_tokenizer("space") class SpaceTokenizer(object): - def __init__(self, source_lang=None, target_lang=None): + def __init__(self, *unused): self.space_tok = re.compile(r"\s+") def encode(self, x: str) -> str: diff --git a/fairseq/data/encoders/subword_nmt_bpe.py b/fairseq/data/encoders/subword_nmt_bpe.py index e85f99af..5d724d27 100644 --- a/fairseq/data/encoders/subword_nmt_bpe.py +++ b/fairseq/data/encoders/subword_nmt_bpe.py @@ -3,25 +3,25 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from dataclasses import dataclass, field + from fairseq import file_utils from fairseq.data.encoders import register_bpe +from fairseq.dataclass import FairseqDataclass -@register_bpe("subword_nmt") +@dataclass +class SubwordNMTBPEConfig(FairseqDataclass): + bpe_codes: str = field(default="???", metadata={"help": "path to subword NMT BPE"}) + bpe_separator: str = field(default="@@", metadata={"help": "BPE separator"}) + + +@register_bpe("subword_nmt", dataclass=SubwordNMTBPEConfig) class SubwordNMTBPE(object): - @staticmethod - def add_args(parser): - # fmt: off - parser.add_argument('--bpe-codes', type=str, - help='path to subword NMT BPE') - parser.add_argument('--bpe-separator', default='@@', - help='BPE separator') - # fmt: on - - def __init__(self, args): - if args.bpe_codes is None: + def __init__(self, cfg): + if cfg.bpe_codes is None: raise ValueError("--bpe-codes is required for --bpe=subword_nmt") - codes = file_utils.cached_path(args.bpe_codes) + codes = file_utils.cached_path(cfg.bpe_codes) try: from subword_nmt import apply_bpe @@ -31,7 +31,7 @@ class SubwordNMTBPE(object): "--codes", codes, "--separator", - args.bpe_separator, + cfg.bpe_separator, ] ) self.bpe = apply_bpe.BPE( diff --git a/fairseq/dataclass/constants.py b/fairseq/dataclass/constants.py index 21b36450..2fd87f5f 100644 --- a/fairseq/dataclass/constants.py +++ b/fairseq/dataclass/constants.py @@ -9,5 +9,7 @@ from fairseq.dataclass.utils import ChoiceEnum LOG_FORMAT_CHOICES = ChoiceEnum(["json", "none", "simple", "tqdm"]) DDP_BACKEND_CHOICES = ChoiceEnum(["c10d", "no_c10d"]) DISTRIBUTED_WRAPPER_CHOICES = ChoiceEnum(["DDP", "SlowMo"]) +GENERATION_CONSTRAINTS_CHOICES = ChoiceEnum(["ordered", "unordered"]) +GENERATION_DECODING_FORMAT_CHOICES = ChoiceEnum(["unigram", "ensemble", "vote", "dp", "bs"]) ZERO_SHARDING_CHOICES = ChoiceEnum(["none", "os"]) PIPELINE_CHECKPOINT_CHOICES = ChoiceEnum(["always", "never", "except_last"]) diff --git a/fairseq/dataclass/data_class.py b/fairseq/dataclass/data_class.py index ed1d12d8..b0c17ba0 100644 --- a/fairseq/dataclass/data_class.py +++ b/fairseq/dataclass/data_class.py @@ -3,32 +3,37 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import logging import sys from argparse import Namespace -from dataclasses import dataclass, field +from dataclasses import _MISSING_TYPE, dataclass, field from typing import Any, Dict, List, Optional, Tuple, Type import torch -from fairseq.criterions import CRITERION_DATACLASS_REGISTRY from fairseq.data.indexed_dataset import get_available_dataset_impl from fairseq.dataclass.constants import ( DDP_BACKEND_CHOICES, DISTRIBUTED_WRAPPER_CHOICES, + GENERATION_CONSTRAINTS_CHOICES, + GENERATION_DECODING_FORMAT_CHOICES, LOG_FORMAT_CHOICES, PIPELINE_CHECKPOINT_CHOICES, ZERO_SHARDING_CHOICES, ) from fairseq.dataclass.utils import ChoiceEnum, FairseqDataclass from fairseq.models import ARCH_MODEL_REGISTRY, MODEL_DATACLASS_REGISTRY -from fairseq.optim import OPTIMIZER_DATACLASS_REGISTRY from fairseq.optim.bmuf import FairseqBMUFConfig -from fairseq.optim.lr_scheduler import LR_SCHEDULER_DATACLASS_REGISTRY +from fairseq.registry import REGISTRIES from fairseq.tasks import TASK_DATACLASS_REGISTRY from hydra.core.config_store import ConfigStore +from omegaconf import II + + +logger = logging.getLogger(__name__) @dataclass -class CommonParams(FairseqDataclass): +class CommonConfig(FairseqDataclass): # This is the core dataclass including common parameters shared by all different jobs. Please append your params to other dataclasses if they were # used for a particular purpose or task, such as those dedicated for `distributed training`, `optimization`, etc. no_progress_bar: bool = field( @@ -109,18 +114,6 @@ class CommonParams(FairseqDataclass): model_parallel_size: int = field( default=1, metadata={"help": "total number of GPUs to parallelize model over"} ) - checkpoint_suffix: str = field( - default="", metadata={"help": "suffix to add to the checkpoint file name"} - ) - checkpoint_shard_count: int = field( - default=1, - metadata={ - "help": "Number of shards containing the checkpoint - " - "if the checkpoint is over 300GB, it is preferable " - "to split it into shards to prevent OOM on CPU while loading " - "the checkpoint" - }, - ) quantization_config_path: Optional[str] = field( default=None, metadata={"help": "path to quantization config file"} ) @@ -130,7 +123,7 @@ class CommonParams(FairseqDataclass): @dataclass -class DistributedTrainingParams(FairseqDataclass): +class DistributedTrainingConfig(FairseqDataclass): distributed_world_size: int = field( default=max(1, torch.cuda.device_count()), metadata={ @@ -229,7 +222,7 @@ class DistributedTrainingParams(FairseqDataclass): default=False, metadata={"help": "if set, use pipeline model parallelism across GPUs"}, ) - pipeline_balance: str = field( + pipeline_balance: Optional[str] = field( default=None, metadata={ "help": "partition the model into N_K pieces, where each piece " @@ -237,7 +230,7 @@ class DistributedTrainingParams(FairseqDataclass): "should equal the total number of layers in the model" }, ) - pipeline_devices: str = field( + pipeline_devices: Optional[str] = field( default=None, metadata={ "help": "a list of device indices indicating which device to place " @@ -245,10 +238,10 @@ class DistributedTrainingParams(FairseqDataclass): "equal the length of the --pipeline-balance argument" }, ) - pipeline_chunks: int = field( + pipeline_chunks: Optional[int] = field( default=0, metadata={"help": "microbatch count for pipeline model parallelism"} ) - pipeline_encoder_balance: str = field( + pipeline_encoder_balance: Optional[str] = field( default=None, metadata={ "help": "partition the pipeline parallel encoder into N_K pieces, where each piece " @@ -256,7 +249,7 @@ class DistributedTrainingParams(FairseqDataclass): "should equal the total number of encoder layers in the model" }, ) - pipeline_encoder_devices: str = field( + pipeline_encoder_devices: Optional[str] = field( default=None, metadata={ "help": "a list of device indices indicating which device to place " @@ -264,7 +257,7 @@ class DistributedTrainingParams(FairseqDataclass): "equal the length of the --pipeline-encoder-balance argument" }, ) - pipeline_decoder_balance: str = field( + pipeline_decoder_balance: Optional[str] = field( default=None, metadata={ "help": "partition the pipeline parallel decoder into N_K pieces, where each piece " @@ -272,7 +265,7 @@ class DistributedTrainingParams(FairseqDataclass): "should equal the total number of decoder layers in the model" }, ) - pipeline_decoder_devices: str = field( + pipeline_decoder_devices: Optional[str] = field( default=None, metadata={ "help": "a list of device indices indicating which device to place " @@ -287,10 +280,11 @@ class DistributedTrainingParams(FairseqDataclass): zero_sharding: ZERO_SHARDING_CHOICES = field( default="none", metadata={"help": "ZeRO sharding"} ) + tpu: bool = II("common.tpu") @dataclass -class DatasetParams(FairseqDataclass): +class DatasetConfig(FairseqDataclass): num_workers: int = field( default=1, metadata={"help": "how many subprocesses to use for data loading"} ) @@ -374,7 +368,7 @@ class DatasetParams(FairseqDataclass): @dataclass -class OptimizationParams(FairseqDataclass): +class OptimizationConfig(FairseqDataclass): max_epoch: int = field( default=0, metadata={"help": "force stop training at specified epoch"} ) @@ -421,7 +415,7 @@ class OptimizationParams(FairseqDataclass): @dataclass -class CheckpointParams(FairseqDataclass): +class CheckpointConfig(FairseqDataclass): save_dir: str = field( default="checkpoints", metadata={"help": "path to save checkpoints"} ) @@ -514,12 +508,217 @@ class CheckpointParams(FairseqDataclass): ) }, ) + checkpoint_suffix: str = field( + default="", metadata={"help": "suffix to add to the checkpoint file name"} + ) + checkpoint_shard_count: int = field( + default=1, + metadata={ + "help": "Number of shards containing the checkpoint - " + "if the checkpoint is over 300GB, it is preferable " + "to split it into shards to prevent OOM on CPU while loading " + "the checkpoint" + }, + ) + model_parallel_size: int = II("common.model_parallel_size") + distributed_rank: int = II("distributed_training.distributed_rank") @dataclass -class CommonEvalParams(FairseqDataclass): +class GenerationConfig(FairseqDataclass): + beam: int = field( + default=5, + metadata={"help": "beam size"}, + ) + nbest: int = field( + default=1, + metadata={"help": "number of hypotheses to output"}, + ) + max_len_a: float = field( + default=0, + metadata={ + "help": "generate sequences of maximum length ax + b, where x is the source length" + }, + ) + max_len_b: int = field( + default=200, + metadata={ + "help": "generate sequences of maximum length ax + b, where x is the source length" + }, + ) + min_len: int = field( + default=1, + metadata={"help": "minimum generation length"}, + ) + match_source_len: bool = field( + default=False, + metadata={"help": "generations should match the source length"}, + ) + unnormalized: bool = field( + default=False, + metadata={"help": "compare unnormalized hypothesis scores"}, + ) + no_early_stop: bool = field( + default=False, + metadata={"help": "deprecated"}, + ) + no_beamable_mm: bool = field( + default=False, + metadata={"help": "don't use BeamableMM in attention layers"}, + ) + lenpen: float = field( + default=1, + metadata={ + "help": "length penalty: <1.0 favors shorter, >1.0 favors longer sentences" + }, + ) + unkpen: float = field( + default=0, + metadata={ + "help": "unknown word penalty: <0 produces more unks, >0 produces fewer" + }, + ) + replace_unk: Optional[str] = field( + default=None, + metadata={ + "help": "perform unknown replacement (optionally with alignment dictionary)", + "argparse_const": "@@ ", + }, + ) + sacrebleu: bool = field( + default=False, + metadata={"help": "score with sacrebleu"}, + ) + score_reference: bool = field( + default=False, + metadata={"help": "just score the reference translation"}, + ) + prefix_size: int = field( + default=0, + metadata={"help": "initialize generation by target prefix of given length"}, + ) + no_repeat_ngram_size: int = field( + default=0, + metadata={ + "help": "ngram blocking such that this size ngram cannot be repeated in the generation" + }, + ) + sampling: bool = field( + default=False, + metadata={"help": "sample hypotheses instead of using beam search"}, + ) + sampling_topk: int = field( + default=-1, + metadata={"help": "sample from top K likely next words instead of all words"}, + ) + sampling_topp: float = field( + default=-1.0, + metadata={ + "help": "sample from the smallest set whose cumulative probability mass exceeds p for next words" + }, + ) + constraints: Optional[GENERATION_CONSTRAINTS_CHOICES] = field( + default=None, + metadata={ + "help": "enables lexically constrained decoding", + "argparse_const": "ordered", + }, + ) + temperature: float = field( + default=1.0, + metadata={"help": "temperature for generation"}, + ) + diverse_beam_groups: int = field( + default=-1, + metadata={"help": "number of groups for Diverse Beam Search"}, + ) + diverse_beam_strength: float = field( + default=0.5, + metadata={"help": "strength of diversity penalty for Diverse Beam Search"}, + ) + diversity_rate: float = field( + default=-1.0, + metadata={"help": "strength of diversity penalty for Diverse Siblings Search"}, + ) + print_alignment: bool = field( + default=False, + metadata={ + "help": "if set, uses attention feedback to compute and print alignment to source tokens" + }, + ) + print_step: bool = field( + default=False, + metadata={"help": "print steps"}, + ) + lm_path: Optional[str] = field( + default=None, + metadata={"help": "path to lm checkpoint for lm fusion"}, + ) + lm_weight: float = field( + default=0.0, + metadata={"help": "weight for lm probs for lm fusion"}, + ) + + # arguments for iterative refinement generator + iter_decode_eos_penalty: float = field( + default=0.0, + metadata={"help": "if > 0.0, it penalized early-stopping in decoding."}, + ) + iter_decode_max_iter: int = field( + default=10, + metadata={"help": "maximum iterations for iterative refinement."}, + ) + iter_decode_force_max_iter: bool = field( + default=False, + metadata={ + "help": "if set, run exact the maximum number of iterations without early stop" + }, + ) + iter_decode_with_beam: int = field( + default=1, + metadata={ + "help": "if > 1, model will generate translations varying by the lengths." + }, + ) + iter_decode_with_external_reranker: bool = field( + default=False, + metadata={ + "help": "if set, the last checkpoint are assumed to be a reranker to rescore the translations" + }, + ) + retain_iter_history: bool = field( + default=False, + metadata={ + "help": "if set, decoding returns the whole history of iterative refinement" + }, + ) + retain_dropout: bool = field( + default=False, + metadata={"help": "Use dropout at inference time"}, + ) + retain_dropout_modules: Optional[List[str]] = field( + default=None, + metadata={ + "help": "if set, only retain dropout for the specified modules; " + "if not set, then dropout will be retained for all modules" + }, + ) + # special decoding format for advanced decoding. + decoding_format: Optional[GENERATION_DECODING_FORMAT_CHOICES] = field( + default=None, + metadata={"help": "special decoding format for advanced decoding."}, + ) + no_seed_provided: bool = field( + default=False, + metadata={"help": "if set, dont use seed for initializing random generators"}, + ) + + +@dataclass +class CommonEvalConfig(FairseqDataclass): path: Optional[str] = field( - default=None, metadata={"help": "path(s) to model file(s), colon separated"} + default=None, + metadata={"help": "path(s) to model file(s), colon separated"}, ) remove_bpe: Optional[str] = field( default=None, @@ -541,7 +740,7 @@ class CommonEvalParams(FairseqDataclass): @dataclass -class EvalLMParams(FairseqDataclass): +class EvalLMConfig(FairseqDataclass): output_word_probs: bool = field( default=False, metadata={ @@ -569,37 +768,31 @@ class EvalLMParams(FairseqDataclass): @dataclass -class TrainingConfig(FairseqDataclass): - """Config for training, a composition of training params""" - - common: CommonParams = CommonParams() - distributed_training: DistributedTrainingParams = DistributedTrainingParams() - dataset: DatasetParams = DatasetParams() - optimization: OptimizationParams = OptimizationParams() - checkpoint: CheckpointParams = CheckpointParams() - bmuf: FairseqBMUFConfig = FairseqBMUFConfig() +class InteractiveConfig(FairseqDataclass): + buffer_size: int = field( + default=0, + metadata={ + "help": "read this many sentences into a buffer before processing them" + }, + ) + input: str = field( + default="-", + metadata={"help": "file to read from; use - for stdin"}, + ) -@dataclass -class EvalLMConfig(FairseqDataclass): - """Config for eval lm, a composition of eval_lm params""" - - common: CommonParams = CommonParams() - distributed_training: DistributedTrainingParams = DistributedTrainingParams() - dataset: DatasetParams = DatasetParams() - optimization: OptimizationParams = OptimizationParams() - checkpoint: CheckpointParams = CheckpointParams() - bmuf: FairseqBMUFConfig = FairseqBMUFConfig() - common_eval: CommonEvalParams = CommonEvalParams() - eval_lm: EvalLMParams = EvalLMParams() - - -def register_params_dataclass( - cs: ConfigStore, name: str, group: str, data_class: Type[FairseqDataclass] -) -> None: - """register params dataclass in config store""" - node_ = data_class(_name=data_class.name()) - cs.store(name=name, group=group, node=node_) +CONFIGS = { + "common": CommonConfig, + "common_eval": CommonEvalConfig, + "distributed_training": DistributedTrainingConfig, + "dataset": DatasetConfig, + "optimization": OptimizationConfig, + "checkpoint": CheckpointConfig, + "bmuf": FairseqBMUFConfig, + "generation": GenerationConfig, + "eval_lm": EvalLMConfig, + "interactive": InteractiveConfig, +} def register_module_dataclass( @@ -608,100 +801,67 @@ def register_module_dataclass( """register dataclasses defined in modules in config store, for example, in migrated tasks, models, etc.""" # note that if `group == model`, we register all model archs, not the model name. for k, v in registry.items(): - if v is not None: - node_ = v(_name=k) - cs.store(name=k, group=group, node=node_) + node_ = v() + node_._name = k + cs.store(name=k, group=group, node=node_, provider="fairseq") -def register_training_hydra_cfg(cs: ConfigStore, name: str = "default") -> None: +def register_hydra_cfg(cs: ConfigStore, name: str = "default") -> None: """cs: config store instance, register common training configs""" - register_params_dataclass( - cs, name="training_params", group="params", data_class=TrainingConfig - ) + for k, v in CONFIGS.items(): + try: + cs.store(name=k, node=v()) + except BaseException: + logger.error(f"{k} - {v()}") + raise register_module_dataclass(cs, TASK_DATACLASS_REGISTRY, "task") register_module_dataclass(cs, MODEL_DATACLASS_REGISTRY, "model") - register_module_dataclass(cs, CRITERION_DATACLASS_REGISTRY, "criterion") - register_module_dataclass(cs, OPTIMIZER_DATACLASS_REGISTRY, "optimizer") - register_module_dataclass(cs, LR_SCHEDULER_DATACLASS_REGISTRY, "lr_scheduler") - -def register_eval_lm_hydra_cfg(cs: ConfigStore, name: str = "default") -> None: - """cs: config store instance, register common training configs""" - - register_params_dataclass( - cs, name="eval_lm_params", group="params", data_class=EvalLMConfig - ) - - register_module_dataclass(cs, TASK_DATACLASS_REGISTRY, "task") - register_module_dataclass(cs, CRITERION_DATACLASS_REGISTRY, "criterion") - register_module_dataclass(cs, OPTIMIZER_DATACLASS_REGISTRY, "optimizer") - register_module_dataclass(cs, LR_SCHEDULER_DATACLASS_REGISTRY, "lr_scheduler") + for k, v in REGISTRIES.items(): + register_module_dataclass(cs, v["dataclass_registry"], k) def _override_attr( sub_node: str, data_class: Type[FairseqDataclass], args: Namespace ) -> List[str]: overrides = [] - for k in data_class.__dataclass_fields__.keys(): - if k == "_name": + + def get_default(f): + if not isinstance(f.default_factory, _MISSING_TYPE): + return f.default_factory() + return f.default + + for k, v in data_class.__dataclass_fields__.items(): + if k.startswith("_"): # private member, skip continue - if not hasattr(args, k): - # print(f"cannot override {sub_node}.{k} since args does not have attribute {k}") - continue - if getattr(args, k) is None: + + val = get_default(v) if not hasattr(args, k) else getattr(args, k) + + if val is None: overrides.append("{}.{}=null".format(sub_node, k)) - elif getattr(args, k) == "": + elif val == "": overrides.append("{}.{}=''".format(sub_node, k)) - elif isinstance(getattr(args, k), str): - if ( - getattr(args, k).startswith("[") - or getattr(args, k).startswith("(") - or getattr(args, k).startswith("{") - or ("," in getattr(args, k)) - ): - overrides.append("{}.{}='{}'".format(sub_node, k, getattr(args, k))) - else: - overrides.append("{}.{}={}".format(sub_node, k, getattr(args, k))) + elif isinstance(val, str): + overrides.append("{}.{}='{}'".format(sub_node, k, val)) else: - overrides.append("{}.{}={}".format(sub_node, k, getattr(args, k))) + overrides.append("{}.{}={}".format(sub_node, k, val)) return overrides -def override_training_args(args: Namespace) -> Tuple[List[str], List[str]]: - overrides = [] - - overrides.extend(_override_attr("params.common", CommonParams, args)) - overrides.extend(_override_attr("params.dataset", DatasetParams, args)) - overrides.extend( - _override_attr("params.distributed_training", DistributedTrainingParams, args) - ) - overrides.extend(_override_attr("params.optimization", OptimizationParams, args)) - overrides.extend(_override_attr("params.checkpoint", CheckpointParams, args)) - overrides.extend(_override_attr("params.bmuf", FairseqBMUFConfig, args)) - module_overrides, module_deletes = override_module_args(args) - overrides.extend(module_overrides) - - return overrides, module_deletes - - -def override_eval_lm_args(args: Namespace) -> Tuple[List[str], List[str]]: - overrides = [] - - overrides.extend(_override_attr("params.common", CommonParams, args)) - overrides.extend(_override_attr("params.dataset", DatasetParams, args)) - overrides.extend( - _override_attr("params.distributed_training", DistributedTrainingParams, args) - ) - overrides.extend(_override_attr("params.common_eval", CommonEvalParams, args)) - overrides.extend(_override_attr("params.eval_lm", EvalLMParams, args)) - overrides.extend(_override_attr("params.bmuf", FairseqBMUFConfig, args)) - module_overrides, module_deletes = override_module_args(args) - overrides.extend(module_overrides) - - return overrides, module_deletes +def migrate_registry( + name, value, registry, args, overrides, deletes, use_name_as_val=False +): + if value in registry: + overrides.append("{}={}".format(name, value)) + overrides.append("{}._name={}".format(name, value)) + overrides.extend(_override_attr(name, registry[value], args)) + elif use_name_as_val and value is not None: + overrides.append("{}={}".format(name, value)) + else: + deletes.append(name) def override_module_args(args: Namespace) -> Tuple[List[str], List[str]]: @@ -709,53 +869,34 @@ def override_module_args(args: Namespace) -> Tuple[List[str], List[str]]: overrides = [] deletes = [] + for k, v in CONFIGS.items(): + overrides.extend(_override_attr(k, v, args)) + if args is not None: - assert ( - hasattr(args, "task") - and hasattr(args, "criterion") - and hasattr(args, "optimizer") - and hasattr(args, "lr_scheduler") - ) - if args.task in TASK_DATACLASS_REGISTRY: - overrides.append("task={}".format(args.task)) - overrides.append("task._name={}".format(args.task)) - overrides.extend( - _override_attr("task", TASK_DATACLASS_REGISTRY[args.task], args) + if hasattr(args, "task"): + migrate_registry( + "task", args.task, TASK_DATACLASS_REGISTRY, args, overrides, deletes ) else: deletes.append("task") - if args.criterion in CRITERION_DATACLASS_REGISTRY: - overrides.append("criterion={}".format(args.criterion)) - overrides.append("criterion._name={}".format(args.criterion)) - overrides.extend( - _override_attr( - "criterion", CRITERION_DATACLASS_REGISTRY[args.criterion], args - ) - ) - else: - deletes.append("criterion") - if args.optimizer in OPTIMIZER_DATACLASS_REGISTRY: - overrides.append("optimizer={}".format(args.optimizer)) - overrides.append("optimizer._name={}".format(args.optimizer)) - overrides.extend( - _override_attr( - "optimizer", OPTIMIZER_DATACLASS_REGISTRY[args.optimizer], args - ) - ) - else: - deletes.append("optimizer") - if args.lr_scheduler in LR_SCHEDULER_DATACLASS_REGISTRY: - overrides.append("lr_scheduler={}".format(args.lr_scheduler)) - overrides.append("lr_scheduler._name={}".format(args.lr_scheduler)) - overrides.extend( - _override_attr( - "lr_scheduler", - LR_SCHEDULER_DATACLASS_REGISTRY[args.lr_scheduler], + + # these options will be set to "None" if they have not yet been migrated + # so we can populate them with the entire flat args + CORE_REGISTRIES = {"criterion", "optimizer", "lr_scheduler"} + + for k, v in REGISTRIES.items(): + if hasattr(args, k): + migrate_registry( + k, + getattr(args, k), + v["dataclass_registry"], args, + overrides, + deletes, + use_name_as_val=k not in CORE_REGISTRIES, ) - ) - else: - deletes.append("lr_scheduler") + else: + deletes.append(k) no_dc = True if hasattr(args, "arch"): diff --git a/fairseq/dataclass/utils.py b/fairseq/dataclass/utils.py index 599cc2b4..bcfe2329 100644 --- a/fairseq/dataclass/utils.py +++ b/fairseq/dataclass/utils.py @@ -3,17 +3,24 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from argparse import ArgumentParser -from dataclasses import MISSING, dataclass +import ast +from argparse import ArgumentParser, Namespace +from dataclasses import _MISSING_TYPE, MISSING, dataclass from enum import Enum from typing import Any, Dict, List, Optional +from hydra.core.global_hydra import GlobalHydra +from hydra.experimental import compose, initialize +from omegaconf import DictConfig, OmegaConf, open_dict + def eval_str_list(x, x_type=float): if x is None: return None if isinstance(x, str): - x = eval(x) + if len(x) == 0: + return [] + x = ast.literal_eval(x) try: return list(map(x_type, x)) except TypeError: @@ -70,22 +77,11 @@ class FairseqDataclass: != self.__dataclass_fields__[attribute_name].default ): return getattr(self, attribute_name) - return self.__dataclass_fields__[attribute_name].default - def _get_default_factory(self, attribute_name: str) -> Any: - if hasattr(self, attribute_name): - if str(getattr(self, attribute_name)).startswith("${"): - return str(getattr(self, attribute_name)) - elif str(self.__dataclass_fields__[attribute_name].default).startswith( - "${" - ): - return str(self.__dataclass_fields__[attribute_name].default) - elif ( - getattr(self, attribute_name) - != self.__dataclass_fields__[attribute_name].default_factory() - ): - return getattr(self, attribute_name) - return self.__dataclass_fields__[attribute_name].default_factory() + f = self.__dataclass_fields__[attribute_name] + if not isinstance(f.default_factory, _MISSING_TYPE): + return f.default_factory() + return f.default def _get_type(self, attribute_name: str) -> Any: return self.__dataclass_fields__[attribute_name].type @@ -119,7 +115,7 @@ def gen_parser_from_dataclass( def interpret_dc_type(field_type): if isinstance(field_type, str): - raise RuntimeError() + raise RuntimeError("field should be a type") typestring = str(field_type) if re.match(r"(typing.|^)Union\[(.*), NoneType\]$", typestring): return field_type.__args__[0] @@ -129,12 +125,13 @@ def gen_parser_from_dataclass( dataclass_instance: FairseqDataclass, k: str ) -> Dict[str, Any]: """k: dataclass attributes""" + + kwargs = {} + field_type = dataclass_instance._get_type(k) inter_type = interpret_dc_type(field_type) - if isinstance(inter_type, type) and issubclass(inter_type, List): - field_default = dataclass_instance._get_default_factory(k) - else: - field_default = dataclass_instance._get_default(k) + + field_default = dataclass_instance._get_default(k) if isinstance(inter_type, type) and issubclass(inter_type, Enum): field_choices = [t.value for t in list(inter_type)] @@ -143,7 +140,7 @@ def gen_parser_from_dataclass( field_help = dataclass_instance._get_help(k) field_const = dataclass_instance._get_argparse_const(k) - kwargs = {} + if isinstance(field_default, str) and field_default.startswith("${"): kwargs["default"] = field_default else: @@ -163,7 +160,11 @@ def gen_parser_from_dataclass( else: raise NotImplementedError() if field_default is not MISSING: - kwargs["default"] = ",".join(map(str, field_default)) + kwargs["default"] = ( + ",".join(map(str, field_default)) + if field_default is not None + else None + ) elif ( isinstance(inter_type, type) and issubclass(inter_type, Enum) ) or "Enum" in str(inter_type): @@ -187,6 +188,7 @@ def gen_parser_from_dataclass( if field_const is not None: kwargs["const"] = field_const kwargs["nargs"] = "?" + return kwargs for k in dataclass_instance._get_all_attributes(): @@ -194,8 +196,122 @@ def gen_parser_from_dataclass( if field_name is None: continue kwargs = get_kwargs_from_dc(dataclass_instance, k) - if isinstance(kwargs["default"], str) and kwargs["default"].startswith("${"): - continue - if delete_default: - del kwargs["default"] + + if "default" in kwargs: + if isinstance(kwargs["default"], str) and kwargs["default"].startswith( + "${" + ): + continue + if delete_default: + del kwargs["default"] parser.add_argument(field_name, **kwargs) + + +def _set_legacy_defaults(args, cls): + """Helper to set default arguments based on *add_args*.""" + if not hasattr(cls, "add_args"): + return + + import argparse + + parser = argparse.ArgumentParser( + argument_default=argparse.SUPPRESS, allow_abbrev=False + ) + cls.add_args(parser) + # copied from argparse.py: + defaults = argparse.Namespace() + for action in parser._actions: + if action.dest is not argparse.SUPPRESS: + if not hasattr(defaults, action.dest): + if action.default is not argparse.SUPPRESS: + setattr(defaults, action.dest, action.default) + for key, default_value in vars(defaults).items(): + if not hasattr(args, key): + setattr(args, key, default_value) + + +def convert_namespace_to_omegaconf(args: Namespace) -> DictConfig: + from fairseq.dataclass.data_class import override_module_args + + # Here we are using field values provided in args to override counterparts inside config object + overrides, deletes = override_module_args(args) + + cfg_name = "config" + cfg_path = f"../../{cfg_name}" + + if not GlobalHydra().is_initialized(): + initialize(config_path=cfg_path) + + composed_cfg = compose(cfg_name, overrides=overrides, strict=False) + for k in deletes: + composed_cfg[k] = None + + cfg = OmegaConf.create( + OmegaConf.to_container(composed_cfg, resolve=True, enum_to_str=True) + ) + + # hack to be able to set Namespace in dict config. this should be removed when we update to newer + # omegaconf version that supports object flags, or when we migrate all existing models + from omegaconf import _utils + + old_primitive = _utils.is_primitive_type + _utils.is_primitive_type = lambda _: True + + if cfg.task is None and getattr(args, "task", None): + cfg.task = Namespace(**vars(args)) + from fairseq.tasks import TASK_REGISTRY + + _set_legacy_defaults(cfg.task, TASK_REGISTRY[args.task]) + cfg.task._name = args.task + if cfg.model is None and getattr(args, "arch", None): + cfg.model = Namespace(**vars(args)) + from fairseq.models import ARCH_MODEL_REGISTRY + + _set_legacy_defaults(cfg.model, ARCH_MODEL_REGISTRY[args.arch]) + cfg.model._name = args.arch + if cfg.optimizer is None and getattr(args, "optimizer", None): + cfg.optimizer = Namespace(**vars(args)) + from fairseq.optim import OPTIMIZER_REGISTRY + + _set_legacy_defaults(cfg.optimizer, OPTIMIZER_REGISTRY[args.optimizer]) + cfg.optimizer._name = args.optimizer + if cfg.lr_scheduler is None and getattr(args, "lr_scheduler", None): + cfg.lr_scheduler = Namespace(**vars(args)) + from fairseq.optim.lr_scheduler import LR_SCHEDULER_REGISTRY + + _set_legacy_defaults(cfg.lr_scheduler, LR_SCHEDULER_REGISTRY[args.lr_scheduler]) + cfg.lr_scheduler._name = args.lr_scheduler + if cfg.criterion is None and getattr(args, "criterion", None): + cfg.criterion = Namespace(**vars(args)) + from fairseq.criterions import CRITERION_REGISTRY + + _set_legacy_defaults(cfg.criterion, CRITERION_REGISTRY[args.criterion]) + cfg.criterion._name = args.criterion + + _utils.is_primitive_type = old_primitive + OmegaConf.set_struct(cfg, True) + return cfg + + +def populate_dataclass( + args: Namespace, dataclass: FairseqDataclass +) -> FairseqDataclass: + for k in dataclass.__dataclass_fields__.keys(): + if k.startswith("_"): + # private member, skip + continue + if hasattr(args, k): + setattr(dataclass, k, getattr(args, k)) + + return dataclass + + +def overwrite_args_by_name(cfg: DictConfig, overrides: Dict[str, any]): + # this will be deprecated when we get rid of argparse and model_overrides logic + + with open_dict(cfg): + for k in cfg.keys(): + if isinstance(cfg[k], DictConfig): + overwrite_args_by_name(cfg[k], overrides) + elif k in overrides: + cfg[k] = overrides[k] diff --git a/fairseq/distributed_utils.py b/fairseq/distributed_utils.py index bcb0595e..23cdfc69 100644 --- a/fairseq/distributed_utils.py +++ b/fairseq/distributed_utils.py @@ -11,35 +11,38 @@ import socket import struct import subprocess import warnings +from argparse import Namespace from collections import OrderedDict from typing import Any, Dict, Mapping import torch import torch.distributed as dist from fairseq import utils +from fairseq.dataclass.utils import convert_namespace_to_omegaconf +from omegaconf import DictConfig, open_dict logger = logging.getLogger(__name__) -def is_master(args): - return args.distributed_rank == 0 +def is_master(cfg: DictConfig): + return cfg.distributed_rank == 0 -def infer_init_method(args, force_distributed=False): - if args.distributed_init_method is not None or getattr(args, "tpu", False): +def infer_init_method(cfg: DictConfig, force_distributed=False): + if cfg.distributed_init_method is not None or cfg.tpu: return - if args.pipeline_model_parallel: + if cfg.pipeline_model_parallel: balance_exists = ( - args.pipeline_balance is not None - or args.pipeline_encoder_balance is not None - or args.pipeline_decoder_balance is not None + cfg.pipeline_balance is not None + or cfg.pipeline_encoder_balance is not None + or cfg.pipeline_decoder_balance is not None ) devices_exist = ( - args.pipeline_devices is not None - or args.pipeline_encoder_devices is not None - or args.pipeline_decoder_devices is not None + cfg.pipeline_devices is not None + or cfg.pipeline_encoder_devices is not None + or cfg.pipeline_decoder_devices is not None ) if not balance_exists: raise ValueError( @@ -50,19 +53,19 @@ def infer_init_method(args, force_distributed=False): "--pipeline-devices is currently required for pipeline model parallelism" ) - args.pipeline_balance = utils.eval_str_list(args.pipeline_balance, type=int) - if args.pipeline_devices is not None: - args.pipeline_devices = utils.eval_str_list(args.pipeline_devices, type=int) - num_pipeline_devices = len(set(args.pipeline_devices)) + cfg.pipeline_balance = utils.eval_str_list(cfg.pipeline_balance, type=int) + if cfg.pipeline_devices is not None: + cfg.pipeline_devices = utils.eval_str_list(cfg.pipeline_devices, type=int) + num_pipeline_devices = len(set(cfg.pipeline_devices)) else: - args.pipeline_encoder_devices = utils.eval_str_list( - args.pipeline_encoder_devices, type=int + cfg.pipeline_encoder_devices = utils.eval_str_list( + cfg.pipeline_encoder_devices, type=int ) - args.pipeline_decoder_devices = utils.eval_str_list( - args.pipeline_decoder_devices, type=int + cfg.pipeline_decoder_devices = utils.eval_str_list( + cfg.pipeline_decoder_devices, type=int ) num_pipeline_devices = len( - set(args.pipeline_encoder_devices + args.pipeline_decoder_devices) + set(cfg.pipeline_encoder_devices + cfg.pipeline_decoder_devices) ) gpus_per_node = torch.cuda.device_count() assert ( @@ -79,14 +82,14 @@ def infer_init_method(args, force_distributed=False): key in os.environ for key in ["MASTER_ADDR", "MASTER_PORT", "WORLD_SIZE", "RANK"] ): - args.distributed_init_method = "env://" - args.distributed_world_size = int(os.environ["WORLD_SIZE"]) - args.distributed_rank = int(os.environ["RANK"]) + cfg.distributed_init_method = "env://" + cfg.distributed_world_size = int(os.environ["WORLD_SIZE"]) + cfg.distributed_rank = int(os.environ["RANK"]) # processes are created by torch.distributed.launch - args.distributed_no_spawn = True + cfg.distributed_no_spawn = True # we can determine the init method automatically for Slurm - elif args.distributed_port > 0: + elif cfg.distributed_port > 0: node_list = os.environ.get("SLURM_STEP_NODELIST") if node_list is None: node_list = os.environ.get("SLURM_JOB_NODELIST") @@ -95,9 +98,9 @@ def infer_init_method(args, force_distributed=False): hostnames = subprocess.check_output( ["scontrol", "show", "hostnames", node_list] ) - args.distributed_init_method = "tcp://{host}:{port}".format( + cfg.distributed_init_method = "tcp://{host}:{port}".format( host=hostnames.split()[0].decode("utf-8"), - port=args.distributed_port, + port=cfg.distributed_port, ) nnodes = int(os.environ.get("SLURM_NNODES")) ntasks_per_node = os.environ.get("SLURM_NTASKS_PER_NODE") @@ -111,88 +114,94 @@ def infer_init_method(args, force_distributed=False): if ntasks_per_node == 1: gpus_per_node = torch.cuda.device_count() node_id = int(os.environ.get("SLURM_NODEID")) - args.distributed_rank = node_id * gpus_per_node - args.distributed_world_size = nnodes * gpus_per_node - elif args.pipeline_model_parallel: + cfg.distributed_rank = node_id * gpus_per_node + cfg.distributed_world_size = nnodes * gpus_per_node + elif cfg.pipeline_model_parallel: assert ntasks_per_node == num_pipelines_per_node, ( "SLURM --ntasks-per-node must match number of pipelines per " "node (={})".format(num_pipelines_per_node) ) - args.distributed_no_spawn = True + cfg.distributed_no_spawn = True # For 4-way MP on nodes with 8 GPUs, ranks will be [0, 1] on # the first node, [1, 2] on the second node, etc. This # matches torch.distributed.launch. node_id = int(os.environ.get("SLURM_NODEID")) local_id = int(os.environ.get("SLURM_LOCALID")) - args.distributed_rank = node_id * num_pipelines_per_node + local_id + cfg.distributed_rank = node_id * num_pipelines_per_node + local_id # In the above example, device_id will always be in [0, 1], # which also matches torch.distributed.launch. - args.device_id = local_id + cfg.device_id = local_id # We also want to set distributed_world_size to be the total # number of pipelines across all nodes. - args.distributed_world_size = nnodes * num_pipelines_per_node + cfg.distributed_world_size = nnodes * num_pipelines_per_node else: - assert ntasks_per_node == args.distributed_world_size // nnodes - args.distributed_no_spawn = True - args.distributed_rank = int(os.environ.get("SLURM_PROCID")) - args.device_id = int(os.environ.get("SLURM_LOCALID")) + assert ntasks_per_node == cfg.distributed_world_size // nnodes + cfg.distributed_no_spawn = True + cfg.distributed_rank = int(os.environ.get("SLURM_PROCID")) + cfg.device_id = int(os.environ.get("SLURM_LOCALID")) except subprocess.CalledProcessError as e: # scontrol failed raise e except FileNotFoundError: # Slurm is not installed pass - elif args.distributed_world_size > 1 or force_distributed: + elif cfg.distributed_world_size > 1 or force_distributed: # fallback for single node with multiple GPUs - assert args.distributed_world_size <= torch.cuda.device_count() + assert cfg.distributed_world_size <= torch.cuda.device_count() port = random.randint(10000, 20000) - args.distributed_init_method = "tcp://localhost:{port}".format(port=port) + cfg.distributed_init_method = "tcp://localhost:{port}".format(port=port) - if args.pipeline_model_parallel: - if not args.distributed_no_spawn: + if cfg.pipeline_model_parallel: + if not cfg.distributed_no_spawn: # When distributed_no_spawn is False, we expect distributed_rank and # distributed_world_size to be based on the total number of GPUs, so # we need to correct them to be based on the number of pipelines. - assert args.distributed_world_size % num_pipeline_devices == 0 - args.distributed_world_size = ( - args.distributed_world_size // num_pipeline_devices + assert cfg.distributed_world_size % num_pipeline_devices == 0 + cfg.distributed_world_size = ( + cfg.distributed_world_size // num_pipeline_devices ) # In the case of 4-way MP on nodes with 8 GPUs, we want # distributed_rank to be the starting GPU index for each pipeline # i.e., 0, 2, ... - assert args.distributed_rank % gpus_per_node == 0 - assert args.distributed_rank % num_pipeline_devices == 0 - args.distributed_rank = args.distributed_rank // num_pipeline_devices - # launch one process per pipeline - args.distributed_num_procs = num_pipelines_per_node + assert cfg.distributed_rank % gpus_per_node == 0 + assert cfg.distributed_rank % num_pipeline_devices == 0 + + with open_dict(cfg): + cfg.distributed_rank = cfg.distributed_rank // num_pipeline_devices + # launch one process per pipeline + cfg.distributed_num_procs = num_pipelines_per_node # if we have 4-way MP on a node with 8 GPUs, we want device_ids to be 0 # and 4, indicating the starting device IDs for each pipeline - args.device_id *= num_pipeline_devices + cfg.device_id *= num_pipeline_devices - if args.device_id > 0: + if cfg.device_id > 0: # if there's multiple pipelines on a node (e.g., 4-way MP on an 8 # GPU node), we need to adjust pipeline_devices accordingly logger.debug( "setting CUDA device={} on rank {}".format( - args.device_id, args.distributed_rank + cfg.device_id, cfg.distributed_rank ) ) - torch.cuda.set_device(args.device_id) - args.pipeline_devices = [args.device_id + d for d in args.pipeline_devices] + torch.cuda.set_device(cfg.device_id) + with open_dict(cfg): + cfg.pipeline_devices = [cfg.device_id + d for d in cfg.pipeline_devices] logger.info( "setting pipeline_devices={} on rank {}".format( - args.pipeline_devices, args.distributed_rank - ), + cfg.pipeline_devices, cfg.distributed_rank + ) + ) + elif not cfg.distributed_no_spawn: + with open_dict(cfg): + cfg.distributed_num_procs = min( + torch.cuda.device_count(), cfg.distributed_world_size ) - elif not args.distributed_no_spawn: - args.distributed_num_procs = min( - torch.cuda.device_count(), - args.distributed_world_size, - ) -def distributed_init(args): - if not getattr(args, "tpu", False): +def distributed_init(cfg: DictConfig): + if isinstance(cfg, Namespace): + cfg = convert_namespace_to_omegaconf(cfg) + + if not cfg.common.tpu: if torch.distributed.is_initialized(): warnings.warn( "Distributed is already initialized, cannot initialize twice!" @@ -200,20 +209,20 @@ def distributed_init(args): else: logger.info( "distributed init (rank {}): {}".format( - args.distributed_rank, - args.distributed_init_method, + cfg.distributed_training.distributed_rank, + cfg.distributed_training.distributed_init_method, ) ) dist.init_process_group( - backend=args.distributed_backend, - init_method=args.distributed_init_method, - world_size=args.distributed_world_size, - rank=args.distributed_rank, + backend=cfg.distributed_training.distributed_backend, + init_method=cfg.distributed_training.distributed_init_method, + world_size=cfg.distributed_training.distributed_world_size, + rank=cfg.distributed_training.distributed_rank, ) logger.info( "initialized host {} as rank {}".format( socket.gethostname(), - args.distributed_rank, + cfg.distributed_training.distributed_rank, ) ) @@ -221,20 +230,22 @@ def distributed_init(args): if torch.cuda.is_available(): dist.all_reduce(torch.zeros(1).cuda()) - args.distributed_rank = torch.distributed.get_rank() + cfg.distributed_training.distributed_rank = torch.distributed.get_rank() else: import torch_xla.core.xla_model as xm - assert xm.xrt_world_size() == args.distributed_world_size - args.device_id = xm.get_local_ordinal() - args.distributed_rank = xm.get_ordinal() + assert xm.xrt_world_size() == cfg.distributed_training.distributed_world_size + cfg.distributed_training.device_id = xm.get_local_ordinal() + cfg.distributed_training.distributed_rank = xm.get_ordinal() xm.rendezvous("distributed_init") # wait for all workers xm.mark_step() - if not is_master(args): + if is_master(cfg.distributed_training): + logging.getLogger().setLevel(logging.INFO) + else: logging.getLogger().setLevel(logging.WARNING) - if args.model_parallel_size > 1: + if cfg.common.model_parallel_size > 1: try: from fairseq.model_parallel.megatron.mpu import ( get_model_parallel_rank, @@ -247,58 +258,61 @@ def distributed_init(args): "\n\n git submodule update --init " "fairseq/model_parallel/megatron" ) - initialize_model_parallel(args.model_parallel_size) - model_parallel_cuda_manual_seed(args.seed) + initialize_model_parallel(cfg.common.model_parallel_size) + model_parallel_cuda_manual_seed(cfg.common.seed) model_part_number = get_model_parallel_rank() - args.checkpoint_suffix += "-model_part-{0}".format(model_part_number) - return args.distributed_rank + cfg.checkpoint.checkpoint_suffix += "-model_part-{0}".format(model_part_number) + return cfg.distributed_training.distributed_rank -def distributed_main(i, main, args, kwargs): - args.device_id = i - if torch.cuda.is_available() and not args.cpu and not getattr(args, "tpu", False): - torch.cuda.set_device(args.device_id) - if args.distributed_rank is None: # torch.multiprocessing.spawn - args.distributed_rank = kwargs.pop("start_rank", 0) + i +def distributed_main(i, main, cfg: DictConfig, kwargs): + cfg.distributed_training.device_id = i + if torch.cuda.is_available() and not cfg.common.cpu and not cfg.common.tpu: + torch.cuda.set_device(cfg.distributed_training.device_id) + if cfg.distributed_training.distributed_rank is None: # torch.multiprocessing.spawn + cfg.distributed_training.distributed_rank = kwargs.pop("start_rank", 0) + i - args.distributed_rank = distributed_init(args) + cfg.distributed_training.distributed_rank = distributed_init(cfg) after_distributed_init_fn = kwargs.pop("after_distributed_init_fn", None) if after_distributed_init_fn: - args = after_distributed_init_fn(args) + cfg = after_distributed_init_fn(cfg) - main(args, **kwargs) + main(cfg, **kwargs) -def call_main(args, main, **kwargs): - if args.distributed_init_method is None: - infer_init_method(args) +def call_main(cfg: DictConfig, main, **kwargs): + if cfg.distributed_training.distributed_init_method is None: + infer_init_method(cfg.distributed_training) - if args.distributed_init_method is not None: + if cfg.distributed_training.distributed_init_method is not None: # distributed training - if not args.distributed_no_spawn: - start_rank = args.distributed_rank - args.distributed_rank = None # assign automatically + if not cfg.distributed_training.distributed_no_spawn: + start_rank = cfg.distributed_training.distributed_rank + cfg.distributed_training.distributed_rank = None # assign automatically kwargs["start_rank"] = start_rank torch.multiprocessing.spawn( fn=distributed_main, - args=(main, args, kwargs), - nprocs=args.distributed_num_procs, + args=(main, cfg, kwargs), + nprocs=min( + torch.cuda.device_count(), + cfg.distributed_training.distributed_world_size, + ), ) else: - distributed_main(args.device_id, main, args, kwargs) - elif getattr(args, "tpu", False) and args.distributed_world_size > 1: + distributed_main(cfg.distributed_training.device_id, main, cfg, kwargs) + elif cfg.common.tpu and cfg.distributed_training.distributed_world_size > 1: import torch_xla.distributed.xla_multiprocessing as xmp torch.multiprocessing.set_sharing_strategy("file_system") xmp.spawn( fn=distributed_main, - args=(main, args, kwargs), + args=(main, cfg, kwargs), nprocs=8, # use all 8 TPU cores ) else: # single GPU main - main(args, **kwargs) + main(cfg, **kwargs) def get_rank(): @@ -392,11 +406,7 @@ def all_gather_list(data, group=None, max_size=16384): ) -def all_reduce_dict( - data: Mapping[str, Any], - device, - group=None, -) -> Dict[str, Any]: +def all_reduce_dict(data: Mapping[str, Any], device, group=None) -> Dict[str, Any]: """ AllReduce a dictionary of values across workers. We separately reduce items that are already on the device and items on CPU for diff --git a/fairseq/hub_utils.py b/fairseq/hub_utils.py index b293e54e..3be7078b 100644 --- a/fairseq/hub_utils.py +++ b/fairseq/hub_utils.py @@ -8,11 +8,12 @@ import argparse import copy import logging import os -from typing import Any, Dict, Iterator, List, Tuple +from typing import Any, Dict, Iterator, List import torch from fairseq import utils from fairseq.data import encoders +from omegaconf import open_dict from torch import nn @@ -85,9 +86,9 @@ class GeneratorHubInterface(nn.Module): translation or language model. """ - def __init__(self, args, task, models): + def __init__(self, cfg, task, models): super().__init__() - self.args = args + self.cfg = cfg self.task = task self.models = nn.ModuleList(models) self.src_dict = task.source_dictionary @@ -95,14 +96,14 @@ class GeneratorHubInterface(nn.Module): # optimize model for generation for model in self.models: - model.prepare_for_inference_(args) + model.prepare_for_inference_(cfg) # Load alignment dictionary for unknown word replacement # (None if no unknown word replacement, empty if no path to align dictionary) - self.align_dict = utils.load_align_dict(getattr(args, "replace_unk", None)) + self.align_dict = utils.load_align_dict(cfg.generation.replace_unk) - self.tokenizer = encoders.build_tokenizer(args) - self.bpe = encoders.build_bpe(args) + self.tokenizer = encoders.build_tokenizer(cfg.tokenizer) + self.bpe = encoders.build_bpe(cfg.bpe) self.max_positions = utils.resolve_max_positions( self.task.max_positions(), *[model.max_positions() for model in models] @@ -156,10 +157,11 @@ class GeneratorHubInterface(nn.Module): )[0] # build generator using current args as well as any kwargs - gen_args = copy.copy(self.args) - gen_args.beam = beam - for k, v in kwargs.items(): - setattr(gen_args, k, v) + gen_args = copy.copy(self.cfg) + with open_dict(gen_args): + gen_args.beam = beam + for k, v in kwargs.items(): + setattr(gen_args, k, v) generator = self.task.build_generator(self.models, gen_args) inference_step_args = inference_step_args or {} @@ -253,8 +255,8 @@ class GeneratorHubInterface(nn.Module): lengths = torch.LongTensor([t.numel() for t in tokens]) batch_iterator = self.task.get_batch_iterator( dataset=self.task.build_dataset_for_inference(tokens, lengths), - max_tokens=self.args.max_tokens, - max_sentences=self.args.batch_size, + max_tokens=self.cfg.dataset.max_tokens, + max_sentences=self.cfg.dataset.batch_size, max_positions=self.max_positions, ignore_invalid_inputs=skip_invalid_size_inputs, disable_iterator_cache=True, diff --git a/fairseq/model_parallel/megatron_trainer.py b/fairseq/model_parallel/megatron_trainer.py index 761ffc8e..258551c9 100644 --- a/fairseq/model_parallel/megatron_trainer.py +++ b/fairseq/model_parallel/megatron_trainer.py @@ -9,6 +9,7 @@ Train a network across multiple GPUs. from fairseq import distributed_utils from fairseq.trainer import Trainer +from omegaconf import DictConfig try: @@ -28,14 +29,14 @@ except (ImportError, ModuleNotFoundError): class MegatronTrainer(Trainer): """Main class for model parallel with data parallel training.""" - def __init__(self, args, task, model, criterion): + def __init__(self, cfg: DictConfig, task, model, criterion, **kwargs): if not has_megatron_submodule: raise ImportError( "\n\nPlease install the megatron submodule:" "\n\n git submodule update --init " "fairseq/model_parallel/megatron" ) - super().__init__(args, task, model, criterion) + super().__init__(cfg, task, model, criterion, **kwargs) @property def data_parallel_world_size(self): diff --git a/fairseq/model_parallel/models/pipeline_parallel_transformer/model.py b/fairseq/model_parallel/models/pipeline_parallel_transformer/model.py index cbfc6ae4..76cfe3b0 100644 --- a/fairseq/model_parallel/models/pipeline_parallel_transformer/model.py +++ b/fairseq/model_parallel/models/pipeline_parallel_transformer/model.py @@ -96,7 +96,7 @@ class PipelineParallelTransformerModel(BaseFairseqModel): encoder_output_tuple = self.encoder(input) return self.decoder(encoder_output_tuple) - def prepare_for_inference_(self, args): + def prepare_for_inference_(self, cfg): if self.encoder is not None and self.decoder is not None: logger.info("Encoder and Decoder already initialized") return @@ -111,9 +111,9 @@ class PipelineParallelTransformerModel(BaseFairseqModel): decoder_module_list.append(module) module_count += 1 self.model = None - self.encoder = TransformerEncoder(args, None, None, encoder_module_list) + self.encoder = TransformerEncoder(cfg.model, None, None, encoder_module_list) self.decoder = TransformerDecoder( - args, None, None, decoder_module_list=decoder_module_list + cfg.model, None, None, decoder_module_list=decoder_module_list ) @staticmethod @@ -320,7 +320,7 @@ class PipelineParallelTransformerModel(BaseFairseqModel): """Maximum length supported by the decoder.""" return self.decoder_max_positions - def load_state_dict(self, state_dict, strict=True, args=None): + def load_state_dict(self, state_dict, strict=True, cfg=None): """Copies parameters and buffers from *state_dict* into this module and its descendants. diff --git a/fairseq/model_parallel/models/transformer_lm.py b/fairseq/model_parallel/models/transformer_lm.py index 5db6efb7..dc52f6e8 100644 --- a/fairseq/model_parallel/models/transformer_lm.py +++ b/fairseq/model_parallel/models/transformer_lm.py @@ -72,6 +72,10 @@ class ModelParallelTransformerLanguageModel(TransformerLanguageModel): ) return cls(decoder) + @staticmethod + def add_args(parser): + TransformerLanguageModel.add_args(parser) + @classmethod def build_embedding(cls, args, dictionary, embed_dim, path=None): def _vocab_init(tensor, **kwargs): diff --git a/fairseq/models/__init__.py b/fairseq/models/__init__.py index 7ff94427..3b4fd51d 100644 --- a/fairseq/models/__init__.py +++ b/fairseq/models/__init__.py @@ -7,8 +7,6 @@ import argparse import importlib import os -from argparse import Namespace -from typing import Union import fairseq from fairseq.dataclass import FairseqDataclass @@ -52,10 +50,10 @@ __all__ = [ ] -def build_model(model_cfg: Union[DictConfig, Namespace], task): - if isinstance(model_cfg, DictConfig): - return ARCH_MODEL_REGISTRY[model_cfg._name].build_model(model_cfg, task) - return ARCH_MODEL_REGISTRY[model_cfg.arch].build_model(model_cfg, task) +def build_model(cfg: DictConfig, task): + if isinstance(cfg, DictConfig): + return ARCH_MODEL_REGISTRY[cfg._name].build_model(cfg, task) + return ARCH_MODEL_REGISTRY[cfg.arch].build_model(cfg, task) def register_model(name, dataclass=None): @@ -92,7 +90,8 @@ def register_model(name, dataclass=None): ) cls.__dataclass = dataclass - MODEL_DATACLASS_REGISTRY[name] = dataclass + if dataclass is not None: + MODEL_DATACLASS_REGISTRY[name] = dataclass return cls return register_model_cls @@ -108,14 +107,13 @@ def register_model_architecture(model_name, arch_name): For example:: @register_model_architecture('lstm', 'lstm_luong_wmt_en_de') - def lstm_luong_wmt_en_de(args): - args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 1000) + def lstm_luong_wmt_en_de(cfg): + args.encoder_embed_dim = getattr(cfg.model, 'encoder_embed_dim', 1000) (...) - The decorated function should take a single argument *args*, which is a - :class:`argparse.Namespace` of arguments parsed from the command-line. The - decorated function should modify these arguments in-place to match the - desired architecture. + The decorated function should take a single argument *cfg*, which is a + :class:`omegaconf.DictConfig`. The decorated function should modify these + arguments in-place to match the desired architecture. Args: model_name (str): the name of the Model (Model must already be diff --git a/fairseq/models/bart/hub_interface.py b/fairseq/models/bart/hub_interface.py index cdabe360..6a520cb9 100644 --- a/fairseq/models/bart/hub_interface.py +++ b/fairseq/models/bart/hub_interface.py @@ -13,6 +13,7 @@ import torch.nn as nn import torch.nn.functional as F from fairseq import utils from fairseq.data import encoders +from omegaconf import open_dict logger = logging.getLogger(__name__) @@ -24,13 +25,13 @@ class BARTHubInterface(nn.Module): Usage: https://github.com/pytorch/fairseq/tree/master/examples/bart """ - def __init__(self, args, task, model): + def __init__(self, cfg, task, model): super().__init__() - self.args = args + self.cfg = cfg self.task = task self.model = model - self.bpe = encoders.build_bpe(args) + self.bpe = encoders.build_bpe(cfg.bpe) self.max_positions = min( utils.resolve_max_positions( @@ -120,10 +121,11 @@ class BARTHubInterface(nn.Module): sample = self._build_sample(tokens) # build generator using current args as well as any kwargs - gen_args = copy.copy(self.args) - gen_args.beam = beam - for k, v in kwargs.items(): - setattr(gen_args, k, v) + gen_args = copy.copy(self.cfg) + with open_dict(gen_args): + gen_args.beam = beam + for k, v in kwargs.items(): + setattr(gen_args, k, v) generator = self.task.build_generator([self.model], gen_args) translations = self.task.inference_step( generator, diff --git a/fairseq/models/bart/model.py b/fairseq/models/bart/model.py index 0f22352b..7263a78d 100644 --- a/fairseq/models/bart/model.py +++ b/fairseq/models/bart/model.py @@ -144,7 +144,9 @@ class BARTModel(TransformerModel): num_classes=num_classes, activation_fn=self.args.pooler_activation_fn, pooler_dropout=self.args.pooler_dropout, - do_spectral_norm=self.args.spectral_norm_classification_head, + do_spectral_norm=getattr( + self.args, "spectral_norm_classification_head", False + ), ) def upgrade_state_dict_named(self, state_dict, name): diff --git a/fairseq/models/fairseq_model.py b/fairseq/models/fairseq_model.py index bfd41777..3ebb30e3 100644 --- a/fairseq/models/fairseq_model.py +++ b/fairseq/models/fairseq_model.py @@ -7,6 +7,7 @@ Base classes for various fairseq models. """ import logging +from argparse import Namespace from typing import Dict, List, Optional, Tuple import torch @@ -15,8 +16,12 @@ import torch.nn.functional as F from fairseq import utils from fairseq.checkpoint_utils import prune_state_dict from fairseq.data import Dictionary -from fairseq.dataclass.utils import gen_parser_from_dataclass +from fairseq.dataclass.utils import ( + convert_namespace_to_omegaconf, + gen_parser_from_dataclass, +) from fairseq.models import FairseqDecoder, FairseqEncoder +from omegaconf import DictConfig from torch import Tensor @@ -86,15 +91,26 @@ class BaseFairseqModel(nn.Module): """Maximum length supported by the model.""" return None - def load_state_dict(self, state_dict, strict=True, args=None): + def load_state_dict( + self, + state_dict, + strict=True, + model_cfg: Optional[DictConfig] = None, + args: Optional[Namespace] = None, + ): """Copies parameters and buffers from *state_dict* into this module and its descendants. Overrides the method in :class:`nn.Module`. Compared with that method this additionally "upgrades" *state_dicts* from old checkpoints. """ + + if model_cfg is None and args is not None: + logger.warn("using 'args' is deprecated, please update your code to use dataclass config") + model_cfg = convert_namespace_to_omegaconf(args).model + self.upgrade_state_dict(state_dict) - new_state_dict = prune_state_dict(state_dict, args) + new_state_dict = prune_state_dict(state_dict, model_cfg) return super().load_state_dict(new_state_dict, strict) def upgrade_state_dict(self, state_dict): @@ -133,18 +149,18 @@ class BaseFairseqModel(nn.Module): self.apply(_apply) - def prepare_for_inference_(self, args): + def prepare_for_inference_(self, cfg: DictConfig): """Prepare model for inference.""" kwargs = {} kwargs["beamable_mm_beam_size"] = ( - None if getattr(args, "no_beamable_mm", False) else getattr(args, "beam", 5) + None + if getattr(cfg.generation, "no_beamable_mm", False) + else getattr(cfg.generation, "beam", 5) ) - kwargs["need_attn"] = getattr(args, "print_alignment", False) - if hasattr(args, "retain_dropout"): - kwargs["retain_dropout"] = args.retain_dropout - kwargs["retain_dropout_modules"] = getattr( - args, "retain_dropout_modules", None - ) + kwargs["need_attn"] = getattr(cfg.generation, "print_alignment", False) + if getattr(cfg.generation, "retain_dropout", False): + kwargs["retain_dropout"] = cfg.generation.retain_dropout + kwargs["retain_dropout_modules"] = cfg.generation.retain_dropout_modules self.make_generation_fast_(**kwargs) def make_generation_fast_(self, **kwargs): @@ -437,15 +453,26 @@ class FairseqMultiModel(BaseFairseqModel): def forward_decoder(self, prev_output_tokens, **kwargs): return self.decoder(prev_output_tokens, **kwargs) - def load_state_dict(self, state_dict, strict=True, args=None): + def load_state_dict( + self, + state_dict, + strict=True, + model_cfg=None, + args: Optional[Namespace] = None, + ): """Copies parameters and buffers from *state_dict* into this module and its descendants. Overrides the method in :class:`nn.Module`. Compared with that method this additionally "upgrades" *state_dicts* from old checkpoints. """ + + if model_cfg is None and args is not None: + logger.warn("using 'args' is deprecated, please update your code to use dataclass config") + model_cfg = convert_namespace_to_omegaconf(args).model + self.upgrade_state_dict(state_dict) - new_state_dict = prune_state_dict(state_dict, args) + new_state_dict = prune_state_dict(state_dict, model_cfg) return super().load_state_dict(new_state_dict, strict) diff --git a/fairseq/models/multilingual_transformer.py b/fairseq/models/multilingual_transformer.py index e3fbbd57..2e1f86f3 100644 --- a/fairseq/models/multilingual_transformer.py +++ b/fairseq/models/multilingual_transformer.py @@ -194,14 +194,14 @@ class MultilingualTransformerModel(FairseqMultiModel): module_class = TransformerEncoder if is_encoder else TransformerDecoder return module_class(args, lang_dict, embed_tokens) - def load_state_dict(self, state_dict, strict=True, args=None): + def load_state_dict(self, state_dict, strict=True, model_cfg=None): state_dict_subset = state_dict.copy() for k, _ in state_dict.items(): assert k.startswith("models.") lang_pair = k.split(".")[1] if lang_pair not in self.models: del state_dict_subset[k] - super().load_state_dict(state_dict_subset, strict=strict, args=args) + super().load_state_dict(state_dict_subset, strict=strict, model_cfg=model_cfg) @register_model_architecture("multilingual_transformer", "multilingual_transformer") diff --git a/fairseq/models/roberta/hub_interface.py b/fairseq/models/roberta/hub_interface.py index 526823bd..0c723f06 100644 --- a/fairseq/models/roberta/hub_interface.py +++ b/fairseq/models/roberta/hub_interface.py @@ -17,13 +17,13 @@ class RobertaHubInterface(nn.Module): Usage: https://github.com/pytorch/fairseq/tree/master/examples/roberta """ - def __init__(self, args, task, model): + def __init__(self, cfg, task, model): super().__init__() - self.args = args + self.cfg = cfg self.task = task self.model = model - self.bpe = encoders.build_bpe(args) + self.bpe = encoders.build_bpe(cfg.bpe) # this is useful for determining the device self.register_buffer("_float_tensor", torch.tensor([0], dtype=torch.float)) diff --git a/fairseq/models/roberta/model.py b/fairseq/models/roberta/model.py index 6ce216a6..d1a63196 100644 --- a/fairseq/models/roberta/model.py +++ b/fairseq/models/roberta/model.py @@ -494,7 +494,7 @@ def base_architecture(args): args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0.0) args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0.0) args.spectral_norm_classification_head = getattr( - args, "spectral_nrom_classification_head", False + args, "spectral_norm_classification_head", False ) diff --git a/fairseq/models/transformer.py b/fairseq/models/transformer.py index fbb7ce23..f87fa50d 100644 --- a/fairseq/models/transformer.py +++ b/fairseq/models/transformer.py @@ -578,10 +578,9 @@ class TransformerDecoder(FairseqIncrementalDecoder): if embed_dim != input_embed_dim else None ) - self.embed_positions = ( PositionalEmbedding( - args.max_target_positions, + self.max_target_positions, embed_dim, self.padding_idx, learned=args.decoder_learned_pos, @@ -963,6 +962,14 @@ def base_architecture(args): args.layernorm_embedding = getattr(args, "layernorm_embedding", False) args.tie_adaptive_weights = getattr(args, "tie_adaptive_weights", False) + args.encoder_layers_to_keep = getattr(args, "encoder_layers_to_keep", None) + args.decoder_layers_to_keep = getattr(args, "decoder_layers_to_keep", None) + args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0) + args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0) + args.quant_noise_pq = getattr(args, "quant_noise_pq", 0) + args.quant_noise_pq_block_size = getattr(args, "quant_noise_pq_block_size", 8) + args.quant_noise_scalar = getattr(args, "quant_noise_scalar", 0) + @register_model_architecture("transformer", "transformer_iwslt_de_en") def transformer_iwslt_de_en(args): diff --git a/fairseq/models/transformer_lm.py b/fairseq/models/transformer_lm.py index 22b17f06..df809bdb 100644 --- a/fairseq/models/transformer_lm.py +++ b/fairseq/models/transformer_lm.py @@ -159,7 +159,7 @@ class TransformerLanguageModelConfig(FairseqDataclass): add_bos_token: bool = II("task.add_bos_token") tokens_per_sample: int = II("task.tokens_per_sample") max_target_positions: Optional[int] = II("task.max_target_positions") - tpu: bool = II("params.common.tpu") + tpu: bool = II("common.tpu") @register_model("transformer_lm", dataclass=TransformerLanguageModelConfig) diff --git a/fairseq/modules/transformer_layer.py b/fairseq/modules/transformer_layer.py index 48cd4c73..8775aa77 100644 --- a/fairseq/modules/transformer_layer.py +++ b/fairseq/modules/transformer_layer.py @@ -32,20 +32,20 @@ class TransformerEncoderLayer(nn.Module): def __init__(self, args): super().__init__() self.embed_dim = args.encoder_embed_dim - self.quant_noise = getattr(args, "quant_noise_pq", 0) - self.quant_noise_block_size = getattr(args, "quant_noise_pq_block_size", 8) + self.quant_noise = getattr(args, 'quant_noise_pq', 0) + self.quant_noise_block_size = getattr(args, 'quant_noise_pq_block_size', 8) or 8 self.self_attn = self.build_self_attention(self.embed_dim, args) self.self_attn_layer_norm = LayerNorm(self.embed_dim) self.dropout_module = FairseqDropout( args.dropout, module_name=self.__class__.__name__ ) self.activation_fn = utils.get_activation_fn( - activation=getattr(args, "activation_fn", "relu") + activation=getattr(args, 'activation_fn', 'relu') or "relu" ) - activation_dropout_p = getattr(args, "activation_dropout", 0) + activation_dropout_p = getattr(args, "activation_dropout", 0) or 0 if activation_dropout_p == 0: # for backwards compatibility with models that use args.relu_dropout - activation_dropout_p = getattr(args, "relu_dropout", 0) + activation_dropout_p = getattr(args, "relu_dropout", 0) or 0 self.activation_dropout_module = FairseqDropout( float(activation_dropout_p), module_name=self.__class__.__name__ ) @@ -197,10 +197,10 @@ class TransformerDecoderLayer(nn.Module): if getattr(args, "activation_fn", None) is not None else "relu" ) - activation_dropout_p = getattr(args, "activation_dropout", 0) + activation_dropout_p = getattr(args, "activation_dropout", 0) or 0 if activation_dropout_p == 0: # for backwards compatibility with models that use args.relu_dropout - activation_dropout_p = getattr(args, "relu_dropout", 0) + activation_dropout_p = getattr(args, "relu_dropout", 0) or 0 self.activation_dropout_module = FairseqDropout( float(activation_dropout_p), module_name=self.__class__.__name__ ) diff --git a/fairseq/optim/__init__.py b/fairseq/optim/__init__.py index 94eb2c7e..d8e58172 100644 --- a/fairseq/optim/__init__.py +++ b/fairseq/optim/__init__.py @@ -6,8 +6,6 @@ import importlib import os -from argparse import Namespace -from typing import Union from fairseq import registry from fairseq.optim.bmuf import FairseqBMUF # noqa @@ -19,7 +17,6 @@ from fairseq.optim.fp16_optimizer import FP16Optimizer, MemoryEfficientFP16Optim from fairseq.optim.shard import shard_ from omegaconf import DictConfig - __all__ = [ "FairseqOptimizer", "FP16Optimizer", @@ -27,7 +24,6 @@ __all__ = [ "shard_", ] - ( _build_optimizer, register_optimizer, @@ -37,12 +33,12 @@ __all__ = [ def build_optimizer( - optimizer_cfg: Union[DictConfig, Namespace], params, *extra_args, **extra_kwargs + cfg: DictConfig, params, *extra_args, **extra_kwargs ): if all(isinstance(p, dict) for p in params): params = [t for p in params for t in p.values()] params = list(filter(lambda p: p.requires_grad, params)) - return _build_optimizer(optimizer_cfg, params, *extra_args, **extra_kwargs) + return _build_optimizer(cfg, params, *extra_args, **extra_kwargs) # automatically import any Python files in the optim/ directory diff --git a/fairseq/optim/adam.py b/fairseq/optim/adam.py index f678a9f5..9b8ddffd 100644 --- a/fairseq/optim/adam.py +++ b/fairseq/optim/adam.py @@ -5,6 +5,7 @@ import logging import math +from collections import Collection from dataclasses import dataclass, field from typing import List @@ -14,7 +15,7 @@ import torch.optim from fairseq.dataclass import FairseqDataclass from fairseq.optim import FairseqOptimizer, register_optimizer from fairseq.optim.fused_adam import get_fused_adam_class -from omegaconf import II +from omegaconf import II, DictConfig logger = logging.getLogger(__name__) @@ -33,8 +34,8 @@ class FairseqAdamConfig(FairseqDataclass): default=False, metadata={"help": "Use fairseq.optim.adam.Adam"} ) # TODO common vars below in parent - tpu: bool = II("params.common.tpu") - lr: List[float] = II("params.optimization.lr") + tpu: bool = II("common.tpu") + lr: List[float] = II("optimization.lr") @register_optimizer("adam", dataclass=FairseqAdamConfig) @@ -46,15 +47,15 @@ class FairseqAdam(FairseqOptimizer): analogous to torch.optim.AdamW from PyTorch. """ - def __init__(self, args, params): - super().__init__(args) + def __init__(self, cfg: DictConfig, params): + super().__init__(cfg) fused_adam_cls = get_fused_adam_class() use_fused_adam = ( - not getattr(args, "use_old_adam", False) + not getattr(cfg, "use_old_adam", False) and fused_adam_cls is not None and torch.cuda.is_available() ) - if getattr(args, "tpu", False): + if getattr(cfg, "tpu", False): # on TPUs we use the Adam defined here, since it # automatically casts gradients to FP32 self._optimizer = Adam(params, **self.optimizer_config) @@ -73,10 +74,12 @@ class FairseqAdam(FairseqOptimizer): different learning rate. """ return { - "lr": self.args.lr[0], - "betas": eval(self.args.adam_betas), - "eps": self.args.adam_eps, - "weight_decay": self.args.weight_decay, + "lr": self.cfg.lr[0] + if isinstance(self.cfg.lr, Collection) + else self.cfg.lr, + "betas": eval(self.cfg.adam_betas), + "eps": self.cfg.adam_eps, + "weight_decay": self.cfg.weight_decay, } def average_params(self): diff --git a/fairseq/optim/bmuf.py b/fairseq/optim/bmuf.py index 3312f811..55f225ba 100644 --- a/fairseq/optim/bmuf.py +++ b/fairseq/optim/bmuf.py @@ -10,7 +10,7 @@ import torch.distributed as dist from fairseq.dataclass import FairseqDataclass from fairseq.dataclass.utils import gen_parser_from_dataclass from fairseq.optim.fairseq_optimizer import FairseqOptimizer -from omegaconf import II +from omegaconf import II, DictConfig @dataclass @@ -38,7 +38,7 @@ class FairseqBMUFConfig(FairseqDataclass): }, ) distributed_world_size: int = II( - "params.distributed_training.distributed_world_size" + "distributed_training.distributed_world_size" ) @@ -52,20 +52,19 @@ class FairseqBMUF(FairseqOptimizer): model-update filtering """ - def __init__(self, args, optimizer): - - super().__init__(args) + def __init__(self, cfg: DictConfig, optimizer): + super().__init__(cfg) self._optimizer = optimizer self._num_updates = 0 - self.sync_iter = self.args.global_sync_iter - self.block_momentum = self.args.block_momentum - self.block_lr = self.args.block_lr + self.sync_iter = cfg.global_sync_iter + self.block_momentum = cfg.block_momentum + self.block_lr = cfg.block_lr self._reset_local_data() - self.warmup_iteration = self.args.warmup_iterations - self.use_nbm = self.args.use_nbm + self.warmup_iteration = cfg.warmup_iterations + self.use_nbm = cfg.use_nbm self.initial_state = self._optimizer.state_dict() - self.average_sync = self.args.average_sync - self.world_size = self.args.distributed_world_size + self.average_sync = self.cfg.average_sync + self.world_size = self.cfg.distributed_world_size @staticmethod def add_args(parser): diff --git a/fairseq/optim/fairseq_optimizer.py b/fairseq/optim/fairseq_optimizer.py index 8a10399a..9c093833 100644 --- a/fairseq/optim/fairseq_optimizer.py +++ b/fairseq/optim/fairseq_optimizer.py @@ -9,9 +9,9 @@ from fairseq.dataclass.utils import gen_parser_from_dataclass class FairseqOptimizer(object): - def __init__(self, args): + def __init__(self, cfg): super().__init__() - self.args = args + self.cfg = cfg @classmethod def add_args(cls, parser): diff --git a/fairseq/optim/fp16_optimizer.py b/fairseq/optim/fp16_optimizer.py index b622fbde..b08a7237 100644 --- a/fairseq/optim/fp16_optimizer.py +++ b/fairseq/optim/fp16_optimizer.py @@ -7,7 +7,8 @@ from collections import defaultdict from itertools import chain import torch -from fairseq import optim, utils +from fairseq import optim +from omegaconf import DictConfig from .dynamic_loss_scaler import DynamicLossScaler @@ -211,7 +212,7 @@ class _FP16OptimizerMixin(object): for fp32_params in self.fp32_params.values(): fp32_params.grad.zero_() else: - raise ("self.fp32_params must be a tensor or dict") + raise RuntimeError("self.fp32_params must be a tensor or dict") else: for p32 in self.fp32_params: p32.grad.zero_() @@ -226,58 +227,60 @@ class FP16Optimizer(_FP16OptimizerMixin, optim.FairseqOptimizer): Wrap an *optimizer* to support FP16 (mixed precision) training. """ - def __init__(self, args, params, fp32_optimizer, fp32_params): - super().__init__(args) + def __init__(self, cfg: DictConfig, params, fp32_optimizer, fp32_params, **kwargs): + super().__init__(cfg.optimizer) self.fp16_params = params self.fp32_optimizer = fp32_optimizer self.fp32_params = fp32_params - if getattr(args, "fp16_scale_window", None) is None: - if len(args.update_freq) > 1: + if getattr(cfg.common, "fp16_scale_window", None) is None: + if len(cfg.optimization.update_freq) > 1: raise ValueError( "--fp16-scale-window must be given explicitly when using a " "custom --update-freq schedule" ) data_parallel_size = int( - args.distributed_world_size / args.model_parallel_size + cfg.distributed_training.distributed_world_size + / cfg.common.model_parallel_size + ) + scale_window = int( + 2 ** 14 / data_parallel_size / cfg.optimization.update_freq[0] ) - scale_window = int(2 ** 14 / data_parallel_size / args.update_freq[0]) else: - scale_window = args.fp16_scale_window + scale_window = cfg.common.fp16_scale_window - if not getattr(args, "bf16", False): + if not getattr(cfg.common, "bf16", False): self.scaler = DynamicLossScaler( - init_scale=args.fp16_init_scale, + init_scale=cfg.common.fp16_init_scale, scale_window=scale_window, - tolerance=args.fp16_scale_tolerance, - threshold=args.threshold_loss_scale, - min_loss_scale=args.min_loss_scale, + tolerance=cfg.common.fp16_scale_tolerance, + threshold=cfg.common.threshold_loss_scale, + min_loss_scale=cfg.common.min_loss_scale, ) else: # disable loss scaling for bfloat16 self.scaler = None @classmethod - def build_optimizer(cls, args, params): + def build_optimizer(cls, cfg: DictConfig, params, **kwargs): """ Args: - args (argparse.Namespace): fairseq args + cfg (omegaconf.DictConfig): fairseq args params (iterable): iterable of parameters to optimize """ - flatten = not getattr(args, "fp16_no_flatten_grads", False) - if getattr(args, "bf16", False): + flatten = not getattr(cfg.common, "fp16_no_flatten_grads", False) + if getattr(cfg.common, "bf16", False): flatten = False # mixed precision is faster on TPUs without flat grads - fp32_params = cls.build_fp32_params(args, params, flatten=flatten) + fp32_params = cls.build_fp32_params(cfg.optimizer, params, flatten=flatten) if flatten: - fp32_optimizer = optim.build_optimizer(args, [fp32_params]) + fp32_optimizer = optim.build_optimizer(cfg.optimizer, [fp32_params]) else: - fp32_optimizer = optim.build_optimizer(args, fp32_params) + fp32_optimizer = optim.build_optimizer(cfg.optimizer, fp32_params) if flatten and not fp32_optimizer.supports_flat_params: raise RuntimeError( - "chosen optimizer does not support flat params, " - "please set --fp16-no-flatten-grads" + f"chosen optimizer {fp32_optimizer.__class__.__name__} does not support flat params, please set --fp16-no-flatten-grads" ) - return cls(args, params, fp32_optimizer, fp32_params) + return cls(cfg, params, fp32_optimizer, fp32_params, **kwargs) @property def optimizer(self): @@ -427,49 +430,52 @@ class MemoryEfficientFP16Optimizer( *supports_memory_efficient_fp16* property. """ - def __init__(self, args, params, optimizer): + def __init__(self, cfg: DictConfig, params, optimizer, **kwargs): if not optimizer.supports_memory_efficient_fp16: raise ValueError( "Unsupported optimizer: {}".format(optimizer.__class__.__name__) ) - super().__init__(args) + super().__init__(cfg.optimizer) self.wrapped_optimizer = optimizer - if getattr(args, "fp16_scale_window", None) is None: - if len(args.update_freq) > 1: + if getattr(cfg.common, "fp16_scale_window", None) is None: + if len(cfg.optimization.update_freq) > 1: raise ValueError( "--fp16-scale-window must be given explicitly when using a " "custom --update-freq schedule" ) data_parallel_size = int( - args.distributed_world_size / args.model_parallel_size + cfg.distributed_training.distributed_world_size + / cfg.common.model_parallel_size + ) + scale_window = ( + 2 ** 14 / data_parallel_size / cfg.optimization.update_freq[0] ) - scale_window = 2 ** 14 / data_parallel_size / args.update_freq[0] else: - scale_window = args.fp16_scale_window + scale_window = cfg.common.fp16_scale_window - if not getattr(args, "bf16", False): + if not getattr(cfg.common, "bf16", False): self.scaler = DynamicLossScaler( - init_scale=args.fp16_init_scale, + init_scale=cfg.common.fp16_init_scale, scale_window=scale_window, - tolerance=args.fp16_scale_tolerance, - threshold=args.threshold_loss_scale, - min_loss_scale=args.min_loss_scale, + tolerance=cfg.common.fp16_scale_tolerance, + threshold=cfg.common.threshold_loss_scale, + min_loss_scale=cfg.common.min_loss_scale, ) else: # disable loss scaling for bfloat16 self.scaler = None @classmethod - def build_optimizer(cls, args, params): + def build_optimizer(cls, cfg: DictConfig, params, **kwargs): """ Args: args (argparse.Namespace): fairseq args params (iterable): iterable of parameters to optimize """ - fp16_optimizer = optim.build_optimizer(args, params) - return cls(args, params, fp16_optimizer) + fp16_optimizer = optim.build_optimizer(cfg.optimizer, params) + return cls(cfg, params, fp16_optimizer, **kwargs) @property def optimizer(self): diff --git a/fairseq/optim/lr_scheduler/__init__.py b/fairseq/optim/lr_scheduler/__init__.py index 7b72c257..f07d43c7 100644 --- a/fairseq/optim/lr_scheduler/__init__.py +++ b/fairseq/optim/lr_scheduler/__init__.py @@ -6,8 +6,6 @@ import importlib import os -from argparse import Namespace -from typing import Union from fairseq import registry from fairseq.optim.lr_scheduler.fairseq_lr_scheduler import ( # noqa @@ -27,8 +25,8 @@ from omegaconf import DictConfig ) -def build_lr_scheduler(lr_scheduler_cfg: Union[DictConfig, Namespace], optimizer): - return build_lr_scheduler_(lr_scheduler_cfg, optimizer) +def build_lr_scheduler(cfg: DictConfig, optimizer): + return build_lr_scheduler_(cfg, optimizer) # automatically import any Python files in the optim/lr_scheduler/ directory diff --git a/fairseq/optim/lr_scheduler/cosine_lr_scheduler.py b/fairseq/optim/lr_scheduler/cosine_lr_scheduler.py index 98d55750..c3c6663e 100644 --- a/fairseq/optim/lr_scheduler/cosine_lr_scheduler.py +++ b/fairseq/optim/lr_scheduler/cosine_lr_scheduler.py @@ -4,11 +4,12 @@ # LICENSE file in the root directory of this source tree. import math +from collections import Collection from dataclasses import dataclass, field from typing import List from fairseq.dataclass import FairseqDataclass -from omegaconf import II +from omegaconf import II, DictConfig from . import FairseqLRScheduler, register_lr_scheduler @@ -38,8 +39,8 @@ class CosineConfig(FairseqDataclass): default=0.1, metadata={"help": "shrink factor for annealing"} ) # TODO common var for parent class - lr: List[float] = II("params.optimization.lr") - max_update: int = II("params.optimization.max_update") + lr: List[float] = II("optimization.lr") + max_update: int = II("optimization.max_update") @register_lr_scheduler("cosine", dataclass=CosineConfig) @@ -66,43 +67,51 @@ class CosineSchedule(FairseqLRScheduler): after every iteration. """ - def __init__(self, args, optimizer): - super().__init__(args, optimizer) - if len(args.lr) > 1: + def __init__( + self, cfg: DictConfig, fairseq_optimizer + ): + super().__init__(cfg, fairseq_optimizer) + if isinstance(cfg.lr, Collection) and len(cfg.lr) > 1: raise ValueError( "Cannot use a fixed learning rate schedule with cosine." " Consider --lr-scheduler=fixed instead." ) - warmup_end_lr = args.max_lr - if args.warmup_init_lr < 0: - args.warmup_init_lr = args.lr[0] - - self.min_lr = args.lr[0] - self.max_lr = args.max_lr + warmup_end_lr = cfg.max_lr + lr = ( + cfg.lr[0] + if isinstance(cfg.lr, Collection) + else cfg.lr + ) + if cfg.warmup_init_lr < 0: + cfg.warmup_init_lr = lr + self.min_lr = lr + self.max_lr = cfg.max_lr assert self.max_lr > self.min_lr, "max_lr must be more than lr" - self.t_mult = args.t_mult - self.period = args.lr_period_updates + self.t_mult = cfg.t_mult + self.period = cfg.lr_period_updates if self.period <= 0: assert ( - args.max_update >= 0 + cfg.max_update >= 0 ), "Either --max_update or --lr-period-updates must be set" - self.period = args.max_update - args.warmup_updates + self.period = cfg.max_update - cfg.warmup_updates - if args.warmup_updates > 0: + if cfg.warmup_updates > 0: # linearly warmup for the first args.warmup_updates - self.lr_step = (warmup_end_lr - args.warmup_init_lr) / args.warmup_updates + self.lr_step = ( + warmup_end_lr - cfg.warmup_init_lr + ) / cfg.warmup_updates else: self.lr_step = 1 - self.warmup_updates = args.warmup_updates - self.lr_shrink = args.lr_shrink + self.warmup_updates = cfg.warmup_updates + self.lr_shrink = cfg.lr_shrink # initial learning rate - self.lr = args.warmup_init_lr + self.lr = cfg.warmup_init_lr self.optimizer.set_lr(self.lr) def step(self, epoch, val_loss=None): @@ -113,10 +122,10 @@ class CosineSchedule(FairseqLRScheduler): def step_update(self, num_updates): """Update the learning rate after each update.""" - if num_updates < self.args.warmup_updates: - self.lr = self.args.warmup_init_lr + num_updates * self.lr_step + if num_updates < self.cfg.warmup_updates: + self.lr = self.cfg.warmup_init_lr + num_updates * self.lr_step else: - curr_updates = num_updates - self.args.warmup_updates + curr_updates = num_updates - self.cfg.warmup_updates if self.t_mult != 1: i = math.floor( math.log( diff --git a/fairseq/optim/lr_scheduler/fairseq_lr_scheduler.py b/fairseq/optim/lr_scheduler/fairseq_lr_scheduler.py index 8fde0713..569e4482 100644 --- a/fairseq/optim/lr_scheduler/fairseq_lr_scheduler.py +++ b/fairseq/optim/lr_scheduler/fairseq_lr_scheduler.py @@ -11,11 +11,11 @@ from .. import FairseqOptimizer class FairseqLRScheduler(object): - def __init__(self, args, optimizer): + def __init__(self, cfg, optimizer): super().__init__() if not isinstance(optimizer, FairseqOptimizer): raise ValueError("optimizer must be an instance of FairseqOptimizer") - self.args = args + self.cfg = cfg self.optimizer = optimizer self.best = None diff --git a/fairseq/optim/lr_scheduler/inverse_square_root_schedule.py b/fairseq/optim/lr_scheduler/inverse_square_root_schedule.py index d27261ad..c42e0906 100644 --- a/fairseq/optim/lr_scheduler/inverse_square_root_schedule.py +++ b/fairseq/optim/lr_scheduler/inverse_square_root_schedule.py @@ -3,11 +3,12 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from collections import Collection from dataclasses import dataclass, field from typing import List from fairseq.dataclass import FairseqDataclass -from omegaconf import II +from omegaconf import II, DictConfig from . import FairseqLRScheduler, register_lr_scheduler @@ -25,7 +26,7 @@ class InverseSquareRootScheduleConfig(FairseqDataclass): }, ) # TODO common vars at parent class - lr: List[float] = II("params.optimization.lr") + lr: List[float] = II("optimization.lr") @register_lr_scheduler("inverse_sqrt", dataclass=InverseSquareRootScheduleConfig) @@ -48,25 +49,33 @@ class InverseSquareRootSchedule(FairseqLRScheduler): lr = decay_factor / sqrt(update_num) """ - def __init__(self, args, optimizer): - super().__init__(args, optimizer) - if len(args.lr) > 1: + def __init__(self, cfg: DictConfig, optimizer): + super().__init__(cfg, optimizer) + if isinstance(cfg.lr, Collection) and len(cfg.lr) > 1: raise ValueError( "Cannot use a fixed learning rate schedule with inverse_sqrt." " Consider --lr-scheduler=fixed instead." ) - warmup_end_lr = args.lr[0] - if args.warmup_init_lr < 0: - args.warmup_init_lr = 0 if args.warmup_updates > 0 else warmup_end_lr + warmup_end_lr = ( + cfg.lr[0] + if isinstance(cfg.lr, Collection) + else cfg.lr + ) + if cfg.warmup_init_lr < 0: + cfg.warmup_init_lr = ( + 0 if cfg.warmup_updates > 0 else warmup_end_lr + ) # linearly warmup for the first args.warmup_updates - self.lr_step = (warmup_end_lr - args.warmup_init_lr) / args.warmup_updates + self.lr_step = ( + warmup_end_lr - cfg.warmup_init_lr + ) / cfg.warmup_updates # then, decay prop. to the inverse square root of the update number - self.decay_factor = warmup_end_lr * args.warmup_updates ** 0.5 + self.decay_factor = warmup_end_lr * cfg.warmup_updates ** 0.5 # initial learning rate - self.lr = args.warmup_init_lr + self.lr = cfg.warmup_init_lr self.optimizer.set_lr(self.lr) def step(self, epoch, val_loss=None): @@ -77,8 +86,8 @@ class InverseSquareRootSchedule(FairseqLRScheduler): def step_update(self, num_updates): """Update the learning rate after each update.""" - if num_updates < self.args.warmup_updates: - self.lr = self.args.warmup_init_lr + num_updates * self.lr_step + if num_updates < self.cfg.warmup_updates: + self.lr = self.cfg.warmup_init_lr + num_updates * self.lr_step else: self.lr = self.decay_factor * num_updates ** -0.5 self.optimizer.set_lr(self.lr) diff --git a/fairseq/optim/nag.py b/fairseq/optim/nag.py index 58d2f356..3982a827 100644 --- a/fairseq/optim/nag.py +++ b/fairseq/optim/nag.py @@ -3,12 +3,13 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from collections import Collection from dataclasses import dataclass, field from typing import List import torch from fairseq.dataclass import FairseqDataclass -from omegaconf import II +from omegaconf import II, DictConfig from torch.optim.optimizer import Optimizer, required from . import FairseqOptimizer, register_optimizer @@ -19,13 +20,13 @@ class FairseqNAGConfig(FairseqDataclass): momentum: float = field(default=0.99, metadata={"help": "momentum factor"}) weight_decay: float = field(default=0.0, metadata={"help": "weight decay"}) # TODO common vars in parent class - lr: List[float] = II("params.optimization.lr") + lr: List[float] = II("optimization.lr") @register_optimizer("nag", dataclass=FairseqNAGConfig) class FairseqNAG(FairseqOptimizer): - def __init__(self, args, params): - super().__init__(args) + def __init__(self, cfg: DictConfig, params): + super().__init__(cfg) self._optimizer = NAG(params, **self.optimizer_config) @property @@ -37,9 +38,11 @@ class FairseqNAG(FairseqOptimizer): different learning rate. """ return { - "lr": self.args.lr[0], - "momentum": self.args.momentum, - "weight_decay": self.args.weight_decay, + "lr": self.cfg.lr[0] + if isinstance(self.cfg.lr, Collection) + else self.cfg.lr, + "momentum": self.cfg.momentum, + "weight_decay": self.cfg.weight_decay, } diff --git a/fairseq/optim/shard.py b/fairseq/optim/shard.py index a035a1c1..ecef05b4 100644 --- a/fairseq/optim/shard.py +++ b/fairseq/optim/shard.py @@ -12,7 +12,7 @@ except ImportError: _has_fairscale = False -def shard_(args, optimizer, group): +def shard_(optimizer, group): if not _has_fairscale: raise ImportError( "\n\nPlease install the fairscale package:" "\n\n pip install fairscale" diff --git a/fairseq/options.py b/fairseq/options.py index 1a24fcca..6bc526ce 100644 --- a/fairseq/options.py +++ b/fairseq/options.py @@ -10,13 +10,15 @@ import torch from fairseq import utils from fairseq.data.indexed_dataset import get_available_dataset_impl from fairseq.dataclass.data_class import ( - CheckpointParams, - CommonEvalParams, - CommonParams, - DatasetParams, - DistributedTrainingParams, - EvalLMParams, - OptimizationParams, + CheckpointConfig, + CommonConfig, + CommonEvalConfig, + DatasetConfig, + DistributedTrainingConfig, + EvalLMConfig, + GenerationConfig, + InteractiveConfig, + OptimizationConfig, ) from fairseq.dataclass.utils import gen_parser_from_dataclass @@ -45,6 +47,7 @@ def get_generation_parser(interactive=False, default_task="translation"): add_dataset_args(parser, gen=True) add_distributed_training_args(parser, default_world_size=1) add_generation_args(parser) + add_checkpoint_args(parser) if interactive: add_interactive_args(parser) return parser @@ -67,7 +70,7 @@ def get_validation_parser(default_task=None): add_dataset_args(parser, train=True) add_distributed_training_args(parser, default_world_size=1) group = parser.add_argument_group("Evaluation") - gen_parser_from_dataclass(group, CommonEvalParams()) + gen_parser_from_dataclass(group, CommonEvalConfig()) return parser @@ -210,7 +213,7 @@ def get_parser(desc, default_task="translation"): utils.import_user_module(usr_args) parser = argparse.ArgumentParser(allow_abbrev=False) - gen_parser_from_dataclass(parser, CommonParams()) + gen_parser_from_dataclass(parser, CommonConfig()) from fairseq.registry import REGISTRIES @@ -283,7 +286,7 @@ def add_preprocess_args(parser): def add_dataset_args(parser, train=False, gen=False): group = parser.add_argument_group("dataset_data_loading") - gen_parser_from_dataclass(group, DatasetParams()) + gen_parser_from_dataclass(group, DatasetConfig()) # fmt: on return group @@ -293,7 +296,7 @@ def add_distributed_training_args(parser, default_world_size=None): if default_world_size is None: default_world_size = max(1, torch.cuda.device_count()) gen_parser_from_dataclass( - group, DistributedTrainingParams(distributed_world_size=default_world_size) + group, DistributedTrainingConfig(distributed_world_size=default_world_size) ) return group @@ -301,7 +304,7 @@ def add_distributed_training_args(parser, default_world_size=None): def add_optimization_args(parser): group = parser.add_argument_group("optimization") # fmt: off - gen_parser_from_dataclass(group, OptimizationParams()) + gen_parser_from_dataclass(group, OptimizationConfig()) # fmt: on return group @@ -309,117 +312,31 @@ def add_optimization_args(parser): def add_checkpoint_args(parser): group = parser.add_argument_group("checkpoint") # fmt: off - gen_parser_from_dataclass(group, CheckpointParams()) + gen_parser_from_dataclass(group, CheckpointConfig()) # fmt: on return group def add_common_eval_args(group): - gen_parser_from_dataclass(group, CommonEvalParams()) + gen_parser_from_dataclass(group, CommonEvalConfig()) def add_eval_lm_args(parser): group = parser.add_argument_group("LM Evaluation") add_common_eval_args(group) - gen_parser_from_dataclass(group, EvalLMParams()) + gen_parser_from_dataclass(group, EvalLMConfig()) def add_generation_args(parser): group = parser.add_argument_group("Generation") add_common_eval_args(group) - # fmt: off - group.add_argument('--beam', default=5, type=int, metavar='N', - help='beam size') - group.add_argument('--nbest', default=1, type=int, metavar='N', - help='number of hypotheses to output') - group.add_argument('--max-len-a', default=0, type=float, metavar='N', - help=('generate sequences of maximum length ax + b, ' - 'where x is the source length')) - group.add_argument('--max-len-b', default=200, type=int, metavar='N', - help=('generate sequences of maximum length ax + b, ' - 'where x is the source length')) - group.add_argument('--min-len', default=1, type=float, metavar='N', - help=('minimum generation length')) - group.add_argument('--match-source-len', default=False, action='store_true', - help=('generations should match the source length')) - group.add_argument('--no-early-stop', action='store_true', - help='deprecated') - group.add_argument('--unnormalized', action='store_true', - help='compare unnormalized hypothesis scores') - group.add_argument('--no-beamable-mm', action='store_true', - help='don\'t use BeamableMM in attention layers') - group.add_argument('--lenpen', default=1, type=float, - help='length penalty: <1.0 favors shorter, >1.0 favors longer sentences') - group.add_argument('--unkpen', default=0, type=float, - help='unknown word penalty: <0 produces more unks, >0 produces fewer') - group.add_argument('--replace-unk', nargs='?', const=True, default=None, - help='perform unknown replacement (optionally with alignment dictionary)') - group.add_argument('--sacrebleu', action='store_true', - help='score with sacrebleu') - group.add_argument('--score-reference', action='store_true', - help='just score the reference translation') - group.add_argument('--prefix-size', default=0, type=int, metavar='PS', - help='initialize generation by target prefix of given length') - group.add_argument('--no-repeat-ngram-size', default=0, type=int, metavar='N', - help='ngram blocking such that this size ngram cannot be repeated in the generation') - group.add_argument('--sampling', action='store_true', - help='sample hypotheses instead of using beam search') - group.add_argument('--sampling-topk', default=-1, type=int, metavar='PS', - help='sample from top K likely next words instead of all words') - group.add_argument('--sampling-topp', default=-1.0, type=float, metavar='PS', - help='sample from the smallest set whose cumulative probability mass exceeds p for next words') - group.add_argument('--constraints', const="ordered", nargs="?", choices=["ordered", "unordered"], - help='enables lexically constrained decoding') - group.add_argument('--temperature', default=1., type=float, metavar='N', - help='temperature for generation') - group.add_argument('--diverse-beam-groups', default=-1, type=int, metavar='N', - help='number of groups for Diverse Beam Search') - group.add_argument('--diverse-beam-strength', default=0.5, type=float, metavar='N', - help='strength of diversity penalty for Diverse Beam Search') - group.add_argument('--diversity-rate', default=-1.0, type=float, metavar='N', - help='strength of diversity penalty for Diverse Siblings Search') - group.add_argument('--print-alignment', action='store_true', - help='if set, uses attention feedback to compute and print alignment to source tokens') - group.add_argument('--print-step', action='store_true') - - group.add_argument('--lm-path', default=None, type=str, metavar='PATH', - help='path to lm checkpoint for lm fusion') - group.add_argument('--lm-weight', default=0.0, type=float, metavar='N', - help='weight for lm probs for lm fusion') - - # arguments for iterative refinement generator - group.add_argument('--iter-decode-eos-penalty', default=0.0, type=float, metavar='N', - help='if > 0.0, it penalized early-stopping in decoding.') - group.add_argument('--iter-decode-max-iter', default=10, type=int, metavar='N', - help='maximum iterations for iterative refinement.') - group.add_argument('--iter-decode-force-max-iter', action='store_true', - help='if set, run exact the maximum number of iterations without early stop') - group.add_argument('--iter-decode-with-beam', default=1, type=int, metavar='N', - help='if > 1, model will generate translations varying by the lengths.') - group.add_argument('--iter-decode-with-external-reranker', action='store_true', - help='if set, the last checkpoint are assumed to be a reranker to rescore the translations'), - group.add_argument('--retain-iter-history', action='store_true', - help='if set, decoding returns the whole history of iterative refinement') - group.add_argument('--retain-dropout', action='store_true', - help='Use dropout at inference time') - group.add_argument('--retain-dropout-modules', default=None, nargs='+', type=str, - help='if set, only retain dropout for the specified modules; ' - 'if not set, then dropout will be retained for all modules') - - # special decoding format for advanced decoding. - group.add_argument('--decoding-format', default=None, type=str, choices=['unigram', 'ensemble', 'vote', 'dp', 'bs']) - # fmt: on + gen_parser_from_dataclass(group, GenerationConfig()) return group def add_interactive_args(parser): group = parser.add_argument_group("Interactive") - # fmt: off - group.add_argument('--buffer-size', default=0, type=int, metavar='N', - help='read this many sentences into a buffer before processing them') - group.add_argument('--input', default='-', type=str, metavar='FILE', - help='file to read from; use - for stdin') - # fmt: on + gen_parser_from_dataclass(group, InteractiveConfig()) def add_model_args(parser): diff --git a/fairseq/quantization_utils.py b/fairseq/quantization_utils.py index 69dd61d7..11fc414c 100644 --- a/fairseq/quantization_utils.py +++ b/fairseq/quantization_utils.py @@ -6,13 +6,14 @@ import logging from fairseq.modules.quantization import pq, quantization_options, scalar +from omegaconf import DictConfig logger = logging.getLogger(__name__) -def quantize_model_scalar(model, args): - quant_noise_scalar = getattr(args, "quant_noise_scalar", 0) +def quantize_model_scalar(model, model_cfg: DictConfig): + quant_noise_scalar = getattr(model_cfg, "quant_noise_scalar", 0) or 0 if quant_noise_scalar > 0: # quantize_model edits the model in place scalar.quantize_model_(model, p=quant_noise_scalar, bits=8, update_step=1000) diff --git a/fairseq/registry.py b/fairseq/registry.py index 382dec22..4446084d 100644 --- a/fairseq/registry.py +++ b/fairseq/registry.py @@ -3,14 +3,13 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import argparse from argparse import Namespace + from typing import Union - from fairseq.dataclass import FairseqDataclass +from fairseq.dataclass.utils import populate_dataclass from omegaconf import DictConfig - REGISTRIES = {} @@ -25,33 +24,30 @@ def setup_registry(registry_name: str, base_class=None, default=None, required=F # maintain a registry of all registries if registry_name in REGISTRIES: return # registry already exists - REGISTRIES[registry_name] = {"registry": REGISTRY, "default": default} + REGISTRIES[registry_name] = {"registry": REGISTRY, "default": default, "dataclass_registry": DATACLASS_REGISTRY} - def build_x(args: Union[DictConfig, Namespace], *extra_args, **extra_kwargs): - if isinstance(args, DictConfig): - if getattr(args, "_name", None) is not None: - choice = args._name - elif hasattr(args, registry_name): - choice = args.registry_name - else: - raise RuntimeError( - f"Neither _name nor {registry_name} in args, args = {args}" - ) + def build_x(cfg: Union[DictConfig, str, Namespace], *extra_args, **extra_kwargs): + if isinstance(cfg, DictConfig): + choice = cfg._name + elif isinstance(cfg, str): + choice = cfg else: - choice = getattr(args, registry_name, None) + choice = getattr(cfg, registry_name, None) + if choice in DATACLASS_REGISTRY: + cfg = populate_dataclass(cfg, DATACLASS_REGISTRY[choice]()) if choice is None: if required: - raise ValueError("--{} is required!".format(registry_name)) + raise ValueError('{} is required!'.format(registry_name)) return None + cls = REGISTRY[choice] if hasattr(cls, "build_" + registry_name): builder = getattr(cls, "build_" + registry_name) else: builder = cls - if isinstance(args, Namespace): - set_defaults(args, cls) - return builder(args, *extra_args, **extra_kwargs) + + return builder(cfg, *extra_args, **extra_kwargs) def register_x(name, dataclass=None): def register_x_cls(cls): @@ -77,30 +73,10 @@ def setup_registry(registry_name: str, base_class=None, default=None, required=F cls.__dataclass = dataclass REGISTRY[name] = cls - DATACLASS_REGISTRY[name] = cls.__dataclass - REGISTRY_CLASS_NAMES.add(cls.__name__) + if cls.__dataclass is not None: + DATACLASS_REGISTRY[name] = cls.__dataclass return cls return register_x_cls return build_x, register_x, REGISTRY, DATACLASS_REGISTRY - - -def set_defaults(args: Namespace, cls): - """Helper to set default arguments based on *add_args*.""" - if not hasattr(cls, "add_args"): - return - parser = argparse.ArgumentParser( - argument_default=argparse.SUPPRESS, allow_abbrev=False - ) - cls.add_args(parser) - # copied from argparse.py: - defaults = argparse.Namespace() - for action in parser._actions: - if action.dest is not argparse.SUPPRESS: - if not hasattr(defaults, action.dest): - if action.default is not argparse.SUPPRESS: - setattr(defaults, action.dest, action.default) - for key, default_value in vars(defaults).items(): - if not hasattr(args, key): - setattr(args, key, default_value) diff --git a/fairseq/scoring/__init__.py b/fairseq/scoring/__init__.py index 4be0cb51..8c706cb5 100644 --- a/fairseq/scoring/__init__.py +++ b/fairseq/scoring/__init__.py @@ -9,11 +9,12 @@ import os from abc import ABC, abstractmethod from fairseq import registry +from omegaconf import DictConfig class BaseScorer(ABC): - def __init__(self, args): - self.args = args + def __init__(self, cfg): + self.cfg = cfg self.ref = [] self.pred = [] @@ -39,19 +40,17 @@ _build_scorer, register_scorer, SCORER_REGISTRY, _ = registry.setup_registry( ) -def build_scorer(args, tgt_dict): - from fairseq import utils +def build_scorer(choice, tgt_dict): + if isinstance(choice, DictConfig): + choice = choice._name - if args.sacrebleu: - utils.deprecation_warning( - "--sacrebleu is deprecated. Please use --scoring sacrebleu instead." - ) - args.scoring = "sacrebleu" - if args.scoring == "bleu": + if choice == "bleu": from fairseq.scoring import bleu - return bleu.Scorer(tgt_dict.pad(), tgt_dict.eos(), tgt_dict.unk()) - return _build_scorer(args) + return bleu.Scorer( + bleu.BleuConfig(pad=tgt_dict.pad(), eos=tgt_dict.eos(), unk=tgt_dict.unk()) + ) + return _build_scorer(choice) # automatically import any Python files in the current directory diff --git a/fairseq/scoring/bleu.py b/fairseq/scoring/bleu.py index 7f8bd73b..97de5f96 100644 --- a/fairseq/scoring/bleu.py +++ b/fairseq/scoring/bleu.py @@ -6,8 +6,10 @@ import ctypes import math import sys +from dataclasses import dataclass, field import torch +from fairseq.dataclass import FairseqDataclass from fairseq.scoring import BaseScorer, register_scorer from fairseq.scoring.tokenizer import EvaluationTokenizer @@ -27,31 +29,32 @@ class BleuStat(ctypes.Structure): ] -@register_scorer("sacrebleu") +@dataclass +class SacrebleuConfig(FairseqDataclass): + sacrebleu_tokenizer: EvaluationTokenizer.ALL_TOKENIZER_TYPES = field( + default="13a", metadata={"help": "tokenizer"} + ) + sacrebleu_lowercase: bool = field( + default=False, metadata={"help": "apply lowercasing"} + ) + sacrebleu_char_level: bool = field( + default=False, metadata={"help": "evaluate at character level"} + ) + + +@register_scorer("sacrebleu", dataclass=SacrebleuConfig) class SacrebleuScorer(BaseScorer): - def __init__(self, args): - super(SacrebleuScorer, self).__init__(args) + def __init__(self, cfg): + super(SacrebleuScorer, self).__init__(cfg) import sacrebleu self.sacrebleu = sacrebleu self.tokenizer = EvaluationTokenizer( - tokenizer_type=self.args.sacrebleu_tokenizer, - lowercase=self.args.sacrebleu_lowercase, - character_tokenization=self.args.sacrebleu_char_level, + tokenizer_type=cfg.sacrebleu_tokenizer, + lowercase=cfg.sacrebleu_lowercase, + character_tokenization=cfg.sacrebleu_char_level, ) - @staticmethod - def add_args(parser): - # fmt: off - parser.add_argument('--sacrebleu-tokenizer', type=str, default='13a', - choices=EvaluationTokenizer.ALL_TOKENIZER_TYPES, - help='tokenizer') - parser.add_argument('--sacrebleu-lowercase', type=str, default=False, - help='apply lowercasing') - parser.add_argument('--sacrebleu-char-level', action='store_true', - help='evaluate at character level') - # fmt: on - def add_string(self, ref, pred): self.ref.append(self.tokenizer.tokenize(ref)) self.pred.append(self.tokenizer.tokenize(pred)) @@ -68,13 +71,20 @@ class SacrebleuScorer(BaseScorer): ).format() -@register_scorer("bleu") +@dataclass +class BleuConfig(FairseqDataclass): + pad: int = field(default=1, metadata={"help": "padding index"}) + eos: int = field(default=2, metadata={"help": "eos index"}) + unk: int = field(default=3, metadata={"help": "unk index"}) + + +@register_scorer("bleu", dataclass=BleuConfig) class Scorer(object): - def __init__(self, pad, eos, unk): + def __init__(self, cfg): self.stat = BleuStat() - self.pad = pad - self.eos = eos - self.unk = unk + self.pad = cfg.pad + self.eos = cfg.eos + self.unk = cfg.unk try: from fairseq import libbleu diff --git a/fairseq/scoring/tokenizer.py b/fairseq/scoring/tokenizer.py index dbcc6e4d..0d0702bf 100644 --- a/fairseq/scoring/tokenizer.py +++ b/fairseq/scoring/tokenizer.py @@ -5,6 +5,8 @@ import unicodedata +from fairseq.dataclass.utils import ChoiceEnum + class EvaluationTokenizer(object): """A generic evaluation-time tokenizer, which leverages built-in tokenizers @@ -22,7 +24,7 @@ class EvaluationTokenizer(object): SPACE = chr(32) SPACE_ESCAPE = chr(9601) - ALL_TOKENIZER_TYPES = ["none", "13a", "intl", "zh", "ja-mecab"] + ALL_TOKENIZER_TYPES = ChoiceEnum(["none", "13a", "intl", "zh", "ja-mecab"]) def __init__( self, @@ -33,7 +35,7 @@ class EvaluationTokenizer(object): ): from sacrebleu.tokenizers import TOKENIZERS - assert tokenizer_type in self.ALL_TOKENIZER_TYPES + assert tokenizer_type in TOKENIZERS, f"{tokenizer_type}, {TOKENIZERS}" self.lowercase = lowercase self.punctuation_removal = punctuation_removal self.character_tokenization = character_tokenization diff --git a/fairseq/scoring/wer.py b/fairseq/scoring/wer.py index 21efefd9..633dc47c 100644 --- a/fairseq/scoring/wer.py +++ b/fairseq/scoring/wer.py @@ -3,14 +3,31 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from dataclasses import dataclass, field + +from fairseq.dataclass import FairseqDataclass from fairseq.scoring import BaseScorer, register_scorer from fairseq.scoring.tokenizer import EvaluationTokenizer -@register_scorer("wer") +@dataclass +class WerScorerConfig(FairseqDataclass): + wer_tokenizer: EvaluationTokenizer.ALL_TOKENIZER_TYPES = field( + default="none", metadata={"help": "sacreBLEU tokenizer to use for evaluation"} + ) + wer_remove_punct: bool = field( + default=False, metadata={"help": "remove punctuation"} + ) + wer_char_level: bool = field( + default=False, metadata={"help": "evaluate at character level"} + ) + wer_lowercase: bool = field(default=False, metadata={"help": "lowercasing"}) + + +@register_scorer("wer", dataclass=WerScorerConfig) class WerScorer(BaseScorer): - def __init__(self, args): - super().__init__(args) + def __init__(self, cfg): + super().__init__(cfg) self.reset() try: import editdistance as ed @@ -18,26 +35,12 @@ class WerScorer(BaseScorer): raise ImportError("Please install editdistance to use WER scorer") self.ed = ed self.tokenizer = EvaluationTokenizer( - tokenizer_type=self.args.wer_tokenizer, - lowercase=self.args.wer_lowercase, - punctuation_removal=self.args.wer_remove_punct, - character_tokenization=self.args.wer_char_level, + tokenizer_type=self.cfg.wer_tokenizer, + lowercase=self.cfg.wer_lowercase, + punctuation_removal=self.cfg.wer_remove_punct, + character_tokenization=self.cfg.wer_char_level, ) - @staticmethod - def add_args(parser): - # fmt: off - parser.add_argument('--wer-tokenizer', type=str, default='none', - choices=EvaluationTokenizer.ALL_TOKENIZER_TYPES, - help='sacreBLEU tokenizer to use for evaluation') - parser.add_argument('--wer-remove-punct', action='store_true', - help='remove punctuation') - parser.add_argument('--wer-char-level', action='store_true', - help='evaluate at character level') - parser.add_argument('--wer-lowercase', action='store_true', - help='lowercasing') - # fmt: on - def reset(self): self.distance = 0 self.ref_length = 0 diff --git a/fairseq/tasks/__init__.py b/fairseq/tasks/__init__.py index e0abce25..41f461f8 100644 --- a/fairseq/tasks/__init__.py +++ b/fairseq/tasks/__init__.py @@ -7,8 +7,6 @@ import argparse import importlib import os -from argparse import Namespace -from typing import Union from fairseq.dataclass import FairseqDataclass from omegaconf import DictConfig @@ -22,10 +20,10 @@ TASK_REGISTRY = {} TASK_CLASS_NAMES = set() -def setup_task(task_cfg: Union[DictConfig, Namespace], **kwargs): - if isinstance(task_cfg, DictConfig): - return TASK_REGISTRY[task_cfg._name].setup_task(task_cfg, **kwargs) - return TASK_REGISTRY[task_cfg.task].setup_task(task_cfg, **kwargs) +def setup_task(cfg: DictConfig, **kwargs): + if isinstance(cfg, DictConfig): + return TASK_REGISTRY[cfg._name].setup_task(cfg, **kwargs) + return TASK_REGISTRY[cfg.task].setup_task(cfg, **kwargs) def register_task(name, dataclass=None): @@ -70,7 +68,8 @@ def register_task(name, dataclass=None): ) cls.__dataclass = dataclass - TASK_DATACLASS_REGISTRY[name] = dataclass + if dataclass is not None: + TASK_DATACLASS_REGISTRY[name] = dataclass return cls diff --git a/fairseq/tasks/audio_pretraining.py b/fairseq/tasks/audio_pretraining.py index ff2342af..a831ad6e 100644 --- a/fairseq/tasks/audio_pretraining.py +++ b/fairseq/tasks/audio_pretraining.py @@ -79,7 +79,7 @@ class AudioPretrainingTask(LegacyFairseqTask): """Setup the task (e.g., load dictionaries). Args: - args (argparse.Namespace): parsed command-line arguments + args (omegaconf.DictConfig): parsed command-line arguments """ return cls(args) diff --git a/fairseq/tasks/fairseq_task.py b/fairseq/tasks/fairseq_task.py index 0a96aeb1..3cdb64cf 100644 --- a/fairseq/tasks/fairseq_task.py +++ b/fairseq/tasks/fairseq_task.py @@ -12,6 +12,7 @@ import torch from fairseq import metrics, search, tokenizer, utils from fairseq.data import Dictionary, FairseqDataset, data_utils, encoders, iterators from fairseq.dataclass.utils import gen_parser_from_dataclass +from omegaconf import DictConfig logger = logging.getLogger(__name__) @@ -39,8 +40,8 @@ class FairseqTask(object): """ return criterion.logging_outputs_can_be_summed() - def __init__(self, args): - self.args = args + def __init__(self, cfg: DictConfig, **kwargs): + self.cfg = cfg self.datasets = {} self.dataset_to_epoch_iter = {} @@ -78,16 +79,16 @@ class FairseqTask(object): return d @classmethod - def setup_task(cls, args, **kwargs): + def setup_task(cls, cfg: DictConfig, **kwargs): """Setup the task (e.g., load dictionaries). Args: - args (argparse.Namespace): parsed command-line arguments + cfg (omegaconf.DictConfig): parsed command-line arguments """ - return cls(args, **kwargs) + return cls(cfg, **kwargs) def has_sharded_data(self, split): - return os.pathsep in getattr(self.args, "data", "") + return os.pathsep in getattr(self.cfg, "data", "") def load_dataset(self, split, combine=False, **kwargs): """Load a given dataset split. @@ -254,39 +255,39 @@ class FairseqTask(object): return epoch_iter - def build_model(self, args): + def build_model(self, cfg: DictConfig): """ Build the :class:`~fairseq.models.BaseFairseqModel` instance for this task. Args: - args (argparse.Namespace): parsed command-line arguments + cfg (omegaconf.DictConfig): configuration object Returns: a :class:`~fairseq.models.BaseFairseqModel` instance """ from fairseq import models, quantization_utils - model = models.build_model(args, self) - if getattr(args, "tpu", False): + model = models.build_model(cfg, self) + if getattr(cfg, "tpu", False): model.prepare_for_tpu_() - model = quantization_utils.quantize_model_scalar(model, args) + model = quantization_utils.quantize_model_scalar(model, cfg) return model - def build_criterion(self, args): + def build_criterion(self, cfg: DictConfig): """ Build the :class:`~fairseq.criterions.FairseqCriterion` instance for this task. Args: - args (argparse.Namespace): parsed command-line arguments + cfg (omegaconf.DictConfig): configration object Returns: a :class:`~fairseq.criterions.FairseqCriterion` instance """ from fairseq import criterions - return criterions.build_criterion(args, self) + return criterions.build_criterion(cfg, self) def build_generator( self, models, args, seq_gen_cls=None, extra_gen_cls_kwargs=None diff --git a/fairseq/tasks/language_modeling.py b/fairseq/tasks/language_modeling.py index 8792c648..6e85417f 100644 --- a/fairseq/tasks/language_modeling.py +++ b/fairseq/tasks/language_modeling.py @@ -28,7 +28,7 @@ from fairseq.data import ( from fairseq.data.indexed_dataset import get_available_dataset_impl from fairseq.data.shorten_dataset import maybe_shorten_dataset from fairseq.dataclass import ChoiceEnum, FairseqDataclass -from fairseq.tasks import FairseqTask, register_task +from fairseq.tasks import LegacyFairseqTask, register_task from omegaconf import II @@ -85,16 +85,16 @@ class LanguageModelingConfig(FairseqDataclass): }, ) # TODO common vars below add to parent - seed: int = II("params.common.seed") + seed: int = II("common.seed") dataset_impl: Optional[ChoiceEnum(get_available_dataset_impl())] = II( - "params.dataset.dataset_impl" + "dataset.dataset_impl" ) - data_buffer_size: int = II("params.dataset.data_buffer_size") - tpu: bool = II("params.common.tpu") + data_buffer_size: int = II("dataset.data_buffer_size") + tpu: bool = II("common.tpu") @register_task("language_modeling", dataclass=LanguageModelingConfig) -class LanguageModelingTask(FairseqTask): +class LanguageModelingTask(LegacyFairseqTask): """ Train a language model. diff --git a/fairseq/tasks/multilingual_translation.py b/fairseq/tasks/multilingual_translation.py index f6cb17f1..26e0b529 100644 --- a/fairseq/tasks/multilingual_translation.py +++ b/fairseq/tasks/multilingual_translation.py @@ -117,7 +117,7 @@ class MultilingualTranslationTask(LegacyFairseqTask): return cls(args, dicts, training) @classmethod - def prepare(cls, args, **kargs): + def update_args(cls, args): args.left_pad_source = utils.eval_bool(args.left_pad_source) args.left_pad_target = utils.eval_bool(args.left_pad_target) @@ -127,6 +127,10 @@ class MultilingualTranslationTask(LegacyFairseqTask): ) if isinstance(args.lang_pairs, str): args.lang_pairs = args.lang_pairs.split(",") + + @classmethod + def prepare(cls, args, **kargs): + cls.update_args(args) sorted_langs = sorted( list({x for lang_pair in args.lang_pairs for x in lang_pair.split("-")}) ) @@ -298,6 +302,10 @@ class MultilingualTranslationTask(LegacyFairseqTask): if len(messages) > 0: raise ValueError(" ".join(messages)) + # Update args -> the fact that the constructor here + # changes the args object doesn't mean you get the same one here + self.update_args(args) + # Check if task args are consistant with model args check_args() diff --git a/fairseq/tasks/speech_to_text.py b/fairseq/tasks/speech_to_text.py index 6d222f0d..c200bb14 100644 --- a/fairseq/tasks/speech_to_text.py +++ b/fairseq/tasks/speech_to_text.py @@ -13,7 +13,7 @@ from fairseq.data.audio.speech_to_text_dataset import ( SpeechToTextDataset, SpeechToTextDatasetCreator, ) -from fairseq.tasks import FairseqTask, register_task +from fairseq.tasks import LegacyFairseqTask, register_task logging.basicConfig( @@ -25,7 +25,7 @@ logger = logging.getLogger(__name__) @register_task("speech_to_text") -class SpeechToTextTask(FairseqTask): +class SpeechToTextTask(LegacyFairseqTask): @staticmethod def add_args(parser): parser.add_argument("data", help="manifest root path") diff --git a/fairseq/trainer.py b/fairseq/trainer.py index 0069b794..8b00e8b4 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -11,15 +11,18 @@ import contextlib import logging import sys import time +from argparse import Namespace from itertools import chain from typing import Any, Dict, List import torch from fairseq import checkpoint_utils, distributed_utils, models, optim, utils +from fairseq.dataclass.utils import convert_namespace_to_omegaconf from fairseq.file_io import PathManager from fairseq.logging import meters, metrics from fairseq.nan_detector import NanDetector from fairseq.optim import lr_scheduler +from omegaconf import DictConfig logger = logging.getLogger(__name__) @@ -35,19 +38,25 @@ class Trainer(object): communication of the gradients across workers. """ - def __init__(self, args, task, model, criterion, quantizer=None): - self.args = args + def __init__(self, cfg: DictConfig, task, model, criterion, quantizer=None): + + if isinstance(cfg, Namespace): + logger.warning( + "argparse.Namespace configuration is deprecated! Automatically converting to OmegaConf" + ) + cfg = convert_namespace_to_omegaconf(cfg) + + self.cfg = cfg self.task = task # catalog shared parameters shared_params = _catalog_shared_params(model) - - self.tpu = getattr(args, "tpu", False) - self.cuda = torch.cuda.is_available() and not args.cpu and not self.tpu + self.tpu = cfg.common.tpu + self.cuda = torch.cuda.is_available() and not cfg.common.cpu and not self.tpu if self.cuda: self.device = torch.device("cuda") elif self.tpu: - self.device = utils.get_tpu_device(args) + self.device = utils.get_tpu_device() else: self.device = torch.device("cpu") @@ -58,19 +67,21 @@ class Trainer(object): import torch_xla.core.xla_model as xm self._model = xm.send_cpu_data_to_device(self._model, self.device) - if args.fp16: + if cfg.common.fp16: self._criterion = self._criterion.half() self._model = self._model.half() - elif args.bf16: + elif cfg.common.bf16: self._criterion = self._criterion.to(dtype=torch.bfloat16) self._model = self._model.to(dtype=torch.bfloat16) - if not args.pipeline_model_parallel: + if not cfg.distributed_training.pipeline_model_parallel: self._criterion = self._criterion.to(device=self.device) self._model = self._model.to(device=self.device) - self.pipeline_model_parallel = args.pipeline_model_parallel + self.pipeline_model_parallel = cfg.distributed_training.pipeline_model_parallel self.last_device = None if self.cuda and self.pipeline_model_parallel: - self.last_device = torch.device(args.pipeline_devices[-1]) + self.last_device = torch.device( + cfg.distributed_training.pipeline_devices[-1] + ) # check that shared parameters are preserved after device transfer for shared_param in shared_params: @@ -129,7 +140,7 @@ class Trainer(object): @property def data_parallel_world_size(self): - return self.args.distributed_world_size + return self.cfg.distributed_training.distributed_world_size @property def data_parallel_process_group(self): @@ -140,11 +151,11 @@ class Trainer(object): @property def data_parallel_rank(self): - return self.args.distributed_rank + return self.cfg.distributed_training.distributed_rank @property def is_data_parallel_master(self): - return distributed_utils.is_master(self.args) + return distributed_utils.is_master(self.cfg.distributed_training) @property def criterion(self): @@ -152,11 +163,11 @@ class Trainer(object): if ( utils.has_parameters(self._criterion) and self.data_parallel_world_size > 1 - and not self.args.use_bmuf + and not self.cfg.optimization.use_bmuf and not self.tpu ): self._wrapped_criterion = models.DistributedFairseqModel( - self.args, + self.cfg.distributed_training, self._criterion, process_group=self.data_parallel_process_group, ) @@ -169,11 +180,11 @@ class Trainer(object): if self._wrapped_model is None: if ( self.data_parallel_world_size > 1 - and not self.args.use_bmuf + and not self.cfg.optimization.use_bmuf and not self.tpu ): self._wrapped_model = models.DistributedFairseqModel( - self.args, + self.cfg.distributed_training, self._model, process_group=self.data_parallel_process_group, ) @@ -201,44 +212,51 @@ class Trainer(object): ) ) - if self.args.fp16 or self.args.bf16: + if self.cfg.common.fp16 or self.cfg.common.bf16: if self.cuda and torch.cuda.get_device_capability(0)[0] < 7: logger.info( "NOTE: your device does NOT support faster training with --fp16, " "please switch to FP32 which is likely to be faster" ) - if self.args.memory_efficient_fp16 or self.args.memory_efficient_bf16: + if ( + self.cfg.common.memory_efficient_fp16 + or self.cfg.common.memory_efficient_bf16 + ): self._optimizer = optim.MemoryEfficientFP16Optimizer.build_optimizer( - self.args, params + self.cfg, params ) else: - self._optimizer = optim.FP16Optimizer.build_optimizer(self.args, params) + self._optimizer = optim.FP16Optimizer.build_optimizer(self.cfg, params) else: if self.cuda and torch.cuda.get_device_capability(0)[0] >= 7: logger.info("NOTE: your device may support faster training with --fp16") - self._optimizer = optim.build_optimizer(self.args, params) + self._optimizer = optim.build_optimizer(self.cfg.optimizer, params) - if self.args.use_bmuf: - self._optimizer = optim.FairseqBMUF(self.args, self._optimizer) + if self.cfg.optimization.use_bmuf: + self._optimizer = optim.FairseqBMUF( + self.cfg.bmuf, + self._optimizer, + ) - if self.args.zero_sharding == "os": + if self.cfg.distributed_training.zero_sharding == "os": if ( - self.args.fp16 - and not self.args.memory_efficient_fp16 - and not self.args.memory_efficient_bf16 - ) and not self.args.fp16_no_flatten_grads: + self.cfg.common.fp16 + and not self.cfg.common.memory_efficient_fp16 + and not self.cfg.common.memory_efficient_bf16 + ) and not self.cfg.common.fp16_no_flatten_grads: raise ValueError( "ZeRO is incomptabile with fp16 and flattened grads. " "Please use --fp16-no-flatten-grads" ) else: - optim.shard_( - self.args, self._optimizer, self.data_parallel_process_group - ) + optim.shard_(self._optimizer, self.data_parallel_process_group) # We should initialize the learning rate scheduler immediately after # building the optimizer, so that the initial learning rate is set. - self._lr_scheduler = lr_scheduler.build_lr_scheduler(self.args, self.optimizer) + self._lr_scheduler = lr_scheduler.build_lr_scheduler( + self.cfg.lr_scheduler, + self.optimizer, + ) self._lr_scheduler.step_update(0) def consolidate_optimizer(self): @@ -253,7 +271,7 @@ class Trainer(object): extra_state["previous_training_time"] = self.cumulative_training_time() checkpoint_utils.save_state( filename, - self.args, + self.cfg, self.get_model().state_dict(), self.get_criterion(), self.optimizer, @@ -277,11 +295,10 @@ class Trainer(object): bexists = PathManager.isfile(filename) if bexists: state = checkpoint_utils.load_checkpoint_to_cpu(filename) - # load model parameters try: self.get_model().load_state_dict( - state["model"], strict=True, args=self.args + state["model"], strict=True, model_cfg=self.cfg.model ) if utils.has_parameters(self.get_criterion()): self.get_criterion().load_state_dict( @@ -355,28 +372,28 @@ class Trainer(object): if load_dataset: logger.info("loading train data for epoch {}".format(epoch)) self.task.load_dataset( - self.args.train_subset, + self.cfg.dataset.train_subset, epoch=epoch, combine=combine, data_selector=data_selector, ) batch_iterator = self.task.get_batch_iterator( - dataset=self.task.dataset(self.args.train_subset), - max_tokens=self.args.max_tokens, - max_sentences=self.args.batch_size, + dataset=self.task.dataset(self.cfg.dataset.train_subset), + max_tokens=self.cfg.dataset.max_tokens, + max_sentences=self.cfg.dataset.batch_size, max_positions=utils.resolve_max_positions( self.task.max_positions(), self.model.max_positions(), - self.args.max_tokens, + self.cfg.dataset.max_tokens, ), ignore_invalid_inputs=True, - required_batch_size_multiple=self.args.required_batch_size_multiple, - seed=self.args.seed, + required_batch_size_multiple=self.cfg.dataset.required_batch_size_multiple, + seed=self.cfg.common.seed, num_shards=self.data_parallel_world_size if shard_batch_itr else 1, shard_id=self.data_parallel_rank if shard_batch_itr else 0, - num_workers=self.args.num_workers, + num_workers=self.cfg.dataset.num_workers, epoch=epoch, - data_buffer_size=self.args.data_buffer_size, + data_buffer_size=self.cfg.dataset.data_buffer_size, disable_iterator_cache=disable_iterator_cache, ) self.reset_dummy_batch(batch_iterator.first_batch) @@ -390,19 +407,19 @@ class Trainer(object): """Return an EpochBatchIterator over given validation subset for a given epoch.""" batch_iterator = self.task.get_batch_iterator( dataset=self.task.dataset(subset), - max_tokens=self.args.max_tokens_valid, - max_sentences=self.args.batch_size_valid, + max_tokens=self.cfg.dataset.max_tokens_valid, + max_sentences=self.cfg.dataset.batch_size_valid, max_positions=utils.resolve_max_positions( self.task.max_positions(), self.model.max_positions(), ), - ignore_invalid_inputs=self.args.skip_invalid_size_inputs_valid_test, - required_batch_size_multiple=self.args.required_batch_size_multiple, - seed=self.args.seed, + ignore_invalid_inputs=self.cfg.dataset.skip_invalid_size_inputs_valid_test, + required_batch_size_multiple=self.cfg.dataset.required_batch_size_multiple, + seed=self.cfg.common.seed, num_shards=self.data_parallel_world_size, shard_id=self.data_parallel_rank, - num_workers=self.args.num_workers, - data_buffer_size=self.args.data_buffer_size, + num_workers=self.cfg.dataset.num_workers, + data_buffer_size=self.cfg.dataset.data_buffer_size, disable_iterator_cache=disable_iterator_cache, ) self.reset_dummy_batch(batch_iterator.first_batch) @@ -504,7 +521,7 @@ class Trainer(object): self.zero_grad() if self.cuda: torch.cuda.empty_cache() - if self.args.distributed_world_size == 1: + if self.cfg.distributed_training.distributed_world_size == 1: return None else: raise e @@ -565,7 +582,7 @@ class Trainer(object): # multiply gradients by (# GPUs / sample_size) since DDP # already normalizes by the number of GPUs. Thus we get # (sum_of_gradients / sample_size). - if not self.args.use_bmuf: + if not self.cfg.optimization.use_bmuf: self.optimizer.multiply_grads( self.data_parallel_world_size / sample_size ) @@ -575,12 +592,12 @@ class Trainer(object): with torch.autograd.profiler.record_function("clip-grads"): # clip grads - grad_norm = self.clip_grad_norm(self.args.clip_norm) + grad_norm = self.clip_grad_norm(self.cfg.optimization.clip_norm) # check that grad norms are consistent across workers if ( - not self.args.use_bmuf - and self.args.distributed_wrapper != "SlowMo" + not self.cfg.optimization.use_bmuf + and self.cfg.distributed_training.distributed_wrapper != "SlowMo" and not self.tpu ): self._check_grad_norms(grad_norm) @@ -624,7 +641,10 @@ class Trainer(object): self.optimizer.optimizer ) - if not overflow or self.args.distributed_wrapper == "SlowMo": + if ( + not overflow + or self.cfg.distributed_training.distributed_wrapper == "SlowMo" + ): self.set_num_updates(self.get_num_updates() + 1) if self.tpu: @@ -636,7 +656,7 @@ class Trainer(object): # only log stats every log_interval steps # this causes wps to be misreported when log_interval > 1 logging_output = {} - if self.get_num_updates() % self.args.log_interval == 0: + if self.get_num_updates() % self.cfg.common.log_interval == 0: # log memory usage mem_info = xm.get_memory_info(self.device) gb_free = mem_info["kb_free"] / 1024 / 1024 @@ -677,16 +697,16 @@ class Trainer(object): # clear CUDA cache to reduce memory fragmentation if ( self.cuda - and self.args.empty_cache_freq > 0 + and self.cfg.common.empty_cache_freq > 0 and ( - (self.get_num_updates() + self.args.empty_cache_freq - 1) - % self.args.empty_cache_freq + (self.get_num_updates() + self.cfg.common.empty_cache_freq - 1) + % self.cfg.common.empty_cache_freq ) == 0 ): torch.cuda.empty_cache() - if self.args.fp16: + if self.cfg.common.fp16: metrics.log_scalar( "loss_scale", self.optimizer.scaler.loss_scale, @@ -883,10 +903,10 @@ class Trainer(object): return t.to(dtype=torch.bfloat16) return t - if self.args.fp16: + if self.cfg.common.fp16: sample = utils.apply_to_sample(apply_half, sample) - if self.args.bf16: + if self.cfg.common.bf16: sample = utils.apply_to_sample(apply_bfloat16, sample) return sample @@ -894,7 +914,7 @@ class Trainer(object): def _set_seed(self): # Set seed based on args.seed and the update number so that we get # reproducible results when resuming from checkpoints - seed = self.args.seed + self.get_num_updates() + seed = self.cfg.common.seed + self.get_num_updates() utils.set_torch_seed(seed) def _sync_stats(self): @@ -902,10 +922,12 @@ class Trainer(object): # BMUF and it's a bmuf sync with warmup iterations completed before. if self.data_parallel_world_size == 1: return False - elif self.args.use_bmuf: - return (self.get_num_updates() + 1) % self.args.global_sync_iter == 0 and ( + elif self.cfg.optimization.use_bmuf: + return ( self.get_num_updates() + 1 - ) > self.args.warmup_iterations + ) % self.cfg.bmuf.global_sync_iter == 0 and ( + self.get_num_updates() + 1 + ) > self.cfg.bmuf.warmup_iterations else: return True @@ -950,7 +972,7 @@ class Trainer(object): zip( *distributed_utils.all_gather_list( [logging_outputs] + list(extra_stats_to_sum), - max_size=getattr(self.args, "all_gather_list_size", 16384), + max_size=getattr(self.cfg.common, "all_gather_list_size", 16384), group=self.data_parallel_process_group, ) ) @@ -1038,11 +1060,11 @@ class Trainer(object): if grad_norm is not None: metrics.log_speed("ups", 1.0, priority=100, round=2) metrics.log_scalar("gnorm", grad_norm, priority=400, round=3) - if self.args.clip_norm > 0: + if self.cfg.optimization.clip_norm > 0: metrics.log_scalar( "clip", torch.where( - grad_norm > self.args.clip_norm, + grad_norm > self.cfg.optimization.clip_norm, grad_norm.new_tensor(100), grad_norm.new_tensor(0), ), @@ -1087,7 +1109,7 @@ class Trainer(object): logger.warning( "XLA compilation detected on device #{}; too many of these can lead " "to slow training, but we expect a few in the beginning".format( - self.args.distributed_rank + self.cfg.distributed_training.distributed_rank ) ) self._num_xla_compiles = num_xla_compiles diff --git a/fairseq_cli/eval_lm.py b/fairseq_cli/eval_lm.py index 9a4ff8ee..4621a66a 100644 --- a/fairseq_cli/eval_lm.py +++ b/fairseq_cli/eval_lm.py @@ -11,13 +11,19 @@ Evaluate the perplexity of a trained language model. import logging import math import os +from argparse import Namespace import torch from fairseq import checkpoint_utils, distributed_utils, options, tasks, utils from fairseq.data import LMContextWindowDataset +from fairseq.dataclass.data_class import register_hydra_cfg +from fairseq.dataclass.utils import convert_namespace_to_omegaconf from fairseq.logging import progress_bar from fairseq.logging.meters import StopwatchMeter, TimeMeter from fairseq.sequence_scorer import SequenceScorer +from hydra.core.config_store import ConfigStore +from hydra.experimental import initialize +from omegaconf import DictConfig logging.basicConfig( @@ -60,65 +66,60 @@ class WordStat(object): ) -def main(parsed_args, **unused_kwargs): - assert parsed_args.path is not None, "--path required for evaluation!" +def main(cfg: DictConfig, override_args=None, **unused_kwargs): + if isinstance(cfg, Namespace): + cfg = convert_namespace_to_omegaconf(cfg) - if torch.cuda.is_available() and not parsed_args.cpu: - torch.cuda.set_device(parsed_args.device_id) + utils.import_user_module(cfg.common) - utils.import_user_module(parsed_args) + use_fp16 = cfg.common.fp16 + use_cuda = torch.cuda.is_available() and not cfg.common.cpu - logger.info(parsed_args) + if use_cuda: + torch.cuda.set_device(cfg.distributed_training.device_id) - use_cuda = torch.cuda.is_available() and not parsed_args.cpu + if override_args is not None: + overrides = vars(override_args) + overrides.update(eval(getattr(override_args, "model_overrides", "{}"))) + else: + overrides = None - task = tasks.setup_task(parsed_args) + logger.info(cfg) # Load ensemble - logger.info("loading model(s) from {}".format(parsed_args.path)) - models, args = checkpoint_utils.load_model_ensemble( - parsed_args.path.split(os.pathsep), - arg_overrides=eval(parsed_args.model_overrides), - task=task, - suffix=getattr(parsed_args, "checkpoint_suffix", ""), - strict=(parsed_args.checkpoint_shard_count == 1), - num_shards=parsed_args.checkpoint_shard_count, - ) - - for arg in vars(parsed_args).keys(): - if arg not in { - "self_target", - "future_target", - "past_target", - "tokens_per_sample", - "output_size_dictionary", - "add_bos_token", - }: - setattr(args, arg, getattr(parsed_args, arg)) + logger.info("loading model(s) from {}".format(cfg.common_eval.path)) # reduce tokens per sample by the required context window size - args.tokens_per_sample -= args.context_window - task = tasks.setup_task(args) + cfg.task.tokens_per_sample -= cfg.eval_lm.context_window + + models, model_args, task = checkpoint_utils.load_model_ensemble_and_task( + [cfg.common_eval.path], + arg_overrides=overrides, + suffix=cfg.checkpoint.checkpoint_suffix, + strict=(cfg.checkpoint.checkpoint_shard_count == 1), + num_shards=cfg.checkpoint.checkpoint_shard_count, + ) # Load dataset splits - task.load_dataset(args.gen_subset) - dataset = task.dataset(args.gen_subset) - if args.context_window > 0: + gen_subset = cfg.dataset.gen_subset + task.load_dataset(gen_subset) + dataset = task.dataset(gen_subset) + if cfg.eval_lm.context_window > 0: dataset = LMContextWindowDataset( dataset=dataset, - tokens_per_sample=args.tokens_per_sample, - context_window=args.context_window, + tokens_per_sample=cfg.task.tokens_per_sample, + context_window=cfg.eval_lm.context_window, pad_idx=task.source_dictionary.pad(), ) - logger.info("{} {} {} examples".format(args.data, args.gen_subset, len(dataset))) + logger.info("{} {} {} examples".format(cfg.task.data, gen_subset, len(dataset))) # Optimize ensemble for generation and set the source and dest dicts on the model (required by scorer) for model in models: - if args.fp16: + if use_fp16: model.half() - if use_cuda and not args.pipeline_model_parallel: + if use_cuda and not cfg.distributed_training.pipeline_model_parallel: model.cuda() - model.prepare_for_inference_(args) + model.prepare_for_inference_(cfg) assert len(models) > 0 @@ -128,35 +129,41 @@ def main(parsed_args, **unused_kwargs): itr = task.get_batch_iterator( dataset=dataset, - max_tokens=args.max_tokens or 36000, - max_sentences=args.batch_size, + max_tokens=cfg.dataset.max_tokens or 36000, + max_sentences=cfg.dataset.batch_size, max_positions=utils.resolve_max_positions( *[model.max_positions() for model in models] ), ignore_invalid_inputs=True, - num_shards=args.num_shards, - shard_id=args.shard_id, - num_workers=args.num_workers, - data_buffer_size=args.data_buffer_size, + num_shards=max( + cfg.dataset.num_shards, + cfg.distributed_training.distributed_world_size, + ), + shard_id=max( + cfg.dataset.shard_id, + cfg.distributed_training.distributed_rank, + ), + num_workers=cfg.dataset.num_workers, + data_buffer_size=cfg.dataset.data_buffer_size, ).next_epoch_itr(shuffle=False) progress = progress_bar.progress_bar( itr, - log_format=args.log_format, - log_interval=args.log_interval, - default_log_format=("tqdm" if not args.no_progress_bar else "none"), + log_format=cfg.common.log_format, + log_interval=cfg.common.log_interval, + default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"), ) gen_timer = StopwatchMeter() - scorer = SequenceScorer(task.target_dictionary, args.softmax_batch) + scorer = SequenceScorer(task.target_dictionary, cfg.eval_lm.softmax_batch) score_sum = 0.0 count = 0 - if args.remove_bpe is not None: - if args.remove_bpe == "sentencepiece": + if cfg.common_eval.remove_bpe is not None: + if cfg.common_eval.remove_bpe == "sentencepiece": raise NotImplementedError else: - bpe_cont = args.remove_bpe.rstrip() + bpe_cont = cfg.common_eval.remove_bpe.rstrip() bpe_toks = { i for i in range(len(task.source_dictionary)) @@ -189,7 +196,7 @@ def main(parsed_args, **unused_kwargs): tgt_len = tokens.numel() pos_scores = hypo["positional_scores"].float() - if getattr(args, "add_bos_token", False): + if cfg.task.add_bos_token: assert hypo["tokens"][0].item() == task.target_dictionary.bos() tokens = tokens[1:] pos_scores = pos_scores[1:] @@ -212,7 +219,7 @@ def main(parsed_args, **unused_kwargs): score_sum += pos_scores.sum().cpu() count += pos_scores.numel() - skipped_toks - if args.output_word_probs or args.output_word_stats: + if cfg.eval_lm.output_word_probs or cfg.eval_lm.output_word_stats: w = "" word_prob = [] is_bpe = False @@ -238,7 +245,7 @@ def main(parsed_args, **unused_kwargs): ) is_bpe = False w = "" - if args.output_word_probs: + if cfg.eval_lm.output_word_probs: logger.info( str(int(sample_id)) + " " @@ -264,7 +271,7 @@ def main(parsed_args, **unused_kwargs): ) ) - if args.output_word_stats: + if cfg.eval_lm.output_word_stats: for ws in sorted(word_stats.values(), key=lambda x: x.count, reverse=True): logger.info(ws) @@ -272,8 +279,16 @@ def main(parsed_args, **unused_kwargs): def cli_main(): parser = options.get_eval_lm_parser() args = options.parse_args_and_arch(parser) - distributed_utils.call_main(args, main) + + # only override args that are explicitly given on the command line + override_parser = options.get_validation_parser() + override_args = options.parse_args_and_arch(override_parser, suppress_defaults=True) + + distributed_utils.call_main(args, main, override_args=override_args) if __name__ == "__main__": + cs = ConfigStore.instance() + register_hydra_cfg(cs) + initialize(config_path="../config", strict=True) cli_main() diff --git a/fairseq_cli/generate.py b/fairseq_cli/generate.py index 8ddf981c..6a6f7465 100644 --- a/fairseq_cli/generate.py +++ b/fairseq_cli/generate.py @@ -12,33 +12,45 @@ import logging import math import os import sys +from argparse import Namespace from itertools import chain import numpy as np import torch from fairseq import checkpoint_utils, options, scoring, tasks, utils +from fairseq.data import encoders +from fairseq.dataclass.data_class import register_hydra_cfg +from fairseq.dataclass.utils import convert_namespace_to_omegaconf from fairseq.logging import progress_bar from fairseq.logging.meters import StopwatchMeter, TimeMeter +from hydra.core.config_store import ConfigStore +from hydra.experimental import initialize +from omegaconf import DictConfig -def main(args): - assert args.path is not None, "--path required for generation!" +def main(cfg: DictConfig): + + if isinstance(cfg, Namespace): + cfg = convert_namespace_to_omegaconf(cfg) + + assert cfg.common_eval.path is not None, "--path required for generation!" assert ( - not args.sampling or args.nbest == args.beam + not cfg.generation.sampling or cfg.generation.nbest == cfg.generation.beam ), "--sampling requires --nbest to be equal to --beam" assert ( - args.replace_unk is None or args.dataset_impl == "raw" + cfg.generation.replace_unk is None or cfg.dataset.dataset_impl == "raw" ), "--replace-unk requires a raw text dataset (--dataset-impl=raw)" - if args.results_path is not None: - os.makedirs(args.results_path, exist_ok=True) + if cfg.common_eval.results_path is not None: + os.makedirs(cfg.common_eval.results_path, exist_ok=True) output_path = os.path.join( - args.results_path, "generate-{}.txt".format(args.gen_subset) + cfg.common_eval.results_path, + "generate-{}.txt".format(cfg.dataset.gen_subset), ) with open(output_path, "w", buffering=1, encoding="utf-8") as h: - return _main(args, h) + return _main(cfg, h) else: - return _main(args, sys.stdout) + return _main(cfg, sys.stdout) def get_symbols_to_strip_from_output(generator): @@ -48,7 +60,7 @@ def get_symbols_to_strip_from_output(generator): return {generator.eos} -def _main(args, output_file): +def _main(cfg: DictConfig, output_file): logging.basicConfig( format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S", @@ -57,22 +69,22 @@ def _main(args, output_file): ) logger = logging.getLogger("fairseq_cli.generate") - utils.import_user_module(args) + utils.import_user_module(cfg.common) - if args.max_tokens is None and args.batch_size is None: - args.max_tokens = 12000 - logger.info(args) + if cfg.dataset.max_tokens is None and cfg.dataset.batch_size is None: + cfg.dataset.max_tokens = 12000 + logger.info(cfg) # Fix seed for stochastic decoding - if args.seed is not None and not args.no_seed_provided: - np.random.seed(args.seed) - utils.set_torch_seed(args.seed) + if cfg.common.seed is not None and not cfg.generation.no_seed_provided: + np.random.seed(cfg.common.seed) + utils.set_torch_seed(cfg.common.seed) - use_cuda = torch.cuda.is_available() and not args.cpu + use_cuda = torch.cuda.is_available() and not cfg.common.cpu # Load dataset splits - task = tasks.setup_task(args) - task.load_dataset(args.gen_subset) + task = tasks.setup_task(cfg.task) + task.load_dataset(cfg.dataset.gen_subset) # Set dictionaries try: @@ -81,32 +93,30 @@ def _main(args, output_file): src_dict = None tgt_dict = task.target_dictionary - overrides = ast.literal_eval(args.model_overrides) + overrides = ast.literal_eval(cfg.common_eval.model_overrides) # Load ensemble - logger.info("loading model(s) from {}".format(args.path)) + logger.info("loading model(s) from {}".format(cfg.common_eval.path)) models, _model_args = checkpoint_utils.load_model_ensemble( - utils.split_paths(args.path), + utils.split_paths(cfg.common_eval.path), arg_overrides=overrides, task=task, - suffix=getattr(args, "checkpoint_suffix", ""), - strict=(args.checkpoint_shard_count == 1), - num_shards=args.checkpoint_shard_count, + suffix=cfg.checkpoint.checkpoint_suffix, + strict=(cfg.checkpoint.checkpoint_shard_count == 1), + num_shards=cfg.checkpoint.checkpoint_shard_count, ) - if args.lm_path is not None: - overrides["data"] = args.data + if cfg.generation.lm_path is not None: + overrides["data"] = cfg.task.data try: lms, _ = checkpoint_utils.load_model_ensemble( - [args.lm_path], - arg_overrides=overrides, - task=None, + [cfg.generation.lm_path], arg_overrides=overrides, task=None ) except: logger.warning( f"Failed to load language model! Please make sure that the language model dict is the same " - f"as target dict and is located in the data dir ({args.data})" + f"as target dict and is located in the data dir ({cfg.task.data})" ) raise @@ -118,49 +128,50 @@ def _main(args, output_file): for model in chain(models, lms): if model is None: continue - if args.fp16: + if cfg.common.fp16: model.half() - if use_cuda and not args.pipeline_model_parallel: + if use_cuda and not cfg.distributed_training.pipeline_model_parallel: model.cuda() - model.prepare_for_inference_(args) + model.prepare_for_inference_(cfg) # Load alignment dictionary for unknown word replacement # (None if no unknown word replacement, empty if no path to align dictionary) - align_dict = utils.load_align_dict(args.replace_unk) + align_dict = utils.load_align_dict(cfg.generation.replace_unk) # Load dataset (possibly sharded) itr = task.get_batch_iterator( - dataset=task.dataset(args.gen_subset), - max_tokens=args.max_tokens, - max_sentences=args.batch_size, + dataset=task.dataset(cfg.dataset.gen_subset), + max_tokens=cfg.dataset.max_tokens, + max_sentences=cfg.dataset.batch_size, max_positions=utils.resolve_max_positions( - task.max_positions(), *[model.max_positions() for model in models] + task.max_positions(), *[m.max_positions() for m in models] ), - ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, - required_batch_size_multiple=args.required_batch_size_multiple, - num_shards=args.num_shards, - shard_id=args.shard_id, - num_workers=args.num_workers, - data_buffer_size=args.data_buffer_size, + ignore_invalid_inputs=cfg.dataset.skip_invalid_size_inputs_valid_test, + required_batch_size_multiple=cfg.dataset.required_batch_size_multiple, + seed=cfg.common.seed, + num_shards=cfg.distributed_training.distributed_world_size, + shard_id=cfg.distributed_training.distributed_rank, + num_workers=cfg.dataset.num_workers, + data_buffer_size=cfg.dataset.data_buffer_size, ).next_epoch_itr(shuffle=False) progress = progress_bar.progress_bar( itr, - log_format=args.log_format, - log_interval=args.log_interval, - default_log_format=("tqdm" if not args.no_progress_bar else "none"), + log_format=cfg.common.log_format, + log_interval=cfg.common.log_interval, + default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"), ) # Initialize generator gen_timer = StopwatchMeter() - extra_gen_cls_kwargs = {"lm_model": lms[0], "lm_weight": args.lm_weight} + extra_gen_cls_kwargs = {"lm_model": lms[0], "lm_weight": cfg.generation.lm_weight} generator = task.build_generator( - models, args, extra_gen_cls_kwargs=extra_gen_cls_kwargs + models, cfg.task, extra_gen_cls_kwargs=extra_gen_cls_kwargs ) # Handle tokenization and BPE - tokenizer = task.build_tokenizer(args) - bpe = task.build_bpe(args) + tokenizer = encoders.build_tokenizer(cfg.tokenizer) + bpe = encoders.build_bpe(cfg.bpe) def decode_fn(x): if bpe is not None: @@ -169,7 +180,7 @@ def _main(args, output_file): x = tokenizer.decode(x) return x - scorer = scoring.build_scorer(args, tgt_dict) + scorer = scoring.build_scorer(cfg.scoring, tgt_dict) num_sentences = 0 has_target = True @@ -180,8 +191,8 @@ def _main(args, output_file): continue prefix_tokens = None - if args.prefix_size > 0: - prefix_tokens = sample["target"][:, : args.prefix_size] + if cfg.generation.prefix_size > 0: + prefix_tokens = sample["target"][:, : cfg.generation.prefix_size] constraints = None if "constraints" in sample: @@ -217,19 +228,21 @@ def _main(args, output_file): # Either retrieve the original sentences or regenerate them from tokens. if align_dict is not None: - src_str = task.dataset(args.gen_subset).src.get_original_text(sample_id) - target_str = task.dataset(args.gen_subset).tgt.get_original_text( + src_str = task.dataset(cfg.dataset.gen_subset).src.get_original_text( + sample_id + ) + target_str = task.dataset(cfg.dataset.gen_subset).tgt.get_original_text( sample_id ) else: if src_dict is not None: - src_str = src_dict.string(src_tokens, args.remove_bpe) + src_str = src_dict.string(src_tokens, cfg.common_eval.remove_bpe) else: src_str = "" if has_target: target_str = tgt_dict.string( target_tokens, - args.remove_bpe, + cfg.common_eval.remove_bpe, escape_unk=True, extra_symbols_to_ignore=get_symbols_to_strip_from_output( generator @@ -240,25 +253,25 @@ def _main(args, output_file): if has_target: target_str = decode_fn(target_str) - if not args.quiet: + if not cfg.common_eval.quiet: if src_dict is not None: print("S-{}\t{}".format(sample_id, src_str), file=output_file) if has_target: print("T-{}\t{}".format(sample_id, target_str), file=output_file) # Process top predictions - for j, hypo in enumerate(hypos[i][: args.nbest]): + for j, hypo in enumerate(hypos[i][: cfg.generation.nbest]): hypo_tokens, hypo_str, alignment = utils.post_process_prediction( hypo_tokens=hypo["tokens"].int().cpu(), src_str=src_str, alignment=hypo["alignment"], align_dict=align_dict, tgt_dict=tgt_dict, - remove_bpe=args.remove_bpe, + remove_bpe=cfg.common_eval.remove_bpe, extra_symbols_to_ignore=get_symbols_to_strip_from_output(generator), ) detok_hypo_str = decode_fn(hypo_str) - if not args.quiet: + if not cfg.common_eval.quiet: score = hypo["score"] / math.log(2) # convert to base 2 # original hypothesis (after tokenization and BPE) print( @@ -286,7 +299,7 @@ def _main(args, output_file): file=output_file, ) - if args.print_alignment: + if cfg.generation.print_alignment: print( "A-{}\t{}".format( sample_id, @@ -300,13 +313,13 @@ def _main(args, output_file): file=output_file, ) - if args.print_step: + if cfg.generation.print_step: print( "I-{}\t{}".format(sample_id, hypo["steps"]), file=output_file, ) - if getattr(args, "retain_iter_history", False): + if cfg.generation.retain_iter_history: for step, h in enumerate(hypo["history"]): _, h_str, _ = utils.post_process_prediction( hypo_tokens=h["tokens"].int().cpu(), @@ -323,7 +336,7 @@ def _main(args, output_file): # Score only the top hypothesis if has_target and j == 0: - if align_dict is not None or args.remove_bpe is not None: + if align_dict is not None or cfg.common_eval.remove_bpe is not None: # Convert back to tokens for evaluation with unk replacement and/or without BPE target_tokens = tgt_dict.encode_line( target_str, add_if_not_exist=True @@ -353,8 +366,8 @@ def _main(args, output_file): ) ) if has_target: - if args.bpe and not args.sacrebleu: - if args.remove_bpe: + if cfg.bpe and not cfg.generation.sacrebleu: + if cfg.common_eval.remove_bpe: logger.warning( "BLEU score is being computed by splitting detokenized string on spaces, this is probably not what you want. Use --sacrebleu for standard 13a BLEU tokenization" ) @@ -365,7 +378,7 @@ def _main(args, output_file): # use print to be consistent with other main outputs: S-, H-, T-, D- and so on print( "Generate {} with beam={}: {}".format( - args.gen_subset, args.beam, scorer.result_string() + cfg.dataset.gen_subset, cfg.generation.beam, scorer.result_string() ), file=output_file, ) @@ -380,4 +393,7 @@ def cli_main(): if __name__ == "__main__": + cs = ConfigStore.instance() + register_hydra_cfg(cs) + initialize(config_path="../config", strict=True) cli_main() diff --git a/fairseq_cli/interactive.py b/fairseq_cli/interactive.py index de3893a3..ddd2617c 100644 --- a/fairseq_cli/interactive.py +++ b/fairseq_cli/interactive.py @@ -7,20 +7,27 @@ Translate raw text with a trained model. Batches data on-the-fly. """ +import ast import fileinput import logging import math import os import sys import time +from argparse import Namespace from collections import namedtuple import numpy as np import torch from fairseq import checkpoint_utils, distributed_utils, options, tasks, utils from fairseq.data import encoders +from fairseq.dataclass.data_class import register_hydra_cfg +from fairseq.dataclass.utils import convert_namespace_to_omegaconf from fairseq.token_generation_constraints import pack_constraints, unpack_constraints from fairseq_cli.generate import get_symbols_to_strip_from_output +from hydra.core.config_store import ConfigStore +from hydra.experimental import initialize +from omegaconf import DictConfig logging.basicConfig( @@ -49,11 +56,11 @@ def buffered_read(input, buffer_size): yield buffer -def make_batches(lines, args, task, max_positions, encode_fn): +def make_batches(lines, cfg, task, max_positions, encode_fn): def encode_fn_target(x): return encode_fn(x) - if args.constraints: + if cfg.generation.constraints: # Strip (tab-delimited) contraints, if present, from input lines, # store them in batch_constraints batch_constraints = [list() for _ in lines] @@ -79,7 +86,7 @@ def make_batches(lines, args, task, max_positions, encode_fn): for src_str in lines ] - if args.constraints: + if cfg.generation.constraints: constraints_tensor = pack_constraints(batch_constraints) else: constraints_tensor = None @@ -89,10 +96,10 @@ def make_batches(lines, args, task, max_positions, encode_fn): dataset=task.build_dataset_for_inference( tokens, lengths, constraints=constraints_tensor ), - max_tokens=args.max_tokens, - max_sentences=args.batch_size, + max_tokens=cfg.dataset.max_tokens, + max_sentences=cfg.dataset.batch_size, max_positions=max_positions, - ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, + ignore_invalid_inputs=cfg.dataset.skip_invalid_size_inputs_valid_test, ).next_epoch_itr(shuffle=False) for batch in itr: ids = batch["id"] @@ -108,45 +115,50 @@ def make_batches(lines, args, task, max_positions, encode_fn): ) -def main(args): +def main(cfg: DictConfig): + if isinstance(cfg, Namespace): + cfg = convert_namespace_to_omegaconf(cfg) + start_time = time.time() total_translate_time = 0 - utils.import_user_module(args) + utils.import_user_module(cfg.common) - if args.buffer_size < 1: - args.buffer_size = 1 - if args.max_tokens is None and args.batch_size is None: - args.batch_size = 1 + if cfg.interactive.buffer_size < 1: + cfg.interactive.buffer_size = 1 + if cfg.dataset.max_tokens is None and cfg.dataset.batch_size is None: + cfg.dataset.batch_size = 1 assert ( - not args.sampling or args.nbest == args.beam + not cfg.generation.sampling or cfg.generation.nbest == cfg.generation.beam ), "--sampling requires --nbest to be equal to --beam" assert ( - not args.batch_size or args.batch_size <= args.buffer_size + not cfg.dataset.batch_size + or cfg.dataset.batch_size <= cfg.interactive.buffer_size ), "--batch-size cannot be larger than --buffer-size" - logger.info(args) + logger.info(cfg) # Fix seed for stochastic decoding - if args.seed is not None and not args.no_seed_provided: - np.random.seed(args.seed) - utils.set_torch_seed(args.seed) + if cfg.common.seed is not None and not cfg.generation.no_seed_provided: + np.random.seed(cfg.common.seed) + utils.set_torch_seed(cfg.common.seed) - use_cuda = torch.cuda.is_available() and not args.cpu + use_cuda = torch.cuda.is_available() and not cfg.common.cpu # Setup task, e.g., translation - task = tasks.setup_task(args) + task = tasks.setup_task(cfg.task) # Load ensemble - logger.info("loading model(s) from {}".format(args.path)) + overrides = ast.literal_eval(cfg.common_eval.model_overrides) + logger.info("loading model(s) from {}".format(cfg.common_eval.path)) models, _model_args = checkpoint_utils.load_model_ensemble( - args.path.split(os.pathsep), - arg_overrides=eval(args.model_overrides), + utils.split_paths(cfg.common_eval.path), + arg_overrides=overrides, task=task, - suffix=getattr(args, "checkpoint_suffix", ""), - strict=(args.checkpoint_shard_count == 1), - num_shards=args.checkpoint_shard_count, + suffix=cfg.checkpoint.checkpoint_suffix, + strict=(cfg.checkpoint.checkpoint_shard_count == 1), + num_shards=cfg.checkpoint.checkpoint_shard_count, ) # Set dictionaries @@ -155,18 +167,20 @@ def main(args): # Optimize ensemble for generation for model in models: - if args.fp16: + if model is None: + continue + if cfg.common.fp16: model.half() - if use_cuda and not args.pipeline_model_parallel: + if use_cuda and not cfg.distributed_training.pipeline_model_parallel: model.cuda() - model.prepare_for_inference_(args) + model.prepare_for_inference_(cfg) # Initialize generator - generator = task.build_generator(models, args) + generator = task.build_generator(models, cfg.task) # Handle tokenization and BPE - tokenizer = encoders.build_tokenizer(args) - bpe = encoders.build_bpe(args) + tokenizer = encoders.build_tokenizer(cfg.tokenizer) + bpe = encoders.build_bpe(cfg.bpe) def encode_fn(x): if tokenizer is not None: @@ -184,25 +198,25 @@ def main(args): # Load alignment dictionary for unknown word replacement # (None if no unknown word replacement, empty if no path to align dictionary) - align_dict = utils.load_align_dict(args.replace_unk) + align_dict = utils.load_align_dict(cfg.generation.replace_unk) max_positions = utils.resolve_max_positions( task.max_positions(), *[model.max_positions() for model in models] ) - if args.constraints: + if cfg.generation.constraints: logger.warning( "NOTE: Constrained decoding currently assumes a shared subword vocabulary." ) - if args.buffer_size > 1: - logger.info("Sentence buffer size: %s", args.buffer_size) + if cfg.interactive.buffer_size > 1: + logger.info("Sentence buffer size: %s", cfg.interactive.buffer_size) logger.info("NOTE: hypothesis and token scores are output in base 2") logger.info("Type the input sentence and press return:") start_id = 0 - for inputs in buffered_read(args.input, args.buffer_size): + for inputs in buffered_read(cfg.interactive.input, cfg.interactive.buffer_size): results = [] - for batch in make_batches(inputs, args, task, max_positions, encode_fn): + for batch in make_batches(inputs, cfg, task, max_positions, encode_fn): bsz = batch.src_tokens.size(0) src_tokens = batch.src_tokens src_lengths = batch.src_lengths @@ -226,7 +240,7 @@ def main(args): translate_time = time.time() - translate_start_time total_translate_time += translate_time list_constraints = [[] for _ in range(bsz)] - if args.constraints: + if cfg.generation.constraints: list_constraints = [unpack_constraints(c) for c in constraints] for i, (id, hypos) in enumerate(zip(batch.ids.tolist(), translations)): src_tokens_i = utils.strip_pad(src_tokens[i], tgt_dict.pad()) @@ -246,25 +260,25 @@ def main(args): # sort output to match input order for id_, src_tokens, hypos, info in sorted(results, key=lambda x: x[0]): if src_dict is not None: - src_str = src_dict.string(src_tokens, args.remove_bpe) + src_str = src_dict.string(src_tokens, cfg.common_eval.remove_bpe) print("S-{}\t{}".format(id_, src_str)) print("W-{}\t{:.3f}\tseconds".format(id_, info["time"])) for constraint in info["constraints"]: print( "C-{}\t{}".format( - id_, tgt_dict.string(constraint, args.remove_bpe) + id_, tgt_dict.string(constraint, cfg.common_eval.remove_bpe) ) ) # Process top predictions - for hypo in hypos[: min(len(hypos), args.nbest)]: + for hypo in hypos[: min(len(hypos), cfg.generation.nbest)]: hypo_tokens, hypo_str, alignment = utils.post_process_prediction( hypo_tokens=hypo["tokens"].int().cpu(), src_str=src_str, alignment=hypo["alignment"], align_dict=align_dict, tgt_dict=tgt_dict, - remove_bpe=args.remove_bpe, + remove_bpe=cfg.common_eval.remove_bpe, extra_symbols_to_ignore=get_symbols_to_strip_from_output(generator), ) detok_hypo_str = decode_fn(hypo_str) @@ -285,7 +299,7 @@ def main(args): ), ) ) - if args.print_alignment: + if cfg.generation.print_alignment: alignment_str = " ".join( ["{}-{}".format(src, tgt) for src, tgt in alignment] ) @@ -308,4 +322,7 @@ def cli_main(): if __name__ == "__main__": + cs = ConfigStore.instance() + register_hydra_cfg(cs) + initialize(config_path="../config", strict=True) cli_main() diff --git a/fairseq_cli/score.py b/fairseq_cli/score.py index b8354eb9..e06d6725 100644 --- a/fairseq_cli/score.py +++ b/fairseq_cli/score.py @@ -78,7 +78,13 @@ def cli_main(): def score(fdsys): with open(args.ref) as fdref: - scorer = bleu.Scorer(dict.pad(), dict.eos(), dict.unk()) + scorer = bleu.Scorer( + bleu.BleuConfig( + pad=dict.pad(), + eos=dict.eos(), + unk=dict.unk(), + ) + ) for sys_tok, ref_tok in zip(readlines(fdsys), readlines(fdref)): sys_tok = dict.encode_line(sys_tok) ref_tok = dict.encode_line(ref_tok) diff --git a/fairseq_cli/train.py b/fairseq_cli/train.py index cd3a93b1..4c007610 100644 --- a/fairseq_cli/train.py +++ b/fairseq_cli/train.py @@ -11,11 +11,13 @@ import argparse import logging import math import os -import random import sys +from typing import Dict, Optional, Any, List, Tuple, Callable import numpy as np import torch +from hydra.core.config_store import ConfigStore + from fairseq import ( checkpoint_utils, distributed_utils, @@ -25,8 +27,12 @@ from fairseq import ( utils, ) from fairseq.data import iterators +from fairseq.dataclass.utils import convert_namespace_to_omegaconf from fairseq.logging import meters, metrics, progress_bar from fairseq.model_parallel.megatron_trainer import MegatronTrainer +from omegaconf import DictConfig +from hydra.experimental import initialize +from fairseq.dataclass.data_class import register_hydra_cfg from fairseq.trainer import Trainer @@ -39,90 +45,86 @@ logging.basicConfig( logger = logging.getLogger("fairseq_cli.train") -def main(args): - utils.import_user_module(args) +def main(cfg: DictConfig) -> None: + if isinstance(cfg, argparse.Namespace): + cfg = convert_namespace_to_omegaconf(cfg) - assert ( - args.max_tokens is not None or args.batch_size is not None - ), "Must specify batch size either with --max-tokens or --batch-size" + utils.import_user_module(cfg.common) + assert cfg.dataset.max_tokens is not None or cfg.dataset.batch_size is not None, \ + 'Must specify batch size either with --max-tokens or --batch-size' metrics.reset() - np.random.seed(args.seed) - utils.set_torch_seed(args.seed) + np.random.seed(cfg.common.seed) + utils.set_torch_seed(cfg.common.seed) - if distributed_utils.is_master(args): - checkpoint_utils.verify_checkpoint_directory(args.save_dir) + if distributed_utils.is_master(cfg.distributed_training): + checkpoint_utils.verify_checkpoint_directory(cfg.checkpoint.save_dir) # Print args - logger.info(args) + logger.info(cfg) # Setup task, e.g., translation, language modeling, etc. - task = tasks.setup_task(args) - + task = tasks.setup_task(cfg.task) # Load valid dataset (we load training data below, based on the latest checkpoint) - for valid_sub_split in args.valid_subset.split(","): + for valid_sub_split in cfg.dataset.valid_subset.split(','): task.load_dataset(valid_sub_split, combine=False, epoch=1) # Build model and criterion - model = task.build_model(args) - criterion = task.build_criterion(args) + model = task.build_model(cfg.model) + criterion = task.build_criterion(cfg.criterion) logger.info(model) - logger.info("task: {} ({})".format(args.task, task.__class__.__name__)) - logger.info("model: {} ({})".format(args.arch, model.__class__.__name__)) + logger.info("task: {} ({})".format(cfg.task._name, task.__class__.__name__)) + logger.info("model: {} ({})".format(cfg.model._name, model.__class__.__name__)) logger.info( - "criterion: {} ({})".format(args.criterion, criterion.__class__.__name__) - ) - logger.info( - "num. model params: {} (num. trained: {})".format( - sum(p.numel() for p in model.parameters()), - sum(p.numel() for p in model.parameters() if p.requires_grad), - ) + "criterion: {} ({})".format(cfg.criterion._name, criterion.__class__.__name__) ) + logger.info("num. model params: {} (num. trained: {})".format( + sum(p.numel() for p in model.parameters()), + sum(p.numel() for p in model.parameters() if p.requires_grad), + )) # (optionally) Configure quantization - if args.quantization_config_path is not None: + if cfg.common.quantization_config_path is not None: quantizer = quantization_utils.Quantizer( - config_path=args.quantization_config_path, - max_epoch=args.max_epoch, - max_update=args.max_update, + config_path=cfg.common.quantization_config_path, + max_epoch=cfg.optimization.max_epoch, + max_update=cfg.optimization.max_update, ) else: quantizer = None # Build trainer - if args.model_parallel_size == 1: - trainer = Trainer(args, task, model, criterion, quantizer) + if cfg.common.model_parallel_size == 1: + trainer = Trainer(cfg, task, model, criterion, quantizer) else: - trainer = MegatronTrainer(args, task, model, criterion) + trainer = MegatronTrainer(cfg, task, model, criterion) - logger.info( - "training on {} devices (GPUs/TPUs)".format(args.distributed_world_size) - ) - logger.info( - "max tokens per GPU = {} and max sentences per GPU = {}".format( - args.max_tokens, args.batch_size - ) - ) + logger.info('training on {} devices (GPUs/TPUs)'.format(cfg.distributed_training.distributed_world_size)) + logger.info('max tokens per GPU = {} and batch size per GPU = {}'.format( + cfg.dataset.max_tokens, + cfg.dataset.batch_size, + )) # Load the latest checkpoint if one is available and restore the # corresponding train iterator extra_state, epoch_itr = checkpoint_utils.load_checkpoint( - args, + cfg.checkpoint, trainer, # don't cache epoch iterators for sharded datasets disable_iterator_cache=task.has_sharded_data("train"), ) - # Train until the learning rate gets too small - max_epoch = args.max_epoch or math.inf + max_epoch = cfg.optimization.max_epoch or math.inf lr = trainer.get_lr() train_meter = meters.StopwatchMeter() train_meter.start() - - while lr > args.min_lr and epoch_itr.next_epoch_idx <= max_epoch: + while ( + lr > cfg.optimization.min_lr + and epoch_itr.next_epoch_idx <= max_epoch + ): # train for one epoch - valid_losses, should_stop = train(args, trainer, task, epoch_itr) + valid_losses, should_stop = train(cfg, trainer, task, epoch_itr) if should_stop: break @@ -140,15 +142,15 @@ def main(args): logger.info("done training in {:.1f} seconds".format(train_meter.sum)) -def should_stop_early(args, valid_loss): +def should_stop_early(cfg: DictConfig, valid_loss: float) -> bool: # skip check if no validation was done in the current epoch if valid_loss is None: return False - if args.patience <= 0: + if cfg.checkpoint.patience <= 0: return False def is_better(a, b): - return a > b if args.maximize_best_checkpoint_metric else a < b + return a > b if cfg.checkpoint.maximize_best_checkpoint_metric else a < b prev_best = getattr(should_stop_early, "best", None) if prev_best is None or is_better(valid_loss, prev_best): @@ -157,48 +159,43 @@ def should_stop_early(args, valid_loss): return False else: should_stop_early.num_runs += 1 - if should_stop_early.num_runs >= args.patience: - logger.info( - "early stop since valid performance hasn't improved for last {} runs".format( - args.patience - ) - ) + if should_stop_early.num_runs >= cfg.checkpoint.patience: + logger.info('early stop since valid performance hasn\'t improved for last {} runs'.format(cfg.checkpoint.patience)) return True else: return False @metrics.aggregate("train") -def train(args, trainer, task, epoch_itr): +def train(cfg: DictConfig, trainer: Trainer, task: tasks.FairseqTask, epoch_itr) -> Tuple[List[Optional[float]], bool]: """Train the model for one epoch and return validation losses.""" # Initialize data iterator itr = epoch_itr.next_epoch_itr( - fix_batches_to_gpus=args.fix_batches_to_gpus, - shuffle=(epoch_itr.next_epoch_idx > args.curriculum), + fix_batches_to_gpus=cfg.distributed_training.fix_batches_to_gpus, + shuffle=(epoch_itr.next_epoch_idx > cfg.dataset.curriculum), ) update_freq = ( - args.update_freq[epoch_itr.epoch - 1] - if epoch_itr.epoch <= len(args.update_freq) - else args.update_freq[-1] + cfg.optimization.update_freq[epoch_itr.epoch - 1] + if epoch_itr.epoch <= len(cfg.optimization.update_freq) + else cfg.optimization.update_freq[-1] ) itr = iterators.GroupedIterator(itr, update_freq) - if getattr(args, "tpu", False): + if getattr(cfg.common, "tpu", False): itr = utils.tpu_data_loader(itr) progress = progress_bar.progress_bar( itr, - log_format=args.log_format, - log_interval=args.log_interval, + log_format=cfg.common.log_format, + log_interval=cfg.common.log_interval, epoch=epoch_itr.epoch, tensorboard_logdir=( - args.tensorboard_logdir if distributed_utils.is_master(args) else None + cfg.common.tensorboard_logdir if distributed_utils.is_master(cfg.distributed_training) else None ), - default_log_format=("tqdm" if not args.no_progress_bar else "simple"), + default_log_format=('tqdm' if not cfg.common.no_progress_bar else 'simple'), ) trainer.begin_epoch(epoch_itr.epoch) - valid_losses = [None] - valid_subsets = args.valid_subset.split(",") + valid_subsets = cfg.dataset.valid_subset.split(',') should_stop = False num_updates = trainer.get_num_updates() for i, samples in enumerate(progress): @@ -210,7 +207,7 @@ def train(args, trainer, task, epoch_itr): if log_output is not None: # not OOM, overflow, ... # log mid-epoch stats num_updates = trainer.get_num_updates() - if num_updates % args.log_interval == 0: + if num_updates % cfg.common.log_interval == 0: stats = get_training_stats(metrics.get_smoothed_values("train_inner")) progress.log(stats, tag="train_inner", step=num_updates) @@ -220,7 +217,7 @@ def train(args, trainer, task, epoch_itr): end_of_epoch = not itr.has_next() valid_losses, should_stop = validate_and_save( - args, trainer, task, epoch_itr, valid_subsets, end_of_epoch + cfg, trainer, task, epoch_itr, valid_subsets, end_of_epoch ) if should_stop: @@ -236,64 +233,64 @@ def train(args, trainer, task, epoch_itr): return valid_losses, should_stop -def validate_and_save(args, trainer, task, epoch_itr, valid_subsets, end_of_epoch): +def validate_and_save(cfg: DictConfig, trainer: Trainer, task: tasks.FairseqTask, epoch_itr, valid_subsets: List[str], end_of_epoch: bool) -> Tuple[List[Optional[float]], bool]: num_updates = trainer.get_num_updates() - max_update = args.max_update or math.inf + max_update = cfg.optimization.max_update or math.inf do_save = ( - (end_of_epoch and epoch_itr.epoch % args.save_interval == 0) + (end_of_epoch and epoch_itr.epoch % cfg.checkpoint.save_interval == 0) or num_updates >= max_update or ( - args.save_interval_updates > 0 + cfg.checkpoint.save_interval_updates > 0 and num_updates > 0 - and num_updates % args.save_interval_updates == 0 - and num_updates >= args.validate_after_updates + and num_updates % cfg.checkpoint.save_interval_updates == 0 + and num_updates >= cfg.dataset.validate_after_updates ) ) do_validate = ( (not end_of_epoch and do_save) # validate during mid-epoch saves - or (end_of_epoch and epoch_itr.epoch % args.validate_interval == 0) + or (end_of_epoch and epoch_itr.epoch % cfg.dataset.validate_interval == 0) or num_updates >= max_update or ( - args.validate_interval_updates > 0 + cfg.dataset.validate_interval_updates > 0 and num_updates > 0 - and num_updates % args.validate_interval_updates == 0 + and num_updates % cfg.dataset.validate_interval_updates == 0 ) - ) and not args.disable_validation + ) and not cfg.dataset.disable_validation # Validate valid_losses = [None] if do_validate: - valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets) + valid_losses = validate(cfg, trainer, task, epoch_itr, valid_subsets) # Stopping conditions should_stop = ( - should_stop_early(args, valid_losses[0]) + should_stop_early(cfg, valid_losses[0]) or num_updates >= max_update or ( - args.stop_time_hours > 0 - and trainer.cumulative_training_time() / (60 * 60) > args.stop_time_hours + cfg.optimization.stop_time_hours > 0 + and trainer.cumulative_training_time() / (60 * 60) > cfg.optimization.stop_time_hours ) ) # Save checkpoint if do_save or should_stop: logger.info("begin save checkpoint") - checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, valid_losses[0]) + checkpoint_utils.save_checkpoint(cfg.checkpoint, trainer, epoch_itr, valid_losses[0]) return valid_losses, should_stop -def get_training_stats(stats): +def get_training_stats(stats: Dict[str, Any]) -> Dict[str, Any]: stats["wall"] = round(metrics.get_meter("default", "wall").elapsed_time, 0) return stats -def validate(args, trainer, task, epoch_itr, subsets): +def validate(cfg: DictConfig, trainer: Trainer, task: tasks.FairseqTask, epoch_itr, subsets: List[str]) -> List[Optional[float]]: """Evaluate the model on the validation set(s) and return the losses.""" - if args.fixed_validation_seed is not None: + if cfg.dataset.fixed_validation_seed is not None: # set fixed seed for every validation - utils.set_torch_seed(args.fixed_validation_seed) + utils.set_torch_seed(cfg.dataset.fixed_validation_seed) trainer.begin_valid_epoch(epoch_itr.epoch) valid_losses = [] @@ -302,18 +299,18 @@ def validate(args, trainer, task, epoch_itr, subsets): # Initialize data iterator itr = trainer.get_valid_iterator(subset).next_epoch_itr(shuffle=False) - if getattr(args, "tpu", False): + if cfg.common.tpu: itr = utils.tpu_data_loader(itr) progress = progress_bar.progress_bar( itr, - log_format=args.log_format, - log_interval=args.log_interval, + log_format=cfg.common.log_format, + log_interval=cfg.common.log_interval, epoch=epoch_itr.epoch, prefix=f"valid on '{subset}' subset", tensorboard_logdir=( - args.tensorboard_logdir if distributed_utils.is_master(args) else None + cfg.common.tensorboard_logdir if distributed_utils.is_master(cfg.distributed_training) else None ), - default_log_format=("tqdm" if not args.no_progress_bar else "simple"), + default_log_format=('tqdm' if not cfg.common.no_progress_bar else 'simple'), ) # create a new root metrics aggregator so validation metrics @@ -323,34 +320,40 @@ def validate(args, trainer, task, epoch_itr, subsets): trainer.valid_step(sample) # log validation stats - stats = get_valid_stats(args, trainer, agg.get_smoothed_values()) + stats = get_valid_stats(cfg, trainer, agg.get_smoothed_values()) progress.print(stats, tag=subset, step=trainer.get_num_updates()) - valid_losses.append(stats[args.best_checkpoint_metric]) + valid_losses.append(stats[cfg.checkpoint.best_checkpoint_metric]) return valid_losses -def get_valid_stats(args, trainer, stats): +def get_valid_stats(cfg: DictConfig, trainer: Trainer, stats: Dict[str, Any]) -> Dict[str, Any]: stats["num_updates"] = trainer.get_num_updates() if hasattr(checkpoint_utils.save_checkpoint, "best"): - key = "best_{0}".format(args.best_checkpoint_metric) - best_function = max if args.maximize_best_checkpoint_metric else min + key = "best_{0}".format(cfg.checkpoint.best_checkpoint_metric) + best_function = max if cfg.checkpoint.maximize_best_checkpoint_metric else min stats[key] = best_function( - checkpoint_utils.save_checkpoint.best, stats[args.best_checkpoint_metric] + checkpoint_utils.save_checkpoint.best, stats[cfg.checkpoint.best_checkpoint_metric] ) return stats -def cli_main(modify_parser=None): +def cli_main(modify_parser: Optional[Callable[[argparse.ArgumentParser], None]] = None) -> None: parser = options.get_training_parser() args = options.parse_args_and_arch(parser, modify_parser=modify_parser) + + cfg = convert_namespace_to_omegaconf(args) + if args.profile: with torch.cuda.profiler.profile(): with torch.autograd.profiler.emit_nvtx(): - distributed_utils.call_main(args, main) + distributed_utils.call_main(cfg, main) else: - distributed_utils.call_main(args, main) + distributed_utils.call_main(cfg, main) -if __name__ == "__main__": +if __name__ == '__main__': + cs = ConfigStore.instance() + register_hydra_cfg(cs) + initialize(config_path="../config", strict=True) cli_main() diff --git a/fairseq_cli/validate.py b/fairseq_cli/validate.py index df857550..368c9cb5 100644 --- a/fairseq_cli/validate.py +++ b/fairseq_cli/validate.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -u -#!/usr/bin/env python3 -u +# !/usr/bin/env python3 -u # Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the @@ -8,11 +8,17 @@ import logging import os import sys +from argparse import Namespace from itertools import chain import torch from fairseq import checkpoint_utils, distributed_utils, options, utils +from fairseq.dataclass.data_class import register_hydra_cfg +from fairseq.dataclass.utils import convert_namespace_to_omegaconf from fairseq.logging import metrics, progress_bar +from hydra.core.config_store import ConfigStore +from hydra.experimental import initialize +from omegaconf import DictConfig logging.basicConfig( @@ -24,18 +30,21 @@ logging.basicConfig( logger = logging.getLogger("fairseq_cli.validate") -def main(args, override_args=None): - utils.import_user_module(args) +def main(cfg: DictConfig, override_args=None): + if isinstance(cfg, Namespace): + cfg = convert_namespace_to_omegaconf(cfg) + + utils.import_user_module(cfg.common) assert ( - args.max_tokens is not None or args.batch_size is not None + cfg.dataset.max_tokens is not None or cfg.dataset.batch_size is not None ), "Must specify batch size either with --max-tokens or --batch-size" - use_fp16 = args.fp16 - use_cuda = torch.cuda.is_available() and not args.cpu + use_fp16 = cfg.common.fp16 + use_cuda = torch.cuda.is_available() and not cfg.common.cpu if use_cuda: - torch.cuda.set_device(args.device_id) + torch.cuda.set_device(cfg.distributed_training.device_id) if override_args is not None: overrides = vars(override_args) @@ -44,11 +53,11 @@ def main(args, override_args=None): overrides = None # Load ensemble - logger.info("loading model(s) from {}".format(args.path)) + logger.info("loading model(s) from {}".format(cfg.common_eval.path)) models, model_args, task = checkpoint_utils.load_model_ensemble_and_task( - [args.path], + [cfg.common_eval.path], arg_overrides=overrides, - suffix=getattr(args, "checkpoint_suffix", ""), + suffix=cfg.checkpoint.checkpoint_suffix, ) model = models[0] @@ -63,10 +72,10 @@ def main(args, override_args=None): logger.info(model_args) # Build criterion - criterion = task.build_criterion(model_args) + criterion = task.build_criterion(model_args.criterion) criterion.eval() - for subset in args.valid_subset.split(","): + for subset in cfg.dataset.valid_subset.split(","): try: task.load_dataset(subset, combine=False, epoch=1) dataset = task.dataset(subset) @@ -76,26 +85,26 @@ def main(args, override_args=None): # Initialize data iterator itr = task.get_batch_iterator( dataset=dataset, - max_tokens=args.max_tokens, - max_sentences=args.batch_size, + max_tokens=cfg.dataset.max_tokens, + max_sentences=cfg.dataset.batch_size, max_positions=utils.resolve_max_positions( task.max_positions(), *[m.max_positions() for m in models], ), - ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, - required_batch_size_multiple=args.required_batch_size_multiple, - seed=args.seed, - num_shards=args.distributed_world_size, - shard_id=args.distributed_rank, - num_workers=args.num_workers, - data_buffer_size=args.data_buffer_size, + ignore_invalid_inputs=cfg.dataset.skip_invalid_size_inputs_valid_test, + required_batch_size_multiple=cfg.dataset.required_batch_size_multiple, + seed=cfg.common.seed, + num_shards=cfg.distributed_training.distributed_world_size, + shard_id=cfg.distributed_training.distributed_rank, + num_workers=cfg.dataset.num_workers, + data_buffer_size=cfg.dataset.data_buffer_size, ).next_epoch_itr(shuffle=False) progress = progress_bar.progress_bar( itr, - log_format=args.log_format, - log_interval=args.log_interval, + log_format=cfg.common.log_format, + log_interval=cfg.common.log_interval, prefix=f"valid on '{subset}' subset", - default_log_format=("tqdm" if not args.no_progress_bar else "simple"), + default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"), ) log_outputs = [] @@ -105,10 +114,10 @@ def main(args, override_args=None): progress.log(log_output, step=i) log_outputs.append(log_output) - if args.distributed_world_size > 1: + if cfg.distributed_training.distributed_world_size > 1: log_outputs = distributed_utils.all_gather_list( log_outputs, - max_size=getattr(args, "all_gather_list_size", 16384), + max_size=cfg.common.all_gather_list_size, ) log_outputs = list(chain.from_iterable(log_outputs)) @@ -131,4 +140,7 @@ def cli_main(): if __name__ == "__main__": + cs = ConfigStore.instance() + register_hydra_cfg(cs) + initialize(config_path="../config", strict=True) cli_main() diff --git a/tests/speech_recognition/asr_test_base.py b/tests/speech_recognition/asr_test_base.py index 03410313..8c5d414e 100644 --- a/tests/speech_recognition/asr_test_base.py +++ b/tests/speech_recognition/asr_test_base.py @@ -272,6 +272,7 @@ class TestFairseqEncoderDecoderModelBase(TestBaseFairseqModelBase): model_cls.add_args(parser) args = parser.parse_args([]) + if extra_args_setters is not None: for args_setter in extra_args_setters: args_setter(args) @@ -515,9 +516,7 @@ class CrossEntropyCriterionTestBase(unittest.TestCase): def setUp(self): args = self.setUpArgs() self.model = DummyEncoderModel(encoder=DummyEncoder()) - self.criterion = self.criterion_cls.build_criterion( - args=args, task=DummyTask(args) - ) + self.criterion = self.criterion_cls.build_criterion(args, task=DummyTask(args)) def get_src_tokens(self, correct_prediction, aggregate): """ diff --git a/tests/test_bmuf.py b/tests/test_bmuf.py index 0165b295..e7aa6da1 100644 --- a/tests/test_bmuf.py +++ b/tests/test_bmuf.py @@ -11,7 +11,7 @@ from multiprocessing import Manager import torch import torch.nn as nn from fairseq import distributed_utils, optim - +from omegaconf import OmegaConf class Model(nn.Module): def __init__(self, input_size, output_size): @@ -23,13 +23,14 @@ class Model(nn.Module): return output -def setup_model_loss_criterion(args, rank, is_cuda): +def setup_model_loss_criterion(cfg, args, rank, is_cuda): """ setup model, criterion and optimizer based on input args """ args.distributed_rank = rank - if args.distributed_world_size > 1: - distributed_utils.distributed_init(args) + cfg.distributed_training.distributed_rank = args.distributed_rank + if cfg.distributed_training.distributed_world_size > 1: + distributed_utils.distributed_init(cfg) torch.manual_seed(1) model = Model(args.input_size, args.nb_classes) loss_fn = nn.CrossEntropyLoss() @@ -38,7 +39,10 @@ def setup_model_loss_criterion(args, rank, is_cuda): loss_fn = loss_fn.cuda() optimizer = optim.sgd.SGD(args, model.parameters()) - optimizer = optim.FairseqBMUF(args, optimizer) + optimizer = optim.FairseqBMUF( + cfg=cfg.bmuf, + optimizer=optimizer + ) return model, loss_fn, optimizer @@ -52,13 +56,13 @@ def train_step(input, target, model, loss_fn, optimizer, **unused): optimizer.step() -def single_gpu_training(args, rank, iterations, shared_results): +def single_gpu_training(cfg, args, rank, iterations, shared_results): is_cuda = torch.cuda.is_available() if is_cuda: torch.cuda.set_device(rank) - model, loss_fn, optimizer = setup_model_loss_criterion(args, rank, is_cuda) + model, loss_fn, optimizer = setup_model_loss_criterion(cfg, args, rank, is_cuda) for _ in range(iterations): input = torch.randn(1, args.input_size) @@ -103,18 +107,44 @@ def setup_args(): args.distributed_init_host = "localhost" args.distributed_port = port + 1 args.local_world_size = args.distributed_world_size - return args + + cfg = OmegaConf.create() + cfg.optimization = OmegaConf.create() + cfg.common = OmegaConf.create() + cfg.distributed_training = OmegaConf.create() + cfg.dataset = OmegaConf.create() + cfg.bmuf = OmegaConf.create() + cfg.optimizer = OmegaConf.create() + + cfg.bmuf.global_sync_iter = args.global_sync_iter + cfg.bmuf.block_momentum = args.block_momentum + cfg.bmuf.block_lr = args.block_lr + cfg.dataset.batch_size = args.batch_size + cfg.optimization.lr = args.lr + cfg.optimizer.momentum = args.momentum + cfg.optimizer.weight_decay = args.weight_decay + cfg.bmuf.warmup_iterations = args.warmup_iterations + cfg.bmuf.use_nbm = args.use_nbm + cfg.bmuf.average_sync = args.average_sync + cfg.common.model_parallel_size = args.model_parallel_size + cfg.distributed_training.distributed_backend = args.distributed_backend + cfg.distributed_training.distributed_world_size = args.distributed_world_size + cfg.bmuf.distributed_world_size = args.distributed_world_size + cfg.distributed_training.distributed_init_method = args.distributed_init_method + cfg.distributed_training.distributed_port = args.distributed_port + + return cfg, args @unittest.skipIf(torch.cuda.device_count() < 2, "test requires 2 GPUs") class TestBMUF(unittest.TestCase): - def bmuf_process(self, args, iterations): + def bmuf_process(self, cfg, args, iterations): processes = [] results = Manager().dict() ctx = torch.multiprocessing.get_context("spawn") for rank in range(args.distributed_world_size): p = ctx.Process( - target=single_gpu_training, args=(args, rank, iterations, results) + target=single_gpu_training, args=(cfg, args, rank, iterations, results) ) p.start() processes.append(p) @@ -125,19 +155,20 @@ class TestBMUF(unittest.TestCase): def test_bmuf_sync(self): # Train model for 1 iteration and do bmuf sync without doing warmup - args = setup_args() + cfg, args = setup_args() iterations = 1 - results = self.bmuf_process(args, iterations) + results = self.bmuf_process(cfg, args, iterations) # Make sure params in both machines are same assert len(results) == 2 self.assertAlmostEqual(results[0], results[1]) def test_warmup_sync(self): # Train model for 20 iteration and do warmup sync without doing bmuf sync - args = setup_args() + cfg, args = setup_args() args.warmup_iterations = 20 + cfg.bmuf.warmup_iterations = args.warmup_iterations iterations = 20 - results = self.bmuf_process(args, iterations) + results = self.bmuf_process(cfg, args, iterations) # Make sure params in both machines are same assert len(results) == 2 self.assertAlmostEqual(results[0], results[1]) @@ -145,22 +176,27 @@ class TestBMUF(unittest.TestCase): def test_warmup_sync_bmuf_sync(self): # Train model for 25 iteration and do warmup sync after 20 iteration # and bmuf sync after 25 iteration - args = setup_args() + cfg, args = setup_args() args.warmup_iterations = 20 args.global_sync_iter = 5 + cfg.bmuf.warmup_iterations = args.warmup_iterations + cfg.bmuf.global_sync_iter = args.global_sync_iter iterations = 25 - results = self.bmuf_process(args, iterations) + results = self.bmuf_process(cfg, args, iterations) # Make sure params in both machines are same assert len(results) == 2 self.assertAlmostEqual(results[0], results[1]) def test_single_gpu_bmuf(self): # Train model for 5 iterations and use GPU 1 - args = setup_args() + cfg, args = setup_args() args.distributed_world_size = 1 args.warmup_iterations = 5 + cfg.distributed_training.distributed_world_size = args.distributed_world_size + cfg.bmuf.distributed_world_size = args.distributed_world_size + cfg.bmuf.warmup_iterations = args.warmup_iterations iterations = 20 - results = self.bmuf_process(args, iterations) + results = self.bmuf_process(cfg, args, iterations) assert len(results) == 1 def assertAlmostEqual(self, t1, t2): diff --git a/tests/test_fp16_optimizer.py b/tests/test_fp16_optimizer.py index c4195273..aa6a863d 100644 --- a/tests/test_fp16_optimizer.py +++ b/tests/test_fp16_optimizer.py @@ -9,6 +9,7 @@ import unittest import torch from fairseq.optim.fp16_optimizer import FP16Optimizer, MemoryEfficientFP16Optimizer +from omegaconf import OmegaConf @unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU") @@ -27,17 +28,23 @@ class TestGradientScaling(unittest.TestCase): self.model.cuda().half() self.params = list(self.model.parameters()) - self.namespace_dls = argparse.Namespace( - optimizer="adam", - lr=[0.1], - adam_betas="(0.9, 0.999)", - adam_eps=1e-8, - weight_decay=0.0, - fp16_init_scale=1, - fp16_scale_window=1, - fp16_scale_tolerance=1, - threshold_loss_scale=1, - min_loss_scale=1e-4, + self.cfg_dls = OmegaConf.create( + { + "optimizer": { + "_name": "adam", + "lr": [0.1], + "adam_betas": "(0.9, 0.999)", + "adam_eps": 1e-8, + "weight_decay": 0.0, + }, + "common": { + "fp16_init_scale": 1, + "fp16_scale_window": 1, + "fp16_scale_tolerance": 1, + "threshold_loss_scale": 1, + "min_loss_scale": 1e-4, + }, + } ) def run_iter(self, model, params, optimizer): @@ -68,7 +75,7 @@ class TestGradientScaling(unittest.TestCase): def test_mixed_precision(self): model = copy.deepcopy(self.model) params = list(model.parameters()) - optimizer = FP16Optimizer.build_optimizer(self.namespace_dls, params) + optimizer = FP16Optimizer.build_optimizer(self.cfg_dls, params) self.run_iter(model, params, optimizer) self.assertTrue( @@ -87,9 +94,7 @@ class TestGradientScaling(unittest.TestCase): def test_memory_efficient(self): model = copy.deepcopy(self.model) params = list(model.parameters()) - optimizer = MemoryEfficientFP16Optimizer.build_optimizer( - self.namespace_dls, params - ) + optimizer = MemoryEfficientFP16Optimizer.build_optimizer(self.cfg_dls, params) self.run_iter(model, params, optimizer) diff --git a/tests/test_inference_dropout.py b/tests/test_inference_dropout.py index fd5edd43..353ac674 100644 --- a/tests/test_inference_dropout.py +++ b/tests/test_inference_dropout.py @@ -6,6 +6,7 @@ import logging import unittest +from fairseq.dataclass.utils import convert_namespace_to_omegaconf from fairseq.models.transformer import TransformerModel from tests.test_sequence_generator import get_dummy_task_and_parser @@ -25,7 +26,8 @@ class TestInferenceDropout(unittest.TestCase): def test_sets_inference_dropout_to_true(self): self.args.retain_dropout = True self.transformer_model = TransformerModel.build_model(self.args, self.task) - self.transformer_model.prepare_for_inference_(self.args) + cfg = convert_namespace_to_omegaconf(self.args) + self.transformer_model.prepare_for_inference_(cfg) assert self.transformer_model.encoder.dropout_module.apply_during_inference assert self.transformer_model.decoder.dropout_module.apply_during_inference for layer in self.transformer_model.encoder.layers: @@ -33,7 +35,8 @@ class TestInferenceDropout(unittest.TestCase): def test_inference_dropout_false_by_default(self): self.transformer_model = TransformerModel.build_model(self.args, self.task) - self.transformer_model.prepare_for_inference_(self.args) + cfg = convert_namespace_to_omegaconf(self.args) + self.transformer_model.prepare_for_inference_(cfg) assert not self.transformer_model.encoder.dropout_module.apply_during_inference assert not self.transformer_model.decoder.dropout_module.apply_during_inference for layer in self.transformer_model.encoder.layers: @@ -59,7 +62,8 @@ class TestInferenceDropout(unittest.TestCase): "TransformerEncoderLayer", ] self.transformer_model = TransformerModel.build_model(self.args, self.task) - self.transformer_model.prepare_for_inference_(self.args) + cfg = convert_namespace_to_omegaconf(self.args) + self.transformer_model.prepare_for_inference_(cfg) assert self.transformer_model.encoder.dropout_module.apply_during_inference assert not self.transformer_model.decoder.dropout_module.apply_during_inference for layer in self.transformer_model.decoder.layers: diff --git a/tests/test_memory_efficient_fp16.py b/tests/test_memory_efficient_fp16.py index e10636d9..2bf2f298 100644 --- a/tests/test_memory_efficient_fp16.py +++ b/tests/test_memory_efficient_fp16.py @@ -10,6 +10,7 @@ import unittest import torch from fairseq.optim.adam import FairseqAdam from fairseq.optim.fp16_optimizer import MemoryEfficientFP16Optimizer +from omegaconf import OmegaConf @unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU") @@ -26,25 +27,36 @@ class TestMemoryEfficientFP16(unittest.TestCase): params = list(model.parameters()) # initialize memory efficient FP16 optimizer + # with pseudo DictConfigs optimizer = FairseqAdam( - argparse.Namespace( - lr=[0.00001], - adam_betas="(0.9, 0.999)", - adam_eps=1e-8, - weight_decay=0.0, + cfg=OmegaConf.create( + vars( + argparse.Namespace( + adam_betas="(0.9, 0.999)", + adam_eps=1e-8, + weight_decay=0.0, + lr=[0.00001], + ) + ) ), - params, + params=params, ) me_optimizer = MemoryEfficientFP16Optimizer( - argparse.Namespace( - fp16_init_scale=1, - fp16_scale_window=1, - fp16_scale_tolerance=1, - threshold_loss_scale=1, - min_loss_scale=1e-4, + cfg=OmegaConf.create( + { + "common": vars( + argparse.Namespace( + fp16_init_scale=1, + fp16_scale_window=1, + fp16_scale_tolerance=1, + threshold_loss_scale=1, + min_loss_scale=1e-4, + ) + ) + } ), - params, - optimizer, + params=params, + optimizer=optimizer, ) # optimizer state is created in the first step diff --git a/tests/test_train.py b/tests/test_train.py index 1b7e027c..57daa194 100644 --- a/tests/test_train.py +++ b/tests/test_train.py @@ -11,6 +11,7 @@ from unittest.mock import MagicMock, patch import torch from fairseq import checkpoint_utils, data +from omegaconf import OmegaConf def mock_trainer(epoch, num_updates, iterations_in_epoch): @@ -56,21 +57,29 @@ def get_trainer_and_epoch_itr(epoch, epoch_size, num_updates, iterations_in_epoc return trainer, epoch_itr -def get_mock_args(finetune_from_model=None): - args_mock = MagicMock() - args_mock.optimizer_overrides = "{}" - args_mock.reset_dataloader = False - args_mock.reset_meters = False - args_mock.reset_optimizer = False - args_mock.reset_lr_scheduler = False - args_mock.finetune_from_model = finetune_from_model - args_mock.model_parallel_size = 1 - return args_mock +def get_mock_cfg(finetune_from_model): + cfg_mock = OmegaConf.create( + { + "checkpoint": { + "optimizer_overrides": "{}", + "reset_dataloader": False, + "reset_meters": False, + "reset_optimizer": False, + "reset_lr_scheduler": False, + "finetune_from_model": finetune_from_model, + "model_parallel_size": 1, + }, + "common": { + "model_parallel_size": 1, + }, + } + ) + return cfg_mock class TestLoadCheckpoint(unittest.TestCase): def setUp(self): - self.args_mock = get_mock_args() + self.cfg_mock = get_mock_cfg(None) self.patches = { "os.makedirs": MagicMock(), "os.path.join": MagicMock(), @@ -91,7 +100,9 @@ class TestLoadCheckpoint(unittest.TestCase): trainer, epoch_itr = get_trainer_and_epoch_itr(2, 150, 200, 50) trainer.get_train_iterator = MagicMock(return_value=epoch_itr) - _, epoch_itr = checkpoint_utils.load_checkpoint(self.args_mock, trainer) + _, epoch_itr = checkpoint_utils.load_checkpoint( + self.cfg_mock.checkpoint, trainer + ) self.assertEqual(epoch_itr.epoch, 2) self.assertEqual(epoch_itr.iterations_in_epoch, 50) @@ -120,7 +131,9 @@ class TestLoadCheckpoint(unittest.TestCase): trainer, epoch_itr = get_trainer_and_epoch_itr(2, 150, 300, 150) trainer.get_train_iterator = MagicMock(return_value=epoch_itr) - _, epoch_itr = checkpoint_utils.load_checkpoint(self.args_mock, trainer) + _, epoch_itr = checkpoint_utils.load_checkpoint( + self.cfg_mock.checkpoint, trainer + ) itr = epoch_itr.next_epoch_itr(shuffle=False) self.assertEqual(epoch_itr.epoch, 3) @@ -133,7 +146,9 @@ class TestLoadCheckpoint(unittest.TestCase): trainer.get_train_iterator = MagicMock(return_value=epoch_itr) self.patches["os.path.isfile"].return_value = False - _, epoch_itr = checkpoint_utils.load_checkpoint(self.args_mock, trainer) + _, epoch_itr = checkpoint_utils.load_checkpoint( + self.cfg_mock.checkpoint, trainer + ) itr = epoch_itr.next_epoch_itr(shuffle=False) self.assertEqual(epoch_itr.epoch, 1) @@ -152,10 +167,12 @@ class TestLoadCheckpoint(unittest.TestCase): "reset_dataloader", ]: with self.subTest(arg=arg): - args_mock = get_mock_args("/temp/checkpoint_pretrained.pt") - setattr(args_mock, arg, True) + cfg_mock = get_mock_cfg("/temp/checkpoint_pretrained.pt") + cfg_mock["checkpoint"][arg] = True with self.assertRaises(Exception) as context: - _, _ = checkpoint_utils.load_checkpoint(args_mock, trainer) + _, _ = checkpoint_utils.load_checkpoint( + cfg_mock.checkpoint, trainer + ) self.assertTrue( "--finetune-from-model can not be set together with either --reset-optimizer" @@ -168,8 +185,6 @@ class TestLoadCheckpoint(unittest.TestCase): trainer, epoch_itr = get_trainer_and_epoch_itr(1, 150, 0, 0) trainer.get_train_iterator = MagicMock(return_value=epoch_itr) from_model_path = "/temp/checkpoint_pretrained.pt" - args_mock = get_mock_args(from_model_path) - args_mock.restore_file = "checkpoint_last.pt" def mock_finetune_exist(path): if path == from_model_path: @@ -180,7 +195,9 @@ class TestLoadCheckpoint(unittest.TestCase): self.patches[ "fairseq.file_io.PathManager.exists" ].side_effect = mock_finetune_exist - _, _ = checkpoint_utils.load_checkpoint(args_mock, trainer) + cfg_mock = get_mock_cfg(from_model_path) + cfg_mock.checkpoint.restore_file = "checkpoint_last.pt" + _, _ = checkpoint_utils.load_checkpoint(cfg_mock.checkpoint, trainer) ( checkpoint_path, reset_optimizer, @@ -197,8 +214,6 @@ class TestLoadCheckpoint(unittest.TestCase): trainer, epoch_itr = get_trainer_and_epoch_itr(1, 150, 0, 0) trainer.get_train_iterator = MagicMock(return_value=epoch_itr) from_model_path = "/temp/checkpoint_pretrained.pt" - args_mock = get_mock_args(from_model_path) - args_mock.restore_file = "checkpoint_last.pt" # launch second time # both restore_file=checkpoint_last.pt and finetune_from_model are set @@ -211,7 +226,9 @@ class TestLoadCheckpoint(unittest.TestCase): self.patches[ "fairseq.file_io.PathManager.exists" ].side_effect = mock_finetune_exist - _, _ = checkpoint_utils.load_checkpoint(args_mock, trainer) + cfg_mock = get_mock_cfg(from_model_path) + cfg_mock.checkpoint.restore_file = "checkpoint_last.pt" + _, _ = checkpoint_utils.load_checkpoint(cfg_mock.checkpoint, trainer) ( checkpoint_path, reset_optimizer, diff --git a/tests/utils.py b/tests/utils.py index 91feca6b..a145aa58 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -20,7 +20,7 @@ from fairseq.models import ( FairseqIncrementalDecoder, ) from fairseq.models.fairseq_encoder import EncoderOut -from fairseq.tasks import FairseqTask, LegacyFairseqTask +from fairseq.tasks import LegacyFairseqTask from fairseq_cli import generate, interactive, preprocess, train, validate