remove missing config entries when loading task from checkpoint (#4905)

This commit is contained in:
Alexei Baevski 2022-12-12 23:36:17 -08:00 committed by GitHub
parent d871f6169f
commit 902f4aa5e8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -440,11 +440,13 @@ def load_model_ensemble_and_task(
)
if task is None:
task = tasks.setup_task(cfg.task)
task = tasks.setup_task(cfg.task, from_checkpoint=True)
if "task_state" in state:
task.load_state_dict(state["task_state"])
argspec = inspect.getfullargspec(task.build_model)
if "fsdp_metadata" in state and num_shards > 1:
model_shard_state["shard_weights"].append(state["model"])
model_shard_state["shard_metadata"].append(state["fsdp_metadata"])
@ -459,7 +461,10 @@ def load_model_ensemble_and_task(
shard_weights=model_shard_state["shard_weights"],
shard_metadata=model_shard_state["shard_metadata"],
)
model = task.build_model(cfg.model)
if "from_checkpoint" in argspec.args:
model = task.build_model(cfg.model, from_checkpoint=True)
else:
model = task.build_model(cfg.model)
if (
"optimizer_history" in state
and len(state["optimizer_history"]) > 0
@ -475,7 +480,6 @@ def load_model_ensemble_and_task(
# model parallel checkpoint or unsharded checkpoint
# support old external tasks
argspec = inspect.getfullargspec(task.build_model)
if "from_checkpoint" in argspec.args:
model = task.build_model(cfg.model, from_checkpoint=True)
else: