Enable Hydra configs in fairseq (#1343) (#1510)

Summary:
Pull Request resolved: https://github.com/facebookresearch/pytext/pull/1510

this is the main pr that switches on hydra functionality in fairseq

we migrate "args" object into omegaconf "DictConfig" at all legacy entry points

in addition this migrates various components from secondary registries (like bpe encoders and tokenizers) to make the migration smoother

i am going through code that references migrated fairseq components and changing it to inherit from "Legacy*" components instead. hopefully tests will catch most of this

Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1343

Reviewed By: myleott

Differential Revision: D23973928

Pulled By: alexeib

fbshipit-source-id: dd9554981fff51ea75c1ff343874d1d6e61793c9
This commit is contained in:
alexeib 2020-10-20 00:31:00 -07:00 committed by Facebook GitHub Bot
parent c76cb6dfb9
commit 3b27ed7996
85 changed files with 2034 additions and 1681 deletions

View File

@ -1,7 +1,111 @@
# @package _group_
common:
no_progress_bar: false
log_interval: 100
log_format: null
tensorboard_logdir: null
seed: 1
cpu: false
tpu: false
bf16: false
fp16: false
memory_efficient_fp16: false
memory_efficient_bf16: false
fp16_no_flatten_grads: false
fp16_init_scale: 128
fp16_scale_window: null
fp16_scale_tolerance: 0.0
min_loss_scale: 1.0e-4
threshold_loss_scale: null
user_dir: null
empty_cache_freq: 0
all_gather_list_size: 16384
model_parallel_size: 1
quantization_config_path: null
profile: false
distributed_training:
distributed_rank: 0
distributed_backend: "nccl"
distributed_init_method: null
distributed_port: -1
device_id: 0
local_rank: 0
distributed_no_spawn: false
ddp_backend: "c10d"
bucket_cap_mb: 25
fix_batches_to_gpus: false
find_unused_parameters: false
fast_stat_sync: false
broadcast_buffers: false
distributed_wrapper: "DDP"
slowmo_momentum: null
slowmo_algorithm: "LocalSGD"
localsgd_frequency: 3
dataset:
num_workers: 1
skip_invalid_size_inputs_valid_test: false
max_tokens: null
batch_size: null
required_batch_size_multiple: 8
dataset_impl: null
data_buffer_size: 10
train_subset: "train"
valid_subset: "valid"
validate_interval: 1
fixed_validation_seed: null
disable_validation: false
curriculum: 0
gen_subset: "test"
num_shards: 1
shard_id: 0
max_tokens_valid: ${dataset.max_tokens}
batch_size_valid: ${dataset.batch_size}
optimization:
max_epoch: 0
max_update: 0
clip_norm: 25.0
sentence_avg: false
update_freq: [ 1 ]
lr: [ 0.25 ]
min_lr: -1.0
use_bmuf: false
checkpoint:
save_dir: "checkpoints"
restore_file: "checkpoint_last.pt"
reset_dataloader: false
reset_lr_scheduler: false
reset_meters: false
reset_optimizer: false
optimizer_overrides: "{}"
save_interval: 1
save_interval_updates: 0
keep_interval_updates: -1
keep_last_epochs: -1
keep_best_checkpoints: -1
no_save: false
no_epoch_checkpoints: false
no_last_checkpoints: false
no_save_optimizer_state: false
best_checkpoint_metric: "loss"
maximize_best_checkpoint_metric: false
patience: -1
checkpoint_suffix: ""
bmuf:
block_lr: 1
block_momentum: 0.875
global_sync_iter: 50
warmup_iterations: 500
use_nbm: false
average_sync: false
defaults: defaults:
- params: training_params - task: language_modeling
- task: language_modeling - model: null
- model: transformer_lm - criterion: null
- criterion: cross_entropy - optimizer: null
- optimizer: adam - lr_scheduler: null
- lr_scheduler: inverse_sqrt - bpe: null
- tokenizer: null
- scoring: null
- generation: null
- common_eval: null
- eval_lm: null

View File

@ -1,7 +0,0 @@
defaults:
- params: eval_lm_params
- task: language_modeling
- model: transformer_lm
- criterion: cross_entropy
- optimizer: adam
- lr_scheduler: inverse_sqrt

View File

@ -1,3 +1,3 @@
# @package _group_ # @package _group_
sentence_avg: ${params.optimization.sentence_avg} sentence_avg: ${optimization.sentence_avg}
ddp_backend: ${params.distributed_training.ddp_backend} ddp_backend: ${distributed_training.ddp_backend}

View File

@ -1,3 +1,2 @@
# @package _group_ # @package _group_
sentence_avg: ${params.optimization.sentence_avg} sentence_avg: ${optimization.sentence_avg}
ddp_backend: ${params.distributed_training.ddp_backend}

View File

@ -1,105 +0,0 @@
# @package _group_
common:
no_progress_bar: false
log_interval: 100
log_format: null
tensorboard_logdir: null
seed: 1
cpu: false
fp16: false
memory_efficient_fp16: false
fp16_no_flatten_grads: false
fp16_init_scale: 128
fp16_scale_window: null
fp16_scale_tolerance: 0.0
min_loss_scale: 1.0e-4
threshold_loss_scale: null
user_dir: null
empty_cache_freq: 0
all_gather_list_size: 16384
model_parallel_size: 1
checkpoint_suffix: ""
quantization_config_path: null
distributed_training:
distributed_rank: 0
distributed_backend: "nccl"
distributed_init_method: null
distributed_port: -1
device_id: 0
local_rank: 0
distributed_no_spawn: false
ddp_backend: "c10d"
bucket_cap_mb: 25
fix_batches_to_gpus: false
find_unused_parameters: false
fast_stat_sync: false
broadcast_buffers: false
distributed_wrapper: "DDP"
slowmo_momentum: null
slowmo_algorithm: "LocalSGD"
localsgd_frequency: 3
dataset:
num_workers: 1
skip_invalid_size_inputs_valid_test: false
max_tokens: null
batch_size: ${params.dataset.batch_size}
required_batch_size_multiple: 8
dataset_impl: null
data_buffer_size: 10
train_subset: "train"
valid_subset: "valid"
validate_interval: 1
fixed_validation_seed: null
disable_validation: false
curriculum: 0
gen_subset: "test"
num_shards: 1
shard_id: 0
max_tokens_valid: ${params.dataset.max_tokens}
batch_size_valid: ${params.dataset.batch_size}
optimization:
max_epoch: 0
max_update: 0
clip_norm: 25.0
sentence_avg: false
update_freq: [1]
lr: [0.25]
min_lr: -1.0
use_bmuf: false
checkpoint:
save_dir: "checkpoints"
restore_file: "checkpoint_last.pt"
reset_dataloader: false
reset_lr_scheduler: false
reset_meters: false
reset_optimizer: false
optimizer_overrides: "{}"
save_interval: 1
save_interval_updates: 0
keep_interval_updates: -1
keep_last_epochs: -1
keep_best_checkpoints: -1
no_save: false
no_epoch_checkpoints: false
no_last_checkpoints: false
no_save_optimizer_state: false
best_checkpoint_metric: "loss"
maximize_best_checkpoint_metric: false
patience: -1
common_eval:
path: null
remove_bpe: null
quiet: false
model_overrides: "{}"
results_path: null
eval_lm:
output_word_probs: false
output_word_stats: false
context_window: 0
bmuf:
block_lr: 1
block_momentum: 0.875
global_sync_iter: 50
warmup_iterations: 500
use_nbm: false
average_sync: false

View File

@ -1,95 +0,0 @@
# @package _group_
common:
no_progress_bar: false
log_interval: 100
log_format: null
tensorboard_logdir: null
seed: 1
cpu: false
fp16: false
memory_efficient_fp16: false
fp16_no_flatten_grads: false
fp16_init_scale: 128
fp16_scale_window: null
fp16_scale_tolerance: 0.0
min_loss_scale: 1.0e-4
threshold_loss_scale: null
user_dir: null
empty_cache_freq: 0
all_gather_list_size: 16384
model_parallel_size: 1
checkpoint_suffix: ""
quantization_config_path: null
distributed_training:
distributed_rank: 0
distributed_backend: "nccl"
distributed_init_method: null
distributed_port: -1
device_id: 0
local_rank: 0
distributed_no_spawn: false
ddp_backend: "c10d"
bucket_cap_mb: 25
fix_batches_to_gpus: false
find_unused_parameters: false
fast_stat_sync: false
broadcast_buffers: false
distributed_wrapper: "DDP"
slowmo_momentum: null
slowmo_algorithm: "LocalSGD"
localsgd_frequency: 3
dataset:
num_workers: 1
skip_invalid_size_inputs_valid_test: false
max_tokens: null
batch_size: ${params.dataset.batch_size}
required_batch_size_multiple: 8
dataset_impl: null
data_buffer_size: 10
train_subset: "train"
valid_subset: "valid"
validate_interval: 1
fixed_validation_seed: null
disable_validation: false
curriculum: 0
gen_subset: "test"
num_shards: 1
shard_id: 0
max_tokens_valid: ${params.dataset.max_tokens}
batch_size_valid: ${params.dataset.batch_size}
optimization:
max_epoch: 0
max_update: 0
clip_norm: 25.0
sentence_avg: false
update_freq: [1]
lr: [0.25]
min_lr: -1.0
use_bmuf: false
checkpoint:
save_dir: "checkpoints"
restore_file: "checkpoint_last.pt"
reset_dataloader: false
reset_lr_scheduler: false
reset_meters: false
reset_optimizer: false
optimizer_overrides: "{}"
save_interval: 1
save_interval_updates: 0
keep_interval_updates: -1
keep_last_epochs: -1
keep_best_checkpoints: -1
no_save: false
no_epoch_checkpoints: false
no_last_checkpoints: false
no_save_optimizer_state: false
best_checkpoint_metric: "loss"
maximize_best_checkpoint_metric: false
patience: -1
bmuf:
block_lr: 1
block_momentum: 0.875
global_sync_iter: 50
warmup_iterations: 500
use_nbm: false
average_sync: false

View File

@ -13,7 +13,6 @@ For example, if we'd like to train a language model with transformer, we could p
``` ```
defaults: defaults:
- params: training_params
- task: language_modeling - task: language_modeling
- model: transformer_lm - model: transformer_lm
- criterion: cross_entropy - criterion: cross_entropy
@ -21,7 +20,7 @@ defaults:
- lr_scheduler: inverse_sqrt - lr_scheduler: inverse_sqrt
``` ```
- Provide generic parameters common across different training jobs: `config/params/training_params.yaml` - Provide generic parameters common across different jobs: `config.yaml`
- Provide task parameters: `config/task/language_modeling.yaml` - Provide task parameters: `config/task/language_modeling.yaml`
- Provide model parameters: `config/model/transformer_lm.yaml` - Provide model parameters: `config/model/transformer_lm.yaml`
- Provide criterion parameters: `config/criterion/cross_entropy.yaml` - Provide criterion parameters: `config/criterion/cross_entropy.yaml`
@ -41,7 +40,6 @@ Alternatively, if we need to override certain params from the command line, we c
``` ```
python fairseq_cli/train_hydra.py python fairseq_cli/train_hydra.py
params=training_params \
task=language_modeling \ task=language_modeling \
task.data=/private/home/abaevski/data/wiki103 \ task.data=/private/home/abaevski/data/wiki103 \
task.tokens_per_sample=512 \ task.tokens_per_sample=512 \
@ -56,17 +54,17 @@ lr_scheduler=inverse_sqrt \
lr_scheduler.warmup_updates=4000 \ lr_scheduler.warmup_updates=4000 \
lr_scheduler.warmup_init_lr=1e-07 \ lr_scheduler.warmup_init_lr=1e-07 \
criterion=cross_entropy \ criterion=cross_entropy \
params.common.fp16=true \ common.fp16=true \
params.common.log_format=json \ common.log_format=json \
params.common.log_interval=1 \ common.log_interval=1 \
params.dataset.max_tokens=1024 \ dataset.max_tokens=1024 \
params.dataset.num_workers=4 \ dataset.num_workers=4 \
params.optimization.update_freq=[16] \ optimization.update_freq=[16] \
params.optimization.max_update=50000 \ optimization.max_update=50000 \
params.optimization.clip_norm=0.0 \ optimization.clip_norm=0.0 \
params.optimization.lr=[0.0005] \ optimization.lr=[0.0005] \
params.checkpoint.save_dir=/checkpoint/mtian/transformer_wikitext-103-hydra-args-cli \ checkpoint.save_dir=/checkpoint/mtian/transformer_wikitext-103-hydra-args-cli \
params.checkpoint.save_interval_updates=10 checkpoint.save_interval_updates=10
``` ```
## Migrate existing/Creating new modules to hydra interface ## Migrate existing/Creating new modules to hydra interface

View File

@ -212,7 +212,7 @@ following contents::
@register_task('simple_classification') @register_task('simple_classification')
class SimpleClassificationTask(FairseqTask): class SimpleClassificationTask(LegacyFairseqTask):
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):

View File

