mirror of
https://github.com/openvinotoolkit/stable-diffusion-webui.git
synced 2024-12-14 22:53:25 +03:00
custom unet support
This commit is contained in:
parent
a6e653be26
commit
339b531570
@ -13,7 +13,7 @@ from skimage import exposure
|
|||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
import modules.sd_hijack
|
import modules.sd_hijack
|
||||||
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts, sd_samplers_common
|
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet
|
||||||
from modules.sd_hijack import model_hijack
|
from modules.sd_hijack import model_hijack
|
||||||
from modules.shared import opts, cmd_opts, state
|
from modules.shared import opts, cmd_opts, state
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
@ -674,6 +674,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||||||
if shared.opts.live_previews_enable and opts.show_progress_type == "Approx NN":
|
if shared.opts.live_previews_enable and opts.show_progress_type == "Approx NN":
|
||||||
sd_vae_approx.model()
|
sd_vae_approx.model()
|
||||||
|
|
||||||
|
sd_unet.apply_unet()
|
||||||
|
|
||||||
if state.job_count == -1:
|
if state.job_count == -1:
|
||||||
state.job_count = p.n_iter
|
state.job_count = p.n_iter
|
||||||
|
|
||||||
|
@ -111,6 +111,7 @@ callback_map = dict(
|
|||||||
callbacks_before_ui=[],
|
callbacks_before_ui=[],
|
||||||
callbacks_on_reload=[],
|
callbacks_on_reload=[],
|
||||||
callbacks_list_optimizers=[],
|
callbacks_list_optimizers=[],
|
||||||
|
callbacks_list_unets=[],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -271,6 +272,18 @@ def list_optimizers_callback():
|
|||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
def list_unets_callback():
|
||||||
|
res = []
|
||||||
|
|
||||||
|
for c in callback_map['callbacks_list_unets']:
|
||||||
|
try:
|
||||||
|
c.callback(res)
|
||||||
|
except Exception:
|
||||||
|
report_exception(c, 'list_unets')
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
def add_callback(callbacks, fun):
|
def add_callback(callbacks, fun):
|
||||||
stack = [x for x in inspect.stack() if x.filename != __file__]
|
stack = [x for x in inspect.stack() if x.filename != __file__]
|
||||||
filename = stack[0].filename if len(stack) > 0 else 'unknown file'
|
filename = stack[0].filename if len(stack) > 0 else 'unknown file'
|
||||||
@ -430,3 +443,10 @@ def on_list_optimizers(callback):
|
|||||||
to it."""
|
to it."""
|
||||||
|
|
||||||
add_callback(callback_map['callbacks_list_optimizers'], callback)
|
add_callback(callback_map['callbacks_list_optimizers'], callback)
|
||||||
|
|
||||||
|
|
||||||
|
def on_list_unets(callback):
|
||||||
|
"""register a function to be called when UI is making a list of alternative options for unet.
|
||||||
|
The function will be called with one argument, a list, and shall add objects of type modules.sd_unet.SdUnetOption to it."""
|
||||||
|
|
||||||
|
add_callback(callback_map['callbacks_list_unets'], callback)
|
||||||
|
@ -3,7 +3,7 @@ from torch.nn.functional import silu
|
|||||||
from types import MethodType
|
from types import MethodType
|
||||||
|
|
||||||
import modules.textual_inversion.textual_inversion
|
import modules.textual_inversion.textual_inversion
|
||||||
from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors
|
from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet
|
||||||
from modules.hypernetworks import hypernetwork
|
from modules.hypernetworks import hypernetwork
|
||||||
from modules.shared import cmd_opts
|
from modules.shared import cmd_opts
|
||||||
from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr
|
from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr
|
||||||
@ -43,7 +43,7 @@ def list_optimizers():
|
|||||||
optimizers.extend(new_optimizers)
|
optimizers.extend(new_optimizers)
|
||||||
|
|
||||||
|
|
||||||
def apply_optimizations():
|
def apply_optimizations(option=None):
|
||||||
global current_optimizer
|
global current_optimizer
|
||||||
|
|
||||||
undo_optimizations()
|
undo_optimizations()
|
||||||
@ -60,7 +60,7 @@ def apply_optimizations():
|
|||||||
current_optimizer.undo()
|
current_optimizer.undo()
|
||||||
current_optimizer = None
|
current_optimizer = None
|
||||||
|
|
||||||
selection = shared.opts.cross_attention_optimization
|
selection = option or shared.opts.cross_attention_optimization
|
||||||
if selection == "Automatic" and len(optimizers) > 0:
|
if selection == "Automatic" and len(optimizers) > 0:
|
||||||
matching_optimizer = next(iter([x for x in optimizers if x.cmd_opt and getattr(shared.cmd_opts, x.cmd_opt, False)]), optimizers[0])
|
matching_optimizer = next(iter([x for x in optimizers if x.cmd_opt and getattr(shared.cmd_opts, x.cmd_opt, False)]), optimizers[0])
|
||||||
else:
|
else:
|
||||||
@ -72,12 +72,13 @@ def apply_optimizations():
|
|||||||
matching_optimizer = optimizers[0]
|
matching_optimizer = optimizers[0]
|
||||||
|
|
||||||
if matching_optimizer is not None:
|
if matching_optimizer is not None:
|
||||||
print(f"Applying optimization: {matching_optimizer.name}... ", end='')
|
print(f"Applying attention optimization: {matching_optimizer.name}... ", end='')
|
||||||
matching_optimizer.apply()
|
matching_optimizer.apply()
|
||||||
print("done.")
|
print("done.")
|
||||||
current_optimizer = matching_optimizer
|
current_optimizer = matching_optimizer
|
||||||
return current_optimizer.name
|
return current_optimizer.name
|
||||||
else:
|
else:
|
||||||
|
print("Disabling attention optimization")
|
||||||
return ''
|
return ''
|
||||||
|
|
||||||
|
|
||||||
@ -155,9 +156,9 @@ class StableDiffusionModelHijack:
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.embedding_db.add_embedding_dir(cmd_opts.embeddings_dir)
|
self.embedding_db.add_embedding_dir(cmd_opts.embeddings_dir)
|
||||||
|
|
||||||
def apply_optimizations(self):
|
def apply_optimizations(self, option=None):
|
||||||
try:
|
try:
|
||||||
self.optimization_method = apply_optimizations()
|
self.optimization_method = apply_optimizations(option)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
errors.display(e, "applying cross attention optimization")
|
errors.display(e, "applying cross attention optimization")
|
||||||
undo_optimizations()
|
undo_optimizations()
|
||||||
@ -194,6 +195,11 @@ class StableDiffusionModelHijack:
|
|||||||
|
|
||||||
self.layers = flatten(m)
|
self.layers = flatten(m)
|
||||||
|
|
||||||
|
if not hasattr(ldm.modules.diffusionmodules.openaimodel, 'copy_of_UNetModel_forward_for_webui'):
|
||||||
|
ldm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui = ldm.modules.diffusionmodules.openaimodel.UNetModel.forward
|
||||||
|
|
||||||
|
ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = sd_unet.UNetModel_forward
|
||||||
|
|
||||||
def undo_hijack(self, m):
|
def undo_hijack(self, m):
|
||||||
if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
|
if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
|
||||||
m.cond_stage_model = m.cond_stage_model.wrapped
|
m.cond_stage_model = m.cond_stage_model.wrapped
|
||||||
@ -215,6 +221,8 @@ class StableDiffusionModelHijack:
|
|||||||
self.layers = None
|
self.layers = None
|
||||||
self.clip = None
|
self.clip = None
|
||||||
|
|
||||||
|
ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = ldm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui
|
||||||
|
|
||||||
def apply_circular(self, enable):
|
def apply_circular(self, enable):
|
||||||
if self.circular_enabled == enable:
|
if self.circular_enabled == enable:
|
||||||
return
|
return
|
||||||
|
@ -14,7 +14,7 @@ import ldm.modules.midas as midas
|
|||||||
|
|
||||||
from ldm.util import instantiate_from_config
|
from ldm.util import instantiate_from_config
|
||||||
|
|
||||||
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config
|
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet
|
||||||
from modules.sd_hijack_inpainting import do_inpainting_hijack
|
from modules.sd_hijack_inpainting import do_inpainting_hijack
|
||||||
from modules.timer import Timer
|
from modules.timer import Timer
|
||||||
import tomesd
|
import tomesd
|
||||||
@ -532,6 +532,8 @@ def reload_model_weights(sd_model=None, info=None):
|
|||||||
if sd_model.sd_model_checkpoint == checkpoint_info.filename:
|
if sd_model.sd_model_checkpoint == checkpoint_info.filename:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
sd_unet.apply_unet("None")
|
||||||
|
|
||||||
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
||||||
lowvram.send_everything_to_cpu()
|
lowvram.send_everything_to_cpu()
|
||||||
else:
|
else:
|
||||||
|
92
modules/sd_unet.py
Normal file
92
modules/sd_unet.py
Normal file
@ -0,0 +1,92 @@
|
|||||||
|
import torch.nn
|
||||||
|
import ldm.modules.diffusionmodules.openaimodel
|
||||||
|
|
||||||
|
from modules import script_callbacks, shared, devices
|
||||||
|
|
||||||
|
unet_options = []
|
||||||
|
current_unet_option = None
|
||||||
|
current_unet = None
|
||||||
|
|
||||||
|
|
||||||
|
def list_unets():
|
||||||
|
new_unets = script_callbacks.list_unets_callback()
|
||||||
|
|
||||||
|
unet_options.clear()
|
||||||
|
unet_options.extend(new_unets)
|
||||||
|
|
||||||
|
|
||||||
|
def get_unet_option(option=None):
|
||||||
|
option = option or shared.opts.sd_unet
|
||||||
|
|
||||||
|
if option == "None":
|
||||||
|
return None
|
||||||
|
|
||||||
|
if option == "Automatic":
|
||||||
|
name = shared.sd_model.sd_checkpoint_info.model_name
|
||||||
|
|
||||||
|
options = [x for x in unet_options if x.model_name == name]
|
||||||
|
|
||||||
|
option = options[0].label if options else "None"
|
||||||
|
|
||||||
|
return next(iter([x for x in unet_options if x.label == option]), None)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_unet(option=None):
|
||||||
|
global current_unet_option
|
||||||
|
global current_unet
|
||||||
|
|
||||||
|
new_option = get_unet_option(option)
|
||||||
|
if new_option == current_unet_option:
|
||||||
|
return
|
||||||
|
|
||||||
|
if current_unet is not None:
|
||||||
|
print(f"Dectivating unet: {current_unet.option.label}")
|
||||||
|
current_unet.deactivate()
|
||||||
|
|
||||||
|
current_unet_option = new_option
|
||||||
|
if current_unet_option is None:
|
||||||
|
current_unet = None
|
||||||
|
|
||||||
|
if not (shared.cmd_opts.lowvram or shared.cmd_opts.medvram):
|
||||||
|
shared.sd_model.model.diffusion_model.to(devices.device)
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
|
shared.sd_model.model.diffusion_model.to(devices.cpu)
|
||||||
|
devices.torch_gc()
|
||||||
|
|
||||||
|
current_unet = current_unet_option.create_unet()
|
||||||
|
current_unet.option = current_unet_option
|
||||||
|
print(f"Activating unet: {current_unet.option.label}")
|
||||||
|
current_unet.activate()
|
||||||
|
|
||||||
|
|
||||||
|
class SdUnetOption:
|
||||||
|
model_name = None
|
||||||
|
"""name of related checkpoint - this option will be selected automatically for unet if the name of checkpoint matches this"""
|
||||||
|
|
||||||
|
label = None
|
||||||
|
"""name of the unet in UI"""
|
||||||
|
|
||||||
|
def create_unet(self):
|
||||||
|
"""returns SdUnet object to be used as a Unet instead of built-in unet when making pictures"""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
|
class SdUnet(torch.nn.Module):
|
||||||
|
def forward(self, x, timesteps, context, *args, **kwargs):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def activate(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def deactivate(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def UNetModel_forward(self, x, timesteps=None, context=None, *args, **kwargs):
|
||||||
|
if current_unet is not None:
|
||||||
|
return current_unet.forward(x, timesteps, context, *args, **kwargs)
|
||||||
|
|
||||||
|
return ldm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui(self, x, timesteps, context, *args, **kwargs)
|
||||||
|
|
@ -403,6 +403,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
|
|||||||
"sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
|
"sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
|
||||||
"sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": shared_items.sd_vae_items()}, refresh=shared_items.refresh_vae_list).info("choose VAE model: Automatic = use one with same filename as checkpoint; None = use VAE from checkpoint"),
|
"sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": shared_items.sd_vae_items()}, refresh=shared_items.refresh_vae_list).info("choose VAE model: Automatic = use one with same filename as checkpoint; None = use VAE from checkpoint"),
|
||||||
"sd_vae_as_default": OptionInfo(True, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"),
|
"sd_vae_as_default": OptionInfo(True, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"),
|
||||||
|
"sd_unet": OptionInfo("Automatic", "SD Unet", gr.Dropdown, lambda: {"choices": shared_items.sd_unet_items()}, refresh=shared_items.refresh_unet_list).info("choose Unet model: Automatic = use one with same filename as checkpoint; None = use Unet from checkpoint"),
|
||||||
"inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
"inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
||||||
"initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.5, "maximum": 1.5, "step": 0.01}),
|
"initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.5, "maximum": 1.5, "step": 0.01}),
|
||||||
"img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
|
"img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
|
||||||
|
@ -29,3 +29,14 @@ def cross_attention_optimizations():
|
|||||||
return ["Automatic"] + [x.title() for x in modules.sd_hijack.optimizers] + ["None"]
|
return ["Automatic"] + [x.title() for x in modules.sd_hijack.optimizers] + ["None"]
|
||||||
|
|
||||||
|
|
||||||
|
def sd_unet_items():
|
||||||
|
import modules.sd_unet
|
||||||
|
|
||||||
|
return ["Automatic"] + [x.label for x in modules.sd_unet.unet_options] + ["None"]
|
||||||
|
|
||||||
|
|
||||||
|
def refresh_unet_list():
|
||||||
|
import modules.sd_unet
|
||||||
|
|
||||||
|
modules.sd_unet.list_unets()
|
||||||
|
|
||||||
|
4
webui.py
4
webui.py
@ -58,6 +58,7 @@ import modules.sd_hijack
|
|||||||
import modules.sd_hijack_optimizations
|
import modules.sd_hijack_optimizations
|
||||||
import modules.sd_models
|
import modules.sd_models
|
||||||
import modules.sd_vae
|
import modules.sd_vae
|
||||||
|
import modules.sd_unet
|
||||||
import modules.txt2img
|
import modules.txt2img
|
||||||
import modules.script_callbacks
|
import modules.script_callbacks
|
||||||
import modules.textual_inversion.textual_inversion
|
import modules.textual_inversion.textual_inversion
|
||||||
@ -291,6 +292,9 @@ def initialize_rest(*, reload_script_modules=False):
|
|||||||
modules.sd_hijack.list_optimizers()
|
modules.sd_hijack.list_optimizers()
|
||||||
startup_timer.record("scripts list_optimizers")
|
startup_timer.record("scripts list_optimizers")
|
||||||
|
|
||||||
|
modules.sd_unet.list_unets()
|
||||||
|
startup_timer.record("scripts list_unets")
|
||||||
|
|
||||||
def load_model():
|
def load_model():
|
||||||
"""
|
"""
|
||||||
Accesses shared.sd_model property to load model.
|
Accesses shared.sd_model property to load model.
|
||||||
|
Loading…
Reference in New Issue
Block a user