From 6d23cc7e7c32d1a6aa1d2d4a4c94abe50c980126 Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Thu, 4 Mar 2021 13:31:02 -0800 Subject: [PATCH] Move checkpoint state_dict creation into Trainer (#1666) Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1666 Context: the checkpoint saving call stack has become a bit convoluted: ``` train.py + checkpoint_utils.save_checkpoint + trainer.save_checkpoint + checkpoint_utils.save_state + checkpoint_utils.torch_persistent_save ``` This diff slightly simplifies the checkpoint saving logic by exposing a `state_dict` method inside the Trainer. This simplifies the call stack to: ``` train.py + checkpoint_utils.save_checkpoint + trainer.save_checkpoint + checkpoint_utils.torch_persistent_save ``` This new structure is important for the FullyShardedDataParallel diff (next diff in the stack), since it enables the Trainer to save multiple checkpoints for the different optimizer state shards. Test Plan: - unit tests - trained WMT En-De models; confirmed checkpoints save/load properly, resuming from a checkpoint gives identical results - `buck test fblearner/flow/projects/langtech/translation:tests` (2 failures are in trunk too): https://www.internalfb.com/intern/testinfra/testconsole/testrun/2533274840914654/ Reviewed By: zhengwy888 Differential Revision: D26771146 Pulled By: myleott fbshipit-source-id: 10f91979cd42205c1d8abcaa9ab56f63eba31e93 --- fairseq/checkpoint_utils.py | 71 ++++------------------------------ fairseq/dataclass/configs.py | 1 - fairseq/trainer.py | 67 +++++++++++++++++++++++++------- tests/test_checkpoint_utils.py | 7 ++-- tests/test_train.py | 1 + 5 files changed, 64 insertions(+), 83 deletions(-) diff --git a/fairseq/checkpoint_utils.py b/fairseq/checkpoint_utils.py index 97f22041b..5a98dad2a 100644 --- a/fairseq/checkpoint_utils.py +++ b/fairseq/checkpoint_utils.py @@ -31,7 +31,7 @@ def save_checkpoint(cfg: CheckpointConfig, trainer, epoch_itr, val_loss): from fairseq import meters # only one worker should attempt to create the required dir - if cfg.distributed_rank == 0: + if trainer.data_parallel_rank == 0: os.makedirs(cfg.save_dir, exist_ok=True) prev_best = getattr(save_checkpoint, "best", val_loss) @@ -44,7 +44,7 @@ def save_checkpoint(cfg: CheckpointConfig, trainer, epoch_itr, val_loss): trainer.consolidate_optimizer() - if not trainer.is_data_parallel_master: + if not trainer.should_save_checkpoint_on_current_rank: return write_timer = meters.StopwatchMeter() @@ -59,7 +59,7 @@ def save_checkpoint(cfg: CheckpointConfig, trainer, epoch_itr, val_loss): def is_better(a, b): return a >= b if cfg.maximize_best_checkpoint_metric else a <= b - suffix = cfg.checkpoint_suffix or "" + suffix = trainer.checkpoint_suffix checkpoint_conds = collections.OrderedDict() checkpoint_conds["checkpoint{}{}.pt".format(epoch, suffix)] = ( end_of_epoch and not cfg.no_epoch_checkpoints and epoch % cfg.save_interval == 0 @@ -165,7 +165,7 @@ def load_checkpoint(cfg: CheckpointConfig, trainer, **passthrough_args): " or reset_lr_scheduler or reset_meters or reset_dataloader" ) - suffix = cfg.checkpoint_suffix + suffix = trainer.checkpoint_suffix if ( cfg.restore_file == "checkpoint_last.pt" ): # default value of restore_file is 'checkpoint_last.pt' @@ -190,7 +190,7 @@ def load_checkpoint(cfg: CheckpointConfig, trainer, **passthrough_args): raise ValueError( f"--funetune-from-model {cfg.finetune_from_model} does not exist" ) - elif cfg.model_parallel_size > 1: + elif suffix is not None: checkpoint_path = cfg.restore_file.replace(".pt", suffix + ".pt") else: checkpoint_path = cfg.restore_file @@ -405,8 +405,8 @@ def checkpoint_paths(path, pattern=r"checkpoint(\d+)\.pt"): return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)] -def torch_persistent_save(cfg: CheckpointConfig, obj, filename): - if cfg.write_checkpoints_asynchronously: +def torch_persistent_save(obj, filename, async_write: bool = False): + if async_write: with PathManager.opena(filename, "wb") as f: _torch_persistent_save(obj, f) else: @@ -434,61 +434,6 @@ def _torch_persistent_save(obj, f): logger.error(traceback.format_exc()) -def save_state( - filename, - cfg: FairseqConfig, - model_state_dict, - criterion, - optimizer, - lr_scheduler, - num_updates, - optim_history=None, - extra_state=None, - task=None, - **kwargs, -): - from fairseq import utils - - if optim_history is None: - optim_history = [] - if extra_state is None: - extra_state = {} - state_dict = { - "cfg": OmegaConf.to_container(cfg) if OmegaConf.is_config(cfg) else cfg, - "args": kwargs.get("args", None), - "model": model_state_dict or {}, - "optimizer_history": optim_history - + [ - { - "criterion_name": criterion.__class__.__name__, - "optimizer_name": optimizer.__class__.__name__, - "lr_scheduler_state": lr_scheduler.state_dict(), - "num_updates": num_updates, - } - ], - "extra_state": extra_state, - "task_state": task.state_dict() if task is not None else {}, - } - if utils.has_parameters(criterion): - state_dict["criterion"] = criterion.state_dict() - - if cfg is None: - cfg = state_dict["args"] - assert cfg is not None, "must provide cfg or args" - - if isinstance(cfg, DictConfig): - no_save_optimizer_state = cfg.checkpoint.no_save_optimizer_state - else: - no_save_optimizer_state = cfg.no_save_optimizer_state - if not no_save_optimizer_state: - state_dict["last_optimizer_state"] = optimizer.state_dict() - - # keep everything on CPU - state_dict = utils.move_to_cpu(state_dict) - - torch_persistent_save(cfg.checkpoint, state_dict, filename) - - def _upgrade_state_dict(state): """Helper for upgrading old model checkpoints.""" from fairseq import models, registry, tasks @@ -529,7 +474,7 @@ def _upgrade_state_dict(state): if "num_updates" not in state["optimizer_history"][-1]: state["optimizer_history"][-1]["num_updates"] = 0 # old model checkpoints may not have separate source/target positions - if hasattr(state["args"], "max_positions") and not hasattr( + if "args" in state and hasattr(state["args"], "max_positions") and not hasattr( state["args"], "max_source_positions" ): state["args"].max_source_positions = state["args"].max_positions diff --git a/fairseq/dataclass/configs.py b/fairseq/dataclass/configs.py index 39355b1ca..4d3c60bfd 100644 --- a/fairseq/dataclass/configs.py +++ b/fairseq/dataclass/configs.py @@ -618,7 +618,6 @@ class CheckpointConfig(FairseqDataclass): }, ) model_parallel_size: int = II("common.model_parallel_size") - distributed_rank: int = II("distributed_training.distributed_rank") @dataclass diff --git a/fairseq/trainer.py b/fairseq/trainer.py index 680a7ee95..45d9591d7 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -25,6 +25,8 @@ from fairseq.logging import meters, metrics from fairseq.nan_detector import NanDetector from fairseq.optim import lr_scheduler +from omegaconf import OmegaConf + logger = logging.getLogger(__name__) @@ -171,6 +173,16 @@ class Trainer(object): and not self.cfg.optimization.use_bmuf ) + @property + def should_save_checkpoint_on_current_rank(self) -> bool: + """Indicates whether to save checkpoints on the current DDP rank.""" + return self.is_data_parallel_master + + @property + def checkpoint_suffix(self) -> str: + """Suffix to add to the checkpoint file name.""" + return self.cfg.checkpoint.checkpoint_suffix or "" + @property def criterion(self): if self._wrapped_criterion is None: @@ -274,25 +286,50 @@ class Trainer(object): if hasattr(self.optimizer.optimizer, "consolidate_state_dict"): self.optimizer.optimizer.consolidate_state_dict() + def state_dict(self): + state_dict = { + "args": None, # legacy + "cfg": ( + OmegaConf.to_container(self.cfg) + if OmegaConf.is_config(self.cfg) else self.cfg + ), + "model": self.model.state_dict(), + "criterion": ( + self.criterion.state_dict() + if utils.has_parameters(self.criterion) else None + ), + "optimizer_history": (self._optim_history or []) + + [ + { + "criterion_name": self.get_criterion().__class__.__name__, + "optimizer_name": self.optimizer.__class__.__name__, + "lr_scheduler_state": self.lr_scheduler.state_dict(), + "num_updates": self.get_num_updates(), + } + ], + "task_state": self.task.state_dict() if self.task is not None else {}, + "extra_state": { + "metrics": metrics.state_dict(), + "previous_training_time": self.cumulative_training_time(), + } + } + if not self.cfg.checkpoint.no_save_optimizer_state: + state_dict["last_optimizer_state"] = self.optimizer.state_dict() + return state_dict + def save_checkpoint(self, filename, extra_state): """Save all training state in a checkpoint file.""" - if self.is_data_parallel_master: # only save one checkpoint - logger.info(f"Saving checkpoint to {filename}") - extra_state["metrics"] = metrics.state_dict() - extra_state["previous_training_time"] = self.cumulative_training_time() - checkpoint_utils.save_state( + logger.info(f"Saving checkpoint to {filename}") + # call state_dict on all ranks in case it needs internal communication + state_dict = utils.move_to_cpu(self.state_dict()) + state_dict["extra_state"].update(extra_state) + if self.should_save_checkpoint_on_current_rank: + checkpoint_utils.torch_persistent_save( + state_dict, filename, - self.cfg, - self.model.state_dict(), - self.get_criterion(), - self.optimizer, - self.lr_scheduler, - self.get_num_updates(), - optim_history=self._optim_history, - extra_state=extra_state, - task=self.task, + async_write=self.cfg.checkpoint.write_checkpoints_asynchronously, ) - logger.info(f"Finished saving checkpoint to {filename}") + logger.info(f"Finished saving checkpoint to {filename}") def load_checkpoint( self, diff --git a/tests/test_checkpoint_utils.py b/tests/test_checkpoint_utils.py index 3278de6b9..0f2822263 100644 --- a/tests/test_checkpoint_utils.py +++ b/tests/test_checkpoint_utils.py @@ -90,15 +90,14 @@ class TestCheckpointUtils(unittest.TestCase): self.assertEqual(len(ensemble[0].decoder.layers), 1) def test_torch_persistent_save_async(self): - cfg = OmegaConf.create() - cfg.dataset = OmegaConf.create() - cfg.dataset.write_checkpoints_asynchronously = True state_dict = {} filename = "async_checkpoint.pt" with patch(f"{checkpoint_utils.__name__}.PathManager.opena") as mock_opena: with patch(f"{checkpoint_utils.__name__}._torch_persistent_save") as mock_save: - checkpoint_utils.torch_persistent_save(cfg.dataset, state_dict, filename) + checkpoint_utils.torch_persistent_save( + state_dict, filename, async_write=True + ) mock_opena.assert_called_with(filename, "wb") mock_save.assert_called() diff --git a/tests/test_train.py b/tests/test_train.py index 57daa194b..65f4683bc 100644 --- a/tests/test_train.py +++ b/tests/test_train.py @@ -68,6 +68,7 @@ def get_mock_cfg(finetune_from_model): "reset_lr_scheduler": False, "finetune_from_model": finetune_from_model, "model_parallel_size": 1, + "restore_file": "checkpoint_last.pt", }, "common": { "model_parallel_size": 1,