fairseq/tests/test_checkpoint_utils.py
Myle Ott 6d23cc7e7c Move checkpoint state_dict creation into Trainer (#1666)
Summary:
Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1666

Context: the checkpoint saving call stack has become a bit convoluted:
```
train.py
+ checkpoint_utils.save_checkpoint
 + trainer.save_checkpoint
  + checkpoint_utils.save_state
   + checkpoint_utils.torch_persistent_save
```

This diff slightly simplifies the checkpoint saving logic by exposing a `state_dict` method inside the Trainer. This simplifies the call stack to:
```
train.py
+ checkpoint_utils.save_checkpoint
 + trainer.save_checkpoint
  + checkpoint_utils.torch_persistent_save
```

This new structure is important for the FullyShardedDataParallel diff (next diff in the stack), since it enables the Trainer to save multiple checkpoints for the different optimizer state shards.

Test Plan:
- unit tests
- trained WMT En-De models; confirmed checkpoints save/load properly, resuming from a checkpoint gives identical results
- `buck test fblearner/flow/projects/langtech/translation:tests` (2 failures are in trunk too): https://www.internalfb.com/intern/testinfra/testconsole/testrun/2533274840914654/

Reviewed By: zhengwy888

Differential Revision: D26771146

Pulled By: myleott

fbshipit-source-id: 10f91979cd42205c1d8abcaa9ab56f63eba31e93
2021-03-04 13:32:44 -08:00

107 lines
3.8 KiB
Python

# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import contextlib
import logging
import os
import tempfile
import unittest
from io import StringIO
from unittest.mock import patch
from fairseq import checkpoint_utils
from omegaconf import OmegaConf
from tests.utils import (
create_dummy_data,
preprocess_translation_data,
train_translation_model,
)
class TestCheckpointUtils(unittest.TestCase):
def setUp(self):
logging.disable(logging.CRITICAL)
def tearDown(self):
logging.disable(logging.NOTSET)
@contextlib.contextmanager
def _train_transformer(self, seed, extra_args=None):
if extra_args is None:
extra_args = []
with tempfile.TemporaryDirectory(f"_train_transformer_seed{seed}") as data_dir:
create_dummy_data(data_dir)
preprocess_translation_data(data_dir)
train_translation_model(
data_dir,
"transformer_iwslt_de_en",
[
"--encoder-layers",
"3",
"--decoder-layers",
"3",
"--encoder-embed-dim",
"8",
"--decoder-embed-dim",
"8",
"--seed",
str(seed),
]
+ extra_args,
)
yield os.path.join(data_dir, "checkpoint_last.pt")
def test_load_model_ensemble_and_task(self):
# with contextlib.redirect_stdout(StringIO()):
with self._train_transformer(seed=123) as model1:
with self._train_transformer(seed=456) as model2:
ensemble, cfg, task = checkpoint_utils.load_model_ensemble_and_task(
filenames=[model1, model2]
)
self.assertEqual(len(ensemble), 2)
# after Transformer has been migrated to Hydra, this will probably
# become cfg.common.seed
self.assertEqual(ensemble[0].args.seed, 123)
self.assertEqual(ensemble[1].args.seed, 456)
# the task from the first model should be returned
self.assertTrue("seed123" in task.cfg.data)
# last cfg is saved
self.assertEqual(cfg.common.seed, 456)
def test_prune_state_dict(self):
with contextlib.redirect_stdout(StringIO()):
extra_args = ["--encoder-layerdrop", "0.01", "--decoder-layerdrop", "0.01"]
with self._train_transformer(seed=1, extra_args=extra_args) as model:
ensemble, cfg, task = checkpoint_utils.load_model_ensemble_and_task(
filenames=[model],
arg_overrides={
"encoder_layers_to_keep": "0,2",
"decoder_layers_to_keep": "1",
},
)
self.assertEqual(len(ensemble), 1)
self.assertEqual(len(ensemble[0].encoder.layers), 2)
self.assertEqual(len(ensemble[0].decoder.layers), 1)
def test_torch_persistent_save_async(self):
state_dict = {}
filename = "async_checkpoint.pt"
with patch(f"{checkpoint_utils.__name__}.PathManager.opena") as mock_opena:
with patch(f"{checkpoint_utils.__name__}._torch_persistent_save") as mock_save:
checkpoint_utils.torch_persistent_save(
state_dict, filename, async_write=True
)
mock_opena.assert_called_with(filename, "wb")
mock_save.assert_called()
if __name__ == "__main__":
unittest.main()