mirror of
https://github.com/openvinotoolkit/stable-diffusion-webui.git
synced 2024-12-15 07:03:06 +03:00
Added initial lora support and removed redundant functions
This commit is contained in:
parent
eb7b1a9af4
commit
c337bed5ff
@ -16,7 +16,7 @@ import modules.scripts as scripts
|
||||
import modules.shared as shared
|
||||
|
||||
from modules import images, devices, extra_networks, generation_parameters_copypaste, masking, sd_samplers, sd_samplers_compvis, sd_samplers_kdiffusion, shared
|
||||
from modules.processing import StableDiffusionProcessing, Processed, apply_overlay, process_images, get_fixed_seed, program_version, StableDiffusionProcessingImg2Img, create_random_tensors
|
||||
from modules.processing import StableDiffusionProcessing, Processed, apply_overlay, process_images, get_fixed_seed, program_version, StableDiffusionProcessingImg2Img, create_random_tensors, create_infotext
|
||||
from modules.sd_models import list_models, CheckpointInfo
|
||||
from modules.sd_samplers_common import samples_to_image_grid, sample_to_image
|
||||
from modules.shared import Shared, cmd_opts, opts, state
|
||||
@ -47,7 +47,6 @@ from diffusers import (
|
||||
first_inference_global = 1
|
||||
sampler_name_global = "Euler a"
|
||||
openvino_device_global = "CPU"
|
||||
warmed_up_global = 0
|
||||
|
||||
def sd_diffusers_model(self):
|
||||
import modules.sd_models
|
||||
@ -95,6 +94,8 @@ def set_scheduler(sd_model, sampler_name):
|
||||
def get_diffusers_sd_model(sampler_name, enable_caching, openvino_device):
|
||||
global first_inference_global, sampler_name_global
|
||||
if (first_inference_global == 1):
|
||||
torch._dynamo.reset()
|
||||
torch._dynamo.config.verbose=True
|
||||
curr_dir_path = os.getcwd()
|
||||
model_path = "/models/Stable-diffusion/"
|
||||
checkpoint_name = shared.opts.sd_model_checkpoint.split(" ")[0]
|
||||
@ -103,8 +104,9 @@ def get_diffusers_sd_model(sampler_name, enable_caching, openvino_device):
|
||||
checkpoint_info = CheckpointInfo(checkpoint_path)
|
||||
sd_model.sd_checkpoint_info = checkpoint_info
|
||||
sd_model.sd_model_hash = checkpoint_info.calculate_shorthash()
|
||||
|
||||
sd_model.safety_checker = None
|
||||
sd_model.cond_stage_key = functools.partial(cond_stage_key, shared.sd_model)
|
||||
|
||||
sd_model.scheduler = set_scheduler(sd_model, sampler_name)
|
||||
sd_model.unet = torch.compile(sd_model.unet, backend="openvino")
|
||||
sd_model.vae.decode = torch.compile(sd_model.vae.decode, backend="openvino")
|
||||
@ -115,57 +117,13 @@ def get_diffusers_sd_model(sampler_name, enable_caching, openvino_device):
|
||||
if enable_caching:
|
||||
os.environ["OPENVINO_TORCH_MODEL_CACHING"] = "1"
|
||||
image = sd_model(warmup_prompt, num_inference_steps=1).images[0]
|
||||
print("warm up run complete")
|
||||
|
||||
first_inference_global = 0
|
||||
shared.sd_diffusers_model = sd_model
|
||||
del sd_model
|
||||
shared.sd_diffusers_model.cond_stage_key = functools.partial(cond_stage_key, shared.sd_diffusers_model)
|
||||
|
||||
return shared.sd_diffusers_model
|
||||
|
||||
def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iteration=0, position_in_batch=0):
|
||||
index = position_in_batch + iteration * p.batch_size
|
||||
|
||||
clip_skip = getattr(p, 'clip_skip', opts.CLIP_stop_at_last_layers)
|
||||
enable_hr = getattr(p, 'enable_hr', False)
|
||||
token_merging_ratio = p.get_token_merging_ratio()
|
||||
token_merging_ratio_hr = p.get_token_merging_ratio(for_hr=True)
|
||||
|
||||
uses_ensd = opts.eta_noise_seed_delta != 0
|
||||
if uses_ensd:
|
||||
uses_ensd = sd_samplers_common.is_sampler_using_eta_noise_seed_delta(p)
|
||||
|
||||
generation_params = {
|
||||
"Steps": p.steps,
|
||||
"Sampler": p.sampler_name,
|
||||
"CFG scale": p.cfg_scale,
|
||||
"Image CFG scale": getattr(p, 'image_cfg_scale', None),
|
||||
"Seed": all_seeds[index],
|
||||
"Face restoration": (opts.face_restoration_model if p.restore_faces else None),
|
||||
"Size": f"{p.width}x{p.height}",
|
||||
"Model hash": getattr(p, 'sd_model_hash', None),
|
||||
"Model": None,
|
||||
"Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),
|
||||
"Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
|
||||
"Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
|
||||
"Denoising strength": getattr(p, 'denoising_strength', None),
|
||||
"Conditional mask weight": getattr(p, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) if p.is_using_inpainting_conditioning else None,
|
||||
"Clip skip": None if clip_skip <= 1 else clip_skip,
|
||||
"ENSD": opts.eta_noise_seed_delta if uses_ensd else None,
|
||||
"Token merging ratio": None if token_merging_ratio == 0 else token_merging_ratio,
|
||||
"Token merging ratio hr": None if not enable_hr or token_merging_ratio_hr == 0 else token_merging_ratio_hr,
|
||||
"Init image hash": getattr(p, 'init_img_hash', None),
|
||||
"RNG": opts.randn_source if opts.randn_source != "GPU" else None,
|
||||
"NGMS": None if p.s_min_uncond == 0 else p.s_min_uncond,
|
||||
**p.extra_generation_params,
|
||||
"Version": program_version() if opts.add_version_to_infotext else None,
|
||||
}
|
||||
|
||||
generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None])
|
||||
|
||||
negative_prompt_text = f"\nNegative prompt: {p.all_negative_prompts[index]}" if p.all_negative_prompts[index] else ""
|
||||
|
||||
return f"{all_prompts[index]}{negative_prompt_text}\n{generation_params_text}".strip()
|
||||
|
||||
|
||||
def init_new(self, all_prompts, all_seeds, all_subseeds):
|
||||
crop_region = None
|
||||
@ -290,7 +248,7 @@ def init_new(self, all_prompts, all_seeds, all_subseeds):
|
||||
self.init_latent = self.init_latent * self.mask
|
||||
|
||||
|
||||
def process_images_openvino(p: StableDiffusionProcessing) -> Processed:
|
||||
def process_images_openvino(p: StableDiffusionProcessing, sampler_name, enable_caching, openvino_device) -> Processed:
|
||||
"""this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
|
||||
|
||||
if type(p.prompt) == list:
|
||||
@ -354,19 +312,25 @@ def process_images_openvino(p: StableDiffusionProcessing) -> Processed:
|
||||
|
||||
if len(p.prompts) == 0:
|
||||
break
|
||||
|
||||
shared.sd_diffusers_model = get_diffusers_sd_model(sampler_name, enable_caching, openvino_device)
|
||||
|
||||
extra_network_data = p.parse_extra_network_prompts()
|
||||
print("Extra network data: ", extra_network_data)
|
||||
|
||||
if not p.disable_extra_networks:
|
||||
with devices.autocast():
|
||||
extra_networks.activate(p, p.extra_network_data)
|
||||
|
||||
# TODO: support multiplier
|
||||
if ('lora' in modules.extra_networks.extra_network_registry):
|
||||
import lora
|
||||
for lora_model in lora.loaded_loras:
|
||||
shared.sd_diffusers_model.load_lora_weights(os.getcwd() + "/models/Lora/", weight_name=lora_model.name + ".safetensors")
|
||||
|
||||
|
||||
if p.scripts is not None:
|
||||
p.scripts.process_batch(p, batch_number=n, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds)
|
||||
|
||||
print("After process batch")
|
||||
|
||||
# params.txt should be saved after scripts.process_batch, since the
|
||||
# infotext could be modified by that callback
|
||||
# Example: a wildcard processed by process_batch sets an extra model
|
||||
@ -375,14 +339,12 @@ def process_images_openvino(p: StableDiffusionProcessing) -> Processed:
|
||||
with open(os.path.join(paths.data_path, "params.txt"), "w", encoding="utf8") as file:
|
||||
processed = Processed(p, [], p.seed, "")
|
||||
file.write(create_infotext(p, p.all_prompts, p.all_seeds, p.all_subseeds, comments=[], position_in_batch=0 % p.batch_size, iteration=0 // p.batch_size))
|
||||
print("After processed")
|
||||
|
||||
if p.n_iter > 1:
|
||||
shared.state.job = f"Batch {n+1} out of {p.n_iter}"
|
||||
print("After shared state job")
|
||||
|
||||
generator = [torch.Generator(device="cpu").manual_seed(s) for s in p.seeds]
|
||||
print("prompts: ", p.prompts)
|
||||
|
||||
output = shared.sd_diffusers_model(
|
||||
prompt=p.prompts,
|
||||
negative_prompt=p.negative_prompts,
|
||||
@ -494,7 +456,6 @@ def process_images_openvino(p: StableDiffusionProcessing) -> Processed:
|
||||
|
||||
return res
|
||||
|
||||
warm_up_triggered_global = 0
|
||||
class Script(scripts.Script):
|
||||
def title(self):
|
||||
return "Accelerate with OpenVINO"
|
||||
@ -504,37 +465,27 @@ class Script(scripts.Script):
|
||||
|
||||
def ui(self, is_img2img):
|
||||
core = Core()
|
||||
global first_inference_global, warmed_up_global, warm_up_triggered_global
|
||||
global first_inference_global
|
||||
openvino_device = gr.Dropdown(label="Select a device", choices=[device for device in core.available_devices], value="CPU")
|
||||
override_sampler = gr.Checkbox(label="Override the sampling selection from the main UI (Recommended as only below sampling methods have been validated for OpenVINO)", value=True)
|
||||
sampler_name = gr.Radio(label="Select a sampling method", choices=["Euler a", "Euler", "LMS", "Heun", "DPM++ 2M", "LMS Karras", "DPM++ 2M Karras", "DDIM", "PLMS"], value="Euler a")
|
||||
enable_caching = gr.Checkbox(label="Cache the compiled models for faster model load in subsequent launches (Recommended)", value=True, elem_id=self.elem_id("enable_caching"))
|
||||
run_warmup = gr.Button("Run a warmup iteration (recommended)")
|
||||
warmup_status = gr.Textbox(label="Status of the warm up iteration", interactive=False)
|
||||
warmup_status = gr.Textbox(label="Device", interactive=False, visible=False)
|
||||
|
||||
def device_change(choice):
|
||||
global first_inference_global, warmed_up_global, openvino_device_global
|
||||
global first_inference_global, openvino_device_global
|
||||
if (openvino_device_global == choice):
|
||||
return gr.update(value="Device selected is " + choice)
|
||||
return gr.update(value="Device selected is " + choice, visible=True)
|
||||
else:
|
||||
warmed_up_global = 0
|
||||
first_inference_global = 1
|
||||
return gr.update(value="Device changed to " + choice + ". Press the button to run a new warmup iteration")
|
||||
return gr.update(value="Device changed to " + choice + ". Model will be re-compiled", visible=True)
|
||||
openvino_device.change(device_change, openvino_device, warmup_status)
|
||||
|
||||
def warmup(run_warmup, openvino_device, enable_caching, sampler_name):
|
||||
global first_inference_global, warmed_up_global
|
||||
first_inference_global = 1
|
||||
shared.sd_diffusers_model = get_diffusers_sd_model(sampler_name, enable_caching, openvino_device)
|
||||
warmed_up_global = 1
|
||||
return gr.update(value="Warm up run complete")
|
||||
run_warmup.click(warmup, [run_warmup, openvino_device, enable_caching, sampler_name], warmup_status)
|
||||
|
||||
return [openvino_device, override_sampler, sampler_name, warmup_status, enable_caching]
|
||||
|
||||
|
||||
def run(self, p, openvino_device, override_sampler, sampler_name, warmup_status, enable_caching):
|
||||
global first_inference_global, warmed_up_global, sampler_name_global, openvino_device_global
|
||||
global first_inference_global, sampler_name_global, openvino_device_global
|
||||
os.environ["OPENVINO_DEVICE"] = str(openvino_device)
|
||||
if enable_caching:
|
||||
os.environ["OPENVINO_TORCH_MODEL_CACHING"] = "1"
|
||||
@ -543,9 +494,6 @@ class Script(scripts.Script):
|
||||
first_inference_global = 1
|
||||
openvino_device_global = openvino_device
|
||||
|
||||
if (warmed_up_global == 0):
|
||||
shared.sd_diffusers_model = get_diffusers_sd_model(sampler_name, enable_caching, openvino_device)
|
||||
|
||||
if override_sampler:
|
||||
p.sampler_name = sampler_name
|
||||
else:
|
||||
@ -558,9 +506,9 @@ class Script(scripts.Script):
|
||||
sampler_name_global = p.sampler_name
|
||||
|
||||
if self.is_txt2img:
|
||||
processed = process_images_openvino(p)
|
||||
processed = process_images_openvino(p, p.sampler_name, enable_caching, openvino_device)
|
||||
else:
|
||||
p.init = functools.partial(init_new, p)
|
||||
processed = process_images_openvino(p)
|
||||
processed = process_images_openvino(p, p.sampler_name, enable_caching, openvino_device)
|
||||
return processed
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user