From dd106d9534b22e7db859a6b87ffd7780c38341f8 Mon Sep 17 00:00:00 2001 From: Omry Yadan Date: Tue, 6 Jul 2021 15:06:07 -0700 Subject: [PATCH] fixes tests/test_train.py to mock checkpoint.save_dir config node (#3675) Summary: ## What does this PR do? Some downstream users reported that errors when passing Namespace to load_checkpoint(). A recent change made the assumption that the passed object is dict like (dict or DictConfig) that have a get function. This changes that and make sure the mocked config have checkpoint.save_dir to allow the test to run. Pull Request resolved: https://github.com/pytorch/fairseq/pull/3675 Reviewed By: omry Differential Revision: D29564805 Pulled By: lematt1991 fbshipit-source-id: 89308811da382667f6c5d3152ee2d6480416ee62 --- fairseq/checkpoint_utils.py | 2 +- tests/test_train.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/fairseq/checkpoint_utils.py b/fairseq/checkpoint_utils.py index d22d9870..627f1416 100644 --- a/fairseq/checkpoint_utils.py +++ b/fairseq/checkpoint_utils.py @@ -203,7 +203,7 @@ def load_checkpoint(cfg: CheckpointConfig, trainer, **passthrough_args): cfg.restore_file == "checkpoint_last.pt" ): # default value of restore_file is 'checkpoint_last.pt' checkpoint_path = os.path.join( - cfg.get("save_dir"), "checkpoint_last{}.pt".format(suffix) + cfg.save_dir, "checkpoint_last{}.pt".format(suffix) ) first_launch = not PathManager.exists(checkpoint_path) if cfg.finetune_from_model is not None and first_launch: diff --git a/tests/test_train.py b/tests/test_train.py index 65f4683b..02ef94cc 100644 --- a/tests/test_train.py +++ b/tests/test_train.py @@ -61,6 +61,7 @@ def get_mock_cfg(finetune_from_model): cfg_mock = OmegaConf.create( { "checkpoint": { + "save_dir": None, "optimizer_overrides": "{}", "reset_dataloader": False, "reset_meters": False,