Merge remote-tracking branch 'auto1111/dev' into shared-hires-prompt-test

This commit is contained in:
Robert Barron 2023-08-14 00:35:17 -07:00
commit d61e31bae6
39 changed files with 1100 additions and 520 deletions

View File

@ -6,9 +6,14 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork):
def __init__(self): def __init__(self):
super().__init__('lora') super().__init__('lora')
self.errors = {}
"""mapping of network names to the number of errors the network had during operation"""
def activate(self, p, params_list): def activate(self, p, params_list):
additional = shared.opts.sd_lora additional = shared.opts.sd_lora
self.errors.clear()
if additional != "None" and additional in networks.available_networks and not any(x for x in params_list if x.items[0] == additional): if additional != "None" and additional in networks.available_networks and not any(x for x in params_list if x.items[0] == additional):
p.all_prompts = [x + f"<lora:{additional}:{shared.opts.extra_networks_default_multiplier}>" for x in p.all_prompts] p.all_prompts = [x + f"<lora:{additional}:{shared.opts.extra_networks_default_multiplier}>" for x in p.all_prompts]
params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier])) params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))
@ -56,4 +61,7 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork):
p.extra_generation_params["Lora hashes"] = ", ".join(network_hashes) p.extra_generation_params["Lora hashes"] = ", ".join(network_hashes)
def deactivate(self, p): def deactivate(self, p):
pass if self.errors:
p.comment("Networks with errors: " + ", ".join(f"{k} ({v})" for k, v in self.errors.items()))
self.errors.clear()

View File

@ -133,7 +133,7 @@ class NetworkModule:
return 1.0 return 1.0
def finalize_updown(self, updown, orig_weight, output_shape): def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None):
if self.bias is not None: if self.bias is not None:
updown = updown.reshape(self.bias.shape) updown = updown.reshape(self.bias.shape)
updown += self.bias.to(orig_weight.device, dtype=orig_weight.dtype) updown += self.bias.to(orig_weight.device, dtype=orig_weight.dtype)
@ -145,7 +145,10 @@ class NetworkModule:
if orig_weight.size().numel() == updown.size().numel(): if orig_weight.size().numel() == updown.size().numel():
updown = updown.reshape(orig_weight.shape) updown = updown.reshape(orig_weight.shape)
return updown * self.calc_scale() * self.multiplier() if ex_bias is not None:
ex_bias = ex_bias * self.multiplier()
return updown * self.calc_scale() * self.multiplier(), ex_bias
def calc_updown(self, target): def calc_updown(self, target):
raise NotImplementedError() raise NotImplementedError()

View File

@ -0,0 +1,28 @@
import network
class ModuleTypeNorm(network.ModuleType):
def create_module(self, net: network.Network, weights: network.NetworkWeights):
if all(x in weights.w for x in ["w_norm", "b_norm"]):
return NetworkModuleNorm(net, weights)
return None
class NetworkModuleNorm(network.NetworkModule):
def __init__(self, net: network.Network, weights: network.NetworkWeights):
super().__init__(net, weights)
self.w_norm = weights.w.get("w_norm")
self.b_norm = weights.w.get("b_norm")
def calc_updown(self, orig_weight):
output_shape = self.w_norm.shape
updown = self.w_norm.to(orig_weight.device, dtype=orig_weight.dtype)
if self.b_norm is not None:
ex_bias = self.b_norm.to(orig_weight.device, dtype=orig_weight.dtype)
else:
ex_bias = None
return self.finalize_updown(updown, orig_weight, output_shape, ex_bias)

View File

@ -1,3 +1,4 @@
import logging
import os import os
import re import re
@ -7,6 +8,7 @@ import network_hada
import network_ia3 import network_ia3
import network_lokr import network_lokr
import network_full import network_full
import network_norm
import torch import torch
from typing import Union from typing import Union
@ -19,6 +21,7 @@ module_types = [
network_ia3.ModuleTypeIa3(), network_ia3.ModuleTypeIa3(),
network_lokr.ModuleTypeLokr(), network_lokr.ModuleTypeLokr(),
network_full.ModuleTypeFull(), network_full.ModuleTypeFull(),
network_norm.ModuleTypeNorm(),
] ]
@ -31,6 +34,8 @@ suffix_conversion = {
"resnets": { "resnets": {
"conv1": "in_layers_2", "conv1": "in_layers_2",
"conv2": "out_layers_3", "conv2": "out_layers_3",
"norm1": "in_layers_0",
"norm2": "out_layers_0",
"time_emb_proj": "emb_layers_1", "time_emb_proj": "emb_layers_1",
"conv_shortcut": "skip_connection", "conv_shortcut": "skip_connection",
} }
@ -190,7 +195,7 @@ def load_network(name, network_on_disk):
net.modules[key] = net_module net.modules[key] = net_module
if keys_failed_to_match: if keys_failed_to_match:
print(f"Failed to match keys when loading network {network_on_disk.filename}: {keys_failed_to_match}") logging.debug(f"Network {network_on_disk.filename} didn't match keys: {keys_failed_to_match}")
return net return net
@ -203,7 +208,6 @@ def purge_networks_from_memory():
devices.torch_gc() devices.torch_gc()
def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=None): def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=None):
already_loaded = {} already_loaded = {}
@ -244,7 +248,7 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No
if net is None: if net is None:
failed_to_load_networks.append(name) failed_to_load_networks.append(name)
print(f"Couldn't find network with name {name}") logging.info(f"Couldn't find network with name {name}")
continue continue
net.te_multiplier = te_multipliers[i] if te_multipliers else 1.0 net.te_multiplier = te_multipliers[i] if te_multipliers else 1.0
@ -253,25 +257,38 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No
loaded_networks.append(net) loaded_networks.append(net)
if failed_to_load_networks: if failed_to_load_networks:
sd_hijack.model_hijack.comments.append("Failed to find networks: " + ", ".join(failed_to_load_networks)) sd_hijack.model_hijack.comments.append("Networks not found: " + ", ".join(failed_to_load_networks))
purge_networks_from_memory() purge_networks_from_memory()
def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]): def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]):
weights_backup = getattr(self, "network_weights_backup", None) weights_backup = getattr(self, "network_weights_backup", None)
bias_backup = getattr(self, "network_bias_backup", None)
if weights_backup is None: if weights_backup is None and bias_backup is None:
return return
if isinstance(self, torch.nn.MultiheadAttention): if weights_backup is not None:
self.in_proj_weight.copy_(weights_backup[0]) if isinstance(self, torch.nn.MultiheadAttention):
self.out_proj.weight.copy_(weights_backup[1]) self.in_proj_weight.copy_(weights_backup[0])
self.out_proj.weight.copy_(weights_backup[1])
else:
self.weight.copy_(weights_backup)
if bias_backup is not None:
if isinstance(self, torch.nn.MultiheadAttention):
self.out_proj.bias.copy_(bias_backup)
else:
self.bias.copy_(bias_backup)
else: else:
self.weight.copy_(weights_backup) if isinstance(self, torch.nn.MultiheadAttention):
self.out_proj.bias = None
else:
self.bias = None
def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]): def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]):
""" """
Applies the currently selected set of networks to the weights of torch layer self. Applies the currently selected set of networks to the weights of torch layer self.
If weights already have this particular set of networks applied, does nothing. If weights already have this particular set of networks applied, does nothing.
@ -294,21 +311,41 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
self.network_weights_backup = weights_backup self.network_weights_backup = weights_backup
bias_backup = getattr(self, "network_bias_backup", None)
if bias_backup is None:
if isinstance(self, torch.nn.MultiheadAttention) and self.out_proj.bias is not None:
bias_backup = self.out_proj.bias.to(devices.cpu, copy=True)
elif getattr(self, 'bias', None) is not None:
bias_backup = self.bias.to(devices.cpu, copy=True)
else:
bias_backup = None
self.network_bias_backup = bias_backup
if current_names != wanted_names: if current_names != wanted_names:
network_restore_weights_from_backup(self) network_restore_weights_from_backup(self)
for net in loaded_networks: for net in loaded_networks:
module = net.modules.get(network_layer_name, None) module = net.modules.get(network_layer_name, None)
if module is not None and hasattr(self, 'weight'): if module is not None and hasattr(self, 'weight'):
with torch.no_grad(): try:
updown = module.calc_updown(self.weight) with torch.no_grad():
updown, ex_bias = module.calc_updown(self.weight)
if len(self.weight.shape) == 4 and self.weight.shape[1] == 9: if len(self.weight.shape) == 4 and self.weight.shape[1] == 9:
# inpainting model. zero pad updown to make channel[1] 4 to 9 # inpainting model. zero pad updown to make channel[1] 4 to 9
updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5)) updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5))
self.weight += updown self.weight += updown
continue if ex_bias is not None and hasattr(self, 'bias'):
if self.bias is None:
self.bias = torch.nn.Parameter(ex_bias)
else:
self.bias += ex_bias
except RuntimeError as e:
logging.debug(f"Network {net.name} layer {network_layer_name}: {e}")
extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1
continue
module_q = net.modules.get(network_layer_name + "_q_proj", None) module_q = net.modules.get(network_layer_name + "_q_proj", None)
module_k = net.modules.get(network_layer_name + "_k_proj", None) module_k = net.modules.get(network_layer_name + "_k_proj", None)
@ -316,21 +353,33 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
module_out = net.modules.get(network_layer_name + "_out_proj", None) module_out = net.modules.get(network_layer_name + "_out_proj", None)
if isinstance(self, torch.nn.MultiheadAttention) and module_q and module_k and module_v and module_out: if isinstance(self, torch.nn.MultiheadAttention) and module_q and module_k and module_v and module_out:
with torch.no_grad(): try:
updown_q = module_q.calc_updown(self.in_proj_weight) with torch.no_grad():
updown_k = module_k.calc_updown(self.in_proj_weight) updown_q, _ = module_q.calc_updown(self.in_proj_weight)
updown_v = module_v.calc_updown(self.in_proj_weight) updown_k, _ = module_k.calc_updown(self.in_proj_weight)
updown_qkv = torch.vstack([updown_q, updown_k, updown_v]) updown_v, _ = module_v.calc_updown(self.in_proj_weight)
updown_out = module_out.calc_updown(self.out_proj.weight) updown_qkv = torch.vstack([updown_q, updown_k, updown_v])
updown_out, ex_bias = module_out.calc_updown(self.out_proj.weight)
self.in_proj_weight += updown_qkv self.in_proj_weight += updown_qkv
self.out_proj.weight += updown_out self.out_proj.weight += updown_out
continue if ex_bias is not None:
if self.out_proj.bias is None:
self.out_proj.bias = torch.nn.Parameter(ex_bias)
else:
self.out_proj.bias += ex_bias
except RuntimeError as e:
logging.debug(f"Network {net.name} layer {network_layer_name}: {e}")
extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1
continue
if module is None: if module is None:
continue continue
print(f'failed to calculate network weights for layer {network_layer_name}') logging.debug(f"Network {net.name} layer {network_layer_name}: couldn't find supported operation")
extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1
self.network_current_names = wanted_names self.network_current_names = wanted_names
@ -357,7 +406,7 @@ def network_forward(module, input, original_forward):
if module is None: if module is None:
continue continue
y = module.forward(y, input) y = module.forward(input, y)
return y return y
@ -397,6 +446,36 @@ def network_Conv2d_load_state_dict(self, *args, **kwargs):
return torch.nn.Conv2d_load_state_dict_before_network(self, *args, **kwargs) return torch.nn.Conv2d_load_state_dict_before_network(self, *args, **kwargs)
def network_GroupNorm_forward(self, input):
if shared.opts.lora_functional:
return network_forward(self, input, torch.nn.GroupNorm_forward_before_network)
network_apply_weights(self)
return torch.nn.GroupNorm_forward_before_network(self, input)
def network_GroupNorm_load_state_dict(self, *args, **kwargs):
network_reset_cached_weight(self)
return torch.nn.GroupNorm_load_state_dict_before_network(self, *args, **kwargs)
def network_LayerNorm_forward(self, input):
if shared.opts.lora_functional:
return network_forward(self, input, torch.nn.LayerNorm_forward_before_network)
network_apply_weights(self)
return torch.nn.LayerNorm_forward_before_network(self, input)
def network_LayerNorm_load_state_dict(self, *args, **kwargs):
network_reset_cached_weight(self)
return torch.nn.LayerNorm_load_state_dict_before_network(self, *args, **kwargs)
def network_MultiheadAttention_forward(self, *args, **kwargs): def network_MultiheadAttention_forward(self, *args, **kwargs):
network_apply_weights(self) network_apply_weights(self)
@ -473,6 +552,7 @@ def infotext_pasted(infotext, params):
if added: if added:
params["Prompt"] += "\n" + "".join(added) params["Prompt"] += "\n" + "".join(added)
extra_network_lora = None
available_networks = {} available_networks = {}
available_network_aliases = {} available_network_aliases = {}

View File

@ -23,9 +23,9 @@ def unload():
def before_ui(): def before_ui():
ui_extra_networks.register_page(ui_extra_networks_lora.ExtraNetworksPageLora()) ui_extra_networks.register_page(ui_extra_networks_lora.ExtraNetworksPageLora())
extra_network = extra_networks_lora.ExtraNetworkLora() networks.extra_network_lora = extra_networks_lora.ExtraNetworkLora()
extra_networks.register_extra_network(extra_network) extra_networks.register_extra_network(networks.extra_network_lora)
extra_networks.register_extra_network_alias(extra_network, "lyco") extra_networks.register_extra_network_alias(networks.extra_network_lora, "lyco")
if not hasattr(torch.nn, 'Linear_forward_before_network'): if not hasattr(torch.nn, 'Linear_forward_before_network'):
@ -40,6 +40,18 @@ if not hasattr(torch.nn, 'Conv2d_forward_before_network'):
if not hasattr(torch.nn, 'Conv2d_load_state_dict_before_network'): if not hasattr(torch.nn, 'Conv2d_load_state_dict_before_network'):
torch.nn.Conv2d_load_state_dict_before_network = torch.nn.Conv2d._load_from_state_dict torch.nn.Conv2d_load_state_dict_before_network = torch.nn.Conv2d._load_from_state_dict
if not hasattr(torch.nn, 'GroupNorm_forward_before_network'):
torch.nn.GroupNorm_forward_before_network = torch.nn.GroupNorm.forward
if not hasattr(torch.nn, 'GroupNorm_load_state_dict_before_network'):
torch.nn.GroupNorm_load_state_dict_before_network = torch.nn.GroupNorm._load_from_state_dict
if not hasattr(torch.nn, 'LayerNorm_forward_before_network'):
torch.nn.LayerNorm_forward_before_network = torch.nn.LayerNorm.forward
if not hasattr(torch.nn, 'LayerNorm_load_state_dict_before_network'):
torch.nn.LayerNorm_load_state_dict_before_network = torch.nn.LayerNorm._load_from_state_dict
if not hasattr(torch.nn, 'MultiheadAttention_forward_before_network'): if not hasattr(torch.nn, 'MultiheadAttention_forward_before_network'):
torch.nn.MultiheadAttention_forward_before_network = torch.nn.MultiheadAttention.forward torch.nn.MultiheadAttention_forward_before_network = torch.nn.MultiheadAttention.forward
@ -50,6 +62,10 @@ torch.nn.Linear.forward = networks.network_Linear_forward
torch.nn.Linear._load_from_state_dict = networks.network_Linear_load_state_dict torch.nn.Linear._load_from_state_dict = networks.network_Linear_load_state_dict
torch.nn.Conv2d.forward = networks.network_Conv2d_forward torch.nn.Conv2d.forward = networks.network_Conv2d_forward
torch.nn.Conv2d._load_from_state_dict = networks.network_Conv2d_load_state_dict torch.nn.Conv2d._load_from_state_dict = networks.network_Conv2d_load_state_dict
torch.nn.GroupNorm.forward = networks.network_GroupNorm_forward
torch.nn.GroupNorm._load_from_state_dict = networks.network_GroupNorm_load_state_dict
torch.nn.LayerNorm.forward = networks.network_LayerNorm_forward
torch.nn.LayerNorm._load_from_state_dict = networks.network_LayerNorm_load_state_dict
torch.nn.MultiheadAttention.forward = networks.network_MultiheadAttention_forward torch.nn.MultiheadAttention.forward = networks.network_MultiheadAttention_forward
torch.nn.MultiheadAttention._load_from_state_dict = networks.network_MultiheadAttention_load_state_dict torch.nn.MultiheadAttention._load_from_state_dict = networks.network_MultiheadAttention_load_state_dict

View File

