mirror of
https://github.com/openvinotoolkit/stable-diffusion-webui.git
synced 2024-12-14 22:53:25 +03:00
Lora: output warnings in UI rather than fail for unfitting loras; switch to logging for error output in console
This commit is contained in:
parent
da80d649fd
commit
d8419762c1
@ -1,4 +1,4 @@
|
||||
from modules import extra_networks, shared
|
||||
from modules import extra_networks, shared, sd_hijack
|
||||
import networks
|
||||
|
||||
|
||||
@ -6,9 +6,14 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork):
|
||||
def __init__(self):
|
||||
super().__init__('lora')
|
||||
|
||||
self.errors = {}
|
||||
"""mapping of network names to the number of errors the network had during operation"""
|
||||
|
||||
def activate(self, p, params_list):
|
||||
additional = shared.opts.sd_lora
|
||||
|
||||
self.errors.clear()
|
||||
|
||||
if additional != "None" and additional in networks.available_networks and not any(x for x in params_list if x.items[0] == additional):
|
||||
p.all_prompts = [x + f"<lora:{additional}:{shared.opts.extra_networks_default_multiplier}>" for x in p.all_prompts]
|
||||
params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))
|
||||
@ -56,4 +61,7 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork):
|
||||
p.extra_generation_params["Lora hashes"] = ", ".join(network_hashes)
|
||||
|
||||
def deactivate(self, p):
|
||||
pass
|
||||
if self.errors:
|
||||
p.comment("Networks with errors: " + ", ".join(f"{k} ({v})" for k, v in self.errors.items()))
|
||||
|
||||
self.errors.clear()
|
||||
|
@ -1,3 +1,4 @@
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
|
||||
@ -194,7 +195,7 @@ def load_network(name, network_on_disk):
|
||||
net.modules[key] = net_module
|
||||
|
||||
if keys_failed_to_match:
|
||||
print(f"Failed to match keys when loading network {network_on_disk.filename}: {keys_failed_to_match}")
|
||||
logging.debug(f"Network {network_on_disk.filename} didn't match keys: {keys_failed_to_match}")
|
||||
|
||||
return net
|
||||
|
||||
@ -207,7 +208,6 @@ def purge_networks_from_memory():
|
||||
devices.torch_gc()
|
||||
|
||||
|
||||
|
||||
def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=None):
|
||||
already_loaded = {}
|
||||
|
||||
@ -248,7 +248,7 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No
|
||||
|
||||
if net is None:
|
||||
failed_to_load_networks.append(name)
|
||||
print(f"Couldn't find network with name {name}")
|
||||
logging.info(f"Couldn't find network with name {name}")
|
||||
continue
|
||||
|
||||
net.te_multiplier = te_multipliers[i] if te_multipliers else 1.0
|
||||
@ -257,7 +257,7 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No
|
||||
loaded_networks.append(net)
|
||||
|
||||
if failed_to_load_networks:
|
||||
sd_hijack.model_hijack.comments.append("Failed to find networks: " + ", ".join(failed_to_load_networks))
|
||||
sd_hijack.model_hijack.comments.append("Networks not found: " + ", ".join(failed_to_load_networks))
|
||||
|
||||
purge_networks_from_memory()
|
||||
|
||||
@ -314,17 +314,22 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
|
||||
for net in loaded_networks:
|
||||
module = net.modules.get(network_layer_name, None)
|
||||
if module is not None and hasattr(self, 'weight'):
|
||||
with torch.no_grad():
|
||||
updown, ex_bias = module.calc_updown(self.weight)
|
||||
try:
|
||||
with torch.no_grad():
|
||||
updown, ex_bias = module.calc_updown(self.weight)
|
||||
|
||||
if len(self.weight.shape) == 4 and self.weight.shape[1] == 9:
|
||||
# inpainting model. zero pad updown to make channel[1] 4 to 9
|
||||
updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5))
|
||||
if len(self.weight.shape) == 4 and self.weight.shape[1] == 9:
|
||||
# inpainting model. zero pad updown to make channel[1] 4 to 9
|
||||
updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5))
|
||||
|
||||
self.weight += updown
|
||||
if ex_bias is not None and getattr(self, 'bias', None) is not None:
|
||||
self.bias += ex_bias
|
||||
continue
|
||||
self.weight += updown
|
||||
if ex_bias is not None and getattr(self, 'bias', None) is not None:
|
||||
self.bias += ex_bias
|
||||
except RuntimeError as e:
|
||||
logging.debug(f"Network {net.name} layer {network_layer_name}: {e}")
|
||||
extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1
|
||||
|
||||
continue
|
||||
|
||||
module_q = net.modules.get(network_layer_name + "_q_proj", None)
|
||||
module_k = net.modules.get(network_layer_name + "_k_proj", None)
|
||||
@ -332,21 +337,28 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
|
||||
module_out = net.modules.get(network_layer_name + "_out_proj", None)
|
||||
|
||||
if isinstance(self, torch.nn.MultiheadAttention) and module_q and module_k and module_v and module_out:
|
||||
with torch.no_grad():
|
||||
updown_q = module_q.calc_updown(self.in_proj_weight)
|
||||
updown_k = module_k.calc_updown(self.in_proj_weight)
|
||||
updown_v = module_v.calc_updown(self.in_proj_weight)
|
||||
updown_qkv = torch.vstack([updown_q, updown_k, updown_v])
|
||||
updown_out = module_out.calc_updown(self.out_proj.weight)
|
||||
try:
|
||||
with torch.no_grad():
|
||||
updown_q = module_q.calc_updown(self.in_proj_weight)
|
||||
updown_k = module_k.calc_updown(self.in_proj_weight)
|
||||
updown_v = module_v.calc_updown(self.in_proj_weight)
|
||||
updown_qkv = torch.vstack([updown_q, updown_k, updown_v])
|
||||
updown_out = module_out.calc_updown(self.out_proj.weight)
|
||||
|
||||
self.in_proj_weight += updown_qkv
|
||||
self.out_proj.weight += updown_out
|
||||
continue
|
||||
self.in_proj_weight += updown_qkv
|
||||
self.out_proj.weight += updown_out
|
||||
|
||||
except RuntimeError as e:
|
||||
logging.debug(f"Network {net.name} layer {network_layer_name}: {e}")
|
||||
extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1
|
||||
|
||||
continue
|
||||
|
||||
if module is None:
|
||||
continue
|
||||
|
||||
print(f'failed to calculate network weights for layer {network_layer_name}')
|
||||
logging.debug(f"Network {net.name} layer {network_layer_name}: couldn't find supported operation")
|
||||
extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1
|
||||
|
||||
self.network_current_names = wanted_names
|
||||
|
||||
@ -519,6 +531,7 @@ def infotext_pasted(infotext, params):
|
||||
if added:
|
||||
params["Prompt"] += "\n" + "".join(added)
|
||||
|
||||
extra_network_lora = None
|
||||
|
||||
available_networks = {}
|
||||
available_network_aliases = {}
|
||||
|
@ -23,9 +23,9 @@ def unload():
|
||||
def before_ui():
|
||||
ui_extra_networks.register_page(ui_extra_networks_lora.ExtraNetworksPageLora())
|
||||
|
||||
extra_network = extra_networks_lora.ExtraNetworkLora()
|
||||
extra_networks.register_extra_network(extra_network)
|
||||
extra_networks.register_extra_network_alias(extra_network, "lyco")
|
||||
networks.extra_network_lora = extra_networks_lora.ExtraNetworkLora()
|
||||
extra_networks.register_extra_network(networks.extra_network_lora)
|
||||
extra_networks.register_extra_network_alias(networks.extra_network_lora, "lyco")
|
||||
|
||||
|
||||
if not hasattr(torch.nn, 'Linear_forward_before_network'):
|
||||
|
@ -157,6 +157,7 @@ class StableDiffusionProcessing:
|
||||
cached_uc = [None, None]
|
||||
cached_c = [None, None]
|
||||
|
||||
comments: dict = None
|
||||
sampler: sd_samplers_common.Sampler | None = field(default=None, init=False)
|
||||
is_using_inpainting_conditioning: bool = field(default=False, init=False)
|
||||
paste_to: tuple | None = field(default=None, init=False)
|
||||
@ -196,6 +197,8 @@ class StableDiffusionProcessing:
|
||||
if self.sampler_index is not None:
|
||||
print("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name", file=sys.stderr)
|
||||
|
||||
self.comments = {}
|
||||
|
||||
self.sampler_noise_scheduler_override = None
|
||||
self.s_min_uncond = self.s_min_uncond if self.s_min_uncond is not None else opts.s_min_uncond
|
||||
self.s_churn = self.s_churn if self.s_churn is not None else opts.s_churn
|
||||
@ -226,6 +229,9 @@ class StableDiffusionProcessing:
|
||||
def sd_model(self, value):
|
||||
pass
|
||||
|
||||
def comment(self, text):
|
||||
self.comments[text] = 1
|
||||
|
||||
def txt2img_image_conditioning(self, x, width=None, height=None):
|
||||
self.is_using_inpainting_conditioning = self.sd_model.model.conditioning_key in {'hybrid', 'concat'}
|
||||
|
||||
@ -429,7 +435,7 @@ class Processed:
|
||||
self.subseed = subseed
|
||||
self.subseed_strength = p.subseed_strength
|
||||
self.info = info
|
||||
self.comments = comments
|
||||
self.comments = "".join(f"{comment}\n" for comment in p.comments)
|
||||
self.width = p.width
|
||||
self.height = p.height
|
||||
self.sampler_name = p.sampler_name
|
||||
@ -720,8 +726,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||
modules.sd_hijack.model_hijack.apply_circular(p.tiling)
|
||||
modules.sd_hijack.model_hijack.clear_comments()
|
||||
|
||||
comments = {}
|
||||
|
||||
p.setup_prompts()
|
||||
|
||||
if type(seed) == list:
|
||||
@ -801,7 +805,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||
p.setup_conds()
|
||||
|
||||
for comment in model_hijack.comments:
|
||||
comments[comment] = 1
|
||||
p.comment(comment)
|
||||
|
||||
p.extra_generation_params.update(model_hijack.extra_generation_params)
|
||||
|
||||
@ -930,7 +934,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||
images_list=output_images,
|
||||
seed=p.all_seeds[0],
|
||||
info=infotexts[0],
|
||||
comments="".join(f"{comment}\n" for comment in comments),
|
||||
subseed=p.all_subseeds[0],
|
||||
index_of_first_image=index_of_first_image,
|
||||
infotexts=infotexts,
|
||||
|
Loading…
Reference in New Issue
Block a user