mirror of
https://github.com/facebookresearch/fairseq.git
synced 2024-08-16 12:00:25 +03:00
fixes tests/test_train.py to mock checkpoint.save_dir config node (#3675)
Summary: ## What does this PR do? Some downstream users reported that errors when passing Namespace to load_checkpoint(). A recent change made the assumption that the passed object is dict like (dict or DictConfig) that have a get function. This changes that and make sure the mocked config have checkpoint.save_dir to allow the test to run. Pull Request resolved: https://github.com/pytorch/fairseq/pull/3675 Reviewed By: omry Differential Revision: D29564805 Pulled By: lematt1991 fbshipit-source-id: 89308811da382667f6c5d3152ee2d6480416ee62
This commit is contained in:
parent
cdc1a553eb
commit
dd106d9534
@ -203,7 +203,7 @@ def load_checkpoint(cfg: CheckpointConfig, trainer, **passthrough_args):
|
||||
cfg.restore_file == "checkpoint_last.pt"
|
||||
): # default value of restore_file is 'checkpoint_last.pt'
|
||||
checkpoint_path = os.path.join(
|
||||
cfg.get("save_dir"), "checkpoint_last{}.pt".format(suffix)
|
||||
cfg.save_dir, "checkpoint_last{}.pt".format(suffix)
|
||||
)
|
||||
first_launch = not PathManager.exists(checkpoint_path)
|
||||
if cfg.finetune_from_model is not None and first_launch:
|
||||
|
@ -61,6 +61,7 @@ def get_mock_cfg(finetune_from_model):
|
||||
cfg_mock = OmegaConf.create(
|
||||
{
|
||||
"checkpoint": {
|
||||
"save_dir": None,
|
||||
"optimizer_overrides": "{}",
|
||||
"reset_dataloader": False,
|
||||
"reset_meters": False,
|
||||
|
Loading…
Reference in New Issue
Block a user