fixes tests/test_train.py to mock checkpoint.save_dir config node (#3675)

Summary:
## What does this PR do?
Some downstream users reported that errors when passing Namespace to load_checkpoint().

A recent change made the assumption that the passed object is dict like (dict or DictConfig) that have a get function.
This changes that and make sure the mocked config have checkpoint.save_dir to allow the test to run.

Pull Request resolved: https://github.com/pytorch/fairseq/pull/3675

Reviewed By: omry

Differential Revision: D29564805

Pulled By: lematt1991

fbshipit-source-id: 89308811da382667f6c5d3152ee2d6480416ee62
This commit is contained in:
Omry Yadan 2021-07-06 15:06:07 -07:00 committed by Facebook GitHub Bot
parent cdc1a553eb
commit dd106d9534
2 changed files with 2 additions and 1 deletions

View File

@ -203,7 +203,7 @@ def load_checkpoint(cfg: CheckpointConfig, trainer, **passthrough_args):
cfg.restore_file == "checkpoint_last.pt"
): # default value of restore_file is 'checkpoint_last.pt'
checkpoint_path = os.path.join(
cfg.get("save_dir"), "checkpoint_last{}.pt".format(suffix)
cfg.save_dir, "checkpoint_last{}.pt".format(suffix)
)
first_launch = not PathManager.exists(checkpoint_path)
if cfg.finetune_from_model is not None and first_launch:

View File

@ -61,6 +61,7 @@ def get_mock_cfg(finetune_from_model):
cfg_mock = OmegaConf.create(
{
"checkpoint": {
"save_dir": None,
"optimizer_overrides": "{}",
"reset_dataloader": False,
"reset_meters": False,