Move checkpoint state_dict creation into Trainer (#1666)

Summary:
Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1666

Context: the checkpoint saving call stack has become a bit convoluted:
```
train.py
+ checkpoint_utils.save_checkpoint
 + trainer.save_checkpoint
  + checkpoint_utils.save_state
   + checkpoint_utils.torch_persistent_save
```

This diff slightly simplifies the checkpoint saving logic by exposing a `state_dict` method inside the Trainer. This simplifies the call stack to:
```
train.py
+ checkpoint_utils.save_checkpoint
 + trainer.save_checkpoint
  + checkpoint_utils.torch_persistent_save
```

This new structure is important for the FullyShardedDataParallel diff (next diff in the stack), since it enables the Trainer to save multiple checkpoints for the different optimizer state shards.

Test Plan:
- unit tests
- trained WMT En-De models; confirmed checkpoints save/load properly, resuming from a checkpoint gives identical results
- `buck test fblearner/flow/projects/langtech/translation:tests` (2 failures are in trunk too): https://www.internalfb.com/intern/testinfra/testconsole/testrun/2533274840914654/

Reviewed By: zhengwy888

Differential Revision: D26771146

Pulled By: myleott

fbshipit-source-id: 10f91979cd42205c1d8abcaa9ab56f63eba31e93
This commit is contained in:
Myle Ott 2021-03-04 13:31:02 -08:00 committed by Facebook GitHub Bot
parent f1c595beb8
commit 6d23cc7e7c
5 changed files with 64 additions and 83 deletions

View File

@ -31,7 +31,7 @@ def save_checkpoint(cfg: CheckpointConfig, trainer, epoch_itr, val_loss):
from fairseq import meters
# only one worker should attempt to create the required dir
if cfg.distributed_rank == 0:
if trainer.data_parallel_rank == 0:
os.makedirs(cfg.save_dir, exist_ok=True)
prev_best = getattr(save_checkpoint, "best", val_loss)
@ -44,7 +44,7 @@ def save_checkpoint(cfg: CheckpointConfig, trainer, epoch_itr, val_loss):
trainer.consolidate_optimizer()
if not trainer.is_data_parallel_master:
if not trainer.should_save_checkpoint_on_current_rank:
return
write_timer = meters.StopwatchMeter()
@ -59,7 +59,7 @@ def save_checkpoint(cfg: CheckpointConfig, trainer, epoch_itr, val_loss):
def is_better(a, b):
return a >= b if cfg.maximize_best_checkpoint_metric else a <= b
suffix = cfg.checkpoint_suffix or ""
suffix = trainer.checkpoint_suffix
checkpoint_conds = collections.OrderedDict()
checkpoint_conds["checkpoint{}{}.pt".format(epoch, suffix)] = (
end_of_epoch and not cfg.no_epoch_checkpoints and epoch % cfg.save_interval == 0
@ -165,7 +165,7 @@ def load_checkpoint(cfg: CheckpointConfig, trainer, **passthrough_args):
" or reset_lr_scheduler or reset_meters or reset_dataloader"
)
suffix = cfg.checkpoint_suffix
suffix = trainer.checkpoint_suffix
if (
cfg.restore_file == "checkpoint_last.pt"
): # default value of restore_file is 'checkpoint_last.pt'
@ -190,7 +190,7 @@ def load_checkpoint(cfg: CheckpointConfig, trainer, **passthrough_args):
raise ValueError(
f"--funetune-from-model {cfg.finetune_from_model} does not exist"
)
elif cfg.model_parallel_size > 1:
elif suffix is not None:
checkpoint_path = cfg.restore_file.replace(".pt", suffix + ".pt")
else:
checkpoint_path = cfg.restore_file
@ -405,8 +405,8 @@ def checkpoint_paths(path, pattern=r"checkpoint(\d+)\.pt"):
return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)]
def torch_persistent_save(cfg: CheckpointConfig, obj, filename):
if cfg.write_checkpoints_asynchronously:
def torch_persistent_save(obj, filename, async_write: bool = False):
if async_write:
with PathManager.opena(filename, "wb") as f:
_torch_persistent_save(obj, f)
else:
@ -434,61 +434,6 @@ def _torch_persistent_save(obj, f):
logger.error(traceback.format_exc())
def save_state(
filename,
cfg: FairseqConfig,
model_state_dict,
criterion,
optimizer,
lr_scheduler,
num_updates,
optim_history=None,
extra_state=None,
task=None,
**kwargs,
):
from fairseq import utils
if optim_history is None:
optim_history = []
if extra_state is None:
extra_state = {}
state_dict = {
"cfg": OmegaConf.to_container(cfg) if OmegaConf.is_config(cfg) else cfg,
"args": kwargs.get("args", None),
"model": model_state_dict or {},
"optimizer_history": optim_history
+ [
{
"criterion_name": criterion.__class__.__name__,
"optimizer_name": optimizer.__class__.__name__,
"lr_scheduler_state": lr_scheduler.state_dict(),
"num_updates": num_updates,
}
],
"extra_state": extra_state,
"task_state": task.state_dict() if task is not None else {},
}
if utils.has_parameters(criterion):
state_dict["criterion"] = criterion.state_dict()
if cfg is None:
cfg = state_dict["args"]
assert cfg is not None, "must provide cfg or args"
if isinstance(cfg, DictConfig):
no_save_optimizer_state = cfg.checkpoint.no_save_optimizer_state
else:
no_save_optimizer_state = cfg.no_save_optimizer_state
if not no_save_optimizer_state:
state_dict["last_optimizer_state"] = optimizer.state_dict()
# keep everything on CPU
state_dict = utils.move_to_cpu(state_dict)
torch_persistent_save(cfg.checkpoint, state_dict, filename)
def _upgrade_state_dict(state):
"""Helper for upgrading old model checkpoints."""
from fairseq import models, registry, tasks
@ -529,7 +474,7 @@ def _upgrade_state_dict(state):
if "num_updates" not in state["optimizer_history"][-1]:
state["optimizer_history"][-1]["num_updates"] = 0
# old model checkpoints may not have separate source/target positions
if hasattr(state["args"], "max_positions") and not hasattr(
if "args" in state and hasattr(state["args"], "max_positions") and not hasattr(
state["args"], "max_source_positions"
):
state["args"].max_source_positions = state["args"].max_positions

View File

@ -618,7 +618,6 @@ class CheckpointConfig(FairseqDataclass):
},
)
model_parallel_size: int = II("common.model_parallel_size")
distributed_rank: int = II("distributed_training.distributed_rank")
@dataclass

