mirror of
https://github.com/openvinotoolkit/stable-diffusion-webui.git
synced 2024-12-15 07:03:06 +03:00
Added a fix for sampler overriding
This commit is contained in:
parent
e8d8156f54
commit
eb7b1a9af4
@ -65,8 +65,8 @@ def set_scheduler(sd_model, sampler_name):
|
||||
sd_model.scheduler = EulerDiscreteScheduler.from_config(sd_model.scheduler.config)
|
||||
elif (sampler_name == "LMS"):
|
||||
sd_model.scheduler = LMSDiscreteScheduler.from_config(sd_model.scheduler.config)
|
||||
elif (sampler_name == "Huen"):
|
||||
sd_model.scheduler = HuenDiscreteScheduler.from_config(sd_model.scheduler.config)
|
||||
elif (sampler_name == "Heun"):
|
||||
sd_model.scheduler = HeunDiscreteScheduler.from_config(sd_model.scheduler.config)
|
||||
#elif (sampler_name == "DPM2"):
|
||||
# sd_model.scheduler = KDPM2DiscreteScheduler.from_config(sd_model.scheduler.config)
|
||||
#elif (sampler_name == "DPM2 a"):
|
||||
@ -507,25 +507,28 @@ class Script(scripts.Script):
|
||||
global first_inference_global, warmed_up_global, warm_up_triggered_global
|
||||
openvino_device = gr.Dropdown(label="Select a device", choices=[device for device in core.available_devices], value="CPU")
|
||||
override_sampler = gr.Checkbox(label="Override the sampling selection from the main UI (Recommended as only below sampling methods have been validated for OpenVINO)", value=True)
|
||||
sampler_name = gr.Radio(label="Select a sampling method", choices=["Euler a", "Euler", "LMS", "Huen", "DPM++ 2M", "LMS Karras", "DPM++ 2M Karras", "DDIM", "PLMS"], value="Euler a")
|
||||
sampler_name = gr.Radio(label="Select a sampling method", choices=["Euler a", "Euler", "LMS", "Heun", "DPM++ 2M", "LMS Karras", "DPM++ 2M Karras", "DDIM", "PLMS"], value="Euler a")
|
||||
enable_caching = gr.Checkbox(label="Cache the compiled models for faster model load in subsequent launches (Recommended)", value=True, elem_id=self.elem_id("enable_caching"))
|
||||
run_warmup = gr.Button("Run a warmup iteration (recommended)")
|
||||
warmup_status = gr.Textbox(label="Status of the warm up iteration", interactive=False)
|
||||
|
||||
def device_change(choice):
|
||||
global first_inference_global, warmed_up_global
|
||||
warmed_up_global = 0
|
||||
first_inference_global = 1
|
||||
return gr.update(value="Device changed to " + choice + ". Press the button to run a new warmup iteration")
|
||||
global first_inference_global, warmed_up_global, openvino_device_global
|
||||
if (openvino_device_global == choice):
|
||||
return gr.update(value="Device selected is " + choice)
|
||||
else:
|
||||
warmed_up_global = 0
|
||||
first_inference_global = 1
|
||||
return gr.update(value="Device changed to " + choice + ". Press the button to run a new warmup iteration")
|
||||
openvino_device.change(device_change, openvino_device, warmup_status)
|
||||
|
||||
def warmup(run_warmup, openvino_device, enable_caching):
|
||||
def warmup(run_warmup, openvino_device, enable_caching, sampler_name):
|
||||
global first_inference_global, warmed_up_global
|
||||
first_inference_global = 1
|
||||
shared.sd_diffusers_model = get_diffusers_sd_model(sampler_name, enable_caching, openvino_device)
|
||||
warmed_up_global = 1
|
||||
return gr.update(value="Warm up run complete")
|
||||
run_warmup.click(warmup, [run_warmup, openvino_device, enable_caching], warmup_status)
|
||||
run_warmup.click(warmup, [run_warmup, openvino_device, enable_caching, sampler_name], warmup_status)
|
||||
|
||||
return [openvino_device, override_sampler, sampler_name, warmup_status, enable_caching]
|
||||
|
||||
@ -535,15 +538,24 @@ class Script(scripts.Script):
|
||||
os.environ["OPENVINO_DEVICE"] = str(openvino_device)
|
||||
if enable_caching:
|
||||
os.environ["OPENVINO_TORCH_MODEL_CACHING"] = "1"
|
||||
|
||||
if (openvino_device_global != openvino_device):
|
||||
first_inference_global = 1
|
||||
openvino_device_global = openvino_device
|
||||
|
||||
if (warmed_up_global == 0):
|
||||
shared.sd_diffusers_model = get_diffusers_sd_model(sampler_name, enable_caching, openvino_device)
|
||||
if (sampler_name_global != sampler_name):
|
||||
print("Sampler name: ", sampler_name)
|
||||
shared.sd_diffusers_model.scheduler = set_scheduler(shared.sd_diffusers_model, sampler_name)
|
||||
sampler_name_global = sampler_name
|
||||
|
||||
if override_sampler:
|
||||
p.sampler_name = sampler_name
|
||||
else:
|
||||
supported_samplers = ["Euler a", "Euler", "LMS", "Heun", "DPM++ 2M", "LMS Karras", "DPM++ 2M Karras", "DDIM", "PLMS"]
|
||||
if (p.sampler_name not in supported_samplers):
|
||||
p.sampler_name = "Euler a"
|
||||
|
||||
if (sampler_name_global != p.sampler_name):
|
||||
shared.sd_diffusers_model.scheduler = set_scheduler(shared.sd_diffusers_model, p.sampler_name)
|
||||
sampler_name_global = p.sampler_name
|
||||
|
||||
if self.is_txt2img:
|
||||
processed = process_images_openvino(p)
|
||||
|
Loading…
Reference in New Issue
Block a user