mirror of
https://github.com/facebookresearch/fairseq.git
synced 2024-09-11 17:25:31 +03:00
9cf0bd96d6
Summary:
- add test for loading ensemble checkpoints (and confirmed it fails if I revert: 265791b727
)
- add test for LayerDrop (and fix it)
Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1468
Reviewed By: alexeib
Differential Revision: D25223272
Pulled By: myleott
fbshipit-source-id: 3f06f753605af251567c70d2961f5506ea423499
90 lines
3.1 KiB
Python
90 lines
3.1 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import contextlib
|
|
import logging
|
|
import os
|
|
import tempfile
|
|
import unittest
|
|
from io import StringIO
|
|
|
|
from fairseq import checkpoint_utils
|
|
|
|
from tests.utils import (
|
|
create_dummy_data,
|
|
preprocess_translation_data,
|
|
train_translation_model,
|
|
)
|
|
|
|
|
|
class TestCheckpointUtils(unittest.TestCase):
|
|
def setUp(self):
|
|
logging.disable(logging.CRITICAL)
|
|
|
|
def tearDown(self):
|
|
logging.disable(logging.NOTSET)
|
|
|
|
@contextlib.contextmanager
|
|
def _train_transformer(self, seed, extra_args=None):
|
|
if extra_args is None:
|
|
extra_args = []
|
|
with tempfile.TemporaryDirectory(f"_train_transformer_seed{seed}") as data_dir:
|
|
create_dummy_data(data_dir)
|
|
preprocess_translation_data(data_dir)
|
|
train_translation_model(
|
|
data_dir,
|
|
"transformer_iwslt_de_en",
|
|
[
|
|
"--encoder-layers",
|
|
"3",
|
|
"--decoder-layers",
|
|
"3",
|
|
"--encoder-embed-dim",
|
|
"8",
|
|
"--decoder-embed-dim",
|
|
"8",
|
|
"--seed",
|
|
str(seed),
|
|
]
|
|
+ extra_args,
|
|
)
|
|
yield os.path.join(data_dir, "checkpoint_last.pt")
|
|
|
|
def test_load_model_ensemble_and_task(self):
|
|
with contextlib.redirect_stdout(StringIO()):
|
|
with self._train_transformer(seed=123) as model1:
|
|
with self._train_transformer(seed=456) as model2:
|
|
ensemble, cfg, task = checkpoint_utils.load_model_ensemble_and_task(
|
|
filenames=[model1, model2]
|
|
)
|
|
self.assertEqual(len(ensemble), 2)
|
|
|
|
# after Transformer has been migrated to Hydra, this will probably
|
|
# become cfg.common.seed
|
|
self.assertEqual(ensemble[0].args.seed, 123)
|
|
self.assertEqual(ensemble[1].args.seed, 456)
|
|
|
|
# the task from the first model should be returned
|
|
self.assertEqual(task.args.seed, 123)
|
|
|
|
def test_prune_state_dict(self):
|
|
with contextlib.redirect_stdout(StringIO()):
|
|
extra_args = ["--encoder-layerdrop", "0.01", "--decoder-layerdrop", "0.01"]
|
|
with self._train_transformer(seed=1, extra_args=extra_args) as model:
|
|
ensemble, cfg, task = checkpoint_utils.load_model_ensemble_and_task(
|
|
filenames=[model],
|
|
arg_overrides={
|
|
"encoder_layers_to_keep": "0,2",
|
|
"decoder_layers_to_keep": "1",
|
|
},
|
|
)
|
|
self.assertEqual(len(ensemble), 1)
|
|
self.assertEqual(len(ensemble[0].encoder.layers), 2)
|
|
self.assertEqual(len(ensemble[0].decoder.layers), 1)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|