Merge pull request #23 from openvinotoolkit/develop

Develop
This commit is contained in:
Yamini Nimmagadda 2023-08-18 09:43:50 -07:00 committed by GitHub
commit 434282272d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -14,15 +14,16 @@ import modules
import modules.paths as paths
import modules.scripts as scripts
from modules import images, devices, extra_networks, masking, shared
from modules import images, devices, extra_networks, masking, shared, sd_models_config
from modules.processing import (
StableDiffusionProcessing, Processed, apply_overlay, apply_color_correction,
get_fixed_seed, create_infotext, setup_color_correction,
process_images
)
from modules.sd_models import CheckpointInfo
from modules.sd_models import CheckpointInfo, get_checkpoint_state_dict
from modules.shared import opts, state
from modules.ui_common import create_refresh_button
from modules.timer import Timer
from PIL import Image, ImageOps
from pathlib import Path
@ -318,16 +319,20 @@ def get_diffusers_sd_model(model_config, sampler_name, enable_caching, openvino_
curr_dir_path = os.getcwd()
checkpoint_name = shared.opts.sd_model_checkpoint.split(" ")[0]
checkpoint_path = os.path.join(curr_dir_path, 'models', 'Stable-diffusion', checkpoint_name)
checkpoint_info = CheckpointInfo(checkpoint_path)
timer = Timer()
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
print("OpenVINO Script: created model from config : " + checkpoint_config)
if model_config != "None":
local_config_file = os.path.join(curr_dir_path, 'configs', model_config)
sd_model = StableDiffusionPipeline.from_single_file(checkpoint_path, local_config_file=local_config_file, load_safety_checker=False)
else:
sd_model = StableDiffusionPipeline.from_single_file(checkpoint_path, load_safety_checker=False, torch_dtype=torch.float32)
sd_model = StableDiffusionPipeline.from_single_file(checkpoint_path, local_config_file=checkpoint_config, load_safety_checker=False, torch_dtype=torch.float32)
if (mode == 1):
sd_model = StableDiffusionImg2ImgPipeline(**sd_model.components)
elif (mode == 2):
sd_model = StableDiffusionInpaintPipeline(**sd_model.components)
checkpoint_info = CheckpointInfo(checkpoint_path)
sd_model.sd_checkpoint_info = checkpoint_info
sd_model.sd_model_hash = checkpoint_info.calculate_shorthash()
sd_model.safety_checker = None