mirror of
https://github.com/facebookresearch/fairseq.git
synced 2024-07-14 18:50:22 +03:00
Summary: Pull Request resolved: https://github.com/facebookresearch/pytext/pull/1510 this is the main pr that switches on hydra functionality in fairseq we migrate "args" object into omegaconf "DictConfig" at all legacy entry points in addition this migrates various components from secondary registries (like bpe encoders and tokenizers) to make the migration smoother i am going through code that references migrated fairseq components and changing it to inherit from "Legacy*" components instead. hopefully tests will catch most of this Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1343 Reviewed By: myleott Differential Revision: D23973928 Pulled By: alexeib fbshipit-source-id: dd9554981fff51ea75c1ff343874d1d6e61793c9
This commit is contained in:
parent
c76cb6dfb9
commit
3b27ed7996
@ -1,7 +1,111 @@
|
||||
# @package _group_
|
||||
common:
|
||||
no_progress_bar: false
|
||||
log_interval: 100
|
||||
log_format: null
|
||||
tensorboard_logdir: null
|
||||
seed: 1
|
||||
cpu: false
|
||||
tpu: false
|
||||
bf16: false
|
||||
fp16: false
|
||||
memory_efficient_fp16: false
|
||||
memory_efficient_bf16: false
|
||||
fp16_no_flatten_grads: false
|
||||
fp16_init_scale: 128
|
||||
fp16_scale_window: null
|
||||
fp16_scale_tolerance: 0.0
|
||||
min_loss_scale: 1.0e-4
|
||||
threshold_loss_scale: null
|
||||
user_dir: null
|
||||
empty_cache_freq: 0
|
||||
all_gather_list_size: 16384
|
||||
model_parallel_size: 1
|
||||
quantization_config_path: null
|
||||
profile: false
|
||||
distributed_training:
|
||||
distributed_rank: 0
|
||||
distributed_backend: "nccl"
|
||||
distributed_init_method: null
|
||||
distributed_port: -1
|
||||
device_id: 0
|
||||
local_rank: 0
|
||||
distributed_no_spawn: false
|
||||
ddp_backend: "c10d"
|
||||
bucket_cap_mb: 25
|
||||
fix_batches_to_gpus: false
|
||||
find_unused_parameters: false
|
||||
fast_stat_sync: false
|
||||
broadcast_buffers: false
|
||||
distributed_wrapper: "DDP"
|
||||
slowmo_momentum: null
|
||||
slowmo_algorithm: "LocalSGD"
|
||||
localsgd_frequency: 3
|
||||
dataset:
|
||||
num_workers: 1
|
||||
skip_invalid_size_inputs_valid_test: false
|
||||
max_tokens: null
|
||||
batch_size: null
|
||||
required_batch_size_multiple: 8
|
||||
dataset_impl: null
|
||||
data_buffer_size: 10
|
||||
train_subset: "train"
|
||||
valid_subset: "valid"
|
||||
validate_interval: 1
|
||||
fixed_validation_seed: null
|
||||
disable_validation: false
|
||||
curriculum: 0
|
||||
gen_subset: "test"
|
||||
num_shards: 1
|
||||
shard_id: 0
|
||||
max_tokens_valid: ${dataset.max_tokens}
|
||||
batch_size_valid: ${dataset.batch_size}
|
||||
optimization:
|
||||
max_epoch: 0
|
||||
max_update: 0
|
||||
clip_norm: 25.0
|
||||
sentence_avg: false
|
||||
update_freq: [ 1 ]
|
||||
lr: [ 0.25 ]
|
||||
min_lr: -1.0
|
||||
use_bmuf: false
|
||||
checkpoint:
|
||||
save_dir: "checkpoints"
|
||||
restore_file: "checkpoint_last.pt"
|
||||
reset_dataloader: false
|
||||
reset_lr_scheduler: false
|
||||
reset_meters: false
|
||||
reset_optimizer: false
|
||||
optimizer_overrides: "{}"
|
||||
save_interval: 1
|
||||
save_interval_updates: 0
|
||||
keep_interval_updates: -1
|
||||
keep_last_epochs: -1
|
||||
keep_best_checkpoints: -1
|
||||
no_save: false
|
||||
no_epoch_checkpoints: false
|
||||
no_last_checkpoints: false
|
||||
no_save_optimizer_state: false
|
||||
best_checkpoint_metric: "loss"
|
||||
maximize_best_checkpoint_metric: false
|
||||
patience: -1
|
||||
checkpoint_suffix: ""
|
||||
bmuf:
|
||||
block_lr: 1
|
||||
block_momentum: 0.875
|
||||
global_sync_iter: 50
|
||||
warmup_iterations: 500
|
||||
use_nbm: false
|
||||
average_sync: false
|
||||
defaults:
|
||||
- params: training_params
|
||||
- task: language_modeling
|
||||
- model: transformer_lm
|
||||
- criterion: cross_entropy
|
||||
- optimizer: adam
|
||||
- lr_scheduler: inverse_sqrt
|
||||
- task: language_modeling
|
||||
- model: null
|
||||
- criterion: null
|
||||
- optimizer: null
|
||||
- lr_scheduler: null
|
||||
- bpe: null
|
||||
- tokenizer: null
|
||||
- scoring: null
|
||||
- generation: null
|
||||
- common_eval: null
|
||||
- eval_lm: null
|
||||
|
@ -1,7 +0,0 @@
|
||||
defaults:
|
||||
- params: eval_lm_params
|
||||
- task: language_modeling
|
||||
- model: transformer_lm
|
||||
- criterion: cross_entropy
|
||||
- optimizer: adam
|
||||
- lr_scheduler: inverse_sqrt
|
@ -1,3 +1,3 @@
|
||||
# @package _group_
|
||||
sentence_avg: ${params.optimization.sentence_avg}
|
||||
ddp_backend: ${params.distributed_training.ddp_backend}
|
||||
sentence_avg: ${optimization.sentence_avg}
|
||||
ddp_backend: ${distributed_training.ddp_backend}
|
||||
|
@ -1,3 +1,2 @@
|
||||
# @package _group_
|
||||
sentence_avg: ${params.optimization.sentence_avg}
|
||||
ddp_backend: ${params.distributed_training.ddp_backend}
|
||||
sentence_avg: ${optimization.sentence_avg}
|
||||
|
@ -1,105 +0,0 @@
|
||||
# @package _group_
|
||||
common:
|
||||
no_progress_bar: false
|
||||
log_interval: 100
|
||||
log_format: null
|
||||
tensorboard_logdir: null
|
||||
seed: 1
|
||||
cpu: false
|
||||
fp16: false
|
||||
memory_efficient_fp16: false
|
||||
fp16_no_flatten_grads: false
|
||||
fp16_init_scale: 128
|
||||
fp16_scale_window: null
|
||||
fp16_scale_tolerance: 0.0
|
||||
min_loss_scale: 1.0e-4
|
||||
threshold_loss_scale: null
|
||||
user_dir: null
|
||||
empty_cache_freq: 0
|
||||
all_gather_list_size: 16384
|
||||
model_parallel_size: 1
|
||||
checkpoint_suffix: ""
|
||||
quantization_config_path: null
|
||||
distributed_training:
|
||||
distributed_rank: 0
|
||||
distributed_backend: "nccl"
|
||||
distributed_init_method: null
|
||||
distributed_port: -1
|
||||
device_id: 0
|
||||
local_rank: 0
|
||||
distributed_no_spawn: false
|
||||
ddp_backend: "c10d"
|
||||
bucket_cap_mb: 25
|
||||
fix_batches_to_gpus: false
|
||||
find_unused_parameters: false
|
||||
fast_stat_sync: false
|
||||
broadcast_buffers: false
|
||||
distributed_wrapper: "DDP"
|
||||
slowmo_momentum: null
|
||||
slowmo_algorithm: "LocalSGD"
|
||||
localsgd_frequency: 3
|
||||
dataset:
|
||||
num_workers: 1
|
||||
skip_invalid_size_inputs_valid_test: false
|
||||
max_tokens: null
|
||||
batch_size: ${params.dataset.batch_size}
|
||||
required_batch_size_multiple: 8
|
||||
dataset_impl: null
|
||||
data_buffer_size: 10
|
||||
train_subset: "train"
|
||||
valid_subset: "valid"
|
||||
validate_interval: 1
|
||||
fixed_validation_seed: null
|
||||
disable_validation: false
|
||||
curriculum: 0
|
||||
gen_subset: "test"
|
||||
num_shards: 1
|
||||
shard_id: 0
|
||||
max_tokens_valid: ${params.dataset.max_tokens}
|
||||
batch_size_valid: ${params.dataset.batch_size}
|
||||
optimization:
|
||||
max_epoch: 0
|
||||
max_update: 0
|
||||
clip_norm: 25.0
|
||||
sentence_avg: false
|
||||
update_freq: [1]
|
||||
lr: [0.25]
|
||||
min_lr: -1.0
|
||||
use_bmuf: false
|
||||
checkpoint:
|
||||
save_dir: "checkpoints"
|
||||
restore_file: "checkpoint_last.pt"
|
||||
reset_dataloader: false
|
||||
reset_lr_scheduler: false
|
||||
reset_meters: false
|
||||
reset_optimizer: false
|
||||
optimizer_overrides: "{}"
|
||||
save_interval: 1
|
||||
save_interval_updates: 0
|
||||
keep_interval_updates: -1
|
||||
keep_last_epochs: -1
|
||||
keep_best_checkpoints: -1
|
||||
no_save: false
|
||||
no_epoch_checkpoints: false
|
||||
no_last_checkpoints: false
|
||||
no_save_optimizer_state: false
|
||||
best_checkpoint_metric: "loss"
|
||||
maximize_best_checkpoint_metric: false
|
||||
patience: -1
|
||||
common_eval:
|
||||
path: null
|
||||
remove_bpe: null
|
||||
quiet: false
|
||||
model_overrides: "{}"
|
||||
results_path: null
|
||||
eval_lm:
|
||||
output_word_probs: false
|
||||
output_word_stats: false
|
||||
context_window: 0
|
||||
bmuf:
|
||||
block_lr: 1
|
||||
block_momentum: 0.875
|
||||
global_sync_iter: 50
|
||||
warmup_iterations: 500
|
||||
use_nbm: false
|
||||
average_sync: false
|
@ -1,95 +0,0 @@
|
||||
# @package _group_
|
||||
common:
|
||||
no_progress_bar: false
|
||||
log_interval: 100
|
||||
log_format: null
|
||||
tensorboard_logdir: null
|
||||
seed: 1
|
||||
cpu: false
|
||||
fp16: false
|
||||
memory_efficient_fp16: false
|
||||
fp16_no_flatten_grads: false
|
||||
fp16_init_scale: 128
|
||||
fp16_scale_window: null
|
||||
fp16_scale_tolerance: 0.0
|
||||
min_loss_scale: 1.0e-4
|
||||
threshold_loss_scale: null
|
||||
user_dir: null
|
||||
empty_cache_freq: 0
|
||||
all_gather_list_size: 16384
|
||||
model_parallel_size: 1
|
||||
checkpoint_suffix: ""
|
||||
quantization_config_path: null
|
||||
distributed_training:
|
||||
distributed_rank: 0
|
||||
distributed_backend: "nccl"
|
||||
distributed_init_method: null
|
||||
distributed_port: -1
|
||||
device_id: 0
|
||||
local_rank: 0
|
||||
distributed_no_spawn: false
|
||||
ddp_backend: "c10d"
|
||||
bucket_cap_mb: 25
|
||||
fix_batches_to_gpus: false
|
||||
find_unused_parameters: false
|
||||
fast_stat_sync: false
|
||||
broadcast_buffers: false
|
||||
distributed_wrapper: "DDP"
|
||||
slowmo_momentum: null
|
||||
slowmo_algorithm: "LocalSGD"
|
||||
localsgd_frequency: 3
|
||||
dataset:
|
||||
num_workers: 1
|
||||
skip_invalid_size_inputs_valid_test: false
|
||||
max_tokens: null
|
||||
batch_size: ${params.dataset.batch_size}
|
||||
required_batch_size_multiple: 8
|
||||
dataset_impl: null
|
||||
data_buffer_size: 10
|
||||
train_subset: "train"
|
||||
valid_subset: "valid"
|
||||
validate_interval: 1
|
||||
fixed_validation_seed: null
|
||||
disable_validation: false
|
||||
curriculum: 0
|
||||
gen_subset: "test"
|
||||
num_shards: 1
|
||||
shard_id: 0
|
||||
max_tokens_valid: ${params.dataset.max_tokens}
|
||||
batch_size_valid: ${params.dataset.batch_size}
|
||||
optimization:
|
||||
max_epoch: 0
|
||||
max_update: 0
|
||||
clip_norm: 25.0
|
||||
sentence_avg: false
|
||||
update_freq: [1]
|
||||
lr: [0.25]
|
||||
min_lr: -1.0
|
||||
use_bmuf: false
|
||||
checkpoint:
|
||||
save_dir: "checkpoints"
|
||||
restore_file: "checkpoint_last.pt"
|
||||
reset_dataloader: false
|
||||
reset_lr_scheduler: false
|
||||
reset_meters: false
|
||||
reset_optimizer: false
|
||||
optimizer_overrides: "{}"
|
||||
save_interval: 1
|
||||
save_interval_updates: 0
|
||||
keep_interval_updates: -1
|
||||
keep_last_epochs: -1
|
||||
keep_best_checkpoints: -1
|
||||
no_save: false
|
||||
no_epoch_checkpoints: false
|
||||
no_last_checkpoints: false
|
||||
no_save_optimizer_state: false
|
||||
best_checkpoint_metric: "loss"
|
||||
maximize_best_checkpoint_metric: false
|
||||
patience: -1
|
||||
bmuf:
|
||||
block_lr: 1
|
||||
block_momentum: 0.875
|
||||
global_sync_iter: 50
|
||||
warmup_iterations: 500
|
||||
use_nbm: false
|
||||
average_sync: false
|
@ -13,7 +13,6 @@ For example, if we'd like to train a language model with transformer, we could p
|
||||
|
||||
```
|
||||
defaults:
|
||||
- params: training_params
|
||||
- task: language_modeling
|
||||
- model: transformer_lm
|
||||
- criterion: cross_entropy
|
||||
@ -21,7 +20,7 @@ defaults:
|
||||
- lr_scheduler: inverse_sqrt
|
||||
```
|
||||
|
||||
- Provide generic parameters common across different training jobs: `config/params/training_params.yaml`
|
||||
- Provide generic parameters common across different jobs: `config.yaml`
|
||||
- Provide task parameters: `config/task/language_modeling.yaml`
|
||||
- Provide model parameters: `config/model/transformer_lm.yaml`
|
||||
- Provide criterion parameters: `config/criterion/cross_entropy.yaml`
|
||||
@ -41,7 +40,6 @@ Alternatively, if we need to override certain params from the command line, we c
|
||||
|
||||
```
|
||||
python fairseq_cli/train_hydra.py
|
||||
params=training_params \
|
||||
task=language_modeling \
|
||||
task.data=/private/home/abaevski/data/wiki103 \
|
||||
task.tokens_per_sample=512 \
|
||||
@ -56,17 +54,17 @@ lr_scheduler=inverse_sqrt \
|
||||
lr_scheduler.warmup_updates=4000 \
|
||||
lr_scheduler.warmup_init_lr=1e-07 \
|
||||
criterion=cross_entropy \
|
||||
params.common.fp16=true \
|
||||
params.common.log_format=json \
|
||||
params.common.log_interval=1 \
|
||||
params.dataset.max_tokens=1024 \
|
||||
params.dataset.num_workers=4 \
|
||||
params.optimization.update_freq=[16] \
|
||||
params.optimization.max_update=50000 \
|
||||
params.optimization.clip_norm=0.0 \
|
||||
params.optimization.lr=[0.0005] \
|
||||
params.checkpoint.save_dir=/checkpoint/mtian/transformer_wikitext-103-hydra-args-cli \
|
||||
params.checkpoint.save_interval_updates=10
|
||||
common.fp16=true \
|
||||
common.log_format=json \
|
||||
common.log_interval=1 \
|
||||
dataset.max_tokens=1024 \
|
||||
dataset.num_workers=4 \
|
||||
optimization.update_freq=[16] \
|
||||
optimization.max_update=50000 \
|
||||
optimization.clip_norm=0.0 \
|
||||
optimization.lr=[0.0005] \
|
||||
checkpoint.save_dir=/checkpoint/mtian/transformer_wikitext-103-hydra-args-cli \
|
||||
checkpoint.save_interval_updates=10
|
||||
```
|
||||
|
||||
## Migrate existing/Creating new modules to hydra interface
|
||||
|
@ -212,7 +212,7 @@ following contents::
|
||||
|
||||
|
||||
@register_task('simple_classification')
|
||||
class SimpleClassificationTask(FairseqTask):
|
||||
class SimpleClassificationTask(LegacyFairseqTask):
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
@ -27,7 +27,13 @@ def score_target_hypo(
|
||||
print("lenpen", lenpen, "weight1", a, "weight2", b, "weight3", c)
|
||||
gen_output_lst, bitext1_lst, bitext2_lst, lm_res_lst = load_score_files(args)
|
||||
dict = dictionary.Dictionary()
|
||||
scorer = bleu.Scorer(dict.pad(), dict.eos(), dict.unk())
|
||||
scorer = scorer = bleu.Scorer(
|
||||
bleu.BleuConfig(
|
||||
pad=dict.pad(),
|
||||
eos=dict.eos(),
|
||||
unk=dict.unk(),
|
||||
)
|
||||
)
|
||||
|
||||
ordered_hypos = {}
|
||||
ordered_targets = {}
|
||||
|
@ -20,8 +20,8 @@ class WSCCriterion(LegacyFairseqCriterion):
|
||||
self.prediction_h = open(self.args.save_predictions, "w")
|
||||
else:
|
||||
self.prediction_h = None
|
||||
self.bpe = encoders.build_bpe(args)
|
||||
self.tokenizer = encoders.build_tokenizer(args)
|
||||
self.bpe = encoders.build_bpe(args.bpe)
|
||||
self.tokenizer = encoders.build_tokenizer(args.tokenizer)
|
||||
|
||||
def __del__(self):
|
||||
if self.prediction_h is not None:
|
||||
|
@ -85,7 +85,7 @@ Produce model scores for the generated translations using `--retain-dropout` opt
|
||||
```
|
||||
CUDA_VISIBLE_DEVICES=${GPU} fairseq-generate ${TMP}/bin-repeated --path ${MODEL_DIR}/${LP}.pt --beam 5
|
||||
--source-lang $SRC_LANG --target-lang $TGT_LANG --no-progress-bar --unkpen 5 --score-reference --retain-dropout
|
||||
--retain-dropout-modules TransformerModel TransformerEncoder TransformerDecoder TransformerEncoderLayer
|
||||
--retain-dropout-modules '["TransformerModel","TransformerEncoder","TransformerDecoder","TransformerEncoderLayer"]'
|
||||
TransformerDecoderLayer --seed 46 > $TMP/dropout.scoring.out
|
||||
|
||||
grep ^H $TMP/dropout.scoring.out | cut -f2- > $TMP/dropout.scores
|
||||
|
@ -3,36 +3,42 @@
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import ast
|
||||
import collections
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import traceback
|
||||
from collections import OrderedDict
|
||||
from typing import Union
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from fairseq.dataclass.utils import (
|
||||
convert_namespace_to_omegaconf,
|
||||
overwrite_args_by_name,
|
||||
)
|
||||
from fairseq.file_io import PathManager
|
||||
from fairseq.models import FairseqDecoder, FairseqEncoder
|
||||
from omegaconf import DictConfig, open_dict
|
||||
from torch.serialization import default_restore_location
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def save_checkpoint(args, trainer, epoch_itr, val_loss):
|
||||
from fairseq import distributed_utils, meters
|
||||
def save_checkpoint(cfg: DictConfig, trainer, epoch_itr, val_loss):
|
||||
from fairseq import meters
|
||||
|
||||
# only one worker should attempt to create the required dir
|
||||
if args.distributed_rank == 0:
|
||||
os.makedirs(args.save_dir, exist_ok=True)
|
||||
if cfg.distributed_rank == 0:
|
||||
os.makedirs(cfg.save_dir, exist_ok=True)
|
||||
|
||||
prev_best = getattr(save_checkpoint, "best", val_loss)
|
||||
if val_loss is not None:
|
||||
best_function = max if args.maximize_best_checkpoint_metric else min
|
||||
best_function = max if cfg.maximize_best_checkpoint_metric else min
|
||||
save_checkpoint.best = best_function(val_loss, prev_best)
|
||||
|
||||
if args.no_save:
|
||||
if cfg.no_save:
|
||||
return
|
||||
|
||||
trainer.consolidate_optimizer()
|
||||
@ -41,7 +47,7 @@ def save_checkpoint(args, trainer, epoch_itr, val_loss):
|
||||
return
|
||||
|
||||
def is_better(a, b):
|
||||
return a >= b if args.maximize_best_checkpoint_metric else a <= b
|
||||
return a >= b if cfg.maximize_best_checkpoint_metric else a <= b
|
||||
|
||||
write_timer = meters.StopwatchMeter()
|
||||
write_timer.start()
|
||||
@ -50,38 +56,36 @@ def save_checkpoint(args, trainer, epoch_itr, val_loss):
|
||||
end_of_epoch = epoch_itr.end_of_epoch()
|
||||
updates = trainer.get_num_updates()
|
||||
|
||||
suffix = getattr(args, "checkpoint_suffix", "")
|
||||
suffix = cfg.checkpoint_suffix or ""
|
||||
checkpoint_conds = collections.OrderedDict()
|
||||
checkpoint_conds["checkpoint{}{}.pt".format(epoch, suffix)] = (
|
||||
end_of_epoch
|
||||
and not args.no_epoch_checkpoints
|
||||
and epoch % args.save_interval == 0
|
||||
end_of_epoch and not cfg.no_epoch_checkpoints and epoch % cfg.save_interval == 0
|
||||
)
|
||||
checkpoint_conds["checkpoint_{}_{}{}.pt".format(epoch, updates, suffix)] = (
|
||||
not end_of_epoch
|
||||
and args.save_interval_updates > 0
|
||||
and updates % args.save_interval_updates == 0
|
||||
and cfg.save_interval_updates > 0
|
||||
and updates % cfg.save_interval_updates == 0
|
||||
)
|
||||
checkpoint_conds["checkpoint_best{}.pt".format(suffix)] = val_loss is not None and (
|
||||
not hasattr(save_checkpoint, "best")
|
||||
or is_better(val_loss, save_checkpoint.best)
|
||||
)
|
||||
if val_loss is not None and args.keep_best_checkpoints > 0:
|
||||
if val_loss is not None and cfg.keep_best_checkpoints > 0:
|
||||
checkpoint_conds[
|
||||
"checkpoint.best_{}_{:.2f}.pt".format(args.best_checkpoint_metric, val_loss)
|
||||
"checkpoint.best_{}_{:.2f}.pt".format(cfg.best_checkpoint_metric, val_loss)
|
||||
] = not hasattr(save_checkpoint, "best") or is_better(
|
||||
val_loss, save_checkpoint.best
|
||||
)
|
||||
checkpoint_conds[
|
||||
"checkpoint_last{}.pt".format(suffix)
|
||||
] = not args.no_last_checkpoints
|
||||
] = not cfg.no_last_checkpoints
|
||||
|
||||
extra_state = {"train_iterator": epoch_itr.state_dict(), "val_loss": val_loss}
|
||||
if hasattr(save_checkpoint, "best"):
|
||||
extra_state.update({"best": save_checkpoint.best})
|
||||
|
||||
checkpoints = [
|
||||
os.path.join(args.save_dir, fn) for fn, cond in checkpoint_conds.items() if cond
|
||||
os.path.join(cfg.save_dir, fn) for fn, cond in checkpoint_conds.items() if cond
|
||||
]
|
||||
if len(checkpoints) > 0:
|
||||
trainer.save_checkpoint(checkpoints[0], extra_state)
|
||||
@ -95,51 +99,52 @@ def save_checkpoint(args, trainer, epoch_itr, val_loss):
|
||||
)
|
||||
)
|
||||
|
||||
if not end_of_epoch and args.keep_interval_updates > 0:
|
||||
if not end_of_epoch and cfg.keep_interval_updates > 0:
|
||||
# remove old checkpoints; checkpoints are sorted in descending order
|
||||
checkpoints = checkpoint_paths(
|
||||
args.save_dir, pattern=r"checkpoint_\d+_(\d+)\.pt"
|
||||
cfg.save_dir, pattern=r"checkpoint_\d+_(\d+)\.pt"
|
||||
)
|
||||
for old_chk in checkpoints[args.keep_interval_updates :]:
|
||||
for old_chk in checkpoints[cfg.keep_interval_updates :]:
|
||||
if os.path.lexists(old_chk):
|
||||
os.remove(old_chk)
|
||||
|
||||
if args.keep_last_epochs > 0:
|
||||
if cfg.keep_last_epochs > 0:
|
||||
# remove old epoch checkpoints; checkpoints are sorted in descending order
|
||||
checkpoints = checkpoint_paths(args.save_dir, pattern=r"checkpoint(\d+)\.pt")
|
||||
for old_chk in checkpoints[args.keep_last_epochs :]:
|
||||
checkpoints = checkpoint_paths(cfg.save_dir, pattern=r"checkpoint(\d+)\.pt")
|
||||
for old_chk in checkpoints[cfg.keep_last_epochs :]:
|
||||
if os.path.lexists(old_chk):
|
||||
os.remove(old_chk)
|
||||
|
||||
if args.keep_best_checkpoints > 0:
|
||||
if cfg.keep_best_checkpoints > 0:
|
||||
# only keep the best N checkpoints according to validation metric
|
||||
checkpoints = checkpoint_paths(
|
||||
args.save_dir,
|
||||
cfg.save_dir,
|
||||
pattern=r"checkpoint\.best_{}_(\d+\.?\d*)\.pt".format(
|
||||
args.best_checkpoint_metric
|
||||
cfg.best_checkpoint_metric
|
||||
),
|
||||
)
|
||||
if not args.maximize_best_checkpoint_metric:
|
||||
if not cfg.maximize_best_checkpoint_metric:
|
||||
checkpoints = checkpoints[::-1]
|
||||
for old_chk in checkpoints[args.keep_best_checkpoints :]:
|
||||
for old_chk in checkpoints[cfg.keep_best_checkpoints :]:
|
||||
if os.path.lexists(old_chk):
|
||||
os.remove(old_chk)
|
||||
|
||||
|
||||
def load_checkpoint(args, trainer, **passthrough_args):
|
||||
def load_checkpoint(cfg: DictConfig, trainer, **passthrough_args):
|
||||
"""
|
||||
Load a checkpoint and restore the training iterator.
|
||||
|
||||
*passthrough_args* will be passed through to
|
||||
``trainer.get_train_iterator``.
|
||||
"""
|
||||
reset_optimizer = args.reset_optimizer
|
||||
reset_lr_scheduler = args.reset_lr_scheduler
|
||||
optimizer_overrides = eval(args.optimizer_overrides)
|
||||
reset_meters = args.reset_meters
|
||||
reset_dataloader = args.reset_dataloader
|
||||
|
||||
if getattr(args, "finetune_from_model", None) is not None and (
|
||||
reset_optimizer = cfg.reset_optimizer
|
||||
reset_lr_scheduler = cfg.reset_lr_scheduler
|
||||
optimizer_overrides = ast.literal_eval(cfg.optimizer_overrides)
|
||||
reset_meters = cfg.reset_meters
|
||||
reset_dataloader = cfg.reset_dataloader
|
||||
|
||||
if cfg.finetune_from_model is not None and (
|
||||
reset_optimizer or reset_lr_scheduler or reset_meters or reset_dataloader
|
||||
):
|
||||
raise ValueError(
|
||||
@ -147,19 +152,19 @@ def load_checkpoint(args, trainer, **passthrough_args):
|
||||
" or reset_lr_scheduler or reset_meters or reset_dataloader"
|
||||
)
|
||||
|
||||
suffix = getattr(args, "checkpoint_suffix", "")
|
||||
suffix = cfg.checkpoint_suffix
|
||||
if (
|
||||
args.restore_file == "checkpoint_last.pt"
|
||||
cfg.restore_file == "checkpoint_last.pt"
|
||||
): # default value of restore_file is 'checkpoint_last.pt'
|
||||
checkpoint_path = os.path.join(
|
||||
args.save_dir, "checkpoint_last{}.pt".format(suffix)
|
||||
cfg.save_dir, "checkpoint_last{}.pt".format(suffix)
|
||||
)
|
||||
first_launch = not PathManager.exists(checkpoint_path)
|
||||
if getattr(args, "finetune_from_model", None) is not None and first_launch:
|
||||
if cfg.finetune_from_model is not None and first_launch:
|
||||
# if there is no last checkpoint to restore, start the finetune from pretrained model
|
||||
# else just use usual logic to load checkpoint, e.g. restart from last checkpoint and etc.
|
||||
if PathManager.exists(args.finetune_from_model):
|
||||
checkpoint_path = args.finetune_from_model
|
||||
if PathManager.exists(cfg.finetune_from_model):
|
||||
checkpoint_path = cfg.finetune_from_model
|
||||
reset_optimizer = True
|
||||
reset_lr_scheduler = True
|
||||
reset_meters = True
|
||||
@ -170,19 +175,17 @@ def load_checkpoint(args, trainer, **passthrough_args):
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"--funetune-from-model {args.finetune_from_model} does not exist"
|
||||
f"--funetune-from-model {cfg.finetune_from_model} does not exist"
|
||||
)
|
||||
elif getattr(args, "model_parallel_size", 1) > 1:
|
||||
checkpoint_path = args.restore_file.replace(".pt", suffix + ".pt")
|
||||
elif cfg.model_parallel_size > 1:
|
||||
checkpoint_path = cfg.restore_file.replace(".pt", suffix + ".pt")
|
||||
else:
|
||||
checkpoint_path = args.restore_file
|
||||
checkpoint_path = cfg.restore_file
|
||||
|
||||
if args.restore_file != "checkpoint_last.pt" and getattr(
|
||||
args, "finetune_from_model", None
|
||||
):
|
||||
if cfg.restore_file != "checkpoint_last.pt" and cfg.finetune_from_model:
|
||||
raise ValueError(
|
||||
"--finetune-from-model and --restore-file (non-default value) "
|
||||
"can not be specified together: " + str(args)
|
||||
"can not be specified together: " + str(cfg)
|
||||
)
|
||||
|
||||
extra_state = trainer.load_checkpoint(
|
||||
@ -225,10 +228,14 @@ def load_checkpoint_to_cpu(path, arg_overrides=None):
|
||||
f, map_location=lambda s, l: default_restore_location(s, "cpu")
|
||||
)
|
||||
|
||||
args = state["args"]
|
||||
if arg_overrides is not None:
|
||||
if "args" in state and state["args"] is not None and arg_overrides is not None:
|
||||
args = state["args"]
|
||||
for arg_name, arg_val in arg_overrides.items():
|
||||
setattr(args, arg_name, arg_val)
|
||||
|
||||
if "cfg" in state and state["cfg"] is not None and arg_overrides is not None:
|
||||
overwrite_args_by_name(state["cfg"], arg_overrides)
|
||||
|
||||
state = _upgrade_state_dict(state)
|
||||
return state
|
||||
|
||||
@ -274,19 +281,28 @@ def load_model_ensemble_and_task(
|
||||
filename = filename.replace(".pt", suffix + ".pt")
|
||||
else:
|
||||
filename = orig_filename[:-3] + f"_part{shard_idx}.pt"
|
||||
|
||||
if not PathManager.exists(filename):
|
||||
raise IOError("Model file not found: {}".format(filename))
|
||||
state = load_checkpoint_to_cpu(filename, arg_overrides)
|
||||
if shard_idx == 0:
|
||||
args = state["args"]
|
||||
if task is None:
|
||||
task = tasks.setup_task(args)
|
||||
if "args" in state and state["args"] is not None:
|
||||
cfg = convert_namespace_to_omegaconf(state["args"])
|
||||
elif "cfg" in state and state["cfg"] is not None:
|
||||
cfg = state["cfg"]
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Neither args nor cfg exist in state keys = {state.keys()}"
|
||||
)
|
||||
|
||||
# build model for ensemble
|
||||
model = task.build_model(args)
|
||||
model.load_state_dict(state["model"], strict=strict, args=args)
|
||||
if task is None:
|
||||
task = tasks.setup_task(cfg.task)
|
||||
|
||||
# build model for ensemble
|
||||
model = task.build_model(cfg.model)
|
||||
|
||||
model.load_state_dict(state["model"], strict=strict, model_cfg=cfg.model)
|
||||
ensemble.append(model)
|
||||
return ensemble, args, task
|
||||
return ensemble, cfg, task
|
||||
|
||||
|
||||
def checkpoint_paths(path, pattern=r"checkpoint(\d+)\.pt"):
|
||||
@ -323,7 +339,7 @@ def torch_persistent_save(obj, f):
|
||||
|
||||
def save_state(
|
||||
filename,
|
||||
args,
|
||||
cfg: DictConfig,
|
||||
model_state_dict,
|
||||
criterion,
|
||||
optimizer,
|
||||
@ -331,6 +347,7 @@ def save_state(
|
||||
num_updates,
|
||||
optim_history=None,
|
||||
extra_state=None,
|
||||
**kwargs,
|
||||
):
|
||||
from fairseq import utils
|
||||
|
||||
@ -339,7 +356,8 @@ def save_state(
|
||||
if extra_state is None:
|
||||
extra_state = {}
|
||||
state_dict = {
|
||||
"args": args,
|
||||
"cfg": cfg,
|
||||
"args": kwargs.get("args", None),
|
||||
"model": model_state_dict or {},
|
||||
"optimizer_history": optim_history
|
||||
+ [
|
||||
@ -354,11 +372,17 @@ def save_state(
|
||||
}
|
||||
if utils.has_parameters(criterion):
|
||||
state_dict["criterion"] = criterion.state_dict()
|
||||
if not args.no_save_optimizer_state:
|
||||
state_dict["last_optimizer_state"] = optimizer.state_dict()
|
||||
|
||||
# convert all state to CPU
|
||||
state_dict = utils.move_to_cpu(state_dict)
|
||||
if cfg is None:
|
||||
cfg = state_dict["args"]
|
||||
assert cfg is not None, "must provide cfg or args"
|
||||
|
||||
if isinstance(cfg, DictConfig):
|
||||
no_save_optimizer_state = cfg.checkpoint.no_save_optimizer_state
|
||||
else:
|
||||
no_save_optimizer_state = cfg.no_save_optimizer_state
|
||||
if not no_save_optimizer_state:
|
||||
state_dict["last_optimizer_state"] = optimizer.state_dict()
|
||||
|
||||
with PathManager.open(filename, "wb") as f:
|
||||
torch_persistent_save(state_dict, f)
|
||||
@ -403,46 +427,49 @@ def _upgrade_state_dict(state):
|
||||
# keep track of number of updates
|
||||
if "num_updates" not in state["optimizer_history"][-1]:
|
||||
state["optimizer_history"][-1]["num_updates"] = 0
|
||||
# old model checkpoints may not have separate source/target positions
|
||||
if hasattr(state["args"], "max_positions") and not hasattr(
|
||||
state["args"], "max_source_positions"
|
||||
):
|
||||
state["args"].max_source_positions = state["args"].max_positions
|
||||
state["args"].max_target_positions = state["args"].max_positions
|
||||
# use stateful training data iterator
|
||||
if "train_iterator" not in state["extra_state"]:
|
||||
state["extra_state"]["train_iterator"] = {
|
||||
"epoch": state["extra_state"]["epoch"],
|
||||
"iterations_in_epoch": state["extra_state"].get("batch_offset", 0),
|
||||
}
|
||||
# default to translation task
|
||||
if not hasattr(state["args"], "task"):
|
||||
state["args"].task = "translation"
|
||||
# --raw-text and --lazy-load are deprecated
|
||||
if getattr(state["args"], "raw_text", False):
|
||||
state["args"].dataset_impl = "raw"
|
||||
elif getattr(state["args"], "lazy_load", False):
|
||||
state["args"].dataset_impl = "lazy"
|
||||
# epochs start at 1
|
||||
if state["extra_state"]["train_iterator"] is not None:
|
||||
state["extra_state"]["train_iterator"]["epoch"] = max(
|
||||
state["extra_state"]["train_iterator"].get("epoch", 1),
|
||||
1,
|
||||
)
|
||||
|
||||
# set any missing default values in the task, model or other registries
|
||||
registry.set_defaults(state["args"], tasks.TASK_REGISTRY[state["args"].task])
|
||||
registry.set_defaults(state["args"], models.ARCH_MODEL_REGISTRY[state["args"].arch])
|
||||
for registry_name, REGISTRY in registry.REGISTRIES.items():
|
||||
choice = getattr(state["args"], registry_name, None)
|
||||
if choice is not None:
|
||||
cls = REGISTRY["registry"][choice]
|
||||
registry.set_defaults(state["args"], cls)
|
||||
# old model checkpoints may not have separate source/target positions
|
||||
# backward compatibility, cfg updates
|
||||
if "args" in state and state["args"] is not None:
|
||||
# default to translation task
|
||||
if not hasattr(state["args"], "task"):
|
||||
state["args"].task = "translation"
|
||||
# --raw-text and --lazy-load are deprecated
|
||||
if getattr(state["args"], "raw_text", False):
|
||||
state["args"].dataset_impl = "raw"
|
||||
elif getattr(state["args"], "lazy_load", False):
|
||||
state["args"].dataset_impl = "lazy"
|
||||
# epochs start at 1
|
||||
if state["extra_state"]["train_iterator"] is not None:
|
||||
state["extra_state"]["train_iterator"]["epoch"] = max(
|
||||
state["extra_state"]["train_iterator"].get("epoch", 1), 1
|
||||
)
|
||||
|
||||
state["cfg"] = convert_namespace_to_omegaconf(state["args"])
|
||||
|
||||
if "cfg" in state and state["cfg"] is not None:
|
||||
with open_dict(state["cfg"]):
|
||||
if state["cfg"].task is not None:
|
||||
if hasattr(state["cfg"].task, "max_positions") and not hasattr(
|
||||
state["cfg"].task, "max_source_positions"
|
||||
):
|
||||
state["cfg"].task.max_source_positions = state[
|
||||
"cfg"
|
||||
].task.max_positions
|
||||
state["cfg"].task.max_target_positions = state[
|
||||
"cfg"
|
||||
].task.max_positions
|
||||
|
||||
return state
|
||||
|
||||
|
||||
def prune_state_dict(state_dict, args):
|
||||
def prune_state_dict(state_dict, model_cfg: Optional[DictConfig]):
|
||||
"""Prune the given state_dict if desired for LayerDrop
|
||||
(https://arxiv.org/abs/1909.11556).
|
||||
|
||||
@ -453,16 +480,20 @@ def prune_state_dict(state_dict, args):
|
||||
It's called by functions that load models from checkpoints and does not
|
||||
need to be called directly.
|
||||
"""
|
||||
if not args or args.arch == "ptt_transformer":
|
||||
arch = None
|
||||
if model_cfg is not None:
|
||||
arch = (
|
||||
model_cfg._name
|
||||
if isinstance(model_cfg, DictConfig)
|
||||
else getattr(model_cfg, "arch", None)
|
||||
)
|
||||
|
||||
if not model_cfg or arch is None or arch == "ptt_transformer":
|
||||
# args should not be none, but don't crash if it is.
|
||||
return state_dict
|
||||
|
||||
encoder_layers_to_keep = (
|
||||
args.encoder_layers_to_keep if "encoder_layers_to_keep" in vars(args) else None
|
||||
)
|
||||
decoder_layers_to_keep = (
|
||||
args.decoder_layers_to_keep if "decoder_layers_to_keep" in vars(args) else None
|
||||
)
|
||||
encoder_layers_to_keep = getattr(model_cfg, "encoder_layers_to_keep", None)
|
||||
decoder_layers_to_keep = getattr(model_cfg, "decoder_layers_to_keep", None)
|
||||
|
||||
if not encoder_layers_to_keep and not decoder_layers_to_keep:
|
||||
return state_dict
|
||||
@ -474,7 +505,7 @@ def prune_state_dict(state_dict, args):
|
||||
|
||||
def create_pruning_pass(layers_to_keep, layer_name):
|
||||
keep_layers = sorted(
|
||||
[int(layer_string) for layer_string in layers_to_keep.split(",")]
|
||||
int(layer_string) for layer_string in layers_to_keep.split(",")
|
||||
)
|
||||
mapping_dict = {}
|
||||
for i in range(len(keep_layers)):
|
||||
@ -518,10 +549,12 @@ def prune_state_dict(state_dict, args):
|
||||
|
||||
# Since layers are now pruned, *_layers_to_keep are no longer needed.
|
||||
# This is more of "It would make it work fix" rather than a proper fix.
|
||||
if "encoder_layers_to_keep" in vars(args):
|
||||
args.encoder_layers_to_keep = None
|
||||
if "decoder_layers_to_keep" in vars(args):
|
||||
args.decoder_layers_to_keep = None
|
||||
|
||||
with open_dict(model_cfg):
|
||||
if hasattr(model_cfg, "encoder_layers_to_keep"):
|
||||
model_cfg.encoder_layers_to_keep = None
|
||||
if hasattr(model_cfg, "decoder_layers_to_keep"):
|
||||
model_cfg.decoder_layers_to_keep = None
|
||||
|
||||
return new_state_dict
|
||||
|
||||
|
@ -6,8 +6,6 @@
|
||||
|
||||
import importlib
|
||||
import os
|
||||
from argparse import Namespace
|
||||
from typing import Union
|
||||
|
||||
from fairseq import registry
|
||||
from fairseq.criterions.fairseq_criterion import ( # noqa
|
||||
@ -27,8 +25,8 @@ from omegaconf import DictConfig
|
||||
)
|
||||
|
||||
|
||||
def build_criterion(criterion_cfg: Union[DictConfig, Namespace], task):
|
||||
return build_criterion_(criterion_cfg, task)
|
||||
def build_criterion(cfg: DictConfig, task):
|
||||
return build_criterion_(cfg, task)
|
||||
|
||||
|
||||
# automatically import any Python files in the criterions/ directory
|
||||
|
@ -11,13 +11,13 @@ from fairseq import metrics, utils
|
||||
from fairseq.criterions import FairseqCriterion, register_criterion
|
||||
from fairseq.dataclass import FairseqDataclass
|
||||
from fairseq.dataclass.constants import DDP_BACKEND_CHOICES
|
||||
from omegaconf import II
|
||||
from omegaconf import II, DictConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdaptiveLossConfig(FairseqDataclass):
|
||||
sentence_avg: bool = II("params.optimization.sentence_avg")
|
||||
ddp_backend: DDP_BACKEND_CHOICES = II("params.distributed_training.ddp_backend")
|
||||
sentence_avg: bool = II("optimization.sentence_avg")
|
||||
ddp_backend: DDP_BACKEND_CHOICES = II("distributed_training.ddp_backend")
|
||||
|
||||
|
||||
@register_criterion("adaptive_loss", dataclass=AdaptiveLossConfig)
|
||||
@ -31,14 +31,14 @@ class AdaptiveLoss(FairseqCriterion):
|
||||
self.sentence_avg = sentence_avg
|
||||
|
||||
@classmethod
|
||||
def build_criterion(cls, args, task):
|
||||
if getattr(args, "ddp_backend", None) == "c10d":
|
||||
def build_criterion(cls, cfg: DictConfig, task):
|
||||
if cfg.ddp_backend == "c10d":
|
||||
raise Exception(
|
||||
"AdaptiveLoss is not compatible with the c10d "
|
||||
"version of DistributedDataParallel. Please use "
|
||||
"`--ddp-backend=no_c10d` instead."
|
||||
)
|
||||
return cls(task, args.sentence_avg)
|
||||
return cls(task, cfg.sentence_avg)
|
||||
|
||||
def forward(self, model, sample, reduce=True):
|
||||
"""Compute the loss for the given sample.
|
||||
|
@ -15,7 +15,7 @@ from omegaconf import II
|
||||
|
||||
@dataclass
|
||||
class CrossEntropyCriterionConfig(FairseqDataclass):
|
||||
sentence_avg: bool = II("params.optimization.sentence_avg")
|
||||
sentence_avg: bool = II("optimization.sentence_avg")
|
||||
|
||||
|
||||
@register_criterion("cross_entropy", dataclass=CrossEntropyCriterionConfig)
|
||||
|
@ -10,24 +10,24 @@ from argparse import Namespace
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from fairseq import metrics, utils
|
||||
from fairseq.criterions import FairseqCriterion, register_criterion
|
||||
from fairseq.criterions import LegacyFairseqCriterion, register_criterion
|
||||
from fairseq.data.data_utils import post_process
|
||||
from fairseq.logging.meters import safe_round
|
||||
|
||||
|
||||
@register_criterion("ctc")
|
||||
class CtcCriterion(FairseqCriterion):
|
||||
def __init__(self, task, wer_args, zero_infinity, sentence_avg, remove_bpe):
|
||||
super().__init__(task)
|
||||
class CtcCriterion(LegacyFairseqCriterion):
|
||||
def __init__(self, args, task):
|
||||
super().__init__(args, task)
|
||||
self.blank_idx = task.target_dictionary.bos()
|
||||
self.pad_idx = task.target_dictionary.pad()
|
||||
self.eos_idx = task.target_dictionary.eos()
|
||||
self.post_process = remove_bpe if remove_bpe else "letter"
|
||||
self.post_process = args.remove_bpe if args.remove_bpe else "letter"
|
||||
|
||||
if wer_args is not None:
|
||||
if args.wer_args is not None:
|
||||
from examples.speech_recognition.w2l_decoder import W2lKenLMDecoder
|
||||
|
||||
wer_compute_kenlm, wer_lexicon, lm_w, ws_w = eval(wer_args)
|
||||
wer_compute_kenlm, wer_lexicon, lm_w, ws_w = eval(args.wer_args)
|
||||
|
||||
dec_args = Namespace()
|
||||
dec_args.nbest = 1
|
||||
@ -46,8 +46,8 @@ class CtcCriterion(FairseqCriterion):
|
||||
else:
|
||||
self.w2l_decoder = None
|
||||
|
||||
self.zero_infinity = zero_infinity
|
||||
self.sentence_avg = sentence_avg
|
||||
self.zero_infinity = args.zero_infinity
|
||||
self.sentence_avg = args.sentence_avg
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
@ -8,6 +8,7 @@ from typing import Any, Dict, List
|
||||
|
||||
from fairseq import metrics, utils
|
||||
from fairseq.dataclass.utils import gen_parser_from_dataclass
|
||||
from omegaconf import DictConfig
|
||||
from torch.nn.modules.loss import _Loss
|
||||
|
||||
|
||||
@ -27,10 +28,8 @@ class FairseqCriterion(_Loss):
|
||||
gen_parser_from_dataclass(parser, dc())
|
||||
|
||||
@classmethod
|
||||
def build_criterion(cls, args, task):
|
||||
def build_criterion(cls, cfg: DictConfig, task):
|
||||
"""Construct a criterion from command-line args."""
|
||||
# Criterions can override this, but for convenience we also try
|
||||
# to automatically map argparse.Namespace keys to corresponding
|
||||
# arguments in the __init__.
|
||||
init_args = {}
|
||||
for p in inspect.signature(cls).parameters.values():
|
||||
@ -47,8 +46,8 @@ class FairseqCriterion(_Loss):
|
||||
|
||||
if p.name == "task":
|
||||
init_args["task"] = task
|
||||
elif hasattr(args, p.name):
|
||||
init_args[p.name] = getattr(args, p.name)
|
||||
elif hasattr(cfg, p.name):
|
||||
init_args[p.name] = getattr(cfg, p.name)
|
||||
elif p.default != p.empty:
|
||||
pass # we'll use the default value
|
||||
else:
|
||||
@ -70,7 +69,7 @@ class FairseqCriterion(_Loss):
|
||||
|
||||
@staticmethod
|
||||
def aggregate_logging_outputs(
|
||||
logging_outputs: List[Dict[str, Any]],
|
||||
logging_outputs: List[Dict[str, Any]]
|
||||
) -> Dict[str, Any]:
|
||||
"""Aggregate logging outputs from data parallel training."""
|
||||
utils.deprecation_warning(
|
||||
|
@ -4,6 +4,8 @@
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from fairseq import file_utils
|
||||
from fairseq.data.encoders import register_bpe
|
||||
from fairseq.data.encoders.byte_utils import (
|
||||
@ -12,19 +14,20 @@ from fairseq.data.encoders.byte_utils import (
|
||||
byte_encode,
|
||||
smart_byte_decode,
|
||||
)
|
||||
from fairseq.dataclass import FairseqDataclass
|
||||
|
||||
|
||||
@register_bpe("byte_bpe")
|
||||
@dataclass
|
||||
class ByteBpeConfig(FairseqDataclass):
|
||||
sentencepiece_model_path: str = field(
|
||||
default="???", metadata={"help": "path to sentencepiece model"}
|
||||
)
|
||||
|
||||
|
||||
@register_bpe("byte_bpe", dataclass=ByteBpeConfig)
|
||||
class ByteBPE(object):
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
# fmt: off
|
||||
parser.add_argument('--sentencepiece-model-path', type=str,
|
||||
help='path to sentencepiece model')
|
||||
# fmt: on
|
||||
|
||||
def __init__(self, args):
|
||||
vocab = file_utils.cached_path(args.sentencepiece_model_path)
|
||||
def __init__(self, cfg):
|
||||
vocab = file_utils.cached_path(cfg.sentencepiece_model_path)
|
||||
try:
|
||||
import sentencepiece as spm
|
||||
|
||||
|
@ -15,7 +15,7 @@ from fairseq.data.encoders.byte_utils import (
|
||||
|
||||
@register_bpe("bytes")
|
||||
class Bytes(object):
|
||||
def __init__(self, args):
|
||||
def __init__(self, *unused):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
|
@ -13,7 +13,7 @@ SPACE_ESCAPE = chr(9601)
|
||||
|
||||
@register_bpe("characters")
|
||||
class Characters(object):
|
||||
def __init__(self, args):
|
||||
def __init__(self, *unused):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
|
@ -3,23 +3,24 @@
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from fairseq import file_utils
|
||||
from fairseq.data.encoders import register_bpe
|
||||
from fairseq.dataclass import FairseqDataclass
|
||||
|
||||
|
||||
@register_bpe("fastbpe")
|
||||
@dataclass
|
||||
class fastBPEConfig(FairseqDataclass):
|
||||
bpe_codes: str = field(default="???", metadata={"help": "path to fastBPE BPE"})
|
||||
|
||||
|
||||
@register_bpe("fastbpe", dataclass=fastBPEConfig)
|
||||
class fastBPE(object):
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
# fmt: off
|
||||
parser.add_argument('--bpe-codes', type=str,
|
||||
help='path to fastBPE BPE')
|
||||
# fmt: on
|
||||
|
||||
def __init__(self, args):
|
||||
if args.bpe_codes is None:
|
||||
def __init__(self, cfg):
|
||||
if cfg.bpe_codes is None:
|
||||
raise ValueError("--bpe-codes is required for --bpe=fastbpe")
|
||||
codes = file_utils.cached_path(args.bpe_codes)
|
||||
codes = file_utils.cached_path(cfg.bpe_codes)
|
||||
try:
|
||||
import fastBPE
|
||||
|
||||
|
@ -3,8 +3,11 @@
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from fairseq import file_utils
|
||||
from fairseq.data.encoders import register_bpe
|
||||
from fairseq.dataclass import FairseqDataclass
|
||||
|
||||
from .gpt2_bpe_utils import get_encoder
|
||||
|
||||
@ -13,26 +16,21 @@ DEFAULT_ENCODER_JSON = "https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.
|
||||
DEFAULT_VOCAB_BPE = "https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe"
|
||||
|
||||
|
||||
@register_bpe("gpt2")
|
||||
class GPT2BPE(object):
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
# fmt: off
|
||||
parser.add_argument('--gpt2-encoder-json', type=str,
|
||||
default=DEFAULT_ENCODER_JSON,
|
||||
help='path to encoder.json')
|
||||
parser.add_argument('--gpt2-vocab-bpe', type=str,
|
||||
default=DEFAULT_VOCAB_BPE,
|
||||
help='path to vocab.bpe')
|
||||
# fmt: on
|
||||
@dataclass
|
||||
class GPT2BPEConfig(FairseqDataclass):
|
||||
gpt2_encoder_json: str = field(
|
||||
default=DEFAULT_ENCODER_JSON, metadata={"help": "path to encoder.json"}
|
||||
)
|
||||
gpt2_vocab_bpe: str = field(
|
||||
default=DEFAULT_VOCAB_BPE, metadata={"help": "path to vocab.bpe"}
|
||||
)
|
||||
|
||||
def __init__(self, args):
|
||||
encoder_json = file_utils.cached_path(
|
||||
getattr(args, "gpt2_encoder_json", DEFAULT_ENCODER_JSON)
|
||||
)
|
||||
vocab_bpe = file_utils.cached_path(
|
||||
getattr(args, "gpt2_vocab_bpe", DEFAULT_VOCAB_BPE)
|
||||
)
|
||||
|
||||
@register_bpe("gpt2", dataclass=GPT2BPEConfig)
|
||||
class GPT2BPE(object):
|
||||
def __init__(self, cfg):
|
||||
encoder_json = file_utils.cached_path(cfg.gpt2_encoder_json)
|
||||
vocab_bpe = file_utils.cached_path(cfg.gpt2_vocab_bpe)
|
||||
self.bpe = get_encoder(encoder_json, vocab_bpe)
|
||||
|
||||
def encode(self, x: str) -> str:
|
||||
|
@ -3,22 +3,24 @@
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
from fairseq.data.encoders import register_bpe
|
||||
from fairseq.dataclass import FairseqDataclass
|
||||
|
||||
|
||||
@register_bpe("bert")
|
||||
@dataclass
|
||||
class BertBPEConfig(FairseqDataclass):
|
||||
bpe_cased: bool = field(default=False, metadata={"help": "set for cased BPE"})
|
||||
bpe_vocab_file: Optional[str] = field(
|
||||
default=None, metadata={"help": "bpe vocab file"}
|
||||
)
|
||||
|
||||
|
||||
@register_bpe("bert", dataclass=BertBPEConfig)
|
||||
class BertBPE(object):
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
# fmt: off
|
||||
parser.add_argument('--bpe-cased', action='store_true',
|
||||
help='set for cased BPE',
|
||||
default=False)
|
||||
parser.add_argument('--bpe-vocab-file', type=str,
|
||||
help='bpe vocab file.')
|
||||
# fmt: on
|
||||
|
||||
def __init__(self, args):
|
||||
def __init__(self, cfg):
|
||||
try:
|
||||
from transformers import BertTokenizer
|
||||
except ImportError:
|
||||
@ -26,13 +28,13 @@ class BertBPE(object):
|
||||
"Please install transformers with: pip install transformers"
|
||||
)
|
||||
|
||||
if "bpe_vocab_file" in args:
|
||||
if cfg.bpe_vocab_file:
|
||||
self.bert_tokenizer = BertTokenizer(
|
||||
args.bpe_vocab_file, do_lower_case=not args.bpe_cased
|
||||
cfg.bpe_vocab_file, do_lower_case=not cfg.bpe_cased
|
||||
)
|
||||
else:
|
||||
vocab_file_name = (
|
||||
"bert-base-cased" if args.bpe_cased else "bert-base-uncased"
|
||||
"bert-base-cased" if cfg.bpe_cased else "bert-base-uncased"
|
||||
)
|
||||
self.bert_tokenizer = BertTokenizer.from_pretrained(vocab_file_name)
|
||||
|
||||
|
@ -3,21 +3,24 @@
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from fairseq.data.encoders import register_bpe
|
||||
from fairseq.dataclass import FairseqDataclass
|
||||
|
||||
|
||||
@register_bpe("hf_byte_bpe")
|
||||
@dataclass
|
||||
class HuggingFaceByteLevelBPEConfig(FairseqDataclass):
|
||||
bpe_merges: str = field(default="???", metadata={"help": "path to merges.txt"})
|
||||
bpe_vocab: str = field(default="???", metadata={"help": "path to vocab.json"})
|
||||
bpe_add_prefix_space: bool = field(
|
||||
default=False, metadata={"help": "add prefix space before encoding"}
|
||||
)
|
||||
|
||||
|
||||
@register_bpe("hf_byte_bpe", dataclass=HuggingFaceByteLevelBPEConfig)
|
||||
class HuggingFaceByteLevelBPE(object):
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
# fmt: off
|
||||
parser.add_argument('--bpe-merges', help='path to merges.txt')
|
||||
parser.add_argument('--bpe-vocab', help='path to vocab.json')
|
||||
parser.add_argument('--bpe-add-prefix-space', action='store_true',
|
||||
help='add prefix space before encoding')
|
||||
# fmt: on
|
||||
|
||||
def __init__(self, args):
|
||||
def __init__(self, cfg):
|
||||
try:
|
||||
from tokenizers import ByteLevelBPETokenizer
|
||||
except ImportError:
|
||||
@ -26,9 +29,9 @@ class HuggingFaceByteLevelBPE(object):
|
||||
)
|
||||
|
||||
self.bpe = ByteLevelBPETokenizer(
|
||||
args.bpe_vocab,
|
||||
args.bpe_merges,
|
||||
add_prefix_space=getattr(args, "bpe_add_prefix_space", False),
|
||||
cfg.bpe_vocab,
|
||||
cfg.bpe_merges,
|
||||
add_prefix_space=cfg.bpe_add_prefix_space,
|
||||
)
|
||||
|
||||
def encode(self, x: str) -> str:
|
||||
|
@ -3,37 +3,35 @@
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from fairseq.data.encoders import register_tokenizer
|
||||
from fairseq.dataclass import FairseqDataclass
|
||||
|
||||
|
||||
@register_tokenizer("moses")
|
||||
@dataclass
|
||||
class MosesTokenizerConfig(FairseqDataclass):
|
||||
source_lang: str = field(default="en", metadata={"help": "source language"})
|
||||
target_lang: str = field(default="en", metadata={"help": "target language"})
|
||||
moses_no_dash_splits: bool = field(
|
||||
default=False, metadata={"help": "don't apply dash split rules"}
|
||||
)
|
||||
moses_no_escape: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "don't perform HTML escaping on apostrophe, quotes, etc."},
|
||||
)
|
||||
|
||||
|
||||
@register_tokenizer("moses", dataclass=MosesTokenizerConfig)
|
||||
class MosesTokenizer(object):
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
# fmt: off
|
||||
parser.add_argument('--moses-source-lang', metavar='SRC',
|
||||
help='source language')
|
||||
parser.add_argument('--moses-target-lang', metavar='TARGET',
|
||||
help='target language')
|
||||
parser.add_argument('--moses-no-dash-splits', action='store_true', default=False,
|
||||
help='don\'t apply dash split rules')
|
||||
parser.add_argument('--moses-no-escape', action='store_true', default=False,
|
||||
help='don\'t perform HTML escaping on apostrophy, quotes, etc.')
|
||||
# fmt: on
|
||||
|
||||
def __init__(self, args):
|
||||
self.args = args
|
||||
|
||||
if getattr(args, "moses_source_lang", None) is None:
|
||||
args.moses_source_lang = getattr(args, "source_lang", "en")
|
||||
if getattr(args, "moses_target_lang", None) is None:
|
||||
args.moses_target_lang = getattr(args, "target_lang", "en")
|
||||
def __init__(self, cfg):
|
||||
self.cfg = cfg
|
||||
|
||||
try:
|
||||
from sacremoses import MosesTokenizer, MosesDetokenizer
|
||||
|
||||
self.tok = MosesTokenizer(args.moses_source_lang)
|
||||
self.detok = MosesDetokenizer(args.moses_target_lang)
|
||||
self.tok = MosesTokenizer(cfg.source_lang)
|
||||
self.detok = MosesDetokenizer(cfg.target_lang)
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Please install Moses tokenizer with: pip install sacremoses"
|
||||
@ -42,9 +40,9 @@ class MosesTokenizer(object):
|
||||
def encode(self, x: str) -> str:
|
||||
return self.tok.tokenize(
|
||||
x,
|
||||
aggressive_dash_splits=(not self.args.moses_no_dash_splits),
|
||||
aggressive_dash_splits=(not self.cfg.moses_no_dash_splits),
|
||||
return_str=True,
|
||||
escape=(not self.args.moses_no_escape),
|
||||
escape=(not self.cfg.moses_no_escape),
|
||||
)
|
||||
|
||||
def decode(self, x: str) -> str:
|
||||
|
@ -8,7 +8,7 @@ from fairseq.data.encoders import register_tokenizer
|
||||
|
||||
@register_tokenizer("nltk")
|
||||
class NLTKTokenizer(object):
|
||||
def __init__(self, source_lang=None, target_lang=None):
|
||||
def __init__(self, *unused):
|
||||
try:
|
||||
from nltk.tokenize import word_tokenize
|
||||
|
||||
|
@ -3,21 +3,24 @@
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from fairseq import file_utils
|
||||
from fairseq.data.encoders import register_bpe
|
||||
from fairseq.dataclass import FairseqDataclass
|
||||
|
||||
|
||||
@register_bpe("sentencepiece")
|
||||
@dataclass
|
||||
class SentencepieceConfig(FairseqDataclass):
|
||||
sentencepiece_model: str = field(
|
||||
default="???", metadata={"help": "path to sentencepiece model"}
|
||||
)
|
||||
|
||||
|
||||
@register_bpe("sentencepiece", dataclass=SentencepieceConfig)
|
||||
class SentencepieceBPE(object):
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
# fmt: off
|
||||
parser.add_argument('--sentencepiece-model', type=str,
|
||||
help='path to sentencepiece model')
|
||||
# fmt: on
|
||||
|
||||
def __init__(self, args):
|
||||
sentencepiece_model = file_utils.cached_path(args.sentencepiece_model)
|
||||
def __init__(self, cfg):
|
||||
sentencepiece_model = file_utils.cached_path(cfg.sentencepiece_model)
|
||||
try:
|
||||
import sentencepiece as spm
|
||||
|
||||
|
@ -10,7 +10,7 @@ from fairseq.data.encoders import register_tokenizer
|
||||
|
||||
@register_tokenizer("space")
|
||||
class SpaceTokenizer(object):
|
||||
def __init__(self, source_lang=None, target_lang=None):
|
||||
def __init__(self, *unused):
|
||||
self.space_tok = re.compile(r"\s+")
|
||||
|
||||
def encode(self, x: str) -> str:
|
||||
|
@ -3,25 +3,25 @@
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from fairseq import file_utils
|
||||
from fairseq.data.encoders import register_bpe
|
||||
from fairseq.dataclass import FairseqDataclass
|
||||
|
||||
|
||||
@register_bpe("subword_nmt")
|
||||
@dataclass
|
||||
class SubwordNMTBPEConfig(FairseqDataclass):
|
||||
bpe_codes: str = field(default="???", metadata={"help": "path to subword NMT BPE"})
|
||||
bpe_separator: str = field(default="@@", metadata={"help": "BPE separator"})
|
||||
|
||||
|
||||
@register_bpe("subword_nmt", dataclass=SubwordNMTBPEConfig)
|
||||
class SubwordNMTBPE(object):
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
# fmt: off
|
||||
parser.add_argument('--bpe-codes', type=str,
|
||||
help='path to subword NMT BPE')
|
||||
parser.add_argument('--bpe-separator', default='@@',
|
||||
help='BPE separator')
|
||||
# fmt: on
|
||||
|
||||
def __init__(self, args):
|
||||
if args.bpe_codes is None:
|
||||
def __init__(self, cfg):
|
||||
if cfg.bpe_codes is None:
|
||||
raise ValueError("--bpe-codes is required for --bpe=subword_nmt")
|
||||
codes = file_utils.cached_path(args.bpe_codes)
|
||||
codes = file_utils.cached_path(cfg.bpe_codes)
|
||||
try:
|
||||
from subword_nmt import apply_bpe
|
||||
|
||||
@ -31,7 +31,7 @@ class SubwordNMTBPE(object):
|
||||
"--codes",
|
||||
codes,
|
||||
"--separator",
|
||||
args.bpe_separator,
|
||||
cfg.bpe_separator,
|
||||
]
|
||||
)
|
||||
self.bpe = apply_bpe.BPE(
|
||||
|
@ -9,5 +9,7 @@ from fairseq.dataclass.utils import ChoiceEnum
|
||||
LOG_FORMAT_CHOICES = ChoiceEnum(["json", "none", "simple", "tqdm"])
|
||||
DDP_BACKEND_CHOICES = ChoiceEnum(["c10d", "no_c10d"])
|
||||
DISTRIBUTED_WRAPPER_CHOICES = ChoiceEnum(["DDP", "SlowMo"])
|
||||
GENERATION_CONSTRAINTS_CHOICES = ChoiceEnum(["ordered", "unordered"])
|
||||
GENERATION_DECODING_FORMAT_CHOICES = ChoiceEnum(["unigram", "ensemble", "vote", "dp", "bs"])
|
||||
ZERO_SHARDING_CHOICES = ChoiceEnum(["none", "os"])
|
||||
PIPELINE_CHECKPOINT_CHOICES = ChoiceEnum(["always", "never", "except_last"])
|
||||
|
@ -3,32 +3,37 @@
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import sys
|
||||
from argparse import Namespace
|
||||
from dataclasses import dataclass, field
|
||||
from dataclasses import _MISSING_TYPE, dataclass, field
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
from fairseq.criterions import CRITERION_DATACLASS_REGISTRY
|
||||
from fairseq.data.indexed_dataset import get_available_dataset_impl
|
||||
from fairseq.dataclass.constants import (
|
||||
DDP_BACKEND_CHOICES,
|
||||
DISTRIBUTED_WRAPPER_CHOICES,
|
||||
GENERATION_CONSTRAINTS_CHOICES,
|
||||
GENERATION_DECODING_FORMAT_CHOICES,
|
||||
LOG_FORMAT_CHOICES,
|
||||
PIPELINE_CHECKPOINT_CHOICES,
|
||||
ZERO_SHARDING_CHOICES,
|
||||
)
|
||||
from fairseq.dataclass.utils import ChoiceEnum, FairseqDataclass
|
||||
from fairseq.models import ARCH_MODEL_REGISTRY, MODEL_DATACLASS_REGISTRY
|
||||
from fairseq.optim import OPTIMIZER_DATACLASS_REGISTRY
|
||||
from fairseq.optim.bmuf import FairseqBMUFConfig
|
||||
from fairseq.optim.lr_scheduler import LR_SCHEDULER_DATACLASS_REGISTRY
|
||||
from fairseq.registry import REGISTRIES
|
||||
from fairseq.tasks import TASK_DATACLASS_REGISTRY
|
||||
from hydra.core.config_store import ConfigStore
|
||||
from omegaconf import II
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CommonParams(FairseqDataclass):
|
||||
class CommonConfig(FairseqDataclass):
|
||||
# This is the core dataclass including common parameters shared by all different jobs. Please append your params to other dataclasses if they were
|
||||
# used for a particular purpose or task, such as those dedicated for `distributed training`, `optimization`, etc.
|
||||
no_progress_bar: bool = field(
|
||||
@ -109,18 +114,6 @@ class CommonParams(FairseqDataclass):
|
||||
model_parallel_size: int = field(
|
||||
default=1, metadata={"help": "total number of GPUs to parallelize model over"}
|
||||
)
|
||||
checkpoint_suffix: str = field(
|
||||
default="", metadata={"help": "suffix to add to the checkpoint file name"}
|
||||
)
|
||||
checkpoint_shard_count: int = field(
|
||||
default=1,
|
||||
metadata={
|
||||
"help": "Number of shards containing the checkpoint - "
|
||||
"if the checkpoint is over 300GB, it is preferable "
|
||||
"to split it into shards to prevent OOM on CPU while loading "
|
||||
"the checkpoint"
|
||||
},
|
||||
)
|
||||
quantization_config_path: Optional[str] = field(
|
||||
default=None, metadata={"help": "path to quantization config file"}
|
||||
)
|
||||
@ -130,7 +123,7 @@ class CommonParams(FairseqDataclass):
|
||||
|
||||
|
||||
@dataclass
|
||||
class DistributedTrainingParams(FairseqDataclass):
|
||||
class DistributedTrainingConfig(FairseqDataclass):
|
||||
distributed_world_size: int = field(
|
||||
default=max(1, torch.cuda.device_count()),
|
||||
metadata={
|
||||
@ -229,7 +222,7 @@ class DistributedTrainingParams(FairseqDataclass):
|
||||
default=False,
|
||||
metadata={"help": "if set, use pipeline model parallelism across GPUs"},
|
||||
)
|
||||
pipeline_balance: str = field(
|
||||
pipeline_balance: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "partition the model into N_K pieces, where each piece "
|
||||
@ -237,7 +230,7 @@ class DistributedTrainingParams(FairseqDataclass):
|
||||
"should equal the total number of layers in the model"
|
||||
},
|
||||
)
|
||||
pipeline_devices: str = field(
|
||||
pipeline_devices: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "a list of device indices indicating which device to place "
|
||||
@ -245,10 +238,10 @@ class DistributedTrainingParams(FairseqDataclass):
|
||||
"equal the length of the --pipeline-balance argument"
|
||||
},
|
||||
)
|
||||
pipeline_chunks: int = field(
|
||||
pipeline_chunks: Optional[int] = field(
|
||||
default=0, metadata={"help": "microbatch count for pipeline model parallelism"}
|
||||
)
|
||||
pipeline_encoder_balance: str = field(
|
||||
pipeline_encoder_balance: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "partition the pipeline parallel encoder into N_K pieces, where each piece "
|
||||
@ -256,7 +249,7 @@ class DistributedTrainingParams(FairseqDataclass):
|
||||
"should equal the total number of encoder layers in the model"
|
||||
},
|
||||
)
|
||||
pipeline_encoder_devices: str = field(
|
||||
pipeline_encoder_devices: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "a list of device indices indicating which device to place "
|
||||
@ -264,7 +257,7 @@ class DistributedTrainingParams(FairseqDataclass):
|
||||
"equal the length of the --pipeline-encoder-balance argument"
|
||||
},
|
||||
)
|
||||
pipeline_decoder_balance: str = field(
|
||||
pipeline_decoder_balance: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "partition the pipeline parallel decoder into N_K pieces, where each piece "
|
||||
@ -272,7 +265,7 @@ class DistributedTrainingParams(FairseqDataclass):
|
||||
"should equal the total number of decoder layers in the model"
|
||||
},
|
||||
)
|
||||
pipeline_decoder_devices: str = field(
|
||||
pipeline_decoder_devices: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "a list of device indices indicating which device to place "
|
||||
@ -287,10 +280,11 @@ class DistributedTrainingParams(FairseqDataclass):
|
||||
zero_sharding: ZERO_SHARDING_CHOICES = field(
|
||||
default="none", metadata={"help": "ZeRO sharding"}
|
||||
)
|
||||
tpu: bool = II("common.tpu")
|
||||
|
||||
|
||||
@dataclass
|
||||
class DatasetParams(FairseqDataclass):
|
||||
class DatasetConfig(FairseqDataclass):
|
||||
num_workers: int = field(
|
||||
default=1, metadata={"help": "how many subprocesses to use for data loading"}
|
||||
)
|
||||
@ -374,7 +368,7 @@ class DatasetParams(FairseqDataclass):
|
||||
|
||||
|
||||
@dataclass
|
||||
class OptimizationParams(FairseqDataclass):
|
||||
class OptimizationConfig(FairseqDataclass):
|
||||
max_epoch: int = field(
|
||||
default=0, metadata={"help": "force stop training at specified epoch"}
|
||||
)
|
||||
@ -421,7 +415,7 @@ class OptimizationParams(FairseqDataclass):
|
||||
|
||||
|
||||
@dataclass
|
||||
class CheckpointParams(FairseqDataclass):
|
||||
class CheckpointConfig(FairseqDataclass):
|
||||
save_dir: str = field(
|
||||
default="checkpoints", metadata={"help": "path to save checkpoints"}
|
||||
)
|
||||
@ -514,12 +508,217 @@ class CheckpointParams(FairseqDataclass):
|
||||
)
|
||||
},
|
||||
)
|
||||
checkpoint_suffix: str = field(
|
||||
default="", metadata={"help": "suffix to add to the checkpoint file name"}
|
||||
)
|
||||
checkpoint_shard_count: int = field(
|
||||
default=1,
|
||||
metadata={
|
||||
"help": "Number of shards containing the checkpoint - "
|
||||
"if the checkpoint is over 300GB, it is preferable "
|
||||
"to split it into shards to prevent OOM on CPU while loading "
|
||||
"the checkpoint"
|
||||
},
|
||||
)
|
||||
model_parallel_size: int = II("common.model_parallel_size")
|
||||
distributed_rank: int = II("distributed_training.distributed_rank")
|
||||
|
||||
|
||||
@dataclass
|
||||
class CommonEvalParams(FairseqDataclass):
|
||||
class GenerationConfig(FairseqDataclass):
|
||||
beam: int = field(
|
||||
default=5,
|
||||
metadata={"help": "beam size"},
|
||||
)
|
||||
nbest: int = field(
|
||||
default=1,
|
||||
metadata={"help": "number of hypotheses to output"},
|
||||
)
|
||||
max_len_a: float = field(
|
||||
default=0,
|
||||
metadata={
|
||||
"help": "generate sequences of maximum length ax + b, where x is the source length"
|
||||
},
|
||||
)
|
||||
max_len_b: int = field(
|
||||
default=200,
|
||||
metadata={
|
||||
"help": "generate sequences of maximum length ax + b, where x is the source length"
|
||||
},
|
||||
)
|
||||
min_len: int = field(
|
||||
default=1,
|
||||
metadata={"help": "minimum generation length"},
|
||||
)
|
||||
match_source_len: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "generations should match the source length"},
|
||||
)
|
||||
unnormalized: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "compare unnormalized hypothesis scores"},
|
||||
)
|
||||
no_early_stop: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "deprecated"},
|
||||
)
|
||||
no_beamable_mm: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "don't use BeamableMM in attention layers"},
|
||||
)
|
||||
lenpen: float = field(
|
||||
default=1,
|
||||
metadata={
|
||||
"help": "length penalty: <1.0 favors shorter, >1.0 favors longer sentences"
|
||||
},
|
||||
)
|
||||
unkpen: float = field(
|
||||
default=0,
|
||||
metadata={
|
||||
"help": "unknown word penalty: <0 produces more unks, >0 produces fewer"
|
||||
},
|
||||
)
|
||||
replace_unk: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "perform unknown replacement (optionally with alignment dictionary)",
|
||||
"argparse_const": "@@ ",
|
||||
},
|
||||
)
|
||||
sacrebleu: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "score with sacrebleu"},
|
||||
)
|
||||
score_reference: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "just score the reference translation"},
|
||||
)
|
||||
prefix_size: int = field(
|
||||
default=0,
|
||||
metadata={"help": "initialize generation by target prefix of given length"},
|
||||
)
|
||||
no_repeat_ngram_size: int = field(
|
||||
default=0,
|
||||
metadata={
|
||||
"help": "ngram blocking such that this size ngram cannot be repeated in the generation"
|
||||
},
|
||||
)
|
||||
sampling: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "sample hypotheses instead of using beam search"},
|
||||
)
|
||||
sampling_topk: int = field(
|
||||
default=-1,
|
||||
metadata={"help": "sample from top K likely next words instead of all words"},
|
||||
)
|
||||
sampling_topp: float = field(
|
||||
default=-1.0,
|
||||
metadata={
|
||||
"help": "sample from the smallest set whose cumulative probability mass exceeds p for next words"
|
||||
},
|
||||
)
|
||||
constraints: Optional[GENERATION_CONSTRAINTS_CHOICES] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "enables lexically constrained decoding",
|
||||
"argparse_const": "ordered",
|
||||
},
|
||||
)
|
||||
temperature: float = field(
|
||||
default=1.0,
|
||||
metadata={"help": "temperature for generation"},
|
||||
)
|
||||
diverse_beam_groups: int = field(
|
||||
default=-1,
|
||||
metadata={"help": "number of groups for Diverse Beam Search"},
|
||||
)
|
||||
diverse_beam_strength: float = field(
|
||||
default=0.5,
|
||||
metadata={"help": "strength of diversity penalty for Diverse Beam Search"},
|
||||
)
|
||||
diversity_rate: float = field(
|
||||
default=-1.0,
|
||||
metadata={"help": "strength of diversity penalty for Diverse Siblings Search"},
|
||||
)
|
||||
print_alignment: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "if set, uses attention feedback to compute and print alignment to source tokens"
|
||||
},
|
||||
)
|
||||
print_step: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "print steps"},
|
||||
)
|
||||
lm_path: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "path to lm checkpoint for lm fusion"},
|
||||
)
|
||||
lm_weight: float = field(
|
||||
default=0.0,
|
||||
metadata={"help": "weight for lm probs for lm fusion"},
|
||||
)
|
||||
|
||||
# arguments for iterative refinement generator
|
||||
iter_decode_eos_penalty: float = field(
|
||||
default=0.0,
|
||||
metadata={"help": "if > 0.0, it penalized early-stopping in decoding."},
|
||||
)
|
||||
iter_decode_max_iter: int = field(
|
||||
default=10,
|
||||
metadata={"help": "maximum iterations for iterative refinement."},
|
||||
)
|
||||
iter_decode_force_max_iter: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "if set, run exact the maximum number of iterations without early stop"
|
||||
},
|
||||
)
|
||||
iter_decode_with_beam: int = field(
|
||||
default=1,
|
||||
metadata={
|
||||
"help": "if > 1, model will generate translations varying by the lengths."
|
||||
},
|
||||
)
|
||||
iter_decode_with_external_reranker: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "if set, the last checkpoint are assumed to be a reranker to rescore the translations"
|
||||
},
|
||||
)
|
||||
retain_iter_history: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "if set, decoding returns the whole history of iterative refinement"
|
||||
},
|
||||
)
|
||||
retain_dropout: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Use dropout at inference time"},
|
||||
)
|
||||
retain_dropout_modules: Optional[List[str]] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "if set, only retain dropout for the specified modules; "
|
||||
"if not set, then dropout will be retained for all modules"
|
||||
},
|
||||
)
|
||||
# special decoding format for advanced decoding.
|
||||
decoding_format: Optional[GENERATION_DECODING_FORMAT_CHOICES] = field(
|
||||
default=None,
|
||||
metadata={"help": "special decoding format for advanced decoding."},
|
||||
)
|
||||
no_seed_provided: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "if set, dont use seed for initializing random generators"},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CommonEvalConfig(FairseqDataclass):
|
||||
path: Optional[str] = field(
|
||||
default=None, metadata={"help": "path(s) to model file(s), colon separated"}
|
||||
default=None,
|
||||
metadata={"help": "path(s) to model file(s), colon separated"},
|
||||
)
|
||||
remove_bpe: Optional[str] = field(
|
||||
default=None,
|
||||
@ -541,7 +740,7 @@ class CommonEvalParams(FairseqDataclass):
|
||||
|
||||
|
||||
@dataclass
|
||||
class EvalLMParams(FairseqDataclass):
|
||||
class EvalLMConfig(FairseqDataclass):
|
||||
output_word_probs: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
@ -569,37 +768,31 @@ class EvalLMParams(FairseqDataclass):
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingConfig(FairseqDataclass):
|
||||
"""Config for training, a composition of training params"""
|
||||
|
||||
common: CommonParams = CommonParams()
|
||||
distributed_training: DistributedTrainingParams = DistributedTrainingParams()
|
||||
dataset: DatasetParams = DatasetParams()
|
||||
optimization: OptimizationParams = OptimizationParams()
|
||||
checkpoint: CheckpointParams = CheckpointParams()
|
||||
bmuf: FairseqBMUFConfig = FairseqBMUFConfig()
|
||||
class InteractiveConfig(FairseqDataclass):
|
||||
buffer_size: int = field(
|
||||
default=0,
|
||||
metadata={
|
||||
"help": "read this many sentences into a buffer before processing them"
|
||||
},
|
||||
)
|
||||
input: str = field(
|
||||
default="-",
|
||||
metadata={"help": "file to read from; use - for stdin"},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class EvalLMConfig(FairseqDataclass):
|
||||
"""Config for eval lm, a composition of eval_lm params"""
|
||||
|
||||
common: CommonParams = CommonParams()
|
||||
distributed_training: DistributedTrainingParams = DistributedTrainingParams()
|
||||
dataset: DatasetParams = DatasetParams()
|
||||
optimization: OptimizationParams = OptimizationParams()
|
||||
checkpoint: CheckpointParams = CheckpointParams()
|
||||
bmuf: FairseqBMUFConfig = FairseqBMUFConfig()
|
||||
common_eval: CommonEvalParams = CommonEvalParams()
|
||||
eval_lm: EvalLMParams = EvalLMParams()
|
||||
|
||||
|
||||
def register_params_dataclass(
|
||||
cs: ConfigStore, name: str, group: str, data_class: Type[FairseqDataclass]
|
||||
) -> None:
|
||||
"""register params dataclass in config store"""
|
||||
node_ = data_class(_name=data_class.name())
|
||||
cs.store(name=name, group=group, node=node_)
|
||||
CONFIGS = {
|
||||
"common": CommonConfig,
|
||||
"common_eval": CommonEvalConfig,
|
||||
"distributed_training": DistributedTrainingConfig,
|
||||
"dataset": DatasetConfig,
|
||||
"optimization": OptimizationConfig,
|
||||
"checkpoint": CheckpointConfig,
|
||||
"bmuf": FairseqBMUFConfig,
|
||||
"generation": GenerationConfig,
|
||||
"eval_lm": EvalLMConfig,
|
||||
"interactive": InteractiveConfig,
|
||||
}
|
||||
|
||||
|
||||
def register_module_dataclass(
|
||||
@ -608,100 +801,67 @@ def register_module_dataclass(
|
||||
"""register dataclasses defined in modules in config store, for example, in migrated tasks, models, etc."""
|
||||
# note that if `group == model`, we register all model archs, not the model name.
|
||||
for k, v in registry.items():
|
||||
if v is not None:
|
||||
node_ = v(_name=k)
|
||||
cs.store(name=k, group=group, node=node_)
|
||||
node_ = v()
|
||||
node_._name = k
|
||||
cs.store(name=k, group=group, node=node_, provider="fairseq")
|
||||
|
||||
|
||||
def register_training_hydra_cfg(cs: ConfigStore, name: str = "default") -> None:
|
||||
def register_hydra_cfg(cs: ConfigStore, name: str = "default") -> None:
|
||||
"""cs: config store instance, register common training configs"""
|
||||
|
||||
register_params_dataclass(
|
||||
cs, name="training_params", group="params", data_class=TrainingConfig
|
||||
)
|
||||
for k, v in CONFIGS.items():
|
||||
try:
|
||||
cs.store(name=k, node=v())
|
||||
except BaseException:
|
||||
logger.error(f"{k} - {v()}")
|
||||
raise
|
||||
|
||||
register_module_dataclass(cs, TASK_DATACLASS_REGISTRY, "task")
|
||||
register_module_dataclass(cs, MODEL_DATACLASS_REGISTRY, "model")
|
||||
register_module_dataclass(cs, CRITERION_DATACLASS_REGISTRY, "criterion")
|
||||
register_module_dataclass(cs, OPTIMIZER_DATACLASS_REGISTRY, "optimizer")
|
||||
register_module_dataclass(cs, LR_SCHEDULER_DATACLASS_REGISTRY, "lr_scheduler")
|
||||
|
||||
|
||||
def register_eval_lm_hydra_cfg(cs: ConfigStore, name: str = "default") -> None:
|
||||
"""cs: config store instance, register common training configs"""
|
||||
|
||||
register_params_dataclass(
|
||||
cs, name="eval_lm_params", group="params", data_class=EvalLMConfig
|
||||
)
|
||||
|
||||
register_module_dataclass(cs, TASK_DATACLASS_REGISTRY, "task")
|
||||
register_module_dataclass(cs, CRITERION_DATACLASS_REGISTRY, "criterion")
|
||||
register_module_dataclass(cs, OPTIMIZER_DATACLASS_REGISTRY, "optimizer")
|
||||
register_module_dataclass(cs, LR_SCHEDULER_DATACLASS_REGISTRY, "lr_scheduler")
|
||||
for k, v in REGISTRIES.items():
|
||||
register_module_dataclass(cs, v["dataclass_registry"], k)
|
||||
|
||||
|
||||
def _override_attr(
|
||||
sub_node: str, data_class: Type[FairseqDataclass], args: Namespace
|
||||
) -> List[str]:
|
||||
overrides = []
|
||||
for k in data_class.__dataclass_fields__.keys():
|
||||
if k == "_name":
|
||||
|
||||
def get_default(f):
|
||||
if not isinstance(f.default_factory, _MISSING_TYPE):
|
||||
return f.default_factory()
|
||||
return f.default
|
||||
|
||||
for k, v in data_class.__dataclass_fields__.items():
|
||||
if k.startswith("_"):
|
||||
# private member, skip
|
||||
continue
|
||||
if not hasattr(args, k):
|
||||
# print(f"cannot override {sub_node}.{k} since args does not have attribute {k}")
|
||||
continue
|
||||
if getattr(args, k) is None:
|
||||
|
||||
val = get_default(v) if not hasattr(args, k) else getattr(args, k)
|
||||
|
||||
if val is None:
|
||||
overrides.append("{}.{}=null".format(sub_node, k))
|
||||
elif getattr(args, k) == "":
|
||||
elif val == "":
|
||||
overrides.append("{}.{}=''".format(sub_node, k))
|
||||
elif isinstance(getattr(args, k), str):
|
||||
if (
|
||||
getattr(args, k).startswith("[")
|
||||
or getattr(args, k).startswith("(")
|
||||
or getattr(args, k).startswith("{")
|
||||
or ("," in getattr(args, k))
|
||||
):
|
||||
overrides.append("{}.{}='{}'".format(sub_node, k, getattr(args, k)))
|
||||
else:
|
||||
overrides.append("{}.{}={}".format(sub_node, k, getattr(args, k)))
|
||||
elif isinstance(val, str):
|
||||
overrides.append("{}.{}='{}'".format(sub_node, k, val))
|
||||
else:
|
||||
overrides.append("{}.{}={}".format(sub_node, k, getattr(args, k)))
|
||||
overrides.append("{}.{}={}".format(sub_node, k, val))
|
||||
return overrides
|
||||
|
||||
|
||||
def override_training_args(args: Namespace) -> Tuple[List[str], List[str]]:
|
||||
overrides = []
|
||||
|
||||
overrides.extend(_override_attr("params.common", CommonParams, args))
|
||||
overrides.extend(_override_attr("params.dataset", DatasetParams, args))
|
||||
overrides.extend(
|
||||
_override_attr("params.distributed_training", DistributedTrainingParams, args)
|
||||
)
|
||||
overrides.extend(_override_attr("params.optimization", OptimizationParams, args))
|
||||
overrides.extend(_override_attr("params.checkpoint", CheckpointParams, args))
|
||||
overrides.extend(_override_attr("params.bmuf", FairseqBMUFConfig, args))
|
||||
module_overrides, module_deletes = override_module_args(args)
|
||||
overrides.extend(module_overrides)
|
||||
|
||||
return overrides, module_deletes
|
||||
|
||||
|
||||
def override_eval_lm_args(args: Namespace) -> Tuple[List[str], List[str]]:
|
||||
overrides = []
|
||||
|
||||
overrides.extend(_override_attr("params.common", CommonParams, args))
|
||||
overrides.extend(_override_attr("params.dataset", DatasetParams, args))
|
||||
overrides.extend(
|
||||
_override_attr("params.distributed_training", DistributedTrainingParams, args)
|
||||
)
|
||||
overrides.extend(_override_attr("params.common_eval", CommonEvalParams, args))
|
||||
overrides.extend(_override_attr("params.eval_lm", EvalLMParams, args))
|
||||
overrides.extend(_override_attr("params.bmuf", FairseqBMUFConfig, args))
|
||||
module_overrides, module_deletes = override_module_args(args)
|
||||
overrides.extend(module_overrides)
|
||||
|
||||
return overrides, module_deletes
|
||||
def migrate_registry(
|
||||
name, value, registry, args, overrides, deletes, use_name_as_val=False
|
||||
):
|
||||
if value in registry:
|
||||
overrides.append("{}={}".format(name, value))
|
||||
overrides.append("{}._name={}".format(name, value))
|
||||
overrides.extend(_override_attr(name, registry[value], args))
|
||||
elif use_name_as_val and value is not None:
|
||||
overrides.append("{}={}".format(name, value))
|
||||
else:
|
||||
deletes.append(name)
|
||||
|
||||
|
||||
def override_module_args(args: Namespace) -> Tuple[List[str], List[str]]:
|
||||
@ -709,53 +869,34 @@ def override_module_args(args: Namespace) -> Tuple[List[str], List[str]]:
|
||||
overrides = []
|
||||
deletes = []
|
||||
|
||||
for k, v in CONFIGS.items():
|
||||
overrides.extend(_override_attr(k, v, args))
|
||||
|
||||
if args is not None:
|
||||
assert (
|
||||
hasattr(args, "task")
|
||||
and hasattr(args, "criterion")
|
||||
and hasattr(args, "optimizer")
|
||||
and hasattr(args, "lr_scheduler")
|
||||
)
|
||||
if args.task in TASK_DATACLASS_REGISTRY:
|
||||
overrides.append("task={}".format(args.task))
|
||||
overrides.append("task._name={}".format(args.task))
|
||||
overrides.extend(
|
||||
_override_attr("task", TASK_DATACLASS_REGISTRY[args.task], args)
|
||||
if hasattr(args, "task"):
|
||||
migrate_registry(
|
||||
"task", args.task, TASK_DATACLASS_REGISTRY, args, overrides, deletes
|
||||
)
|
||||
else:
|
||||
deletes.append("task")
|
||||
if args.criterion in CRITERION_DATACLASS_REGISTRY:
|
||||
overrides.append("criterion={}".format(args.criterion))
|
||||
overrides.append("criterion._name={}".format(args.criterion))
|
||||
overrides.extend(
|
||||
_override_attr(
|
||||
"criterion", CRITERION_DATACLASS_REGISTRY[args.criterion], args
|
||||
)
|
||||
)
|
||||
else:
|
||||
deletes.append("criterion")
|
||||
if args.optimizer in OPTIMIZER_DATACLASS_REGISTRY:
|
||||
overrides.append("optimizer={}".format(args.optimizer))
|
||||
overrides.append("optimizer._name={}".format(args.optimizer))
|
||||
overrides.extend(
|
||||
_override_attr(
|
||||
"optimizer", OPTIMIZER_DATACLASS_REGISTRY[args.optimizer], args
|
||||
)
|
||||
)
|
||||
else:
|
||||
deletes.append("optimizer")
|
||||
if args.lr_scheduler in LR_SCHEDULER_DATACLASS_REGISTRY:
|
||||
overrides.append("lr_scheduler={}".format(args.lr_scheduler))
|
||||
overrides.append("lr_scheduler._name={}".format(args.lr_scheduler))
|
||||
overrides.extend(
|
||||
_override_attr(
|
||||
"lr_scheduler",
|
||||
LR_SCHEDULER_DATACLASS_REGISTRY[args.lr_scheduler],
|
||||
|
||||
# these options will be set to "None" if they have not yet been migrated
|
||||
# so we can populate them with the entire flat args
|
||||
CORE_REGISTRIES = {"criterion", "optimizer", "lr_scheduler"}
|
||||
|
||||
for k, v in REGISTRIES.items():
|
||||
if hasattr(args, k):
|
||||
migrate_registry(
|
||||
k,
|
||||
getattr(args, k),
|
||||
v["dataclass_registry"],
|
||||
args,
|
||||
overrides,
|
||||
deletes,
|
||||
use_name_as_val=k not in CORE_REGISTRIES,
|
||||
)
|
||||
)
|
||||
else:
|
||||
deletes.append("lr_scheduler")
|
||||
else:
|
||||
deletes.append(k)
|
||||
|
||||
no_dc = True
|
||||
if hasattr(args, "arch"):
|
||||
|
@ -3,17 +3,24 @@
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from argparse import ArgumentParser
|
||||
from dataclasses import MISSING, dataclass
|
||||
import ast
|
||||
from argparse import ArgumentParser, Namespace
|
||||
from dataclasses import _MISSING_TYPE, MISSING, dataclass
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from hydra.core.global_hydra import GlobalHydra
|
||||
from hydra.experimental import compose, initialize
|
||||
from omegaconf import DictConfig, OmegaConf, open_dict
|
||||
|
||||
|
||||
def eval_str_list(x, x_type=float):
|
||||
if x is None:
|
||||
return None
|
||||
if isinstance(x, str):
|
||||
x = eval(x)
|
||||
if len(x) == 0:
|
||||
return []
|
||||
x = ast.literal_eval(x)
|
||||
try:
|
||||
return list(map(x_type, x))
|
||||
except TypeError:
|
||||
@ -70,22 +77,11 @@ class FairseqDataclass:
|
||||
!= self.__dataclass_fields__[attribute_name].default
|
||||
):
|
||||
return getattr(self, attribute_name)
|
||||
return self.__dataclass_fields__[attribute_name].default
|
||||
|
||||
def _get_default_factory(self, attribute_name: str) -> Any:
|
||||
if hasattr(self, attribute_name):
|
||||
if str(getattr(self, attribute_name)).startswith("${"):
|
||||
return str(getattr(self, attribute_name))
|
||||
elif str(self.__dataclass_fields__[attribute_name].default).startswith(
|
||||
"${"
|
||||
):
|
||||
return str(self.__dataclass_fields__[attribute_name].default)
|
||||
elif (
|
||||
getattr(self, attribute_name)
|
||||
!= self.__dataclass_fields__[attribute_name].default_factory()
|
||||
):
|
||||
return getattr(self, attribute_name)
|
||||
return self.__dataclass_fields__[attribute_name].default_factory()
|
||||
f = self.__dataclass_fields__[attribute_name]
|
||||
if not isinstance(f.default_factory, _MISSING_TYPE):
|
||||
return f.default_factory()
|
||||
return f.default
|
||||
|
||||
def _get_type(self, attribute_name: str) -> Any:
|
||||
return self.__dataclass_fields__[attribute_name].type
|
||||
@ -119,7 +115,7 @@ def gen_parser_from_dataclass(
|
||||
|
||||
def interpret_dc_type(field_type):
|
||||
if isinstance(field_type, str):
|
||||
raise RuntimeError()
|
||||
raise RuntimeError("field should be a type")
|
||||
typestring = str(field_type)
|
||||
if re.match(r"(typing.|^)Union\[(.*), NoneType\]$", typestring):
|
||||
return field_type.__args__[0]
|
||||
@ -129,12 +125,13 @@ def gen_parser_from_dataclass(
|
||||
dataclass_instance: FairseqDataclass, k: str
|
||||
) -> Dict[str, Any]:
|
||||
"""k: dataclass attributes"""
|
||||
|
||||
kwargs = {}
|
||||
|
||||
field_type = dataclass_instance._get_type(k)
|
||||
inter_type = interpret_dc_type(field_type)
|
||||
if isinstance(inter_type, type) and issubclass(inter_type, List):
|
||||
field_default = dataclass_instance._get_default_factory(k)
|
||||
else:
|
||||
field_default = dataclass_instance._get_default(k)
|
||||
|
||||
field_default = dataclass_instance._get_default(k)
|
||||
|
||||
if isinstance(inter_type, type) and issubclass(inter_type, Enum):
|
||||
field_choices = [t.value for t in list(inter_type)]
|
||||
@ -143,7 +140,7 @@ def gen_parser_from_dataclass(
|
||||
|
||||
field_help = dataclass_instance._get_help(k)
|
||||
field_const = dataclass_instance._get_argparse_const(k)
|
||||
kwargs = {}
|
||||
|
||||
if isinstance(field_default, str) and field_default.startswith("${"):
|
||||
kwargs["default"] = field_default
|
||||
else:
|
||||
@ -163,7 +160,11 @@ def gen_parser_from_dataclass(
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
if field_default is not MISSING:
|
||||
kwargs["default"] = ",".join(map(str, field_default))
|
||||
kwargs["default"] = (
|
||||
",".join(map(str, field_default))
|
||||
if field_default is not None
|
||||
else None
|
||||
)
|
||||
elif (
|
||||
isinstance(inter_type, type) and issubclass(inter_type, Enum)
|
||||
) or "Enum" in str(inter_type):
|
||||
@ -187,6 +188,7 @@ def gen_parser_from_dataclass(
|
||||
if field_const is not None:
|
||||
kwargs["const"] = field_const
|
||||
kwargs["nargs"] = "?"
|
||||
|
||||
return kwargs
|
||||
|
||||
for k in dataclass_instance._get_all_attributes():
|
||||
@ -194,8 +196,122 @@ def gen_parser_from_dataclass(
|
||||
if field_name is None:
|
||||
continue
|
||||
kwargs = get_kwargs_from_dc(dataclass_instance, k)
|
||||
if isinstance(kwargs["default"], str) and kwargs["default"].startswith("${"):
|
||||
continue
|
||||
if delete_default:
|
||||
del kwargs["default"]
|
||||
|
||||
if "default" in kwargs:
|
||||
if isinstance(kwargs["default"], str) and kwargs["default"].startswith(
|
||||
"${"
|
||||
):
|
||||
continue
|
||||
if delete_default:
|
||||
del kwargs["default"]
|
||||
parser.add_argument(field_name, **kwargs)
|
||||
|
||||
|
||||
def _set_legacy_defaults(args, cls):
|
||||
"""Helper to set default arguments based on *add_args*."""
|
||||
if not hasattr(cls, "add_args"):
|
||||
return
|
||||
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
argument_default=argparse.SUPPRESS, allow_abbrev=False
|
||||
)
|
||||
cls.add_args(parser)
|
||||
# copied from argparse.py:
|
||||
defaults = argparse.Namespace()
|
||||
for action in parser._actions:
|
||||
if action.dest is not argparse.SUPPRESS:
|
||||
if not hasattr(defaults, action.dest):
|
||||
if action.default is not argparse.SUPPRESS:
|
||||
setattr(defaults, action.dest, action.default)
|
||||
for key, default_value in vars(defaults).items():
|
||||
if not hasattr(args, key):
|
||||
setattr(args, key, default_value)
|
||||
|
||||
|
||||
def convert_namespace_to_omegaconf(args: Namespace) -> DictConfig:
|
||||
from fairseq.dataclass.data_class import override_module_args
|
||||
|
||||
# Here we are using field values provided in args to override counterparts inside config object
|
||||
overrides, deletes = override_module_args(args)
|
||||
|
||||
cfg_name = "config"
|
||||
cfg_path = f"../../{cfg_name}"
|
||||
|
||||
if not GlobalHydra().is_initialized():
|
||||
initialize(config_path=cfg_path)
|
||||
|
||||
composed_cfg = compose(cfg_name, overrides=overrides, strict=False)
|
||||
for k in deletes:
|
||||
composed_cfg[k] = None
|
||||
|
||||
cfg = OmegaConf.create(
|
||||
OmegaConf.to_container(composed_cfg, resolve=True, enum_to_str=True)
|
||||
)
|
||||
|
||||
# hack to be able to set Namespace in dict config. this should be removed when we update to newer
|
||||
# omegaconf version that supports object flags, or when we migrate all existing models
|
||||
from omegaconf import _utils
|
||||
|
||||
old_primitive = _utils.is_primitive_type
|
||||
_utils.is_primitive_type = lambda _: True
|
||||
|
||||
if cfg.task is None and getattr(args, "task", None):
|
||||
cfg.task = Namespace(**vars(args))
|
||||
from fairseq.tasks import TASK_REGISTRY
|
||||
|
||||
_set_legacy_defaults(cfg.task, TASK_REGISTRY[args.task])
|
||||
cfg.task._name = args.task
|
||||
if cfg.model is None and getattr(args, "arch", None):
|
||||
cfg.model = Namespace(**vars(args))
|
||||
from fairseq.models import ARCH_MODEL_REGISTRY
|
||||
|
||||
_set_legacy_defaults(cfg.model, ARCH_MODEL_REGISTRY[args.arch])
|
||||
cfg.model._name = args.arch
|
||||
if cfg.optimizer is None and getattr(args, "optimizer", None):
|
||||
cfg.optimizer = Namespace(**vars(args))
|
||||
from fairseq.optim import OPTIMIZER_REGISTRY
|
||||
|
||||
_set_legacy_defaults(cfg.optimizer, OPTIMIZER_REGISTRY[args.optimizer])
|
||||
cfg.optimizer._name = args.optimizer
|
||||
if cfg.lr_scheduler is None and getattr(args, "lr_scheduler", None):
|
||||
cfg.lr_scheduler = Namespace(**vars(args))
|
||||
from fairseq.optim.lr_scheduler import LR_SCHEDULER_REGISTRY
|
||||
|
||||
_set_legacy_defaults(cfg.lr_scheduler, LR_SCHEDULER_REGISTRY[args.lr_scheduler])
|
||||
cfg.lr_scheduler._name = args.lr_scheduler
|
||||
if cfg.criterion is None and getattr(args, "criterion", None):
|
||||
cfg.criterion = Namespace(**vars(args))
|
||||
from fairseq.criterions import CRITERION_REGISTRY
|
||||
|
||||
_set_legacy_defaults(cfg.criterion, CRITERION_REGISTRY[args.criterion])
|
||||
cfg.criterion._name = args.criterion
|
||||
|
||||
_utils.is_primitive_type = old_primitive
|
||||
OmegaConf.set_struct(cfg, True)
|
||||
return cfg
|
||||
|
||||
|
||||
def populate_dataclass(
|
||||
args: Namespace, dataclass: FairseqDataclass
|
||||
) -> FairseqDataclass:
|
||||
for k in dataclass.__dataclass_fields__.keys():
|
||||
if k.startswith("_"):
|
||||
# private member, skip
|
||||
continue
|
||||
if hasattr(args, k):
|
||||
setattr(dataclass, k, getattr(args, k))
|
||||
|
||||
return dataclass
|
||||
|
||||
|
||||
def overwrite_args_by_name(cfg: DictConfig, overrides: Dict[str, any]):
|
||||
# this will be deprecated when we get rid of argparse and model_overrides logic
|
||||
|
||||
with open_dict(cfg):
|
||||
for k in cfg.keys():
|
||||
if isinstance(cfg[k], DictConfig):
|
||||
overwrite_args_by_name(cfg[k], overrides)
|
||||
elif k in overrides:
|
||||
cfg[k] = overrides[k]
|
||||
|
@ -11,35 +11,38 @@ import socket
|
||||
import struct
|
||||
import subprocess
|
||||
import warnings
|
||||
from argparse import Namespace
|
||||
from collections import OrderedDict
|
||||
from typing import Any, Dict, Mapping
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from fairseq import utils
|
||||
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
|
||||
from omegaconf import DictConfig, open_dict
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def is_master(args):
|
||||
return args.distributed_rank == 0
|
||||
def is_master(cfg: DictConfig):
|
||||
return cfg.distributed_rank == 0
|
||||
|
||||
|
||||
def infer_init_method(args, force_distributed=False):
|
||||
if args.distributed_init_method is not None or getattr(args, "tpu", False):
|
||||
def infer_init_method(cfg: DictConfig, force_distributed=False):
|
||||
if cfg.distributed_init_method is not None or cfg.tpu:
|
||||
return
|
||||
|
||||
if args.pipeline_model_parallel:
|
||||
if cfg.pipeline_model_parallel:
|
||||
balance_exists = (
|
||||
args.pipeline_balance is not None
|
||||
or args.pipeline_encoder_balance is not None
|
||||
or args.pipeline_decoder_balance is not None
|
||||
cfg.pipeline_balance is not None
|
||||
or cfg.pipeline_encoder_balance is not None
|
||||
or cfg.pipeline_decoder_balance is not None
|
||||
)
|
||||
devices_exist = (
|
||||
args.pipeline_devices is not None
|
||||
or args.pipeline_encoder_devices is not None
|
||||
or args.pipeline_decoder_devices is not None
|
||||
cfg.pipeline_devices is not None
|
||||
or cfg.pipeline_encoder_devices is not None
|
||||
or cfg.pipeline_decoder_devices is not None
|
||||
)
|
||||
if not balance_exists:
|
||||
raise ValueError(
|
||||
@ -50,19 +53,19 @@ def infer_init_method(args, force_distributed=False):
|
||||
"--pipeline-devices is currently required for pipeline model parallelism"
|
||||
)
|
||||
|
||||
args.pipeline_balance = utils.eval_str_list(args.pipeline_balance, type=int)
|
||||
if args.pipeline_devices is not None:
|
||||
args.pipeline_devices = utils.eval_str_list(args.pipeline_devices, type=int)
|
||||
num_pipeline_devices = len(set(args.pipeline_devices))
|
||||
cfg.pipeline_balance = utils.eval_str_list(cfg.pipeline_balance, type=int)
|
||||
if cfg.pipeline_devices is not None:
|
||||
cfg.pipeline_devices = utils.eval_str_list(cfg.pipeline_devices, type=int)
|
||||
num_pipeline_devices = len(set(cfg.pipeline_devices))
|
||||
else:
|
||||
args.pipeline_encoder_devices = utils.eval_str_list(
|
||||
args.pipeline_encoder_devices, type=int
|
||||
cfg.pipeline_encoder_devices = utils.eval_str_list(
|
||||
cfg.pipeline_encoder_devices, type=int
|
||||
)
|
||||
args.pipeline_decoder_devices = utils.eval_str_list(
|
||||
args.pipeline_decoder_devices, type=int
|
||||
cfg.pipeline_decoder_devices = utils.eval_str_list(
|
||||
cfg.pipeline_decoder_devices, type=int
|
||||
)
|
||||
num_pipeline_devices = len(
|
||||
set(args.pipeline_encoder_devices + args.pipeline_decoder_devices)
|
||||
set(cfg.pipeline_encoder_devices + cfg.pipeline_decoder_devices)
|
||||
)
|
||||
gpus_per_node = torch.cuda.device_count()
|
||||
assert (
|
||||
@ -79,14 +82,14 @@ def infer_init_method(args, force_distributed=False):
|
||||
key in os.environ
|
||||
for key in ["MASTER_ADDR", "MASTER_PORT", "WORLD_SIZE", "RANK"]
|
||||
):
|
||||
args.distributed_init_method = "env://"
|
||||
args.distributed_world_size = int(os.environ["WORLD_SIZE"])
|
||||
args.distributed_rank = int(os.environ["RANK"])
|
||||
cfg.distributed_init_method = "env://"
|
||||
cfg.distributed_world_size = int(os.environ["WORLD_SIZE"])
|
||||
cfg.distributed_rank = int(os.environ["RANK"])
|
||||
# processes are created by torch.distributed.launch
|
||||
args.distributed_no_spawn = True
|
||||
cfg.distributed_no_spawn = True
|
||||
|
||||
# we can determine the init method automatically for Slurm
|
||||
elif args.distributed_port > 0:
|
||||
elif cfg.distributed_port > 0:
|
||||
node_list = os.environ.get("SLURM_STEP_NODELIST")
|
||||
if node_list is None:
|
||||
node_list = os.environ.get("SLURM_JOB_NODELIST")
|
||||
@ -95,9 +98,9 @@ def infer_init_method(args, force_distributed=False):
|
||||
hostnames = subprocess.check_output(
|
||||
["scontrol", "show", "hostnames", node_list]
|
||||
)
|
||||
args.distributed_init_method = "tcp://{host}:{port}".format(
|
||||
cfg.distributed_init_method = "tcp://{host}:{port}".format(
|
||||
host=hostnames.split()[0].decode("utf-8"),
|
||||
port=args.distributed_port,
|
||||
port=cfg.distributed_port,
|
||||
)
|
||||
nnodes = int(os.environ.get("SLURM_NNODES"))
|
||||
ntasks_per_node = os.environ.get("SLURM_NTASKS_PER_NODE")
|
||||
@ -111,88 +114,94 @@ def infer_init_method(args, force_distributed=False):
|
||||
if ntasks_per_node == 1:
|
||||
gpus_per_node = torch.cuda.device_count()
|
||||
node_id = int(os.environ.get("SLURM_NODEID"))
|
||||
args.distributed_rank = node_id * gpus_per_node
|
||||
args.distributed_world_size = nnodes * gpus_per_node
|
||||
elif args.pipeline_model_parallel:
|
||||
cfg.distributed_rank = node_id * gpus_per_node
|
||||
cfg.distributed_world_size = nnodes * gpus_per_node
|
||||
elif cfg.pipeline_model_parallel:
|
||||
assert ntasks_per_node == num_pipelines_per_node, (
|
||||
"SLURM --ntasks-per-node must match number of pipelines per "
|
||||
"node (={})".format(num_pipelines_per_node)
|
||||
)
|
||||
args.distributed_no_spawn = True
|
||||
cfg.distributed_no_spawn = True
|
||||
# For 4-way MP on nodes with 8 GPUs, ranks will be [0, 1] on
|
||||
# the first node, [1, 2] on the second node, etc. This
|
||||
# matches torch.distributed.launch.
|
||||
node_id = int(os.environ.get("SLURM_NODEID"))
|
||||
local_id = int(os.environ.get("SLURM_LOCALID"))
|
||||
args.distributed_rank = node_id * num_pipelines_per_node + local_id
|
||||
cfg.distributed_rank = node_id * num_pipelines_per_node + local_id
|
||||
# In the above example, device_id will always be in [0, 1],
|
||||
# which also matches torch.distributed.launch.
|
||||
args.device_id = local_id
|
||||
cfg.device_id = local_id
|
||||
# We also want to set distributed_world_size to be the total
|
||||
# number of pipelines across all nodes.
|
||||
args.distributed_world_size = nnodes * num_pipelines_per_node
|
||||
cfg.distributed_world_size = nnodes * num_pipelines_per_node
|
||||
else:
|
||||
assert ntasks_per_node == args.distributed_world_size // nnodes
|
||||
args.distributed_no_spawn = True
|
||||
args.distributed_rank = int(os.environ.get("SLURM_PROCID"))
|
||||
args.device_id = int(os.environ.get("SLURM_LOCALID"))
|
||||
assert ntasks_per_node == cfg.distributed_world_size // nnodes
|
||||
cfg.distributed_no_spawn = True
|
||||
cfg.distributed_rank = int(os.environ.get("SLURM_PROCID"))
|
||||
cfg.device_id = int(os.environ.get("SLURM_LOCALID"))
|
||||
except subprocess.CalledProcessError as e: # scontrol failed
|
||||
raise e
|
||||
except FileNotFoundError: # Slurm is not installed
|
||||
pass
|
||||
|
||||
elif args.distributed_world_size > 1 or force_distributed:
|
||||
elif cfg.distributed_world_size > 1 or force_distributed:
|
||||
# fallback for single node with multiple GPUs
|
||||
assert args.distributed_world_size <= torch.cuda.device_count()
|
||||
assert cfg.distributed_world_size <= torch.cuda.device_count()
|
||||
port = random.randint(10000, 20000)
|
||||
args.distributed_init_method = "tcp://localhost:{port}".format(port=port)
|
||||
cfg.distributed_init_method = "tcp://localhost:{port}".format(port=port)
|
||||
|
||||
if args.pipeline_model_parallel:
|
||||
if not args.distributed_no_spawn:
|
||||
if cfg.pipeline_model_parallel:
|
||||
if not cfg.distributed_no_spawn:
|
||||
# When distributed_no_spawn is False, we expect distributed_rank and
|
||||
# distributed_world_size to be based on the total number of GPUs, so
|
||||
# we need to correct them to be based on the number of pipelines.
|
||||
assert args.distributed_world_size % num_pipeline_devices == 0
|
||||
args.distributed_world_size = (
|
||||
args.distributed_world_size // num_pipeline_devices
|
||||
assert cfg.distributed_world_size % num_pipeline_devices == 0
|
||||
cfg.distributed_world_size = (
|
||||
cfg.distributed_world_size // num_pipeline_devices
|
||||
)
|
||||
# In the case of 4-way MP on nodes with 8 GPUs, we want
|
||||
# distributed_rank to be the starting GPU index for each pipeline
|
||||
# i.e., 0, 2, ...
|
||||
assert args.distributed_rank % gpus_per_node == 0
|
||||
assert args.distributed_rank % num_pipeline_devices == 0
|
||||
args.distributed_rank = args.distributed_rank // num_pipeline_devices
|
||||
# launch one process per pipeline
|
||||
args.distributed_num_procs = num_pipelines_per_node
|
||||
assert cfg.distributed_rank % gpus_per_node == 0
|
||||
assert cfg.distributed_rank % num_pipeline_devices == 0
|
||||
|
||||
with open_dict(cfg):
|
||||
cfg.distributed_rank = cfg.distributed_rank // num_pipeline_devices
|
||||
# launch one process per pipeline
|
||||
cfg.distributed_num_procs = num_pipelines_per_node
|
||||
|
||||
# if we have 4-way MP on a node with 8 GPUs, we want device_ids to be 0
|
||||
# and 4, indicating the starting device IDs for each pipeline
|
||||
args.device_id *= num_pipeline_devices
|
||||
cfg.device_id *= num_pipeline_devices
|
||||
|
||||
if args.device_id > 0:
|
||||
if cfg.device_id > 0:
|
||||
# if there's multiple pipelines on a node (e.g., 4-way MP on an 8
|
||||
# GPU node), we need to adjust pipeline_devices accordingly
|
||||
logger.debug(
|
||||
"setting CUDA device={} on rank {}".format(
|
||||
args.device_id, args.distributed_rank
|
||||
cfg.device_id, cfg.distributed_rank
|
||||
)
|
||||
)
|
||||
torch.cuda.set_device(args.device_id)
|
||||
args.pipeline_devices = [args.device_id + d for d in args.pipeline_devices]
|
||||
torch.cuda.set_device(cfg.device_id)
|
||||
with open_dict(cfg):
|
||||
cfg.pipeline_devices = [cfg.device_id + d for d in cfg.pipeline_devices]
|
||||
logger.info(
|
||||
"setting pipeline_devices={} on rank {}".format(
|
||||
args.pipeline_devices, args.distributed_rank
|
||||
),
|
||||
cfg.pipeline_devices, cfg.distributed_rank
|
||||
)
|
||||
)
|
||||
elif not cfg.distributed_no_spawn:
|
||||
with open_dict(cfg):
|
||||
cfg.distributed_num_procs = min(
|
||||
torch.cuda.device_count(), cfg.distributed_world_size
|
||||
)
|
||||
elif not args.distributed_no_spawn:
|
||||
args.distributed_num_procs = min(
|
||||
torch.cuda.device_count(),
|
||||
args.distributed_world_size,
|
||||
)
|
||||
|
||||
|
||||
def distributed_init(args):
|
||||
if not getattr(args, "tpu", False):
|
||||
def distributed_init(cfg: DictConfig):
|
||||
if isinstance(cfg, Namespace):
|
||||
cfg = convert_namespace_to_omegaconf(cfg)
|
||||
|
||||
if not cfg.common.tpu:
|
||||
if torch.distributed.is_initialized():
|
||||
warnings.warn(
|
||||
"Distributed is already initialized, cannot initialize twice!"
|
||||
@ -200,20 +209,20 @@ def distributed_init(args):
|
||||
else:
|
||||
logger.info(
|
||||
"distributed init (rank {}): {}".format(
|
||||
args.distributed_rank,
|
||||
args.distributed_init_method,
|
||||
cfg.distributed_training.distributed_rank,
|
||||
cfg.distributed_training.distributed_init_method,
|
||||
)
|
||||
)
|
||||
dist.init_process_group(
|
||||
backend=args.distributed_backend,
|
||||
init_method=args.distributed_init_method,
|
||||
world_size=args.distributed_world_size,
|
||||
rank=args.distributed_rank,
|
||||
backend=cfg.distributed_training.distributed_backend,
|
||||
init_method=cfg.distributed_training.distributed_init_method,
|
||||
world_size=cfg.distributed_training.distributed_world_size,
|
||||
rank=cfg.distributed_training.distributed_rank,
|
||||
)
|
||||
logger.info(
|
||||
"initialized host {} as rank {}".format(
|
||||
socket.gethostname(),
|
||||
args.distributed_rank,
|
||||
cfg.distributed_training.distributed_rank,
|
||||
)
|
||||
)
|
||||
|
||||
@ -221,20 +230,22 @@ def distributed_init(args):
|
||||
if torch.cuda.is_available():
|
||||
dist.all_reduce(torch.zeros(1).cuda())
|
||||
|
||||
args.distributed_rank = torch.distributed.get_rank()
|
||||
cfg.distributed_training.distributed_rank = torch.distributed.get_rank()
|
||||
else:
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
assert xm.xrt_world_size() == args.distributed_world_size
|
||||
args.device_id = xm.get_local_ordinal()
|
||||
args.distributed_rank = xm.get_ordinal()
|
||||
assert xm.xrt_world_size() == cfg.distributed_training.distributed_world_size
|
||||
cfg.distributed_training.device_id = xm.get_local_ordinal()
|
||||
cfg.distributed_training.distributed_rank = xm.get_ordinal()
|
||||
xm.rendezvous("distributed_init") # wait for all workers
|
||||
xm.mark_step()
|
||||
|
||||
if not is_master(args):
|
||||
if is_master(cfg.distributed_training):
|
||||
logging.getLogger().setLevel(logging.INFO)
|
||||
else:
|
||||
logging.getLogger().setLevel(logging.WARNING)
|
||||
|
||||
if args.model_parallel_size > 1:
|
||||
if cfg.common.model_parallel_size > 1:
|
||||
try:
|
||||
from fairseq.model_parallel.megatron.mpu import (
|
||||
get_model_parallel_rank,
|
||||
@ -247,58 +258,61 @@ def distributed_init(args):
|
||||
"\n\n git submodule update --init "
|
||||
"fairseq/model_parallel/megatron"
|
||||
)
|
||||
initialize_model_parallel(args.model_parallel_size)
|
||||
model_parallel_cuda_manual_seed(args.seed)
|
||||
initialize_model_parallel(cfg.common.model_parallel_size)
|
||||
model_parallel_cuda_manual_seed(cfg.common.seed)
|
||||
model_part_number = get_model_parallel_rank()
|
||||
args.checkpoint_suffix += "-model_part-{0}".format(model_part_number)
|
||||
return args.distributed_rank
|
||||
cfg.checkpoint.checkpoint_suffix += "-model_part-{0}".format(model_part_number)
|
||||
return cfg.distributed_training.distributed_rank
|
||||
|
||||
|
||||
def distributed_main(i, main, args, kwargs):
|
||||
args.device_id = i
|
||||
if torch.cuda.is_available() and not args.cpu and not getattr(args, "tpu", False):
|
||||
torch.cuda.set_device(args.device_id)
|
||||
if args.distributed_rank is None: # torch.multiprocessing.spawn
|
||||
args.distributed_rank = kwargs.pop("start_rank", 0) + i
|
||||
def distributed_main(i, main, cfg: DictConfig, kwargs):
|
||||
cfg.distributed_training.device_id = i
|
||||
if torch.cuda.is_available() and not cfg.common.cpu and not cfg.common.tpu:
|
||||
torch.cuda.set_device(cfg.distributed_training.device_id)
|
||||
if cfg.distributed_training.distributed_rank is None: # torch.multiprocessing.spawn
|
||||
cfg.distributed_training.distributed_rank = kwargs.pop("start_rank", 0) + i
|
||||
|
||||
args.distributed_rank = distributed_init(args)
|
||||
cfg.distributed_training.distributed_rank = distributed_init(cfg)
|
||||
|
||||
after_distributed_init_fn = kwargs.pop("after_distributed_init_fn", None)
|
||||
if after_distributed_init_fn:
|
||||
args = after_distributed_init_fn(args)
|
||||
cfg = after_distributed_init_fn(cfg)
|
||||
|
||||
main(args, **kwargs)
|
||||
main(cfg, **kwargs)
|
||||
|
||||
|
||||
def call_main(args, main, **kwargs):
|
||||
if args.distributed_init_method is None:
|
||||
infer_init_method(args)
|
||||
def call_main(cfg: DictConfig, main, **kwargs):
|
||||
if cfg.distributed_training.distributed_init_method is None:
|
||||
infer_init_method(cfg.distributed_training)
|
||||
|
||||
if args.distributed_init_method is not None:
|
||||
if cfg.distributed_training.distributed_init_method is not None:
|
||||
# distributed training
|
||||
if not args.distributed_no_spawn:
|
||||
start_rank = args.distributed_rank
|
||||
args.distributed_rank = None # assign automatically
|
||||
if not cfg.distributed_training.distributed_no_spawn:
|
||||
start_rank = cfg.distributed_training.distributed_rank
|
||||
cfg.distributed_training.distributed_rank = None # assign automatically
|
||||
kwargs["start_rank"] = start_rank
|
||||
torch.multiprocessing.spawn(
|
||||
fn=distributed_main,
|
||||
args=(main, args, kwargs),
|
||||
nprocs=args.distributed_num_procs,
|
||||
args=(main, cfg, kwargs),
|
||||
nprocs=min(
|
||||
torch.cuda.device_count(),
|
||||
cfg.distributed_training.distributed_world_size,
|
||||
),
|
||||
)
|
||||
else:
|
||||
distributed_main(args.device_id, main, args, kwargs)
|
||||
elif getattr(args, "tpu", False) and args.distributed_world_size > 1:
|
||||
distributed_main(cfg.distributed_training.device_id, main, cfg, kwargs)
|
||||
elif cfg.common.tpu and cfg.distributed_training.distributed_world_size > 1:
|
||||
import torch_xla.distributed.xla_multiprocessing as xmp
|
||||
|
||||
torch.multiprocessing.set_sharing_strategy("file_system")
|
||||
xmp.spawn(
|
||||
fn=distributed_main,
|
||||
args=(main, args, kwargs),
|
||||
args=(main, cfg, kwargs),
|
||||
nprocs=8, # use all 8 TPU cores
|
||||
)
|
||||
else:
|
||||
# single GPU main
|
||||
main(args, **kwargs)
|
||||
main(cfg, **kwargs)
|
||||
|
||||
|
||||
def get_rank():
|
||||
@ -392,11 +406,7 @@ def all_gather_list(data, group=None, max_size=16384):
|
||||
)
|
||||
|
||||
|
||||
def all_reduce_dict(
|
||||
data: Mapping[str, Any],
|
||||
device,
|
||||
group=None,
|
||||
) -> Dict[str, Any]:
|
||||
def all_reduce_dict(data: Mapping[str, Any], device, group=None) -> Dict[str, Any]:
|
||||
"""
|
||||
AllReduce a dictionary of values across workers. We separately
|
||||
reduce items that are already on the device and items on CPU for
|
||||
|
@ -8,11 +8,12 @@ import argparse
|
||||
import copy
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Dict, Iterator, List, Tuple
|
||||
from typing import Any, Dict, Iterator, List
|
||||
|
||||
import torch
|
||||
from fairseq import utils
|
||||
from fairseq.data import encoders
|
||||
from omegaconf import open_dict
|
||||
from torch import nn
|
||||
|
||||
|
||||
@ -85,9 +86,9 @@ class GeneratorHubInterface(nn.Module):
|
||||
translation or language model.
|
||||
"""
|
||||
|
||||
def __init__(self, args, task, models):
|
||||
def __init__(self, cfg, task, models):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.cfg = cfg
|
||||
self.task = task
|
||||
self.models = nn.ModuleList(models)
|
||||
self.src_dict = task.source_dictionary
|
||||
@ -95,14 +96,14 @@ class GeneratorHubInterface(nn.Module):
|
||||
|
||||
# optimize model for generation
|
||||
for model in self.models:
|
||||
model.prepare_for_inference_(args)
|
||||
model.prepare_for_inference_(cfg)
|
||||
|
||||
# Load alignment dictionary for unknown word replacement
|
||||
# (None if no unknown word replacement, empty if no path to align dictionary)
|
||||
self.align_dict = utils.load_align_dict(getattr(args, "replace_unk", None))
|
||||
self.align_dict = utils.load_align_dict(cfg.generation.replace_unk)
|
||||
|
||||
self.tokenizer = encoders.build_tokenizer(args)
|
||||
self.bpe = encoders.build_bpe(args)
|
||||
self.tokenizer = encoders.build_tokenizer(cfg.tokenizer)
|
||||
self.bpe = encoders.build_bpe(cfg.bpe)
|
||||
|
||||
self.max_positions = utils.resolve_max_positions(
|
||||
self.task.max_positions(), *[model.max_positions() for model in models]
|
||||
@ -156,10 +157,11 @@ class GeneratorHubInterface(nn.Module):
|
||||
)[0]
|
||||
|
||||
# build generator using current args as well as any kwargs
|
||||
gen_args = copy.copy(self.args)
|
||||
gen_args.beam = beam
|
||||
for k, v in kwargs.items():
|
||||
setattr(gen_args, k, v)
|
||||
gen_args = copy.copy(self.cfg)
|
||||
with open_dict(gen_args):
|
||||
gen_args.beam = beam
|
||||
for k, v in kwargs.items():
|
||||
setattr(gen_args, k, v)
|
||||
generator = self.task.build_generator(self.models, gen_args)
|
||||
|
||||
inference_step_args = inference_step_args or {}
|
||||
@ -253,8 +255,8 @@ class GeneratorHubInterface(nn.Module):
|
||||
lengths = torch.LongTensor([t.numel() for t in tokens])
|
||||
batch_iterator = self.task.get_batch_iterator(
|
||||
dataset=self.task.build_dataset_for_inference(tokens, lengths),
|
||||
max_tokens=self.args.max_tokens,
|
||||
max_sentences=self.args.batch_size,
|
||||
max_tokens=self.cfg.dataset.max_tokens,
|
||||
max_sentences=self.cfg.dataset.batch_size,
|
||||
max_positions=self.max_positions,
|
||||
ignore_invalid_inputs=skip_invalid_size_inputs,
|
||||
disable_iterator_cache=True,
|
||||
|
@ -9,6 +9,7 @@ Train a network across multiple GPUs.
|
||||
|
||||
from fairseq import distributed_utils
|
||||
from fairseq.trainer import Trainer
|
||||
from omegaconf import DictConfig
|
||||
|
||||
|
||||
try:
|
||||
@ -28,14 +29,14 @@ except (ImportError, ModuleNotFoundError):
|
||||
class MegatronTrainer(Trainer):
|
||||
"""Main class for model parallel with data parallel training."""
|
||||
|
||||
def __init__(self, args, task, model, criterion):
|
||||
def __init__(self, cfg: DictConfig, task, model, criterion, **kwargs):
|
||||
if not has_megatron_submodule:
|
||||
raise ImportError(
|
||||
"\n\nPlease install the megatron submodule:"
|
||||
"\n\n git submodule update --init "
|
||||
"fairseq/model_parallel/megatron"
|
||||
)
|
||||
super().__init__(args, task, model, criterion)
|
||||
super().__init__(cfg, task, model, criterion, **kwargs)
|
||||
|
||||
@property
|
||||
def data_parallel_world_size(self):
|
||||
|
@ -96,7 +96,7 @@ class PipelineParallelTransformerModel(BaseFairseqModel):
|
||||
encoder_output_tuple = self.encoder(input)
|
||||
return self.decoder(encoder_output_tuple)
|
||||
|
||||
def prepare_for_inference_(self, args):
|
||||
def prepare_for_inference_(self, cfg):
|
||||
if self.encoder is not None and self.decoder is not None:
|
||||
logger.info("Encoder and Decoder already initialized")
|
||||
return
|
||||
@ -111,9 +111,9 @@ class PipelineParallelTransformerModel(BaseFairseqModel):
|
||||
decoder_module_list.append(module)
|
||||
module_count += 1
|
||||
self.model = None
|
||||
self.encoder = TransformerEncoder(args, None, None, encoder_module_list)
|
||||
self.encoder = TransformerEncoder(cfg.model, None, None, encoder_module_list)
|
||||
self.decoder = TransformerDecoder(
|
||||
args, None, None, decoder_module_list=decoder_module_list
|
||||
cfg.model, None, None, decoder_module_list=decoder_module_list
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@ -320,7 +320,7 @@ class PipelineParallelTransformerModel(BaseFairseqModel):
|
||||
"""Maximum length supported by the decoder."""
|
||||
return self.decoder_max_positions
|
||||
|
||||
def load_state_dict(self, state_dict, strict=True, args=None):
|
||||
def load_state_dict(self, state_dict, strict=True, cfg=None):
|
||||
"""Copies parameters and buffers from *state_dict* into this module and
|
||||
its descendants.
|
||||
|
||||
|
@ -72,6 +72,10 @@ class ModelParallelTransformerLanguageModel(TransformerLanguageModel):
|
||||
)
|
||||
return cls(decoder)
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
TransformerLanguageModel.add_args(parser)
|
||||
|
||||
@classmethod
|
||||
def build_embedding(cls, args, dictionary, embed_dim, path=None):
|
||||
def _vocab_init(tensor, **kwargs):
|
||||
|
@ -7,8 +7,6 @@
|
||||
import argparse
|
||||
import importlib
|
||||
import os
|
||||
from argparse import Namespace
|
||||
from typing import Union
|
||||
|
||||
import fairseq
|
||||
from fairseq.dataclass import FairseqDataclass
|
||||
@ -52,10 +50,10 @@ __all__ = [
|
||||
]
|
||||
|
||||
|
||||
def build_model(model_cfg: Union[DictConfig, Namespace], task):
|
||||
if isinstance(model_cfg, DictConfig):
|
||||
return ARCH_MODEL_REGISTRY[model_cfg._name].build_model(model_cfg, task)
|
||||
return ARCH_MODEL_REGISTRY[model_cfg.arch].build_model(model_cfg, task)
|
||||
def build_model(cfg: DictConfig, task):
|
||||
if isinstance(cfg, DictConfig):
|
||||
return ARCH_MODEL_REGISTRY[cfg._name].build_model(cfg, task)
|
||||
return ARCH_MODEL_REGISTRY[cfg.arch].build_model(cfg, task)
|
||||
|
||||
|
||||
def register_model(name, dataclass=None):
|
||||
@ -92,7 +90,8 @@ def register_model(name, dataclass=None):
|
||||
)
|
||||
|
||||
cls.__dataclass = dataclass
|
||||
MODEL_DATACLASS_REGISTRY[name] = dataclass
|
||||
if dataclass is not None:
|
||||
MODEL_DATACLASS_REGISTRY[name] = dataclass
|
||||
return cls
|
||||
|
||||
return register_model_cls
|
||||
@ -108,14 +107,13 @@ def register_model_architecture(model_name, arch_name):
|
||||
For example::
|
||||
|
||||
@register_model_architecture('lstm', 'lstm_luong_wmt_en_de')
|
||||
def lstm_luong_wmt_en_de(args):
|
||||
args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 1000)
|
||||
def lstm_luong_wmt_en_de(cfg):
|
||||
args.encoder_embed_dim = getattr(cfg.model, 'encoder_embed_dim', 1000)
|
||||
(...)
|
||||
|
||||
The decorated function should take a single argument *args*, which is a
|
||||
:class:`argparse.Namespace` of arguments parsed from the command-line. The
|
||||
decorated function should modify these arguments in-place to match the
|
||||
desired architecture.
|
||||
The decorated function should take a single argument *cfg*, which is a
|
||||
:class:`omegaconf.DictConfig`. The decorated function should modify these
|
||||
arguments in-place to match the desired architecture.
|
||||
|
||||
Args:
|
||||
model_name (str): the name of the Model (Model must already be
|
||||
|
@ -13,6 +13,7 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from fairseq import utils
|
||||
from fairseq.data import encoders
|
||||
from omegaconf import open_dict
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -24,13 +25,13 @@ class BARTHubInterface(nn.Module):
|
||||
Usage: https://github.com/pytorch/fairseq/tree/master/examples/bart
|
||||
"""
|
||||
|
||||
def __init__(self, args, task, model):
|
||||
def __init__(self, cfg, task, model):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.cfg = cfg
|
||||
self.task = task
|
||||
self.model = model
|
||||
|
||||
self.bpe = encoders.build_bpe(args)
|
||||
self.bpe = encoders.build_bpe(cfg.bpe)
|
||||
|
||||
self.max_positions = min(
|
||||
utils.resolve_max_positions(
|
||||
@ -120,10 +121,11 @@ class BARTHubInterface(nn.Module):
|
||||
sample = self._build_sample(tokens)
|
||||
|
||||
# build generator using current args as well as any kwargs
|
||||
gen_args = copy.copy(self.args)
|
||||
gen_args.beam = beam
|
||||
for k, v in kwargs.items():
|
||||
setattr(gen_args, k, v)
|
||||
gen_args = copy.copy(self.cfg)
|
||||
with open_dict(gen_args):
|
||||
gen_args.beam = beam
|
||||
for k, v in kwargs.items():
|
||||
setattr(gen_args, k, v)
|
||||
generator = self.task.build_generator([self.model], gen_args)
|
||||
translations = self.task.inference_step(
|
||||
generator,
|
||||
|
@ -144,7 +144,9 @@ class BARTModel(TransformerModel):
|
||||
num_classes=num_classes,
|
||||
activation_fn=self.args.pooler_activation_fn,
|
||||
pooler_dropout=self.args.pooler_dropout,
|
||||
do_spectral_norm=self.args.spectral_norm_classification_head,
|
||||
do_spectral_norm=getattr(
|
||||
self.args, "spectral_norm_classification_head", False
|
||||
),
|
||||
)
|
||||
|
||||
def upgrade_state_dict_named(self, state_dict, name):
|
||||
|
@ -7,6 +7,7 @@ Base classes for various fairseq models.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from argparse import Namespace
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
@ -15,8 +16,12 @@ import torch.nn.functional as F
|
||||
from fairseq import utils
|
||||
from fairseq.checkpoint_utils import prune_state_dict
|
||||
from fairseq.data import Dictionary
|
||||
from fairseq.dataclass.utils import gen_parser_from_dataclass
|
||||
from fairseq.dataclass.utils import (
|
||||
convert_namespace_to_omegaconf,
|
||||
gen_parser_from_dataclass,
|
||||
)
|
||||
from fairseq.models import FairseqDecoder, FairseqEncoder
|
||||
from omegaconf import DictConfig
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
@ -86,15 +91,26 @@ class BaseFairseqModel(nn.Module):
|
||||
"""Maximum length supported by the model."""
|
||||
return None
|
||||
|
||||
def load_state_dict(self, state_dict, strict=True, args=None):
|
||||
def load_state_dict(
|
||||
self,
|
||||
state_dict,
|
||||
strict=True,
|
||||
model_cfg: Optional[DictConfig] = None,
|
||||
args: Optional[Namespace] = None,
|
||||
):
|
||||
"""Copies parameters and buffers from *state_dict* into this module and
|
||||
its descendants.
|
||||
|
||||
Overrides the method in :class:`nn.Module`. Compared with that method
|
||||
this additionally "upgrades" *state_dicts* from old checkpoints.
|
||||
"""
|
||||
|
||||
if model_cfg is None and args is not None:
|
||||
logger.warn("using 'args' is deprecated, please update your code to use dataclass config")
|
||||
model_cfg = convert_namespace_to_omegaconf(args).model
|
||||
|
||||
self.upgrade_state_dict(state_dict)
|
||||
new_state_dict = prune_state_dict(state_dict, args)
|
||||
new_state_dict = prune_state_dict(state_dict, model_cfg)
|
||||
return super().load_state_dict(new_state_dict, strict)
|
||||
|
||||
def upgrade_state_dict(self, state_dict):
|
||||
@ -133,18 +149,18 @@ class BaseFairseqModel(nn.Module):
|
||||
|
||||
self.apply(_apply)
|
||||
|
||||
def prepare_for_inference_(self, args):
|
||||
def prepare_for_inference_(self, cfg: DictConfig):
|
||||
"""Prepare model for inference."""
|
||||
kwargs = {}
|
||||
kwargs["beamable_mm_beam_size"] = (
|
||||
None if getattr(args, "no_beamable_mm", False) else getattr(args, "beam", 5)
|
||||
None
|
||||
if getattr(cfg.generation, "no_beamable_mm", False)
|
||||
else getattr(cfg.generation, "beam", 5)
|
||||
)
|
||||
kwargs["need_attn"] = getattr(args, "print_alignment", False)
|
||||
if hasattr(args, "retain_dropout"):
|
||||
kwargs["retain_dropout"] = args.retain_dropout
|
||||
kwargs["retain_dropout_modules"] = getattr(
|
||||
args, "retain_dropout_modules", None
|
||||
)
|
||||
kwargs["need_attn"] = getattr(cfg.generation, "print_alignment", False)
|
||||
if getattr(cfg.generation, "retain_dropout", False):
|
||||
kwargs["retain_dropout"] = cfg.generation.retain_dropout
|
||||
kwargs["retain_dropout_modules"] = cfg.generation.retain_dropout_modules
|
||||
self.make_generation_fast_(**kwargs)
|
||||
|
||||
def make_generation_fast_(self, **kwargs):
|
||||
@ -437,15 +453,26 @@ class FairseqMultiModel(BaseFairseqModel):
|
||||
def forward_decoder(self, prev_output_tokens, **kwargs):
|
||||
return self.decoder(prev_output_tokens, **kwargs)
|
||||
|
||||
def load_state_dict(self, state_dict, strict=True, args=None):
|
||||
def load_state_dict(
|
||||
self,
|
||||
state_dict,
|
||||
strict=True,
|
||||
model_cfg=None,
|
||||
args: Optional[Namespace] = None,
|
||||
):
|
||||
"""Copies parameters and buffers from *state_dict* into this module and
|
||||
its descendants.
|
||||
|
||||
Overrides the method in :class:`nn.Module`. Compared with that method
|
||||
this additionally "upgrades" *state_dicts* from old checkpoints.
|
||||
"""
|
||||
|
||||
if model_cfg is None and args is not None:
|
||||
logger.warn("using 'args' is deprecated, please update your code to use dataclass config")
|
||||
model_cfg = convert_namespace_to_omegaconf(args).model
|
||||
|
||||
self.upgrade_state_dict(state_dict)
|
||||
new_state_dict = prune_state_dict(state_dict, args)
|
||||
new_state_dict = prune_state_dict(state_dict, model_cfg)
|
||||
return super().load_state_dict(new_state_dict, strict)
|
||||
|
||||
|
||||
|
@ -194,14 +194,14 @@ class MultilingualTransformerModel(FairseqMultiModel):
|
||||
module_class = TransformerEncoder if is_encoder else TransformerDecoder
|
||||
return module_class(args, lang_dict, embed_tokens)
|
||||
|
||||
def load_state_dict(self, state_dict, strict=True, args=None):
|
||||
def load_state_dict(self, state_dict, strict=True, model_cfg=None):
|
||||
state_dict_subset = state_dict.copy()
|
||||
for k, _ in state_dict.items():
|
||||
assert k.startswith("models.")
|
||||
lang_pair = k.split(".")[1]
|
||||
if lang_pair not in self.models:
|
||||
del state_dict_subset[k]
|
||||
super().load_state_dict(state_dict_subset, strict=strict, args=args)
|
||||
super().load_state_dict(state_dict_subset, strict=strict, model_cfg=model_cfg)
|
||||
|
||||
|
||||
@register_model_architecture("multilingual_transformer", "multilingual_transformer")
|
||||
|
@ -17,13 +17,13 @@ class RobertaHubInterface(nn.Module):
|
||||
Usage: https://github.com/pytorch/fairseq/tree/master/examples/roberta
|
||||
"""
|
||||
|
||||
def __init__(self, args, task, model):
|
||||
def __init__(self, cfg, task, model):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.cfg = cfg
|
||||
self.task = task
|
||||
self.model = model
|
||||
|
||||
self.bpe = encoders.build_bpe(args)
|
||||
self.bpe = encoders.build_bpe(cfg.bpe)
|
||||
|
||||
# this is useful for determining the device
|
||||
self.register_buffer("_float_tensor", torch.tensor([0], dtype=torch.float))
|
||||
|
@ -494,7 +494,7 @@ def base_architecture(args):
|
||||
args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0.0)
|
||||
args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0.0)
|
||||
args.spectral_norm_classification_head = getattr(
|
||||
args, "spectral_nrom_classification_head", False
|
||||
args, "spectral_norm_classification_head", False
|
||||
)
|
||||
|
||||
|
||||
|
@ -578,10 +578,9 @@ class TransformerDecoder(FairseqIncrementalDecoder):
|
||||
if embed_dim != input_embed_dim
|
||||
else None
|
||||
)
|
||||
|
||||
self.embed_positions = (
|
||||
PositionalEmbedding(
|
||||
args.max_target_positions,
|
||||
self.max_target_positions,
|
||||
embed_dim,
|
||||
self.padding_idx,
|
||||
learned=args.decoder_learned_pos,
|
||||
@ -963,6 +962,14 @@ def base_architecture(args):
|
||||
args.layernorm_embedding = getattr(args, "layernorm_embedding", False)
|
||||
args.tie_adaptive_weights = getattr(args, "tie_adaptive_weights", False)
|
||||
|
||||
args.encoder_layers_to_keep = getattr(args, "encoder_layers_to_keep", None)
|
||||
args.decoder_layers_to_keep = getattr(args, "decoder_layers_to_keep", None)
|
||||
args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0)
|
||||
args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0)
|
||||
args.quant_noise_pq = getattr(args, "quant_noise_pq", 0)
|
||||
args.quant_noise_pq_block_size = getattr(args, "quant_noise_pq_block_size", 8)
|
||||
args.quant_noise_scalar = getattr(args, "quant_noise_scalar", 0)
|
||||
|
||||
|
||||
@register_model_architecture("transformer", "transformer_iwslt_de_en")
|
||||
def transformer_iwslt_de_en(args):
|
||||
|
@ -159,7 +159,7 @@ class TransformerLanguageModelConfig(FairseqDataclass):
|
||||
add_bos_token: bool = II("task.add_bos_token")
|
||||
tokens_per_sample: int = II("task.tokens_per_sample")
|
||||
max_target_positions: Optional[int] = II("task.max_target_positions")
|
||||
tpu: bool = II("params.common.tpu")
|
||||
tpu: bool = II("common.tpu")
|
||||
|
||||
|
||||
@register_model("transformer_lm", dataclass=TransformerLanguageModelConfig)
|
||||
|
@ -32,20 +32,20 @@ class TransformerEncoderLayer(nn.Module):
|
||||
def __init__(self, args):
|
||||
super().__init__()
|
||||
self.embed_dim = args.encoder_embed_dim
|
||||
self.quant_noise = getattr(args, "quant_noise_pq", 0)
|
||||
self.quant_noise_block_size = getattr(args, "quant_noise_pq_block_size", 8)
|
||||
self.quant_noise = getattr(args, 'quant_noise_pq', 0)
|
||||
self.quant_noise_block_size = getattr(args, 'quant_noise_pq_block_size', 8) or 8
|
||||
self.self_attn = self.build_self_attention(self.embed_dim, args)
|
||||
self.self_attn_layer_norm = LayerNorm(self.embed_dim)
|
||||
self.dropout_module = FairseqDropout(
|
||||
args.dropout, module_name=self.__class__.__name__
|
||||
)
|
||||
self.activation_fn = utils.get_activation_fn(
|
||||
activation=getattr(args, "activation_fn", "relu")
|
||||
activation=getattr(args, 'activation_fn', 'relu') or "relu"
|
||||
)
|
||||
activation_dropout_p = getattr(args, "activation_dropout", 0)
|
||||
activation_dropout_p = getattr(args, "activation_dropout", 0) or 0
|
||||
if activation_dropout_p == 0:
|
||||
# for backwards compatibility with models that use args.relu_dropout
|
||||
activation_dropout_p = getattr(args, "relu_dropout", 0)
|
||||
activation_dropout_p = getattr(args, "relu_dropout", 0) or 0
|
||||
self.activation_dropout_module = FairseqDropout(
|
||||
float(activation_dropout_p), module_name=self.__class__.__name__
|
||||
)
|
||||
@ -197,10 +197,10 @@ class TransformerDecoderLayer(nn.Module):
|
||||
if getattr(args, "activation_fn", None) is not None
|
||||
else "relu"
|
||||
)
|
||||
activation_dropout_p = getattr(args, "activation_dropout", 0)
|
||||
activation_dropout_p = getattr(args, "activation_dropout", 0) or 0
|
||||
if activation_dropout_p == 0:
|
||||
# for backwards compatibility with models that use args.relu_dropout
|
||||
activation_dropout_p = getattr(args, "relu_dropout", 0)
|
||||
activation_dropout_p = getattr(args, "relu_dropout", 0) or 0
|
||||
self.activation_dropout_module = FairseqDropout(
|
||||
float(activation_dropout_p), module_name=self.__class__.__name__
|
||||
)
|
||||
|
@ -6,8 +6,6 @@
|
||||
|
||||
import importlib
|
||||
import os
|
||||
from argparse import Namespace
|
||||
from typing import Union
|
||||
|
||||
from fairseq import registry
|
||||
from fairseq.optim.bmuf import FairseqBMUF # noqa
|
||||
@ -19,7 +17,6 @@ from fairseq.optim.fp16_optimizer import FP16Optimizer, MemoryEfficientFP16Optim
|
||||
from fairseq.optim.shard import shard_
|
||||
from omegaconf import DictConfig
|
||||
|
||||
|
||||
__all__ = [
|
||||
"FairseqOptimizer",
|
||||
"FP16Optimizer",
|
||||
@ -27,7 +24,6 @@ __all__ = [
|
||||
"shard_",
|
||||
]
|
||||
|
||||
|
||||
(
|
||||
_build_optimizer,
|
||||
register_optimizer,
|
||||
@ -37,12 +33,12 @@ __all__ = [
|
||||
|
||||
|
||||
def build_optimizer(
|
||||
optimizer_cfg: Union[DictConfig, Namespace], params, *extra_args, **extra_kwargs
|
||||
cfg: DictConfig, params, *extra_args, **extra_kwargs
|
||||
):
|
||||
if all(isinstance(p, dict) for p in params):
|
||||
params = [t for p in params for t in p.values()]
|
||||
params = list(filter(lambda p: p.requires_grad, params))
|
||||
return _build_optimizer(optimizer_cfg, params, *extra_args, **extra_kwargs)
|
||||
return _build_optimizer(cfg, params, *extra_args, **extra_kwargs)
|
||||
|
||||
|
||||
# automatically import any Python files in the optim/ directory
|
||||
|
@ -5,6 +5,7 @@
|
||||
|
||||
import logging
|
||||
import math
|
||||
from collections import Collection
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List
|
||||
|
||||
@ -14,7 +15,7 @@ import torch.optim
|
||||
from fairseq.dataclass import FairseqDataclass
|
||||
from fairseq.optim import FairseqOptimizer, register_optimizer
|
||||
from fairseq.optim.fused_adam import get_fused_adam_class
|
||||
from omegaconf import II
|
||||
from omegaconf import II, DictConfig
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -33,8 +34,8 @@ class FairseqAdamConfig(FairseqDataclass):
|
||||
default=False, metadata={"help": "Use fairseq.optim.adam.Adam"}
|
||||
)
|
||||
# TODO common vars below in parent
|
||||
tpu: bool = II("params.common.tpu")
|
||||
lr: List[float] = II("params.optimization.lr")
|
||||
tpu: bool = II("common.tpu")
|
||||
lr: List[float] = II("optimization.lr")
|
||||
|
||||
|
||||
@register_optimizer("adam", dataclass=FairseqAdamConfig)
|
||||
@ -46,15 +47,15 @@ class FairseqAdam(FairseqOptimizer):
|
||||
analogous to torch.optim.AdamW from PyTorch.
|
||||
"""
|
||||
|
||||
def __init__(self, args, params):
|
||||
super().__init__(args)
|
||||
def __init__(self, cfg: DictConfig, params):
|
||||
super().__init__(cfg)
|
||||
fused_adam_cls = get_fused_adam_class()
|
||||
use_fused_adam = (
|
||||
not getattr(args, "use_old_adam", False)
|
||||
not getattr(cfg, "use_old_adam", False)
|
||||
and fused_adam_cls is not None
|
||||
and torch.cuda.is_available()
|
||||
)
|
||||
if getattr(args, "tpu", False):
|
||||
if getattr(cfg, "tpu", False):
|
||||
# on TPUs we use the Adam defined here, since it
|
||||
# automatically casts gradients to FP32
|
||||
self._optimizer = Adam(params, **self.optimizer_config)
|
||||
@ -73,10 +74,12 @@ class FairseqAdam(FairseqOptimizer):
|
||||
different learning rate.
|
||||
"""
|
||||
return {
|
||||
"lr": self.args.lr[0],
|
||||
"betas": eval(self.args.adam_betas),
|
||||
"eps": self.args.adam_eps,
|
||||
"weight_decay": self.args.weight_decay,
|
||||
"lr": self.cfg.lr[0]
|
||||
if isinstance(self.cfg.lr, Collection)
|
||||
else self.cfg.lr,
|
||||
"betas": eval(self.cfg.adam_betas),
|
||||
"eps": self.cfg.adam_eps,
|
||||
"weight_decay": self.cfg.weight_decay,
|
||||
}
|
||||
|
||||
def average_params(self):
|
||||
|
@ -10,7 +10,7 @@ import torch.distributed as dist
|
||||
from fairseq.dataclass import FairseqDataclass
|
||||
from fairseq.dataclass.utils import gen_parser_from_dataclass
|
||||
from fairseq.optim.fairseq_optimizer import FairseqOptimizer
|
||||
from omegaconf import II
|
||||
from omegaconf import II, DictConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -38,7 +38,7 @@ class FairseqBMUFConfig(FairseqDataclass):
|
||||
},
|
||||
)
|
||||
distributed_world_size: int = II(
|
||||
"params.distributed_training.distributed_world_size"
|
||||
"distributed_training.distributed_world_size"
|
||||
)
|
||||
|
||||
|
||||
@ -52,20 +52,19 @@ class FairseqBMUF(FairseqOptimizer):
|
||||
model-update filtering
|
||||
"""
|
||||
|
||||
def __init__(self, args, optimizer):
|
||||
|
||||
super().__init__(args)
|
||||
def __init__(self, cfg: DictConfig, optimizer):
|
||||
super().__init__(cfg)
|
||||
self._optimizer = optimizer
|
||||
self._num_updates = 0
|
||||
self.sync_iter = self.args.global_sync_iter
|
||||
self.block_momentum = self.args.block_momentum
|
||||
self.block_lr = self.args.block_lr
|
||||
self.sync_iter = cfg.global_sync_iter
|
||||
self.block_momentum = cfg.block_momentum
|
||||
self.block_lr = cfg.block_lr
|
||||
self._reset_local_data()
|
||||
self.warmup_iteration = self.args.warmup_iterations
|
||||
self.use_nbm = self.args.use_nbm
|
||||
self.warmup_iteration = cfg.warmup_iterations
|
||||
self.use_nbm = cfg.use_nbm
|
||||
self.initial_state = self._optimizer.state_dict()
|
||||
self.average_sync = self.args.average_sync
|
||||
self.world_size = self.args.distributed_world_size
|
||||
self.average_sync = self.cfg.average_sync
|
||||
self.world_size = self.cfg.distributed_world_size
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
@ -9,9 +9,9 @@ from fairseq.dataclass.utils import gen_parser_from_dataclass
|
||||
|
||||
|
||||
class FairseqOptimizer(object):
|
||||
def __init__(self, args):
|
||||
def __init__(self, cfg):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.cfg = cfg
|
||||
|
||||
@classmethod
|
||||
def add_args(cls, parser):
|
||||
|
@ -7,7 +7,8 @@ from collections import defaultdict
|
||||
from itertools import chain
|
||||
|
||||
import torch
|
||||
from fairseq import optim, utils
|
||||
from fairseq import optim
|
||||
from omegaconf import DictConfig
|
||||
|
||||
from .dynamic_loss_scaler import DynamicLossScaler
|
||||
|
||||
@ -211,7 +212,7 @@ class _FP16OptimizerMixin(object):
|
||||
for fp32_params in self.fp32_params.values():
|
||||
fp32_params.grad.zero_()
|
||||
else:
|
||||
raise ("self.fp32_params must be a tensor or dict")
|
||||
raise RuntimeError("self.fp32_params must be a tensor or dict")
|
||||
else:
|
||||
for p32 in self.fp32_params:
|
||||
p32.grad.zero_()
|
||||
@ -226,58 +227,60 @@ class FP16Optimizer(_FP16OptimizerMixin, optim.FairseqOptimizer):
|
||||
Wrap an *optimizer* to support FP16 (mixed precision) training.
|
||||
"""
|
||||
|
||||
def __init__(self, args, params, fp32_optimizer, fp32_params):
|
||||
super().__init__(args)
|
||||
def __init__(self, cfg: DictConfig, params, fp32_optimizer, fp32_params, **kwargs):
|
||||
super().__init__(cfg.optimizer)
|
||||
self.fp16_params = params
|
||||
self.fp32_optimizer = fp32_optimizer
|
||||
self.fp32_params = fp32_params
|
||||
|
||||
if getattr(args, "fp16_scale_window", None) is None:
|
||||
if len(args.update_freq) > 1:
|
||||
if getattr(cfg.common, "fp16_scale_window", None) is None:
|
||||
if len(cfg.optimization.update_freq) > 1:
|
||||
raise ValueError(
|
||||
"--fp16-scale-window must be given explicitly when using a "
|
||||
"custom --update-freq schedule"
|
||||
)
|
||||
data_parallel_size = int(
|
||||
args.distributed_world_size / args.model_parallel_size
|
||||
cfg.distributed_training.distributed_world_size
|
||||
/ cfg.common.model_parallel_size
|
||||
)
|
||||
scale_window = int(
|
||||
2 ** 14 / data_parallel_size / cfg.optimization.update_freq[0]
|
||||
)
|
||||
scale_window = int(2 ** 14 / data_parallel_size / args.update_freq[0])
|
||||
else:
|
||||
scale_window = args.fp16_scale_window
|
||||
scale_window = cfg.common.fp16_scale_window
|
||||
|
||||
if not getattr(args, "bf16", False):
|
||||
if not getattr(cfg.common, "bf16", False):
|
||||
self.scaler = DynamicLossScaler(
|
||||
init_scale=args.fp16_init_scale,
|
||||
init_scale=cfg.common.fp16_init_scale,
|
||||
scale_window=scale_window,
|
||||
tolerance=args.fp16_scale_tolerance,
|
||||
threshold=args.threshold_loss_scale,
|
||||
min_loss_scale=args.min_loss_scale,
|
||||
tolerance=cfg.common.fp16_scale_tolerance,
|
||||
threshold=cfg.common.threshold_loss_scale,
|
||||
min_loss_scale=cfg.common.min_loss_scale,
|
||||
)
|
||||
else:
|
||||
# disable loss scaling for bfloat16
|
||||
self.scaler = None
|
||||
|
||||
@classmethod
|
||||
def build_optimizer(cls, args, params):
|
||||
def build_optimizer(cls, cfg: DictConfig, params, **kwargs):
|
||||
"""
|
||||
Args:
|
||||
args (argparse.Namespace): fairseq args
|
||||
cfg (omegaconf.DictConfig): fairseq args
|
||||
params (iterable): iterable of parameters to optimize
|
||||
"""
|
||||
flatten = not getattr(args, "fp16_no_flatten_grads", False)
|
||||
if getattr(args, "bf16", False):
|
||||
flatten = not getattr(cfg.common, "fp16_no_flatten_grads", False)
|
||||
if getattr(cfg.common, "bf16", False):
|
||||
flatten = False # mixed precision is faster on TPUs without flat grads
|
||||
fp32_params = cls.build_fp32_params(args, params, flatten=flatten)
|
||||
fp32_params = cls.build_fp32_params(cfg.optimizer, params, flatten=flatten)
|
||||
if flatten:
|
||||
fp32_optimizer = optim.build_optimizer(args, [fp32_params])
|
||||
fp32_optimizer = optim.build_optimizer(cfg.optimizer, [fp32_params])
|
||||
else:
|
||||
fp32_optimizer = optim.build_optimizer(args, fp32_params)
|
||||
fp32_optimizer = optim.build_optimizer(cfg.optimizer, fp32_params)
|
||||
if flatten and not fp32_optimizer.supports_flat_params:
|
||||
raise RuntimeError(
|
||||
"chosen optimizer does not support flat params, "
|
||||
"please set --fp16-no-flatten-grads"
|
||||
f"chosen optimizer {fp32_optimizer.__class__.__name__} does not support flat params, please set --fp16-no-flatten-grads"
|
||||
)
|
||||
return cls(args, params, fp32_optimizer, fp32_params)
|
||||
return cls(cfg, params, fp32_optimizer, fp32_params, **kwargs)
|
||||
|
||||
@property
|
||||
def optimizer(self):
|
||||
@ -427,49 +430,52 @@ class MemoryEfficientFP16Optimizer(
|
||||
*supports_memory_efficient_fp16* property.
|
||||
"""
|
||||
|
||||
def __init__(self, args, params, optimizer):
|
||||
def __init__(self, cfg: DictConfig, params, optimizer, **kwargs):
|
||||
if not optimizer.supports_memory_efficient_fp16:
|
||||
raise ValueError(
|
||||
"Unsupported optimizer: {}".format(optimizer.__class__.__name__)
|
||||
)
|
||||
|
||||
super().__init__(args)
|
||||
super().__init__(cfg.optimizer)
|
||||
self.wrapped_optimizer = optimizer
|
||||
|
||||
if getattr(args, "fp16_scale_window", None) is None:
|
||||
if len(args.update_freq) > 1:
|
||||
if getattr(cfg.common, "fp16_scale_window", None) is None:
|
||||
if len(cfg.optimization.update_freq) > 1:
|
||||
raise ValueError(
|
||||
"--fp16-scale-window must be given explicitly when using a "
|
||||
"custom --update-freq schedule"
|
||||
)
|
||||
data_parallel_size = int(
|
||||
args.distributed_world_size / args.model_parallel_size
|
||||
cfg.distributed_training.distributed_world_size
|
||||
/ cfg.common.model_parallel_size
|
||||
)
|
||||
scale_window = (
|
||||
2 ** 14 / data_parallel_size / cfg.optimization.update_freq[0]
|
||||
)
|
||||
scale_window = 2 ** 14 / data_parallel_size / args.update_freq[0]
|
||||
else:
|
||||
scale_window = args.fp16_scale_window
|
||||
scale_window = cfg.common.fp16_scale_window
|
||||
|
||||
if not getattr(args, "bf16", False):
|
||||
if not getattr(cfg.common, "bf16", False):
|
||||
self.scaler = DynamicLossScaler(
|
||||
init_scale=args.fp16_init_scale,
|
||||
init_scale=cfg.common.fp16_init_scale,
|
||||
scale_window=scale_window,
|
||||
tolerance=args.fp16_scale_tolerance,
|
||||
threshold=args.threshold_loss_scale,
|
||||
min_loss_scale=args.min_loss_scale,
|
||||
tolerance=cfg.common.fp16_scale_tolerance,
|
||||
threshold=cfg.common.threshold_loss_scale,
|
||||
min_loss_scale=cfg.common.min_loss_scale,
|
||||
)
|
||||
else:
|
||||
# disable loss scaling for bfloat16
|
||||
self.scaler = None
|
||||
|
||||
@classmethod
|
||||
def build_optimizer(cls, args, params):
|
||||
def build_optimizer(cls, cfg: DictConfig, params, **kwargs):
|
||||
"""
|
||||
Args:
|
||||
args (argparse.Namespace): fairseq args
|
||||
params (iterable): iterable of parameters to optimize
|
||||
"""
|
||||
fp16_optimizer = optim.build_optimizer(args, params)
|
||||
return cls(args, params, fp16_optimizer)
|
||||
fp16_optimizer = optim.build_optimizer(cfg.optimizer, params)
|
||||
return cls(cfg, params, fp16_optimizer, **kwargs)
|
||||
|
||||
@property
|
||||
def optimizer(self):
|
||||
|
@ -6,8 +6,6 @@
|
||||
|
||||
import importlib
|
||||
import os
|
||||
from argparse import Namespace
|
||||
from typing import Union
|
||||
|
||||
from fairseq import registry
|
||||
from fairseq.optim.lr_scheduler.fairseq_lr_scheduler import ( # noqa
|
||||
@ -27,8 +25,8 @@ from omegaconf import DictConfig
|
||||
)
|
||||
|
||||
|
||||
def build_lr_scheduler(lr_scheduler_cfg: Union[DictConfig, Namespace], optimizer):
|
||||
return build_lr_scheduler_(lr_scheduler_cfg, optimizer)
|
||||
def build_lr_scheduler(cfg: DictConfig, optimizer):
|
||||
return build_lr_scheduler_(cfg, optimizer)
|
||||
|
||||
|
||||
# automatically import any Python files in the optim/lr_scheduler/ directory
|
||||
|
@ -4,11 +4,12 @@
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import math
|
||||
from collections import Collection
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List
|
||||
|
||||
from fairseq.dataclass import FairseqDataclass
|
||||
from omegaconf import II
|
||||
from omegaconf import II, DictConfig
|
||||
|
||||
from . import FairseqLRScheduler, register_lr_scheduler
|
||||
|
||||
@ -38,8 +39,8 @@ class CosineConfig(FairseqDataclass):
|
||||
default=0.1, metadata={"help": "shrink factor for annealing"}
|
||||
)
|
||||
# TODO common var for parent class
|
||||
lr: List[float] = II("params.optimization.lr")
|
||||
max_update: int = II("params.optimization.max_update")
|
||||
lr: List[float] = II("optimization.lr")
|
||||
max_update: int = II("optimization.max_update")
|
||||
|
||||
|
||||
@register_lr_scheduler("cosine", dataclass=CosineConfig)
|
||||
@ -66,43 +67,51 @@ class CosineSchedule(FairseqLRScheduler):
|
||||
after every iteration.
|
||||
"""
|
||||
|
||||
def __init__(self, args, optimizer):
|
||||
super().__init__(args, optimizer)
|
||||
if len(args.lr) > 1:
|
||||
def __init__(
|
||||
self, cfg: DictConfig, fairseq_optimizer
|
||||
):
|
||||
super().__init__(cfg, fairseq_optimizer)
|
||||
if isinstance(cfg.lr, Collection) and len(cfg.lr) > 1:
|
||||
raise ValueError(
|
||||
"Cannot use a fixed learning rate schedule with cosine."
|
||||
" Consider --lr-scheduler=fixed instead."
|
||||
)
|
||||
|
||||
warmup_end_lr = args.max_lr
|
||||
if args.warmup_init_lr < 0:
|
||||
args.warmup_init_lr = args.lr[0]
|
||||
|
||||
self.min_lr = args.lr[0]
|
||||
self.max_lr = args.max_lr
|
||||
warmup_end_lr = cfg.max_lr
|
||||
lr = (
|
||||
cfg.lr[0]
|
||||
if isinstance(cfg.lr, Collection)
|
||||
else cfg.lr
|
||||
)
|
||||
if cfg.warmup_init_lr < 0:
|
||||
cfg.warmup_init_lr = lr
|
||||
|
||||
self.min_lr = lr
|
||||
self.max_lr = cfg.max_lr
|
||||
assert self.max_lr > self.min_lr, "max_lr must be more than lr"
|
||||
|
||||
self.t_mult = args.t_mult
|
||||
self.period = args.lr_period_updates
|
||||
self.t_mult = cfg.t_mult
|
||||
self.period = cfg.lr_period_updates
|
||||
|
||||
if self.period <= 0:
|
||||
assert (
|
||||
args.max_update >= 0
|
||||
cfg.max_update >= 0
|
||||
), "Either --max_update or --lr-period-updates must be set"
|
||||
self.period = args.max_update - args.warmup_updates
|
||||
self.period = cfg.max_update - cfg.warmup_updates
|
||||
|
||||
if args.warmup_updates > 0:
|
||||
if cfg.warmup_updates > 0:
|
||||
# linearly warmup for the first args.warmup_updates
|
||||
self.lr_step = (warmup_end_lr - args.warmup_init_lr) / args.warmup_updates
|
||||
self.lr_step = (
|
||||
warmup_end_lr - cfg.warmup_init_lr
|
||||
) / cfg.warmup_updates
|
||||
else:
|
||||
self.lr_step = 1
|
||||
|
||||
self.warmup_updates = args.warmup_updates
|
||||
self.lr_shrink = args.lr_shrink
|
||||
self.warmup_updates = cfg.warmup_updates
|
||||
self.lr_shrink = cfg.lr_shrink
|
||||
|
||||
# initial learning rate
|
||||
self.lr = args.warmup_init_lr
|
||||
self.lr = cfg.warmup_init_lr
|
||||
self.optimizer.set_lr(self.lr)
|
||||
|
||||
def step(self, epoch, val_loss=None):
|
||||
@ -113,10 +122,10 @@ class CosineSchedule(FairseqLRScheduler):
|
||||
|
||||
def step_update(self, num_updates):
|
||||
"""Update the learning rate after each update."""
|
||||
if num_updates < self.args.warmup_updates:
|
||||
self.lr = self.args.warmup_init_lr + num_updates * self.lr_step
|
||||
if num_updates < self.cfg.warmup_updates:
|
||||
self.lr = self.cfg.warmup_init_lr + num_updates * self.lr_step
|
||||
else:
|
||||
curr_updates = num_updates - self.args.warmup_updates
|
||||
curr_updates = num_updates - self.cfg.warmup_updates
|
||||
if self.t_mult != 1:
|
||||
i = math.floor(
|
||||
math.log(
|
||||
|
@ -11,11 +11,11 @@ from .. import FairseqOptimizer
|
||||
|
||||
|
||||
class FairseqLRScheduler(object):
|
||||
def __init__(self, args, optimizer):
|
||||
def __init__(self, cfg, optimizer):
|
||||
super().__init__()
|
||||
if not isinstance(optimizer, FairseqOptimizer):
|
||||
raise ValueError("optimizer must be an instance of FairseqOptimizer")
|
||||
self.args = args
|
||||
self.cfg = cfg
|
||||
self.optimizer = optimizer
|
||||
self.best = None
|
||||
|
||||
|
@ -3,11 +3,12 @@
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from collections import Collection
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List
|
||||
|
||||
from fairseq.dataclass import FairseqDataclass
|
||||
from omegaconf import II
|
||||
from omegaconf import II, DictConfig
|
||||
|
||||
from . import FairseqLRScheduler, register_lr_scheduler
|
||||
|
||||
@ -25,7 +26,7 @@ class InverseSquareRootScheduleConfig(FairseqDataclass):
|
||||
},
|
||||
)
|
||||
# TODO common vars at parent class
|
||||
lr: List[float] = II("params.optimization.lr")
|
||||
lr: List[float] = II("optimization.lr")
|
||||
|
||||
|
||||
@register_lr_scheduler("inverse_sqrt", dataclass=InverseSquareRootScheduleConfig)
|
||||
@ -48,25 +49,33 @@ class InverseSquareRootSchedule(FairseqLRScheduler):
|
||||
lr = decay_factor / sqrt(update_num)
|
||||
"""
|
||||
|
||||
def __init__(self, args, optimizer):
|
||||
super().__init__(args, optimizer)
|
||||
if len(args.lr) > 1:
|
||||
def __init__(self, cfg: DictConfig, optimizer):
|
||||
super().__init__(cfg, optimizer)
|
||||
if isinstance(cfg.lr, Collection) and len(cfg.lr) > 1:
|
||||
raise ValueError(
|
||||
"Cannot use a fixed learning rate schedule with inverse_sqrt."
|
||||
" Consider --lr-scheduler=fixed instead."
|
||||
)
|
||||
warmup_end_lr = args.lr[0]
|
||||
if args.warmup_init_lr < 0:
|
||||
args.warmup_init_lr = 0 if args.warmup_updates > 0 else warmup_end_lr
|
||||
warmup_end_lr = (
|
||||
cfg.lr[0]
|
||||
if isinstance(cfg.lr, Collection)
|
||||
else cfg.lr
|
||||
)
|
||||
if cfg.warmup_init_lr < 0:
|
||||
cfg.warmup_init_lr = (
|
||||
0 if cfg.warmup_updates > 0 else warmup_end_lr
|
||||
)
|
||||
|
||||
# linearly warmup for the first args.warmup_updates
|
||||
self.lr_step = (warmup_end_lr - args.warmup_init_lr) / args.warmup_updates
|
||||
self.lr_step = (
|
||||
warmup_end_lr - cfg.warmup_init_lr
|
||||
) / cfg.warmup_updates
|
||||
|
||||
# then, decay prop. to the inverse square root of the update number
|
||||
self.decay_factor = warmup_end_lr * args.warmup_updates ** 0.5
|
||||
self.decay_factor = warmup_end_lr * cfg.warmup_updates ** 0.5
|
||||
|
||||
# initial learning rate
|
||||
self.lr = args.warmup_init_lr
|
||||
self.lr = cfg.warmup_init_lr
|
||||
self.optimizer.set_lr(self.lr)
|
||||
|
||||
def step(self, epoch, val_loss=None):
|
||||
@ -77,8 +86,8 @@ class InverseSquareRootSchedule(FairseqLRScheduler):
|
||||
|
||||
def step_update(self, num_updates):
|
||||
"""Update the learning rate after each update."""
|
||||
if num_updates < self.args.warmup_updates:
|
||||
self.lr = self.args.warmup_init_lr + num_updates * self.lr_step
|
||||
if num_updates < self.cfg.warmup_updates:
|
||||
self.lr = self.cfg.warmup_init_lr + num_updates * self.lr_step
|
||||
else:
|
||||
self.lr = self.decay_factor * num_updates ** -0.5
|
||||
self.optimizer.set_lr(self.lr)
|
||||
|
@ -3,12 +3,13 @@
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from collections import Collection
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
from fairseq.dataclass import FairseqDataclass
|
||||
from omegaconf import II
|
||||
from omegaconf import II, DictConfig
|
||||
from torch.optim.optimizer import Optimizer, required
|
||||
|
||||
from . import FairseqOptimizer, register_optimizer
|
||||
@ -19,13 +20,13 @@ class FairseqNAGConfig(FairseqDataclass):
|
||||
momentum: float = field(default=0.99, metadata={"help": "momentum factor"})
|
||||
weight_decay: float = field(default=0.0, metadata={"help": "weight decay"})
|
||||
# TODO common vars in parent class
|
||||
lr: List[float] = II("params.optimization.lr")
|
||||
lr: List[float] = II("optimization.lr")
|
||||
|
||||
|
||||
@register_optimizer("nag", dataclass=FairseqNAGConfig)
|
||||
class FairseqNAG(FairseqOptimizer):
|
||||
def __init__(self, args, params):
|
||||
super().__init__(args)
|
||||
def __init__(self, cfg: DictConfig, params):
|
||||
super().__init__(cfg)
|
||||
self._optimizer = NAG(params, **self.optimizer_config)
|
||||
|
||||
@property
|
||||
@ -37,9 +38,11 @@ class FairseqNAG(FairseqOptimizer):
|
||||
different learning rate.
|
||||
"""
|
||||
return {
|
||||
"lr": self.args.lr[0],
|
||||
"momentum": self.args.momentum,
|
||||
"weight_decay": self.args.weight_decay,
|
||||
"lr": self.cfg.lr[0]
|
||||
if isinstance(self.cfg.lr, Collection)
|
||||
else self.cfg.lr,
|
||||
"momentum": self.cfg.momentum,
|
||||
"weight_decay": self.cfg.weight_decay,
|
||||
}
|
||||
|
||||
|
||||
|
@ -12,7 +12,7 @@ except ImportError:
|
||||
_has_fairscale = False
|
||||
|
||||
|
||||
def shard_(args, optimizer, group):
|
||||
def shard_(optimizer, group):
|
||||
if not _has_fairscale:
|
||||
raise ImportError(
|
||||
"\n\nPlease install the fairscale package:" "\n\n pip install fairscale"
|
||||
|
@ -10,13 +10,15 @@ import torch
|
||||
from fairseq import utils
|
||||
from fairseq.data.indexed_dataset import get_available_dataset_impl
|
||||
from fairseq.dataclass.data_class import (
|
||||
CheckpointParams,
|
||||
CommonEvalParams,
|
||||
CommonParams,
|
||||
DatasetParams,
|
||||
DistributedTrainingParams,
|
||||
EvalLMParams,
|
||||
OptimizationParams,
|
||||
CheckpointConfig,
|
||||
CommonConfig,
|
||||
CommonEvalConfig,
|
||||
DatasetConfig,
|
||||
DistributedTrainingConfig,
|
||||
EvalLMConfig,
|
||||
GenerationConfig,
|
||||
InteractiveConfig,
|
||||
OptimizationConfig,
|
||||
)
|
||||
from fairseq.dataclass.utils import gen_parser_from_dataclass
|
||||
|
||||
@ -45,6 +47,7 @@ def get_generation_parser(interactive=False, default_task="translation"):
|
||||
add_dataset_args(parser, gen=True)
|
||||
add_distributed_training_args(parser, default_world_size=1)
|
||||
add_generation_args(parser)
|
||||
add_checkpoint_args(parser)
|
||||
if interactive:
|
||||
add_interactive_args(parser)
|
||||
return parser
|
||||
@ -67,7 +70,7 @@ def get_validation_parser(default_task=None):
|
||||
add_dataset_args(parser, train=True)
|
||||
add_distributed_training_args(parser, default_world_size=1)
|
||||
group = parser.add_argument_group("Evaluation")
|
||||
gen_parser_from_dataclass(group, CommonEvalParams())
|
||||
gen_parser_from_dataclass(group, CommonEvalConfig())
|
||||
return parser
|
||||
|
||||
|
||||
@ -210,7 +213,7 @@ def get_parser(desc, default_task="translation"):
|
||||
utils.import_user_module(usr_args)
|
||||
|
||||
parser = argparse.ArgumentParser(allow_abbrev=False)
|
||||
gen_parser_from_dataclass(parser, CommonParams())
|
||||
gen_parser_from_dataclass(parser, CommonConfig())
|
||||
|
||||
from fairseq.registry import REGISTRIES
|
||||
|
||||
@ -283,7 +286,7 @@ def add_preprocess_args(parser):
|
||||
|
||||
def add_dataset_args(parser, train=False, gen=False):
|
||||
group = parser.add_argument_group("dataset_data_loading")
|
||||
gen_parser_from_dataclass(group, DatasetParams())
|
||||
gen_parser_from_dataclass(group, DatasetConfig())
|
||||
# fmt: on
|
||||
return group
|
||||
|
||||
@ -293,7 +296,7 @@ def add_distributed_training_args(parser, default_world_size=None):
|
||||
if default_world_size is None:
|
||||
default_world_size = max(1, torch.cuda.device_count())
|
||||
gen_parser_from_dataclass(
|
||||
group, DistributedTrainingParams(distributed_world_size=default_world_size)
|
||||
group, DistributedTrainingConfig(distributed_world_size=default_world_size)
|
||||
)
|
||||
return group
|
||||
|
||||
@ -301,7 +304,7 @@ def add_distributed_training_args(parser, default_world_size=None):
|
||||
def add_optimization_args(parser):
|
||||
group = parser.add_argument_group("optimization")
|
||||
# fmt: off
|
||||
gen_parser_from_dataclass(group, OptimizationParams())
|
||||
gen_parser_from_dataclass(group, OptimizationConfig())
|
||||
# fmt: on
|
||||
return group
|
||||
|
||||
@ -309,117 +312,31 @@ def add_optimization_args(parser):
|
||||
def add_checkpoint_args(parser):
|
||||
group = parser.add_argument_group("checkpoint")
|
||||
# fmt: off
|
||||
gen_parser_from_dataclass(group, CheckpointParams())
|
||||
gen_parser_from_dataclass(group, CheckpointConfig())
|
||||
# fmt: on
|
||||
return group
|
||||
|
||||
|
||||
def add_common_eval_args(group):
|
||||
gen_parser_from_dataclass(group, CommonEvalParams())
|
||||
gen_parser_from_dataclass(group, CommonEvalConfig())
|
||||
|
||||
|
||||
def add_eval_lm_args(parser):
|
||||
group = parser.add_argument_group("LM Evaluation")
|
||||
add_common_eval_args(group)
|
||||
gen_parser_from_dataclass(group, EvalLMParams())
|
||||
gen_parser_from_dataclass(group, EvalLMConfig())
|
||||
|
||||
|
||||
def add_generation_args(parser):
|
||||
group = parser.add_argument_group("Generation")
|
||||
add_common_eval_args(group)
|
||||
# fmt: off
|
||||
group.add_argument('--beam', default=5, type=int, metavar='N',
|
||||
help='beam size')
|
||||
group.add_argument('--nbest', default=1, type=int, metavar='N',
|
||||
help='number of hypotheses to output')
|
||||
group.add_argument('--max-len-a', default=0, type=float, metavar='N',
|
||||
help=('generate sequences of maximum length ax + b, '
|
||||
'where x is the source length'))
|
||||
group.add_argument('--max-len-b', default=200, type=int, metavar='N',
|
||||
help=('generate sequences of maximum length ax + b, '
|
||||
'where x is the source length'))
|
||||
group.add_argument('--min-len', default=1, type=float, metavar='N',
|
||||
help=('minimum generation length'))
|
||||
group.add_argument('--match-source-len', default=False, action='store_true',
|
||||
help=('generations should match the source length'))
|
||||
group.add_argument('--no-early-stop', action='store_true',
|
||||
help='deprecated')
|
||||
group.add_argument('--unnormalized', action='store_true',
|
||||
help='compare unnormalized hypothesis scores')
|
||||
group.add_argument('--no-beamable-mm', action='store_true',
|
||||
help='don\'t use BeamableMM in attention layers')
|
||||
group.add_argument('--lenpen', default=1, type=float,
|
||||
help='length penalty: <1.0 favors shorter, >1.0 favors longer sentences')
|
||||
group.add_argument('--unkpen', default=0, type=float,
|
||||
help='unknown word penalty: <0 produces more unks, >0 produces fewer')
|
||||
group.add_argument('--replace-unk', nargs='?', const=True, default=None,
|
||||
help='perform unknown replacement (optionally with alignment dictionary)')
|
||||
group.add_argument('--sacrebleu', action='store_true',
|
||||
help='score with sacrebleu')
|
||||
group.add_argument('--score-reference', action='store_true',
|
||||
help='just score the reference translation')
|
||||
group.add_argument('--prefix-size', default=0, type=int, metavar='PS',
|
||||
help='initialize generation by target prefix of given length')
|
||||
group.add_argument('--no-repeat-ngram-size', default=0, type=int, metavar='N',
|
||||
help='ngram blocking such that this size ngram cannot be repeated in the generation')
|
||||
group.add_argument('--sampling', action='store_true',
|
||||
help='sample hypotheses instead of using beam search')
|
||||
group.add_argument('--sampling-topk', default=-1, type=int, metavar='PS',
|
||||
help='sample from top K likely next words instead of all words')
|
||||
group.add_argument('--sampling-topp', default=-1.0, type=float, metavar='PS',
|
||||
help='sample from the smallest set whose cumulative probability mass exceeds p for next words')
|
||||
group.add_argument('--constraints', const="ordered", nargs="?", choices=["ordered", "unordered"],
|
||||
help='enables lexically constrained decoding')
|
||||
group.add_argument('--temperature', default=1., type=float, metavar='N',
|
||||
help='temperature for generation')
|
||||
group.add_argument('--diverse-beam-groups', default=-1, type=int, metavar='N',
|
||||
help='number of groups for Diverse Beam Search')
|
||||
group.add_argument('--diverse-beam-strength', default=0.5, type=float, metavar='N',
|
||||
help='strength of diversity penalty for Diverse Beam Search')
|
||||
group.add_argument('--diversity-rate', default=-1.0, type=float, metavar='N',
|
||||
help='strength of diversity penalty for Diverse Siblings Search')
|
||||
group.add_argument('--print-alignment', action='store_true',
|
||||
help='if set, uses attention feedback to compute and print alignment to source tokens')
|
||||
group.add_argument('--print-step', action='store_true')
|
||||
|
||||
group.add_argument('--lm-path', default=None, type=str, metavar='PATH',
|
||||
help='path to lm checkpoint for lm fusion')
|
||||
group.add_argument('--lm-weight', default=0.0, type=float, metavar='N',
|
||||
help='weight for lm probs for lm fusion')
|
||||
|
||||
# arguments for iterative refinement generator
|
||||
group.add_argument('--iter-decode-eos-penalty', default=0.0, type=float, metavar='N',
|
||||
help='if > 0.0, it penalized early-stopping in decoding.')
|
||||
group.add_argument('--iter-decode-max-iter', default=10, type=int, metavar='N',
|
||||
help='maximum iterations for iterative refinement.')
|
||||
group.add_argument('--iter-decode-force-max-iter', action='store_true',
|
||||
help='if set, run exact the maximum number of iterations without early stop')
|
||||
group.add_argument('--iter-decode-with-beam', default=1, type=int, metavar='N',
|
||||
help='if > 1, model will generate translations varying by the lengths.')
|
||||
group.add_argument('--iter-decode-with-external-reranker', action='store_true',
|
||||
help='if set, the last checkpoint are assumed to be a reranker to rescore the translations'),
|
||||
group.add_argument('--retain-iter-history', action='store_true',
|
||||
help='if set, decoding returns the whole history of iterative refinement')
|
||||
group.add_argument('--retain-dropout', action='store_true',
|
||||
help='Use dropout at inference time')
|
||||
group.add_argument('--retain-dropout-modules', default=None, nargs='+', type=str,
|
||||
help='if set, only retain dropout for the specified modules; '
|
||||
'if not set, then dropout will be retained for all modules')
|
||||
|
||||
# special decoding format for advanced decoding.
|
||||
group.add_argument('--decoding-format', default=None, type=str, choices=['unigram', 'ensemble', 'vote', 'dp', 'bs'])
|
||||
# fmt: on
|
||||
gen_parser_from_dataclass(group, GenerationConfig())
|
||||
return group
|
||||
|
||||
|
||||
def add_interactive_args(parser):
|
||||
group = parser.add_argument_group("Interactive")
|
||||
# fmt: off
|
||||
group.add_argument('--buffer-size', default=0, type=int, metavar='N',
|
||||
help='read this many sentences into a buffer before processing them')
|
||||
group.add_argument('--input', default='-', type=str, metavar='FILE',
|
||||
help='file to read from; use - for stdin')
|
||||
# fmt: on
|
||||
gen_parser_from_dataclass(group, InteractiveConfig())
|
||||
|
||||
|
||||
def add_model_args(parser):
|
||||
|
@ -6,13 +6,14 @@
|
||||
import logging
|
||||
|
||||
from fairseq.modules.quantization import pq, quantization_options, scalar
|
||||
from omegaconf import DictConfig
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def quantize_model_scalar(model, args):
|
||||
quant_noise_scalar = getattr(args, "quant_noise_scalar", 0)
|
||||
def quantize_model_scalar(model, model_cfg: DictConfig):
|
||||
quant_noise_scalar = getattr(model_cfg, "quant_noise_scalar", 0) or 0
|
||||
if quant_noise_scalar > 0:
|
||||
# quantize_model edits the model in place
|
||||
scalar.quantize_model_(model, p=quant_noise_scalar, bits=8, update_step=1000)
|
||||
|
@ -3,14 +3,13 @@
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
from argparse import Namespace
|
||||
|
||||
from typing import Union
|
||||
|
||||
from fairseq.dataclass import FairseqDataclass
|
||||
from fairseq.dataclass.utils import populate_dataclass
|
||||
from omegaconf import DictConfig
|
||||
|
||||
|
||||
REGISTRIES = {}
|
||||
|
||||
|
||||
@ -25,33 +24,30 @@ def setup_registry(registry_name: str, base_class=None, default=None, required=F
|
||||
# maintain a registry of all registries
|
||||
if registry_name in REGISTRIES:
|
||||
return # registry already exists
|
||||
REGISTRIES[registry_name] = {"registry": REGISTRY, "default": default}
|
||||
REGISTRIES[registry_name] = {"registry": REGISTRY, "default": default, "dataclass_registry": DATACLASS_REGISTRY}
|
||||
|
||||
def build_x(args: Union[DictConfig, Namespace], *extra_args, **extra_kwargs):
|
||||
if isinstance(args, DictConfig):
|
||||
if getattr(args, "_name", None) is not None:
|
||||
choice = args._name
|
||||
elif hasattr(args, registry_name):
|
||||
choice = args.registry_name
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Neither _name nor {registry_name} in args, args = {args}"
|
||||
)
|
||||
def build_x(cfg: Union[DictConfig, str, Namespace], *extra_args, **extra_kwargs):
|
||||
if isinstance(cfg, DictConfig):
|
||||
choice = cfg._name
|
||||
elif isinstance(cfg, str):
|
||||
choice = cfg
|
||||
else:
|
||||
choice = getattr(args, registry_name, None)
|
||||
choice = getattr(cfg, registry_name, None)
|
||||
if choice in DATACLASS_REGISTRY:
|
||||
cfg = populate_dataclass(cfg, DATACLASS_REGISTRY[choice]())
|
||||
|
||||
if choice is None:
|
||||
if required:
|
||||
raise ValueError("--{} is required!".format(registry_name))
|
||||
raise ValueError('{} is required!'.format(registry_name))
|
||||
return None
|
||||
|
||||
cls = REGISTRY[choice]
|
||||
if hasattr(cls, "build_" + registry_name):
|
||||
builder = getattr(cls, "build_" + registry_name)
|
||||
else:
|
||||
builder = cls
|
||||
if isinstance(args, Namespace):
|
||||
set_defaults(args, cls)
|
||||
return builder(args, *extra_args, **extra_kwargs)
|
||||
|
||||
return builder(cfg, *extra_args, **extra_kwargs)
|
||||
|
||||
def register_x(name, dataclass=None):
|
||||
def register_x_cls(cls):
|
||||
@ -77,30 +73,10 @@ def setup_registry(registry_name: str, base_class=None, default=None, required=F
|
||||
|
||||
cls.__dataclass = dataclass
|
||||
REGISTRY[name] = cls
|
||||
DATACLASS_REGISTRY[name] = cls.__dataclass
|
||||
REGISTRY_CLASS_NAMES.add(cls.__name__)
|
||||
if cls.__dataclass is not None:
|
||||
DATACLASS_REGISTRY[name] = cls.__dataclass
|
||||
return cls
|
||||
|
||||
return register_x_cls
|
||||
|
||||
return build_x, register_x, REGISTRY, DATACLASS_REGISTRY
|
||||
|
||||
|
||||
def set_defaults(args: Namespace, cls):
|
||||
"""Helper to set default arguments based on *add_args*."""
|
||||
if not hasattr(cls, "add_args"):
|
||||
return
|
||||
parser = argparse.ArgumentParser(
|
||||
argument_default=argparse.SUPPRESS, allow_abbrev=False
|
||||
)
|
||||
cls.add_args(parser)
|
||||
# copied from argparse.py:
|
||||
defaults = argparse.Namespace()
|
||||
for action in parser._actions:
|
||||
if action.dest is not argparse.SUPPRESS:
|
||||
if not hasattr(defaults, action.dest):
|
||||
if action.default is not argparse.SUPPRESS:
|
||||
setattr(defaults, action.dest, action.default)
|
||||
for key, default_value in vars(defaults).items():
|
||||
if not hasattr(args, key):
|
||||
setattr(args, key, default_value)
|
||||
|
@ -9,11 +9,12 @@ import os
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from fairseq import registry
|
||||
from omegaconf import DictConfig
|
||||
|
||||
|
||||
class BaseScorer(ABC):
|
||||
def __init__(self, args):
|
||||
self.args = args
|
||||
def __init__(self, cfg):
|
||||
self.cfg = cfg
|
||||
self.ref = []
|
||||
self.pred = []
|
||||
|
||||
@ -39,19 +40,17 @@ _build_scorer, register_scorer, SCORER_REGISTRY, _ = registry.setup_registry(
|
||||
)
|
||||
|
||||
|
||||
def build_scorer(args, tgt_dict):
|
||||
from fairseq import utils
|
||||
def build_scorer(choice, tgt_dict):
|
||||
if isinstance(choice, DictConfig):
|
||||
choice = choice._name
|
||||
|
||||
if args.sacrebleu:
|
||||
utils.deprecation_warning(
|
||||
"--sacrebleu is deprecated. Please use --scoring sacrebleu instead."
|
||||
)
|
||||
args.scoring = "sacrebleu"
|
||||
if args.scoring == "bleu":
|
||||
if choice == "bleu":
|
||||
from fairseq.scoring import bleu
|
||||
|
||||
return bleu.Scorer(tgt_dict.pad(), tgt_dict.eos(), tgt_dict.unk())
|
||||
return _build_scorer(args)
|
||||
return bleu.Scorer(
|
||||
bleu.BleuConfig(pad=tgt_dict.pad(), eos=tgt_dict.eos(), unk=tgt_dict.unk())
|
||||
)
|
||||
return _build_scorer(choice)
|
||||
|
||||
|
||||
# automatically import any Python files in the current directory
|
||||
|
@ -6,8 +6,10 @@
|
||||
import ctypes
|
||||
import math
|
||||
import sys
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import torch
|
||||
from fairseq.dataclass import FairseqDataclass
|
||||
from fairseq.scoring import BaseScorer, register_scorer
|
||||
from fairseq.scoring.tokenizer import EvaluationTokenizer
|
||||
|
||||
@ -27,31 +29,32 @@ class BleuStat(ctypes.Structure):
|
||||
]
|
||||
|
||||
|
||||
@register_scorer("sacrebleu")
|
||||
@dataclass
|
||||
class SacrebleuConfig(FairseqDataclass):
|
||||
sacrebleu_tokenizer: EvaluationTokenizer.ALL_TOKENIZER_TYPES = field(
|
||||
default="13a", metadata={"help": "tokenizer"}
|
||||
)
|
||||
sacrebleu_lowercase: bool = field(
|
||||
default=False, metadata={"help": "apply lowercasing"}
|
||||
)
|
||||
sacrebleu_char_level: bool = field(
|
||||
default=False, metadata={"help": "evaluate at character level"}
|
||||
)
|
||||
|
||||
|
||||
@register_scorer("sacrebleu", dataclass=SacrebleuConfig)
|
||||
class SacrebleuScorer(BaseScorer):
|
||||
def __init__(self, args):
|
||||
super(SacrebleuScorer, self).__init__(args)
|
||||
def __init__(self, cfg):
|
||||
super(SacrebleuScorer, self).__init__(cfg)
|
||||
import sacrebleu
|
||||
|
||||
self.sacrebleu = sacrebleu
|
||||
self.tokenizer = EvaluationTokenizer(
|
||||
tokenizer_type=self.args.sacrebleu_tokenizer,
|
||||
lowercase=self.args.sacrebleu_lowercase,
|
||||
character_tokenization=self.args.sacrebleu_char_level,
|
||||
tokenizer_type=cfg.sacrebleu_tokenizer,
|
||||
lowercase=cfg.sacrebleu_lowercase,
|
||||
character_tokenization=cfg.sacrebleu_char_level,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
# fmt: off
|
||||
parser.add_argument('--sacrebleu-tokenizer', type=str, default='13a',
|
||||
choices=EvaluationTokenizer.ALL_TOKENIZER_TYPES,
|
||||
help='tokenizer')
|
||||
parser.add_argument('--sacrebleu-lowercase', type=str, default=False,
|
||||
help='apply lowercasing')
|
||||
parser.add_argument('--sacrebleu-char-level', action='store_true',
|
||||
help='evaluate at character level')
|
||||
# fmt: on
|
||||
|
||||
def add_string(self, ref, pred):
|
||||
self.ref.append(self.tokenizer.tokenize(ref))
|
||||
self.pred.append(self.tokenizer.tokenize(pred))
|
||||
@ -68,13 +71,20 @@ class SacrebleuScorer(BaseScorer):
|
||||
).format()
|
||||
|
||||
|
||||
@register_scorer("bleu")
|
||||
@dataclass
|
||||
class BleuConfig(FairseqDataclass):
|
||||
pad: int = field(default=1, metadata={"help": "padding index"})
|
||||
eos: int = field(default=2, metadata={"help": "eos index"})
|
||||
unk: int = field(default=3, metadata={"help": "unk index"})
|
||||
|
||||
|
||||
@register_scorer("bleu", dataclass=BleuConfig)
|
||||
class Scorer(object):
|
||||
def __init__(self, pad, eos, unk):
|
||||
def __init__(self, cfg):
|
||||
self.stat = BleuStat()
|
||||
self.pad = pad
|
||||
self.eos = eos
|
||||
self.unk = unk
|
||||
self.pad = cfg.pad
|
||||
self.eos = cfg.eos
|
||||
self.unk = cfg.unk
|
||||
|
||||
try:
|
||||
from fairseq import libbleu
|
||||
|
@ -5,6 +5,8 @@
|
||||
|
||||
import unicodedata
|
||||
|
||||
from fairseq.dataclass.utils import ChoiceEnum
|
||||
|
||||
|
||||
class EvaluationTokenizer(object):
|
||||
"""A generic evaluation-time tokenizer, which leverages built-in tokenizers
|
||||
@ -22,7 +24,7 @@ class EvaluationTokenizer(object):
|
||||
|
||||
SPACE = chr(32)
|
||||
SPACE_ESCAPE = chr(9601)
|
||||
ALL_TOKENIZER_TYPES = ["none", "13a", "intl", "zh", "ja-mecab"]
|
||||
ALL_TOKENIZER_TYPES = ChoiceEnum(["none", "13a", "intl", "zh", "ja-mecab"])
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -33,7 +35,7 @@ class EvaluationTokenizer(object):
|
||||
):
|
||||
from sacrebleu.tokenizers import TOKENIZERS
|
||||
|
||||
assert tokenizer_type in self.ALL_TOKENIZER_TYPES
|
||||
assert tokenizer_type in TOKENIZERS, f"{tokenizer_type}, {TOKENIZERS}"
|
||||
self.lowercase = lowercase
|
||||
self.punctuation_removal = punctuation_removal
|
||||
self.character_tokenization = character_tokenization
|
||||
|
@ -3,14 +3,31 @@
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from fairseq.dataclass import FairseqDataclass
|
||||
from fairseq.scoring import BaseScorer, register_scorer
|
||||
from fairseq.scoring.tokenizer import EvaluationTokenizer
|
||||
|
||||
|
||||
@register_scorer("wer")
|
||||
@dataclass
|
||||
class WerScorerConfig(FairseqDataclass):
|
||||
wer_tokenizer: EvaluationTokenizer.ALL_TOKENIZER_TYPES = field(
|
||||
default="none", metadata={"help": "sacreBLEU tokenizer to use for evaluation"}
|
||||
)
|
||||
wer_remove_punct: bool = field(
|
||||
default=False, metadata={"help": "remove punctuation"}
|
||||
)
|
||||
wer_char_level: bool = field(
|
||||
default=False, metadata={"help": "evaluate at character level"}
|
||||
)
|
||||
wer_lowercase: bool = field(default=False, metadata={"help": "lowercasing"})
|
||||
|
||||
|
||||
@register_scorer("wer", dataclass=WerScorerConfig)
|
||||
class WerScorer(BaseScorer):
|
||||
def __init__(self, args):
|
||||
super().__init__(args)
|
||||
def __init__(self, cfg):
|
||||
super().__init__(cfg)
|
||||
self.reset()
|
||||
try:
|
||||
import editdistance as ed
|
||||
@ -18,26 +35,12 @@ class WerScorer(BaseScorer):
|
||||
raise ImportError("Please install editdistance to use WER scorer")
|
||||
self.ed = ed
|
||||
self.tokenizer = EvaluationTokenizer(
|
||||
tokenizer_type=self.args.wer_tokenizer,
|
||||
lowercase=self.args.wer_lowercase,
|
||||
punctuation_removal=self.args.wer_remove_punct,
|
||||
character_tokenization=self.args.wer_char_level,
|
||||
tokenizer_type=self.cfg.wer_tokenizer,
|
||||
lowercase=self.cfg.wer_lowercase,
|
||||
punctuation_removal=self.cfg.wer_remove_punct,
|
||||
character_tokenization=self.cfg.wer_char_level,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
# fmt: off
|
||||
parser.add_argument('--wer-tokenizer', type=str, default='none',
|
||||
choices=EvaluationTokenizer.ALL_TOKENIZER_TYPES,
|
||||
help='sacreBLEU tokenizer to use for evaluation')
|
||||
parser.add_argument('--wer-remove-punct', action='store_true',
|
||||
help='remove punctuation')
|
||||
parser.add_argument('--wer-char-level', action='store_true',
|
||||
help='evaluate at character level')
|
||||
parser.add_argument('--wer-lowercase', action='store_true',
|
||||
help='lowercasing')
|
||||
# fmt: on
|
||||
|
||||
def reset(self):
|
||||
self.distance = 0
|
||||
self.ref_length = 0
|
||||
|
@ -7,8 +7,6 @@
|
||||
import argparse
|
||||
import importlib
|
||||
import os
|
||||
from argparse import Namespace
|
||||
from typing import Union
|
||||
|
||||
from fairseq.dataclass import FairseqDataclass
|
||||
from omegaconf import DictConfig
|
||||
@ -22,10 +20,10 @@ TASK_REGISTRY = {}
|
||||
TASK_CLASS_NAMES = set()
|
||||
|
||||
|
||||
def setup_task(task_cfg: Union[DictConfig, Namespace], **kwargs):
|
||||
if isinstance(task_cfg, DictConfig):
|
||||
return TASK_REGISTRY[task_cfg._name].setup_task(task_cfg, **kwargs)
|
||||
return TASK_REGISTRY[task_cfg.task].setup_task(task_cfg, **kwargs)
|
||||
def setup_task(cfg: DictConfig, **kwargs):
|
||||
if isinstance(cfg, DictConfig):
|
||||
return TASK_REGISTRY[cfg._name].setup_task(cfg, **kwargs)
|
||||
return TASK_REGISTRY[cfg.task].setup_task(cfg, **kwargs)
|
||||
|
||||
|
||||
def register_task(name, dataclass=None):
|
||||
@ -70,7 +68,8 @@ def register_task(name, dataclass=None):
|
||||
)
|
||||
|
||||
cls.__dataclass = dataclass
|
||||
TASK_DATACLASS_REGISTRY[name] = dataclass
|
||||
if dataclass is not None:
|
||||
TASK_DATACLASS_REGISTRY[name] = dataclass
|
||||
|
||||
return cls
|
||||
|
||||
|
@ -79,7 +79,7 @@ class AudioPretrainingTask(LegacyFairseqTask):
|
||||
"""Setup the task (e.g., load dictionaries).
|
||||
|
||||
Args:
|
||||
args (argparse.Namespace): parsed command-line arguments
|
||||
args (omegaconf.DictConfig): parsed command-line arguments
|
||||
"""
|
||||
return cls(args)
|
||||
|
||||
|
@ -12,6 +12,7 @@ import torch
|
||||
from fairseq import metrics, search, tokenizer, utils
|
||||
from fairseq.data import Dictionary, FairseqDataset, data_utils, encoders, iterators
|
||||
from fairseq.dataclass.utils import gen_parser_from_dataclass
|
||||
from omegaconf import DictConfig
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -39,8 +40,8 @@ class FairseqTask(object):
|
||||
"""
|
||||
return criterion.logging_outputs_can_be_summed()
|
||||
|
||||
def __init__(self, args):
|
||||
self.args = args
|
||||
def __init__(self, cfg: DictConfig, **kwargs):
|
||||
self.cfg = cfg
|
||||
self.datasets = {}
|
||||
self.dataset_to_epoch_iter = {}
|
||||
|
||||
@ -78,16 +79,16 @@ class FairseqTask(object):
|
||||
return d
|
||||
|
||||
@classmethod
|
||||
def setup_task(cls, args, **kwargs):
|
||||
def setup_task(cls, cfg: DictConfig, **kwargs):
|
||||
"""Setup the task (e.g., load dictionaries).
|
||||
|
||||
Args:
|
||||
args (argparse.Namespace): parsed command-line arguments
|
||||
cfg (omegaconf.DictConfig): parsed command-line arguments
|
||||
"""
|
||||
return cls(args, **kwargs)
|
||||
return cls(cfg, **kwargs)
|
||||
|
||||
def has_sharded_data(self, split):
|
||||
return os.pathsep in getattr(self.args, "data", "")
|
||||
return os.pathsep in getattr(self.cfg, "data", "")
|
||||
|
||||
def load_dataset(self, split, combine=False, **kwargs):
|
||||
"""Load a given dataset split.
|
||||
@ -254,39 +255,39 @@ class FairseqTask(object):
|
||||
|
||||
return epoch_iter
|
||||
|
||||
def build_model(self, args):
|
||||
def build_model(self, cfg: DictConfig):
|
||||
"""
|
||||
Build the :class:`~fairseq.models.BaseFairseqModel` instance for this
|
||||
task.
|
||||
|
||||
Args:
|
||||
args (argparse.Namespace): parsed command-line arguments
|
||||
cfg (omegaconf.DictConfig): configuration object
|
||||
|
||||
Returns:
|
||||
a :class:`~fairseq.models.BaseFairseqModel` instance
|
||||
"""
|
||||
from fairseq import models, quantization_utils
|
||||
|
||||
model = models.build_model(args, self)
|
||||
if getattr(args, "tpu", False):
|
||||
model = models.build_model(cfg, self)
|
||||
if getattr(cfg, "tpu", False):
|
||||
model.prepare_for_tpu_()
|
||||
model = quantization_utils.quantize_model_scalar(model, args)
|
||||
model = quantization_utils.quantize_model_scalar(model, cfg)
|
||||
return model
|
||||
|
||||
def build_criterion(self, args):
|
||||
def build_criterion(self, cfg: DictConfig):
|
||||
"""
|
||||
Build the :class:`~fairseq.criterions.FairseqCriterion` instance for
|
||||
this task.
|
||||
|
||||
Args:
|
||||
args (argparse.Namespace): parsed command-line arguments
|
||||
cfg (omegaconf.DictConfig): configration object
|
||||
|
||||
Returns:
|
||||
a :class:`~fairseq.criterions.FairseqCriterion` instance
|
||||
"""
|
||||
from fairseq import criterions
|
||||
|
||||
return criterions.build_criterion(args, self)
|
||||
return criterions.build_criterion(cfg, self)
|
||||
|
||||
def build_generator(
|
||||
self, models, args, seq_gen_cls=None, extra_gen_cls_kwargs=None
|
||||
|
@ -28,7 +28,7 @@ from fairseq.data import (
|
||||
from fairseq.data.indexed_dataset import get_available_dataset_impl
|
||||
from fairseq.data.shorten_dataset import maybe_shorten_dataset
|
||||
from fairseq.dataclass import ChoiceEnum, FairseqDataclass
|
||||
from fairseq.tasks import FairseqTask, register_task
|
||||
from fairseq.tasks import LegacyFairseqTask, register_task
|
||||
from omegaconf import II
|
||||
|
||||
|
||||
@ -85,16 +85,16 @@ class LanguageModelingConfig(FairseqDataclass):
|
||||
},
|
||||
)
|
||||
# TODO common vars below add to parent
|
||||
seed: int = II("params.common.seed")
|
||||
seed: int = II("common.seed")
|
||||
dataset_impl: Optional[ChoiceEnum(get_available_dataset_impl())] = II(
|
||||
"params.dataset.dataset_impl"
|
||||
"dataset.dataset_impl"
|
||||
)
|
||||
data_buffer_size: int = II("params.dataset.data_buffer_size")
|
||||
tpu: bool = II("params.common.tpu")
|
||||
data_buffer_size: int = II("dataset.data_buffer_size")
|
||||
tpu: bool = II("common.tpu")
|
||||
|
||||
|
||||
@register_task("language_modeling", dataclass=LanguageModelingConfig)
|
||||
class LanguageModelingTask(FairseqTask):
|
||||
class LanguageModelingTask(LegacyFairseqTask):
|
||||
"""
|
||||
Train a language model.
|
||||
|
||||
|
@ -117,7 +117,7 @@ class MultilingualTranslationTask(LegacyFairseqTask):
|
||||
return cls(args, dicts, training)
|
||||
|
||||
@classmethod
|
||||
def prepare(cls, args, **kargs):
|
||||
def update_args(cls, args):
|
||||
args.left_pad_source = utils.eval_bool(args.left_pad_source)
|
||||
args.left_pad_target = utils.eval_bool(args.left_pad_target)
|
||||
|
||||
@ -127,6 +127,10 @@ class MultilingualTranslationTask(LegacyFairseqTask):
|
||||
)
|
||||
if isinstance(args.lang_pairs, str):
|
||||
args.lang_pairs = args.lang_pairs.split(",")
|
||||
|
||||
@classmethod
|
||||
def prepare(cls, args, **kargs):
|
||||
cls.update_args(args)
|
||||
sorted_langs = sorted(
|
||||
list({x for lang_pair in args.lang_pairs for x in lang_pair.split("-")})
|
||||
)
|
||||
@ -298,6 +302,10 @@ class MultilingualTranslationTask(LegacyFairseqTask):
|
||||
if len(messages) > 0:
|
||||
raise ValueError(" ".join(messages))
|
||||
|
||||
# Update args -> the fact that the constructor here
|
||||
# changes the args object doesn't mean you get the same one here
|
||||
self.update_args(args)
|
||||
|
||||
# Check if task args are consistant with model args
|
||||
check_args()
|
||||
|
||||
|
@ -13,7 +13,7 @@ from fairseq.data.audio.speech_to_text_dataset import (
|
||||
SpeechToTextDataset,
|
||||
SpeechToTextDatasetCreator,
|
||||
)
|
||||
from fairseq.tasks import FairseqTask, register_task
|
||||
from fairseq.tasks import LegacyFairseqTask, register_task
|
||||
|
||||
|
||||
logging.basicConfig(
|
||||
@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@register_task("speech_to_text")
|
||||
class SpeechToTextTask(FairseqTask):
|
||||
class SpeechToTextTask(LegacyFairseqTask):
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
parser.add_argument("data", help="manifest root path")
|
||||
|
@ -11,15 +11,18 @@ import contextlib
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
from argparse import Namespace
|
||||
from itertools import chain
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import torch
|
||||
from fairseq import checkpoint_utils, distributed_utils, models, optim, utils
|
||||
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
|
||||
from fairseq.file_io import PathManager
|
||||
from fairseq.logging import meters, metrics
|
||||
from fairseq.nan_detector import NanDetector
|
||||
from fairseq.optim import lr_scheduler
|
||||
from omegaconf import DictConfig
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -35,19 +38,25 @@ class Trainer(object):
|
||||
communication of the gradients across workers.
|
||||
"""
|
||||
|
||||
def __init__(self, args, task, model, criterion, quantizer=None):
|
||||
self.args = args
|
||||
def __init__(self, cfg: DictConfig, task, model, criterion, quantizer=None):
|
||||
|
||||
if isinstance(cfg, Namespace):
|
||||
logger.warning(
|
||||
"argparse.Namespace configuration is deprecated! Automatically converting to OmegaConf"
|
||||
)
|
||||
cfg = convert_namespace_to_omegaconf(cfg)
|
||||
|
||||
self.cfg = cfg
|
||||
self.task = task
|
||||
|
||||
# catalog shared parameters
|
||||
shared_params = _catalog_shared_params(model)
|
||||
|
||||
self.tpu = getattr(args, "tpu", False)
|
||||
self.cuda = torch.cuda.is_available() and not args.cpu and not self.tpu
|
||||
self.tpu = cfg.common.tpu
|
||||
self.cuda = torch.cuda.is_available() and not cfg.common.cpu and not self.tpu
|
||||
if self.cuda:
|
||||
self.device = torch.device("cuda")
|
||||
elif self.tpu:
|
||||
self.device = utils.get_tpu_device(args)
|
||||
self.device = utils.get_tpu_device()
|
||||
else:
|
||||
self.device = torch.device("cpu")
|
||||
|
||||
@ -58,19 +67,21 @@ class Trainer(object):
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
self._model = xm.send_cpu_data_to_device(self._model, self.device)
|
||||
if args.fp16:
|
||||
if cfg.common.fp16:
|
||||
self._criterion = self._criterion.half()
|
||||
self._model = self._model.half()
|
||||
elif args.bf16:
|
||||
elif cfg.common.bf16:
|
||||
self._criterion = self._criterion.to(dtype=torch.bfloat16)
|
||||
self._model = self._model.to(dtype=torch.bfloat16)
|
||||
if not args.pipeline_model_parallel:
|
||||
if not cfg.distributed_training.pipeline_model_parallel:
|
||||
self._criterion = self._criterion.to(device=self.device)
|
||||
self._model = self._model.to(device=self.device)
|
||||
self.pipeline_model_parallel = args.pipeline_model_parallel
|
||||
self.pipeline_model_parallel = cfg.distributed_training.pipeline_model_parallel
|
||||
self.last_device = None
|
||||
if self.cuda and self.pipeline_model_parallel:
|
||||
self.last_device = torch.device(args.pipeline_devices[-1])
|
||||
self.last_device = torch.device(
|
||||
cfg.distributed_training.pipeline_devices[-1]
|
||||
)
|
||||
|
||||
# check that shared parameters are preserved after device transfer
|
||||
for shared_param in shared_params:
|
||||
@ -129,7 +140,7 @@ class Trainer(object):
|
||||
|
||||
@property
|
||||
def data_parallel_world_size(self):
|
||||
return self.args.distributed_world_size
|
||||
return self.cfg.distributed_training.distributed_world_size
|
||||
|
||||
@property
|
||||
def data_parallel_process_group(self):
|
||||
@ -140,11 +151,11 @@ class Trainer(object):
|
||||
|
||||
@property
|
||||
def data_parallel_rank(self):
|
||||
return self.args.distributed_rank
|
||||
return self.cfg.distributed_training.distributed_rank
|
||||
|
||||
@property
|
||||
def is_data_parallel_master(self):
|
||||
return distributed_utils.is_master(self.args)
|
||||
return distributed_utils.is_master(self.cfg.distributed_training)
|
||||
|
||||
@property
|
||||
def criterion(self):
|
||||
@ -152,11 +163,11 @@ class Trainer(object):
|
||||
if (
|
||||
utils.has_parameters(self._criterion)
|
||||
and self.data_parallel_world_size > 1
|
||||
and not self.args.use_bmuf
|
||||
and not self.cfg.optimization.use_bmuf
|
||||
and not self.tpu
|
||||
):
|
||||
self._wrapped_criterion = models.DistributedFairseqModel(
|
||||
self.args,
|
||||
self.cfg.distributed_training,
|
||||
self._criterion,
|
||||
process_group=self.data_parallel_process_group,
|
||||
)
|
||||
@ -169,11 +180,11 @@ class Trainer(object):
|
||||
if self._wrapped_model is None:
|
||||
if (
|
||||
self.data_parallel_world_size > 1
|
||||
and not self.args.use_bmuf
|
||||
and not self.cfg.optimization.use_bmuf
|
||||
and not self.tpu
|
||||
):
|
||||
self._wrapped_model = models.DistributedFairseqModel(
|
||||
self.args,
|
||||
self.cfg.distributed_training,
|
||||
self._model,
|
||||
process_group=self.data_parallel_process_group,
|
||||
)
|
||||
@ -201,44 +212,51 @@ class Trainer(object):
|
||||
)
|
||||
)
|
||||
|
||||
if self.args.fp16 or self.args.bf16:
|
||||
if self.cfg.common.fp16 or self.cfg.common.bf16:
|
||||
if self.cuda and torch.cuda.get_device_capability(0)[0] < 7:
|
||||
logger.info(
|
||||
"NOTE: your device does NOT support faster training with --fp16, "
|
||||
"please switch to FP32 which is likely to be faster"
|
||||
)
|
||||
if self.args.memory_efficient_fp16 or self.args.memory_efficient_bf16:
|
||||
if (
|
||||
self.cfg.common.memory_efficient_fp16
|
||||
or self.cfg.common.memory_efficient_bf16
|
||||
):
|
||||
self._optimizer = optim.MemoryEfficientFP16Optimizer.build_optimizer(
|
||||
self.args, params
|
||||
self.cfg, params
|
||||
)
|
||||
else:
|
||||
self._optimizer = optim.FP16Optimizer.build_optimizer(self.args, params)
|
||||
self._optimizer = optim.FP16Optimizer.build_optimizer(self.cfg, params)
|
||||
else:
|
||||
if self.cuda and torch.cuda.get_device_capability(0)[0] >= 7:
|
||||
logger.info("NOTE: your device may support faster training with --fp16")
|
||||
self._optimizer = optim.build_optimizer(self.args, params)
|
||||
self._optimizer = optim.build_optimizer(self.cfg.optimizer, params)
|
||||
|
||||
if self.args.use_bmuf:
|
||||
self._optimizer = optim.FairseqBMUF(self.args, self._optimizer)
|
||||
if self.cfg.optimization.use_bmuf:
|
||||
self._optimizer = optim.FairseqBMUF(
|
||||
self.cfg.bmuf,
|
||||
self._optimizer,
|
||||
)
|
||||
|
||||
if self.args.zero_sharding == "os":
|
||||
if self.cfg.distributed_training.zero_sharding == "os":
|
||||
if (
|
||||
self.args.fp16
|
||||
and not self.args.memory_efficient_fp16
|
||||
and not self.args.memory_efficient_bf16
|
||||
) and not self.args.fp16_no_flatten_grads:
|
||||
self.cfg.common.fp16
|
||||
and not self.cfg.common.memory_efficient_fp16
|
||||
and not self.cfg.common.memory_efficient_bf16
|
||||
) and not self.cfg.common.fp16_no_flatten_grads:
|
||||
raise ValueError(
|
||||
"ZeRO is incomptabile with fp16 and flattened grads. "
|
||||
"Please use --fp16-no-flatten-grads"
|
||||
)
|
||||
else:
|
||||
optim.shard_(
|
||||
self.args, self._optimizer, self.data_parallel_process_group
|
||||
)
|
||||
optim.shard_(self._optimizer, self.data_parallel_process_group)
|
||||
|
||||
# We should initialize the learning rate scheduler immediately after
|
||||
# building the optimizer, so that the initial learning rate is set.
|
||||
self._lr_scheduler = lr_scheduler.build_lr_scheduler(self.args, self.optimizer)
|
||||
self._lr_scheduler = lr_scheduler.build_lr_scheduler(
|
||||
self.cfg.lr_scheduler,
|
||||
self.optimizer,
|
||||
)
|
||||
self._lr_scheduler.step_update(0)
|
||||
|
||||
def consolidate_optimizer(self):
|
||||
@ -253,7 +271,7 @@ class Trainer(object):
|
||||
extra_state["previous_training_time"] = self.cumulative_training_time()
|
||||
checkpoint_utils.save_state(
|
||||
filename,
|
||||
self.args,
|
||||
self.cfg,
|
||||
self.get_model().state_dict(),
|
||||
self.get_criterion(),
|
||||
self.optimizer,
|
||||
@ -277,11 +295,10 @@ class Trainer(object):
|
||||
bexists = PathManager.isfile(filename)
|
||||
if bexists:
|
||||
state = checkpoint_utils.load_checkpoint_to_cpu(filename)
|
||||
|
||||
# load model parameters
|
||||
try:
|
||||
self.get_model().load_state_dict(
|
||||
state["model"], strict=True, args=self.args
|
||||
state["model"], strict=True, model_cfg=self.cfg.model
|
||||
)
|
||||
if utils.has_parameters(self.get_criterion()):
|
||||
self.get_criterion().load_state_dict(
|
||||
@ -355,28 +372,28 @@ class Trainer(object):
|
||||
if load_dataset:
|
||||
logger.info("loading train data for epoch {}".format(epoch))
|
||||
self.task.load_dataset(
|
||||
self.args.train_subset,
|
||||
self.cfg.dataset.train_subset,
|
||||
epoch=epoch,
|
||||
combine=combine,
|
||||
data_selector=data_selector,
|
||||
)
|
||||
batch_iterator = self.task.get_batch_iterator(
|
||||
dataset=self.task.dataset(self.args.train_subset),
|
||||
max_tokens=self.args.max_tokens,
|
||||
max_sentences=self.args.batch_size,
|
||||
dataset=self.task.dataset(self.cfg.dataset.train_subset),
|
||||
max_tokens=self.cfg.dataset.max_tokens,
|
||||
max_sentences=self.cfg.dataset.batch_size,
|
||||
max_positions=utils.resolve_max_positions(
|
||||
self.task.max_positions(),
|
||||
self.model.max_positions(),
|
||||
self.args.max_tokens,
|
||||
self.cfg.dataset.max_tokens,
|
||||
),
|
||||
ignore_invalid_inputs=True,
|
||||
required_batch_size_multiple=self.args.required_batch_size_multiple,
|
||||
seed=self.args.seed,
|
||||
required_batch_size_multiple=self.cfg.dataset.required_batch_size_multiple,
|
||||
seed=self.cfg.common.seed,
|
||||
num_shards=self.data_parallel_world_size if shard_batch_itr else 1,
|
||||
shard_id=self.data_parallel_rank if shard_batch_itr else 0,
|
||||
num_workers=self.args.num_workers,
|
||||
num_workers=self.cfg.dataset.num_workers,
|
||||
epoch=epoch,
|
||||
data_buffer_size=self.args.data_buffer_size,
|
||||
data_buffer_size=self.cfg.dataset.data_buffer_size,
|
||||
disable_iterator_cache=disable_iterator_cache,
|
||||
)
|
||||
self.reset_dummy_batch(batch_iterator.first_batch)
|
||||
@ -390,19 +407,19 @@ class Trainer(object):
|
||||
"""Return an EpochBatchIterator over given validation subset for a given epoch."""
|
||||
batch_iterator = self.task.get_batch_iterator(
|
||||
dataset=self.task.dataset(subset),
|
||||
max_tokens=self.args.max_tokens_valid,
|
||||
max_sentences=self.args.batch_size_valid,
|
||||
max_tokens=self.cfg.dataset.max_tokens_valid,
|
||||
max_sentences=self.cfg.dataset.batch_size_valid,
|
||||
max_positions=utils.resolve_max_positions(
|
||||
self.task.max_positions(),
|
||||
self.model.max_positions(),
|
||||
),
|
||||
ignore_invalid_inputs=self.args.skip_invalid_size_inputs_valid_test,
|
||||
required_batch_size_multiple=self.args.required_batch_size_multiple,
|
||||
seed=self.args.seed,
|
||||
ignore_invalid_inputs=self.cfg.dataset.skip_invalid_size_inputs_valid_test,
|
||||
required_batch_size_multiple=self.cfg.dataset.required_batch_size_multiple,
|
||||
seed=self.cfg.common.seed,
|
||||
num_shards=self.data_parallel_world_size,
|
||||
shard_id=self.data_parallel_rank,
|
||||
num_workers=self.args.num_workers,
|
||||
data_buffer_size=self.args.data_buffer_size,
|
||||
num_workers=self.cfg.dataset.num_workers,
|
||||
data_buffer_size=self.cfg.dataset.data_buffer_size,
|
||||
disable_iterator_cache=disable_iterator_cache,
|
||||
)
|
||||
self.reset_dummy_batch(batch_iterator.first_batch)
|
||||
@ -504,7 +521,7 @@ class Trainer(object):
|
||||
self.zero_grad()
|
||||
if self.cuda:
|
||||
torch.cuda.empty_cache()
|
||||
if self.args.distributed_world_size == 1:
|
||||
if self.cfg.distributed_training.distributed_world_size == 1:
|
||||
return None
|
||||
else:
|
||||
raise e
|
||||
@ -565,7 +582,7 @@ class Trainer(object):
|
||||
# multiply gradients by (# GPUs / sample_size) since DDP
|
||||
# already normalizes by the number of GPUs. Thus we get
|
||||
# (sum_of_gradients / sample_size).
|
||||
if not self.args.use_bmuf:
|
||||
if not self.cfg.optimization.use_bmuf:
|
||||
self.optimizer.multiply_grads(
|
||||
self.data_parallel_world_size / sample_size
|
||||
)
|
||||
@ -575,12 +592,12 @@ class Trainer(object):
|
||||
|
||||
with torch.autograd.profiler.record_function("clip-grads"):
|
||||
# clip grads
|
||||
grad_norm = self.clip_grad_norm(self.args.clip_norm)
|
||||
grad_norm = self.clip_grad_norm(self.cfg.optimization.clip_norm)
|
||||
|
||||
# check that grad norms are consistent across workers
|
||||
if (
|
||||
not self.args.use_bmuf
|
||||
and self.args.distributed_wrapper != "SlowMo"
|
||||
not self.cfg.optimization.use_bmuf
|
||||
and self.cfg.distributed_training.distributed_wrapper != "SlowMo"
|
||||
and not self.tpu
|
||||
):
|
||||
self._check_grad_norms(grad_norm)
|
||||
@ -624,7 +641,10 @@ class Trainer(object):
|
||||
self.optimizer.optimizer
|
||||
)
|
||||
|
||||
if not overflow or self.args.distributed_wrapper == "SlowMo":
|
||||
if (
|
||||
not overflow
|
||||
or self.cfg.distributed_training.distributed_wrapper == "SlowMo"
|
||||
):
|
||||
self.set_num_updates(self.get_num_updates() + 1)
|
||||
|
||||
if self.tpu:
|
||||
@ -636,7 +656,7 @@ class Trainer(object):
|
||||
# only log stats every log_interval steps
|
||||
# this causes wps to be misreported when log_interval > 1
|
||||
logging_output = {}
|
||||
if self.get_num_updates() % self.args.log_interval == 0:
|
||||
if self.get_num_updates() % self.cfg.common.log_interval == 0:
|
||||
# log memory usage
|
||||
mem_info = xm.get_memory_info(self.device)
|
||||
gb_free = mem_info["kb_free"] / 1024 / 1024
|
||||
@ -677,16 +697,16 @@ class Trainer(object):
|
||||
# clear CUDA cache to reduce memory fragmentation
|
||||
if (
|
||||
self.cuda
|
||||
and self.args.empty_cache_freq > 0
|
||||
and self.cfg.common.empty_cache_freq > 0
|
||||
and (
|
||||
(self.get_num_updates() + self.args.empty_cache_freq - 1)
|
||||
% self.args.empty_cache_freq
|
||||
(self.get_num_updates() + self.cfg.common.empty_cache_freq - 1)
|
||||
% self.cfg.common.empty_cache_freq
|
||||
)
|
||||
== 0
|
||||
):
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if self.args.fp16:
|
||||
if self.cfg.common.fp16:
|
||||
metrics.log_scalar(
|
||||
"loss_scale",
|
||||
self.optimizer.scaler.loss_scale,
|
||||
@ -883,10 +903,10 @@ class Trainer(object):
|
||||
return t.to(dtype=torch.bfloat16)
|
||||
return t
|
||||
|
||||
if self.args.fp16:
|
||||
if self.cfg.common.fp16:
|
||||
sample = utils.apply_to_sample(apply_half, sample)
|
||||
|
||||
if self.args.bf16:
|
||||
if self.cfg.common.bf16:
|
||||
sample = utils.apply_to_sample(apply_bfloat16, sample)
|
||||
|
||||
return sample
|
||||
@ -894,7 +914,7 @@ class Trainer(object):
|
||||
def _set_seed(self):
|
||||
# Set seed based on args.seed and the update number so that we get
|
||||
# reproducible results when resuming from checkpoints
|
||||
seed = self.args.seed + self.get_num_updates()
|
||||
seed = self.cfg.common.seed + self.get_num_updates()
|
||||
utils.set_torch_seed(seed)
|
||||
|
||||
def _sync_stats(self):
|
||||
@ -902,10 +922,12 @@ class Trainer(object):
|
||||
# BMUF and it's a bmuf sync with warmup iterations completed before.
|
||||
if self.data_parallel_world_size == 1:
|
||||
return False
|
||||
elif self.args.use_bmuf:
|
||||
return (self.get_num_updates() + 1) % self.args.global_sync_iter == 0 and (
|
||||
elif self.cfg.optimization.use_bmuf:
|
||||
return (
|
||||
self.get_num_updates() + 1
|
||||
) > self.args.warmup_iterations
|
||||
) % self.cfg.bmuf.global_sync_iter == 0 and (
|
||||
self.get_num_updates() + 1
|
||||
) > self.cfg.bmuf.warmup_iterations
|
||||
else:
|
||||
return True
|
||||
|
||||
@ -950,7 +972,7 @@ class Trainer(object):
|
||||
zip(
|
||||
*distributed_utils.all_gather_list(
|
||||
[logging_outputs] + list(extra_stats_to_sum),
|
||||
max_size=getattr(self.args, "all_gather_list_size", 16384),
|
||||
max_size=getattr(self.cfg.common, "all_gather_list_size", 16384),
|
||||
group=self.data_parallel_process_group,
|
||||
)
|
||||
)
|
||||
@ -1038,11 +1060,11 @@ class Trainer(object):
|
||||
if grad_norm is not None:
|
||||
metrics.log_speed("ups", 1.0, priority=100, round=2)
|
||||
metrics.log_scalar("gnorm", grad_norm, priority=400, round=3)
|
||||
if self.args.clip_norm > 0:
|
||||
if self.cfg.optimization.clip_norm > 0:
|
||||
metrics.log_scalar(
|
||||
"clip",
|
||||
torch.where(
|
||||
grad_norm > self.args.clip_norm,
|
||||
grad_norm > self.cfg.optimization.clip_norm,
|
||||
grad_norm.new_tensor(100),
|
||||
grad_norm.new_tensor(0),
|
||||
),
|
||||
@ -1087,7 +1109,7 @@ class Trainer(object):
|
||||
logger.warning(
|
||||
"XLA compilation detected on device #{}; too many of these can lead "
|
||||
"to slow training, but we expect a few in the beginning".format(
|
||||
self.args.distributed_rank
|
||||
self.cfg.distributed_training.distributed_rank
|
||||
)
|
||||
)
|
||||
self._num_xla_compiles = num_xla_compiles
|
||||
|
@ -11,13 +11,19 @@ Evaluate the perplexity of a trained language model.
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
from argparse import Namespace
|
||||
|
||||
import torch
|
||||
from fairseq import checkpoint_utils, distributed_utils, options, tasks, utils
|
||||
from fairseq.data import LMContextWindowDataset
|
||||
from fairseq.dataclass.data_class import register_hydra_cfg
|
||||
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
|
||||
from fairseq.logging import progress_bar
|
||||
from fairseq.logging.meters import StopwatchMeter, TimeMeter
|
||||
from fairseq.sequence_scorer import SequenceScorer
|
||||
from hydra.core.config_store import ConfigStore
|
||||
from hydra.experimental import initialize
|
||||
from omegaconf import DictConfig
|
||||
|
||||
|
||||
logging.basicConfig(
|
||||
@ -60,65 +66,60 @@ class WordStat(object):
|
||||
)
|
||||
|
||||
|
||||
def main(parsed_args, **unused_kwargs):
|
||||
assert parsed_args.path is not None, "--path required for evaluation!"
|
||||
def main(cfg: DictConfig, override_args=None, **unused_kwargs):
|
||||
if isinstance(cfg, Namespace):
|
||||
cfg = convert_namespace_to_omegaconf(cfg)
|
||||
|
||||
if torch.cuda.is_available() and not parsed_args.cpu:
|
||||
torch.cuda.set_device(parsed_args.device_id)
|
||||
utils.import_user_module(cfg.common)
|
||||
|
||||
utils.import_user_module(parsed_args)
|
||||
use_fp16 = cfg.common.fp16
|
||||
use_cuda = torch.cuda.is_available() and not cfg.common.cpu
|
||||
|
||||
logger.info(parsed_args)
|
||||
if use_cuda:
|
||||
torch.cuda.set_device(cfg.distributed_training.device_id)
|
||||
|
||||
use_cuda = torch.cuda.is_available() and not parsed_args.cpu
|
||||
if override_args is not None:
|
||||
overrides = vars(override_args)
|
||||
overrides.update(eval(getattr(override_args, "model_overrides", "{}")))
|
||||
else:
|
||||
overrides = None
|
||||
|
||||
task = tasks.setup_task(parsed_args)
|
||||
logger.info(cfg)
|
||||
|
||||
# Load ensemble
|
||||
logger.info("loading model(s) from {}".format(parsed_args.path))
|
||||
models, args = checkpoint_utils.load_model_ensemble(
|
||||
parsed_args.path.split(os.pathsep),
|
||||
arg_overrides=eval(parsed_args.model_overrides),
|
||||
task=task,
|
||||
suffix=getattr(parsed_args, "checkpoint_suffix", ""),
|
||||
strict=(parsed_args.checkpoint_shard_count == 1),
|
||||
num_shards=parsed_args.checkpoint_shard_count,
|
||||
)
|
||||
|
||||
for arg in vars(parsed_args).keys():
|
||||
if arg not in {
|
||||
"self_target",
|
||||
"future_target",
|
||||
"past_target",
|
||||
"tokens_per_sample",
|
||||
"output_size_dictionary",
|
||||
"add_bos_token",
|
||||
}:
|
||||
setattr(args, arg, getattr(parsed_args, arg))
|
||||
logger.info("loading model(s) from {}".format(cfg.common_eval.path))
|
||||
|
||||
# reduce tokens per sample by the required context window size
|
||||
args.tokens_per_sample -= args.context_window
|
||||
task = tasks.setup_task(args)
|
||||
cfg.task.tokens_per_sample -= cfg.eval_lm.context_window
|
||||
|
||||
models, model_args, task = checkpoint_utils.load_model_ensemble_and_task(
|
||||
[cfg.common_eval.path],
|
||||
arg_overrides=overrides,
|
||||
suffix=cfg.checkpoint.checkpoint_suffix,
|
||||
strict=(cfg.checkpoint.checkpoint_shard_count == 1),
|
||||
num_shards=cfg.checkpoint.checkpoint_shard_count,
|
||||
)
|
||||
|
||||
# Load dataset splits
|
||||
task.load_dataset(args.gen_subset)
|
||||
dataset = task.dataset(args.gen_subset)
|
||||
if args.context_window > 0:
|
||||
gen_subset = cfg.dataset.gen_subset
|
||||
task.load_dataset(gen_subset)
|
||||
dataset = task.dataset(gen_subset)
|
||||
if cfg.eval_lm.context_window > 0:
|
||||
dataset = LMContextWindowDataset(
|
||||
dataset=dataset,
|
||||
tokens_per_sample=args.tokens_per_sample,
|
||||
context_window=args.context_window,
|
||||
tokens_per_sample=cfg.task.tokens_per_sample,
|
||||
context_window=cfg.eval_lm.context_window,
|
||||
pad_idx=task.source_dictionary.pad(),
|
||||
)
|
||||
logger.info("{} {} {} examples".format(args.data, args.gen_subset, len(dataset)))
|
||||
logger.info("{} {} {} examples".format(cfg.task.data, gen_subset, len(dataset)))
|
||||
|
||||
# Optimize ensemble for generation and set the source and dest dicts on the model (required by scorer)
|
||||
for model in models:
|
||||
if args.fp16:
|
||||
if use_fp16:
|
||||
model.half()
|
||||
if use_cuda and not args.pipeline_model_parallel:
|
||||
if use_cuda and not cfg.distributed_training.pipeline_model_parallel:
|
||||
model.cuda()
|
||||
model.prepare_for_inference_(args)
|
||||
model.prepare_for_inference_(cfg)
|
||||
|
||||
assert len(models) > 0
|
||||
|
||||
@ -128,35 +129,41 @@ def main(parsed_args, **unused_kwargs):
|
||||
|
||||
itr = task.get_batch_iterator(
|
||||
dataset=dataset,
|
||||
max_tokens=args.max_tokens or 36000,
|
||||
max_sentences=args.batch_size,
|
||||
max_tokens=cfg.dataset.max_tokens or 36000,
|
||||
max_sentences=cfg.dataset.batch_size,
|
||||
max_positions=utils.resolve_max_positions(
|
||||
*[model.max_positions() for model in models]
|
||||
),
|
||||
ignore_invalid_inputs=True,
|
||||
num_shards=args.num_shards,
|
||||
shard_id=args.shard_id,
|
||||
num_workers=args.num_workers,
|
||||
data_buffer_size=args.data_buffer_size,
|
||||
num_shards=max(
|
||||
cfg.dataset.num_shards,
|
||||
cfg.distributed_training.distributed_world_size,
|
||||
),
|
||||
shard_id=max(
|
||||
cfg.dataset.shard_id,
|
||||
cfg.distributed_training.distributed_rank,
|
||||
),
|
||||
num_workers=cfg.dataset.num_workers,
|
||||
data_buffer_size=cfg.dataset.data_buffer_size,
|
||||
).next_epoch_itr(shuffle=False)
|
||||
progress = progress_bar.progress_bar(
|
||||
itr,
|
||||
log_format=args.log_format,
|
||||
log_interval=args.log_interval,
|
||||
default_log_format=("tqdm" if not args.no_progress_bar else "none"),
|
||||
log_format=cfg.common.log_format,
|
||||
log_interval=cfg.common.log_interval,
|
||||
default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"),
|
||||
)
|
||||
|
||||
gen_timer = StopwatchMeter()
|
||||
scorer = SequenceScorer(task.target_dictionary, args.softmax_batch)
|
||||
scorer = SequenceScorer(task.target_dictionary, cfg.eval_lm.softmax_batch)
|
||||
|
||||
score_sum = 0.0
|
||||
count = 0
|
||||
|
||||
if args.remove_bpe is not None:
|
||||
if args.remove_bpe == "sentencepiece":
|
||||
if cfg.common_eval.remove_bpe is not None:
|
||||
if cfg.common_eval.remove_bpe == "sentencepiece":
|
||||
raise NotImplementedError
|
||||
else:
|
||||
bpe_cont = args.remove_bpe.rstrip()
|
||||
bpe_cont = cfg.common_eval.remove_bpe.rstrip()
|
||||
bpe_toks = {
|
||||
i
|
||||
for i in range(len(task.source_dictionary))
|
||||
@ -189,7 +196,7 @@ def main(parsed_args, **unused_kwargs):
|
||||
tgt_len = tokens.numel()
|
||||
pos_scores = hypo["positional_scores"].float()
|
||||
|
||||
if getattr(args, "add_bos_token", False):
|
||||
if cfg.task.add_bos_token:
|
||||
assert hypo["tokens"][0].item() == task.target_dictionary.bos()
|
||||
tokens = tokens[1:]
|
||||
pos_scores = pos_scores[1:]
|
||||
@ -212,7 +219,7 @@ def main(parsed_args, **unused_kwargs):
|
||||
score_sum += pos_scores.sum().cpu()
|
||||
count += pos_scores.numel() - skipped_toks
|
||||
|
||||
if args.output_word_probs or args.output_word_stats:
|
||||
if cfg.eval_lm.output_word_probs or cfg.eval_lm.output_word_stats:
|
||||
w = ""
|
||||
word_prob = []
|
||||
is_bpe = False
|
||||
@ -238,7 +245,7 @@ def main(parsed_args, **unused_kwargs):
|
||||
)
|
||||
is_bpe = False
|
||||
w = ""
|
||||
if args.output_word_probs:
|
||||
if cfg.eval_lm.output_word_probs:
|
||||
logger.info(
|
||||
str(int(sample_id))
|
||||
+ " "
|
||||
@ -264,7 +271,7 @@ def main(parsed_args, **unused_kwargs):
|
||||
)
|
||||
)
|
||||
|
||||
if args.output_word_stats:
|
||||
if cfg.eval_lm.output_word_stats:
|
||||
for ws in sorted(word_stats.values(), key=lambda x: x.count, reverse=True):
|
||||
logger.info(ws)
|
||||
|
||||
@ -272,8 +279,16 @@ def main(parsed_args, **unused_kwargs):
|
||||
def cli_main():
|
||||
parser = options.get_eval_lm_parser()
|
||||
args = options.parse_args_and_arch(parser)
|
||||
distributed_utils.call_main(args, main)
|
||||
|
||||
# only override args that are explicitly given on the command line
|
||||
override_parser = options.get_validation_parser()
|
||||
override_args = options.parse_args_and_arch(override_parser, suppress_defaults=True)
|
||||
|
||||
distributed_utils.call_main(args, main, override_args=override_args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cs = ConfigStore.instance()
|
||||
register_hydra_cfg(cs)
|
||||
initialize(config_path="../config", strict=True)
|
||||
cli_main()
|
||||
|
@ -12,33 +12,45 @@ import logging
|
||||
import math
|
||||
import os
|
||||
import sys
|
||||
from argparse import Namespace
|
||||
from itertools import chain
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from fairseq import checkpoint_utils, options, scoring, tasks, utils
|
||||
from fairseq.data import encoders
|
||||
from fairseq.dataclass.data_class import register_hydra_cfg
|
||||
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
|
||||
from fairseq.logging import progress_bar
|
||||
from fairseq.logging.meters import StopwatchMeter, TimeMeter
|
||||
from hydra.core.config_store import ConfigStore
|
||||
from hydra.experimental import initialize
|
||||
from omegaconf import DictConfig
|
||||
|
||||
|
||||
def main(args):
|
||||
assert args.path is not None, "--path required for generation!"
|
||||
def main(cfg: DictConfig):
|
||||
|
||||
if isinstance(cfg, Namespace):
|
||||
cfg = convert_namespace_to_omegaconf(cfg)
|
||||
|
||||
assert cfg.common_eval.path is not None, "--path required for generation!"
|
||||
assert (
|
||||
not args.sampling or args.nbest == args.beam
|
||||
not cfg.generation.sampling or cfg.generation.nbest == cfg.generation.beam
|
||||
), "--sampling requires --nbest to be equal to --beam"
|
||||
assert (
|
||||
args.replace_unk is None or args.dataset_impl == "raw"
|
||||
cfg.generation.replace_unk is None or cfg.dataset.dataset_impl == "raw"
|
||||
), "--replace-unk requires a raw text dataset (--dataset-impl=raw)"
|
||||
|
||||
if args.results_path is not None:
|
||||
os.makedirs(args.results_path, exist_ok=True)
|
||||
if cfg.common_eval.results_path is not None:
|
||||
os.makedirs(cfg.common_eval.results_path, exist_ok=True)
|
||||
output_path = os.path.join(
|
||||
args.results_path, "generate-{}.txt".format(args.gen_subset)
|
||||
cfg.common_eval.results_path,
|
||||
"generate-{}.txt".format(cfg.dataset.gen_subset),
|
||||
)
|
||||
with open(output_path, "w", buffering=1, encoding="utf-8") as h:
|
||||
return _main(args, h)
|
||||
return _main(cfg, h)
|
||||
else:
|
||||
return _main(args, sys.stdout)
|
||||
return _main(cfg, sys.stdout)
|
||||
|
||||
|
||||
def get_symbols_to_strip_from_output(generator):
|
||||
@ -48,7 +60,7 @@ def get_symbols_to_strip_from_output(generator):
|
||||
return {generator.eos}
|
||||
|
||||
|
||||
def _main(args, output_file):
|
||||
def _main(cfg: DictConfig, output_file):
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
@ -57,22 +69,22 @@ def _main(args, output_file):
|
||||
)
|
||||
logger = logging.getLogger("fairseq_cli.generate")
|
||||
|
||||
utils.import_user_module(args)
|
||||
utils.import_user_module(cfg.common)
|
||||
|
||||
if args.max_tokens is None and args.batch_size is None:
|
||||
args.max_tokens = 12000
|
||||
logger.info(args)
|
||||
if cfg.dataset.max_tokens is None and cfg.dataset.batch_size is None:
|
||||
cfg.dataset.max_tokens = 12000
|
||||
logger.info(cfg)
|
||||
|
||||
# Fix seed for stochastic decoding
|
||||
if args.seed is not None and not args.no_seed_provided:
|
||||
np.random.seed(args.seed)
|
||||
utils.set_torch_seed(args.seed)
|
||||
if cfg.common.seed is not None and not cfg.generation.no_seed_provided:
|
||||
np.random.seed(cfg.common.seed)
|
||||
utils.set_torch_seed(cfg.common.seed)
|
||||
|
||||
use_cuda = torch.cuda.is_available() and not args.cpu
|
||||
use_cuda = torch.cuda.is_available() and not cfg.common.cpu
|
||||
|
||||
# Load dataset splits
|
||||
task = tasks.setup_task(args)
|
||||
task.load_dataset(args.gen_subset)
|
||||
task = tasks.setup_task(cfg.task)
|
||||
task.load_dataset(cfg.dataset.gen_subset)
|
||||
|
||||
# Set dictionaries
|
||||
try:
|
||||
@ -81,32 +93,30 @@ def _main(args, output_file):
|
||||
src_dict = None
|
||||
tgt_dict = task.target_dictionary
|
||||
|
||||
overrides = ast.literal_eval(args.model_overrides)
|
||||
overrides = ast.literal_eval(cfg.common_eval.model_overrides)
|
||||
|
||||
# Load ensemble
|
||||
logger.info("loading model(s) from {}".format(args.path))
|
||||
logger.info("loading model(s) from {}".format(cfg.common_eval.path))
|
||||
models, _model_args = checkpoint_utils.load_model_ensemble(
|
||||
utils.split_paths(args.path),
|
||||
utils.split_paths(cfg.common_eval.path),
|
||||
arg_overrides=overrides,
|
||||
task=task,
|
||||
suffix=getattr(args, "checkpoint_suffix", ""),
|
||||
strict=(args.checkpoint_shard_count == 1),
|
||||
num_shards=args.checkpoint_shard_count,
|
||||
suffix=cfg.checkpoint.checkpoint_suffix,
|
||||
strict=(cfg.checkpoint.checkpoint_shard_count == 1),
|
||||
num_shards=cfg.checkpoint.checkpoint_shard_count,
|
||||
)
|
||||
|
||||
if args.lm_path is not None:
|
||||
overrides["data"] = args.data
|
||||
if cfg.generation.lm_path is not None:
|
||||
overrides["data"] = cfg.task.data
|
||||
|
||||
try:
|
||||
lms, _ = checkpoint_utils.load_model_ensemble(
|
||||
[args.lm_path],
|
||||
arg_overrides=overrides,
|
||||
task=None,
|
||||
[cfg.generation.lm_path], arg_overrides=overrides, task=None
|
||||
)
|
||||
except:
|
||||
logger.warning(
|
||||
f"Failed to load language model! Please make sure that the language model dict is the same "
|
||||
f"as target dict and is located in the data dir ({args.data})"
|
||||
f"as target dict and is located in the data dir ({cfg.task.data})"
|
||||
)
|
||||
raise
|
||||
|
||||
@ -118,49 +128,50 @@ def _main(args, output_file):
|
||||
for model in chain(models, lms):
|
||||
if model is None:
|
||||
continue
|
||||
if args.fp16:
|
||||
if cfg.common.fp16:
|
||||
model.half()
|
||||
if use_cuda and not args.pipeline_model_parallel:
|
||||
if use_cuda and not cfg.distributed_training.pipeline_model_parallel:
|
||||
model.cuda()
|
||||
model.prepare_for_inference_(args)
|
||||
model.prepare_for_inference_(cfg)
|
||||
|
||||
# Load alignment dictionary for unknown word replacement
|
||||
# (None if no unknown word replacement, empty if no path to align dictionary)
|
||||
align_dict = utils.load_align_dict(args.replace_unk)
|
||||
align_dict = utils.load_align_dict(cfg.generation.replace_unk)
|
||||
|
||||
# Load dataset (possibly sharded)
|
||||
itr = task.get_batch_iterator(
|
||||
dataset=task.dataset(args.gen_subset),
|
||||
max_tokens=args.max_tokens,
|
||||
max_sentences=args.batch_size,
|
||||
dataset=task.dataset(cfg.dataset.gen_subset),
|
||||
max_tokens=cfg.dataset.max_tokens,
|
||||
max_sentences=cfg.dataset.batch_size,
|
||||
max_positions=utils.resolve_max_positions(
|
||||
task.max_positions(), *[model.max_positions() for model in models]
|
||||
task.max_positions(), *[m.max_positions() for m in models]
|
||||
),
|
||||
ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
|
||||
required_batch_size_multiple=args.required_batch_size_multiple,
|
||||
num_shards=args.num_shards,
|
||||
shard_id=args.shard_id,
|
||||
num_workers=args.num_workers,
|
||||
data_buffer_size=args.data_buffer_size,
|
||||
ignore_invalid_inputs=cfg.dataset.skip_invalid_size_inputs_valid_test,
|
||||
required_batch_size_multiple=cfg.dataset.required_batch_size_multiple,
|
||||
seed=cfg.common.seed,
|
||||
num_shards=cfg.distributed_training.distributed_world_size,
|
||||
shard_id=cfg.distributed_training.distributed_rank,
|
||||
num_workers=cfg.dataset.num_workers,
|
||||
data_buffer_size=cfg.dataset.data_buffer_size,
|
||||
).next_epoch_itr(shuffle=False)
|
||||
progress = progress_bar.progress_bar(
|
||||
itr,
|
||||
log_format=args.log_format,
|
||||
log_interval=args.log_interval,
|
||||
default_log_format=("tqdm" if not args.no_progress_bar else "none"),
|
||||
log_format=cfg.common.log_format,
|
||||
log_interval=cfg.common.log_interval,
|
||||
default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"),
|
||||
)
|
||||
|
||||
# Initialize generator
|
||||
gen_timer = StopwatchMeter()
|
||||
|
||||
extra_gen_cls_kwargs = {"lm_model": lms[0], "lm_weight": args.lm_weight}
|
||||
extra_gen_cls_kwargs = {"lm_model": lms[0], "lm_weight": cfg.generation.lm_weight}
|
||||
generator = task.build_generator(
|
||||
models, args, extra_gen_cls_kwargs=extra_gen_cls_kwargs
|
||||
models, cfg.task, extra_gen_cls_kwargs=extra_gen_cls_kwargs
|
||||
)
|
||||
|
||||
# Handle tokenization and BPE
|
||||
tokenizer = task.build_tokenizer(args)
|
||||
bpe = task.build_bpe(args)
|
||||
tokenizer = encoders.build_tokenizer(cfg.tokenizer)
|
||||
bpe = encoders.build_bpe(cfg.bpe)
|
||||
|
||||
def decode_fn(x):
|
||||
if bpe is not None:
|
||||
@ -169,7 +180,7 @@ def _main(args, output_file):
|
||||
x = tokenizer.decode(x)
|
||||
return x
|
||||
|
||||
scorer = scoring.build_scorer(args, tgt_dict)
|
||||
scorer = scoring.build_scorer(cfg.scoring, tgt_dict)
|
||||
|
||||
num_sentences = 0
|
||||
has_target = True
|
||||
@ -180,8 +191,8 @@ def _main(args, output_file):
|
||||
continue
|
||||
|
||||
prefix_tokens = None
|
||||
if args.prefix_size > 0:
|
||||
prefix_tokens = sample["target"][:, : args.prefix_size]
|
||||
if cfg.generation.prefix_size > 0:
|
||||
prefix_tokens = sample["target"][:, : cfg.generation.prefix_size]
|
||||
|
||||
constraints = None
|
||||
if "constraints" in sample:
|
||||
@ -217,19 +228,21 @@ def _main(args, output_file):
|
||||
|
||||
# Either retrieve the original sentences or regenerate them from tokens.
|
||||
if align_dict is not None:
|
||||
src_str = task.dataset(args.gen_subset).src.get_original_text(sample_id)
|
||||
target_str = task.dataset(args.gen_subset).tgt.get_original_text(
|
||||
src_str = task.dataset(cfg.dataset.gen_subset).src.get_original_text(
|
||||
sample_id
|
||||
)
|
||||
target_str = task.dataset(cfg.dataset.gen_subset).tgt.get_original_text(
|
||||
sample_id
|
||||
)
|
||||
else:
|
||||
if src_dict is not None:
|
||||
src_str = src_dict.string(src_tokens, args.remove_bpe)
|
||||
src_str = src_dict.string(src_tokens, cfg.common_eval.remove_bpe)
|
||||
else:
|
||||
src_str = ""
|
||||
if has_target:
|
||||
target_str = tgt_dict.string(
|
||||
target_tokens,
|
||||
args.remove_bpe,
|
||||
cfg.common_eval.remove_bpe,
|
||||
escape_unk=True,
|
||||
extra_symbols_to_ignore=get_symbols_to_strip_from_output(
|
||||
generator
|
||||
@ -240,25 +253,25 @@ def _main(args, output_file):
|
||||
if has_target:
|
||||
target_str = decode_fn(target_str)
|
||||
|
||||
if not args.quiet:
|
||||
if not cfg.common_eval.quiet:
|
||||
if src_dict is not None:
|
||||
print("S-{}\t{}".format(sample_id, src_str), file=output_file)
|
||||
if has_target:
|
||||
print("T-{}\t{}".format(sample_id, target_str), file=output_file)
|
||||
|
||||
# Process top predictions
|
||||
for j, hypo in enumerate(hypos[i][: args.nbest]):
|
||||
for j, hypo in enumerate(hypos[i][: cfg.generation.nbest]):
|
||||
hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
|
||||
hypo_tokens=hypo["tokens"].int().cpu(),
|
||||
src_str=src_str,
|
||||
alignment=hypo["alignment"],
|
||||
align_dict=align_dict,
|
||||
tgt_dict=tgt_dict,
|
||||
remove_bpe=args.remove_bpe,
|
||||
remove_bpe=cfg.common_eval.remove_bpe,
|
||||
extra_symbols_to_ignore=get_symbols_to_strip_from_output(generator),
|
||||
)
|
||||
detok_hypo_str = decode_fn(hypo_str)
|
||||
if not args.quiet:
|
||||
if not cfg.common_eval.quiet:
|
||||
score = hypo["score"] / math.log(2) # convert to base 2
|
||||
# original hypothesis (after tokenization and BPE)
|
||||
print(
|
||||
@ -286,7 +299,7 @@ def _main(args, output_file):
|
||||
file=output_file,
|
||||
)
|
||||
|
||||
if args.print_alignment:
|
||||
if cfg.generation.print_alignment:
|
||||
print(
|
||||
"A-{}\t{}".format(
|
||||
sample_id,
|
||||
@ -300,13 +313,13 @@ def _main(args, output_file):
|
||||
file=output_file,
|
||||
)
|
||||
|
||||
if args.print_step:
|
||||
if cfg.generation.print_step:
|
||||
print(
|
||||
"I-{}\t{}".format(sample_id, hypo["steps"]),
|
||||
file=output_file,
|
||||
)
|
||||
|
||||
if getattr(args, "retain_iter_history", False):
|
||||
if cfg.generation.retain_iter_history:
|
||||
for step, h in enumerate(hypo["history"]):
|
||||
_, h_str, _ = utils.post_process_prediction(
|
||||
hypo_tokens=h["tokens"].int().cpu(),
|
||||
@ -323,7 +336,7 @@ def _main(args, output_file):
|
||||
|
||||
# Score only the top hypothesis
|
||||
if has_target and j == 0:
|
||||
if align_dict is not None or args.remove_bpe is not None:
|
||||
if align_dict is not None or cfg.common_eval.remove_bpe is not None:
|
||||
# Convert back to tokens for evaluation with unk replacement and/or without BPE
|
||||
target_tokens = tgt_dict.encode_line(
|
||||
target_str, add_if_not_exist=True
|
||||
@ -353,8 +366,8 @@ def _main(args, output_file):
|
||||
)
|
||||
)
|
||||
if has_target:
|
||||
if args.bpe and not args.sacrebleu:
|
||||
if args.remove_bpe:
|
||||
if cfg.bpe and not cfg.generation.sacrebleu:
|
||||
if cfg.common_eval.remove_bpe:
|
||||
logger.warning(
|
||||
"BLEU score is being computed by splitting detokenized string on spaces, this is probably not what you want. Use --sacrebleu for standard 13a BLEU tokenization"
|
||||
)
|
||||
@ -365,7 +378,7 @@ def _main(args, output_file):
|
||||
# use print to be consistent with other main outputs: S-, H-, T-, D- and so on
|
||||
print(
|
||||
"Generate {} with beam={}: {}".format(
|
||||
args.gen_subset, args.beam, scorer.result_string()
|
||||
cfg.dataset.gen_subset, cfg.generation.beam, scorer.result_string()
|
||||
),
|
||||
file=output_file,
|
||||
)
|
||||
@ -380,4 +393,7 @@ def cli_main():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cs = ConfigStore.instance()
|
||||
register_hydra_cfg(cs)
|
||||
initialize(config_path="../config", strict=True)
|
||||
cli_main()
|
||||
|
@ -7,20 +7,27 @@
|
||||
Translate raw text with a trained model. Batches data on-the-fly.
|
||||
"""
|
||||
|
||||
import ast
|
||||
import fileinput
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from argparse import Namespace
|
||||
from collections import namedtuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from fairseq import checkpoint_utils, distributed_utils, options, tasks, utils
|
||||
from fairseq.data import encoders
|
||||
from fairseq.dataclass.data_class import register_hydra_cfg
|
||||
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
|
||||
from fairseq.token_generation_constraints import pack_constraints, unpack_constraints
|
||||
from fairseq_cli.generate import get_symbols_to_strip_from_output
|
||||
from hydra.core.config_store import ConfigStore
|
||||
from hydra.experimental import initialize
|
||||
from omegaconf import DictConfig
|
||||
|
||||
|
||||
logging.basicConfig(
|
||||
@ -49,11 +56,11 @@ def buffered_read(input, buffer_size):
|
||||
yield buffer
|
||||
|
||||
|
||||
def make_batches(lines, args, task, max_positions, encode_fn):
|
||||
def make_batches(lines, cfg, task, max_positions, encode_fn):
|
||||
def encode_fn_target(x):
|
||||
return encode_fn(x)
|
||||
|
||||
if args.constraints:
|
||||
if cfg.generation.constraints:
|
||||
# Strip (tab-delimited) contraints, if present, from input lines,
|
||||
# store them in batch_constraints
|
||||
batch_constraints = [list() for _ in lines]
|
||||
@ -79,7 +86,7 @@ def make_batches(lines, args, task, max_positions, encode_fn):
|
||||
for src_str in lines
|
||||
]
|
||||
|
||||
if args.constraints:
|
||||
if cfg.generation.constraints:
|
||||
constraints_tensor = pack_constraints(batch_constraints)
|
||||
else:
|
||||
constraints_tensor = None
|
||||
@ -89,10 +96,10 @@ def make_batches(lines, args, task, max_positions, encode_fn):
|
||||
dataset=task.build_dataset_for_inference(
|
||||
tokens, lengths, constraints=constraints_tensor
|
||||
),
|
||||
max_tokens=args.max_tokens,
|
||||
max_sentences=args.batch_size,
|
||||
max_tokens=cfg.dataset.max_tokens,
|
||||
max_sentences=cfg.dataset.batch_size,
|
||||
max_positions=max_positions,
|
||||
ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
|
||||
ignore_invalid_inputs=cfg.dataset.skip_invalid_size_inputs_valid_test,
|
||||
).next_epoch_itr(shuffle=False)
|
||||
for batch in itr:
|
||||
ids = batch["id"]
|
||||
@ -108,45 +115,50 @@ def make_batches(lines, args, task, max_positions, encode_fn):
|
||||
)
|
||||
|
||||
|
||||
def main(args):
|
||||
def main(cfg: DictConfig):
|
||||
if isinstance(cfg, Namespace):
|
||||
cfg = convert_namespace_to_omegaconf(cfg)
|
||||
|
||||
start_time = time.time()
|
||||
total_translate_time = 0
|
||||
|
||||
utils.import_user_module(args)
|
||||
utils.import_user_module(cfg.common)
|
||||
|
||||
if args.buffer_size < 1:
|
||||
args.buffer_size = 1
|
||||
if args.max_tokens is None and args.batch_size is None:
|
||||
args.batch_size = 1
|
||||
if cfg.interactive.buffer_size < 1:
|
||||
cfg.interactive.buffer_size = 1
|
||||
if cfg.dataset.max_tokens is None and cfg.dataset.batch_size is None:
|
||||
cfg.dataset.batch_size = 1
|
||||
|
||||
assert (
|
||||
not args.sampling or args.nbest == args.beam
|
||||
not cfg.generation.sampling or cfg.generation.nbest == cfg.generation.beam
|
||||
), "--sampling requires --nbest to be equal to --beam"
|
||||
assert (
|
||||
not args.batch_size or args.batch_size <= args.buffer_size
|
||||
not cfg.dataset.batch_size
|
||||
or cfg.dataset.batch_size <= cfg.interactive.buffer_size
|
||||
), "--batch-size cannot be larger than --buffer-size"
|
||||
|
||||
logger.info(args)
|
||||
logger.info(cfg)
|
||||
|
||||
# Fix seed for stochastic decoding
|
||||
if args.seed is not None and not args.no_seed_provided:
|
||||
np.random.seed(args.seed)
|
||||
utils.set_torch_seed(args.seed)
|
||||
if cfg.common.seed is not None and not cfg.generation.no_seed_provided:
|
||||
np.random.seed(cfg.common.seed)
|
||||
utils.set_torch_seed(cfg.common.seed)
|
||||
|
||||
use_cuda = torch.cuda.is_available() and not args.cpu
|
||||
use_cuda = torch.cuda.is_available() and not cfg.common.cpu
|
||||
|
||||
# Setup task, e.g., translation
|
||||
task = tasks.setup_task(args)
|
||||
task = tasks.setup_task(cfg.task)
|
||||
|
||||
# Load ensemble
|
||||
logger.info("loading model(s) from {}".format(args.path))
|
||||
overrides = ast.literal_eval(cfg.common_eval.model_overrides)
|
||||
logger.info("loading model(s) from {}".format(cfg.common_eval.path))
|
||||
models, _model_args = checkpoint_utils.load_model_ensemble(
|
||||
args.path.split(os.pathsep),
|
||||
arg_overrides=eval(args.model_overrides),
|
||||
utils.split_paths(cfg.common_eval.path),
|
||||
arg_overrides=overrides,
|
||||
task=task,
|
||||
suffix=getattr(args, "checkpoint_suffix", ""),
|
||||
strict=(args.checkpoint_shard_count == 1),
|
||||
num_shards=args.checkpoint_shard_count,
|
||||
suffix=cfg.checkpoint.checkpoint_suffix,
|
||||
strict=(cfg.checkpoint.checkpoint_shard_count == 1),
|
||||
num_shards=cfg.checkpoint.checkpoint_shard_count,
|
||||
)
|
||||
|
||||
# Set dictionaries
|
||||
@ -155,18 +167,20 @@ def main(args):
|
||||
|
||||
# Optimize ensemble for generation
|
||||
for model in models:
|
||||
if args.fp16:
|
||||
if model is None:
|
||||
continue
|
||||
if cfg.common.fp16:
|
||||
model.half()
|
||||
if use_cuda and not args.pipeline_model_parallel:
|
||||
if use_cuda and not cfg.distributed_training.pipeline_model_parallel:
|
||||
model.cuda()
|
||||
model.prepare_for_inference_(args)
|
||||
model.prepare_for_inference_(cfg)
|
||||
|
||||
# Initialize generator
|
||||
generator = task.build_generator(models, args)
|
||||
generator = task.build_generator(models, cfg.task)
|
||||
|
||||
# Handle tokenization and BPE
|
||||
tokenizer = encoders.build_tokenizer(args)
|
||||
bpe = encoders.build_bpe(args)
|
||||
tokenizer = encoders.build_tokenizer(cfg.tokenizer)
|
||||
bpe = encoders.build_bpe(cfg.bpe)
|
||||
|
||||
def encode_fn(x):
|
||||
if tokenizer is not None:
|
||||
@ -184,25 +198,25 @@ def main(args):
|
||||
|
||||
# Load alignment dictionary for unknown word replacement
|
||||
# (None if no unknown word replacement, empty if no path to align dictionary)
|
||||
align_dict = utils.load_align_dict(args.replace_unk)
|
||||
align_dict = utils.load_align_dict(cfg.generation.replace_unk)
|
||||
|
||||
max_positions = utils.resolve_max_positions(
|
||||
task.max_positions(), *[model.max_positions() for model in models]
|
||||
)
|
||||
|
||||
if args.constraints:
|
||||
if cfg.generation.constraints:
|
||||
logger.warning(
|
||||
"NOTE: Constrained decoding currently assumes a shared subword vocabulary."
|
||||
)
|
||||
|
||||
if args.buffer_size > 1:
|
||||
logger.info("Sentence buffer size: %s", args.buffer_size)
|
||||
if cfg.interactive.buffer_size > 1:
|
||||
logger.info("Sentence buffer size: %s", cfg.interactive.buffer_size)
|
||||
logger.info("NOTE: hypothesis and token scores are output in base 2")
|
||||
logger.info("Type the input sentence and press return:")
|
||||
start_id = 0
|
||||
for inputs in buffered_read(args.input, args.buffer_size):
|
||||
for inputs in buffered_read(cfg.interactive.input, cfg.interactive.buffer_size):
|
||||
results = []
|
||||
for batch in make_batches(inputs, args, task, max_positions, encode_fn):
|
||||
for batch in make_batches(inputs, cfg, task, max_positions, encode_fn):
|
||||
bsz = batch.src_tokens.size(0)
|
||||
src_tokens = batch.src_tokens
|
||||
src_lengths = batch.src_lengths
|
||||
@ -226,7 +240,7 @@ def main(args):
|
||||
translate_time = time.time() - translate_start_time
|
||||
total_translate_time += translate_time
|
||||
list_constraints = [[] for _ in range(bsz)]
|
||||
if args.constraints:
|
||||
if cfg.generation.constraints:
|
||||
list_constraints = [unpack_constraints(c) for c in constraints]
|
||||
for i, (id, hypos) in enumerate(zip(batch.ids.tolist(), translations)):
|
||||
src_tokens_i = utils.strip_pad(src_tokens[i], tgt_dict.pad())
|
||||
@ -246,25 +260,25 @@ def main(args):
|
||||
# sort output to match input order
|
||||
for id_, src_tokens, hypos, info in sorted(results, key=lambda x: x[0]):
|
||||
if src_dict is not None:
|
||||
src_str = src_dict.string(src_tokens, args.remove_bpe)
|
||||
src_str = src_dict.string(src_tokens, cfg.common_eval.remove_bpe)
|
||||
print("S-{}\t{}".format(id_, src_str))
|
||||
print("W-{}\t{:.3f}\tseconds".format(id_, info["time"]))
|
||||
for constraint in info["constraints"]:
|
||||
print(
|
||||
"C-{}\t{}".format(
|
||||
id_, tgt_dict.string(constraint, args.remove_bpe)
|
||||
id_, tgt_dict.string(constraint, cfg.common_eval.remove_bpe)
|
||||
)
|
||||
)
|
||||
|
||||
# Process top predictions
|
||||
for hypo in hypos[: min(len(hypos), args.nbest)]:
|
||||
for hypo in hypos[: min(len(hypos), cfg.generation.nbest)]:
|
||||
hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
|
||||
hypo_tokens=hypo["tokens"].int().cpu(),
|
||||
src_str=src_str,
|
||||
alignment=hypo["alignment"],
|
||||
align_dict=align_dict,
|
||||
tgt_dict=tgt_dict,
|
||||
remove_bpe=args.remove_bpe,
|
||||
remove_bpe=cfg.common_eval.remove_bpe,
|
||||
extra_symbols_to_ignore=get_symbols_to_strip_from_output(generator),
|
||||
)
|
||||
detok_hypo_str = decode_fn(hypo_str)
|
||||
@ -285,7 +299,7 @@ def main(args):
|
||||
),
|
||||
)
|
||||
)
|
||||
if args.print_alignment:
|
||||
if cfg.generation.print_alignment:
|
||||
alignment_str = " ".join(
|
||||
["{}-{}".format(src, tgt) for src, tgt in alignment]
|
||||
)
|
||||
@ -308,4 +322,7 @@ def cli_main():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cs = ConfigStore.instance()
|
||||
register_hydra_cfg(cs)
|
||||
initialize(config_path="../config", strict=True)
|
||||
cli_main()
|
||||
|
@ -78,7 +78,13 @@ def cli_main():
|
||||
|
||||
def score(fdsys):
|
||||
with open(args.ref) as fdref:
|
||||
scorer = bleu.Scorer(dict.pad(), dict.eos(), dict.unk())
|
||||
scorer = bleu.Scorer(
|
||||
bleu.BleuConfig(
|
||||
pad=dict.pad(),
|
||||
eos=dict.eos(),
|
||||
unk=dict.unk(),
|
||||
)
|
||||
)
|
||||
for sys_tok, ref_tok in zip(readlines(fdsys), readlines(fdref)):
|
||||
sys_tok = dict.encode_line(sys_tok)
|
||||
ref_tok = dict.encode_line(ref_tok)
|
||||
|
@ -11,11 +11,13 @@ import argparse
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
from typing import Dict, Optional, Any, List, Tuple, Callable
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from hydra.core.config_store import ConfigStore
|
||||
|
||||
from fairseq import (
|
||||
checkpoint_utils,
|
||||
distributed_utils,
|
||||
@ -25,8 +27,12 @@ from fairseq import (
|
||||
utils,
|
||||
)
|
||||
from fairseq.data import iterators
|
||||
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
|
||||
from fairseq.logging import meters, metrics, progress_bar
|
||||
from fairseq.model_parallel.megatron_trainer import MegatronTrainer
|
||||
from omegaconf import DictConfig
|
||||
from hydra.experimental import initialize
|
||||
from fairseq.dataclass.data_class import register_hydra_cfg
|
||||
from fairseq.trainer import Trainer
|
||||
|
||||
|
||||
@ -39,90 +45,86 @@ logging.basicConfig(
|
||||
logger = logging.getLogger("fairseq_cli.train")
|
||||
|
||||
|
||||
def main(args):
|
||||
utils.import_user_module(args)
|
||||
def main(cfg: DictConfig) -> None:
|
||||
if isinstance(cfg, argparse.Namespace):
|
||||
cfg = convert_namespace_to_omegaconf(cfg)
|
||||
|
||||
assert (
|
||||
args.max_tokens is not None or args.batch_size is not None
|
||||
), "Must specify batch size either with --max-tokens or --batch-size"
|
||||
utils.import_user_module(cfg.common)
|
||||
|
||||
assert cfg.dataset.max_tokens is not None or cfg.dataset.batch_size is not None, \
|
||||
'Must specify batch size either with --max-tokens or --batch-size'
|
||||
metrics.reset()
|
||||
|
||||
np.random.seed(args.seed)
|
||||
utils.set_torch_seed(args.seed)
|
||||
np.random.seed(cfg.common.seed)
|
||||
utils.set_torch_seed(cfg.common.seed)
|
||||
|
||||
if distributed_utils.is_master(args):
|
||||
checkpoint_utils.verify_checkpoint_directory(args.save_dir)
|
||||
if distributed_utils.is_master(cfg.distributed_training):
|
||||
checkpoint_utils.verify_checkpoint_directory(cfg.checkpoint.save_dir)
|
||||
|
||||
# Print args
|
||||
logger.info(args)
|
||||
logger.info(cfg)
|
||||
|
||||
# Setup task, e.g., translation, language modeling, etc.
|
||||
task = tasks.setup_task(args)
|
||||
|
||||
task = tasks.setup_task(cfg.task)
|
||||
# Load valid dataset (we load training data below, based on the latest checkpoint)
|
||||
for valid_sub_split in args.valid_subset.split(","):
|
||||
for valid_sub_split in cfg.dataset.valid_subset.split(','):
|
||||
task.load_dataset(valid_sub_split, combine=False, epoch=1)
|
||||
|
||||
# Build model and criterion
|
||||
model = task.build_model(args)
|
||||
criterion = task.build_criterion(args)
|
||||
model = task.build_model(cfg.model)
|
||||
criterion = task.build_criterion(cfg.criterion)
|
||||
logger.info(model)
|
||||
logger.info("task: {} ({})".format(args.task, task.__class__.__name__))
|
||||
logger.info("model: {} ({})".format(args.arch, model.__class__.__name__))
|
||||
logger.info("task: {} ({})".format(cfg.task._name, task.__class__.__name__))
|
||||
logger.info("model: {} ({})".format(cfg.model._name, model.__class__.__name__))
|
||||
logger.info(
|
||||
"criterion: {} ({})".format(args.criterion, criterion.__class__.__name__)
|
||||
)
|
||||
logger.info(
|
||||
"num. model params: {} (num. trained: {})".format(
|
||||
sum(p.numel() for p in model.parameters()),
|
||||
sum(p.numel() for p in model.parameters() if p.requires_grad),
|
||||
)
|
||||
"criterion: {} ({})".format(cfg.criterion._name, criterion.__class__.__name__)
|
||||
)
|
||||
logger.info("num. model params: {} (num. trained: {})".format(
|
||||
sum(p.numel() for p in model.parameters()),
|
||||
sum(p.numel() for p in model.parameters() if p.requires_grad),
|
||||
))
|
||||
|
||||
# (optionally) Configure quantization
|
||||
if args.quantization_config_path is not None:
|
||||
if cfg.common.quantization_config_path is not None:
|
||||
quantizer = quantization_utils.Quantizer(
|
||||
config_path=args.quantization_config_path,
|
||||
max_epoch=args.max_epoch,
|
||||
max_update=args.max_update,
|
||||
config_path=cfg.common.quantization_config_path,
|
||||
max_epoch=cfg.optimization.max_epoch,
|
||||
max_update=cfg.optimization.max_update,
|
||||
)
|
||||
else:
|
||||
quantizer = None
|
||||
|
||||
# Build trainer
|
||||
if args.model_parallel_size == 1:
|
||||
trainer = Trainer(args, task, model, criterion, quantizer)
|
||||
if cfg.common.model_parallel_size == 1:
|
||||
trainer = Trainer(cfg, task, model, criterion, quantizer)
|
||||
else:
|
||||
trainer = MegatronTrainer(args, task, model, criterion)
|
||||
trainer = MegatronTrainer(cfg, task, model, criterion)
|
||||
|
||||
logger.info(
|
||||
"training on {} devices (GPUs/TPUs)".format(args.distributed_world_size)
|
||||
)
|
||||
logger.info(
|
||||
"max tokens per GPU = {} and max sentences per GPU = {}".format(
|
||||
args.max_tokens, args.batch_size
|
||||
)
|
||||
)
|
||||
logger.info('training on {} devices (GPUs/TPUs)'.format(cfg.distributed_training.distributed_world_size))
|
||||
logger.info('max tokens per GPU = {} and batch size per GPU = {}'.format(
|
||||
cfg.dataset.max_tokens,
|
||||
cfg.dataset.batch_size,
|
||||
))
|
||||
|
||||
# Load the latest checkpoint if one is available and restore the
|
||||
# corresponding train iterator
|
||||
extra_state, epoch_itr = checkpoint_utils.load_checkpoint(
|
||||
args,
|
||||
cfg.checkpoint,
|
||||
trainer,
|
||||
# don't cache epoch iterators for sharded datasets
|
||||
disable_iterator_cache=task.has_sharded_data("train"),
|
||||
)
|
||||
|
||||
# Train until the learning rate gets too small
|
||||
max_epoch = args.max_epoch or math.inf
|
||||
max_epoch = cfg.optimization.max_epoch or math.inf
|
||||
lr = trainer.get_lr()
|
||||
train_meter = meters.StopwatchMeter()
|
||||
train_meter.start()
|
||||
|
||||
while lr > args.min_lr and epoch_itr.next_epoch_idx <= max_epoch:
|
||||
while (
|
||||
lr > cfg.optimization.min_lr
|
||||
and epoch_itr.next_epoch_idx <= max_epoch
|
||||
):
|
||||
# train for one epoch
|
||||
valid_losses, should_stop = train(args, trainer, task, epoch_itr)
|
||||
valid_losses, should_stop = train(cfg, trainer, task, epoch_itr)
|
||||
if should_stop:
|
||||
break
|
||||
|
||||
@ -140,15 +142,15 @@ def main(args):
|
||||
logger.info("done training in {:.1f} seconds".format(train_meter.sum))
|
||||
|
||||
|
||||
def should_stop_early(args, valid_loss):
|
||||
def should_stop_early(cfg: DictConfig, valid_loss: float) -> bool:
|
||||
# skip check if no validation was done in the current epoch
|
||||
if valid_loss is None:
|
||||
return False
|
||||
if args.patience <= 0:
|
||||
if cfg.checkpoint.patience <= 0:
|
||||
return False
|
||||
|
||||
def is_better(a, b):
|
||||
return a > b if args.maximize_best_checkpoint_metric else a < b
|
||||
return a > b if cfg.checkpoint.maximize_best_checkpoint_metric else a < b
|
||||
|
||||
prev_best = getattr(should_stop_early, "best", None)
|
||||
if prev_best is None or is_better(valid_loss, prev_best):
|
||||
@ -157,48 +159,43 @@ def should_stop_early(args, valid_loss):
|
||||
return False
|
||||
else:
|
||||
should_stop_early.num_runs += 1
|
||||
if should_stop_early.num_runs >= args.patience:
|
||||
logger.info(
|
||||
"early stop since valid performance hasn't improved for last {} runs".format(
|
||||
args.patience
|
||||
)
|
||||
)
|
||||
if should_stop_early.num_runs >= cfg.checkpoint.patience:
|
||||
logger.info('early stop since valid performance hasn\'t improved for last {} runs'.format(cfg.checkpoint.patience))
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
@metrics.aggregate("train")
|
||||
def train(args, trainer, task, epoch_itr):
|
||||
def train(cfg: DictConfig, trainer: Trainer, task: tasks.FairseqTask, epoch_itr) -> Tuple[List[Optional[float]], bool]:
|
||||
"""Train the model for one epoch and return validation losses."""
|
||||
# Initialize data iterator
|
||||
itr = epoch_itr.next_epoch_itr(
|
||||
fix_batches_to_gpus=args.fix_batches_to_gpus,
|
||||
shuffle=(epoch_itr.next_epoch_idx > args.curriculum),
|
||||
fix_batches_to_gpus=cfg.distributed_training.fix_batches_to_gpus,
|
||||
shuffle=(epoch_itr.next_epoch_idx > cfg.dataset.curriculum),
|
||||
)
|
||||
update_freq = (
|
||||
args.update_freq[epoch_itr.epoch - 1]
|
||||
if epoch_itr.epoch <= len(args.update_freq)
|
||||
else args.update_freq[-1]
|
||||
cfg.optimization.update_freq[epoch_itr.epoch - 1]
|
||||
if epoch_itr.epoch <= len(cfg.optimization.update_freq)
|
||||
else cfg.optimization.update_freq[-1]
|
||||
)
|
||||
itr = iterators.GroupedIterator(itr, update_freq)
|
||||
if getattr(args, "tpu", False):
|
||||
if getattr(cfg.common, "tpu", False):
|
||||
itr = utils.tpu_data_loader(itr)
|
||||
progress = progress_bar.progress_bar(
|
||||
itr,
|
||||
log_format=args.log_format,
|
||||
log_interval=args.log_interval,
|
||||
log_format=cfg.common.log_format,
|
||||
log_interval=cfg.common.log_interval,
|
||||
epoch=epoch_itr.epoch,
|
||||
tensorboard_logdir=(
|
||||
args.tensorboard_logdir if distributed_utils.is_master(args) else None
|
||||
cfg.common.tensorboard_logdir if distributed_utils.is_master(cfg.distributed_training) else None
|
||||
),
|
||||
default_log_format=("tqdm" if not args.no_progress_bar else "simple"),
|
||||
default_log_format=('tqdm' if not cfg.common.no_progress_bar else 'simple'),
|
||||
)
|
||||
|
||||
trainer.begin_epoch(epoch_itr.epoch)
|
||||
|
||||
valid_losses = [None]
|
||||
valid_subsets = args.valid_subset.split(",")
|
||||
valid_subsets = cfg.dataset.valid_subset.split(',')
|
||||
should_stop = False
|
||||
num_updates = trainer.get_num_updates()
|
||||
for i, samples in enumerate(progress):
|
||||
@ -210,7 +207,7 @@ def train(args, trainer, task, epoch_itr):
|
||||
if log_output is not None: # not OOM, overflow, ...
|
||||
# log mid-epoch stats
|
||||
num_updates = trainer.get_num_updates()
|
||||
if num_updates % args.log_interval == 0:
|
||||
if num_updates % cfg.common.log_interval == 0:
|
||||
stats = get_training_stats(metrics.get_smoothed_values("train_inner"))
|
||||
progress.log(stats, tag="train_inner", step=num_updates)
|
||||
|
||||
@ -220,7 +217,7 @@ def train(args, trainer, task, epoch_itr):
|
||||
|
||||
end_of_epoch = not itr.has_next()
|
||||
valid_losses, should_stop = validate_and_save(
|
||||
args, trainer, task, epoch_itr, valid_subsets, end_of_epoch
|
||||
cfg, trainer, task, epoch_itr, valid_subsets, end_of_epoch
|
||||
)
|
||||
|
||||
if should_stop:
|
||||
@ -236,64 +233,64 @@ def train(args, trainer, task, epoch_itr):
|
||||
return valid_losses, should_stop
|
||||
|
||||
|
||||
def validate_and_save(args, trainer, task, epoch_itr, valid_subsets, end_of_epoch):
|
||||
def validate_and_save(cfg: DictConfig, trainer: Trainer, task: tasks.FairseqTask, epoch_itr, valid_subsets: List[str], end_of_epoch: bool) -> Tuple[List[Optional[float]], bool]:
|
||||
num_updates = trainer.get_num_updates()
|
||||
max_update = args.max_update or math.inf
|
||||
max_update = cfg.optimization.max_update or math.inf
|
||||
do_save = (
|
||||
(end_of_epoch and epoch_itr.epoch % args.save_interval == 0)
|
||||
(end_of_epoch and epoch_itr.epoch % cfg.checkpoint.save_interval == 0)
|
||||
or num_updates >= max_update
|
||||
or (
|
||||
args.save_interval_updates > 0
|
||||
cfg.checkpoint.save_interval_updates > 0
|
||||
and num_updates > 0
|
||||
and num_updates % args.save_interval_updates == 0
|
||||
and num_updates >= args.validate_after_updates
|
||||
and num_updates % cfg.checkpoint.save_interval_updates == 0
|
||||
and num_updates >= cfg.dataset.validate_after_updates
|
||||
)
|
||||
)
|
||||
do_validate = (
|
||||
(not end_of_epoch and do_save) # validate during mid-epoch saves
|
||||
or (end_of_epoch and epoch_itr.epoch % args.validate_interval == 0)
|
||||
or (end_of_epoch and epoch_itr.epoch % cfg.dataset.validate_interval == 0)
|
||||
or num_updates >= max_update
|
||||
or (
|
||||
args.validate_interval_updates > 0
|
||||
cfg.dataset.validate_interval_updates > 0
|
||||
and num_updates > 0
|
||||
and num_updates % args.validate_interval_updates == 0
|
||||
and num_updates % cfg.dataset.validate_interval_updates == 0
|
||||
)
|
||||
) and not args.disable_validation
|
||||
) and not cfg.dataset.disable_validation
|
||||
|
||||
# Validate
|
||||
valid_losses = [None]
|
||||
if do_validate:
|
||||
valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets)
|
||||
valid_losses = validate(cfg, trainer, task, epoch_itr, valid_subsets)
|
||||
|
||||
# Stopping conditions
|
||||
should_stop = (
|
||||
should_stop_early(args, valid_losses[0])
|
||||
should_stop_early(cfg, valid_losses[0])
|
||||
or num_updates >= max_update
|
||||
or (
|
||||
args.stop_time_hours > 0
|
||||
and trainer.cumulative_training_time() / (60 * 60) > args.stop_time_hours
|
||||
cfg.optimization.stop_time_hours > 0
|
||||
and trainer.cumulative_training_time() / (60 * 60) > cfg.optimization.stop_time_hours
|
||||
)
|
||||
)
|
||||
|
||||
# Save checkpoint
|
||||
if do_save or should_stop:
|
||||
logger.info("begin save checkpoint")
|
||||
checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, valid_losses[0])
|
||||
checkpoint_utils.save_checkpoint(cfg.checkpoint, trainer, epoch_itr, valid_losses[0])
|
||||
|
||||
return valid_losses, should_stop
|
||||
|
||||
|
||||
def get_training_stats(stats):
|
||||
def get_training_stats(stats: Dict[str, Any]) -> Dict[str, Any]:
|
||||
stats["wall"] = round(metrics.get_meter("default", "wall").elapsed_time, 0)
|
||||
return stats
|
||||
|
||||
|
||||
def validate(args, trainer, task, epoch_itr, subsets):
|
||||
def validate(cfg: DictConfig, trainer: Trainer, task: tasks.FairseqTask, epoch_itr, subsets: List[str]) -> List[Optional[float]]:
|
||||
"""Evaluate the model on the validation set(s) and return the losses."""
|
||||
|
||||
if args.fixed_validation_seed is not None:
|
||||
if cfg.dataset.fixed_validation_seed is not None:
|
||||
# set fixed seed for every validation
|
||||
utils.set_torch_seed(args.fixed_validation_seed)
|
||||
utils.set_torch_seed(cfg.dataset.fixed_validation_seed)
|
||||
|
||||
trainer.begin_valid_epoch(epoch_itr.epoch)
|
||||
valid_losses = []
|
||||
@ -302,18 +299,18 @@ def validate(args, trainer, task, epoch_itr, subsets):
|
||||
|
||||
# Initialize data iterator
|
||||
itr = trainer.get_valid_iterator(subset).next_epoch_itr(shuffle=False)
|
||||
if getattr(args, "tpu", False):
|
||||
if cfg.common.tpu:
|
||||
itr = utils.tpu_data_loader(itr)
|
||||
progress = progress_bar.progress_bar(
|
||||
itr,
|
||||
log_format=args.log_format,
|
||||
log_interval=args.log_interval,
|
||||
log_format=cfg.common.log_format,
|
||||
log_interval=cfg.common.log_interval,
|
||||
epoch=epoch_itr.epoch,
|
||||
prefix=f"valid on '{subset}' subset",
|
||||
tensorboard_logdir=(
|
||||
args.tensorboard_logdir if distributed_utils.is_master(args) else None
|
||||
cfg.common.tensorboard_logdir if distributed_utils.is_master(cfg.distributed_training) else None
|
||||
),
|
||||
default_log_format=("tqdm" if not args.no_progress_bar else "simple"),
|
||||
default_log_format=('tqdm' if not cfg.common.no_progress_bar else 'simple'),
|
||||
)
|
||||
|
||||
# create a new root metrics aggregator so validation metrics
|
||||
@ -323,34 +320,40 @@ def validate(args, trainer, task, epoch_itr, subsets):
|
||||
trainer.valid_step(sample)
|
||||
|
||||
# log validation stats
|
||||
stats = get_valid_stats(args, trainer, agg.get_smoothed_values())
|
||||
stats = get_valid_stats(cfg, trainer, agg.get_smoothed_values())
|
||||
progress.print(stats, tag=subset, step=trainer.get_num_updates())
|
||||
|
||||
valid_losses.append(stats[args.best_checkpoint_metric])
|
||||
valid_losses.append(stats[cfg.checkpoint.best_checkpoint_metric])
|
||||
return valid_losses
|
||||
|
||||
|
||||
def get_valid_stats(args, trainer, stats):
|
||||
def get_valid_stats(cfg: DictConfig, trainer: Trainer, stats: Dict[str, Any]) -> Dict[str, Any]:
|
||||
stats["num_updates"] = trainer.get_num_updates()
|
||||
if hasattr(checkpoint_utils.save_checkpoint, "best"):
|
||||
key = "best_{0}".format(args.best_checkpoint_metric)
|
||||
best_function = max if args.maximize_best_checkpoint_metric else min
|
||||
key = "best_{0}".format(cfg.checkpoint.best_checkpoint_metric)
|
||||
best_function = max if cfg.checkpoint.maximize_best_checkpoint_metric else min
|
||||
stats[key] = best_function(
|
||||
checkpoint_utils.save_checkpoint.best, stats[args.best_checkpoint_metric]
|
||||
checkpoint_utils.save_checkpoint.best, stats[cfg.checkpoint.best_checkpoint_metric]
|
||||
)
|
||||
return stats
|
||||
|
||||
|
||||
def cli_main(modify_parser=None):
|
||||
def cli_main(modify_parser: Optional[Callable[[argparse.ArgumentParser], None]] = None) -> None:
|
||||
parser = options.get_training_parser()
|
||||
args = options.parse_args_and_arch(parser, modify_parser=modify_parser)
|
||||
|
||||
cfg = convert_namespace_to_omegaconf(args)
|
||||
|
||||
if args.profile:
|
||||
with torch.cuda.profiler.profile():
|
||||
with torch.autograd.profiler.emit_nvtx():
|
||||
distributed_utils.call_main(args, main)
|
||||
distributed_utils.call_main(cfg, main)
|
||||
else:
|
||||
distributed_utils.call_main(args, main)
|
||||
distributed_utils.call_main(cfg, main)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if __name__ == '__main__':
|
||||
cs = ConfigStore.instance()
|
||||
register_hydra_cfg(cs)
|
||||
initialize(config_path="../config", strict=True)
|
||||
cli_main()
|
||||
|
@ -1,5 +1,5 @@
|
||||
#!/usr/bin/env python3 -u
|
||||
#!/usr/bin/env python3 -u
|
||||
# !/usr/bin/env python3 -u
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
@ -8,11 +8,17 @@
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from argparse import Namespace
|
||||
from itertools import chain
|
||||
|
||||
import torch
|
||||
from fairseq import checkpoint_utils, distributed_utils, options, utils
|
||||
from fairseq.dataclass.data_class import register_hydra_cfg
|
||||
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
|
||||
from fairseq.logging import metrics, progress_bar
|
||||
from hydra.core.config_store import ConfigStore
|
||||
from hydra.experimental import initialize
|
||||
from omegaconf import DictConfig
|
||||
|
||||
|
||||
logging.basicConfig(
|
||||
@ -24,18 +30,21 @@ logging.basicConfig(
|
||||
logger = logging.getLogger("fairseq_cli.validate")
|
||||
|
||||
|
||||
def main(args, override_args=None):
|
||||
utils.import_user_module(args)
|
||||
def main(cfg: DictConfig, override_args=None):
|
||||
if isinstance(cfg, Namespace):
|
||||
cfg = convert_namespace_to_omegaconf(cfg)
|
||||
|
||||
utils.import_user_module(cfg.common)
|
||||
|
||||
assert (
|
||||
args.max_tokens is not None or args.batch_size is not None
|
||||
cfg.dataset.max_tokens is not None or cfg.dataset.batch_size is not None
|
||||
), "Must specify batch size either with --max-tokens or --batch-size"
|
||||
|
||||
use_fp16 = args.fp16
|
||||
use_cuda = torch.cuda.is_available() and not args.cpu
|
||||
use_fp16 = cfg.common.fp16
|
||||
use_cuda = torch.cuda.is_available() and not cfg.common.cpu
|
||||
|
||||
if use_cuda:
|
||||
torch.cuda.set_device(args.device_id)
|
||||
torch.cuda.set_device(cfg.distributed_training.device_id)
|
||||
|
||||
if override_args is not None:
|
||||
overrides = vars(override_args)
|
||||
@ -44,11 +53,11 @@ def main(args, override_args=None):
|
||||
overrides = None
|
||||
|
||||
# Load ensemble
|
||||
logger.info("loading model(s) from {}".format(args.path))
|
||||
logger.info("loading model(s) from {}".format(cfg.common_eval.path))
|
||||
models, model_args, task = checkpoint_utils.load_model_ensemble_and_task(
|
||||
[args.path],
|
||||
[cfg.common_eval.path],
|
||||
arg_overrides=overrides,
|
||||
suffix=getattr(args, "checkpoint_suffix", ""),
|
||||
suffix=cfg.checkpoint.checkpoint_suffix,
|
||||
)
|
||||
model = models[0]
|
||||
|
||||
@ -63,10 +72,10 @@ def main(args, override_args=None):
|
||||
logger.info(model_args)
|
||||
|
||||
# Build criterion
|
||||
criterion = task.build_criterion(model_args)
|
||||
criterion = task.build_criterion(model_args.criterion)
|
||||
criterion.eval()
|
||||
|
||||
for subset in args.valid_subset.split(","):
|
||||
for subset in cfg.dataset.valid_subset.split(","):
|
||||
try:
|
||||
task.load_dataset(subset, combine=False, epoch=1)
|
||||
dataset = task.dataset(subset)
|
||||
@ -76,26 +85,26 @@ def main(args, override_args=None):
|
||||
# Initialize data iterator
|
||||
itr = task.get_batch_iterator(
|
||||
dataset=dataset,
|
||||
max_tokens=args.max_tokens,
|
||||
max_sentences=args.batch_size,
|
||||
max_tokens=cfg.dataset.max_tokens,
|
||||
max_sentences=cfg.dataset.batch_size,
|
||||
max_positions=utils.resolve_max_positions(
|
||||
task.max_positions(),
|
||||
*[m.max_positions() for m in models],
|
||||
),
|
||||
ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
|
||||
required_batch_size_multiple=args.required_batch_size_multiple,
|
||||
seed=args.seed,
|
||||
num_shards=args.distributed_world_size,
|
||||
shard_id=args.distributed_rank,
|
||||
num_workers=args.num_workers,
|
||||
data_buffer_size=args.data_buffer_size,
|
||||
ignore_invalid_inputs=cfg.dataset.skip_invalid_size_inputs_valid_test,
|
||||
required_batch_size_multiple=cfg.dataset.required_batch_size_multiple,
|
||||
seed=cfg.common.seed,
|
||||
num_shards=cfg.distributed_training.distributed_world_size,
|
||||
shard_id=cfg.distributed_training.distributed_rank,
|
||||
num_workers=cfg.dataset.num_workers,
|
||||
data_buffer_size=cfg.dataset.data_buffer_size,
|
||||
).next_epoch_itr(shuffle=False)
|
||||
progress = progress_bar.progress_bar(
|
||||
itr,
|
||||
log_format=args.log_format,
|
||||
log_interval=args.log_interval,
|
||||
log_format=cfg.common.log_format,
|
||||
log_interval=cfg.common.log_interval,
|
||||
prefix=f"valid on '{subset}' subset",
|
||||
default_log_format=("tqdm" if not args.no_progress_bar else "simple"),
|
||||
default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"),
|
||||
)
|
||||
|
||||
log_outputs = []
|
||||
@ -105,10 +114,10 @@ def main(args, override_args=None):
|
||||
progress.log(log_output, step=i)
|
||||
log_outputs.append(log_output)
|
||||
|
||||
if args.distributed_world_size > 1:
|
||||
if cfg.distributed_training.distributed_world_size > 1:
|
||||
log_outputs = distributed_utils.all_gather_list(
|
||||
log_outputs,
|
||||
max_size=getattr(args, "all_gather_list_size", 16384),
|
||||
max_size=cfg.common.all_gather_list_size,
|
||||
)
|
||||
log_outputs = list(chain.from_iterable(log_outputs))
|
||||
|
||||
@ -131,4 +140,7 @@ def cli_main():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cs = ConfigStore.instance()
|
||||
register_hydra_cfg(cs)
|
||||
initialize(config_path="../config", strict=True)
|
||||
cli_main()
|
||||
|
@ -272,6 +272,7 @@ class TestFairseqEncoderDecoderModelBase(TestBaseFairseqModelBase):
|
||||
model_cls.add_args(parser)
|
||||
|
||||
args = parser.parse_args([])
|
||||
|
||||
if extra_args_setters is not None:
|
||||
for args_setter in extra_args_setters:
|
||||
args_setter(args)
|
||||
@ -515,9 +516,7 @@ class CrossEntropyCriterionTestBase(unittest.TestCase):
|
||||
def setUp(self):
|
||||
args = self.setUpArgs()
|
||||
self.model = DummyEncoderModel(encoder=DummyEncoder())
|
||||
self.criterion = self.criterion_cls.build_criterion(
|
||||
args=args, task=DummyTask(args)
|
||||
)
|
||||
self.criterion = self.criterion_cls.build_criterion(args, task=DummyTask(args))
|
||||
|
||||
def get_src_tokens(self, correct_prediction, aggregate):
|
||||
"""
|
||||
|
@ -11,7 +11,7 @@ from multiprocessing import Manager
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from fairseq import distributed_utils, optim
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, input_size, output_size):
|
||||
@ -23,13 +23,14 @@ class Model(nn.Module):
|
||||
return output
|
||||
|
||||
|
||||
def setup_model_loss_criterion(args, rank, is_cuda):
|
||||
def setup_model_loss_criterion(cfg, args, rank, is_cuda):
|
||||
"""
|
||||
setup model, criterion and optimizer based on input args
|
||||
"""
|
||||
args.distributed_rank = rank
|
||||
if args.distributed_world_size > 1:
|
||||
distributed_utils.distributed_init(args)
|
||||
cfg.distributed_training.distributed_rank = args.distributed_rank
|
||||
if cfg.distributed_training.distributed_world_size > 1:
|
||||
distributed_utils.distributed_init(cfg)
|
||||
torch.manual_seed(1)
|
||||
model = Model(args.input_size, args.nb_classes)
|
||||
loss_fn = nn.CrossEntropyLoss()
|
||||
@ -38,7 +39,10 @@ def setup_model_loss_criterion(args, rank, is_cuda):
|
||||
loss_fn = loss_fn.cuda()
|
||||
|
||||
optimizer = optim.sgd.SGD(args, model.parameters())
|
||||
optimizer = optim.FairseqBMUF(args, optimizer)
|
||||
optimizer = optim.FairseqBMUF(
|
||||
cfg=cfg.bmuf,
|
||||
optimizer=optimizer
|
||||
)
|
||||
|
||||
return model, loss_fn, optimizer
|
||||
|
||||
@ -52,13 +56,13 @@ def train_step(input, target, model, loss_fn, optimizer, **unused):
|
||||
optimizer.step()
|
||||
|
||||
|
||||
def single_gpu_training(args, rank, iterations, shared_results):
|
||||
def single_gpu_training(cfg, args, rank, iterations, shared_results):
|
||||
|
||||
is_cuda = torch.cuda.is_available()
|
||||
if is_cuda:
|
||||
torch.cuda.set_device(rank)
|
||||
|
||||
model, loss_fn, optimizer = setup_model_loss_criterion(args, rank, is_cuda)
|
||||
model, loss_fn, optimizer = setup_model_loss_criterion(cfg, args, rank, is_cuda)
|
||||
|
||||
for _ in range(iterations):
|
||||
input = torch.randn(1, args.input_size)
|
||||
@ -103,18 +107,44 @@ def setup_args():
|
||||
args.distributed_init_host = "localhost"
|
||||
args.distributed_port = port + 1
|
||||
args.local_world_size = args.distributed_world_size
|
||||
return args
|
||||
|
||||
cfg = OmegaConf.create()
|
||||
cfg.optimization = OmegaConf.create()
|
||||
cfg.common = OmegaConf.create()
|
||||
cfg.distributed_training = OmegaConf.create()
|
||||
cfg.dataset = OmegaConf.create()
|
||||
cfg.bmuf = OmegaConf.create()
|
||||
cfg.optimizer = OmegaConf.create()
|
||||
|
||||
cfg.bmuf.global_sync_iter = args.global_sync_iter
|
||||
cfg.bmuf.block_momentum = args.block_momentum
|
||||
cfg.bmuf.block_lr = args.block_lr
|
||||
cfg.dataset.batch_size = args.batch_size
|
||||
cfg.optimization.lr = args.lr
|
||||
cfg.optimizer.momentum = args.momentum
|
||||
cfg.optimizer.weight_decay = args.weight_decay
|
||||
cfg.bmuf.warmup_iterations = args.warmup_iterations
|
||||
cfg.bmuf.use_nbm = args.use_nbm
|
||||
cfg.bmuf.average_sync = args.average_sync
|
||||
cfg.common.model_parallel_size = args.model_parallel_size
|
||||
cfg.distributed_training.distributed_backend = args.distributed_backend
|
||||
cfg.distributed_training.distributed_world_size = args.distributed_world_size
|
||||
cfg.bmuf.distributed_world_size = args.distributed_world_size
|
||||
cfg.distributed_training.distributed_init_method = args.distributed_init_method
|
||||
cfg.distributed_training.distributed_port = args.distributed_port
|
||||
|
||||
return cfg, args
|
||||
|
||||
|
||||
@unittest.skipIf(torch.cuda.device_count() < 2, "test requires 2 GPUs")
|
||||
class TestBMUF(unittest.TestCase):
|
||||
def bmuf_process(self, args, iterations):
|
||||
def bmuf_process(self, cfg, args, iterations):
|
||||
processes = []
|
||||
results = Manager().dict()
|
||||
ctx = torch.multiprocessing.get_context("spawn")
|
||||
for rank in range(args.distributed_world_size):
|
||||
p = ctx.Process(
|
||||
target=single_gpu_training, args=(args, rank, iterations, results)
|
||||
target=single_gpu_training, args=(cfg, args, rank, iterations, results)
|
||||
)
|
||||
p.start()
|
||||
processes.append(p)
|
||||
@ -125,19 +155,20 @@ class TestBMUF(unittest.TestCase):
|
||||
|
||||
def test_bmuf_sync(self):
|
||||
# Train model for 1 iteration and do bmuf sync without doing warmup
|
||||
args = setup_args()
|
||||
cfg, args = setup_args()
|
||||
iterations = 1
|
||||
results = self.bmuf_process(args, iterations)
|
||||
results = self.bmuf_process(cfg, args, iterations)
|
||||
# Make sure params in both machines are same
|
||||
assert len(results) == 2
|
||||
self.assertAlmostEqual(results[0], results[1])
|
||||
|
||||
def test_warmup_sync(self):
|
||||
# Train model for 20 iteration and do warmup sync without doing bmuf sync
|
||||
args = setup_args()
|
||||
cfg, args = setup_args()
|
||||
args.warmup_iterations = 20
|
||||
cfg.bmuf.warmup_iterations = args.warmup_iterations
|
||||
iterations = 20
|
||||
results = self.bmuf_process(args, iterations)
|
||||
results = self.bmuf_process(cfg, args, iterations)
|
||||
# Make sure params in both machines are same
|
||||
assert len(results) == 2
|
||||
self.assertAlmostEqual(results[0], results[1])
|
||||
@ -145,22 +176,27 @@ class TestBMUF(unittest.TestCase):
|
||||
def test_warmup_sync_bmuf_sync(self):
|
||||
# Train model for 25 iteration and do warmup sync after 20 iteration
|
||||
# and bmuf sync after 25 iteration
|
||||
args = setup_args()
|
||||
cfg, args = setup_args()
|
||||
args.warmup_iterations = 20
|
||||
args.global_sync_iter = 5
|
||||
cfg.bmuf.warmup_iterations = args.warmup_iterations
|
||||
cfg.bmuf.global_sync_iter = args.global_sync_iter
|
||||
iterations = 25
|
||||
results = self.bmuf_process(args, iterations)
|
||||
results = self.bmuf_process(cfg, args, iterations)
|
||||
# Make sure params in both machines are same
|
||||
assert len(results) == 2
|
||||
self.assertAlmostEqual(results[0], results[1])
|
||||
|
||||
def test_single_gpu_bmuf(self):
|
||||
# Train model for 5 iterations and use GPU 1
|
||||
args = setup_args()
|
||||
cfg, args = setup_args()
|
||||
args.distributed_world_size = 1
|
||||
args.warmup_iterations = 5
|
||||
cfg.distributed_training.distributed_world_size = args.distributed_world_size
|
||||
cfg.bmuf.distributed_world_size = args.distributed_world_size
|
||||
cfg.bmuf.warmup_iterations = args.warmup_iterations
|
||||
iterations = 20
|
||||
results = self.bmuf_process(args, iterations)
|
||||
results = self.bmuf_process(cfg, args, iterations)
|
||||
assert len(results) == 1
|
||||
|
||||
def assertAlmostEqual(self, t1, t2):
|
||||
|
@ -9,6 +9,7 @@ import unittest
|
||||
|
||||
import torch
|
||||
from fairseq.optim.fp16_optimizer import FP16Optimizer, MemoryEfficientFP16Optimizer
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU")
|
||||
@ -27,17 +28,23 @@ class TestGradientScaling(unittest.TestCase):
|
||||
self.model.cuda().half()
|
||||
self.params = list(self.model.parameters())
|
||||
|
||||
self.namespace_dls = argparse.Namespace(
|
||||
optimizer="adam",
|
||||
lr=[0.1],
|
||||
adam_betas="(0.9, 0.999)",
|
||||
adam_eps=1e-8,
|
||||
weight_decay=0.0,
|
||||
fp16_init_scale=1,
|
||||
fp16_scale_window=1,
|
||||
fp16_scale_tolerance=1,
|
||||
threshold_loss_scale=1,
|
||||
min_loss_scale=1e-4,
|
||||
self.cfg_dls = OmegaConf.create(
|
||||
{
|
||||
"optimizer": {
|
||||
"_name": "adam",
|
||||
"lr": [0.1],
|
||||
"adam_betas": "(0.9, 0.999)",
|
||||
"adam_eps": 1e-8,
|
||||
"weight_decay": 0.0,
|
||||
},
|
||||
"common": {
|
||||
"fp16_init_scale": 1,
|
||||
"fp16_scale_window": 1,
|
||||
"fp16_scale_tolerance": 1,
|
||||
"threshold_loss_scale": 1,
|
||||
"min_loss_scale": 1e-4,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
def run_iter(self, model, params, optimizer):
|
||||
@ -68,7 +75,7 @@ class TestGradientScaling(unittest.TestCase):
|
||||
def test_mixed_precision(self):
|
||||
model = copy.deepcopy(self.model)
|
||||
params = list(model.parameters())
|
||||
optimizer = FP16Optimizer.build_optimizer(self.namespace_dls, params)
|
||||
optimizer = FP16Optimizer.build_optimizer(self.cfg_dls, params)
|
||||
|
||||
self.run_iter(model, params, optimizer)
|
||||
self.assertTrue(
|
||||
@ -87,9 +94,7 @@ class TestGradientScaling(unittest.TestCase):
|
||||
def test_memory_efficient(self):
|
||||
model = copy.deepcopy(self.model)
|
||||
params = list(model.parameters())
|
||||
optimizer = MemoryEfficientFP16Optimizer.build_optimizer(
|
||||
self.namespace_dls, params
|
||||
)
|
||||
optimizer = MemoryEfficientFP16Optimizer.build_optimizer(self.cfg_dls, params)
|
||||
|
||||
self.run_iter(model, params, optimizer)
|
||||
|
||||
|
@ -6,6 +6,7 @@
|
||||
import logging
|
||||
import unittest
|
||||
|
||||
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
|
||||
from fairseq.models.transformer import TransformerModel
|
||||
from tests.test_sequence_generator import get_dummy_task_and_parser
|
||||
|
||||
@ -25,7 +26,8 @@ class TestInferenceDropout(unittest.TestCase):
|
||||
def test_sets_inference_dropout_to_true(self):
|
||||
self.args.retain_dropout = True
|
||||
self.transformer_model = TransformerModel.build_model(self.args, self.task)
|
||||
self.transformer_model.prepare_for_inference_(self.args)
|
||||
cfg = convert_namespace_to_omegaconf(self.args)
|
||||
self.transformer_model.prepare_for_inference_(cfg)
|
||||
assert self.transformer_model.encoder.dropout_module.apply_during_inference
|
||||
assert self.transformer_model.decoder.dropout_module.apply_during_inference
|
||||
for layer in self.transformer_model.encoder.layers:
|
||||
@ -33,7 +35,8 @@ class TestInferenceDropout(unittest.TestCase):
|
||||
|
||||
def test_inference_dropout_false_by_default(self):
|
||||
self.transformer_model = TransformerModel.build_model(self.args, self.task)
|
||||
self.transformer_model.prepare_for_inference_(self.args)
|
||||
cfg = convert_namespace_to_omegaconf(self.args)
|
||||
self.transformer_model.prepare_for_inference_(cfg)
|
||||
assert not self.transformer_model.encoder.dropout_module.apply_during_inference
|
||||
assert not self.transformer_model.decoder.dropout_module.apply_during_inference
|
||||
for layer in self.transformer_model.encoder.layers:
|
||||
@ -59,7 +62,8 @@ class TestInferenceDropout(unittest.TestCase):
|
||||
"TransformerEncoderLayer",
|
||||
]
|
||||
self.transformer_model = TransformerModel.build_model(self.args, self.task)
|
||||
self.transformer_model.prepare_for_inference_(self.args)
|
||||
cfg = convert_namespace_to_omegaconf(self.args)
|
||||
self.transformer_model.prepare_for_inference_(cfg)
|
||||
assert self.transformer_model.encoder.dropout_module.apply_during_inference
|
||||
assert not self.transformer_model.decoder.dropout_module.apply_during_inference
|
||||
for layer in self.transformer_model.decoder.layers:
|
||||
|
@ -10,6 +10,7 @@ import unittest
|
||||
import torch
|
||||
from fairseq.optim.adam import FairseqAdam
|
||||
from fairseq.optim.fp16_optimizer import MemoryEfficientFP16Optimizer
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU")
|
||||
@ -26,25 +27,36 @@ class TestMemoryEfficientFP16(unittest.TestCase):
|
||||
params = list(model.parameters())
|
||||
|
||||
# initialize memory efficient FP16 optimizer
|
||||
# with pseudo DictConfigs
|
||||
optimizer = FairseqAdam(
|
||||
argparse.Namespace(
|
||||
lr=[0.00001],
|
||||
adam_betas="(0.9, 0.999)",
|
||||
adam_eps=1e-8,
|
||||
weight_decay=0.0,
|
||||
cfg=OmegaConf.create(
|
||||
vars(
|
||||
argparse.Namespace(
|
||||
adam_betas="(0.9, 0.999)",
|
||||
adam_eps=1e-8,
|
||||
weight_decay=0.0,
|
||||
lr=[0.00001],
|
||||
)
|
||||
)
|
||||
),
|
||||
params,
|
||||
params=params,
|
||||
)
|
||||
me_optimizer = MemoryEfficientFP16Optimizer(
|
||||
argparse.Namespace(
|
||||
fp16_init_scale=1,
|
||||
fp16_scale_window=1,
|
||||
fp16_scale_tolerance=1,
|
||||
threshold_loss_scale=1,
|
||||
min_loss_scale=1e-4,
|
||||
cfg=OmegaConf.create(
|
||||
{
|
||||
"common": vars(
|
||||
argparse.Namespace(
|
||||
fp16_init_scale=1,
|
||||
fp16_scale_window=1,
|
||||
fp16_scale_tolerance=1,
|
||||
threshold_loss_scale=1,
|
||||
min_loss_scale=1e-4,
|
||||
)
|
||||
)
|
||||
}
|
||||
),
|
||||
params,
|
||||
optimizer,
|
||||
params=params,
|
||||
optimizer=optimizer,
|
||||
)
|
||||
|
||||
# optimizer state is created in the first step
|
||||
|
@ -11,6 +11,7 @@ from unittest.mock import MagicMock, patch
|
||||
|
||||
import torch
|
||||
from fairseq import checkpoint_utils, data
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
|
||||
def mock_trainer(epoch, num_updates, iterations_in_epoch):
|
||||
@ -56,21 +57,29 @@ def get_trainer_and_epoch_itr(epoch, epoch_size, num_updates, iterations_in_epoc
|
||||
return trainer, epoch_itr
|
||||
|
||||
|
||||
def get_mock_args(finetune_from_model=None):
|
||||
args_mock = MagicMock()
|
||||
args_mock.optimizer_overrides = "{}"
|
||||
args_mock.reset_dataloader = False
|
||||
args_mock.reset_meters = False
|
||||
args_mock.reset_optimizer = False
|
||||
args_mock.reset_lr_scheduler = False
|
||||
args_mock.finetune_from_model = finetune_from_model
|
||||
args_mock.model_parallel_size = 1
|
||||
return args_mock
|
||||
def get_mock_cfg(finetune_from_model):
|
||||
cfg_mock = OmegaConf.create(
|
||||
{
|
||||
"checkpoint": {
|
||||
"optimizer_overrides": "{}",
|
||||
"reset_dataloader": False,
|
||||
"reset_meters": False,
|
||||
"reset_optimizer": False,
|
||||
"reset_lr_scheduler": False,
|
||||
"finetune_from_model": finetune_from_model,
|
||||
"model_parallel_size": 1,
|
||||
},
|
||||
"common": {
|
||||
"model_parallel_size": 1,
|
||||
},
|
||||
}
|
||||
)
|
||||
return cfg_mock
|
||||
|
||||
|
||||
class TestLoadCheckpoint(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.args_mock = get_mock_args()
|
||||
self.cfg_mock = get_mock_cfg(None)
|
||||
self.patches = {
|
||||
"os.makedirs": MagicMock(),
|
||||
"os.path.join": MagicMock(),
|
||||
@ -91,7 +100,9 @@ class TestLoadCheckpoint(unittest.TestCase):
|
||||
trainer, epoch_itr = get_trainer_and_epoch_itr(2, 150, 200, 50)
|
||||
trainer.get_train_iterator = MagicMock(return_value=epoch_itr)
|
||||
|
||||
_, epoch_itr = checkpoint_utils.load_checkpoint(self.args_mock, trainer)
|
||||
_, epoch_itr = checkpoint_utils.load_checkpoint(
|
||||
self.cfg_mock.checkpoint, trainer
|
||||
)
|
||||
|
||||
self.assertEqual(epoch_itr.epoch, 2)
|
||||
self.assertEqual(epoch_itr.iterations_in_epoch, 50)
|
||||
@ -120,7 +131,9 @@ class TestLoadCheckpoint(unittest.TestCase):
|
||||
trainer, epoch_itr = get_trainer_and_epoch_itr(2, 150, 300, 150)
|
||||
trainer.get_train_iterator = MagicMock(return_value=epoch_itr)
|
||||
|
||||
_, epoch_itr = checkpoint_utils.load_checkpoint(self.args_mock, trainer)
|
||||
_, epoch_itr = checkpoint_utils.load_checkpoint(
|
||||
self.cfg_mock.checkpoint, trainer
|
||||
)
|
||||
itr = epoch_itr.next_epoch_itr(shuffle=False)
|
||||
|
||||
self.assertEqual(epoch_itr.epoch, 3)
|
||||
@ -133,7 +146,9 @@ class TestLoadCheckpoint(unittest.TestCase):
|
||||
trainer.get_train_iterator = MagicMock(return_value=epoch_itr)
|
||||
self.patches["os.path.isfile"].return_value = False
|
||||
|
||||
_, epoch_itr = checkpoint_utils.load_checkpoint(self.args_mock, trainer)
|
||||
_, epoch_itr = checkpoint_utils.load_checkpoint(
|
||||
self.cfg_mock.checkpoint, trainer
|
||||
)
|
||||
itr = epoch_itr.next_epoch_itr(shuffle=False)
|
||||
|
||||
self.assertEqual(epoch_itr.epoch, 1)
|
||||
@ -152,10 +167,12 @@ class TestLoadCheckpoint(unittest.TestCase):
|
||||
"reset_dataloader",
|
||||
]:
|
||||
with self.subTest(arg=arg):
|
||||
args_mock = get_mock_args("/temp/checkpoint_pretrained.pt")
|
||||
setattr(args_mock, arg, True)
|
||||
cfg_mock = get_mock_cfg("/temp/checkpoint_pretrained.pt")
|
||||
cfg_mock["checkpoint"][arg] = True
|
||||
with self.assertRaises(Exception) as context:
|
||||
_, _ = checkpoint_utils.load_checkpoint(args_mock, trainer)
|
||||
_, _ = checkpoint_utils.load_checkpoint(
|
||||
cfg_mock.checkpoint, trainer
|
||||
)
|
||||
|
||||
self.assertTrue(
|
||||
"--finetune-from-model can not be set together with either --reset-optimizer"
|
||||
@ -168,8 +185,6 @@ class TestLoadCheckpoint(unittest.TestCase):
|
||||
trainer, epoch_itr = get_trainer_and_epoch_itr(1, 150, 0, 0)
|
||||
trainer.get_train_iterator = MagicMock(return_value=epoch_itr)
|
||||
from_model_path = "/temp/checkpoint_pretrained.pt"
|
||||
args_mock = get_mock_args(from_model_path)
|
||||
args_mock.restore_file = "checkpoint_last.pt"
|
||||
|
||||
def mock_finetune_exist(path):
|
||||
if path == from_model_path:
|
||||
@ -180,7 +195,9 @@ class TestLoadCheckpoint(unittest.TestCase):
|
||||
self.patches[
|
||||
"fairseq.file_io.PathManager.exists"
|
||||
].side_effect = mock_finetune_exist
|
||||
_, _ = checkpoint_utils.load_checkpoint(args_mock, trainer)
|
||||
cfg_mock = get_mock_cfg(from_model_path)
|
||||
cfg_mock.checkpoint.restore_file = "checkpoint_last.pt"
|
||||
_, _ = checkpoint_utils.load_checkpoint(cfg_mock.checkpoint, trainer)
|
||||
(
|
||||
checkpoint_path,
|
||||
reset_optimizer,
|
||||
@ -197,8 +214,6 @@ class TestLoadCheckpoint(unittest.TestCase):
|
||||
trainer, epoch_itr = get_trainer_and_epoch_itr(1, 150, 0, 0)
|
||||
trainer.get_train_iterator = MagicMock(return_value=epoch_itr)
|
||||
from_model_path = "/temp/checkpoint_pretrained.pt"
|
||||
args_mock = get_mock_args(from_model_path)
|
||||
args_mock.restore_file = "checkpoint_last.pt"
|
||||
|
||||
# launch second time
|
||||
# both restore_file=checkpoint_last.pt and finetune_from_model are set
|
||||
@ -211,7 +226,9 @@ class TestLoadCheckpoint(unittest.TestCase):
|
||||
self.patches[
|
||||
"fairseq.file_io.PathManager.exists"
|
||||
].side_effect = mock_finetune_exist
|
||||
_, _ = checkpoint_utils.load_checkpoint(args_mock, trainer)
|
||||
cfg_mock = get_mock_cfg(from_model_path)
|
||||
cfg_mock.checkpoint.restore_file = "checkpoint_last.pt"
|
||||
_, _ = checkpoint_utils.load_checkpoint(cfg_mock.checkpoint, trainer)
|
||||
(
|
||||
checkpoint_path,
|
||||
reset_optimizer,
|
||||
|
@ -20,7 +20,7 @@ from fairseq.models import (
|
||||
FairseqIncrementalDecoder,
|
||||
)
|
||||
from fairseq.models.fairseq_encoder import EncoderOut
|
||||
from fairseq.tasks import FairseqTask, LegacyFairseqTask
|
||||
from fairseq.tasks import LegacyFairseqTask
|
||||
from fairseq_cli import generate, interactive, preprocess, train, validate
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user