Merge pull request #97 from openvinotoolkit/dev

HiRes Latent upscaler support
This commit is contained in:
Anna Likholat 2024-05-06 18:45:12 +02:00 committed by GitHub
commit e5a634da06
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -53,6 +53,7 @@ from diffusers import (
StableDiffusionXLImg2ImgPipeline,
StableDiffusionXLInpaintPipeline,
ControlNetModel,
StableDiffusionLatentUpscalePipeline,
DDIMScheduler,
DPMSolverMultistepScheduler,
EulerAncestralDiscreteScheduler,
@ -535,6 +536,23 @@ class NoWatermark:
def apply_watermark(self, img):
return img
def get_diffusers_upscaler(upscaler: str):
torch._dynamo.reset()
openvino_clear_caches()
model_name = "stabilityai/sd-x2-latent-upscaler"
print("OpenVINO Script: loading upscaling model: " + model_name)
sd_model = StableDiffusionLatentUpscalePipeline.from_pretrained(model_name, torch_dtype=torch.float32)
sd_model.safety_checker = None
sd_model.cond_stage_key = functools.partial(cond_stage_key, shared.sd_model)
sd_model.unet = torch.compile(sd_model.unet, backend="openvino")
sd_model.vae.decode = torch.compile(sd_model.vae.decode, backend="openvino")
shared.sd_diffusers_model = sd_model
del sd_model
return shared.sd_diffusers_model
def get_diffusers_sd_model(model_config, vae_ckpt, sampler_name, enable_caching, openvino_device, mode, is_xl_ckpt, refiner_ckpt, refiner_frac):
if (model_state.recompile == 1):
model_state.partition_id = 0
@ -770,7 +788,8 @@ def init_new(self, all_prompts, all_seeds, all_subseeds):
else:
raise RuntimeError(f"bad number of images passed: {len(imgs)}; expecting {self.batch_size} or less")
def process_images_openvino(p: StableDiffusionProcessing, model_config, vae_ckpt, sampler_name, enable_caching, openvino_device, mode, is_xl_ckpt, refiner_ckpt, refiner_frac) -> Processed:
def process_images_openvino(p: StableDiffusionProcessing, model_config, vae_ckpt, sampler_name, enable_caching, override_hires, upscaler, hires_steps, d_strength, openvino_device, mode, is_xl_ckpt, refiner_ckpt, refiner_frac) -> Processed:
"""this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
"""this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
if (mode == 0 and p.enable_hr):
@ -1092,6 +1111,23 @@ def process_images_openvino(p: StableDiffusionProcessing, model_config, vae_ckpt
devices.torch_gc()
# Hight resolutuon mode
if override_hires:
if upscaler == "Latent":
model_state.mode = -1
shared.sd_diffusers_model = get_diffusers_upscaler(upscaler)
img_idx = slice(len(output_images)) if p.batch_size == 1 else slice(1, len(output_images))
output_images[img_idx] = shared.sd_diffusers_model(
image=output_images[img_idx],
prompt=p.prompts,
negative_prompt=p.negative_prompts,
num_inference_steps=hires_steps,
guidance_scale=p.cfg_scale,
generator=generator,
callback = callback,
callback_steps = 1,
).images
res = Processed(
p,
images_list=output_images,
@ -1102,7 +1138,8 @@ def process_images_openvino(p: StableDiffusionProcessing, model_config, vae_ckpt
index_of_first_image=index_of_first_image,
infotexts=infotexts,
)
if override_hires:
res.info = res.info + f", Hires upscaler: {upscaler}, Denoising strength: {d_strength}"
res.info = res.info + ", Warm up time: " + str(round(warmup_duration, 2)) + " secs "
if (generation_rate >= 1.0):
@ -1116,6 +1153,9 @@ def process_images_openvino(p: StableDiffusionProcessing, model_config, vae_ckpt
return res
def on_change(mode):
return gr.update(visible=mode)
class Script(scripts.Script):
def title(self):
return "Accelerate with OpenVINO"
@ -1170,6 +1210,12 @@ class Script(scripts.Script):
override_sampler = gr.Checkbox(label="Override the sampling selection from the main UI (Recommended as only below sampling methods have been validated for OpenVINO)", value=True)
sampler_name = gr.Radio(label="Select a sampling method", choices=["Euler a", "Euler", "LMS", "Heun", "DPM++ 2M", "LMS Karras", "DPM++ 2M Karras", "DDIM", "PLMS"], value="Euler a")
enable_caching = gr.Checkbox(label="Cache the compiled models on disk for faster model load in subsequent launches (Recommended)", value=True, elem_id=self.elem_id("enable_caching"))
override_hires = gr.Checkbox(label="Override the Hires.fix selection from the main UI (Recommended as only below upscalers have been validated for OpenVINO)", value=False, visible=self.is_txt2img)
with gr.Group(visible=False) as hires:
with gr.Row():
upscaler = gr.Dropdown(label="Upscaler", choices=["Latent"], value="Latent")
hires_steps = gr.Slider(1, 150, value=10, step=1, label="Steps")
d_strength = gr.Slider(0, 1, value=0.5, step=0.01, label="Strength")
warmup_status = gr.Textbox(label="Device", interactive=False, visible=False)
vae_status = gr.Textbox(label="VAE", interactive=False, visible=False)
gr.Markdown(
@ -1184,6 +1230,8 @@ class Script(scripts.Script):
So it's normal for the first inference after a settings change to be slower, while subsequent inferences use the optimized compiled model and run faster.
""")
override_hires.change(on_change, override_hires, hires)
def device_change(choice):
if (model_state.device == choice):
return gr.update(value="Device selected is " + choice, visible=True)
@ -1206,9 +1254,9 @@ class Script(scripts.Script):
else:
model_state.refiner_ckpt = choice
refiner_ckpt.change(refiner_ckpt_change, refiner_ckpt)
return [model_config, vae_ckpt, openvino_device, override_sampler, sampler_name, enable_caching, is_xl_ckpt, refiner_ckpt, refiner_frac]
return [model_config, vae_ckpt, openvino_device, override_sampler, sampler_name, enable_caching, override_hires, upscaler, hires_steps, d_strength, is_xl_ckpt, refiner_ckpt, refiner_frac]
def run(self, p, model_config, vae_ckpt, openvino_device, override_sampler, sampler_name, enable_caching, is_xl_ckpt, refiner_ckpt, refiner_frac):
def run(self, p, model_config, vae_ckpt, openvino_device, override_sampler, sampler_name, enable_caching, override_hires, upscaler, hires_steps, d_strength, is_xl_ckpt, refiner_ckpt, refiner_frac):
os.environ["OPENVINO_TORCH_BACKEND_DEVICE"] = str(openvino_device)
if enable_caching:
@ -1225,14 +1273,12 @@ class Script(scripts.Script):
mode = 0
if self.is_txt2img:
mode = 0
processed = process_images_openvino(p, model_config, vae_ckpt, p.sampler_name, enable_caching, openvino_device, mode, is_xl_ckpt, refiner_ckpt, refiner_frac)
processed = process_images_openvino(p, model_config, vae_ckpt, p.sampler_name, enable_caching, override_hires, upscaler, hires_steps, d_strength, openvino_device, mode, is_xl_ckpt, refiner_ckpt, refiner_frac)
else:
if p.image_mask is None:
mode = 1
else:
mode = 2
p.init = functools.partial(init_new, p)
processed = process_images_openvino(p, model_config, vae_ckpt, p.sampler_name, enable_caching, openvino_device, mode, is_xl_ckpt, refiner_ckpt, refiner_frac)
processed = process_images_openvino(p, model_config, vae_ckpt, p.sampler_name, enable_caching, override_hires, upscaler, hires_steps, d_strength, openvino_device, mode, is_xl_ckpt, refiner_ckpt, refiner_frac)
return processed