Merged changes from master

This commit is contained in:
ynimmaga 2023-08-21 17:50:25 -07:00
commit 78480d8762
8 changed files with 136 additions and 43 deletions

View File

@ -1,7 +1,11 @@
# Stable Diffusion web UI
A browser interface based on Gradio library for Stable Diffusion.
# Stable Diffusion web UI with OpenVINO™ Acceleration
A browser interface based on Gradio library for Stable Diffusion with OpenVINO™ Acceleration Script.
![](screenshot.png)
This repo is a fork of [AUTOMATIC1111/stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui) which includes OpenVINO support through a [custom script](https://github.com/openvinotoolkit/stable-diffusion-webui/blob/master/scripts/openvino_accelerate.py) to run it on Intel CPUs and Intel GPUs.
See wiki page for [Installation-on-Intel-Silicon](https://github.com/openvinotoolkit/stable-diffusion-webui/wiki/Installation-on-Intel-Silicon)
![](screenshot_OpenVINO.png)
## Features
[Detailed feature showcase with images](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features):

21
first-time-runner.bat Normal file
View File

@ -0,0 +1,21 @@
@echo off
set "filePath=%cd%\webui-user.bat"
(
echo @echo off
echo.
echo set GIT=
echo set VENV_DIR=
echo set COMMANDLINE_ARGS=--skip-torch-cuda-test --precision full --no-half
echo set PYTORCH_TRACING_MODE=TORCHFX
echo.
echo call webui.bat
) > %filepath%
call webui-user.bat
pause

View File

@ -32,4 +32,4 @@ torchdiffeq
torchsde
transformers==4.30.0
diffusers==0.18.2
openvino==2023.1.0.dev20230728
openvino==2023.1.0.dev20230811

View File

@ -1,4 +1,4 @@
GitPython==3.1.30
GitPython==3.1.32
Pillow==9.5.0
accelerate==0.18.0
basicsr==1.4.2
@ -30,5 +30,5 @@ torchdiffeq==0.2.3
torchsde==0.2.5
transformers==4.30.0
diffusers==0.18.2
openvino==2023.1.0.dev20230728
openvino==2023.1.0.dev20230811

BIN
screenshot_OpenVINO.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 615 KiB

View File

@ -14,23 +14,25 @@ import modules
import modules.paths as paths
import modules.scripts as scripts
from modules import images, devices, extra_networks, masking, shared
from modules import images, devices, extra_networks, masking, shared, sd_models_config
from modules.processing import (
StableDiffusionProcessing, Processed, apply_overlay, apply_color_correction,
get_fixed_seed, create_infotext, setup_color_correction,
process_images
)
from modules.sd_models import CheckpointInfo
from modules.sd_models import CheckpointInfo, get_checkpoint_state_dict
from modules.shared import opts, state
from modules.ui_common import create_refresh_button
from modules.timer import Timer
from PIL import Image, ImageOps
from pathlib import Path
from types import MappingProxyType
from typing import Callable, Optional
from typing import Optional
from openvino.frontend import FrontEndManager
from openvino.frontend.pytorch.fx_decoder import TorchFXPythonDecoder
from openvino.frontend.pytorch.torchdynamo import backend #, compile # noqa: F401
from openvino.frontend.pytorch.torchdynamo import backend # noqa: F401
from openvino.frontend.pytorch.torchdynamo.partition import Partitioner
from openvino.runtime import Core, Type, PartialShape, serialize
@ -114,7 +116,7 @@ def openvino_fx(subgraph, example_inputs):
example_inputs_reordered = []
if (os.path.isfile(maybe_fs_cached_name + ".txt")):
f = open(maybe_fs_cached_name + ".txt", "r")
for idx, input_data in enumerate(example_inputs):
for input_data in example_inputs:
shape = f.readline()
if (str(input_data.size()) != shape):
for idx1, input_data1 in enumerate(example_inputs):
@ -130,7 +132,7 @@ def openvino_fx(subgraph, example_inputs):
args_reordered = []
if (os.path.isfile(maybe_fs_cached_name + ".txt")):
f = open(maybe_fs_cached_name + ".txt", "r")
for idx, input_data in enumerate(args):
for input_data in args:
shape = f.readline()
if (str(input_data.size()) != shape):
for idx1, input_data1 in enumerate(args):
@ -163,6 +165,7 @@ def openvino_fx(subgraph, example_inputs):
return res
return _call
except Exception as e:
print(e)
return compile_fx(subgraph, example_inputs)
def check_fully_supported(self, graph_module: GraphModule) -> bool:
@ -315,7 +318,7 @@ def cached_model_name(model_hash_str, device, args, cache_root, reversed = False
return None
inputs_str = ""
for idx, input_data in enumerate(args):
for input_data in args:
if reversed:
inputs_str = "_" + str(input_data.type()) + str(input_data.size())[11:-1].replace(" ", "") + inputs_str
else:
@ -380,7 +383,7 @@ def openvino_compile(gm: GraphModule, *args, model_hash_str: str = None, file_na
input_shapes = []
input_types = []
for idx, input_data in enumerate(args):
for input_data in args:
input_types.append(input_data.type())
input_shapes.append(input_data.size())
@ -394,7 +397,7 @@ def openvino_compile(gm: GraphModule, *args, model_hash_str: str = None, file_na
serialize(om, file_name + ".xml", file_name + ".bin")
if (model_state.cn_model != "None"):
f = open(file_name + ".txt", "w")
for idx, input_data in enumerate(args):
for input_data in args:
f.write(str(input_data.size()))
f.write("\n")
f.close()
@ -572,7 +575,7 @@ def set_scheduler(sd_model, sampler_name):
return sd_model.scheduler
def get_diffusers_sd_model(local_config, model_config, sampler_name, enable_caching, openvino_device, mode):
def get_diffusers_sd_model(model_config, sampler_name, enable_caching, openvino_device, mode):
if (model_state.recompile == 1):
model_state.partition_id = 0
torch._dynamo.reset()
@ -580,11 +583,16 @@ def get_diffusers_sd_model(local_config, model_config, sampler_name, enable_cach
curr_dir_path = os.getcwd()
checkpoint_name = shared.opts.sd_model_checkpoint.split(" ")[0]
checkpoint_path = os.path.join(curr_dir_path, 'models', 'Stable-diffusion', checkpoint_name)
if local_config:
checkpoint_info = CheckpointInfo(checkpoint_path)
timer = Timer()
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
print("OpenVINO Script: created model from config : " + checkpoint_config)
if model_config != "None":
local_config_file = os.path.join(curr_dir_path, 'configs', model_config)
sd_model = StableDiffusionPipeline.from_single_file(checkpoint_path, local_config_file=local_config_file, load_safety_checker=False)
else:
sd_model = StableDiffusionPipeline.from_single_file(checkpoint_path, load_safety_checker=False, torch_dtype=torch.float32)
sd_model = StableDiffusionPipeline.from_single_file(checkpoint_path, local_config_file=checkpoint_config, load_safety_checker=False, torch_dtype=torch.float32)
if (mode == 1):
sd_model = StableDiffusionImg2ImgPipeline(**sd_model.components)
elif (mode == 2):
@ -594,6 +602,7 @@ def get_diffusers_sd_model(local_config, model_config, sampler_name, enable_cach
sd_model = StableDiffusionControlNetPipeline(**sd_model.components, controlnet=controlnet)
sd_model.controlnet = torch.compile(sd_model.controlnet, backend="openvino_fx")
checkpoint_info = CheckpointInfo(checkpoint_path)
sd_model.sd_checkpoint_info = checkpoint_info
sd_model.sd_model_hash = checkpoint_info.calculate_shorthash()
sd_model.safety_checker = None
@ -700,8 +709,7 @@ def init_new(self, all_prompts, all_seeds, all_subseeds):
else:
raise RuntimeError(f"bad number of images passed: {len(imgs)}; expecting {self.batch_size} or less")
def process_images_openvino(p: StableDiffusionProcessing, local_config, model_config, sampler_name, enable_caching, openvino_device, mode) -> Processed:
def process_images_openvino(p: StableDiffusionProcessing, model_config, sampler_name, enable_caching, openvino_device, mode) -> Processed:
"""this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
if (mode == 0 and p.enable_hr):
@ -756,7 +764,6 @@ def process_images_openvino(p: StableDiffusionProcessing, local_config, model_co
)
p.scripts.postprocess(p, control_res)
control_image = control_images[0]
#cn_model = "lllyasviel/" + cn_model
mode = 3
infotexts = []
@ -800,7 +807,7 @@ def process_images_openvino(p: StableDiffusionProcessing, local_config, model_co
model_state.cn_model = cn_model
model_state.model_hash = shared.sd_model.sd_model_hash
shared.sd_diffusers_model = get_diffusers_sd_model(local_config, model_config, sampler_name, enable_caching, openvino_device, mode)
shared.sd_diffusers_model = get_diffusers_sd_model(model_config, sampler_name, enable_caching, openvino_device, mode)
shared.sd_diffusers_model.scheduler = set_scheduler(shared.sd_diffusers_model, sampler_name)
extra_network_data = p.parse_extra_network_prompts()
@ -1000,15 +1007,20 @@ class Script(scripts.Script):
def ui(self, is_img2img):
core = Core()
config_dir_list = os.listdir(os.path.join(os.getcwd(), 'configs'))
def get_config_list():
config_dir_list = os.listdir(os.path.join(os.getcwd(), 'configs'))
config_list = []
config_list.append("None")
for file in config_dir_list:
if file.endswith('.yaml'):
config_list.append(file)
return config_list
with gr.Row():
model_config = gr.Dropdown(label="Select a local config for the model from the configs directory of the webui root", choices=get_config_list(), value="None", visible=True)
create_refresh_button(model_config, get_config_list, lambda: {"choices": get_config_list()},"refresh_model_config")
local_config = gr.Checkbox(label="Use a local inference config file", value=False)
model_config = gr.Dropdown(label="Select a config for the model (Below config files are listed from the configs directory of the WebUI root)", choices=config_list, value="v1-inference.yaml", visible=False)
openvino_device = gr.Dropdown(label="Select a device", choices=list(core.available_devices), value=model_state.device)
override_sampler = gr.Checkbox(label="Override the sampling selection from the main UI (Recommended as only below sampling methods have been validated for OpenVINO)", value=True)
sampler_name = gr.Radio(label="Select a sampling method", choices=["Euler a", "Euler", "LMS", "Heun", "DPM++ 2M", "LMS Karras", "DPM++ 2M Karras", "DDIM", "PLMS"], value="Euler a")
@ -1026,13 +1038,6 @@ class Script(scripts.Script):
So it's normal for the first inference after a settings change to be slower, while subsequent inferences use the optimized compiled model and run faster.
""")
def local_config_change(choice):
if choice:
return gr.update(visible=True)
else:
return gr.update(visible=False)
local_config.change(local_config_change, local_config, model_config)
def device_change(choice):
if (model_state.device == choice):
return gr.update(value="Device selected is " + choice, visible=True)
@ -1042,9 +1047,9 @@ class Script(scripts.Script):
return gr.update(value="Device changed to " + choice + ". Model will be re-compiled", visible=True)
openvino_device.change(device_change, openvino_device, warmup_status)
return [local_config, model_config, openvino_device, override_sampler, sampler_name, enable_caching]
return [model_config, openvino_device, override_sampler, sampler_name, enable_caching]
def run(self, p, local_config, model_config, openvino_device, override_sampler, sampler_name, enable_caching):
def run(self, p, model_config, openvino_device, override_sampler, sampler_name, enable_caching):
model_state.partition_id = 0
os.environ["OPENVINO_TORCH_BACKEND_DEVICE"] = str(openvino_device)
@ -1062,14 +1067,14 @@ class Script(scripts.Script):
mode = 0
if self.is_txt2img:
mode = 0
processed = process_images_openvino(p, local_config, model_config, p.sampler_name, enable_caching, openvino_device, mode)
processed = process_images_openvino(p, model_config, p.sampler_name, enable_caching, openvino_device, mode)
else:
if p.image_mask is None:
mode = 1
else:
mode = 2
p.init = functools.partial(init_new, p)
processed = process_images_openvino(p, local_config, model_config, p.sampler_name, enable_caching, openvino_device, mode)
processed = process_images_openvino(p, model_config, p.sampler_name, enable_caching, openvino_device, mode)
return processed

12
torch-install.bat Normal file
View File

@ -0,0 +1,12 @@
@echo off
start /wait cmd /k "%cd%\venv\Scripts\activate && pip install --pre torch==2.1.0.dev20230713+cpu torchvision==0.16.0.dev20230713+cpu -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html && exit"
echo torch 2.1.0 dev installation completed.
powershell -executionpolicy bypass .\torch-install.ps1
echo eval_frame.py modification completed. press any key to exit
pause

51
torch-install.ps1 Normal file
View File

@ -0,0 +1,51 @@
$scriptDirectory = $PSScriptRoot
Set-Location $scriptDirectory
## modify webui-user.bat
$filePath = $pwd.Path + "\webui-user.bat"
$newContent = @"
@echo off
set PYTHON=
set GIT=
set VENV_DIR=
set COMMANDLINE_ARGS=--skip-torch-cuda-test --precision full --no-half --skip-prepare-environment
set PYTORCH_TRACING_MODE=TORCHFX
call webui.bat
"@
$newContent | Set-Content -Path $filePath
### modify eval_frame
$eval_filePath = $pwd.Path + "\venv\Lib\site-packages\torch\_dynamo\eval_frame.py"
#comment out the two lines to test torch.compile on windows
$replacements = @{
" if sys.platform == `"win32`":" = "# if sys.platform == `"win32`":"
" raise RuntimeError(`"Windows not yet supported for torch.compile`")" = "# raise RuntimeError(`"Windows not yet supported for torch.compile`")"
}
$lines = Get-Content -Path $eval_filePath
foreach ($search_Text in $replacements.Keys){
$replaceText = $replacements[$search_text]
$lines = $lines.Replace($search_Text , $replaceText)
}
#write content back to file
$lines | Set-Content -Path $eval_filePath