Add/fix tests (#1468)

Summary:
- add test for loading ensemble checkpoints (and confirmed it fails if I revert: 265791b727)
- add test for LayerDrop (and fix it)

Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1468

Reviewed By: alexeib

Differential Revision: D25223272

Pulled By: myleott

fbshipit-source-id: 3f06f753605af251567c70d2961f5506ea423499
This commit is contained in:
Myle Ott 2020-11-30 14:19:25 -08:00 committed by Facebook GitHub Bot
parent 4ba5b4be98
commit 9cf0bd96d6
3 changed files with 141 additions and 4 deletions

View File

@ -5,6 +5,7 @@
import ast
import collections
import contextlib
import logging
import os
import re
@ -239,7 +240,13 @@ def load_checkpoint_to_cpu(path, arg_overrides=None):
def load_model_ensemble(
filenames, arg_overrides=None, task=None, strict=True, suffix="", num_shards=1, state=None
filenames,
arg_overrides=None,
task=None,
strict=True,
suffix="",
num_shards=1,
state=None,
):
"""Loads an ensemble of models.
@ -265,7 +272,13 @@ def load_model_ensemble(
def load_model_ensemble_and_task(
filenames, arg_overrides=None, task=None, strict=True, suffix="", num_shards=1, state=None
filenames,
arg_overrides=None,
task=None,
strict=True,
suffix="",
num_shards=1,
state=None,
):
assert state is None or len(filenames) == 1
@ -563,8 +576,11 @@ def prune_state_dict(state_dict, model_cfg: Optional[DictConfig]):
# Since layers are now pruned, *_layers_to_keep are no longer needed.
# This is more of "It would make it work fix" rather than a proper fix.
with open_dict(model_cfg):
if isinstance(model_cfg, DictConfig):
context = open_dict(model_cfg)
else:
context = contextlib.ExitStack()
with context:
if hasattr(model_cfg, "encoder_layers_to_keep"):
model_cfg.encoder_layers_to_keep = None
if hasattr(model_cfg, "decoder_layers_to_keep"):

View File

@ -925,6 +925,38 @@ class TestTranslation(unittest.TestCase):
)
generate_main(data_dir)
def test_transformer_layerdrop(self):
with contextlib.redirect_stdout(StringIO()):
with tempfile.TemporaryDirectory("test_transformer_layerdrop") as data_dir:
create_dummy_data(data_dir)
preprocess_translation_data(data_dir)
train_translation_model(
data_dir,
"transformer_iwslt_de_en",
[
"--encoder-layers",
"3",
"--decoder-layers",
"3",
"--encoder-embed-dim",
"8",
"--decoder-embed-dim",
"8",
"--encoder-layerdrop",
"0.01",
"--decoder-layerdrop",
"0.01",
],
)
generate_main(data_dir)
generate_main(
data_dir,
[
"--model-overrides",
"{'encoder_layers_to_keep':'0,2','decoder_layers_to_keep':'1'}"
],
)
class TestStories(unittest.TestCase):
def setUp(self):

View File

@ -0,0 +1,89 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import contextlib
import logging
import os
import tempfile
import unittest
from io import StringIO
from fairseq import checkpoint_utils
from tests.utils import (
create_dummy_data,
preprocess_translation_data,
train_translation_model,
)
class TestCheckpointUtils(unittest.TestCase):
def setUp(self):
logging.disable(logging.CRITICAL)
def tearDown(self):
logging.disable(logging.NOTSET)
@contextlib.contextmanager
def _train_transformer(self, seed, extra_args=None):
if extra_args is None:
extra_args = []
with tempfile.TemporaryDirectory(f"_train_transformer_seed{seed}") as data_dir:
create_dummy_data(data_dir)
preprocess_translation_data(data_dir)
train_translation_model(
data_dir,
"transformer_iwslt_de_en",
[
"--encoder-layers",
"3",
"--decoder-layers",
"3",
"--encoder-embed-dim",
"8",
"--decoder-embed-dim",
"8",
"--seed",
str(seed),
]
+ extra_args,
)
yield os.path.join(data_dir, "checkpoint_last.pt")
def test_load_model_ensemble_and_task(self):
with contextlib.redirect_stdout(StringIO()):
with self._train_transformer(seed=123) as model1:
with self._train_transformer(seed=456) as model2:
ensemble, cfg, task = checkpoint_utils.load_model_ensemble_and_task(
filenames=[model1, model2]
)
self.assertEqual(len(ensemble), 2)
# after Transformer has been migrated to Hydra, this will probably
# become cfg.common.seed
self.assertEqual(ensemble[0].args.seed, 123)
self.assertEqual(ensemble[1].args.seed, 456)
# the task from the first model should be returned
self.assertEqual(task.args.seed, 123)
def test_prune_state_dict(self):
with contextlib.redirect_stdout(StringIO()):
extra_args = ["--encoder-layerdrop", "0.01", "--decoder-layerdrop", "0.01"]
with self._train_transformer(seed=1, extra_args=extra_args) as model:
ensemble, cfg, task = checkpoint_utils.load_model_ensemble_and_task(
filenames=[model],
arg_overrides={
"encoder_layers_to_keep": "0,2",
"decoder_layers_to_keep": "1",
},
)
self.assertEqual(len(ensemble), 1)
self.assertEqual(len(ensemble[0].encoder.layers), 2)
self.assertEqual(len(ensemble[0].decoder.layers), 1)
if __name__ == "__main__":
unittest.main()