mirror of
https://github.com/openvinotoolkit/stable-diffusion-webui.git
synced 2024-12-14 22:53:25 +03:00
Added fixes for image to image and inpainting
This commit is contained in:
parent
d8c7ce3135
commit
9f31570ae8
@ -31,6 +31,8 @@ from openvino.runtime import Core
|
|||||||
|
|
||||||
from diffusers import (
|
from diffusers import (
|
||||||
StableDiffusionPipeline,
|
StableDiffusionPipeline,
|
||||||
|
StableDiffusionImg2ImgPipeline,
|
||||||
|
StableDiffusionInpaintPipeline,
|
||||||
DDIMScheduler,
|
DDIMScheduler,
|
||||||
DPMSolverMultistepScheduler,
|
DPMSolverMultistepScheduler,
|
||||||
EulerAncestralDiscreteScheduler,
|
EulerAncestralDiscreteScheduler,
|
||||||
@ -47,6 +49,7 @@ class ModelState:
|
|||||||
self.height = 512
|
self.height = 512
|
||||||
self.width = 512
|
self.width = 512
|
||||||
self.batch_size = 1
|
self.batch_size = 1
|
||||||
|
self.mode = 0
|
||||||
|
|
||||||
model_state = ModelState()
|
model_state = ModelState()
|
||||||
|
|
||||||
@ -90,7 +93,7 @@ def set_scheduler(sd_model, sampler_name):
|
|||||||
|
|
||||||
return sd_model.scheduler
|
return sd_model.scheduler
|
||||||
|
|
||||||
def get_diffusers_sd_model(sampler_name, enable_caching, openvino_device):
|
def get_diffusers_sd_model(sampler_name, enable_caching, openvino_device, mode):
|
||||||
if (model_state.recompile == 1):
|
if (model_state.recompile == 1):
|
||||||
torch._dynamo.reset()
|
torch._dynamo.reset()
|
||||||
openvino_clear_caches()
|
openvino_clear_caches()
|
||||||
@ -99,7 +102,12 @@ def get_diffusers_sd_model(sampler_name, enable_caching, openvino_device):
|
|||||||
checkpoint_name = shared.opts.sd_model_checkpoint.split(" ")[0]
|
checkpoint_name = shared.opts.sd_model_checkpoint.split(" ")[0]
|
||||||
checkpoint_path = curr_dir_path + model_path + checkpoint_name
|
checkpoint_path = curr_dir_path + model_path + checkpoint_name
|
||||||
sd_model = StableDiffusionPipeline.from_single_file(checkpoint_path)
|
sd_model = StableDiffusionPipeline.from_single_file(checkpoint_path)
|
||||||
|
if (mode == 1):
|
||||||
|
sd_model = StableDiffusionImg2ImgPipeline.from_single_file(checkpoint_path)
|
||||||
|
elif (mode == 2):
|
||||||
|
sd_model = StableDiffusionInpaintPipeline.from_pretrained(curr_dir_path + model_path, **sd_model.components, local_files_only=True)
|
||||||
checkpoint_info = CheckpointInfo(checkpoint_path)
|
checkpoint_info = CheckpointInfo(checkpoint_path)
|
||||||
|
#model_state.mode = mode
|
||||||
sd_model.sd_checkpoint_info = checkpoint_info
|
sd_model.sd_checkpoint_info = checkpoint_info
|
||||||
sd_model.sd_model_hash = checkpoint_info.calculate_shorthash()
|
sd_model.sd_model_hash = checkpoint_info.calculate_shorthash()
|
||||||
sd_model.safety_checker = None
|
sd_model.safety_checker = None
|
||||||
@ -160,7 +168,6 @@ def init_new(self, all_prompts, all_seeds, all_subseeds):
|
|||||||
self.color_corrections = []
|
self.color_corrections = []
|
||||||
imgs = []
|
imgs = []
|
||||||
for img in self.init_images:
|
for img in self.init_images:
|
||||||
|
|
||||||
# Save init image
|
# Save init image
|
||||||
if opts.save_init_img:
|
if opts.save_init_img:
|
||||||
self.init_img_hash = hashlib.md5(img.tobytes()).hexdigest()
|
self.init_img_hash = hashlib.md5(img.tobytes()).hexdigest()
|
||||||
@ -174,7 +181,8 @@ def init_new(self, all_prompts, all_seeds, all_subseeds):
|
|||||||
if image_mask is not None:
|
if image_mask is not None:
|
||||||
image_masked = Image.new('RGBa', (image.width, image.height))
|
image_masked = Image.new('RGBa', (image.width, image.height))
|
||||||
image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(self.mask_for_overlay.convert('L')))
|
image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(self.mask_for_overlay.convert('L')))
|
||||||
|
self.mask = image_mask
|
||||||
|
image_mask.save("/home/yamini/pytorch_fx/test/mask_image.jpg")
|
||||||
self.overlay_images.append(image_masked.convert('RGBA'))
|
self.overlay_images.append(image_masked.convert('RGBA'))
|
||||||
|
|
||||||
# crop_region is not None if we are doing inpaint full res
|
# crop_region is not None if we are doing inpaint full res
|
||||||
@ -182,6 +190,7 @@ def init_new(self, all_prompts, all_seeds, all_subseeds):
|
|||||||
image = image.crop(crop_region)
|
image = image.crop(crop_region)
|
||||||
image = images.resize_image(2, image, self.width, self.height)
|
image = images.resize_image(2, image, self.width, self.height)
|
||||||
|
|
||||||
|
self.init_images = image
|
||||||
if image_mask is not None:
|
if image_mask is not None:
|
||||||
if self.inpainting_fill != 1:
|
if self.inpainting_fill != 1:
|
||||||
image = masking.fill(image, latent_mask)
|
image = masking.fill(image, latent_mask)
|
||||||
@ -208,34 +217,8 @@ def init_new(self, all_prompts, all_seeds, all_subseeds):
|
|||||||
else:
|
else:
|
||||||
raise RuntimeError(f"bad number of images passed: {len(imgs)}; expecting {self.batch_size} or less")
|
raise RuntimeError(f"bad number of images passed: {len(imgs)}; expecting {self.batch_size} or less")
|
||||||
|
|
||||||
image = torch.from_numpy(batch_images)
|
|
||||||
image = 2. * image - 1.
|
|
||||||
image = image.to(shared.device)
|
|
||||||
|
|
||||||
self.init_latent = shared.sd_diffusers_model.vae.encode(image).latent_dist.sample()
|
def process_images_openvino(p: StableDiffusionProcessing, sampler_name, enable_caching, openvino_device, mode) -> Processed:
|
||||||
|
|
||||||
if self.resize_mode == 3:
|
|
||||||
self.init_latent = torch.nn.functional.interpolate(self.init_latent, size=(self.height // 8, self.width // 8), mode="bilinear")
|
|
||||||
|
|
||||||
if image_mask is not None:
|
|
||||||
init_mask = latent_mask
|
|
||||||
latmask = init_mask.convert('RGB').resize((self.init_latent.shape[3], self.init_latent.shape[2]))
|
|
||||||
latmask = np.moveaxis(np.array(latmask, dtype=np.float32), 2, 0) / 255
|
|
||||||
latmask = latmask[0]
|
|
||||||
latmask = np.around(latmask)
|
|
||||||
latmask = np.tile(latmask[None], (4, 1, 1))
|
|
||||||
|
|
||||||
self.mask = torch.asarray(1.0 - latmask).to(shared.device).type(shared.sd_diffusers_model.vae.dtype)
|
|
||||||
self.nmask = torch.asarray(latmask).to(shared.device).type(shared.sd_diffusers_model.vae.dtype)
|
|
||||||
|
|
||||||
# this needs to be fixed to be done in sample() using actual seeds for batches
|
|
||||||
if self.inpainting_fill == 2:
|
|
||||||
self.init_latent = self.init_latent * self.mask + create_random_tensors(self.init_latent.shape[1:], all_seeds[0:self.init_latent.shape[0]]) * self.nmask
|
|
||||||
elif self.inpainting_fill == 3:
|
|
||||||
self.init_latent = self.init_latent * self.mask
|
|
||||||
|
|
||||||
|
|
||||||
def process_images_openvino(p: StableDiffusionProcessing, sampler_name, enable_caching, openvino_device) -> Processed:
|
|
||||||
"""this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
|
"""this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
|
||||||
|
|
||||||
if type(p.prompt) == list:
|
if type(p.prompt) == list:
|
||||||
@ -243,11 +226,6 @@ def process_images_openvino(p: StableDiffusionProcessing, sampler_name, enable_c
|
|||||||
else:
|
else:
|
||||||
assert p.prompt is not None
|
assert p.prompt is not None
|
||||||
|
|
||||||
if openvino_device[:3] == "GPU":
|
|
||||||
img_size_err_message = "Image height and width should be equal or less than 728 for GPU execution"
|
|
||||||
assert p.height <= 728, img_size_err_message
|
|
||||||
assert p.width <= 728, img_size_err_message
|
|
||||||
|
|
||||||
devices.torch_gc()
|
devices.torch_gc()
|
||||||
|
|
||||||
seed = get_fixed_seed(p.seed)
|
seed = get_fixed_seed(p.seed)
|
||||||
@ -304,13 +282,14 @@ def process_images_openvino(p: StableDiffusionProcessing, sampler_name, enable_c
|
|||||||
if len(p.prompts) == 0:
|
if len(p.prompts) == 0:
|
||||||
break
|
break
|
||||||
|
|
||||||
if (model_state.height != p.height or model_state.width != p.width or model_state.batch_size != p.batch_size):
|
if (model_state.height != p.height or model_state.width != p.width or model_state.batch_size != p.batch_size or model_state.mode != mode):
|
||||||
model_state.recompile = 1
|
model_state.recompile = 1
|
||||||
model_state.height = p.height
|
model_state.height = p.height
|
||||||
model_state.width = p.width
|
model_state.width = p.width
|
||||||
model_state.batch_size = p.batch_size
|
model_state.batch_size = p.batch_size
|
||||||
|
model_state.mode = mode
|
||||||
|
|
||||||
shared.sd_diffusers_model = get_diffusers_sd_model(sampler_name, enable_caching, openvino_device)
|
shared.sd_diffusers_model = get_diffusers_sd_model(sampler_name, enable_caching, openvino_device, mode)
|
||||||
shared.sd_diffusers_model.scheduler = set_scheduler(shared.sd_diffusers_model, sampler_name)
|
shared.sd_diffusers_model.scheduler = set_scheduler(shared.sd_diffusers_model, sampler_name)
|
||||||
|
|
||||||
extra_network_data = p.parse_extra_network_prompts()
|
extra_network_data = p.parse_extra_network_prompts()
|
||||||
@ -348,19 +327,47 @@ def process_images_openvino(p: StableDiffusionProcessing, sampler_name, enable_c
|
|||||||
time_stamps.append(time.time())
|
time_stamps.append(time.time())
|
||||||
|
|
||||||
time_stamps.append(time.time())
|
time_stamps.append(time.time())
|
||||||
output = shared.sd_diffusers_model(
|
if (mode == 0):
|
||||||
prompt=p.prompts,
|
output = shared.sd_diffusers_model(
|
||||||
negative_prompt=p.negative_prompts,
|
prompt=p.prompts,
|
||||||
num_inference_steps=p.steps,
|
negative_prompt=p.negative_prompts,
|
||||||
guidance_scale=p.cfg_scale,
|
num_inference_steps=p.steps,
|
||||||
height=p.height,
|
guidance_scale=p.cfg_scale,
|
||||||
width=p.width,
|
width = p.width,
|
||||||
generator=generator,
|
height = p.height,
|
||||||
output_type="np",
|
generator=generator,
|
||||||
callback = callback,
|
output_type="np",
|
||||||
callback_steps = 1
|
callback = callback,
|
||||||
)
|
callback_steps = 1,
|
||||||
|
)
|
||||||
|
elif (mode == 1):
|
||||||
|
output = shared.sd_diffusers_model(
|
||||||
|
prompt=p.prompts,
|
||||||
|
negative_prompt=p.negative_prompts,
|
||||||
|
num_inference_steps=p.steps,
|
||||||
|
guidance_scale=p.cfg_scale,
|
||||||
|
image = p.init_images,
|
||||||
|
strength = p.denoising_strength,
|
||||||
|
generator=generator,
|
||||||
|
output_type="np",
|
||||||
|
callback = callback,
|
||||||
|
callback_steps = 1,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
output = shared.sd_diffusers_model(
|
||||||
|
prompt=p.prompts,
|
||||||
|
negative_prompt=p.negative_prompts,
|
||||||
|
num_inference_steps=p.steps,
|
||||||
|
guidance_scale=p.cfg_scale,
|
||||||
|
mask_image = p.mask,
|
||||||
|
image = p.init_images,
|
||||||
|
strength = p.denoising_strength,
|
||||||
|
generator=generator,
|
||||||
|
output_type="np",
|
||||||
|
callback = callback,
|
||||||
|
callback_steps = 1,
|
||||||
|
)
|
||||||
|
|
||||||
model_state.recompile = 0
|
model_state.recompile = 0
|
||||||
|
|
||||||
warmup_duration = time_stamps[1] - time_stamps[0]
|
warmup_duration = time_stamps[1] - time_stamps[0]
|
||||||
@ -523,11 +530,18 @@ class Script(scripts.Script):
|
|||||||
if (p.sampler_name not in supported_samplers):
|
if (p.sampler_name not in supported_samplers):
|
||||||
p.sampler_name = "Euler a"
|
p.sampler_name = "Euler a"
|
||||||
|
|
||||||
|
# mode can be 0, 1, 2 corresponding to txt2img, img2img, inpaint respectively
|
||||||
|
mode = 0
|
||||||
if self.is_txt2img:
|
if self.is_txt2img:
|
||||||
processed = process_images_openvino(p, p.sampler_name, enable_caching, openvino_device)
|
mode = 0
|
||||||
|
processed = process_images_openvino(p, p.sampler_name, enable_caching, openvino_device, mode)
|
||||||
else:
|
else:
|
||||||
|
if p.image_mask is None:
|
||||||
|
mode = 1
|
||||||
|
else:
|
||||||
|
mode = 2
|
||||||
p.init = functools.partial(init_new, p)
|
p.init = functools.partial(init_new, p)
|
||||||
processed = process_images_openvino(p, p.sampler_name, enable_caching, openvino_device)
|
processed = process_images_openvino(p, p.sampler_name, enable_caching, openvino_device, mode)
|
||||||
return processed
|
return processed
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user