fairseq/tests/test_checkpoint_utils.py
dianaml0 0dfd6b6240 Add linting with black (#2678)
Summary:
# Before submitting

- [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
- [ ] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/main/CONTRIBUTING.md)?
- [ ] Did you make sure to update the docs?
- [ ] Did you write any new necessary tests?

## What does this PR do?
Fixes # (issue).

## PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

## Did you have fun?
Make sure you had fun coding �

Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/2678

Reviewed By: Mortimerp9

Differential Revision: D32653381

Pulled By: dianaml0

fbshipit-source-id: 2810d14867cd7d64f4d340740e2b590b82de47fe
2021-11-29 12:32:59 -08:00

109 lines
3.7 KiB
Python

# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import contextlib
import logging
import os
import tempfile
import unittest
from io import StringIO
from unittest.mock import patch
from omegaconf import OmegaConf
from fairseq import checkpoint_utils
from tests.utils import (
create_dummy_data,
preprocess_translation_data,
train_translation_model,
)
class TestCheckpointUtils(unittest.TestCase):
def setUp(self):
logging.disable(logging.CRITICAL)
def tearDown(self):
logging.disable(logging.NOTSET)
@contextlib.contextmanager
def _train_transformer(self, seed, extra_args=None):
if extra_args is None:
extra_args = []
with tempfile.TemporaryDirectory(f"_train_transformer_seed{seed}") as data_dir:
create_dummy_data(data_dir)
preprocess_translation_data(data_dir)
train_translation_model(
data_dir,
"transformer_iwslt_de_en",
[
"--encoder-layers",
"3",
"--decoder-layers",
"3",
"--encoder-embed-dim",
"8",
"--decoder-embed-dim",
"8",
"--seed",
str(seed),
]
+ extra_args,
)
yield os.path.join(data_dir, "checkpoint_last.pt")
def test_load_model_ensemble_and_task(self):
# with contextlib.redirect_stdout(StringIO()):
with self._train_transformer(seed=123) as model1:
with self._train_transformer(seed=456) as model2:
ensemble, cfg, task = checkpoint_utils.load_model_ensemble_and_task(
filenames=[model1, model2]
)
self.assertEqual(len(ensemble), 2)
# after Transformer has been migrated to Hydra, this will probably
# become cfg.common.seed
self.assertEqual(ensemble[0].args.seed, 123)
self.assertEqual(ensemble[1].args.seed, 456)
# the task from the first model should be returned
self.assertTrue("seed123" in task.cfg.data)
# last cfg is saved
self.assertEqual(cfg.common.seed, 456)
def test_prune_state_dict(self):
with contextlib.redirect_stdout(StringIO()):
extra_args = ["--encoder-layerdrop", "0.01", "--decoder-layerdrop", "0.01"]
with self._train_transformer(seed=1, extra_args=extra_args) as model:
ensemble, cfg, task = checkpoint_utils.load_model_ensemble_and_task(
filenames=[model],
arg_overrides={
"encoder_layers_to_keep": "0,2",
"decoder_layers_to_keep": "1",
},
)
self.assertEqual(len(ensemble), 1)
self.assertEqual(len(ensemble[0].encoder.layers), 2)
self.assertEqual(len(ensemble[0].decoder.layers), 1)
def test_torch_persistent_save_async(self):
state_dict = {}
filename = "async_checkpoint.pt"
with patch(f"{checkpoint_utils.__name__}.PathManager.opena") as mock_opena:
with patch(
f"{checkpoint_utils.__name__}._torch_persistent_save"
) as mock_save:
checkpoint_utils.torch_persistent_save(
state_dict, filename, async_write=True
)
mock_opena.assert_called_with(filename, "wb")
mock_save.assert_called()
if __name__ == "__main__":
unittest.main()