mirror of
https://github.com/facebookresearch/fairseq.git
synced 2024-11-13 07:41:39 +03:00
remove missing config entries when loading task from checkpoint (#4905)
This commit is contained in:
parent
d871f6169f
commit
902f4aa5e8
@ -440,11 +440,13 @@ def load_model_ensemble_and_task(
|
||||
)
|
||||
|
||||
if task is None:
|
||||
task = tasks.setup_task(cfg.task)
|
||||
task = tasks.setup_task(cfg.task, from_checkpoint=True)
|
||||
|
||||
if "task_state" in state:
|
||||
task.load_state_dict(state["task_state"])
|
||||
|
||||
argspec = inspect.getfullargspec(task.build_model)
|
||||
|
||||
if "fsdp_metadata" in state and num_shards > 1:
|
||||
model_shard_state["shard_weights"].append(state["model"])
|
||||
model_shard_state["shard_metadata"].append(state["fsdp_metadata"])
|
||||
@ -459,7 +461,10 @@ def load_model_ensemble_and_task(
|
||||
shard_weights=model_shard_state["shard_weights"],
|
||||
shard_metadata=model_shard_state["shard_metadata"],
|
||||
)
|
||||
model = task.build_model(cfg.model)
|
||||
if "from_checkpoint" in argspec.args:
|
||||
model = task.build_model(cfg.model, from_checkpoint=True)
|
||||
else:
|
||||
model = task.build_model(cfg.model)
|
||||
if (
|
||||
"optimizer_history" in state
|
||||
and len(state["optimizer_history"]) > 0
|
||||
@ -475,7 +480,6 @@ def load_model_ensemble_and_task(
|
||||
# model parallel checkpoint or unsharded checkpoint
|
||||
# support old external tasks
|
||||
|
||||
argspec = inspect.getfullargspec(task.build_model)
|
||||
if "from_checkpoint" in argspec.args:
|
||||
model = task.build_model(cfg.model, from_checkpoint=True)
|
||||
else:
|
||||
|
Loading…
Reference in New Issue
Block a user