mirror of
https://github.com/openvinotoolkit/stable-diffusion-webui.git
synced 2024-12-14 22:53:25 +03:00
Added caching optimizations and local config selection
This commit is contained in:
commit
64ee9ff6bb
@ -26,9 +26,20 @@ from modules.shared import opts, state
|
||||
from PIL import Image, ImageOps
|
||||
from pathlib import Path
|
||||
|
||||
import openvino.frontend.pytorch.torchdynamo.backend # noqa: F401
|
||||
from openvino.frontend.pytorch.torchdynamo.execute import partitioned_modules, compiled_cache # noqa: F401
|
||||
from openvino.runtime import Core
|
||||
#from openvino.frontend import FrontEndManager
|
||||
from openvino.frontend.pytorch.torchdynamo import backend, compile # noqa: F401
|
||||
from openvino.frontend.pytorch.torchdynamo.execute import execute, partitioned_modules, compiled_cache # noqa: F401
|
||||
from openvino.frontend.pytorch.torchdynamo.partition import Partitioner
|
||||
from openvino.runtime import Core, Type, PartialShape #, serialize
|
||||
|
||||
from torch._dynamo.backends.common import fake_tensor_unsupported
|
||||
from torch._dynamo.backends.registry import register_backend
|
||||
from torch._inductor.compile_fx import compile_fx
|
||||
#from torch.fx import GraphModule
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
|
||||
#from typing import Callable, Optional
|
||||
from hashlib import sha256
|
||||
|
||||
from diffusers import (
|
||||
StableDiffusionPipeline,
|
||||
@ -59,9 +70,98 @@ class ModelState:
|
||||
self.width = 512
|
||||
self.batch_size = 1
|
||||
self.mode = 0
|
||||
self.partition_id = 0
|
||||
self.model_hash = ""
|
||||
|
||||
model_state = ModelState()
|
||||
|
||||
@register_backend
|
||||
@fake_tensor_unsupported
|
||||
def openvino_fx(subgraph, example_inputs):
|
||||
try:
|
||||
executor_parameters = None
|
||||
core = Core()
|
||||
if os.getenv("OPENVINO_TORCH_MODEL_CACHING") is not None:
|
||||
model_hash_str = sha256(subgraph.code.encode('utf-8')).hexdigest()
|
||||
model_hash_str_file = model_hash_str + str(model_state.partition_id)
|
||||
model_state.partition_id = model_state.partition_id + 1
|
||||
executor_parameters = {"model_hash_str": model_hash_str}
|
||||
|
||||
example_inputs.reverse()
|
||||
cache_root = "./cache/"
|
||||
if os.getenv("OPENVINO_TORCH_CACHE_DIR") is not None:
|
||||
cache_root = os.getenv("OPENVINO_TORCH_CACHE_DIR")
|
||||
|
||||
device = "CPU"
|
||||
|
||||
if os.getenv("OPENVINO_TORCH_BACKEND_DEVICE") is not None:
|
||||
device = os.getenv("OPENVINO_TORCH_BACKEND_DEVICE")
|
||||
assert device in core.available_devices, "Specified device " + device + " is not in the list of OpenVINO Available Devices"
|
||||
|
||||
file_name = get_cached_file_name(*example_inputs, model_hash_str=model_hash_str_file, device=device, cache_root=cache_root)
|
||||
|
||||
if file_name is not None and os.path.isfile(file_name + ".xml") and os.path.isfile(file_name + ".bin"):
|
||||
om = core.read_model(file_name + ".xml")
|
||||
|
||||
dtype_mapping = {
|
||||
torch.float32: Type.f32,
|
||||
torch.float64: Type.f64,
|
||||
torch.float16: Type.f16,
|
||||
torch.int64: Type.i64,
|
||||
torch.int32: Type.i32,
|
||||
torch.uint8: Type.u8,
|
||||
torch.int8: Type.i8,
|
||||
torch.bool: Type.boolean
|
||||
}
|
||||
|
||||
for idx, input_data in enumerate(example_inputs):
|
||||
om.inputs[idx].get_node().set_element_type(dtype_mapping[input_data.dtype])
|
||||
om.inputs[idx].get_node().set_partial_shape(PartialShape(list(input_data.shape)))
|
||||
om.validate_nodes_and_infer_types()
|
||||
|
||||
if model_hash_str is not None:
|
||||
core.set_property({'CACHE_DIR': cache_root + '/blob'})
|
||||
|
||||
compiled_model = core.compile_model(om, device)
|
||||
def _call(*args):
|
||||
ov_inputs = [a.detach().cpu().numpy() for a in args]
|
||||
ov_inputs.reverse()
|
||||
res = compiled_model(ov_inputs)
|
||||
result = [torch.from_numpy(res[out]) for out in compiled_model.outputs]
|
||||
return result
|
||||
return _call
|
||||
else:
|
||||
example_inputs.reverse()
|
||||
model = make_fx(subgraph)(*example_inputs)
|
||||
with torch.no_grad():
|
||||
model.eval()
|
||||
partitioner = Partitioner()
|
||||
compiled_model = partitioner.make_partitions(model)
|
||||
|
||||
def _call(*args):
|
||||
res = execute(compiled_model, *args, executor="openvino",
|
||||
executor_parameters=executor_parameters)
|
||||
return res
|
||||
return _call
|
||||
except Exception:
|
||||
return compile_fx(subgraph, example_inputs)
|
||||
|
||||
def get_cached_file_name(*args, model_hash_str, device, cache_root):
|
||||
file_name = None
|
||||
if model_hash_str is not None:
|
||||
model_cache_dir = cache_root + "/model/"
|
||||
try:
|
||||
os.makedirs(model_cache_dir, exist_ok=True)
|
||||
file_name = model_cache_dir + model_hash_str + "_" + device
|
||||
for input_data in args:
|
||||
if file_name is not None:
|
||||
file_name += "_" + str(input_data.type()) + str(input_data.size())[11:-1].replace(" ", "")
|
||||
except OSError as error:
|
||||
print("Cache directory ", cache_root, " cannot be created. Model caching is disabled. Error: ", error)
|
||||
file_name = None
|
||||
model_hash_str = None
|
||||
return file_name
|
||||
|
||||
def from_single_file(self, pretrained_model_link_or_path, **kwargs):
|
||||
|
||||
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
||||
@ -213,16 +313,18 @@ def set_scheduler(sd_model, sampler_name):
|
||||
|
||||
return sd_model.scheduler
|
||||
|
||||
def get_diffusers_sd_model(sampler_name, enable_caching, openvino_device, mode):
|
||||
def get_diffusers_sd_model(local_config, model_config, sampler_name, enable_caching, openvino_device, mode):
|
||||
if (model_state.recompile == 1):
|
||||
torch._dynamo.reset()
|
||||
openvino_clear_caches()
|
||||
curr_dir_path = os.getcwd()
|
||||
checkpoint_name = shared.opts.sd_model_checkpoint.split(" ")[0]
|
||||
checkpoint_path = os.path.join(curr_dir_path, 'models', 'Stable-diffusion', checkpoint_name)
|
||||
config_name = checkpoint_name.split(".")[0] + ".yaml"
|
||||
local_config_file = os.path.join(curr_dir_path, 'configs',config_name)
|
||||
sd_model = StableDiffusionPipeline.from_single_file(checkpoint_path, local_config_file=local_config_file, load_safety_checker=False)
|
||||
if local_config:
|
||||
local_config_file = os.path.join(curr_dir_path, 'configs', model_config)
|
||||
sd_model = StableDiffusionPipeline.from_single_file(checkpoint_path, local_config_file=local_config_file, load_safety_checker=False)
|
||||
else:
|
||||
sd_model = StableDiffusionPipeline.from_single_file(checkpoint_path, load_safety_checker=False, torch_dtype=torch.float32)
|
||||
if (mode == 1):
|
||||
sd_model = StableDiffusionImg2ImgPipeline(**sd_model.components)
|
||||
elif (mode == 2):
|
||||
@ -233,8 +335,8 @@ def get_diffusers_sd_model(sampler_name, enable_caching, openvino_device, mode):
|
||||
sd_model.safety_checker = None
|
||||
sd_model.cond_stage_key = functools.partial(cond_stage_key, shared.sd_model)
|
||||
sd_model.scheduler = set_scheduler(sd_model, sampler_name)
|
||||
sd_model.unet = torch.compile(sd_model.unet, backend="openvino")
|
||||
sd_model.vae.decode = torch.compile(sd_model.vae.decode, backend="openvino")
|
||||
sd_model.unet = torch.compile(sd_model.unet, backend="openvino_fx")
|
||||
sd_model.vae.decode = torch.compile(sd_model.vae.decode, backend="openvino_fx")
|
||||
shared.sd_diffusers_model = sd_model
|
||||
del sd_model
|
||||
return shared.sd_diffusers_model
|
||||
@ -335,7 +437,7 @@ def init_new(self, all_prompts, all_seeds, all_subseeds):
|
||||
raise RuntimeError(f"bad number of images passed: {len(imgs)}; expecting {self.batch_size} or less")
|
||||
|
||||
|
||||
def process_images_openvino(p: StableDiffusionProcessing, sampler_name, enable_caching, openvino_device, mode) -> Processed:
|
||||
def process_images_openvino(p: StableDiffusionProcessing, local_config, model_config, sampler_name, enable_caching, openvino_device, mode) -> Processed:
|
||||
"""this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
|
||||
|
||||
if (mode == 0 and p.enable_hr):
|
||||
@ -406,14 +508,16 @@ def process_images_openvino(p: StableDiffusionProcessing, sampler_name, enable_c
|
||||
if len(p.prompts) == 0:
|
||||
break
|
||||
|
||||
if (model_state.height != p.height or model_state.width != p.width or model_state.batch_size != p.batch_size or model_state.mode != mode):
|
||||
if (model_state.height != p.height or model_state.width != p.width or model_state.batch_size != p.batch_size
|
||||
or model_state.mode != mode or model_state.model_hash != shared.sd_model.sd_model_hash):
|
||||
model_state.recompile = 1
|
||||
model_state.height = p.height
|
||||
model_state.width = p.width
|
||||
model_state.batch_size = p.batch_size
|
||||
model_state.mode = mode
|
||||
model_state.model_hash = shared.sd_model.sd_model_hash
|
||||
|
||||
shared.sd_diffusers_model = get_diffusers_sd_model(sampler_name, enable_caching, openvino_device, mode)
|
||||
shared.sd_diffusers_model = get_diffusers_sd_model(local_config, model_config, sampler_name, enable_caching, openvino_device, mode)
|
||||
shared.sd_diffusers_model.scheduler = set_scheduler(shared.sd_diffusers_model, sampler_name)
|
||||
|
||||
extra_network_data = p.parse_extra_network_prompts()
|
||||
@ -604,6 +708,15 @@ class Script(scripts.Script):
|
||||
|
||||
def ui(self, is_img2img):
|
||||
core = Core()
|
||||
config_dir_list = os.listdir(os.path.join(os.getcwd(), 'configs'))
|
||||
|
||||
config_list = []
|
||||
for file in config_dir_list:
|
||||
if file.endswith('.yaml'):
|
||||
config_list.append(file)
|
||||
|
||||
local_config = gr.Checkbox(label="Use a local inference config file", value=False)
|
||||
model_config = gr.Dropdown(label="Select a config for the model (Below config files are listed from the configs directory of the WebUI root)", choices=config_list, value="v1-inference.yaml", visible=False)
|
||||
openvino_device = gr.Dropdown(label="Select a device", choices=list(core.available_devices), value=model_state.device)
|
||||
override_sampler = gr.Checkbox(label="Override the sampling selection from the main UI (Recommended as only below sampling methods have been validated for OpenVINO)", value=True)
|
||||
sampler_name = gr.Radio(label="Select a sampling method", choices=["Euler a", "Euler", "LMS", "Heun", "DPM++ 2M", "LMS Karras", "DPM++ 2M Karras", "DDIM", "PLMS"], value="Euler a")
|
||||
@ -620,6 +733,13 @@ class Script(scripts.Script):
|
||||
iterations use the cached compiled model for faster inference.
|
||||
""")
|
||||
|
||||
def local_config_change(choice):
|
||||
if choice:
|
||||
return gr.update(visible=True)
|
||||
else:
|
||||
return gr.update(visible=False)
|
||||
local_config.change(local_config_change, local_config, model_config)
|
||||
|
||||
def device_change(choice):
|
||||
if (model_state.device == choice):
|
||||
return gr.update(value="Device selected is " + choice, visible=True)
|
||||
@ -629,9 +749,10 @@ class Script(scripts.Script):
|
||||
return gr.update(value="Device changed to " + choice + ". Model will be re-compiled", visible=True)
|
||||
openvino_device.change(device_change, openvino_device, warmup_status)
|
||||
|
||||
return [openvino_device, override_sampler, sampler_name, enable_caching]
|
||||
return [local_config, model_config, openvino_device, override_sampler, sampler_name, enable_caching]
|
||||
|
||||
def run(self, p, openvino_device, override_sampler, sampler_name, enable_caching):
|
||||
def run(self, p, local_config, model_config, openvino_device, override_sampler, sampler_name, enable_caching):
|
||||
model_state.partition_id = 0
|
||||
os.environ["OPENVINO_TORCH_BACKEND_DEVICE"] = str(openvino_device)
|
||||
|
||||
if enable_caching:
|
||||
@ -648,14 +769,14 @@ class Script(scripts.Script):
|
||||
mode = 0
|
||||
if self.is_txt2img:
|
||||
mode = 0
|
||||
processed = process_images_openvino(p, p.sampler_name, enable_caching, openvino_device, mode)
|
||||
processed = process_images_openvino(p, local_config, model_config, p.sampler_name, enable_caching, openvino_device, mode)
|
||||
else:
|
||||
if p.image_mask is None:
|
||||
mode = 1
|
||||
else:
|
||||
mode = 2
|
||||
p.init = functools.partial(init_new, p)
|
||||
processed = process_images_openvino(p, p.sampler_name, enable_caching, openvino_device, mode)
|
||||
processed = process_images_openvino(p, local_config, model_config, p.sampler_name, enable_caching, openvino_device, mode)
|
||||
return processed
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user