mirror of
https://github.com/openvinotoolkit/stable-diffusion-webui.git
synced 2024-12-15 07:03:06 +03:00
Merge remote-tracking branch 'auto1111/dev' into shared-hires-prompt-test
This commit is contained in:
commit
d61e31bae6
@ -6,9 +6,14 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork):
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__('lora')
|
super().__init__('lora')
|
||||||
|
|
||||||
|
self.errors = {}
|
||||||
|
"""mapping of network names to the number of errors the network had during operation"""
|
||||||
|
|
||||||
def activate(self, p, params_list):
|
def activate(self, p, params_list):
|
||||||
additional = shared.opts.sd_lora
|
additional = shared.opts.sd_lora
|
||||||
|
|
||||||
|
self.errors.clear()
|
||||||
|
|
||||||
if additional != "None" and additional in networks.available_networks and not any(x for x in params_list if x.items[0] == additional):
|
if additional != "None" and additional in networks.available_networks and not any(x for x in params_list if x.items[0] == additional):
|
||||||
p.all_prompts = [x + f"<lora:{additional}:{shared.opts.extra_networks_default_multiplier}>" for x in p.all_prompts]
|
p.all_prompts = [x + f"<lora:{additional}:{shared.opts.extra_networks_default_multiplier}>" for x in p.all_prompts]
|
||||||
params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))
|
params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))
|
||||||
@ -56,4 +61,7 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork):
|
|||||||
p.extra_generation_params["Lora hashes"] = ", ".join(network_hashes)
|
p.extra_generation_params["Lora hashes"] = ", ".join(network_hashes)
|
||||||
|
|
||||||
def deactivate(self, p):
|
def deactivate(self, p):
|
||||||
pass
|
if self.errors:
|
||||||
|
p.comment("Networks with errors: " + ", ".join(f"{k} ({v})" for k, v in self.errors.items()))
|
||||||
|
|
||||||
|
self.errors.clear()
|
||||||
|
@ -133,7 +133,7 @@ class NetworkModule:
|
|||||||
|
|
||||||
return 1.0
|
return 1.0
|
||||||
|
|
||||||
def finalize_updown(self, updown, orig_weight, output_shape):
|
def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None):
|
||||||
if self.bias is not None:
|
if self.bias is not None:
|
||||||
updown = updown.reshape(self.bias.shape)
|
updown = updown.reshape(self.bias.shape)
|
||||||
updown += self.bias.to(orig_weight.device, dtype=orig_weight.dtype)
|
updown += self.bias.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||||
@ -145,7 +145,10 @@ class NetworkModule:
|
|||||||
if orig_weight.size().numel() == updown.size().numel():
|
if orig_weight.size().numel() == updown.size().numel():
|
||||||
updown = updown.reshape(orig_weight.shape)
|
updown = updown.reshape(orig_weight.shape)
|
||||||
|
|
||||||
return updown * self.calc_scale() * self.multiplier()
|
if ex_bias is not None:
|
||||||
|
ex_bias = ex_bias * self.multiplier()
|
||||||
|
|
||||||
|
return updown * self.calc_scale() * self.multiplier(), ex_bias
|
||||||
|
|
||||||
def calc_updown(self, target):
|
def calc_updown(self, target):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
28
extensions-builtin/Lora/network_norm.py
Normal file
28
extensions-builtin/Lora/network_norm.py
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
import network
|
||||||
|
|
||||||
|
|
||||||
|
class ModuleTypeNorm(network.ModuleType):
|
||||||
|
def create_module(self, net: network.Network, weights: network.NetworkWeights):
|
||||||
|
if all(x in weights.w for x in ["w_norm", "b_norm"]):
|
||||||
|
return NetworkModuleNorm(net, weights)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class NetworkModuleNorm(network.NetworkModule):
|
||||||
|
def __init__(self, net: network.Network, weights: network.NetworkWeights):
|
||||||
|
super().__init__(net, weights)
|
||||||
|
|
||||||
|
self.w_norm = weights.w.get("w_norm")
|
||||||
|
self.b_norm = weights.w.get("b_norm")
|
||||||
|
|
||||||
|
def calc_updown(self, orig_weight):
|
||||||
|
output_shape = self.w_norm.shape
|
||||||
|
updown = self.w_norm.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||||
|
|
||||||
|
if self.b_norm is not None:
|
||||||
|
ex_bias = self.b_norm.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||||
|
else:
|
||||||
|
ex_bias = None
|
||||||
|
|
||||||
|
return self.finalize_updown(updown, orig_weight, output_shape, ex_bias)
|
@ -1,3 +1,4 @@
|
|||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
|
|
||||||
@ -7,6 +8,7 @@ import network_hada
|
|||||||
import network_ia3
|
import network_ia3
|
||||||
import network_lokr
|
import network_lokr
|
||||||
import network_full
|
import network_full
|
||||||
|
import network_norm
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from typing import Union
|
from typing import Union
|
||||||
@ -19,6 +21,7 @@ module_types = [
|
|||||||
network_ia3.ModuleTypeIa3(),
|
network_ia3.ModuleTypeIa3(),
|
||||||
network_lokr.ModuleTypeLokr(),
|
network_lokr.ModuleTypeLokr(),
|
||||||
network_full.ModuleTypeFull(),
|
network_full.ModuleTypeFull(),
|
||||||
|
network_norm.ModuleTypeNorm(),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -31,6 +34,8 @@ suffix_conversion = {
|
|||||||
"resnets": {
|
"resnets": {
|
||||||
"conv1": "in_layers_2",
|
"conv1": "in_layers_2",
|
||||||
"conv2": "out_layers_3",
|
"conv2": "out_layers_3",
|
||||||
|
"norm1": "in_layers_0",
|
||||||
|
"norm2": "out_layers_0",
|
||||||
"time_emb_proj": "emb_layers_1",
|
"time_emb_proj": "emb_layers_1",
|
||||||
"conv_shortcut": "skip_connection",
|
"conv_shortcut": "skip_connection",
|
||||||
}
|
}
|
||||||
@ -190,7 +195,7 @@ def load_network(name, network_on_disk):
|
|||||||
net.modules[key] = net_module
|
net.modules[key] = net_module
|
||||||
|
|
||||||
if keys_failed_to_match:
|
if keys_failed_to_match:
|
||||||
print(f"Failed to match keys when loading network {network_on_disk.filename}: {keys_failed_to_match}")
|
logging.debug(f"Network {network_on_disk.filename} didn't match keys: {keys_failed_to_match}")
|
||||||
|
|
||||||
return net
|
return net
|
||||||
|
|
||||||
@ -203,7 +208,6 @@ def purge_networks_from_memory():
|
|||||||
devices.torch_gc()
|
devices.torch_gc()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=None):
|
def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=None):
|
||||||
already_loaded = {}
|
already_loaded = {}
|
||||||
|
|
||||||
@ -244,7 +248,7 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No
|
|||||||
|
|
||||||
if net is None:
|
if net is None:
|
||||||
failed_to_load_networks.append(name)
|
failed_to_load_networks.append(name)
|
||||||
print(f"Couldn't find network with name {name}")
|
logging.info(f"Couldn't find network with name {name}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
net.te_multiplier = te_multipliers[i] if te_multipliers else 1.0
|
net.te_multiplier = te_multipliers[i] if te_multipliers else 1.0
|
||||||
@ -253,25 +257,38 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No
|
|||||||
loaded_networks.append(net)
|
loaded_networks.append(net)
|
||||||
|
|
||||||
if failed_to_load_networks:
|
if failed_to_load_networks:
|
||||||
sd_hijack.model_hijack.comments.append("Failed to find networks: " + ", ".join(failed_to_load_networks))
|
sd_hijack.model_hijack.comments.append("Networks not found: " + ", ".join(failed_to_load_networks))
|
||||||
|
|
||||||
purge_networks_from_memory()
|
purge_networks_from_memory()
|
||||||
|
|
||||||
|
|
||||||
def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]):
|
def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]):
|
||||||
weights_backup = getattr(self, "network_weights_backup", None)
|
weights_backup = getattr(self, "network_weights_backup", None)
|
||||||
|
bias_backup = getattr(self, "network_bias_backup", None)
|
||||||
|
|
||||||
if weights_backup is None:
|
if weights_backup is None and bias_backup is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
if isinstance(self, torch.nn.MultiheadAttention):
|
if weights_backup is not None:
|
||||||
self.in_proj_weight.copy_(weights_backup[0])
|
if isinstance(self, torch.nn.MultiheadAttention):
|
||||||
self.out_proj.weight.copy_(weights_backup[1])
|
self.in_proj_weight.copy_(weights_backup[0])
|
||||||
|
self.out_proj.weight.copy_(weights_backup[1])
|
||||||
|
else:
|
||||||
|
self.weight.copy_(weights_backup)
|
||||||
|
|
||||||
|
if bias_backup is not None:
|
||||||
|
if isinstance(self, torch.nn.MultiheadAttention):
|
||||||
|
self.out_proj.bias.copy_(bias_backup)
|
||||||
|
else:
|
||||||
|
self.bias.copy_(bias_backup)
|
||||||
else:
|
else:
|
||||||
self.weight.copy_(weights_backup)
|
if isinstance(self, torch.nn.MultiheadAttention):
|
||||||
|
self.out_proj.bias = None
|
||||||
|
else:
|
||||||
|
self.bias = None
|
||||||
|
|
||||||
|
|
||||||
def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]):
|
def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]):
|
||||||
"""
|
"""
|
||||||
Applies the currently selected set of networks to the weights of torch layer self.
|
Applies the currently selected set of networks to the weights of torch layer self.
|
||||||
If weights already have this particular set of networks applied, does nothing.
|
If weights already have this particular set of networks applied, does nothing.
|
||||||
@ -294,21 +311,41 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
|
|||||||
|
|
||||||
self.network_weights_backup = weights_backup
|
self.network_weights_backup = weights_backup
|
||||||
|
|
||||||
|
bias_backup = getattr(self, "network_bias_backup", None)
|
||||||
|
if bias_backup is None:
|
||||||
|
if isinstance(self, torch.nn.MultiheadAttention) and self.out_proj.bias is not None:
|
||||||
|
bias_backup = self.out_proj.bias.to(devices.cpu, copy=True)
|
||||||
|
elif getattr(self, 'bias', None) is not None:
|
||||||
|
bias_backup = self.bias.to(devices.cpu, copy=True)
|
||||||
|
else:
|
||||||
|
bias_backup = None
|
||||||
|
self.network_bias_backup = bias_backup
|
||||||
|
|
||||||
if current_names != wanted_names:
|
if current_names != wanted_names:
|
||||||
network_restore_weights_from_backup(self)
|
network_restore_weights_from_backup(self)
|
||||||
|
|
||||||
for net in loaded_networks:
|
for net in loaded_networks:
|
||||||
module = net.modules.get(network_layer_name, None)
|
module = net.modules.get(network_layer_name, None)
|
||||||
if module is not None and hasattr(self, 'weight'):
|
if module is not None and hasattr(self, 'weight'):
|
||||||
with torch.no_grad():
|
try:
|
||||||
updown = module.calc_updown(self.weight)
|
with torch.no_grad():
|
||||||
|
updown, ex_bias = module.calc_updown(self.weight)
|
||||||
|
|
||||||
if len(self.weight.shape) == 4 and self.weight.shape[1] == 9:
|
if len(self.weight.shape) == 4 and self.weight.shape[1] == 9:
|
||||||
# inpainting model. zero pad updown to make channel[1] 4 to 9
|
# inpainting model. zero pad updown to make channel[1] 4 to 9
|
||||||
updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5))
|
updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5))
|
||||||
|
|
||||||
self.weight += updown
|
self.weight += updown
|
||||||
continue
|
if ex_bias is not None and hasattr(self, 'bias'):
|
||||||
|
if self.bias is None:
|
||||||
|
self.bias = torch.nn.Parameter(ex_bias)
|
||||||
|
else:
|
||||||
|
self.bias += ex_bias
|
||||||
|
except RuntimeError as e:
|
||||||
|
logging.debug(f"Network {net.name} layer {network_layer_name}: {e}")
|
||||||
|
extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1
|
||||||
|
|
||||||
|
continue
|
||||||
|
|
||||||
module_q = net.modules.get(network_layer_name + "_q_proj", None)
|
module_q = net.modules.get(network_layer_name + "_q_proj", None)
|
||||||
module_k = net.modules.get(network_layer_name + "_k_proj", None)
|
module_k = net.modules.get(network_layer_name + "_k_proj", None)
|
||||||
@ -316,21 +353,33 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
|
|||||||
module_out = net.modules.get(network_layer_name + "_out_proj", None)
|
module_out = net.modules.get(network_layer_name + "_out_proj", None)
|
||||||
|
|
||||||
if isinstance(self, torch.nn.MultiheadAttention) and module_q and module_k and module_v and module_out:
|
if isinstance(self, torch.nn.MultiheadAttention) and module_q and module_k and module_v and module_out:
|
||||||
with torch.no_grad():
|
try:
|
||||||
updown_q = module_q.calc_updown(self.in_proj_weight)
|
with torch.no_grad():
|
||||||
updown_k = module_k.calc_updown(self.in_proj_weight)
|
updown_q, _ = module_q.calc_updown(self.in_proj_weight)
|
||||||
updown_v = module_v.calc_updown(self.in_proj_weight)
|
updown_k, _ = module_k.calc_updown(self.in_proj_weight)
|
||||||
updown_qkv = torch.vstack([updown_q, updown_k, updown_v])
|
updown_v, _ = module_v.calc_updown(self.in_proj_weight)
|
||||||
updown_out = module_out.calc_updown(self.out_proj.weight)
|
updown_qkv = torch.vstack([updown_q, updown_k, updown_v])
|
||||||
|
updown_out, ex_bias = module_out.calc_updown(self.out_proj.weight)
|
||||||
|
|
||||||
self.in_proj_weight += updown_qkv
|
self.in_proj_weight += updown_qkv
|
||||||
self.out_proj.weight += updown_out
|
self.out_proj.weight += updown_out
|
||||||
continue
|
if ex_bias is not None:
|
||||||
|
if self.out_proj.bias is None:
|
||||||
|
self.out_proj.bias = torch.nn.Parameter(ex_bias)
|
||||||
|
else:
|
||||||
|
self.out_proj.bias += ex_bias
|
||||||
|
|
||||||
|
except RuntimeError as e:
|
||||||
|
logging.debug(f"Network {net.name} layer {network_layer_name}: {e}")
|
||||||
|
extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1
|
||||||
|
|
||||||
|
continue
|
||||||
|
|
||||||
if module is None:
|
if module is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
print(f'failed to calculate network weights for layer {network_layer_name}')
|
logging.debug(f"Network {net.name} layer {network_layer_name}: couldn't find supported operation")
|
||||||
|
extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1
|
||||||
|
|
||||||
self.network_current_names = wanted_names
|
self.network_current_names = wanted_names
|
||||||
|
|
||||||
@ -357,7 +406,7 @@ def network_forward(module, input, original_forward):
|
|||||||
if module is None:
|
if module is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
y = module.forward(y, input)
|
y = module.forward(input, y)
|
||||||
|
|
||||||
return y
|
return y
|
||||||
|
|
||||||
@ -397,6 +446,36 @@ def network_Conv2d_load_state_dict(self, *args, **kwargs):
|
|||||||
return torch.nn.Conv2d_load_state_dict_before_network(self, *args, **kwargs)
|
return torch.nn.Conv2d_load_state_dict_before_network(self, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def network_GroupNorm_forward(self, input):
|
||||||
|
if shared.opts.lora_functional:
|
||||||
|
return network_forward(self, input, torch.nn.GroupNorm_forward_before_network)
|
||||||
|
|
||||||
|
network_apply_weights(self)
|
||||||
|
|
||||||
|
return torch.nn.GroupNorm_forward_before_network(self, input)
|
||||||
|
|
||||||
|
|
||||||
|
def network_GroupNorm_load_state_dict(self, *args, **kwargs):
|
||||||
|
network_reset_cached_weight(self)
|
||||||
|
|
||||||
|
return torch.nn.GroupNorm_load_state_dict_before_network(self, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def network_LayerNorm_forward(self, input):
|
||||||
|
if shared.opts.lora_functional:
|
||||||
|
return network_forward(self, input, torch.nn.LayerNorm_forward_before_network)
|
||||||
|
|
||||||
|
network_apply_weights(self)
|
||||||
|
|
||||||
|
return torch.nn.LayerNorm_forward_before_network(self, input)
|
||||||
|
|
||||||
|
|
||||||
|
def network_LayerNorm_load_state_dict(self, *args, **kwargs):
|
||||||
|
network_reset_cached_weight(self)
|
||||||
|
|
||||||
|
return torch.nn.LayerNorm_load_state_dict_before_network(self, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def network_MultiheadAttention_forward(self, *args, **kwargs):
|
def network_MultiheadAttention_forward(self, *args, **kwargs):
|
||||||
network_apply_weights(self)
|
network_apply_weights(self)
|
||||||
|
|
||||||
@ -473,6 +552,7 @@ def infotext_pasted(infotext, params):
|
|||||||
if added:
|
if added:
|
||||||
params["Prompt"] += "\n" + "".join(added)
|
params["Prompt"] += "\n" + "".join(added)
|
||||||
|
|
||||||
|
extra_network_lora = None
|
||||||
|
|
||||||
available_networks = {}
|
available_networks = {}
|
||||||
available_network_aliases = {}
|
available_network_aliases = {}
|
||||||
|
@ -23,9 +23,9 @@ def unload():
|
|||||||
def before_ui():
|
def before_ui():
|
||||||
ui_extra_networks.register_page(ui_extra_networks_lora.ExtraNetworksPageLora())
|
ui_extra_networks.register_page(ui_extra_networks_lora.ExtraNetworksPageLora())
|
||||||
|
|
||||||
extra_network = extra_networks_lora.ExtraNetworkLora()
|
networks.extra_network_lora = extra_networks_lora.ExtraNetworkLora()
|
||||||
extra_networks.register_extra_network(extra_network)
|
extra_networks.register_extra_network(networks.extra_network_lora)
|
||||||
extra_networks.register_extra_network_alias(extra_network, "lyco")
|
extra_networks.register_extra_network_alias(networks.extra_network_lora, "lyco")
|
||||||
|
|
||||||
|
|
||||||
if not hasattr(torch.nn, 'Linear_forward_before_network'):
|
if not hasattr(torch.nn, 'Linear_forward_before_network'):
|
||||||
@ -40,6 +40,18 @@ if not hasattr(torch.nn, 'Conv2d_forward_before_network'):
|
|||||||
if not hasattr(torch.nn, 'Conv2d_load_state_dict_before_network'):
|
if not hasattr(torch.nn, 'Conv2d_load_state_dict_before_network'):
|
||||||
torch.nn.Conv2d_load_state_dict_before_network = torch.nn.Conv2d._load_from_state_dict
|
torch.nn.Conv2d_load_state_dict_before_network = torch.nn.Conv2d._load_from_state_dict
|
||||||
|
|
||||||
|
if not hasattr(torch.nn, 'GroupNorm_forward_before_network'):
|
||||||
|
torch.nn.GroupNorm_forward_before_network = torch.nn.GroupNorm.forward
|
||||||
|
|
||||||
|
if not hasattr(torch.nn, 'GroupNorm_load_state_dict_before_network'):
|
||||||
|
torch.nn.GroupNorm_load_state_dict_before_network = torch.nn.GroupNorm._load_from_state_dict
|
||||||
|
|
||||||
|
if not hasattr(torch.nn, 'LayerNorm_forward_before_network'):
|
||||||
|
torch.nn.LayerNorm_forward_before_network = torch.nn.LayerNorm.forward
|
||||||
|
|
||||||
|
if not hasattr(torch.nn, 'LayerNorm_load_state_dict_before_network'):
|
||||||
|
torch.nn.LayerNorm_load_state_dict_before_network = torch.nn.LayerNorm._load_from_state_dict
|
||||||
|
|
||||||
if not hasattr(torch.nn, 'MultiheadAttention_forward_before_network'):
|
if not hasattr(torch.nn, 'MultiheadAttention_forward_before_network'):
|
||||||
torch.nn.MultiheadAttention_forward_before_network = torch.nn.MultiheadAttention.forward
|
torch.nn.MultiheadAttention_forward_before_network = torch.nn.MultiheadAttention.forward
|
||||||
|
|
||||||
@ -50,6 +62,10 @@ torch.nn.Linear.forward = networks.network_Linear_forward
|
|||||||
torch.nn.Linear._load_from_state_dict = networks.network_Linear_load_state_dict
|
torch.nn.Linear._load_from_state_dict = networks.network_Linear_load_state_dict
|
||||||
torch.nn.Conv2d.forward = networks.network_Conv2d_forward
|
torch.nn.Conv2d.forward = networks.network_Conv2d_forward
|
||||||
torch.nn.Conv2d._load_from_state_dict = networks.network_Conv2d_load_state_dict
|
torch.nn.Conv2d._load_from_state_dict = networks.network_Conv2d_load_state_dict
|
||||||
|
torch.nn.GroupNorm.forward = networks.network_GroupNorm_forward
|
||||||
|
torch.nn.GroupNorm._load_from_state_dict = networks.network_GroupNorm_load_state_dict
|
||||||
|
torch.nn.LayerNorm.forward = networks.network_LayerNorm_forward
|
||||||
|
torch.nn.LayerNorm._load_from_state_dict = networks.network_LayerNorm_load_state_dict
|
||||||
torch.nn.MultiheadAttention.forward = networks.network_MultiheadAttention_forward
|
torch.nn.MultiheadAttention.forward = networks.network_MultiheadAttention_forward
|
||||||
torch.nn.MultiheadAttention._load_from_state_dict = networks.network_MultiheadAttention_load_state_dict
|
torch.nn.MultiheadAttention._load_from_state_dict = networks.network_MultiheadAttention_load_state_dict
|
||||||
|
|
||||||
|
@ -25,9 +25,10 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
|
|||||||
item = {
|
item = {
|
||||||
"name": name,
|
"name": name,
|
||||||
"filename": lora_on_disk.filename,
|
"filename": lora_on_disk.filename,
|
||||||
|
"shorthash": lora_on_disk.shorthash,
|
||||||
"preview": self.find_preview(path),
|
"preview": self.find_preview(path),
|
||||||
"description": self.find_description(path),
|
"description": self.find_description(path),
|
||||||
"search_term": self.search_terms_from_path(lora_on_disk.filename),
|
"search_term": self.search_terms_from_path(lora_on_disk.filename) + " " + (lora_on_disk.hash or ""),
|
||||||
"local_preview": f"{path}.{shared.opts.samples_format}",
|
"local_preview": f"{path}.{shared.opts.samples_format}",
|
||||||
"metadata": lora_on_disk.metadata,
|
"metadata": lora_on_disk.metadata,
|
||||||
"sort_keys": {'default': index, **self.get_sort_keys(lora_on_disk.filename)},
|
"sort_keys": {'default': index, **self.get_sort_keys(lora_on_disk.filename)},
|
||||||
|
@ -12,6 +12,7 @@ onUiLoaded(async() => {
|
|||||||
"Sketch": elementIDs.sketch
|
"Sketch": elementIDs.sketch
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
// Helper functions
|
// Helper functions
|
||||||
// Get active tab
|
// Get active tab
|
||||||
function getActiveTab(elements, all = false) {
|
function getActiveTab(elements, all = false) {
|
||||||
@ -377,6 +378,11 @@ onUiLoaded(async() => {
|
|||||||
toggleOverlap("off");
|
toggleOverlap("off");
|
||||||
fullScreenMode = false;
|
fullScreenMode = false;
|
||||||
|
|
||||||
|
const closeBtn = targetElement.querySelector("button[aria-label='Remove Image']");
|
||||||
|
if (closeBtn) {
|
||||||
|
closeBtn.addEventListener("click", resetZoom);
|
||||||
|
}
|
||||||
|
|
||||||
if (
|
if (
|
||||||
canvas &&
|
canvas &&
|
||||||
parseFloat(canvas.style.width) > 865 &&
|
parseFloat(canvas.style.width) > 865 &&
|
||||||
@ -657,17 +663,20 @@ onUiLoaded(async() => {
|
|||||||
// Simulation of the function to put a long image into the screen.
|
// Simulation of the function to put a long image into the screen.
|
||||||
// We detect if an image has a scroll bar or not, make a fullscreen to reveal the image, then reduce it to fit into the element.
|
// We detect if an image has a scroll bar or not, make a fullscreen to reveal the image, then reduce it to fit into the element.
|
||||||
// We hide the image and show it to the user when it is ready.
|
// We hide the image and show it to the user when it is ready.
|
||||||
function autoExpand(e) {
|
|
||||||
|
targetElement.isExpanded = false;
|
||||||
|
function autoExpand() {
|
||||||
const canvas = document.querySelector(`${elemId} canvas[key="interface"]`);
|
const canvas = document.querySelector(`${elemId} canvas[key="interface"]`);
|
||||||
const isMainTab = activeElement === elementIDs.inpaint || activeElement === elementIDs.inpaintSketch || activeElement === elementIDs.sketch;
|
const isMainTab = activeElement === elementIDs.inpaint || activeElement === elementIDs.inpaintSketch || activeElement === elementIDs.sketch;
|
||||||
|
|
||||||
if (canvas && isMainTab) {
|
if (canvas && isMainTab) {
|
||||||
if (hasHorizontalScrollbar(targetElement)) {
|
if (hasHorizontalScrollbar(targetElement) && targetElement.isExpanded === false) {
|
||||||
targetElement.style.visibility = "hidden";
|
targetElement.style.visibility = "hidden";
|
||||||
setTimeout(() => {
|
setTimeout(() => {
|
||||||
fitToScreen();
|
fitToScreen();
|
||||||
resetZoom();
|
resetZoom();
|
||||||
targetElement.style.visibility = "visible";
|
targetElement.style.visibility = "visible";
|
||||||
|
targetElement.isExpanded = true;
|
||||||
}, 10);
|
}, 10);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -675,9 +684,24 @@ onUiLoaded(async() => {
|
|||||||
|
|
||||||
targetElement.addEventListener("mousemove", getMousePosition);
|
targetElement.addEventListener("mousemove", getMousePosition);
|
||||||
|
|
||||||
|
//observers
|
||||||
|
// Creating an observer with a callback function to handle DOM changes
|
||||||
|
const observer = new MutationObserver((mutationsList, observer) => {
|
||||||
|
for (let mutation of mutationsList) {
|
||||||
|
// If the style attribute of the canvas has changed, by observation it happens only when the picture changes
|
||||||
|
if (mutation.type === 'attributes' && mutation.attributeName === 'style' &&
|
||||||
|
mutation.target.tagName.toLowerCase() === 'canvas') {
|
||||||
|
targetElement.isExpanded = false;
|
||||||
|
setTimeout(resetZoom, 10);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
// Apply auto expand if enabled
|
// Apply auto expand if enabled
|
||||||
if (hotkeysConfig.canvas_auto_expand) {
|
if (hotkeysConfig.canvas_auto_expand) {
|
||||||
targetElement.addEventListener("mousemove", autoExpand);
|
targetElement.addEventListener("mousemove", autoExpand);
|
||||||
|
// Set up an observer to track attribute changes
|
||||||
|
observer.observe(targetElement, {attributes: true, childList: true, subtree: true});
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle events only inside the targetElement
|
// Handle events only inside the targetElement
|
||||||
|
@ -50,10 +50,12 @@ class PydanticModelGenerator:
|
|||||||
additional_fields = None,
|
additional_fields = None,
|
||||||
):
|
):
|
||||||
def field_type_generator(k, v):
|
def field_type_generator(k, v):
|
||||||
# field_type = str if not overrides.get(k) else overrides[k]["type"]
|
|
||||||
# print(k, v.annotation, v.default)
|
|
||||||
field_type = v.annotation
|
field_type = v.annotation
|
||||||
|
|
||||||
|
if field_type == 'Image':
|
||||||
|
# images are sent as base64 strings via API
|
||||||
|
field_type = 'str'
|
||||||
|
|
||||||
return Optional[field_type]
|
return Optional[field_type]
|
||||||
|
|
||||||
def merge_class_params(class_):
|
def merge_class_params(class_):
|
||||||
@ -63,7 +65,6 @@ class PydanticModelGenerator:
|
|||||||
parameters = {**parameters, **inspect.signature(classes.__init__).parameters}
|
parameters = {**parameters, **inspect.signature(classes.__init__).parameters}
|
||||||
return parameters
|
return parameters
|
||||||
|
|
||||||
|
|
||||||
self._model_name = model_name
|
self._model_name = model_name
|
||||||
self._class_data = merge_class_params(class_instance)
|
self._class_data = merge_class_params(class_instance)
|
||||||
|
|
||||||
@ -72,7 +73,7 @@ class PydanticModelGenerator:
|
|||||||
field=underscore(k),
|
field=underscore(k),
|
||||||
field_alias=k,
|
field_alias=k,
|
||||||
field_type=field_type_generator(k, v),
|
field_type=field_type_generator(k, v),
|
||||||
field_value=v.default
|
field_value=None if isinstance(v.default, property) else v.default
|
||||||
)
|
)
|
||||||
for (k,v) in self._class_data.items() if k not in API_NOT_ALLOWED
|
for (k,v) in self._class_data.items() if k not in API_NOT_ALLOWED
|
||||||
]
|
]
|
||||||
|
@ -116,7 +116,7 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal
|
|||||||
process_images(p)
|
process_images(p)
|
||||||
|
|
||||||
|
|
||||||
def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_name: str, mask_blur: int, mask_alpha: float, inpainting_fill: int, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, img2img_batch_use_png_info: bool, img2img_batch_png_info_props: list, img2img_batch_png_info_dir: str, request: gr.Request, *args):
|
def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_name: str, mask_blur: int, mask_alpha: float, inpainting_fill: int, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, img2img_batch_use_png_info: bool, img2img_batch_png_info_props: list, img2img_batch_png_info_dir: str, request: gr.Request, *args):
|
||||||
override_settings = create_override_settings_dict(override_settings_texts)
|
override_settings = create_override_settings_dict(override_settings_texts)
|
||||||
|
|
||||||
is_batch = mode == 5
|
is_batch = mode == 5
|
||||||
@ -166,12 +166,6 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
|
|||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
negative_prompt=negative_prompt,
|
negative_prompt=negative_prompt,
|
||||||
styles=prompt_styles,
|
styles=prompt_styles,
|
||||||
seed=seed,
|
|
||||||
subseed=subseed,
|
|
||||||
subseed_strength=subseed_strength,
|
|
||||||
seed_resize_from_h=seed_resize_from_h,
|
|
||||||
seed_resize_from_w=seed_resize_from_w,
|
|
||||||
seed_enable_extras=seed_enable_extras,
|
|
||||||
sampler_name=sampler_name,
|
sampler_name=sampler_name,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
n_iter=n_iter,
|
n_iter=n_iter,
|
||||||
|
@ -173,9 +173,12 @@ def git_clone(url, dir, name, commithash=None):
|
|||||||
if current_hash == commithash:
|
if current_hash == commithash:
|
||||||
return
|
return
|
||||||
|
|
||||||
run_git('fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}", autofix=False)
|
if run_git(dir, name, 'config --get remote.origin.url', None, f"Couldn't determine {name}'s origin URL", live=False).strip() != url:
|
||||||
|
run_git(dir, name, f'remote set-url origin "{url}"', None, f"Failed to set {name}'s origin URL", live=False)
|
||||||
|
|
||||||
run_git('checkout', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}", live=True)
|
run_git(dir, name, 'fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}", autofix=False)
|
||||||
|
|
||||||
|
run_git(dir, name, f'checkout {commithash}', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}", live=True)
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -243,7 +246,7 @@ def list_extensions(settings_file):
|
|||||||
disabled_extensions = set(settings.get('disabled_extensions', []))
|
disabled_extensions = set(settings.get('disabled_extensions', []))
|
||||||
disable_all_extensions = settings.get('disable_all_extensions', 'none')
|
disable_all_extensions = settings.get('disable_all_extensions', 'none')
|
||||||
|
|
||||||
if disable_all_extensions != 'none':
|
if disable_all_extensions != 'none' or args.disable_extra_extensions or args.disable_all_extensions:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
return [x for x in os.listdir(extensions_dir) if x not in disabled_extensions]
|
return [x for x in os.listdir(extensions_dir) if x not in disabled_extensions]
|
||||||
@ -319,12 +322,12 @@ def prepare_environment():
|
|||||||
|
|
||||||
stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf")
|
stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf")
|
||||||
stable_diffusion_xl_commit_hash = os.environ.get('STABLE_DIFFUSION_XL_COMMIT_HASH', "5c10deee76adad0032b412294130090932317a87")
|
stable_diffusion_xl_commit_hash = os.environ.get('STABLE_DIFFUSION_XL_COMMIT_HASH', "5c10deee76adad0032b412294130090932317a87")
|
||||||
k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "c9fe758757e022f05ca5a53fa8fac28889e4f1cf")
|
k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "ab527a9a6d347f364e3d185ba6d714e22d80cb3c")
|
||||||
codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af")
|
codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af")
|
||||||
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
|
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# the existance of this file is a signal to webui.sh/bat that webui needs to be restarted when it stops execution
|
# the existence of this file is a signal to webui.sh/bat that webui needs to be restarted when it stops execution
|
||||||
os.remove(os.path.join(script_path, "tmp", "restart"))
|
os.remove(os.path.join(script_path, "tmp", "restart"))
|
||||||
os.environ.setdefault('SD_WEBUI_RESTARTING', '1')
|
os.environ.setdefault('SD_WEBUI_RESTARTING', '1')
|
||||||
except OSError:
|
except OSError:
|
||||||
|
@ -52,9 +52,6 @@ def cumsum_fix(input, cumsum_func, *args, **kwargs):
|
|||||||
|
|
||||||
|
|
||||||
if has_mps:
|
if has_mps:
|
||||||
# MPS fix for randn in torchsde
|
|
||||||
CondFunc('torchsde._brownian.brownian_interval._randn', lambda _, size, dtype, device, seed: torch.randn(size, dtype=dtype, device=torch.device("cpu"), generator=torch.Generator(torch.device("cpu")).manual_seed(int(seed))).to(device), lambda _, size, dtype, device, seed: device.type == 'mps')
|
|
||||||
|
|
||||||
if platform.mac_ver()[0].startswith("13.2."):
|
if platform.mac_ver()[0].startswith("13.2."):
|
||||||
# MPS workaround for https://github.com/pytorch/pytorch/issues/95188, thanks to danieldk (https://github.com/explosion/curated-transformers/pull/124)
|
# MPS workaround for https://github.com/pytorch/pytorch/issues/95188, thanks to danieldk (https://github.com/explosion/curated-transformers/pull/124)
|
||||||
CondFunc('torch.nn.functional.linear', lambda _, input, weight, bias: (torch.matmul(input, weight.t()) + bias) if bias is not None else torch.matmul(input, weight.t()), lambda _, input, weight, bias: input.numel() > 10485760)
|
CondFunc('torch.nn.functional.linear', lambda _, input, weight, bias: (torch.matmul(input, weight.t()) + bias) if bias is not None else torch.matmul(input, weight.t()), lambda _, input, weight, bias: input.numel() > 10485760)
|
||||||
|
@ -11,37 +11,32 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir,
|
|||||||
|
|
||||||
shared.state.begin(job="extras")
|
shared.state.begin(job="extras")
|
||||||
|
|
||||||
image_data = []
|
|
||||||
image_names = []
|
|
||||||
outputs = []
|
outputs = []
|
||||||
|
|
||||||
if extras_mode == 1:
|
def get_images(extras_mode, image, image_folder, input_dir):
|
||||||
for img in image_folder:
|
if extras_mode == 1:
|
||||||
if isinstance(img, Image.Image):
|
for img in image_folder:
|
||||||
image = img
|
if isinstance(img, Image.Image):
|
||||||
fn = ''
|
image = img
|
||||||
else:
|
fn = ''
|
||||||
image = Image.open(os.path.abspath(img.name))
|
else:
|
||||||
fn = os.path.splitext(img.orig_name)[0]
|
image = Image.open(os.path.abspath(img.name))
|
||||||
image_data.append(image)
|
fn = os.path.splitext(img.orig_name)[0]
|
||||||
image_names.append(fn)
|
yield image, fn
|
||||||
elif extras_mode == 2:
|
elif extras_mode == 2:
|
||||||
assert not shared.cmd_opts.hide_ui_dir_config, '--hide-ui-dir-config option must be disabled'
|
assert not shared.cmd_opts.hide_ui_dir_config, '--hide-ui-dir-config option must be disabled'
|
||||||
assert input_dir, 'input directory not selected'
|
assert input_dir, 'input directory not selected'
|
||||||
|
|
||||||
image_list = shared.listfiles(input_dir)
|
image_list = shared.listfiles(input_dir)
|
||||||
for filename in image_list:
|
for filename in image_list:
|
||||||
try:
|
try:
|
||||||
image = Image.open(filename)
|
image = Image.open(filename)
|
||||||
except Exception:
|
except Exception:
|
||||||
continue
|
continue
|
||||||
image_data.append(image)
|
yield image, filename
|
||||||
image_names.append(filename)
|
else:
|
||||||
else:
|
assert image, 'image not selected'
|
||||||
assert image, 'image not selected'
|
yield image, None
|
||||||
|
|
||||||
image_data.append(image)
|
|
||||||
image_names.append(None)
|
|
||||||
|
|
||||||
if extras_mode == 2 and output_dir != '':
|
if extras_mode == 2 and output_dir != '':
|
||||||
outpath = output_dir
|
outpath = output_dir
|
||||||
@ -50,14 +45,16 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir,
|
|||||||
|
|
||||||
infotext = ''
|
infotext = ''
|
||||||
|
|
||||||
for image, name in zip(image_data, image_names):
|
for image_data, name in get_images(extras_mode, image, image_folder, input_dir):
|
||||||
|
image_data: Image.Image
|
||||||
|
|
||||||
shared.state.textinfo = name
|
shared.state.textinfo = name
|
||||||
|
|
||||||
parameters, existing_pnginfo = images.read_info_from_image(image)
|
parameters, existing_pnginfo = images.read_info_from_image(image_data)
|
||||||
if parameters:
|
if parameters:
|
||||||
existing_pnginfo["parameters"] = parameters
|
existing_pnginfo["parameters"] = parameters
|
||||||
|
|
||||||
pp = scripts_postprocessing.PostprocessedImage(image.convert("RGB"))
|
pp = scripts_postprocessing.PostprocessedImage(image_data.convert("RGB"))
|
||||||
|
|
||||||
scripts.scripts_postproc.run(pp, args)
|
scripts.scripts_postproc.run(pp, args)
|
||||||
|
|
||||||
@ -78,6 +75,8 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir,
|
|||||||
if extras_mode != 2 or show_extras_results:
|
if extras_mode != 2 or show_extras_results:
|
||||||
outputs.append(pp.image)
|
outputs.append(pp.image)
|
||||||
|
|
||||||
|
image_data.close()
|
||||||
|
|
||||||
devices.torch_gc()
|
devices.torch_gc()
|
||||||
|
|
||||||
return outputs, ui_common.plaintext_to_html(infotext), ''
|
return outputs, ui_common.plaintext_to_html(infotext), ''
|
||||||
|
@ -1,9 +1,11 @@
|
|||||||
|
from __future__ import annotations
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import hashlib
|
import hashlib
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -11,7 +13,7 @@ from PIL import Image, ImageOps
|
|||||||
import random
|
import random
|
||||||
import cv2
|
import cv2
|
||||||
from skimage import exposure
|
from skimage import exposure
|
||||||
from typing import Any, Dict, List
|
from typing import Any
|
||||||
|
|
||||||
import modules.sd_hijack
|
import modules.sd_hijack
|
||||||
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors, rng
|
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors, rng
|
||||||
@ -57,7 +59,7 @@ def apply_color_correction(correction, original_image):
|
|||||||
|
|
||||||
image = blendLayers(image, original_image, BlendType.LUMINOSITY)
|
image = blendLayers(image, original_image, BlendType.LUMINOSITY)
|
||||||
|
|
||||||
return image
|
return image.convert('RGB')
|
||||||
|
|
||||||
|
|
||||||
def apply_overlay(image, paste_loc, index, overlays):
|
def apply_overlay(image, paste_loc, index, overlays):
|
||||||
@ -104,97 +106,163 @@ def txt2img_image_conditioning(sd_model, x, width, height):
|
|||||||
return x.new_zeros(x.shape[0], 5, 1, 1, dtype=x.dtype, device=x.device)
|
return x.new_zeros(x.shape[0], 5, 1, 1, dtype=x.dtype, device=x.device)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(repr=False)
|
||||||
class StableDiffusionProcessing:
|
class StableDiffusionProcessing:
|
||||||
"""
|
sd_model: object = None
|
||||||
The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing
|
outpath_samples: str = None
|
||||||
"""
|
outpath_grids: str = None
|
||||||
|
prompt: str = ""
|
||||||
|
prompt_for_display: str = None
|
||||||
|
negative_prompt: str = ""
|
||||||
|
styles: list[str] = None
|
||||||
|
seed: int = -1
|
||||||
|
subseed: int = -1
|
||||||
|
subseed_strength: float = 0
|
||||||
|
seed_resize_from_h: int = -1
|
||||||
|
seed_resize_from_w: int = -1
|
||||||
|
seed_enable_extras: bool = True
|
||||||
|
sampler_name: str = None
|
||||||
|
batch_size: int = 1
|
||||||
|
n_iter: int = 1
|
||||||
|
steps: int = 50
|
||||||
|
cfg_scale: float = 7.0
|
||||||
|
width: int = 512
|
||||||
|
height: int = 512
|
||||||
|
restore_faces: bool = None
|
||||||
|
tiling: bool = None
|
||||||
|
do_not_save_samples: bool = False
|
||||||
|
do_not_save_grid: bool = False
|
||||||
|
extra_generation_params: dict[str, Any] = None
|
||||||
|
overlay_images: list = None
|
||||||
|
eta: float = None
|
||||||
|
do_not_reload_embeddings: bool = False
|
||||||
|
denoising_strength: float = 0
|
||||||
|
ddim_discretize: str = None
|
||||||
|
s_min_uncond: float = None
|
||||||
|
s_churn: float = None
|
||||||
|
s_tmax: float = None
|
||||||
|
s_tmin: float = None
|
||||||
|
s_noise: float = None
|
||||||
|
override_settings: dict[str, Any] = None
|
||||||
|
override_settings_restore_afterwards: bool = True
|
||||||
|
sampler_index: int = None
|
||||||
|
refiner_checkpoint: str = None
|
||||||
|
refiner_switch_at: float = None
|
||||||
|
token_merging_ratio = 0
|
||||||
|
token_merging_ratio_hr = 0
|
||||||
|
disable_extra_networks: bool = False
|
||||||
|
|
||||||
|
scripts_value: scripts.ScriptRunner = field(default=None, init=False)
|
||||||
|
script_args_value: list = field(default=None, init=False)
|
||||||
|
scripts_setup_complete: bool = field(default=False, init=False)
|
||||||
|
|
||||||
cached_uc = [None, None]
|
cached_uc = [None, None]
|
||||||
cached_c = [None, None]
|
cached_c = [None, None]
|
||||||
|
|
||||||
def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = None, tiling: bool = None, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_min_uncond: float = 0.0, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = None, override_settings: Dict[str, Any] = None, override_settings_restore_afterwards: bool = True, sampler_index: int = None, script_args: list = None):
|
comments: dict = None
|
||||||
if sampler_index is not None:
|
sampler: sd_samplers_common.Sampler | None = field(default=None, init=False)
|
||||||
|
is_using_inpainting_conditioning: bool = field(default=False, init=False)
|
||||||
|
paste_to: tuple | None = field(default=None, init=False)
|
||||||
|
|
||||||
|
is_hr_pass: bool = field(default=False, init=False)
|
||||||
|
|
||||||
|
c: tuple = field(default=None, init=False)
|
||||||
|
uc: tuple = field(default=None, init=False)
|
||||||
|
|
||||||
|
rng: rng.ImageRNG | None = field(default=None, init=False)
|
||||||
|
step_multiplier: int = field(default=1, init=False)
|
||||||
|
color_corrections: list = field(default=None, init=False)
|
||||||
|
|
||||||
|
all_prompts: list = field(default=None, init=False)
|
||||||
|
all_negative_prompts: list = field(default=None, init=False)
|
||||||
|
all_seeds: list = field(default=None, init=False)
|
||||||
|
all_subseeds: list = field(default=None, init=False)
|
||||||
|
iteration: int = field(default=0, init=False)
|
||||||
|
main_prompt: str = field(default=None, init=False)
|
||||||
|
main_negative_prompt: str = field(default=None, init=False)
|
||||||
|
|
||||||
|
prompts: list = field(default=None, init=False)
|
||||||
|
negative_prompts: list = field(default=None, init=False)
|
||||||
|
seeds: list = field(default=None, init=False)
|
||||||
|
subseeds: list = field(default=None, init=False)
|
||||||
|
extra_network_data: dict = field(default=None, init=False)
|
||||||
|
|
||||||
|
user: str = field(default=None, init=False)
|
||||||
|
|
||||||
|
sd_model_name: str = field(default=None, init=False)
|
||||||
|
sd_model_hash: str = field(default=None, init=False)
|
||||||
|
sd_vae_name: str = field(default=None, init=False)
|
||||||
|
sd_vae_hash: str = field(default=None, init=False)
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if self.sampler_index is not None:
|
||||||
print("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name", file=sys.stderr)
|
print("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name", file=sys.stderr)
|
||||||
|
|
||||||
self.outpath_samples: str = outpath_samples
|
self.comments = {}
|
||||||
self.outpath_grids: str = outpath_grids
|
|
||||||
self.prompt: str = prompt
|
|
||||||
self.prompt_for_display: str = None
|
|
||||||
self.negative_prompt: str = (negative_prompt or "")
|
|
||||||
self.styles: list = styles or []
|
|
||||||
self.seed: int = seed
|
|
||||||
self.subseed: int = subseed
|
|
||||||
self.subseed_strength: float = subseed_strength
|
|
||||||
self.seed_resize_from_h: int = seed_resize_from_h
|
|
||||||
self.seed_resize_from_w: int = seed_resize_from_w
|
|
||||||
self.sampler_name: str = sampler_name
|
|
||||||
self.batch_size: int = batch_size
|
|
||||||
self.n_iter: int = n_iter
|
|
||||||
self.steps: int = steps
|
|
||||||
self.cfg_scale: float = cfg_scale
|
|
||||||
self.width: int = width
|
|
||||||
self.height: int = height
|
|
||||||
self.restore_faces: bool = restore_faces
|
|
||||||
self.tiling: bool = tiling
|
|
||||||
self.do_not_save_samples: bool = do_not_save_samples
|
|
||||||
self.do_not_save_grid: bool = do_not_save_grid
|
|
||||||
self.extra_generation_params: dict = extra_generation_params or {}
|
|
||||||
self.overlay_images = overlay_images
|
|
||||||
self.eta = eta
|
|
||||||
self.do_not_reload_embeddings = do_not_reload_embeddings
|
|
||||||
self.paste_to = None
|
|
||||||
self.color_corrections = None
|
|
||||||
self.denoising_strength: float = denoising_strength
|
|
||||||
self.sampler_noise_scheduler_override = None
|
|
||||||
self.ddim_discretize = ddim_discretize or opts.ddim_discretize
|
|
||||||
self.s_min_uncond = s_min_uncond or opts.s_min_uncond
|
|
||||||
self.s_churn = s_churn or opts.s_churn
|
|
||||||
self.s_tmin = s_tmin or opts.s_tmin
|
|
||||||
self.s_tmax = (s_tmax if s_tmax is not None else opts.s_tmax) or float('inf')
|
|
||||||
self.s_noise = s_noise if s_noise is not None else opts.s_noise
|
|
||||||
self.override_settings = {k: v for k, v in (override_settings or {}).items() if k not in shared.restricted_opts}
|
|
||||||
self.override_settings_restore_afterwards = override_settings_restore_afterwards
|
|
||||||
self.is_using_inpainting_conditioning = False
|
|
||||||
self.disable_extra_networks = False
|
|
||||||
self.token_merging_ratio = 0
|
|
||||||
self.token_merging_ratio_hr = 0
|
|
||||||
|
|
||||||
if not seed_enable_extras:
|
if self.styles is None:
|
||||||
|
self.styles = []
|
||||||
|
|
||||||
|
self.sampler_noise_scheduler_override = None
|
||||||
|
self.s_min_uncond = self.s_min_uncond if self.s_min_uncond is not None else opts.s_min_uncond
|
||||||
|
self.s_churn = self.s_churn if self.s_churn is not None else opts.s_churn
|
||||||
|
self.s_tmin = self.s_tmin if self.s_tmin is not None else opts.s_tmin
|
||||||
|
self.s_tmax = (self.s_tmax if self.s_tmax is not None else opts.s_tmax) or float('inf')
|
||||||
|
self.s_noise = self.s_noise if self.s_noise is not None else opts.s_noise
|
||||||
|
|
||||||
|
self.extra_generation_params = self.extra_generation_params or {}
|
||||||
|
self.override_settings = self.override_settings or {}
|
||||||
|
self.script_args = self.script_args or {}
|
||||||
|
|
||||||
|
self.refiner_checkpoint_info = None
|
||||||
|
|
||||||
|
if not self.seed_enable_extras:
|
||||||
self.subseed = -1
|
self.subseed = -1
|
||||||
self.subseed_strength = 0
|
self.subseed_strength = 0
|
||||||
self.seed_resize_from_h = 0
|
self.seed_resize_from_h = 0
|
||||||
self.seed_resize_from_w = 0
|
self.seed_resize_from_w = 0
|
||||||
|
|
||||||
self.scripts = None
|
|
||||||
self.script_args = script_args
|
|
||||||
self.all_prompts = None
|
|
||||||
self.all_negative_prompts = None
|
|
||||||
self.all_seeds = None
|
|
||||||
self.all_subseeds = None
|
|
||||||
self.iteration = 0
|
|
||||||
self.is_hr_pass = False
|
|
||||||
self.sampler = None
|
|
||||||
self.main_prompt = None
|
|
||||||
self.main_negative_prompt = None
|
|
||||||
|
|
||||||
self.prompts = None
|
|
||||||
self.negative_prompts = None
|
|
||||||
self.extra_network_data = None
|
|
||||||
self.seeds = None
|
|
||||||
self.subseeds = None
|
|
||||||
|
|
||||||
self.step_multiplier = 1
|
|
||||||
self.cached_uc = StableDiffusionProcessing.cached_uc
|
self.cached_uc = StableDiffusionProcessing.cached_uc
|
||||||
self.cached_c = StableDiffusionProcessing.cached_c
|
self.cached_c = StableDiffusionProcessing.cached_c
|
||||||
self.uc = None
|
|
||||||
self.c = None
|
|
||||||
self.rng: rng.ImageRNG = None
|
|
||||||
|
|
||||||
self.user = None
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def sd_model(self):
|
def sd_model(self):
|
||||||
return shared.sd_model
|
return shared.sd_model
|
||||||
|
|
||||||
|
@sd_model.setter
|
||||||
|
def sd_model(self, value):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@property
|
||||||
|
def scripts(self):
|
||||||
|
return self.scripts_value
|
||||||
|
|
||||||
|
@scripts.setter
|
||||||
|
def scripts(self, value):
|
||||||
|
self.scripts_value = value
|
||||||
|
|
||||||
|
if self.scripts_value and self.script_args_value and not self.scripts_setup_complete:
|
||||||
|
self.setup_scripts()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def script_args(self):
|
||||||
|
return self.script_args_value
|
||||||
|
|
||||||
|
@script_args.setter
|
||||||
|
def script_args(self, value):
|
||||||
|
self.script_args_value = value
|
||||||
|
|
||||||
|
if self.scripts_value and self.script_args_value and not self.scripts_setup_complete:
|
||||||
|
self.setup_scripts()
|
||||||
|
|
||||||
|
def setup_scripts(self):
|
||||||
|
self.scripts_setup_complete = True
|
||||||
|
|
||||||
|
self.scripts.setup_scrips(self)
|
||||||
|
|
||||||
|
def comment(self, text):
|
||||||
|
self.comments[text] = 1
|
||||||
|
|
||||||
def txt2img_image_conditioning(self, x, width=None, height=None):
|
def txt2img_image_conditioning(self, x, width=None, height=None):
|
||||||
self.is_using_inpainting_conditioning = self.sd_model.model.conditioning_key in {'hybrid', 'concat'}
|
self.is_using_inpainting_conditioning = self.sd_model.model.conditioning_key in {'hybrid', 'concat'}
|
||||||
|
|
||||||
@ -343,7 +411,7 @@ class StableDiffusionProcessing:
|
|||||||
self.height,
|
self.height,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_conds_with_caching(self, function, required_prompts, steps, hires_steps, caches, extra_network_data):
|
def get_conds_with_caching(self, function, required_prompts, steps, caches, extra_network_data, hires_steps=None):
|
||||||
"""
|
"""
|
||||||
Returns the result of calling function(shared.sd_model, required_prompts, steps)
|
Returns the result of calling function(shared.sd_model, required_prompts, steps)
|
||||||
using a cache to store the result if the same arguments have been used before.
|
using a cache to store the result if the same arguments have been used before.
|
||||||
@ -375,10 +443,12 @@ class StableDiffusionProcessing:
|
|||||||
negative_prompts = prompt_parser.SdConditioning(self.negative_prompts, width=self.width, height=self.height, is_negative_prompt=True)
|
negative_prompts = prompt_parser.SdConditioning(self.negative_prompts, width=self.width, height=self.height, is_negative_prompt=True)
|
||||||
|
|
||||||
sampler_config = sd_samplers.find_sampler_config(self.sampler_name)
|
sampler_config = sd_samplers.find_sampler_config(self.sampler_name)
|
||||||
self.step_multiplier = 2 if sampler_config and sampler_config.options.get("second_order", False) else 1
|
total_steps = sampler_config.total_steps(self.steps) if sampler_config else self.steps
|
||||||
self.firstpass_steps = self.steps * self.step_multiplier
|
self.step_multiplier = total_steps // self.steps
|
||||||
self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, self.firstpass_steps, None, [self.cached_uc], self.extra_network_data)
|
self.firstpass_steps = total_steps
|
||||||
self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, self.firstpass_steps, None, [self.cached_c], self.extra_network_data)
|
|
||||||
|
self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, total_steps, [self.cached_uc], self.extra_network_data, None)
|
||||||
|
self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, total_steps, [self.cached_c], self.extra_network_data, None )
|
||||||
|
|
||||||
def get_conds(self):
|
def get_conds(self):
|
||||||
return self.c, self.uc
|
return self.c, self.uc
|
||||||
@ -400,7 +470,7 @@ class Processed:
|
|||||||
self.subseed = subseed
|
self.subseed = subseed
|
||||||
self.subseed_strength = p.subseed_strength
|
self.subseed_strength = p.subseed_strength
|
||||||
self.info = info
|
self.info = info
|
||||||
self.comments = comments
|
self.comments = "".join(f"{comment}\n" for comment in p.comments)
|
||||||
self.width = p.width
|
self.width = p.width
|
||||||
self.height = p.height
|
self.height = p.height
|
||||||
self.sampler_name = p.sampler_name
|
self.sampler_name = p.sampler_name
|
||||||
@ -410,7 +480,10 @@ class Processed:
|
|||||||
self.batch_size = p.batch_size
|
self.batch_size = p.batch_size
|
||||||
self.restore_faces = p.restore_faces
|
self.restore_faces = p.restore_faces
|
||||||
self.face_restoration_model = opts.face_restoration_model if p.restore_faces else None
|
self.face_restoration_model = opts.face_restoration_model if p.restore_faces else None
|
||||||
self.sd_model_hash = shared.sd_model.sd_model_hash
|
self.sd_model_name = p.sd_model_name
|
||||||
|
self.sd_model_hash = p.sd_model_hash
|
||||||
|
self.sd_vae_name = p.sd_vae_name
|
||||||
|
self.sd_vae_hash = p.sd_vae_hash
|
||||||
self.seed_resize_from_w = p.seed_resize_from_w
|
self.seed_resize_from_w = p.seed_resize_from_w
|
||||||
self.seed_resize_from_h = p.seed_resize_from_h
|
self.seed_resize_from_h = p.seed_resize_from_h
|
||||||
self.denoising_strength = getattr(p, 'denoising_strength', None)
|
self.denoising_strength = getattr(p, 'denoising_strength', None)
|
||||||
@ -461,7 +534,10 @@ class Processed:
|
|||||||
"batch_size": self.batch_size,
|
"batch_size": self.batch_size,
|
||||||
"restore_faces": self.restore_faces,
|
"restore_faces": self.restore_faces,
|
||||||
"face_restoration_model": self.face_restoration_model,
|
"face_restoration_model": self.face_restoration_model,
|
||||||
|
"sd_model_name": self.sd_model_name,
|
||||||
"sd_model_hash": self.sd_model_hash,
|
"sd_model_hash": self.sd_model_hash,
|
||||||
|
"sd_vae_name": self.sd_vae_name,
|
||||||
|
"sd_vae_hash": self.sd_vae_hash,
|
||||||
"seed_resize_from_w": self.seed_resize_from_w,
|
"seed_resize_from_w": self.seed_resize_from_w,
|
||||||
"seed_resize_from_h": self.seed_resize_from_h,
|
"seed_resize_from_h": self.seed_resize_from_h,
|
||||||
"denoising_strength": self.denoising_strength,
|
"denoising_strength": self.denoising_strength,
|
||||||
@ -580,10 +656,10 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
|
|||||||
"Seed": p.all_seeds[0] if use_main_prompt else all_seeds[index],
|
"Seed": p.all_seeds[0] if use_main_prompt else all_seeds[index],
|
||||||
"Face restoration": opts.face_restoration_model if p.restore_faces else None,
|
"Face restoration": opts.face_restoration_model if p.restore_faces else None,
|
||||||
"Size": f"{p.width}x{p.height}",
|
"Size": f"{p.width}x{p.height}",
|
||||||
"Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
|
"Model hash": p.sd_model_hash if opts.add_model_hash_to_info else None,
|
||||||
"Model": (None if not opts.add_model_name_to_info else shared.sd_model.sd_checkpoint_info.name_for_extra),
|
"Model": p.sd_model_name if opts.add_model_name_to_info else None,
|
||||||
"VAE hash": sd_vae.get_loaded_vae_hash() if opts.add_model_hash_to_info else None,
|
"VAE hash": p.sd_vae_hash if opts.add_model_hash_to_info else None,
|
||||||
"VAE": sd_vae.get_loaded_vae_name() if opts.add_model_name_to_info else None,
|
"VAE": p.sd_vae_name if opts.add_model_name_to_info else None,
|
||||||
"Variation seed": (None if p.subseed_strength == 0 else (p.all_subseeds[0] if use_main_prompt else all_subseeds[index])),
|
"Variation seed": (None if p.subseed_strength == 0 else (p.all_subseeds[0] if use_main_prompt else all_subseeds[index])),
|
||||||
"Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
|
"Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
|
||||||
"Seed resize from": (None if p.seed_resize_from_w <= 0 or p.seed_resize_from_h <= 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
|
"Seed resize from": (None if p.seed_resize_from_w <= 0 or p.seed_resize_from_h <= 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
|
||||||
@ -672,11 +748,19 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||||||
if p.tiling is None:
|
if p.tiling is None:
|
||||||
p.tiling = opts.tiling
|
p.tiling = opts.tiling
|
||||||
|
|
||||||
|
if p.refiner_checkpoint not in (None, "", "None"):
|
||||||
|
p.refiner_checkpoint_info = sd_models.get_closet_checkpoint_match(p.refiner_checkpoint)
|
||||||
|
if p.refiner_checkpoint_info is None:
|
||||||
|
raise Exception(f'Could not find checkpoint with name {p.refiner_checkpoint}')
|
||||||
|
|
||||||
|
p.sd_model_name = shared.sd_model.sd_checkpoint_info.name_for_extra
|
||||||
|
p.sd_model_hash = shared.sd_model.sd_model_hash
|
||||||
|
p.sd_vae_name = sd_vae.get_loaded_vae_name()
|
||||||
|
p.sd_vae_hash = sd_vae.get_loaded_vae_hash()
|
||||||
|
|
||||||
modules.sd_hijack.model_hijack.apply_circular(p.tiling)
|
modules.sd_hijack.model_hijack.apply_circular(p.tiling)
|
||||||
modules.sd_hijack.model_hijack.clear_comments()
|
modules.sd_hijack.model_hijack.clear_comments()
|
||||||
|
|
||||||
comments = {}
|
|
||||||
|
|
||||||
p.setup_prompts()
|
p.setup_prompts()
|
||||||
|
|
||||||
if type(seed) == list:
|
if type(seed) == list:
|
||||||
@ -756,7 +840,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||||||
p.setup_conds()
|
p.setup_conds()
|
||||||
|
|
||||||
for comment in model_hijack.comments:
|
for comment in model_hijack.comments:
|
||||||
comments[comment] = 1
|
p.comment(comment)
|
||||||
|
|
||||||
p.extra_generation_params.update(model_hijack.extra_generation_params)
|
p.extra_generation_params.update(model_hijack.extra_generation_params)
|
||||||
|
|
||||||
@ -885,7 +969,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||||||
images_list=output_images,
|
images_list=output_images,
|
||||||
seed=p.all_seeds[0],
|
seed=p.all_seeds[0],
|
||||||
info=infotexts[0],
|
info=infotexts[0],
|
||||||
comments="".join(f"{comment}\n" for comment in comments),
|
|
||||||
subseed=p.all_subseeds[0],
|
subseed=p.all_subseeds[0],
|
||||||
index_of_first_image=index_of_first_image,
|
index_of_first_image=index_of_first_image,
|
||||||
infotexts=infotexts,
|
infotexts=infotexts,
|
||||||
@ -909,49 +992,51 @@ def old_hires_fix_first_pass_dimensions(width, height):
|
|||||||
return width, height
|
return width, height
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(repr=False)
|
||||||
class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||||
sampler = None
|
enable_hr: bool = False
|
||||||
|
denoising_strength: float = 0.75
|
||||||
|
firstphase_width: int = 0
|
||||||
|
firstphase_height: int = 0
|
||||||
|
hr_scale: float = 2.0
|
||||||
|
hr_upscaler: str = None
|
||||||
|
hr_second_pass_steps: int = 0
|
||||||
|
hr_resize_x: int = 0
|
||||||
|
hr_resize_y: int = 0
|
||||||
|
hr_checkpoint_name: str = None
|
||||||
|
hr_sampler_name: str = None
|
||||||
|
hr_prompt: str = ''
|
||||||
|
hr_negative_prompt: str = ''
|
||||||
|
|
||||||
cached_hr_uc = [None, None]
|
cached_hr_uc = [None, None]
|
||||||
cached_hr_c = [None, None]
|
cached_hr_c = [None, None]
|
||||||
|
|
||||||
def __init__(self, enable_hr: bool = False, denoising_strength: float = 0.75, firstphase_width: int = 0, firstphase_height: int = 0, hr_scale: float = 2.0, hr_upscaler: str = None, hr_second_pass_steps: int = 0, hr_resize_x: int = 0, hr_resize_y: int = 0, hr_checkpoint_name: str = None, hr_sampler_name: str = None, hr_prompt: str = '', hr_negative_prompt: str = '', **kwargs):
|
hr_checkpoint_info: dict = field(default=None, init=False)
|
||||||
super().__init__(**kwargs)
|
hr_upscale_to_x: int = field(default=0, init=False)
|
||||||
self.enable_hr = enable_hr
|
hr_upscale_to_y: int = field(default=0, init=False)
|
||||||
self.denoising_strength = denoising_strength
|
truncate_x: int = field(default=0, init=False)
|
||||||
self.hr_scale = hr_scale
|
truncate_y: int = field(default=0, init=False)
|
||||||
self.hr_upscaler = hr_upscaler
|
applied_old_hires_behavior_to: tuple = field(default=None, init=False)
|
||||||
self.hr_second_pass_steps = hr_second_pass_steps
|
latent_scale_mode: dict = field(default=None, init=False)
|
||||||
self.hr_resize_x = hr_resize_x
|
hr_c: tuple | None = field(default=None, init=False)
|
||||||
self.hr_resize_y = hr_resize_y
|
hr_uc: tuple | None = field(default=None, init=False)
|
||||||
self.hr_upscale_to_x = hr_resize_x
|
all_hr_prompts: list = field(default=None, init=False)
|
||||||
self.hr_upscale_to_y = hr_resize_y
|
all_hr_negative_prompts: list = field(default=None, init=False)
|
||||||
self.hr_checkpoint_name = hr_checkpoint_name
|
hr_prompts: list = field(default=None, init=False)
|
||||||
self.hr_checkpoint_info = None
|
hr_negative_prompts: list = field(default=None, init=False)
|
||||||
self.hr_sampler_name = hr_sampler_name
|
hr_extra_network_data: list = field(default=None, init=False)
|
||||||
self.hr_prompt = hr_prompt
|
|
||||||
self.hr_negative_prompt = hr_negative_prompt
|
|
||||||
self.all_hr_prompts = None
|
|
||||||
self.all_hr_negative_prompts = None
|
|
||||||
self.latent_scale_mode = None
|
|
||||||
|
|
||||||
if firstphase_width != 0 or firstphase_height != 0:
|
def __post_init__(self):
|
||||||
|
super().__post_init__()
|
||||||
|
|
||||||
|
if self.firstphase_width != 0 or self.firstphase_height != 0:
|
||||||
self.hr_upscale_to_x = self.width
|
self.hr_upscale_to_x = self.width
|
||||||
self.hr_upscale_to_y = self.height
|
self.hr_upscale_to_y = self.height
|
||||||
self.width = firstphase_width
|
self.width = self.firstphase_width
|
||||||
self.height = firstphase_height
|
self.height = self.firstphase_height
|
||||||
|
|
||||||
self.truncate_x = 0
|
|
||||||
self.truncate_y = 0
|
|
||||||
self.applied_old_hires_behavior_to = None
|
|
||||||
|
|
||||||
self.hr_prompts = None
|
|
||||||
self.hr_negative_prompts = None
|
|
||||||
self.hr_extra_network_data = None
|
|
||||||
|
|
||||||
self.cached_hr_uc = StableDiffusionProcessingTxt2Img.cached_hr_uc
|
self.cached_hr_uc = StableDiffusionProcessingTxt2Img.cached_hr_uc
|
||||||
self.cached_hr_c = StableDiffusionProcessingTxt2Img.cached_hr_c
|
self.cached_hr_c = StableDiffusionProcessingTxt2Img.cached_hr_c
|
||||||
self.hr_c = None
|
|
||||||
self.hr_uc = None
|
|
||||||
|
|
||||||
def calculate_target_resolution(self):
|
def calculate_target_resolution(self):
|
||||||
if opts.use_old_hires_fix_width_height and self.applied_old_hires_behavior_to != (self.width, self.height):
|
if opts.use_old_hires_fix_width_height and self.applied_old_hires_behavior_to != (self.width, self.height):
|
||||||
@ -1145,6 +1230,9 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
|||||||
|
|
||||||
sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio())
|
sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio())
|
||||||
|
|
||||||
|
self.sampler = None
|
||||||
|
devices.torch_gc()
|
||||||
|
|
||||||
decoded_samples = decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)
|
decoded_samples = decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)
|
||||||
|
|
||||||
self.is_hr_pass = False
|
self.is_hr_pass = False
|
||||||
@ -1191,11 +1279,20 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
|||||||
hr_prompts = prompt_parser.SdConditioning(self.hr_prompts, width=self.hr_upscale_to_x, height=self.hr_upscale_to_y)
|
hr_prompts = prompt_parser.SdConditioning(self.hr_prompts, width=self.hr_upscale_to_x, height=self.hr_upscale_to_y)
|
||||||
hr_negative_prompts = prompt_parser.SdConditioning(self.hr_negative_prompts, width=self.hr_upscale_to_x, height=self.hr_upscale_to_y, is_negative_prompt=True)
|
hr_negative_prompts = prompt_parser.SdConditioning(self.hr_negative_prompts, width=self.hr_upscale_to_x, height=self.hr_upscale_to_y, is_negative_prompt=True)
|
||||||
|
|
||||||
hires_steps = (self.hr_second_pass_steps or self.steps) * self.step_multiplier
|
sampler_config = sd_samplers.find_sampler_config(self.hr_sampler_name or self.sampler_name)
|
||||||
self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, hr_negative_prompts, self.firstpass_steps, hires_steps, [self.cached_hr_uc, self.cached_uc], self.hr_extra_network_data)
|
steps = self.hr_second_pass_steps or self.steps
|
||||||
self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, hr_prompts, self.firstpass_steps, hires_steps, [self.cached_hr_c, self.cached_c], self.hr_extra_network_data)
|
total_steps = sampler_config.total_steps(steps) if sampler_config else steps
|
||||||
|
|
||||||
|
self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, hr_negative_prompts, self.firstpass_steps, [self.cached_hr_uc, self.cached_uc], self.hr_extra_network_data, total_steps)
|
||||||
|
self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, hr_prompts, self.firstpass_steps, [self.cached_hr_c, self.cached_c], self.hr_extra_network_data, total_steps)
|
||||||
|
|
||||||
def setup_conds(self):
|
def setup_conds(self):
|
||||||
|
if self.is_hr_pass:
|
||||||
|
# if we are in hr pass right now, the call is being made from the refiner, and we don't need to setup firstpass cons or switch model
|
||||||
|
self.hr_c = None
|
||||||
|
self.calculate_hr_conds()
|
||||||
|
return
|
||||||
|
|
||||||
super().setup_conds()
|
super().setup_conds()
|
||||||
|
|
||||||
self.hr_uc = None
|
self.hr_uc = None
|
||||||
@ -1220,7 +1317,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
|||||||
|
|
||||||
return super().get_conds()
|
return super().get_conds()
|
||||||
|
|
||||||
|
|
||||||
def parse_extra_network_prompts(self):
|
def parse_extra_network_prompts(self):
|
||||||
res = super().parse_extra_network_prompts()
|
res = super().parse_extra_network_prompts()
|
||||||
|
|
||||||
@ -1233,35 +1329,53 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
|||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(repr=False)
|
||||||
class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
||||||
sampler = None
|
init_images: list = None
|
||||||
|
resize_mode: int = 0
|
||||||
|
denoising_strength: float = 0.75
|
||||||
|
image_cfg_scale: float = None
|
||||||
|
mask: Any = None
|
||||||
|
mask_blur_x: int = 4
|
||||||
|
mask_blur_y: int = 4
|
||||||
|
mask_blur: int = None
|
||||||
|
inpainting_fill: int = 0
|
||||||
|
inpaint_full_res: bool = True
|
||||||
|
inpaint_full_res_padding: int = 0
|
||||||
|
inpainting_mask_invert: int = 0
|
||||||
|
initial_noise_multiplier: float = None
|
||||||
|
latent_mask: Image = None
|
||||||
|
|
||||||
def __init__(self, init_images: list = None, resize_mode: int = 0, denoising_strength: float = 0.75, image_cfg_scale: float = None, mask: Any = None, mask_blur: int = None, mask_blur_x: int = 4, mask_blur_y: int = 4, inpainting_fill: int = 0, inpaint_full_res: bool = True, inpaint_full_res_padding: int = 0, inpainting_mask_invert: int = 0, initial_noise_multiplier: float = None, **kwargs):
|
image_mask: Any = field(default=None, init=False)
|
||||||
super().__init__(**kwargs)
|
|
||||||
|
|
||||||
self.init_images = init_images
|
nmask: torch.Tensor = field(default=None, init=False)
|
||||||
self.resize_mode: int = resize_mode
|
image_conditioning: torch.Tensor = field(default=None, init=False)
|
||||||
self.denoising_strength: float = denoising_strength
|
init_img_hash: str = field(default=None, init=False)
|
||||||
self.image_cfg_scale: float = image_cfg_scale if shared.sd_model.cond_stage_key == "edit" else None
|
mask_for_overlay: Image = field(default=None, init=False)
|
||||||
self.init_latent = None
|
init_latent: torch.Tensor = field(default=None, init=False)
|
||||||
self.image_mask = mask
|
|
||||||
self.latent_mask = None
|
def __post_init__(self):
|
||||||
self.mask_for_overlay = None
|
super().__post_init__()
|
||||||
if mask_blur is not None:
|
|
||||||
mask_blur_x = mask_blur
|
self.image_mask = self.mask
|
||||||
mask_blur_y = mask_blur
|
|
||||||
self.mask_blur_x = mask_blur_x
|
|
||||||
self.mask_blur_y = mask_blur_y
|
|
||||||
self.inpainting_fill = inpainting_fill
|
|
||||||
self.inpaint_full_res = inpaint_full_res
|
|
||||||
self.inpaint_full_res_padding = inpaint_full_res_padding
|
|
||||||
self.inpainting_mask_invert = inpainting_mask_invert
|
|
||||||
self.initial_noise_multiplier = opts.initial_noise_multiplier if initial_noise_multiplier is None else initial_noise_multiplier
|
|
||||||
self.mask = None
|
self.mask = None
|
||||||
self.nmask = None
|
self.initial_noise_multiplier = opts.initial_noise_multiplier if self.initial_noise_multiplier is None else self.initial_noise_multiplier
|
||||||
self.image_conditioning = None
|
|
||||||
|
@property
|
||||||
|
def mask_blur(self):
|
||||||
|
if self.mask_blur_x == self.mask_blur_y:
|
||||||
|
return self.mask_blur_x
|
||||||
|
return None
|
||||||
|
|
||||||
|
@mask_blur.setter
|
||||||
|
def mask_blur(self, value):
|
||||||
|
if isinstance(value, int):
|
||||||
|
self.mask_blur_x = value
|
||||||
|
self.mask_blur_y = value
|
||||||
|
|
||||||
def init(self, all_prompts, all_seeds, all_subseeds):
|
def init(self, all_prompts, all_seeds, all_subseeds):
|
||||||
|
self.image_cfg_scale: float = self.image_cfg_scale if shared.sd_model.cond_stage_key == "edit" else None
|
||||||
|
|
||||||
self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
|
self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
|
||||||
crop_region = None
|
crop_region = None
|
||||||
|
|
||||||
@ -1275,13 +1389,13 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
|||||||
|
|
||||||
if self.mask_blur_x > 0:
|
if self.mask_blur_x > 0:
|
||||||
np_mask = np.array(image_mask)
|
np_mask = np.array(image_mask)
|
||||||
kernel_size = 2 * int(4 * self.mask_blur_x + 0.5) + 1
|
kernel_size = 2 * int(2.5 * self.mask_blur_x + 0.5) + 1
|
||||||
np_mask = cv2.GaussianBlur(np_mask, (kernel_size, 1), self.mask_blur_x)
|
np_mask = cv2.GaussianBlur(np_mask, (kernel_size, 1), self.mask_blur_x)
|
||||||
image_mask = Image.fromarray(np_mask)
|
image_mask = Image.fromarray(np_mask)
|
||||||
|
|
||||||
if self.mask_blur_y > 0:
|
if self.mask_blur_y > 0:
|
||||||
np_mask = np.array(image_mask)
|
np_mask = np.array(image_mask)
|
||||||
kernel_size = 2 * int(4 * self.mask_blur_y + 0.5) + 1
|
kernel_size = 2 * int(2.5 * self.mask_blur_y + 0.5) + 1
|
||||||
np_mask = cv2.GaussianBlur(np_mask, (1, kernel_size), self.mask_blur_y)
|
np_mask = cv2.GaussianBlur(np_mask, (1, kernel_size), self.mask_blur_y)
|
||||||
image_mask = Image.fromarray(np_mask)
|
image_mask = Image.fromarray(np_mask)
|
||||||
|
|
||||||
|
49
modules/processing_scripts/refiner.py
Normal file
49
modules/processing_scripts/refiner.py
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
import gradio as gr
|
||||||
|
|
||||||
|
from modules import scripts, sd_models
|
||||||
|
from modules.ui_common import create_refresh_button
|
||||||
|
from modules.ui_components import InputAccordion
|
||||||
|
|
||||||
|
|
||||||
|
class ScriptRefiner(scripts.Script):
|
||||||
|
section = "accordions"
|
||||||
|
create_group = False
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def title(self):
|
||||||
|
return "Refiner"
|
||||||
|
|
||||||
|
def show(self, is_img2img):
|
||||||
|
return scripts.AlwaysVisible
|
||||||
|
|
||||||
|
def ui(self, is_img2img):
|
||||||
|
with InputAccordion(False, label="Refiner", elem_id=self.elem_id("enable")) as enable_refiner:
|
||||||
|
with gr.Row():
|
||||||
|
refiner_checkpoint = gr.Dropdown(label='Checkpoint', elem_id=self.elem_id("checkpoint"), choices=sd_models.checkpoint_tiles(), value='', tooltip="switch to another model in the middle of generation")
|
||||||
|
create_refresh_button(refiner_checkpoint, sd_models.list_models, lambda: {"choices": sd_models.checkpoint_tiles()}, self.elem_id("checkpoint_refresh"))
|
||||||
|
|
||||||
|
refiner_switch_at = gr.Slider(value=0.8, label="Switch at", minimum=0.01, maximum=1.0, step=0.01, elem_id=self.elem_id("switch_at"), tooltip="fraction of sampling steps when the switch to refiner model should happen; 1=never, 0.5=switch in the middle of generation")
|
||||||
|
|
||||||
|
def lookup_checkpoint(title):
|
||||||
|
info = sd_models.get_closet_checkpoint_match(title)
|
||||||
|
return None if info is None else info.title
|
||||||
|
|
||||||
|
self.infotext_fields = [
|
||||||
|
(enable_refiner, lambda d: 'Refiner' in d),
|
||||||
|
(refiner_checkpoint, lambda d: lookup_checkpoint(d.get('Refiner'))),
|
||||||
|
(refiner_switch_at, 'Refiner switch at'),
|
||||||
|
]
|
||||||
|
|
||||||
|
return enable_refiner, refiner_checkpoint, refiner_switch_at
|
||||||
|
|
||||||
|
def setup(self, p, enable_refiner, refiner_checkpoint, refiner_switch_at):
|
||||||
|
# the actual implementation is in sd_samplers_common.py, apply_refiner
|
||||||
|
|
||||||
|
if not enable_refiner or refiner_checkpoint in (None, "", "None"):
|
||||||
|
p.refiner_checkpoint_info = None
|
||||||
|
p.refiner_switch_at = None
|
||||||
|
else:
|
||||||
|
p.refiner_checkpoint = refiner_checkpoint
|
||||||
|
p.refiner_switch_at = refiner_switch_at
|
111
modules/processing_scripts/seed.py
Normal file
111
modules/processing_scripts/seed.py
Normal file
@ -0,0 +1,111 @@
|
|||||||
|
import json
|
||||||
|
|
||||||
|
import gradio as gr
|
||||||
|
|
||||||
|
from modules import scripts, ui, errors
|
||||||
|
from modules.shared import cmd_opts
|
||||||
|
from modules.ui_components import ToolButton
|
||||||
|
|
||||||
|
|
||||||
|
class ScriptSeed(scripts.ScriptBuiltin):
|
||||||
|
section = "seed"
|
||||||
|
create_group = False
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.seed = None
|
||||||
|
self.reuse_seed = None
|
||||||
|
self.reuse_subseed = None
|
||||||
|
|
||||||
|
def title(self):
|
||||||
|
return "Seed"
|
||||||
|
|
||||||
|
def show(self, is_img2img):
|
||||||
|
return scripts.AlwaysVisible
|
||||||
|
|
||||||
|
def ui(self, is_img2img):
|
||||||
|
with gr.Row(elem_id=self.elem_id("seed_row")):
|
||||||
|
if cmd_opts.use_textbox_seed:
|
||||||
|
self.seed = gr.Textbox(label='Seed', value="", elem_id=self.elem_id("seed"), min_width=100)
|
||||||
|
else:
|
||||||
|
self.seed = gr.Number(label='Seed', value=-1, elem_id=self.elem_id("seed"), min_width=100, precision=0)
|
||||||
|
|
||||||
|
random_seed = ToolButton(ui.random_symbol, elem_id=self.elem_id("random_seed"), label='Random seed')
|
||||||
|
reuse_seed = ToolButton(ui.reuse_symbol, elem_id=self.elem_id("reuse_seed"), label='Reuse seed')
|
||||||
|
|
||||||
|
seed_checkbox = gr.Checkbox(label='Extra', elem_id=self.elem_id("subseed_show"), value=False)
|
||||||
|
|
||||||
|
with gr.Group(visible=False, elem_id=self.elem_id("seed_extras")) as seed_extras:
|
||||||
|
with gr.Row(elem_id=self.elem_id("subseed_row")):
|
||||||
|
subseed = gr.Number(label='Variation seed', value=-1, elem_id=self.elem_id("subseed"), precision=0)
|
||||||
|
random_subseed = ToolButton(ui.random_symbol, elem_id=self.elem_id("random_subseed"))
|
||||||
|
reuse_subseed = ToolButton(ui.reuse_symbol, elem_id=self.elem_id("reuse_subseed"))
|
||||||
|
subseed_strength = gr.Slider(label='Variation strength', value=0.0, minimum=0, maximum=1, step=0.01, elem_id=self.elem_id("subseed_strength"))
|
||||||
|
|
||||||
|
with gr.Row(elem_id=self.elem_id("seed_resize_from_row")):
|
||||||
|
seed_resize_from_w = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from width", value=0, elem_id=self.elem_id("seed_resize_from_w"))
|
||||||
|
seed_resize_from_h = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from height", value=0, elem_id=self.elem_id("seed_resize_from_h"))
|
||||||
|
|
||||||
|
random_seed.click(fn=None, _js="function(){setRandomSeed('" + self.elem_id("seed") + "')}", show_progress=False, inputs=[], outputs=[])
|
||||||
|
random_subseed.click(fn=None, _js="function(){setRandomSeed('" + self.elem_id("subseed") + "')}", show_progress=False, inputs=[], outputs=[])
|
||||||
|
|
||||||
|
seed_checkbox.change(lambda x: gr.update(visible=x), show_progress=False, inputs=[seed_checkbox], outputs=[seed_extras])
|
||||||
|
|
||||||
|
self.infotext_fields = [
|
||||||
|
(self.seed, "Seed"),
|
||||||
|
(seed_checkbox, lambda d: "Variation seed" in d or "Seed resize from-1" in d),
|
||||||
|
(subseed, "Variation seed"),
|
||||||
|
(subseed_strength, "Variation seed strength"),
|
||||||
|
(seed_resize_from_w, "Seed resize from-1"),
|
||||||
|
(seed_resize_from_h, "Seed resize from-2"),
|
||||||
|
]
|
||||||
|
|
||||||
|
self.on_after_component(lambda x: connect_reuse_seed(self.seed, reuse_seed, x.component, False), elem_id=f'generation_info_{self.tabname}')
|
||||||
|
self.on_after_component(lambda x: connect_reuse_seed(subseed, reuse_subseed, x.component, True), elem_id=f'generation_info_{self.tabname}')
|
||||||
|
|
||||||
|
return self.seed, seed_checkbox, subseed, subseed_strength, seed_resize_from_w, seed_resize_from_h
|
||||||
|
|
||||||
|
def setup(self, p, seed, seed_checkbox, subseed, subseed_strength, seed_resize_from_w, seed_resize_from_h):
|
||||||
|
p.seed = seed
|
||||||
|
|
||||||
|
if seed_checkbox and subseed_strength > 0:
|
||||||
|
p.subseed = subseed
|
||||||
|
p.subseed_strength = subseed_strength
|
||||||
|
|
||||||
|
if seed_checkbox and seed_resize_from_w > 0 and seed_resize_from_h > 0:
|
||||||
|
p.seed_resize_from_w = seed_resize_from_w
|
||||||
|
p.seed_resize_from_h = seed_resize_from_h
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info: gr.Textbox, is_subseed):
|
||||||
|
""" Connects a 'reuse (sub)seed' button's click event so that it copies last used
|
||||||
|
(sub)seed value from generation info the to the seed field. If copying subseed and subseed strength
|
||||||
|
was 0, i.e. no variation seed was used, it copies the normal seed value instead."""
|
||||||
|
|
||||||
|
def copy_seed(gen_info_string: str, index):
|
||||||
|
res = -1
|
||||||
|
|
||||||
|
try:
|
||||||
|
gen_info = json.loads(gen_info_string)
|
||||||
|
index -= gen_info.get('index_of_first_image', 0)
|
||||||
|
|
||||||
|
if is_subseed and gen_info.get('subseed_strength', 0) > 0:
|
||||||
|
all_subseeds = gen_info.get('all_subseeds', [-1])
|
||||||
|
res = all_subseeds[index if 0 <= index < len(all_subseeds) else 0]
|
||||||
|
else:
|
||||||
|
all_seeds = gen_info.get('all_seeds', [-1])
|
||||||
|
res = all_seeds[index if 0 <= index < len(all_seeds) else 0]
|
||||||
|
|
||||||
|
except json.decoder.JSONDecodeError:
|
||||||
|
if gen_info_string:
|
||||||
|
errors.report(f"Error parsing JSON generation info: {gen_info_string}")
|
||||||
|
|
||||||
|
return [res, gr.update()]
|
||||||
|
|
||||||
|
reuse_seed.click(
|
||||||
|
fn=copy_seed,
|
||||||
|
_js="(x, y) => [x, selected_gallery_index()]",
|
||||||
|
show_progress=False,
|
||||||
|
inputs=[generation_info, seed],
|
||||||
|
outputs=[seed, seed]
|
||||||
|
)
|
@ -3,6 +3,7 @@ import re
|
|||||||
import sys
|
import sys
|
||||||
import inspect
|
import inspect
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
@ -21,6 +22,11 @@ class PostprocessBatchListArgs:
|
|||||||
self.images = images
|
self.images = images
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class OnComponent:
|
||||||
|
component: gr.blocks.Block
|
||||||
|
|
||||||
|
|
||||||
class Script:
|
class Script:
|
||||||
name = None
|
name = None
|
||||||
"""script's internal name derived from title"""
|
"""script's internal name derived from title"""
|
||||||
@ -35,9 +41,13 @@ class Script:
|
|||||||
|
|
||||||
is_txt2img = False
|
is_txt2img = False
|
||||||
is_img2img = False
|
is_img2img = False
|
||||||
|
tabname = None
|
||||||
|
|
||||||
group = None
|
group = None
|
||||||
"""A gr.Group component that has all script's UI inside it"""
|
"""A gr.Group component that has all script's UI inside it."""
|
||||||
|
|
||||||
|
create_group = True
|
||||||
|
"""If False, for alwayson scripts, a group component will not be created."""
|
||||||
|
|
||||||
infotext_fields = None
|
infotext_fields = None
|
||||||
"""if set in ui(), this is a list of pairs of gradio component + text; the text will be used when
|
"""if set in ui(), this is a list of pairs of gradio component + text; the text will be used when
|
||||||
@ -52,6 +62,12 @@ class Script:
|
|||||||
api_info = None
|
api_info = None
|
||||||
"""Generated value of type modules.api.models.ScriptInfo with information about the script for API"""
|
"""Generated value of type modules.api.models.ScriptInfo with information about the script for API"""
|
||||||
|
|
||||||
|
on_before_component_elem_id = None
|
||||||
|
"""list of callbacks to be called before a component with an elem_id is created"""
|
||||||
|
|
||||||
|
on_after_component_elem_id = None
|
||||||
|
"""list of callbacks to be called after a component with an elem_id is created"""
|
||||||
|
|
||||||
def title(self):
|
def title(self):
|
||||||
"""this function should return the title of the script. This is what will be displayed in the dropdown menu."""
|
"""this function should return the title of the script. This is what will be displayed in the dropdown menu."""
|
||||||
|
|
||||||
@ -90,9 +106,16 @@ class Script:
|
|||||||
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def setup(self, p, *args):
|
||||||
|
"""For AlwaysVisible scripts, this function is called when the processing object is set up, before any processing starts.
|
||||||
|
args contains all values returned by components from ui().
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def before_process(self, p, *args):
|
def before_process(self, p, *args):
|
||||||
"""
|
"""
|
||||||
This function is called very early before processing begins for AlwaysVisible scripts.
|
This function is called very early during processing begins for AlwaysVisible scripts.
|
||||||
You can modify the processing object (p) here, inject hooks, etc.
|
You can modify the processing object (p) here, inject hooks, etc.
|
||||||
args contains all values returned by components from ui()
|
args contains all values returned by components from ui()
|
||||||
"""
|
"""
|
||||||
@ -212,6 +235,30 @@ class Script:
|
|||||||
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def on_before_component(self, callback, *, elem_id):
|
||||||
|
"""
|
||||||
|
Calls callback before a component is created. The callback function is called with a single argument of type OnComponent.
|
||||||
|
|
||||||
|
May be called in show() or ui() - but it may be too late in latter as some components may already be created.
|
||||||
|
|
||||||
|
This function is an alternative to before_component in that it also cllows to run before a component is created, but
|
||||||
|
it doesn't require to be called for every created component - just for the one you need.
|
||||||
|
"""
|
||||||
|
if self.on_before_component_elem_id is None:
|
||||||
|
self.on_before_component_elem_id = []
|
||||||
|
|
||||||
|
self.on_before_component_elem_id.append((elem_id, callback))
|
||||||
|
|
||||||
|
def on_after_component(self, callback, *, elem_id):
|
||||||
|
"""
|
||||||
|
Calls callback after a component is created. The callback function is called with a single argument of type OnComponent.
|
||||||
|
"""
|
||||||
|
if self.on_after_component_elem_id is None:
|
||||||
|
self.on_after_component_elem_id = []
|
||||||
|
|
||||||
|
self.on_after_component_elem_id.append((elem_id, callback))
|
||||||
|
|
||||||
|
|
||||||
def describe(self):
|
def describe(self):
|
||||||
"""unused"""
|
"""unused"""
|
||||||
return ""
|
return ""
|
||||||
@ -232,6 +279,18 @@ class Script:
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ScriptBuiltin(Script):
|
||||||
|
|
||||||
|
def elem_id(self, item_id):
|
||||||
|
"""helper function to generate id for a HTML element, constructs final id out of tab and user-supplied item_id"""
|
||||||
|
|
||||||
|
need_tabname = self.show(True) == self.show(False)
|
||||||
|
tabname = ('img2img' if self.is_img2img else 'txt2txt') + "_" if need_tabname else ""
|
||||||
|
|
||||||
|
return f'{tabname}{item_id}'
|
||||||
|
|
||||||
|
|
||||||
current_basedir = paths.script_path
|
current_basedir = paths.script_path
|
||||||
|
|
||||||
|
|
||||||
@ -250,7 +309,7 @@ postprocessing_scripts_data = []
|
|||||||
ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir", "module"])
|
ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir", "module"])
|
||||||
|
|
||||||
|
|
||||||
def list_scripts(scriptdirname, extension):
|
def list_scripts(scriptdirname, extension, *, include_extensions=True):
|
||||||
scripts_list = []
|
scripts_list = []
|
||||||
|
|
||||||
basedir = os.path.join(paths.script_path, scriptdirname)
|
basedir = os.path.join(paths.script_path, scriptdirname)
|
||||||
@ -258,8 +317,9 @@ def list_scripts(scriptdirname, extension):
|
|||||||
for filename in sorted(os.listdir(basedir)):
|
for filename in sorted(os.listdir(basedir)):
|
||||||
scripts_list.append(ScriptFile(paths.script_path, filename, os.path.join(basedir, filename)))
|
scripts_list.append(ScriptFile(paths.script_path, filename, os.path.join(basedir, filename)))
|
||||||
|
|
||||||
for ext in extensions.active():
|
if include_extensions:
|
||||||
scripts_list += ext.list_files(scriptdirname, extension)
|
for ext in extensions.active():
|
||||||
|
scripts_list += ext.list_files(scriptdirname, extension)
|
||||||
|
|
||||||
scripts_list = [x for x in scripts_list if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)]
|
scripts_list = [x for x in scripts_list if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)]
|
||||||
|
|
||||||
@ -288,7 +348,7 @@ def load_scripts():
|
|||||||
postprocessing_scripts_data.clear()
|
postprocessing_scripts_data.clear()
|
||||||
script_callbacks.clear_callbacks()
|
script_callbacks.clear_callbacks()
|
||||||
|
|
||||||
scripts_list = list_scripts("scripts", ".py")
|
scripts_list = list_scripts("scripts", ".py") + list_scripts("modules/processing_scripts", ".py", include_extensions=False)
|
||||||
|
|
||||||
syspath = sys.path
|
syspath = sys.path
|
||||||
|
|
||||||
@ -349,10 +409,17 @@ class ScriptRunner:
|
|||||||
self.selectable_scripts = []
|
self.selectable_scripts = []
|
||||||
self.alwayson_scripts = []
|
self.alwayson_scripts = []
|
||||||
self.titles = []
|
self.titles = []
|
||||||
|
self.title_map = {}
|
||||||
self.infotext_fields = []
|
self.infotext_fields = []
|
||||||
self.paste_field_names = []
|
self.paste_field_names = []
|
||||||
self.inputs = [None]
|
self.inputs = [None]
|
||||||
|
|
||||||
|
self.on_before_component_elem_id = {}
|
||||||
|
"""dict of callbacks to be called before an element is created; key=elem_id, value=list of callbacks"""
|
||||||
|
|
||||||
|
self.on_after_component_elem_id = {}
|
||||||
|
"""dict of callbacks to be called after an element is created; key=elem_id, value=list of callbacks"""
|
||||||
|
|
||||||
def initialize_scripts(self, is_img2img):
|
def initialize_scripts(self, is_img2img):
|
||||||
from modules import scripts_auto_postprocessing
|
from modules import scripts_auto_postprocessing
|
||||||
|
|
||||||
@ -367,6 +434,7 @@ class ScriptRunner:
|
|||||||
script.filename = script_data.path
|
script.filename = script_data.path
|
||||||
script.is_txt2img = not is_img2img
|
script.is_txt2img = not is_img2img
|
||||||
script.is_img2img = is_img2img
|
script.is_img2img = is_img2img
|
||||||
|
script.tabname = "img2img" if is_img2img else "txt2img"
|
||||||
|
|
||||||
visibility = script.show(script.is_img2img)
|
visibility = script.show(script.is_img2img)
|
||||||
|
|
||||||
@ -379,6 +447,28 @@ class ScriptRunner:
|
|||||||
self.scripts.append(script)
|
self.scripts.append(script)
|
||||||
self.selectable_scripts.append(script)
|
self.selectable_scripts.append(script)
|
||||||
|
|
||||||
|
self.apply_on_before_component_callbacks()
|
||||||
|
|
||||||
|
def apply_on_before_component_callbacks(self):
|
||||||
|
for script in self.scripts:
|
||||||
|
on_before = script.on_before_component_elem_id or []
|
||||||
|
on_after = script.on_after_component_elem_id or []
|
||||||
|
|
||||||
|
for elem_id, callback in on_before:
|
||||||
|
if elem_id not in self.on_before_component_elem_id:
|
||||||
|
self.on_before_component_elem_id[elem_id] = []
|
||||||
|
|
||||||
|
self.on_before_component_elem_id[elem_id].append((callback, script))
|
||||||
|
|
||||||
|
for elem_id, callback in on_after:
|
||||||
|
if elem_id not in self.on_after_component_elem_id:
|
||||||
|
self.on_after_component_elem_id[elem_id] = []
|
||||||
|
|
||||||
|
self.on_after_component_elem_id[elem_id].append((callback, script))
|
||||||
|
|
||||||
|
on_before.clear()
|
||||||
|
on_after.clear()
|
||||||
|
|
||||||
def create_script_ui(self, script):
|
def create_script_ui(self, script):
|
||||||
import modules.api.models as api_models
|
import modules.api.models as api_models
|
||||||
|
|
||||||
@ -429,15 +519,20 @@ class ScriptRunner:
|
|||||||
if script.alwayson and script.section != section:
|
if script.alwayson and script.section != section:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
with gr.Group(visible=script.alwayson) as group:
|
if script.create_group:
|
||||||
self.create_script_ui(script)
|
with gr.Group(visible=script.alwayson) as group:
|
||||||
|
self.create_script_ui(script)
|
||||||
|
|
||||||
script.group = group
|
script.group = group
|
||||||
|
else:
|
||||||
|
self.create_script_ui(script)
|
||||||
|
|
||||||
def prepare_ui(self):
|
def prepare_ui(self):
|
||||||
self.inputs = [None]
|
self.inputs = [None]
|
||||||
|
|
||||||
def setup_ui(self):
|
def setup_ui(self):
|
||||||
|
all_titles = [wrap_call(script.title, script.filename, "title") or script.filename for script in self.scripts]
|
||||||
|
self.title_map = {title.lower(): script for title, script in zip(all_titles, self.scripts)}
|
||||||
self.titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in self.selectable_scripts]
|
self.titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in self.selectable_scripts]
|
||||||
|
|
||||||
self.setup_ui_for_section(None)
|
self.setup_ui_for_section(None)
|
||||||
@ -484,6 +579,8 @@ class ScriptRunner:
|
|||||||
self.infotext_fields.append((dropdown, lambda x: gr.update(value=x.get('Script', 'None'))))
|
self.infotext_fields.append((dropdown, lambda x: gr.update(value=x.get('Script', 'None'))))
|
||||||
self.infotext_fields.extend([(script.group, onload_script_visibility) for script in self.selectable_scripts])
|
self.infotext_fields.extend([(script.group, onload_script_visibility) for script in self.selectable_scripts])
|
||||||
|
|
||||||
|
self.apply_on_before_component_callbacks()
|
||||||
|
|
||||||
return self.inputs
|
return self.inputs
|
||||||
|
|
||||||
def run(self, p, *args):
|
def run(self, p, *args):
|
||||||
@ -577,6 +674,12 @@ class ScriptRunner:
|
|||||||
errors.report(f"Error running postprocess_image: {script.filename}", exc_info=True)
|
errors.report(f"Error running postprocess_image: {script.filename}", exc_info=True)
|
||||||
|
|
||||||
def before_component(self, component, **kwargs):
|
def before_component(self, component, **kwargs):
|
||||||
|
for callback, script in self.on_before_component_elem_id.get(kwargs.get("elem_id"), []):
|
||||||
|
try:
|
||||||
|
callback(OnComponent(component=component))
|
||||||
|
except Exception:
|
||||||
|
errors.report(f"Error running on_before_component: {script.filename}", exc_info=True)
|
||||||
|
|
||||||
for script in self.scripts:
|
for script in self.scripts:
|
||||||
try:
|
try:
|
||||||
script.before_component(component, **kwargs)
|
script.before_component(component, **kwargs)
|
||||||
@ -584,12 +687,21 @@ class ScriptRunner:
|
|||||||
errors.report(f"Error running before_component: {script.filename}", exc_info=True)
|
errors.report(f"Error running before_component: {script.filename}", exc_info=True)
|
||||||
|
|
||||||
def after_component(self, component, **kwargs):
|
def after_component(self, component, **kwargs):
|
||||||
|
for callback, script in self.on_after_component_elem_id.get(component.elem_id, []):
|
||||||
|
try:
|
||||||
|
callback(OnComponent(component=component))
|
||||||
|
except Exception:
|
||||||
|
errors.report(f"Error running on_after_component: {script.filename}", exc_info=True)
|
||||||
|
|
||||||
for script in self.scripts:
|
for script in self.scripts:
|
||||||
try:
|
try:
|
||||||
script.after_component(component, **kwargs)
|
script.after_component(component, **kwargs)
|
||||||
except Exception:
|
except Exception:
|
||||||
errors.report(f"Error running after_component: {script.filename}", exc_info=True)
|
errors.report(f"Error running after_component: {script.filename}", exc_info=True)
|
||||||
|
|
||||||
|
def script(self, title):
|
||||||
|
return self.title_map.get(title.lower())
|
||||||
|
|
||||||
def reload_sources(self, cache):
|
def reload_sources(self, cache):
|
||||||
for si, script in list(enumerate(self.scripts)):
|
for si, script in list(enumerate(self.scripts)):
|
||||||
args_from = script.args_from
|
args_from = script.args_from
|
||||||
@ -608,7 +720,6 @@ class ScriptRunner:
|
|||||||
self.scripts[si].args_from = args_from
|
self.scripts[si].args_from = args_from
|
||||||
self.scripts[si].args_to = args_to
|
self.scripts[si].args_to = args_to
|
||||||
|
|
||||||
|
|
||||||
def before_hr(self, p):
|
def before_hr(self, p):
|
||||||
for script in self.alwayson_scripts:
|
for script in self.alwayson_scripts:
|
||||||
try:
|
try:
|
||||||
@ -617,6 +728,14 @@ class ScriptRunner:
|
|||||||
except Exception:
|
except Exception:
|
||||||
errors.report(f"Error running before_hr: {script.filename}", exc_info=True)
|
errors.report(f"Error running before_hr: {script.filename}", exc_info=True)
|
||||||
|
|
||||||
|
def setup_scrips(self, p):
|
||||||
|
for script in self.alwayson_scripts:
|
||||||
|
try:
|
||||||
|
script_args = p.script_args[script.args_from:script.args_to]
|
||||||
|
script.setup(p, *script_args)
|
||||||
|
except Exception:
|
||||||
|
errors.report(f"Error running setup: {script.filename}", exc_info=True)
|
||||||
|
|
||||||
|
|
||||||
scripts_txt2img: ScriptRunner = None
|
scripts_txt2img: ScriptRunner = None
|
||||||
scripts_img2img: ScriptRunner = None
|
scripts_img2img: ScriptRunner = None
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
import math
|
import math
|
||||||
import psutil
|
import psutil
|
||||||
|
import platform
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import einsum
|
from torch import einsum
|
||||||
@ -94,7 +95,10 @@ class SdOptimizationSdp(SdOptimizationSdpNoMem):
|
|||||||
class SdOptimizationSubQuad(SdOptimization):
|
class SdOptimizationSubQuad(SdOptimization):
|
||||||
name = "sub-quadratic"
|
name = "sub-quadratic"
|
||||||
cmd_opt = "opt_sub_quad_attention"
|
cmd_opt = "opt_sub_quad_attention"
|
||||||
priority = 10
|
|
||||||
|
@property
|
||||||
|
def priority(self):
|
||||||
|
return 1000 if shared.device.type == 'mps' else 10
|
||||||
|
|
||||||
def apply(self):
|
def apply(self):
|
||||||
ldm.modules.attention.CrossAttention.forward = sub_quad_attention_forward
|
ldm.modules.attention.CrossAttention.forward = sub_quad_attention_forward
|
||||||
@ -120,7 +124,7 @@ class SdOptimizationInvokeAI(SdOptimization):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def priority(self):
|
def priority(self):
|
||||||
return 1000 if not torch.cuda.is_available() else 10
|
return 1000 if shared.device.type != 'mps' and not torch.cuda.is_available() else 10
|
||||||
|
|
||||||
def apply(self):
|
def apply(self):
|
||||||
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_invokeAI
|
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_invokeAI
|
||||||
@ -427,7 +431,10 @@ def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_
|
|||||||
qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens
|
qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens
|
||||||
|
|
||||||
if chunk_threshold is None:
|
if chunk_threshold is None:
|
||||||
chunk_threshold_bytes = int(get_available_vram() * 0.9) if q.device.type == 'mps' else int(get_available_vram() * 0.7)
|
if q.device.type == 'mps':
|
||||||
|
chunk_threshold_bytes = 268435456 * (2 if platform.processor() == 'i386' else bytes_per_token)
|
||||||
|
else:
|
||||||
|
chunk_threshold_bytes = int(get_available_vram() * 0.7)
|
||||||
elif chunk_threshold == 0:
|
elif chunk_threshold == 0:
|
||||||
chunk_threshold_bytes = None
|
chunk_threshold_bytes = None
|
||||||
else:
|
else:
|
||||||
|
@ -147,6 +147,9 @@ re_strip_checksum = re.compile(r"\s*\[[^]]+]\s*$")
|
|||||||
|
|
||||||
|
|
||||||
def get_closet_checkpoint_match(search_string):
|
def get_closet_checkpoint_match(search_string):
|
||||||
|
if not search_string:
|
||||||
|
return None
|
||||||
|
|
||||||
checkpoint_info = checkpoint_aliases.get(search_string, None)
|
checkpoint_info = checkpoint_aliases.get(search_string, None)
|
||||||
if checkpoint_info is not None:
|
if checkpoint_info is not None:
|
||||||
return checkpoint_info
|
return checkpoint_info
|
||||||
|
@ -45,18 +45,23 @@ class CFGDenoiser(torch.nn.Module):
|
|||||||
self.nmask = None
|
self.nmask = None
|
||||||
self.init_latent = None
|
self.init_latent = None
|
||||||
self.steps = None
|
self.steps = None
|
||||||
|
"""number of steps as specified by user in UI"""
|
||||||
|
|
||||||
|
self.total_steps = None
|
||||||
|
"""expected number of calls to denoiser calculated from self.steps and specifics of the selected sampler"""
|
||||||
|
|
||||||
self.step = 0
|
self.step = 0
|
||||||
self.image_cfg_scale = None
|
self.image_cfg_scale = None
|
||||||
self.padded_cond_uncond = False
|
self.padded_cond_uncond = False
|
||||||
self.sampler = sampler
|
self.sampler = sampler
|
||||||
self.model_wrap = None
|
self.model_wrap = None
|
||||||
self.p = None
|
self.p = None
|
||||||
|
self.mask_before_denoising = False
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def inner_model(self):
|
def inner_model(self):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
def combine_denoised(self, x_out, conds_list, uncond, cond_scale):
|
def combine_denoised(self, x_out, conds_list, uncond, cond_scale):
|
||||||
denoised_uncond = x_out[-uncond.shape[0]:]
|
denoised_uncond = x_out[-uncond.shape[0]:]
|
||||||
denoised = torch.clone(denoised_uncond)
|
denoised = torch.clone(denoised_uncond)
|
||||||
@ -100,7 +105,7 @@ class CFGDenoiser(torch.nn.Module):
|
|||||||
|
|
||||||
assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)"
|
assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)"
|
||||||
|
|
||||||
if self.mask is not None:
|
if self.mask_before_denoising and self.mask is not None:
|
||||||
x = self.init_latent * self.mask + self.nmask * x
|
x = self.init_latent * self.mask + self.nmask * x
|
||||||
|
|
||||||
batch_size = len(conds_list)
|
batch_size = len(conds_list)
|
||||||
@ -202,6 +207,9 @@ class CFGDenoiser(torch.nn.Module):
|
|||||||
else:
|
else:
|
||||||
denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
|
denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
|
||||||
|
|
||||||
|
if not self.mask_before_denoising and self.mask is not None:
|
||||||
|
denoised = self.init_latent * self.mask + self.nmask * denoised
|
||||||
|
|
||||||
self.sampler.last_latent = self.get_pred_x0(torch.cat([x_in[i:i + 1] for i in denoised_image_indexes]), torch.cat([x_out[i:i + 1] for i in denoised_image_indexes]), sigma)
|
self.sampler.last_latent = self.get_pred_x0(torch.cat([x_in[i:i + 1] for i in denoised_image_indexes]), torch.cat([x_out[i:i + 1] for i in denoised_image_indexes]), sigma)
|
||||||
|
|
||||||
if opts.live_preview_content == "Prompt":
|
if opts.live_preview_content == "Prompt":
|
||||||
|
@ -7,7 +7,16 @@ from modules import devices, images, sd_vae_approx, sd_samplers, sd_vae_taesd, s
|
|||||||
from modules.shared import opts, state
|
from modules.shared import opts, state
|
||||||
import k_diffusion.sampling
|
import k_diffusion.sampling
|
||||||
|
|
||||||
SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options'])
|
|
||||||
|
SamplerDataTuple = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options'])
|
||||||
|
|
||||||
|
|
||||||
|
class SamplerData(SamplerDataTuple):
|
||||||
|
def total_steps(self, steps):
|
||||||
|
if self.options.get("second_order", False):
|
||||||
|
steps = steps * 2
|
||||||
|
|
||||||
|
return steps
|
||||||
|
|
||||||
|
|
||||||
def setup_img2img_steps(p, steps=None):
|
def setup_img2img_steps(p, steps=None):
|
||||||
@ -83,7 +92,15 @@ def images_tensor_to_samples(image, approximation=None, model=None):
|
|||||||
model = shared.sd_model
|
model = shared.sd_model
|
||||||
image = image.to(shared.device, dtype=devices.dtype_vae)
|
image = image.to(shared.device, dtype=devices.dtype_vae)
|
||||||
image = image * 2 - 1
|
image = image * 2 - 1
|
||||||
x_latent = model.get_first_stage_encoding(model.encode_first_stage(image))
|
if len(image) > 1:
|
||||||
|
x_latent = torch.stack([
|
||||||
|
model.get_first_stage_encoding(
|
||||||
|
model.encode_first_stage(torch.unsqueeze(img, 0))
|
||||||
|
)[0]
|
||||||
|
for img in image
|
||||||
|
])
|
||||||
|
else:
|
||||||
|
x_latent = model.get_first_stage_encoding(model.encode_first_stage(image))
|
||||||
|
|
||||||
return x_latent
|
return x_latent
|
||||||
|
|
||||||
@ -131,31 +148,29 @@ def replace_torchsde_browinan():
|
|||||||
replace_torchsde_browinan()
|
replace_torchsde_browinan()
|
||||||
|
|
||||||
|
|
||||||
def apply_refiner(sampler):
|
def apply_refiner(cfg_denoiser):
|
||||||
completed_ratio = sampler.step / sampler.steps
|
completed_ratio = cfg_denoiser.step / cfg_denoiser.total_steps
|
||||||
|
refiner_switch_at = cfg_denoiser.p.refiner_switch_at
|
||||||
|
refiner_checkpoint_info = cfg_denoiser.p.refiner_checkpoint_info
|
||||||
|
|
||||||
if completed_ratio <= shared.opts.sd_refiner_switch_at:
|
if refiner_switch_at is not None and completed_ratio < refiner_switch_at:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if shared.opts.sd_refiner_checkpoint == "None":
|
if refiner_checkpoint_info is None or shared.sd_model.sd_checkpoint_info == refiner_checkpoint_info:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if shared.sd_model.sd_checkpoint_info.title == shared.opts.sd_refiner_checkpoint:
|
if getattr(cfg_denoiser.p, "enable_hr", False) and not cfg_denoiser.p.is_hr_pass:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
refiner_checkpoint_info = sd_models.get_closet_checkpoint_match(shared.opts.sd_refiner_checkpoint)
|
cfg_denoiser.p.extra_generation_params['Refiner'] = refiner_checkpoint_info.short_title
|
||||||
if refiner_checkpoint_info is None:
|
cfg_denoiser.p.extra_generation_params['Refiner switch at'] = refiner_switch_at
|
||||||
raise Exception(f'Could not find checkpoint with name {shared.opts.sd_refiner_checkpoint}')
|
|
||||||
|
|
||||||
sampler.p.extra_generation_params['Refiner'] = refiner_checkpoint_info.short_title
|
|
||||||
sampler.p.extra_generation_params['Refiner switch at'] = shared.opts.sd_refiner_switch_at
|
|
||||||
|
|
||||||
with sd_models.SkipWritingToConfig():
|
with sd_models.SkipWritingToConfig():
|
||||||
sd_models.reload_model_weights(info=refiner_checkpoint_info)
|
sd_models.reload_model_weights(info=refiner_checkpoint_info)
|
||||||
|
|
||||||
devices.torch_gc()
|
devices.torch_gc()
|
||||||
sampler.p.setup_conds()
|
cfg_denoiser.p.setup_conds()
|
||||||
sampler.update_inner_model()
|
cfg_denoiser.update_inner_model()
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@ -192,7 +207,7 @@ class Sampler:
|
|||||||
self.sampler_noises = None
|
self.sampler_noises = None
|
||||||
self.stop_at = None
|
self.stop_at = None
|
||||||
self.eta = None
|
self.eta = None
|
||||||
self.config = None # set by the function calling the constructor
|
self.config: SamplerData = None # set by the function calling the constructor
|
||||||
self.last_latent = None
|
self.last_latent = None
|
||||||
self.s_min_uncond = None
|
self.s_min_uncond = None
|
||||||
self.s_churn = 0.0
|
self.s_churn = 0.0
|
||||||
@ -208,6 +223,7 @@ class Sampler:
|
|||||||
self.p = None
|
self.p = None
|
||||||
self.model_wrap_cfg = None
|
self.model_wrap_cfg = None
|
||||||
self.sampler_extra_args = None
|
self.sampler_extra_args = None
|
||||||
|
self.options = {}
|
||||||
|
|
||||||
def callback_state(self, d):
|
def callback_state(self, d):
|
||||||
step = d['i']
|
step = d['i']
|
||||||
@ -220,6 +236,7 @@ class Sampler:
|
|||||||
|
|
||||||
def launch_sampling(self, steps, func):
|
def launch_sampling(self, steps, func):
|
||||||
self.model_wrap_cfg.steps = steps
|
self.model_wrap_cfg.steps = steps
|
||||||
|
self.model_wrap_cfg.total_steps = self.config.total_steps(steps)
|
||||||
state.sampling_steps = steps
|
state.sampling_steps = steps
|
||||||
state.sampling_step = 0
|
state.sampling_step = 0
|
||||||
|
|
||||||
@ -267,19 +284,19 @@ class Sampler:
|
|||||||
s_tmax = getattr(opts, 's_tmax', p.s_tmax) or self.s_tmax # 0 = inf
|
s_tmax = getattr(opts, 's_tmax', p.s_tmax) or self.s_tmax # 0 = inf
|
||||||
s_noise = getattr(opts, 's_noise', p.s_noise)
|
s_noise = getattr(opts, 's_noise', p.s_noise)
|
||||||
|
|
||||||
if s_churn != self.s_churn:
|
if 's_churn' in extra_params_kwargs and s_churn != self.s_churn:
|
||||||
extra_params_kwargs['s_churn'] = s_churn
|
extra_params_kwargs['s_churn'] = s_churn
|
||||||
p.s_churn = s_churn
|
p.s_churn = s_churn
|
||||||
p.extra_generation_params['Sigma churn'] = s_churn
|
p.extra_generation_params['Sigma churn'] = s_churn
|
||||||
if s_tmin != self.s_tmin:
|
if 's_tmin' in extra_params_kwargs and s_tmin != self.s_tmin:
|
||||||
extra_params_kwargs['s_tmin'] = s_tmin
|
extra_params_kwargs['s_tmin'] = s_tmin
|
||||||
p.s_tmin = s_tmin
|
p.s_tmin = s_tmin
|
||||||
p.extra_generation_params['Sigma tmin'] = s_tmin
|
p.extra_generation_params['Sigma tmin'] = s_tmin
|
||||||
if s_tmax != self.s_tmax:
|
if 's_tmax' in extra_params_kwargs and s_tmax != self.s_tmax:
|
||||||
extra_params_kwargs['s_tmax'] = s_tmax
|
extra_params_kwargs['s_tmax'] = s_tmax
|
||||||
p.s_tmax = s_tmax
|
p.s_tmax = s_tmax
|
||||||
p.extra_generation_params['Sigma tmax'] = s_tmax
|
p.extra_generation_params['Sigma tmax'] = s_tmax
|
||||||
if s_noise != self.s_noise:
|
if 's_noise' in extra_params_kwargs and s_noise != self.s_noise:
|
||||||
extra_params_kwargs['s_noise'] = s_noise
|
extra_params_kwargs['s_noise'] = s_noise
|
||||||
p.s_noise = s_noise
|
p.s_noise = s_noise
|
||||||
p.extra_generation_params['Sigma noise'] = s_noise
|
p.extra_generation_params['Sigma noise'] = s_noise
|
||||||
@ -296,5 +313,8 @@ class Sampler:
|
|||||||
current_iter_seeds = p.all_seeds[p.iteration * p.batch_size:(p.iteration + 1) * p.batch_size]
|
current_iter_seeds = p.all_seeds[p.iteration * p.batch_size:(p.iteration + 1) * p.batch_size]
|
||||||
return BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=current_iter_seeds)
|
return BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=current_iter_seeds)
|
||||||
|
|
||||||
|
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
||||||
|
raise NotImplementedError()
|
||||||
|
@ -22,6 +22,12 @@ samplers_k_diffusion = [
|
|||||||
('DPM++ 2M', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {}),
|
('DPM++ 2M', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {}),
|
||||||
('DPM++ SDE', 'sample_dpmpp_sde', ['k_dpmpp_sde'], {"second_order": True, "brownian_noise": True}),
|
('DPM++ SDE', 'sample_dpmpp_sde', ['k_dpmpp_sde'], {"second_order": True, "brownian_noise": True}),
|
||||||
('DPM++ 2M SDE', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_ka'], {"brownian_noise": True}),
|
('DPM++ 2M SDE', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_ka'], {"brownian_noise": True}),
|
||||||
|
('DPM++ 2M SDE Heun', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_heun'], {"brownian_noise": True, "solver_type": "heun"}),
|
||||||
|
('DPM++ 2M SDE Heun Karras', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_heun_ka'], {'scheduler': 'karras', "brownian_noise": True, "solver_type": "heun"}),
|
||||||
|
('DPM++ 2M SDE Heun Exponential', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_heun_exp'], {'scheduler': 'exponential', "brownian_noise": True, "solver_type": "heun"}),
|
||||||
|
('DPM++ 3M SDE', 'sample_dpmpp_3m_sde', ['k_dpmpp_3m_sde'], {'discard_next_to_last_sigma': True, "brownian_noise": True}),
|
||||||
|
('DPM++ 3M SDE Karras', 'sample_dpmpp_3m_sde', ['k_dpmpp_3m_sde_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "brownian_noise": True}),
|
||||||
|
('DPM++ 3M SDE Exponential', 'sample_dpmpp_3m_sde', ['k_dpmpp_3m_sde_exp'], {'scheduler': 'exponential', 'discard_next_to_last_sigma': True, "brownian_noise": True}),
|
||||||
('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {"uses_ensd": True}),
|
('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {"uses_ensd": True}),
|
||||||
('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {"uses_ensd": True}),
|
('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {"uses_ensd": True}),
|
||||||
('LMS Karras', 'sample_lms', ['k_lms_ka'], {'scheduler': 'karras'}),
|
('LMS Karras', 'sample_lms', ['k_lms_ka'], {'scheduler': 'karras'}),
|
||||||
@ -42,6 +48,12 @@ sampler_extra_params = {
|
|||||||
'sample_euler': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
|
'sample_euler': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
|
||||||
'sample_heun': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
|
'sample_heun': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
|
||||||
'sample_dpm_2': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
|
'sample_dpm_2': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
|
||||||
|
'sample_dpm_fast': ['s_noise'],
|
||||||
|
'sample_dpm_2_ancestral': ['s_noise'],
|
||||||
|
'sample_dpmpp_2s_ancestral': ['s_noise'],
|
||||||
|
'sample_dpmpp_sde': ['s_noise'],
|
||||||
|
'sample_dpmpp_2m_sde': ['s_noise'],
|
||||||
|
'sample_dpmpp_3m_sde': ['s_noise'],
|
||||||
}
|
}
|
||||||
|
|
||||||
k_diffusion_samplers_map = {x.name: x for x in samplers_data_k_diffusion}
|
k_diffusion_samplers_map = {x.name: x for x in samplers_data_k_diffusion}
|
||||||
@ -64,9 +76,12 @@ class CFGDenoiserKDiffusion(sd_samplers_cfg_denoiser.CFGDenoiser):
|
|||||||
|
|
||||||
|
|
||||||
class KDiffusionSampler(sd_samplers_common.Sampler):
|
class KDiffusionSampler(sd_samplers_common.Sampler):
|
||||||
def __init__(self, funcname, sd_model):
|
def __init__(self, funcname, sd_model, options=None):
|
||||||
super().__init__(funcname)
|
super().__init__(funcname)
|
||||||
|
|
||||||
|
self.extra_params = sampler_extra_params.get(funcname, [])
|
||||||
|
|
||||||
|
self.options = options or {}
|
||||||
self.func = funcname if callable(funcname) else getattr(k_diffusion.sampling, self.funcname)
|
self.func = funcname if callable(funcname) else getattr(k_diffusion.sampling, self.funcname)
|
||||||
|
|
||||||
self.model_wrap_cfg = CFGDenoiserKDiffusion(self)
|
self.model_wrap_cfg = CFGDenoiserKDiffusion(self)
|
||||||
@ -149,6 +164,9 @@ class KDiffusionSampler(sd_samplers_common.Sampler):
|
|||||||
noise_sampler = self.create_noise_sampler(x, sigmas, p)
|
noise_sampler = self.create_noise_sampler(x, sigmas, p)
|
||||||
extra_params_kwargs['noise_sampler'] = noise_sampler
|
extra_params_kwargs['noise_sampler'] = noise_sampler
|
||||||
|
|
||||||
|
if self.config.options.get('solver_type', None) == 'heun':
|
||||||
|
extra_params_kwargs['solver_type'] = 'heun'
|
||||||
|
|
||||||
self.model_wrap_cfg.init_latent = x
|
self.model_wrap_cfg.init_latent = x
|
||||||
self.last_latent = x
|
self.last_latent = x
|
||||||
self.sampler_extra_args = {
|
self.sampler_extra_args = {
|
||||||
@ -190,6 +208,9 @@ class KDiffusionSampler(sd_samplers_common.Sampler):
|
|||||||
noise_sampler = self.create_noise_sampler(x, sigmas, p)
|
noise_sampler = self.create_noise_sampler(x, sigmas, p)
|
||||||
extra_params_kwargs['noise_sampler'] = noise_sampler
|
extra_params_kwargs['noise_sampler'] = noise_sampler
|
||||||
|
|
||||||
|
if self.config.options.get('solver_type', None) == 'heun':
|
||||||
|
extra_params_kwargs['solver_type'] = 'heun'
|
||||||
|
|
||||||
self.last_latent = x
|
self.last_latent = x
|
||||||
self.sampler_extra_args = {
|
self.sampler_extra_args = {
|
||||||
'cond': conditioning,
|
'cond': conditioning,
|
||||||
@ -198,6 +219,7 @@ class KDiffusionSampler(sd_samplers_common.Sampler):
|
|||||||
'cond_scale': p.cfg_scale,
|
'cond_scale': p.cfg_scale,
|
||||||
's_min_uncond': self.s_min_uncond
|
's_min_uncond': self.s_min_uncond
|
||||||
}
|
}
|
||||||
|
|
||||||
samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args=self.sampler_extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
|
samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args=self.sampler_extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
|
||||||
|
|
||||||
if self.model_wrap_cfg.padded_cond_uncond:
|
if self.model_wrap_cfg.padded_cond_uncond:
|
||||||
|
@ -49,12 +49,12 @@ class CFGDenoiserTimesteps(CFGDenoiser):
|
|||||||
super().__init__(sampler)
|
super().__init__(sampler)
|
||||||
|
|
||||||
self.alphas = shared.sd_model.alphas_cumprod
|
self.alphas = shared.sd_model.alphas_cumprod
|
||||||
|
self.mask_before_denoising = True
|
||||||
|
|
||||||
def get_pred_x0(self, x_in, x_out, sigma):
|
def get_pred_x0(self, x_in, x_out, sigma):
|
||||||
ts = int(sigma.item())
|
ts = sigma.to(dtype=int)
|
||||||
|
|
||||||
s_in = x_in.new_ones([x_in.shape[0]])
|
a_t = self.alphas[ts][:, None, None, None]
|
||||||
a_t = self.alphas[ts].item() * s_in
|
|
||||||
sqrt_one_minus_at = (1 - a_t).sqrt()
|
sqrt_one_minus_at = (1 - a_t).sqrt()
|
||||||
|
|
||||||
pred_x0 = (x_in - sqrt_one_minus_at * x_out) / a_t.sqrt()
|
pred_x0 = (x_in - sqrt_one_minus_at * x_out) / a_t.sqrt()
|
||||||
|
@ -11,21 +11,22 @@ from modules.models.diffusion.uni_pc import uni_pc
|
|||||||
def ddim(model, x, timesteps, extra_args=None, callback=None, disable=None, eta=0.0):
|
def ddim(model, x, timesteps, extra_args=None, callback=None, disable=None, eta=0.0):
|
||||||
alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
|
alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
|
||||||
alphas = alphas_cumprod[timesteps]
|
alphas = alphas_cumprod[timesteps]
|
||||||
alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(torch.float64)
|
alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(torch.float64 if x.device.type != 'mps' else torch.float32)
|
||||||
sqrt_one_minus_alphas = torch.sqrt(1 - alphas)
|
sqrt_one_minus_alphas = torch.sqrt(1 - alphas)
|
||||||
sigmas = eta * np.sqrt((1 - alphas_prev.cpu().numpy()) / (1 - alphas.cpu()) * (1 - alphas.cpu() / alphas_prev.cpu().numpy()))
|
sigmas = eta * np.sqrt((1 - alphas_prev.cpu().numpy()) / (1 - alphas.cpu()) * (1 - alphas.cpu() / alphas_prev.cpu().numpy()))
|
||||||
|
|
||||||
extra_args = {} if extra_args is None else extra_args
|
extra_args = {} if extra_args is None else extra_args
|
||||||
s_in = x.new_ones([x.shape[0]])
|
s_in = x.new_ones((x.shape[0]))
|
||||||
|
s_x = x.new_ones((x.shape[0], 1, 1, 1))
|
||||||
for i in tqdm.trange(len(timesteps) - 1, disable=disable):
|
for i in tqdm.trange(len(timesteps) - 1, disable=disable):
|
||||||
index = len(timesteps) - 1 - i
|
index = len(timesteps) - 1 - i
|
||||||
|
|
||||||
e_t = model(x, timesteps[index].item() * s_in, **extra_args)
|
e_t = model(x, timesteps[index].item() * s_in, **extra_args)
|
||||||
|
|
||||||
a_t = alphas[index].item() * s_in
|
a_t = alphas[index].item() * s_x
|
||||||
a_prev = alphas_prev[index].item() * s_in
|
a_prev = alphas_prev[index].item() * s_x
|
||||||
sigma_t = sigmas[index].item() * s_in
|
sigma_t = sigmas[index].item() * s_x
|
||||||
sqrt_one_minus_at = sqrt_one_minus_alphas[index].item() * s_in
|
sqrt_one_minus_at = sqrt_one_minus_alphas[index].item() * s_x
|
||||||
|
|
||||||
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||||
dir_xt = (1. - a_prev - sigma_t ** 2).sqrt() * e_t
|
dir_xt = (1. - a_prev - sigma_t ** 2).sqrt() * e_t
|
||||||
@ -42,18 +43,19 @@ def ddim(model, x, timesteps, extra_args=None, callback=None, disable=None, eta=
|
|||||||
def plms(model, x, timesteps, extra_args=None, callback=None, disable=None):
|
def plms(model, x, timesteps, extra_args=None, callback=None, disable=None):
|
||||||
alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
|
alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
|
||||||
alphas = alphas_cumprod[timesteps]
|
alphas = alphas_cumprod[timesteps]
|
||||||
alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(torch.float64)
|
alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(torch.float64 if x.device.type != 'mps' else torch.float32)
|
||||||
sqrt_one_minus_alphas = torch.sqrt(1 - alphas)
|
sqrt_one_minus_alphas = torch.sqrt(1 - alphas)
|
||||||
|
|
||||||
extra_args = {} if extra_args is None else extra_args
|
extra_args = {} if extra_args is None else extra_args
|
||||||
s_in = x.new_ones([x.shape[0]])
|
s_in = x.new_ones([x.shape[0]])
|
||||||
|
s_x = x.new_ones((x.shape[0], 1, 1, 1))
|
||||||
old_eps = []
|
old_eps = []
|
||||||
|
|
||||||
def get_x_prev_and_pred_x0(e_t, index):
|
def get_x_prev_and_pred_x0(e_t, index):
|
||||||
# select parameters corresponding to the currently considered timestep
|
# select parameters corresponding to the currently considered timestep
|
||||||
a_t = alphas[index].item() * s_in
|
a_t = alphas[index].item() * s_x
|
||||||
a_prev = alphas_prev[index].item() * s_in
|
a_prev = alphas_prev[index].item() * s_x
|
||||||
sqrt_one_minus_at = sqrt_one_minus_alphas[index].item() * s_in
|
sqrt_one_minus_at = sqrt_one_minus_alphas[index].item() * s_x
|
||||||
|
|
||||||
# current prediction for x_0
|
# current prediction for x_0
|
||||||
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||||
|
@ -31,7 +31,9 @@ def get_loaded_vae_hash():
|
|||||||
if loaded_vae_file is None:
|
if loaded_vae_file is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return hashes.sha256(loaded_vae_file, 'vae')[0:10]
|
sha256 = hashes.sha256(loaded_vae_file, 'vae')
|
||||||
|
|
||||||
|
return sha256[0:10] if sha256 else None
|
||||||
|
|
||||||
|
|
||||||
def get_base_vae(model):
|
def get_base_vae(model):
|
||||||
|
@ -69,10 +69,11 @@ def reload_hypernetworks():
|
|||||||
ui_reorder_categories_builtin_items = [
|
ui_reorder_categories_builtin_items = [
|
||||||
"inpaint",
|
"inpaint",
|
||||||
"sampler",
|
"sampler",
|
||||||
|
"accordions",
|
||||||
"checkboxes",
|
"checkboxes",
|
||||||
"hires_fix",
|
|
||||||
"dimensions",
|
"dimensions",
|
||||||
"cfg",
|
"cfg",
|
||||||
|
"denoising",
|
||||||
"seed",
|
"seed",
|
||||||
"batch",
|
"batch",
|
||||||
"override_settings",
|
"override_settings",
|
||||||
@ -86,7 +87,7 @@ def ui_reorder_categories():
|
|||||||
|
|
||||||
sections = {}
|
sections = {}
|
||||||
for script in scripts.scripts_txt2img.scripts + scripts.scripts_img2img.scripts:
|
for script in scripts.scripts_txt2img.scripts + scripts.scripts_img2img.scripts:
|
||||||
if isinstance(script.section, str):
|
if isinstance(script.section, str) and script.section not in ui_reorder_categories_builtin_items:
|
||||||
sections[script.section] = 1
|
sections[script.section] = 1
|
||||||
|
|
||||||
yield from sections
|
yield from sections
|
||||||
|
@ -140,8 +140,6 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
|
|||||||
"upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"),
|
"upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"),
|
||||||
"randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU", "NV"]}).info("changes seeds drastically; use CPU to produce the same picture across different videocard vendors; use NV to produce same picture as on NVidia videocards"),
|
"randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU", "NV"]}).info("changes seeds drastically; use CPU to produce the same picture across different videocard vendors; use NV to produce same picture as on NVidia videocards"),
|
||||||
"tiling": OptionInfo(False, "Tiling", infotext='Tiling').info("produce a tileable picture"),
|
"tiling": OptionInfo(False, "Tiling", infotext='Tiling').info("produce a tileable picture"),
|
||||||
"sd_refiner_checkpoint": OptionInfo("None", "Refiner checkpoint", gr.Dropdown, lambda: {"choices": ["None"] + shared_items.list_checkpoint_tiles()}, refresh=shared_items.refresh_checkpoints, infotext="Refiner").info("switch to another model in the middle of generation"),
|
|
||||||
"sd_refiner_switch_at": OptionInfo(1.0, "Refiner switch at", gr.Slider, {"minimum": 0.01, "maximum": 1.0, "step": 0.01}, infotext='Refiner switch at').info("fraction of sampling steps when the swtch to refiner model should happen; 1=never, 0.5=switch in the middle of generation"),
|
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('sdxl', "Stable Diffusion XL"), {
|
options_templates.update(options_section(('sdxl', "Stable Diffusion XL"), {
|
||||||
@ -288,12 +286,12 @@ options_templates.update(options_section(('ui', "Live previews"), {
|
|||||||
options_templates.update(options_section(('sampler-params', "Sampler parameters"), {
|
options_templates.update(options_section(('sampler-params', "Sampler parameters"), {
|
||||||
"hide_samplers": OptionInfo([], "Hide samplers in user interface", gr.CheckboxGroup, lambda: {"choices": [x.name for x in shared_items.list_samplers()]}).needs_reload_ui(),
|
"hide_samplers": OptionInfo([], "Hide samplers in user interface", gr.CheckboxGroup, lambda: {"choices": [x.name for x in shared_items.list_samplers()]}).needs_reload_ui(),
|
||||||
"eta_ddim": OptionInfo(0.0, "Eta for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}, infotext='Eta DDIM').info("noise multiplier; higher = more unperdictable results"),
|
"eta_ddim": OptionInfo(0.0, "Eta for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}, infotext='Eta DDIM').info("noise multiplier; higher = more unperdictable results"),
|
||||||
"eta_ancestral": OptionInfo(1.0, "Eta for ancestral samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}, infotext='Eta').info("noise multiplier; applies to Euler a and other samplers that have a in them"),
|
"eta_ancestral": OptionInfo(1.0, "Eta for k-diffusion samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}, infotext='Eta').info("noise multiplier; currently only applies to ancestral samplers (i.e. Euler a) and SDE samplers"),
|
||||||
"ddim_discretize": OptionInfo('uniform', "img2img DDIM discretize", gr.Radio, {"choices": ['uniform', 'quad']}),
|
"ddim_discretize": OptionInfo('uniform', "img2img DDIM discretize", gr.Radio, {"choices": ['uniform', 'quad']}),
|
||||||
's_churn': OptionInfo(0.0, "sigma churn", gr.Slider, {"minimum": 0.0, "maximum": 100.0, "step": 0.01}, infotext='Sigma churn').info('amount of stochasticity; only applies to Euler, Heun, and DPM2'),
|
's_churn': OptionInfo(0.0, "sigma churn", gr.Slider, {"minimum": 0.0, "maximum": 100.0, "step": 0.01}, infotext='Sigma churn').info('amount of stochasticity; only applies to Euler, Heun, and DPM2'),
|
||||||
's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 10.0, "step": 0.01}, infotext='Sigma tmin').info('enable stochasticity; start value of the sigma range; only applies to Euler, Heun, and DPM2'),
|
's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 10.0, "step": 0.01}, infotext='Sigma tmin').info('enable stochasticity; start value of the sigma range; only applies to Euler, Heun, and DPM2'),
|
||||||
's_tmax': OptionInfo(0.0, "sigma tmax", gr.Slider, {"minimum": 0.0, "maximum": 999.0, "step": 0.01}, infotext='Sigma tmax').info("0 = inf; end value of the sigma range; only applies to Euler, Heun, and DPM2"),
|
's_tmax': OptionInfo(0.0, "sigma tmax", gr.Slider, {"minimum": 0.0, "maximum": 999.0, "step": 0.01}, infotext='Sigma tmax').info("0 = inf; end value of the sigma range; only applies to Euler, Heun, and DPM2"),
|
||||||
's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.1, "step": 0.001}, infotext='Sigma noise').info('amount of additional noise to counteract loss of detail during sampling; only applies to Euler, Heun, and DPM2'),
|
's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.1, "step": 0.001}, infotext='Sigma noise').info('amount of additional noise to counteract loss of detail during sampling'),
|
||||||
'k_sched_type': OptionInfo("Automatic", "Scheduler type", gr.Dropdown, {"choices": ["Automatic", "karras", "exponential", "polyexponential"]}, infotext='Schedule type').info("lets you override the noise schedule for k-diffusion samplers; choosing Automatic disables the three parameters below"),
|
'k_sched_type': OptionInfo("Automatic", "Scheduler type", gr.Dropdown, {"choices": ["Automatic", "karras", "exponential", "polyexponential"]}, infotext='Schedule type').info("lets you override the noise schedule for k-diffusion samplers; choosing Automatic disables the three parameters below"),
|
||||||
'sigma_min': OptionInfo(0.0, "sigma min", gr.Number, infotext='Schedule max sigma').info("0 = default (~0.03); minimum noise strength for k-diffusion noise scheduler"),
|
'sigma_min': OptionInfo(0.0, "sigma min", gr.Number, infotext='Schedule max sigma').info("0 = default (~0.03); minimum noise strength for k-diffusion noise scheduler"),
|
||||||
'sigma_max': OptionInfo(0.0, "sigma max", gr.Number, infotext='Schedule min sigma').info("0 = default (~14.6); maximum noise strength for k-diffusion noise scheduler"),
|
'sigma_max': OptionInfo(0.0, "sigma max", gr.Number, infotext='Schedule min sigma').info("0 = default (~14.6); maximum noise strength for k-diffusion noise scheduler"),
|
||||||
|
@ -58,7 +58,7 @@ def _summarize_chunk(
|
|||||||
scale: float,
|
scale: float,
|
||||||
) -> AttnChunk:
|
) -> AttnChunk:
|
||||||
attn_weights = torch.baddbmm(
|
attn_weights = torch.baddbmm(
|
||||||
torch.empty(1, 1, 1, device=query.device, dtype=query.dtype),
|
torch.zeros(1, 1, 1, device=query.device, dtype=query.dtype),
|
||||||
query,
|
query,
|
||||||
key.transpose(1,2),
|
key.transpose(1,2),
|
||||||
alpha=scale,
|
alpha=scale,
|
||||||
@ -121,7 +121,7 @@ def _get_attention_scores_no_kv_chunking(
|
|||||||
scale: float,
|
scale: float,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
attn_scores = torch.baddbmm(
|
attn_scores = torch.baddbmm(
|
||||||
torch.empty(1, 1, 1, device=query.device, dtype=query.dtype),
|
torch.zeros(1, 1, 1, device=query.device, dtype=query.dtype),
|
||||||
query,
|
query,
|
||||||
key.transpose(1,2),
|
key.transpose(1,2),
|
||||||
alpha=scale,
|
alpha=scale,
|
||||||
|
@ -9,7 +9,7 @@ from modules.ui import plaintext_to_html
|
|||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
|
|
||||||
def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_name: str, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_checkpoint_name: str, hr_sampler_name: str, hr_prompt: str, hr_negative_prompt, override_settings_texts, request: gr.Request, *args):
|
def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_name: str, n_iter: int, batch_size: int, cfg_scale: float, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_checkpoint_name: str, hr_sampler_name: str, hr_prompt: str, hr_negative_prompt, override_settings_texts, request: gr.Request, *args):
|
||||||
override_settings = create_override_settings_dict(override_settings_texts)
|
override_settings = create_override_settings_dict(override_settings_texts)
|
||||||
|
|
||||||
p = processing.StableDiffusionProcessingTxt2Img(
|
p = processing.StableDiffusionProcessingTxt2Img(
|
||||||
@ -19,12 +19,6 @@ def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, step
|
|||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
styles=prompt_styles,
|
styles=prompt_styles,
|
||||||
negative_prompt=negative_prompt,
|
negative_prompt=negative_prompt,
|
||||||
seed=seed,
|
|
||||||
subseed=subseed,
|
|
||||||
subseed_strength=subseed_strength,
|
|
||||||
seed_resize_from_h=seed_resize_from_h,
|
|
||||||
seed_resize_from_w=seed_resize_from_w,
|
|
||||||
seed_enable_extras=seed_enable_extras,
|
|
||||||
sampler_name=sampler_name,
|
sampler_name=sampler_name,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
n_iter=n_iter,
|
n_iter=n_iter,
|
||||||
|
177
modules/ui.py
177
modules/ui.py
@ -1,5 +1,4 @@
|
|||||||
import datetime
|
import datetime
|
||||||
import json
|
|
||||||
import mimetypes
|
import mimetypes
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
@ -13,7 +12,7 @@ from PIL import Image, PngImagePlugin # noqa: F401
|
|||||||
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
|
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
|
||||||
|
|
||||||
from modules import gradio_extensons # noqa: F401
|
from modules import gradio_extensons # noqa: F401
|
||||||
from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, errors, shared_items, ui_settings, timer, sysinfo, ui_checkpoint_merger, ui_prompt_styles, scripts, sd_samplers, processing, ui_extra_networks
|
from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, shared_items, ui_settings, timer, sysinfo, ui_checkpoint_merger, ui_prompt_styles, scripts, sd_samplers, processing, ui_extra_networks
|
||||||
from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML, InputAccordion
|
from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML, InputAccordion
|
||||||
from modules.paths import script_path
|
from modules.paths import script_path
|
||||||
from modules.ui_common import create_refresh_button
|
from modules.ui_common import create_refresh_button
|
||||||
@ -142,45 +141,6 @@ def interrogate_deepbooru(image):
|
|||||||
return gr.update() if prompt is None else prompt
|
return gr.update() if prompt is None else prompt
|
||||||
|
|
||||||
|
|
||||||
def create_seed_inputs(target_interface):
|
|
||||||
with FormRow(elem_id=f"{target_interface}_seed_row", variant="compact"):
|
|
||||||
if cmd_opts.use_textbox_seed:
|
|
||||||
seed = gr.Textbox(label='Seed', value="", elem_id=f"{target_interface}_seed")
|
|
||||||
else:
|
|
||||||
seed = gr.Number(label='Seed', value=-1, elem_id=f"{target_interface}_seed", precision=0)
|
|
||||||
|
|
||||||
random_seed = ToolButton(random_symbol, elem_id=f"{target_interface}_random_seed", label='Random seed')
|
|
||||||
reuse_seed = ToolButton(reuse_symbol, elem_id=f"{target_interface}_reuse_seed", label='Reuse seed')
|
|
||||||
|
|
||||||
seed_checkbox = gr.Checkbox(label='Extra', elem_id=f"{target_interface}_subseed_show", value=False)
|
|
||||||
|
|
||||||
# Components to show/hide based on the 'Extra' checkbox
|
|
||||||
seed_extras = []
|
|
||||||
|
|
||||||
with FormRow(visible=False, elem_id=f"{target_interface}_subseed_row") as seed_extra_row_1:
|
|
||||||
seed_extras.append(seed_extra_row_1)
|
|
||||||
subseed = gr.Number(label='Variation seed', value=-1, elem_id=f"{target_interface}_subseed", precision=0)
|
|
||||||
random_subseed = ToolButton(random_symbol, elem_id=f"{target_interface}_random_subseed")
|
|
||||||
reuse_subseed = ToolButton(reuse_symbol, elem_id=f"{target_interface}_reuse_subseed")
|
|
||||||
subseed_strength = gr.Slider(label='Variation strength', value=0.0, minimum=0, maximum=1, step=0.01, elem_id=f"{target_interface}_subseed_strength")
|
|
||||||
|
|
||||||
with FormRow(visible=False) as seed_extra_row_2:
|
|
||||||
seed_extras.append(seed_extra_row_2)
|
|
||||||
seed_resize_from_w = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from width", value=0, elem_id=f"{target_interface}_seed_resize_from_w")
|
|
||||||
seed_resize_from_h = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from height", value=0, elem_id=f"{target_interface}_seed_resize_from_h")
|
|
||||||
|
|
||||||
random_seed.click(fn=None, _js="function(){setRandomSeed('" + target_interface + "_seed')}", show_progress=False, inputs=[], outputs=[])
|
|
||||||
random_subseed.click(fn=None, _js="function(){setRandomSeed('" + target_interface + "_subseed')}", show_progress=False, inputs=[], outputs=[])
|
|
||||||
|
|
||||||
def change_visibility(show):
|
|
||||||
return {comp: gr_show(show) for comp in seed_extras}
|
|
||||||
|
|
||||||
seed_checkbox.change(change_visibility, show_progress=False, inputs=[seed_checkbox], outputs=seed_extras)
|
|
||||||
|
|
||||||
return seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def connect_clear_prompt(button):
|
def connect_clear_prompt(button):
|
||||||
"""Given clear button, prompt, and token_counter objects, setup clear prompt button click event"""
|
"""Given clear button, prompt, and token_counter objects, setup clear prompt button click event"""
|
||||||
button.click(
|
button.click(
|
||||||
@ -191,39 +151,6 @@ def connect_clear_prompt(button):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info: gr.Textbox, dummy_component, is_subseed):
|
|
||||||
""" Connects a 'reuse (sub)seed' button's click event so that it copies last used
|
|
||||||
(sub)seed value from generation info the to the seed field. If copying subseed and subseed strength
|
|
||||||
was 0, i.e. no variation seed was used, it copies the normal seed value instead."""
|
|
||||||
def copy_seed(gen_info_string: str, index):
|
|
||||||
res = -1
|
|
||||||
|
|
||||||
try:
|
|
||||||
gen_info = json.loads(gen_info_string)
|
|
||||||
index -= gen_info.get('index_of_first_image', 0)
|
|
||||||
|
|
||||||
if is_subseed and gen_info.get('subseed_strength', 0) > 0:
|
|
||||||
all_subseeds = gen_info.get('all_subseeds', [-1])
|
|
||||||
res = all_subseeds[index if 0 <= index < len(all_subseeds) else 0]
|
|
||||||
else:
|
|
||||||
all_seeds = gen_info.get('all_seeds', [-1])
|
|
||||||
res = all_seeds[index if 0 <= index < len(all_seeds) else 0]
|
|
||||||
|
|
||||||
except json.decoder.JSONDecodeError:
|
|
||||||
if gen_info_string:
|
|
||||||
errors.report(f"Error parsing JSON generation info: {gen_info_string}")
|
|
||||||
|
|
||||||
return [res, gr_show(False)]
|
|
||||||
|
|
||||||
reuse_seed.click(
|
|
||||||
fn=copy_seed,
|
|
||||||
_js="(x, y) => [x, selected_gallery_index()]",
|
|
||||||
show_progress=False,
|
|
||||||
inputs=[generation_info, dummy_component],
|
|
||||||
outputs=[seed, dummy_component]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def update_token_counter(text, steps):
|
def update_token_counter(text, steps):
|
||||||
try:
|
try:
|
||||||
text, _ = extra_networks.parse_prompt(text)
|
text, _ = extra_networks.parse_prompt(text)
|
||||||
@ -429,44 +356,45 @@ def create_ui():
|
|||||||
batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size")
|
batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size")
|
||||||
|
|
||||||
elif category == "cfg":
|
elif category == "cfg":
|
||||||
cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="txt2img_cfg_scale")
|
with gr.Row():
|
||||||
|
cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="txt2img_cfg_scale")
|
||||||
elif category == "seed":
|
|
||||||
seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs('txt2img')
|
|
||||||
|
|
||||||
elif category == "checkboxes":
|
elif category == "checkboxes":
|
||||||
with FormRow(elem_classes="checkboxes-row", variant="compact"):
|
with FormRow(elem_classes="checkboxes-row", variant="compact"):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
elif category == "hires_fix":
|
elif category == "accordions":
|
||||||
with InputAccordion(False, label="Hires. fix") as enable_hr:
|
with gr.Row(elem_id="txt2img_accordions", elem_classes="accordions"):
|
||||||
with enable_hr.extra():
|
with InputAccordion(False, label="Hires. fix", elem_id="txt2img_hr") as enable_hr:
|
||||||
hr_final_resolution = FormHTML(value="", elem_id="txtimg_hr_finalres", label="Upscaled resolution", interactive=False, min_width=0)
|
with enable_hr.extra():
|
||||||
|
hr_final_resolution = FormHTML(value="", elem_id="txtimg_hr_finalres", label="Upscaled resolution", interactive=False, min_width=0)
|
||||||
|
|
||||||
with FormRow(elem_id="txt2img_hires_fix_row1", variant="compact"):
|
with FormRow(elem_id="txt2img_hires_fix_row1", variant="compact"):
|
||||||
hr_upscaler = gr.Dropdown(label="Upscaler", elem_id="txt2img_hr_upscaler", choices=[*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]], value=shared.latent_upscale_default_mode)
|
hr_upscaler = gr.Dropdown(label="Upscaler", elem_id="txt2img_hr_upscaler", choices=[*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]], value=shared.latent_upscale_default_mode)
|
||||||
hr_second_pass_steps = gr.Slider(minimum=0, maximum=150, step=1, label='Hires steps', value=0, elem_id="txt2img_hires_steps")
|
hr_second_pass_steps = gr.Slider(minimum=0, maximum=150, step=1, label='Hires steps', value=0, elem_id="txt2img_hires_steps")
|
||||||
denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7, elem_id="txt2img_denoising_strength")
|
denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7, elem_id="txt2img_denoising_strength")
|
||||||
|
|
||||||
with FormRow(elem_id="txt2img_hires_fix_row2", variant="compact"):
|
with FormRow(elem_id="txt2img_hires_fix_row2", variant="compact"):
|
||||||
hr_scale = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Upscale by", value=2.0, elem_id="txt2img_hr_scale")
|
hr_scale = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Upscale by", value=2.0, elem_id="txt2img_hr_scale")
|
||||||
hr_resize_x = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize width to", value=0, elem_id="txt2img_hr_resize_x")
|
hr_resize_x = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize width to", value=0, elem_id="txt2img_hr_resize_x")
|
||||||
hr_resize_y = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize height to", value=0, elem_id="txt2img_hr_resize_y")
|
hr_resize_y = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize height to", value=0, elem_id="txt2img_hr_resize_y")
|
||||||
|
|
||||||
with FormRow(elem_id="txt2img_hires_fix_row3", variant="compact", visible=opts.hires_fix_show_sampler) as hr_sampler_container:
|
with FormRow(elem_id="txt2img_hires_fix_row3", variant="compact", visible=opts.hires_fix_show_sampler) as hr_sampler_container:
|
||||||
|
|
||||||
hr_checkpoint_name = gr.Dropdown(label='Hires checkpoint', elem_id="hr_checkpoint", choices=["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True), value="Use same checkpoint")
|
hr_checkpoint_name = gr.Dropdown(label='Hires checkpoint', elem_id="hr_checkpoint", choices=["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True), value="Use same checkpoint")
|
||||||
create_refresh_button(hr_checkpoint_name, modules.sd_models.list_models, lambda: {"choices": ["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True)}, "hr_checkpoint_refresh")
|
create_refresh_button(hr_checkpoint_name, modules.sd_models.list_models, lambda: {"choices": ["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True)}, "hr_checkpoint_refresh")
|
||||||
|
|
||||||
hr_sampler_name = gr.Dropdown(label='Hires sampling method', elem_id="hr_sampler", choices=["Use same sampler"] + sd_samplers.visible_sampler_names(), value="Use same sampler")
|
hr_sampler_name = gr.Dropdown(label='Hires sampling method', elem_id="hr_sampler", choices=["Use same sampler"] + sd_samplers.visible_sampler_names(), value="Use same sampler")
|
||||||
|
|
||||||
with FormRow(elem_id="txt2img_hires_fix_row4", variant="compact", visible=opts.hires_fix_show_prompts) as hr_prompts_container:
|
with FormRow(elem_id="txt2img_hires_fix_row4", variant="compact", visible=opts.hires_fix_show_prompts) as hr_prompts_container:
|
||||||
with gr.Column(scale=80):
|
with gr.Column(scale=80):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
hr_prompt = gr.Textbox(label="Hires prompt", elem_id="hires_prompt", show_label=False, lines=3, placeholder="Prompt for hires fix pass.\nLeave empty to use the same prompt as in first pass.", elem_classes=["prompt"])
|
hr_prompt = gr.Textbox(label="Hires prompt", elem_id="hires_prompt", show_label=False, lines=3, placeholder="Prompt for hires fix pass.\nLeave empty to use the same prompt as in first pass.", elem_classes=["prompt"])
|
||||||
with gr.Column(scale=80):
|
with gr.Column(scale=80):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
hr_negative_prompt = gr.Textbox(label="Hires negative prompt", elem_id="hires_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt for hires fix pass.\nLeave empty to use the same negative prompt as in first pass.", elem_classes=["prompt"])
|
hr_negative_prompt = gr.Textbox(label="Hires negative prompt", elem_id="hires_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt for hires fix pass.\nLeave empty to use the same negative prompt as in first pass.", elem_classes=["prompt"])
|
||||||
|
|
||||||
|
scripts.scripts_txt2img.setup_ui_for_section(category)
|
||||||
|
|
||||||
elif category == "batch":
|
elif category == "batch":
|
||||||
if not opts.dimensions_and_batch_together:
|
if not opts.dimensions_and_batch_together:
|
||||||
@ -482,7 +410,7 @@ def create_ui():
|
|||||||
with FormGroup(elem_id="txt2img_script_container"):
|
with FormGroup(elem_id="txt2img_script_container"):
|
||||||
custom_inputs = scripts.scripts_txt2img.setup_ui()
|
custom_inputs = scripts.scripts_txt2img.setup_ui()
|
||||||
|
|
||||||
else:
|
if category not in {"accordions"}:
|
||||||
scripts.scripts_txt2img.setup_ui_for_section(category)
|
scripts.scripts_txt2img.setup_ui_for_section(category)
|
||||||
|
|
||||||
hr_resolution_preview_inputs = [enable_hr, width, height, hr_scale, hr_resize_x, hr_resize_y]
|
hr_resolution_preview_inputs = [enable_hr, width, height, hr_scale, hr_resize_x, hr_resize_y]
|
||||||
@ -506,9 +434,6 @@ def create_ui():
|
|||||||
|
|
||||||
txt2img_gallery, generation_info, html_info, html_log = create_output_panel("txt2img", opts.outdir_txt2img_samples)
|
txt2img_gallery, generation_info, html_info, html_log = create_output_panel("txt2img", opts.outdir_txt2img_samples)
|
||||||
|
|
||||||
connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False)
|
|
||||||
connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True)
|
|
||||||
|
|
||||||
txt2img_args = dict(
|
txt2img_args = dict(
|
||||||
fn=wrap_gradio_gpu_call(modules.txt2img.txt2img, extra_outputs=[None, '', '']),
|
fn=wrap_gradio_gpu_call(modules.txt2img.txt2img, extra_outputs=[None, '', '']),
|
||||||
_js="submit",
|
_js="submit",
|
||||||
@ -522,8 +447,6 @@ def create_ui():
|
|||||||
batch_count,
|
batch_count,
|
||||||
batch_size,
|
batch_size,
|
||||||
cfg_scale,
|
cfg_scale,
|
||||||
seed,
|
|
||||||
subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox,
|
|
||||||
height,
|
height,
|
||||||
width,
|
width,
|
||||||
enable_hr,
|
enable_hr,
|
||||||
@ -574,15 +497,9 @@ def create_ui():
|
|||||||
(steps, "Steps"),
|
(steps, "Steps"),
|
||||||
(sampler_name, "Sampler"),
|
(sampler_name, "Sampler"),
|
||||||
(cfg_scale, "CFG scale"),
|
(cfg_scale, "CFG scale"),
|
||||||
(seed, "Seed"),
|
|
||||||
(width, "Size-1"),
|
(width, "Size-1"),
|
||||||
(height, "Size-2"),
|
(height, "Size-2"),
|
||||||
(batch_size, "Batch size"),
|
(batch_size, "Batch size"),
|
||||||
(seed_checkbox, lambda d: "Variation seed" in d or "Seed resize from-1" in d),
|
|
||||||
(subseed, "Variation seed"),
|
|
||||||
(subseed_strength, "Variation seed strength"),
|
|
||||||
(seed_resize_from_w, "Seed resize from-1"),
|
|
||||||
(seed_resize_from_h, "Seed resize from-2"),
|
|
||||||
(toprow.ui_styles.dropdown, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()),
|
(toprow.ui_styles.dropdown, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()),
|
||||||
(denoising_strength, "Denoising strength"),
|
(denoising_strength, "Denoising strength"),
|
||||||
(enable_hr, lambda d: "Denoising strength" in d and ("Hires upscale" in d or "Hires upscaler" in d or "Hires resize-1" in d)),
|
(enable_hr, lambda d: "Denoising strength" in d and ("Hires upscale" in d or "Hires upscaler" in d or "Hires resize-1" in d)),
|
||||||
@ -610,7 +527,7 @@ def create_ui():
|
|||||||
steps,
|
steps,
|
||||||
sampler_name,
|
sampler_name,
|
||||||
cfg_scale,
|
cfg_scale,
|
||||||
seed,
|
scripts.scripts_txt2img.script('Seed').seed,
|
||||||
width,
|
width,
|
||||||
height,
|
height,
|
||||||
]
|
]
|
||||||
@ -780,20 +697,22 @@ def create_ui():
|
|||||||
batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count")
|
batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count")
|
||||||
batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="img2img_batch_size")
|
batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="img2img_batch_size")
|
||||||
|
|
||||||
elif category == "cfg":
|
elif category == "denoising":
|
||||||
with FormGroup():
|
denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.75, elem_id="img2img_denoising_strength")
|
||||||
with FormRow():
|
|
||||||
cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="img2img_cfg_scale")
|
|
||||||
image_cfg_scale = gr.Slider(minimum=0, maximum=3.0, step=0.05, label='Image CFG Scale', value=1.5, elem_id="img2img_image_cfg_scale", visible=False)
|
|
||||||
denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.75, elem_id="img2img_denoising_strength")
|
|
||||||
|
|
||||||
elif category == "seed":
|
elif category == "cfg":
|
||||||
seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs('img2img')
|
with gr.Row():
|
||||||
|
cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="img2img_cfg_scale")
|
||||||
|
image_cfg_scale = gr.Slider(minimum=0, maximum=3.0, step=0.05, label='Image CFG Scale', value=1.5, elem_id="img2img_image_cfg_scale", visible=False)
|
||||||
|
|
||||||
elif category == "checkboxes":
|
elif category == "checkboxes":
|
||||||
with FormRow(elem_classes="checkboxes-row", variant="compact"):
|
with FormRow(elem_classes="checkboxes-row", variant="compact"):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
elif category == "accordions":
|
||||||
|
with gr.Row(elem_id="img2img_accordions", elem_classes="accordions"):
|
||||||
|
scripts.scripts_img2img.setup_ui_for_section(category)
|
||||||
|
|
||||||
elif category == "batch":
|
elif category == "batch":
|
||||||
if not opts.dimensions_and_batch_together:
|
if not opts.dimensions_and_batch_together:
|
||||||
with FormRow(elem_id="img2img_column_batch"):
|
with FormRow(elem_id="img2img_column_batch"):
|
||||||
@ -836,14 +755,12 @@ def create_ui():
|
|||||||
inputs=[],
|
inputs=[],
|
||||||
outputs=[inpaint_controls, mask_alpha],
|
outputs=[inpaint_controls, mask_alpha],
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
|
if category not in {"accordions"}:
|
||||||
scripts.scripts_img2img.setup_ui_for_section(category)
|
scripts.scripts_img2img.setup_ui_for_section(category)
|
||||||
|
|
||||||
img2img_gallery, generation_info, html_info, html_log = create_output_panel("img2img", opts.outdir_img2img_samples)
|
img2img_gallery, generation_info, html_info, html_log = create_output_panel("img2img", opts.outdir_img2img_samples)
|
||||||
|
|
||||||
connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False)
|
|
||||||
connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True)
|
|
||||||
|
|
||||||
img2img_args = dict(
|
img2img_args = dict(
|
||||||
fn=wrap_gradio_gpu_call(modules.img2img.img2img, extra_outputs=[None, '', '']),
|
fn=wrap_gradio_gpu_call(modules.img2img.img2img, extra_outputs=[None, '', '']),
|
||||||
_js="submit_img2img",
|
_js="submit_img2img",
|
||||||
@ -870,8 +787,6 @@ def create_ui():
|
|||||||
cfg_scale,
|
cfg_scale,
|
||||||
image_cfg_scale,
|
image_cfg_scale,
|
||||||
denoising_strength,
|
denoising_strength,
|
||||||
seed,
|
|
||||||
subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox,
|
|
||||||
selected_scale_tab,
|
selected_scale_tab,
|
||||||
height,
|
height,
|
||||||
width,
|
width,
|
||||||
@ -958,15 +873,9 @@ def create_ui():
|
|||||||
(sampler_name, "Sampler"),
|
(sampler_name, "Sampler"),
|
||||||
(cfg_scale, "CFG scale"),
|
(cfg_scale, "CFG scale"),
|
||||||
(image_cfg_scale, "Image CFG scale"),
|
(image_cfg_scale, "Image CFG scale"),
|
||||||
(seed, "Seed"),
|
|
||||||
(width, "Size-1"),
|
(width, "Size-1"),
|
||||||
(height, "Size-2"),
|
(height, "Size-2"),
|
||||||
(batch_size, "Batch size"),
|
(batch_size, "Batch size"),
|
||||||
(seed_checkbox, lambda d: "Variation seed" in d or "Seed resize from-1" in d),
|
|
||||||
(subseed, "Variation seed"),
|
|
||||||
(subseed_strength, "Variation seed strength"),
|
|
||||||
(seed_resize_from_w, "Seed resize from-1"),
|
|
||||||
(seed_resize_from_h, "Seed resize from-2"),
|
|
||||||
(toprow.ui_styles.dropdown, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()),
|
(toprow.ui_styles.dropdown, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()),
|
||||||
(denoising_strength, "Denoising strength"),
|
(denoising_strength, "Denoising strength"),
|
||||||
(mask_blur, "Mask blur"),
|
(mask_blur, "Mask blur"),
|
||||||
|
@ -137,13 +137,17 @@ Requested path was: {f}
|
|||||||
generation_info = None
|
generation_info = None
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
with gr.Row(elem_id=f"image_buttons_{tabname}", elem_classes="image-buttons"):
|
with gr.Row(elem_id=f"image_buttons_{tabname}", elem_classes="image-buttons"):
|
||||||
open_folder_button = gr.Button(folder_symbol, visible=not shared.cmd_opts.hide_ui_dir_config)
|
open_folder_button = ToolButton(folder_symbol, elem_id=f'{tabname}_open_folder', visible=not shared.cmd_opts.hide_ui_dir_config, tooltip="Open images output directory.")
|
||||||
|
|
||||||
if tabname != "extras":
|
if tabname != "extras":
|
||||||
save = gr.Button('Save', elem_id=f'save_{tabname}')
|
save = ToolButton('💾', elem_id=f'save_{tabname}', tooltip=f"Save the image to a dedicated directory ({shared.opts.outdir_save}).")
|
||||||
save_zip = gr.Button('Zip', elem_id=f'save_zip_{tabname}')
|
save_zip = ToolButton('🗃️', elem_id=f'save_zip_{tabname}', tooltip=f"Save zip archive with images to a dedicated directory ({shared.opts.outdir_save})")
|
||||||
|
|
||||||
buttons = parameters_copypaste.create_buttons(["img2img", "inpaint", "extras"])
|
buttons = {
|
||||||
|
'img2img': ToolButton('🖼️', elem_id=f'{tabname}_send_to_img2img', tooltip="Send image and generation parameters to img2img tab."),
|
||||||
|
'inpaint': ToolButton('🎨️', elem_id=f'{tabname}_send_to_inpaint', tooltip="Send image and generation parameters to img2img inpaint tab."),
|
||||||
|
'extras': ToolButton('📐', elem_id=f'{tabname}_send_to_extras', tooltip="Send image and generation parameters to extras tab.")
|
||||||
|
}
|
||||||
|
|
||||||
open_folder_button.click(
|
open_folder_button.click(
|
||||||
fn=lambda: open_folder(shared.opts.outdir_samples or outdir),
|
fn=lambda: open_folder(shared.opts.outdir_samples or outdir),
|
||||||
|
@ -87,13 +87,23 @@ class InputAccordion(gr.Checkbox):
|
|||||||
self.accordion_id = f"input-accordion-{InputAccordion.global_index}"
|
self.accordion_id = f"input-accordion-{InputAccordion.global_index}"
|
||||||
InputAccordion.global_index += 1
|
InputAccordion.global_index += 1
|
||||||
|
|
||||||
kwargs['elem_id'] = self.accordion_id + "-checkbox"
|
kwargs_checkbox = {
|
||||||
kwargs['visible'] = False
|
**kwargs,
|
||||||
super().__init__(value, **kwargs)
|
"elem_id": f"{self.accordion_id}-checkbox",
|
||||||
|
"visible": False,
|
||||||
|
}
|
||||||
|
super().__init__(value, **kwargs_checkbox)
|
||||||
|
|
||||||
self.change(fn=None, _js='function(checked){ inputAccordionChecked("' + self.accordion_id + '", checked); }', inputs=[self])
|
self.change(fn=None, _js='function(checked){ inputAccordionChecked("' + self.accordion_id + '", checked); }', inputs=[self])
|
||||||
|
|
||||||
self.accordion = gr.Accordion(kwargs.get('label', 'Accordion'), open=value, elem_id=self.accordion_id, elem_classes=['input-accordion'])
|
kwargs_accordion = {
|
||||||
|
**kwargs,
|
||||||
|
"elem_id": self.accordion_id,
|
||||||
|
"label": kwargs.get('label', 'Accordion'),
|
||||||
|
"elem_classes": ['input-accordion'],
|
||||||
|
"open": value,
|
||||||
|
}
|
||||||
|
self.accordion = gr.Accordion(**kwargs_accordion)
|
||||||
|
|
||||||
def extra(self):
|
def extra(self):
|
||||||
"""Allows you to put something into the label of the accordion.
|
"""Allows you to put something into the label of the accordion.
|
||||||
|
@ -19,6 +19,7 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
|
|||||||
return {
|
return {
|
||||||
"name": checkpoint.name_for_extra,
|
"name": checkpoint.name_for_extra,
|
||||||
"filename": checkpoint.filename,
|
"filename": checkpoint.filename,
|
||||||
|
"shorthash": checkpoint.shorthash,
|
||||||
"preview": self.find_preview(path),
|
"preview": self.find_preview(path),
|
||||||
"description": self.find_description(path),
|
"description": self.find_description(path),
|
||||||
"search_term": self.search_terms_from_path(checkpoint.filename) + " " + (checkpoint.sha256 or ""),
|
"search_term": self.search_terms_from_path(checkpoint.filename) + " " + (checkpoint.sha256 or ""),
|
||||||
|
@ -2,6 +2,7 @@ import os
|
|||||||
|
|
||||||
from modules import shared, ui_extra_networks
|
from modules import shared, ui_extra_networks
|
||||||
from modules.ui_extra_networks import quote_js
|
from modules.ui_extra_networks import quote_js
|
||||||
|
from modules.hashes import sha256_from_cache
|
||||||
|
|
||||||
|
|
||||||
class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage):
|
class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage):
|
||||||
@ -14,13 +15,16 @@ class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage):
|
|||||||
def create_item(self, name, index=None, enable_filter=True):
|
def create_item(self, name, index=None, enable_filter=True):
|
||||||
full_path = shared.hypernetworks[name]
|
full_path = shared.hypernetworks[name]
|
||||||
path, ext = os.path.splitext(full_path)
|
path, ext = os.path.splitext(full_path)
|
||||||
|
sha256 = sha256_from_cache(full_path, f'hypernet/{name}')
|
||||||
|
shorthash = sha256[0:10] if sha256 else None
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"name": name,
|
"name": name,
|
||||||
"filename": full_path,
|
"filename": full_path,
|
||||||
|
"shorthash": shorthash,
|
||||||
"preview": self.find_preview(path),
|
"preview": self.find_preview(path),
|
||||||
"description": self.find_description(path),
|
"description": self.find_description(path),
|
||||||
"search_term": self.search_terms_from_path(path),
|
"search_term": self.search_terms_from_path(path) + " " + (sha256 or ""),
|
||||||
"prompt": quote_js(f"<hypernet:{name}:") + " + opts.extra_networks_default_multiplier + " + quote_js(">"),
|
"prompt": quote_js(f"<hypernet:{name}:") + " + opts.extra_networks_default_multiplier + " + quote_js(">"),
|
||||||
"local_preview": f"{path}.preview.{shared.opts.samples_format}",
|
"local_preview": f"{path}.preview.{shared.opts.samples_format}",
|
||||||
"sort_keys": {'default': index, **self.get_sort_keys(path + ext)},
|
"sort_keys": {'default': index, **self.get_sort_keys(path + ext)},
|
||||||
|
@ -19,9 +19,10 @@ class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage):
|
|||||||
return {
|
return {
|
||||||
"name": name,
|
"name": name,
|
||||||
"filename": embedding.filename,
|
"filename": embedding.filename,
|
||||||
|
"shorthash": embedding.shorthash,
|
||||||
"preview": self.find_preview(path),
|
"preview": self.find_preview(path),
|
||||||
"description": self.find_description(path),
|
"description": self.find_description(path),
|
||||||
"search_term": self.search_terms_from_path(embedding.filename),
|
"search_term": self.search_terms_from_path(embedding.filename) + " " + (embedding.hash or ""),
|
||||||
"prompt": quote_js(embedding.name),
|
"prompt": quote_js(embedding.name),
|
||||||
"local_preview": f"{path}.preview.{shared.opts.samples_format}",
|
"local_preview": f"{path}.preview.{shared.opts.samples_format}",
|
||||||
"sort_keys": {'default': index, **self.get_sort_keys(embedding.filename)},
|
"sort_keys": {'default': index, **self.get_sort_keys(embedding.filename)},
|
||||||
|
@ -93,11 +93,13 @@ class UserMetadataEditor:
|
|||||||
item = self.page.items.get(name, {})
|
item = self.page.items.get(name, {})
|
||||||
try:
|
try:
|
||||||
filename = item["filename"]
|
filename = item["filename"]
|
||||||
|
shorthash = item.get("shorthash", None)
|
||||||
|
|
||||||
stats = os.stat(filename)
|
stats = os.stat(filename)
|
||||||
params = [
|
params = [
|
||||||
('Filename: ', os.path.basename(filename)),
|
('Filename: ', os.path.basename(filename)),
|
||||||
('File size: ', sysinfo.pretty_bytes(stats.st_size)),
|
('File size: ', sysinfo.pretty_bytes(stats.st_size)),
|
||||||
|
('Hash: ', shorthash),
|
||||||
('Modified: ', datetime.datetime.fromtimestamp(stats.st_mtime).strftime('%Y-%m-%d %H:%M')),
|
('Modified: ', datetime.datetime.fromtimestamp(stats.st_mtime).strftime('%Y-%m-%d %H:%M')),
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -115,7 +117,7 @@ class UserMetadataEditor:
|
|||||||
errors.display(e, f"reading metadata info for {name}")
|
errors.display(e, f"reading metadata info for {name}")
|
||||||
params = []
|
params = []
|
||||||
|
|
||||||
table = '<table class="file-metadata">' + "".join(f"<tr><th>{name}</th><td>{value}</td></tr>" for name, value in params) + '</table>'
|
table = '<table class="file-metadata">' + "".join(f"<tr><th>{name}</th><td>{value}</td></tr>" for name, value in params if value is not None) + '</table>'
|
||||||
|
|
||||||
return html.escape(name), user_metadata.get('description', ''), table, self.get_card_html(name), user_metadata.get('notes', '')
|
return html.escape(name), user_metadata.get('description', ''), table, self.get_card_html(name), user_metadata.get('notes', '')
|
||||||
|
|
||||||
|
@ -48,13 +48,13 @@ class UiLoadsave:
|
|||||||
elif condition and not condition(saved_value):
|
elif condition and not condition(saved_value):
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
if isinstance(x, gr.Textbox) and field == 'value': # due to an undersirable behavior of gr.Textbox, if you give it an int value instead of str, everything dies
|
if isinstance(x, gr.Textbox) and field == 'value': # due to an undesirable behavior of gr.Textbox, if you give it an int value instead of str, everything dies
|
||||||
saved_value = str(saved_value)
|
saved_value = str(saved_value)
|
||||||
elif isinstance(x, gr.Number) and field == 'value':
|
elif isinstance(x, gr.Number) and field == 'value':
|
||||||
try:
|
try:
|
||||||
saved_value = float(saved_value)
|
saved_value = float(saved_value)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
saved_value = -1
|
return
|
||||||
|
|
||||||
setattr(obj, field, saved_value)
|
setattr(obj, field, saved_value)
|
||||||
if init_field is not None:
|
if init_field is not None:
|
||||||
|
@ -175,14 +175,22 @@ def do_nothing(p, x, xs):
|
|||||||
def format_nothing(p, opt, x):
|
def format_nothing(p, opt, x):
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
|
||||||
def format_remove_path(p, opt, x):
|
def format_remove_path(p, opt, x):
|
||||||
return os.path.basename(x)
|
return os.path.basename(x)
|
||||||
|
|
||||||
|
|
||||||
def str_permutations(x):
|
def str_permutations(x):
|
||||||
"""dummy function for specifying it in AxisOption's type when you want to get a list of permutations"""
|
"""dummy function for specifying it in AxisOption's type when you want to get a list of permutations"""
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def list_to_csv_string(data_list):
|
||||||
|
with StringIO() as o:
|
||||||
|
csv.writer(o).writerow(data_list)
|
||||||
|
return o.getvalue().strip()
|
||||||
|
|
||||||
|
|
||||||
class AxisOption:
|
class AxisOption:
|
||||||
def __init__(self, label, type, apply, format_value=format_value_add_label, confirm=None, cost=0.0, choices=None):
|
def __init__(self, label, type, apply, format_value=format_value_add_label, confirm=None, cost=0.0, choices=None):
|
||||||
self.label = label
|
self.label = label
|
||||||
@ -199,6 +207,7 @@ class AxisOptionImg2Img(AxisOption):
|
|||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.is_img2img = True
|
self.is_img2img = True
|
||||||
|
|
||||||
|
|
||||||
class AxisOptionTxt2Img(AxisOption):
|
class AxisOptionTxt2Img(AxisOption):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
@ -286,11 +295,10 @@ def draw_xyz_grid(p, xs, ys, zs, x_labels, y_labels, z_labels, cell, draw_legend
|
|||||||
cell_size = (processed_result.width, processed_result.height)
|
cell_size = (processed_result.width, processed_result.height)
|
||||||
if processed_result.images[0] is not None:
|
if processed_result.images[0] is not None:
|
||||||
cell_mode = processed_result.images[0].mode
|
cell_mode = processed_result.images[0].mode
|
||||||
#This corrects size in case of batches:
|
# This corrects size in case of batches:
|
||||||
cell_size = processed_result.images[0].size
|
cell_size = processed_result.images[0].size
|
||||||
processed_result.images[idx] = Image.new(cell_mode, cell_size)
|
processed_result.images[idx] = Image.new(cell_mode, cell_size)
|
||||||
|
|
||||||
|
|
||||||
if first_axes_processed == 'x':
|
if first_axes_processed == 'x':
|
||||||
for ix, x in enumerate(xs):
|
for ix, x in enumerate(xs):
|
||||||
if second_axes_processed == 'y':
|
if second_axes_processed == 'y':
|
||||||
@ -348,9 +356,9 @@ def draw_xyz_grid(p, xs, ys, zs, x_labels, y_labels, z_labels, cell, draw_legend
|
|||||||
if draw_legend:
|
if draw_legend:
|
||||||
z_grid = images.draw_grid_annotations(z_grid, sub_grid_size[0], sub_grid_size[1], title_texts, [[images.GridAnnotation()]])
|
z_grid = images.draw_grid_annotations(z_grid, sub_grid_size[0], sub_grid_size[1], title_texts, [[images.GridAnnotation()]])
|
||||||
processed_result.images.insert(0, z_grid)
|
processed_result.images.insert(0, z_grid)
|
||||||
#TODO: Deeper aspects of the program rely on grid info being misaligned between metadata arrays, which is not ideal.
|
# TODO: Deeper aspects of the program rely on grid info being misaligned between metadata arrays, which is not ideal.
|
||||||
#processed_result.all_prompts.insert(0, processed_result.all_prompts[0])
|
# processed_result.all_prompts.insert(0, processed_result.all_prompts[0])
|
||||||
#processed_result.all_seeds.insert(0, processed_result.all_seeds[0])
|
# processed_result.all_seeds.insert(0, processed_result.all_seeds[0])
|
||||||
processed_result.infotexts.insert(0, processed_result.infotexts[0])
|
processed_result.infotexts.insert(0, processed_result.infotexts[0])
|
||||||
|
|
||||||
return processed_result
|
return processed_result
|
||||||
@ -374,8 +382,8 @@ class SharedSettingsStackHelper(object):
|
|||||||
re_range = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\(([+-]\d+)\s*\))?\s*")
|
re_range = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\(([+-]\d+)\s*\))?\s*")
|
||||||
re_range_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d*)?)(?:\s*\(([+-]\d+(?:.\d*)?)\s*\))?\s*")
|
re_range_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d*)?)(?:\s*\(([+-]\d+(?:.\d*)?)\s*\))?\s*")
|
||||||
|
|
||||||
re_range_count = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\[(\d+)\s*\])?\s*")
|
re_range_count = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\[(\d+)\s*])?\s*")
|
||||||
re_range_count_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d*)?)(?:\s*\[(\d+(?:.\d*)?)\s*\])?\s*")
|
re_range_count_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d*)?)(?:\s*\[(\d+(?:.\d*)?)\s*])?\s*")
|
||||||
|
|
||||||
|
|
||||||
class Script(scripts.Script):
|
class Script(scripts.Script):
|
||||||
@ -390,19 +398,19 @@ class Script(scripts.Script):
|
|||||||
with gr.Row():
|
with gr.Row():
|
||||||
x_type = gr.Dropdown(label="X type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[1].label, type="index", elem_id=self.elem_id("x_type"))
|
x_type = gr.Dropdown(label="X type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[1].label, type="index", elem_id=self.elem_id("x_type"))
|
||||||
x_values = gr.Textbox(label="X values", lines=1, elem_id=self.elem_id("x_values"))
|
x_values = gr.Textbox(label="X values", lines=1, elem_id=self.elem_id("x_values"))
|
||||||
x_values_dropdown = gr.Dropdown(label="X values",visible=False,multiselect=True,interactive=True)
|
x_values_dropdown = gr.Dropdown(label="X values", visible=False, multiselect=True, interactive=True)
|
||||||
fill_x_button = ToolButton(value=fill_values_symbol, elem_id="xyz_grid_fill_x_tool_button", visible=False)
|
fill_x_button = ToolButton(value=fill_values_symbol, elem_id="xyz_grid_fill_x_tool_button", visible=False)
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
y_type = gr.Dropdown(label="Y type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[0].label, type="index", elem_id=self.elem_id("y_type"))
|
y_type = gr.Dropdown(label="Y type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[0].label, type="index", elem_id=self.elem_id("y_type"))
|
||||||
y_values = gr.Textbox(label="Y values", lines=1, elem_id=self.elem_id("y_values"))
|
y_values = gr.Textbox(label="Y values", lines=1, elem_id=self.elem_id("y_values"))
|
||||||
y_values_dropdown = gr.Dropdown(label="Y values",visible=False,multiselect=True,interactive=True)
|
y_values_dropdown = gr.Dropdown(label="Y values", visible=False, multiselect=True, interactive=True)
|
||||||
fill_y_button = ToolButton(value=fill_values_symbol, elem_id="xyz_grid_fill_y_tool_button", visible=False)
|
fill_y_button = ToolButton(value=fill_values_symbol, elem_id="xyz_grid_fill_y_tool_button", visible=False)
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
z_type = gr.Dropdown(label="Z type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[0].label, type="index", elem_id=self.elem_id("z_type"))
|
z_type = gr.Dropdown(label="Z type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[0].label, type="index", elem_id=self.elem_id("z_type"))
|
||||||
z_values = gr.Textbox(label="Z values", lines=1, elem_id=self.elem_id("z_values"))
|
z_values = gr.Textbox(label="Z values", lines=1, elem_id=self.elem_id("z_values"))
|
||||||
z_values_dropdown = gr.Dropdown(label="Z values",visible=False,multiselect=True,interactive=True)
|
z_values_dropdown = gr.Dropdown(label="Z values", visible=False, multiselect=True, interactive=True)
|
||||||
fill_z_button = ToolButton(value=fill_values_symbol, elem_id="xyz_grid_fill_z_tool_button", visible=False)
|
fill_z_button = ToolButton(value=fill_values_symbol, elem_id="xyz_grid_fill_z_tool_button", visible=False)
|
||||||
|
|
||||||
with gr.Row(variant="compact", elem_id="axis_options"):
|
with gr.Row(variant="compact", elem_id="axis_options"):
|
||||||
@ -414,6 +422,9 @@ class Script(scripts.Script):
|
|||||||
include_sub_grids = gr.Checkbox(label='Include Sub Grids', value=False, elem_id=self.elem_id("include_sub_grids"))
|
include_sub_grids = gr.Checkbox(label='Include Sub Grids', value=False, elem_id=self.elem_id("include_sub_grids"))
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
margin_size = gr.Slider(label="Grid margins (px)", minimum=0, maximum=500, value=0, step=2, elem_id=self.elem_id("margin_size"))
|
margin_size = gr.Slider(label="Grid margins (px)", minimum=0, maximum=500, value=0, step=2, elem_id=self.elem_id("margin_size"))
|
||||||
|
with gr.Column():
|
||||||
|
csv_mode = gr.Checkbox(label='Use text inputs instead of dropdowns', value=False, elem_id=self.elem_id("csv_mode"))
|
||||||
|
|
||||||
|
|
||||||
with gr.Row(variant="compact", elem_id="swap_axes"):
|
with gr.Row(variant="compact", elem_id="swap_axes"):
|
||||||
swap_xy_axes_button = gr.Button(value="Swap X/Y axes", elem_id="xy_grid_swap_axes_button")
|
swap_xy_axes_button = gr.Button(value="Swap X/Y axes", elem_id="xy_grid_swap_axes_button")
|
||||||
@ -430,50 +441,71 @@ class Script(scripts.Script):
|
|||||||
xz_swap_args = [x_type, x_values, x_values_dropdown, z_type, z_values, z_values_dropdown]
|
xz_swap_args = [x_type, x_values, x_values_dropdown, z_type, z_values, z_values_dropdown]
|
||||||
swap_xz_axes_button.click(swap_axes, inputs=xz_swap_args, outputs=xz_swap_args)
|
swap_xz_axes_button.click(swap_axes, inputs=xz_swap_args, outputs=xz_swap_args)
|
||||||
|
|
||||||
def fill(x_type):
|
def fill(axis_type, csv_mode):
|
||||||
axis = self.current_axis_options[x_type]
|
axis = self.current_axis_options[axis_type]
|
||||||
return axis.choices() if axis.choices else gr.update()
|
if axis.choices:
|
||||||
|
if csv_mode:
|
||||||
|
return list_to_csv_string(axis.choices()), gr.update()
|
||||||
|
else:
|
||||||
|
return gr.update(), axis.choices()
|
||||||
|
else:
|
||||||
|
return gr.update(), gr.update()
|
||||||
|
|
||||||
fill_x_button.click(fn=fill, inputs=[x_type], outputs=[x_values_dropdown])
|
fill_x_button.click(fn=fill, inputs=[x_type, csv_mode], outputs=[x_values, x_values_dropdown])
|
||||||
fill_y_button.click(fn=fill, inputs=[y_type], outputs=[y_values_dropdown])
|
fill_y_button.click(fn=fill, inputs=[y_type, csv_mode], outputs=[y_values, y_values_dropdown])
|
||||||
fill_z_button.click(fn=fill, inputs=[z_type], outputs=[z_values_dropdown])
|
fill_z_button.click(fn=fill, inputs=[z_type, csv_mode], outputs=[z_values, z_values_dropdown])
|
||||||
|
|
||||||
def select_axis(axis_type,axis_values_dropdown):
|
def select_axis(axis_type, axis_values, axis_values_dropdown, csv_mode):
|
||||||
choices = self.current_axis_options[axis_type].choices
|
choices = self.current_axis_options[axis_type].choices
|
||||||
has_choices = choices is not None
|
has_choices = choices is not None
|
||||||
current_values = axis_values_dropdown
|
|
||||||
|
current_values = axis_values
|
||||||
|
current_dropdown_values = axis_values_dropdown
|
||||||
if has_choices:
|
if has_choices:
|
||||||
choices = choices()
|
choices = choices()
|
||||||
if isinstance(current_values,str):
|
if csv_mode:
|
||||||
current_values = current_values.split(",")
|
current_dropdown_values = list(filter(lambda x: x in choices, current_dropdown_values))
|
||||||
current_values = list(filter(lambda x: x in choices, current_values))
|
current_values = list_to_csv_string(current_dropdown_values)
|
||||||
return gr.Button.update(visible=has_choices),gr.Textbox.update(visible=not has_choices),gr.update(choices=choices if has_choices else None,visible=has_choices,value=current_values)
|
else:
|
||||||
|
current_dropdown_values = [x.strip() for x in chain.from_iterable(csv.reader(StringIO(axis_values)))]
|
||||||
|
current_dropdown_values = list(filter(lambda x: x in choices, current_dropdown_values))
|
||||||
|
|
||||||
x_type.change(fn=select_axis, inputs=[x_type,x_values_dropdown], outputs=[fill_x_button,x_values,x_values_dropdown])
|
return (gr.Button.update(visible=has_choices), gr.Textbox.update(visible=not has_choices or csv_mode, value=current_values),
|
||||||
y_type.change(fn=select_axis, inputs=[y_type,y_values_dropdown], outputs=[fill_y_button,y_values,y_values_dropdown])
|
gr.update(choices=choices if has_choices else None, visible=has_choices and not csv_mode, value=current_dropdown_values))
|
||||||
z_type.change(fn=select_axis, inputs=[z_type,z_values_dropdown], outputs=[fill_z_button,z_values,z_values_dropdown])
|
|
||||||
|
|
||||||
def get_dropdown_update_from_params(axis,params):
|
x_type.change(fn=select_axis, inputs=[x_type, x_values, x_values_dropdown, csv_mode], outputs=[fill_x_button, x_values, x_values_dropdown])
|
||||||
|
y_type.change(fn=select_axis, inputs=[y_type, y_values, y_values_dropdown, csv_mode], outputs=[fill_y_button, y_values, y_values_dropdown])
|
||||||
|
z_type.change(fn=select_axis, inputs=[z_type, z_values, z_values_dropdown, csv_mode], outputs=[fill_z_button, z_values, z_values_dropdown])
|
||||||
|
|
||||||
|
def change_choice_mode(csv_mode, x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown):
|
||||||
|
_fill_x_button, _x_values, _x_values_dropdown = select_axis(x_type, x_values, x_values_dropdown, csv_mode)
|
||||||
|
_fill_y_button, _y_values, _y_values_dropdown = select_axis(y_type, y_values, y_values_dropdown, csv_mode)
|
||||||
|
_fill_z_button, _z_values, _z_values_dropdown = select_axis(z_type, z_values, z_values_dropdown, csv_mode)
|
||||||
|
return _fill_x_button, _x_values, _x_values_dropdown, _fill_y_button, _y_values, _y_values_dropdown, _fill_z_button, _z_values, _z_values_dropdown
|
||||||
|
|
||||||
|
csv_mode.change(fn=change_choice_mode, inputs=[csv_mode, x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown], outputs=[fill_x_button, x_values, x_values_dropdown, fill_y_button, y_values, y_values_dropdown, fill_z_button, z_values, z_values_dropdown])
|
||||||
|
|
||||||
|
def get_dropdown_update_from_params(axis, params):
|
||||||
val_key = f"{axis} Values"
|
val_key = f"{axis} Values"
|
||||||
vals = params.get(val_key,"")
|
vals = params.get(val_key, "")
|
||||||
valslist = [x.strip() for x in chain.from_iterable(csv.reader(StringIO(vals))) if x]
|
valslist = [x.strip() for x in chain.from_iterable(csv.reader(StringIO(vals))) if x]
|
||||||
return gr.update(value = valslist)
|
return gr.update(value=valslist)
|
||||||
|
|
||||||
self.infotext_fields = (
|
self.infotext_fields = (
|
||||||
(x_type, "X Type"),
|
(x_type, "X Type"),
|
||||||
(x_values, "X Values"),
|
(x_values, "X Values"),
|
||||||
(x_values_dropdown, lambda params:get_dropdown_update_from_params("X",params)),
|
(x_values_dropdown, lambda params: get_dropdown_update_from_params("X", params)),
|
||||||
(y_type, "Y Type"),
|
(y_type, "Y Type"),
|
||||||
(y_values, "Y Values"),
|
(y_values, "Y Values"),
|
||||||
(y_values_dropdown, lambda params:get_dropdown_update_from_params("Y",params)),
|
(y_values_dropdown, lambda params: get_dropdown_update_from_params("Y", params)),
|
||||||
(z_type, "Z Type"),
|
(z_type, "Z Type"),
|
||||||
(z_values, "Z Values"),
|
(z_values, "Z Values"),
|
||||||
(z_values_dropdown, lambda params:get_dropdown_update_from_params("Z",params)),
|
(z_values_dropdown, lambda params: get_dropdown_update_from_params("Z", params)),
|
||||||
)
|
)
|
||||||
|
|
||||||
return [x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, margin_size]
|
return [x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, margin_size, csv_mode]
|
||||||
|
|
||||||
def run(self, p, x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, margin_size):
|
def run(self, p, x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, margin_size, csv_mode):
|
||||||
if not no_fixed_seeds:
|
if not no_fixed_seeds:
|
||||||
modules.processing.fix_seed(p)
|
modules.processing.fix_seed(p)
|
||||||
|
|
||||||
@ -484,7 +516,7 @@ class Script(scripts.Script):
|
|||||||
if opt.label == 'Nothing':
|
if opt.label == 'Nothing':
|
||||||
return [0]
|
return [0]
|
||||||
|
|
||||||
if opt.choices is not None:
|
if opt.choices is not None and not csv_mode:
|
||||||
valslist = vals_dropdown
|
valslist = vals_dropdown
|
||||||
else:
|
else:
|
||||||
valslist = [x.strip() for x in chain.from_iterable(csv.reader(StringIO(vals))) if x]
|
valslist = [x.strip() for x in chain.from_iterable(csv.reader(StringIO(vals))) if x]
|
||||||
@ -503,8 +535,8 @@ class Script(scripts.Script):
|
|||||||
valslist_ext += list(range(start, end, step))
|
valslist_ext += list(range(start, end, step))
|
||||||
elif mc is not None:
|
elif mc is not None:
|
||||||
start = int(mc.group(1))
|
start = int(mc.group(1))
|
||||||
end = int(mc.group(2))
|
end = int(mc.group(2))
|
||||||
num = int(mc.group(3)) if mc.group(3) is not None else 1
|
num = int(mc.group(3)) if mc.group(3) is not None else 1
|
||||||
|
|
||||||
valslist_ext += [int(x) for x in np.linspace(start=start, stop=end, num=num).tolist()]
|
valslist_ext += [int(x) for x in np.linspace(start=start, stop=end, num=num).tolist()]
|
||||||
else:
|
else:
|
||||||
@ -525,8 +557,8 @@ class Script(scripts.Script):
|
|||||||
valslist_ext += np.arange(start, end + step, step).tolist()
|
valslist_ext += np.arange(start, end + step, step).tolist()
|
||||||
elif mc is not None:
|
elif mc is not None:
|
||||||
start = float(mc.group(1))
|
start = float(mc.group(1))
|
||||||
end = float(mc.group(2))
|
end = float(mc.group(2))
|
||||||
num = int(mc.group(3)) if mc.group(3) is not None else 1
|
num = int(mc.group(3)) if mc.group(3) is not None else 1
|
||||||
|
|
||||||
valslist_ext += np.linspace(start=start, stop=end, num=num).tolist()
|
valslist_ext += np.linspace(start=start, stop=end, num=num).tolist()
|
||||||
else:
|
else:
|
||||||
@ -545,22 +577,22 @@ class Script(scripts.Script):
|
|||||||
return valslist
|
return valslist
|
||||||
|
|
||||||
x_opt = self.current_axis_options[x_type]
|
x_opt = self.current_axis_options[x_type]
|
||||||
if x_opt.choices is not None:
|
if x_opt.choices is not None and not csv_mode:
|
||||||
x_values = ",".join(x_values_dropdown)
|
x_values = list_to_csv_string(x_values_dropdown)
|
||||||
xs = process_axis(x_opt, x_values, x_values_dropdown)
|
xs = process_axis(x_opt, x_values, x_values_dropdown)
|
||||||
|
|
||||||
y_opt = self.current_axis_options[y_type]
|
y_opt = self.current_axis_options[y_type]
|
||||||
if y_opt.choices is not None:
|
if y_opt.choices is not None and not csv_mode:
|
||||||
y_values = ",".join(y_values_dropdown)
|
y_values = list_to_csv_string(y_values_dropdown)
|
||||||
ys = process_axis(y_opt, y_values, y_values_dropdown)
|
ys = process_axis(y_opt, y_values, y_values_dropdown)
|
||||||
|
|
||||||
z_opt = self.current_axis_options[z_type]
|
z_opt = self.current_axis_options[z_type]
|
||||||
if z_opt.choices is not None:
|
if z_opt.choices is not None and not csv_mode:
|
||||||
z_values = ",".join(z_values_dropdown)
|
z_values = list_to_csv_string(z_values_dropdown)
|
||||||
zs = process_axis(z_opt, z_values, z_values_dropdown)
|
zs = process_axis(z_opt, z_values, z_values_dropdown)
|
||||||
|
|
||||||
# this could be moved to common code, but unlikely to be ever triggered anywhere else
|
# this could be moved to common code, but unlikely to be ever triggered anywhere else
|
||||||
Image.MAX_IMAGE_PIXELS = None # disable check in Pillow and rely on check below to allow large custom image sizes
|
Image.MAX_IMAGE_PIXELS = None # disable check in Pillow and rely on check below to allow large custom image sizes
|
||||||
grid_mp = round(len(xs) * len(ys) * len(zs) * p.width * p.height / 1000000)
|
grid_mp = round(len(xs) * len(ys) * len(zs) * p.width * p.height / 1000000)
|
||||||
assert grid_mp < opts.img_max_size_mp, f'Error: Resulting grid would be too large ({grid_mp} MPixels) (max configured size is {opts.img_max_size_mp} MPixels)'
|
assert grid_mp < opts.img_max_size_mp, f'Error: Resulting grid would be too large ({grid_mp} MPixels) (max configured size is {opts.img_max_size_mp} MPixels)'
|
||||||
|
|
||||||
@ -720,7 +752,7 @@ class Script(scripts.Script):
|
|||||||
# Auto-save main and sub-grids:
|
# Auto-save main and sub-grids:
|
||||||
grid_count = z_count + 1 if z_count > 1 else 1
|
grid_count = z_count + 1 if z_count > 1 else 1
|
||||||
for g in range(grid_count):
|
for g in range(grid_count):
|
||||||
#TODO: See previous comment about intentional data misalignment.
|
# TODO: See previous comment about intentional data misalignment.
|
||||||
adj_g = g-1 if g > 0 else g
|
adj_g = g-1 if g > 0 else g
|
||||||
images.save_image(processed.images[g], p.outpath_grids, "xyz_grid", info=processed.infotexts[g], extension=opts.grid_format, prompt=processed.all_prompts[adj_g], seed=processed.all_seeds[adj_g], grid=True, p=processed)
|
images.save_image(processed.images[g], p.outpath_grids, "xyz_grid", info=processed.infotexts[g], extension=opts.grid_format, prompt=processed.all_prompts[adj_g], seed=processed.all_seeds[adj_g], grid=True, p=processed)
|
||||||
|
|
||||||
|
55
style.css
55
style.css
@ -166,16 +166,6 @@ a{
|
|||||||
color: var(--button-secondary-text-color-hover);
|
color: var(--button-secondary-text-color-hover);
|
||||||
}
|
}
|
||||||
|
|
||||||
.checkboxes-row{
|
|
||||||
margin-bottom: 0.5em;
|
|
||||||
margin-left: 0em;
|
|
||||||
}
|
|
||||||
.checkboxes-row > div{
|
|
||||||
flex: 0;
|
|
||||||
white-space: nowrap;
|
|
||||||
min-width: auto !important;
|
|
||||||
}
|
|
||||||
|
|
||||||
button.custom-button{
|
button.custom-button{
|
||||||
border-radius: var(--button-large-radius);
|
border-radius: var(--button-large-radius);
|
||||||
padding: var(--button-large-padding);
|
padding: var(--button-large-padding);
|
||||||
@ -192,7 +182,7 @@ button.custom-button{
|
|||||||
text-align: center;
|
text-align: center;
|
||||||
}
|
}
|
||||||
|
|
||||||
div.gradio-accordion {
|
div.block.gradio-accordion {
|
||||||
border: 1px solid var(--block-border-color) !important;
|
border: 1px solid var(--block-border-color) !important;
|
||||||
border-radius: 8px !important;
|
border-radius: 8px !important;
|
||||||
margin: 2px 0;
|
margin: 2px 0;
|
||||||
@ -239,10 +229,14 @@ div.gradio-accordion {
|
|||||||
}
|
}
|
||||||
|
|
||||||
[id$=_subseed_show] label{
|
[id$=_subseed_show] label{
|
||||||
margin-bottom: 0.5em;
|
margin-bottom: 0.65em;
|
||||||
align-self: end;
|
align-self: end;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
[id$=_seed_extras] > div{
|
||||||
|
gap: 0.5em;
|
||||||
|
}
|
||||||
|
|
||||||
.html-log .comments{
|
.html-log .comments{
|
||||||
padding-top: 0.5em;
|
padding-top: 0.5em;
|
||||||
}
|
}
|
||||||
@ -352,7 +346,7 @@ div.gradio-accordion {
|
|||||||
}
|
}
|
||||||
|
|
||||||
div.dimensions-tools{
|
div.dimensions-tools{
|
||||||
min-width: 0 !important;
|
min-width: 1.6em !important;
|
||||||
max-width: fit-content;
|
max-width: fit-content;
|
||||||
flex-direction: column;
|
flex-direction: column;
|
||||||
place-content: center;
|
place-content: center;
|
||||||
@ -369,8 +363,8 @@ div#extras_scale_to_tab div.form{
|
|||||||
z-index: 5;
|
z-index: 5;
|
||||||
}
|
}
|
||||||
|
|
||||||
.image-buttons button{
|
.image-buttons > .form{
|
||||||
min-width: auto;
|
justify-content: center;
|
||||||
}
|
}
|
||||||
|
|
||||||
.infotext {
|
.infotext {
|
||||||
@ -391,19 +385,21 @@ div#extras_scale_to_tab div.form{
|
|||||||
|
|
||||||
/* settings */
|
/* settings */
|
||||||
#quicksettings {
|
#quicksettings {
|
||||||
width: fit-content;
|
|
||||||
align-items: end;
|
align-items: end;
|
||||||
}
|
}
|
||||||
|
|
||||||
#quicksettings > div, #quicksettings > fieldset{
|
#quicksettings > div, #quicksettings > fieldset{
|
||||||
max-width: 24em;
|
max-width: 36em;
|
||||||
min-width: 24em;
|
width: fit-content;
|
||||||
width: 24em;
|
flex: 0 1 fit-content;
|
||||||
padding: 0;
|
padding: 0;
|
||||||
border: none;
|
border: none;
|
||||||
box-shadow: none;
|
box-shadow: none;
|
||||||
background: none;
|
background: none;
|
||||||
}
|
}
|
||||||
|
#quicksettings > div.gradio-dropdown{
|
||||||
|
min-width: 24em !important;
|
||||||
|
}
|
||||||
|
|
||||||
#settings{
|
#settings{
|
||||||
display: block;
|
display: block;
|
||||||
@ -1012,10 +1008,29 @@ div.block.gradio-box.popup-dialog > div:last-child, .popup-dialog > div:last-chi
|
|||||||
}
|
}
|
||||||
|
|
||||||
div.block.input-accordion{
|
div.block.input-accordion{
|
||||||
margin-bottom: 0.4em;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
.input-accordion-extra{
|
.input-accordion-extra{
|
||||||
flex: 0 0 auto !important;
|
flex: 0 0 auto !important;
|
||||||
margin: 0 0.5em 0 auto;
|
margin: 0 0.5em 0 auto;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
div.accordions > div.input-accordion{
|
||||||
|
min-width: fit-content !important;
|
||||||
|
}
|
||||||
|
|
||||||
|
div.accordions > div.gradio-accordion .label-wrap span{
|
||||||
|
white-space: nowrap;
|
||||||
|
margin-right: 0.25em;
|
||||||
|
}
|
||||||
|
|
||||||
|
div.accordions{
|
||||||
|
gap: 0.5em;
|
||||||
|
}
|
||||||
|
|
||||||
|
div.accordions > div.input-accordion.input-accordion-open{
|
||||||
|
flex: 1 auto;
|
||||||
|
flex-flow: column;
|
||||||
|
}
|
||||||
|
|
||||||
|
@ -12,8 +12,6 @@ fi
|
|||||||
export install_dir="$HOME"
|
export install_dir="$HOME"
|
||||||
export COMMANDLINE_ARGS="--skip-torch-cuda-test --upcast-sampling --no-half-vae --use-cpu interrogate"
|
export COMMANDLINE_ARGS="--skip-torch-cuda-test --upcast-sampling --no-half-vae --use-cpu interrogate"
|
||||||
export TORCH_COMMAND="pip install torch==2.0.1 torchvision==0.15.2"
|
export TORCH_COMMAND="pip install torch==2.0.1 torchvision==0.15.2"
|
||||||
export K_DIFFUSION_REPO="https://github.com/brkirch/k-diffusion.git"
|
|
||||||
export K_DIFFUSION_COMMIT_HASH="51c9778f269cedb55a4d88c79c0246d35bdadb71"
|
|
||||||
export PYTORCH_ENABLE_MPS_FALLBACK=1
|
export PYTORCH_ENABLE_MPS_FALLBACK=1
|
||||||
|
|
||||||
####################################################################
|
####################################################################
|
||||||
|
Loading…
Reference in New Issue
Block a user