@ -25,9 +25,10 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
item = { item = {
"name": name, "name": name,
"filename": lora_on_disk.filename, "filename": lora_on_disk.filename,
"shorthash": lora_on_disk.shorthash,
"preview": self.find_preview(path), "preview": self.find_preview(path),
"description": self.find_description(path), "description": self.find_description(path),
"search_term": self.search_terms_from_path(lora_on_disk.filename), "search_term": self.search_terms_from_path(lora_on_disk.filename) + " " + (lora_on_disk.hash or ""),
"local_preview": f"{path}.{shared.opts.samples_format}", "local_preview": f"{path}.{shared.opts.samples_format}",
"metadata": lora_on_disk.metadata, "metadata": lora_on_disk.metadata,
"sort_keys": {'default': index, **self.get_sort_keys(lora_on_disk.filename)}, "sort_keys": {'default': index, **self.get_sort_keys(lora_on_disk.filename)},

View File

@ -12,6 +12,7 @@ onUiLoaded(async() => {
"Sketch": elementIDs.sketch "Sketch": elementIDs.sketch
}; };
// Helper functions // Helper functions
// Get active tab // Get active tab
function getActiveTab(elements, all = false) { function getActiveTab(elements, all = false) {
@ -377,6 +378,11 @@ onUiLoaded(async() => {
toggleOverlap("off"); toggleOverlap("off");
fullScreenMode = false; fullScreenMode = false;
const closeBtn = targetElement.querySelector("button[aria-label='Remove Image']");
if (closeBtn) {
closeBtn.addEventListener("click", resetZoom);
}
if ( if (
canvas && canvas &&
parseFloat(canvas.style.width) > 865 && parseFloat(canvas.style.width) > 865 &&
@ -657,17 +663,20 @@ onUiLoaded(async() => {
// Simulation of the function to put a long image into the screen. // Simulation of the function to put a long image into the screen.
// We detect if an image has a scroll bar or not, make a fullscreen to reveal the image, then reduce it to fit into the element. // We detect if an image has a scroll bar or not, make a fullscreen to reveal the image, then reduce it to fit into the element.
// We hide the image and show it to the user when it is ready. // We hide the image and show it to the user when it is ready.
function autoExpand(e) {
targetElement.isExpanded = false;
function autoExpand() {
const canvas = document.querySelector(`${elemId} canvas[key="interface"]`); const canvas = document.querySelector(`${elemId} canvas[key="interface"]`);
const isMainTab = activeElement === elementIDs.inpaint || activeElement === elementIDs.inpaintSketch || activeElement === elementIDs.sketch; const isMainTab = activeElement === elementIDs.inpaint || activeElement === elementIDs.inpaintSketch || activeElement === elementIDs.sketch;
if (canvas && isMainTab) { if (canvas && isMainTab) {
if (hasHorizontalScrollbar(targetElement)) { if (hasHorizontalScrollbar(targetElement) && targetElement.isExpanded === false) {
targetElement.style.visibility = "hidden"; targetElement.style.visibility = "hidden";
setTimeout(() => { setTimeout(() => {
fitToScreen(); fitToScreen();
resetZoom(); resetZoom();
targetElement.style.visibility = "visible"; targetElement.style.visibility = "visible";
targetElement.isExpanded = true;
}, 10); }, 10);
} }
} }
@ -675,9 +684,24 @@ onUiLoaded(async() => {
targetElement.addEventListener("mousemove", getMousePosition); targetElement.addEventListener("mousemove", getMousePosition);
//observers
// Creating an observer with a callback function to handle DOM changes
const observer = new MutationObserver((mutationsList, observer) => {
for (let mutation of mutationsList) {
// If the style attribute of the canvas has changed, by observation it happens only when the picture changes
if (mutation.type === 'attributes' && mutation.attributeName === 'style' &&
mutation.target.tagName.toLowerCase() === 'canvas') {
targetElement.isExpanded = false;
setTimeout(resetZoom, 10);
}
}
});
// Apply auto expand if enabled // Apply auto expand if enabled
if (hotkeysConfig.canvas_auto_expand) { if (hotkeysConfig.canvas_auto_expand) {
targetElement.addEventListener("mousemove", autoExpand); targetElement.addEventListener("mousemove", autoExpand);
// Set up an observer to track attribute changes
observer.observe(targetElement, {attributes: true, childList: true, subtree: true});
} }
// Handle events only inside the targetElement // Handle events only inside the targetElement

View File

@ -50,10 +50,12 @@ class PydanticModelGenerator:
additional_fields = None, additional_fields = None,
): ):
def field_type_generator(k, v): def field_type_generator(k, v):
# field_type = str if not overrides.get(k) else overrides[k]["type"]
# print(k, v.annotation, v.default)
field_type = v.annotation field_type = v.annotation
if field_type == 'Image':
# images are sent as base64 strings via API
field_type = 'str'
return Optional[field_type] return Optional[field_type]
def merge_class_params(class_): def merge_class_params(class_):
@ -63,7 +65,6 @@ class PydanticModelGenerator:
parameters = {**parameters, **inspect.signature(classes.__init__).parameters} parameters = {**parameters, **inspect.signature(classes.__init__).parameters}
return parameters return parameters
self._model_name = model_name self._model_name = model_name
self._class_data = merge_class_params(class_instance) self._class_data = merge_class_params(class_instance)
@ -72,7 +73,7 @@ class PydanticModelGenerator:
field=underscore(k), field=underscore(k),
field_alias=k, field_alias=k,
field_type=field_type_generator(k, v), field_type=field_type_generator(k, v),
field_value=v.default field_value=None if isinstance(v.default, property) else v.default
) )
for (k,v) in self._class_data.items() if k not in API_NOT_ALLOWED for (k,v) in self._class_data.items() if k not in API_NOT_ALLOWED
] ]

View File

@ -116,7 +116,7 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal
process_images(p) process_images(p)
def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_name: str, mask_blur: int, mask_alpha: float, inpainting_fill: int, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, img2img_batch_use_png_info: bool, img2img_batch_png_info_props: list, img2img_batch_png_info_dir: str, request: gr.Request, *args): def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_name: str, mask_blur: int, mask_alpha: float, inpainting_fill: int, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, img2img_batch_use_png_info: bool, img2img_batch_png_info_props: list, img2img_batch_png_info_dir: str, request: gr.Request, *args):
override_settings = create_override_settings_dict(override_settings_texts) override_settings = create_override_settings_dict(override_settings_texts)
is_batch = mode == 5 is_batch = mode == 5
@ -166,12 +166,6 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
prompt=prompt, prompt=prompt,
negative_prompt=negative_prompt, negative_prompt=negative_prompt,
styles=prompt_styles, styles=prompt_styles,
seed=seed,
subseed=subseed,
subseed_strength=subseed_strength,
seed_resize_from_h=seed_resize_from_h,
seed_resize_from_w=seed_resize_from_w,
seed_enable_extras=seed_enable_extras,
sampler_name=sampler_name, sampler_name=sampler_name,
batch_size=batch_size, batch_size=batch_size,
n_iter=n_iter, n_iter=n_iter,

View File

@ -173,9 +173,12 @@ def git_clone(url, dir, name, commithash=None):
if current_hash == commithash: if current_hash == commithash:
return return
run_git('fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}", autofix=False) if run_git(dir, name, 'config --get remote.origin.url', None, f"Couldn't determine {name}'s origin URL", live=False).strip() != url:
run_git(dir, name, f'remote set-url origin "{url}"', None, f"Failed to set {name}'s origin URL", live=False)
run_git('checkout', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}", live=True) run_git(dir, name, 'fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}", autofix=False)
run_git(dir, name, f'checkout {commithash}', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}", live=True)
return return
@ -243,7 +246,7 @@ def list_extensions(settings_file):
disabled_extensions = set(settings.get('disabled_extensions', [])) disabled_extensions = set(settings.get('disabled_extensions', []))
disable_all_extensions = settings.get('disable_all_extensions', 'none') disable_all_extensions = settings.get('disable_all_extensions', 'none')
if disable_all_extensions != 'none': if disable_all_extensions != 'none' or args.disable_extra_extensions or args.disable_all_extensions:
return [] return []
return [x for x in os.listdir(extensions_dir) if x not in disabled_extensions] return [x for x in os.listdir(extensions_dir) if x not in disabled_extensions]
@ -319,12 +322,12 @@ def prepare_environment():
stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf") stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf")
stable_diffusion_xl_commit_hash = os.environ.get('STABLE_DIFFUSION_XL_COMMIT_HASH', "5c10deee76adad0032b412294130090932317a87") stable_diffusion_xl_commit_hash = os.environ.get('STABLE_DIFFUSION_XL_COMMIT_HASH', "5c10deee76adad0032b412294130090932317a87")
k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "c9fe758757e022f05ca5a53fa8fac28889e4f1cf") k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "ab527a9a6d347f364e3d185ba6d714e22d80cb3c")
codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af") codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af")
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9") blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
try: try:
# the existance of this file is a signal to webui.sh/bat that webui needs to be restarted when it stops execution # the existence of this file is a signal to webui.sh/bat that webui needs to be restarted when it stops execution
os.remove(os.path.join(script_path, "tmp", "restart")) os.remove(os.path.join(script_path, "tmp", "restart"))
os.environ.setdefault('SD_WEBUI_RESTARTING', '1') os.environ.setdefault('SD_WEBUI_RESTARTING', '1')
except OSError: except OSError:

View File

@ -52,9 +52,6 @@ def cumsum_fix(input, cumsum_func, *args, **kwargs):
if has_mps: if has_mps:
# MPS fix for randn in torchsde
CondFunc('torchsde._brownian.brownian_interval._randn', lambda _, size, dtype, device, seed: torch.randn(size, dtype=dtype, device=torch.device("cpu"), generator=torch.Generator(torch.device("cpu")).manual_seed(int(seed))).to(device), lambda _, size, dtype, device, seed: device.type == 'mps')
if platform.mac_ver()[0].startswith("13.2."): if platform.mac_ver()[0].startswith("13.2."):
# MPS workaround for https://github.com/pytorch/pytorch/issues/95188, thanks to danieldk (https://github.com/explosion/curated-transformers/pull/124) # MPS workaround for https://github.com/pytorch/pytorch/issues/95188, thanks to danieldk (https://github.com/explosion/curated-transformers/pull/124)
CondFunc('torch.nn.functional.linear', lambda _, input, weight, bias: (torch.matmul(input, weight.t()) + bias) if bias is not None else torch.matmul(input, weight.t()), lambda _, input, weight, bias: input.numel() > 10485760) CondFunc('torch.nn.functional.linear', lambda _, input, weight, bias: (torch.matmul(input, weight.t()) + bias) if bias is not None else torch.matmul(input, weight.t()), lambda _, input, weight, bias: input.numel() > 10485760)

View File

@ -11,37 +11,32 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir,
shared.state.begin(job="extras") shared.state.begin(job="extras")
image_data = []
image_names = []
outputs = [] outputs = []
if extras_mode == 1: def get_images(extras_mode, image, image_folder, input_dir):
for img in image_folder: if extras_mode == 1:
if isinstance(img, Image.Image): for img in image_folder:
image = img if isinstance(img, Image.Image):
fn = '' image = img
else: fn = ''
image = Image.open(os.path.abspath(img.name)) else:
fn = os.path.splitext(img.orig_name)[0] image = Image.open(os.path.abspath(img.name))
image_data.append(image) fn = os.path.splitext(img.orig_name)[0]
image_names.append(fn) yield image, fn
elif extras_mode == 2: elif extras_mode == 2:
assert not shared.cmd_opts.hide_ui_dir_config, '--hide-ui-dir-config option must be disabled' assert not shared.cmd_opts.hide_ui_dir_config, '--hide-ui-dir-config option must be disabled'
assert input_dir, 'input directory not selected' assert input_dir, 'input directory not selected'
image_list = shared.listfiles(input_dir) image_list = shared.listfiles(input_dir)
for filename in image_list: for filename in image_list:
try: try:
image = Image.open(filename) image = Image.open(filename)
except Exception: except Exception:
continue continue
image_data.append(image) yield image, filename
image_names.append(filename) else:
else: assert image, 'image not selected'
assert image, 'image not selected' yield image, None
image_data.append(image)
image_names.append(None)
if extras_mode == 2 and output_dir != '': if extras_mode == 2 and output_dir != '':
outpath = output_dir outpath = output_dir
@ -50,14 +45,16 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir,
infotext = '' infotext = ''
for image, name in zip(image_data, image_names): for image_data, name in get_images(extras_mode, image, image_folder, input_dir):
image_data: Image.Image
shared.state.textinfo = name shared.state.textinfo = name
parameters, existing_pnginfo = images.read_info_from_image(image) parameters, existing_pnginfo = images.read_info_from_image(image_data)
if parameters: if parameters:
existing_pnginfo["parameters"] = parameters existing_pnginfo["parameters"] = parameters
pp = scripts_postprocessing.PostprocessedImage(image.convert("RGB")) pp = scripts_postprocessing.PostprocessedImage(image_data.convert("RGB"))
scripts.scripts_postproc.run(pp, args) scripts.scripts_postproc.run(pp, args)
@ -78,6 +75,8 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir,
if extras_mode != 2 or show_extras_results: if extras_mode != 2 or show_extras_results:
outputs.append(pp.image) outputs.append(pp.image)
image_data.close()
devices.torch_gc() devices.torch_gc()
return outputs, ui_common.plaintext_to_html(infotext), '' return outputs, ui_common.plaintext_to_html(infotext), ''

View File

@ -1,9 +1,11 @@
from __future__ import annotations
import json import json
import logging import logging
import math import math
import os import os
import sys import sys
import hashlib import hashlib
from dataclasses import dataclass, field
import torch import torch
import numpy as np import numpy as np
@ -11,7 +13,7 @@ from PIL import Image, ImageOps
import random import random
import cv2 import cv2
from skimage import exposure from skimage import exposure
from typing import Any, Dict, List from typing import Any
import modules.sd_hijack import modules.sd_hijack
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors, rng from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors, rng
@ -57,7 +59,7 @@ def apply_color_correction(correction, original_image):
image = blendLayers(image, original_image, BlendType.LUMINOSITY) image = blendLayers(image, original_image, BlendType.LUMINOSITY)
return image return image.convert('RGB')
def apply_overlay(image, paste_loc, index, overlays): def apply_overlay(image, paste_loc, index, overlays):
@ -104,97 +106,163 @@ def txt2img_image_conditioning(sd_model, x, width, height):
return x.new_zeros(x.shape[0], 5, 1, 1, dtype=x.dtype, device=x.device) return x.new_zeros(x.shape[0], 5, 1, 1, dtype=x.dtype, device=x.device)
@dataclass(repr=False)
class StableDiffusionProcessing: class StableDiffusionProcessing:
""" sd_model: object = None
The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing outpath_samples: str = None
""" outpath_grids: str = None
prompt: str = ""
prompt_for_display: str = None
negative_prompt: str = ""
styles: list[str] = None
seed: int = -1
subseed: int = -1
subseed_strength: float = 0
seed_resize_from_h: int = -1
seed_resize_from_w: int = -1
seed_enable_extras: bool = True
sampler_name: str = None
batch_size: int = 1
n_iter: int = 1
steps: int = 50
cfg_scale: float = 7.0
width: int = 512
height: int = 512
restore_faces: bool = None
tiling: bool = None
do_not_save_samples: bool = False
do_not_save_grid: bool = False
extra_generation_params: dict[str, Any] = None
overlay_images: list = None
eta: float = None
do_not_reload_embeddings: bool = False
denoising_strength: float = 0
ddim_discretize: str = None
s_min_uncond: float = None
s_churn: float = None
s_tmax: float = None
s_tmin: float = None
s_noise: float = None
override_settings: dict[str, Any] = None
override_settings_restore_afterwards: bool = True
sampler_index: int = None
refiner_checkpoint: str = None
refiner_switch_at: float = None
token_merging_ratio = 0
token_merging_ratio_hr = 0
disable_extra_networks: bool = False
scripts_value: scripts.ScriptRunner = field(default=None, init=False)
script_args_value: list = field(default=None, init=False)
scripts_setup_complete: bool = field(default=False, init=False)
cached_uc = [None, None] cached_uc = [None, None]
cached_c = [None, None] cached_c = [None, None]
def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = None, tiling: bool = None, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_min_uncond: float = 0.0, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = None, override_settings: Dict[str, Any] = None, override_settings_restore_afterwards: bool = True, sampler_index: int = None, script_args: list = None): comments: dict = None
if sampler_index is not None: sampler: sd_samplers_common.Sampler | None = field(default=None, init=False)
is_using_inpainting_conditioning: bool = field(default=False, init=False)
paste_to: tuple | None = field(default=None, init=False)
is_hr_pass: bool = field(default=False, init=False)
c: tuple = field(default=None, init=False)
uc: tuple = field(default=None, init=False)
rng: rng.ImageRNG | None = field(default=None, init=False)
step_multiplier: int = field(default=1, init=False)
color_corrections: list = field(default=None, init=False)
all_prompts: list = field(default=None, init=False)
all_negative_prompts: list = field(default=None, init=False)
all_seeds: list = field(default=None, init=False)
all_subseeds: list = field(default=None, init=False)
iteration: int = field(default=0, init=False)
main_prompt: str = field(default=None, init=False)
main_negative_prompt: str = field(default=None, init=False)
prompts: list = field(default=None, init=False)
negative_prompts: list = field(default=None, init=False)
seeds: list = field(default=None, init=False)
subseeds: list = field(default=None, init=False)
extra_network_data: dict = field(default=None, init=False)
user: str = field(default=None, init=False)
sd_model_name: str = field(default=None, init=False)
sd_model_hash: str = field(default=None, init=False)
sd_vae_name: str = field(default=None, init=False)
sd_vae_hash: str = field(default=None, init=False)
def __post_init__(self):
if self.sampler_index is not None:
print("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name", file=sys.stderr) print("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name", file=sys.stderr)
self.outpath_samples: str = outpath_samples self.comments = {}
self.outpath_grids: str = outpath_grids
self.prompt: str = prompt
self.prompt_for_display: str = None
self.negative_prompt: str = (negative_prompt or "")
self.styles: list = styles or []
self.seed: int = seed
self.subseed: int = subseed
self.subseed_strength: float = subseed_strength
self.seed_resize_from_h: int = seed_resize_from_h
self.seed_resize_from_w: int = seed_resize_from_w
self.sampler_name: str = sampler_name
self.batch_size: int = batch_size
self.n_iter: int = n_iter
self.steps: int = steps
self.cfg_scale: float = cfg_scale
self.width: int = width
self.height: int = height
self.restore_faces: bool = restore_faces
self.tiling: bool = tiling
self.do_not_save_samples: bool = do_not_save_samples
self.do_not_save_grid: bool = do_not_save_grid
self.extra_generation_params: dict = extra_generation_params or {}
self.overlay_images = overlay_images
self.eta = eta
self.do_not_reload_embeddings = do_not_reload_embeddings
self.paste_to = None
self.color_corrections = None
self.denoising_strength: float = denoising_strength
self.sampler_noise_scheduler_override = None
self.ddim_discretize = ddim_discretize or opts.ddim_discretize
self.s_min_uncond = s_min_uncond or opts.s_min_uncond
self.s_churn = s_churn or opts.s_churn
self.s_tmin = s_tmin or opts.s_tmin
self.s_tmax = (s_tmax if s_tmax is not None else opts.s_tmax) or float('inf')
self.s_noise = s_noise if s_noise is not None else opts.s_noise
self.override_settings = {k: v for k, v in (override_settings or {}).items() if k not in shared.restricted_opts}
self.override_settings_restore_afterwards = override_settings_restore_afterwards
self.is_using_inpainting_conditioning = False
self.disable_extra_networks = False
self.token_merging_ratio = 0
self.token_merging_ratio_hr = 0
if not seed_enable_extras: if self.styles is None:
self.styles = []
self.sampler_noise_scheduler_override = None
self.s_min_uncond = self.s_min_uncond if self.s_min_uncond is not None else opts.s_min_uncond
self.s_churn = self.s_churn if self.s_churn is not None else opts.s_churn
self.s_tmin = self.s_tmin if self.s_tmin is not None else opts.s_tmin
self.s_tmax = (self.s_tmax if self.s_tmax is not None else opts.s_tmax) or float('inf')
self.s_noise = self.s_noise if self.s_noise is not None else opts.s_noise
self.extra_generation_params = self.extra_generation_params or {}
self.override_settings = self.override_settings or {}
self.script_args = self.script_args or {}
self.refiner_checkpoint_info = None
if not self.seed_enable_extras:
self.subseed = -1 self.subseed = -1
self.subseed_strength = 0 self.subseed_strength = 0
self.seed_resize_from_h = 0 self.seed_resize_from_h = 0
self.seed_resize_from_w = 0 self.seed_resize_from_w = 0
self.scripts = None
self.script_args = script_args
self.all_prompts = None
self.all_negative_prompts = None
self.all_seeds = None
self.all_subseeds = None
self.iteration = 0
self.is_hr_pass = False
self.sampler = None
self.main_prompt = None
self.main_negative_prompt = None
self.prompts = None
self.negative_prompts = None
self.extra_network_data = None
self.seeds = None
self.subseeds = None
self.step_multiplier = 1
self.cached_uc = StableDiffusionProcessing.cached_uc self.cached_uc = StableDiffusionProcessing.cached_uc
self.cached_c = StableDiffusionProcessing.cached_c self.cached_c = StableDiffusionProcessing.cached_c
self.uc = None
self.c = None
self.rng: rng.ImageRNG = None
self.user = None
@property @property
def sd_model(self): def sd_model(self):
return shared.sd_model return shared.sd_model
@sd_model.setter
def sd_model(self, value):
pass
@property
def scripts(self):
return self.scripts_value
@scripts.setter
def scripts(self, value):
self.scripts_value = value
if self.scripts_value and self.script_args_value and not self.scripts_setup_complete:
self.setup_scripts()
@property
def script_args(self):
return self.script_args_value
@script_args.setter
def script_args(self, value):
self.script_args_value = value
if self.scripts_value and self.script_args_value and not self.scripts_setup_complete:
self.setup_scripts()
def setup_scripts(self):
self.scripts_setup_complete = True
self.scripts.setup_scrips(self)
def comment(self, text):
self.comments[text] = 1
def txt2img_image_conditioning(self, x, width=None, height=None): def txt2img_image_conditioning(self, x, width=None, height=None):
self.is_using_inpainting_conditioning = self.sd_model.model.conditioning_key in {'hybrid', 'concat'} self.is_using_inpainting_conditioning = self.sd_model.model.conditioning_key in {'hybrid', 'concat'}
@ -343,7 +411,7 @@ class StableDiffusionProcessing:
self.height, self.height,
) )
def get_conds_with_caching(self, function, required_prompts, steps, hires_steps, caches, extra_network_data): def get_conds_with_caching(self, function, required_prompts, steps, caches, extra_network_data, hires_steps=None):
""" """
Returns the result of calling function(shared.sd_model, required_prompts, steps) Returns the result of calling function(shared.sd_model, required_prompts, steps)
using a cache to store the result if the same arguments have been used before. using a cache to store the result if the same arguments have been used before.
@ -375,10 +443,12 @@ class StableDiffusionProcessing:
negative_prompts = prompt_parser.SdConditioning(self.negative_prompts, width=self.width, height=self.height, is_negative_prompt=True) negative_prompts = prompt_parser.SdConditioning(self.negative_prompts, width=self.width, height=self.height, is_negative_prompt=True)
sampler_config = sd_samplers.find_sampler_config(self.sampler_name) sampler_config = sd_samplers.find_sampler_config(self.sampler_name)
self.step_multiplier = 2 if sampler_config and sampler_config.options.get("second_order", False) else 1 total_steps = sampler_config.total_steps(self.steps) if sampler_config else self.steps
self.firstpass_steps = self.steps * self.step_multiplier self.step_multiplier = total_steps // self.steps
self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, self.firstpass_steps, None, [self.cached_uc], self.extra_network_data) self.firstpass_steps = total_steps
self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, self.firstpass_steps, None, [self.cached_c], self.extra_network_data)
self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, total_steps, [self.cached_uc], self.extra_network_data, None)
self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, total_steps, [self.cached_c], self.extra_network_data, None )
def get_conds(self): def get_conds(self):
return self.c, self.uc return self.c, self.uc
@ -400,7 +470,7 @@ class Processed:
self.subseed = subseed self.subseed = subseed
self.subseed_strength = p.subseed_strength self.subseed_strength = p.subseed_strength
self.info = info self.info = info
self.comments = comments self.comments = "".join(f"{comment}\n" for comment in p.comments)
self.width = p.width self.width = p.width
self.height = p.height self.height = p.height
self.sampler_name = p.sampler_name self.sampler_name = p.sampler_name
@ -410,7 +480,10 @@ class Processed:
self.batch_size = p.batch_size self.batch_size = p.batch_size
self.restore_faces = p.restore_faces self.restore_faces = p.restore_faces
self.face_restoration_model = opts.face_restoration_model if p.restore_faces else None self.face_restoration_model = opts.face_restoration_model if p.restore_faces else None
self.sd_model_hash = shared.sd_model.sd_model_hash self.sd_model_name = p.sd_model_name
self.sd_model_hash = p.sd_model_hash
self.sd_vae_name = p.sd_vae_name
self.sd_vae_hash = p.sd_vae_hash
self.seed_resize_from_w = p.seed_resize_from_w self.seed_resize_from_w = p.seed_resize_from_w
self.seed_resize_from_h = p.seed_resize_from_h self.seed_resize_from_h = p.seed_resize_from_h
self.denoising_strength = getattr(p, 'denoising_strength', None) self.denoising_strength = getattr(p, 'denoising_strength', None)
@ -461,7 +534,10 @@ class Processed:
"batch_size": self.batch_size, "batch_size": self.batch_size,
"restore_faces": self.restore_faces, "restore_faces": self.restore_faces,
"face_restoration_model": self.face_restoration_model, "face_restoration_model": self.face_restoration_model,
"sd_model_name": self.sd_model_name,
"sd_model_hash": self.sd_model_hash, "sd_model_hash": self.sd_model_hash,
"sd_vae_name": self.sd_vae_name,
"sd_vae_hash": self.sd_vae_hash,
"seed_resize_from_w": self.seed_resize_from_w, "seed_resize_from_w": self.seed_resize_from_w,
"seed_resize_from_h": self.seed_resize_from_h, "seed_resize_from_h": self.seed_resize_from_h,
"denoising_strength": self.denoising_strength, "denoising_strength": self.denoising_strength,
@ -580,10 +656,10 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
"Seed": p.all_seeds[0] if use_main_prompt else all_seeds[index], "Seed": p.all_seeds[0] if use_main_prompt else all_seeds[index],
"Face restoration": opts.face_restoration_model if p.restore_faces else None, "Face restoration": opts.face_restoration_model if p.restore_faces else None,
"Size": f"{p.width}x{p.height}", "Size": f"{p.width}x{p.height}",
"Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash), "Model hash": p.sd_model_hash if opts.add_model_hash_to_info else None,
"Model": (None if not opts.add_model_name_to_info else shared.sd_model.sd_checkpoint_info.name_for_extra), "Model": p.sd_model_name if opts.add_model_name_to_info else None,
"VAE hash": sd_vae.get_loaded_vae_hash() if opts.add_model_hash_to_info else None, "VAE hash": p.sd_vae_hash if opts.add_model_hash_to_info else None,
"VAE": sd_vae.get_loaded_vae_name() if opts.add_model_name_to_info else None, "VAE": p.sd_vae_name if opts.add_model_name_to_info else None,
"Variation seed": (None if p.subseed_strength == 0 else (p.all_subseeds[0] if use_main_prompt else all_subseeds[index])), "Variation seed": (None if p.subseed_strength == 0 else (p.all_subseeds[0] if use_main_prompt else all_subseeds[index])),
"Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength), "Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
"Seed resize from": (None if p.seed_resize_from_w <= 0 or p.seed_resize_from_h <= 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"), "Seed resize from": (None if p.seed_resize_from_w <= 0 or p.seed_resize_from_h <= 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
@ -672,11 +748,19 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if p.tiling is None: if p.tiling is None:
p.tiling = opts.tiling p.tiling = opts.tiling
if p.refiner_checkpoint not in (None, "", "None"):
p.refiner_checkpoint_info = sd_models.get_closet_checkpoint_match(p.refiner_checkpoint)
if p.refiner_checkpoint_info is None:
raise Exception(f'Could not find checkpoint with name {p.refiner_checkpoint}')
p.sd_model_name = shared.sd_model.sd_checkpoint_info.name_for_extra
p.sd_model_hash = shared.sd_model.sd_model_hash
p.sd_vae_name = sd_vae.get_loaded_vae_name()
p.sd_vae_hash = sd_vae.get_loaded_vae_hash()
modules.sd_hijack.model_hijack.apply_circular(p.tiling) modules.sd_hijack.model_hijack.apply_circular(p.tiling)
modules.sd_hijack.model_hijack.clear_comments() modules.sd_hijack.model_hijack.clear_comments()
comments = {}
p.setup_prompts() p.setup_prompts()
if type(seed) == list: if type(seed) == list:
@ -756,7 +840,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
p.setup_conds() p.setup_conds()
for comment in model_hijack.comments: for comment in model_hijack.comments:
comments[comment] = 1 p.comment(comment)
p.extra_generation_params.update(model_hijack.extra_generation_params) p.extra_generation_params.update(model_hijack.extra_generation_params)
@ -885,7 +969,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
images_list=output_images, images_list=output_images,
seed=p.all_seeds[0], seed=p.all_seeds[0],
info=infotexts[0], info=infotexts[0],
comments="".join(f"{comment}\n" for comment in comments),
subseed=p.all_subseeds[0], subseed=p.all_subseeds[0],
index_of_first_image=index_of_first_image, index_of_first_image=index_of_first_image,
infotexts=infotexts, infotexts=infotexts,
@ -909,49 +992,51 @@ def old_hires_fix_first_pass_dimensions(width, height):
return width, height return width, height
@dataclass(repr=False)
class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
sampler = None enable_hr: bool = False
denoising_strength: float = 0.75
firstphase_width: int = 0
firstphase_height: int = 0
hr_scale: float = 2.0
hr_upscaler: str = None
hr_second_pass_steps: int = 0
hr_resize_x: int = 0
hr_resize_y: int = 0
hr_checkpoint_name: str = None
hr_sampler_name: str = None
hr_prompt: str = ''
hr_negative_prompt: str = ''
cached_hr_uc = [None, None] cached_hr_uc = [None, None]
cached_hr_c = [None, None] cached_hr_c = [None, None]
def __init__(self, enable_hr: bool = False, denoising_strength: float = 0.75, firstphase_width: int = 0, firstphase_height: int = 0, hr_scale: float = 2.0, hr_upscaler: str = None, hr_second_pass_steps: int = 0, hr_resize_x: int = 0, hr_resize_y: int = 0, hr_checkpoint_name: str = None, hr_sampler_name: str = None, hr_prompt: str = '', hr_negative_prompt: str = '', **kwargs): hr_checkpoint_info: dict = field(default=None, init=False)
super().__init__(**kwargs) hr_upscale_to_x: int = field(default=0, init=False)
self.enable_hr = enable_hr hr_upscale_to_y: int = field(default=0, init=False)
self.denoising_strength = denoising_strength truncate_x: int = field(default=0, init=False)
self.hr_scale = hr_scale truncate_y: int = field(default=0, init=False)
self.hr_upscaler = hr_upscaler applied_old_hires_behavior_to: tuple = field(default=None, init=False)
self.hr_second_pass_steps = hr_second_pass_steps latent_scale_mode: dict = field(default=None, init=False)
self.hr_resize_x = hr_resize_x hr_c: tuple | None = field(default=None, init=False)
self.hr_resize_y = hr_resize_y hr_uc: tuple | None = field(default=None, init=False)
self.hr_upscale_to_x = hr_resize_x all_hr_prompts: list = field(default=None, init=False)
self.hr_upscale_to_y = hr_resize_y all_hr_negative_prompts: list = field(default=None, init=False)
self.hr_checkpoint_name = hr_checkpoint_name hr_prompts: list = field(default=None, init=False)
self.hr_checkpoint_info = None hr_negative_prompts: list = field(default=None, init=False)
self.hr_sampler_name = hr_sampler_name hr_extra_network_data: list = field(default=None, init=False)
self.hr_prompt = hr_prompt
self.hr_negative_prompt = hr_negative_prompt
self.all_hr_prompts = None
self.all_hr_negative_prompts = None
self.latent_scale_mode = None
if firstphase_width != 0 or firstphase_height != 0: def __post_init__(self):
super().__post_init__()
if self.firstphase_width != 0 or self.firstphase_height != 0:
self.hr_upscale_to_x = self.width self.hr_upscale_to_x = self.width
self.hr_upscale_to_y = self.height self.hr_upscale_to_y = self.height
self.width = firstphase_width self.width = self.firstphase_width
self.height = firstphase_height self.height = self.firstphase_height
self.truncate_x = 0
self.truncate_y = 0
self.applied_old_hires_behavior_to = None
self.hr_prompts = None
self.hr_negative_prompts = None
self.hr_extra_network_data = None
self.cached_hr_uc = StableDiffusionProcessingTxt2Img.cached_hr_uc self.cached_hr_uc = StableDiffusionProcessingTxt2Img.cached_hr_uc
self.cached_hr_c = StableDiffusionProcessingTxt2Img.cached_hr_c self.cached_hr_c = StableDiffusionProcessingTxt2Img.cached_hr_c
self.hr_c = None
self.hr_uc = None
def calculate_target_resolution(self): def calculate_target_resolution(self):
if opts.use_old_hires_fix_width_height and self.applied_old_hires_behavior_to != (self.width, self.height): if opts.use_old_hires_fix_width_height and self.applied_old_hires_behavior_to != (self.width, self.height):
@ -1145,6 +1230,9 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio()) sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio())
self.sampler = None
devices.torch_gc()
decoded_samples = decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True) decoded_samples = decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)
self.is_hr_pass = False self.is_hr_pass = False
@ -1191,11 +1279,20 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
hr_prompts = prompt_parser.SdConditioning(self.hr_prompts, width=self.hr_upscale_to_x, height=self.hr_upscale_to_y) hr_prompts = prompt_parser.SdConditioning(self.hr_prompts, width=self.hr_upscale_to_x, height=self.hr_upscale_to_y)
hr_negative_prompts = prompt_parser.SdConditioning(self.hr_negative_prompts, width=self.hr_upscale_to_x, height=self.hr_upscale_to_y, is_negative_prompt=True) hr_negative_prompts = prompt_parser.SdConditioning(self.hr_negative_prompts, width=self.hr_upscale_to_x, height=self.hr_upscale_to_y, is_negative_prompt=True)
hires_steps = (self.hr_second_pass_steps or self.steps) * self.step_multiplier sampler_config = sd_samplers.find_sampler_config(self.hr_sampler_name or self.sampler_name)
self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, hr_negative_prompts, self.firstpass_steps, hires_steps, [self.cached_hr_uc, self.cached_uc], self.hr_extra_network_data) steps = self.hr_second_pass_steps or self.steps
self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, hr_prompts, self.firstpass_steps, hires_steps, [self.cached_hr_c, self.cached_c], self.hr_extra_network_data) total_steps = sampler_config.total_steps(steps) if sampler_config else steps
self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, hr_negative_prompts, self.firstpass_steps, [self.cached_hr_uc, self.cached_uc], self.hr_extra_network_data, total_steps)
self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, hr_prompts, self.firstpass_steps, [self.cached_hr_c, self.cached_c], self.hr_extra_network_data, total_steps)
def setup_conds(self): def setup_conds(self):
if self.is_hr_pass:
# if we are in hr pass right now, the call is being made from the refiner, and we don't need to setup firstpass cons or switch model
self.hr_c = None
self.calculate_hr_conds()
return
super().setup_conds() super().setup_conds()
self.hr_uc = None self.hr_uc = None
@ -1220,7 +1317,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
return super().get_conds() return super().get_conds()
def parse_extra_network_prompts(self): def parse_extra_network_prompts(self):
res = super().parse_extra_network_prompts() res = super().parse_extra_network_prompts()
@ -1233,35 +1329,53 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
return res return res
@dataclass(repr=False)
class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
sampler = None init_images: list = None
resize_mode: int = 0
denoising_strength: float = 0.75
image_cfg_scale: float = None
mask: Any = None
mask_blur_x: int = 4
mask_blur_y: int = 4
mask_blur: int = None
inpainting_fill: int = 0
inpaint_full_res: bool = True
inpaint_full_res_padding: int = 0
inpainting_mask_invert: int = 0
initial_noise_multiplier: float = None
latent_mask: Image = None
def __init__(self, init_images: list = None, resize_mode: int = 0, denoising_strength: float = 0.75, image_cfg_scale: float = None, mask: Any = None, mask_blur: int = None, mask_blur_x: int = 4, mask_blur_y: int = 4, inpainting_fill: int = 0, inpaint_full_res: bool = True, inpaint_full_res_padding: int = 0, inpainting_mask_invert: int = 0, initial_noise_multiplier: float = None, **kwargs): image_mask: Any = field(default=None, init=False)
super().__init__(**kwargs)
self.init_images = init_images nmask: torch.Tensor = field(default=None, init=False)
self.resize_mode: int = resize_mode image_conditioning: torch.Tensor = field(default=None, init=False)
self.denoising_strength: float = denoising_strength init_img_hash: str = field(default=None, init=False)
self.image_cfg_scale: float = image_cfg_scale if shared.sd_model.cond_stage_key == "edit" else None mask_for_overlay: Image = field(default=None, init=False)
self.init_latent = None init_latent: torch.Tensor = field(default=None, init=False)
self.image_mask = mask
self.latent_mask = None def __post_init__(self):
self.mask_for_overlay = None super().__post_init__()
if mask_blur is not None:
mask_blur_x = mask_blur self.image_mask = self.mask
mask_blur_y = mask_blur
self.mask_blur_x = mask_blur_x
self.mask_blur_y = mask_blur_y
self.inpainting_fill = inpainting_fill
self.inpaint_full_res = inpaint_full_res
self.inpaint_full_res_padding = inpaint_full_res_padding
self.inpainting_mask_invert = inpainting_mask_invert
self.initial_noise_multiplier = opts.initial_noise_multiplier if initial_noise_multiplier is None else initial_noise_multiplier
self.mask = None self.mask = None
self.nmask = None self.initial_noise_multiplier = opts.initial_noise_multiplier if self.initial_noise_multiplier is None else self.initial_noise_multiplier
self.image_conditioning = None
@property
def mask_blur(self):
if self.mask_blur_x == self.mask_blur_y:
return self.mask_blur_x
return None
@mask_blur.setter
def mask_blur(self, value):
if isinstance(value, int):
self.mask_blur_x = value
self.mask_blur_y = value
def init(self, all_prompts, all_seeds, all_subseeds): def init(self, all_prompts, all_seeds, all_subseeds):
self.image_cfg_scale: float = self.image_cfg_scale if shared.sd_model.cond_stage_key == "edit" else None
self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model) self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
crop_region = None crop_region = None
@ -1275,13 +1389,13 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
if self.mask_blur_x > 0: if self.mask_blur_x > 0:
np_mask = np.array(image_mask) np_mask = np.array(image_mask)
kernel_size = 2 * int(4 * self.mask_blur_x + 0.5) + 1 kernel_size = 2 * int(2.5 * self.mask_blur_x + 0.5) + 1
np_mask = cv2.GaussianBlur(np_mask, (kernel_size, 1), self.mask_blur_x) np_mask = cv2.GaussianBlur(np_mask, (kernel_size, 1), self.mask_blur_x)
image_mask = Image.fromarray(np_mask) image_mask = Image.fromarray(np_mask)
if self.mask_blur_y > 0: if self.mask_blur_y > 0:
np_mask = np.array(image_mask) np_mask = np.array(image_mask)
kernel_size = 2 * int(4 * self.mask_blur_y + 0.5) + 1 kernel_size = 2 * int(2.5 * self.mask_blur_y + 0.5) + 1
np_mask = cv2.GaussianBlur(np_mask, (1, kernel_size), self.mask_blur_y) np_mask = cv2.GaussianBlur(np_mask, (1, kernel_size), self.mask_blur_y)
image_mask = Image.fromarray(np_mask) image_mask = Image.fromarray(np_mask)