@ -27,7 +27,13 @@ def score_target_hypo(
print("lenpen", lenpen, "weight1", a, "weight2", b, "weight3", c) print("lenpen", lenpen, "weight1", a, "weight2", b, "weight3", c)
gen_output_lst, bitext1_lst, bitext2_lst, lm_res_lst = load_score_files(args) gen_output_lst, bitext1_lst, bitext2_lst, lm_res_lst = load_score_files(args)
dict = dictionary.Dictionary() dict = dictionary.Dictionary()
scorer = bleu.Scorer(dict.pad(), dict.eos(), dict.unk()) scorer = scorer = bleu.Scorer(
bleu.BleuConfig(
pad=dict.pad(),
eos=dict.eos(),
unk=dict.unk(),
)
)
ordered_hypos = {} ordered_hypos = {}
ordered_targets = {} ordered_targets = {}

View File

@ -20,8 +20,8 @@ class WSCCriterion(LegacyFairseqCriterion):
self.prediction_h = open(self.args.save_predictions, "w") self.prediction_h = open(self.args.save_predictions, "w")
else: else:
self.prediction_h = None self.prediction_h = None
self.bpe = encoders.build_bpe(args) self.bpe = encoders.build_bpe(args.bpe)
self.tokenizer = encoders.build_tokenizer(args) self.tokenizer = encoders.build_tokenizer(args.tokenizer)
def __del__(self): def __del__(self):
if self.prediction_h is not None: if self.prediction_h is not None:

View File

@ -85,7 +85,7 @@ Produce model scores for the generated translations using `--retain-dropout` opt
``` ```
CUDA_VISIBLE_DEVICES=${GPU} fairseq-generate ${TMP}/bin-repeated --path ${MODEL_DIR}/${LP}.pt --beam 5 CUDA_VISIBLE_DEVICES=${GPU} fairseq-generate ${TMP}/bin-repeated --path ${MODEL_DIR}/${LP}.pt --beam 5
--source-lang $SRC_LANG --target-lang $TGT_LANG --no-progress-bar --unkpen 5 --score-reference --retain-dropout --source-lang $SRC_LANG --target-lang $TGT_LANG --no-progress-bar --unkpen 5 --score-reference --retain-dropout
--retain-dropout-modules TransformerModel TransformerEncoder TransformerDecoder TransformerEncoderLayer --retain-dropout-modules '["TransformerModel","TransformerEncoder","TransformerDecoder","TransformerEncoderLayer"]'
TransformerDecoderLayer --seed 46 > $TMP/dropout.scoring.out TransformerDecoderLayer --seed 46 > $TMP/dropout.scoring.out
grep ^H $TMP/dropout.scoring.out | cut -f2- > $TMP/dropout.scores grep ^H $TMP/dropout.scoring.out | cut -f2- > $TMP/dropout.scores

View File

@ -3,36 +3,42 @@
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import ast
import collections import collections
import logging import logging
import os import os
import re import re
import traceback import traceback
from collections import OrderedDict from collections import OrderedDict
from typing import Union from typing import Optional, Union
import torch import torch
from fairseq.dataclass.utils import (
convert_namespace_to_omegaconf,
overwrite_args_by_name,
)
from fairseq.file_io import PathManager from fairseq.file_io import PathManager
from fairseq.models import FairseqDecoder, FairseqEncoder from fairseq.models import FairseqDecoder, FairseqEncoder
from omegaconf import DictConfig, open_dict
from torch.serialization import default_restore_location from torch.serialization import default_restore_location
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def save_checkpoint(args, trainer, epoch_itr, val_loss): def save_checkpoint(cfg: DictConfig, trainer, epoch_itr, val_loss):
from fairseq import distributed_utils, meters from fairseq import meters
# only one worker should attempt to create the required dir # only one worker should attempt to create the required dir
if args.distributed_rank == 0: if cfg.distributed_rank == 0:
os.makedirs(args.save_dir, exist_ok=True) os.makedirs(cfg.save_dir, exist_ok=True)
prev_best = getattr(save_checkpoint, "best", val_loss) prev_best = getattr(save_checkpoint, "best", val_loss)
if val_loss is not None: if val_loss is not None:
best_function = max if args.maximize_best_checkpoint_metric else min best_function = max if cfg.maximize_best_checkpoint_metric else min
save_checkpoint.best = best_function(val_loss, prev_best) save_checkpoint.best = best_function(val_loss, prev_best)
if args.no_save: if cfg.no_save:
return return
trainer.consolidate_optimizer() trainer.consolidate_optimizer()
@ -41,7 +47,7 @@ def save_checkpoint(args, trainer, epoch_itr, val_loss):
return return
def is_better(a, b): def is_better(a, b):
return a >= b if args.maximize_best_checkpoint_metric else a <= b return a >= b if cfg.maximize_best_checkpoint_metric else a <= b
write_timer = meters.StopwatchMeter() write_timer = meters.StopwatchMeter()
write_timer.start() write_timer.start()
@ -50,38 +56,36 @@ def save_checkpoint(args, trainer, epoch_itr, val_loss):
end_of_epoch = epoch_itr.end_of_epoch() end_of_epoch = epoch_itr.end_of_epoch()
updates = trainer.get_num_updates() updates = trainer.get_num_updates()
suffix = getattr(args, "checkpoint_suffix", "") suffix = cfg.checkpoint_suffix or ""
checkpoint_conds = collections.OrderedDict() checkpoint_conds = collections.OrderedDict()
checkpoint_conds["checkpoint{}{}.pt".format(epoch, suffix)] = ( checkpoint_conds["checkpoint{}{}.pt".format(epoch, suffix)] = (
end_of_epoch end_of_epoch and not cfg.no_epoch_checkpoints and epoch % cfg.save_interval == 0
and not args.no_epoch_checkpoints
and epoch % args.save_interval == 0
) )
checkpoint_conds["checkpoint_{}_{}{}.pt".format(epoch, updates, suffix)] = ( checkpoint_conds["checkpoint_{}_{}{}.pt".format(epoch, updates, suffix)] = (
not end_of_epoch not end_of_epoch
and args.save_interval_updates > 0 and cfg.save_interval_updates > 0
and updates % args.save_interval_updates == 0 and updates % cfg.save_interval_updates == 0
) )
checkpoint_conds["checkpoint_best{}.pt".format(suffix)] = val_loss is not None and ( checkpoint_conds["checkpoint_best{}.pt".format(suffix)] = val_loss is not None and (
not hasattr(save_checkpoint, "best") not hasattr(save_checkpoint, "best")
or is_better(val_loss, save_checkpoint.best) or is_better(val_loss, save_checkpoint.best)
) )
if val_loss is not None and args.keep_best_checkpoints > 0: if val_loss is not None and cfg.keep_best_checkpoints > 0:
checkpoint_conds[ checkpoint_conds[
"checkpoint.best_{}_{:.2f}.pt".format(args.best_checkpoint_metric, val_loss) "checkpoint.best_{}_{:.2f}.pt".format(cfg.best_checkpoint_metric, val_loss)
] = not hasattr(save_checkpoint, "best") or is_better( ] = not hasattr(save_checkpoint, "best") or is_better(
val_loss, save_checkpoint.best val_loss, save_checkpoint.best
) )
checkpoint_conds[ checkpoint_conds[
"checkpoint_last{}.pt".format(suffix) "checkpoint_last{}.pt".format(suffix)
] = not args.no_last_checkpoints ] = not cfg.no_last_checkpoints
extra_state = {"train_iterator": epoch_itr.state_dict(), "val_loss": val_loss} extra_state = {"train_iterator": epoch_itr.state_dict(), "val_loss": val_loss}
if hasattr(save_checkpoint, "best"): if hasattr(save_checkpoint, "best"):
extra_state.update({"best": save_checkpoint.best}) extra_state.update({"best": save_checkpoint.best})
checkpoints = [ checkpoints = [
os.path.join(args.save_dir, fn) for fn, cond in checkpoint_conds.items() if cond os.path.join(cfg.save_dir, fn) for fn, cond in checkpoint_conds.items() if cond
] ]
if len(checkpoints) > 0: if len(checkpoints) > 0:
trainer.save_checkpoint(checkpoints[0], extra_state) trainer.save_checkpoint(checkpoints[0], extra_state)
@ -95,51 +99,52 @@ def save_checkpoint(args, trainer, epoch_itr, val_loss):
) )
) )
if not end_of_epoch and args.keep_interval_updates > 0: if not end_of_epoch and cfg.keep_interval_updates > 0:
# remove old checkpoints; checkpoints are sorted in descending order # remove old checkpoints; checkpoints are sorted in descending order
checkpoints = checkpoint_paths( checkpoints = checkpoint_paths(
args.save_dir, pattern=r"checkpoint_\d+_(\d+)\.pt" cfg.save_dir, pattern=r"checkpoint_\d+_(\d+)\.pt"
) )
for old_chk in checkpoints[args.keep_interval_updates :]: for old_chk in checkpoints[cfg.keep_interval_updates :]:
if os.path.lexists(old_chk): if os.path.lexists(old_chk):
os.remove(old_chk) os.remove(old_chk)
if args.keep_last_epochs > 0: if cfg.keep_last_epochs > 0:
# remove old epoch checkpoints; checkpoints are sorted in descending order # remove old epoch checkpoints; checkpoints are sorted in descending order
checkpoints = checkpoint_paths(args.save_dir, pattern=r"checkpoint(\d+)\.pt") checkpoints = checkpoint_paths(cfg.save_dir, pattern=r"checkpoint(\d+)\.pt")
for old_chk in checkpoints[args.keep_last_epochs :]: for old_chk in checkpoints[cfg.keep_last_epochs :]:
if os.path.lexists(old_chk): if os.path.lexists(old_chk):
os.remove(old_chk) os.remove(old_chk)
if args.keep_best_checkpoints > 0: if cfg.keep_best_checkpoints > 0:
# only keep the best N checkpoints according to validation metric # only keep the best N checkpoints according to validation metric
checkpoints = checkpoint_paths( checkpoints = checkpoint_paths(
args.save_dir, cfg.save_dir,
pattern=r"checkpoint\.best_{}_(\d+\.?\d*)\.pt".format( pattern=r"checkpoint\.best_{}_(\d+\.?\d*)\.pt".format(
args.best_checkpoint_metric cfg.best_checkpoint_metric
), ),
) )
if not args.maximize_best_checkpoint_metric: if not cfg.maximize_best_checkpoint_metric:
checkpoints = checkpoints[::-1] checkpoints = checkpoints[::-1]
for old_chk in checkpoints[args.keep_best_checkpoints :]: for old_chk in checkpoints[cfg.keep_best_checkpoints :]:
if os.path.lexists(old_chk): if os.path.lexists(old_chk):
os.remove(old_chk) os.remove(old_chk)
def load_checkpoint(args, trainer, **passthrough_args): def load_checkpoint(cfg: DictConfig, trainer, **passthrough_args):
""" """
Load a checkpoint and restore the training iterator. Load a checkpoint and restore the training iterator.
*passthrough_args* will be passed through to *passthrough_args* will be passed through to
``trainer.get_train_iterator``. ``trainer.get_train_iterator``.
""" """
reset_optimizer = args.reset_optimizer
reset_lr_scheduler = args.reset_lr_scheduler
optimizer_overrides = eval(args.optimizer_overrides)
reset_meters = args.reset_meters
reset_dataloader = args.reset_dataloader
if getattr(args, "finetune_from_model", None) is not None and ( reset_optimizer = cfg.reset_optimizer
reset_lr_scheduler = cfg.reset_lr_scheduler
optimizer_overrides = ast.literal_eval(cfg.optimizer_overrides)
reset_meters = cfg.reset_meters
reset_dataloader = cfg.reset_dataloader
if cfg.finetune_from_model is not None and (
reset_optimizer or reset_lr_scheduler or reset_meters or reset_dataloader reset_optimizer or reset_lr_scheduler or reset_meters or reset_dataloader
): ):
raise ValueError( raise ValueError(
@ -147,19 +152,19 @@ def load_checkpoint(args, trainer, **passthrough_args):
" or reset_lr_scheduler or reset_meters or reset_dataloader" " or reset_lr_scheduler or reset_meters or reset_dataloader"
) )
suffix = getattr(args, "checkpoint_suffix", "") suffix = cfg.checkpoint_suffix
if ( if (
args.restore_file == "checkpoint_last.pt" cfg.restore_file == "checkpoint_last.pt"
): # default value of restore_file is 'checkpoint_last.pt' ): # default value of restore_file is 'checkpoint_last.pt'
checkpoint_path = os.path.join( checkpoint_path = os.path.join(
args.save_dir, "checkpoint_last{}.pt".format(suffix) cfg.save_dir, "checkpoint_last{}.pt".format(suffix)
) )
first_launch = not PathManager.exists(checkpoint_path) first_launch = not PathManager.exists(checkpoint_path)
if getattr(args, "finetune_from_model", None) is not None and first_launch: if cfg.finetune_from_model is not None and first_launch:
# if there is no last checkpoint to restore, start the finetune from pretrained model # if there is no last checkpoint to restore, start the finetune from pretrained model
# else just use usual logic to load checkpoint, e.g. restart from last checkpoint and etc. # else just use usual logic to load checkpoint, e.g. restart from last checkpoint and etc.
if PathManager.exists(args.finetune_from_model): if PathManager.exists(cfg.finetune_from_model):
checkpoint_path = args.finetune_from_model checkpoint_path = cfg.finetune_from_model
reset_optimizer = True reset_optimizer = True
reset_lr_scheduler = True reset_lr_scheduler = True
reset_meters = True reset_meters = True
@ -170,19 +175,17 @@ def load_checkpoint(args, trainer, **passthrough_args):
) )
else: else:
raise ValueError( raise ValueError(
f"--funetune-from-model {args.finetune_from_model} does not exist" f"--funetune-from-model {cfg.finetune_from_model} does not exist"
) )
elif getattr(args, "model_parallel_size", 1) > 1: elif cfg.model_parallel_size > 1:
checkpoint_path = args.restore_file.replace(".pt", suffix + ".pt") checkpoint_path = cfg.restore_file.replace(".pt", suffix + ".pt")
else: else:
checkpoint_path = args.restore_file checkpoint_path = cfg.restore_file
if args.restore_file != "checkpoint_last.pt" and getattr( if cfg.restore_file != "checkpoint_last.pt" and cfg.finetune_from_model:
args, "finetune_from_model", None
):
raise ValueError( raise ValueError(
"--finetune-from-model and --restore-file (non-default value) " "--finetune-from-model and --restore-file (non-default value) "
"can not be specified together: " + str(args) "can not be specified together: " + str(cfg)
) )
extra_state = trainer.load_checkpoint( extra_state = trainer.load_checkpoint(
@ -225,10 +228,14 @@ def load_checkpoint_to_cpu(path, arg_overrides=None):
f, map_location=lambda s, l: default_restore_location(s, "cpu") f, map_location=lambda s, l: default_restore_location(s, "cpu")
) )
args = state["args"] if "args" in state and state["args"] is not None and arg_overrides is not None:
if arg_overrides is not None: args = state["args"]
for arg_name, arg_val in arg_overrides.items(): for arg_name, arg_val in arg_overrides.items():
setattr(args, arg_name, arg_val) setattr(args, arg_name, arg_val)
if "cfg" in state and state["cfg"] is not None and arg_overrides is not None:
overwrite_args_by_name(state["cfg"], arg_overrides)
state = _upgrade_state_dict(state) state = _upgrade_state_dict(state)
return state return state
@ -274,19 +281,28 @@ def load_model_ensemble_and_task(
filename = filename.replace(".pt", suffix + ".pt") filename = filename.replace(".pt", suffix + ".pt")
else: else:
filename = orig_filename[:-3] + f"_part{shard_idx}.pt" filename = orig_filename[:-3] + f"_part{shard_idx}.pt"
if not PathManager.exists(filename): if not PathManager.exists(filename):
raise IOError("Model file not found: {}".format(filename)) raise IOError("Model file not found: {}".format(filename))
state = load_checkpoint_to_cpu(filename, arg_overrides) state = load_checkpoint_to_cpu(filename, arg_overrides)
if shard_idx == 0: if "args" in state and state["args"] is not None:
args = state["args"] cfg = convert_namespace_to_omegaconf(state["args"])
if task is None: elif "cfg" in state and state["cfg"] is not None:
task = tasks.setup_task(args) cfg = state["cfg"]
else:
raise RuntimeError(
f"Neither args nor cfg exist in state keys = {state.keys()}"
)
# build model for ensemble if task is None:
model = task.build_model(args) task = tasks.setup_task(cfg.task)
model.load_state_dict(state["model"], strict=strict, args=args)
# build model for ensemble
model = task.build_model(cfg.model)
model.load_state_dict(state["model"], strict=strict, model_cfg=cfg.model)
ensemble.append(model) ensemble.append(model)
return ensemble, args, task return ensemble, cfg, task
def checkpoint_paths(path, pattern=r"checkpoint(\d+)\.pt"): def checkpoint_paths(path, pattern=r"checkpoint(\d+)\.pt"):
@ -323,7 +339,7 @@ def torch_persistent_save(obj, f):
def save_state( def save_state(
filename, filename,
args, cfg: DictConfig,
model_state_dict, model_state_dict,
criterion, criterion,
optimizer, optimizer,
@ -331,6 +347,7 @@ def save_state(
num_updates, num_updates,
optim_history=None, optim_history=None,
extra_state=None, extra_state=None,
**kwargs,
): ):
from fairseq import utils from fairseq import utils
@ -339,7 +356,8 @@ def save_state(
if extra_state is None: if extra_state is None:
extra_state = {} extra_state = {}
state_dict = { state_dict = {
"args": args, "cfg": cfg,
"args": kwargs.get("args", None),
"model": model_state_dict or {}, "model": model_state_dict or {},
"optimizer_history": optim_history "optimizer_history": optim_history
+ [ + [
@ -354,11 +372,17 @@ def save_state(
} }
if utils.has_parameters(criterion): if utils.has_parameters(criterion):
state_dict["criterion"] = criterion.state_dict() state_dict["criterion"] = criterion.state_dict()
if not args.no_save_optimizer_state:
state_dict["last_optimizer_state"] = optimizer.state_dict()
# convert all state to CPU if cfg is None:
state_dict = utils.move_to_cpu(state_dict) cfg = state_dict["args"]
assert cfg is not None, "must provide cfg or args"
if isinstance(cfg, DictConfig):
no_save_optimizer_state = cfg.checkpoint.no_save_optimizer_state
else:
no_save_optimizer_state = cfg.no_save_optimizer_state
if not no_save_optimizer_state:
state_dict["last_optimizer_state"] = optimizer.state_dict()
with PathManager.open(filename, "wb") as f: with PathManager.open(filename, "wb") as f:
torch_persistent_save(state_dict, f) torch_persistent_save(state_dict, f)
@ -403,46 +427,49 @@ def _upgrade_state_dict(state):
# keep track of number of updates # keep track of number of updates
if "num_updates" not in state["optimizer_history"][-1]: if "num_updates" not in state["optimizer_history"][-1]:
state["optimizer_history"][-1]["num_updates"] = 0 state["optimizer_history"][-1]["num_updates"] = 0
# old model checkpoints may not have separate source/target positions
if hasattr(state["args"], "max_positions") and not hasattr(
state["args"], "max_source_positions"
):
state["args"].max_source_positions = state["args"].max_positions
state["args"].max_target_positions = state["args"].max_positions
# use stateful training data iterator # use stateful training data iterator
if "train_iterator" not in state["extra_state"]: if "train_iterator" not in state["extra_state"]:
state["extra_state"]["train_iterator"] = { state["extra_state"]["train_iterator"] = {
"epoch": state["extra_state"]["epoch"], "epoch": state["extra_state"]["epoch"],
"iterations_in_epoch": state["extra_state"].get("batch_offset", 0), "iterations_in_epoch": state["extra_state"].get("batch_offset", 0),
} }
# default to translation task
if not hasattr(state["args"], "task"):
state["args"].task = "translation"
# --raw-text and --lazy-load are deprecated
if getattr(state["args"], "raw_text", False):
state["args"].dataset_impl = "raw"
elif getattr(state["args"], "lazy_load", False):
state["args"].dataset_impl = "lazy"
# epochs start at 1
if state["extra_state"]["train_iterator"] is not None:
state["extra_state"]["train_iterator"]["epoch"] = max(
state["extra_state"]["train_iterator"].get("epoch", 1),
1,
)
# set any missing default values in the task, model or other registries # old model checkpoints may not have separate source/target positions
registry.set_defaults(state["args"], tasks.TASK_REGISTRY[state["args"].task]) # backward compatibility, cfg updates
registry.set_defaults(state["args"], models.ARCH_MODEL_REGISTRY[state["args"].arch]) if "args" in state and state["args"] is not None:
for registry_name, REGISTRY in registry.REGISTRIES.items(): # default to translation task
choice = getattr(state["args"], registry_name, None) if not hasattr(state["args"], "task"):
if choice is not None: state["args"].task = "translation"
cls = REGISTRY["registry"][choice] # --raw-text and --lazy-load are deprecated
registry.set_defaults(state["args"], cls) if getattr(state["args"], "raw_text", False):
state["args"].dataset_impl = "raw"
elif getattr(state["args"], "lazy_load", False):
state["args"].dataset_impl = "lazy"
# epochs start at 1
if state["extra_state"]["train_iterator"] is not None:
state["extra_state"]["train_iterator"]["epoch"] = max(
state["extra_state"]["train_iterator"].get("epoch", 1), 1
)
state["cfg"] = convert_namespace_to_omegaconf(state["args"])
if "cfg" in state and state["cfg"] is not None:
with open_dict(state["cfg"]):
if state["cfg"].task is not None:
if hasattr(state["cfg"].task, "max_positions") and not hasattr(
state["cfg"].task, "max_source_positions"
):
state["cfg"].task.max_source_positions = state[
"cfg"
].task.max_positions
state["cfg"].task.max_target_positions = state[
"cfg"
].task.max_positions
return state return state
def prune_state_dict(state_dict, args): def prune_state_dict(state_dict, model_cfg: Optional[DictConfig]):
"""Prune the given state_dict if desired for LayerDrop """Prune the given state_dict if desired for LayerDrop
(https://arxiv.org/abs/1909.11556). (https://arxiv.org/abs/1909.11556).
@ -453,16 +480,20 @@ def prune_state_dict(state_dict, args):
It's called by functions that load models from checkpoints and does not It's called by functions that load models from checkpoints and does not
need to be called directly. need to be called directly.
""" """
if not args or args.arch == "ptt_transformer": arch = None
if model_cfg is not None:
arch = (
model_cfg._name
if isinstance(model_cfg, DictConfig)
else getattr(model_cfg, "arch", None)
)
if not model_cfg or arch is None or arch == "ptt_transformer":
# args should not be none, but don't crash if it is. # args should not be none, but don't crash if it is.
return state_dict return state_dict
encoder_layers_to_keep = ( encoder_layers_to_keep = getattr(model_cfg, "encoder_layers_to_keep", None)
args.encoder_layers_to_keep if "encoder_layers_to_keep" in vars(args) else None decoder_layers_to_keep = getattr(model_cfg, "decoder_layers_to_keep", None)
)
decoder_layers_to_keep = (
args.decoder_layers_to_keep if "decoder_layers_to_keep" in vars(args) else None
)
if not encoder_layers_to_keep and not decoder_layers_to_keep: if not encoder_layers_to_keep and not decoder_layers_to_keep:
return state_dict return state_dict
@ -474,7 +505,7 @@ def prune_state_dict(state_dict, args):
def create_pruning_pass(layers_to_keep, layer_name): def create_pruning_pass(layers_to_keep, layer_name):
keep_layers = sorted( keep_layers = sorted(
[int(layer_string) for layer_string in layers_to_keep.split(",")] int(layer_string) for layer_string in layers_to_keep.split(",")
) )
mapping_dict = {} mapping_dict = {}
for i in range(len(keep_layers)): for i in range(len(keep_layers)):
@ -518,10 +549,12 @@ def prune_state_dict(state_dict, args):
# Since layers are now pruned, *_layers_to_keep are no longer needed. # Since layers are now pruned, *_layers_to_keep are no longer needed.
# This is more of "It would make it work fix" rather than a proper fix. # This is more of "It would make it work fix" rather than a proper fix.
if "encoder_layers_to_keep" in vars(args):
args.encoder_layers_to_keep = None with open_dict(model_cfg):
if "decoder_layers_to_keep" in vars(args): if hasattr(model_cfg, "encoder_layers_to_keep"):
args.decoder_layers_to_keep = None model_cfg.encoder_layers_to_keep = None
if hasattr(model_cfg, "decoder_layers_to_keep"):
model_cfg.decoder_layers_to_keep = None
return new_state_dict return new_state_dict

View File

@ -6,8 +6,6 @@
import importlib import importlib
import os import os
from argparse import Namespace
from typing import Union
from fairseq import registry from fairseq import registry
from fairseq.criterions.fairseq_criterion import ( # noqa from fairseq.criterions.fairseq_criterion import ( # noqa
@ -27,8 +25,8 @@ from omegaconf import DictConfig
) )
def build_criterion(criterion_cfg: Union[DictConfig, Namespace], task): def build_criterion(cfg: DictConfig, task):
return build_criterion_(criterion_cfg, task) return build_criterion_(cfg, task)
# automatically import any Python files in the criterions/ directory # automatically import any Python files in the criterions/ directory

View File

@ -11,13 +11,13 @@ from fairseq import metrics, utils
from fairseq.criterions import FairseqCriterion, register_criterion from fairseq.criterions import FairseqCriterion, register_criterion
from fairseq.dataclass import FairseqDataclass from fairseq.dataclass import FairseqDataclass
from fairseq.dataclass.constants import DDP_BACKEND_CHOICES from fairseq.dataclass.constants import DDP_BACKEND_CHOICES
from omegaconf import II from omegaconf import II, DictConfig
@dataclass @dataclass
class AdaptiveLossConfig(FairseqDataclass): class AdaptiveLossConfig(FairseqDataclass):
sentence_avg: bool = II("params.optimization.sentence_avg") sentence_avg: bool = II("optimization.sentence_avg")
ddp_backend: DDP_BACKEND_CHOICES = II("params.distributed_training.ddp_backend") ddp_backend: DDP_BACKEND_CHOICES = II("distributed_training.ddp_backend")
@register_criterion("adaptive_loss", dataclass=AdaptiveLossConfig) @register_criterion("adaptive_loss", dataclass=AdaptiveLossConfig)
@ -31,14 +31,14 @@ class AdaptiveLoss(FairseqCriterion):
self.sentence_avg = sentence_avg self.sentence_avg = sentence_avg
@classmethod @classmethod
def build_criterion(cls, args, task): def build_criterion(cls, cfg: DictConfig, task):
if getattr(args, "ddp_backend", None) == "c10d": if cfg.ddp_backend == "c10d":
raise Exception( raise Exception(
"AdaptiveLoss is not compatible with the c10d " "AdaptiveLoss is not compatible with the c10d "
"version of DistributedDataParallel. Please use " "version of DistributedDataParallel. Please use "
"`--ddp-backend=no_c10d` instead." "`--ddp-backend=no_c10d` instead."
) )
return cls(task, args.sentence_avg) return cls(task, cfg.sentence_avg)
def forward(self, model, sample, reduce=True): def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample. """Compute the loss for the given sample.

View File

@ -15,7 +15,7 @@ from omegaconf import II
@dataclass @dataclass
class CrossEntropyCriterionConfig(FairseqDataclass): class CrossEntropyCriterionConfig(FairseqDataclass):
sentence_avg: bool = II("params.optimization.sentence_avg") sentence_avg: bool = II("optimization.sentence_avg")
@register_criterion("cross_entropy", dataclass=CrossEntropyCriterionConfig) @register_criterion("cross_entropy", dataclass=CrossEntropyCriterionConfig)

View File

@ -10,24 +10,24 @@ from argparse import Namespace
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from fairseq import metrics, utils from fairseq import metrics, utils
from fairseq.criterions import FairseqCriterion, register_criterion from fairseq.criterions import LegacyFairseqCriterion, register_criterion
from fairseq.data.data_utils import post_process from fairseq.data.data_utils import post_process
from fairseq.logging.meters import safe_round from fairseq.logging.meters import safe_round
@register_criterion("ctc") @register_criterion("ctc")
class CtcCriterion(FairseqCriterion): class CtcCriterion(LegacyFairseqCriterion):
def __init__(self, task, wer_args, zero_infinity, sentence_avg, remove_bpe): def __init__(self, args, task):
super().__init__(task) super().__init__(args, task)
self.blank_idx = task.target_dictionary.bos() self.blank_idx = task.target_dictionary.bos()
self.pad_idx = task.target_dictionary.pad() self.pad_idx = task.target_dictionary.pad()
self.eos_idx = task.target_dictionary.eos() self.eos_idx = task.target_dictionary.eos()
self.post_process = remove_bpe if remove_bpe else "letter" self.post_process = args.remove_bpe if args.remove_bpe else "letter"
if wer_args is not None: if args.wer_args is not None:
from examples.speech_recognition.w2l_decoder import W2lKenLMDecoder from examples.speech_recognition.w2l_decoder import W2lKenLMDecoder
wer_compute_kenlm, wer_lexicon, lm_w, ws_w = eval(wer_args) wer_compute_kenlm, wer_lexicon, lm_w, ws_w = eval(args.wer_args)
dec_args = Namespace() dec_args = Namespace()
dec_args.nbest = 1 dec_args.nbest = 1
@ -46,8 +46,8 @@ class CtcCriterion(FairseqCriterion):
else: else:
self.w2l_decoder = None self.w2l_decoder = None
self.zero_infinity = zero_infinity self.zero_infinity = args.zero_infinity
self.sentence_avg = sentence_avg self.sentence_avg = args.sentence_avg
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):

View File

@ -8,6 +8,7 @@ from typing import Any, Dict, List
from fairseq import metrics, utils from fairseq import metrics, utils
from fairseq.dataclass.utils import gen_parser_from_dataclass from fairseq.dataclass.utils import gen_parser_from_dataclass
from omegaconf import DictConfig
from torch.nn.modules.loss import _Loss from torch.nn.modules.loss import _Loss
@ -27,10 +28,8 @@ class FairseqCriterion(_Loss):
gen_parser_from_dataclass(parser, dc()) gen_parser_from_dataclass(parser, dc())
@classmethod @classmethod
def build_criterion(cls, args, task): def build_criterion(cls, cfg: DictConfig, task):
"""Construct a criterion from command-line args.""" """Construct a criterion from command-line args."""
# Criterions can override this, but for convenience we also try
# to automatically map argparse.Namespace keys to corresponding
# arguments in the __init__. # arguments in the __init__.
init_args = {} init_args = {}
for p in inspect.signature(cls).parameters.values(): for p in inspect.signature(cls).parameters.values():
@ -47,8 +46,8 @@ class FairseqCriterion(_Loss):
if p.name == "task": if p.name == "task":
init_args["task"] = task init_args["task"] = task
elif hasattr(args, p.name): elif hasattr(cfg, p.name):
init_args[p.name] = getattr(args, p.name) init_args[p.name] = getattr(cfg, p.name)
elif p.default != p.empty: elif p.default != p.empty:
pass # we'll use the default value pass # we'll use the default value
else: else:
@ -70,7 +69,7 @@ class FairseqCriterion(_Loss):
@staticmethod @staticmethod
def aggregate_logging_outputs( def aggregate_logging_outputs(
logging_outputs: List[Dict[str, Any]], logging_outputs: List[Dict[str, Any]]
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""Aggregate logging outputs from data parallel training.""" """Aggregate logging outputs from data parallel training."""
utils.deprecation_warning( utils.deprecation_warning(

View File

@ -4,6 +4,8 @@
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from dataclasses import dataclass, field
from fairseq import file_utils from fairseq import file_utils
from fairseq.data.encoders import register_bpe from fairseq.data.encoders import register_bpe
from fairseq.data.encoders.byte_utils import ( from fairseq.data.encoders.byte_utils import (
@ -12,19 +14,20 @@ from fairseq.data.encoders.byte_utils import (
byte_encode, byte_encode,
smart_byte_decode, smart_byte_decode,
) )
from fairseq.dataclass import FairseqDataclass
@register_bpe("byte_bpe") @dataclass
class ByteBpeConfig(FairseqDataclass):
sentencepiece_model_path: str = field(
default="???", metadata={"help": "path to sentencepiece model"}
)
@register_bpe("byte_bpe", dataclass=ByteBpeConfig)
class ByteBPE(object): class ByteBPE(object):
@staticmethod def __init__(self, cfg):
def add_args(parser): vocab = file_utils.cached_path(cfg.sentencepiece_model_path)
# fmt: off
parser.add_argument('--sentencepiece-model-path', type=str,
help='path to sentencepiece model')
# fmt: on
def __init__(self, args):
vocab = file_utils.cached_path(args.sentencepiece_model_path)
try: try:
import sentencepiece as spm import sentencepiece as spm

View File

@ -15,7 +15,7 @@ from fairseq.data.encoders.byte_utils import (
@register_bpe("bytes") @register_bpe("bytes")
class Bytes(object): class Bytes(object):
def __init__(self, args): def __init__(self, *unused):
pass pass
@staticmethod @staticmethod

View File

@ -13,7 +13,7 @@ SPACE_ESCAPE = chr(9601)
@register_bpe("characters") @register_bpe("characters")
class Characters(object): class Characters(object):
def __init__(self, args): def __init__(self, *unused):
pass pass
@staticmethod @staticmethod

View File

@ -3,23 +3,24 @@
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from dataclasses import dataclass, field
from fairseq import file_utils from fairseq import file_utils
from fairseq.data.encoders import register_bpe from fairseq.data.encoders import register_bpe
from fairseq.dataclass import FairseqDataclass
@register_bpe("fastbpe") @dataclass
class fastBPEConfig(FairseqDataclass):
bpe_codes: str = field(default="???", metadata={"help": "path to fastBPE BPE"})
@register_bpe("fastbpe", dataclass=fastBPEConfig)
class fastBPE(object): class fastBPE(object):
@staticmethod def __init__(self, cfg):
def add_args(parser): if cfg.bpe_codes is None:
# fmt: off
parser.add_argument('--bpe-codes', type=str,
help='path to fastBPE BPE')
# fmt: on
def __init__(self, args):
if args.bpe_codes is None:
raise ValueError("--bpe-codes is required for --bpe=fastbpe") raise ValueError("--bpe-codes is required for --bpe=fastbpe")
codes = file_utils.cached_path(args.bpe_codes) codes = file_utils.cached_path(cfg.bpe_codes)
try: try:
import fastBPE import fastBPE

View File

@ -3,8 +3,11 @@
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from dataclasses import dataclass, field
from fairseq import file_utils from fairseq import file_utils
from fairseq.data.encoders import register_bpe from fairseq.data.encoders import register_bpe
from fairseq.dataclass import FairseqDataclass
from .gpt2_bpe_utils import get_encoder from .gpt2_bpe_utils import get_encoder
@ -13,26 +16,21 @@ DEFAULT_ENCODER_JSON = "https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.
DEFAULT_VOCAB_BPE = "https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe" DEFAULT_VOCAB_BPE = "https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe"
@register_bpe("gpt2") @dataclass
class GPT2BPE(object): class GPT2BPEConfig(FairseqDataclass):
@staticmethod gpt2_encoder_json: str = field(
def add_args(parser): default=DEFAULT_ENCODER_JSON, metadata={"help": "path to encoder.json"}
# fmt: off )
parser.add_argument('--gpt2-encoder-json', type=str, gpt2_vocab_bpe: str = field(
default=DEFAULT_ENCODER_JSON, default=DEFAULT_VOCAB_BPE, metadata={"help": "path to vocab.bpe"}
help='path to encoder.json') )
parser.add_argument('--gpt2-vocab-bpe', type=str,
default=DEFAULT_VOCAB_BPE,
help='path to vocab.bpe')
# fmt: on
def __init__(self, args):
encoder_json = file_utils.cached_path( @register_bpe("gpt2", dataclass=GPT2BPEConfig)
getattr(args, "gpt2_encoder_json", DEFAULT_ENCODER_JSON) class GPT2BPE(object):
) def __init__(self, cfg):
vocab_bpe = file_utils.cached_path( encoder_json = file_utils.cached_path(cfg.gpt2_encoder_json)
getattr(args, "gpt2_vocab_bpe", DEFAULT_VOCAB_BPE) vocab_bpe = file_utils.cached_path(cfg.gpt2_vocab_bpe)
)
self.bpe = get_encoder(encoder_json, vocab_bpe) self.bpe = get_encoder(encoder_json, vocab_bpe)
def encode(self, x: str) -> str: def encode(self, x: str) -> str:

View File

@ -3,22 +3,24 @@
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from dataclasses import dataclass, field
from typing import Optional
from fairseq.data.encoders import register_bpe from fairseq.data.encoders import register_bpe
from fairseq.dataclass import FairseqDataclass
@register_bpe("bert") @dataclass
class BertBPEConfig(FairseqDataclass):
bpe_cased: bool = field(default=False, metadata={"help": "set for cased BPE"})
bpe_vocab_file: Optional[str] = field(
default=None, metadata={"help": "bpe vocab file"}
)
@register_bpe("bert", dataclass=BertBPEConfig)
class BertBPE(object): class BertBPE(object):
@staticmethod def __init__(self, cfg):
def add_args(parser):
# fmt: off
parser.add_argument('--bpe-cased', action='store_true',
help='set for cased BPE',
default=False)
parser.add_argument('--bpe-vocab-file', type=str,
help='bpe vocab file.')
# fmt: on
def __init__(self, args):
try: try:
from transformers import BertTokenizer from transformers import BertTokenizer
except ImportError: except ImportError:
@ -26,13 +28,13 @@ class BertBPE(object):
"Please install transformers with: pip install transformers" "Please install transformers with: pip install transformers"
) )
if "bpe_vocab_file" in args: if cfg.bpe_vocab_file:
self.bert_tokenizer = BertTokenizer( self.bert_tokenizer = BertTokenizer(
args.bpe_vocab_file, do_lower_case=not args.bpe_cased cfg.bpe_vocab_file, do_lower_case=not cfg.bpe_cased
) )
else: else:
vocab_file_name = ( vocab_file_name = (
"bert-base-cased" if args.bpe_cased else "bert-base-uncased" "bert-base-cased" if cfg.bpe_cased else "bert-base-uncased"
) )
self.bert_tokenizer = BertTokenizer.from_pretrained(vocab_file_name) self.bert_tokenizer = BertTokenizer.from_pretrained(vocab_file_name)

View File

@ -3,21 +3,24 @@
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from dataclasses import dataclass, field
from fairseq.data.encoders import register_bpe from fairseq.data.encoders import register_bpe
from fairseq.dataclass import FairseqDataclass
@register_bpe("hf_byte_bpe") @dataclass
class HuggingFaceByteLevelBPEConfig(FairseqDataclass):
bpe_merges: str = field(default="???", metadata={"help": "path to merges.txt"})
bpe_vocab: str = field(default="???", metadata={"help": "path to vocab.json"})
bpe_add_prefix_space: bool = field(
default=False, metadata={"help": "add prefix space before encoding"}
)
@register_bpe("hf_byte_bpe", dataclass=HuggingFaceByteLevelBPEConfig)
class HuggingFaceByteLevelBPE(object): class HuggingFaceByteLevelBPE(object):
@staticmethod def __init__(self, cfg):
def add_args(parser):
# fmt: off
parser.add_argument('--bpe-merges', help='path to merges.txt')
parser.add_argument('--bpe-vocab', help='path to vocab.json')
parser.add_argument('--bpe-add-prefix-space', action='store_true',
help='add prefix space before encoding')
# fmt: on
def __init__(self, args):
try: try:
from tokenizers import ByteLevelBPETokenizer from tokenizers import ByteLevelBPETokenizer
except ImportError: except ImportError:
@ -26,9 +29,9 @@ class HuggingFaceByteLevelBPE(object):
) )
self.bpe = ByteLevelBPETokenizer( self.bpe = ByteLevelBPETokenizer(
args.bpe_vocab, cfg.bpe_vocab,
args.bpe_merges, cfg.bpe_merges,
add_prefix_space=getattr(args, "bpe_add_prefix_space", False), add_prefix_space=cfg.bpe_add_prefix_space,
) )
def encode(self, x: str) -> str: def encode(self, x: str) -> str:

View File

@ -3,37 +3,35 @@
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from dataclasses import dataclass, field
from fairseq.data.encoders import register_tokenizer from fairseq.data.encoders import register_tokenizer
from fairseq.dataclass import FairseqDataclass
@register_tokenizer("moses") @dataclass
class MosesTokenizerConfig(FairseqDataclass):
source_lang: str = field(default="en", metadata={"help": "source language"})
target_lang: str = field(default="en", metadata={"help": "target language"})
moses_no_dash_splits: bool = field(
default=False, metadata={"help": "don't apply dash split rules"}
)
moses_no_escape: bool = field(
default=False,
metadata={"help": "don't perform HTML escaping on apostrophe, quotes, etc."},
)
@register_tokenizer("moses", dataclass=MosesTokenizerConfig)
class MosesTokenizer(object): class MosesTokenizer(object):
@staticmethod def __init__(self, cfg):
def add_args(parser): self.cfg = cfg
# fmt: off
parser.add_argument('--moses-source-lang', metavar='SRC',
help='source language')
parser.add_argument('--moses-target-lang', metavar='TARGET',
help='target language')
parser.add_argument('--moses-no-dash-splits', action='store_true', default=False,
help='don\'t apply dash split rules')
parser.add_argument('--moses-no-escape', action='store_true', default=False,
help='don\'t perform HTML escaping on apostrophy, quotes, etc.')
# fmt: on
def __init__(self, args):
self.args = args
if getattr(args, "moses_source_lang", None) is None:
args.moses_source_lang = getattr(args, "source_lang", "en")
if getattr(args, "moses_target_lang", None) is None:
args.moses_target_lang = getattr(args, "target_lang", "en")
try: try:
from sacremoses import MosesTokenizer, MosesDetokenizer from sacremoses import MosesTokenizer, MosesDetokenizer
self.tok = MosesTokenizer(args.moses_source_lang) self.tok = MosesTokenizer(cfg.source_lang)
self.detok = MosesDetokenizer(args.moses_target_lang) self.detok = MosesDetokenizer(cfg.target_lang)
except ImportError: except ImportError:
raise ImportError( raise ImportError(
"Please install Moses tokenizer with: pip install sacremoses" "Please install Moses tokenizer with: pip install sacremoses"
@ -42,9 +40,9 @@ class MosesTokenizer(object):
def encode(self, x: str) -> str: def encode(self, x: str) -> str:
return self.tok.tokenize( return self.tok.tokenize(
x, x,
aggressive_dash_splits=(not self.args.moses_no_dash_splits), aggressive_dash_splits=(not self.cfg.moses_no_dash_splits),
return_str=True, return_str=True,
escape=(not self.args.moses_no_escape), escape=(not self.cfg.moses_no_escape),
) )
def decode(self, x: str) -> str: def decode(self, x: str) -> str:

View File

@ -8,7 +8,7 @@ from fairseq.data.encoders import register_tokenizer
@register_tokenizer("nltk") @register_tokenizer("nltk")
class NLTKTokenizer(object): class NLTKTokenizer(object):
def __init__(self, source_lang=None, target_lang=None): def __init__(self, *unused):
try: try:
from nltk.tokenize import word_tokenize from nltk.tokenize import word_tokenize

View File

@ -3,21 +3,24 @@
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from dataclasses import dataclass, field
from fairseq import file_utils from fairseq import file_utils
from fairseq.data.encoders import register_bpe from fairseq.data.encoders import register_bpe
from fairseq.dataclass import FairseqDataclass
@register_bpe("sentencepiece") @dataclass
class SentencepieceConfig(FairseqDataclass):
sentencepiece_model: str = field(
default="???", metadata={"help": "path to sentencepiece model"}
)
@register_bpe("sentencepiece", dataclass=SentencepieceConfig)
class SentencepieceBPE(object): class SentencepieceBPE(object):
@staticmethod def __init__(self, cfg):
def add_args(parser): sentencepiece_model = file_utils.cached_path(cfg.sentencepiece_model)
# fmt: off
parser.add_argument('--sentencepiece-model', type=str,
help='path to sentencepiece model')
# fmt: on
def __init__(self, args):
sentencepiece_model = file_utils.cached_path(args.sentencepiece_model)
try: try:
import sentencepiece as spm import sentencepiece as spm

View File

@ -10,7 +10,7 @@ from fairseq.data.encoders import register_tokenizer
@register_tokenizer("space") @register_tokenizer("space")
class SpaceTokenizer(object): class SpaceTokenizer(object):
def __init__(self, source_lang=None, target_lang=None): def __init__(self, *unused):
self.space_tok = re.compile(r"\s+") self.space_tok = re.compile(r"\s+")
def encode(self, x: str) -> str: def encode(self, x: str) -> str:

View File

@ -3,25 +3,25 @@
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from dataclasses import dataclass, field
from fairseq import file_utils from fairseq import file_utils
from fairseq.data.encoders import register_bpe from fairseq.data.encoders import register_bpe
from fairseq.dataclass import FairseqDataclass
@register_bpe("subword_nmt") @dataclass
class SubwordNMTBPEConfig(FairseqDataclass):
bpe_codes: str = field(default="???", metadata={"help": "path to subword NMT BPE"})
bpe_separator: str = field(default="@@", metadata={"help": "BPE separator"})
@register_bpe("subword_nmt", dataclass=SubwordNMTBPEConfig)
class SubwordNMTBPE(object): class SubwordNMTBPE(object):
@staticmethod def __init__(self, cfg):
def add_args(parser): if cfg.bpe_codes is None:
# fmt: off
parser.add_argument('--bpe-codes', type=str,
help='path to subword NMT BPE')
parser.add_argument('--bpe-separator', default='@@',
help='BPE separator')
# fmt: on
def __init__(self, args):
if args.bpe_codes is None:
raise ValueError("--bpe-codes is required for --bpe=subword_nmt") raise ValueError("--bpe-codes is required for --bpe=subword_nmt")
codes = file_utils.cached_path(args.bpe_codes) codes = file_utils.cached_path(cfg.bpe_codes)
try: try:
from subword_nmt import apply_bpe from subword_nmt import apply_bpe
@ -31,7 +31,7 @@ class SubwordNMTBPE(object):
"--codes", "--codes",
codes, codes,
"--separator", "--separator",
args.bpe_separator, cfg.bpe_separator,
] ]
) )
self.bpe = apply_bpe.BPE( self.bpe = apply_bpe.BPE(

View File

@ -9,5 +9,7 @@ from fairseq.dataclass.utils import ChoiceEnum
LOG_FORMAT_CHOICES = ChoiceEnum(["json", "none", "simple", "tqdm"]) LOG_FORMAT_CHOICES = ChoiceEnum(["json", "none", "simple", "tqdm"])
DDP_BACKEND_CHOICES = ChoiceEnum(["c10d", "no_c10d"]) DDP_BACKEND_CHOICES = ChoiceEnum(["c10d", "no_c10d"])
DISTRIBUTED_WRAPPER_CHOICES = ChoiceEnum(["DDP", "SlowMo"]) DISTRIBUTED_WRAPPER_CHOICES = ChoiceEnum(["DDP", "SlowMo"])
GENERATION_CONSTRAINTS_CHOICES = ChoiceEnum(["ordered", "unordered"])
GENERATION_DECODING_FORMAT_CHOICES = ChoiceEnum(["unigram", "ensemble", "vote", "dp", "bs"])
ZERO_SHARDING_CHOICES = ChoiceEnum(["none", "os"]) ZERO_SHARDING_CHOICES = ChoiceEnum(["none", "os"])
PIPELINE_CHECKPOINT_CHOICES = ChoiceEnum(["always", "never", "except_last"]) PIPELINE_CHECKPOINT_CHOICES = ChoiceEnum(["always", "never", "except_last"])

View File

@ -3,32 +3,37 @@
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import logging
import sys import sys
from argparse import Namespace from argparse import Namespace
from dataclasses import dataclass, field from dataclasses import _MISSING_TYPE, dataclass, field
from typing import Any, Dict, List, Optional, Tuple, Type from typing import Any, Dict, List, Optional, Tuple, Type
import torch import torch
from fairseq.criterions import CRITERION_DATACLASS_REGISTRY
from fairseq.data.indexed_dataset import get_available_dataset_impl from fairseq.data.indexed_dataset import get_available_dataset_impl
from fairseq.dataclass.constants import ( from fairseq.dataclass.constants import (
DDP_BACKEND_CHOICES, DDP_BACKEND_CHOICES,
DISTRIBUTED_WRAPPER_CHOICES, DISTRIBUTED_WRAPPER_CHOICES,
GENERATION_CONSTRAINTS_CHOICES,
GENERATION_DECODING_FORMAT_CHOICES,
LOG_FORMAT_CHOICES, LOG_FORMAT_CHOICES,
PIPELINE_CHECKPOINT_CHOICES, PIPELINE_CHECKPOINT_CHOICES,
ZERO_SHARDING_CHOICES, ZERO_SHARDING_CHOICES,
) )
from fairseq.dataclass.utils import ChoiceEnum, FairseqDataclass from fairseq.dataclass.utils import ChoiceEnum, FairseqDataclass
from fairseq.models import ARCH_MODEL_REGISTRY, MODEL_DATACLASS_REGISTRY from fairseq.models import ARCH_MODEL_REGISTRY, MODEL_DATACLASS_REGISTRY
from fairseq.optim import OPTIMIZER_DATACLASS_REGISTRY
from fairseq.optim.bmuf import FairseqBMUFConfig from fairseq.optim.bmuf import FairseqBMUFConfig
from fairseq.optim.lr_scheduler import LR_SCHEDULER_DATACLASS_REGISTRY from fairseq.registry import REGISTRIES
from fairseq.tasks import TASK_DATACLASS_REGISTRY from fairseq.tasks import TASK_DATACLASS_REGISTRY
from hydra.core.config_store import ConfigStore from hydra.core.config_store import ConfigStore
from omegaconf import II
logger = logging.getLogger(__name__)
@dataclass @dataclass
class CommonParams(FairseqDataclass): class CommonConfig(FairseqDataclass):
# This is the core dataclass including common parameters shared by all different jobs. Please append your params to other dataclasses if they were # This is the core dataclass including common parameters shared by all different jobs. Please append your params to other dataclasses if they were
# used for a particular purpose or task, such as those dedicated for `distributed training`, `optimization`, etc. # used for a particular purpose or task, such as those dedicated for `distributed training`, `optimization`, etc.
no_progress_bar: bool = field( no_progress_bar: bool = field(
@ -109,18 +114,6 @@ class CommonParams(FairseqDataclass):
model_parallel_size: int = field( model_parallel_size: int = field(
default=1, metadata={"help": "total number of GPUs to parallelize model over"} default=1, metadata={"help": "total number of GPUs to parallelize model over"}
) )
checkpoint_suffix: str = field(
default="", metadata={"help": "suffix to add to the checkpoint file name"}
)
checkpoint_shard_count: int = field(
default=1,
metadata={
"help": "Number of shards containing the checkpoint - "
"if the checkpoint is over 300GB, it is preferable "
"to split it into shards to prevent OOM on CPU while loading "
"the checkpoint"
},
)
quantization_config_path: Optional[str] = field( quantization_config_path: Optional[str] = field(
default=None, metadata={"help": "path to quantization config file"} default=None, metadata={"help": "path to quantization config file"}
) )
@ -130,7 +123,7 @@ class CommonParams(FairseqDataclass):
@dataclass @dataclass
class DistributedTrainingParams(FairseqDataclass): class DistributedTrainingConfig(FairseqDataclass):
distributed_world_size: int = field( distributed_world_size: int = field(
default=max(1, torch.cuda.device_count()), default=max(1, torch.cuda.device_count()),
metadata={ metadata={
@ -229,7 +222,7 @@ class DistributedTrainingParams(FairseqDataclass):
default=False, default=False,
metadata={"help": "if set, use pipeline model parallelism across GPUs"}, metadata={"help": "if set, use pipeline model parallelism across GPUs"},
) )
pipeline_balance: str = field( pipeline_balance: Optional[str] = field(
default=None, default=None,
metadata={ metadata={
"help": "partition the model into N_K pieces, where each piece " "help": "partition the model into N_K pieces, where each piece "
@ -237,7 +230,7 @@ class DistributedTrainingParams(FairseqDataclass):
"should equal the total number of layers in the model" "should equal the total number of layers in the model"
}, },
) )
pipeline_devices: str = field( pipeline_devices: Optional[str] = field(
default=None, default=None,
metadata={ metadata={
"help": "a list of device indices indicating which device to place " "help": "a list of device indices indicating which device to place "
@ -245,10 +238,10 @@ class DistributedTrainingParams(FairseqDataclass):
"equal the length of the --pipeline-balance argument" "equal the length of the --pipeline-balance argument"
}, },
) )
pipeline_chunks: int = field( pipeline_chunks: Optional[int] = field(
default=0, metadata={"help": "microbatch count for pipeline model parallelism"} default=0, metadata={"help": "microbatch count for pipeline model parallelism"}
) )
pipeline_encoder_balance: str = field( pipeline_encoder_balance: Optional[str] = field(
default=None, default=None,
metadata={ metadata={
"help": "partition the pipeline parallel encoder into N_K pieces, where each piece " "help": "partition the pipeline parallel encoder into N_K pieces, where each piece "
@ -256,7 +249,7 @@ class DistributedTrainingParams(FairseqDataclass):
"should equal the total number of encoder layers in the model" "should equal the total number of encoder layers in the model"
}, },
) )
pipeline_encoder_devices: str = field( pipeline_encoder_devices: Optional[str] = field(
default=None, default=None,
metadata={ metadata={
"help": "a list of device indices indicating which device to place " "help": "a list of device indices indicating which device to place "
@ -264,7 +257,7 @@ class DistributedTrainingParams(FairseqDataclass):
"equal the length of the --pipeline-encoder-balance argument" "equal the length of the --pipeline-encoder-balance argument"
}, },
) )
pipeline_decoder_balance: str = field( pipeline_decoder_balance: Optional[str] = field(
default=None, default=None,
metadata={ metadata={
"help": "partition the pipeline parallel decoder into N_K pieces, where each piece " "help": "partition the pipeline parallel decoder into N_K pieces, where each piece "
@ -272,7 +265,7 @@ class DistributedTrainingParams(FairseqDataclass):
"should equal the total number of decoder layers in the model" "should equal the total number of decoder layers in the model"
}, },
) )
pipeline_decoder_devices: str = field( pipeline_decoder_devices: Optional[str] = field(
default=None, default=None,
metadata={ metadata={
"help": "a list of device indices indicating which device to place " "help": "a list of device indices indicating which device to place "
@ -287,10 +280,11 @@ class DistributedTrainingParams(FairseqDataclass):
zero_sharding: ZERO_SHARDING_CHOICES = field( zero_sharding: ZERO_SHARDING_CHOICES = field(
default="none", metadata={"help": "ZeRO sharding"} default="none", metadata={"help": "ZeRO sharding"}
) )
tpu: bool = II("common.tpu")
@dataclass @dataclass
class DatasetParams(FairseqDataclass): class DatasetConfig(FairseqDataclass):
num_workers: int = field( num_workers: int = field(
default=1, metadata={"help": "how many subprocesses to use for data loading"} default=1, metadata={"help": "how many subprocesses to use for data loading"}
) )
@ -374,7 +368,7 @@ class DatasetParams(FairseqDataclass):
@dataclass @dataclass
class OptimizationParams(FairseqDataclass): class OptimizationConfig(FairseqDataclass):
max_epoch: int = field( max_epoch: int = field(
default=0, metadata={"help": "force stop training at specified epoch"} default=0, metadata={"help": "force stop training at specified epoch"}
) )
@ -421,7 +415,7 @@ class OptimizationParams(FairseqDataclass):
@dataclass @dataclass
class CheckpointParams(FairseqDataclass): class CheckpointConfig(FairseqDataclass):
save_dir: str = field( save_dir: str = field(
default="checkpoints", metadata={"help": "path to save checkpoints"} default="checkpoints", metadata={"help": "path to save checkpoints"}
) )
@ -514,12 +508,217 @@ class CheckpointParams(FairseqDataclass):
) )
}, },
) )
checkpoint_suffix: str = field(
default="", metadata={"help": "suffix to add to the checkpoint file name"}
)
checkpoint_shard_count: int = field(
default=1,
metadata={
"help": "Number of shards containing the checkpoint - "
"if the checkpoint is over 300GB, it is preferable "
"to split it into shards to prevent OOM on CPU while loading "
"the checkpoint"
},
)
model_parallel_size: int = II("common.model_parallel_size")
distributed_rank: int = II("distributed_training.distributed_rank")
@dataclass @dataclass
class CommonEvalParams(FairseqDataclass): class GenerationConfig(FairseqDataclass):
beam: int = field(
default=5,
metadata={"help": "beam size"},
)
nbest: int = field(
default=1,
metadata={"help": "number of hypotheses to output"},
)
max_len_a: float = field(
default=0,
metadata={
"help": "generate sequences of maximum length ax + b, where x is the source length"
},
)
max_len_b: int = field(
default=200,
metadata={
"help": "generate sequences of maximum length ax + b, where x is the source length"
},
)
min_len: int = field(
default=1,
metadata={"help": "minimum generation length"},
)
match_source_len: bool = field(
default=False,
metadata={"help": "generations should match the source length"},
)
unnormalized: bool = field(
default=False,
metadata={"help": "compare unnormalized hypothesis scores"},
)
no_early_stop: bool = field(
default=False,
metadata={"help": "deprecated"},
)
no_beamable_mm: bool = field(
default=False,
metadata={"help": "don't use BeamableMM in attention layers"},
)
lenpen: float = field(
default=1,
metadata={
"help": "length penalty: <1.0 favors shorter, >1.0 favors longer sentences"
},
)
unkpen: float = field(
default=0,
metadata={
"help": "unknown word penalty: <0 produces more unks, >0 produces fewer"
},
)
replace_unk: Optional[str] = field(
default=None,
metadata={
"help": "perform unknown replacement (optionally with alignment dictionary)",
"argparse_const": "@@ ",
},
)
sacrebleu: bool = field(
default=False,
metadata={"help": "score with sacrebleu"},
)
score_reference: bool = field(
default=False,
metadata={"help": "just score the reference translation"},
)
prefix_size: int = field(
default=0,
metadata={"help": "initialize generation by target prefix of given length"},
)
no_repeat_ngram_size: int = field(
default=0,
metadata={
"help": "ngram blocking such that this size ngram cannot be repeated in the generation"
},
)
sampling: bool = field(
default=False,
metadata={"help": "sample hypotheses instead of using beam search"},
)
sampling_topk: int = field(
default=-1,
metadata={"help": "sample from top K likely next words instead of all words"},
)
sampling_topp: float = field(
default=-1.0,
metadata={
"help": "sample from the smallest set whose cumulative probability mass exceeds p for next words"
},
)
constraints: Optional[GENERATION_CONSTRAINTS_CHOICES] = field(
default=None,
metadata={
"help": "enables lexically constrained decoding",
"argparse_const": "ordered",
},
)
temperature: float = field(
default=1.0,
metadata={"help": "temperature for generation"},
)
diverse_beam_groups: int = field(
default=-1,
metadata={"help": "number of groups for Diverse Beam Search"},
)
diverse_beam_strength: float = field(
default=0.5,
metadata={"help": "strength of diversity penalty for Diverse Beam Search"},
)
diversity_rate: float = field(
default=-1.0,
metadata={"help": "strength of diversity penalty for Diverse Siblings Search"},
)
print_alignment: bool = field(
default=False,
metadata={
"help": "if set, uses attention feedback to compute and print alignment to source tokens"
},
)
print_step: bool = field(
default=False,
metadata={"help": "print steps"},
)
lm_path: Optional[str] = field(
default=None,
metadata={"help": "path to lm checkpoint for lm fusion"},
)
lm_weight: float = field(
default=0.0,
metadata={"help": "weight for lm probs for lm fusion"},
)
# arguments for iterative refinement generator
iter_decode_eos_penalty: float = field(
default=0.0,
metadata={"help": "if > 0.0, it penalized early-stopping in decoding."},
)
iter_decode_max_iter: int = field(
default=10,
metadata={"help": "maximum iterations for iterative refinement."},
)
iter_decode_force_max_iter: bool = field(
default=False,
metadata={
"help": "if set, run exact the maximum number of iterations without early stop"
},
)
iter_decode_with_beam: int = field(
default=1,
metadata={
"help": "if > 1, model will generate translations varying by the lengths."
},
)
iter_decode_with_external_reranker: bool = field(
default=False,
metadata={
"help": "if set, the last checkpoint are assumed to be a reranker to rescore the translations"
},
)
retain_iter_history: bool = field(
default=False,
metadata={
"help": "if set, decoding returns the whole history of iterative refinement"
},
)
retain_dropout: bool = field(
default=False,
metadata={"help": "Use dropout at inference time"},
)
retain_dropout_modules: Optional[List[str]] = field(
default=None,
metadata={
"help": "if set, only retain dropout for the specified modules; "
"if not set, then dropout will be retained for all modules"
},
)
# special decoding format for advanced decoding.
decoding_format: Optional[GENERATION_DECODING_FORMAT_CHOICES] = field(
default=None,
metadata={"help": "special decoding format for advanced decoding."},
)
no_seed_provided: bool = field(
default=False,
metadata={"help": "if set, dont use seed for initializing random generators"},
)
@dataclass
class CommonEvalConfig(FairseqDataclass):
path: Optional[str] = field( path: Optional[str] = field(
default=None, metadata={"help": "path(s) to model file(s), colon separated"} default=None,
metadata={"help": "path(s) to model file(s), colon separated"},
) )
remove_bpe: Optional[str] = field( remove_bpe: Optional[str] = field(
default=None, default=None,
@ -541,7 +740,7 @@ class CommonEvalParams(FairseqDataclass):
@dataclass @dataclass
class EvalLMParams(FairseqDataclass): class EvalLMConfig(FairseqDataclass):
output_word_probs: bool = field( output_word_probs: bool = field(
default=False, default=False,
metadata={ metadata={
@ -569,37 +768,31 @@ class EvalLMParams(FairseqDataclass):
@dataclass @dataclass
class TrainingConfig(FairseqDataclass): class InteractiveConfig(FairseqDataclass):
"""Config for training, a composition of training params""" buffer_size: int = field(
default=0,
common: CommonParams = CommonParams() metadata={
distributed_training: DistributedTrainingParams = DistributedTrainingParams() "help": "read this many sentences into a buffer before processing them"
dataset: DatasetParams = DatasetParams() },
optimization: OptimizationParams = OptimizationParams() )
checkpoint: CheckpointParams = CheckpointParams() input: str = field(
bmuf: FairseqBMUFConfig = FairseqBMUFConfig() default="-",
metadata={"help": "file to read from; use - for stdin"},
)
@dataclass CONFIGS = {
class EvalLMConfig(FairseqDataclass): "common": CommonConfig,
"""Config for eval lm, a composition of eval_lm params""" "common_eval": CommonEvalConfig,
"distributed_training": DistributedTrainingConfig,
common: CommonParams = CommonParams() "dataset": DatasetConfig,
distributed_training: DistributedTrainingParams = DistributedTrainingParams() "optimization": OptimizationConfig,
dataset: DatasetParams = DatasetParams() "checkpoint": CheckpointConfig,
optimization: OptimizationParams = OptimizationParams() "bmuf": FairseqBMUFConfig,
checkpoint: CheckpointParams = CheckpointParams() "generation": GenerationConfig,
bmuf: FairseqBMUFConfig = FairseqBMUFConfig() "eval_lm": EvalLMConfig,
common_eval: CommonEvalParams = CommonEvalParams() "interactive": InteractiveConfig,
eval_lm: EvalLMParams = EvalLMParams() }
def register_params_dataclass(
cs: ConfigStore, name: str, group: str, data_class: Type[FairseqDataclass]
) -> None:
"""register params dataclass in config store"""
node_ = data_class(_name=data_class.name())
cs.store(name=name, group=group, node=node_)
def register_module_dataclass( def register_module_dataclass(
@ -608,100 +801,67 @@ def register_module_dataclass(
"""register dataclasses defined in modules in config store, for example, in migrated tasks, models, etc.""" """register dataclasses defined in modules in config store, for example, in migrated tasks, models, etc."""
# note that if `group == model`, we register all model archs, not the model name. # note that if `group == model`, we register all model archs, not the model name.
for k, v in registry.items(): for k, v in registry.items():
if v is not None: node_ = v()
node_ = v(_name=k) node_._name = k
cs.store(name=k, group=group, node=node_) cs.store(name=k, group=group, node=node_, provider="fairseq")
def register_training_hydra_cfg(cs: ConfigStore, name: str = "default") -> None: def register_hydra_cfg(cs: ConfigStore, name: str = "default") -> None:
"""cs: config store instance, register common training configs""" """cs: config store instance, register common training configs"""
register_params_dataclass( for k, v in CONFIGS.items():
cs, name="training_params", group="params", data_class=TrainingConfig try:
) cs.store(name=k, node=v())
except BaseException:
logger.error(f"{k} - {v()}")
raise
register_module_dataclass(cs, TASK_DATACLASS_REGISTRY, "task") register_module_dataclass(cs, TASK_DATACLASS_REGISTRY, "task")
register_module_dataclass(cs, MODEL_DATACLASS_REGISTRY, "model") register_module_dataclass(cs, MODEL_DATACLASS_REGISTRY, "model")
register_module_dataclass(cs, CRITERION_DATACLASS_REGISTRY, "criterion")
register_module_dataclass(cs, OPTIMIZER_DATACLASS_REGISTRY, "optimizer")
register_module_dataclass(cs, LR_SCHEDULER_DATACLASS_REGISTRY, "lr_scheduler")
for k, v in REGISTRIES.items():
def register_eval_lm_hydra_cfg(cs: ConfigStore, name: str = "default") -> None: register_module_dataclass(cs, v["dataclass_registry"], k)
"""cs: config store instance, register common training configs"""
register_params_dataclass(
cs, name="eval_lm_params", group="params", data_class=EvalLMConfig
)
register_module_dataclass(cs, TASK_DATACLASS_REGISTRY, "task")
register_module_dataclass(cs, CRITERION_DATACLASS_REGISTRY, "criterion")
register_module_dataclass(cs, OPTIMIZER_DATACLASS_REGISTRY, "optimizer")
register_module_dataclass(cs, LR_SCHEDULER_DATACLASS_REGISTRY, "lr_scheduler")
def _override_attr( def _override_attr(
sub_node: str, data_class: Type[FairseqDataclass], args: Namespace sub_node: str, data_class: Type[FairseqDataclass], args: Namespace
) -> List[str]: ) -> List[str]:
overrides = [] overrides = []
for k in data_class.__dataclass_fields__.keys():
if k == "_name": def get_default(f):
if not isinstance(f.default_factory, _MISSING_TYPE):
return f.default_factory()
return f.default
for k, v in data_class.__dataclass_fields__.items():
if k.startswith("_"):
# private member, skip # private member, skip
continue continue
if not hasattr(args, k):
# print(f"cannot override {sub_node}.{k} since args does not have attribute {k}") val = get_default(v) if not hasattr(args, k) else getattr(args, k)
continue
if getattr(args, k) is None: if val is None:
overrides.append("{}.{}=null".format(sub_node, k)) overrides.append("{}.{}=null".format(sub_node, k))
elif getattr(args, k) == "": elif val == "":
overrides.append("{}.{}=''".format(sub_node, k)) overrides.append("{}.{}=''".format(sub_node, k))
elif isinstance(getattr(args, k), str): elif isinstance(val, str):
if ( overrides.append("{}.{}='{}'".format(sub_node, k, val))
getattr(args, k).startswith("[")
or getattr(args, k).startswith("(")
or getattr(args, k).startswith("{")
or ("," in getattr(args, k))
):
overrides.append("{}.{}='{}'".format(sub_node, k, getattr(args, k)))
else:
overrides.append("{}.{}={}".format(sub_node, k, getattr(args, k)))
else: else:
overrides.append("{}.{}={}".format(sub_node, k, getattr(args, k))) overrides.append("{}.{}={}".format(sub_node, k, val))
return overrides return overrides
def override_training_args(args: Namespace) -> Tuple[List[str], List[str]]: def migrate_registry(
overrides = [] name, value, registry, args, overrides, deletes, use_name_as_val=False
):
overrides.extend(_override_attr("params.common", CommonParams, args)) if value in registry:
overrides.extend(_override_attr("params.dataset", DatasetParams, args)) overrides.append("{}={}".format(name, value))
overrides.extend( overrides.append("{}._name={}".format(name, value))
_override_attr("params.distributed_training", DistributedTrainingParams, args) overrides.extend(_override_attr(name, registry[value], args))
) elif use_name_as_val and value is not None:
overrides.extend(_override_attr("params.optimization", OptimizationParams, args)) overrides.append("{}={}".format(name, value))
overrides.extend(_override_attr("params.checkpoint", CheckpointParams, args)) else:
overrides.extend(_override_attr("params.bmuf", FairseqBMUFConfig, args)) deletes.append(name)
module_overrides, module_deletes = override_module_args(args)
overrides.extend(module_overrides)
return overrides, module_deletes
def override_eval_lm_args(args: Namespace) -> Tuple[List[str], List[str]]:
overrides = []
overrides.extend(_override_attr("params.common", CommonParams, args))
overrides.extend(_override_attr("params.dataset", DatasetParams, args))
overrides.extend(
_override_attr("params.distributed_training", DistributedTrainingParams, args)
)
overrides.extend(_override_attr("params.common_eval", CommonEvalParams, args))
overrides.extend(_override_attr("params.eval_lm", EvalLMParams, args))
overrides.extend(_override_attr("params.bmuf", FairseqBMUFConfig, args))
module_overrides, module_deletes = override_module_args(args)
overrides.extend(module_overrides)
return overrides, module_deletes
def override_module_args(args: Namespace) -> Tuple[List[str], List[str]]: def override_module_args(args: Namespace) -> Tuple[List[str], List[str]]:
@ -709,53 +869,34 @@ def override_module_args(args: Namespace) -> Tuple[List[str], List[str]]:
overrides = [] overrides = []
deletes = [] deletes = []
for k, v in CONFIGS.items():
overrides.extend(_override_attr(k, v, args))
if args is not None: if args is not None:
assert ( if hasattr(args, "task"):
hasattr(args, "task") migrate_registry(
and hasattr(args, "criterion") "task", args.task, TASK_DATACLASS_REGISTRY, args, overrides, deletes
and hasattr(args, "optimizer")
and hasattr(args, "lr_scheduler")
)
if args.task in TASK_DATACLASS_REGISTRY:
overrides.append("task={}".format(args.task))
overrides.append("task._name={}".format(args.task))
overrides.extend(
_override_attr("task", TASK_DATACLASS_REGISTRY[args.task], args)
) )
else: else:
deletes.append("task") deletes.append("task")
if args.criterion in CRITERION_DATACLASS_REGISTRY:
overrides.append("criterion={}".format(args.criterion)) # these options will be set to "None" if they have not yet been migrated
overrides.append("criterion._name={}".format(args.criterion)) # so we can populate them with the entire flat args
overrides.extend( CORE_REGISTRIES = {"criterion", "optimizer", "lr_scheduler"}
_override_attr(
"criterion", CRITERION_DATACLASS_REGISTRY[args.criterion], args for k, v in REGISTRIES.items():
) if hasattr(args, k):
) migrate_registry(
else: k,
deletes.append("criterion") getattr(args, k),
if args.optimizer in OPTIMIZER_DATACLASS_REGISTRY: v["dataclass_registry"],
overrides.append("optimizer={}".format(args.optimizer))
overrides.append("optimizer._name={}".format(args.optimizer))
overrides.extend(
_override_attr(
"optimizer", OPTIMIZER_DATACLASS_REGISTRY[args.optimizer], args
)
)
else:
deletes.append("optimizer")
if args.lr_scheduler in LR_SCHEDULER_DATACLASS_REGISTRY:
overrides.append("lr_scheduler={}".format(args.lr_scheduler))
overrides.append("lr_scheduler._name={}".format(args.lr_scheduler))
overrides.extend(
_override_attr(
"lr_scheduler",
LR_SCHEDULER_DATACLASS_REGISTRY[args.lr_scheduler],
args, args,
overrides,
deletes,
use_name_as_val=k not in CORE_REGISTRIES,
) )
) else:
else: deletes.append(k)
deletes.append("lr_scheduler")
no_dc = True no_dc = True
if hasattr(args, "arch"): if hasattr(args, "arch"):

View File

@ -3,17 +3,24 @@
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from argparse import ArgumentParser import ast
from dataclasses import MISSING, dataclass from argparse import ArgumentParser, Namespace
from dataclasses import _MISSING_TYPE, MISSING, dataclass
from enum import Enum from enum import Enum
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from hydra.core.global_hydra import GlobalHydra
from hydra.experimental import compose, initialize
from omegaconf import DictConfig, OmegaConf, open_dict
def eval_str_list(x, x_type=float): def eval_str_list(x, x_type=float):
if x is None: if x is None:
return None return None
if isinstance(x, str): if isinstance(x, str):
x = eval(x) if len(x) == 0:
return []
x = ast.literal_eval(x)
try: try:
return list(map(x_type, x)) return list(map(x_type, x))
except TypeError: except TypeError:
@ -70,22 +77,11 @@ class FairseqDataclass:
!= self.__dataclass_fields__[attribute_name].default != self.__dataclass_fields__[attribute_name].default
): ):
return getattr(self, attribute_name) return getattr(self, attribute_name)
return self.__dataclass_fields__[attribute_name].default
def _get_default_factory(self, attribute_name: str) -> Any: f = self.__dataclass_fields__[attribute_name]
if hasattr(self, attribute_name): if not isinstance(f.default_factory, _MISSING_TYPE):
if str(getattr(self, attribute_name)).startswith("${"): return f.default_factory()
return str(getattr(self, attribute_name)) return f.default
elif str(self.__dataclass_fields__[attribute_name].default).startswith(
"${"
):
return str(self.__dataclass_fields__[attribute_name].default)
elif (
getattr(self, attribute_name)
!= self.__dataclass_fields__[attribute_name].default_factory()
):
return getattr(self, attribute_name)
return self.__dataclass_fields__[attribute_name].default_factory()
def _get_type(self, attribute_name: str) -> Any: def _get_type(self, attribute_name: str) -> Any:
return self.__dataclass_fields__[attribute_name].type return self.__dataclass_fields__[attribute_name].type
@ -119,7 +115,7 @@ def gen_parser_from_dataclass(
def interpret_dc_type(field_type): def interpret_dc_type(field_type):
if isinstance(field_type, str): if isinstance(field_type, str):
raise RuntimeError() raise RuntimeError("field should be a type")
typestring = str(field_type) typestring = str(field_type)
if re.match(r"(typing.|^)Union\[(.*), NoneType\]$", typestring): if re.match(r"(typing.|^)Union\[(.*), NoneType\]$", typestring):
return field_type.__args__[0] return field_type.__args__[0]
@ -129,12 +125,13 @@ def gen_parser_from_dataclass(
dataclass_instance: FairseqDataclass, k: str dataclass_instance: FairseqDataclass, k: str
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""k: dataclass attributes""" """k: dataclass attributes"""
kwargs = {}
field_type = dataclass_instance._get_type(k) field_type = dataclass_instance._get_type(k)
inter_type = interpret_dc_type(field_type) inter_type = interpret_dc_type(field_type)
if isinstance(inter_type, type) and issubclass(inter_type, List):
field_default = dataclass_instance._get_default_factory(k) field_default = dataclass_instance._get_default(k)
else:
field_default = dataclass_instance._get_default(k)
if isinstance(inter_type, type) and issubclass(inter_type, Enum): if isinstance(inter_type, type) and issubclass(inter_type, Enum):
field_choices = [t.value for t in list(inter_type)] field_choices = [t.value for t in list(inter_type)]
@ -143,7 +140,7 @@ def gen_parser_from_dataclass(
field_help = dataclass_instance._get_help(k) field_help = dataclass_instance._get_help(k)
field_const = dataclass_instance._get_argparse_const(k) field_const = dataclass_instance._get_argparse_const(k)
kwargs = {}
if isinstance(field_default, str) and field_default.startswith("${"): if isinstance(field_default, str) and field_default.startswith("${"):
kwargs["default"] = field_default kwargs["default"] = field_default
else: else:
@ -163,7 +160,11 @@ def gen_parser_from_dataclass(
else: else:
raise NotImplementedError() raise NotImplementedError()
if field_default is not MISSING: if field_default is not MISSING:
kwargs["default"] = ",".join(map(str, field_default)) kwargs["default"] = (
",".join(map(str, field_default))
if field_default is not None
else None
)
elif ( elif (
isinstance(inter_type, type) and issubclass(inter_type, Enum) isinstance(inter_type, type) and issubclass(inter_type, Enum)
) or "Enum" in str(inter_type): ) or "Enum" in str(inter_type):
@ -187,6 +188,7 @@ def gen_parser_from_dataclass(
if field_const is not None: if field_const is not None:
kwargs["const"] = field_const kwargs["const"] = field_const
kwargs["nargs"] = "?" kwargs["nargs"] = "?"
return kwargs return kwargs
for k in dataclass_instance._get_all_attributes(): for k in dataclass_instance._get_all_attributes():
@ -194,8 +196,122 @@ def gen_parser_from_dataclass(
if field_name is None: if field_name is None:
continue continue
kwargs = get_kwargs_from_dc(dataclass_instance, k) kwargs = get_kwargs_from_dc(dataclass_instance, k)
if isinstance(kwargs["default"], str) and kwargs["default"].startswith("${"):
continue if "default" in kwargs:
if delete_default: if isinstance(kwargs["default"], str) and kwargs["default"].startswith(
del kwargs["default"] "${"
):
continue
if delete_default:
del kwargs["default"]
parser.add_argument(field_name, **kwargs) parser.add_argument(field_name, **kwargs)
def _set_legacy_defaults(args, cls):
"""Helper to set default arguments based on *add_args*."""
if not hasattr(cls, "add_args"):
return
import argparse
parser = argparse.ArgumentParser(
argument_default=argparse.SUPPRESS, allow_abbrev=False
)
cls.add_args(parser)
# copied from argparse.py:
defaults = argparse.Namespace()
for action in parser._actions:
if action.dest is not argparse.SUPPRESS:
if not hasattr(defaults, action.dest):
if action.default is not argparse.SUPPRESS:
setattr(defaults, action.dest, action.default)
for key, default_value in vars(defaults).items():
if not hasattr(args, key):
setattr(args, key, default_value)
def convert_namespace_to_omegaconf(args: Namespace) -> DictConfig:
from fairseq.dataclass.data_class import override_module_args
# Here we are using field values provided in args to override counterparts inside config object
overrides, deletes = override_module_args(args)
cfg_name = "config"
cfg_path = f"../../{cfg_name}"
if not GlobalHydra().is_initialized():
initialize(config_path=cfg_path)
composed_cfg = compose(cfg_name, overrides=overrides, strict=False)
for k in deletes:
composed_cfg[k] = None
cfg = OmegaConf.create(
OmegaConf.to_container(composed_cfg, resolve=True, enum_to_str=True)
)
# hack to be able to set Namespace in dict config. this should be removed when we update to newer
# omegaconf version that supports object flags, or when we migrate all existing models
from omegaconf import _utils
old_primitive = _utils.is_primitive_type
_utils.is_primitive_type = lambda _: True
if cfg.task is None and getattr(args, "task", None):
cfg.task = Namespace(**vars(args))
from fairseq.tasks import TASK_REGISTRY
_set_legacy_defaults(cfg.task, TASK_REGISTRY[args.task])
cfg.task._name = args.task
if cfg.model is None and getattr(args, "arch", None):
cfg.model = Namespace(**vars(args))
from fairseq.models import ARCH_MODEL_REGISTRY
_set_legacy_defaults(cfg.model, ARCH_MODEL_REGISTRY[args.arch])
cfg.model._name = args.arch
if cfg.optimizer is None and getattr(args, "optimizer", None):
cfg.optimizer = Namespace(**vars(args))
from fairseq.optim import OPTIMIZER_REGISTRY
_set_legacy_defaults(cfg.optimizer, OPTIMIZER_REGISTRY[args.optimizer])
cfg.optimizer._name = args.optimizer
if cfg.lr_scheduler is None and getattr(args, "lr_scheduler", None):
cfg.lr_scheduler = Namespace(**vars(args))
from fairseq.optim.lr_scheduler import LR_SCHEDULER_REGISTRY
_set_legacy_defaults(cfg.lr_scheduler, LR_SCHEDULER_REGISTRY[args.lr_scheduler])
cfg.lr_scheduler._name = args.lr_scheduler
if cfg.criterion is None and getattr(args, "criterion", None):
cfg.criterion = Namespace(**vars(args))
from fairseq.criterions import CRITERION_REGISTRY
_set_legacy_defaults(cfg.criterion, CRITERION_REGISTRY[args.criterion])
cfg.criterion._name = args.criterion
_utils.is_primitive_type = old_primitive
OmegaConf.set_struct(cfg, True)
return cfg
def populate_dataclass(
args: Namespace, dataclass: FairseqDataclass
) -> FairseqDataclass:
for k in dataclass.__dataclass_fields__.keys():
if k.startswith("_"):
# private member, skip
continue
if hasattr(args, k):
setattr(dataclass, k, getattr(args, k))
return dataclass
def overwrite_args_by_name(cfg: DictConfig, overrides: Dict[str, any]):
# this will be deprecated when we get rid of argparse and model_overrides logic
with open_dict(cfg):
for k in cfg.keys():
if isinstance(cfg[k], DictConfig):
overwrite_args_by_name(cfg[k], overrides)
elif k in overrides:
cfg[k] = overrides[k]

View File

@ -11,35 +11,38 @@ import socket
import struct import struct
import subprocess import subprocess
import warnings import warnings
from argparse import Namespace
from collections import OrderedDict from collections import OrderedDict
from typing import Any, Dict, Mapping from typing import Any, Dict, Mapping
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from fairseq import utils from fairseq import utils
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
from omegaconf import DictConfig, open_dict
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def is_master(args): def is_master(cfg: DictConfig):
return args.distributed_rank == 0 return cfg.distributed_rank == 0
def infer_init_method(args, force_distributed=False): def infer_init_method(cfg: DictConfig, force_distributed=False):
if args.distributed_init_method is not None or getattr(args, "tpu", False): if cfg.distributed_init_method is not None or cfg.tpu:
return return
if args.pipeline_model_parallel: if cfg.pipeline_model_parallel:
balance_exists = ( balance_exists = (
args.pipeline_balance is not None cfg.pipeline_balance is not None
or args.pipeline_encoder_balance is not None or cfg.pipeline_encoder_balance is not None
or args.pipeline_decoder_balance is not None or cfg.pipeline_decoder_balance is not None
) )
devices_exist = ( devices_exist = (
args.pipeline_devices is not None cfg.pipeline_devices is not None
or args.pipeline_encoder_devices is not None or cfg.pipeline_encoder_devices is not None
or args.pipeline_decoder_devices is not None or cfg.pipeline_decoder_devices is not None
) )
if not balance_exists: if not balance_exists:
raise ValueError( raise ValueError(
@ -50,19 +53,19 @@ def infer_init_method(args, force_distributed=False):
"--pipeline-devices is currently required for pipeline model parallelism" "--pipeline-devices is currently required for pipeline model parallelism"
) )
args.pipeline_balance = utils.eval_str_list(args.pipeline_balance, type=int) cfg.pipeline_balance = utils.eval_str_list(cfg.pipeline_balance, type=int)
if args.pipeline_devices is not None: if cfg.pipeline_devices is not None:
args.pipeline_devices = utils.eval_str_list(args.pipeline_devices, type=int) cfg.pipeline_devices = utils.eval_str_list(cfg.pipeline_devices, type=int)
num_pipeline_devices = len(set(args.pipeline_devices)) num_pipeline_devices = len(set(cfg.pipeline_devices))
else: else:
args.pipeline_encoder_devices = utils.eval_str_list( cfg.pipeline_encoder_devices = utils.eval_str_list(
args.pipeline_encoder_devices, type=int cfg.pipeline_encoder_devices, type=int
) )
args.pipeline_decoder_devices = utils.eval_str_list( cfg.pipeline_decoder_devices = utils.eval_str_list(
args.pipeline_decoder_devices, type=int cfg.pipeline_decoder_devices, type=int
) )
num_pipeline_devices = len( num_pipeline_devices = len(
set(args.pipeline_encoder_devices + args.pipeline_decoder_devices) set(cfg.pipeline_encoder_devices + cfg.pipeline_decoder_devices)
) )
gpus_per_node = torch.cuda.device_count() gpus_per_node = torch.cuda.device_count()
assert ( assert (
@ -79,14 +82,14 @@ def infer_init_method(args, force_distributed=False):
key in os.environ key in os.environ
for key in ["MASTER_ADDR", "MASTER_PORT", "WORLD_SIZE", "RANK"] for key in ["MASTER_ADDR", "MASTER_PORT", "WORLD_SIZE", "RANK"]
): ):
args.distributed_init_method = "env://" cfg.distributed_init_method = "env://"
args.distributed_world_size = int(os.environ["WORLD_SIZE"]) cfg.distributed_world_size = int(os.environ["WORLD_SIZE"])
args.distributed_rank = int(os.environ["RANK"]) cfg.distributed_rank = int(os.environ["RANK"])
# processes are created by torch.distributed.launch # processes are created by torch.distributed.launch
args.distributed_no_spawn = True cfg.distributed_no_spawn = True
# we can determine the init method automatically for Slurm # we can determine the init method automatically for Slurm
elif args.distributed_port > 0: elif cfg.distributed_port > 0:
node_list = os.environ.get("SLURM_STEP_NODELIST") node_list = os.environ.get("SLURM_STEP_NODELIST")
if node_list is None: if node_list is None:
node_list = os.environ.get("SLURM_JOB_NODELIST") node_list = os.environ.get("SLURM_JOB_NODELIST")
@ -95,9 +98,9 @@ def infer_init_method(args, force_distributed=False):
hostnames = subprocess.check_output( hostnames = subprocess.check_output(
["scontrol", "show", "hostnames", node_list] ["scontrol", "show", "hostnames", node_list]
) )
args.distributed_init_method = "tcp://{host}:{port}".format( cfg.distributed_init_method = "tcp://{host}:{port}".format(
host=hostnames.split()[0].decode("utf-8"), host=hostnames.split()[0].decode("utf-8"),
port=args.distributed_port, port=cfg.distributed_port,
) )
nnodes = int(os.environ.get("SLURM_NNODES")) nnodes = int(os.environ.get("SLURM_NNODES"))
ntasks_per_node = os.environ.get("SLURM_NTASKS_PER_NODE") ntasks_per_node = os.environ.get("SLURM_NTASKS_PER_NODE")
@ -111,88 +114,94 @@ def infer_init_method(args, force_distributed=False):
if ntasks_per_node == 1: if ntasks_per_node == 1:
gpus_per_node = torch.cuda.device_count() gpus_per_node = torch.cuda.device_count()
node_id = int(os.environ.get("SLURM_NODEID")) node_id = int(os.environ.get("SLURM_NODEID"))
args.distributed_rank = node_id * gpus_per_node cfg.distributed_rank = node_id * gpus_per_node
args.distributed_world_size = nnodes * gpus_per_node cfg.distributed_world_size = nnodes * gpus_per_node
elif args.pipeline_model_parallel: elif cfg.pipeline_model_parallel:
assert ntasks_per_node == num_pipelines_per_node, ( assert ntasks_per_node == num_pipelines_per_node, (
"SLURM --ntasks-per-node must match number of pipelines per " "SLURM --ntasks-per-node must match number of pipelines per "
"node (={})".format(num_pipelines_per_node) "node (={})".format(num_pipelines_per_node)
) )
args.distributed_no_spawn = True cfg.distributed_no_spawn = True
# For 4-way MP on nodes with 8 GPUs, ranks will be [0, 1] on # For 4-way MP on nodes with 8 GPUs, ranks will be [0, 1] on
# the first node, [1, 2] on the second node, etc. This # the first node, [1, 2] on the second node, etc. This
# matches torch.distributed.launch. # matches torch.distributed.launch.
node_id = int(os.environ.get("SLURM_NODEID")) node_id = int(os.environ.get("SLURM_NODEID"))
local_id = int(os.environ.get("SLURM_LOCALID")) local_id = int(os.environ.get("SLURM_LOCALID"))
args.distributed_rank = node_id * num_pipelines_per_node + local_id cfg.distributed_rank = node_id * num_pipelines_per_node + local_id
# In the above example, device_id will always be in [0, 1], # In the above example, device_id will always be in [0, 1],
# which also matches torch.distributed.launch. # which also matches torch.distributed.launch.
args.device_id = local_id cfg.device_id = local_id
# We also want to set distributed_world_size to be the total # We also want to set distributed_world_size to be the total
# number of pipelines across all nodes. # number of pipelines across all nodes.
args.distributed_world_size = nnodes * num_pipelines_per_node cfg.distributed_world_size = nnodes * num_pipelines_per_node
else: else:
assert ntasks_per_node == args.distributed_world_size // nnodes assert ntasks_per_node == cfg.distributed_world_size // nnodes
args.distributed_no_spawn = True cfg.distributed_no_spawn = True
args.distributed_rank = int(os.environ.get("SLURM_PROCID")) cfg.distributed_rank = int(os.environ.get("SLURM_PROCID"))
args.device_id = int(os.environ.get("SLURM_LOCALID")) cfg.device_id = int(os.environ.get("SLURM_LOCALID"))
except subprocess.CalledProcessError as e: # scontrol failed except subprocess.CalledProcessError as e: # scontrol failed
raise e raise e
except FileNotFoundError: # Slurm is not installed except FileNotFoundError: # Slurm is not installed
pass pass
elif args.distributed_world_size > 1 or force_distributed: elif cfg.distributed_world_size > 1 or force_distributed:
# fallback for single node with multiple GPUs # fallback for single node with multiple GPUs
assert args.distributed_world_size <= torch.cuda.device_count() assert cfg.distributed_world_size <= torch.cuda.device_count()
port = random.randint(10000, 20000) port = random.randint(10000, 20000)
args.distributed_init_method = "tcp://localhost:{port}".format(port=port) cfg.distributed_init_method = "tcp://localhost:{port}".format(port=port)
if args.pipeline_model_parallel: if cfg.pipeline_model_parallel:
if not args.distributed_no_spawn: if not cfg.distributed_no_spawn:
# When distributed_no_spawn is False, we expect distributed_rank and # When distributed_no_spawn is False, we expect distributed_rank and
# distributed_world_size to be based on the total number of GPUs, so # distributed_world_size to be based on the total number of GPUs, so
# we need to correct them to be based on the number of pipelines. # we need to correct them to be based on the number of pipelines.
assert args.distributed_world_size % num_pipeline_devices == 0 assert cfg.distributed_world_size % num_pipeline_devices == 0
args.distributed_world_size = ( cfg.distributed_world_size = (
args.distributed_world_size // num_pipeline_devices cfg.distributed_world_size // num_pipeline_devices
) )
# In the case of 4-way MP on nodes with 8 GPUs, we want # In the case of 4-way MP on nodes with 8 GPUs, we want
# distributed_rank to be the starting GPU index for each pipeline # distributed_rank to be the starting GPU index for each pipeline
# i.e., 0, 2, ... # i.e., 0, 2, ...
assert args.distributed_rank % gpus_per_node == 0 assert cfg.distributed_rank % gpus_per_node == 0
assert args.distributed_rank % num_pipeline_devices == 0 assert cfg.distributed_rank % num_pipeline_devices == 0
args.distributed_rank = args.distributed_rank // num_pipeline_devices
# launch one process per pipeline with open_dict(cfg):
args.distributed_num_procs = num_pipelines_per_node cfg.distributed_rank = cfg.distributed_rank // num_pipeline_devices
# launch one process per pipeline
cfg.distributed_num_procs = num_pipelines_per_node
# if we have 4-way MP on a node with 8 GPUs, we want device_ids to be 0 # if we have 4-way MP on a node with 8 GPUs, we want device_ids to be 0
# and 4, indicating the starting device IDs for each pipeline # and 4, indicating the starting device IDs for each pipeline
args.device_id *= num_pipeline_devices cfg.device_id *= num_pipeline_devices
if args.device_id > 0: if cfg.device_id > 0:
# if there's multiple pipelines on a node (e.g., 4-way MP on an 8 # if there's multiple pipelines on a node (e.g., 4-way MP on an 8
# GPU node), we need to adjust pipeline_devices accordingly # GPU node), we need to adjust pipeline_devices accordingly
logger.debug( logger.debug(
"setting CUDA device={} on rank {}".format( "setting CUDA device={} on rank {}".format(
args.device_id, args.distributed_rank cfg.device_id, cfg.distributed_rank
) )
) )
torch.cuda.set_device(args.device_id) torch.cuda.set_device(cfg.device_id)
args.pipeline_devices = [args.device_id + d for d in args.pipeline_devices] with open_dict(cfg):
cfg.pipeline_devices = [cfg.device_id + d for d in cfg.pipeline_devices]
logger.info( logger.info(
"setting pipeline_devices={} on rank {}".format( "setting pipeline_devices={} on rank {}".format(
args.pipeline_devices, args.distributed_rank cfg.pipeline_devices, cfg.distributed_rank
), )
)
elif not cfg.distributed_no_spawn:
with open_dict(cfg):
cfg.distributed_num_procs = min(
torch.cuda.device_count(), cfg.distributed_world_size
) )
elif not args.distributed_no_spawn:
args.distributed_num_procs = min(
torch.cuda.device_count(),
args.distributed_world_size,
)
def distributed_init(args): def distributed_init(cfg: DictConfig):
if not getattr(args, "tpu", False): if isinstance(cfg, Namespace):
cfg = convert_namespace_to_omegaconf(cfg)
if not cfg.common.tpu:
if torch.distributed.is_initialized(): if torch.distributed.is_initialized():
warnings.warn( warnings.warn(
"Distributed is already initialized, cannot initialize twice!" "Distributed is already initialized, cannot initialize twice!"
@ -200,20 +209,20 @@ def distributed_init(args):
else: else:
logger.info( logger.info(
"distributed init (rank {}): {}".format( "distributed init (rank {}): {}".format(
args.distributed_rank, cfg.distributed_training.distributed_rank,
args.distributed_init_method, cfg.distributed_training.distributed_init_method,
) )
) )
dist.init_process_group( dist.init_process_group(
backend=args.distributed_backend, backend=cfg.distributed_training.distributed_backend,
init_method=args.distributed_init_method, init_method=cfg.distributed_training.distributed_init_method,
world_size=args.distributed_world_size, world_size=cfg.distributed_training.distributed_world_size,
rank=args.distributed_rank, rank=cfg.distributed_training.distributed_rank,
) )
logger.info( logger.info(
"initialized host {} as rank {}".format( "initialized host {} as rank {}".format(
socket.gethostname(), socket.gethostname(),
args.distributed_rank, cfg.distributed_training.distributed_rank,
) )
) )
@ -221,20 +230,22 @@ def distributed_init(args):
if torch.cuda.is_available(): if torch.cuda.is_available():
dist.all_reduce(torch.zeros(1).cuda()) dist.all_reduce(torch.zeros(1).cuda())
args.distributed_rank = torch.distributed.get_rank() cfg.distributed_training.distributed_rank = torch.distributed.get_rank()
else: else:
import torch_xla.core.xla_model as xm import torch_xla.core.xla_model as xm
assert xm.xrt_world_size() == args.distributed_world_size assert xm.xrt_world_size() == cfg.distributed_training.distributed_world_size
args.device_id = xm.get_local_ordinal() cfg.distributed_training.device_id = xm.get_local_ordinal()
args.distributed_rank = xm.get_ordinal() cfg.distributed_training.distributed_rank = xm.get_ordinal()
xm.rendezvous("distributed_init") # wait for all workers xm.rendezvous("distributed_init") # wait for all workers
xm.mark_step() xm.mark_step()
if not is_master(args): if is_master(cfg.distributed_training):
logging.getLogger().setLevel(logging.INFO)
else:
logging.getLogger().setLevel(logging.WARNING) logging.getLogger().setLevel(logging.WARNING)
if args.model_parallel_size > 1: if cfg.common.model_parallel_size > 1:
try: try:
from fairseq.model_parallel.megatron.mpu import ( from fairseq.model_parallel.megatron.mpu import (
get_model_parallel_rank, get_model_parallel_rank,
@ -247,58 +258,61 @@ def distributed_init(args):
"\n\n git submodule update --init " "\n\n git submodule update --init "
"fairseq/model_parallel/megatron" "fairseq/model_parallel/megatron"
) )
initialize_model_parallel(args.model_parallel_size) initialize_model_parallel(cfg.common.model_parallel_size)
model_parallel_cuda_manual_seed(args.seed) model_parallel_cuda_manual_seed(cfg.common.seed)
model_part_number = get_model_parallel_rank() model_part_number = get_model_parallel_rank()
args.checkpoint_suffix += "-model_part-{0}".format(model_part_number) cfg.checkpoint.checkpoint_suffix += "-model_part-{0}".format(model_part_number)
return args.distributed_rank return cfg.distributed_training.distributed_rank
def distributed_main(i, main, args, kwargs): def distributed_main(i, main, cfg: DictConfig, kwargs):
args.device_id = i cfg.distributed_training.device_id = i
if torch.cuda.is_available() and not args.cpu and not getattr(args, "tpu", False): if torch.cuda.is_available() and not cfg.common.cpu and not cfg.common.tpu:
torch.cuda.set_device(args.device_id) torch.cuda.set_device(cfg.distributed_training.device_id)
if args.distributed_rank is None: # torch.multiprocessing.spawn if cfg.distributed_training.distributed_rank is None: # torch.multiprocessing.spawn
args.distributed_rank = kwargs.pop("start_rank", 0) + i cfg.distributed_training.distributed_rank = kwargs.pop("start_rank", 0) + i
args.distributed_rank = distributed_init(args) cfg.distributed_training.distributed_rank = distributed_init(cfg)
after_distributed_init_fn = kwargs.pop("after_distributed_init_fn", None) after_distributed_init_fn = kwargs.pop("after_distributed_init_fn", None)
if after_distributed_init_fn: if after_distributed_init_fn:
args = after_distributed_init_fn(args) cfg = after_distributed_init_fn(cfg)
main(args, **kwargs) main(cfg, **kwargs)
def call_main(args, main, **kwargs): def call_main(cfg: DictConfig, main, **kwargs):
if args.distributed_init_method is None: if cfg.distributed_training.distributed_init_method is None:
infer_init_method(args) infer_init_method(cfg.distributed_training)
if args.distributed_init_method is not None: if cfg.distributed_training.distributed_init_method is not None:
# distributed training # distributed training
if not args.distributed_no_spawn: if not cfg.distributed_training.distributed_no_spawn:
start_rank = args.distributed_rank start_rank = cfg.distributed_training.distributed_rank
args.distributed_rank = None # assign automatically cfg.distributed_training.distributed_rank = None # assign automatically
kwargs["start_rank"] = start_rank kwargs["start_rank"] = start_rank
torch.multiprocessing.spawn( torch.multiprocessing.spawn(
fn=distributed_main, fn=distributed_main,
args=(main, args, kwargs), args=(main, cfg, kwargs),
nprocs=args.distributed_num_procs, nprocs=min(
torch.cuda.device_count(),
cfg.distributed_training.distributed_world_size,
),
) )
else: else:
distributed_main(args.device_id, main, args, kwargs) distributed_main(cfg.distributed_training.device_id, main, cfg, kwargs)
elif getattr(args, "tpu", False) and args.distributed_world_size > 1: elif cfg.common.tpu and cfg.distributed_training.distributed_world_size > 1:
import torch_xla.distributed.xla_multiprocessing as xmp import torch_xla.distributed.xla_multiprocessing as xmp
torch.multiprocessing.set_sharing_strategy("file_system") torch.multiprocessing.set_sharing_strategy("file_system")
xmp.spawn( xmp.spawn(
fn=distributed_main, fn=distributed_main,
args=(main, args, kwargs), args=(main, cfg, kwargs),
nprocs=8, # use all 8 TPU cores nprocs=8, # use all 8 TPU cores
) )
else: else:
# single GPU main # single GPU main
main(args, **kwargs) main(cfg, **kwargs)
def get_rank(): def get_rank():
@ -392,11 +406,7 @@ def all_gather_list(data, group=None, max_size=16384):
) )
def all_reduce_dict( def all_reduce_dict(data: Mapping[str, Any], device, group=None) -> Dict[str, Any]:
data: Mapping[str, Any],
device,
group=None,
) -> Dict[str, Any]:
""" """
AllReduce a dictionary of values across workers. We separately AllReduce a dictionary of values across workers. We separately
reduce items that are already on the device and items on CPU for reduce items that are already on the device and items on CPU for

View File

@ -8,11 +8,12 @@ import argparse
import copy import copy
import logging import logging
import os import os
from typing import Any, Dict, Iterator, List, Tuple from typing import Any, Dict, Iterator, List
import torch import torch
from fairseq import utils from fairseq import utils
from fairseq.data import encoders from fairseq.data import encoders
from omegaconf import open_dict
from torch import nn from torch import nn
@ -85,9 +86,9 @@ class GeneratorHubInterface(nn.Module):
translation or language model. translation or language model.
""" """
def __init__(self, args, task, models): def __init__(self, cfg, task, models):
super().__init__() super().__init__()
self.args = args self.cfg = cfg
self.task = task self.task = task
self.models = nn.ModuleList(models) self.models = nn.ModuleList(models)
self.src_dict = task.source_dictionary self.src_dict = task.source_dictionary
@ -95,14 +96,14 @@ class GeneratorHubInterface(nn.Module):
# optimize model for generation # optimize model for generation
for model in self.models: for model in self.models:
model.prepare_for_inference_(args) model.prepare_for_inference_(cfg)
# Load alignment dictionary for unknown word replacement # Load alignment dictionary for unknown word replacement
# (None if no unknown word replacement, empty if no path to align dictionary) # (None if no unknown word replacement, empty if no path to align dictionary)
self.align_dict = utils.load_align_dict(getattr(args, "replace_unk", None)) self.align_dict = utils.load_align_dict(cfg.generation.replace_unk)
self.tokenizer = encoders.build_tokenizer(args) self.tokenizer = encoders.build_tokenizer(cfg.tokenizer)
self.bpe = encoders.build_bpe(args) self.bpe = encoders.build_bpe(cfg.bpe)
self.max_positions = utils.resolve_max_positions( self.max_positions = utils.resolve_max_positions(
self.task.max_positions(), *[model.max_positions() for model in models] self.task.max_positions(), *[model.max_positions() for model in models]
@ -156,10 +157,11 @@ class GeneratorHubInterface(nn.Module):
)[0] )[0]
# build generator using current args as well as any kwargs # build generator using current args as well as any kwargs
gen_args = copy.copy(self.args) gen_args = copy.copy(self.cfg)
gen_args.beam = beam with open_dict(gen_args):
for k, v in kwargs.items(): gen_args.beam = beam
setattr(gen_args, k, v) for k, v in kwargs.items():
setattr(gen_args, k, v)
generator = self.task.build_generator(self.models, gen_args) generator = self.task.build_generator(self.models, gen_args)
inference_step_args = inference_step_args or {} inference_step_args = inference_step_args or {}
@ -253,8 +255,8 @@ class GeneratorHubInterface(nn.Module):
lengths = torch.LongTensor([t.numel() for t in tokens]) lengths = torch.LongTensor([t.numel() for t in tokens])
batch_iterator = self.task.get_batch_iterator( batch_iterator = self.task.get_batch_iterator(
dataset=self.task.build_dataset_for_inference(tokens, lengths), dataset=self.task.build_dataset_for_inference(tokens, lengths),
max_tokens=self.args.max_tokens, max_tokens=self.cfg.dataset.max_tokens,
max_sentences=self.args.batch_size, max_sentences=self.cfg.dataset.batch_size,
max_positions=self.max_positions, max_positions=self.max_positions,
ignore_invalid_inputs=skip_invalid_size_inputs, ignore_invalid_inputs=skip_invalid_size_inputs,
disable_iterator_cache=True, disable_iterator_cache=True,

View File

@ -9,6 +9,7 @@ Train a network across multiple GPUs.
from fairseq import distributed_utils from fairseq import distributed_utils
from fairseq.trainer import Trainer from fairseq.trainer import Trainer
from omegaconf import DictConfig
try: try:
@ -28,14 +29,14 @@ except (ImportError, ModuleNotFoundError):
class MegatronTrainer(Trainer): class MegatronTrainer(Trainer):
"""Main class for model parallel with data parallel training.""" """Main class for model parallel with data parallel training."""
def __init__(self, args, task, model, criterion): def __init__(self, cfg: DictConfig, task, model, criterion, **kwargs):
if not has_megatron_submodule: if not has_megatron_submodule:
raise ImportError( raise ImportError(
"\n\nPlease install the megatron submodule:" "\n\nPlease install the megatron submodule:"
"\n\n git submodule update --init " "\n\n git submodule update --init "
"fairseq/model_parallel/megatron" "fairseq/model_parallel/megatron"
) )
super().__init__(args, task, model, criterion) super().__init__(cfg, task, model, criterion, **kwargs)
@property @property
def data_parallel_world_size(self): def data_parallel_world_size(self):

View File

@ -96,7 +96,7 @@ class PipelineParallelTransformerModel(BaseFairseqModel):
encoder_output_tuple = self.encoder(input) encoder_output_tuple = self.encoder(input)
return self.decoder(encoder_output_tuple) return self.decoder(encoder_output_tuple)
def prepare_for_inference_(self, args): def prepare_for_inference_(self, cfg):
if self.encoder is not None and self.decoder is not None: if self.encoder is not None and self.decoder is not None:
logger.info("Encoder and Decoder already initialized") logger.info("Encoder and Decoder already initialized")
return return
@ -111,9 +111,9 @@ class PipelineParallelTransformerModel(BaseFairseqModel):
decoder_module_list.append(module) decoder_module_list.append(module)
module_count += 1 module_count += 1
self.model = None self.model = None
self.encoder = TransformerEncoder(args, None, None, encoder_module_list) self.encoder = TransformerEncoder(cfg.model, None, None, encoder_module_list)
self.decoder = TransformerDecoder( self.decoder = TransformerDecoder(
args, None, None, decoder_module_list=decoder_module_list cfg.model, None, None, decoder_module_list=decoder_module_list
) )
@staticmethod @staticmethod
@ -320,7 +320,7 @@ class PipelineParallelTransformerModel(BaseFairseqModel):
"""Maximum length supported by the decoder.""" """Maximum length supported by the decoder."""
return self.decoder_max_positions return self.decoder_max_positions
def load_state_dict(self, state_dict, strict=True, args=None): def load_state_dict(self, state_dict, strict=True, cfg=None):
"""Copies parameters and buffers from *state_dict* into this module and """Copies parameters and buffers from *state_dict* into this module and
its descendants. its descendants.

View File

@ -72,6 +72,10 @@ class ModelParallelTransformerLanguageModel(TransformerLanguageModel):
) )
return cls(decoder) return cls(decoder)
@staticmethod
def add_args(parser):
TransformerLanguageModel.add_args(parser)
@classmethod @classmethod
def build_embedding(cls, args, dictionary, embed_dim, path=None): def build_embedding(cls, args, dictionary, embed_dim, path=None):
def _vocab_init(tensor, **kwargs): def _vocab_init(tensor, **kwargs):

View File

@ -7,8 +7,6 @@
import argparse import argparse
import importlib import importlib
import os import os
from argparse import Namespace
from typing import Union
import fairseq import fairseq
from fairseq.dataclass import FairseqDataclass from fairseq.dataclass import FairseqDataclass
@ -52,10 +50,10 @@ __all__ = [
] ]
def build_model(model_cfg: Union[DictConfig, Namespace], task): def build_model(cfg: DictConfig, task):
if isinstance(model_cfg, DictConfig): if isinstance(cfg, DictConfig):
return ARCH_MODEL_REGISTRY[model_cfg._name].build_model(model_cfg, task) return ARCH_MODEL_REGISTRY[cfg._name].build_model(cfg, task)
return ARCH_MODEL_REGISTRY[model_cfg.arch].build_model(model_cfg, task) return ARCH_MODEL_REGISTRY[cfg.arch].build_model(cfg, task)
def register_model(name, dataclass=None): def register_model(name, dataclass=None):
@ -92,7 +90,8 @@ def register_model(name, dataclass=None):
) )
cls.__dataclass = dataclass cls.__dataclass = dataclass
MODEL_DATACLASS_REGISTRY[name] = dataclass if dataclass is not None:
MODEL_DATACLASS_REGISTRY[name] = dataclass
return cls return cls
return register_model_cls return register_model_cls
@ -108,14 +107,13 @@ def register_model_architecture(model_name, arch_name):
For example:: For example::
@register_model_architecture('lstm', 'lstm_luong_wmt_en_de') @register_model_architecture('lstm', 'lstm_luong_wmt_en_de')
def lstm_luong_wmt_en_de(args): def lstm_luong_wmt_en_de(cfg):
args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 1000) args.encoder_embed_dim = getattr(cfg.model, 'encoder_embed_dim', 1000)
(...) (...)
The decorated function should take a single argument *args*, which is a The decorated function should take a single argument *cfg*, which is a
:class:`argparse.Namespace` of arguments parsed from the command-line. The :class:`omegaconf.DictConfig`. The decorated function should modify these
decorated function should modify these arguments in-place to match the arguments in-place to match the desired architecture.
desired architecture.
Args: Args:
model_name (str): the name of the Model (Model must already be model_name (str): the name of the Model (Model must already be

View File

@ -13,6 +13,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from fairseq import utils from fairseq import utils
from fairseq.data import encoders from fairseq.data import encoders
from omegaconf import open_dict
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -24,13 +25,13 @@ class BARTHubInterface(nn.Module):
Usage: https://github.com/pytorch/fairseq/tree/master/examples/bart Usage: https://github.com/pytorch/fairseq/tree/master/examples/bart
""" """
def __init__(self, args, task, model): def __init__(self, cfg, task, model):
super().__init__() super().__init__()
self.args = args self.cfg = cfg
self.task = task self.task = task
self.model = model self.model = model
self.bpe = encoders.build_bpe(args) self.bpe = encoders.build_bpe(cfg.bpe)
self.max_positions = min( self.max_positions = min(
utils.resolve_max_positions( utils.resolve_max_positions(
@ -120,10 +121,11 @@ class BARTHubInterface(nn.Module):
sample = self._build_sample(tokens) sample = self._build_sample(tokens)
# build generator using current args as well as any kwargs # build generator using current args as well as any kwargs
gen_args = copy.copy(self.args) gen_args = copy.copy(self.cfg)
gen_args.beam = beam with open_dict(gen_args):
for k, v in kwargs.items(): gen_args.beam = beam
setattr(gen_args, k, v) for k, v in kwargs.items():
setattr(gen_args, k, v)
generator = self.task.build_generator([self.model], gen_args) generator = self.task.build_generator([self.model], gen_args)
translations = self.task.inference_step( translations = self.task.inference_step(
generator, generator,

View File

@ -144,7 +144,9 @@ class BARTModel(TransformerModel):
num_classes=num_classes, num_classes=num_classes,
activation_fn=self.args.pooler_activation_fn, activation_fn=self.args.pooler_activation_fn,
pooler_dropout=self.args.pooler_dropout, pooler_dropout=self.args.pooler_dropout,
do_spectral_norm=self.args.spectral_norm_classification_head, do_spectral_norm=getattr(
self.args, "spectral_norm_classification_head", False
),
) )
def upgrade_state_dict_named(self, state_dict, name): def upgrade_state_dict_named(self, state_dict, name):

View File

@ -7,6 +7,7 @@ Base classes for various fairseq models.
""" """
import logging import logging
from argparse import Namespace
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
import torch import torch
@ -15,8 +16,12 @@ import torch.nn.functional as F
from fairseq import utils from fairseq import utils
from fairseq.checkpoint_utils import prune_state_dict from fairseq.checkpoint_utils import prune_state_dict
from fairseq.data import Dictionary from fairseq.data import Dictionary
from fairseq.dataclass.utils import gen_parser_from_dataclass from fairseq.dataclass.utils import (
convert_namespace_to_omegaconf,
gen_parser_from_dataclass,
)
from fairseq.models import FairseqDecoder, FairseqEncoder from fairseq.models import FairseqDecoder, FairseqEncoder
from omegaconf import DictConfig
from torch import Tensor from torch import Tensor
@ -86,15 +91,26 @@ class BaseFairseqModel(nn.Module):
"""Maximum length supported by the model.""" """Maximum length supported by the model."""
return None return None
def load_state_dict(self, state_dict, strict=True, args=None): def load_state_dict(
self,
state_dict,
strict=True,
model_cfg: Optional[DictConfig] = None,
args: Optional[Namespace] = None,
):
"""Copies parameters and buffers from *state_dict* into this module and """Copies parameters and buffers from *state_dict* into this module and
its descendants. its descendants.
Overrides the method in :class:`nn.Module`. Compared with that method Overrides the method in :class:`nn.Module`. Compared with that method
this additionally "upgrades" *state_dicts* from old checkpoints. this additionally "upgrades" *state_dicts* from old checkpoints.
""" """
if model_cfg is None and args is not None:
logger.warn("using 'args' is deprecated, please update your code to use dataclass config")
model_cfg = convert_namespace_to_omegaconf(args).model
self.upgrade_state_dict(state_dict) self.upgrade_state_dict(state_dict)
new_state_dict = prune_state_dict(state_dict, args) new_state_dict = prune_state_dict(state_dict, model_cfg)
return super().load_state_dict(new_state_dict, strict) return super().load_state_dict(new_state_dict, strict)
def upgrade_state_dict(self, state_dict): def upgrade_state_dict(self, state_dict):
@ -133,18 +149,18 @@ class BaseFairseqModel(nn.Module):
self.apply(_apply) self.apply(_apply)
def prepare_for_inference_(self, args): def prepare_for_inference_(self, cfg: DictConfig):
"""Prepare model for inference.""" """Prepare model for inference."""
kwargs = {} kwargs = {}
kwargs["beamable_mm_beam_size"] = ( kwargs["beamable_mm_beam_size"] = (
None if getattr(args, "no_beamable_mm", False) else getattr(args, "beam", 5) None
if getattr(cfg.generation, "no_beamable_mm", False)
else getattr(cfg.generation, "beam", 5)
) )
kwargs["need_attn"] = getattr(args, "print_alignment", False) kwargs["need_attn"] = getattr(cfg.generation, "print_alignment", False)
if hasattr(args, "retain_dropout"): if getattr(cfg.generation, "retain_dropout", False):
kwargs["retain_dropout"] = args.retain_dropout kwargs["retain_dropout"] = cfg.generation.retain_dropout
kwargs["retain_dropout_modules"] = getattr( kwargs["retain_dropout_modules"] = cfg.generation.retain_dropout_modules
args, "retain_dropout_modules", None
)
self.make_generation_fast_(**kwargs) self.make_generation_fast_(**kwargs)
def make_generation_fast_(self, **kwargs): def make_generation_fast_(self, **kwargs):
@ -437,15 +453,26 @@ class FairseqMultiModel(BaseFairseqModel):
def forward_decoder(self, prev_output_tokens, **kwargs): def forward_decoder(self, prev_output_tokens, **kwargs):
return self.decoder(prev_output_tokens, **kwargs) return self.decoder(prev_output_tokens, **kwargs)
def load_state_dict(self, state_dict, strict=True, args=None): def load_state_dict(
self,
state_dict,
strict=True,
model_cfg=None,
args: Optional[Namespace] = None,
):
"""Copies parameters and buffers from *state_dict* into this module and """Copies parameters and buffers from *state_dict* into this module and
its descendants. its descendants.
Overrides the method in :class:`nn.Module`. Compared with that method Overrides the method in :class:`nn.Module`. Compared with that method
this additionally "upgrades" *state_dicts* from old checkpoints. this additionally "upgrades" *state_dicts* from old checkpoints.
""" """
if model_cfg is None and args is not None:
logger.warn("using 'args' is deprecated, please update your code to use dataclass config")
model_cfg = convert_namespace_to_omegaconf(args).model
self.upgrade_state_dict(state_dict) self.upgrade_state_dict(state_dict)
new_state_dict = prune_state_dict(state_dict, args) new_state_dict = prune_state_dict(state_dict, model_cfg)
return super().load_state_dict(new_state_dict, strict) return super().load_state_dict(new_state_dict, strict)

View File

@ -194,14 +194,14 @@ class MultilingualTransformerModel(FairseqMultiModel):
module_class = TransformerEncoder if is_encoder else TransformerDecoder module_class = TransformerEncoder if is_encoder else TransformerDecoder
return module_class(args, lang_dict, embed_tokens) return module_class(args, lang_dict, embed_tokens)
def load_state_dict(self, state_dict, strict=True, args=None): def load_state_dict(self, state_dict, strict=True, model_cfg=None):
state_dict_subset = state_dict.copy() state_dict_subset = state_dict.copy()
for k, _ in state_dict.items(): for k, _ in state_dict.items():
assert k.startswith("models.") assert k.startswith("models.")
lang_pair = k.split(".")[1] lang_pair = k.split(".")[1]
if lang_pair not in self.models: if lang_pair not in self.models:
del state_dict_subset[k] del state_dict_subset[k]
super().load_state_dict(state_dict_subset, strict=strict, args=args) super().load_state_dict(state_dict_subset, strict=strict, model_cfg=model_cfg)
@register_model_architecture("multilingual_transformer", "multilingual_transformer") @register_model_architecture("multilingual_transformer", "multilingual_transformer")

View File

@ -17,13 +17,13 @@ class RobertaHubInterface(nn.Module):
Usage: https://github.com/pytorch/fairseq/tree/master/examples/roberta Usage: https://github.com/pytorch/fairseq/tree/master/examples/roberta
""" """
def __init__(self, args, task, model): def __init__(self, cfg, task, model):
super().__init__() super().__init__()
self.args = args self.cfg = cfg
self.task = task self.task = task
self.model = model self.model = model
self.bpe = encoders.build_bpe(args) self.bpe = encoders.build_bpe(cfg.bpe)
# this is useful for determining the device # this is useful for determining the device
self.register_buffer("_float_tensor", torch.tensor([0], dtype=torch.float)) self.register_buffer("_float_tensor", torch.tensor([0], dtype=torch.float))

View File

@ -494,7 +494,7 @@ def base_architecture(args):
args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0.0) args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0.0)
args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0.0) args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0.0)
args.spectral_norm_classification_head = getattr( args.spectral_norm_classification_head = getattr(
args, "spectral_nrom_classification_head", False args, "spectral_norm_classification_head", False
) )

View File

@ -578,10 +578,9 @@ class TransformerDecoder(FairseqIncrementalDecoder):
if embed_dim != input_embed_dim if embed_dim != input_embed_dim
else None else None
) )
self.embed_positions = ( self.embed_positions = (
PositionalEmbedding( PositionalEmbedding(
args.max_target_positions, self.max_target_positions,
embed_dim, embed_dim,
self.padding_idx, self.padding_idx,
learned=args.decoder_learned_pos, learned=args.decoder_learned_pos,
@ -963,6 +962,14 @@ def base_architecture(args):
args.layernorm_embedding = getattr(args, "layernorm_embedding", False) args.layernorm_embedding = getattr(args, "layernorm_embedding", False)
args.tie_adaptive_weights = getattr(args, "tie_adaptive_weights", False) args.tie_adaptive_weights = getattr(args, "tie_adaptive_weights", False)
args.encoder_layers_to_keep = getattr(args, "encoder_layers_to_keep", None)
args.decoder_layers_to_keep = getattr(args, "decoder_layers_to_keep", None)
args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0)
args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0)
args.quant_noise_pq = getattr(args, "quant_noise_pq", 0)
args.quant_noise_pq_block_size = getattr(args, "quant_noise_pq_block_size", 8)
args.quant_noise_scalar = getattr(args, "quant_noise_scalar", 0)
@register_model_architecture("transformer", "transformer_iwslt_de_en") @register_model_architecture("transformer", "transformer_iwslt_de_en")
def transformer_iwslt_de_en(args): def transformer_iwslt_de_en(args):

View File

@ -159,7 +159,7 @@ class TransformerLanguageModelConfig(FairseqDataclass):
add_bos_token: bool = II("task.add_bos_token") add_bos_token: bool = II("task.add_bos_token")
tokens_per_sample: int = II("task.tokens_per_sample") tokens_per_sample: int = II("task.tokens_per_sample")
max_target_positions: Optional[int] = II("task.max_target_positions") max_target_positions: Optional[int] = II("task.max_target_positions")
tpu: bool = II("params.common.tpu") tpu: bool = II("common.tpu")
@register_model("transformer_lm", dataclass=TransformerLanguageModelConfig) @register_model("transformer_lm", dataclass=TransformerLanguageModelConfig)

View File

@ -32,20 +32,20 @@ class TransformerEncoderLayer(nn.Module):
def __init__(self, args): def __init__(self, args):
super().__init__() super().__init__()
self.embed_dim = args.encoder_embed_dim self.embed_dim = args.encoder_embed_dim
self.quant_noise = getattr(args, "quant_noise_pq", 0) self.quant_noise = getattr(args, 'quant_noise_pq', 0)
self.quant_noise_block_size = getattr(args, "quant_noise_pq_block_size", 8) self.quant_noise_block_size = getattr(args, 'quant_noise_pq_block_size', 8) or 8
self.self_attn = self.build_self_attention(self.embed_dim, args) self.self_attn = self.build_self_attention(self.embed_dim, args)
self.self_attn_layer_norm = LayerNorm(self.embed_dim) self.self_attn_layer_norm = LayerNorm(self.embed_dim)
self.dropout_module = FairseqDropout( self.dropout_module = FairseqDropout(
args.dropout, module_name=self.__class__.__name__ args.dropout, module_name=self.__class__.__name__
) )
self.activation_fn = utils.get_activation_fn( self.activation_fn = utils.get_activation_fn(
activation=getattr(args, "activation_fn", "relu") activation=getattr(args, 'activation_fn', 'relu') or "relu"
) )
activation_dropout_p = getattr(args, "activation_dropout", 0) activation_dropout_p = getattr(args, "activation_dropout", 0) or 0
if activation_dropout_p == 0: if activation_dropout_p == 0:
# for backwards compatibility with models that use args.relu_dropout # for backwards compatibility with models that use args.relu_dropout
activation_dropout_p = getattr(args, "relu_dropout", 0) activation_dropout_p = getattr(args, "relu_dropout", 0) or 0
self.activation_dropout_module = FairseqDropout( self.activation_dropout_module = FairseqDropout(
float(activation_dropout_p), module_name=self.__class__.__name__ float(activation_dropout_p), module_name=self.__class__.__name__
) )
@ -197,10 +197,10 @@ class TransformerDecoderLayer(nn.Module):
if getattr(args, "activation_fn", None) is not None if getattr(args, "activation_fn", None) is not None
else "relu" else "relu"
) )
activation_dropout_p = getattr(args, "activation_dropout", 0) activation_dropout_p = getattr(args, "activation_dropout", 0) or 0
if activation_dropout_p == 0: if activation_dropout_p == 0:
# for backwards compatibility with models that use args.relu_dropout # for backwards compatibility with models that use args.relu_dropout
activation_dropout_p = getattr(args, "relu_dropout", 0) activation_dropout_p = getattr(args, "relu_dropout", 0) or 0
self.activation_dropout_module = FairseqDropout( self.activation_dropout_module = FairseqDropout(
float(activation_dropout_p), module_name=self.__class__.__name__ float(activation_dropout_p), module_name=self.__class__.__name__
) )

View File

@ -6,8 +6,6 @@
import importlib import importlib
import os import os
from argparse import Namespace
from typing import Union
from fairseq import registry from fairseq import registry
from fairseq.optim.bmuf import FairseqBMUF # noqa from fairseq.optim.bmuf import FairseqBMUF # noqa
@ -19,7 +17,6 @@ from fairseq.optim.fp16_optimizer import FP16Optimizer, MemoryEfficientFP16Optim
from fairseq.optim.shard import shard_ from fairseq.optim.shard import shard_
from omegaconf import DictConfig from omegaconf import DictConfig
__all__ = [ __all__ = [
"FairseqOptimizer", "FairseqOptimizer",
"FP16Optimizer", "FP16Optimizer",
@ -27,7 +24,6 @@ __all__ = [
"shard_", "shard_",
] ]
( (
_build_optimizer, _build_optimizer,
register_optimizer, register_optimizer,
@ -37,12 +33,12 @@ __all__ = [
def build_optimizer( def build_optimizer(
optimizer_cfg: Union[DictConfig, Namespace], params, *extra_args, **extra_kwargs cfg: DictConfig, params, *extra_args, **extra_kwargs
): ):
if all(isinstance(p, dict) for p in params): if all(isinstance(p, dict) for p in params):
params = [t for p in params for t in p.values()] params = [t for p in params for t in p.values()]
params = list(filter(lambda p: p.requires_grad, params)) params = list(filter(lambda p: p.requires_grad, params))
return _build_optimizer(optimizer_cfg, params, *extra_args, **extra_kwargs) return _build_optimizer(cfg, params, *extra_args, **extra_kwargs)
# automatically import any Python files in the optim/ directory # automatically import any Python files in the optim/ directory

View File

@ -5,6 +5,7 @@
import logging import logging
import math import math
from collections import Collection
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import List from typing import List
@ -14,7 +15,7 @@ import torch.optim
from fairseq.dataclass import FairseqDataclass from fairseq.dataclass import FairseqDataclass
from fairseq.optim import FairseqOptimizer, register_optimizer from fairseq.optim import FairseqOptimizer, register_optimizer
from fairseq.optim.fused_adam import get_fused_adam_class from fairseq.optim.fused_adam import get_fused_adam_class
from omegaconf import II from omegaconf import II, DictConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -33,8 +34,8 @@ class FairseqAdamConfig(FairseqDataclass):
default=False, metadata={"help": "Use fairseq.optim.adam.Adam"} default=False, metadata={"help": "Use fairseq.optim.adam.Adam"}
) )
# TODO common vars below in parent # TODO common vars below in parent
tpu: bool = II("params.common.tpu") tpu: bool = II("common.tpu")
lr: List[float] = II("params.optimization.lr") lr: List[float] = II("optimization.lr")
@register_optimizer("adam", dataclass=FairseqAdamConfig) @register_optimizer("adam", dataclass=FairseqAdamConfig)
@ -46,15 +47,15 @@ class FairseqAdam(FairseqOptimizer):
analogous to torch.optim.AdamW from PyTorch. analogous to torch.optim.AdamW from PyTorch.
""" """
def __init__(self, args, params): def __init__(self, cfg: DictConfig, params):
super().__init__(args) super().__init__(cfg)
fused_adam_cls = get_fused_adam_class() fused_adam_cls = get_fused_adam_class()
use_fused_adam = ( use_fused_adam = (
not getattr(args, "use_old_adam", False) not getattr(cfg, "use_old_adam", False)
and fused_adam_cls is not None and fused_adam_cls is not None
and torch.cuda.is_available() and torch.cuda.is_available()
) )
if getattr(args, "tpu", False): if getattr(cfg, "tpu", False):
# on TPUs we use the Adam defined here, since it # on TPUs we use the Adam defined here, since it
# automatically casts gradients to FP32 # automatically casts gradients to FP32
self._optimizer = Adam(params, **self.optimizer_config) self._optimizer = Adam(params, **self.optimizer_config)
@ -73,10 +74,12 @@ class FairseqAdam(FairseqOptimizer):
different learning rate. different learning rate.
""" """
return { return {
"lr": self.args.lr[0], "lr": self.cfg.lr[0]
"betas": eval(self.args.adam_betas), if isinstance(self.cfg.lr, Collection)
"eps": self.args.adam_eps, else self.cfg.lr,
"weight_decay": self.args.weight_decay, "betas": eval(self.cfg.adam_betas),
"eps": self.cfg.adam_eps,
"weight_decay": self.cfg.weight_decay,
} }
def average_params(self): def average_params(self):

View File

@ -10,7 +10,7 @@ import torch.distributed as dist
from fairseq.dataclass import FairseqDataclass from fairseq.dataclass import FairseqDataclass
from fairseq.dataclass.utils import gen_parser_from_dataclass from fairseq.dataclass.utils import gen_parser_from_dataclass
from fairseq.optim.fairseq_optimizer import FairseqOptimizer from fairseq.optim.fairseq_optimizer import FairseqOptimizer
from omegaconf import II from omegaconf import II, DictConfig
@dataclass @dataclass
@ -38,7 +38,7 @@ class FairseqBMUFConfig(FairseqDataclass):
}, },
) )
distributed_world_size: int = II( distributed_world_size: int = II(
"params.distributed_training.distributed_world_size" "distributed_training.distributed_world_size"
) )
@ -52,20 +52,19 @@ class FairseqBMUF(FairseqOptimizer):
model-update filtering model-update filtering
""" """
def __init__(self, args, optimizer): def __init__(self, cfg: DictConfig, optimizer):
super().__init__(cfg)
super().__init__(args)
self._optimizer = optimizer self._optimizer = optimizer
self._num_updates = 0 self._num_updates = 0
self.sync_iter = self.args.global_sync_iter self.sync_iter = cfg.global_sync_iter
self.block_momentum = self.args.block_momentum self.block_momentum = cfg.block_momentum
self.block_lr = self.args.block_lr self.block_lr = cfg.block_lr
self._reset_local_data() self._reset_local_data()
self.warmup_iteration = self.args.warmup_iterations self.warmup_iteration = cfg.warmup_iterations
self.use_nbm = self.args.use_nbm self.use_nbm = cfg.use_nbm
self.initial_state = self._optimizer.state_dict() self.initial_state = self._optimizer.state_dict()
self.average_sync = self.args.average_sync self.average_sync = self.cfg.average_sync
self.world_size = self.args.distributed_world_size self.world_size = self.cfg.distributed_world_size
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):

View File

@ -9,9 +9,9 @@ from fairseq.dataclass.utils import gen_parser_from_dataclass
class FairseqOptimizer(object): class FairseqOptimizer(object):
def __init__(self, args): def __init__(self, cfg):
super().__init__() super().__init__()
self.args = args self.cfg = cfg
@classmethod @classmethod
def add_args(cls, parser): def add_args(cls, parser):

View File

@ -7,7 +7,8 @@ from collections import defaultdict
from itertools import chain from itertools import chain
import torch import torch
from fairseq import optim, utils from fairseq import optim
from omegaconf import DictConfig
from .dynamic_loss_scaler import DynamicLossScaler from .dynamic_loss_scaler import DynamicLossScaler
@ -211,7 +212,7 @@ class _FP16OptimizerMixin(object):
for fp32_params in self.fp32_params.values(): for fp32_params in self.fp32_params.values():
fp32_params.grad.zero_() fp32_params.grad.zero_()
else: else:
raise ("self.fp32_params must be a tensor or dict") raise RuntimeError("self.fp32_params must be a tensor or dict")
else: else:
for p32 in self.fp32_params: for p32 in self.fp32_params:
p32.grad.zero_() p32.grad.zero_()
@ -226,58 +227,60 @@ class FP16Optimizer(_FP16OptimizerMixin, optim.FairseqOptimizer):
Wrap an *optimizer* to support FP16 (mixed precision) training. Wrap an *optimizer* to support FP16 (mixed precision) training.
""" """
def __init__(self, args, params, fp32_optimizer, fp32_params): def __init__(self, cfg: DictConfig, params, fp32_optimizer, fp32_params, **kwargs):
super().__init__(args) super().__init__(cfg.optimizer)
self.fp16_params = params self.fp16_params = params
self.fp32_optimizer = fp32_optimizer self.fp32_optimizer = fp32_optimizer
self.fp32_params = fp32_params self.fp32_params = fp32_params
if getattr(args, "fp16_scale_window", None) is None: if getattr(cfg.common, "fp16_scale_window", None) is None:
if len(args.update_freq) > 1: if len(cfg.optimization.update_freq) > 1:
raise ValueError( raise ValueError(
"--fp16-scale-window must be given explicitly when using a " "--fp16-scale-window must be given explicitly when using a "
"custom --update-freq schedule" "custom --update-freq schedule"
) )
data_parallel_size = int( data_parallel_size = int(
args.distributed_world_size / args.model_parallel_size cfg.distributed_training.distributed_world_size
/ cfg.common.model_parallel_size
)
scale_window = int(
2 ** 14 / data_parallel_size / cfg.optimization.update_freq[0]
) )
scale_window = int(2 ** 14 / data_parallel_size / args.update_freq[0])
else: else:
scale_window = args.fp16_scale_window scale_window = cfg.common.fp16_scale_window
if not getattr(args, "bf16", False): if not getattr(cfg.common, "bf16", False):
self.scaler = DynamicLossScaler( self.scaler = DynamicLossScaler(
init_scale=args.fp16_init_scale, init_scale=cfg.common.fp16_init_scale,
scale_window=scale_window, scale_window=scale_window,
tolerance=args.fp16_scale_tolerance, tolerance=cfg.common.fp16_scale_tolerance,
threshold=args.threshold_loss_scale, threshold=cfg.common.threshold_loss_scale,
min_loss_scale=args.min_loss_scale, min_loss_scale=cfg.common.min_loss_scale,
) )
else: else:
# disable loss scaling for bfloat16 # disable loss scaling for bfloat16
self.scaler = None self.scaler = None
@classmethod @classmethod
def build_optimizer(cls, args, params): def build_optimizer(cls, cfg: DictConfig, params, **kwargs):
""" """
Args: Args:
args (argparse.Namespace): fairseq args cfg (omegaconf.DictConfig): fairseq args
params (iterable): iterable of parameters to optimize params (iterable): iterable of parameters to optimize
""" """
flatten = not getattr(args, "fp16_no_flatten_grads", False) flatten = not getattr(cfg.common, "fp16_no_flatten_grads", False)
if getattr(args, "bf16", False): if getattr(cfg.common, "bf16", False):
flatten = False # mixed precision is faster on TPUs without flat grads flatten = False # mixed precision is faster on TPUs without flat grads
fp32_params = cls.build_fp32_params(args, params, flatten=flatten) fp32_params = cls.build_fp32_params(cfg.optimizer, params, flatten=flatten)
if flatten: if flatten:
fp32_optimizer = optim.build_optimizer(args, [fp32_params]) fp32_optimizer = optim.build_optimizer(cfg.optimizer, [fp32_params])
else: else:
fp32_optimizer = optim.build_optimizer(args, fp32_params) fp32_optimizer = optim.build_optimizer(cfg.optimizer, fp32_params)
if flatten and not fp32_optimizer.supports_flat_params: if flatten and not fp32_optimizer.supports_flat_params:
raise RuntimeError( raise RuntimeError(
"chosen optimizer does not support flat params, " f"chosen optimizer {fp32_optimizer.__class__.__name__} does not support flat params, please set --fp16-no-flatten-grads"
"please set --fp16-no-flatten-grads"
) )
return cls(args, params, fp32_optimizer, fp32_params) return cls(cfg, params, fp32_optimizer, fp32_params, **kwargs)
@property @property
def optimizer(self): def optimizer(self):
@ -427,49 +430,52 @@ class MemoryEfficientFP16Optimizer(
*supports_memory_efficient_fp16* property. *supports_memory_efficient_fp16* property.
""" """
def __init__(self, args, params, optimizer): def __init__(self, cfg: DictConfig, params, optimizer, **kwargs):
if not optimizer.supports_memory_efficient_fp16: if not optimizer.supports_memory_efficient_fp16:
raise ValueError( raise ValueError(
"Unsupported optimizer: {}".format(optimizer.__class__.__name__) "Unsupported optimizer: {}".format(optimizer.__class__.__name__)
) )
super().__init__(args) super().__init__(cfg.optimizer)
self.wrapped_optimizer = optimizer self.wrapped_optimizer = optimizer
if getattr(args, "fp16_scale_window", None) is None: if getattr(cfg.common, "fp16_scale_window", None) is None:
if len(args.update_freq) > 1: if len(cfg.optimization.update_freq) > 1:
raise ValueError( raise ValueError(
"--fp16-scale-window must be given explicitly when using a " "--fp16-scale-window must be given explicitly when using a "
"custom --update-freq schedule" "custom --update-freq schedule"
) )
data_parallel_size = int( data_parallel_size = int(
args.distributed_world_size / args.model_parallel_size cfg.distributed_training.distributed_world_size
/ cfg.common.model_parallel_size
)
scale_window = (
2 ** 14 / data_parallel_size / cfg.optimization.update_freq[0]
) )
scale_window = 2 ** 14 / data_parallel_size / args.update_freq[0]
else: else:
scale_window = args.fp16_scale_window scale_window = cfg.common.fp16_scale_window
if not getattr(args, "bf16", False): if not getattr(cfg.common, "bf16", False):
self.scaler = DynamicLossScaler( self.scaler = DynamicLossScaler(
init_scale=args.fp16_init_scale, init_scale=cfg.common.fp16_init_scale,
scale_window=scale_window, scale_window=scale_window,
tolerance=args.fp16_scale_tolerance, tolerance=cfg.common.fp16_scale_tolerance,
threshold=args.threshold_loss_scale, threshold=cfg.common.threshold_loss_scale,
min_loss_scale=args.min_loss_scale, min_loss_scale=cfg.common.min_loss_scale,
) )
else: else:
# disable loss scaling for bfloat16 # disable loss scaling for bfloat16
self.scaler = None self.scaler = None
@classmethod @classmethod
def build_optimizer(cls, args, params): def build_optimizer(cls, cfg: DictConfig, params, **kwargs):
""" """
Args: Args:
args (argparse.Namespace): fairseq args args (argparse.Namespace): fairseq args
params (iterable): iterable of parameters to optimize params (iterable): iterable of parameters to optimize
""" """
fp16_optimizer = optim.build_optimizer(args, params) fp16_optimizer = optim.build_optimizer(cfg.optimizer, params)
return cls(args, params, fp16_optimizer) return cls(cfg, params, fp16_optimizer, **kwargs)
@property @property
def optimizer(self): def optimizer(self):

View File

@ -6,8 +6,6 @@
import importlib import importlib
import os import os
from argparse import Namespace
from typing import Union
from fairseq import registry from fairseq import registry
from fairseq.optim.lr_scheduler.fairseq_lr_scheduler import ( # noqa from fairseq.optim.lr_scheduler.fairseq_lr_scheduler import ( # noqa
@ -27,8 +25,8 @@ from omegaconf import DictConfig
) )
def build_lr_scheduler(lr_scheduler_cfg: Union[DictConfig, Namespace], optimizer): def build_lr_scheduler(cfg: DictConfig, optimizer):
return build_lr_scheduler_(lr_scheduler_cfg, optimizer) return build_lr_scheduler_(cfg, optimizer)
# automatically import any Python files in the optim/lr_scheduler/ directory # automatically import any Python files in the optim/lr_scheduler/ directory

View File

@ -4,11 +4,12 @@
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import math import math
from collections import Collection
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import List from typing import List
from fairseq.dataclass import FairseqDataclass from fairseq.dataclass import FairseqDataclass
from omegaconf import II from omegaconf import II, DictConfig
from . import FairseqLRScheduler, register_lr_scheduler from . import FairseqLRScheduler, register_lr_scheduler
@ -38,8 +39,8 @@ class CosineConfig(FairseqDataclass):
default=0.1, metadata={"help": "shrink factor for annealing"} default=0.1, metadata={"help": "shrink factor for annealing"}
) )
# TODO common var for parent class # TODO common var for parent class
lr: List[float] = II("params.optimization.lr") lr: List[float] = II("optimization.lr")
max_update: int = II("params.optimization.max_update") max_update: int = II("optimization.max_update")
@register_lr_scheduler("cosine", dataclass=CosineConfig) @register_lr_scheduler("cosine", dataclass=CosineConfig)
@ -66,43 +67,51 @@ class CosineSchedule(FairseqLRScheduler):
after every iteration. after every iteration.
""" """
def __init__(self, args, optimizer): def __init__(
super().__init__(args, optimizer) self, cfg: DictConfig, fairseq_optimizer
if len(args.lr) > 1: ):
super().__init__(cfg, fairseq_optimizer)
if isinstance(cfg.lr, Collection) and len(cfg.lr) > 1:
raise ValueError( raise ValueError(
"Cannot use a fixed learning rate schedule with cosine." "Cannot use a fixed learning rate schedule with cosine."
" Consider --lr-scheduler=fixed instead." " Consider --lr-scheduler=fixed instead."
) )
warmup_end_lr = args.max_lr warmup_end_lr = cfg.max_lr
if args.warmup_init_lr < 0: lr = (
args.warmup_init_lr = args.lr[0] cfg.lr[0]
if isinstance(cfg.lr, Collection)
self.min_lr = args.lr[0] else cfg.lr
self.max_lr = args.max_lr )
if cfg.warmup_init_lr < 0:
cfg.warmup_init_lr = lr
self.min_lr = lr
self.max_lr = cfg.max_lr
assert self.max_lr > self.min_lr, "max_lr must be more than lr" assert self.max_lr > self.min_lr, "max_lr must be more than lr"
self.t_mult = args.t_mult self.t_mult = cfg.t_mult
self.period = args.lr_period_updates self.period = cfg.lr_period_updates
if self.period <= 0: if self.period <= 0:
assert ( assert (
args.max_update >= 0 cfg.max_update >= 0
), "Either --max_update or --lr-period-updates must be set" ), "Either --max_update or --lr-period-updates must be set"
self.period = args.max_update - args.warmup_updates self.period = cfg.max_update - cfg.warmup_updates
if args.warmup_updates > 0: if cfg.warmup_updates > 0:
# linearly warmup for the first args.warmup_updates # linearly warmup for the first args.warmup_updates
self.lr_step = (warmup_end_lr - args.warmup_init_lr) / args.warmup_updates self.lr_step = (
warmup_end_lr - cfg.warmup_init_lr
) / cfg.warmup_updates
else: else:
self.lr_step = 1 self.lr_step = 1
self.warmup_updates = args.warmup_updates self.warmup_updates = cfg.warmup_updates
self.lr_shrink = args.lr_shrink self.lr_shrink = cfg.lr_shrink
# initial learning rate # initial learning rate
self.lr = args.warmup_init_lr self.lr = cfg.warmup_init_lr
self.optimizer.set_lr(self.lr) self.optimizer.set_lr(self.lr)
def step(self, epoch, val_loss=None): def step(self, epoch, val_loss=None):
@ -113,10 +122,10 @@ class CosineSchedule(FairseqLRScheduler):
def step_update(self, num_updates): def step_update(self, num_updates):
"""Update the learning rate after each update.""" """Update the learning rate after each update."""
if num_updates < self.args.warmup_updates: if num_updates < self.cfg.warmup_updates:
self.lr = self.args.warmup_init_lr + num_updates * self.lr_step self.lr = self.cfg.warmup_init_lr + num_updates * self.lr_step
else: else:
curr_updates = num_updates - self.args.warmup_updates curr_updates = num_updates - self.cfg.warmup_updates
if self.t_mult != 1: if self.t_mult != 1:
i = math.floor( i = math.floor(
math.log( math.log(

View File

@ -11,11 +11,11 @@ from .. import FairseqOptimizer
class FairseqLRScheduler(object): class FairseqLRScheduler(object):
def __init__(self, args, optimizer): def __init__(self, cfg, optimizer):
super().__init__() super().__init__()
if not isinstance(optimizer, FairseqOptimizer): if not isinstance(optimizer, FairseqOptimizer):
raise ValueError("optimizer must be an instance of FairseqOptimizer") raise ValueError("optimizer must be an instance of FairseqOptimizer")
self.args = args self.cfg = cfg
self.optimizer = optimizer self.optimizer = optimizer
self.best = None self.best = None

View File

@ -3,11 +3,12 @@
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from collections import Collection
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import List from typing import List
from fairseq.dataclass import FairseqDataclass from fairseq.dataclass import FairseqDataclass
from omegaconf import II from omegaconf import II, DictConfig
from . import FairseqLRScheduler, register_lr_scheduler from . import FairseqLRScheduler, register_lr_scheduler
@ -25,7 +26,7 @@ class InverseSquareRootScheduleConfig(FairseqDataclass):
}, },
) )
# TODO common vars at parent class # TODO common vars at parent class
lr: List[float] = II("params.optimization.lr") lr: List[float] = II("optimization.lr")
@register_lr_scheduler("inverse_sqrt", dataclass=InverseSquareRootScheduleConfig) @register_lr_scheduler("inverse_sqrt", dataclass=InverseSquareRootScheduleConfig)
@ -48,25 +49,33 @@ class InverseSquareRootSchedule(FairseqLRScheduler):
lr = decay_factor / sqrt(update_num) lr = decay_factor / sqrt(update_num)
""" """
def __init__(self, args, optimizer): def __init__(self, cfg: DictConfig, optimizer):
super().__init__(args, optimizer) super().__init__(cfg, optimizer)
if len(args.lr) > 1: if isinstance(cfg.lr, Collection) and len(cfg.lr) > 1:
raise ValueError( raise ValueError(
"Cannot use a fixed learning rate schedule with inverse_sqrt." "Cannot use a fixed learning rate schedule with inverse_sqrt."
" Consider --lr-scheduler=fixed instead." " Consider --lr-scheduler=fixed instead."
) )
warmup_end_lr = args.lr[0] warmup_end_lr = (
if args.warmup_init_lr < 0: cfg.lr[0]
args.warmup_init_lr = 0 if args.warmup_updates > 0 else warmup_end_lr if isinstance(cfg.lr, Collection)
else cfg.lr
)
if cfg.warmup_init_lr < 0:
cfg.warmup_init_lr = (
0 if cfg.warmup_updates > 0 else warmup_end_lr
)
# linearly warmup for the first args.warmup_updates # linearly warmup for the first args.warmup_updates
self.lr_step = (warmup_end_lr - args.warmup_init_lr) / args.warmup_updates self.lr_step = (
warmup_end_lr - cfg.warmup_init_lr
) / cfg.warmup_updates
# then, decay prop. to the inverse square root of the update number # then, decay prop. to the inverse square root of the update number
self.decay_factor = warmup_end_lr * args.warmup_updates ** 0.5 self.decay_factor = warmup_end_lr * cfg.warmup_updates ** 0.5
# initial learning rate # initial learning rate
self.lr = args.warmup_init_lr self.lr = cfg.warmup_init_lr
self.optimizer.set_lr(self.lr) self.optimizer.set_lr(self.lr)
def step(self, epoch, val_loss=None): def step(self, epoch, val_loss=None):
@ -77,8 +86,8 @@ class InverseSquareRootSchedule(FairseqLRScheduler):
def step_update(self, num_updates): def step_update(self, num_updates):
"""Update the learning rate after each update.""" """Update the learning rate after each update."""
if num_updates < self.args.warmup_updates: if num_updates < self.cfg.warmup_updates:
self.lr = self.args.warmup_init_lr + num_updates * self.lr_step self.lr = self.cfg.warmup_init_lr + num_updates * self.lr_step
else: else:
self.lr = self.decay_factor * num_updates ** -0.5 self.lr = self.decay_factor * num_updates ** -0.5
self.optimizer.set_lr(self.lr) self.optimizer.set_lr(self.lr)

View File

@ -3,12 +3,13 @@
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from collections import Collection
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import List from typing import List
import torch import torch
from fairseq.dataclass import FairseqDataclass from fairseq.dataclass import FairseqDataclass
from omegaconf import II from omegaconf import II, DictConfig
from torch.optim.optimizer import Optimizer, required from torch.optim.optimizer import Optimizer, required
from . import FairseqOptimizer, register_optimizer from . import FairseqOptimizer, register_optimizer
@ -19,13 +20,13 @@ class FairseqNAGConfig(FairseqDataclass):
momentum: float = field(default=0.99, metadata={"help": "momentum factor"}) momentum: float = field(default=0.99, metadata={"help": "momentum factor"})
weight_decay: float = field(default=0.0, metadata={"help": "weight decay"}) weight_decay: float = field(default=0.0, metadata={"help": "weight decay"})
# TODO common vars in parent class # TODO common vars in parent class
lr: List[float] = II("params.optimization.lr") lr: List[float] = II("optimization.lr")
@register_optimizer("nag", dataclass=FairseqNAGConfig) @register_optimizer("nag", dataclass=FairseqNAGConfig)
class FairseqNAG(FairseqOptimizer): class FairseqNAG(FairseqOptimizer):
def __init__(self, args, params): def __init__(self, cfg: DictConfig, params):
super().__init__(args) super().__init__(cfg)
self._optimizer = NAG(params, **self.optimizer_config) self._optimizer = NAG(params, **self.optimizer_config)
@property @property
@ -37,9 +38,11 @@ class FairseqNAG(FairseqOptimizer):
different learning rate. different learning rate.
""" """
return { return {
"lr": self.args.lr[0], "lr": self.cfg.lr[0]
"momentum": self.args.momentum, if isinstance(self.cfg.lr, Collection)
"weight_decay": self.args.weight_decay, else self.cfg.lr,
"momentum": self.cfg.momentum,
"weight_decay": self.cfg.weight_decay,
} }

View File

@ -12,7 +12,7 @@ except ImportError:
_has_fairscale = False _has_fairscale = False
def shard_(args, optimizer, group): def shard_(optimizer, group):
if not _has_fairscale: if not _has_fairscale:
raise ImportError( raise ImportError(
"\n\nPlease install the fairscale package:" "\n\n pip install fairscale" "\n\nPlease install the fairscale package:" "\n\n pip install fairscale"

View File

@ -10,13 +10,15 @@ import torch
from fairseq import utils from fairseq import utils
from fairseq.data.indexed_dataset import get_available_dataset_impl from fairseq.data.indexed_dataset import get_available_dataset_impl
from fairseq.dataclass.data_class import ( from fairseq.dataclass.data_class import (
CheckpointParams, CheckpointConfig,
CommonEvalParams, CommonConfig,
CommonParams, CommonEvalConfig,
DatasetParams, DatasetConfig,
DistributedTrainingParams, DistributedTrainingConfig,
EvalLMParams, EvalLMConfig,
OptimizationParams, GenerationConfig,
InteractiveConfig,
OptimizationConfig,
) )
from fairseq.dataclass.utils import gen_parser_from_dataclass from fairseq.dataclass.utils import gen_parser_from_dataclass
@ -45,6 +47,7 @@ def get_generation_parser(interactive=False, default_task="translation"):
add_dataset_args(parser, gen=True) add_dataset_args(parser, gen=True)
add_distributed_training_args(parser, default_world_size=1) add_distributed_training_args(parser, default_world_size=1)
add_generation_args(parser) add_generation_args(parser)
add_checkpoint_args(parser)
if interactive: if interactive:
add_interactive_args(parser) add_interactive_args(parser)
return parser return parser
@ -67,7 +70,7 @@ def get_validation_parser(default_task=None):
add_dataset_args(parser, train=True) add_dataset_args(parser, train=True)
add_distributed_training_args(parser, default_world_size=1) add_distributed_training_args(parser, default_world_size=1)
group = parser.add_argument_group("Evaluation") group = parser.add_argument_group("Evaluation")
gen_parser_from_dataclass(group, CommonEvalParams()) gen_parser_from_dataclass(group, CommonEvalConfig())
return parser return parser
@ -210,7 +213,7 @@ def get_parser(desc, default_task="translation"):
utils.import_user_module(usr_args) utils.import_user_module(usr_args)
parser = argparse.ArgumentParser(allow_abbrev=False) parser = argparse.ArgumentParser(allow_abbrev=False)
gen_parser_from_dataclass(parser, CommonParams()) gen_parser_from_dataclass(parser, CommonConfig())
from fairseq.registry import REGISTRIES from fairseq.registry import REGISTRIES
@ -283,7 +286,7 @@ def add_preprocess_args(parser):
def add_dataset_args(parser, train=False, gen=False): def add_dataset_args(parser, train=False, gen=False):
group = parser.add_argument_group("dataset_data_loading") group = parser.add_argument_group("dataset_data_loading")
gen_parser_from_dataclass(group, DatasetParams()) gen_parser_from_dataclass(group, DatasetConfig())
# fmt: on # fmt: on
return group return group
@ -293,7 +296,7 @@ def add_distributed_training_args(parser, default_world_size=None):
if default_world_size is None: if default_world_size is None:
default_world_size = max(1, torch.cuda.device_count()) default_world_size = max(1, torch.cuda.device_count())
gen_parser_from_dataclass( gen_parser_from_dataclass(
group, DistributedTrainingParams(distributed_world_size=default_world_size) group, DistributedTrainingConfig(distributed_world_size=default_world_size)
) )
return group return group
@ -301,7 +304,7 @@ def add_distributed_training_args(parser, default_world_size=None):
def add_optimization_args(parser): def add_optimization_args(parser):
group = parser.add_argument_group("optimization") group = parser.add_argument_group("optimization")
# fmt: off # fmt: off
gen_parser_from_dataclass(group, OptimizationParams()) gen_parser_from_dataclass(group, OptimizationConfig())
# fmt: on # fmt: on
return group return group
@ -309,117 +312,31 @@ def add_optimization_args(parser):
def add_checkpoint_args(parser): def add_checkpoint_args(parser):
group = parser.add_argument_group("checkpoint") group = parser.add_argument_group("checkpoint")
# fmt: off # fmt: off
gen_parser_from_dataclass(group, CheckpointParams()) gen_parser_from_dataclass(group, CheckpointConfig())
# fmt: on # fmt: on
return group return group
def add_common_eval_args(group): def add_common_eval_args(group):
gen_parser_from_dataclass(group, CommonEvalParams()) gen_parser_from_dataclass(group, CommonEvalConfig())
def add_eval_lm_args(parser): def add_eval_lm_args(parser):
group = parser.add_argument_group("LM Evaluation") group = parser.add_argument_group("LM Evaluation")
add_common_eval_args(group) add_common_eval_args(group)
gen_parser_from_dataclass(group, EvalLMParams()) gen_parser_from_dataclass(group, EvalLMConfig())
def add_generation_args(parser): def add_generation_args(parser):
group = parser.add_argument_group("Generation") group = parser.add_argument_group("Generation")
add_common_eval_args(group) add_common_eval_args(group)
# fmt: off gen_parser_from_dataclass(group, GenerationConfig())
group.add_argument('--beam', default=5, type=int, metavar='N',
help='beam size')
group.add_argument('--nbest', default=1, type=int, metavar='N',
help='number of hypotheses to output')
group.add_argument('--max-len-a', default=0, type=float, metavar='N',
help=('generate sequences of maximum length ax + b, '
'where x is the source length'))
group.add_argument('--max-len-b', default=200, type=int, metavar='N',
help=('generate sequences of maximum length ax + b, '
'where x is the source length'))
group.add_argument('--min-len', default=1, type=float, metavar='N',
help=('minimum generation length'))
group.add_argument('--match-source-len', default=False, action='store_true',
help=('generations should match the source length'))
group.add_argument('--no-early-stop', action='store_true',
help='deprecated')
group.add_argument('--unnormalized', action='store_true',
help='compare unnormalized hypothesis scores')
group.add_argument('--no-beamable-mm', action='store_true',
help='don\'t use BeamableMM in attention layers')
group.add_argument('--lenpen', default=1, type=float,
help='length penalty: <1.0 favors shorter, >1.0 favors longer sentences')
group.add_argument('--unkpen', default=0, type=float,
help='unknown word penalty: <0 produces more unks, >0 produces fewer')
group.add_argument('--replace-unk', nargs='?', const=True, default=None,
help='perform unknown replacement (optionally with alignment dictionary)')
group.add_argument('--sacrebleu', action='store_true',
help='score with sacrebleu')
group.add_argument('--score-reference', action='store_true',
help='just score the reference translation')
group.add_argument('--prefix-size', default=0, type=int, metavar='PS',
help='initialize generation by target prefix of given length')
group.add_argument('--no-repeat-ngram-size', default=0, type=int, metavar='N',
help='ngram blocking such that this size ngram cannot be repeated in the generation')
group.add_argument('--sampling', action='store_true',
help='sample hypotheses instead of using beam search')
group.add_argument('--sampling-topk', default=-1, type=int, metavar='PS',
help='sample from top K likely next words instead of all words')
group.add_argument('--sampling-topp', default=-1.0, type=float, metavar='PS',
help='sample from the smallest set whose cumulative probability mass exceeds p for next words')
group.add_argument('--constraints', const="ordered", nargs="?", choices=["ordered", "unordered"],
help='enables lexically constrained decoding')
group.add_argument('--temperature', default=1., type=float, metavar='N',
help='temperature for generation')
group.add_argument('--diverse-beam-groups', default=-1, type=int, metavar='N',
help='number of groups for Diverse Beam Search')
group.add_argument('--diverse-beam-strength', default=0.5, type=float, metavar='N',
help='strength of diversity penalty for Diverse Beam Search')
group.add_argument('--diversity-rate', default=-1.0, type=float, metavar='N',
help='strength of diversity penalty for Diverse Siblings Search')
group.add_argument('--print-alignment', action='store_true',
help='if set, uses attention feedback to compute and print alignment to source tokens')
group.add_argument('--print-step', action='store_true')
group.add_argument('--lm-path', default=None, type=str, metavar='PATH',
help='path to lm checkpoint for lm fusion')
group.add_argument('--lm-weight', default=0.0, type=float, metavar='N',
help='weight for lm probs for lm fusion')
# arguments for iterative refinement generator
group.add_argument('--iter-decode-eos-penalty', default=0.0, type=float, metavar='N',
help='if > 0.0, it penalized early-stopping in decoding.')
group.add_argument('--iter-decode-max-iter', default=10, type=int, metavar='N',
help='maximum iterations for iterative refinement.')
group.add_argument('--iter-decode-force-max-iter', action='store_true',
help='if set, run exact the maximum number of iterations without early stop')
group.add_argument('--iter-decode-with-beam', default=1, type=int, metavar='N',
help='if > 1, model will generate translations varying by the lengths.')
group.add_argument('--iter-decode-with-external-reranker', action='store_true',
help='if set, the last checkpoint are assumed to be a reranker to rescore the translations'),
group.add_argument('--retain-iter-history', action='store_true',
help='if set, decoding returns the whole history of iterative refinement')
group.add_argument('--retain-dropout', action='store_true',
help='Use dropout at inference time')
group.add_argument('--retain-dropout-modules', default=None, nargs='+', type=str,
help='if set, only retain dropout for the specified modules; '
'if not set, then dropout will be retained for all modules')
# special decoding format for advanced decoding.
group.add_argument('--decoding-format', default=None, type=str, choices=['unigram', 'ensemble', 'vote', 'dp', 'bs'])
# fmt: on
return group return group
def add_interactive_args(parser): def add_interactive_args(parser):
group = parser.add_argument_group("Interactive") group = parser.add_argument_group("Interactive")
# fmt: off gen_parser_from_dataclass(group, InteractiveConfig())
group.add_argument('--buffer-size', default=0, type=int, metavar='N',
help='read this many sentences into a buffer before processing them')
group.add_argument('--input', default='-', type=str, metavar='FILE',
help='file to read from; use - for stdin')
# fmt: on
def add_model_args(parser): def add_model_args(parser):

View File

@ -6,13 +6,14 @@
import logging import logging
from fairseq.modules.quantization import pq, quantization_options, scalar from fairseq.modules.quantization import pq, quantization_options, scalar
from omegaconf import DictConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def quantize_model_scalar(model, args): def quantize_model_scalar(model, model_cfg: DictConfig):
quant_noise_scalar = getattr(args, "quant_noise_scalar", 0) quant_noise_scalar = getattr(model_cfg, "quant_noise_scalar", 0) or 0
if quant_noise_scalar > 0: if quant_noise_scalar > 0:
# quantize_model edits the model in place # quantize_model edits the model in place
scalar.quantize_model_(model, p=quant_noise_scalar, bits=8, update_step=1000) scalar.quantize_model_(model, p=quant_noise_scalar, bits=8, update_step=1000)

View File

@ -3,14 +3,13 @@
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import argparse
from argparse import Namespace from argparse import Namespace
from typing import Union from typing import Union
from fairseq.dataclass import FairseqDataclass from fairseq.dataclass import FairseqDataclass
from fairseq.dataclass.utils import populate_dataclass
from omegaconf import DictConfig from omegaconf import DictConfig
REGISTRIES = {} REGISTRIES = {}
@ -25,33 +24,30 @@ def setup_registry(registry_name: str, base_class=None, default=None, required=F
# maintain a registry of all registries # maintain a registry of all registries
if registry_name in REGISTRIES: if registry_name in REGISTRIES:
return # registry already exists return # registry already exists
REGISTRIES[registry_name] = {"registry": REGISTRY, "default": default} REGISTRIES[registry_name] = {"registry": REGISTRY, "default": default, "dataclass_registry": DATACLASS_REGISTRY}
def build_x(args: Union[DictConfig, Namespace], *extra_args, **extra_kwargs): def build_x(cfg: Union[DictConfig, str, Namespace], *extra_args, **extra_kwargs):
if isinstance(args, DictConfig): if isinstance(cfg, DictConfig):
if getattr(args, "_name", None) is not None: choice = cfg._name
choice = args._name elif isinstance(cfg, str):
elif hasattr(args, registry_name): choice = cfg
choice = args.registry_name
else:
raise RuntimeError(
f"Neither _name nor {registry_name} in args, args = {args}"
)
else: else:
choice = getattr(args, registry_name, None) choice = getattr(cfg, registry_name, None)
if choice in DATACLASS_REGISTRY:
cfg = populate_dataclass(cfg, DATACLASS_REGISTRY[choice]())
if choice is None: if choice is None:
if required: if required:
raise ValueError("--{} is required!".format(registry_name)) raise ValueError('{} is required!'.format(registry_name))
return None return None
cls = REGISTRY[choice] cls = REGISTRY[choice]
if hasattr(cls, "build_" + registry_name): if hasattr(cls, "build_" + registry_name):
builder = getattr(cls, "build_" + registry_name) builder = getattr(cls, "build_" + registry_name)
else: else:
builder = cls builder = cls
if isinstance(args, Namespace):
set_defaults(args, cls) return builder(cfg, *extra_args, **extra_kwargs)
return builder(args, *extra_args, **extra_kwargs)
def register_x(name, dataclass=None): def register_x(name, dataclass=None):
def register_x_cls(cls): def register_x_cls(cls):
@ -77,30 +73,10 @@ def setup_registry(registry_name: str, base_class=None, default=None, required=F
cls.__dataclass = dataclass cls.__dataclass = dataclass
REGISTRY[name] = cls REGISTRY[name] = cls
DATACLASS_REGISTRY[name] = cls.__dataclass if cls.__dataclass is not None:
REGISTRY_CLASS_NAMES.add(cls.__name__) DATACLASS_REGISTRY[name] = cls.__dataclass
return cls return cls
return register_x_cls return register_x_cls
return build_x, register_x, REGISTRY, DATACLASS_REGISTRY return build_x, register_x, REGISTRY, DATACLASS_REGISTRY
def set_defaults(args: Namespace, cls):
"""Helper to set default arguments based on *add_args*."""
if not hasattr(cls, "add_args"):
return
parser = argparse.ArgumentParser(
argument_default=argparse.SUPPRESS, allow_abbrev=False
)
cls.add_args(parser)
# copied from argparse.py:
defaults = argparse.Namespace()
for action in parser._actions:
if action.dest is not argparse.SUPPRESS:
if not hasattr(defaults, action.dest):
if action.default is not argparse.SUPPRESS:
setattr(defaults, action.dest, action.default)
for key, default_value in vars(defaults).items():
if not hasattr(args, key):
setattr(args, key, default_value)

View File

@ -9,11 +9,12 @@ import os
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from fairseq import registry from fairseq import registry
from omegaconf import DictConfig
class BaseScorer(ABC): class BaseScorer(ABC):
def __init__(self, args): def __init__(self, cfg):
self.args = args self.cfg = cfg
self.ref = [] self.ref = []
self.pred = [] self.pred = []
@ -39,19 +40,17 @@ _build_scorer, register_scorer, SCORER_REGISTRY, _ = registry.setup_registry(
) )
def build_scorer(args, tgt_dict): def build_scorer(choice, tgt_dict):
from fairseq import utils if isinstance(choice, DictConfig):
choice = choice._name
if args.sacrebleu: if choice == "bleu":
utils.deprecation_warning(
"--sacrebleu is deprecated. Please use --scoring sacrebleu instead."
)
args.scoring = "sacrebleu"
if args.scoring == "bleu":
from fairseq.scoring import bleu from fairseq.scoring import bleu
return bleu.Scorer(tgt_dict.pad(), tgt_dict.eos(), tgt_dict.unk()) return bleu.Scorer(
return _build_scorer(args) bleu.BleuConfig(pad=tgt_dict.pad(), eos=tgt_dict.eos(), unk=tgt_dict.unk())
)
return _build_scorer(choice)
# automatically import any Python files in the current directory # automatically import any Python files in the current directory

View File

@ -6,8 +6,10 @@
import ctypes import ctypes
import math import math
import sys import sys
from dataclasses import dataclass, field
import torch import torch
from fairseq.dataclass import FairseqDataclass
from fairseq.scoring import BaseScorer, register_scorer from fairseq.scoring import BaseScorer, register_scorer
from fairseq.scoring.tokenizer import EvaluationTokenizer from fairseq.scoring.tokenizer import EvaluationTokenizer
@ -27,31 +29,32 @@ class BleuStat(ctypes.Structure):
] ]
@register_scorer("sacrebleu") @dataclass
class SacrebleuConfig(FairseqDataclass):
sacrebleu_tokenizer: EvaluationTokenizer.ALL_TOKENIZER_TYPES = field(
default="13a", metadata={"help": "tokenizer"}
)
sacrebleu_lowercase: bool = field(
default=False, metadata={"help": "apply lowercasing"}
)
sacrebleu_char_level: bool = field(
default=False, metadata={"help": "evaluate at character level"}
)
@register_scorer("sacrebleu", dataclass=SacrebleuConfig)
class SacrebleuScorer(BaseScorer): class SacrebleuScorer(BaseScorer):
def __init__(self, args): def __init__(self, cfg):
super(SacrebleuScorer, self).__init__(args) super(SacrebleuScorer, self).__init__(cfg)
import sacrebleu import sacrebleu
self.sacrebleu = sacrebleu self.sacrebleu = sacrebleu
self.tokenizer = EvaluationTokenizer( self.tokenizer = EvaluationTokenizer(
tokenizer_type=self.args.sacrebleu_tokenizer, tokenizer_type=cfg.sacrebleu_tokenizer,
lowercase=self.args.sacrebleu_lowercase, lowercase=cfg.sacrebleu_lowercase,
character_tokenization=self.args.sacrebleu_char_level, character_tokenization=cfg.sacrebleu_char_level,
) )
@staticmethod
def add_args(parser):
# fmt: off
parser.add_argument('--sacrebleu-tokenizer', type=str, default='13a',
choices=EvaluationTokenizer.ALL_TOKENIZER_TYPES,
help='tokenizer')
parser.add_argument('--sacrebleu-lowercase', type=str, default=False,
help='apply lowercasing')
parser.add_argument('--sacrebleu-char-level', action='store_true',
help='evaluate at character level')
# fmt: on
def add_string(self, ref, pred): def add_string(self, ref, pred):
self.ref.append(self.tokenizer.tokenize(ref)) self.ref.append(self.tokenizer.tokenize(ref))
self.pred.append(self.tokenizer.tokenize(pred)) self.pred.append(self.tokenizer.tokenize(pred))
@ -68,13 +71,20 @@ class SacrebleuScorer(BaseScorer):
).format() ).format()
@register_scorer("bleu") @dataclass
class BleuConfig(FairseqDataclass):
pad: int = field(default=1, metadata={"help": "padding index"})
eos: int = field(default=2, metadata={"help": "eos index"})
unk: int = field(default=3, metadata={"help": "unk index"})
@register_scorer("bleu", dataclass=BleuConfig)
class Scorer(object): class Scorer(object):
def __init__(self, pad, eos, unk): def __init__(self, cfg):
self.stat = BleuStat() self.stat = BleuStat()
self.pad = pad self.pad = cfg.pad
self.eos = eos self.eos = cfg.eos
self.unk = unk self.unk = cfg.unk
try: try:
from fairseq import libbleu from fairseq import libbleu

View File

@ -5,6 +5,8 @@
import unicodedata import unicodedata
from fairseq.dataclass.utils import ChoiceEnum
class EvaluationTokenizer(object): class EvaluationTokenizer(object):
"""A generic evaluation-time tokenizer, which leverages built-in tokenizers """A generic evaluation-time tokenizer, which leverages built-in tokenizers
@ -22,7 +24,7 @@ class EvaluationTokenizer(object):
SPACE = chr(32) SPACE = chr(32)
SPACE_ESCAPE = chr(9601) SPACE_ESCAPE = chr(9601)
ALL_TOKENIZER_TYPES = ["none", "13a", "intl", "zh", "ja-mecab"] ALL_TOKENIZER_TYPES = ChoiceEnum(["none", "13a", "intl", "zh", "ja-mecab"])
def __init__( def __init__(
self, self,
@ -33,7 +35,7 @@ class EvaluationTokenizer(object):
): ):
from sacrebleu.tokenizers import TOKENIZERS from sacrebleu.tokenizers import TOKENIZERS
assert tokenizer_type in self.ALL_TOKENIZER_TYPES assert tokenizer_type in TOKENIZERS, f"{tokenizer_type}, {TOKENIZERS}"
self.lowercase = lowercase self.lowercase = lowercase
self.punctuation_removal = punctuation_removal self.punctuation_removal = punctuation_removal
self.character_tokenization = character_tokenization self.character_tokenization = character_tokenization

View File

@ -3,14 +3,31 @@
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from dataclasses import dataclass, field
from fairseq.dataclass import FairseqDataclass
from fairseq.scoring import BaseScorer, register_scorer from fairseq.scoring import BaseScorer, register_scorer
from fairseq.scoring.tokenizer import EvaluationTokenizer from fairseq.scoring.tokenizer import EvaluationTokenizer
@register_scorer("wer") @dataclass
class WerScorerConfig(FairseqDataclass):
wer_tokenizer: EvaluationTokenizer.ALL_TOKENIZER_TYPES = field(
default="none", metadata={"help": "sacreBLEU tokenizer to use for evaluation"}
)
wer_remove_punct: bool = field(
default=False, metadata={"help": "remove punctuation"}
)
wer_char_level: bool = field(
default=False, metadata={"help": "evaluate at character level"}
)
wer_lowercase: bool = field(default=False, metadata={"help": "lowercasing"})
@register_scorer("wer", dataclass=WerScorerConfig)
class WerScorer(BaseScorer): class WerScorer(BaseScorer):
def __init__(self, args): def __init__(self, cfg):
super().__init__(args) super().__init__(cfg)
self.reset() self.reset()
try: try:
import editdistance as ed import editdistance as ed
@ -18,26 +35,12 @@ class WerScorer(BaseScorer):
raise ImportError("Please install editdistance to use WER scorer") raise ImportError("Please install editdistance to use WER scorer")
self.ed = ed self.ed = ed
self.tokenizer = EvaluationTokenizer( self.tokenizer = EvaluationTokenizer(
tokenizer_type=self.args.wer_tokenizer, tokenizer_type=self.cfg.wer_tokenizer,
lowercase=self.args.wer_lowercase, lowercase=self.cfg.wer_lowercase,
punctuation_removal=self.args.wer_remove_punct, punctuation_removal=self.cfg.wer_remove_punct,
character_tokenization=self.args.wer_char_level, character_tokenization=self.cfg.wer_char_level,
) )
@staticmethod
def add_args(parser):
# fmt: off
parser.add_argument('--wer-tokenizer', type=str, default='none',
choices=EvaluationTokenizer.ALL_TOKENIZER_TYPES,
help='sacreBLEU tokenizer to use for evaluation')
parser.add_argument('--wer-remove-punct', action='store_true',
help='remove punctuation')
parser.add_argument('--wer-char-level', action='store_true',
help='evaluate at character level')
parser.add_argument('--wer-lowercase', action='store_true',
help='lowercasing')
# fmt: on
def reset(self): def reset(self):
self.distance = 0 self.distance = 0
self.ref_length = 0 self.ref_length = 0

View File

@ -7,8 +7,6 @@
import argparse import argparse
import importlib import importlib
import os import os
from argparse import Namespace
from typing import Union
from fairseq.dataclass import FairseqDataclass from fairseq.dataclass import FairseqDataclass
from omegaconf import DictConfig from omegaconf import DictConfig
@ -22,10 +20,10 @@ TASK_REGISTRY = {}
TASK_CLASS_NAMES = set() TASK_CLASS_NAMES = set()
def setup_task(task_cfg: Union[DictConfig, Namespace], **kwargs): def setup_task(cfg: DictConfig, **kwargs):
if isinstance(task_cfg, DictConfig): if isinstance(cfg, DictConfig):
return TASK_REGISTRY[task_cfg._name].setup_task(task_cfg, **kwargs) return TASK_REGISTRY[cfg._name].setup_task(cfg, **kwargs)
return TASK_REGISTRY[task_cfg.task].setup_task(task_cfg, **kwargs) return TASK_REGISTRY[cfg.task].setup_task(cfg, **kwargs)
def register_task(name, dataclass=None): def register_task(name, dataclass=None):
@ -70,7 +68,8 @@ def register_task(name, dataclass=None):
) )
cls.__dataclass = dataclass cls.__dataclass = dataclass
TASK_DATACLASS_REGISTRY[name] = dataclass if dataclass is not None:
TASK_DATACLASS_REGISTRY[name] = dataclass
return cls return cls

View File

@ -79,7 +79,7 @@ class AudioPretrainingTask(LegacyFairseqTask):
"""Setup the task (e.g., load dictionaries). """Setup the task (e.g., load dictionaries).
Args: Args:
args (argparse.Namespace): parsed command-line arguments args (omegaconf.DictConfig): parsed command-line arguments
""" """
return cls(args) return cls(args)

View File

@ -12,6 +12,7 @@ import torch
from fairseq import metrics, search, tokenizer, utils from fairseq import metrics, search, tokenizer, utils
from fairseq.data import Dictionary, FairseqDataset, data_utils, encoders, iterators from fairseq.data import Dictionary, FairseqDataset, data_utils, encoders, iterators
from fairseq.dataclass.utils import gen_parser_from_dataclass from fairseq.dataclass.utils import gen_parser_from_dataclass
from omegaconf import DictConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -39,8 +40,8 @@ class FairseqTask(object):
""" """
return criterion.logging_outputs_can_be_summed() return criterion.logging_outputs_can_be_summed()
def __init__(self, args): def __init__(self, cfg: DictConfig, **kwargs):
self.args = args self.cfg = cfg
self.datasets = {} self.datasets = {}
self.dataset_to_epoch_iter = {} self.dataset_to_epoch_iter = {}
@ -78,16 +79,16 @@ class FairseqTask(object):
return d return d
@classmethod @classmethod
def setup_task(cls, args, **kwargs): def setup_task(cls, cfg: DictConfig, **kwargs):
"""Setup the task (e.g., load dictionaries). """Setup the task (e.g., load dictionaries).
Args: Args:
args (argparse.Namespace): parsed command-line arguments cfg (omegaconf.DictConfig): parsed command-line arguments
""" """
return cls(args, **kwargs) return cls(cfg, **kwargs)
def has_sharded_data(self, split): def has_sharded_data(self, split):
return os.pathsep in getattr(self.args, "data", "") return os.pathsep in getattr(self.cfg, "data", "")
def load_dataset(self, split, combine=False, **kwargs): def load_dataset(self, split, combine=False, **kwargs):
"""Load a given dataset split. """Load a given dataset split.
@ -254,39 +255,39 @@ class FairseqTask(object):
return epoch_iter return epoch_iter
def build_model(self, args): def build_model(self, cfg: DictConfig):
""" """
Build the :class:`~fairseq.models.BaseFairseqModel` instance for this Build the :class:`~fairseq.models.BaseFairseqModel` instance for this
task. task.
Args: Args:
args (argparse.Namespace): parsed command-line arguments cfg (omegaconf.DictConfig): configuration object
Returns: Returns:
a :class:`~fairseq.models.BaseFairseqModel` instance a :class:`~fairseq.models.BaseFairseqModel` instance
""" """
from fairseq import models, quantization_utils from fairseq import models, quantization_utils
model = models.build_model(args, self) model = models.build_model(cfg, self)
if getattr(args, "tpu", False): if getattr(cfg, "tpu", False):
model.prepare_for_tpu_() model.prepare_for_tpu_()
model = quantization_utils.quantize_model_scalar(model, args) model = quantization_utils.quantize_model_scalar(model, cfg)
return model return model
def build_criterion(self, args): def build_criterion(self, cfg: DictConfig):
""" """
Build the :class:`~fairseq.criterions.FairseqCriterion` instance for Build the :class:`~fairseq.criterions.FairseqCriterion` instance for
this task. this task.
Args: Args:
args (argparse.Namespace): parsed command-line arguments cfg (omegaconf.DictConfig): configration object
Returns: Returns:
a :class:`~fairseq.criterions.FairseqCriterion` instance a :class:`~fairseq.criterions.FairseqCriterion` instance
""" """
from fairseq import criterions from fairseq import criterions
return criterions.build_criterion(args, self) return criterions.build_criterion(cfg, self)
def build_generator( def build_generator(
self, models, args, seq_gen_cls=None, extra_gen_cls_kwargs=None self, models, args, seq_gen_cls=None, extra_gen_cls_kwargs=None

View File

@ -28,7 +28,7 @@ from fairseq.data import (
from fairseq.data.indexed_dataset import get_available_dataset_impl from fairseq.data.indexed_dataset import get_available_dataset_impl
from fairseq.data.shorten_dataset import maybe_shorten_dataset from fairseq.data.shorten_dataset import maybe_shorten_dataset
from fairseq.dataclass import ChoiceEnum, FairseqDataclass from fairseq.dataclass import ChoiceEnum, FairseqDataclass
from fairseq.tasks import FairseqTask, register_task from fairseq.tasks import LegacyFairseqTask, register_task
from omegaconf import II from omegaconf import II
@ -85,16 +85,16 @@ class LanguageModelingConfig(FairseqDataclass):
}, },
) )
# TODO common vars below add to parent # TODO common vars below add to parent
seed: int = II("params.common.seed") seed: int = II("common.seed")
dataset_impl: Optional[ChoiceEnum(get_available_dataset_impl())] = II( dataset_impl: Optional[ChoiceEnum(get_available_dataset_impl())] = II(
"params.dataset.dataset_impl" "dataset.dataset_impl"
) )
data_buffer_size: int = II("params.dataset.data_buffer_size") data_buffer_size: int = II("dataset.data_buffer_size")
tpu: bool = II("params.common.tpu") tpu: bool = II("common.tpu")
@register_task("language_modeling", dataclass=LanguageModelingConfig) @register_task("language_modeling", dataclass=LanguageModelingConfig)
class LanguageModelingTask(FairseqTask): class LanguageModelingTask(LegacyFairseqTask):
""" """
Train a language model. Train a language model.

View File

@ -117,7 +117,7 @@ class MultilingualTranslationTask(LegacyFairseqTask):
return cls(args, dicts, training) return cls(args, dicts, training)
@classmethod @classmethod
def prepare(cls, args, **kargs): def update_args(cls, args):
args.left_pad_source = utils.eval_bool(args.left_pad_source) args.left_pad_source = utils.eval_bool(args.left_pad_source)
args.left_pad_target = utils.eval_bool(args.left_pad_target) args.left_pad_target = utils.eval_bool(args.left_pad_target)
@ -127,6 +127,10 @@ class MultilingualTranslationTask(LegacyFairseqTask):
) )
if isinstance(args.lang_pairs, str): if isinstance(args.lang_pairs, str):
args.lang_pairs = args.lang_pairs.split(",") args.lang_pairs = args.lang_pairs.split(",")
@classmethod
def prepare(cls, args, **kargs):
cls.update_args(args)
sorted_langs = sorted( sorted_langs = sorted(
list({x for lang_pair in args.lang_pairs for x in lang_pair.split("-")}) list({x for lang_pair in args.lang_pairs for x in lang_pair.split("-")})
) )
@ -298,6 +302,10 @@ class MultilingualTranslationTask(LegacyFairseqTask):
if len(messages) > 0: if len(messages) > 0:
raise ValueError(" ".join(messages)) raise ValueError(" ".join(messages))
# Update args -> the fact that the constructor here
# changes the args object doesn't mean you get the same one here
self.update_args(args)
# Check if task args are consistant with model args # Check if task args are consistant with model args
check_args() check_args()

View File

@ -13,7 +13,7 @@ from fairseq.data.audio.speech_to_text_dataset import (
SpeechToTextDataset, SpeechToTextDataset,
SpeechToTextDatasetCreator, SpeechToTextDatasetCreator,
) )
from fairseq.tasks import FairseqTask, register_task from fairseq.tasks import LegacyFairseqTask, register_task
logging.basicConfig( logging.basicConfig(
@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
@register_task("speech_to_text") @register_task("speech_to_text")
class SpeechToTextTask(FairseqTask): class SpeechToTextTask(LegacyFairseqTask):
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):
parser.add_argument("data", help="manifest root path") parser.add_argument("data", help="manifest root path")

View File

@ -11,15 +11,18 @@ import contextlib
import logging import logging
import sys import sys
import time import time
from argparse import Namespace
from itertools import chain from itertools import chain
from typing import Any, Dict, List from typing import Any, Dict, List
import torch import torch
from fairseq import checkpoint_utils, distributed_utils, models, optim, utils from fairseq import checkpoint_utils, distributed_utils, models, optim, utils
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
from fairseq.file_io import PathManager from fairseq.file_io import PathManager
from fairseq.logging import meters, metrics from fairseq.logging import meters, metrics
from fairseq.nan_detector import NanDetector from fairseq.nan_detector import NanDetector
from fairseq.optim import lr_scheduler from fairseq.optim import lr_scheduler
from omegaconf import DictConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -35,19 +38,25 @@ class Trainer(object):
communication of the gradients across workers. communication of the gradients across workers.
""" """
def __init__(self, args, task, model, criterion, quantizer=None): def __init__(self, cfg: DictConfig, task, model, criterion, quantizer=None):
self.args = args
if isinstance(cfg, Namespace):
logger.warning(
"argparse.Namespace configuration is deprecated! Automatically converting to OmegaConf"
)
cfg = convert_namespace_to_omegaconf(cfg)
self.cfg = cfg
self.task = task self.task = task
# catalog shared parameters # catalog shared parameters
shared_params = _catalog_shared_params(model) shared_params = _catalog_shared_params(model)
self.tpu = cfg.common.tpu
self.tpu = getattr(args, "tpu", False) self.cuda = torch.cuda.is_available() and not cfg.common.cpu and not self.tpu
self.cuda = torch.cuda.is_available() and not args.cpu and not self.tpu
if self.cuda: if self.cuda:
self.device = torch.device("cuda") self.device = torch.device("cuda")
elif self.tpu: elif self.tpu:
self.device = utils.get_tpu_device(args) self.device = utils.get_tpu_device()
else: else:
self.device = torch.device("cpu") self.device = torch.device("cpu")
@ -58,19 +67,21 @@ class Trainer(object):
import torch_xla.core.xla_model as xm import torch_xla.core.xla_model as xm
self._model = xm.send_cpu_data_to_device(self._model, self.device) self._model = xm.send_cpu_data_to_device(self._model, self.device)
if args.fp16: if cfg.common.fp16:
self._criterion = self._criterion.half() self._criterion = self._criterion.half()
self._model = self._model.half() self._model = self._model.half()
elif args.bf16: elif cfg.common.bf16:
self._criterion = self._criterion.to(dtype=torch.bfloat16) self._criterion = self._criterion.to(dtype=torch.bfloat16)
self._model = self._model.to(dtype=torch.bfloat16) self._model = self._model.to(dtype=torch.bfloat16)
if not args.pipeline_model_parallel: if not cfg.distributed_training.pipeline_model_parallel:
self._criterion = self._criterion.to(device=self.device) self._criterion = self._criterion.to(device=self.device)
self._model = self._model.to(device=self.device) self._model = self._model.to(device=self.device)
self.pipeline_model_parallel = args.pipeline_model_parallel self.pipeline_model_parallel = cfg.distributed_training.pipeline_model_parallel
self.last_device = None self.last_device = None
if self.cuda and self.pipeline_model_parallel: if self.cuda and self.pipeline_model_parallel:
self.last_device = torch.device(args.pipeline_devices[-1]) self.last_device = torch.device(
cfg.distributed_training.pipeline_devices[-1]
)
# check that shared parameters are preserved after device transfer # check that shared parameters are preserved after device transfer
for shared_param in shared_params: for shared_param in shared_params:
@ -129,7 +140,7 @@ class Trainer(object):
@property @property
def data_parallel_world_size(self): def data_parallel_world_size(self):
return self.args.distributed_world_size return self.cfg.distributed_training.distributed_world_size
@property @property
def data_parallel_process_group(self): def data_parallel_process_group(self):
@ -140,11 +151,11 @@ class Trainer(object):
@property @property
def data_parallel_rank(self): def data_parallel_rank(self):
return self.args.distributed_rank return self.cfg.distributed_training.distributed_rank
@property @property
def is_data_parallel_master(self): def is_data_parallel_master(self):
return distributed_utils.is_master(self.args) return distributed_utils.is_master(self.cfg.distributed_training)
@property @property
def criterion(self): def criterion(self):
@ -152,11 +163,11 @@ class Trainer(object):
if ( if (
utils.has_parameters(self._criterion) utils.has_parameters(self._criterion)
and self.data_parallel_world_size > 1 and self.data_parallel_world_size > 1
and not self.args.use_bmuf and not self.cfg.optimization.use_bmuf
and not self.tpu and not self.tpu
): ):
self._wrapped_criterion = models.DistributedFairseqModel( self._wrapped_criterion = models.DistributedFairseqModel(
self.args, self.cfg.distributed_training,
self._criterion, self._criterion,
process_group=self.data_parallel_process_group, process_group=self.data_parallel_process_group,
) )
@ -169,11 +180,11 @@ class Trainer(object):
if self._wrapped_model is None: if self._wrapped_model is None:
if ( if (
self.data_parallel_world_size > 1 self.data_parallel_world_size > 1
and not self.args.use_bmuf and not self.cfg.optimization.use_bmuf
and not self.tpu and not self.tpu
): ):
self._wrapped_model = models.DistributedFairseqModel( self._wrapped_model = models.DistributedFairseqModel(
self.args, self.cfg.distributed_training,
self._model, self._model,
process_group=self.data_parallel_process_group, process_group=self.data_parallel_process_group,
) )
@ -201,44 +212,51 @@ class Trainer(object):
) )
) )
if self.args.fp16 or self.args.bf16: if self.cfg.common.fp16 or self.cfg.common.bf16:
if self.cuda and torch.cuda.get_device_capability(0)[0] < 7: if self.cuda and torch.cuda.get_device_capability(0)[0] < 7:
logger.info( logger.info(
"NOTE: your device does NOT support faster training with --fp16, " "NOTE: your device does NOT support faster training with --fp16, "
"please switch to FP32 which is likely to be faster" "please switch to FP32 which is likely to be faster"
) )
if self.args.memory_efficient_fp16 or self.args.memory_efficient_bf16: if (
self.cfg.common.memory_efficient_fp16
or self.cfg.common.memory_efficient_bf16
):
self._optimizer = optim.MemoryEfficientFP16Optimizer.build_optimizer( self._optimizer = optim.MemoryEfficientFP16Optimizer.build_optimizer(
self.args, params self.cfg, params
) )
else: else:
self._optimizer = optim.FP16Optimizer.build_optimizer(self.args, params) self._optimizer = optim.FP16Optimizer.build_optimizer(self.cfg, params)
else: else:
if self.cuda and torch.cuda.get_device_capability(0)[0] >= 7: if self.cuda and torch.cuda.get_device_capability(0)[0] >= 7:
logger.info("NOTE: your device may support faster training with --fp16") logger.info("NOTE: your device may support faster training with --fp16")
self._optimizer = optim.build_optimizer(self.args, params) self._optimizer = optim.build_optimizer(self.cfg.optimizer, params)
if self.args.use_bmuf: if self.cfg.optimization.use_bmuf:
self._optimizer = optim.FairseqBMUF(self.args, self._optimizer) self._optimizer = optim.FairseqBMUF(
self.cfg.bmuf,
self._optimizer,
)
if self.args.zero_sharding == "os": if self.cfg.distributed_training.zero_sharding == "os":
if ( if (
self.args.fp16 self.cfg.common.fp16
and not self.args.memory_efficient_fp16 and not self.cfg.common.memory_efficient_fp16
and not self.args.memory_efficient_bf16 and not self.cfg.common.memory_efficient_bf16
) and not self.args.fp16_no_flatten_grads: ) and not self.cfg.common.fp16_no_flatten_grads:
raise ValueError( raise ValueError(
"ZeRO is incomptabile with fp16 and flattened grads. " "ZeRO is incomptabile with fp16 and flattened grads. "
"Please use --fp16-no-flatten-grads" "Please use --fp16-no-flatten-grads"
) )
else: else:
optim.shard_( optim.shard_(self._optimizer, self.data_parallel_process_group)
self.args, self._optimizer, self.data_parallel_process_group
)
# We should initialize the learning rate scheduler immediately after # We should initialize the learning rate scheduler immediately after
# building the optimizer, so that the initial learning rate is set. # building the optimizer, so that the initial learning rate is set.
self._lr_scheduler = lr_scheduler.build_lr_scheduler(self.args, self.optimizer) self._lr_scheduler = lr_scheduler.build_lr_scheduler(
self.cfg.lr_scheduler,
self.optimizer,
)
self._lr_scheduler.step_update(0) self._lr_scheduler.step_update(0)
def consolidate_optimizer(self): def consolidate_optimizer(self):
@ -253,7 +271,7 @@ class Trainer(object):
extra_state["previous_training_time"] = self.cumulative_training_time() extra_state["previous_training_time"] = self.cumulative_training_time()
checkpoint_utils.save_state( checkpoint_utils.save_state(
filename, filename,
self.args, self.cfg,
self.get_model().state_dict(), self.get_model().state_dict(),
self.get_criterion(), self.get_criterion(),
self.optimizer, self.optimizer,
@ -277,11 +295,10 @@ class Trainer(object):
bexists = PathManager.isfile(filename) bexists = PathManager.isfile(filename)
if bexists: if bexists:
state = checkpoint_utils.load_checkpoint_to_cpu(filename) state = checkpoint_utils.load_checkpoint_to_cpu(filename)
# load model parameters # load model parameters
try: try:
self.get_model().load_state_dict( self.get_model().load_state_dict(
state["model"], strict=True, args=self.args state["model"], strict=True, model_cfg=self.cfg.model
) )
if utils.has_parameters(self.get_criterion()): if utils.has_parameters(self.get_criterion()):
self.get_criterion().load_state_dict( self.get_criterion().load_state_dict(
@ -355,28 +372,28 @@ class Trainer(object):
if load_dataset: if load_dataset:
logger.info("loading train data for epoch {}".format(epoch)) logger.info("loading train data for epoch {}".format(epoch))
self.task.load_dataset( self.task.load_dataset(
self.args.train_subset, self.cfg.dataset.train_subset,
epoch=epoch, epoch=epoch,
combine=combine, combine=combine,
data_selector=data_selector, data_selector=data_selector,
) )
batch_iterator = self.task.get_batch_iterator( batch_iterator = self.task.get_batch_iterator(
dataset=self.task.dataset(self.args.train_subset), dataset=self.task.dataset(self.cfg.dataset.train_subset),
max_tokens=self.args.max_tokens, max_tokens=self.cfg.dataset.max_tokens,
max_sentences=self.args.batch_size, max_sentences=self.cfg.dataset.batch_size,
max_positions=utils.resolve_max_positions( max_positions=utils.resolve_max_positions(
self.task.max_positions(), self.task.max_positions(),
self.model.max_positions(), self.model.max_positions(),
self.args.max_tokens, self.cfg.dataset.max_tokens,
), ),
ignore_invalid_inputs=True, ignore_invalid_inputs=True,
required_batch_size_multiple=self.args.required_batch_size_multiple, required_batch_size_multiple=self.cfg.dataset.required_batch_size_multiple,
seed=self.args.seed, seed=self.cfg.common.seed,
num_shards=self.data_parallel_world_size if shard_batch_itr else 1, num_shards=self.data_parallel_world_size if shard_batch_itr else 1,
shard_id=self.data_parallel_rank if shard_batch_itr else 0, shard_id=self.data_parallel_rank if shard_batch_itr else 0,
num_workers=self.args.num_workers, num_workers=self.cfg.dataset.num_workers,
epoch=epoch, epoch=epoch,
data_buffer_size=self.args.data_buffer_size, data_buffer_size=self.cfg.dataset.data_buffer_size,
disable_iterator_cache=disable_iterator_cache, disable_iterator_cache=disable_iterator_cache,
) )
self.reset_dummy_batch(batch_iterator.first_batch) self.reset_dummy_batch(batch_iterator.first_batch)
@ -390,19 +407,19 @@ class Trainer(object):
"""Return an EpochBatchIterator over given validation subset for a given epoch.""" """Return an EpochBatchIterator over given validation subset for a given epoch."""
batch_iterator = self.task.get_batch_iterator( batch_iterator = self.task.get_batch_iterator(
dataset=self.task.dataset(subset), dataset=self.task.dataset(subset),
max_tokens=self.args.max_tokens_valid, max_tokens=self.cfg.dataset.max_tokens_valid,
max_sentences=self.args.batch_size_valid, max_sentences=self.cfg.dataset.batch_size_valid,
max_positions=utils.resolve_max_positions( max_positions=utils.resolve_max_positions(
self.task.max_positions(), self.task.max_positions(),
self.model.max_positions(), self.model.max_positions(),
), ),
ignore_invalid_inputs=self.args.skip_invalid_size_inputs_valid_test, ignore_invalid_inputs=self.cfg.dataset.skip_invalid_size_inputs_valid_test,
required_batch_size_multiple=self.args.required_batch_size_multiple, required_batch_size_multiple=self.cfg.dataset.required_batch_size_multiple,
seed=self.args.seed, seed=self.cfg.common.seed,
num_shards=self.data_parallel_world_size, num_shards=self.data_parallel_world_size,
shard_id=self.data_parallel_rank, shard_id=self.data_parallel_rank,
num_workers=self.args.num_workers, num_workers=self.cfg.dataset.num_workers,
data_buffer_size=self.args.data_buffer_size, data_buffer_size=self.cfg.dataset.data_buffer_size,
disable_iterator_cache=disable_iterator_cache, disable_iterator_cache=disable_iterator_cache,
) )
self.reset_dummy_batch(batch_iterator.first_batch) self.reset_dummy_batch(batch_iterator.first_batch)
@ -504,7 +521,7 @@ class Trainer(object):
self.zero_grad() self.zero_grad()
if self.cuda: if self.cuda:
torch.cuda.empty_cache() torch.cuda.empty_cache()
if self.args.distributed_world_size == 1: if self.cfg.distributed_training.distributed_world_size == 1:
return None return None
else: else:
raise e raise e
@ -565,7 +582,7 @@ class Trainer(object):
# multiply gradients by (# GPUs / sample_size) since DDP # multiply gradients by (# GPUs / sample_size) since DDP
# already normalizes by the number of GPUs. Thus we get # already normalizes by the number of GPUs. Thus we get
# (sum_of_gradients / sample_size). # (sum_of_gradients / sample_size).
if not self.args.use_bmuf: if not self.cfg.optimization.use_bmuf:
self.optimizer.multiply_grads( self.optimizer.multiply_grads(
self.data_parallel_world_size / sample_size self.data_parallel_world_size / sample_size
) )
@ -575,12 +592,12 @@ class Trainer(object):
with torch.autograd.profiler.record_function("clip-grads"): with torch.autograd.profiler.record_function("clip-grads"):
# clip grads # clip grads
grad_norm = self.clip_grad_norm(self.args.clip_norm) grad_norm = self.clip_grad_norm(self.cfg.optimization.clip_norm)
# check that grad norms are consistent across workers # check that grad norms are consistent across workers
if ( if (
not self.args.use_bmuf not self.cfg.optimization.use_bmuf
and self.args.distributed_wrapper != "SlowMo" and self.cfg.distributed_training.distributed_wrapper != "SlowMo"
and not self.tpu and not self.tpu
): ):
self._check_grad_norms(grad_norm) self._check_grad_norms(grad_norm)
@ -624,7 +641,10 @@ class Trainer(object):
self.optimizer.optimizer self.optimizer.optimizer
) )
if not overflow or self.args.distributed_wrapper == "SlowMo": if (
not overflow
or self.cfg.distributed_training.distributed_wrapper == "SlowMo"
):
self.set_num_updates(self.get_num_updates() + 1) self.set_num_updates(self.get_num_updates() + 1)
if self.tpu: if self.tpu:
@ -636,7 +656,7 @@ class Trainer(object):
# only log stats every log_interval steps # only log stats every log_interval steps
# this causes wps to be misreported when log_interval > 1 # this causes wps to be misreported when log_interval > 1
logging_output = {} logging_output = {}
if self.get_num_updates() % self.args.log_interval == 0: if self.get_num_updates() % self.cfg.common.log_interval == 0:
# log memory usage # log memory usage
mem_info = xm.get_memory_info(self.device) mem_info = xm.get_memory_info(self.device)
gb_free = mem_info["kb_free"] / 1024 / 1024 gb_free = mem_info["kb_free"] / 1024 / 1024
@ -677,16 +697,16 @@ class Trainer(object):
# clear CUDA cache to reduce memory fragmentation # clear CUDA cache to reduce memory fragmentation
if ( if (
self.cuda self.cuda
and self.args.empty_cache_freq > 0 and self.cfg.common.empty_cache_freq > 0
and ( and (
(self.get_num_updates() + self.args.empty_cache_freq - 1) (self.get_num_updates() + self.cfg.common.empty_cache_freq - 1)
% self.args.empty_cache_freq % self.cfg.common.empty_cache_freq
) )
== 0 == 0
): ):
torch.cuda.empty_cache() torch.cuda.empty_cache()
if self.args.fp16: if self.cfg.common.fp16:
metrics.log_scalar( metrics.log_scalar(
"loss_scale", "loss_scale",
self.optimizer.scaler.loss_scale, self.optimizer.scaler.loss_scale,
@ -883,10 +903,10 @@ class Trainer(object):
return t.to(dtype=torch.bfloat16) return t.to(dtype=torch.bfloat16)
return t return t
if self.args.fp16: if self.cfg.common.fp16:
sample = utils.apply_to_sample(apply_half, sample) sample = utils.apply_to_sample(apply_half, sample)
if self.args.bf16: if self.cfg.common.bf16:
sample = utils.apply_to_sample(apply_bfloat16, sample) sample = utils.apply_to_sample(apply_bfloat16, sample)
return sample return sample
@ -894,7 +914,7 @@ class Trainer(object):
def _set_seed(self): def _set_seed(self):
# Set seed based on args.seed and the update number so that we get # Set seed based on args.seed and the update number so that we get
# reproducible results when resuming from checkpoints # reproducible results when resuming from checkpoints
seed = self.args.seed + self.get_num_updates() seed = self.cfg.common.seed + self.get_num_updates()
utils.set_torch_seed(seed) utils.set_torch_seed(seed)
def _sync_stats(self): def _sync_stats(self):
@ -902,10 +922,12 @@ class Trainer(object):
# BMUF and it's a bmuf sync with warmup iterations completed before. # BMUF and it's a bmuf sync with warmup iterations completed before.
if self.data_parallel_world_size == 1: if self.data_parallel_world_size == 1:
return False return False
elif self.args.use_bmuf: elif self.cfg.optimization.use_bmuf:
return (self.get_num_updates() + 1) % self.args.global_sync_iter == 0 and ( return (
self.get_num_updates() + 1 self.get_num_updates() + 1
) > self.args.warmup_iterations ) % self.cfg.bmuf.global_sync_iter == 0 and (
self.get_num_updates() + 1
) > self.cfg.bmuf.warmup_iterations
else: else:
return True return True
@ -950,7 +972,7 @@ class Trainer(object):
zip( zip(
*distributed_utils.all_gather_list( *distributed_utils.all_gather_list(
[logging_outputs] + list(extra_stats_to_sum), [logging_outputs] + list(extra_stats_to_sum),
max_size=getattr(self.args, "all_gather_list_size", 16384), max_size=getattr(self.cfg.common, "all_gather_list_size", 16384),
group=self.data_parallel_process_group, group=self.data_parallel_process_group,
) )
) )
@ -1038,11 +1060,11 @@ class Trainer(object):
if grad_norm is not None: if grad_norm is not None:
metrics.log_speed("ups", 1.0, priority=100, round=2) metrics.log_speed("ups", 1.0, priority=100, round=2)
metrics.log_scalar("gnorm", grad_norm, priority=400, round=3) metrics.log_scalar("gnorm", grad_norm, priority=400, round=3)
if self.args.clip_norm > 0: if self.cfg.optimization.clip_norm > 0:
metrics.log_scalar( metrics.log_scalar(
"clip", "clip",
torch.where( torch.where(
grad_norm > self.args.clip_norm, grad_norm > self.cfg.optimization.clip_norm,
grad_norm.new_tensor(100), grad_norm.new_tensor(100),
grad_norm.new_tensor(0), grad_norm.new_tensor(0),
), ),
@ -1087,7 +1109,7 @@ class Trainer(object):
logger.warning( logger.warning(
"XLA compilation detected on device #{}; too many of these can lead " "XLA compilation detected on device #{}; too many of these can lead "
"to slow training, but we expect a few in the beginning".format( "to slow training, but we expect a few in the beginning".format(
self.args.distributed_rank self.cfg.distributed_training.distributed_rank
) )
) )
self._num_xla_compiles = num_xla_compiles self._num_xla_compiles = num_xla_compiles

View File

@ -11,13 +11,19 @@ Evaluate the perplexity of a trained language model.
import logging import logging
import math import math
import os import os
from argparse import Namespace
import torch import torch
from fairseq import checkpoint_utils, distributed_utils, options, tasks, utils from fairseq import checkpoint_utils, distributed_utils, options, tasks, utils
from fairseq.data import LMContextWindowDataset from fairseq.data import LMContextWindowDataset
from fairseq.dataclass.data_class import register_hydra_cfg
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
from fairseq.logging import progress_bar from fairseq.logging import progress_bar
from fairseq.logging.meters import StopwatchMeter, TimeMeter from fairseq.logging.meters import StopwatchMeter, TimeMeter
from fairseq.sequence_scorer import SequenceScorer from fairseq.sequence_scorer import SequenceScorer
from hydra.core.config_store import ConfigStore
from hydra.experimental import initialize
from omegaconf import DictConfig
logging.basicConfig( logging.basicConfig(
@ -60,65 +66,60 @@ class WordStat(object):
) )
def main(parsed_args, **unused_kwargs): def main(cfg: DictConfig, override_args=None, **unused_kwargs):
assert parsed_args.path is not None, "--path required for evaluation!" if isinstance(cfg, Namespace):
cfg = convert_namespace_to_omegaconf(cfg)
if torch.cuda.is_available() and not parsed_args.cpu: utils.import_user_module(cfg.common)
torch.cuda.set_device(parsed_args.device_id)
utils.import_user_module(parsed_args) use_fp16 = cfg.common.fp16
use_cuda = torch.cuda.is_available() and not cfg.common.cpu
logger.info(parsed_args) if use_cuda:
torch.cuda.set_device(cfg.distributed_training.device_id)
use_cuda = torch.cuda.is_available() and not parsed_args.cpu if override_args is not None:
overrides = vars(override_args)
overrides.update(eval(getattr(override_args, "model_overrides", "{}")))
else:
overrides = None
task = tasks.setup_task(parsed_args) logger.info(cfg)
# Load ensemble # Load ensemble
logger.info("loading model(s) from {}".format(parsed_args.path)) logger.info("loading model(s) from {}".format(cfg.common_eval.path))
models, args = checkpoint_utils.load_model_ensemble(
parsed_args.path.split(os.pathsep),
arg_overrides=eval(parsed_args.model_overrides),
task=task,
suffix=getattr(parsed_args, "checkpoint_suffix", ""),
strict=(parsed_args.checkpoint_shard_count == 1),
num_shards=parsed_args.checkpoint_shard_count,
)
for arg in vars(parsed_args).keys():
if arg not in {
"self_target",
"future_target",
"past_target",
"tokens_per_sample",
"output_size_dictionary",
"add_bos_token",
}:
setattr(args, arg, getattr(parsed_args, arg))
# reduce tokens per sample by the required context window size # reduce tokens per sample by the required context window size
args.tokens_per_sample -= args.context_window cfg.task.tokens_per_sample -= cfg.eval_lm.context_window
task = tasks.setup_task(args)
models, model_args, task = checkpoint_utils.load_model_ensemble_and_task(
[cfg.common_eval.path],
arg_overrides=overrides,
suffix=cfg.checkpoint.checkpoint_suffix,
strict=(cfg.checkpoint.checkpoint_shard_count == 1),
num_shards=cfg.checkpoint.checkpoint_shard_count,
)
# Load dataset splits # Load dataset splits
task.load_dataset(args.gen_subset) gen_subset = cfg.dataset.gen_subset
dataset = task.dataset(args.gen_subset) task.load_dataset(gen_subset)
if args.context_window > 0: dataset = task.dataset(gen_subset)
if cfg.eval_lm.context_window > 0:
dataset = LMContextWindowDataset( dataset = LMContextWindowDataset(
dataset=dataset, dataset=dataset,
tokens_per_sample=args.tokens_per_sample, tokens_per_sample=cfg.task.tokens_per_sample,
context_window=args.context_window, context_window=cfg.eval_lm.context_window,
pad_idx=task.source_dictionary.pad(), pad_idx=task.source_dictionary.pad(),
) )
logger.info("{} {} {} examples".format(args.data, args.gen_subset, len(dataset))) logger.info("{} {} {} examples".format(cfg.task.data, gen_subset, len(dataset)))
# Optimize ensemble for generation and set the source and dest dicts on the model (required by scorer) # Optimize ensemble for generation and set the source and dest dicts on the model (required by scorer)
for model in models: for model in models:
if args.fp16: if use_fp16:
model.half() model.half()
if use_cuda and not args.pipeline_model_parallel: if use_cuda and not cfg.distributed_training.pipeline_model_parallel:
model.cuda() model.cuda()
model.prepare_for_inference_(args) model.prepare_for_inference_(cfg)
assert len(models) > 0 assert len(models) > 0
@ -128,35 +129,41 @@ def main(parsed_args, **unused_kwargs):
itr = task.get_batch_iterator( itr = task.get_batch_iterator(
dataset=dataset, dataset=dataset,
max_tokens=args.max_tokens or 36000, max_tokens=cfg.dataset.max_tokens or 36000,
max_sentences=args.batch_size, max_sentences=cfg.dataset.batch_size,
max_positions=utils.resolve_max_positions( max_positions=utils.resolve_max_positions(
*[model.max_positions() for model in models] *[model.max_positions() for model in models]
), ),
ignore_invalid_inputs=True, ignore_invalid_inputs=True,
num_shards=args.num_shards, num_shards=max(
shard_id=args.shard_id, cfg.dataset.num_shards,
num_workers=args.num_workers, cfg.distributed_training.distributed_world_size,
data_buffer_size=args.data_buffer_size, ),
shard_id=max(
cfg.dataset.shard_id,
cfg.distributed_training.distributed_rank,
),
num_workers=cfg.dataset.num_workers,
data_buffer_size=cfg.dataset.data_buffer_size,
).next_epoch_itr(shuffle=False) ).next_epoch_itr(shuffle=False)
progress = progress_bar.progress_bar( progress = progress_bar.progress_bar(
itr, itr,
log_format=args.log_format, log_format=cfg.common.log_format,
log_interval=args.log_interval, log_interval=cfg.common.log_interval,
default_log_format=("tqdm" if not args.no_progress_bar else "none"), default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"),
) )
gen_timer = StopwatchMeter() gen_timer = StopwatchMeter()
scorer = SequenceScorer(task.target_dictionary, args.softmax_batch) scorer = SequenceScorer(task.target_dictionary, cfg.eval_lm.softmax_batch)
score_sum = 0.0 score_sum = 0.0
count = 0 count = 0
if args.remove_bpe is not None: if cfg.common_eval.remove_bpe is not None:
if args.remove_bpe == "sentencepiece": if cfg.common_eval.remove_bpe == "sentencepiece":
raise NotImplementedError raise NotImplementedError
else: else:
bpe_cont = args.remove_bpe.rstrip() bpe_cont = cfg.common_eval.remove_bpe.rstrip()
bpe_toks = { bpe_toks = {
i i
for i in range(len(task.source_dictionary)) for i in range(len(task.source_dictionary))
@ -189,7 +196,7 @@ def main(parsed_args, **unused_kwargs):
tgt_len = tokens.numel() tgt_len = tokens.numel()
pos_scores = hypo["positional_scores"].float() pos_scores = hypo["positional_scores"].float()
if getattr(args, "add_bos_token", False): if cfg.task.add_bos_token:
assert hypo["tokens"][0].item() == task.target_dictionary.bos() assert hypo["tokens"][0].item() == task.target_dictionary.bos()
tokens = tokens[1:] tokens = tokens[1:]
pos_scores = pos_scores[1:] pos_scores = pos_scores[1:]
@ -212,7 +219,7 @@ def main(parsed_args, **unused_kwargs):
score_sum += pos_scores.sum().cpu() score_sum += pos_scores.sum().cpu()
count += pos_scores.numel() - skipped_toks count += pos_scores.numel() - skipped_toks
if args.output_word_probs or args.output_word_stats: if cfg.eval_lm.output_word_probs or cfg.eval_lm.output_word_stats:
w = "" w = ""
word_prob = [] word_prob = []
is_bpe = False is_bpe = False
@ -238,7 +245,7 @@ def main(parsed_args, **unused_kwargs):
) )
is_bpe = False is_bpe = False
w = "" w = ""
if args.output_word_probs: if cfg.eval_lm.output_word_probs:
logger.info( logger.info(
str(int(sample_id)) str(int(sample_id))
+ " " + " "
@ -264,7 +271,7 @@ def main(parsed_args, **unused_kwargs):
) )
) )
if args.output_word_stats: if cfg.eval_lm.output_word_stats:
for ws in sorted(word_stats.values(), key=lambda x: x.count, reverse=True): for ws in sorted(word_stats.values(), key=lambda x: x.count, reverse=True):
logger.info(ws) logger.info(ws)
@ -272,8 +279,16 @@ def main(parsed_args, **unused_kwargs):
def cli_main(): def cli_main():
parser = options.get_eval_lm_parser() parser = options.get_eval_lm_parser()
args = options.parse_args_and_arch(parser) args = options.parse_args_and_arch(parser)
distributed_utils.call_main(args, main)
# only override args that are explicitly given on the command line
override_parser = options.get_validation_parser()
override_args = options.parse_args_and_arch(override_parser, suppress_defaults=True)
distributed_utils.call_main(args, main, override_args=override_args)
if __name__ == "__main__": if __name__ == "__main__":
cs = ConfigStore.instance()
register_hydra_cfg(cs)
initialize(config_path="../config", strict=True)
cli_main() cli_main()

View File

@ -12,33 +12,45 @@ import logging
import math import math
import os import os
import sys import sys
from argparse import Namespace
from itertools import chain from itertools import chain
import numpy as np import numpy as np
import torch import torch
from fairseq import checkpoint_utils, options, scoring, tasks, utils from fairseq import checkpoint_utils, options, scoring, tasks, utils
from fairseq.data import encoders
from fairseq.dataclass.data_class import register_hydra_cfg
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
from fairseq.logging import progress_bar from fairseq.logging import progress_bar
from fairseq.logging.meters import StopwatchMeter, TimeMeter from fairseq.logging.meters import StopwatchMeter, TimeMeter
from hydra.core.config_store import ConfigStore
from hydra.experimental import initialize
from omegaconf import DictConfig
def main(args): def main(cfg: DictConfig):
assert args.path is not None, "--path required for generation!"
if isinstance(cfg, Namespace):
cfg = convert_namespace_to_omegaconf(cfg)
assert cfg.common_eval.path is not None, "--path required for generation!"
assert ( assert (
not args.sampling or args.nbest == args.beam not cfg.generation.sampling or cfg.generation.nbest == cfg.generation.beam
), "--sampling requires --nbest to be equal to --beam" ), "--sampling requires --nbest to be equal to --beam"
assert ( assert (
args.replace_unk is None or args.dataset_impl == "raw" cfg.generation.replace_unk is None or cfg.dataset.dataset_impl == "raw"
), "--replace-unk requires a raw text dataset (--dataset-impl=raw)" ), "--replace-unk requires a raw text dataset (--dataset-impl=raw)"
if args.results_path is not None: if cfg.common_eval.results_path is not None:
os.makedirs(args.results_path, exist_ok=True) os.makedirs(cfg.common_eval.results_path, exist_ok=True)
output_path = os.path.join( output_path = os.path.join(
args.results_path, "generate-{}.txt".format(args.gen_subset) cfg.common_eval.results_path,
"generate-{}.txt".format(cfg.dataset.gen_subset),
) )
with open(output_path, "w", buffering=1, encoding="utf-8") as h: with open(output_path, "w", buffering=1, encoding="utf-8") as h:
return _main(args, h) return _main(cfg, h)
else: else:
return _main(args, sys.stdout) return _main(cfg, sys.stdout)
def get_symbols_to_strip_from_output(generator): def get_symbols_to_strip_from_output(generator):
@ -48,7 +60,7 @@ def get_symbols_to_strip_from_output(generator):
return {generator.eos} return {generator.eos}
def _main(args, output_file): def _main(cfg: DictConfig, output_file):
logging.basicConfig( logging.basicConfig(
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
datefmt="%Y-%m-%d %H:%M:%S", datefmt="%Y-%m-%d %H:%M:%S",
@ -57,22 +69,22 @@ def _main(args, output_file):
) )
logger = logging.getLogger("fairseq_cli.generate") logger = logging.getLogger("fairseq_cli.generate")
utils.import_user_module(args) utils.import_user_module(cfg.common)
if args.max_tokens is None and args.batch_size is None: if cfg.dataset.max_tokens is None and cfg.dataset.batch_size is None:
args.max_tokens = 12000 cfg.dataset.max_tokens = 12000
logger.info(args) logger.info(cfg)
# Fix seed for stochastic decoding # Fix seed for stochastic decoding
if args.seed is not None and not args.no_seed_provided: if cfg.common.seed is not None and not cfg.generation.no_seed_provided:
np.random.seed(args.seed) np.random.seed(cfg.common.seed)
utils.set_torch_seed(args.seed) utils.set_torch_seed(cfg.common.seed)
use_cuda = torch.cuda.is_available() and not args.cpu use_cuda = torch.cuda.is_available() and not cfg.common.cpu
# Load dataset splits # Load dataset splits
task = tasks.setup_task(args) task = tasks.setup_task(cfg.task)
task.load_dataset(args.gen_subset) task.load_dataset(cfg.dataset.gen_subset)
# Set dictionaries # Set dictionaries
try: try:
@ -81,32 +93,30 @@ def _main(args, output_file):
src_dict = None src_dict = None
tgt_dict = task.target_dictionary tgt_dict = task.target_dictionary
overrides = ast.literal_eval(args.model_overrides) overrides = ast.literal_eval(cfg.common_eval.model_overrides)
# Load ensemble # Load ensemble
logger.info("loading model(s) from {}".format(args.path)) logger.info("loading model(s) from {}".format(cfg.common_eval.path))
models, _model_args = checkpoint_utils.load_model_ensemble( models, _model_args = checkpoint_utils.load_model_ensemble(
utils.split_paths(args.path), utils.split_paths(cfg.common_eval.path),
arg_overrides=overrides, arg_overrides=overrides,
task=task, task=task,
suffix=getattr(args, "checkpoint_suffix", ""), suffix=cfg.checkpoint.checkpoint_suffix,
strict=(args.checkpoint_shard_count == 1), strict=(cfg.checkpoint.checkpoint_shard_count == 1),
num_shards=args.checkpoint_shard_count, num_shards=cfg.checkpoint.checkpoint_shard_count,
) )
if args.lm_path is not None: if cfg.generation.lm_path is not None:
overrides["data"] = args.data overrides["data"] = cfg.task.data
try: try:
lms, _ = checkpoint_utils.load_model_ensemble( lms, _ = checkpoint_utils.load_model_ensemble(
[args.lm_path], [cfg.generation.lm_path], arg_overrides=overrides, task=None
arg_overrides=overrides,
task=None,
) )
except: except:
logger.warning( logger.warning(
f"Failed to load language model! Please make sure that the language model dict is the same " f"Failed to load language model! Please make sure that the language model dict is the same "
f"as target dict and is located in the data dir ({args.data})" f"as target dict and is located in the data dir ({cfg.task.data})"
) )
raise raise
@ -118,49 +128,50 @@ def _main(args, output_file):
for model in chain(models, lms): for model in chain(models, lms):
if model is None: if model is None:
continue continue
if args.fp16: if cfg.common.fp16:
model.half() model.half()
if use_cuda and not args.pipeline_model_parallel: if use_cuda and not cfg.distributed_training.pipeline_model_parallel:
model.cuda() model.cuda()
model.prepare_for_inference_(args) model.prepare_for_inference_(cfg)
# Load alignment dictionary for unknown word replacement # Load alignment dictionary for unknown word replacement
# (None if no unknown word replacement, empty if no path to align dictionary) # (None if no unknown word replacement, empty if no path to align dictionary)
align_dict = utils.load_align_dict(args.replace_unk) align_dict = utils.load_align_dict(cfg.generation.replace_unk)
# Load dataset (possibly sharded) # Load dataset (possibly sharded)
itr = task.get_batch_iterator( itr = task.get_batch_iterator(
dataset=task.dataset(args.gen_subset), dataset=task.dataset(cfg.dataset.gen_subset),
max_tokens=args.max_tokens, max_tokens=cfg.dataset.max_tokens,
max_sentences=args.batch_size, max_sentences=cfg.dataset.batch_size,
max_positions=utils.resolve_max_positions( max_positions=utils.resolve_max_positions(
task.max_positions(), *[model.max_positions() for model in models] task.max_positions(), *[m.max_positions() for m in models]
), ),
ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, ignore_invalid_inputs=cfg.dataset.skip_invalid_size_inputs_valid_test,
required_batch_size_multiple=args.required_batch_size_multiple, required_batch_size_multiple=cfg.dataset.required_batch_size_multiple,
num_shards=args.num_shards, seed=cfg.common.seed,
shard_id=args.shard_id, num_shards=cfg.distributed_training.distributed_world_size,
num_workers=args.num_workers, shard_id=cfg.distributed_training.distributed_rank,
data_buffer_size=args.data_buffer_size, num_workers=cfg.dataset.num_workers,
data_buffer_size=cfg.dataset.data_buffer_size,
).next_epoch_itr(shuffle=False) ).next_epoch_itr(shuffle=False)
progress = progress_bar.progress_bar( progress = progress_bar.progress_bar(
itr, itr,
log_format=args.log_format, log_format=cfg.common.log_format,
log_interval=args.log_interval, log_interval=cfg.common.log_interval,
default_log_format=("tqdm" if not args.no_progress_bar else "none"), default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"),
) )
# Initialize generator # Initialize generator
gen_timer = StopwatchMeter() gen_timer = StopwatchMeter()
extra_gen_cls_kwargs = {"lm_model": lms[0], "lm_weight": args.lm_weight} extra_gen_cls_kwargs = {"lm_model": lms[0], "lm_weight": cfg.generation.lm_weight}
generator = task.build_generator( generator = task.build_generator(
models, args, extra_gen_cls_kwargs=extra_gen_cls_kwargs models, cfg.task, extra_gen_cls_kwargs=extra_gen_cls_kwargs
) )
# Handle tokenization and BPE # Handle tokenization and BPE
tokenizer = task.build_tokenizer(args) tokenizer = encoders.build_tokenizer(cfg.tokenizer)
bpe = task.build_bpe(args) bpe = encoders.build_bpe(cfg.bpe)
def decode_fn(x): def decode_fn(x):
if bpe is not None: if bpe is not None:
@ -169,7 +180,7 @@ def _main(args, output_file):
x = tokenizer.decode(x) x = tokenizer.decode(x)
return x return x
scorer = scoring.build_scorer(args, tgt_dict) scorer = scoring.build_scorer(cfg.scoring, tgt_dict)
num_sentences = 0 num_sentences = 0
has_target = True has_target = True
@ -180,8 +191,8 @@ def _main(args, output_file):
continue continue
prefix_tokens = None prefix_tokens = None
if args.prefix_size > 0: if cfg.generation.prefix_size > 0:
prefix_tokens = sample["target"][:, : args.prefix_size] prefix_tokens = sample["target"][:, : cfg.generation.prefix_size]
constraints = None constraints = None
if "constraints" in sample: if "constraints" in sample:
@ -217,19 +228,21 @@ def _main(args, output_file):
# Either retrieve the original sentences or regenerate them from tokens. # Either retrieve the original sentences or regenerate them from tokens.
if align_dict is not None: if align_dict is not None:
src_str = task.dataset(args.gen_subset).src.get_original_text(sample_id) src_str = task.dataset(cfg.dataset.gen_subset).src.get_original_text(
target_str = task.dataset(args.gen_subset).tgt.get_original_text( sample_id
)
target_str = task.dataset(cfg.dataset.gen_subset).tgt.get_original_text(
sample_id sample_id
) )
else: else:
if src_dict is not None: if src_dict is not None:
src_str = src_dict.string(src_tokens, args.remove_bpe) src_str = src_dict.string(src_tokens, cfg.common_eval.remove_bpe)
else: else:
src_str = "" src_str = ""
if has_target: if has_target:
target_str = tgt_dict.string( target_str = tgt_dict.string(
target_tokens, target_tokens,
args.remove_bpe, cfg.common_eval.remove_bpe,
escape_unk=True, escape_unk=True,
extra_symbols_to_ignore=get_symbols_to_strip_from_output( extra_symbols_to_ignore=get_symbols_to_strip_from_output(
generator generator
@ -240,25 +253,25 @@ def _main(args, output_file):
if has_target: if has_target:
target_str = decode_fn(target_str) target_str = decode_fn(target_str)
if not args.quiet: if not cfg.common_eval.quiet:
if src_dict is not None: if src_dict is not None:
print("S-{}\t{}".format(sample_id, src_str), file=output_file) print("S-{}\t{}".format(sample_id, src_str), file=output_file)
if has_target: if has_target:
print("T-{}\t{}".format(sample_id, target_str), file=output_file) print("T-{}\t{}".format(sample_id, target_str), file=output_file)
# Process top predictions # Process top predictions
for j, hypo in enumerate(hypos[i][: args.nbest]): for j, hypo in enumerate(hypos[i][: cfg.generation.nbest]):
hypo_tokens, hypo_str, alignment = utils.post_process_prediction( hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
hypo_tokens=hypo["tokens"].int().cpu(), hypo_tokens=hypo["tokens"].int().cpu(),
src_str=src_str, src_str=src_str,
alignment=hypo["alignment"], alignment=hypo["alignment"],
align_dict=align_dict, align_dict=align_dict,
tgt_dict=tgt_dict, tgt_dict=tgt_dict,
remove_bpe=args.remove_bpe, remove_bpe=cfg.common_eval.remove_bpe,
extra_symbols_to_ignore=get_symbols_to_strip_from_output(generator), extra_symbols_to_ignore=get_symbols_to_strip_from_output(generator),
) )
detok_hypo_str = decode_fn(hypo_str) detok_hypo_str = decode_fn(hypo_str)
if not args.quiet: if not cfg.common_eval.quiet:
score = hypo["score"] / math.log(2) # convert to base 2 score = hypo["score"] / math.log(2) # convert to base 2
# original hypothesis (after tokenization and BPE) # original hypothesis (after tokenization and BPE)
print( print(
@ -286,7 +299,7 @@ def _main(args, output_file):
file=output_file, file=output_file,
) )
if args.print_alignment: if cfg.generation.print_alignment:
print( print(
"A-{}\t{}".format( "A-{}\t{}".format(
sample_id, sample_id,
@ -300,13 +313,13 @@ def _main(args, output_file):
file=output_file, file=output_file,
) )
if args.print_step: if cfg.generation.print_step:
print( print(
"I-{}\t{}".format(sample_id, hypo["steps"]), "I-{}\t{}".format(sample_id, hypo["steps"]),
file=output_file, file=output_file,
) )
if getattr(args, "retain_iter_history", False): if cfg.generation.retain_iter_history:
for step, h in enumerate(hypo["history"]): for step, h in enumerate(hypo["history"]):
_, h_str, _ = utils.post_process_prediction( _, h_str, _ = utils.post_process_prediction(
hypo_tokens=h["tokens"].int().cpu(), hypo_tokens=h["tokens"].int().cpu(),
@ -323,7 +336,7 @@ def _main(args, output_file):
# Score only the top hypothesis # Score only the top hypothesis
if has_target and j == 0: if has_target and j == 0:
if align_dict is not None or args.remove_bpe is not None: if align_dict is not None or cfg.common_eval.remove_bpe is not None:
# Convert back to tokens for evaluation with unk replacement and/or without BPE # Convert back to tokens for evaluation with unk replacement and/or without BPE
target_tokens = tgt_dict.encode_line( target_tokens = tgt_dict.encode_line(
target_str, add_if_not_exist=True target_str, add_if_not_exist=True
@ -353,8 +366,8 @@ def _main(args, output_file):
) )
) )
if has_target: if has_target:
if args.bpe and not args.sacrebleu: if cfg.bpe and not cfg.generation.sacrebleu:
if args.remove_bpe: if cfg.common_eval.remove_bpe:
logger.warning( logger.warning(
"BLEU score is being computed by splitting detokenized string on spaces, this is probably not what you want. Use --sacrebleu for standard 13a BLEU tokenization" "BLEU score is being computed by splitting detokenized string on spaces, this is probably not what you want. Use --sacrebleu for standard 13a BLEU tokenization"
) )
@ -365,7 +378,7 @@ def _main(args, output_file):
# use print to be consistent with other main outputs: S-, H-, T-, D- and so on # use print to be consistent with other main outputs: S-, H-, T-, D- and so on
print( print(
"Generate {} with beam={}: {}".format( "Generate {} with beam={}: {}".format(
args.gen_subset, args.beam, scorer.result_string() cfg.dataset.gen_subset, cfg.generation.beam, scorer.result_string()
), ),
file=output_file, file=output_file,
) )
@ -380,4 +393,7 @@ def cli_main():
if __name__ == "__main__": if __name__ == "__main__":
cs = ConfigStore.instance()
register_hydra_cfg(cs)
initialize(config_path="../config", strict=True)
cli_main() cli_main()

View File

@ -7,20 +7,27 @@
Translate raw text with a trained model. Batches data on-the-fly. Translate raw text with a trained model. Batches data on-the-fly.
""" """
import ast
import fileinput import fileinput
import logging import logging
import math import math
import os import os
import sys import sys
import time import time
from argparse import Namespace
from collections import namedtuple from collections import namedtuple
import numpy as np import numpy as np
import torch import torch
from fairseq import checkpoint_utils, distributed_utils, options, tasks, utils from fairseq import checkpoint_utils, distributed_utils, options, tasks, utils
from fairseq.data import encoders from fairseq.data import encoders
from fairseq.dataclass.data_class import register_hydra_cfg
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
from fairseq.token_generation_constraints import pack_constraints, unpack_constraints from fairseq.token_generation_constraints import pack_constraints, unpack_constraints
from fairseq_cli.generate import get_symbols_to_strip_from_output from fairseq_cli.generate import get_symbols_to_strip_from_output
from hydra.core.config_store import ConfigStore
from hydra.experimental import initialize
from omegaconf import DictConfig
logging.basicConfig( logging.basicConfig(
@ -49,11 +56,11 @@ def buffered_read(input, buffer_size):
yield buffer yield buffer
def make_batches(lines, args, task, max_positions, encode_fn): def make_batches(lines, cfg, task, max_positions, encode_fn):
def encode_fn_target(x): def encode_fn_target(x):
return encode_fn(x) return encode_fn(x)
if args.constraints: if cfg.generation.constraints:
# Strip (tab-delimited) contraints, if present, from input lines, # Strip (tab-delimited) contraints, if present, from input lines,
# store them in batch_constraints # store them in batch_constraints
batch_constraints = [list() for _ in lines] batch_constraints = [list() for _ in lines]
@ -79,7 +86,7 @@ def make_batches(lines, args, task, max_positions, encode_fn):
for src_str in lines for src_str in lines
] ]
if args.constraints: if cfg.generation.constraints:
constraints_tensor = pack_constraints(batch_constraints) constraints_tensor = pack_constraints(batch_constraints)
else: else:
constraints_tensor = None constraints_tensor = None
@ -89,10 +96,10 @@ def make_batches(lines, args, task, max_positions, encode_fn):
dataset=task.build_dataset_for_inference( dataset=task.build_dataset_for_inference(
tokens, lengths, constraints=constraints_tensor tokens, lengths, constraints=constraints_tensor
), ),
max_tokens=args.max_tokens, max_tokens=cfg.dataset.max_tokens,
max_sentences=args.batch_size, max_sentences=cfg.dataset.batch_size,
max_positions=max_positions, max_positions=max_positions,
ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, ignore_invalid_inputs=cfg.dataset.skip_invalid_size_inputs_valid_test,
).next_epoch_itr(shuffle=False) ).next_epoch_itr(shuffle=False)
for batch in itr: for batch in itr:
ids = batch["id"] ids = batch["id"]
@ -108,45 +115,50 @@ def make_batches(lines, args, task, max_positions, encode_fn):
) )
def main(args): def main(cfg: DictConfig):
if isinstance(cfg, Namespace):
cfg = convert_namespace_to_omegaconf(cfg)
start_time = time.time() start_time = time.time()
total_translate_time = 0 total_translate_time = 0
utils.import_user_module(args) utils.import_user_module(cfg.common)
if args.buffer_size < 1: if cfg.interactive.buffer_size < 1:
args.buffer_size = 1 cfg.interactive.buffer_size = 1
if args.max_tokens is None and args.batch_size is None: if cfg.dataset.max_tokens is None and cfg.dataset.batch_size is None:
args.batch_size = 1 cfg.dataset.batch_size = 1
assert ( assert (
not args.sampling or args.nbest == args.beam not cfg.generation.sampling or cfg.generation.nbest == cfg.generation.beam
), "--sampling requires --nbest to be equal to --beam" ), "--sampling requires --nbest to be equal to --beam"
assert ( assert (
not args.batch_size or args.batch_size <= args.buffer_size not cfg.dataset.batch_size
or cfg.dataset.batch_size <= cfg.interactive.buffer_size
), "--batch-size cannot be larger than --buffer-size" ), "--batch-size cannot be larger than --buffer-size"
logger.info(args) logger.info(cfg)
# Fix seed for stochastic decoding # Fix seed for stochastic decoding
if args.seed is not None and not args.no_seed_provided: if cfg.common.seed is not None and not cfg.generation.no_seed_provided:
np.random.seed(args.seed) np.random.seed(cfg.common.seed)
utils.set_torch_seed(args.seed) utils.set_torch_seed(cfg.common.seed)
use_cuda = torch.cuda.is_available() and not args.cpu use_cuda = torch.cuda.is_available() and not cfg.common.cpu
# Setup task, e.g., translation # Setup task, e.g., translation
task = tasks.setup_task(args) task = tasks.setup_task(cfg.task)
# Load ensemble # Load ensemble
logger.info("loading model(s) from {}".format(args.path)) overrides = ast.literal_eval(cfg.common_eval.model_overrides)
logger.info("loading model(s) from {}".format(cfg.common_eval.path))
models, _model_args = checkpoint_utils.load_model_ensemble( models, _model_args = checkpoint_utils.load_model_ensemble(
args.path.split(os.pathsep), utils.split_paths(cfg.common_eval.path),
arg_overrides=eval(args.model_overrides), arg_overrides=overrides,
task=task, task=task,
suffix=getattr(args, "checkpoint_suffix", ""), suffix=cfg.checkpoint.checkpoint_suffix,
strict=(args.checkpoint_shard_count == 1), strict=(cfg.checkpoint.checkpoint_shard_count == 1),
num_shards=args.checkpoint_shard_count, num_shards=cfg.checkpoint.checkpoint_shard_count,
) )
# Set dictionaries # Set dictionaries
@ -155,18 +167,20 @@ def main(args):
# Optimize ensemble for generation # Optimize ensemble for generation
for model in models: for model in models:
if args.fp16: if model is None:
continue
if cfg.common.fp16:
model.half() model.half()
if use_cuda and not args.pipeline_model_parallel: if use_cuda and not cfg.distributed_training.pipeline_model_parallel:
model.cuda() model.cuda()
model.prepare_for_inference_(args) model.prepare_for_inference_(cfg)
# Initialize generator # Initialize generator
generator = task.build_generator(models, args) generator = task.build_generator(models, cfg.task)
# Handle tokenization and BPE # Handle tokenization and BPE
tokenizer = encoders.build_tokenizer(args) tokenizer = encoders.build_tokenizer(cfg.tokenizer)
bpe = encoders.build_bpe(args) bpe = encoders.build_bpe(cfg.bpe)
def encode_fn(x): def encode_fn(x):
if tokenizer is not None: if tokenizer is not None:
@ -184,25 +198,25 @@ def main(args):
# Load alignment dictionary for unknown word replacement # Load alignment dictionary for unknown word replacement
# (None if no unknown word replacement, empty if no path to align dictionary) # (None if no unknown word replacement, empty if no path to align dictionary)
align_dict = utils.load_align_dict(args.replace_unk) align_dict = utils.load_align_dict(cfg.generation.replace_unk)
max_positions = utils.resolve_max_positions( max_positions = utils.resolve_max_positions(
task.max_positions(), *[model.max_positions() for model in models] task.max_positions(), *[model.max_positions() for model in models]
) )
if args.constraints: if cfg.generation.constraints:
logger.warning( logger.warning(
"NOTE: Constrained decoding currently assumes a shared subword vocabulary." "NOTE: Constrained decoding currently assumes a shared subword vocabulary."
) )
if args.buffer_size > 1: if cfg.interactive.buffer_size > 1:
logger.info("Sentence buffer size: %s", args.buffer_size) logger.info("Sentence buffer size: %s", cfg.interactive.buffer_size)
logger.info("NOTE: hypothesis and token scores are output in base 2") logger.info("NOTE: hypothesis and token scores are output in base 2")
logger.info("Type the input sentence and press return:") logger.info("Type the input sentence and press return:")
start_id = 0 start_id = 0
for inputs in buffered_read(args.input, args.buffer_size): for inputs in buffered_read(cfg.interactive.input, cfg.interactive.buffer_size):
results = [] results = []
for batch in make_batches(inputs, args, task, max_positions, encode_fn): for batch in make_batches(inputs, cfg, task, max_positions, encode_fn):
bsz = batch.src_tokens.size(0) bsz = batch.src_tokens.size(0)
src_tokens = batch.src_tokens src_tokens = batch.src_tokens
src_lengths = batch.src_lengths src_lengths = batch.src_lengths
@ -226,7 +240,7 @@ def main(args):
translate_time = time.time() - translate_start_time translate_time = time.time() - translate_start_time
total_translate_time += translate_time total_translate_time += translate_time
list_constraints = [[] for _ in range(bsz)] list_constraints = [[] for _ in range(bsz)]
if args.constraints: if cfg.generation.constraints:
list_constraints = [unpack_constraints(c) for c in constraints] list_constraints = [unpack_constraints(c) for c in constraints]
for i, (id, hypos) in enumerate(zip(batch.ids.tolist(), translations)): for i, (id, hypos) in enumerate(zip(batch.ids.tolist(), translations)):
src_tokens_i = utils.strip_pad(src_tokens[i], tgt_dict.pad()) src_tokens_i = utils.strip_pad(src_tokens[i], tgt_dict.pad())
@ -246,25 +260,25 @@ def main(args):
# sort output to match input order # sort output to match input order
for id_, src_tokens, hypos, info in sorted(results, key=lambda x: x[0]): for id_, src_tokens, hypos, info in sorted(results, key=lambda x: x[0]):
if src_dict is not None: if src_dict is not None:
src_str = src_dict.string(src_tokens, args.remove_bpe) src_str = src_dict.string(src_tokens, cfg.common_eval.remove_bpe)
print("S-{}\t{}".format(id_, src_str)) print("S-{}\t{}".format(id_, src_str))
print("W-{}\t{:.3f}\tseconds".format(id_, info["time"])) print("W-{}\t{:.3f}\tseconds".format(id_, info["time"]))
for constraint in info["constraints"]: for constraint in info["constraints"]:
print( print(
"C-{}\t{}".format( "C-{}\t{}".format(
id_, tgt_dict.string(constraint, args.remove_bpe) id_, tgt_dict.string(constraint, cfg.common_eval.remove_bpe)
) )
) )
# Process top predictions # Process top predictions
for hypo in hypos[: min(len(hypos), args.nbest)]: for hypo in hypos[: min(len(hypos), cfg.generation.nbest)]:
hypo_tokens, hypo_str, alignment = utils.post_process_prediction( hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
hypo_tokens=hypo["tokens"].int().cpu(), hypo_tokens=hypo["tokens"].int().cpu(),
src_str=src_str, src_str=src_str,
alignment=hypo["alignment"], alignment=hypo["alignment"],
align_dict=align_dict, align_dict=align_dict,
tgt_dict=tgt_dict, tgt_dict=tgt_dict,
remove_bpe=args.remove_bpe, remove_bpe=cfg.common_eval.remove_bpe,
extra_symbols_to_ignore=get_symbols_to_strip_from_output(generator), extra_symbols_to_ignore=get_symbols_to_strip_from_output(generator),
) )
detok_hypo_str = decode_fn(hypo_str) detok_hypo_str = decode_fn(hypo_str)
@ -285,7 +299,7 @@ def main(args):
), ),
) )
) )
if args.print_alignment: if cfg.generation.print_alignment:
alignment_str = " ".join( alignment_str = " ".join(
["{}-{}".format(src, tgt) for src, tgt in alignment] ["{}-{}".format(src, tgt) for src, tgt in alignment]
) )
@ -308,4 +322,7 @@ def cli_main():
if __name__ == "__main__": if __name__ == "__main__":
cs = ConfigStore.instance()
register_hydra_cfg(cs)
initialize(config_path="../config", strict=True)
cli_main() cli_main()

View File

@ -78,7 +78,13 @@ def cli_main():
def score(fdsys): def score(fdsys):
with open(args.ref) as fdref: with open(args.ref) as fdref:
scorer = bleu.Scorer(dict.pad(), dict.eos(), dict.unk()) scorer = bleu.Scorer(
bleu.BleuConfig(
pad=dict.pad(),
eos=dict.eos(),
unk=dict.unk(),
)
)
for sys_tok, ref_tok in zip(readlines(fdsys), readlines(fdref)): for sys_tok, ref_tok in zip(readlines(fdsys), readlines(fdref)):
sys_tok = dict.encode_line(sys_tok) sys_tok = dict.encode_line(sys_tok)
ref_tok = dict.encode_line(ref_tok) ref_tok = dict.encode_line(ref_tok)

View File

@ -11,11 +11,13 @@ import argparse
import logging import logging
import math import math
import os import os
import random
import sys import sys
from typing import Dict, Optional, Any, List, Tuple, Callable
import numpy as np import numpy as np
import torch import torch
from hydra.core.config_store import ConfigStore
from fairseq import ( from fairseq import (
checkpoint_utils, checkpoint_utils,
distributed_utils, distributed_utils,
@ -25,8 +27,12 @@ from fairseq import (
utils, utils,
) )
from fairseq.data import iterators from fairseq.data import iterators
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
from fairseq.logging import meters, metrics, progress_bar from fairseq.logging import meters, metrics, progress_bar
from fairseq.model_parallel.megatron_trainer import MegatronTrainer from fairseq.model_parallel.megatron_trainer import MegatronTrainer
from omegaconf import DictConfig
from hydra.experimental import initialize
from fairseq.dataclass.data_class import register_hydra_cfg
from fairseq.trainer import Trainer from fairseq.trainer import Trainer
@ -39,90 +45,86 @@ logging.basicConfig(
logger = logging.getLogger("fairseq_cli.train") logger = logging.getLogger("fairseq_cli.train")
def main(args): def main(cfg: DictConfig) -> None:
utils.import_user_module(args) if isinstance(cfg, argparse.Namespace):
cfg = convert_namespace_to_omegaconf(cfg)
assert ( utils.import_user_module(cfg.common)
args.max_tokens is not None or args.batch_size is not None
), "Must specify batch size either with --max-tokens or --batch-size"
assert cfg.dataset.max_tokens is not None or cfg.dataset.batch_size is not None, \
'Must specify batch size either with --max-tokens or --batch-size'
metrics.reset() metrics.reset()
np.random.seed(args.seed) np.random.seed(cfg.common.seed)
utils.set_torch_seed(args.seed) utils.set_torch_seed(cfg.common.seed)
if distributed_utils.is_master(args): if distributed_utils.is_master(cfg.distributed_training):
checkpoint_utils.verify_checkpoint_directory(args.save_dir) checkpoint_utils.verify_checkpoint_directory(cfg.checkpoint.save_dir)
# Print args # Print args
logger.info(args) logger.info(cfg)
# Setup task, e.g., translation, language modeling, etc. # Setup task, e.g., translation, language modeling, etc.
task = tasks.setup_task(args) task = tasks.setup_task(cfg.task)
# Load valid dataset (we load training data below, based on the latest checkpoint) # Load valid dataset (we load training data below, based on the latest checkpoint)
for valid_sub_split in args.valid_subset.split(","): for valid_sub_split in cfg.dataset.valid_subset.split(','):
task.load_dataset(valid_sub_split, combine=False, epoch=1) task.load_dataset(valid_sub_split, combine=False, epoch=1)
# Build model and criterion # Build model and criterion
model = task.build_model(args) model = task.build_model(cfg.model)
criterion = task.build_criterion(args) criterion = task.build_criterion(cfg.criterion)
logger.info(model) logger.info(model)
logger.info("task: {} ({})".format(args.task, task.__class__.__name__)) logger.info("task: {} ({})".format(cfg.task._name, task.__class__.__name__))
logger.info("model: {} ({})".format(args.arch, model.__class__.__name__)) logger.info("model: {} ({})".format(cfg.model._name, model.__class__.__name__))
logger.info( logger.info(
"criterion: {} ({})".format(args.criterion, criterion.__class__.__name__) "criterion: {} ({})".format(cfg.criterion._name, criterion.__class__.__name__)
)
logger.info(
"num. model params: {} (num. trained: {})".format(
sum(p.numel() for p in model.parameters()),
sum(p.numel() for p in model.parameters() if p.requires_grad),
)
) )
logger.info("num. model params: {} (num. trained: {})".format(
sum(p.numel() for p in model.parameters()),
sum(p.numel() for p in model.parameters() if p.requires_grad),
))
# (optionally) Configure quantization # (optionally) Configure quantization
if args.quantization_config_path is not None: if cfg.common.quantization_config_path is not None:
quantizer = quantization_utils.Quantizer( quantizer = quantization_utils.Quantizer(
config_path=args.quantization_config_path, config_path=cfg.common.quantization_config_path,
max_epoch=args.max_epoch, max_epoch=cfg.optimization.max_epoch,
max_update=args.max_update, max_update=cfg.optimization.max_update,
) )
else: else:
quantizer = None quantizer = None
# Build trainer # Build trainer
if args.model_parallel_size == 1: if cfg.common.model_parallel_size == 1:
trainer = Trainer(args, task, model, criterion, quantizer) trainer = Trainer(cfg, task, model, criterion, quantizer)
else: else:
trainer = MegatronTrainer(args, task, model, criterion) trainer = MegatronTrainer(cfg, task, model, criterion)
logger.info( logger.info('training on {} devices (GPUs/TPUs)'.format(cfg.distributed_training.distributed_world_size))
"training on {} devices (GPUs/TPUs)".format(args.distributed_world_size) logger.info('max tokens per GPU = {} and batch size per GPU = {}'.format(
) cfg.dataset.max_tokens,
logger.info( cfg.dataset.batch_size,
"max tokens per GPU = {} and max sentences per GPU = {}".format( ))
args.max_tokens, args.batch_size
)
)
# Load the latest checkpoint if one is available and restore the # Load the latest checkpoint if one is available and restore the
# corresponding train iterator # corresponding train iterator
extra_state, epoch_itr = checkpoint_utils.load_checkpoint( extra_state, epoch_itr = checkpoint_utils.load_checkpoint(
args, cfg.checkpoint,
trainer, trainer,
# don't cache epoch iterators for sharded datasets # don't cache epoch iterators for sharded datasets
disable_iterator_cache=task.has_sharded_data("train"), disable_iterator_cache=task.has_sharded_data("train"),
) )
# Train until the learning rate gets too small max_epoch = cfg.optimization.max_epoch or math.inf
max_epoch = args.max_epoch or math.inf
lr = trainer.get_lr() lr = trainer.get_lr()
train_meter = meters.StopwatchMeter() train_meter = meters.StopwatchMeter()
train_meter.start() train_meter.start()
while (
while lr > args.min_lr and epoch_itr.next_epoch_idx <= max_epoch: lr > cfg.optimization.min_lr
and epoch_itr.next_epoch_idx <= max_epoch
):
# train for one epoch # train for one epoch
valid_losses, should_stop = train(args, trainer, task, epoch_itr) valid_losses, should_stop = train(cfg, trainer, task, epoch_itr)
if should_stop: if should_stop:
break break
@ -140,15 +142,15 @@ def main(args):
logger.info("done training in {:.1f} seconds".format(train_meter.sum)) logger.info("done training in {:.1f} seconds".format(train_meter.sum))
def should_stop_early(args, valid_loss): def should_stop_early(cfg: DictConfig, valid_loss: float) -> bool:
# skip check if no validation was done in the current epoch # skip check if no validation was done in the current epoch
if valid_loss is None: if valid_loss is None:
return False return False
if args.patience <= 0: if cfg.checkpoint.patience <= 0:
return False return False
def is_better(a, b): def is_better(a, b):
return a > b if args.maximize_best_checkpoint_metric else a < b return a > b if cfg.checkpoint.maximize_best_checkpoint_metric else a < b
prev_best = getattr(should_stop_early, "best", None) prev_best = getattr(should_stop_early, "best", None)
if prev_best is None or is_better(valid_loss, prev_best): if prev_best is None or is_better(valid_loss, prev_best):
@ -157,48 +159,43 @@ def should_stop_early(args, valid_loss):
return False return False
else: else:
should_stop_early.num_runs += 1 should_stop_early.num_runs += 1
if should_stop_early.num_runs >= args.patience: if should_stop_early.num_runs >= cfg.checkpoint.patience:
logger.info( logger.info('early stop since valid performance hasn\'t improved for last {} runs'.format(cfg.checkpoint.patience))
"early stop since valid performance hasn't improved for last {} runs".format(
args.patience
)
)
return True return True
else: else:
return False return False
@metrics.aggregate("train") @metrics.aggregate("train")
def train(args, trainer, task, epoch_itr): def train(cfg: DictConfig, trainer: Trainer, task: tasks.FairseqTask, epoch_itr) -> Tuple[List[Optional[float]], bool]:
"""Train the model for one epoch and return validation losses.""" """Train the model for one epoch and return validation losses."""
# Initialize data iterator # Initialize data iterator
itr = epoch_itr.next_epoch_itr( itr = epoch_itr.next_epoch_itr(
fix_batches_to_gpus=args.fix_batches_to_gpus, fix_batches_to_gpus=cfg.distributed_training.fix_batches_to_gpus,
shuffle=(epoch_itr.next_epoch_idx > args.curriculum), shuffle=(epoch_itr.next_epoch_idx > cfg.dataset.curriculum),
) )
update_freq = ( update_freq = (
args.update_freq[epoch_itr.epoch - 1] cfg.optimization.update_freq[epoch_itr.epoch - 1]
if epoch_itr.epoch <= len(args.update_freq) if epoch_itr.epoch <= len(cfg.optimization.update_freq)
else args.update_freq[-1] else cfg.optimization.update_freq[-1]
) )
itr = iterators.GroupedIterator(itr, update_freq) itr = iterators.GroupedIterator(itr, update_freq)
if getattr(args, "tpu", False): if getattr(cfg.common, "tpu", False):
itr = utils.tpu_data_loader(itr) itr = utils.tpu_data_loader(itr)
progress = progress_bar.progress_bar( progress = progress_bar.progress_bar(
itr, itr,
log_format=args.log_format, log_format=cfg.common.log_format,
log_interval=args.log_interval, log_interval=cfg.common.log_interval,
epoch=epoch_itr.epoch, epoch=epoch_itr.epoch,
tensorboard_logdir=( tensorboard_logdir=(
args.tensorboard_logdir if distributed_utils.is_master(args) else None cfg.common.tensorboard_logdir if distributed_utils.is_master(cfg.distributed_training) else None
), ),
default_log_format=("tqdm" if not args.no_progress_bar else "simple"), default_log_format=('tqdm' if not cfg.common.no_progress_bar else 'simple'),
) )
trainer.begin_epoch(epoch_itr.epoch) trainer.begin_epoch(epoch_itr.epoch)
valid_losses = [None] valid_subsets = cfg.dataset.valid_subset.split(',')
valid_subsets = args.valid_subset.split(",")
should_stop = False should_stop = False
num_updates = trainer.get_num_updates() num_updates = trainer.get_num_updates()
for i, samples in enumerate(progress): for i, samples in enumerate(progress):
@ -210,7 +207,7 @@ def train(args, trainer, task, epoch_itr):
if log_output is not None: # not OOM, overflow, ... if log_output is not None: # not OOM, overflow, ...
# log mid-epoch stats # log mid-epoch stats
num_updates = trainer.get_num_updates() num_updates = trainer.get_num_updates()
if num_updates % args.log_interval == 0: if num_updates % cfg.common.log_interval == 0:
stats = get_training_stats(metrics.get_smoothed_values("train_inner")) stats = get_training_stats(metrics.get_smoothed_values("train_inner"))
progress.log(stats, tag="train_inner", step=num_updates) progress.log(stats, tag="train_inner", step=num_updates)
@ -220,7 +217,7 @@ def train(args, trainer, task, epoch_itr):
end_of_epoch = not itr.has_next() end_of_epoch = not itr.has_next()
valid_losses, should_stop = validate_and_save( valid_losses, should_stop = validate_and_save(
args, trainer, task, epoch_itr, valid_subsets, end_of_epoch cfg, trainer, task, epoch_itr, valid_subsets, end_of_epoch
) )
if should_stop: if should_stop:
@ -236,64 +233,64 @@ def train(args, trainer, task, epoch_itr):
return valid_losses, should_stop return valid_losses, should_stop
def validate_and_save(args, trainer, task, epoch_itr, valid_subsets, end_of_epoch): def validate_and_save(cfg: DictConfig, trainer: Trainer, task: tasks.FairseqTask, epoch_itr, valid_subsets: List[str], end_of_epoch: bool) -> Tuple[List[Optional[float]], bool]:
num_updates = trainer.get_num_updates() num_updates = trainer.get_num_updates()
max_update = args.max_update or math.inf max_update = cfg.optimization.max_update or math.inf
do_save = ( do_save = (
(end_of_epoch and epoch_itr.epoch % args.save_interval == 0) (end_of_epoch and epoch_itr.epoch % cfg.checkpoint.save_interval == 0)
or num_updates >= max_update or num_updates >= max_update
or ( or (
args.save_interval_updates > 0 cfg.checkpoint.save_interval_updates > 0
and num_updates > 0 and num_updates > 0
and num_updates % args.save_interval_updates == 0 and num_updates % cfg.checkpoint.save_interval_updates == 0
and num_updates >= args.validate_after_updates and num_updates >= cfg.dataset.validate_after_updates
) )
) )
do_validate = ( do_validate = (
(not end_of_epoch and do_save) # validate during mid-epoch saves (not end_of_epoch and do_save) # validate during mid-epoch saves
or (end_of_epoch and epoch_itr.epoch % args.validate_interval == 0) or (end_of_epoch and epoch_itr.epoch % cfg.dataset.validate_interval == 0)
or num_updates >= max_update or num_updates >= max_update
or ( or (
args.validate_interval_updates > 0 cfg.dataset.validate_interval_updates > 0
and num_updates > 0 and num_updates > 0
and num_updates % args.validate_interval_updates == 0 and num_updates % cfg.dataset.validate_interval_updates == 0
) )
) and not args.disable_validation ) and not cfg.dataset.disable_validation
# Validate # Validate
valid_losses = [None] valid_losses = [None]
if do_validate: if do_validate:
valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets) valid_losses = validate(cfg, trainer, task, epoch_itr, valid_subsets)
# Stopping conditions # Stopping conditions
should_stop = ( should_stop = (
should_stop_early(args, valid_losses[0]) should_stop_early(cfg, valid_losses[0])
or num_updates >= max_update or num_updates >= max_update
or ( or (
args.stop_time_hours > 0 cfg.optimization.stop_time_hours > 0
and trainer.cumulative_training_time() / (60 * 60) > args.stop_time_hours and trainer.cumulative_training_time() / (60 * 60) > cfg.optimization.stop_time_hours
) )
) )
# Save checkpoint # Save checkpoint
if do_save or should_stop: if do_save or should_stop:
logger.info("begin save checkpoint") logger.info("begin save checkpoint")
checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, valid_losses[0]) checkpoint_utils.save_checkpoint(cfg.checkpoint, trainer, epoch_itr, valid_losses[0])
return valid_losses, should_stop return valid_losses, should_stop
def get_training_stats(stats): def get_training_stats(stats: Dict[str, Any]) -> Dict[str, Any]:
stats["wall"] = round(metrics.get_meter("default", "wall").elapsed_time, 0) stats["wall"] = round(metrics.get_meter("default", "wall").elapsed_time, 0)
return stats return stats
def validate(args, trainer, task, epoch_itr, subsets): def validate(cfg: DictConfig, trainer: Trainer, task: tasks.FairseqTask, epoch_itr, subsets: List[str]) -> List[Optional[float]]:
"""Evaluate the model on the validation set(s) and return the losses.""" """Evaluate the model on the validation set(s) and return the losses."""
if args.fixed_validation_seed is not None: if cfg.dataset.fixed_validation_seed is not None:
# set fixed seed for every validation # set fixed seed for every validation
utils.set_torch_seed(args.fixed_validation_seed) utils.set_torch_seed(cfg.dataset.fixed_validation_seed)
trainer.begin_valid_epoch(epoch_itr.epoch) trainer.begin_valid_epoch(epoch_itr.epoch)
valid_losses = [] valid_losses = []
@ -302,18 +299,18 @@ def validate(args, trainer, task, epoch_itr, subsets):
# Initialize data iterator # Initialize data iterator
itr = trainer.get_valid_iterator(subset).next_epoch_itr(shuffle=False) itr = trainer.get_valid_iterator(subset).next_epoch_itr(shuffle=False)
if getattr(args, "tpu", False): if cfg.common.tpu:
itr = utils.tpu_data_loader(itr) itr = utils.tpu_data_loader(itr)
progress = progress_bar.progress_bar( progress = progress_bar.progress_bar(
itr, itr,
log_format=args.log_format, log_format=cfg.common.log_format,
log_interval=args.log_interval, log_interval=cfg.common.log_interval,
epoch=epoch_itr.epoch, epoch=epoch_itr.epoch,
prefix=f"valid on '{subset}' subset", prefix=f"valid on '{subset}' subset",
tensorboard_logdir=( tensorboard_logdir=(
args.tensorboard_logdir if distributed_utils.is_master(args) else None cfg.common.tensorboard_logdir if distributed_utils.is_master(cfg.distributed_training) else None
), ),
default_log_format=("tqdm" if not args.no_progress_bar else "simple"), default_log_format=('tqdm' if not cfg.common.no_progress_bar else 'simple'),
) )
# create a new root metrics aggregator so validation metrics # create a new root metrics aggregator so validation metrics
@ -323,34 +320,40 @@ def validate(args, trainer, task, epoch_itr, subsets):
trainer.valid_step(sample) trainer.valid_step(sample)
# log validation stats # log validation stats
stats = get_valid_stats(args, trainer, agg.get_smoothed_values()) stats = get_valid_stats(cfg, trainer, agg.get_smoothed_values())
progress.print(stats, tag=subset, step=trainer.get_num_updates()) progress.print(stats, tag=subset, step=trainer.get_num_updates())
valid_losses.append(stats[args.best_checkpoint_metric]) valid_losses.append(stats[cfg.checkpoint.best_checkpoint_metric])
return valid_losses return valid_losses
def get_valid_stats(args, trainer, stats): def get_valid_stats(cfg: DictConfig, trainer: Trainer, stats: Dict[str, Any]) -> Dict[str, Any]:
stats["num_updates"] = trainer.get_num_updates() stats["num_updates"] = trainer.get_num_updates()
if hasattr(checkpoint_utils.save_checkpoint, "best"): if hasattr(checkpoint_utils.save_checkpoint, "best"):
key = "best_{0}".format(args.best_checkpoint_metric) key = "best_{0}".format(cfg.checkpoint.best_checkpoint_metric)
best_function = max if args.maximize_best_checkpoint_metric else min best_function = max if cfg.checkpoint.maximize_best_checkpoint_metric else min
stats[key] = best_function( stats[key] = best_function(
checkpoint_utils.save_checkpoint.best, stats[args.best_checkpoint_metric] checkpoint_utils.save_checkpoint.best, stats[cfg.checkpoint.best_checkpoint_metric]
) )
return stats return stats
def cli_main(modify_parser=None): def cli_main(modify_parser: Optional[Callable[[argparse.ArgumentParser], None]] = None) -> None:
parser = options.get_training_parser() parser = options.get_training_parser()
args = options.parse_args_and_arch(parser, modify_parser=modify_parser) args = options.parse_args_and_arch(parser, modify_parser=modify_parser)
cfg = convert_namespace_to_omegaconf(args)
if args.profile: if args.profile:
with torch.cuda.profiler.profile(): with torch.cuda.profiler.profile():
with torch.autograd.profiler.emit_nvtx(): with torch.autograd.profiler.emit_nvtx():
distributed_utils.call_main(args, main) distributed_utils.call_main(cfg, main)
else: else:
distributed_utils.call_main(args, main) distributed_utils.call_main(cfg, main)
if __name__ == "__main__": if __name__ == '__main__':
cs = ConfigStore.instance()
register_hydra_cfg(cs)
initialize(config_path="../config", strict=True)
cli_main() cli_main()

View File

@ -1,5 +1,5 @@
#!/usr/bin/env python3 -u #!/usr/bin/env python3 -u
#!/usr/bin/env python3 -u # !/usr/bin/env python3 -u
# Copyright (c) Facebook, Inc. and its affiliates. # Copyright (c) Facebook, Inc. and its affiliates.
# #
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
@ -8,11 +8,17 @@
import logging import logging
import os import os
import sys import sys
from argparse import Namespace
from itertools import chain from itertools import chain
import torch import torch
from fairseq import checkpoint_utils, distributed_utils, options, utils from fairseq import checkpoint_utils, distributed_utils, options, utils
from fairseq.dataclass.data_class import register_hydra_cfg
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
from fairseq.logging import metrics, progress_bar from fairseq.logging import metrics, progress_bar
from hydra.core.config_store import ConfigStore
from hydra.experimental import initialize
from omegaconf import DictConfig
logging.basicConfig( logging.basicConfig(
@ -24,18 +30,21 @@ logging.basicConfig(
logger = logging.getLogger("fairseq_cli.validate") logger = logging.getLogger("fairseq_cli.validate")
def main(args, override_args=None): def main(cfg: DictConfig, override_args=None):
utils.import_user_module(args) if isinstance(cfg, Namespace):
cfg = convert_namespace_to_omegaconf(cfg)
utils.import_user_module(cfg.common)
assert ( assert (
args.max_tokens is not None or args.batch_size is not None cfg.dataset.max_tokens is not None or cfg.dataset.batch_size is not None
), "Must specify batch size either with --max-tokens or --batch-size" ), "Must specify batch size either with --max-tokens or --batch-size"
use_fp16 = args.fp16 use_fp16 = cfg.common.fp16
use_cuda = torch.cuda.is_available() and not args.cpu use_cuda = torch.cuda.is_available() and not cfg.common.cpu
if use_cuda: if use_cuda:
torch.cuda.set_device(args.device_id) torch.cuda.set_device(cfg.distributed_training.device_id)
if override_args is not None: if override_args is not None:
overrides = vars(override_args) overrides = vars(override_args)
@ -44,11 +53,11 @@ def main(args, override_args=None):
overrides = None overrides = None
# Load ensemble # Load ensemble
logger.info("loading model(s) from {}".format(args.path)) logger.info("loading model(s) from {}".format(cfg.common_eval.path))
models, model_args, task = checkpoint_utils.load_model_ensemble_and_task( models, model_args, task = checkpoint_utils.load_model_ensemble_and_task(
[args.path], [cfg.common_eval.path],
arg_overrides=overrides, arg_overrides=overrides,
suffix=getattr(args, "checkpoint_suffix", ""), suffix=cfg.checkpoint.checkpoint_suffix,
) )
model = models[0] model = models[0]
@ -63,10 +72,10 @@ def main(args, override_args=None):
logger.info(model_args) logger.info(model_args)
# Build criterion # Build criterion
criterion = task.build_criterion(model_args) criterion = task.build_criterion(model_args.criterion)
criterion.eval() criterion.eval()
for subset in args.valid_subset.split(","): for subset in cfg.dataset.valid_subset.split(","):
try: try:
task.load_dataset(subset, combine=False, epoch=1) task.load_dataset(subset, combine=False, epoch=1)
dataset = task.dataset(subset) dataset = task.dataset(subset)
@ -76,26 +85,26 @@ def main(args, override_args=None):
# Initialize data iterator # Initialize data iterator
itr = task.get_batch_iterator( itr = task.get_batch_iterator(
dataset=dataset, dataset=dataset,
max_tokens=args.max_tokens, max_tokens=cfg.dataset.max_tokens,
max_sentences=args.batch_size, max_sentences=cfg.dataset.batch_size,
max_positions=utils.resolve_max_positions( max_positions=utils.resolve_max_positions(
task.max_positions(), task.max_positions(),
*[m.max_positions() for m in models], *[m.max_positions() for m in models],
), ),
ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, ignore_invalid_inputs=cfg.dataset.skip_invalid_size_inputs_valid_test,
required_batch_size_multiple=args.required_batch_size_multiple, required_batch_size_multiple=cfg.dataset.required_batch_size_multiple,
seed=args.seed, seed=cfg.common.seed,
num_shards=args.distributed_world_size, num_shards=cfg.distributed_training.distributed_world_size,
shard_id=args.distributed_rank, shard_id=cfg.distributed_training.distributed_rank,
num_workers=args.num_workers, num_workers=cfg.dataset.num_workers,
data_buffer_size=args.data_buffer_size, data_buffer_size=cfg.dataset.data_buffer_size,
).next_epoch_itr(shuffle=False) ).next_epoch_itr(shuffle=False)
progress = progress_bar.progress_bar( progress = progress_bar.progress_bar(
itr, itr,
log_format=args.log_format, log_format=cfg.common.log_format,
log_interval=args.log_interval, log_interval=cfg.common.log_interval,
prefix=f"valid on '{subset}' subset", prefix=f"valid on '{subset}' subset",
default_log_format=("tqdm" if not args.no_progress_bar else "simple"), default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"),
) )
log_outputs = [] log_outputs = []
@ -105,10 +114,10 @@ def main(args, override_args=None):
progress.log(log_output, step=i) progress.log(log_output, step=i)
log_outputs.append(log_output) log_outputs.append(log_output)
if args.distributed_world_size > 1: if cfg.distributed_training.distributed_world_size > 1:
log_outputs = distributed_utils.all_gather_list( log_outputs = distributed_utils.all_gather_list(
log_outputs, log_outputs,
max_size=getattr(args, "all_gather_list_size", 16384), max_size=cfg.common.all_gather_list_size,
) )
log_outputs = list(chain.from_iterable(log_outputs)) log_outputs = list(chain.from_iterable(log_outputs))
@ -131,4 +140,7 @@ def cli_main():
if __name__ == "__main__": if __name__ == "__main__":
cs = ConfigStore.instance()
register_hydra_cfg(cs)
initialize(config_path="../config", strict=True)
cli_main() cli_main()

View File

@ -272,6 +272,7 @@ class TestFairseqEncoderDecoderModelBase(TestBaseFairseqModelBase):
model_cls.add_args(parser) model_cls.add_args(parser)
args = parser.parse_args([]) args = parser.parse_args([])
if extra_args_setters is not None: if extra_args_setters is not None:
for args_setter in extra_args_setters: for args_setter in extra_args_setters:
args_setter(args) args_setter(args)
@ -515,9 +516,7 @@ class CrossEntropyCriterionTestBase(unittest.TestCase):
def setUp(self): def setUp(self):
args = self.setUpArgs() args = self.setUpArgs()
self.model = DummyEncoderModel(encoder=DummyEncoder()) self.model = DummyEncoderModel(encoder=DummyEncoder())
self.criterion = self.criterion_cls.build_criterion( self.criterion = self.criterion_cls.build_criterion(args, task=DummyTask(args))
args=args, task=DummyTask(args)
)
def get_src_tokens(self, correct_prediction, aggregate): def get_src_tokens(self, correct_prediction, aggregate):
""" """

View File

@ -11,7 +11,7 @@ from multiprocessing import Manager
import torch import torch
import torch.nn as nn import torch.nn as nn
from fairseq import distributed_utils, optim from fairseq import distributed_utils, optim
from omegaconf import OmegaConf
class Model(nn.Module): class Model(nn.Module):
def __init__(self, input_size, output_size): def __init__(self, input_size, output_size):
@ -23,13 +23,14 @@ class Model(nn.Module):
return output return output
def setup_model_loss_criterion(args, rank, is_cuda): def setup_model_loss_criterion(cfg, args, rank, is_cuda):
""" """
setup model, criterion and optimizer based on input args setup model, criterion and optimizer based on input args
""" """
args.distributed_rank = rank args.distributed_rank = rank
if args.distributed_world_size > 1: cfg.distributed_training.distributed_rank = args.distributed_rank
distributed_utils.distributed_init(args) if cfg.distributed_training.distributed_world_size > 1:
distributed_utils.distributed_init(cfg)
torch.manual_seed(1) torch.manual_seed(1)
model = Model(args.input_size, args.nb_classes) model = Model(args.input_size, args.nb_classes)
loss_fn = nn.CrossEntropyLoss() loss_fn = nn.CrossEntropyLoss()
@ -38,7 +39,10 @@ def setup_model_loss_criterion(args, rank, is_cuda):
loss_fn = loss_fn.cuda() loss_fn = loss_fn.cuda()
optimizer = optim.sgd.SGD(args, model.parameters()) optimizer = optim.sgd.SGD(args, model.parameters())
optimizer = optim.FairseqBMUF(args, optimizer) optimizer = optim.FairseqBMUF(
cfg=cfg.bmuf,
optimizer=optimizer
)
return model, loss_fn, optimizer return model, loss_fn, optimizer
@ -52,13 +56,13 @@ def train_step(input, target, model, loss_fn, optimizer, **unused):
optimizer.step() optimizer.step()
def single_gpu_training(args, rank, iterations, shared_results): def single_gpu_training(cfg, args, rank, iterations, shared_results):
is_cuda = torch.cuda.is_available() is_cuda = torch.cuda.is_available()
if is_cuda: if is_cuda:
torch.cuda.set_device(rank) torch.cuda.set_device(rank)
model, loss_fn, optimizer = setup_model_loss_criterion(args, rank, is_cuda) model, loss_fn, optimizer = setup_model_loss_criterion(cfg, args, rank, is_cuda)
for _ in range(iterations): for _ in range(iterations):
input = torch.randn(1, args.input_size) input = torch.randn(1, args.input_size)
@ -103,18 +107,44 @@ def setup_args():
args.distributed_init_host = "localhost" args.distributed_init_host = "localhost"
args.distributed_port = port + 1 args.distributed_port = port + 1
args.local_world_size = args.distributed_world_size args.local_world_size = args.distributed_world_size
return args
cfg = OmegaConf.create()
cfg.optimization = OmegaConf.create()
cfg.common = OmegaConf.create()
cfg.distributed_training = OmegaConf.create()
cfg.dataset = OmegaConf.create()
cfg.bmuf = OmegaConf.create()
cfg.optimizer = OmegaConf.create()
cfg.bmuf.global_sync_iter = args.global_sync_iter
cfg.bmuf.block_momentum = args.block_momentum
cfg.bmuf.block_lr = args.block_lr
cfg.dataset.batch_size = args.batch_size
cfg.optimization.lr = args.lr
cfg.optimizer.momentum = args.momentum
cfg.optimizer.weight_decay = args.weight_decay
cfg.bmuf.warmup_iterations = args.warmup_iterations
cfg.bmuf.use_nbm = args.use_nbm
cfg.bmuf.average_sync = args.average_sync
cfg.common.model_parallel_size = args.model_parallel_size
cfg.distributed_training.distributed_backend = args.distributed_backend
cfg.distributed_training.distributed_world_size = args.distributed_world_size
cfg.bmuf.distributed_world_size = args.distributed_world_size
cfg.distributed_training.distributed_init_method = args.distributed_init_method
cfg.distributed_training.distributed_port = args.distributed_port
return cfg, args
@unittest.skipIf(torch.cuda.device_count() < 2, "test requires 2 GPUs") @unittest.skipIf(torch.cuda.device_count() < 2, "test requires 2 GPUs")
class TestBMUF(unittest.TestCase): class TestBMUF(unittest.TestCase):
def bmuf_process(self, args, iterations): def bmuf_process(self, cfg, args, iterations):
processes = [] processes = []
results = Manager().dict() results = Manager().dict()
ctx = torch.multiprocessing.get_context("spawn") ctx = torch.multiprocessing.get_context("spawn")
for rank in range(args.distributed_world_size): for rank in range(args.distributed_world_size):
p = ctx.Process( p = ctx.Process(
target=single_gpu_training, args=(args, rank, iterations, results) target=single_gpu_training, args=(cfg, args, rank, iterations, results)
) )
p.start() p.start()
processes.append(p) processes.append(p)
@ -125,19 +155,20 @@ class TestBMUF(unittest.TestCase):
def test_bmuf_sync(self): def test_bmuf_sync(self):
# Train model for 1 iteration and do bmuf sync without doing warmup # Train model for 1 iteration and do bmuf sync without doing warmup
args = setup_args() cfg, args = setup_args()
iterations = 1 iterations = 1
results = self.bmuf_process(args, iterations) results = self.bmuf_process(cfg, args, iterations)
# Make sure params in both machines are same # Make sure params in both machines are same
assert len(results) == 2 assert len(results) == 2
self.assertAlmostEqual(results[0], results[1]) self.assertAlmostEqual(results[0], results[1])
def test_warmup_sync(self): def test_warmup_sync(self):
# Train model for 20 iteration and do warmup sync without doing bmuf sync # Train model for 20 iteration and do warmup sync without doing bmuf sync
args = setup_args() cfg, args = setup_args()
args.warmup_iterations = 20 args.warmup_iterations = 20
cfg.bmuf.warmup_iterations = args.warmup_iterations
iterations = 20 iterations = 20
results = self.bmuf_process(args, iterations) results = self.bmuf_process(cfg, args, iterations)
# Make sure params in both machines are same # Make sure params in both machines are same
assert len(results) == 2 assert len(results) == 2
self.assertAlmostEqual(results[0], results[1]) self.assertAlmostEqual(results[0], results[1])
@ -145,22 +176,27 @@ class TestBMUF(unittest.TestCase):
def test_warmup_sync_bmuf_sync(self): def test_warmup_sync_bmuf_sync(self):
# Train model for 25 iteration and do warmup sync after 20 iteration # Train model for 25 iteration and do warmup sync after 20 iteration
# and bmuf sync after 25 iteration # and bmuf sync after 25 iteration
args = setup_args() cfg, args = setup_args()
args.warmup_iterations = 20 args.warmup_iterations = 20
args.global_sync_iter = 5 args.global_sync_iter = 5
cfg.bmuf.warmup_iterations = args.warmup_iterations
cfg.bmuf.global_sync_iter = args.global_sync_iter
iterations = 25 iterations = 25
results = self.bmuf_process(args, iterations) results = self.bmuf_process(cfg, args, iterations)
# Make sure params in both machines are same # Make sure params in both machines are same
assert len(results) == 2 assert len(results) == 2
self.assertAlmostEqual(results[0], results[1]) self.assertAlmostEqual(results[0], results[1])
def test_single_gpu_bmuf(self): def test_single_gpu_bmuf(self):
# Train model for 5 iterations and use GPU 1 # Train model for 5 iterations and use GPU 1
args = setup_args() cfg, args = setup_args()
args.distributed_world_size = 1 args.distributed_world_size = 1
args.warmup_iterations = 5 args.warmup_iterations = 5
cfg.distributed_training.distributed_world_size = args.distributed_world_size
cfg.bmuf.distributed_world_size = args.distributed_world_size
cfg.bmuf.warmup_iterations = args.warmup_iterations
iterations = 20 iterations = 20
results = self.bmuf_process(args, iterations) results = self.bmuf_process(cfg, args, iterations)
assert len(results) == 1 assert len(results) == 1
def assertAlmostEqual(self, t1, t2): def assertAlmostEqual(self, t1, t2):

View File

@ -9,6 +9,7 @@ import unittest
import torch import torch
from fairseq.optim.fp16_optimizer import FP16Optimizer, MemoryEfficientFP16Optimizer from fairseq.optim.fp16_optimizer import FP16Optimizer, MemoryEfficientFP16Optimizer
from omegaconf import OmegaConf
@unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU") @unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU")
@ -27,17 +28,23 @@ class TestGradientScaling(unittest.TestCase):
self.model.cuda().half() self.model.cuda().half()
self.params = list(self.model.parameters()) self.params = list(self.model.parameters())
self.namespace_dls = argparse.Namespace( self.cfg_dls = OmegaConf.create(
optimizer="adam", {
lr=[0.1], "optimizer": {
adam_betas="(0.9, 0.999)", "_name": "adam",
adam_eps=1e-8, "lr": [0.1],
weight_decay=0.0, "adam_betas": "(0.9, 0.999)",
fp16_init_scale=1, "adam_eps": 1e-8,
fp16_scale_window=1, "weight_decay": 0.0,
fp16_scale_tolerance=1, },
threshold_loss_scale=1, "common": {
min_loss_scale=1e-4, "fp16_init_scale": 1,
"fp16_scale_window": 1,
"fp16_scale_tolerance": 1,
"threshold_loss_scale": 1,
"min_loss_scale": 1e-4,
},
}
) )
def run_iter(self, model, params, optimizer): def run_iter(self, model, params, optimizer):
@ -68,7 +75,7 @@ class TestGradientScaling(unittest.TestCase):
def test_mixed_precision(self): def test_mixed_precision(self):
model = copy.deepcopy(self.model) model = copy.deepcopy(self.model)
params = list(model.parameters()) params = list(model.parameters())
optimizer = FP16Optimizer.build_optimizer(self.namespace_dls, params) optimizer = FP16Optimizer.build_optimizer(self.cfg_dls, params)
self.run_iter(model, params, optimizer) self.run_iter(model, params, optimizer)
self.assertTrue( self.assertTrue(
@ -87,9 +94,7 @@ class TestGradientScaling(unittest.TestCase):
def test_memory_efficient(self): def test_memory_efficient(self):
model = copy.deepcopy(self.model) model = copy.deepcopy(self.model)
params = list(model.parameters()) params = list(model.parameters())
optimizer = MemoryEfficientFP16Optimizer.build_optimizer( optimizer = MemoryEfficientFP16Optimizer.build_optimizer(self.cfg_dls, params)
self.namespace_dls, params
)
self.run_iter(model, params, optimizer) self.run_iter(model, params, optimizer)

View File

@ -6,6 +6,7 @@
import logging import logging
import unittest import unittest
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
from fairseq.models.transformer import TransformerModel from fairseq.models.transformer import TransformerModel
from tests.test_sequence_generator import get_dummy_task_and_parser from tests.test_sequence_generator import get_dummy_task_and_parser
@ -25,7 +26,8 @@ class TestInferenceDropout(unittest.TestCase):
def test_sets_inference_dropout_to_true(self): def test_sets_inference_dropout_to_true(self):
self.args.retain_dropout = True self.args.retain_dropout = True
self.transformer_model = TransformerModel.build_model(self.args, self.task) self.transformer_model = TransformerModel.build_model(self.args, self.task)
self.transformer_model.prepare_for_inference_(self.args) cfg = convert_namespace_to_omegaconf(self.args)
self.transformer_model.prepare_for_inference_(cfg)
assert self.transformer_model.encoder.dropout_module.apply_during_inference assert self.transformer_model.encoder.dropout_module.apply_during_inference
assert self.transformer_model.decoder.dropout_module.apply_during_inference assert self.transformer_model.decoder.dropout_module.apply_during_inference
for layer in self.transformer_model.encoder.layers: for layer in self.transformer_model.encoder.layers:
@ -33,7 +35,8 @@ class TestInferenceDropout(unittest.TestCase):
def test_inference_dropout_false_by_default(self): def test_inference_dropout_false_by_default(self):
self.transformer_model = TransformerModel.build_model(self.args, self.task) self.transformer_model = TransformerModel.build_model(self.args, self.task)
self.transformer_model.prepare_for_inference_(self.args) cfg = convert_namespace_to_omegaconf(self.args)
self.transformer_model.prepare_for_inference_(cfg)
assert not self.transformer_model.encoder.dropout_module.apply_during_inference assert not self.transformer_model.encoder.dropout_module.apply_during_inference
assert not self.transformer_model.decoder.dropout_module.apply_during_inference assert not self.transformer_model.decoder.dropout_module.apply_during_inference
for layer in self.transformer_model.encoder.layers: for layer in self.transformer_model.encoder.layers:
@ -59,7 +62,8 @@ class TestInferenceDropout(unittest.TestCase):
"TransformerEncoderLayer", "TransformerEncoderLayer",
] ]
self.transformer_model = TransformerModel.build_model(self.args, self.task) self.transformer_model = TransformerModel.build_model(self.args, self.task)
self.transformer_model.prepare_for_inference_(self.args) cfg = convert_namespace_to_omegaconf(self.args)
self.transformer_model.prepare_for_inference_(cfg)
assert self.transformer_model.encoder.dropout_module.apply_during_inference assert self.transformer_model.encoder.dropout_module.apply_during_inference
assert not self.transformer_model.decoder.dropout_module.apply_during_inference assert not self.transformer_model.decoder.dropout_module.apply_during_inference
for layer in self.transformer_model.decoder.layers: for layer in self.transformer_model.decoder.layers:

View File

@ -10,6 +10,7 @@ import unittest
import torch import torch
from fairseq.optim.adam import FairseqAdam from fairseq.optim.adam import FairseqAdam
from fairseq.optim.fp16_optimizer import MemoryEfficientFP16Optimizer from fairseq.optim.fp16_optimizer import MemoryEfficientFP16Optimizer
from omegaconf import OmegaConf
@unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU") @unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU")
@ -26,25 +27,36 @@ class TestMemoryEfficientFP16(unittest.TestCase):
params = list(model.parameters()) params = list(model.parameters())
# initialize memory efficient FP16 optimizer # initialize memory efficient FP16 optimizer
# with pseudo DictConfigs
optimizer = FairseqAdam( optimizer = FairseqAdam(
argparse.Namespace( cfg=OmegaConf.create(
lr=[0.00001], vars(
adam_betas="(0.9, 0.999)", argparse.Namespace(
adam_eps=1e-8, adam_betas="(0.9, 0.999)",
weight_decay=0.0, adam_eps=1e-8,
weight_decay=0.0,
lr=[0.00001],
)
)
), ),
params, params=params,
) )
me_optimizer = MemoryEfficientFP16Optimizer( me_optimizer = MemoryEfficientFP16Optimizer(
argparse.Namespace( cfg=OmegaConf.create(
fp16_init_scale=1, {
fp16_scale_window=1, "common": vars(
fp16_scale_tolerance=1, argparse.Namespace(
threshold_loss_scale=1, fp16_init_scale=1,
min_loss_scale=1e-4, fp16_scale_window=1,
fp16_scale_tolerance=1,
threshold_loss_scale=1,
min_loss_scale=1e-4,
)
)
}
), ),
params, params=params,
optimizer, optimizer=optimizer,
) )
# optimizer state is created in the first step # optimizer state is created in the first step

View File

@ -11,6 +11,7 @@ from unittest.mock import MagicMock, patch
import torch import torch
from fairseq import checkpoint_utils, data from fairseq import checkpoint_utils, data
from omegaconf import OmegaConf
def mock_trainer(epoch, num_updates, iterations_in_epoch): def mock_trainer(epoch, num_updates, iterations_in_epoch):
@ -56,21 +57,29 @@ def get_trainer_and_epoch_itr(epoch, epoch_size, num_updates, iterations_in_epoc
return trainer, epoch_itr return trainer, epoch_itr
def get_mock_args(finetune_from_model=None): def get_mock_cfg(finetune_from_model):
args_mock = MagicMock() cfg_mock = OmegaConf.create(
args_mock.optimizer_overrides = "{}" {
args_mock.reset_dataloader = False "checkpoint": {
args_mock.reset_meters = False "optimizer_overrides": "{}",
args_mock.reset_optimizer = False "reset_dataloader": False,
args_mock.reset_lr_scheduler = False "reset_meters": False,
args_mock.finetune_from_model = finetune_from_model "reset_optimizer": False,
args_mock.model_parallel_size = 1 "reset_lr_scheduler": False,
return args_mock "finetune_from_model": finetune_from_model,
"model_parallel_size": 1,
},
"common": {
"model_parallel_size": 1,
},
}
)
return cfg_mock
class TestLoadCheckpoint(unittest.TestCase): class TestLoadCheckpoint(unittest.TestCase):
def setUp(self): def setUp(self):
self.args_mock = get_mock_args() self.cfg_mock = get_mock_cfg(None)
self.patches = { self.patches = {
"os.makedirs": MagicMock(), "os.makedirs": MagicMock(),
"os.path.join": MagicMock(), "os.path.join": MagicMock(),
@ -91,7 +100,9 @@ class TestLoadCheckpoint(unittest.TestCase):
trainer, epoch_itr = get_trainer_and_epoch_itr(2, 150, 200, 50) trainer, epoch_itr = get_trainer_and_epoch_itr(2, 150, 200, 50)
trainer.get_train_iterator = MagicMock(return_value=epoch_itr) trainer.get_train_iterator = MagicMock(return_value=epoch_itr)
_, epoch_itr = checkpoint_utils.load_checkpoint(self.args_mock, trainer) _, epoch_itr = checkpoint_utils.load_checkpoint(
self.cfg_mock.checkpoint, trainer
)
self.assertEqual(epoch_itr.epoch, 2) self.assertEqual(epoch_itr.epoch, 2)
self.assertEqual(epoch_itr.iterations_in_epoch, 50) self.assertEqual(epoch_itr.iterations_in_epoch, 50)
@ -120,7 +131,9 @@ class TestLoadCheckpoint(unittest.TestCase):
trainer, epoch_itr = get_trainer_and_epoch_itr(2, 150, 300, 150) trainer, epoch_itr = get_trainer_and_epoch_itr(2, 150, 300, 150)
trainer.get_train_iterator = MagicMock(return_value=epoch_itr) trainer.get_train_iterator = MagicMock(return_value=epoch_itr)
_, epoch_itr = checkpoint_utils.load_checkpoint(self.args_mock, trainer) _, epoch_itr = checkpoint_utils.load_checkpoint(
self.cfg_mock.checkpoint, trainer
)
itr = epoch_itr.next_epoch_itr(shuffle=False) itr = epoch_itr.next_epoch_itr(shuffle=False)
self.assertEqual(epoch_itr.epoch, 3) self.assertEqual(epoch_itr.epoch, 3)
@ -133,7 +146,9 @@ class TestLoadCheckpoint(unittest.TestCase):
trainer.get_train_iterator = MagicMock(return_value=epoch_itr) trainer.get_train_iterator = MagicMock(return_value=epoch_itr)
self.patches["os.path.isfile"].return_value = False self.patches["os.path.isfile"].return_value = False
_, epoch_itr = checkpoint_utils.load_checkpoint(self.args_mock, trainer) _, epoch_itr = checkpoint_utils.load_checkpoint(
self.cfg_mock.checkpoint, trainer
)
itr = epoch_itr.next_epoch_itr(shuffle=False) itr = epoch_itr.next_epoch_itr(shuffle=False)
self.assertEqual(epoch_itr.epoch, 1) self.assertEqual(epoch_itr.epoch, 1)
@ -152,10 +167,12 @@ class TestLoadCheckpoint(unittest.TestCase):
"reset_dataloader", "reset_dataloader",
]: ]:
with self.subTest(arg=arg): with self.subTest(arg=arg):
args_mock = get_mock_args("/temp/checkpoint_pretrained.pt") cfg_mock = get_mock_cfg("/temp/checkpoint_pretrained.pt")
setattr(args_mock, arg, True) cfg_mock["checkpoint"][arg] = True
with self.assertRaises(Exception) as context: with self.assertRaises(Exception) as context:
_, _ = checkpoint_utils.load_checkpoint(args_mock, trainer) _, _ = checkpoint_utils.load_checkpoint(
cfg_mock.checkpoint, trainer
)
self.assertTrue( self.assertTrue(
"--finetune-from-model can not be set together with either --reset-optimizer" "--finetune-from-model can not be set together with either --reset-optimizer"
@ -168,8 +185,6 @@ class TestLoadCheckpoint(unittest.TestCase):
trainer, epoch_itr = get_trainer_and_epoch_itr(1, 150, 0, 0) trainer, epoch_itr = get_trainer_and_epoch_itr(1, 150, 0, 0)
trainer.get_train_iterator = MagicMock(return_value=epoch_itr) trainer.get_train_iterator = MagicMock(return_value=epoch_itr)
from_model_path = "/temp/checkpoint_pretrained.pt" from_model_path = "/temp/checkpoint_pretrained.pt"
args_mock = get_mock_args(from_model_path)
args_mock.restore_file = "checkpoint_last.pt"
def mock_finetune_exist(path): def mock_finetune_exist(path):
if path == from_model_path: if path == from_model_path:
@ -180,7 +195,9 @@ class TestLoadCheckpoint(unittest.TestCase):
self.patches[ self.patches[
"fairseq.file_io.PathManager.exists" "fairseq.file_io.PathManager.exists"
].side_effect = mock_finetune_exist ].side_effect = mock_finetune_exist
_, _ = checkpoint_utils.load_checkpoint(args_mock, trainer) cfg_mock = get_mock_cfg(from_model_path)
cfg_mock.checkpoint.restore_file = "checkpoint_last.pt"
_, _ = checkpoint_utils.load_checkpoint(cfg_mock.checkpoint, trainer)
( (
checkpoint_path, checkpoint_path,
reset_optimizer, reset_optimizer,
@ -197,8 +214,6 @@ class TestLoadCheckpoint(unittest.TestCase):
trainer, epoch_itr = get_trainer_and_epoch_itr(1, 150, 0, 0) trainer, epoch_itr = get_trainer_and_epoch_itr(1, 150, 0, 0)
trainer.get_train_iterator = MagicMock(return_value=epoch_itr) trainer.get_train_iterator = MagicMock(return_value=epoch_itr)
from_model_path = "/temp/checkpoint_pretrained.pt" from_model_path = "/temp/checkpoint_pretrained.pt"
args_mock = get_mock_args(from_model_path)
args_mock.restore_file = "checkpoint_last.pt"
# launch second time # launch second time
# both restore_file=checkpoint_last.pt and finetune_from_model are set # both restore_file=checkpoint_last.pt and finetune_from_model are set
@ -211,7 +226,9 @@ class TestLoadCheckpoint(unittest.TestCase):
self.patches[ self.patches[
"fairseq.file_io.PathManager.exists" "fairseq.file_io.PathManager.exists"
].side_effect = mock_finetune_exist ].side_effect = mock_finetune_exist
_, _ = checkpoint_utils.load_checkpoint(args_mock, trainer) cfg_mock = get_mock_cfg(from_model_path)
cfg_mock.checkpoint.restore_file = "checkpoint_last.pt"
_, _ = checkpoint_utils.load_checkpoint(cfg_mock.checkpoint, trainer)
( (
checkpoint_path, checkpoint_path,
reset_optimizer, reset_optimizer,

View File

@ -20,7 +20,7 @@ from fairseq.models import (
FairseqIncrementalDecoder, FairseqIncrementalDecoder,
) )
from fairseq.models.fairseq_encoder import EncoderOut from fairseq.models.fairseq_encoder import EncoderOut
from fairseq.tasks import FairseqTask, LegacyFairseqTask from fairseq.tasks import LegacyFairseqTask
from fairseq_cli import generate, interactive, preprocess, train, validate from fairseq_cli import generate, interactive, preprocess, train, validate