diff --git a/scripts/openvino_accelerate.py b/scripts/openvino_accelerate.py index 443b2389..7785b14e 100644 --- a/scripts/openvino_accelerate.py +++ b/scripts/openvino_accelerate.py @@ -28,6 +28,7 @@ from PIL import Image, ImageOps from pathlib import Path import openvino.frontend.pytorch.torchdynamo.backend +import openvino.frontend.pytorch.torchdynamo.compile from openvino.frontend.pytorch.torchdynamo.execute import partitioned_modules, compiled_cache from openvino.runtime import Core @@ -63,6 +64,118 @@ class ModelState: model_state = ModelState() +from torch.fx import GraphModule + +from openvino.frontend.pytorch.fx_decoder import TorchFXPythonDecoder +from openvino.runtime import Core, Type, PartialShape, serialize + +from typing import Callable, Optional +from hashlib import sha256 + +from torch._dynamo.backends.common import fake_tensor_unsupported +from torch._dynamo.backends.registry import register_backend +from torch._inductor.compile_fx import compile_fx +from torch.fx.experimental.proxy_tensor import make_fx + +from openvino.frontend import FrontEndManager +from openvino.frontend.pytorch.torchdynamo.partition import Partitioner +from openvino.frontend.pytorch.torchdynamo.execute import execute + +import time + +partition_id = 0 + +@register_backend +@fake_tensor_unsupported +def openvino_fx(subgraph, example_inputs): + global partition_id + try: + executor_parameters = None + core = Core() + if os.getenv("OPENVINO_TORCH_MODEL_CACHING") is not None: + model_hash_str = sha256(subgraph.code.encode('utf-8')).hexdigest() + model_hash_str_file = model_hash_str + str(partition_id) + partition_id = partition_id + 1 + executor_parameters = {"model_hash_str": model_hash_str} + + example_inputs.reverse() + cache_root = "./cache/" + if os.getenv("OPENVINO_TORCH_CACHE_DIR") is not None: + cache_root = os.getenv("OPENVINO_TORCH_CACHE_DIR") + + device = "CPU" + + if os.getenv("OPENVINO_TORCH_BACKEND_DEVICE") is not None: + device = os.getenv("OPENVINO_TORCH_BACKEND_DEVICE") + assert device in core.available_devices, "Specified device " + device + " is not in the list of OpenVINO Available Devices" + + file_name = get_cached_file_name(*example_inputs, model_hash_str=model_hash_str_file, device=device, cache_root=cache_root) + + if file_name is not None and os.path.isfile(file_name + ".xml") and os.path.isfile(file_name + ".bin"): + start_time = time.time() + om = core.read_model(file_name + ".xml") + + dtype_mapping = { + torch.float32: Type.f32, + torch.float64: Type.f64, + torch.float16: Type.f16, + torch.int64: Type.i64, + torch.int32: Type.i32, + torch.uint8: Type.u8, + torch.int8: Type.i8, + torch.bool: Type.boolean + } + + for idx, input_data in enumerate(example_inputs): + om.inputs[idx].get_node().set_element_type(dtype_mapping[input_data.dtype]) + om.inputs[idx].get_node().set_partial_shape(PartialShape(list(input_data.shape))) + om.validate_nodes_and_infer_types() + + if model_hash_str is not None: + core.set_property({'CACHE_DIR': cache_root + '/blob'}) + + start_time = time.time() + compiled_model = core.compile_model(om, device) + def _call(*args): + ov_inputs = [a.detach().cpu().numpy() for a in args] + ov_inputs.reverse() + res = compiled_model(ov_inputs) + result = [torch.from_numpy(res[out]) for out in compiled_model.outputs] + return result + return _call + else: + example_inputs.reverse() + model = make_fx(subgraph)(*example_inputs) + with torch.no_grad(): + model.eval() + partitioner = Partitioner() + compiled_model = partitioner.make_partitions(model) + + def _call(*args): + res = execute(compiled_model, *args, executor="openvino", + executor_parameters=executor_parameters) + return res + return _call + except Exception as e: + log.debug(f"Failed in OpenVINO execution: {e}") + return compile_fx(subgraph, example_inputs) + +def get_cached_file_name(*args, model_hash_str, device, cache_root): + file_name = None + if model_hash_str is not None: + model_cache_dir = cache_root + "/model/" + try: + os.makedirs(model_cache_dir, exist_ok=True) + file_name = model_cache_dir + model_hash_str + "_" + device + for idx, input_data in enumerate(args): + if file_name is not None: + file_name += "_" + str(input_data.type()) + str(input_data.size())[11:-1].replace(" ", "") + except OSError as error: + print("Cache directory ", cache_root, " cannot be created. Model caching is disabled. Error: ", error) + file_name = None + model_hash_str = None + return file_name + def from_single_file(self, pretrained_model_link_or_path, **kwargs): cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) @@ -233,8 +346,8 @@ def get_diffusers_sd_model(sampler_name, enable_caching, openvino_device, mode): sd_model.safety_checker = None sd_model.cond_stage_key = functools.partial(cond_stage_key, shared.sd_model) sd_model.scheduler = set_scheduler(sd_model, sampler_name) - sd_model.unet = torch.compile(sd_model.unet, backend="openvino") - sd_model.vae.decode = torch.compile(sd_model.vae.decode, backend="openvino") + sd_model.unet = torch.compile(sd_model.unet, backend="openvino_fx") + sd_model.vae.decode = torch.compile(sd_model.vae.decode, backend="openvino_fx") shared.sd_diffusers_model = sd_model del sd_model return shared.sd_diffusers_model @@ -644,6 +757,8 @@ class Script(scripts.Script): return [openvino_device, override_sampler, sampler_name, enable_caching] def run(self, p, openvino_device, override_sampler, sampler_name, enable_caching): + global partition_id + partition_id = 0 os.environ["OPENVINO_TORCH_BACKEND_DEVICE"] = str(openvino_device) if enable_caching: