add DPM++ 2M SDE Karras sampler

see: https://github.com/openvinotoolkit/stable-diffusion-webui/issues/91
This commit is contained in:
微影 2024-01-04 23:13:18 +08:00 committed by GitHub
parent 44006297e0
commit 6c932e6100
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -521,6 +521,8 @@ def set_scheduler(sd_model, sampler_name):
sd_model.scheduler = LMSDiscreteScheduler.from_config(sd_model.scheduler.config, use_karras_sigmas=True)
elif (sampler_name == "DPM++ 2M Karras"):
sd_model.scheduler = DPMSolverMultistepScheduler.from_config(sd_model.scheduler.config, algorithm_type="dpmsolver++", use_karras_sigmas=True)
elif (sampler_name == "DPM++ 2M SDE Karras"):
sd_model.scheduler = DPMSolverMultistepScheduler.from_config(sd_model.scheduler.config, algorithm_type="sde-dpmsolver++", use_karras_sigmas=True)
elif (sampler_name == "DDIM"):
sd_model.scheduler = DDIMScheduler.from_config(sd_model.scheduler.config)
elif (sampler_name == "PLMS"):
@ -1168,7 +1170,7 @@ class Script(scripts.Script):
refiner_frac = gr.Slider(minimum=0, maximum=1, step=0.1, label='Refiner Denosing Fraction:', value=0.8)
override_sampler = gr.Checkbox(label="Override the sampling selection from the main UI (Recommended as only below sampling methods have been validated for OpenVINO)", value=True)
sampler_name = gr.Radio(label="Select a sampling method", choices=["Euler a", "Euler", "LMS", "Heun", "DPM++ 2M", "LMS Karras", "DPM++ 2M Karras", "DDIM", "PLMS"], value="Euler a")
sampler_name = gr.Radio(label="Select a sampling method", choices=["Euler a", "Euler", "LMS", "Heun", "DPM++ 2M", "LMS Karras", "DPM++ 2M Karras","DPM++ 2M SDE Karras","DDIM", "PLMS"], value="Euler a")
enable_caching = gr.Checkbox(label="Cache the compiled models on disk for faster model load in subsequent launches (Recommended)", value=True, elem_id=self.elem_id("enable_caching"))
warmup_status = gr.Textbox(label="Device", interactive=False, visible=False)
vae_status = gr.Textbox(label="VAE", interactive=False, visible=False)
@ -1217,7 +1219,7 @@ class Script(scripts.Script):
if override_sampler:
p.sampler_name = sampler_name
else:
supported_samplers = ["Euler a", "Euler", "LMS", "Heun", "DPM++ 2M", "LMS Karras", "DPM++ 2M Karras", "DDIM", "PLMS"]
supported_samplers = ["Euler a", "Euler", "LMS", "Heun", "DPM++ 2M", "LMS Karras", "DPM++ 2M Karras","DPM++ 2M SDE Karras","DDIM", "PLMS"]
if (p.sampler_name not in supported_samplers):
p.sampler_name = "Euler a"