diff --git a/fairseq/checkpoint_utils.py b/fairseq/checkpoint_utils.py index d22d9870..627f1416 100644 --- a/fairseq/checkpoint_utils.py +++ b/fairseq/checkpoint_utils.py @@ -203,7 +203,7 @@ def load_checkpoint(cfg: CheckpointConfig, trainer, **passthrough_args): cfg.restore_file == "checkpoint_last.pt" ): # default value of restore_file is 'checkpoint_last.pt' checkpoint_path = os.path.join( - cfg.get("save_dir"), "checkpoint_last{}.pt".format(suffix) + cfg.save_dir, "checkpoint_last{}.pt".format(suffix) ) first_launch = not PathManager.exists(checkpoint_path) if cfg.finetune_from_model is not None and first_launch: diff --git a/tests/test_train.py b/tests/test_train.py index 65f4683b..02ef94cc 100644 --- a/tests/test_train.py +++ b/tests/test_train.py @@ -61,6 +61,7 @@ def get_mock_cfg(finetune_from_model): cfg_mock = OmegaConf.create( { "checkpoint": { + "save_dir": None, "optimizer_overrides": "{}", "reset_dataloader": False, "reset_meters": False,