From 6c932e61002014b893215fb68a08cdd0fbd8f3f9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BE=AE=E5=BD=B1?= <15190450708@163.com> Date: Thu, 4 Jan 2024 23:13:18 +0800 Subject: [PATCH] add DPM++ 2M SDE Karras sampler see: https://github.com/openvinotoolkit/stable-diffusion-webui/issues/91 --- scripts/openvino_accelerate.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/scripts/openvino_accelerate.py b/scripts/openvino_accelerate.py index 2bcd5002..d54d40d6 100644 --- a/scripts/openvino_accelerate.py +++ b/scripts/openvino_accelerate.py @@ -521,6 +521,8 @@ def set_scheduler(sd_model, sampler_name): sd_model.scheduler = LMSDiscreteScheduler.from_config(sd_model.scheduler.config, use_karras_sigmas=True) elif (sampler_name == "DPM++ 2M Karras"): sd_model.scheduler = DPMSolverMultistepScheduler.from_config(sd_model.scheduler.config, algorithm_type="dpmsolver++", use_karras_sigmas=True) + elif (sampler_name == "DPM++ 2M SDE Karras"): + sd_model.scheduler = DPMSolverMultistepScheduler.from_config(sd_model.scheduler.config, algorithm_type="sde-dpmsolver++", use_karras_sigmas=True) elif (sampler_name == "DDIM"): sd_model.scheduler = DDIMScheduler.from_config(sd_model.scheduler.config) elif (sampler_name == "PLMS"): @@ -1168,7 +1170,7 @@ class Script(scripts.Script): refiner_frac = gr.Slider(minimum=0, maximum=1, step=0.1, label='Refiner Denosing Fraction:', value=0.8) override_sampler = gr.Checkbox(label="Override the sampling selection from the main UI (Recommended as only below sampling methods have been validated for OpenVINO)", value=True) - sampler_name = gr.Radio(label="Select a sampling method", choices=["Euler a", "Euler", "LMS", "Heun", "DPM++ 2M", "LMS Karras", "DPM++ 2M Karras", "DDIM", "PLMS"], value="Euler a") + sampler_name = gr.Radio(label="Select a sampling method", choices=["Euler a", "Euler", "LMS", "Heun", "DPM++ 2M", "LMS Karras", "DPM++ 2M Karras","DPM++ 2M SDE Karras","DDIM", "PLMS"], value="Euler a") enable_caching = gr.Checkbox(label="Cache the compiled models on disk for faster model load in subsequent launches (Recommended)", value=True, elem_id=self.elem_id("enable_caching")) warmup_status = gr.Textbox(label="Device", interactive=False, visible=False) vae_status = gr.Textbox(label="VAE", interactive=False, visible=False) @@ -1217,7 +1219,7 @@ class Script(scripts.Script): if override_sampler: p.sampler_name = sampler_name else: - supported_samplers = ["Euler a", "Euler", "LMS", "Heun", "DPM++ 2M", "LMS Karras", "DPM++ 2M Karras", "DDIM", "PLMS"] + supported_samplers = ["Euler a", "Euler", "LMS", "Heun", "DPM++ 2M", "LMS Karras", "DPM++ 2M Karras","DPM++ 2M SDE Karras","DDIM", "PLMS"] if (p.sampler_name not in supported_samplers): p.sampler_name = "Euler a"