View File

@ -0,0 +1,49 @@
import gradio as gr
from modules import scripts, sd_models
from modules.ui_common import create_refresh_button
from modules.ui_components import InputAccordion
class ScriptRefiner(scripts.Script):
section = "accordions"
create_group = False
def __init__(self):
pass
def title(self):
return "Refiner"
def show(self, is_img2img):
return scripts.AlwaysVisible
def ui(self, is_img2img):
with InputAccordion(False, label="Refiner", elem_id=self.elem_id("enable")) as enable_refiner:
with gr.Row():
refiner_checkpoint = gr.Dropdown(label='Checkpoint', elem_id=self.elem_id("checkpoint"), choices=sd_models.checkpoint_tiles(), value='', tooltip="switch to another model in the middle of generation")
create_refresh_button(refiner_checkpoint, sd_models.list_models, lambda: {"choices": sd_models.checkpoint_tiles()}, self.elem_id("checkpoint_refresh"))
refiner_switch_at = gr.Slider(value=0.8, label="Switch at", minimum=0.01, maximum=1.0, step=0.01, elem_id=self.elem_id("switch_at"), tooltip="fraction of sampling steps when the switch to refiner model should happen; 1=never, 0.5=switch in the middle of generation")
def lookup_checkpoint(title):
info = sd_models.get_closet_checkpoint_match(title)
return None if info is None else info.title
self.infotext_fields = [
(enable_refiner, lambda d: 'Refiner' in d),
(refiner_checkpoint, lambda d: lookup_checkpoint(d.get('Refiner'))),
(refiner_switch_at, 'Refiner switch at'),
]
return enable_refiner, refiner_checkpoint, refiner_switch_at
def setup(self, p, enable_refiner, refiner_checkpoint, refiner_switch_at):
# the actual implementation is in sd_samplers_common.py, apply_refiner
if not enable_refiner or refiner_checkpoint in (None, "", "None"):
p.refiner_checkpoint_info = None
p.refiner_switch_at = None
else:
p.refiner_checkpoint = refiner_checkpoint
p.refiner_switch_at = refiner_switch_at

