Enable Hydra configs in fairseq (#1343) (#1510)

Summary:
Pull Request resolved: https://github.com/facebookresearch/pytext/pull/1510

this is the main pr that switches on hydra functionality in fairseq

we migrate "args" object into omegaconf "DictConfig" at all legacy entry points

in addition this migrates various components from secondary registries (like bpe encoders and tokenizers) to make the migration smoother

i am going through code that references migrated fairseq components and changing it to inherit from "Legacy*" components instead. hopefully tests will catch most of this

Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1343

Reviewed By: myleott

Differential Revision: D23973928

Pulled By: alexeib

fbshipit-source-id: dd9554981fff51ea75c1ff343874d1d6e61793c9
This commit is contained in:
alexeib 2020-10-20 00:31:00 -07:00 committed by Facebook GitHub Bot
parent c76cb6dfb9
commit 3b27ed7996
85 changed files with 2034 additions and 1681 deletions

View File

@ -1,7 +1,111 @@
# @package _group_
common:
no_progress_bar: false
log_interval: 100
log_format: null
tensorboard_logdir: null
seed: 1
cpu: false
tpu: false
bf16: false
fp16: false
memory_efficient_fp16: false
memory_efficient_bf16: false
fp16_no_flatten_grads: false
fp16_init_scale: 128
fp16_scale_window: null
fp16_scale_tolerance: 0.0
min_loss_scale: 1.0e-4
threshold_loss_scale: null
user_dir: null
empty_cache_freq: 0
all_gather_list_size: 16384
model_parallel_size: 1
quantization_config_path: null
profile: false
distributed_training:
distributed_rank: 0
distributed_backend: "nccl"
distributed_init_method: null
distributed_port: -1
device_id: 0
local_rank: 0
distributed_no_spawn: false
ddp_backend: "c10d"
bucket_cap_mb: 25
fix_batches_to_gpus: false
find_unused_parameters: false
fast_stat_sync: false
broadcast_buffers: false
distributed_wrapper: "DDP"
slowmo_momentum: null
slowmo_algorithm: "LocalSGD"
localsgd_frequency: 3
dataset:
num_workers: 1
skip_invalid_size_inputs_valid_test: false
max_tokens: null
batch_size: null
required_batch_size_multiple: 8
dataset_impl: null
data_buffer_size: 10
train_subset: "train"
valid_subset: "valid"
validate_interval: 1
fixed_validation_seed: null
disable_validation: false
curriculum: 0
gen_subset: "test"
num_shards: 1
shard_id: 0
max_tokens_valid: ${dataset.max_tokens}
batch_size_valid: ${dataset.batch_size}
optimization:
max_epoch: 0
max_update: 0
clip_norm: 25.0
sentence_avg: false
update_freq: [ 1 ]
lr: [ 0.25 ]
min_lr: -1.0
use_bmuf: false
checkpoint:
save_dir: "checkpoints"
restore_file: "checkpoint_last.pt"
reset_dataloader: false
reset_lr_scheduler: false
reset_meters: false
reset_optimizer: false
optimizer_overrides: "{}"
save_interval: 1
save_interval_updates: 0
keep_interval_updates: -1
keep_last_epochs: -1
keep_best_checkpoints: -1
no_save: false
no_epoch_checkpoints: false
no_last_checkpoints: false
no_save_optimizer_state: false
best_checkpoint_metric: "loss"
maximize_best_checkpoint_metric: false
patience: -1
checkpoint_suffix: ""
bmuf:
block_lr: 1
block_momentum: 0.875
global_sync_iter: 50
warmup_iterations: 500
use_nbm: false
average_sync: false
defaults:
- params: training_params
- task: language_modeling
- model: transformer_lm
- criterion: cross_entropy
- optimizer: adam
- lr_scheduler: inverse_sqrt
- task: language_modeling
- model: null
- criterion: null
- optimizer: null
- lr_scheduler: null
- bpe: null
- tokenizer: null
- scoring: null
- generation: null
- common_eval: null
- eval_lm: null

View File

@ -1,7 +0,0 @@
defaults:
- params: eval_lm_params
- task: language_modeling
- model: transformer_lm
- criterion: cross_entropy
- optimizer: adam
- lr_scheduler: inverse_sqrt

View File

@ -1,3 +1,3 @@
# @package _group_
sentence_avg: ${params.optimization.sentence_avg}
ddp_backend: ${params.distributed_training.ddp_backend}
sentence_avg: ${optimization.sentence_avg}
ddp_backend: ${distributed_training.ddp_backend}

View File

@ -1,3 +1,2 @@
# @package _group_
sentence_avg: ${params.optimization.sentence_avg}
ddp_backend: ${params.distributed_training.ddp_backend}
sentence_avg: ${optimization.sentence_avg}

View File

@ -1,105 +0,0 @@
# @package _group_
common:
no_progress_bar: false
log_interval: 100
log_format: null
tensorboard_logdir: null
seed: 1
cpu: false
fp16: false
memory_efficient_fp16: false
fp16_no_flatten_grads: false
fp16_init_scale: 128
fp16_scale_window: null
fp16_scale_tolerance: 0.0
min_loss_scale: 1.0e-4
threshold_loss_scale: null
user_dir: null
empty_cache_freq: 0
all_gather_list_size: 16384
model_parallel_size: 1
checkpoint_suffix: ""
quantization_config_path: null
distributed_training:
distributed_rank: 0
distributed_backend: "nccl"
distributed_init_method: null
distributed_port: -1
device_id: 0
local_rank: 0
distributed_no_spawn: false
ddp_backend: "c10d"
bucket_cap_mb: 25
fix_batches_to_gpus: false
find_unused_parameters: false
fast_stat_sync: false
broadcast_buffers: false
distributed_wrapper: "DDP"
slowmo_momentum: null
slowmo_algorithm: "LocalSGD"
localsgd_frequency: 3
dataset:
num_workers: 1
skip_invalid_size_inputs_valid_test: false
max_tokens: null
batch_size: ${params.dataset.batch_size}
required_batch_size_multiple: 8
dataset_impl: null
data_buffer_size: 10
train_subset: "train"
valid_subset: "valid"
validate_interval: 1
fixed_validation_seed: null
disable_validation: false
curriculum: 0
gen_subset: "test"
num_shards: 1
shard_id: 0
max_tokens_valid: ${params.dataset.max_tokens}
batch_size_valid: ${params.dataset.batch_size}
optimization:
max_epoch: 0
max_update: 0
clip_norm: 25.0
sentence_avg: false
update_freq: [1]
lr: [0.25]
min_lr: -1.0
use_bmuf: false
checkpoint:
save_dir: "checkpoints"
restore_file: "checkpoint_last.pt"
reset_dataloader: false
reset_lr_scheduler: false
reset_meters: false
reset_optimizer: false
optimizer_overrides: "{}"
save_interval: 1
save_interval_updates: 0
keep_interval_updates: -1
keep_last_epochs: -1
keep_best_checkpoints: -1
no_save: false
no_epoch_checkpoints: false
no_last_checkpoints: false
no_save_optimizer_state: false
best_checkpoint_metric: "loss"
maximize_best_checkpoint_metric: false
patience: -1
common_eval:
path: null
remove_bpe: null
quiet: false
model_overrides: "{}"
results_path: null
eval_lm:
output_word_probs: false
output_word_stats: false
context_window: 0
bmuf:
block_lr: 1
block_momentum: 0.875
global_sync_iter: 50
warmup_iterations: 500
use_nbm: false
average_sync: false

View File

@ -1,95 +0,0 @@
# @package _group_
common:
no_progress_bar: false
log_interval: 100
log_format: null
tensorboard_logdir: null
seed: 1
cpu: false
fp16: false
memory_efficient_fp16: false
fp16_no_flatten_grads: false
fp16_init_scale: 128
fp16_scale_window: null
fp16_scale_tolerance: 0.0
min_loss_scale: 1.0e-4
threshold_loss_scale: null
user_dir: null
empty_cache_freq: 0
all_gather_list_size: 16384
model_parallel_size: 1
checkpoint_suffix: ""
quantization_config_path: null
distributed_training:
distributed_rank: 0
distributed_backend: "nccl"
distributed_init_method: null
distributed_port: -1
device_id: 0
local_rank: 0
distributed_no_spawn: false
ddp_backend: "c10d"
bucket_cap_mb: 25
fix_batches_to_gpus: false
find_unused_parameters: false
fast_stat_sync: false
broadcast_buffers: false
distributed_wrapper: "DDP"
slowmo_momentum: null
slowmo_algorithm: "LocalSGD"
localsgd_frequency: 3
dataset:
num_workers: 1
skip_invalid_size_inputs_valid_test: false
max_tokens: null
batch_size: ${params.dataset.batch_size}
required_batch_size_multiple: 8
dataset_impl: null
data_buffer_size: 10
train_subset: "train"
valid_subset: "valid"
validate_interval: 1
fixed_validation_seed: null
disable_validation: false
curriculum: 0
gen_subset: "test"
num_shards: 1
shard_id: 0
max_tokens_valid: ${params.dataset.max_tokens}
batch_size_valid: ${params.dataset.batch_size}
optimization:
max_epoch: 0
max_update: 0
clip_norm: 25.0
sentence_avg: false
update_freq: [1]
lr: [0.25]
min_lr: -1.0
use_bmuf: false
checkpoint:
save_dir: "checkpoints"
restore_file: "checkpoint_last.pt"
reset_dataloader: false
reset_lr_scheduler: false
reset_meters: false
reset_optimizer: false
optimizer_overrides: "{}"
save_interval: 1
save_interval_updates: 0
keep_interval_updates: -1
keep_last_epochs: -1
keep_best_checkpoints: -1
no_save: false
no_epoch_checkpoints: false
no_last_checkpoints: false
no_save_optimizer_state: false
best_checkpoint_metric: "loss"
maximize_best_checkpoint_metric: false
patience: -1
bmuf:
block_lr: 1
block_momentum: 0.875
global_sync_iter: 50
warmup_iterations: 500
use_nbm: false
average_sync: false

View File

@ -13,7 +13,6 @@ For example, if we'd like to train a language model with transformer, we could p
```
defaults:
- params: training_params
- task: language_modeling
- model: transformer_lm
- criterion: cross_entropy
@ -21,7 +20,7 @@ defaults:
- lr_scheduler: inverse_sqrt
```
- Provide generic parameters common across different training jobs: `config/params/training_params.yaml`
- Provide generic parameters common across different jobs: `config.yaml`
- Provide task parameters: `config/task/language_modeling.yaml`
- Provide model parameters: `config/model/transformer_lm.yaml`
- Provide criterion parameters: `config/criterion/cross_entropy.yaml`
@ -41,7 +40,6 @@ Alternatively, if we need to override certain params from the command line, we c
```
python fairseq_cli/train_hydra.py
params=training_params \
task=language_modeling \
task.data=/private/home/abaevski/data/wiki103 \
task.tokens_per_sample=512 \
@ -56,17 +54,17 @@ lr_scheduler=inverse_sqrt \
lr_scheduler.warmup_updates=4000 \
lr_scheduler.warmup_init_lr=1e-07 \
criterion=cross_entropy \
params.common.fp16=true \
params.common.log_format=json \
params.common.log_interval=1 \
params.dataset.max_tokens=1024 \
params.dataset.num_workers=4 \
params.optimization.update_freq=[16] \
params.optimization.max_update=50000 \
params.optimization.clip_norm=0.0 \
params.optimization.lr=[0.0005] \
params.checkpoint.save_dir=/checkpoint/mtian/transformer_wikitext-103-hydra-args-cli \
params.checkpoint.save_interval_updates=10
common.fp16=true \
common.log_format=json \
common.log_interval=1 \
dataset.max_tokens=1024 \
dataset.num_workers=4 \
optimization.update_freq=[16] \
optimization.max_update=50000 \
optimization.clip_norm=0.0 \
optimization.lr=[0.0005] \
checkpoint.save_dir=/checkpoint/mtian/transformer_wikitext-103-hydra-args-cli \
checkpoint.save_interval_updates=10
```
## Migrate existing/Creating new modules to hydra interface

View File

@ -212,7 +212,7 @@ following contents::
@register_task('simple_classification')
class SimpleClassificationTask(FairseqTask):
class SimpleClassificationTask(LegacyFairseqTask):
@staticmethod
def add_args(parser):

View File

@ -27,7 +27,13 @@ def score_target_hypo(
print("lenpen", lenpen, "weight1", a, "weight2", b, "weight3", c)
gen_output_lst, bitext1_lst, bitext2_lst, lm_res_lst = load_score_files(args)
dict = dictionary.Dictionary()
scorer = bleu.Scorer(dict.pad(), dict.eos(), dict.unk())
scorer = scorer = bleu.Scorer(
bleu.BleuConfig(
pad=dict.pad(),
eos=dict.eos(),
unk=dict.unk(),
)
)
ordered_hypos = {}
ordered_targets = {}

View File

@ -20,8 +20,8 @@ class WSCCriterion(LegacyFairseqCriterion):
self.prediction_h = open(self.args.save_predictions, "w")
else:
self.prediction_h = None
self.bpe = encoders.build_bpe(args)
self.tokenizer = encoders.build_tokenizer(args)
self.bpe = encoders.build_bpe(args.bpe)
self.tokenizer = encoders.build_tokenizer(args.tokenizer)
def __del__(self):
if self.prediction_h is not None:

View File

@ -85,7 +85,7 @@ Produce model scores for the generated translations using `--retain-dropout` opt
```
CUDA_VISIBLE_DEVICES=${GPU} fairseq-generate ${TMP}/bin-repeated --path ${MODEL_DIR}/${LP}.pt --beam 5
--source-lang $SRC_LANG --target-lang $TGT_LANG --no-progress-bar --unkpen 5 --score-reference --retain-dropout
--retain-dropout-modules TransformerModel TransformerEncoder TransformerDecoder TransformerEncoderLayer
--retain-dropout-modules '["TransformerModel","TransformerEncoder","TransformerDecoder","TransformerEncoderLayer"]'
TransformerDecoderLayer --seed 46 > $TMP/dropout.scoring.out
grep ^H $TMP/dropout.scoring.out | cut -f2- > $TMP/dropout.scores

View File

@ -3,36 +3,42 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import ast
import collections
import logging
import os
import re
import traceback
from collections import OrderedDict
from typing import Union
from typing import Optional, Union
import torch
from fairseq.dataclass.utils import (
convert_namespace_to_omegaconf,
overwrite_args_by_name,
)
from fairseq.file_io import PathManager
from fairseq.models import FairseqDecoder, FairseqEncoder
from omegaconf import DictConfig, open_dict
from torch.serialization import default_restore_location
logger = logging.getLogger(__name__)
def save_checkpoint(args, trainer, epoch_itr, val_loss):
from fairseq import distributed_utils, meters
def save_checkpoint(cfg: DictConfig, trainer, epoch_itr, val_loss):
from fairseq import meters
# only one worker should attempt to create the required dir
if args.distributed_rank == 0:
os.makedirs(args.save_dir, exist_ok=True)
if cfg.distributed_rank == 0:
os.makedirs(cfg.save_dir, exist_ok=True)
prev_best = getattr(save_checkpoint, "best", val_loss)
if val_loss is not None:
best_function = max if args.maximize_best_checkpoint_metric else min
best_function = max if cfg.maximize_best_checkpoint_metric else min
save_checkpoint.best = best_function(val_loss, prev_best)
if args.no_save:
if cfg.no_save:
return
trainer.consolidate_optimizer()
@ -41,7 +47,7 @@ def save_checkpoint(args, trainer, epoch_itr, val_loss):
return
def is_better(a, b):
return a >= b if args.maximize_best_checkpoint_metric else a <= b
return a >= b if cfg.maximize_best_checkpoint_metric else a <= b
write_timer = meters.StopwatchMeter()
write_timer.start()
@ -50,38 +56,36 @@ def save_checkpoint(args, trainer, epoch_itr, val_loss):
end_of_epoch = epoch_itr.end_of_epoch()
updates = trainer.get_num_updates()
suffix = getattr(args, "checkpoint_suffix", "")
suffix = cfg.checkpoint_suffix or ""
checkpoint_conds = collections.OrderedDict()
checkpoint_conds["checkpoint{}{}.pt".format(epoch, suffix)] = (
end_of_epoch
and not args.no_epoch_checkpoints
and epoch % args.save_interval == 0
end_of_epoch and not cfg.no_epoch_checkpoints and epoch % cfg.save_interval == 0
)
checkpoint_conds["checkpoint_{}_{}{}.pt".format(epoch, updates, suffix)] = (
not end_of_epoch
and args.save_interval_updates > 0
and updates % args.save_interval_updates == 0
and cfg.save_interval_updates > 0
and updates % cfg.save_interval_updates == 0
)
checkpoint_conds["checkpoint_best{}.pt".format(suffix)] = val_loss is not None and (
not hasattr(save_checkpoint, "best")
or is_better(val_loss, save_checkpoint.best)
)
if val_loss is not None and args.keep_best_checkpoints > 0:
if val_loss is not None and cfg.keep_best_checkpoints > 0:
checkpoint_conds[
"checkpoint.best_{}_{:.2f}.pt".format(args.best_checkpoint_metric, val_loss)
"checkpoint.best_{}_{:.2f}.pt".format(cfg.best_checkpoint_metric, val_loss)
] = not hasattr(save_checkpoint, "best") or is_better(
val_loss, save_checkpoint.best
)
checkpoint_conds[
"checkpoint_last{}.pt".format(suffix)
] = not args.no_last_checkpoints
] = not cfg.no_last_checkpoints
extra_state = {"train_iterator": epoch_itr.state_dict(), "val_loss": val_loss}
if hasattr(save_checkpoint, "best"):
extra_state.update({"best": save_checkpoint.best})
checkpoints = [
os.path.join(args.save_dir, fn) for fn, cond in checkpoint_conds.items() if cond
os.path.join(cfg.save_dir, fn) for fn, cond in checkpoint_conds.items() if cond
]
if len(checkpoints) > 0:
trainer.save_checkpoint(checkpoints[0], extra_state)
@ -95,51 +99,52 @@ def save_checkpoint(args, trainer, epoch_itr, val_loss):
)
)
if not end_of_epoch and args.keep_interval_updates > 0:
if not end_of_epoch and cfg.keep_interval_updates > 0:
# remove old checkpoints; checkpoints are sorted in descending order
checkpoints = checkpoint_paths(
args.save_dir, pattern=r"checkpoint_\d+_(\d+)\.pt"
cfg.save_dir, pattern=r"checkpoint_\d+_(\d+)\.pt"
)
for old_chk in checkpoints[args.keep_interval_updates :]:
for old_chk in checkpoints[cfg.keep_interval_updates :]:
if os.path.lexists(old_chk):
os.remove(old_chk)
if args.keep_last_epochs > 0:
if cfg.keep_last_epochs > 0:
# remove old epoch checkpoints; checkpoints are sorted in descending order
checkpoints = checkpoint_paths(args.save_dir, pattern=r"checkpoint(\d+)\.pt")
for old_chk in checkpoints[args.keep_last_epochs :]:
checkpoints = checkpoint_paths(cfg.save_dir, pattern=r"checkpoint(\d+)\.pt")
for old_chk in checkpoints[cfg.keep_last_epochs :]:
if os.path.lexists(old_chk):
os.remove(old_chk)
if args.keep_best_checkpoints > 0:
if cfg.keep_best_checkpoints > 0:
# only keep the best N checkpoints according to validation metric
checkpoints = checkpoint_paths(
args.save_dir,
cfg.save_dir,
pattern=r"checkpoint\.best_{}_(\d+\.?\d*)\.pt".format(
args.best_checkpoint_metric
cfg.best_checkpoint_metric
),
)
if not args.maximize_best_checkpoint_metric:
if not cfg.maximize_best_checkpoint_metric:
checkpoints = checkpoints[::-1]
for old_chk in checkpoints[args.keep_best_checkpoints :]:
for old_chk in checkpoints[cfg.keep_best_checkpoints :]:
if os.path.lexists(old_chk):
os.remove(old_chk)
def load_checkpoint(args, trainer, **passthrough_args):
def load_checkpoint(cfg: DictConfig, trainer, **passthrough_args):
"""
Load a checkpoint and restore the training iterator.
*passthrough_args* will be passed through to
``trainer.get_train_iterator``.
"""
reset_optimizer = args.reset_optimizer
reset_lr_scheduler = args.reset_lr_scheduler
optimizer_overrides = eval(args.optimizer_overrides)
reset_meters = args.reset_meters
reset_dataloader = args.reset_dataloader
if getattr(args, "finetune_from_model", None) is not None and (
reset_optimizer = cfg.reset_optimizer
reset_lr_scheduler = cfg.reset_lr_scheduler
optimizer_overrides = ast.literal_eval(cfg.optimizer_overrides)
reset_meters = cfg.reset_meters
reset_dataloader = cfg.reset_dataloader
if cfg.finetune_from_model is not None and (
reset_optimizer or reset_lr_scheduler or reset_meters or reset_dataloader
):
raise ValueError(
@ -147,19 +152,19 @@ def load_checkpoint(args, trainer, **passthrough_args):
" or reset_lr_scheduler or reset_meters or reset_dataloader"
)
suffix = getattr(args, "checkpoint_suffix", "")
suffix = cfg.checkpoint_suffix
if (
args.restore_file == "checkpoint_last.pt"
cfg.restore_file == "checkpoint_last.pt"
): # default value of restore_file is 'checkpoint_last.pt'
checkpoint_path = os.path.join(
args.save_dir, "checkpoint_last{}.pt".format(suffix)
cfg.save_dir, "checkpoint_last{}.pt".format(suffix)
)
first_launch = not PathManager.exists(checkpoint_path)
if getattr(args, "finetune_from_model", None) is not None and first_launch:
if cfg.finetune_from_model is not None and first_launch:
# if there is no last checkpoint to restore, start the finetune from pretrained model
# else just use usual logic to load checkpoint, e.g. restart from last checkpoint and etc.
if PathManager.exists(args.finetune_from_model):
checkpoint_path = args.finetune_from_model
if PathManager.exists(cfg.finetune_from_model):
checkpoint_path = cfg.finetune_from_model
reset_optimizer = True
reset_lr_scheduler = True
reset_meters = True
@ -170,19 +175,17 @@ def load_checkpoint(args, trainer, **passthrough_args):
)
else:
raise ValueError(
f"--funetune-from-model {args.finetune_from_model} does not exist"
f"--funetune-from-model {cfg.finetune_from_model} does not exist"
)
elif getattr(args, "model_parallel_size", 1) > 1:
checkpoint_path = args.restore_file.replace(".pt", suffix + ".pt")
elif cfg.model_parallel_size > 1:
checkpoint_path = cfg.restore_file.replace(".pt", suffix + ".pt")
else:
checkpoint_path = args.restore_file
checkpoint_path = cfg.restore_file
if args.restore_file != "checkpoint_last.pt" and getattr(
args, "finetune_from_model", None
):
if cfg.restore_file != "checkpoint_last.pt" and cfg.finetune_from_model:
raise ValueError(
"--finetune-from-model and --restore-file (non-default value) "
"can not be specified together: " + str(args)
"can not be specified together: " + str(cfg)
)
extra_state = trainer.load_checkpoint(
@ -225,10 +228,14 @@ def load_checkpoint_to_cpu(path, arg_overrides=None):
f, map_location=lambda s, l: default_restore_location(s, "cpu")
)
args = state["args"]
if arg_overrides is not None:
if "args" in state and state["args"] is not None and arg_overrides is not None:
args = state["args"]
for arg_name, arg_val in arg_overrides.items():
setattr(args, arg_name, arg_val)
if "cfg" in state and state["cfg"] is not None and arg_overrides is not None:
overwrite_args_by_name(state["cfg"], arg_overrides)
state = _upgrade_state_dict(state)
return state
@ -274,19 +281,28 @@ def load_model_ensemble_and_task(
filename = filename.replace(".pt", suffix + ".pt")
else:
filename = orig_filename[:-3] + f"_part{shard_idx}.pt"
if not PathManager.exists(filename):
raise IOError("Model file not found: {}".format(filename))
state = load_checkpoint_to_cpu(filename, arg_overrides)
if shard_idx == 0:
args = state["args"]
if task is None:
task = tasks.setup_task(args)
if "args" in state and state["args"] is not None:
cfg = convert_namespace_to_omegaconf(state["args"])
elif "cfg" in state and state["cfg"] is not None:
cfg = state["cfg"]
else:
raise RuntimeError(
f"Neither args nor cfg exist in state keys = {state.keys()}"
)
# build model for ensemble
model = task.build_model(args)
model.load_state_dict(state["model"], strict=strict, args=args)
if task is None:
task = tasks.setup_task(cfg.task)
# build model for ensemble
model = task.build_model(cfg.model)
model.load_state_dict(state["model"], strict=strict, model_cfg=cfg.model)
ensemble.append(model)
return ensemble, args, task
return ensemble, cfg, task
def checkpoint_paths(path, pattern=r"checkpoint(\d+)\.pt"):
@ -323,7 +339,7 @@ def torch_persistent_save(obj, f):
def save_state(
filename,
args,
cfg: DictConfig,
model_state_dict,
criterion,
optimizer,
@ -331,6 +347,7 @@ def save_state(
num_updates,
optim_history=None,
extra_state=None,
**kwargs,
):
from fairseq import utils
@ -339,7 +356,8 @@ def save_state(
if extra_state is None:
extra_state = {}
state_dict = {
"args": args,
"cfg": cfg,
"args": kwargs.get("args", None),
"model": model_state_dict or {},
"optimizer_history": optim_history
+ [
@ -354,11 +372,17 @@ def save_state(
}
if utils.has_parameters(criterion):
state_dict["criterion"] = criterion.state_dict()
if not args.no_save_optimizer_state:
state_dict["last_optimizer_state"] = optimizer.state_dict()
# convert all state to CPU
state_dict = utils.move_to_cpu(state_dict)
if cfg is None:
cfg = state_dict["args"]
assert cfg is not None, "must provide cfg or args"
if isinstance(cfg, DictConfig):
no_save_optimizer_state = cfg.checkpoint.no_save_optimizer_state
else:
no_save_optimizer_state = cfg.no_save_optimizer_state
if not no_save_optimizer_state:
state_dict["last_optimizer_state"] = optimizer.state_dict()
with PathManager.open(filename, "wb") as f:
torch_persistent_save(state_dict, f)
@ -403,46 +427,49 @@ def _upgrade_state_dict(state):
# keep track of number of updates
if "num_updates" not in state["optimizer_history"][-1]:
state["optimizer_history"][-1]["num_updates"] = 0
# old model checkpoints may not have separate source/target positions
if hasattr(state["args"], "max_positions") and not hasattr(
state["args"], "max_source_positions"
):
state["args"].max_source_positions = state["args"].max_positions
state["args"].max_target_positions = state["args"].max_positions
# use stateful training data iterator
if "train_iterator" not in state["extra_state"]:
state["extra_state"]["train_iterator"] = {
"epoch": state["extra_state"]["epoch"],
"iterations_in_epoch": state["extra_state"].get("batch_offset", 0),
}
# default to translation task
if not hasattr(state["args"], "task"):
state["args"].task = "translation"
# --raw-text and --lazy-load are deprecated
if getattr(state["args"], "raw_text", False):
state["args"].dataset_impl = "raw"
elif getattr(state["args"], "lazy_load", False):
state["args"].dataset_impl = "lazy"
# epochs start at 1
if state["extra_state"]["train_iterator"] is not None:
state["extra_state"]["train_iterator"]["epoch"] = max(
state["extra_state"]["train_iterator"].get("epoch", 1),
1,
)
# set any missing default values in the task, model or other registries
registry.set_defaults(state["args"], tasks.TASK_REGISTRY[state["args"].task])
registry.set_defaults(state["args"], models.ARCH_MODEL_REGISTRY[state["args"].arch])
for registry_name, REGISTRY in registry.REGISTRIES.items():
choice = getattr(state["args"], registry_name, None)
if choice is not None:
cls = REGISTRY["registry"][choice]
registry.set_defaults(state["args"], cls)
# old model checkpoints may not have separate source/target positions
# backward compatibility, cfg updates
if "args" in state and state["args"] is not None:
# default to translation task
if not hasattr(state["args"], "task"):
state["args"].task = "translation"
# --raw-text and --lazy-load are deprecated
if getattr(state["args"], "raw_text", False):
state["args"].dataset_impl = "raw"
elif getattr(state["args"], "lazy_load", False):
state["args"].dataset_impl = "lazy"
# epochs start at 1
if state["extra_state"]["train_iterator"] is not None:
state["extra_state"]["train_iterator"]["epoch"] = max(
state["extra_state"]["train_iterator"].get("epoch", 1), 1
)
state["cfg"] = convert_namespace_to_omegaconf(state["args"])
if "cfg" in state and state["cfg"] is not None:
with open_dict(state["cfg"]):
if state["cfg"].task is not None:
if hasattr(state["cfg"].task, "max_positions") and not hasattr(
state["cfg"].task, "max_source_positions"
):
state["cfg"].task.max_source_positions = state[
"cfg"
].task.max_positions
state["cfg"].task.max_target_positions = state[
"cfg"
].task.max_positions
return state
def prune_state_dict(state_dict, args):
def prune_state_dict(state_dict, model_cfg: Optional[DictConfig]):
"""Prune the given state_dict if desired for LayerDrop
(https://arxiv.org/abs/1909.11556).
@ -453,16 +480,20 @@ def prune_state_dict(state_dict, args):
It's called by functions that load models from checkpoints and does not
need to be called directly.
"""
if not args or args.arch == "ptt_transformer":
arch = None
if model_cfg is not None:
arch = (
model_cfg._name
if isinstance(model_cfg, DictConfig)
else getattr(model_cfg, "arch", None)
)
if not model_cfg or arch is None or arch == "ptt_transformer":
# args should not be none, but don't crash if it is.
return state_dict
encoder_layers_to_keep = (
args.encoder_layers_to_keep if "encoder_layers_to_keep" in vars(args) else None
)
decoder_layers_to_keep = (
args.decoder_layers_to_keep if "decoder_layers_to_keep" in vars(args) else None
)
encoder_layers_to_keep = getattr(model_cfg, "encoder_layers_to_keep", None)
decoder_layers_to_keep = getattr(model_cfg, "decoder_layers_to_keep", None)
if not encoder_layers_to_keep and not decoder_layers_to_keep:
return state_dict
@ -474,7 +505,7 @@ def prune_state_dict(state_dict, args):
def create_pruning_pass(layers_to_keep, layer_name):
keep_layers = sorted(
[int(layer_string) for layer_string in layers_to_keep.split(",")]
int(layer_string) for layer_string in layers_to_keep.split(",")
)
mapping_dict = {}
for i in range(len(keep_layers)):
@ -518,10 +549,12 @@ def prune_state_dict(state_dict, args):
# Since layers are now pruned, *_layers_to_keep are no longer needed.
# This is more of "It would make it work fix" rather than a proper fix.
if "encoder_layers_to_keep" in vars(args):
args.encoder_layers_to_keep = None
if "decoder_layers_to_keep" in vars(args):
args.decoder_layers_to_keep = None
with open_dict(model_cfg):
if hasattr(model_cfg, "encoder_layers_to_keep"):
model_cfg.encoder_layers_to_keep = None
if hasattr(model_cfg, "decoder_layers_to_keep"):
model_cfg.decoder_layers_to_keep = None
return new_state_dict

View File

@ -6,8 +6,6 @@
import importlib
import os
from argparse import Namespace
from typing import Union
from fairseq import registry
from fairseq.criterions.fairseq_criterion import ( # noqa
@ -27,8 +25,8 @@ from omegaconf import DictConfig
)
def build_criterion(criterion_cfg: Union[DictConfig, Namespace], task):
return build_criterion_(criterion_cfg, task)
def build_criterion(cfg: DictConfig, task):
return build_criterion_(cfg, task)
# automatically import any Python files in the criterions/ directory

View File

@ -11,13 +11,13 @@ from fairseq import metrics, utils
from fairseq.criterions import FairseqCriterion, register_criterion
from fairseq.dataclass import FairseqDataclass
from fairseq.dataclass.constants import DDP_BACKEND_CHOICES
from omegaconf import II
from omegaconf import II, DictConfig
@dataclass
class AdaptiveLossConfig(FairseqDataclass):
sentence_avg: bool = II("params.optimization.sentence_avg")
ddp_backend: DDP_BACKEND_CHOICES = II("params.distributed_training.ddp_backend")
sentence_avg: bool = II("optimization.sentence_avg")
ddp_backend: DDP_BACKEND_CHOICES = II("distributed_training.ddp_backend")
@register_criterion("adaptive_loss", dataclass=AdaptiveLossConfig)
@ -31,14 +31,14 @@ class AdaptiveLoss(FairseqCriterion):
self.sentence_avg = sentence_avg
@classmethod
def build_criterion(cls, args, task):
if getattr(args, "ddp_backend", None) == "c10d":
def build_criterion(cls, cfg: DictConfig, task):
if cfg.ddp_backend == "c10d":
raise Exception(
"AdaptiveLoss is not compatible with the c10d "
"version of DistributedDataParallel. Please use "
"`--ddp-backend=no_c10d` instead."
)
return cls(task, args.sentence_avg)
return cls(task, cfg.sentence_avg)
def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample.

View File

@ -15,7 +15,7 @@ from omegaconf import II
@dataclass
class CrossEntropyCriterionConfig(FairseqDataclass):
sentence_avg: bool = II("params.optimization.sentence_avg")
sentence_avg: bool = II("optimization.sentence_avg")
@register_criterion("cross_entropy", dataclass=CrossEntropyCriterionConfig)

View File

@ -10,24 +10,24 @@ from argparse import Namespace
import torch
import torch.nn.functional as F
from fairseq import metrics, utils
from fairseq.criterions import FairseqCriterion, register_criterion
from fairseq.criterions import LegacyFairseqCriterion, register_criterion
from fairseq.data.data_utils import post_process
from fairseq.logging.meters import safe_round
@register_criterion("ctc")
class CtcCriterion(FairseqCriterion):
def __init__(self, task, wer_args, zero_infinity, sentence_avg, remove_bpe):
super().__init__(task)
class CtcCriterion(LegacyFairseqCriterion):
def __init__(self, args, task):
super().__init__(args, task)
self.blank_idx = task.target_dictionary.bos()
self.pad_idx = task.target_dictionary.pad()
self.eos_idx = task.target_dictionary.eos()
self.post_process = remove_bpe if remove_bpe else "letter"
self.post_process = args.remove_bpe if args.remove_bpe else "letter"
if wer_args is not None:
if args.wer_args is not None:
from examples.speech_recognition.w2l_decoder import W2lKenLMDecoder
wer_compute_kenlm, wer_lexicon, lm_w, ws_w = eval(wer_args)
wer_compute_kenlm, wer_lexicon, lm_w, ws_w = eval(args.wer_args)
dec_args = Namespace()
dec_args.nbest = 1
@ -46,8 +46,8 @@ class CtcCriterion(FairseqCriterion):
else:
self.w2l_decoder = None
self.zero_infinity = zero_infinity
self.sentence_avg = sentence_avg
self.zero_infinity = args.zero_infinity
self.sentence_avg = args.sentence_avg
@staticmethod
def add_args(parser):

View File

@ -8,6 +8,7 @@ from typing import Any, Dict, List
from fairseq import metrics, utils
from fairseq.dataclass.utils import gen_parser_from_dataclass
from omegaconf import DictConfig
from torch.nn.modules.loss import _Loss
@ -27,10 +28,8 @@ class FairseqCriterion(_Loss):
gen_parser_from_dataclass(parser, dc())
@classmethod
def build_criterion(cls, args, task):
def build_criterion(cls, cfg: DictConfig, task):
"""Construct a criterion from command-line args."""
# Criterions can override this, but for convenience we also try
# to automatically map argparse.Namespace keys to corresponding
# arguments in the __init__.
init_args = {}
for p in inspect.signature(cls).parameters.values():
@ -47,8 +46,8 @@ class FairseqCriterion(_Loss):
if p.name == "task":
init_args["task"] = task
elif hasattr(args, p.name):
init_args[p.name] = getattr(args, p.name)
elif hasattr(cfg, p.name):
init_args[p.name] = getattr(cfg, p.name)
elif p.default != p.empty:
pass # we'll use the default value
else:
@ -70,7 +69,7 @@ class FairseqCriterion(_Loss):
@staticmethod
def aggregate_logging_outputs(
logging_outputs: List[Dict[str, Any]],
logging_outputs: List[Dict[str, Any]]
) -> Dict[str, Any]:
"""Aggregate logging outputs from data parallel training."""
utils.deprecation_warning(

View File

@ -4,6 +4,8 @@
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass, field
from fairseq import file_utils
from fairseq.data.encoders import register_bpe
from fairseq.data.encoders.byte_utils import (
@ -12,19 +14,20 @@ from fairseq.data.encoders.byte_utils import (
byte_encode,
smart_byte_decode,
)
from fairseq.dataclass import FairseqDataclass
@register_bpe("byte_bpe")
@dataclass
class ByteBpeConfig(FairseqDataclass):
sentencepiece_model_path: str = field(
default="???", metadata={"help": "path to sentencepiece model"}
)
@register_bpe("byte_bpe", dataclass=ByteBpeConfig)
class ByteBPE(object):
@staticmethod
def add_args(parser):
# fmt: off
parser.add_argument('--sentencepiece-model-path', type=str,
help='path to sentencepiece model')
# fmt: on
def __init__(self, args):
vocab = file_utils.cached_path(args.sentencepiece_model_path)
def __init__(self, cfg):
vocab = file_utils.cached_path(cfg.sentencepiece_model_path)
try:
import sentencepiece as spm

View File

@ -15,7 +15,7 @@ from fairseq.data.encoders.byte_utils import (
@register_bpe("bytes")
class Bytes(object):
def __init__(self, args):
def __init__(self, *unused):
pass
@staticmethod

View File

@ -13,7 +13,7 @@ SPACE_ESCAPE = chr(9601)
@register_bpe("characters")
class Characters(object):
def __init__(self, args):
def __init__(self, *unused):
pass
@staticmethod

View File

@ -3,23 +3,24 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass, field
from fairseq import file_utils
from fairseq.data.encoders import register_bpe
from fairseq.dataclass import FairseqDataclass
@register_bpe("fastbpe")
@dataclass
class fastBPEConfig(FairseqDataclass):
bpe_codes: str = field(default="???", metadata={"help": "path to fastBPE BPE"})
@register_bpe("fastbpe", dataclass=fastBPEConfig)
class fastBPE(object):
@staticmethod
def add_args(parser):
# fmt: off
parser.add_argument('--bpe-codes', type=str,
help='path to fastBPE BPE')
# fmt: on
def __init__(self, args):
if args.bpe_codes is None:
def __init__(self, cfg):
if cfg.bpe_codes is None:
raise ValueError("--bpe-codes is required for --bpe=fastbpe")
codes = file_utils.cached_path(args.bpe_codes)
codes = file_utils.cached_path(cfg.bpe_codes)
try:
import fastBPE

View File

@ -3,8 +3,11 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass, field
from fairseq import file_utils
from fairseq.data.encoders import register_bpe
from fairseq.dataclass import FairseqDataclass
from .gpt2_bpe_utils import get_encoder
@ -13,26 +16,21 @@ DEFAULT_ENCODER_JSON = "https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.
DEFAULT_VOCAB_BPE = "https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe"
@register_bpe("gpt2")
class GPT2BPE(object):
@staticmethod
def add_args(parser):
# fmt: off
parser.add_argument('--gpt2-encoder-json', type=str,
default=DEFAULT_ENCODER_JSON,
help='path to encoder.json')
parser.add_argument('--gpt2-vocab-bpe', type=str,
default=DEFAULT_VOCAB_BPE,
help='path to vocab.bpe')
# fmt: on
@dataclass
class GPT2BPEConfig(FairseqDataclass):
gpt2_encoder_json: str = field(
default=DEFAULT_ENCODER_JSON, metadata={"help": "path to encoder.json"}
)
gpt2_vocab_bpe: str = field(
default=DEFAULT_VOCAB_BPE, metadata={"help": "path to vocab.bpe"}
)
def __init__(self, args):
encoder_json = file_utils.cached_path(
getattr(args, "gpt2_encoder_json", DEFAULT_ENCODER_JSON)
)
vocab_bpe = file_utils.cached_path(
getattr(args, "gpt2_vocab_bpe", DEFAULT_VOCAB_BPE)
)
@register_bpe("gpt2", dataclass=GPT2BPEConfig)
class GPT2BPE(object):
def __init__(self, cfg):
encoder_json = file_utils.cached_path(cfg.gpt2_encoder_json)
vocab_bpe = file_utils.cached_path(cfg.gpt2_vocab_bpe)
self.bpe = get_encoder(encoder_json, vocab_bpe)
def encode(self, x: str) -> str:

View File

@ -3,22 +3,24 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass, field
from typing import Optional
from fairseq.data.encoders import register_bpe
from fairseq.dataclass import FairseqDataclass
@register_bpe("bert")
@dataclass
class BertBPEConfig(FairseqDataclass):
bpe_cased: bool = field(default=False, metadata={"help": "set for cased BPE"})
bpe_vocab_file: Optional[str] = field(
default=None, metadata={"help": "bpe vocab file"}
)
@register_bpe("bert", dataclass=BertBPEConfig)
class BertBPE(object):
@staticmethod
def add_args(parser):
# fmt: off
parser.add_argument('--bpe-cased', action='store_true',
help='set for cased BPE',
default=False)
parser.add_argument('--bpe-vocab-file', type=str,
help='bpe vocab file.')
# fmt: on
def __init__(self, args):
def __init__(self, cfg):
try:
from transformers import BertTokenizer
except ImportError:
@ -26,13 +28,13 @@ class BertBPE(object):
"Please install transformers with: pip install transformers"
)
if "bpe_vocab_file" in args:
if cfg.bpe_vocab_file:
self.bert_tokenizer = BertTokenizer(
args.bpe_vocab_file, do_lower_case=not args.bpe_cased
cfg.bpe_vocab_file, do_lower_case=not cfg.bpe_cased
)
else:
vocab_file_name = (
"bert-base-cased" if args.bpe_cased else "bert-base-uncased"
"bert-base-cased" if cfg.bpe_cased else "bert-base-uncased"
)
self.bert_tokenizer = BertTokenizer.from_pretrained(vocab_file_name)

View File

@ -3,21 +3,24 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass, field
from fairseq.data.encoders import register_bpe
from fairseq.dataclass import FairseqDataclass
@register_bpe("hf_byte_bpe")
@dataclass
class HuggingFaceByteLevelBPEConfig(FairseqDataclass):
bpe_merges: str = field(default="???", metadata={"help": "path to merges.txt"})
bpe_vocab: str = field(default="???", metadata={"help": "path to vocab.json"})
bpe_add_prefix_space: bool = field(
default=False, metadata={"help": "add prefix space before encoding"}
)
@register_bpe("hf_byte_bpe", dataclass=HuggingFaceByteLevelBPEConfig)
class HuggingFaceByteLevelBPE(object):
@staticmethod
def add_args(parser):
# fmt: off
parser.add_argument('--bpe-merges', help='path to merges.txt')
parser.add_argument('--bpe-vocab', help='path to vocab.json')
parser.add_argument('--bpe-add-prefix-space', action='store_true',
help='add prefix space before encoding')
# fmt: on
def __init__(self, args):
def __init__(self, cfg):
try:
from tokenizers import ByteLevelBPETokenizer
except ImportError:
@ -26,9 +29,9 @@ class HuggingFaceByteLevelBPE(object):
)
self.bpe = ByteLevelBPETokenizer(
args.bpe_vocab,
args.bpe_merges,
add_prefix_space=getattr(args, "bpe_add_prefix_space", False),
cfg.bpe_vocab,
cfg.bpe_merges,
add_prefix_space=cfg.bpe_add_prefix_space,
)
def encode(self, x: str) -> str:

View File

@ -3,37 +3,35 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass, field
from fairseq.data.encoders import register_tokenizer
from fairseq.dataclass import FairseqDataclass
@register_tokenizer("moses")
@dataclass
class MosesTokenizerConfig(FairseqDataclass):
source_lang: str = field(default="en", metadata={"help": "source language"})
target_lang: str = field(default="en", metadata={"help": "target language"})
moses_no_dash_splits: bool = field(
default=False, metadata={"help": "don't apply dash split rules"}
)
moses_no_escape: bool = field(
default=False,
metadata={"help": "don't perform HTML escaping on apostrophe, quotes, etc."},
)
@register_tokenizer("moses", dataclass=MosesTokenizerConfig)
class MosesTokenizer(object):
@staticmethod
def add_args(parser):
# fmt: off
parser.add_argument('--moses-source-lang', metavar='SRC',
help='source language')
parser.add_argument('--moses-target-lang', metavar='TARGET',
help='target language')
parser.add_argument('--moses-no-dash-splits', action='store_true', default=False,
help='don\'t apply dash split rules')
parser.add_argument('--moses-no-escape', action='store_true', default=False,
help='don\'t perform HTML escaping on apostrophy, quotes, etc.')
# fmt: on
def __init__(self, args):
self.args = args
if getattr(args, "moses_source_lang", None) is None:
args.moses_source_lang = getattr(args, "source_lang", "en")
if getattr(args, "moses_target_lang", None) is None:
args.moses_target_lang = getattr(args, "target_lang", "en")
def __init__(self, cfg):
self.cfg = cfg
try:
from sacremoses import MosesTokenizer, MosesDetokenizer
self.tok = MosesTokenizer(args.moses_source_lang)
self.detok = MosesDetokenizer(args.moses_target_lang)
self.tok = MosesTokenizer(cfg.source_lang)
self.detok = MosesDetokenizer(cfg.target_lang)
except ImportError:
raise ImportError(
"Please install Moses tokenizer with: pip install sacremoses"
@ -42,9 +40,9 @@ class MosesTokenizer(object):
def encode(self, x: str) -> str:
return self.tok.tokenize(
x,
aggressive_dash_splits=(not self.args.moses_no_dash_splits),
aggressive_dash_splits=(not self.cfg.moses_no_dash_splits),
return_str=True,
escape=(not self.args.moses_no_escape),
escape=(not self.cfg.moses_no_escape),
)
def decode(self, x: str) -> str:

View File

@ -8,7 +8,7 @@ from fairseq.data.encoders import register_tokenizer
@register_tokenizer("nltk")
class NLTKTokenizer(object):
def __init__(self, source_lang=None, target_lang=None):
def __init__(self, *unused):
try:
from nltk.tokenize import word_tokenize

View File

@ -3,21 +3,24 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass, field
from fairseq import file_utils
from fairseq.data.encoders import register_bpe
from fairseq.dataclass import FairseqDataclass
@register_bpe("sentencepiece")
@dataclass
class SentencepieceConfig(FairseqDataclass):
sentencepiece_model: str = field(
default="???", metadata={"help": "path to sentencepiece model"}
)
@register_bpe("sentencepiece", dataclass=SentencepieceConfig)
class SentencepieceBPE(object):
@staticmethod
def add_args(parser):
# fmt: off
parser.add_argument('--sentencepiece-model', type=str,
help='path to sentencepiece model')
# fmt: on
def __init__(self, args):
sentencepiece_model = file_utils.cached_path(args.sentencepiece_model)
def __init__(self, cfg):
sentencepiece_model = file_utils.cached_path(cfg.sentencepiece_model)
try:
import sentencepiece as spm

View File

@ -10,7 +10,7 @@ from fairseq.data.encoders import register_tokenizer
@register_tokenizer("space")
class SpaceTokenizer(object):
def __init__(self, source_lang=None, target_lang=None):
def __init__(self, *unused):
self.space_tok = re.compile(r"\s+")
def encode(self, x: str) -> str:

View File

@ -3,25 +3,25 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass, field
from fairseq import file_utils
from fairseq.data.encoders import register_bpe
from fairseq.dataclass import FairseqDataclass
@register_bpe("subword_nmt")
@dataclass
class SubwordNMTBPEConfig(FairseqDataclass):
bpe_codes: str = field(default="???", metadata={"help": "path to subword NMT BPE"})
bpe_separator: str = field(default="@@", metadata={"help": "BPE separator"})
@register_bpe("subword_nmt", dataclass=SubwordNMTBPEConfig)
class SubwordNMTBPE(object):
@staticmethod
def add_args(parser):
# fmt: off
parser.add_argument('--bpe-codes', type=str,
help='path to subword NMT BPE')
parser.add_argument('--bpe-separator', default='@@',
help='BPE separator')
# fmt: on
def __init__(self, args):
if args.bpe_codes is None:
def __init__(self, cfg):
if cfg.bpe_codes is None:
raise ValueError("--bpe-codes is required for --bpe=subword_nmt")
codes = file_utils.cached_path(args.bpe_codes)
codes = file_utils.cached_path(cfg.bpe_codes)
try:
from subword_nmt import apply_bpe
@ -31,7 +31,7 @@ class SubwordNMTBPE(object):
"--codes",
codes,
"--separator",
args.bpe_separator,
cfg.bpe_separator,
]
)
self.bpe = apply_bpe.BPE(

View File

@ -9,5 +9,7 @@ from fairseq.dataclass.utils import ChoiceEnum
LOG_FORMAT_CHOICES = ChoiceEnum(["json", "none", "simple", "tqdm"])
DDP_BACKEND_CHOICES = ChoiceEnum(["c10d", "no_c10d"])
DISTRIBUTED_WRAPPER_CHOICES = ChoiceEnum(["DDP", "SlowMo"])
GENERATION_CONSTRAINTS_CHOICES = ChoiceEnum(["ordered", "unordered"])
GENERATION_DECODING_FORMAT_CHOICES = ChoiceEnum(["unigram", "ensemble", "vote", "dp", "bs"])
ZERO_SHARDING_CHOICES = ChoiceEnum(["none", "os"])
PIPELINE_CHECKPOINT_CHOICES = ChoiceEnum(["always", "never", "except_last"])

View File

@ -3,32 +3,37 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
import sys
from argparse import Namespace
from dataclasses import dataclass, field
from dataclasses import _MISSING_TYPE, dataclass, field
from typing import Any, Dict, List, Optional, Tuple, Type
import torch
from fairseq.criterions import CRITERION_DATACLASS_REGISTRY
from fairseq.data.indexed_dataset import get_available_dataset_impl
from fairseq.dataclass.constants import (
DDP_BACKEND_CHOICES,
DISTRIBUTED_WRAPPER_CHOICES,
GENERATION_CONSTRAINTS_CHOICES,
GENERATION_DECODING_FORMAT_CHOICES,
LOG_FORMAT_CHOICES,
PIPELINE_CHECKPOINT_CHOICES,
ZERO_SHARDING_CHOICES,
)
from fairseq.dataclass.utils import ChoiceEnum, FairseqDataclass
from fairseq.models import ARCH_MODEL_REGISTRY, MODEL_DATACLASS_REGISTRY
from fairseq.optim import OPTIMIZER_DATACLASS_REGISTRY
from fairseq.optim.bmuf import FairseqBMUFConfig
from fairseq.optim.lr_scheduler import LR_SCHEDULER_DATACLASS_REGISTRY
from fairseq.registry import REGISTRIES
from fairseq.tasks import TASK_DATACLASS_REGISTRY
from hydra.core.config_store import ConfigStore
from omegaconf import II
logger = logging.getLogger(__name__)
@dataclass
class CommonParams(FairseqDataclass):
class CommonConfig(FairseqDataclass):
# This is the core dataclass including common parameters shared by all different jobs. Please append your params to other dataclasses if they were
# used for a particular purpose or task, such as those dedicated for `distributed training`, `optimization`, etc.
no_progress_bar: bool = field(
@ -109,18 +114,6 @@ class CommonParams(FairseqDataclass):
model_parallel_size: int = field(
default=1, metadata={"help": "total number of GPUs to parallelize model over"}
)
checkpoint_suffix: str = field(
default="", metadata={"help": "suffix to add to the checkpoint file name"}
)
checkpoint_shard_count: int = field(
default=1,
metadata={
"help": "Number of shards containing the checkpoint - "
"if the checkpoint is over 300GB, it is preferable "
"to split it into shards to prevent OOM on CPU while loading "
"the checkpoint"
},
)
quantization_config_path: Optional[str] = field(
default=None, metadata={"help": "path to quantization config file"}
)
@ -130,7 +123,7 @@ class CommonParams(FairseqDataclass):
@dataclass
class DistributedTrainingParams(FairseqDataclass):
class DistributedTrainingConfig(FairseqDataclass):
distributed_world_size: int = field(
default=max(1, torch.cuda.device_count()),
metadata={
@ -229,7 +222,7 @@ class DistributedTrainingParams(FairseqDataclass):
default=False,
metadata={"help": "if set, use pipeline model parallelism across GPUs"},
)
pipeline_balance: str = field(
pipeline_balance: Optional[str] = field(
default=None,
metadata={
"help": "partition the model into N_K pieces, where each piece "
@ -237,7 +230,7 @@ class DistributedTrainingParams(FairseqDataclass):
"should equal the total number of layers in the model"
},
)
pipeline_devices: str = field(
pipeline_devices: Optional[str] = field(
default=None,
metadata={
"help": "a list of device indices indicating which device to place "
@ -245,10 +238,10 @@ class DistributedTrainingParams(FairseqDataclass):
"equal the length of the --pipeline-balance argument"
},
)
pipeline_chunks: int = field(
pipeline_chunks: Optional[int] = field(
default=0, metadata={"help": "microbatch count for pipeline model parallelism"}
)
pipeline_encoder_balance: str = field(
pipeline_encoder_balance: Optional[str] = field(
default=None,
metadata={
"help": "partition the pipeline parallel encoder into N_K pieces, where each piece "
@ -256,7 +249,7 @@ class DistributedTrainingParams(FairseqDataclass):
"should equal the total number of encoder layers in the model"
},
)
pipeline_encoder_devices: str = field(
pipeline_encoder_devices: Optional[str] = field(
default=None,
metadata={
"help": "a list of device indices indicating which device to place "
@ -264,7 +257,7 @@ class DistributedTrainingParams(FairseqDataclass):
"equal the length of the --pipeline-encoder-balance argument"
},
)
pipeline_decoder_balance: str = field(
pipeline_decoder_balance: Optional[str] = field(
default=None,
metadata={
"help": "partition the pipeline parallel decoder into N_K pieces, where each piece "
@ -272,7 +265,7 @@ class DistributedTrainingParams(FairseqDataclass):
"should equal the total number of decoder layers in the model"
},
)
pipeline_decoder_devices: str = field(
pipeline_decoder_devices: Optional[str] = field(
default=None,
metadata={
"help": "a list of device indices indicating which device to place "
@ -287,10 +280,11 @@ class DistributedTrainingParams(FairseqDataclass):
zero_sharding: ZERO_SHARDING_CHOICES = field(
default="none", metadata={"help": "ZeRO sharding"}
)
tpu: bool = II("common.tpu")
@dataclass
class DatasetParams(FairseqDataclass):
class DatasetConfig(FairseqDataclass):
num_workers: int = field(
default=1, metadata={"help": "how many subprocesses to use for data loading"}
)
@ -374,7 +368,7 @@ class DatasetParams(FairseqDataclass):
@dataclass
class OptimizationParams(FairseqDataclass):
class OptimizationConfig(FairseqDataclass):
max_epoch: int = field(
default=0, metadata={"help": "force stop training at specified epoch"}
)
@ -421,7 +415,7 @@ class OptimizationParams(FairseqDataclass):
@dataclass
class CheckpointParams(FairseqDataclass):
class CheckpointConfig(FairseqDataclass):
save_dir: str = field(
default="checkpoints", metadata={"help": "path to save checkpoints"}
)
@ -514,12 +508,217 @@ class CheckpointParams(FairseqDataclass):
)
},
)
checkpoint_suffix: str = field(
default="", metadata={"help": "suffix to add to the checkpoint file name"}
)
checkpoint_shard_count: int = field(
default=1,
metadata={
"help": "Number of shards containing the checkpoint - "
"if the checkpoint is over 300GB, it is preferable "
"to split it into shards to prevent OOM on CPU while loading "
"the checkpoint"
},
)
model_parallel_size: int = II("common.model_parallel_size")
distributed_rank: int = II("distributed_training.distributed_rank")
@dataclass
class CommonEvalParams(FairseqDataclass):
class GenerationConfig(FairseqDataclass):
beam: int = field(
default=5,
metadata={"help": "beam size"},
)
nbest: int = field(
default=1,
metadata={"help": "number of hypotheses to output"},
)
max_len_a: float = field(
default=0,
metadata={
"help": "generate sequences of maximum length ax + b, where x is the source length"
},
)
max_len_b: int = field(
default=200,
metadata={
"help": "generate sequences of maximum length ax + b, where x is the source length"
},
)
min_len: int = field(
default=1,
metadata={"help": "minimum generation length"},
)
match_source_len: bool = field(
default=False,
metadata={"help": "generations should match the source length"},
)
unnormalized: bool = field(
default=False,
metadata={"help": "compare unnormalized hypothesis scores"},
)
no_early_stop: bool = field(
default=False,
metadata={"help": "deprecated"},
)
no_beamable_mm: bool = field(
default=False,
metadata={"help": "don't use BeamableMM in attention layers"},
)
lenpen: float = field(
default=1,
metadata={
"help": "length penalty: <1.0 favors shorter, >1.0 favors longer sentences"
},
)
unkpen: float = field(
default=0,
metadata={
"help": "unknown word penalty: <0 produces more unks, >0 produces fewer"
},
)
replace_unk: Optional[str] = field(
default=None,
metadata={
"help": "perform unknown replacement (optionally with alignment dictionary)",
"argparse_const": "@@ ",
},
)
sacrebleu: bool = field(
default=False,
metadata={"help": "score with sacrebleu"},
)
score_reference: bool = field(
default=False,
metadata={"help": "just score the reference translation"},
)
prefix_size: int = field(
default=0,
metadata={"help": "initialize generation by target prefix of given length"},
)
no_repeat_ngram_size: int = field(
default=0,
metadata={
"help": "ngram blocking such that this size ngram cannot be repeated in the generation"
},
)
sampling: bool = field(
default=False,
metadata={"help": "sample hypotheses instead of using beam search"},
)
sampling_topk: int = field(
default=-1,
metadata={"help": "sample from top K likely next words instead of all words"},
)
sampling_topp: float = field(
default=-1.0,
metadata={
"help": "sample from the smallest set whose cumulative probability mass exceeds p for next words"
},
)
constraints: Optional[GENERATION_CONSTRAINTS_CHOICES] = field(
default=None,
metadata={
"help": "enables lexically constrained decoding",
"argparse_const": "ordered",
},
)
temperature: float = field(
default=1.0,
metadata={"help": "temperature for generation"},
)
diverse_beam_groups: int = field(
default=-1,
metadata={"help": "number of groups for Diverse Beam Search"},
)
diverse_beam_strength: float = field(
default=0.5,
metadata={"help": "strength of diversity penalty for Diverse Beam Search"},
)
diversity_rate: float = field(
default=-1.0,
metadata={"help": "strength of diversity penalty for Diverse Siblings Search"},
)
print_alignment: bool = field(
default=False,
metadata={
"help": "if set, uses attention feedback to compute and print alignment to source tokens"
},
)
print_step: bool = field(
default=False,
metadata={"help": "print steps"},
)
lm_path: Optional[str] = field(
default=None,
metadata={"help": "path to lm checkpoint for lm fusion"},
)
lm_weight: float = field(
default=0.0,
metadata={"help": "weight for lm probs for lm fusion"},
)
# arguments for iterative refinement generator
iter_decode_eos_penalty: float = field(
default=0.0,
metadata={"help": "if > 0.0, it penalized early-stopping in decoding."},
)
iter_decode_max_iter: int = field(
default=10,
metadata={"help": "maximum iterations for iterative refinement."},
)
iter_decode_force_max_iter: bool = field(
default=False,
metadata={
"help": "if set, run exact the maximum number of iterations without early stop"
},
)
iter_decode_with_beam: int = field(
default=1,
metadata={
"help": "if > 1, model will generate translations varying by the lengths."
},
)
iter_decode_with_external_reranker: bool = field(
default=False,
metadata={
"help": "if set, the last checkpoint are assumed to be a reranker to rescore the translations"
},
)
retain_iter_history: bool = field(
default=False,
metadata={
"help": "if set, decoding returns the whole history of iterative refinement"
},
)
retain_dropout: bool = field(
default=False,
metadata={"help": "Use dropout at inference time"},
)
retain_dropout_modules: Optional[List[str]] = field(
default=None,
metadata={
"help": "if set, only retain dropout for the specified modules; "
"if not set, then dropout will be retained for all modules"
},
)
# special decoding format for advanced decoding.
decoding_format: Optional[GENERATION_DECODING_FORMAT_CHOICES] = field(
default=None,
metadata={"help": "special decoding format for advanced decoding."},
)
no_seed_provided: bool = field(
default=False,
metadata={"help": "if set, dont use seed for initializing random generators"},
)
@dataclass
class CommonEvalConfig(FairseqDataclass):
path: Optional[str] = field(
default=None, metadata={"help": "path(s) to model file(s), colon separated"}
default=None,
metadata={"help": "path(s) to model file(s), colon separated"},
)
remove_bpe: Optional[str] = field(
default=None,
@ -541,7 +740,7 @@ class CommonEvalParams(FairseqDataclass):
@dataclass
class EvalLMParams(FairseqDataclass):
class EvalLMConfig(FairseqDataclass):
output_word_probs: bool = field(
default=False,
metadata={
@ -569,37 +768,31 @@ class EvalLMParams(FairseqDataclass):
@dataclass
class TrainingConfig(FairseqDataclass):
"""Config for training, a composition of training params"""
common: CommonParams = CommonParams()
distributed_training: DistributedTrainingParams = DistributedTrainingParams()
dataset: DatasetParams = DatasetParams()
optimization: OptimizationParams = OptimizationParams()
checkpoint: CheckpointParams = CheckpointParams()
bmuf: FairseqBMUFConfig = FairseqBMUFConfig()
class InteractiveConfig(FairseqDataclass):
buffer_size: int = field(
default=0,
metadata={
"help": "read this many sentences into a buffer before processing them"
},
)
input: str = field(
default="-",
metadata={"help": "file to read from; use - for stdin"},
)
@dataclass
class EvalLMConfig(FairseqDataclass):
"""Config for eval lm, a composition of eval_lm params"""
common: CommonParams = CommonParams()
distributed_training: DistributedTrainingParams = DistributedTrainingParams()
dataset: DatasetParams = DatasetParams()
optimization: OptimizationParams = OptimizationParams()
checkpoint: CheckpointParams = CheckpointParams()
bmuf: FairseqBMUFConfig = FairseqBMUFConfig()
common_eval: CommonEvalParams = CommonEvalParams()
eval_lm: EvalLMParams = EvalLMParams()
def register_params_dataclass(
cs: ConfigStore, name: str, group: str, data_class: Type[FairseqDataclass]
) -> None:
"""register params dataclass in config store"""
node_ = data_class(_name=data_class.name())
cs.store(name=name, group=group, node=node_)
CONFIGS = {
"common": CommonConfig,
"common_eval": CommonEvalConfig,
"distributed_training": DistributedTrainingConfig,
"dataset": DatasetConfig,
"optimization": OptimizationConfig,
"checkpoint": CheckpointConfig,
"bmuf": FairseqBMUFConfig,
"generation": GenerationConfig,
"eval_lm": EvalLMConfig,
"interactive": InteractiveConfig,
}
def register_module_dataclass(
@ -608,100 +801,67 @@ def register_module_dataclass(
"""register dataclasses defined in modules in config store, for example, in migrated tasks, models, etc."""
# note that if `group == model`, we register all model archs, not the model name.
for k, v in registry.items():
if v is not None:
node_ = v(_name=k)
cs.store(name=k, group=group, node=node_)
node_ = v()
node_._name = k
cs.store(name=k, group=group, node=node_, provider="fairseq")
def register_training_hydra_cfg(cs: ConfigStore, name: str = "default") -> None:
def register_hydra_cfg(cs: ConfigStore, name: str = "default") -> None:
"""cs: config store instance, register common training configs"""
register_params_dataclass(
cs, name="training_params", group="params", data_class=TrainingConfig
)
for k, v in CONFIGS.items():
try:
cs.store(name=k, node=v())
except BaseException:
logger.error(f"{k} - {v()}")
raise
register_module_dataclass(cs, TASK_DATACLASS_REGISTRY, "task")
register_module_dataclass(cs, MODEL_DATACLASS_REGISTRY, "model")
register_module_dataclass(cs, CRITERION_DATACLASS_REGISTRY, "criterion")
register_module_dataclass(cs, OPTIMIZER_DATACLASS_REGISTRY, "optimizer")
register_module_dataclass(cs, LR_SCHEDULER_DATACLASS_REGISTRY, "lr_scheduler")
def register_eval_lm_hydra_cfg(cs: ConfigStore, name: str = "default") -> None:
"""cs: config store instance, register common training configs"""
register_params_dataclass(
cs, name="eval_lm_params", group="params", data_class=EvalLMConfig
)
register_module_dataclass(cs, TASK_DATACLASS_REGISTRY, "task")
register_module_dataclass(cs, CRITERION_DATACLASS_REGISTRY, "criterion")
register_module_dataclass(cs, OPTIMIZER_DATACLASS_REGISTRY, "optimizer")
register_module_dataclass(cs, LR_SCHEDULER_DATACLASS_REGISTRY, "lr_scheduler")
for k, v in REGISTRIES.items():
register_module_dataclass(cs, v["dataclass_registry"], k)
def _override_attr(
sub_node: str, data_class: Type[FairseqDataclass], args: Namespace
) -> List[str]:
overrides = []
for k in data_class.__dataclass_fields__.keys():
if k == "_name":
def get_default(f):
if not isinstance(f.default_factory, _MISSING_TYPE):
return f.default_factory()
return f.default
for k, v in data_class.__dataclass_fields__.items():
if k.startswith("_"):
# private member, skip
continue
if not hasattr(args, k):
# print(f"cannot override {sub_node}.{k} since args does not have attribute {k}")
continue
if getattr(args, k) is None:
val = get_default(v) if not hasattr(args, k) else getattr(args, k)
if val is None:
overrides.append("{}.{}=null".format(sub_node, k))
elif getattr(args, k) == "":
elif val == "":
overrides.append("{}.{}=''".format(sub_node, k))
elif isinstance(getattr(args, k), str):
if (
getattr(args, k).startswith("[")
or getattr(args, k).startswith("(")
or getattr(args, k).startswith("{")
or ("," in getattr(args, k))
):
overrides.append("{}.{}='{}'".format(sub_node, k, getattr(args, k)))
else:
overrides.append("{}.{}={}".format(sub_node, k, getattr(args, k)))
elif isinstance(val, str):
overrides.append("{}.{}='{}'".format(sub_node, k, val))
else:
overrides.append("{}.{}={}".format(sub_node, k, getattr(args, k)))
overrides.append("{}.{}={}".format(sub_node, k, val))
return overrides
def override_training_args(args: Namespace) -> Tuple[List[str], List[str]]:
overrides = []
overrides.extend(_override_attr("params.common", CommonParams, args))
overrides.extend(_override_attr("params.dataset", DatasetParams, args))
overrides.extend(
_override_attr("params.distributed_training", DistributedTrainingParams, args)
)
overrides.extend(_override_attr("params.optimization", OptimizationParams, args))
overrides.extend(_override_attr("params.checkpoint", CheckpointParams, args))
overrides.extend(_override_attr("params.bmuf", FairseqBMUFConfig, args))
module_overrides, module_deletes = override_module_args(args)
overrides.extend(module_overrides)
return overrides, module_deletes
def override_eval_lm_args(args: Namespace) -> Tuple[List[str], List[str]]:
overrides = []
overrides.extend(_override_attr("params.common", CommonParams, args))
overrides.extend(_override_attr("params.dataset", DatasetParams, args))
overrides.extend(
_override_attr("params.distributed_training", DistributedTrainingParams, args)
)
overrides.extend(_override_attr("params.common_eval", CommonEvalParams, args))
overrides.extend(_override_attr("params.eval_lm", EvalLMParams, args))
overrides.extend(_override_attr("params.bmuf", FairseqBMUFConfig, args))
module_overrides, module_deletes = override_module_args(args)
overrides.extend(module_overrides)
return overrides, module_deletes
def migrate_registry(
name, value, registry, args, overrides, deletes, use_name_as_val=False
):
if value in registry:
overrides.append("{}={}".format(name, value))
overrides.append("{}._name={}".format(name, value))
overrides.extend(_override_attr(name, registry[value], args))
elif use_name_as_val and value is not None:
overrides.append("{}={}".format(name, value))
else:
deletes.append(name)
def override_module_args(args: Namespace) -> Tuple[List[str], List[str]]:
@ -709,53 +869,34 @@ def override_module_args(args: Namespace) -> Tuple[List[str], List[str]]:
overrides = []
deletes = []
for k, v in CONFIGS.items():
overrides.extend(_override_attr(k, v, args))
if args is not None:
assert (
hasattr(args, "task")
and hasattr(args, "criterion")
and hasattr(args, "optimizer")
and hasattr(args, "lr_scheduler")
)
if args.task in TASK_DATACLASS_REGISTRY:
overrides.append("task={}".format(args.task))
overrides.append("task._name={}".format(args.task))
overrides.extend(
_override_attr("task", TASK_DATACLASS_REGISTRY[args.task], args)
if hasattr(args, "task"):
migrate_registry(
"task", args.task, TASK_DATACLASS_REGISTRY, args, overrides, deletes
)
else:
deletes.append("task")
if args.criterion in CRITERION_DATACLASS_REGISTRY:
overrides.append("criterion={}".format(args.criterion))
overrides.append("criterion._name={}".format(args.criterion))
overrides.extend(
_override_attr(
"criterion", CRITERION_DATACLASS_REGISTRY[args.criterion], args
)
)
else:
deletes.append("criterion")
if args.optimizer in OPTIMIZER_DATACLASS_REGISTRY:
overrides.append("optimizer={}".format(args.optimizer))
overrides.append("optimizer._name={}".format(args.optimizer))
overrides.extend(
_override_attr(
"optimizer", OPTIMIZER_DATACLASS_REGISTRY[args.optimizer], args
)
)
else:
deletes.append("optimizer")
if args.lr_scheduler in LR_SCHEDULER_DATACLASS_REGISTRY:
overrides.append("lr_scheduler={}".format(args.lr_scheduler))
overrides.append("lr_scheduler._name={}".format(args.lr_scheduler))
overrides.extend(
_override_attr(
"lr_scheduler",
LR_SCHEDULER_DATACLASS_REGISTRY[args.lr_scheduler],
# these options will be set to "None" if they have not yet been migrated
# so we can populate them with the entire flat args
CORE_REGISTRIES = {"criterion", "optimizer", "lr_scheduler"}
for k, v in REGISTRIES.items():
if hasattr(args, k):
migrate_registry(
k,
getattr(args, k),
v["dataclass_registry"],
args,
overrides,
deletes,
use_name_as_val=k not in CORE_REGISTRIES,
)
)
else:
deletes.append("lr_scheduler")
else:
deletes.append(k)
no_dc = True
if hasattr(args, "arch"):

View File

@ -3,17 +3,24 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from argparse import ArgumentParser
from dataclasses import MISSING, dataclass
import ast
from argparse import ArgumentParser, Namespace
from dataclasses import _MISSING_TYPE, MISSING, dataclass
from enum import Enum
from typing import Any, Dict, List, Optional
from hydra.core.global_hydra import GlobalHydra
from hydra.experimental import compose, initialize
from omegaconf import DictConfig, OmegaConf, open_dict
def eval_str_list(x, x_type=float):
if x is None:
return None
if isinstance(x, str):
x = eval(x)
if len(x) == 0:
return []
x = ast.literal_eval(x)
try:
return list(map(x_type, x))
except TypeError:
@ -70,22 +77,11 @@ class FairseqDataclass:
!= self.__dataclass_fields__[attribute_name].default
):
return getattr(self, attribute_name)
return self.__dataclass_fields__[attribute_name].default
def _get_default_factory(self, attribute_name: str) -> Any:
if hasattr(self, attribute_name):
if str(getattr(self, attribute_name)).startswith("${"):
return str(getattr(self, attribute_name))
elif str(self.__dataclass_fields__[attribute_name].default).startswith(
"${"
):
return str(self.__dataclass_fields__[attribute_name].default)
elif (
getattr(self, attribute_name)
!= self.__dataclass_fields__[attribute_name].default_factory()
):
return getattr(self, attribute_name)
return self.__dataclass_fields__[attribute_name].default_factory()
f = self.__dataclass_fields__[attribute_name]
if not isinstance(f.default_factory, _MISSING_TYPE):
return f.default_factory()
return f.default
def _get_type(self, attribute_name: str) -> Any:
return self.__dataclass_fields__[attribute_name].type
@ -119,7 +115,7 @@ def gen_parser_from_dataclass(
def interpret_dc_type(field_type):
if isinstance(field_type, str):
raise RuntimeError()
raise RuntimeError("field should be a type")
typestring = str(field_type)
if re.match(r"(typing.|^)Union\[(.*), NoneType\]$", typestring):
return field_type.__args__[0]
@ -129,12 +125,13 @@ def gen_parser_from_dataclass(
dataclass_instance: FairseqDataclass, k: str
) -> Dict[str, Any]:
"""k: dataclass attributes"""
kwargs = {}
field_type = dataclass_instance._get_type(k)
inter_type = interpret_dc_type(field_type)
if isinstance(inter_type, type) and issubclass(inter_type, List):
field_default = dataclass_instance._get_default_factory(k)
else:
field_default = dataclass_instance._get_default(k)
field_default = dataclass_instance._get_default(k)
if isinstance(inter_type, type) and issubclass(inter_type, Enum):
field_choices = [t.value for t in list(inter_type)]
@ -143,7 +140,7 @@ def gen_parser_from_dataclass(
field_help = dataclass_instance._get_help(k)
field_const = dataclass_instance._get_argparse_const(k)
kwargs = {}
if isinstance(field_default, str) and field_default.startswith("${"):
kwargs["default"] = field_default
else:
@ -163,7 +160,11 @@ def gen_parser_from_dataclass(
else:
raise NotImplementedError()
if field_default is not MISSING:
kwargs["default"] = ",".join(map(str, field_default))
kwargs["default"] = (
",".join(map(str, field_default))
if field_default is not None
else None
)
elif (
isinstance(inter_type, type) and issubclass(inter_type, Enum)
) or "Enum" in str(inter_type):
@ -187,6 +188,7 @@ def gen_parser_from_dataclass(
if field_const is not None:
kwargs["const"] = field_const
kwargs["nargs"] = "?"
return kwargs
for k in dataclass_instance._get_all_attributes():
@ -194,8 +196,122 @@ def gen_parser_from_dataclass(
if field_name is None:
continue
kwargs = get_kwargs_from_dc(dataclass_instance, k)
if isinstance(kwargs["default"], str) and kwargs["default"].startswith("${"):
continue
if delete_default:
del kwargs["default"]
if "default" in kwargs:
if isinstance(kwargs["default"], str) and kwargs["default"].startswith(
"${"
):
continue
if delete_default:
del kwargs["default"]
parser.add_argument(field_name, **kwargs)
def _set_legacy_defaults(args, cls):
"""Helper to set default arguments based on *add_args*."""
if not hasattr(cls, "add_args"):
return
import argparse
parser = argparse.ArgumentParser(
argument_default=argparse.SUPPRESS, allow_abbrev=False
)
cls.add_args(parser)
# copied from argparse.py:
defaults = argparse.Namespace()
for action in parser._actions:
if action.dest is not argparse.SUPPRESS:
if not hasattr(defaults, action.dest):
if action.default is not argparse.SUPPRESS:
setattr(defaults, action.dest, action.default)
for key, default_value in vars(defaults).items():
if not hasattr(args, key):
setattr(args, key, default_value)
def convert_namespace_to_omegaconf(args: Namespace) -> DictConfig:
from fairseq.dataclass.data_class import override_module_args
# Here we are using field values provided in args to override counterparts inside config object
overrides, deletes = override_module_args(args)
cfg_name = "config"
cfg_path = f"../../{cfg_name}"
if not GlobalHydra().is_initialized():
initialize(config_path=cfg_path)
composed_cfg = compose(cfg_name, overrides=overrides, strict=False)
for k in deletes:
composed_cfg[k] = None
cfg = OmegaConf.create(
OmegaConf.to_container(composed_cfg, resolve=True, enum_to_str=True)
)
# hack to be able to set Namespace in dict config. this should be removed when we update to newer
# omegaconf version that supports object flags, or when we migrate all existing models
from omegaconf import _utils
old_primitive = _utils.is_primitive_type
_utils.is_primitive_type = lambda _: True
if cfg.task is None and getattr(args, "task", None):
cfg.task = Namespace(**vars(args))
from fairseq.tasks import TASK_REGISTRY
_set_legacy_defaults(cfg.task, TASK_REGISTRY[args.task])
cfg.task._name = args.task
if cfg.model is None and getattr(args, "arch", None):
cfg.model = Namespace(**vars(args))
from fairseq.models import ARCH_MODEL_REGISTRY
_set_legacy_defaults(cfg.model, ARCH_MODEL_REGISTRY[args.arch])
cfg.model._name = args.arch
if cfg.optimizer is None and getattr(args, "optimizer", None):
cfg.optimizer = Namespace(**vars(args))
from fairseq.optim import OPTIMIZER_REGISTRY
_set_legacy_defaults(cfg.optimizer, OPTIMIZER_REGISTRY[args.optimizer])
cfg.optimizer._name = args.optimizer
if cfg.lr_scheduler is None and getattr(args, "lr_scheduler", None):
cfg.lr_scheduler = Namespace(**vars(args))
from fairseq.optim.lr_scheduler import LR_SCHEDULER_REGISTRY
_set_legacy_defaults(cfg.lr_scheduler, LR_SCHEDULER_REGISTRY[args.lr_scheduler])
cfg.lr_scheduler._name = args.lr_scheduler
if cfg.criterion is None and getattr(args, "criterion", None):
cfg.criterion = Namespace(**vars(args))
from fairseq.criterions import CRITERION_REGISTRY
_set_legacy_defaults(cfg.criterion, CRITERION_REGISTRY[args.criterion])
cfg.criterion._name = args.criterion
_utils.is_primitive_type = old_primitive
OmegaConf.set_struct(cfg, True)
return cfg
def populate_dataclass(
args: Namespace, dataclass: FairseqDataclass
) -> FairseqDataclass:
for k in dataclass.__dataclass_fields__.keys():
if k.startswith("_"):
# private member, skip
continue
if hasattr(args, k):
setattr(dataclass, k, getattr(args, k))
return dataclass
def overwrite_args_by_name(cfg: DictConfig, overrides: Dict[str, any]):
# this will be deprecated when we get rid of argparse and model_overrides logic
with open_dict(cfg):
for k in cfg.keys():
if isinstance(cfg[k], DictConfig):
overwrite_args_by_name(cfg[k], overrides)
elif k in overrides:
cfg[k] = overrides[k]

View File

@ -11,35 +11,38 @@ import socket
import struct
import subprocess
import warnings
from argparse import Namespace
from collections import OrderedDict
from typing import Any, Dict, Mapping
import torch
import torch.distributed as dist
from fairseq import utils
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
from omegaconf import DictConfig, open_dict
logger = logging.getLogger(__name__)
def is_master(args):
return args.distributed_rank == 0
def is_master(cfg: DictConfig):
return cfg.distributed_rank == 0
def infer_init_method(args, force_distributed=False):
if args.distributed_init_method is not None or getattr(args, "tpu", False):
def infer_init_method(cfg: DictConfig, force_distributed=False):
if cfg.distributed_init_method is not None or cfg.tpu:
return
if args.pipeline_model_parallel:
if cfg.pipeline_model_parallel:
balance_exists = (
args.pipeline_balance is not None
or args.pipeline_encoder_balance is not None
or args.pipeline_decoder_balance is not None
cfg.pipeline_balance is not None
or cfg.pipeline_encoder_balance is not None
or cfg.pipeline_decoder_balance is not None
)
devices_exist = (
args.pipeline_devices is not None
or args.pipeline_encoder_devices is not None
or args.pipeline_decoder_devices is not None
cfg.pipeline_devices is not None
or cfg.pipeline_encoder_devices is not None
or cfg.pipeline_decoder_devices is not None
)
if not balance_exists:
raise ValueError(
@ -50,19 +53,19 @@ def infer_init_method(args, force_distributed=False):
"--pipeline-devices is currently required for pipeline model parallelism"
)
args.pipeline_balance = utils.eval_str_list(args.pipeline_balance, type=int)
if args.pipeline_devices is not None:
args.pipeline_devices = utils.eval_str_list(args.pipeline_devices, type=int)
num_pipeline_devices = len(set(args.pipeline_devices))
cfg.pipeline_balance = utils.eval_str_list(cfg.pipeline_balance, type=int)
if cfg.pipeline_devices is not None:
cfg.pipeline_devices = utils.eval_str_list(cfg.pipeline_devices, type=int)
num_pipeline_devices = len(set(cfg.pipeline_devices))
else:
args.pipeline_encoder_devices = utils.eval_str_list(
args.pipeline_encoder_devices, type=int
cfg.pipeline_encoder_devices = utils.eval_str_list(
cfg.pipeline_encoder_devices, type=int
)
args.pipeline_decoder_devices = utils.eval_str_list(
args.pipeline_decoder_devices, type=int
cfg.pipeline_decoder_devices = utils.eval_str_list(
cfg.pipeline_decoder_devices, type=int
)
num_pipeline_devices = len(
set(args.pipeline_encoder_devices + args.pipeline_decoder_devices)
set(cfg.pipeline_encoder_devices + cfg.pipeline_decoder_devices)
)
gpus_per_node = torch.cuda.device_count()
assert (
@ -79,14 +82,14 @@ def infer_init_method(args, force_distributed=False):
key in os.environ
for key in ["MASTER_ADDR", "MASTER_PORT", "WORLD_SIZE", "RANK"]
):
args.distributed_init_method = "env://"
args.distributed_world_size = int(os.environ["WORLD_SIZE"])
args.distributed_rank = int(os.environ["RANK"])
cfg.distributed_init_method = "env://"
cfg.distributed_world_size = int(os.environ["WORLD_SIZE"])
cfg.distributed_rank = int(os.environ["RANK"])
# processes are created by torch.distributed.launch
args.distributed_no_spawn = True
cfg.distributed_no_spawn = True
# we can determine the init method automatically for Slurm
elif args.distributed_port > 0:
elif cfg.distributed_port > 0:
node_list = os.environ.get("SLURM_STEP_NODELIST")
if node_list is None:
node_list = os.environ.get("SLURM_JOB_NODELIST")
@ -95,9 +98,9 @@ def infer_init_method(args, force_distributed=False):
hostnames = subprocess.check_output(
["scontrol", "show", "hostnames", node_list]
)
args.distributed_init_method = "tcp://{host}:{port}".format(
cfg.distributed_init_method = "tcp://{host}:{port}".format(
host=hostnames.split()[0].decode("utf-8"),
port=args.distributed_port,
port=cfg.distributed_port,
)
nnodes = int(os.environ.get("SLURM_NNODES"))
ntasks_per_node = os.environ.get("SLURM_NTASKS_PER_NODE")
@ -111,88 +114,94 @@ def infer_init_method(args, force_distributed=False):
if ntasks_per_node == 1:
gpus_per_node = torch.cuda.device_count()
node_id = int(os.environ.get("SLURM_NODEID"))
args.distributed_rank = node_id * gpus_per_node
args.distributed_world_size = nnodes * gpus_per_node
elif args.pipeline_model_parallel:
cfg.distributed_rank = node_id * gpus_per_node
cfg.distributed_world_size = nnodes * gpus_per_node
elif cfg.pipeline_model_parallel:
assert ntasks_per_node == num_pipelines_per_node, (
"SLURM --ntasks-per-node must match number of pipelines per "
"node (={})".format(num_pipelines_per_node)
)
args.distributed_no_spawn = True
cfg.distributed_no_spawn = True
# For 4-way MP on nodes with 8 GPUs, ranks will be [0, 1] on
# the first node, [1, 2] on the second node, etc. This
# matches torch.distributed.launch.
node_id = int(os.environ.get("SLURM_NODEID"))
local_id = int(os.environ.get("SLURM_LOCALID"))
args.distributed_rank = node_id * num_pipelines_per_node + local_id
cfg.distributed_rank = node_id * num_pipelines_per_node + local_id
# In the above example, device_id will always be in [0, 1],
# which also matches torch.distributed.launch.
args.device_id = local_id
cfg.device_id = local_id
# We also want to set distributed_world_size to be the total
# number of pipelines across all nodes.
args.distributed_world_size = nnodes * num_pipelines_per_node
cfg.distributed_world_size = nnodes * num_pipelines_per_node
else:
assert ntasks_per_node == args.distributed_world_size // nnodes
args.distributed_no_spawn = True
args.distributed_rank = int(os.environ.get("SLURM_PROCID"))
args.device_id = int(os.environ.get("SLURM_LOCALID"))
assert ntasks_per_node == cfg.distributed_world_size // nnodes
cfg.distributed_no_spawn = True
cfg.distributed_rank = int(os.environ.get("SLURM_PROCID"))
cfg.device_id = int(os.environ.get("SLURM_LOCALID"))
except subprocess.CalledProcessError as e: # scontrol failed
raise e
except FileNotFoundError: # Slurm is not installed
pass
elif args.distributed_world_size > 1 or force_distributed:
elif cfg.distributed_world_size > 1 or force_distributed:
# fallback for single node with multiple GPUs
assert args.distributed_world_size <= torch.cuda.device_count()
assert cfg.distributed_world_size <= torch.cuda.device_count()
port = random.randint(10000, 20000)
args.distributed_init_method = "tcp://localhost:{port}".format(port=port)
cfg.distributed_init_method = "tcp://localhost:{port}".format(port=port)
if args.pipeline_model_parallel:
if not args.distributed_no_spawn:
if cfg.pipeline_model_parallel:
if not cfg.distributed_no_spawn:
# When distributed_no_spawn is False, we expect distributed_rank and
# distributed_world_size to be based on the total number of GPUs, so
# we need to correct them to be based on the number of pipelines.
assert args.distributed_world_size % num_pipeline_devices == 0
args.distributed_world_size = (
args.distributed_world_size // num_pipeline_devices
assert cfg.distributed_world_size % num_pipeline_devices == 0
cfg.distributed_world_size = (
cfg.distributed_world_size // num_pipeline_devices
)
# In the case of 4-way MP on nodes with 8 GPUs, we want
# distributed_rank to be the starting GPU index for each pipeline
# i.e., 0, 2, ...
assert args.distributed_rank % gpus_per_node == 0
assert args.distributed_rank % num_pipeline_devices == 0
args.distributed_rank = args.distributed_rank // num_pipeline_devices
# launch one process per pipeline
args.distributed_num_procs = num_pipelines_per_node
assert cfg.distributed_rank % gpus_per_node == 0
assert cfg.distributed_rank % num_pipeline_devices == 0
with open_dict(cfg):
cfg.distributed_rank = cfg.distributed_rank // num_pipeline_devices
# launch one process per pipeline
cfg.distributed_num_procs = num_pipelines_per_node
# if we have 4-way MP on a node with 8 GPUs, we want device_ids to be 0
# and 4, indicating the starting device IDs for each pipeline
args.device_id *= num_pipeline_devices
cfg.device_id *= num_pipeline_devices
if args.device_id > 0:
if cfg.device_id > 0:
# if there's multiple pipelines on a node (e.g., 4-way MP on an 8
# GPU node), we need to adjust pipeline_devices accordingly
logger.debug(
"setting CUDA device={} on rank {}".format(
args.device_id, args.distributed_rank
cfg.device_id, cfg.distributed_rank
)
)
torch.cuda.set_device(args.device_id)
args.pipeline_devices = [args.device_id + d for d in args.pipeline_devices]
torch.cuda.set_device(cfg.device_id)
with open_dict(cfg):
cfg.pipeline_devices = [cfg.device_id + d for d in cfg.pipeline_devices]
logger.info(
"setting pipeline_devices={} on rank {}".format(
args.pipeline_devices, args.distributed_rank
),
cfg.pipeline_devices, cfg.distributed_rank
)
)
elif not cfg.distributed_no_spawn:
with open_dict(cfg):
cfg.distributed_num_procs = min(
torch.cuda.device_count(), cfg.distributed_world_size
)
elif not args.distributed_no_spawn:
args.distributed_num_procs = min(
torch.cuda.device_count(),
args.distributed_world_size,
)
def distributed_init(args):
if not getattr(args, "tpu", False):
def distributed_init(cfg: DictConfig):
if isinstance(cfg, Namespace):
cfg = convert_namespace_to_omegaconf(cfg)
if not cfg.common.tpu:
if torch.distributed.is_initialized():
warnings.warn(
"Distributed is already initialized, cannot initialize twice!"
@ -200,20 +209,20 @@ def distributed_init(args):
else:
logger.info(
"distributed init (rank {}): {}".format(
args.distributed_rank,
args.distributed_init_method,
cfg.distributed_training.distributed_rank,
cfg.distributed_training.distributed_init_method,
)
)
dist.init_process_group(
backend=args.distributed_backend,
init_method=args.distributed_init_method,
world_size=args.distributed_world_size,
rank=args.distributed_rank,
backend=cfg.distributed_training.distributed_backend,
init_method=cfg.distributed_training.distributed_init_method,
world_size=cfg.distributed_training.distributed_world_size,
rank=cfg.distributed_training.distributed_rank,
)
logger.info(
"initialized host {} as rank {}".format(
socket.gethostname(),
args.distributed_rank,
cfg.distributed_training.distributed_rank,
)
)
@ -221,20 +230,22 @@ def distributed_init(args):
if torch.cuda.is_available():
dist.all_reduce(torch.zeros(1).cuda())
args.distributed_rank = torch.distributed.get_rank()
cfg.distributed_training.distributed_rank = torch.distributed.get_rank()
else:
import torch_xla.core.xla_model as xm
assert xm.xrt_world_size() == args.distributed_world_size
args.device_id = xm.get_local_ordinal()
args.distributed_rank = xm.get_ordinal()
assert xm.xrt_world_size() == cfg.distributed_training.distributed_world_size
cfg.distributed_training.device_id = xm.get_local_ordinal()
cfg.distributed_training.distributed_rank = xm.get_ordinal()
xm.rendezvous("distributed_init") # wait for all workers
xm.mark_step()
if not is_master(args):
if is_master(cfg.distributed_training):
logging.getLogger().setLevel(logging.INFO)
else:
logging.getLogger().setLevel(logging.WARNING)
if args.model_parallel_size > 1:
if cfg.common.model_parallel_size > 1:
try:
from fairseq.model_parallel.megatron.mpu import (
get_model_parallel_rank,
@ -247,58 +258,61 @@ def distributed_init(args):
"\n\n git submodule update --init "
"fairseq/model_parallel/megatron"
)
initialize_model_parallel(args.model_parallel_size)
model_parallel_cuda_manual_seed(args.seed)
initialize_model_parallel(cfg.common.model_parallel_size)
model_parallel_cuda_manual_seed(cfg.common.seed)
model_part_number = get_model_parallel_rank()
args.checkpoint_suffix += "-model_part-{0}".format(model_part_number)
return args.distributed_rank
cfg.checkpoint.checkpoint_suffix += "-model_part-{0}".format(model_part_number)
return cfg.distributed_training.distributed_rank
def distributed_main(i, main, args, kwargs):
args.device_id = i
if torch.cuda.is_available() and not args.cpu and not getattr(args, "tpu", False):
torch.cuda.set_device(args.device_id)
if args.distributed_rank is None: # torch.multiprocessing.spawn
args.distributed_rank = kwargs.pop("start_rank", 0) + i
def distributed_main(i, main, cfg: DictConfig, kwargs):
cfg.distributed_training.device_id = i
if torch.cuda.is_available() and not cfg.common.cpu and not cfg.common.tpu:
torch.cuda.set_device(cfg.distributed_training.device_id)
if cfg.distributed_training.distributed_rank is None: # torch.multiprocessing.spawn
cfg.distributed_training.distributed_rank = kwargs.pop("start_rank", 0) + i
args.distributed_rank = distributed_init(args)
cfg.distributed_training.distributed_rank = distributed_init(cfg)
after_distributed_init_fn = kwargs.pop("after_distributed_init_fn", None)
if after_distributed_init_fn:
args = after_distributed_init_fn(args)
cfg = after_distributed_init_fn(cfg)
main(args, **kwargs)
main(cfg, **kwargs)
def call_main(args, main, **kwargs):
if args.distributed_init_method is None:
infer_init_method(args)
def call_main(cfg: DictConfig, main, **kwargs):
if cfg.distributed_training.distributed_init_method is None:
infer_init_method(cfg.distributed_training)
if args.distributed_init_method is not None:
if cfg.distributed_training.distributed_init_method is not None:
# distributed training
if not args.distributed_no_spawn:
start_rank = args.distributed_rank
args.distributed_rank = None # assign automatically
if not cfg.distributed_training.distributed_no_spawn:
start_rank = cfg.distributed_training.distributed_rank
cfg.distributed_training.distributed_rank = None # assign automatically
kwargs["start_rank"] = start_rank
torch.multiprocessing.spawn(
fn=distributed_main,
args=(main, args, kwargs),
nprocs=args.distributed_num_procs,
args=(main, cfg, kwargs),
nprocs=min(
torch.cuda.device_count(),
cfg.distributed_training.distributed_world_size,
),
)
else:
distributed_main(args.device_id, main, args, kwargs)
elif getattr(args, "tpu", False) and args.distributed_world_size > 1:
distributed_main(cfg.distributed_training.device_id, main, cfg, kwargs)
elif cfg.common.tpu and cfg.distributed_training.distributed_world_size > 1:
import torch_xla.distributed.xla_multiprocessing as xmp
torch.multiprocessing.set_sharing_strategy("file_system")
xmp.spawn(
fn=distributed_main,
args=(main, args, kwargs),
args=(main, cfg, kwargs),
nprocs=8, # use all 8 TPU cores
)
else:
# single GPU main
main(args, **kwargs)
main(cfg, **kwargs)
def get_rank():
@ -392,11 +406,7 @@ def all_gather_list(data, group=None, max_size=16384):
)
def all_reduce_dict(
data: Mapping[str, Any],
device,
group=None,
) -> Dict[str, Any]:
def all_reduce_dict(data: Mapping[str, Any], device, group=None) -> Dict[str, Any]:
"""
AllReduce a dictionary of values across workers. We separately
reduce items that are already on the device and items on CPU for

View File

@ -8,11 +8,12 @@ import argparse
import copy
import logging
import os
from typing import Any, Dict, Iterator, List, Tuple
from typing import Any, Dict, Iterator, List
import torch
from fairseq import utils
from fairseq.data import encoders
from omegaconf import open_dict
from torch import nn
@ -85,9 +86,9 @@ class GeneratorHubInterface(nn.Module):
translation or language model.
"""
def __init__(self, args, task, models):
def __init__(self, cfg, task, models):
super().__init__()
self.args = args
self.cfg = cfg
self.task = task
self.models = nn.ModuleList(models)
self.src_dict = task.source_dictionary
@ -95,14 +96,14 @@ class GeneratorHubInterface(nn.Module):
# optimize model for generation
for model in self.models:
model.prepare_for_inference_(args)
model.prepare_for_inference_(cfg)
# Load alignment dictionary for unknown word replacement
# (None if no unknown word replacement, empty if no path to align dictionary)
self.align_dict = utils.load_align_dict(getattr(args, "replace_unk", None))
self.align_dict = utils.load_align_dict(cfg.generation.replace_unk)
self.tokenizer = encoders.build_tokenizer(args)
self.bpe = encoders.build_bpe(args)
self.tokenizer = encoders.build_tokenizer(cfg.tokenizer)
self.bpe = encoders.build_bpe(cfg.bpe)
self.max_positions = utils.resolve_max_positions(
self.task.max_positions(), *[model.max_positions() for model in models]
@ -156,10 +157,11 @@ class GeneratorHubInterface(nn.Module):
)[0]
# build generator using current args as well as any kwargs
gen_args = copy.copy(self.args)
gen_args.beam = beam
for k, v in kwargs.items():
setattr(gen_args, k, v)
gen_args = copy.copy(self.cfg)
with open_dict(gen_args):
gen_args.beam = beam
for k, v in kwargs.items():
setattr(gen_args, k, v)
generator = self.task.build_generator(self.models, gen_args)
inference_step_args = inference_step_args or {}
@ -253,8 +255,8 @@ class GeneratorHubInterface(nn.Module):
lengths = torch.LongTensor([t.numel() for t in tokens])
batch_iterator = self.task.get_batch_iterator(
dataset=self.task.build_dataset_for_inference(tokens, lengths),
max_tokens=self.args.max_tokens,
max_sentences=self.args.batch_size,
max_tokens=self.cfg.dataset.max_tokens,
max_sentences=self.cfg.dataset.batch_size,
max_positions=self.max_positions,
ignore_invalid_inputs=skip_invalid_size_inputs,
disable_iterator_cache=True,

View File

@ -9,6 +9,7 @@ Train a network across multiple GPUs.
from fairseq import distributed_utils
from fairseq.trainer import Trainer
from omegaconf import DictConfig
try:
@ -28,14 +29,14 @@ except (ImportError, ModuleNotFoundError):
class MegatronTrainer(Trainer):
"""Main class for model parallel with data parallel training."""
def __init__(self, args, task, model, criterion):
def __init__(self, cfg: DictConfig, task, model, criterion, **kwargs):
if not has_megatron_submodule:
raise ImportError(
"\n\nPlease install the megatron submodule:"
"\n\n git submodule update --init "
"fairseq/model_parallel/megatron"
)
super().__init__(args, task, model, criterion)
super().__init__(cfg, task, model, criterion, **kwargs)
@property
def data_parallel_world_size(self):

View File

@ -96,7 +96,7 @@ class PipelineParallelTransformerModel(BaseFairseqModel):
encoder_output_tuple = self.encoder(input)
return self.decoder(encoder_output_tuple)
def prepare_for_inference_(self, args):
def prepare_for_inference_(self, cfg):
if self.encoder is not None and self.decoder is not None:
logger.info("Encoder and Decoder already initialized")
return
@ -111,9 +111,9 @@ class PipelineParallelTransformerModel(BaseFairseqModel):
decoder_module_list.append(module)
module_count += 1
self.model = None
self.encoder = TransformerEncoder(args, None, None, encoder_module_list)
self.encoder = TransformerEncoder(cfg.model, None, None, encoder_module_list)
self.decoder = TransformerDecoder(
args, None, None, decoder_module_list=decoder_module_list
cfg.model, None, None, decoder_module_list=decoder_module_list
)
@staticmethod
@ -320,7 +320,7 @@ class PipelineParallelTransformerModel(BaseFairseqModel):
"""Maximum length supported by the decoder."""
return self.decoder_max_positions
def load_state_dict(self, state_dict, strict=True, args=None):
def load_state_dict(self, state_dict, strict=True, cfg=None):
"""Copies parameters and buffers from *state_dict* into this module and
its descendants.

View File

@ -72,6 +72,10 @@ class ModelParallelTransformerLanguageModel(TransformerLanguageModel):
)
return cls(decoder)
@staticmethod
def add_args(parser):
TransformerLanguageModel.add_args(parser)
@classmethod
def build_embedding(cls, args, dictionary, embed_dim, path=None):
def _vocab_init(tensor, **kwargs):

View File

@ -7,8 +7,6 @@
import argparse
import importlib
import os
from argparse import Namespace
from typing import Union
import fairseq
from fairseq.dataclass import FairseqDataclass
@ -52,10 +50,10 @@ __all__ = [
]
def build_model(model_cfg: Union[DictConfig, Namespace], task):
if isinstance(model_cfg, DictConfig):
return ARCH_MODEL_REGISTRY[model_cfg._name].build_model(model_cfg, task)
return ARCH_MODEL_REGISTRY[model_cfg.arch].build_model(model_cfg, task)
def build_model(cfg: DictConfig, task):
if isinstance(cfg, DictConfig):
return ARCH_MODEL_REGISTRY[cfg._name].build_model(cfg, task)
return ARCH_MODEL_REGISTRY[cfg.arch].build_model(cfg, task)
def register_model(name, dataclass=None):
@ -92,7 +90,8 @@ def register_model(name, dataclass=None):
)
cls.__dataclass = dataclass
MODEL_DATACLASS_REGISTRY[name] = dataclass
if dataclass is not None:
MODEL_DATACLASS_REGISTRY[name] = dataclass
return cls
return register_model_cls
@ -108,14 +107,13 @@ def register_model_architecture(model_name, arch_name):
For example::
@register_model_architecture('lstm', 'lstm_luong_wmt_en_de')
def lstm_luong_wmt_en_de(args):
args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 1000)
def lstm_luong_wmt_en_de(cfg):
args.encoder_embed_dim = getattr(cfg.model, 'encoder_embed_dim', 1000)
(...)
The decorated function should take a single argument *args*, which is a
:class:`argparse.Namespace` of arguments parsed from the command-line. The
decorated function should modify these arguments in-place to match the
desired architecture.
The decorated function should take a single argument *cfg*, which is a
:class:`omegaconf.DictConfig`. The decorated function should modify these
arguments in-place to match the desired architecture.
Args:
model_name (str): the name of the Model (Model must already be

View File

@ -13,6 +13,7 @@ import torch.nn as nn
import torch.nn.functional as F
from fairseq import utils
from fairseq.data import encoders
from omegaconf import open_dict
logger = logging.getLogger(__name__)
@ -24,13 +25,13 @@ class BARTHubInterface(nn.Module):
Usage: https://github.com/pytorch/fairseq/tree/master/examples/bart
"""
def __init__(self, args, task, model):
def __init__(self, cfg, task, model):
super().__init__()
self.args = args
self.cfg = cfg
self.task = task
self.model = model
self.bpe = encoders.build_bpe(args)
self.bpe = encoders.build_bpe(cfg.bpe)
self.max_positions = min(
utils.resolve_max_positions(
@ -120,10 +121,11 @@ class BARTHubInterface(nn.Module):
sample = self._build_sample(tokens)
# build generator using current args as well as any kwargs
gen_args = copy.copy(self.args)
gen_args.beam = beam
for k, v in kwargs.items():
setattr(gen_args, k, v)
gen_args = copy.copy(self.cfg)
with open_dict(gen_args):
gen_args.beam = beam
for k, v in kwargs.items():
setattr(gen_args, k, v)
generator = self.task.build_generator([self.model], gen_args)
translations = self.task.inference_step(
generator,

View File

@ -144,7 +144,9 @@ class BARTModel(TransformerModel):
num_classes=num_classes,
activation_fn=self.args.pooler_activation_fn,
pooler_dropout=self.args.pooler_dropout,
do_spectral_norm=self.args.spectral_norm_classification_head,
do_spectral_norm=getattr(
self.args, "spectral_norm_classification_head", False
),
)
def upgrade_state_dict_named(self, state_dict, name):

View File

@ -7,6 +7,7 @@ Base classes for various fairseq models.
"""
import logging
from argparse import Namespace
from typing import Dict, List, Optional, Tuple
import torch
@ -15,8 +16,12 @@ import torch.nn.functional as F
from fairseq import utils
from fairseq.checkpoint_utils import prune_state_dict
from fairseq.data import Dictionary
from fairseq.dataclass.utils import gen_parser_from_dataclass
from fairseq.dataclass.utils import (
convert_namespace_to_omegaconf,
gen_parser_from_dataclass,
)
from fairseq.models import FairseqDecoder, FairseqEncoder
from omegaconf import DictConfig
from torch import Tensor
@ -86,15 +91,26 @@ class BaseFairseqModel(nn.Module):
"""Maximum length supported by the model."""
return None
def load_state_dict(self, state_dict, strict=True, args=None):
def load_state_dict(
self,
state_dict,
strict=True,
model_cfg: Optional[DictConfig] = None,
args: Optional[Namespace] = None,
):
"""Copies parameters and buffers from *state_dict* into this module and
its descendants.
Overrides the method in :class:`nn.Module`. Compared with that method
this additionally "upgrades" *state_dicts* from old checkpoints.
"""
if model_cfg is None and args is not None:
logger.warn("using 'args' is deprecated, please update your code to use dataclass config")
model_cfg = convert_namespace_to_omegaconf(args).model
self.upgrade_state_dict(state_dict)
new_state_dict = prune_state_dict(state_dict, args)
new_state_dict = prune_state_dict(state_dict, model_cfg)
return super().load_state_dict(new_state_dict, strict)
def upgrade_state_dict(self, state_dict):
@ -133,18 +149,18 @@ class BaseFairseqModel(nn.Module):
self.apply(_apply)
def prepare_for_inference_(self, args):
def prepare_for_inference_(self, cfg: DictConfig):
"""Prepare model for inference."""
kwargs = {}
kwargs["beamable_mm_beam_size"] = (
None if getattr(args, "no_beamable_mm", False) else getattr(args, "beam", 5)
None
if getattr(cfg.generation, "no_beamable_mm", False)
else getattr(cfg.generation, "beam", 5)
)
kwargs["need_attn"] = getattr(args, "print_alignment", False)
if hasattr(args, "retain_dropout"):
kwargs["retain_dropout"] = args.retain_dropout
kwargs["retain_dropout_modules"] = getattr(
args, "retain_dropout_modules", None
)
kwargs["need_attn"] = getattr(cfg.generation, "print_alignment", False)
if getattr(cfg.generation, "retain_dropout", False):
kwargs["retain_dropout"] = cfg.generation.retain_dropout
kwargs["retain_dropout_modules"] = cfg.generation.retain_dropout_modules
self.make_generation_fast_(**kwargs)
def make_generation_fast_(self, **kwargs):
@ -437,15 +453,26 @@ class FairseqMultiModel(BaseFairseqModel):
def forward_decoder(self, prev_output_tokens, **kwargs):
return self.decoder(prev_output_tokens, **kwargs)
def load_state_dict(self, state_dict, strict=True, args=None):
def load_state_dict(
self,
state_dict,
strict=True,
model_cfg=None,
args: Optional[Namespace] = None,
):
"""Copies parameters and buffers from *state_dict* into this module and
its descendants.
Overrides the method in :class:`nn.Module`. Compared with that method
this additionally "upgrades" *state_dicts* from old checkpoints.
"""
if model_cfg is None and args is not None:
logger.warn("using 'args' is deprecated, please update your code to use dataclass config")
model_cfg = convert_namespace_to_omegaconf(args).model
self.upgrade_state_dict(state_dict)
new_state_dict = prune_state_dict(state_dict, args)
new_state_dict = prune_state_dict(state_dict, model_cfg)
return super().load_state_dict(new_state_dict, strict)

View File

@ -194,14 +194,14 @@ class MultilingualTransformerModel(FairseqMultiModel):
module_class = TransformerEncoder if is_encoder else TransformerDecoder
return module_class(args, lang_dict, embed_tokens)
def load_state_dict(self, state_dict, strict=True, args=None):
def load_state_dict(self, state_dict, strict=True, model_cfg=None):
state_dict_subset = state_dict.copy()
for k, _ in state_dict.items():
assert k.startswith("models.")
lang_pair = k.split(".")[1]
if lang_pair not in self.models:
del state_dict_subset[k]
super().load_state_dict(state_dict_subset, strict=strict, args=args)
super().load_state_dict(state_dict_subset, strict=strict, model_cfg=model_cfg)
@register_model_architecture("multilingual_transformer", "multilingual_transformer")

View File

@ -17,13 +17,13 @@ class RobertaHubInterface(nn.Module):
Usage: https://github.com/pytorch/fairseq/tree/master/examples/roberta
"""
def __init__(self, args, task, model):
def __init__(self, cfg, task, model):
super().__init__()
self.args = args
self.cfg = cfg
self.task = task
self.model = model
self.bpe = encoders.build_bpe(args)
self.bpe = encoders.build_bpe(cfg.bpe)
# this is useful for determining the device
self.register_buffer("_float_tensor", torch.tensor([0], dtype=torch.float))

View File

@ -494,7 +494,7 @@ def base_architecture(args):
args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0.0)
args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0.0)
args.spectral_norm_classification_head = getattr(
args, "spectral_nrom_classification_head", False
args, "spectral_norm_classification_head", False
)

View File

@ -578,10 +578,9 @@ class TransformerDecoder(FairseqIncrementalDecoder):
if embed_dim != input_embed_dim
else None
)
self.embed_positions = (
PositionalEmbedding(
args.max_target_positions,
self.max_target_positions,
embed_dim,
self.padding_idx,
learned=args.decoder_learned_pos,
@ -963,6 +962,14 @@ def base_architecture(args):
args.layernorm_embedding = getattr(args, "layernorm_embedding", False)
args.tie_adaptive_weights = getattr(args, "tie_adaptive_weights", False)
args.encoder_layers_to_keep = getattr(args, "encoder_layers_to_keep", None)
args.decoder_layers_to_keep = getattr(args, "decoder_layers_to_keep", None)
args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0)
args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0)
args.quant_noise_pq = getattr(args, "quant_noise_pq", 0)
args.quant_noise_pq_block_size = getattr(args, "quant_noise_pq_block_size", 8)
args.quant_noise_scalar = getattr(args, "quant_noise_scalar", 0)
@register_model_architecture("transformer", "transformer_iwslt_de_en")
def transformer_iwslt_de_en(args):

View File

@ -159,7 +159,7 @@ class TransformerLanguageModelConfig(FairseqDataclass):
add_bos_token: bool = II("task.add_bos_token")
tokens_per_sample: int = II("task.tokens_per_sample")
max_target_positions: Optional[int] = II("task.max_target_positions")
tpu: bool = II("params.common.tpu")
tpu: bool = II("common.tpu")
@register_model("transformer_lm", dataclass=TransformerLanguageModelConfig)

View File

@ -32,20 +32,20 @@ class TransformerEncoderLayer(nn.Module):
def __init__(self, args):
super().__init__()
self.embed_dim = args.encoder_embed_dim
self.quant_noise = getattr(args, "quant_noise_pq", 0)
self.quant_noise_block_size = getattr(args, "quant_noise_pq_block_size", 8)
self.quant_noise = getattr(args, 'quant_noise_pq', 0)
self.quant_noise_block_size = getattr(args, 'quant_noise_pq_block_size', 8) or 8
self.self_attn = self.build_self_attention(self.embed_dim, args)
self.self_attn_layer_norm = LayerNorm(self.embed_dim)
self.dropout_module = FairseqDropout(
args.dropout, module_name=self.__class__.__name__
)
self.activation_fn = utils.get_activation_fn(
activation=getattr(args, "activation_fn", "relu")
activation=getattr(args, 'activation_fn', 'relu') or "relu"
)
activation_dropout_p = getattr(args, "activation_dropout", 0)
activation_dropout_p = getattr(args, "activation_dropout", 0) or 0
if activation_dropout_p == 0:
# for backwards compatibility with models that use args.relu_dropout
activation_dropout_p = getattr(args, "relu_dropout", 0)
activation_dropout_p = getattr(args, "relu_dropout", 0) or 0
self.activation_dropout_module = FairseqDropout(
float(activation_dropout_p), module_name=self.__class__.__name__
)
@ -197,10 +197,10 @@ class TransformerDecoderLayer(nn.Module):
if getattr(args, "activation_fn", None) is not None
else "relu"
)
activation_dropout_p = getattr(args, "activation_dropout", 0)
activation_dropout_p = getattr(args, "activation_dropout", 0) or 0
if activation_dropout_p == 0:
# for backwards compatibility with models that use args.relu_dropout
activation_dropout_p = getattr(args, "relu_dropout", 0)
activation_dropout_p = getattr(args, "relu_dropout", 0) or 0
self.activation_dropout_module = FairseqDropout(
float(activation_dropout_p), module_name=self.__class__.__name__
)

View File

@ -6,8 +6,6 @@
import importlib
import os
from argparse import Namespace
from typing import Union
from fairseq import registry
from fairseq.optim.bmuf import FairseqBMUF # noqa
@ -19,7 +17,6 @@ from fairseq.optim.fp16_optimizer import FP16Optimizer, MemoryEfficientFP16Optim
from fairseq.optim.shard import shard_
from omegaconf import DictConfig
__all__ = [
"FairseqOptimizer",
"FP16Optimizer",
@ -27,7 +24,6 @@ __all__ = [
"shard_",
]
(
_build_optimizer,
register_optimizer,
@ -37,12 +33,12 @@ __all__ = [
def build_optimizer(
optimizer_cfg: Union[DictConfig, Namespace], params, *extra_args, **extra_kwargs
cfg: DictConfig, params, *extra_args, **extra_kwargs
):
if all(isinstance(p, dict) for p in params):
params = [t for p in params for t in p.values()]
params = list(filter(lambda p: p.requires_grad, params))
return _build_optimizer(optimizer_cfg, params, *extra_args, **extra_kwargs)
return _build_optimizer(cfg, params, *extra_args, **extra_kwargs)
# automatically import any Python files in the optim/ directory

View File

@ -5,6 +5,7 @@
import logging
import math
from collections import Collection
from dataclasses import dataclass, field
from typing import List
@ -14,7 +15,7 @@ import torch.optim
from fairseq.dataclass import FairseqDataclass
from fairseq.optim import FairseqOptimizer, register_optimizer
from fairseq.optim.fused_adam import get_fused_adam_class
from omegaconf import II
from omegaconf import II, DictConfig
logger = logging.getLogger(__name__)
@ -33,8 +34,8 @@ class FairseqAdamConfig(FairseqDataclass):
default=False, metadata={"help": "Use fairseq.optim.adam.Adam"}
)
# TODO common vars below in parent
tpu: bool = II("params.common.tpu")
lr: List[float] = II("params.optimization.lr")
tpu: bool = II("common.tpu")
lr: List[float] = II("optimization.lr")
@register_optimizer("adam", dataclass=FairseqAdamConfig)
@ -46,15 +47,15 @@ class FairseqAdam(FairseqOptimizer):
analogous to torch.optim.AdamW from PyTorch.
"""
def __init__(self, args, params):
super().__init__(args)
def __init__(self, cfg: DictConfig, params):
super().__init__(cfg)
fused_adam_cls = get_fused_adam_class()
use_fused_adam = (
not getattr(args, "use_old_adam", False)
not getattr(cfg, "use_old_adam", False)
and fused_adam_cls is not None
and torch.cuda.is_available()
)
if getattr(args, "tpu", False):
if getattr(cfg, "tpu", False):
# on TPUs we use the Adam defined here, since it
# automatically casts gradients to FP32
self._optimizer = Adam(params, **self.optimizer_config)
@ -73,10 +74,12 @@ class FairseqAdam(FairseqOptimizer):
different learning rate.
"""
return {
"lr": self.args.lr[0],
"betas": eval(self.args.adam_betas),
"eps": self.args.adam_eps,
"weight_decay": self.args.weight_decay,
"lr": self.cfg.lr[0]
if isinstance(self.cfg.lr, Collection)
else self.cfg.lr,
"betas": eval(self.cfg.adam_betas),
"eps": self.cfg.adam_eps,
"weight_decay": self.cfg.weight_decay,
}
def average_params(self):

View File

@ -10,7 +10,7 @@ import torch.distributed as dist
from fairseq.dataclass import FairseqDataclass
from fairseq.dataclass.utils import gen_parser_from_dataclass
from fairseq.optim.fairseq_optimizer import FairseqOptimizer
from omegaconf import II
from omegaconf import II, DictConfig
@dataclass
@ -38,7 +38,7 @@ class FairseqBMUFConfig(FairseqDataclass):
},
)
distributed_world_size: int = II(
"params.distributed_training.distributed_world_size"
"distributed_training.distributed_world_size"
)
@ -52,20 +52,19 @@ class FairseqBMUF(FairseqOptimizer):
model-update filtering
"""
def __init__(self, args, optimizer):
super().__init__(args)
def __init__(self, cfg: DictConfig, optimizer):
super().__init__(cfg)
self._optimizer = optimizer
self._num_updates = 0
self.sync_iter = self.args.global_sync_iter
self.block_momentum = self.args.block_momentum
self.block_lr = self.args.block_lr
self.sync_iter = cfg.global_sync_iter
self.block_momentum = cfg.block_momentum
self.block_lr = cfg.block_lr
self._reset_local_data()
self.warmup_iteration = self.args.warmup_iterations
self.use_nbm = self.args.use_nbm
self.warmup_iteration = cfg.warmup_iterations
self.use_nbm = cfg.use_nbm
self.initial_state = self._optimizer.state_dict()
self.average_sync = self.args.average_sync
self.world_size = self.args.distributed_world_size
self.average_sync = self.cfg.average_sync
self.world_size = self.cfg.distributed_world_size
@staticmethod
def add_args(parser):

View File

@ -9,9 +9,9 @@ from fairseq.dataclass.utils import gen_parser_from_dataclass
class FairseqOptimizer(object):
def __init__(self, args):
def __init__(self, cfg):
super().__init__()
self.args = args
self.cfg = cfg
@classmethod
def add_args(cls, parser):

View File

@ -7,7 +7,8 @@ from collections import defaultdict
from itertools import chain
import torch
from fairseq import optim, utils
from fairseq import optim
from omegaconf import DictConfig
from .dynamic_loss_scaler import DynamicLossScaler
@ -211,7 +212,7 @@ class _FP16OptimizerMixin(object):
for fp32_params in self.fp32_params.values():
fp32_params.grad.zero_()
else:
raise ("self.fp32_params must be a tensor or dict")
raise RuntimeError("self.fp32_params must be a tensor or dict")
else:
for p32 in self.fp32_params:
p32.grad.zero_()
@ -226,58 +227,60 @@ class FP16Optimizer(_FP16OptimizerMixin, optim.FairseqOptimizer):
Wrap an *optimizer* to support FP16 (mixed precision) training.
"""
def __init__(self, args, params, fp32_optimizer, fp32_params):
super().__init__(args)
def __init__(self, cfg: DictConfig, params, fp32_optimizer, fp32_params, **kwargs):
super().__init__(cfg.optimizer)
self.fp16_params = params
self.fp32_optimizer = fp32_optimizer
self.fp32_params = fp32_params
if getattr(args, "fp16_scale_window", None) is None:
if len(args.update_freq) > 1:
if getattr(cfg.common, "fp16_scale_window", None) is None:
if len(cfg.optimization.update_freq) > 1:
raise ValueError(
"--fp16-scale-window must be given explicitly when using a "
"custom --update-freq schedule"
)
data_parallel_size = int(
args.distributed_world_size / args.model_parallel_size
cfg.distributed_training.distributed_world_size
/ cfg.common.model_parallel_size
)
scale_window = int(
2 ** 14 / data_parallel_size / cfg.optimization.update_freq[0]
)
scale_window = int(2 ** 14 / data_parallel_size / args.update_freq[0])
else:
scale_window = args.fp16_scale_window
scale_window = cfg.common.fp16_scale_window
if not getattr(args, "bf16", False):
if not getattr(cfg.common, "bf16", False):
self.scaler = DynamicLossScaler(
init_scale=args.fp16_init_scale,
init_scale=cfg.common.fp16_init_scale,
scale_window=scale_window,
tolerance=args.fp16_scale_tolerance,
threshold=args.threshold_loss_scale,
min_loss_scale=args.min_loss_scale,
tolerance=cfg.common.fp16_scale_tolerance,
threshold=cfg.common.threshold_loss_scale,
min_loss_scale=cfg.common.min_loss_scale,
)
else:
# disable loss scaling for bfloat16
self.scaler = None
@classmethod
def build_optimizer(cls, args, params):
def build_optimizer(cls, cfg: DictConfig, params, **kwargs):
"""
Args:
args (argparse.Namespace): fairseq args
cfg (omegaconf.DictConfig): fairseq args
params (iterable): iterable of parameters to optimize
"""
flatten = not getattr(args, "fp16_no_flatten_grads", False)
if getattr(args, "bf16", False):
flatten = not getattr(cfg.common, "fp16_no_flatten_grads", False)
if getattr(cfg.common, "bf16", False):
flatten = False # mixed precision is faster on TPUs without flat grads
fp32_params = cls.build_fp32_params(args, params, flatten=flatten)
fp32_params = cls.build_fp32_params(cfg.optimizer, params, flatten=flatten)
if flatten:
fp32_optimizer = optim.build_optimizer(args, [fp32_params])
fp32_optimizer = optim.build_optimizer(cfg.optimizer, [fp32_params])
else:
fp32_optimizer = optim.build_optimizer(args, fp32_params)
fp32_optimizer = optim.build_optimizer(cfg.optimizer, fp32_params)
if flatten and not fp32_optimizer.supports_flat_params:
raise RuntimeError(
"chosen optimizer does not support flat params, "
"please set --fp16-no-flatten-grads"
f"chosen optimizer {fp32_optimizer.__class__.__name__} does not support flat params, please set --fp16-no-flatten-grads"
)
return cls(args, params, fp32_optimizer, fp32_params)
return cls(cfg, params, fp32_optimizer, fp32_params, **kwargs)
@property
def optimizer(self):
@ -427,49 +430,52 @@ class MemoryEfficientFP16Optimizer(
*supports_memory_efficient_fp16* property.
"""
def __init__(self, args, params, optimizer):
def __init__(self, cfg: DictConfig, params, optimizer, **kwargs):
if not optimizer.supports_memory_efficient_fp16:
raise ValueError(
"Unsupported optimizer: {}".format(optimizer.__class__.__name__)
)
super().__init__(args)
super().__init__(cfg.optimizer)
self.wrapped_optimizer = optimizer
if getattr(args, "fp16_scale_window", None) is None:
if len(args.update_freq) > 1:
if getattr(cfg.common, "fp16_scale_window", None) is None:
if len(cfg.optimization.update_freq) > 1:
raise ValueError(
"--fp16-scale-window must be given explicitly when using a "
"custom --update-freq schedule"
)
data_parallel_size = int(
args.distributed_world_size / args.model_parallel_size
cfg.distributed_training.distributed_world_size
/ cfg.common.model_parallel_size
)
scale_window = (
2 ** 14 / data_parallel_size / cfg.optimization.update_freq[0]
)
scale_window = 2 ** 14 / data_parallel_size / args.update_freq[0]
else:
scale_window = args.fp16_scale_window
scale_window = cfg.common.fp16_scale_window
if not getattr(args, "bf16", False):
if not getattr(cfg.common, "bf16", False):
self.scaler = DynamicLossScaler(
init_scale=args.fp16_init_scale,
init_scale=cfg.common.fp16_init_scale,
scale_window=scale_window,
tolerance=args.fp16_scale_tolerance,
threshold=args.threshold_loss_scale,
min_loss_scale=args.min_loss_scale,
tolerance=cfg.common.fp16_scale_tolerance,
threshold=cfg.common.threshold_loss_scale,
min_loss_scale=cfg.common.min_loss_scale,
)
else:
# disable loss scaling for bfloat16
self.scaler = None
@classmethod
def build_optimizer(cls, args, params):
def build_optimizer(cls, cfg: DictConfig, params, **kwargs):
"""
Args:
args (argparse.Namespace): fairseq args
params (iterable): iterable of parameters to optimize
"""
fp16_optimizer = optim.build_optimizer(args, params)
return cls(args, params, fp16_optimizer)
fp16_optimizer = optim.build_optimizer(cfg.optimizer, params)
return cls(cfg, params, fp16_optimizer, **kwargs)
@property
def optimizer(self):

View File

@ -6,8 +6,6 @@
import importlib
import os
from argparse import Namespace
from typing import Union
from fairseq import registry
from fairseq.optim.lr_scheduler.fairseq_lr_scheduler import ( # noqa
@ -27,8 +25,8 @@ from omegaconf import DictConfig
)
def build_lr_scheduler(lr_scheduler_cfg: Union[DictConfig, Namespace], optimizer):
return build_lr_scheduler_(lr_scheduler_cfg, optimizer)
def build_lr_scheduler(cfg: DictConfig, optimizer):
return build_lr_scheduler_(cfg, optimizer)
# automatically import any Python files in the optim/lr_scheduler/ directory

View File

@ -4,11 +4,12 @@
# LICENSE file in the root directory of this source tree.
import math
from collections import Collection
from dataclasses import dataclass, field
from typing import List
from fairseq.dataclass import FairseqDataclass
from omegaconf import II
from omegaconf import II, DictConfig
from . import FairseqLRScheduler, register_lr_scheduler
@ -38,8 +39,8 @@ class CosineConfig(FairseqDataclass):
default=0.1, metadata={"help": "shrink factor for annealing"}
)
# TODO common var for parent class
lr: List[float] = II("params.optimization.lr")
max_update: int = II("params.optimization.max_update")
lr: List[float] = II("optimization.lr")
max_update: int = II("optimization.max_update")
@register_lr_scheduler("cosine", dataclass=CosineConfig)
@ -66,43 +67,51 @@ class CosineSchedule(FairseqLRScheduler):
after every iteration.
"""
def __init__(self, args, optimizer):
super().__init__(args, optimizer)
if len(args.lr) > 1:
def __init__(
self, cfg: DictConfig, fairseq_optimizer
):
super().__init__(cfg, fairseq_optimizer)
if isinstance(cfg.lr, Collection) and len(cfg.lr) > 1:
raise ValueError(
"Cannot use a fixed learning rate schedule with cosine."
" Consider --lr-scheduler=fixed instead."
)
warmup_end_lr = args.max_lr
if args.warmup_init_lr < 0:
args.warmup_init_lr = args.lr[0]
self.min_lr = args.lr[0]
self.max_lr = args.max_lr
warmup_end_lr = cfg.max_lr
lr = (
cfg.lr[0]
if isinstance(cfg.lr, Collection)
else cfg.lr
)
if cfg.warmup_init_lr < 0:
cfg.warmup_init_lr = lr
self.min_lr = lr
self.max_lr = cfg.max_lr
assert self.max_lr > self.min_lr, "max_lr must be more than lr"
self.t_mult = args.t_mult
self.period = args.lr_period_updates
self.t_mult = cfg.t_mult
self.period = cfg.lr_period_updates
if self.period <= 0:
assert (
args.max_update >= 0
cfg.max_update >= 0
), "Either --max_update or --lr-period-updates must be set"
self.period = args.max_update - args.warmup_updates
self.period = cfg.max_update - cfg.warmup_updates
if args.warmup_updates > 0:
if cfg.warmup_updates > 0:
# linearly warmup for the first args.warmup_updates
self.lr_step = (warmup_end_lr - args.warmup_init_lr) / args.warmup_updates
self.lr_step = (
warmup_end_lr - cfg.warmup_init_lr
) / cfg.warmup_updates
else:
self.lr_step = 1
self.warmup_updates = args.warmup_updates
self.lr_shrink = args.lr_shrink
self.warmup_updates = cfg.warmup_updates
self.lr_shrink = cfg.lr_shrink
# initial learning rate
self.lr = args.warmup_init_lr
self.lr = cfg.warmup_init_lr
self.optimizer.set_lr(self.lr)
def step(self, epoch, val_loss=None):
@ -113,10 +122,10 @@ class CosineSchedule(FairseqLRScheduler):
def step_update(self, num_updates):
"""Update the learning rate after each update."""
if num_updates < self.args.warmup_updates:
self.lr = self.args.warmup_init_lr + num_updates * self.lr_step
if num_updates < self.cfg.warmup_updates:
self.lr = self.cfg.warmup_init_lr + num_updates * self.lr_step
else:
curr_updates = num_updates - self.args.warmup_updates
curr_updates = num_updates - self.cfg.warmup_updates
if self.t_mult != 1:
i = math.floor(
math.log(

View File

@ -11,11 +11,11 @@ from .. import FairseqOptimizer
class FairseqLRScheduler(object):
def __init__(self, args, optimizer):
def __init__(self, cfg, optimizer):
super().__init__()
if not isinstance(optimizer, FairseqOptimizer):
raise ValueError("optimizer must be an instance of FairseqOptimizer")
self.args = args
self.cfg = cfg
self.optimizer = optimizer
self.best = None

View File

@ -3,11 +3,12 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from collections import Collection
from dataclasses import dataclass, field
from typing import List
from fairseq.dataclass import FairseqDataclass
from omegaconf import II
from omegaconf import II, DictConfig
from . import FairseqLRScheduler, register_lr_scheduler
@ -25,7 +26,7 @@ class InverseSquareRootScheduleConfig(FairseqDataclass):
},
)
# TODO common vars at parent class
lr: List[float] = II("params.optimization.lr")
lr: List[float] = II("optimization.lr")
@register_lr_scheduler("inverse_sqrt", dataclass=InverseSquareRootScheduleConfig)
@ -48,25 +49,33 @@ class InverseSquareRootSchedule(FairseqLRScheduler):
lr = decay_factor / sqrt(update_num)
"""
def __init__(self, args, optimizer):
super().__init__(args, optimizer)
if len(args.lr) > 1:
def __init__(self, cfg: DictConfig, optimizer):
super().__init__(cfg, optimizer)
if isinstance(cfg.lr, Collection) and len(cfg.lr) > 1:
raise ValueError(
"Cannot use a fixed learning rate schedule with inverse_sqrt."
" Consider --lr-scheduler=fixed instead."
)
warmup_end_lr = args.lr[0]
if args.warmup_init_lr < 0:
args.warmup_init_lr = 0 if args.warmup_updates > 0 else warmup_end_lr
warmup_end_lr = (
cfg.lr[0]
if isinstance(cfg.lr, Collection)
else cfg.lr
)
if cfg.warmup_init_lr < 0:
cfg.warmup_init_lr = (
0 if cfg.warmup_updates > 0 else warmup_end_lr
)
# linearly warmup for the first args.warmup_updates
self.lr_step = (warmup_end_lr - args.warmup_init_lr) / args.warmup_updates
self.lr_step = (
warmup_end_lr - cfg.warmup_init_lr
) / cfg.warmup_updates
# then, decay prop. to the inverse square root of the update number
self.decay_factor = warmup_end_lr * args.warmup_updates ** 0.5
self.decay_factor = warmup_end_lr * cfg.warmup_updates ** 0.5
# initial learning rate
self.lr = args.warmup_init_lr
self.lr = cfg.warmup_init_lr
self.optimizer.set_lr(self.lr)
def step(self, epoch, val_loss=None):
@ -77,8 +86,8 @@ class InverseSquareRootSchedule(FairseqLRScheduler):
def step_update(self, num_updates):
"""Update the learning rate after each update."""
if num_updates < self.args.warmup_updates:
self.lr = self.args.warmup_init_lr + num_updates * self.lr_step
if num_updates < self.cfg.warmup_updates:
self.lr = self.cfg.warmup_init_lr + num_updates * self.lr_step
else:
self.lr = self.decay_factor * num_updates ** -0.5
self.optimizer.set_lr(self.lr)

View File

@ -3,12 +3,13 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from collections import Collection
from dataclasses import dataclass, field
from typing import List
import torch
from fairseq.dataclass import FairseqDataclass
from omegaconf import II
from omegaconf import II, DictConfig
from torch.optim.optimizer import Optimizer, required
from . import FairseqOptimizer, register_optimizer
@ -19,13 +20,13 @@ class FairseqNAGConfig(FairseqDataclass):
momentum: float = field(default=0.99, metadata={"help": "momentum factor"})
weight_decay: float = field(default=0.0, metadata={"help": "weight decay"})
# TODO common vars in parent class
lr: List[float] = II("params.optimization.lr")
lr: List[float] = II("optimization.lr")
@register_optimizer("nag", dataclass=FairseqNAGConfig)
class FairseqNAG(FairseqOptimizer):
def __init__(self, args, params):
super().__init__(args)
def __init__(self, cfg: DictConfig, params):
super().__init__(cfg)
self._optimizer = NAG(params, **self.optimizer_config)
@property
@ -37,9 +38,11 @@ class FairseqNAG(FairseqOptimizer):
different learning rate.
"""
return {
"lr": self.args.lr[0],
"momentum": self.args.momentum,
"weight_decay": self.args.weight_decay,
"lr": self.cfg.lr[0]
if isinstance(self.cfg.lr, Collection)
else self.cfg.lr,
"momentum": self.cfg.momentum,
"weight_decay": self.cfg.weight_decay,
}

View File

@ -12,7 +12,7 @@ except ImportError:
_has_fairscale = False
def shard_(args, optimizer, group):
def shard_(optimizer, group):
if not _has_fairscale:
raise ImportError(
"\n\nPlease install the fairscale package:" "\n\n pip install fairscale"

View File

@ -10,13 +10,15 @@ import torch
from fairseq import utils
from fairseq.data.indexed_dataset import get_available_dataset_impl
from fairseq.dataclass.data_class import (
CheckpointParams,
CommonEvalParams,
CommonParams,
DatasetParams,
DistributedTrainingParams,
EvalLMParams,
OptimizationParams,
CheckpointConfig,
CommonConfig,
CommonEvalConfig,
DatasetConfig,
DistributedTrainingConfig,
EvalLMConfig,
GenerationConfig,
InteractiveConfig,
OptimizationConfig,
)
from fairseq.dataclass.utils import gen_parser_from_dataclass
@ -45,6 +47,7 @@ def get_generation_parser(interactive=False, default_task="translation"):
add_dataset_args(parser, gen=True)
add_distributed_training_args(parser, default_world_size=1)
add_generation_args(parser)
add_checkpoint_args(parser)
if interactive:
add_interactive_args(parser)
return parser
@ -67,7 +70,7 @@ def get_validation_parser(default_task=None):
add_dataset_args(parser, train=True)
add_distributed_training_args(parser, default_world_size=1)
group = parser.add_argument_group("Evaluation")
gen_parser_from_dataclass(group, CommonEvalParams())
gen_parser_from_dataclass(group, CommonEvalConfig())
return parser
@ -210,7 +213,7 @@ def get_parser(desc, default_task="translation"):
utils.import_user_module(usr_args)
parser = argparse.ArgumentParser(allow_abbrev=False)
gen_parser_from_dataclass(parser, CommonParams())
gen_parser_from_dataclass(parser, CommonConfig())
from fairseq.registry import REGISTRIES
@ -283,7 +286,7 @@ def add_preprocess_args(parser):
def add_dataset_args(parser, train=False, gen=False):
group = parser.add_argument_group("dataset_data_loading")
gen_parser_from_dataclass(group, DatasetParams())
gen_parser_from_dataclass(group, DatasetConfig())
# fmt: on
return group
@ -293,7 +296,7 @@ def add_distributed_training_args(parser, default_world_size=None):
if default_world_size is None:
default_world_size = max(1, torch.cuda.device_count())
gen_parser_from_dataclass(
group, DistributedTrainingParams(distributed_world_size=default_world_size)
group, DistributedTrainingConfig(distributed_world_size=default_world_size)
)
return group
@ -301,7 +304,7 @@ def add_distributed_training_args(parser, default_world_size=None):
def add_optimization_args(parser):
group = parser.add_argument_group("optimization")
# fmt: off
gen_parser_from_dataclass(group, OptimizationParams())
gen_parser_from_dataclass(group, OptimizationConfig())
# fmt: on
return group
@ -309,117 +312,31 @@ def add_optimization_args(parser):
def add_checkpoint_args(parser):
group = parser.add_argument_group("checkpoint")
# fmt: off
gen_parser_from_dataclass(group, CheckpointParams())
gen_parser_from_dataclass(group, CheckpointConfig())
# fmt: on
return group
def add_common_eval_args(group):
gen_parser_from_dataclass(group, CommonEvalParams())
gen_parser_from_dataclass(group, CommonEvalConfig())
def add_eval_lm_args(parser):
group = parser.add_argument_group("LM Evaluation")
add_common_eval_args(group)
gen_parser_from_dataclass(group, EvalLMParams())
gen_parser_from_dataclass(group, EvalLMConfig())
def add_generation_args(parser):
group = parser.add_argument_group("Generation")
add_common_eval_args(group)
# fmt: off
group.add_argument('--beam', default=5, type=int, metavar='N',
help='beam size')
group.add_argument('--nbest', default=1, type=int, metavar='N',
help='number of hypotheses to output')
group.add_argument('--max-len-a', default=0, type=float, metavar='N',
help=('generate sequences of maximum length ax + b, '
'where x is the source length'))
group.add_argument('--max-len-b', default=200, type=int, metavar='N',
help=('generate sequences of maximum length ax + b, '
'where x is the source length'))
group.add_argument('--min-len', default=1, type=float, metavar='N',
help=('minimum generation length'))
group.add_argument('--match-source-len', default=False, action='store_true',
help=('generations should match the source length'))
group.add_argument('--no-early-stop', action='store_true',
help='deprecated')
group.add_argument('--unnormalized', action='store_true',
help='compare unnormalized hypothesis scores')
group.add_argument('--no-beamable-mm', action='store_true',
help='don\'t use BeamableMM in attention layers')
group.add_argument('--lenpen', default=1, type=float,
help='length penalty: <1.0 favors shorter, >1.0 favors longer sentences')
group.add_argument('--unkpen', default=0, type=float,
help='unknown word penalty: <0 produces more unks, >0 produces fewer')
group.add_argument('--replace-unk', nargs='?', const=True, default=None,
help='perform unknown replacement (optionally with alignment dictionary)')
group.add_argument('--sacrebleu', action='store_true',
help='score with sacrebleu')
group.add_argument('--score-reference', action='store_true',
help='just score the reference translation')
group.add_argument('--prefix-size', default=0, type=int, metavar='PS',
help='initialize generation by target prefix of given length')
group.add_argument('--no-repeat-ngram-size', default=0, type=int, metavar='N',
help='ngram blocking such that this size ngram cannot be repeated in the generation')
group.add_argument('--sampling', action='store_true',
help='sample hypotheses instead of using beam search')
group.add_argument('--sampling-topk', default=-1, type=int, metavar='PS',
help='sample from top K likely next words instead of all words')
group.add_argument('--sampling-topp', default=-1.0, type=float, metavar='PS',
help='sample from the smallest set whose cumulative probability mass exceeds p for next words')
group.add_argument('--constraints', const="ordered", nargs="?", choices=["ordered", "unordered"],
help='enables lexically constrained decoding')
group.add_argument('--temperature', default=1., type=float, metavar='N',
help='temperature for generation')
group.add_argument('--diverse-beam-groups', default=-1, type=int, metavar='N',
help='number of groups for Diverse Beam Search')
group.add_argument('--diverse-beam-strength', default=0.5, type=float, metavar='N',
help='strength of diversity penalty for Diverse Beam Search')
group.add_argument('--diversity-rate', default=-1.0, type=float, metavar='N',
help='strength of diversity penalty for Diverse Siblings Search')
group.add_argument('--print-alignment', action='store_true',
help='if set, uses attention feedback to compute and print alignment to source tokens')
group.add_argument('--print-step', action='store_true')
group.add_argument('--lm-path', default=None, type=str, metavar='PATH',
help='path to lm checkpoint for lm fusion')
group.add_argument('--lm-weight', default=0.0, type=float, metavar='N',
help='weight for lm probs for lm fusion')
# arguments for iterative refinement generator
group.add_argument('--iter-decode-eos-penalty', default=0.0, type=float, metavar='N',
help='if > 0.0, it penalized early-stopping in decoding.')
group.add_argument('--iter-decode-max-iter', default=10, type=int, metavar='N',
help='maximum iterations for iterative refinement.')
group.add_argument('--iter-decode-force-max-iter', action='store_true',
help='if set, run exact the maximum number of iterations without early stop')
group.add_argument('--iter-decode-with-beam', default=1, type=int, metavar='N',
help='if > 1, model will generate translations varying by the lengths.')
group.add_argument('--iter-decode-with-external-reranker', action='store_true',
help='if set, the last checkpoint are assumed to be a reranker to rescore the translations'),
group.add_argument('--retain-iter-history', action='store_true',
help='if set, decoding returns the whole history of iterative refinement')
group.add_argument('--retain-dropout', action='store_true',
help='Use dropout at inference time')
group.add_argument('--retain-dropout-modules', default=None, nargs='+', type=str,
help='if set, only retain dropout for the specified modules; '
'if not set, then dropout will be retained for all modules')
# special decoding format for advanced decoding.
group.add_argument('--decoding-format', default=None, type=str, choices=['unigram', 'ensemble', 'vote', 'dp', 'bs'])
# fmt: on
gen_parser_from_dataclass(group, GenerationConfig())
return group
def add_interactive_args(parser):
group = parser.add_argument_group("Interactive")
# fmt: off
group.add_argument('--buffer-size', default=0, type=int, metavar='N',
help='read this many sentences into a buffer before processing them')
group.add_argument('--input', default='-', type=str, metavar='FILE',
help='file to read from; use - for stdin')
# fmt: on
gen_parser_from_dataclass(group, InteractiveConfig())
def add_model_args(parser):

View File

@ -6,13 +6,14 @@
import logging
from fairseq.modules.quantization import pq, quantization_options, scalar
from omegaconf import DictConfig
logger = logging.getLogger(__name__)
def quantize_model_scalar(model, args):
quant_noise_scalar = getattr(args, "quant_noise_scalar", 0)
def quantize_model_scalar(model, model_cfg: DictConfig):
quant_noise_scalar = getattr(model_cfg, "quant_noise_scalar", 0) or 0
if quant_noise_scalar > 0:
# quantize_model edits the model in place
scalar.quantize_model_(model, p=quant_noise_scalar, bits=8, update_step=1000)

View File

@ -3,14 +3,13 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse
from argparse import Namespace
from typing import Union
from fairseq.dataclass import FairseqDataclass
from fairseq.dataclass.utils import populate_dataclass
from omegaconf import DictConfig
REGISTRIES = {}
@ -25,33 +24,30 @@ def setup_registry(registry_name: str, base_class=None, default=None, required=F
# maintain a registry of all registries
if registry_name in REGISTRIES:
return # registry already exists
REGISTRIES[registry_name] = {"registry": REGISTRY, "default": default}
REGISTRIES[registry_name] = {"registry": REGISTRY, "default": default, "dataclass_registry": DATACLASS_REGISTRY}
def build_x(args: Union[DictConfig, Namespace], *extra_args, **extra_kwargs):
if isinstance(args, DictConfig):
if getattr(args, "_name", None) is not None:
choice = args._name
elif hasattr(args, registry_name):
choice = args.registry_name
else:
raise RuntimeError(
f"Neither _name nor {registry_name} in args, args = {args}"
)
def build_x(cfg: Union[DictConfig, str, Namespace], *extra_args, **extra_kwargs):
if isinstance(cfg, DictConfig):
choice = cfg._name
elif isinstance(cfg, str):
choice = cfg
else:
choice = getattr(args, registry_name, None)
choice = getattr(cfg, registry_name, None)
if choice in DATACLASS_REGISTRY:
cfg = populate_dataclass(cfg, DATACLASS_REGISTRY[choice]())
if choice is None:
if required:
raise ValueError("--{} is required!".format(registry_name))
raise ValueError('{} is required!'.format(registry_name))
return None
cls = REGISTRY[choice]
if hasattr(cls, "build_" + registry_name):
builder = getattr(cls, "build_" + registry_name)
else:
builder = cls
if isinstance(args, Namespace):
set_defaults(args, cls)
return builder(args, *extra_args, **extra_kwargs)
return builder(cfg, *extra_args, **extra_kwargs)
def register_x(name, dataclass=None):
def register_x_cls(cls):
@ -77,30 +73,10 @@ def setup_registry(registry_name: str, base_class=None, default=None, required=F
cls.__dataclass = dataclass
REGISTRY[name] = cls
DATACLASS_REGISTRY[name] = cls.__dataclass
REGISTRY_CLASS_NAMES.add(cls.__name__)
if cls.__dataclass is not None:
DATACLASS_REGISTRY[name] = cls.__dataclass
return cls
return register_x_cls
return build_x, register_x, REGISTRY, DATACLASS_REGISTRY
def set_defaults(args: Namespace, cls):
"""Helper to set default arguments based on *add_args*."""
if not hasattr(cls, "add_args"):
return
parser = argparse.ArgumentParser(
argument_default=argparse.SUPPRESS, allow_abbrev=False
)
cls.add_args(parser)
# copied from argparse.py:
defaults = argparse.Namespace()
for action in parser._actions:
if action.dest is not argparse.SUPPRESS:
if not hasattr(defaults, action.dest):
if action.default is not argparse.SUPPRESS:
setattr(defaults, action.dest, action.default)
for key, default_value in vars(defaults).items():
if not hasattr(args, key):
setattr(args, key, default_value)

View File

@ -9,11 +9,12 @@ import os
from abc import ABC, abstractmethod
from fairseq import registry
from omegaconf import DictConfig
class BaseScorer(ABC):
def __init__(self, args):
self.args = args
def __init__(self, cfg):
self.cfg = cfg
self.ref = []
self.pred = []
@ -39,19 +40,17 @@ _build_scorer, register_scorer, SCORER_REGISTRY, _ = registry.setup_registry(
)
def build_scorer(args, tgt_dict):
from fairseq import utils
def build_scorer(choice, tgt_dict):
if isinstance(choice, DictConfig):
choice = choice._name
if args.sacrebleu:
utils.deprecation_warning(
"--sacrebleu is deprecated. Please use --scoring sacrebleu instead."
)
args.scoring = "sacrebleu"
if args.scoring == "bleu":
if choice == "bleu":
from fairseq.scoring import bleu
return bleu.Scorer(tgt_dict.pad(), tgt_dict.eos(), tgt_dict.unk())
return _build_scorer(args)
return bleu.Scorer(
bleu.BleuConfig(pad=tgt_dict.pad(), eos=tgt_dict.eos(), unk=tgt_dict.unk())
)
return _build_scorer(choice)
# automatically import any Python files in the current directory

View File

@ -6,8 +6,10 @@
import ctypes
import math
import sys
from dataclasses import dataclass, field
import torch
from fairseq.dataclass import FairseqDataclass
from fairseq.scoring import BaseScorer, register_scorer
from fairseq.scoring.tokenizer import EvaluationTokenizer
@ -27,31 +29,32 @@ class BleuStat(ctypes.Structure):
]
@register_scorer("sacrebleu")
@dataclass
class SacrebleuConfig(FairseqDataclass):
sacrebleu_tokenizer: EvaluationTokenizer.ALL_TOKENIZER_TYPES = field(
default="13a", metadata={"help": "tokenizer"}
)
sacrebleu_lowercase: bool = field(
default=False, metadata={"help": "apply lowercasing"}
)
sacrebleu_char_level: bool = field(
default=False, metadata={"help": "evaluate at character level"}
)
@register_scorer("sacrebleu", dataclass=SacrebleuConfig)
class SacrebleuScorer(BaseScorer):
def __init__(self, args):
super(SacrebleuScorer, self).__init__(args)
def __init__(self, cfg):
super(SacrebleuScorer, self).__init__(cfg)
import sacrebleu
self.sacrebleu = sacrebleu
self.tokenizer = EvaluationTokenizer(
tokenizer_type=self.args.sacrebleu_tokenizer,
lowercase=self.args.sacrebleu_lowercase,
character_tokenization=self.args.sacrebleu_char_level,
tokenizer_type=cfg.sacrebleu_tokenizer,
lowercase=cfg.sacrebleu_lowercase,
character_tokenization=cfg.sacrebleu_char_level,
)
@staticmethod
def add_args(parser):
# fmt: off
parser.add_argument('--sacrebleu-tokenizer', type=str, default='13a',
choices=EvaluationTokenizer.ALL_TOKENIZER_TYPES,
help='tokenizer')
parser.add_argument('--sacrebleu-lowercase', type=str, default=False,
help='apply lowercasing')
parser.add_argument('--sacrebleu-char-level', action='store_true',
help='evaluate at character level')
# fmt: on
def add_string(self, ref, pred):
self.ref.append(self.tokenizer.tokenize(ref))
self.pred.append(self.tokenizer.tokenize(pred))
@ -68,13 +71,20 @@ class SacrebleuScorer(BaseScorer):
).format()
@register_scorer("bleu")
@dataclass
class BleuConfig(FairseqDataclass):
pad: int = field(default=1, metadata={"help": "padding index"})
eos: int = field(default=2, metadata={"help": "eos index"})
unk: int = field(default=3, metadata={"help": "unk index"})
@register_scorer("bleu", dataclass=BleuConfig)
class Scorer(object):
def __init__(self, pad, eos, unk):
def __init__(self, cfg):
self.stat = BleuStat()
self.pad = pad
self.eos = eos
self.unk = unk
self.pad = cfg.pad
self.eos = cfg.eos
self.unk = cfg.unk
try:
from fairseq import libbleu

View File

@ -5,6 +5,8 @@
import unicodedata
from fairseq.dataclass.utils import ChoiceEnum
class EvaluationTokenizer(object):
"""A generic evaluation-time tokenizer, which leverages built-in tokenizers
@ -22,7 +24,7 @@ class EvaluationTokenizer(object):
SPACE = chr(32)
SPACE_ESCAPE = chr(9601)
ALL_TOKENIZER_TYPES = ["none", "13a", "intl", "zh", "ja-mecab"]
ALL_TOKENIZER_TYPES = ChoiceEnum(["none", "13a", "intl", "zh", "ja-mecab"])
def __init__(
self,
@ -33,7 +35,7 @@ class EvaluationTokenizer(object):
):
from sacrebleu.tokenizers import TOKENIZERS
assert tokenizer_type in self.ALL_TOKENIZER_TYPES
assert tokenizer_type in TOKENIZERS, f"{tokenizer_type}, {TOKENIZERS}"
self.lowercase = lowercase
self.punctuation_removal = punctuation_removal
self.character_tokenization = character_tokenization

View File

@ -3,14 +3,31 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass, field
from fairseq.dataclass import FairseqDataclass
from fairseq.scoring import BaseScorer, register_scorer
from fairseq.scoring.tokenizer import EvaluationTokenizer
@register_scorer("wer")
@dataclass
class WerScorerConfig(FairseqDataclass):
wer_tokenizer: EvaluationTokenizer.ALL_TOKENIZER_TYPES = field(
default="none", metadata={"help": "sacreBLEU tokenizer to use for evaluation"}
)
wer_remove_punct: bool = field(
default=False, metadata={"help": "remove punctuation"}
)
wer_char_level: bool = field(
default=False, metadata={"help": "evaluate at character level"}
)
wer_lowercase: bool = field(default=False, metadata={"help": "lowercasing"})
@register_scorer("wer", dataclass=WerScorerConfig)
class WerScorer(BaseScorer):
def __init__(self, args):
super().__init__(args)
def __init__(self, cfg):
super().__init__(cfg)
self.reset()
try:
import editdistance as ed
@ -18,26 +35,12 @@ class WerScorer(BaseScorer):
raise ImportError("Please install editdistance to use WER scorer")
self.ed = ed
self.tokenizer = EvaluationTokenizer(
tokenizer_type=self.args.wer_tokenizer,
lowercase=self.args.wer_lowercase,
punctuation_removal=self.args.wer_remove_punct,
character_tokenization=self.args.wer_char_level,
tokenizer_type=self.cfg.wer_tokenizer,
lowercase=self.cfg.wer_lowercase,
punctuation_removal=self.cfg.wer_remove_punct,
character_tokenization=self.cfg.wer_char_level,
)
@staticmethod
def add_args(parser):
# fmt: off
parser.add_argument('--wer-tokenizer', type=str, default='none',
choices=EvaluationTokenizer.ALL_TOKENIZER_TYPES,
help='sacreBLEU tokenizer to use for evaluation')
parser.add_argument('--wer-remove-punct', action='store_true',
help='remove punctuation')
parser.add_argument('--wer-char-level', action='store_true',
help='evaluate at character level')
parser.add_argument('--wer-lowercase', action='store_true',
help='lowercasing')
# fmt: on
def reset(self):
self.distance = 0
self.ref_length = 0

View File

@ -7,8 +7,6 @@
import argparse
import importlib
import os
from argparse import Namespace
from typing import Union
from fairseq.dataclass import FairseqDataclass
from omegaconf import DictConfig
@ -22,10 +20,10 @@ TASK_REGISTRY = {}
TASK_CLASS_NAMES = set()
def setup_task(task_cfg: Union[DictConfig, Namespace], **kwargs):
if isinstance(task_cfg, DictConfig):
return TASK_REGISTRY[task_cfg._name].setup_task(task_cfg, **kwargs)
return TASK_REGISTRY[task_cfg.task].setup_task(task_cfg, **kwargs)
def setup_task(cfg: DictConfig, **kwargs):
if isinstance(cfg, DictConfig):
return TASK_REGISTRY[cfg._name].setup_task(cfg, **kwargs)
return TASK_REGISTRY[cfg.task].setup_task(cfg, **kwargs)
def register_task(name, dataclass=None):
@ -70,7 +68,8 @@ def register_task(name, dataclass=None):
)
cls.__dataclass = dataclass
TASK_DATACLASS_REGISTRY[name] = dataclass
if dataclass is not None:
TASK_DATACLASS_REGISTRY[name] = dataclass
return cls

View File

@ -79,7 +79,7 @@ class AudioPretrainingTask(LegacyFairseqTask):
"""Setup the task (e.g., load dictionaries).
Args:
args (argparse.Namespace): parsed command-line arguments
args (omegaconf.DictConfig): parsed command-line arguments
"""
return cls(args)

View File

@ -12,6 +12,7 @@ import torch
from fairseq import metrics, search, tokenizer, utils
from fairseq.data import Dictionary, FairseqDataset, data_utils, encoders, iterators
from fairseq.dataclass.utils import gen_parser_from_dataclass
from omegaconf import DictConfig
logger = logging.getLogger(__name__)
@ -39,8 +40,8 @@ class FairseqTask(object):
"""
return criterion.logging_outputs_can_be_summed()
def __init__(self, args):
self.args = args
def __init__(self, cfg: DictConfig, **kwargs):
self.cfg = cfg
self.datasets = {}
self.dataset_to_epoch_iter = {}
@ -78,16 +79,16 @@ class FairseqTask(object):
return d
@classmethod
def setup_task(cls, args, **kwargs):
def setup_task(cls, cfg: DictConfig, **kwargs):
"""Setup the task (e.g., load dictionaries).
Args:
args (argparse.Namespace): parsed command-line arguments
cfg (omegaconf.DictConfig): parsed command-line arguments
"""
return cls(args, **kwargs)
return cls(cfg, **kwargs)
def has_sharded_data(self, split):
return os.pathsep in getattr(self.args, "data", "")
return os.pathsep in getattr(self.cfg, "data", "")
def load_dataset(self, split, combine=False, **kwargs):
"""Load a given dataset split.
@ -254,39 +255,39 @@ class FairseqTask(object):
return epoch_iter
def build_model(self, args):
def build_model(self, cfg: DictConfig):
"""
Build the :class:`~fairseq.models.BaseFairseqModel` instance for this
task.
Args:
args (argparse.Namespace): parsed command-line arguments
cfg (omegaconf.DictConfig): configuration object
Returns:
a :class:`~fairseq.models.BaseFairseqModel` instance
"""
from fairseq import models, quantization_utils
model = models.build_model(args, self)
if getattr(args, "tpu", False):
model = models.build_model(cfg, self)
if getattr(cfg, "tpu", False):
model.prepare_for_tpu_()
model = quantization_utils.quantize_model_scalar(model, args)
model = quantization_utils.quantize_model_scalar(model, cfg)
return model
def build_criterion(self, args):
def build_criterion(self, cfg: DictConfig):
"""
Build the :class:`~fairseq.criterions.FairseqCriterion` instance for
this task.
Args:
args (argparse.Namespace): parsed command-line arguments
cfg (omegaconf.DictConfig): configration object
Returns:
a :class:`~fairseq.criterions.FairseqCriterion` instance
"""
from fairseq import criterions
return criterions.build_criterion(args, self)
return criterions.build_criterion(cfg, self)
def build_generator(
self, models, args, seq_gen_cls=None, extra_gen_cls_kwargs=None

View File

@ -28,7 +28,7 @@ from fairseq.data import (
from fairseq.data.indexed_dataset import get_available_dataset_impl
from fairseq.data.shorten_dataset import maybe_shorten_dataset
from fairseq.dataclass import ChoiceEnum, FairseqDataclass
from fairseq.tasks import FairseqTask, register_task
from fairseq.tasks import LegacyFairseqTask, register_task
from omegaconf import II
@ -85,16 +85,16 @@ class LanguageModelingConfig(FairseqDataclass):
},
)
# TODO common vars below add to parent
seed: int = II("params.common.seed")
seed: int = II("common.seed")
dataset_impl: Optional[ChoiceEnum(get_available_dataset_impl())] = II(
"params.dataset.dataset_impl"
"dataset.dataset_impl"
)
data_buffer_size: int = II("params.dataset.data_buffer_size")
tpu: bool = II("params.common.tpu")
data_buffer_size: int = II("dataset.data_buffer_size")
tpu: bool = II("common.tpu")
@register_task("language_modeling", dataclass=LanguageModelingConfig)
class LanguageModelingTask(FairseqTask):
class LanguageModelingTask(LegacyFairseqTask):
"""
Train a language model.

View File

@ -117,7 +117,7 @@ class MultilingualTranslationTask(LegacyFairseqTask):
return cls(args, dicts, training)
@classmethod
def prepare(cls, args, **kargs):
def update_args(cls, args):
args.left_pad_source = utils.eval_bool(args.left_pad_source)
args.left_pad_target = utils.eval_bool(args.left_pad_target)
@ -127,6 +127,10 @@ class MultilingualTranslationTask(LegacyFairseqTask):
)
if isinstance(args.lang_pairs, str):
args.lang_pairs = args.lang_pairs.split(",")
@classmethod
def prepare(cls, args, **kargs):
cls.update_args(args)
sorted_langs = sorted(
list({x for lang_pair in args.lang_pairs for x in lang_pair.split("-")})
)
@ -298,6 +302,10 @@ class MultilingualTranslationTask(LegacyFairseqTask):
if len(messages) > 0:
raise ValueError(" ".join(messages))
# Update args -> the fact that the constructor here
# changes the args object doesn't mean you get the same one here
self.update_args(args)
# Check if task args are consistant with model args
check_args()

View File

@ -13,7 +13,7 @@ from fairseq.data.audio.speech_to_text_dataset import (
SpeechToTextDataset,
SpeechToTextDatasetCreator,
)
from fairseq.tasks import FairseqTask, register_task
from fairseq.tasks import LegacyFairseqTask, register_task
logging.basicConfig(
@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
@register_task("speech_to_text")
class SpeechToTextTask(FairseqTask):
class SpeechToTextTask(LegacyFairseqTask):
@staticmethod
def add_args(parser):
parser.add_argument("data", help="manifest root path")

View File

@ -11,15 +11,18 @@ import contextlib
import logging
import sys
import time
from argparse import Namespace
from itertools import chain
from typing import Any, Dict, List
import torch
from fairseq import checkpoint_utils, distributed_utils, models, optim, utils
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
from fairseq.file_io import PathManager
from fairseq.logging import meters, metrics
from fairseq.nan_detector import NanDetector
from fairseq.optim import lr_scheduler
from omegaconf import DictConfig
logger = logging.getLogger(__name__)
@ -35,19 +38,25 @@ class Trainer(object):
communication of the gradients across workers.
"""
def __init__(self, args, task, model, criterion, quantizer=None):
self.args = args
def __init__(self, cfg: DictConfig, task, model, criterion, quantizer=None):
if isinstance(cfg, Namespace):
logger.warning(
"argparse.Namespace configuration is deprecated! Automatically converting to OmegaConf"
)
cfg = convert_namespace_to_omegaconf(cfg)
self.cfg = cfg
self.task = task
# catalog shared parameters
shared_params = _catalog_shared_params(model)
self.tpu = getattr(args, "tpu", False)
self.cuda = torch.cuda.is_available() and not args.cpu and not self.tpu
self.tpu = cfg.common.tpu
self.cuda = torch.cuda.is_available() and not cfg.common.cpu and not self.tpu
if self.cuda:
self.device = torch.device("cuda")
elif self.tpu:
self.device = utils.get_tpu_device(args)
self.device = utils.get_tpu_device()
else:
self.device = torch.device("cpu")
@ -58,19 +67,21 @@ class Trainer(object):
import torch_xla.core.xla_model as xm
self._model = xm.send_cpu_data_to_device(self._model, self.device)
if args.fp16:
if cfg.common.fp16:
self._criterion = self._criterion.half()
self._model = self._model.half()
elif args.bf16:
elif cfg.common.bf16:
self._criterion = self._criterion.to(dtype=torch.bfloat16)
self._model = self._model.to(dtype=torch.bfloat16)
if not args.pipeline_model_parallel:
if not cfg.distributed_training.pipeline_model_parallel:
self._criterion = self._criterion.to(device=self.device)
self._model = self._model.to(device=self.device)
self.pipeline_model_parallel = args.pipeline_model_parallel
self.pipeline_model_parallel = cfg.distributed_training.pipeline_model_parallel
self.last_device = None
if self.cuda and self.pipeline_model_parallel:
self.last_device = torch.device(args.pipeline_devices[-1])
self.last_device = torch.device(
cfg.distributed_training.pipeline_devices[-1]
)
# check that shared parameters are preserved after device transfer
for shared_param in shared_params:
@ -129,7 +140,7 @@ class Trainer(object):
@property
def data_parallel_world_size(self):
return self.args.distributed_world_size
return self.cfg.distributed_training.distributed_world_size
@property
def data_parallel_process_group(self):
@ -140,11 +151,11 @@ class Trainer(object):
@property
def data_parallel_rank(self):
return self.args.distributed_rank
return self.cfg.distributed_training.distributed_rank
@property
def is_data_parallel_master(self):
return distributed_utils.is_master(self.args)
return distributed_utils.is_master(self.cfg.distributed_training)
@property
def criterion(self):
@ -152,11 +163,11 @@ class Trainer(object):
if (
utils.has_parameters(self._criterion)
and self.data_parallel_world_size > 1
and not self.args.use_bmuf
and not self.cfg.optimization.use_bmuf
and not self.tpu
):
self._wrapped_criterion = models.DistributedFairseqModel(
self.args,
self.cfg.distributed_training,
self._criterion,
process_group=self.data_parallel_process_group,
)
@ -169,11 +180,11 @@ class Trainer(object):
if self._wrapped_model is None:
if (
self.data_parallel_world_size > 1
and not self.args.use_bmuf
and not self.cfg.optimization.use_bmuf
and not self.tpu
):
self._wrapped_model = models.DistributedFairseqModel(
self.args,
self.cfg.distributed_training,
self._model,
process_group=self.data_parallel_process_group,
)
@ -201,44 +212,51 @@ class Trainer(object):
)
)
if self.args.fp16 or self.args.bf16:
if self.cfg.common.fp16 or self.cfg.common.bf16:
if self.cuda and torch.cuda.get_device_capability(0)[0] < 7:
logger.info(
"NOTE: your device does NOT support faster training with --fp16, "
"please switch to FP32 which is likely to be faster"
)
if self.args.memory_efficient_fp16 or self.args.memory_efficient_bf16:
if (
self.cfg.common.memory_efficient_fp16
or self.cfg.common.memory_efficient_bf16
):
self._optimizer = optim.MemoryEfficientFP16Optimizer.build_optimizer(
self.args, params
self.cfg, params
)
else:
self._optimizer = optim.FP16Optimizer.build_optimizer(self.args, params)
self._optimizer = optim.FP16Optimizer.build_optimizer(self.cfg, params)
else:
if self.cuda and torch.cuda.get_device_capability(0)[0] >= 7:
logger.info("NOTE: your device may support faster training with --fp16")
self._optimizer = optim.build_optimizer(self.args, params)
self._optimizer = optim.build_optimizer(self.cfg.optimizer, params)
if self.args.use_bmuf:
self._optimizer = optim.FairseqBMUF(self.args, self._optimizer)
if self.cfg.optimization.use_bmuf:
self._optimizer = optim.FairseqBMUF(
self.cfg.bmuf,
self._optimizer,
)
if self.args.zero_sharding == "os":
if self.cfg.distributed_training.zero_sharding == "os":
if (
self.args.fp16
and not self.args.memory_efficient_fp16
and not self.args.memory_efficient_bf16
) and not self.args.fp16_no_flatten_grads:
self.cfg.common.fp16
and not self.cfg.common.memory_efficient_fp16
and not self.cfg.common.memory_efficient_bf16
) and not self.cfg.common.fp16_no_flatten_grads:
raise ValueError(
"ZeRO is incomptabile with fp16 and flattened grads. "
"Please use --fp16-no-flatten-grads"
)
else:
optim.shard_(
self.args, self._optimizer, self.data_parallel_process_group
)
optim.shard_(self._optimizer, self.data_parallel_process_group)
# We should initialize the learning rate scheduler immediately after
# building the optimizer, so that the initial learning rate is set.
self._lr_scheduler = lr_scheduler.build_lr_scheduler(self.args, self.optimizer)
self._lr_scheduler = lr_scheduler.build_lr_scheduler(
self.cfg.lr_scheduler,
self.optimizer,
)
self._lr_scheduler.step_update(0)
def consolidate_optimizer(self):
@ -253,7 +271,7 @@ class Trainer(object):
extra_state["previous_training_time"] = self.cumulative_training_time()
checkpoint_utils.save_state(
filename,
self.args,
self.cfg,
self.get_model().state_dict(),
self.get_criterion(),
self.optimizer,
@ -277,11 +295,10 @@ class Trainer(object):
bexists = PathManager.isfile(filename)
if bexists:
state = checkpoint_utils.load_checkpoint_to_cpu(filename)
# load model parameters
try:
self.get_model().load_state_dict(
state["model"], strict=True, args=self.args
state["model"], strict=True, model_cfg=self.cfg.model
)
if utils.has_parameters(self.get_criterion()):
self.get_criterion().load_state_dict(
@ -355,28 +372,28 @@ class Trainer(object):
if load_dataset:
logger.info("loading train data for epoch {}".format(epoch))
self.task.load_dataset(
self.args.train_subset,
self.cfg.dataset.train_subset,
epoch=epoch,
combine=combine,
data_selector=data_selector,
)
batch_iterator = self.task.get_batch_iterator(
dataset=self.task.dataset(self.args.train_subset),
max_tokens=self.args.max_tokens,
max_sentences=self.args.batch_size,
dataset=self.task.dataset(self.cfg.dataset.train_subset),
max_tokens=self.cfg.dataset.max_tokens,
max_sentences=self.cfg.dataset.batch_size,
max_positions=utils.resolve_max_positions(
self.task.max_positions(),
self.model.max_positions(),
self.args.max_tokens,
self.cfg.dataset.max_tokens,
),
ignore_invalid_inputs=True,
required_batch_size_multiple=self.args.required_batch_size_multiple,
seed=self.args.seed,
required_batch_size_multiple=self.cfg.dataset.required_batch_size_multiple,
seed=self.cfg.common.seed,
num_shards=self.data_parallel_world_size if shard_batch_itr else 1,
shard_id=self.data_parallel_rank if shard_batch_itr else 0,
num_workers=self.args.num_workers,
num_workers=self.cfg.dataset.num_workers,
epoch=epoch,
data_buffer_size=self.args.data_buffer_size,
data_buffer_size=self.cfg.dataset.data_buffer_size,
disable_iterator_cache=disable_iterator_cache,
)
self.reset_dummy_batch(batch_iterator.first_batch)
@ -390,19 +407,19 @@ class Trainer(object):
"""Return an EpochBatchIterator over given validation subset for a given epoch."""
batch_iterator = self.task.get_batch_iterator(
dataset=self.task.dataset(subset),
max_tokens=self.args.max_tokens_valid,
max_sentences=self.args.batch_size_valid,
max_tokens=self.cfg.dataset.max_tokens_valid,
max_sentences=self.cfg.dataset.batch_size_valid,
max_positions=utils.resolve_max_positions(
self.task.max_positions(),
self.model.max_positions(),
),
ignore_invalid_inputs=self.args.skip_invalid_size_inputs_valid_test,
required_batch_size_multiple=self.args.required_batch_size_multiple,
seed=self.args.seed,
ignore_invalid_inputs=self.cfg.dataset.skip_invalid_size_inputs_valid_test,
required_batch_size_multiple=self.cfg.dataset.required_batch_size_multiple,
seed=self.cfg.common.seed,
num_shards=self.data_parallel_world_size,
shard_id=self.data_parallel_rank,
num_workers=self.args.num_workers,
data_buffer_size=self.args.data_buffer_size,
num_workers=self.cfg.dataset.num_workers,
data_buffer_size=self.cfg.dataset.data_buffer_size,
disable_iterator_cache=disable_iterator_cache,
)
self.reset_dummy_batch(batch_iterator.first_batch)
@ -504,7 +521,7 @@ class Trainer(object):
self.zero_grad()
if self.cuda:
torch.cuda.empty_cache()
if self.args.distributed_world_size == 1:
if self.cfg.distributed_training.distributed_world_size == 1:
return None
else:
raise e
@ -565,7 +582,7 @@ class Trainer(object):
# multiply gradients by (# GPUs / sample_size) since DDP
# already normalizes by the number of GPUs. Thus we get
# (sum_of_gradients / sample_size).
if not self.args.use_bmuf:
if not self.cfg.optimization.use_bmuf:
self.optimizer.multiply_grads(
self.data_parallel_world_size / sample_size
)
@ -575,12 +592,12 @@ class Trainer(object):
with torch.autograd.profiler.record_function("clip-grads"):
# clip grads
grad_norm = self.clip_grad_norm(self.args.clip_norm)
grad_norm = self.clip_grad_norm(self.cfg.optimization.clip_norm)
# check that grad norms are consistent across workers
if (
not self.args.use_bmuf
and self.args.distributed_wrapper != "SlowMo"
not self.cfg.optimization.use_bmuf
and self.cfg.distributed_training.distributed_wrapper != "SlowMo"
and not self.tpu
):
self._check_grad_norms(grad_norm)
@ -624,7 +641,10 @@ class Trainer(object):
self.optimizer.optimizer
)
if not overflow or self.args.distributed_wrapper == "SlowMo":
if (
not overflow
or self.cfg.distributed_training.distributed_wrapper == "SlowMo"
):
self.set_num_updates(self.get_num_updates() + 1)
if self.tpu:
@ -636,7 +656,7 @@ class Trainer(object):
# only log stats every log_interval steps
# this causes wps to be misreported when log_interval > 1
logging_output = {}
if self.get_num_updates() % self.args.log_interval == 0:
if self.get_num_updates() % self.cfg.common.log_interval == 0:
# log memory usage
mem_info = xm.get_memory_info(self.device)
gb_free = mem_info["kb_free"] / 1024 / 1024
@ -677,16 +697,16 @@ class Trainer(object):
# clear CUDA cache to reduce memory fragmentation
if (
self.cuda
and self.args.empty_cache_freq > 0
and self.cfg.common.empty_cache_freq > 0
and (
(self.get_num_updates() + self.args.empty_cache_freq - 1)
% self.args.empty_cache_freq
(self.get_num_updates() + self.cfg.common.empty_cache_freq - 1)
% self.cfg.common.empty_cache_freq
)
== 0
):
torch.cuda.empty_cache()
if self.args.fp16:
if self.cfg.common.fp16:
metrics.log_scalar(
"loss_scale",
self.optimizer.scaler.loss_scale,
@ -883,10 +903,10 @@ class Trainer(object):
return t.to(dtype=torch.bfloat16)
return t
if self.args.fp16:
if self.cfg.common.fp16:
sample = utils.apply_to_sample(apply_half, sample)
if self.args.bf16:
if self.cfg.common.bf16:
sample = utils.apply_to_sample(apply_bfloat16, sample)
return sample
@ -894,7 +914,7 @@ class Trainer(object):
def _set_seed(self):
# Set seed based on args.seed and the update number so that we get
# reproducible results when resuming from checkpoints
seed = self.args.seed + self.get_num_updates()
seed = self.cfg.common.seed + self.get_num_updates()
utils.set_torch_seed(seed)
def _sync_stats(self):
@ -902,10 +922,12 @@ class Trainer(object):
# BMUF and it's a bmuf sync with warmup iterations completed before.
if self.data_parallel_world_size == 1:
return False
elif self.args.use_bmuf:
return (self.get_num_updates() + 1) % self.args.global_sync_iter == 0 and (
elif self.cfg.optimization.use_bmuf:
return (
self.get_num_updates() + 1
) > self.args.warmup_iterations
) % self.cfg.bmuf.global_sync_iter == 0 and (
self.get_num_updates() + 1
) > self.cfg.bmuf.warmup_iterations
else:
return True
@ -950,7 +972,7 @@ class Trainer(object):
zip(
*distributed_utils.all_gather_list(
[logging_outputs] + list(extra_stats_to_sum),
max_size=getattr(self.args, "all_gather_list_size", 16384),
max_size=getattr(self.cfg.common, "all_gather_list_size", 16384),
group=self.data_parallel_process_group,
)
)
@ -1038,11 +1060,11 @@ class Trainer(object):
if grad_norm is not None:
metrics.log_speed("ups", 1.0, priority=100, round=2)
metrics.log_scalar("gnorm", grad_norm, priority=400, round=3)
if self.args.clip_norm > 0:
if self.cfg.optimization.clip_norm > 0:
metrics.log_scalar(
"clip",
torch.where(
grad_norm > self.args.clip_norm,
grad_norm > self.cfg.optimization.clip_norm,
grad_norm.new_tensor(100),
grad_norm.new_tensor(0),
),
@ -1087,7 +1109,7 @@ class Trainer(object):
logger.warning(
"XLA compilation detected on device #{}; too many of these can lead "
"to slow training, but we expect a few in the beginning".format(
self.args.distributed_rank
self.cfg.distributed_training.distributed_rank
)
)
self._num_xla_compiles = num_xla_compiles

View File

@ -11,13 +11,19 @@ Evaluate the perplexity of a trained language model.
import logging
import math
import os
from argparse import Namespace
import torch
from fairseq import checkpoint_utils, distributed_utils, options, tasks, utils
from fairseq.data import LMContextWindowDataset
from fairseq.dataclass.data_class import register_hydra_cfg
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
from fairseq.logging import progress_bar
from fairseq.logging.meters import StopwatchMeter, TimeMeter
from fairseq.sequence_scorer import SequenceScorer
from hydra.core.config_store import ConfigStore
from hydra.experimental import initialize
from omegaconf import DictConfig
logging.basicConfig(
@ -60,65 +66,60 @@ class WordStat(object):
)
def main(parsed_args, **unused_kwargs):
assert parsed_args.path is not None, "--path required for evaluation!"
def main(cfg: DictConfig, override_args=None, **unused_kwargs):
if isinstance(cfg, Namespace):
cfg = convert_namespace_to_omegaconf(cfg)
if torch.cuda.is_available() and not parsed_args.cpu:
torch.cuda.set_device(parsed_args.device_id)
utils.import_user_module(cfg.common)
utils.import_user_module(parsed_args)
use_fp16 = cfg.common.fp16
use_cuda = torch.cuda.is_available() and not cfg.common.cpu
logger.info(parsed_args)
if use_cuda:
torch.cuda.set_device(cfg.distributed_training.device_id)
use_cuda = torch.cuda.is_available() and not parsed_args.cpu
if override_args is not None:
overrides = vars(override_args)
overrides.update(eval(getattr(override_args, "model_overrides", "{}")))
else:
overrides = None
task = tasks.setup_task(parsed_args)
logger.info(cfg)
# Load ensemble
logger.info("loading model(s) from {}".format(parsed_args.path))
models, args = checkpoint_utils.load_model_ensemble(
parsed_args.path.split(os.pathsep),
arg_overrides=eval(parsed_args.model_overrides),
task=task,
suffix=getattr(parsed_args, "checkpoint_suffix", ""),
strict=(parsed_args.checkpoint_shard_count == 1),
num_shards=parsed_args.checkpoint_shard_count,
)
for arg in vars(parsed_args).keys():
if arg not in {
"self_target",
"future_target",
"past_target",
"tokens_per_sample",
"output_size_dictionary",
"add_bos_token",
}:
setattr(args, arg, getattr(parsed_args, arg))
logger.info("loading model(s) from {}".format(cfg.common_eval.path))
# reduce tokens per sample by the required context window size
args.tokens_per_sample -= args.context_window
task = tasks.setup_task(args)
cfg.task.tokens_per_sample -= cfg.eval_lm.context_window
models, model_args, task = checkpoint_utils.load_model_ensemble_and_task(
[cfg.common_eval.path],
arg_overrides=overrides,
suffix=cfg.checkpoint.checkpoint_suffix,
strict=(cfg.checkpoint.checkpoint_shard_count == 1),
num_shards=cfg.checkpoint.checkpoint_shard_count,
)
# Load dataset splits
task.load_dataset(args.gen_subset)
dataset = task.dataset(args.gen_subset)
if args.context_window > 0:
gen_subset = cfg.dataset.gen_subset
task.load_dataset(gen_subset)
dataset = task.dataset(gen_subset)
if cfg.eval_lm.context_window > 0:
dataset = LMContextWindowDataset(
dataset=dataset,
tokens_per_sample=args.tokens_per_sample,
context_window=args.context_window,
tokens_per_sample=cfg.task.tokens_per_sample,
context_window=cfg.eval_lm.context_window,
pad_idx=task.source_dictionary.pad(),
)
logger.info("{} {} {} examples".format(args.data, args.gen_subset, len(dataset)))
logger.info("{} {} {} examples".format(cfg.task.data, gen_subset, len(dataset)))
# Optimize ensemble for generation and set the source and dest dicts on the model (required by scorer)
for model in models:
if args.fp16:
if use_fp16:
model.half()
if use_cuda and not args.pipeline_model_parallel:
if use_cuda and not cfg.distributed_training.pipeline_model_parallel:
model.cuda()
model.prepare_for_inference_(args)
model.prepare_for_inference_(cfg)
assert len(models) > 0
@ -128,35 +129,41 @@ def main(parsed_args, **unused_kwargs):
itr = task.get_batch_iterator(
dataset=dataset,
max_tokens=args.max_tokens or 36000,
max_sentences=args.batch_size,
max_tokens=cfg.dataset.max_tokens or 36000,
max_sentences=cfg.dataset.batch_size,
max_positions=utils.resolve_max_positions(
*[model.max_positions() for model in models]
),
ignore_invalid_inputs=True,
num_shards=args.num_shards,
shard_id=args.shard_id,
num_workers=args.num_workers,
data_buffer_size=args.data_buffer_size,
num_shards=max(
cfg.dataset.num_shards,
cfg.distributed_training.distributed_world_size,
),
shard_id=max(
cfg.dataset.shard_id,
cfg.distributed_training.distributed_rank,
),
num_workers=cfg.dataset.num_workers,
data_buffer_size=cfg.dataset.data_buffer_size,
).next_epoch_itr(shuffle=False)
progress = progress_bar.progress_bar(
itr,
log_format=args.log_format,
log_interval=args.log_interval,
default_log_format=("tqdm" if not args.no_progress_bar else "none"),
log_format=cfg.common.log_format,
log_interval=cfg.common.log_interval,
default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"),
)
gen_timer = StopwatchMeter()
scorer = SequenceScorer(task.target_dictionary, args.softmax_batch)
scorer = SequenceScorer(task.target_dictionary, cfg.eval_lm.softmax_batch)
score_sum = 0.0
count = 0
if args.remove_bpe is not None:
if args.remove_bpe == "sentencepiece":
if cfg.common_eval.remove_bpe is not None:
if cfg.common_eval.remove_bpe == "sentencepiece":
raise NotImplementedError
else:
bpe_cont = args.remove_bpe.rstrip()
bpe_cont = cfg.common_eval.remove_bpe.rstrip()
bpe_toks = {
i
for i in range(len(task.source_dictionary))
@ -189,7 +196,7 @@ def main(parsed_args, **unused_kwargs):
tgt_len = tokens.numel()
pos_scores = hypo["positional_scores"].float()
if getattr(args, "add_bos_token", False):
if cfg.task.add_bos_token:
assert hypo["tokens"][0].item() == task.target_dictionary.bos()
tokens = tokens[1:]
pos_scores = pos_scores[1:]
@ -212,7 +219,7 @@ def main(parsed_args, **unused_kwargs):
score_sum += pos_scores.sum().cpu()
count += pos_scores.numel() - skipped_toks
if args.output_word_probs or args.output_word_stats:
if cfg.eval_lm.output_word_probs or cfg.eval_lm.output_word_stats:
w = ""
word_prob = []
is_bpe = False
@ -238,7 +245,7 @@ def main(parsed_args, **unused_kwargs):
)
is_bpe = False
w = ""
if args.output_word_probs:
if cfg.eval_lm.output_word_probs:
logger.info(
str(int(sample_id))
+ " "
@ -264,7 +271,7 @@ def main(parsed_args, **unused_kwargs):
)
)
if args.output_word_stats:
if cfg.eval_lm.output_word_stats:
for ws in sorted(word_stats.values(), key=lambda x: x.count, reverse=True):
logger.info(ws)
@ -272,8 +279,16 @@ def main(parsed_args, **unused_kwargs):
def cli_main():
parser = options.get_eval_lm_parser()
args = options.parse_args_and_arch(parser)
distributed_utils.call_main(args, main)
# only override args that are explicitly given on the command line
override_parser = options.get_validation_parser()
override_args = options.parse_args_and_arch(override_parser, suppress_defaults=True)
distributed_utils.call_main(args, main, override_args=override_args)
if __name__ == "__main__":
cs = ConfigStore.instance()
register_hydra_cfg(cs)
initialize(config_path="../config", strict=True)
cli_main()

View File

@ -12,33 +12,45 @@ import logging
import math
import os
import sys
from argparse import Namespace
from itertools import chain
import numpy as np
import torch
from fairseq import checkpoint_utils, options, scoring, tasks, utils
from fairseq.data import encoders
from fairseq.dataclass.data_class import register_hydra_cfg
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
from fairseq.logging import progress_bar
from fairseq.logging.meters import StopwatchMeter, TimeMeter
from hydra.core.config_store import ConfigStore
from hydra.experimental import initialize
from omegaconf import DictConfig
def main(args):
assert args.path is not None, "--path required for generation!"
def main(cfg: DictConfig):
if isinstance(cfg, Namespace):
cfg = convert_namespace_to_omegaconf(cfg)
assert cfg.common_eval.path is not None, "--path required for generation!"
assert (
not args.sampling or args.nbest == args.beam
not cfg.generation.sampling or cfg.generation.nbest == cfg.generation.beam
), "--sampling requires --nbest to be equal to --beam"
assert (
args.replace_unk is None or args.dataset_impl == "raw"
cfg.generation.replace_unk is None or cfg.dataset.dataset_impl == "raw"
), "--replace-unk requires a raw text dataset (--dataset-impl=raw)"
if args.results_path is not None:
os.makedirs(args.results_path, exist_ok=True)
if cfg.common_eval.results_path is not None:
os.makedirs(cfg.common_eval.results_path, exist_ok=True)
output_path = os.path.join(
args.results_path, "generate-{}.txt".format(args.gen_subset)
cfg.common_eval.results_path,
"generate-{}.txt".format(cfg.dataset.gen_subset),
)
with open(output_path, "w", buffering=1, encoding="utf-8") as h:
return _main(args, h)
return _main(cfg, h)
else:
return _main(args, sys.stdout)
return _main(cfg, sys.stdout)
def get_symbols_to_strip_from_output(generator):
@ -48,7 +60,7 @@ def get_symbols_to_strip_from_output(generator):
return {generator.eos}
def _main(args, output_file):
def _main(cfg: DictConfig, output_file):
logging.basicConfig(
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
@ -57,22 +69,22 @@ def _main(args, output_file):
)
logger = logging.getLogger("fairseq_cli.generate")
utils.import_user_module(args)
utils.import_user_module(cfg.common)
if args.max_tokens is None and args.batch_size is None:
args.max_tokens = 12000
logger.info(args)
if cfg.dataset.max_tokens is None and cfg.dataset.batch_size is None:
cfg.dataset.max_tokens = 12000
logger.info(cfg)
# Fix seed for stochastic decoding
if args.seed is not None and not args.no_seed_provided:
np.random.seed(args.seed)
utils.set_torch_seed(args.seed)
if cfg.common.seed is not None and not cfg.generation.no_seed_provided:
np.random.seed(cfg.common.seed)
utils.set_torch_seed(cfg.common.seed)
use_cuda = torch.cuda.is_available() and not args.cpu
use_cuda = torch.cuda.is_available() and not cfg.common.cpu
# Load dataset splits
task = tasks.setup_task(args)
task.load_dataset(args.gen_subset)
task = tasks.setup_task(cfg.task)
task.load_dataset(cfg.dataset.gen_subset)
# Set dictionaries
try:
@ -81,32 +93,30 @@ def _main(args, output_file):
src_dict = None
tgt_dict = task.target_dictionary
overrides = ast.literal_eval(args.model_overrides)
overrides = ast.literal_eval(cfg.common_eval.model_overrides)
# Load ensemble
logger.info("loading model(s) from {}".format(args.path))
logger.info("loading model(s) from {}".format(cfg.common_eval.path))
models, _model_args = checkpoint_utils.load_model_ensemble(
utils.split_paths(args.path),
utils.split_paths(cfg.common_eval.path),
arg_overrides=overrides,
task=task,
suffix=getattr(args, "checkpoint_suffix", ""),
strict=(args.checkpoint_shard_count == 1),
num_shards=args.checkpoint_shard_count,
suffix=cfg.checkpoint.checkpoint_suffix,
strict=(cfg.checkpoint.checkpoint_shard_count == 1),
num_shards=cfg.checkpoint.checkpoint_shard_count,
)
if args.lm_path is not None:
overrides["data"] = args.data
if cfg.generation.lm_path is not None:
overrides["data"] = cfg.task.data
try:
lms, _ = checkpoint_utils.load_model_ensemble(
[args.lm_path],
arg_overrides=overrides,
task=None,
[cfg.generation.lm_path], arg_overrides=overrides, task=None
)
except:
logger.warning(
f"Failed to load language model! Please make sure that the language model dict is the same "
f"as target dict and is located in the data dir ({args.data})"
f"as target dict and is located in the data dir ({cfg.task.data})"
)
raise
@ -118,49 +128,50 @@ def _main(args, output_file):
for model in chain(models, lms):
if model is None:
continue
if args.fp16:
if cfg.common.fp16:
model.half()
if use_cuda and not args.pipeline_model_parallel:
if use_cuda and not cfg.distributed_training.pipeline_model_parallel:
model.cuda()
model.prepare_for_inference_(args)
model.prepare_for_inference_(cfg)
# Load alignment dictionary for unknown word replacement
# (None if no unknown word replacement, empty if no path to align dictionary)
align_dict = utils.load_align_dict(args.replace_unk)
align_dict = utils.load_align_dict(cfg.generation.replace_unk)
# Load dataset (possibly sharded)
itr = task.get_batch_iterator(
dataset=task.dataset(args.gen_subset),
max_tokens=args.max_tokens,
max_sentences=args.batch_size,
dataset=task.dataset(cfg.dataset.gen_subset),
max_tokens=cfg.dataset.max_tokens,
max_sentences=cfg.dataset.batch_size,
max_positions=utils.resolve_max_positions(
task.max_positions(), *[model.max_positions() for model in models]
task.max_positions(), *[m.max_positions() for m in models]
),
ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
required_batch_size_multiple=args.required_batch_size_multiple,
num_shards=args.num_shards,
shard_id=args.shard_id,
num_workers=args.num_workers,
data_buffer_size=args.data_buffer_size,
ignore_invalid_inputs=cfg.dataset.skip_invalid_size_inputs_valid_test,
required_batch_size_multiple=cfg.dataset.required_batch_size_multiple,
seed=cfg.common.seed,
num_shards=cfg.distributed_training.distributed_world_size,
shard_id=cfg.distributed_training.distributed_rank,
num_workers=cfg.dataset.num_workers,
data_buffer_size=cfg.dataset.data_buffer_size,
).next_epoch_itr(shuffle=False)
progress = progress_bar.progress_bar(
itr,
log_format=args.log_format,
log_interval=args.log_interval,
default_log_format=("tqdm" if not args.no_progress_bar else "none"),
log_format=cfg.common.log_format,
log_interval=cfg.common.log_interval,
default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"),
)
# Initialize generator
gen_timer = StopwatchMeter()
extra_gen_cls_kwargs = {"lm_model": lms[0], "lm_weight": args.lm_weight}
extra_gen_cls_kwargs = {"lm_model": lms[0], "lm_weight": cfg.generation.lm_weight}
generator = task.build_generator(
models, args, extra_gen_cls_kwargs=extra_gen_cls_kwargs
models, cfg.task, extra_gen_cls_kwargs=extra_gen_cls_kwargs
)
# Handle tokenization and BPE
tokenizer = task.build_tokenizer(args)
bpe = task.build_bpe(args)
tokenizer = encoders.build_tokenizer(cfg.tokenizer)
bpe = encoders.build_bpe(cfg.bpe)
def decode_fn(x):
if bpe is not None:
@ -169,7 +180,7 @@ def _main(args, output_file):
x = tokenizer.decode(x)
return x
scorer = scoring.build_scorer(args, tgt_dict)
scorer = scoring.build_scorer(cfg.scoring, tgt_dict)
num_sentences = 0
has_target = True
@ -180,8 +191,8 @@ def _main(args, output_file):
continue
prefix_tokens = None
if args.prefix_size > 0:
prefix_tokens = sample["target"][:, : args.prefix_size]
if cfg.generation.prefix_size > 0:
prefix_tokens = sample["target"][:, : cfg.generation.prefix_size]
constraints = None
if "constraints" in sample:
@ -217,19 +228,21 @@ def _main(args, output_file):
# Either retrieve the original sentences or regenerate them from tokens.
if align_dict is not None:
src_str = task.dataset(args.gen_subset).src.get_original_text(sample_id)
target_str = task.dataset(args.gen_subset).tgt.get_original_text(
src_str = task.dataset(cfg.dataset.gen_subset).src.get_original_text(
sample_id
)
target_str = task.dataset(cfg.dataset.gen_subset).tgt.get_original_text(
sample_id
)
else:
if src_dict is not None:
src_str = src_dict.string(src_tokens, args.remove_bpe)
src_str = src_dict.string(src_tokens, cfg.common_eval.remove_bpe)
else:
src_str = ""
if has_target:
target_str = tgt_dict.string(
target_tokens,
args.remove_bpe,
cfg.common_eval.remove_bpe,
escape_unk=True,
extra_symbols_to_ignore=get_symbols_to_strip_from_output(
generator
@ -240,25 +253,25 @@ def _main(args, output_file):
if has_target:
target_str = decode_fn(target_str)
if not args.quiet:
if not cfg.common_eval.quiet:
if src_dict is not None:
print("S-{}\t{}".format(sample_id, src_str), file=output_file)
if has_target:
print("T-{}\t{}".format(sample_id, target_str), file=output_file)
# Process top predictions
for j, hypo in enumerate(hypos[i][: args.nbest]):
for j, hypo in enumerate(hypos[i][: cfg.generation.nbest]):
hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
hypo_tokens=hypo["tokens"].int().cpu(),
src_str=src_str,
alignment=hypo["alignment"],
align_dict=align_dict,
tgt_dict=tgt_dict,
remove_bpe=args.remove_bpe,
remove_bpe=cfg.common_eval.remove_bpe,
extra_symbols_to_ignore=get_symbols_to_strip_from_output(generator),
)
detok_hypo_str = decode_fn(hypo_str)
if not args.quiet:
if not cfg.common_eval.quiet:
score = hypo["score"] / math.log(2) # convert to base 2
# original hypothesis (after tokenization and BPE)
print(
@ -286,7 +299,7 @@ def _main(args, output_file):
file=output_file,
)
if args.print_alignment:
if cfg.generation.print_alignment:
print(
"A-{}\t{}".format(
sample_id,
@ -300,13 +313,13 @@ def _main(args, output_file):
file=output_file,
)
if args.print_step:
if cfg.generation.print_step:
print(
"I-{}\t{}".format(sample_id, hypo["steps"]),
file=output_file,
)
if getattr(args, "retain_iter_history", False):
if cfg.generation.retain_iter_history:
for step, h in enumerate(hypo["history"]):
_, h_str, _ = utils.post_process_prediction(
hypo_tokens=h["tokens"].int().cpu(),
@ -323,7 +336,7 @@ def _main(args, output_file):
# Score only the top hypothesis
if has_target and j == 0:
if align_dict is not None or args.remove_bpe is not None:
if align_dict is not None or cfg.common_eval.remove_bpe is not None:
# Convert back to tokens for evaluation with unk replacement and/or without BPE
target_tokens = tgt_dict.encode_line(
target_str, add_if_not_exist=True
@ -353,8 +366,8 @@ def _main(args, output_file):
)
)
if has_target:
if args.bpe and not args.sacrebleu:
if args.remove_bpe:
if cfg.bpe and not cfg.generation.sacrebleu:
if cfg.common_eval.remove_bpe:
logger.warning(
"BLEU score is being computed by splitting detokenized string on spaces, this is probably not what you want. Use --sacrebleu for standard 13a BLEU tokenization"
)
@ -365,7 +378,7 @@ def _main(args, output_file):
# use print to be consistent with other main outputs: S-, H-, T-, D- and so on
print(
"Generate {} with beam={}: {}".format(
args.gen_subset, args.beam, scorer.result_string()
cfg.dataset.gen_subset, cfg.generation.beam, scorer.result_string()
),
file=output_file,
)
@ -380,4 +393,7 @@ def cli_main():
if __name__ == "__main__":
cs = ConfigStore.instance()
register_hydra_cfg(cs)
initialize(config_path="../config", strict=True)
cli_main()

View File

@ -7,20 +7,27 @@
Translate raw text with a trained model. Batches data on-the-fly.
"""
import ast
import fileinput
import logging
import math
import os
import sys
import time
from argparse import Namespace
from collections import namedtuple
import numpy as np
import torch
from fairseq import checkpoint_utils, distributed_utils, options, tasks, utils
from fairseq.data import encoders
from fairseq.dataclass.data_class import register_hydra_cfg
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
from fairseq.token_generation_constraints import pack_constraints, unpack_constraints
from fairseq_cli.generate import get_symbols_to_strip_from_output
from hydra.core.config_store import ConfigStore
from hydra.experimental import initialize
from omegaconf import DictConfig
logging.basicConfig(
@ -49,11 +56,11 @@ def buffered_read(input, buffer_size):
yield buffer
def make_batches(lines, args, task, max_positions, encode_fn):
def make_batches(lines, cfg, task, max_positions, encode_fn):
def encode_fn_target(x):
return encode_fn(x)
if args.constraints:
if cfg.generation.constraints:
# Strip (tab-delimited) contraints, if present, from input lines,
# store them in batch_constraints
batch_constraints = [list() for _ in lines]
@ -79,7 +86,7 @@ def make_batches(lines, args, task, max_positions, encode_fn):
for src_str in lines
]
if args.constraints:
if cfg.generation.constraints:
constraints_tensor = pack_constraints(batch_constraints)
else:
constraints_tensor = None
@ -89,10 +96,10 @@ def make_batches(lines, args, task, max_positions, encode_fn):
dataset=task.build_dataset_for_inference(
tokens, lengths, constraints=constraints_tensor
),
max_tokens=args.max_tokens,
max_sentences=args.batch_size,
max_tokens=cfg.dataset.max_tokens,
max_sentences=cfg.dataset.batch_size,
max_positions=max_positions,
ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
ignore_invalid_inputs=cfg.dataset.skip_invalid_size_inputs_valid_test,
).next_epoch_itr(shuffle=False)
for batch in itr:
ids = batch["id"]
@ -108,45 +115,50 @@ def make_batches(lines, args, task, max_positions, encode_fn):
)
def main(args):
def main(cfg: DictConfig):
if isinstance(cfg, Namespace):
cfg = convert_namespace_to_omegaconf(cfg)
start_time = time.time()
total_translate_time = 0
utils.import_user_module(args)
utils.import_user_module(cfg.common)
if args.buffer_size < 1:
args.buffer_size = 1
if args.max_tokens is None and args.batch_size is None:
args.batch_size = 1
if cfg.interactive.buffer_size < 1:
cfg.interactive.buffer_size = 1
if cfg.dataset.max_tokens is None and cfg.dataset.batch_size is None:
cfg.dataset.batch_size = 1
assert (
not args.sampling or args.nbest == args.beam
not cfg.generation.sampling or cfg.generation.nbest == cfg.generation.beam
), "--sampling requires --nbest to be equal to --beam"
assert (
not args.batch_size or args.batch_size <= args.buffer_size
not cfg.dataset.batch_size
or cfg.dataset.batch_size <= cfg.interactive.buffer_size
), "--batch-size cannot be larger than --buffer-size"
logger.info(args)
logger.info(cfg)
# Fix seed for stochastic decoding
if args.seed is not None and not args.no_seed_provided:
np.random.seed(args.seed)
utils.set_torch_seed(args.seed)
if cfg.common.seed is not None and not cfg.generation.no_seed_provided:
np.random.seed(cfg.common.seed)
utils.set_torch_seed(cfg.common.seed)
use_cuda = torch.cuda.is_available() and not args.cpu
use_cuda = torch.cuda.is_available() and not cfg.common.cpu
# Setup task, e.g., translation
task = tasks.setup_task(args)
task = tasks.setup_task(cfg.task)
# Load ensemble
logger.info("loading model(s) from {}".format(args.path))
overrides = ast.literal_eval(cfg.common_eval.model_overrides)
logger.info("loading model(s) from {}".format(cfg.common_eval.path))
models, _model_args = checkpoint_utils.load_model_ensemble(
args.path.split(os.pathsep),
arg_overrides=eval(args.model_overrides),
utils.split_paths(cfg.common_eval.path),
arg_overrides=overrides,
task=task,
suffix=getattr(args, "checkpoint_suffix", ""),
strict=(args.checkpoint_shard_count == 1),
num_shards=args.checkpoint_shard_count,
suffix=cfg.checkpoint.checkpoint_suffix,
strict=(cfg.checkpoint.checkpoint_shard_count == 1),
num_shards=cfg.checkpoint.checkpoint_shard_count,
)
# Set dictionaries
@ -155,18 +167,20 @@ def main(args):
# Optimize ensemble for generation
for model in models:
if args.fp16:
if model is None:
continue
if cfg.common.fp16:
model.half()
if use_cuda and not args.pipeline_model_parallel:
if use_cuda and not cfg.distributed_training.pipeline_model_parallel:
model.cuda()
model.prepare_for_inference_(args)
model.prepare_for_inference_(cfg)
# Initialize generator
generator = task.build_generator(models, args)
generator = task.build_generator(models, cfg.task)
# Handle tokenization and BPE
tokenizer = encoders.build_tokenizer(args)
bpe = encoders.build_bpe(args)
tokenizer = encoders.build_tokenizer(cfg.tokenizer)
bpe = encoders.build_bpe(cfg.bpe)
def encode_fn(x):
if tokenizer is not None:
@ -184,25 +198,25 @@ def main(args):
# Load alignment dictionary for unknown word replacement
# (None if no unknown word replacement, empty if no path to align dictionary)
align_dict = utils.load_align_dict(args.replace_unk)
align_dict = utils.load_align_dict(cfg.generation.replace_unk)
max_positions = utils.resolve_max_positions(
task.max_positions(), *[model.max_positions() for model in models]
)
if args.constraints:
if cfg.generation.constraints:
logger.warning(
"NOTE: Constrained decoding currently assumes a shared subword vocabulary."
)
if args.buffer_size > 1:
logger.info("Sentence buffer size: %s", args.buffer_size)
if cfg.interactive.buffer_size > 1:
logger.info("Sentence buffer size: %s", cfg.interactive.buffer_size)
logger.info("NOTE: hypothesis and token scores are output in base 2")
logger.info("Type the input sentence and press return:")
start_id = 0
for inputs in buffered_read(args.input, args.buffer_size):
for inputs in buffered_read(cfg.interactive.input, cfg.interactive.buffer_size):
results = []
for batch in make_batches(inputs, args, task, max_positions, encode_fn):
for batch in make_batches(inputs, cfg, task, max_positions, encode_fn):
bsz = batch.src_tokens.size(0)
src_tokens = batch.src_tokens
src_lengths = batch.src_lengths
@ -226,7 +240,7 @@ def main(args):
translate_time = time.time() - translate_start_time
total_translate_time += translate_time
list_constraints = [[] for _ in range(bsz)]
if args.constraints:
if cfg.generation.constraints:
list_constraints = [unpack_constraints(c) for c in constraints]
for i, (id, hypos) in enumerate(zip(batch.ids.tolist(), translations)):
src_tokens_i = utils.strip_pad(src_tokens[i], tgt_dict.pad())
@ -246,25 +260,25 @@ def main(args):
# sort output to match input order
for id_, src_tokens, hypos, info in sorted(results, key=lambda x: x[0]):
if src_dict is not None:
src_str = src_dict.string(src_tokens, args.remove_bpe)
src_str = src_dict.string(src_tokens, cfg.common_eval.remove_bpe)
print("S-{}\t{}".format(id_, src_str))
print("W-{}\t{:.3f}\tseconds".format(id_, info["time"]))
for constraint in info["constraints"]:
print(
"C-{}\t{}".format(
id_, tgt_dict.string(constraint, args.remove_bpe)
id_, tgt_dict.string(constraint, cfg.common_eval.remove_bpe)
)
)
# Process top predictions
for hypo in hypos[: min(len(hypos), args.nbest)]:
for hypo in hypos[: min(len(hypos), cfg.generation.nbest)]:
hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
hypo_tokens=hypo["tokens"].int().cpu(),
src_str=src_str,
alignment=hypo["alignment"],
align_dict=align_dict,
tgt_dict=tgt_dict,
remove_bpe=args.remove_bpe,
remove_bpe=cfg.common_eval.remove_bpe,
extra_symbols_to_ignore=get_symbols_to_strip_from_output(generator),
)
detok_hypo_str = decode_fn(hypo_str)
@ -285,7 +299,7 @@ def main(args):
),
)
)
if args.print_alignment:
if cfg.generation.print_alignment:
alignment_str = " ".join(
["{}-{}".format(src, tgt) for src, tgt in alignment]
)
@ -308,4 +322,7 @@ def cli_main():
if __name__ == "__main__":
cs = ConfigStore.instance()
register_hydra_cfg(cs)
initialize(config_path="../config", strict=True)
cli_main()

View File

@ -78,7 +78,13 @@ def cli_main():
def score(fdsys):
with open(args.ref) as fdref:
scorer = bleu.Scorer(dict.pad(), dict.eos(), dict.unk())
scorer = bleu.Scorer(
bleu.BleuConfig(
pad=dict.pad(),
eos=dict.eos(),
unk=dict.unk(),
)
)
for sys_tok, ref_tok in zip(readlines(fdsys), readlines(fdref)):
sys_tok = dict.encode_line(sys_tok)
ref_tok = dict.encode_line(ref_tok)

View File

@ -11,11 +11,13 @@ import argparse
import logging
import math
import os
import random
import sys
from typing import Dict, Optional, Any, List, Tuple, Callable
import numpy as np
import torch
from hydra.core.config_store import ConfigStore
from fairseq import (
checkpoint_utils,
distributed_utils,
@ -25,8 +27,12 @@ from fairseq import (
utils,
)
from fairseq.data import iterators
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
from fairseq.logging import meters, metrics, progress_bar
from fairseq.model_parallel.megatron_trainer import MegatronTrainer
from omegaconf import DictConfig
from hydra.experimental import initialize
from fairseq.dataclass.data_class import register_hydra_cfg
from fairseq.trainer import Trainer
@ -39,90 +45,86 @@ logging.basicConfig(
logger = logging.getLogger("fairseq_cli.train")
def main(args):
utils.import_user_module(args)
def main(cfg: DictConfig) -> None:
if isinstance(cfg, argparse.Namespace):
cfg = convert_namespace_to_omegaconf(cfg)
assert (
args.max_tokens is not None or args.batch_size is not None
), "Must specify batch size either with --max-tokens or --batch-size"
utils.import_user_module(cfg.common)
assert cfg.dataset.max_tokens is not None or cfg.dataset.batch_size is not None, \
'Must specify batch size either with --max-tokens or --batch-size'
metrics.reset()
np.random.seed(args.seed)
utils.set_torch_seed(args.seed)
np.random.seed(cfg.common.seed)
utils.set_torch_seed(cfg.common.seed)
if distributed_utils.is_master(args):
checkpoint_utils.verify_checkpoint_directory(args.save_dir)
if distributed_utils.is_master(cfg.distributed_training):
checkpoint_utils.verify_checkpoint_directory(cfg.checkpoint.save_dir)
# Print args
logger.info(args)
logger.info(cfg)
# Setup task, e.g., translation, language modeling, etc.
task = tasks.setup_task(args)
task = tasks.setup_task(cfg.task)
# Load valid dataset (we load training data below, based on the latest checkpoint)
for valid_sub_split in args.valid_subset.split(","):
for valid_sub_split in cfg.dataset.valid_subset.split(','):
task.load_dataset(valid_sub_split, combine=False, epoch=1)
# Build model and criterion
model = task.build_model(args)
criterion = task.build_criterion(args)
model = task.build_model(cfg.model)
criterion = task.build_criterion(cfg.criterion)
logger.info(model)
logger.info("task: {} ({})".format(args.task, task.__class__.__name__))
logger.info("model: {} ({})".format(args.arch, model.__class__.__name__))
logger.info("task: {} ({})".format(cfg.task._name, task.__class__.__name__))
logger.info("model: {} ({})".format(cfg.model._name, model.__class__.__name__))
logger.info(
"criterion: {} ({})".format(args.criterion, criterion.__class__.__name__)
)
logger.info(
"num. model params: {} (num. trained: {})".format(
sum(p.numel() for p in model.parameters()),
sum(p.numel() for p in model.parameters() if p.requires_grad),
)
"criterion: {} ({})".format(cfg.criterion._name, criterion.__class__.__name__)
)
logger.info("num. model params: {} (num. trained: {})".format(
sum(p.numel() for p in model.parameters()),
sum(p.numel() for p in model.parameters() if p.requires_grad),
))
# (optionally) Configure quantization
if args.quantization_config_path is not None:
if cfg.common.quantization_config_path is not None:
quantizer = quantization_utils.Quantizer(
config_path=args.quantization_config_path,
max_epoch=args.max_epoch,
max_update=args.max_update,
config_path=cfg.common.quantization_config_path,
max_epoch=cfg.optimization.max_epoch,
max_update=cfg.optimization.max_update,
)
else:
quantizer = None
# Build trainer
if args.model_parallel_size == 1:
trainer = Trainer(args, task, model, criterion, quantizer)
if cfg.common.model_parallel_size == 1:
trainer = Trainer(cfg, task, model, criterion, quantizer)
else:
trainer = MegatronTrainer(args, task, model, criterion)
trainer = MegatronTrainer(cfg, task, model, criterion)
logger.info(
"training on {} devices (GPUs/TPUs)".format(args.distributed_world_size)
)
logger.info(
"max tokens per GPU = {} and max sentences per GPU = {}".format(
args.max_tokens, args.batch_size
)
)
logger.info('training on {} devices (GPUs/TPUs)'.format(cfg.distributed_training.distributed_world_size))
logger.info('max tokens per GPU = {} and batch size per GPU = {}'.format(
cfg.dataset.max_tokens,
cfg.dataset.batch_size,
))
# Load the latest checkpoint if one is available and restore the
# corresponding train iterator
extra_state, epoch_itr = checkpoint_utils.load_checkpoint(
args,
cfg.checkpoint,
trainer,
# don't cache epoch iterators for sharded datasets
disable_iterator_cache=task.has_sharded_data("train"),
)
# Train until the learning rate gets too small
max_epoch = args.max_epoch or math.inf
max_epoch = cfg.optimization.max_epoch or math.inf
lr = trainer.get_lr()
train_meter = meters.StopwatchMeter()
train_meter.start()
while lr > args.min_lr and epoch_itr.next_epoch_idx <= max_epoch:
while (
lr > cfg.optimization.min_lr
and epoch_itr.next_epoch_idx <= max_epoch
):
# train for one epoch
valid_losses, should_stop = train(args, trainer, task, epoch_itr)
valid_losses, should_stop = train(cfg, trainer, task, epoch_itr)
if should_stop:
break
@ -140,15 +142,15 @@ def main(args):
logger.info("done training in {:.1f} seconds".format(train_meter.sum))
def should_stop_early(args, valid_loss):
def should_stop_early(cfg: DictConfig, valid_loss: float) -> bool:
# skip check if no validation was done in the current epoch
if valid_loss is None:
return False
if args.patience <= 0:
if cfg.checkpoint.patience <= 0:
return False
def is_better(a, b):
return a > b if args.maximize_best_checkpoint_metric else a < b
return a > b if cfg.checkpoint.maximize_best_checkpoint_metric else a < b
prev_best = getattr(should_stop_early, "best", None)
if prev_best is None or is_better(valid_loss, prev_best):
@ -157,48 +159,43 @@ def should_stop_early(args, valid_loss):
return False
else:
should_stop_early.num_runs += 1
if should_stop_early.num_runs >= args.patience:
logger.info(
"early stop since valid performance hasn't improved for last {} runs".format(
args.patience
)
)
if should_stop_early.num_runs >= cfg.checkpoint.patience:
logger.info('early stop since valid performance hasn\'t improved for last {} runs'.format(cfg.checkpoint.patience))
return True
else:
return False
@metrics.aggregate("train")
def train(args, trainer, task, epoch_itr):
def train(cfg: DictConfig, trainer: Trainer, task: tasks.FairseqTask, epoch_itr) -> Tuple[List[Optional[float]], bool]:
"""Train the model for one epoch and return validation losses."""
# Initialize data iterator
itr = epoch_itr.next_epoch_itr(
fix_batches_to_gpus=args.fix_batches_to_gpus,
shuffle=(epoch_itr.next_epoch_idx > args.curriculum),
fix_batches_to_gpus=cfg.distributed_training.fix_batches_to_gpus,
shuffle=(epoch_itr.next_epoch_idx > cfg.dataset.curriculum),
)
update_freq = (
args.update_freq[epoch_itr.epoch - 1]
if epoch_itr.epoch <= len(args.update_freq)
else args.update_freq[-1]
cfg.optimization.update_freq[epoch_itr.epoch - 1]
if epoch_itr.epoch <= len(cfg.optimization.update_freq)
else cfg.optimization.update_freq[-1]
)
itr = iterators.GroupedIterator(itr, update_freq)
if getattr(args, "tpu", False):
if getattr(cfg.common, "tpu", False):
itr = utils.tpu_data_loader(itr)
progress = progress_bar.progress_bar(
itr,
log_format=args.log_format,
log_interval=args.log_interval,
log_format=cfg.common.log_format,
log_interval=cfg.common.log_interval,
epoch=epoch_itr.epoch,
tensorboard_logdir=(
args.tensorboard_logdir if distributed_utils.is_master(args) else None
cfg.common.tensorboard_logdir if distributed_utils.is_master(cfg.distributed_training) else None
),
default_log_format=("tqdm" if not args.no_progress_bar else "simple"),
default_log_format=('tqdm' if not cfg.common.no_progress_bar else 'simple'),
)
trainer.begin_epoch(epoch_itr.epoch)
valid_losses = [None]
valid_subsets = args.valid_subset.split(",")
valid_subsets = cfg.dataset.valid_subset.split(',')
should_stop = False
num_updates = trainer.get_num_updates()
for i, samples in enumerate(progress):
@ -210,7 +207,7 @@ def train(args, trainer, task, epoch_itr):
if log_output is not None: # not OOM, overflow, ...
# log mid-epoch stats
num_updates = trainer.get_num_updates()
if num_updates % args.log_interval == 0:
if num_updates % cfg.common.log_interval == 0:
stats = get_training_stats(metrics.get_smoothed_values("train_inner"))
progress.log(stats, tag="train_inner", step=num_updates)
@ -220,7 +217,7 @@ def train(args, trainer, task, epoch_itr):
end_of_epoch = not itr.has_next()
valid_losses, should_stop = validate_and_save(
args, trainer, task, epoch_itr, valid_subsets, end_of_epoch
cfg, trainer, task, epoch_itr, valid_subsets, end_of_epoch
)
if should_stop:
@ -236,64 +233,64 @@ def train(args, trainer, task, epoch_itr):
return valid_losses, should_stop
def validate_and_save(args, trainer, task, epoch_itr, valid_subsets, end_of_epoch):
def validate_and_save(cfg: DictConfig, trainer: Trainer, task: tasks.FairseqTask, epoch_itr, valid_subsets: List[str], end_of_epoch: bool) -> Tuple[List[Optional[float]], bool]:
num_updates = trainer.get_num_updates()
max_update = args.max_update or math.inf
max_update = cfg.optimization.max_update or math.inf
do_save = (
(end_of_epoch and epoch_itr.epoch % args.save_interval == 0)
(end_of_epoch and epoch_itr.epoch % cfg.checkpoint.save_interval == 0)
or num_updates >= max_update
or (
args.save_interval_updates > 0
cfg.checkpoint.save_interval_updates > 0
and num_updates > 0
and num_updates % args.save_interval_updates == 0
and num_updates >= args.validate_after_updates
and num_updates % cfg.checkpoint.save_interval_updates == 0
and num_updates >= cfg.dataset.validate_after_updates
)
)
do_validate = (
(not end_of_epoch and do_save) # validate during mid-epoch saves
or (end_of_epoch and epoch_itr.epoch % args.validate_interval == 0)
or (end_of_epoch and epoch_itr.epoch % cfg.dataset.validate_interval == 0)
or num_updates >= max_update
or (
args.validate_interval_updates > 0
cfg.dataset.validate_interval_updates > 0
and num_updates > 0
and num_updates % args.validate_interval_updates == 0
and num_updates % cfg.dataset.validate_interval_updates == 0
)
) and not args.disable_validation
) and not cfg.dataset.disable_validation
# Validate
valid_losses = [None]
if do_validate:
valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets)
valid_losses = validate(cfg, trainer, task, epoch_itr, valid_subsets)
# Stopping conditions
should_stop = (
should_stop_early(args, valid_losses[0])
should_stop_early(cfg, valid_losses[0])
or num_updates >= max_update
or (
args.stop_time_hours > 0
and trainer.cumulative_training_time() / (60 * 60) > args.stop_time_hours
cfg.optimization.stop_time_hours > 0
and trainer.cumulative_training_time() / (60 * 60) > cfg.optimization.stop_time_hours
)
)
# Save checkpoint
if do_save or should_stop:
logger.info("begin save checkpoint")
checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, valid_losses[0])
checkpoint_utils.save_checkpoint(cfg.checkpoint, trainer, epoch_itr, valid_losses[0])
return valid_losses, should_stop
def get_training_stats(stats):
def get_training_stats(stats: Dict[str, Any]) -> Dict[str, Any]:
stats["wall"] = round(metrics.get_meter("default", "wall").elapsed_time, 0)
return stats
def validate(args, trainer, task, epoch_itr, subsets):
def validate(cfg: DictConfig, trainer: Trainer, task: tasks.FairseqTask, epoch_itr, subsets: List[str]) -> List[Optional[float]]:
"""Evaluate the model on the validation set(s) and return the losses."""
if args.fixed_validation_seed is not None:
if cfg.dataset.fixed_validation_seed is not None:
# set fixed seed for every validation
utils.set_torch_seed(args.fixed_validation_seed)
utils.set_torch_seed(cfg.dataset.fixed_validation_seed)
trainer.begin_valid_epoch(epoch_itr.epoch)
valid_losses = []
@ -302,18 +299,18 @@ def validate(args, trainer, task, epoch_itr, subsets):
# Initialize data iterator
itr = trainer.get_valid_iterator(subset).next_epoch_itr(shuffle=False)
if getattr(args, "tpu", False):
if cfg.common.tpu:
itr = utils.tpu_data_loader(itr)
progress = progress_bar.progress_bar(
itr,
log_format=args.log_format,
log_interval=args.log_interval,
log_format=cfg.common.log_format,
log_interval=cfg.common.log_interval,
epoch=epoch_itr.epoch,
prefix=f"valid on '{subset}' subset",
tensorboard_logdir=(
args.tensorboard_logdir if distributed_utils.is_master(args) else None
cfg.common.tensorboard_logdir if distributed_utils.is_master(cfg.distributed_training) else None
),
default_log_format=("tqdm" if not args.no_progress_bar else "simple"),
default_log_format=('tqdm' if not cfg.common.no_progress_bar else 'simple'),
)
# create a new root metrics aggregator so validation metrics
@ -323,34 +320,40 @@ def validate(args, trainer, task, epoch_itr, subsets):
trainer.valid_step(sample)
# log validation stats
stats = get_valid_stats(args, trainer, agg.get_smoothed_values())
stats = get_valid_stats(cfg, trainer, agg.get_smoothed_values())
progress.print(stats, tag=subset, step=trainer.get_num_updates())
valid_losses.append(stats[args.best_checkpoint_metric])
valid_losses.append(stats[cfg.checkpoint.best_checkpoint_metric])
return valid_losses
def get_valid_stats(args, trainer, stats):
def get_valid_stats(cfg: DictConfig, trainer: Trainer, stats: Dict[str, Any]) -> Dict[str, Any]:
stats["num_updates"] = trainer.get_num_updates()
if hasattr(checkpoint_utils.save_checkpoint, "best"):
key = "best_{0}".format(args.best_checkpoint_metric)
best_function = max if args.maximize_best_checkpoint_metric else min
key = "best_{0}".format(cfg.checkpoint.best_checkpoint_metric)
best_function = max if cfg.checkpoint.maximize_best_checkpoint_metric else min
stats[key] = best_function(
checkpoint_utils.save_checkpoint.best, stats[args.best_checkpoint_metric]
checkpoint_utils.save_checkpoint.best, stats[cfg.checkpoint.best_checkpoint_metric]
)
return stats
def cli_main(modify_parser=None):
def cli_main(modify_parser: Optional[Callable[[argparse.ArgumentParser], None]] = None) -> None:
parser = options.get_training_parser()
args = options.parse_args_and_arch(parser, modify_parser=modify_parser)
cfg = convert_namespace_to_omegaconf(args)
if args.profile:
with torch.cuda.profiler.profile():
with torch.autograd.profiler.emit_nvtx():
distributed_utils.call_main(args, main)
distributed_utils.call_main(cfg, main)
else:
distributed_utils.call_main(args, main)
distributed_utils.call_main(cfg, main)
if __name__ == "__main__":
if __name__ == '__main__':
cs = ConfigStore.instance()
register_hydra_cfg(cs)
initialize(config_path="../config", strict=True)
cli_main()

View File

@ -1,5 +1,5 @@
#!/usr/bin/env python3 -u
#!/usr/bin/env python3 -u
# !/usr/bin/env python3 -u
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
@ -8,11 +8,17 @@
import logging
import os
import sys
from argparse import Namespace
from itertools import chain
import torch
from fairseq import checkpoint_utils, distributed_utils, options, utils
from fairseq.dataclass.data_class import register_hydra_cfg
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
from fairseq.logging import metrics, progress_bar
from hydra.core.config_store import ConfigStore
from hydra.experimental import initialize
from omegaconf import DictConfig
logging.basicConfig(
@ -24,18 +30,21 @@ logging.basicConfig(
logger = logging.getLogger("fairseq_cli.validate")
def main(args, override_args=None):
utils.import_user_module(args)
def main(cfg: DictConfig, override_args=None):
if isinstance(cfg, Namespace):
cfg = convert_namespace_to_omegaconf(cfg)
utils.import_user_module(cfg.common)
assert (
args.max_tokens is not None or args.batch_size is not None
cfg.dataset.max_tokens is not None or cfg.dataset.batch_size is not None
), "Must specify batch size either with --max-tokens or --batch-size"
use_fp16 = args.fp16
use_cuda = torch.cuda.is_available() and not args.cpu
use_fp16 = cfg.common.fp16
use_cuda = torch.cuda.is_available() and not cfg.common.cpu
if use_cuda:
torch.cuda.set_device(args.device_id)
torch.cuda.set_device(cfg.distributed_training.device_id)
if override_args is not None:
overrides = vars(override_args)
@ -44,11 +53,11 @@ def main(args, override_args=None):
overrides = None
# Load ensemble
logger.info("loading model(s) from {}".format(args.path))
logger.info("loading model(s) from {}".format(cfg.common_eval.path))
models, model_args, task = checkpoint_utils.load_model_ensemble_and_task(
[args.path],
[cfg.common_eval.path],
arg_overrides=overrides,
suffix=getattr(args, "checkpoint_suffix", ""),
suffix=cfg.checkpoint.checkpoint_suffix,
)
model = models[0]
@ -63,10 +72,10 @@ def main(args, override_args=None):
logger.info(model_args)
# Build criterion
criterion = task.build_criterion(model_args)
criterion = task.build_criterion(model_args.criterion)
criterion.eval()
for subset in args.valid_subset.split(","):
for subset in cfg.dataset.valid_subset.split(","):
try:
task.load_dataset(subset, combine=False, epoch=1)
dataset = task.dataset(subset)
@ -76,26 +85,26 @@ def main(args, override_args=None):
# Initialize data iterator
itr = task.get_batch_iterator(
dataset=dataset,
max_tokens=args.max_tokens,
max_sentences=args.batch_size,
max_tokens=cfg.dataset.max_tokens,
max_sentences=cfg.dataset.batch_size,
max_positions=utils.resolve_max_positions(
task.max_positions(),
*[m.max_positions() for m in models],
),
ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
required_batch_size_multiple=args.required_batch_size_multiple,
seed=args.seed,
num_shards=args.distributed_world_size,
shard_id=args.distributed_rank,
num_workers=args.num_workers,
data_buffer_size=args.data_buffer_size,
ignore_invalid_inputs=cfg.dataset.skip_invalid_size_inputs_valid_test,
required_batch_size_multiple=cfg.dataset.required_batch_size_multiple,
seed=cfg.common.seed,
num_shards=cfg.distributed_training.distributed_world_size,
shard_id=cfg.distributed_training.distributed_rank,
num_workers=cfg.dataset.num_workers,
data_buffer_size=cfg.dataset.data_buffer_size,
).next_epoch_itr(shuffle=False)
progress = progress_bar.progress_bar(
itr,
log_format=args.log_format,
log_interval=args.log_interval,
log_format=cfg.common.log_format,
log_interval=cfg.common.log_interval,
prefix=f"valid on '{subset}' subset",
default_log_format=("tqdm" if not args.no_progress_bar else "simple"),
default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"),
)
log_outputs = []
@ -105,10 +114,10 @@ def main(args, override_args=None):
progress.log(log_output, step=i)
log_outputs.append(log_output)
if args.distributed_world_size > 1:
if cfg.distributed_training.distributed_world_size > 1:
log_outputs = distributed_utils.all_gather_list(
log_outputs,
max_size=getattr(args, "all_gather_list_size", 16384),
max_size=cfg.common.all_gather_list_size,
)
log_outputs = list(chain.from_iterable(log_outputs))
@ -131,4 +140,7 @@ def cli_main():
if __name__ == "__main__":
cs = ConfigStore.instance()
register_hydra_cfg(cs)
initialize(config_path="../config", strict=True)
cli_main()

View File

@ -272,6 +272,7 @@ class TestFairseqEncoderDecoderModelBase(TestBaseFairseqModelBase):
model_cls.add_args(parser)
args = parser.parse_args([])
if extra_args_setters is not None:
for args_setter in extra_args_setters:
args_setter(args)
@ -515,9 +516,7 @@ class CrossEntropyCriterionTestBase(unittest.TestCase):
def setUp(self):
args = self.setUpArgs()
self.model = DummyEncoderModel(encoder=DummyEncoder())
self.criterion = self.criterion_cls.build_criterion(
args=args, task=DummyTask(args)
)
self.criterion = self.criterion_cls.build_criterion(args, task=DummyTask(args))
def get_src_tokens(self, correct_prediction, aggregate):
"""

View File

@ -11,7 +11,7 @@ from multiprocessing import Manager
import torch
import torch.nn as nn
from fairseq import distributed_utils, optim
from omegaconf import OmegaConf
class Model(nn.Module):
def __init__(self, input_size, output_size):
@ -23,13 +23,14 @@ class Model(nn.Module):
return output
def setup_model_loss_criterion(args, rank, is_cuda):
def setup_model_loss_criterion(cfg, args, rank, is_cuda):
"""
setup model, criterion and optimizer based on input args
"""
args.distributed_rank = rank
if args.distributed_world_size > 1:
distributed_utils.distributed_init(args)
cfg.distributed_training.distributed_rank = args.distributed_rank
if cfg.distributed_training.distributed_world_size > 1:
distributed_utils.distributed_init(cfg)
torch.manual_seed(1)
model = Model(args.input_size, args.nb_classes)
loss_fn = nn.CrossEntropyLoss()
@ -38,7 +39,10 @@ def setup_model_loss_criterion(args, rank, is_cuda):
loss_fn = loss_fn.cuda()
optimizer = optim.sgd.SGD(args, model.parameters())
optimizer = optim.FairseqBMUF(args, optimizer)
optimizer = optim.FairseqBMUF(
cfg=cfg.bmuf,
optimizer=optimizer
)
return model, loss_fn, optimizer
@ -52,13 +56,13 @@ def train_step(input, target, model, loss_fn, optimizer, **unused):
optimizer.step()
def single_gpu_training(args, rank, iterations, shared_results):
def single_gpu_training(cfg, args, rank, iterations, shared_results):
is_cuda = torch.cuda.is_available()
if is_cuda:
torch.cuda.set_device(rank)
model, loss_fn, optimizer = setup_model_loss_criterion(args, rank, is_cuda)
model, loss_fn, optimizer = setup_model_loss_criterion(cfg, args, rank, is_cuda)
for _ in range(iterations):
input = torch.randn(1, args.input_size)
@ -103,18 +107,44 @@ def setup_args():
args.distributed_init_host = "localhost"
args.distributed_port = port + 1
args.local_world_size = args.distributed_world_size
return args
cfg = OmegaConf.create()
cfg.optimization = OmegaConf.create()
cfg.common = OmegaConf.create()
cfg.distributed_training = OmegaConf.create()
cfg.dataset = OmegaConf.create()
cfg.bmuf = OmegaConf.create()
cfg.optimizer = OmegaConf.create()
cfg.bmuf.global_sync_iter = args.global_sync_iter
cfg.bmuf.block_momentum = args.block_momentum
cfg.bmuf.block_lr = args.block_lr
cfg.dataset.batch_size = args.batch_size
cfg.optimization.lr = args.lr
cfg.optimizer.momentum = args.momentum
cfg.optimizer.weight_decay = args.weight_decay
cfg.bmuf.warmup_iterations = args.warmup_iterations
cfg.bmuf.use_nbm = args.use_nbm
cfg.bmuf.average_sync = args.average_sync
cfg.common.model_parallel_size = args.model_parallel_size
cfg.distributed_training.distributed_backend = args.distributed_backend
cfg.distributed_training.distributed_world_size = args.distributed_world_size
cfg.bmuf.distributed_world_size = args.distributed_world_size
cfg.distributed_training.distributed_init_method = args.distributed_init_method
cfg.distributed_training.distributed_port = args.distributed_port
return cfg, args
@unittest.skipIf(torch.cuda.device_count() < 2, "test requires 2 GPUs")
class TestBMUF(unittest.TestCase):
def bmuf_process(self, args, iterations):
def bmuf_process(self, cfg, args, iterations):
processes = []
results = Manager().dict()
ctx = torch.multiprocessing.get_context("spawn")
for rank in range(args.distributed_world_size):
p = ctx.Process(
target=single_gpu_training, args=(args, rank, iterations, results)
target=single_gpu_training, args=(cfg, args, rank, iterations, results)
)
p.start()
processes.append(p)
@ -125,19 +155,20 @@ class TestBMUF(unittest.TestCase):
def test_bmuf_sync(self):
# Train model for 1 iteration and do bmuf sync without doing warmup
args = setup_args()
cfg, args = setup_args()
iterations = 1
results = self.bmuf_process(args, iterations)
results = self.bmuf_process(cfg, args, iterations)
# Make sure params in both machines are same
assert len(results) == 2
self.assertAlmostEqual(results[0], results[1])
def test_warmup_sync(self):
# Train model for 20 iteration and do warmup sync without doing bmuf sync
args = setup_args()
cfg, args = setup_args()
args.warmup_iterations = 20
cfg.bmuf.warmup_iterations = args.warmup_iterations
iterations = 20
results = self.bmuf_process(args, iterations)
results = self.bmuf_process(cfg, args, iterations)
# Make sure params in both machines are same
assert len(results) == 2
self.assertAlmostEqual(results[0], results[1])
@ -145,22 +176,27 @@ class TestBMUF(unittest.TestCase):
def test_warmup_sync_bmuf_sync(self):
# Train model for 25 iteration and do warmup sync after 20 iteration
# and bmuf sync after 25 iteration
args = setup_args()
cfg, args = setup_args()
args.warmup_iterations = 20
args.global_sync_iter = 5
cfg.bmuf.warmup_iterations = args.warmup_iterations
cfg.bmuf.global_sync_iter = args.global_sync_iter
iterations = 25
results = self.bmuf_process(args, iterations)
results = self.bmuf_process(cfg, args, iterations)
# Make sure params in both machines are same
assert len(results) == 2
self.assertAlmostEqual(results[0], results[1])
def test_single_gpu_bmuf(self):
# Train model for 5 iterations and use GPU 1
args = setup_args()
cfg, args = setup_args()
args.distributed_world_size = 1
args.warmup_iterations = 5
cfg.distributed_training.distributed_world_size = args.distributed_world_size
cfg.bmuf.distributed_world_size = args.distributed_world_size
cfg.bmuf.warmup_iterations = args.warmup_iterations
iterations = 20
results = self.bmuf_process(args, iterations)
results = self.bmuf_process(cfg, args, iterations)
assert len(results) == 1
def assertAlmostEqual(self, t1, t2):

View File

@ -9,6 +9,7 @@ import unittest
import torch
from fairseq.optim.fp16_optimizer import FP16Optimizer, MemoryEfficientFP16Optimizer
from omegaconf import OmegaConf
@unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU")
@ -27,17 +28,23 @@ class TestGradientScaling(unittest.TestCase):
self.model.cuda().half()
self.params = list(self.model.parameters())
self.namespace_dls = argparse.Namespace(
optimizer="adam",
lr=[0.1],
adam_betas="(0.9, 0.999)",
adam_eps=1e-8,
weight_decay=0.0,
fp16_init_scale=1,
fp16_scale_window=1,
fp16_scale_tolerance=1,
threshold_loss_scale=1,
min_loss_scale=1e-4,
self.cfg_dls = OmegaConf.create(
{
"optimizer": {
"_name": "adam",
"lr": [0.1],
"adam_betas": "(0.9, 0.999)",
"adam_eps": 1e-8,
"weight_decay": 0.0,
},
"common": {
"fp16_init_scale": 1,
"fp16_scale_window": 1,
"fp16_scale_tolerance": 1,
"threshold_loss_scale": 1,
"min_loss_scale": 1e-4,
},
}
)
def run_iter(self, model, params, optimizer):
@ -68,7 +75,7 @@ class TestGradientScaling(unittest.TestCase):
def test_mixed_precision(self):
model = copy.deepcopy(self.model)
params = list(model.parameters())
optimizer = FP16Optimizer.build_optimizer(self.namespace_dls, params)
optimizer = FP16Optimizer.build_optimizer(self.cfg_dls, params)
self.run_iter(model, params, optimizer)
self.assertTrue(
@ -87,9 +94,7 @@ class TestGradientScaling(unittest.TestCase):
def test_memory_efficient(self):
model = copy.deepcopy(self.model)
params = list(model.parameters())
optimizer = MemoryEfficientFP16Optimizer.build_optimizer(
self.namespace_dls, params
)
optimizer = MemoryEfficientFP16Optimizer.build_optimizer(self.cfg_dls, params)
self.run_iter(model, params, optimizer)

View File

@ -6,6 +6,7 @@
import logging
import unittest
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
from fairseq.models.transformer import TransformerModel
from tests.test_sequence_generator import get_dummy_task_and_parser
@ -25,7 +26,8 @@ class TestInferenceDropout(unittest.TestCase):
def test_sets_inference_dropout_to_true(self):
self.args.retain_dropout = True
self.transformer_model = TransformerModel.build_model(self.args, self.task)
self.transformer_model.prepare_for_inference_(self.args)
cfg = convert_namespace_to_omegaconf(self.args)
self.transformer_model.prepare_for_inference_(cfg)
assert self.transformer_model.encoder.dropout_module.apply_during_inference
assert self.transformer_model.decoder.dropout_module.apply_during_inference
for layer in self.transformer_model.encoder.layers:
@ -33,7 +35,8 @@ class TestInferenceDropout(unittest.TestCase):
def test_inference_dropout_false_by_default(self):
self.transformer_model = TransformerModel.build_model(self.args, self.task)
self.transformer_model.prepare_for_inference_(self.args)
cfg = convert_namespace_to_omegaconf(self.args)
self.transformer_model.prepare_for_inference_(cfg)
assert not self.transformer_model.encoder.dropout_module.apply_during_inference
assert not self.transformer_model.decoder.dropout_module.apply_during_inference
for layer in self.transformer_model.encoder.layers:
@ -59,7 +62,8 @@ class TestInferenceDropout(unittest.TestCase):
"TransformerEncoderLayer",
]
self.transformer_model = TransformerModel.build_model(self.args, self.task)
self.transformer_model.prepare_for_inference_(self.args)
cfg = convert_namespace_to_omegaconf(self.args)
self.transformer_model.prepare_for_inference_(cfg)
assert self.transformer_model.encoder.dropout_module.apply_during_inference
assert not self.transformer_model.decoder.dropout_module.apply_during_inference
for layer in self.transformer_model.decoder.layers:

View File

@ -10,6 +10,7 @@ import unittest
import torch
from fairseq.optim.adam import FairseqAdam
from fairseq.optim.fp16_optimizer import MemoryEfficientFP16Optimizer
from omegaconf import OmegaConf
@unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU")
@ -26,25 +27,36 @@ class TestMemoryEfficientFP16(unittest.TestCase):
params = list(model.parameters())
# initialize memory efficient FP16 optimizer
# with pseudo DictConfigs
optimizer = FairseqAdam(
argparse.Namespace(
lr=[0.00001],
adam_betas="(0.9, 0.999)",
adam_eps=1e-8,
weight_decay=0.0,
cfg=OmegaConf.create(
vars(
argparse.Namespace(
adam_betas="(0.9, 0.999)",
adam_eps=1e-8,
weight_decay=0.0,
lr=[0.00001],
)
)
),
params,
params=params,
)
me_optimizer = MemoryEfficientFP16Optimizer(
argparse.Namespace(
fp16_init_scale=1,
fp16_scale_window=1,
fp16_scale_tolerance=1,
threshold_loss_scale=1,
min_loss_scale=1e-4,
cfg=OmegaConf.create(
{
"common": vars(
argparse.Namespace(
fp16_init_scale=1,
fp16_scale_window=1,
fp16_scale_tolerance=1,
threshold_loss_scale=1,
min_loss_scale=1e-4,
)
)
}
),
params,
optimizer,
params=params,
optimizer=optimizer,
)
# optimizer state is created in the first step

View File

@ -11,6 +11,7 @@ from unittest.mock import MagicMock, patch
import torch
from fairseq import checkpoint_utils, data
from omegaconf import OmegaConf
def mock_trainer(epoch, num_updates, iterations_in_epoch):
@ -56,21 +57,29 @@ def get_trainer_and_epoch_itr(epoch, epoch_size, num_updates, iterations_in_epoc
return trainer, epoch_itr
def get_mock_args(finetune_from_model=None):
args_mock = MagicMock()
args_mock.optimizer_overrides = "{}"
args_mock.reset_dataloader = False
args_mock.reset_meters = False
args_mock.reset_optimizer = False
args_mock.reset_lr_scheduler = False
args_mock.finetune_from_model = finetune_from_model
args_mock.model_parallel_size = 1
return args_mock
def get_mock_cfg(finetune_from_model):
cfg_mock = OmegaConf.create(
{
"checkpoint": {
"optimizer_overrides": "{}",
"reset_dataloader": False,
"reset_meters": False,
"reset_optimizer": False,
"reset_lr_scheduler": False,
"finetune_from_model": finetune_from_model,
"model_parallel_size": 1,
},
"common": {
"model_parallel_size": 1,
},
}
)
return cfg_mock
class TestLoadCheckpoint(unittest.TestCase):
def setUp(self):
self.args_mock = get_mock_args()
self.cfg_mock = get_mock_cfg(None)
self.patches = {
"os.makedirs": MagicMock(),
"os.path.join": MagicMock(),
@ -91,7 +100,9 @@ class TestLoadCheckpoint(unittest.TestCase):
trainer, epoch_itr = get_trainer_and_epoch_itr(2, 150, 200, 50)
trainer.get_train_iterator = MagicMock(return_value=epoch_itr)
_, epoch_itr = checkpoint_utils.load_checkpoint(self.args_mock, trainer)
_, epoch_itr = checkpoint_utils.load_checkpoint(
self.cfg_mock.checkpoint, trainer
)
self.assertEqual(epoch_itr.epoch, 2)
self.assertEqual(epoch_itr.iterations_in_epoch, 50)
@ -120,7 +131,9 @@ class TestLoadCheckpoint(unittest.TestCase):
trainer, epoch_itr = get_trainer_and_epoch_itr(2, 150, 300, 150)
trainer.get_train_iterator = MagicMock(return_value=epoch_itr)
_, epoch_itr = checkpoint_utils.load_checkpoint(self.args_mock, trainer)
_, epoch_itr = checkpoint_utils.load_checkpoint(
self.cfg_mock.checkpoint, trainer
)
itr = epoch_itr.next_epoch_itr(shuffle=False)
self.assertEqual(epoch_itr.epoch, 3)
@ -133,7 +146,9 @@ class TestLoadCheckpoint(unittest.TestCase):
trainer.get_train_iterator = MagicMock(return_value=epoch_itr)
self.patches["os.path.isfile"].return_value = False
_, epoch_itr = checkpoint_utils.load_checkpoint(self.args_mock, trainer)
_, epoch_itr = checkpoint_utils.load_checkpoint(
self.cfg_mock.checkpoint, trainer
)
itr = epoch_itr.next_epoch_itr(shuffle=False)
self.assertEqual(epoch_itr.epoch, 1)
@ -152,10 +167,12 @@ class TestLoadCheckpoint(unittest.TestCase):
"reset_dataloader",
]:
with self.subTest(arg=arg):
args_mock = get_mock_args("/temp/checkpoint_pretrained.pt")
setattr(args_mock, arg, True)
cfg_mock = get_mock_cfg("/temp/checkpoint_pretrained.pt")
cfg_mock["checkpoint"][arg] = True
with self.assertRaises(Exception) as context:
_, _ = checkpoint_utils.load_checkpoint(args_mock, trainer)
_, _ = checkpoint_utils.load_checkpoint(
cfg_mock.checkpoint, trainer
)
self.assertTrue(
"--finetune-from-model can not be set together with either --reset-optimizer"
@ -168,8 +185,6 @@ class TestLoadCheckpoint(unittest.TestCase):
trainer, epoch_itr = get_trainer_and_epoch_itr(1, 150, 0, 0)
trainer.get_train_iterator = MagicMock(return_value=epoch_itr)
from_model_path = "/temp/checkpoint_pretrained.pt"
args_mock = get_mock_args(from_model_path)
args_mock.restore_file = "checkpoint_last.pt"
def mock_finetune_exist(path):
if path == from_model_path:
@ -180,7 +195,9 @@ class TestLoadCheckpoint(unittest.TestCase):
self.patches[
"fairseq.file_io.PathManager.exists"
].side_effect = mock_finetune_exist
_, _ = checkpoint_utils.load_checkpoint(args_mock, trainer)
cfg_mock = get_mock_cfg(from_model_path)
cfg_mock.checkpoint.restore_file = "checkpoint_last.pt"
_, _ = checkpoint_utils.load_checkpoint(cfg_mock.checkpoint, trainer)
(
checkpoint_path,
reset_optimizer,
@ -197,8 +214,6 @@ class TestLoadCheckpoint(unittest.TestCase):
trainer, epoch_itr = get_trainer_and_epoch_itr(1, 150, 0, 0)
trainer.get_train_iterator = MagicMock(return_value=epoch_itr)
from_model_path = "/temp/checkpoint_pretrained.pt"
args_mock = get_mock_args(from_model_path)
args_mock.restore_file = "checkpoint_last.pt"
# launch second time
# both restore_file=checkpoint_last.pt and finetune_from_model are set
@ -211,7 +226,9 @@ class TestLoadCheckpoint(unittest.TestCase):
self.patches[
"fairseq.file_io.PathManager.exists"
].side_effect = mock_finetune_exist
_, _ = checkpoint_utils.load_checkpoint(args_mock, trainer)
cfg_mock = get_mock_cfg(from_model_path)
cfg_mock.checkpoint.restore_file = "checkpoint_last.pt"
_, _ = checkpoint_utils.load_checkpoint(cfg_mock.checkpoint, trainer)
(
checkpoint_path,
reset_optimizer,

View File

@ -20,7 +20,7 @@ from fairseq.models import (
FairseqIncrementalDecoder,
)
from fairseq.models.fairseq_encoder import EncoderOut
from fairseq.tasks import FairseqTask, LegacyFairseqTask
from fairseq.tasks import LegacyFairseqTask
from fairseq_cli import generate, interactive, preprocess, train, validate