mirror of
https://github.com/facebookresearch/fairseq.git
synced 2024-10-26 17:32:57 +03:00
6d23cc7e7c
Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1666 Context: the checkpoint saving call stack has become a bit convoluted: ``` train.py + checkpoint_utils.save_checkpoint + trainer.save_checkpoint + checkpoint_utils.save_state + checkpoint_utils.torch_persistent_save ``` This diff slightly simplifies the checkpoint saving logic by exposing a `state_dict` method inside the Trainer. This simplifies the call stack to: ``` train.py + checkpoint_utils.save_checkpoint + trainer.save_checkpoint + checkpoint_utils.torch_persistent_save ``` This new structure is important for the FullyShardedDataParallel diff (next diff in the stack), since it enables the Trainer to save multiple checkpoints for the different optimizer state shards. Test Plan: - unit tests - trained WMT En-De models; confirmed checkpoints save/load properly, resuming from a checkpoint gives identical results - `buck test fblearner/flow/projects/langtech/translation:tests` (2 failures are in trunk too): https://www.internalfb.com/intern/testinfra/testconsole/testrun/2533274840914654/ Reviewed By: zhengwy888 Differential Revision: D26771146 Pulled By: myleott fbshipit-source-id: 10f91979cd42205c1d8abcaa9ab56f63eba31e93
107 lines
3.8 KiB
Python
107 lines
3.8 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import contextlib
|
|
import logging
|
|
import os
|
|
import tempfile
|
|
import unittest
|
|
from io import StringIO
|
|
from unittest.mock import patch
|
|
|
|
from fairseq import checkpoint_utils
|
|
from omegaconf import OmegaConf
|
|
|
|
from tests.utils import (
|
|
create_dummy_data,
|
|
preprocess_translation_data,
|
|
train_translation_model,
|
|
)
|
|
|
|
|
|
class TestCheckpointUtils(unittest.TestCase):
|
|
def setUp(self):
|
|
logging.disable(logging.CRITICAL)
|
|
|
|
def tearDown(self):
|
|
logging.disable(logging.NOTSET)
|
|
|
|
@contextlib.contextmanager
|
|
def _train_transformer(self, seed, extra_args=None):
|
|
if extra_args is None:
|
|
extra_args = []
|
|
with tempfile.TemporaryDirectory(f"_train_transformer_seed{seed}") as data_dir:
|
|
create_dummy_data(data_dir)
|
|
preprocess_translation_data(data_dir)
|
|
train_translation_model(
|
|
data_dir,
|
|
"transformer_iwslt_de_en",
|
|
[
|
|
"--encoder-layers",
|
|
"3",
|
|
"--decoder-layers",
|
|
"3",
|
|
"--encoder-embed-dim",
|
|
"8",
|
|
"--decoder-embed-dim",
|
|
"8",
|
|
"--seed",
|
|
str(seed),
|
|
]
|
|
+ extra_args,
|
|
)
|
|
yield os.path.join(data_dir, "checkpoint_last.pt")
|
|
|
|
def test_load_model_ensemble_and_task(self):
|
|
# with contextlib.redirect_stdout(StringIO()):
|
|
with self._train_transformer(seed=123) as model1:
|
|
with self._train_transformer(seed=456) as model2:
|
|
ensemble, cfg, task = checkpoint_utils.load_model_ensemble_and_task(
|
|
filenames=[model1, model2]
|
|
)
|
|
self.assertEqual(len(ensemble), 2)
|
|
|
|
# after Transformer has been migrated to Hydra, this will probably
|
|
# become cfg.common.seed
|
|
self.assertEqual(ensemble[0].args.seed, 123)
|
|
self.assertEqual(ensemble[1].args.seed, 456)
|
|
|
|
# the task from the first model should be returned
|
|
self.assertTrue("seed123" in task.cfg.data)
|
|
|
|
# last cfg is saved
|
|
self.assertEqual(cfg.common.seed, 456)
|
|
|
|
def test_prune_state_dict(self):
|
|
with contextlib.redirect_stdout(StringIO()):
|
|
extra_args = ["--encoder-layerdrop", "0.01", "--decoder-layerdrop", "0.01"]
|
|
with self._train_transformer(seed=1, extra_args=extra_args) as model:
|
|
ensemble, cfg, task = checkpoint_utils.load_model_ensemble_and_task(
|
|
filenames=[model],
|
|
arg_overrides={
|
|
"encoder_layers_to_keep": "0,2",
|
|
"decoder_layers_to_keep": "1",
|
|
},
|
|
)
|
|
self.assertEqual(len(ensemble), 1)
|
|
self.assertEqual(len(ensemble[0].encoder.layers), 2)
|
|
self.assertEqual(len(ensemble[0].decoder.layers), 1)
|
|
|
|
def test_torch_persistent_save_async(self):
|
|
state_dict = {}
|
|
filename = "async_checkpoint.pt"
|
|
|
|
with patch(f"{checkpoint_utils.__name__}.PathManager.opena") as mock_opena:
|
|
with patch(f"{checkpoint_utils.__name__}._torch_persistent_save") as mock_save:
|
|
checkpoint_utils.torch_persistent_save(
|
|
state_dict, filename, async_write=True
|
|
)
|
|
mock_opena.assert_called_with(filename, "wb")
|
|
mock_save.assert_called()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|