diff --git a/.gitignore b/.gitignore index 5381c515..78cf719e 100644 --- a/.gitignore +++ b/.gitignore @@ -9,4 +9,5 @@ __pycache__ /outputs /config.json /log -webui.settings.bat \ No newline at end of file +/webui.settings.bat +/embeddings diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 1084e248..db9952a5 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -73,11 +73,21 @@ class StableDiffusionModelHijack: name = os.path.splitext(filename)[0] data = torch.load(path) - param_dict = data['string_to_param'] - if hasattr(param_dict, '_parameters'): - param_dict = getattr(param_dict, '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11 - assert len(param_dict) == 1, 'embedding file has multiple terms in it' - emb = next(iter(param_dict.items()))[1] + + # textual inversion embeddings + if 'string_to_param' in data: + param_dict = data['string_to_param'] + if hasattr(param_dict, '_parameters'): + param_dict = getattr(param_dict, '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11 + assert len(param_dict) == 1, 'embedding file has multiple terms in it' + emb = next(iter(param_dict.items()))[1] + elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: + assert len(data.keys()) == 1, 'embedding file has multiple terms in it' + + emb = next(iter(data.values())) + if len(emb.shape) == 1: + emb = emb.unsqueeze(0) + self.word_embeddings[name] = emb.detach() self.word_embeddings_checksums[name] = f'{const_hash(emb.reshape(-1))&0xffff:04x}'