diff --git a/scripts/sd_utils/__init__.py b/scripts/sd_utils/__init__.py index 66977b9..87c6f50 100644 --- a/scripts/sd_utils/__init__.py +++ b/scripts/sd_utils/__init__.py @@ -498,7 +498,7 @@ def load_model_from_config(config, ckpt, verbose=False): pl_sd = torch.load(ckpt, map_location="cpu") if "global_step" in pl_sd: logger.info(f"Global Step: {pl_sd['global_step']}") - sd = pl_sd["state_dict"] + sd = pl_sd["state_dict"] if "state_dict" in pl_sd else pl_sd model = instantiate_from_config(config.model, personalization_config='') m, u = model.load_state_dict(sd, strict=False) if len(m) > 0 and verbose: