2023-08-13 08:24:16 +03:00
|
|
|
from __future__ import annotations
|
2022-09-03 12:08:45 +03:00
|
|
|
import json
|
2023-05-29 08:54:13 +03:00
|
|
|
import logging
|
2022-09-03 12:08:45 +03:00
|
|
|
import math
|
|
|
|
import os
|
|
|
|
import sys
|
2023-04-06 19:42:26 +03:00
|
|
|
import hashlib
|
2023-08-13 08:24:16 +03:00
|
|
|
from dataclasses import dataclass, field
|
2022-09-03 12:08:45 +03:00
|
|
|
|
|
|
|
import torch
|
|
|
|
import numpy as np
|
2023-06-09 22:47:27 +03:00
|
|
|
from PIL import Image, ImageOps
|
2022-09-03 12:08:45 +03:00
|
|
|
import random
|
2022-09-13 12:51:57 +03:00
|
|
|
import cv2
|
|
|
|
from skimage import exposure
|
2023-08-13 08:24:16 +03:00
|
|
|
from typing import Any
|
2022-09-03 12:08:45 +03:00
|
|
|
|
2022-09-05 03:25:37 +03:00
|
|
|
import modules.sd_hijack
|
2023-08-09 08:43:31 +03:00
|
|
|
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors, rng
|
2023-08-09 21:24:16 +03:00
|
|
|
from modules.rng import slerp # noqa: F401
|
2022-09-03 12:08:45 +03:00
|
|
|
from modules.sd_hijack import model_hijack
|
2023-08-04 13:23:14 +03:00
|
|
|
from modules.sd_samplers_common import images_tensor_to_samples, decode_first_stage, approximation_indexes
|
2022-09-03 12:08:45 +03:00
|
|
|
from modules.shared import opts, cmd_opts, state
|
|
|
|
import modules.shared as shared
|
2023-01-25 19:15:42 +03:00
|
|
|
import modules.paths as paths
|
2022-09-07 12:32:28 +03:00
|
|
|
import modules.face_restoration
|
2022-09-03 12:08:45 +03:00
|
|
|
import modules.images as images
|
2022-09-09 23:16:02 +03:00
|
|
|
import modules.styles
|
2022-11-29 06:11:29 +03:00
|
|
|
import modules.sd_models as sd_models
|
|
|
|
import modules.sd_vae as sd_vae
|
2022-12-09 03:14:35 +03:00
|
|
|
from ldm.data.util import AddMiDaS
|
|
|
|
from ldm.models.diffusion.ddpm import LatentDepth2ImageDiffusion
|
2022-09-03 12:08:45 +03:00
|
|
|
|
2022-12-09 03:14:35 +03:00
|
|
|
from einops import repeat, rearrange
|
2022-12-12 02:03:36 +03:00
|
|
|
from blendmodes.blend import blendLayers, BlendType
|
2023-04-10 11:37:15 +03:00
|
|
|
|
|
|
|
|
2022-09-03 12:08:45 +03:00
|
|
|
# some of those options should not be changed at all because they would break the model, so I removed them from options.
|
|
|
|
opt_C = 4
|
|
|
|
opt_f = 8
|
|
|
|
|
|
|
|
|
2022-09-13 12:51:57 +03:00
|
|
|
def setup_color_correction(image):
|
2022-09-23 03:57:42 +03:00
|
|
|
logging.info("Calibrating color correction.")
|
2022-09-13 12:51:57 +03:00
|
|
|
correction_target = cv2.cvtColor(np.asarray(image.copy()), cv2.COLOR_RGB2LAB)
|
|
|
|
return correction_target
|
|
|
|
|
|
|
|
|
2022-12-12 02:03:36 +03:00
|
|
|
def apply_color_correction(correction, original_image):
|
2022-09-23 03:57:42 +03:00
|
|
|
logging.info("Applying color correction.")
|
2022-09-13 12:51:57 +03:00
|
|
|
image = Image.fromarray(cv2.cvtColor(exposure.match_histograms(
|
|
|
|
cv2.cvtColor(
|
2022-12-12 02:03:36 +03:00
|
|
|
np.asarray(original_image),
|
2022-09-13 12:51:57 +03:00
|
|
|
cv2.COLOR_RGB2LAB
|
|
|
|
),
|
|
|
|
correction,
|
|
|
|
channel_axis=2
|
|
|
|
), cv2.COLOR_LAB2RGB).astype("uint8"))
|
2022-12-25 12:17:49 +03:00
|
|
|
|
2022-12-12 02:03:36 +03:00
|
|
|
image = blendLayers(image, original_image, BlendType.LUMINOSITY)
|
2022-12-25 12:17:49 +03:00
|
|
|
|
2023-08-11 19:22:11 +03:00
|
|
|
return image.convert('RGB')
|
2022-09-13 12:51:57 +03:00
|
|
|
|
2022-10-24 09:15:26 +03:00
|
|
|
|
|
|
|
def apply_overlay(image, paste_loc, index, overlays):
|
|
|
|
if overlays is None or index >= len(overlays):
|
|
|
|
return image
|
|
|
|
|
|
|
|
overlay = overlays[index]
|
|
|
|
|
|
|
|
if paste_loc is not None:
|
|
|
|
x, y, w, h = paste_loc
|
|
|
|
base_image = Image.new('RGBA', (overlay.width, overlay.height))
|
|
|
|
image = images.resize_image(1, image, w, h)
|
|
|
|
base_image.paste(image, (x, y))
|
|
|
|
image = base_image
|
|
|
|
|
|
|
|
image = image.convert('RGBA')
|
|
|
|
image.alpha_composite(overlay)
|
|
|
|
image = image.convert('RGB')
|
2022-10-23 22:38:42 +03:00
|
|
|
|
|
|
|
return image
|
2022-09-13 12:51:57 +03:00
|
|
|
|
2022-10-09 03:13:13 +03:00
|
|
|
|
2023-01-04 17:58:07 +03:00
|
|
|
def txt2img_image_conditioning(sd_model, x, width, height):
|
2023-03-25 05:48:16 +03:00
|
|
|
if sd_model.model.conditioning_key in {'hybrid', 'concat'}: # Inpainting models
|
|
|
|
|
|
|
|
# The "masked-image" in this case will just be all zeros since the entire image is masked.
|
|
|
|
image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device)
|
2023-08-04 13:23:14 +03:00
|
|
|
image_conditioning = images_tensor_to_samples(image_conditioning, approximation_indexes.get(opts.sd_vae_encode_method))
|
2023-03-25 05:48:16 +03:00
|
|
|
|
|
|
|
# Add the fake full 1s mask to the first dimension.
|
|
|
|
image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
|
|
|
|
image_conditioning = image_conditioning.to(x.dtype)
|
|
|
|
|
|
|
|
return image_conditioning
|
2023-01-04 17:58:07 +03:00
|
|
|
|
2023-03-25 05:48:16 +03:00
|
|
|
elif sd_model.model.conditioning_key == "crossattn-adm": # UnCLIP models
|
2023-01-04 17:58:07 +03:00
|
|
|
|
2023-03-25 05:48:16 +03:00
|
|
|
return x.new_zeros(x.shape[0], 2*sd_model.noise_augmentor.time_embed.dim, dtype=x.dtype, device=x.device)
|
2023-01-04 17:58:07 +03:00
|
|
|
|
2023-03-25 05:48:16 +03:00
|
|
|
else:
|
|
|
|
# Dummy zero conditioning if we're not using inpainting or unclip models.
|
|
|
|
# Still takes up a bit of memory, but no encoder call.
|
|
|
|
# Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
|
|
|
|
return x.new_zeros(x.shape[0], 5, 1, 1, dtype=x.dtype, device=x.device)
|
2023-01-04 17:58:07 +03:00
|
|
|
|
|
|
|
|
2023-08-13 08:24:16 +03:00
|
|
|
@dataclass(repr=False)
|
2023-01-16 23:09:08 +03:00
|
|
|
class StableDiffusionProcessing:
|
2023-08-13 08:24:16 +03:00
|
|
|
sd_model: object = None
|
|
|
|
outpath_samples: str = None
|
|
|
|
outpath_grids: str = None
|
|
|
|
prompt: str = ""
|
|
|
|
prompt_for_display: str = None
|
|
|
|
negative_prompt: str = ""
|
2023-08-14 09:48:40 +03:00
|
|
|
styles: list[str] = None
|
2023-08-13 08:24:16 +03:00
|
|
|
seed: int = -1
|
|
|
|
subseed: int = -1
|
|
|
|
subseed_strength: float = 0
|
|
|
|
seed_resize_from_h: int = -1
|
|
|
|
seed_resize_from_w: int = -1
|
|
|
|
seed_enable_extras: bool = True
|
|
|
|
sampler_name: str = None
|
|
|
|
batch_size: int = 1
|
|
|
|
n_iter: int = 1
|
|
|
|
steps: int = 50
|
|
|
|
cfg_scale: float = 7.0
|
|
|
|
width: int = 512
|
|
|
|
height: int = 512
|
|
|
|
restore_faces: bool = None
|
|
|
|
tiling: bool = None
|
|
|
|
do_not_save_samples: bool = False
|
|
|
|
do_not_save_grid: bool = False
|
|
|
|
extra_generation_params: dict[str, Any] = None
|
|
|
|
overlay_images: list = None
|
|
|
|
eta: float = None
|
|
|
|
do_not_reload_embeddings: bool = False
|
|
|
|
denoising_strength: float = 0
|
|
|
|
ddim_discretize: str = None
|
|
|
|
s_min_uncond: float = None
|
|
|
|
s_churn: float = None
|
|
|
|
s_tmax: float = None
|
|
|
|
s_tmin: float = None
|
|
|
|
s_noise: float = None
|
|
|
|
override_settings: dict[str, Any] = None
|
|
|
|
override_settings_restore_afterwards: bool = True
|
|
|
|
sampler_index: int = None
|
|
|
|
refiner_checkpoint: str = None
|
|
|
|
refiner_switch_at: float = None
|
|
|
|
token_merging_ratio = 0
|
|
|
|
token_merging_ratio_hr = 0
|
|
|
|
disable_extra_networks: bool = False
|
|
|
|
|
2023-08-13 17:31:10 +03:00
|
|
|
scripts_value: scripts.ScriptRunner = field(default=None, init=False)
|
|
|
|
script_args_value: list = field(default=None, init=False)
|
|
|
|
scripts_setup_complete: bool = field(default=False, init=False)
|
2023-08-13 08:24:16 +03:00
|
|
|
|
2023-06-08 07:53:02 +03:00
|
|
|
cached_uc = [None, None]
|
|
|
|
cached_c = [None, None]
|
|
|
|
|
2023-08-13 15:07:37 +03:00
|
|
|
comments: dict = None
|
2023-08-13 08:24:16 +03:00
|
|
|
sampler: sd_samplers_common.Sampler | None = field(default=None, init=False)
|
|
|
|
is_using_inpainting_conditioning: bool = field(default=False, init=False)
|
|
|
|
paste_to: tuple | None = field(default=None, init=False)
|
|
|
|
|
|
|
|
is_hr_pass: bool = field(default=False, init=False)
|
|
|
|
|
|
|
|
c: tuple = field(default=None, init=False)
|
|
|
|
uc: tuple = field(default=None, init=False)
|
|
|
|
|
|
|
|
rng: rng.ImageRNG | None = field(default=None, init=False)
|
|
|
|
step_multiplier: int = field(default=1, init=False)
|
|
|
|
color_corrections: list = field(default=None, init=False)
|
|
|
|
|
|
|
|
all_prompts: list = field(default=None, init=False)
|
|
|
|
all_negative_prompts: list = field(default=None, init=False)
|
|
|
|
all_seeds: list = field(default=None, init=False)
|
|
|
|
all_subseeds: list = field(default=None, init=False)
|
|
|
|
iteration: int = field(default=0, init=False)
|
|
|
|
main_prompt: str = field(default=None, init=False)
|
|
|
|
main_negative_prompt: str = field(default=None, init=False)
|
|
|
|
|
|
|
|
prompts: list = field(default=None, init=False)
|
|
|
|
negative_prompts: list = field(default=None, init=False)
|
|
|
|
seeds: list = field(default=None, init=False)
|
|
|
|
subseeds: list = field(default=None, init=False)
|
|
|
|
extra_network_data: dict = field(default=None, init=False)
|
|
|
|
|
|
|
|
user: str = field(default=None, init=False)
|
|
|
|
|
|
|
|
sd_model_name: str = field(default=None, init=False)
|
|
|
|
sd_model_hash: str = field(default=None, init=False)
|
|
|
|
sd_vae_name: str = field(default=None, init=False)
|
|
|
|
sd_vae_hash: str = field(default=None, init=False)
|
|
|
|
|
|
|
|
def __post_init__(self):
|
|
|
|
if self.sampler_index is not None:
|
2022-11-27 13:17:39 +03:00
|
|
|
print("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name", file=sys.stderr)
|
2022-11-19 12:01:51 +03:00
|
|
|
|
2023-08-13 15:07:37 +03:00
|
|
|
self.comments = {}
|
2023-08-14 10:15:10 +03:00
|
|
|
|
|
|
|
if self.styles is None:
|
|
|
|
self.styles = []
|
2023-08-13 15:07:37 +03:00
|
|
|
|
2022-09-30 03:44:38 +03:00
|
|
|
self.sampler_noise_scheduler_override = None
|
2023-08-13 08:24:16 +03:00
|
|
|
self.s_min_uncond = self.s_min_uncond if self.s_min_uncond is not None else opts.s_min_uncond
|
|
|
|
self.s_churn = self.s_churn if self.s_churn is not None else opts.s_churn
|
|
|
|
self.s_tmin = self.s_tmin if self.s_tmin is not None else opts.s_tmin
|
|
|
|
self.s_tmax = (self.s_tmax if self.s_tmax is not None else opts.s_tmax) or float('inf')
|
|
|
|
self.s_noise = self.s_noise if self.s_noise is not None else opts.s_noise
|
|
|
|
|
|
|
|
self.extra_generation_params = self.extra_generation_params or {}
|
|
|
|
self.override_settings = self.override_settings or {}
|
|
|
|
self.script_args = self.script_args or {}
|
|
|
|
|
2023-08-13 06:07:30 +03:00
|
|
|
self.refiner_checkpoint_info = None
|
2022-10-04 18:49:51 +03:00
|
|
|
|
2023-08-13 08:24:16 +03:00
|
|
|
if not self.seed_enable_extras:
|
2022-09-21 13:34:10 +03:00
|
|
|
self.subseed = -1
|
|
|
|
self.subseed_strength = 0
|
|
|
|
self.seed_resize_from_h = 0
|
|
|
|
self.seed_resize_from_w = 0
|
|
|
|
|
2023-06-08 07:53:02 +03:00
|
|
|
self.cached_uc = StableDiffusionProcessing.cached_uc
|
|
|
|
self.cached_c = StableDiffusionProcessing.cached_c
|
2023-08-13 06:07:30 +03:00
|
|
|
|
2023-01-16 23:09:08 +03:00
|
|
|
@property
|
|
|
|
def sd_model(self):
|
|
|
|
return shared.sd_model
|
|
|
|
|
2023-08-13 08:24:16 +03:00
|
|
|
@sd_model.setter
|
|
|
|
def sd_model(self, value):
|
|
|
|
pass
|
|
|
|
|
2023-08-13 17:31:10 +03:00
|
|
|
@property
|
|
|
|
def scripts(self):
|
|
|
|
return self.scripts_value
|
|
|
|
|
|
|
|
@scripts.setter
|
|
|
|
def scripts(self, value):
|
|
|
|
self.scripts_value = value
|
|
|
|
|
|
|
|
if self.scripts_value and self.script_args_value and not self.scripts_setup_complete:
|
|
|
|
self.setup_scripts()
|
|
|
|
|
|
|
|
@property
|
|
|
|
def script_args(self):
|
|
|
|
return self.script_args_value
|
|
|
|
|
|
|
|
@script_args.setter
|
|
|
|
def script_args(self, value):
|
|
|
|
self.script_args_value = value
|
|
|
|
|
|
|
|
if self.scripts_value and self.script_args_value and not self.scripts_setup_complete:
|
|
|
|
self.setup_scripts()
|
|
|
|
|
|
|
|
def setup_scripts(self):
|
|
|
|
self.scripts_setup_complete = True
|
|
|
|
|
|
|
|
self.scripts.setup_scrips(self)
|
|
|
|
|
2023-08-13 15:07:37 +03:00
|
|
|
def comment(self, text):
|
|
|
|
self.comments[text] = 1
|
|
|
|
|
2022-10-27 21:27:59 +03:00
|
|
|
def txt2img_image_conditioning(self, x, width=None, height=None):
|
2023-01-04 17:58:07 +03:00
|
|
|
self.is_using_inpainting_conditioning = self.sd_model.model.conditioning_key in {'hybrid', 'concat'}
|
2022-11-19 12:47:52 +03:00
|
|
|
|
2023-01-04 17:58:07 +03:00
|
|
|
return txt2img_image_conditioning(self.sd_model, x, width or self.width, height or self.height)
|
2022-10-27 21:27:59 +03:00
|
|
|
|
2022-12-09 03:14:35 +03:00
|
|
|
def depth2img_image_conditioning(self, source_image):
|
|
|
|
# Use the AddMiDaS helper to Format our source image to suit the MiDaS model
|
|
|
|
transformer = AddMiDaS(model_type="dpt_hybrid")
|
|
|
|
transformed = transformer({"jpg": rearrange(source_image[0], "c h w -> h w c")})
|
|
|
|
midas_in = torch.from_numpy(transformed["midas_in"][None, ...]).to(device=shared.device)
|
|
|
|
midas_in = repeat(midas_in, "1 ... -> n ...", n=self.batch_size)
|
|
|
|
|
2023-08-04 13:23:14 +03:00
|
|
|
conditioning_image = images_tensor_to_samples(source_image*0.5+0.5, approximation_indexes.get(opts.sd_vae_encode_method))
|
2022-12-09 03:14:35 +03:00
|
|
|
conditioning = torch.nn.functional.interpolate(
|
|
|
|
self.sd_model.depth_model(midas_in),
|
|
|
|
size=conditioning_image.shape[2:],
|
|
|
|
mode="bicubic",
|
|
|
|
align_corners=False,
|
|
|
|
)
|
|
|
|
|
|
|
|
(depth_min, depth_max) = torch.aminmax(conditioning)
|
|
|
|
conditioning = 2. * (conditioning - depth_min) / (depth_max - depth_min) - 1.
|
|
|
|
return conditioning
|
2022-10-27 21:27:59 +03:00
|
|
|
|
2023-01-25 23:25:25 +03:00
|
|
|
def edit_image_conditioning(self, source_image):
|
2023-08-04 13:23:14 +03:00
|
|
|
conditioning_image = images_tensor_to_samples(source_image*0.5+0.5, approximation_indexes.get(opts.sd_vae_encode_method))
|
2023-01-25 23:25:25 +03:00
|
|
|
|
|
|
|
return conditioning_image
|
|
|
|
|
2023-03-25 05:48:16 +03:00
|
|
|
def unclip_image_conditioning(self, source_image):
|
|
|
|
c_adm = self.sd_model.embedder(source_image)
|
|
|
|
if self.sd_model.noise_augmentor is not None:
|
|
|
|
noise_level = 0 # TODO: Allow other noise levels?
|
|
|
|
c_adm, noise_level_emb = self.sd_model.noise_augmentor(c_adm, noise_level=repeat(torch.tensor([noise_level]).to(c_adm.device), '1 -> b', b=c_adm.shape[0]))
|
|
|
|
c_adm = torch.cat((c_adm, noise_level_emb), 1)
|
|
|
|
return c_adm
|
|
|
|
|
2023-01-25 23:25:25 +03:00
|
|
|
def inpainting_image_conditioning(self, source_image, latent_image, image_mask=None):
|
2022-11-19 12:47:52 +03:00
|
|
|
self.is_using_inpainting_conditioning = True
|
|
|
|
|
2022-10-27 21:27:59 +03:00
|
|
|
# Handle the different mask inputs
|
|
|
|
if image_mask is not None:
|
|
|
|
if torch.is_tensor(image_mask):
|
|
|
|
conditioning_mask = image_mask
|
|
|
|
else:
|
|
|
|
conditioning_mask = np.array(image_mask.convert("L"))
|
|
|
|
conditioning_mask = conditioning_mask.astype(np.float32) / 255.0
|
|
|
|
conditioning_mask = torch.from_numpy(conditioning_mask[None, None])
|
|
|
|
|
|
|
|
# Inpainting model uses a discretized mask as input, so we round to either 1.0 or 0.0
|
|
|
|
conditioning_mask = torch.round(conditioning_mask)
|
|
|
|
else:
|
2022-10-29 20:35:51 +03:00
|
|
|
conditioning_mask = source_image.new_ones(1, 1, *source_image.shape[-2:])
|
2022-10-27 21:27:59 +03:00
|
|
|
|
|
|
|
# Create another latent image, this time with a masked version of the original input.
|
|
|
|
# Smoothly interpolate between the masked and unmasked latent conditioning image using a parameter.
|
2023-01-25 07:51:45 +03:00
|
|
|
conditioning_mask = conditioning_mask.to(device=source_image.device, dtype=source_image.dtype)
|
2022-10-27 21:27:59 +03:00
|
|
|
conditioning_image = torch.lerp(
|
|
|
|
source_image,
|
|
|
|
source_image * (1.0 - conditioning_mask),
|
|
|
|
getattr(self, "inpainting_mask_weight", shared.opts.inpainting_mask_weight)
|
|
|
|
)
|
2022-12-15 05:01:32 +03:00
|
|
|
|
2022-10-27 21:27:59 +03:00
|
|
|
# Encode the new masked image using first stage of network.
|
2023-08-04 12:53:30 +03:00
|
|
|
conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image))
|
2022-10-27 21:27:59 +03:00
|
|
|
|
|
|
|
# Create the concatenated conditioning tensor to be fed to `c_concat`
|
|
|
|
conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=latent_image.shape[-2:])
|
|
|
|
conditioning_mask = conditioning_mask.expand(conditioning_image.shape[0], -1, -1, -1)
|
|
|
|
image_conditioning = torch.cat([conditioning_mask, conditioning_image], dim=1)
|
|
|
|
image_conditioning = image_conditioning.to(shared.device).type(self.sd_model.dtype)
|
|
|
|
|
|
|
|
return image_conditioning
|
|
|
|
|
2022-12-09 03:14:35 +03:00
|
|
|
def img2img_image_conditioning(self, source_image, latent_image, image_mask=None):
|
2023-01-27 18:19:43 +03:00
|
|
|
source_image = devices.cond_cast_float(source_image)
|
|
|
|
|
2022-12-09 03:14:35 +03:00
|
|
|
# HACK: Using introspection as the Depth2Image model doesn't appear to uniquely
|
|
|
|
# identify itself with a field common to all models. The conditioning_key is also hybrid.
|
|
|
|
if isinstance(self.sd_model, LatentDepth2ImageDiffusion):
|
2023-01-27 18:19:43 +03:00
|
|
|
return self.depth2img_image_conditioning(source_image)
|
2022-12-09 03:14:35 +03:00
|
|
|
|
2023-01-25 23:25:25 +03:00
|
|
|
if self.sd_model.cond_stage_key == "edit":
|
|
|
|
return self.edit_image_conditioning(source_image)
|
|
|
|
|
2022-12-09 03:14:35 +03:00
|
|
|
if self.sampler.conditioning_key in {'hybrid', 'concat'}:
|
2023-01-27 18:19:43 +03:00
|
|
|
return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask)
|
2022-12-09 03:14:35 +03:00
|
|
|
|
2023-03-25 05:48:16 +03:00
|
|
|
if self.sampler.conditioning_key == "crossattn-adm":
|
|
|
|
return self.unclip_image_conditioning(source_image)
|
|
|
|
|
2022-12-09 03:14:35 +03:00
|
|
|
# Dummy zero conditioning if we're not using inpainting or depth model.
|
|
|
|
return latent_image.new_zeros(latent_image.shape[0], 5, 1, 1)
|
|
|
|
|
2022-09-19 16:42:56 +03:00
|
|
|
def init(self, all_prompts, all_seeds, all_subseeds):
|
2022-09-03 12:08:45 +03:00
|
|
|
pass
|
|
|
|
|
2023-05-18 20:16:09 +03:00
|
|
|
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
|
2022-09-03 12:08:45 +03:00
|
|
|
raise NotImplementedError()
|
|
|
|
|
2022-11-02 03:56:47 +03:00
|
|
|
def close(self):
|
|
|
|
self.sampler = None
|
2023-05-18 20:16:09 +03:00
|
|
|
self.c = None
|
|
|
|
self.uc = None
|
2023-08-06 13:25:51 +03:00
|
|
|
if not opts.persistent_cond_cache:
|
2023-06-08 07:53:02 +03:00
|
|
|
StableDiffusionProcessing.cached_c = [None, None]
|
|
|
|
StableDiffusionProcessing.cached_uc = [None, None]
|
2022-11-02 03:56:47 +03:00
|
|
|
|
2023-05-17 20:22:38 +03:00
|
|
|
def get_token_merging_ratio(self, for_hr=False):
|
|
|
|
if for_hr:
|
|
|
|
return self.token_merging_ratio_hr or opts.token_merging_ratio_hr or self.token_merging_ratio or opts.token_merging_ratio
|
|
|
|
|
|
|
|
return self.token_merging_ratio or opts.token_merging_ratio
|
|
|
|
|
2023-05-18 20:16:09 +03:00
|
|
|
def setup_prompts(self):
|
|
|
|
if type(self.prompt) == list:
|
|
|
|
self.all_prompts = self.prompt
|
|
|
|
else:
|
|
|
|
self.all_prompts = self.batch_size * self.n_iter * [self.prompt]
|
|
|
|
|
|
|
|
if type(self.negative_prompt) == list:
|
|
|
|
self.all_negative_prompts = self.negative_prompt
|
|
|
|
else:
|
|
|
|
self.all_negative_prompts = self.batch_size * self.n_iter * [self.negative_prompt]
|
|
|
|
|
|
|
|
self.all_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, self.styles) for x in self.all_prompts]
|
|
|
|
self.all_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, self.styles) for x in self.all_negative_prompts]
|
|
|
|
|
2023-08-09 07:45:06 +03:00
|
|
|
self.main_prompt = self.all_prompts[0]
|
|
|
|
self.main_negative_prompt = self.all_negative_prompts[0]
|
|
|
|
|
2023-08-06 13:25:51 +03:00
|
|
|
def cached_params(self, required_prompts, steps, extra_network_data):
|
|
|
|
"""Returns parameters that invalidate the cond cache if changed"""
|
|
|
|
|
|
|
|
return (
|
|
|
|
required_prompts,
|
|
|
|
steps,
|
|
|
|
opts.CLIP_stop_at_last_layers,
|
|
|
|
shared.sd_model.sd_checkpoint_info,
|
|
|
|
extra_network_data,
|
|
|
|
opts.sdxl_crop_left,
|
|
|
|
opts.sdxl_crop_top,
|
|
|
|
self.width,
|
|
|
|
self.height,
|
|
|
|
)
|
|
|
|
|
2023-06-04 16:29:02 +03:00
|
|
|
def get_conds_with_caching(self, function, required_prompts, steps, caches, extra_network_data):
|
2023-05-18 20:16:09 +03:00
|
|
|
"""
|
|
|
|
Returns the result of calling function(shared.sd_model, required_prompts, steps)
|
|
|
|
using a cache to store the result if the same arguments have been used before.
|
|
|
|
|
|
|
|
cache is an array containing two elements. The first element is a tuple
|
|
|
|
representing the previously used arguments, or None if no arguments
|
|
|
|
have been used before. The second element is where the previously
|
|
|
|
computed result is stored.
|
2023-06-04 16:29:02 +03:00
|
|
|
|
|
|
|
caches is a list with items described above.
|
2023-05-18 20:16:09 +03:00
|
|
|
"""
|
2023-07-14 17:54:09 +03:00
|
|
|
|
2023-08-06 13:25:51 +03:00
|
|
|
cached_params = self.cached_params(required_prompts, steps, extra_network_data)
|
2023-07-14 17:54:09 +03:00
|
|
|
|
2023-06-04 16:29:02 +03:00
|
|
|
for cache in caches:
|
2023-07-14 17:54:09 +03:00
|
|
|
if cache[0] is not None and cached_params == cache[0]:
|
2023-06-04 16:29:02 +03:00
|
|
|
return cache[1]
|
|
|
|
|
|
|
|
cache = caches[0]
|
2023-05-18 20:16:09 +03:00
|
|
|
|
|
|
|
with devices.autocast():
|
|
|
|
cache[1] = function(shared.sd_model, required_prompts, steps)
|
|
|
|
|
2023-07-14 17:54:09 +03:00
|
|
|
cache[0] = cached_params
|
2023-05-18 20:16:09 +03:00
|
|
|
return cache[1]
|
|
|
|
|
|
|
|
def setup_conds(self):
|
2023-07-12 23:52:43 +03:00
|
|
|
prompts = prompt_parser.SdConditioning(self.prompts, width=self.width, height=self.height)
|
2023-07-13 11:35:52 +03:00
|
|
|
negative_prompts = prompt_parser.SdConditioning(self.negative_prompts, width=self.width, height=self.height, is_negative_prompt=True)
|
2023-07-12 23:52:43 +03:00
|
|
|
|
2023-05-18 20:16:09 +03:00
|
|
|
sampler_config = sd_samplers.find_sampler_config(self.sampler_name)
|
2023-08-12 12:39:59 +03:00
|
|
|
total_steps = sampler_config.total_steps(self.steps) if sampler_config else self.steps
|
|
|
|
self.step_multiplier = total_steps // self.steps
|
|
|
|
self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, total_steps, [self.cached_uc], self.extra_network_data)
|
|
|
|
self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, total_steps, [self.cached_c], self.extra_network_data)
|
2023-05-18 20:16:09 +03:00
|
|
|
|
2023-08-06 17:53:33 +03:00
|
|
|
def get_conds(self):
|
|
|
|
return self.c, self.uc
|
|
|
|
|
2023-05-18 20:16:09 +03:00
|
|
|
def parse_extra_network_prompts(self):
|
2023-06-03 23:19:34 +03:00
|
|
|
self.prompts, self.extra_network_data = extra_networks.parse_prompts(self.prompts)
|
2023-05-18 20:16:09 +03:00
|
|
|
|
2023-08-06 06:21:36 +03:00
|
|
|
def save_samples(self) -> bool:
|
|
|
|
"""Returns whether generated images need to be written to disk"""
|
|
|
|
return opts.samples_save and not self.do_not_save_samples and (opts.save_incomplete_images or not state.interrupted and not state.skipped)
|
|
|
|
|
2022-09-03 12:08:45 +03:00
|
|
|
|
|
|
|
class Processed:
|
2022-12-31 23:40:55 +03:00
|
|
|
def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_negative_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None, comments=""):
|
2022-09-03 12:08:45 +03:00
|
|
|
self.images = images_list
|
|
|
|
self.prompt = p.prompt
|
2022-09-12 19:57:31 +03:00
|
|
|
self.negative_prompt = p.negative_prompt
|
2022-09-03 12:08:45 +03:00
|
|
|
self.seed = seed
|
2022-09-16 22:20:56 +03:00
|
|
|
self.subseed = subseed
|
|
|
|
self.subseed_strength = p.subseed_strength
|
2022-09-03 12:08:45 +03:00
|
|
|
self.info = info
|
2023-08-13 15:07:37 +03:00
|
|
|
self.comments = "".join(f"{comment}\n" for comment in p.comments)
|
2022-09-03 12:08:45 +03:00
|
|
|
self.width = p.width
|
|
|
|
self.height = p.height
|
2022-11-19 12:01:51 +03:00
|
|
|
self.sampler_name = p.sampler_name
|
2022-09-03 12:08:45 +03:00
|
|
|
self.cfg_scale = p.cfg_scale
|
2023-02-04 03:15:32 +03:00
|
|
|
self.image_cfg_scale = getattr(p, 'image_cfg_scale', None)
|
2022-09-03 12:08:45 +03:00
|
|
|
self.steps = p.steps
|
2022-09-19 09:02:10 +03:00
|
|
|
self.batch_size = p.batch_size
|
|
|
|
self.restore_faces = p.restore_faces
|
|
|
|
self.face_restoration_model = opts.face_restoration_model if p.restore_faces else None
|
2023-08-13 06:07:30 +03:00
|
|
|
self.sd_model_name = p.sd_model_name
|
|
|
|
self.sd_model_hash = p.sd_model_hash
|
|
|
|
self.sd_vae_name = p.sd_vae_name
|
|
|
|
self.sd_vae_hash = p.sd_vae_hash
|
2022-09-19 09:02:10 +03:00
|
|
|
self.seed_resize_from_w = p.seed_resize_from_w
|
|
|
|
self.seed_resize_from_h = p.seed_resize_from_h
|
|
|
|
self.denoising_strength = getattr(p, 'denoising_strength', None)
|
|
|
|
self.extra_generation_params = p.extra_generation_params
|
|
|
|
self.index_of_first_image = index_of_first_image
|
2022-10-04 20:13:09 +03:00
|
|
|
self.styles = p.styles
|
2022-10-04 20:17:15 +03:00
|
|
|
self.job_timestamp = state.job_timestamp
|
2022-10-09 00:28:42 +03:00
|
|
|
self.clip_skip = opts.CLIP_stop_at_last_layers
|
2023-05-17 20:22:38 +03:00
|
|
|
self.token_merging_ratio = p.token_merging_ratio
|
|
|
|
self.token_merging_ratio_hr = p.token_merging_ratio_hr
|
2022-09-19 09:02:10 +03:00
|
|
|
|
2022-09-28 05:11:03 +03:00
|
|
|
self.eta = p.eta
|
2022-09-26 17:40:47 +03:00
|
|
|
self.ddim_discretize = p.ddim_discretize
|
|
|
|
self.s_churn = p.s_churn
|
|
|
|
self.s_tmin = p.s_tmin
|
|
|
|
self.s_tmax = p.s_tmax
|
|
|
|
self.s_noise = p.s_noise
|
2023-05-17 10:26:32 +03:00
|
|
|
self.s_min_uncond = p.s_min_uncond
|
2022-09-30 03:44:38 +03:00
|
|
|
self.sampler_noise_scheduler_override = p.sampler_noise_scheduler_override
|
2022-09-19 09:02:10 +03:00
|
|
|
self.prompt = self.prompt if type(self.prompt) != list else self.prompt[0]
|
|
|
|
self.negative_prompt = self.negative_prompt if type(self.negative_prompt) != list else self.negative_prompt[0]
|
2022-10-14 06:05:07 +03:00
|
|
|
self.seed = int(self.seed if type(self.seed) != list else self.seed[0]) if self.seed is not None else -1
|
2022-09-19 09:02:10 +03:00
|
|
|
self.subseed = int(self.subseed if type(self.subseed) != list else self.subseed[0]) if self.subseed is not None else -1
|
2022-11-19 12:47:52 +03:00
|
|
|
self.is_using_inpainting_conditioning = p.is_using_inpainting_conditioning
|
2022-09-19 09:02:10 +03:00
|
|
|
|
2022-11-19 13:23:25 +03:00
|
|
|
self.all_prompts = all_prompts or p.all_prompts or [self.prompt]
|
|
|
|
self.all_negative_prompts = all_negative_prompts or p.all_negative_prompts or [self.negative_prompt]
|
|
|
|
self.all_seeds = all_seeds or p.all_seeds or [self.seed]
|
|
|
|
self.all_subseeds = all_subseeds or p.all_subseeds or [self.subseed]
|
2022-09-28 17:05:23 +03:00
|
|
|
self.infotexts = infotexts or [info]
|
2022-09-03 12:08:45 +03:00
|
|
|
|
|
|
|
def js(self):
|
|
|
|
obj = {
|
2022-11-19 13:23:25 +03:00
|
|
|
"prompt": self.all_prompts[0],
|
2022-09-19 09:02:10 +03:00
|
|
|
"all_prompts": self.all_prompts,
|
2022-11-19 13:23:25 +03:00
|
|
|
"negative_prompt": self.all_negative_prompts[0],
|
|
|
|
"all_negative_prompts": self.all_negative_prompts,
|
2022-09-19 09:02:10 +03:00
|
|
|
"seed": self.seed,
|
|
|
|
"all_seeds": self.all_seeds,
|
|
|
|
"subseed": self.subseed,
|
|
|
|
"all_subseeds": self.all_subseeds,
|
2022-09-16 22:20:56 +03:00
|
|
|
"subseed_strength": self.subseed_strength,
|
2022-09-03 12:08:45 +03:00
|
|
|
"width": self.width,
|
|
|
|
"height": self.height,
|
2022-11-19 12:01:51 +03:00
|
|
|
"sampler_name": self.sampler_name,
|
2022-09-03 12:08:45 +03:00
|
|
|
"cfg_scale": self.cfg_scale,
|
|
|
|
"steps": self.steps,
|
2022-09-19 09:02:10 +03:00
|
|
|
"batch_size": self.batch_size,
|
|
|
|
"restore_faces": self.restore_faces,
|
|
|
|
"face_restoration_model": self.face_restoration_model,
|
2023-08-13 06:07:30 +03:00
|
|
|
"sd_model_name": self.sd_model_name,
|
2022-09-19 09:02:10 +03:00
|
|
|
"sd_model_hash": self.sd_model_hash,
|
2023-08-13 06:07:30 +03:00
|
|
|
"sd_vae_name": self.sd_vae_name,
|
|
|
|
"sd_vae_hash": self.sd_vae_hash,
|
2022-09-19 09:02:10 +03:00
|
|
|
"seed_resize_from_w": self.seed_resize_from_w,
|
|
|
|
"seed_resize_from_h": self.seed_resize_from_h,
|
|
|
|
"denoising_strength": self.denoising_strength,
|
|
|
|
"extra_generation_params": self.extra_generation_params,
|
|
|
|
"index_of_first_image": self.index_of_first_image,
|
2022-09-28 17:05:23 +03:00
|
|
|
"infotexts": self.infotexts,
|
2022-10-04 20:13:09 +03:00
|
|
|
"styles": self.styles,
|
2022-10-04 20:17:15 +03:00
|
|
|
"job_timestamp": self.job_timestamp,
|
2022-10-08 22:21:15 +03:00
|
|
|
"clip_skip": self.clip_skip,
|
2022-11-19 12:47:52 +03:00
|
|
|
"is_using_inpainting_conditioning": self.is_using_inpainting_conditioning,
|
2022-09-03 12:08:45 +03:00
|
|
|
}
|
|
|
|
|
|
|
|
return json.dumps(obj)
|
|
|
|
|
2022-12-14 00:05:40 +03:00
|
|
|
def infotext(self, p: StableDiffusionProcessing, index):
|
2022-09-19 09:02:10 +03:00
|
|
|
return create_infotext(p, self.all_prompts, self.all_seeds, self.all_subseeds, comments=[], position_in_batch=index % self.batch_size, iteration=index // self.batch_size)
|
|
|
|
|
2023-05-17 20:22:38 +03:00
|
|
|
def get_token_merging_ratio(self, for_hr=False):
|
|
|
|
return self.token_merging_ratio_hr if for_hr else self.token_merging_ratio
|
|
|
|
|
2022-09-19 09:02:10 +03:00
|
|
|
|
2022-09-13 21:49:58 +03:00
|
|
|
def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, seed_resize_from_h=0, seed_resize_from_w=0, p=None):
|
2023-08-09 08:43:31 +03:00
|
|
|
g = rng.ImageRNG(shape, seeds, subseeds=subseeds, subseed_strength=subseed_strength, seed_resize_from_h=seed_resize_from_h, seed_resize_from_w=seed_resize_from_w)
|
|
|
|
return g.next()
|
2022-09-03 12:08:45 +03:00
|
|
|
|
|
|
|
|
2023-07-30 15:30:33 +03:00
|
|
|
class DecodedSamples(list):
|
|
|
|
already_decoded = True
|
|
|
|
|
|
|
|
|
2023-07-19 20:23:30 +03:00
|
|
|
def decode_latent_batch(model, batch, target_device=None, check_for_nans=False):
|
2023-07-30 15:30:33 +03:00
|
|
|
samples = DecodedSamples()
|
2023-07-19 20:23:30 +03:00
|
|
|
|
|
|
|
for i in range(batch.shape[0]):
|
|
|
|
sample = decode_first_stage(model, batch[i:i + 1])[0]
|
|
|
|
|
|
|
|
if check_for_nans:
|
|
|
|
try:
|
|
|
|
devices.test_for_nans(sample, "vae")
|
|
|
|
except devices.NansException as e:
|
|
|
|
if devices.dtype_vae == torch.float32 or not shared.opts.auto_vae_precision:
|
|
|
|
raise e
|
|
|
|
|
|
|
|
errors.print_error_explanation(
|
|
|
|
"A tensor with all NaNs was produced in VAE.\n"
|
|
|
|
"Web UI will now convert VAE into 32-bit float and retry.\n"
|
2023-08-08 20:08:37 +03:00
|
|
|
"To disable this behavior, disable the 'Automatically revert VAE to 32-bit floats' setting.\n"
|
2023-07-19 20:23:30 +03:00
|
|
|
"To always start with 32-bit VAE, use --no-half-vae commandline flag."
|
|
|
|
)
|
|
|
|
|
|
|
|
devices.dtype_vae = torch.float32
|
|
|
|
model.first_stage_model.to(devices.dtype_vae)
|
|
|
|
batch = batch.to(devices.dtype_vae)
|
|
|
|
|
|
|
|
sample = decode_first_stage(model, batch[i:i + 1])[0]
|
|
|
|
|
|
|
|
if target_device is not None:
|
|
|
|
sample = sample.to(target_device)
|
|
|
|
|
|
|
|
samples.append(sample)
|
|
|
|
|
|
|
|
return samples
|
|
|
|
|
|
|
|
|
2022-10-04 17:36:39 +03:00
|
|
|
def get_fixed_seed(seed):
|
2023-08-10 15:58:53 +03:00
|
|
|
if seed == '' or seed is None:
|
|
|
|
seed = -1
|
|
|
|
elif isinstance(seed, str):
|
|
|
|
try:
|
|
|
|
seed = int(seed)
|
|
|
|
except Exception:
|
|
|
|
seed = -1
|
|
|
|
|
|
|
|
if seed == -1:
|
2022-10-04 17:36:39 +03:00
|
|
|
return int(random.randrange(4294967294))
|
|
|
|
|
|
|
|
return seed
|
|
|
|
|
|
|
|
|
2022-09-09 17:54:04 +03:00
|
|
|
def fix_seed(p):
|
2022-10-04 17:36:39 +03:00
|
|
|
p.seed = get_fixed_seed(p.seed)
|
|
|
|
p.subseed = get_fixed_seed(p.subseed)
|
2022-09-07 01:44:44 +03:00
|
|
|
|
|
|
|
|
2023-05-08 15:23:49 +03:00
|
|
|
def program_version():
|
|
|
|
import launch
|
|
|
|
|
|
|
|
res = launch.git_tag()
|
|
|
|
if res == "<none>":
|
|
|
|
res = None
|
|
|
|
|
|
|
|
return res
|
|
|
|
|
|
|
|
|
2023-07-26 06:36:06 +03:00
|
|
|
def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iteration=0, position_in_batch=0, use_main_prompt=False, index=None, all_negative_prompts=None):
|
|
|
|
if index is None:
|
|
|
|
index = position_in_batch + iteration * p.batch_size
|
|
|
|
|
|
|
|
if all_negative_prompts is None:
|
|
|
|
all_negative_prompts = p.all_negative_prompts
|
2022-09-19 09:02:10 +03:00
|
|
|
|
2022-10-09 00:28:42 +03:00
|
|
|
clip_skip = getattr(p, 'clip_skip', opts.CLIP_stop_at_last_layers)
|
2023-05-13 19:11:02 +03:00
|
|
|
enable_hr = getattr(p, 'enable_hr', False)
|
2023-05-17 20:22:38 +03:00
|
|
|
token_merging_ratio = p.get_token_merging_ratio()
|
|
|
|
token_merging_ratio_hr = p.get_token_merging_ratio(for_hr=True)
|
2022-10-08 22:21:15 +03:00
|
|
|
|
2023-05-16 11:54:02 +03:00
|
|
|
uses_ensd = opts.eta_noise_seed_delta != 0
|
|
|
|
if uses_ensd:
|
|
|
|
uses_ensd = sd_samplers_common.is_sampler_using_eta_noise_seed_delta(p)
|
|
|
|
|
2022-09-19 09:02:10 +03:00
|
|
|
generation_params = {
|
|
|
|
"Steps": p.steps,
|
2022-11-19 12:01:51 +03:00
|
|
|
"Sampler": p.sampler_name,
|
2022-09-19 09:02:10 +03:00
|
|
|
"CFG scale": p.cfg_scale,
|
2023-02-04 02:19:56 +03:00
|
|
|
"Image CFG scale": getattr(p, 'image_cfg_scale', None),
|
2023-07-26 07:04:07 +03:00
|
|
|
"Seed": p.all_seeds[0] if use_main_prompt else all_seeds[index],
|
2023-08-10 12:41:41 +03:00
|
|
|
"Face restoration": opts.face_restoration_model if p.restore_faces else None,
|
2022-09-19 09:02:10 +03:00
|
|
|
"Size": f"{p.width}x{p.height}",
|
2023-08-13 06:07:30 +03:00
|
|
|
"Model hash": p.sd_model_hash if opts.add_model_hash_to_info else None,
|
|
|
|
"Model": p.sd_model_name if opts.add_model_name_to_info else None,
|
|
|
|
"VAE hash": p.sd_vae_hash if opts.add_model_hash_to_info else None,
|
|
|
|
"VAE": p.sd_vae_name if opts.add_model_name_to_info else None,
|
2023-07-26 07:04:07 +03:00
|
|
|
"Variation seed": (None if p.subseed_strength == 0 else (p.all_subseeds[0] if use_main_prompt else all_subseeds[index])),
|
2022-09-19 09:02:10 +03:00
|
|
|
"Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
|
2023-07-03 20:41:10 +03:00
|
|
|
"Seed resize from": (None if p.seed_resize_from_w <= 0 or p.seed_resize_from_h <= 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
|
2022-09-19 09:02:10 +03:00
|
|
|
"Denoising strength": getattr(p, 'denoising_strength', None),
|
2022-11-19 12:47:52 +03:00
|
|
|
"Conditional mask weight": getattr(p, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) if p.is_using_inpainting_conditioning else None,
|
2022-10-09 22:30:59 +03:00
|
|
|
"Clip skip": None if clip_skip <= 1 else clip_skip,
|
2023-05-16 11:54:02 +03:00
|
|
|
"ENSD": opts.eta_noise_seed_delta if uses_ensd else None,
|
2023-05-17 20:22:38 +03:00
|
|
|
"Token merging ratio": None if token_merging_ratio == 0 else token_merging_ratio,
|
|
|
|
"Token merging ratio hr": None if not enable_hr or token_merging_ratio_hr == 0 else token_merging_ratio_hr,
|
2023-04-29 11:29:37 +03:00
|
|
|
"Init image hash": getattr(p, 'init_img_hash', None),
|
2023-08-03 00:00:23 +03:00
|
|
|
"RNG": opts.randn_source if opts.randn_source != "GPU" and opts.randn_source != "NV" else None,
|
2023-04-29 15:57:09 +03:00
|
|
|
"NGMS": None if p.s_min_uncond == 0 else p.s_min_uncond,
|
2023-08-10 12:41:41 +03:00
|
|
|
"Tiling": "True" if p.tiling else None,
|
2023-05-16 11:54:02 +03:00
|
|
|
**p.extra_generation_params,
|
2023-05-08 15:23:49 +03:00
|
|
|
"Version": program_version() if opts.add_version_to_infotext else None,
|
2023-06-15 18:55:53 +03:00
|
|
|
"User": p.user if opts.add_user_name_to_info else None,
|
2022-09-19 09:02:10 +03:00
|
|
|
}
|
|
|
|
|
2022-10-21 16:10:51 +03:00
|
|
|
generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None])
|
2022-09-19 09:02:10 +03:00
|
|
|
|
2023-08-09 07:45:06 +03:00
|
|
|
prompt_text = p.main_prompt if use_main_prompt else all_prompts[index]
|
|
|
|
negative_prompt_text = f"\nNegative prompt: {p.main_negative_prompt if use_main_prompt else all_negative_prompts[index]}" if all_negative_prompts[index] else ""
|
2022-09-19 09:02:10 +03:00
|
|
|
|
2023-06-19 21:36:44 +03:00
|
|
|
return f"{prompt_text}{negative_prompt_text}\n{generation_params_text}".strip()
|
2022-09-19 09:02:10 +03:00
|
|
|
|
|
|
|
|
2022-09-03 12:08:45 +03:00
|
|
|
def process_images(p: StableDiffusionProcessing) -> Processed:
|
2023-05-31 22:40:09 +03:00
|
|
|
if p.scripts is not None:
|
|
|
|
p.scripts.before_process(p)
|
|
|
|
|
2022-10-26 11:47:07 +03:00
|
|
|
stored_opts = {k: opts.data[k] for k in p.override_settings.keys()}
|
|
|
|
|
|
|
|
try:
|
2023-08-06 17:01:07 +03:00
|
|
|
# after running refiner, the refiner model is not unloaded - webui swaps back to main model here
|
|
|
|
if shared.sd_model.sd_checkpoint_info.title != opts.sd_model_checkpoint:
|
|
|
|
sd_models.reload_model_weights()
|
|
|
|
|
2023-04-30 23:22:53 +03:00
|
|
|
# if no checkpoint override or the override checkpoint can't be found, remove override entry and load opts checkpoint
|
2023-07-03 12:17:20 +03:00
|
|
|
if sd_models.checkpoint_aliases.get(p.override_settings.get('sd_model_checkpoint')) is None:
|
2023-04-30 23:22:53 +03:00
|
|
|
p.override_settings.pop('sd_model_checkpoint', None)
|
|
|
|
sd_models.reload_model_weights()
|
|
|
|
|
2022-10-26 11:47:07 +03:00
|
|
|
for k, v in p.override_settings.items():
|
2022-11-29 06:11:29 +03:00
|
|
|
setattr(opts, k, v)
|
2023-01-05 10:21:17 +03:00
|
|
|
|
|
|
|
if k == 'sd_model_checkpoint':
|
2023-01-21 08:36:07 +03:00
|
|
|
sd_models.reload_model_weights()
|
2023-01-05 10:21:17 +03:00
|
|
|
|
|
|
|
if k == 'sd_vae':
|
2023-01-21 08:36:07 +03:00
|
|
|
sd_vae.reload_vae_weights()
|
2022-10-26 11:47:07 +03:00
|
|
|
|
2023-05-17 20:22:38 +03:00
|
|
|
sd_models.apply_token_merging(p.sd_model, p.get_token_merging_ratio())
|
2023-04-02 06:18:35 +03:00
|
|
|
|
2022-10-26 11:47:07 +03:00
|
|
|
res = process_images_inner(p)
|
|
|
|
|
2022-12-20 12:36:49 +03:00
|
|
|
finally:
|
2023-05-17 20:22:38 +03:00
|
|
|
sd_models.apply_token_merging(p.sd_model, 0)
|
2023-04-02 06:18:35 +03:00
|
|
|
|
2022-12-20 12:36:49 +03:00
|
|
|
# restore opts to original state
|
|
|
|
if p.override_settings_restore_afterwards:
|
|
|
|
for k, v in stored_opts.items():
|
|
|
|
setattr(opts, k, v)
|
2023-01-21 08:36:07 +03:00
|
|
|
|
|
|
|
if k == 'sd_vae':
|
|
|
|
sd_vae.reload_vae_weights()
|
2022-10-26 11:47:07 +03:00
|
|
|
|
|
|
|
return res
|
|
|
|
|
|
|
|
|
|
|
|
def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
2022-09-03 12:08:45 +03:00
|
|
|
"""this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
|
|
|
|
|
2022-09-17 11:34:33 +03:00
|
|
|
if type(p.prompt) == list:
|
|
|
|
assert(len(p.prompt) > 0)
|
|
|
|
else:
|
|
|
|
assert p.prompt is not None
|
2022-10-04 18:49:51 +03:00
|
|
|
|
2022-09-11 23:24:24 +03:00
|
|
|
devices.torch_gc()
|
2022-09-03 12:08:45 +03:00
|
|
|
|
2022-10-04 17:36:39 +03:00
|
|
|
seed = get_fixed_seed(p.seed)
|
|
|
|
subseed = get_fixed_seed(p.subseed)
|
2022-09-03 12:08:45 +03:00
|
|
|
|
2023-08-10 12:41:41 +03:00
|
|
|
if p.restore_faces is None:
|
|
|
|
p.restore_faces = opts.face_restoration
|
|
|
|
|
|
|
|
if p.tiling is None:
|
|
|
|
p.tiling = opts.tiling
|
|
|
|
|
2023-08-13 06:07:30 +03:00
|
|
|
if p.refiner_checkpoint not in (None, "", "None"):
|
|
|
|
p.refiner_checkpoint_info = sd_models.get_closet_checkpoint_match(p.refiner_checkpoint)
|
|
|
|
if p.refiner_checkpoint_info is None:
|
|
|
|
raise Exception(f'Could not find checkpoint with name {p.refiner_checkpoint}')
|
|
|
|
|
|
|
|
p.sd_model_name = shared.sd_model.sd_checkpoint_info.name_for_extra
|
|
|
|
p.sd_model_hash = shared.sd_model.sd_model_hash
|
|
|
|
p.sd_vae_name = sd_vae.get_loaded_vae_name()
|
|
|
|
p.sd_vae_hash = sd_vae.get_loaded_vae_hash()
|
2023-08-12 12:39:59 +03:00
|
|
|
|
2022-09-05 03:25:37 +03:00
|
|
|
modules.sd_hijack.model_hijack.apply_circular(p.tiling)
|
2022-10-08 00:48:34 +03:00
|
|
|
modules.sd_hijack.model_hijack.clear_comments()
|
2022-09-05 03:25:37 +03:00
|
|
|
|
2023-05-18 20:16:09 +03:00
|
|
|
p.setup_prompts()
|
2023-01-22 14:28:53 +03:00
|
|
|
|
2022-10-04 17:36:39 +03:00
|
|
|
if type(seed) == list:
|
2022-10-22 12:23:45 +03:00
|
|
|
p.all_seeds = seed
|
2022-09-03 17:21:15 +03:00
|
|
|
else:
|
2022-10-22 12:23:45 +03:00
|
|
|
p.all_seeds = [int(seed) + (x if p.subseed_strength == 0 else 0) for x in range(len(p.all_prompts))]
|
2022-09-09 17:54:04 +03:00
|
|
|
|
2022-10-04 17:36:39 +03:00
|
|
|
if type(subseed) == list:
|
2022-10-22 12:23:45 +03:00
|
|
|
p.all_subseeds = subseed
|
2022-09-09 17:54:04 +03:00
|
|
|
else:
|
2022-10-22 12:23:45 +03:00
|
|
|
p.all_subseeds = [int(subseed) + x for x in range(len(p.all_prompts))]
|
2022-09-03 12:08:45 +03:00
|
|
|
|
2022-10-16 08:51:24 +03:00
|
|
|
if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings:
|
2022-10-02 15:03:39 +03:00
|
|
|
model_hijack.embedding_db.load_textual_inversion_embeddings()
|
2022-09-03 12:08:45 +03:00
|
|
|
|
2022-10-22 12:23:45 +03:00
|
|
|
if p.scripts is not None:
|
2022-10-29 22:20:02 +03:00
|
|
|
p.scripts.process(p)
|
2022-10-22 12:23:45 +03:00
|
|
|
|
2022-09-28 17:05:23 +03:00
|
|
|
infotexts = []
|
2022-09-03 12:08:45 +03:00
|
|
|
output_images = []
|
2022-10-04 12:32:22 +03:00
|
|
|
|
2022-10-08 23:26:48 +03:00
|
|
|
with torch.no_grad(), p.sd_model.ema_scope():
|
2022-10-04 16:54:31 +03:00
|
|
|
with devices.autocast():
|
2022-10-22 12:23:45 +03:00
|
|
|
p.init(p.all_prompts, p.all_seeds, p.all_subseeds)
|
2022-09-03 12:08:45 +03:00
|
|
|
|
2023-01-25 18:56:23 +03:00
|
|
|
# for OSX, loading the model during sampling changes the generated picture, so it is loaded here
|
|
|
|
if shared.opts.live_previews_enable and opts.show_progress_type == "Approx NN":
|
2023-01-24 06:49:20 +03:00
|
|
|
sd_vae_approx.model()
|
|
|
|
|
2023-05-27 15:47:33 +03:00
|
|
|
sd_unet.apply_unet()
|
|
|
|
|
2022-09-06 10:11:25 +03:00
|
|
|
if state.job_count == -1:
|
|
|
|
state.job_count = p.n_iter
|
2022-09-06 02:09:01 +03:00
|
|
|
|
2022-10-05 04:28:50 +03:00
|
|
|
for n in range(p.n_iter):
|
2023-01-04 17:24:46 +03:00
|
|
|
p.iteration = n
|
|
|
|
|
2022-10-05 06:56:30 +03:00
|
|
|
if state.skipped:
|
|
|
|
state.skipped = False
|
2022-12-15 05:01:32 +03:00
|
|
|
|
2022-09-03 12:08:45 +03:00
|
|
|
if state.interrupted:
|
|
|
|
break
|
|
|
|
|
2023-08-08 22:14:02 +03:00
|
|
|
sd_models.reload_model_weights() # model can be changed for example by refiner
|
|
|
|
|
2023-05-18 20:16:09 +03:00
|
|
|
p.prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
|
|
|
|
p.negative_prompts = p.all_negative_prompts[n * p.batch_size:(n + 1) * p.batch_size]
|
|
|
|
p.seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
|
|
|
|
p.subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
|
2022-09-03 12:08:45 +03:00
|
|
|
|
2023-08-09 08:43:31 +03:00
|
|
|
p.rng = rng.ImageRNG((opt_C, p.height // opt_f, p.width // opt_f), p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, seed_resize_from_h=p.seed_resize_from_h, seed_resize_from_w=p.seed_resize_from_w)
|
|
|
|
|
2023-02-22 12:52:53 +03:00
|
|
|
if p.scripts is not None:
|
2023-05-18 20:16:09 +03:00
|
|
|
p.scripts.before_process_batch(p, batch_number=n, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds)
|
2023-02-22 12:52:53 +03:00
|
|
|
|
2023-05-18 20:16:09 +03:00
|
|
|
if len(p.prompts) == 0:
|
2022-09-17 11:34:33 +03:00
|
|
|
break
|
|
|
|
|
2023-06-03 23:19:34 +03:00
|
|
|
p.parse_extra_network_prompts()
|
2023-03-01 21:30:20 +03:00
|
|
|
|
2023-02-13 14:33:28 +03:00
|
|
|
if not p.disable_extra_networks:
|
|
|
|
with devices.autocast():
|
2023-06-03 22:24:44 +03:00
|
|
|
extra_networks.activate(p, p.extra_network_data)
|
2023-01-21 16:41:25 +03:00
|
|
|
|
2022-11-02 19:05:01 +03:00
|
|
|
if p.scripts is not None:
|
2023-05-18 20:16:09 +03:00
|
|
|
p.scripts.process_batch(p, batch_number=n, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds)
|
2022-11-02 19:05:01 +03:00
|
|
|
|
2023-02-17 08:44:46 +03:00
|
|
|
# params.txt should be saved after scripts.process_batch, since the
|
|
|
|
# infotext could be modified by that callback
|
|
|
|
# Example: a wildcard processed by process_batch sets an extra model
|
|
|
|
# strength, which is saved as "Model Strength: 1.0" in the infotext
|
|
|
|
if n == 0:
|
|
|
|
with open(os.path.join(paths.data_path, "params.txt"), "w", encoding="utf8") as file:
|
2023-08-10 15:58:53 +03:00
|
|
|
processed = Processed(p, [])
|
2023-02-17 08:44:46 +03:00
|
|
|
file.write(processed.infotext(p, 0))
|
|
|
|
|
2023-05-18 20:16:09 +03:00
|
|
|
p.setup_conds()
|
2022-09-03 12:08:45 +03:00
|
|
|
|
2023-07-15 08:41:22 +03:00
|
|
|
for comment in model_hijack.comments:
|
2023-08-13 15:07:37 +03:00
|
|
|
p.comment(comment)
|
2023-07-15 08:41:22 +03:00
|
|
|
|
|
|
|
p.extra_generation_params.update(model_hijack.extra_generation_params)
|
2022-09-03 12:08:45 +03:00
|
|
|
|
|
|
|
if p.n_iter > 1:
|
2022-09-24 08:23:01 +03:00
|
|
|
shared.state.job = f"Batch {n+1} out of {p.n_iter}"
|
2022-09-03 12:08:45 +03:00
|
|
|
|
2023-01-25 08:23:10 +03:00
|
|
|
with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast():
|
2023-05-18 20:16:09 +03:00
|
|
|
samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts)
|
2022-10-04 12:32:22 +03:00
|
|
|
|
2023-07-30 15:30:33 +03:00
|
|
|
if getattr(samples_ddim, 'already_decoded', False):
|
|
|
|
x_samples_ddim = samples_ddim
|
|
|
|
else:
|
2023-08-05 10:36:26 +03:00
|
|
|
if opts.sd_vae_decode_method != 'Full':
|
|
|
|
p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method
|
|
|
|
|
2023-07-30 15:30:33 +03:00
|
|
|
x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True)
|
|
|
|
|
2022-11-28 14:29:43 +03:00
|
|
|
x_samples_ddim = torch.stack(x_samples_ddim).float()
|
2022-09-03 12:08:45 +03:00
|
|
|
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
|
|
|
|
2022-09-29 04:14:13 +03:00
|
|
|
del samples_ddim
|
|
|
|
|
2023-06-04 13:07:22 +03:00
|
|
|
if lowvram.is_enabled(shared.sd_model):
|
2022-09-29 04:14:13 +03:00
|
|
|
lowvram.send_everything_to_cpu()
|
|
|
|
|
|
|
|
devices.torch_gc()
|
|
|
|
|
2022-12-10 14:54:02 +03:00
|
|
|
if p.scripts is not None:
|
|
|
|
p.scripts.postprocess_batch(p, x_samples_ddim, batch_number=n)
|
2022-09-13 03:15:35 +03:00
|
|
|
|
2023-07-26 07:49:57 +03:00
|
|
|
p.prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
|
|
|
|
p.negative_prompts = p.all_negative_prompts[n * p.batch_size:(n + 1) * p.batch_size]
|
2023-07-26 06:36:06 +03:00
|
|
|
|
2023-07-26 07:49:57 +03:00
|
|
|
batch_params = scripts.PostprocessBatchListArgs(list(x_samples_ddim))
|
2023-07-26 06:36:06 +03:00
|
|
|
p.scripts.postprocess_batch_list(p, batch_params, batch_number=n)
|
|
|
|
x_samples_ddim = batch_params.images
|
|
|
|
|
|
|
|
def infotext(index=0, use_main_prompt=False):
|
2023-07-26 07:04:07 +03:00
|
|
|
return create_infotext(p, p.prompts, p.seeds, p.subseeds, use_main_prompt=use_main_prompt, index=index, all_negative_prompts=p.negative_prompts)
|
2023-07-24 20:52:24 +03:00
|
|
|
|
2023-08-06 06:21:36 +03:00
|
|
|
save_samples = p.save_samples()
|
2023-08-05 13:07:39 +03:00
|
|
|
|
2022-10-05 04:28:50 +03:00
|
|
|
for i, x_sample in enumerate(x_samples_ddim):
|
2023-04-07 15:04:46 +03:00
|
|
|
p.batch_index = i
|
|
|
|
|
2022-09-03 12:08:45 +03:00
|
|
|
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
|
|
|
|
x_sample = x_sample.astype(np.uint8)
|
|
|
|
|
2022-10-05 04:28:50 +03:00
|
|
|
if p.restore_faces:
|
2023-08-06 06:21:36 +03:00
|
|
|
if save_samples and opts.save_images_before_face_restoration:
|
2023-07-26 06:36:06 +03:00
|
|
|
images.save_image(Image.fromarray(x_sample), p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-before-face-restoration")
|
2022-09-12 17:47:36 +03:00
|
|
|
|
2022-10-04 12:32:22 +03:00
|
|
|
devices.torch_gc()
|
2022-09-03 12:08:45 +03:00
|
|
|
|
2022-10-05 04:28:50 +03:00
|
|
|
x_sample = modules.face_restoration.restore_faces(x_sample)
|
|
|
|
devices.torch_gc()
|
2022-09-29 04:14:13 +03:00
|
|
|
|
2022-09-03 12:08:45 +03:00
|
|
|
image = Image.fromarray(x_sample)
|
2022-10-23 22:44:46 +03:00
|
|
|
|
2023-01-26 23:29:27 +03:00
|
|
|
if p.scripts is not None:
|
|
|
|
pp = scripts.PostprocessImageArgs(image)
|
|
|
|
p.scripts.postprocess_image(p, pp)
|
|
|
|
image = pp.image
|
2022-09-13 12:51:57 +03:00
|
|
|
if p.color_corrections is not None and i < len(p.color_corrections):
|
2023-08-06 06:21:36 +03:00
|
|
|
if save_samples and opts.save_images_before_color_correction:
|
2022-10-24 09:15:26 +03:00
|
|
|
image_without_cc = apply_overlay(image, p.paste_to, i, p.overlay_images)
|
2023-07-26 06:36:06 +03:00
|
|
|
images.save_image(image_without_cc, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-before-color-correction")
|
2022-09-13 12:51:57 +03:00
|
|
|
image = apply_color_correction(p.color_corrections[i], image)
|
2022-09-12 17:47:36 +03:00
|
|
|
|
2022-10-24 09:15:26 +03:00
|
|
|
image = apply_overlay(image, p.paste_to, i, p.overlay_images)
|
2022-09-03 12:08:45 +03:00
|
|
|
|
2023-08-06 06:21:36 +03:00
|
|
|
if save_samples:
|
2023-07-26 06:36:06 +03:00
|
|
|
images.save_image(image, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p)
|
2022-09-03 12:08:45 +03:00
|
|
|
|
2023-07-26 06:36:06 +03:00
|
|
|
text = infotext(i)
|
2022-10-06 20:27:50 +03:00
|
|
|
infotexts.append(text)
|
2022-10-09 13:10:15 +03:00
|
|
|
if opts.enable_pnginfo:
|
|
|
|
image.info["parameters"] = text
|
2022-09-03 12:08:45 +03:00
|
|
|
output_images.append(image)
|
2023-08-06 06:55:19 +03:00
|
|
|
if save_samples and hasattr(p, 'mask_for_overlay') and p.mask_for_overlay and any([opts.save_mask, opts.save_mask_composite, opts.return_mask, opts.return_mask_composite]):
|
2023-03-22 20:51:40 +03:00
|
|
|
image_mask = p.mask_for_overlay.convert('RGB')
|
2023-04-15 00:17:14 +03:00
|
|
|
image_mask_composite = Image.composite(image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), images.resize_image(2, p.mask_for_overlay, image.width, image.height).convert('L')).convert('RGBA')
|
2023-03-22 20:51:40 +03:00
|
|
|
|
|
|
|
if opts.save_mask:
|
2023-07-26 06:36:06 +03:00
|
|
|
images.save_image(image_mask, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask")
|
2023-03-22 20:51:40 +03:00
|
|
|
|
|
|
|
if opts.save_mask_composite:
|
2023-07-26 06:36:06 +03:00
|
|
|
images.save_image(image_mask_composite, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask-composite")
|
2023-03-22 20:51:40 +03:00
|
|
|
|
|
|
|
if opts.return_mask:
|
|
|
|
output_images.append(image_mask)
|
2023-04-07 15:04:46 +03:00
|
|
|
|
2023-03-22 20:51:40 +03:00
|
|
|
if opts.return_mask_composite:
|
|
|
|
output_images.append(image_mask_composite)
|
|
|
|
|
2022-12-15 05:01:32 +03:00
|
|
|
del x_samples_ddim
|
2022-09-06 02:09:01 +03:00
|
|
|
|
2022-10-05 04:28:50 +03:00
|
|
|
devices.torch_gc()
|
2022-09-29 04:14:13 +03:00
|
|
|
|
2022-10-05 04:28:50 +03:00
|
|
|
state.nextjob()
|
2022-09-29 04:14:13 +03:00
|
|
|
|
2022-09-18 01:18:30 +03:00
|
|
|
p.color_corrections = None
|
|
|
|
|
2022-09-19 09:02:10 +03:00
|
|
|
index_of_first_image = 0
|
2022-09-03 12:08:45 +03:00
|
|
|
unwanted_grid_because_of_img_count = len(output_images) < 2 and opts.grid_only_if_multiple
|
2022-09-14 10:34:44 +03:00
|
|
|
if (opts.return_grid or opts.grid_save) and not p.do_not_save_grid and not unwanted_grid_because_of_img_count:
|
2022-09-03 17:21:15 +03:00
|
|
|
grid = images.image_grid(output_images, p.batch_size)
|
2022-09-03 12:08:45 +03:00
|
|
|
|
2022-09-14 10:34:44 +03:00
|
|
|
if opts.return_grid:
|
2023-06-19 21:36:44 +03:00
|
|
|
text = infotext(use_main_prompt=True)
|
2022-10-06 20:27:50 +03:00
|
|
|
infotexts.insert(0, text)
|
2022-10-09 13:10:15 +03:00
|
|
|
if opts.enable_pnginfo:
|
|
|
|
grid.info["parameters"] = text
|
2022-09-03 12:08:45 +03:00
|
|
|
output_images.insert(0, grid)
|
2022-09-19 09:02:10 +03:00
|
|
|
index_of_first_image = 1
|
2022-09-03 12:08:45 +03:00
|
|
|
if opts.grid_save:
|
2023-06-19 21:36:44 +03:00
|
|
|
images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(use_main_prompt=True), short_filename=not opts.grid_extended_filename, p=p, grid=True)
|
2022-09-03 12:08:45 +03:00
|
|
|
|
2023-06-03 22:24:44 +03:00
|
|
|
if not p.disable_extra_networks and p.extra_network_data:
|
|
|
|
extra_networks.deactivate(p, p.extra_network_data)
|
2023-01-21 23:40:13 +03:00
|
|
|
|
2022-09-11 23:24:24 +03:00
|
|
|
devices.torch_gc()
|
2022-10-29 22:20:02 +03:00
|
|
|
|
2023-05-09 22:17:58 +03:00
|
|
|
res = Processed(
|
|
|
|
p,
|
|
|
|
images_list=output_images,
|
|
|
|
seed=p.all_seeds[0],
|
2023-07-26 07:04:07 +03:00
|
|
|
info=infotexts[0],
|
2023-05-09 22:17:58 +03:00
|
|
|
subseed=p.all_subseeds[0],
|
|
|
|
index_of_first_image=index_of_first_image,
|
|
|
|
infotexts=infotexts,
|
|
|
|
)
|
2022-10-29 22:20:02 +03:00
|
|
|
|
|
|
|
if p.scripts is not None:
|
|
|
|
p.scripts.postprocess(p, res)
|
|
|
|
|
|
|
|
return res
|
2022-09-03 12:08:45 +03:00
|
|
|
|
|
|
|
|
2023-01-09 14:57:47 +03:00
|
|
|
def old_hires_fix_first_pass_dimensions(width, height):
|
|
|
|
"""old algorithm for auto-calculating first pass size"""
|
|
|
|
|
|
|
|
desired_pixel_count = 512 * 512
|
|
|
|
actual_pixel_count = width * height
|
|
|
|
scale = math.sqrt(desired_pixel_count / actual_pixel_count)
|
|
|
|
width = math.ceil(scale * width / 64) * 64
|
|
|
|
height = math.ceil(scale * height / 64) * 64
|
|
|
|
|
|
|
|
return width, height
|
|
|
|
|
|
|
|
|
2023-08-13 08:24:16 +03:00
|
|
|
@dataclass(repr=False)
|
2022-09-03 12:08:45 +03:00
|
|
|
class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
2023-08-13 08:24:16 +03:00
|
|
|
enable_hr: bool = False
|
|
|
|
denoising_strength: float = 0.75
|
|
|
|
firstphase_width: int = 0
|
|
|
|
firstphase_height: int = 0
|
|
|
|
hr_scale: float = 2.0
|
|
|
|
hr_upscaler: str = None
|
|
|
|
hr_second_pass_steps: int = 0
|
|
|
|
hr_resize_x: int = 0
|
|
|
|
hr_resize_y: int = 0
|
|
|
|
hr_checkpoint_name: str = None
|
|
|
|
hr_sampler_name: str = None
|
|
|
|
hr_prompt: str = ''
|
|
|
|
hr_negative_prompt: str = ''
|
|
|
|
|
2023-06-08 07:53:02 +03:00
|
|
|
cached_hr_uc = [None, None]
|
|
|
|
cached_hr_c = [None, None]
|
2022-09-19 16:42:56 +03:00
|
|
|
|
2023-08-13 08:24:16 +03:00
|
|
|
hr_checkpoint_info: dict = field(default=None, init=False)
|
|
|
|
hr_upscale_to_x: int = field(default=0, init=False)
|
|
|
|
hr_upscale_to_y: int = field(default=0, init=False)
|
|
|
|
truncate_x: int = field(default=0, init=False)
|
|
|
|
truncate_y: int = field(default=0, init=False)
|
|
|
|
applied_old_hires_behavior_to: tuple = field(default=None, init=False)
|
|
|
|
latent_scale_mode: dict = field(default=None, init=False)
|
|
|
|
hr_c: tuple | None = field(default=None, init=False)
|
|
|
|
hr_uc: tuple | None = field(default=None, init=False)
|
|
|
|
all_hr_prompts: list = field(default=None, init=False)
|
|
|
|
all_hr_negative_prompts: list = field(default=None, init=False)
|
|
|
|
hr_prompts: list = field(default=None, init=False)
|
|
|
|
hr_negative_prompts: list = field(default=None, init=False)
|
|
|
|
hr_extra_network_data: list = field(default=None, init=False)
|
|
|
|
|
|
|
|
def __post_init__(self):
|
|
|
|
super().__post_init__()
|
|
|
|
|
|
|
|
if self.firstphase_width != 0 or self.firstphase_height != 0:
|
2023-01-09 14:57:47 +03:00
|
|
|
self.hr_upscale_to_x = self.width
|
|
|
|
self.hr_upscale_to_y = self.height
|
2023-08-13 08:24:16 +03:00
|
|
|
self.width = self.firstphase_width
|
|
|
|
self.height = self.firstphase_height
|
2023-05-18 20:16:09 +03:00
|
|
|
|
2023-06-08 07:53:02 +03:00
|
|
|
self.cached_hr_uc = StableDiffusionProcessingTxt2Img.cached_hr_uc
|
|
|
|
self.cached_hr_c = StableDiffusionProcessingTxt2Img.cached_hr_c
|
2023-05-18 20:16:09 +03:00
|
|
|
|
2023-08-10 11:20:46 +03:00
|
|
|
def calculate_target_resolution(self):
|
|
|
|
if opts.use_old_hires_fix_width_height and self.applied_old_hires_behavior_to != (self.width, self.height):
|
|
|
|
self.hr_resize_x = self.width
|
|
|
|
self.hr_resize_y = self.height
|
|
|
|
self.hr_upscale_to_x = self.width
|
|
|
|
self.hr_upscale_to_y = self.height
|
|
|
|
|
|
|
|
self.width, self.height = old_hires_fix_first_pass_dimensions(self.width, self.height)
|
|
|
|
self.applied_old_hires_behavior_to = (self.width, self.height)
|
|
|
|
|
|
|
|
if self.hr_resize_x == 0 and self.hr_resize_y == 0:
|
|
|
|
self.extra_generation_params["Hires upscale"] = self.hr_scale
|
|
|
|
self.hr_upscale_to_x = int(self.width * self.hr_scale)
|
|
|
|
self.hr_upscale_to_y = int(self.height * self.hr_scale)
|
|
|
|
else:
|
|
|
|
self.extra_generation_params["Hires resize"] = f"{self.hr_resize_x}x{self.hr_resize_y}"
|
|
|
|
|
|
|
|
if self.hr_resize_y == 0:
|
|
|
|
self.hr_upscale_to_x = self.hr_resize_x
|
|
|
|
self.hr_upscale_to_y = self.hr_resize_x * self.height // self.width
|
|
|
|
elif self.hr_resize_x == 0:
|
|
|
|
self.hr_upscale_to_x = self.hr_resize_y * self.width // self.height
|
|
|
|
self.hr_upscale_to_y = self.hr_resize_y
|
|
|
|
else:
|
|
|
|
target_w = self.hr_resize_x
|
|
|
|
target_h = self.hr_resize_y
|
|
|
|
src_ratio = self.width / self.height
|
|
|
|
dst_ratio = self.hr_resize_x / self.hr_resize_y
|
|
|
|
|
|
|
|
if src_ratio < dst_ratio:
|
|
|
|
self.hr_upscale_to_x = self.hr_resize_x
|
|
|
|
self.hr_upscale_to_y = self.hr_resize_x * self.height // self.width
|
|
|
|
else:
|
|
|
|
self.hr_upscale_to_x = self.hr_resize_y * self.width // self.height
|
|
|
|
self.hr_upscale_to_y = self.hr_resize_y
|
|
|
|
|
|
|
|
self.truncate_x = (self.hr_upscale_to_x - target_w) // opt_f
|
|
|
|
self.truncate_y = (self.hr_upscale_to_y - target_h) // opt_f
|
|
|
|
|
2022-09-19 16:42:56 +03:00
|
|
|
def init(self, all_prompts, all_seeds, all_subseeds):
|
|
|
|
if self.enable_hr:
|
2023-07-30 13:48:27 +03:00
|
|
|
if self.hr_checkpoint_name:
|
|
|
|
self.hr_checkpoint_info = sd_models.get_closet_checkpoint_match(self.hr_checkpoint_name)
|
|
|
|
|
|
|
|
if self.hr_checkpoint_info is None:
|
|
|
|
raise Exception(f'Could not find checkpoint with name {self.hr_checkpoint_name}')
|
|
|
|
|
|
|
|
self.extra_generation_params["Hires checkpoint"] = self.hr_checkpoint_info.short_title
|
|
|
|
|
2023-05-18 20:16:09 +03:00
|
|
|
if self.hr_sampler_name is not None and self.hr_sampler_name != self.sampler_name:
|
|
|
|
self.extra_generation_params["Hires sampler"] = self.hr_sampler_name
|
|
|
|
|
|
|
|
if tuple(self.hr_prompt) != tuple(self.prompt):
|
|
|
|
self.extra_generation_params["Hires prompt"] = self.hr_prompt
|
2023-02-05 18:24:41 +03:00
|
|
|
|
2023-05-18 20:16:09 +03:00
|
|
|
if tuple(self.hr_negative_prompt) != tuple(self.negative_prompt):
|
|
|
|
self.extra_generation_params["Hires negative prompt"] = self.hr_negative_prompt
|
2023-02-05 18:24:41 +03:00
|
|
|
|
2023-07-30 13:48:27 +03:00
|
|
|
self.latent_scale_mode = shared.latent_upscale_modes.get(self.hr_upscaler, None) if self.hr_upscaler is not None else shared.latent_upscale_modes.get(shared.latent_upscale_default_mode, "nearest")
|
|
|
|
if self.enable_hr and self.latent_scale_mode is None:
|
|
|
|
if not any(x.name == self.hr_upscaler for x in shared.sd_upscalers):
|
|
|
|
raise Exception(f"could not find upscaler named {self.hr_upscaler}")
|
|
|
|
|
2023-08-10 11:20:46 +03:00
|
|
|
self.calculate_target_resolution()
|
2023-01-04 22:04:40 +03:00
|
|
|
|
2023-01-05 01:25:52 +03:00
|
|
|
if not state.processing_has_refined_job_count:
|
|
|
|
if state.job_count == -1:
|
|
|
|
state.job_count = self.n_iter
|
2022-10-14 23:19:05 +03:00
|
|
|
|
2023-01-05 01:25:52 +03:00
|
|
|
shared.total_tqdm.updateTotal((self.steps + (self.hr_second_pass_steps or self.steps)) * state.job_count)
|
|
|
|
state.job_count = state.job_count * 2
|
|
|
|
state.processing_has_refined_job_count = True
|
2022-10-14 23:19:05 +03:00
|
|
|
|
2023-01-04 22:04:40 +03:00
|
|
|
if self.hr_second_pass_steps:
|
|
|
|
self.extra_generation_params["Hires steps"] = self.hr_second_pass_steps
|
2022-10-14 23:19:05 +03:00
|
|
|
|
2023-01-02 19:42:10 +03:00
|
|
|
if self.hr_upscaler is not None:
|
|
|
|
self.extra_generation_params["Hires upscaler"] = self.hr_upscaler
|
2022-10-14 23:19:05 +03:00
|
|
|
|
2023-05-18 20:16:09 +03:00
|
|
|
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
|
2022-11-19 12:01:51 +03:00
|
|
|
self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
|
2022-10-20 01:09:43 +03:00
|
|
|
|
2023-08-09 08:43:31 +03:00
|
|
|
x = self.rng.next()
|
2023-01-02 19:42:10 +03:00
|
|
|
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
|
2023-07-30 14:10:33 +03:00
|
|
|
del x
|
2023-01-02 19:42:10 +03:00
|
|
|
|
2022-10-20 01:09:43 +03:00
|
|
|
if not self.enable_hr:
|
2022-09-19 16:42:56 +03:00
|
|
|
return samples
|
|
|
|
|
2023-07-30 15:12:09 +03:00
|
|
|
if self.latent_scale_mode is None:
|
2023-07-31 13:20:26 +03:00
|
|
|
decoded_samples = torch.stack(decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)).to(dtype=torch.float32)
|
2023-07-30 15:12:09 +03:00
|
|
|
else:
|
|
|
|
decoded_samples = None
|
|
|
|
|
2023-07-30 13:48:27 +03:00
|
|
|
current = shared.sd_model.sd_checkpoint_info
|
|
|
|
try:
|
|
|
|
if self.hr_checkpoint_info is not None:
|
2023-07-31 09:13:07 +03:00
|
|
|
self.sampler = None
|
2023-07-30 13:48:27 +03:00
|
|
|
sd_models.reload_model_weights(info=self.hr_checkpoint_info)
|
2023-07-30 19:36:24 +03:00
|
|
|
devices.torch_gc()
|
2023-07-30 13:48:27 +03:00
|
|
|
|
2023-07-30 15:12:09 +03:00
|
|
|
return self.sample_hr_pass(samples, decoded_samples, seeds, subseeds, subseed_strength, prompts)
|
2023-07-30 13:48:27 +03:00
|
|
|
finally:
|
2023-07-31 09:13:07 +03:00
|
|
|
self.sampler = None
|
2023-07-30 13:48:27 +03:00
|
|
|
sd_models.reload_model_weights(info=current)
|
2023-07-30 19:36:24 +03:00
|
|
|
devices.torch_gc()
|
2023-07-30 13:48:27 +03:00
|
|
|
|
2023-07-30 15:12:09 +03:00
|
|
|
def sample_hr_pass(self, samples, decoded_samples, seeds, subseeds, subseed_strength, prompts):
|
2023-04-29 16:28:51 +03:00
|
|
|
self.is_hr_pass = True
|
|
|
|
|
2023-01-04 22:04:40 +03:00
|
|
|
target_width = self.hr_upscale_to_x
|
|
|
|
target_height = self.hr_upscale_to_y
|
2022-09-19 16:42:56 +03:00
|
|
|
|
2022-11-02 12:45:03 +03:00
|
|
|
def save_intermediate(image, index):
|
2023-01-02 19:42:10 +03:00
|
|
|
"""saves image before applying hires fix, if enabled in options; takes as an argument either an image or batch with latent space images"""
|
|
|
|
|
2023-08-06 06:21:36 +03:00
|
|
|
if not self.save_samples() or not opts.save_images_before_highres_fix:
|
2022-11-02 12:45:03 +03:00
|
|
|
return
|
|
|
|
|
|
|
|
if not isinstance(image, Image.Image):
|
2023-01-04 12:22:01 +03:00
|
|
|
image = sd_samplers.sample_to_image(image, index, approximation=0)
|
2022-11-02 12:45:03 +03:00
|
|
|
|
2023-01-04 17:24:46 +03:00
|
|
|
info = create_infotext(self, self.all_prompts, self.all_seeds, self.all_subseeds, [], iteration=self.iteration, position_in_batch=index)
|
2023-07-19 02:27:19 +03:00
|
|
|
images.save_image(image, self.outpath_samples, "", seeds[index], prompts[index], opts.samples_format, info=info, p=self, suffix="-before-highres-fix")
|
2022-11-02 12:45:03 +03:00
|
|
|
|
2023-07-31 10:43:26 +03:00
|
|
|
img2img_sampler_name = self.hr_sampler_name or self.sampler_name
|
|
|
|
|
|
|
|
self.sampler = sd_samplers.create_sampler(img2img_sampler_name, self.sd_model)
|
|
|
|
|
2023-07-30 13:48:27 +03:00
|
|
|
if self.latent_scale_mode is not None:
|
2022-11-04 10:45:34 +03:00
|
|
|
for i in range(samples.shape[0]):
|
|
|
|
save_intermediate(samples, i)
|
|
|
|
|
2023-07-30 13:48:27 +03:00
|
|
|
samples = torch.nn.functional.interpolate(samples, size=(target_height // opt_f, target_width // opt_f), mode=self.latent_scale_mode["mode"], antialias=self.latent_scale_mode["antialias"])
|
2022-11-04 10:45:34 +03:00
|
|
|
|
2022-12-15 05:01:32 +03:00
|
|
|
# Avoid making the inpainting conditioning unless necessary as
|
2022-10-29 20:35:51 +03:00
|
|
|
# this does need some extra compute to decode / encode the image again.
|
|
|
|
if getattr(self, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) < 1.0:
|
|
|
|
image_conditioning = self.img2img_image_conditioning(decode_first_stage(self.sd_model, samples), samples)
|
|
|
|
else:
|
|
|
|
image_conditioning = self.txt2img_image_conditioning(samples)
|
2022-09-19 16:42:56 +03:00
|
|
|
else:
|
2022-10-14 17:03:03 +03:00
|
|
|
lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0)
|
2022-09-20 19:32:26 +03:00
|
|
|
|
2022-10-14 17:03:03 +03:00
|
|
|
batch_images = []
|
|
|
|
for i, x_sample in enumerate(lowres_samples):
|
|
|
|
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
|
|
|
|
x_sample = x_sample.astype(np.uint8)
|
|
|
|
image = Image.fromarray(x_sample)
|
2022-11-02 12:45:03 +03:00
|
|
|
|
|
|
|
save_intermediate(image, i)
|
|
|
|
|
2023-01-02 19:42:10 +03:00
|
|
|
image = images.resize_image(0, image, target_width, target_height, upscaler_name=self.hr_upscaler)
|
2022-10-14 17:03:03 +03:00
|
|
|
image = np.array(image).astype(np.float32) / 255.0
|
|
|
|
image = np.moveaxis(image, 2, 0)
|
|
|
|
batch_images.append(image)
|
|
|
|
|
|
|
|
decoded_samples = torch.from_numpy(np.array(batch_images))
|
2023-08-05 08:14:00 +03:00
|
|
|
decoded_samples = decoded_samples.to(shared.device, dtype=devices.dtype_vae)
|
2022-10-14 17:03:03 +03:00
|
|
|
|
2023-08-05 10:36:26 +03:00
|
|
|
if opts.sd_vae_encode_method != 'Full':
|
|
|
|
self.extra_generation_params['VAE Encoder'] = opts.sd_vae_encode_method
|
2023-08-04 13:23:14 +03:00
|
|
|
samples = images_tensor_to_samples(decoded_samples, approximation_indexes.get(opts.sd_vae_encode_method))
|
2022-09-19 16:42:56 +03:00
|
|
|
|
2022-10-29 20:35:51 +03:00
|
|
|
image_conditioning = self.img2img_image_conditioning(decoded_samples, samples)
|
2022-10-29 20:02:56 +03:00
|
|
|
|
2022-09-19 16:42:56 +03:00
|
|
|
shared.state.nextjob()
|
2022-09-03 12:08:45 +03:00
|
|
|
|
2023-01-04 22:04:40 +03:00
|
|
|
samples = samples[:, :, self.truncate_y//2:samples.shape[2]-(self.truncate_y+1)//2, self.truncate_x//2:samples.shape[3]-(self.truncate_x+1)//2]
|
|
|
|
|
2023-08-09 08:43:31 +03:00
|
|
|
self.rng = rng.ImageRNG(samples.shape[1:], self.seeds, subseeds=self.subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w)
|
|
|
|
noise = self.rng.next()
|
2022-09-24 04:18:34 +03:00
|
|
|
|
|
|
|
# GC now before running the next img2img to prevent running out of memory
|
|
|
|
devices.torch_gc()
|
2022-10-04 18:49:51 +03:00
|
|
|
|
2023-05-18 20:16:09 +03:00
|
|
|
if not self.disable_extra_networks:
|
|
|
|
with devices.autocast():
|
|
|
|
extra_networks.activate(self, self.hr_extra_network_data)
|
|
|
|
|
2023-06-04 13:07:22 +03:00
|
|
|
with devices.autocast():
|
|
|
|
self.calculate_hr_conds()
|
|
|
|
|
2023-05-17 20:22:38 +03:00
|
|
|
sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio(for_hr=True))
|
2023-04-02 06:18:35 +03:00
|
|
|
|
2023-06-28 12:37:08 +03:00
|
|
|
if self.scripts is not None:
|
|
|
|
self.scripts.before_hr(self)
|
|
|
|
|
2023-05-18 20:16:09 +03:00
|
|
|
samples = self.sampler.sample_img2img(self, samples, noise, self.hr_c, self.hr_uc, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning)
|
2022-09-19 16:42:56 +03:00
|
|
|
|
2023-05-17 20:22:38 +03:00
|
|
|
sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio())
|
2023-05-13 18:23:42 +03:00
|
|
|
|
2023-08-13 11:40:34 +03:00
|
|
|
self.sampler = None
|
|
|
|
devices.torch_gc()
|
|
|
|
|
2023-07-30 15:30:33 +03:00
|
|
|
decoded_samples = decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)
|
|
|
|
|
2023-04-29 16:28:51 +03:00
|
|
|
self.is_hr_pass = False
|
|
|
|
|
2023-07-30 15:30:33 +03:00
|
|
|
return decoded_samples
|
2022-09-03 12:08:45 +03:00
|
|
|
|
2023-05-18 20:16:09 +03:00
|
|
|
def close(self):
|
2023-06-08 07:53:02 +03:00
|
|
|
super().close()
|
2023-05-18 20:16:09 +03:00
|
|
|
self.hr_c = None
|
|
|
|
self.hr_uc = None
|
2023-08-06 13:25:51 +03:00
|
|
|
if not opts.persistent_cond_cache:
|
2023-06-08 07:53:02 +03:00
|
|
|
StableDiffusionProcessingTxt2Img.cached_hr_uc = [None, None]
|
|
|
|
StableDiffusionProcessingTxt2Img.cached_hr_c = [None, None]
|
2023-05-18 20:16:09 +03:00
|
|
|
|
|
|
|
def setup_prompts(self):
|
|
|
|
super().setup_prompts()
|
|
|
|
|
|
|
|
if not self.enable_hr:
|
|
|
|
return
|
|
|
|
|
|
|
|
if self.hr_prompt == '':
|
|
|
|
self.hr_prompt = self.prompt
|
|
|
|
|
|
|
|
if self.hr_negative_prompt == '':
|
|
|
|
self.hr_negative_prompt = self.negative_prompt
|
|
|
|
|
|
|
|
if type(self.hr_prompt) == list:
|
|
|
|
self.all_hr_prompts = self.hr_prompt
|
|
|
|
else:
|
|
|
|
self.all_hr_prompts = self.batch_size * self.n_iter * [self.hr_prompt]
|
|
|
|
|
|
|
|
if type(self.hr_negative_prompt) == list:
|
|
|
|
self.all_hr_negative_prompts = self.hr_negative_prompt
|
|
|
|
else:
|
|
|
|
self.all_hr_negative_prompts = self.batch_size * self.n_iter * [self.hr_negative_prompt]
|
|
|
|
|
|
|
|
self.all_hr_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, self.styles) for x in self.all_hr_prompts]
|
|
|
|
self.all_hr_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, self.styles) for x in self.all_hr_negative_prompts]
|
|
|
|
|
2023-06-04 13:07:22 +03:00
|
|
|
def calculate_hr_conds(self):
|
|
|
|
if self.hr_c is not None:
|
|
|
|
return
|
|
|
|
|
2023-07-30 14:10:26 +03:00
|
|
|
hr_prompts = prompt_parser.SdConditioning(self.hr_prompts, width=self.hr_upscale_to_x, height=self.hr_upscale_to_y)
|
|
|
|
hr_negative_prompts = prompt_parser.SdConditioning(self.hr_negative_prompts, width=self.hr_upscale_to_x, height=self.hr_upscale_to_y, is_negative_prompt=True)
|
|
|
|
|
2023-08-12 12:39:59 +03:00
|
|
|
sampler_config = sd_samplers.find_sampler_config(self.hr_sampler_name or self.sampler_name)
|
|
|
|
steps = self.hr_second_pass_steps or self.steps
|
|
|
|
total_steps = sampler_config.total_steps(steps) if sampler_config else steps
|
|
|
|
|
|
|
|
self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, hr_negative_prompts, total_steps, [self.cached_hr_uc, self.cached_uc], self.hr_extra_network_data)
|
|
|
|
self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, hr_prompts, total_steps, [self.cached_hr_c, self.cached_c], self.hr_extra_network_data)
|
2023-06-04 13:07:22 +03:00
|
|
|
|
2023-05-18 20:16:09 +03:00
|
|
|
def setup_conds(self):
|
2023-08-12 12:54:32 +03:00
|
|
|
if self.is_hr_pass:
|
|
|
|
# if we are in hr pass right now, the call is being made from the refiner, and we don't need to setup firstpass cons or switch model
|
|
|
|
self.hr_c = None
|
|
|
|
self.calculate_hr_conds()
|
|
|
|
return
|
|
|
|
|
2023-05-18 20:16:09 +03:00
|
|
|
super().setup_conds()
|
|
|
|
|
2023-06-04 13:07:22 +03:00
|
|
|
self.hr_uc = None
|
|
|
|
self.hr_c = None
|
|
|
|
|
2023-07-30 13:48:27 +03:00
|
|
|
if self.enable_hr and self.hr_checkpoint_info is None:
|
2023-06-04 13:07:22 +03:00
|
|
|
if shared.opts.hires_fix_use_firstpass_conds:
|
|
|
|
self.calculate_hr_conds()
|
|
|
|
|
|
|
|
elif lowvram.is_enabled(shared.sd_model): # if in lowvram mode, we need to calculate conds right away, before the cond NN is unloaded
|
|
|
|
with devices.autocast():
|
|
|
|
extra_networks.activate(self, self.hr_extra_network_data)
|
|
|
|
|
|
|
|
self.calculate_hr_conds()
|
|
|
|
|
|
|
|
with devices.autocast():
|
|
|
|
extra_networks.activate(self, self.extra_network_data)
|
2023-05-18 20:16:09 +03:00
|
|
|
|
2023-08-06 17:53:33 +03:00
|
|
|
def get_conds(self):
|
|
|
|
if self.is_hr_pass:
|
|
|
|
return self.hr_c, self.hr_uc
|
|
|
|
|
|
|
|
return super().get_conds()
|
|
|
|
|
2023-05-18 20:16:09 +03:00
|
|
|
def parse_extra_network_prompts(self):
|
|
|
|
res = super().parse_extra_network_prompts()
|
|
|
|
|
|
|
|
if self.enable_hr:
|
|
|
|
self.hr_prompts = self.all_hr_prompts[self.iteration * self.batch_size:(self.iteration + 1) * self.batch_size]
|
|
|
|
self.hr_negative_prompts = self.all_hr_negative_prompts[self.iteration * self.batch_size:(self.iteration + 1) * self.batch_size]
|
|
|
|
|
|
|
|
self.hr_prompts, self.hr_extra_network_data = extra_networks.parse_prompts(self.hr_prompts)
|
|
|
|
|
|
|
|
return res
|
|
|
|
|
2022-09-03 12:08:45 +03:00
|
|
|
|
2023-08-13 08:24:16 +03:00
|
|
|
@dataclass(repr=False)
|
2022-09-03 12:08:45 +03:00
|
|
|
class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
2023-08-13 08:24:16 +03:00
|
|
|
init_images: list = None
|
|
|
|
resize_mode: int = 0
|
|
|
|
denoising_strength: float = 0.75
|
|
|
|
image_cfg_scale: float = None
|
|
|
|
mask: Any = None
|
|
|
|
mask_blur_x: int = 4
|
|
|
|
mask_blur_y: int = 4
|
|
|
|
mask_blur: int = None
|
|
|
|
inpainting_fill: int = 0
|
|
|
|
inpaint_full_res: bool = True
|
|
|
|
inpaint_full_res_padding: int = 0
|
|
|
|
inpainting_mask_invert: int = 0
|
|
|
|
initial_noise_multiplier: float = None
|
|
|
|
latent_mask: Image = None
|
|
|
|
|
|
|
|
image_mask: Any = field(default=None, init=False)
|
|
|
|
|
|
|
|
nmask: torch.Tensor = field(default=None, init=False)
|
|
|
|
image_conditioning: torch.Tensor = field(default=None, init=False)
|
|
|
|
init_img_hash: str = field(default=None, init=False)
|
|
|
|
mask_for_overlay: Image = field(default=None, init=False)
|
|
|
|
init_latent: torch.Tensor = field(default=None, init=False)
|
|
|
|
|
|
|
|
def __post_init__(self):
|
|
|
|
super().__post_init__()
|
|
|
|
|
|
|
|
self.image_mask = self.mask
|
2022-09-03 12:08:45 +03:00
|
|
|
self.mask = None
|
2023-08-13 08:24:16 +03:00
|
|
|
self.initial_noise_multiplier = opts.initial_noise_multiplier if self.initial_noise_multiplier is None else self.initial_noise_multiplier
|
2022-09-03 12:08:45 +03:00
|
|
|
|
2023-08-03 05:03:35 +03:00
|
|
|
@property
|
|
|
|
def mask_blur(self):
|
|
|
|
if self.mask_blur_x == self.mask_blur_y:
|
|
|
|
return self.mask_blur_x
|
|
|
|
return None
|
|
|
|
|
|
|
|
@mask_blur.setter
|
|
|
|
def mask_blur(self, value):
|
2023-08-13 08:24:16 +03:00
|
|
|
if isinstance(value, int):
|
|
|
|
self.mask_blur_x = value
|
|
|
|
self.mask_blur_y = value
|
2023-08-03 05:03:35 +03:00
|
|
|
|
2022-09-19 16:42:56 +03:00
|
|
|
def init(self, all_prompts, all_seeds, all_subseeds):
|
2023-08-13 08:24:16 +03:00
|
|
|
self.image_cfg_scale: float = self.image_cfg_scale if shared.sd_model.cond_stage_key == "edit" else None
|
|
|
|
|
2022-11-19 12:01:51 +03:00
|
|
|
self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
|
2022-09-03 12:08:45 +03:00
|
|
|
crop_region = None
|
|
|
|
|
2022-11-19 13:47:37 +03:00
|
|
|
image_mask = self.image_mask
|
2022-09-03 21:02:38 +03:00
|
|
|
|
2022-11-19 13:47:37 +03:00
|
|
|
if image_mask is not None:
|
|
|
|
image_mask = image_mask.convert('L')
|
2022-09-03 21:02:38 +03:00
|
|
|
|
2022-11-19 13:47:37 +03:00
|
|
|
if self.inpainting_mask_invert:
|
|
|
|
image_mask = ImageOps.invert(image_mask)
|
2022-09-04 01:29:43 +03:00
|
|
|
|
2023-05-12 01:03:44 +03:00
|
|
|
if self.mask_blur_x > 0:
|
|
|
|
np_mask = np.array(image_mask)
|
2023-08-05 07:54:23 +03:00
|
|
|
kernel_size = 2 * int(2.5 * self.mask_blur_x + 0.5) + 1
|
2023-05-12 01:03:44 +03:00
|
|
|
np_mask = cv2.GaussianBlur(np_mask, (kernel_size, 1), self.mask_blur_x)
|
|
|
|
image_mask = Image.fromarray(np_mask)
|
|
|
|
|
|
|
|
if self.mask_blur_y > 0:
|
|
|
|
np_mask = np.array(image_mask)
|
2023-08-05 07:54:23 +03:00
|
|
|
kernel_size = 2 * int(2.5 * self.mask_blur_y + 0.5) + 1
|
2023-05-12 01:03:44 +03:00
|
|
|
np_mask = cv2.GaussianBlur(np_mask, (1, kernel_size), self.mask_blur_y)
|
|
|
|
image_mask = Image.fromarray(np_mask)
|
2022-09-03 12:08:45 +03:00
|
|
|
|
|
|
|
if self.inpaint_full_res:
|
2022-11-19 13:47:37 +03:00
|
|
|
self.mask_for_overlay = image_mask
|
|
|
|
mask = image_mask.convert('L')
|
2022-09-22 12:11:48 +03:00
|
|
|
crop_region = masking.get_crop_region(np.array(mask), self.inpaint_full_res_padding)
|
2022-09-18 10:49:00 +03:00
|
|
|
crop_region = masking.expand_crop_region(crop_region, self.width, self.height, mask.width, mask.height)
|
2022-09-03 12:08:45 +03:00
|
|
|
x1, y1, x2, y2 = crop_region
|
|
|
|
|
|
|
|
mask = mask.crop(crop_region)
|
2022-11-19 13:47:37 +03:00
|
|
|
image_mask = images.resize_image(2, mask, self.width, self.height)
|
2022-09-03 12:08:45 +03:00
|
|
|
self.paste_to = (x1, y1, x2-x1, y2-y1)
|
|
|
|
else:
|
2023-03-28 20:36:57 +03:00
|
|
|
image_mask = images.resize_image(self.resize_mode, image_mask, self.width, self.height)
|
2022-11-19 13:47:37 +03:00
|
|
|
np_mask = np.array(image_mask)
|
2022-09-13 17:14:40 +03:00
|
|
|
np_mask = np.clip((np_mask.astype(np.float32)) * 2, 0, 255).astype(np.uint8)
|
2022-09-07 00:58:01 +03:00
|
|
|
self.mask_for_overlay = Image.fromarray(np_mask)
|
2022-09-03 12:08:45 +03:00
|
|
|
|
|
|
|
self.overlay_images = []
|
|
|
|
|
2022-11-19 13:47:37 +03:00
|
|
|
latent_mask = self.latent_mask if self.latent_mask is not None else image_mask
|
2022-09-07 17:00:51 +03:00
|
|
|
|
2022-09-16 08:33:47 +03:00
|
|
|
add_color_corrections = opts.img2img_color_correction and self.color_corrections is None
|
|
|
|
if add_color_corrections:
|
|
|
|
self.color_corrections = []
|
2022-09-03 12:08:45 +03:00
|
|
|
imgs = []
|
|
|
|
for img in self.init_images:
|
2023-04-06 19:42:26 +03:00
|
|
|
|
|
|
|
# Save init image
|
|
|
|
if opts.save_init_img:
|
|
|
|
self.init_img_hash = hashlib.md5(img.tobytes()).hexdigest()
|
|
|
|
images.save_image(img, path=opts.outdir_init_images, basename=None, forced_filename=self.init_img_hash, save_to_dirs=False)
|
|
|
|
|
2022-12-24 09:46:35 +03:00
|
|
|
image = images.flatten(img, opts.img2img_background_color)
|
2022-09-03 12:08:45 +03:00
|
|
|
|
2022-12-08 10:09:09 +03:00
|
|
|
if crop_region is None and self.resize_mode != 3:
|
2023-03-28 20:36:57 +03:00
|
|
|
image = images.resize_image(self.resize_mode, image, self.width, self.height)
|
2022-09-03 12:08:45 +03:00
|
|
|
|
2022-11-19 13:47:37 +03:00
|
|
|
if image_mask is not None:
|
2022-09-03 12:08:45 +03:00
|
|
|
image_masked = Image.new('RGBa', (image.width, image.height))
|
|
|
|
image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(self.mask_for_overlay.convert('L')))
|
|
|
|
|
|
|
|
self.overlay_images.append(image_masked.convert('RGBA'))
|
|
|
|
|
2022-12-24 11:12:17 +03:00
|
|
|
# crop_region is not None if we are doing inpaint full res
|
2022-09-03 12:08:45 +03:00
|
|
|
if crop_region is not None:
|
|
|
|
image = image.crop(crop_region)
|
|
|
|
image = images.resize_image(2, image, self.width, self.height)
|
|
|
|
|
2022-11-19 13:47:37 +03:00
|
|
|
if image_mask is not None:
|
2022-09-08 10:03:21 +03:00
|
|
|
if self.inpainting_fill != 1:
|
2022-09-18 10:49:00 +03:00
|
|
|
image = masking.fill(image, latent_mask)
|
2022-09-08 10:03:21 +03:00
|
|
|
|
2022-09-16 08:33:47 +03:00
|
|
|
if add_color_corrections:
|
2022-09-13 12:51:57 +03:00
|
|
|
self.color_corrections.append(setup_color_correction(image))
|
|
|
|
|
2022-09-03 12:08:45 +03:00
|
|
|
image = np.array(image).astype(np.float32) / 255.0
|
|
|
|
image = np.moveaxis(image, 2, 0)
|
|
|
|
|
|
|
|
imgs.append(image)
|
|
|
|
|
|
|
|
if len(imgs) == 1:
|
|
|
|
batch_images = np.expand_dims(imgs[0], axis=0).repeat(self.batch_size, axis=0)
|
|
|
|
if self.overlay_images is not None:
|
|
|
|
self.overlay_images = self.overlay_images * self.batch_size
|
2022-10-22 22:06:54 +03:00
|
|
|
|
|
|
|
if self.color_corrections is not None and len(self.color_corrections) == 1:
|
|
|
|
self.color_corrections = self.color_corrections * self.batch_size
|
|
|
|
|
2022-09-03 12:08:45 +03:00
|
|
|
elif len(imgs) <= self.batch_size:
|
|
|
|
self.batch_size = len(imgs)
|
|
|
|
batch_images = np.array(imgs)
|
|
|
|
else:
|
|
|
|
raise RuntimeError(f"bad number of images passed: {len(imgs)}; expecting {self.batch_size} or less")
|
|
|
|
|
|
|
|
image = torch.from_numpy(batch_images)
|
2023-08-05 08:14:00 +03:00
|
|
|
image = image.to(shared.device, dtype=devices.dtype_vae)
|
2023-08-05 10:36:26 +03:00
|
|
|
|
|
|
|
if opts.sd_vae_encode_method != 'Full':
|
|
|
|
self.extra_generation_params['VAE Encoder'] = opts.sd_vae_encode_method
|
|
|
|
|
2023-08-04 08:38:52 +03:00
|
|
|
self.init_latent = images_tensor_to_samples(image, approximation_indexes.get(opts.sd_vae_encode_method), self.sd_model)
|
2023-08-02 18:53:09 +03:00
|
|
|
devices.torch_gc()
|
2022-09-03 12:08:45 +03:00
|
|
|
|
2023-03-28 20:36:57 +03:00
|
|
|
if self.resize_mode == 3:
|
|
|
|
self.init_latent = torch.nn.functional.interpolate(self.init_latent, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")
|
2022-12-08 10:09:09 +03:00
|
|
|
|
2022-11-19 13:47:37 +03:00
|
|
|
if image_mask is not None:
|
2022-09-07 17:00:51 +03:00
|
|
|
init_mask = latent_mask
|
2022-09-04 01:29:43 +03:00
|
|
|
latmask = init_mask.convert('RGB').resize((self.init_latent.shape[3], self.init_latent.shape[2]))
|
2022-09-12 20:09:32 +03:00
|
|
|
latmask = np.moveaxis(np.array(latmask, dtype=np.float32), 2, 0) / 255
|
2022-09-03 12:08:45 +03:00
|
|
|
latmask = latmask[0]
|
2022-09-07 17:00:51 +03:00
|
|
|
latmask = np.around(latmask)
|
2022-09-03 12:08:45 +03:00
|
|
|
latmask = np.tile(latmask[None], (4, 1, 1))
|
|
|
|
|
|
|
|
self.mask = torch.asarray(1.0 - latmask).to(shared.device).type(self.sd_model.dtype)
|
|
|
|
self.nmask = torch.asarray(latmask).to(shared.device).type(self.sd_model.dtype)
|
|
|
|
|
2022-09-19 16:42:56 +03:00
|
|
|
# this needs to be fixed to be done in sample() using actual seeds for batches
|
2022-09-03 12:08:45 +03:00
|
|
|
if self.inpainting_fill == 2:
|
2022-09-19 16:42:56 +03:00
|
|
|
self.init_latent = self.init_latent * self.mask + create_random_tensors(self.init_latent.shape[1:], all_seeds[0:self.init_latent.shape[0]]) * self.nmask
|
2022-09-03 12:08:45 +03:00
|
|
|
elif self.inpainting_fill == 3:
|
|
|
|
self.init_latent = self.init_latent * self.mask
|
|
|
|
|
2022-11-19 13:47:37 +03:00
|
|
|
self.image_conditioning = self.img2img_image_conditioning(image, self.init_latent, image_mask)
|
2022-10-20 01:09:43 +03:00
|
|
|
|
2022-11-02 12:45:03 +03:00
|
|
|
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
|
2023-08-09 08:43:31 +03:00
|
|
|
x = self.rng.next()
|
2022-12-10 09:51:26 +03:00
|
|
|
|
|
|
|
if self.initial_noise_multiplier != 1.0:
|
|
|
|
self.extra_generation_params["Noise multiplier"] = self.initial_noise_multiplier
|
|
|
|
x *= self.initial_noise_multiplier
|
2022-09-19 16:42:56 +03:00
|
|
|
|
2022-10-19 23:47:45 +03:00
|
|
|
samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning)
|
2022-09-03 12:08:45 +03:00
|
|
|
|
|
|
|
if self.mask is not None:
|
|
|
|
samples = samples * self.nmask + self.init_latent * self.mask
|
|
|
|
|
2022-09-29 04:14:13 +03:00
|
|
|
del x
|
|
|
|
devices.torch_gc()
|
|
|
|
|
2022-11-02 12:45:03 +03:00
|
|
|
return samples
|
2023-05-17 20:22:38 +03:00
|
|
|
|
|
|
|
def get_token_merging_ratio(self, for_hr=False):
|
|
|
|
return self.token_merging_ratio or ("token_merging_ratio" in self.override_settings and opts.token_merging_ratio) or opts.token_merging_ratio_img2img or opts.token_merging_ratio
|