From 902f4aa5e82d40a78e81e49b3b2f3cc434539e23 Mon Sep 17 00:00:00 2001 From: Alexei Baevski Date: Mon, 12 Dec 2022 23:36:17 -0800 Subject: [PATCH] remove missing config entries when loading task from checkpoint (#4905) --- fairseq/checkpoint_utils.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/fairseq/checkpoint_utils.py b/fairseq/checkpoint_utils.py index ff1da2553..fb9a6679b 100644 --- a/fairseq/checkpoint_utils.py +++ b/fairseq/checkpoint_utils.py @@ -440,11 +440,13 @@ def load_model_ensemble_and_task( ) if task is None: - task = tasks.setup_task(cfg.task) + task = tasks.setup_task(cfg.task, from_checkpoint=True) if "task_state" in state: task.load_state_dict(state["task_state"]) + argspec = inspect.getfullargspec(task.build_model) + if "fsdp_metadata" in state and num_shards > 1: model_shard_state["shard_weights"].append(state["model"]) model_shard_state["shard_metadata"].append(state["fsdp_metadata"]) @@ -459,7 +461,10 @@ def load_model_ensemble_and_task( shard_weights=model_shard_state["shard_weights"], shard_metadata=model_shard_state["shard_metadata"], ) - model = task.build_model(cfg.model) + if "from_checkpoint" in argspec.args: + model = task.build_model(cfg.model, from_checkpoint=True) + else: + model = task.build_model(cfg.model) if ( "optimizer_history" in state and len(state["optimizer_history"]) > 0 @@ -475,7 +480,6 @@ def load_model_ensemble_and_task( # model parallel checkpoint or unsharded checkpoint # support old external tasks - argspec = inspect.getfullargspec(task.build_model) if "from_checkpoint" in argspec.args: model = task.build_model(cfg.model, from_checkpoint=True) else: