Fixed KeyError: 'state_dict' when loading certain models, thanks to ResidentChiefNZ and hlky for the fix. (#1739)

This commit is contained in:
Alejandro Gil 2023-01-21 16:52:03 -08:00 committed by GitHub
commit c47243e3a8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -498,7 +498,7 @@ def load_model_from_config(config, ckpt, verbose=False):
pl_sd = torch.load(ckpt, map_location="cpu")
if "global_step" in pl_sd:
logger.info(f"Global Step: {pl_sd['global_step']}")
sd = pl_sd["state_dict"]
sd = pl_sd["state_dict"] if "state_dict" in pl_sd else pl_sd
model = instantiate_from_config(config.model, personalization_config='')
m, u = model.load_state_dict(sd, strict=False)
if len(m) > 0 and verbose: