mirror of
https://github.com/openvinotoolkit/stable-diffusion-webui.git
synced 2024-12-14 14:45:06 +03:00
Merge branch 'dev' into extra-networks-always-visible
This commit is contained in:
commit
ef1698fd6d
1
.github/workflows/run_tests.yaml
vendored
1
.github/workflows/run_tests.yaml
vendored
@ -41,6 +41,7 @@ jobs:
|
|||||||
--skip-prepare-environment
|
--skip-prepare-environment
|
||||||
--skip-torch-cuda-test
|
--skip-torch-cuda-test
|
||||||
--test-server
|
--test-server
|
||||||
|
--do-not-download-clip
|
||||||
--no-half
|
--no-half
|
||||||
--disable-opt-split-attention
|
--disable-opt-split-attention
|
||||||
--use-cpu all
|
--use-cpu all
|
||||||
|
90
CHANGELOG.md
90
CHANGELOG.md
@ -1,3 +1,93 @@
|
|||||||
|
## 1.5.1
|
||||||
|
|
||||||
|
### Minor:
|
||||||
|
* support parsing text encoder blocks in some new LoRAs
|
||||||
|
* delete scale checker script due to user demand
|
||||||
|
|
||||||
|
### Extensions and API:
|
||||||
|
* add postprocess_batch_list script callback
|
||||||
|
|
||||||
|
### Bug Fixes:
|
||||||
|
* fix TI training for SD1
|
||||||
|
* fix reload altclip model error
|
||||||
|
* prepend the pythonpath instead of overriding it
|
||||||
|
* fix typo in SD_WEBUI_RESTARTING
|
||||||
|
* if txt2img/img2img raises an exception, finally call state.end()
|
||||||
|
* fix composable diffusion weight parsing
|
||||||
|
* restyle Startup profile for black users
|
||||||
|
* fix webui not launching with --nowebui
|
||||||
|
* catch exception for non git extensions
|
||||||
|
* fix some options missing from /sdapi/v1/options
|
||||||
|
* fix for extension update status always saying "unknown"
|
||||||
|
* fix display of extra network cards that have `<>` in the name
|
||||||
|
* update lora extension to work with python 3.8
|
||||||
|
|
||||||
|
|
||||||
|
## 1.5.0
|
||||||
|
|
||||||
|
### Features:
|
||||||
|
* SD XL support
|
||||||
|
* user metadata system for custom networks
|
||||||
|
* extended Lora metadata editor: set activation text, default weight, view tags, training info
|
||||||
|
* Lora extension rework to include other types of networks (all that were previously handled by LyCORIS extension)
|
||||||
|
* show github stars for extenstions
|
||||||
|
* img2img batch mode can read extra stuff from png info
|
||||||
|
* img2img batch works with subdirectories
|
||||||
|
* hotkeys to move prompt elements: alt+left/right
|
||||||
|
* restyle time taken/VRAM display
|
||||||
|
* add textual inversion hashes to infotext
|
||||||
|
* optimization: cache git extension repo information
|
||||||
|
* move generate button next to the generated picture for mobile clients
|
||||||
|
* hide cards for networks of incompatible Stable Diffusion version in Lora extra networks interface
|
||||||
|
* skip installing packages with pip if they all are already installed - startup speedup of about 2 seconds
|
||||||
|
|
||||||
|
### Minor:
|
||||||
|
* checkbox to check/uncheck all extensions in the Installed tab
|
||||||
|
* add gradio user to infotext and to filename patterns
|
||||||
|
* allow gif for extra network previews
|
||||||
|
* add options to change colors in grid
|
||||||
|
* use natural sort for items in extra networks
|
||||||
|
* Mac: use empty_cache() from torch 2 to clear VRAM
|
||||||
|
* added automatic support for installing the right libraries for Navi3 (AMD)
|
||||||
|
* add option SWIN_torch_compile to accelerate SwinIR upscale
|
||||||
|
* suppress printing TI embedding info at start to console by default
|
||||||
|
* speedup extra networks listing
|
||||||
|
* added `[none]` filename token.
|
||||||
|
* removed thumbs extra networks view mode (use settings tab to change width/height/scale to get thumbs)
|
||||||
|
* add always_discard_next_to_last_sigma option to XYZ plot
|
||||||
|
* automatically switch to 32-bit float VAE if the generated picture has NaNs without the need for `--no-half-vae` commandline flag.
|
||||||
|
|
||||||
|
### Extensions and API:
|
||||||
|
* api endpoints: /sdapi/v1/server-kill, /sdapi/v1/server-restart, /sdapi/v1/server-stop
|
||||||
|
* allow Script to have custom metaclass
|
||||||
|
* add model exists status check /sdapi/v1/options
|
||||||
|
* rename --add-stop-route to --api-server-stop
|
||||||
|
* add `before_hr` script callback
|
||||||
|
* add callback `after_extra_networks_activate`
|
||||||
|
* disable rich exception output in console for API by default, use WEBUI_RICH_EXCEPTIONS env var to enable
|
||||||
|
* return http 404 when thumb file not found
|
||||||
|
* allow replacing extensions index with environment variable
|
||||||
|
|
||||||
|
### Bug Fixes:
|
||||||
|
* fix for catch errors when retrieving extension index #11290
|
||||||
|
* fix very slow loading speed of .safetensors files when reading from network drives
|
||||||
|
* API cache cleanup
|
||||||
|
* fix UnicodeEncodeError when writing to file CLIP Interrogator batch mode
|
||||||
|
* fix warning of 'has_mps' deprecated from PyTorch
|
||||||
|
* fix problem with extra network saving images as previews losing generation info
|
||||||
|
* fix throwing exception when trying to resize image with I;16 mode
|
||||||
|
* fix for #11534: canvas zoom and pan extension hijacking shortcut keys
|
||||||
|
* fixed launch script to be runnable from any directory
|
||||||
|
* don't add "Seed Resize: -1x-1" to API image metadata
|
||||||
|
* correctly remove end parenthesis with ctrl+up/down
|
||||||
|
* fixing --subpath on newer gradio version
|
||||||
|
* fix: check fill size none zero when resize (fixes #11425)
|
||||||
|
* use submit and blur for quick settings textbox
|
||||||
|
* save img2img batch with images.save_image()
|
||||||
|
* prevent running preload.py for disabled extensions
|
||||||
|
* fix: previously, model name was added together with directory name to infotext and to [model_name] filename pattern; directory name is now not included
|
||||||
|
|
||||||
|
|
||||||
## 1.4.1
|
## 1.4.1
|
||||||
|
|
||||||
### Bug Fixes:
|
### Bug Fixes:
|
||||||
|
@ -88,7 +88,7 @@ A browser interface based on Gradio library for Stable Diffusion.
|
|||||||
- [Alt-Diffusion](https://arxiv.org/abs/2211.06679) support - see [wiki](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#alt-diffusion) for instructions
|
- [Alt-Diffusion](https://arxiv.org/abs/2211.06679) support - see [wiki](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#alt-diffusion) for instructions
|
||||||
- Now without any bad letters!
|
- Now without any bad letters!
|
||||||
- Load checkpoints in safetensors format
|
- Load checkpoints in safetensors format
|
||||||
- Eased resolution restriction: generated image's domension must be a multiple of 8 rather than 64
|
- Eased resolution restriction: generated image's dimension must be a multiple of 8 rather than 64
|
||||||
- Now with a license!
|
- Now with a license!
|
||||||
- Reorder elements in the UI from settings screen
|
- Reorder elements in the UI from settings screen
|
||||||
|
|
||||||
@ -168,5 +168,7 @@ Licenses for borrowed code can be found in `Settings -> Licenses` screen, and al
|
|||||||
- Security advice - RyotaK
|
- Security advice - RyotaK
|
||||||
- UniPC sampler - Wenliang Zhao - https://github.com/wl-zhao/UniPC
|
- UniPC sampler - Wenliang Zhao - https://github.com/wl-zhao/UniPC
|
||||||
- TAESD - Ollin Boer Bohan - https://github.com/madebyollin/taesd
|
- TAESD - Ollin Boer Bohan - https://github.com/madebyollin/taesd
|
||||||
|
- LyCORIS - KohakuBlueleaf
|
||||||
|
- Restart sampling - lambertae - https://github.com/Newbeeer/diffusion_restart_sampling
|
||||||
- Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user.
|
- Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user.
|
||||||
- (You)
|
- (You)
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
from modules import extra_networks, shared
|
from modules import extra_networks, shared
|
||||||
import lora
|
import networks
|
||||||
|
|
||||||
|
|
||||||
class ExtraNetworkLora(extra_networks.ExtraNetwork):
|
class ExtraNetworkLora(extra_networks.ExtraNetwork):
|
||||||
@ -9,24 +9,38 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork):
|
|||||||
def activate(self, p, params_list):
|
def activate(self, p, params_list):
|
||||||
additional = shared.opts.sd_lora
|
additional = shared.opts.sd_lora
|
||||||
|
|
||||||
if additional != "None" and additional in lora.available_loras and not any(x for x in params_list if x.items[0] == additional):
|
if additional != "None" and additional in networks.available_networks and not any(x for x in params_list if x.items[0] == additional):
|
||||||
p.all_prompts = [x + f"<lora:{additional}:{shared.opts.extra_networks_default_multiplier}>" for x in p.all_prompts]
|
p.all_prompts = [x + f"<lora:{additional}:{shared.opts.extra_networks_default_multiplier}>" for x in p.all_prompts]
|
||||||
params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))
|
params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))
|
||||||
|
|
||||||
names = []
|
names = []
|
||||||
multipliers = []
|
te_multipliers = []
|
||||||
|
unet_multipliers = []
|
||||||
|
dyn_dims = []
|
||||||
for params in params_list:
|
for params in params_list:
|
||||||
assert params.items
|
assert params.items
|
||||||
|
|
||||||
names.append(params.items[0])
|
names.append(params.positional[0])
|
||||||
multipliers.append(float(params.items[1]) if len(params.items) > 1 else 1.0)
|
|
||||||
|
|
||||||
lora.load_loras(names, multipliers)
|
te_multiplier = float(params.positional[1]) if len(params.positional) > 1 else 1.0
|
||||||
|
te_multiplier = float(params.named.get("te", te_multiplier))
|
||||||
|
|
||||||
|
unet_multiplier = float(params.positional[2]) if len(params.positional) > 2 else te_multiplier
|
||||||
|
unet_multiplier = float(params.named.get("unet", unet_multiplier))
|
||||||
|
|
||||||
|
dyn_dim = int(params.positional[3]) if len(params.positional) > 3 else None
|
||||||
|
dyn_dim = int(params.named["dyn"]) if "dyn" in params.named else dyn_dim
|
||||||
|
|
||||||
|
te_multipliers.append(te_multiplier)
|
||||||
|
unet_multipliers.append(unet_multiplier)
|
||||||
|
dyn_dims.append(dyn_dim)
|
||||||
|
|
||||||
|
networks.load_networks(names, te_multipliers, unet_multipliers, dyn_dims)
|
||||||
|
|
||||||
if shared.opts.lora_add_hashes_to_infotext:
|
if shared.opts.lora_add_hashes_to_infotext:
|
||||||
lora_hashes = []
|
network_hashes = []
|
||||||
for item in lora.loaded_loras:
|
for item in networks.loaded_networks:
|
||||||
shorthash = item.lora_on_disk.shorthash
|
shorthash = item.network_on_disk.shorthash
|
||||||
if not shorthash:
|
if not shorthash:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@ -36,10 +50,10 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork):
|
|||||||
|
|
||||||
alias = alias.replace(":", "").replace(",", "")
|
alias = alias.replace(":", "").replace(",", "")
|
||||||
|
|
||||||
lora_hashes.append(f"{alias}: {shorthash}")
|
network_hashes.append(f"{alias}: {shorthash}")
|
||||||
|
|
||||||
if lora_hashes:
|
if network_hashes:
|
||||||
p.extra_generation_params["Lora hashes"] = ", ".join(lora_hashes)
|
p.extra_generation_params["Lora hashes"] = ", ".join(network_hashes)
|
||||||
|
|
||||||
def deactivate(self, p):
|
def deactivate(self, p):
|
||||||
pass
|
pass
|
||||||
|
@ -1,511 +1,9 @@
|
|||||||
import os
|
import networks
|
||||||
import re
|
|
||||||
import torch
|
|
||||||
from typing import Union
|
|
||||||
|
|
||||||
from modules import shared, devices, sd_models, errors, scripts, sd_hijack, hashes, cache
|
list_available_loras = networks.list_available_networks
|
||||||
|
|
||||||
metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20}
|
available_loras = networks.available_networks
|
||||||
|
available_lora_aliases = networks.available_network_aliases
|
||||||
re_digits = re.compile(r"\d+")
|
available_lora_hash_lookup = networks.available_network_hash_lookup
|
||||||
re_x_proj = re.compile(r"(.*)_([qkv]_proj)$")
|
forbidden_lora_aliases = networks.forbidden_network_aliases
|
||||||
re_compiled = {}
|
loaded_loras = networks.loaded_networks
|
||||||
|
|
||||||
suffix_conversion = {
|
|
||||||
"attentions": {},
|
|
||||||
"resnets": {
|
|
||||||
"conv1": "in_layers_2",
|
|
||||||
"conv2": "out_layers_3",
|
|
||||||
"time_emb_proj": "emb_layers_1",
|
|
||||||
"conv_shortcut": "skip_connection",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def convert_diffusers_name_to_compvis(key, is_sd2):
|
|
||||||
def match(match_list, regex_text):
|
|
||||||
regex = re_compiled.get(regex_text)
|
|
||||||
if regex is None:
|
|
||||||
regex = re.compile(regex_text)
|
|
||||||
re_compiled[regex_text] = regex
|
|
||||||
|
|
||||||
r = re.match(regex, key)
|
|
||||||
if not r:
|
|
||||||
return False
|
|
||||||
|
|
||||||
match_list.clear()
|
|
||||||
match_list.extend([int(x) if re.match(re_digits, x) else x for x in r.groups()])
|
|
||||||
return True
|
|
||||||
|
|
||||||
m = []
|
|
||||||
|
|
||||||
if match(m, r"lora_unet_down_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"):
|
|
||||||
suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3])
|
|
||||||
return f"diffusion_model_input_blocks_{1 + m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}"
|
|
||||||
|
|
||||||
if match(m, r"lora_unet_mid_block_(attentions|resnets)_(\d+)_(.+)"):
|
|
||||||
suffix = suffix_conversion.get(m[0], {}).get(m[2], m[2])
|
|
||||||
return f"diffusion_model_middle_block_{1 if m[0] == 'attentions' else m[1] * 2}_{suffix}"
|
|
||||||
|
|
||||||
if match(m, r"lora_unet_up_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"):
|
|
||||||
suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3])
|
|
||||||
return f"diffusion_model_output_blocks_{m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}"
|
|
||||||
|
|
||||||
if match(m, r"lora_unet_down_blocks_(\d+)_downsamplers_0_conv"):
|
|
||||||
return f"diffusion_model_input_blocks_{3 + m[0] * 3}_0_op"
|
|
||||||
|
|
||||||
if match(m, r"lora_unet_up_blocks_(\d+)_upsamplers_0_conv"):
|
|
||||||
return f"diffusion_model_output_blocks_{2 + m[0] * 3}_{2 if m[0]>0 else 1}_conv"
|
|
||||||
|
|
||||||
if match(m, r"lora_te_text_model_encoder_layers_(\d+)_(.+)"):
|
|
||||||
if is_sd2:
|
|
||||||
if 'mlp_fc1' in m[1]:
|
|
||||||
return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc1', 'mlp_c_fc')}"
|
|
||||||
elif 'mlp_fc2' in m[1]:
|
|
||||||
return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc2', 'mlp_c_proj')}"
|
|
||||||
else:
|
|
||||||
return f"model_transformer_resblocks_{m[0]}_{m[1].replace('self_attn', 'attn')}"
|
|
||||||
|
|
||||||
return f"transformer_text_model_encoder_layers_{m[0]}_{m[1]}"
|
|
||||||
|
|
||||||
return key
|
|
||||||
|
|
||||||
|
|
||||||
class LoraOnDisk:
|
|
||||||
def __init__(self, name, filename):
|
|
||||||
self.name = name
|
|
||||||
self.filename = filename
|
|
||||||
self.metadata = {}
|
|
||||||
self.is_safetensors = os.path.splitext(filename)[1].lower() == ".safetensors"
|
|
||||||
|
|
||||||
def read_metadata():
|
|
||||||
metadata = sd_models.read_metadata_from_safetensors(filename)
|
|
||||||
metadata.pop('ssmd_cover_images', None) # those are cover images, and they are too big to display in UI as text
|
|
||||||
|
|
||||||
return metadata
|
|
||||||
|
|
||||||
if self.is_safetensors:
|
|
||||||
try:
|
|
||||||
self.metadata = cache.cached_data_for_file('safetensors-metadata', "lora/" + self.name, filename, read_metadata)
|
|
||||||
except Exception as e:
|
|
||||||
errors.display(e, f"reading lora {filename}")
|
|
||||||
|
|
||||||
if self.metadata:
|
|
||||||
m = {}
|
|
||||||
for k, v in sorted(self.metadata.items(), key=lambda x: metadata_tags_order.get(x[0], 999)):
|
|
||||||
m[k] = v
|
|
||||||
|
|
||||||
self.metadata = m
|
|
||||||
|
|
||||||
self.alias = self.metadata.get('ss_output_name', self.name)
|
|
||||||
|
|
||||||
self.hash = None
|
|
||||||
self.shorthash = None
|
|
||||||
self.set_hash(
|
|
||||||
self.metadata.get('sshs_model_hash') or
|
|
||||||
hashes.sha256_from_cache(self.filename, "lora/" + self.name, use_addnet_hash=self.is_safetensors) or
|
|
||||||
''
|
|
||||||
)
|
|
||||||
|
|
||||||
def set_hash(self, v):
|
|
||||||
self.hash = v
|
|
||||||
self.shorthash = self.hash[0:12]
|
|
||||||
|
|
||||||
if self.shorthash:
|
|
||||||
available_lora_hash_lookup[self.shorthash] = self
|
|
||||||
|
|
||||||
def read_hash(self):
|
|
||||||
if not self.hash:
|
|
||||||
self.set_hash(hashes.sha256(self.filename, "lora/" + self.name, use_addnet_hash=self.is_safetensors) or '')
|
|
||||||
|
|
||||||
def get_alias(self):
|
|
||||||
if shared.opts.lora_preferred_name == "Filename" or self.alias.lower() in forbidden_lora_aliases:
|
|
||||||
return self.name
|
|
||||||
else:
|
|
||||||
return self.alias
|
|
||||||
|
|
||||||
|
|
||||||
class LoraModule:
|
|
||||||
def __init__(self, name, lora_on_disk: LoraOnDisk):
|
|
||||||
self.name = name
|
|
||||||
self.lora_on_disk = lora_on_disk
|
|
||||||
self.multiplier = 1.0
|
|
||||||
self.modules = {}
|
|
||||||
self.mtime = None
|
|
||||||
|
|
||||||
self.mentioned_name = None
|
|
||||||
"""the text that was used to add lora to prompt - can be either name or an alias"""
|
|
||||||
|
|
||||||
|
|
||||||
class LoraUpDownModule:
|
|
||||||
def __init__(self):
|
|
||||||
self.up = None
|
|
||||||
self.down = None
|
|
||||||
self.alpha = None
|
|
||||||
|
|
||||||
|
|
||||||
def assign_lora_names_to_compvis_modules(sd_model):
|
|
||||||
lora_layer_mapping = {}
|
|
||||||
|
|
||||||
for name, module in shared.sd_model.cond_stage_model.wrapped.named_modules():
|
|
||||||
lora_name = name.replace(".", "_")
|
|
||||||
lora_layer_mapping[lora_name] = module
|
|
||||||
module.lora_layer_name = lora_name
|
|
||||||
|
|
||||||
for name, module in shared.sd_model.model.named_modules():
|
|
||||||
lora_name = name.replace(".", "_")
|
|
||||||
lora_layer_mapping[lora_name] = module
|
|
||||||
module.lora_layer_name = lora_name
|
|
||||||
|
|
||||||
sd_model.lora_layer_mapping = lora_layer_mapping
|
|
||||||
|
|
||||||
|
|
||||||
def load_lora(name, lora_on_disk):
|
|
||||||
lora = LoraModule(name, lora_on_disk)
|
|
||||||
lora.mtime = os.path.getmtime(lora_on_disk.filename)
|
|
||||||
|
|
||||||
sd = sd_models.read_state_dict(lora_on_disk.filename)
|
|
||||||
|
|
||||||
# this should not be needed but is here as an emergency fix for an unknown error people are experiencing in 1.2.0
|
|
||||||
if not hasattr(shared.sd_model, 'lora_layer_mapping'):
|
|
||||||
assign_lora_names_to_compvis_modules(shared.sd_model)
|
|
||||||
|
|
||||||
keys_failed_to_match = {}
|
|
||||||
is_sd2 = 'model_transformer_resblocks' in shared.sd_model.lora_layer_mapping
|
|
||||||
|
|
||||||
for key_diffusers, weight in sd.items():
|
|
||||||
key_diffusers_without_lora_parts, lora_key = key_diffusers.split(".", 1)
|
|
||||||
key = convert_diffusers_name_to_compvis(key_diffusers_without_lora_parts, is_sd2)
|
|
||||||
|
|
||||||
sd_module = shared.sd_model.lora_layer_mapping.get(key, None)
|
|
||||||
|
|
||||||
if sd_module is None:
|
|
||||||
m = re_x_proj.match(key)
|
|
||||||
if m:
|
|
||||||
sd_module = shared.sd_model.lora_layer_mapping.get(m.group(1), None)
|
|
||||||
|
|
||||||
if sd_module is None:
|
|
||||||
keys_failed_to_match[key_diffusers] = key
|
|
||||||
continue
|
|
||||||
|
|
||||||
lora_module = lora.modules.get(key, None)
|
|
||||||
if lora_module is None:
|
|
||||||
lora_module = LoraUpDownModule()
|
|
||||||
lora.modules[key] = lora_module
|
|
||||||
|
|
||||||
if lora_key == "alpha":
|
|
||||||
lora_module.alpha = weight.item()
|
|
||||||
continue
|
|
||||||
|
|
||||||
if type(sd_module) == torch.nn.Linear:
|
|
||||||
module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
|
|
||||||
elif type(sd_module) == torch.nn.modules.linear.NonDynamicallyQuantizableLinear:
|
|
||||||
module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
|
|
||||||
elif type(sd_module) == torch.nn.MultiheadAttention:
|
|
||||||
module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
|
|
||||||
elif type(sd_module) == torch.nn.Conv2d and weight.shape[2:] == (1, 1):
|
|
||||||
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False)
|
|
||||||
elif type(sd_module) == torch.nn.Conv2d and weight.shape[2:] == (3, 3):
|
|
||||||
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (3, 3), bias=False)
|
|
||||||
else:
|
|
||||||
print(f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}')
|
|
||||||
continue
|
|
||||||
raise AssertionError(f"Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}")
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
module.weight.copy_(weight)
|
|
||||||
|
|
||||||
module.to(device=devices.cpu, dtype=devices.dtype)
|
|
||||||
|
|
||||||
if lora_key == "lora_up.weight":
|
|
||||||
lora_module.up = module
|
|
||||||
elif lora_key == "lora_down.weight":
|
|
||||||
lora_module.down = module
|
|
||||||
else:
|
|
||||||
raise AssertionError(f"Bad Lora layer name: {key_diffusers} - must end in lora_up.weight, lora_down.weight or alpha")
|
|
||||||
|
|
||||||
if keys_failed_to_match:
|
|
||||||
print(f"Failed to match keys when loading Lora {lora_on_disk.filename}: {keys_failed_to_match}")
|
|
||||||
|
|
||||||
return lora
|
|
||||||
|
|
||||||
|
|
||||||
def load_loras(names, multipliers=None):
|
|
||||||
already_loaded = {}
|
|
||||||
|
|
||||||
for lora in loaded_loras:
|
|
||||||
if lora.name in names:
|
|
||||||
already_loaded[lora.name] = lora
|
|
||||||
|
|
||||||
loaded_loras.clear()
|
|
||||||
|
|
||||||
loras_on_disk = [available_lora_aliases.get(name, None) for name in names]
|
|
||||||
if any(x is None for x in loras_on_disk):
|
|
||||||
list_available_loras()
|
|
||||||
|
|
||||||
loras_on_disk = [available_lora_aliases.get(name, None) for name in names]
|
|
||||||
|
|
||||||
failed_to_load_loras = []
|
|
||||||
|
|
||||||
for i, name in enumerate(names):
|
|
||||||
lora = already_loaded.get(name, None)
|
|
||||||
|
|
||||||
lora_on_disk = loras_on_disk[i]
|
|
||||||
|
|
||||||
if lora_on_disk is not None:
|
|
||||||
if lora is None or os.path.getmtime(lora_on_disk.filename) > lora.mtime:
|
|
||||||
try:
|
|
||||||
lora = load_lora(name, lora_on_disk)
|
|
||||||
except Exception as e:
|
|
||||||
errors.display(e, f"loading Lora {lora_on_disk.filename}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
lora.mentioned_name = name
|
|
||||||
|
|
||||||
lora_on_disk.read_hash()
|
|
||||||
|
|
||||||
if lora is None:
|
|
||||||
failed_to_load_loras.append(name)
|
|
||||||
print(f"Couldn't find Lora with name {name}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
lora.multiplier = multipliers[i] if multipliers else 1.0
|
|
||||||
loaded_loras.append(lora)
|
|
||||||
|
|
||||||
if failed_to_load_loras:
|
|
||||||
sd_hijack.model_hijack.comments.append("Failed to find Loras: " + ", ".join(failed_to_load_loras))
|
|
||||||
|
|
||||||
|
|
||||||
def lora_calc_updown(lora, module, target):
|
|
||||||
with torch.no_grad():
|
|
||||||
up = module.up.weight.to(target.device, dtype=target.dtype)
|
|
||||||
down = module.down.weight.to(target.device, dtype=target.dtype)
|
|
||||||
|
|
||||||
if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1):
|
|
||||||
updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3)
|
|
||||||
elif up.shape[2:] == (3, 3) or down.shape[2:] == (3, 3):
|
|
||||||
updown = torch.nn.functional.conv2d(down.permute(1, 0, 2, 3), up).permute(1, 0, 2, 3)
|
|
||||||
else:
|
|
||||||
updown = up @ down
|
|
||||||
|
|
||||||
updown = updown * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0)
|
|
||||||
|
|
||||||
return updown
|
|
||||||
|
|
||||||
|
|
||||||
def lora_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]):
|
|
||||||
weights_backup = getattr(self, "lora_weights_backup", None)
|
|
||||||
|
|
||||||
if weights_backup is None:
|
|
||||||
return
|
|
||||||
|
|
||||||
if isinstance(self, torch.nn.MultiheadAttention):
|
|
||||||
self.in_proj_weight.copy_(weights_backup[0])
|
|
||||||
self.out_proj.weight.copy_(weights_backup[1])
|
|
||||||
else:
|
|
||||||
self.weight.copy_(weights_backup)
|
|
||||||
|
|
||||||
|
|
||||||
def lora_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]):
|
|
||||||
"""
|
|
||||||
Applies the currently selected set of Loras to the weights of torch layer self.
|
|
||||||
If weights already have this particular set of loras applied, does nothing.
|
|
||||||
If not, restores orginal weights from backup and alters weights according to loras.
|
|
||||||
"""
|
|
||||||
|
|
||||||
lora_layer_name = getattr(self, 'lora_layer_name', None)
|
|
||||||
if lora_layer_name is None:
|
|
||||||
return
|
|
||||||
|
|
||||||
current_names = getattr(self, "lora_current_names", ())
|
|
||||||
wanted_names = tuple((x.name, x.multiplier) for x in loaded_loras)
|
|
||||||
|
|
||||||
weights_backup = getattr(self, "lora_weights_backup", None)
|
|
||||||
if weights_backup is None:
|
|
||||||
if isinstance(self, torch.nn.MultiheadAttention):
|
|
||||||
weights_backup = (self.in_proj_weight.to(devices.cpu, copy=True), self.out_proj.weight.to(devices.cpu, copy=True))
|
|
||||||
else:
|
|
||||||
weights_backup = self.weight.to(devices.cpu, copy=True)
|
|
||||||
|
|
||||||
self.lora_weights_backup = weights_backup
|
|
||||||
|
|
||||||
if current_names != wanted_names:
|
|
||||||
lora_restore_weights_from_backup(self)
|
|
||||||
|
|
||||||
for lora in loaded_loras:
|
|
||||||
module = lora.modules.get(lora_layer_name, None)
|
|
||||||
if module is not None and hasattr(self, 'weight'):
|
|
||||||
self.weight += lora_calc_updown(lora, module, self.weight)
|
|
||||||
continue
|
|
||||||
|
|
||||||
module_q = lora.modules.get(lora_layer_name + "_q_proj", None)
|
|
||||||
module_k = lora.modules.get(lora_layer_name + "_k_proj", None)
|
|
||||||
module_v = lora.modules.get(lora_layer_name + "_v_proj", None)
|
|
||||||
module_out = lora.modules.get(lora_layer_name + "_out_proj", None)
|
|
||||||
|
|
||||||
if isinstance(self, torch.nn.MultiheadAttention) and module_q and module_k and module_v and module_out:
|
|
||||||
updown_q = lora_calc_updown(lora, module_q, self.in_proj_weight)
|
|
||||||
updown_k = lora_calc_updown(lora, module_k, self.in_proj_weight)
|
|
||||||
updown_v = lora_calc_updown(lora, module_v, self.in_proj_weight)
|
|
||||||
updown_qkv = torch.vstack([updown_q, updown_k, updown_v])
|
|
||||||
|
|
||||||
self.in_proj_weight += updown_qkv
|
|
||||||
self.out_proj.weight += lora_calc_updown(lora, module_out, self.out_proj.weight)
|
|
||||||
continue
|
|
||||||
|
|
||||||
if module is None:
|
|
||||||
continue
|
|
||||||
|
|
||||||
print(f'failed to calculate lora weights for layer {lora_layer_name}')
|
|
||||||
|
|
||||||
self.lora_current_names = wanted_names
|
|
||||||
|
|
||||||
|
|
||||||
def lora_forward(module, input, original_forward):
|
|
||||||
"""
|
|
||||||
Old way of applying Lora by executing operations during layer's forward.
|
|
||||||
Stacking many loras this way results in big performance degradation.
|
|
||||||
"""
|
|
||||||
|
|
||||||
if len(loaded_loras) == 0:
|
|
||||||
return original_forward(module, input)
|
|
||||||
|
|
||||||
input = devices.cond_cast_unet(input)
|
|
||||||
|
|
||||||
lora_restore_weights_from_backup(module)
|
|
||||||
lora_reset_cached_weight(module)
|
|
||||||
|
|
||||||
res = original_forward(module, input)
|
|
||||||
|
|
||||||
lora_layer_name = getattr(module, 'lora_layer_name', None)
|
|
||||||
for lora in loaded_loras:
|
|
||||||
module = lora.modules.get(lora_layer_name, None)
|
|
||||||
if module is None:
|
|
||||||
continue
|
|
||||||
|
|
||||||
module.up.to(device=devices.device)
|
|
||||||
module.down.to(device=devices.device)
|
|
||||||
|
|
||||||
res = res + module.up(module.down(input)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0)
|
|
||||||
|
|
||||||
return res
|
|
||||||
|
|
||||||
|
|
||||||
def lora_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]):
|
|
||||||
self.lora_current_names = ()
|
|
||||||
self.lora_weights_backup = None
|
|
||||||
|
|
||||||
|
|
||||||
def lora_Linear_forward(self, input):
|
|
||||||
if shared.opts.lora_functional:
|
|
||||||
return lora_forward(self, input, torch.nn.Linear_forward_before_lora)
|
|
||||||
|
|
||||||
lora_apply_weights(self)
|
|
||||||
|
|
||||||
return torch.nn.Linear_forward_before_lora(self, input)
|
|
||||||
|
|
||||||
|
|
||||||
def lora_Linear_load_state_dict(self, *args, **kwargs):
|
|
||||||
lora_reset_cached_weight(self)
|
|
||||||
|
|
||||||
return torch.nn.Linear_load_state_dict_before_lora(self, *args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def lora_Conv2d_forward(self, input):
|
|
||||||
if shared.opts.lora_functional:
|
|
||||||
return lora_forward(self, input, torch.nn.Conv2d_forward_before_lora)
|
|
||||||
|
|
||||||
lora_apply_weights(self)
|
|
||||||
|
|
||||||
return torch.nn.Conv2d_forward_before_lora(self, input)
|
|
||||||
|
|
||||||
|
|
||||||
def lora_Conv2d_load_state_dict(self, *args, **kwargs):
|
|
||||||
lora_reset_cached_weight(self)
|
|
||||||
|
|
||||||
return torch.nn.Conv2d_load_state_dict_before_lora(self, *args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def lora_MultiheadAttention_forward(self, *args, **kwargs):
|
|
||||||
lora_apply_weights(self)
|
|
||||||
|
|
||||||
return torch.nn.MultiheadAttention_forward_before_lora(self, *args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def lora_MultiheadAttention_load_state_dict(self, *args, **kwargs):
|
|
||||||
lora_reset_cached_weight(self)
|
|
||||||
|
|
||||||
return torch.nn.MultiheadAttention_load_state_dict_before_lora(self, *args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def list_available_loras():
|
|
||||||
available_loras.clear()
|
|
||||||
available_lora_aliases.clear()
|
|
||||||
forbidden_lora_aliases.clear()
|
|
||||||
available_lora_hash_lookup.clear()
|
|
||||||
forbidden_lora_aliases.update({"none": 1, "Addams": 1})
|
|
||||||
|
|
||||||
os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True)
|
|
||||||
|
|
||||||
candidates = list(shared.walk_files(shared.cmd_opts.lora_dir, allowed_extensions=[".pt", ".ckpt", ".safetensors"]))
|
|
||||||
for filename in candidates:
|
|
||||||
if os.path.isdir(filename):
|
|
||||||
continue
|
|
||||||
|
|
||||||
name = os.path.splitext(os.path.basename(filename))[0]
|
|
||||||
try:
|
|
||||||
entry = LoraOnDisk(name, filename)
|
|
||||||
except OSError: # should catch FileNotFoundError and PermissionError etc.
|
|
||||||
errors.report(f"Failed to load LoRA {name} from {filename}", exc_info=True)
|
|
||||||
continue
|
|
||||||
|
|
||||||
available_loras[name] = entry
|
|
||||||
|
|
||||||
if entry.alias in available_lora_aliases:
|
|
||||||
forbidden_lora_aliases[entry.alias.lower()] = 1
|
|
||||||
|
|
||||||
available_lora_aliases[name] = entry
|
|
||||||
available_lora_aliases[entry.alias] = entry
|
|
||||||
|
|
||||||
|
|
||||||
re_lora_name = re.compile(r"(.*)\s*\([0-9a-fA-F]+\)")
|
|
||||||
|
|
||||||
|
|
||||||
def infotext_pasted(infotext, params):
|
|
||||||
if "AddNet Module 1" in [x[1] for x in scripts.scripts_txt2img.infotext_fields]:
|
|
||||||
return # if the other extension is active, it will handle those fields, no need to do anything
|
|
||||||
|
|
||||||
added = []
|
|
||||||
|
|
||||||
for k in params:
|
|
||||||
if not k.startswith("AddNet Model "):
|
|
||||||
continue
|
|
||||||
|
|
||||||
num = k[13:]
|
|
||||||
|
|
||||||
if params.get("AddNet Module " + num) != "LoRA":
|
|
||||||
continue
|
|
||||||
|
|
||||||
name = params.get("AddNet Model " + num)
|
|
||||||
if name is None:
|
|
||||||
continue
|
|
||||||
|
|
||||||
m = re_lora_name.match(name)
|
|
||||||
if m:
|
|
||||||
name = m.group(1)
|
|
||||||
|
|
||||||
multiplier = params.get("AddNet Weight A " + num, "1.0")
|
|
||||||
|
|
||||||
added.append(f"<lora:{name}:{multiplier}>")
|
|
||||||
|
|
||||||
if added:
|
|
||||||
params["Prompt"] += "\n" + "".join(added)
|
|
||||||
|
|
||||||
|
|
||||||
available_loras = {}
|
|
||||||
available_lora_aliases = {}
|
|
||||||
available_lora_hash_lookup = {}
|
|
||||||
forbidden_lora_aliases = {}
|
|
||||||
loaded_loras = []
|
|
||||||
|
|
||||||
list_available_loras()
|
|
||||||
|
21
extensions-builtin/Lora/lyco_helpers.py
Normal file
21
extensions-builtin/Lora/lyco_helpers.py
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def make_weight_cp(t, wa, wb):
|
||||||
|
temp = torch.einsum('i j k l, j r -> i r k l', t, wb)
|
||||||
|
return torch.einsum('i j k l, i r -> r j k l', temp, wa)
|
||||||
|
|
||||||
|
|
||||||
|
def rebuild_conventional(up, down, shape, dyn_dim=None):
|
||||||
|
up = up.reshape(up.size(0), -1)
|
||||||
|
down = down.reshape(down.size(0), -1)
|
||||||
|
if dyn_dim is not None:
|
||||||
|
up = up[:, :dyn_dim]
|
||||||
|
down = down[:dyn_dim, :]
|
||||||
|
return (up @ down).reshape(shape)
|
||||||
|
|
||||||
|
|
||||||
|
def rebuild_cp_decomposition(up, down, mid):
|
||||||
|
up = up.reshape(up.size(0), -1)
|
||||||
|
down = down.reshape(down.size(0), -1)
|
||||||
|
return torch.einsum('n m k l, i n, m j -> i j k l', mid, up, down)
|
155
extensions-builtin/Lora/network.py
Normal file
155
extensions-builtin/Lora/network.py
Normal file
@ -0,0 +1,155 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
import os
|
||||||
|
from collections import namedtuple
|
||||||
|
import enum
|
||||||
|
|
||||||
|
from modules import sd_models, cache, errors, hashes, shared
|
||||||
|
|
||||||
|
NetworkWeights = namedtuple('NetworkWeights', ['network_key', 'sd_key', 'w', 'sd_module'])
|
||||||
|
|
||||||
|
metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20}
|
||||||
|
|
||||||
|
|
||||||
|
class SdVersion(enum.Enum):
|
||||||
|
Unknown = 1
|
||||||
|
SD1 = 2
|
||||||
|
SD2 = 3
|
||||||
|
SDXL = 4
|
||||||
|
|
||||||
|
|
||||||
|
class NetworkOnDisk:
|
||||||
|
def __init__(self, name, filename):
|
||||||
|
self.name = name
|
||||||
|
self.filename = filename
|
||||||
|
self.metadata = {}
|
||||||
|
self.is_safetensors = os.path.splitext(filename)[1].lower() == ".safetensors"
|
||||||
|
|
||||||
|
def read_metadata():
|
||||||
|
metadata = sd_models.read_metadata_from_safetensors(filename)
|
||||||
|
metadata.pop('ssmd_cover_images', None) # those are cover images, and they are too big to display in UI as text
|
||||||
|
|
||||||
|
return metadata
|
||||||
|
|
||||||
|
if self.is_safetensors:
|
||||||
|
try:
|
||||||
|
self.metadata = cache.cached_data_for_file('safetensors-metadata', "lora/" + self.name, filename, read_metadata)
|
||||||
|
except Exception as e:
|
||||||
|
errors.display(e, f"reading lora {filename}")
|
||||||
|
|
||||||
|
if self.metadata:
|
||||||
|
m = {}
|
||||||
|
for k, v in sorted(self.metadata.items(), key=lambda x: metadata_tags_order.get(x[0], 999)):
|
||||||
|
m[k] = v
|
||||||
|
|
||||||
|
self.metadata = m
|
||||||
|
|
||||||
|
self.alias = self.metadata.get('ss_output_name', self.name)
|
||||||
|
|
||||||
|
self.hash = None
|
||||||
|
self.shorthash = None
|
||||||
|
self.set_hash(
|
||||||
|
self.metadata.get('sshs_model_hash') or
|
||||||
|
hashes.sha256_from_cache(self.filename, "lora/" + self.name, use_addnet_hash=self.is_safetensors) or
|
||||||
|
''
|
||||||
|
)
|
||||||
|
|
||||||
|
self.sd_version = self.detect_version()
|
||||||
|
|
||||||
|
def detect_version(self):
|
||||||
|
if str(self.metadata.get('ss_base_model_version', "")).startswith("sdxl_"):
|
||||||
|
return SdVersion.SDXL
|
||||||
|
elif str(self.metadata.get('ss_v2', "")) == "True":
|
||||||
|
return SdVersion.SD2
|
||||||
|
elif len(self.metadata):
|
||||||
|
return SdVersion.SD1
|
||||||
|
|
||||||
|
return SdVersion.Unknown
|
||||||
|
|
||||||
|
def set_hash(self, v):
|
||||||
|
self.hash = v
|
||||||
|
self.shorthash = self.hash[0:12]
|
||||||
|
|
||||||
|
if self.shorthash:
|
||||||
|
import networks
|
||||||
|
networks.available_network_hash_lookup[self.shorthash] = self
|
||||||
|
|
||||||
|
def read_hash(self):
|
||||||
|
if not self.hash:
|
||||||
|
self.set_hash(hashes.sha256(self.filename, "lora/" + self.name, use_addnet_hash=self.is_safetensors) or '')
|
||||||
|
|
||||||
|
def get_alias(self):
|
||||||
|
import networks
|
||||||
|
if shared.opts.lora_preferred_name == "Filename" or self.alias.lower() in networks.forbidden_network_aliases:
|
||||||
|
return self.name
|
||||||
|
else:
|
||||||
|
return self.alias
|
||||||
|
|
||||||
|
|
||||||
|
class Network: # LoraModule
|
||||||
|
def __init__(self, name, network_on_disk: NetworkOnDisk):
|
||||||
|
self.name = name
|
||||||
|
self.network_on_disk = network_on_disk
|
||||||
|
self.te_multiplier = 1.0
|
||||||
|
self.unet_multiplier = 1.0
|
||||||
|
self.dyn_dim = None
|
||||||
|
self.modules = {}
|
||||||
|
self.mtime = None
|
||||||
|
|
||||||
|
self.mentioned_name = None
|
||||||
|
"""the text that was used to add the network to prompt - can be either name or an alias"""
|
||||||
|
|
||||||
|
|
||||||
|
class ModuleType:
|
||||||
|
def create_module(self, net: Network, weights: NetworkWeights) -> Network | None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class NetworkModule:
|
||||||
|
def __init__(self, net: Network, weights: NetworkWeights):
|
||||||
|
self.network = net
|
||||||
|
self.network_key = weights.network_key
|
||||||
|
self.sd_key = weights.sd_key
|
||||||
|
self.sd_module = weights.sd_module
|
||||||
|
|
||||||
|
if hasattr(self.sd_module, 'weight'):
|
||||||
|
self.shape = self.sd_module.weight.shape
|
||||||
|
|
||||||
|
self.dim = None
|
||||||
|
self.bias = weights.w.get("bias")
|
||||||
|
self.alpha = weights.w["alpha"].item() if "alpha" in weights.w else None
|
||||||
|
self.scale = weights.w["scale"].item() if "scale" in weights.w else None
|
||||||
|
|
||||||
|
def multiplier(self):
|
||||||
|
if 'transformer' in self.sd_key[:20]:
|
||||||
|
return self.network.te_multiplier
|
||||||
|
else:
|
||||||
|
return self.network.unet_multiplier
|
||||||
|
|
||||||
|
def calc_scale(self):
|
||||||
|
if self.scale is not None:
|
||||||
|
return self.scale
|
||||||
|
if self.dim is not None and self.alpha is not None:
|
||||||
|
return self.alpha / self.dim
|
||||||
|
|
||||||
|
return 1.0
|
||||||
|
|
||||||
|
def finalize_updown(self, updown, orig_weight, output_shape):
|
||||||
|
if self.bias is not None:
|
||||||
|
updown = updown.reshape(self.bias.shape)
|
||||||
|
updown += self.bias.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||||
|
updown = updown.reshape(output_shape)
|
||||||
|
|
||||||
|
if len(output_shape) == 4:
|
||||||
|
updown = updown.reshape(output_shape)
|
||||||
|
|
||||||
|
if orig_weight.size().numel() == updown.size().numel():
|
||||||
|
updown = updown.reshape(orig_weight.shape)
|
||||||
|
|
||||||
|
return updown * self.calc_scale() * self.multiplier()
|
||||||
|
|
||||||
|
def calc_updown(self, target):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def forward(self, x, y):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
22
extensions-builtin/Lora/network_full.py
Normal file
22
extensions-builtin/Lora/network_full.py
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
import network
|
||||||
|
|
||||||
|
|
||||||
|
class ModuleTypeFull(network.ModuleType):
|
||||||
|
def create_module(self, net: network.Network, weights: network.NetworkWeights):
|
||||||
|
if all(x in weights.w for x in ["diff"]):
|
||||||
|
return NetworkModuleFull(net, weights)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class NetworkModuleFull(network.NetworkModule):
|
||||||
|
def __init__(self, net: network.Network, weights: network.NetworkWeights):
|
||||||
|
super().__init__(net, weights)
|
||||||
|
|
||||||
|
self.weight = weights.w.get("diff")
|
||||||
|
|
||||||
|
def calc_updown(self, orig_weight):
|
||||||
|
output_shape = self.weight.shape
|
||||||
|
updown = self.weight.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||||
|
|
||||||
|
return self.finalize_updown(updown, orig_weight, output_shape)
|
55
extensions-builtin/Lora/network_hada.py
Normal file
55
extensions-builtin/Lora/network_hada.py
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
import lyco_helpers
|
||||||
|
import network
|
||||||
|
|
||||||
|
|
||||||
|
class ModuleTypeHada(network.ModuleType):
|
||||||
|
def create_module(self, net: network.Network, weights: network.NetworkWeights):
|
||||||
|
if all(x in weights.w for x in ["hada_w1_a", "hada_w1_b", "hada_w2_a", "hada_w2_b"]):
|
||||||
|
return NetworkModuleHada(net, weights)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class NetworkModuleHada(network.NetworkModule):
|
||||||
|
def __init__(self, net: network.Network, weights: network.NetworkWeights):
|
||||||
|
super().__init__(net, weights)
|
||||||
|
|
||||||
|
if hasattr(self.sd_module, 'weight'):
|
||||||
|
self.shape = self.sd_module.weight.shape
|
||||||
|
|
||||||
|
self.w1a = weights.w["hada_w1_a"]
|
||||||
|
self.w1b = weights.w["hada_w1_b"]
|
||||||
|
self.dim = self.w1b.shape[0]
|
||||||
|
self.w2a = weights.w["hada_w2_a"]
|
||||||
|
self.w2b = weights.w["hada_w2_b"]
|
||||||
|
|
||||||
|
self.t1 = weights.w.get("hada_t1")
|
||||||
|
self.t2 = weights.w.get("hada_t2")
|
||||||
|
|
||||||
|
def calc_updown(self, orig_weight):
|
||||||
|
w1a = self.w1a.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||||
|
w1b = self.w1b.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||||
|
w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||||
|
w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||||
|
|
||||||
|
output_shape = [w1a.size(0), w1b.size(1)]
|
||||||
|
|
||||||
|
if self.t1 is not None:
|
||||||
|
output_shape = [w1a.size(1), w1b.size(1)]
|
||||||
|
t1 = self.t1.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||||
|
updown1 = lyco_helpers.make_weight_cp(t1, w1a, w1b)
|
||||||
|
output_shape += t1.shape[2:]
|
||||||
|
else:
|
||||||
|
if len(w1b.shape) == 4:
|
||||||
|
output_shape += w1b.shape[2:]
|
||||||
|
updown1 = lyco_helpers.rebuild_conventional(w1a, w1b, output_shape)
|
||||||
|
|
||||||
|
if self.t2 is not None:
|
||||||
|
t2 = self.t2.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||||
|
updown2 = lyco_helpers.make_weight_cp(t2, w2a, w2b)
|
||||||
|
else:
|
||||||
|
updown2 = lyco_helpers.rebuild_conventional(w2a, w2b, output_shape)
|
||||||
|
|
||||||
|
updown = updown1 * updown2
|
||||||
|
|
||||||
|
return self.finalize_updown(updown, orig_weight, output_shape)
|
30
extensions-builtin/Lora/network_ia3.py
Normal file
30
extensions-builtin/Lora/network_ia3.py
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
import network
|
||||||
|
|
||||||
|
|
||||||
|
class ModuleTypeIa3(network.ModuleType):
|
||||||
|
def create_module(self, net: network.Network, weights: network.NetworkWeights):
|
||||||
|
if all(x in weights.w for x in ["weight"]):
|
||||||
|
return NetworkModuleIa3(net, weights)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class NetworkModuleIa3(network.NetworkModule):
|
||||||
|
def __init__(self, net: network.Network, weights: network.NetworkWeights):
|
||||||
|
super().__init__(net, weights)
|
||||||
|
|
||||||
|
self.w = weights.w["weight"]
|
||||||
|
self.on_input = weights.w["on_input"].item()
|
||||||
|
|
||||||
|
def calc_updown(self, orig_weight):
|
||||||
|
w = self.w.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||||
|
|
||||||
|
output_shape = [w.size(0), orig_weight.size(1)]
|
||||||
|
if self.on_input:
|
||||||
|
output_shape.reverse()
|
||||||
|
else:
|
||||||
|
w = w.reshape(-1, 1)
|
||||||
|
|
||||||
|
updown = orig_weight * w
|
||||||
|
|
||||||
|
return self.finalize_updown(updown, orig_weight, output_shape)
|
64
extensions-builtin/Lora/network_lokr.py
Normal file
64
extensions-builtin/Lora/network_lokr.py
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
import lyco_helpers
|
||||||
|
import network
|
||||||
|
|
||||||
|
|
||||||
|
class ModuleTypeLokr(network.ModuleType):
|
||||||
|
def create_module(self, net: network.Network, weights: network.NetworkWeights):
|
||||||
|
has_1 = "lokr_w1" in weights.w or ("lokr_w1_a" in weights.w and "lokr_w1_b" in weights.w)
|
||||||
|
has_2 = "lokr_w2" in weights.w or ("lokr_w2_a" in weights.w and "lokr_w2_b" in weights.w)
|
||||||
|
if has_1 and has_2:
|
||||||
|
return NetworkModuleLokr(net, weights)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def make_kron(orig_shape, w1, w2):
|
||||||
|
if len(w2.shape) == 4:
|
||||||
|
w1 = w1.unsqueeze(2).unsqueeze(2)
|
||||||
|
w2 = w2.contiguous()
|
||||||
|
return torch.kron(w1, w2).reshape(orig_shape)
|
||||||
|
|
||||||
|
|
||||||
|
class NetworkModuleLokr(network.NetworkModule):
|
||||||
|
def __init__(self, net: network.Network, weights: network.NetworkWeights):
|
||||||
|
super().__init__(net, weights)
|
||||||
|
|
||||||
|
self.w1 = weights.w.get("lokr_w1")
|
||||||
|
self.w1a = weights.w.get("lokr_w1_a")
|
||||||
|
self.w1b = weights.w.get("lokr_w1_b")
|
||||||
|
self.dim = self.w1b.shape[0] if self.w1b is not None else self.dim
|
||||||
|
self.w2 = weights.w.get("lokr_w2")
|
||||||
|
self.w2a = weights.w.get("lokr_w2_a")
|
||||||
|
self.w2b = weights.w.get("lokr_w2_b")
|
||||||
|
self.dim = self.w2b.shape[0] if self.w2b is not None else self.dim
|
||||||
|
self.t2 = weights.w.get("lokr_t2")
|
||||||
|
|
||||||
|
def calc_updown(self, orig_weight):
|
||||||
|
if self.w1 is not None:
|
||||||
|
w1 = self.w1.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||||
|
else:
|
||||||
|
w1a = self.w1a.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||||
|
w1b = self.w1b.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||||
|
w1 = w1a @ w1b
|
||||||
|
|
||||||
|
if self.w2 is not None:
|
||||||
|
w2 = self.w2.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||||
|
elif self.t2 is None:
|
||||||
|
w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||||
|
w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||||
|
w2 = w2a @ w2b
|
||||||
|
else:
|
||||||
|
t2 = self.t2.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||||
|
w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||||
|
w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||||
|
w2 = lyco_helpers.make_weight_cp(t2, w2a, w2b)
|
||||||
|
|
||||||
|
output_shape = [w1.size(0) * w2.size(0), w1.size(1) * w2.size(1)]
|
||||||
|
if len(orig_weight.shape) == 4:
|
||||||
|
output_shape = orig_weight.shape
|
||||||
|
|
||||||
|
updown = make_kron(output_shape, w1, w2)
|
||||||
|
|
||||||
|
return self.finalize_updown(updown, orig_weight, output_shape)
|
86
extensions-builtin/Lora/network_lora.py
Normal file
86
extensions-builtin/Lora/network_lora.py
Normal file
@ -0,0 +1,86 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
import lyco_helpers
|
||||||
|
import network
|
||||||
|
from modules import devices
|
||||||
|
|
||||||
|
|
||||||
|
class ModuleTypeLora(network.ModuleType):
|
||||||
|
def create_module(self, net: network.Network, weights: network.NetworkWeights):
|
||||||
|
if all(x in weights.w for x in ["lora_up.weight", "lora_down.weight"]):
|
||||||
|
return NetworkModuleLora(net, weights)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class NetworkModuleLora(network.NetworkModule):
|
||||||
|
def __init__(self, net: network.Network, weights: network.NetworkWeights):
|
||||||
|
super().__init__(net, weights)
|
||||||
|
|
||||||
|
self.up_model = self.create_module(weights.w, "lora_up.weight")
|
||||||
|
self.down_model = self.create_module(weights.w, "lora_down.weight")
|
||||||
|
self.mid_model = self.create_module(weights.w, "lora_mid.weight", none_ok=True)
|
||||||
|
|
||||||
|
self.dim = weights.w["lora_down.weight"].shape[0]
|
||||||
|
|
||||||
|
def create_module(self, weights, key, none_ok=False):
|
||||||
|
weight = weights.get(key)
|
||||||
|
|
||||||
|
if weight is None and none_ok:
|
||||||
|
return None
|
||||||
|
|
||||||
|
is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear, torch.nn.MultiheadAttention]
|
||||||
|
is_conv = type(self.sd_module) in [torch.nn.Conv2d]
|
||||||
|
|
||||||
|
if is_linear:
|
||||||
|
weight = weight.reshape(weight.shape[0], -1)
|
||||||
|
module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
|
||||||
|
elif is_conv and key == "lora_down.weight" or key == "dyn_up":
|
||||||
|
if len(weight.shape) == 2:
|
||||||
|
weight = weight.reshape(weight.shape[0], -1, 1, 1)
|
||||||
|
|
||||||
|
if weight.shape[2] != 1 or weight.shape[3] != 1:
|
||||||
|
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], self.sd_module.kernel_size, self.sd_module.stride, self.sd_module.padding, bias=False)
|
||||||
|
else:
|
||||||
|
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False)
|
||||||
|
elif is_conv and key == "lora_mid.weight":
|
||||||
|
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], self.sd_module.kernel_size, self.sd_module.stride, self.sd_module.padding, bias=False)
|
||||||
|
elif is_conv and key == "lora_up.weight" or key == "dyn_down":
|
||||||
|
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False)
|
||||||
|
else:
|
||||||
|
raise AssertionError(f'Lora layer {self.network_key} matched a layer with unsupported type: {type(self.sd_module).__name__}')
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
if weight.shape != module.weight.shape:
|
||||||
|
weight = weight.reshape(module.weight.shape)
|
||||||
|
module.weight.copy_(weight)
|
||||||
|
|
||||||
|
module.to(device=devices.cpu, dtype=devices.dtype)
|
||||||
|
module.weight.requires_grad_(False)
|
||||||
|
|
||||||
|
return module
|
||||||
|
|
||||||
|
def calc_updown(self, orig_weight):
|
||||||
|
up = self.up_model.weight.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||||
|
down = self.down_model.weight.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||||
|
|
||||||
|
output_shape = [up.size(0), down.size(1)]
|
||||||
|
if self.mid_model is not None:
|
||||||
|
# cp-decomposition
|
||||||
|
mid = self.mid_model.weight.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||||
|
updown = lyco_helpers.rebuild_cp_decomposition(up, down, mid)
|
||||||
|
output_shape += mid.shape[2:]
|
||||||
|
else:
|
||||||
|
if len(down.shape) == 4:
|
||||||
|
output_shape += down.shape[2:]
|
||||||
|
updown = lyco_helpers.rebuild_conventional(up, down, output_shape, self.network.dyn_dim)
|
||||||
|
|
||||||
|
return self.finalize_updown(updown, orig_weight, output_shape)
|
||||||
|
|
||||||
|
def forward(self, x, y):
|
||||||
|
self.up_model.to(device=devices.device)
|
||||||
|
self.down_model.to(device=devices.device)
|
||||||
|
|
||||||
|
return y + self.up_model(self.down_model(x)) * self.multiplier() * self.calc_scale()
|
||||||
|
|
||||||
|
|
468
extensions-builtin/Lora/networks.py
Normal file
468
extensions-builtin/Lora/networks.py
Normal file
@ -0,0 +1,468 @@
|
|||||||
|
import os
|
||||||
|
import re
|
||||||
|
|
||||||
|
import network
|
||||||
|
import network_lora
|
||||||
|
import network_hada
|
||||||
|
import network_ia3
|
||||||
|
import network_lokr
|
||||||
|
import network_full
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
from modules import shared, devices, sd_models, errors, scripts, sd_hijack
|
||||||
|
|
||||||
|
module_types = [
|
||||||
|
network_lora.ModuleTypeLora(),
|
||||||
|
network_hada.ModuleTypeHada(),
|
||||||
|
network_ia3.ModuleTypeIa3(),
|
||||||
|
network_lokr.ModuleTypeLokr(),
|
||||||
|
network_full.ModuleTypeFull(),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
re_digits = re.compile(r"\d+")
|
||||||
|
re_x_proj = re.compile(r"(.*)_([qkv]_proj)$")
|
||||||
|
re_compiled = {}
|
||||||
|
|
||||||
|
suffix_conversion = {
|
||||||
|
"attentions": {},
|
||||||
|
"resnets": {
|
||||||
|
"conv1": "in_layers_2",
|
||||||
|
"conv2": "out_layers_3",
|
||||||
|
"time_emb_proj": "emb_layers_1",
|
||||||
|
"conv_shortcut": "skip_connection",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def convert_diffusers_name_to_compvis(key, is_sd2):
|
||||||
|
def match(match_list, regex_text):
|
||||||
|
regex = re_compiled.get(regex_text)
|
||||||
|
if regex is None:
|
||||||
|
regex = re.compile(regex_text)
|
||||||
|
re_compiled[regex_text] = regex
|
||||||
|
|
||||||
|
r = re.match(regex, key)
|
||||||
|
if not r:
|
||||||
|
return False
|
||||||
|
|
||||||
|
match_list.clear()
|
||||||
|
match_list.extend([int(x) if re.match(re_digits, x) else x for x in r.groups()])
|
||||||
|
return True
|
||||||
|
|
||||||
|
m = []
|
||||||
|
|
||||||
|
if match(m, r"lora_unet_conv_in(.*)"):
|
||||||
|
return f'diffusion_model_input_blocks_0_0{m[0]}'
|
||||||
|
|
||||||
|
if match(m, r"lora_unet_conv_out(.*)"):
|
||||||
|
return f'diffusion_model_out_2{m[0]}'
|
||||||
|
|
||||||
|
if match(m, r"lora_unet_time_embedding_linear_(\d+)(.*)"):
|
||||||
|
return f"diffusion_model_time_embed_{m[0] * 2 - 2}{m[1]}"
|
||||||
|
|
||||||
|
if match(m, r"lora_unet_down_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"):
|
||||||
|
suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3])
|
||||||
|
return f"diffusion_model_input_blocks_{1 + m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}"
|
||||||
|
|
||||||
|
if match(m, r"lora_unet_mid_block_(attentions|resnets)_(\d+)_(.+)"):
|
||||||
|
suffix = suffix_conversion.get(m[0], {}).get(m[2], m[2])
|
||||||
|
return f"diffusion_model_middle_block_{1 if m[0] == 'attentions' else m[1] * 2}_{suffix}"
|
||||||
|
|
||||||
|
if match(m, r"lora_unet_up_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"):
|
||||||
|
suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3])
|
||||||
|
return f"diffusion_model_output_blocks_{m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}"
|
||||||
|
|
||||||
|
if match(m, r"lora_unet_down_blocks_(\d+)_downsamplers_0_conv"):
|
||||||
|
return f"diffusion_model_input_blocks_{3 + m[0] * 3}_0_op"
|
||||||
|
|
||||||
|
if match(m, r"lora_unet_up_blocks_(\d+)_upsamplers_0_conv"):
|
||||||
|
return f"diffusion_model_output_blocks_{2 + m[0] * 3}_{2 if m[0]>0 else 1}_conv"
|
||||||
|
|
||||||
|
if match(m, r"lora_te_text_model_encoder_layers_(\d+)_(.+)"):
|
||||||
|
if is_sd2:
|
||||||
|
if 'mlp_fc1' in m[1]:
|
||||||
|
return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc1', 'mlp_c_fc')}"
|
||||||
|
elif 'mlp_fc2' in m[1]:
|
||||||
|
return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc2', 'mlp_c_proj')}"
|
||||||
|
else:
|
||||||
|
return f"model_transformer_resblocks_{m[0]}_{m[1].replace('self_attn', 'attn')}"
|
||||||
|
|
||||||
|
return f"transformer_text_model_encoder_layers_{m[0]}_{m[1]}"
|
||||||
|
|
||||||
|
if match(m, r"lora_te2_text_model_encoder_layers_(\d+)_(.+)"):
|
||||||
|
if 'mlp_fc1' in m[1]:
|
||||||
|
return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc1', 'mlp_c_fc')}"
|
||||||
|
elif 'mlp_fc2' in m[1]:
|
||||||
|
return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc2', 'mlp_c_proj')}"
|
||||||
|
else:
|
||||||
|
return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('self_attn', 'attn')}"
|
||||||
|
|
||||||
|
return key
|
||||||
|
|
||||||
|
|
||||||
|
def assign_network_names_to_compvis_modules(sd_model):
|
||||||
|
network_layer_mapping = {}
|
||||||
|
|
||||||
|
if shared.sd_model.is_sdxl:
|
||||||
|
for i, embedder in enumerate(shared.sd_model.conditioner.embedders):
|
||||||
|
if not hasattr(embedder, 'wrapped'):
|
||||||
|
continue
|
||||||
|
|
||||||
|
for name, module in embedder.wrapped.named_modules():
|
||||||
|
network_name = f'{i}_{name.replace(".", "_")}'
|
||||||
|
network_layer_mapping[network_name] = module
|
||||||
|
module.network_layer_name = network_name
|
||||||
|
else:
|
||||||
|
for name, module in shared.sd_model.cond_stage_model.wrapped.named_modules():
|
||||||
|
network_name = name.replace(".", "_")
|
||||||
|
network_layer_mapping[network_name] = module
|
||||||
|
module.network_layer_name = network_name
|
||||||
|
|
||||||
|
for name, module in shared.sd_model.model.named_modules():
|
||||||
|
network_name = name.replace(".", "_")
|
||||||
|
network_layer_mapping[network_name] = module
|
||||||
|
module.network_layer_name = network_name
|
||||||
|
|
||||||
|
sd_model.network_layer_mapping = network_layer_mapping
|
||||||
|
|
||||||
|
|
||||||
|
def load_network(name, network_on_disk):
|
||||||
|
net = network.Network(name, network_on_disk)
|
||||||
|
net.mtime = os.path.getmtime(network_on_disk.filename)
|
||||||
|
|
||||||
|
sd = sd_models.read_state_dict(network_on_disk.filename)
|
||||||
|
|
||||||
|
# this should not be needed but is here as an emergency fix for an unknown error people are experiencing in 1.2.0
|
||||||
|
if not hasattr(shared.sd_model, 'network_layer_mapping'):
|
||||||
|
assign_network_names_to_compvis_modules(shared.sd_model)
|
||||||
|
|
||||||
|
keys_failed_to_match = {}
|
||||||
|
is_sd2 = 'model_transformer_resblocks' in shared.sd_model.network_layer_mapping
|
||||||
|
|
||||||
|
matched_networks = {}
|
||||||
|
|
||||||
|
for key_network, weight in sd.items():
|
||||||
|
key_network_without_network_parts, network_part = key_network.split(".", 1)
|
||||||
|
|
||||||
|
key = convert_diffusers_name_to_compvis(key_network_without_network_parts, is_sd2)
|
||||||
|
sd_module = shared.sd_model.network_layer_mapping.get(key, None)
|
||||||
|
|
||||||
|
if sd_module is None:
|
||||||
|
m = re_x_proj.match(key)
|
||||||
|
if m:
|
||||||
|
sd_module = shared.sd_model.network_layer_mapping.get(m.group(1), None)
|
||||||
|
|
||||||
|
# SDXL loras seem to already have correct compvis keys, so only need to replace "lora_unet" with "diffusion_model"
|
||||||
|
if sd_module is None and "lora_unet" in key_network_without_network_parts:
|
||||||
|
key = key_network_without_network_parts.replace("lora_unet", "diffusion_model")
|
||||||
|
sd_module = shared.sd_model.network_layer_mapping.get(key, None)
|
||||||
|
elif sd_module is None and "lora_te1_text_model" in key_network_without_network_parts:
|
||||||
|
key = key_network_without_network_parts.replace("lora_te1_text_model", "0_transformer_text_model")
|
||||||
|
sd_module = shared.sd_model.network_layer_mapping.get(key, None)
|
||||||
|
|
||||||
|
# some SD1 Loras also have correct compvis keys
|
||||||
|
if sd_module is None:
|
||||||
|
key = key_network_without_network_parts.replace("lora_te1_text_model", "transformer_text_model")
|
||||||
|
sd_module = shared.sd_model.network_layer_mapping.get(key, None)
|
||||||
|
|
||||||
|
if sd_module is None:
|
||||||
|
keys_failed_to_match[key_network] = key
|
||||||
|
continue
|
||||||
|
|
||||||
|
if key not in matched_networks:
|
||||||
|
matched_networks[key] = network.NetworkWeights(network_key=key_network, sd_key=key, w={}, sd_module=sd_module)
|
||||||
|
|
||||||
|
matched_networks[key].w[network_part] = weight
|
||||||
|
|
||||||
|
for key, weights in matched_networks.items():
|
||||||
|
net_module = None
|
||||||
|
for nettype in module_types:
|
||||||
|
net_module = nettype.create_module(net, weights)
|
||||||
|
if net_module is not None:
|
||||||
|
break
|
||||||
|
|
||||||
|
if net_module is None:
|
||||||
|
raise AssertionError(f"Could not find a module type (out of {', '.join([x.__class__.__name__ for x in module_types])}) that would accept those keys: {', '.join(weights.w)}")
|
||||||
|
|
||||||
|
net.modules[key] = net_module
|
||||||
|
|
||||||
|
if keys_failed_to_match:
|
||||||
|
print(f"Failed to match keys when loading network {network_on_disk.filename}: {keys_failed_to_match}")
|
||||||
|
|
||||||
|
return net
|
||||||
|
|
||||||
|
|
||||||
|
def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=None):
|
||||||
|
already_loaded = {}
|
||||||
|
|
||||||
|
for net in loaded_networks:
|
||||||
|
if net.name in names:
|
||||||
|
already_loaded[net.name] = net
|
||||||
|
|
||||||
|
loaded_networks.clear()
|
||||||
|
|
||||||
|
networks_on_disk = [available_network_aliases.get(name, None) for name in names]
|
||||||
|
if any(x is None for x in networks_on_disk):
|
||||||
|
list_available_networks()
|
||||||
|
|
||||||
|
networks_on_disk = [available_network_aliases.get(name, None) for name in names]
|
||||||
|
|
||||||
|
failed_to_load_networks = []
|
||||||
|
|
||||||
|
for i, name in enumerate(names):
|
||||||
|
net = already_loaded.get(name, None)
|
||||||
|
|
||||||
|
network_on_disk = networks_on_disk[i]
|
||||||
|
|
||||||
|
if network_on_disk is not None:
|
||||||
|
if net is None or os.path.getmtime(network_on_disk.filename) > net.mtime:
|
||||||
|
try:
|
||||||
|
net = load_network(name, network_on_disk)
|
||||||
|
except Exception as e:
|
||||||
|
errors.display(e, f"loading network {network_on_disk.filename}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
net.mentioned_name = name
|
||||||
|
|
||||||
|
network_on_disk.read_hash()
|
||||||
|
|
||||||
|
if net is None:
|
||||||
|
failed_to_load_networks.append(name)
|
||||||
|
print(f"Couldn't find network with name {name}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
net.te_multiplier = te_multipliers[i] if te_multipliers else 1.0
|
||||||
|
net.unet_multiplier = unet_multipliers[i] if unet_multipliers else 1.0
|
||||||
|
net.dyn_dim = dyn_dims[i] if dyn_dims else 1.0
|
||||||
|
loaded_networks.append(net)
|
||||||
|
|
||||||
|
if failed_to_load_networks:
|
||||||
|
sd_hijack.model_hijack.comments.append("Failed to find networks: " + ", ".join(failed_to_load_networks))
|
||||||
|
|
||||||
|
|
||||||
|
def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]):
|
||||||
|
weights_backup = getattr(self, "network_weights_backup", None)
|
||||||
|
|
||||||
|
if weights_backup is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
if isinstance(self, torch.nn.MultiheadAttention):
|
||||||
|
self.in_proj_weight.copy_(weights_backup[0])
|
||||||
|
self.out_proj.weight.copy_(weights_backup[1])
|
||||||
|
else:
|
||||||
|
self.weight.copy_(weights_backup)
|
||||||
|
|
||||||
|
|
||||||
|
def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]):
|
||||||
|
"""
|
||||||
|
Applies the currently selected set of networks to the weights of torch layer self.
|
||||||
|
If weights already have this particular set of networks applied, does nothing.
|
||||||
|
If not, restores orginal weights from backup and alters weights according to networks.
|
||||||
|
"""
|
||||||
|
|
||||||
|
network_layer_name = getattr(self, 'network_layer_name', None)
|
||||||
|
if network_layer_name is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
current_names = getattr(self, "network_current_names", ())
|
||||||
|
wanted_names = tuple((x.name, x.te_multiplier, x.unet_multiplier, x.dyn_dim) for x in loaded_networks)
|
||||||
|
|
||||||
|
weights_backup = getattr(self, "network_weights_backup", None)
|
||||||
|
if weights_backup is None:
|
||||||
|
if isinstance(self, torch.nn.MultiheadAttention):
|
||||||
|
weights_backup = (self.in_proj_weight.to(devices.cpu, copy=True), self.out_proj.weight.to(devices.cpu, copy=True))
|
||||||
|
else:
|
||||||
|
weights_backup = self.weight.to(devices.cpu, copy=True)
|
||||||
|
|
||||||
|
self.network_weights_backup = weights_backup
|
||||||
|
|
||||||
|
if current_names != wanted_names:
|
||||||
|
network_restore_weights_from_backup(self)
|
||||||
|
|
||||||
|
for net in loaded_networks:
|
||||||
|
module = net.modules.get(network_layer_name, None)
|
||||||
|
if module is not None and hasattr(self, 'weight'):
|
||||||
|
with torch.no_grad():
|
||||||
|
updown = module.calc_updown(self.weight)
|
||||||
|
|
||||||
|
if len(self.weight.shape) == 4 and self.weight.shape[1] == 9:
|
||||||
|
# inpainting model. zero pad updown to make channel[1] 4 to 9
|
||||||
|
updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5))
|
||||||
|
|
||||||
|
self.weight += updown
|
||||||
|
continue
|
||||||
|
|
||||||
|
module_q = net.modules.get(network_layer_name + "_q_proj", None)
|
||||||
|
module_k = net.modules.get(network_layer_name + "_k_proj", None)
|
||||||
|
module_v = net.modules.get(network_layer_name + "_v_proj", None)
|
||||||
|
module_out = net.modules.get(network_layer_name + "_out_proj", None)
|
||||||
|
|
||||||
|
if isinstance(self, torch.nn.MultiheadAttention) and module_q and module_k and module_v and module_out:
|
||||||
|
with torch.no_grad():
|
||||||
|
updown_q = module_q.calc_updown(self.in_proj_weight)
|
||||||
|
updown_k = module_k.calc_updown(self.in_proj_weight)
|
||||||
|
updown_v = module_v.calc_updown(self.in_proj_weight)
|
||||||
|
updown_qkv = torch.vstack([updown_q, updown_k, updown_v])
|
||||||
|
updown_out = module_out.calc_updown(self.out_proj.weight)
|
||||||
|
|
||||||
|
self.in_proj_weight += updown_qkv
|
||||||
|
self.out_proj.weight += updown_out
|
||||||
|
continue
|
||||||
|
|
||||||
|
if module is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
print(f'failed to calculate network weights for layer {network_layer_name}')
|
||||||
|
|
||||||
|
self.network_current_names = wanted_names
|
||||||
|
|
||||||
|
|
||||||
|
def network_forward(module, input, original_forward):
|
||||||
|
"""
|
||||||
|
Old way of applying Lora by executing operations during layer's forward.
|
||||||
|
Stacking many loras this way results in big performance degradation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if len(loaded_networks) == 0:
|
||||||
|
return original_forward(module, input)
|
||||||
|
|
||||||
|
input = devices.cond_cast_unet(input)
|
||||||
|
|
||||||
|
network_restore_weights_from_backup(module)
|
||||||
|
network_reset_cached_weight(module)
|
||||||
|
|
||||||
|
y = original_forward(module, input)
|
||||||
|
|
||||||
|
network_layer_name = getattr(module, 'network_layer_name', None)
|
||||||
|
for lora in loaded_networks:
|
||||||
|
module = lora.modules.get(network_layer_name, None)
|
||||||
|
if module is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
y = module.forward(y, input)
|
||||||
|
|
||||||
|
return y
|
||||||
|
|
||||||
|
|
||||||
|
def network_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]):
|
||||||
|
self.network_current_names = ()
|
||||||
|
self.network_weights_backup = None
|
||||||
|
|
||||||
|
|
||||||
|
def network_Linear_forward(self, input):
|
||||||
|
if shared.opts.lora_functional:
|
||||||
|
return network_forward(self, input, torch.nn.Linear_forward_before_network)
|
||||||
|
|
||||||
|
network_apply_weights(self)
|
||||||
|
|
||||||
|
return torch.nn.Linear_forward_before_network(self, input)
|
||||||
|
|
||||||
|
|
||||||
|
def network_Linear_load_state_dict(self, *args, **kwargs):
|
||||||
|
network_reset_cached_weight(self)
|
||||||
|
|
||||||
|
return torch.nn.Linear_load_state_dict_before_network(self, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def network_Conv2d_forward(self, input):
|
||||||
|
if shared.opts.lora_functional:
|
||||||
|
return network_forward(self, input, torch.nn.Conv2d_forward_before_network)
|
||||||
|
|
||||||
|
network_apply_weights(self)
|
||||||
|
|
||||||
|
return torch.nn.Conv2d_forward_before_network(self, input)
|
||||||
|
|
||||||
|
|
||||||
|
def network_Conv2d_load_state_dict(self, *args, **kwargs):
|
||||||
|
network_reset_cached_weight(self)
|
||||||
|
|
||||||
|
return torch.nn.Conv2d_load_state_dict_before_network(self, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def network_MultiheadAttention_forward(self, *args, **kwargs):
|
||||||
|
network_apply_weights(self)
|
||||||
|
|
||||||
|
return torch.nn.MultiheadAttention_forward_before_network(self, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def network_MultiheadAttention_load_state_dict(self, *args, **kwargs):
|
||||||
|
network_reset_cached_weight(self)
|
||||||
|
|
||||||
|
return torch.nn.MultiheadAttention_load_state_dict_before_network(self, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def list_available_networks():
|
||||||
|
available_networks.clear()
|
||||||
|
available_network_aliases.clear()
|
||||||
|
forbidden_network_aliases.clear()
|
||||||
|
available_network_hash_lookup.clear()
|
||||||
|
forbidden_network_aliases.update({"none": 1, "Addams": 1})
|
||||||
|
|
||||||
|
os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True)
|
||||||
|
|
||||||
|
candidates = list(shared.walk_files(shared.cmd_opts.lora_dir, allowed_extensions=[".pt", ".ckpt", ".safetensors"]))
|
||||||
|
candidates += list(shared.walk_files(shared.cmd_opts.lyco_dir_backcompat, allowed_extensions=[".pt", ".ckpt", ".safetensors"]))
|
||||||
|
for filename in candidates:
|
||||||
|
if os.path.isdir(filename):
|
||||||
|
continue
|
||||||
|
|
||||||
|
name = os.path.splitext(os.path.basename(filename))[0]
|
||||||
|
try:
|
||||||
|
entry = network.NetworkOnDisk(name, filename)
|
||||||
|
except OSError: # should catch FileNotFoundError and PermissionError etc.
|
||||||
|
errors.report(f"Failed to load network {name} from {filename}", exc_info=True)
|
||||||
|
continue
|
||||||
|
|
||||||
|
available_networks[name] = entry
|
||||||
|
|
||||||
|
if entry.alias in available_network_aliases:
|
||||||
|
forbidden_network_aliases[entry.alias.lower()] = 1
|
||||||
|
|
||||||
|
available_network_aliases[name] = entry
|
||||||
|
available_network_aliases[entry.alias] = entry
|
||||||
|
|
||||||
|
|
||||||
|
re_network_name = re.compile(r"(.*)\s*\([0-9a-fA-F]+\)")
|
||||||
|
|
||||||
|
|
||||||
|
def infotext_pasted(infotext, params):
|
||||||
|
if "AddNet Module 1" in [x[1] for x in scripts.scripts_txt2img.infotext_fields]:
|
||||||
|
return # if the other extension is active, it will handle those fields, no need to do anything
|
||||||
|
|
||||||
|
added = []
|
||||||
|
|
||||||
|
for k in params:
|
||||||
|
if not k.startswith("AddNet Model "):
|
||||||
|
continue
|
||||||
|
|
||||||
|
num = k[13:]
|
||||||
|
|
||||||
|
if params.get("AddNet Module " + num) != "LoRA":
|
||||||
|
continue
|
||||||
|
|
||||||
|
name = params.get("AddNet Model " + num)
|
||||||
|
if name is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
m = re_network_name.match(name)
|
||||||
|
if m:
|
||||||
|
name = m.group(1)
|
||||||
|
|
||||||
|
multiplier = params.get("AddNet Weight A " + num, "1.0")
|
||||||
|
|
||||||
|
added.append(f"<lora:{name}:{multiplier}>")
|
||||||
|
|
||||||
|
if added:
|
||||||
|
params["Prompt"] += "\n" + "".join(added)
|
||||||
|
|
||||||
|
|
||||||
|
available_networks = {}
|
||||||
|
available_network_aliases = {}
|
||||||
|
loaded_networks = []
|
||||||
|
available_network_hash_lookup = {}
|
||||||
|
forbidden_network_aliases = {}
|
||||||
|
|
||||||
|
list_available_networks()
|
@ -4,3 +4,4 @@ from modules import paths
|
|||||||
|
|
||||||
def preload(parser):
|
def preload(parser):
|
||||||
parser.add_argument("--lora-dir", type=str, help="Path to directory with Lora networks.", default=os.path.join(paths.models_path, 'Lora'))
|
parser.add_argument("--lora-dir", type=str, help="Path to directory with Lora networks.", default=os.path.join(paths.models_path, 'Lora'))
|
||||||
|
parser.add_argument("--lyco-dir-backcompat", type=str, help="Path to directory with LyCORIS networks (for backawards compatibility; can also use --lyco-dir).", default=os.path.join(paths.models_path, 'LyCORIS'))
|
||||||
|
@ -4,69 +4,76 @@ import torch
|
|||||||
import gradio as gr
|
import gradio as gr
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
|
|
||||||
import lora
|
import network
|
||||||
|
import networks
|
||||||
|
import lora # noqa:F401
|
||||||
import extra_networks_lora
|
import extra_networks_lora
|
||||||
import ui_extra_networks_lora
|
import ui_extra_networks_lora
|
||||||
from modules import script_callbacks, ui_extra_networks, extra_networks, shared
|
from modules import script_callbacks, ui_extra_networks, extra_networks, shared
|
||||||
|
|
||||||
def unload():
|
def unload():
|
||||||
torch.nn.Linear.forward = torch.nn.Linear_forward_before_lora
|
torch.nn.Linear.forward = torch.nn.Linear_forward_before_network
|
||||||
torch.nn.Linear._load_from_state_dict = torch.nn.Linear_load_state_dict_before_lora
|
torch.nn.Linear._load_from_state_dict = torch.nn.Linear_load_state_dict_before_network
|
||||||
torch.nn.Conv2d.forward = torch.nn.Conv2d_forward_before_lora
|
torch.nn.Conv2d.forward = torch.nn.Conv2d_forward_before_network
|
||||||
torch.nn.Conv2d._load_from_state_dict = torch.nn.Conv2d_load_state_dict_before_lora
|
torch.nn.Conv2d._load_from_state_dict = torch.nn.Conv2d_load_state_dict_before_network
|
||||||
torch.nn.MultiheadAttention.forward = torch.nn.MultiheadAttention_forward_before_lora
|
torch.nn.MultiheadAttention.forward = torch.nn.MultiheadAttention_forward_before_network
|
||||||
torch.nn.MultiheadAttention._load_from_state_dict = torch.nn.MultiheadAttention_load_state_dict_before_lora
|
torch.nn.MultiheadAttention._load_from_state_dict = torch.nn.MultiheadAttention_load_state_dict_before_network
|
||||||
|
|
||||||
|
|
||||||
def before_ui():
|
def before_ui():
|
||||||
ui_extra_networks.register_page(ui_extra_networks_lora.ExtraNetworksPageLora())
|
ui_extra_networks.register_page(ui_extra_networks_lora.ExtraNetworksPageLora())
|
||||||
extra_networks.register_extra_network(extra_networks_lora.ExtraNetworkLora())
|
|
||||||
|
extra_network = extra_networks_lora.ExtraNetworkLora()
|
||||||
|
extra_networks.register_extra_network(extra_network)
|
||||||
|
extra_networks.register_extra_network_alias(extra_network, "lyco")
|
||||||
|
|
||||||
|
|
||||||
if not hasattr(torch.nn, 'Linear_forward_before_lora'):
|
if not hasattr(torch.nn, 'Linear_forward_before_network'):
|
||||||
torch.nn.Linear_forward_before_lora = torch.nn.Linear.forward
|
torch.nn.Linear_forward_before_network = torch.nn.Linear.forward
|
||||||
|
|
||||||
if not hasattr(torch.nn, 'Linear_load_state_dict_before_lora'):
|
if not hasattr(torch.nn, 'Linear_load_state_dict_before_network'):
|
||||||
torch.nn.Linear_load_state_dict_before_lora = torch.nn.Linear._load_from_state_dict
|
torch.nn.Linear_load_state_dict_before_network = torch.nn.Linear._load_from_state_dict
|
||||||
|
|
||||||
if not hasattr(torch.nn, 'Conv2d_forward_before_lora'):
|
if not hasattr(torch.nn, 'Conv2d_forward_before_network'):
|
||||||
torch.nn.Conv2d_forward_before_lora = torch.nn.Conv2d.forward
|
torch.nn.Conv2d_forward_before_network = torch.nn.Conv2d.forward
|
||||||
|
|
||||||
if not hasattr(torch.nn, 'Conv2d_load_state_dict_before_lora'):
|
if not hasattr(torch.nn, 'Conv2d_load_state_dict_before_network'):
|
||||||
torch.nn.Conv2d_load_state_dict_before_lora = torch.nn.Conv2d._load_from_state_dict
|
torch.nn.Conv2d_load_state_dict_before_network = torch.nn.Conv2d._load_from_state_dict
|
||||||
|
|
||||||
if not hasattr(torch.nn, 'MultiheadAttention_forward_before_lora'):
|
if not hasattr(torch.nn, 'MultiheadAttention_forward_before_network'):
|
||||||
torch.nn.MultiheadAttention_forward_before_lora = torch.nn.MultiheadAttention.forward
|
torch.nn.MultiheadAttention_forward_before_network = torch.nn.MultiheadAttention.forward
|
||||||
|
|
||||||
if not hasattr(torch.nn, 'MultiheadAttention_load_state_dict_before_lora'):
|
if not hasattr(torch.nn, 'MultiheadAttention_load_state_dict_before_network'):
|
||||||
torch.nn.MultiheadAttention_load_state_dict_before_lora = torch.nn.MultiheadAttention._load_from_state_dict
|
torch.nn.MultiheadAttention_load_state_dict_before_network = torch.nn.MultiheadAttention._load_from_state_dict
|
||||||
|
|
||||||
torch.nn.Linear.forward = lora.lora_Linear_forward
|
torch.nn.Linear.forward = networks.network_Linear_forward
|
||||||
torch.nn.Linear._load_from_state_dict = lora.lora_Linear_load_state_dict
|
torch.nn.Linear._load_from_state_dict = networks.network_Linear_load_state_dict
|
||||||
torch.nn.Conv2d.forward = lora.lora_Conv2d_forward
|
torch.nn.Conv2d.forward = networks.network_Conv2d_forward
|
||||||
torch.nn.Conv2d._load_from_state_dict = lora.lora_Conv2d_load_state_dict
|
torch.nn.Conv2d._load_from_state_dict = networks.network_Conv2d_load_state_dict
|
||||||
torch.nn.MultiheadAttention.forward = lora.lora_MultiheadAttention_forward
|
torch.nn.MultiheadAttention.forward = networks.network_MultiheadAttention_forward
|
||||||
torch.nn.MultiheadAttention._load_from_state_dict = lora.lora_MultiheadAttention_load_state_dict
|
torch.nn.MultiheadAttention._load_from_state_dict = networks.network_MultiheadAttention_load_state_dict
|
||||||
|
|
||||||
script_callbacks.on_model_loaded(lora.assign_lora_names_to_compvis_modules)
|
script_callbacks.on_model_loaded(networks.assign_network_names_to_compvis_modules)
|
||||||
script_callbacks.on_script_unloaded(unload)
|
script_callbacks.on_script_unloaded(unload)
|
||||||
script_callbacks.on_before_ui(before_ui)
|
script_callbacks.on_before_ui(before_ui)
|
||||||
script_callbacks.on_infotext_pasted(lora.infotext_pasted)
|
script_callbacks.on_infotext_pasted(networks.infotext_pasted)
|
||||||
|
|
||||||
|
|
||||||
shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), {
|
shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), {
|
||||||
"sd_lora": shared.OptionInfo("None", "Add Lora to prompt", gr.Dropdown, lambda: {"choices": ["None", *lora.available_loras]}, refresh=lora.list_available_loras),
|
"sd_lora": shared.OptionInfo("None", "Add network to prompt", gr.Dropdown, lambda: {"choices": ["None", *networks.available_networks]}, refresh=networks.list_available_networks),
|
||||||
"lora_preferred_name": shared.OptionInfo("Alias from file", "When adding to prompt, refer to Lora by", gr.Radio, {"choices": ["Alias from file", "Filename"]}),
|
"lora_preferred_name": shared.OptionInfo("Alias from file", "When adding to prompt, refer to Lora by", gr.Radio, {"choices": ["Alias from file", "Filename"]}),
|
||||||
"lora_add_hashes_to_infotext": shared.OptionInfo(True, "Add Lora hashes to infotext"),
|
"lora_add_hashes_to_infotext": shared.OptionInfo(True, "Add Lora hashes to infotext"),
|
||||||
|
"lora_show_all": shared.OptionInfo(False, "Always show all networks on the Lora page").info("otherwise, those detected as for incompatible version of Stable Diffusion will be hidden"),
|
||||||
|
"lora_hide_unknown_for_versions": shared.OptionInfo([], "Hide networks of unknown versions for model versions", gr.CheckboxGroup, {"choices": ["SD1", "SD2", "SDXL"]}),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
|
|
||||||
shared.options_templates.update(shared.options_section(('compatibility', "Compatibility"), {
|
shared.options_templates.update(shared.options_section(('compatibility', "Compatibility"), {
|
||||||
"lora_functional": shared.OptionInfo(False, "Lora: use old method that takes longer when you have multiple Loras active and produces same results as kohya-ss/sd-webui-additional-networks extension"),
|
"lora_functional": shared.OptionInfo(False, "Lora/Networks: use old method that takes longer when you have multiple Loras active and produces same results as kohya-ss/sd-webui-additional-networks extension"),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
|
|
||||||
def create_lora_json(obj: lora.LoraOnDisk):
|
def create_lora_json(obj: network.NetworkOnDisk):
|
||||||
return {
|
return {
|
||||||
"name": obj.name,
|
"name": obj.name,
|
||||||
"alias": obj.alias,
|
"alias": obj.alias,
|
||||||
@ -75,17 +82,17 @@ def create_lora_json(obj: lora.LoraOnDisk):
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def api_loras(_: gr.Blocks, app: FastAPI):
|
def api_networks(_: gr.Blocks, app: FastAPI):
|
||||||
@app.get("/sdapi/v1/loras")
|
@app.get("/sdapi/v1/loras")
|
||||||
async def get_loras():
|
async def get_loras():
|
||||||
return [create_lora_json(obj) for obj in lora.available_loras.values()]
|
return [create_lora_json(obj) for obj in networks.available_networks.values()]
|
||||||
|
|
||||||
@app.post("/sdapi/v1/refresh-loras")
|
@app.post("/sdapi/v1/refresh-loras")
|
||||||
async def refresh_loras():
|
async def refresh_loras():
|
||||||
return lora.list_available_loras()
|
return networks.list_available_networks()
|
||||||
|
|
||||||
|
|
||||||
script_callbacks.on_app_started(api_loras)
|
script_callbacks.on_app_started(api_networks)
|
||||||
|
|
||||||
re_lora = re.compile("<lora:([^:]+):")
|
re_lora = re.compile("<lora:([^:]+):")
|
||||||
|
|
||||||
@ -98,19 +105,19 @@ def infotext_pasted(infotext, d):
|
|||||||
hashes = [x.strip().split(':', 1) for x in hashes.split(",")]
|
hashes = [x.strip().split(':', 1) for x in hashes.split(",")]
|
||||||
hashes = {x[0].strip().replace(",", ""): x[1].strip() for x in hashes}
|
hashes = {x[0].strip().replace(",", ""): x[1].strip() for x in hashes}
|
||||||
|
|
||||||
def lora_replacement(m):
|
def network_replacement(m):
|
||||||
alias = m.group(1)
|
alias = m.group(1)
|
||||||
shorthash = hashes.get(alias)
|
shorthash = hashes.get(alias)
|
||||||
if shorthash is None:
|
if shorthash is None:
|
||||||
return m.group(0)
|
return m.group(0)
|
||||||
|
|
||||||
lora_on_disk = lora.available_lora_hash_lookup.get(shorthash)
|
network_on_disk = networks.available_network_hash_lookup.get(shorthash)
|
||||||
if lora_on_disk is None:
|
if network_on_disk is None:
|
||||||
return m.group(0)
|
return m.group(0)
|
||||||
|
|
||||||
return f'<lora:{lora_on_disk.get_alias()}:'
|
return f'<lora:{network_on_disk.get_alias()}:'
|
||||||
|
|
||||||
d["Prompt"] = re.sub(re_lora, lora_replacement, d["Prompt"])
|
d["Prompt"] = re.sub(re_lora, network_replacement, d["Prompt"])
|
||||||
|
|
||||||
|
|
||||||
script_callbacks.on_infotext_pasted(infotext_pasted)
|
script_callbacks.on_infotext_pasted(infotext_pasted)
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import datetime
|
||||||
import html
|
import html
|
||||||
import random
|
import random
|
||||||
|
|
||||||
@ -46,14 +47,17 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor)
|
|||||||
def __init__(self, ui, tabname, page):
|
def __init__(self, ui, tabname, page):
|
||||||
super().__init__(ui, tabname, page)
|
super().__init__(ui, tabname, page)
|
||||||
|
|
||||||
|
self.select_sd_version = None
|
||||||
|
|
||||||
self.taginfo = None
|
self.taginfo = None
|
||||||
self.edit_activation_text = None
|
self.edit_activation_text = None
|
||||||
self.slider_preferred_weight = None
|
self.slider_preferred_weight = None
|
||||||
self.edit_notes = None
|
self.edit_notes = None
|
||||||
|
|
||||||
def save_lora_user_metadata(self, name, desc, activation_text, preferred_weight, notes):
|
def save_lora_user_metadata(self, name, desc, sd_version, activation_text, preferred_weight, notes):
|
||||||
user_metadata = self.get_user_metadata(name)
|
user_metadata = self.get_user_metadata(name)
|
||||||
user_metadata["description"] = desc
|
user_metadata["description"] = desc
|
||||||
|
user_metadata["sd version"] = sd_version
|
||||||
user_metadata["activation text"] = activation_text
|
user_metadata["activation text"] = activation_text
|
||||||
user_metadata["preferred weight"] = preferred_weight
|
user_metadata["preferred weight"] = preferred_weight
|
||||||
user_metadata["notes"] = notes
|
user_metadata["notes"] = notes
|
||||||
@ -68,6 +72,7 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor)
|
|||||||
keys = {
|
keys = {
|
||||||
'ss_sd_model_name': "Model:",
|
'ss_sd_model_name': "Model:",
|
||||||
'ss_clip_skip': "Clip skip:",
|
'ss_clip_skip': "Clip skip:",
|
||||||
|
'ss_network_module': "Kohya module:",
|
||||||
}
|
}
|
||||||
|
|
||||||
for key, label in keys.items():
|
for key, label in keys.items():
|
||||||
@ -75,6 +80,10 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor)
|
|||||||
if value is not None and str(value) != "None":
|
if value is not None and str(value) != "None":
|
||||||
table.append((label, html.escape(value)))
|
table.append((label, html.escape(value)))
|
||||||
|
|
||||||
|
ss_training_started_at = metadata.get('ss_training_started_at')
|
||||||
|
if ss_training_started_at:
|
||||||
|
table.append(("Date trained:", datetime.datetime.utcfromtimestamp(float(ss_training_started_at)).strftime('%Y-%m-%d %H:%M')))
|
||||||
|
|
||||||
ss_bucket_info = metadata.get("ss_bucket_info")
|
ss_bucket_info = metadata.get("ss_bucket_info")
|
||||||
if ss_bucket_info and "buckets" in ss_bucket_info:
|
if ss_bucket_info and "buckets" in ss_bucket_info:
|
||||||
resolutions = {}
|
resolutions = {}
|
||||||
@ -112,11 +121,11 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor)
|
|||||||
gradio_tags = [(tag, str(count)) for tag, count in tags[0:24]]
|
gradio_tags = [(tag, str(count)) for tag, count in tags[0:24]]
|
||||||
|
|
||||||
return [
|
return [
|
||||||
*values[0:4],
|
*values[0:5],
|
||||||
|
item.get("sd_version", "Unknown"),
|
||||||
gr.HighlightedText.update(value=gradio_tags, visible=True if tags else False),
|
gr.HighlightedText.update(value=gradio_tags, visible=True if tags else False),
|
||||||
user_metadata.get('activation text', ''),
|
user_metadata.get('activation text', ''),
|
||||||
float(user_metadata.get('preferred weight', 0.0)),
|
float(user_metadata.get('preferred weight', 0.0)),
|
||||||
user_metadata.get('notes', ''),
|
|
||||||
gr.update(visible=True if tags else False),
|
gr.update(visible=True if tags else False),
|
||||||
gr.update(value=self.generate_random_prompt_from_tags(tags), visible=True if tags else False),
|
gr.update(value=self.generate_random_prompt_from_tags(tags), visible=True if tags else False),
|
||||||
]
|
]
|
||||||
@ -141,10 +150,15 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor)
|
|||||||
|
|
||||||
return ", ".join(sorted(res))
|
return ", ".join(sorted(res))
|
||||||
|
|
||||||
|
def create_extra_default_items_in_left_column(self):
|
||||||
|
|
||||||
|
# this would be a lot better as gr.Radio but I can't make it work
|
||||||
|
self.select_sd_version = gr.Dropdown(['SD1', 'SD2', 'SDXL', 'Unknown'], value='Unknown', label='Stable Diffusion version', interactive=True)
|
||||||
|
|
||||||
def create_editor(self):
|
def create_editor(self):
|
||||||
self.create_default_editor_elems()
|
self.create_default_editor_elems()
|
||||||
|
|
||||||
self.taginfo = gr.HighlightedText(label="Tags")
|
self.taginfo = gr.HighlightedText(label="Training dataset tags")
|
||||||
self.edit_activation_text = gr.Text(label='Activation text', info="Will be added to prompt along with Lora")
|
self.edit_activation_text = gr.Text(label='Activation text', info="Will be added to prompt along with Lora")
|
||||||
self.slider_preferred_weight = gr.Slider(label='Preferred weight', info="Set to 0 to disable", minimum=0.0, maximum=2.0, step=0.01)
|
self.slider_preferred_weight = gr.Slider(label='Preferred weight', info="Set to 0 to disable", minimum=0.0, maximum=2.0, step=0.01)
|
||||||
|
|
||||||
@ -153,7 +167,7 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor)
|
|||||||
random_prompt = gr.Textbox(label='Random prompt', lines=4, max_lines=4, interactive=False)
|
random_prompt = gr.Textbox(label='Random prompt', lines=4, max_lines=4, interactive=False)
|
||||||
|
|
||||||
with gr.Column(scale=1, min_width=120):
|
with gr.Column(scale=1, min_width=120):
|
||||||
generate_random_prompt = gr.Button('Generate').style(full_width=True, size="lg")
|
generate_random_prompt = gr.Button('Generate', size="lg", scale=1)
|
||||||
|
|
||||||
self.edit_notes = gr.TextArea(label='Notes', lines=4)
|
self.edit_notes = gr.TextArea(label='Notes', lines=4)
|
||||||
|
|
||||||
@ -178,10 +192,11 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor)
|
|||||||
self.edit_description,
|
self.edit_description,
|
||||||
self.html_filedata,
|
self.html_filedata,
|
||||||
self.html_preview,
|
self.html_preview,
|
||||||
|
self.edit_notes,
|
||||||
|
self.select_sd_version,
|
||||||
self.taginfo,
|
self.taginfo,
|
||||||
self.edit_activation_text,
|
self.edit_activation_text,
|
||||||
self.slider_preferred_weight,
|
self.slider_preferred_weight,
|
||||||
self.edit_notes,
|
|
||||||
row_random_prompt,
|
row_random_prompt,
|
||||||
random_prompt,
|
random_prompt,
|
||||||
]
|
]
|
||||||
@ -192,6 +207,7 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor)
|
|||||||
|
|
||||||
edited_components = [
|
edited_components = [
|
||||||
self.edit_description,
|
self.edit_description,
|
||||||
|
self.select_sd_version,
|
||||||
self.edit_activation_text,
|
self.edit_activation_text,
|
||||||
self.slider_preferred_weight,
|
self.slider_preferred_weight,
|
||||||
self.edit_notes,
|
self.edit_notes,
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
import lora
|
|
||||||
|
import network
|
||||||
|
import networks
|
||||||
|
|
||||||
from modules import shared, ui_extra_networks
|
from modules import shared, ui_extra_networks
|
||||||
from modules.ui_extra_networks import quote_js
|
from modules.ui_extra_networks import quote_js
|
||||||
@ -11,16 +13,15 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
|
|||||||
super().__init__('Lora')
|
super().__init__('Lora')
|
||||||
|
|
||||||
def refresh(self):
|
def refresh(self):
|
||||||
lora.list_available_loras()
|
networks.list_available_networks()
|
||||||
|
|
||||||
def create_item(self, name, index=None):
|
def create_item(self, name, index=None, enable_filter=True):
|
||||||
lora_on_disk = lora.available_loras.get(name)
|
lora_on_disk = networks.available_networks.get(name)
|
||||||
|
|
||||||
path, ext = os.path.splitext(lora_on_disk.filename)
|
path, ext = os.path.splitext(lora_on_disk.filename)
|
||||||
|
|
||||||
alias = lora_on_disk.get_alias()
|
alias = lora_on_disk.get_alias()
|
||||||
|
|
||||||
# in 1.5 filename changes to be full filename instead of path without extension, and metadata is dict instead of json string
|
|
||||||
item = {
|
item = {
|
||||||
"name": name,
|
"name": name,
|
||||||
"filename": lora_on_disk.filename,
|
"filename": lora_on_disk.filename,
|
||||||
@ -30,6 +31,7 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
|
|||||||
"local_preview": f"{path}.{shared.opts.samples_format}",
|
"local_preview": f"{path}.{shared.opts.samples_format}",
|
||||||
"metadata": lora_on_disk.metadata,
|
"metadata": lora_on_disk.metadata,
|
||||||
"sort_keys": {'default': index, **self.get_sort_keys(lora_on_disk.filename)},
|
"sort_keys": {'default': index, **self.get_sort_keys(lora_on_disk.filename)},
|
||||||
|
"sd_version": lora_on_disk.sd_version.name,
|
||||||
}
|
}
|
||||||
|
|
||||||
self.read_user_metadata(item)
|
self.read_user_metadata(item)
|
||||||
@ -40,15 +42,37 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
|
|||||||
if activation_text:
|
if activation_text:
|
||||||
item["prompt"] += " + " + quote_js(" " + activation_text)
|
item["prompt"] += " + " + quote_js(" " + activation_text)
|
||||||
|
|
||||||
|
sd_version = item["user_metadata"].get("sd version")
|
||||||
|
if sd_version in network.SdVersion.__members__:
|
||||||
|
item["sd_version"] = sd_version
|
||||||
|
sd_version = network.SdVersion[sd_version]
|
||||||
|
else:
|
||||||
|
sd_version = lora_on_disk.sd_version
|
||||||
|
|
||||||
|
if shared.opts.lora_show_all or not enable_filter:
|
||||||
|
pass
|
||||||
|
elif sd_version == network.SdVersion.Unknown:
|
||||||
|
model_version = network.SdVersion.SDXL if shared.sd_model.is_sdxl else network.SdVersion.SD2 if shared.sd_model.is_sd2 else network.SdVersion.SD1
|
||||||
|
if model_version.name in shared.opts.lora_hide_unknown_for_versions:
|
||||||
|
return None
|
||||||
|
elif shared.sd_model.is_sdxl and sd_version != network.SdVersion.SDXL:
|
||||||
|
return None
|
||||||
|
elif shared.sd_model.is_sd2 and sd_version != network.SdVersion.SD2:
|
||||||
|
return None
|
||||||
|
elif shared.sd_model.is_sd1 and sd_version != network.SdVersion.SD1:
|
||||||
|
return None
|
||||||
|
|
||||||
return item
|
return item
|
||||||
|
|
||||||
def list_items(self):
|
def list_items(self):
|
||||||
for index, name in enumerate(lora.available_loras):
|
for index, name in enumerate(networks.available_networks):
|
||||||
item = self.create_item(name, index)
|
item = self.create_item(name, index)
|
||||||
yield item
|
|
||||||
|
if item is not None:
|
||||||
|
yield item
|
||||||
|
|
||||||
def allowed_directories_for_previews(self):
|
def allowed_directories_for_previews(self):
|
||||||
return [shared.cmd_opts.lora_dir]
|
return [shared.cmd_opts.lora_dir, shared.cmd_opts.lyco_dir_backcompat]
|
||||||
|
|
||||||
def create_user_metadata_editor(self, ui, tabname):
|
def create_user_metadata_editor(self, ui, tabname):
|
||||||
return LoraUserMetadataEditor(ui, tabname, self)
|
return LoraUserMetadataEditor(ui, tabname, self)
|
||||||
|
26
extensions-builtin/mobile/javascript/mobile.js
Normal file
26
extensions-builtin/mobile/javascript/mobile.js
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
var isSetupForMobile = false;
|
||||||
|
|
||||||
|
function isMobile() {
|
||||||
|
for (var tab of ["txt2img", "img2img"]) {
|
||||||
|
var imageTab = gradioApp().getElementById(tab + '_results');
|
||||||
|
if (imageTab && imageTab.offsetParent && imageTab.offsetLeft == 0) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
function reportWindowSize() {
|
||||||
|
var currentlyMobile = isMobile();
|
||||||
|
if (currentlyMobile == isSetupForMobile) return;
|
||||||
|
isSetupForMobile = currentlyMobile;
|
||||||
|
|
||||||
|
for (var tab of ["txt2img", "img2img"]) {
|
||||||
|
var button = gradioApp().getElementById(tab + '_generate_box');
|
||||||
|
var target = gradioApp().getElementById(currentlyMobile ? tab + '_results' : tab + '_actions_column');
|
||||||
|
target.insertBefore(button, target.firstElementChild);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
window.addEventListener("resize", reportWindowSize);
|
@ -1,8 +1,8 @@
|
|||||||
<div class='card' style={style} onclick={card_clicked} data-name="{name}" {sort_keys}>
|
<div class='card' style={style} onclick={card_clicked} data-name="{name}" {sort_keys}>
|
||||||
{background_image}
|
{background_image}
|
||||||
<div class="button-row">
|
<div class="button-row">
|
||||||
{edit_button}
|
|
||||||
{metadata_button}
|
{metadata_button}
|
||||||
|
{edit_button}
|
||||||
</div>
|
</div>
|
||||||
<div class='actions'>
|
<div class='actions'>
|
||||||
<div class='additional'>
|
<div class='additional'>
|
||||||
|
@ -211,7 +211,7 @@ function popup(contents) {
|
|||||||
globalPopupInner.classList.add('global-popup-inner');
|
globalPopupInner.classList.add('global-popup-inner');
|
||||||
globalPopup.appendChild(globalPopupInner);
|
globalPopup.appendChild(globalPopupInner);
|
||||||
|
|
||||||
gradioApp().appendChild(globalPopup);
|
gradioApp().querySelector('.main').appendChild(globalPopup);
|
||||||
}
|
}
|
||||||
|
|
||||||
globalPopupInner.innerHTML = '';
|
globalPopupInner.innerHTML = '';
|
||||||
|
@ -190,3 +190,14 @@ onUiUpdate(function(mutationRecords) {
|
|||||||
tooltipCheckTimer = setTimeout(processTooltipCheckNodes, 1000);
|
tooltipCheckTimer = setTimeout(processTooltipCheckNodes, 1000);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
|
onUiLoaded(function() {
|
||||||
|
for (var comp of window.gradio_config.components) {
|
||||||
|
if (comp.props.webui_tooltip && comp.props.elem_id) {
|
||||||
|
var elem = gradioApp().getElementById(comp.props.elem_id);
|
||||||
|
if (elem) {
|
||||||
|
elem.title = comp.props.webui_tooltip;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
@ -11,11 +11,11 @@ var ignore_ids_for_localization = {
|
|||||||
train_hypernetwork: 'OPTION',
|
train_hypernetwork: 'OPTION',
|
||||||
txt2img_styles: 'OPTION',
|
txt2img_styles: 'OPTION',
|
||||||
img2img_styles: 'OPTION',
|
img2img_styles: 'OPTION',
|
||||||
setting_random_artist_categories: 'SPAN',
|
setting_random_artist_categories: 'OPTION',
|
||||||
setting_face_restoration_model: 'SPAN',
|
setting_face_restoration_model: 'OPTION',
|
||||||
setting_realesrgan_enabled_models: 'SPAN',
|
setting_realesrgan_enabled_models: 'OPTION',
|
||||||
extras_upscaler_1: 'SPAN',
|
extras_upscaler_1: 'OPTION',
|
||||||
extras_upscaler_2: 'SPAN',
|
extras_upscaler_2: 'OPTION',
|
||||||
};
|
};
|
||||||
|
|
||||||
var re_num = /^[.\d]+$/;
|
var re_num = /^[.\d]+$/;
|
||||||
|
@ -152,7 +152,11 @@ function submit() {
|
|||||||
showSubmitButtons('txt2img', false);
|
showSubmitButtons('txt2img', false);
|
||||||
|
|
||||||
var id = randomId();
|
var id = randomId();
|
||||||
localStorage.setItem("txt2img_task_id", id);
|
try {
|
||||||
|
localStorage.setItem("txt2img_task_id", id);
|
||||||
|
} catch (e) {
|
||||||
|
console.warn(`Failed to save txt2img task id to localStorage: ${e}`);
|
||||||
|
}
|
||||||
|
|
||||||
requestProgress(id, gradioApp().getElementById('txt2img_gallery_container'), gradioApp().getElementById('txt2img_gallery'), function() {
|
requestProgress(id, gradioApp().getElementById('txt2img_gallery_container'), gradioApp().getElementById('txt2img_gallery'), function() {
|
||||||
showSubmitButtons('txt2img', true);
|
showSubmitButtons('txt2img', true);
|
||||||
@ -171,7 +175,11 @@ function submit_img2img() {
|
|||||||
showSubmitButtons('img2img', false);
|
showSubmitButtons('img2img', false);
|
||||||
|
|
||||||
var id = randomId();
|
var id = randomId();
|
||||||
localStorage.setItem("img2img_task_id", id);
|
try {
|
||||||
|
localStorage.setItem("img2img_task_id", id);
|
||||||
|
} catch (e) {
|
||||||
|
console.warn(`Failed to save img2img task id to localStorage: ${e}`);
|
||||||
|
}
|
||||||
|
|
||||||
requestProgress(id, gradioApp().getElementById('img2img_gallery_container'), gradioApp().getElementById('img2img_gallery'), function() {
|
requestProgress(id, gradioApp().getElementById('img2img_gallery_container'), gradioApp().getElementById('img2img_gallery'), function() {
|
||||||
showSubmitButtons('img2img', true);
|
showSubmitButtons('img2img', true);
|
||||||
@ -191,8 +199,6 @@ function restoreProgressTxt2img() {
|
|||||||
showRestoreProgressButton("txt2img", false);
|
showRestoreProgressButton("txt2img", false);
|
||||||
var id = localStorage.getItem("txt2img_task_id");
|
var id = localStorage.getItem("txt2img_task_id");
|
||||||
|
|
||||||
id = localStorage.getItem("txt2img_task_id");
|
|
||||||
|
|
||||||
if (id) {
|
if (id) {
|
||||||
requestProgress(id, gradioApp().getElementById('txt2img_gallery_container'), gradioApp().getElementById('txt2img_gallery'), function() {
|
requestProgress(id, gradioApp().getElementById('txt2img_gallery_container'), gradioApp().getElementById('txt2img_gallery'), function() {
|
||||||
showSubmitButtons('txt2img', true);
|
showSubmitButtons('txt2img', true);
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
from modules import launch_utils
|
from modules import launch_utils
|
||||||
|
|
||||||
|
|
||||||
args = launch_utils.args
|
args = launch_utils.args
|
||||||
python = launch_utils.python
|
python = launch_utils.python
|
||||||
git = launch_utils.git
|
git = launch_utils.git
|
||||||
@ -18,6 +17,7 @@ run_pip = launch_utils.run_pip
|
|||||||
check_run_python = launch_utils.check_run_python
|
check_run_python = launch_utils.check_run_python
|
||||||
git_clone = launch_utils.git_clone
|
git_clone = launch_utils.git_clone
|
||||||
git_pull_recursive = launch_utils.git_pull_recursive
|
git_pull_recursive = launch_utils.git_pull_recursive
|
||||||
|
list_extensions = launch_utils.list_extensions
|
||||||
run_extension_installer = launch_utils.run_extension_installer
|
run_extension_installer = launch_utils.run_extension_installer
|
||||||
prepare_environment = launch_utils.prepare_environment
|
prepare_environment = launch_utils.prepare_environment
|
||||||
configure_for_tests = launch_utils.configure_for_tests
|
configure_for_tests = launch_utils.configure_for_tests
|
||||||
@ -25,8 +25,11 @@ start = launch_utils.start
|
|||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
if not args.skip_prepare_environment:
|
launch_utils.startup_timer.record("initial startup")
|
||||||
prepare_environment()
|
|
||||||
|
with launch_utils.startup_timer.subcategory("prepare environment"):
|
||||||
|
if not args.skip_prepare_environment:
|
||||||
|
prepare_environment()
|
||||||
|
|
||||||
if args.test_server:
|
if args.test_server:
|
||||||
configure_for_tests()
|
configure_for_tests()
|
||||||
|
@ -15,7 +15,7 @@ from fastapi.encoders import jsonable_encoder
|
|||||||
from secrets import compare_digest
|
from secrets import compare_digest
|
||||||
|
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart
|
from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart, shared_items
|
||||||
from modules.api import models
|
from modules.api import models
|
||||||
from modules.shared import opts
|
from modules.shared import opts
|
||||||
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
|
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
|
||||||
@ -197,6 +197,7 @@ class Api:
|
|||||||
self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[models.PromptStyleItem])
|
self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[models.PromptStyleItem])
|
||||||
self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=models.EmbeddingsResponse)
|
self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=models.EmbeddingsResponse)
|
||||||
self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"])
|
self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"])
|
||||||
|
self.add_api_route("/sdapi/v1/refresh-vae", self.refresh_vae, methods=["POST"])
|
||||||
self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=models.CreateResponse)
|
self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=models.CreateResponse)
|
||||||
self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=models.CreateResponse)
|
self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=models.CreateResponse)
|
||||||
self.add_api_route("/sdapi/v1/preprocess", self.preprocess, methods=["POST"], response_model=models.PreprocessResponse)
|
self.add_api_route("/sdapi/v1/preprocess", self.preprocess, methods=["POST"], response_model=models.PreprocessResponse)
|
||||||
@ -333,14 +334,17 @@ class Api:
|
|||||||
p.outpath_grids = opts.outdir_txt2img_grids
|
p.outpath_grids = opts.outdir_txt2img_grids
|
||||||
p.outpath_samples = opts.outdir_txt2img_samples
|
p.outpath_samples = opts.outdir_txt2img_samples
|
||||||
|
|
||||||
shared.state.begin(job="scripts_txt2img")
|
try:
|
||||||
if selectable_scripts is not None:
|
shared.state.begin(job="scripts_txt2img")
|
||||||
p.script_args = script_args
|
if selectable_scripts is not None:
|
||||||
processed = scripts.scripts_txt2img.run(p, *p.script_args) # Need to pass args as list here
|
p.script_args = script_args
|
||||||
else:
|
processed = scripts.scripts_txt2img.run(p, *p.script_args) # Need to pass args as list here
|
||||||
p.script_args = tuple(script_args) # Need to pass args as tuple here
|
else:
|
||||||
processed = process_images(p)
|
p.script_args = tuple(script_args) # Need to pass args as tuple here
|
||||||
shared.state.end()
|
processed = process_images(p)
|
||||||
|
finally:
|
||||||
|
shared.state.end()
|
||||||
|
shared.total_tqdm.clear()
|
||||||
|
|
||||||
b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else []
|
b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else []
|
||||||
|
|
||||||
@ -390,14 +394,17 @@ class Api:
|
|||||||
p.outpath_grids = opts.outdir_img2img_grids
|
p.outpath_grids = opts.outdir_img2img_grids
|
||||||
p.outpath_samples = opts.outdir_img2img_samples
|
p.outpath_samples = opts.outdir_img2img_samples
|
||||||
|
|
||||||
shared.state.begin(job="scripts_img2img")
|
try:
|
||||||
if selectable_scripts is not None:
|
shared.state.begin(job="scripts_img2img")
|
||||||
p.script_args = script_args
|
if selectable_scripts is not None:
|
||||||
processed = scripts.scripts_img2img.run(p, *p.script_args) # Need to pass args as list here
|
p.script_args = script_args
|
||||||
else:
|
processed = scripts.scripts_img2img.run(p, *p.script_args) # Need to pass args as list here
|
||||||
p.script_args = tuple(script_args) # Need to pass args as tuple here
|
else:
|
||||||
processed = process_images(p)
|
p.script_args = tuple(script_args) # Need to pass args as tuple here
|
||||||
shared.state.end()
|
processed = process_images(p)
|
||||||
|
finally:
|
||||||
|
shared.state.end()
|
||||||
|
shared.total_tqdm.clear()
|
||||||
|
|
||||||
b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else []
|
b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else []
|
||||||
|
|
||||||
@ -604,6 +611,10 @@ class Api:
|
|||||||
with self.queue_lock:
|
with self.queue_lock:
|
||||||
shared.refresh_checkpoints()
|
shared.refresh_checkpoints()
|
||||||
|
|
||||||
|
def refresh_vae(self):
|
||||||
|
with self.queue_lock:
|
||||||
|
shared_items.refresh_vae_list()
|
||||||
|
|
||||||
def create_embedding(self, args: dict):
|
def create_embedding(self, args: dict):
|
||||||
try:
|
try:
|
||||||
shared.state.begin(job="create_embedding")
|
shared.state.begin(job="create_embedding")
|
||||||
@ -720,9 +731,9 @@ class Api:
|
|||||||
cuda = {'error': f'{err}'}
|
cuda = {'error': f'{err}'}
|
||||||
return models.MemoryResponse(ram=ram, cuda=cuda)
|
return models.MemoryResponse(ram=ram, cuda=cuda)
|
||||||
|
|
||||||
def launch(self, server_name, port):
|
def launch(self, server_name, port, root_path):
|
||||||
self.app.include_router(self.router)
|
self.app.include_router(self.router)
|
||||||
uvicorn.run(self.app, host=server_name, port=port, timeout_keep_alive=shared.cmd_opts.timeout_keep_alive)
|
uvicorn.run(self.app, host=server_name, port=port, timeout_keep_alive=shared.cmd_opts.timeout_keep_alive, root_path=root_path)
|
||||||
|
|
||||||
def kill_webui(self):
|
def kill_webui(self):
|
||||||
restart.stop_program()
|
restart.stop_program()
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import inspect
|
import inspect
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, create_model
|
from pydantic import BaseModel, Field, create_model
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
from typing_extensions import Literal
|
from typing_extensions import Literal
|
||||||
@ -207,11 +208,10 @@ class PreprocessResponse(BaseModel):
|
|||||||
fields = {}
|
fields = {}
|
||||||
for key, metadata in opts.data_labels.items():
|
for key, metadata in opts.data_labels.items():
|
||||||
value = opts.data.get(key)
|
value = opts.data.get(key)
|
||||||
optType = opts.typemap.get(type(metadata.default), type(value))
|
optType = opts.typemap.get(type(metadata.default), type(metadata.default)) if metadata.default else Any
|
||||||
|
|
||||||
if (metadata is not None):
|
if metadata is not None:
|
||||||
fields.update({key: (Optional[optType], Field(
|
fields.update({key: (Optional[optType], Field(default=metadata.default, description=metadata.label))})
|
||||||
default=metadata.default ,description=metadata.label))})
|
|
||||||
else:
|
else:
|
||||||
fields.update({key: (Optional[optType], Field())})
|
fields.update({key: (Optional[optType], Field())})
|
||||||
|
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import json
|
import json
|
||||||
import os.path
|
import os.path
|
||||||
import threading
|
import threading
|
||||||
|
import time
|
||||||
|
|
||||||
from modules.paths import data_path, script_path
|
from modules.paths import data_path, script_path
|
||||||
|
|
||||||
@ -8,15 +9,37 @@ cache_filename = os.path.join(data_path, "cache.json")
|
|||||||
cache_data = None
|
cache_data = None
|
||||||
cache_lock = threading.Lock()
|
cache_lock = threading.Lock()
|
||||||
|
|
||||||
|
dump_cache_after = None
|
||||||
|
dump_cache_thread = None
|
||||||
|
|
||||||
|
|
||||||
def dump_cache():
|
def dump_cache():
|
||||||
"""
|
"""
|
||||||
Saves all cache data to a file.
|
Marks cache for writing to disk. 5 seconds after no one else flags the cache for writing, it is written.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
global dump_cache_after
|
||||||
|
global dump_cache_thread
|
||||||
|
|
||||||
|
def thread_func():
|
||||||
|
global dump_cache_after
|
||||||
|
global dump_cache_thread
|
||||||
|
|
||||||
|
while dump_cache_after is not None and time.time() < dump_cache_after:
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
|
with cache_lock:
|
||||||
|
with open(cache_filename, "w", encoding="utf8") as file:
|
||||||
|
json.dump(cache_data, file, indent=4)
|
||||||
|
|
||||||
|
dump_cache_after = None
|
||||||
|
dump_cache_thread = None
|
||||||
|
|
||||||
with cache_lock:
|
with cache_lock:
|
||||||
with open(cache_filename, "w", encoding="utf8") as file:
|
dump_cache_after = time.time() + 5
|
||||||
json.dump(cache_data, file, indent=4)
|
if dump_cache_thread is None:
|
||||||
|
dump_cache_thread = threading.Thread(name='cache-writer', target=thread_func)
|
||||||
|
dump_cache_thread.start()
|
||||||
|
|
||||||
|
|
||||||
def cache(subsection):
|
def cache(subsection):
|
||||||
@ -84,7 +107,7 @@ def cached_data_for_file(subsection, title, filename, func):
|
|||||||
if ondisk_mtime > cached_mtime:
|
if ondisk_mtime > cached_mtime:
|
||||||
entry = None
|
entry = None
|
||||||
|
|
||||||
if not entry:
|
if not entry or 'value' not in entry:
|
||||||
value = func()
|
value = func()
|
||||||
if value is None:
|
if value is None:
|
||||||
return None
|
return None
|
||||||
|
@ -3,7 +3,7 @@ import html
|
|||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
|
|
||||||
from modules import shared, progress, errors
|
from modules import shared, progress, errors, devices
|
||||||
|
|
||||||
queue_lock = threading.Lock()
|
queue_lock = threading.Lock()
|
||||||
|
|
||||||
@ -75,6 +75,8 @@ def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
|
|||||||
error_message = f'{type(e).__name__}: {e}'
|
error_message = f'{type(e).__name__}: {e}'
|
||||||
res = extra_outputs_array + [f"<div class='error'>{html.escape(error_message)}</div>"]
|
res = extra_outputs_array + [f"<div class='error'>{html.escape(error_message)}</div>"]
|
||||||
|
|
||||||
|
devices.torch_gc()
|
||||||
|
|
||||||
shared.state.skipped = False
|
shared.state.skipped = False
|
||||||
shared.state.interrupted = False
|
shared.state.interrupted = False
|
||||||
shared.state.job_count = 0
|
shared.state.job_count = 0
|
||||||
|
@ -13,8 +13,10 @@ parser.add_argument("--reinstall-xformers", action='store_true', help="launch.py
|
|||||||
parser.add_argument("--reinstall-torch", action='store_true', help="launch.py argument: install the appropriate version of torch even if you have some version already installed")
|
parser.add_argument("--reinstall-torch", action='store_true', help="launch.py argument: install the appropriate version of torch even if you have some version already installed")
|
||||||
parser.add_argument("--update-check", action='store_true', help="launch.py argument: check for updates at startup")
|
parser.add_argument("--update-check", action='store_true', help="launch.py argument: check for updates at startup")
|
||||||
parser.add_argument("--test-server", action='store_true', help="launch.py argument: configure server for testing")
|
parser.add_argument("--test-server", action='store_true', help="launch.py argument: configure server for testing")
|
||||||
|
parser.add_argument("--log-startup", action='store_true', help="launch.py argument: print a detailed log of what's happening at startup")
|
||||||
parser.add_argument("--skip-prepare-environment", action='store_true', help="launch.py argument: skip all environment preparation")
|
parser.add_argument("--skip-prepare-environment", action='store_true', help="launch.py argument: skip all environment preparation")
|
||||||
parser.add_argument("--skip-install", action='store_true', help="launch.py argument: skip installation of packages")
|
parser.add_argument("--skip-install", action='store_true', help="launch.py argument: skip installation of packages")
|
||||||
|
parser.add_argument("--do-not-download-clip", action='store_true', help="do not download CLIP model even if it's not included in the checkpoint")
|
||||||
parser.add_argument("--data-dir", type=str, default=os.path.dirname(os.path.dirname(os.path.realpath(__file__))), help="base path where all user data is stored")
|
parser.add_argument("--data-dir", type=str, default=os.path.dirname(os.path.dirname(os.path.realpath(__file__))), help="base path where all user data is stored")
|
||||||
parser.add_argument("--config", type=str, default=sd_default_config, help="path to config which constructs model",)
|
parser.add_argument("--config", type=str, default=sd_default_config, help="path to config which constructs model",)
|
||||||
parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",)
|
parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",)
|
||||||
@ -65,6 +67,7 @@ parser.add_argument("--opt-sdp-no-mem-attention", action='store_true', help="pre
|
|||||||
parser.add_argument("--disable-opt-split-attention", action='store_true', help="prefer no cross-attention layer optimization for automatic choice of optimization")
|
parser.add_argument("--disable-opt-split-attention", action='store_true', help="prefer no cross-attention layer optimization for automatic choice of optimization")
|
||||||
parser.add_argument("--disable-nan-check", action='store_true', help="do not check if produced images/latent spaces have nans; useful for running without a checkpoint in CI")
|
parser.add_argument("--disable-nan-check", action='store_true', help="do not check if produced images/latent spaces have nans; useful for running without a checkpoint in CI")
|
||||||
parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for specified modules", default=[], type=str.lower)
|
parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for specified modules", default=[], type=str.lower)
|
||||||
|
parser.add_argument("--disable-model-loading-ram-optimization", action='store_true', help="disable an optimization that reduces RAM use when loading a model")
|
||||||
parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests")
|
parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests")
|
||||||
parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None)
|
parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None)
|
||||||
parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False)
|
parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False)
|
||||||
@ -109,3 +112,5 @@ parser.add_argument('--subpath', type=str, help='customize the subpath for gradi
|
|||||||
parser.add_argument('--add-stop-route', action='store_true', help='add /_stop route to stop server')
|
parser.add_argument('--add-stop-route', action='store_true', help='add /_stop route to stop server')
|
||||||
parser.add_argument('--api-server-stop', action='store_true', help='enable server stop/restart/kill via api')
|
parser.add_argument('--api-server-stop', action='store_true', help='enable server stop/restart/kill via api')
|
||||||
parser.add_argument('--timeout-keep-alive', type=int, default=30, help='set timeout_keep_alive for uvicorn')
|
parser.add_argument('--timeout-keep-alive', type=int, default=30, help='set timeout_keep_alive for uvicorn')
|
||||||
|
parser.add_argument("--disable-all-extensions", action='store_true', help="prevent all extensions from running regardless of any other settings", default=False)
|
||||||
|
parser.add_argument("--disable-extra-extensions", action='store_true', help=" prevent all extensions except built-in from running regardless of any other settings", default=False)
|
||||||
|
@ -3,7 +3,7 @@ import contextlib
|
|||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from modules import errors
|
from modules import errors, rng_philox
|
||||||
|
|
||||||
if sys.platform == "darwin":
|
if sys.platform == "darwin":
|
||||||
from modules import mac_specific
|
from modules import mac_specific
|
||||||
@ -71,14 +71,17 @@ def enable_tf32():
|
|||||||
torch.backends.cudnn.allow_tf32 = True
|
torch.backends.cudnn.allow_tf32 = True
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
errors.run(enable_tf32, "Enabling TF32")
|
errors.run(enable_tf32, "Enabling TF32")
|
||||||
|
|
||||||
cpu = torch.device("cpu")
|
cpu: torch.device = torch.device("cpu")
|
||||||
device = device_interrogate = device_gfpgan = device_esrgan = device_codeformer = None
|
device: torch.device = None
|
||||||
dtype = torch.float16
|
device_interrogate: torch.device = None
|
||||||
dtype_vae = torch.float16
|
device_gfpgan: torch.device = None
|
||||||
dtype_unet = torch.float16
|
device_esrgan: torch.device = None
|
||||||
|
device_codeformer: torch.device = None
|
||||||
|
dtype: torch.dtype = torch.float16
|
||||||
|
dtype_vae: torch.dtype = torch.float16
|
||||||
|
dtype_unet: torch.dtype = torch.float16
|
||||||
unet_needs_upcast = False
|
unet_needs_upcast = False
|
||||||
|
|
||||||
|
|
||||||
@ -90,23 +93,87 @@ def cond_cast_float(input):
|
|||||||
return input.float() if unet_needs_upcast else input
|
return input.float() if unet_needs_upcast else input
|
||||||
|
|
||||||
|
|
||||||
|
nv_rng = None
|
||||||
|
|
||||||
|
|
||||||
def randn(seed, shape):
|
def randn(seed, shape):
|
||||||
|
"""Generate a tensor with random numbers from a normal distribution using seed.
|
||||||
|
|
||||||
|
Uses the seed parameter to set the global torch seed; to generate more with that seed, use randn_like/randn_without_seed."""
|
||||||
|
|
||||||
from modules.shared import opts
|
from modules.shared import opts
|
||||||
|
|
||||||
torch.manual_seed(seed)
|
manual_seed(seed)
|
||||||
|
|
||||||
|
if opts.randn_source == "NV":
|
||||||
|
return torch.asarray(nv_rng.randn(shape), device=device)
|
||||||
|
|
||||||
if opts.randn_source == "CPU" or device.type == 'mps':
|
if opts.randn_source == "CPU" or device.type == 'mps':
|
||||||
return torch.randn(shape, device=cpu).to(device)
|
return torch.randn(shape, device=cpu).to(device)
|
||||||
|
|
||||||
return torch.randn(shape, device=device)
|
return torch.randn(shape, device=device)
|
||||||
|
|
||||||
|
|
||||||
|
def randn_local(seed, shape):
|
||||||
|
"""Generate a tensor with random numbers from a normal distribution using seed.
|
||||||
|
|
||||||
|
Does not change the global random number generator. You can only generate the seed's first tensor using this function."""
|
||||||
|
|
||||||
|
from modules.shared import opts
|
||||||
|
|
||||||
|
if opts.randn_source == "NV":
|
||||||
|
rng = rng_philox.Generator(seed)
|
||||||
|
return torch.asarray(rng.randn(shape), device=device)
|
||||||
|
|
||||||
|
local_device = cpu if opts.randn_source == "CPU" or device.type == 'mps' else device
|
||||||
|
local_generator = torch.Generator(local_device).manual_seed(int(seed))
|
||||||
|
return torch.randn(shape, device=local_device, generator=local_generator).to(device)
|
||||||
|
|
||||||
|
|
||||||
|
def randn_like(x):
|
||||||
|
"""Generate a tensor with random numbers from a normal distribution using the previously initialized genrator.
|
||||||
|
|
||||||
|
Use either randn() or manual_seed() to initialize the generator."""
|
||||||
|
|
||||||
|
from modules.shared import opts
|
||||||
|
|
||||||
|
if opts.randn_source == "NV":
|
||||||
|
return torch.asarray(nv_rng.randn(x.shape), device=x.device, dtype=x.dtype)
|
||||||
|
|
||||||
|
if opts.randn_source == "CPU" or x.device.type == 'mps':
|
||||||
|
return torch.randn_like(x, device=cpu).to(x.device)
|
||||||
|
|
||||||
|
return torch.randn_like(x)
|
||||||
|
|
||||||
|
|
||||||
def randn_without_seed(shape):
|
def randn_without_seed(shape):
|
||||||
|
"""Generate a tensor with random numbers from a normal distribution using the previously initialized genrator.
|
||||||
|
|
||||||
|
Use either randn() or manual_seed() to initialize the generator."""
|
||||||
|
|
||||||
from modules.shared import opts
|
from modules.shared import opts
|
||||||
|
|
||||||
|
if opts.randn_source == "NV":
|
||||||
|
return torch.asarray(nv_rng.randn(shape), device=device)
|
||||||
|
|
||||||
if opts.randn_source == "CPU" or device.type == 'mps':
|
if opts.randn_source == "CPU" or device.type == 'mps':
|
||||||
return torch.randn(shape, device=cpu).to(device)
|
return torch.randn(shape, device=cpu).to(device)
|
||||||
|
|
||||||
return torch.randn(shape, device=device)
|
return torch.randn(shape, device=device)
|
||||||
|
|
||||||
|
|
||||||
|
def manual_seed(seed):
|
||||||
|
"""Set up a global random number generator using the specified seed."""
|
||||||
|
from modules.shared import opts
|
||||||
|
|
||||||
|
if opts.randn_source == "NV":
|
||||||
|
global nv_rng
|
||||||
|
nv_rng = rng_philox.Generator(seed)
|
||||||
|
return
|
||||||
|
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
|
||||||
|
|
||||||
def autocast(disable=False):
|
def autocast(disable=False):
|
||||||
from modules import shared
|
from modules import shared
|
||||||
|
|
||||||
|
@ -14,7 +14,8 @@ def record_exception():
|
|||||||
if exception_records and exception_records[-1] == e:
|
if exception_records and exception_records[-1] == e:
|
||||||
return
|
return
|
||||||
|
|
||||||
exception_records.append((e, tb))
|
from modules import sysinfo
|
||||||
|
exception_records.append(sysinfo.format_exception(e, tb))
|
||||||
|
|
||||||
if len(exception_records) > 5:
|
if len(exception_records) > 5:
|
||||||
exception_records.pop(0)
|
exception_records.pop(0)
|
||||||
@ -83,3 +84,53 @@ def run(code, task):
|
|||||||
code()
|
code()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
display(task, e)
|
display(task, e)
|
||||||
|
|
||||||
|
|
||||||
|
def check_versions():
|
||||||
|
from packaging import version
|
||||||
|
from modules import shared
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import gradio
|
||||||
|
|
||||||
|
expected_torch_version = "2.0.0"
|
||||||
|
expected_xformers_version = "0.0.20"
|
||||||
|
expected_gradio_version = "3.39.0"
|
||||||
|
|
||||||
|
if version.parse(torch.__version__) < version.parse(expected_torch_version):
|
||||||
|
print_error_explanation(f"""
|
||||||
|
You are running torch {torch.__version__}.
|
||||||
|
The program is tested to work with torch {expected_torch_version}.
|
||||||
|
To reinstall the desired version, run with commandline flag --reinstall-torch.
|
||||||
|
Beware that this will cause a lot of large files to be downloaded, as well as
|
||||||
|
there are reports of issues with training tab on the latest version.
|
||||||
|
|
||||||
|
Use --skip-version-check commandline argument to disable this check.
|
||||||
|
""".strip())
|
||||||
|
|
||||||
|
if shared.xformers_available:
|
||||||
|
import xformers
|
||||||
|
|
||||||
|
if version.parse(xformers.__version__) < version.parse(expected_xformers_version):
|
||||||
|
print_error_explanation(f"""
|
||||||
|
You are running xformers {xformers.__version__}.
|
||||||
|
The program is tested to work with xformers {expected_xformers_version}.
|
||||||
|
To reinstall the desired version, run with commandline flag --reinstall-xformers.
|
||||||
|
|
||||||
|
Use --skip-version-check commandline argument to disable this check.
|
||||||
|
""".strip())
|
||||||
|
|
||||||
|
if gradio.__version__ != expected_gradio_version:
|
||||||
|
print_error_explanation(f"""
|
||||||
|
You are running gradio {gradio.__version__}.
|
||||||
|
The program is designed to work with gradio {expected_gradio_version}.
|
||||||
|
Using a different version of gradio is extremely likely to break the program.
|
||||||
|
|
||||||
|
Reasons why you have the mismatched gradio version can be:
|
||||||
|
- you use --skip-install flag.
|
||||||
|
- you use webui.py to start the program instead of launch.py.
|
||||||
|
- an extension installs the incompatible gradio version.
|
||||||
|
|
||||||
|
Use --skip-version-check commandline argument to disable this check.
|
||||||
|
""".strip())
|
||||||
|
|
||||||
|
@ -11,9 +11,9 @@ os.makedirs(extensions_dir, exist_ok=True)
|
|||||||
|
|
||||||
|
|
||||||
def active():
|
def active():
|
||||||
if shared.opts.disable_all_extensions == "all":
|
if shared.cmd_opts.disable_all_extensions or shared.opts.disable_all_extensions == "all":
|
||||||
return []
|
return []
|
||||||
elif shared.opts.disable_all_extensions == "extra":
|
elif shared.cmd_opts.disable_extra_extensions or shared.opts.disable_all_extensions == "extra":
|
||||||
return [x for x in extensions if x.enabled and x.is_builtin]
|
return [x for x in extensions if x.enabled and x.is_builtin]
|
||||||
else:
|
else:
|
||||||
return [x for x in extensions if x.enabled]
|
return [x for x in extensions if x.enabled]
|
||||||
@ -56,10 +56,12 @@ class Extension:
|
|||||||
self.do_read_info_from_repo()
|
self.do_read_info_from_repo()
|
||||||
|
|
||||||
return self.to_dict()
|
return self.to_dict()
|
||||||
|
try:
|
||||||
d = cache.cached_data_for_file('extensions-git', self.name, os.path.join(self.path, ".git"), read_from_repo)
|
d = cache.cached_data_for_file('extensions-git', self.name, os.path.join(self.path, ".git"), read_from_repo)
|
||||||
self.from_dict(d)
|
self.from_dict(d)
|
||||||
self.status = 'unknown'
|
except FileNotFoundError:
|
||||||
|
pass
|
||||||
|
self.status = 'unknown' if self.status == '' else self.status
|
||||||
|
|
||||||
def do_read_info_from_repo(self):
|
def do_read_info_from_repo(self):
|
||||||
repo = None
|
repo = None
|
||||||
@ -139,8 +141,12 @@ def list_extensions():
|
|||||||
if not os.path.isdir(extensions_dir):
|
if not os.path.isdir(extensions_dir):
|
||||||
return
|
return
|
||||||
|
|
||||||
if shared.opts.disable_all_extensions == "all":
|
if shared.cmd_opts.disable_all_extensions:
|
||||||
|
print("*** \"--disable-all-extensions\" arg was used, will not load any extensions ***")
|
||||||
|
elif shared.opts.disable_all_extensions == "all":
|
||||||
print("*** \"Disable all extensions\" option was set, will not load any extensions ***")
|
print("*** \"Disable all extensions\" option was set, will not load any extensions ***")
|
||||||
|
elif shared.cmd_opts.disable_extra_extensions:
|
||||||
|
print("*** \"--disable-extra-extensions\" arg was used, will only load built-in extensions ***")
|
||||||
elif shared.opts.disable_all_extensions == "extra":
|
elif shared.opts.disable_all_extensions == "extra":
|
||||||
print("*** \"Disable all extensions\" option was set, will only load built-in extensions ***")
|
print("*** \"Disable all extensions\" option was set, will only load built-in extensions ***")
|
||||||
|
|
||||||
|
@ -1,19 +1,27 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
import re
|
import re
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
|
||||||
from modules import errors
|
from modules import errors
|
||||||
|
|
||||||
extra_network_registry = {}
|
extra_network_registry = {}
|
||||||
|
extra_network_aliases = {}
|
||||||
|
|
||||||
|
|
||||||
def initialize():
|
def initialize():
|
||||||
extra_network_registry.clear()
|
extra_network_registry.clear()
|
||||||
|
extra_network_aliases.clear()
|
||||||
|
|
||||||
|
|
||||||
def register_extra_network(extra_network):
|
def register_extra_network(extra_network):
|
||||||
extra_network_registry[extra_network.name] = extra_network
|
extra_network_registry[extra_network.name] = extra_network
|
||||||
|
|
||||||
|
|
||||||
|
def register_extra_network_alias(extra_network, alias):
|
||||||
|
extra_network_aliases[alias] = extra_network
|
||||||
|
|
||||||
|
|
||||||
def register_default_extra_networks():
|
def register_default_extra_networks():
|
||||||
from modules.extra_networks_hypernet import ExtraNetworkHypernet
|
from modules.extra_networks_hypernet import ExtraNetworkHypernet
|
||||||
register_extra_network(ExtraNetworkHypernet())
|
register_extra_network(ExtraNetworkHypernet())
|
||||||
@ -82,20 +90,26 @@ def activate(p, extra_network_data):
|
|||||||
"""call activate for extra networks in extra_network_data in specified order, then call
|
"""call activate for extra networks in extra_network_data in specified order, then call
|
||||||
activate for all remaining registered networks with an empty argument list"""
|
activate for all remaining registered networks with an empty argument list"""
|
||||||
|
|
||||||
|
activated = []
|
||||||
|
|
||||||
for extra_network_name, extra_network_args in extra_network_data.items():
|
for extra_network_name, extra_network_args in extra_network_data.items():
|
||||||
extra_network = extra_network_registry.get(extra_network_name, None)
|
extra_network = extra_network_registry.get(extra_network_name, None)
|
||||||
|
|
||||||
|
if extra_network is None:
|
||||||
|
extra_network = extra_network_aliases.get(extra_network_name, None)
|
||||||
|
|
||||||
if extra_network is None:
|
if extra_network is None:
|
||||||
print(f"Skipping unknown extra network: {extra_network_name}")
|
print(f"Skipping unknown extra network: {extra_network_name}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
extra_network.activate(p, extra_network_args)
|
extra_network.activate(p, extra_network_args)
|
||||||
|
activated.append(extra_network)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
errors.display(e, f"activating extra network {extra_network_name} with arguments {extra_network_args}")
|
errors.display(e, f"activating extra network {extra_network_name} with arguments {extra_network_args}")
|
||||||
|
|
||||||
for extra_network_name, extra_network in extra_network_registry.items():
|
for extra_network_name, extra_network in extra_network_registry.items():
|
||||||
args = extra_network_data.get(extra_network_name, None)
|
if extra_network in activated:
|
||||||
if args is not None:
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -165,3 +179,20 @@ def parse_prompts(prompts):
|
|||||||
|
|
||||||
return res, extra_data
|
return res, extra_data
|
||||||
|
|
||||||
|
|
||||||
|
def get_user_metadata(filename):
|
||||||
|
if filename is None:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
basename, ext = os.path.splitext(filename)
|
||||||
|
metadata_filename = basename + '.json'
|
||||||
|
|
||||||
|
metadata = {}
|
||||||
|
try:
|
||||||
|
if os.path.isfile(metadata_filename):
|
||||||
|
with open(metadata_filename, "r", encoding="utf8") as file:
|
||||||
|
metadata = json.load(file)
|
||||||
|
except Exception as e:
|
||||||
|
errors.display(e, f"reading extra network user metadata from {metadata_filename}")
|
||||||
|
|
||||||
|
return metadata
|
||||||
|
@ -7,7 +7,7 @@ import json
|
|||||||
import torch
|
import torch
|
||||||
import tqdm
|
import tqdm
|
||||||
|
|
||||||
from modules import shared, images, sd_models, sd_vae, sd_models_config
|
from modules import shared, images, sd_models, sd_vae, sd_models_config, errors
|
||||||
from modules.ui_common import plaintext_to_html
|
from modules.ui_common import plaintext_to_html
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
import safetensors.torch
|
import safetensors.torch
|
||||||
@ -72,7 +72,20 @@ def to_half(tensor, enable):
|
|||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights, save_metadata):
|
def read_metadata(primary_model_name, secondary_model_name, tertiary_model_name):
|
||||||
|
metadata = {}
|
||||||
|
|
||||||
|
for checkpoint_name in [primary_model_name, secondary_model_name, tertiary_model_name]:
|
||||||
|
checkpoint_info = sd_models.checkpoints_list.get(checkpoint_name, None)
|
||||||
|
if checkpoint_info is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
metadata.update(checkpoint_info.metadata)
|
||||||
|
|
||||||
|
return json.dumps(metadata, indent=4, ensure_ascii=False)
|
||||||
|
|
||||||
|
|
||||||
|
def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights, save_metadata, add_merge_recipe, copy_metadata_fields, metadata_json):
|
||||||
shared.state.begin(job="model-merge")
|
shared.state.begin(job="model-merge")
|
||||||
|
|
||||||
def fail(message):
|
def fail(message):
|
||||||
@ -241,11 +254,25 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
|
|||||||
shared.state.textinfo = "Saving"
|
shared.state.textinfo = "Saving"
|
||||||
print(f"Saving to {output_modelname}...")
|
print(f"Saving to {output_modelname}...")
|
||||||
|
|
||||||
metadata = None
|
metadata = {}
|
||||||
|
|
||||||
|
if save_metadata and copy_metadata_fields:
|
||||||
|
if primary_model_info:
|
||||||
|
metadata.update(primary_model_info.metadata)
|
||||||
|
if secondary_model_info:
|
||||||
|
metadata.update(secondary_model_info.metadata)
|
||||||
|
if tertiary_model_info:
|
||||||
|
metadata.update(tertiary_model_info.metadata)
|
||||||
|
|
||||||
if save_metadata:
|
if save_metadata:
|
||||||
metadata = {"format": "pt"}
|
try:
|
||||||
|
metadata.update(json.loads(metadata_json))
|
||||||
|
except Exception as e:
|
||||||
|
errors.display(e, "readin metadata from json")
|
||||||
|
|
||||||
|
metadata["format"] = "pt"
|
||||||
|
|
||||||
|
if save_metadata and add_merge_recipe:
|
||||||
merge_recipe = {
|
merge_recipe = {
|
||||||
"type": "webui", # indicate this model was merged with webui's built-in merger
|
"type": "webui", # indicate this model was merged with webui's built-in merger
|
||||||
"primary_model_hash": primary_model_info.sha256,
|
"primary_model_hash": primary_model_info.sha256,
|
||||||
@ -261,7 +288,6 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
|
|||||||
"is_inpainting": result_is_inpainting_model,
|
"is_inpainting": result_is_inpainting_model,
|
||||||
"is_instruct_pix2pix": result_is_instruct_pix2pix_model
|
"is_instruct_pix2pix": result_is_instruct_pix2pix_model
|
||||||
}
|
}
|
||||||
metadata["sd_merge_recipe"] = json.dumps(merge_recipe)
|
|
||||||
|
|
||||||
sd_merge_models = {}
|
sd_merge_models = {}
|
||||||
|
|
||||||
@ -281,11 +307,12 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
|
|||||||
if tertiary_model_info:
|
if tertiary_model_info:
|
||||||
add_model_metadata(tertiary_model_info)
|
add_model_metadata(tertiary_model_info)
|
||||||
|
|
||||||
|
metadata["sd_merge_recipe"] = json.dumps(merge_recipe)
|
||||||
metadata["sd_merge_models"] = json.dumps(sd_merge_models)
|
metadata["sd_merge_models"] = json.dumps(sd_merge_models)
|
||||||
|
|
||||||
_, extension = os.path.splitext(output_modelname)
|
_, extension = os.path.splitext(output_modelname)
|
||||||
if extension.lower() == ".safetensors":
|
if extension.lower() == ".safetensors":
|
||||||
safetensors.torch.save_file(theta_0, output_modelname, metadata=metadata)
|
safetensors.torch.save_file(theta_0, output_modelname, metadata=metadata if len(metadata)>0 else None)
|
||||||
else:
|
else:
|
||||||
torch.save(theta_0, output_modelname)
|
torch.save(theta_0, output_modelname)
|
||||||
|
|
||||||
|
@ -280,6 +280,9 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
|
|||||||
if "Hires sampler" not in res:
|
if "Hires sampler" not in res:
|
||||||
res["Hires sampler"] = "Use same sampler"
|
res["Hires sampler"] = "Use same sampler"
|
||||||
|
|
||||||
|
if "Hires checkpoint" not in res:
|
||||||
|
res["Hires checkpoint"] = "Use same checkpoint"
|
||||||
|
|
||||||
if "Hires prompt" not in res:
|
if "Hires prompt" not in res:
|
||||||
res["Hires prompt"] = ""
|
res["Hires prompt"] = ""
|
||||||
|
|
||||||
|
60
modules/gradio_extensons.py
Normal file
60
modules/gradio_extensons.py
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
import gradio as gr
|
||||||
|
|
||||||
|
from modules import scripts
|
||||||
|
|
||||||
|
def add_classes_to_gradio_component(comp):
|
||||||
|
"""
|
||||||
|
this adds gradio-* to the component for css styling (ie gradio-button to gr.Button), as well as some others
|
||||||
|
"""
|
||||||
|
|
||||||
|
comp.elem_classes = [f"gradio-{comp.get_block_name()}", *(comp.elem_classes or [])]
|
||||||
|
|
||||||
|
if getattr(comp, 'multiselect', False):
|
||||||
|
comp.elem_classes.append('multiselect')
|
||||||
|
|
||||||
|
|
||||||
|
def IOComponent_init(self, *args, **kwargs):
|
||||||
|
self.webui_tooltip = kwargs.pop('tooltip', None)
|
||||||
|
|
||||||
|
if scripts.scripts_current is not None:
|
||||||
|
scripts.scripts_current.before_component(self, **kwargs)
|
||||||
|
|
||||||
|
scripts.script_callbacks.before_component_callback(self, **kwargs)
|
||||||
|
|
||||||
|
res = original_IOComponent_init(self, *args, **kwargs)
|
||||||
|
|
||||||
|
add_classes_to_gradio_component(self)
|
||||||
|
|
||||||
|
scripts.script_callbacks.after_component_callback(self, **kwargs)
|
||||||
|
|
||||||
|
if scripts.scripts_current is not None:
|
||||||
|
scripts.scripts_current.after_component(self, **kwargs)
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
def Block_get_config(self):
|
||||||
|
config = original_Block_get_config(self)
|
||||||
|
|
||||||
|
webui_tooltip = getattr(self, 'webui_tooltip', None)
|
||||||
|
if webui_tooltip:
|
||||||
|
config["webui_tooltip"] = webui_tooltip
|
||||||
|
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
def BlockContext_init(self, *args, **kwargs):
|
||||||
|
res = original_BlockContext_init(self, *args, **kwargs)
|
||||||
|
|
||||||
|
add_classes_to_gradio_component(self)
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
original_IOComponent_init = gr.components.IOComponent.__init__
|
||||||
|
original_Block_get_config = gr.blocks.Block.get_config
|
||||||
|
original_BlockContext_init = gr.blocks.BlockContext.__init__
|
||||||
|
|
||||||
|
gr.components.IOComponent.__init__ = IOComponent_init
|
||||||
|
gr.blocks.Block.get_config = Block_get_config
|
||||||
|
gr.blocks.BlockContext.__init__ = BlockContext_init
|
@ -10,7 +10,7 @@ import torch
|
|||||||
import tqdm
|
import tqdm
|
||||||
from einops import rearrange, repeat
|
from einops import rearrange, repeat
|
||||||
from ldm.util import default
|
from ldm.util import default
|
||||||
from modules import devices, processing, sd_models, shared, sd_samplers, hashes, sd_hijack_checkpoint, errors
|
from modules import devices, sd_models, shared, sd_samplers, hashes, sd_hijack_checkpoint, errors
|
||||||
from modules.textual_inversion import textual_inversion, logging
|
from modules.textual_inversion import textual_inversion, logging
|
||||||
from modules.textual_inversion.learn_schedule import LearnRateScheduler
|
from modules.textual_inversion.learn_schedule import LearnRateScheduler
|
||||||
from torch import einsum
|
from torch import einsum
|
||||||
@ -378,7 +378,7 @@ def apply_hypernetworks(hypernetworks, context, layer=None):
|
|||||||
return context_k, context_v
|
return context_k, context_v
|
||||||
|
|
||||||
|
|
||||||
def attention_CrossAttention_forward(self, x, context=None, mask=None):
|
def attention_CrossAttention_forward(self, x, context=None, mask=None, **kwargs):
|
||||||
h = self.heads
|
h = self.heads
|
||||||
|
|
||||||
q = self.to_q(x)
|
q = self.to_q(x)
|
||||||
@ -469,8 +469,7 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None,
|
|||||||
|
|
||||||
|
|
||||||
def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, use_weight, create_image_every, save_hypernetwork_every, template_filename, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
|
def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, use_weight, create_image_every, save_hypernetwork_every, template_filename, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
|
||||||
# images allows training previews to have infotext. Importing it at the top causes a circular import problem.
|
from modules import images, processing
|
||||||
from modules import images
|
|
||||||
|
|
||||||
save_hypernetwork_every = save_hypernetwork_every or 0
|
save_hypernetwork_every = save_hypernetwork_every or 0
|
||||||
create_image_every = create_image_every or 0
|
create_image_every = create_image_every or 0
|
||||||
|
@ -318,7 +318,7 @@ def resize_image(resize_mode, im, width, height, upscaler_name=None):
|
|||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
invalid_filename_chars = '<>:"/\\|?*\n'
|
invalid_filename_chars = '<>:"/\\|?*\n\r\t'
|
||||||
invalid_filename_prefix = ' '
|
invalid_filename_prefix = ' '
|
||||||
invalid_filename_postfix = ' .'
|
invalid_filename_postfix = ' .'
|
||||||
re_nonletters = re.compile(r'[\s' + string.punctuation + ']+')
|
re_nonletters = re.compile(r'[\s' + string.punctuation + ']+')
|
||||||
@ -363,7 +363,7 @@ class FilenameGenerator:
|
|||||||
'styles': lambda self: self.p and sanitize_filename_part(", ".join([style for style in self.p.styles if not style == "None"]) or "None", replace_spaces=False),
|
'styles': lambda self: self.p and sanitize_filename_part(", ".join([style for style in self.p.styles if not style == "None"]) or "None", replace_spaces=False),
|
||||||
'sampler': lambda self: self.p and sanitize_filename_part(self.p.sampler_name, replace_spaces=False),
|
'sampler': lambda self: self.p and sanitize_filename_part(self.p.sampler_name, replace_spaces=False),
|
||||||
'model_hash': lambda self: getattr(self.p, "sd_model_hash", shared.sd_model.sd_model_hash),
|
'model_hash': lambda self: getattr(self.p, "sd_model_hash", shared.sd_model.sd_model_hash),
|
||||||
'model_name': lambda self: sanitize_filename_part(shared.sd_model.sd_checkpoint_info.model_name, replace_spaces=False),
|
'model_name': lambda self: sanitize_filename_part(shared.sd_model.sd_checkpoint_info.name_for_extra, replace_spaces=False),
|
||||||
'date': lambda self: datetime.datetime.now().strftime('%Y-%m-%d'),
|
'date': lambda self: datetime.datetime.now().strftime('%Y-%m-%d'),
|
||||||
'datetime': lambda self, *args: self.datetime(*args), # accepts formats: [datetime], [datetime<Format>], [datetime<Format><Time Zone>]
|
'datetime': lambda self, *args: self.datetime(*args), # accepts formats: [datetime], [datetime<Format>], [datetime<Format><Time Zone>]
|
||||||
'job_timestamp': lambda self: getattr(self.p, "job_timestamp", shared.state.job_timestamp),
|
'job_timestamp': lambda self: getattr(self.p, "job_timestamp", shared.state.job_timestamp),
|
||||||
|
@ -3,14 +3,13 @@ from contextlib import closing
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image, ImageOps, ImageFilter, ImageEnhance, ImageChops, UnidentifiedImageError
|
from PIL import Image, ImageOps, ImageFilter, ImageEnhance, UnidentifiedImageError
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
from modules import sd_samplers, images as imgutil
|
from modules import sd_samplers, images as imgutil
|
||||||
from modules.generation_parameters_copypaste import create_override_settings_dict, parse_generation_parameters
|
from modules.generation_parameters_copypaste import create_override_settings_dict, parse_generation_parameters
|
||||||
from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
|
from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
|
||||||
from modules.shared import opts, state
|
from modules.shared import opts, state
|
||||||
from modules.images import save_image
|
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
import modules.processing as processing
|
import modules.processing as processing
|
||||||
from modules.ui import plaintext_to_html
|
from modules.ui import plaintext_to_html
|
||||||
@ -18,9 +17,10 @@ import modules.scripts
|
|||||||
|
|
||||||
|
|
||||||
def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=False, scale_by=1.0, use_png_info=False, png_info_props=None, png_info_dir=None):
|
def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=False, scale_by=1.0, use_png_info=False, png_info_props=None, png_info_dir=None):
|
||||||
|
output_dir = output_dir.strip()
|
||||||
processing.fix_seed(p)
|
processing.fix_seed(p)
|
||||||
|
|
||||||
images = list(shared.walk_files(input_dir, allowed_extensions=(".png", ".jpg", ".jpeg", ".webp")))
|
images = list(shared.walk_files(input_dir, allowed_extensions=(".png", ".jpg", ".jpeg", ".webp", ".tif", ".tiff")))
|
||||||
|
|
||||||
is_inpaint_batch = False
|
is_inpaint_batch = False
|
||||||
if inpaint_mask_dir:
|
if inpaint_mask_dir:
|
||||||
@ -32,11 +32,6 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal
|
|||||||
|
|
||||||
print(f"Will process {len(images)} images, creating {p.n_iter * p.batch_size} new images for each.")
|
print(f"Will process {len(images)} images, creating {p.n_iter * p.batch_size} new images for each.")
|
||||||
|
|
||||||
save_normally = output_dir == ''
|
|
||||||
|
|
||||||
p.do_not_save_grid = True
|
|
||||||
p.do_not_save_samples = not save_normally
|
|
||||||
|
|
||||||
state.job_count = len(images) * p.n_iter
|
state.job_count = len(images) * p.n_iter
|
||||||
|
|
||||||
# extract "default" params to use in case getting png info fails
|
# extract "default" params to use in case getting png info fails
|
||||||
@ -111,21 +106,14 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal
|
|||||||
|
|
||||||
proc = modules.scripts.scripts_img2img.run(p, *args)
|
proc = modules.scripts.scripts_img2img.run(p, *args)
|
||||||
if proc is None:
|
if proc is None:
|
||||||
proc = process_images(p)
|
if output_dir:
|
||||||
|
p.outpath_samples = output_dir
|
||||||
for n, processed_image in enumerate(proc.images):
|
p.override_settings['save_to_dirs'] = False
|
||||||
filename = image_path.stem
|
if p.n_iter > 1 or p.batch_size > 1:
|
||||||
infotext = proc.infotext(p, n)
|
p.override_settings['samples_filename_pattern'] = f'{image_path.stem}-[generation_number]'
|
||||||
relpath = os.path.dirname(os.path.relpath(image, input_dir))
|
else:
|
||||||
|
p.override_settings['samples_filename_pattern'] = f'{image_path.stem}'
|
||||||
if n > 0:
|
process_images(p)
|
||||||
filename += f"-{n}"
|
|
||||||
|
|
||||||
if not save_normally:
|
|
||||||
os.makedirs(os.path.join(output_dir, relpath), exist_ok=True)
|
|
||||||
if processed_image.mode == 'RGBA':
|
|
||||||
processed_image = processed_image.convert("RGB")
|
|
||||||
save_image(processed_image, os.path.join(output_dir, relpath), None, extension=opts.samples_format, info=infotext, forced_filename=filename, save_to_dirs=False)
|
|
||||||
|
|
||||||
|
|
||||||
def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, img2img_batch_use_png_info: bool, img2img_batch_png_info_props: list, img2img_batch_png_info_dir: str, request: gr.Request, *args):
|
def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, img2img_batch_use_png_info: bool, img2img_batch_png_info_props: list, img2img_batch_png_info_dir: str, request: gr.Request, *args):
|
||||||
@ -141,9 +129,7 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
|
|||||||
mask = None
|
mask = None
|
||||||
elif mode == 2: # inpaint
|
elif mode == 2: # inpaint
|
||||||
image, mask = init_img_with_mask["image"], init_img_with_mask["mask"]
|
image, mask = init_img_with_mask["image"], init_img_with_mask["mask"]
|
||||||
alpha_mask = ImageOps.invert(image.split()[-1]).convert('L').point(lambda x: 255 if x > 0 else 0, mode='1')
|
mask = mask.split()[-1].convert("L").point(lambda x: 255 if x > 128 else 0)
|
||||||
mask = mask.convert('L').point(lambda x: 255 if x > 128 else 0, mode='1')
|
|
||||||
mask = ImageChops.lighter(alpha_mask, mask).convert('L')
|
|
||||||
image = image.convert("RGB")
|
image = image.convert("RGB")
|
||||||
elif mode == 3: # inpaint sketch
|
elif mode == 3: # inpaint sketch
|
||||||
image = inpaint_color_sketch
|
image = inpaint_color_sketch
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
# this scripts installs necessary requirements and launches main program in webui.py
|
# this scripts installs necessary requirements and launches main program in webui.py
|
||||||
|
import re
|
||||||
import subprocess
|
import subprocess
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
@ -9,6 +10,7 @@ from functools import lru_cache
|
|||||||
|
|
||||||
from modules import cmd_args, errors
|
from modules import cmd_args, errors
|
||||||
from modules.paths_internal import script_path, extensions_dir
|
from modules.paths_internal import script_path, extensions_dir
|
||||||
|
from modules.timer import startup_timer
|
||||||
|
|
||||||
args, _ = cmd_args.parser.parse_known_args()
|
args, _ = cmd_args.parser.parse_known_args()
|
||||||
|
|
||||||
@ -192,7 +194,7 @@ def run_extension_installer(extension_dir):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
env = os.environ.copy()
|
env = os.environ.copy()
|
||||||
env['PYTHONPATH'] = os.path.abspath(".")
|
env['PYTHONPATH'] = f"{os.path.abspath('.')}{os.pathsep}{env.get('PYTHONPATH', '')}"
|
||||||
|
|
||||||
print(run(f'"{python}" "{path_installer}"', errdesc=f"Error running install.py for extension {extension_dir}", custom_env=env))
|
print(run(f'"{python}" "{path_installer}"', errdesc=f"Error running install.py for extension {extension_dir}", custom_env=env))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -222,8 +224,51 @@ def run_extensions_installers(settings_file):
|
|||||||
if not os.path.isdir(extensions_dir):
|
if not os.path.isdir(extensions_dir):
|
||||||
return
|
return
|
||||||
|
|
||||||
for dirname_extension in list_extensions(settings_file):
|
with startup_timer.subcategory("run extensions installers"):
|
||||||
run_extension_installer(os.path.join(extensions_dir, dirname_extension))
|
for dirname_extension in list_extensions(settings_file):
|
||||||
|
path = os.path.join(extensions_dir, dirname_extension)
|
||||||
|
|
||||||
|
if os.path.isdir(path):
|
||||||
|
run_extension_installer(path)
|
||||||
|
startup_timer.record(dirname_extension)
|
||||||
|
|
||||||
|
|
||||||
|
re_requirement = re.compile(r"\s*([-_a-zA-Z0-9]+)\s*(?:==\s*([-+_.a-zA-Z0-9]+))?\s*")
|
||||||
|
|
||||||
|
|
||||||
|
def requirements_met(requirements_file):
|
||||||
|
"""
|
||||||
|
Does a simple parse of a requirements.txt file to determine if all rerqirements in it
|
||||||
|
are already installed. Returns True if so, False if not installed or parsing fails.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import importlib.metadata
|
||||||
|
import packaging.version
|
||||||
|
|
||||||
|
with open(requirements_file, "r", encoding="utf8") as file:
|
||||||
|
for line in file:
|
||||||
|
if line.strip() == "":
|
||||||
|
continue
|
||||||
|
|
||||||
|
m = re.match(re_requirement, line)
|
||||||
|
if m is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
package = m.group(1).strip()
|
||||||
|
version_required = (m.group(2) or "").strip()
|
||||||
|
|
||||||
|
if version_required == "":
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
version_installed = importlib.metadata.version(package)
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if packaging.version.parse(version_required) != packaging.version.parse(version_installed):
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
def prepare_environment():
|
def prepare_environment():
|
||||||
@ -237,11 +282,13 @@ def prepare_environment():
|
|||||||
openclip_package = os.environ.get('OPENCLIP_PACKAGE', "https://github.com/mlfoundations/open_clip/archive/bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b.zip")
|
openclip_package = os.environ.get('OPENCLIP_PACKAGE', "https://github.com/mlfoundations/open_clip/archive/bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b.zip")
|
||||||
|
|
||||||
stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git")
|
stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git")
|
||||||
|
stable_diffusion_xl_repo = os.environ.get('STABLE_DIFFUSION_XL_REPO', "https://github.com/Stability-AI/generative-models.git")
|
||||||
k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git')
|
k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git')
|
||||||
codeformer_repo = os.environ.get('CODEFORMER_REPO', 'https://github.com/sczhou/CodeFormer.git')
|
codeformer_repo = os.environ.get('CODEFORMER_REPO', 'https://github.com/sczhou/CodeFormer.git')
|
||||||
blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git')
|
blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git')
|
||||||
|
|
||||||
stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf")
|
stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf")
|
||||||
|
stable_diffusion_xl_commit_hash = os.environ.get('STABLE_DIFFUSION_XL_COMMIT_HASH', "5c10deee76adad0032b412294130090932317a87")
|
||||||
k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "c9fe758757e022f05ca5a53fa8fac28889e4f1cf")
|
k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "c9fe758757e022f05ca5a53fa8fac28889e4f1cf")
|
||||||
codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af")
|
codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af")
|
||||||
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
|
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
|
||||||
@ -249,15 +296,18 @@ def prepare_environment():
|
|||||||
try:
|
try:
|
||||||
# the existance of this file is a signal to webui.sh/bat that webui needs to be restarted when it stops execution
|
# the existance of this file is a signal to webui.sh/bat that webui needs to be restarted when it stops execution
|
||||||
os.remove(os.path.join(script_path, "tmp", "restart"))
|
os.remove(os.path.join(script_path, "tmp", "restart"))
|
||||||
os.environ.setdefault('SD_WEBUI_RESTARTING ', '1')
|
os.environ.setdefault('SD_WEBUI_RESTARTING', '1')
|
||||||
except OSError:
|
except OSError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
if not args.skip_python_version_check:
|
if not args.skip_python_version_check:
|
||||||
check_python_version()
|
check_python_version()
|
||||||
|
|
||||||
|
startup_timer.record("checks")
|
||||||
|
|
||||||
commit = commit_hash()
|
commit = commit_hash()
|
||||||
tag = git_tag()
|
tag = git_tag()
|
||||||
|
startup_timer.record("git version info")
|
||||||
|
|
||||||
print(f"Python {sys.version}")
|
print(f"Python {sys.version}")
|
||||||
print(f"Version: {tag}")
|
print(f"Version: {tag}")
|
||||||
@ -265,21 +315,27 @@ def prepare_environment():
|
|||||||
|
|
||||||
if args.reinstall_torch or not is_installed("torch") or not is_installed("torchvision"):
|
if args.reinstall_torch or not is_installed("torch") or not is_installed("torchvision"):
|
||||||
run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch", live=True)
|
run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch", live=True)
|
||||||
|
startup_timer.record("install torch")
|
||||||
|
|
||||||
if not args.skip_torch_cuda_test and not check_run_python("import torch; assert torch.cuda.is_available()"):
|
if not args.skip_torch_cuda_test and not check_run_python("import torch; assert torch.cuda.is_available()"):
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
'Torch is not able to use GPU; '
|
'Torch is not able to use GPU; '
|
||||||
'add --skip-torch-cuda-test to COMMANDLINE_ARGS variable to disable this check'
|
'add --skip-torch-cuda-test to COMMANDLINE_ARGS variable to disable this check'
|
||||||
)
|
)
|
||||||
|
startup_timer.record("torch GPU test")
|
||||||
|
|
||||||
|
|
||||||
if not is_installed("gfpgan"):
|
if not is_installed("gfpgan"):
|
||||||
run_pip(f"install {gfpgan_package}", "gfpgan")
|
run_pip(f"install {gfpgan_package}", "gfpgan")
|
||||||
|
startup_timer.record("install gfpgan")
|
||||||
|
|
||||||
if not is_installed("clip"):
|
if not is_installed("clip"):
|
||||||
run_pip(f"install {clip_package}", "clip")
|
run_pip(f"install {clip_package}", "clip")
|
||||||
|
startup_timer.record("install clip")
|
||||||
|
|
||||||
if not is_installed("open_clip"):
|
if not is_installed("open_clip"):
|
||||||
run_pip(f"install {openclip_package}", "open_clip")
|
run_pip(f"install {openclip_package}", "open_clip")
|
||||||
|
startup_timer.record("install open_clip")
|
||||||
|
|
||||||
if (not is_installed("xformers") or args.reinstall_xformers) and args.xformers:
|
if (not is_installed("xformers") or args.reinstall_xformers) and args.xformers:
|
||||||
if platform.system() == "Windows":
|
if platform.system() == "Windows":
|
||||||
@ -293,36 +349,49 @@ def prepare_environment():
|
|||||||
elif platform.system() == "Linux":
|
elif platform.system() == "Linux":
|
||||||
run_pip(f"install -U -I --no-deps {xformers_package}", "xformers")
|
run_pip(f"install -U -I --no-deps {xformers_package}", "xformers")
|
||||||
|
|
||||||
|
startup_timer.record("install xformers")
|
||||||
|
|
||||||
if not is_installed("ngrok") and args.ngrok:
|
if not is_installed("ngrok") and args.ngrok:
|
||||||
run_pip("install ngrok", "ngrok")
|
run_pip("install ngrok", "ngrok")
|
||||||
|
startup_timer.record("install ngrok")
|
||||||
|
|
||||||
os.makedirs(os.path.join(script_path, dir_repos), exist_ok=True)
|
os.makedirs(os.path.join(script_path, dir_repos), exist_ok=True)
|
||||||
|
|
||||||
git_clone(stable_diffusion_repo, repo_dir('stable-diffusion-stability-ai'), "Stable Diffusion", stable_diffusion_commit_hash)
|
git_clone(stable_diffusion_repo, repo_dir('stable-diffusion-stability-ai'), "Stable Diffusion", stable_diffusion_commit_hash)
|
||||||
|
git_clone(stable_diffusion_xl_repo, repo_dir('generative-models'), "Stable Diffusion XL", stable_diffusion_xl_commit_hash)
|
||||||
git_clone(k_diffusion_repo, repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash)
|
git_clone(k_diffusion_repo, repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash)
|
||||||
git_clone(codeformer_repo, repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash)
|
git_clone(codeformer_repo, repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash)
|
||||||
git_clone(blip_repo, repo_dir('BLIP'), "BLIP", blip_commit_hash)
|
git_clone(blip_repo, repo_dir('BLIP'), "BLIP", blip_commit_hash)
|
||||||
|
|
||||||
|
startup_timer.record("clone repositores")
|
||||||
|
|
||||||
if not is_installed("lpips"):
|
if not is_installed("lpips"):
|
||||||
run_pip(f"install -r \"{os.path.join(repo_dir('CodeFormer'), 'requirements.txt')}\"", "requirements for CodeFormer")
|
run_pip(f"install -r \"{os.path.join(repo_dir('CodeFormer'), 'requirements.txt')}\"", "requirements for CodeFormer")
|
||||||
|
startup_timer.record("install CodeFormer requirements")
|
||||||
|
|
||||||
if not os.path.isfile(requirements_file):
|
if not os.path.isfile(requirements_file):
|
||||||
requirements_file = os.path.join(script_path, requirements_file)
|
requirements_file = os.path.join(script_path, requirements_file)
|
||||||
run_pip(f"install -r \"{requirements_file}\"", "requirements")
|
|
||||||
|
if not requirements_met(requirements_file):
|
||||||
|
run_pip(f"install -r \"{requirements_file}\"", "requirements")
|
||||||
|
startup_timer.record("install requirements")
|
||||||
|
|
||||||
run_extensions_installers(settings_file=args.ui_settings_file)
|
run_extensions_installers(settings_file=args.ui_settings_file)
|
||||||
|
|
||||||
if args.update_check:
|
if args.update_check:
|
||||||
version_check(commit)
|
version_check(commit)
|
||||||
|
startup_timer.record("check version")
|
||||||
|
|
||||||
if args.update_all_extensions:
|
if args.update_all_extensions:
|
||||||
git_pull_recursive(extensions_dir)
|
git_pull_recursive(extensions_dir)
|
||||||
|
startup_timer.record("update extensions")
|
||||||
|
|
||||||
if "--exit" in sys.argv:
|
if "--exit" in sys.argv:
|
||||||
print("Exiting because of --exit argument")
|
print("Exiting because of --exit argument")
|
||||||
exit(0)
|
exit(0)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def configure_for_tests():
|
def configure_for_tests():
|
||||||
if "--api" not in sys.argv:
|
if "--api" not in sys.argv:
|
||||||
sys.argv.append("--api")
|
sys.argv.append("--api")
|
||||||
|
@ -15,6 +15,9 @@ def send_everything_to_cpu():
|
|||||||
|
|
||||||
|
|
||||||
def setup_for_low_vram(sd_model, use_medvram):
|
def setup_for_low_vram(sd_model, use_medvram):
|
||||||
|
if getattr(sd_model, 'lowvram', False):
|
||||||
|
return
|
||||||
|
|
||||||
sd_model.lowvram = True
|
sd_model.lowvram = True
|
||||||
|
|
||||||
parents = {}
|
parents = {}
|
||||||
@ -53,19 +56,50 @@ def setup_for_low_vram(sd_model, use_medvram):
|
|||||||
send_me_to_gpu(first_stage_model, None)
|
send_me_to_gpu(first_stage_model, None)
|
||||||
return first_stage_model_decode(z)
|
return first_stage_model_decode(z)
|
||||||
|
|
||||||
# for SD1, cond_stage_model is CLIP and its NN is in the tranformer frield, but for SD2, it's open clip, and it's in model field
|
to_remain_in_cpu = [
|
||||||
if hasattr(sd_model.cond_stage_model, 'model'):
|
(sd_model, 'first_stage_model'),
|
||||||
sd_model.cond_stage_model.transformer = sd_model.cond_stage_model.model
|
(sd_model, 'depth_model'),
|
||||||
|
(sd_model, 'embedder'),
|
||||||
|
(sd_model, 'model'),
|
||||||
|
(sd_model, 'embedder'),
|
||||||
|
]
|
||||||
|
|
||||||
# remove several big modules: cond, first_stage, depth/embedder (if applicable), and unet from the model and then
|
is_sdxl = hasattr(sd_model, 'conditioner')
|
||||||
# send the model to GPU. Then put modules back. the modules will be in CPU.
|
is_sd2 = not is_sdxl and hasattr(sd_model.cond_stage_model, 'model')
|
||||||
stored = sd_model.cond_stage_model.transformer, sd_model.first_stage_model, getattr(sd_model, 'depth_model', None), getattr(sd_model, 'embedder', None), sd_model.model
|
|
||||||
sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.depth_model, sd_model.embedder, sd_model.model = None, None, None, None, None
|
if is_sdxl:
|
||||||
|
to_remain_in_cpu.append((sd_model, 'conditioner'))
|
||||||
|
elif is_sd2:
|
||||||
|
to_remain_in_cpu.append((sd_model.cond_stage_model, 'model'))
|
||||||
|
else:
|
||||||
|
to_remain_in_cpu.append((sd_model.cond_stage_model, 'transformer'))
|
||||||
|
|
||||||
|
# remove several big modules: cond, first_stage, depth/embedder (if applicable), and unet from the model
|
||||||
|
stored = []
|
||||||
|
for obj, field in to_remain_in_cpu:
|
||||||
|
module = getattr(obj, field, None)
|
||||||
|
stored.append(module)
|
||||||
|
setattr(obj, field, None)
|
||||||
|
|
||||||
|
# send the model to GPU.
|
||||||
sd_model.to(devices.device)
|
sd_model.to(devices.device)
|
||||||
sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.depth_model, sd_model.embedder, sd_model.model = stored
|
|
||||||
|
# put modules back. the modules will be in CPU.
|
||||||
|
for (obj, field), module in zip(to_remain_in_cpu, stored):
|
||||||
|
setattr(obj, field, module)
|
||||||
|
|
||||||
# register hooks for those the first three models
|
# register hooks for those the first three models
|
||||||
sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu)
|
if is_sdxl:
|
||||||
|
sd_model.conditioner.register_forward_pre_hook(send_me_to_gpu)
|
||||||
|
elif is_sd2:
|
||||||
|
sd_model.cond_stage_model.model.register_forward_pre_hook(send_me_to_gpu)
|
||||||
|
sd_model.cond_stage_model.model.token_embedding.register_forward_pre_hook(send_me_to_gpu)
|
||||||
|
parents[sd_model.cond_stage_model.model] = sd_model.cond_stage_model
|
||||||
|
parents[sd_model.cond_stage_model.model.token_embedding] = sd_model.cond_stage_model
|
||||||
|
else:
|
||||||
|
sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu)
|
||||||
|
parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model
|
||||||
|
|
||||||
sd_model.first_stage_model.register_forward_pre_hook(send_me_to_gpu)
|
sd_model.first_stage_model.register_forward_pre_hook(send_me_to_gpu)
|
||||||
sd_model.first_stage_model.encode = first_stage_model_encode_wrap
|
sd_model.first_stage_model.encode = first_stage_model_encode_wrap
|
||||||
sd_model.first_stage_model.decode = first_stage_model_decode_wrap
|
sd_model.first_stage_model.decode = first_stage_model_decode_wrap
|
||||||
@ -73,11 +107,6 @@ def setup_for_low_vram(sd_model, use_medvram):
|
|||||||
sd_model.depth_model.register_forward_pre_hook(send_me_to_gpu)
|
sd_model.depth_model.register_forward_pre_hook(send_me_to_gpu)
|
||||||
if sd_model.embedder:
|
if sd_model.embedder:
|
||||||
sd_model.embedder.register_forward_pre_hook(send_me_to_gpu)
|
sd_model.embedder.register_forward_pre_hook(send_me_to_gpu)
|
||||||
parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model
|
|
||||||
|
|
||||||
if hasattr(sd_model.cond_stage_model, 'model'):
|
|
||||||
sd_model.cond_stage_model.model = sd_model.cond_stage_model.transformer
|
|
||||||
del sd_model.cond_stage_model.transformer
|
|
||||||
|
|
||||||
if use_medvram:
|
if use_medvram:
|
||||||
sd_model.model.register_forward_pre_hook(send_me_to_gpu)
|
sd_model.model.register_forward_pre_hook(send_me_to_gpu)
|
||||||
|
@ -5,6 +5,21 @@ from modules.paths_internal import models_path, script_path, data_path, extensio
|
|||||||
import modules.safe # noqa: F401
|
import modules.safe # noqa: F401
|
||||||
|
|
||||||
|
|
||||||
|
def mute_sdxl_imports():
|
||||||
|
"""create fake modules that SDXL wants to import but doesn't actually use for our purposes"""
|
||||||
|
|
||||||
|
class Dummy:
|
||||||
|
pass
|
||||||
|
|
||||||
|
module = Dummy()
|
||||||
|
module.LPIPS = None
|
||||||
|
sys.modules['taming.modules.losses.lpips'] = module
|
||||||
|
|
||||||
|
module = Dummy()
|
||||||
|
module.StableDataModuleFromConfig = None
|
||||||
|
sys.modules['sgm.data'] = module
|
||||||
|
|
||||||
|
|
||||||
# data_path = cmd_opts_pre.data
|
# data_path = cmd_opts_pre.data
|
||||||
sys.path.insert(0, script_path)
|
sys.path.insert(0, script_path)
|
||||||
|
|
||||||
@ -18,8 +33,11 @@ for possible_sd_path in possible_sd_paths:
|
|||||||
|
|
||||||
assert sd_path is not None, f"Couldn't find Stable Diffusion in any of: {possible_sd_paths}"
|
assert sd_path is not None, f"Couldn't find Stable Diffusion in any of: {possible_sd_paths}"
|
||||||
|
|
||||||
|
mute_sdxl_imports()
|
||||||
|
|
||||||
path_dirs = [
|
path_dirs = [
|
||||||
(sd_path, 'ldm', 'Stable Diffusion', []),
|
(sd_path, 'ldm', 'Stable Diffusion', []),
|
||||||
|
(os.path.join(sd_path, '../generative-models'), 'sgm', 'Stable Diffusion XL', ["sgm"]),
|
||||||
(os.path.join(sd_path, '../CodeFormer'), 'inference_codeformer.py', 'CodeFormer', []),
|
(os.path.join(sd_path, '../CodeFormer'), 'inference_codeformer.py', 'CodeFormer', []),
|
||||||
(os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP', []),
|
(os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP', []),
|
||||||
(os.path.join(sd_path, '../k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]),
|
(os.path.join(sd_path, '../k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]),
|
||||||
@ -35,6 +53,13 @@ for d, must_exist, what, options in path_dirs:
|
|||||||
d = os.path.abspath(d)
|
d = os.path.abspath(d)
|
||||||
if "atstart" in options:
|
if "atstart" in options:
|
||||||
sys.path.insert(0, d)
|
sys.path.insert(0, d)
|
||||||
|
elif "sgm" in options:
|
||||||
|
# Stable Diffusion XL repo has scripts dir with __init__.py in it which ruins every extension's scripts dir, so we
|
||||||
|
# import sgm and remove it from sys.path so that when a script imports scripts.something, it doesbn't use sgm's scripts dir.
|
||||||
|
|
||||||
|
sys.path.insert(0, d)
|
||||||
|
import sgm # noqa: F401
|
||||||
|
sys.path.pop(0)
|
||||||
else:
|
else:
|
||||||
sys.path.append(d)
|
sys.path.append(d)
|
||||||
paths[what] = d
|
paths[what] = d
|
||||||
|
@ -14,7 +14,7 @@ from skimage import exposure
|
|||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
import modules.sd_hijack
|
import modules.sd_hijack
|
||||||
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet
|
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors
|
||||||
from modules.sd_hijack import model_hijack
|
from modules.sd_hijack import model_hijack
|
||||||
from modules.shared import opts, cmd_opts, state
|
from modules.shared import opts, cmd_opts, state
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
@ -30,6 +30,7 @@ from ldm.models.diffusion.ddpm import LatentDepth2ImageDiffusion
|
|||||||
from einops import repeat, rearrange
|
from einops import repeat, rearrange
|
||||||
from blendmodes.blend import blendLayers, BlendType
|
from blendmodes.blend import blendLayers, BlendType
|
||||||
|
|
||||||
|
decode_first_stage = sd_samplers_common.decode_first_stage
|
||||||
|
|
||||||
# some of those options should not be changed at all because they would break the model, so I removed them from options.
|
# some of those options should not be changed at all because they would break the model, so I removed them from options.
|
||||||
opt_C = 4
|
opt_C = 4
|
||||||
@ -330,8 +331,21 @@ class StableDiffusionProcessing:
|
|||||||
|
|
||||||
caches is a list with items described above.
|
caches is a list with items described above.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
cached_params = (
|
||||||
|
required_prompts,
|
||||||
|
steps,
|
||||||
|
opts.CLIP_stop_at_last_layers,
|
||||||
|
shared.sd_model.sd_checkpoint_info,
|
||||||
|
extra_network_data,
|
||||||
|
opts.sdxl_crop_left,
|
||||||
|
opts.sdxl_crop_top,
|
||||||
|
self.width,
|
||||||
|
self.height,
|
||||||
|
)
|
||||||
|
|
||||||
for cache in caches:
|
for cache in caches:
|
||||||
if cache[0] is not None and (required_prompts, steps, opts.CLIP_stop_at_last_layers, shared.sd_model.sd_checkpoint_info, extra_network_data) == cache[0]:
|
if cache[0] is not None and cached_params == cache[0]:
|
||||||
return cache[1]
|
return cache[1]
|
||||||
|
|
||||||
cache = caches[0]
|
cache = caches[0]
|
||||||
@ -339,14 +353,17 @@ class StableDiffusionProcessing:
|
|||||||
with devices.autocast():
|
with devices.autocast():
|
||||||
cache[1] = function(shared.sd_model, required_prompts, steps)
|
cache[1] = function(shared.sd_model, required_prompts, steps)
|
||||||
|
|
||||||
cache[0] = (required_prompts, steps, opts.CLIP_stop_at_last_layers, shared.sd_model.sd_checkpoint_info, extra_network_data)
|
cache[0] = cached_params
|
||||||
return cache[1]
|
return cache[1]
|
||||||
|
|
||||||
def setup_conds(self):
|
def setup_conds(self):
|
||||||
|
prompts = prompt_parser.SdConditioning(self.prompts, width=self.width, height=self.height)
|
||||||
|
negative_prompts = prompt_parser.SdConditioning(self.negative_prompts, width=self.width, height=self.height, is_negative_prompt=True)
|
||||||
|
|
||||||
sampler_config = sd_samplers.find_sampler_config(self.sampler_name)
|
sampler_config = sd_samplers.find_sampler_config(self.sampler_name)
|
||||||
self.step_multiplier = 2 if sampler_config and sampler_config.options.get("second_order", False) else 1
|
self.step_multiplier = 2 if sampler_config and sampler_config.options.get("second_order", False) else 1
|
||||||
self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, self.negative_prompts, self.steps * self.step_multiplier, [self.cached_uc], self.extra_network_data)
|
self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, self.steps * self.step_multiplier, [self.cached_uc], self.extra_network_data)
|
||||||
self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, self.prompts, self.steps * self.step_multiplier, [self.cached_c], self.extra_network_data)
|
self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, self.steps * self.step_multiplier, [self.cached_c], self.extra_network_data)
|
||||||
|
|
||||||
def parse_extra_network_prompts(self):
|
def parse_extra_network_prompts(self):
|
||||||
self.prompts, self.extra_network_data = extra_networks.parse_prompts(self.prompts)
|
self.prompts, self.extra_network_data = extra_networks.parse_prompts(self.prompts)
|
||||||
@ -476,7 +493,7 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see
|
|||||||
noise_shape = shape if seed_resize_from_h <= 0 or seed_resize_from_w <= 0 else (shape[0], seed_resize_from_h//8, seed_resize_from_w//8)
|
noise_shape = shape if seed_resize_from_h <= 0 or seed_resize_from_w <= 0 else (shape[0], seed_resize_from_h//8, seed_resize_from_w//8)
|
||||||
|
|
||||||
subnoise = None
|
subnoise = None
|
||||||
if subseeds is not None:
|
if subseeds is not None and subseed_strength != 0:
|
||||||
subseed = 0 if i >= len(subseeds) else subseeds[i]
|
subseed = 0 if i >= len(subseeds) else subseeds[i]
|
||||||
|
|
||||||
subnoise = devices.randn(subseed, noise_shape)
|
subnoise = devices.randn(subseed, noise_shape)
|
||||||
@ -508,7 +525,7 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see
|
|||||||
cnt = p.sampler.number_of_needed_noises(p)
|
cnt = p.sampler.number_of_needed_noises(p)
|
||||||
|
|
||||||
if eta_noise_seed_delta > 0:
|
if eta_noise_seed_delta > 0:
|
||||||
torch.manual_seed(seed + eta_noise_seed_delta)
|
devices.manual_seed(seed + eta_noise_seed_delta)
|
||||||
|
|
||||||
for j in range(cnt):
|
for j in range(cnt):
|
||||||
sampler_noises[j].append(devices.randn_without_seed(tuple(noise_shape)))
|
sampler_noises[j].append(devices.randn_without_seed(tuple(noise_shape)))
|
||||||
@ -522,11 +539,42 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
def decode_first_stage(model, x):
|
class DecodedSamples(list):
|
||||||
with devices.autocast(disable=x.dtype == devices.dtype_vae):
|
already_decoded = True
|
||||||
x = model.decode_first_stage(x)
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
def decode_latent_batch(model, batch, target_device=None, check_for_nans=False):
|
||||||
|
samples = DecodedSamples()
|
||||||
|
|
||||||
|
for i in range(batch.shape[0]):
|
||||||
|
sample = decode_first_stage(model, batch[i:i + 1])[0]
|
||||||
|
|
||||||
|
if check_for_nans:
|
||||||
|
try:
|
||||||
|
devices.test_for_nans(sample, "vae")
|
||||||
|
except devices.NansException as e:
|
||||||
|
if devices.dtype_vae == torch.float32 or not shared.opts.auto_vae_precision:
|
||||||
|
raise e
|
||||||
|
|
||||||
|
errors.print_error_explanation(
|
||||||
|
"A tensor with all NaNs was produced in VAE.\n"
|
||||||
|
"Web UI will now convert VAE into 32-bit float and retry.\n"
|
||||||
|
"To disable this behavior, disable the 'Automaticlly revert VAE to 32-bit floats' setting.\n"
|
||||||
|
"To always start with 32-bit VAE, use --no-half-vae commandline flag."
|
||||||
|
)
|
||||||
|
|
||||||
|
devices.dtype_vae = torch.float32
|
||||||
|
model.first_stage_model.to(devices.dtype_vae)
|
||||||
|
batch = batch.to(devices.dtype_vae)
|
||||||
|
|
||||||
|
sample = decode_first_stage(model, batch[i:i + 1])[0]
|
||||||
|
|
||||||
|
if target_device is not None:
|
||||||
|
sample = sample.to(target_device)
|
||||||
|
|
||||||
|
samples.append(sample)
|
||||||
|
|
||||||
|
return samples
|
||||||
|
|
||||||
|
|
||||||
def get_fixed_seed(seed):
|
def get_fixed_seed(seed):
|
||||||
@ -551,8 +599,12 @@ def program_version():
|
|||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iteration=0, position_in_batch=0, use_main_prompt=False):
|
def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iteration=0, position_in_batch=0, use_main_prompt=False, index=None, all_negative_prompts=None):
|
||||||
index = position_in_batch + iteration * p.batch_size
|
if index is None:
|
||||||
|
index = position_in_batch + iteration * p.batch_size
|
||||||
|
|
||||||
|
if all_negative_prompts is None:
|
||||||
|
all_negative_prompts = p.all_negative_prompts
|
||||||
|
|
||||||
clip_skip = getattr(p, 'clip_skip', opts.CLIP_stop_at_last_layers)
|
clip_skip = getattr(p, 'clip_skip', opts.CLIP_stop_at_last_layers)
|
||||||
enable_hr = getattr(p, 'enable_hr', False)
|
enable_hr = getattr(p, 'enable_hr', False)
|
||||||
@ -568,12 +620,12 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
|
|||||||
"Sampler": p.sampler_name,
|
"Sampler": p.sampler_name,
|
||||||
"CFG scale": p.cfg_scale,
|
"CFG scale": p.cfg_scale,
|
||||||
"Image CFG scale": getattr(p, 'image_cfg_scale', None),
|
"Image CFG scale": getattr(p, 'image_cfg_scale', None),
|
||||||
"Seed": all_seeds[index],
|
"Seed": p.all_seeds[0] if use_main_prompt else all_seeds[index],
|
||||||
"Face restoration": (opts.face_restoration_model if p.restore_faces else None),
|
"Face restoration": (opts.face_restoration_model if p.restore_faces else None),
|
||||||
"Size": f"{p.width}x{p.height}",
|
"Size": f"{p.width}x{p.height}",
|
||||||
"Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
|
"Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
|
||||||
"Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')),
|
"Model": (None if not opts.add_model_name_to_info else shared.sd_model.sd_checkpoint_info.name_for_extra),
|
||||||
"Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),
|
"Variation seed": (None if p.subseed_strength == 0 else (p.all_subseeds[0] if use_main_prompt else all_subseeds[index])),
|
||||||
"Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
|
"Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
|
||||||
"Seed resize from": (None if p.seed_resize_from_w <= 0 or p.seed_resize_from_h <= 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
|
"Seed resize from": (None if p.seed_resize_from_w <= 0 or p.seed_resize_from_h <= 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
|
||||||
"Denoising strength": getattr(p, 'denoising_strength', None),
|
"Denoising strength": getattr(p, 'denoising_strength', None),
|
||||||
@ -583,7 +635,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
|
|||||||
"Token merging ratio": None if token_merging_ratio == 0 else token_merging_ratio,
|
"Token merging ratio": None if token_merging_ratio == 0 else token_merging_ratio,
|
||||||
"Token merging ratio hr": None if not enable_hr or token_merging_ratio_hr == 0 else token_merging_ratio_hr,
|
"Token merging ratio hr": None if not enable_hr or token_merging_ratio_hr == 0 else token_merging_ratio_hr,
|
||||||
"Init image hash": getattr(p, 'init_img_hash', None),
|
"Init image hash": getattr(p, 'init_img_hash', None),
|
||||||
"RNG": opts.randn_source if opts.randn_source != "GPU" else None,
|
"RNG": opts.randn_source if opts.randn_source != "GPU" and opts.randn_source != "NV" else None,
|
||||||
"NGMS": None if p.s_min_uncond == 0 else p.s_min_uncond,
|
"NGMS": None if p.s_min_uncond == 0 else p.s_min_uncond,
|
||||||
**p.extra_generation_params,
|
**p.extra_generation_params,
|
||||||
"Version": program_version() if opts.add_version_to_infotext else None,
|
"Version": program_version() if opts.add_version_to_infotext else None,
|
||||||
@ -593,7 +645,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
|
|||||||
generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None])
|
generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None])
|
||||||
|
|
||||||
prompt_text = p.prompt if use_main_prompt else all_prompts[index]
|
prompt_text = p.prompt if use_main_prompt else all_prompts[index]
|
||||||
negative_prompt_text = f"\nNegative prompt: {p.all_negative_prompts[index]}" if p.all_negative_prompts[index] else ""
|
negative_prompt_text = f"\nNegative prompt: {all_negative_prompts[index]}" if all_negative_prompts[index] else ""
|
||||||
|
|
||||||
return f"{prompt_text}{negative_prompt_text}\n{generation_params_text}".strip()
|
return f"{prompt_text}{negative_prompt_text}\n{generation_params_text}".strip()
|
||||||
|
|
||||||
@ -667,9 +719,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||||||
else:
|
else:
|
||||||
p.all_subseeds = [int(subseed) + x for x in range(len(p.all_prompts))]
|
p.all_subseeds = [int(subseed) + x for x in range(len(p.all_prompts))]
|
||||||
|
|
||||||
def infotext(iteration=0, position_in_batch=0, use_main_prompt=False):
|
|
||||||
return create_infotext(p, p.all_prompts, p.all_seeds, p.all_subseeds, comments, iteration, position_in_batch, use_main_prompt)
|
|
||||||
|
|
||||||
if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings:
|
if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings:
|
||||||
model_hijack.embedding_db.load_textual_inversion_embeddings()
|
model_hijack.embedding_db.load_textual_inversion_embeddings()
|
||||||
|
|
||||||
@ -743,9 +792,10 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||||||
with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast():
|
with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast():
|
||||||
samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts)
|
samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts)
|
||||||
|
|
||||||
x_samples_ddim = [decode_first_stage(p.sd_model, samples_ddim[i:i+1].to(dtype=devices.dtype_vae))[0].cpu() for i in range(samples_ddim.size(0))]
|
if getattr(samples_ddim, 'already_decoded', False):
|
||||||
for x in x_samples_ddim:
|
x_samples_ddim = samples_ddim
|
||||||
devices.test_for_nans(x, "vae")
|
else:
|
||||||
|
x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True)
|
||||||
|
|
||||||
x_samples_ddim = torch.stack(x_samples_ddim).float()
|
x_samples_ddim = torch.stack(x_samples_ddim).float()
|
||||||
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
||||||
@ -760,6 +810,16 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||||||
if p.scripts is not None:
|
if p.scripts is not None:
|
||||||
p.scripts.postprocess_batch(p, x_samples_ddim, batch_number=n)
|
p.scripts.postprocess_batch(p, x_samples_ddim, batch_number=n)
|
||||||
|
|
||||||
|
p.prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
|
||||||
|
p.negative_prompts = p.all_negative_prompts[n * p.batch_size:(n + 1) * p.batch_size]
|
||||||
|
|
||||||
|
batch_params = scripts.PostprocessBatchListArgs(list(x_samples_ddim))
|
||||||
|
p.scripts.postprocess_batch_list(p, batch_params, batch_number=n)
|
||||||
|
x_samples_ddim = batch_params.images
|
||||||
|
|
||||||
|
def infotext(index=0, use_main_prompt=False):
|
||||||
|
return create_infotext(p, p.prompts, p.seeds, p.subseeds, use_main_prompt=use_main_prompt, index=index, all_negative_prompts=p.negative_prompts)
|
||||||
|
|
||||||
for i, x_sample in enumerate(x_samples_ddim):
|
for i, x_sample in enumerate(x_samples_ddim):
|
||||||
p.batch_index = i
|
p.batch_index = i
|
||||||
|
|
||||||
@ -768,7 +828,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||||||
|
|
||||||
if p.restore_faces:
|
if p.restore_faces:
|
||||||
if opts.save and not p.do_not_save_samples and opts.save_images_before_face_restoration:
|
if opts.save and not p.do_not_save_samples and opts.save_images_before_face_restoration:
|
||||||
images.save_image(Image.fromarray(x_sample), p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-face-restoration")
|
images.save_image(Image.fromarray(x_sample), p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-before-face-restoration")
|
||||||
|
|
||||||
devices.torch_gc()
|
devices.torch_gc()
|
||||||
|
|
||||||
@ -785,15 +845,15 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||||||
if p.color_corrections is not None and i < len(p.color_corrections):
|
if p.color_corrections is not None and i < len(p.color_corrections):
|
||||||
if opts.save and not p.do_not_save_samples and opts.save_images_before_color_correction:
|
if opts.save and not p.do_not_save_samples and opts.save_images_before_color_correction:
|
||||||
image_without_cc = apply_overlay(image, p.paste_to, i, p.overlay_images)
|
image_without_cc = apply_overlay(image, p.paste_to, i, p.overlay_images)
|
||||||
images.save_image(image_without_cc, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-color-correction")
|
images.save_image(image_without_cc, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-before-color-correction")
|
||||||
image = apply_color_correction(p.color_corrections[i], image)
|
image = apply_color_correction(p.color_corrections[i], image)
|
||||||
|
|
||||||
image = apply_overlay(image, p.paste_to, i, p.overlay_images)
|
image = apply_overlay(image, p.paste_to, i, p.overlay_images)
|
||||||
|
|
||||||
if opts.samples_save and not p.do_not_save_samples:
|
if opts.samples_save and not p.do_not_save_samples:
|
||||||
images.save_image(image, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(n, i), p=p)
|
images.save_image(image, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p)
|
||||||
|
|
||||||
text = infotext(n, i)
|
text = infotext(i)
|
||||||
infotexts.append(text)
|
infotexts.append(text)
|
||||||
if opts.enable_pnginfo:
|
if opts.enable_pnginfo:
|
||||||
image.info["parameters"] = text
|
image.info["parameters"] = text
|
||||||
@ -804,10 +864,10 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||||||
image_mask_composite = Image.composite(image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), images.resize_image(2, p.mask_for_overlay, image.width, image.height).convert('L')).convert('RGBA')
|
image_mask_composite = Image.composite(image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), images.resize_image(2, p.mask_for_overlay, image.width, image.height).convert('L')).convert('RGBA')
|
||||||
|
|
||||||
if opts.save_mask:
|
if opts.save_mask:
|
||||||
images.save_image(image_mask, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-mask")
|
images.save_image(image_mask, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask")
|
||||||
|
|
||||||
if opts.save_mask_composite:
|
if opts.save_mask_composite:
|
||||||
images.save_image(image_mask_composite, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-mask-composite")
|
images.save_image(image_mask_composite, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask-composite")
|
||||||
|
|
||||||
if opts.return_mask:
|
if opts.return_mask:
|
||||||
output_images.append(image_mask)
|
output_images.append(image_mask)
|
||||||
@ -848,7 +908,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||||||
p,
|
p,
|
||||||
images_list=output_images,
|
images_list=output_images,
|
||||||
seed=p.all_seeds[0],
|
seed=p.all_seeds[0],
|
||||||
info=infotext(),
|
info=infotexts[0],
|
||||||
comments="".join(f"{comment}\n" for comment in comments),
|
comments="".join(f"{comment}\n" for comment in comments),
|
||||||
subseed=p.all_subseeds[0],
|
subseed=p.all_subseeds[0],
|
||||||
index_of_first_image=index_of_first_image,
|
index_of_first_image=index_of_first_image,
|
||||||
@ -878,7 +938,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
|||||||
cached_hr_uc = [None, None]
|
cached_hr_uc = [None, None]
|
||||||
cached_hr_c = [None, None]
|
cached_hr_c = [None, None]
|
||||||
|
|
||||||
def __init__(self, enable_hr: bool = False, denoising_strength: float = 0.75, firstphase_width: int = 0, firstphase_height: int = 0, hr_scale: float = 2.0, hr_upscaler: str = None, hr_second_pass_steps: int = 0, hr_resize_x: int = 0, hr_resize_y: int = 0, hr_sampler_name: str = None, hr_prompt: str = '', hr_negative_prompt: str = '', **kwargs):
|
def __init__(self, enable_hr: bool = False, denoising_strength: float = 0.75, firstphase_width: int = 0, firstphase_height: int = 0, hr_scale: float = 2.0, hr_upscaler: str = None, hr_second_pass_steps: int = 0, hr_resize_x: int = 0, hr_resize_y: int = 0, hr_checkpoint_name: str = None, hr_sampler_name: str = None, hr_prompt: str = '', hr_negative_prompt: str = '', **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.enable_hr = enable_hr
|
self.enable_hr = enable_hr
|
||||||
self.denoising_strength = denoising_strength
|
self.denoising_strength = denoising_strength
|
||||||
@ -889,11 +949,14 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
|||||||
self.hr_resize_y = hr_resize_y
|
self.hr_resize_y = hr_resize_y
|
||||||
self.hr_upscale_to_x = hr_resize_x
|
self.hr_upscale_to_x = hr_resize_x
|
||||||
self.hr_upscale_to_y = hr_resize_y
|
self.hr_upscale_to_y = hr_resize_y
|
||||||
|
self.hr_checkpoint_name = hr_checkpoint_name
|
||||||
|
self.hr_checkpoint_info = None
|
||||||
self.hr_sampler_name = hr_sampler_name
|
self.hr_sampler_name = hr_sampler_name
|
||||||
self.hr_prompt = hr_prompt
|
self.hr_prompt = hr_prompt
|
||||||
self.hr_negative_prompt = hr_negative_prompt
|
self.hr_negative_prompt = hr_negative_prompt
|
||||||
self.all_hr_prompts = None
|
self.all_hr_prompts = None
|
||||||
self.all_hr_negative_prompts = None
|
self.all_hr_negative_prompts = None
|
||||||
|
self.latent_scale_mode = None
|
||||||
|
|
||||||
if firstphase_width != 0 or firstphase_height != 0:
|
if firstphase_width != 0 or firstphase_height != 0:
|
||||||
self.hr_upscale_to_x = self.width
|
self.hr_upscale_to_x = self.width
|
||||||
@ -916,6 +979,14 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
|||||||
|
|
||||||
def init(self, all_prompts, all_seeds, all_subseeds):
|
def init(self, all_prompts, all_seeds, all_subseeds):
|
||||||
if self.enable_hr:
|
if self.enable_hr:
|
||||||
|
if self.hr_checkpoint_name:
|
||||||
|
self.hr_checkpoint_info = sd_models.get_closet_checkpoint_match(self.hr_checkpoint_name)
|
||||||
|
|
||||||
|
if self.hr_checkpoint_info is None:
|
||||||
|
raise Exception(f'Could not find checkpoint with name {self.hr_checkpoint_name}')
|
||||||
|
|
||||||
|
self.extra_generation_params["Hires checkpoint"] = self.hr_checkpoint_info.short_title
|
||||||
|
|
||||||
if self.hr_sampler_name is not None and self.hr_sampler_name != self.sampler_name:
|
if self.hr_sampler_name is not None and self.hr_sampler_name != self.sampler_name:
|
||||||
self.extra_generation_params["Hires sampler"] = self.hr_sampler_name
|
self.extra_generation_params["Hires sampler"] = self.hr_sampler_name
|
||||||
|
|
||||||
@ -925,6 +996,11 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
|||||||
if tuple(self.hr_negative_prompt) != tuple(self.negative_prompt):
|
if tuple(self.hr_negative_prompt) != tuple(self.negative_prompt):
|
||||||
self.extra_generation_params["Hires negative prompt"] = self.hr_negative_prompt
|
self.extra_generation_params["Hires negative prompt"] = self.hr_negative_prompt
|
||||||
|
|
||||||
|
self.latent_scale_mode = shared.latent_upscale_modes.get(self.hr_upscaler, None) if self.hr_upscaler is not None else shared.latent_upscale_modes.get(shared.latent_upscale_default_mode, "nearest")
|
||||||
|
if self.enable_hr and self.latent_scale_mode is None:
|
||||||
|
if not any(x.name == self.hr_upscaler for x in shared.sd_upscalers):
|
||||||
|
raise Exception(f"could not find upscaler named {self.hr_upscaler}")
|
||||||
|
|
||||||
if opts.use_old_hires_fix_width_height and self.applied_old_hires_behavior_to != (self.width, self.height):
|
if opts.use_old_hires_fix_width_height and self.applied_old_hires_behavior_to != (self.width, self.height):
|
||||||
self.hr_resize_x = self.width
|
self.hr_resize_x = self.width
|
||||||
self.hr_resize_y = self.height
|
self.hr_resize_y = self.height
|
||||||
@ -963,14 +1039,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
|||||||
self.truncate_x = (self.hr_upscale_to_x - target_w) // opt_f
|
self.truncate_x = (self.hr_upscale_to_x - target_w) // opt_f
|
||||||
self.truncate_y = (self.hr_upscale_to_y - target_h) // opt_f
|
self.truncate_y = (self.hr_upscale_to_y - target_h) // opt_f
|
||||||
|
|
||||||
# special case: the user has chosen to do nothing
|
|
||||||
if self.hr_upscale_to_x == self.width and self.hr_upscale_to_y == self.height:
|
|
||||||
self.enable_hr = False
|
|
||||||
self.denoising_strength = None
|
|
||||||
self.extra_generation_params.pop("Hires upscale", None)
|
|
||||||
self.extra_generation_params.pop("Hires resize", None)
|
|
||||||
return
|
|
||||||
|
|
||||||
if not state.processing_has_refined_job_count:
|
if not state.processing_has_refined_job_count:
|
||||||
if state.job_count == -1:
|
if state.job_count == -1:
|
||||||
state.job_count = self.n_iter
|
state.job_count = self.n_iter
|
||||||
@ -988,17 +1056,32 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
|||||||
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
|
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
|
||||||
self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
|
self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
|
||||||
|
|
||||||
latent_scale_mode = shared.latent_upscale_modes.get(self.hr_upscaler, None) if self.hr_upscaler is not None else shared.latent_upscale_modes.get(shared.latent_upscale_default_mode, "nearest")
|
|
||||||
if self.enable_hr and latent_scale_mode is None:
|
|
||||||
if not any(x.name == self.hr_upscaler for x in shared.sd_upscalers):
|
|
||||||
raise Exception(f"could not find upscaler named {self.hr_upscaler}")
|
|
||||||
|
|
||||||
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
|
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
|
||||||
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
|
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
|
||||||
|
del x
|
||||||
|
|
||||||
if not self.enable_hr:
|
if not self.enable_hr:
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
|
if self.latent_scale_mode is None:
|
||||||
|
decoded_samples = torch.stack(decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)).to(dtype=torch.float32)
|
||||||
|
else:
|
||||||
|
decoded_samples = None
|
||||||
|
|
||||||
|
current = shared.sd_model.sd_checkpoint_info
|
||||||
|
try:
|
||||||
|
if self.hr_checkpoint_info is not None:
|
||||||
|
self.sampler = None
|
||||||
|
sd_models.reload_model_weights(info=self.hr_checkpoint_info)
|
||||||
|
devices.torch_gc()
|
||||||
|
|
||||||
|
return self.sample_hr_pass(samples, decoded_samples, seeds, subseeds, subseed_strength, prompts)
|
||||||
|
finally:
|
||||||
|
self.sampler = None
|
||||||
|
sd_models.reload_model_weights(info=current)
|
||||||
|
devices.torch_gc()
|
||||||
|
|
||||||
|
def sample_hr_pass(self, samples, decoded_samples, seeds, subseeds, subseed_strength, prompts):
|
||||||
self.is_hr_pass = True
|
self.is_hr_pass = True
|
||||||
|
|
||||||
target_width = self.hr_upscale_to_x
|
target_width = self.hr_upscale_to_x
|
||||||
@ -1014,13 +1097,20 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
|||||||
image = sd_samplers.sample_to_image(image, index, approximation=0)
|
image = sd_samplers.sample_to_image(image, index, approximation=0)
|
||||||
|
|
||||||
info = create_infotext(self, self.all_prompts, self.all_seeds, self.all_subseeds, [], iteration=self.iteration, position_in_batch=index)
|
info = create_infotext(self, self.all_prompts, self.all_seeds, self.all_subseeds, [], iteration=self.iteration, position_in_batch=index)
|
||||||
images.save_image(image, self.outpath_samples, "", seeds[index], prompts[index], opts.samples_format, info=info, suffix="-before-highres-fix")
|
images.save_image(image, self.outpath_samples, "", seeds[index], prompts[index], opts.samples_format, info=info, p=self, suffix="-before-highres-fix")
|
||||||
|
|
||||||
if latent_scale_mode is not None:
|
img2img_sampler_name = self.hr_sampler_name or self.sampler_name
|
||||||
|
|
||||||
|
if self.sampler_name in ['PLMS', 'UniPC']: # PLMS/UniPC do not support img2img so we just silently switch to DDIM
|
||||||
|
img2img_sampler_name = 'DDIM'
|
||||||
|
|
||||||
|
self.sampler = sd_samplers.create_sampler(img2img_sampler_name, self.sd_model)
|
||||||
|
|
||||||
|
if self.latent_scale_mode is not None:
|
||||||
for i in range(samples.shape[0]):
|
for i in range(samples.shape[0]):
|
||||||
save_intermediate(samples, i)
|
save_intermediate(samples, i)
|
||||||
|
|
||||||
samples = torch.nn.functional.interpolate(samples, size=(target_height // opt_f, target_width // opt_f), mode=latent_scale_mode["mode"], antialias=latent_scale_mode["antialias"])
|
samples = torch.nn.functional.interpolate(samples, size=(target_height // opt_f, target_width // opt_f), mode=self.latent_scale_mode["mode"], antialias=self.latent_scale_mode["antialias"])
|
||||||
|
|
||||||
# Avoid making the inpainting conditioning unless necessary as
|
# Avoid making the inpainting conditioning unless necessary as
|
||||||
# this does need some extra compute to decode / encode the image again.
|
# this does need some extra compute to decode / encode the image again.
|
||||||
@ -1029,7 +1119,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
|||||||
else:
|
else:
|
||||||
image_conditioning = self.txt2img_image_conditioning(samples)
|
image_conditioning = self.txt2img_image_conditioning(samples)
|
||||||
else:
|
else:
|
||||||
decoded_samples = decode_first_stage(self.sd_model, samples)
|
|
||||||
lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0)
|
lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0)
|
||||||
|
|
||||||
batch_images = []
|
batch_images = []
|
||||||
@ -1048,6 +1137,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
|||||||
decoded_samples = torch.from_numpy(np.array(batch_images))
|
decoded_samples = torch.from_numpy(np.array(batch_images))
|
||||||
decoded_samples = decoded_samples.to(shared.device)
|
decoded_samples = decoded_samples.to(shared.device)
|
||||||
decoded_samples = 2. * decoded_samples - 1.
|
decoded_samples = 2. * decoded_samples - 1.
|
||||||
|
decoded_samples = decoded_samples.to(shared.device, dtype=devices.dtype_vae)
|
||||||
|
|
||||||
samples = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(decoded_samples))
|
samples = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(decoded_samples))
|
||||||
|
|
||||||
@ -1055,19 +1145,11 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
|||||||
|
|
||||||
shared.state.nextjob()
|
shared.state.nextjob()
|
||||||
|
|
||||||
img2img_sampler_name = self.hr_sampler_name or self.sampler_name
|
|
||||||
|
|
||||||
if self.sampler_name in ['PLMS', 'UniPC']: # PLMS/UniPC do not support img2img so we just silently switch to DDIM
|
|
||||||
img2img_sampler_name = 'DDIM'
|
|
||||||
|
|
||||||
self.sampler = sd_samplers.create_sampler(img2img_sampler_name, self.sd_model)
|
|
||||||
|
|
||||||
samples = samples[:, :, self.truncate_y//2:samples.shape[2]-(self.truncate_y+1)//2, self.truncate_x//2:samples.shape[3]-(self.truncate_x+1)//2]
|
samples = samples[:, :, self.truncate_y//2:samples.shape[2]-(self.truncate_y+1)//2, self.truncate_x//2:samples.shape[3]-(self.truncate_x+1)//2]
|
||||||
|
|
||||||
noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, p=self)
|
noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, p=self)
|
||||||
|
|
||||||
# GC now before running the next img2img to prevent running out of memory
|
# GC now before running the next img2img to prevent running out of memory
|
||||||
x = None
|
|
||||||
devices.torch_gc()
|
devices.torch_gc()
|
||||||
|
|
||||||
if not self.disable_extra_networks:
|
if not self.disable_extra_networks:
|
||||||
@ -1086,9 +1168,11 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
|||||||
|
|
||||||
sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio())
|
sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio())
|
||||||
|
|
||||||
|
decoded_samples = decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)
|
||||||
|
|
||||||
self.is_hr_pass = False
|
self.is_hr_pass = False
|
||||||
|
|
||||||
return samples
|
return decoded_samples
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
super().close()
|
super().close()
|
||||||
@ -1127,8 +1211,11 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
|||||||
if self.hr_c is not None:
|
if self.hr_c is not None:
|
||||||
return
|
return
|
||||||
|
|
||||||
self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, self.hr_negative_prompts, self.steps * self.step_multiplier, [self.cached_hr_uc, self.cached_uc], self.hr_extra_network_data)
|
hr_prompts = prompt_parser.SdConditioning(self.hr_prompts, width=self.hr_upscale_to_x, height=self.hr_upscale_to_y)
|
||||||
self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, self.hr_prompts, self.steps * self.step_multiplier, [self.cached_hr_c, self.cached_c], self.hr_extra_network_data)
|
hr_negative_prompts = prompt_parser.SdConditioning(self.hr_negative_prompts, width=self.hr_upscale_to_x, height=self.hr_upscale_to_y, is_negative_prompt=True)
|
||||||
|
|
||||||
|
self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, hr_negative_prompts, self.steps * self.step_multiplier, [self.cached_hr_uc, self.cached_uc], self.hr_extra_network_data)
|
||||||
|
self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, hr_prompts, self.steps * self.step_multiplier, [self.cached_hr_c, self.cached_c], self.hr_extra_network_data)
|
||||||
|
|
||||||
def setup_conds(self):
|
def setup_conds(self):
|
||||||
super().setup_conds()
|
super().setup_conds()
|
||||||
@ -1136,7 +1223,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
|||||||
self.hr_uc = None
|
self.hr_uc = None
|
||||||
self.hr_c = None
|
self.hr_c = None
|
||||||
|
|
||||||
if self.enable_hr:
|
if self.enable_hr and self.hr_checkpoint_info is None:
|
||||||
if shared.opts.hires_fix_use_firstpass_conds:
|
if shared.opts.hires_fix_use_firstpass_conds:
|
||||||
self.calculate_hr_conds()
|
self.calculate_hr_conds()
|
||||||
|
|
||||||
@ -1288,9 +1375,10 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
|||||||
|
|
||||||
image = torch.from_numpy(batch_images)
|
image = torch.from_numpy(batch_images)
|
||||||
image = 2. * image - 1.
|
image = 2. * image - 1.
|
||||||
image = image.to(shared.device)
|
image = image.to(shared.device, dtype=devices.dtype_vae)
|
||||||
|
|
||||||
self.init_latent = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image))
|
self.init_latent = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image))
|
||||||
|
devices.torch_gc()
|
||||||
|
|
||||||
if self.resize_mode == 3:
|
if self.resize_mode == 3:
|
||||||
self.init_latent = torch.nn.functional.interpolate(self.init_latent, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")
|
self.init_latent = torch.nn.functional.interpolate(self.init_latent, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import re
|
import re
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
from typing import List
|
from typing import List
|
||||||
@ -17,8 +19,8 @@ prompt: (emphasized | scheduled | alternate | plain | WHITESPACE)*
|
|||||||
!emphasized: "(" prompt ")"
|
!emphasized: "(" prompt ")"
|
||||||
| "(" prompt ":" prompt ")"
|
| "(" prompt ":" prompt ")"
|
||||||
| "[" prompt "]"
|
| "[" prompt "]"
|
||||||
scheduled: "[" [prompt ":"] prompt ":" [WHITESPACE] NUMBER "]"
|
scheduled: "[" [prompt ":"] prompt ":" [WHITESPACE] NUMBER [WHITESPACE] "]"
|
||||||
alternate: "[" prompt ("|" prompt)+ "]"
|
alternate: "[" prompt ("|" [prompt])+ "]"
|
||||||
WHITESPACE: /\s+/
|
WHITESPACE: /\s+/
|
||||||
plain: /([^\\\[\]():|]|\\.)+/
|
plain: /([^\\\[\]():|]|\\.)+/
|
||||||
%import common.SIGNED_NUMBER -> NUMBER
|
%import common.SIGNED_NUMBER -> NUMBER
|
||||||
@ -51,6 +53,10 @@ def get_learned_conditioning_prompt_schedules(prompts, steps):
|
|||||||
[[3, '((a][:b:c '], [10, '((a][:b:c d']]
|
[[3, '((a][:b:c '], [10, '((a][:b:c d']]
|
||||||
>>> g("[a|(b:1.1)]")
|
>>> g("[a|(b:1.1)]")
|
||||||
[[1, 'a'], [2, '(b:1.1)'], [3, 'a'], [4, '(b:1.1)'], [5, 'a'], [6, '(b:1.1)'], [7, 'a'], [8, '(b:1.1)'], [9, 'a'], [10, '(b:1.1)']]
|
[[1, 'a'], [2, '(b:1.1)'], [3, 'a'], [4, '(b:1.1)'], [5, 'a'], [6, '(b:1.1)'], [7, 'a'], [8, '(b:1.1)'], [9, 'a'], [10, '(b:1.1)']]
|
||||||
|
>>> g("[fe|]male")
|
||||||
|
[[1, 'female'], [2, 'male'], [3, 'female'], [4, 'male'], [5, 'female'], [6, 'male'], [7, 'female'], [8, 'male'], [9, 'female'], [10, 'male']]
|
||||||
|
>>> g("[fe|||]male")
|
||||||
|
[[1, 'female'], [2, 'male'], [3, 'male'], [4, 'male'], [5, 'female'], [6, 'male'], [7, 'male'], [8, 'male'], [9, 'female'], [10, 'male']]
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def collect_steps(steps, tree):
|
def collect_steps(steps, tree):
|
||||||
@ -58,11 +64,11 @@ def get_learned_conditioning_prompt_schedules(prompts, steps):
|
|||||||
|
|
||||||
class CollectSteps(lark.Visitor):
|
class CollectSteps(lark.Visitor):
|
||||||
def scheduled(self, tree):
|
def scheduled(self, tree):
|
||||||
tree.children[-1] = float(tree.children[-1])
|
tree.children[-2] = float(tree.children[-2])
|
||||||
if tree.children[-1] < 1:
|
if tree.children[-2] < 1:
|
||||||
tree.children[-1] *= steps
|
tree.children[-2] *= steps
|
||||||
tree.children[-1] = min(steps, int(tree.children[-1]))
|
tree.children[-2] = min(steps, int(tree.children[-2]))
|
||||||
res.append(tree.children[-1])
|
res.append(tree.children[-2])
|
||||||
|
|
||||||
def alternate(self, tree):
|
def alternate(self, tree):
|
||||||
res.extend(range(1, steps+1))
|
res.extend(range(1, steps+1))
|
||||||
@ -73,10 +79,11 @@ def get_learned_conditioning_prompt_schedules(prompts, steps):
|
|||||||
def at_step(step, tree):
|
def at_step(step, tree):
|
||||||
class AtStep(lark.Transformer):
|
class AtStep(lark.Transformer):
|
||||||
def scheduled(self, args):
|
def scheduled(self, args):
|
||||||
before, after, _, when = args
|
before, after, _, when, _ = args
|
||||||
yield before or () if step <= when else after
|
yield before or () if step <= when else after
|
||||||
def alternate(self, args):
|
def alternate(self, args):
|
||||||
yield next(args[(step - 1)%len(args)])
|
args = ["" if not arg else arg for arg in args]
|
||||||
|
yield args[(step - 1) % len(args)]
|
||||||
def start(self, args):
|
def start(self, args):
|
||||||
def flatten(x):
|
def flatten(x):
|
||||||
if type(x) == str:
|
if type(x) == str:
|
||||||
@ -109,7 +116,25 @@ def get_learned_conditioning_prompt_schedules(prompts, steps):
|
|||||||
ScheduledPromptConditioning = namedtuple("ScheduledPromptConditioning", ["end_at_step", "cond"])
|
ScheduledPromptConditioning = namedtuple("ScheduledPromptConditioning", ["end_at_step", "cond"])
|
||||||
|
|
||||||
|
|
||||||
def get_learned_conditioning(model, prompts, steps):
|
class SdConditioning(list):
|
||||||
|
"""
|
||||||
|
A list with prompts for stable diffusion's conditioner model.
|
||||||
|
Can also specify width and height of created image - SDXL needs it.
|
||||||
|
"""
|
||||||
|
def __init__(self, prompts, is_negative_prompt=False, width=None, height=None, copy_from=None):
|
||||||
|
super().__init__()
|
||||||
|
self.extend(prompts)
|
||||||
|
|
||||||
|
if copy_from is None:
|
||||||
|
copy_from = prompts
|
||||||
|
|
||||||
|
self.is_negative_prompt = is_negative_prompt or getattr(copy_from, 'is_negative_prompt', False)
|
||||||
|
self.width = width or getattr(copy_from, 'width', None)
|
||||||
|
self.height = height or getattr(copy_from, 'height', None)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def get_learned_conditioning(model, prompts: SdConditioning | list[str], steps):
|
||||||
"""converts a list of prompts into a list of prompt schedules - each schedule is a list of ScheduledPromptConditioning, specifying the comdition (cond),
|
"""converts a list of prompts into a list of prompt schedules - each schedule is a list of ScheduledPromptConditioning, specifying the comdition (cond),
|
||||||
and the sampling step at which this condition is to be replaced by the next one.
|
and the sampling step at which this condition is to be replaced by the next one.
|
||||||
|
|
||||||
@ -139,12 +164,17 @@ def get_learned_conditioning(model, prompts, steps):
|
|||||||
res.append(cached)
|
res.append(cached)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
texts = [x[1] for x in prompt_schedule]
|
texts = SdConditioning([x[1] for x in prompt_schedule], copy_from=prompts)
|
||||||
conds = model.get_learned_conditioning(texts)
|
conds = model.get_learned_conditioning(texts)
|
||||||
|
|
||||||
cond_schedule = []
|
cond_schedule = []
|
||||||
for i, (end_at_step, _) in enumerate(prompt_schedule):
|
for i, (end_at_step, _) in enumerate(prompt_schedule):
|
||||||
cond_schedule.append(ScheduledPromptConditioning(end_at_step, conds[i]))
|
if isinstance(conds, dict):
|
||||||
|
cond = {k: v[i] for k, v in conds.items()}
|
||||||
|
else:
|
||||||
|
cond = conds[i]
|
||||||
|
|
||||||
|
cond_schedule.append(ScheduledPromptConditioning(end_at_step, cond))
|
||||||
|
|
||||||
cache[prompt] = cond_schedule
|
cache[prompt] = cond_schedule
|
||||||
res.append(cond_schedule)
|
res.append(cond_schedule)
|
||||||
@ -153,13 +183,15 @@ def get_learned_conditioning(model, prompts, steps):
|
|||||||
|
|
||||||
|
|
||||||
re_AND = re.compile(r"\bAND\b")
|
re_AND = re.compile(r"\bAND\b")
|
||||||
re_weight = re.compile(r"^(.*?)(?:\s*:\s*([-+]?(?:\d+\.?|\d*\.\d+)))?\s*$")
|
re_weight = re.compile(r"^((?:\s|.)*?)(?:\s*:\s*([-+]?(?:\d+\.?|\d*\.\d+)))?\s*$")
|
||||||
|
|
||||||
def get_multicond_prompt_list(prompts):
|
|
||||||
|
def get_multicond_prompt_list(prompts: SdConditioning | list[str]):
|
||||||
res_indexes = []
|
res_indexes = []
|
||||||
|
|
||||||
prompt_flat_list = []
|
|
||||||
prompt_indexes = {}
|
prompt_indexes = {}
|
||||||
|
prompt_flat_list = SdConditioning(prompts)
|
||||||
|
prompt_flat_list.clear()
|
||||||
|
|
||||||
for prompt in prompts:
|
for prompt in prompts:
|
||||||
subprompts = re_AND.split(prompt)
|
subprompts = re_AND.split(prompt)
|
||||||
@ -196,6 +228,7 @@ class MulticondLearnedConditioning:
|
|||||||
self.shape: tuple = shape # the shape field is needed to send this object to DDIM/PLMS
|
self.shape: tuple = shape # the shape field is needed to send this object to DDIM/PLMS
|
||||||
self.batch: List[List[ComposableScheduledPromptConditioning]] = batch
|
self.batch: List[List[ComposableScheduledPromptConditioning]] = batch
|
||||||
|
|
||||||
|
|
||||||
def get_multicond_learned_conditioning(model, prompts, steps) -> MulticondLearnedConditioning:
|
def get_multicond_learned_conditioning(model, prompts, steps) -> MulticondLearnedConditioning:
|
||||||
"""same as get_learned_conditioning, but returns a list of ScheduledPromptConditioning along with the weight objects for each prompt.
|
"""same as get_learned_conditioning, but returns a list of ScheduledPromptConditioning along with the weight objects for each prompt.
|
||||||
For each prompt, the list is obtained by splitting the prompt using the AND separator.
|
For each prompt, the list is obtained by splitting the prompt using the AND separator.
|
||||||
@ -214,20 +247,57 @@ def get_multicond_learned_conditioning(model, prompts, steps) -> MulticondLearne
|
|||||||
return MulticondLearnedConditioning(shape=(len(prompts),), batch=res)
|
return MulticondLearnedConditioning(shape=(len(prompts),), batch=res)
|
||||||
|
|
||||||
|
|
||||||
|
class DictWithShape(dict):
|
||||||
|
def __init__(self, x, shape):
|
||||||
|
super().__init__()
|
||||||
|
self.update(x)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def shape(self):
|
||||||
|
return self["crossattn"].shape
|
||||||
|
|
||||||
|
|
||||||
def reconstruct_cond_batch(c: List[List[ScheduledPromptConditioning]], current_step):
|
def reconstruct_cond_batch(c: List[List[ScheduledPromptConditioning]], current_step):
|
||||||
param = c[0][0].cond
|
param = c[0][0].cond
|
||||||
res = torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype)
|
is_dict = isinstance(param, dict)
|
||||||
|
|
||||||
|
if is_dict:
|
||||||
|
dict_cond = param
|
||||||
|
res = {k: torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype) for k, param in dict_cond.items()}
|
||||||
|
res = DictWithShape(res, (len(c),) + dict_cond['crossattn'].shape)
|
||||||
|
else:
|
||||||
|
res = torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype)
|
||||||
|
|
||||||
for i, cond_schedule in enumerate(c):
|
for i, cond_schedule in enumerate(c):
|
||||||
target_index = 0
|
target_index = 0
|
||||||
for current, entry in enumerate(cond_schedule):
|
for current, entry in enumerate(cond_schedule):
|
||||||
if current_step <= entry.end_at_step:
|
if current_step <= entry.end_at_step:
|
||||||
target_index = current
|
target_index = current
|
||||||
break
|
break
|
||||||
res[i] = cond_schedule[target_index].cond
|
|
||||||
|
if is_dict:
|
||||||
|
for k, param in cond_schedule[target_index].cond.items():
|
||||||
|
res[k][i] = param
|
||||||
|
else:
|
||||||
|
res[i] = cond_schedule[target_index].cond
|
||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
def stack_conds(tensors):
|
||||||
|
# if prompts have wildly different lengths above the limit we'll get tensors of different shapes
|
||||||
|
# and won't be able to torch.stack them. So this fixes that.
|
||||||
|
token_count = max([x.shape[0] for x in tensors])
|
||||||
|
for i in range(len(tensors)):
|
||||||
|
if tensors[i].shape[0] != token_count:
|
||||||
|
last_vector = tensors[i][-1:]
|
||||||
|
last_vector_repeated = last_vector.repeat([token_count - tensors[i].shape[0], 1])
|
||||||
|
tensors[i] = torch.vstack([tensors[i], last_vector_repeated])
|
||||||
|
|
||||||
|
return torch.stack(tensors)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def reconstruct_multicond_batch(c: MulticondLearnedConditioning, current_step):
|
def reconstruct_multicond_batch(c: MulticondLearnedConditioning, current_step):
|
||||||
param = c.batch[0][0].schedules[0].cond
|
param = c.batch[0][0].schedules[0].cond
|
||||||
|
|
||||||
@ -249,16 +319,14 @@ def reconstruct_multicond_batch(c: MulticondLearnedConditioning, current_step):
|
|||||||
|
|
||||||
conds_list.append(conds_for_batch)
|
conds_list.append(conds_for_batch)
|
||||||
|
|
||||||
# if prompts have wildly different lengths above the limit we'll get tensors fo different shapes
|
if isinstance(tensors[0], dict):
|
||||||
# and won't be able to torch.stack them. So this fixes that.
|
keys = list(tensors[0].keys())
|
||||||
token_count = max([x.shape[0] for x in tensors])
|
stacked = {k: stack_conds([x[k] for x in tensors]) for k in keys}
|
||||||
for i in range(len(tensors)):
|
stacked = DictWithShape(stacked, stacked['crossattn'].shape)
|
||||||
if tensors[i].shape[0] != token_count:
|
else:
|
||||||
last_vector = tensors[i][-1:]
|
stacked = stack_conds(tensors).to(device=param.device, dtype=param.dtype)
|
||||||
last_vector_repeated = last_vector.repeat([token_count - tensors[i].shape[0], 1])
|
|
||||||
tensors[i] = torch.vstack([tensors[i], last_vector_repeated])
|
|
||||||
|
|
||||||
return conds_list, torch.stack(tensors).to(device=param.device, dtype=param.dtype)
|
return conds_list, stacked
|
||||||
|
|
||||||
|
|
||||||
re_attention = re.compile(r"""
|
re_attention = re.compile(r"""
|
||||||
@ -270,7 +338,7 @@ re_attention = re.compile(r"""
|
|||||||
\\|
|
\\|
|
||||||
\(|
|
\(|
|
||||||
\[|
|
\[|
|
||||||
:([+-]?[.\d]+)\)|
|
:\s*([+-]?[.\d]+)\s*\)|
|
||||||
\)|
|
\)|
|
||||||
]|
|
]|
|
||||||
[^\\()\[\]:]+|
|
[^\\()\[\]:]+|
|
||||||
|
102
modules/rng_philox.py
Normal file
102
modules/rng_philox.py
Normal file
@ -0,0 +1,102 @@
|
|||||||
|
"""RNG imitiating torch cuda randn on CPU. You are welcome.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
|
||||||
|
```
|
||||||
|
g = Generator(seed=0)
|
||||||
|
print(g.randn(shape=(3, 4)))
|
||||||
|
```
|
||||||
|
|
||||||
|
Expected output:
|
||||||
|
```
|
||||||
|
[[-0.92466259 -0.42534415 -2.6438457 0.14518388]
|
||||||
|
[-0.12086647 -0.57972564 -0.62285122 -0.32838709]
|
||||||
|
[-1.07454231 -0.36314407 -1.67105067 2.26550497]]
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
philox_m = [0xD2511F53, 0xCD9E8D57]
|
||||||
|
philox_w = [0x9E3779B9, 0xBB67AE85]
|
||||||
|
|
||||||
|
two_pow32_inv = np.array([2.3283064e-10], dtype=np.float32)
|
||||||
|
two_pow32_inv_2pi = np.array([2.3283064e-10 * 6.2831855], dtype=np.float32)
|
||||||
|
|
||||||
|
|
||||||
|
def uint32(x):
|
||||||
|
"""Converts (N,) np.uint64 array into (2, N) np.unit32 array."""
|
||||||
|
return x.view(np.uint32).reshape(-1, 2).transpose(1, 0)
|
||||||
|
|
||||||
|
|
||||||
|
def philox4_round(counter, key):
|
||||||
|
"""A single round of the Philox 4x32 random number generator."""
|
||||||
|
|
||||||
|
v1 = uint32(counter[0].astype(np.uint64) * philox_m[0])
|
||||||
|
v2 = uint32(counter[2].astype(np.uint64) * philox_m[1])
|
||||||
|
|
||||||
|
counter[0] = v2[1] ^ counter[1] ^ key[0]
|
||||||
|
counter[1] = v2[0]
|
||||||
|
counter[2] = v1[1] ^ counter[3] ^ key[1]
|
||||||
|
counter[3] = v1[0]
|
||||||
|
|
||||||
|
|
||||||
|
def philox4_32(counter, key, rounds=10):
|
||||||
|
"""Generates 32-bit random numbers using the Philox 4x32 random number generator.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
counter (numpy.ndarray): A 4xN array of 32-bit integers representing the counter values (offset into generation).
|
||||||
|
key (numpy.ndarray): A 2xN array of 32-bit integers representing the key values (seed).
|
||||||
|
rounds (int): The number of rounds to perform.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
numpy.ndarray: A 4xN array of 32-bit integers containing the generated random numbers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
for _ in range(rounds - 1):
|
||||||
|
philox4_round(counter, key)
|
||||||
|
|
||||||
|
key[0] = key[0] + philox_w[0]
|
||||||
|
key[1] = key[1] + philox_w[1]
|
||||||
|
|
||||||
|
philox4_round(counter, key)
|
||||||
|
return counter
|
||||||
|
|
||||||
|
|
||||||
|
def box_muller(x, y):
|
||||||
|
"""Returns just the first out of two numbers generated by Box–Muller transform algorithm."""
|
||||||
|
u = x * two_pow32_inv + two_pow32_inv / 2
|
||||||
|
v = y * two_pow32_inv_2pi + two_pow32_inv_2pi / 2
|
||||||
|
|
||||||
|
s = np.sqrt(-2.0 * np.log(u))
|
||||||
|
|
||||||
|
r1 = s * np.sin(v)
|
||||||
|
return r1.astype(np.float32)
|
||||||
|
|
||||||
|
|
||||||
|
class Generator:
|
||||||
|
"""RNG that produces same outputs as torch.randn(..., device='cuda') on CPU"""
|
||||||
|
|
||||||
|
def __init__(self, seed):
|
||||||
|
self.seed = seed
|
||||||
|
self.offset = 0
|
||||||
|
|
||||||
|
def randn(self, shape):
|
||||||
|
"""Generate a sequence of n standard normal random variables using the Philox 4x32 random number generator and the Box-Muller transform."""
|
||||||
|
|
||||||
|
n = 1
|
||||||
|
for x in shape:
|
||||||
|
n *= x
|
||||||
|
|
||||||
|
counter = np.zeros((4, n), dtype=np.uint32)
|
||||||
|
counter[0] = self.offset
|
||||||
|
counter[2] = np.arange(n, dtype=np.uint32) # up to 2^32 numbers can be generated - if you want more you'd need to spill into counter[3]
|
||||||
|
self.offset += 1
|
||||||
|
|
||||||
|
key = np.empty(n, dtype=np.uint64)
|
||||||
|
key.fill(self.seed)
|
||||||
|
key = uint32(key)
|
||||||
|
|
||||||
|
g = philox4_32(counter, key)
|
||||||
|
|
||||||
|
return box_muller(g[0], g[1]).reshape(shape) # discard g[2] and g[3]
|
@ -12,11 +12,12 @@ def load_module(path):
|
|||||||
return module
|
return module
|
||||||
|
|
||||||
|
|
||||||
def preload_extensions(extensions_dir, parser):
|
def preload_extensions(extensions_dir, parser, extension_list=None):
|
||||||
if not os.path.isdir(extensions_dir):
|
if not os.path.isdir(extensions_dir):
|
||||||
return
|
return
|
||||||
|
|
||||||
for dirname in sorted(os.listdir(extensions_dir)):
|
extensions = extension_list if extension_list is not None else os.listdir(extensions_dir)
|
||||||
|
for dirname in sorted(extensions):
|
||||||
preload_script = os.path.join(extensions_dir, dirname, "preload.py")
|
preload_script = os.path.join(extensions_dir, dirname, "preload.py")
|
||||||
if not os.path.isfile(preload_script):
|
if not os.path.isfile(preload_script):
|
||||||
continue
|
continue
|
||||||
|
@ -16,6 +16,11 @@ class PostprocessImageArgs:
|
|||||||
self.image = image
|
self.image = image
|
||||||
|
|
||||||
|
|
||||||
|
class PostprocessBatchListArgs:
|
||||||
|
def __init__(self, images):
|
||||||
|
self.images = images
|
||||||
|
|
||||||
|
|
||||||
class Script:
|
class Script:
|
||||||
name = None
|
name = None
|
||||||
"""script's internal name derived from title"""
|
"""script's internal name derived from title"""
|
||||||
@ -119,7 +124,7 @@ class Script:
|
|||||||
|
|
||||||
def after_extra_networks_activate(self, p, *args, **kwargs):
|
def after_extra_networks_activate(self, p, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
Calledafter extra networks activation, before conds calculation
|
Called after extra networks activation, before conds calculation
|
||||||
allow modification of the network after extra networks activation been applied
|
allow modification of the network after extra networks activation been applied
|
||||||
won't be call if p.disable_extra_networks
|
won't be call if p.disable_extra_networks
|
||||||
|
|
||||||
@ -156,6 +161,25 @@ class Script:
|
|||||||
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def postprocess_batch_list(self, p, pp: PostprocessBatchListArgs, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Same as postprocess_batch(), but receives batch images as a list of 3D tensors instead of a 4D tensor.
|
||||||
|
This is useful when you want to update the entire batch instead of individual images.
|
||||||
|
|
||||||
|
You can modify the postprocessing object (pp) to update the images in the batch, remove images, add images, etc.
|
||||||
|
If the number of images is different from the batch size when returning,
|
||||||
|
then the script has the responsibility to also update the following attributes in the processing object (p):
|
||||||
|
- p.prompts
|
||||||
|
- p.negative_prompts
|
||||||
|
- p.seeds
|
||||||
|
- p.subseeds
|
||||||
|
|
||||||
|
**kwargs will have same items as process_batch, and also:
|
||||||
|
- batch_number - index of current batch, from 0 to number of batches-1
|
||||||
|
"""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
def postprocess_image(self, p, pp: PostprocessImageArgs, *args):
|
def postprocess_image(self, p, pp: PostprocessImageArgs, *args):
|
||||||
"""
|
"""
|
||||||
Called for every image after it has been generated.
|
Called for every image after it has been generated.
|
||||||
@ -536,6 +560,14 @@ class ScriptRunner:
|
|||||||
except Exception:
|
except Exception:
|
||||||
errors.report(f"Error running postprocess_batch: {script.filename}", exc_info=True)
|
errors.report(f"Error running postprocess_batch: {script.filename}", exc_info=True)
|
||||||
|
|
||||||
|
def postprocess_batch_list(self, p, pp: PostprocessBatchListArgs, **kwargs):
|
||||||
|
for script in self.alwayson_scripts:
|
||||||
|
try:
|
||||||
|
script_args = p.script_args[script.args_from:script.args_to]
|
||||||
|
script.postprocess_batch_list(p, pp, *script_args, **kwargs)
|
||||||
|
except Exception:
|
||||||
|
errors.report(f"Error running postprocess_batch_list: {script.filename}", exc_info=True)
|
||||||
|
|
||||||
def postprocess_image(self, p, pp: PostprocessImageArgs):
|
def postprocess_image(self, p, pp: PostprocessImageArgs):
|
||||||
for script in self.alwayson_scripts:
|
for script in self.alwayson_scripts:
|
||||||
try:
|
try:
|
||||||
@ -599,49 +631,3 @@ def reload_script_body_only():
|
|||||||
|
|
||||||
|
|
||||||
reload_scripts = load_scripts # compatibility alias
|
reload_scripts = load_scripts # compatibility alias
|
||||||
|
|
||||||
|
|
||||||
def add_classes_to_gradio_component(comp):
|
|
||||||
"""
|
|
||||||
this adds gradio-* to the component for css styling (ie gradio-button to gr.Button), as well as some others
|
|
||||||
"""
|
|
||||||
|
|
||||||
comp.elem_classes = [f"gradio-{comp.get_block_name()}", *(comp.elem_classes or [])]
|
|
||||||
|
|
||||||
if getattr(comp, 'multiselect', False):
|
|
||||||
comp.elem_classes.append('multiselect')
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def IOComponent_init(self, *args, **kwargs):
|
|
||||||
if scripts_current is not None:
|
|
||||||
scripts_current.before_component(self, **kwargs)
|
|
||||||
|
|
||||||
script_callbacks.before_component_callback(self, **kwargs)
|
|
||||||
|
|
||||||
res = original_IOComponent_init(self, *args, **kwargs)
|
|
||||||
|
|
||||||
add_classes_to_gradio_component(self)
|
|
||||||
|
|
||||||
script_callbacks.after_component_callback(self, **kwargs)
|
|
||||||
|
|
||||||
if scripts_current is not None:
|
|
||||||
scripts_current.after_component(self, **kwargs)
|
|
||||||
|
|
||||||
return res
|
|
||||||
|
|
||||||
|
|
||||||
original_IOComponent_init = gr.components.IOComponent.__init__
|
|
||||||
gr.components.IOComponent.__init__ = IOComponent_init
|
|
||||||
|
|
||||||
|
|
||||||
def BlockContext_init(self, *args, **kwargs):
|
|
||||||
res = original_BlockContext_init(self, *args, **kwargs)
|
|
||||||
|
|
||||||
add_classes_to_gradio_component(self)
|
|
||||||
|
|
||||||
return res
|
|
||||||
|
|
||||||
|
|
||||||
original_BlockContext_init = gr.blocks.BlockContext.__init__
|
|
||||||
gr.blocks.BlockContext.__init__ = BlockContext_init
|
|
||||||
|
@ -3,8 +3,31 @@ import open_clip
|
|||||||
import torch
|
import torch
|
||||||
import transformers.utils.hub
|
import transformers.utils.hub
|
||||||
|
|
||||||
|
from modules import shared
|
||||||
|
|
||||||
class DisableInitialization:
|
|
||||||
|
class ReplaceHelper:
|
||||||
|
def __init__(self):
|
||||||
|
self.replaced = []
|
||||||
|
|
||||||
|
def replace(self, obj, field, func):
|
||||||
|
original = getattr(obj, field, None)
|
||||||
|
if original is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
self.replaced.append((obj, field, original))
|
||||||
|
setattr(obj, field, func)
|
||||||
|
|
||||||
|
return original
|
||||||
|
|
||||||
|
def restore(self):
|
||||||
|
for obj, field, original in self.replaced:
|
||||||
|
setattr(obj, field, original)
|
||||||
|
|
||||||
|
self.replaced.clear()
|
||||||
|
|
||||||
|
|
||||||
|
class DisableInitialization(ReplaceHelper):
|
||||||
"""
|
"""
|
||||||
When an object of this class enters a `with` block, it starts:
|
When an object of this class enters a `with` block, it starts:
|
||||||
- preventing torch's layer initialization functions from working
|
- preventing torch's layer initialization functions from working
|
||||||
@ -21,7 +44,7 @@ class DisableInitialization:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, disable_clip=True):
|
def __init__(self, disable_clip=True):
|
||||||
self.replaced = []
|
super().__init__()
|
||||||
self.disable_clip = disable_clip
|
self.disable_clip = disable_clip
|
||||||
|
|
||||||
def replace(self, obj, field, func):
|
def replace(self, obj, field, func):
|
||||||
@ -86,8 +109,81 @@ class DisableInitialization:
|
|||||||
self.transformers_utils_hub_get_from_cache = self.replace(transformers.utils.hub, 'get_from_cache', transformers_utils_hub_get_from_cache)
|
self.transformers_utils_hub_get_from_cache = self.replace(transformers.utils.hub, 'get_from_cache', transformers_utils_hub_get_from_cache)
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
for obj, field, original in self.replaced:
|
self.restore()
|
||||||
setattr(obj, field, original)
|
|
||||||
|
|
||||||
self.replaced.clear()
|
|
||||||
|
|
||||||
|
class InitializeOnMeta(ReplaceHelper):
|
||||||
|
"""
|
||||||
|
Context manager that causes all parameters for linear/conv2d/mha layers to be allocated on meta device,
|
||||||
|
which results in those parameters having no values and taking no memory. model.to() will be broken and
|
||||||
|
will need to be repaired by using LoadStateDictOnMeta below when loading params from state dict.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
```
|
||||||
|
with sd_disable_initialization.InitializeOnMeta():
|
||||||
|
sd_model = instantiate_from_config(sd_config.model)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
if shared.cmd_opts.disable_model_loading_ram_optimization:
|
||||||
|
return
|
||||||
|
|
||||||
|
def set_device(x):
|
||||||
|
x["device"] = "meta"
|
||||||
|
return x
|
||||||
|
|
||||||
|
linear_init = self.replace(torch.nn.Linear, '__init__', lambda *args, **kwargs: linear_init(*args, **set_device(kwargs)))
|
||||||
|
conv2d_init = self.replace(torch.nn.Conv2d, '__init__', lambda *args, **kwargs: conv2d_init(*args, **set_device(kwargs)))
|
||||||
|
mha_init = self.replace(torch.nn.MultiheadAttention, '__init__', lambda *args, **kwargs: mha_init(*args, **set_device(kwargs)))
|
||||||
|
self.replace(torch.nn.Module, 'to', lambda *args, **kwargs: None)
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
self.restore()
|
||||||
|
|
||||||
|
|
||||||
|
class LoadStateDictOnMeta(ReplaceHelper):
|
||||||
|
"""
|
||||||
|
Context manager that allows to read parameters from state_dict into a model that has some of its parameters in the meta device.
|
||||||
|
As those parameters are read from state_dict, they will be deleted from it, so by the end state_dict will be mostly empty, to save memory.
|
||||||
|
Meant to be used together with InitializeOnMeta above.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
```
|
||||||
|
with sd_disable_initialization.LoadStateDictOnMeta(state_dict):
|
||||||
|
model.load_state_dict(state_dict, strict=False)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, state_dict, device):
|
||||||
|
super().__init__()
|
||||||
|
self.state_dict = state_dict
|
||||||
|
self.device = device
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
if shared.cmd_opts.disable_model_loading_ram_optimization:
|
||||||
|
return
|
||||||
|
|
||||||
|
sd = self.state_dict
|
||||||
|
device = self.device
|
||||||
|
|
||||||
|
def load_from_state_dict(original, self, state_dict, prefix, *args, **kwargs):
|
||||||
|
params = [(name, param) for name, param in self._parameters.items() if param is not None and param.is_meta]
|
||||||
|
|
||||||
|
for name, param in params:
|
||||||
|
if param.is_meta:
|
||||||
|
self._parameters[name] = torch.nn.parameter.Parameter(torch.zeros_like(param, device=device), requires_grad=param.requires_grad)
|
||||||
|
|
||||||
|
original(self, state_dict, prefix, *args, **kwargs)
|
||||||
|
|
||||||
|
for name, _ in params:
|
||||||
|
key = prefix + name
|
||||||
|
if key in sd:
|
||||||
|
del sd[key]
|
||||||
|
|
||||||
|
linear_load_from_state_dict = self.replace(torch.nn.Linear, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(linear_load_from_state_dict, *args, **kwargs))
|
||||||
|
conv2d_load_from_state_dict = self.replace(torch.nn.Conv2d, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(conv2d_load_from_state_dict, *args, **kwargs))
|
||||||
|
mha_load_from_state_dict = self.replace(torch.nn.MultiheadAttention, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(mha_load_from_state_dict, *args, **kwargs))
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
self.restore()
|
||||||
|
@ -2,11 +2,10 @@ import torch
|
|||||||
from torch.nn.functional import silu
|
from torch.nn.functional import silu
|
||||||
from types import MethodType
|
from types import MethodType
|
||||||
|
|
||||||
import modules.textual_inversion.textual_inversion
|
|
||||||
from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet
|
from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet
|
||||||
from modules.hypernetworks import hypernetwork
|
from modules.hypernetworks import hypernetwork
|
||||||
from modules.shared import cmd_opts
|
from modules.shared import cmd_opts
|
||||||
from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr
|
from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr, sd_hijack_inpainting
|
||||||
|
|
||||||
import ldm.modules.attention
|
import ldm.modules.attention
|
||||||
import ldm.modules.diffusionmodules.model
|
import ldm.modules.diffusionmodules.model
|
||||||
@ -15,6 +14,11 @@ import ldm.models.diffusion.ddim
|
|||||||
import ldm.models.diffusion.plms
|
import ldm.models.diffusion.plms
|
||||||
import ldm.modules.encoders.modules
|
import ldm.modules.encoders.modules
|
||||||
|
|
||||||
|
import sgm.modules.attention
|
||||||
|
import sgm.modules.diffusionmodules.model
|
||||||
|
import sgm.modules.diffusionmodules.openaimodel
|
||||||
|
import sgm.modules.encoders.modules
|
||||||
|
|
||||||
attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward
|
attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward
|
||||||
diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity
|
diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity
|
||||||
diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
|
diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
|
||||||
@ -25,8 +29,12 @@ ldm.modules.attention.MemoryEfficientCrossAttention = ldm.modules.attention.Cros
|
|||||||
ldm.modules.attention.BasicTransformerBlock.ATTENTION_MODES["softmax-xformers"] = ldm.modules.attention.CrossAttention
|
ldm.modules.attention.BasicTransformerBlock.ATTENTION_MODES["softmax-xformers"] = ldm.modules.attention.CrossAttention
|
||||||
|
|
||||||
# silence new console spam from SD2
|
# silence new console spam from SD2
|
||||||
ldm.modules.attention.print = lambda *args: None
|
ldm.modules.attention.print = shared.ldm_print
|
||||||
ldm.modules.diffusionmodules.model.print = lambda *args: None
|
ldm.modules.diffusionmodules.model.print = shared.ldm_print
|
||||||
|
ldm.util.print = shared.ldm_print
|
||||||
|
ldm.models.diffusion.ddpm.print = shared.ldm_print
|
||||||
|
|
||||||
|
sd_hijack_inpainting.do_inpainting_hijack()
|
||||||
|
|
||||||
optimizers = []
|
optimizers = []
|
||||||
current_optimizer: sd_hijack_optimizations.SdOptimization = None
|
current_optimizer: sd_hijack_optimizations.SdOptimization = None
|
||||||
@ -56,6 +64,9 @@ def apply_optimizations(option=None):
|
|||||||
ldm.modules.diffusionmodules.model.nonlinearity = silu
|
ldm.modules.diffusionmodules.model.nonlinearity = silu
|
||||||
ldm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th
|
ldm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th
|
||||||
|
|
||||||
|
sgm.modules.diffusionmodules.model.nonlinearity = silu
|
||||||
|
sgm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th
|
||||||
|
|
||||||
if current_optimizer is not None:
|
if current_optimizer is not None:
|
||||||
current_optimizer.undo()
|
current_optimizer.undo()
|
||||||
current_optimizer = None
|
current_optimizer = None
|
||||||
@ -89,6 +100,10 @@ def undo_optimizations():
|
|||||||
ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
|
ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
|
||||||
ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
|
ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
|
||||||
|
|
||||||
|
sgm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity
|
||||||
|
sgm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
|
||||||
|
sgm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
|
||||||
|
|
||||||
|
|
||||||
def fix_checkpoint():
|
def fix_checkpoint():
|
||||||
"""checkpoints are now added and removed in embedding/hypernet code, since torch doesn't want
|
"""checkpoints are now added and removed in embedding/hypernet code, since torch doesn't want
|
||||||
@ -152,12 +167,13 @@ class StableDiffusionModelHijack:
|
|||||||
clip = None
|
clip = None
|
||||||
optimization_method = None
|
optimization_method = None
|
||||||
|
|
||||||
embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase()
|
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
import modules.textual_inversion.textual_inversion
|
||||||
|
|
||||||
self.extra_generation_params = {}
|
self.extra_generation_params = {}
|
||||||
self.comments = []
|
self.comments = []
|
||||||
|
|
||||||
|
self.embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase()
|
||||||
self.embedding_db.add_embedding_dir(cmd_opts.embeddings_dir)
|
self.embedding_db.add_embedding_dir(cmd_opts.embeddings_dir)
|
||||||
|
|
||||||
def apply_optimizations(self, option=None):
|
def apply_optimizations(self, option=None):
|
||||||
@ -168,6 +184,32 @@ class StableDiffusionModelHijack:
|
|||||||
undo_optimizations()
|
undo_optimizations()
|
||||||
|
|
||||||
def hijack(self, m):
|
def hijack(self, m):
|
||||||
|
conditioner = getattr(m, 'conditioner', None)
|
||||||
|
if conditioner:
|
||||||
|
text_cond_models = []
|
||||||
|
|
||||||
|
for i in range(len(conditioner.embedders)):
|
||||||
|
embedder = conditioner.embedders[i]
|
||||||
|
typename = type(embedder).__name__
|
||||||
|
if typename == 'FrozenOpenCLIPEmbedder':
|
||||||
|
embedder.model.token_embedding = EmbeddingsWithFixes(embedder.model.token_embedding, self)
|
||||||
|
conditioner.embedders[i] = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(embedder, self)
|
||||||
|
text_cond_models.append(conditioner.embedders[i])
|
||||||
|
if typename == 'FrozenCLIPEmbedder':
|
||||||
|
model_embeddings = embedder.transformer.text_model.embeddings
|
||||||
|
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
|
||||||
|
conditioner.embedders[i] = sd_hijack_clip.FrozenCLIPEmbedderForSDXLWithCustomWords(embedder, self)
|
||||||
|
text_cond_models.append(conditioner.embedders[i])
|
||||||
|
if typename == 'FrozenOpenCLIPEmbedder2':
|
||||||
|
embedder.model.token_embedding = EmbeddingsWithFixes(embedder.model.token_embedding, self, textual_inversion_key='clip_g')
|
||||||
|
conditioner.embedders[i] = sd_hijack_open_clip.FrozenOpenCLIPEmbedder2WithCustomWords(embedder, self)
|
||||||
|
text_cond_models.append(conditioner.embedders[i])
|
||||||
|
|
||||||
|
if len(text_cond_models) == 1:
|
||||||
|
m.cond_stage_model = text_cond_models[0]
|
||||||
|
else:
|
||||||
|
m.cond_stage_model = conditioner
|
||||||
|
|
||||||
if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
|
if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
|
||||||
model_embeddings = m.cond_stage_model.roberta.embeddings
|
model_embeddings = m.cond_stage_model.roberta.embeddings
|
||||||
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.word_embeddings, self)
|
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.word_embeddings, self)
|
||||||
@ -205,7 +247,7 @@ class StableDiffusionModelHijack:
|
|||||||
ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = sd_unet.UNetModel_forward
|
ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = sd_unet.UNetModel_forward
|
||||||
|
|
||||||
def undo_hijack(self, m):
|
def undo_hijack(self, m):
|
||||||
if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
|
if type(m.cond_stage_model) == sd_hijack_xlmr.FrozenXLMREmbedderWithCustomWords:
|
||||||
m.cond_stage_model = m.cond_stage_model.wrapped
|
m.cond_stage_model = m.cond_stage_model.wrapped
|
||||||
|
|
||||||
elif type(m.cond_stage_model) == sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords:
|
elif type(m.cond_stage_model) == sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords:
|
||||||
@ -254,10 +296,11 @@ class StableDiffusionModelHijack:
|
|||||||
|
|
||||||
|
|
||||||
class EmbeddingsWithFixes(torch.nn.Module):
|
class EmbeddingsWithFixes(torch.nn.Module):
|
||||||
def __init__(self, wrapped, embeddings):
|
def __init__(self, wrapped, embeddings, textual_inversion_key='clip_l'):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.wrapped = wrapped
|
self.wrapped = wrapped
|
||||||
self.embeddings = embeddings
|
self.embeddings = embeddings
|
||||||
|
self.textual_inversion_key = textual_inversion_key
|
||||||
|
|
||||||
def forward(self, input_ids):
|
def forward(self, input_ids):
|
||||||
batch_fixes = self.embeddings.fixes
|
batch_fixes = self.embeddings.fixes
|
||||||
@ -271,7 +314,8 @@ class EmbeddingsWithFixes(torch.nn.Module):
|
|||||||
vecs = []
|
vecs = []
|
||||||
for fixes, tensor in zip(batch_fixes, inputs_embeds):
|
for fixes, tensor in zip(batch_fixes, inputs_embeds):
|
||||||
for offset, embedding in fixes:
|
for offset, embedding in fixes:
|
||||||
emb = devices.cond_cast_unet(embedding.vec)
|
vec = embedding.vec[self.textual_inversion_key] if isinstance(embedding.vec, dict) else embedding.vec
|
||||||
|
emb = devices.cond_cast_unet(vec)
|
||||||
emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0])
|
emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0])
|
||||||
tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]])
|
tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]])
|
||||||
|
|
||||||
|
@ -42,6 +42,10 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
|
|||||||
self.hijack: sd_hijack.StableDiffusionModelHijack = hijack
|
self.hijack: sd_hijack.StableDiffusionModelHijack = hijack
|
||||||
self.chunk_length = 75
|
self.chunk_length = 75
|
||||||
|
|
||||||
|
self.is_trainable = getattr(wrapped, 'is_trainable', False)
|
||||||
|
self.input_key = getattr(wrapped, 'input_key', 'txt')
|
||||||
|
self.legacy_ucg_val = None
|
||||||
|
|
||||||
def empty_chunk(self):
|
def empty_chunk(self):
|
||||||
"""creates an empty PromptChunk and returns it"""
|
"""creates an empty PromptChunk and returns it"""
|
||||||
|
|
||||||
@ -157,7 +161,7 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
|
|||||||
position += 1
|
position += 1
|
||||||
continue
|
continue
|
||||||
|
|
||||||
emb_len = int(embedding.vec.shape[0])
|
emb_len = int(embedding.vectors)
|
||||||
if len(chunk.tokens) + emb_len > self.chunk_length:
|
if len(chunk.tokens) + emb_len > self.chunk_length:
|
||||||
next_chunk()
|
next_chunk()
|
||||||
|
|
||||||
@ -199,8 +203,9 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
|
|||||||
"""
|
"""
|
||||||
Accepts an array of texts; Passes texts through transformers network to create a tensor with numerical representation of those texts.
|
Accepts an array of texts; Passes texts through transformers network to create a tensor with numerical representation of those texts.
|
||||||
Returns a tensor with shape of (B, T, C), where B is length of the array; T is length, in tokens, of texts (including padding) - T will
|
Returns a tensor with shape of (B, T, C), where B is length of the array; T is length, in tokens, of texts (including padding) - T will
|
||||||
be a multiple of 77; and C is dimensionality of each token - for SD1 it's 768, and for SD2 it's 1024.
|
be a multiple of 77; and C is dimensionality of each token - for SD1 it's 768, for SD2 it's 1024, and for SDXL it's 1280.
|
||||||
An example shape returned by this function can be: (2, 77, 768).
|
An example shape returned by this function can be: (2, 77, 768).
|
||||||
|
For SDXL, instead of returning one tensor avobe, it returns a tuple with two: the other one with shape (B, 1280) with pooled values.
|
||||||
Webui usually sends just one text at a time through this function - the only time when texts is an array with more than one elemenet
|
Webui usually sends just one text at a time through this function - the only time when texts is an array with more than one elemenet
|
||||||
is when you do prompt editing: "a picture of a [cat:dog:0.4] eating ice cream"
|
is when you do prompt editing: "a picture of a [cat:dog:0.4] eating ice cream"
|
||||||
"""
|
"""
|
||||||
@ -240,9 +245,14 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
|
|||||||
hashes.append(f"{name}: {shorthash}")
|
hashes.append(f"{name}: {shorthash}")
|
||||||
|
|
||||||
if hashes:
|
if hashes:
|
||||||
|
if self.hijack.extra_generation_params.get("TI hashes"):
|
||||||
|
hashes.append(self.hijack.extra_generation_params.get("TI hashes"))
|
||||||
self.hijack.extra_generation_params["TI hashes"] = ", ".join(hashes)
|
self.hijack.extra_generation_params["TI hashes"] = ", ".join(hashes)
|
||||||
|
|
||||||
return torch.hstack(zs)
|
if getattr(self.wrapped, 'return_pooled', False):
|
||||||
|
return torch.hstack(zs), zs[0].pooled
|
||||||
|
else:
|
||||||
|
return torch.hstack(zs)
|
||||||
|
|
||||||
def process_tokens(self, remade_batch_tokens, batch_multipliers):
|
def process_tokens(self, remade_batch_tokens, batch_multipliers):
|
||||||
"""
|
"""
|
||||||
@ -262,6 +272,8 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
|
|||||||
|
|
||||||
z = self.encode_with_transformers(tokens)
|
z = self.encode_with_transformers(tokens)
|
||||||
|
|
||||||
|
pooled = getattr(z, 'pooled', None)
|
||||||
|
|
||||||
# restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
|
# restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
|
||||||
batch_multipliers = torch.asarray(batch_multipliers).to(devices.device)
|
batch_multipliers = torch.asarray(batch_multipliers).to(devices.device)
|
||||||
original_mean = z.mean()
|
original_mean = z.mean()
|
||||||
@ -269,6 +281,9 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
|
|||||||
new_mean = z.mean()
|
new_mean = z.mean()
|
||||||
z = z * (original_mean / new_mean)
|
z = z * (original_mean / new_mean)
|
||||||
|
|
||||||
|
if pooled is not None:
|
||||||
|
z.pooled = pooled
|
||||||
|
|
||||||
return z
|
return z
|
||||||
|
|
||||||
|
|
||||||
@ -324,3 +339,18 @@ class FrozenCLIPEmbedderWithCustomWords(FrozenCLIPEmbedderWithCustomWordsBase):
|
|||||||
embedded = embedding_layer.token_embedding.wrapped(ids.to(embedding_layer.token_embedding.wrapped.weight.device)).squeeze(0)
|
embedded = embedding_layer.token_embedding.wrapped(ids.to(embedding_layer.token_embedding.wrapped.weight.device)).squeeze(0)
|
||||||
|
|
||||||
return embedded
|
return embedded
|
||||||
|
|
||||||
|
|
||||||
|
class FrozenCLIPEmbedderForSDXLWithCustomWords(FrozenCLIPEmbedderWithCustomWords):
|
||||||
|
def __init__(self, wrapped, hijack):
|
||||||
|
super().__init__(wrapped, hijack)
|
||||||
|
|
||||||
|
def encode_with_transformers(self, tokens):
|
||||||
|
outputs = self.wrapped.transformer(input_ids=tokens, output_hidden_states=self.wrapped.layer == "hidden")
|
||||||
|
|
||||||
|
if self.wrapped.layer == "last":
|
||||||
|
z = outputs.last_hidden_state
|
||||||
|
else:
|
||||||
|
z = outputs.hidden_states[self.wrapped.layer_idx]
|
||||||
|
|
||||||
|
return z
|
||||||
|
@ -92,6 +92,4 @@ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=F
|
|||||||
|
|
||||||
|
|
||||||
def do_inpainting_hijack():
|
def do_inpainting_hijack():
|
||||||
# p_sample_plms is needed because PLMS can't work with dicts as conditionings
|
|
||||||
|
|
||||||
ldm.models.diffusion.plms.PLMSSampler.p_sample_plms = p_sample_plms
|
ldm.models.diffusion.plms.PLMSSampler.p_sample_plms = p_sample_plms
|
||||||
|
@ -35,3 +35,37 @@ class FrozenOpenCLIPEmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWit
|
|||||||
embedded = self.wrapped.model.token_embedding.wrapped(ids).squeeze(0)
|
embedded = self.wrapped.model.token_embedding.wrapped(ids).squeeze(0)
|
||||||
|
|
||||||
return embedded
|
return embedded
|
||||||
|
|
||||||
|
|
||||||
|
class FrozenOpenCLIPEmbedder2WithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase):
|
||||||
|
def __init__(self, wrapped, hijack):
|
||||||
|
super().__init__(wrapped, hijack)
|
||||||
|
|
||||||
|
self.comma_token = [v for k, v in tokenizer.encoder.items() if k == ',</w>'][0]
|
||||||
|
self.id_start = tokenizer.encoder["<start_of_text>"]
|
||||||
|
self.id_end = tokenizer.encoder["<end_of_text>"]
|
||||||
|
self.id_pad = 0
|
||||||
|
|
||||||
|
def tokenize(self, texts):
|
||||||
|
assert not opts.use_old_emphasis_implementation, 'Old emphasis implementation not supported for Open Clip'
|
||||||
|
|
||||||
|
tokenized = [tokenizer.encode(text) for text in texts]
|
||||||
|
|
||||||
|
return tokenized
|
||||||
|
|
||||||
|
def encode_with_transformers(self, tokens):
|
||||||
|
d = self.wrapped.encode_with_transformer(tokens)
|
||||||
|
z = d[self.wrapped.layer]
|
||||||
|
|
||||||
|
pooled = d.get("pooled")
|
||||||
|
if pooled is not None:
|
||||||
|
z.pooled = pooled
|
||||||
|
|
||||||
|
return z
|
||||||
|
|
||||||
|
def encode_embedding_init_text(self, init_text, nvpt):
|
||||||
|
ids = tokenizer.encode(init_text)
|
||||||
|
ids = torch.asarray([ids], device=devices.device, dtype=torch.int)
|
||||||
|
embedded = self.wrapped.model.token_embedding.wrapped(ids.to(self.wrapped.model.token_embedding.wrapped.weight.device)).squeeze(0)
|
||||||
|
|
||||||
|
return embedded
|
||||||
|
@ -14,7 +14,11 @@ from modules.hypernetworks import hypernetwork
|
|||||||
import ldm.modules.attention
|
import ldm.modules.attention
|
||||||
import ldm.modules.diffusionmodules.model
|
import ldm.modules.diffusionmodules.model
|
||||||
|
|
||||||
|
import sgm.modules.attention
|
||||||
|
import sgm.modules.diffusionmodules.model
|
||||||
|
|
||||||
diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
|
diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
|
||||||
|
sgm_diffusionmodules_model_AttnBlock_forward = sgm.modules.diffusionmodules.model.AttnBlock.forward
|
||||||
|
|
||||||
|
|
||||||
class SdOptimization:
|
class SdOptimization:
|
||||||
@ -39,6 +43,9 @@ class SdOptimization:
|
|||||||
ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
|
ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
|
||||||
ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
|
ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
|
||||||
|
|
||||||
|
sgm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
|
||||||
|
sgm.modules.diffusionmodules.model.AttnBlock.forward = sgm_diffusionmodules_model_AttnBlock_forward
|
||||||
|
|
||||||
|
|
||||||
class SdOptimizationXformers(SdOptimization):
|
class SdOptimizationXformers(SdOptimization):
|
||||||
name = "xformers"
|
name = "xformers"
|
||||||
@ -51,6 +58,8 @@ class SdOptimizationXformers(SdOptimization):
|
|||||||
def apply(self):
|
def apply(self):
|
||||||
ldm.modules.attention.CrossAttention.forward = xformers_attention_forward
|
ldm.modules.attention.CrossAttention.forward = xformers_attention_forward
|
||||||
ldm.modules.diffusionmodules.model.AttnBlock.forward = xformers_attnblock_forward
|
ldm.modules.diffusionmodules.model.AttnBlock.forward = xformers_attnblock_forward
|
||||||
|
sgm.modules.attention.CrossAttention.forward = xformers_attention_forward
|
||||||
|
sgm.modules.diffusionmodules.model.AttnBlock.forward = xformers_attnblock_forward
|
||||||
|
|
||||||
|
|
||||||
class SdOptimizationSdpNoMem(SdOptimization):
|
class SdOptimizationSdpNoMem(SdOptimization):
|
||||||
@ -65,6 +74,8 @@ class SdOptimizationSdpNoMem(SdOptimization):
|
|||||||
def apply(self):
|
def apply(self):
|
||||||
ldm.modules.attention.CrossAttention.forward = scaled_dot_product_no_mem_attention_forward
|
ldm.modules.attention.CrossAttention.forward = scaled_dot_product_no_mem_attention_forward
|
||||||
ldm.modules.diffusionmodules.model.AttnBlock.forward = sdp_no_mem_attnblock_forward
|
ldm.modules.diffusionmodules.model.AttnBlock.forward = sdp_no_mem_attnblock_forward
|
||||||
|
sgm.modules.attention.CrossAttention.forward = scaled_dot_product_no_mem_attention_forward
|
||||||
|
sgm.modules.diffusionmodules.model.AttnBlock.forward = sdp_no_mem_attnblock_forward
|
||||||
|
|
||||||
|
|
||||||
class SdOptimizationSdp(SdOptimizationSdpNoMem):
|
class SdOptimizationSdp(SdOptimizationSdpNoMem):
|
||||||
@ -76,6 +87,8 @@ class SdOptimizationSdp(SdOptimizationSdpNoMem):
|
|||||||
def apply(self):
|
def apply(self):
|
||||||
ldm.modules.attention.CrossAttention.forward = scaled_dot_product_attention_forward
|
ldm.modules.attention.CrossAttention.forward = scaled_dot_product_attention_forward
|
||||||
ldm.modules.diffusionmodules.model.AttnBlock.forward = sdp_attnblock_forward
|
ldm.modules.diffusionmodules.model.AttnBlock.forward = sdp_attnblock_forward
|
||||||
|
sgm.modules.attention.CrossAttention.forward = scaled_dot_product_attention_forward
|
||||||
|
sgm.modules.diffusionmodules.model.AttnBlock.forward = sdp_attnblock_forward
|
||||||
|
|
||||||
|
|
||||||
class SdOptimizationSubQuad(SdOptimization):
|
class SdOptimizationSubQuad(SdOptimization):
|
||||||
@ -86,6 +99,8 @@ class SdOptimizationSubQuad(SdOptimization):
|
|||||||
def apply(self):
|
def apply(self):
|
||||||
ldm.modules.attention.CrossAttention.forward = sub_quad_attention_forward
|
ldm.modules.attention.CrossAttention.forward = sub_quad_attention_forward
|
||||||
ldm.modules.diffusionmodules.model.AttnBlock.forward = sub_quad_attnblock_forward
|
ldm.modules.diffusionmodules.model.AttnBlock.forward = sub_quad_attnblock_forward
|
||||||
|
sgm.modules.attention.CrossAttention.forward = sub_quad_attention_forward
|
||||||
|
sgm.modules.diffusionmodules.model.AttnBlock.forward = sub_quad_attnblock_forward
|
||||||
|
|
||||||
|
|
||||||
class SdOptimizationV1(SdOptimization):
|
class SdOptimizationV1(SdOptimization):
|
||||||
@ -94,9 +109,9 @@ class SdOptimizationV1(SdOptimization):
|
|||||||
cmd_opt = "opt_split_attention_v1"
|
cmd_opt = "opt_split_attention_v1"
|
||||||
priority = 10
|
priority = 10
|
||||||
|
|
||||||
|
|
||||||
def apply(self):
|
def apply(self):
|
||||||
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1
|
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1
|
||||||
|
sgm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1
|
||||||
|
|
||||||
|
|
||||||
class SdOptimizationInvokeAI(SdOptimization):
|
class SdOptimizationInvokeAI(SdOptimization):
|
||||||
@ -109,6 +124,7 @@ class SdOptimizationInvokeAI(SdOptimization):
|
|||||||
|
|
||||||
def apply(self):
|
def apply(self):
|
||||||
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_invokeAI
|
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_invokeAI
|
||||||
|
sgm.modules.attention.CrossAttention.forward = split_cross_attention_forward_invokeAI
|
||||||
|
|
||||||
|
|
||||||
class SdOptimizationDoggettx(SdOptimization):
|
class SdOptimizationDoggettx(SdOptimization):
|
||||||
@ -119,6 +135,8 @@ class SdOptimizationDoggettx(SdOptimization):
|
|||||||
def apply(self):
|
def apply(self):
|
||||||
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward
|
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward
|
||||||
ldm.modules.diffusionmodules.model.AttnBlock.forward = cross_attention_attnblock_forward
|
ldm.modules.diffusionmodules.model.AttnBlock.forward = cross_attention_attnblock_forward
|
||||||
|
sgm.modules.attention.CrossAttention.forward = split_cross_attention_forward
|
||||||
|
sgm.modules.diffusionmodules.model.AttnBlock.forward = cross_attention_attnblock_forward
|
||||||
|
|
||||||
|
|
||||||
def list_optimizers(res):
|
def list_optimizers(res):
|
||||||
@ -155,7 +173,7 @@ def get_available_vram():
|
|||||||
|
|
||||||
|
|
||||||
# see https://github.com/basujindal/stable-diffusion/pull/117 for discussion
|
# see https://github.com/basujindal/stable-diffusion/pull/117 for discussion
|
||||||
def split_cross_attention_forward_v1(self, x, context=None, mask=None):
|
def split_cross_attention_forward_v1(self, x, context=None, mask=None, **kwargs):
|
||||||
h = self.heads
|
h = self.heads
|
||||||
|
|
||||||
q_in = self.to_q(x)
|
q_in = self.to_q(x)
|
||||||
@ -196,7 +214,7 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None):
|
|||||||
|
|
||||||
|
|
||||||
# taken from https://github.com/Doggettx/stable-diffusion and modified
|
# taken from https://github.com/Doggettx/stable-diffusion and modified
|
||||||
def split_cross_attention_forward(self, x, context=None, mask=None):
|
def split_cross_attention_forward(self, x, context=None, mask=None, **kwargs):
|
||||||
h = self.heads
|
h = self.heads
|
||||||
|
|
||||||
q_in = self.to_q(x)
|
q_in = self.to_q(x)
|
||||||
@ -238,9 +256,9 @@ def split_cross_attention_forward(self, x, context=None, mask=None):
|
|||||||
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
|
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
|
||||||
f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free')
|
f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free')
|
||||||
|
|
||||||
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
|
slice_size = q.shape[1] // steps
|
||||||
for i in range(0, q.shape[1], slice_size):
|
for i in range(0, q.shape[1], slice_size):
|
||||||
end = i + slice_size
|
end = min(i + slice_size, q.shape[1])
|
||||||
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k)
|
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k)
|
||||||
|
|
||||||
s2 = s1.softmax(dim=-1, dtype=q.dtype)
|
s2 = s1.softmax(dim=-1, dtype=q.dtype)
|
||||||
@ -262,11 +280,13 @@ def split_cross_attention_forward(self, x, context=None, mask=None):
|
|||||||
# -- Taken from https://github.com/invoke-ai/InvokeAI and modified --
|
# -- Taken from https://github.com/invoke-ai/InvokeAI and modified --
|
||||||
mem_total_gb = psutil.virtual_memory().total // (1 << 30)
|
mem_total_gb = psutil.virtual_memory().total // (1 << 30)
|
||||||
|
|
||||||
|
|
||||||
def einsum_op_compvis(q, k, v):
|
def einsum_op_compvis(q, k, v):
|
||||||
s = einsum('b i d, b j d -> b i j', q, k)
|
s = einsum('b i d, b j d -> b i j', q, k)
|
||||||
s = s.softmax(dim=-1, dtype=s.dtype)
|
s = s.softmax(dim=-1, dtype=s.dtype)
|
||||||
return einsum('b i j, b j d -> b i d', s, v)
|
return einsum('b i j, b j d -> b i d', s, v)
|
||||||
|
|
||||||
|
|
||||||
def einsum_op_slice_0(q, k, v, slice_size):
|
def einsum_op_slice_0(q, k, v, slice_size):
|
||||||
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||||
for i in range(0, q.shape[0], slice_size):
|
for i in range(0, q.shape[0], slice_size):
|
||||||
@ -274,6 +294,7 @@ def einsum_op_slice_0(q, k, v, slice_size):
|
|||||||
r[i:end] = einsum_op_compvis(q[i:end], k[i:end], v[i:end])
|
r[i:end] = einsum_op_compvis(q[i:end], k[i:end], v[i:end])
|
||||||
return r
|
return r
|
||||||
|
|
||||||
|
|
||||||
def einsum_op_slice_1(q, k, v, slice_size):
|
def einsum_op_slice_1(q, k, v, slice_size):
|
||||||
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||||
for i in range(0, q.shape[1], slice_size):
|
for i in range(0, q.shape[1], slice_size):
|
||||||
@ -281,6 +302,7 @@ def einsum_op_slice_1(q, k, v, slice_size):
|
|||||||
r[:, i:end] = einsum_op_compvis(q[:, i:end], k, v)
|
r[:, i:end] = einsum_op_compvis(q[:, i:end], k, v)
|
||||||
return r
|
return r
|
||||||
|
|
||||||
|
|
||||||
def einsum_op_mps_v1(q, k, v):
|
def einsum_op_mps_v1(q, k, v):
|
||||||
if q.shape[0] * q.shape[1] <= 2**16: # (512x512) max q.shape[1]: 4096
|
if q.shape[0] * q.shape[1] <= 2**16: # (512x512) max q.shape[1]: 4096
|
||||||
return einsum_op_compvis(q, k, v)
|
return einsum_op_compvis(q, k, v)
|
||||||
@ -290,12 +312,14 @@ def einsum_op_mps_v1(q, k, v):
|
|||||||
slice_size -= 1
|
slice_size -= 1
|
||||||
return einsum_op_slice_1(q, k, v, slice_size)
|
return einsum_op_slice_1(q, k, v, slice_size)
|
||||||
|
|
||||||
|
|
||||||
def einsum_op_mps_v2(q, k, v):
|
def einsum_op_mps_v2(q, k, v):
|
||||||
if mem_total_gb > 8 and q.shape[0] * q.shape[1] <= 2**16:
|
if mem_total_gb > 8 and q.shape[0] * q.shape[1] <= 2**16:
|
||||||
return einsum_op_compvis(q, k, v)
|
return einsum_op_compvis(q, k, v)
|
||||||
else:
|
else:
|
||||||
return einsum_op_slice_0(q, k, v, 1)
|
return einsum_op_slice_0(q, k, v, 1)
|
||||||
|
|
||||||
|
|
||||||
def einsum_op_tensor_mem(q, k, v, max_tensor_mb):
|
def einsum_op_tensor_mem(q, k, v, max_tensor_mb):
|
||||||
size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20)
|
size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20)
|
||||||
if size_mb <= max_tensor_mb:
|
if size_mb <= max_tensor_mb:
|
||||||
@ -305,6 +329,7 @@ def einsum_op_tensor_mem(q, k, v, max_tensor_mb):
|
|||||||
return einsum_op_slice_0(q, k, v, q.shape[0] // div)
|
return einsum_op_slice_0(q, k, v, q.shape[0] // div)
|
||||||
return einsum_op_slice_1(q, k, v, max(q.shape[1] // div, 1))
|
return einsum_op_slice_1(q, k, v, max(q.shape[1] // div, 1))
|
||||||
|
|
||||||
|
|
||||||
def einsum_op_cuda(q, k, v):
|
def einsum_op_cuda(q, k, v):
|
||||||
stats = torch.cuda.memory_stats(q.device)
|
stats = torch.cuda.memory_stats(q.device)
|
||||||
mem_active = stats['active_bytes.all.current']
|
mem_active = stats['active_bytes.all.current']
|
||||||
@ -315,6 +340,7 @@ def einsum_op_cuda(q, k, v):
|
|||||||
# Divide factor of safety as there's copying and fragmentation
|
# Divide factor of safety as there's copying and fragmentation
|
||||||
return einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20))
|
return einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20))
|
||||||
|
|
||||||
|
|
||||||
def einsum_op(q, k, v):
|
def einsum_op(q, k, v):
|
||||||
if q.device.type == 'cuda':
|
if q.device.type == 'cuda':
|
||||||
return einsum_op_cuda(q, k, v)
|
return einsum_op_cuda(q, k, v)
|
||||||
@ -328,7 +354,8 @@ def einsum_op(q, k, v):
|
|||||||
# Tested on i7 with 8MB L3 cache.
|
# Tested on i7 with 8MB L3 cache.
|
||||||
return einsum_op_tensor_mem(q, k, v, 32)
|
return einsum_op_tensor_mem(q, k, v, 32)
|
||||||
|
|
||||||
def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None):
|
|
||||||
|
def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None, **kwargs):
|
||||||
h = self.heads
|
h = self.heads
|
||||||
|
|
||||||
q = self.to_q(x)
|
q = self.to_q(x)
|
||||||
@ -356,7 +383,7 @@ def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None):
|
|||||||
|
|
||||||
# Based on Birch-san's modified implementation of sub-quadratic attention from https://github.com/Birch-san/diffusers/pull/1
|
# Based on Birch-san's modified implementation of sub-quadratic attention from https://github.com/Birch-san/diffusers/pull/1
|
||||||
# The sub_quad_attention_forward function is under the MIT License listed under Memory Efficient Attention in the Licenses section of the web UI interface
|
# The sub_quad_attention_forward function is under the MIT License listed under Memory Efficient Attention in the Licenses section of the web UI interface
|
||||||
def sub_quad_attention_forward(self, x, context=None, mask=None):
|
def sub_quad_attention_forward(self, x, context=None, mask=None, **kwargs):
|
||||||
assert mask is None, "attention-mask not currently implemented for SubQuadraticCrossAttnProcessor."
|
assert mask is None, "attention-mask not currently implemented for SubQuadraticCrossAttnProcessor."
|
||||||
|
|
||||||
h = self.heads
|
h = self.heads
|
||||||
@ -392,6 +419,7 @@ def sub_quad_attention_forward(self, x, context=None, mask=None):
|
|||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_size_min=None, chunk_threshold=None, use_checkpoint=True):
|
def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_size_min=None, chunk_threshold=None, use_checkpoint=True):
|
||||||
bytes_per_token = torch.finfo(q.dtype).bits//8
|
bytes_per_token = torch.finfo(q.dtype).bits//8
|
||||||
batch_x_heads, q_tokens, _ = q.shape
|
batch_x_heads, q_tokens, _ = q.shape
|
||||||
@ -442,7 +470,7 @@ def get_xformers_flash_attention_op(q, k, v):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def xformers_attention_forward(self, x, context=None, mask=None):
|
def xformers_attention_forward(self, x, context=None, mask=None, **kwargs):
|
||||||
h = self.heads
|
h = self.heads
|
||||||
q_in = self.to_q(x)
|
q_in = self.to_q(x)
|
||||||
context = default(context, x)
|
context = default(context, x)
|
||||||
@ -465,9 +493,10 @@ def xformers_attention_forward(self, x, context=None, mask=None):
|
|||||||
out = rearrange(out, 'b n h d -> b n (h d)', h=h)
|
out = rearrange(out, 'b n h d -> b n (h d)', h=h)
|
||||||
return self.to_out(out)
|
return self.to_out(out)
|
||||||
|
|
||||||
|
|
||||||
# Based on Diffusers usage of scaled dot product attention from https://github.com/huggingface/diffusers/blob/c7da8fd23359a22d0df2741688b5b4f33c26df21/src/diffusers/models/cross_attention.py
|
# Based on Diffusers usage of scaled dot product attention from https://github.com/huggingface/diffusers/blob/c7da8fd23359a22d0df2741688b5b4f33c26df21/src/diffusers/models/cross_attention.py
|
||||||
# The scaled_dot_product_attention_forward function contains parts of code under Apache-2.0 license listed under Scaled Dot Product Attention in the Licenses section of the web UI interface
|
# The scaled_dot_product_attention_forward function contains parts of code under Apache-2.0 license listed under Scaled Dot Product Attention in the Licenses section of the web UI interface
|
||||||
def scaled_dot_product_attention_forward(self, x, context=None, mask=None):
|
def scaled_dot_product_attention_forward(self, x, context=None, mask=None, **kwargs):
|
||||||
batch_size, sequence_length, inner_dim = x.shape
|
batch_size, sequence_length, inner_dim = x.shape
|
||||||
|
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
@ -507,10 +536,12 @@ def scaled_dot_product_attention_forward(self, x, context=None, mask=None):
|
|||||||
hidden_states = self.to_out[1](hidden_states)
|
hidden_states = self.to_out[1](hidden_states)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
def scaled_dot_product_no_mem_attention_forward(self, x, context=None, mask=None):
|
|
||||||
|
def scaled_dot_product_no_mem_attention_forward(self, x, context=None, mask=None, **kwargs):
|
||||||
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False):
|
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False):
|
||||||
return scaled_dot_product_attention_forward(self, x, context, mask)
|
return scaled_dot_product_attention_forward(self, x, context, mask)
|
||||||
|
|
||||||
|
|
||||||
def cross_attention_attnblock_forward(self, x):
|
def cross_attention_attnblock_forward(self, x):
|
||||||
h_ = x
|
h_ = x
|
||||||
h_ = self.norm(h_)
|
h_ = self.norm(h_)
|
||||||
@ -569,6 +600,7 @@ def cross_attention_attnblock_forward(self, x):
|
|||||||
|
|
||||||
return h3
|
return h3
|
||||||
|
|
||||||
|
|
||||||
def xformers_attnblock_forward(self, x):
|
def xformers_attnblock_forward(self, x):
|
||||||
try:
|
try:
|
||||||
h_ = x
|
h_ = x
|
||||||
@ -592,6 +624,7 @@ def xformers_attnblock_forward(self, x):
|
|||||||
except NotImplementedError:
|
except NotImplementedError:
|
||||||
return cross_attention_attnblock_forward(self, x)
|
return cross_attention_attnblock_forward(self, x)
|
||||||
|
|
||||||
|
|
||||||
def sdp_attnblock_forward(self, x):
|
def sdp_attnblock_forward(self, x):
|
||||||
h_ = x
|
h_ = x
|
||||||
h_ = self.norm(h_)
|
h_ = self.norm(h_)
|
||||||
@ -612,10 +645,12 @@ def sdp_attnblock_forward(self, x):
|
|||||||
out = self.proj_out(out)
|
out = self.proj_out(out)
|
||||||
return x + out
|
return x + out
|
||||||
|
|
||||||
|
|
||||||
def sdp_no_mem_attnblock_forward(self, x):
|
def sdp_no_mem_attnblock_forward(self, x):
|
||||||
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False):
|
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False):
|
||||||
return sdp_attnblock_forward(self, x)
|
return sdp_attnblock_forward(self, x)
|
||||||
|
|
||||||
|
|
||||||
def sub_quad_attnblock_forward(self, x):
|
def sub_quad_attnblock_forward(self, x):
|
||||||
h_ = x
|
h_ = x
|
||||||
h_ = self.norm(h_)
|
h_ = self.norm(h_)
|
||||||
|
@ -39,7 +39,10 @@ def apply_model(orig_func, self, x_noisy, t, cond, **kwargs):
|
|||||||
|
|
||||||
if isinstance(cond, dict):
|
if isinstance(cond, dict):
|
||||||
for y in cond.keys():
|
for y in cond.keys():
|
||||||
cond[y] = [x.to(devices.dtype_unet) if isinstance(x, torch.Tensor) else x for x in cond[y]]
|
if isinstance(cond[y], list):
|
||||||
|
cond[y] = [x.to(devices.dtype_unet) if isinstance(x, torch.Tensor) else x for x in cond[y]]
|
||||||
|
else:
|
||||||
|
cond[y] = cond[y].to(devices.dtype_unet) if isinstance(cond[y], torch.Tensor) else cond[y]
|
||||||
|
|
||||||
with devices.autocast():
|
with devices.autocast():
|
||||||
return orig_func(self, x_noisy.to(devices.dtype_unet), t.to(devices.dtype_unet), cond, **kwargs).float()
|
return orig_func(self, x_noisy.to(devices.dtype_unet), t.to(devices.dtype_unet), cond, **kwargs).float()
|
||||||
@ -77,3 +80,6 @@ first_stage_sub = lambda orig_func, self, x, **kwargs: orig_func(self, x.to(devi
|
|||||||
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.decode_first_stage', first_stage_sub, first_stage_cond)
|
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.decode_first_stage', first_stage_sub, first_stage_cond)
|
||||||
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond)
|
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond)
|
||||||
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.get_first_stage_encoding', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).float(), first_stage_cond)
|
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.get_first_stage_encoding', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).float(), first_stage_cond)
|
||||||
|
|
||||||
|
CondFunc('sgm.modules.diffusionmodules.wrappers.OpenAIWrapper.forward', apply_model, unet_needs_upcast)
|
||||||
|
CondFunc('sgm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, timesteps, *args, **kwargs: orig_func(timesteps, *args, **kwargs).to(torch.float32 if timesteps.dtype == torch.int64 else devices.dtype_unet), unet_needs_upcast)
|
||||||
|
@ -14,8 +14,7 @@ import ldm.modules.midas as midas
|
|||||||
|
|
||||||
from ldm.util import instantiate_from_config
|
from ldm.util import instantiate_from_config
|
||||||
|
|
||||||
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet
|
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache
|
||||||
from modules.sd_hijack_inpainting import do_inpainting_hijack
|
|
||||||
from modules.timer import Timer
|
from modules.timer import Timer
|
||||||
import tomesd
|
import tomesd
|
||||||
|
|
||||||
@ -33,6 +32,8 @@ class CheckpointInfo:
|
|||||||
self.filename = filename
|
self.filename = filename
|
||||||
abspath = os.path.abspath(filename)
|
abspath = os.path.abspath(filename)
|
||||||
|
|
||||||
|
self.is_safetensors = os.path.splitext(filename)[1].lower() == ".safetensors"
|
||||||
|
|
||||||
if shared.cmd_opts.ckpt_dir is not None and abspath.startswith(shared.cmd_opts.ckpt_dir):
|
if shared.cmd_opts.ckpt_dir is not None and abspath.startswith(shared.cmd_opts.ckpt_dir):
|
||||||
name = abspath.replace(shared.cmd_opts.ckpt_dir, '')
|
name = abspath.replace(shared.cmd_opts.ckpt_dir, '')
|
||||||
elif abspath.startswith(model_path):
|
elif abspath.startswith(model_path):
|
||||||
@ -43,6 +44,19 @@ class CheckpointInfo:
|
|||||||
if name.startswith("\\") or name.startswith("/"):
|
if name.startswith("\\") or name.startswith("/"):
|
||||||
name = name[1:]
|
name = name[1:]
|
||||||
|
|
||||||
|
def read_metadata():
|
||||||
|
metadata = read_metadata_from_safetensors(filename)
|
||||||
|
self.modelspec_thumbnail = metadata.pop('modelspec.thumbnail', None)
|
||||||
|
|
||||||
|
return metadata
|
||||||
|
|
||||||
|
self.metadata = {}
|
||||||
|
if self.is_safetensors:
|
||||||
|
try:
|
||||||
|
self.metadata = cache.cached_data_for_file('safetensors-metadata', "checkpoint/" + name, filename, read_metadata)
|
||||||
|
except Exception as e:
|
||||||
|
errors.display(e, f"reading metadata for {filename}")
|
||||||
|
|
||||||
self.name = name
|
self.name = name
|
||||||
self.name_for_extra = os.path.splitext(os.path.basename(filename))[0]
|
self.name_for_extra = os.path.splitext(os.path.basename(filename))[0]
|
||||||
self.model_name = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0]
|
self.model_name = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0]
|
||||||
@ -52,17 +66,9 @@ class CheckpointInfo:
|
|||||||
self.shorthash = self.sha256[0:10] if self.sha256 else None
|
self.shorthash = self.sha256[0:10] if self.sha256 else None
|
||||||
|
|
||||||
self.title = name if self.shorthash is None else f'{name} [{self.shorthash}]'
|
self.title = name if self.shorthash is None else f'{name} [{self.shorthash}]'
|
||||||
|
self.short_title = self.name_for_extra if self.shorthash is None else f'{self.name_for_extra} [{self.shorthash}]'
|
||||||
|
|
||||||
self.ids = [self.hash, self.model_name, self.title, name, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]'] if self.shorthash else [])
|
self.ids = [self.hash, self.model_name, self.title, name, self.name_for_extra, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]'] if self.shorthash else [])
|
||||||
|
|
||||||
self.metadata = {}
|
|
||||||
|
|
||||||
_, ext = os.path.splitext(self.filename)
|
|
||||||
if ext.lower() == ".safetensors":
|
|
||||||
try:
|
|
||||||
self.metadata = read_metadata_from_safetensors(filename)
|
|
||||||
except Exception as e:
|
|
||||||
errors.display(e, f"reading checkpoint metadata: {filename}")
|
|
||||||
|
|
||||||
def register(self):
|
def register(self):
|
||||||
checkpoints_list[self.title] = self
|
checkpoints_list[self.title] = self
|
||||||
@ -79,8 +85,9 @@ class CheckpointInfo:
|
|||||||
if self.shorthash not in self.ids:
|
if self.shorthash not in self.ids:
|
||||||
self.ids += [self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]']
|
self.ids += [self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]']
|
||||||
|
|
||||||
checkpoints_list.pop(self.title)
|
checkpoints_list.pop(self.title, None)
|
||||||
self.title = f'{self.name} [{self.shorthash}]'
|
self.title = f'{self.name} [{self.shorthash}]'
|
||||||
|
self.short_title = f'{self.name_for_extra} [{self.shorthash}]'
|
||||||
self.register()
|
self.register()
|
||||||
|
|
||||||
return self.shorthash
|
return self.shorthash
|
||||||
@ -101,14 +108,8 @@ def setup_model():
|
|||||||
enable_midas_autodownload()
|
enable_midas_autodownload()
|
||||||
|
|
||||||
|
|
||||||
def checkpoint_tiles():
|
def checkpoint_tiles(use_short=False):
|
||||||
def convert(name):
|
return [x.short_title if use_short else x.title for x in checkpoints_list.values()]
|
||||||
return int(name) if name.isdigit() else name.lower()
|
|
||||||
|
|
||||||
def alphanumeric_key(key):
|
|
||||||
return [convert(c) for c in re.split('([0-9]+)', key)]
|
|
||||||
|
|
||||||
return sorted([x.title for x in checkpoints_list.values()], key=alphanumeric_key)
|
|
||||||
|
|
||||||
|
|
||||||
def list_models():
|
def list_models():
|
||||||
@ -131,11 +132,14 @@ def list_models():
|
|||||||
elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file:
|
elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file:
|
||||||
print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr)
|
print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr)
|
||||||
|
|
||||||
for filename in sorted(model_list, key=str.lower):
|
for filename in model_list:
|
||||||
checkpoint_info = CheckpointInfo(filename)
|
checkpoint_info = CheckpointInfo(filename)
|
||||||
checkpoint_info.register()
|
checkpoint_info.register()
|
||||||
|
|
||||||
|
|
||||||
|
re_strip_checksum = re.compile(r"\s*\[[^]]+]\s*$")
|
||||||
|
|
||||||
|
|
||||||
def get_closet_checkpoint_match(search_string):
|
def get_closet_checkpoint_match(search_string):
|
||||||
checkpoint_info = checkpoint_aliases.get(search_string, None)
|
checkpoint_info = checkpoint_aliases.get(search_string, None)
|
||||||
if checkpoint_info is not None:
|
if checkpoint_info is not None:
|
||||||
@ -145,6 +149,11 @@ def get_closet_checkpoint_match(search_string):
|
|||||||
if found:
|
if found:
|
||||||
return found[0]
|
return found[0]
|
||||||
|
|
||||||
|
search_string_without_checksum = re.sub(re_strip_checksum, '', search_string)
|
||||||
|
found = sorted([info for info in checkpoints_list.values() if search_string_without_checksum in info.title], key=lambda x: len(x.title))
|
||||||
|
if found:
|
||||||
|
return found[0]
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
@ -289,13 +298,21 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
|
|||||||
if state_dict is None:
|
if state_dict is None:
|
||||||
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
|
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
|
||||||
|
|
||||||
|
model.is_sdxl = hasattr(model, 'conditioner')
|
||||||
|
model.is_sd2 = not model.is_sdxl and hasattr(model.cond_stage_model, 'model')
|
||||||
|
model.is_sd1 = not model.is_sdxl and not model.is_sd2
|
||||||
|
|
||||||
|
if model.is_sdxl:
|
||||||
|
sd_models_xl.extend_sdxl(model)
|
||||||
|
|
||||||
model.load_state_dict(state_dict, strict=False)
|
model.load_state_dict(state_dict, strict=False)
|
||||||
del state_dict
|
|
||||||
timer.record("apply weights to model")
|
timer.record("apply weights to model")
|
||||||
|
|
||||||
if shared.opts.sd_checkpoint_cache > 0:
|
if shared.opts.sd_checkpoint_cache > 0:
|
||||||
# cache newly loaded model
|
# cache newly loaded model
|
||||||
checkpoints_loaded[checkpoint_info] = model.state_dict().copy()
|
checkpoints_loaded[checkpoint_info] = state_dict
|
||||||
|
|
||||||
|
del state_dict
|
||||||
|
|
||||||
if shared.cmd_opts.opt_channelslast:
|
if shared.cmd_opts.opt_channelslast:
|
||||||
model.to(memory_format=torch.channels_last)
|
model.to(memory_format=torch.channels_last)
|
||||||
@ -319,7 +336,7 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
|
|||||||
|
|
||||||
timer.record("apply half()")
|
timer.record("apply half()")
|
||||||
|
|
||||||
devices.dtype_unet = model.model.diffusion_model.dtype
|
devices.dtype_unet = torch.float16 if model.is_sdxl and not shared.cmd_opts.no_half else model.model.diffusion_model.dtype
|
||||||
devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16
|
devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16
|
||||||
|
|
||||||
model.first_stage_model.to(devices.dtype_vae)
|
model.first_stage_model.to(devices.dtype_vae)
|
||||||
@ -334,7 +351,8 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
|
|||||||
model.sd_checkpoint_info = checkpoint_info
|
model.sd_checkpoint_info = checkpoint_info
|
||||||
shared.opts.data["sd_checkpoint_hash"] = checkpoint_info.sha256
|
shared.opts.data["sd_checkpoint_hash"] = checkpoint_info.sha256
|
||||||
|
|
||||||
model.logvar = model.logvar.to(devices.device) # fix for training
|
if hasattr(model, 'logvar'):
|
||||||
|
model.logvar = model.logvar.to(devices.device) # fix for training
|
||||||
|
|
||||||
sd_vae.delete_base_vae()
|
sd_vae.delete_base_vae()
|
||||||
sd_vae.clear_loaded_vae()
|
sd_vae.clear_loaded_vae()
|
||||||
@ -391,10 +409,11 @@ def repair_config(sd_config):
|
|||||||
if not hasattr(sd_config.model.params, "use_ema"):
|
if not hasattr(sd_config.model.params, "use_ema"):
|
||||||
sd_config.model.params.use_ema = False
|
sd_config.model.params.use_ema = False
|
||||||
|
|
||||||
if shared.cmd_opts.no_half:
|
if hasattr(sd_config.model.params, 'unet_config'):
|
||||||
sd_config.model.params.unet_config.params.use_fp16 = False
|
if shared.cmd_opts.no_half:
|
||||||
elif shared.cmd_opts.upcast_sampling:
|
sd_config.model.params.unet_config.params.use_fp16 = False
|
||||||
sd_config.model.params.unet_config.params.use_fp16 = True
|
elif shared.cmd_opts.upcast_sampling:
|
||||||
|
sd_config.model.params.unet_config.params.use_fp16 = True
|
||||||
|
|
||||||
if getattr(sd_config.model.params.first_stage_config.params.ddconfig, "attn_type", None) == "vanilla-xformers" and not shared.xformers_available:
|
if getattr(sd_config.model.params.first_stage_config.params.ddconfig, "attn_type", None) == "vanilla-xformers" and not shared.xformers_available:
|
||||||
sd_config.model.params.first_stage_config.params.ddconfig.attn_type = "vanilla"
|
sd_config.model.params.first_stage_config.params.ddconfig.attn_type = "vanilla"
|
||||||
@ -407,11 +426,14 @@ def repair_config(sd_config):
|
|||||||
|
|
||||||
sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embedding.weight'
|
sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embedding.weight'
|
||||||
sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight'
|
sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight'
|
||||||
|
sdxl_clip_weight = 'conditioner.embedders.1.model.ln_final.weight'
|
||||||
|
sdxl_refiner_clip_weight = 'conditioner.embedders.0.model.ln_final.weight'
|
||||||
|
|
||||||
|
|
||||||
class SdModelData:
|
class SdModelData:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.sd_model = None
|
self.sd_model = None
|
||||||
|
self.loaded_sd_models = []
|
||||||
self.was_loaded_at_least_once = False
|
self.was_loaded_at_least_once = False
|
||||||
self.lock = threading.Lock()
|
self.lock = threading.Lock()
|
||||||
|
|
||||||
@ -426,6 +448,7 @@ class SdModelData:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
load_model()
|
load_model()
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
errors.display(e, "loading stable diffusion model", full_traceback=True)
|
errors.display(e, "loading stable diffusion model", full_traceback=True)
|
||||||
print("", file=sys.stderr)
|
print("", file=sys.stderr)
|
||||||
@ -437,23 +460,68 @@ class SdModelData:
|
|||||||
def set_sd_model(self, v):
|
def set_sd_model(self, v):
|
||||||
self.sd_model = v
|
self.sd_model = v
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.loaded_sd_models.remove(v)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if v is not None:
|
||||||
|
self.loaded_sd_models.insert(0, v)
|
||||||
|
|
||||||
|
|
||||||
model_data = SdModelData()
|
model_data = SdModelData()
|
||||||
|
|
||||||
|
|
||||||
|
def get_empty_cond(sd_model):
|
||||||
|
from modules import extra_networks, processing
|
||||||
|
|
||||||
|
p = processing.StableDiffusionProcessingTxt2Img()
|
||||||
|
extra_networks.activate(p, {})
|
||||||
|
|
||||||
|
if hasattr(sd_model, 'conditioner'):
|
||||||
|
d = sd_model.get_learned_conditioning([""])
|
||||||
|
return d['crossattn']
|
||||||
|
else:
|
||||||
|
return sd_model.cond_stage_model([""])
|
||||||
|
|
||||||
|
|
||||||
|
def send_model_to_cpu(m):
|
||||||
|
from modules import lowvram
|
||||||
|
|
||||||
|
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
||||||
|
lowvram.send_everything_to_cpu()
|
||||||
|
else:
|
||||||
|
m.to(devices.cpu)
|
||||||
|
|
||||||
|
devices.torch_gc()
|
||||||
|
|
||||||
|
|
||||||
|
def send_model_to_device(m):
|
||||||
|
from modules import lowvram
|
||||||
|
|
||||||
|
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
||||||
|
lowvram.setup_for_low_vram(m, shared.cmd_opts.medvram)
|
||||||
|
else:
|
||||||
|
m.to(shared.device)
|
||||||
|
|
||||||
|
|
||||||
|
def send_model_to_trash(m):
|
||||||
|
m.to(device="meta")
|
||||||
|
devices.torch_gc()
|
||||||
|
|
||||||
|
|
||||||
def load_model(checkpoint_info=None, already_loaded_state_dict=None):
|
def load_model(checkpoint_info=None, already_loaded_state_dict=None):
|
||||||
from modules import lowvram, sd_hijack
|
from modules import sd_hijack
|
||||||
checkpoint_info = checkpoint_info or select_checkpoint()
|
checkpoint_info = checkpoint_info or select_checkpoint()
|
||||||
|
|
||||||
|
timer = Timer()
|
||||||
|
|
||||||
if model_data.sd_model:
|
if model_data.sd_model:
|
||||||
sd_hijack.model_hijack.undo_hijack(model_data.sd_model)
|
send_model_to_trash(model_data.sd_model)
|
||||||
model_data.sd_model = None
|
model_data.sd_model = None
|
||||||
gc.collect()
|
|
||||||
devices.torch_gc()
|
devices.torch_gc()
|
||||||
|
|
||||||
do_inpainting_hijack()
|
timer.record("unload existing model")
|
||||||
|
|
||||||
timer = Timer()
|
|
||||||
|
|
||||||
if already_loaded_state_dict is not None:
|
if already_loaded_state_dict is not None:
|
||||||
state_dict = already_loaded_state_dict
|
state_dict = already_loaded_state_dict
|
||||||
@ -461,7 +529,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
|
|||||||
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
|
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
|
||||||
|
|
||||||
checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
|
checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
|
||||||
clip_is_included_into_sd = sd1_clip_weight in state_dict or sd2_clip_weight in state_dict
|
clip_is_included_into_sd = any(x for x in [sd1_clip_weight, sd2_clip_weight, sdxl_clip_weight, sdxl_refiner_clip_weight] if x in state_dict)
|
||||||
|
|
||||||
timer.record("find config")
|
timer.record("find config")
|
||||||
|
|
||||||
@ -474,26 +542,28 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
|
|||||||
|
|
||||||
sd_model = None
|
sd_model = None
|
||||||
try:
|
try:
|
||||||
with sd_disable_initialization.DisableInitialization(disable_clip=clip_is_included_into_sd):
|
with sd_disable_initialization.DisableInitialization(disable_clip=clip_is_included_into_sd or shared.cmd_opts.do_not_download_clip):
|
||||||
sd_model = instantiate_from_config(sd_config.model)
|
with sd_disable_initialization.InitializeOnMeta():
|
||||||
except Exception:
|
sd_model = instantiate_from_config(sd_config.model)
|
||||||
pass
|
|
||||||
|
except Exception as e:
|
||||||
|
errors.display(e, "creating model quickly", full_traceback=True)
|
||||||
|
|
||||||
if sd_model is None:
|
if sd_model is None:
|
||||||
print('Failed to create model quickly; will retry using slow method.', file=sys.stderr)
|
print('Failed to create model quickly; will retry using slow method.', file=sys.stderr)
|
||||||
sd_model = instantiate_from_config(sd_config.model)
|
|
||||||
|
with sd_disable_initialization.InitializeOnMeta():
|
||||||
|
sd_model = instantiate_from_config(sd_config.model)
|
||||||
|
|
||||||
sd_model.used_config = checkpoint_config
|
sd_model.used_config = checkpoint_config
|
||||||
|
|
||||||
timer.record("create model")
|
timer.record("create model")
|
||||||
|
|
||||||
load_model_weights(sd_model, checkpoint_info, state_dict, timer)
|
with sd_disable_initialization.LoadStateDictOnMeta(state_dict, devices.cpu):
|
||||||
|
load_model_weights(sd_model, checkpoint_info, state_dict, timer)
|
||||||
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
timer.record("load weights from state dict")
|
||||||
lowvram.setup_for_low_vram(sd_model, shared.cmd_opts.medvram)
|
|
||||||
else:
|
|
||||||
sd_model.to(shared.device)
|
|
||||||
|
|
||||||
|
send_model_to_device(sd_model)
|
||||||
timer.record("move model to device")
|
timer.record("move model to device")
|
||||||
|
|
||||||
sd_hijack.model_hijack.hijack(sd_model)
|
sd_hijack.model_hijack.hijack(sd_model)
|
||||||
@ -501,7 +571,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
|
|||||||
timer.record("hijack")
|
timer.record("hijack")
|
||||||
|
|
||||||
sd_model.eval()
|
sd_model.eval()
|
||||||
model_data.sd_model = sd_model
|
model_data.set_sd_model(sd_model)
|
||||||
model_data.was_loaded_at_least_once = True
|
model_data.was_loaded_at_least_once = True
|
||||||
|
|
||||||
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) # Reload embeddings after model load as they may or may not fit the model
|
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) # Reload embeddings after model load as they may or may not fit the model
|
||||||
@ -513,7 +583,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
|
|||||||
timer.record("scripts callbacks")
|
timer.record("scripts callbacks")
|
||||||
|
|
||||||
with devices.autocast(), torch.no_grad():
|
with devices.autocast(), torch.no_grad():
|
||||||
sd_model.cond_stage_model_empty_prompt = sd_model.cond_stage_model([""])
|
sd_model.cond_stage_model_empty_prompt = get_empty_cond(sd_model)
|
||||||
|
|
||||||
timer.record("calculate empty prompt")
|
timer.record("calculate empty prompt")
|
||||||
|
|
||||||
@ -522,10 +592,61 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
|
|||||||
return sd_model
|
return sd_model
|
||||||
|
|
||||||
|
|
||||||
|
def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer):
|
||||||
|
"""
|
||||||
|
Checks if the desired checkpoint from checkpoint_info is not already loaded in model_data.loaded_sd_models.
|
||||||
|
If it is loaded, returns that (moving it to GPU if necessary, and moving the currently loadded model to CPU if necessary).
|
||||||
|
If not, returns the model that can be used to load weights from checkpoint_info's file.
|
||||||
|
If no such model exists, returns None.
|
||||||
|
Additionaly deletes loaded models that are over the limit set in settings (sd_checkpoints_limit).
|
||||||
|
"""
|
||||||
|
|
||||||
|
already_loaded = None
|
||||||
|
for i in reversed(range(len(model_data.loaded_sd_models))):
|
||||||
|
loaded_model = model_data.loaded_sd_models[i]
|
||||||
|
if loaded_model.sd_checkpoint_info.filename == checkpoint_info.filename:
|
||||||
|
already_loaded = loaded_model
|
||||||
|
continue
|
||||||
|
|
||||||
|
if len(model_data.loaded_sd_models) > shared.opts.sd_checkpoints_limit > 0:
|
||||||
|
print(f"Unloading model {len(model_data.loaded_sd_models)} over the limit of {shared.opts.sd_checkpoints_limit}: {loaded_model.sd_checkpoint_info.title}")
|
||||||
|
model_data.loaded_sd_models.pop()
|
||||||
|
send_model_to_trash(loaded_model)
|
||||||
|
timer.record("send model to trash")
|
||||||
|
|
||||||
|
if shared.opts.sd_checkpoints_keep_in_cpu:
|
||||||
|
send_model_to_cpu(sd_model)
|
||||||
|
timer.record("send model to cpu")
|
||||||
|
|
||||||
|
if already_loaded is not None:
|
||||||
|
send_model_to_device(already_loaded)
|
||||||
|
timer.record("send model to device")
|
||||||
|
|
||||||
|
model_data.set_sd_model(already_loaded)
|
||||||
|
print(f"Using already loaded model {already_loaded.sd_checkpoint_info.title}: done in {timer.summary()}")
|
||||||
|
return model_data.sd_model
|
||||||
|
elif shared.opts.sd_checkpoints_limit > 1 and len(model_data.loaded_sd_models) < shared.opts.sd_checkpoints_limit:
|
||||||
|
print(f"Loading model {checkpoint_info.title} ({len(model_data.loaded_sd_models) + 1} out of {shared.opts.sd_checkpoints_limit})")
|
||||||
|
|
||||||
|
model_data.sd_model = None
|
||||||
|
load_model(checkpoint_info)
|
||||||
|
return model_data.sd_model
|
||||||
|
elif len(model_data.loaded_sd_models) > 0:
|
||||||
|
sd_model = model_data.loaded_sd_models.pop()
|
||||||
|
model_data.sd_model = sd_model
|
||||||
|
|
||||||
|
print(f"Reusing loaded model {sd_model.sd_checkpoint_info.title} to load {checkpoint_info.title}")
|
||||||
|
return sd_model
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def reload_model_weights(sd_model=None, info=None):
|
def reload_model_weights(sd_model=None, info=None):
|
||||||
from modules import lowvram, devices, sd_hijack
|
from modules import devices, sd_hijack
|
||||||
checkpoint_info = info or select_checkpoint()
|
checkpoint_info = info or select_checkpoint()
|
||||||
|
|
||||||
|
timer = Timer()
|
||||||
|
|
||||||
if not sd_model:
|
if not sd_model:
|
||||||
sd_model = model_data.sd_model
|
sd_model = model_data.sd_model
|
||||||
|
|
||||||
@ -534,19 +655,17 @@ def reload_model_weights(sd_model=None, info=None):
|
|||||||
else:
|
else:
|
||||||
current_checkpoint_info = sd_model.sd_checkpoint_info
|
current_checkpoint_info = sd_model.sd_checkpoint_info
|
||||||
if sd_model.sd_model_checkpoint == checkpoint_info.filename:
|
if sd_model.sd_model_checkpoint == checkpoint_info.filename:
|
||||||
return
|
return sd_model
|
||||||
|
|
||||||
|
sd_model = reuse_model_from_already_loaded(sd_model, checkpoint_info, timer)
|
||||||
|
if sd_model is not None and sd_model.sd_checkpoint_info.filename == checkpoint_info.filename:
|
||||||
|
return sd_model
|
||||||
|
|
||||||
|
if sd_model is not None:
|
||||||
sd_unet.apply_unet("None")
|
sd_unet.apply_unet("None")
|
||||||
|
send_model_to_cpu(sd_model)
|
||||||
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
|
||||||
lowvram.send_everything_to_cpu()
|
|
||||||
else:
|
|
||||||
sd_model.to(devices.cpu)
|
|
||||||
|
|
||||||
sd_hijack.model_hijack.undo_hijack(sd_model)
|
sd_hijack.model_hijack.undo_hijack(sd_model)
|
||||||
|
|
||||||
timer = Timer()
|
|
||||||
|
|
||||||
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
|
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
|
||||||
|
|
||||||
checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
|
checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
|
||||||
@ -554,7 +673,9 @@ def reload_model_weights(sd_model=None, info=None):
|
|||||||
timer.record("find config")
|
timer.record("find config")
|
||||||
|
|
||||||
if sd_model is None or checkpoint_config != sd_model.used_config:
|
if sd_model is None or checkpoint_config != sd_model.used_config:
|
||||||
del sd_model
|
if sd_model is not None:
|
||||||
|
send_model_to_trash(sd_model)
|
||||||
|
|
||||||
load_model(checkpoint_info, already_loaded_state_dict=state_dict)
|
load_model(checkpoint_info, already_loaded_state_dict=state_dict)
|
||||||
return model_data.sd_model
|
return model_data.sd_model
|
||||||
|
|
||||||
@ -577,6 +698,8 @@ def reload_model_weights(sd_model=None, info=None):
|
|||||||
|
|
||||||
print(f"Weights loaded in {timer.summary()}.")
|
print(f"Weights loaded in {timer.summary()}.")
|
||||||
|
|
||||||
|
model_data.set_sd_model(sd_model)
|
||||||
|
|
||||||
return sd_model
|
return sd_model
|
||||||
|
|
||||||
|
|
||||||
|
@ -6,12 +6,15 @@ from modules import shared, paths, sd_disable_initialization
|
|||||||
|
|
||||||
sd_configs_path = shared.sd_configs_path
|
sd_configs_path = shared.sd_configs_path
|
||||||
sd_repo_configs_path = os.path.join(paths.paths['Stable Diffusion'], "configs", "stable-diffusion")
|
sd_repo_configs_path = os.path.join(paths.paths['Stable Diffusion'], "configs", "stable-diffusion")
|
||||||
|
sd_xl_repo_configs_path = os.path.join(paths.paths['Stable Diffusion XL'], "configs", "inference")
|
||||||
|
|
||||||
|
|
||||||
config_default = shared.sd_default_config
|
config_default = shared.sd_default_config
|
||||||
config_sd2 = os.path.join(sd_repo_configs_path, "v2-inference.yaml")
|
config_sd2 = os.path.join(sd_repo_configs_path, "v2-inference.yaml")
|
||||||
config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-v.yaml")
|
config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-v.yaml")
|
||||||
config_sd2_inpainting = os.path.join(sd_repo_configs_path, "v2-inpainting-inference.yaml")
|
config_sd2_inpainting = os.path.join(sd_repo_configs_path, "v2-inpainting-inference.yaml")
|
||||||
|
config_sdxl = os.path.join(sd_xl_repo_configs_path, "sd_xl_base.yaml")
|
||||||
|
config_sdxl_refiner = os.path.join(sd_xl_repo_configs_path, "sd_xl_refiner.yaml")
|
||||||
config_depth_model = os.path.join(sd_repo_configs_path, "v2-midas-inference.yaml")
|
config_depth_model = os.path.join(sd_repo_configs_path, "v2-midas-inference.yaml")
|
||||||
config_unclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-l-inference.yaml")
|
config_unclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-l-inference.yaml")
|
||||||
config_unopenclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-h-inference.yaml")
|
config_unopenclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-h-inference.yaml")
|
||||||
@ -68,7 +71,11 @@ def guess_model_config_from_state_dict(sd, filename):
|
|||||||
diffusion_model_input = sd.get('model.diffusion_model.input_blocks.0.0.weight', None)
|
diffusion_model_input = sd.get('model.diffusion_model.input_blocks.0.0.weight', None)
|
||||||
sd2_variations_weight = sd.get('embedder.model.ln_final.weight', None)
|
sd2_variations_weight = sd.get('embedder.model.ln_final.weight', None)
|
||||||
|
|
||||||
if sd.get('depth_model.model.pretrained.act_postprocess3.0.project.0.bias', None) is not None:
|
if sd.get('conditioner.embedders.1.model.ln_final.weight', None) is not None:
|
||||||
|
return config_sdxl
|
||||||
|
if sd.get('conditioner.embedders.0.model.ln_final.weight', None) is not None:
|
||||||
|
return config_sdxl_refiner
|
||||||
|
elif sd.get('depth_model.model.pretrained.act_postprocess3.0.project.0.bias', None) is not None:
|
||||||
return config_depth_model
|
return config_depth_model
|
||||||
elif sd2_variations_weight is not None and sd2_variations_weight.shape[0] == 768:
|
elif sd2_variations_weight is not None and sd2_variations_weight.shape[0] == 768:
|
||||||
return config_unclip
|
return config_unclip
|
||||||
|
108
modules/sd_models_xl.py
Normal file
108
modules/sd_models_xl.py
Normal file
@ -0,0 +1,108 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import sgm.models.diffusion
|
||||||
|
import sgm.modules.diffusionmodules.denoiser_scaling
|
||||||
|
import sgm.modules.diffusionmodules.discretizer
|
||||||
|
from modules import devices, shared, prompt_parser
|
||||||
|
|
||||||
|
|
||||||
|
def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: prompt_parser.SdConditioning | list[str]):
|
||||||
|
for embedder in self.conditioner.embedders:
|
||||||
|
embedder.ucg_rate = 0.0
|
||||||
|
|
||||||
|
width = getattr(batch, 'width', 1024)
|
||||||
|
height = getattr(batch, 'height', 1024)
|
||||||
|
is_negative_prompt = getattr(batch, 'is_negative_prompt', False)
|
||||||
|
aesthetic_score = shared.opts.sdxl_refiner_low_aesthetic_score if is_negative_prompt else shared.opts.sdxl_refiner_high_aesthetic_score
|
||||||
|
|
||||||
|
devices_args = dict(device=devices.device, dtype=devices.dtype)
|
||||||
|
|
||||||
|
sdxl_conds = {
|
||||||
|
"txt": batch,
|
||||||
|
"original_size_as_tuple": torch.tensor([height, width], **devices_args).repeat(len(batch), 1),
|
||||||
|
"crop_coords_top_left": torch.tensor([shared.opts.sdxl_crop_top, shared.opts.sdxl_crop_left], **devices_args).repeat(len(batch), 1),
|
||||||
|
"target_size_as_tuple": torch.tensor([height, width], **devices_args).repeat(len(batch), 1),
|
||||||
|
"aesthetic_score": torch.tensor([aesthetic_score], **devices_args).repeat(len(batch), 1),
|
||||||
|
}
|
||||||
|
|
||||||
|
force_zero_negative_prompt = is_negative_prompt and all(x == '' for x in batch)
|
||||||
|
c = self.conditioner(sdxl_conds, force_zero_embeddings=['txt'] if force_zero_negative_prompt else [])
|
||||||
|
|
||||||
|
return c
|
||||||
|
|
||||||
|
|
||||||
|
def apply_model(self: sgm.models.diffusion.DiffusionEngine, x, t, cond):
|
||||||
|
return self.model(x, t, cond)
|
||||||
|
|
||||||
|
|
||||||
|
def get_first_stage_encoding(self, x): # SDXL's encode_first_stage does everything so get_first_stage_encoding is just there for compatibility
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
sgm.models.diffusion.DiffusionEngine.get_learned_conditioning = get_learned_conditioning
|
||||||
|
sgm.models.diffusion.DiffusionEngine.apply_model = apply_model
|
||||||
|
sgm.models.diffusion.DiffusionEngine.get_first_stage_encoding = get_first_stage_encoding
|
||||||
|
|
||||||
|
|
||||||
|
def encode_embedding_init_text(self: sgm.modules.GeneralConditioner, init_text, nvpt):
|
||||||
|
res = []
|
||||||
|
|
||||||
|
for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'encode_embedding_init_text')]:
|
||||||
|
encoded = embedder.encode_embedding_init_text(init_text, nvpt)
|
||||||
|
res.append(encoded)
|
||||||
|
|
||||||
|
return torch.cat(res, dim=1)
|
||||||
|
|
||||||
|
|
||||||
|
def tokenize(self: sgm.modules.GeneralConditioner, texts):
|
||||||
|
for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'tokenize')]:
|
||||||
|
return embedder.tokenize(texts)
|
||||||
|
|
||||||
|
raise AssertionError('no tokenizer available')
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def process_texts(self, texts):
|
||||||
|
for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'process_texts')]:
|
||||||
|
return embedder.process_texts(texts)
|
||||||
|
|
||||||
|
|
||||||
|
def get_target_prompt_token_count(self, token_count):
|
||||||
|
for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'get_target_prompt_token_count')]:
|
||||||
|
return embedder.get_target_prompt_token_count(token_count)
|
||||||
|
|
||||||
|
|
||||||
|
# those additions to GeneralConditioner make it possible to use it as model.cond_stage_model from SD1.5 in exist
|
||||||
|
sgm.modules.GeneralConditioner.encode_embedding_init_text = encode_embedding_init_text
|
||||||
|
sgm.modules.GeneralConditioner.tokenize = tokenize
|
||||||
|
sgm.modules.GeneralConditioner.process_texts = process_texts
|
||||||
|
sgm.modules.GeneralConditioner.get_target_prompt_token_count = get_target_prompt_token_count
|
||||||
|
|
||||||
|
|
||||||
|
def extend_sdxl(model):
|
||||||
|
"""this adds a bunch of parameters to make SDXL model look a bit more like SD1.5 to the rest of the codebase."""
|
||||||
|
|
||||||
|
dtype = next(model.model.diffusion_model.parameters()).dtype
|
||||||
|
model.model.diffusion_model.dtype = dtype
|
||||||
|
model.model.conditioning_key = 'crossattn'
|
||||||
|
model.cond_stage_key = 'txt'
|
||||||
|
# model.cond_stage_model will be set in sd_hijack
|
||||||
|
|
||||||
|
model.parameterization = "v" if isinstance(model.denoiser.scaling, sgm.modules.diffusionmodules.denoiser_scaling.VScaling) else "eps"
|
||||||
|
|
||||||
|
discretization = sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization()
|
||||||
|
model.alphas_cumprod = torch.asarray(discretization.alphas_cumprod, device=devices.device, dtype=dtype)
|
||||||
|
|
||||||
|
model.conditioner.wrapped = torch.nn.Module()
|
||||||
|
|
||||||
|
|
||||||
|
sgm.modules.attention.print = shared.ldm_print
|
||||||
|
sgm.modules.diffusionmodules.model.print = shared.ldm_print
|
||||||
|
sgm.modules.diffusionmodules.openaimodel.print = shared.ldm_print
|
||||||
|
sgm.modules.encoders.modules.print = shared.ldm_print
|
||||||
|
|
||||||
|
# this gets the code to load the vanilla attention that we override
|
||||||
|
sgm.modules.attention.SDP_IS_AVAILABLE = True
|
||||||
|
sgm.modules.attention.XFORMERS_IS_AVAILABLE = False
|
@ -28,6 +28,9 @@ def create_sampler(name, model):
|
|||||||
|
|
||||||
assert config is not None, f'bad sampler name: {name}'
|
assert config is not None, f'bad sampler name: {name}'
|
||||||
|
|
||||||
|
if model.is_sdxl and config.options.get("no_sdxl", False):
|
||||||
|
raise Exception(f"Sampler {config.name} is not supported for SDXL")
|
||||||
|
|
||||||
sampler = config.constructor(model)
|
sampler = config.constructor(model)
|
||||||
sampler.config = config
|
sampler.config = config
|
||||||
|
|
||||||
|
@ -2,10 +2,8 @@ from collections import namedtuple
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from modules import devices, processing, images, sd_vae_approx, sd_samplers, sd_vae_taesd
|
from modules import devices, images, sd_vae_approx, sd_samplers, sd_vae_taesd, shared
|
||||||
|
|
||||||
from modules.shared import opts, state
|
from modules.shared import opts, state
|
||||||
import modules.shared as shared
|
|
||||||
|
|
||||||
SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options'])
|
SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options'])
|
||||||
|
|
||||||
@ -37,7 +35,7 @@ def single_sample_to_image(sample, approximation=None):
|
|||||||
x_sample = sample * 1.5
|
x_sample = sample * 1.5
|
||||||
x_sample = sd_vae_taesd.model()(x_sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach()
|
x_sample = sd_vae_taesd.model()(x_sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach()
|
||||||
else:
|
else:
|
||||||
x_sample = processing.decode_first_stage(shared.sd_model, sample.unsqueeze(0))[0] * 0.5 + 0.5
|
x_sample = decode_first_stage(shared.sd_model, sample.unsqueeze(0))[0] * 0.5 + 0.5
|
||||||
|
|
||||||
x_sample = torch.clamp(x_sample, min=0.0, max=1.0)
|
x_sample = torch.clamp(x_sample, min=0.0, max=1.0)
|
||||||
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
|
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
|
||||||
@ -46,6 +44,12 @@ def single_sample_to_image(sample, approximation=None):
|
|||||||
return Image.fromarray(x_sample)
|
return Image.fromarray(x_sample)
|
||||||
|
|
||||||
|
|
||||||
|
def decode_first_stage(model, x):
|
||||||
|
x = model.decode_first_stage(x.to(devices.dtype_vae))
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
def sample_to_image(samples, index=0, approximation=None):
|
def sample_to_image(samples, index=0, approximation=None):
|
||||||
return single_sample_to_image(samples[index], approximation)
|
return single_sample_to_image(samples[index], approximation)
|
||||||
|
|
||||||
@ -85,11 +89,13 @@ class InterruptedException(BaseException):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
if opts.randn_source == "CPU":
|
def replace_torchsde_browinan():
|
||||||
import torchsde._brownian.brownian_interval
|
import torchsde._brownian.brownian_interval
|
||||||
|
|
||||||
def torchsde_randn(size, dtype, device, seed):
|
def torchsde_randn(size, dtype, device, seed):
|
||||||
generator = torch.Generator(devices.cpu).manual_seed(int(seed))
|
return devices.randn_local(seed, size).to(device=device, dtype=dtype)
|
||||||
return torch.randn(size, dtype=dtype, device=devices.cpu, generator=generator).to(device)
|
|
||||||
|
|
||||||
torchsde._brownian.brownian_interval._randn = torchsde_randn
|
torchsde._brownian.brownian_interval._randn = torchsde_randn
|
||||||
|
|
||||||
|
|
||||||
|
replace_torchsde_browinan()
|
||||||
|
@ -11,9 +11,9 @@ import modules.models.diffusion.uni_pc
|
|||||||
|
|
||||||
|
|
||||||
samplers_data_compvis = [
|
samplers_data_compvis = [
|
||||||
sd_samplers_common.SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {"default_eta_is_0": True, "uses_ensd": True}),
|
sd_samplers_common.SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {"default_eta_is_0": True, "uses_ensd": True, "no_sdxl": True}),
|
||||||
sd_samplers_common.SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {}),
|
sd_samplers_common.SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {"no_sdxl": True}),
|
||||||
sd_samplers_common.SamplerData('UniPC', lambda model: VanillaStableDiffusionSampler(modules.models.diffusion.uni_pc.UniPCSampler, model), [], {}),
|
sd_samplers_common.SamplerData('UniPC', lambda model: VanillaStableDiffusionSampler(modules.models.diffusion.uni_pc.UniPCSampler, model), [], {"no_sdxl": True}),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
74
modules/sd_samplers_extra.py
Normal file
74
modules/sd_samplers_extra.py
Normal file
@ -0,0 +1,74 @@
|
|||||||
|
import torch
|
||||||
|
import tqdm
|
||||||
|
import k_diffusion.sampling
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def restart_sampler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1., restart_list=None):
|
||||||
|
"""Implements restart sampling in Restart Sampling for Improving Generative Processes (2023)
|
||||||
|
Restart_list format: {min_sigma: [ restart_steps, restart_times, max_sigma]}
|
||||||
|
If restart_list is None: will choose restart_list automatically, otherwise will use the given restart_list
|
||||||
|
"""
|
||||||
|
extra_args = {} if extra_args is None else extra_args
|
||||||
|
s_in = x.new_ones([x.shape[0]])
|
||||||
|
step_id = 0
|
||||||
|
from k_diffusion.sampling import to_d, get_sigmas_karras
|
||||||
|
|
||||||
|
def heun_step(x, old_sigma, new_sigma, second_order=True):
|
||||||
|
nonlocal step_id
|
||||||
|
denoised = model(x, old_sigma * s_in, **extra_args)
|
||||||
|
d = to_d(x, old_sigma, denoised)
|
||||||
|
if callback is not None:
|
||||||
|
callback({'x': x, 'i': step_id, 'sigma': new_sigma, 'sigma_hat': old_sigma, 'denoised': denoised})
|
||||||
|
dt = new_sigma - old_sigma
|
||||||
|
if new_sigma == 0 or not second_order:
|
||||||
|
# Euler method
|
||||||
|
x = x + d * dt
|
||||||
|
else:
|
||||||
|
# Heun's method
|
||||||
|
x_2 = x + d * dt
|
||||||
|
denoised_2 = model(x_2, new_sigma * s_in, **extra_args)
|
||||||
|
d_2 = to_d(x_2, new_sigma, denoised_2)
|
||||||
|
d_prime = (d + d_2) / 2
|
||||||
|
x = x + d_prime * dt
|
||||||
|
step_id += 1
|
||||||
|
return x
|
||||||
|
|
||||||
|
steps = sigmas.shape[0] - 1
|
||||||
|
if restart_list is None:
|
||||||
|
if steps >= 20:
|
||||||
|
restart_steps = 9
|
||||||
|
restart_times = 1
|
||||||
|
if steps >= 36:
|
||||||
|
restart_steps = steps // 4
|
||||||
|
restart_times = 2
|
||||||
|
sigmas = get_sigmas_karras(steps - restart_steps * restart_times, sigmas[-2].item(), sigmas[0].item(), device=sigmas.device)
|
||||||
|
restart_list = {0.1: [restart_steps + 1, restart_times, 2]}
|
||||||
|
else:
|
||||||
|
restart_list = {}
|
||||||
|
|
||||||
|
restart_list = {int(torch.argmin(abs(sigmas - key), dim=0)): value for key, value in restart_list.items()}
|
||||||
|
|
||||||
|
step_list = []
|
||||||
|
for i in range(len(sigmas) - 1):
|
||||||
|
step_list.append((sigmas[i], sigmas[i + 1]))
|
||||||
|
if i + 1 in restart_list:
|
||||||
|
restart_steps, restart_times, restart_max = restart_list[i + 1]
|
||||||
|
min_idx = i + 1
|
||||||
|
max_idx = int(torch.argmin(abs(sigmas - restart_max), dim=0))
|
||||||
|
if max_idx < min_idx:
|
||||||
|
sigma_restart = get_sigmas_karras(restart_steps, sigmas[min_idx].item(), sigmas[max_idx].item(), device=sigmas.device)[:-1]
|
||||||
|
while restart_times > 0:
|
||||||
|
restart_times -= 1
|
||||||
|
step_list.extend([(old_sigma, new_sigma) for (old_sigma, new_sigma) in zip(sigma_restart[:-1], sigma_restart[1:])])
|
||||||
|
|
||||||
|
last_sigma = None
|
||||||
|
for old_sigma, new_sigma in tqdm.tqdm(step_list, disable=disable):
|
||||||
|
if last_sigma is None:
|
||||||
|
last_sigma = old_sigma
|
||||||
|
elif last_sigma < old_sigma:
|
||||||
|
x = x + k_diffusion.sampling.torch.randn_like(x) * s_noise * (old_sigma ** 2 - last_sigma ** 2) ** 0.5
|
||||||
|
x = heun_step(x, old_sigma, new_sigma)
|
||||||
|
last_sigma = new_sigma
|
||||||
|
|
||||||
|
return x
|
@ -2,7 +2,7 @@ from collections import deque
|
|||||||
import torch
|
import torch
|
||||||
import inspect
|
import inspect
|
||||||
import k_diffusion.sampling
|
import k_diffusion.sampling
|
||||||
from modules import prompt_parser, devices, sd_samplers_common
|
from modules import prompt_parser, devices, sd_samplers_common, sd_samplers_extra
|
||||||
|
|
||||||
from modules.shared import opts, state
|
from modules.shared import opts, state
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
@ -30,12 +30,15 @@ samplers_k_diffusion = [
|
|||||||
('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}),
|
('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}),
|
||||||
('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras', "second_order": True, "brownian_noise": True}),
|
('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras', "second_order": True, "brownian_noise": True}),
|
||||||
('DPM++ 2M SDE Karras', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_ka'], {'scheduler': 'karras', "brownian_noise": True}),
|
('DPM++ 2M SDE Karras', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_ka'], {'scheduler': 'karras', "brownian_noise": True}),
|
||||||
|
('DPM++ 2M SDE Exponential', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_exp'], {'scheduler': 'exponential', "brownian_noise": True}),
|
||||||
|
('Restart', sd_samplers_extra.restart_sampler, ['restart'], {'scheduler': 'karras'}),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
samplers_data_k_diffusion = [
|
samplers_data_k_diffusion = [
|
||||||
sd_samplers_common.SamplerData(label, lambda model, funcname=funcname: KDiffusionSampler(funcname, model), aliases, options)
|
sd_samplers_common.SamplerData(label, lambda model, funcname=funcname: KDiffusionSampler(funcname, model), aliases, options)
|
||||||
for label, funcname, aliases, options in samplers_k_diffusion
|
for label, funcname, aliases, options in samplers_k_diffusion
|
||||||
if hasattr(k_diffusion.sampling, funcname)
|
if callable(funcname) or hasattr(k_diffusion.sampling, funcname)
|
||||||
]
|
]
|
||||||
|
|
||||||
sampler_extra_params = {
|
sampler_extra_params = {
|
||||||
@ -53,6 +56,28 @@ k_diffusion_scheduler = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def catenate_conds(conds):
|
||||||
|
if not isinstance(conds[0], dict):
|
||||||
|
return torch.cat(conds)
|
||||||
|
|
||||||
|
return {key: torch.cat([x[key] for x in conds]) for key in conds[0].keys()}
|
||||||
|
|
||||||
|
|
||||||
|
def subscript_cond(cond, a, b):
|
||||||
|
if not isinstance(cond, dict):
|
||||||
|
return cond[a:b]
|
||||||
|
|
||||||
|
return {key: vec[a:b] for key, vec in cond.items()}
|
||||||
|
|
||||||
|
|
||||||
|
def pad_cond(tensor, repeats, empty):
|
||||||
|
if not isinstance(tensor, dict):
|
||||||
|
return torch.cat([tensor, empty.repeat((tensor.shape[0], repeats, 1))], axis=1)
|
||||||
|
|
||||||
|
tensor['crossattn'] = pad_cond(tensor['crossattn'], repeats, empty)
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
class CFGDenoiser(torch.nn.Module):
|
class CFGDenoiser(torch.nn.Module):
|
||||||
"""
|
"""
|
||||||
Classifier free guidance denoiser. A wrapper for stable diffusion model (specifically for unet)
|
Classifier free guidance denoiser. A wrapper for stable diffusion model (specifically for unet)
|
||||||
@ -105,10 +130,13 @@ class CFGDenoiser(torch.nn.Module):
|
|||||||
|
|
||||||
if shared.sd_model.model.conditioning_key == "crossattn-adm":
|
if shared.sd_model.model.conditioning_key == "crossattn-adm":
|
||||||
image_uncond = torch.zeros_like(image_cond)
|
image_uncond = torch.zeros_like(image_cond)
|
||||||
make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": c_crossattn, "c_adm": c_adm}
|
make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": [c_crossattn], "c_adm": c_adm}
|
||||||
else:
|
else:
|
||||||
image_uncond = image_cond
|
image_uncond = image_cond
|
||||||
make_condition_dict = lambda c_crossattn, c_concat: {"c_crossattn": c_crossattn, "c_concat": [c_concat]}
|
if isinstance(uncond, dict):
|
||||||
|
make_condition_dict = lambda c_crossattn, c_concat: {**c_crossattn, "c_concat": [c_concat]}
|
||||||
|
else:
|
||||||
|
make_condition_dict = lambda c_crossattn, c_concat: {"c_crossattn": [c_crossattn], "c_concat": [c_concat]}
|
||||||
|
|
||||||
if not is_edit_model:
|
if not is_edit_model:
|
||||||
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
|
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
|
||||||
@ -140,28 +168,28 @@ class CFGDenoiser(torch.nn.Module):
|
|||||||
num_repeats = (tensor.shape[1] - uncond.shape[1]) // empty.shape[1]
|
num_repeats = (tensor.shape[1] - uncond.shape[1]) // empty.shape[1]
|
||||||
|
|
||||||
if num_repeats < 0:
|
if num_repeats < 0:
|
||||||
tensor = torch.cat([tensor, empty.repeat((tensor.shape[0], -num_repeats, 1))], axis=1)
|
tensor = pad_cond(tensor, -num_repeats, empty)
|
||||||
self.padded_cond_uncond = True
|
self.padded_cond_uncond = True
|
||||||
elif num_repeats > 0:
|
elif num_repeats > 0:
|
||||||
uncond = torch.cat([uncond, empty.repeat((uncond.shape[0], num_repeats, 1))], axis=1)
|
uncond = pad_cond(uncond, num_repeats, empty)
|
||||||
self.padded_cond_uncond = True
|
self.padded_cond_uncond = True
|
||||||
|
|
||||||
if tensor.shape[1] == uncond.shape[1] or skip_uncond:
|
if tensor.shape[1] == uncond.shape[1] or skip_uncond:
|
||||||
if is_edit_model:
|
if is_edit_model:
|
||||||
cond_in = torch.cat([tensor, uncond, uncond])
|
cond_in = catenate_conds([tensor, uncond, uncond])
|
||||||
elif skip_uncond:
|
elif skip_uncond:
|
||||||
cond_in = tensor
|
cond_in = tensor
|
||||||
else:
|
else:
|
||||||
cond_in = torch.cat([tensor, uncond])
|
cond_in = catenate_conds([tensor, uncond])
|
||||||
|
|
||||||
if shared.batch_cond_uncond:
|
if shared.batch_cond_uncond:
|
||||||
x_out = self.inner_model(x_in, sigma_in, cond=make_condition_dict([cond_in], image_cond_in))
|
x_out = self.inner_model(x_in, sigma_in, cond=make_condition_dict(cond_in, image_cond_in))
|
||||||
else:
|
else:
|
||||||
x_out = torch.zeros_like(x_in)
|
x_out = torch.zeros_like(x_in)
|
||||||
for batch_offset in range(0, x_out.shape[0], batch_size):
|
for batch_offset in range(0, x_out.shape[0], batch_size):
|
||||||
a = batch_offset
|
a = batch_offset
|
||||||
b = a + batch_size
|
b = a + batch_size
|
||||||
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict([cond_in[a:b]], image_cond_in[a:b]))
|
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(subscript_cond(cond_in, a, b), image_cond_in[a:b]))
|
||||||
else:
|
else:
|
||||||
x_out = torch.zeros_like(x_in)
|
x_out = torch.zeros_like(x_in)
|
||||||
batch_size = batch_size*2 if shared.batch_cond_uncond else batch_size
|
batch_size = batch_size*2 if shared.batch_cond_uncond else batch_size
|
||||||
@ -170,14 +198,14 @@ class CFGDenoiser(torch.nn.Module):
|
|||||||
b = min(a + batch_size, tensor.shape[0])
|
b = min(a + batch_size, tensor.shape[0])
|
||||||
|
|
||||||
if not is_edit_model:
|
if not is_edit_model:
|
||||||
c_crossattn = [tensor[a:b]]
|
c_crossattn = subscript_cond(tensor, a, b)
|
||||||
else:
|
else:
|
||||||
c_crossattn = torch.cat([tensor[a:b]], uncond)
|
c_crossattn = torch.cat([tensor[a:b]], uncond)
|
||||||
|
|
||||||
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(c_crossattn, image_cond_in[a:b]))
|
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(c_crossattn, image_cond_in[a:b]))
|
||||||
|
|
||||||
if not skip_uncond:
|
if not skip_uncond:
|
||||||
x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=make_condition_dict([uncond], image_cond_in[-uncond.shape[0]:]))
|
x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=make_condition_dict(uncond, image_cond_in[-uncond.shape[0]:]))
|
||||||
|
|
||||||
denoised_image_indexes = [x[0][0] for x in conds_list]
|
denoised_image_indexes = [x[0][0] for x in conds_list]
|
||||||
if skip_uncond:
|
if skip_uncond:
|
||||||
@ -233,10 +261,7 @@ class TorchHijack:
|
|||||||
if noise.shape == x.shape:
|
if noise.shape == x.shape:
|
||||||
return noise
|
return noise
|
||||||
|
|
||||||
if opts.randn_source == "CPU" or x.device.type == 'mps':
|
return devices.randn_like(x)
|
||||||
return torch.randn_like(x, device=devices.cpu).to(x.device)
|
|
||||||
else:
|
|
||||||
return torch.randn_like(x)
|
|
||||||
|
|
||||||
|
|
||||||
class KDiffusionSampler:
|
class KDiffusionSampler:
|
||||||
@ -245,7 +270,7 @@ class KDiffusionSampler:
|
|||||||
|
|
||||||
self.model_wrap = denoiser(sd_model, quantize=shared.opts.enable_quantization)
|
self.model_wrap = denoiser(sd_model, quantize=shared.opts.enable_quantization)
|
||||||
self.funcname = funcname
|
self.funcname = funcname
|
||||||
self.func = getattr(k_diffusion.sampling, self.funcname)
|
self.func = funcname if callable(funcname) else getattr(k_diffusion.sampling, self.funcname)
|
||||||
self.extra_params = sampler_extra_params.get(funcname, [])
|
self.extra_params = sampler_extra_params.get(funcname, [])
|
||||||
self.model_wrap_cfg = CFGDenoiser(self.model_wrap)
|
self.model_wrap_cfg = CFGDenoiser(self.model_wrap)
|
||||||
self.sampler_noises = None
|
self.sampler_noises = None
|
||||||
@ -351,6 +376,9 @@ class KDiffusionSampler:
|
|||||||
sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item())
|
sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item())
|
||||||
|
|
||||||
sigmas = k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, device=shared.device)
|
sigmas = k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, device=shared.device)
|
||||||
|
elif self.config is not None and self.config.options.get('scheduler', None) == 'exponential':
|
||||||
|
m_sigma_min, m_sigma_max = (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item())
|
||||||
|
sigmas = k_diffusion.sampling.get_sigmas_exponential(n=steps, sigma_min=m_sigma_min, sigma_max=m_sigma_max, device=shared.device)
|
||||||
else:
|
else:
|
||||||
sigmas = self.model_wrap.get_sigmas(steps)
|
sigmas = self.model_wrap.get_sigmas(steps)
|
||||||
|
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
import collections
|
import collections
|
||||||
from modules import paths, shared, devices, script_callbacks, sd_models
|
from modules import paths, shared, devices, script_callbacks, sd_models, extra_networks
|
||||||
import glob
|
import glob
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
|
||||||
@ -16,6 +16,7 @@ checkpoint_info = None
|
|||||||
|
|
||||||
checkpoints_loaded = collections.OrderedDict()
|
checkpoints_loaded = collections.OrderedDict()
|
||||||
|
|
||||||
|
|
||||||
def get_base_vae(model):
|
def get_base_vae(model):
|
||||||
if base_vae is not None and checkpoint_info == model.sd_checkpoint_info and model:
|
if base_vae is not None and checkpoint_info == model.sd_checkpoint_info and model:
|
||||||
return base_vae
|
return base_vae
|
||||||
@ -50,6 +51,7 @@ def get_filename(filepath):
|
|||||||
|
|
||||||
|
|
||||||
def refresh_vae_list():
|
def refresh_vae_list():
|
||||||
|
global vae_dict
|
||||||
vae_dict.clear()
|
vae_dict.clear()
|
||||||
|
|
||||||
paths = [
|
paths = [
|
||||||
@ -83,6 +85,8 @@ def refresh_vae_list():
|
|||||||
name = get_filename(filepath)
|
name = get_filename(filepath)
|
||||||
vae_dict[name] = filepath
|
vae_dict[name] = filepath
|
||||||
|
|
||||||
|
vae_dict = dict(sorted(vae_dict.items(), key=lambda item: shared.natural_sort_key(item[0])))
|
||||||
|
|
||||||
|
|
||||||
def find_vae_near_checkpoint(checkpoint_file):
|
def find_vae_near_checkpoint(checkpoint_file):
|
||||||
checkpoint_path = os.path.basename(checkpoint_file).rsplit('.', 1)[0]
|
checkpoint_path = os.path.basename(checkpoint_file).rsplit('.', 1)[0]
|
||||||
@ -97,6 +101,16 @@ def resolve_vae(checkpoint_file):
|
|||||||
if shared.cmd_opts.vae_path is not None:
|
if shared.cmd_opts.vae_path is not None:
|
||||||
return shared.cmd_opts.vae_path, 'from commandline argument'
|
return shared.cmd_opts.vae_path, 'from commandline argument'
|
||||||
|
|
||||||
|
metadata = extra_networks.get_user_metadata(checkpoint_file)
|
||||||
|
vae_metadata = metadata.get("vae", None)
|
||||||
|
if vae_metadata is not None and vae_metadata != "Automatic":
|
||||||
|
if vae_metadata == "None":
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
vae_from_metadata = vae_dict.get(vae_metadata, None)
|
||||||
|
if vae_from_metadata is not None:
|
||||||
|
return vae_from_metadata, "from user metadata"
|
||||||
|
|
||||||
is_automatic = shared.opts.sd_vae in {"Automatic", "auto"} # "auto" for people with old config
|
is_automatic = shared.opts.sd_vae in {"Automatic", "auto"} # "auto" for people with old config
|
||||||
|
|
||||||
vae_near_checkpoint = find_vae_near_checkpoint(checkpoint_file)
|
vae_near_checkpoint = find_vae_near_checkpoint(checkpoint_file)
|
||||||
|
@ -2,9 +2,9 @@ import os
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from modules import devices, paths
|
from modules import devices, paths, shared
|
||||||
|
|
||||||
sd_vae_approx_model = None
|
sd_vae_approx_models = {}
|
||||||
|
|
||||||
|
|
||||||
class VAEApprox(nn.Module):
|
class VAEApprox(nn.Module):
|
||||||
@ -31,30 +31,55 @@ class VAEApprox(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def download_model(model_path, model_url):
|
||||||
|
if not os.path.exists(model_path):
|
||||||
|
os.makedirs(os.path.dirname(model_path), exist_ok=True)
|
||||||
|
|
||||||
|
print(f'Downloading VAEApprox model to: {model_path}')
|
||||||
|
torch.hub.download_url_to_file(model_url, model_path)
|
||||||
|
|
||||||
|
|
||||||
def model():
|
def model():
|
||||||
global sd_vae_approx_model
|
model_name = "vaeapprox-sdxl.pt" if getattr(shared.sd_model, 'is_sdxl', False) else "model.pt"
|
||||||
|
loaded_model = sd_vae_approx_models.get(model_name)
|
||||||
|
|
||||||
if sd_vae_approx_model is None:
|
if loaded_model is None:
|
||||||
model_path = os.path.join(paths.models_path, "VAE-approx", "model.pt")
|
model_path = os.path.join(paths.models_path, "VAE-approx", model_name)
|
||||||
sd_vae_approx_model = VAEApprox()
|
|
||||||
if not os.path.exists(model_path):
|
if not os.path.exists(model_path):
|
||||||
model_path = os.path.join(paths.script_path, "models", "VAE-approx", "model.pt")
|
model_path = os.path.join(paths.script_path, "models", "VAE-approx", model_name)
|
||||||
sd_vae_approx_model.load_state_dict(torch.load(model_path, map_location='cpu' if devices.device.type != 'cuda' else None))
|
|
||||||
sd_vae_approx_model.eval()
|
|
||||||
sd_vae_approx_model.to(devices.device, devices.dtype)
|
|
||||||
|
|
||||||
return sd_vae_approx_model
|
if not os.path.exists(model_path):
|
||||||
|
model_path = os.path.join(paths.models_path, "VAE-approx", model_name)
|
||||||
|
download_model(model_path, 'https://github.com/AUTOMATIC1111/stable-diffusion-webui/releases/download/v1.0.0-pre/' + model_name)
|
||||||
|
|
||||||
|
loaded_model = VAEApprox()
|
||||||
|
loaded_model.load_state_dict(torch.load(model_path, map_location='cpu' if devices.device.type != 'cuda' else None))
|
||||||
|
loaded_model.eval()
|
||||||
|
loaded_model.to(devices.device, devices.dtype)
|
||||||
|
sd_vae_approx_models[model_name] = loaded_model
|
||||||
|
|
||||||
|
return loaded_model
|
||||||
|
|
||||||
|
|
||||||
def cheap_approximation(sample):
|
def cheap_approximation(sample):
|
||||||
# https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/2
|
# https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/2
|
||||||
|
|
||||||
coefs = torch.tensor([
|
if shared.sd_model.is_sdxl:
|
||||||
[0.298, 0.207, 0.208],
|
coeffs = [
|
||||||
[0.187, 0.286, 0.173],
|
[ 0.3448, 0.4168, 0.4395],
|
||||||
[-0.158, 0.189, 0.264],
|
[-0.1953, -0.0290, 0.0250],
|
||||||
[-0.184, -0.271, -0.473],
|
[ 0.1074, 0.0886, -0.0163],
|
||||||
]).to(sample.device)
|
[-0.3730, -0.2499, -0.2088],
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
coeffs = [
|
||||||
|
[ 0.298, 0.207, 0.208],
|
||||||
|
[ 0.187, 0.286, 0.173],
|
||||||
|
[-0.158, 0.189, 0.264],
|
||||||
|
[-0.184, -0.271, -0.473],
|
||||||
|
]
|
||||||
|
|
||||||
|
coefs = torch.tensor(coeffs).to(sample.device)
|
||||||
|
|
||||||
x_sample = torch.einsum("lxy,lr -> rxy", sample, coefs)
|
x_sample = torch.einsum("lxy,lr -> rxy", sample, coefs)
|
||||||
|
|
||||||
|
@ -8,9 +8,9 @@ import os
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from modules import devices, paths_internal
|
from modules import devices, paths_internal, shared
|
||||||
|
|
||||||
sd_vae_taesd = None
|
sd_vae_taesd_models = {}
|
||||||
|
|
||||||
|
|
||||||
def conv(n_in, n_out, **kwargs):
|
def conv(n_in, n_out, **kwargs):
|
||||||
@ -61,9 +61,7 @@ class TAESD(nn.Module):
|
|||||||
return x.sub(TAESD.latent_shift).mul(2 * TAESD.latent_magnitude)
|
return x.sub(TAESD.latent_shift).mul(2 * TAESD.latent_magnitude)
|
||||||
|
|
||||||
|
|
||||||
def download_model(model_path):
|
def download_model(model_path, model_url):
|
||||||
model_url = 'https://github.com/madebyollin/taesd/raw/main/taesd_decoder.pth'
|
|
||||||
|
|
||||||
if not os.path.exists(model_path):
|
if not os.path.exists(model_path):
|
||||||
os.makedirs(os.path.dirname(model_path), exist_ok=True)
|
os.makedirs(os.path.dirname(model_path), exist_ok=True)
|
||||||
|
|
||||||
@ -72,17 +70,19 @@ def download_model(model_path):
|
|||||||
|
|
||||||
|
|
||||||
def model():
|
def model():
|
||||||
global sd_vae_taesd
|
model_name = "taesdxl_decoder.pth" if getattr(shared.sd_model, 'is_sdxl', False) else "taesd_decoder.pth"
|
||||||
|
loaded_model = sd_vae_taesd_models.get(model_name)
|
||||||
|
|
||||||
if sd_vae_taesd is None:
|
if loaded_model is None:
|
||||||
model_path = os.path.join(paths_internal.models_path, "VAE-taesd", "taesd_decoder.pth")
|
model_path = os.path.join(paths_internal.models_path, "VAE-taesd", model_name)
|
||||||
download_model(model_path)
|
download_model(model_path, 'https://github.com/madebyollin/taesd/raw/main/' + model_name)
|
||||||
|
|
||||||
if os.path.exists(model_path):
|
if os.path.exists(model_path):
|
||||||
sd_vae_taesd = TAESD(model_path)
|
loaded_model = TAESD(model_path)
|
||||||
sd_vae_taesd.eval()
|
loaded_model.eval()
|
||||||
sd_vae_taesd.to(devices.device, devices.dtype)
|
loaded_model.to(devices.device, devices.dtype)
|
||||||
|
sd_vae_taesd_models[model_name] = loaded_model
|
||||||
else:
|
else:
|
||||||
raise FileNotFoundError('TAESD model not found')
|
raise FileNotFoundError('TAESD model not found')
|
||||||
|
|
||||||
return sd_vae_taesd.decoder
|
return loaded_model.decoder
|
||||||
|
@ -11,6 +11,7 @@ import gradio as gr
|
|||||||
import torch
|
import torch
|
||||||
import tqdm
|
import tqdm
|
||||||
|
|
||||||
|
import launch
|
||||||
import modules.interrogate
|
import modules.interrogate
|
||||||
import modules.memmon
|
import modules.memmon
|
||||||
import modules.styles
|
import modules.styles
|
||||||
@ -26,7 +27,7 @@ demo = None
|
|||||||
|
|
||||||
parser = cmd_args.parser
|
parser = cmd_args.parser
|
||||||
|
|
||||||
script_loading.preload_extensions(extensions_dir, parser)
|
script_loading.preload_extensions(extensions_dir, parser, extension_list=launch.list_extensions(launch.args.ui_settings_file))
|
||||||
script_loading.preload_extensions(extensions_builtin_dir, parser)
|
script_loading.preload_extensions(extensions_builtin_dir, parser)
|
||||||
|
|
||||||
if os.environ.get('IGNORE_CMD_ARGS_ERRORS', None) is None:
|
if os.environ.get('IGNORE_CMD_ARGS_ERRORS', None) is None:
|
||||||
@ -219,12 +220,19 @@ class State:
|
|||||||
return
|
return
|
||||||
|
|
||||||
import modules.sd_samplers
|
import modules.sd_samplers
|
||||||
if opts.show_progress_grid:
|
|
||||||
self.assign_current_image(modules.sd_samplers.samples_to_image_grid(self.current_latent))
|
|
||||||
else:
|
|
||||||
self.assign_current_image(modules.sd_samplers.sample_to_image(self.current_latent))
|
|
||||||
|
|
||||||
self.current_image_sampling_step = self.sampling_step
|
try:
|
||||||
|
if opts.show_progress_grid:
|
||||||
|
self.assign_current_image(modules.sd_samplers.samples_to_image_grid(self.current_latent))
|
||||||
|
else:
|
||||||
|
self.assign_current_image(modules.sd_samplers.sample_to_image(self.current_latent))
|
||||||
|
|
||||||
|
self.current_image_sampling_step = self.sampling_step
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
# when switching models during genration, VAE would be on CPU, so creating an image will fail.
|
||||||
|
# we silently ignore this error
|
||||||
|
errors.record_exception()
|
||||||
|
|
||||||
def assign_current_image(self, image):
|
def assign_current_image(self, image):
|
||||||
self.current_image = image
|
self.current_image = image
|
||||||
@ -384,13 +392,15 @@ options_templates.update(options_section(('face-restoration', "Face restoration"
|
|||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('system', "System"), {
|
options_templates.update(options_section(('system', "System"), {
|
||||||
"show_warnings": OptionInfo(False, "Show warnings in console."),
|
"show_warnings": OptionInfo(False, "Show warnings in console.").needs_restart(),
|
||||||
|
"show_gradio_deprecation_warnings": OptionInfo(True, "Show gradio deprecation warnings in console.").needs_restart(),
|
||||||
"memmon_poll_rate": OptionInfo(8, "VRAM usage polls per second during generation.", gr.Slider, {"minimum": 0, "maximum": 40, "step": 1}).info("0 = disable"),
|
"memmon_poll_rate": OptionInfo(8, "VRAM usage polls per second during generation.", gr.Slider, {"minimum": 0, "maximum": 40, "step": 1}).info("0 = disable"),
|
||||||
"samples_log_stdout": OptionInfo(False, "Always print all generation info to standard output"),
|
"samples_log_stdout": OptionInfo(False, "Always print all generation info to standard output"),
|
||||||
"multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."),
|
"multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."),
|
||||||
"print_hypernet_extra": OptionInfo(False, "Print extra hypernetwork information to console."),
|
"print_hypernet_extra": OptionInfo(False, "Print extra hypernetwork information to console."),
|
||||||
"list_hidden_files": OptionInfo(True, "Load models/files in hidden directories").info("directory is hidden if its name starts with \".\""),
|
"list_hidden_files": OptionInfo(True, "Load models/files in hidden directories").info("directory is hidden if its name starts with \".\""),
|
||||||
"disable_mmap_load_safetensors": OptionInfo(False, "Disable memmapping for loading .safetensors files.").info("fixes very slow loading speed in some cases"),
|
"disable_mmap_load_safetensors": OptionInfo(False, "Disable memmapping for loading .safetensors files.").info("fixes very slow loading speed in some cases"),
|
||||||
|
"hide_ldm_prints": OptionInfo(True, "Prevent Stability-AI's ldm/sgm modules from printing noise to console."),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('training', "Training"), {
|
options_templates.update(options_section(('training', "Training"), {
|
||||||
@ -410,28 +420,49 @@ options_templates.update(options_section(('training', "Training"), {
|
|||||||
|
|
||||||
options_templates.update(options_section(('sd', "Stable Diffusion"), {
|
options_templates.update(options_section(('sd', "Stable Diffusion"), {
|
||||||
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints),
|
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints),
|
||||||
"sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
|
"sd_checkpoints_limit": OptionInfo(1, "Maximum number of checkpoints loaded at the same time", gr.Slider, {"minimum": 1, "maximum": 10, "step": 1}),
|
||||||
|
"sd_checkpoints_keep_in_cpu": OptionInfo(True, "Only keep one model on device").info("will keep models other than the currently used one in RAM rather than VRAM"),
|
||||||
|
"sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}).info("obsolete; set to 0 and use the two settings above instead"),
|
||||||
"sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
|
"sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
|
||||||
"sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": shared_items.sd_vae_items()}, refresh=shared_items.refresh_vae_list).info("choose VAE model: Automatic = use one with same filename as checkpoint; None = use VAE from checkpoint"),
|
"sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": shared_items.sd_vae_items()}, refresh=shared_items.refresh_vae_list).info("choose VAE model: Automatic = use one with same filename as checkpoint; None = use VAE from checkpoint"),
|
||||||
"sd_vae_as_default": OptionInfo(True, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"),
|
"sd_vae_as_default": OptionInfo(True, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"),
|
||||||
"sd_unet": OptionInfo("Automatic", "SD Unet", gr.Dropdown, lambda: {"choices": shared_items.sd_unet_items()}, refresh=shared_items.refresh_unet_list).info("choose Unet model: Automatic = use one with same filename as checkpoint; None = use Unet from checkpoint"),
|
"sd_unet": OptionInfo("Automatic", "SD Unet", gr.Dropdown, lambda: {"choices": shared_items.sd_unet_items()}, refresh=shared_items.refresh_unet_list).info("choose Unet model: Automatic = use one with same filename as checkpoint; None = use Unet from checkpoint"),
|
||||||
"inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
"enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds").needs_restart(),
|
||||||
"initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.5, "maximum": 1.5, "step": 0.01}),
|
|
||||||
"img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
|
|
||||||
"img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies.").info("normally you'd do less with less denoising"),
|
|
||||||
"img2img_background_color": OptionInfo("#ffffff", "With img2img, fill image's transparent parts with this color.", ui_components.FormColorPicker, {}),
|
|
||||||
"enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds. Requires restart to apply."),
|
|
||||||
"enable_emphasis": OptionInfo(True, "Enable emphasis").info("use (text) to make model pay more attention to text and [text] to make it pay less attention"),
|
"enable_emphasis": OptionInfo(True, "Enable emphasis").info("use (text) to make model pay more attention to text and [text] to make it pay less attention"),
|
||||||
"enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"),
|
"enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"),
|
||||||
"comma_padding_backtrack": OptionInfo(20, "Prompt word wrap length limit", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1}).info("in tokens - for texts shorter than specified, if they don't fit into 75 token limit, move them to the next 75 token chunk"),
|
"comma_padding_backtrack": OptionInfo(20, "Prompt word wrap length limit", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1}).info("in tokens - for texts shorter than specified, if they don't fit into 75 token limit, move them to the next 75 token chunk"),
|
||||||
"CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#clip-skip").info("ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer"),
|
"CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#clip-skip").info("ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer"),
|
||||||
"upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"),
|
"upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"),
|
||||||
"randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU"]}).info("changes seeds drastically; use CPU to produce the same picture across different videocard vendors"),
|
"auto_vae_precision": OptionInfo(True, "Automaticlly revert VAE to 32-bit floats").info("triggers when a tensor with NaNs is produced in VAE; disabling the option in this case will result in a black square image"),
|
||||||
|
"randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU", "NV"]}).info("changes seeds drastically; use CPU to produce the same picture across different videocard vendors; use NV to produce same picture as on NVidia videocards"),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
|
options_templates.update(options_section(('sdxl', "Stable Diffusion XL"), {
|
||||||
|
"sdxl_crop_top": OptionInfo(0, "crop top coordinate"),
|
||||||
|
"sdxl_crop_left": OptionInfo(0, "crop left coordinate"),
|
||||||
|
"sdxl_refiner_low_aesthetic_score": OptionInfo(2.5, "SDXL low aesthetic score", gr.Number).info("used for refiner model negative prompt"),
|
||||||
|
"sdxl_refiner_high_aesthetic_score": OptionInfo(6.0, "SDXL high aesthetic score", gr.Number).info("used for refiner model prompt"),
|
||||||
|
}))
|
||||||
|
|
||||||
|
|
||||||
|
options_templates.update(options_section(('img2img', "img2img"), {
|
||||||
|
"inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
||||||
|
"initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.5, "maximum": 1.5, "step": 0.01}),
|
||||||
|
"img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
|
||||||
|
"img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies.").info("normally you'd do less with less denoising"),
|
||||||
|
"img2img_background_color": OptionInfo("#ffffff", "With img2img, fill transparent parts of the input image with this color.", ui_components.FormColorPicker, {}),
|
||||||
|
"img2img_editor_height": OptionInfo(720, "Height of the image editor", gr.Slider, {"minimum": 80, "maximum": 1600, "step": 1}).info("in pixels").needs_restart(),
|
||||||
|
"img2img_sketch_default_brush_color": OptionInfo("#ffffff", "Sketch initial brush color", ui_components.FormColorPicker, {}).info("default brush color of img2img sketch").needs_restart(),
|
||||||
|
"img2img_inpaint_mask_brush_color": OptionInfo("#ffffff", "Inpaint mask brush color", ui_components.FormColorPicker, {}).info("brush color of inpaint mask").needs_restart(),
|
||||||
|
"img2img_inpaint_sketch_default_brush_color": OptionInfo("#ffffff", "Inpaint sketch initial brush color", ui_components.FormColorPicker, {}).info("default brush color of img2img inpaint sketch").needs_restart(),
|
||||||
|
"return_mask": OptionInfo(False, "For inpainting, include the greyscale mask in results for web"),
|
||||||
|
"return_mask_composite": OptionInfo(False, "For inpainting, include masked composite in results for web"),
|
||||||
|
}))
|
||||||
|
|
||||||
|
|
||||||
options_templates.update(options_section(('optimizations', "Optimizations"), {
|
options_templates.update(options_section(('optimizations', "Optimizations"), {
|
||||||
"cross_attention_optimization": OptionInfo("Automatic", "Cross attention optimization", gr.Dropdown, lambda: {"choices": shared_items.cross_attention_optimizations()}),
|
"cross_attention_optimization": OptionInfo("Automatic", "Cross attention optimization", gr.Dropdown, lambda: {"choices": shared_items.cross_attention_optimizations()}),
|
||||||
"s_min_uncond": OptionInfo(0.0, "Negative Guidance minimum sigma", gr.Slider, {"minimum": 0.0, "maximum": 4.0, "step": 0.01}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9177").info("skip negative prompt for some steps when the image is almost ready; 0=disable, higher=faster"),
|
"s_min_uncond": OptionInfo(0.0, "Negative Guidance minimum sigma", gr.Slider, {"minimum": 0.0, "maximum": 15.0, "step": 0.01}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9177").info("skip negative prompt for some steps when the image is almost ready; 0=disable, higher=faster"),
|
||||||
"token_merging_ratio": OptionInfo(0.0, "Token merging ratio", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9256").info("0=disable, higher=faster"),
|
"token_merging_ratio": OptionInfo(0.0, "Token merging ratio", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9256").info("0=disable, higher=faster"),
|
||||||
"token_merging_ratio_img2img": OptionInfo(0.0, "Token merging ratio for img2img", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).info("only applies if non-zero and overrides above"),
|
"token_merging_ratio_img2img": OptionInfo(0.0, "Token merging ratio for img2img", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).info("only applies if non-zero and overrides above"),
|
||||||
"token_merging_ratio_hr": OptionInfo(0.0, "Token merging ratio for high-res pass", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).info("only applies if non-zero and overrides above"),
|
"token_merging_ratio_hr": OptionInfo(0.0, "Token merging ratio for high-res pass", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).info("only applies if non-zero and overrides above"),
|
||||||
@ -448,7 +479,7 @@ options_templates.update(options_section(('compatibility', "Compatibility"), {
|
|||||||
"hires_fix_use_firstpass_conds": OptionInfo(False, "For hires fix, calculate conds of second pass using extra networks of first pass."),
|
"hires_fix_use_firstpass_conds": OptionInfo(False, "For hires fix, calculate conds of second pass using extra networks of first pass."),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('interrogate', "Interrogate Options"), {
|
options_templates.update(options_section(('interrogate', "Interrogate"), {
|
||||||
"interrogate_keep_models_in_memory": OptionInfo(False, "Keep models in VRAM"),
|
"interrogate_keep_models_in_memory": OptionInfo(False, "Keep models in VRAM"),
|
||||||
"interrogate_return_ranks": OptionInfo(False, "Include ranks of model tags matches in results.").info("booru only"),
|
"interrogate_return_ranks": OptionInfo(False, "Include ranks of model tags matches in results.").info("booru only"),
|
||||||
"interrogate_clip_num_beams": OptionInfo(1, "BLIP: num_beams", gr.Slider, {"minimum": 1, "maximum": 16, "step": 1}),
|
"interrogate_clip_num_beams": OptionInfo(1, "BLIP: num_beams", gr.Slider, {"minimum": 1, "maximum": 16, "step": 1}),
|
||||||
@ -481,10 +512,7 @@ options_templates.update(options_section(('extra_networks', "Extra Networks"), {
|
|||||||
options_templates.update(options_section(('ui', "User interface"), {
|
options_templates.update(options_section(('ui', "User interface"), {
|
||||||
"localization": OptionInfo("None", "Localization", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)).needs_restart(),
|
"localization": OptionInfo("None", "Localization", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)).needs_restart(),
|
||||||
"gradio_theme": OptionInfo("Default", "Gradio theme", ui_components.DropdownEditable, lambda: {"choices": ["Default"] + gradio_hf_hub_themes}).needs_restart(),
|
"gradio_theme": OptionInfo("Default", "Gradio theme", ui_components.DropdownEditable, lambda: {"choices": ["Default"] + gradio_hf_hub_themes}).needs_restart(),
|
||||||
"img2img_editor_height": OptionInfo(720, "img2img: height of image editor", gr.Slider, {"minimum": 80, "maximum": 1600, "step": 1}).info("in pixels").needs_restart(),
|
|
||||||
"return_grid": OptionInfo(True, "Show grid in results for web"),
|
"return_grid": OptionInfo(True, "Show grid in results for web"),
|
||||||
"return_mask": OptionInfo(False, "For inpainting, include the greyscale mask in results for web"),
|
|
||||||
"return_mask_composite": OptionInfo(False, "For inpainting, include masked composite in results for web"),
|
|
||||||
"do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),
|
"do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),
|
||||||
"send_seed": OptionInfo(True, "Send seed when sending prompt or image to other interface"),
|
"send_seed": OptionInfo(True, "Send seed when sending prompt or image to other interface"),
|
||||||
"send_size": OptionInfo(True, "Send size when sending prompt or image to another interface"),
|
"send_size": OptionInfo(True, "Send size when sending prompt or image to another interface"),
|
||||||
@ -503,11 +531,12 @@ options_templates.update(options_section(('ui', "User interface"), {
|
|||||||
"ui_tab_order": OptionInfo([], "UI tab order", ui_components.DropdownMulti, lambda: {"choices": list(tab_names)}).needs_restart(),
|
"ui_tab_order": OptionInfo([], "UI tab order", ui_components.DropdownMulti, lambda: {"choices": list(tab_names)}).needs_restart(),
|
||||||
"hidden_tabs": OptionInfo([], "Hidden UI tabs", ui_components.DropdownMulti, lambda: {"choices": list(tab_names)}).needs_restart(),
|
"hidden_tabs": OptionInfo([], "Hidden UI tabs", ui_components.DropdownMulti, lambda: {"choices": list(tab_names)}).needs_restart(),
|
||||||
"ui_reorder_list": OptionInfo([], "txt2img/img2img UI item order", ui_components.DropdownMulti, lambda: {"choices": list(shared_items.ui_reorder_categories())}).info("selected items appear first").needs_restart(),
|
"ui_reorder_list": OptionInfo([], "txt2img/img2img UI item order", ui_components.DropdownMulti, lambda: {"choices": list(shared_items.ui_reorder_categories())}).info("selected items appear first").needs_restart(),
|
||||||
"hires_fix_show_sampler": OptionInfo(False, "Hires fix: show hires sampler selection").needs_restart(),
|
"hires_fix_show_sampler": OptionInfo(False, "Hires fix: show hires checkpoint and sampler selection").needs_restart(),
|
||||||
"hires_fix_show_prompts": OptionInfo(False, "Hires fix: show hires prompt and negative prompt").needs_restart(),
|
"hires_fix_show_prompts": OptionInfo(False, "Hires fix: show hires prompt and negative prompt").needs_restart(),
|
||||||
"disable_token_counters": OptionInfo(False, "Disable prompt token counters").needs_restart(),
|
"disable_token_counters": OptionInfo(False, "Disable prompt token counters").needs_restart(),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
|
|
||||||
options_templates.update(options_section(('infotext', "Infotext"), {
|
options_templates.update(options_section(('infotext', "Infotext"), {
|
||||||
"add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
|
"add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
|
||||||
"add_model_name_to_info": OptionInfo(True, "Add model name to generation information"),
|
"add_model_name_to_info": OptionInfo(True, "Add model name to generation information"),
|
||||||
@ -880,3 +909,10 @@ def walk_files(path, allowed_extensions=None):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
yield os.path.join(root, filename)
|
yield os.path.join(root, filename)
|
||||||
|
|
||||||
|
|
||||||
|
def ldm_print(*args, **kwargs):
|
||||||
|
if opts.hide_ldm_prints:
|
||||||
|
return
|
||||||
|
|
||||||
|
print(*args, **kwargs)
|
||||||
|
@ -106,10 +106,7 @@ class StyleDatabase:
|
|||||||
if os.path.exists(path):
|
if os.path.exists(path):
|
||||||
shutil.copy(path, f"{path}.bak")
|
shutil.copy(path, f"{path}.bak")
|
||||||
|
|
||||||
fd = os.open(path, os.O_RDWR | os.O_CREAT)
|
with open(path, "w", encoding="utf-8-sig", newline='') as file:
|
||||||
with os.fdopen(fd, "w", encoding="utf-8-sig", newline='') as file:
|
|
||||||
# _fields is actually part of the public API: typing.NamedTuple is a replacement for collections.NamedTuple,
|
|
||||||
# and collections.NamedTuple has explicit documentation for accessing _fields. Same goes for _asdict()
|
|
||||||
writer = csv.DictWriter(file, fieldnames=PromptStyle._fields)
|
writer = csv.DictWriter(file, fieldnames=PromptStyle._fields)
|
||||||
writer.writeheader()
|
writer.writeheader()
|
||||||
writer.writerows(style._asdict() for k, style in self.styles.items())
|
writer.writerows(style._asdict() for k, style in self.styles.items())
|
||||||
|
@ -109,11 +109,15 @@ def format_traceback(tb):
|
|||||||
return [[f"{x.filename}, line {x.lineno}, {x.name}", x.line] for x in traceback.extract_tb(tb)]
|
return [[f"{x.filename}, line {x.lineno}, {x.name}", x.line] for x in traceback.extract_tb(tb)]
|
||||||
|
|
||||||
|
|
||||||
|
def format_exception(e, tb):
|
||||||
|
return {"exception": str(e), "traceback": format_traceback(tb)}
|
||||||
|
|
||||||
|
|
||||||
def get_exceptions():
|
def get_exceptions():
|
||||||
try:
|
try:
|
||||||
from modules import errors
|
from modules import errors
|
||||||
|
|
||||||
return [{"exception": str(e), "traceback": format_traceback(tb)} for e, tb in reversed(errors.exception_records)]
|
return list(reversed(errors.exception_records))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return str(e)
|
return str(e)
|
||||||
|
|
||||||
|
@ -13,7 +13,7 @@ import numpy as np
|
|||||||
from PIL import Image, PngImagePlugin
|
from PIL import Image, PngImagePlugin
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from modules import shared, devices, sd_hijack, processing, sd_models, images, sd_samplers, sd_hijack_checkpoint, errors, hashes
|
from modules import shared, devices, sd_hijack, sd_models, images, sd_samplers, sd_hijack_checkpoint, errors, hashes
|
||||||
import modules.textual_inversion.dataset
|
import modules.textual_inversion.dataset
|
||||||
from modules.textual_inversion.learn_schedule import LearnRateScheduler
|
from modules.textual_inversion.learn_schedule import LearnRateScheduler
|
||||||
|
|
||||||
@ -181,29 +181,38 @@ class EmbeddingDatabase:
|
|||||||
else:
|
else:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
# textual inversion embeddings
|
# textual inversion embeddings
|
||||||
if 'string_to_param' in data:
|
if 'string_to_param' in data:
|
||||||
param_dict = data['string_to_param']
|
param_dict = data['string_to_param']
|
||||||
param_dict = getattr(param_dict, '_parameters', param_dict) # fix for torch 1.12.1 loading saved file from torch 1.11
|
param_dict = getattr(param_dict, '_parameters', param_dict) # fix for torch 1.12.1 loading saved file from torch 1.11
|
||||||
assert len(param_dict) == 1, 'embedding file has multiple terms in it'
|
assert len(param_dict) == 1, 'embedding file has multiple terms in it'
|
||||||
emb = next(iter(param_dict.items()))[1]
|
emb = next(iter(param_dict.items()))[1]
|
||||||
# diffuser concepts
|
vec = emb.detach().to(devices.device, dtype=torch.float32)
|
||||||
elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor:
|
shape = vec.shape[-1]
|
||||||
|
vectors = vec.shape[0]
|
||||||
|
elif type(data) == dict and 'clip_g' in data and 'clip_l' in data: # SDXL embedding
|
||||||
|
vec = {k: v.detach().to(devices.device, dtype=torch.float32) for k, v in data.items()}
|
||||||
|
shape = data['clip_g'].shape[-1] + data['clip_l'].shape[-1]
|
||||||
|
vectors = data['clip_g'].shape[0]
|
||||||
|
elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: # diffuser concepts
|
||||||
assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
|
assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
|
||||||
|
|
||||||
emb = next(iter(data.values()))
|
emb = next(iter(data.values()))
|
||||||
if len(emb.shape) == 1:
|
if len(emb.shape) == 1:
|
||||||
emb = emb.unsqueeze(0)
|
emb = emb.unsqueeze(0)
|
||||||
|
vec = emb.detach().to(devices.device, dtype=torch.float32)
|
||||||
|
shape = vec.shape[-1]
|
||||||
|
vectors = vec.shape[0]
|
||||||
else:
|
else:
|
||||||
raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.")
|
raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.")
|
||||||
|
|
||||||
vec = emb.detach().to(devices.device, dtype=torch.float32)
|
|
||||||
embedding = Embedding(vec, name)
|
embedding = Embedding(vec, name)
|
||||||
embedding.step = data.get('step', None)
|
embedding.step = data.get('step', None)
|
||||||
embedding.sd_checkpoint = data.get('sd_checkpoint', None)
|
embedding.sd_checkpoint = data.get('sd_checkpoint', None)
|
||||||
embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
|
embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
|
||||||
embedding.vectors = vec.shape[0]
|
embedding.vectors = vectors
|
||||||
embedding.shape = vec.shape[-1]
|
embedding.shape = shape
|
||||||
embedding.filename = path
|
embedding.filename = path
|
||||||
embedding.set_hash(hashes.sha256(embedding.filename, "textual_inversion/" + name) or '')
|
embedding.set_hash(hashes.sha256(embedding.filename, "textual_inversion/" + name) or '')
|
||||||
|
|
||||||
@ -378,6 +387,8 @@ def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, dat
|
|||||||
|
|
||||||
|
|
||||||
def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, use_weight, create_image_every, save_embedding_every, template_filename, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
|
def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, use_weight, create_image_every, save_embedding_every, template_filename, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
|
||||||
|
from modules import processing
|
||||||
|
|
||||||
save_embedding_every = save_embedding_every or 0
|
save_embedding_every = save_embedding_every or 0
|
||||||
create_image_every = create_image_every or 0
|
create_image_every = create_image_every or 0
|
||||||
template_file = textual_inversion_templates.get(template_filename, None)
|
template_file = textual_inversion_templates.get(template_filename, None)
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import time
|
import time
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
|
||||||
class TimerSubcategory:
|
class TimerSubcategory:
|
||||||
@ -11,20 +12,27 @@ class TimerSubcategory:
|
|||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
self.start = time.time()
|
self.start = time.time()
|
||||||
self.timer.base_category = self.original_base_category + self.category + "/"
|
self.timer.base_category = self.original_base_category + self.category + "/"
|
||||||
|
self.timer.subcategory_level += 1
|
||||||
|
|
||||||
|
if self.timer.print_log:
|
||||||
|
print(f"{' ' * self.timer.subcategory_level}{self.category}:")
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
elapsed_for_subcategroy = time.time() - self.start
|
elapsed_for_subcategroy = time.time() - self.start
|
||||||
self.timer.base_category = self.original_base_category
|
self.timer.base_category = self.original_base_category
|
||||||
self.timer.add_time_to_record(self.original_base_category + self.category, elapsed_for_subcategroy)
|
self.timer.add_time_to_record(self.original_base_category + self.category, elapsed_for_subcategroy)
|
||||||
self.timer.record(self.category)
|
self.timer.subcategory_level -= 1
|
||||||
|
self.timer.record(self.category, disable_log=True)
|
||||||
|
|
||||||
|
|
||||||
class Timer:
|
class Timer:
|
||||||
def __init__(self):
|
def __init__(self, print_log=False):
|
||||||
self.start = time.time()
|
self.start = time.time()
|
||||||
self.records = {}
|
self.records = {}
|
||||||
self.total = 0
|
self.total = 0
|
||||||
self.base_category = ''
|
self.base_category = ''
|
||||||
|
self.print_log = print_log
|
||||||
|
self.subcategory_level = 0
|
||||||
|
|
||||||
def elapsed(self):
|
def elapsed(self):
|
||||||
end = time.time()
|
end = time.time()
|
||||||
@ -38,13 +46,16 @@ class Timer:
|
|||||||
|
|
||||||
self.records[category] += amount
|
self.records[category] += amount
|
||||||
|
|
||||||
def record(self, category, extra_time=0):
|
def record(self, category, extra_time=0, disable_log=False):
|
||||||
e = self.elapsed()
|
e = self.elapsed()
|
||||||
|
|
||||||
self.add_time_to_record(self.base_category + category, e + extra_time)
|
self.add_time_to_record(self.base_category + category, e + extra_time)
|
||||||
|
|
||||||
self.total += e + extra_time
|
self.total += e + extra_time
|
||||||
|
|
||||||
|
if self.print_log and not disable_log:
|
||||||
|
print(f"{' ' * self.subcategory_level}{category}: done in {e + extra_time:.3f}s")
|
||||||
|
|
||||||
def subcategory(self, name):
|
def subcategory(self, name):
|
||||||
self.elapsed()
|
self.elapsed()
|
||||||
|
|
||||||
@ -71,6 +82,10 @@ class Timer:
|
|||||||
self.__init__()
|
self.__init__()
|
||||||
|
|
||||||
|
|
||||||
startup_timer = Timer()
|
parser = argparse.ArgumentParser(add_help=False)
|
||||||
|
parser.add_argument("--log-startup", action='store_true', help="print a detailed log of what's happening at startup")
|
||||||
|
args = parser.parse_known_args()[0]
|
||||||
|
|
||||||
|
startup_timer = Timer(print_log=args.log_startup)
|
||||||
|
|
||||||
startup_record = None
|
startup_record = None
|
||||||
|
@ -9,7 +9,7 @@ from modules.ui import plaintext_to_html
|
|||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
|
|
||||||
def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_sampler_index: int, hr_prompt: str, hr_negative_prompt, override_settings_texts, request: gr.Request, *args):
|
def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_checkpoint_name: str, hr_sampler_index: int, hr_prompt: str, hr_negative_prompt, override_settings_texts, request: gr.Request, *args):
|
||||||
override_settings = create_override_settings_dict(override_settings_texts)
|
override_settings = create_override_settings_dict(override_settings_texts)
|
||||||
|
|
||||||
p = processing.StableDiffusionProcessingTxt2Img(
|
p = processing.StableDiffusionProcessingTxt2Img(
|
||||||
@ -41,6 +41,7 @@ def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, step
|
|||||||
hr_second_pass_steps=hr_second_pass_steps,
|
hr_second_pass_steps=hr_second_pass_steps,
|
||||||
hr_resize_x=hr_resize_x,
|
hr_resize_x=hr_resize_x,
|
||||||
hr_resize_y=hr_resize_y,
|
hr_resize_y=hr_resize_y,
|
||||||
|
hr_checkpoint_name=None if hr_checkpoint_name == 'Use same checkpoint' else hr_checkpoint_name,
|
||||||
hr_sampler_name=sd_samplers.samplers_for_img2img[hr_sampler_index - 1].name if hr_sampler_index != 0 else None,
|
hr_sampler_name=sd_samplers.samplers_for_img2img[hr_sampler_index - 1].name if hr_sampler_index != 0 else None,
|
||||||
hr_prompt=hr_prompt,
|
hr_prompt=hr_prompt,
|
||||||
hr_negative_prompt=hr_negative_prompt,
|
hr_negative_prompt=hr_negative_prompt,
|
||||||
|
427
modules/ui.py
427
modules/ui.py
@ -12,34 +12,30 @@ import numpy as np
|
|||||||
from PIL import Image, PngImagePlugin # noqa: F401
|
from PIL import Image, PngImagePlugin # noqa: F401
|
||||||
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
|
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
|
||||||
|
|
||||||
from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, errors, shared_items, ui_settings, timer, sysinfo
|
from modules import gradio_extensons # noqa: F401
|
||||||
|
from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, errors, shared_items, ui_settings, timer, sysinfo, ui_checkpoint_merger, ui_prompt_styles, scripts
|
||||||
from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML
|
from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML
|
||||||
from modules.paths import script_path
|
from modules.paths import script_path
|
||||||
from modules.ui_common import create_refresh_button
|
from modules.ui_common import create_refresh_button
|
||||||
from modules.ui_gradio_extensions import reload_javascript
|
from modules.ui_gradio_extensions import reload_javascript
|
||||||
|
|
||||||
|
|
||||||
from modules.shared import opts, cmd_opts
|
from modules.shared import opts, cmd_opts
|
||||||
|
|
||||||
import modules.codeformer_model
|
|
||||||
import modules.generation_parameters_copypaste as parameters_copypaste
|
import modules.generation_parameters_copypaste as parameters_copypaste
|
||||||
import modules.gfpgan_model
|
import modules.hypernetworks.ui as hypernetworks_ui
|
||||||
import modules.hypernetworks.ui
|
import modules.textual_inversion.ui as textual_inversion_ui
|
||||||
import modules.scripts
|
import modules.textual_inversion.textual_inversion as textual_inversion
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
import modules.styles
|
import modules.images
|
||||||
import modules.textual_inversion.ui
|
|
||||||
from modules import prompt_parser
|
from modules import prompt_parser
|
||||||
from modules.sd_hijack import model_hijack
|
from modules.sd_hijack import model_hijack
|
||||||
from modules.sd_samplers import samplers, samplers_for_img2img
|
from modules.sd_samplers import samplers, samplers_for_img2img
|
||||||
from modules.textual_inversion import textual_inversion
|
|
||||||
import modules.hypernetworks.ui
|
|
||||||
from modules.generation_parameters_copypaste import image_from_url_text
|
from modules.generation_parameters_copypaste import image_from_url_text
|
||||||
import modules.extras
|
|
||||||
|
|
||||||
create_setting_component = ui_settings.create_setting_component
|
create_setting_component = ui_settings.create_setting_component
|
||||||
|
|
||||||
warnings.filterwarnings("default" if opts.show_warnings else "ignore", category=UserWarning)
|
warnings.filterwarnings("default" if opts.show_warnings else "ignore", category=UserWarning)
|
||||||
|
warnings.filterwarnings("default" if opts.show_gradio_deprecation_warnings else "ignore", category=gr.deprecation.GradioDeprecationWarning)
|
||||||
|
|
||||||
# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI
|
# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI
|
||||||
mimetypes.init()
|
mimetypes.init()
|
||||||
@ -92,19 +88,6 @@ def send_gradio_gallery_to_image(x):
|
|||||||
return image_from_url_text(x[0])
|
return image_from_url_text(x[0])
|
||||||
|
|
||||||
|
|
||||||
def add_style(name: str, prompt: str, negative_prompt: str):
|
|
||||||
if name is None:
|
|
||||||
return [gr_show() for x in range(4)]
|
|
||||||
|
|
||||||
style = modules.styles.PromptStyle(name, prompt, negative_prompt)
|
|
||||||
shared.prompt_styles.styles[style.name] = style
|
|
||||||
# Save all loaded prompt styles: this allows us to update the storage format in the future more easily, because we
|
|
||||||
# reserialize all styles every time we save them
|
|
||||||
shared.prompt_styles.save_styles(shared.styles_filename)
|
|
||||||
|
|
||||||
return [gr.Dropdown.update(visible=True, choices=list(shared.prompt_styles.styles)) for _ in range(2)]
|
|
||||||
|
|
||||||
|
|
||||||
def calc_resolution_hires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y):
|
def calc_resolution_hires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y):
|
||||||
from modules import processing, devices
|
from modules import processing, devices
|
||||||
|
|
||||||
@ -129,13 +112,6 @@ def resize_from_to_html(width, height, scale_by):
|
|||||||
return f"resize: from <span class='resolution'>{width}x{height}</span> to <span class='resolution'>{target_width}x{target_height}</span>"
|
return f"resize: from <span class='resolution'>{width}x{height}</span> to <span class='resolution'>{target_width}x{target_height}</span>"
|
||||||
|
|
||||||
|
|
||||||
def apply_styles(prompt, prompt_neg, styles):
|
|
||||||
prompt = shared.prompt_styles.apply_styles_to_prompt(prompt, styles)
|
|
||||||
prompt_neg = shared.prompt_styles.apply_negative_styles_to_prompt(prompt_neg, styles)
|
|
||||||
|
|
||||||
return [gr.Textbox.update(value=prompt), gr.Textbox.update(value=prompt_neg), gr.Dropdown.update(value=[])]
|
|
||||||
|
|
||||||
|
|
||||||
def process_interrogate(interrogation_function, mode, ii_input_dir, ii_output_dir, *ii_singles):
|
def process_interrogate(interrogation_function, mode, ii_input_dir, ii_output_dir, *ii_singles):
|
||||||
if mode in {0, 1, 3, 4}:
|
if mode in {0, 1, 3, 4}:
|
||||||
return [interrogation_function(ii_singles[mode]), None]
|
return [interrogation_function(ii_singles[mode]), None]
|
||||||
@ -172,7 +148,6 @@ def interrogate_deepbooru(image):
|
|||||||
def create_seed_inputs(target_interface):
|
def create_seed_inputs(target_interface):
|
||||||
with FormRow(elem_id=f"{target_interface}_seed_row", variant="compact"):
|
with FormRow(elem_id=f"{target_interface}_seed_row", variant="compact"):
|
||||||
seed = (gr.Textbox if cmd_opts.use_textbox_seed else gr.Number)(label='Seed', value=-1, elem_id=f"{target_interface}_seed")
|
seed = (gr.Textbox if cmd_opts.use_textbox_seed else gr.Number)(label='Seed', value=-1, elem_id=f"{target_interface}_seed")
|
||||||
seed.style(container=False)
|
|
||||||
random_seed = ToolButton(random_symbol, elem_id=f"{target_interface}_random_seed", label='Random seed')
|
random_seed = ToolButton(random_symbol, elem_id=f"{target_interface}_random_seed", label='Random seed')
|
||||||
reuse_seed = ToolButton(reuse_symbol, elem_id=f"{target_interface}_reuse_seed", label='Reuse seed')
|
reuse_seed = ToolButton(reuse_symbol, elem_id=f"{target_interface}_reuse_seed", label='Reuse seed')
|
||||||
|
|
||||||
@ -184,7 +159,6 @@ def create_seed_inputs(target_interface):
|
|||||||
with FormRow(visible=False, elem_id=f"{target_interface}_subseed_row") as seed_extra_row_1:
|
with FormRow(visible=False, elem_id=f"{target_interface}_subseed_row") as seed_extra_row_1:
|
||||||
seed_extras.append(seed_extra_row_1)
|
seed_extras.append(seed_extra_row_1)
|
||||||
subseed = gr.Number(label='Variation seed', value=-1, elem_id=f"{target_interface}_subseed")
|
subseed = gr.Number(label='Variation seed', value=-1, elem_id=f"{target_interface}_subseed")
|
||||||
subseed.style(container=False)
|
|
||||||
random_subseed = ToolButton(random_symbol, elem_id=f"{target_interface}_random_subseed")
|
random_subseed = ToolButton(random_symbol, elem_id=f"{target_interface}_random_subseed")
|
||||||
reuse_subseed = ToolButton(reuse_symbol, elem_id=f"{target_interface}_reuse_subseed")
|
reuse_subseed = ToolButton(reuse_symbol, elem_id=f"{target_interface}_reuse_subseed")
|
||||||
subseed_strength = gr.Slider(label='Variation strength', value=0.0, minimum=0, maximum=1, step=0.01, elem_id=f"{target_interface}_subseed_strength")
|
subseed_strength = gr.Slider(label='Variation strength', value=0.0, minimum=0, maximum=1, step=0.01, elem_id=f"{target_interface}_subseed_strength")
|
||||||
@ -267,70 +241,77 @@ def update_token_counter(text, steps):
|
|||||||
return f"<span class='gr-box gr-text-input'>{token_count}/{max_length}</span>"
|
return f"<span class='gr-box gr-text-input'>{token_count}/{max_length}</span>"
|
||||||
|
|
||||||
|
|
||||||
def create_toprow(is_img2img):
|
class Toprow:
|
||||||
id_part = "img2img" if is_img2img else "txt2img"
|
"""Creates a top row UI with prompts, generate button, styles, extra little buttons for things, and enables some functionality related to their operation"""
|
||||||
|
|
||||||
with gr.Row(elem_id=f"{id_part}_toprow", variant="compact"):
|
def __init__(self, is_img2img):
|
||||||
with gr.Column(elem_id=f"{id_part}_prompt_container", scale=6):
|
id_part = "img2img" if is_img2img else "txt2img"
|
||||||
with gr.Row():
|
self.id_part = id_part
|
||||||
with gr.Column(scale=80):
|
|
||||||
with gr.Row():
|
|
||||||
prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, lines=3, placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)", elem_classes=["prompt"])
|
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row(elem_id=f"{id_part}_toprow", variant="compact"):
|
||||||
with gr.Column(scale=80):
|
with gr.Column(elem_id=f"{id_part}_prompt_container", scale=6):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)", elem_classes=["prompt"])
|
with gr.Column(scale=80):
|
||||||
|
with gr.Row():
|
||||||
|
self.prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, lines=3, placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)", elem_classes=["prompt"])
|
||||||
|
self.prompt_img = gr.File(label="", elem_id=f"{id_part}_prompt_image", file_count="single", type="binary", visible=False)
|
||||||
|
|
||||||
button_interrogate = None
|
with gr.Row():
|
||||||
button_deepbooru = None
|
with gr.Column(scale=80):
|
||||||
if is_img2img:
|
with gr.Row():
|
||||||
with gr.Column(scale=1, elem_classes="interrogate-col"):
|
self.negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)", elem_classes=["prompt"])
|
||||||
button_interrogate = gr.Button('Interrogate\nCLIP', elem_id="interrogate")
|
|
||||||
button_deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru")
|
|
||||||
|
|
||||||
with gr.Column(scale=1, elem_id=f"{id_part}_actions_column"):
|
self.button_interrogate = None
|
||||||
with gr.Row(elem_id=f"{id_part}_generate_box", elem_classes="generate-box"):
|
self.button_deepbooru = None
|
||||||
interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt", elem_classes="generate-box-interrupt")
|
if is_img2img:
|
||||||
skip = gr.Button('Skip', elem_id=f"{id_part}_skip", elem_classes="generate-box-skip")
|
with gr.Column(scale=1, elem_classes="interrogate-col"):
|
||||||
submit = gr.Button('Generate', elem_id=f"{id_part}_generate", variant='primary')
|
self.button_interrogate = gr.Button('Interrogate\nCLIP', elem_id="interrogate")
|
||||||
|
self.button_deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru")
|
||||||
|
|
||||||
skip.click(
|
with gr.Column(scale=1, elem_id=f"{id_part}_actions_column"):
|
||||||
fn=lambda: shared.state.skip(),
|
with gr.Row(elem_id=f"{id_part}_generate_box", elem_classes="generate-box"):
|
||||||
inputs=[],
|
self.interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt", elem_classes="generate-box-interrupt")
|
||||||
outputs=[],
|
self.skip = gr.Button('Skip', elem_id=f"{id_part}_skip", elem_classes="generate-box-skip")
|
||||||
)
|
self.submit = gr.Button('Generate', elem_id=f"{id_part}_generate", variant='primary')
|
||||||
|
|
||||||
interrupt.click(
|
self.skip.click(
|
||||||
fn=lambda: shared.state.interrupt(),
|
fn=lambda: shared.state.skip(),
|
||||||
inputs=[],
|
inputs=[],
|
||||||
outputs=[],
|
outputs=[],
|
||||||
)
|
)
|
||||||
|
|
||||||
with gr.Row(elem_id=f"{id_part}_tools"):
|
self.interrupt.click(
|
||||||
paste = ToolButton(value=paste_symbol, elem_id="paste")
|
fn=lambda: shared.state.interrupt(),
|
||||||
clear_prompt_button = ToolButton(value=clear_prompt_symbol, elem_id=f"{id_part}_clear_prompt")
|
inputs=[],
|
||||||
prompt_style_apply = ToolButton(value=apply_style_symbol, elem_id=f"{id_part}_style_apply")
|
outputs=[],
|
||||||
save_style = ToolButton(value=save_style_symbol, elem_id=f"{id_part}_style_create")
|
)
|
||||||
restore_progress_button = ToolButton(value=restore_progress_symbol, elem_id=f"{id_part}_restore_progress", visible=False)
|
|
||||||
|
|
||||||
token_counter = gr.HTML(value="<span>0/75</span>", elem_id=f"{id_part}_token_counter", elem_classes=["token-counter"])
|
with gr.Row(elem_id=f"{id_part}_tools"):
|
||||||
token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button")
|
self.paste = ToolButton(value=paste_symbol, elem_id="paste")
|
||||||
negative_token_counter = gr.HTML(value="<span>0/75</span>", elem_id=f"{id_part}_negative_token_counter", elem_classes=["token-counter"])
|
self.clear_prompt_button = ToolButton(value=clear_prompt_symbol, elem_id=f"{id_part}_clear_prompt")
|
||||||
negative_token_button = gr.Button(visible=False, elem_id=f"{id_part}_negative_token_button")
|
self.extra_networks_button = ToolButton(value=extra_networks_symbol, elem_id=f"{id_part}_extra_networks")
|
||||||
|
self.restore_progress_button = ToolButton(value=restore_progress_symbol, elem_id=f"{id_part}_restore_progress", visible=False)
|
||||||
|
|
||||||
clear_prompt_button.click(
|
self.token_counter = gr.HTML(value="<span>0/75</span>", elem_id=f"{id_part}_token_counter", elem_classes=["token-counter"])
|
||||||
fn=lambda *x: x,
|
self.token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button")
|
||||||
_js="confirm_clear_prompt",
|
self.negative_token_counter = gr.HTML(value="<span>0/75</span>", elem_id=f"{id_part}_negative_token_counter", elem_classes=["token-counter"])
|
||||||
inputs=[prompt, negative_prompt],
|
self.negative_token_button = gr.Button(visible=False, elem_id=f"{id_part}_negative_token_button")
|
||||||
outputs=[prompt, negative_prompt],
|
|
||||||
)
|
|
||||||
|
|
||||||
with gr.Row(elem_id=f"{id_part}_styles_row"):
|
self.clear_prompt_button.click(
|
||||||
prompt_styles = gr.Dropdown(label="Styles", elem_id=f"{id_part}_styles", choices=[k for k, v in shared.prompt_styles.styles.items()], value=[], multiselect=True)
|
fn=lambda *x: x,
|
||||||
create_refresh_button(prompt_styles, shared.prompt_styles.reload, lambda: {"choices": [k for k, v in shared.prompt_styles.styles.items()]}, f"refresh_{id_part}_styles")
|
_js="confirm_clear_prompt",
|
||||||
|
inputs=[self.prompt, self.negative_prompt],
|
||||||
|
outputs=[self.prompt, self.negative_prompt],
|
||||||
|
)
|
||||||
|
|
||||||
return prompt, prompt_styles, negative_prompt, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, None, token_counter, token_button, negative_token_counter, negative_token_button, restore_progress_button
|
self.ui_styles = ui_prompt_styles.UiPromptStyles(id_part, self.prompt, self.negative_prompt)
|
||||||
|
|
||||||
|
self.prompt_img.change(
|
||||||
|
fn=modules.images.image_data,
|
||||||
|
inputs=[self.prompt_img],
|
||||||
|
outputs=[self.prompt, self.prompt_img],
|
||||||
|
show_progress=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def setup_progressbar(*args, **kwargs):
|
def setup_progressbar(*args, **kwargs):
|
||||||
@ -414,21 +395,20 @@ def create_ui():
|
|||||||
|
|
||||||
parameters_copypaste.reset()
|
parameters_copypaste.reset()
|
||||||
|
|
||||||
modules.scripts.scripts_current = modules.scripts.scripts_txt2img
|
scripts.scripts_current = scripts.scripts_txt2img
|
||||||
modules.scripts.scripts_txt2img.initialize_scripts(is_img2img=False)
|
scripts.scripts_txt2img.initialize_scripts(is_img2img=False)
|
||||||
|
|
||||||
with gr.Blocks(analytics_enabled=False) as txt2img_interface:
|
with gr.Blocks(analytics_enabled=False) as txt2img_interface:
|
||||||
txt2img_prompt, txt2img_prompt_styles, txt2img_negative_prompt, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, _, token_counter, token_button, negative_token_counter, negative_token_button, restore_progress_button = create_toprow(is_img2img=False)
|
toprow = Toprow(is_img2img=False)
|
||||||
|
|
||||||
dummy_component = gr.Label(visible=False)
|
dummy_component = gr.Label(visible=False)
|
||||||
txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="binary", visible=False)
|
|
||||||
|
|
||||||
extra_tabs = gr.Tabs(elem_id="txt2img_extra_tabs")
|
extra_tabs = gr.Tabs(elem_id="txt2img_extra_tabs")
|
||||||
extra_tabs.__enter__()
|
extra_tabs.__enter__()
|
||||||
|
|
||||||
with gr.Tab("Generation", id="txt2img_generation") as txt2img_generation_tab, gr.Row().style(equal_height=False):
|
with gr.Tab("Generation", id="txt2img_generation") as txt2img_generation_tab, gr.Row().style(equal_height=False):
|
||||||
with gr.Column(variant='compact', elem_id="txt2img_settings"):
|
with gr.Column(variant='compact', elem_id="txt2img_settings"):
|
||||||
modules.scripts.scripts_txt2img.prepare_ui()
|
scripts.scripts_txt2img.prepare_ui()
|
||||||
|
|
||||||
for category in ordered_ui_categories():
|
for category in ordered_ui_categories():
|
||||||
if category == "sampler":
|
if category == "sampler":
|
||||||
@ -474,6 +454,10 @@ def create_ui():
|
|||||||
hr_resize_y = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize height to", value=0, elem_id="txt2img_hr_resize_y")
|
hr_resize_y = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize height to", value=0, elem_id="txt2img_hr_resize_y")
|
||||||
|
|
||||||
with FormRow(elem_id="txt2img_hires_fix_row3", variant="compact", visible=opts.hires_fix_show_sampler) as hr_sampler_container:
|
with FormRow(elem_id="txt2img_hires_fix_row3", variant="compact", visible=opts.hires_fix_show_sampler) as hr_sampler_container:
|
||||||
|
|
||||||
|
hr_checkpoint_name = gr.Dropdown(label='Hires checkpoint', elem_id="hr_checkpoint", choices=["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True), value="Use same checkpoint")
|
||||||
|
create_refresh_button(hr_checkpoint_name, modules.sd_models.list_models, lambda: {"choices": ["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True)}, "hr_checkpoint_refresh")
|
||||||
|
|
||||||
hr_sampler_index = gr.Dropdown(label='Hires sampling method', elem_id="hr_sampler", choices=["Use same sampler"] + [x.name for x in samplers_for_img2img], value="Use same sampler", type="index")
|
hr_sampler_index = gr.Dropdown(label='Hires sampling method', elem_id="hr_sampler", choices=["Use same sampler"] + [x.name for x in samplers_for_img2img], value="Use same sampler", type="index")
|
||||||
|
|
||||||
with FormRow(elem_id="txt2img_hires_fix_row4", variant="compact", visible=opts.hires_fix_show_prompts) as hr_prompts_container:
|
with FormRow(elem_id="txt2img_hires_fix_row4", variant="compact", visible=opts.hires_fix_show_prompts) as hr_prompts_container:
|
||||||
@ -496,10 +480,10 @@ def create_ui():
|
|||||||
|
|
||||||
elif category == "scripts":
|
elif category == "scripts":
|
||||||
with FormGroup(elem_id="txt2img_script_container"):
|
with FormGroup(elem_id="txt2img_script_container"):
|
||||||
custom_inputs = modules.scripts.scripts_txt2img.setup_ui()
|
custom_inputs = scripts.scripts_txt2img.setup_ui()
|
||||||
|
|
||||||
else:
|
else:
|
||||||
modules.scripts.scripts_txt2img.setup_ui_for_section(category)
|
scripts.scripts_txt2img.setup_ui_for_section(category)
|
||||||
|
|
||||||
hr_resolution_preview_inputs = [enable_hr, width, height, hr_scale, hr_resize_x, hr_resize_y]
|
hr_resolution_preview_inputs = [enable_hr, width, height, hr_scale, hr_resize_x, hr_resize_y]
|
||||||
|
|
||||||
@ -530,9 +514,9 @@ def create_ui():
|
|||||||
_js="submit",
|
_js="submit",
|
||||||
inputs=[
|
inputs=[
|
||||||
dummy_component,
|
dummy_component,
|
||||||
txt2img_prompt,
|
toprow.prompt,
|
||||||
txt2img_negative_prompt,
|
toprow.negative_prompt,
|
||||||
txt2img_prompt_styles,
|
toprow.ui_styles.dropdown,
|
||||||
steps,
|
steps,
|
||||||
sampler_index,
|
sampler_index,
|
||||||
restore_faces,
|
restore_faces,
|
||||||
@ -551,6 +535,7 @@ def create_ui():
|
|||||||
hr_second_pass_steps,
|
hr_second_pass_steps,
|
||||||
hr_resize_x,
|
hr_resize_x,
|
||||||
hr_resize_y,
|
hr_resize_y,
|
||||||
|
hr_checkpoint_name,
|
||||||
hr_sampler_index,
|
hr_sampler_index,
|
||||||
hr_prompt,
|
hr_prompt,
|
||||||
hr_negative_prompt,
|
hr_negative_prompt,
|
||||||
@ -567,12 +552,12 @@ def create_ui():
|
|||||||
show_progress=False,
|
show_progress=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
txt2img_prompt.submit(**txt2img_args)
|
toprow.prompt.submit(**txt2img_args)
|
||||||
submit.click(**txt2img_args)
|
toprow.submit.click(**txt2img_args)
|
||||||
|
|
||||||
res_switch_btn.click(fn=None, _js="function(){switchWidthHeight('txt2img')}", inputs=None, outputs=None, show_progress=False)
|
res_switch_btn.click(fn=None, _js="function(){switchWidthHeight('txt2img')}", inputs=None, outputs=None, show_progress=False)
|
||||||
|
|
||||||
restore_progress_button.click(
|
toprow.restore_progress_button.click(
|
||||||
fn=progress.restore_progress,
|
fn=progress.restore_progress,
|
||||||
_js="restoreProgressTxt2img",
|
_js="restoreProgressTxt2img",
|
||||||
inputs=[dummy_component],
|
inputs=[dummy_component],
|
||||||
@ -585,18 +570,6 @@ def create_ui():
|
|||||||
show_progress=False,
|
show_progress=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
txt_prompt_img.change(
|
|
||||||
fn=modules.images.image_data,
|
|
||||||
inputs=[
|
|
||||||
txt_prompt_img
|
|
||||||
],
|
|
||||||
outputs=[
|
|
||||||
txt2img_prompt,
|
|
||||||
txt_prompt_img
|
|
||||||
],
|
|
||||||
show_progress=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
enable_hr.change(
|
enable_hr.change(
|
||||||
fn=lambda x: gr_show(x),
|
fn=lambda x: gr_show(x),
|
||||||
inputs=[enable_hr],
|
inputs=[enable_hr],
|
||||||
@ -605,8 +578,8 @@ def create_ui():
|
|||||||
)
|
)
|
||||||
|
|
||||||
txt2img_paste_fields = [
|
txt2img_paste_fields = [
|
||||||
(txt2img_prompt, "Prompt"),
|
(toprow.prompt, "Prompt"),
|
||||||
(txt2img_negative_prompt, "Negative prompt"),
|
(toprow.negative_prompt, "Negative prompt"),
|
||||||
(steps, "Steps"),
|
(steps, "Steps"),
|
||||||
(sampler_index, "Sampler"),
|
(sampler_index, "Sampler"),
|
||||||
(restore_faces, "Face restoration"),
|
(restore_faces, "Face restoration"),
|
||||||
@ -615,34 +588,36 @@ def create_ui():
|
|||||||
(width, "Size-1"),
|
(width, "Size-1"),
|
||||||
(height, "Size-2"),
|
(height, "Size-2"),
|
||||||
(batch_size, "Batch size"),
|
(batch_size, "Batch size"),
|
||||||
|
(seed_checkbox, lambda d: "Variation seed" in d or "Seed resize from-1" in d),
|
||||||
(subseed, "Variation seed"),
|
(subseed, "Variation seed"),
|
||||||
(subseed_strength, "Variation seed strength"),
|
(subseed_strength, "Variation seed strength"),
|
||||||
(seed_resize_from_w, "Seed resize from-1"),
|
(seed_resize_from_w, "Seed resize from-1"),
|
||||||
(seed_resize_from_h, "Seed resize from-2"),
|
(seed_resize_from_h, "Seed resize from-2"),
|
||||||
(txt2img_prompt_styles, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()),
|
(toprow.ui_styles.dropdown, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()),
|
||||||
(denoising_strength, "Denoising strength"),
|
(denoising_strength, "Denoising strength"),
|
||||||
(enable_hr, lambda d: "Denoising strength" in d),
|
(enable_hr, lambda d: "Denoising strength" in d and ("Hires upscale" in d or "Hires upscaler" in d or "Hires resize-1" in d)),
|
||||||
(hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)),
|
(hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d and ("Hires upscale" in d or "Hires upscaler" in d or "Hires resize-1" in d))),
|
||||||
(hr_scale, "Hires upscale"),
|
(hr_scale, "Hires upscale"),
|
||||||
(hr_upscaler, "Hires upscaler"),
|
(hr_upscaler, "Hires upscaler"),
|
||||||
(hr_second_pass_steps, "Hires steps"),
|
(hr_second_pass_steps, "Hires steps"),
|
||||||
(hr_resize_x, "Hires resize-1"),
|
(hr_resize_x, "Hires resize-1"),
|
||||||
(hr_resize_y, "Hires resize-2"),
|
(hr_resize_y, "Hires resize-2"),
|
||||||
|
(hr_checkpoint_name, "Hires checkpoint"),
|
||||||
(hr_sampler_index, "Hires sampler"),
|
(hr_sampler_index, "Hires sampler"),
|
||||||
(hr_sampler_container, lambda d: gr.update(visible=True) if d.get("Hires sampler", "Use same sampler") != "Use same sampler" else gr.update()),
|
(hr_sampler_container, lambda d: gr.update(visible=True) if d.get("Hires sampler", "Use same sampler") != "Use same sampler" or d.get("Hires checkpoint", "Use same checkpoint") != "Use same checkpoint" else gr.update()),
|
||||||
(hr_prompt, "Hires prompt"),
|
(hr_prompt, "Hires prompt"),
|
||||||
(hr_negative_prompt, "Hires negative prompt"),
|
(hr_negative_prompt, "Hires negative prompt"),
|
||||||
(hr_prompts_container, lambda d: gr.update(visible=True) if d.get("Hires prompt", "") != "" or d.get("Hires negative prompt", "") != "" else gr.update()),
|
(hr_prompts_container, lambda d: gr.update(visible=True) if d.get("Hires prompt", "") != "" or d.get("Hires negative prompt", "") != "" else gr.update()),
|
||||||
*modules.scripts.scripts_txt2img.infotext_fields
|
*scripts.scripts_txt2img.infotext_fields
|
||||||
]
|
]
|
||||||
parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields, override_settings)
|
parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields, override_settings)
|
||||||
parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding(
|
parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding(
|
||||||
paste_button=txt2img_paste, tabname="txt2img", source_text_component=txt2img_prompt, source_image_component=None,
|
paste_button=toprow.paste, tabname="txt2img", source_text_component=toprow.prompt, source_image_component=None,
|
||||||
))
|
))
|
||||||
|
|
||||||
txt2img_preview_params = [
|
txt2img_preview_params = [
|
||||||
txt2img_prompt,
|
toprow.prompt,
|
||||||
txt2img_negative_prompt,
|
toprow.negative_prompt,
|
||||||
steps,
|
steps,
|
||||||
sampler_index,
|
sampler_index,
|
||||||
cfg_scale,
|
cfg_scale,
|
||||||
@ -651,8 +626,8 @@ def create_ui():
|
|||||||
height,
|
height,
|
||||||
]
|
]
|
||||||
|
|
||||||
token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_prompt, steps], outputs=[token_counter])
|
toprow.token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[toprow.prompt, steps], outputs=[toprow.token_counter])
|
||||||
negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_negative_prompt, steps], outputs=[negative_token_counter])
|
toprow.negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[toprow.negative_prompt, steps], outputs=[toprow.negative_token_counter])
|
||||||
|
|
||||||
from modules import ui_extra_networks
|
from modules import ui_extra_networks
|
||||||
extra_networks_ui = ui_extra_networks.create_ui(txt2img_interface, [txt2img_generation_tab], 'txt2img')
|
extra_networks_ui = ui_extra_networks.create_ui(txt2img_interface, [txt2img_generation_tab], 'txt2img')
|
||||||
@ -660,13 +635,11 @@ def create_ui():
|
|||||||
|
|
||||||
extra_tabs.__exit__()
|
extra_tabs.__exit__()
|
||||||
|
|
||||||
modules.scripts.scripts_current = modules.scripts.scripts_img2img
|
scripts.scripts_current = scripts.scripts_img2img
|
||||||
modules.scripts.scripts_img2img.initialize_scripts(is_img2img=True)
|
scripts.scripts_img2img.initialize_scripts(is_img2img=True)
|
||||||
|
|
||||||
with gr.Blocks(analytics_enabled=False) as img2img_interface:
|
with gr.Blocks(analytics_enabled=False) as img2img_interface:
|
||||||
img2img_prompt, img2img_prompt_styles, img2img_negative_prompt, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste, _, token_counter, token_button, negative_token_counter, negative_token_button, restore_progress_button = create_toprow(is_img2img=True)
|
toprow = Toprow(is_img2img=True)
|
||||||
|
|
||||||
img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="binary", visible=False)
|
|
||||||
|
|
||||||
extra_tabs = gr.Tabs(elem_id="img2img_extra_tabs")
|
extra_tabs = gr.Tabs(elem_id="img2img_extra_tabs")
|
||||||
extra_tabs.__enter__()
|
extra_tabs.__enter__()
|
||||||
@ -693,19 +666,19 @@ def create_ui():
|
|||||||
img2img_selected_tab = gr.State(0)
|
img2img_selected_tab = gr.State(0)
|
||||||
|
|
||||||
with gr.TabItem('img2img', id='img2img', elem_id="img2img_img2img_tab") as tab_img2img:
|
with gr.TabItem('img2img', id='img2img', elem_id="img2img_img2img_tab") as tab_img2img:
|
||||||
init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool="editor", image_mode="RGBA").style(height=opts.img2img_editor_height)
|
init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool="editor", image_mode="RGBA", height=opts.img2img_editor_height)
|
||||||
add_copy_image_controls('img2img', init_img)
|
add_copy_image_controls('img2img', init_img)
|
||||||
|
|
||||||
with gr.TabItem('Sketch', id='img2img_sketch', elem_id="img2img_img2img_sketch_tab") as tab_sketch:
|
with gr.TabItem('Sketch', id='img2img_sketch', elem_id="img2img_img2img_sketch_tab") as tab_sketch:
|
||||||
sketch = gr.Image(label="Image for img2img", elem_id="img2img_sketch", show_label=False, source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGBA").style(height=opts.img2img_editor_height)
|
sketch = gr.Image(label="Image for img2img", elem_id="img2img_sketch", show_label=False, source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGBA", height=opts.img2img_editor_height, brush_color=opts.img2img_sketch_default_brush_color)
|
||||||
add_copy_image_controls('sketch', sketch)
|
add_copy_image_controls('sketch', sketch)
|
||||||
|
|
||||||
with gr.TabItem('Inpaint', id='inpaint', elem_id="img2img_inpaint_tab") as tab_inpaint:
|
with gr.TabItem('Inpaint', id='inpaint', elem_id="img2img_inpaint_tab") as tab_inpaint:
|
||||||
init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA").style(height=opts.img2img_editor_height)
|
init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA", height=opts.img2img_editor_height, brush_color=opts.img2img_inpaint_mask_brush_color)
|
||||||
add_copy_image_controls('inpaint', init_img_with_mask)
|
add_copy_image_controls('inpaint', init_img_with_mask)
|
||||||
|
|
||||||
with gr.TabItem('Inpaint sketch', id='inpaint_sketch', elem_id="img2img_inpaint_sketch_tab") as tab_inpaint_color:
|
with gr.TabItem('Inpaint sketch', id='inpaint_sketch', elem_id="img2img_inpaint_sketch_tab") as tab_inpaint_color:
|
||||||
inpaint_color_sketch = gr.Image(label="Color sketch inpainting", show_label=False, elem_id="inpaint_sketch", source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGBA").style(height=opts.img2img_editor_height)
|
inpaint_color_sketch = gr.Image(label="Color sketch inpainting", show_label=False, elem_id="inpaint_sketch", source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGBA", height=opts.img2img_editor_height, brush_color=opts.img2img_inpaint_sketch_default_brush_color)
|
||||||
inpaint_color_sketch_orig = gr.State(None)
|
inpaint_color_sketch_orig = gr.State(None)
|
||||||
add_copy_image_controls('inpaint_sketch', inpaint_color_sketch)
|
add_copy_image_controls('inpaint_sketch', inpaint_color_sketch)
|
||||||
|
|
||||||
@ -765,7 +738,7 @@ def create_ui():
|
|||||||
with FormRow():
|
with FormRow():
|
||||||
resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", choices=["Just resize", "Crop and resize", "Resize and fill", "Just resize (latent upscale)"], type="index", value="Just resize")
|
resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", choices=["Just resize", "Crop and resize", "Resize and fill", "Just resize (latent upscale)"], type="index", value="Just resize")
|
||||||
|
|
||||||
modules.scripts.scripts_img2img.prepare_ui()
|
scripts.scripts_img2img.prepare_ui()
|
||||||
|
|
||||||
for category in ordered_ui_categories():
|
for category in ordered_ui_categories():
|
||||||
if category == "sampler":
|
if category == "sampler":
|
||||||
@ -846,7 +819,7 @@ def create_ui():
|
|||||||
|
|
||||||
elif category == "scripts":
|
elif category == "scripts":
|
||||||
with FormGroup(elem_id="img2img_script_container"):
|
with FormGroup(elem_id="img2img_script_container"):
|
||||||
custom_inputs = modules.scripts.scripts_img2img.setup_ui()
|
custom_inputs = scripts.scripts_img2img.setup_ui()
|
||||||
|
|
||||||
elif category == "inpaint":
|
elif category == "inpaint":
|
||||||
with FormGroup(elem_id="inpaint_controls", visible=False) as inpaint_controls:
|
with FormGroup(elem_id="inpaint_controls", visible=False) as inpaint_controls:
|
||||||
@ -877,34 +850,22 @@ def create_ui():
|
|||||||
outputs=[inpaint_controls, mask_alpha],
|
outputs=[inpaint_controls, mask_alpha],
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
modules.scripts.scripts_img2img.setup_ui_for_section(category)
|
scripts.scripts_img2img.setup_ui_for_section(category)
|
||||||
|
|
||||||
img2img_gallery, generation_info, html_info, html_log = create_output_panel("img2img", opts.outdir_img2img_samples)
|
img2img_gallery, generation_info, html_info, html_log = create_output_panel("img2img", opts.outdir_img2img_samples)
|
||||||
|
|
||||||
connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False)
|
connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False)
|
||||||
connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True)
|
connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True)
|
||||||
|
|
||||||
img2img_prompt_img.change(
|
|
||||||
fn=modules.images.image_data,
|
|
||||||
inputs=[
|
|
||||||
img2img_prompt_img
|
|
||||||
],
|
|
||||||
outputs=[
|
|
||||||
img2img_prompt,
|
|
||||||
img2img_prompt_img
|
|
||||||
],
|
|
||||||
show_progress=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
img2img_args = dict(
|
img2img_args = dict(
|
||||||
fn=wrap_gradio_gpu_call(modules.img2img.img2img, extra_outputs=[None, '', '']),
|
fn=wrap_gradio_gpu_call(modules.img2img.img2img, extra_outputs=[None, '', '']),
|
||||||
_js="submit_img2img",
|
_js="submit_img2img",
|
||||||
inputs=[
|
inputs=[
|
||||||
dummy_component,
|
dummy_component,
|
||||||
dummy_component,
|
dummy_component,
|
||||||
img2img_prompt,
|
toprow.prompt,
|
||||||
img2img_negative_prompt,
|
toprow.negative_prompt,
|
||||||
img2img_prompt_styles,
|
toprow.ui_styles.dropdown,
|
||||||
init_img,
|
init_img,
|
||||||
sketch,
|
sketch,
|
||||||
init_img_with_mask,
|
init_img_with_mask,
|
||||||
@ -963,11 +924,11 @@ def create_ui():
|
|||||||
inpaint_color_sketch,
|
inpaint_color_sketch,
|
||||||
init_img_inpaint,
|
init_img_inpaint,
|
||||||
],
|
],
|
||||||
outputs=[img2img_prompt, dummy_component],
|
outputs=[toprow.prompt, dummy_component],
|
||||||
)
|
)
|
||||||
|
|
||||||
img2img_prompt.submit(**img2img_args)
|
toprow.prompt.submit(**img2img_args)
|
||||||
submit.click(**img2img_args)
|
toprow.submit.click(**img2img_args)
|
||||||
|
|
||||||
res_switch_btn.click(fn=None, _js="function(){switchWidthHeight('img2img')}", inputs=None, outputs=None, show_progress=False)
|
res_switch_btn.click(fn=None, _js="function(){switchWidthHeight('img2img')}", inputs=None, outputs=None, show_progress=False)
|
||||||
|
|
||||||
@ -979,7 +940,7 @@ def create_ui():
|
|||||||
show_progress=False,
|
show_progress=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
restore_progress_button.click(
|
toprow.restore_progress_button.click(
|
||||||
fn=progress.restore_progress,
|
fn=progress.restore_progress,
|
||||||
_js="restoreProgressImg2img",
|
_js="restoreProgressImg2img",
|
||||||
inputs=[dummy_component],
|
inputs=[dummy_component],
|
||||||
@ -992,44 +953,22 @@ def create_ui():
|
|||||||
show_progress=False,
|
show_progress=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
img2img_interrogate.click(
|
toprow.button_interrogate.click(
|
||||||
fn=lambda *args: process_interrogate(interrogate, *args),
|
fn=lambda *args: process_interrogate(interrogate, *args),
|
||||||
**interrogate_args,
|
**interrogate_args,
|
||||||
)
|
)
|
||||||
|
|
||||||
img2img_deepbooru.click(
|
toprow.button_deepbooru.click(
|
||||||
fn=lambda *args: process_interrogate(interrogate_deepbooru, *args),
|
fn=lambda *args: process_interrogate(interrogate_deepbooru, *args),
|
||||||
**interrogate_args,
|
**interrogate_args,
|
||||||
)
|
)
|
||||||
|
|
||||||
prompts = [(txt2img_prompt, txt2img_negative_prompt), (img2img_prompt, img2img_negative_prompt)]
|
toprow.token_button.click(fn=update_token_counter, inputs=[toprow.prompt, steps], outputs=[toprow.token_counter])
|
||||||
style_dropdowns = [txt2img_prompt_styles, img2img_prompt_styles]
|
toprow.negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[toprow.negative_prompt, steps], outputs=[toprow.negative_token_counter])
|
||||||
style_js_funcs = ["update_txt2img_tokens", "update_img2img_tokens"]
|
|
||||||
|
|
||||||
for button, (prompt, negative_prompt) in zip([txt2img_save_style, img2img_save_style], prompts):
|
|
||||||
button.click(
|
|
||||||
fn=add_style,
|
|
||||||
_js="ask_for_style_name",
|
|
||||||
# Have to pass empty dummy component here, because the JavaScript and Python function have to accept
|
|
||||||
# the same number of parameters, but we only know the style-name after the JavaScript prompt
|
|
||||||
inputs=[dummy_component, prompt, negative_prompt],
|
|
||||||
outputs=[txt2img_prompt_styles, img2img_prompt_styles],
|
|
||||||
)
|
|
||||||
|
|
||||||
for button, (prompt, negative_prompt), styles, js_func in zip([txt2img_prompt_style_apply, img2img_prompt_style_apply], prompts, style_dropdowns, style_js_funcs):
|
|
||||||
button.click(
|
|
||||||
fn=apply_styles,
|
|
||||||
_js=js_func,
|
|
||||||
inputs=[prompt, negative_prompt, styles],
|
|
||||||
outputs=[prompt, negative_prompt, styles],
|
|
||||||
)
|
|
||||||
|
|
||||||
token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter])
|
|
||||||
negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[img2img_negative_prompt, steps], outputs=[negative_token_counter])
|
|
||||||
|
|
||||||
img2img_paste_fields = [
|
img2img_paste_fields = [
|
||||||
(img2img_prompt, "Prompt"),
|
(toprow.prompt, "Prompt"),
|
||||||
(img2img_negative_prompt, "Negative prompt"),
|
(toprow.negative_prompt, "Negative prompt"),
|
||||||
(steps, "Steps"),
|
(steps, "Steps"),
|
||||||
(sampler_index, "Sampler"),
|
(sampler_index, "Sampler"),
|
||||||
(restore_faces, "Face restoration"),
|
(restore_faces, "Face restoration"),
|
||||||
@ -1039,19 +978,20 @@ def create_ui():
|
|||||||
(width, "Size-1"),
|
(width, "Size-1"),
|
||||||
(height, "Size-2"),
|
(height, "Size-2"),
|
||||||
(batch_size, "Batch size"),
|
(batch_size, "Batch size"),
|
||||||
|
(seed_checkbox, lambda d: "Variation seed" in d or "Seed resize from-1" in d),
|
||||||
(subseed, "Variation seed"),
|
(subseed, "Variation seed"),
|
||||||
(subseed_strength, "Variation seed strength"),
|
(subseed_strength, "Variation seed strength"),
|
||||||
(seed_resize_from_w, "Seed resize from-1"),
|
(seed_resize_from_w, "Seed resize from-1"),
|
||||||
(seed_resize_from_h, "Seed resize from-2"),
|
(seed_resize_from_h, "Seed resize from-2"),
|
||||||
(img2img_prompt_styles, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()),
|
(toprow.ui_styles.dropdown, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()),
|
||||||
(denoising_strength, "Denoising strength"),
|
(denoising_strength, "Denoising strength"),
|
||||||
(mask_blur, "Mask blur"),
|
(mask_blur, "Mask blur"),
|
||||||
*modules.scripts.scripts_img2img.infotext_fields
|
*scripts.scripts_img2img.infotext_fields
|
||||||
]
|
]
|
||||||
parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields, override_settings)
|
parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields, override_settings)
|
||||||
parameters_copypaste.add_paste_fields("inpaint", init_img_with_mask, img2img_paste_fields, override_settings)
|
parameters_copypaste.add_paste_fields("inpaint", init_img_with_mask, img2img_paste_fields, override_settings)
|
||||||
parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding(
|
parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding(
|
||||||
paste_button=img2img_paste, tabname="img2img", source_text_component=img2img_prompt, source_image_component=None,
|
paste_button=toprow.paste, tabname="img2img", source_text_component=toprow.prompt, source_image_component=None,
|
||||||
))
|
))
|
||||||
|
|
||||||
from modules import ui_extra_networks
|
from modules import ui_extra_networks
|
||||||
@ -1060,13 +1000,13 @@ def create_ui():
|
|||||||
|
|
||||||
extra_tabs.__exit__()
|
extra_tabs.__exit__()
|
||||||
|
|
||||||
modules.scripts.scripts_current = None
|
scripts.scripts_current = None
|
||||||
|
|
||||||
with gr.Blocks(analytics_enabled=False) as extras_interface:
|
with gr.Blocks(analytics_enabled=False) as extras_interface:
|
||||||
ui_postprocessing.create_ui()
|
ui_postprocessing.create_ui()
|
||||||
|
|
||||||
with gr.Blocks(analytics_enabled=False) as pnginfo_interface:
|
with gr.Blocks(analytics_enabled=False) as pnginfo_interface:
|
||||||
with gr.Row().style(equal_height=False):
|
with gr.Row(equal_height=False):
|
||||||
with gr.Column(variant='panel'):
|
with gr.Column(variant='panel'):
|
||||||
image = gr.Image(elem_id="pnginfo_image", label="Source", source="upload", interactive=True, type="pil")
|
image = gr.Image(elem_id="pnginfo_image", label="Source", source="upload", interactive=True, type="pil")
|
||||||
|
|
||||||
@ -1088,64 +1028,13 @@ def create_ui():
|
|||||||
outputs=[html, generation_info, html2],
|
outputs=[html, generation_info, html2],
|
||||||
)
|
)
|
||||||
|
|
||||||
def update_interp_description(value):
|
modelmerger_ui = ui_checkpoint_merger.UiCheckpointMerger()
|
||||||
interp_description_css = "<p style='margin-bottom: 2.5em'>{}</p>"
|
|
||||||
interp_descriptions = {
|
|
||||||
"No interpolation": interp_description_css.format("No interpolation will be used. Requires one model; A. Allows for format conversion and VAE baking."),
|
|
||||||
"Weighted sum": interp_description_css.format("A weighted sum will be used for interpolation. Requires two models; A and B. The result is calculated as A * (1 - M) + B * M"),
|
|
||||||
"Add difference": interp_description_css.format("The difference between the last two models will be added to the first. Requires three models; A, B and C. The result is calculated as A + (B - C) * M")
|
|
||||||
}
|
|
||||||
return interp_descriptions[value]
|
|
||||||
|
|
||||||
with gr.Blocks(analytics_enabled=False) as modelmerger_interface:
|
|
||||||
with gr.Row().style(equal_height=False):
|
|
||||||
with gr.Column(variant='compact'):
|
|
||||||
interp_description = gr.HTML(value=update_interp_description("Weighted sum"), elem_id="modelmerger_interp_description")
|
|
||||||
|
|
||||||
with FormRow(elem_id="modelmerger_models"):
|
|
||||||
primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary model (A)")
|
|
||||||
create_refresh_button(primary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_A")
|
|
||||||
|
|
||||||
secondary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_secondary_model_name", label="Secondary model (B)")
|
|
||||||
create_refresh_button(secondary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_B")
|
|
||||||
|
|
||||||
tertiary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_tertiary_model_name", label="Tertiary model (C)")
|
|
||||||
create_refresh_button(tertiary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_C")
|
|
||||||
|
|
||||||
custom_name = gr.Textbox(label="Custom Name (Optional)", elem_id="modelmerger_custom_name")
|
|
||||||
interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3, elem_id="modelmerger_interp_amount")
|
|
||||||
interp_method = gr.Radio(choices=["No interpolation", "Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method", elem_id="modelmerger_interp_method")
|
|
||||||
interp_method.change(fn=update_interp_description, inputs=[interp_method], outputs=[interp_description])
|
|
||||||
|
|
||||||
with FormRow():
|
|
||||||
checkpoint_format = gr.Radio(choices=["ckpt", "safetensors"], value="safetensors", label="Checkpoint format", elem_id="modelmerger_checkpoint_format")
|
|
||||||
save_as_half = gr.Checkbox(value=False, label="Save as float16", elem_id="modelmerger_save_as_half")
|
|
||||||
save_metadata = gr.Checkbox(value=True, label="Save metadata (.safetensors only)", elem_id="modelmerger_save_metadata")
|
|
||||||
|
|
||||||
with FormRow():
|
|
||||||
with gr.Column():
|
|
||||||
config_source = gr.Radio(choices=["A, B or C", "B", "C", "Don't"], value="A, B or C", label="Copy config from", type="index", elem_id="modelmerger_config_method")
|
|
||||||
|
|
||||||
with gr.Column():
|
|
||||||
with FormRow():
|
|
||||||
bake_in_vae = gr.Dropdown(choices=["None"] + list(sd_vae.vae_dict), value="None", label="Bake in VAE", elem_id="modelmerger_bake_in_vae")
|
|
||||||
create_refresh_button(bake_in_vae, sd_vae.refresh_vae_list, lambda: {"choices": ["None"] + list(sd_vae.vae_dict)}, "modelmerger_refresh_bake_in_vae")
|
|
||||||
|
|
||||||
with FormRow():
|
|
||||||
discard_weights = gr.Textbox(value="", label="Discard weights with matching name", elem_id="modelmerger_discard_weights")
|
|
||||||
|
|
||||||
with gr.Row():
|
|
||||||
modelmerger_merge = gr.Button(elem_id="modelmerger_merge", value="Merge", variant='primary')
|
|
||||||
|
|
||||||
with gr.Column(variant='compact', elem_id="modelmerger_results_container"):
|
|
||||||
with gr.Group(elem_id="modelmerger_results_panel"):
|
|
||||||
modelmerger_result = gr.HTML(elem_id="modelmerger_result", show_label=False)
|
|
||||||
|
|
||||||
with gr.Blocks(analytics_enabled=False) as train_interface:
|
with gr.Blocks(analytics_enabled=False) as train_interface:
|
||||||
with gr.Row().style(equal_height=False):
|
with gr.Row(equal_height=False):
|
||||||
gr.HTML(value="<p style='margin-bottom: 0.7em'>See <b><a href=\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion\">wiki</a></b> for detailed explanation.</p>")
|
gr.HTML(value="<p style='margin-bottom: 0.7em'>See <b><a href=\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion\">wiki</a></b> for detailed explanation.</p>")
|
||||||
|
|
||||||
with gr.Row(variant="compact").style(equal_height=False):
|
with gr.Row(variant="compact", equal_height=False):
|
||||||
with gr.Tabs(elem_id="train_tabs"):
|
with gr.Tabs(elem_id="train_tabs"):
|
||||||
|
|
||||||
with gr.Tab(label="Create embedding", id="create_embedding"):
|
with gr.Tab(label="Create embedding", id="create_embedding"):
|
||||||
@ -1165,7 +1054,7 @@ def create_ui():
|
|||||||
new_hypernetwork_name = gr.Textbox(label="Name", elem_id="train_new_hypernetwork_name")
|
new_hypernetwork_name = gr.Textbox(label="Name", elem_id="train_new_hypernetwork_name")
|
||||||
new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "1024", "320", "640", "1280"], elem_id="train_new_hypernetwork_sizes")
|
new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "1024", "320", "640", "1280"], elem_id="train_new_hypernetwork_sizes")
|
||||||
new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'", elem_id="train_new_hypernetwork_layer_structure")
|
new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'", elem_id="train_new_hypernetwork_layer_structure")
|
||||||
new_hypernetwork_activation_func = gr.Dropdown(value="linear", label="Select activation function of hypernetwork. Recommended : Swish / Linear(none)", choices=modules.hypernetworks.ui.keys, elem_id="train_new_hypernetwork_activation_func")
|
new_hypernetwork_activation_func = gr.Dropdown(value="linear", label="Select activation function of hypernetwork. Recommended : Swish / Linear(none)", choices=hypernetworks_ui.keys, elem_id="train_new_hypernetwork_activation_func")
|
||||||
new_hypernetwork_initialization_option = gr.Dropdown(value = "Normal", label="Select Layer weights initialization. Recommended: Kaiming for relu-like, Xavier for sigmoid-like, Normal otherwise", choices=["Normal", "KaimingUniform", "KaimingNormal", "XavierUniform", "XavierNormal"], elem_id="train_new_hypernetwork_initialization_option")
|
new_hypernetwork_initialization_option = gr.Dropdown(value = "Normal", label="Select Layer weights initialization. Recommended: Kaiming for relu-like, Xavier for sigmoid-like, Normal otherwise", choices=["Normal", "KaimingUniform", "KaimingNormal", "XavierUniform", "XavierNormal"], elem_id="train_new_hypernetwork_initialization_option")
|
||||||
new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization", elem_id="train_new_hypernetwork_add_layer_norm")
|
new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization", elem_id="train_new_hypernetwork_add_layer_norm")
|
||||||
new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout", elem_id="train_new_hypernetwork_use_dropout")
|
new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout", elem_id="train_new_hypernetwork_use_dropout")
|
||||||
@ -1305,12 +1194,12 @@ def create_ui():
|
|||||||
|
|
||||||
with gr.Column(elem_id='ti_gallery_container'):
|
with gr.Column(elem_id='ti_gallery_container'):
|
||||||
ti_output = gr.Text(elem_id="ti_output", value="", show_label=False)
|
ti_output = gr.Text(elem_id="ti_output", value="", show_label=False)
|
||||||
gr.Gallery(label='Output', show_label=False, elem_id='ti_gallery').style(columns=4)
|
gr.Gallery(label='Output', show_label=False, elem_id='ti_gallery', columns=4)
|
||||||
gr.HTML(elem_id="ti_progress", value="")
|
gr.HTML(elem_id="ti_progress", value="")
|
||||||
ti_outcome = gr.HTML(elem_id="ti_error", value="")
|
ti_outcome = gr.HTML(elem_id="ti_error", value="")
|
||||||
|
|
||||||
create_embedding.click(
|
create_embedding.click(
|
||||||
fn=modules.textual_inversion.ui.create_embedding,
|
fn=textual_inversion_ui.create_embedding,
|
||||||
inputs=[
|
inputs=[
|
||||||
new_embedding_name,
|
new_embedding_name,
|
||||||
initialization_text,
|
initialization_text,
|
||||||
@ -1325,7 +1214,7 @@ def create_ui():
|
|||||||
)
|
)
|
||||||
|
|
||||||
create_hypernetwork.click(
|
create_hypernetwork.click(
|
||||||
fn=modules.hypernetworks.ui.create_hypernetwork,
|
fn=hypernetworks_ui.create_hypernetwork,
|
||||||
inputs=[
|
inputs=[
|
||||||
new_hypernetwork_name,
|
new_hypernetwork_name,
|
||||||
new_hypernetwork_sizes,
|
new_hypernetwork_sizes,
|
||||||
@ -1345,7 +1234,7 @@ def create_ui():
|
|||||||
)
|
)
|
||||||
|
|
||||||
run_preprocess.click(
|
run_preprocess.click(
|
||||||
fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.preprocess, extra_outputs=[gr.update()]),
|
fn=wrap_gradio_gpu_call(textual_inversion_ui.preprocess, extra_outputs=[gr.update()]),
|
||||||
_js="start_training_textual_inversion",
|
_js="start_training_textual_inversion",
|
||||||
inputs=[
|
inputs=[
|
||||||
dummy_component,
|
dummy_component,
|
||||||
@ -1381,7 +1270,7 @@ def create_ui():
|
|||||||
)
|
)
|
||||||
|
|
||||||
train_embedding.click(
|
train_embedding.click(
|
||||||
fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.train_embedding, extra_outputs=[gr.update()]),
|
fn=wrap_gradio_gpu_call(textual_inversion_ui.train_embedding, extra_outputs=[gr.update()]),
|
||||||
_js="start_training_textual_inversion",
|
_js="start_training_textual_inversion",
|
||||||
inputs=[
|
inputs=[
|
||||||
dummy_component,
|
dummy_component,
|
||||||
@ -1415,7 +1304,7 @@ def create_ui():
|
|||||||
)
|
)
|
||||||
|
|
||||||
train_hypernetwork.click(
|
train_hypernetwork.click(
|
||||||
fn=wrap_gradio_gpu_call(modules.hypernetworks.ui.train_hypernetwork, extra_outputs=[gr.update()]),
|
fn=wrap_gradio_gpu_call(hypernetworks_ui.train_hypernetwork, extra_outputs=[gr.update()]),
|
||||||
_js="start_training_textual_inversion",
|
_js="start_training_textual_inversion",
|
||||||
inputs=[
|
inputs=[
|
||||||
dummy_component,
|
dummy_component,
|
||||||
@ -1469,7 +1358,7 @@ def create_ui():
|
|||||||
(img2img_interface, "img2img", "img2img"),
|
(img2img_interface, "img2img", "img2img"),
|
||||||
(extras_interface, "Extras", "extras"),
|
(extras_interface, "Extras", "extras"),
|
||||||
(pnginfo_interface, "PNG Info", "pnginfo"),
|
(pnginfo_interface, "PNG Info", "pnginfo"),
|
||||||
(modelmerger_interface, "Checkpoint Merger", "modelmerger"),
|
(modelmerger_ui.blocks, "Checkpoint Merger", "modelmerger"),
|
||||||
(train_interface, "Train", "train"),
|
(train_interface, "Train", "train"),
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -1521,49 +1410,11 @@ def create_ui():
|
|||||||
settings.text_settings.change(fn=update_image_cfg_scale_visibility, inputs=[], outputs=[image_cfg_scale])
|
settings.text_settings.change(fn=update_image_cfg_scale_visibility, inputs=[], outputs=[image_cfg_scale])
|
||||||
demo.load(fn=update_image_cfg_scale_visibility, inputs=[], outputs=[image_cfg_scale])
|
demo.load(fn=update_image_cfg_scale_visibility, inputs=[], outputs=[image_cfg_scale])
|
||||||
|
|
||||||
def modelmerger(*args):
|
modelmerger_ui.setup_ui(dummy_component=dummy_component, sd_model_checkpoint_component=settings.component_dict['sd_model_checkpoint'])
|
||||||
try:
|
|
||||||
results = modules.extras.run_modelmerger(*args)
|
|
||||||
except Exception as e:
|
|
||||||
errors.report("Error loading/saving model file", exc_info=True)
|
|
||||||
modules.sd_models.list_models() # to remove the potentially missing models from the list
|
|
||||||
return [*[gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(4)], f"Error merging checkpoints: {e}"]
|
|
||||||
return results
|
|
||||||
|
|
||||||
modelmerger_merge.click(fn=lambda: '', inputs=[], outputs=[modelmerger_result])
|
|
||||||
modelmerger_merge.click(
|
|
||||||
fn=wrap_gradio_gpu_call(modelmerger, extra_outputs=lambda: [gr.update() for _ in range(4)]),
|
|
||||||
_js='modelmerger',
|
|
||||||
inputs=[
|
|
||||||
dummy_component,
|
|
||||||
primary_model_name,
|
|
||||||
secondary_model_name,
|
|
||||||
tertiary_model_name,
|
|
||||||
interp_method,
|
|
||||||
interp_amount,
|
|
||||||
save_as_half,
|
|
||||||
custom_name,
|
|
||||||
checkpoint_format,
|
|
||||||
config_source,
|
|
||||||
bake_in_vae,
|
|
||||||
discard_weights,
|
|
||||||
save_metadata,
|
|
||||||
],
|
|
||||||
outputs=[
|
|
||||||
primary_model_name,
|
|
||||||
secondary_model_name,
|
|
||||||
tertiary_model_name,
|
|
||||||
settings.component_dict['sd_model_checkpoint'],
|
|
||||||
modelmerger_result,
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
loadsave.dump_defaults()
|
loadsave.dump_defaults()
|
||||||
demo.ui_loadsave = loadsave
|
demo.ui_loadsave = loadsave
|
||||||
|
|
||||||
# Required as a workaround for change() event not triggering when loading values from ui-config.json
|
|
||||||
interp_description.value = update_interp_description(interp_method.value)
|
|
||||||
|
|
||||||
return demo
|
return demo
|
||||||
|
|
||||||
|
|
||||||
|
124
modules/ui_checkpoint_merger.py
Normal file
124
modules/ui_checkpoint_merger.py
Normal file
@ -0,0 +1,124 @@
|
|||||||
|
|
||||||
|
import gradio as gr
|
||||||
|
|
||||||
|
from modules import sd_models, sd_vae, errors, extras, call_queue
|
||||||
|
from modules.ui_components import FormRow
|
||||||
|
from modules.ui_common import create_refresh_button
|
||||||
|
|
||||||
|
|
||||||
|
def update_interp_description(value):
|
||||||
|
interp_description_css = "<p style='margin-bottom: 2.5em'>{}</p>"
|
||||||
|
interp_descriptions = {
|
||||||
|
"No interpolation": interp_description_css.format("No interpolation will be used. Requires one model; A. Allows for format conversion and VAE baking."),
|
||||||
|
"Weighted sum": interp_description_css.format("A weighted sum will be used for interpolation. Requires two models; A and B. The result is calculated as A * (1 - M) + B * M"),
|
||||||
|
"Add difference": interp_description_css.format("The difference between the last two models will be added to the first. Requires three models; A, B and C. The result is calculated as A + (B - C) * M")
|
||||||
|
}
|
||||||
|
return interp_descriptions[value]
|
||||||
|
|
||||||
|
|
||||||
|
def modelmerger(*args):
|
||||||
|
try:
|
||||||
|
results = extras.run_modelmerger(*args)
|
||||||
|
except Exception as e:
|
||||||
|
errors.report("Error loading/saving model file", exc_info=True)
|
||||||
|
sd_models.list_models() # to remove the potentially missing models from the list
|
||||||
|
return [*[gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)], f"Error merging checkpoints: {e}"]
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
class UiCheckpointMerger:
|
||||||
|
def __init__(self):
|
||||||
|
with gr.Blocks(analytics_enabled=False) as modelmerger_interface:
|
||||||
|
with gr.Row(equal_height=False):
|
||||||
|
with gr.Column(variant='compact'):
|
||||||
|
self.interp_description = gr.HTML(value=update_interp_description("Weighted sum"), elem_id="modelmerger_interp_description")
|
||||||
|
|
||||||
|
with FormRow(elem_id="modelmerger_models"):
|
||||||
|
self.primary_model_name = gr.Dropdown(sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary model (A)")
|
||||||
|
create_refresh_button(self.primary_model_name, sd_models.list_models, lambda: {"choices": sd_models.checkpoint_tiles()}, "refresh_checkpoint_A")
|
||||||
|
|
||||||
|
self.secondary_model_name = gr.Dropdown(sd_models.checkpoint_tiles(), elem_id="modelmerger_secondary_model_name", label="Secondary model (B)")
|
||||||
|
create_refresh_button(self.secondary_model_name, sd_models.list_models, lambda: {"choices": sd_models.checkpoint_tiles()}, "refresh_checkpoint_B")
|
||||||
|
|
||||||
|
self.tertiary_model_name = gr.Dropdown(sd_models.checkpoint_tiles(), elem_id="modelmerger_tertiary_model_name", label="Tertiary model (C)")
|
||||||
|
create_refresh_button(self.tertiary_model_name, sd_models.list_models, lambda: {"choices": sd_models.checkpoint_tiles()}, "refresh_checkpoint_C")
|
||||||
|
|
||||||
|
self.custom_name = gr.Textbox(label="Custom Name (Optional)", elem_id="modelmerger_custom_name")
|
||||||
|
self.interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3, elem_id="modelmerger_interp_amount")
|
||||||
|
self.interp_method = gr.Radio(choices=["No interpolation", "Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method", elem_id="modelmerger_interp_method")
|
||||||
|
self.interp_method.change(fn=update_interp_description, inputs=[self.interp_method], outputs=[self.interp_description])
|
||||||
|
|
||||||
|
with FormRow():
|
||||||
|
self.checkpoint_format = gr.Radio(choices=["ckpt", "safetensors"], value="safetensors", label="Checkpoint format", elem_id="modelmerger_checkpoint_format")
|
||||||
|
self.save_as_half = gr.Checkbox(value=False, label="Save as float16", elem_id="modelmerger_save_as_half")
|
||||||
|
|
||||||
|
with FormRow():
|
||||||
|
with gr.Column():
|
||||||
|
self.config_source = gr.Radio(choices=["A, B or C", "B", "C", "Don't"], value="A, B or C", label="Copy config from", type="index", elem_id="modelmerger_config_method")
|
||||||
|
|
||||||
|
with gr.Column():
|
||||||
|
with FormRow():
|
||||||
|
self.bake_in_vae = gr.Dropdown(choices=["None"] + list(sd_vae.vae_dict), value="None", label="Bake in VAE", elem_id="modelmerger_bake_in_vae")
|
||||||
|
create_refresh_button(self.bake_in_vae, sd_vae.refresh_vae_list, lambda: {"choices": ["None"] + list(sd_vae.vae_dict)}, "modelmerger_refresh_bake_in_vae")
|
||||||
|
|
||||||
|
with FormRow():
|
||||||
|
self.discard_weights = gr.Textbox(value="", label="Discard weights with matching name", elem_id="modelmerger_discard_weights")
|
||||||
|
|
||||||
|
with gr.Accordion("Metadata", open=False) as metadata_editor:
|
||||||
|
with FormRow():
|
||||||
|
self.save_metadata = gr.Checkbox(value=True, label="Save metadata", elem_id="modelmerger_save_metadata")
|
||||||
|
self.add_merge_recipe = gr.Checkbox(value=True, label="Add merge recipe metadata", elem_id="modelmerger_add_recipe")
|
||||||
|
self.copy_metadata_fields = gr.Checkbox(value=True, label="Copy metadata from merged models", elem_id="modelmerger_copy_metadata")
|
||||||
|
|
||||||
|
self.metadata_json = gr.TextArea('{}', label="Metadata in JSON format")
|
||||||
|
self.read_metadata = gr.Button("Read metadata from selected checkpoints")
|
||||||
|
|
||||||
|
with FormRow():
|
||||||
|
self.modelmerger_merge = gr.Button(elem_id="modelmerger_merge", value="Merge", variant='primary')
|
||||||
|
|
||||||
|
with gr.Column(variant='compact', elem_id="modelmerger_results_container"):
|
||||||
|
with gr.Group(elem_id="modelmerger_results_panel"):
|
||||||
|
self.modelmerger_result = gr.HTML(elem_id="modelmerger_result", show_label=False)
|
||||||
|
|
||||||
|
self.metadata_editor = metadata_editor
|
||||||
|
self.blocks = modelmerger_interface
|
||||||
|
|
||||||
|
def setup_ui(self, dummy_component, sd_model_checkpoint_component):
|
||||||
|
self.checkpoint_format.change(lambda fmt: gr.update(visible=fmt == 'safetensors'), inputs=[self.checkpoint_format], outputs=[self.metadata_editor], show_progress=False)
|
||||||
|
|
||||||
|
self.read_metadata.click(extras.read_metadata, inputs=[self.primary_model_name, self.secondary_model_name, self.tertiary_model_name], outputs=[self.metadata_json])
|
||||||
|
|
||||||
|
self.modelmerger_merge.click(fn=lambda: '', inputs=[], outputs=[self.modelmerger_result])
|
||||||
|
self.modelmerger_merge.click(
|
||||||
|
fn=call_queue.wrap_gradio_gpu_call(modelmerger, extra_outputs=lambda: [gr.update() for _ in range(4)]),
|
||||||
|
_js='modelmerger',
|
||||||
|
inputs=[
|
||||||
|
dummy_component,
|
||||||
|
self.primary_model_name,
|
||||||
|
self.secondary_model_name,
|
||||||
|
self.tertiary_model_name,
|
||||||
|
self.interp_method,
|
||||||
|
self.interp_amount,
|
||||||
|
self.save_as_half,
|
||||||
|
self.custom_name,
|
||||||
|
self.checkpoint_format,
|
||||||
|
self.config_source,
|
||||||
|
self.bake_in_vae,
|
||||||
|
self.discard_weights,
|
||||||
|
self.save_metadata,
|
||||||
|
self.add_merge_recipe,
|
||||||
|
self.copy_metadata_fields,
|
||||||
|
self.metadata_json,
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
self.primary_model_name,
|
||||||
|
self.secondary_model_name,
|
||||||
|
self.tertiary_model_name,
|
||||||
|
sd_model_checkpoint_component,
|
||||||
|
self.modelmerger_result,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Required as a workaround for change() event not triggering when loading values from ui-config.json
|
||||||
|
self.interp_description.value = update_interp_description(self.interp_method.value)
|
||||||
|
|
@ -134,7 +134,7 @@ Requested path was: {f}
|
|||||||
|
|
||||||
with gr.Column(variant='panel', elem_id=f"{tabname}_results"):
|
with gr.Column(variant='panel', elem_id=f"{tabname}_results"):
|
||||||
with gr.Group(elem_id=f"{tabname}_gallery_container"):
|
with gr.Group(elem_id=f"{tabname}_gallery_container"):
|
||||||
result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery").style(columns=4)
|
result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery", columns=4)
|
||||||
|
|
||||||
generation_info = None
|
generation_info = None
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
@ -223,20 +223,44 @@ Requested path was: {f}
|
|||||||
|
|
||||||
|
|
||||||
def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
|
def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
|
||||||
|
refresh_components = refresh_component if isinstance(refresh_component, list) else [refresh_component]
|
||||||
|
|
||||||
|
label = None
|
||||||
|
for comp in refresh_components:
|
||||||
|
label = getattr(comp, 'label', None)
|
||||||
|
if label is not None:
|
||||||
|
break
|
||||||
|
|
||||||
def refresh():
|
def refresh():
|
||||||
refresh_method()
|
refresh_method()
|
||||||
args = refreshed_args() if callable(refreshed_args) else refreshed_args
|
args = refreshed_args() if callable(refreshed_args) else refreshed_args
|
||||||
|
|
||||||
for k, v in args.items():
|
for k, v in args.items():
|
||||||
setattr(refresh_component, k, v)
|
for comp in refresh_components:
|
||||||
|
setattr(comp, k, v)
|
||||||
|
|
||||||
return gr.update(**(args or {}))
|
return (gr.update(**(args or {})) for _ in refresh_components) if len(refresh_components) > 1 else gr.update(**(args or {}))
|
||||||
|
|
||||||
refresh_button = ToolButton(value=refresh_symbol, elem_id=elem_id)
|
refresh_button = ToolButton(value=refresh_symbol, elem_id=elem_id, tooltip=f"{label}: refresh" if label else "Refresh")
|
||||||
refresh_button.click(
|
refresh_button.click(
|
||||||
fn=refresh,
|
fn=refresh,
|
||||||
inputs=[],
|
inputs=[],
|
||||||
outputs=[refresh_component]
|
outputs=refresh_components
|
||||||
)
|
)
|
||||||
return refresh_button
|
return refresh_button
|
||||||
|
|
||||||
|
|
||||||
|
def setup_dialog(button_show, dialog, *, button_close=None):
|
||||||
|
"""Sets up the UI so that the dialog (gr.Box) is invisible, and is only shown when buttons_show is clicked, in a fullscreen modal window."""
|
||||||
|
|
||||||
|
dialog.visible = False
|
||||||
|
|
||||||
|
button_show.click(
|
||||||
|
fn=lambda: gr.update(visible=True),
|
||||||
|
inputs=[],
|
||||||
|
outputs=[dialog],
|
||||||
|
).then(fn=None, _js="function(){ popup(gradioApp().getElementById('" + dialog.elem_id + "')); }")
|
||||||
|
|
||||||
|
if button_close:
|
||||||
|
button_close.click(fn=None, _js="closePopup")
|
||||||
|
|
||||||
|
@ -35,7 +35,7 @@ class FormColumn(FormComponent, gr.Column):
|
|||||||
|
|
||||||
|
|
||||||
class FormGroup(FormComponent, gr.Group):
|
class FormGroup(FormComponent, gr.Group):
|
||||||
"""Same as gr.Row but fits inside gradio forms"""
|
"""Same as gr.Group but fits inside gradio forms"""
|
||||||
|
|
||||||
def get_block_name(self):
|
def get_block_name(self):
|
||||||
return "group"
|
return "group"
|
||||||
|
@ -164,7 +164,7 @@ def extension_table():
|
|||||||
ext_status = ext.status
|
ext_status = ext.status
|
||||||
|
|
||||||
style = ""
|
style = ""
|
||||||
if shared.opts.disable_all_extensions == "extra" and not ext.is_builtin or shared.opts.disable_all_extensions == "all":
|
if shared.cmd_opts.disable_extra_extensions and not ext.is_builtin or shared.opts.disable_all_extensions == "extra" and not ext.is_builtin or shared.cmd_opts.disable_all_extensions or shared.opts.disable_all_extensions == "all":
|
||||||
style = STYLE_PRIMARY
|
style = STYLE_PRIMARY
|
||||||
|
|
||||||
version_link = ext.version
|
version_link = ext.version
|
||||||
@ -533,16 +533,20 @@ def create_ui():
|
|||||||
apply = gr.Button(value=apply_label, variant="primary")
|
apply = gr.Button(value=apply_label, variant="primary")
|
||||||
check = gr.Button(value="Check for updates")
|
check = gr.Button(value="Check for updates")
|
||||||
extensions_disable_all = gr.Radio(label="Disable all extensions", choices=["none", "extra", "all"], value=shared.opts.disable_all_extensions, elem_id="extensions_disable_all")
|
extensions_disable_all = gr.Radio(label="Disable all extensions", choices=["none", "extra", "all"], value=shared.opts.disable_all_extensions, elem_id="extensions_disable_all")
|
||||||
extensions_disabled_list = gr.Text(elem_id="extensions_disabled_list", visible=False).style(container=False)
|
extensions_disabled_list = gr.Text(elem_id="extensions_disabled_list", visible=False, container=False)
|
||||||
extensions_update_list = gr.Text(elem_id="extensions_update_list", visible=False).style(container=False)
|
extensions_update_list = gr.Text(elem_id="extensions_update_list", visible=False, container=False)
|
||||||
|
|
||||||
html = ""
|
html = ""
|
||||||
if shared.opts.disable_all_extensions != "none":
|
|
||||||
html = """
|
if shared.cmd_opts.disable_all_extensions or shared.cmd_opts.disable_extra_extensions or shared.opts.disable_all_extensions != "none":
|
||||||
<span style="color: var(--primary-400);">
|
if shared.cmd_opts.disable_all_extensions:
|
||||||
"Disable all extensions" was set, change it to "none" to load all extensions again
|
msg = '"--disable-all-extensions" was used, remove it to load all extensions again'
|
||||||
</span>
|
elif shared.opts.disable_all_extensions != "none":
|
||||||
"""
|
msg = '"Disable all extensions" was set, change it to "none" to load all extensions again'
|
||||||
|
elif shared.cmd_opts.disable_extra_extensions:
|
||||||
|
msg = '"--disable-extra-extensions" was used, remove it to load all extensions again'
|
||||||
|
html = f'<span style="color: var(--primary-400);">{msg}</span>'
|
||||||
|
|
||||||
info = gr.HTML(html)
|
info = gr.HTML(html)
|
||||||
extensions_table = gr.HTML('Loading...')
|
extensions_table = gr.HTML('Loading...')
|
||||||
ui.load(fn=extension_table, inputs=[], outputs=[extensions_table])
|
ui.load(fn=extension_table, inputs=[], outputs=[extensions_table])
|
||||||
@ -565,7 +569,7 @@ def create_ui():
|
|||||||
with gr.Row():
|
with gr.Row():
|
||||||
refresh_available_extensions_button = gr.Button(value="Load from:", variant="primary")
|
refresh_available_extensions_button = gr.Button(value="Load from:", variant="primary")
|
||||||
extensions_index_url = os.environ.get('WEBUI_EXTENSIONS_INDEX', "https://raw.githubusercontent.com/AUTOMATIC1111/stable-diffusion-webui-extensions/master/index.json")
|
extensions_index_url = os.environ.get('WEBUI_EXTENSIONS_INDEX', "https://raw.githubusercontent.com/AUTOMATIC1111/stable-diffusion-webui-extensions/master/index.json")
|
||||||
available_extensions_index = gr.Text(value=extensions_index_url, label="Extension index URL").style(container=False)
|
available_extensions_index = gr.Text(value=extensions_index_url, label="Extension index URL", container=False)
|
||||||
extension_to_install = gr.Text(elem_id="extension_to_install", visible=False)
|
extension_to_install = gr.Text(elem_id="extension_to_install", visible=False)
|
||||||
install_extension_button = gr.Button(elem_id="install_extension_button", visible=False)
|
install_extension_button = gr.Button(elem_id="install_extension_button", visible=False)
|
||||||
|
|
||||||
@ -574,7 +578,7 @@ def create_ui():
|
|||||||
sort_column = gr.Radio(value="newest first", label="Order", choices=["newest first", "oldest first", "a-z", "z-a", "internal order",'update time', 'create time', "stars"], type="index")
|
sort_column = gr.Radio(value="newest first", label="Order", choices=["newest first", "oldest first", "a-z", "z-a", "internal order",'update time', 'create time', "stars"], type="index")
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
search_extensions_text = gr.Text(label="Search").style(container=False)
|
search_extensions_text = gr.Text(label="Search", container=False)
|
||||||
|
|
||||||
install_result = gr.HTML()
|
install_result = gr.HTML()
|
||||||
available_extensions_table = gr.HTML()
|
available_extensions_table = gr.HTML()
|
||||||
|
@ -2,7 +2,7 @@ import os.path
|
|||||||
import urllib.parse
|
import urllib.parse
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from modules import shared, ui_extra_networks_user_metadata, errors
|
from modules import shared, ui_extra_networks_user_metadata, errors, extra_networks
|
||||||
from modules.images import read_info_from_image, save_image_with_geninfo
|
from modules.images import read_info_from_image, save_image_with_geninfo
|
||||||
from modules.ui import up_down_symbol
|
from modules.ui import up_down_symbol
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
@ -62,7 +62,8 @@ def get_single_card(page: str = "", tabname: str = "", name: str = ""):
|
|||||||
page = next(iter([x for x in extra_pages if x.name == page]), None)
|
page = next(iter([x for x in extra_pages if x.name == page]), None)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
item = page.create_item(name)
|
item = page.create_item(name, enable_filter=False)
|
||||||
|
page.items[name] = item
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
errors.display(e, "creating item for extra network")
|
errors.display(e, "creating item for extra network")
|
||||||
item = page.items.get(name)
|
item = page.items.get(name)
|
||||||
@ -100,16 +101,7 @@ class ExtraNetworksPage:
|
|||||||
|
|
||||||
def read_user_metadata(self, item):
|
def read_user_metadata(self, item):
|
||||||
filename = item.get("filename", None)
|
filename = item.get("filename", None)
|
||||||
basename, ext = os.path.splitext(filename)
|
metadata = extra_networks.get_user_metadata(filename)
|
||||||
metadata_filename = basename + '.json'
|
|
||||||
|
|
||||||
metadata = {}
|
|
||||||
try:
|
|
||||||
if os.path.isfile(metadata_filename):
|
|
||||||
with open(metadata_filename, "r", encoding="utf8") as file:
|
|
||||||
metadata = json.load(file)
|
|
||||||
except Exception as e:
|
|
||||||
errors.display(e, f"reading extra network user metadata from {metadata_filename}")
|
|
||||||
|
|
||||||
desc = metadata.get("description", None)
|
desc = metadata.get("description", None)
|
||||||
if desc is not None:
|
if desc is not None:
|
||||||
@ -252,7 +244,7 @@ class ExtraNetworksPage:
|
|||||||
"prompt": item.get("prompt", None),
|
"prompt": item.get("prompt", None),
|
||||||
"tabname": quote_js(tabname),
|
"tabname": quote_js(tabname),
|
||||||
"local_preview": quote_js(item["local_preview"]),
|
"local_preview": quote_js(item["local_preview"]),
|
||||||
"name": item["name"],
|
"name": html.escape(item["name"]),
|
||||||
"description": (item.get("description") or "" if shared.opts.extra_networks_card_show_desc else ""),
|
"description": (item.get("description") or "" if shared.opts.extra_networks_card_show_desc else ""),
|
||||||
"card_clicked": onclick,
|
"card_clicked": onclick,
|
||||||
"save_card_preview": '"' + html.escape(f"""return saveCardPreview(event, {quote_js(tabname)}, {quote_js(item["local_preview"])})""") + '"',
|
"save_card_preview": '"' + html.escape(f"""return saveCardPreview(event, {quote_js(tabname)}, {quote_js(item["local_preview"])})""") + '"',
|
||||||
|
@ -3,6 +3,7 @@ import os
|
|||||||
|
|
||||||
from modules import shared, ui_extra_networks, sd_models
|
from modules import shared, ui_extra_networks, sd_models
|
||||||
from modules.ui_extra_networks import quote_js
|
from modules.ui_extra_networks import quote_js
|
||||||
|
from modules.ui_extra_networks_checkpoints_user_metadata import CheckpointUserMetadataEditor
|
||||||
|
|
||||||
|
|
||||||
class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
|
class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
|
||||||
@ -12,7 +13,7 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
|
|||||||
def refresh(self):
|
def refresh(self):
|
||||||
shared.refresh_checkpoints()
|
shared.refresh_checkpoints()
|
||||||
|
|
||||||
def create_item(self, name, index=None):
|
def create_item(self, name, index=None, enable_filter=True):
|
||||||
checkpoint: sd_models.CheckpointInfo = sd_models.checkpoint_aliases.get(name)
|
checkpoint: sd_models.CheckpointInfo = sd_models.checkpoint_aliases.get(name)
|
||||||
path, ext = os.path.splitext(checkpoint.filename)
|
path, ext = os.path.splitext(checkpoint.filename)
|
||||||
return {
|
return {
|
||||||
@ -23,6 +24,7 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
|
|||||||
"search_term": self.search_terms_from_path(checkpoint.filename) + " " + (checkpoint.sha256 or ""),
|
"search_term": self.search_terms_from_path(checkpoint.filename) + " " + (checkpoint.sha256 or ""),
|
||||||
"onclick": '"' + html.escape(f"""return selectCheckpoint({quote_js(name)})""") + '"',
|
"onclick": '"' + html.escape(f"""return selectCheckpoint({quote_js(name)})""") + '"',
|
||||||
"local_preview": f"{path}.{shared.opts.samples_format}",
|
"local_preview": f"{path}.{shared.opts.samples_format}",
|
||||||
|
"metadata": checkpoint.metadata,
|
||||||
"sort_keys": {'default': index, **self.get_sort_keys(checkpoint.filename)},
|
"sort_keys": {'default': index, **self.get_sort_keys(checkpoint.filename)},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -33,3 +35,5 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
|
|||||||
def allowed_directories_for_previews(self):
|
def allowed_directories_for_previews(self):
|
||||||
return [v for v in [shared.cmd_opts.ckpt_dir, sd_models.model_path] if v is not None]
|
return [v for v in [shared.cmd_opts.ckpt_dir, sd_models.model_path] if v is not None]
|
||||||
|
|
||||||
|
def create_user_metadata_editor(self, ui, tabname):
|
||||||
|
return CheckpointUserMetadataEditor(ui, tabname, self)
|
||||||
|
60
modules/ui_extra_networks_checkpoints_user_metadata.py
Normal file
60
modules/ui_extra_networks_checkpoints_user_metadata.py
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
import gradio as gr
|
||||||
|
|
||||||
|
from modules import ui_extra_networks_user_metadata, sd_vae
|
||||||
|
from modules.ui_common import create_refresh_button
|
||||||
|
|
||||||
|
|
||||||
|
class CheckpointUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor):
|
||||||
|
def __init__(self, ui, tabname, page):
|
||||||
|
super().__init__(ui, tabname, page)
|
||||||
|
|
||||||
|
self.select_vae = None
|
||||||
|
|
||||||
|
def save_user_metadata(self, name, desc, notes, vae):
|
||||||
|
user_metadata = self.get_user_metadata(name)
|
||||||
|
user_metadata["description"] = desc
|
||||||
|
user_metadata["notes"] = notes
|
||||||
|
user_metadata["vae"] = vae
|
||||||
|
|
||||||
|
self.write_user_metadata(name, user_metadata)
|
||||||
|
|
||||||
|
def put_values_into_components(self, name):
|
||||||
|
user_metadata = self.get_user_metadata(name)
|
||||||
|
values = super().put_values_into_components(name)
|
||||||
|
|
||||||
|
return [
|
||||||
|
*values[0:5],
|
||||||
|
user_metadata.get('vae', ''),
|
||||||
|
]
|
||||||
|
|
||||||
|
def create_editor(self):
|
||||||
|
self.create_default_editor_elems()
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
self.select_vae = gr.Dropdown(choices=["Automatic", "None"] + list(sd_vae.vae_dict), value="None", label="Preferred VAE", elem_id="checpoint_edit_user_metadata_preferred_vae")
|
||||||
|
create_refresh_button(self.select_vae, sd_vae.refresh_vae_list, lambda: {"choices": ["Automatic", "None"] + list(sd_vae.vae_dict)}, "checpoint_edit_user_metadata_refresh_preferred_vae")
|
||||||
|
|
||||||
|
self.edit_notes = gr.TextArea(label='Notes', lines=4)
|
||||||
|
|
||||||
|
self.create_default_buttons()
|
||||||
|
|
||||||
|
viewed_components = [
|
||||||
|
self.edit_name,
|
||||||
|
self.edit_description,
|
||||||
|
self.html_filedata,
|
||||||
|
self.html_preview,
|
||||||
|
self.edit_notes,
|
||||||
|
self.select_vae,
|
||||||
|
]
|
||||||
|
|
||||||
|
self.button_edit\
|
||||||
|
.click(fn=self.put_values_into_components, inputs=[self.edit_name_input], outputs=viewed_components)\
|
||||||
|
.then(fn=lambda: gr.update(visible=True), inputs=[], outputs=[self.box])
|
||||||
|
|
||||||
|
edited_components = [
|
||||||
|
self.edit_description,
|
||||||
|
self.edit_notes,
|
||||||
|
self.select_vae,
|
||||||
|
]
|
||||||
|
|
||||||
|
self.setup_save_handler(self.button_save, self.save_user_metadata, edited_components)
|
@ -11,7 +11,7 @@ class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage):
|
|||||||
def refresh(self):
|
def refresh(self):
|
||||||
shared.reload_hypernetworks()
|
shared.reload_hypernetworks()
|
||||||
|
|
||||||
def create_item(self, name, index=None):
|
def create_item(self, name, index=None, enable_filter=True):
|
||||||
full_path = shared.hypernetworks[name]
|
full_path = shared.hypernetworks[name]
|
||||||
path, ext = os.path.splitext(full_path)
|
path, ext = os.path.splitext(full_path)
|
||||||
|
|
||||||
|
@ -12,7 +12,7 @@ class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage):
|
|||||||
def refresh(self):
|
def refresh(self):
|
||||||
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True)
|
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True)
|
||||||
|
|
||||||
def create_item(self, name, index=None):
|
def create_item(self, name, index=None, enable_filter=True):
|
||||||
embedding = sd_hijack.model_hijack.embedding_db.word_embeddings.get(name)
|
embedding = sd_hijack.model_hijack.embedding_db.word_embeddings.get(name)
|
||||||
|
|
||||||
path, ext = os.path.splitext(embedding.filename)
|
path, ext = os.path.splitext(embedding.filename)
|
||||||
|
@ -42,6 +42,9 @@ class UserMetadataEditor:
|
|||||||
|
|
||||||
return user_metadata
|
return user_metadata
|
||||||
|
|
||||||
|
def create_extra_default_items_in_left_column(self):
|
||||||
|
pass
|
||||||
|
|
||||||
def create_default_editor_elems(self):
|
def create_default_editor_elems(self):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column(scale=2):
|
with gr.Column(scale=2):
|
||||||
@ -49,6 +52,8 @@ class UserMetadataEditor:
|
|||||||
self.edit_description = gr.Textbox(label="Description", lines=4)
|
self.edit_description = gr.Textbox(label="Description", lines=4)
|
||||||
self.html_filedata = gr.HTML()
|
self.html_filedata = gr.HTML()
|
||||||
|
|
||||||
|
self.create_extra_default_items_in_left_column()
|
||||||
|
|
||||||
with gr.Column(scale=1, min_width=0):
|
with gr.Column(scale=1, min_width=0):
|
||||||
self.html_preview = gr.HTML()
|
self.html_preview = gr.HTML()
|
||||||
|
|
||||||
@ -91,6 +96,7 @@ class UserMetadataEditor:
|
|||||||
|
|
||||||
stats = os.stat(filename)
|
stats = os.stat(filename)
|
||||||
params = [
|
params = [
|
||||||
|
('Filename: ', os.path.basename(filename)),
|
||||||
('File size: ', sysinfo.pretty_bytes(stats.st_size)),
|
('File size: ', sysinfo.pretty_bytes(stats.st_size)),
|
||||||
('Modified: ', datetime.datetime.fromtimestamp(stats.st_mtime).strftime('%Y-%m-%d %H:%M')),
|
('Modified: ', datetime.datetime.fromtimestamp(stats.st_mtime).strftime('%Y-%m-%d %H:%M')),
|
||||||
]
|
]
|
||||||
@ -111,7 +117,7 @@ class UserMetadataEditor:
|
|||||||
|
|
||||||
table = '<table class="file-metadata">' + "".join(f"<tr><th>{name}</th><td>{value}</td></tr>" for name, value in params) + '</table>'
|
table = '<table class="file-metadata">' + "".join(f"<tr><th>{name}</th><td>{value}</td></tr>" for name, value in params) + '</table>'
|
||||||
|
|
||||||
return html.escape(name), user_metadata.get('description', ''), table, self.get_card_html(name), user_metadata.get('notes', ''),
|
return html.escape(name), user_metadata.get('description', ''), table, self.get_card_html(name), user_metadata.get('notes', '')
|
||||||
|
|
||||||
def write_user_metadata(self, name, metadata):
|
def write_user_metadata(self, name, metadata):
|
||||||
item = self.page.items.get(name, {})
|
item = self.page.items.get(name, {})
|
||||||
|
@ -6,7 +6,7 @@ import modules.generation_parameters_copypaste as parameters_copypaste
|
|||||||
def create_ui():
|
def create_ui():
|
||||||
tab_index = gr.State(value=0)
|
tab_index = gr.State(value=0)
|
||||||
|
|
||||||
with gr.Row().style(equal_height=False, variant='compact'):
|
with gr.Row(equal_height=False, variant='compact'):
|
||||||
with gr.Column(variant='compact'):
|
with gr.Column(variant='compact'):
|
||||||
with gr.Tabs(elem_id="mode_extras"):
|
with gr.Tabs(elem_id="mode_extras"):
|
||||||
with gr.TabItem('Single Image', id="single_image", elem_id="extras_single_tab") as tab_single:
|
with gr.TabItem('Single Image', id="single_image", elem_id="extras_single_tab") as tab_single:
|
||||||
|
110
modules/ui_prompt_styles.py
Normal file
110
modules/ui_prompt_styles.py
Normal file
@ -0,0 +1,110 @@
|
|||||||
|
import gradio as gr
|
||||||
|
|
||||||
|
from modules import shared, ui_common, ui_components, styles
|
||||||
|
|
||||||
|
styles_edit_symbol = '\U0001f58c\uFE0F' # 🖌️
|
||||||
|
styles_materialize_symbol = '\U0001f4cb' # 📋
|
||||||
|
|
||||||
|
|
||||||
|
def select_style(name):
|
||||||
|
style = shared.prompt_styles.styles.get(name)
|
||||||
|
existing = style is not None
|
||||||
|
empty = not name
|
||||||
|
|
||||||
|
prompt = style.prompt if style else gr.update()
|
||||||
|
negative_prompt = style.negative_prompt if style else gr.update()
|
||||||
|
|
||||||
|
return prompt, negative_prompt, gr.update(visible=existing), gr.update(visible=not empty)
|
||||||
|
|
||||||
|
|
||||||
|
def save_style(name, prompt, negative_prompt):
|
||||||
|
if not name:
|
||||||
|
return gr.update(visible=False)
|
||||||
|
|
||||||
|
style = styles.PromptStyle(name, prompt, negative_prompt)
|
||||||
|
shared.prompt_styles.styles[style.name] = style
|
||||||
|
shared.prompt_styles.save_styles(shared.styles_filename)
|
||||||
|
|
||||||
|
return gr.update(visible=True)
|
||||||
|
|
||||||
|
|
||||||
|
def delete_style(name):
|
||||||
|
if name == "":
|
||||||
|
return
|
||||||
|
|
||||||
|
shared.prompt_styles.styles.pop(name, None)
|
||||||
|
shared.prompt_styles.save_styles(shared.styles_filename)
|
||||||
|
|
||||||
|
return '', '', ''
|
||||||
|
|
||||||
|
|
||||||
|
def materialize_styles(prompt, negative_prompt, styles):
|
||||||
|
prompt = shared.prompt_styles.apply_styles_to_prompt(prompt, styles)
|
||||||
|
negative_prompt = shared.prompt_styles.apply_negative_styles_to_prompt(negative_prompt, styles)
|
||||||
|
|
||||||
|
return [gr.Textbox.update(value=prompt), gr.Textbox.update(value=negative_prompt), gr.Dropdown.update(value=[])]
|
||||||
|
|
||||||
|
|
||||||
|
def refresh_styles():
|
||||||
|
return gr.update(choices=list(shared.prompt_styles.styles)), gr.update(choices=list(shared.prompt_styles.styles))
|
||||||
|
|
||||||
|
|
||||||
|
class UiPromptStyles:
|
||||||
|
def __init__(self, tabname, main_ui_prompt, main_ui_negative_prompt):
|
||||||
|
self.tabname = tabname
|
||||||
|
|
||||||
|
with gr.Row(elem_id=f"{tabname}_styles_row"):
|
||||||
|
self.dropdown = gr.Dropdown(label="Styles", show_label=False, elem_id=f"{tabname}_styles", choices=list(shared.prompt_styles.styles), value=[], multiselect=True, tooltip="Styles")
|
||||||
|
edit_button = ui_components.ToolButton(value=styles_edit_symbol, elem_id=f"{tabname}_styles_edit_button", tooltip="Edit styles")
|
||||||
|
|
||||||
|
with gr.Box(elem_id=f"{tabname}_styles_dialog", elem_classes="popup-dialog") as styles_dialog:
|
||||||
|
with gr.Row():
|
||||||
|
self.selection = gr.Dropdown(label="Styles", elem_id=f"{tabname}_styles_edit_select", choices=list(shared.prompt_styles.styles), value=[], allow_custom_value=True, info="Styles allow you to add custom text to prompt. Use the {prompt} token in style text, and it will be replaced with user's prompt when applying style. Otherwise, style's text will be added to the end of the prompt.")
|
||||||
|
ui_common.create_refresh_button([self.dropdown, self.selection], shared.prompt_styles.reload, lambda: {"choices": list(shared.prompt_styles.styles)}, f"refresh_{tabname}_styles")
|
||||||
|
self.materialize = ui_components.ToolButton(value=styles_materialize_symbol, elem_id=f"{tabname}_style_apply", tooltip="Apply all selected styles from the style selction dropdown in main UI to the prompt.")
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
self.prompt = gr.Textbox(label="Prompt", show_label=True, elem_id=f"{tabname}_edit_style_prompt", lines=3)
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
self.neg_prompt = gr.Textbox(label="Negative prompt", show_label=True, elem_id=f"{tabname}_edit_style_neg_prompt", lines=3)
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
self.save = gr.Button('Save', variant='primary', elem_id=f'{tabname}_edit_style_save', visible=False)
|
||||||
|
self.delete = gr.Button('Delete', variant='primary', elem_id=f'{tabname}_edit_style_delete', visible=False)
|
||||||
|
self.close = gr.Button('Close', variant='secondary', elem_id=f'{tabname}_edit_style_close')
|
||||||
|
|
||||||
|
self.selection.change(
|
||||||
|
fn=select_style,
|
||||||
|
inputs=[self.selection],
|
||||||
|
outputs=[self.prompt, self.neg_prompt, self.delete, self.save],
|
||||||
|
show_progress=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.save.click(
|
||||||
|
fn=save_style,
|
||||||
|
inputs=[self.selection, self.prompt, self.neg_prompt],
|
||||||
|
outputs=[self.delete],
|
||||||
|
show_progress=False,
|
||||||
|
).then(refresh_styles, outputs=[self.dropdown, self.selection], show_progress=False)
|
||||||
|
|
||||||
|
self.delete.click(
|
||||||
|
fn=delete_style,
|
||||||
|
_js='function(name){ if(name == "") return ""; return confirm("Delete style " + name + "?") ? name : ""; }',
|
||||||
|
inputs=[self.selection],
|
||||||
|
outputs=[self.selection, self.prompt, self.neg_prompt],
|
||||||
|
show_progress=False,
|
||||||
|
).then(refresh_styles, outputs=[self.dropdown, self.selection], show_progress=False)
|
||||||
|
|
||||||
|
self.materialize.click(
|
||||||
|
fn=materialize_styles,
|
||||||
|
inputs=[main_ui_prompt, main_ui_negative_prompt, self.dropdown],
|
||||||
|
outputs=[main_ui_prompt, main_ui_negative_prompt, self.dropdown],
|
||||||
|
show_progress=False,
|
||||||
|
).then(fn=None, _js="function(){update_"+tabname+"_tokens(); closePopup();}", show_progress=False)
|
||||||
|
|
||||||
|
ui_common.setup_dialog(button_show=edit_button, dialog=styles_dialog, button_close=self.close)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -158,7 +158,7 @@ class UiSettings:
|
|||||||
loadsave.create_ui()
|
loadsave.create_ui()
|
||||||
|
|
||||||
with gr.TabItem("Sysinfo", id="sysinfo", elem_id="settings_tab_sysinfo"):
|
with gr.TabItem("Sysinfo", id="sysinfo", elem_id="settings_tab_sysinfo"):
|
||||||
gr.HTML('<a href="./internal/sysinfo-download" class="sysinfo_big_link" download>Download system info</a><br /><a href="./internal/sysinfo">(or open as text in a new page)</a>', elem_id="sysinfo_download")
|
gr.HTML('<a href="./internal/sysinfo-download" class="sysinfo_big_link" download>Download system info</a><br /><a href="./internal/sysinfo" target="_blank">(or open as text in a new page)</a>', elem_id="sysinfo_download")
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column(scale=1):
|
with gr.Column(scale=1):
|
||||||
|
@ -7,13 +7,14 @@ blendmodes
|
|||||||
clean-fid
|
clean-fid
|
||||||
einops
|
einops
|
||||||
gfpgan
|
gfpgan
|
||||||
gradio==3.32.0
|
gradio==3.39.0
|
||||||
inflection
|
inflection
|
||||||
jsonmerge
|
jsonmerge
|
||||||
kornia
|
kornia
|
||||||
lark
|
lark
|
||||||
numpy
|
numpy
|
||||||
omegaconf
|
omegaconf
|
||||||
|
open-clip-torch
|
||||||
|
|
||||||
piexif
|
piexif
|
||||||
psutil
|
psutil
|
||||||
@ -29,4 +30,4 @@ tomesd
|
|||||||
torch
|
torch
|
||||||
torchdiffeq
|
torchdiffeq
|
||||||
torchsde
|
torchsde
|
||||||
transformers==4.25.1
|
transformers==4.30.2
|
||||||
|
@ -1,30 +1,31 @@
|
|||||||
GitPython==3.1.30
|
GitPython==3.1.32
|
||||||
Pillow==9.5.0
|
Pillow==9.5.0
|
||||||
accelerate==0.18.0
|
accelerate==0.21.0
|
||||||
basicsr==1.4.2
|
basicsr==1.4.2
|
||||||
blendmodes==2022
|
blendmodes==2022
|
||||||
clean-fid==0.1.35
|
clean-fid==0.1.35
|
||||||
einops==0.4.1
|
einops==0.4.1
|
||||||
fastapi==0.94.0
|
fastapi==0.94.0
|
||||||
gfpgan==1.3.8
|
gfpgan==1.3.8
|
||||||
gradio==3.32.0
|
gradio==3.39.0
|
||||||
httpcore<=0.15
|
httpcore==0.15
|
||||||
inflection==0.5.1
|
inflection==0.5.1
|
||||||
jsonmerge==1.8.0
|
jsonmerge==1.8.0
|
||||||
kornia==0.6.7
|
kornia==0.6.7
|
||||||
lark==1.1.2
|
lark==1.1.2
|
||||||
numpy==1.23.5
|
numpy==1.23.5
|
||||||
omegaconf==2.2.3
|
omegaconf==2.2.3
|
||||||
|
open-clip-torch==2.20.0
|
||||||
piexif==1.1.3
|
piexif==1.1.3
|
||||||
psutil~=5.9.5
|
psutil==5.9.5
|
||||||
pytorch_lightning==1.9.4
|
pytorch_lightning==1.9.4
|
||||||
realesrgan==0.3.0
|
realesrgan==0.3.0
|
||||||
resize-right==0.0.2
|
resize-right==0.0.2
|
||||||
safetensors==0.3.1
|
safetensors==0.3.1
|
||||||
scikit-image==0.20.0
|
scikit-image==0.21.0
|
||||||
timm==0.6.7
|
timm==0.9.2
|
||||||
tomesd==0.1.2
|
tomesd==0.1.3
|
||||||
torch
|
torch
|
||||||
torchdiffeq==0.2.3
|
torchdiffeq==0.2.3
|
||||||
torchsde==0.2.5
|
torchsde==0.2.5
|
||||||
transformers==4.25.1
|
transformers==4.30.2
|
||||||
|
@ -3,6 +3,7 @@ from copy import copy
|
|||||||
from itertools import permutations, chain
|
from itertools import permutations, chain
|
||||||
import random
|
import random
|
||||||
import csv
|
import csv
|
||||||
|
import os.path
|
||||||
from io import StringIO
|
from io import StringIO
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -10,7 +11,7 @@ import numpy as np
|
|||||||
import modules.scripts as scripts
|
import modules.scripts as scripts
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
from modules import images, sd_samplers, processing, sd_models, sd_vae, sd_samplers_kdiffusion
|
from modules import images, sd_samplers, processing, sd_models, sd_vae, sd_samplers_kdiffusion, errors
|
||||||
from modules.processing import process_images, Processed, StableDiffusionProcessingTxt2Img
|
from modules.processing import process_images, Processed, StableDiffusionProcessingTxt2Img
|
||||||
from modules.shared import opts, state
|
from modules.shared import opts, state
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
@ -66,14 +67,6 @@ def apply_order(p, x, xs):
|
|||||||
p.prompt = prompt_tmp + p.prompt
|
p.prompt = prompt_tmp + p.prompt
|
||||||
|
|
||||||
|
|
||||||
def apply_sampler(p, x, xs):
|
|
||||||
sampler_name = sd_samplers.samplers_map.get(x.lower(), None)
|
|
||||||
if sampler_name is None:
|
|
||||||
raise RuntimeError(f"Unknown sampler: {x}")
|
|
||||||
|
|
||||||
p.sampler_name = sampler_name
|
|
||||||
|
|
||||||
|
|
||||||
def confirm_samplers(p, xs):
|
def confirm_samplers(p, xs):
|
||||||
for x in xs:
|
for x in xs:
|
||||||
if x.lower() not in sd_samplers.samplers_map:
|
if x.lower() not in sd_samplers.samplers_map:
|
||||||
@ -144,11 +137,20 @@ def apply_face_restore(p, opt, x):
|
|||||||
p.restore_faces = is_active
|
p.restore_faces = is_active
|
||||||
|
|
||||||
|
|
||||||
def apply_override(field):
|
def apply_override(field, boolean: bool = False):
|
||||||
def fun(p, x, xs):
|
def fun(p, x, xs):
|
||||||
|
if boolean:
|
||||||
|
x = True if x.lower() == "true" else False
|
||||||
p.override_settings[field] = x
|
p.override_settings[field] = x
|
||||||
return fun
|
return fun
|
||||||
|
|
||||||
|
|
||||||
|
def boolean_choice(reverse: bool = False):
|
||||||
|
def choice():
|
||||||
|
return ["False", "True"] if reverse else ["True", "False"]
|
||||||
|
return choice
|
||||||
|
|
||||||
|
|
||||||
def format_value_add_label(p, opt, x):
|
def format_value_add_label(p, opt, x):
|
||||||
if type(x) == float:
|
if type(x) == float:
|
||||||
x = round(x, 8)
|
x = round(x, 8)
|
||||||
@ -173,6 +175,8 @@ def do_nothing(p, x, xs):
|
|||||||
def format_nothing(p, opt, x):
|
def format_nothing(p, opt, x):
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
def format_remove_path(p, opt, x):
|
||||||
|
return os.path.basename(x)
|
||||||
|
|
||||||
def str_permutations(x):
|
def str_permutations(x):
|
||||||
"""dummy function for specifying it in AxisOption's type when you want to get a list of permutations"""
|
"""dummy function for specifying it in AxisOption's type when you want to get a list of permutations"""
|
||||||
@ -212,9 +216,10 @@ axis_options = [
|
|||||||
AxisOptionImg2Img("Image CFG Scale", float, apply_field("image_cfg_scale")),
|
AxisOptionImg2Img("Image CFG Scale", float, apply_field("image_cfg_scale")),
|
||||||
AxisOption("Prompt S/R", str, apply_prompt, format_value=format_value),
|
AxisOption("Prompt S/R", str, apply_prompt, format_value=format_value),
|
||||||
AxisOption("Prompt order", str_permutations, apply_order, format_value=format_value_join_list),
|
AxisOption("Prompt order", str_permutations, apply_order, format_value=format_value_join_list),
|
||||||
AxisOptionTxt2Img("Sampler", str, apply_sampler, format_value=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers]),
|
AxisOptionTxt2Img("Sampler", str, apply_field("sampler_name"), format_value=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers]),
|
||||||
AxisOptionImg2Img("Sampler", str, apply_sampler, format_value=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers_for_img2img]),
|
AxisOptionTxt2Img("Hires sampler", str, apply_field("hr_sampler_name"), confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers_for_img2img]),
|
||||||
AxisOption("Checkpoint name", str, apply_checkpoint, format_value=format_value, confirm=confirm_checkpoints, cost=1.0, choices=lambda: sorted(sd_models.checkpoints_list, key=str.casefold)),
|
AxisOptionImg2Img("Sampler", str, apply_field("sampler_name"), format_value=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers_for_img2img]),
|
||||||
|
AxisOption("Checkpoint name", str, apply_checkpoint, format_value=format_remove_path, confirm=confirm_checkpoints, cost=1.0, choices=lambda: sorted(sd_models.checkpoints_list, key=str.casefold)),
|
||||||
AxisOption("Negative Guidance minimum sigma", float, apply_field("s_min_uncond")),
|
AxisOption("Negative Guidance minimum sigma", float, apply_field("s_min_uncond")),
|
||||||
AxisOption("Sigma Churn", float, apply_field("s_churn")),
|
AxisOption("Sigma Churn", float, apply_field("s_churn")),
|
||||||
AxisOption("Sigma min", float, apply_field("s_tmin")),
|
AxisOption("Sigma min", float, apply_field("s_tmin")),
|
||||||
@ -235,6 +240,7 @@ axis_options = [
|
|||||||
AxisOption("Face restore", str, apply_face_restore, format_value=format_value),
|
AxisOption("Face restore", str, apply_face_restore, format_value=format_value),
|
||||||
AxisOption("Token merging ratio", float, apply_override('token_merging_ratio')),
|
AxisOption("Token merging ratio", float, apply_override('token_merging_ratio')),
|
||||||
AxisOption("Token merging ratio high-res", float, apply_override('token_merging_ratio_hr')),
|
AxisOption("Token merging ratio high-res", float, apply_override('token_merging_ratio_hr')),
|
||||||
|
AxisOption("Always discard next-to-last sigma", str, apply_override('always_discard_next_to_last_sigma', boolean=True), choices=boolean_choice(reverse=True)),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -638,7 +644,12 @@ class Script(scripts.Script):
|
|||||||
y_opt.apply(pc, y, ys)
|
y_opt.apply(pc, y, ys)
|
||||||
z_opt.apply(pc, z, zs)
|
z_opt.apply(pc, z, zs)
|
||||||
|
|
||||||
res = process_images(pc)
|
try:
|
||||||
|
res = process_images(pc)
|
||||||
|
except Exception as e:
|
||||||
|
errors.display(e, "generating image for xyz plot")
|
||||||
|
|
||||||
|
res = Processed(p, [], p.seed, "")
|
||||||
|
|
||||||
# Sets subgrid infotexts
|
# Sets subgrid infotexts
|
||||||
subgrid_index = 1 + iz
|
subgrid_index = 1 + iz
|
||||||
|
42
style.css
42
style.css
@ -8,6 +8,7 @@
|
|||||||
--checkbox-label-gap: 0.25em 0.1em;
|
--checkbox-label-gap: 0.25em 0.1em;
|
||||||
--section-header-text-size: 12pt;
|
--section-header-text-size: 12pt;
|
||||||
--block-background-fill: transparent;
|
--block-background-fill: transparent;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
.block.padded:not(.gradio-accordion) {
|
.block.padded:not(.gradio-accordion) {
|
||||||
@ -42,7 +43,8 @@ div.form{
|
|||||||
.block.gradio-radio,
|
.block.gradio-radio,
|
||||||
.block.gradio-checkboxgroup,
|
.block.gradio-checkboxgroup,
|
||||||
.block.gradio-number,
|
.block.gradio-number,
|
||||||
.block.gradio-colorpicker
|
.block.gradio-colorpicker,
|
||||||
|
div.gradio-group
|
||||||
{
|
{
|
||||||
border-width: 0 !important;
|
border-width: 0 !important;
|
||||||
box-shadow: none !important;
|
box-shadow: none !important;
|
||||||
@ -133,6 +135,15 @@ a{
|
|||||||
cursor: pointer;
|
cursor: pointer;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
div.styler{
|
||||||
|
border: none;
|
||||||
|
background: var(--background-fill-primary);
|
||||||
|
}
|
||||||
|
|
||||||
|
.block.gradio-textbox{
|
||||||
|
overflow: visible !important;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
/* general styled components */
|
/* general styled components */
|
||||||
|
|
||||||
@ -164,7 +175,7 @@ a{
|
|||||||
.checkboxes-row > div{
|
.checkboxes-row > div{
|
||||||
flex: 0;
|
flex: 0;
|
||||||
white-space: nowrap;
|
white-space: nowrap;
|
||||||
min-width: auto;
|
min-width: auto !important;
|
||||||
}
|
}
|
||||||
|
|
||||||
button.custom-button{
|
button.custom-button{
|
||||||
@ -388,6 +399,7 @@ div#extras_scale_to_tab div.form{
|
|||||||
#quicksettings > div, #quicksettings > fieldset{
|
#quicksettings > div, #quicksettings > fieldset{
|
||||||
max-width: 24em;
|
max-width: 24em;
|
||||||
min-width: 24em;
|
min-width: 24em;
|
||||||
|
width: 24em;
|
||||||
padding: 0;
|
padding: 0;
|
||||||
border: none;
|
border: none;
|
||||||
box-shadow: none;
|
box-shadow: none;
|
||||||
@ -423,15 +435,16 @@ div#extras_scale_to_tab div.form{
|
|||||||
}
|
}
|
||||||
|
|
||||||
table.popup-table{
|
table.popup-table{
|
||||||
background: white;
|
background: var(--body-background-fill);
|
||||||
|
color: var(--body-text-color);
|
||||||
border-collapse: collapse;
|
border-collapse: collapse;
|
||||||
margin: 1em;
|
margin: 1em;
|
||||||
border: 4px solid white;
|
border: 4px solid var(--body-background-fill);
|
||||||
}
|
}
|
||||||
|
|
||||||
table.popup-table td{
|
table.popup-table td{
|
||||||
padding: 0.4em;
|
padding: 0.4em;
|
||||||
border: 1px solid #ccc;
|
border: 1px solid rgba(128, 128, 128, 0.5);
|
||||||
max-width: 36em;
|
max-width: 36em;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -845,7 +858,7 @@ footer {
|
|||||||
|
|
||||||
.extra-network-cards .card .card-button {
|
.extra-network-cards .card .card-button {
|
||||||
text-shadow: 2px 2px 3px black;
|
text-shadow: 2px 2px 3px black;
|
||||||
padding: 0.25em;
|
padding: 0.25em 0.1em;
|
||||||
font-size: 200%;
|
font-size: 200%;
|
||||||
width: 1.5em;
|
width: 1.5em;
|
||||||
}
|
}
|
||||||
@ -961,6 +974,10 @@ div.block.gradio-box.edit-user-metadata {
|
|||||||
text-align: left;
|
text-align: left;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
.edit-user-metadata .file-metadata th, .edit-user-metadata .file-metadata td{
|
||||||
|
padding: 0.3em 1em;
|
||||||
|
}
|
||||||
|
|
||||||
.edit-user-metadata .wrap.translucent{
|
.edit-user-metadata .wrap.translucent{
|
||||||
background: var(--body-background-fill);
|
background: var(--body-background-fill);
|
||||||
}
|
}
|
||||||
@ -971,3 +988,16 @@ div.block.gradio-box.edit-user-metadata {
|
|||||||
.edit-user-metadata-buttons{
|
.edit-user-metadata-buttons{
|
||||||
margin-top: 1.5em;
|
margin-top: 1.5em;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
div.block.gradio-box.popup-dialog, .popup-dialog {
|
||||||
|
width: 56em;
|
||||||
|
background: var(--body-background-fill);
|
||||||
|
padding: 2em !important;
|
||||||
|
}
|
||||||
|
|
||||||
|
div.block.gradio-box.popup-dialog > div:last-child, .popup-dialog > div:last-child{
|
||||||
|
margin-top: 1em;
|
||||||
|
}
|
||||||
|
60
webui.py
60
webui.py
@ -14,7 +14,6 @@ from typing import Iterable
|
|||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from fastapi.middleware.gzip import GZipMiddleware
|
from fastapi.middleware.gzip import GZipMiddleware
|
||||||
from packaging import version
|
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
@ -31,24 +30,26 @@ if log_level:
|
|||||||
logging.getLogger("torch.distributed.nn").setLevel(logging.ERROR) # sshh...
|
logging.getLogger("torch.distributed.nn").setLevel(logging.ERROR) # sshh...
|
||||||
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
|
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
|
||||||
|
|
||||||
from modules import paths, timer, import_hook, errors, devices # noqa: F401
|
from modules import timer
|
||||||
|
|
||||||
startup_timer = timer.startup_timer
|
startup_timer = timer.startup_timer
|
||||||
|
startup_timer.record("launcher")
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import pytorch_lightning # noqa: F401 # pytorch_lightning should be imported after torch, but it re-enables warnings on import so import once to disable them
|
import pytorch_lightning # noqa: F401 # pytorch_lightning should be imported after torch, but it re-enables warnings on import so import once to disable them
|
||||||
warnings.filterwarnings(action="ignore", category=DeprecationWarning, module="pytorch_lightning")
|
warnings.filterwarnings(action="ignore", category=DeprecationWarning, module="pytorch_lightning")
|
||||||
warnings.filterwarnings(action="ignore", category=UserWarning, module="torchvision")
|
warnings.filterwarnings(action="ignore", category=UserWarning, module="torchvision")
|
||||||
|
|
||||||
|
|
||||||
startup_timer.record("import torch")
|
startup_timer.record("import torch")
|
||||||
|
|
||||||
import gradio # noqa: F401
|
import gradio # noqa: F401
|
||||||
startup_timer.record("import gradio")
|
startup_timer.record("import gradio")
|
||||||
|
|
||||||
|
from modules import paths, timer, import_hook, errors, devices # noqa: F401
|
||||||
|
startup_timer.record("setup paths")
|
||||||
|
|
||||||
import ldm.modules.encoders.modules # noqa: F401
|
import ldm.modules.encoders.modules # noqa: F401
|
||||||
startup_timer.record("import ldm")
|
startup_timer.record("import ldm")
|
||||||
|
|
||||||
|
|
||||||
from modules import extra_networks
|
from modules import extra_networks
|
||||||
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, queue_lock # noqa: F401
|
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, queue_lock # noqa: F401
|
||||||
|
|
||||||
@ -57,10 +58,15 @@ if ".dev" in torch.__version__ or "+git" in torch.__version__:
|
|||||||
torch.__long_version__ = torch.__version__
|
torch.__long_version__ = torch.__version__
|
||||||
torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0)
|
torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0)
|
||||||
|
|
||||||
from modules import shared, sd_samplers, upscaler, extensions, localization, ui_tempdir, ui_extra_networks, config_states
|
from modules import shared
|
||||||
|
|
||||||
|
if not shared.cmd_opts.skip_version_check:
|
||||||
|
errors.check_versions()
|
||||||
|
|
||||||
import modules.codeformer_model as codeformer
|
import modules.codeformer_model as codeformer
|
||||||
import modules.face_restoration
|
|
||||||
import modules.gfpgan_model as gfpgan
|
import modules.gfpgan_model as gfpgan
|
||||||
|
from modules import sd_samplers, upscaler, extensions, localization, ui_tempdir, ui_extra_networks, config_states
|
||||||
|
import modules.face_restoration
|
||||||
import modules.img2img
|
import modules.img2img
|
||||||
|
|
||||||
import modules.lowvram
|
import modules.lowvram
|
||||||
@ -129,37 +135,6 @@ def fix_asyncio_event_loop_policy():
|
|||||||
asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy())
|
asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy())
|
||||||
|
|
||||||
|
|
||||||
def check_versions():
|
|
||||||
if shared.cmd_opts.skip_version_check:
|
|
||||||
return
|
|
||||||
|
|
||||||
expected_torch_version = "2.0.0"
|
|
||||||
|
|
||||||
if version.parse(torch.__version__) < version.parse(expected_torch_version):
|
|
||||||
errors.print_error_explanation(f"""
|
|
||||||
You are running torch {torch.__version__}.
|
|
||||||
The program is tested to work with torch {expected_torch_version}.
|
|
||||||
To reinstall the desired version, run with commandline flag --reinstall-torch.
|
|
||||||
Beware that this will cause a lot of large files to be downloaded, as well as
|
|
||||||
there are reports of issues with training tab on the latest version.
|
|
||||||
|
|
||||||
Use --skip-version-check commandline argument to disable this check.
|
|
||||||
""".strip())
|
|
||||||
|
|
||||||
expected_xformers_version = "0.0.20"
|
|
||||||
if shared.xformers_available:
|
|
||||||
import xformers
|
|
||||||
|
|
||||||
if version.parse(xformers.__version__) < version.parse(expected_xformers_version):
|
|
||||||
errors.print_error_explanation(f"""
|
|
||||||
You are running xformers {xformers.__version__}.
|
|
||||||
The program is tested to work with xformers {expected_xformers_version}.
|
|
||||||
To reinstall the desired version, run with commandline flag --reinstall-xformers.
|
|
||||||
|
|
||||||
Use --skip-version-check commandline argument to disable this check.
|
|
||||||
""".strip())
|
|
||||||
|
|
||||||
|
|
||||||
def restore_config_state_file():
|
def restore_config_state_file():
|
||||||
config_state_file = shared.opts.restore_config_state_file
|
config_state_file = shared.opts.restore_config_state_file
|
||||||
if config_state_file == "":
|
if config_state_file == "":
|
||||||
@ -247,7 +222,6 @@ def initialize():
|
|||||||
fix_asyncio_event_loop_policy()
|
fix_asyncio_event_loop_policy()
|
||||||
validate_tls_options()
|
validate_tls_options()
|
||||||
configure_sigint_handler()
|
configure_sigint_handler()
|
||||||
check_versions()
|
|
||||||
modelloader.cleanup_models()
|
modelloader.cleanup_models()
|
||||||
configure_opts_onchange()
|
configure_opts_onchange()
|
||||||
|
|
||||||
@ -319,9 +293,9 @@ def initialize_rest(*, reload_script_modules=False):
|
|||||||
if modules.sd_hijack.current_optimizer is None:
|
if modules.sd_hijack.current_optimizer is None:
|
||||||
modules.sd_hijack.apply_optimizations()
|
modules.sd_hijack.apply_optimizations()
|
||||||
|
|
||||||
Thread(target=load_model).start()
|
devices.first_time_calculation()
|
||||||
|
|
||||||
Thread(target=devices.first_time_calculation).start()
|
Thread(target=load_model).start()
|
||||||
|
|
||||||
shared.reload_hypernetworks()
|
shared.reload_hypernetworks()
|
||||||
startup_timer.record("reload hypernetworks")
|
startup_timer.record("reload hypernetworks")
|
||||||
@ -373,7 +347,7 @@ def api_only():
|
|||||||
api.launch(
|
api.launch(
|
||||||
server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1",
|
server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1",
|
||||||
port=cmd_opts.port if cmd_opts.port else 7861,
|
port=cmd_opts.port if cmd_opts.port else 7861,
|
||||||
root_path = f"/{cmd_opts.subpath}"
|
root_path=f"/{cmd_opts.subpath}" if cmd_opts.subpath else ""
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -406,7 +380,7 @@ def webui():
|
|||||||
ssl_verify=cmd_opts.disable_tls_verify,
|
ssl_verify=cmd_opts.disable_tls_verify,
|
||||||
debug=cmd_opts.gradio_debug,
|
debug=cmd_opts.gradio_debug,
|
||||||
auth=gradio_auth_creds,
|
auth=gradio_auth_creds,
|
||||||
inbrowser=cmd_opts.autolaunch and os.getenv('SD_WEBUI_RESTARTING ') != '1',
|
inbrowser=cmd_opts.autolaunch and os.getenv('SD_WEBUI_RESTARTING') != '1',
|
||||||
prevent_thread_lock=True,
|
prevent_thread_lock=True,
|
||||||
allowed_paths=cmd_opts.gradio_allowed_path,
|
allowed_paths=cmd_opts.gradio_allowed_path,
|
||||||
app_kwargs={
|
app_kwargs={
|
||||||
|
15
webui.sh
15
webui.sh
@ -4,8 +4,15 @@
|
|||||||
# change the variables in webui-user.sh instead #
|
# change the variables in webui-user.sh instead #
|
||||||
#################################################
|
#################################################
|
||||||
|
|
||||||
|
|
||||||
|
use_venv=1
|
||||||
|
if [[ $venv_dir == "-" ]]; then
|
||||||
|
use_venv=0
|
||||||
|
fi
|
||||||
|
|
||||||
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
|
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
|
||||||
|
|
||||||
|
|
||||||
# If run from macOS, load defaults from webui-macos-env.sh
|
# If run from macOS, load defaults from webui-macos-env.sh
|
||||||
if [[ "$OSTYPE" == "darwin"* ]]; then
|
if [[ "$OSTYPE" == "darwin"* ]]; then
|
||||||
if [[ -f "$SCRIPT_DIR"/webui-macos-env.sh ]]
|
if [[ -f "$SCRIPT_DIR"/webui-macos-env.sh ]]
|
||||||
@ -47,7 +54,7 @@ then
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
# python3 venv without trailing slash (defaults to ${install_dir}/${clone_dir}/venv)
|
# python3 venv without trailing slash (defaults to ${install_dir}/${clone_dir}/venv)
|
||||||
if [[ -z "${venv_dir}" ]]
|
if [[ -z "${venv_dir}" ]] && [[ $use_venv -eq 1 ]]
|
||||||
then
|
then
|
||||||
venv_dir="venv"
|
venv_dir="venv"
|
||||||
fi
|
fi
|
||||||
@ -164,7 +171,7 @@ do
|
|||||||
fi
|
fi
|
||||||
done
|
done
|
||||||
|
|
||||||
if ! "${python_cmd}" -c "import venv" &>/dev/null
|
if [[ $use_venv -eq 1 ]] && ! "${python_cmd}" -c "import venv" &>/dev/null
|
||||||
then
|
then
|
||||||
printf "\n%s\n" "${delimiter}"
|
printf "\n%s\n" "${delimiter}"
|
||||||
printf "\e[1m\e[31mERROR: python3-venv is not installed, aborting...\e[0m"
|
printf "\e[1m\e[31mERROR: python3-venv is not installed, aborting...\e[0m"
|
||||||
@ -184,7 +191,7 @@ else
|
|||||||
cd "${clone_dir}"/ || { printf "\e[1m\e[31mERROR: Can't cd to %s/%s/, aborting...\e[0m" "${install_dir}" "${clone_dir}"; exit 1; }
|
cd "${clone_dir}"/ || { printf "\e[1m\e[31mERROR: Can't cd to %s/%s/, aborting...\e[0m" "${install_dir}" "${clone_dir}"; exit 1; }
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [[ -z "${VIRTUAL_ENV}" ]];
|
if [[ $use_venv -eq 1 ]] && [[ -z "${VIRTUAL_ENV}" ]];
|
||||||
then
|
then
|
||||||
printf "\n%s\n" "${delimiter}"
|
printf "\n%s\n" "${delimiter}"
|
||||||
printf "Create and activate python venv"
|
printf "Create and activate python venv"
|
||||||
@ -207,7 +214,7 @@ then
|
|||||||
fi
|
fi
|
||||||
else
|
else
|
||||||
printf "\n%s\n" "${delimiter}"
|
printf "\n%s\n" "${delimiter}"
|
||||||
printf "python venv already activate: ${VIRTUAL_ENV}"
|
printf "python venv already activate or run without venv: ${VIRTUAL_ENV}"
|
||||||
printf "\n%s\n" "${delimiter}"
|
printf "\n%s\n" "${delimiter}"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user