Fixes checkpoint_path while loading a model-parallel checkpoint (#2365)

Summary:
Fixes https://github.com/pytorch/fairseq/issues/2351

Pull Request resolved: https://github.com/pytorch/fairseq/pull/2365

Reviewed By: pipibjc

Differential Revision: D22727384

Pulled By: myleott

fbshipit-source-id: e2ff703181a6b8f10df9b4ee7aa3f9e128c04b4e
This commit is contained in:
Rakesh Chada 2020-08-04 08:24:25 -07:00 committed by Facebook GitHub Bot
parent 33cefe3728
commit b040dae714
2 changed files with 3 additions and 0 deletions

View File

@ -126,6 +126,8 @@ def load_checkpoint(args, trainer, **passthrough_args):
suffix = getattr(args, "checkpoint_suffix", "")
if args.restore_file == "checkpoint_last.pt":
checkpoint_path = os.path.join(args.save_dir, "checkpoint_last{}.pt".format(suffix))
elif getattr(args, "model_parallel_size", 1) > 1:
checkpoint_path = args.restore_file.replace(".pt", suffix + ".pt")
else:
checkpoint_path = args.restore_file

View File

@ -57,6 +57,7 @@ class TestLoadCheckpoint(unittest.TestCase):
self.args_mock.reset_dataloader = False
self.args_mock.reset_meters = False
self.args_mock.reset_optimizer = False
self.args_mock.model_parallel_size = 1
self.patches = {
'os.makedirs': MagicMock(),
'os.path.join': MagicMock(),