View File

@ -0,0 +1,111 @@
import json
import gradio as gr
from modules import scripts, ui, errors
from modules.shared import cmd_opts
from modules.ui_components import ToolButton
class ScriptSeed(scripts.ScriptBuiltin):
section = "seed"
create_group = False
def __init__(self):
self.seed = None
self.reuse_seed = None
self.reuse_subseed = None
def title(self):
return "Seed"
def show(self, is_img2img):
return scripts.AlwaysVisible
def ui(self, is_img2img):
with gr.Row(elem_id=self.elem_id("seed_row")):
if cmd_opts.use_textbox_seed:
self.seed = gr.Textbox(label='Seed', value="", elem_id=self.elem_id("seed"), min_width=100)
else:
self.seed = gr.Number(label='Seed', value=-1, elem_id=self.elem_id("seed"), min_width=100, precision=0)
random_seed = ToolButton(ui.random_symbol, elem_id=self.elem_id("random_seed"), label='Random seed')
reuse_seed = ToolButton(ui.reuse_symbol, elem_id=self.elem_id("reuse_seed"), label='Reuse seed')
seed_checkbox = gr.Checkbox(label='Extra', elem_id=self.elem_id("subseed_show"), value=False)
with gr.Group(visible=False, elem_id=self.elem_id("seed_extras")) as seed_extras:
with gr.Row(elem_id=self.elem_id("subseed_row")):
subseed = gr.Number(label='Variation seed', value=-1, elem_id=self.elem_id("subseed"), precision=0)
random_subseed = ToolButton(ui.random_symbol, elem_id=self.elem_id("random_subseed"))
reuse_subseed = ToolButton(ui.reuse_symbol, elem_id=self.elem_id("reuse_subseed"))
subseed_strength = gr.Slider(label='Variation strength', value=0.0, minimum=0, maximum=1, step=0.01, elem_id=self.elem_id("subseed_strength"))
with gr.Row(elem_id=self.elem_id("seed_resize_from_row")):
seed_resize_from_w = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from width", value=0, elem_id=self.elem_id("seed_resize_from_w"))
seed_resize_from_h = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from height", value=0, elem_id=self.elem_id("seed_resize_from_h"))
random_seed.click(fn=None, _js="function(){setRandomSeed('" + self.elem_id("seed") + "')}", show_progress=False, inputs=[], outputs=[])
random_subseed.click(fn=None, _js="function(){setRandomSeed('" + self.elem_id("subseed") + "')}", show_progress=False, inputs=[], outputs=[])
seed_checkbox.change(lambda x: gr.update(visible=x), show_progress=False, inputs=[seed_checkbox], outputs=[seed_extras])
self.infotext_fields = [
(self.seed, "Seed"),
(seed_checkbox, lambda d: "Variation seed" in d or "Seed resize from-1" in d),
(subseed, "Variation seed"),
(subseed_strength, "Variation seed strength"),
(seed_resize_from_w, "Seed resize from-1"),
(seed_resize_from_h, "Seed resize from-2"),
]
self.on_after_component(lambda x: connect_reuse_seed(self.seed, reuse_seed, x.component, False), elem_id=f'generation_info_{self.tabname}')
self.on_after_component(lambda x: connect_reuse_seed(subseed, reuse_subseed, x.component, True), elem_id=f'generation_info_{self.tabname}')
return self.seed, seed_checkbox, subseed, subseed_strength, seed_resize_from_w, seed_resize_from_h
def setup(self, p, seed, seed_checkbox, subseed, subseed_strength, seed_resize_from_w, seed_resize_from_h):
p.seed = seed
if seed_checkbox and subseed_strength > 0:
p.subseed = subseed
p.subseed_strength = subseed_strength
if seed_checkbox and seed_resize_from_w > 0 and seed_resize_from_h > 0:
p.seed_resize_from_w = seed_resize_from_w
p.seed_resize_from_h = seed_resize_from_h
def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info: gr.Textbox, is_subseed):
""" Connects a 'reuse (sub)seed' button's click event so that it copies last used
(sub)seed value from generation info the to the seed field. If copying subseed and subseed strength
was 0, i.e. no variation seed was used, it copies the normal seed value instead."""
def copy_seed(gen_info_string: str, index):
res = -1
try:
gen_info = json.loads(gen_info_string)
index -= gen_info.get('index_of_first_image', 0)
if is_subseed and gen_info.get('subseed_strength', 0) > 0:
all_subseeds = gen_info.get('all_subseeds', [-1])
res = all_subseeds[index if 0 <= index < len(all_subseeds) else 0]
else:
all_seeds = gen_info.get('all_seeds', [-1])
res = all_seeds[index if 0 <= index < len(all_seeds) else 0]
except json.decoder.JSONDecodeError:
if gen_info_string:
errors.report(f"Error parsing JSON generation info: {gen_info_string}")
return [res, gr.update()]
reuse_seed.click(
fn=copy_seed,
_js="(x, y) => [x, selected_gallery_index()]",
show_progress=False,
inputs=[generation_info, seed],
outputs=[seed, seed]
)

View File

@ -3,6 +3,7 @@ import re
import sys import sys
import inspect import inspect
from collections import namedtuple from collections import namedtuple
from dataclasses import dataclass
import gradio as gr import gradio as gr
@ -21,6 +22,11 @@ class PostprocessBatchListArgs:
self.images = images self.images = images
@dataclass
class OnComponent:
component: gr.blocks.Block
class Script: class Script:
name = None name = None
"""script's internal name derived from title""" """script's internal name derived from title"""
@ -35,9 +41,13 @@ class Script:
is_txt2img = False is_txt2img = False
is_img2img = False is_img2img = False
tabname = None
group = None group = None
"""A gr.Group component that has all script's UI inside it""" """A gr.Group component that has all script's UI inside it."""
create_group = True
"""If False, for alwayson scripts, a group component will not be created."""
infotext_fields = None infotext_fields = None
"""if set in ui(), this is a list of pairs of gradio component + text; the text will be used when """if set in ui(), this is a list of pairs of gradio component + text; the text will be used when
@ -52,6 +62,12 @@ class Script:
api_info = None api_info = None
"""Generated value of type modules.api.models.ScriptInfo with information about the script for API""" """Generated value of type modules.api.models.ScriptInfo with information about the script for API"""
on_before_component_elem_id = None
"""list of callbacks to be called before a component with an elem_id is created"""
on_after_component_elem_id = None
"""list of callbacks to be called after a component with an elem_id is created"""
def title(self): def title(self):
"""this function should return the title of the script. This is what will be displayed in the dropdown menu.""" """this function should return the title of the script. This is what will be displayed in the dropdown menu."""
@ -90,9 +106,16 @@ class Script:
pass pass
def setup(self, p, *args):
"""For AlwaysVisible scripts, this function is called when the processing object is set up, before any processing starts.
args contains all values returned by components from ui().
"""
pass
def before_process(self, p, *args): def before_process(self, p, *args):
""" """
This function is called very early before processing begins for AlwaysVisible scripts. This function is called very early during processing begins for AlwaysVisible scripts.
You can modify the processing object (p) here, inject hooks, etc. You can modify the processing object (p) here, inject hooks, etc.
args contains all values returned by components from ui() args contains all values returned by components from ui()
""" """
@ -212,6 +235,30 @@ class Script:
pass pass
def on_before_component(self, callback, *, elem_id):
"""
Calls callback before a component is created. The callback function is called with a single argument of type OnComponent.
May be called in show() or ui() - but it may be too late in latter as some components may already be created.
This function is an alternative to before_component in that it also cllows to run before a component is created, but
it doesn't require to be called for every created component - just for the one you need.
"""
if self.on_before_component_elem_id is None:
self.on_before_component_elem_id = []
self.on_before_component_elem_id.append((elem_id, callback))
def on_after_component(self, callback, *, elem_id):
"""
Calls callback after a component is created. The callback function is called with a single argument of type OnComponent.
"""
if self.on_after_component_elem_id is None:
self.on_after_component_elem_id = []
self.on_after_component_elem_id.append((elem_id, callback))
def describe(self): def describe(self):
"""unused""" """unused"""
return "" return ""
@ -232,6 +279,18 @@ class Script:
""" """
pass pass
class ScriptBuiltin(Script):
def elem_id(self, item_id):
"""helper function to generate id for a HTML element, constructs final id out of tab and user-supplied item_id"""
need_tabname = self.show(True) == self.show(False)
tabname = ('img2img' if self.is_img2img else 'txt2txt') + "_" if need_tabname else ""
return f'{tabname}{item_id}'
current_basedir = paths.script_path current_basedir = paths.script_path
@ -250,7 +309,7 @@ postprocessing_scripts_data = []
ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir", "module"]) ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir", "module"])
def list_scripts(scriptdirname, extension): def list_scripts(scriptdirname, extension, *, include_extensions=True):
scripts_list = [] scripts_list = []
basedir = os.path.join(paths.script_path, scriptdirname) basedir = os.path.join(paths.script_path, scriptdirname)
@ -258,8 +317,9 @@ def list_scripts(scriptdirname, extension):
for filename in sorted(os.listdir(basedir)): for filename in sorted(os.listdir(basedir)):
scripts_list.append(ScriptFile(paths.script_path, filename, os.path.join(basedir, filename))) scripts_list.append(ScriptFile(paths.script_path, filename, os.path.join(basedir, filename)))
for ext in extensions.active(): if include_extensions:
scripts_list += ext.list_files(scriptdirname, extension) for ext in extensions.active():
scripts_list += ext.list_files(scriptdirname, extension)
scripts_list = [x for x in scripts_list if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)] scripts_list = [x for x in scripts_list if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)]
@ -288,7 +348,7 @@ def load_scripts():
postprocessing_scripts_data.clear() postprocessing_scripts_data.clear()
script_callbacks.clear_callbacks() script_callbacks.clear_callbacks()
scripts_list = list_scripts("scripts", ".py") scripts_list = list_scripts("scripts", ".py") + list_scripts("modules/processing_scripts", ".py", include_extensions=False)
syspath = sys.path syspath = sys.path
@ -349,10 +409,17 @@ class ScriptRunner:
self.selectable_scripts = [] self.selectable_scripts = []
self.alwayson_scripts = [] self.alwayson_scripts = []
self.titles = [] self.titles = []
self.title_map = {}
self.infotext_fields = [] self.infotext_fields = []
self.paste_field_names = [] self.paste_field_names = []
self.inputs = [None] self.inputs = [None]
self.on_before_component_elem_id = {}
"""dict of callbacks to be called before an element is created; key=elem_id, value=list of callbacks"""
self.on_after_component_elem_id = {}
"""dict of callbacks to be called after an element is created; key=elem_id, value=list of callbacks"""
def initialize_scripts(self, is_img2img): def initialize_scripts(self, is_img2img):
from modules import scripts_auto_postprocessing from modules import scripts_auto_postprocessing
@ -367,6 +434,7 @@ class ScriptRunner:
script.filename = script_data.path script.filename = script_data.path
script.is_txt2img = not is_img2img script.is_txt2img = not is_img2img
script.is_img2img = is_img2img script.is_img2img = is_img2img
script.tabname = "img2img" if is_img2img else "txt2img"
visibility = script.show(script.is_img2img) visibility = script.show(script.is_img2img)
@ -379,6 +447,28 @@ class ScriptRunner:
self.scripts.append(script) self.scripts.append(script)
self.selectable_scripts.append(script) self.selectable_scripts.append(script)
self.apply_on_before_component_callbacks()
def apply_on_before_component_callbacks(self):
for script in self.scripts:
on_before = script.on_before_component_elem_id or []
on_after = script.on_after_component_elem_id or []
for elem_id, callback in on_before:
if elem_id not in self.on_before_component_elem_id:
self.on_before_component_elem_id[elem_id] = []
self.on_before_component_elem_id[elem_id].append((callback, script))
for elem_id, callback in on_after:
if elem_id not in self.on_after_component_elem_id:
self.on_after_component_elem_id[elem_id] = []
self.on_after_component_elem_id[elem_id].append((callback, script))
on_before.clear()
on_after.clear()
def create_script_ui(self, script): def create_script_ui(self, script):
import modules.api.models as api_models import modules.api.models as api_models
@ -429,15 +519,20 @@ class ScriptRunner:
if script.alwayson and script.section != section: if script.alwayson and script.section != section:
continue continue
with gr.Group(visible=script.alwayson) as group: if script.create_group:
self.create_script_ui(script) with gr.Group(visible=script.alwayson) as group:
self.create_script_ui(script)
script.group = group script.group = group
else:
self.create_script_ui(script)
def prepare_ui(self): def prepare_ui(self):
self.inputs = [None] self.inputs = [None]
def setup_ui(self): def setup_ui(self):
all_titles = [wrap_call(script.title, script.filename, "title") or script.filename for script in self.scripts]
self.title_map = {title.lower(): script for title, script in zip(all_titles, self.scripts)}
self.titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in self.selectable_scripts] self.titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in self.selectable_scripts]
self.setup_ui_for_section(None) self.setup_ui_for_section(None)
@ -484,6 +579,8 @@ class ScriptRunner:
self.infotext_fields.append((dropdown, lambda x: gr.update(value=x.get('Script', 'None')))) self.infotext_fields.append((dropdown, lambda x: gr.update(value=x.get('Script', 'None'))))
self.infotext_fields.extend([(script.group, onload_script_visibility) for script in self.selectable_scripts]) self.infotext_fields.extend([(script.group, onload_script_visibility) for script in self.selectable_scripts])
self.apply_on_before_component_callbacks()
return self.inputs return self.inputs
def run(self, p, *args): def run(self, p, *args):
@ -577,6 +674,12 @@ class ScriptRunner:
errors.report(f"Error running postprocess_image: {script.filename}", exc_info=True) errors.report(f"Error running postprocess_image: {script.filename}", exc_info=True)
def before_component(self, component, **kwargs): def before_component(self, component, **kwargs):
for callback, script in self.on_before_component_elem_id.get(kwargs.get("elem_id"), []):
try:
callback(OnComponent(component=component))
except Exception:
errors.report(f"Error running on_before_component: {script.filename}", exc_info=True)
for script in self.scripts: for script in self.scripts:
try: try:
script.before_component(component, **kwargs) script.before_component(component, **kwargs)
@ -584,12 +687,21 @@ class ScriptRunner:
errors.report(f"Error running before_component: {script.filename}", exc_info=True) errors.report(f"Error running before_component: {script.filename}", exc_info=True)
def after_component(self, component, **kwargs): def after_component(self, component, **kwargs):
for callback, script in self.on_after_component_elem_id.get(component.elem_id, []):
try:
callback(OnComponent(component=component))
except Exception:
errors.report(f"Error running on_after_component: {script.filename}", exc_info=True)
for script in self.scripts: for script in self.scripts:
try: try:
script.after_component(component, **kwargs) script.after_component(component, **kwargs)
except Exception: except Exception:
errors.report(f"Error running after_component: {script.filename}", exc_info=True) errors.report(f"Error running after_component: {script.filename}", exc_info=True)
def script(self, title):
return self.title_map.get(title.lower())
def reload_sources(self, cache): def reload_sources(self, cache):
for si, script in list(enumerate(self.scripts)): for si, script in list(enumerate(self.scripts)):
args_from = script.args_from args_from = script.args_from
@ -608,7 +720,6 @@ class ScriptRunner:
self.scripts[si].args_from = args_from self.scripts[si].args_from = args_from
self.scripts[si].args_to = args_to self.scripts[si].args_to = args_to
def before_hr(self, p): def before_hr(self, p):
for script in self.alwayson_scripts: for script in self.alwayson_scripts:
try: try:
@ -617,6 +728,14 @@ class ScriptRunner:
except Exception: except Exception:
errors.report(f"Error running before_hr: {script.filename}", exc_info=True) errors.report(f"Error running before_hr: {script.filename}", exc_info=True)
def setup_scrips(self, p):
for script in self.alwayson_scripts:
try:
script_args = p.script_args[script.args_from:script.args_to]
script.setup(p, *script_args)
except Exception:
errors.report(f"Error running setup: {script.filename}", exc_info=True)
scripts_txt2img: ScriptRunner = None scripts_txt2img: ScriptRunner = None
scripts_img2img: ScriptRunner = None scripts_img2img: ScriptRunner = None

View File

@ -1,6 +1,7 @@
from __future__ import annotations from __future__ import annotations
import math import math
import psutil import psutil
import platform
import torch import torch
from torch import einsum from torch import einsum
@ -94,7 +95,10 @@ class SdOptimizationSdp(SdOptimizationSdpNoMem):
class SdOptimizationSubQuad(SdOptimization): class SdOptimizationSubQuad(SdOptimization):
name = "sub-quadratic" name = "sub-quadratic"
cmd_opt = "opt_sub_quad_attention" cmd_opt = "opt_sub_quad_attention"
priority = 10
@property
def priority(self):
return 1000 if shared.device.type == 'mps' else 10
def apply(self): def apply(self):
ldm.modules.attention.CrossAttention.forward = sub_quad_attention_forward ldm.modules.attention.CrossAttention.forward = sub_quad_attention_forward
@ -120,7 +124,7 @@ class SdOptimizationInvokeAI(SdOptimization):
@property @property
def priority(self): def priority(self):
return 1000 if not torch.cuda.is_available() else 10 return 1000 if shared.device.type != 'mps' and not torch.cuda.is_available() else 10
def apply(self): def apply(self):
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_invokeAI ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_invokeAI
@ -427,7 +431,10 @@ def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_
qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens
if chunk_threshold is None: if chunk_threshold is None:
chunk_threshold_bytes = int(get_available_vram() * 0.9) if q.device.type == 'mps' else int(get_available_vram() * 0.7) if q.device.type == 'mps':
chunk_threshold_bytes = 268435456 * (2 if platform.processor() == 'i386' else bytes_per_token)
else:
chunk_threshold_bytes = int(get_available_vram() * 0.7)
elif chunk_threshold == 0: elif chunk_threshold == 0:
chunk_threshold_bytes = None chunk_threshold_bytes = None
else: else:

View File

@ -147,6 +147,9 @@ re_strip_checksum = re.compile(r"\s*\[[^]]+]\s*$")
def get_closet_checkpoint_match(search_string): def get_closet_checkpoint_match(search_string):
if not search_string:
return None
checkpoint_info = checkpoint_aliases.get(search_string, None) checkpoint_info = checkpoint_aliases.get(search_string, None)
if checkpoint_info is not None: if checkpoint_info is not None:
return checkpoint_info return checkpoint_info

View File

