mirror of
https://github.com/openvinotoolkit/stable-diffusion-webui.git
synced 2024-12-14 22:53:25 +03:00
SD VAE rework 2
- the setting for preferring opts.sd_vae has been inverted and reworded - resolve_vae function made easier to read and now returns an object rather than a tuple - if the checkbox for overriding per-model preferences is checked, opts.sd_vae overrides checkpoint user metadata - changing VAE in user metadata for currently loaded model immediately applies the selection
This commit is contained in:
parent
5a38a9c0ee
commit
c96e4750d8
@ -356,7 +356,7 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
|
|||||||
|
|
||||||
sd_vae.delete_base_vae()
|
sd_vae.delete_base_vae()
|
||||||
sd_vae.clear_loaded_vae()
|
sd_vae.clear_loaded_vae()
|
||||||
vae_file, vae_source = sd_vae.resolve_vae(checkpoint_info.filename)
|
vae_file, vae_source = sd_vae.resolve_vae(checkpoint_info.filename).tuple()
|
||||||
sd_vae.load_vae(model, vae_file, vae_source)
|
sd_vae.load_vae(model, vae_file, vae_source)
|
||||||
timer.record("load VAE")
|
timer.record("load VAE")
|
||||||
|
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
import collections
|
import collections
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
from modules import paths, shared, devices, script_callbacks, sd_models, extra_networks
|
from modules import paths, shared, devices, script_callbacks, sd_models, extra_networks
|
||||||
import glob
|
import glob
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
@ -97,37 +99,74 @@ def find_vae_near_checkpoint(checkpoint_file):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def resolve_vae(checkpoint_file):
|
@dataclass
|
||||||
if shared.cmd_opts.vae_path is not None:
|
class VaeResolution:
|
||||||
return shared.cmd_opts.vae_path, 'from commandline argument'
|
vae: str = None
|
||||||
|
source: str = None
|
||||||
|
resolved: bool = True
|
||||||
|
|
||||||
|
def tuple(self):
|
||||||
|
return self.vae, self.source
|
||||||
|
|
||||||
|
|
||||||
|
def is_automatic():
|
||||||
|
return shared.opts.sd_vae in {"Automatic", "auto"} # "auto" for people with old config
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_vae_from_setting() -> VaeResolution:
|
||||||
|
if shared.opts.sd_vae == "None":
|
||||||
|
return VaeResolution()
|
||||||
|
|
||||||
|
vae_from_options = vae_dict.get(shared.opts.sd_vae, None)
|
||||||
|
if vae_from_options is not None:
|
||||||
|
return VaeResolution(vae_from_options, 'specified in settings')
|
||||||
|
|
||||||
|
if not is_automatic():
|
||||||
|
print(f"Couldn't find VAE named {shared.opts.sd_vae}; using None instead")
|
||||||
|
|
||||||
|
return VaeResolution(resolved=False)
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_vae_from_user_metadata(checkpoint_file) -> VaeResolution:
|
||||||
metadata = extra_networks.get_user_metadata(checkpoint_file)
|
metadata = extra_networks.get_user_metadata(checkpoint_file)
|
||||||
vae_metadata = metadata.get("vae", None)
|
vae_metadata = metadata.get("vae", None)
|
||||||
if vae_metadata is not None and vae_metadata != "Automatic":
|
if vae_metadata is not None and vae_metadata != "Automatic":
|
||||||
if vae_metadata == "None":
|
if vae_metadata == "None":
|
||||||
return None, None
|
return VaeResolution()
|
||||||
|
|
||||||
vae_from_metadata = vae_dict.get(vae_metadata, None)
|
vae_from_metadata = vae_dict.get(vae_metadata, None)
|
||||||
if vae_from_metadata is not None:
|
if vae_from_metadata is not None:
|
||||||
return vae_from_metadata, "from user metadata"
|
return VaeResolution(vae_from_metadata, "from user metadata")
|
||||||
|
|
||||||
is_automatic = shared.opts.sd_vae in {"Automatic", "auto"} # "auto" for people with old config
|
return VaeResolution(resolved=False)
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_vae_near_checkpoint(checkpoint_file) -> VaeResolution:
|
||||||
vae_near_checkpoint = find_vae_near_checkpoint(checkpoint_file)
|
vae_near_checkpoint = find_vae_near_checkpoint(checkpoint_file)
|
||||||
if vae_near_checkpoint is not None and (shared.opts.sd_vae_as_default or is_automatic):
|
if vae_near_checkpoint is not None and (shared.opts.sd_vae_as_default or is_automatic):
|
||||||
return vae_near_checkpoint, 'found near the checkpoint'
|
return VaeResolution(vae_near_checkpoint, 'found near the checkpoint')
|
||||||
|
|
||||||
if shared.opts.sd_vae == "None":
|
return VaeResolution(resolved=False)
|
||||||
return None, None
|
|
||||||
|
|
||||||
vae_from_options = vae_dict.get(shared.opts.sd_vae, None)
|
|
||||||
if vae_from_options is not None:
|
|
||||||
return vae_from_options, 'specified in settings'
|
|
||||||
|
|
||||||
if not is_automatic:
|
def resolve_vae(checkpoint_file) -> VaeResolution:
|
||||||
print(f"Couldn't find VAE named {shared.opts.sd_vae}; using None instead")
|
if shared.cmd_opts.vae_path is not None:
|
||||||
|
return VaeResolution(shared.cmd_opts.vae_path, 'from commandline argument')
|
||||||
|
|
||||||
return None, None
|
if shared.opts.sd_vae_overrides_per_model_preferences and not is_automatic():
|
||||||
|
return resolve_vae_from_setting()
|
||||||
|
|
||||||
|
res = resolve_vae_from_user_metadata(checkpoint_file)
|
||||||
|
if res.resolved:
|
||||||
|
return res
|
||||||
|
|
||||||
|
res = resolve_vae_near_checkpoint(checkpoint_file)
|
||||||
|
if res.resolved:
|
||||||
|
return res
|
||||||
|
|
||||||
|
res = resolve_vae_from_setting()
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
def load_vae_dict(filename, map_location):
|
def load_vae_dict(filename, map_location):
|
||||||
@ -201,7 +240,7 @@ def reload_vae_weights(sd_model=None, vae_file=unspecified):
|
|||||||
checkpoint_file = checkpoint_info.filename
|
checkpoint_file = checkpoint_info.filename
|
||||||
|
|
||||||
if vae_file == unspecified:
|
if vae_file == unspecified:
|
||||||
vae_file, vae_source = resolve_vae(checkpoint_file)
|
vae_file, vae_source = resolve_vae(checkpoint_file).tuple()
|
||||||
else:
|
else:
|
||||||
vae_source = "from function argument"
|
vae_source = "from function argument"
|
||||||
|
|
||||||
|
@ -479,7 +479,7 @@ For img2img, VAE is used to process user's input image before the sampling, and
|
|||||||
"""),
|
"""),
|
||||||
"sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
|
"sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
|
||||||
"sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": shared_items.sd_vae_items()}, refresh=shared_items.refresh_vae_list).info("choose VAE model: Automatic = use one with same filename as checkpoint; None = use VAE from checkpoint"),
|
"sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": shared_items.sd_vae_items()}, refresh=shared_items.refresh_vae_list).info("choose VAE model: Automatic = use one with same filename as checkpoint; None = use VAE from checkpoint"),
|
||||||
"sd_vae_as_default": OptionInfo(True, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"),
|
"sd_vae_overrides_per_model_preferences": OptionInfo(True, "Selected VAE overrides per-model preferences").info("you can set per-model VAE either by editing user metadata for checkpoints, or by making the VAE have same name as checkpoint"),
|
||||||
"auto_vae_precision": OptionInfo(True, "Automaticlly revert VAE to 32-bit floats").info("triggers when a tensor with NaNs is produced in VAE; disabling the option in this case will result in a black square image"),
|
"auto_vae_precision": OptionInfo(True, "Automaticlly revert VAE to 32-bit floats").info("triggers when a tensor with NaNs is produced in VAE; disabling the option in this case will result in a black square image"),
|
||||||
"sd_vae_encode_method": OptionInfo("Full", "VAE type for encode", gr.Radio, {"choices": ["Full", "TAESD"]}).info("method to encode image to latent (use in img2img, hires-fix or inpaint mask)"),
|
"sd_vae_encode_method": OptionInfo("Full", "VAE type for encode", gr.Radio, {"choices": ["Full", "TAESD"]}).info("method to encode image to latent (use in img2img, hires-fix or inpaint mask)"),
|
||||||
"sd_vae_decode_method": OptionInfo("Full", "VAE type for decode", gr.Radio, {"choices": ["Full", "TAESD"]}).info("method to decode latent to image"),
|
"sd_vae_decode_method": OptionInfo("Full", "VAE type for decode", gr.Radio, {"choices": ["Full", "TAESD"]}).info("method to decode latent to image"),
|
||||||
@ -733,6 +733,10 @@ class Options:
|
|||||||
with open(filename, "r", encoding="utf8") as file:
|
with open(filename, "r", encoding="utf8") as file:
|
||||||
self.data = json.load(file)
|
self.data = json.load(file)
|
||||||
|
|
||||||
|
# 1.6.0 VAE defaults
|
||||||
|
if self.data.get('sd_vae_as_default') is not None and self.data.get('sd_vae_overrides_per_model_preferences') is None:
|
||||||
|
self.data['sd_vae_overrides_per_model_preferences'] = not self.data.get('sd_vae_as_default')
|
||||||
|
|
||||||
# 1.1.1 quicksettings list migration
|
# 1.1.1 quicksettings list migration
|
||||||
if self.data.get('quicksettings') is not None and self.data.get('quicksettings_list') is None:
|
if self.data.get('quicksettings') is not None and self.data.get('quicksettings_list') is None:
|
||||||
self.data['quicksettings_list'] = [i.strip() for i in self.data.get('quicksettings').split(',')]
|
self.data['quicksettings_list'] = [i.strip() for i in self.data.get('quicksettings').split(',')]
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
from modules import ui_extra_networks_user_metadata, sd_vae
|
from modules import ui_extra_networks_user_metadata, sd_vae, shared
|
||||||
from modules.ui_common import create_refresh_button
|
from modules.ui_common import create_refresh_button
|
||||||
|
|
||||||
|
|
||||||
@ -18,6 +18,10 @@ class CheckpointUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataE
|
|||||||
|
|
||||||
self.write_user_metadata(name, user_metadata)
|
self.write_user_metadata(name, user_metadata)
|
||||||
|
|
||||||
|
def update_vae(self, name):
|
||||||
|
if name == shared.sd_model.sd_checkpoint_info.name_for_extra:
|
||||||
|
sd_vae.reload_vae_weights()
|
||||||
|
|
||||||
def put_values_into_components(self, name):
|
def put_values_into_components(self, name):
|
||||||
user_metadata = self.get_user_metadata(name)
|
user_metadata = self.get_user_metadata(name)
|
||||||
values = super().put_values_into_components(name)
|
values = super().put_values_into_components(name)
|
||||||
@ -58,3 +62,5 @@ class CheckpointUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataE
|
|||||||
]
|
]
|
||||||
|
|
||||||
self.setup_save_handler(self.button_save, self.save_user_metadata, edited_components)
|
self.setup_save_handler(self.button_save, self.save_user_metadata, edited_components)
|
||||||
|
self.button_save.click(fn=self.update_vae, inputs=[self.edit_name_input])
|
||||||
|
|
||||||
|
2
webui.py
2
webui.py
@ -211,7 +211,7 @@ def configure_sigint_handler():
|
|||||||
def configure_opts_onchange():
|
def configure_opts_onchange():
|
||||||
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights()), call=False)
|
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights()), call=False)
|
||||||
shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
|
shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
|
||||||
shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
|
shared.opts.onchange("sd_vae_overrides_per_model_preferences", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
|
||||||
shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed)
|
shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed)
|
||||||
shared.opts.onchange("gradio_theme", shared.reload_gradio_theme)
|
shared.opts.onchange("gradio_theme", shared.reload_gradio_theme)
|
||||||
shared.opts.onchange("cross_attention_optimization", wrap_queued_call(lambda: modules.sd_hijack.model_hijack.redo_hijack(shared.sd_model)), call=False)
|
shared.opts.onchange("cross_attention_optimization", wrap_queued_call(lambda: modules.sd_hijack.model_hijack.redo_hijack(shared.sd_model)), call=False)
|
||||||
|
Loading…
Reference in New Issue
Block a user