Lora: output warnings in UI rather than fail for unfitting loras; switch to logging for error output in console

This commit is contained in:
AUTOMATIC1111 2023-08-13 15:07:37 +03:00
parent da80d649fd
commit d8419762c1
4 changed files with 57 additions and 33 deletions

View File

@ -1,4 +1,4 @@
from modules import extra_networks, shared from modules import extra_networks, shared, sd_hijack
import networks import networks
@ -6,9 +6,14 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork):
def __init__(self): def __init__(self):
super().__init__('lora') super().__init__('lora')
self.errors = {}
"""mapping of network names to the number of errors the network had during operation"""
def activate(self, p, params_list): def activate(self, p, params_list):
additional = shared.opts.sd_lora additional = shared.opts.sd_lora
self.errors.clear()
if additional != "None" and additional in networks.available_networks and not any(x for x in params_list if x.items[0] == additional): if additional != "None" and additional in networks.available_networks and not any(x for x in params_list if x.items[0] == additional):
p.all_prompts = [x + f"<lora:{additional}:{shared.opts.extra_networks_default_multiplier}>" for x in p.all_prompts] p.all_prompts = [x + f"<lora:{additional}:{shared.opts.extra_networks_default_multiplier}>" for x in p.all_prompts]
params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier])) params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))
@ -56,4 +61,7 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork):
p.extra_generation_params["Lora hashes"] = ", ".join(network_hashes) p.extra_generation_params["Lora hashes"] = ", ".join(network_hashes)
def deactivate(self, p): def deactivate(self, p):
pass if self.errors:
p.comment("Networks with errors: " + ", ".join(f"{k} ({v})" for k, v in self.errors.items()))
self.errors.clear()

View File

@ -1,3 +1,4 @@
import logging
import os import os
import re import re
@ -194,7 +195,7 @@ def load_network(name, network_on_disk):
net.modules[key] = net_module net.modules[key] = net_module
if keys_failed_to_match: if keys_failed_to_match:
print(f"Failed to match keys when loading network {network_on_disk.filename}: {keys_failed_to_match}") logging.debug(f"Network {network_on_disk.filename} didn't match keys: {keys_failed_to_match}")
return net return net
@ -207,7 +208,6 @@ def purge_networks_from_memory():
devices.torch_gc() devices.torch_gc()
def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=None): def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=None):
already_loaded = {} already_loaded = {}
@ -248,7 +248,7 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No
if net is None: if net is None:
failed_to_load_networks.append(name) failed_to_load_networks.append(name)
print(f"Couldn't find network with name {name}") logging.info(f"Couldn't find network with name {name}")
continue continue
net.te_multiplier = te_multipliers[i] if te_multipliers else 1.0 net.te_multiplier = te_multipliers[i] if te_multipliers else 1.0
@ -257,7 +257,7 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No
loaded_networks.append(net) loaded_networks.append(net)
if failed_to_load_networks: if failed_to_load_networks:
sd_hijack.model_hijack.comments.append("Failed to find networks: " + ", ".join(failed_to_load_networks)) sd_hijack.model_hijack.comments.append("Networks not found: " + ", ".join(failed_to_load_networks))
purge_networks_from_memory() purge_networks_from_memory()
@ -314,17 +314,22 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
for net in loaded_networks: for net in loaded_networks:
module = net.modules.get(network_layer_name, None) module = net.modules.get(network_layer_name, None)
if module is not None and hasattr(self, 'weight'): if module is not None and hasattr(self, 'weight'):
with torch.no_grad(): try:
updown, ex_bias = module.calc_updown(self.weight) with torch.no_grad():
updown, ex_bias = module.calc_updown(self.weight)
if len(self.weight.shape) == 4 and self.weight.shape[1] == 9: if len(self.weight.shape) == 4 and self.weight.shape[1] == 9:
# inpainting model. zero pad updown to make channel[1] 4 to 9 # inpainting model. zero pad updown to make channel[1] 4 to 9
updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5)) updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5))
self.weight += updown self.weight += updown
if ex_bias is not None and getattr(self, 'bias', None) is not None: if ex_bias is not None and getattr(self, 'bias', None) is not None:
self.bias += ex_bias self.bias += ex_bias
continue except RuntimeError as e:
logging.debug(f"Network {net.name} layer {network_layer_name}: {e}")
extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1
continue
module_q = net.modules.get(network_layer_name + "_q_proj", None) module_q = net.modules.get(network_layer_name + "_q_proj", None)
module_k = net.modules.get(network_layer_name + "_k_proj", None) module_k = net.modules.get(network_layer_name + "_k_proj", None)
@ -332,21 +337,28 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
module_out = net.modules.get(network_layer_name + "_out_proj", None) module_out = net.modules.get(network_layer_name + "_out_proj", None)
if isinstance(self, torch.nn.MultiheadAttention) and module_q and module_k and module_v and module_out: if isinstance(self, torch.nn.MultiheadAttention) and module_q and module_k and module_v and module_out:
with torch.no_grad(): try:
updown_q = module_q.calc_updown(self.in_proj_weight) with torch.no_grad():
updown_k = module_k.calc_updown(self.in_proj_weight) updown_q = module_q.calc_updown(self.in_proj_weight)
updown_v = module_v.calc_updown(self.in_proj_weight) updown_k = module_k.calc_updown(self.in_proj_weight)
updown_qkv = torch.vstack([updown_q, updown_k, updown_v]) updown_v = module_v.calc_updown(self.in_proj_weight)
updown_out = module_out.calc_updown(self.out_proj.weight) updown_qkv = torch.vstack([updown_q, updown_k, updown_v])
updown_out = module_out.calc_updown(self.out_proj.weight)
self.in_proj_weight += updown_qkv self.in_proj_weight += updown_qkv
self.out_proj.weight += updown_out self.out_proj.weight += updown_out
continue
except RuntimeError as e:
logging.debug(f"Network {net.name} layer {network_layer_name}: {e}")
extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1
continue
if module is None: if module is None:
continue continue
print(f'failed to calculate network weights for layer {network_layer_name}') logging.debug(f"Network {net.name} layer {network_layer_name}: couldn't find supported operation")
extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1
self.network_current_names = wanted_names self.network_current_names = wanted_names
@ -519,6 +531,7 @@ def infotext_pasted(infotext, params):
if added: if added:
params["Prompt"] += "\n" + "".join(added) params["Prompt"] += "\n" + "".join(added)
extra_network_lora = None
available_networks = {} available_networks = {}
available_network_aliases = {} available_network_aliases = {}