@ -45,18 +45,23 @@ class CFGDenoiser(torch.nn.Module):
self.nmask = None self.nmask = None
self.init_latent = None self.init_latent = None
self.steps = None self.steps = None
"""number of steps as specified by user in UI"""
self.total_steps = None
"""expected number of calls to denoiser calculated from self.steps and specifics of the selected sampler"""
self.step = 0 self.step = 0
self.image_cfg_scale = None self.image_cfg_scale = None
self.padded_cond_uncond = False self.padded_cond_uncond = False
self.sampler = sampler self.sampler = sampler
self.model_wrap = None self.model_wrap = None
self.p = None self.p = None
self.mask_before_denoising = False
@property @property
def inner_model(self): def inner_model(self):
raise NotImplementedError() raise NotImplementedError()
def combine_denoised(self, x_out, conds_list, uncond, cond_scale): def combine_denoised(self, x_out, conds_list, uncond, cond_scale):
denoised_uncond = x_out[-uncond.shape[0]:] denoised_uncond = x_out[-uncond.shape[0]:]
denoised = torch.clone(denoised_uncond) denoised = torch.clone(denoised_uncond)
@ -100,7 +105,7 @@ class CFGDenoiser(torch.nn.Module):
assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)" assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)"
if self.mask is not None: if self.mask_before_denoising and self.mask is not None:
x = self.init_latent * self.mask + self.nmask * x x = self.init_latent * self.mask + self.nmask * x
batch_size = len(conds_list) batch_size = len(conds_list)
@ -202,6 +207,9 @@ class CFGDenoiser(torch.nn.Module):
else: else:
denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale) denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
if not self.mask_before_denoising and self.mask is not None:
denoised = self.init_latent * self.mask + self.nmask * denoised
self.sampler.last_latent = self.get_pred_x0(torch.cat([x_in[i:i + 1] for i in denoised_image_indexes]), torch.cat([x_out[i:i + 1] for i in denoised_image_indexes]), sigma) self.sampler.last_latent = self.get_pred_x0(torch.cat([x_in[i:i + 1] for i in denoised_image_indexes]), torch.cat([x_out[i:i + 1] for i in denoised_image_indexes]), sigma)
if opts.live_preview_content == "Prompt": if opts.live_preview_content == "Prompt":

View File

@ -7,7 +7,16 @@ from modules import devices, images, sd_vae_approx, sd_samplers, sd_vae_taesd, s
from modules.shared import opts, state from modules.shared import opts, state
import k_diffusion.sampling import k_diffusion.sampling
SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options'])
SamplerDataTuple = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options'])
class SamplerData(SamplerDataTuple):
def total_steps(self, steps):
if self.options.get("second_order", False):
steps = steps * 2
return steps
def setup_img2img_steps(p, steps=None): def setup_img2img_steps(p, steps=None):
@ -83,7 +92,15 @@ def images_tensor_to_samples(image, approximation=None, model=None):
model = shared.sd_model model = shared.sd_model
image = image.to(shared.device, dtype=devices.dtype_vae) image = image.to(shared.device, dtype=devices.dtype_vae)
image = image * 2 - 1 image = image * 2 - 1
x_latent = model.get_first_stage_encoding(model.encode_first_stage(image)) if len(image) > 1:
x_latent = torch.stack([
model.get_first_stage_encoding(
model.encode_first_stage(torch.unsqueeze(img, 0))
)[0]
for img in image
])
else:
x_latent = model.get_first_stage_encoding(model.encode_first_stage(image))
return x_latent return x_latent
@ -131,31 +148,29 @@ def replace_torchsde_browinan():
replace_torchsde_browinan() replace_torchsde_browinan()
def apply_refiner(sampler): def apply_refiner(cfg_denoiser):
completed_ratio = sampler.step / sampler.steps completed_ratio = cfg_denoiser.step / cfg_denoiser.total_steps
refiner_switch_at = cfg_denoiser.p.refiner_switch_at
refiner_checkpoint_info = cfg_denoiser.p.refiner_checkpoint_info
if completed_ratio <= shared.opts.sd_refiner_switch_at: if refiner_switch_at is not None and completed_ratio < refiner_switch_at:
return False return False
if shared.opts.sd_refiner_checkpoint == "None": if refiner_checkpoint_info is None or shared.sd_model.sd_checkpoint_info == refiner_checkpoint_info:
return False return False
if shared.sd_model.sd_checkpoint_info.title == shared.opts.sd_refiner_checkpoint: if getattr(cfg_denoiser.p, "enable_hr", False) and not cfg_denoiser.p.is_hr_pass:
return False return False
refiner_checkpoint_info = sd_models.get_closet_checkpoint_match(shared.opts.sd_refiner_checkpoint) cfg_denoiser.p.extra_generation_params['Refiner'] = refiner_checkpoint_info.short_title
if refiner_checkpoint_info is None: cfg_denoiser.p.extra_generation_params['Refiner switch at'] = refiner_switch_at
raise Exception(f'Could not find checkpoint with name {shared.opts.sd_refiner_checkpoint}')
sampler.p.extra_generation_params['Refiner'] = refiner_checkpoint_info.short_title
sampler.p.extra_generation_params['Refiner switch at'] = shared.opts.sd_refiner_switch_at
with sd_models.SkipWritingToConfig(): with sd_models.SkipWritingToConfig():
sd_models.reload_model_weights(info=refiner_checkpoint_info) sd_models.reload_model_weights(info=refiner_checkpoint_info)
devices.torch_gc() devices.torch_gc()
sampler.p.setup_conds() cfg_denoiser.p.setup_conds()
sampler.update_inner_model() cfg_denoiser.update_inner_model()
return True return True
@ -192,7 +207,7 @@ class Sampler:
self.sampler_noises = None self.sampler_noises = None
self.stop_at = None self.stop_at = None
self.eta = None self.eta = None
self.config = None # set by the function calling the constructor self.config: SamplerData = None # set by the function calling the constructor
self.last_latent = None self.last_latent = None
self.s_min_uncond = None self.s_min_uncond = None
self.s_churn = 0.0 self.s_churn = 0.0
@ -208,6 +223,7 @@ class Sampler:
self.p = None self.p = None
self.model_wrap_cfg = None self.model_wrap_cfg = None
self.sampler_extra_args = None self.sampler_extra_args = None
self.options = {}
def callback_state(self, d): def callback_state(self, d):
step = d['i'] step = d['i']
@ -220,6 +236,7 @@ class Sampler:
def launch_sampling(self, steps, func): def launch_sampling(self, steps, func):
self.model_wrap_cfg.steps = steps self.model_wrap_cfg.steps = steps
self.model_wrap_cfg.total_steps = self.config.total_steps(steps)
state.sampling_steps = steps state.sampling_steps = steps
state.sampling_step = 0 state.sampling_step = 0
@ -267,19 +284,19 @@ class Sampler:
s_tmax = getattr(opts, 's_tmax', p.s_tmax) or self.s_tmax # 0 = inf s_tmax = getattr(opts, 's_tmax', p.s_tmax) or self.s_tmax # 0 = inf
s_noise = getattr(opts, 's_noise', p.s_noise) s_noise = getattr(opts, 's_noise', p.s_noise)
if s_churn != self.s_churn: if 's_churn' in extra_params_kwargs and s_churn != self.s_churn:
extra_params_kwargs['s_churn'] = s_churn extra_params_kwargs['s_churn'] = s_churn
p.s_churn = s_churn p.s_churn = s_churn
p.extra_generation_params['Sigma churn'] = s_churn p.extra_generation_params['Sigma churn'] = s_churn
if s_tmin != self.s_tmin: if 's_tmin' in extra_params_kwargs and s_tmin != self.s_tmin:
extra_params_kwargs['s_tmin'] = s_tmin extra_params_kwargs['s_tmin'] = s_tmin
p.s_tmin = s_tmin p.s_tmin = s_tmin
p.extra_generation_params['Sigma tmin'] = s_tmin p.extra_generation_params['Sigma tmin'] = s_tmin
if s_tmax != self.s_tmax: if 's_tmax' in extra_params_kwargs and s_tmax != self.s_tmax:
extra_params_kwargs['s_tmax'] = s_tmax extra_params_kwargs['s_tmax'] = s_tmax
p.s_tmax = s_tmax p.s_tmax = s_tmax
p.extra_generation_params['Sigma tmax'] = s_tmax p.extra_generation_params['Sigma tmax'] = s_tmax
if s_noise != self.s_noise: if 's_noise' in extra_params_kwargs and s_noise != self.s_noise:
extra_params_kwargs['s_noise'] = s_noise extra_params_kwargs['s_noise'] = s_noise
p.s_noise = s_noise p.s_noise = s_noise
p.extra_generation_params['Sigma noise'] = s_noise p.extra_generation_params['Sigma noise'] = s_noise
@ -296,5 +313,8 @@ class Sampler:
current_iter_seeds = p.all_seeds[p.iteration * p.batch_size:(p.iteration + 1) * p.batch_size] current_iter_seeds = p.all_seeds[p.iteration * p.batch_size:(p.iteration + 1) * p.batch_size]
return BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=current_iter_seeds) return BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=current_iter_seeds)
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
raise NotImplementedError()
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
raise NotImplementedError()

View File

@ -22,6 +22,12 @@ samplers_k_diffusion = [
('DPM++ 2M', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {}), ('DPM++ 2M', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {}),
('DPM++ SDE', 'sample_dpmpp_sde', ['k_dpmpp_sde'], {"second_order": True, "brownian_noise": True}), ('DPM++ SDE', 'sample_dpmpp_sde', ['k_dpmpp_sde'], {"second_order": True, "brownian_noise": True}),
('DPM++ 2M SDE', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_ka'], {"brownian_noise": True}), ('DPM++ 2M SDE', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_ka'], {"brownian_noise": True}),
('DPM++ 2M SDE Heun', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_heun'], {"brownian_noise": True, "solver_type": "heun"}),
('DPM++ 2M SDE Heun Karras', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_heun_ka'], {'scheduler': 'karras', "brownian_noise": True, "solver_type": "heun"}),
('DPM++ 2M SDE Heun Exponential', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_heun_exp'], {'scheduler': 'exponential', "brownian_noise": True, "solver_type": "heun"}),
('DPM++ 3M SDE', 'sample_dpmpp_3m_sde', ['k_dpmpp_3m_sde'], {'discard_next_to_last_sigma': True, "brownian_noise": True}),
('DPM++ 3M SDE Karras', 'sample_dpmpp_3m_sde', ['k_dpmpp_3m_sde_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "brownian_noise": True}),
('DPM++ 3M SDE Exponential', 'sample_dpmpp_3m_sde', ['k_dpmpp_3m_sde_exp'], {'scheduler': 'exponential', 'discard_next_to_last_sigma': True, "brownian_noise": True}),
('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {"uses_ensd": True}), ('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {"uses_ensd": True}),
('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {"uses_ensd": True}), ('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {"uses_ensd": True}),
('LMS Karras', 'sample_lms', ['k_lms_ka'], {'scheduler': 'karras'}), ('LMS Karras', 'sample_lms', ['k_lms_ka'], {'scheduler': 'karras'}),
@ -42,6 +48,12 @@ sampler_extra_params = {
'sample_euler': ['s_churn', 's_tmin', 's_tmax', 's_noise'], 'sample_euler': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
'sample_heun': ['s_churn', 's_tmin', 's_tmax', 's_noise'], 'sample_heun': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
'sample_dpm_2': ['s_churn', 's_tmin', 's_tmax', 's_noise'], 'sample_dpm_2': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
'sample_dpm_fast': ['s_noise'],
'sample_dpm_2_ancestral': ['s_noise'],
'sample_dpmpp_2s_ancestral': ['s_noise'],
'sample_dpmpp_sde': ['s_noise'],
'sample_dpmpp_2m_sde': ['s_noise'],
'sample_dpmpp_3m_sde': ['s_noise'],
} }
k_diffusion_samplers_map = {x.name: x for x in samplers_data_k_diffusion} k_diffusion_samplers_map = {x.name: x for x in samplers_data_k_diffusion}
@ -64,9 +76,12 @@ class CFGDenoiserKDiffusion(sd_samplers_cfg_denoiser.CFGDenoiser):
class KDiffusionSampler(sd_samplers_common.Sampler): class KDiffusionSampler(sd_samplers_common.Sampler):
def __init__(self, funcname, sd_model): def __init__(self, funcname, sd_model, options=None):
super().__init__(funcname) super().__init__(funcname)
self.extra_params = sampler_extra_params.get(funcname, [])
self.options = options or {}
self.func = funcname if callable(funcname) else getattr(k_diffusion.sampling, self.funcname) self.func = funcname if callable(funcname) else getattr(k_diffusion.sampling, self.funcname)
self.model_wrap_cfg = CFGDenoiserKDiffusion(self) self.model_wrap_cfg = CFGDenoiserKDiffusion(self)
@ -149,6 +164,9 @@ class KDiffusionSampler(sd_samplers_common.Sampler):
noise_sampler = self.create_noise_sampler(x, sigmas, p) noise_sampler = self.create_noise_sampler(x, sigmas, p)
extra_params_kwargs['noise_sampler'] = noise_sampler extra_params_kwargs['noise_sampler'] = noise_sampler
if self.config.options.get('solver_type', None) == 'heun':
extra_params_kwargs['solver_type'] = 'heun'
self.model_wrap_cfg.init_latent = x self.model_wrap_cfg.init_latent = x
self.last_latent = x self.last_latent = x
self.sampler_extra_args = { self.sampler_extra_args = {
@ -190,6 +208,9 @@ class KDiffusionSampler(sd_samplers_common.Sampler):
noise_sampler = self.create_noise_sampler(x, sigmas, p) noise_sampler = self.create_noise_sampler(x, sigmas, p)
extra_params_kwargs['noise_sampler'] = noise_sampler extra_params_kwargs['noise_sampler'] = noise_sampler
if self.config.options.get('solver_type', None) == 'heun':
extra_params_kwargs['solver_type'] = 'heun'
self.last_latent = x self.last_latent = x
self.sampler_extra_args = { self.sampler_extra_args = {
'cond': conditioning, 'cond': conditioning,
@ -198,6 +219,7 @@ class KDiffusionSampler(sd_samplers_common.Sampler):
'cond_scale': p.cfg_scale, 'cond_scale': p.cfg_scale,
's_min_uncond': self.s_min_uncond 's_min_uncond': self.s_min_uncond
} }
samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args=self.sampler_extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs)) samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args=self.sampler_extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
if self.model_wrap_cfg.padded_cond_uncond: if self.model_wrap_cfg.padded_cond_uncond:

View File

@ -49,12 +49,12 @@ class CFGDenoiserTimesteps(CFGDenoiser):
super().__init__(sampler) super().__init__(sampler)
self.alphas = shared.sd_model.alphas_cumprod self.alphas = shared.sd_model.alphas_cumprod
self.mask_before_denoising = True
def get_pred_x0(self, x_in, x_out, sigma): def get_pred_x0(self, x_in, x_out, sigma):
ts = int(sigma.item()) ts = sigma.to(dtype=int)
s_in = x_in.new_ones([x_in.shape[0]]) a_t = self.alphas[ts][:, None, None, None]
a_t = self.alphas[ts].item() * s_in
sqrt_one_minus_at = (1 - a_t).sqrt() sqrt_one_minus_at = (1 - a_t).sqrt()
pred_x0 = (x_in - sqrt_one_minus_at * x_out) / a_t.sqrt() pred_x0 = (x_in - sqrt_one_minus_at * x_out) / a_t.sqrt()

View File

@ -11,21 +11,22 @@ from modules.models.diffusion.uni_pc import uni_pc
def ddim(model, x, timesteps, extra_args=None, callback=None, disable=None, eta=0.0): def ddim(model, x, timesteps, extra_args=None, callback=None, disable=None, eta=0.0):
alphas_cumprod = model.inner_model.inner_model.alphas_cumprod alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
alphas = alphas_cumprod[timesteps] alphas = alphas_cumprod[timesteps]
alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(torch.float64) alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(torch.float64 if x.device.type != 'mps' else torch.float32)
sqrt_one_minus_alphas = torch.sqrt(1 - alphas) sqrt_one_minus_alphas = torch.sqrt(1 - alphas)
sigmas = eta * np.sqrt((1 - alphas_prev.cpu().numpy()) / (1 - alphas.cpu()) * (1 - alphas.cpu() / alphas_prev.cpu().numpy())) sigmas = eta * np.sqrt((1 - alphas_prev.cpu().numpy()) / (1 - alphas.cpu()) * (1 - alphas.cpu() / alphas_prev.cpu().numpy()))
extra_args = {} if extra_args is None else extra_args extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]]) s_in = x.new_ones((x.shape[0]))
s_x = x.new_ones((x.shape[0], 1, 1, 1))
for i in tqdm.trange(len(timesteps) - 1, disable=disable): for i in tqdm.trange(len(timesteps) - 1, disable=disable):
index = len(timesteps) - 1 - i index = len(timesteps) - 1 - i
e_t = model(x, timesteps[index].item() * s_in, **extra_args) e_t = model(x, timesteps[index].item() * s_in, **extra_args)
a_t = alphas[index].item() * s_in a_t = alphas[index].item() * s_x
a_prev = alphas_prev[index].item() * s_in a_prev = alphas_prev[index].item() * s_x
sigma_t = sigmas[index].item() * s_in sigma_t = sigmas[index].item() * s_x
sqrt_one_minus_at = sqrt_one_minus_alphas[index].item() * s_in sqrt_one_minus_at = sqrt_one_minus_alphas[index].item() * s_x
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
dir_xt = (1. - a_prev - sigma_t ** 2).sqrt() * e_t dir_xt = (1. - a_prev - sigma_t ** 2).sqrt() * e_t
@ -42,18 +43,19 @@ def ddim(model, x, timesteps, extra_args=None, callback=None, disable=None, eta=
def plms(model, x, timesteps, extra_args=None, callback=None, disable=None): def plms(model, x, timesteps, extra_args=None, callback=None, disable=None):
alphas_cumprod = model.inner_model.inner_model.alphas_cumprod alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
alphas = alphas_cumprod[timesteps] alphas = alphas_cumprod[timesteps]
alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(torch.float64) alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(torch.float64 if x.device.type != 'mps' else torch.float32)
sqrt_one_minus_alphas = torch.sqrt(1 - alphas) sqrt_one_minus_alphas = torch.sqrt(1 - alphas)
extra_args = {} if extra_args is None else extra_args extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]]) s_in = x.new_ones([x.shape[0]])
s_x = x.new_ones((x.shape[0], 1, 1, 1))
old_eps = [] old_eps = []
def get_x_prev_and_pred_x0(e_t, index): def get_x_prev_and_pred_x0(e_t, index):
# select parameters corresponding to the currently considered timestep # select parameters corresponding to the currently considered timestep
a_t = alphas[index].item() * s_in a_t = alphas[index].item() * s_x
a_prev = alphas_prev[index].item() * s_in a_prev = alphas_prev[index].item() * s_x
sqrt_one_minus_at = sqrt_one_minus_alphas[index].item() * s_in sqrt_one_minus_at = sqrt_one_minus_alphas[index].item() * s_x
# current prediction for x_0 # current prediction for x_0
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()

View File

@ -31,7 +31,9 @@ def get_loaded_vae_hash():
if loaded_vae_file is None: if loaded_vae_file is None:
return None return None
return hashes.sha256(loaded_vae_file, 'vae')[0:10] sha256 = hashes.sha256(loaded_vae_file, 'vae')
return sha256[0:10] if sha256 else None
def get_base_vae(model): def get_base_vae(model):

View File

