Initial control net optimization

This commit is contained in:
ynimmaga 2023-08-13 22:40:48 -07:00
parent 11779119d7
commit 3691141e61

View File

@ -25,16 +25,21 @@ from modules.shared import opts, state
from PIL import Image, ImageOps
from pathlib import Path
from types import MappingProxyType
from openvino.frontend import FrontEndManager
from openvino.frontend.pytorch.fx_decoder import TorchFXPythonDecoder
from openvino.frontend.pytorch.torchdynamo import backend, compile # noqa: F401
from openvino.frontend.pytorch.torchdynamo.execute import execute, partitioned_modules, compiled_cache # noqa: F401
from openvino.frontend.pytorch.torchdynamo.execute import execute, compiled_cache # noqa: F401
from openvino.frontend.pytorch.torchdynamo.partition import Partitioner
from openvino.runtime import Core, Type, PartialShape
from openvino.runtime import Core, Type, PartialShape, serialize
from torch._dynamo.backends.common import fake_tensor_unsupported
from torch._dynamo.backends.registry import register_backend
from torch._inductor.compile_fx import compile_fx
from torch.fx.experimental.proxy_tensor import make_fx
from torch.fx import GraphModule
from torch.utils._pytree import tree_flatten
from hashlib import sha256
@ -42,6 +47,8 @@ from diffusers import (
StableDiffusionPipeline,
StableDiffusionImg2ImgPipeline,
StableDiffusionInpaintPipeline,
StableDiffusionControlNetPipeline,
ControlNetModel,
DDIMScheduler,
DPMSolverMultistepScheduler,
EulerAncestralDiscreteScheduler,
@ -69,19 +76,30 @@ class ModelState:
self.mode = 0
self.partition_id = 0
self.model_hash = ""
self.cn_model = ""
model_state = ModelState()
DEFAULT_OPENVINO_PYTHON_CONFIG = MappingProxyType(
{
"use_python_fusion_cache": True,
"allow_single_op_fusion": True,
},
)
max_openvino_partitions = 0
partitioned_modules = {}
@register_backend
@fake_tensor_unsupported
def openvino_fx(subgraph, example_inputs):
try:
executor_parameters = None
core = Core()
model_hash_str_file = ""
if os.getenv("OPENVINO_TORCH_MODEL_CACHING") is not None:
model_hash_str = sha256(subgraph.code.encode('utf-8')).hexdigest()
model_hash_str_file = model_hash_str + str(model_state.partition_id)
model_state.partition_id = model_state.partition_id + 1
executor_parameters = {"model_hash_str": model_hash_str}
example_inputs.reverse()
@ -97,7 +115,9 @@ def openvino_fx(subgraph, example_inputs):
file_name = get_cached_file_name(*example_inputs, model_hash_str=model_hash_str_file, device=device, cache_root=cache_root)
if file_name is not None and os.path.isfile(file_name + ".xml") and os.path.isfile(file_name + ".bin"):
if (file_name is not None and os.path.isfile(file_name + ".xml") and os.path.isfile(file_name + ".bin")
and (model_state.cn_model == "" or model_state.cn_model == "None")):
model_state.partition_id = model_state.partition_id + 1
om = core.read_model(file_name + ".xml")
dtype_mapping = {
@ -129,6 +149,9 @@ def openvino_fx(subgraph, example_inputs):
return _call
else:
example_inputs.reverse()
if os.getenv("OPENVINO_TORCH_MODEL_CACHING") is not None:
model_hash_str = sha256(subgraph.code.encode('utf-8')).hexdigest()
executor_parameters = {"model_hash_str": model_hash_str}
model = make_fx(subgraph)(*example_inputs)
with torch.no_grad():
model.eval()
@ -136,13 +159,160 @@ def openvino_fx(subgraph, example_inputs):
compiled_model = partitioner.make_partitions(model)
def _call(*args):
res = execute(compiled_model, *args, executor="openvino",
executor_parameters=executor_parameters)
res = openvino_execute_partitioned(compiled_model, *args,
executor_parameters=executor_parameters, file_name=file_name)
return res
return _call
except Exception:
return compile_fx(subgraph, example_inputs)
class OpenVINOGraphModule(torch.nn.Module):
def __init__(self, gm, partition_id, use_python_fusion_cache, model_hash_str: str = None, file_name=""):
super().__init__()
self.gm = gm
self.partition_id = partition_id
self.executor_parameters = {"use_python_fusion_cache": use_python_fusion_cache,
"model_hash_str": model_hash_str}
self.file_name = file_name
self.perm_fallback = False
def __call__(self, *args):
if self.perm_fallback:
return self.gm(*args)
try:
result = openvino_execute(self.gm, *args, executor_parameters=self.executor_parameters, partition_id=self.partition_id, file_name=self.file_name)
except Exception:
self.perm_fallback = True
return self.gm(*args)
return result
def partition_graph(gm: GraphModule, use_python_fusion_cache: bool, model_hash_str: str = None, file_name=""):
global max_openvino_partitions
for node in gm.graph.nodes:
if node.op == "call_module" and "fused_" in node.name:
openvino_submodule = getattr(gm, node.name)
gm.delete_submodule(node.target)
gm.add_submodule(
node.target,
OpenVINOGraphModule(openvino_submodule, model_state.partition_id, use_python_fusion_cache,
model_hash_str = model_hash_str, file_name=file_name),
)
model_state.partition_id = model_state.partition_id + 1
return gm
def openvino_execute(gm: GraphModule, *args, executor_parameters=None, partition_id, file_name=""):
executor_parameters = executor_parameters or DEFAULT_OPENVINO_PYTHON_CONFIG
use_cache = executor_parameters.get(
"use_python_fusion_cache",
DEFAULT_OPENVINO_PYTHON_CONFIG["use_python_fusion_cache"],
)
global compiled_cache
model_hash_str = executor_parameters.get("model_hash_str", None)
if model_hash_str is not None:
model_hash_str = model_hash_str + str(partition_id)
if use_cache and (partition_id in compiled_cache):
compiled = compiled_cache[partition_id]
else:
compiled = openvino_compile(gm, *args, model_hash_str=model_hash_str, file_name=file_name)
compiled_cache[partition_id] = compiled
flat_args, _ = tree_flatten(args)
ov_inputs = [a.detach().cpu().numpy() for a in flat_args]
res = compiled(ov_inputs)
results1 = [torch.from_numpy(res[out]) for out in compiled.outputs]
if len(results1) == 1:
return results1[0]
return results1
def openvino_execute_partitioned(gm: GraphModule, *args, executor_parameters=None, file_name=""):
executor_parameters = executor_parameters or DEFAULT_OPENVINO_PYTHON_CONFIG
global partitioned_modules
use_python_fusion_cache = executor_parameters.get(
"use_python_fusion_cache",
DEFAULT_OPENVINO_PYTHON_CONFIG["use_python_fusion_cache"],
)
model_hash_str = executor_parameters.get("model_hash_str", None)
signature = str(id(gm))
for idx, input_data in enumerate(args):
if isinstance(input_data, torch.Tensor):
signature = signature + "_" + str(idx) + ":" + str(input_data.type())[6:] + ":" + str(input_data.size())[11:-1].replace(" ", "")
else:
signature = signature + "_" + str(idx) + ":" + type(input_data).__name__ + ":val(" + str(input_data) + ")"
if signature not in partitioned_modules:
partitioned_modules[signature] = partition_graph(gm, use_python_fusion_cache=use_python_fusion_cache,
model_hash_str=model_hash_str, file_name=file_name)
return partitioned_modules[signature](*args)
def openvino_compile(gm: GraphModule, *args, model_hash_str: str = None, file_name=""):
core = Core()
device = "CPU"
if os.getenv("OPENVINO_TORCH_BACKEND_DEVICE") is not None:
device = os.getenv("OPENVINO_TORCH_BACKEND_DEVICE")
assert device in core.available_devices, "Specified device " + device + " is not in the list of OpenVINO Available Devices"
cache_root = "./cache/"
if os.getenv("OPENVINO_TORCH_CACHE_DIR") is not None:
cache_root = os.getenv("OPENVINO_TORCH_CACHE_DIR")
if file_name is not None and os.path.isfile(file_name + ".xml") and os.path.isfile(file_name + ".bin"):
om = core.read_model(file_name + ".xml")
else:
fe_manager = FrontEndManager()
fe = fe_manager.load_by_framework("pytorch")
input_shapes = []
input_types = []
for idx, input_data in enumerate(args):
input_types.append(input_data.type())
input_shapes.append(input_data.size())
decoder = TorchFXPythonDecoder(gm, gm, input_shapes=input_shapes, input_types=input_types)
im = fe.load(decoder)
om = fe.convert(im)
if file_name is not None:
serialize(om, file_name + ".xml", file_name + ".bin")
dtype_mapping = {
torch.float32: Type.f32,
torch.float64: Type.f64,
torch.float16: Type.f16,
torch.int64: Type.i64,
torch.int32: Type.i32,
torch.uint8: Type.u8,
torch.int8: Type.i8,
torch.bool: Type.boolean
}
for idx, input_data in enumerate(args):
om.inputs[idx].get_node().set_element_type(dtype_mapping[input_data.dtype])
om.inputs[idx].get_node().set_partial_shape(PartialShape(list(input_data.shape)))
om.validate_nodes_and_infer_types()
if model_hash_str is not None:
core.set_property({'CACHE_DIR': cache_root + '/blob'})
compiled = core.compile_model(om, device)
return compiled
def get_cached_file_name(*args, model_hash_str, device, cache_root):
file_name = None
if model_hash_str is not None:
@ -150,15 +320,18 @@ def get_cached_file_name(*args, model_hash_str, device, cache_root):
try:
os.makedirs(model_cache_dir, exist_ok=True)
file_name = model_cache_dir + model_hash_str + "_" + device
type_shape_string = ""
for input_data in args:
if file_name is not None:
file_name += "_" + str(input_data.type()) + str(input_data.size())[11:-1].replace(" ", "")
type_shape_string += "_" + str(input_data.type()) + str(input_data.size())[11:-1].replace(" ", "")
file_name += sha256(type_shape_string.encode('utf-8')).hexdigest()
except OSError as error:
print("Cache directory ", cache_root, " cannot be created. Model caching is disabled. Error: ", error)
file_name = None
model_hash_str = None
return file_name
def from_single_file(self, pretrained_model_link_or_path, **kwargs):
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
@ -312,6 +485,7 @@ def set_scheduler(sd_model, sampler_name):
def get_diffusers_sd_model(local_config, model_config, sampler_name, enable_caching, openvino_device, mode):
if (model_state.recompile == 1):
model_state.partition_id = 0
torch._dynamo.reset()
openvino_clear_caches()
curr_dir_path = os.getcwd()
@ -326,6 +500,10 @@ def get_diffusers_sd_model(local_config, model_config, sampler_name, enable_cach
sd_model = StableDiffusionImg2ImgPipeline(**sd_model.components)
elif (mode == 2):
sd_model = StableDiffusionInpaintPipeline(**sd_model.components)
elif (mode == 3):
controlnet = ControlNetModel.from_pretrained(model_state.cn_model)
sd_model = StableDiffusionControlNetPipeline(**sd_model.components, controlnet=controlnet)
controlnet = torch.compile(controlnet, backend="openvino")
checkpoint_info = CheckpointInfo(checkpoint_path)
sd_model.sd_checkpoint_info = checkpoint_info
sd_model.sd_model_hash = checkpoint_info.calculate_shorthash()
@ -471,8 +649,26 @@ def process_images_openvino(p: StableDiffusionProcessing, local_config, model_co
if p.scripts is not None:
p.scripts.process(p)
if 'ControlNet' in p.extra_generation_params:
return process_images(p)
control_images = []
cn_model=""
if ('ControlNet' in p.extra_generation_params):
cn_params = p.extra_generation_params['ControlNet']
cn_param_elements = [part.strip() for part in cn_params.split(', ')]
for element in cn_param_elements:
if (element.split(':')[0] == "model"):
cn_model = (element.split(':')[1]).split(' ')[1]
if (cn_model != "None"):
control_res = Processed(
p,
images_list=control_images,
)
p.scripts.postprocess(p, control_res)
control_image = control_images[0]
cn_model = "lllyasviel/" + cn_model
mode = 3
infotexts = []
output_images = []
@ -506,12 +702,13 @@ def process_images_openvino(p: StableDiffusionProcessing, local_config, model_co
break
if (model_state.height != p.height or model_state.width != p.width or model_state.batch_size != p.batch_size
or model_state.mode != mode or model_state.model_hash != shared.sd_model.sd_model_hash):
or model_state.mode != mode or model_state.model_hash != shared.sd_model.sd_model_hash or model_state.cn_model != cn_model):
model_state.recompile = 1
model_state.height = p.height
model_state.width = p.width
model_state.batch_size = p.batch_size
model_state.mode = mode
model_state.cn_model = cn_model
model_state.model_hash = shared.sd_model.sd_model_hash
shared.sd_diffusers_model = get_diffusers_sd_model(local_config, model_config, sampler_name, enable_caching, openvino_device, mode)
@ -564,12 +761,19 @@ def process_images_openvino(p: StableDiffusionProcessing, local_config, model_co
'image': p.init_images,
'strength':p.denoising_strength,
})
else:
elif (mode == 2):
custom_inputs.update({
'image': p.init_images,
'strength':p.denoising_strength,
'mask_image': p.mask,
})
elif (mode == 3):
custom_inputs.update({
'image': control_image,
'width': p.width,
'height': p.height,
})
output = shared.sd_diffusers_model(
prompt=p.prompts,
negative_prompt=p.negative_prompts,
@ -626,6 +830,8 @@ def process_images_openvino(p: StableDiffusionProcessing, local_config, model_co
if opts.enable_pnginfo:
image.info["parameters"] = text
output_images.append(image)
if ('ControlNet' in p.extra_generation_params and cn_model != "None"):
output_images.append(control_image)
if hasattr(p, 'mask_for_overlay') and p.mask_for_overlay and any([opts.save_mask, opts.save_mask_composite, opts.return_mask, opts.return_mask_composite]):
image_mask = p.mask_for_overlay.convert('RGB')
@ -723,11 +929,11 @@ class Script(scripts.Script):
"""
###
### Note:
- First inference involves compilation of the model for best performance.
- First inference involves compilation of the model for best performance.
Since compilation happens only on the first run, the first inference (or warm up inference) will be slower than subsequent inferences.
- For accurate performance measurements, it is recommended to exclude this slower first inference, as it doesn't reflect normal running time.
- Model is recompiled when resolution, batchsize, device, or samplers like DPM++ or Karras are changed.
After recompiling, later inferences will reuse the newly compiled model and achieve faster running times.
- Model is recompiled when resolution, batchsize, device, or samplers like DPM++ or Karras are changed.
After recompiling, later inferences will reuse the newly compiled model and achieve faster running times.
So it's normal for the first inference after a settings change to be slower, while subsequent inferences use the optimized compiled model and run faster.
""")