mirror of
https://github.com/facebookresearch/fairseq.git
synced 2024-10-26 17:32:57 +03:00
Move checkpoint state_dict creation into Trainer (#1666)
Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1666 Context: the checkpoint saving call stack has become a bit convoluted: ``` train.py + checkpoint_utils.save_checkpoint + trainer.save_checkpoint + checkpoint_utils.save_state + checkpoint_utils.torch_persistent_save ``` This diff slightly simplifies the checkpoint saving logic by exposing a `state_dict` method inside the Trainer. This simplifies the call stack to: ``` train.py + checkpoint_utils.save_checkpoint + trainer.save_checkpoint + checkpoint_utils.torch_persistent_save ``` This new structure is important for the FullyShardedDataParallel diff (next diff in the stack), since it enables the Trainer to save multiple checkpoints for the different optimizer state shards. Test Plan: - unit tests - trained WMT En-De models; confirmed checkpoints save/load properly, resuming from a checkpoint gives identical results - `buck test fblearner/flow/projects/langtech/translation:tests` (2 failures are in trunk too): https://www.internalfb.com/intern/testinfra/testconsole/testrun/2533274840914654/ Reviewed By: zhengwy888 Differential Revision: D26771146 Pulled By: myleott fbshipit-source-id: 10f91979cd42205c1d8abcaa9ab56f63eba31e93
This commit is contained in:
parent
f1c595beb8
commit
6d23cc7e7c
@ -31,7 +31,7 @@ def save_checkpoint(cfg: CheckpointConfig, trainer, epoch_itr, val_loss):
|
||||
from fairseq import meters
|
||||
|
||||
# only one worker should attempt to create the required dir
|
||||
if cfg.distributed_rank == 0:
|
||||
if trainer.data_parallel_rank == 0:
|
||||
os.makedirs(cfg.save_dir, exist_ok=True)
|
||||
|
||||
prev_best = getattr(save_checkpoint, "best", val_loss)
|
||||
@ -44,7 +44,7 @@ def save_checkpoint(cfg: CheckpointConfig, trainer, epoch_itr, val_loss):
|
||||
|
||||
trainer.consolidate_optimizer()
|
||||
|
||||
if not trainer.is_data_parallel_master:
|
||||
if not trainer.should_save_checkpoint_on_current_rank:
|
||||
return
|
||||
|
||||
write_timer = meters.StopwatchMeter()
|
||||
@ -59,7 +59,7 @@ def save_checkpoint(cfg: CheckpointConfig, trainer, epoch_itr, val_loss):
|
||||
def is_better(a, b):
|
||||
return a >= b if cfg.maximize_best_checkpoint_metric else a <= b
|
||||
|
||||
suffix = cfg.checkpoint_suffix or ""
|
||||
suffix = trainer.checkpoint_suffix
|
||||
checkpoint_conds = collections.OrderedDict()
|
||||
checkpoint_conds["checkpoint{}{}.pt".format(epoch, suffix)] = (
|
||||
end_of_epoch and not cfg.no_epoch_checkpoints and epoch % cfg.save_interval == 0
|
||||
@ -165,7 +165,7 @@ def load_checkpoint(cfg: CheckpointConfig, trainer, **passthrough_args):
|
||||
" or reset_lr_scheduler or reset_meters or reset_dataloader"
|
||||
)
|
||||
|
||||
suffix = cfg.checkpoint_suffix
|
||||
suffix = trainer.checkpoint_suffix
|
||||
if (
|
||||
cfg.restore_file == "checkpoint_last.pt"
|
||||
): # default value of restore_file is 'checkpoint_last.pt'
|
||||
@ -190,7 +190,7 @@ def load_checkpoint(cfg: CheckpointConfig, trainer, **passthrough_args):
|
||||
raise ValueError(
|
||||
f"--funetune-from-model {cfg.finetune_from_model} does not exist"
|
||||
)
|
||||
elif cfg.model_parallel_size > 1:
|
||||
elif suffix is not None:
|
||||
checkpoint_path = cfg.restore_file.replace(".pt", suffix + ".pt")
|
||||
else:
|
||||
checkpoint_path = cfg.restore_file
|
||||
@ -405,8 +405,8 @@ def checkpoint_paths(path, pattern=r"checkpoint(\d+)\.pt"):
|
||||
return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)]
|
||||
|
||||
|
||||
def torch_persistent_save(cfg: CheckpointConfig, obj, filename):
|
||||
if cfg.write_checkpoints_asynchronously:
|
||||
def torch_persistent_save(obj, filename, async_write: bool = False):
|
||||
if async_write:
|
||||
with PathManager.opena(filename, "wb") as f:
|
||||
_torch_persistent_save(obj, f)
|
||||
else:
|
||||
@ -434,61 +434,6 @@ def _torch_persistent_save(obj, f):
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
|
||||
def save_state(
|
||||
filename,
|
||||
cfg: FairseqConfig,
|
||||
model_state_dict,
|
||||
criterion,
|
||||
optimizer,
|
||||
lr_scheduler,
|
||||
num_updates,
|
||||
optim_history=None,
|
||||
extra_state=None,
|
||||
task=None,
|
||||
**kwargs,
|
||||
):
|
||||
from fairseq import utils
|
||||
|
||||
if optim_history is None:
|
||||
optim_history = []
|
||||
if extra_state is None:
|
||||
extra_state = {}
|
||||
state_dict = {
|
||||
"cfg": OmegaConf.to_container(cfg) if OmegaConf.is_config(cfg) else cfg,
|
||||
"args": kwargs.get("args", None),
|
||||
"model": model_state_dict or {},
|
||||
"optimizer_history": optim_history
|
||||
+ [
|
||||
{
|
||||
"criterion_name": criterion.__class__.__name__,
|
||||
"optimizer_name": optimizer.__class__.__name__,
|
||||
"lr_scheduler_state": lr_scheduler.state_dict(),
|
||||
"num_updates": num_updates,
|
||||
}
|
||||
],
|
||||
"extra_state": extra_state,
|
||||
"task_state": task.state_dict() if task is not None else {},
|
||||
}
|
||||
if utils.has_parameters(criterion):
|
||||
state_dict["criterion"] = criterion.state_dict()
|
||||
|
||||
if cfg is None:
|
||||
cfg = state_dict["args"]
|
||||
assert cfg is not None, "must provide cfg or args"
|
||||
|
||||
if isinstance(cfg, DictConfig):
|
||||
no_save_optimizer_state = cfg.checkpoint.no_save_optimizer_state
|
||||
else:
|
||||
no_save_optimizer_state = cfg.no_save_optimizer_state
|
||||
if not no_save_optimizer_state:
|
||||
state_dict["last_optimizer_state"] = optimizer.state_dict()
|
||||
|
||||
# keep everything on CPU
|
||||
state_dict = utils.move_to_cpu(state_dict)
|
||||
|
||||
torch_persistent_save(cfg.checkpoint, state_dict, filename)
|
||||
|
||||
|
||||
def _upgrade_state_dict(state):
|
||||
"""Helper for upgrading old model checkpoints."""
|
||||
from fairseq import models, registry, tasks
|
||||
@ -529,7 +474,7 @@ def _upgrade_state_dict(state):
|
||||
if "num_updates" not in state["optimizer_history"][-1]:
|
||||
state["optimizer_history"][-1]["num_updates"] = 0
|
||||
# old model checkpoints may not have separate source/target positions
|
||||
if hasattr(state["args"], "max_positions") and not hasattr(
|
||||
if "args" in state and hasattr(state["args"], "max_positions") and not hasattr(
|
||||
state["args"], "max_source_positions"
|
||||
):
|
||||
state["args"].max_source_positions = state["args"].max_positions
|
||||
|
@ -618,7 +618,6 @@ class CheckpointConfig(FairseqDataclass):
|
||||
},
|
||||
)
|
||||
model_parallel_size: int = II("common.model_parallel_size")
|
||||
distributed_rank: int = II("distributed_training.distributed_rank")
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -25,6 +25,8 @@ from fairseq.logging import meters, metrics
|
||||
from fairseq.nan_detector import NanDetector
|
||||
from fairseq.optim import lr_scheduler
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -171,6 +173,16 @@ class Trainer(object):
|
||||
and not self.cfg.optimization.use_bmuf
|
||||
)
|
||||
|
||||
@property
|
||||
def should_save_checkpoint_on_current_rank(self) -> bool:
|
||||
"""Indicates whether to save checkpoints on the current DDP rank."""
|
||||
return self.is_data_parallel_master
|
||||
|
||||
@property
|
||||
def checkpoint_suffix(self) -> str:
|
||||
"""Suffix to add to the checkpoint file name."""
|
||||
return self.cfg.checkpoint.checkpoint_suffix or ""
|
||||
|
||||
@property
|
||||
def criterion(self):
|
||||
if self._wrapped_criterion is None:
|
||||
@ -274,23 +286,48 @@ class Trainer(object):
|
||||
if hasattr(self.optimizer.optimizer, "consolidate_state_dict"):
|
||||
self.optimizer.optimizer.consolidate_state_dict()
|
||||
|
||||
def state_dict(self):
|
||||
state_dict = {
|
||||
"args": None, # legacy
|
||||
"cfg": (
|
||||
OmegaConf.to_container(self.cfg)
|
||||
if OmegaConf.is_config(self.cfg) else self.cfg
|
||||
),
|
||||
"model": self.model.state_dict(),
|
||||
"criterion": (
|
||||
self.criterion.state_dict()
|
||||
if utils.has_parameters(self.criterion) else None
|
||||
),
|
||||
"optimizer_history": (self._optim_history or [])
|
||||
+ [
|
||||
{
|
||||
"criterion_name": self.get_criterion().__class__.__name__,
|
||||
"optimizer_name": self.optimizer.__class__.__name__,
|
||||
"lr_scheduler_state": self.lr_scheduler.state_dict(),
|
||||
"num_updates": self.get_num_updates(),
|
||||
}
|
||||
],
|
||||
"task_state": self.task.state_dict() if self.task is not None else {},
|
||||
"extra_state": {
|
||||
"metrics": metrics.state_dict(),
|
||||
"previous_training_time": self.cumulative_training_time(),
|
||||
}
|
||||
}
|
||||
if not self.cfg.checkpoint.no_save_optimizer_state:
|
||||
state_dict["last_optimizer_state"] = self.optimizer.state_dict()
|
||||
return state_dict
|
||||
|
||||
def save_checkpoint(self, filename, extra_state):
|
||||
"""Save all training state in a checkpoint file."""
|
||||
if self.is_data_parallel_master: # only save one checkpoint
|
||||
logger.info(f"Saving checkpoint to {filename}")
|
||||
extra_state["metrics"] = metrics.state_dict()
|
||||
extra_state["previous_training_time"] = self.cumulative_training_time()
|
||||
checkpoint_utils.save_state(
|
||||
# call state_dict on all ranks in case it needs internal communication
|
||||
state_dict = utils.move_to_cpu(self.state_dict())
|
||||
state_dict["extra_state"].update(extra_state)
|
||||
if self.should_save_checkpoint_on_current_rank:
|
||||
checkpoint_utils.torch_persistent_save(
|
||||
state_dict,
|
||||
filename,
|
||||
self.cfg,
|
||||
self.model.state_dict(),
|
||||
self.get_criterion(),
|
||||
self.optimizer,
|
||||
self.lr_scheduler,
|
||||
self.get_num_updates(),
|
||||
optim_history=self._optim_history,
|
||||
extra_state=extra_state,
|
||||
task=self.task,
|
||||
async_write=self.cfg.checkpoint.write_checkpoints_asynchronously,
|
||||
)
|
||||
logger.info(f"Finished saving checkpoint to {filename}")
|
||||
|
||||
|
@ -90,15 +90,14 @@ class TestCheckpointUtils(unittest.TestCase):
|
||||
self.assertEqual(len(ensemble[0].decoder.layers), 1)
|
||||
|
||||
def test_torch_persistent_save_async(self):
|
||||
cfg = OmegaConf.create()
|
||||
cfg.dataset = OmegaConf.create()
|
||||
cfg.dataset.write_checkpoints_asynchronously = True
|
||||
state_dict = {}
|
||||
filename = "async_checkpoint.pt"
|
||||
|
||||
with patch(f"{checkpoint_utils.__name__}.PathManager.opena") as mock_opena:
|
||||
with patch(f"{checkpoint_utils.__name__}._torch_persistent_save") as mock_save:
|
||||
checkpoint_utils.torch_persistent_save(cfg.dataset, state_dict, filename)
|
||||
checkpoint_utils.torch_persistent_save(
|
||||
state_dict, filename, async_write=True
|
||||
)
|
||||
mock_opena.assert_called_with(filename, "wb")
|
||||
mock_save.assert_called()
|
||||
|
||||
|
@ -68,6 +68,7 @@ def get_mock_cfg(finetune_from_model):
|
||||
"reset_lr_scheduler": False,
|
||||
"finetune_from_model": finetune_from_model,
|
||||
"model_parallel_size": 1,
|
||||
"restore_file": "checkpoint_last.pt",
|
||||
},
|
||||
"common": {
|
||||
"model_parallel_size": 1,
|
||||
|
Loading…
Reference in New Issue
Block a user