mirror of
https://github.com/openvinotoolkit/stable-diffusion-webui.git
synced 2024-12-15 15:13:45 +03:00
Lora: output warnings in UI rather than fail for unfitting loras; switch to logging for error output in console
This commit is contained in:
parent
da80d649fd
commit
d8419762c1
@ -1,4 +1,4 @@
|
|||||||
from modules import extra_networks, shared
|
from modules import extra_networks, shared, sd_hijack
|
||||||
import networks
|
import networks
|
||||||
|
|
||||||
|
|
||||||
@ -6,9 +6,14 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork):
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__('lora')
|
super().__init__('lora')
|
||||||
|
|
||||||
|
self.errors = {}
|
||||||
|
"""mapping of network names to the number of errors the network had during operation"""
|
||||||
|
|
||||||
def activate(self, p, params_list):
|
def activate(self, p, params_list):
|
||||||
additional = shared.opts.sd_lora
|
additional = shared.opts.sd_lora
|
||||||
|
|
||||||
|
self.errors.clear()
|
||||||
|
|
||||||
if additional != "None" and additional in networks.available_networks and not any(x for x in params_list if x.items[0] == additional):
|
if additional != "None" and additional in networks.available_networks and not any(x for x in params_list if x.items[0] == additional):
|
||||||
p.all_prompts = [x + f"<lora:{additional}:{shared.opts.extra_networks_default_multiplier}>" for x in p.all_prompts]
|
p.all_prompts = [x + f"<lora:{additional}:{shared.opts.extra_networks_default_multiplier}>" for x in p.all_prompts]
|
||||||
params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))
|
params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))
|
||||||
@ -56,4 +61,7 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork):
|
|||||||
p.extra_generation_params["Lora hashes"] = ", ".join(network_hashes)
|
p.extra_generation_params["Lora hashes"] = ", ".join(network_hashes)
|
||||||
|
|
||||||
def deactivate(self, p):
|
def deactivate(self, p):
|
||||||
pass
|
if self.errors:
|
||||||
|
p.comment("Networks with errors: " + ", ".join(f"{k} ({v})" for k, v in self.errors.items()))
|
||||||
|
|
||||||
|
self.errors.clear()
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
|
|
||||||
@ -194,7 +195,7 @@ def load_network(name, network_on_disk):
|
|||||||
net.modules[key] = net_module
|
net.modules[key] = net_module
|
||||||
|
|
||||||
if keys_failed_to_match:
|
if keys_failed_to_match:
|
||||||
print(f"Failed to match keys when loading network {network_on_disk.filename}: {keys_failed_to_match}")
|
logging.debug(f"Network {network_on_disk.filename} didn't match keys: {keys_failed_to_match}")
|
||||||
|
|
||||||
return net
|
return net
|
||||||
|
|
||||||
@ -207,7 +208,6 @@ def purge_networks_from_memory():
|
|||||||
devices.torch_gc()
|
devices.torch_gc()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=None):
|
def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=None):
|
||||||
already_loaded = {}
|
already_loaded = {}
|
||||||
|
|
||||||
@ -248,7 +248,7 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No
|
|||||||
|
|
||||||
if net is None:
|
if net is None:
|
||||||
failed_to_load_networks.append(name)
|
failed_to_load_networks.append(name)
|
||||||
print(f"Couldn't find network with name {name}")
|
logging.info(f"Couldn't find network with name {name}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
net.te_multiplier = te_multipliers[i] if te_multipliers else 1.0
|
net.te_multiplier = te_multipliers[i] if te_multipliers else 1.0
|
||||||
@ -257,7 +257,7 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No
|
|||||||
loaded_networks.append(net)
|
loaded_networks.append(net)
|
||||||
|
|
||||||
if failed_to_load_networks:
|
if failed_to_load_networks:
|
||||||
sd_hijack.model_hijack.comments.append("Failed to find networks: " + ", ".join(failed_to_load_networks))
|
sd_hijack.model_hijack.comments.append("Networks not found: " + ", ".join(failed_to_load_networks))
|
||||||
|
|
||||||
purge_networks_from_memory()
|
purge_networks_from_memory()
|
||||||
|
|
||||||
@ -314,17 +314,22 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
|
|||||||
for net in loaded_networks:
|
for net in loaded_networks:
|
||||||
module = net.modules.get(network_layer_name, None)
|
module = net.modules.get(network_layer_name, None)
|
||||||
if module is not None and hasattr(self, 'weight'):
|
if module is not None and hasattr(self, 'weight'):
|
||||||
with torch.no_grad():
|
try:
|
||||||
updown, ex_bias = module.calc_updown(self.weight)
|
with torch.no_grad():
|
||||||
|
updown, ex_bias = module.calc_updown(self.weight)
|
||||||
|
|
||||||
if len(self.weight.shape) == 4 and self.weight.shape[1] == 9:
|
if len(self.weight.shape) == 4 and self.weight.shape[1] == 9:
|
||||||
# inpainting model. zero pad updown to make channel[1] 4 to 9
|
# inpainting model. zero pad updown to make channel[1] 4 to 9
|
||||||
updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5))
|
updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5))
|
||||||
|
|
||||||
self.weight += updown
|
self.weight += updown
|
||||||
if ex_bias is not None and getattr(self, 'bias', None) is not None:
|
if ex_bias is not None and getattr(self, 'bias', None) is not None:
|
||||||
self.bias += ex_bias
|
self.bias += ex_bias
|
||||||
continue
|
except RuntimeError as e:
|
||||||
|
logging.debug(f"Network {net.name} layer {network_layer_name}: {e}")
|
||||||
|
extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1
|
||||||
|
|
||||||
|
continue
|
||||||
|
|
||||||
module_q = net.modules.get(network_layer_name + "_q_proj", None)
|
module_q = net.modules.get(network_layer_name + "_q_proj", None)
|
||||||
module_k = net.modules.get(network_layer_name + "_k_proj", None)
|
module_k = net.modules.get(network_layer_name + "_k_proj", None)
|
||||||
@ -332,21 +337,28 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
|
|||||||
module_out = net.modules.get(network_layer_name + "_out_proj", None)
|
module_out = net.modules.get(network_layer_name + "_out_proj", None)
|
||||||
|
|
||||||
if isinstance(self, torch.nn.MultiheadAttention) and module_q and module_k and module_v and module_out:
|
if isinstance(self, torch.nn.MultiheadAttention) and module_q and module_k and module_v and module_out:
|
||||||
with torch.no_grad():
|
try:
|
||||||
updown_q = module_q.calc_updown(self.in_proj_weight)
|
with torch.no_grad():
|
||||||
updown_k = module_k.calc_updown(self.in_proj_weight)
|
updown_q = module_q.calc_updown(self.in_proj_weight)
|
||||||
updown_v = module_v.calc_updown(self.in_proj_weight)
|
updown_k = module_k.calc_updown(self.in_proj_weight)
|
||||||
updown_qkv = torch.vstack([updown_q, updown_k, updown_v])
|
updown_v = module_v.calc_updown(self.in_proj_weight)
|
||||||
updown_out = module_out.calc_updown(self.out_proj.weight)
|
updown_qkv = torch.vstack([updown_q, updown_k, updown_v])
|
||||||
|
updown_out = module_out.calc_updown(self.out_proj.weight)
|
||||||
|
|
||||||
self.in_proj_weight += updown_qkv
|
self.in_proj_weight += updown_qkv
|
||||||
self.out_proj.weight += updown_out
|
self.out_proj.weight += updown_out
|
||||||
continue
|
|
||||||
|
except RuntimeError as e:
|
||||||
|
logging.debug(f"Network {net.name} layer {network_layer_name}: {e}")
|
||||||
|
extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1
|
||||||
|
|
||||||
|
continue
|
||||||
|
|
||||||
if module is None:
|
if module is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
print(f'failed to calculate network weights for layer {network_layer_name}')
|
logging.debug(f"Network {net.name} layer {network_layer_name}: couldn't find supported operation")
|
||||||
|
extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1
|
||||||
|
|
||||||
self.network_current_names = wanted_names
|
self.network_current_names = wanted_names
|
||||||
|
|
||||||
@ -519,6 +531,7 @@ def infotext_pasted(infotext, params):
|
|||||||
if added:
|
if added:
|
||||||
params["Prompt"] += "\n" + "".join(added)
|
params["Prompt"] += "\n" + "".join(added)
|
||||||
|
|
||||||
|
extra_network_lora = None
|
||||||
|
|
||||||
available_networks = {}
|
available_networks = {}
|
||||||
available_network_aliases = {}
|
available_network_aliases = {}
|
||||||
|
@ -23,9 +23,9 @@ def unload():
|
|||||||
def before_ui():
|
def before_ui():
|
||||||
ui_extra_networks.register_page(ui_extra_networks_lora.ExtraNetworksPageLora())
|
ui_extra_networks.register_page(ui_extra_networks_lora.ExtraNetworksPageLora())
|
||||||
|
|
||||||
extra_network = extra_networks_lora.ExtraNetworkLora()
|
networks.extra_network_lora = extra_networks_lora.ExtraNetworkLora()
|
||||||
extra_networks.register_extra_network(extra_network)
|
extra_networks.register_extra_network(networks.extra_network_lora)
|
||||||
extra_networks.register_extra_network_alias(extra_network, "lyco")
|
extra_networks.register_extra_network_alias(networks.extra_network_lora, "lyco")
|
||||||
|
|
||||||
|
|
||||||
if not hasattr(torch.nn, 'Linear_forward_before_network'):
|
if not hasattr(torch.nn, 'Linear_forward_before_network'):
|
||||||
|
@ -157,6 +157,7 @@ class StableDiffusionProcessing:
|
|||||||
cached_uc = [None, None]
|
cached_uc = [None, None]
|
||||||
cached_c = [None, None]
|
cached_c = [None, None]
|
||||||
|
|
||||||
|
comments: dict = None
|
||||||
sampler: sd_samplers_common.Sampler | None = field(default=None, init=False)
|
sampler: sd_samplers_common.Sampler | None = field(default=None, init=False)
|
||||||
is_using_inpainting_conditioning: bool = field(default=False, init=False)
|
is_using_inpainting_conditioning: bool = field(default=False, init=False)
|
||||||
paste_to: tuple | None = field(default=None, init=False)
|
paste_to: tuple | None = field(default=None, init=False)
|
||||||
@ -196,6 +197,8 @@ class StableDiffusionProcessing:
|
|||||||
if self.sampler_index is not None:
|
if self.sampler_index is not None:
|
||||||
print("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name", file=sys.stderr)
|
print("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name", file=sys.stderr)
|
||||||
|
|
||||||
|
self.comments = {}
|
||||||
|
|
||||||
self.sampler_noise_scheduler_override = None
|
self.sampler_noise_scheduler_override = None
|
||||||
self.s_min_uncond = self.s_min_uncond if self.s_min_uncond is not None else opts.s_min_uncond
|
self.s_min_uncond = self.s_min_uncond if self.s_min_uncond is not None else opts.s_min_uncond
|
||||||
self.s_churn = self.s_churn if self.s_churn is not None else opts.s_churn
|
self.s_churn = self.s_churn if self.s_churn is not None else opts.s_churn
|
||||||
@ -226,6 +229,9 @@ class StableDiffusionProcessing:
|
|||||||
def sd_model(self, value):
|
def sd_model(self, value):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def comment(self, text):
|
||||||
|
self.comments[text] = 1
|
||||||
|
|
||||||
def txt2img_image_conditioning(self, x, width=None, height=None):
|
def txt2img_image_conditioning(self, x, width=None, height=None):
|
||||||
self.is_using_inpainting_conditioning = self.sd_model.model.conditioning_key in {'hybrid', 'concat'}
|
self.is_using_inpainting_conditioning = self.sd_model.model.conditioning_key in {'hybrid', 'concat'}
|
||||||
|
|
||||||
@ -429,7 +435,7 @@ class Processed:
|
|||||||
self.subseed = subseed
|
self.subseed = subseed
|
||||||
self.subseed_strength = p.subseed_strength
|
self.subseed_strength = p.subseed_strength
|
||||||
self.info = info
|
self.info = info
|
||||||
self.comments = comments
|
self.comments = "".join(f"{comment}\n" for comment in p.comments)
|
||||||
self.width = p.width
|
self.width = p.width
|
||||||
self.height = p.height
|
self.height = p.height
|
||||||
self.sampler_name = p.sampler_name
|
self.sampler_name = p.sampler_name
|
||||||
@ -720,8 +726,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||||||
modules.sd_hijack.model_hijack.apply_circular(p.tiling)
|
modules.sd_hijack.model_hijack.apply_circular(p.tiling)
|
||||||
modules.sd_hijack.model_hijack.clear_comments()
|
modules.sd_hijack.model_hijack.clear_comments()
|
||||||
|
|
||||||
comments = {}
|
|
||||||
|
|
||||||
p.setup_prompts()
|
p.setup_prompts()
|
||||||
|
|
||||||
if type(seed) == list:
|
if type(seed) == list:
|
||||||
@ -801,7 +805,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||||||
p.setup_conds()
|
p.setup_conds()
|
||||||
|
|
||||||
for comment in model_hijack.comments:
|
for comment in model_hijack.comments:
|
||||||
comments[comment] = 1
|
p.comment(comment)
|
||||||
|
|
||||||
p.extra_generation_params.update(model_hijack.extra_generation_params)
|
p.extra_generation_params.update(model_hijack.extra_generation_params)
|
||||||
|
|
||||||
@ -930,7 +934,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||||||
images_list=output_images,
|
images_list=output_images,
|
||||||
seed=p.all_seeds[0],
|
seed=p.all_seeds[0],
|
||||||
info=infotexts[0],
|
info=infotexts[0],
|
||||||
comments="".join(f"{comment}\n" for comment in comments),
|
|
||||||
subseed=p.all_subseeds[0],
|
subseed=p.all_subseeds[0],
|
||||||
index_of_first_image=index_of_first_image,
|
index_of_first_image=index_of_first_image,
|
||||||
infotexts=infotexts,
|
infotexts=infotexts,
|
||||||
|
Loading…
Reference in New Issue
Block a user