diff --git a/fairseq/checkpoint_utils.py b/fairseq/checkpoint_utils.py index 97f22041b..5a98dad2a 100644 --- a/fairseq/checkpoint_utils.py +++ b/fairseq/checkpoint_utils.py @@ -31,7 +31,7 @@ def save_checkpoint(cfg: CheckpointConfig, trainer, epoch_itr, val_loss): from fairseq import meters # only one worker should attempt to create the required dir - if cfg.distributed_rank == 0: + if trainer.data_parallel_rank == 0: os.makedirs(cfg.save_dir, exist_ok=True) prev_best = getattr(save_checkpoint, "best", val_loss) @@ -44,7 +44,7 @@ def save_checkpoint(cfg: CheckpointConfig, trainer, epoch_itr, val_loss): trainer.consolidate_optimizer() - if not trainer.is_data_parallel_master: + if not trainer.should_save_checkpoint_on_current_rank: return write_timer = meters.StopwatchMeter() @@ -59,7 +59,7 @@ def save_checkpoint(cfg: CheckpointConfig, trainer, epoch_itr, val_loss): def is_better(a, b): return a >= b if cfg.maximize_best_checkpoint_metric else a <= b - suffix = cfg.checkpoint_suffix or "" + suffix = trainer.checkpoint_suffix checkpoint_conds = collections.OrderedDict() checkpoint_conds["checkpoint{}{}.pt".format(epoch, suffix)] = ( end_of_epoch and not cfg.no_epoch_checkpoints and epoch % cfg.save_interval == 0 @@ -165,7 +165,7 @@ def load_checkpoint(cfg: CheckpointConfig, trainer, **passthrough_args): " or reset_lr_scheduler or reset_meters or reset_dataloader" ) - suffix = cfg.checkpoint_suffix + suffix = trainer.checkpoint_suffix if ( cfg.restore_file == "checkpoint_last.pt" ): # default value of restore_file is 'checkpoint_last.pt' @@ -190,7 +190,7 @@ def load_checkpoint(cfg: CheckpointConfig, trainer, **passthrough_args): raise ValueError( f"--funetune-from-model {cfg.finetune_from_model} does not exist" ) - elif cfg.model_parallel_size > 1: + elif suffix is not None: checkpoint_path = cfg.restore_file.replace(".pt", suffix + ".pt") else: checkpoint_path = cfg.restore_file @@ -405,8 +405,8 @@ def checkpoint_paths(path, pattern=r"checkpoint(\d+)\.pt"): return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)] -def torch_persistent_save(cfg: CheckpointConfig, obj, filename): - if cfg.write_checkpoints_asynchronously: +def torch_persistent_save(obj, filename, async_write: bool = False): + if async_write: with PathManager.opena(filename, "wb") as f: _torch_persistent_save(obj, f) else: @@ -434,61 +434,6 @@ def _torch_persistent_save(obj, f): logger.error(traceback.format_exc()) -def save_state( - filename, - cfg: FairseqConfig, - model_state_dict, - criterion, - optimizer, - lr_scheduler, - num_updates, - optim_history=None, - extra_state=None, - task=None, - **kwargs, -): - from fairseq import utils - - if optim_history is None: - optim_history = [] - if extra_state is None: - extra_state = {} - state_dict = { - "cfg": OmegaConf.to_container(cfg) if OmegaConf.is_config(cfg) else cfg, - "args": kwargs.get("args", None), - "model": model_state_dict or {}, - "optimizer_history": optim_history - + [ - { - "criterion_name": criterion.__class__.__name__, - "optimizer_name": optimizer.__class__.__name__, - "lr_scheduler_state": lr_scheduler.state_dict(), - "num_updates": num_updates, - } - ], - "extra_state": extra_state, - "task_state": task.state_dict() if task is not None else {}, - } - if utils.has_parameters(criterion): - state_dict["criterion"] = criterion.state_dict() - - if cfg is None: - cfg = state_dict["args"] - assert cfg is not None, "must provide cfg or args" - - if isinstance(cfg, DictConfig): - no_save_optimizer_state = cfg.checkpoint.no_save_optimizer_state - else: - no_save_optimizer_state = cfg.no_save_optimizer_state - if not no_save_optimizer_state: - state_dict["last_optimizer_state"] = optimizer.state_dict() - - # keep everything on CPU - state_dict = utils.move_to_cpu(state_dict) - - torch_persistent_save(cfg.checkpoint, state_dict, filename) - - def _upgrade_state_dict(state): """Helper for upgrading old model checkpoints.""" from fairseq import models, registry, tasks @@ -529,7 +474,7 @@ def _upgrade_state_dict(state): if "num_updates" not in state["optimizer_history"][-1]: state["optimizer_history"][-1]["num_updates"] = 0 # old model checkpoints may not have separate source/target positions - if hasattr(state["args"], "max_positions") and not hasattr( + if "args" in state and hasattr(state["args"], "max_positions") and not hasattr( state["args"], "max_source_positions" ): state["args"].max_source_positions = state["args"].max_positions diff --git a/fairseq/dataclass/configs.py b/fairseq/dataclass/configs.py index 39355b1ca..4d3c60bfd 100644 --- a/fairseq/dataclass/configs.py +++ b/fairseq/dataclass/configs.py @@ -618,7 +618,6 @@ class CheckpointConfig(FairseqDataclass): }, ) model_parallel_size: int = II("common.model_parallel_size") - distributed_rank: int = II("distributed_training.distributed_rank") @dataclass diff --git a/fairseq/trainer.py b/fairseq/trainer.py index 680a7ee95..45d9591d7 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -25,6 +25,8 @@ from fairseq.logging import meters, metrics from fairseq.nan_detector import NanDetector from fairseq.optim import lr_scheduler +from omegaconf import OmegaConf + logger = logging.getLogger(__name__) @@ -171,6 +173,16 @@ class Trainer(object): and not self.cfg.optimization.use_bmuf ) + @property + def should_save_checkpoint_on_current_rank(self) -> bool: + """Indicates whether to save checkpoints on the current DDP rank.""" + return self.is_data_parallel_master + + @property + def checkpoint_suffix(self) -> str: + """Suffix to add to the checkpoint file name.""" + return self.cfg.checkpoint.checkpoint_suffix or "" + @property def criterion(self): if self._wrapped_criterion is None: @@ -274,25 +286,50 @@ class Trainer(object): if hasattr(self.optimizer.optimizer, "consolidate_state_dict"): self.optimizer.optimizer.consolidate_state_dict() + def state_dict(self): + state_dict = { + "args": None, # legacy + "cfg": ( + OmegaConf.to_container(self.cfg) + if OmegaConf.is_config(self.cfg) else self.cfg + ), + "model": self.model.state_dict(), + "criterion": ( + self.criterion.state_dict() + if utils.has_parameters(self.criterion) else None + ), + "optimizer_history": (self._optim_history or []) + + [ + { + "criterion_name": self.get_criterion().__class__.__name__, + "optimizer_name": self.optimizer.__class__.__name__, + "lr_scheduler_state": self.lr_scheduler.state_dict(), + "num_updates": self.get_num_updates(), + } + ], + "task_state": self.task.state_dict() if self.task is not None else {}, + "extra_state": { + "metrics": metrics.state_dict(), + "previous_training_time": self.cumulative_training_time(), + } + } + if not self.cfg.checkpoint.no_save_optimizer_state: + state_dict["last_optimizer_state"] = self.optimizer.state_dict() + return state_dict + def save_checkpoint(self, filename, extra_state): """Save all training state in a checkpoint file.""" - if self.is_data_parallel_master: # only save one checkpoint - logger.info(f"Saving checkpoint to {filename}") - extra_state["metrics"] = metrics.state_dict() - extra_state["previous_training_time"] = self.cumulative_training_time() - checkpoint_utils.save_state( + logger.info(f"Saving checkpoint to {filename}") + # call state_dict on all ranks in case it needs internal communication + state_dict = utils.move_to_cpu(self.state_dict()) + state_dict["extra_state"].update(extra_state) + if self.should_save_checkpoint_on_current_rank: + checkpoint_utils.torch_persistent_save( + state_dict, filename, - self.cfg, - self.model.state_dict(), - self.get_criterion(), - self.optimizer, - self.lr_scheduler, - self.get_num_updates(), - optim_history=self._optim_history, - extra_state=extra_state, - task=self.task, + async_write=self.cfg.checkpoint.write_checkpoints_asynchronously, ) - logger.info(f"Finished saving checkpoint to {filename}") + logger.info(f"Finished saving checkpoint to {filename}") def load_checkpoint( self, diff --git a/tests/test_checkpoint_utils.py b/tests/test_checkpoint_utils.py index 3278de6b9..0f2822263 100644 --- a/tests/test_checkpoint_utils.py +++ b/tests/test_checkpoint_utils.py @@ -90,15 +90,14 @@ class TestCheckpointUtils(unittest.TestCase): self.assertEqual(len(ensemble[0].decoder.layers), 1) def test_torch_persistent_save_async(self): - cfg = OmegaConf.create() - cfg.dataset = OmegaConf.create() - cfg.dataset.write_checkpoints_asynchronously = True state_dict = {} filename = "async_checkpoint.pt" with patch(f"{checkpoint_utils.__name__}.PathManager.opena") as mock_opena: with patch(f"{checkpoint_utils.__name__}._torch_persistent_save") as mock_save: - checkpoint_utils.torch_persistent_save(cfg.dataset, state_dict, filename) + checkpoint_utils.torch_persistent_save( + state_dict, filename, async_write=True + ) mock_opena.assert_called_with(filename, "wb") mock_save.assert_called() diff --git a/tests/test_train.py b/tests/test_train.py index 57daa194b..65f4683bc 100644 --- a/tests/test_train.py +++ b/tests/test_train.py @@ -68,6 +68,7 @@ def get_mock_cfg(finetune_from_model): "reset_lr_scheduler": False, "finetune_from_model": finetune_from_model, "model_parallel_size": 1, + "restore_file": "checkpoint_last.pt", }, "common": { "model_parallel_size": 1,