Added a fix for sampler overriding

This commit is contained in:
ynimmaga 2023-07-21 21:08:01 -07:00
parent e8d8156f54
commit eb7b1a9af4

View File

@ -65,8 +65,8 @@ def set_scheduler(sd_model, sampler_name):
sd_model.scheduler = EulerDiscreteScheduler.from_config(sd_model.scheduler.config)
elif (sampler_name == "LMS"):
sd_model.scheduler = LMSDiscreteScheduler.from_config(sd_model.scheduler.config)
elif (sampler_name == "Huen"):
sd_model.scheduler = HuenDiscreteScheduler.from_config(sd_model.scheduler.config)
elif (sampler_name == "Heun"):
sd_model.scheduler = HeunDiscreteScheduler.from_config(sd_model.scheduler.config)
#elif (sampler_name == "DPM2"):
# sd_model.scheduler = KDPM2DiscreteScheduler.from_config(sd_model.scheduler.config)
#elif (sampler_name == "DPM2 a"):
@ -507,25 +507,28 @@ class Script(scripts.Script):
global first_inference_global, warmed_up_global, warm_up_triggered_global
openvino_device = gr.Dropdown(label="Select a device", choices=[device for device in core.available_devices], value="CPU")
override_sampler = gr.Checkbox(label="Override the sampling selection from the main UI (Recommended as only below sampling methods have been validated for OpenVINO)", value=True)
sampler_name = gr.Radio(label="Select a sampling method", choices=["Euler a", "Euler", "LMS", "Huen", "DPM++ 2M", "LMS Karras", "DPM++ 2M Karras", "DDIM", "PLMS"], value="Euler a")
sampler_name = gr.Radio(label="Select a sampling method", choices=["Euler a", "Euler", "LMS", "Heun", "DPM++ 2M", "LMS Karras", "DPM++ 2M Karras", "DDIM", "PLMS"], value="Euler a")
enable_caching = gr.Checkbox(label="Cache the compiled models for faster model load in subsequent launches (Recommended)", value=True, elem_id=self.elem_id("enable_caching"))
run_warmup = gr.Button("Run a warmup iteration (recommended)")
warmup_status = gr.Textbox(label="Status of the warm up iteration", interactive=False)
def device_change(choice):
global first_inference_global, warmed_up_global
warmed_up_global = 0
first_inference_global = 1
return gr.update(value="Device changed to " + choice + ". Press the button to run a new warmup iteration")
global first_inference_global, warmed_up_global, openvino_device_global
if (openvino_device_global == choice):
return gr.update(value="Device selected is " + choice)
else:
warmed_up_global = 0
first_inference_global = 1
return gr.update(value="Device changed to " + choice + ". Press the button to run a new warmup iteration")
openvino_device.change(device_change, openvino_device, warmup_status)
def warmup(run_warmup, openvino_device, enable_caching):
def warmup(run_warmup, openvino_device, enable_caching, sampler_name):
global first_inference_global, warmed_up_global
first_inference_global = 1
shared.sd_diffusers_model = get_diffusers_sd_model(sampler_name, enable_caching, openvino_device)
warmed_up_global = 1
return gr.update(value="Warm up run complete")
run_warmup.click(warmup, [run_warmup, openvino_device, enable_caching], warmup_status)
run_warmup.click(warmup, [run_warmup, openvino_device, enable_caching, sampler_name], warmup_status)
return [openvino_device, override_sampler, sampler_name, warmup_status, enable_caching]
@ -535,15 +538,24 @@ class Script(scripts.Script):
os.environ["OPENVINO_DEVICE"] = str(openvino_device)
if enable_caching:
os.environ["OPENVINO_TORCH_MODEL_CACHING"] = "1"
if (openvino_device_global != openvino_device):
first_inference_global = 1
openvino_device_global = openvino_device
if (warmed_up_global == 0):
shared.sd_diffusers_model = get_diffusers_sd_model(sampler_name, enable_caching, openvino_device)
if (sampler_name_global != sampler_name):
print("Sampler name: ", sampler_name)
shared.sd_diffusers_model.scheduler = set_scheduler(shared.sd_diffusers_model, sampler_name)
sampler_name_global = sampler_name
if override_sampler:
p.sampler_name = sampler_name
else:
supported_samplers = ["Euler a", "Euler", "LMS", "Heun", "DPM++ 2M", "LMS Karras", "DPM++ 2M Karras", "DDIM", "PLMS"]
if (p.sampler_name not in supported_samplers):
p.sampler_name = "Euler a"
if (sampler_name_global != p.sampler_name):
shared.sd_diffusers_model.scheduler = set_scheduler(shared.sd_diffusers_model, p.sampler_name)
sampler_name_global = p.sampler_name
if self.is_txt2img:
processed = process_images_openvino(p)