Added torch.compile support for controlnet

This commit is contained in:
ynimmaga 2023-08-15 21:30:35 -07:00
parent 3691141e61
commit 7eb98ef5be

View File

@ -30,7 +30,7 @@ from types import MappingProxyType
from openvino.frontend import FrontEndManager
from openvino.frontend.pytorch.fx_decoder import TorchFXPythonDecoder
from openvino.frontend.pytorch.torchdynamo import backend, compile # noqa: F401
from openvino.frontend.pytorch.torchdynamo.execute import execute, compiled_cache # noqa: F401
from openvino.frontend.pytorch.torchdynamo.execute import execute # noqa: F401
from openvino.frontend.pytorch.torchdynamo.partition import Partitioner
from openvino.runtime import Core, Type, PartialShape, serialize
@ -76,7 +76,7 @@ class ModelState:
self.mode = 0
self.partition_id = 0
self.model_hash = ""
self.cn_model = ""
self.cn_model = "None"
model_state = ModelState()
@ -87,6 +87,7 @@ DEFAULT_OPENVINO_PYTHON_CONFIG = MappingProxyType(
},
)
compiled_cache = {}
max_openvino_partitions = 0
partitioned_modules = {}
@ -100,6 +101,9 @@ def openvino_fx(subgraph, example_inputs):
if os.getenv("OPENVINO_TORCH_MODEL_CACHING") is not None:
model_hash_str = sha256(subgraph.code.encode('utf-8')).hexdigest()
model_hash_str_file = model_hash_str + str(model_state.partition_id)
if (model_state.cn_model != "None" and model_state.partition_id == 0):
model_hash_str_file = model_hash_str_file + model_state.cn_model
executor_parameters = {"model_hash_str": model_hash_str}
example_inputs.reverse()
@ -116,7 +120,7 @@ def openvino_fx(subgraph, example_inputs):
file_name = get_cached_file_name(*example_inputs, model_hash_str=model_hash_str_file, device=device, cache_root=cache_root)
if (file_name is not None and os.path.isfile(file_name + ".xml") and os.path.isfile(file_name + ".bin")
and (model_state.cn_model == "" or model_state.cn_model == "None")):
and model_state.cn_model == "None"):
model_state.partition_id = model_state.partition_id + 1
om = core.read_model(file_name + ".xml")
@ -270,6 +274,10 @@ def openvino_compile(gm: GraphModule, *args, model_hash_str: str = None, file_na
if os.getenv("OPENVINO_TORCH_CACHE_DIR") is not None:
cache_root = os.getenv("OPENVINO_TORCH_CACHE_DIR")
type_shape_string = ""
for input_data in args:
type_shape_string += "_" + str(input_data.type()) + str(input_data.size())[11:-1].replace(" ", "")
if file_name is not None and os.path.isfile(file_name + ".xml") and os.path.isfile(file_name + ".bin"):
om = core.read_model(file_name + ".xml")
else:
@ -278,7 +286,7 @@ def openvino_compile(gm: GraphModule, *args, model_hash_str: str = None, file_na
input_shapes = []
input_types = []
for idx, input_data in enumerate(args):
for input_data in args:
input_types.append(input_data.type())
input_shapes.append(input_data.size())
@ -501,9 +509,9 @@ def get_diffusers_sd_model(local_config, model_config, sampler_name, enable_cach
elif (mode == 2):
sd_model = StableDiffusionInpaintPipeline(**sd_model.components)
elif (mode == 3):
controlnet = ControlNetModel.from_pretrained(model_state.cn_model)
controlnet = ControlNetModel.from_pretrained("lllyasviel/" + model_state.cn_model)
sd_model = StableDiffusionControlNetPipeline(**sd_model.components, controlnet=controlnet)
controlnet = torch.compile(controlnet, backend="openvino")
sd_model.controlnet = torch.compile(sd_model.controlnet, backend="openvino_fx")
checkpoint_info = CheckpointInfo(checkpoint_path)
sd_model.sd_checkpoint_info = checkpoint_info
sd_model.sd_model_hash = checkpoint_info.calculate_shorthash()
@ -651,7 +659,7 @@ def process_images_openvino(p: StableDiffusionProcessing, local_config, model_co
control_images = []
cn_model=""
cn_model="None"
if ('ControlNet' in p.extra_generation_params):
cn_params = p.extra_generation_params['ControlNet']
@ -667,7 +675,6 @@ def process_images_openvino(p: StableDiffusionProcessing, local_config, model_co
)
p.scripts.postprocess(p, control_res)
control_image = control_images[0]
cn_model = "lllyasviel/" + cn_model
mode = 3
infotexts = []