View File

@ -25,6 +25,8 @@ from fairseq.logging import meters, metrics
from fairseq.nan_detector import NanDetector
from fairseq.optim import lr_scheduler
from omegaconf import OmegaConf
logger = logging.getLogger(__name__)
@ -171,6 +173,16 @@ class Trainer(object):
and not self.cfg.optimization.use_bmuf
)
@property
def should_save_checkpoint_on_current_rank(self) -> bool:
"""Indicates whether to save checkpoints on the current DDP rank."""
return self.is_data_parallel_master
@property
def checkpoint_suffix(self) -> str:
"""Suffix to add to the checkpoint file name."""
return self.cfg.checkpoint.checkpoint_suffix or ""
@property
def criterion(self):
if self._wrapped_criterion is None:
@ -274,25 +286,50 @@ class Trainer(object):
if hasattr(self.optimizer.optimizer, "consolidate_state_dict"):
self.optimizer.optimizer.consolidate_state_dict()
def state_dict(self):
state_dict = {
"args": None, # legacy
"cfg": (
OmegaConf.to_container(self.cfg)
if OmegaConf.is_config(self.cfg) else self.cfg
),
"model": self.model.state_dict(),
"criterion": (
self.criterion.state_dict()
if utils.has_parameters(self.criterion) else None
),
"optimizer_history": (self._optim_history or [])
+ [
{
"criterion_name": self.get_criterion().__class__.__name__,
"optimizer_name": self.optimizer.__class__.__name__,
"lr_scheduler_state": self.lr_scheduler.state_dict(),
"num_updates": self.get_num_updates(),
}
],
"task_state": self.task.state_dict() if self.task is not None else {},
"extra_state": {
"metrics": metrics.state_dict(),
"previous_training_time": self.cumulative_training_time(),
}
}
if not self.cfg.checkpoint.no_save_optimizer_state:
state_dict["last_optimizer_state"] = self.optimizer.state_dict()
return state_dict
def save_checkpoint(self, filename, extra_state):
"""Save all training state in a checkpoint file."""
if self.is_data_parallel_master: # only save one checkpoint
logger.info(f"Saving checkpoint to {filename}")
extra_state["metrics"] = metrics.state_dict()
extra_state["previous_training_time"] = self.cumulative_training_time()
checkpoint_utils.save_state(
logger.info(f"Saving checkpoint to {filename}")
# call state_dict on all ranks in case it needs internal communication
state_dict = utils.move_to_cpu(self.state_dict())
state_dict["extra_state"].update(extra_state)
if self.should_save_checkpoint_on_current_rank:
checkpoint_utils.torch_persistent_save(
state_dict,
filename,
self.cfg,
self.model.state_dict(),
self.get_criterion(),
self.optimizer,
self.lr_scheduler,
self.get_num_updates(),
optim_history=self._optim_history,
extra_state=extra_state,
task=self.task,
async_write=self.cfg.checkpoint.write_checkpoints_asynchronously,
)
logger.info(f"Finished saving checkpoint to {filename}")
logger.info(f"Finished saving checkpoint to {filename}")
def load_checkpoint(
self,

View File

@ -90,15 +90,14 @@ class TestCheckpointUtils(unittest.TestCase):
self.assertEqual(len(ensemble[0].decoder.layers), 1)
def test_torch_persistent_save_async(self):
cfg = OmegaConf.create()
cfg.dataset = OmegaConf.create()
cfg.dataset.write_checkpoints_asynchronously = True
state_dict = {}
filename = "async_checkpoint.pt"
with patch(f"{checkpoint_utils.__name__}.PathManager.opena") as mock_opena:
with patch(f"{checkpoint_utils.__name__}._torch_persistent_save") as mock_save:
checkpoint_utils.torch_persistent_save(cfg.dataset, state_dict, filename)
checkpoint_utils.torch_persistent_save(
state_dict, filename, async_write=True
)
mock_opena.assert_called_with(filename, "wb")
mock_save.assert_called()

View File

@ -68,6 +68,7 @@ def get_mock_cfg(finetune_from_model):
"reset_lr_scheduler": False,
"finetune_from_model": finetune_from_model,
"model_parallel_size": 1,
"restore_file": "checkpoint_last.pt",
},
"common": {
"model_parallel_size": 1,