View File

@ -23,9 +23,9 @@ def unload():
def before_ui(): def before_ui():
ui_extra_networks.register_page(ui_extra_networks_lora.ExtraNetworksPageLora()) ui_extra_networks.register_page(ui_extra_networks_lora.ExtraNetworksPageLora())
extra_network = extra_networks_lora.ExtraNetworkLora() networks.extra_network_lora = extra_networks_lora.ExtraNetworkLora()
extra_networks.register_extra_network(extra_network) extra_networks.register_extra_network(networks.extra_network_lora)
extra_networks.register_extra_network_alias(extra_network, "lyco") extra_networks.register_extra_network_alias(networks.extra_network_lora, "lyco")
if not hasattr(torch.nn, 'Linear_forward_before_network'): if not hasattr(torch.nn, 'Linear_forward_before_network'):

View File

@ -157,6 +157,7 @@ class StableDiffusionProcessing:
cached_uc = [None, None] cached_uc = [None, None]
cached_c = [None, None] cached_c = [None, None]
comments: dict = None
sampler: sd_samplers_common.Sampler | None = field(default=None, init=False) sampler: sd_samplers_common.Sampler | None = field(default=None, init=False)
is_using_inpainting_conditioning: bool = field(default=False, init=False) is_using_inpainting_conditioning: bool = field(default=False, init=False)
paste_to: tuple | None = field(default=None, init=False) paste_to: tuple | None = field(default=None, init=False)
@ -196,6 +197,8 @@ class StableDiffusionProcessing:
if self.sampler_index is not None: if self.sampler_index is not None:
print("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name", file=sys.stderr) print("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name", file=sys.stderr)
self.comments = {}
self.sampler_noise_scheduler_override = None self.sampler_noise_scheduler_override = None
self.s_min_uncond = self.s_min_uncond if self.s_min_uncond is not None else opts.s_min_uncond self.s_min_uncond = self.s_min_uncond if self.s_min_uncond is not None else opts.s_min_uncond
self.s_churn = self.s_churn if self.s_churn is not None else opts.s_churn self.s_churn = self.s_churn if self.s_churn is not None else opts.s_churn
@ -226,6 +229,9 @@ class StableDiffusionProcessing:
def sd_model(self, value): def sd_model(self, value):
pass pass
def comment(self, text):
self.comments[text] = 1
def txt2img_image_conditioning(self, x, width=None, height=None): def txt2img_image_conditioning(self, x, width=None, height=None):
self.is_using_inpainting_conditioning = self.sd_model.model.conditioning_key in {'hybrid', 'concat'} self.is_using_inpainting_conditioning = self.sd_model.model.conditioning_key in {'hybrid', 'concat'}
@ -429,7 +435,7 @@ class Processed:
self.subseed = subseed self.subseed = subseed
self.subseed_strength = p.subseed_strength self.subseed_strength = p.subseed_strength
self.info = info self.info = info
self.comments = comments self.comments = "".join(f"{comment}\n" for comment in p.comments)
self.width = p.width self.width = p.width
self.height = p.height self.height = p.height
self.sampler_name = p.sampler_name self.sampler_name = p.sampler_name
@ -720,8 +726,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
modules.sd_hijack.model_hijack.apply_circular(p.tiling) modules.sd_hijack.model_hijack.apply_circular(p.tiling)
modules.sd_hijack.model_hijack.clear_comments() modules.sd_hijack.model_hijack.clear_comments()
comments = {}
p.setup_prompts() p.setup_prompts()
if type(seed) == list: if type(seed) == list:
@ -801,7 +805,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
p.setup_conds() p.setup_conds()
for comment in model_hijack.comments: for comment in model_hijack.comments:
comments[comment] = 1 p.comment(comment)
p.extra_generation_params.update(model_hijack.extra_generation_params) p.extra_generation_params.update(model_hijack.extra_generation_params)
@ -930,7 +934,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
images_list=output_images, images_list=output_images,
seed=p.all_seeds[0], seed=p.all_seeds[0],
info=infotexts[0], info=infotexts[0],
comments="".join(f"{comment}\n" for comment in comments),
subseed=p.all_subseeds[0], subseed=p.all_subseeds[0],
index_of_first_image=index_of_first_image, index_of_first_image=index_of_first_image,
infotexts=infotexts, infotexts=infotexts,