@ -69,10 +69,11 @@ def reload_hypernetworks():
ui_reorder_categories_builtin_items = [ ui_reorder_categories_builtin_items = [
"inpaint", "inpaint",
"sampler", "sampler",
"accordions",
"checkboxes", "checkboxes",
"hires_fix",
"dimensions", "dimensions",
"cfg", "cfg",
"denoising",
"seed", "seed",
"batch", "batch",
"override_settings", "override_settings",
@ -86,7 +87,7 @@ def ui_reorder_categories():
sections = {} sections = {}
for script in scripts.scripts_txt2img.scripts + scripts.scripts_img2img.scripts: for script in scripts.scripts_txt2img.scripts + scripts.scripts_img2img.scripts:
if isinstance(script.section, str): if isinstance(script.section, str) and script.section not in ui_reorder_categories_builtin_items:
sections[script.section] = 1 sections[script.section] = 1
yield from sections yield from sections

View File

@ -140,8 +140,6 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
"upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"), "upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"),
"randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU", "NV"]}).info("changes seeds drastically; use CPU to produce the same picture across different videocard vendors; use NV to produce same picture as on NVidia videocards"), "randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU", "NV"]}).info("changes seeds drastically; use CPU to produce the same picture across different videocard vendors; use NV to produce same picture as on NVidia videocards"),
"tiling": OptionInfo(False, "Tiling", infotext='Tiling').info("produce a tileable picture"), "tiling": OptionInfo(False, "Tiling", infotext='Tiling').info("produce a tileable picture"),
"sd_refiner_checkpoint": OptionInfo("None", "Refiner checkpoint", gr.Dropdown, lambda: {"choices": ["None"] + shared_items.list_checkpoint_tiles()}, refresh=shared_items.refresh_checkpoints, infotext="Refiner").info("switch to another model in the middle of generation"),
"sd_refiner_switch_at": OptionInfo(1.0, "Refiner switch at", gr.Slider, {"minimum": 0.01, "maximum": 1.0, "step": 0.01}, infotext='Refiner switch at').info("fraction of sampling steps when the swtch to refiner model should happen; 1=never, 0.5=switch in the middle of generation"),
})) }))
options_templates.update(options_section(('sdxl', "Stable Diffusion XL"), { options_templates.update(options_section(('sdxl', "Stable Diffusion XL"), {
@ -288,12 +286,12 @@ options_templates.update(options_section(('ui', "Live previews"), {
options_templates.update(options_section(('sampler-params', "Sampler parameters"), { options_templates.update(options_section(('sampler-params', "Sampler parameters"), {
"hide_samplers": OptionInfo([], "Hide samplers in user interface", gr.CheckboxGroup, lambda: {"choices": [x.name for x in shared_items.list_samplers()]}).needs_reload_ui(), "hide_samplers": OptionInfo([], "Hide samplers in user interface", gr.CheckboxGroup, lambda: {"choices": [x.name for x in shared_items.list_samplers()]}).needs_reload_ui(),
"eta_ddim": OptionInfo(0.0, "Eta for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}, infotext='Eta DDIM').info("noise multiplier; higher = more unperdictable results"), "eta_ddim": OptionInfo(0.0, "Eta for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}, infotext='Eta DDIM').info("noise multiplier; higher = more unperdictable results"),
"eta_ancestral": OptionInfo(1.0, "Eta for ancestral samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}, infotext='Eta').info("noise multiplier; applies to Euler a and other samplers that have a in them"), "eta_ancestral": OptionInfo(1.0, "Eta for k-diffusion samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}, infotext='Eta').info("noise multiplier; currently only applies to ancestral samplers (i.e. Euler a) and SDE samplers"),
"ddim_discretize": OptionInfo('uniform', "img2img DDIM discretize", gr.Radio, {"choices": ['uniform', 'quad']}), "ddim_discretize": OptionInfo('uniform', "img2img DDIM discretize", gr.Radio, {"choices": ['uniform', 'quad']}),
's_churn': OptionInfo(0.0, "sigma churn", gr.Slider, {"minimum": 0.0, "maximum": 100.0, "step": 0.01}, infotext='Sigma churn').info('amount of stochasticity; only applies to Euler, Heun, and DPM2'), 's_churn': OptionInfo(0.0, "sigma churn", gr.Slider, {"minimum": 0.0, "maximum": 100.0, "step": 0.01}, infotext='Sigma churn').info('amount of stochasticity; only applies to Euler, Heun, and DPM2'),
's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 10.0, "step": 0.01}, infotext='Sigma tmin').info('enable stochasticity; start value of the sigma range; only applies to Euler, Heun, and DPM2'), 's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 10.0, "step": 0.01}, infotext='Sigma tmin').info('enable stochasticity; start value of the sigma range; only applies to Euler, Heun, and DPM2'),
's_tmax': OptionInfo(0.0, "sigma tmax", gr.Slider, {"minimum": 0.0, "maximum": 999.0, "step": 0.01}, infotext='Sigma tmax').info("0 = inf; end value of the sigma range; only applies to Euler, Heun, and DPM2"), 's_tmax': OptionInfo(0.0, "sigma tmax", gr.Slider, {"minimum": 0.0, "maximum": 999.0, "step": 0.01}, infotext='Sigma tmax').info("0 = inf; end value of the sigma range; only applies to Euler, Heun, and DPM2"),
's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.1, "step": 0.001}, infotext='Sigma noise').info('amount of additional noise to counteract loss of detail during sampling; only applies to Euler, Heun, and DPM2'), 's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.1, "step": 0.001}, infotext='Sigma noise').info('amount of additional noise to counteract loss of detail during sampling'),
'k_sched_type': OptionInfo("Automatic", "Scheduler type", gr.Dropdown, {"choices": ["Automatic", "karras", "exponential", "polyexponential"]}, infotext='Schedule type').info("lets you override the noise schedule for k-diffusion samplers; choosing Automatic disables the three parameters below"), 'k_sched_type': OptionInfo("Automatic", "Scheduler type", gr.Dropdown, {"choices": ["Automatic", "karras", "exponential", "polyexponential"]}, infotext='Schedule type').info("lets you override the noise schedule for k-diffusion samplers; choosing Automatic disables the three parameters below"),
'sigma_min': OptionInfo(0.0, "sigma min", gr.Number, infotext='Schedule max sigma').info("0 = default (~0.03); minimum noise strength for k-diffusion noise scheduler"), 'sigma_min': OptionInfo(0.0, "sigma min", gr.Number, infotext='Schedule max sigma').info("0 = default (~0.03); minimum noise strength for k-diffusion noise scheduler"),
'sigma_max': OptionInfo(0.0, "sigma max", gr.Number, infotext='Schedule min sigma').info("0 = default (~14.6); maximum noise strength for k-diffusion noise scheduler"), 'sigma_max': OptionInfo(0.0, "sigma max", gr.Number, infotext='Schedule min sigma').info("0 = default (~14.6); maximum noise strength for k-diffusion noise scheduler"),

View File

@ -58,7 +58,7 @@ def _summarize_chunk(
scale: float, scale: float,
) -> AttnChunk: ) -> AttnChunk:
attn_weights = torch.baddbmm( attn_weights = torch.baddbmm(
torch.empty(1, 1, 1, device=query.device, dtype=query.dtype), torch.zeros(1, 1, 1, device=query.device, dtype=query.dtype),
query, query,
key.transpose(1,2), key.transpose(1,2),
alpha=scale, alpha=scale,
@ -121,7 +121,7 @@ def _get_attention_scores_no_kv_chunking(
scale: float, scale: float,
) -> Tensor: ) -> Tensor:
attn_scores = torch.baddbmm( attn_scores = torch.baddbmm(
torch.empty(1, 1, 1, device=query.device, dtype=query.dtype), torch.zeros(1, 1, 1, device=query.device, dtype=query.dtype),
query, query,
key.transpose(1,2), key.transpose(1,2),
alpha=scale, alpha=scale,

View File

@ -9,7 +9,7 @@ from modules.ui import plaintext_to_html
import gradio as gr import gradio as gr
def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_name: str, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_checkpoint_name: str, hr_sampler_name: str, hr_prompt: str, hr_negative_prompt, override_settings_texts, request: gr.Request, *args): def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_name: str, n_iter: int, batch_size: int, cfg_scale: float, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_checkpoint_name: str, hr_sampler_name: str, hr_prompt: str, hr_negative_prompt, override_settings_texts, request: gr.Request, *args):
override_settings = create_override_settings_dict(override_settings_texts) override_settings = create_override_settings_dict(override_settings_texts)
p = processing.StableDiffusionProcessingTxt2Img( p = processing.StableDiffusionProcessingTxt2Img(
@ -19,12 +19,6 @@ def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, step
prompt=prompt, prompt=prompt,
styles=prompt_styles, styles=prompt_styles,
negative_prompt=negative_prompt, negative_prompt=negative_prompt,
seed=seed,
subseed=subseed,
subseed_strength=subseed_strength,
seed_resize_from_h=seed_resize_from_h,
seed_resize_from_w=seed_resize_from_w,
seed_enable_extras=seed_enable_extras,
sampler_name=sampler_name, sampler_name=sampler_name,
batch_size=batch_size, batch_size=batch_size,
n_iter=n_iter, n_iter=n_iter,

View File

@ -1,5 +1,4 @@
import datetime import datetime
import json
import mimetypes import mimetypes
import os import os
import sys import sys
@ -13,7 +12,7 @@ from PIL import Image, PngImagePlugin # noqa: F401
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
from modules import gradio_extensons # noqa: F401 from modules import gradio_extensons # noqa: F401
from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, errors, shared_items, ui_settings, timer, sysinfo, ui_checkpoint_merger, ui_prompt_styles, scripts, sd_samplers, processing, ui_extra_networks from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, shared_items, ui_settings, timer, sysinfo, ui_checkpoint_merger, ui_prompt_styles, scripts, sd_samplers, processing, ui_extra_networks
from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML, InputAccordion from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML, InputAccordion
from modules.paths import script_path from modules.paths import script_path
from modules.ui_common import create_refresh_button from modules.ui_common import create_refresh_button
@ -142,45 +141,6 @@ def interrogate_deepbooru(image):
return gr.update() if prompt is None else prompt return gr.update() if prompt is None else prompt
def create_seed_inputs(target_interface):
with FormRow(elem_id=f"{target_interface}_seed_row", variant="compact"):
if cmd_opts.use_textbox_seed:
seed = gr.Textbox(label='Seed', value="", elem_id=f"{target_interface}_seed")
else:
seed = gr.Number(label='Seed', value=-1, elem_id=f"{target_interface}_seed", precision=0)
random_seed = ToolButton(random_symbol, elem_id=f"{target_interface}_random_seed", label='Random seed')
reuse_seed = ToolButton(reuse_symbol, elem_id=f"{target_interface}_reuse_seed", label='Reuse seed')
seed_checkbox = gr.Checkbox(label='Extra', elem_id=f"{target_interface}_subseed_show", value=False)
# Components to show/hide based on the 'Extra' checkbox
seed_extras = []
with FormRow(visible=False, elem_id=f"{target_interface}_subseed_row") as seed_extra_row_1:
seed_extras.append(seed_extra_row_1)
subseed = gr.Number(label='Variation seed', value=-1, elem_id=f"{target_interface}_subseed", precision=0)
random_subseed = ToolButton(random_symbol, elem_id=f"{target_interface}_random_subseed")
reuse_subseed = ToolButton(reuse_symbol, elem_id=f"{target_interface}_reuse_subseed")
subseed_strength = gr.Slider(label='Variation strength', value=0.0, minimum=0, maximum=1, step=0.01, elem_id=f"{target_interface}_subseed_strength")
with FormRow(visible=False) as seed_extra_row_2:
seed_extras.append(seed_extra_row_2)
seed_resize_from_w = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from width", value=0, elem_id=f"{target_interface}_seed_resize_from_w")
seed_resize_from_h = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from height", value=0, elem_id=f"{target_interface}_seed_resize_from_h")
random_seed.click(fn=None, _js="function(){setRandomSeed('" + target_interface + "_seed')}", show_progress=False, inputs=[], outputs=[])
random_subseed.click(fn=None, _js="function(){setRandomSeed('" + target_interface + "_subseed')}", show_progress=False, inputs=[], outputs=[])
def change_visibility(show):
return {comp: gr_show(show) for comp in seed_extras}
seed_checkbox.change(change_visibility, show_progress=False, inputs=[seed_checkbox], outputs=seed_extras)
return seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox
def connect_clear_prompt(button): def connect_clear_prompt(button):
"""Given clear button, prompt, and token_counter objects, setup clear prompt button click event""" """Given clear button, prompt, and token_counter objects, setup clear prompt button click event"""
button.click( button.click(
@ -191,39 +151,6 @@ def connect_clear_prompt(button):
) )
def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info: gr.Textbox, dummy_component, is_subseed):
""" Connects a 'reuse (sub)seed' button's click event so that it copies last used
(sub)seed value from generation info the to the seed field. If copying subseed and subseed strength
was 0, i.e. no variation seed was used, it copies the normal seed value instead."""
def copy_seed(gen_info_string: str, index):
res = -1
try:
gen_info = json.loads(gen_info_string)
index -= gen_info.get('index_of_first_image', 0)
if is_subseed and gen_info.get('subseed_strength', 0) > 0:
all_subseeds = gen_info.get('all_subseeds', [-1])
res = all_subseeds[index if 0 <= index < len(all_subseeds) else 0]
else:
all_seeds = gen_info.get('all_seeds', [-1])
res = all_seeds[index if 0 <= index < len(all_seeds) else 0]
except json.decoder.JSONDecodeError:
if gen_info_string:
errors.report(f"Error parsing JSON generation info: {gen_info_string}")
return [res, gr_show(False)]
reuse_seed.click(
fn=copy_seed,
_js="(x, y) => [x, selected_gallery_index()]",
show_progress=False,
inputs=[generation_info, dummy_component],
outputs=[seed, dummy_component]
)
def update_token_counter(text, steps): def update_token_counter(text, steps):
try: try:
text, _ = extra_networks.parse_prompt(text) text, _ = extra_networks.parse_prompt(text)
@ -429,44 +356,45 @@ def create_ui():
batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size") batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size")
elif category == "cfg": elif category == "cfg":
cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="txt2img_cfg_scale") with gr.Row():
cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="txt2img_cfg_scale")
elif category == "seed":
seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs('txt2img')
elif category == "checkboxes": elif category == "checkboxes":
with FormRow(elem_classes="checkboxes-row", variant="compact"): with FormRow(elem_classes="checkboxes-row", variant="compact"):
pass pass
elif category == "hires_fix": elif category == "accordions":
with InputAccordion(False, label="Hires. fix") as enable_hr: with gr.Row(elem_id="txt2img_accordions", elem_classes="accordions"):
with enable_hr.extra(): with InputAccordion(False, label="Hires. fix", elem_id="txt2img_hr") as enable_hr:
hr_final_resolution = FormHTML(value="", elem_id="txtimg_hr_finalres", label="Upscaled resolution", interactive=False, min_width=0) with enable_hr.extra():
hr_final_resolution = FormHTML(value="", elem_id="txtimg_hr_finalres", label="Upscaled resolution", interactive=False, min_width=0)
with FormRow(elem_id="txt2img_hires_fix_row1", variant="compact"): with FormRow(elem_id="txt2img_hires_fix_row1", variant="compact"):
hr_upscaler = gr.Dropdown(label="Upscaler", elem_id="txt2img_hr_upscaler", choices=[*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]], value=shared.latent_upscale_default_mode) hr_upscaler = gr.Dropdown(label="Upscaler", elem_id="txt2img_hr_upscaler", choices=[*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]], value=shared.latent_upscale_default_mode)
hr_second_pass_steps = gr.Slider(minimum=0, maximum=150, step=1, label='Hires steps', value=0, elem_id="txt2img_hires_steps") hr_second_pass_steps = gr.Slider(minimum=0, maximum=150, step=1, label='Hires steps', value=0, elem_id="txt2img_hires_steps")
denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7, elem_id="txt2img_denoising_strength") denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7, elem_id="txt2img_denoising_strength")
with FormRow(elem_id="txt2img_hires_fix_row2", variant="compact"): with FormRow(elem_id="txt2img_hires_fix_row2", variant="compact"):
hr_scale = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Upscale by", value=2.0, elem_id="txt2img_hr_scale") hr_scale = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Upscale by", value=2.0, elem_id="txt2img_hr_scale")
hr_resize_x = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize width to", value=0, elem_id="txt2img_hr_resize_x") hr_resize_x = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize width to", value=0, elem_id="txt2img_hr_resize_x")
hr_resize_y = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize height to", value=0, elem_id="txt2img_hr_resize_y") hr_resize_y = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize height to", value=0, elem_id="txt2img_hr_resize_y")
with FormRow(elem_id="txt2img_hires_fix_row3", variant="compact", visible=opts.hires_fix_show_sampler) as hr_sampler_container: with FormRow(elem_id="txt2img_hires_fix_row3", variant="compact", visible=opts.hires_fix_show_sampler) as hr_sampler_container:
hr_checkpoint_name = gr.Dropdown(label='Hires checkpoint', elem_id="hr_checkpoint", choices=["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True), value="Use same checkpoint") hr_checkpoint_name = gr.Dropdown(label='Hires checkpoint', elem_id="hr_checkpoint", choices=["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True), value="Use same checkpoint")
create_refresh_button(hr_checkpoint_name, modules.sd_models.list_models, lambda: {"choices": ["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True)}, "hr_checkpoint_refresh") create_refresh_button(hr_checkpoint_name, modules.sd_models.list_models, lambda: {"choices": ["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True)}, "hr_checkpoint_refresh")
hr_sampler_name = gr.Dropdown(label='Hires sampling method', elem_id="hr_sampler", choices=["Use same sampler"] + sd_samplers.visible_sampler_names(), value="Use same sampler") hr_sampler_name = gr.Dropdown(label='Hires sampling method', elem_id="hr_sampler", choices=["Use same sampler"] + sd_samplers.visible_sampler_names(), value="Use same sampler")
with FormRow(elem_id="txt2img_hires_fix_row4", variant="compact", visible=opts.hires_fix_show_prompts) as hr_prompts_container: with FormRow(elem_id="txt2img_hires_fix_row4", variant="compact", visible=opts.hires_fix_show_prompts) as hr_prompts_container:
with gr.Column(scale=80): with gr.Column(scale=80):
with gr.Row(): with gr.Row():
hr_prompt = gr.Textbox(label="Hires prompt", elem_id="hires_prompt", show_label=False, lines=3, placeholder="Prompt for hires fix pass.\nLeave empty to use the same prompt as in first pass.", elem_classes=["prompt"]) hr_prompt = gr.Textbox(label="Hires prompt", elem_id="hires_prompt", show_label=False, lines=3, placeholder="Prompt for hires fix pass.\nLeave empty to use the same prompt as in first pass.", elem_classes=["prompt"])
with gr.Column(scale=80): with gr.Column(scale=80):
with gr.Row(): with gr.Row():
hr_negative_prompt = gr.Textbox(label="Hires negative prompt", elem_id="hires_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt for hires fix pass.\nLeave empty to use the same negative prompt as in first pass.", elem_classes=["prompt"]) hr_negative_prompt = gr.Textbox(label="Hires negative prompt", elem_id="hires_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt for hires fix pass.\nLeave empty to use the same negative prompt as in first pass.", elem_classes=["prompt"])
scripts.scripts_txt2img.setup_ui_for_section(category)
elif category == "batch": elif category == "batch":
if not opts.dimensions_and_batch_together: if not opts.dimensions_and_batch_together:
@ -482,7 +410,7 @@ def create_ui():
with FormGroup(elem_id="txt2img_script_container"): with FormGroup(elem_id="txt2img_script_container"):
custom_inputs = scripts.scripts_txt2img.setup_ui() custom_inputs = scripts.scripts_txt2img.setup_ui()
else: if category not in {"accordions"}:
scripts.scripts_txt2img.setup_ui_for_section(category) scripts.scripts_txt2img.setup_ui_for_section(category)
hr_resolution_preview_inputs = [enable_hr, width, height, hr_scale, hr_resize_x, hr_resize_y] hr_resolution_preview_inputs = [enable_hr, width, height, hr_scale, hr_resize_x, hr_resize_y]
@ -506,9 +434,6 @@ def create_ui():
txt2img_gallery, generation_info, html_info, html_log = create_output_panel("txt2img", opts.outdir_txt2img_samples) txt2img_gallery, generation_info, html_info, html_log = create_output_panel("txt2img", opts.outdir_txt2img_samples)
connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False)
connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True)
txt2img_args = dict( txt2img_args = dict(
fn=wrap_gradio_gpu_call(modules.txt2img.txt2img, extra_outputs=[None, '', '']), fn=wrap_gradio_gpu_call(modules.txt2img.txt2img, extra_outputs=[None, '', '']),
_js="submit", _js="submit",
@ -522,8 +447,6 @@ def create_ui():
batch_count, batch_count,
batch_size, batch_size,
cfg_scale, cfg_scale,
seed,
subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox,
height, height,
width, width,
enable_hr, enable_hr,
@ -574,15 +497,9 @@ def create_ui():
(steps, "Steps"), (steps, "Steps"),
(sampler_name, "Sampler"), (sampler_name, "Sampler"),
(cfg_scale, "CFG scale"), (cfg_scale, "CFG scale"),
(seed, "Seed"),
(width, "Size-1"), (width, "Size-1"),
(height, "Size-2"), (height, "Size-2"),
(batch_size, "Batch size"), (batch_size, "Batch size"),
(seed_checkbox, lambda d: "Variation seed" in d or "Seed resize from-1" in d),
(subseed, "Variation seed"),
(subseed_strength, "Variation seed strength"),
(seed_resize_from_w, "Seed resize from-1"),
(seed_resize_from_h, "Seed resize from-2"),
(toprow.ui_styles.dropdown, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()), (toprow.ui_styles.dropdown, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()),
(denoising_strength, "Denoising strength"), (denoising_strength, "Denoising strength"),
(enable_hr, lambda d: "Denoising strength" in d and ("Hires upscale" in d or "Hires upscaler" in d or "Hires resize-1" in d)), (enable_hr, lambda d: "Denoising strength" in d and ("Hires upscale" in d or "Hires upscaler" in d or "Hires resize-1" in d)),
@ -610,7 +527,7 @@ def create_ui():
steps, steps,
sampler_name, sampler_name,
cfg_scale, cfg_scale,
seed, scripts.scripts_txt2img.script('Seed').seed,
width, width,
height, height,
] ]
@ -780,20 +697,22 @@ def create_ui():
batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count") batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count")
batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="img2img_batch_size") batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="img2img_batch_size")
elif category == "cfg": elif category == "denoising":
with FormGroup(): denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.75, elem_id="img2img_denoising_strength")
with FormRow():
cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="img2img_cfg_scale")
image_cfg_scale = gr.Slider(minimum=0, maximum=3.0, step=0.05, label='Image CFG Scale', value=1.5, elem_id="img2img_image_cfg_scale", visible=False)
denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.75, elem_id="img2img_denoising_strength")
elif category == "seed": elif category == "cfg":
seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs('img2img') with gr.Row():
cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="img2img_cfg_scale")
image_cfg_scale = gr.Slider(minimum=0, maximum=3.0, step=0.05, label='Image CFG Scale', value=1.5, elem_id="img2img_image_cfg_scale", visible=False)
elif category == "checkboxes": elif category == "checkboxes":
with FormRow(elem_classes="checkboxes-row", variant="compact"): with FormRow(elem_classes="checkboxes-row", variant="compact"):
pass pass
elif category == "accordions":
with gr.Row(elem_id="img2img_accordions", elem_classes="accordions"):
scripts.scripts_img2img.setup_ui_for_section(category)
elif category == "batch": elif category == "batch":
if not opts.dimensions_and_batch_together: if not opts.dimensions_and_batch_together:
with FormRow(elem_id="img2img_column_batch"): with FormRow(elem_id="img2img_column_batch"):
@ -836,14 +755,12 @@ def create_ui():
inputs=[], inputs=[],
outputs=[inpaint_controls, mask_alpha], outputs=[inpaint_controls, mask_alpha],
) )
else:
if category not in {"accordions"}:
scripts.scripts_img2img.setup_ui_for_section(category) scripts.scripts_img2img.setup_ui_for_section(category)
img2img_gallery, generation_info, html_info, html_log = create_output_panel("img2img", opts.outdir_img2img_samples) img2img_gallery, generation_info, html_info, html_log = create_output_panel("img2img", opts.outdir_img2img_samples)
connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False)
connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True)
img2img_args = dict( img2img_args = dict(
fn=wrap_gradio_gpu_call(modules.img2img.img2img, extra_outputs=[None, '', '']), fn=wrap_gradio_gpu_call(modules.img2img.img2img, extra_outputs=[None, '', '']),
_js="submit_img2img", _js="submit_img2img",
@ -870,8 +787,6 @@ def create_ui():
cfg_scale, cfg_scale,
image_cfg_scale, image_cfg_scale,
denoising_strength, denoising_strength,
seed,
subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox,
selected_scale_tab, selected_scale_tab,
height, height,
width, width,
@ -958,15 +873,9 @@ def create_ui():
(sampler_name, "Sampler"), (sampler_name, "Sampler"),
(cfg_scale, "CFG scale"), (cfg_scale, "CFG scale"),
(image_cfg_scale, "Image CFG scale"), (image_cfg_scale, "Image CFG scale"),
(seed, "Seed"),
(width, "Size-1"), (width, "Size-1"),
(height, "Size-2"), (height, "Size-2"),
(batch_size, "Batch size"), (batch_size, "Batch size"),
(seed_checkbox, lambda d: "Variation seed" in d or "Seed resize from-1" in d),
(subseed, "Variation seed"),
(subseed_strength, "Variation seed strength"),
(seed_resize_from_w, "Seed resize from-1"),
(seed_resize_from_h, "Seed resize from-2"),
(toprow.ui_styles.dropdown, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()), (toprow.ui_styles.dropdown, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()),
(denoising_strength, "Denoising strength"), (denoising_strength, "Denoising strength"),
(mask_blur, "Mask blur"), (mask_blur, "Mask blur"),

View File

@ -137,13 +137,17 @@ Requested path was: {f}
generation_info = None generation_info = None
with gr.Column(): with gr.Column():
with gr.Row(elem_id=f"image_buttons_{tabname}", elem_classes="image-buttons"): with gr.Row(elem_id=f"image_buttons_{tabname}", elem_classes="image-buttons"):
open_folder_button = gr.Button(folder_symbol, visible=not shared.cmd_opts.hide_ui_dir_config) open_folder_button = ToolButton(folder_symbol, elem_id=f'{tabname}_open_folder', visible=not shared.cmd_opts.hide_ui_dir_config, tooltip="Open images output directory.")
if tabname != "extras": if tabname != "extras":
save = gr.Button('Save', elem_id=f'save_{tabname}') save = ToolButton('💾', elem_id=f'save_{tabname}', tooltip=f"Save the image to a dedicated directory ({shared.opts.outdir_save}).")
save_zip = gr.Button('Zip', elem_id=f'save_zip_{tabname}') save_zip = ToolButton('🗃️', elem_id=f'save_zip_{tabname}', tooltip=f"Save zip archive with images to a dedicated directory ({shared.opts.outdir_save})")
buttons = parameters_copypaste.create_buttons(["img2img", "inpaint", "extras"]) buttons = {
'img2img': ToolButton('🖼️', elem_id=f'{tabname}_send_to_img2img', tooltip="Send image and generation parameters to img2img tab."),
'inpaint': ToolButton('🎨️', elem_id=f'{tabname}_send_to_inpaint', tooltip="Send image and generation parameters to img2img inpaint tab."),
'extras': ToolButton('📐', elem_id=f'{tabname}_send_to_extras', tooltip="Send image and generation parameters to extras tab.")
}
open_folder_button.click( open_folder_button.click(
fn=lambda: open_folder(shared.opts.outdir_samples or outdir), fn=lambda: open_folder(shared.opts.outdir_samples or outdir),

View File

@ -87,13 +87,23 @@ class InputAccordion(gr.Checkbox):
self.accordion_id = f"input-accordion-{InputAccordion.global_index}" self.accordion_id = f"input-accordion-{InputAccordion.global_index}"
InputAccordion.global_index += 1 InputAccordion.global_index += 1
kwargs['elem_id'] = self.accordion_id + "-checkbox" kwargs_checkbox = {
kwargs['visible'] = False **kwargs,
super().__init__(value, **kwargs) "elem_id": f"{self.accordion_id}-checkbox",
"visible": False,
}
super().__init__(value, **kwargs_checkbox)
self.change(fn=None, _js='function(checked){ inputAccordionChecked("' + self.accordion_id + '", checked); }', inputs=[self]) self.change(fn=None, _js='function(checked){ inputAccordionChecked("' + self.accordion_id + '", checked); }', inputs=[self])
self.accordion = gr.Accordion(kwargs.get('label', 'Accordion'), open=value, elem_id=self.accordion_id, elem_classes=['input-accordion']) kwargs_accordion = {
**kwargs,
"elem_id": self.accordion_id,
"label": kwargs.get('label', 'Accordion'),
"elem_classes": ['input-accordion'],
"open": value,
}
self.accordion = gr.Accordion(**kwargs_accordion)
def extra(self): def extra(self):
"""Allows you to put something into the label of the accordion. """Allows you to put something into the label of the accordion.

View File

@ -19,6 +19,7 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
return { return {
"name": checkpoint.name_for_extra, "name": checkpoint.name_for_extra,
"filename": checkpoint.filename, "filename": checkpoint.filename,
"shorthash": checkpoint.shorthash,
"preview": self.find_preview(path), "preview": self.find_preview(path),
"description": self.find_description(path), "description": self.find_description(path),
"search_term": self.search_terms_from_path(checkpoint.filename) + " " + (checkpoint.sha256 or ""), "search_term": self.search_terms_from_path(checkpoint.filename) + " " + (checkpoint.sha256 or ""),

View File

@ -2,6 +2,7 @@ import os
from modules import shared, ui_extra_networks from modules import shared, ui_extra_networks
from modules.ui_extra_networks import quote_js from modules.ui_extra_networks import quote_js
from modules.hashes import sha256_from_cache
class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage): class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage):
@ -14,13 +15,16 @@ class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage):
def create_item(self, name, index=None, enable_filter=True): def create_item(self, name, index=None, enable_filter=True):
full_path = shared.hypernetworks[name] full_path = shared.hypernetworks[name]
path, ext = os.path.splitext(full_path) path, ext = os.path.splitext(full_path)
sha256 = sha256_from_cache(full_path, f'hypernet/{name}')
shorthash = sha256[0:10] if sha256 else None
return { return {
"name": name, "name": name,
"filename": full_path, "filename": full_path,
"shorthash": shorthash,
"preview": self.find_preview(path), "preview": self.find_preview(path),
"description": self.find_description(path), "description": self.find_description(path),
"search_term": self.search_terms_from_path(path), "search_term": self.search_terms_from_path(path) + " " + (sha256 or ""),
"prompt": quote_js(f"<hypernet:{name}:") + " + opts.extra_networks_default_multiplier + " + quote_js(">"), "prompt": quote_js(f"<hypernet:{name}:") + " + opts.extra_networks_default_multiplier + " + quote_js(">"),
"local_preview": f"{path}.preview.{shared.opts.samples_format}", "local_preview": f"{path}.preview.{shared.opts.samples_format}",
"sort_keys": {'default': index, **self.get_sort_keys(path + ext)}, "sort_keys": {'default': index, **self.get_sort_keys(path + ext)},

View File

@ -19,9 +19,10 @@ class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage):
return { return {
"name": name, "name": name,
"filename": embedding.filename, "filename": embedding.filename,
"shorthash": embedding.shorthash,
"preview": self.find_preview(path), "preview": self.find_preview(path),
"description": self.find_description(path), "description": self.find_description(path),
"search_term": self.search_terms_from_path(embedding.filename), "search_term": self.search_terms_from_path(embedding.filename) + " " + (embedding.hash or ""),
"prompt": quote_js(embedding.name), "prompt": quote_js(embedding.name),
"local_preview": f"{path}.preview.{shared.opts.samples_format}", "local_preview": f"{path}.preview.{shared.opts.samples_format}",
"sort_keys": {'default': index, **self.get_sort_keys(embedding.filename)}, "sort_keys": {'default': index, **self.get_sort_keys(embedding.filename)},

View File

@ -93,11 +93,13 @@ class UserMetadataEditor:
item = self.page.items.get(name, {}) item = self.page.items.get(name, {})
try: try:
filename = item["filename"] filename = item["filename"]
shorthash = item.get("shorthash", None)
stats = os.stat(filename) stats = os.stat(filename)
params = [ params = [
('Filename: ', os.path.basename(filename)), ('Filename: ', os.path.basename(filename)),
('File size: ', sysinfo.pretty_bytes(stats.st_size)), ('File size: ', sysinfo.pretty_bytes(stats.st_size)),
('Hash: ', shorthash),
('Modified: ', datetime.datetime.fromtimestamp(stats.st_mtime).strftime('%Y-%m-%d %H:%M')), ('Modified: ', datetime.datetime.fromtimestamp(stats.st_mtime).strftime('%Y-%m-%d %H:%M')),
] ]
@ -115,7 +117,7 @@ class UserMetadataEditor:
errors.display(e, f"reading metadata info for {name}") errors.display(e, f"reading metadata info for {name}")
params = [] params = []
table = '<table class="file-metadata">' + "".join(f"<tr><th>{name}</th><td>{value}</td></tr>" for name, value in params) + '</table>' table = '<table class="file-metadata">' + "".join(f"<tr><th>{name}</th><td>{value}</td></tr>" for name, value in params if value is not None) + '</table>'
return html.escape(name), user_metadata.get('description', ''), table, self.get_card_html(name), user_metadata.get('notes', '') return html.escape(name), user_metadata.get('description', ''), table, self.get_card_html(name), user_metadata.get('notes', '')

View File

@ -48,13 +48,13 @@ class UiLoadsave:
elif condition and not condition(saved_value): elif condition and not condition(saved_value):
pass pass
else: else:
if isinstance(x, gr.Textbox) and field == 'value': # due to an undersirable behavior of gr.Textbox, if you give it an int value instead of str, everything dies if isinstance(x, gr.Textbox) and field == 'value': # due to an undesirable behavior of gr.Textbox, if you give it an int value instead of str, everything dies
saved_value = str(saved_value) saved_value = str(saved_value)
elif isinstance(x, gr.Number) and field == 'value': elif isinstance(x, gr.Number) and field == 'value':
try: try:
saved_value = float(saved_value) saved_value = float(saved_value)
except ValueError: except ValueError:
saved_value = -1 return
setattr(obj, field, saved_value) setattr(obj, field, saved_value)
if init_field is not None: if init_field is not None:

View File

@ -175,14 +175,22 @@ def do_nothing(p, x, xs):
def format_nothing(p, opt, x): def format_nothing(p, opt, x):
return "" return ""
def format_remove_path(p, opt, x): def format_remove_path(p, opt, x):
return os.path.basename(x) return os.path.basename(x)
def str_permutations(x): def str_permutations(x):
"""dummy function for specifying it in AxisOption's type when you want to get a list of permutations""" """dummy function for specifying it in AxisOption's type when you want to get a list of permutations"""
return x return x
def list_to_csv_string(data_list):
with StringIO() as o:
csv.writer(o).writerow(data_list)
return o.getvalue().strip()
class AxisOption: class AxisOption:
def __init__(self, label, type, apply, format_value=format_value_add_label, confirm=None, cost=0.0, choices=None): def __init__(self, label, type, apply, format_value=format_value_add_label, confirm=None, cost=0.0, choices=None):
self.label = label self.label = label
@ -199,6 +207,7 @@ class AxisOptionImg2Img(AxisOption):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.is_img2img = True self.is_img2img = True
class AxisOptionTxt2Img(AxisOption): class AxisOptionTxt2Img(AxisOption):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
@ -286,11 +295,10 @@ def draw_xyz_grid(p, xs, ys, zs, x_labels, y_labels, z_labels, cell, draw_legend
cell_size = (processed_result.width, processed_result.height) cell_size = (processed_result.width, processed_result.height)
if processed_result.images[0] is not None: if processed_result.images[0] is not None:
cell_mode = processed_result.images[0].mode cell_mode = processed_result.images[0].mode
#This corrects size in case of batches: # This corrects size in case of batches:
cell_size = processed_result.images[0].size cell_size = processed_result.images[0].size
processed_result.images[idx] = Image.new(cell_mode, cell_size) processed_result.images[idx] = Image.new(cell_mode, cell_size)
if first_axes_processed == 'x': if first_axes_processed == 'x':
for ix, x in enumerate(xs): for ix, x in enumerate(xs):
if second_axes_processed == 'y': if second_axes_processed == 'y':
@ -348,9 +356,9 @@ def draw_xyz_grid(p, xs, ys, zs, x_labels, y_labels, z_labels, cell, draw_legend
if draw_legend: if draw_legend:
z_grid = images.draw_grid_annotations(z_grid, sub_grid_size[0], sub_grid_size[1], title_texts, [[images.GridAnnotation()]]) z_grid = images.draw_grid_annotations(z_grid, sub_grid_size[0], sub_grid_size[1], title_texts, [[images.GridAnnotation()]])
processed_result.images.insert(0, z_grid) processed_result.images.insert(0, z_grid)
#TODO: Deeper aspects of the program rely on grid info being misaligned between metadata arrays, which is not ideal. # TODO: Deeper aspects of the program rely on grid info being misaligned between metadata arrays, which is not ideal.
#processed_result.all_prompts.insert(0, processed_result.all_prompts[0]) # processed_result.all_prompts.insert(0, processed_result.all_prompts[0])
#processed_result.all_seeds.insert(0, processed_result.all_seeds[0]) # processed_result.all_seeds.insert(0, processed_result.all_seeds[0])
processed_result.infotexts.insert(0, processed_result.infotexts[0]) processed_result.infotexts.insert(0, processed_result.infotexts[0])
return processed_result return processed_result
@ -374,8 +382,8 @@ class SharedSettingsStackHelper(object):
re_range = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\(([+-]\d+)\s*\))?\s*") re_range = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\(([+-]\d+)\s*\))?\s*")
re_range_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d*)?)(?:\s*\(([+-]\d+(?:.\d*)?)\s*\))?\s*") re_range_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d*)?)(?:\s*\(([+-]\d+(?:.\d*)?)\s*\))?\s*")
re_range_count = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\[(\d+)\s*\])?\s*") re_range_count = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\[(\d+)\s*])?\s*")
re_range_count_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d*)?)(?:\s*\[(\d+(?:.\d*)?)\s*\])?\s*") re_range_count_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d*)?)(?:\s*\[(\d+(?:.\d*)?)\s*])?\s*")
class Script(scripts.Script): class Script(scripts.Script):
@ -390,19 +398,19 @@ class Script(scripts.Script):
with gr.Row(): with gr.Row():
x_type = gr.Dropdown(label="X type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[1].label, type="index", elem_id=self.elem_id("x_type")) x_type = gr.Dropdown(label="X type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[1].label, type="index", elem_id=self.elem_id("x_type"))
x_values = gr.Textbox(label="X values", lines=1, elem_id=self.elem_id("x_values")) x_values = gr.Textbox(label="X values", lines=1, elem_id=self.elem_id("x_values"))
x_values_dropdown = gr.Dropdown(label="X values",visible=False,multiselect=True,interactive=True) x_values_dropdown = gr.Dropdown(label="X values", visible=False, multiselect=True, interactive=True)
fill_x_button = ToolButton(value=fill_values_symbol, elem_id="xyz_grid_fill_x_tool_button", visible=False) fill_x_button = ToolButton(value=fill_values_symbol, elem_id="xyz_grid_fill_x_tool_button", visible=False)
with gr.Row(): with gr.Row():
y_type = gr.Dropdown(label="Y type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[0].label, type="index", elem_id=self.elem_id("y_type")) y_type = gr.Dropdown(label="Y type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[0].label, type="index", elem_id=self.elem_id("y_type"))
y_values = gr.Textbox(label="Y values", lines=1, elem_id=self.elem_id("y_values")) y_values = gr.Textbox(label="Y values", lines=1, elem_id=self.elem_id("y_values"))
y_values_dropdown = gr.Dropdown(label="Y values",visible=False,multiselect=True,interactive=True) y_values_dropdown = gr.Dropdown(label="Y values", visible=False, multiselect=True, interactive=True)
fill_y_button = ToolButton(value=fill_values_symbol, elem_id="xyz_grid_fill_y_tool_button", visible=False) fill_y_button = ToolButton(value=fill_values_symbol, elem_id="xyz_grid_fill_y_tool_button", visible=False)
with gr.Row(): with gr.Row():
z_type = gr.Dropdown(label="Z type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[0].label, type="index", elem_id=self.elem_id("z_type")) z_type = gr.Dropdown(label="Z type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[0].label, type="index", elem_id=self.elem_id("z_type"))
z_values = gr.Textbox(label="Z values", lines=1, elem_id=self.elem_id("z_values")) z_values = gr.Textbox(label="Z values", lines=1, elem_id=self.elem_id("z_values"))
z_values_dropdown = gr.Dropdown(label="Z values",visible=False,multiselect=True,interactive=True) z_values_dropdown = gr.Dropdown(label="Z values", visible=False, multiselect=True, interactive=True)
fill_z_button = ToolButton(value=fill_values_symbol, elem_id="xyz_grid_fill_z_tool_button", visible=False) fill_z_button = ToolButton(value=fill_values_symbol, elem_id="xyz_grid_fill_z_tool_button", visible=False)
with gr.Row(variant="compact", elem_id="axis_options"): with gr.Row(variant="compact", elem_id="axis_options"):
@ -414,6 +422,9 @@ class Script(scripts.Script):
include_sub_grids = gr.Checkbox(label='Include Sub Grids', value=False, elem_id=self.elem_id("include_sub_grids")) include_sub_grids = gr.Checkbox(label='Include Sub Grids', value=False, elem_id=self.elem_id("include_sub_grids"))
with gr.Column(): with gr.Column():
margin_size = gr.Slider(label="Grid margins (px)", minimum=0, maximum=500, value=0, step=2, elem_id=self.elem_id("margin_size")) margin_size = gr.Slider(label="Grid margins (px)", minimum=0, maximum=500, value=0, step=2, elem_id=self.elem_id("margin_size"))
with gr.Column():
csv_mode = gr.Checkbox(label='Use text inputs instead of dropdowns', value=False, elem_id=self.elem_id("csv_mode"))
with gr.Row(variant="compact", elem_id="swap_axes"): with gr.Row(variant="compact", elem_id="swap_axes"):
swap_xy_axes_button = gr.Button(value="Swap X/Y axes", elem_id="xy_grid_swap_axes_button") swap_xy_axes_button = gr.Button(value="Swap X/Y axes", elem_id="xy_grid_swap_axes_button")
@ -430,50 +441,71 @@ class Script(scripts.Script):
xz_swap_args = [x_type, x_values, x_values_dropdown, z_type, z_values, z_values_dropdown] xz_swap_args = [x_type, x_values, x_values_dropdown, z_type, z_values, z_values_dropdown]
swap_xz_axes_button.click(swap_axes, inputs=xz_swap_args, outputs=xz_swap_args) swap_xz_axes_button.click(swap_axes, inputs=xz_swap_args, outputs=xz_swap_args)
def fill(x_type): def fill(axis_type, csv_mode):
axis = self.current_axis_options[x_type] axis = self.current_axis_options[axis_type]
return axis.choices() if axis.choices else gr.update() if axis.choices:
if csv_mode:
return list_to_csv_string(axis.choices()), gr.update()
else:
return gr.update(), axis.choices()
else:
return gr.update(), gr.update()
fill_x_button.click(fn=fill, inputs=[x_type], outputs=[x_values_dropdown]) fill_x_button.click(fn=fill, inputs=[x_type, csv_mode], outputs=[x_values, x_values_dropdown])
fill_y_button.click(fn=fill, inputs=[y_type], outputs=[y_values_dropdown]) fill_y_button.click(fn=fill, inputs=[y_type, csv_mode], outputs=[y_values, y_values_dropdown])
fill_z_button.click(fn=fill, inputs=[z_type], outputs=[z_values_dropdown]) fill_z_button.click(fn=fill, inputs=[z_type, csv_mode], outputs=[z_values, z_values_dropdown])
def select_axis(axis_type,axis_values_dropdown): def select_axis(axis_type, axis_values, axis_values_dropdown, csv_mode):
choices = self.current_axis_options[axis_type].choices choices = self.current_axis_options[axis_type].choices
has_choices = choices is not None has_choices = choices is not None
current_values = axis_values_dropdown
current_values = axis_values
current_dropdown_values = axis_values_dropdown
if has_choices: if has_choices:
choices = choices() choices = choices()
if isinstance(current_values,str): if csv_mode:
current_values = current_values.split(",") current_dropdown_values = list(filter(lambda x: x in choices, current_dropdown_values))
current_values = list(filter(lambda x: x in choices, current_values)) current_values = list_to_csv_string(current_dropdown_values)
return gr.Button.update(visible=has_choices),gr.Textbox.update(visible=not has_choices),gr.update(choices=choices if has_choices else None,visible=has_choices,value=current_values) else:
current_dropdown_values = [x.strip() for x in chain.from_iterable(csv.reader(StringIO(axis_values)))]
current_dropdown_values = list(filter(lambda x: x in choices, current_dropdown_values))
x_type.change(fn=select_axis, inputs=[x_type,x_values_dropdown], outputs=[fill_x_button,x_values,x_values_dropdown]) return (gr.Button.update(visible=has_choices), gr.Textbox.update(visible=not has_choices or csv_mode, value=current_values),
y_type.change(fn=select_axis, inputs=[y_type,y_values_dropdown], outputs=[fill_y_button,y_values,y_values_dropdown]) gr.update(choices=choices if has_choices else None, visible=has_choices and not csv_mode, value=current_dropdown_values))
z_type.change(fn=select_axis, inputs=[z_type,z_values_dropdown], outputs=[fill_z_button,z_values,z_values_dropdown])
def get_dropdown_update_from_params(axis,params): x_type.change(fn=select_axis, inputs=[x_type, x_values, x_values_dropdown, csv_mode], outputs=[fill_x_button, x_values, x_values_dropdown])
y_type.change(fn=select_axis, inputs=[y_type, y_values, y_values_dropdown, csv_mode], outputs=[fill_y_button, y_values, y_values_dropdown])
z_type.change(fn=select_axis, inputs=[z_type, z_values, z_values_dropdown, csv_mode], outputs=[fill_z_button, z_values, z_values_dropdown])
def change_choice_mode(csv_mode, x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown):
_fill_x_button, _x_values, _x_values_dropdown = select_axis(x_type, x_values, x_values_dropdown, csv_mode)
_fill_y_button, _y_values, _y_values_dropdown = select_axis(y_type, y_values, y_values_dropdown, csv_mode)
_fill_z_button, _z_values, _z_values_dropdown = select_axis(z_type, z_values, z_values_dropdown, csv_mode)
return _fill_x_button, _x_values, _x_values_dropdown, _fill_y_button, _y_values, _y_values_dropdown, _fill_z_button, _z_values, _z_values_dropdown
csv_mode.change(fn=change_choice_mode, inputs=[csv_mode, x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown], outputs=[fill_x_button, x_values, x_values_dropdown, fill_y_button, y_values, y_values_dropdown, fill_z_button, z_values, z_values_dropdown])
def get_dropdown_update_from_params(axis, params):
val_key = f"{axis} Values" val_key = f"{axis} Values"
vals = params.get(val_key,"") vals = params.get(val_key, "")
valslist = [x.strip() for x in chain.from_iterable(csv.reader(StringIO(vals))) if x] valslist = [x.strip() for x in chain.from_iterable(csv.reader(StringIO(vals))) if x]
return gr.update(value = valslist) return gr.update(value=valslist)
self.infotext_fields = ( self.infotext_fields = (
(x_type, "X Type"), (x_type, "X Type"),
(x_values, "X Values"), (x_values, "X Values"),
(x_values_dropdown, lambda params:get_dropdown_update_from_params("X",params)), (x_values_dropdown, lambda params: get_dropdown_update_from_params("X", params)),
(y_type, "Y Type"), (y_type, "Y Type"),
(y_values, "Y Values"), (y_values, "Y Values"),
(y_values_dropdown, lambda params:get_dropdown_update_from_params("Y",params)), (y_values_dropdown, lambda params: get_dropdown_update_from_params("Y", params)),
(z_type, "Z Type"), (z_type, "Z Type"),
(z_values, "Z Values"), (z_values, "Z Values"),
(z_values_dropdown, lambda params:get_dropdown_update_from_params("Z",params)), (z_values_dropdown, lambda params: get_dropdown_update_from_params("Z", params)),
) )
return [x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, margin_size] return [x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, margin_size, csv_mode]
def run(self, p, x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, margin_size): def run(self, p, x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, margin_size, csv_mode):
if not no_fixed_seeds: if not no_fixed_seeds:
modules.processing.fix_seed(p) modules.processing.fix_seed(p)
@ -484,7 +516,7 @@ class Script(scripts.Script):
if opt.label == 'Nothing': if opt.label == 'Nothing':
return [0] return [0]
if opt.choices is not None: if opt.choices is not None and not csv_mode:
valslist = vals_dropdown valslist = vals_dropdown
else: else:
valslist = [x.strip() for x in chain.from_iterable(csv.reader(StringIO(vals))) if x] valslist = [x.strip() for x in chain.from_iterable(csv.reader(StringIO(vals))) if x]
@ -503,8 +535,8 @@ class Script(scripts.Script):
valslist_ext += list(range(start, end, step)) valslist_ext += list(range(start, end, step))
elif mc is not None: elif mc is not None:
start = int(mc.group(1)) start = int(mc.group(1))
end = int(mc.group(2)) end = int(mc.group(2))
num = int(mc.group(3)) if mc.group(3) is not None else 1 num = int(mc.group(3)) if mc.group(3) is not None else 1
valslist_ext += [int(x) for x in np.linspace(start=start, stop=end, num=num).tolist()] valslist_ext += [int(x) for x in np.linspace(start=start, stop=end, num=num).tolist()]
else: else:
@ -525,8 +557,8 @@ class Script(scripts.Script):
valslist_ext += np.arange(start, end + step, step).tolist() valslist_ext += np.arange(start, end + step, step).tolist()
elif mc is not None: elif mc is not None:
start = float(mc.group(1)) start = float(mc.group(1))
end = float(mc.group(2)) end = float(mc.group(2))
num = int(mc.group(3)) if mc.group(3) is not None else 1 num = int(mc.group(3)) if mc.group(3) is not None else 1
valslist_ext += np.linspace(start=start, stop=end, num=num).tolist() valslist_ext += np.linspace(start=start, stop=end, num=num).tolist()
else: else:
@ -545,22 +577,22 @@ class Script(scripts.Script):
return valslist return valslist
x_opt = self.current_axis_options[x_type] x_opt = self.current_axis_options[x_type]
if x_opt.choices is not None: if x_opt.choices is not None and not csv_mode:
x_values = ",".join(x_values_dropdown) x_values = list_to_csv_string(x_values_dropdown)
xs = process_axis(x_opt, x_values, x_values_dropdown) xs = process_axis(x_opt, x_values, x_values_dropdown)
y_opt = self.current_axis_options[y_type] y_opt = self.current_axis_options[y_type]
if y_opt.choices is not None: if y_opt.choices is not None and not csv_mode:
y_values = ",".join(y_values_dropdown) y_values = list_to_csv_string(y_values_dropdown)
ys = process_axis(y_opt, y_values, y_values_dropdown) ys = process_axis(y_opt, y_values, y_values_dropdown)
z_opt = self.current_axis_options[z_type] z_opt = self.current_axis_options[z_type]
if z_opt.choices is not None: if z_opt.choices is not None and not csv_mode:
z_values = ",".join(z_values_dropdown) z_values = list_to_csv_string(z_values_dropdown)
zs = process_axis(z_opt, z_values, z_values_dropdown) zs = process_axis(z_opt, z_values, z_values_dropdown)
# this could be moved to common code, but unlikely to be ever triggered anywhere else # this could be moved to common code, but unlikely to be ever triggered anywhere else
Image.MAX_IMAGE_PIXELS = None # disable check in Pillow and rely on check below to allow large custom image sizes Image.MAX_IMAGE_PIXELS = None # disable check in Pillow and rely on check below to allow large custom image sizes
grid_mp = round(len(xs) * len(ys) * len(zs) * p.width * p.height / 1000000) grid_mp = round(len(xs) * len(ys) * len(zs) * p.width * p.height / 1000000)
assert grid_mp < opts.img_max_size_mp, f'Error: Resulting grid would be too large ({grid_mp} MPixels) (max configured size is {opts.img_max_size_mp} MPixels)' assert grid_mp < opts.img_max_size_mp, f'Error: Resulting grid would be too large ({grid_mp} MPixels) (max configured size is {opts.img_max_size_mp} MPixels)'
@ -720,7 +752,7 @@ class Script(scripts.Script):
# Auto-save main and sub-grids: # Auto-save main and sub-grids:
grid_count = z_count + 1 if z_count > 1 else 1 grid_count = z_count + 1 if z_count > 1 else 1
for g in range(grid_count): for g in range(grid_count):
#TODO: See previous comment about intentional data misalignment. # TODO: See previous comment about intentional data misalignment.
adj_g = g-1 if g > 0 else g adj_g = g-1 if g > 0 else g
images.save_image(processed.images[g], p.outpath_grids, "xyz_grid", info=processed.infotexts[g], extension=opts.grid_format, prompt=processed.all_prompts[adj_g], seed=processed.all_seeds[adj_g], grid=True, p=processed) images.save_image(processed.images[g], p.outpath_grids, "xyz_grid", info=processed.infotexts[g], extension=opts.grid_format, prompt=processed.all_prompts[adj_g], seed=processed.all_seeds[adj_g], grid=True, p=processed)

