mirror of
https://github.com/openvinotoolkit/stable-diffusion-webui.git
synced 2024-12-14 14:45:06 +03:00
Merge branch 'dev' into extra-norm-module
This commit is contained in:
commit
e7c03ccdce
@ -6,9 +6,14 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork):
|
||||
def __init__(self):
|
||||
super().__init__('lora')
|
||||
|
||||
self.errors = {}
|
||||
"""mapping of network names to the number of errors the network had during operation"""
|
||||
|
||||
def activate(self, p, params_list):
|
||||
additional = shared.opts.sd_lora
|
||||
|
||||
self.errors.clear()
|
||||
|
||||
if additional != "None" and additional in networks.available_networks and not any(x for x in params_list if x.items[0] == additional):
|
||||
p.all_prompts = [x + f"<lora:{additional}:{shared.opts.extra_networks_default_multiplier}>" for x in p.all_prompts]
|
||||
params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))
|
||||
@ -56,4 +61,7 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork):
|
||||
p.extra_generation_params["Lora hashes"] = ", ".join(network_hashes)
|
||||
|
||||
def deactivate(self, p):
|
||||
pass
|
||||
if self.errors:
|
||||
p.comment("Networks with errors: " + ", ".join(f"{k} ({v})" for k, v in self.errors.items()))
|
||||
|
||||
self.errors.clear()
|
||||
|
@ -1,3 +1,4 @@
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
|
||||
@ -194,7 +195,7 @@ def load_network(name, network_on_disk):
|
||||
net.modules[key] = net_module
|
||||
|
||||
if keys_failed_to_match:
|
||||
print(f"Failed to match keys when loading network {network_on_disk.filename}: {keys_failed_to_match}")
|
||||
logging.debug(f"Network {network_on_disk.filename} didn't match keys: {keys_failed_to_match}")
|
||||
|
||||
return net
|
||||
|
||||
@ -207,7 +208,6 @@ def purge_networks_from_memory():
|
||||
devices.torch_gc()
|
||||
|
||||
|
||||
|
||||
def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=None):
|
||||
already_loaded = {}
|
||||
|
||||
@ -248,7 +248,7 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No
|
||||
|
||||
if net is None:
|
||||
failed_to_load_networks.append(name)
|
||||
print(f"Couldn't find network with name {name}")
|
||||
logging.info(f"Couldn't find network with name {name}")
|
||||
continue
|
||||
|
||||
net.te_multiplier = te_multipliers[i] if te_multipliers else 1.0
|
||||
@ -257,7 +257,7 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No
|
||||
loaded_networks.append(net)
|
||||
|
||||
if failed_to_load_networks:
|
||||
sd_hijack.model_hijack.comments.append("Failed to find networks: " + ", ".join(failed_to_load_networks))
|
||||
sd_hijack.model_hijack.comments.append("Networks not found: " + ", ".join(failed_to_load_networks))
|
||||
|
||||
purge_networks_from_memory()
|
||||
|
||||
@ -327,20 +327,25 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
|
||||
for net in loaded_networks:
|
||||
module = net.modules.get(network_layer_name, None)
|
||||
if module is not None and hasattr(self, 'weight'):
|
||||
with torch.no_grad():
|
||||
updown, ex_bias = module.calc_updown(self.weight)
|
||||
try:
|
||||
with torch.no_grad():
|
||||
updown, ex_bias = module.calc_updown(self.weight)
|
||||
|
||||
if len(self.weight.shape) == 4 and self.weight.shape[1] == 9:
|
||||
# inpainting model. zero pad updown to make channel[1] 4 to 9
|
||||
updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5))
|
||||
if len(self.weight.shape) == 4 and self.weight.shape[1] == 9:
|
||||
# inpainting model. zero pad updown to make channel[1] 4 to 9
|
||||
updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5))
|
||||
|
||||
self.weight += updown
|
||||
if ex_bias is not None and hasattr(self, 'bias'):
|
||||
if self.bias is None:
|
||||
self.bias = torch.nn.Parameter(ex_bias)
|
||||
else:
|
||||
self.bias += ex_bias
|
||||
continue
|
||||
self.weight += updown
|
||||
if ex_bias is not None and hasattr(self, 'bias'):
|
||||
if self.bias is None:
|
||||
self.bias = torch.nn.Parameter(ex_bias)
|
||||
else:
|
||||
self.bias += ex_bias
|
||||
except RuntimeError as e:
|
||||
logging.debug(f"Network {net.name} layer {network_layer_name}: {e}")
|
||||
extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1
|
||||
|
||||
continue
|
||||
|
||||
module_q = net.modules.get(network_layer_name + "_q_proj", None)
|
||||
module_k = net.modules.get(network_layer_name + "_k_proj", None)
|
||||
@ -348,26 +353,33 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
|
||||
module_out = net.modules.get(network_layer_name + "_out_proj", None)
|
||||
|
||||
if isinstance(self, torch.nn.MultiheadAttention) and module_q and module_k and module_v and module_out:
|
||||
with torch.no_grad():
|
||||
updown_q, _ = module_q.calc_updown(self.in_proj_weight)
|
||||
updown_k, _ = module_k.calc_updown(self.in_proj_weight)
|
||||
updown_v, _ = module_v.calc_updown(self.in_proj_weight)
|
||||
updown_qkv = torch.vstack([updown_q, updown_k, updown_v])
|
||||
updown_out, ex_bias = module_out.calc_updown(self.out_proj.weight)
|
||||
try:
|
||||
with torch.no_grad():
|
||||
updown_q, _ = module_q.calc_updown(self.in_proj_weight)
|
||||
updown_k, _ = module_k.calc_updown(self.in_proj_weight)
|
||||
updown_v, _ = module_v.calc_updown(self.in_proj_weight)
|
||||
updown_qkv = torch.vstack([updown_q, updown_k, updown_v])
|
||||
updown_out, ex_bias = module_out.calc_updown(self.out_proj.weight)
|
||||
|
||||
self.in_proj_weight += updown_qkv
|
||||
self.out_proj.weight += updown_out
|
||||
self.in_proj_weight += updown_qkv
|
||||
self.out_proj.weight += updown_out
|
||||
if ex_bias is not None:
|
||||
if self.out_proj.bias is None:
|
||||
self.out_proj.bias = torch.nn.Parameter(ex_bias)
|
||||
else:
|
||||
self.out_proj.bias += ex_bias
|
||||
continue
|
||||
|
||||
except RuntimeError as e:
|
||||
logging.debug(f"Network {net.name} layer {network_layer_name}: {e}")
|
||||
extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1
|
||||
|
||||
continue
|
||||
|
||||
if module is None:
|
||||
continue
|
||||
|
||||
print(f'failed to calculate network weights for layer {network_layer_name}')
|
||||
logging.debug(f"Network {net.name} layer {network_layer_name}: couldn't find supported operation")
|
||||
extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1
|
||||
|
||||
self.network_current_names = wanted_names
|
||||
|
||||
@ -540,6 +552,7 @@ def infotext_pasted(infotext, params):
|
||||
if added:
|
||||
params["Prompt"] += "\n" + "".join(added)
|
||||
|
||||
extra_network_lora = None
|
||||
|
||||
available_networks = {}
|
||||
available_network_aliases = {}
|
||||
|
@ -23,9 +23,9 @@ def unload():
|
||||
def before_ui():
|
||||
ui_extra_networks.register_page(ui_extra_networks_lora.ExtraNetworksPageLora())
|
||||
|
||||
extra_network = extra_networks_lora.ExtraNetworkLora()
|
||||
extra_networks.register_extra_network(extra_network)
|
||||
extra_networks.register_extra_network_alias(extra_network, "lyco")
|
||||
networks.extra_network_lora = extra_networks_lora.ExtraNetworkLora()
|
||||
extra_networks.register_extra_network(networks.extra_network_lora)
|
||||
extra_networks.register_extra_network_alias(networks.extra_network_lora, "lyco")
|
||||
|
||||
|
||||
if not hasattr(torch.nn, 'Linear_forward_before_network'):
|
||||
|
@ -25,9 +25,10 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
|
||||
item = {
|
||||
"name": name,
|
||||
"filename": lora_on_disk.filename,
|
||||
"shorthash": lora_on_disk.shorthash,
|
||||
"preview": self.find_preview(path),
|
||||
"description": self.find_description(path),
|
||||
"search_term": self.search_terms_from_path(lora_on_disk.filename),
|
||||
"search_term": self.search_terms_from_path(lora_on_disk.filename) + " " + (lora_on_disk.hash or ""),
|
||||
"local_preview": f"{path}.{shared.opts.samples_format}",
|
||||
"metadata": lora_on_disk.metadata,
|
||||
"sort_keys": {'default': index, **self.get_sort_keys(lora_on_disk.filename)},
|
||||
|
@ -173,9 +173,12 @@ def git_clone(url, dir, name, commithash=None):
|
||||
if current_hash == commithash:
|
||||
return
|
||||
|
||||
run_git('fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}", autofix=False)
|
||||
if run_git(dir, name, 'config --get remote.origin.url', None, f"Couldn't determine {name}'s origin URL", live=False).strip() != url:
|
||||
run_git(dir, name, f'remote set-url origin "{url}"', None, f"Failed to set {name}'s origin URL", live=False)
|
||||
|
||||
run_git('checkout', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}", live=True)
|
||||
run_git(dir, name, 'fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}", autofix=False)
|
||||
|
||||
run_git(dir, name, f'checkout {commithash}', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}", live=True)
|
||||
|
||||
return
|
||||
|
||||
@ -319,12 +322,12 @@ def prepare_environment():
|
||||
|
||||
stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf")
|
||||
stable_diffusion_xl_commit_hash = os.environ.get('STABLE_DIFFUSION_XL_COMMIT_HASH', "5c10deee76adad0032b412294130090932317a87")
|
||||
k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "c9fe758757e022f05ca5a53fa8fac28889e4f1cf")
|
||||
k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "ab527a9a6d347f364e3d185ba6d714e22d80cb3c")
|
||||
codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af")
|
||||
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
|
||||
|
||||
try:
|
||||
# the existance of this file is a signal to webui.sh/bat that webui needs to be restarted when it stops execution
|
||||
# the existence of this file is a signal to webui.sh/bat that webui needs to be restarted when it stops execution
|
||||
os.remove(os.path.join(script_path, "tmp", "restart"))
|
||||
os.environ.setdefault('SD_WEBUI_RESTARTING', '1')
|
||||
except OSError:
|
||||
|
@ -52,9 +52,6 @@ def cumsum_fix(input, cumsum_func, *args, **kwargs):
|
||||
|
||||
|
||||
if has_mps:
|
||||
# MPS fix for randn in torchsde
|
||||
CondFunc('torchsde._brownian.brownian_interval._randn', lambda _, size, dtype, device, seed: torch.randn(size, dtype=dtype, device=torch.device("cpu"), generator=torch.Generator(torch.device("cpu")).manual_seed(int(seed))).to(device), lambda _, size, dtype, device, seed: device.type == 'mps')
|
||||
|
||||
if platform.mac_ver()[0].startswith("13.2."):
|
||||
# MPS workaround for https://github.com/pytorch/pytorch/issues/95188, thanks to danieldk (https://github.com/explosion/curated-transformers/pull/124)
|
||||
CondFunc('torch.nn.functional.linear', lambda _, input, weight, bias: (torch.matmul(input, weight.t()) + bias) if bias is not None else torch.matmul(input, weight.t()), lambda _, input, weight, bias: input.numel() > 10485760)
|
||||
|
@ -1,9 +1,11 @@
|
||||
from __future__ import annotations
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import sys
|
||||
import hashlib
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
@ -11,7 +13,7 @@ from PIL import Image, ImageOps
|
||||
import random
|
||||
import cv2
|
||||
from skimage import exposure
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any
|
||||
|
||||
import modules.sd_hijack
|
||||
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors, rng
|
||||
@ -104,97 +106,160 @@ def txt2img_image_conditioning(sd_model, x, width, height):
|
||||
return x.new_zeros(x.shape[0], 5, 1, 1, dtype=x.dtype, device=x.device)
|
||||
|
||||
|
||||
@dataclass(repr=False)
|
||||
class StableDiffusionProcessing:
|
||||
"""
|
||||
The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing
|
||||
"""
|
||||
sd_model: object = None
|
||||
outpath_samples: str = None
|
||||
outpath_grids: str = None
|
||||
prompt: str = ""
|
||||
prompt_for_display: str = None
|
||||
negative_prompt: str = ""
|
||||
styles: list[str] = field(default_factory=list)
|
||||
seed: int = -1
|
||||
subseed: int = -1
|
||||
subseed_strength: float = 0
|
||||
seed_resize_from_h: int = -1
|
||||
seed_resize_from_w: int = -1
|
||||
seed_enable_extras: bool = True
|
||||
sampler_name: str = None
|
||||
batch_size: int = 1
|
||||
n_iter: int = 1
|
||||
steps: int = 50
|
||||
cfg_scale: float = 7.0
|
||||
width: int = 512
|
||||
height: int = 512
|
||||
restore_faces: bool = None
|
||||
tiling: bool = None
|
||||
do_not_save_samples: bool = False
|
||||
do_not_save_grid: bool = False
|
||||
extra_generation_params: dict[str, Any] = None
|
||||
overlay_images: list = None
|
||||
eta: float = None
|
||||
do_not_reload_embeddings: bool = False
|
||||
denoising_strength: float = 0
|
||||
ddim_discretize: str = None
|
||||
s_min_uncond: float = None
|
||||
s_churn: float = None
|
||||
s_tmax: float = None
|
||||
s_tmin: float = None
|
||||
s_noise: float = None
|
||||
override_settings: dict[str, Any] = None
|
||||
override_settings_restore_afterwards: bool = True
|
||||
sampler_index: int = None
|
||||
refiner_checkpoint: str = None
|
||||
refiner_switch_at: float = None
|
||||
token_merging_ratio = 0
|
||||
token_merging_ratio_hr = 0
|
||||
disable_extra_networks: bool = False
|
||||
|
||||
scripts_value: scripts.ScriptRunner = field(default=None, init=False)
|
||||
script_args_value: list = field(default=None, init=False)
|
||||
scripts_setup_complete: bool = field(default=False, init=False)
|
||||
|
||||
cached_uc = [None, None]
|
||||
cached_c = [None, None]
|
||||
|
||||
def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = None, tiling: bool = None, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_min_uncond: float = 0.0, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = None, override_settings: Dict[str, Any] = None, override_settings_restore_afterwards: bool = True, sampler_index: int = None, script_args: list = None):
|
||||
if sampler_index is not None:
|
||||
comments: dict = None
|
||||
sampler: sd_samplers_common.Sampler | None = field(default=None, init=False)
|
||||
is_using_inpainting_conditioning: bool = field(default=False, init=False)
|
||||
paste_to: tuple | None = field(default=None, init=False)
|
||||
|
||||
is_hr_pass: bool = field(default=False, init=False)
|
||||
|
||||
c: tuple = field(default=None, init=False)
|
||||
uc: tuple = field(default=None, init=False)
|
||||
|
||||
rng: rng.ImageRNG | None = field(default=None, init=False)
|
||||
step_multiplier: int = field(default=1, init=False)
|
||||
color_corrections: list = field(default=None, init=False)
|
||||
|
||||
all_prompts: list = field(default=None, init=False)
|
||||
all_negative_prompts: list = field(default=None, init=False)
|
||||
all_seeds: list = field(default=None, init=False)
|
||||
all_subseeds: list = field(default=None, init=False)
|
||||
iteration: int = field(default=0, init=False)
|
||||
main_prompt: str = field(default=None, init=False)
|
||||
main_negative_prompt: str = field(default=None, init=False)
|
||||
|
||||
prompts: list = field(default=None, init=False)
|
||||
negative_prompts: list = field(default=None, init=False)
|
||||
seeds: list = field(default=None, init=False)
|
||||
subseeds: list = field(default=None, init=False)
|
||||
extra_network_data: dict = field(default=None, init=False)
|
||||
|
||||
user: str = field(default=None, init=False)
|
||||
|
||||
sd_model_name: str = field(default=None, init=False)
|
||||
sd_model_hash: str = field(default=None, init=False)
|
||||
sd_vae_name: str = field(default=None, init=False)
|
||||
sd_vae_hash: str = field(default=None, init=False)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.sampler_index is not None:
|
||||
print("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name", file=sys.stderr)
|
||||
|
||||
self.outpath_samples: str = outpath_samples
|
||||
self.outpath_grids: str = outpath_grids
|
||||
self.prompt: str = prompt
|
||||
self.prompt_for_display: str = None
|
||||
self.negative_prompt: str = (negative_prompt or "")
|
||||
self.styles: list = styles or []
|
||||
self.seed: int = seed
|
||||
self.subseed: int = subseed
|
||||
self.subseed_strength: float = subseed_strength
|
||||
self.seed_resize_from_h: int = seed_resize_from_h
|
||||
self.seed_resize_from_w: int = seed_resize_from_w
|
||||
self.sampler_name: str = sampler_name
|
||||
self.batch_size: int = batch_size
|
||||
self.n_iter: int = n_iter
|
||||
self.steps: int = steps
|
||||
self.cfg_scale: float = cfg_scale
|
||||
self.width: int = width
|
||||
self.height: int = height
|
||||
self.restore_faces: bool = restore_faces
|
||||
self.tiling: bool = tiling
|
||||
self.do_not_save_samples: bool = do_not_save_samples
|
||||
self.do_not_save_grid: bool = do_not_save_grid
|
||||
self.extra_generation_params: dict = extra_generation_params or {}
|
||||
self.overlay_images = overlay_images
|
||||
self.eta = eta
|
||||
self.do_not_reload_embeddings = do_not_reload_embeddings
|
||||
self.paste_to = None
|
||||
self.color_corrections = None
|
||||
self.denoising_strength: float = denoising_strength
|
||||
self.sampler_noise_scheduler_override = None
|
||||
self.ddim_discretize = ddim_discretize or opts.ddim_discretize
|
||||
self.s_min_uncond = s_min_uncond or opts.s_min_uncond
|
||||
self.s_churn = s_churn or opts.s_churn
|
||||
self.s_tmin = s_tmin or opts.s_tmin
|
||||
self.s_tmax = (s_tmax if s_tmax is not None else opts.s_tmax) or float('inf')
|
||||
self.s_noise = s_noise if s_noise is not None else opts.s_noise
|
||||
self.override_settings = {k: v for k, v in (override_settings or {}).items() if k not in shared.restricted_opts}
|
||||
self.override_settings_restore_afterwards = override_settings_restore_afterwards
|
||||
self.is_using_inpainting_conditioning = False
|
||||
self.disable_extra_networks = False
|
||||
self.token_merging_ratio = 0
|
||||
self.token_merging_ratio_hr = 0
|
||||
self.comments = {}
|
||||
|
||||
if not seed_enable_extras:
|
||||
self.sampler_noise_scheduler_override = None
|
||||
self.s_min_uncond = self.s_min_uncond if self.s_min_uncond is not None else opts.s_min_uncond
|
||||
self.s_churn = self.s_churn if self.s_churn is not None else opts.s_churn
|
||||
self.s_tmin = self.s_tmin if self.s_tmin is not None else opts.s_tmin
|
||||
self.s_tmax = (self.s_tmax if self.s_tmax is not None else opts.s_tmax) or float('inf')
|
||||
self.s_noise = self.s_noise if self.s_noise is not None else opts.s_noise
|
||||
|
||||
self.extra_generation_params = self.extra_generation_params or {}
|
||||
self.override_settings = self.override_settings or {}
|
||||
self.script_args = self.script_args or {}
|
||||
|
||||
self.refiner_checkpoint_info = None
|
||||
|
||||
if not self.seed_enable_extras:
|
||||
self.subseed = -1
|
||||
self.subseed_strength = 0
|
||||
self.seed_resize_from_h = 0
|
||||
self.seed_resize_from_w = 0
|
||||
|
||||
self.scripts = None
|
||||
self.script_args = script_args
|
||||
self.all_prompts = None
|
||||
self.all_negative_prompts = None
|
||||
self.all_seeds = None
|
||||
self.all_subseeds = None
|
||||
self.iteration = 0
|
||||
self.is_hr_pass = False
|
||||
self.sampler = None
|
||||
self.main_prompt = None
|
||||
self.main_negative_prompt = None
|
||||
|
||||
self.prompts = None
|
||||
self.negative_prompts = None
|
||||
self.extra_network_data = None
|
||||
self.seeds = None
|
||||
self.subseeds = None
|
||||
|
||||
self.step_multiplier = 1
|
||||
self.cached_uc = StableDiffusionProcessing.cached_uc
|
||||
self.cached_c = StableDiffusionProcessing.cached_c
|
||||
self.uc = None
|
||||
self.c = None
|
||||
self.rng: rng.ImageRNG = None
|
||||
|
||||
self.user = None
|
||||
|
||||
@property
|
||||
def sd_model(self):
|
||||
return shared.sd_model
|
||||
|
||||
@sd_model.setter
|
||||
def sd_model(self, value):
|
||||
pass
|
||||
|
||||
@property
|
||||
def scripts(self):
|
||||
return self.scripts_value
|
||||
|
||||
@scripts.setter
|
||||
def scripts(self, value):
|
||||
self.scripts_value = value
|
||||
|
||||
if self.scripts_value and self.script_args_value and not self.scripts_setup_complete:
|
||||
self.setup_scripts()
|
||||
|
||||
@property
|
||||
def script_args(self):
|
||||
return self.script_args_value
|
||||
|
||||
@script_args.setter
|
||||
def script_args(self, value):
|
||||
self.script_args_value = value
|
||||
|
||||
if self.scripts_value and self.script_args_value and not self.scripts_setup_complete:
|
||||
self.setup_scripts()
|
||||
|
||||
def setup_scripts(self):
|
||||
self.scripts_setup_complete = True
|
||||
|
||||
self.scripts.setup_scrips(self)
|
||||
|
||||
def comment(self, text):
|
||||
self.comments[text] = 1
|
||||
|
||||
def txt2img_image_conditioning(self, x, width=None, height=None):
|
||||
self.is_using_inpainting_conditioning = self.sd_model.model.conditioning_key in {'hybrid', 'concat'}
|
||||
|
||||
@ -398,7 +463,7 @@ class Processed:
|
||||
self.subseed = subseed
|
||||
self.subseed_strength = p.subseed_strength
|
||||
self.info = info
|
||||
self.comments = comments
|
||||
self.comments = "".join(f"{comment}\n" for comment in p.comments)
|
||||
self.width = p.width
|
||||
self.height = p.height
|
||||
self.sampler_name = p.sampler_name
|
||||
@ -408,7 +473,10 @@ class Processed:
|
||||
self.batch_size = p.batch_size
|
||||
self.restore_faces = p.restore_faces
|
||||
self.face_restoration_model = opts.face_restoration_model if p.restore_faces else None
|
||||
self.sd_model_hash = shared.sd_model.sd_model_hash
|
||||
self.sd_model_name = p.sd_model_name
|
||||
self.sd_model_hash = p.sd_model_hash
|
||||
self.sd_vae_name = p.sd_vae_name
|
||||
self.sd_vae_hash = p.sd_vae_hash
|
||||
self.seed_resize_from_w = p.seed_resize_from_w
|
||||
self.seed_resize_from_h = p.seed_resize_from_h
|
||||
self.denoising_strength = getattr(p, 'denoising_strength', None)
|
||||
@ -459,7 +527,10 @@ class Processed:
|
||||
"batch_size": self.batch_size,
|
||||
"restore_faces": self.restore_faces,
|
||||
"face_restoration_model": self.face_restoration_model,
|
||||
"sd_model_name": self.sd_model_name,
|
||||
"sd_model_hash": self.sd_model_hash,
|
||||
"sd_vae_name": self.sd_vae_name,
|
||||
"sd_vae_hash": self.sd_vae_hash,
|
||||
"seed_resize_from_w": self.seed_resize_from_w,
|
||||
"seed_resize_from_h": self.seed_resize_from_h,
|
||||
"denoising_strength": self.denoising_strength,
|
||||
@ -578,10 +649,10 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
|
||||
"Seed": p.all_seeds[0] if use_main_prompt else all_seeds[index],
|
||||
"Face restoration": opts.face_restoration_model if p.restore_faces else None,
|
||||
"Size": f"{p.width}x{p.height}",
|
||||
"Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
|
||||
"Model": (None if not opts.add_model_name_to_info else shared.sd_model.sd_checkpoint_info.name_for_extra),
|
||||
"VAE hash": p.loaded_vae_hash if opts.add_model_hash_to_info else None,
|
||||
"VAE": p.loaded_vae_name if opts.add_model_name_to_info else None,
|
||||
"Model hash": p.sd_model_hash if opts.add_model_hash_to_info else None,
|
||||
"Model": p.sd_model_name if opts.add_model_name_to_info else None,
|
||||
"VAE hash": p.sd_vae_hash if opts.add_model_hash_to_info else None,
|
||||
"VAE": p.sd_vae_name if opts.add_model_name_to_info else None,
|
||||
"Variation seed": (None if p.subseed_strength == 0 else (p.all_subseeds[0] if use_main_prompt else all_subseeds[index])),
|
||||
"Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
|
||||
"Seed resize from": (None if p.seed_resize_from_w <= 0 or p.seed_resize_from_h <= 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
|
||||
@ -670,14 +741,19 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||
if p.tiling is None:
|
||||
p.tiling = opts.tiling
|
||||
|
||||
p.loaded_vae_name = sd_vae.get_loaded_vae_name()
|
||||
p.loaded_vae_hash = sd_vae.get_loaded_vae_hash()
|
||||
if p.refiner_checkpoint not in (None, "", "None"):
|
||||
p.refiner_checkpoint_info = sd_models.get_closet_checkpoint_match(p.refiner_checkpoint)
|
||||
if p.refiner_checkpoint_info is None:
|
||||
raise Exception(f'Could not find checkpoint with name {p.refiner_checkpoint}')
|
||||
|
||||
p.sd_model_name = shared.sd_model.sd_checkpoint_info.name_for_extra
|
||||
p.sd_model_hash = shared.sd_model.sd_model_hash
|
||||
p.sd_vae_name = sd_vae.get_loaded_vae_name()
|
||||
p.sd_vae_hash = sd_vae.get_loaded_vae_hash()
|
||||
|
||||
modules.sd_hijack.model_hijack.apply_circular(p.tiling)
|
||||
modules.sd_hijack.model_hijack.clear_comments()
|
||||
|
||||
comments = {}
|
||||
|
||||
p.setup_prompts()
|
||||
|
||||
if type(seed) == list:
|
||||
@ -757,7 +833,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||
p.setup_conds()
|
||||
|
||||
for comment in model_hijack.comments:
|
||||
comments[comment] = 1
|
||||
p.comment(comment)
|
||||
|
||||
p.extra_generation_params.update(model_hijack.extra_generation_params)
|
||||
|
||||
@ -886,7 +962,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||
images_list=output_images,
|
||||
seed=p.all_seeds[0],
|
||||
info=infotexts[0],
|
||||
comments="".join(f"{comment}\n" for comment in comments),
|
||||
subseed=p.all_subseeds[0],
|
||||
index_of_first_image=index_of_first_image,
|
||||
infotexts=infotexts,
|
||||
@ -910,49 +985,51 @@ def old_hires_fix_first_pass_dimensions(width, height):
|
||||
return width, height
|
||||
|
||||
|
||||
@dataclass(repr=False)
|
||||
class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||
sampler = None
|
||||
enable_hr: bool = False
|
||||
denoising_strength: float = 0.75
|
||||
firstphase_width: int = 0
|
||||
firstphase_height: int = 0
|
||||
hr_scale: float = 2.0
|
||||
hr_upscaler: str = None
|
||||
hr_second_pass_steps: int = 0
|
||||
hr_resize_x: int = 0
|
||||
hr_resize_y: int = 0
|
||||
hr_checkpoint_name: str = None
|
||||
hr_sampler_name: str = None
|
||||
hr_prompt: str = ''
|
||||
hr_negative_prompt: str = ''
|
||||
|
||||
cached_hr_uc = [None, None]
|
||||
cached_hr_c = [None, None]
|
||||
|
||||
def __init__(self, enable_hr: bool = False, denoising_strength: float = 0.75, firstphase_width: int = 0, firstphase_height: int = 0, hr_scale: float = 2.0, hr_upscaler: str = None, hr_second_pass_steps: int = 0, hr_resize_x: int = 0, hr_resize_y: int = 0, hr_checkpoint_name: str = None, hr_sampler_name: str = None, hr_prompt: str = '', hr_negative_prompt: str = '', **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.enable_hr = enable_hr
|
||||
self.denoising_strength = denoising_strength
|
||||
self.hr_scale = hr_scale
|
||||
self.hr_upscaler = hr_upscaler
|
||||
self.hr_second_pass_steps = hr_second_pass_steps
|
||||
self.hr_resize_x = hr_resize_x
|
||||
self.hr_resize_y = hr_resize_y
|
||||
self.hr_upscale_to_x = hr_resize_x
|
||||
self.hr_upscale_to_y = hr_resize_y
|
||||
self.hr_checkpoint_name = hr_checkpoint_name
|
||||
self.hr_checkpoint_info = None
|
||||
self.hr_sampler_name = hr_sampler_name
|
||||
self.hr_prompt = hr_prompt
|
||||
self.hr_negative_prompt = hr_negative_prompt
|
||||
self.all_hr_prompts = None
|
||||
self.all_hr_negative_prompts = None
|
||||
self.latent_scale_mode = None
|
||||
hr_checkpoint_info: dict = field(default=None, init=False)
|
||||
hr_upscale_to_x: int = field(default=0, init=False)
|
||||
hr_upscale_to_y: int = field(default=0, init=False)
|
||||
truncate_x: int = field(default=0, init=False)
|
||||
truncate_y: int = field(default=0, init=False)
|
||||
applied_old_hires_behavior_to: tuple = field(default=None, init=False)
|
||||
latent_scale_mode: dict = field(default=None, init=False)
|
||||
hr_c: tuple | None = field(default=None, init=False)
|
||||
hr_uc: tuple | None = field(default=None, init=False)
|
||||
all_hr_prompts: list = field(default=None, init=False)
|
||||
all_hr_negative_prompts: list = field(default=None, init=False)
|
||||
hr_prompts: list = field(default=None, init=False)
|
||||
hr_negative_prompts: list = field(default=None, init=False)
|
||||
hr_extra_network_data: list = field(default=None, init=False)
|
||||
|
||||
if firstphase_width != 0 or firstphase_height != 0:
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
||||
if self.firstphase_width != 0 or self.firstphase_height != 0:
|
||||
self.hr_upscale_to_x = self.width
|
||||
self.hr_upscale_to_y = self.height
|
||||
self.width = firstphase_width
|
||||
self.height = firstphase_height
|
||||
|
||||
self.truncate_x = 0
|
||||
self.truncate_y = 0
|
||||
self.applied_old_hires_behavior_to = None
|
||||
|
||||
self.hr_prompts = None
|
||||
self.hr_negative_prompts = None
|
||||
self.hr_extra_network_data = None
|
||||
self.width = self.firstphase_width
|
||||
self.height = self.firstphase_height
|
||||
|
||||
self.cached_hr_uc = StableDiffusionProcessingTxt2Img.cached_hr_uc
|
||||
self.cached_hr_c = StableDiffusionProcessingTxt2Img.cached_hr_c
|
||||
self.hr_c = None
|
||||
self.hr_uc = None
|
||||
|
||||
def calculate_target_resolution(self):
|
||||
if opts.use_old_hires_fix_width_height and self.applied_old_hires_behavior_to != (self.width, self.height):
|
||||
@ -1146,6 +1223,9 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||
|
||||
sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio())
|
||||
|
||||
self.sampler = None
|
||||
devices.torch_gc()
|
||||
|
||||
decoded_samples = decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)
|
||||
|
||||
self.is_hr_pass = False
|
||||
@ -1230,7 +1310,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||
|
||||
return super().get_conds()
|
||||
|
||||
|
||||
def parse_extra_network_prompts(self):
|
||||
res = super().parse_extra_network_prompts()
|
||||
|
||||
@ -1243,32 +1322,37 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||
return res
|
||||
|
||||
|
||||
@dataclass(repr=False)
|
||||
class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
||||
sampler = None
|
||||
init_images: list = None
|
||||
resize_mode: int = 0
|
||||
denoising_strength: float = 0.75
|
||||
image_cfg_scale: float = None
|
||||
mask: Any = None
|
||||
mask_blur_x: int = 4
|
||||
mask_blur_y: int = 4
|
||||
mask_blur: int = None
|
||||
inpainting_fill: int = 0
|
||||
inpaint_full_res: bool = True
|
||||
inpaint_full_res_padding: int = 0
|
||||
inpainting_mask_invert: int = 0
|
||||
initial_noise_multiplier: float = None
|
||||
latent_mask: Image = None
|
||||
|
||||
def __init__(self, init_images: list = None, resize_mode: int = 0, denoising_strength: float = 0.75, image_cfg_scale: float = None, mask: Any = None, mask_blur: int = None, mask_blur_x: int = 4, mask_blur_y: int = 4, inpainting_fill: int = 0, inpaint_full_res: bool = True, inpaint_full_res_padding: int = 0, inpainting_mask_invert: int = 0, initial_noise_multiplier: float = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
image_mask: Any = field(default=None, init=False)
|
||||
|
||||
self.init_images = init_images
|
||||
self.resize_mode: int = resize_mode
|
||||
self.denoising_strength: float = denoising_strength
|
||||
self.image_cfg_scale: float = image_cfg_scale if shared.sd_model.cond_stage_key == "edit" else None
|
||||
self.init_latent = None
|
||||
self.image_mask = mask
|
||||
self.latent_mask = None
|
||||
self.mask_for_overlay = None
|
||||
self.mask_blur_x = mask_blur_x
|
||||
self.mask_blur_y = mask_blur_y
|
||||
if mask_blur is not None:
|
||||
self.mask_blur = mask_blur
|
||||
self.inpainting_fill = inpainting_fill
|
||||
self.inpaint_full_res = inpaint_full_res
|
||||
self.inpaint_full_res_padding = inpaint_full_res_padding
|
||||
self.inpainting_mask_invert = inpainting_mask_invert
|
||||
self.initial_noise_multiplier = opts.initial_noise_multiplier if initial_noise_multiplier is None else initial_noise_multiplier
|
||||
nmask: torch.Tensor = field(default=None, init=False)
|
||||
image_conditioning: torch.Tensor = field(default=None, init=False)
|
||||
init_img_hash: str = field(default=None, init=False)
|
||||
mask_for_overlay: Image = field(default=None, init=False)
|
||||
init_latent: torch.Tensor = field(default=None, init=False)
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
||||
self.image_mask = self.mask
|
||||
self.mask = None
|
||||
self.nmask = None
|
||||
self.image_conditioning = None
|
||||
self.initial_noise_multiplier = opts.initial_noise_multiplier if self.initial_noise_multiplier is None else self.initial_noise_multiplier
|
||||
|
||||
@property
|
||||
def mask_blur(self):
|
||||
@ -1278,15 +1362,13 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
||||
|
||||
@mask_blur.setter
|
||||
def mask_blur(self, value):
|
||||
self.mask_blur_x = value
|
||||
self.mask_blur_y = value
|
||||
|
||||
@mask_blur.deleter
|
||||
def mask_blur(self):
|
||||
del self.mask_blur_x
|
||||
del self.mask_blur_y
|
||||
if isinstance(value, int):
|
||||
self.mask_blur_x = value
|
||||
self.mask_blur_y = value
|
||||
|
||||
def init(self, all_prompts, all_seeds, all_subseeds):
|
||||
self.image_cfg_scale: float = self.image_cfg_scale if shared.sd_model.cond_stage_key == "edit" else None
|
||||
|
||||
self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
|
||||
crop_region = None
|
||||
|
||||
|
@ -38,18 +38,12 @@ class ScriptRefiner(scripts.Script):
|
||||
|
||||
return enable_refiner, refiner_checkpoint, refiner_switch_at
|
||||
|
||||
def before_process(self, p, enable_refiner, refiner_checkpoint, refiner_switch_at):
|
||||
def setup(self, p, enable_refiner, refiner_checkpoint, refiner_switch_at):
|
||||
# the actual implementation is in sd_samplers_common.py, apply_refiner
|
||||
|
||||
p.refiner_checkpoint_info = None
|
||||
p.refiner_switch_at = None
|
||||
|
||||
if not enable_refiner or refiner_checkpoint in (None, "", "None"):
|
||||
return
|
||||
|
||||
refiner_checkpoint_info = sd_models.get_closet_checkpoint_match(refiner_checkpoint)
|
||||
if refiner_checkpoint_info is None:
|
||||
raise Exception(f'Could not find checkpoint with name {refiner_checkpoint}')
|
||||
|
||||
p.refiner_checkpoint_info = refiner_checkpoint_info
|
||||
p.refiner_switch_at = refiner_switch_at
|
||||
p.refiner_checkpoint_info = None
|
||||
p.refiner_switch_at = None
|
||||
else:
|
||||
p.refiner_checkpoint = refiner_checkpoint
|
||||
p.refiner_switch_at = refiner_switch_at
|
||||
|
@ -58,7 +58,7 @@ class ScriptSeed(scripts.ScriptBuiltin):
|
||||
|
||||
return self.seed, subseed, subseed_strength
|
||||
|
||||
def before_process(self, p, seed, subseed, subseed_strength):
|
||||
def setup(self, p, seed, subseed, subseed_strength):
|
||||
p.seed = seed
|
||||
|
||||
if subseed_strength > 0:
|
||||
|
@ -106,9 +106,16 @@ class Script:
|
||||
|
||||
pass
|
||||
|
||||
def setup(self, p, *args):
|
||||
"""For AlwaysVisible scripts, this function is called when the processing object is set up, before any processing starts.
|
||||
args contains all values returned by components from ui().
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
def before_process(self, p, *args):
|
||||
"""
|
||||
This function is called very early before processing begins for AlwaysVisible scripts.
|
||||
This function is called very early during processing begins for AlwaysVisible scripts.
|
||||
You can modify the processing object (p) here, inject hooks, etc.
|
||||
args contains all values returned by components from ui()
|
||||
"""
|
||||
@ -706,6 +713,14 @@ class ScriptRunner:
|
||||
except Exception:
|
||||
errors.report(f"Error running before_hr: {script.filename}", exc_info=True)
|
||||
|
||||
def setup_scrips(self, p):
|
||||
for script in self.alwayson_scripts:
|
||||
try:
|
||||
script_args = p.script_args[script.args_from:script.args_to]
|
||||
script.setup(p, *script_args)
|
||||
except Exception:
|
||||
errors.report(f"Error running setup: {script.filename}", exc_info=True)
|
||||
|
||||
|
||||
scripts_txt2img: ScriptRunner = None
|
||||
scripts_img2img: ScriptRunner = None
|
||||
|
@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
import math
|
||||
import psutil
|
||||
import platform
|
||||
|
||||
import torch
|
||||
from torch import einsum
|
||||
@ -94,7 +95,10 @@ class SdOptimizationSdp(SdOptimizationSdpNoMem):
|
||||
class SdOptimizationSubQuad(SdOptimization):
|
||||
name = "sub-quadratic"
|
||||
cmd_opt = "opt_sub_quad_attention"
|
||||
priority = 10
|
||||
|
||||
@property
|
||||
def priority(self):
|
||||
return 1000 if shared.device.type == 'mps' else 10
|
||||
|
||||
def apply(self):
|
||||
ldm.modules.attention.CrossAttention.forward = sub_quad_attention_forward
|
||||
@ -120,7 +124,7 @@ class SdOptimizationInvokeAI(SdOptimization):
|
||||
|
||||
@property
|
||||
def priority(self):
|
||||
return 1000 if not torch.cuda.is_available() else 10
|
||||
return 1000 if shared.device.type != 'mps' and not torch.cuda.is_available() else 10
|
||||
|
||||
def apply(self):
|
||||
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_invokeAI
|
||||
@ -427,7 +431,10 @@ def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_
|
||||
qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens
|
||||
|
||||
if chunk_threshold is None:
|
||||
chunk_threshold_bytes = int(get_available_vram() * 0.9) if q.device.type == 'mps' else int(get_available_vram() * 0.7)
|
||||
if q.device.type == 'mps':
|
||||
chunk_threshold_bytes = 268435456 * (2 if platform.processor() == 'i386' else bytes_per_token)
|
||||
else:
|
||||
chunk_threshold_bytes = int(get_available_vram() * 0.7)
|
||||
elif chunk_threshold == 0:
|
||||
chunk_threshold_bytes = None
|
||||
else:
|
||||
|
@ -92,7 +92,15 @@ def images_tensor_to_samples(image, approximation=None, model=None):
|
||||
model = shared.sd_model
|
||||
image = image.to(shared.device, dtype=devices.dtype_vae)
|
||||
image = image * 2 - 1
|
||||
x_latent = model.get_first_stage_encoding(model.encode_first_stage(image))
|
||||
if len(image) > 1:
|
||||
x_latent = torch.stack([
|
||||
model.get_first_stage_encoding(
|
||||
model.encode_first_stage(torch.unsqueeze(img, 0))
|
||||
)[0]
|
||||
for img in image
|
||||
])
|
||||
else:
|
||||
x_latent = model.get_first_stage_encoding(model.encode_first_stage(image))
|
||||
|
||||
return x_latent
|
||||
|
||||
@ -145,7 +153,7 @@ def apply_refiner(cfg_denoiser):
|
||||
refiner_switch_at = cfg_denoiser.p.refiner_switch_at
|
||||
refiner_checkpoint_info = cfg_denoiser.p.refiner_checkpoint_info
|
||||
|
||||
if refiner_switch_at is not None and completed_ratio <= refiner_switch_at:
|
||||
if refiner_switch_at is not None and completed_ratio < refiner_switch_at:
|
||||
return False
|
||||
|
||||
if refiner_checkpoint_info is None or shared.sd_model.sd_checkpoint_info == refiner_checkpoint_info:
|
||||
@ -276,19 +284,19 @@ class Sampler:
|
||||
s_tmax = getattr(opts, 's_tmax', p.s_tmax) or self.s_tmax # 0 = inf
|
||||
s_noise = getattr(opts, 's_noise', p.s_noise)
|
||||
|
||||
if s_churn != self.s_churn:
|
||||
if 's_churn' in extra_params_kwargs and s_churn != self.s_churn:
|
||||
extra_params_kwargs['s_churn'] = s_churn
|
||||
p.s_churn = s_churn
|
||||
p.extra_generation_params['Sigma churn'] = s_churn
|
||||
if s_tmin != self.s_tmin:
|
||||
if 's_tmin' in extra_params_kwargs and s_tmin != self.s_tmin:
|
||||
extra_params_kwargs['s_tmin'] = s_tmin
|
||||
p.s_tmin = s_tmin
|
||||
p.extra_generation_params['Sigma tmin'] = s_tmin
|
||||
if s_tmax != self.s_tmax:
|
||||
if 's_tmax' in extra_params_kwargs and s_tmax != self.s_tmax:
|
||||
extra_params_kwargs['s_tmax'] = s_tmax
|
||||
p.s_tmax = s_tmax
|
||||
p.extra_generation_params['Sigma tmax'] = s_tmax
|
||||
if s_noise != self.s_noise:
|
||||
if 's_noise' in extra_params_kwargs and s_noise != self.s_noise:
|
||||
extra_params_kwargs['s_noise'] = s_noise
|
||||
p.s_noise = s_noise
|
||||
p.extra_generation_params['Sigma noise'] = s_noise
|
||||
@ -305,5 +313,8 @@ class Sampler:
|
||||
current_iter_seeds = p.all_seeds[p.iteration * p.batch_size:(p.iteration + 1) * p.batch_size]
|
||||
return BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=current_iter_seeds)
|
||||
|
||||
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
||||
raise NotImplementedError()
|
||||
|
@ -22,6 +22,9 @@ samplers_k_diffusion = [
|
||||
('DPM++ 2M', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {}),
|
||||
('DPM++ SDE', 'sample_dpmpp_sde', ['k_dpmpp_sde'], {"second_order": True, "brownian_noise": True}),
|
||||
('DPM++ 2M SDE', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_ka'], {"brownian_noise": True}),
|
||||
('DPM++ 3M SDE', 'sample_dpmpp_3m_sde', ['k_dpmpp_3m_sde'], {'discard_next_to_last_sigma': True, "brownian_noise": True}),
|
||||
('DPM++ 3M SDE Karras', 'sample_dpmpp_3m_sde', ['k_dpmpp_3m_sde_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "brownian_noise": True}),
|
||||
('DPM++ 3M SDE Exponential', 'sample_dpmpp_3m_sde', ['k_dpmpp_3m_sde_exp'], {'scheduler': 'exponential', 'discard_next_to_last_sigma': True, "brownian_noise": True}),
|
||||
('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {"uses_ensd": True}),
|
||||
('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {"uses_ensd": True}),
|
||||
('LMS Karras', 'sample_lms', ['k_lms_ka'], {'scheduler': 'karras'}),
|
||||
@ -42,6 +45,12 @@ sampler_extra_params = {
|
||||
'sample_euler': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
|
||||
'sample_heun': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
|
||||
'sample_dpm_2': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
|
||||
'sample_dpm_fast': ['s_noise'],
|
||||
'sample_dpm_2_ancestral': ['s_noise'],
|
||||
'sample_dpmpp_2s_ancestral': ['s_noise'],
|
||||
'sample_dpmpp_sde': ['s_noise'],
|
||||
'sample_dpmpp_2m_sde': ['s_noise'],
|
||||
'sample_dpmpp_3m_sde': ['s_noise'],
|
||||
}
|
||||
|
||||
k_diffusion_samplers_map = {x.name: x for x in samplers_data_k_diffusion}
|
||||
@ -67,6 +76,8 @@ class KDiffusionSampler(sd_samplers_common.Sampler):
|
||||
def __init__(self, funcname, sd_model, options=None):
|
||||
super().__init__(funcname)
|
||||
|
||||
self.extra_params = sampler_extra_params.get(funcname, [])
|
||||
|
||||
self.options = options or {}
|
||||
self.func = funcname if callable(funcname) else getattr(k_diffusion.sampling, self.funcname)
|
||||
|
||||
|
@ -11,7 +11,7 @@ from modules.models.diffusion.uni_pc import uni_pc
|
||||
def ddim(model, x, timesteps, extra_args=None, callback=None, disable=None, eta=0.0):
|
||||
alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
|
||||
alphas = alphas_cumprod[timesteps]
|
||||
alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(torch.float64)
|
||||
alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(torch.float64 if x.device.type != 'mps' else torch.float32)
|
||||
sqrt_one_minus_alphas = torch.sqrt(1 - alphas)
|
||||
sigmas = eta * np.sqrt((1 - alphas_prev.cpu().numpy()) / (1 - alphas.cpu()) * (1 - alphas.cpu() / alphas_prev.cpu().numpy()))
|
||||
|
||||
@ -42,7 +42,7 @@ def ddim(model, x, timesteps, extra_args=None, callback=None, disable=None, eta=
|
||||
def plms(model, x, timesteps, extra_args=None, callback=None, disable=None):
|
||||
alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
|
||||
alphas = alphas_cumprod[timesteps]
|
||||
alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(torch.float64)
|
||||
alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(torch.float64 if x.device.type != 'mps' else torch.float32)
|
||||
sqrt_one_minus_alphas = torch.sqrt(1 - alphas)
|
||||
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
|
@ -285,12 +285,12 @@ options_templates.update(options_section(('ui', "Live previews"), {
|
||||
options_templates.update(options_section(('sampler-params', "Sampler parameters"), {
|
||||
"hide_samplers": OptionInfo([], "Hide samplers in user interface", gr.CheckboxGroup, lambda: {"choices": [x.name for x in shared_items.list_samplers()]}).needs_reload_ui(),
|
||||
"eta_ddim": OptionInfo(0.0, "Eta for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}, infotext='Eta DDIM').info("noise multiplier; higher = more unperdictable results"),
|
||||
"eta_ancestral": OptionInfo(1.0, "Eta for ancestral samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}, infotext='Eta').info("noise multiplier; applies to Euler a and other samplers that have a in them"),
|
||||
"eta_ancestral": OptionInfo(1.0, "Eta for k-diffusion samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}, infotext='Eta').info("noise multiplier; currently only applies to ancestral samplers (i.e. Euler a) and SDE samplers"),
|
||||
"ddim_discretize": OptionInfo('uniform', "img2img DDIM discretize", gr.Radio, {"choices": ['uniform', 'quad']}),
|
||||
's_churn': OptionInfo(0.0, "sigma churn", gr.Slider, {"minimum": 0.0, "maximum": 100.0, "step": 0.01}, infotext='Sigma churn').info('amount of stochasticity; only applies to Euler, Heun, and DPM2'),
|
||||
's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 10.0, "step": 0.01}, infotext='Sigma tmin').info('enable stochasticity; start value of the sigma range; only applies to Euler, Heun, and DPM2'),
|
||||
's_tmax': OptionInfo(0.0, "sigma tmax", gr.Slider, {"minimum": 0.0, "maximum": 999.0, "step": 0.01}, infotext='Sigma tmax').info("0 = inf; end value of the sigma range; only applies to Euler, Heun, and DPM2"),
|
||||
's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.1, "step": 0.001}, infotext='Sigma noise').info('amount of additional noise to counteract loss of detail during sampling; only applies to Euler, Heun, and DPM2'),
|
||||
's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.1, "step": 0.001}, infotext='Sigma noise').info('amount of additional noise to counteract loss of detail during sampling'),
|
||||
'k_sched_type': OptionInfo("Automatic", "Scheduler type", gr.Dropdown, {"choices": ["Automatic", "karras", "exponential", "polyexponential"]}, infotext='Schedule type').info("lets you override the noise schedule for k-diffusion samplers; choosing Automatic disables the three parameters below"),
|
||||
'sigma_min': OptionInfo(0.0, "sigma min", gr.Number, infotext='Schedule max sigma').info("0 = default (~0.03); minimum noise strength for k-diffusion noise scheduler"),
|
||||
'sigma_max': OptionInfo(0.0, "sigma max", gr.Number, infotext='Schedule min sigma').info("0 = default (~14.6); maximum noise strength for k-diffusion noise scheduler"),
|
||||
|
@ -58,7 +58,7 @@ def _summarize_chunk(
|
||||
scale: float,
|
||||
) -> AttnChunk:
|
||||
attn_weights = torch.baddbmm(
|
||||
torch.empty(1, 1, 1, device=query.device, dtype=query.dtype),
|
||||
torch.zeros(1, 1, 1, device=query.device, dtype=query.dtype),
|
||||
query,
|
||||
key.transpose(1,2),
|
||||
alpha=scale,
|
||||
@ -121,7 +121,7 @@ def _get_attention_scores_no_kv_chunking(
|
||||
scale: float,
|
||||
) -> Tensor:
|
||||
attn_scores = torch.baddbmm(
|
||||
torch.empty(1, 1, 1, device=query.device, dtype=query.dtype),
|
||||
torch.zeros(1, 1, 1, device=query.device, dtype=query.dtype),
|
||||
query,
|
||||
key.transpose(1,2),
|
||||
alpha=scale,
|
||||
|
@ -19,6 +19,7 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
|
||||
return {
|
||||
"name": checkpoint.name_for_extra,
|
||||
"filename": checkpoint.filename,
|
||||
"shorthash": checkpoint.shorthash,
|
||||
"preview": self.find_preview(path),
|
||||
"description": self.find_description(path),
|
||||
"search_term": self.search_terms_from_path(checkpoint.filename) + " " + (checkpoint.sha256 or ""),
|
||||
|
@ -2,6 +2,7 @@ import os
|
||||
|
||||
from modules import shared, ui_extra_networks
|
||||
from modules.ui_extra_networks import quote_js
|
||||
from modules.hashes import sha256_from_cache
|
||||
|
||||
|
||||
class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage):
|
||||
@ -14,13 +15,16 @@ class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage):
|
||||
def create_item(self, name, index=None, enable_filter=True):
|
||||
full_path = shared.hypernetworks[name]
|
||||
path, ext = os.path.splitext(full_path)
|
||||
sha256 = sha256_from_cache(full_path, f'hypernet/{name}')
|
||||
shorthash = sha256[0:10] if sha256 else None
|
||||
|
||||
return {
|
||||
"name": name,
|
||||
"filename": full_path,
|
||||
"shorthash": shorthash,
|
||||
"preview": self.find_preview(path),
|
||||
"description": self.find_description(path),
|
||||
"search_term": self.search_terms_from_path(path),
|
||||
"search_term": self.search_terms_from_path(path) + " " + (sha256 or ""),
|
||||
"prompt": quote_js(f"<hypernet:{name}:") + " + opts.extra_networks_default_multiplier + " + quote_js(">"),
|
||||
"local_preview": f"{path}.preview.{shared.opts.samples_format}",
|
||||
"sort_keys": {'default': index, **self.get_sort_keys(path + ext)},
|
||||
|
@ -19,9 +19,10 @@ class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage):
|
||||
return {
|
||||
"name": name,
|
||||
"filename": embedding.filename,
|
||||
"shorthash": embedding.shorthash,
|
||||
"preview": self.find_preview(path),
|
||||
"description": self.find_description(path),
|
||||
"search_term": self.search_terms_from_path(embedding.filename),
|
||||
"search_term": self.search_terms_from_path(embedding.filename) + " " + (embedding.hash or ""),
|
||||
"prompt": quote_js(embedding.name),
|
||||
"local_preview": f"{path}.preview.{shared.opts.samples_format}",
|
||||
"sort_keys": {'default': index, **self.get_sort_keys(embedding.filename)},
|
||||
|
@ -93,11 +93,13 @@ class UserMetadataEditor:
|
||||
item = self.page.items.get(name, {})
|
||||
try:
|
||||
filename = item["filename"]
|
||||
shorthash = item.get("shorthash", None)
|
||||
|
||||
stats = os.stat(filename)
|
||||
params = [
|
||||
('Filename: ', os.path.basename(filename)),
|
||||
('File size: ', sysinfo.pretty_bytes(stats.st_size)),
|
||||
('Hash: ', shorthash),
|
||||
('Modified: ', datetime.datetime.fromtimestamp(stats.st_mtime).strftime('%Y-%m-%d %H:%M')),
|
||||
]
|
||||
|
||||
@ -115,7 +117,7 @@ class UserMetadataEditor:
|
||||
errors.display(e, f"reading metadata info for {name}")
|
||||
params = []
|
||||
|
||||
table = '<table class="file-metadata">' + "".join(f"<tr><th>{name}</th><td>{value}</td></tr>" for name, value in params) + '</table>'
|
||||
table = '<table class="file-metadata">' + "".join(f"<tr><th>{name}</th><td>{value}</td></tr>" for name, value in params if value is not None) + '</table>'
|
||||
|
||||
return html.escape(name), user_metadata.get('description', ''), table, self.get_card_html(name), user_metadata.get('notes', '')
|
||||
|
||||
|
@ -175,14 +175,22 @@ def do_nothing(p, x, xs):
|
||||
def format_nothing(p, opt, x):
|
||||
return ""
|
||||
|
||||
|
||||
def format_remove_path(p, opt, x):
|
||||
return os.path.basename(x)
|
||||
|
||||
|
||||
def str_permutations(x):
|
||||
"""dummy function for specifying it in AxisOption's type when you want to get a list of permutations"""
|
||||
return x
|
||||
|
||||
|
||||
def list_to_csv_string(data_list):
|
||||
with StringIO() as o:
|
||||
csv.writer(o).writerow(data_list)
|
||||
return o.getvalue().strip()
|
||||
|
||||
|
||||
class AxisOption:
|
||||
def __init__(self, label, type, apply, format_value=format_value_add_label, confirm=None, cost=0.0, choices=None):
|
||||
self.label = label
|
||||
@ -199,6 +207,7 @@ class AxisOptionImg2Img(AxisOption):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.is_img2img = True
|
||||
|
||||
|
||||
class AxisOptionTxt2Img(AxisOption):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
@ -286,11 +295,10 @@ def draw_xyz_grid(p, xs, ys, zs, x_labels, y_labels, z_labels, cell, draw_legend
|
||||
cell_size = (processed_result.width, processed_result.height)
|
||||
if processed_result.images[0] is not None:
|
||||
cell_mode = processed_result.images[0].mode
|
||||
#This corrects size in case of batches:
|
||||
# This corrects size in case of batches:
|
||||
cell_size = processed_result.images[0].size
|
||||
processed_result.images[idx] = Image.new(cell_mode, cell_size)
|
||||
|
||||
|
||||
if first_axes_processed == 'x':
|
||||
for ix, x in enumerate(xs):
|
||||
if second_axes_processed == 'y':
|
||||
@ -348,9 +356,9 @@ def draw_xyz_grid(p, xs, ys, zs, x_labels, y_labels, z_labels, cell, draw_legend
|
||||
if draw_legend:
|
||||
z_grid = images.draw_grid_annotations(z_grid, sub_grid_size[0], sub_grid_size[1], title_texts, [[images.GridAnnotation()]])
|
||||
processed_result.images.insert(0, z_grid)
|
||||
#TODO: Deeper aspects of the program rely on grid info being misaligned between metadata arrays, which is not ideal.
|
||||
#processed_result.all_prompts.insert(0, processed_result.all_prompts[0])
|
||||
#processed_result.all_seeds.insert(0, processed_result.all_seeds[0])
|
||||
# TODO: Deeper aspects of the program rely on grid info being misaligned between metadata arrays, which is not ideal.
|
||||
# processed_result.all_prompts.insert(0, processed_result.all_prompts[0])
|
||||
# processed_result.all_seeds.insert(0, processed_result.all_seeds[0])
|
||||
processed_result.infotexts.insert(0, processed_result.infotexts[0])
|
||||
|
||||
return processed_result
|
||||
@ -374,8 +382,8 @@ class SharedSettingsStackHelper(object):
|
||||
re_range = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\(([+-]\d+)\s*\))?\s*")
|
||||
re_range_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d*)?)(?:\s*\(([+-]\d+(?:.\d*)?)\s*\))?\s*")
|
||||
|
||||
re_range_count = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\[(\d+)\s*\])?\s*")
|
||||
re_range_count_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d*)?)(?:\s*\[(\d+(?:.\d*)?)\s*\])?\s*")
|
||||
re_range_count = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\[(\d+)\s*])?\s*")
|
||||
re_range_count_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d*)?)(?:\s*\[(\d+(?:.\d*)?)\s*])?\s*")
|
||||
|
||||
|
||||
class Script(scripts.Script):
|
||||
@ -390,19 +398,19 @@ class Script(scripts.Script):
|
||||
with gr.Row():
|
||||
x_type = gr.Dropdown(label="X type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[1].label, type="index", elem_id=self.elem_id("x_type"))
|
||||
x_values = gr.Textbox(label="X values", lines=1, elem_id=self.elem_id("x_values"))
|
||||
x_values_dropdown = gr.Dropdown(label="X values",visible=False,multiselect=True,interactive=True)
|
||||
x_values_dropdown = gr.Dropdown(label="X values", visible=False, multiselect=True, interactive=True)
|
||||
fill_x_button = ToolButton(value=fill_values_symbol, elem_id="xyz_grid_fill_x_tool_button", visible=False)
|
||||
|
||||
with gr.Row():
|
||||
y_type = gr.Dropdown(label="Y type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[0].label, type="index", elem_id=self.elem_id("y_type"))
|
||||
y_values = gr.Textbox(label="Y values", lines=1, elem_id=self.elem_id("y_values"))
|
||||
y_values_dropdown = gr.Dropdown(label="Y values",visible=False,multiselect=True,interactive=True)
|
||||
y_values_dropdown = gr.Dropdown(label="Y values", visible=False, multiselect=True, interactive=True)
|
||||
fill_y_button = ToolButton(value=fill_values_symbol, elem_id="xyz_grid_fill_y_tool_button", visible=False)
|
||||
|
||||
with gr.Row():
|
||||
z_type = gr.Dropdown(label="Z type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[0].label, type="index", elem_id=self.elem_id("z_type"))
|
||||
z_values = gr.Textbox(label="Z values", lines=1, elem_id=self.elem_id("z_values"))
|
||||
z_values_dropdown = gr.Dropdown(label="Z values",visible=False,multiselect=True,interactive=True)
|
||||
z_values_dropdown = gr.Dropdown(label="Z values", visible=False, multiselect=True, interactive=True)
|
||||
fill_z_button = ToolButton(value=fill_values_symbol, elem_id="xyz_grid_fill_z_tool_button", visible=False)
|
||||
|
||||
with gr.Row(variant="compact", elem_id="axis_options"):
|
||||
@ -414,6 +422,9 @@ class Script(scripts.Script):
|
||||
include_sub_grids = gr.Checkbox(label='Include Sub Grids', value=False, elem_id=self.elem_id("include_sub_grids"))
|
||||
with gr.Column():
|
||||
margin_size = gr.Slider(label="Grid margins (px)", minimum=0, maximum=500, value=0, step=2, elem_id=self.elem_id("margin_size"))
|
||||
with gr.Column():
|
||||
csv_mode = gr.Checkbox(label='Use text inputs instead of dropdowns', value=False, elem_id=self.elem_id("csv_mode"))
|
||||
|
||||
|
||||
with gr.Row(variant="compact", elem_id="swap_axes"):
|
||||
swap_xy_axes_button = gr.Button(value="Swap X/Y axes", elem_id="xy_grid_swap_axes_button")
|
||||
@ -430,50 +441,71 @@ class Script(scripts.Script):
|
||||
xz_swap_args = [x_type, x_values, x_values_dropdown, z_type, z_values, z_values_dropdown]
|
||||
swap_xz_axes_button.click(swap_axes, inputs=xz_swap_args, outputs=xz_swap_args)
|
||||
|
||||
def fill(x_type):
|
||||
axis = self.current_axis_options[x_type]
|
||||
return axis.choices() if axis.choices else gr.update()
|
||||
def fill(axis_type, csv_mode):
|
||||
axis = self.current_axis_options[axis_type]
|
||||
if axis.choices:
|
||||
if csv_mode:
|
||||
return list_to_csv_string(axis.choices()), gr.update()
|
||||
else:
|
||||
return gr.update(), axis.choices()
|
||||
else:
|
||||
return gr.update(), gr.update()
|
||||
|
||||
fill_x_button.click(fn=fill, inputs=[x_type], outputs=[x_values_dropdown])
|
||||
fill_y_button.click(fn=fill, inputs=[y_type], outputs=[y_values_dropdown])
|
||||
fill_z_button.click(fn=fill, inputs=[z_type], outputs=[z_values_dropdown])
|
||||
fill_x_button.click(fn=fill, inputs=[x_type, csv_mode], outputs=[x_values, x_values_dropdown])
|
||||
fill_y_button.click(fn=fill, inputs=[y_type, csv_mode], outputs=[y_values, y_values_dropdown])
|
||||
fill_z_button.click(fn=fill, inputs=[z_type, csv_mode], outputs=[z_values, z_values_dropdown])
|
||||
|
||||
def select_axis(axis_type,axis_values_dropdown):
|
||||
def select_axis(axis_type, axis_values, axis_values_dropdown, csv_mode):
|
||||
choices = self.current_axis_options[axis_type].choices
|
||||
has_choices = choices is not None
|
||||
current_values = axis_values_dropdown
|
||||
|
||||
current_values = axis_values
|
||||
current_dropdown_values = axis_values_dropdown
|
||||
if has_choices:
|
||||
choices = choices()
|
||||
if isinstance(current_values,str):
|
||||
current_values = current_values.split(",")
|
||||
current_values = list(filter(lambda x: x in choices, current_values))
|
||||
return gr.Button.update(visible=has_choices),gr.Textbox.update(visible=not has_choices),gr.update(choices=choices if has_choices else None,visible=has_choices,value=current_values)
|
||||
if csv_mode:
|
||||
current_dropdown_values = list(filter(lambda x: x in choices, current_dropdown_values))
|
||||
current_values = list_to_csv_string(current_dropdown_values)
|
||||
else:
|
||||
current_dropdown_values = [x.strip() for x in chain.from_iterable(csv.reader(StringIO(axis_values)))]
|
||||
current_dropdown_values = list(filter(lambda x: x in choices, current_dropdown_values))
|
||||
|
||||
x_type.change(fn=select_axis, inputs=[x_type,x_values_dropdown], outputs=[fill_x_button,x_values,x_values_dropdown])
|
||||
y_type.change(fn=select_axis, inputs=[y_type,y_values_dropdown], outputs=[fill_y_button,y_values,y_values_dropdown])
|
||||
z_type.change(fn=select_axis, inputs=[z_type,z_values_dropdown], outputs=[fill_z_button,z_values,z_values_dropdown])
|
||||
return (gr.Button.update(visible=has_choices), gr.Textbox.update(visible=not has_choices or csv_mode, value=current_values),
|
||||
gr.update(choices=choices if has_choices else None, visible=has_choices and not csv_mode, value=current_dropdown_values))
|
||||
|
||||
def get_dropdown_update_from_params(axis,params):
|
||||
x_type.change(fn=select_axis, inputs=[x_type, x_values, x_values_dropdown, csv_mode], outputs=[fill_x_button, x_values, x_values_dropdown])
|
||||
y_type.change(fn=select_axis, inputs=[y_type, y_values, y_values_dropdown, csv_mode], outputs=[fill_y_button, y_values, y_values_dropdown])
|
||||
z_type.change(fn=select_axis, inputs=[z_type, z_values, z_values_dropdown, csv_mode], outputs=[fill_z_button, z_values, z_values_dropdown])
|
||||
|
||||
def change_choice_mode(csv_mode, x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown):
|
||||
_fill_x_button, _x_values, _x_values_dropdown = select_axis(x_type, x_values, x_values_dropdown, csv_mode)
|
||||
_fill_y_button, _y_values, _y_values_dropdown = select_axis(y_type, y_values, y_values_dropdown, csv_mode)
|
||||
_fill_z_button, _z_values, _z_values_dropdown = select_axis(z_type, z_values, z_values_dropdown, csv_mode)
|
||||
return _fill_x_button, _x_values, _x_values_dropdown, _fill_y_button, _y_values, _y_values_dropdown, _fill_z_button, _z_values, _z_values_dropdown
|
||||
|
||||
csv_mode.change(fn=change_choice_mode, inputs=[csv_mode, x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown], outputs=[fill_x_button, x_values, x_values_dropdown, fill_y_button, y_values, y_values_dropdown, fill_z_button, z_values, z_values_dropdown])
|
||||
|
||||
def get_dropdown_update_from_params(axis, params):
|
||||
val_key = f"{axis} Values"
|
||||
vals = params.get(val_key,"")
|
||||
vals = params.get(val_key, "")
|
||||
valslist = [x.strip() for x in chain.from_iterable(csv.reader(StringIO(vals))) if x]
|
||||
return gr.update(value = valslist)
|
||||
return gr.update(value=valslist)
|
||||
|
||||
self.infotext_fields = (
|
||||
(x_type, "X Type"),
|
||||
(x_values, "X Values"),
|
||||
(x_values_dropdown, lambda params:get_dropdown_update_from_params("X",params)),
|
||||
(x_values_dropdown, lambda params: get_dropdown_update_from_params("X", params)),
|
||||
(y_type, "Y Type"),
|
||||
(y_values, "Y Values"),
|
||||
(y_values_dropdown, lambda params:get_dropdown_update_from_params("Y",params)),
|
||||
(y_values_dropdown, lambda params: get_dropdown_update_from_params("Y", params)),
|
||||
(z_type, "Z Type"),
|
||||
(z_values, "Z Values"),
|
||||
(z_values_dropdown, lambda params:get_dropdown_update_from_params("Z",params)),
|
||||
(z_values_dropdown, lambda params: get_dropdown_update_from_params("Z", params)),
|
||||
)
|
||||
|
||||
return [x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, margin_size]
|
||||
return [x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, margin_size, csv_mode]
|
||||
|
||||
def run(self, p, x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, margin_size):
|
||||
def run(self, p, x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, margin_size, csv_mode):
|
||||
if not no_fixed_seeds:
|
||||
modules.processing.fix_seed(p)
|
||||
|
||||
@ -484,7 +516,7 @@ class Script(scripts.Script):
|
||||
if opt.label == 'Nothing':
|
||||
return [0]
|
||||
|
||||
if opt.choices is not None:
|
||||
if opt.choices is not None and not csv_mode:
|
||||
valslist = vals_dropdown
|
||||
else:
|
||||
valslist = [x.strip() for x in chain.from_iterable(csv.reader(StringIO(vals))) if x]
|
||||
@ -503,8 +535,8 @@ class Script(scripts.Script):
|
||||
valslist_ext += list(range(start, end, step))
|
||||
elif mc is not None:
|
||||
start = int(mc.group(1))
|
||||
end = int(mc.group(2))
|
||||
num = int(mc.group(3)) if mc.group(3) is not None else 1
|
||||
end = int(mc.group(2))
|
||||
num = int(mc.group(3)) if mc.group(3) is not None else 1
|
||||
|
||||
valslist_ext += [int(x) for x in np.linspace(start=start, stop=end, num=num).tolist()]
|
||||
else:
|
||||
@ -525,8 +557,8 @@ class Script(scripts.Script):
|
||||
valslist_ext += np.arange(start, end + step, step).tolist()
|
||||
elif mc is not None:
|
||||
start = float(mc.group(1))
|
||||
end = float(mc.group(2))
|
||||
num = int(mc.group(3)) if mc.group(3) is not None else 1
|
||||
end = float(mc.group(2))
|
||||
num = int(mc.group(3)) if mc.group(3) is not None else 1
|
||||
|
||||
valslist_ext += np.linspace(start=start, stop=end, num=num).tolist()
|
||||
else:
|
||||
@ -545,22 +577,22 @@ class Script(scripts.Script):
|
||||
return valslist
|
||||
|
||||
x_opt = self.current_axis_options[x_type]
|
||||
if x_opt.choices is not None:
|
||||
x_values = ",".join(x_values_dropdown)
|
||||
if x_opt.choices is not None and not csv_mode:
|
||||
x_values = list_to_csv_string(x_values_dropdown)
|
||||
xs = process_axis(x_opt, x_values, x_values_dropdown)
|
||||
|
||||
y_opt = self.current_axis_options[y_type]
|
||||
if y_opt.choices is not None:
|
||||
y_values = ",".join(y_values_dropdown)
|
||||
if y_opt.choices is not None and not csv_mode:
|
||||
y_values = list_to_csv_string(y_values_dropdown)
|
||||
ys = process_axis(y_opt, y_values, y_values_dropdown)
|
||||
|
||||
z_opt = self.current_axis_options[z_type]
|
||||
if z_opt.choices is not None:
|
||||
z_values = ",".join(z_values_dropdown)
|
||||
if z_opt.choices is not None and not csv_mode:
|
||||
z_values = list_to_csv_string(z_values_dropdown)
|
||||
zs = process_axis(z_opt, z_values, z_values_dropdown)
|
||||
|
||||
# this could be moved to common code, but unlikely to be ever triggered anywhere else
|
||||
Image.MAX_IMAGE_PIXELS = None # disable check in Pillow and rely on check below to allow large custom image sizes
|
||||
Image.MAX_IMAGE_PIXELS = None # disable check in Pillow and rely on check below to allow large custom image sizes
|
||||
grid_mp = round(len(xs) * len(ys) * len(zs) * p.width * p.height / 1000000)
|
||||
assert grid_mp < opts.img_max_size_mp, f'Error: Resulting grid would be too large ({grid_mp} MPixels) (max configured size is {opts.img_max_size_mp} MPixels)'
|
||||
|
||||
@ -720,7 +752,7 @@ class Script(scripts.Script):
|
||||
# Auto-save main and sub-grids:
|
||||
grid_count = z_count + 1 if z_count > 1 else 1
|
||||
for g in range(grid_count):
|
||||
#TODO: See previous comment about intentional data misalignment.
|
||||
# TODO: See previous comment about intentional data misalignment.
|
||||
adj_g = g-1 if g > 0 else g
|
||||
images.save_image(processed.images[g], p.outpath_grids, "xyz_grid", info=processed.infotexts[g], extension=opts.grid_format, prompt=processed.all_prompts[adj_g], seed=processed.all_seeds[adj_g], grid=True, p=processed)
|
||||
|
||||
|
@ -12,8 +12,6 @@ fi
|
||||
export install_dir="$HOME"
|
||||
export COMMANDLINE_ARGS="--skip-torch-cuda-test --upcast-sampling --no-half-vae --use-cpu interrogate"
|
||||
export TORCH_COMMAND="pip install torch==2.0.1 torchvision==0.15.2"
|
||||
export K_DIFFUSION_REPO="https://github.com/brkirch/k-diffusion.git"
|
||||
export K_DIFFUSION_COMMIT_HASH="51c9778f269cedb55a4d88c79c0246d35bdadb71"
|
||||
export PYTORCH_ENABLE_MPS_FALLBACK=1
|
||||
|
||||
####################################################################
|
||||
|
Loading…
Reference in New Issue
Block a user