mirror of
https://github.com/openvinotoolkit/stable-diffusion-webui.git
synced 2024-12-14 14:45:06 +03:00
Initial control net optimization
This commit is contained in:
parent
11779119d7
commit
3691141e61
@ -25,16 +25,21 @@ from modules.shared import opts, state
|
||||
|
||||
from PIL import Image, ImageOps
|
||||
from pathlib import Path
|
||||
from types import MappingProxyType
|
||||
|
||||
from openvino.frontend import FrontEndManager
|
||||
from openvino.frontend.pytorch.fx_decoder import TorchFXPythonDecoder
|
||||
from openvino.frontend.pytorch.torchdynamo import backend, compile # noqa: F401
|
||||
from openvino.frontend.pytorch.torchdynamo.execute import execute, partitioned_modules, compiled_cache # noqa: F401
|
||||
from openvino.frontend.pytorch.torchdynamo.execute import execute, compiled_cache # noqa: F401
|
||||
from openvino.frontend.pytorch.torchdynamo.partition import Partitioner
|
||||
from openvino.runtime import Core, Type, PartialShape
|
||||
from openvino.runtime import Core, Type, PartialShape, serialize
|
||||
|
||||
from torch._dynamo.backends.common import fake_tensor_unsupported
|
||||
from torch._dynamo.backends.registry import register_backend
|
||||
from torch._inductor.compile_fx import compile_fx
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch.fx import GraphModule
|
||||
from torch.utils._pytree import tree_flatten
|
||||
|
||||
from hashlib import sha256
|
||||
|
||||
@ -42,6 +47,8 @@ from diffusers import (
|
||||
StableDiffusionPipeline,
|
||||
StableDiffusionImg2ImgPipeline,
|
||||
StableDiffusionInpaintPipeline,
|
||||
StableDiffusionControlNetPipeline,
|
||||
ControlNetModel,
|
||||
DDIMScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
@ -69,19 +76,30 @@ class ModelState:
|
||||
self.mode = 0
|
||||
self.partition_id = 0
|
||||
self.model_hash = ""
|
||||
self.cn_model = ""
|
||||
|
||||
model_state = ModelState()
|
||||
|
||||
DEFAULT_OPENVINO_PYTHON_CONFIG = MappingProxyType(
|
||||
{
|
||||
"use_python_fusion_cache": True,
|
||||
"allow_single_op_fusion": True,
|
||||
},
|
||||
)
|
||||
|
||||
max_openvino_partitions = 0
|
||||
partitioned_modules = {}
|
||||
|
||||
@register_backend
|
||||
@fake_tensor_unsupported
|
||||
def openvino_fx(subgraph, example_inputs):
|
||||
try:
|
||||
executor_parameters = None
|
||||
core = Core()
|
||||
model_hash_str_file = ""
|
||||
if os.getenv("OPENVINO_TORCH_MODEL_CACHING") is not None:
|
||||
model_hash_str = sha256(subgraph.code.encode('utf-8')).hexdigest()
|
||||
model_hash_str_file = model_hash_str + str(model_state.partition_id)
|
||||
model_state.partition_id = model_state.partition_id + 1
|
||||
executor_parameters = {"model_hash_str": model_hash_str}
|
||||
|
||||
example_inputs.reverse()
|
||||
@ -97,7 +115,9 @@ def openvino_fx(subgraph, example_inputs):
|
||||
|
||||
file_name = get_cached_file_name(*example_inputs, model_hash_str=model_hash_str_file, device=device, cache_root=cache_root)
|
||||
|
||||
if file_name is not None and os.path.isfile(file_name + ".xml") and os.path.isfile(file_name + ".bin"):
|
||||
if (file_name is not None and os.path.isfile(file_name + ".xml") and os.path.isfile(file_name + ".bin")
|
||||
and (model_state.cn_model == "" or model_state.cn_model == "None")):
|
||||
model_state.partition_id = model_state.partition_id + 1
|
||||
om = core.read_model(file_name + ".xml")
|
||||
|
||||
dtype_mapping = {
|
||||
@ -129,6 +149,9 @@ def openvino_fx(subgraph, example_inputs):
|
||||
return _call
|
||||
else:
|
||||
example_inputs.reverse()
|
||||
if os.getenv("OPENVINO_TORCH_MODEL_CACHING") is not None:
|
||||
model_hash_str = sha256(subgraph.code.encode('utf-8')).hexdigest()
|
||||
executor_parameters = {"model_hash_str": model_hash_str}
|
||||
model = make_fx(subgraph)(*example_inputs)
|
||||
with torch.no_grad():
|
||||
model.eval()
|
||||
@ -136,13 +159,160 @@ def openvino_fx(subgraph, example_inputs):
|
||||
compiled_model = partitioner.make_partitions(model)
|
||||
|
||||
def _call(*args):
|
||||
res = execute(compiled_model, *args, executor="openvino",
|
||||
executor_parameters=executor_parameters)
|
||||
res = openvino_execute_partitioned(compiled_model, *args,
|
||||
executor_parameters=executor_parameters, file_name=file_name)
|
||||
return res
|
||||
return _call
|
||||
except Exception:
|
||||
return compile_fx(subgraph, example_inputs)
|
||||
|
||||
class OpenVINOGraphModule(torch.nn.Module):
|
||||
def __init__(self, gm, partition_id, use_python_fusion_cache, model_hash_str: str = None, file_name=""):
|
||||
super().__init__()
|
||||
self.gm = gm
|
||||
self.partition_id = partition_id
|
||||
self.executor_parameters = {"use_python_fusion_cache": use_python_fusion_cache,
|
||||
"model_hash_str": model_hash_str}
|
||||
self.file_name = file_name
|
||||
self.perm_fallback = False
|
||||
|
||||
def __call__(self, *args):
|
||||
if self.perm_fallback:
|
||||
return self.gm(*args)
|
||||
|
||||
try:
|
||||
result = openvino_execute(self.gm, *args, executor_parameters=self.executor_parameters, partition_id=self.partition_id, file_name=self.file_name)
|
||||
except Exception:
|
||||
self.perm_fallback = True
|
||||
return self.gm(*args)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def partition_graph(gm: GraphModule, use_python_fusion_cache: bool, model_hash_str: str = None, file_name=""):
|
||||
global max_openvino_partitions
|
||||
for node in gm.graph.nodes:
|
||||
if node.op == "call_module" and "fused_" in node.name:
|
||||
openvino_submodule = getattr(gm, node.name)
|
||||
gm.delete_submodule(node.target)
|
||||
gm.add_submodule(
|
||||
node.target,
|
||||
OpenVINOGraphModule(openvino_submodule, model_state.partition_id, use_python_fusion_cache,
|
||||
model_hash_str = model_hash_str, file_name=file_name),
|
||||
)
|
||||
model_state.partition_id = model_state.partition_id + 1
|
||||
|
||||
return gm
|
||||
|
||||
|
||||
def openvino_execute(gm: GraphModule, *args, executor_parameters=None, partition_id, file_name=""):
|
||||
executor_parameters = executor_parameters or DEFAULT_OPENVINO_PYTHON_CONFIG
|
||||
|
||||
use_cache = executor_parameters.get(
|
||||
"use_python_fusion_cache",
|
||||
DEFAULT_OPENVINO_PYTHON_CONFIG["use_python_fusion_cache"],
|
||||
)
|
||||
global compiled_cache
|
||||
|
||||
model_hash_str = executor_parameters.get("model_hash_str", None)
|
||||
if model_hash_str is not None:
|
||||
model_hash_str = model_hash_str + str(partition_id)
|
||||
|
||||
if use_cache and (partition_id in compiled_cache):
|
||||
compiled = compiled_cache[partition_id]
|
||||
else:
|
||||
compiled = openvino_compile(gm, *args, model_hash_str=model_hash_str, file_name=file_name)
|
||||
compiled_cache[partition_id] = compiled
|
||||
|
||||
flat_args, _ = tree_flatten(args)
|
||||
ov_inputs = [a.detach().cpu().numpy() for a in flat_args]
|
||||
|
||||
res = compiled(ov_inputs)
|
||||
|
||||
results1 = [torch.from_numpy(res[out]) for out in compiled.outputs]
|
||||
if len(results1) == 1:
|
||||
return results1[0]
|
||||
return results1
|
||||
|
||||
def openvino_execute_partitioned(gm: GraphModule, *args, executor_parameters=None, file_name=""):
|
||||
executor_parameters = executor_parameters or DEFAULT_OPENVINO_PYTHON_CONFIG
|
||||
|
||||
global partitioned_modules
|
||||
|
||||
use_python_fusion_cache = executor_parameters.get(
|
||||
"use_python_fusion_cache",
|
||||
DEFAULT_OPENVINO_PYTHON_CONFIG["use_python_fusion_cache"],
|
||||
)
|
||||
model_hash_str = executor_parameters.get("model_hash_str", None)
|
||||
|
||||
signature = str(id(gm))
|
||||
for idx, input_data in enumerate(args):
|
||||
if isinstance(input_data, torch.Tensor):
|
||||
signature = signature + "_" + str(idx) + ":" + str(input_data.type())[6:] + ":" + str(input_data.size())[11:-1].replace(" ", "")
|
||||
else:
|
||||
signature = signature + "_" + str(idx) + ":" + type(input_data).__name__ + ":val(" + str(input_data) + ")"
|
||||
|
||||
if signature not in partitioned_modules:
|
||||
partitioned_modules[signature] = partition_graph(gm, use_python_fusion_cache=use_python_fusion_cache,
|
||||
model_hash_str=model_hash_str, file_name=file_name)
|
||||
return partitioned_modules[signature](*args)
|
||||
|
||||
def openvino_compile(gm: GraphModule, *args, model_hash_str: str = None, file_name=""):
|
||||
core = Core()
|
||||
|
||||
device = "CPU"
|
||||
|
||||
if os.getenv("OPENVINO_TORCH_BACKEND_DEVICE") is not None:
|
||||
device = os.getenv("OPENVINO_TORCH_BACKEND_DEVICE")
|
||||
assert device in core.available_devices, "Specified device " + device + " is not in the list of OpenVINO Available Devices"
|
||||
|
||||
cache_root = "./cache/"
|
||||
if os.getenv("OPENVINO_TORCH_CACHE_DIR") is not None:
|
||||
cache_root = os.getenv("OPENVINO_TORCH_CACHE_DIR")
|
||||
|
||||
if file_name is not None and os.path.isfile(file_name + ".xml") and os.path.isfile(file_name + ".bin"):
|
||||
om = core.read_model(file_name + ".xml")
|
||||
else:
|
||||
fe_manager = FrontEndManager()
|
||||
fe = fe_manager.load_by_framework("pytorch")
|
||||
|
||||
input_shapes = []
|
||||
input_types = []
|
||||
for idx, input_data in enumerate(args):
|
||||
input_types.append(input_data.type())
|
||||
input_shapes.append(input_data.size())
|
||||
|
||||
decoder = TorchFXPythonDecoder(gm, gm, input_shapes=input_shapes, input_types=input_types)
|
||||
|
||||
im = fe.load(decoder)
|
||||
|
||||
om = fe.convert(im)
|
||||
|
||||
if file_name is not None:
|
||||
serialize(om, file_name + ".xml", file_name + ".bin")
|
||||
|
||||
dtype_mapping = {
|
||||
torch.float32: Type.f32,
|
||||
torch.float64: Type.f64,
|
||||
torch.float16: Type.f16,
|
||||
torch.int64: Type.i64,
|
||||
torch.int32: Type.i32,
|
||||
torch.uint8: Type.u8,
|
||||
torch.int8: Type.i8,
|
||||
torch.bool: Type.boolean
|
||||
}
|
||||
|
||||
for idx, input_data in enumerate(args):
|
||||
om.inputs[idx].get_node().set_element_type(dtype_mapping[input_data.dtype])
|
||||
om.inputs[idx].get_node().set_partial_shape(PartialShape(list(input_data.shape)))
|
||||
om.validate_nodes_and_infer_types()
|
||||
|
||||
if model_hash_str is not None:
|
||||
core.set_property({'CACHE_DIR': cache_root + '/blob'})
|
||||
|
||||
compiled = core.compile_model(om, device)
|
||||
return compiled
|
||||
|
||||
def get_cached_file_name(*args, model_hash_str, device, cache_root):
|
||||
file_name = None
|
||||
if model_hash_str is not None:
|
||||
@ -150,15 +320,18 @@ def get_cached_file_name(*args, model_hash_str, device, cache_root):
|
||||
try:
|
||||
os.makedirs(model_cache_dir, exist_ok=True)
|
||||
file_name = model_cache_dir + model_hash_str + "_" + device
|
||||
type_shape_string = ""
|
||||
for input_data in args:
|
||||
if file_name is not None:
|
||||
file_name += "_" + str(input_data.type()) + str(input_data.size())[11:-1].replace(" ", "")
|
||||
type_shape_string += "_" + str(input_data.type()) + str(input_data.size())[11:-1].replace(" ", "")
|
||||
file_name += sha256(type_shape_string.encode('utf-8')).hexdigest()
|
||||
except OSError as error:
|
||||
print("Cache directory ", cache_root, " cannot be created. Model caching is disabled. Error: ", error)
|
||||
file_name = None
|
||||
model_hash_str = None
|
||||
return file_name
|
||||
|
||||
|
||||
def from_single_file(self, pretrained_model_link_or_path, **kwargs):
|
||||
|
||||
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
||||
@ -312,6 +485,7 @@ def set_scheduler(sd_model, sampler_name):
|
||||
|
||||
def get_diffusers_sd_model(local_config, model_config, sampler_name, enable_caching, openvino_device, mode):
|
||||
if (model_state.recompile == 1):
|
||||
model_state.partition_id = 0
|
||||
torch._dynamo.reset()
|
||||
openvino_clear_caches()
|
||||
curr_dir_path = os.getcwd()
|
||||
@ -326,6 +500,10 @@ def get_diffusers_sd_model(local_config, model_config, sampler_name, enable_cach
|
||||
sd_model = StableDiffusionImg2ImgPipeline(**sd_model.components)
|
||||
elif (mode == 2):
|
||||
sd_model = StableDiffusionInpaintPipeline(**sd_model.components)
|
||||
elif (mode == 3):
|
||||
controlnet = ControlNetModel.from_pretrained(model_state.cn_model)
|
||||
sd_model = StableDiffusionControlNetPipeline(**sd_model.components, controlnet=controlnet)
|
||||
controlnet = torch.compile(controlnet, backend="openvino")
|
||||
checkpoint_info = CheckpointInfo(checkpoint_path)
|
||||
sd_model.sd_checkpoint_info = checkpoint_info
|
||||
sd_model.sd_model_hash = checkpoint_info.calculate_shorthash()
|
||||
@ -471,8 +649,26 @@ def process_images_openvino(p: StableDiffusionProcessing, local_config, model_co
|
||||
if p.scripts is not None:
|
||||
p.scripts.process(p)
|
||||
|
||||
if 'ControlNet' in p.extra_generation_params:
|
||||
return process_images(p)
|
||||
control_images = []
|
||||
|
||||
cn_model=""
|
||||
if ('ControlNet' in p.extra_generation_params):
|
||||
cn_params = p.extra_generation_params['ControlNet']
|
||||
|
||||
cn_param_elements = [part.strip() for part in cn_params.split(', ')]
|
||||
for element in cn_param_elements:
|
||||
if (element.split(':')[0] == "model"):
|
||||
cn_model = (element.split(':')[1]).split(' ')[1]
|
||||
|
||||
if (cn_model != "None"):
|
||||
control_res = Processed(
|
||||
p,
|
||||
images_list=control_images,
|
||||
)
|
||||
p.scripts.postprocess(p, control_res)
|
||||
control_image = control_images[0]
|
||||
cn_model = "lllyasviel/" + cn_model
|
||||
mode = 3
|
||||
|
||||
infotexts = []
|
||||
output_images = []
|
||||
@ -506,12 +702,13 @@ def process_images_openvino(p: StableDiffusionProcessing, local_config, model_co
|
||||
break
|
||||
|
||||
if (model_state.height != p.height or model_state.width != p.width or model_state.batch_size != p.batch_size
|
||||
or model_state.mode != mode or model_state.model_hash != shared.sd_model.sd_model_hash):
|
||||
or model_state.mode != mode or model_state.model_hash != shared.sd_model.sd_model_hash or model_state.cn_model != cn_model):
|
||||
model_state.recompile = 1
|
||||
model_state.height = p.height
|
||||
model_state.width = p.width
|
||||
model_state.batch_size = p.batch_size
|
||||
model_state.mode = mode
|
||||
model_state.cn_model = cn_model
|
||||
model_state.model_hash = shared.sd_model.sd_model_hash
|
||||
|
||||
shared.sd_diffusers_model = get_diffusers_sd_model(local_config, model_config, sampler_name, enable_caching, openvino_device, mode)
|
||||
@ -564,12 +761,19 @@ def process_images_openvino(p: StableDiffusionProcessing, local_config, model_co
|
||||
'image': p.init_images,
|
||||
'strength':p.denoising_strength,
|
||||
})
|
||||
else:
|
||||
elif (mode == 2):
|
||||
custom_inputs.update({
|
||||
'image': p.init_images,
|
||||
'strength':p.denoising_strength,
|
||||
'mask_image': p.mask,
|
||||
})
|
||||
elif (mode == 3):
|
||||
custom_inputs.update({
|
||||
'image': control_image,
|
||||
'width': p.width,
|
||||
'height': p.height,
|
||||
})
|
||||
|
||||
output = shared.sd_diffusers_model(
|
||||
prompt=p.prompts,
|
||||
negative_prompt=p.negative_prompts,
|
||||
@ -626,6 +830,8 @@ def process_images_openvino(p: StableDiffusionProcessing, local_config, model_co
|
||||
if opts.enable_pnginfo:
|
||||
image.info["parameters"] = text
|
||||
output_images.append(image)
|
||||
if ('ControlNet' in p.extra_generation_params and cn_model != "None"):
|
||||
output_images.append(control_image)
|
||||
|
||||
if hasattr(p, 'mask_for_overlay') and p.mask_for_overlay and any([opts.save_mask, opts.save_mask_composite, opts.return_mask, opts.return_mask_composite]):
|
||||
image_mask = p.mask_for_overlay.convert('RGB')
|
||||
@ -723,11 +929,11 @@ class Script(scripts.Script):
|
||||
"""
|
||||
###
|
||||
### Note:
|
||||
- First inference involves compilation of the model for best performance.
|
||||
- First inference involves compilation of the model for best performance.
|
||||
Since compilation happens only on the first run, the first inference (or warm up inference) will be slower than subsequent inferences.
|
||||
- For accurate performance measurements, it is recommended to exclude this slower first inference, as it doesn't reflect normal running time.
|
||||
- Model is recompiled when resolution, batchsize, device, or samplers like DPM++ or Karras are changed.
|
||||
After recompiling, later inferences will reuse the newly compiled model and achieve faster running times.
|
||||
- Model is recompiled when resolution, batchsize, device, or samplers like DPM++ or Karras are changed.
|
||||
After recompiling, later inferences will reuse the newly compiled model and achieve faster running times.
|
||||
So it's normal for the first inference after a settings change to be slower, while subsequent inferences use the optimized compiled model and run faster.
|
||||
""")
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user