View File

@ -166,16 +166,6 @@ a{
color: var(--button-secondary-text-color-hover); color: var(--button-secondary-text-color-hover);
} }
.checkboxes-row{
margin-bottom: 0.5em;
margin-left: 0em;
}
.checkboxes-row > div{
flex: 0;
white-space: nowrap;
min-width: auto !important;
}
button.custom-button{ button.custom-button{
border-radius: var(--button-large-radius); border-radius: var(--button-large-radius);
padding: var(--button-large-padding); padding: var(--button-large-padding);
@ -192,7 +182,7 @@ button.custom-button{
text-align: center; text-align: center;
} }
div.gradio-accordion { div.block.gradio-accordion {
border: 1px solid var(--block-border-color) !important; border: 1px solid var(--block-border-color) !important;
border-radius: 8px !important; border-radius: 8px !important;
margin: 2px 0; margin: 2px 0;
@ -239,10 +229,14 @@ div.gradio-accordion {
} }
[id$=_subseed_show] label{ [id$=_subseed_show] label{
margin-bottom: 0.5em; margin-bottom: 0.65em;
align-self: end; align-self: end;
} }
[id$=_seed_extras] > div{
gap: 0.5em;
}
.html-log .comments{ .html-log .comments{
padding-top: 0.5em; padding-top: 0.5em;
} }
@ -352,7 +346,7 @@ div.gradio-accordion {
} }
div.dimensions-tools{ div.dimensions-tools{
min-width: 0 !important; min-width: 1.6em !important;
max-width: fit-content; max-width: fit-content;
flex-direction: column; flex-direction: column;
place-content: center; place-content: center;
@ -369,8 +363,8 @@ div#extras_scale_to_tab div.form{
z-index: 5; z-index: 5;
} }
.image-buttons button{ .image-buttons > .form{
min-width: auto; justify-content: center;
} }
.infotext { .infotext {
@ -391,19 +385,21 @@ div#extras_scale_to_tab div.form{
/* settings */ /* settings */
#quicksettings { #quicksettings {
width: fit-content;
align-items: end; align-items: end;
} }
#quicksettings > div, #quicksettings > fieldset{ #quicksettings > div, #quicksettings > fieldset{
max-width: 24em; max-width: 36em;
min-width: 24em; width: fit-content;
width: 24em; flex: 0 1 fit-content;
padding: 0; padding: 0;
border: none; border: none;
box-shadow: none; box-shadow: none;
background: none; background: none;
} }
#quicksettings > div.gradio-dropdown{
min-width: 24em !important;
}
#settings{ #settings{
display: block; display: block;
@ -1012,10 +1008,29 @@ div.block.gradio-box.popup-dialog > div:last-child, .popup-dialog > div:last-chi
} }
div.block.input-accordion{ div.block.input-accordion{
margin-bottom: 0.4em;
} }
.input-accordion-extra{ .input-accordion-extra{
flex: 0 0 auto !important; flex: 0 0 auto !important;
margin: 0 0.5em 0 auto; margin: 0 0.5em 0 auto;
} }
div.accordions > div.input-accordion{
min-width: fit-content !important;
}
div.accordions > div.gradio-accordion .label-wrap span{
white-space: nowrap;
margin-right: 0.25em;
}
div.accordions{
gap: 0.5em;
}
div.accordions > div.input-accordion.input-accordion-open{
flex: 1 auto;
flex-flow: column;
}

View File

@ -12,8 +12,6 @@ fi
export install_dir="$HOME" export install_dir="$HOME"
export COMMANDLINE_ARGS="--skip-torch-cuda-test --upcast-sampling --no-half-vae --use-cpu interrogate" export COMMANDLINE_ARGS="--skip-torch-cuda-test --upcast-sampling --no-half-vae --use-cpu interrogate"
export TORCH_COMMAND="pip install torch==2.0.1 torchvision==0.15.2" export TORCH_COMMAND="pip install torch==2.0.1 torchvision==0.15.2"
export K_DIFFUSION_REPO="https://github.com/brkirch/k-diffusion.git"
export K_DIFFUSION_COMMIT_HASH="51c9778f269cedb55a4d88c79c0246d35bdadb71"
export PYTORCH_ENABLE_MPS_FALLBACK=1 export PYTORCH_ENABLE_MPS_FALLBACK=1
#################################################################### ####################################################################