mirror of
https://github.com/openvinotoolkit/stable-diffusion-webui.git
synced 2024-12-14 22:53:25 +03:00
Added caching optimizations
This commit is contained in:
parent
ac9c9e19ae
commit
b154c7e32b
@ -28,6 +28,7 @@ from PIL import Image, ImageOps
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import openvino.frontend.pytorch.torchdynamo.backend
|
import openvino.frontend.pytorch.torchdynamo.backend
|
||||||
|
import openvino.frontend.pytorch.torchdynamo.compile
|
||||||
from openvino.frontend.pytorch.torchdynamo.execute import partitioned_modules, compiled_cache
|
from openvino.frontend.pytorch.torchdynamo.execute import partitioned_modules, compiled_cache
|
||||||
from openvino.runtime import Core
|
from openvino.runtime import Core
|
||||||
|
|
||||||
@ -63,6 +64,118 @@ class ModelState:
|
|||||||
|
|
||||||
model_state = ModelState()
|
model_state = ModelState()
|
||||||
|
|
||||||
|
from torch.fx import GraphModule
|
||||||
|
|
||||||
|
from openvino.frontend.pytorch.fx_decoder import TorchFXPythonDecoder
|
||||||
|
from openvino.runtime import Core, Type, PartialShape, serialize
|
||||||
|
|
||||||
|
from typing import Callable, Optional
|
||||||
|
from hashlib import sha256
|
||||||
|
|
||||||
|
from torch._dynamo.backends.common import fake_tensor_unsupported
|
||||||
|
from torch._dynamo.backends.registry import register_backend
|
||||||
|
from torch._inductor.compile_fx import compile_fx
|
||||||
|
from torch.fx.experimental.proxy_tensor import make_fx
|
||||||
|
|
||||||
|
from openvino.frontend import FrontEndManager
|
||||||
|
from openvino.frontend.pytorch.torchdynamo.partition import Partitioner
|
||||||
|
from openvino.frontend.pytorch.torchdynamo.execute import execute
|
||||||
|
|
||||||
|
import time
|
||||||
|
|
||||||
|
partition_id = 0
|
||||||
|
|
||||||
|
@register_backend
|
||||||
|
@fake_tensor_unsupported
|
||||||
|
def openvino_fx(subgraph, example_inputs):
|
||||||
|
global partition_id
|
||||||
|
try:
|
||||||
|
executor_parameters = None
|
||||||
|
core = Core()
|
||||||
|
if os.getenv("OPENVINO_TORCH_MODEL_CACHING") is not None:
|
||||||
|
model_hash_str = sha256(subgraph.code.encode('utf-8')).hexdigest()
|
||||||
|
model_hash_str_file = model_hash_str + str(partition_id)
|
||||||
|
partition_id = partition_id + 1
|
||||||
|
executor_parameters = {"model_hash_str": model_hash_str}
|
||||||
|
|
||||||
|
example_inputs.reverse()
|
||||||
|
cache_root = "./cache/"
|
||||||
|
if os.getenv("OPENVINO_TORCH_CACHE_DIR") is not None:
|
||||||
|
cache_root = os.getenv("OPENVINO_TORCH_CACHE_DIR")
|
||||||
|
|
||||||
|
device = "CPU"
|
||||||
|
|
||||||
|
if os.getenv("OPENVINO_TORCH_BACKEND_DEVICE") is not None:
|
||||||
|
device = os.getenv("OPENVINO_TORCH_BACKEND_DEVICE")
|
||||||
|
assert device in core.available_devices, "Specified device " + device + " is not in the list of OpenVINO Available Devices"
|
||||||
|
|
||||||
|
file_name = get_cached_file_name(*example_inputs, model_hash_str=model_hash_str_file, device=device, cache_root=cache_root)
|
||||||
|
|
||||||
|
if file_name is not None and os.path.isfile(file_name + ".xml") and os.path.isfile(file_name + ".bin"):
|
||||||
|
start_time = time.time()
|
||||||
|
om = core.read_model(file_name + ".xml")
|
||||||
|
|
||||||
|
dtype_mapping = {
|
||||||
|
torch.float32: Type.f32,
|
||||||
|
torch.float64: Type.f64,
|
||||||
|
torch.float16: Type.f16,
|
||||||
|
torch.int64: Type.i64,
|
||||||
|
torch.int32: Type.i32,
|
||||||
|
torch.uint8: Type.u8,
|
||||||
|
torch.int8: Type.i8,
|
||||||
|
torch.bool: Type.boolean
|
||||||
|
}
|
||||||
|
|
||||||
|
for idx, input_data in enumerate(example_inputs):
|
||||||
|
om.inputs[idx].get_node().set_element_type(dtype_mapping[input_data.dtype])
|
||||||
|
om.inputs[idx].get_node().set_partial_shape(PartialShape(list(input_data.shape)))
|
||||||
|
om.validate_nodes_and_infer_types()
|
||||||
|
|
||||||
|
if model_hash_str is not None:
|
||||||
|
core.set_property({'CACHE_DIR': cache_root + '/blob'})
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
compiled_model = core.compile_model(om, device)
|
||||||
|
def _call(*args):
|
||||||
|
ov_inputs = [a.detach().cpu().numpy() for a in args]
|
||||||
|
ov_inputs.reverse()
|
||||||
|
res = compiled_model(ov_inputs)
|
||||||
|
result = [torch.from_numpy(res[out]) for out in compiled_model.outputs]
|
||||||
|
return result
|
||||||
|
return _call
|
||||||
|
else:
|
||||||
|
example_inputs.reverse()
|
||||||
|
model = make_fx(subgraph)(*example_inputs)
|
||||||
|
with torch.no_grad():
|
||||||
|
model.eval()
|
||||||
|
partitioner = Partitioner()
|
||||||
|
compiled_model = partitioner.make_partitions(model)
|
||||||
|
|
||||||
|
def _call(*args):
|
||||||
|
res = execute(compiled_model, *args, executor="openvino",
|
||||||
|
executor_parameters=executor_parameters)
|
||||||
|
return res
|
||||||
|
return _call
|
||||||
|
except Exception as e:
|
||||||
|
log.debug(f"Failed in OpenVINO execution: {e}")
|
||||||
|
return compile_fx(subgraph, example_inputs)
|
||||||
|
|
||||||
|
def get_cached_file_name(*args, model_hash_str, device, cache_root):
|
||||||
|
file_name = None
|
||||||
|
if model_hash_str is not None:
|
||||||
|
model_cache_dir = cache_root + "/model/"
|
||||||
|
try:
|
||||||
|
os.makedirs(model_cache_dir, exist_ok=True)
|
||||||
|
file_name = model_cache_dir + model_hash_str + "_" + device
|
||||||
|
for idx, input_data in enumerate(args):
|
||||||
|
if file_name is not None:
|
||||||
|
file_name += "_" + str(input_data.type()) + str(input_data.size())[11:-1].replace(" ", "")
|
||||||
|
except OSError as error:
|
||||||
|
print("Cache directory ", cache_root, " cannot be created. Model caching is disabled. Error: ", error)
|
||||||
|
file_name = None
|
||||||
|
model_hash_str = None
|
||||||
|
return file_name
|
||||||
|
|
||||||
def from_single_file(self, pretrained_model_link_or_path, **kwargs):
|
def from_single_file(self, pretrained_model_link_or_path, **kwargs):
|
||||||
|
|
||||||
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
||||||
@ -233,8 +346,8 @@ def get_diffusers_sd_model(sampler_name, enable_caching, openvino_device, mode):
|
|||||||
sd_model.safety_checker = None
|
sd_model.safety_checker = None
|
||||||
sd_model.cond_stage_key = functools.partial(cond_stage_key, shared.sd_model)
|
sd_model.cond_stage_key = functools.partial(cond_stage_key, shared.sd_model)
|
||||||
sd_model.scheduler = set_scheduler(sd_model, sampler_name)
|
sd_model.scheduler = set_scheduler(sd_model, sampler_name)
|
||||||
sd_model.unet = torch.compile(sd_model.unet, backend="openvino")
|
sd_model.unet = torch.compile(sd_model.unet, backend="openvino_fx")
|
||||||
sd_model.vae.decode = torch.compile(sd_model.vae.decode, backend="openvino")
|
sd_model.vae.decode = torch.compile(sd_model.vae.decode, backend="openvino_fx")
|
||||||
shared.sd_diffusers_model = sd_model
|
shared.sd_diffusers_model = sd_model
|
||||||
del sd_model
|
del sd_model
|
||||||
return shared.sd_diffusers_model
|
return shared.sd_diffusers_model
|
||||||
@ -644,6 +757,8 @@ class Script(scripts.Script):
|
|||||||
return [openvino_device, override_sampler, sampler_name, enable_caching]
|
return [openvino_device, override_sampler, sampler_name, enable_caching]
|
||||||
|
|
||||||
def run(self, p, openvino_device, override_sampler, sampler_name, enable_caching):
|
def run(self, p, openvino_device, override_sampler, sampler_name, enable_caching):
|
||||||
|
global partition_id
|
||||||
|
partition_id = 0
|
||||||
os.environ["OPENVINO_TORCH_BACKEND_DEVICE"] = str(openvino_device)
|
os.environ["OPENVINO_TORCH_BACKEND_DEVICE"] = str(openvino_device)
|
||||||
|
|
||||||
if enable_caching:
|
if enable_caching:
|
||||||
|
Loading…
Reference in New Issue
Block a user