From a00d827cd8a4fd6c0c7135cdcb88768bcdf473a9 Mon Sep 17 00:00:00 2001 From: hlky <106811348+hlky@users.noreply.github.com> Date: Thu, 8 Sep 2022 11:41:04 +0100 Subject: [PATCH 01/27] Dev merge (#819) * #715 #699 #698 #663 #625 #617 #611 #604 (#716) * Update README.md * Add sampler name to metadata (#695) Co-authored-by: EliEron * old-dev-merge Co-authored-by: EliEron Co-authored-by: EliEron * img2img-fix (#717) * Revert "img2img-fix (#717)" This reverts commit 70d4b1ca2a27ff6e67aada0a47cb02670adfe056. * img2img fixes * Revert "img2img fixes" This reverts commit e66eddc6217d37deaa5e3086366a6f208a688969. * Revert "Revert "img2img-fix (#717)"" This reverts commit bf08b617d4fc97551cd9f264556b2e875e54b831. * img2img fixed * - Removed duplicated calls to save_sample. - Change variables and arguments to be more self-explanatory and easier to understand what they do. * Moved streamlit files to their proper location, before they were incorrectly added to the repository root folder. * Added retry dependency for the streamlit version. * Added .cmd file for easy running and updating the streamlit version of the UI. * Removed duplicated entry for streamlit on the environment.yaml file. * Removed some unnecessary lines from the the webui_streamlit.cmd file. * add gfpgan folder to gitignore, auto gen by imglab * added placeholder text similar to gradio * added auto conversion for 4 channel PNG to RGB * fix: regex escape characters * Update Readme links to sd-webui when appropriate (#781) * Update link to sd-webui when appropriate * added LDSR instruction per devilismyfriend guide * fix: stack overflow during recursion call (#784) * Added option to set default sampler name from config file, will be useful for those wanting to change the default sampler and have it persist even when closing the UI and opening it again. * Added try and except block to handle basic errors like StopException which is raised by streamlit when you hit the stop button and KeyError which happens also when stopping the generation because it tries to check the model at the end which is not loaded at that time, this can be ignored and so thats the reason for the exception. * separate css to external file * Added "git pull" and "git stash" to the commands run by the cmd scripts when launching the UI, this should make it so people who use it can automatically update the code from the repo and be up to date without manually using those commands everytime. * resolve conflict with master Co-authored-by: EliEron Co-authored-by: EliEron Co-authored-by: ZeroCool Co-authored-by: ZeroCool940711 Co-authored-by: Hafiidz <3688500+Hafiidz@users.noreply.github.com> Co-authored-by: Thomas Mello --- .gitignore | 1 + configs/webui/webui_streamlit.yaml | 103 ++ environment.yaml | 2 +- frontend/css/streamlit.main.css | 15 + frontend/frontend.py | 3 +- scripts/webui_streamlit.py | 1754 ++++++++++++++++++++++++++++ webui-streamlit.cmd | 53 + webui.cmd | 3 + 8 files changed, 1932 insertions(+), 2 deletions(-) create mode 100644 configs/webui/webui_streamlit.yaml create mode 100644 frontend/css/streamlit.main.css create mode 100644 scripts/webui_streamlit.py create mode 100644 webui-streamlit.cmd diff --git a/.gitignore b/.gitignore index 62482b6..46597f4 100644 --- a/.gitignore +++ b/.gitignore @@ -57,3 +57,4 @@ condaenv.*.requirements.txt /log/**/*.png /log/log.csv /flagged/* +/gfpgan/* diff --git a/configs/webui/webui_streamlit.yaml b/configs/webui/webui_streamlit.yaml new file mode 100644 index 0000000..84263bd --- /dev/null +++ b/configs/webui/webui_streamlit.yaml @@ -0,0 +1,103 @@ +# UI defaults configuration file. It is automatically loaded if located at configs/webui/webui_streamlit.yaml. +# Any changes made here will be available automatically on the web app without having to stop it. +general: + gpu: 0 + outdir: outputs + ckpt: "models/ldm/stable-diffusion-v1/model.ckpt" + fp: + name: 'embeddings/alex/embeddings_gs-11000.pt' + GFPGAN_dir: "./src/gfpgan" + RealESRGAN_dir: "./src/realesrgan" + RealESRGAN_model: "RealESRGAN_x4plus" + outdir_txt2img: outputs/txt2img-samples + outdir_img2img: outputs/img2img-samples + gfpgan_cpu: False + esrgan_cpu: False + extra_models_cpu: False + extra_models_gpu: False + save_metadata: True + skip_grid: False + skip_save: False + grid_format: "jpg:95" + n_rows: -1 + no_verify_input: False + no_half: False + precision: "autocast" + optimized: False + optimized_turbo: False + update_preview: True + update_preview_frequency: 1 + +txt2img: + prompt: + height: 512 + width: 512 + cfg_scale: 5.0 + seed: "" + batch_count: 1 + batch_size: 1 + sampling_steps: 50 + default_sampler: "k_lms" + separate_prompts: False + normalize_prompt_weights: True + save_individual_images: True + save_grid: True + group_by_prompt: True + save_as_jpg: False + use_GFPGAN: True + use_RealESRGAN: True + RealESRGAN_model: "RealESRGAN_x4plus" + variant_amount: 0.0 + variant_seed: "" + +img2img: + prompt: + sampling_steps: 50 + # Adding an int to toggles enables the corresponding feature. + # 0: Create prompt matrix (separate multiple prompts using |, and get all combinations of them) + # 1: Normalize Prompt Weights (ensure sum of weights add up to 1.0) + # 2: Loopback (use images from previous batch when creating next batch) + # 3: Random loopback seed + # 4: Save individual images + # 5: Save grid + # 6: Sort samples by prompt + # 7: Write sample info files + # 8: jpg samples + # 9: Fix faces using GFPGAN + # 10: Upscale images using Real-ESRGAN + sampler_name: k_lms + denoising_strength: 0.45 + # 0: Keep masked area + # 1: Regenerate only masked area + mask_mode: 0 + # 0: Just resize + # 1: Crop and resize + # 2: Resize and fill + resize_mode: 0 + # Leave blank for random seed: + seed: "" + ddim_eta: 0.0 + cfg_scale: 5.0 + batch_count: 1 + batch_size: 1 + height: 512 + width: 512 + # Textual inversion embeddings file path: + fp: "" + loopback: True + random_seed_loopback: True + separate_prompts: False + normalize_prompt_weights: True + save_individual_images: True + save_grid: True + group_by_prompt: True + save_as_jpg: False + use_GFPGAN: True + use_RealESRGAN: True + RealESRGAN_model: "RealESRGAN_x4plus" + variant_amount: 0.0 + variant_seed: "" + +gfpgan: + strength: 100 + diff --git a/environment.yaml b/environment.yaml index d7328b3..5bb3bf8 100644 --- a/environment.yaml +++ b/environment.yaml @@ -20,7 +20,6 @@ dependencies: - pytorch-lightning==1.4.2 - omegaconf==2.1.1 - test-tube>=0.7.5 - - streamlit>=0.73.1 - einops==0.3.0 - torch-fidelity==0.3.0 - transformers==4.19.2 @@ -33,6 +32,7 @@ dependencies: - facexlib>=0.2.3 - python-slugify>=6.1.2 - streamlit>=1.12.2 + - retry>=0.9.2 - -e git+https://github.com/CompVis/taming-transformers#egg=taming-transformers - -e git+https://github.com/openai/CLIP#egg=clip - -e git+https://github.com/TencentARC/GFPGAN#egg=GFPGAN diff --git a/frontend/css/streamlit.main.css b/frontend/css/streamlit.main.css new file mode 100644 index 0000000..4e11b77 --- /dev/null +++ b/frontend/css/streamlit.main.css @@ -0,0 +1,15 @@ +.css-18e3th9 { + padding-top: 2rem; + padding-bottom: 10rem; + padding-left: 5rem; + padding-right: 5rem; +} +.css-1d391kg { + padding-top: 3.5rem; + padding-right: 1rem; + padding-bottom: 3.5rem; + padding-left: 1rem; +} +button[data-baseweb="tab"] { + font-size: 25px; +} diff --git a/frontend/frontend.py b/frontend/frontend.py index adef7d3..41e8672 100644 --- a/frontend/frontend.py +++ b/frontend/frontend.py @@ -98,7 +98,8 @@ def draw_gradio_ui(opt, img2img=lambda x: x, txt2img=lambda x: x, imgproc=lambda choices=['RealESRGAN_x4plus', 'RealESRGAN_x4plus_anime_6B'], value='RealESRGAN_x4plus', - visible=False) # RealESRGAN is not None # invisible until removed) # TODO: Feels like I shouldnt slot it in here. + visible=False) # RealESRGAN is not None # invisible until removed) # TODO: Feels like I shouldnt slot it in here. + txt2img_ddim_eta = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="DDIM ETA", value=txt2img_defaults['ddim_eta'], visible=False) txt2img_variant_amount = gr.Slider(minimum=0.0, maximum=1.0, label='Variation Amount', diff --git a/scripts/webui_streamlit.py b/scripts/webui_streamlit.py new file mode 100644 index 0000000..bd680da --- /dev/null +++ b/scripts/webui_streamlit.py @@ -0,0 +1,1754 @@ +import warnings +import streamlit as st +from streamlit import StopException, StreamlitAPIException + +import base64, cv2 +import argparse, os, sys, glob, re, random, datetime +from PIL import Image, ImageFont, ImageDraw, ImageFilter, ImageOps +from PIL.PngImagePlugin import PngInfo +import requests +from scipy import integrate +import torch +from torchdiffeq import odeint +from tqdm.auto import trange, tqdm +import k_diffusion as K +import math +import mimetypes +import numpy as np +import pynvml +import threading, asyncio +import time +import torch +from torch import autocast +from torchvision import transforms +import torch.nn as nn +import yaml +from typing import List, Union +from pathlib import Path +from tqdm import tqdm +from contextlib import contextmanager, nullcontext +from einops import rearrange, repeat +from itertools import islice +from omegaconf import OmegaConf +from io import BytesIO +from ldm.models.diffusion.ddim import DDIMSampler +from ldm.models.diffusion.plms import PLMSSampler +from ldm.util import instantiate_from_config +from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, \ + extract_into_tensor +from retry import retry + +# we use python-slugify to make the filenames safe for windows and linux, its better than doing it manually +# install it with 'pip install python-slugify' +from slugify import slugify + +try: + # this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start. + from transformers import logging + + logging.set_verbosity_error() +except: + pass + +# remove some annoying deprecation warnings that show every now and then. +warnings.filterwarnings("ignore", category=DeprecationWarning) + +defaults = OmegaConf.load("configs/webui/webui_streamlit.yaml") + +# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the bowser will not show any UI +mimetypes.init() +mimetypes.add_type('application/javascript', '.js') + +# some of those options should not be changed at all because they would break the model, so I removed them from options. +opt_C = 4 +opt_f = 8 + +# should and will be moved to a settings menu in the UI at some point +grid_format = [s.lower() for s in defaults.general.grid_format.split(':')] +grid_lossless = False +grid_quality = 100 +if grid_format[0] == 'png': + grid_ext = 'png' + grid_format = 'png' +elif grid_format[0] in ['jpg', 'jpeg']: + grid_quality = int(grid_format[1]) if len(grid_format) > 1 else 100 + grid_ext = 'jpg' + grid_format = 'jpeg' +elif grid_format[0] == 'webp': + grid_quality = int(grid_format[1]) if len(grid_format) > 1 else 100 + grid_ext = 'webp' + grid_format = 'webp' + if grid_quality < 0: # e.g. webp:-100 for lossless mode + grid_lossless = True + grid_quality = abs(grid_quality) + +# this should force GFPGAN and RealESRGAN onto the selected gpu as well +os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" # see issue #152 +os.environ["CUDA_VISIBLE_DEVICES"] = str(defaults.general.gpu) + +@retry(tries=5) +def load_models(continue_prev_run = False, use_GFPGAN=False, use_RealESRGAN=False, RealESRGAN_model="RealESRGAN_x4plus"): + """Load the different models. We also reuse the models that are already in memory to speed things up instead of loading them again. """ + + print ("Loading models.") + + # Generate random run ID + # Used to link runs linked w/ continue_prev_run which is not yet implemented + # Use URL and filesystem safe version just in case. + st.session_state["run_id"] = base64.urlsafe_b64encode( + os.urandom(6) + ).decode("ascii") + + # check what models we want to use and if the they are already loaded. + + if use_GFPGAN: + if "GFPGAN" in st.session_state: + print("GFPGAN already loaded") + else: + # Load GFPGAN + if os.path.exists(defaults.general.GFPGAN_dir): + try: + st.session_state["GFPGAN"] = load_GFPGAN() + print("Loaded GFPGAN") + except Exception: + import traceback + print("Error loading GFPGAN:", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + else: + if "GFPGAN" in st.session_state: + del st.session_state["GFPGAN"] + + if use_RealESRGAN: + if "RealESRGAN" in st.session_state and st.session_state["RealESRGAN"].model.name == RealESRGAN_model: + print("RealESRGAN already loaded") + else: + #Load RealESRGAN + try: + # We first remove the variable in case it has something there, + # some errors can load the model incorrectly and leave things in memory. + del st.session_state["RealESRGAN"] + except KeyError: + pass + + if os.path.exists(defaults.general.RealESRGAN_dir): + # st.session_state is used for keeping the models in memory across multiple pages or runs. + st.session_state["RealESRGAN"] = load_RealESRGAN(RealESRGAN_model) + print("Loaded RealESRGAN with model "+ st.session_state["RealESRGAN"].model.name) + + else: + if "RealESRGAN" in st.session_state: + del st.session_state["RealESRGAN"] + + + if "model" in st.session_state: + print("Model already loaded") + else: + config = OmegaConf.load("configs/stable-diffusion/v1-inference.yaml") + model = load_model_from_config(config, defaults.general.ckpt) + + st.session_state["device"] = torch.device(f"cuda:{defaults.general.gpu}") if torch.cuda.is_available() else torch.device("cpu") + st.session_state["model"] = (model if defaults.general.no_half else model.half()).to(st.session_state["device"] ) + + print("Model loaded.") + + +def load_model_from_config(config, ckpt, verbose=False): + + print(f"Loading model from {ckpt}") + + pl_sd = torch.load(ckpt, map_location="cpu") + if "global_step" in pl_sd: + print(f"Global Step: {pl_sd['global_step']}") + sd = pl_sd["state_dict"] + model = instantiate_from_config(config.model) + m, u = model.load_state_dict(sd, strict=False) + if len(m) > 0 and verbose: + print("missing keys:") + print(m) + if len(u) > 0 and verbose: + print("unexpected keys:") + print(u) + + model.cuda() + model.eval() + return model + +def load_sd_from_config(ckpt, verbose=False): + print(f"Loading model from {ckpt}") + pl_sd = torch.load(ckpt, map_location="cpu") + if "global_step" in pl_sd: + print(f"Global Step: {pl_sd['global_step']}") + sd = pl_sd["state_dict"] + return sd +# +@retry(tries=5) +def generation_callback(img, i=0): + + try: + if i == 0: + if img['i']: i = img['i'] + except TypeError: + pass + + + if i % int(defaults.general.update_preview_frequency) == 0 and defaults.general.update_preview: + #print (img) + #print (type(img)) + # The following lines will convert the tensor we got on img to an actual image we can render on the UI. + # It can probably be done in a better way for someone who knows what they're doing. I don't. + #print (img,isinstance(img, torch.Tensor)) + if isinstance(img, torch.Tensor): + x_samples_ddim = (st.session_state["model"] if not defaults.general.optimized else modelFS).decode_first_stage(img) + else: + # When using the k Diffusion samplers they return a dict instead of a tensor that look like this: + # {'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised} + x_samples_ddim = (st.session_state["model"] if not defaults.general.optimized else modelFS).decode_first_stage(img["denoised"]) + + x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) + + pil_image = transforms.ToPILImage()(x_samples_ddim.squeeze_(0)) + + # update image on the UI so we can see the progress + st.session_state["preview_image"].image(pil_image) + + # Show a progress bar so we can keep track of the progress even when the image progress is not been shown, + # Dont worry, it doesnt affect the performance. + if st.session_state["generation_mode"] == "txt2img": + percent = int(100 * float(i+1 if i+1 < st.session_state.sampling_steps else st.session_state.sampling_steps)/float(st.session_state.sampling_steps)) + st.session_state["progress_bar_text"].text( + f"Running step: {i+1 if i+1 < st.session_state.sampling_steps else st.session_state.sampling_steps}/{st.session_state.sampling_steps} {percent if percent < 100 else 100}%") + else: + round_sampling_steps = round(st.session_state.sampling_steps * st.session_state["denoising_strength"]) + percent = int(100 * float(i+1 if i+1 < round_sampling_steps else round_sampling_steps)/float(round_sampling_steps)) + st.session_state["progress_bar_text"].text( + f"""Running step: {i+1 if i+1 < round_sampling_steps else round_sampling_steps}/{round_sampling_steps} {percent if percent < 100 else 100}%""") + + st.session_state["progress_bar"].progress(percent if percent < 100 else 100) + + + +class MemUsageMonitor(threading.Thread): + stop_flag = False + max_usage = 0 + total = -1 + + def __init__(self, name): + threading.Thread.__init__(self) + self.name = name + + def run(self): + try: + pynvml.nvmlInit() + except: + print(f"[{self.name}] Unable to initialize NVIDIA management. No memory stats. \n") + return + print(f"[{self.name}] Recording max memory usage...\n") + handle = pynvml.nvmlDeviceGetHandleByIndex(defaults.general.gpu) + self.total = pynvml.nvmlDeviceGetMemoryInfo(handle).total + while not self.stop_flag: + m = pynvml.nvmlDeviceGetMemoryInfo(handle) + self.max_usage = max(self.max_usage, m.used) + # print(self.max_usage) + time.sleep(0.1) + print(f"[{self.name}] Stopped recording.\n") + pynvml.nvmlShutdown() + + def read(self): + return self.max_usage, self.total + + def stop(self): + self.stop_flag = True + + def read_and_stop(self): + self.stop_flag = True + return self.max_usage, self.total + +class CFGMaskedDenoiser(nn.Module): + def __init__(self, model): + super().__init__() + self.inner_model = model + + def forward(self, x, sigma, uncond, cond, cond_scale, mask, x0, xi): + x_in = x + x_in = torch.cat([x_in] * 2) + sigma_in = torch.cat([sigma] * 2) + cond_in = torch.cat([uncond, cond]) + uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2) + denoised = uncond + (cond - uncond) * cond_scale + + if mask is not None: + assert x0 is not None + img_orig = x0 + mask_inv = 1. - mask + denoised = (img_orig * mask_inv) + (mask * denoised) + + return denoised + +class CFGDenoiser(nn.Module): + def __init__(self, model): + super().__init__() + self.inner_model = model + + def forward(self, x, sigma, uncond, cond, cond_scale): + x_in = torch.cat([x] * 2) + sigma_in = torch.cat([sigma] * 2) + cond_in = torch.cat([uncond, cond]) + uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2) + return uncond + (cond - uncond) * cond_scale +def append_zero(x): + return torch.cat([x, x.new_zeros([1])]) +def append_dims(x, target_dims): + """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" + dims_to_append = target_dims - x.ndim + if dims_to_append < 0: + raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less') + return x[(...,) + (None,) * dims_to_append] +def get_sigmas_karras(n, sigma_min, sigma_max, rho=7., device='cpu'): + """Constructs the noise schedule of Karras et al. (2022).""" + ramp = torch.linspace(0, 1, n) + min_inv_rho = sigma_min ** (1 / rho) + max_inv_rho = sigma_max ** (1 / rho) + sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho + return append_zero(sigmas).to(device) + + +def get_sigmas_exponential(n, sigma_min, sigma_max, device='cpu'): + """Constructs an exponential noise schedule.""" + sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), n, device=device).exp() + return append_zero(sigmas) + + +def get_sigmas_vp(n, beta_d=19.9, beta_min=0.1, eps_s=1e-3, device='cpu'): + """Constructs a continuous VP noise schedule.""" + t = torch.linspace(1, eps_s, n, device=device) + sigmas = torch.sqrt(torch.exp(beta_d * t ** 2 / 2 + beta_min * t) - 1) + return append_zero(sigmas) + + +def to_d(x, sigma, denoised): + """Converts a denoiser output to a Karras ODE derivative.""" + return (x - denoised) / append_dims(sigma, x.ndim) +def linear_multistep_coeff(order, t, i, j): + if order - 1 > i: + raise ValueError(f'Order {order} too high for step {i}') + def fn(tau): + prod = 1. + for k in range(order): + if j == k: + continue + prod *= (tau - t[i - k]) / (t[i - j] - t[i - k]) + return prod + return integrate.quad(fn, t[i], t[i + 1], epsrel=1e-4)[0] + +class KDiffusionSampler: + def __init__(self, m, sampler): + self.model = m + self.model_wrap = K.external.CompVisDenoiser(m) + self.schedule = sampler + def get_sampler_name(self): + return self.schedule + def sample(self, S, conditioning, batch_size, shape, verbose, unconditional_guidance_scale, unconditional_conditioning, eta, x_T, img_callback=None, log_every_t=None): + sigmas = self.model_wrap.get_sigmas(S) + x = x_T * sigmas[0] + model_wrap_cfg = CFGDenoiser(self.model_wrap) + samples_ddim = None + samples_ddim = K.sampling.__dict__[f'sample_{self.schedule}'](model_wrap_cfg, x, sigmas, + extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, + 'cond_scale': unconditional_guidance_scale}, disable=False, callback=generation_callback) + # + return samples_ddim, None + + +@torch.no_grad() +def log_likelihood(model, x, sigma_min, sigma_max, extra_args=None, atol=1e-4, rtol=1e-4): + extra_args = {} if extra_args is None else extra_args + s_in = x.new_ones([x.shape[0]]) + v = torch.randint_like(x, 2) * 2 - 1 + fevals = 0 + def ode_fn(sigma, x): + nonlocal fevals + with torch.enable_grad(): + x = x[0].detach().requires_grad_() + denoised = model(x, sigma * s_in, **extra_args) + d = to_d(x, sigma, denoised) + fevals += 1 + grad = torch.autograd.grad((d * v).sum(), x)[0] + d_ll = (v * grad).flatten(1).sum(1) + return d.detach(), d_ll + x_min = x, x.new_zeros([x.shape[0]]) + t = x.new_tensor([sigma_min, sigma_max]) + sol = odeint(ode_fn, x_min, t, atol=atol, rtol=rtol, method='dopri5') + latent, delta_ll = sol[0][-1], sol[1][-1] + ll_prior = torch.distributions.Normal(0, sigma_max).log_prob(latent).flatten(1).sum(1) + return ll_prior + delta_ll, {'fevals': fevals} + + +def create_random_tensors(shape, seeds): + xs = [] + for seed in seeds: + torch.manual_seed(seed) + + # randn results depend on device; gpu and cpu get different results for same seed; + # the way I see it, it's better to do this on CPU, so that everyone gets same result; + # but the original script had it like this so i do not dare change it for now because + # it will break everyone's seeds. + xs.append(torch.randn(shape, device=defaults.general.gpu)) + x = torch.stack(xs) + return x + +def torch_gc(): + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + +def load_GFPGAN(): + model_name = 'GFPGANv1.3' + model_path = os.path.join(defaults.general.GFPGAN_dir, 'experiments/pretrained_models', model_name + '.pth') + if not os.path.isfile(model_path): + raise Exception("GFPGAN model not found at path "+model_path) + + sys.path.append(os.path.abspath(defaults.general.GFPGAN_dir)) + from gfpgan import GFPGANer + + if defaults.general.gfpgan_cpu or defaults.general.extra_models_cpu: + instance = GFPGANer(model_path=model_path, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=torch.device('cpu')) + elif defaults.general.extra_models_gpu: + instance = GFPGANer(model_path=model_path, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=torch.device(f'cuda:{defaults.general.gfpgan_gpu}')) + else: + instance = GFPGANer(model_path=model_path, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=torch.device(f'cuda:{defaults.general.gpu}')) + return instance + +def load_RealESRGAN(model_name: str): + from basicsr.archs.rrdbnet_arch import RRDBNet + RealESRGAN_models = { + 'RealESRGAN_x4plus': RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4), + 'RealESRGAN_x4plus_anime_6B': RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4) + } + + model_path = os.path.join(defaults.general.RealESRGAN_dir, 'experiments/pretrained_models', model_name + '.pth') + if not os.path.exists(os.path.join(defaults.general.RealESRGAN_dir, "experiments","pretrained_models", f"{model_name}.pth")): + raise Exception(model_name+".pth not found at path "+model_path) + + sys.path.append(os.path.abspath(defaults.general.RealESRGAN_dir)) + from realesrgan import RealESRGANer + + if defaults.general.esrgan_cpu or defaults.general.extra_models_cpu: + instance = RealESRGANer(scale=2, model_path=model_path, model=RealESRGAN_models[model_name], pre_pad=0, half=False) # cpu does not support half + instance.device = torch.device('cpu') + instance.model.to('cpu') + elif defaults.general.extra_models_gpu: + instance = RealESRGANer(scale=2, model_path=model_path, model=RealESRGAN_models[model_name], pre_pad=0, half=not defaults.general.no_half, device=torch.device(f'cuda:{defaults.general.esrgan_gpu}')) + else: + instance = RealESRGANer(scale=2, model_path=model_path, model=RealESRGAN_models[model_name], pre_pad=0, half=not defaults.general.no_half, device=torch.device(f'cuda:{defaults.general.gpu}')) + instance.model.name = model_name + + return instance + +prompt_parser = re.compile(""" + (?P # capture group for 'prompt' + [^:]+ # match one or more non ':' characters + ) # end 'prompt' + (?: # non-capture group + :+ # match one or more ':' characters + (?P # capture group for 'weight' + -?\\d+(?:\\.\\d+)? # match positive or negative decimal number + )? # end weight capture group, make optional + \\s* # strip spaces after weight + | # OR + $ # else, if no ':' then match end of line + ) # end non-capture group +""", re.VERBOSE) + +# grabs all text up to the first occurrence of ':' as sub-prompt +# takes the value following ':' as weight +# if ':' has no value defined, defaults to 1.0 +# repeats until no text remaining +def split_weighted_subprompts(input_string, normalize=True): + parsed_prompts = [(match.group("prompt"), float(match.group("weight") or 1)) for match in re.finditer(prompt_parser, input_string)] + if not normalize: + return parsed_prompts + # this probably still doesn't handle negative weights very well + weight_sum = sum(map(lambda x: x[1], parsed_prompts)) + return [(x[0], x[1] / weight_sum) for x in parsed_prompts] + +def slerp(device, t, v0:torch.Tensor, v1:torch.Tensor, DOT_THRESHOLD=0.9995): + v0 = v0.detach().cpu().numpy() + v1 = v1.detach().cpu().numpy() + + dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1))) + if np.abs(dot) > DOT_THRESHOLD: + v2 = (1 - t) * v0 + t * v1 + else: + theta_0 = np.arccos(dot) + sin_theta_0 = np.sin(theta_0) + theta_t = theta_0 * t + sin_theta_t = np.sin(theta_t) + s0 = np.sin(theta_0 - theta_t) / sin_theta_0 + s1 = sin_theta_t / sin_theta_0 + v2 = s0 * v0 + s1 * v1 + + v2 = torch.from_numpy(v2).to(device) + + return v2 + + +def ModelLoader(models,load=False,unload=False,imgproc_realesrgan_model_name='RealESRGAN_x4plus'): + #get global variables + global_vars = globals() + #check if m is in globals + if unload: + for m in models: + if m in global_vars: + #if it is, delete it + del global_vars[m] + if defaults.general.optimized: + if m == 'model': + del global_vars[m+'FS'] + del global_vars[m+'CS'] + if m =='model': + m='Stable Diffusion' + print('Unloaded ' + m) + if load: + for m in models: + if m not in global_vars or m in global_vars and type(global_vars[m]) == bool: + #if it isn't, load it + if m == 'GFPGAN': + global_vars[m] = load_GFPGAN() + elif m == 'model': + sdLoader = load_sd_from_config() + global_vars[m] = sdLoader[0] + if defaults.general.optimized: + global_vars[m+'CS'] = sdLoader[1] + global_vars[m+'FS'] = sdLoader[2] + elif m == 'RealESRGAN': + global_vars[m] = load_RealESRGAN(imgproc_realesrgan_model_name) + elif m == 'LDSR': + global_vars[m] = load_LDSR() + if m =='model': + m='Stable Diffusion' + print('Loaded ' + m) + torch_gc() + + + +def get_font(fontsize): + fonts = ["arial.ttf", "DejaVuSans.ttf"] + for font_name in fonts: + try: + return ImageFont.truetype(font_name, fontsize) + except OSError: + pass + + # ImageFont.load_default() is practically unusable as it only supports + # latin1, so raise an exception instead if no usable font was found + raise Exception(f"No usable font found (tried {', '.join(fonts)})") + +def load_embeddings(fp): + if fp is not None and hasattr(st.session_state["model"], "embedding_manager"): + st.session_state["model"].embedding_manager.load(fp['name']) + +def image_grid(imgs, batch_size, force_n_rows=None, captions=None): + #print (len(imgs)) + if force_n_rows is not None: + rows = force_n_rows + elif defaults.general.n_rows > 0: + rows = defaults.general.n_rows + elif defaults.general.n_rows == 0: + rows = batch_size + else: + rows = math.sqrt(len(imgs)) + rows = round(rows) + + cols = math.ceil(len(imgs) / rows) + + w, h = imgs[0].size + grid = Image.new('RGB', size=(cols * w, rows * h), color='black') + + fnt = get_font(30) + + for i, img in enumerate(imgs): + grid.paste(img, box=(i % cols * w, i // cols * h)) + if captions and i= 2**32: + n = n >> 32 + return n + +def check_prompt_length(prompt, comments): + """this function tests if prompt is too long, and if so, adds a message to comments""" + + tokenizer = (st.session_state["model"] if not defaults.general.optimized else modelCS).cond_stage_model.tokenizer + max_length = (st.session_state["model"] if not defaults.general.optimized else modelCS).cond_stage_model.max_length + + info = (st.session_state["model"] if not defaults.general.optimized else modelCS).cond_stage_model.tokenizer([prompt], truncation=True, max_length=max_length, + return_overflowing_tokens=True, padding="max_length", return_tensors="pt") + ovf = info['overflowing_tokens'][0] + overflowing_count = ovf.shape[0] + if overflowing_count == 0: + return + + vocab = {v: k for k, v in tokenizer.get_vocab().items()} + overflowing_words = [vocab.get(int(x), "") for x in ovf] + overflowing_text = tokenizer.convert_tokens_to_string(''.join(overflowing_words)) + + comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n") + +def save_sample(image, sample_path_i, filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale, + normalize_prompt_weights, use_GFPGAN, write_info_files, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, + save_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, save_individual_images): + + filename_i = os.path.join(sample_path_i, filename) + + if not jpg_sample: + if defaults.general.save_metadata: + metadata = PngInfo() + metadata.add_text("SD:prompt", prompts[i]) + metadata.add_text("SD:seed", str(seeds[i])) + metadata.add_text("SD:width", str(width)) + metadata.add_text("SD:height", str(height)) + metadata.add_text("SD:steps", str(steps)) + metadata.add_text("SD:cfg_scale", str(cfg_scale)) + metadata.add_text("SD:normalize_prompt_weights", str(normalize_prompt_weights)) + if init_img is not None: + metadata.add_text("SD:denoising_strength", str(denoising_strength)) + metadata.add_text("SD:GFPGAN", str(use_GFPGAN and st.session_state["GFPGAN"] is not None)) + image.save(f"{filename_i}.png", pnginfo=metadata) + else: + image.save(f"{filename_i}.png") + else: + image.save(f"{filename_i}.jpg", 'jpeg', quality=100, optimize=True) + + if write_info_files: + # toggles differ for txt2img vs. img2img: + offset = 0 if init_img is None else 2 + toggles = [] + if prompt_matrix: + toggles.append(0) + if normalize_prompt_weights: + toggles.append(1) + if init_img is not None: + if uses_loopback: + toggles.append(2) + if uses_random_seed_loopback: + toggles.append(3) + if save_individual_images: + toggles.append(2 + offset) + if save_grid: + toggles.append(3 + offset) + if sort_samples: + toggles.append(4 + offset) + if write_info_files: + toggles.append(5 + offset) + if use_GFPGAN: + toggles.append(6 + offset) + info_dict = dict( + target="txt2img" if init_img is None else "img2img", + prompt=prompts[i], ddim_steps=steps, toggles=toggles, sampler_name=sampler_name, + ddim_eta=ddim_eta, n_iter=n_iter, batch_size=batch_size, cfg_scale=cfg_scale, + seed=seeds[i], width=width, height=height + ) + if init_img is not None: + # Not yet any use for these, but they bloat up the files: + #info_dict["init_img"] = init_img + #info_dict["init_mask"] = init_mask + info_dict["denoising_strength"] = denoising_strength + info_dict["resize_mode"] = resize_mode + with open(f"{filename_i}.yaml", "w", encoding="utf8") as f: + yaml.dump(info_dict, f, allow_unicode=True, width=10000) + + # render the image on the frontend + st.session_state["preview_image"].image(image) + +def get_next_sequence_number(path, prefix=''): + """ + Determines and returns the next sequence number to use when saving an + image in the specified directory. + + If a prefix is given, only consider files whose names start with that + prefix, and strip the prefix from filenames before extracting their + sequence number. + + The sequence starts at 0. + """ + result = -1 + for p in Path(path).iterdir(): + if p.name.endswith(('.png', '.jpg')) and p.name.startswith(prefix): + tmp = p.name[len(prefix):] + try: + result = max(int(tmp.split('-')[0]), result) + except ValueError: + pass + return result + 1 + + +def oxlamon_matrix(prompt, seed, n_iter, batch_size): + pattern = re.compile(r'(,\s){2,}') + + class PromptItem: + def __init__(self, text, parts, item): + self.text = text + self.parts = parts + if item: + self.parts.append( item ) + + def clean(txt): + return re.sub(pattern, ', ', txt) + + def getrowcount( txt ): + for data in re.finditer( ".*?\\((.*?)\\).*", txt ): + if data: + return len(data.group(1).split("|")) + break + return None + + def repliter( txt ): + for data in re.finditer( ".*?\\((.*?)\\).*", txt ): + if data: + r = data.span(1) + for item in data.group(1).split("|"): + yield (clean(txt[:r[0]-1] + item.strip() + txt[r[1]+1:]), item.strip()) + break + + def iterlist( items ): + outitems = [] + for item in items: + for newitem, newpart in repliter(item.text): + outitems.append( PromptItem(newitem, item.parts.copy(), newpart) ) + + return outitems + + def getmatrix( prompt ): + dataitems = [ PromptItem( prompt[1:].strip(), [], None ) ] + while True: + newdataitems = iterlist( dataitems ) + if len( newdataitems ) == 0: + return dataitems + dataitems = newdataitems + + def classToArrays( items, seed, n_iter ): + texts = [] + parts = [] + seeds = [] + + for item in items: + itemseed = seed + for i in range(n_iter): + texts.append( item.text ) + parts.append( f"Seed: {itemseed}\n" + "\n".join(item.parts) ) + seeds.append( itemseed ) + itemseed += 1 + + return seeds, texts, parts + + all_seeds, all_prompts, prompt_matrix_parts = classToArrays(getmatrix( prompt ), seed, n_iter) + n_iter = math.ceil(len(all_prompts) / batch_size) + + needrows = getrowcount(prompt) + if needrows: + xrows = math.sqrt(len(all_prompts)) + xrows = round(xrows) + # if columns is to much + cols = math.ceil(len(all_prompts) / xrows) + if cols > needrows*4: + needrows *= 2 + + return all_seeds, n_iter, prompt_matrix_parts, all_prompts, needrows + + +def process_images( + outpath, func_init, func_sample, prompt, seed, sampler_name, save_grid, batch_size, + n_iter, steps, cfg_scale, width, height, prompt_matrix, use_GFPGAN, use_RealESRGAN, realesrgan_model_name, + fp=None, ddim_eta=0.0, normalize_prompt_weights=True, init_img=None, init_mask=None, + keep_mask=False, mask_blur_strength=3, denoising_strength=0.75, resize_mode=None, uses_loopback=False, + uses_random_seed_loopback=False, sort_samples=True, write_info_files=True, jpg_sample=False, + variant_amount=0.0, variant_seed=None, save_individual_images: bool = True): + """this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch""" + assert prompt is not None + torch_gc() + # start time after garbage collection (or before?) + start_time = time.time() + + # We will use this date here later for the folder name, need to start_time if not need + run_start_dt = datetime.datetime.now() + + mem_mon = MemUsageMonitor('MemMon') + mem_mon.start() + + if hasattr(st.session_state["model"], "embedding_manager"): + load_embeddings(fp) + + os.makedirs(outpath, exist_ok=True) + + sample_path = os.path.join(outpath, "samples") + os.makedirs(sample_path, exist_ok=True) + + if not ("|" in prompt) and prompt.startswith("@"): + prompt = prompt[1:] + + comments = [] + + prompt_matrix_parts = [] + simple_templating = False + add_original_image = not (use_RealESRGAN or use_GFPGAN) + + if prompt_matrix: + if prompt.startswith("@"): + simple_templating = True + add_original_image = not (use_RealESRGAN or use_GFPGAN) + all_seeds, n_iter, prompt_matrix_parts, all_prompts, frows = oxlamon_matrix(prompt, seed, n_iter, batch_size) + else: + all_prompts = [] + prompt_matrix_parts = prompt.split("|") + combination_count = 2 ** (len(prompt_matrix_parts) - 1) + for combination_num in range(combination_count): + current = prompt_matrix_parts[0] + + for n, text in enumerate(prompt_matrix_parts[1:]): + if combination_num & (2 ** n) > 0: + current += ("" if text.strip().startswith(",") else ", ") + text + + all_prompts.append(current) + + n_iter = math.ceil(len(all_prompts) / batch_size) + all_seeds = len(all_prompts) * [seed] + + print(f"Prompt matrix will create {len(all_prompts)} images using a total of {n_iter} batches.") + else: + + if not defaults.general.no_verify_input: + try: + check_prompt_length(prompt, comments) + except: + import traceback + print("Error verifying input:", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + + all_prompts = batch_size * n_iter * [prompt] + all_seeds = [seed + x for x in range(len(all_prompts))] + + precision_scope = autocast if defaults.general.precision == "autocast" else nullcontext + output_images = [] + grid_captions = [] + stats = [] + with torch.no_grad(), precision_scope("cuda"), (st.session_state["model"].ema_scope() if not defaults.general.optimized else nullcontext()): + init_data = func_init() + tic = time.time() + + + # if variant_amount > 0.0 create noise from base seed + base_x = None + if variant_amount > 0.0: + target_seed_randomizer = seed_to_int('') # random seed + torch.manual_seed(seed) # this has to be the single starting seed (not per-iteration) + base_x = create_random_tensors([opt_C, height // opt_f, width // opt_f], seeds=[seed]) + # we don't want all_seeds to be sequential from starting seed with variants, + # since that makes the same variants each time, + # so we add target_seed_randomizer as a random offset + for si in range(len(all_seeds)): + all_seeds[si] += target_seed_randomizer + + for n in range(n_iter): + print(f"Iteration: {n+1}/{n_iter}") + prompts = all_prompts[n * batch_size:(n + 1) * batch_size] + captions = prompt_matrix_parts[n * batch_size:(n + 1) * batch_size] + seeds = all_seeds[n * batch_size:(n + 1) * batch_size] + + print(prompt) + + if defaults.general.optimized: + modelCS.to(defaults.general.gpu) + + uc = (st.session_state["model"] if not defaults.general.optimized else modelCS).get_learned_conditioning(len(prompts) * [""]) + + if isinstance(prompts, tuple): + prompts = list(prompts) + + # split the prompt if it has : for weighting + # TODO for speed it might help to have this occur when all_prompts filled?? + weighted_subprompts = split_weighted_subprompts(prompts[0], normalize_prompt_weights) + + # sub-prompt weighting used if more than 1 + if len(weighted_subprompts) > 1: + c = torch.zeros_like(uc) # i dont know if this is correct.. but it works + for i in range(0, len(weighted_subprompts)): + # note if alpha negative, it functions same as torch.sub + c = torch.add(c, (st.session_state["model"] if not defaults.general.optimized else modelCS).get_learned_conditioning(weighted_subprompts[i][0]), alpha=weighted_subprompts[i][1]) + else: # just behave like usual + c = (st.session_state["model"] if not defaults.general.optimized else modelCS).get_learned_conditioning(prompts) + + + shape = [opt_C, height // opt_f, width // opt_f] + + if defaults.general.optimized: + mem = torch.cuda.memory_allocated()/1e6 + modelCS.to("cpu") + while(torch.cuda.memory_allocated()/1e6 >= mem): + time.sleep(1) + + if variant_amount == 0.0: + # we manually generate all input noises because each one should have a specific seed + x = create_random_tensors(shape, seeds=seeds) + + else: # we are making variants + # using variant_seed as sneaky toggle, + # when not None or '' use the variant_seed + # otherwise use seeds + if variant_seed != None and variant_seed != '': + specified_variant_seed = seed_to_int(variant_seed) + torch.manual_seed(specified_variant_seed) + seeds = [specified_variant_seed] + target_x = create_random_tensors(shape, seeds=seeds) + # finally, slerp base_x noise to target_x noise for creating a variant + x = slerp(defaults.general.gpu, max(0.0, min(1.0, variant_amount)), base_x, target_x) + + samples_ddim = func_sample(init_data=init_data, x=x, conditioning=c, unconditional_conditioning=uc, sampler_name=sampler_name) + + if defaults.general.optimized: + modelFS.to(defaults.general.gpu) + + x_samples_ddim = (st.session_state["model"] if not defaults.general.optimized else modelFS).decode_first_stage(samples_ddim) + x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) + + for i, x_sample in enumerate(x_samples_ddim): + sanitized_prompt = slugify(prompts[i]) + + if sort_samples: + full_path = os.path.join(os.getcwd(), sample_path, sanitized_prompt) + + + sanitized_prompt = sanitized_prompt[:220-len(full_path)] + sample_path_i = os.path.join(sample_path, sanitized_prompt) + + #print(f"output folder length: {len(os.path.join(os.getcwd(), sample_path_i))}") + #print(os.path.join(os.getcwd(), sample_path_i)) + + os.makedirs(sample_path_i, exist_ok=True) + base_count = get_next_sequence_number(sample_path_i) + filename = f"{base_count:05}-{steps}_{sampler_name}_{seeds[i]}" + else: + full_path = os.path.join(os.getcwd(), sample_path) + sample_path_i = sample_path + base_count = get_next_sequence_number(sample_path_i) + filename = f"{base_count:05}-{steps}_{sampler_name}_{seeds[i]}_{sanitized_prompt}"[:220-len(full_path)] #same as before + + x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') + x_sample = x_sample.astype(np.uint8) + image = Image.fromarray(x_sample) + original_sample = x_sample + original_filename = filename + + if use_GFPGAN and st.session_state["GFPGAN"] is not None and not use_RealESRGAN: + #skip_save = True # #287 >_> + torch_gc() + cropped_faces, restored_faces, restored_img = st.session_state["GFPGAN"].enhance(x_sample[:,:,::-1], has_aligned=False, only_center_face=False, paste_back=True) + gfpgan_sample = restored_img[:,:,::-1] + gfpgan_image = Image.fromarray(gfpgan_sample) + gfpgan_filename = original_filename + '-gfpgan' + + save_sample(gfpgan_image, sample_path_i, gfpgan_filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale, + normalize_prompt_weights, use_GFPGAN, write_info_files, prompt_matrix, init_img, uses_loopback, + uses_random_seed_loopback, save_grid, sort_samples, sampler_name, ddim_eta, + n_iter, batch_size, i, denoising_strength, resize_mode, save_individual_images=False) + + output_images.append(gfpgan_image) #287 + if simple_templating: + grid_captions.append( captions[i] + "\ngfpgan" ) + + if use_RealESRGAN and st.session_state["RealESRGAN"] is not None and not use_GFPGAN: + #skip_save = True # #287 >_> + torch_gc() + + if st.session_state["RealESRGAN"].model.name != realesrgan_model_name: + #try_loading_RealESRGAN(realesrgan_model_name) + load_models(use_GFPGAN=use_GFPGAN, use_RealESRGAN=use_RealESRGAN, RealESRGAN_model=realesrgan_model_name) + + output, img_mode = st.session_state["RealESRGAN"].enhance(x_sample[:,:,::-1]) + esrgan_filename = original_filename + '-esrgan4x' + esrgan_sample = output[:,:,::-1] + esrgan_image = Image.fromarray(esrgan_sample) + + #save_sample(image, sample_path_i, original_filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale, + #normalize_prompt_weights, use_GFPGAN, write_info_files, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, skip_save, + #save_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode) + + save_sample(esrgan_image, sample_path_i, esrgan_filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale, + normalize_prompt_weights, use_GFPGAN, write_info_files, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, + save_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, save_individual_images=False) + + output_images.append(esrgan_image) #287 + if simple_templating: + grid_captions.append( captions[i] + "\nesrgan" ) + + if use_RealESRGAN and st.session_state["RealESRGAN"] is not None and use_GFPGAN and st.session_state["GFPGAN"] is not None: + #skip_save = True # #287 >_> + torch_gc() + cropped_faces, restored_faces, restored_img = st.session_state["GFPGAN"].enhance(x_sample[:,:,::-1], has_aligned=False, only_center_face=False, paste_back=True) + gfpgan_sample = restored_img[:,:,::-1] + + if st.session_state["RealESRGAN"].model.name != realesrgan_model_name: + #try_loading_RealESRGAN(realesrgan_model_name) + load_models(use_GFPGAN=use_GFPGAN, use_RealESRGAN=use_RealESRGAN, RealESRGAN_model=realesrgan_model_name) + + output, img_mode = st.session_state["RealESRGAN"].enhance(gfpgan_sample[:,:,::-1]) + gfpgan_esrgan_filename = original_filename + '-gfpgan-esrgan4x' + gfpgan_esrgan_sample = output[:,:,::-1] + gfpgan_esrgan_image = Image.fromarray(gfpgan_esrgan_sample) + + save_sample(gfpgan_esrgan_image, sample_path_i, gfpgan_esrgan_filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale, + normalize_prompt_weights, False, write_info_files, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, + save_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, save_individual_images=False) + + output_images.append(gfpgan_esrgan_image) #287 + + if simple_templating: + grid_captions.append( captions[i] + "\ngfpgan_esrgan" ) + + if save_individual_images: + save_sample(image, sample_path_i, filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale, + normalize_prompt_weights, use_GFPGAN, write_info_files, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, + save_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, save_individual_images) + + if not use_GFPGAN or not use_RealESRGAN: + output_images.append(image) + + #if add_original_image or not simple_templating: + #output_images.append(image) + #if simple_templating: + #grid_captions.append( captions[i] ) + + if defaults.general.optimized: + mem = torch.cuda.memory_allocated()/1e6 + modelFS.to("cpu") + while(torch.cuda.memory_allocated()/1e6 >= mem): + time.sleep(1) + + if prompt_matrix or save_grid: + if prompt_matrix: + if simple_templating: + grid = image_grid(output_images, n_iter, force_n_rows=frows, captions=grid_captions) + else: + grid = image_grid(output_images, n_iter, force_n_rows=1 << ((len(prompt_matrix_parts)-1)//2)) + try: + grid = draw_prompt_matrix(grid, width, height, prompt_matrix_parts) + except: + import traceback + print("Error creating prompt_matrix text:", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + else: + grid = image_grid(output_images, batch_size) + + if grid and (batch_size > 1 or n_iter > 1): + output_images.insert(0, grid) + + grid_count = get_next_sequence_number(outpath, 'grid-') + grid_file = f"grid-{grid_count:05}-{seed}_{slugify(prompts[i].replace(' ', '_')[:220-len(full_path)])}.{grid_ext}" + grid.save(os.path.join(outpath, grid_file), grid_format, quality=grid_quality, lossless=grid_lossless, optimize=True) + + toc = time.time() + + mem_max_used, mem_total = mem_mon.read_and_stop() + time_diff = time.time()-start_time + + info = f""" + {prompt} + Steps: {steps}, Sampler: {sampler_name}, CFG scale: {cfg_scale}, Seed: {seed}{', Denoising strength: '+str(denoising_strength) if init_img is not None else ''}{', GFPGAN' if use_GFPGAN and st.session_state["GFPGAN"] is not None else ''}{', '+realesrgan_model_name if use_RealESRGAN and st.session_state["RealESRGAN"] is not None else ''}{', Prompt Matrix Mode.' if prompt_matrix else ''}""".strip() + stats = f''' + Took { round(time_diff, 2) }s total ({ round(time_diff/(len(all_prompts)),2) }s per image) + Peak memory usage: { -(mem_max_used // -1_048_576) } MiB / { -(mem_total // -1_048_576) } MiB / { round(mem_max_used/mem_total*100, 3) }%''' + + for comment in comments: + info += "\n\n" + comment + + #mem_mon.stop() + #del mem_mon + torch_gc() + + return output_images, seed, info, stats + + +def resize_image(resize_mode, im, width, height): + LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS) + if resize_mode == 0: + res = im.resize((width, height), resample=LANCZOS) + elif resize_mode == 1: + ratio = width / height + src_ratio = im.width / im.height + + src_w = width if ratio > src_ratio else im.width * height // im.height + src_h = height if ratio <= src_ratio else im.height * width // im.width + + resized = im.resize((src_w, src_h), resample=LANCZOS) + res = Image.new("RGBA", (width, height)) + res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2)) + else: + ratio = width / height + src_ratio = im.width / im.height + + src_w = width if ratio < src_ratio else im.width * height // im.height + src_h = height if ratio >= src_ratio else im.height * width // im.width + + resized = im.resize((src_w, src_h), resample=LANCZOS) + res = Image.new("RGBA", (width, height)) + res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2)) + + if ratio < src_ratio: + fill_height = height // 2 - src_h // 2 + res.paste(resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0)) + res.paste(resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)), box=(0, fill_height + src_h)) + elif ratio > src_ratio: + fill_width = width // 2 - src_w // 2 + res.paste(resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0)) + res.paste(resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)), box=(fill_width + src_w, 0)) + + return res + +def img2img(prompt: str = '', init_info: any = None, init_info_mask: any = None, mask_mode: int = 0, mask_blur_strength: int = 3, + ddim_steps: int = 50, sampler_name: str = 'DDIM', + n_iter: int = 1, cfg_scale: float = 7.5, denoising_strength: float = 0.8, + seed: int = -1, height: int = 512, width: int = 512, resize_mode: int = 0, fp = None, + variant_amount: float = None, variant_seed: int = None, ddim_eta:float = 0.0, + write_info_files:bool = True, RealESRGAN_model: str = "RealESRGAN_x4plus_anime_6B", + separate_prompts:bool = False, normalize_prompt_weights:bool = True, + save_individual_images: bool = True, save_grid: bool = True, group_by_prompt: bool = True, + save_as_jpg: bool = True, use_GFPGAN: bool = True, use_RealESRGAN: bool = True, loopback: bool = False, + random_seed_loopback: bool = False + ): + + outpath = defaults.general.outdir_img2img or defaults.general.outdir or "outputs/img2img-samples" + err = False + #loopback = False + #skip_save = False + seed = seed_to_int(seed) + + batch_size = 1 + + #prompt_matrix = 0 + #normalize_prompt_weights = 1 in toggles + #loopback = 2 in toggles + #random_seed_loopback = 3 in toggles + #skip_save = 4 not in toggles + #save_grid = 5 in toggles + #sort_samples = 6 in toggles + #write_info_files = 7 in toggles + #write_sample_info_to_log_file = 8 in toggles + #jpg_sample = 9 in toggles + #use_GFPGAN = 10 in toggles + #use_RealESRGAN = 11 in toggles + + if sampler_name == 'PLMS': + sampler = PLMSSampler(st.session_state["model"]) + elif sampler_name == 'DDIM': + sampler = DDIMSampler(st.session_state["model"]) + elif sampler_name == 'k_dpm_2_a': + sampler = KDiffusionSampler(st.session_state["model"],'dpm_2_ancestral') + elif sampler_name == 'k_dpm_2': + sampler = KDiffusionSampler(st.session_state["model"],'dpm_2') + elif sampler_name == 'k_euler_a': + sampler = KDiffusionSampler(st.session_state["model"],'euler_ancestral') + elif sampler_name == 'k_euler': + sampler = KDiffusionSampler(st.session_state["model"],'euler') + elif sampler_name == 'k_heun': + sampler = KDiffusionSampler(st.session_state["model"],'heun') + elif sampler_name == 'k_lms': + sampler = KDiffusionSampler(st.session_state["model"],'lms') + else: + raise Exception("Unknown sampler: " + sampler_name) + + init_img = init_info + init_mask = None + keep_mask = False + + assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]' + t_enc = int(denoising_strength * ddim_steps) + + def init(): + + image = init_img + image = np.array(image).astype(np.float32) / 255.0 + image = image[None].transpose(0, 3, 1, 2) + image = torch.from_numpy(image) + + mask = None + if defaults.general.optimized: + modelFS.to(st.session_state["device"] ) + + init_image = 2. * image - 1. + init_image = init_image.to(st.session_state["device"]) + init_latent = (st.session_state["model"] if not defaults.general.optimized else modelFS).get_first_stage_encoding((st.session_state["model"] if not defaults.general.optimized else modelFS).encode_first_stage(init_image)) # move to latent space + + if defaults.general.optimized: + mem = torch.cuda.memory_allocated()/1e6 + modelFS.to("cpu") + while(torch.cuda.memory_allocated()/1e6 >= mem): + time.sleep(1) + + return init_latent, mask, + + def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name): + t_enc_steps = t_enc + obliterate = False + if ddim_steps == t_enc_steps: + t_enc_steps = t_enc_steps - 1 + obliterate = True + + if sampler_name != 'DDIM': + x0, z_mask = init_data + + sigmas = sampler.model_wrap.get_sigmas(ddim_steps) + noise = x * sigmas[ddim_steps - t_enc_steps - 1] + + xi = x0 + noise + + # Obliterate masked image + if z_mask is not None and obliterate: + random = torch.randn(z_mask.shape, device=xi.device) + xi = (z_mask * noise) + ((1-z_mask) * xi) + + sigma_sched = sigmas[ddim_steps - t_enc_steps - 1:] + model_wrap_cfg = CFGMaskedDenoiser(sampler.model_wrap) + samples_ddim = K.sampling.__dict__[f'sample_{sampler.get_sampler_name()}'](model_wrap_cfg, xi, sigma_sched, + extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, + 'cond_scale': cfg_scale, 'mask': z_mask, 'x0': x0, 'xi': xi}, disable=False, + callback=generation_callback) + else: + + x0, z_mask = init_data + + sampler.make_schedule(ddim_num_steps=ddim_steps, ddim_eta=0.0, verbose=False) + z_enc = sampler.stochastic_encode(x0, torch.tensor([t_enc_steps]*batch_size).to(st.session_state["device"] )) + + # Obliterate masked image + if z_mask is not None and obliterate: + random = torch.randn(z_mask.shape, device=z_enc.device) + z_enc = (z_mask * random) + ((1-z_mask) * z_enc) + + # decode it + samples_ddim = sampler.decode(z_enc, conditioning, t_enc_steps, + unconditional_guidance_scale=cfg_scale, + unconditional_conditioning=unconditional_conditioning, + z_mask=z_mask, x0=x0) + return samples_ddim + + + + if loopback: + output_images, info = None, None + history = [] + initial_seed = None + + do_color_correction = False + try: + from skimage import exposure + do_color_correction = True + except: + print("Install scikit-image to perform color correction on loopback") + + for i in range(1): + if do_color_correction and i == 0: + correction_target = cv2.cvtColor(np.asarray(init_img.copy()), cv2.COLOR_RGB2LAB) + + output_images, seed, info, stats = process_images( + outpath=outpath, + func_init=init, + func_sample=sample, + prompt=prompt, + seed=seed, + sampler_name=sampler_name, + save_grid=save_grid, + batch_size=1, + n_iter=n_iter, + steps=ddim_steps, + cfg_scale=cfg_scale, + width=width, + height=height, + prompt_matrix=separate_prompts, + use_GFPGAN=use_GFPGAN, + use_RealESRGAN=use_RealESRGAN, # Forcefully disable upscaling when using loopback + realesrgan_model_name=RealESRGAN_model, + fp=fp, + normalize_prompt_weights=normalize_prompt_weights, + save_individual_images=save_individual_images, + init_img=init_img, + init_mask=init_mask, + keep_mask=keep_mask, + mask_blur_strength=mask_blur_strength, + denoising_strength=denoising_strength, + resize_mode=resize_mode, + uses_loopback=loopback, + uses_random_seed_loopback=random_seed_loopback, + sort_samples=group_by_prompt, + write_info_files=write_info_files, + jpg_sample=save_as_jpg + ) + + if initial_seed is None: + initial_seed = seed + + init_img = output_images[0] + + if do_color_correction and correction_target is not None: + init_img = Image.fromarray(cv2.cvtColor(exposure.match_histograms( + cv2.cvtColor( + np.asarray(init_img), + cv2.COLOR_RGB2LAB + ), + correction_target, + channel_axis=2 + ), cv2.COLOR_LAB2RGB).astype("uint8")) + + if not random_seed_loopback: + seed = seed + 1 + else: + seed = seed_to_int(None) + + denoising_strength = max(denoising_strength * 0.95, 0.1) + history.append(init_img) + + output_images = history + seed = initial_seed + + else: + output_images, seed, info, stats = process_images( + outpath=outpath, + func_init=init, + func_sample=sample, + prompt=prompt, + seed=seed, + sampler_name=sampler_name, + save_grid=save_grid, + batch_size=batch_size, + n_iter=n_iter, + steps=ddim_steps, + cfg_scale=cfg_scale, + width=width, + height=height, + prompt_matrix=separate_prompts, + use_GFPGAN=use_GFPGAN, + use_RealESRGAN=use_RealESRGAN, + realesrgan_model_name=RealESRGAN_model, + fp=fp, + normalize_prompt_weights=normalize_prompt_weights, + save_individual_images=save_individual_images, + init_img=init_img, + init_mask=init_mask, + keep_mask=keep_mask, + mask_blur_strength=2, + denoising_strength=denoising_strength, + resize_mode=resize_mode, + uses_loopback=loopback, + sort_samples=group_by_prompt, + write_info_files=write_info_files, + jpg_sample=save_as_jpg + ) + + del sampler + + return output_images, seed, info, stats + +#@retry(RuntimeError, tries=3) +def txt2img(prompt: str, ddim_steps: int, sampler_name: str, realesrgan_model_name: str, + n_iter: int, batch_size: int, cfg_scale: float, seed: Union[int, str, None], + height: int, width: int, separate_prompts:bool = False, normalize_prompt_weights:bool = True, + save_individual_images: bool = True, save_grid: bool = True, group_by_prompt: bool = True, + save_as_jpg: bool = True, use_GFPGAN: bool = True, use_RealESRGAN: bool = True, + RealESRGAN_model: str = "RealESRGAN_x4plus_anime_6B", fp = None, variant_amount: float = None, + variant_seed: int = None, ddim_eta:float = 0.0, write_info_files:bool = True): + + outpath = defaults.general.outdir_txt2img or defaults.general.outdir or "outputs/txt2img-samples" + + err = False + seed = seed_to_int(seed) + + #prompt_matrix = 0 in toggles + #normalize_prompt_weights = 1 in toggles + #skip_save = 2 not in toggles + #save_grid = 3 not in toggles + #sort_samples = 4 in toggles + #write_info_files = 5 in toggles + #jpg_sample = 6 in toggles + #use_GFPGAN = 7 in toggles + #use_RealESRGAN = 8 in toggles + + if sampler_name == 'PLMS': + sampler = PLMSSampler(st.session_state["model"]) + elif sampler_name == 'DDIM': + sampler = DDIMSampler(st.session_state["model"]) + elif sampler_name == 'k_dpm_2_a': + sampler = KDiffusionSampler(st.session_state["model"],'dpm_2_ancestral') + elif sampler_name == 'k_dpm_2': + sampler = KDiffusionSampler(st.session_state["model"],'dpm_2') + elif sampler_name == 'k_euler_a': + sampler = KDiffusionSampler(st.session_state["model"],'euler_ancestral') + elif sampler_name == 'k_euler': + sampler = KDiffusionSampler(st.session_state["model"],'euler') + elif sampler_name == 'k_heun': + sampler = KDiffusionSampler(st.session_state["model"],'heun') + elif sampler_name == 'k_lms': + sampler = KDiffusionSampler(st.session_state["model"],'lms') + else: + raise Exception("Unknown sampler: " + sampler_name) + + def init(): + pass + + def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name): + samples_ddim, _ = sampler.sample(S=ddim_steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=cfg_scale, + unconditional_conditioning=unconditional_conditioning, eta=ddim_eta, x_T=x, img_callback=generation_callback, + log_every_t=int(defaults.general.update_preview_frequency)) + + return samples_ddim + + #try: + output_images, seed, info, stats = process_images( + outpath=outpath, + func_init=init, + func_sample=sample, + prompt=prompt, + seed=seed, + sampler_name=sampler_name, + save_grid=save_grid, + batch_size=batch_size, + n_iter=n_iter, + steps=ddim_steps, + cfg_scale=cfg_scale, + width=width, + height=height, + prompt_matrix=separate_prompts, + use_GFPGAN=use_GFPGAN, + use_RealESRGAN=use_RealESRGAN, + realesrgan_model_name=realesrgan_model_name, + fp=fp, + ddim_eta=ddim_eta, + normalize_prompt_weights=normalize_prompt_weights, + save_individual_images=save_individual_images, + sort_samples=group_by_prompt, + write_info_files=write_info_files, + jpg_sample=save_as_jpg, + variant_amount=variant_amount, + variant_seed=variant_seed, + ) + + del sampler + + return output_images, seed, info, stats + + #except RuntimeError as e: + #err = e + #err_msg = f'CRASHED:


Please wait while the program restarts.' + #stats = err_msg + #return [], seed, 'err', stats + + + + +# functions to load css locally OR remotely starts here. Options exist for future flexibility. Called as st.markdown with unsafe_allow_html as css injection +# TODO, maybe look into async loading the file especially for remote fetching +def local_css(file_name): + with open(file_name) as f: + st.markdown(f'', unsafe_allow_html=True) + +def remote_css(url): + st.markdown(f'', unsafe_allow_html=True) + +def load_css(isLocal, nameOrURL): + if(isLocal): + local_css(nameOrURL) + else: + remote_css(nameOrURL) + + +# main functions to define streamlit layout here +def layout(): + + st.set_page_config(page_title="Stable Diffusion Playground", layout="wide", initial_sidebar_state="collapsed") + + with st.empty(): + # load css as an external file, function has an option to local or remote url. Potential use when running from cloud infra that might not have access to local path. + load_css(True, 'frontend/css/streamlit.main.css') + + # check if the models exist on their respective folders + if os.path.exists(os.path.join(defaults.general.GFPGAN_dir, "experiments", "pretrained_models", "GFPGANv1.3.pth")): + GFPGAN_available = True + else: + GFPGAN_available = False + + if os.path.exists(os.path.join(defaults.general.RealESRGAN_dir, "experiments","pretrained_models", f"{defaults.general.RealESRGAN_model}.pth")): + RealESRGAN_available = True + else: + RealESRGAN_available = False + + with st.sidebar: + # we should use an expander and group things together when more options are added so the sidebar is not too messy. + #with st.expander("Global Settings:"): + st.write("Global Settings:") + defaults.general.update_preview = st.checkbox("Update Image Preview", value=defaults.general.update_preview, + help="If enabled the image preview will be updated during the generation instead of at the end. You can use the Update Preview \ + Frequency option bellow to customize how frequent it's updated. By default this is enabled and the frequency is set to 1 step.") + defaults.general.update_preview_frequency = st.text_input("Update Image Preview Frequency", value=defaults.general.update_preview_frequency, + help="Frequency in steps at which the the preview image is updated. By default the frequency is set to 1 step.") + + + + txt2img_tab, img2img_tab, txt2video, postprocessing_tab = st.tabs(["Text-to-Image Unified", "Image-to-Image Unified", "Text-to-Video","Post-Processing"]) + + with txt2img_tab: + with st.form("txt2img-inputs"): + st.session_state["generation_mode"] = "txt2img" + + input_col1, generate_col1 = st.columns([10,1]) + with input_col1: + #prompt = st.text_area("Input Text","") + prompt = st.text_input("Input Text","", placeholder="A corgi wearing a top hat as an oil painting.") + + # Every form must have a submit button, the extra blank spaces is a temp way to align it with the input field. Needs to be done in CSS or some other way. + generate_col1.write("") + generate_col1.write("") + generate_button = generate_col1.form_submit_button("Generate") + + # creating the page layout using columns + col1, col2, col3 = st.columns([1,2,1], gap="large") + + with col1: + width = st.slider("Width:", min_value=64, max_value=1024, value=defaults.txt2img.width, step=64) + height = st.slider("Height:", min_value=64, max_value=1024, value=defaults.txt2img.height, step=64) + cfg_scale = st.slider("CFG (Classifier Free Guidance Scale):", min_value=1.0, max_value=30.0, value=defaults.txt2img.cfg_scale, step=0.5, help="How strongly the image should follow the prompt.") + seed = st.text_input("Seed:", value=defaults.txt2img.seed, help=" The seed to use, if left blank a random seed will be generated.") + batch_count = st.slider("Batch count.", min_value=1, max_value=100, value=defaults.txt2img.batch_count, step=1, help="How many iterations or batches of images to generate in total.") + #batch_size = st.slider("Batch size", min_value=1, max_value=250, value=defaults.txt2img.batch_size, step=1, + #help="How many images are at once in a batch.\ + #It increases the VRAM usage a lot but if you have enough VRAM it can reduce the time it takes to finish generation as more images are generated at once.\ + #Default: 1") + + with col2: + preview_tab, gallery_tab = st.tabs(["Preview", "Gallery"]) + + with preview_tab: + #st.write("Image") + #Image for testing + #image = Image.open(requests.get("https://icon-library.com/images/image-placeholder-icon/image-placeholder-icon-13.jpg", stream=True).raw).convert('RGB') + #new_image = image.resize((175, 240)) + #preview_image = st.image(image) + + # create an empty container for the image, progress bar, etc so we can update it later and use session_state to hold them globally. + st.session_state["preview_image"] = st.empty() + + st.session_state["loading"] = st.empty() + + st.session_state["progress_bar_text"] = st.empty() + st.session_state["progress_bar"] = st.empty() + + message = st.empty() + + with gallery_tab: + st.write('Here should be the image gallery, if I could make a grid in streamlit.') + + with col3: + st.session_state.sampling_steps = st.slider("Sampling Steps", value=defaults.txt2img.sampling_steps, min_value=1, max_value=250) + + sampler_name_list = ["k_lms", "k_euler", "k_euler_a", "k_dpm_2", "k_dpm_2_a", "k_heun", "PLMS", "DDIM"] + sampler_name = st.selectbox("Sampling method", sampler_name_list, + index=sampler_name_list.index(defaults.txt2img.default_sampler), help="Sampling method to use. Default: k_euler") + + + + #basic_tab, advanced_tab = st.tabs(["Basic", "Advanced"]) + + #with basic_tab: + #summit_on_enter = st.radio("Submit on enter?", ("Yes", "No"), horizontal=True, + #help="Press the Enter key to summit, when 'No' is selected you can use the Enter key to write multiple lines.") + + with st.expander("Advanced"): + separate_prompts = st.checkbox("Create Prompt Matrix.", value=False, help="Separate multiple prompts using the `|` character, and get all combinations of them.") + normalize_prompt_weights = st.checkbox("Normalize Prompt Weights.", value=True, help="Ensure the sum of all weights add up to 1.0") + save_individual_images = st.checkbox("Save individual images.", value=True, help="Save each image generated before any filter or enhancement is applied.") + save_grid = st.checkbox("Save grid",value=True, help="Save a grid with all the images generated into a single image.") + group_by_prompt = st.checkbox("Group results by prompt", value=True, + help="Saves all the images with the same prompt into the same folder. When using a prompt matrix each prompt combination will have its own folder.") + write_info_files = st.checkbox("Write Info file", value=True, help="Save a file next to the image with informartion about the generation.") + save_as_jpg = st.checkbox("Save samples as jpg", value=False, help="Saves the images as jpg instead of png.") + + if GFPGAN_available: + use_GFPGAN = st.checkbox("Use GFPGAN", value=defaults.txt2img.use_GFPGAN, help="Uses the GFPGAN model to improve faces after the generation. This greatly improve the quality and consistency of faces but uses extra VRAM. Disable if you need the extra VRAM.") + else: + use_GFPGAN = False + + if RealESRGAN_available: + use_RealESRGAN = st.checkbox("Use RealESRGAN", value=defaults.txt2img.use_RealESRGAN, help="Uses the RealESRGAN model to upscale the images after the generation. This greatly improve the quality and lets you have high resolution images but uses extra VRAM. Disable if you need the extra VRAM.") + RealESRGAN_model = st.selectbox("RealESRGAN model", ["RealESRGAN_x4plus", "RealESRGAN_x4plus_anime_6B"], index=0) + else: + use_RealESRGAN = False + RealESRGAN_model = "RealESRGAN_x4plus" + + variant_amount = st.slider("Variant Amount:", value=defaults.txt2img.variant_amount, min_value=0.0, max_value=1.0, step=0.01) + variant_seed = st.text_input("Variant Seed:", value=defaults.txt2img.seed, help="The seed to use when generating a variant, if left blank a random seed will be generated.") + + + if generate_button: + #print("Loading models") + # load the models when we hit the generate button for the first time, it wont be loaded after that so dont worry. + load_models(False, use_GFPGAN, use_RealESRGAN, RealESRGAN_model) + + try: + output_images, seed, info, stats = txt2img(prompt, st.session_state.sampling_steps, sampler_name, RealESRGAN_model, batch_count, 1, + cfg_scale, seed, height, width, separate_prompts, normalize_prompt_weights, save_individual_images, + save_grid, group_by_prompt, save_as_jpg, use_GFPGAN, use_RealESRGAN, RealESRGAN_model, fp=defaults.general.fp, + variant_amount=variant_amount, variant_seed=variant_seed, write_info_files=write_info_files) + + message.success('Done!', icon="✅") + + except (StopException, KeyError): + print(f"Received Streamlit StopException") + + # this will render all the images at the end of the generation but its better if its moved to a second tab inside col2 and shown as a gallery. + # use the current col2 first tab to show the preview_img and update it as its generated. + #preview_image.image(output_images) + + with img2img_tab: + with st.form("img2img-inputs"): + st.session_state["generation_mode"] = "img2img" + + img2img_input_col, img2img_generate_col = st.columns([10,1]) + with img2img_input_col: + #prompt = st.text_area("Input Text","") + prompt = st.text_input("Input Text","", placeholder="A corgi wearing a top hat as an oil painting.") + + # Every form must have a submit button, the extra blank spaces is a temp way to align it with the input field. Needs to be done in CSS or some other way. + img2img_generate_col.write("") + img2img_generate_col.write("") + generate_button = img2img_generate_col.form_submit_button("Generate") + + + # creating the page layout using columns + col1_img2img_layout, col2_img2img_layout, col3_img2img_layout = st.columns([1,2,2], gap="small") + + with col1_img2img_layout: + st.session_state["sampling_steps"] = st.slider("Sampling Steps", value=defaults.img2img.sampling_steps, min_value=1, max_value=250) + st.session_state["sampler_name"] = st.selectbox("Sampling method", ["k_lms", "k_euler", "k_euler_a", "k_dpm_2", "k_dpm_2_a", "k_heun", "PLMS", "DDIM"], + index=0, help="Sampling method to use. Default: k_lms") + + uploaded_images = st.file_uploader("Upload Image", accept_multiple_files=False, type=["png", "jpg", "jpeg"], + help="Upload an image which will be used for the image to image generation." + ) + + width = st.slider("Width:", min_value=64, max_value=1024, value=defaults.img2img.width, step=64) + height = st.slider("Height:", min_value=64, max_value=1024, value=defaults.img2img.height, step=64) + seed = st.text_input("Seed:", value=defaults.img2img.seed, help=" The seed to use, if left blank a random seed will be generated.") + batch_count = st.slider("Batch count.", min_value=1, max_value=100, value=defaults.img2img.batch_count, step=1, help="How many iterations or batches of images to generate in total.") + + # + with st.expander("Advanced"): + separate_prompts = st.checkbox("Create Prompt Matrix.", value=defaults.img2img.separate_prompts, help="Separate multiple prompts using the `|` character, and get all combinations of them.") + normalize_prompt_weights = st.checkbox("Normalize Prompt Weights.", value=defaults.img2img.normalize_prompt_weights, help="Ensure the sum of all weights add up to 1.0") + loopback = st.checkbox("Loopback.", value=defaults.img2img.loopback, help="Use images from previous batch when creating next batch.") + random_seed_loopback = st.checkbox("Random loopback seed.", value=defaults.img2img.random_seed_loopback, help="Random loopback seed") + save_individual_images = st.checkbox("Save individual images.", value=True, help="Save each image generated before any filter or enhancement is applied.") + save_grid = st.checkbox("Save grid",value=defaults.img2img.save_grid, help="Save a grid with all the images generated into a single image.") + group_by_prompt = st.checkbox("Group results by prompt", value=defaults.img2img.group_by_prompt, + help="Saves all the images with the same prompt into the same folder. When using a prompt matrix each prompt combination will have its own folder.") + write_info_files = st.checkbox("Write Info file", value=True, help="Save a file next to the image with informartion about the generation.") + save_as_jpg = st.checkbox("Save samples as jpg", value=False, help="Saves the images as jpg instead of png.") + + if GFPGAN_available: + use_GFPGAN = st.checkbox("Use GFPGAN", value=defaults.img2img.use_GFPGAN, help="Uses the GFPGAN model to improve faces after the generation.\ + This greatly improve the quality and consistency of faces but uses extra VRAM. Disable if you need the extra VRAM.") + else: + use_GFPGAN = False + + if RealESRGAN_available: + use_RealESRGAN = st.checkbox("Use RealESRGAN", value=defaults.img2img.use_RealESRGAN, help="Uses the RealESRGAN model to upscale the images after the generation.\ + This greatly improve the quality and lets you have high resolution images but uses extra VRAM. Disable if you need the extra VRAM.") + RealESRGAN_model = st.selectbox("RealESRGAN model", ["RealESRGAN_x4plus", "RealESRGAN_x4plus_anime_6B"], index=0) + else: + use_RealESRGAN = False + RealESRGAN_model = "RealESRGAN_x4plus" + + variant_amount = st.slider("Variant Amount:", value=defaults.img2img.variant_amount, min_value=0.0, max_value=1.0, step=0.01) + variant_seed = st.text_input("Variant Seed:", value=defaults.img2img.variant_seed, help="The seed to use when generating a variant, if left blank a random seed will be generated.") + cfg_scale = st.slider("CFG (Classifier Free Guidance Scale):", min_value=1.0, max_value=30.0, value=defaults.img2img.cfg_scale, step=0.5, help="How strongly the image should follow the prompt.") + batch_size = st.slider("Batch size", min_value=1, max_value=100, value=defaults.img2img.batch_size, step=1, + help="How many images are at once in a batch.\ + It increases the VRAM usage a lot but if you have enough VRAM it can reduce the time it takes to finish generation as more images are generated at once.\ + Default: 1") + + st.session_state["denoising_strength"] = st.slider("Denoising Strength:", value=defaults.img2img.denoising_strength, min_value=0.01, max_value=1.0, step=0.01) + + + with col2_img2img_layout: + editor_tab = st.tabs(["Editor"]) + + editor_image = st.empty() + st.session_state["editor_image"] = editor_image + + if uploaded_images: + image = Image.open(uploaded_images).convert('RGB') + #img_array = np.array(image) # if you want to pass it to OpenCV + new_img = image.resize((width, height)) + st.image(new_img) + + + with col3_img2img_layout: + result_tab = st.tabs(["Result"]) + + # create an empty container for the image, progress bar, etc so we can update it later and use session_state to hold them globally. + preview_image = st.empty() + st.session_state["preview_image"] = preview_image + + #st.session_state["loading"] = st.empty() + + st.session_state["progress_bar_text"] = st.empty() + st.session_state["progress_bar"] = st.empty() + + + message = st.empty() + + #if uploaded_images: + #image = Image.open(uploaded_images).convert('RGB') + ##img_array = np.array(image) # if you want to pass it to OpenCV + #new_img = image.resize((width, height)) + #st.image(new_img, use_column_width=True) + + + if generate_button: + #print("Loading models") + # load the models when we hit the generate button for the first time, it wont be loaded after that so dont worry. + load_models(False, use_GFPGAN, use_RealESRGAN, RealESRGAN_model) + if uploaded_images: + image = Image.open(uploaded_images).convert('RGB') + new_img = image.resize((width, height)) + #img_array = np.array(image) # if you want to pass it to OpenCV + + try: + output_images, seed, info, stats = img2img(prompt=prompt, init_info=new_img, ddim_steps=st.session_state["sampling_steps"], + sampler_name=st.session_state["sampler_name"], n_iter=batch_count, + cfg_scale=cfg_scale, denoising_strength=st.session_state["denoising_strength"], variant_seed=variant_seed, + seed=seed, width=width, height=height, fp=defaults.general.fp, variant_amount=variant_amount, + ddim_eta=0.0, write_info_files=write_info_files, RealESRGAN_model=RealESRGAN_model, + separate_prompts=separate_prompts, normalize_prompt_weights=normalize_prompt_weights, + save_individual_images=save_individual_images, save_grid=save_grid, + group_by_prompt=group_by_prompt, save_as_jpg=save_as_jpg, use_GFPGAN=use_GFPGAN, + use_RealESRGAN=use_RealESRGAN if not loopback else False, loopback=loopback + ) + + #show a message when the generation is complete. + message.success('Done!', icon="✅") + + except (StopException, KeyError): + print(f"Received Streamlit StopException") + + # this will render all the images at the end of the generation but its better if its moved to a second tab inside col2 and shown as a gallery. + # use the current col2 first tab to show the preview_img and update it as its generated. + #preview_image.image(output_images, width=750) + + +if __name__ == '__main__': + layout() \ No newline at end of file diff --git a/webui-streamlit.cmd b/webui-streamlit.cmd new file mode 100644 index 0000000..2513b5a --- /dev/null +++ b/webui-streamlit.cmd @@ -0,0 +1,53 @@ +@echo off +set conda_env_name=ldm + +:: Put the path to conda directory after "=" sign if it's installed at non-standard path: +set custom_conda_path= + +IF NOT "%custom_conda_path%"=="" ( + set paths=%custom_conda_path%;%paths% +) +:: Put the path to conda directory in a file called "custom-conda-path.txt" if it's installed at non-standard path: +FOR /F %%i IN (custom-conda-path.txt) DO set custom_conda_path=%%i + +set paths=%ProgramData%\miniconda3 +set paths=%paths%;%USERPROFILE%\miniconda3 +set paths=%paths%;%ProgramData%\anaconda3 +set paths=%paths%;%USERPROFILE%\anaconda3 + +for %%a in (%paths%) do ( + IF NOT "%custom_conda_path%"=="" ( + set paths=%custom_conda_path%;%paths% + ) +) + +for %%a in (%paths%) do ( + if EXIST "%%a\Scripts\activate.bat" ( + SET CONDA_PATH=%%a + echo anaconda3/miniconda3 detected in %%a + goto :foundPath + ) +) + +IF "%CONDA_PATH%"=="" ( + echo anaconda3/miniconda3 not found. Install from here https://docs.conda.io/en/latest/miniconda.html + exit /b 1 +) + +call git stash +call git pull + +:foundPath +call "%CONDA_PATH%\Scripts\activate.bat" +call conda env create -n "%conda_env_name%" -f environment.yaml +call conda env update --name "%conda_env_name%" -f environment.yaml +call "%CONDA_PATH%\Scripts\activate.bat" "%conda_env_name%" +::python "%CD%"\scripts\relauncher.py + +:PROMPT +set SETUPTOOLS_USE_DISTUTILS=stdlib +IF EXIST "models\ldm\stable-diffusion-v1\model.ckpt" ( + python -m streamlit run scripts\webui_streamlit.py +) ELSE ( + ECHO Your model file does not exist! Place it in 'models\ldm\stable-diffusion-v1' with the name 'model.ckpt'. +) \ No newline at end of file diff --git a/webui.cmd b/webui.cmd index 6caeb33..d621d79 100644 --- a/webui.cmd +++ b/webui.cmd @@ -35,6 +35,9 @@ IF "%CONDA_PATH%"=="" ( exit /b 1 ) +call git stash +call git pull + :foundPath call "%CONDA_PATH%\Scripts\activate.bat" call conda env create -n "%conda_env_name%" -f environment.yaml From 0dffc3918d596ad36a32ac56ecf4d523f490ae5e Mon Sep 17 00:00:00 2001 From: Thomas Mello Date: Thu, 8 Sep 2022 15:43:08 +0300 Subject: [PATCH 02/27] fix: advanced editor (#827), close #811 refactor js_Call hook to take all gradio arguments --- frontend/css_and_js.py | 18 +++--------------- frontend/frontend.py | 2 +- frontend/js/index.js | 13 ++++++++----- 3 files changed, 12 insertions(+), 21 deletions(-) diff --git a/frontend/css_and_js.py b/frontend/css_and_js.py index 266cf43..64e6dd5 100644 --- a/frontend/css_and_js.py +++ b/frontend/css_and_js.py @@ -24,19 +24,6 @@ def js(opt): return data - -def js_painterro_launch(to_id): - return w(f"Painterro.init('{to_id}')") - -def js_move_image(from_id, to_id): - return w(f"moveImageFromGallery('{from_id}', '{to_id}')") - -def js_copy_to_clipboard(from_id): - return w(f"copyImageFromGalleryToClipboard('{from_id}')") - -def js_img2img_submit(prompt_row_id): - return w(f"clickFirstVisibleButton('{prompt_row_id}')") - # TODO : @altryne fix this to the new JS format js_copy_txt2img_output = "(x) => {navigator.clipboard.writeText(document.querySelector('gradio-app').shadowRoot.querySelector('#highlight .textfield').textContent.replace(/\s+/g,' ').replace(/: /g,':'))}" @@ -93,12 +80,13 @@ return [txt2img_prompt, parseInt(txt2img_width), parseInt(txt2img_height), parse """ -# @altryne this came up as conflict, still needed or no? # Wrap the typical SD method call into async closure for ease of use # Supplies the js function with a params object # That includes all the passed arguments and input from Gradio: x +# ATTENTION: x is an array of values of all components passed to your +# python event handler # Example call in Gradio component's event handler (pass the result to _js arg): # _js=call_JS("myJsMethod", arg1="string", arg2=100, arg3=[]) def call_JS(sd_method, **kwargs): param_str = json.dumps(kwargs) - return f"async (x) => {{ return await SD.{sd_method}({{ x, ...{param_str} }}) ?? []; }}" + return f"async (...x) => {{ return await SD.{sd_method}({{ x, ...{param_str} }}) ?? []; }}" diff --git a/frontend/frontend.py b/frontend/frontend.py index 41e8672..365f9d3 100644 --- a/frontend/frontend.py +++ b/frontend/frontend.py @@ -372,7 +372,7 @@ def draw_gradio_ui(opt, img2img=lambda x: x, txt2img=lambda x: x, imgproc=lambda rowId="prompt_row")) img2img_painterro_btn.click(None, - [img2img_image_editor], + [img2img_image_editor, img2img_image_mask, img2img_image_editor_mode], [img2img_image_editor, img2img_image_mask], _js=call_JS("Painterro.init", toId="img2img_editor") ) diff --git a/frontend/js/index.js b/frontend/js/index.js index 64b9f18..2afe2db 100644 --- a/frontend/js/index.js +++ b/frontend/js/index.js @@ -6,8 +6,9 @@ window.SD = (() => { class PainterroClass { static isOpen = false; static async init ({ x, toId }) { - const img = x; - const originalImage = Array.isArray(img) ? img[0] : img; + console.log(x) + + const originalImage = x[2] === 'Mask' ? x[1]?.image : x[0]; if (window.Painterro === undefined) { try { @@ -52,8 +53,8 @@ window.SD = (() => { return result ? this.success(result) : this.fallback(originalImage); } - static success (result) { return [result, result]; } - static fallback (image) { return [image, image]; } + static success (result) { return [result, { image: result, mask: result }] }; + static fallback (image) { return [image, { image: image, mask: image }] }; static load () { return new Promise((resolve, reject) => { const scriptId = '__painterro-script'; @@ -112,6 +113,7 @@ window.SD = (() => { el = new ElementCache(); Painterro = PainterroClass; moveImageFromGallery ({ x, fromId, toId }) { + x = x[0]; if (!Array.isArray(x) || x.length === 0) return; this.clearImageInput(this.el.get(`#${toId}`)); @@ -121,6 +123,7 @@ window.SD = (() => { return [x[i].replace('data:;','data:image/png;')]; } async copyImageFromGalleryToClipboard ({ x, fromId }) { + x = x[0]; if (!Array.isArray(x) || x.length === 0) return; const i = this.#getGallerySelectedIndex(this.el.get(`#${fromId}`)); @@ -147,7 +150,7 @@ window.SD = (() => { } } } - async gradioInputToClipboard ({ x }) { return this.copyToClipboard(x); } + async gradioInputToClipboard ({ x }) { return this.copyToClipboard(x[0]); } async copyToClipboard (value) { if (!value || typeof value === 'boolean') return; try { From c3b1facef3a17330b10b650bed3cc7399a7f1f88 Mon Sep 17 00:00:00 2001 From: hlky <106811348+hlky@users.noreply.github.com> Date: Thu, 8 Sep 2022 23:57:26 +0100 Subject: [PATCH 03/27] Update FUNDING.yml --- .github/FUNDING.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/FUNDING.yml b/.github/FUNDING.yml index 481557a..3d3c92e 100644 --- a/.github/FUNDING.yml +++ b/.github/FUNDING.yml @@ -1,2 +1,2 @@ ko_fi: hlky_ -github: [hlky, altryne] +github: [hlky, ZeroCool940711, altryne] From a05a4a9b0a38431b403faca7f154fdbd6841a068 Mon Sep 17 00:00:00 2001 From: Thomas Mello Date: Fri, 9 Sep 2022 13:09:13 +0300 Subject: [PATCH 04/27] chore: update maintenance scripts and docs (#891) * automate conda_env_name as per name in yaml * Embed installation links directly in README.md Include links to Windows, Linux, and Google Colab installations. * Fix conda update in webui.sh for pip bug * Add info about new PRs Co-authored-by: Hafiidz <3688500+Hafiidz@users.noreply.github.com> Co-authored-by: Tom Pham <54967380+TomPham97@users.noreply.github.com> Co-authored-by: GRMrGecko --- README.md | 7 ++++++- webui-streamlit.cmd | 6 +++++- webui.cmd | 5 ++++- webui.sh | 4 ++-- 4 files changed, 17 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index ce90ca3..e0f3295 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,17 @@ [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/altryne/sd-webui-colab/blob/main/Stable_Diffusion_WebUi_Altryne.ipynb) -# [Installation](https://github.com/sd-webui/stable-diffusion-webui/wiki/Installation) +# Installation instructions for [Windows](https://github.com/sd-webui/stable-diffusion-webui/wiki/Installation), [Linux](https://github.com/sd-webui/stable-diffusion-webui/wiki/Linux-Automated-Setup-Guide), or [Google Colab](https://colab.research.google.com/github/altryne/sd-webui-colab/blob/main/Stable_Diffusion_WebUi_Altryne.ipynb) ### Have an **issue**? * If the issue involves _a bug_ in **textual-inversion** create the issue on **_[sd-webui/stable-diffusion-webui](https://github.com/sd-webui/stable-diffusion-webui)_** * If you want to know how to **activate** or **use** textual-inversion see **_[hlky/sd-enable-textual-inversion](https://github.com/hlky/sd-enable-textual-inversion)_**. Activation not working? create the issue on **_[sd-webui/stable-diffusion-webui](https://github.com/sd-webui/stable-diffusion-webui)_** +### Want to contribute? + +Open new Pull Requests against `dev` branch! + +**If you're thinking about adding a new feature to Web UI focus on the Streamlit version (webui_streamlit.py) which is in active development.** ## More documentation about features, troubleshooting, common issues very soon ### Want to help with documentation? Documented something? Use [Discussions](https://github.com/sd-webui/stable-diffusion-webui/discussions) diff --git a/webui-streamlit.cmd b/webui-streamlit.cmd index 2513b5a..6ba0597 100644 --- a/webui-streamlit.cmd +++ b/webui-streamlit.cmd @@ -1,5 +1,9 @@ @echo off -set conda_env_name=ldm + +:: copy over the first line from environment.yaml, e.g. name: ldm, and take the second word after splitting by ":" delimiter +set /p first_line=< environment.yaml +for /f "tokens=2 delims=:" %%i in ("%first_line%") do set conda_env_name=%%i +echo Environment name is set as %conda_env_name% as per environment.yaml :: Put the path to conda directory after "=" sign if it's installed at non-standard path: set custom_conda_path= diff --git a/webui.cmd b/webui.cmd index d621d79..9ba6695 100644 --- a/webui.cmd +++ b/webui.cmd @@ -1,6 +1,9 @@ @echo off -set conda_env_name=ldm +:: copy over the first line from environment.yaml, e.g. name: ldm, and take the second word after splitting by ":" delimiter +set /p first_line=< environment.yaml +for /f "tokens=2 delims=:" %%i in ("%first_line%") do set conda_env_name=%%i +echo Environment name is set as %conda_env_name% as per environment.yaml :: Put the path to conda directory after "=" sign if it's installed at non-standard path: set custom_conda_path= diff --git a/webui.sh b/webui.sh index ea7028f..7b07a49 100755 --- a/webui.sh +++ b/webui.sh @@ -37,7 +37,7 @@ if ! conda env list | grep ".*${ENV_NAME}.*" >/dev/null 2>&1; then ENV_UPDATED=1 elif [[ ! -z $CONDA_FORCE_UPDATE && $CONDA_FORCE_UPDATE == "true" ]] || (( $ENV_MODIFIED > $ENV_MODIFIED_CACHED )); then echo "Updating conda env: ${ENV_NAME} ..." - conda env update --file $ENV_FILE --prune + PIP_EXISTS_ACTION=w conda env update --file $ENV_FILE --prune ENV_UPDATED=1 fi @@ -56,4 +56,4 @@ if [ ! -e "models/ldm/stable-diffusion-v1/model.ckpt" ]; then exit 1 fi -python scripts/relauncher.py \ No newline at end of file +python scripts/relauncher.py From ae77fcb8ae534945fb5e0adf34ba2210d3629bd9 Mon Sep 17 00:00:00 2001 From: Hafiidz <3688500+Hafiidz@users.noreply.github.com> Date: Fri, 9 Sep 2022 19:11:11 +0800 Subject: [PATCH 05/27] Urgent fix to remove whitespace in conda env name --- webui-streamlit.cmd | 4 +++- webui.cmd | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/webui-streamlit.cmd b/webui-streamlit.cmd index 6ba0597..b952a21 100644 --- a/webui-streamlit.cmd +++ b/webui-streamlit.cmd @@ -2,7 +2,9 @@ :: copy over the first line from environment.yaml, e.g. name: ldm, and take the second word after splitting by ":" delimiter set /p first_line=< environment.yaml -for /f "tokens=2 delims=:" %%i in ("%first_line%") do set conda_env_name=%%i +for /f "tokens=2 delims=:" %%i in ("%first_line%") do set untrimmed_conda_env_name=%%i +:: remove whitespaces +for /f "tokens=* delims= " %%a in ("%untrimmed_conda_env_name%") do set conda_env_name=%%a echo Environment name is set as %conda_env_name% as per environment.yaml :: Put the path to conda directory after "=" sign if it's installed at non-standard path: diff --git a/webui.cmd b/webui.cmd index 9ba6695..378f8d7 100644 --- a/webui.cmd +++ b/webui.cmd @@ -2,7 +2,9 @@ :: copy over the first line from environment.yaml, e.g. name: ldm, and take the second word after splitting by ":" delimiter set /p first_line=< environment.yaml -for /f "tokens=2 delims=:" %%i in ("%first_line%") do set conda_env_name=%%i +for /f "tokens=2 delims=:" %%i in ("%first_line%") do set untrimmed_conda_env_name=%%i +:: remove whitespaces +for /f "tokens=* delims= " %%a in ("%untrimmed_conda_env_name%") do set conda_env_name=%%a echo Environment name is set as %conda_env_name% as per environment.yaml :: Put the path to conda directory after "=" sign if it's installed at non-standard path: From e4a0047c774aa61730c62535ef305355c9233f72 Mon Sep 17 00:00:00 2001 From: AK391 <81195143+AK391@users.noreply.github.com> Date: Fri, 9 Sep 2022 08:51:51 -0400 Subject: [PATCH 06/27] add gradio contribution --- README.md | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index e0f3295..cb91a4f 100644 --- a/README.md +++ b/README.md @@ -9,9 +9,17 @@ ### Want to contribute? +Gradio version (stable) + +Open new Pull request on 'main' branch! + +for Gradio check out the [docs](https://gradio.app/docs/) to contribute + +Have an issue or feature request with Gradio? open a issue/feature request on github for support: https://github.com/gradio-app/gradio/issues + Open new Pull Requests against `dev` branch! -**If you're thinking about adding a new feature to Web UI focus on the Streamlit version (webui_streamlit.py) which is in active development.** +**New features can be added to Gradio or Streamlit version** ## More documentation about features, troubleshooting, common issues very soon ### Want to help with documentation? Documented something? Use [Discussions](https://github.com/sd-webui/stable-diffusion-webui/discussions) From b9d97c9816251933d094f1dae43d2c631a07db7a Mon Sep 17 00:00:00 2001 From: Thomas Mello Date: Fri, 9 Sep 2022 17:24:27 +0300 Subject: [PATCH 07/27] chore: fix branch names --- README.md | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index cb91a4f..4e9e054 100644 --- a/README.md +++ b/README.md @@ -11,15 +11,14 @@ Gradio version (stable) -Open new Pull request on 'main' branch! +Open new Pull Requests on `dev` branch! for Gradio check out the [docs](https://gradio.app/docs/) to contribute Have an issue or feature request with Gradio? open a issue/feature request on github for support: https://github.com/gradio-app/gradio/issues -Open new Pull Requests against `dev` branch! -**New features can be added to Gradio or Streamlit version** +**New features can be added to Gradio or Streamlit versions** ## More documentation about features, troubleshooting, common issues very soon ### Want to help with documentation? Documented something? Use [Discussions](https://github.com/sd-webui/stable-diffusion-webui/discussions) From f62e46934b6c2c2fdf7e894d6a3a514cc6d93341 Mon Sep 17 00:00:00 2001 From: hlky <106811348+hlky@users.noreply.github.com> Date: Fri, 9 Sep 2022 22:30:54 +0100 Subject: [PATCH 08/27] Revert "Merge pull request #894 from Hafiidz/master" This reverts commit 67515f0627135f6a0d3f6f61680bc45eb01a30f6, reversing changes made to a05a4a9b0a38431b403faca7f154fdbd6841a068. --- webui-streamlit.cmd | 4 +--- webui.cmd | 4 +--- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/webui-streamlit.cmd b/webui-streamlit.cmd index b952a21..6ba0597 100644 --- a/webui-streamlit.cmd +++ b/webui-streamlit.cmd @@ -2,9 +2,7 @@ :: copy over the first line from environment.yaml, e.g. name: ldm, and take the second word after splitting by ":" delimiter set /p first_line=< environment.yaml -for /f "tokens=2 delims=:" %%i in ("%first_line%") do set untrimmed_conda_env_name=%%i -:: remove whitespaces -for /f "tokens=* delims= " %%a in ("%untrimmed_conda_env_name%") do set conda_env_name=%%a +for /f "tokens=2 delims=:" %%i in ("%first_line%") do set conda_env_name=%%i echo Environment name is set as %conda_env_name% as per environment.yaml :: Put the path to conda directory after "=" sign if it's installed at non-standard path: diff --git a/webui.cmd b/webui.cmd index 378f8d7..9ba6695 100644 --- a/webui.cmd +++ b/webui.cmd @@ -2,9 +2,7 @@ :: copy over the first line from environment.yaml, e.g. name: ldm, and take the second word after splitting by ":" delimiter set /p first_line=< environment.yaml -for /f "tokens=2 delims=:" %%i in ("%first_line%") do set untrimmed_conda_env_name=%%i -:: remove whitespaces -for /f "tokens=* delims= " %%a in ("%untrimmed_conda_env_name%") do set conda_env_name=%%a +for /f "tokens=2 delims=:" %%i in ("%first_line%") do set conda_env_name=%%i echo Environment name is set as %conda_env_name% as per environment.yaml :: Put the path to conda directory after "=" sign if it's installed at non-standard path: From f659e5349ddfe924f95b2c51702abb1e47dd15e1 Mon Sep 17 00:00:00 2001 From: hlky <106811348+hlky@users.noreply.github.com> Date: Fri, 9 Sep 2022 22:34:24 +0100 Subject: [PATCH 09/27] revert chore: update maintenance scripts and docs (#891) --- webui-streamlit.cmd | 6 +----- webui.cmd | 5 +---- webui.sh | 4 ++-- 3 files changed, 4 insertions(+), 11 deletions(-) diff --git a/webui-streamlit.cmd b/webui-streamlit.cmd index 6ba0597..2513b5a 100644 --- a/webui-streamlit.cmd +++ b/webui-streamlit.cmd @@ -1,9 +1,5 @@ @echo off - -:: copy over the first line from environment.yaml, e.g. name: ldm, and take the second word after splitting by ":" delimiter -set /p first_line=< environment.yaml -for /f "tokens=2 delims=:" %%i in ("%first_line%") do set conda_env_name=%%i -echo Environment name is set as %conda_env_name% as per environment.yaml +set conda_env_name=ldm :: Put the path to conda directory after "=" sign if it's installed at non-standard path: set custom_conda_path= diff --git a/webui.cmd b/webui.cmd index 9ba6695..d621d79 100644 --- a/webui.cmd +++ b/webui.cmd @@ -1,9 +1,6 @@ @echo off -:: copy over the first line from environment.yaml, e.g. name: ldm, and take the second word after splitting by ":" delimiter -set /p first_line=< environment.yaml -for /f "tokens=2 delims=:" %%i in ("%first_line%") do set conda_env_name=%%i -echo Environment name is set as %conda_env_name% as per environment.yaml +set conda_env_name=ldm :: Put the path to conda directory after "=" sign if it's installed at non-standard path: set custom_conda_path= diff --git a/webui.sh b/webui.sh index 7b07a49..ea7028f 100755 --- a/webui.sh +++ b/webui.sh @@ -37,7 +37,7 @@ if ! conda env list | grep ".*${ENV_NAME}.*" >/dev/null 2>&1; then ENV_UPDATED=1 elif [[ ! -z $CONDA_FORCE_UPDATE && $CONDA_FORCE_UPDATE == "true" ]] || (( $ENV_MODIFIED > $ENV_MODIFIED_CACHED )); then echo "Updating conda env: ${ENV_NAME} ..." - PIP_EXISTS_ACTION=w conda env update --file $ENV_FILE --prune + conda env update --file $ENV_FILE --prune ENV_UPDATED=1 fi @@ -56,4 +56,4 @@ if [ ! -e "models/ldm/stable-diffusion-v1/model.ckpt" ]; then exit 1 fi -python scripts/relauncher.py +python scripts/relauncher.py \ No newline at end of file From 28338ceb03b6288b011c48c2d030c0c44319ba7d Mon Sep 17 00:00:00 2001 From: AK391 <81195143+AK391@users.noreply.github.com> Date: Sat, 10 Sep 2022 09:03:41 -0400 Subject: [PATCH 10/27] add gradio discord link --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 4e9e054..651fdec 100644 --- a/README.md +++ b/README.md @@ -17,6 +17,7 @@ for Gradio check out the [docs](https://gradio.app/docs/) to contribute Have an issue or feature request with Gradio? open a issue/feature request on github for support: https://github.com/gradio-app/gradio/issues +Need more support with Gradio? We have a discord channel called `gradio-stable-diffusion` for Q&A with the gradio authors, to join use this link https://discord.gg/Qs8AsnX7Jd, then go to `role-assignment` and click gradio to join the `gradio` channels. **New features can be added to Gradio or Streamlit versions** From 2236e8b5854092054e2c30edc559006ace53bf96 Mon Sep 17 00:00:00 2001 From: node7-ai <44690655+node7-ai@users.noreply.github.com> Date: Sat, 10 Sep 2022 16:48:11 -0600 Subject: [PATCH 11/27] fix: sampler name in GoBig --- scripts/webui.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/webui.py b/scripts/webui.py index fe39152..0de1c1e 100644 --- a/scripts/webui.py +++ b/scripts/webui.py @@ -1623,7 +1623,7 @@ def imgproc(image,image_batch,imgproc_prompt,imgproc_toggles, imgproc_upscale_to elif sampler_name == 'k_dpm_2_a': sampler = KDiffusionSampler(model,'dpm_2_ancestral') elif sampler_name == 'k_dpm_2': - sampler_name = KDiffusionSampler(model,'dpm_2') + sampler = KDiffusionSampler(model,'dpm_2') elif sampler_name == 'k_euler_a': sampler = KDiffusionSampler(model,'euler_ancestral') elif sampler_name == 'k_euler': From 7623a5734740025d79b710f3744bff9276e1467b Mon Sep 17 00:00:00 2001 From: Thomas Mello Date: Sat, 10 Sep 2022 19:08:30 +0300 Subject: [PATCH 12/27] add PR template --- .github/PULL_REQUEST_TEMPLATE.md | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) create mode 100644 .github/PULL_REQUEST_TEMPLATE.md diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md new file mode 100644 index 0000000..d0ba547 --- /dev/null +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -0,0 +1,16 @@ +# Description + +Please include: +* relevant motivation +* a summary of the change +* which issue is fixed. +* any additional dependencies that are required for this change. + +Closes: # (issue) + +# Checklist: + +- [ ] I have changed the base branch to `dev` +- [ ] I have performed a self-review of my own code +- [ ] I have commented my code in hard-to-understand areas +- [ ] I have made corresponding changes to the documentation \ No newline at end of file From e6a9e5d968aad020e54045a0923d8c6aa0b54591 Mon Sep 17 00:00:00 2001 From: Thomas Mello Date: Tue, 13 Sep 2022 00:37:56 +0300 Subject: [PATCH 13/27] fix: disable live prompt parsing, fix #676 --- frontend/frontend.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/frontend/frontend.py b/frontend/frontend.py index 365f9d3..c3c1a95 100644 --- a/frontend/frontend.py +++ b/frontend/frontend.py @@ -139,14 +139,16 @@ def draw_gradio_ui(opt, img2img=lambda x: x, txt2img=lambda x: x, imgproc=lambda # txt2img_width.change(fn=uifn.update_dimensions_info, inputs=[txt2img_width, txt2img_height], outputs=txt2img_dimensions_info_text_box) # txt2img_height.change(fn=uifn.update_dimensions_info, inputs=[txt2img_width, txt2img_height], outputs=txt2img_dimensions_info_text_box) - live_prompt_params = [txt2img_prompt, txt2img_width, txt2img_height, txt2img_steps, txt2img_seed, - txt2img_batch_count, txt2img_cfg] - txt2img_prompt.change( - fn=None, - inputs=live_prompt_params, - outputs=live_prompt_params, - _js=js_parse_prompt - ) + # Temporarily disable prompt parsing until memory issues could be solved + # See #676 + # live_prompt_params = [txt2img_prompt, txt2img_width, txt2img_height, txt2img_steps, txt2img_seed, + # txt2img_batch_count, txt2img_cfg] + # txt2img_prompt.change( + # fn=None, + # inputs=live_prompt_params, + # outputs=live_prompt_params, + # _js=js_parse_prompt + # ) with gr.TabItem("Image-to-Image Unified", id="img2img_tab"): with gr.Row(elem_id="prompt_row"): From fab5765fe453117590d43d5999b97cbcd782cd08 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9rgio?= Date: Wed, 14 Sep 2022 00:42:10 +0300 Subject: [PATCH 14/27] Patch docker conda install pip requirements (#1094) Install getting stuck on user prompt: sd | Obtaining GFPGAN from git+https://github.com/TencentARC/GFPGAN#egg=GFPGAN (from -r /sd/condaenv.h7qk_3wn.requirements.txt (line 25)) sd | What to do? (i)gnore, (w)ipe, (b)ackup sd | failed sd | sd | CondaEnvException: Pip failed this defaults the pip exists action to wiping the already existing lib and re-downloading it --- docker-compose.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docker-compose.yml b/docker-compose.yml index c8e3130..11da1c3 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -7,6 +7,8 @@ services: context: . dockerfile: Dockerfile env_file: .env_docker + environment: + PIP_EXISTS_ACTION: w volumes: - .:/sd - ./outputs:/sd/outputs From 833a91047df999302f699637768741cecee9c37b Mon Sep 17 00:00:00 2001 From: Charlie Date: Tue, 13 Sep 2022 18:53:27 -0400 Subject: [PATCH 15/27] Publish Streamlit ports (#1102) --- docker-compose.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/docker-compose.yml b/docker-compose.yml index 11da1c3..968df1c 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -16,6 +16,7 @@ services: - root_profile:/root ports: - '7860:7860' + - '8501:8501' deploy: resources: reservations: From 1b08d867874044004138ae07840d46f1e4dc2985 Mon Sep 17 00:00:00 2001 From: Thomas Mello Date: Thu, 15 Sep 2022 00:38:52 +0300 Subject: [PATCH 16/27] chore: cmd readability (#1135) --- update_to_latest.cmd | 60 +++++++++++++++++++++++++++++++++++ webui-streamlit.cmd | 73 +++++++++++++++++++++++------------------- webui.cmd | 75 +++++++++++++++++++++++++------------------- 3 files changed, 143 insertions(+), 65 deletions(-) create mode 100644 update_to_latest.cmd diff --git a/update_to_latest.cmd b/update_to_latest.cmd new file mode 100644 index 0000000..8ed32e3 --- /dev/null +++ b/update_to_latest.cmd @@ -0,0 +1,60 @@ +@echo off +cd %~dp0 + +:: Duplicate code to find miniconda + +IF EXIST custom-conda-path.txt ( + FOR /F %%i IN (custom-conda-path.txt) DO set v_custom_path=%%i +) + +set v_paths=%ProgramData%\miniconda3 +set v_paths=%v_paths%;%USERPROFILE%\miniconda3 +set v_paths=%v_paths%;%ProgramData%\anaconda3 +set v_paths=%v_paths%;%USERPROFILE%\anaconda3 + +for %%a in (%v_paths%) do ( + IF NOT "%v_custom_path%"=="" ( + set v_paths=%v_custom_path%;%v_paths% + ) +) + +for %%a in (%v_paths%) do ( + if EXIST "%%a\Scripts\activate.bat" ( + SET v_conda_path=%%a + echo anaconda3/miniconda3 detected in %%a + ) +) + +IF "%v_conda_path%"=="" ( + echo anaconda3/miniconda3 not found. Install from here https://docs.conda.io/en/latest/miniconda.html + pause + exit /b 1 +) + +:: Update + +echo Stashing local changes and pulling latest update... +call git stash +call git pull +echo If you want to restore changes you made before updating, run "git stash pop". +call "%v_conda_path%\Scripts\activate.bat" + +for /f "delims=" %%a in ('git log -1 --format^="%%H" -- environment.yaml') DO set v_cur_hash=%%a +set /p "v_last_hash="<"z_version_env.tmp" +echo %v_cur_hash%>z_version_env.tmp + +echo Current environment.yaml hash: %v_cur_hash% +echo Previous environment.yaml hash: %v_last_hash% + +if "%v_last_hash%" == "%v_cur_hash%" ( + echo environment.yaml unchanged. dependencies should be up to date. + echo if you still have unresolved dependencies, delete "z_version_env.tmp" + if not defined AUTO pause +) else ( + echo environment.yaml changed. updating dependencies + call conda env create --name "%v_conda_env_name%" -f environment.yaml + call conda env update --name "%v_conda_env_name%" -f environment.yaml + if not defined AUTO pause +) + +::cmd /k diff --git a/webui-streamlit.cmd b/webui-streamlit.cmd index 2513b5a..4bd5e48 100644 --- a/webui-streamlit.cmd +++ b/webui-streamlit.cmd @@ -1,53 +1,62 @@ @echo off -set conda_env_name=ldm -:: Put the path to conda directory after "=" sign if it's installed at non-standard path: -set custom_conda_path= +:: Run all commands using this script's directory as the working directory +cd %~dp0 -IF NOT "%custom_conda_path%"=="" ( - set paths=%custom_conda_path%;%paths% +:: copy over the first line from environment.yaml, e.g. name: ldm, and take the second word after splitting by ":" delimiter +for /F "tokens=2 delims=: " %%i in (environment.yaml) DO ( + set v_conda_env_name=%%i + goto EOL ) -:: Put the path to conda directory in a file called "custom-conda-path.txt" if it's installed at non-standard path: -FOR /F %%i IN (custom-conda-path.txt) DO set custom_conda_path=%%i +:EOL -set paths=%ProgramData%\miniconda3 -set paths=%paths%;%USERPROFILE%\miniconda3 -set paths=%paths%;%ProgramData%\anaconda3 -set paths=%paths%;%USERPROFILE%\anaconda3 +echo Environment name is set as %v_conda_env_name% as per environment.yaml -for %%a in (%paths%) do ( - IF NOT "%custom_conda_path%"=="" ( - set paths=%custom_conda_path%;%paths% - ) +:: Put the path to conda directory in a file called "custom-conda-path.txt" if it's installed at non-standard path +IF EXIST custom-conda-path.txt ( + FOR /F %%i IN (custom-conda-path.txt) DO set v_custom_path=%%i ) -for %%a in (%paths%) do ( - if EXIST "%%a\Scripts\activate.bat" ( - SET CONDA_PATH=%%a +set v_paths=%ProgramData%\miniconda3 +set v_paths=%v_paths%;%USERPROFILE%\miniconda3 +set v_paths=%v_paths%;%ProgramData%\anaconda3 +set v_paths=%v_paths%;%USERPROFILE%\anaconda3 + +for %%a in (%v_paths%) do ( + IF NOT "%v_custom_path%"=="" ( + set v_paths=%v_custom_path%;%v_paths% + ) +) + +for %%a in (%v_paths%) do ( + if EXIST "%%a\Scripts\activate.bat" ( + SET v_conda_path=%%a echo anaconda3/miniconda3 detected in %%a - goto :foundPath - ) + goto :CONDA_FOUND + ) ) -IF "%CONDA_PATH%"=="" ( +IF "%v_conda_path%"=="" ( echo anaconda3/miniconda3 not found. Install from here https://docs.conda.io/en/latest/miniconda.html + pause exit /b 1 ) -call git stash -call git pull +:CONDA_FOUND -:foundPath -call "%CONDA_PATH%\Scripts\activate.bat" -call conda env create -n "%conda_env_name%" -f environment.yaml -call conda env update --name "%conda_env_name%" -f environment.yaml -call "%CONDA_PATH%\Scripts\activate.bat" "%conda_env_name%" -::python "%CD%"\scripts\relauncher.py +if not exist "z_version_env.tmp" ( + :: first time running, we need to update + set AUTO=1 + call "update_to_latest.cmd" +) + +call "%v_conda_path%\Scripts\activate.bat" "%v_conda_env_name%" :PROMPT set SETUPTOOLS_USE_DISTUTILS=stdlib IF EXIST "models\ldm\stable-diffusion-v1\model.ckpt" ( - python -m streamlit run scripts\webui_streamlit.py + python -m streamlit run scripts\webui_streamlit.py --theme.base dark ) ELSE ( - ECHO Your model file does not exist! Place it in 'models\ldm\stable-diffusion-v1' with the name 'model.ckpt'. -) \ No newline at end of file + echo Your model file does not exist! Place it in 'models\ldm\stable-diffusion-v1' with the name 'model.ckpt'. + pause +) diff --git a/webui.cmd b/webui.cmd index d621d79..e062f1d 100644 --- a/webui.cmd +++ b/webui.cmd @@ -1,54 +1,63 @@ @echo off -set conda_env_name=ldm +:: Run all commands using this script's directory as the working directory +cd %~dp0 -:: Put the path to conda directory after "=" sign if it's installed at non-standard path: -set custom_conda_path= - -IF NOT "%custom_conda_path%"=="" ( - set paths=%custom_conda_path%;%paths% +:: copy over the first line from environment.yaml, e.g. name: ldm, and take the second word after splitting by ":" delimiter +for /F "tokens=2 delims=: " %%i in (environment.yaml) DO ( + set v_conda_env_name=%%i + goto EOL ) -:: Put the path to conda directory in a file called "custom-conda-path.txt" if it's installed at non-standard path: -FOR /F %%i IN (custom-conda-path.txt) DO set custom_conda_path=%%i +:EOL -set paths=%ProgramData%\miniconda3 -set paths=%paths%;%USERPROFILE%\miniconda3 -set paths=%paths%;%ProgramData%\anaconda3 -set paths=%paths%;%USERPROFILE%\anaconda3 +echo Environment name is set as %v_conda_env_name% as per environment.yaml -for %%a in (%paths%) do ( - IF NOT "%custom_conda_path%"=="" ( - set paths=%custom_conda_path%;%paths% - ) +:: Put the path to conda directory in a file called "custom-conda-path.txt" if it's installed at non-standard path +IF EXIST custom-conda-path.txt ( + FOR /F %%i IN (custom-conda-path.txt) DO set v_custom_path=%%i ) -for %%a in (%paths%) do ( - if EXIST "%%a\Scripts\activate.bat" ( - SET CONDA_PATH=%%a +set v_paths=%ProgramData%\miniconda3 +set v_paths=%v_paths%;%USERPROFILE%\miniconda3 +set v_paths=%v_paths%;%ProgramData%\anaconda3 +set v_paths=%v_paths%;%USERPROFILE%\anaconda3 + +for %%a in (%v_paths%) do ( + IF NOT "%v_custom_path%"=="" ( + set v_paths=%v_custom_path%;%v_paths% + ) +) + +for %%a in (%v_paths%) do ( + if EXIST "%%a\Scripts\activate.bat" ( + SET v_conda_path=%%a echo anaconda3/miniconda3 detected in %%a - goto :foundPath - ) + goto :CONDA_FOUND + ) ) -IF "%CONDA_PATH%"=="" ( +IF "%v_conda_path%"=="" ( echo anaconda3/miniconda3 not found. Install from here https://docs.conda.io/en/latest/miniconda.html + pause exit /b 1 ) -call git stash -call git pull +:CONDA_FOUND -:foundPath -call "%CONDA_PATH%\Scripts\activate.bat" -call conda env create -n "%conda_env_name%" -f environment.yaml -call conda env update -n "%conda_env_name%" --file environment.yaml --prune -call "%CONDA_PATH%\Scripts\activate.bat" "%conda_env_name%" -python "%CD%"\scripts\relauncher.py +if not exist "z_version_env.tmp" ( + :: first time running, we need to update + set AUTO=1 + call "update_to_latest.cmd" +) + +call "%v_conda_path%\Scripts\activate.bat" "%v_conda_env_name%" :PROMPT set SETUPTOOLS_USE_DISTUTILS=stdlib IF EXIST "models\ldm\stable-diffusion-v1\model.ckpt" ( - python scripts/relauncher.py + set PYTHONPATH=%~dp0 + python scripts\relauncher.py ) ELSE ( - ECHO Your model file does not exist! Place it in 'models\ldm\stable-diffusion-v1' with the name 'model.ckpt'. -) \ No newline at end of file + echo Your model file does not exist! Place it in 'models\ldm\stable-diffusion-v1' with the name 'model.ckpt'. + pause +) From 4a8cc9ba6dd7a205c8c82ee7f4939af253a30389 Mon Sep 17 00:00:00 2001 From: hlky <106811348+hlky@users.noreply.github.com> Date: Fri, 16 Sep 2022 08:18:29 +0100 Subject: [PATCH 17/27] Update README.md --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 651fdec..36d1dc0 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ -[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/altryne/sd-webui-colab/blob/main/Stable_Diffusion_WebUi_Altryne.ipynb) +# [Visit sd-webui's Discord Server](https://discord.gg/gyXNe4NySY) [![Discord Server](https://user-images.githubusercontent.com/5977640/190528254-9b5b4423-47ee-4f24-b4f9-fd13fba37518.png)](https://discord.gg/gyXNe4NySY) -# Installation instructions for [Windows](https://github.com/sd-webui/stable-diffusion-webui/wiki/Installation), [Linux](https://github.com/sd-webui/stable-diffusion-webui/wiki/Linux-Automated-Setup-Guide), or [Google Colab](https://colab.research.google.com/github/altryne/sd-webui-colab/blob/main/Stable_Diffusion_WebUi_Altryne.ipynb) +# Installation instructions for [Windows](https://github.com/sd-webui/stable-diffusion-webui/wiki/Installation), [Linux](https://github.com/sd-webui/stable-diffusion-webui/wiki/Linux-Automated-Setup-Guide) ### Have an **issue**? From 1af2610c7107426171de82d821d0631c38c4e330 Mon Sep 17 00:00:00 2001 From: Thomas Mello Date: Fri, 16 Sep 2022 12:26:01 +0300 Subject: [PATCH 18/27] Update conda environment on startup always (#1171) (#1176) * Update environment on startup always * Message to explicitly state no environment.yaml update required Co-authored-by: hlky <106811348+hlky@users.noreply.github.com> Co-authored-by: Brian Semrau Co-authored-by: hlky <106811348+hlky@users.noreply.github.com> --- update_to_latest.cmd | 39 ++++++++++++++++++++------------------- webui-streamlit.cmd | 18 ++++++++++++++++++ webui.cmd | 18 ++++++++++++++++++ 3 files changed, 56 insertions(+), 19 deletions(-) diff --git a/update_to_latest.cmd b/update_to_latest.cmd index 8ed32e3..1fdfdeb 100644 --- a/update_to_latest.cmd +++ b/update_to_latest.cmd @@ -36,25 +36,26 @@ IF "%v_conda_path%"=="" ( echo Stashing local changes and pulling latest update... call git stash call git pull -echo If you want to restore changes you made before updating, run "git stash pop". -call "%v_conda_path%\Scripts\activate.bat" - -for /f "delims=" %%a in ('git log -1 --format^="%%H" -- environment.yaml') DO set v_cur_hash=%%a -set /p "v_last_hash="<"z_version_env.tmp" -echo %v_cur_hash%>z_version_env.tmp - -echo Current environment.yaml hash: %v_cur_hash% -echo Previous environment.yaml hash: %v_last_hash% - -if "%v_last_hash%" == "%v_cur_hash%" ( - echo environment.yaml unchanged. dependencies should be up to date. - echo if you still have unresolved dependencies, delete "z_version_env.tmp" - if not defined AUTO pause -) else ( - echo environment.yaml changed. updating dependencies - call conda env create --name "%v_conda_env_name%" -f environment.yaml - call conda env update --name "%v_conda_env_name%" -f environment.yaml - if not defined AUTO pause +set /P restore="Do you want to restore changes you made before updating? (Y/N): " +IF /I "%restore%" == "N" ( + echo Removing changes please wait... + call git stash drop + echo Changes removed, press any key to continue... + pause >nul +) ELSE IF /I "%restore%" == "Y" ( + echo Restoring changes, please wait... + call git stash pop --quiet + echo Changes restored, press any key to continue... + pause >nul ) +for /f "delims=" %%a in ('git log -1 --format^="%%H" -- environment.yaml') DO set v_cur_hash=%%a +echo %v_cur_hash%>z_version_env.tmp + +call conda env create --name "%v_conda_env_name%" -f environment.yaml +call conda env update --name "%v_conda_env_name%" -f environment.yaml +if not defined AUTO pause + +call "%v_conda_path%\Scripts\activate.bat" + ::cmd /k diff --git a/webui-streamlit.cmd b/webui-streamlit.cmd index 4bd5e48..eefd2e7 100644 --- a/webui-streamlit.cmd +++ b/webui-streamlit.cmd @@ -52,6 +52,24 @@ if not exist "z_version_env.tmp" ( call "%v_conda_path%\Scripts\activate.bat" "%v_conda_env_name%" +set v_last_hash=0 +set /p "v_last_hash="<"z_version_env.tmp" +for /f "delims=" %%a in ('git log -1 --format^="%%H" -- environment.yaml') DO set v_cur_hash=%%a + +if not "%v_last_hash%" == "%v_cur_hash%" ( + set /P updateenv="Do you want to update with the latest environment.yaml? (Y/N): " + if /I "%updateenv%" == "N" ( + echo Starting without updating dependencies. + ) else if /I "%updateenv%" == "Y" ( + echo Updating dependencies... + call conda env create --name "%v_conda_env_name%" -f environment.yaml + call conda env update --name "%v_conda_env_name%" -f environment.yaml + echo %v_cur_hash%>z_version_env.tmp + ) +) else ( + echo No environment.yaml update required. +) + :PROMPT set SETUPTOOLS_USE_DISTUTILS=stdlib IF EXIST "models\ldm\stable-diffusion-v1\model.ckpt" ( diff --git a/webui.cmd b/webui.cmd index e062f1d..dae03a7 100644 --- a/webui.cmd +++ b/webui.cmd @@ -52,6 +52,24 @@ if not exist "z_version_env.tmp" ( call "%v_conda_path%\Scripts\activate.bat" "%v_conda_env_name%" +set v_last_hash=0 +set /p "v_last_hash="<"z_version_env.tmp" +for /f "delims=" %%a in ('git log -1 --format^="%%H" -- environment.yaml') DO set v_cur_hash=%%a + +if not "%v_last_hash%" == "%v_cur_hash%" ( + set /P updateenv="Do you want to update with the latest environment.yaml? (Y/N): " + if /I "%updateenv%" == "N" ( + echo Starting without updating dependencies. + ) else if /I "%updateenv%" == "Y" ( + echo Updating dependencies... + call conda env create --name "%v_conda_env_name%" -f environment.yaml + call conda env update --name "%v_conda_env_name%" -f environment.yaml + echo %v_cur_hash%>z_version_env.tmp + ) +) else ( + echo No environment.yaml update required. +) + :PROMPT set SETUPTOOLS_USE_DISTUTILS=stdlib IF EXIST "models\ldm\stable-diffusion-v1\model.ckpt" ( From 2c1c9465d5e4f87a8061ae30431fd87ec42173c6 Mon Sep 17 00:00:00 2001 From: hlky <106811348+hlky@users.noreply.github.com> Date: Fri, 16 Sep 2022 11:12:15 +0100 Subject: [PATCH 19/27] environment update from .cmd --- update_to_latest.cmd | 61 -------------------------------------------- webui-streamlit.cmd | 57 ++++++++++++++++++++++++----------------- webui.cmd | 57 ++++++++++++++++++++++++----------------- 3 files changed, 68 insertions(+), 107 deletions(-) delete mode 100644 update_to_latest.cmd diff --git a/update_to_latest.cmd b/update_to_latest.cmd deleted file mode 100644 index 1fdfdeb..0000000 --- a/update_to_latest.cmd +++ /dev/null @@ -1,61 +0,0 @@ -@echo off -cd %~dp0 - -:: Duplicate code to find miniconda - -IF EXIST custom-conda-path.txt ( - FOR /F %%i IN (custom-conda-path.txt) DO set v_custom_path=%%i -) - -set v_paths=%ProgramData%\miniconda3 -set v_paths=%v_paths%;%USERPROFILE%\miniconda3 -set v_paths=%v_paths%;%ProgramData%\anaconda3 -set v_paths=%v_paths%;%USERPROFILE%\anaconda3 - -for %%a in (%v_paths%) do ( - IF NOT "%v_custom_path%"=="" ( - set v_paths=%v_custom_path%;%v_paths% - ) -) - -for %%a in (%v_paths%) do ( - if EXIST "%%a\Scripts\activate.bat" ( - SET v_conda_path=%%a - echo anaconda3/miniconda3 detected in %%a - ) -) - -IF "%v_conda_path%"=="" ( - echo anaconda3/miniconda3 not found. Install from here https://docs.conda.io/en/latest/miniconda.html - pause - exit /b 1 -) - -:: Update - -echo Stashing local changes and pulling latest update... -call git stash -call git pull -set /P restore="Do you want to restore changes you made before updating? (Y/N): " -IF /I "%restore%" == "N" ( - echo Removing changes please wait... - call git stash drop - echo Changes removed, press any key to continue... - pause >nul -) ELSE IF /I "%restore%" == "Y" ( - echo Restoring changes, please wait... - call git stash pop --quiet - echo Changes restored, press any key to continue... - pause >nul -) - -for /f "delims=" %%a in ('git log -1 --format^="%%H" -- environment.yaml') DO set v_cur_hash=%%a -echo %v_cur_hash%>z_version_env.tmp - -call conda env create --name "%v_conda_env_name%" -f environment.yaml -call conda env update --name "%v_conda_env_name%" -f environment.yaml -if not defined AUTO pause - -call "%v_conda_path%\Scripts\activate.bat" - -::cmd /k diff --git a/webui-streamlit.cmd b/webui-streamlit.cmd index eefd2e7..92b1dfc 100644 --- a/webui-streamlit.cmd +++ b/webui-streamlit.cmd @@ -43,33 +43,42 @@ IF "%v_conda_path%"=="" ( ) :CONDA_FOUND - -if not exist "z_version_env.tmp" ( - :: first time running, we need to update - set AUTO=1 - call "update_to_latest.cmd" +echo Stashing local changes and pulling latest update... +call git stash +call git pull +set /P restore="Do you want to restore changes you made before updating? (Y/N): " +IF /I "%restore%" == "N" ( + echo Removing changes please wait... + call git stash drop + echo Changes removed, press any key to continue... + pause >nul +) ELSE IF /I "%restore%" == "Y" ( + echo Restoring changes, please wait... + call git stash pop --quiet + echo Changes restored, press any key to continue... + pause >nul ) +call "%v_conda_path%\Scripts\activate.bat" + +for /f "delims=" %%a in ('git log -1 --format^="%%H" -- environment.yaml') DO set v_cur_hash=%%a +set /p "v_last_hash="<"z_version_env.tmp" +echo %v_cur_hash%>z_version_env.tmp + +echo Current environment.yaml hash: %v_cur_hash% +echo Previous environment.yaml hash: %v_last_hash% + +if "%v_last_hash%" == "%v_cur_hash%" ( + echo environment.yaml unchanged. dependencies should be up to date. + echo if you still have unresolved dependencies, delete "z_version_env.tmp" +) else ( + echo environment.yaml changed. updating dependencies + call conda env create --name "%v_conda_env_name%" -f environment.yaml + call conda env update --name "%v_conda_env_name%" -f environment.yaml +) + call "%v_conda_path%\Scripts\activate.bat" "%v_conda_env_name%" -set v_last_hash=0 -set /p "v_last_hash="<"z_version_env.tmp" -for /f "delims=" %%a in ('git log -1 --format^="%%H" -- environment.yaml') DO set v_cur_hash=%%a - -if not "%v_last_hash%" == "%v_cur_hash%" ( - set /P updateenv="Do you want to update with the latest environment.yaml? (Y/N): " - if /I "%updateenv%" == "N" ( - echo Starting without updating dependencies. - ) else if /I "%updateenv%" == "Y" ( - echo Updating dependencies... - call conda env create --name "%v_conda_env_name%" -f environment.yaml - call conda env update --name "%v_conda_env_name%" -f environment.yaml - echo %v_cur_hash%>z_version_env.tmp - ) -) else ( - echo No environment.yaml update required. -) - :PROMPT set SETUPTOOLS_USE_DISTUTILS=stdlib IF EXIST "models\ldm\stable-diffusion-v1\model.ckpt" ( @@ -78,3 +87,5 @@ IF EXIST "models\ldm\stable-diffusion-v1\model.ckpt" ( echo Your model file does not exist! Place it in 'models\ldm\stable-diffusion-v1' with the name 'model.ckpt'. pause ) + +::cmd /k diff --git a/webui.cmd b/webui.cmd index dae03a7..97df705 100644 --- a/webui.cmd +++ b/webui.cmd @@ -43,33 +43,42 @@ IF "%v_conda_path%"=="" ( ) :CONDA_FOUND - -if not exist "z_version_env.tmp" ( - :: first time running, we need to update - set AUTO=1 - call "update_to_latest.cmd" +echo Stashing local changes and pulling latest update... +call git stash +call git pull +set /P restore="Do you want to restore changes you made before updating? (Y/N): " +IF /I "%restore%" == "N" ( + echo Removing changes please wait... + call git stash drop + echo Changes removed, press any key to continue... + pause >nul +) ELSE IF /I "%restore%" == "Y" ( + echo Restoring changes, please wait... + call git stash pop --quiet + echo Changes restored, press any key to continue... + pause >nul ) +call "%v_conda_path%\Scripts\activate.bat" + +for /f "delims=" %%a in ('git log -1 --format^="%%H" -- environment.yaml') DO set v_cur_hash=%%a +set /p "v_last_hash="<"z_version_env.tmp" +echo %v_cur_hash%>z_version_env.tmp + +echo Current environment.yaml hash: %v_cur_hash% +echo Previous environment.yaml hash: %v_last_hash% + +if "%v_last_hash%" == "%v_cur_hash%" ( + echo environment.yaml unchanged. dependencies should be up to date. + echo if you still have unresolved dependencies, delete "z_version_env.tmp" +) else ( + echo environment.yaml changed. updating dependencies + call conda env create --name "%v_conda_env_name%" -f environment.yaml + call conda env update --name "%v_conda_env_name%" -f environment.yaml +) + call "%v_conda_path%\Scripts\activate.bat" "%v_conda_env_name%" -set v_last_hash=0 -set /p "v_last_hash="<"z_version_env.tmp" -for /f "delims=" %%a in ('git log -1 --format^="%%H" -- environment.yaml') DO set v_cur_hash=%%a - -if not "%v_last_hash%" == "%v_cur_hash%" ( - set /P updateenv="Do you want to update with the latest environment.yaml? (Y/N): " - if /I "%updateenv%" == "N" ( - echo Starting without updating dependencies. - ) else if /I "%updateenv%" == "Y" ( - echo Updating dependencies... - call conda env create --name "%v_conda_env_name%" -f environment.yaml - call conda env update --name "%v_conda_env_name%" -f environment.yaml - echo %v_cur_hash%>z_version_env.tmp - ) -) else ( - echo No environment.yaml update required. -) - :PROMPT set SETUPTOOLS_USE_DISTUTILS=stdlib IF EXIST "models\ldm\stable-diffusion-v1\model.ckpt" ( @@ -79,3 +88,5 @@ IF EXIST "models\ldm\stable-diffusion-v1\model.ckpt" ( echo Your model file does not exist! Place it in 'models\ldm\stable-diffusion-v1' with the name 'model.ckpt'. pause ) + +::cmd /k From 4ae7a5805ce82c4ea5df1d1bf529cf74abb75034 Mon Sep 17 00:00:00 2001 From: hlky <106811348+hlky@users.noreply.github.com> Date: Fri, 16 Sep 2022 11:16:46 +0100 Subject: [PATCH 20/27] Update .gitignore --- .gitignore | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.gitignore b/.gitignore index 46597f4..4b30236 100644 --- a/.gitignore +++ b/.gitignore @@ -58,3 +58,5 @@ condaenv.*.requirements.txt /log/log.csv /flagged/* /gfpgan/* +/models/* +z_version_env.tmp \ No newline at end of file From b5462536f8670aae860906939ca4021fdd9250be Mon Sep 17 00:00:00 2001 From: bryanlyon <3223233+bryanlyon@users.noreply.github.com> Date: Fri, 16 Sep 2022 12:42:41 -0700 Subject: [PATCH 21/27] Fixed decimal prompt weights without leading digit (#1182) The regex was not accounting properly for prompt weights that didn't begin with a leading number such as .5 or .1 and was instead splitting those off into their own prompt which got everything all screwed up. For example, the prompt string of "Fruit:1 grapes:-.5" should parse as [('Fruit', 1.0), ('grapes', -.5)] but was being incorrectly parsed as [('Fruit', 1.0), ('grapes', 1.0), ('-.5', 1.0)] This fixes that by making the regex properly catch decimals. --- scripts/webui.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/webui.py b/scripts/webui.py index 0de1c1e..dd64a4c 100644 --- a/scripts/webui.py +++ b/scripts/webui.py @@ -1519,7 +1519,7 @@ prompt_parser = re.compile(""" (?: # non-capture group :+ # match one or more ':' characters (?P # capture group for 'weight' - -?\d+(?:\.\d+)? # match positive or negative integer or decimal number + -?\d*\.{0,1}\d+ # match positive or negative integer or decimal number )? # end weight capture group, make optional \s* # strip spaces after weight | # OR From ea6b422bff279eee4f8f1ae4eaea6451c544e00f Mon Sep 17 00:00:00 2001 From: Mark Knol Date: Sun, 18 Sep 2022 05:57:08 +0200 Subject: [PATCH 22/27] very minor spelling error (#762) * Update frontend.py * Update frontend.py Co-authored-by: hlky <106811348+hlky@users.noreply.github.com> --- frontend/frontend.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/frontend/frontend.py b/frontend/frontend.py index c3c1a95..29d3c50 100644 --- a/frontend/frontend.py +++ b/frontend/frontend.py @@ -631,9 +631,9 @@ def draw_gradio_ui(opt, img2img=lambda x: x, txt2img=lambda x: x, imgproc=lambda """ gr.HTML("""
-

For help and advanced usage guides, visit the Project Wiki

-

Stable Diffusion WebUI is an open-source project. You can find the latest stable builds on the main repository. - If you would like to contribute to development or test bleeding edge builds, you can visit the developement repository.

+

For help and advanced usage guides, visit the Project Wiki

+

Stable Diffusion WebUI is an open-source project. + If you would like to contribute to development or test bleeding edge builds, use the dev branch.

""") # Hack: Detect the load event on the frontend From 74fd533077bacaa2642b4691b819c48f4351f3e5 Mon Sep 17 00:00:00 2001 From: hlky <106811348+hlky@users.noreply.github.com> Date: Sun, 18 Sep 2022 11:54:17 +0100 Subject: [PATCH 23/27] default img2img denoising strength increased --- configs/webui/webui_streamlit.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/configs/webui/webui_streamlit.yaml b/configs/webui/webui_streamlit.yaml index 36887e1..bf49f7e 100644 --- a/configs/webui/webui_streamlit.yaml +++ b/configs/webui/webui_streamlit.yaml @@ -114,7 +114,7 @@ img2img: # 9: Fix faces using GFPGAN # 10: Upscale images using Real-ESRGAN sampler_name: "k_euler" - denoising_strength: 0.45 + denoising_strength: 0.75 # 0: Keep masked area # 1: Regenerate only masked area mask_mode: 0 From 6f4a1d8a4136d52ec65005e3597a4c28f5b78a34 Mon Sep 17 00:00:00 2001 From: hlky <106811348+hlky@users.noreply.github.com> Date: Sun, 18 Sep 2022 12:17:57 +0100 Subject: [PATCH 24/27] slider_steps and slider_bounds in defaults config slider_steps and slider_bounds in defaults config --- configs/webui/webui_streamlit.yaml | 127 +++++++++++++++++------------ scripts/img2img.py | 6 +- scripts/txt2img.py | 6 +- scripts/txt2vid.py | 8 +- 4 files changed, 93 insertions(+), 54 deletions(-) diff --git a/configs/webui/webui_streamlit.yaml b/configs/webui/webui_streamlit.yaml index bf49f7e..394494c 100644 --- a/configs/webui/webui_streamlit.yaml +++ b/configs/webui/webui_streamlit.yaml @@ -62,6 +62,15 @@ txt2img: variant_amount: 0.0 variant_seed: "" write_info_files: True + slider_steps: { + sampling: 1 + } + slider_bounds: { + sampling: { + lower: 1, + upper: 150 + } + } txt2vid: default_model: "CompVis/stable-diffusion-v1-4" @@ -97,58 +106,76 @@ txt2vid: beta_end: 0.012 beta_scheduler_type: "linear" max_frames: 1000 + slider_steps: { + sampling: 1 + } + slider_bounds: { + sampling: { + lower: 1, + upper: 150 + } + } img2img: - prompt: - sampling_steps: 30 - # Adding an int to toggles enables the corresponding feature. - # 0: Create prompt matrix (separate multiple prompts using |, and get all combinations of them) - # 1: Normalize Prompt Weights (ensure sum of weights add up to 1.0) - # 2: Loopback (use images from previous batch when creating next batch) - # 3: Random loopback seed - # 4: Save individual images - # 5: Save grid - # 6: Sort samples by prompt - # 7: Write sample info files - # 8: jpg samples - # 9: Fix faces using GFPGAN - # 10: Upscale images using Real-ESRGAN - sampler_name: "k_euler" - denoising_strength: 0.75 - # 0: Keep masked area - # 1: Regenerate only masked area - mask_mode: 0 - mask_restore: False - # 0: Just resize - # 1: Crop and resize - # 2: Resize and fill - resize_mode: 0 - # Leave blank for random seed: - seed: "" - ddim_eta: 0.0 - cfg_scale: 7.5 - batch_count: 1 - batch_size: 1 - height: 512 - width: 512 - # Textual inversion embeddings file path: - fp: "" - loopback: True - random_seed_loopback: True - separate_prompts: False - update_preview: True - update_preview_frequency: 5 - normalize_prompt_weights: True - save_individual_images: True - save_grid: True - group_by_prompt: True - save_as_jpg: False - use_GFPGAN: False - use_RealESRGAN: False - RealESRGAN_model: "RealESRGAN_x4plus" - variant_amount: 0.0 - variant_seed: "" - write_info_files: True + prompt: + sampling_steps: 30 + # Adding an int to toggles enables the corresponding feature. + # 0: Create prompt matrix (separate multiple prompts using |, and get all combinations of them) + # 1: Normalize Prompt Weights (ensure sum of weights add up to 1.0) + # 2: Loopback (use images from previous batch when creating next batch) + # 3: Random loopback seed + # 4: Save individual images + # 5: Save grid + # 6: Sort samples by prompt + # 7: Write sample info files + # 8: jpg samples + # 9: Fix faces using GFPGAN + # 10: Upscale images using Real-ESRGAN + sampler_name: "k_euler" + denoising_strength: 0.75 + # 0: Keep masked area + # 1: Regenerate only masked area + mask_mode: 0 + mask_restore: False + # 0: Just resize + # 1: Crop and resize + # 2: Resize and fill + resize_mode: 0 + # Leave blank for random seed: + seed: "" + ddim_eta: 0.0 + cfg_scale: 7.5 + batch_count: 1 + batch_size: 1 + height: 512 + width: 512 + # Textual inversion embeddings file path: + fp: "" + loopback: True + random_seed_loopback: True + separate_prompts: False + update_preview: True + update_preview_frequency: 5 + normalize_prompt_weights: True + save_individual_images: True + save_grid: True + group_by_prompt: True + save_as_jpg: False + use_GFPGAN: False + use_RealESRGAN: False + RealESRGAN_model: "RealESRGAN_x4plus" + variant_amount: 0.0 + variant_seed: "" + write_info_files: True + slider_steps: { + sampling: 1 + } + slider_bounds: { + sampling: { + lower: 1, + upper: 150 + } + } gfpgan: strength: 100 diff --git a/scripts/img2img.py b/scripts/img2img.py index 0445c6d..142fe81 100644 --- a/scripts/img2img.py +++ b/scripts/img2img.py @@ -383,7 +383,11 @@ def layout(): st.session_state["custom_model"] = "Stable Diffusion v1.4" - st.session_state["sampling_steps"] = st.slider("Sampling Steps", value=st.session_state['defaults'].img2img.sampling_steps, min_value=1, max_value=500) + st.session_state["sampling_steps"] = st.slider("Sampling Steps", + value=st.session_state['defaults'].img2img.sampling_steps, + min_value=st.session_state['defaults'].img2img.slider_bounds.sampling.lower, + max_value=st.session_state['defaults'].img2img.slider_bounds.sampling.upper, + step=st.session_state['defaults'].img2img.slider_steps.sampling) sampler_name_list = ["k_lms", "k_euler", "k_euler_a", "k_dpm_2", "k_dpm_2_a", "k_heun", "PLMS", "DDIM"] st.session_state["sampler_name"] = st.selectbox("Sampling method",sampler_name_list, diff --git a/scripts/txt2img.py b/scripts/txt2img.py index d85d2a8..6f74143 100644 --- a/scripts/txt2img.py +++ b/scripts/txt2img.py @@ -217,7 +217,11 @@ def layout(): the file for the model has on said folder, it is recommended to give the .ckpt file a name that \ will make it easier for you to distinguish it from other models. Default: Stable Diffusion v1.4") - st.session_state.sampling_steps = st.slider("Sampling Steps", value=st.session_state['defaults'].txt2img.sampling_steps, min_value=10, max_value=500, step=10) + st.session_state.sampling_steps = st.slider("Sampling Steps", + value=st.session_state['defaults'].txt2img.sampling_steps, + min_value=st.session_state['defaults'].txt2img.slider_bounds.sampling.lower, + max_value=st.session_state['defaults'].txt2img.slider_bounds.sampling.upper, + step=st.session_state['defaults'].txt2img.slider_steps.sampling) sampler_name_list = ["k_lms", "k_euler", "k_euler_a", "k_dpm_2", "k_dpm_2_a", "k_heun", "PLMS", "DDIM"] sampler_name = st.selectbox("Sampling method", sampler_name_list, diff --git a/scripts/txt2vid.py b/scripts/txt2vid.py index aa26493..e1be209 100644 --- a/scripts/txt2vid.py +++ b/scripts/txt2vid.py @@ -650,8 +650,12 @@ def layout(): #custom_model = "CompVis/stable-diffusion-v1-4" #st.session_state["weights_path"] = f"CompVis/{slugify(custom_model.lower())}" - st.session_state.sampling_steps = st.slider("Sampling Steps", value=st.session_state['defaults'].txt2vid.sampling_steps, min_value=10, step=10, max_value=500, - help="Number of steps between each pair of sampled points") + st.session_state.sampling_steps = st.slider("Sampling Steps", + value=st.session_state['defaults'].txt2vid.sampling_steps, + min_value=st.session_state['defaults'].txt2vid.slider_bounds.sampling.lower, + max_value=st.session_state['defaults'].txt2vid.slider_bounds.sampling.upper, + step=st.session_state['defaults'].txt2vid.slider_steps.sampling, + help="Number of steps between each pair of sampled points") st.session_state.num_inference_steps = st.slider("Inference Steps:", value=st.session_state['defaults'].txt2vid.num_inference_steps, min_value=10,step=10, max_value=500, help="Higher values (e.g. 100, 200 etc) can create better images.") From 8540c8d42c5c9c4fea8536edb3da9b642e82d9e0 Mon Sep 17 00:00:00 2001 From: Thomas Mello Date: Sun, 18 Sep 2022 15:08:47 +0300 Subject: [PATCH 25/27] fix: copy to clipboard button --- frontend/frontend.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/frontend/frontend.py b/frontend/frontend.py index f0bf874..94c76c9 100644 --- a/frontend/frontend.py +++ b/frontend/frontend.py @@ -56,11 +56,15 @@ def draw_gradio_ui(opt, img2img=lambda x: x, txt2img=lambda x: x, imgproc=lambda gr.Markdown( "Select an image from the gallery, then click one of the buttons below to perform an action.") with gr.Row(elem_id='txt2img_actions_row'): - gr.Button("Copy to clipboard").click(fn=None, - inputs=output_txt2img_gallery, - outputs=[], - # _js=js_copy_to_clipboard( 'txt2img_gallery_output') - ) + gr.Button("Copy to clipboard").click( + fn=None, + inputs=output_txt2img_gallery, + outputs=[], + _js=call_JS( + "copyImageFromGalleryToClipboard", + fromId="txt2img_gallery_output" + ) + ) output_txt2img_copy_to_input_btn = gr.Button("Push to img2img") output_txt2img_to_imglab = gr.Button("Send to Lab", visible=True) From a797312183188867d7b4bdc7f1a1263017540a48 Mon Sep 17 00:00:00 2001 From: Thomas Mello Date: Sun, 18 Sep 2022 16:00:47 +0300 Subject: [PATCH 26/27] The Merge (#1201) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * resolve conflict with master * - Added option to select custom models instead of just using the default one, if you want to use a custom model just place your .ckpt file in "models/custom" and the UI will detect it and let you switch between stable diffusion and your custom model, make sure to give the filename a proper name that is easy to distinguish from other models because that name will be used on the UI. - Implemented basic Text To Video tab, will continue to improve it as it is really basic right now. - Improved the model loading, you now should see less frequently errors about it not been loaded correctly. * fix: advanced editor (#827), close #811 refactor js_Call hook to take all gradio arguments * Added num_inference_steps to config file and fixed incorrectly calls to the config file from the txt2vid tab calling txt2img instead. * update readme as per installation step & format * proposed streamlit code organization changes I want people of all skill levels to be able to contribute This is one way the code could be split up with the aim of making it easy to understand and contribute especially for people on the lower end of the skill spectrum All i've done is split things, I think renaming and reorganising is still needed * Fixed missing diffusers dependency for Streamlit * Streamlit: Allow user defaults to be specified in a userconfig_streamlit.yaml file. * Changed Streamit yaml default configs Changed `update_preview_frequency` from every 1 step to every 5 steps. This results in a massive gain in performance (roughly going from 2-3 times slower to only 10-15% slower) while still showing good image generation output. Changed default GFPGAN and realESRGAN settings to be off by default. That way, users can decide if they want to use them on, and what images they wish to do so. * Made sure img2txt and img2img checkboxes respect YAML defaults * Move location of user file to configs/webui folder * Fixed the path in webui_streamlit.py * Display Info and Stats when render is complete, similar to what Gradio shows. * Add info and stats to img2img * chore: update maintenance scripts and docs (#891) * automate conda_env_name as per name in yaml * Embed installation links directly in README.md Include links to Windows, Linux, and Google Colab installations. * Fix conda update in webui.sh for pip bug * Add info about new PRs Co-authored-by: Hafiidz <3688500+Hafiidz@users.noreply.github.com> Co-authored-by: Tom Pham <54967380+TomPham97@users.noreply.github.com> Co-authored-by: GRMrGecko * Improvements to the txt2vid tab. * Urgent Fix to PR:860 * Update attention.py * Update FUNDING.yml * when in outcrop mode, mask added regions and fill in with voroni noise for better outpainting * frontend: display current device info (#889) Displays the current device info at the bottom of the page. For users who run multiple instances of `sd-webui` on the same system (for multiple GPUs), it helps to know which of the active `CUDA_VISIBLE_DEVICES` is being used. * Fixed aspect ratio box not being updated on txt2img tab, for issue 219 from old repo (#812) * Metadata cleanup - Maintain metadata within UI (#845) * Metadata cleanup - Maintain metadata within UI This commit, when combined with Gradio 3.2.1b1+, maintains image metadata as an image is passed throughout the UI. For example, if you generate an image, send it to Image Lab, upscale it, fix faces, and then drag the resulting image back in to Image Lab, it will still remember the image generation parameters. When the image is saved, the metadata will be stripped from it if save-metadata is not enabled. If the image is saved by *dragging* out of the UI on to the filesystem it may maintain its metadata. Note: I have ran into UI responsiveness issues with upgrading Gradio. Seems there may be some Gradio queue management issues. *Without* the gradio update this commit will maintain current functionality, but will not keep meetadata when dragging an image between UI components. * Move ImageMetadata into its own file Cleans up webui, enables webui_streamlit et al to use it as well. * Fix typo * Add filename formatting argument (#908) * Update webui.py Filename formatting argument * Update scripts/webui.py Co-authored-by: Thomas Mello * Tiling parameter (#911) * tiling * default to False * fix: filename format parameter (#923) * For issue :884, ensure webui.cmd before init src * Remove embeddings file path * Add mask_restore to restore images based on mask, fixing #665 (#898) * Add mask_restore option to give users the option to restore images based on mask, fixing #665. Before commit c73fdd78 (Implement masking during sampling to improve blending, #308) image mask was applied after sampling, resulting in masked parts that are not regenerated to actually stay the same. Since c73fdd78 the masked img2img will change the whole image, even in masked areas. It gives better looking results at first glance, but will result in image degredation when applied a few times. See issue #665. In the workflow of using repeated masked img2img, users may want to use this options to keep the parts of image they actually want to keep without image degradation. A final masked img2img or whole image img2img with mask_restore disabled will give the better blending of "Implement masking during sampling". * revert changes of a7be43ba in change_image_editor_mode * fix ui_functions.change_image_editor_mode by adding gr.update to the end of the list it returns * revert inserted newlines and whitespaces to match format of previous code * improve caption of new option mask_restore "Only modify regenerated parts of image" * fix ui_functions.change_image_editor_mode by adding gr.update to the end of the list it returns an old copy of the function exists in webui.py, this superflous function mistakenly was changed by the earlier commit b6a9e16b * remove unused functions that are near duplicates of functions in ui_functions.py * Added CSS to center the image in the txt2img interface * add img2img option for color correction. (#936) color correction is already used for loopback to prevent color drift with the first image as correction target. the option allows to use the color correction even without loopback mode. it helps keeping the colors similar to the input image. * Image transparency is used as mask for inpainting * fix: lost imports from #921 * Changed StreamIt to `k_euler` 30 steps as default * Fixed an issue with the txt2vid model. * Removed old files from a split test we deed that are not needed anymore, we plan to do the split differently. * Changed the scheduler for the txt2vid tab back to LMS, for now we can only use that. * Better support for large batches in optimized mode * Removed some unused lines from the css file for the streamlit version. * Changed the diffusers version to be 0.2.4 or lower as a new version breaks the txt2vid generation. * Added the models/custom folder to gitignore to ignore custom models. * Added two new scripts that will be used for the new implementation of the txt2vid tab which uses the latest version of the diffusers library. * - Improved the progress bar for the txt2vid tab, it now shows more information during generation. - Changed the guidance_scale variable to be cfg_scale. * Perform masked image restoration for GFPGAN, RealESRGAN, fixing #947 * Perform masked image restoration when using GFPGAN or RealESRGAN, fixing #947. Also fixes bug in image display when using masked image restoration with RealESRGAN. When the image is upscaled using RealESRGAN the image restoration can not use the original image because it has wrong resolution. In this case the image restoration will restore the non-regenerated parts of the image with an RealESRGAN upscaled version of the original input image. Modifications from GFPGAN or color correction in (un)masked parts are also restored to the original image by mask blending. * Update scripts/webui.py Co-authored-by: Thomas Mello * fix: sampler name in GoBig #988 * add sampler_name defaults to img2img * add metadata to other file output file types * remove deprecated kwargs/parameter * refactor: sort out dependencies Co-Authored-By: oc013 <101832295+oc013@users.noreply.github.com> Co-Authored-By: Aarni Koskela Co-Authored-By: oc013 <101832295+oc013@users.noreply.github.com> Co-Authored-By: Aarni Koskela * webui: detect scoped-down GPU environment (#993) * webui: detect scoped-down GPU environment check if we're using a scoped-down GPU environment (pynvml does not listen to CUDA_VISIBLE_DEVICES) so that we can measure memory on the correct GPU * remove unnecessary import * Added piexif dependency. * Changed the minimum value for the Sampling Steps and Inference Steps to 10 and added step with a value of 10 to make it easier to move the slider as it will require a higher maximum value than in other tabs for good results on the text2vid tab. * Commented an import that is not used for now but will be used soon. * write same metadata to file and yaml * include piexif in environment needed for exif labelling of non-png files * fix individual image file format saves * introduces a general config setting save_format similar to grid_format for individual file saves * Add NSFW filter to avoid unexpected (#955) * Add NSFW filter to avoid unexpected * Fix img2img configuration numbering * Added some basic layout for the Model Manager tab and added there the models that most people use to make it easy to download instead of having to go do the wiki or searching through discord for links, it also shows the path where you are supposed to put those models for them to work. * webui: display the GPU in use during startup (#994) * webui: display the GPU in use during startup tell the user which GPU the code is actually going to use before spending lots of time loading everything onto the GPU * typo * add some info messages * evaluate current GPU properly * add debug flag gating not everyone wants or needs to see debug messages :) * add in stray debug msg * Docker updates - Add LDSR, streamlit, other updates for new repository * Update util.py * Docker - Set PYTHONPATH to parent directory to avoid `No module named frontend` error * Add missing comma for nsfw toggle in img2img (#1028) * Multiple improvements to the txt2vid tab. - Improved txt2vid speed by 2 times. - Added DDIM scheduler. - Added sliders for beta_start and beta_end to have more control over these parameters on the scheduler. - Added option to select the scheduler type from scaled_linear or linear. - Added option to save info files for the txt2vid tab and improved the information saved to include most of the parameters used to run the generation. - You can now download any model from the huggingface website to use on the txt2vid tab, just add the name to the custom_models_list on the config file. * webui: add prompt output to console (#1031) * webui: add prompt output to console show the user what prompt is currently being rendered * fix prompt print location * support negative prompts separated by ### e.g. "shopping mall ### people" will try to generate an image of a mall without people in it. * Docker validate model files if not a symlink in case user has VALIDATE_MODELS=false set (#1038) * - Added changes made by @Hafiidz on the ui-improvements branch to the css for the streamli-on-hover-tabs component. * Added streamlit-on-Hover-tabs and streamlit-option-menu dependencies to the environment.yaml file. * Changed some values to be dynamic instead of a fixed value so they are more responsive. * Changed the cmd script to use the dark theme by default when launching the streamlit UI. * Removed the padding at the top of the sidebar so we can have more free space. * - Added code for @Hafiidz's changes made on the css. * Fixed an error with the metadata not able to be saved because of the seed was not converted to a string before so it had no attribute encode on it. * add masking to streamlit img2img, find_noise_for_image, matched_noise * Use the webui script directories as PWD (#946) * add Gradio API endpoint settings (#1055) * add Gradio API endpoint settings * Add comments crediting code authors. (probably not enough, but better than none) * Renamed the save_grid option for txt2vid on the config file to be save_video, this will be used to determine if the user wants to save a video at the end of the generation or not, similar to the save_grid that is used on txt2img and img2img but for video. * -Added the Update Image Preview option to be part of the current tab options under Preview Settings. - Added Dynamic Preview Frequency option for the txt2vid tab which tries to find the lowest value for update_preview_frequency at which we can update the preview image during generation while at the same time minimizing the impact it has in performance. - Added option to save a video file on the outputs/txt2vid-samples folder after the generation is complete similar to how the save_grid option works on other tabs. - Added a video preview which shows a video on the txt2vid tab when the generation is completed. - Formated some lines of code to make it use less space and fit on the a single screen. - Added a script called Settings.py to the script folder in which Settings for the Setting page will be placed. Empty for now. * Commented some print statements that were used for debugging and forgot to remove previously. * fix: disable live prompt parsing * Fix issue where loopback was using batch mode * Fix indentation error that prevents mask_restore from working unless ESRGAN is turned on * Fixed Sidebar CSS for 4K displays * img2img mask fixes and fix image2noise normalization * Made it so the sampling_steps is added to num_inference_steps, otherwise it would not match the value you set for it on the slider. * Changed the loading of the model on the txt2vid tab so the half models are only loaded if the no_half option on the config file is set to False. * fix: launcher batch file fix #920, fix #605 - Allow reading environment.yaml file in either LF or CRLF - Only update environment if environment.yaml changes - Remove custom_conda_path to discourage changing source file - Fix unable to launch webui due to frontend module missing (#605) * Update README.md (#1075) fix typo * half precision streamlit txt2vid `RuntimeError: expected scalar type Half but found Float` with both `torch_dtype=torch.float16` and `revision="fp16"` * Add mask restore feature to streamlit, prevent color correction from modifying initial image when mask_restore is turned on * Add mask_restore to streamlit config * JobManager: Fix typo breaking jobs close #858 close #1041 * JobManager: Buttons skip queue (#1092) Have JobManager buttons skip Gradio's queue, since otherwise they aren't sending JobManager button presses. * The webui_streamlit.py file has been split into multiple modules containing their own code making it easier to work with than a single big file. The list of modules is as follow: - webuit_streamlit.py: contains the main layout as well as the functions that load the css which is needed by the layout. - webui_streamlit_old.py: contains the code for the previous version of the WebUI. Will be removed once the new UI code starts to get used and if everything works as it should. - txt2img.py: contains the code for the txt2img tab. - img2img.py: contains the code for the img2img tab. - txt2vid.py: contains the code for the txt2vid tab. - sd_utils.py: contains utility functions used by more than one module, any function that meets such condition should be placed here. - ModelManager.py: contains the code for the Model Manager page on the sidebar menu. - Settings.py: contains the code for the Settings page on the sidebar menu. - home.py: contains the code for the Home tab, history and gallery implemented by @devilismyfriend. - imglab.py: contains the code for the Image Lab tab implemented by @devilismyfriend * fix: patch docker conda install pip requirements (#1094) (cherry picked from commit fab5765fe453117590d43d5999b97cbcd782cd08) Co-authored-by: Sérgio * Using the Optimization from Dogettx (#974) * Update attention.py * change to dogettx Co-authored-by: hlky <106811348+hlky@users.noreply.github.com> * Update Dockerfile (#1101) add expose for streamlit port * Publish Streamlit ports (#1102) (cherry picked from commit 833a91047df999302f699637768741cecee9c37b) Co-authored-by: Charlie * Forgot to call the layout() for the Model Manager tab after the import so it was not been used and the tab was shown as empty. * Removed the "find_noise_for_image.py" and "matched_noise.py" scripts as their content is now part of "sd_utils.py" * - Added the functions to load the optimized models, this "should" make it so optimized and turbo mode work now but needs to be tested more. - Added function to load LDSR. * Fixed some imports. * Fixed the info message on the txt2img tab not showing the info but just showing the text "Done" * Made the defaults settings from the config file be stored inside st.session_state to avoid loading it multiple times when calling the "sd_utils.py" file from other modules. * Moved defaults to the webui_streamlit.py file and fixed some imports. * Removed condition to check if the defaults are in the st.session_state dictionary, this is not needed and would cause issues with it not being reloaded when the user changes something on it. * Modified the way the defaults settings are loaded from the config file so we only load them on the webui_strealit.py file and use st.session_state to access them from anywhere else, this makes it so the config can be modified externally like before the code split and the changes will be updated on next rerun of the UI. * fix: [streamlit] optimization mode * temp disable nvml support for multiple gpus * Fixed defaults not being loaded correctly or missing in some places. * Add a separate update script instead of git pull on startup (#1106) * - Fixed max_frame not being properly used and instead sampling_steps was the variable being use. - Fixed several issues with wrong variable being used on multiple places. - Addd option to toggle some extra option from the config file for when the model is loading on the txt2vid tab. * Re-merge #611 - View/Cancel in-progress diffusions (#796) * JobManager: Re-merge #611 PR #611 seems to have got lost in the shuffle after the transition to 'dev'. This commit re-merges the feature branch. This adds support for viewing preview images as the image generates, as well as cancelling in-progress images and a couple fixes and clean-ups. * JobManager: Clear jobs that fail to start Sometimes if a job fails to start it will get stuck in the active job list. This commit ensures that jobs that raise exceptions are cleared, and also adds a start timer to clear out jobs that fail to start within a reasonable amount of time. * chore: add breaks to cmds for readability (#1134) * Added custom models list to the txt2img tab. * Small fix to the custom model list. * Corrected breaking issues introduced in #1136 to txt2img and made state variables consistent with img2img. Fixed a bug where switching models after running would not reload the used model. * Formatted tabs as spaces * Fixed update_preview_frequency and update_preview using defaults from webui_streamlit.yaml instead of state variables from UI. * Prompt user if they want to restore changes (#1137) - After stashing any changes and pulling updates, ask user if they wish to pop changes - If user declines the restore, drop the stash to prevent the case of an ever growing stash pile * Added streamlit_nested_layout component as dependency and imported on the webui_streamli.py file to allow us to use nested columns and expanders. * - Added the Home tab made by @devilismyfriend - Added gallery tab on txt2img. * Added case insensitivity to restore prompt (#1152) * Calculate aspect ratio and pixel count on start (#1157) * Fix errors rendering galleries when there are not enough images to render * Fix the gallery back/next buttons and add a refresh button * Fix invalid invocation of find_noise_for_image * Removed the Home tab until the gallery is fixed. * Fixed a missing import on the ModelManager script. * Added discord server link to the Readme.md * - Increased the max value for the width and height sliders on the txt2img tab. - Fixed a leftover line from removing the home tab. * Update conda environment on startup always (#1171) * Update environment on startup always * Message to explicitly state no environment.yaml update required Co-authored-by: hlky <106811348+hlky@users.noreply.github.com> * environment update from .cmd * Update .gitignore * Enable negative prompts on streamlit * - Bumped the version of diffusers used on the txt2vid tab to be now v0.3.0. - Added initial file for the textual inversion tab. * add missing argument to GoBig sample function, fixes #1183 (#1184) * cherry-pick @Any-Winter-4079's https://github.com/lstein/stable-diffusion/pull/540. this is a collaboration incorporating a lot of people's contributions -- including for example @Doggettx and the original code from @neonsecret on which the Doggetx optimizations were based (see https://github.com/lstein/stable-diffusion/issues/431, https://github.com/sd-webui/stable-diffusion-webui/pull/771\#issuecomment-1239716055). Takes exactly the same amount of time to run 8 steps as original CompVis code does (10.4 secs, ~1.25s/it). (#1177) Co-authored-by: Alex Birch * allow webp uploads to img2img tab #991 * Don't attempt mask restoration when there is no mask given (#1186) * When running a batch with preview turned on, produce a grid of preview images * When early terminating, generation_callback gets invoked but st.session_state is empty. When this happens, just bail. * Collect images for final display This is a collection of several changes to enhance image display: * When using GFPGAN or RealESRGAN, only the final output will be displayed. * In batch>1 mode, each final image will be collected into an image grid for display * The image is constrained to a reasonable size to ensure that batch grids of RealESRGAN'd images don't end up spitting out a massive image that the browser then has to handle. * Additionally, the progress bar indicator is updated as each image is post-processed. * Display the final image before running postprocessing, and don't preview when i=0 * Added a config option to use embeddings from the huggingface stable diffusion concept library. * Added option to enable enable_attention_slicing and enable_minimal_memory_usage, this for now only works on txt2vid which uses diffusers. * Basic implementation for the Concept Library tab made by cloning the Home tab. * Temporarily hide sd_concept_library due to missing layout() * st.session_state["defaults"] fix * Used loaded_model state variable in .yaml generation (#1196) Used loaded_model state variable in .yaml generation * Streamlit txt2img page settings now follow defaults (#1195) * Some options on the Streamlit txt2img page now follow the defaults from the relevant config files. * Fixed a copy-paste gone wrong in my previous commit. * st.session_state["defaults"] fix Co-authored-by: hlky <106811348+hlky@users.noreply.github.com> * default img2img denoising strength increased * slider_steps and slider_bounds in defaults config slider_steps and slider_bounds in defaults config * fix: copy to clipboard button Co-authored-by: ZeroCool940711 Co-authored-by: ZeroCool Co-authored-by: Hafiidz <3688500+Hafiidz@users.noreply.github.com> Co-authored-by: hlky <106811348+hlky@users.noreply.github.com> Co-authored-by: Joshua Kimsey Co-authored-by: Tony Beeman Co-authored-by: Tom Pham <54967380+TomPham97@users.noreply.github.com> Co-authored-by: GRMrGecko Co-authored-by: TingTingin <36141041+TingTingin@users.noreply.github.com> Co-authored-by: Logan zoellner Co-authored-by: M Co-authored-by: James Pound Co-authored-by: cobryan05 <13701027+cobryan05@users.noreply.github.com> Co-authored-by: Michoko Co-authored-by: VulumeCode <2590984+VulumeCode@users.noreply.github.com> Co-authored-by: xaedes Co-authored-by: Michael Hearn Co-authored-by: Soul-Burn Co-authored-by: JJ Co-authored-by: oc013 <101832295+oc013@users.noreply.github.com> Co-authored-by: Aarni Koskela Co-authored-by: osi1880vr <87379616+osi1880vr@users.noreply.github.com> Co-authored-by: Rae Fu Co-authored-by: Brian Semrau Co-authored-by: Matt Soucy Co-authored-by: endomorphosis Co-authored-by: unnamedplugins <79282950+unnamedplugins@users.noreply.github.com> Co-authored-by: Syahmi Azhar Co-authored-by: Ahmad Abdullah <83442967+ahmad1284@users.noreply.github.com> Co-authored-by: Sérgio Co-authored-by: Charlie Co-authored-by: protoplm Co-authored-by: Ascended Co-authored-by: JuanLagu <32816584+JuanLagu@users.noreply.github.com> Co-authored-by: Chris Heald Co-authored-by: Charles Galant Co-authored-by: Alex Birch Co-authored-by: protoplm <57930981+protoplm@users.noreply.github.com> Co-authored-by: Dekker3D --- .dockerignore | 3 + .env_docker.example | 10 +- .gitignore | 9 +- Dockerfile | 9 +- README.md | 6 +- configs/webui/webui.yaml | 11 +- configs/webui/webui_streamlit.yaml | 190 +- docker-compose.yml | 5 +- docker-reset.sh | 9 +- entrypoint.sh | 73 +- environment.yaml | 62 +- frontend/css/streamlit.main.css | 122 +- frontend/frontend.py | 80 +- frontend/image_metadata.py | 57 + frontend/job_manager.py | 221 +- frontend/ui_functions.py | 4 +- images/nsfw.jpeg | Bin 0 -> 25276 bytes ldm/modules/attention.py | 105 +- ldm/modules/diffusionmodules/model.py | 145 +- ldm/modules/diffusionmodules/util.py | 5 +- scripts/DeforumStableDiffusion.py | 1312 ++++++++++++ scripts/ModelManager.py | 46 + scripts/Settings.py | 5 + scripts/home.py | 216 ++ scripts/img2img.py | 592 ++++++ scripts/imglab.py | 161 ++ scripts/perlin.py | 48 + scripts/relauncher.py | 4 + scripts/sd_utils.py | 1728 ++++++++++++++++ scripts/stable_diffusion_pipeline.py | 233 +++ scripts/stable_diffusion_walk.py | 218 ++ scripts/textual_inversion.py | 57 + scripts/txt2img.py | 368 ++++ scripts/txt2vid.py | 780 +++++++ scripts/webui.py | 577 ++++-- scripts/webui_streamlit.py | 1805 +--------------- scripts/webui_streamlit_old.py | 2738 +++++++++++++++++++++++++ setup.py | 2 +- webui.sh | 4 +- 39 files changed, 9911 insertions(+), 2109 deletions(-) create mode 100644 .dockerignore mode change 100644 => 100755 docker-reset.sh create mode 100644 frontend/image_metadata.py create mode 100644 images/nsfw.jpeg create mode 100644 scripts/DeforumStableDiffusion.py create mode 100644 scripts/ModelManager.py create mode 100644 scripts/Settings.py create mode 100644 scripts/home.py create mode 100644 scripts/img2img.py create mode 100644 scripts/imglab.py create mode 100644 scripts/perlin.py create mode 100644 scripts/sd_utils.py create mode 100644 scripts/stable_diffusion_pipeline.py create mode 100644 scripts/stable_diffusion_walk.py create mode 100644 scripts/textual_inversion.py create mode 100644 scripts/txt2img.py create mode 100644 scripts/txt2vid.py create mode 100644 scripts/webui_streamlit_old.py diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000..4ac12df --- /dev/null +++ b/.dockerignore @@ -0,0 +1,3 @@ +models/ +outputs/ +src/ diff --git a/.env_docker.example b/.env_docker.example index 5a34945..51eb059 100644 --- a/.env_docker.example +++ b/.env_docker.example @@ -6,9 +6,13 @@ CONDA_FORCE_UPDATE=false # (useful to set to false after you're sure the model files are already in place) VALIDATE_MODELS=true -#Automatically relaunch the webui on crashes +# Automatically relaunch the webui on crashes WEBUI_RELAUNCH=true -#Pass cli arguments to webui.py e.g: -#WEBUI_ARGS=--gpu=1 --esrgan-gpu=1 --gfpgan-gpu=1 +# Which webui to launch +# WEBUI_SCRIPT=webui_streamlit.py +WEBUI_SCRIPT=webui.py + +# Pass cli arguments to webui.py e.g: +# WEBUI_ARGS=--optimized --extra-models-cpu --gpu=1 --esrgan-gpu=1 --gfpgan-gpu=1 WEBUI_ARGS= diff --git a/.gitignore b/.gitignore index 4b30236..b014154 100644 --- a/.gitignore +++ b/.gitignore @@ -47,16 +47,21 @@ MANIFEST .env_updated condaenv.*.requirements.txt +# Visual Studio directories +.vs/ +.vscode/ # =========================================================================== # # Repo-specific # =========================================================================== # +/configs/webui/userconfig_streamlit.yaml /custom-conda-path.txt /src/* -/outputs/* +/outputs +/model_cache /log/**/*.png /log/log.csv /flagged/* /gfpgan/* /models/* -z_version_env.tmp \ No newline at end of file +z_version_env.tmp diff --git a/Dockerfile b/Dockerfile index 8d5ecb4..2b061b0 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,6 +1,10 @@ FROM nvidia/cuda:11.3.1-runtime-ubuntu20.04 -ENV DEBIAN_FRONTEND=noninteractive +ENV DEBIAN_FRONTEND=noninteractive \ + PYTHONUNBUFFERED=1 \ + PYTHONIOENCODING=UTF-8 \ + CONDA_DIR=/opt/conda + WORKDIR /sd SHELL ["/bin/bash", "-c"] @@ -11,7 +15,6 @@ RUN apt-get update && \ rm -rf /var/lib/apt/lists/* # Install miniconda -ENV CONDA_DIR /opt/conda RUN wget -O ~/miniconda.sh -q --show-progress --progress=bar:force https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && \ /bin/bash ~/miniconda.sh -b -p $CONDA_DIR && \ rm ~/miniconda.sh @@ -20,7 +23,7 @@ ENV PATH=$CONDA_DIR/bin:$PATH # Install font for prompt matrix COPY /data/DejaVuSans.ttf /usr/share/fonts/truetype/ -EXPOSE 7860 +EXPOSE 7860 8501 COPY ./entrypoint.sh /sd/ ENTRYPOINT /sd/entrypoint.sh diff --git a/README.md b/README.md index 36d1dc0..f5d96ba 100644 --- a/README.md +++ b/README.md @@ -46,8 +46,8 @@ Features: * Gradio GUI: Idiot-proof, fully featured frontend for both txt2img and img2img generation * No more manually typing parameters, now all you have to do is write your prompt and adjust sliders -* GFPGAN Face Correction 🔥: [Download the model](https://github.com/sd-webui/stable-diffusion-webui#gfpgan)Automatically correct distorted faces with a built-in GFPGAN option, fixes them in less than half a second -* RealESRGAN Upscaling 🔥: [Download the models](https://github.com/sd-webui/stable-diffusion-webui#realesrgan) Boosts the resolution of images with a built-in RealESRGAN option +* GFPGAN Face Correction 🔥: [Download the model](https://github.com/sd-webui/stable-diffusion-webui/wiki/Installation#optional-additional-models) Automatically correct distorted faces with a built-in GFPGAN option, fixes them in less than half a second +* RealESRGAN Upscaling 🔥: [Download the models](https://github.com/sd-webui/stable-diffusion-webui/wiki/Installation#optional-additional-models) Boosts the resolution of images with a built-in RealESRGAN option * :computer: esrgan/gfpgan on cpu support :computer: * Textual inversion 🔥: [info](https://textual-inversion.github.io/) - requires enabling, see [here](https://github.com/hlky/sd-enable-textual-inversion), script works as usual without it enabled * Advanced img2img editor :art: :fire: :art: @@ -106,7 +106,7 @@ that are not in original script. ### GFPGAN Lets you improve faces in pictures using the GFPGAN model. There is a checkbox in every tab to use GFPGAN at 100%, and -also a separate tab that just allows you to use GFPGAN on any picture, with a slider that controls how strongthe effect is. +also a separate tab that just allows you to use GFPGAN on any picture, with a slider that controls how strong the effect is. ![](images/GFPGAN.png) diff --git a/configs/webui/webui.yaml b/configs/webui/webui.yaml index b7bd258..25d222b 100644 --- a/configs/webui/webui.yaml +++ b/configs/webui/webui.yaml @@ -12,8 +12,9 @@ txt2img: # 5: Write sample info files # 6: write sample info to log file # 7: jpg samples - # 8: Fix faces using GFPGAN - # 9: Upscale images using RealESRGAN + # 8: Filter NSFW content + # 9: Fix faces using GFPGAN + # 10: Upscale images using RealESRGAN toggles: [1, 2, 3, 4, 5] sampler_name: k_lms ddim_eta: 0.0 # legacy name, applies to all algorithms. @@ -40,8 +41,10 @@ img2img: # 6: Sort samples by prompt # 7: Write sample info files # 8: jpg samples - # 9: Fix faces using GFPGAN - # 10: Upscale images using Real-ESRGAN + # 9: Color correction + # 10: Filter NSFW content + # 11: Fix faces using GFPGAN + # 12: Upscale images using Real-ESRGAN toggles: [1, 4, 5, 6, 7] sampler_name: k_lms ddim_eta: 0.0 diff --git a/configs/webui/webui_streamlit.yaml b/configs/webui/webui_streamlit.yaml index 84263bd..394494c 100644 --- a/configs/webui/webui_streamlit.yaml +++ b/configs/webui/webui_streamlit.yaml @@ -1,14 +1,19 @@ # UI defaults configuration file. It is automatically loaded if located at configs/webui/webui_streamlit.yaml. # Any changes made here will be available automatically on the web app without having to stop it. +# You may add overrides in a file named "userconfig_streamlit.yaml" in this folder, which can contain any subset +# of the properties below. general: gpu: 0 outdir: outputs - ckpt: "models/ldm/stable-diffusion-v1/model.ckpt" - fp: - name: 'embeddings/alex/embeddings_gs-11000.pt' + default_model: "Stable Diffusion v1.4" + default_model_config: "configs/stable-diffusion/v1-inference.yaml" + default_model_path: "models/ldm/stable-diffusion-v1/model.ckpt" + use_sd_concepts_library: True + sd_concepts_library_folder: "models/custom/sd-concepts-library" GFPGAN_dir: "./src/gfpgan" RealESRGAN_dir: "./src/realesrgan" RealESRGAN_model: "RealESRGAN_x4plus" + LDSR_dir: "./src/latent-diffusion" outdir_txt2img: outputs/txt2img-samples outdir_img2img: outputs/img2img-samples gfpgan_cpu: False @@ -16,88 +21,161 @@ general: extra_models_cpu: False extra_models_gpu: False save_metadata: True + save_format: "png" skip_grid: False skip_save: False grid_format: "jpg:95" n_rows: -1 no_verify_input: False no_half: False + use_float16: False precision: "autocast" optimized: False optimized_turbo: False + optimized_config: "optimizedSD/v1-inference.yaml" + enable_attention_slicing: False + enable_minimal_memory_usage : False update_preview: True - update_preview_frequency: 1 + update_preview_frequency: 5 txt2img: prompt: height: 512 width: 512 - cfg_scale: 5.0 + cfg_scale: 7.5 seed: "" batch_count: 1 batch_size: 1 - sampling_steps: 50 - default_sampler: "k_lms" + sampling_steps: 30 + default_sampler: "k_euler" separate_prompts: False + update_preview: True + update_preview_frequency: 5 normalize_prompt_weights: True save_individual_images: True save_grid: True group_by_prompt: True save_as_jpg: False - use_GFPGAN: True - use_RealESRGAN: True + use_GFPGAN: False + use_RealESRGAN: False RealESRGAN_model: "RealESRGAN_x4plus" variant_amount: 0.0 variant_seed: "" + write_info_files: True + slider_steps: { + sampling: 1 + } + slider_bounds: { + sampling: { + lower: 1, + upper: 150 + } + } + +txt2vid: + default_model: "CompVis/stable-diffusion-v1-4" + custom_models_list: ["CompVis/stable-diffusion-v1-4", "naclbit/trinart_stable_diffusion_v2", "hakurei/waifu-diffusion", "osanseviero/BigGAN-deep-128"] + prompt: + height: 512 + width: 512 + cfg_scale: 7.5 + seed: "" + batch_count: 1 + batch_size: 1 + sampling_steps: 30 + num_inference_steps: 200 + default_sampler: "k_euler" + scheduler_name: "klms" + separate_prompts: False + update_preview: True + update_preview_frequency: 5 + dynamic_preview_frequency: True + normalize_prompt_weights: True + save_individual_images: True + save_video: True + group_by_prompt: True + write_info_files: True + do_loop: False + save_as_jpg: False + use_GFPGAN: False + use_RealESRGAN: False + RealESRGAN_model: "RealESRGAN_x4plus" + variant_amount: 0.0 + variant_seed: "" + beta_start: 0.00085 + beta_end: 0.012 + beta_scheduler_type: "linear" + max_frames: 1000 + slider_steps: { + sampling: 1 + } + slider_bounds: { + sampling: { + lower: 1, + upper: 150 + } + } img2img: - prompt: - sampling_steps: 50 - # Adding an int to toggles enables the corresponding feature. - # 0: Create prompt matrix (separate multiple prompts using |, and get all combinations of them) - # 1: Normalize Prompt Weights (ensure sum of weights add up to 1.0) - # 2: Loopback (use images from previous batch when creating next batch) - # 3: Random loopback seed - # 4: Save individual images - # 5: Save grid - # 6: Sort samples by prompt - # 7: Write sample info files - # 8: jpg samples - # 9: Fix faces using GFPGAN - # 10: Upscale images using Real-ESRGAN - sampler_name: k_lms - denoising_strength: 0.45 - # 0: Keep masked area - # 1: Regenerate only masked area - mask_mode: 0 - # 0: Just resize - # 1: Crop and resize - # 2: Resize and fill - resize_mode: 0 - # Leave blank for random seed: - seed: "" - ddim_eta: 0.0 - cfg_scale: 5.0 - batch_count: 1 - batch_size: 1 - height: 512 - width: 512 - # Textual inversion embeddings file path: - fp: "" - loopback: True - random_seed_loopback: True - separate_prompts: False - normalize_prompt_weights: True - save_individual_images: True - save_grid: True - group_by_prompt: True - save_as_jpg: False - use_GFPGAN: True - use_RealESRGAN: True - RealESRGAN_model: "RealESRGAN_x4plus" - variant_amount: 0.0 - variant_seed: "" + prompt: + sampling_steps: 30 + # Adding an int to toggles enables the corresponding feature. + # 0: Create prompt matrix (separate multiple prompts using |, and get all combinations of them) + # 1: Normalize Prompt Weights (ensure sum of weights add up to 1.0) + # 2: Loopback (use images from previous batch when creating next batch) + # 3: Random loopback seed + # 4: Save individual images + # 5: Save grid + # 6: Sort samples by prompt + # 7: Write sample info files + # 8: jpg samples + # 9: Fix faces using GFPGAN + # 10: Upscale images using Real-ESRGAN + sampler_name: "k_euler" + denoising_strength: 0.75 + # 0: Keep masked area + # 1: Regenerate only masked area + mask_mode: 0 + mask_restore: False + # 0: Just resize + # 1: Crop and resize + # 2: Resize and fill + resize_mode: 0 + # Leave blank for random seed: + seed: "" + ddim_eta: 0.0 + cfg_scale: 7.5 + batch_count: 1 + batch_size: 1 + height: 512 + width: 512 + # Textual inversion embeddings file path: + fp: "" + loopback: True + random_seed_loopback: True + separate_prompts: False + update_preview: True + update_preview_frequency: 5 + normalize_prompt_weights: True + save_individual_images: True + save_grid: True + group_by_prompt: True + save_as_jpg: False + use_GFPGAN: False + use_RealESRGAN: False + RealESRGAN_model: "RealESRGAN_x4plus" + variant_amount: 0.0 + variant_seed: "" + write_info_files: True + slider_steps: { + sampling: 1 + } + slider_bounds: { + sampling: { + lower: 1, + upper: 150 + } + } gfpgan: strength: 100 - diff --git a/docker-compose.yml b/docker-compose.yml index 968df1c..f378963 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -2,7 +2,7 @@ version: '3.3' services: stable-diffusion: - container_name: sd + container_name: sd-webui build: context: . dockerfile: Dockerfile @@ -12,6 +12,7 @@ services: volumes: - .:/sd - ./outputs:/sd/outputs + - ./model_cache:/sd/model_cache - conda_env:/opt/conda - root_profile:/root ports: @@ -21,7 +22,7 @@ services: resources: reservations: devices: - - capabilities: [gpu] + - capabilities: [ gpu ] volumes: conda_env: diff --git a/docker-reset.sh b/docker-reset.sh old mode 100644 new mode 100755 index 3ca3158..5042026 --- a/docker-reset.sh +++ b/docker-reset.sh @@ -10,12 +10,13 @@ echo $(pwd) read -p "Is the directory above correct to run reset on? (y/n) " -n 1 DIRCONFIRM if [[ $DIRCONFIRM =~ ^[Yy]$ ]]; then docker compose down - docker image rm stable-diffusion_stable-diffusion:latest - docker volume rm stable-diffusion_conda_env - docker volume rm stable-diffusion_root_profile + docker image rm stable-diffusion-webui_stable-diffusion:latest + docker volume rm stable-diffusion-webui_conda_env + docker volume rm stable-diffusion-webui_root_profile echo "Remove ./src" sudo rm -rf src - sudo rm -rf latent_diffusion.egg-info + sudo rm -rf gfpgan + sudo rm -rf sd_webui.egg-info sudo rm .env_updated else echo "Exited without resetting" diff --git a/entrypoint.sh b/entrypoint.sh index 21ab01e..e130ea0 100755 --- a/entrypoint.sh +++ b/entrypoint.sh @@ -3,26 +3,36 @@ # Starts the gui inside the docker container using the conda env # +# set -x + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +cd $SCRIPT_DIR +export PYTHONPATH=$SCRIPT_DIR + +MODEL_DIR="${SCRIPT_DIR}/model_cache" # Array of model files to pre-download # local filename # local path in container (no trailing slash) # download URL # sha256sum MODEL_FILES=( - 'model.ckpt /sd/models/ldm/stable-diffusion-v1 https://www.googleapis.com/storage/v1/b/aai-blog-files/o/sd-v1-4.ckpt?alt=media fe4efff1e174c627256e44ec2991ba279b3816e364b49f9be2abc0b3ff3f8556' - 'GFPGANv1.3.pth /sd/src/gfpgan/experiments/pretrained_models https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth c953a88f2727c85c3d9ae72e2bd4846bbaf59fe6972ad94130e23e7017524a70' - 'RealESRGAN_x4plus.pth /sd/src/realesrgan/experiments/pretrained_models https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth 4fa0d38905f75ac06eb49a7951b426670021be3018265fd191d2125df9d682f1' - 'RealESRGAN_x4plus_anime_6B.pth /sd/src/realesrgan/experiments/pretrained_models https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth f872d837d3c90ed2e05227bed711af5671a6fd1c9f7d7e91c911a61f155e99da' + 'model.ckpt models/ldm/stable-diffusion-v1 https://www.googleapis.com/storage/v1/b/aai-blog-files/o/sd-v1-4.ckpt?alt=media fe4efff1e174c627256e44ec2991ba279b3816e364b49f9be2abc0b3ff3f8556' + 'GFPGANv1.3.pth src/gfpgan/experiments/pretrained_models https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth c953a88f2727c85c3d9ae72e2bd4846bbaf59fe6972ad94130e23e7017524a70' + 'RealESRGAN_x4plus.pth src/realesrgan/experiments/pretrained_models https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth 4fa0d38905f75ac06eb49a7951b426670021be3018265fd191d2125df9d682f1' + 'RealESRGAN_x4plus_anime_6B.pth src/realesrgan/experiments/pretrained_models https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth f872d837d3c90ed2e05227bed711af5671a6fd1c9f7d7e91c911a61f155e99da' + 'project.yaml src/latent-diffusion/experiments/pretrained_models https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1 9d6ad53c5dafeb07200fb712db14b813b527edd262bc80ea136777bdb41be2ba' + 'model.ckpt src/latent-diffusion/experiments/pretrained_models https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1 c209caecac2f97b4bb8f4d726b70ac2ac9b35904b7fc99801e1f5e61f9210c13' ) # Conda environment installs/updates # @see https://github.com/ContinuumIO/docker-images/issues/89#issuecomment-467287039 ENV_NAME="ldm" -ENV_FILE="/sd/environment.yaml" +ENV_FILE="${SCRIPT_DIR}/environment.yaml" ENV_UPDATED=0 ENV_MODIFIED=$(date -r $ENV_FILE "+%s") -ENV_MODIFED_FILE="/sd/.env_updated" +ENV_MODIFED_FILE="${SCRIPT_DIR}/.env_updated" if [[ -f $ENV_MODIFED_FILE ]]; then ENV_MODIFIED_CACHED=$(<${ENV_MODIFED_FILE}); else ENV_MODIFIED_CACHED=0; fi +export PIP_EXISTS_ACTION=w # Create/update conda env if needed if ! conda env list | grep ".*${ENV_NAME}.*" >/dev/null 2>&1; then @@ -51,54 +61,67 @@ conda info | grep active # Function to checks for valid hash for model files and download/replaces if invalid or does not exist validateDownloadModel() { local file=$1 - local path=$2 + local path="${SCRIPT_DIR}/${2}" local url=$3 local hash=$4 echo "checking ${file}..." - sha256sum --check --status <<< "${hash} ${path}/${file}" + sha256sum --check --status <<< "${hash} ${MODEL_DIR}/${file}.${hash}" if [[ $? == "1" ]]; then echo "Downloading: ${url} please wait..." mkdir -p ${path} - wget --output-document=${path}/${file} --no-verbose --show-progress --progress=dot:giga ${url} - echo "saved ${file}" + wget --output-document=${MODEL_DIR}/${file}.${hash} --no-verbose --show-progress --progress=dot:giga ${url} + ln -sf ${MODEL_DIR}/${file}.${hash} ${path}/${file} + if [[ -e "${path}/${file}" ]]; then + echo "saved ${file}" + else + echo "error saving ${path}/${file}!" + exit 1 + fi else - echo -e "${file} is valid!\n" + if [[ ! -e ${path}/${file} || ! -L ${path}/${file} ]]; then + mkdir -p ${path} + ln -sf ${MODEL_DIR}/${file}.${hash} ${path}/${file} + echo -e "linked valid ${file}\n" + else + echo -e "${file} is valid!\n" + fi fi } # Validate model files -if [[ -z $VALIDATE_MODELS || $VALIDATE_MODELS == "true" ]]; then - echo "Validating model files..." - for models in "${MODEL_FILES[@]}"; do - model=($models) +echo "Validating model files..." +for models in "${MODEL_FILES[@]}"; do + model=($models) + if [[ ! -e ${model[1]}/${model[0]} || ! -L ${model[1]}/${model[0]} || -z $VALIDATE_MODELS || $VALIDATE_MODELS == "true" ]]; then validateDownloadModel ${model[0]} ${model[1]} ${model[2]} ${model[3]} - done -fi + fi +done # Launch web gui -cd /sd - -if [[ -z $WEBUI_ARGS ]]; then - launch_message="entrypoint.sh: Launching..." +if [[ ! -z $WEBUI_SCRIPT && $WEBUI_SCRIPT == "webui_streamlit.py" ]]; then + launch_command="streamlit run scripts/${WEBUI_SCRIPT:-webui.py} $WEBUI_ARGS" else - launch_message="entrypoint.sh: Launching with arguments ${WEBUI_ARGS}" + launch_command="python scripts/${WEBUI_SCRIPT:-webui.py} $WEBUI_ARGS" fi +launch_message="entrypoint.sh: Run ${launch_command}..." if [[ -z $WEBUI_RELAUNCH || $WEBUI_RELAUNCH == "true" ]]; then n=0 while true; do - echo $launch_message + if (( $n > 0 )); then echo "Relaunch count: ${n}" fi - python -u scripts/webui.py $WEBUI_ARGS + + $launch_command + echo "entrypoint.sh: Process is ending. Relaunching in 0.5s..." ((n++)) sleep 0.5 done else echo $launch_message - python -u scripts/webui.py $WEBUI_ARGS + $launch_command fi diff --git a/environment.yaml b/environment.yaml index 5bb3bf8..b5bd8e3 100644 --- a/environment.yaml +++ b/environment.yaml @@ -3,39 +3,47 @@ channels: - pytorch - defaults dependencies: - - git - - python=3.8.5 - - pip=20.3 - cudatoolkit=11.3 + - git + - numpy=1.22.3 + - pip=20.3 + - python=3.8.5 - pytorch=1.11.0 + - scikit-image=0.19.2 - torchvision=0.12.0 - - numpy=1.19.2 - pip: - - albumentations==0.4.3 - - opencv-python==4.1.2.30 - - opencv-python-headless==4.1.2.30 - - pudb==2019.2 - - imageio==2.9.0 - - imageio-ffmpeg==0.4.2 - - pytorch-lightning==1.4.2 - - omegaconf==2.1.1 - - test-tube>=0.7.5 - - einops==0.3.0 - - torch-fidelity==0.3.0 - - transformers==4.19.2 - - torchmetrics==0.6.0 - - kornia==0.6 - - gradio==3.1.6 - - accelerate==0.12.0 - - pynvml==11.4.1 - - basicsr>=1.3.4.0 - - facexlib>=0.2.3 - - python-slugify>=6.1.2 - - streamlit>=1.12.2 - - retry>=0.9.2 + - -e . - -e git+https://github.com/CompVis/taming-transformers#egg=taming-transformers - -e git+https://github.com/openai/CLIP#egg=clip - -e git+https://github.com/TencentARC/GFPGAN#egg=GFPGAN - -e git+https://github.com/xinntao/Real-ESRGAN#egg=realesrgan - -e git+https://github.com/hlky/k-diffusion-sd#egg=k_diffusion - - -e . \ No newline at end of file + - -e git+https://github.com/devilismyfriend/latent-diffusion#egg=latent-diffusion + - accelerate==0.12.0 + - albumentations==0.4.3 + - basicsr>=1.3.4.0 + - diffusers==0.3.0 + - einops==0.3.0 + - facexlib>=0.2.3 + - gradio==3.1.6 + - imageio-ffmpeg==0.4.2 + - imageio==2.9.0 + - kornia==0.6 + - omegaconf==2.1.1 + - opencv-python-headless==4.6.0.66 + - pandas==1.4.3 + - piexif==1.1.3 + - pudb==2019.2 + - pynvml==11.4.1 + - python-slugify>=6.1.2 + - pytorch-lightning==1.4.2 + - retry>=0.9.2 + - streamlit>=1.12.2 + - streamlit-on-Hover-tabs==1.0.1 + - streamlit-option-menu==0.3.2 + - streamlit_nested_layout + - test-tube>=0.7.5 + - tensorboard + - torch-fidelity==0.3.0 + - torchmetrics==0.6.0 + - transformers==4.19.2 diff --git a/frontend/css/streamlit.main.css b/frontend/css/streamlit.main.css index 4e11b77..a11d21d 100644 --- a/frontend/css/streamlit.main.css +++ b/frontend/css/streamlit.main.css @@ -1,15 +1,111 @@ -.css-18e3th9 { - padding-top: 2rem; - padding-bottom: 10rem; - padding-left: 5rem; - padding-right: 5rem; -} -.css-1d391kg { - padding-top: 3.5rem; - padding-right: 1rem; - padding-bottom: 3.5rem; - padding-left: 1rem; -} +/*********************************************************** +* Additional CSS for streamlit builtin components * +************************************************************/ + +/* Tab name (e.g. Text-to-Image) */ button[data-baseweb="tab"] { - font-size: 25px; + font-size: 25px; //improve legibility } + +/* Image Container (only appear after run finished) */ +.css-du1fp8 { + justify-content: center; //center the image, especially better looks in wide screen +} + +/* Streamlit header */ +.css-1avcm0n { + background-color: transparent; +} + +/* Main streamlit container (below header) */ +.css-18e3th9 { + padding-top: 2rem; //reduce the empty spaces +} + +/* @media only for widescreen, to ensure enough space to see all */ +@media (min-width: 1024px) { + /* Main streamlit container (below header) */ + .css-18e3th9 { + padding-top: 0px; //reduce the empty spaces, can go fully to the top on widescreen devices + } +} + +/*********************************************************** +* Additional CSS for streamlit custom/3rd party components * +************************************************************/ +/* For stream_on_hover */ +section[data-testid="stSidebar"] > div:nth-of-type(1) { + background-color: #111; +} + +button[kind="header"] { + background-color: transparent; + color: rgb(180, 167, 141); +} + +@media (hover) { + /* header element */ + header[data-testid="stHeader"] { + /* display: none;*/ /*suggested behavior by streamlit hover components*/ + pointer-events: none; /* disable interaction of the transparent background */ + } + + /* The button on the streamlit navigation menu */ + button[kind="header"] { + /* display: none;*/ /*suggested behavior by streamlit hover components*/ + pointer-events: auto; /* enable interaction of the button even if parents intereaction disabled */ + } + + /* added to avoid main sectors (all element to the right of sidebar from) moving */ + section[data-testid="stSidebar"] { + width: 3.5% !important; + min-width: 3.5% !important; + } + + /* The navigation menu specs and size */ + section[data-testid="stSidebar"] > div { + height: 100%; + width: 2% !important; + min-width: 100% !important; + position: relative; + z-index: 1; + top: 0; + left: 0; + background-color: #111; + overflow-x: hidden; + transition: 0.5s ease-in-out; + padding-top: 0px; + white-space: nowrap; + } + + /* The navigation menu open and close on hover and size */ + section[data-testid="stSidebar"] > div:hover { + width: 300px !important; + } +} + +@media (max-width: 272px) { + section[data-testid="stSidebar"] > div { + width: 15rem; + } +} + +/*********************************************************** +* Additional CSS for other elements +************************************************************/ +button[data-baseweb="tab"] { + font-size: 20px; +} + +@media (min-width: 1200px){ +h1 { + font-size: 1.75rem; +} +} +#tabs-1-tabpanel-0 > div:nth-child(1) > div > div.stTabs.css-0.exp6ofz0 { + width: 50rem; + align-self: center; +} +div.gallery:hover { + border: 1px solid #777; +} \ No newline at end of file diff --git a/frontend/frontend.py b/frontend/frontend.py index 29d3c50..94c76c9 100644 --- a/frontend/frontend.py +++ b/frontend/frontend.py @@ -3,6 +3,8 @@ from frontend.css_and_js import css, js, call_JS, js_parse_prompt, js_copy_txt2i from frontend.job_manager import JobManager import frontend.ui_functions as uifn import uuid +import torch + def draw_gradio_ui(opt, img2img=lambda x: x, txt2img=lambda x: x, imgproc=lambda x: x, txt2img_defaults={}, @@ -36,8 +38,11 @@ def draw_gradio_ui(opt, img2img=lambda x: x, txt2img=lambda x: x, imgproc=lambda value=txt2img_defaults['cfg_scale'], elem_id='cfg_slider') txt2img_seed = gr.Textbox(label="Seed (blank to randomize)", lines=1, max_lines=1, value=txt2img_defaults["seed"]) + txt2img_batch_size = gr.Slider(minimum=1, maximum=50, step=1, + label='Images per batch', + value=txt2img_defaults['batch_size']) txt2img_batch_count = gr.Slider(minimum=1, maximum=50, step=1, - label='Number of images to generate', + label='Number of batches to generate', value=txt2img_defaults['n_iter']) txt2img_job_ui = job_manager.draw_gradio_ui() if job_manager else None @@ -51,11 +56,15 @@ def draw_gradio_ui(opt, img2img=lambda x: x, txt2img=lambda x: x, imgproc=lambda gr.Markdown( "Select an image from the gallery, then click one of the buttons below to perform an action.") with gr.Row(elem_id='txt2img_actions_row'): - gr.Button("Copy to clipboard").click(fn=None, - inputs=output_txt2img_gallery, - outputs=[], - # _js=js_copy_to_clipboard( 'txt2img_gallery_output') - ) + gr.Button("Copy to clipboard").click( + fn=None, + inputs=output_txt2img_gallery, + outputs=[], + _js=call_JS( + "copyImageFromGalleryToClipboard", + fromId="txt2img_gallery_output" + ) + ) output_txt2img_copy_to_input_btn = gr.Button("Push to img2img") output_txt2img_to_imglab = gr.Button("Send to Lab", visible=True) @@ -91,9 +100,6 @@ def draw_gradio_ui(opt, img2img=lambda x: x, txt2img=lambda x: x, imgproc=lambda with gr.TabItem('Advanced'): txt2img_toggles = gr.CheckboxGroup(label='', choices=txt2img_toggles, value=txt2img_toggle_defaults, type="index") - txt2img_batch_size = gr.Slider(minimum=1, maximum=8, step=1, - label='Batch size (how many images are in a batch; memory-hungry)', - value=txt2img_defaults['batch_size']) txt2img_realesrgan_model_name = gr.Dropdown(label='RealESRGAN model', choices=['RealESRGAN_x4plus', 'RealESRGAN_x4plus_anime_6B'], @@ -124,20 +130,27 @@ def draw_gradio_ui(opt, img2img=lambda x: x, txt2img=lambda x: x, imgproc=lambda inputs=txt2img_inputs, outputs=txt2img_outputs ) + use_queue = False + else: + use_queue = True txt2img_btn.click( txt2img_func, txt2img_inputs, - txt2img_outputs + txt2img_outputs, + api_name='txt2img', + queue=use_queue ) txt2img_prompt.submit( txt2img_func, txt2img_inputs, - txt2img_outputs + txt2img_outputs, + queue=use_queue ) - # txt2img_width.change(fn=uifn.update_dimensions_info, inputs=[txt2img_width, txt2img_height], outputs=txt2img_dimensions_info_text_box) - # txt2img_height.change(fn=uifn.update_dimensions_info, inputs=[txt2img_width, txt2img_height], outputs=txt2img_dimensions_info_text_box) + txt2img_width.change(fn=uifn.update_dimensions_info, inputs=[txt2img_width, txt2img_height], outputs=txt2img_dimensions_info_text_box) + txt2img_height.change(fn=uifn.update_dimensions_info, inputs=[txt2img_width, txt2img_height], outputs=txt2img_dimensions_info_text_box) + txt2img_dimensions_info_text_box.value = uifn.update_dimensions_info(txt2img_width.value, txt2img_height.value) # Temporarily disable prompt parsing until memory issues could be solved # See #676 @@ -189,8 +202,9 @@ def draw_gradio_ui(opt, img2img=lambda x: x, txt2img=lambda x: x, imgproc=lambda with gr.TabItem("Editor Options"): with gr.Row(): # disable Uncrop for now - # choices=["Mask", "Crop", "Uncrop"] - img2img_image_editor_mode = gr.Radio(choices=["Mask", "Crop"], + choices=["Mask", "Crop", "Uncrop"] + #choices=["Mask", "Crop"] + img2img_image_editor_mode = gr.Radio(choices=choices, label="Image Editor Mode", value="Mask", elem_id='edit_mode_select', visible=True) @@ -199,9 +213,13 @@ def draw_gradio_ui(opt, img2img=lambda x: x, txt2img=lambda x: x, imgproc=lambda value=img2img_mask_modes[img2img_defaults['mask_mode']], visible=True) - img2img_mask_blur_strength = gr.Slider(minimum=1, maximum=10, step=1, + img2img_mask_restore = gr.Checkbox(label="Only modify regenerated parts of image", + value=img2img_defaults['mask_restore'], + visible=True) + + img2img_mask_blur_strength = gr.Slider(minimum=1, maximum=100, step=1, label="How much blurry should the mask be? (to avoid hard edges)", - value=3, visible=False) + value=3, visible=True) img2img_resize = gr.Radio(label="Resize mode", choices=["Just resize", "Crop and resize", @@ -293,7 +311,7 @@ def draw_gradio_ui(opt, img2img=lambda x: x, txt2img=lambda x: x, imgproc=lambda img2img_height ], [img2img_image_editor, img2img_image_mask, img2img_btn_editor, img2img_btn_mask, - img2img_painterro_btn, img2img_mask, img2img_mask_blur_strength] + img2img_painterro_btn, img2img_mask, img2img_mask_blur_strength, img2img_mask_restore] ) # img2img_image_editor_mode.change( @@ -334,8 +352,8 @@ def draw_gradio_ui(opt, img2img=lambda x: x, txt2img=lambda x: x, imgproc=lambda ) img2img_func = img2img - img2img_inputs = [img2img_prompt, img2img_image_editor_mode, img2img_mask, - img2img_mask_blur_strength, img2img_steps, img2img_sampling, img2img_toggles, + img2img_inputs = [img2img_prompt, img2img_image_editor_mode, img2img_mask, img2img_mask_blur_strength, + img2img_mask_restore, img2img_steps, img2img_sampling, img2img_toggles, img2img_realesrgan_model_name, img2img_batch_count, img2img_cfg, img2img_denoising, img2img_seed, img2img_height, img2img_width, img2img_resize, img2img_image_editor, img2img_image_mask, img2img_embeddings] @@ -349,11 +367,16 @@ def draw_gradio_ui(opt, img2img=lambda x: x, txt2img=lambda x: x, imgproc=lambda inputs=img2img_inputs, outputs=img2img_outputs, ) + use_queue = False + else: + use_queue = True img2img_btn_mask.click( img2img_func, img2img_inputs, - img2img_outputs + img2img_outputs, + api_name="img2img", + queue=use_queue ) def img2img_submit_params(): @@ -383,6 +406,7 @@ def draw_gradio_ui(opt, img2img=lambda x: x, txt2img=lambda x: x, imgproc=lambda outputs=img2img_dimensions_info_text_box) img2img_height.change(fn=uifn.update_dimensions_info, inputs=[img2img_width, img2img_height], outputs=img2img_dimensions_info_text_box) + img2img_dimensions_info_text_box.value = uifn.update_dimensions_info(img2img_width.value, img2img_height.value) with gr.TabItem("Image Lab", id='imgproc_tab'): gr.Markdown("Post-process results") @@ -397,8 +421,7 @@ def draw_gradio_ui(opt, img2img=lambda x: x, txt2img=lambda x: x, imgproc=lambda # value=gfpgan_defaults['strength']) # select folder with images to process with gr.TabItem('Batch Process'): - imgproc_folder = gr.File(label="Batch Process", file_count="multiple", source="upload", - interactive=True, type="file") + imgproc_folder = gr.File(label="Batch Process", file_count="multiple", interactive=True, type="file") imgproc_pngnfo = gr.Textbox(label="PNG Metadata", placeholder="PngNfo", visible=False, max_lines=5) with gr.Row(): @@ -540,7 +563,7 @@ def draw_gradio_ui(opt, img2img=lambda x: x, txt2img=lambda x: x, imgproc=lambda imgproc_width, imgproc_cfg, imgproc_denoising, imgproc_seed, imgproc_gfpgan_strength, imgproc_ldsr_steps, imgproc_ldsr_pre_downSample, imgproc_ldsr_post_downSample], - [imgproc_output]) + [imgproc_output], api_name="imgproc") imgproc_source.change( uifn.get_png_nfo, @@ -631,11 +654,12 @@ def draw_gradio_ui(opt, img2img=lambda x: x, txt2img=lambda x: x, imgproc=lambda """ gr.HTML("""
-

For help and advanced usage guides, visit the Project Wiki

-

Stable Diffusion WebUI is an open-source project. - If you would like to contribute to development or test bleeding edge builds, use the dev branch.

+

For help and advanced usage guides, visit the Project Wiki

+

Stable Diffusion WebUI is an open-source project. You can find the latest stable builds on the main repository. + If you would like to contribute to development or test bleeding edge builds, you can visit the developement repository.

+

Device ID {current_device_index}: {current_device_name}
{total_device_count} total devices

- """) + """.format(current_device_name=torch.cuda.get_device_name(), current_device_index=torch.cuda.current_device(), total_device_count=torch.cuda.device_count())) # Hack: Detect the load event on the frontend # Won't be needed in the next version of gradio # See the relevant PR: https://github.com/gradio-app/gradio/pull/2108 diff --git a/frontend/image_metadata.py b/frontend/image_metadata.py new file mode 100644 index 0000000..8448088 --- /dev/null +++ b/frontend/image_metadata.py @@ -0,0 +1,57 @@ +''' Class to store image generation parameters to be stored as metadata in the image''' +from __future__ import annotations +from dataclasses import dataclass, asdict +from typing import Dict, Optional +from PIL import Image +from PIL.PngImagePlugin import PngInfo +import copy + +@dataclass +class ImageMetadata: + prompt: str = None + seed: str = None + width: str = None + height: str = None + steps: str = None + cfg_scale: str = None + normalize_prompt_weights: str = None + denoising_strength: str = None + GFPGAN: str = None + + def as_png_info(self) -> PngInfo: + info = PngInfo() + for key, value in self.as_dict().items(): + info.add_text(key, value) + return info + + def as_dict(self) -> Dict[str, str]: + return {f"SD:{key}": str(value) for key, value in asdict(self).items() if value is not None} + + @classmethod + def set_on_image(cls, image: Image, metadata: ImageMetadata) -> None: + ''' Sets metadata on image, in both text form and as an ImageMetadata object ''' + if metadata: + image.info = metadata.as_dict() + else: + metadata = ImageMetadata() + image.info["ImageMetadata"] = copy.copy(metadata) + + @classmethod + def get_from_image(cls, image: Image) -> Optional[ImageMetadata]: + ''' Gets metadata from an image, first looking for an ImageMetadata, + then if not found tries to construct one from the info ''' + metadata = image.info.get("ImageMetadata", None) + if not metadata: + found_metadata = False + metadata = ImageMetadata() + for key, value in image.info.items(): + if key.lower().startswith("sd:"): + key = key[3:] + if f"{key}" in metadata.__dict__: + metadata.__dict__[key] = value + found_metadata = True + if not found_metadata: + metadata = None + if not metadata: + print("Couldn't find metadata on image") + return metadata diff --git a/frontend/job_manager.py b/frontend/job_manager.py index 8eda8d9..026742f 100644 --- a/frontend/job_manager.py +++ b/frontend/job_manager.py @@ -1,7 +1,7 @@ ''' Provides simple job management for gradio, allowing viewing and stopping in-progress multi-batch generations ''' from __future__ import annotations import gradio as gr -from gradio.components import Component, Gallery +from gradio.components import Component, Gallery, Slider from threading import Event, Timer from typing import Callable, List, Dict, Tuple, Optional, Any from dataclasses import dataclass, field @@ -9,6 +9,7 @@ from functools import partial from PIL.Image import Image import uuid import traceback +import time @dataclass(eq=True, frozen=True) @@ -30,9 +31,21 @@ class JobInfo: session_key: str job_token: Optional[int] = None images: List[Image] = field(default_factory=list) + active_image: Image = None + rec_steps_enabled: bool = False + rec_steps_imgs: List[Image] = field(default_factory=list) + rec_steps_intrvl: int = None + rec_steps_to_gallery: bool = False + rec_steps_to_file: bool = False should_stop: Event = field(default_factory=Event) + refresh_active_image_requested: Event = field(default_factory=Event) + refresh_active_image_done: Event = field(default_factory=Event) + stop_cur_iter: Event = field(default_factory=Event) + active_iteration_cnt: int = field(default_factory=int) job_status: str = field(default_factory=str) finished: bool = False + started: bool = False + timestamp: float = None removed_output_idxs: List[int] = field(default_factory=list) @@ -76,7 +89,7 @@ class JobManagerUi: ''' return self._job_manager._wrap_func( func=func, inputs=inputs, outputs=outputs, - refresh_btn=self._refresh_btn, stop_btn=self._stop_btn, status_text=self._status_text + job_ui=self ) _refresh_btn: gr.Button @@ -84,10 +97,19 @@ class JobManagerUi: _status_text: gr.Textbox _stop_all_session_btn: gr.Button _free_done_sessions_btn: gr.Button + _active_image: gr.Image + _active_image_stop_btn: gr.Button + _active_image_refresh_btn: gr.Button + _rec_steps_intrvl_sldr: gr.Slider + _rec_steps_checkbox: gr.Checkbox + _save_rec_steps_to_gallery_chkbx: gr.Checkbox + _save_rec_steps_to_file_chkbx: gr.Checkbox _job_manager: JobManager class JobManager: + JOB_MAX_START_TIME = 5.0 # How long can a job be stuck 'starting' before assuming it isn't running + def __init__(self, max_jobs: int): self._max_jobs: int = max_jobs self._avail_job_tokens: List[Any] = list(range(max_jobs)) @@ -102,11 +124,23 @@ class JobManager: ''' assert gr.context.Context.block is not None, "draw_gradio_ui must be called within a 'gr.Blocks' 'with' context" with gr.Tabs(): - with gr.TabItem("Current Session"): + with gr.TabItem("Job Controls"): with gr.Row(): - stop_btn = gr.Button("Stop", elem_id="stop", variant="secondary") - refresh_btn = gr.Button("Refresh", elem_id="refresh", variant="secondary") + stop_btn = gr.Button("Stop All Batches", elem_id="stop", variant="secondary") + refresh_btn = gr.Button("Refresh Finished Batches", elem_id="refresh", variant="secondary") status_text = gr.Textbox(placeholder="Job Status", interactive=False, show_label=False) + with gr.Row(): + active_image_stop_btn = gr.Button("Skip Active Batch", variant="secondary") + active_image_refresh_btn = gr.Button("View Batch Progress", variant="secondary") + active_image = gr.Image(type="pil", interactive=False, visible=False, elem_id="active_iteration_image") + with gr.TabItem("Batch Progress Settings"): + with gr.Row(): + record_steps_checkbox = gr.Checkbox(value=False, label="Enable Batch Progress Grid") + record_steps_interval_slider = gr.Slider( + value=3, label="Record Interval (steps)", minimum=1, maximum=25, step=1) + with gr.Row() as record_steps_box: + steps_to_gallery_checkbox = gr.Checkbox(value=False, label="Save Progress Grid to Gallery") + steps_to_file_checkbox = gr.Checkbox(value=False, label="Save Progress Grid to File") with gr.TabItem("Maintenance"): with gr.Row(): gr.Markdown( @@ -118,9 +152,15 @@ class JobManager: free_done_sessions_btn = gr.Button( "Clear Finished Jobs", elem_id="clear_finished", variant="secondary" ) + return JobManagerUi(_refresh_btn=refresh_btn, _stop_btn=stop_btn, _status_text=status_text, _stop_all_session_btn=stop_all_sessions_btn, _free_done_sessions_btn=free_done_sessions_btn, - _job_manager=self) + _active_image=active_image, _active_image_stop_btn=active_image_stop_btn, + _active_image_refresh_btn=active_image_refresh_btn, + _rec_steps_checkbox=record_steps_checkbox, + _save_rec_steps_to_gallery_chkbx=steps_to_gallery_checkbox, + _save_rec_steps_to_file_chkbx=steps_to_file_checkbox, + _rec_steps_intrvl_sldr=record_steps_interval_slider, _job_manager=self) def clear_all_finished_jobs(self): ''' Removes all currently finished jobs, across all sessions. @@ -134,6 +174,7 @@ class JobManager: for session in self._sessions.values(): for job in session.jobs.values(): job.should_stop.set() + job.stop_cur_iter.set() def _get_job_token(self, block: bool = False) -> Optional[int]: ''' Attempts to acquire a job token, optionally blocking until available ''' @@ -175,6 +216,26 @@ class JobManager: job_info.should_stop.set() return "Stopping after current batch finishes" + def _refresh_cur_iter_func(self, func_key: FuncKey, session_key: str) -> List[Component]: + ''' Updates information from the active iteration ''' + session_info, job_info = self._get_call_info(func_key, session_key) + if job_info is None: + return [None, f"Session {session_key} was not running function {func_key}"] + + job_info.refresh_active_image_requested.set() + if job_info.refresh_active_image_done.wait(timeout=20.0): + job_info.refresh_active_image_done.clear() + return [gr.Image.update(value=job_info.active_image, visible=True), f"Sample iteration {job_info.active_iteration_cnt}"] + return [gr.Image.update(visible=False), "Timed out getting image"] + + def _stop_cur_iter_func(self, func_key: FuncKey, session_key: str) -> List[Component]: + ''' Marks that the active iteration should be stopped''' + session_info, job_info = self._get_call_info(func_key, session_key) + if job_info is None: + return [None, f"Session {session_key} was not running function {func_key}"] + job_info.stop_cur_iter.set() + return [gr.Image.update(visible=False), "Stopping current iteration"] + def _get_call_info(self, func_key: FuncKey, session_key: str) -> Tuple[SessionInfo, JobInfo]: ''' Helper to get the SessionInfo and JobInfo. ''' session_info = self._sessions.get(session_key, None) @@ -207,19 +268,22 @@ class JobManager: def _pre_call_func( self, func_key: FuncKey, output_dummy_obj: Component, refresh_btn: gr.Button, stop_btn: gr.Button, - status_text: gr.Textbox, session_key: str) -> List[Component]: + status_text: gr.Textbox, active_image: gr.Image, active_refresh_btn: gr.Button, active_stop_btn: gr.Button, + session_key: str) -> List[Component]: ''' Called when a job is about to start ''' session_info, job_info = self._get_call_info(func_key, session_key) # If we didn't already get a token then queue up for one if job_info.job_token is None: - job_info.token = self._get_job_token(block=True) + job_info.job_token = self._get_job_token(block=True) # Buttons don't seem to update unless value is set on them as well... return {output_dummy_obj: triggerChangeEvent(), refresh_btn: gr.Button.update(variant="primary", value=refresh_btn.value), stop_btn: gr.Button.update(variant="primary", value=stop_btn.value), - status_text: gr.Textbox.update(value="Generation has started. Click 'Refresh' for updates") + status_text: gr.Textbox.update(value="Generation has started. Click 'Refresh' to see finished images, 'View Batch Progress' for active images"), + active_refresh_btn: gr.Button.update(variant="primary", value=active_refresh_btn.value), + active_stop_btn: gr.Button.update(variant="primary", value=active_stop_btn.value), } def _call_func(self, func_key: FuncKey, session_key: str) -> List[Component]: @@ -228,12 +292,19 @@ class JobManager: if session_info is None or job_info is None: return [] + job_info.started = True try: + if job_info.should_stop.is_set(): + raise Exception(f"Job {job_info} requested a stop before execution began") outputs = job_info.func(*job_info.inputs, job_info=job_info) except Exception as e: job_info.job_status = f"Error: {e}" print(f"Exception processing job {job_info}: {e}\n{traceback.format_exc()}") - outputs = [] + raise + finally: + job_info.finished = True + session_info.finished_jobs[func_key] = session_info.jobs.pop(func_key) + self._release_job_token(job_info.job_token) # Filter the function output for any removed outputs filtered_output = [] @@ -241,11 +312,6 @@ class JobManager: if idx not in job_info.removed_output_idxs: filtered_output.append(output) - job_info.finished = True - session_info.finished_jobs[func_key] = session_info.jobs.pop(func_key) - - self._release_job_token(job_info.job_token) - # The wrapper added a dummy JSON output. Append a random text string # to fire the dummy objects 'change' event to notify that the job is done filtered_output.append(triggerChangeEvent()) @@ -254,12 +320,16 @@ class JobManager: def _post_call_func( self, func_key: FuncKey, output_dummy_obj: Component, refresh_btn: gr.Button, stop_btn: gr.Button, - status_text: gr.Textbox, session_key: str) -> List[Component]: + status_text: gr.Textbox, active_image: gr.Image, active_refresh_btn: gr.Button, active_stop_btn: gr.Button, + session_key: str) -> List[Component]: ''' Called when a job completes ''' return {output_dummy_obj: triggerChangeEvent(), refresh_btn: gr.Button.update(variant="secondary", value=refresh_btn.value), stop_btn: gr.Button.update(variant="secondary", value=stop_btn.value), - status_text: gr.Textbox.update(value="Generation has finished!") + status_text: gr.Textbox.update(value="Generation has finished!"), + active_refresh_btn: gr.Button.update(variant="secondary", value=active_refresh_btn.value), + active_stop_btn: gr.Button.update(variant="secondary", value=active_stop_btn.value), + active_image: gr.Image.update(visible=False) } def _update_gallery_event(self, func_key: FuncKey, session_key: str) -> List[Component]: @@ -270,21 +340,17 @@ class JobManager: if session_info is None or job_info is None: return [] - if job_info.finished: - session_info.finished_jobs.pop(func_key) - return job_info.images - def _wrap_func( - self, func: Callable, inputs: List[Component], outputs: List[Component], - refresh_btn: gr.Button = None, stop_btn: gr.Button = None, - status_text: Optional[gr.Textbox] = None) -> Tuple[Callable, List[Component]]: + def _wrap_func(self, func: Callable, inputs: List[Component], + outputs: List[Component], + job_ui: JobManagerUi) -> Tuple[Callable, List[Component]]: ''' handles JobManageUI's wrap_func''' assert gr.context.Context.block is not None, "wrap_func must be called within a 'gr.Blocks' 'with' context" # Create a unique key for this job - func_key = FuncKey(job_id=uuid.uuid4(), func=func) + func_key = FuncKey(job_id=uuid.uuid4().hex, func=func) # Create a unique session key (next gradio release can use gr.State, see https://gradio.app/state_in_blocks/) if self._session_key is None: @@ -302,31 +368,59 @@ class JobManager: del outputs[idx] break - # Add the session key to the inputs - inputs += [self._session_key] - # Create dummy objects update_gallery_obj = gr.JSON(visible=False, elem_id="JobManagerDummyObject") update_gallery_obj.change( partial(self._update_gallery_event, func_key), [self._session_key], - [gallery_comp] + [gallery_comp], + queue=False ) - if refresh_btn: - refresh_btn.variant = 'secondary' - refresh_btn.click( + if job_ui._refresh_btn: + job_ui._refresh_btn.variant = 'secondary' + job_ui._refresh_btn.click( partial(self._refresh_func, func_key), [self._session_key], - [update_gallery_obj, status_text] + [update_gallery_obj, job_ui._status_text], + queue=False ) - if stop_btn: - stop_btn.variant = 'secondary' - stop_btn.click( + if job_ui._stop_btn: + job_ui._stop_btn.variant = 'secondary' + job_ui._stop_btn.click( partial(self._stop_wrapped_func, func_key), [self._session_key], - [status_text] + [job_ui._status_text], + queue=False + ) + + if job_ui._active_image and job_ui._active_image_refresh_btn: + job_ui._active_image_refresh_btn.click( + partial(self._refresh_cur_iter_func, func_key), + [self._session_key], + [job_ui._active_image, job_ui._status_text], + queue=False + ) + + if job_ui._active_image_stop_btn: + job_ui._active_image_stop_btn.click( + partial(self._stop_cur_iter_func, func_key), + [self._session_key], + [job_ui._active_image, job_ui._status_text], + queue=False + ) + + if job_ui._stop_all_session_btn: + job_ui._stop_all_session_btn.click( + self.stop_all_jobs, [], [], + queue=False + ) + + if job_ui._free_done_sessions_btn: + job_ui._free_done_sessions_btn.click( + self.clear_all_finished_jobs, [], [], + queue=False ) # (ab)use gr.JSON to forward events. @@ -343,7 +437,8 @@ class JobManager: # Since some parameters are optional it makes sense to use the 'dict' return value type, which requires # the Component as a key... so group together the UI components that the event listeners are going to update # to make it easy to append to function calls and outputs - job_ui_params = [refresh_btn, stop_btn, status_text] + job_ui_params = [job_ui._refresh_btn, job_ui._stop_btn, job_ui._status_text, + job_ui._active_image, job_ui._active_image_refresh_btn, job_ui._active_image_stop_btn] job_ui_outputs = [comp for comp in job_ui_params if comp is not None] # Here a chain is constructed that will make a 'pre' call, a 'run' call, and a 'post' call, @@ -352,44 +447,70 @@ class JobManager: post_call_dummyobj.change( partial(self._post_call_func, func_key, update_gallery_obj, *job_ui_params), [self._session_key], - [update_gallery_obj] + job_ui_outputs + [update_gallery_obj] + job_ui_outputs, + queue=False ) call_dummyobj = gr.JSON(visible=False, elem_id="JobManagerDummyObject_runCall") call_dummyobj.change( partial(self._call_func, func_key), [self._session_key], - outputs + [post_call_dummyobj] + outputs + [post_call_dummyobj], + queue=False ) pre_call_dummyobj = gr.JSON(visible=False, elem_id="JobManagerDummyObject_preCall") pre_call_dummyobj.change( partial(self._pre_call_func, func_key, call_dummyobj, *job_ui_params), [self._session_key], - [call_dummyobj] + job_ui_outputs + [call_dummyobj] + job_ui_outputs, + queue=False ) - # Now replace the original function with one that creates a JobInfo and triggers the dummy obj + # Add any components that we want the runtime values for + added_inputs = [self._session_key, job_ui._rec_steps_checkbox, job_ui._save_rec_steps_to_gallery_chkbx, + job_ui._save_rec_steps_to_file_chkbx, job_ui._rec_steps_intrvl_sldr] - def wrapped_func(*inputs): - session_key = inputs[-1] - inputs = inputs[:-1] + # Now replace the original function with one that creates a JobInfo and triggers the dummy obj + def wrapped_func(*wrapped_inputs): + # Remove the added_inputs (pop opposite order of list) + + wrapped_inputs = list(wrapped_inputs) + rec_steps_interval: int = wrapped_inputs.pop() + save_rec_steps_file: bool = wrapped_inputs.pop() + save_rec_steps_grid: bool = wrapped_inputs.pop() + record_steps_enabled: bool = wrapped_inputs.pop() + session_key: str = wrapped_inputs.pop() + job_inputs = tuple(wrapped_inputs) # Get or create a session for this key session_info = self._sessions.setdefault(session_key, SessionInfo()) # Is this session already running this job? if func_key in session_info.jobs: - return {status_text: "This session is already running that function!"} + job_info = session_info.jobs[func_key] + # If the job seems stuck in 'starting' then go ahead and toss it + if not job_info.started and time.time() > job_info.timestamp + JobManager.JOB_MAX_START_TIME: + job_info.should_stop.set() + job_info.stop_cur_iter.set() + session_info.jobs.pop(func_key) + return {job_ui._status_text: "Canceled possibly hung job. Try again"} + return {job_ui._status_text: "This session is already running that function!"} + + # Is this a new run of a previously finished job? Clear old info + if func_key in session_info.finished_jobs: + session_info.finished_jobs.pop(func_key) job_token = self._get_job_token(block=False) - job = JobInfo(inputs=inputs, func=func, removed_output_idxs=removed_idxs, session_key=session_key, - job_token=job_token) + job = JobInfo( + inputs=job_inputs, func=func, removed_output_idxs=removed_idxs, session_key=session_key, + job_token=job_token, rec_steps_enabled=record_steps_enabled, rec_steps_intrvl=rec_steps_interval, + rec_steps_to_gallery=save_rec_steps_grid, rec_steps_to_file=save_rec_steps_file, timestamp=time.time()) session_info.jobs[func_key] = job ret = {pre_call_dummyobj: triggerChangeEvent()} if job_token is None: - ret[status_text] = "Job is queued" + ret[job_ui._status_text] = "Job is queued" return ret - return wrapped_func, inputs, [pre_call_dummyobj, status_text] + return wrapped_func, inputs + added_inputs, [pre_call_dummyobj, job_ui._status_text] diff --git a/frontend/ui_functions.py b/frontend/ui_functions.py index 6557841..ee6af8d 100644 --- a/frontend/ui_functions.py +++ b/frontend/ui_functions.py @@ -9,10 +9,10 @@ import re def change_image_editor_mode(choice, cropped_image, masked_image, resize_mode, width, height): if choice == "Mask": update_image_result = update_image_mask(cropped_image, resize_mode, width, height) - return [gr.update(visible=False), update_image_result, gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=True), gr.update(visible=True)] + return [gr.update(visible=False), update_image_result, gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)] update_image_result = update_image_mask(masked_image["image"] if masked_image is not None else None, resize_mode, width, height) - return [update_image_result, gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)] + return [update_image_result, gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)] def update_image_mask(cropped_image, resize_mode, width, height): resized_cropped_image = resize_image(resize_mode, cropped_image, width, height) if cropped_image else None diff --git a/images/nsfw.jpeg b/images/nsfw.jpeg new file mode 100644 index 0000000000000000000000000000000000000000..0ecf3b68b55af111bf8a73f5436add752b1c95c3 GIT binary patch literal 25276 zcmeHv2UrtJ*Z(F|DN+OkK@6aXN+1akiXy#A2fLJzgramp5tSk;NL4^VMFD9l3L+{Z zO{7^-Kvaq#L8V!U1raIVB!F_?dtcv;?{}Z?`9J@=>}Jo*{LYzk&YYduomqDJ`SeSO z-_+2=5JDgjkTLj!rhBTq902TY>cxrpv7nDoq(2` z#UKHVMIa$waP|WS4$wm291afmC{_k<3ZV1A0jgu9=gvkZ=2j@20v4m7gae!y98LqH zu7Sm%Flrh&tOiyEpb?orbphK2QC?nAoRowVj_Gas79_w)*NFuo4Iu>(ECPt>I%pZtlMS>qy|HIE1d@f7jh%yY z0T(v_T;qp;r&y4zENpD7tN^_altZinY=Tl)9ri`m1P*CmAzZ}KEKV8S;u>L_=8@$} z&VKtBaEUAy6$HgDOwjYuN9xVpLT^rr*_ z1_kd9i9B%dP*n8cn55*CW5-j|PMkcOos*lFfA0K+D+cJp>^U4292oxAs1 zTA#E%efGS)qnFzE>h+ucfx#i#$4{TXjD8*aHcq#TZs*)KW7xmgB>?O~va+(Ua?tHU zAcMfcBEZTfg=H7ivF0H7E|SJYa0=-j%_^>0Afse6BJAwf%q6m1xpxJPZW_Zfc!ur& zEz5ow_S3Eoh=&CM8jnQ)(t^HC#mHqMfA_dF==f3e$~06IJq>-9+PJ#=*fb;+t2jQE zI1T9vp;lWP2tmJlY<#$KN)NB~g|=+!!0OXGTEbo_5vL(XfoW(!_W+xQ+}Uuh-#yB6 z!tU&>ovM(lr*~OJ&ZGif1rAu?Cp*s%ur^@qzKXSLXRKLQKfBM7pcYQ(UE6WYL zsw58Hy9YE7k1nqI)$j}LNk*;yMvLce(k2N^bwzR5U~;Z6?-Vbedxi!L%($Xc82 zcaI${kVad$=KAK@0MUA-Lip*B|0+Pr=mGN6AtBbC1?|gB`n1MIvJfCX{)-F7J>#QD z$uv|rI1OcS6>O&Xp7y#G($z>$LAk#hb#%VX6#sayTjM(zqH;HVF8ao^HSAek*f{@x zsdAzFx?z#I%h0}g?CD#r;Wyw?v0_{O-^5&bj>N7_94+*VZsY2*Tqq!xKSZ3-jc5xE zoQ;9xe)CxUCK>cWrsK- zkxZ^98xt>ghdpAHHIIr{fBTX#HF**$GO+xSSYI174XrN;>nYLZx85|HnudBxjGilM z^|#O~LskmQ)F^viqaL-)UM19no`!Dp9E0#NfsIqcp>}6iiN9~?ACN<&*d1DsLaKYY z=X~B<@Anz^!o>{~o=y!;oa^X$p>+&eAymHod;Z5BKq2 zRojZvslHIV-QJT8^rYQgcYoBq?e}WitWE~aX%cojd`)LZ@PWDmp9RkLh;d1RP^#!AdtMB$qGboPsLz@k0Mq9qX@0j3icrocabS5 zXR@oim!|ZG@@i?6J4sV|vx+&!+((b>=58F~N45>Iup@@-BC3<5wbrB71#1L*`goEl z1XQr6hnK%bu%&aBCZglLj*UNG=*S z`i8R-KuS}3R@I=OAcY_$1#drBMXb8Ix*`Urh{K@)0`0%si$VxSd-==ENzfH`vwP(}%$UJo^I$ z=pUfJRGwZnuuwzan;1Z^#6(|Hnl4F$&65Io75OjJ`>Ml0h~Flc2Zf(u&Rnc$38 zR&~LVRn=Tv2qfYhzonl$2qS{We0}J0NOU>s&gxhdJQ+=PQBp@ME0eHjXOfEwT3y+N zsG@?wC@Z;;rBNiJhM~8gCjqpByC=bwthm#M>?)0-`%6Q|%0yEdrvS8NB&|FM6c>Qi zls0qs3J9LfwR86*+foR0&tdU+oEjF7QN^ocFe*6wEMgPc&mY7KT@>R{%ma7QuXQG-CFN4KW5KOvAzlAg^Zx)How$>2H#u`&a5|BI$m z!#b-Gu>=y@#hIjnR>lICgXUI3V~OfSb!Qba38Shq$M5g$LJ1=Hk#$@_903o2C}KPc zqU0FfT{VXbawF4S2Yi6WsG)Hdc%yaofF7BSL_|BL$Llt@Zs6%`dU8HXo< z_#=VPC6JWSB+#CCq8eFEiA0#w>v!~5M`JN_ZHS+3L%RNo^sZ#4DHQ*;l^A97x}7;V z=*#HjkFIe&>Ywc&5B%eSe?0Jy2mbNEKOXr1&jUXnuE<{CPAv#LI!#lNmreBa9Ib3D z4Nc4pz!W){-p=$Odit|sAjs2;;%8%|ivqK?D2_K^W}OSn;)A(+0@2?`$I8lVZqE46 z=h3J1>2C;XMbl;dx%`(oyd>}t1?HMj0Hsaz@uL9R3DEc;iVvMm1GEsh;Uoe2BB15{ zfP#R2PS1Cqq2JPJR|ZXAfPjSj>}>RaZ9sbnT%fTT589^qHIb;Q)AZI85a)&%13WNq{FOcI8*?@B0 zf5X3o!EXh)oB@|RLKB5g7asy$t{cY(IVL??PV|0ULrx7lPQjr>8%y0LwGU z5HxvYdis0T^z`Ieu)NU)LAShq@_T1OklGiJf8=N0@(c*#I{-n~>we}rr$SI^6j&l@ z_aXQZ81>LsLXa+CDd=N91o3W#phb@%h->?d-#{6?9LjzML3Y4bW)C6g#7PK}bp^6_ z{2ROJ>q$TS_DjsH{}@Xub5r}jIwO$?1lSSCx$P%~WGt!tWc~we4*nm{)3?Er3d;fT zhXAVyb4x0c^lJ#oSbd;7gT5^QNmh0?4o)txG_ho+c&23LmxT>1MX*oLPVe)x@Ph`O zhO}PJ$NZ9(j~_jHGs@?enE8cU?!S{cz3}DytY6Z_i<^x?#I%2jnOBG)6rD{}#)Z$z z`X!01b2L~wK>j6Wej!_IRf~qz`}tYFq`8;w6Y|YU`6Xt4p@_qG<7T}Z^Rs?Q%fB8y zn!tqnx9%CpHsHzprOf$-SEdr=n)|_(4Rd2V`Ps+NQC$rt9CV=7*)hty?`p$kK)7m3 zSSL1_!o@?}dj_b{4t`d!8MZyLzArRU4W9wM`O@o{stxC0*C3Mk#IOF5;uF{m+2bWmA(R{jb7i5We3`GwxCEfX%@5dSEQaJN+GO1~$Jc>d~v3nXnlO zH+Ko*qbgwIg{dF*>$L=h!(~9$IQkfMNRGh8LtDI!>LvH~yTN9V=gX!hb~DimB7)zK zYd!{d1u!>ceUD$GsbnThI0S9!9LPS#M5k9=>T9&xL>sPr1hJ5spqvyA7mvK}Xt2bM zi7_Ydv#&vhx*S~jkizEdm?hts=u-nUC9H~==~E95P@}4t=u_>zjf7TxV!}lqdJH(_ zm~c^URSeqd9&8VB^X(G~W@cEHd@n~cf!_l1yj|AgN$`h|&BHfxc$a6w*1t5US>Ke2 zyCScmGG{a&e0|Vb+0!d+^O@;MFMi%@ZP~=ckoKg(ZOmsu&*Im{NndpYL4ByC(XQCUrtg2x& zA>RUepV(V&c&G=)$A#hdLI78C32=iuGkj-7>{Plmv<}}g6)2`lh?=M+(GaPOO z2$I)sVt^|A+Z6fSNAkP!m}r88>#q}3m}!FRUBwreX#&H}|agXAFRwrpX!f%2#l{q%jh4As4hmCW1cVxm13Fy)bnu+8BZMb;E zmObA{Upe6GvpRLJP3eC4c)NvT6%Hfrt;8KIyLLrfeK2lBc^=qQmh2T9mKwFB-$AKC zsn){`QoAovR!cY&QDXMq)l2F5Q?rulO+3@k+QHJK{-*Td0=41d8;whb8zNHG8#{U& z2?~2mRQtoUPVYS=Gvr=)dg)$lbHkxb8Ts*O{Aef90Gj8;nL=;7t>K2+`5)zn=2#HUfLCcQ~1oI`$R)6I{KRy=*D zi;q_G-f%7XsvZF5f@S%^?Fy2eK3P2HTLe3&6L{O6LqK>jo=rqnaFGz07v==z2sUJz zYEs)`jCX)rT$*ar35<6o&Z&pHPnP#>70#(UmimUbt7^5=ix9J#^3UfZgA`-ZRK1R2 zye2`Oo5oRF_BI{6gK4TfwzYZ>#N(Ic=+~4{rv#MP$HK%g-pd_QF?V?MvX)}*=#70l zq`SEB0gf;K0d5gE=;!$d`)TV|_Q_Tk-mjKsX{yw*_Q-u`>k{^_t+Kqsxevsf_wJ~v zq?APl<=ffSlnIRM*OU;1-P9wUpFPgw%@@4MU&3CuEi;cd^hz!Z#%puHOKDl&PmVy+ zy{m?1;$O|vRL57J7sq%~Y;8oGIv(}r@irtY>(|_rj6MT<$WO3e*|?KWwr z&VbMEk>s3s4vfd?g--1j8G32>N=*@`?&3sV;Nf~&3H#W#Y@k=3nRxpIsirEU!A!hO z)~zhk=l0}TS>BV0ErL$%7n(URp278p4yNI|6g!qUb?{UQI(2C|9L0D$=uJeI9nx#P zpzG9iS*~8+Ogz`sRMzQ5G1mdRnhL@(R`x0KUulUMsiQV+a8cx*v)-QeThR;Q)vSH0Y7aazLcv=AA&GpqLS zk*)YNwxmIy-MBOfT&m~^jT1S((*Cyh^m`Y|T7W%arA6dD$^iAi+@3N*W$g0j=wy_`lRwC}-u6ey&^K$#8h~Qn;S32U4 zKk>MpO}-|MKBvg$MkHJn=goY6iM`Pczed~SyON|BH+Anp%IhT_jeL1ge6mIh^JaHI z&+I>P<_Miv<8D#>vQhkb^LCJwjUas_iCD-Q@#Z1Ps-%?5;KEiE0RK+H#=^pi0Bbvp zOB;c7f>^LPUseHY!9~qV5j7)z{HUWsQWzbdEXEZMIsq+2@Z!V4g+;<`)SDF>>H~sP zXHtFGeJfpKhG_NM_(JhANi=%De)S}XBx@UW)- zl-7#Ai(eLvO{6BAr6npZHRzLc(1>kVR}ru{Zg=odrA}fJtt}*S*^0gkU)a}gm=Qup z&Pk@%1nQqvW%K6Q%Ks(nNBw`SWLC8~O=s#<6#Ynh;(&eKcdmkXpI*c;@=uZ-l)OZz zp<>_$VEe53z!^Vqb3*BL{6XBD=0K-EVb=FX#eF{{sq|iS<QHu<;_V?%1+@eHDeQ zV#CV@_I*tly52)r`r1bZC- z925KT;YSx7PaU6zezc7fB9{WfZMJLK^mexGLt}%{RI>^ejF;GqD@>CusuXoHCE_T`-2iebmBpDBG}m3SdeV2j1^lDtPq=%Uoi`-kdm@>MAnt&5dpzP{x$qqoQZXZw|-V|4w2Fxhcd}^rrIt^v`)%ihHX1pqcYP ztAc+^T77fgj{aP1>`}?#OIhImSU|bYz`ONXgZj>|S~~XW0mQ>x$0wSr;&yawoSc`m z|IaeB&B`8`5*It1_CG0$4H$f;ewMh{OE+KH%z3!A&|kknAs*Kc1kXwOJe=N)jV-b$ zd$>sS=1ihg$`-H&zFJL3S8DVF)<3E}TZQ2X|1&8+|D?lzmO6uY zcgODEc>hJx`FK|6_jm0Y_G5!c7l=BdiGwOGkX5nu=xx#hRta3SF*1hQ+;lzDxYgUDiHrn9;~kuL|^zl`t&a; zo`TaymYv^Gju(04f4_v_@paw*0*Nh`p$ls+`xGgDQ+2(M_rbNt$85{~cp^B(-~S0A zI7LoFE1P&}E1$JKxpFay;AazG2eM#F_89ff~{@DGohFrPj+?7j0EPL-g z$Cf(O#J>hIt`r4^hw5bnHgrrq`SkvIr(?&NeVyXFzS}-3m;JWWa?o=9yO^ub{og+? zFpru~zAc!%{DsbPFQl7eQeaAW>OtzYml>AMTDPm)lS8k`e9q4ZH2?S{qvg!LcGY9a z-#}HKc9yTiySI05k-w}J?rkVD!lNg$Mk;wykM+?$tL_U~s0#Py}( z8EW{y_`v`-_p_^l2iQOl4dCXVKZHQT{wCgoC0!riKHskSNa2Ck9YEfG8~-R>oZGeR zb^FD)^=Y4ZLEos$&*&Rj*b$r@3m79e8~FE77N=_h%7H1ZZAhggDV*UZmt0wPL22uV zn&u}$B3K>&fUHG$8)v_x$+uf(Z`awikm(Dk2d`$99A0N}Jw?QK<>99``}2AtZ9d_ zN!=N_?IsJZRD;=D2C z@~6WJ_d{n91*-`8f?H0I<&>$Gr86qerWReyWlJ}q2K>Rw7P^&;dNW*S43H0sIo^V3 zT9HAnUaS~Y5|O@P>zgm1ofaK-^||>h^m<_!J!^D3Eo2A5g<;|`fU*fy`>)N{eblr;fRBXAKz~#FpcCmGh#R~a-8N%os zmN{Q;*!$$>k%hi%K2nY$#&B!>8h5Fcg2%s|_e4ohO`f*XaE;%$ zxYs4fY&H{6J6gAMaQ{j{wRn4sKGjH|AdfTMWc?$-i$T{8rJHKLFQw)-jfVRlcDNt9 z#@?7s?f6@o-#A9B?z)1bAwunV>BIX2w_F7)tK{vC*wm8!GpGSYqd$!J_hAvZn-Le* zz%e%cdw{MbliO}4u=}E8Q=4x!t*{@ee>Cdjc!0Wbs$jwKVRsPU z-d`9opIm7QqELKfCcs(`ui5*6-ik7ZMQmA`mfH`hP-+$(=H6RZcv9Bap?`P8Mm5_& zw94o?z3bPi^Oii2>O48N3lrtFYw+UjrDa8jN;m`GREhm-cD|AvzM*>QgT&4y6MJx4 zyLNq>c>L~CQDceT==JI|5BWMei+#K9@fH{5ec}v!T6GHWZyTQ{Hjg28UTwYeYMaS+ zr7O{!{!@vQbR}kmRNP*A4ak0BK1 z=_}k*{Ya3dXL6Bxx%BGcgJ-+yHGQsbPs=ZqjP(}^&};E>PB<50{=mA8&!j9yBWq9< zYh`$ldkDyOS{snzFGg9aN%eQ!VZdOkdL=lnFZx<~c%evvR?)z{+1I?JxF_P(du z99K4(QMH_}Oes*`c9eFzc}8YWB9K{5yLm&ovs%0$(i|pPxwrs~NxVJrZ8dp4rqZcl z`3cK<*wUw=cgLrpyDT3cd>yIZ^nHDmnihrqL*r0X-ixtGpl0>53zEq}eP3W$1dJB} z#($wJdgew5de-=m-hooAw`$Ub!-``YI>WcH3{tsr|BBl{8~TV)Jn z|E$vJU3{qKiDlQWTK8qIiDk!CsZ{LPp!TvZW}p~#EBveah?8E>gKy`@v;wYIPF?~- z&1x{zAX(Vy(?7GbKL`OQ2wTH1rDILVLizYc94!v-9XUdqn*q|23RdWrJ^pkozb-nv z&OlOCj^^ndwmDg&P;NlXlnkMM0fR;{#qvT*sz&4jR+Ym2)&5suWj+nQSB3 zni^SmOe>{6LXeenY2CVwvdIm?!Q)P6CFSxTSlY$JpQCx`1m|T{QfR!Dl5&);V2*-D zgLNh$I`&qdSsJUd`ZyE3`?BlJbpqCBR#5cn1x{8_`OeG9^n3^|I+xwPt->mJXv2|i z>0S1H$KM3sNg|hr71X`)>qhY{J=80S(N1ox|5VgYC8nP`H$c-%-*j+``=`TGs8y;5 z9ac{*d@NEV$SSGAZ(68cH@0}$Wtyv8@X#W}Ck~^S#KmnARO9G%UM5s2y#AT>9O4Z{ zlKc033C>lKRDYJ3zBtx2=$Hdaj^?BuW?UrVE2;K6A+|x{oaA!WQA>ejEw98y0^FD7 zXiAqwUhl4p;E02IBB;C8ZBdGgN=SY>K?y!bsouA{o`9#R$gz!{qp>!XdSB{m=I|@x z&tHb&6Rh1pYooH%YlLm>`%>GUUAM!wj`X~+9(8Ncw|io3LfPcN^qco8q*rqh2DrnH z#SU8Q-oB-?YKZ{aK-bj*XXU=}&>0WMs{%DP2f4&G)(MPcdj`H^jocycvQK~`igTd_ zDhZK<9=KKDuOf!@Q;j&;zhsaxP4IgG z`_+7U9~RfL-{_U|>8*|CR)~^ZWarj=R$fZ(oVrWX6QPEu$O%U%{T3$+h*HMqDHtyS zgNaSJ8wxHEeANZ+#(3o$F8FW2Y5xr}+2G5}gbjt?PE5EIzEDosiZTI2!NxLSQ#fp) z@U5Sjdx0g`2NsGBS_m77(1s0#-%fKw84S^xFfcO^_!mAD4qGVvc7iPuLcp&c1zYO> z@r4gtPw>w=u(3R_k$;Oz|DOpcCR_^Jqkj*OiM~`Dwg&L~9oQn70dm0xGtrmI!4}HI zg&z)Es3>eC6D zw#b=py)prZ1!bb&LBSG82Qd*Q;jo3mcPae~v@kbqm^fx&Ot_R8v2d|}OScDQ-h#qS zP~G8>IBJE4i()6VI8)x_Rg$&{)C-B|>N;%_>=5qs1Yi8S!|w=YW##Cv?y%GU`5Sy| zm9ZfpCjqQ4_ztXfGl3>`G-8Awh4Ep`$jh$eSxB$5T2c7N{Lrz8*3DS^M?d_L(c^@oxa@|8yGgXt*Ho}yAyxd$`kClSjRBCE^qfOy{6+U8u$bQ| zbpM&lqlY;6Mn%@>N_}G;2~J$|RPxzH&7v)P94?Pm65fT>c6^DtHq5u{sbsSiPEE5& zw0c*AVh8G0#P@I3BLWK|rXg+CTjzGmWg|^jvz`bZ+*8vccWW{-$0XwWZFcGPpZ7ct z=S;l3vSKOte3p*B*(*xLn-_}er5l<&TOj>K-^Z*zS{t&a*e$@(TE0f+2#mTiJ+y zR2N4@BSR5=Iapo2Z;%;RR8rQfOoYJXtwqQf%<{9DCaDRj6-!ZPW&Jd?IroQ!aY>p+ zC55vZjX-HakhIo$erch-Y|9Wvx2#8m_OkLvBZBL{J77bf#W`NS*p?Ll z?~&5qN}GKt4SZ(+5@@DbAI*}&IuSxu(Zf zy4aTozU#bccitxF?cFo};~8crK)TfH!3i8a{lH;nw|rnRvs*yPWfyTicWYPjr2px3 zi)=pTYK*Y;iR*PwM=Kr7^SRn2a^iNg`O@318roc^4yZlL+__WxQ6{gsdydfQhZ`*S z)aO@HAF~;J8&4UkS<^OPD#v9QEzynmOljeEdsn-pK_zY%Vzp;+0@&>-E3ioW*L#xr(K_`~fjoI8Sf&_turA6q_B{VwG#n++nV2Z z1mzP^6W&Za(#vRg0M_E~mVoRhcXz99D~qdD^roH~Yt`s1>-pNPwDs;-`UkOGwEf#X zAH>$Zp0Egu@jG+3Uy&jcd#)Xyduqq#CO-LiP&U*sYW=#@6+WBe+Ip?Mqq(N%)$7<%el#)sbl3gIphw zGP2+;nBexEn+;7es%Un}J6ma|slMy2E^T=7Cd6BCSJs)riVtGPBfFkuP&32?P410J zd=Tc)Q|PU54^%2JHqLsqj~dUTa`@20s}|)A0q#pVEvCvFKN3GoVk=yS1h?Ur2?s?> zSxE{zPkimheURICq;+G34X60#q8jtQZ!BM@&ULRl<#|4|il!uN*mg=xV_|tcbix+J za$&f_X=5d^>|KSEjnK|me6xo4k}rEqZcLTNXS41xFD$4aBS~Kcy1G}i9kIJ(uVPOP z$+D}=x_`7=J}H~L`#r%Uh_BcBKP(!%=rZA;8Yid}YAN1;LEU{~ zxu9ZaRn8@8Hony0_94gIE2ru`3arJBhj!MpTU zKP{drck-DjU(=qtw38H*qP6uCP1|H8g2d}pbiB(iJoXk&c$p{HG<4Hjyeevs*uu@- zax~+I-zvNXO|MooE@dgCB{w;3MOXpGLBSdv02fbk?b5ZK*s9T)M8u{d{O?{K6B( zr-oIdMa?2R&;d)SW~Sz{h2=Woyu2&$JK{0xT{yF?YkWGl?L7JM*1!(4K;6TSQq4GY zhn%|JU2lt9$k+Q6n@uU;#^hX-5j*Ex$m?`{Z|L*X{KGk0r*3$Osw7`uy`ubO4VEX# zYirGZvgoI-$q}f~W@6mu^Ag;KDsKy4lf7%s4%QAx>qp;J+^wl$Q|nq2(|i4Do676q zjCGW!eoy1rD{S=?JK67C5R&*-ci!XpgL>nz(YXtYmnaVFyMS@BTOXh*fPXU1>X8z)Kk zI=HYptw6TK#42oI?ci3~p(dAq%s1H2d8$4a*3>tN5+z z*tTms6EAlt<*A~gkimx#zI)olRv@)l0-z&LN@M-z! G_WuEesGD~H literal 0 HcmV?d00001 diff --git a/ldm/modules/attention.py b/ldm/modules/attention.py index f4eff39..4485c1e 100644 --- a/ldm/modules/attention.py +++ b/ldm/modules/attention.py @@ -7,6 +7,8 @@ from einops import rearrange, repeat from ldm.modules.diffusionmodules.util import checkpoint +import psutil + def exists(val): return val is not None @@ -167,30 +169,98 @@ class CrossAttention(nn.Module): nn.Dropout(dropout) ) + if torch.cuda.is_available(): + self.einsum_op = self.einsum_op_cuda + else: + self.mem_total = psutil.virtual_memory().total / (1024**3) + self.einsum_op = self.einsum_op_mps_v1 if self.mem_total >= 32 else self.einsum_op_mps_v2 + + def einsum_op_compvis(self, q, k, v, r1): + s1 = einsum('b i d, b j d -> b i j', q, k) * self.scale # faster + s2 = s1.softmax(dim=-1, dtype=q.dtype) + del s1 + r1 = einsum('b i j, b j d -> b i d', s2, v) + del s2 + return r1 + + def einsum_op_mps_v1(self, q, k, v, r1): + if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096 + r1 = self.einsum_op_compvis(q, k, v, r1) + else: + slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1])) + for i in range(0, q.shape[1], slice_size): + end = i + slice_size + s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * self.scale + s2 = s1.softmax(dim=-1, dtype=r1.dtype) + del s1 + r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v) + del s2 + return r1 + + def einsum_op_mps_v2(self, q, k, v, r1): + if self.mem_total >= 8 and q.shape[1] <= 4096: + r1 = self.einsum_op_compvis(q, k, v, r1) + else: + slice_size = 1 + for i in range(0, q.shape[0], slice_size): + end = min(q.shape[0], i + slice_size) + s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end]) + s1 *= self.scale + s2 = s1.softmax(dim=-1, dtype=r1.dtype) + del s1 + r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end]) + del s2 + return r1 + + def einsum_op_cuda(self, q, k, v, r1): + stats = torch.cuda.memory_stats(q.device) + mem_active = stats['active_bytes.all.current'] + mem_reserved = stats['reserved_bytes.all.current'] + mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device()) + mem_free_torch = mem_reserved - mem_active + mem_free_total = mem_free_cuda + mem_free_torch + + gb = 1024 ** 3 + tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * 4 + mem_required = tensor_size * 2.5 + steps = 1 + + if mem_required > mem_free_total: + steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2))) + + if steps > 64: + max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64 + raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). ' + f'Need: {mem_required/64/gb:0.1f}GB free, Have:{mem_free_total/gb:0.1f}GB free') + + slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] + for i in range(0, q.shape[1], slice_size): + end = min(q.shape[1], i + slice_size) + s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * self.scale + s2 = s1.softmax(dim=-1, dtype=r1.dtype) + del s1 + r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v) + del s2 + return r1 + def forward(self, x, context=None, mask=None): h = self.heads q = self.to_q(x) context = default(context, x) + del x k = self.to_k(context) v = self.to_v(context) + del context q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) - sim = einsum('b i d, b j d -> b i j', q, k) * self.scale - - if exists(mask): - mask = rearrange(mask, 'b ... -> b (...)') - max_neg_value = -torch.finfo(sim.dtype).max - mask = repeat(mask, 'b j -> (b h) () j', h=h) - sim.masked_fill_(~mask, max_neg_value) - - # attention, what we cannot get enough of - attn = sim.softmax(dim=-1) - - out = einsum('b i j, b j d -> b i d', attn, v) - out = rearrange(out, '(b h) n d -> b n (h d)', h=h) - return self.to_out(out) + r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) + r1 = self.einsum_op(q, k, v, r1) + del q, k, v + r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h) + del r1 + return self.to_out(r2) class BasicTransformerBlock(nn.Module): @@ -209,9 +279,10 @@ class BasicTransformerBlock(nn.Module): return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint) def _forward(self, x, context=None): - x = self.attn1(self.norm1(x)) + x - x = self.attn2(self.norm2(x), context=context) + x - x = self.ff(self.norm3(x)) + x + x = x.contiguous() if x.device.type == 'mps' else x + x += self.attn1(self.norm1(x)) + x += self.attn2(self.norm2(x), context=context) + x += self.ff(self.norm3(x)) return x diff --git a/ldm/modules/diffusionmodules/model.py b/ldm/modules/diffusionmodules/model.py index 533e589..dbbb325 100644 --- a/ldm/modules/diffusionmodules/model.py +++ b/ldm/modules/diffusionmodules/model.py @@ -1,4 +1,5 @@ # pytorch_diffusion + derived encoder decoder +import gc import math import torch import torch.nn as nn @@ -119,18 +120,30 @@ class ResnetBlock(nn.Module): padding=0) def forward(self, x, temb): - h = x - h = self.norm1(h) - h = nonlinearity(h) - h = self.conv1(h) + h1 = x + h2 = self.norm1(h1) + del h1 + + h3 = nonlinearity(h2) + del h2 + + h4 = self.conv1(h3) + del h3 if temb is not None: - h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None] + h4 = h4 + self.temb_proj(nonlinearity(temb))[:,:,None,None] - h = self.norm2(h) - h = nonlinearity(h) - h = self.dropout(h) - h = self.conv2(h) + h5 = self.norm2(h4) + del h4 + + h6 = nonlinearity(h5) + del h5 + + h7 = self.dropout(h6) + del h6 + + h8 = self.conv2(h7) + del h7 if self.in_channels != self.out_channels: if self.use_conv_shortcut: @@ -138,7 +151,7 @@ class ResnetBlock(nn.Module): else: x = self.nin_shortcut(x) - return x+h + return x + h8 class LinAttnBlock(LinearAttention): @@ -178,28 +191,65 @@ class AttnBlock(nn.Module): def forward(self, x): h_ = x h_ = self.norm(h_) - q = self.q(h_) - k = self.k(h_) + q1 = self.q(h_) + k1 = self.k(h_) v = self.v(h_) # compute attention - b,c,h,w = q.shape - q = q.reshape(b,c,h*w) - q = q.permute(0,2,1) # b,hw,c - k = k.reshape(b,c,h*w) # b,c,hw - w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] - w_ = w_ * (int(c)**(-0.5)) - w_ = torch.nn.functional.softmax(w_, dim=2) + b, c, h, w = q1.shape - # attend to values - v = v.reshape(b,c,h*w) - w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q) - h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] - h_ = h_.reshape(b,c,h,w) + q2 = q1.reshape(b, c, h*w) + del q1 - h_ = self.proj_out(h_) + q = q2.permute(0, 2, 1) # b,hw,c + del q2 - return x+h_ + k = k1.reshape(b, c, h*w) # b,c,hw + del k1 + + h_ = torch.zeros_like(k, device=q.device) + + stats = torch.cuda.memory_stats(q.device) + mem_active = stats['active_bytes.all.current'] + mem_reserved = stats['reserved_bytes.all.current'] + mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device()) + mem_free_torch = mem_reserved - mem_active + mem_free_total = mem_free_cuda + mem_free_torch + + tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * 4 + mem_required = tensor_size * 2.5 + steps = 1 + + if mem_required > mem_free_total: + steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2))) + + slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] + for i in range(0, q.shape[1], slice_size): + end = i + slice_size + + w1 = torch.bmm(q[:, i:end], k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w2 = w1 * (int(c)**(-0.5)) + del w1 + w3 = torch.nn.functional.softmax(w2, dim=2) + del w2 + + # attend to values + v1 = v.reshape(b, c, h*w) + w4 = w3.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) + del w3 + + h_[:, :, i:end] = torch.bmm(v1, w4) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + del v1, w4 + + h2 = h_.reshape(b, c, h, w) + del h_ + + h3 = self.proj_out(h2) + del h2 + + h3 += x + + return h3 def make_attn(in_channels, attn_type="vanilla"): @@ -540,31 +590,54 @@ class Decoder(nn.Module): temb = None # z to block_in - h = self.conv_in(z) + h1 = self.conv_in(z) # middle - h = self.mid.block_1(h, temb) - h = self.mid.attn_1(h) - h = self.mid.block_2(h, temb) + h2 = self.mid.block_1(h1, temb) + del h1 + + h3 = self.mid.attn_1(h2) + del h2 + + h = self.mid.block_2(h3, temb) + del h3 + + # prepare for up sampling + gc.collect() + torch.cuda.empty_cache() # upsampling for i_level in reversed(range(self.num_resolutions)): for i_block in range(self.num_res_blocks+1): h = self.up[i_level].block[i_block](h, temb) if len(self.up[i_level].attn) > 0: - h = self.up[i_level].attn[i_block](h) + t = h + h = self.up[i_level].attn[i_block](t) + del t + if i_level != 0: - h = self.up[i_level].upsample(h) + t = h + h = self.up[i_level].upsample(t) + del t # end if self.give_pre_end: return h - h = self.norm_out(h) - h = nonlinearity(h) - h = self.conv_out(h) + h1 = self.norm_out(h) + del h + + h2 = nonlinearity(h1) + del h1 + + h = self.conv_out(h2) + del h2 + if self.tanh_out: - h = torch.tanh(h) + t = h + h = torch.tanh(t) + del t + return h diff --git a/ldm/modules/diffusionmodules/util.py b/ldm/modules/diffusionmodules/util.py index a952e6c..f872ba0 100644 --- a/ldm/modules/diffusionmodules/util.py +++ b/ldm/modules/diffusionmodules/util.py @@ -54,7 +54,8 @@ def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timestep # assert ddim_timesteps.shape[0] == num_ddim_timesteps # add one to get the final alpha values right (the ones from first scale to data during sampling) - steps_out = ddim_timesteps + 1 + # steps_out = ddim_timesteps + 1 # removed due to some issues when reaching 1000 + steps_out = np.where(ddim_timesteps != 999, ddim_timesteps+1, ddim_timesteps) if verbose: print(f'Selected timesteps for ddim sampler: {steps_out}') return steps_out @@ -264,4 +265,4 @@ class HybridConditioner(nn.Module): def noise_like(shape, device, repeat=False): repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) noise = lambda: torch.randn(shape, device=device) - return repeat_noise() if repeat else noise() \ No newline at end of file + return repeat_noise() if repeat else noise() diff --git a/scripts/DeforumStableDiffusion.py b/scripts/DeforumStableDiffusion.py new file mode 100644 index 0000000..cd88539 --- /dev/null +++ b/scripts/DeforumStableDiffusion.py @@ -0,0 +1,1312 @@ +#Deforum Stable Diffusion v0.4 +#Stable Diffusion by Robin Rombach, Andreas Blattmann, Dominik Lorenz, Patrick Esser, Björn Ommer and the Stability.ai Team. K Diffusion by Katherine Crowson. You need to get the ckpt file and put it on your Google Drive first to use this. It can be downloaded from HuggingFace. + +#Notebook by deforum +#Local Version by DGSpitzer 大谷的游戏创作小屋 + +import os, time +def get_output_folder(output_path, batch_folder): + out_path = os.path.join(output_path,time.strftime('%Y-%m')) + if batch_folder != "": + out_path = os.path.join(out_path, batch_folder) + os.makedirs(out_path, exist_ok=True) + return out_path + + +def main(): + + import argparse + parser = argparse.ArgumentParser() + + parser.add_argument( + "--settings", + type=str, + default="./examples/runSettings_StillImages.txt", + help="Settings file", + ) + + parser.add_argument( + "--enable_animation_mode", + default=False, + action='store_true', + help="Enable animation mode settings", + ) + + opt = parser.parse_args() + + #@markdown **Model and Output Paths** + # ask for the link + print("Local Path Variables:\n") + + models_path = "./models" #@param {type:"string"} + output_path = "./output" #@param {type:"string"} + + #@markdown **Google Drive Path Variables (Optional)** + mount_google_drive = False #@param {type:"boolean"} + force_remount = False + + + + + if mount_google_drive: + from google.colab import drive # type: ignore + try: + drive_path = "/content/drive" + drive.mount(drive_path,force_remount=force_remount) + models_path_gdrive = "/content/drive/MyDrive/AI/models" #@param {type:"string"} + output_path_gdrive = "/content/drive/MyDrive/AI/StableDiffusion" #@param {type:"string"} + models_path = models_path_gdrive + output_path = output_path_gdrive + except: + print("...error mounting drive or with drive path variables") + print("...reverting to default path variables") + + import os + os.makedirs(models_path, exist_ok=True) + os.makedirs(output_path, exist_ok=True) + + print(f"models_path: {models_path}") + print(f"output_path: {output_path}") + + + + #@markdown **Python Definitions** + import IPython + import json + from IPython import display + + import gc, math, os, pathlib, shutil, subprocess, sys, time + import cv2 + import numpy as np + import pandas as pd + import random + import requests + import torch, torchvision + import torch.nn as nn + import torchvision.transforms as T + import torchvision.transforms.functional as TF + from contextlib import contextmanager, nullcontext + from einops import rearrange, repeat + from itertools import islice + from omegaconf import OmegaConf + from PIL import Image + from pytorch_lightning import seed_everything + from skimage.exposure import match_histograms + from torchvision.utils import make_grid + from tqdm import tqdm, trange + from types import SimpleNamespace + from torch import autocast + + sys.path.extend([ + 'src/taming-transformers', + 'src/clip', + 'stable-diffusion/', + 'k-diffusion', + 'pytorch3d-lite', + 'AdaBins', + 'MiDaS', + ]) + + import py3d_tools as p3d + + from helpers import DepthModel, sampler_fn + from k_diffusion.external import CompVisDenoiser + from ldm.util import instantiate_from_config + from ldm.models.diffusion.ddim import DDIMSampler + from ldm.models.diffusion.plms import PLMSSampler + + + #Read settings files + def load_args(path): + with open(path, "r") as f: + loaded_args = json.load(f)#, ensure_ascii=False, indent=4) + return loaded_args + + master_args = load_args(opt.settings) + + + def sanitize(prompt): + whitelist = set('abcdefghijklmnopqrstuvwxyz ABCDEFGHIJKLMNOPQRSTUVWXYZ') + tmp = ''.join(filter(whitelist.__contains__, prompt)) + return tmp.replace(' ', '_') + + def anim_frame_warp_2d(prev_img_cv2, args, anim_args, keys, frame_idx): + angle = keys.angle_series[frame_idx] + zoom = keys.zoom_series[frame_idx] + translation_x = keys.translation_x_series[frame_idx] + translation_y = keys.translation_y_series[frame_idx] + + center = (args.W // 2, args.H // 2) + trans_mat = np.float32([[1, 0, translation_x], [0, 1, translation_y]]) + rot_mat = cv2.getRotationMatrix2D(center, angle, zoom) + trans_mat = np.vstack([trans_mat, [0,0,1]]) + rot_mat = np.vstack([rot_mat, [0,0,1]]) + xform = np.matmul(rot_mat, trans_mat) + + return cv2.warpPerspective( + prev_img_cv2, + xform, + (prev_img_cv2.shape[1], prev_img_cv2.shape[0]), + borderMode=cv2.BORDER_WRAP if anim_args.border == 'wrap' else cv2.BORDER_REPLICATE + ) + + def anim_frame_warp_3d(prev_img_cv2, depth, anim_args, keys, frame_idx): + TRANSLATION_SCALE = 1.0/200.0 # matches Disco + translate_xyz = [ + -keys.translation_x_series[frame_idx] * TRANSLATION_SCALE, + keys.translation_y_series[frame_idx] * TRANSLATION_SCALE, + -keys.translation_z_series[frame_idx] * TRANSLATION_SCALE + ] + rotate_xyz = [ + math.radians(keys.rotation_3d_x_series[frame_idx]), + math.radians(keys.rotation_3d_y_series[frame_idx]), + math.radians(keys.rotation_3d_z_series[frame_idx]) + ] + rot_mat = p3d.euler_angles_to_matrix(torch.tensor(rotate_xyz, device=device), "XYZ").unsqueeze(0) + result = transform_image_3d(prev_img_cv2, depth, rot_mat, translate_xyz, anim_args) + torch.cuda.empty_cache() + return result + + def add_noise(sample: torch.Tensor, noise_amt: float) -> torch.Tensor: + return sample + torch.randn(sample.shape, device=sample.device) * noise_amt + + def load_img(path, shape, use_alpha_as_mask=False): + # use_alpha_as_mask: Read the alpha channel of the image as the mask image + if path.startswith('http://') or path.startswith('https://'): + image = Image.open(requests.get(path, stream=True).raw) + else: + image = Image.open(path) + + if use_alpha_as_mask: + image = image.convert('RGBA') + else: + image = image.convert('RGB') + + image = image.resize(shape, resample=Image.LANCZOS) + + mask_image = None + if use_alpha_as_mask: + # Split alpha channel into a mask_image + red, green, blue, alpha = Image.Image.split(image) + mask_image = alpha.convert('L') + image = image.convert('RGB') + + image = np.array(image).astype(np.float16) / 255.0 + image = image[None].transpose(0, 3, 1, 2) + image = torch.from_numpy(image) + image = 2.*image - 1. + + return image, mask_image + + def load_mask_latent(mask_input, shape): + # mask_input (str or PIL Image.Image): Path to the mask image or a PIL Image object + # shape (list-like len(4)): shape of the image to match, usually latent_image.shape + + if isinstance(mask_input, str): # mask input is probably a file name + if mask_input.startswith('http://') or mask_input.startswith('https://'): + mask_image = Image.open(requests.get(mask_input, stream=True).raw).convert('RGBA') + else: + mask_image = Image.open(mask_input).convert('RGBA') + elif isinstance(mask_input, Image.Image): + mask_image = mask_input + else: + raise Exception("mask_input must be a PIL image or a file name") + + mask_w_h = (shape[-1], shape[-2]) + mask = mask_image.resize(mask_w_h, resample=Image.LANCZOS) + mask = mask.convert("L") + return mask + + def prepare_mask(mask_input, mask_shape, mask_brightness_adjust=1.0, mask_contrast_adjust=1.0): + # mask_input (str or PIL Image.Image): Path to the mask image or a PIL Image object + # shape (list-like len(4)): shape of the image to match, usually latent_image.shape + # mask_brightness_adjust (non-negative float): amount to adjust brightness of the iamge, + # 0 is black, 1 is no adjustment, >1 is brighter + # mask_contrast_adjust (non-negative float): amount to adjust contrast of the image, + # 0 is a flat grey image, 1 is no adjustment, >1 is more contrast + + mask = load_mask_latent(mask_input, mask_shape) + + # Mask brightness/contrast adjustments + if mask_brightness_adjust != 1: + mask = TF.adjust_brightness(mask, mask_brightness_adjust) + if mask_contrast_adjust != 1: + mask = TF.adjust_contrast(mask, mask_contrast_adjust) + + # Mask image to array + mask = np.array(mask).astype(np.float32) / 255.0 + mask = np.tile(mask,(4,1,1)) + mask = np.expand_dims(mask,axis=0) + mask = torch.from_numpy(mask) + + if args.invert_mask: + mask = ( (mask - 0.5) * -1) + 0.5 + + mask = np.clip(mask,0,1) + return mask + + def maintain_colors(prev_img, color_match_sample, mode): + if mode == 'Match Frame 0 RGB': + return match_histograms(prev_img, color_match_sample, multichannel=True) + elif mode == 'Match Frame 0 HSV': + prev_img_hsv = cv2.cvtColor(prev_img, cv2.COLOR_RGB2HSV) + color_match_hsv = cv2.cvtColor(color_match_sample, cv2.COLOR_RGB2HSV) + matched_hsv = match_histograms(prev_img_hsv, color_match_hsv, multichannel=True) + return cv2.cvtColor(matched_hsv, cv2.COLOR_HSV2RGB) + else: # Match Frame 0 LAB + prev_img_lab = cv2.cvtColor(prev_img, cv2.COLOR_RGB2LAB) + color_match_lab = cv2.cvtColor(color_match_sample, cv2.COLOR_RGB2LAB) + matched_lab = match_histograms(prev_img_lab, color_match_lab, multichannel=True) + return cv2.cvtColor(matched_lab, cv2.COLOR_LAB2RGB) + + + def make_callback(sampler_name, dynamic_threshold=None, static_threshold=None, mask=None, init_latent=None, sigmas=None, sampler=None, masked_noise_modifier=1.0): + # Creates the callback function to be passed into the samplers + # The callback function is applied to the image at each step + def dynamic_thresholding_(img, threshold): + # Dynamic thresholding from Imagen paper (May 2022) + s = np.percentile(np.abs(img.cpu()), threshold, axis=tuple(range(1,img.ndim))) + s = np.max(np.append(s,1.0)) + torch.clamp_(img, -1*s, s) + torch.FloatTensor.div_(img, s) + + # Callback for samplers in the k-diffusion repo, called thus: + # callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + def k_callback_(args_dict): + if dynamic_threshold is not None: + dynamic_thresholding_(args_dict['x'], dynamic_threshold) + if static_threshold is not None: + torch.clamp_(args_dict['x'], -1*static_threshold, static_threshold) + if mask is not None: + init_noise = init_latent + noise * args_dict['sigma'] + is_masked = torch.logical_and(mask >= mask_schedule[args_dict['i']], mask != 0 ) + new_img = init_noise * torch.where(is_masked,1,0) + args_dict['x'] * torch.where(is_masked,0,1) + args_dict['x'].copy_(new_img) + + # Function that is called on the image (img) and step (i) at each step + def img_callback_(img, i): + # Thresholding functions + if dynamic_threshold is not None: + dynamic_thresholding_(img, dynamic_threshold) + if static_threshold is not None: + torch.clamp_(img, -1*static_threshold, static_threshold) + if mask is not None: + i_inv = len(sigmas) - i - 1 + init_noise = sampler.stochastic_encode(init_latent, torch.tensor([i_inv]*batch_size).to(device), noise=noise) + is_masked = torch.logical_and(mask >= mask_schedule[i], mask != 0 ) + new_img = init_noise * torch.where(is_masked,1,0) + img * torch.where(is_masked,0,1) + img.copy_(new_img) + + if init_latent is not None: + noise = torch.randn_like(init_latent, device=device) * masked_noise_modifier + if sigmas is not None and len(sigmas) > 0: + mask_schedule, _ = torch.sort(sigmas/torch.max(sigmas)) + elif len(sigmas) == 0: + mask = None # no mask needed if no steps (usually happens because strength==1.0) + if sampler_name in ["plms","ddim"]: + # Callback function formated for compvis latent diffusion samplers + if mask is not None: + assert sampler is not None, "Callback function for stable-diffusion samplers requires sampler variable" + batch_size = init_latent.shape[0] + + callback = img_callback_ + else: + # Default callback function uses k-diffusion sampler variables + callback = k_callback_ + + return callback + + def sample_from_cv2(sample: np.ndarray) -> torch.Tensor: + sample = ((sample.astype(float) / 255.0) * 2) - 1 + sample = sample[None].transpose(0, 3, 1, 2).astype(np.float16) + sample = torch.from_numpy(sample) + return sample + + def sample_to_cv2(sample: torch.Tensor, type=np.uint8) -> np.ndarray: + sample_f32 = rearrange(sample.squeeze().cpu().numpy(), "c h w -> h w c").astype(np.float32) + sample_f32 = ((sample_f32 * 0.5) + 0.5).clip(0, 1) + sample_int8 = (sample_f32 * 255) + return sample_int8.astype(type) + + def transform_image_3d(prev_img_cv2, depth_tensor, rot_mat, translate, anim_args): + # adapted and optimized version of transform_image_3d from Disco Diffusion https://github.com/alembics/disco-diffusion + w, h = prev_img_cv2.shape[1], prev_img_cv2.shape[0] + + aspect_ratio = float(w)/float(h) + near, far, fov_deg = anim_args.near_plane, anim_args.far_plane, anim_args.fov + persp_cam_old = p3d.FoVPerspectiveCameras(near, far, aspect_ratio, fov=fov_deg, degrees=True, device=device) + persp_cam_new = p3d.FoVPerspectiveCameras(near, far, aspect_ratio, fov=fov_deg, degrees=True, R=rot_mat, T=torch.tensor([translate]), device=device) + + # range of [-1,1] is important to torch grid_sample's padding handling + y,x = torch.meshgrid(torch.linspace(-1.,1.,h,dtype=torch.float32,device=device),torch.linspace(-1.,1.,w,dtype=torch.float32,device=device)) + z = torch.as_tensor(depth_tensor, dtype=torch.float32, device=device) + xyz_old_world = torch.stack((x.flatten(), y.flatten(), z.flatten()), dim=1) + + xyz_old_cam_xy = persp_cam_old.get_full_projection_transform().transform_points(xyz_old_world)[:,0:2] + xyz_new_cam_xy = persp_cam_new.get_full_projection_transform().transform_points(xyz_old_world)[:,0:2] + + offset_xy = xyz_new_cam_xy - xyz_old_cam_xy + # affine_grid theta param expects a batch of 2D mats. Each is 2x3 to do rotation+translation. + identity_2d_batch = torch.tensor([[1.,0.,0.],[0.,1.,0.]], device=device).unsqueeze(0) + # coords_2d will have shape (N,H,W,2).. which is also what grid_sample needs. + coords_2d = torch.nn.functional.affine_grid(identity_2d_batch, [1,1,h,w], align_corners=False) + offset_coords_2d = coords_2d - torch.reshape(offset_xy, (h,w,2)).unsqueeze(0) + + image_tensor = rearrange(torch.from_numpy(prev_img_cv2.astype(np.float32)), 'h w c -> c h w').to(device) + new_image = torch.nn.functional.grid_sample( + image_tensor.add(1/512 - 0.0001).unsqueeze(0), + offset_coords_2d, + mode=anim_args.sampling_mode, + padding_mode=anim_args.padding_mode, + align_corners=False + ) + + # convert back to cv2 style numpy array + result = rearrange( + new_image.squeeze().clamp(0,255), + 'c h w -> h w c' + ).cpu().numpy().astype(prev_img_cv2.dtype) + return result + + def generate(args, return_latent=False, return_sample=False, return_c=False): + seed_everything(args.seed) + os.makedirs(args.outdir, exist_ok=True) + + sampler = PLMSSampler(model) if args.sampler == 'plms' else DDIMSampler(model) + model_wrap = CompVisDenoiser(model) + batch_size = args.n_samples + prompt = args.prompt + assert prompt is not None + data = [batch_size * [prompt]] + precision_scope = autocast if args.precision == "autocast" else nullcontext + + init_latent = None + mask_image = None + init_image = None + if args.init_latent is not None: + init_latent = args.init_latent + elif args.init_sample is not None: + with precision_scope("cuda"): + init_latent = model.get_first_stage_encoding(model.encode_first_stage(args.init_sample)) + elif args.use_init and args.init_image != None and args.init_image != '': + init_image, mask_image = load_img(args.init_image, + shape=(args.W, args.H), + use_alpha_as_mask=args.use_alpha_as_mask) + init_image = init_image.to(device) + init_image = repeat(init_image, '1 ... -> b ...', b=batch_size) + with precision_scope("cuda"): + init_latent = model.get_first_stage_encoding(model.encode_first_stage(init_image)) # move to latent space + + if not args.use_init and args.strength > 0 and args.strength_0_no_init: + print("\nNo init image, but strength > 0. Strength has been auto set to 0, since use_init is False.") + print("If you want to force strength > 0 with no init, please set strength_0_no_init to False.\n") + args.strength = 0 + + # Mask functions + if args.use_mask: + assert args.mask_file is not None or mask_image is not None, "use_mask==True: An mask image is required for a mask. Please enter a mask_file or use an init image with an alpha channel" + assert args.use_init, "use_mask==True: use_init is required for a mask" + assert init_latent is not None, "use_mask==True: An latent init image is required for a mask" + + mask = prepare_mask(args.mask_file if mask_image is None else mask_image, + init_latent.shape, + args.mask_contrast_adjust, + args.mask_brightness_adjust) + + if (torch.all(mask == 0) or torch.all(mask == 1)) and args.use_alpha_as_mask: + raise Warning("use_alpha_as_mask==True: Using the alpha channel from the init image as a mask, but the alpha channel is blank.") + + mask = mask.to(device) + mask = repeat(mask, '1 ... -> b ...', b=batch_size) + else: + mask = None + + t_enc = int((1.0-args.strength) * args.steps) + + # Noise schedule for the k-diffusion samplers (used for masking) + k_sigmas = model_wrap.get_sigmas(args.steps) + k_sigmas = k_sigmas[len(k_sigmas)-t_enc-1:] + + if args.sampler in ['plms','ddim']: + sampler.make_schedule(ddim_num_steps=args.steps, ddim_eta=args.ddim_eta, ddim_discretize='fill', verbose=False) + + callback = make_callback(sampler_name=args.sampler, + dynamic_threshold=args.dynamic_threshold, + static_threshold=args.static_threshold, + mask=mask, + init_latent=init_latent, + sigmas=k_sigmas, + sampler=sampler) + + results = [] + with torch.no_grad(): + with precision_scope("cuda"): + with model.ema_scope(): + for prompts in data: + uc = None + if args.scale != 1.0: + uc = model.get_learned_conditioning(batch_size * [""]) + if isinstance(prompts, tuple): + prompts = list(prompts) + c = model.get_learned_conditioning(prompts) + + if args.init_c != None: + c = args.init_c + + if args.sampler in ["klms","dpm2","dpm2_ancestral","heun","euler","euler_ancestral"]: + samples = sampler_fn( + c=c, + uc=uc, + args=args, + model_wrap=model_wrap, + init_latent=init_latent, + t_enc=t_enc, + device=device, + cb=callback) + else: + # args.sampler == 'plms' or args.sampler == 'ddim': + if init_latent is not None and args.strength > 0: + z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(device)) + else: + z_enc = torch.randn([args.n_samples, args.C, args.H // args.f, args.W // args.f], device=device) + if args.sampler == 'ddim': + samples = sampler.decode(z_enc, + c, + t_enc, + unconditional_guidance_scale=args.scale, + unconditional_conditioning=uc, + img_callback=callback) + elif args.sampler == 'plms': # no "decode" function in plms, so use "sample" + shape = [args.C, args.H // args.f, args.W // args.f] + samples, _ = sampler.sample(S=args.steps, + conditioning=c, + batch_size=args.n_samples, + shape=shape, + verbose=False, + unconditional_guidance_scale=args.scale, + unconditional_conditioning=uc, + eta=args.ddim_eta, + x_T=z_enc, + img_callback=callback) + else: + raise Exception(f"Sampler {args.sampler} not recognised.") + + if return_latent: + results.append(samples.clone()) + + x_samples = model.decode_first_stage(samples) + if return_sample: + results.append(x_samples.clone()) + + x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0) + + if return_c: + results.append(c.clone()) + + for x_sample in x_samples: + x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') + image = Image.fromarray(x_sample.astype(np.uint8)) + results.append(image) + return results + + #@markdown **Select and Load Model** + + model_config = "v1-inference.yaml" #@param ["custom","v1-inference.yaml"] + model_checkpoint = "sd-v1-4.ckpt" #@param ["custom","sd-v1-4-full-ema.ckpt","sd-v1-4.ckpt","sd-v1-3-full-ema.ckpt","sd-v1-3.ckpt","sd-v1-2-full-ema.ckpt","sd-v1-2.ckpt","sd-v1-1-full-ema.ckpt","sd-v1-1.ckpt"] + custom_config_path = "" #@param {type:"string"} + custom_checkpoint_path = "" #@param {type:"string"} + + load_on_run_all = True #@param {type: 'boolean'} + half_precision = True # check + check_sha256 = True #@param {type:"boolean"} + + model_map = { + "sd-v1-4-full-ema.ckpt": {'sha256': '14749efc0ae8ef0329391ad4436feb781b402f4fece4883c7ad8d10556d8a36a'}, + "sd-v1-4.ckpt": {'sha256': 'fe4efff1e174c627256e44ec2991ba279b3816e364b49f9be2abc0b3ff3f8556'}, + "sd-v1-3-full-ema.ckpt": {'sha256': '54632c6e8a36eecae65e36cb0595fab314e1a1545a65209f24fde221a8d4b2ca'}, + "sd-v1-3.ckpt": {'sha256': '2cff93af4dcc07c3e03110205988ff98481e86539c51a8098d4f2236e41f7f2f'}, + "sd-v1-2-full-ema.ckpt": {'sha256': 'bc5086a904d7b9d13d2a7bccf38f089824755be7261c7399d92e555e1e9ac69a'}, + "sd-v1-2.ckpt": {'sha256': '3b87d30facd5bafca1cbed71cfb86648aad75d1c264663c0cc78c7aea8daec0d'}, + "sd-v1-1-full-ema.ckpt": {'sha256': 'efdeb5dc418a025d9a8cc0a8617e106c69044bc2925abecc8a254b2910d69829'}, + "sd-v1-1.ckpt": {'sha256': '86cd1d3ccb044d7ba8db743d717c9bac603c4043508ad2571383f954390f3cea'} + } + + # config path + ckpt_config_path = custom_config_path if model_config == "custom" else os.path.join(models_path, model_config) + if os.path.exists(ckpt_config_path): + print(f"{ckpt_config_path} exists") + else: + ckpt_config_path = "./stable-diffusion/configs/stable-diffusion/v1-inference.yaml" + print(f"Using config: {ckpt_config_path}") + + # checkpoint path or download + ckpt_path = custom_checkpoint_path if model_checkpoint == "custom" else os.path.join(models_path, model_checkpoint) + ckpt_valid = True + if os.path.exists(ckpt_path): + print(f"{ckpt_path} exists") + else: + print(f"Please download model checkpoint and place in {os.path.join(models_path, model_checkpoint)}") + ckpt_valid = False + + if check_sha256 and model_checkpoint != "custom" and ckpt_valid: + import hashlib + print("\n...checking sha256") + with open(ckpt_path, "rb") as f: + bytes = f.read() + hash = hashlib.sha256(bytes).hexdigest() + del bytes + if model_map[model_checkpoint]["sha256"] == hash: + print("hash is correct\n") + else: + print("hash in not correct\n") + ckpt_valid = False + + if ckpt_valid: + print(f"Using ckpt: {ckpt_path}") + + def load_model_from_config(config, ckpt, verbose=False, device='cuda', half_precision=True): + map_location = "cuda" #@param ["cpu", "cuda"] + print(f"Loading model from {ckpt}") + pl_sd = torch.load(ckpt, map_location=map_location) + if "global_step" in pl_sd: + print(f"Global Step: {pl_sd['global_step']}") + sd = pl_sd["state_dict"] + model = instantiate_from_config(config.model) + m, u = model.load_state_dict(sd, strict=False) + if len(m) > 0 and verbose: + print("missing keys:") + print(m) + if len(u) > 0 and verbose: + print("unexpected keys:") + print(u) + + if half_precision: + model = model.half().to(device) + else: + model = model.to(device) + model.eval() + return model + + if load_on_run_all and ckpt_valid: + local_config = OmegaConf.load(f"{ckpt_config_path}") + model = load_model_from_config(local_config, f"{ckpt_path}", half_precision=half_precision) + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + model = model.to(device) + + + def DeforumAnimArgs(): + + #@markdown ####**Animation:** + if opt.enable_animation_mode == True: + animation_mode = master_args["animation_mode"] #@param ['None', '2D', '3D', 'Video Input', 'Interpolation'] {type:'string'} + max_frames = master_args["max_frames"] #@param {type:"number"} + border = master_args["border"] #@param ['wrap', 'replicate'] {type:'string'} + + #@markdown ####**Motion Parameters:** + angle = master_args["angle"]#@param {type:"string"} + zoom = master_args["zoom"] #@param {type:"string"} + translation_x = master_args["translation_x"] #@param {type:"string"} + translation_y = master_args["translation_y"] #@param {type:"string"} + translation_z = master_args["translation_z"] #@param {type:"string"} + rotation_3d_x = master_args["rotation_3d_x"] #@param {type:"string"} + rotation_3d_y = master_args["rotation_3d_y"] #@param {type:"string"} + rotation_3d_z = master_args["rotation_3d_z"] #@param {type:"string"} + noise_schedule = master_args["noise_schedule"] #@param {type:"string"} + strength_schedule = master_args["strength_schedule"] #@param {type:"string"} + contrast_schedule = master_args["contrast_schedule"] #@param {type:"string"} + + #@markdown ####**Coherence:** + color_coherence = master_args["color_coherence"] #@param ['None', 'Match Frame 0 HSV', 'Match Frame 0 LAB', 'Match Frame 0 RGB'] {type:'string'} + diffusion_cadence = master_args["diffusion_cadence"] #@param ['1','2','3','4','5','6','7','8'] {type:'string'} + + #@markdown #### Depth Warping + use_depth_warping = master_args["use_depth_warping"] #@param {type:"boolean"} + midas_weight = master_args["midas_weight"] #@param {type:"number"} + near_plane = master_args["near_plane"] + far_plane = master_args["far_plane"] + fov = master_args["fov"] #@param {type:"number"} + padding_mode = master_args["padding_mode"] #@param ['border', 'reflection', 'zeros'] {type:'string'} + sampling_mode = master_args["sampling_mode"] #@param ['bicubic', 'bilinear', 'nearest'] {type:'string'} + save_depth_maps = master_args["save_depth_maps"] #@param {type:"boolean"} + + #@markdown ####**Video Input:** + video_init_path = master_args["video_init_path"] #@param {type:"string"} + extract_nth_frame = master_args["extract_nth_frame"] #@param {type:"number"} + + #@markdown ####**Interpolation:** + interpolate_key_frames = master_args["interpolate_key_frames"] #@param {type:"boolean"} + interpolate_x_frames = master_args["interpolate_x_frames"] #@param {type:"number"} + + #@markdown ####**Resume Animation:** + resume_from_timestring = master_args["resume_from_timestring"] #@param {type:"boolean"} + resume_timestring = master_args["resume_timestring"] #@param {type:"string"} + else: + #@markdown ####**Still image mode:** + animation_mode = 'None' #@param ['None', '2D', '3D', 'Video Input', 'Interpolation'] {type:'string'} + max_frames = 10 #@param {type:"number"} + border = 'wrap' #@param ['wrap', 'replicate'] {type:'string'} + + #@markdown ####**Motion Parameters:** + angle = "0:(0)"#@param {type:"string"} + zoom = "0:(1.04)"#@param {type:"string"} + translation_x = "0:(0)"#@param {type:"string"} + translation_y = "0:(2)"#@param {type:"string"} + translation_z = "0:(0.5)"#@param {type:"string"} + rotation_3d_x = "0:(0)"#@param {type:"string"} + rotation_3d_y = "0:(0)"#@param {type:"string"} + rotation_3d_z = "0:(0)"#@param {type:"string"} + noise_schedule = "0: (0.02)"#@param {type:"string"} + strength_schedule = "0: (0.6)"#@param {type:"string"} + contrast_schedule = "0: (1.0)"#@param {type:"string"} + + #@markdown ####**Coherence:** + color_coherence = 'Match Frame 0 LAB' #@param ['None', 'Match Frame 0 HSV', 'Match Frame 0 LAB', 'Match Frame 0 RGB'] {type:'string'} + diffusion_cadence = '1' #@param ['1','2','3','4','5','6','7','8'] {type:'string'} + + #@markdown #### Depth Warping + use_depth_warping = True #@param {type:"boolean"} + midas_weight = 0.3#@param {type:"number"} + near_plane = 200 + far_plane = 10000 + fov = 40#@param {type:"number"} + padding_mode = 'border'#@param ['border', 'reflection', 'zeros'] {type:'string'} + sampling_mode = 'bicubic'#@param ['bicubic', 'bilinear', 'nearest'] {type:'string'} + save_depth_maps = False #@param {type:"boolean"} + + #@markdown ####**Video Input:** + video_init_path ='./input/video_in.mp4'#@param {type:"string"} + extract_nth_frame = 1#@param {type:"number"} + + #@markdown ####**Interpolation:** + interpolate_key_frames = True #@param {type:"boolean"} + interpolate_x_frames = 100 #@param {type:"number"} + + #@markdown ####**Resume Animation:** + resume_from_timestring = False #@param {type:"boolean"} + resume_timestring = "20220829210106" #@param {type:"string"} + + return locals() + + class DeformAnimKeys(): + def __init__(self, anim_args): + self.angle_series = get_inbetweens(parse_key_frames(anim_args.angle)) + self.zoom_series = get_inbetweens(parse_key_frames(anim_args.zoom)) + self.translation_x_series = get_inbetweens(parse_key_frames(anim_args.translation_x)) + self.translation_y_series = get_inbetweens(parse_key_frames(anim_args.translation_y)) + self.translation_z_series = get_inbetweens(parse_key_frames(anim_args.translation_z)) + self.rotation_3d_x_series = get_inbetweens(parse_key_frames(anim_args.rotation_3d_x)) + self.rotation_3d_y_series = get_inbetweens(parse_key_frames(anim_args.rotation_3d_y)) + self.rotation_3d_z_series = get_inbetweens(parse_key_frames(anim_args.rotation_3d_z)) + self.noise_schedule_series = get_inbetweens(parse_key_frames(anim_args.noise_schedule)) + self.strength_schedule_series = get_inbetweens(parse_key_frames(anim_args.strength_schedule)) + self.contrast_schedule_series = get_inbetweens(parse_key_frames(anim_args.contrast_schedule)) + + + def get_inbetweens(key_frames, integer=False, interp_method='Linear'): + key_frame_series = pd.Series([np.nan for a in range(anim_args.max_frames)]) + + for i, value in key_frames.items(): + key_frame_series[i] = value + key_frame_series = key_frame_series.astype(float) + + if interp_method == 'Cubic' and len(key_frames.items()) <= 3: + interp_method = 'Quadratic' + if interp_method == 'Quadratic' and len(key_frames.items()) <= 2: + interp_method = 'Linear' + + key_frame_series[0] = key_frame_series[key_frame_series.first_valid_index()] + key_frame_series[anim_args.max_frames-1] = key_frame_series[key_frame_series.last_valid_index()] + key_frame_series = key_frame_series.interpolate(method=interp_method.lower(),limit_direction='both') + if integer: + return key_frame_series.astype(int) + return key_frame_series + + def parse_key_frames(string, prompt_parser=None): + import re + pattern = r'((?P[0-9]+):[\s]*[\(](?P[\S\s]*?)[\)])' + frames = dict() + for match_object in re.finditer(pattern, string): + frame = int(match_object.groupdict()['frame']) + param = match_object.groupdict()['param'] + if prompt_parser: + frames[frame] = prompt_parser(param) + else: + frames[frame] = param + if frames == {} and len(string) != 0: + raise RuntimeError('Key Frame string not correctly formatted') + return frames + + #Prompt will be put in here: for example: + ''' + prompts = [ + "a beaufiful young girl holding a flower, art by huang guangjian and gil elvgren and sachin teng, trending on artstation", + "a beaufiful young girl holding a flower, art by greg rutkowski and alphonse mucha, trending on artstation", + #"the third prompt I don't want it I commented it with an", + ] + + animation_prompts = { + 0: "amazing alien landscape with lush vegetation and colourful galaxy foreground, digital art, breathtaking, golden ratio, extremely detailed, hyper - detailed, establishing shot, hyperrealistic, cinematic lighting, particles, unreal engine, simon stalenhag, rendered by beeple, makoto shinkai, syd meade, kentaro miura, jean giraud, environment concept, artstation, octane render, 8k uhd image", + 50: "desolate landscape fill with giant flowers, moody :: by James Jean, Jeff Koons, Dan McPharlin Daniel Merrian :: ornate, dynamic, particulate, rich colors, intricate, elegant, highly detailed, centered, artstation, smooth, sharp focus, octane render, 3d", + } + ''' + + #Replace by text file + prompts = master_args["prompts"] + + if opt.enable_animation_mode: + animation_prompts = master_args["animation_prompts"] + else: + animation_prompts = {} + + + + def DeforumArgs(): + + #@markdown **Image Settings** + W = master_args["width"] #@param + H = master_args["height"] #@param + W, H = map(lambda x: x - x % 64, (W, H)) # resize to integer multiple of 64 + + #@markdown **Sampling Settings** + seed = master_args["seed"] #@param + sampler = master_args["sampler"] #@param ["klms","dpm2","dpm2_ancestral","heun","euler","euler_ancestral","plms", "ddim"] + steps = master_args["steps"] #@param + scale = master_args["scale"] #@param + ddim_eta = master_args["ddim_eta"] #@param + dynamic_threshold = None + static_threshold = None + + #@markdown **Save & Display Settings** + save_samples = True #@param {type:"boolean"} + save_settings = True #@param {type:"boolean"} + display_samples = True #@param {type:"boolean"} + + #@markdown **Batch Settings** + n_batch = master_args["n_batch"] #@param + batch_name = master_args["batch_name"] #@param {type:"string"} + filename_format = master_args["filename_format"] #@param ["{timestring}_{index}_{seed}.png","{timestring}_{index}_{prompt}.png"] + seed_behavior = master_args["seed_behavior"] #@param ["iter","fixed","random"] + make_grid = False #@param {type:"boolean"} + grid_rows = 2 #@param + outdir = get_output_folder(output_path, batch_name) + + #@markdown **Init Settings** + use_init = master_args["use_init"] #@param {type:"boolean"} + strength = master_args["strength"] #@param {type:"number"} + init_image = master_args["init_image"] #@param {type:"string"} + strength_0_no_init = True # Set the strength to 0 automatically when no init image is used + # Whiter areas of the mask are areas that change more + use_mask = master_args["use_mask"] #@param {type:"boolean"} + use_alpha_as_mask = master_args["use_alpha_as_mask"] # use the alpha channel of the init image as the mask + mask_file = master_args["mask_file"] #@param {type:"string"} + invert_mask = master_args["invert_mask"] #@param {type:"boolean"} + # Adjust mask image, 1.0 is no adjustment. Should be positive numbers. + mask_brightness_adjust = 1.0 #@param {type:"number"} + mask_contrast_adjust = 1.0 #@param {type:"number"} + + n_samples = 1 # doesnt do anything + precision = 'autocast' + C = 4 + f = 8 + + prompt = "" + timestring = "" + init_latent = None + init_sample = None + init_c = None + + return locals() + + + + def next_seed(args): + if args.seed_behavior == 'iter': + args.seed += 1 + elif args.seed_behavior == 'fixed': + pass # always keep seed the same + else: + args.seed = random.randint(0, 2**32) + return args.seed + + def render_image_batch(args): + args.prompts = {k: f"{v:05d}" for v, k in enumerate(prompts)} + + # create output folder for the batch + os.makedirs(args.outdir, exist_ok=True) + if args.save_settings or args.save_samples: + print(f"Saving to {os.path.join(args.outdir, args.timestring)}_*") + + # save settings for the batch + if args.save_settings: + filename = os.path.join(args.outdir, f"{args.timestring}_settings.txt") + with open(filename, "w+", encoding="utf-8") as f: + dictlist = dict(args.__dict__) + del dictlist['master_args'] + json.dump(dictlist, f, ensure_ascii=False, indent=4) + + index = 0 + + # function for init image batching + init_array = [] + if args.use_init: + if args.init_image == "": + raise FileNotFoundError("No path was given for init_image") + if args.init_image.startswith('http://') or args.init_image.startswith('https://'): + init_array.append(args.init_image) + elif not os.path.isfile(args.init_image): + if args.init_image[-1] != "/": # avoids path error by adding / to end if not there + args.init_image += "/" + for image in sorted(os.listdir(args.init_image)): # iterates dir and appends images to init_array + if image.split(".")[-1] in ("png", "jpg", "jpeg"): + init_array.append(args.init_image + image) + else: + init_array.append(args.init_image) + else: + init_array = [""] + + # when doing large batches don't flood browser with images + clear_between_batches = args.n_batch >= 32 + + for iprompt, prompt in enumerate(prompts): + args.prompt = prompt + print(f"Prompt {iprompt+1} of {len(prompts)}") + print(f"{args.prompt}") + + all_images = [] + + for batch_index in range(args.n_batch): + if clear_between_batches and batch_index % 32 == 0: + display.clear_output(wait=True) + print(f"Batch {batch_index+1} of {args.n_batch}") + + for image in init_array: # iterates the init images + args.init_image = image + results = generate(args) + for image in results: + if args.make_grid: + all_images.append(T.functional.pil_to_tensor(image)) + if args.save_samples: + if args.filename_format == "{timestring}_{index}_{prompt}.png": + filename = f"{args.timestring}_{index:05}_{sanitize(prompt)[:160]}.png" + else: + filename = f"{args.timestring}_{index:05}_{args.seed}.png" + image.save(os.path.join(args.outdir, filename)) + if args.display_samples: + display.display(image) + index += 1 + args.seed = next_seed(args) + + #print(len(all_images)) + if args.make_grid: + grid = make_grid(all_images, nrow=int(len(all_images)/args.grid_rows)) + grid = rearrange(grid, 'c h w -> h w c').cpu().numpy() + filename = f"{args.timestring}_{iprompt:05d}_grid_{args.seed}.png" + grid_image = Image.fromarray(grid.astype(np.uint8)) + grid_image.save(os.path.join(args.outdir, filename)) + display.clear_output(wait=True) + display.display(grid_image) + + + def render_animation(args, anim_args): + # animations use key framed prompts + args.prompts = animation_prompts + + # expand key frame strings to values + keys = DeformAnimKeys(anim_args) + + # resume animation + start_frame = 0 + if anim_args.resume_from_timestring: + for tmp in os.listdir(args.outdir): + if tmp.split("_")[0] == anim_args.resume_timestring: + start_frame += 1 + start_frame = start_frame - 1 + + # create output folder for the batch + os.makedirs(args.outdir, exist_ok=True) + print(f"Saving animation frames to {args.outdir}") + + # save settings for the batch + settings_filename = os.path.join(args.outdir, f"{args.timestring}_settings.txt") + with open(settings_filename, "w+", encoding="utf-8") as f: + s = {**dict(args.__dict__), **dict(anim_args.__dict__)} + del s['master_args'] + del s['opt'] + json.dump(s, f, ensure_ascii=False, indent=4) + + # resume from timestring + if anim_args.resume_from_timestring: + args.timestring = anim_args.resume_timestring + + # expand prompts out to per-frame + prompt_series = pd.Series([np.nan for a in range(anim_args.max_frames)]) + for i, prompt in animation_prompts.items(): + prompt_series[int(i)] = prompt + prompt_series = prompt_series.ffill().bfill() + + # check for video inits + using_vid_init = anim_args.animation_mode == 'Video Input' + + # load depth model for 3D + predict_depths = (anim_args.animation_mode == '3D' and anim_args.use_depth_warping) or anim_args.save_depth_maps + if predict_depths: + depth_model = DepthModel(device) + depth_model.load_midas(models_path) + if anim_args.midas_weight < 1.0: + depth_model.load_adabins() + else: + depth_model = None + anim_args.save_depth_maps = False + + # state for interpolating between diffusion steps + turbo_steps = 1 if using_vid_init else int(anim_args.diffusion_cadence) + turbo_prev_image, turbo_prev_frame_idx = None, 0 + turbo_next_image, turbo_next_frame_idx = None, 0 + + # resume animation + prev_sample = None + color_match_sample = None + if anim_args.resume_from_timestring: + last_frame = start_frame-1 + if turbo_steps > 1: + last_frame -= last_frame%turbo_steps + path = os.path.join(args.outdir,f"{args.timestring}_{last_frame:05}.png") + img = cv2.imread(path) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + prev_sample = sample_from_cv2(img) + if anim_args.color_coherence != 'None': + color_match_sample = img + if turbo_steps > 1: + turbo_next_image, turbo_next_frame_idx = sample_to_cv2(prev_sample, type=np.float32), last_frame + turbo_prev_image, turbo_prev_frame_idx = turbo_next_image, turbo_next_frame_idx + start_frame = last_frame+turbo_steps + + args.n_samples = 1 + frame_idx = start_frame + while frame_idx < anim_args.max_frames: + print(f"Rendering animation frame {frame_idx} of {anim_args.max_frames}") + noise = keys.noise_schedule_series[frame_idx] + strength = keys.strength_schedule_series[frame_idx] + contrast = keys.contrast_schedule_series[frame_idx] + depth = None + + # emit in-between frames + if turbo_steps > 1: + tween_frame_start_idx = max(0, frame_idx-turbo_steps) + for tween_frame_idx in range(tween_frame_start_idx, frame_idx): + tween = float(tween_frame_idx - tween_frame_start_idx + 1) / float(frame_idx - tween_frame_start_idx) + print(f" creating in between frame {tween_frame_idx} tween:{tween:0.2f}") + + advance_prev = turbo_prev_image is not None and tween_frame_idx > turbo_prev_frame_idx + advance_next = tween_frame_idx > turbo_next_frame_idx + + if depth_model is not None: + assert(turbo_next_image is not None) + depth = depth_model.predict(turbo_next_image, anim_args) + + if anim_args.animation_mode == '2D': + if advance_prev: + turbo_prev_image = anim_frame_warp_2d(turbo_prev_image, args, anim_args, keys, tween_frame_idx) + if advance_next: + turbo_next_image = anim_frame_warp_2d(turbo_next_image, args, anim_args, keys, tween_frame_idx) + else: # '3D' + if advance_prev: + turbo_prev_image = anim_frame_warp_3d(turbo_prev_image, depth, anim_args, keys, tween_frame_idx) + if advance_next: + turbo_next_image = anim_frame_warp_3d(turbo_next_image, depth, anim_args, keys, tween_frame_idx) + turbo_prev_frame_idx = turbo_next_frame_idx = tween_frame_idx + + if turbo_prev_image is not None and tween < 1.0: + img = turbo_prev_image*(1.0-tween) + turbo_next_image*tween + else: + img = turbo_next_image + + filename = f"{args.timestring}_{tween_frame_idx:05}.png" + cv2.imwrite(os.path.join(args.outdir, filename), cv2.cvtColor(img.astype(np.uint8), cv2.COLOR_RGB2BGR)) + if anim_args.save_depth_maps: + depth_model.save(os.path.join(args.outdir, f"{args.timestring}_depth_{tween_frame_idx:05}.png"), depth) + if turbo_next_image is not None: + prev_sample = sample_from_cv2(turbo_next_image) + + # apply transforms to previous frame + if prev_sample is not None: + if anim_args.animation_mode == '2D': + prev_img = anim_frame_warp_2d(sample_to_cv2(prev_sample), args, anim_args, keys, frame_idx) + else: # '3D' + prev_img_cv2 = sample_to_cv2(prev_sample) + depth = depth_model.predict(prev_img_cv2, anim_args) if depth_model else None + prev_img = anim_frame_warp_3d(prev_img_cv2, depth, anim_args, keys, frame_idx) + + # apply color matching + if anim_args.color_coherence != 'None': + if color_match_sample is None: + color_match_sample = prev_img.copy() + else: + prev_img = maintain_colors(prev_img, color_match_sample, anim_args.color_coherence) + + # apply scaling + contrast_sample = prev_img * contrast + # apply frame noising + noised_sample = add_noise(sample_from_cv2(contrast_sample), noise) + + # use transformed previous frame as init for current + args.use_init = True + if half_precision: + args.init_sample = noised_sample.half().to(device) + else: + args.init_sample = noised_sample.to(device) + args.strength = max(0.0, min(1.0, strength)) + + # grab prompt for current frame + args.prompt = prompt_series[frame_idx] + print(f"{args.prompt} {args.seed}") + + # grab init image for current frame + if using_vid_init: + init_frame = os.path.join(args.outdir, 'inputframes', f"{frame_idx+1:04}.jpg") + print(f"Using video init frame {init_frame}") + args.init_image = init_frame + + # sample the diffusion model + sample, image = generate(args, return_latent=False, return_sample=True) + if not using_vid_init: + prev_sample = sample + + if turbo_steps > 1: + turbo_prev_image, turbo_prev_frame_idx = turbo_next_image, turbo_next_frame_idx + turbo_next_image, turbo_next_frame_idx = sample_to_cv2(sample, type=np.float32), frame_idx + frame_idx += turbo_steps + else: + filename = f"{args.timestring}_{frame_idx:05}.png" + image.save(os.path.join(args.outdir, filename)) + if anim_args.save_depth_maps: + if depth is None: + depth = depth_model.predict(sample_to_cv2(sample), anim_args) + depth_model.save(os.path.join(args.outdir, f"{args.timestring}_depth_{frame_idx:05}.png"), depth) + frame_idx += 1 + + display.clear_output(wait=True) + display.display(image) + + args.seed = next_seed(args) + + def render_input_video(args, anim_args): + # create a folder for the video input frames to live in + video_in_frame_path = os.path.join(args.outdir, 'inputframes') + os.makedirs(video_in_frame_path, exist_ok=True) + + # save the video frames from input video + print(f"Exporting Video Frames (1 every {anim_args.extract_nth_frame}) frames to {video_in_frame_path}...") + try: + for f in pathlib.Path(video_in_frame_path).glob('*.jpg'): + f.unlink() + except: + pass + vf = r'select=not(mod(n\,'+str(anim_args.extract_nth_frame)+'))' + subprocess.run([ + 'ffmpeg', '-i', f'{anim_args.video_init_path}', + '-vf', f'{vf}', '-vsync', 'vfr', '-q:v', '2', + '-loglevel', 'error', '-stats', + os.path.join(video_in_frame_path, '%04d.jpg') + ], stdout=subprocess.PIPE).stdout.decode('utf-8') + + # determine max frames from length of input frames + anim_args.max_frames = len([f for f in pathlib.Path(video_in_frame_path).glob('*.jpg')]) + + args.use_init = True + print(f"Loading {anim_args.max_frames} input frames from {video_in_frame_path} and saving video frames to {args.outdir}") + render_animation(args, anim_args) + + def render_interpolation(args, anim_args): + # animations use key framed prompts + args.prompts = animation_prompts + + # create output folder for the batch + os.makedirs(args.outdir, exist_ok=True) + print(f"Saving animation frames to {args.outdir}") + + # save settings for the batch + settings_filename = os.path.join(args.outdir, f"{args.timestring}_settings.txt") + with open(settings_filename, "w+", encoding="utf-8") as f: + s = {**dict(args.__dict__), **dict(anim_args.__dict__)} + del s['master_args'] + del s['opt'] + json.dump(s, f, ensure_ascii=False, indent=4) + + # Interpolation Settings + args.n_samples = 1 + args.seed_behavior = 'fixed' # force fix seed at the moment bc only 1 seed is available + prompts_c_s = [] # cache all the text embeddings + + print(f"Preparing for interpolation of the following...") + + for i, prompt in animation_prompts.items(): + args.prompt = prompt + + # sample the diffusion model + results = generate(args, return_c=True) + c, image = results[0], results[1] + prompts_c_s.append(c) + + # display.clear_output(wait=True) + display.display(image) + + args.seed = next_seed(args) + + display.clear_output(wait=True) + print(f"Interpolation start...") + + frame_idx = 0 + + if anim_args.interpolate_key_frames: + for i in range(len(prompts_c_s)-1): + dist_frames = list(animation_prompts.items())[i+1][0] - list(animation_prompts.items())[i][0] + if dist_frames <= 0: + print("key frames duplicated or reversed. interpolation skipped.") + return + else: + for j in range(dist_frames): + # interpolate the text embedding + prompt1_c = prompts_c_s[i] + prompt2_c = prompts_c_s[i+1] + args.init_c = prompt1_c.add(prompt2_c.sub(prompt1_c).mul(j * 1/dist_frames)) + + # sample the diffusion model + results = generate(args) + image = results[0] + + filename = f"{args.timestring}_{frame_idx:05}.png" + image.save(os.path.join(args.outdir, filename)) + frame_idx += 1 + + display.clear_output(wait=True) + display.display(image) + + args.seed = next_seed(args) + + else: + for i in range(len(prompts_c_s)-1): + for j in range(anim_args.interpolate_x_frames+1): + # interpolate the text embedding + prompt1_c = prompts_c_s[i] + prompt2_c = prompts_c_s[i+1] + args.init_c = prompt1_c.add(prompt2_c.sub(prompt1_c).mul(j * 1/(anim_args.interpolate_x_frames+1))) + + # sample the diffusion model + results = generate(args) + image = results[0] + + filename = f"{args.timestring}_{frame_idx:05}.png" + image.save(os.path.join(args.outdir, filename)) + frame_idx += 1 + + display.clear_output(wait=True) + display.display(image) + + args.seed = next_seed(args) + + # generate the last prompt + args.init_c = prompts_c_s[-1] + results = generate(args) + image = results[0] + filename = f"{args.timestring}_{frame_idx:05}.png" + image.save(os.path.join(args.outdir, filename)) + + display.clear_output(wait=True) + display.display(image) + args.seed = next_seed(args) + + #clear init_c + args.init_c = None + + + args = SimpleNamespace(**DeforumArgs()) + anim_args = SimpleNamespace(**DeforumAnimArgs()) + + args.timestring = time.strftime('%Y%m%d%H%M%S') + args.strength = max(0.0, min(1.0, args.strength)) + + if args.seed == -1: + args.seed = random.randint(0, 2**32 - 1) + if not args.use_init: + args.init_image = None + if args.sampler == 'plms' and (args.use_init or anim_args.animation_mode != 'None'): + print(f"Init images aren't supported with PLMS yet, switching to KLMS") + args.sampler = 'klms' + if args.sampler != 'ddim': + args.ddim_eta = 0 + + if anim_args.animation_mode == 'None': + anim_args.max_frames = 1 + elif anim_args.animation_mode == 'Video Input': + args.use_init = True + + # clean up unused memory + gc.collect() + torch.cuda.empty_cache() + + # dispatch to appropriate renderer + if anim_args.animation_mode == '2D' or anim_args.animation_mode == '3D': + render_animation(args, anim_args) + elif anim_args.animation_mode == 'Video Input': + render_input_video(args, anim_args) + elif anim_args.animation_mode == 'Interpolation': + render_interpolation(args, anim_args) + else: + render_image_batch(args) + + + skip_video_for_run_all = False #@param {type: 'boolean'} + fps = 12 #@param {type:"number"} + #@markdown **Manual Settings** + use_manual_settings = False #@param {type:"boolean"} + image_path = "./output/out_%05d.png" #@param {type:"string"} + mp4_path = "./output/out_%05d.mp4" #@param {type:"string"} + + + if skip_video_for_run_all == True or opt.enable_animation_mode == False: + print('Skipping video creation, uncheck skip_video_for_run_all if you want to run it') + else: + import os + import subprocess + from base64 import b64encode + + print(f"{image_path} -> {mp4_path}") + + if use_manual_settings: + max_frames = "200" #@param {type:"string"} + else: + image_path = os.path.join(args.outdir, f"{args.timestring}_%05d.png") + mp4_path = os.path.join(args.outdir, f"{args.timestring}.mp4") + max_frames = str(anim_args.max_frames) + + # make video + cmd = [ + 'ffmpeg', + '-y', + '-vcodec', 'png', + '-r', str(fps), + '-start_number', str(0), + '-i', image_path, + '-frames:v', max_frames, + '-c:v', 'libx264', + '-vf', + f'fps={fps}', + '-pix_fmt', 'yuv420p', + '-crf', '17', + '-preset', 'veryfast', + mp4_path + ] + process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + stdout, stderr = process.communicate() + if process.returncode != 0: + print(stderr) + raise RuntimeError(stderr) + + mp4 = open(mp4_path,'rb').read() + data_url = "data:video/mp4;base64," + b64encode(mp4).decode() + display.display( display.HTML(f'') ) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/scripts/ModelManager.py b/scripts/ModelManager.py new file mode 100644 index 0000000..983f85b --- /dev/null +++ b/scripts/ModelManager.py @@ -0,0 +1,46 @@ +# base webui import and utils. +from webui_streamlit import st +from sd_utils import * + +# streamlit imports + + +#other imports +import pandas as pd +from io import StringIO + +# Temp imports + + +# end of imports +#--------------------------------------------------------------------------------------------------------------- + +def layout(): + #search = st.text_input(label="Search", placeholder="Type the name of the model you want to search for.", help="") + + csvString = f""" + ,Stable Diffusion v1.4 , ./models/ldm/stable-diffusion-v1 , https://www.googleapis.com/storage/v1/b/aai-blog-files/o/sd-v1-4.ckpt?alt=media + ,GFPGAN v1.3 , ./src/gfpgan/experiments/pretrained_models , https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth + ,RealESRGAN_x4plus , ./src/realesrgan/experiments/pretrained_models , https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth + ,RealESRGAN_x4plus_anime_6B , ./src/realesrgan/experiments/pretrained_models , https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth + ,Waifu Diffusion v1.2 , ./models/custom , http://wd.links.sd:8880/wd-v1-2-full-ema.ckpt + ,TrinArt Stable Diffusion v2 , ./models/custom , https://huggingface.co/naclbit/trinart_stable_diffusion_v2/resolve/main/trinart2_step115000.ckpt + ,Stable Diffusion Concept Library , ./models/customsd-concepts-library , https://github.com/sd-webui/sd-concepts-library + """ + colms = st.columns((1, 3, 5, 5)) + columns = ["№",'Model Name','Save Location','Download Link'] + + # Convert String into StringIO + csvStringIO = StringIO(csvString) + df = pd.read_csv(csvStringIO, sep=",", header=None, names=columns) + + for col, field_name in zip(colms, columns): + # table header + col.write(field_name) + + for x, model_name in enumerate(df["Model Name"]): + col1, col2, col3, col4 = st.columns((1, 3, 4, 6)) + col1.write(x) # index + col2.write(df['Model Name'][x]) + col3.write(df['Save Location'][x]) + col4.write(df['Download Link'][x]) \ No newline at end of file diff --git a/scripts/Settings.py b/scripts/Settings.py new file mode 100644 index 0000000..a1e21ec --- /dev/null +++ b/scripts/Settings.py @@ -0,0 +1,5 @@ +from webui_streamlit import st + +# The global settings section will be moved to the Settings page. +#with st.expander("Global Settings:"): +st.write("Global Settings:") diff --git a/scripts/home.py b/scripts/home.py new file mode 100644 index 0000000..2702fcc --- /dev/null +++ b/scripts/home.py @@ -0,0 +1,216 @@ +# base webui import and utils. +from webui_streamlit import st +from sd_utils import * + +# streamlit imports + + +#other imports + +# Temp imports + + +# end of imports +#--------------------------------------------------------------------------------------------------------------- + +import os +from PIL import Image + +try: + # this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start. + from transformers import logging + + logging.set_verbosity_error() +except: + pass + +class plugin_info(): + plugname = "home" + description = "Home" + isTab = True + displayPriority = 0 + +def getLatestGeneratedImagesFromPath(): + #get the latest images from the generated images folder + #get the path to the generated images folder + generatedImagesPath = os.path.join(os.getcwd(),'outputs') + #get all the files from the folders and subfolders + files = [] + #get the latest 10 images from the output folder without walking the subfolders + for r, d, f in os.walk(generatedImagesPath): + for file in f: + if '.png' in file: + files.append(os.path.join(r, file)) + #sort the files by date + files.sort(reverse=True, key=os.path.getmtime) + latest = files[:90] + latest.reverse() + + # reverse the list so the latest images are first and truncate to + # a reasonable number of images, 10 pages worth + return [Image.open(f) for f in latest] + +def get_images_from_lexica(): + #scrape images from lexica.art + #get the html from the page + #get the html with cookies and javascript + apiEndpoint = r'https://lexica.art/api/trpc/prompts.infinitePrompts?batch=1&input=%7B%220%22%3A%7B%22json%22%3A%7B%22limit%22%3A10%2C%22text%22%3A%22%22%2C%22cursor%22%3A10%7D%7D%7D' + #REST API call + # + from requests_html import HTMLSession + session = HTMLSession() + + response = session.get(apiEndpoint) + #req = requests.Session() + #req.headers['user-agent'] = 'Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/96.0.4664.45 Safari/537.36' + #response = req.get(apiEndpoint) + print(response.status_code) + print(response.text) + #get the json from the response + #json = response.json() + #get the prompts from the json + print(response) + #session = requests.Session() + #parseEndpointJson = session.get(apiEndpoint,headers=headers,verify=False) + #print(parseEndpointJson) + #print('test2') + #page = requests.get("https://lexica.art/", headers={'User-Agent': 'Mozilla/5.0'}) + #parse the html + #soup = BeautifulSoup(page.content, 'html.parser') + #find all the images + #print(soup) + #images = soup.find_all('alt-image') + #create a list to store the image urls + image_urls = [] + #loop through the images + for image in images: + #get the url + image_url = image['src'] + #add it to the list + image_urls.append('http://www.lexica.art/'+image_url) + #return the list + print(image_urls) + return image_urls + +def layout(): + #streamlit home page layout + #center the title + st.markdown("

Welcome, let's make some 🎨

", unsafe_allow_html=True) + #make a gallery of images + #st.markdown("

Gallery

", unsafe_allow_html=True) + #create a gallery of images using columns + #col1, col2, col3 = st.columns(3) + #load the images + #create 3 columns + # create a tab for the gallery + #st.markdown("

Gallery

", unsafe_allow_html=True) + #st.markdown("

Gallery

", unsafe_allow_html=True) + + history_tab, discover_tabs = st.tabs(["History","Discover"]) + + latestImages = getLatestGeneratedImagesFromPath() + st.session_state['latestImages'] = latestImages + + with history_tab: + ##--------------------------------------------------------- + ## image slideshow test + ## Number of entries per screen + #slideshow_N = 9 + #slideshow_page_number = 0 + #slideshow_last_page = len(latestImages) // slideshow_N + + ## Add a next button and a previous button + + #slideshow_prev, slideshow_image_col , slideshow_next = st.columns([1, 10, 1]) + + #with slideshow_image_col: + #slideshow_image = st.empty() + + #slideshow_image.image(st.session_state['latestImages'][0]) + + #current_image = 0 + + #if slideshow_next.button("Next", key=1): + ##print (current_image+1) + #current_image = current_image+1 + #slideshow_image.image(st.session_state['latestImages'][current_image+1]) + #if slideshow_prev.button("Previous", key=0): + ##print ([current_image-1]) + #current_image = current_image-1 + #slideshow_image.image(st.session_state['latestImages'][current_image - 1]) + + + #--------------------------------------------------------- + + # image gallery + # Number of entries per screen + gallery_N = 9 + if not "galleryPage" in st.session_state: + st.session_state["galleryPage"] = 0 + gallery_last_page = len(latestImages) // gallery_N + + # Add a next button and a previous button + + gallery_prev, gallery_refresh, gallery_pagination , gallery_next = st.columns([2, 2, 8, 1]) + + # the pagination doesnt work for now so its better to disable the buttons. + + if gallery_refresh.button("Refresh", key=4): + st.session_state["galleryPage"] = 0 + + if gallery_next.button("Next", key=3): + + if st.session_state["galleryPage"] + 1 > gallery_last_page: + st.session_state["galleryPage"] = 0 + else: + st.session_state["galleryPage"] += 1 + + if gallery_prev.button("Previous", key=2): + + if st.session_state["galleryPage"] - 1 < 0: + st.session_state["galleryPage"] = gallery_last_page + else: + st.session_state["galleryPage"] -= 1 + + print(st.session_state["galleryPage"]) + # Get start and end indices of the next page of the dataframe + gallery_start_idx = st.session_state["galleryPage"] * gallery_N + gallery_end_idx = (1 + st.session_state["galleryPage"]) * gallery_N + + #--------------------------------------------------------- + + placeholder = st.empty() + + #populate the 3 images per column + with placeholder.container(): + col1, col2, col3 = st.columns(3) + col1_cont = st.container() + col2_cont = st.container() + col3_cont = st.container() + + print (len(st.session_state['latestImages'])) + images = list(reversed(st.session_state['latestImages']))[gallery_start_idx:(gallery_start_idx+gallery_N)] + + with col1_cont: + with col1: + [st.image(images[index]) for index in [0, 3, 6] if index < len(images)] + with col2_cont: + with col2: + [st.image(images[index]) for index in [1, 4, 7] if index < len(images)] + with col3_cont: + with col3: + [st.image(images[index]) for index in [2, 5, 8] if index < len(images)] + + + st.session_state['historyTab'] = [history_tab,col1,col2,col3,placeholder,col1_cont,col2_cont,col3_cont] + + with discover_tabs: + st.markdown("

Soon :)

", unsafe_allow_html=True) + + #display the images + #add a button to the gallery + #st.markdown("

Try it out

", unsafe_allow_html=True) + #create a button to the gallery + #if st.button("Try it out"): + #if the button is clicked, go to the gallery + #st.experimental_rerun() diff --git a/scripts/img2img.py b/scripts/img2img.py new file mode 100644 index 0000000..142fe81 --- /dev/null +++ b/scripts/img2img.py @@ -0,0 +1,592 @@ +# base webui import and utils. +from webui_streamlit import st +from sd_utils import * + +# streamlit imports +from streamlit import StopException + +#other imports +import cv2 +from PIL import Image, ImageOps +import torch +import k_diffusion as K +import numpy as np +import time +import torch +import skimage +from ldm.models.diffusion.ddim import DDIMSampler +from ldm.models.diffusion.plms import PLMSSampler +# Temp imports + + +# end of imports +#--------------------------------------------------------------------------------------------------------------- + + +try: + # this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start. + from transformers import logging + + logging.set_verbosity_error() +except: + pass + +def img2img(prompt: str = '', init_info: any = None, init_info_mask: any = None, mask_mode: int = 0, mask_blur_strength: int = 3, + mask_restore: bool = False, ddim_steps: int = 50, sampler_name: str = 'DDIM', + n_iter: int = 1, cfg_scale: float = 7.5, denoising_strength: float = 0.8, + seed: int = -1, noise_mode: int = 0, find_noise_steps: str = "", height: int = 512, width: int = 512, resize_mode: int = 0, fp = None, + variant_amount: float = None, variant_seed: int = None, ddim_eta:float = 0.0, + write_info_files:bool = True, RealESRGAN_model: str = "RealESRGAN_x4plus_anime_6B", + separate_prompts:bool = False, normalize_prompt_weights:bool = True, + save_individual_images: bool = True, save_grid: bool = True, group_by_prompt: bool = True, + save_as_jpg: bool = True, use_GFPGAN: bool = True, use_RealESRGAN: bool = True, loopback: bool = False, + random_seed_loopback: bool = False + ): + + outpath = st.session_state['defaults'].general.outdir_img2img or st.session_state['defaults'].general.outdir or "outputs/img2img-samples" + #err = False + #loopback = False + #skip_save = False + seed = seed_to_int(seed) + + batch_size = 1 + + #prompt_matrix = 0 + #normalize_prompt_weights = 1 in toggles + #loopback = 2 in toggles + #random_seed_loopback = 3 in toggles + #skip_save = 4 not in toggles + #save_grid = 5 in toggles + #sort_samples = 6 in toggles + #write_info_files = 7 in toggles + #write_sample_info_to_log_file = 8 in toggles + #jpg_sample = 9 in toggles + #use_GFPGAN = 10 in toggles + #use_RealESRGAN = 11 in toggles + + if sampler_name == 'PLMS': + sampler = PLMSSampler(st.session_state["model"]) + elif sampler_name == 'DDIM': + sampler = DDIMSampler(st.session_state["model"]) + elif sampler_name == 'k_dpm_2_a': + sampler = KDiffusionSampler(st.session_state["model"],'dpm_2_ancestral') + elif sampler_name == 'k_dpm_2': + sampler = KDiffusionSampler(st.session_state["model"],'dpm_2') + elif sampler_name == 'k_euler_a': + sampler = KDiffusionSampler(st.session_state["model"],'euler_ancestral') + elif sampler_name == 'k_euler': + sampler = KDiffusionSampler(st.session_state["model"],'euler') + elif sampler_name == 'k_heun': + sampler = KDiffusionSampler(st.session_state["model"],'heun') + elif sampler_name == 'k_lms': + sampler = KDiffusionSampler(st.session_state["model"],'lms') + else: + raise Exception("Unknown sampler: " + sampler_name) + + def process_init_mask(init_mask: Image): + if init_mask.mode == "RGBA": + init_mask = init_mask.convert('RGBA') + background = Image.new('RGBA', init_mask.size, (0, 0, 0)) + init_mask = Image.alpha_composite(background, init_mask) + init_mask = init_mask.convert('RGB') + return init_mask + + init_img = init_info + init_mask = None + if mask_mode == 0: + if init_info_mask: + init_mask = process_init_mask(init_info_mask) + elif mask_mode == 1: + if init_info_mask: + init_mask = process_init_mask(init_info_mask) + init_mask = ImageOps.invert(init_mask) + elif mask_mode == 2: + init_img_transparency = init_img.split()[-1].convert('L')#.point(lambda x: 255 if x > 0 else 0, mode='1') + init_mask = init_img_transparency + init_mask = init_mask.convert("RGB") + init_mask = resize_image(resize_mode, init_mask, width, height) + init_mask = init_mask.convert("RGB") + + assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]' + t_enc = int(denoising_strength * ddim_steps) + + if init_mask is not None and (noise_mode == 2 or noise_mode == 3) and init_img is not None: + noise_q = 0.99 + color_variation = 0.0 + mask_blend_factor = 1.0 + + np_init = (np.asarray(init_img.convert("RGB"))/255.0).astype(np.float64) # annoyingly complex mask fixing + np_mask_rgb = 1. - (np.asarray(ImageOps.invert(init_mask).convert("RGB"))/255.0).astype(np.float64) + np_mask_rgb -= np.min(np_mask_rgb) + np_mask_rgb /= np.max(np_mask_rgb) + np_mask_rgb = 1. - np_mask_rgb + np_mask_rgb_hardened = 1. - (np_mask_rgb < 0.99).astype(np.float64) + blurred = skimage.filters.gaussian(np_mask_rgb_hardened[:], sigma=16., channel_axis=2, truncate=32.) + blurred2 = skimage.filters.gaussian(np_mask_rgb_hardened[:], sigma=16., channel_axis=2, truncate=32.) + #np_mask_rgb_dilated = np_mask_rgb + blurred # fixup mask todo: derive magic constants + #np_mask_rgb = np_mask_rgb + blurred + np_mask_rgb_dilated = np.clip((np_mask_rgb + blurred2) * 0.7071, 0., 1.) + np_mask_rgb = np.clip((np_mask_rgb + blurred) * 0.7071, 0., 1.) + + noise_rgb = get_matched_noise(np_init, np_mask_rgb, noise_q, color_variation) + blend_mask_rgb = np.clip(np_mask_rgb_dilated,0.,1.) ** (mask_blend_factor) + noised = noise_rgb[:] + blend_mask_rgb **= (2.) + noised = np_init[:] * (1. - blend_mask_rgb) + noised * blend_mask_rgb + + np_mask_grey = np.sum(np_mask_rgb, axis=2)/3. + ref_mask = np_mask_grey < 1e-3 + + all_mask = np.ones((height, width), dtype=bool) + noised[all_mask,:] = skimage.exposure.match_histograms(noised[all_mask,:]**1., noised[ref_mask,:], channel_axis=1) + + init_img = Image.fromarray(np.clip(noised * 255., 0., 255.).astype(np.uint8), mode="RGB") + st.session_state["editor_image"].image(init_img) # debug + + def init(): + image = init_img.convert('RGB') + image = np.array(image).astype(np.float32) / 255.0 + image = image[None].transpose(0, 3, 1, 2) + image = torch.from_numpy(image) + + mask_channel = None + if init_mask: + alpha = resize_image(resize_mode, init_mask, width // 8, height // 8) + mask_channel = alpha.split()[-1] + + mask = None + if mask_channel is not None: + mask = np.array(mask_channel).astype(np.float32) / 255.0 + mask = (1 - mask) + mask = np.tile(mask, (4, 1, 1)) + mask = mask[None].transpose(0, 1, 2, 3) + mask = torch.from_numpy(mask).to(st.session_state["device"]) + + if st.session_state['defaults'].general.optimized: + st.session_state.modelFS.to(st.session_state["device"] ) + + init_image = 2. * image - 1. + init_image = init_image.to(st.session_state["device"]) + init_latent = (st.session_state["model"] if not st.session_state['defaults'].general.optimized else st.session_state.modelFS).get_first_stage_encoding((st.session_state["model"] if not st.session_state['defaults'].general.optimized else modelFS).encode_first_stage(init_image)) # move to latent space + + if st.session_state['defaults'].general.optimized: + mem = torch.cuda.memory_allocated()/1e6 + st.session_state.modelFS.to("cpu") + while(torch.cuda.memory_allocated()/1e6 >= mem): + time.sleep(1) + + return init_latent, mask, + + def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name): + t_enc_steps = t_enc + obliterate = False + if ddim_steps == t_enc_steps: + t_enc_steps = t_enc_steps - 1 + obliterate = True + + if sampler_name != 'DDIM': + x0, z_mask = init_data + + sigmas = sampler.model_wrap.get_sigmas(ddim_steps) + noise = x * sigmas[ddim_steps - t_enc_steps - 1] + + xi = x0 + noise + + # Obliterate masked image + if z_mask is not None and obliterate: + random = torch.randn(z_mask.shape, device=xi.device) + xi = (z_mask * noise) + ((1-z_mask) * xi) + + sigma_sched = sigmas[ddim_steps - t_enc_steps - 1:] + model_wrap_cfg = CFGMaskedDenoiser(sampler.model_wrap) + samples_ddim = K.sampling.__dict__[f'sample_{sampler.get_sampler_name()}'](model_wrap_cfg, xi, sigma_sched, + extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, + 'cond_scale': cfg_scale, 'mask': z_mask, 'x0': x0, 'xi': xi}, disable=False, + callback=generation_callback) + else: + + x0, z_mask = init_data + + sampler.make_schedule(ddim_num_steps=ddim_steps, ddim_eta=0.0, verbose=False) + z_enc = sampler.stochastic_encode(x0, torch.tensor([t_enc_steps]*batch_size).to(st.session_state["device"] )) + + # Obliterate masked image + if z_mask is not None and obliterate: + random = torch.randn(z_mask.shape, device=z_enc.device) + z_enc = (z_mask * random) + ((1-z_mask) * z_enc) + + # decode it + samples_ddim = sampler.decode(z_enc, conditioning, t_enc_steps, + unconditional_guidance_scale=cfg_scale, + unconditional_conditioning=unconditional_conditioning, + z_mask=z_mask, x0=x0) + return samples_ddim + + + + if loopback: + output_images, info = None, None + history = [] + initial_seed = None + + do_color_correction = False + try: + from skimage import exposure + do_color_correction = True + except: + print("Install scikit-image to perform color correction on loopback") + + for i in range(n_iter): + if do_color_correction and i == 0: + correction_target = cv2.cvtColor(np.asarray(init_img.copy()), cv2.COLOR_RGB2LAB) + + output_images, seed, info, stats = process_images( + outpath=outpath, + func_init=init, + func_sample=sample, + prompt=prompt, + seed=seed, + sampler_name=sampler_name, + save_grid=save_grid, + batch_size=1, + n_iter=1, + steps=ddim_steps, + cfg_scale=cfg_scale, + width=width, + height=height, + prompt_matrix=separate_prompts, + use_GFPGAN=use_GFPGAN, + use_RealESRGAN=use_RealESRGAN, # Forcefully disable upscaling when using loopback + realesrgan_model_name=RealESRGAN_model, + normalize_prompt_weights=normalize_prompt_weights, + save_individual_images=save_individual_images, + init_img=init_img, + init_mask=init_mask, + mask_blur_strength=mask_blur_strength, + mask_restore=mask_restore, + denoising_strength=denoising_strength, + noise_mode=noise_mode, + find_noise_steps=find_noise_steps, + resize_mode=resize_mode, + uses_loopback=loopback, + uses_random_seed_loopback=random_seed_loopback, + sort_samples=group_by_prompt, + write_info_files=write_info_files, + jpg_sample=save_as_jpg + ) + + if initial_seed is None: + initial_seed = seed + + input_image = init_img + init_img = output_images[0] + + if do_color_correction and correction_target is not None: + init_img = Image.fromarray(cv2.cvtColor(exposure.match_histograms( + cv2.cvtColor( + np.asarray(init_img), + cv2.COLOR_RGB2LAB + ), + correction_target, + channel_axis=2 + ), cv2.COLOR_LAB2RGB).astype("uint8")) + if mask_restore is True and init_mask is not None: + color_mask = init_mask.filter(ImageFilter.GaussianBlur(mask_blur_strength)) + color_mask = color_mask.convert('L') + source_image = input_image.convert('RGB') + target_image = init_img.convert('RGB') + + init_img = Image.composite(source_image, target_image, color_mask) + + if not random_seed_loopback: + seed = seed + 1 + else: + seed = seed_to_int(None) + + denoising_strength = max(denoising_strength * 0.95, 0.1) + history.append(init_img) + + output_images = history + seed = initial_seed + + else: + output_images, seed, info, stats = process_images( + outpath=outpath, + func_init=init, + func_sample=sample, + prompt=prompt, + seed=seed, + sampler_name=sampler_name, + save_grid=save_grid, + batch_size=batch_size, + n_iter=n_iter, + steps=ddim_steps, + cfg_scale=cfg_scale, + width=width, + height=height, + prompt_matrix=separate_prompts, + use_GFPGAN=use_GFPGAN, + use_RealESRGAN=use_RealESRGAN, + realesrgan_model_name=RealESRGAN_model, + normalize_prompt_weights=normalize_prompt_weights, + save_individual_images=save_individual_images, + init_img=init_img, + init_mask=init_mask, + mask_blur_strength=mask_blur_strength, + denoising_strength=denoising_strength, + noise_mode=noise_mode, + find_noise_steps=find_noise_steps, + mask_restore=mask_restore, + resize_mode=resize_mode, + uses_loopback=loopback, + sort_samples=group_by_prompt, + write_info_files=write_info_files, + jpg_sample=save_as_jpg + ) + + del sampler + + return output_images, seed, info, stats + +# + + +def layout(): + with st.form("img2img-inputs"): + st.session_state["generation_mode"] = "img2img" + + img2img_input_col, img2img_generate_col = st.columns([10,1]) + with img2img_input_col: + #prompt = st.text_area("Input Text","") + prompt = st.text_input("Input Text","", placeholder="A corgi wearing a top hat as an oil painting.") + + # Every form must have a submit button, the extra blank spaces is a temp way to align it with the input field. Needs to be done in CSS or some other way. + img2img_generate_col.write("") + img2img_generate_col.write("") + generate_button = img2img_generate_col.form_submit_button("Generate") + + + # creating the page layout using columns + col1_img2img_layout, col2_img2img_layout, col3_img2img_layout = st.columns([1,2,2], gap="small") + + with col1_img2img_layout: + # If we have custom models available on the "models/custom" + #folder then we show a menu to select which model we want to use, otherwise we use the main model for SD + if st.session_state["CustomModel_available"]: + st.session_state["custom_model"] = st.selectbox("Custom Model:", st.session_state["custom_models"], + index=st.session_state["custom_models"].index(st.session_state['defaults'].general.default_model), + help="Select the model you want to use. This option is only available if you have custom models \ + on your 'models/custom' folder. The model name that will be shown here is the same as the name\ + the file for the model has on said folder, it is recommended to give the .ckpt file a name that \ + will make it easier for you to distinguish it from other models. Default: Stable Diffusion v1.4") + else: + st.session_state["custom_model"] = "Stable Diffusion v1.4" + + + st.session_state["sampling_steps"] = st.slider("Sampling Steps", + value=st.session_state['defaults'].img2img.sampling_steps, + min_value=st.session_state['defaults'].img2img.slider_bounds.sampling.lower, + max_value=st.session_state['defaults'].img2img.slider_bounds.sampling.upper, + step=st.session_state['defaults'].img2img.slider_steps.sampling) + + sampler_name_list = ["k_lms", "k_euler", "k_euler_a", "k_dpm_2", "k_dpm_2_a", "k_heun", "PLMS", "DDIM"] + st.session_state["sampler_name"] = st.selectbox("Sampling method",sampler_name_list, + index=sampler_name_list.index(st.session_state['defaults'].img2img.sampler_name), help="Sampling method to use.") + + mask_mode_list = ["Mask", "Inverted mask", "Image alpha"] + mask_mode = st.selectbox("Mask Mode", mask_mode_list, + help="Select how you want your image to be masked.\"Mask\" modifies the image where the mask is white.\n\ + \"Inverted mask\" modifies the image where the mask is black. \"Image alpha\" modifies the image where the image is transparent." + ) + mask_mode = mask_mode_list.index(mask_mode) + + width = st.slider("Width:", min_value=64, max_value=1024, value=st.session_state['defaults'].img2img.width, step=64) + height = st.slider("Height:", min_value=64, max_value=1024, value=st.session_state['defaults'].img2img.height, step=64) + seed = st.text_input("Seed:", value=st.session_state['defaults'].img2img.seed, help=" The seed to use, if left blank a random seed will be generated.") + noise_mode_list = ["Seed", "Find Noise", "Matched Noise", "Find+Matched Noise"] + noise_mode = st.selectbox( + "Noise Mode", noise_mode_list, + help="" + ) + noise_mode = noise_mode_list.index(noise_mode) + find_noise_steps = st.slider("Find Noise Steps", value=100, min_value=1, max_value=500) + batch_count = st.slider("Batch count.", min_value=1, max_value=100, value=st.session_state['defaults'].img2img.batch_count, step=1, + help="How many iterations or batches of images to generate in total.") + + # + with st.expander("Advanced"): + separate_prompts = st.checkbox("Create Prompt Matrix.", value=st.session_state['defaults'].img2img.separate_prompts, + help="Separate multiple prompts using the `|` character, and get all combinations of them.") + normalize_prompt_weights = st.checkbox("Normalize Prompt Weights.", value=st.session_state['defaults'].img2img.normalize_prompt_weights, + help="Ensure the sum of all weights add up to 1.0") + loopback = st.checkbox("Loopback.", value=st.session_state['defaults'].img2img.loopback, help="Use images from previous batch when creating next batch.") + random_seed_loopback = st.checkbox("Random loopback seed.", value=st.session_state['defaults'].img2img.random_seed_loopback, help="Random loopback seed") + img2img_mask_restore = st.checkbox("Only modify regenerated parts of image", + value=st.session_state['defaults'].img2img.mask_restore, + help="Enable to restore the unmasked parts of the image with the input, may not blend as well but preserves detail") + save_individual_images = st.checkbox("Save individual images.", value=st.session_state['defaults'].img2img.save_individual_images, + help="Save each image generated before any filter or enhancement is applied.") + save_grid = st.checkbox("Save grid",value=st.session_state['defaults'].img2img.save_grid, help="Save a grid with all the images generated into a single image.") + group_by_prompt = st.checkbox("Group results by prompt", value=st.session_state['defaults'].img2img.group_by_prompt, + help="Saves all the images with the same prompt into the same folder. \ + When using a prompt matrix each prompt combination will have its own folder.") + write_info_files = st.checkbox("Write Info file", value=st.session_state['defaults'].img2img.write_info_files, + help="Save a file next to the image with informartion about the generation.") + save_as_jpg = st.checkbox("Save samples as jpg", value=st.session_state['defaults'].img2img.save_as_jpg, help="Saves the images as jpg instead of png.") + + if st.session_state["GFPGAN_available"]: + use_GFPGAN = st.checkbox("Use GFPGAN", value=st.session_state['defaults'].img2img.use_GFPGAN, help="Uses the GFPGAN model to improve faces after the generation.\ + This greatly improve the quality and consistency of faces but uses extra VRAM. Disable if you need the extra VRAM.") + else: + use_GFPGAN = False + + if st.session_state["RealESRGAN_available"]: + st.session_state["use_RealESRGAN"] = st.checkbox("Use RealESRGAN", value=st.session_state['defaults'].img2img.use_RealESRGAN, + help="Uses the RealESRGAN model to upscale the images after the generation.\ + This greatly improve the quality and lets you have high resolution images but uses extra VRAM. Disable if you need the extra VRAM.") + st.session_state["RealESRGAN_model"] = st.selectbox("RealESRGAN model", ["RealESRGAN_x4plus", "RealESRGAN_x4plus_anime_6B"], index=0) + else: + st.session_state["use_RealESRGAN"] = False + st.session_state["RealESRGAN_model"] = "RealESRGAN_x4plus" + + variant_amount = st.slider("Variant Amount:", value=st.session_state['defaults'].img2img.variant_amount, min_value=0.0, max_value=1.0, step=0.01) + variant_seed = st.text_input("Variant Seed:", value=st.session_state['defaults'].img2img.variant_seed, + help="The seed to use when generating a variant, if left blank a random seed will be generated.") + cfg_scale = st.slider("CFG (Classifier Free Guidance Scale):", min_value=1.0, max_value=30.0, value=st.session_state['defaults'].img2img.cfg_scale, step=0.5, + help="How strongly the image should follow the prompt.") + batch_size = st.slider("Batch size", min_value=1, max_value=100, value=st.session_state['defaults'].img2img.batch_size, step=1, + help="How many images are at once in a batch.\ + It increases the VRAM usage a lot but if you have enough VRAM it can reduce the time it takes to finish \ + generation as more images are generated at once.\ + Default: 1") + + st.session_state["denoising_strength"] = st.slider("Denoising Strength:", value=st.session_state['defaults'].img2img.denoising_strength, + min_value=0.01, max_value=1.0, step=0.01) + + with st.expander("Preview Settings"): + st.session_state["update_preview"] = st.checkbox("Update Image Preview", value=st.session_state['defaults'].img2img.update_preview, + help="If enabled the image preview will be updated during the generation instead of at the end. \ + You can use the Update Preview \Frequency option bellow to customize how frequent it's updated. \ + By default this is enabled and the frequency is set to 1 step.") + + st.session_state["update_preview_frequency"] = st.text_input("Update Image Preview Frequency", value=st.session_state['defaults'].img2img.update_preview_frequency, + help="Frequency in steps at which the the preview image is updated. By default the frequency \ + is set to 1 step.") + + with col2_img2img_layout: + editor_tab = st.tabs(["Editor"]) + + editor_image = st.empty() + st.session_state["editor_image"] = editor_image + + st.form_submit_button("Refresh") + + masked_image_holder = st.empty() + image_holder = st.empty() + + uploaded_images = st.file_uploader( + "Upload Image", accept_multiple_files=False, type=["png", "jpg", "jpeg", "webp"], + help="Upload an image which will be used for the image to image generation.", + ) + if uploaded_images: + image = Image.open(uploaded_images).convert('RGBA') + new_img = image.resize((width, height)) + image_holder.image(new_img) + + mask_holder = st.empty() + + uploaded_masks = st.file_uploader( + "Upload Mask", accept_multiple_files=False, type=["png", "jpg", "jpeg", "webp"], + help="Upload an mask image which will be used for masking the image to image generation.", + ) + if uploaded_masks: + mask = Image.open(uploaded_masks) + if mask.mode == "RGBA": + mask = mask.convert('RGBA') + background = Image.new('RGBA', mask.size, (0, 0, 0)) + mask = Image.alpha_composite(background, mask) + mask = mask.resize((width, height)) + mask_holder.image(mask) + + if uploaded_images and uploaded_masks: + if mask_mode != 2: + final_img = new_img.copy() + alpha_layer = mask.convert('L') + strength = st.session_state["denoising_strength"] + if mask_mode == 0: + alpha_layer = ImageOps.invert(alpha_layer) + alpha_layer = alpha_layer.point(lambda a: a * strength) + alpha_layer = ImageOps.invert(alpha_layer) + elif mask_mode == 1: + alpha_layer = alpha_layer.point(lambda a: a * strength) + alpha_layer = ImageOps.invert(alpha_layer) + + final_img.putalpha(alpha_layer) + + with masked_image_holder.container(): + st.text("Masked Image Preview") + st.image(final_img) + + + with col3_img2img_layout: + result_tab = st.tabs(["Result"]) + + # create an empty container for the image, progress bar, etc so we can update it later and use session_state to hold them globally. + preview_image = st.empty() + st.session_state["preview_image"] = preview_image + + #st.session_state["loading"] = st.empty() + + st.session_state["progress_bar_text"] = st.empty() + st.session_state["progress_bar"] = st.empty() + + + message = st.empty() + + #if uploaded_images: + #image = Image.open(uploaded_images).convert('RGB') + ##img_array = np.array(image) # if you want to pass it to OpenCV + #new_img = image.resize((width, height)) + #st.image(new_img, use_column_width=True) + + + if generate_button: + #print("Loading models") + # load the models when we hit the generate button for the first time, it wont be loaded after that so dont worry. + load_models(False, use_GFPGAN, st.session_state["use_RealESRGAN"], st.session_state["RealESRGAN_model"], st.session_state["CustomModel_available"], + st.session_state["custom_model"]) + + if uploaded_images: + image = Image.open(uploaded_images).convert('RGBA') + new_img = image.resize((width, height)) + #img_array = np.array(image) # if you want to pass it to OpenCV + new_mask = None + if uploaded_masks: + mask = Image.open(uploaded_masks).convert('RGBA') + new_mask = mask.resize((width, height)) + + try: + output_images, seed, info, stats = img2img(prompt=prompt, init_info=new_img, init_info_mask=new_mask, mask_mode=mask_mode, + mask_restore=img2img_mask_restore, ddim_steps=st.session_state["sampling_steps"], + sampler_name=st.session_state["sampler_name"], n_iter=batch_count, + cfg_scale=cfg_scale, denoising_strength=st.session_state["denoising_strength"], variant_seed=variant_seed, + seed=seed, noise_mode=noise_mode, find_noise_steps=find_noise_steps, width=width, + height=height, variant_amount=variant_amount, + ddim_eta=0.0, write_info_files=write_info_files, RealESRGAN_model=st.session_state["RealESRGAN_model"], + separate_prompts=separate_prompts, normalize_prompt_weights=normalize_prompt_weights, + save_individual_images=save_individual_images, save_grid=save_grid, + group_by_prompt=group_by_prompt, save_as_jpg=save_as_jpg, use_GFPGAN=use_GFPGAN, + use_RealESRGAN=st.session_state["use_RealESRGAN"] if not loopback else False, loopback=loopback + ) + + #show a message when the generation is complete. + message.success('Render Complete: ' + info + '; Stats: ' + stats, icon="✅") + + except (StopException, KeyError): + print(f"Received Streamlit StopException") + + # this will render all the images at the end of the generation but its better if its moved to a second tab inside col2 and shown as a gallery. + # use the current col2 first tab to show the preview_img and update it as its generated. + #preview_image.image(output_images, width=750) + +#on import run init diff --git a/scripts/imglab.py b/scripts/imglab.py new file mode 100644 index 0000000..eb09c6a --- /dev/null +++ b/scripts/imglab.py @@ -0,0 +1,161 @@ +# base webui import and utils. +from webui_streamlit import st +from sd_utils import * + +#home plugin +import os +from PIL import Image +#from bs4 import BeautifulSoup +from streamlit.runtime.in_memory_file_manager import in_memory_file_manager +from streamlit.elements import image as STImage + +# Temp imports + + +# end of imports +#--------------------------------------------------------------------------------------------------------------- + +try: + # this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start. + from transformers import logging + + logging.set_verbosity_error() +except: + pass + +class plugin_info(): + plugname = "imglab" + description = "Image Lab" + isTab = True + displayPriority = 3 + +def getLatestGeneratedImagesFromPath(): + #get the latest images from the generated images folder + #get the path to the generated images folder + generatedImagesPath = os.path.join(os.getcwd(),'outputs') + #get all the files from the folders and subfolders + files = [] + #get the laest 10 images from the output folder without walking the subfolders + for r, d, f in os.walk(generatedImagesPath): + for file in f: + if '.png' in file: + files.append(os.path.join(r, file)) + #sort the files by date + files.sort(key=os.path.getmtime) + #reverse the list so the latest images are first + for f in files: + img = Image.open(f) + files[files.index(f)] = img + #get the latest 10 files + #get all the files with the .png or .jpg extension + #sort files by date + #get the latest 10 files + latestFiles = files[-10:] + #reverse the list + latestFiles.reverse() + return latestFiles + +def getImagesFromLexica(): + #scrape images from lexica.art + #get the html from the page + #get the html with cookies and javascript + apiEndpoint = r'https://lexica.art/api/trpc/prompts.infinitePrompts?batch=1&input=%7B%220%22%3A%7B%22json%22%3A%7B%22limit%22%3A10%2C%22text%22%3A%22%22%2C%22cursor%22%3A10%7D%7D%7D' + #REST API call + # + from requests_html import HTMLSession + session = HTMLSession() + + response = session.get(apiEndpoint) + #req = requests.Session() + #req.headers['user-agent'] = 'Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/96.0.4664.45 Safari/537.36' + #response = req.get(apiEndpoint) + print(response.status_code) + print(response.text) + #get the json from the response + #json = response.json() + #get the prompts from the json + print(response) + #session = requests.Session() + #parseEndpointJson = session.get(apiEndpoint,headers=headers,verify=False) + #print(parseEndpointJson) + #print('test2') + #page = requests.get("https://lexica.art/", headers={'User-Agent': 'Mozilla/5.0'}) + #parse the html + #soup = BeautifulSoup(page.content, 'html.parser') + #find all the images + #print(soup) + #images = soup.find_all('alt-image') + #create a list to store the image urls + image_urls = [] + #loop through the images + for image in images: + #get the url + image_url = image['src'] + #add it to the list + image_urls.append('http://www.lexica.art/'+image_url) + #return the list + print(image_urls) + return image_urls +def changeImage(): + #change the image in the image holder + #check if the file is not empty + if len(st.session_state['uploaded_file']) > 0: + #read the file + print('test2') + uploaded = st.session_state['uploaded_file'][0].read() + #show the image in the image holder + st.session_state['previewImg'].empty() + st.session_state['previewImg'].image(uploaded,use_column_width=True) +def createHTMLGallery(images): + html3 = """ + ' + return html3 +def layout(): + + col1, col2 = st.columns(2) + with col1: + st.session_state['uploaded_file'] = st.file_uploader("Choose an image or images", type=["png", "jpg", "jpeg", "webp"],accept_multiple_files=True,on_change=changeImage) + if 'previewImg' not in st.session_state: + st.session_state['previewImg'] = st.empty() + else: + if len(st.session_state['uploaded_file']) > 0: + st.session_state['previewImg'].empty() + st.session_state['previewImg'].image(st.session_state['uploaded_file'][0],use_column_width=True) + else: + st.session_state['previewImg'] = st.empty() + diff --git a/scripts/perlin.py b/scripts/perlin.py new file mode 100644 index 0000000..327a994 --- /dev/null +++ b/scripts/perlin.py @@ -0,0 +1,48 @@ +import numpy as np + +def perlin(x, y, seed=0): + # permutation table + np.random.seed(seed) + p = np.arange(256, dtype=int) + np.random.shuffle(p) + p = np.stack([p, p]).flatten() + # coordinates of the top-left + xi, yi = x.astype(int), y.astype(int) + # internal coordinates + xf, yf = x - xi, y - yi + # fade factors + u, v = fade(xf), fade(yf) + # noise components + n00 = gradient(p[p[xi] + yi], xf, yf) + n01 = gradient(p[p[xi] + yi + 1], xf, yf - 1) + n11 = gradient(p[p[xi + 1] + yi + 1], xf - 1, yf - 1) + n10 = gradient(p[p[xi + 1] + yi], xf - 1, yf) + # combine noises + x1 = lerp(n00, n10, u) + x2 = lerp(n01, n11, u) # FIX1: I was using n10 instead of n01 + return lerp(x1, x2, v) # FIX2: I also had to reverse x1 and x2 here + +def lerp(a, b, x): + "linear interpolation" + return a + x * (b - a) + +def fade(t): + "6t^5 - 15t^4 + 10t^3" + return 6 * t**5 - 15 * t**4 + 10 * t**3 + +def gradient(h, x, y): + "grad converts h to the right gradient vector and return the dot product with (x,y)" + vectors = np.array([[0, 1], [0, -1], [1, 0], [-1, 0]]) + g = vectors[h % 4] + return g[:, :, 0] * x + g[:, :, 1] * y + +lin = np.linspace(0, 5, 100, endpoint=False) +x, y = np.meshgrid(lin, lin) + + + +def perlinNoise(height,width,octavesx=5,octavesy=5,seed=None): + linx = np.linspace(0,octavesx,width,endpoint=False) + liny = np.linspace(0,octavesy,height,endpoint=False) + x,y = np.meshgrid(linx,liny) + return perlin(x,y,seed=seed) \ No newline at end of file diff --git a/scripts/relauncher.py b/scripts/relauncher.py index 7179d7f..457d539 100644 --- a/scripts/relauncher.py +++ b/scripts/relauncher.py @@ -19,6 +19,8 @@ optimized_turbo = False # Creates a public xxxxx.gradio.app share link to allow others to use your interface (requires properly forwarded ports to work correctly) share = False +# Generate tiling images +tiling = False # Enter other `--arguments` you wish to use - Must be entered as a `--argument ` syntax additional_arguments = "" @@ -37,6 +39,8 @@ if optimized_turbo == True: common_arguments += "--optimized-turbo " if optimized == True: common_arguments += "--optimized " +if tiling == True: + common_arguments += "--tiling " if share == True: common_arguments += "--share " diff --git a/scripts/sd_utils.py b/scripts/sd_utils.py new file mode 100644 index 0000000..6983edb --- /dev/null +++ b/scripts/sd_utils.py @@ -0,0 +1,1728 @@ +# base webui import and utils. +from webui_streamlit import st + + +# streamlit imports +from streamlit import StopException +#other imports + +import warnings +import json + +import base64 +import os, sys, re, random, datetime, time, math, glob +from PIL import Image, ImageFont, ImageDraw, ImageFilter +from PIL.PngImagePlugin import PngInfo +from scipy import integrate +import torch +from torchdiffeq import odeint +import k_diffusion as K +import math +import mimetypes +import numpy as np +import pynvml +import threading +import torch +from torch import autocast +from torchvision import transforms +import torch.nn as nn +from omegaconf import OmegaConf +import yaml +from pathlib import Path +from contextlib import nullcontext +from einops import rearrange +from ldm.util import instantiate_from_config +from retry import retry +from slugify import slugify +import skimage +import piexif +import piexif.helper +from tqdm import trange + +# Temp imports + + +# end of imports +#--------------------------------------------------------------------------------------------------------------- + +try: + # this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start. + from transformers import logging + + logging.set_verbosity_error() +except: + pass + +# remove some annoying deprecation warnings that show every now and then. +warnings.filterwarnings("ignore", category=DeprecationWarning) + +# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the bowser will not show any UI +mimetypes.init() +mimetypes.add_type('application/javascript', '.js') + +# some of those options should not be changed at all because they would break the model, so I removed them from options. +opt_C = 4 +opt_f = 8 + +if not "defaults" in st.session_state: + st.session_state["defaults"] = {} + +st.session_state["defaults"] = OmegaConf.load("configs/webui/webui_streamlit.yaml") + +if (os.path.exists("configs/webui/userconfig_streamlit.yaml")): + user_defaults = OmegaConf.load("configs/webui/userconfig_streamlit.yaml") + st.session_state["defaults"] = OmegaConf.merge(st.session_state["defaults"], user_defaults) + + +# should and will be moved to a settings menu in the UI at some point +grid_format = [s.lower() for s in st.session_state["defaults"].general.grid_format.split(':')] +grid_lossless = False +grid_quality = 100 +if grid_format[0] == 'png': + grid_ext = 'png' + grid_format = 'png' +elif grid_format[0] in ['jpg', 'jpeg']: + grid_quality = int(grid_format[1]) if len(grid_format) > 1 else 100 + grid_ext = 'jpg' + grid_format = 'jpeg' +elif grid_format[0] == 'webp': + grid_quality = int(grid_format[1]) if len(grid_format) > 1 else 100 + grid_ext = 'webp' + grid_format = 'webp' + if grid_quality < 0: # e.g. webp:-100 for lossless mode + grid_lossless = True + grid_quality = abs(grid_quality) + +# should and will be moved to a settings menu in the UI at some point +save_format = [s.lower() for s in st.session_state["defaults"].general.save_format.split(':')] +save_lossless = False +save_quality = 100 +if save_format[0] == 'png': + save_ext = 'png' + save_format = 'png' +elif save_format[0] in ['jpg', 'jpeg']: + save_quality = int(save_format[1]) if len(save_format) > 1 else 100 + save_ext = 'jpg' + save_format = 'jpeg' +elif save_format[0] == 'webp': + save_quality = int(save_format[1]) if len(save_format) > 1 else 100 + save_ext = 'webp' + save_format = 'webp' + if save_quality < 0: # e.g. webp:-100 for lossless mode + save_lossless = True + save_quality = abs(save_quality) + +# this should force GFPGAN and RealESRGAN onto the selected gpu as well +os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" # see issue #152 +os.environ["CUDA_VISIBLE_DEVICES"] = str(st.session_state["defaults"].general.gpu) + +@retry(tries=5) +def load_models(continue_prev_run = False, use_GFPGAN=False, use_RealESRGAN=False, RealESRGAN_model="RealESRGAN_x4plus", + CustomModel_available=False, custom_model="Stable Diffusion v1.4"): + """Load the different models. We also reuse the models that are already in memory to speed things up instead of loading them again. """ + + print ("Loading models.") + + st.session_state["progress_bar_text"].text("Loading models...") + + # Generate random run ID + # Used to link runs linked w/ continue_prev_run which is not yet implemented + # Use URL and filesystem safe version just in case. + st.session_state["run_id"] = base64.urlsafe_b64encode( + os.urandom(6) + ).decode("ascii") + + # check what models we want to use and if the they are already loaded. + + if use_GFPGAN: + if "GFPGAN" in st.session_state: + print("GFPGAN already loaded") + else: + # Load GFPGAN + if os.path.exists(st.session_state["defaults"].general.GFPGAN_dir): + try: + st.session_state["GFPGAN"] = load_GFPGAN() + print("Loaded GFPGAN") + except Exception: + import traceback + print("Error loading GFPGAN:", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + else: + if "GFPGAN" in st.session_state: + del st.session_state["GFPGAN"] + + if use_RealESRGAN: + if "RealESRGAN" in st.session_state and st.session_state["RealESRGAN"].model.name == RealESRGAN_model: + print("RealESRGAN already loaded") + else: + #Load RealESRGAN + try: + # We first remove the variable in case it has something there, + # some errors can load the model incorrectly and leave things in memory. + del st.session_state["RealESRGAN"] + except KeyError: + pass + + if os.path.exists(st.session_state["defaults"].general.RealESRGAN_dir): + # st.session_state is used for keeping the models in memory across multiple pages or runs. + st.session_state["RealESRGAN"] = load_RealESRGAN(RealESRGAN_model) + print("Loaded RealESRGAN with model "+ st.session_state["RealESRGAN"].model.name) + + else: + if "RealESRGAN" in st.session_state: + del st.session_state["RealESRGAN"] + + if "model" in st.session_state: + if "model" in st.session_state and st.session_state["loaded_model"] == custom_model: + # TODO: check if the optimized mode was changed? + print("Model already loaded") + + return + else: + try: + del st.session_state.model + del st.session_state.modelCS + del st.session_state.modelFS + del st.session_state.loaded_model + except KeyError: + pass + + # At this point the model is either + # is not loaded yet or have been evicted: + # load new model into memory + st.session_state.custom_model = custom_model + + config, device, model, modelCS, modelFS = load_sd_model(custom_model) + + st.session_state.device = device + st.session_state.model = model + st.session_state.modelCS = modelCS + st.session_state.modelFS = modelFS + st.session_state.loaded_model = custom_model + + if st.session_state.defaults.general.enable_attention_slicing: + st.session_state.model.enable_attention_slicing() + + if st.session_state.defaults.general.enable_minimal_memory_usage: + st.session_state.model.enable_minimal_memory_usage() + + print("Model loaded.") + + +def load_model_from_config(config, ckpt, verbose=False): + + print(f"Loading model from {ckpt}") + + pl_sd = torch.load(ckpt, map_location="cpu") + if "global_step" in pl_sd: + print(f"Global Step: {pl_sd['global_step']}") + sd = pl_sd["state_dict"] + model = instantiate_from_config(config.model) + m, u = model.load_state_dict(sd, strict=False) + if len(m) > 0 and verbose: + print("missing keys:") + print(m) + if len(u) > 0 and verbose: + print("unexpected keys:") + print(u) + + model.cuda() + model.eval() + return model + + +def load_sd_from_config(ckpt, verbose=False): + print(f"Loading model from {ckpt}") + pl_sd = torch.load(ckpt, map_location="cpu") + if "global_step" in pl_sd: + print(f"Global Step: {pl_sd['global_step']}") + sd = pl_sd["state_dict"] + return sd + + +class MemUsageMonitor(threading.Thread): + stop_flag = False + max_usage = 0 + total = -1 + + def __init__(self, name): + threading.Thread.__init__(self) + self.name = name + + def run(self): + try: + pynvml.nvmlInit() + except: + print(f"[{self.name}] Unable to initialize NVIDIA management. No memory stats. \n") + return + print(f"[{self.name}] Recording max memory usage...\n") + # Missing context + #handle = pynvml.nvmlDeviceGetHandleByIndex(st.session_state['defaults'].general.gpu) + handle = pynvml.nvmlDeviceGetHandleByIndex(0) + self.total = pynvml.nvmlDeviceGetMemoryInfo(handle).total + while not self.stop_flag: + m = pynvml.nvmlDeviceGetMemoryInfo(handle) + self.max_usage = max(self.max_usage, m.used) + # print(self.max_usage) + time.sleep(0.1) + print(f"[{self.name}] Stopped recording.\n") + pynvml.nvmlShutdown() + + def read(self): + return self.max_usage, self.total + + def stop(self): + self.stop_flag = True + + def read_and_stop(self): + self.stop_flag = True + return self.max_usage, self.total + +class CFGMaskedDenoiser(nn.Module): + def __init__(self, model): + super().__init__() + self.inner_model = model + + def forward(self, x, sigma, uncond, cond, cond_scale, mask, x0, xi): + x_in = x + x_in = torch.cat([x_in] * 2) + sigma_in = torch.cat([sigma] * 2) + cond_in = torch.cat([uncond, cond]) + uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2) + denoised = uncond + (cond - uncond) * cond_scale + + if mask is not None: + assert x0 is not None + img_orig = x0 + mask_inv = 1. - mask + denoised = (img_orig * mask_inv) + (mask * denoised) + + return denoised + +class CFGDenoiser(nn.Module): + def __init__(self, model): + super().__init__() + self.inner_model = model + + def forward(self, x, sigma, uncond, cond, cond_scale): + x_in = torch.cat([x] * 2) + sigma_in = torch.cat([sigma] * 2) + cond_in = torch.cat([uncond, cond]) + uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2) + return uncond + (cond - uncond) * cond_scale +def append_zero(x): + return torch.cat([x, x.new_zeros([1])]) +def append_dims(x, target_dims): + """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" + dims_to_append = target_dims - x.ndim + if dims_to_append < 0: + raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less') + return x[(...,) + (None,) * dims_to_append] +def get_sigmas_karras(n, sigma_min, sigma_max, rho=7., device='cpu'): + """Constructs the noise schedule of Karras et al. (2022).""" + ramp = torch.linspace(0, 1, n) + min_inv_rho = sigma_min ** (1 / rho) + max_inv_rho = sigma_max ** (1 / rho) + sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho + return append_zero(sigmas).to(device) + +# +# helper fft routines that keep ortho normalization and auto-shift before and after fft +def _fft2(data): + if data.ndim > 2: # has channels + out_fft = np.zeros((data.shape[0], data.shape[1], data.shape[2]), dtype=np.complex128) + for c in range(data.shape[2]): + c_data = data[:,:,c] + out_fft[:,:,c] = np.fft.fft2(np.fft.fftshift(c_data),norm="ortho") + out_fft[:,:,c] = np.fft.ifftshift(out_fft[:,:,c]) + else: # one channel + out_fft = np.zeros((data.shape[0], data.shape[1]), dtype=np.complex128) + out_fft[:,:] = np.fft.fft2(np.fft.fftshift(data),norm="ortho") + out_fft[:,:] = np.fft.ifftshift(out_fft[:,:]) + + return out_fft + +def _ifft2(data): + if data.ndim > 2: # has channels + out_ifft = np.zeros((data.shape[0], data.shape[1], data.shape[2]), dtype=np.complex128) + for c in range(data.shape[2]): + c_data = data[:,:,c] + out_ifft[:,:,c] = np.fft.ifft2(np.fft.fftshift(c_data),norm="ortho") + out_ifft[:,:,c] = np.fft.ifftshift(out_ifft[:,:,c]) + else: # one channel + out_ifft = np.zeros((data.shape[0], data.shape[1]), dtype=np.complex128) + out_ifft[:,:] = np.fft.ifft2(np.fft.fftshift(data),norm="ortho") + out_ifft[:,:] = np.fft.ifftshift(out_ifft[:,:]) + + return out_ifft + +def _get_gaussian_window(width, height, std=3.14, mode=0): + + window_scale_x = float(width / min(width, height)) + window_scale_y = float(height / min(width, height)) + + window = np.zeros((width, height)) + x = (np.arange(width) / width * 2. - 1.) * window_scale_x + for y in range(height): + fy = (y / height * 2. - 1.) * window_scale_y + if mode == 0: + window[:, y] = np.exp(-(x**2+fy**2) * std) + else: + window[:, y] = (1/((x**2+1.) * (fy**2+1.))) ** (std/3.14) # hey wait a minute that's not gaussian + + return window + +def _get_masked_window_rgb(np_mask_grey, hardness=1.): + np_mask_rgb = np.zeros((np_mask_grey.shape[0], np_mask_grey.shape[1], 3)) + if hardness != 1.: + hardened = np_mask_grey[:] ** hardness + else: + hardened = np_mask_grey[:] + for c in range(3): + np_mask_rgb[:,:,c] = hardened[:] + return np_mask_rgb + +def get_matched_noise(_np_src_image, np_mask_rgb, noise_q, color_variation): + """ + Explanation: + Getting good results in/out-painting with stable diffusion can be challenging. + Although there are simpler effective solutions for in-painting, out-painting can be especially challenging because there is no color data + in the masked area to help prompt the generator. Ideally, even for in-painting we'd like work effectively without that data as well. + Provided here is my take on a potential solution to this problem. + + By taking a fourier transform of the masked src img we get a function that tells us the presence and orientation of each feature scale in the unmasked src. + Shaping the init/seed noise for in/outpainting to the same distribution of feature scales, orientations, and positions increases output coherence + by helping keep features aligned. This technique is applicable to any continuous generation task such as audio or video, each of which can + be conceptualized as a series of out-painting steps where the last half of the input "frame" is erased. For multi-channel data such as color + or stereo sound the "color tone" or histogram of the seed noise can be matched to improve quality (using scikit-image currently) + This method is quite robust and has the added benefit of being fast independently of the size of the out-painted area. + The effects of this method include things like helping the generator integrate the pre-existing view distance and camera angle. + + Carefully managing color and brightness with histogram matching is also essential to achieving good coherence. + + noise_q controls the exponent in the fall-off of the distribution can be any positive number, lower values means higher detail (range > 0, default 1.) + color_variation controls how much freedom is allowed for the colors/palette of the out-painted area (range 0..1, default 0.01) + This code is provided as is under the Unlicense (https://unlicense.org/) + Although you have no obligation to do so, if you found this code helpful please find it in your heart to credit me [parlance-zz]. + + Questions or comments can be sent to parlance@fifth-harmonic.com (https://github.com/parlance-zz/) + This code is part of a new branch of a discord bot I am working on integrating with diffusers (https://github.com/parlance-zz/g-diffuser-bot) + + """ + + global DEBUG_MODE + global TMP_ROOT_PATH + + width = _np_src_image.shape[0] + height = _np_src_image.shape[1] + num_channels = _np_src_image.shape[2] + + np_src_image = _np_src_image[:] * (1. - np_mask_rgb) + np_mask_grey = (np.sum(np_mask_rgb, axis=2)/3.) + np_src_grey = (np.sum(np_src_image, axis=2)/3.) + all_mask = np.ones((width, height), dtype=bool) + img_mask = np_mask_grey > 1e-6 + ref_mask = np_mask_grey < 1e-3 + + windowed_image = _np_src_image * (1.-_get_masked_window_rgb(np_mask_grey)) + windowed_image /= np.max(windowed_image) + windowed_image += np.average(_np_src_image) * np_mask_rgb# / (1.-np.average(np_mask_rgb)) # rather than leave the masked area black, we get better results from fft by filling the average unmasked color + #windowed_image += np.average(_np_src_image) * (np_mask_rgb * (1.- np_mask_rgb)) / (1.-np.average(np_mask_rgb)) # compensate for darkening across the mask transition area + #_save_debug_img(windowed_image, "windowed_src_img") + + src_fft = _fft2(windowed_image) # get feature statistics from masked src img + src_dist = np.absolute(src_fft) + src_phase = src_fft / src_dist + #_save_debug_img(src_dist, "windowed_src_dist") + + noise_window = _get_gaussian_window(width, height, mode=1) # start with simple gaussian noise + noise_rgb = np.random.random_sample((width, height, num_channels)) + noise_grey = (np.sum(noise_rgb, axis=2)/3.) + noise_rgb *= color_variation # the colorfulness of the starting noise is blended to greyscale with a parameter + for c in range(num_channels): + noise_rgb[:,:,c] += (1. - color_variation) * noise_grey + + noise_fft = _fft2(noise_rgb) + for c in range(num_channels): + noise_fft[:,:,c] *= noise_window + noise_rgb = np.real(_ifft2(noise_fft)) + shaped_noise_fft = _fft2(noise_rgb) + shaped_noise_fft[:,:,:] = np.absolute(shaped_noise_fft[:,:,:])**2 * (src_dist ** noise_q) * src_phase # perform the actual shaping + + brightness_variation = 0.#color_variation # todo: temporarily tieing brightness variation to color variation for now + contrast_adjusted_np_src = _np_src_image[:] * (brightness_variation + 1.) - brightness_variation * 2. + + # scikit-image is used for histogram matching, very convenient! + shaped_noise = np.real(_ifft2(shaped_noise_fft)) + shaped_noise -= np.min(shaped_noise) + shaped_noise /= np.max(shaped_noise) + shaped_noise[img_mask,:] = skimage.exposure.match_histograms(shaped_noise[img_mask,:]**1., contrast_adjusted_np_src[ref_mask,:], channel_axis=1) + shaped_noise = _np_src_image[:] * (1. - np_mask_rgb) + shaped_noise * np_mask_rgb + #_save_debug_img(shaped_noise, "shaped_noise") + + matched_noise = np.zeros((width, height, num_channels)) + matched_noise = shaped_noise[:] + #matched_noise[all_mask,:] = skimage.exposure.match_histograms(shaped_noise[all_mask,:], _np_src_image[ref_mask,:], channel_axis=1) + #matched_noise = _np_src_image[:] * (1. - np_mask_rgb) + matched_noise * np_mask_rgb + + #_save_debug_img(matched_noise, "matched_noise") + + """ + todo: + color_variation doesnt have to be a single number, the overall color tone of the out-painted area could be param controlled + """ + + return np.clip(matched_noise, 0., 1.) + + +# +def find_noise_for_image(model, device, init_image, prompt, steps=200, cond_scale=2.0, verbose=False, normalize=False, generation_callback=None): + image = np.array(init_image).astype(np.float32) / 255.0 + image = image[None].transpose(0, 3, 1, 2) + image = torch.from_numpy(image) + image = 2. * image - 1. + image = image.to(device) + x = model.get_first_stage_encoding(model.encode_first_stage(image)) + + uncond = model.get_learned_conditioning(['']) + cond = model.get_learned_conditioning([prompt]) + + s_in = x.new_ones([x.shape[0]]) + dnw = K.external.CompVisDenoiser(model) + sigmas = dnw.get_sigmas(steps).flip(0) + + if verbose: + print(sigmas) + + for i in trange(1, len(sigmas)): + x_in = torch.cat([x] * 2) + sigma_in = torch.cat([sigmas[i - 1] * s_in] * 2) + cond_in = torch.cat([uncond, cond]) + + c_out, c_in = [K.utils.append_dims(k, x_in.ndim) for k in dnw.get_scalings(sigma_in)] + + if i == 1: + t = dnw.sigma_to_t(torch.cat([sigmas[i] * s_in] * 2)) + else: + t = dnw.sigma_to_t(sigma_in) + + eps = model.apply_model(x_in * c_in, t, cond=cond_in) + denoised_uncond, denoised_cond = (x_in + eps * c_out).chunk(2) + + denoised = denoised_uncond + (denoised_cond - denoised_uncond) * cond_scale + + if i == 1: + d = (x - denoised) / (2 * sigmas[i]) + else: + d = (x - denoised) / sigmas[i - 1] + + if generation_callback is not None: + generation_callback(x, i) + + dt = sigmas[i] - sigmas[i - 1] + x = x + d * dt + + return x / sigmas[-1] + + +def get_sigmas_exponential(n, sigma_min, sigma_max, device='cpu'): + """Constructs an exponential noise schedule.""" + sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), n, device=device).exp() + return append_zero(sigmas) + + +def get_sigmas_vp(n, beta_d=19.9, beta_min=0.1, eps_s=1e-3, device='cpu'): + """Constructs a continuous VP noise schedule.""" + t = torch.linspace(1, eps_s, n, device=device) + sigmas = torch.sqrt(torch.exp(beta_d * t ** 2 / 2 + beta_min * t) - 1) + return append_zero(sigmas) + + +def to_d(x, sigma, denoised): + """Converts a denoiser output to a Karras ODE derivative.""" + return (x - denoised) / append_dims(sigma, x.ndim) +def linear_multistep_coeff(order, t, i, j): + if order - 1 > i: + raise ValueError(f'Order {order} too high for step {i}') + def fn(tau): + prod = 1. + for k in range(order): + if j == k: + continue + prod *= (tau - t[i - k]) / (t[i - j] - t[i - k]) + return prod + return integrate.quad(fn, t[i], t[i + 1], epsrel=1e-4)[0] + +class KDiffusionSampler: + def __init__(self, m, sampler): + self.model = m + self.model_wrap = K.external.CompVisDenoiser(m) + self.schedule = sampler + def get_sampler_name(self): + return self.schedule + def sample(self, S, conditioning, batch_size, shape, verbose, unconditional_guidance_scale, unconditional_conditioning, eta, x_T, img_callback=None, log_every_t=None): + sigmas = self.model_wrap.get_sigmas(S) + x = x_T * sigmas[0] + model_wrap_cfg = CFGDenoiser(self.model_wrap) + samples_ddim = None + samples_ddim = K.sampling.__dict__[f'sample_{self.schedule}'](model_wrap_cfg, x, sigmas, + extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, + 'cond_scale': unconditional_guidance_scale}, disable=False, callback=generation_callback) + # + return samples_ddim, None + + +@torch.no_grad() +def log_likelihood(model, x, sigma_min, sigma_max, extra_args=None, atol=1e-4, rtol=1e-4): + extra_args = {} if extra_args is None else extra_args + s_in = x.new_ones([x.shape[0]]) + v = torch.randint_like(x, 2) * 2 - 1 + fevals = 0 + def ode_fn(sigma, x): + nonlocal fevals + with torch.enable_grad(): + x = x[0].detach().requires_grad_() + denoised = model(x, sigma * s_in, **extra_args) + d = to_d(x, sigma, denoised) + fevals += 1 + grad = torch.autograd.grad((d * v).sum(), x)[0] + d_ll = (v * grad).flatten(1).sum(1) + return d.detach(), d_ll + x_min = x, x.new_zeros([x.shape[0]]) + t = x.new_tensor([sigma_min, sigma_max]) + sol = odeint(ode_fn, x_min, t, atol=atol, rtol=rtol, method='dopri5') + latent, delta_ll = sol[0][-1], sol[1][-1] + ll_prior = torch.distributions.Normal(0, sigma_max).log_prob(latent).flatten(1).sum(1) + return ll_prior + delta_ll, {'fevals': fevals} + + +def create_random_tensors(shape, seeds): + xs = [] + for seed in seeds: + torch.manual_seed(seed) + + # randn results depend on device; gpu and cpu get different results for same seed; + # the way I see it, it's better to do this on CPU, so that everyone gets same result; + # but the original script had it like this so i do not dare change it for now because + # it will break everyone's seeds. + xs.append(torch.randn(shape, device=st.session_state['defaults'].general.gpu)) + x = torch.stack(xs) + return x + +def torch_gc(): + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + +def load_GFPGAN(): + model_name = 'GFPGANv1.3' + model_path = os.path.join(st.session_state['defaults'].general.GFPGAN_dir, 'experiments/pretrained_models', model_name + '.pth') + if not os.path.isfile(model_path): + raise Exception("GFPGAN model not found at path "+model_path) + + sys.path.append(os.path.abspath(st.session_state['defaults'].general.GFPGAN_dir)) + from gfpgan import GFPGANer + + if st.session_state['defaults'].general.gfpgan_cpu or st.session_state['defaults'].general.extra_models_cpu: + instance = GFPGANer(model_path=model_path, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=torch.device('cpu')) + elif st.session_state['defaults'].general.extra_models_gpu: + instance = GFPGANer(model_path=model_path, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=torch.device(f"cuda:{st.session_state['defaults'].general.gfpgan_gpu}")) + else: + instance = GFPGANer(model_path=model_path, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=torch.device(f"cuda:{st.session_state['defaults'].general.gpu}")) + return instance + +def load_RealESRGAN(model_name: str): + from basicsr.archs.rrdbnet_arch import RRDBNet + RealESRGAN_models = { + 'RealESRGAN_x4plus': RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4), + 'RealESRGAN_x4plus_anime_6B': RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4) + } + + model_path = os.path.join(st.session_state['defaults'].general.RealESRGAN_dir, 'experiments/pretrained_models', model_name + '.pth') + if not os.path.exists(os.path.join(st.session_state['defaults'].general.RealESRGAN_dir, "experiments","pretrained_models", f"{model_name}.pth")): + raise Exception(model_name+".pth not found at path "+model_path) + + sys.path.append(os.path.abspath(st.session_state['defaults'].general.RealESRGAN_dir)) + from realesrgan import RealESRGANer + + if st.session_state['defaults'].general.esrgan_cpu or st.session_state['defaults'].general.extra_models_cpu: + instance = RealESRGANer(scale=2, model_path=model_path, model=RealESRGAN_models[model_name], pre_pad=0, half=False) # cpu does not support half + instance.device = torch.device('cpu') + instance.model.to('cpu') + elif st.session_state['defaults'].general.extra_models_gpu: + instance = RealESRGANer(scale=2, model_path=model_path, model=RealESRGAN_models[model_name], pre_pad=0, half=not st.session_state['defaults'].general.no_half, device=torch.device(f"cuda:{st.session_state['defaults'].general.esrgan_gpu}")) + else: + instance = RealESRGANer(scale=2, model_path=model_path, model=RealESRGAN_models[model_name], pre_pad=0, half=not st.session_state['defaults'].general.no_half, device=torch.device(f"cuda:{st.session_state['defaults'].general.gpu}")) + instance.model.name = model_name + + return instance + +# +def load_LDSR(checking=False): + model_name = 'model' + yaml_name = 'project' + model_path = os.path.join(st.session_state['defaults'].general.LDSR_dir, 'experiments/pretrained_models', model_name + '.ckpt') + yaml_path = os.path.join(st.session_state['defaults'].general.LDSR_dir, 'experiments/pretrained_models', yaml_name + '.yaml') + if not os.path.isfile(model_path): + raise Exception("LDSR model not found at path "+model_path) + if not os.path.isfile(yaml_path): + raise Exception("LDSR model not found at path "+yaml_path) + if checking == True: + return True + + sys.path.append(os.path.abspath(st.session_state['defaults'].general.LDSR_dir)) + from LDSR import LDSR + LDSRObject = LDSR(model_path, yaml_path) + return LDSRObject + +# +LDSR = None +def try_loading_LDSR(model_name: str,checking=False): + global LDSR + if os.path.exists(st.session_state['defaults'].general.LDSR_dir): + try: + LDSR = load_LDSR(checking=True) # TODO: Should try to load both models before giving up + if checking == True: + print("Found LDSR") + return True + print("Latent Diffusion Super Sampling (LDSR) model loaded") + except Exception: + import traceback + print("Error loading LDSR:", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + else: + print("LDSR not found at path, please make sure you have cloned the LDSR repo to ./src/latent-diffusion/") + +#try_loading_LDSR('model',checking=True) + + +# Loads Stable Diffusion model by name +def load_sd_model(model_name: str) -> [any, any, any, any, any]: + ckpt_path = st.session_state.defaults.general.default_model_path + if model_name != st.session_state.defaults.general.default_model: + ckpt_path = os.path.join("models", "custom", f"{model_name}.ckpt") + + if st.session_state.defaults.general.optimized: + config = OmegaConf.load(st.session_state.defaults.general.optimized_config) + + sd = load_sd_from_config(ckpt_path) + li, lo = [], [] + for key, v_ in sd.items(): + sp = key.split('.') + if (sp[0]) == 'model': + if 'input_blocks' in sp: + li.append(key) + elif 'middle_block' in sp: + li.append(key) + elif 'time_embed' in sp: + li.append(key) + else: + lo.append(key) + for key in li: + sd['model1.' + key[6:]] = sd.pop(key) + for key in lo: + sd['model2.' + key[6:]] = sd.pop(key) + + device = torch.device(f"cuda:{st.session_state.defaults.general.gpu}") \ + if torch.cuda.is_available() else torch.device("cpu") + + model = instantiate_from_config(config.modelUNet) + _, _ = model.load_state_dict(sd, strict=False) + model.cuda() + model.eval() + model.turbo = st.session_state.defaults.general.optimized_turbo + + modelCS = instantiate_from_config(config.modelCondStage) + _, _ = modelCS.load_state_dict(sd, strict=False) + modelCS.cond_stage_model.device = device + modelCS.eval() + + modelFS = instantiate_from_config(config.modelFirstStage) + _, _ = modelFS.load_state_dict(sd, strict=False) + modelFS.eval() + + del sd + + if not st.session_state.defaults.general.no_half: + model = model.half() + modelCS = modelCS.half() + modelFS = modelFS.half() + + return config, device, model, modelCS, modelFS + else: + config = OmegaConf.load(st.session_state.defaults.general.default_model_config) + model = load_model_from_config(config, ckpt_path) + + device = torch.device(f"cuda:{st.session_state.defaults.general.gpu}") \ + if torch.cuda.is_available() else torch.device("cpu") + model = (model if st.session_state.defaults.general.no_half + else model.half()).to(device) + + return config, device, model, None, None + + +# @codedealer: No usages +def ModelLoader(models,load=False,unload=False,imgproc_realesrgan_model_name='RealESRGAN_x4plus'): + #get global variables + global_vars = globals() + #check if m is in globals + if unload: + for m in models: + if m in global_vars: + #if it is, delete it + del global_vars[m] + if st.session_state['defaults'].general.optimized: + if m == 'model': + del global_vars[m+'FS'] + del global_vars[m+'CS'] + if m == 'model': + m = 'Stable Diffusion' + print('Unloaded ' + m) + if load: + for m in models: + if m not in global_vars or m in global_vars and type(global_vars[m]) == bool: + #if it isn't, load it + if m == 'GFPGAN': + global_vars[m] = load_GFPGAN() + elif m == 'model': + sdLoader = load_sd_from_config() + global_vars[m] = sdLoader[0] + if st.session_state['defaults'].general.optimized: + global_vars[m+'CS'] = sdLoader[1] + global_vars[m+'FS'] = sdLoader[2] + elif m == 'RealESRGAN': + global_vars[m] = load_RealESRGAN(imgproc_realesrgan_model_name) + elif m == 'LDSR': + global_vars[m] = load_LDSR() + if m =='model': + m='Stable Diffusion' + print('Loaded ' + m) + torch_gc() + + +# +@retry(tries=5) +def generation_callback(img, i=0): + if "update_preview_frequency" not in st.session_state: + raise StopException + + try: + if i == 0: + if img['i']: i = img['i'] + except TypeError: + pass + + if i % int(st.session_state.update_preview_frequency) == 0 and st.session_state.update_preview and i > 0: + #print (img) + #print (type(img)) + # The following lines will convert the tensor we got on img to an actual image we can render on the UI. + # It can probably be done in a better way for someone who knows what they're doing. I don't. + #print (img,isinstance(img, torch.Tensor)) + if isinstance(img, torch.Tensor): + x_samples_ddim = (st.session_state["model"] if not st.session_state['defaults'].general.optimized else st.session_state.modelFS).decode_first_stage(img) + else: + # When using the k Diffusion samplers they return a dict instead of a tensor that look like this: + # {'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised} + x_samples_ddim = (st.session_state["model"] if not st.session_state['defaults'].general.optimized else st.session_state.modelFS).decode_first_stage(img["denoised"]) + + x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) + + if x_samples_ddim.ndimension() == 4: + pil_images = [transforms.ToPILImage()(x.squeeze_(0)) for x in x_samples_ddim] + pil_image = image_grid(pil_images, 1) + else: + pil_image = transforms.ToPILImage()(x_samples_ddim.squeeze_(0)) + + # update image on the UI so we can see the progress + st.session_state["preview_image"].image(pil_image) + + # Show a progress bar so we can keep track of the progress even when the image progress is not been shown, + # Dont worry, it doesnt affect the performance. + if st.session_state["generation_mode"] == "txt2img": + percent = int(100 * float(i+1 if i+1 < st.session_state.sampling_steps else st.session_state.sampling_steps)/float(st.session_state.sampling_steps)) + st.session_state["progress_bar_text"].text( + f"Running step: {i+1 if i+1 < st.session_state.sampling_steps else st.session_state.sampling_steps}/{st.session_state.sampling_steps} {percent if percent < 100 else 100}%") + else: + if st.session_state["generation_mode"] == "img2img": + round_sampling_steps = round(st.session_state.sampling_steps * st.session_state["denoising_strength"]) + percent = int(100 * float(i+1 if i+1 < round_sampling_steps else round_sampling_steps)/float(round_sampling_steps)) + st.session_state["progress_bar_text"].text( + f"""Running step: {i+1 if i+1 < round_sampling_steps else round_sampling_steps}/{round_sampling_steps} {percent if percent < 100 else 100}%""") + else: + if st.session_state["generation_mode"] == "txt2vid": + percent = int(100 * float(i+1 if i+1 < st.session_state.sampling_steps else st.session_state.sampling_steps)/float(st.session_state.sampling_steps)) + st.session_state["progress_bar_text"].text( + f"Running step: {i+1 if i+1 < st.session_state.sampling_steps else st.session_state.sampling_steps}/{st.session_state.sampling_steps}" + f"{percent if percent < 100 else 100}%") + + st.session_state["progress_bar"].progress(percent if percent < 100 else 100) + + +prompt_parser = re.compile(""" + (?P # capture group for 'prompt' + [^:]+ # match one or more non ':' characters + ) # end 'prompt' + (?: # non-capture group + :+ # match one or more ':' characters + (?P # capture group for 'weight' + -?\\d+(?:\\.\\d+)? # match positive or negative decimal number + )? # end weight capture group, make optional + \\s* # strip spaces after weight + | # OR + $ # else, if no ':' then match end of line + ) # end non-capture group +""", re.VERBOSE) + +# grabs all text up to the first occurrence of ':' as sub-prompt +# takes the value following ':' as weight +# if ':' has no value defined, defaults to 1.0 +# repeats until no text remaining +def split_weighted_subprompts(input_string, normalize=True): + parsed_prompts = [(match.group("prompt"), float(match.group("weight") or 1)) for match in re.finditer(prompt_parser, input_string)] + if not normalize: + return parsed_prompts + # this probably still doesn't handle negative weights very well + weight_sum = sum(map(lambda x: x[1], parsed_prompts)) + return [(x[0], x[1] / weight_sum) for x in parsed_prompts] + +def slerp(device, t, v0:torch.Tensor, v1:torch.Tensor, DOT_THRESHOLD=0.9995): + v0 = v0.detach().cpu().numpy() + v1 = v1.detach().cpu().numpy() + + dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1))) + if np.abs(dot) > DOT_THRESHOLD: + v2 = (1 - t) * v0 + t * v1 + else: + theta_0 = np.arccos(dot) + sin_theta_0 = np.sin(theta_0) + theta_t = theta_0 * t + sin_theta_t = np.sin(theta_t) + s0 = np.sin(theta_0 - theta_t) / sin_theta_0 + s1 = sin_theta_t / sin_theta_0 + v2 = s0 * v0 + s1 * v1 + + v2 = torch.from_numpy(v2).to(device) + + return v2 + +# +def optimize_update_preview_frequency(current_chunk_speed, previous_chunk_speed_list, update_preview_frequency, update_preview_frequency_list): + """Find the optimal update_preview_frequency value maximizing + performance while minimizing the time between updates.""" + from statistics import mean + + previous_chunk_avg_speed = mean(previous_chunk_speed_list) + + previous_chunk_speed_list.append(current_chunk_speed) + current_chunk_avg_speed = mean(previous_chunk_speed_list) + + if current_chunk_avg_speed >= previous_chunk_avg_speed: + #print(f"{current_chunk_speed} >= {previous_chunk_speed}") + update_preview_frequency_list.append(update_preview_frequency + 1) + else: + #print(f"{current_chunk_speed} <= {previous_chunk_speed}") + update_preview_frequency_list.append(update_preview_frequency - 1) + + update_preview_frequency = round(mean(update_preview_frequency_list)) + + return current_chunk_speed, previous_chunk_speed_list, update_preview_frequency, update_preview_frequency_list + + +def get_font(fontsize): + fonts = ["arial.ttf", "DejaVuSans.ttf"] + for font_name in fonts: + try: + return ImageFont.truetype(font_name, fontsize) + except OSError: + pass + + # ImageFont.load_default() is practically unusable as it only supports + # latin1, so raise an exception instead if no usable font was found + raise Exception(f"No usable font found (tried {', '.join(fonts)})") + +def load_embeddings(fp): + if fp is not None and hasattr(st.session_state["model"], "embedding_manager"): + st.session_state["model"].embedding_manager.load(fp['name']) + +def load_learned_embed_in_clip(learned_embeds_path, text_encoder, tokenizer, token=None): + loaded_learned_embeds = torch.load(learned_embeds_path, map_location="cpu") + + # separate token and the embeds + if learned_embeds_path.endswith('.pt'): + print(loaded_learned_embeds['string_to_token']) + trained_token = list(loaded_learned_embeds['string_to_token'].keys())[0] + embeds = list(loaded_learned_embeds['string_to_param'].values())[0] + + elif learned_embeds_path.endswith('.bin'): + trained_token = list(loaded_learned_embeds.keys())[0] + embeds = loaded_learned_embeds[trained_token] + + embeds = loaded_learned_embeds[trained_token] + # cast to dtype of text_encoder + dtype = text_encoder.get_input_embeddings().weight.dtype + embeds.to(dtype) + + # add the token in tokenizer + token = token if token is not None else trained_token + num_added_tokens = tokenizer.add_tokens(token) + + # resize the token embeddings + text_encoder.resize_token_embeddings(len(tokenizer)) + + # get the id for the token and assign the embeds + token_id = tokenizer.convert_tokens_to_ids(token) + text_encoder.get_input_embeddings().weight.data[token_id] = embeds + return token + +def image_grid(imgs, batch_size, force_n_rows=None, captions=None): + #print (len(imgs)) + if force_n_rows is not None: + rows = force_n_rows + elif st.session_state['defaults'].general.n_rows > 0: + rows = st.session_state['defaults'].general.n_rows + elif st.session_state['defaults'].general.n_rows == 0: + rows = batch_size + else: + rows = math.sqrt(len(imgs)) + rows = round(rows) + + cols = math.ceil(len(imgs) / rows) + + w, h = imgs[0].size + grid = Image.new('RGB', size=(cols * w, rows * h), color='black') + + fnt = get_font(30) + + for i, img in enumerate(imgs): + grid.paste(img, box=(i % cols * w, i // cols * h)) + if captions and i= 2**32: + n = n >> 32 + return n + +# +def draw_prompt_matrix(im, width, height, all_prompts): + def wrap(text, d, font, line_length): + lines = [''] + for word in text.split(): + line = f'{lines[-1]} {word}'.strip() + if d.textlength(line, font=font) <= line_length: + lines[-1] = line + else: + lines.append(word) + return '\n'.join(lines) + + def draw_texts(pos, x, y, texts, sizes): + for i, (text, size) in enumerate(zip(texts, sizes)): + active = pos & (1 << i) != 0 + + if not active: + text = '\u0336'.join(text) + '\u0336' + + d.multiline_text((x, y + size[1] / 2), text, font=fnt, fill=color_active if active else color_inactive, anchor="mm", align="center") + + y += size[1] + line_spacing + + fontsize = (width + height) // 25 + line_spacing = fontsize // 2 + fnt = get_font(fontsize) + color_active = (0, 0, 0) + color_inactive = (153, 153, 153) + + pad_top = height // 4 + pad_left = width * 3 // 4 if len(all_prompts) > 2 else 0 + + cols = im.width // width + rows = im.height // height + + prompts = all_prompts[1:] + + result = Image.new("RGB", (im.width + pad_left, im.height + pad_top), "white") + result.paste(im, (pad_left, pad_top)) + + d = ImageDraw.Draw(result) + + boundary = math.ceil(len(prompts) / 2) + prompts_horiz = [wrap(x, d, fnt, width) for x in prompts[:boundary]] + prompts_vert = [wrap(x, d, fnt, pad_left) for x in prompts[boundary:]] + + sizes_hor = [(x[2] - x[0], x[3] - x[1]) for x in [d.multiline_textbbox((0, 0), x, font=fnt) for x in prompts_horiz]] + sizes_ver = [(x[2] - x[0], x[3] - x[1]) for x in [d.multiline_textbbox((0, 0), x, font=fnt) for x in prompts_vert]] + hor_text_height = sum([x[1] + line_spacing for x in sizes_hor]) - line_spacing + ver_text_height = sum([x[1] + line_spacing for x in sizes_ver]) - line_spacing + + for col in range(cols): + x = pad_left + width * col + width / 2 + y = pad_top / 2 - hor_text_height / 2 + + draw_texts(col, x, y, prompts_horiz, sizes_hor) + + for row in range(rows): + x = pad_left / 2 + y = pad_top + height * row + height / 2 - ver_text_height / 2 + + draw_texts(row, x, y, prompts_vert, sizes_ver) + + return result + +def check_prompt_length(prompt, comments): + """this function tests if prompt is too long, and if so, adds a message to comments""" + + tokenizer = (st.session_state["model"] if not st.session_state['defaults'].general.optimized else st.session_state.modelCS).cond_stage_model.tokenizer + max_length = (st.session_state["model"] if not st.session_state['defaults'].general.optimized else st.session_state.modelCS).cond_stage_model.max_length + + info = (st.session_state["model"] if not st.session_state['defaults'].general.optimized else st.session_state.modelCS).cond_stage_model.tokenizer([prompt], truncation=True, max_length=max_length, + return_overflowing_tokens=True, padding="max_length", return_tensors="pt") + ovf = info['overflowing_tokens'][0] + overflowing_count = ovf.shape[0] + if overflowing_count == 0: + return + + vocab = {v: k for k, v in tokenizer.get_vocab().items()} + overflowing_words = [vocab.get(int(x), "") for x in ovf] + overflowing_text = tokenizer.convert_tokens_to_string(''.join(overflowing_words)) + + comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n") + +def save_sample(image, sample_path_i, filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale, + normalize_prompt_weights, use_GFPGAN, write_info_files, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, + save_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, save_individual_images, model_name): + + filename_i = os.path.join(sample_path_i, filename) + + if st.session_state['defaults'].general.save_metadata or write_info_files: + # toggles differ for txt2img vs. img2img: + offset = 0 if init_img is None else 2 + toggles = [] + if prompt_matrix: + toggles.append(0) + if normalize_prompt_weights: + toggles.append(1) + if init_img is not None: + if uses_loopback: + toggles.append(2) + if uses_random_seed_loopback: + toggles.append(3) + if save_individual_images: + toggles.append(2 + offset) + if save_grid: + toggles.append(3 + offset) + if sort_samples: + toggles.append(4 + offset) + if write_info_files: + toggles.append(5 + offset) + if use_GFPGAN: + toggles.append(6 + offset) + metadata = \ + dict( + target="txt2img" if init_img is None else "img2img", + prompt=prompts[i], ddim_steps=steps, toggles=toggles, sampler_name=sampler_name, + ddim_eta=ddim_eta, n_iter=n_iter, batch_size=batch_size, cfg_scale=cfg_scale, + seed=seeds[i], width=width, height=height, normalize_prompt_weights=normalize_prompt_weights, model_name=st.session_state["loaded_model"]) + # Not yet any use for these, but they bloat up the files: + # info_dict["init_img"] = init_img + # info_dict["init_mask"] = init_mask + if init_img is not None: + metadata["denoising_strength"] = str(denoising_strength) + metadata["resize_mode"] = resize_mode + + if write_info_files: + with open(f"{filename_i}.yaml", "w", encoding="utf8") as f: + yaml.dump(metadata, f, allow_unicode=True, width=10000) + + if st.session_state['defaults'].general.save_metadata: + # metadata = { + # "SD:prompt": prompts[i], + # "SD:seed": str(seeds[i]), + # "SD:width": str(width), + # "SD:height": str(height), + # "SD:steps": str(steps), + # "SD:cfg_scale": str(cfg_scale), + # "SD:normalize_prompt_weights": str(normalize_prompt_weights), + # } + metadata = {"SD:" + k:v for (k,v) in metadata.items()} + + if save_ext == "png": + mdata = PngInfo() + for key in metadata: + mdata.add_text(key, str(metadata[key])) + image.save(f"{filename_i}.png", pnginfo=mdata) + else: + if jpg_sample: + image.save(f"{filename_i}.jpg", quality=save_quality, + optimize=True) + elif save_ext == "webp": + image.save(f"{filename_i}.{save_ext}", f"webp", quality=save_quality, + lossless=save_lossless) + else: + # not sure what file format this is + image.save(f"{filename_i}.{save_ext}", f"{save_ext}") + try: + exif_dict = piexif.load(f"{filename_i}.{save_ext}") + except: + exif_dict = { "Exif": dict() } + exif_dict["Exif"][piexif.ExifIFD.UserComment] = piexif.helper.UserComment.dump( + json.dumps(metadata), encoding="unicode") + piexif.insert(piexif.dump(exif_dict), f"{filename_i}.{save_ext}") + + +def get_next_sequence_number(path, prefix=''): + """ + Determines and returns the next sequence number to use when saving an + image in the specified directory. + + If a prefix is given, only consider files whose names start with that + prefix, and strip the prefix from filenames before extracting their + sequence number. + + The sequence starts at 0. + """ + result = -1 + for p in Path(path).iterdir(): + if p.name.endswith(('.png', '.jpg')) and p.name.startswith(prefix): + tmp = p.name[len(prefix):] + try: + result = max(int(tmp.split('-')[0]), result) + except ValueError: + pass + return result + 1 + + +def oxlamon_matrix(prompt, seed, n_iter, batch_size): + pattern = re.compile(r'(,\s){2,}') + + class PromptItem: + def __init__(self, text, parts, item): + self.text = text + self.parts = parts + if item: + self.parts.append( item ) + + def clean(txt): + return re.sub(pattern, ', ', txt) + + def getrowcount( txt ): + for data in re.finditer( ".*?\\((.*?)\\).*", txt ): + if data: + return len(data.group(1).split("|")) + break + return None + + def repliter( txt ): + for data in re.finditer( ".*?\\((.*?)\\).*", txt ): + if data: + r = data.span(1) + for item in data.group(1).split("|"): + yield (clean(txt[:r[0]-1] + item.strip() + txt[r[1]+1:]), item.strip()) + break + + def iterlist( items ): + outitems = [] + for item in items: + for newitem, newpart in repliter(item.text): + outitems.append( PromptItem(newitem, item.parts.copy(), newpart) ) + + return outitems + + def getmatrix( prompt ): + dataitems = [ PromptItem( prompt[1:].strip(), [], None ) ] + while True: + newdataitems = iterlist( dataitems ) + if len( newdataitems ) == 0: + return dataitems + dataitems = newdataitems + + def classToArrays( items, seed, n_iter ): + texts = [] + parts = [] + seeds = [] + + for item in items: + itemseed = seed + for i in range(n_iter): + texts.append( item.text ) + parts.append( f"Seed: {itemseed}\n" + "\n".join(item.parts) ) + seeds.append( itemseed ) + itemseed += 1 + + return seeds, texts, parts + + all_seeds, all_prompts, prompt_matrix_parts = classToArrays(getmatrix( prompt ), seed, n_iter) + n_iter = math.ceil(len(all_prompts) / batch_size) + + needrows = getrowcount(prompt) + if needrows: + xrows = math.sqrt(len(all_prompts)) + xrows = round(xrows) + # if columns is to much + cols = math.ceil(len(all_prompts) / xrows) + if cols > needrows*4: + needrows *= 2 + + return all_seeds, n_iter, prompt_matrix_parts, all_prompts, needrows + +# +def process_images( + outpath, func_init, func_sample, prompt, seed, sampler_name, save_grid, batch_size, + n_iter, steps, cfg_scale, width, height, prompt_matrix, use_GFPGAN, use_RealESRGAN, realesrgan_model_name, + ddim_eta=0.0, normalize_prompt_weights=True, init_img=None, init_mask=None, + mask_blur_strength=3, mask_restore=False, denoising_strength=0.75, noise_mode=0, find_noise_steps=1, resize_mode=None, uses_loopback=False, + uses_random_seed_loopback=False, sort_samples=True, write_info_files=True, jpg_sample=False, + variant_amount=0.0, variant_seed=None, save_individual_images: bool = True): + """this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch""" + assert prompt is not None + torch_gc() + # start time after garbage collection (or before?) + start_time = time.time() + + # We will use this date here later for the folder name, need to start_time if not need + run_start_dt = datetime.datetime.now() + + mem_mon = MemUsageMonitor('MemMon') + mem_mon.start() + + if st.session_state.defaults.general.use_sd_concepts_library: + + prompt_tokens = re.findall('<([a-zA-Z0-9-]+)>', prompt) + + if prompt_tokens: + # compviz + tokenizer = (st.session_state["model"] if not st.session_state['defaults'].general.optimized else st.session_state.modelCS).cond_stage_model.tokenizer + text_encoder = (st.session_state["model"] if not st.session_state['defaults'].general.optimized else st.session_state.modelCS).cond_stage_model.transformer + + # diffusers + #tokenizer = pipe.tokenizer + #text_encoder = pipe.text_encoder + + ext = ('pt', 'bin') + + if len(prompt_tokens) > 1: + for token_name in prompt_tokens: + embedding_path = os.path.join(st.session_state['defaults'].general.sd_concepts_library_folder, token_name) + if os.path.exists(embedding_path): + for files in os.listdir(embedding_path): + if files.endswith(ext): + load_learned_embed_in_clip(f"{os.path.join(embedding_path, files)}", text_encoder, tokenizer, f"<{token_name}>") + else: + embedding_path = os.path.join(st.session_state['defaults'].general.sd_concepts_library_folder, prompt_tokens[0]) + if os.path.exists(embedding_path): + for files in os.listdir(embedding_path): + if files.endswith(ext): + load_learned_embed_in_clip(f"{os.path.join(embedding_path, files)}", text_encoder, tokenizer, f"<{prompt_tokens[0]}>") + + # + + + os.makedirs(outpath, exist_ok=True) + + sample_path = os.path.join(outpath, "samples") + os.makedirs(sample_path, exist_ok=True) + + if not ("|" in prompt) and prompt.startswith("@"): + prompt = prompt[1:] + + negprompt = '' + if '###' in prompt: + prompt, negprompt = prompt.split('###', 1) + prompt = prompt.strip() + negprompt = negprompt.strip() + + comments = [] + + prompt_matrix_parts = [] + simple_templating = False + add_original_image = not (use_RealESRGAN or use_GFPGAN) + + if prompt_matrix: + if prompt.startswith("@"): + simple_templating = True + add_original_image = not (use_RealESRGAN or use_GFPGAN) + all_seeds, n_iter, prompt_matrix_parts, all_prompts, frows = oxlamon_matrix(prompt, seed, n_iter, batch_size) + else: + all_prompts = [] + prompt_matrix_parts = prompt.split("|") + combination_count = 2 ** (len(prompt_matrix_parts) - 1) + for combination_num in range(combination_count): + current = prompt_matrix_parts[0] + + for n, text in enumerate(prompt_matrix_parts[1:]): + if combination_num & (2 ** n) > 0: + current += ("" if text.strip().startswith(",") else ", ") + text + + all_prompts.append(current) + + n_iter = math.ceil(len(all_prompts) / batch_size) + all_seeds = len(all_prompts) * [seed] + + print(f"Prompt matrix will create {len(all_prompts)} images using a total of {n_iter} batches.") + else: + + if not st.session_state['defaults'].general.no_verify_input: + try: + check_prompt_length(prompt, comments) + except: + import traceback + print("Error verifying input:", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + + all_prompts = batch_size * n_iter * [prompt] + all_seeds = [seed + x for x in range(len(all_prompts))] + + precision_scope = autocast if st.session_state['defaults'].general.precision == "autocast" else nullcontext + output_images = [] + grid_captions = [] + stats = [] + with torch.no_grad(), precision_scope("cuda"), (st.session_state["model"].ema_scope() if not st.session_state['defaults'].general.optimized else nullcontext()): + init_data = func_init() + tic = time.time() + + + # if variant_amount > 0.0 create noise from base seed + base_x = None + if variant_amount > 0.0: + target_seed_randomizer = seed_to_int('') # random seed + torch.manual_seed(seed) # this has to be the single starting seed (not per-iteration) + base_x = create_random_tensors([opt_C, height // opt_f, width // opt_f], seeds=[seed]) + # we don't want all_seeds to be sequential from starting seed with variants, + # since that makes the same variants each time, + # so we add target_seed_randomizer as a random offset + for si in range(len(all_seeds)): + all_seeds[si] += target_seed_randomizer + + for n in range(n_iter): + print(f"Iteration: {n+1}/{n_iter}") + prompts = all_prompts[n * batch_size:(n + 1) * batch_size] + captions = prompt_matrix_parts[n * batch_size:(n + 1) * batch_size] + seeds = all_seeds[n * batch_size:(n + 1) * batch_size] + + print(prompt) + + if st.session_state['defaults'].general.optimized: + st.session_state.modelCS.to(st.session_state['defaults'].general.gpu) + + uc = (st.session_state["model"] if not st.session_state['defaults'].general.optimized else st.session_state.modelCS).get_learned_conditioning(len(prompts) * [negprompt]) + + if isinstance(prompts, tuple): + prompts = list(prompts) + + # split the prompt if it has : for weighting + # TODO for speed it might help to have this occur when all_prompts filled?? + weighted_subprompts = split_weighted_subprompts(prompts[0], normalize_prompt_weights) + + # sub-prompt weighting used if more than 1 + if len(weighted_subprompts) > 1: + c = torch.zeros_like(uc) # i dont know if this is correct.. but it works + for i in range(0, len(weighted_subprompts)): + # note if alpha negative, it functions same as torch.sub + c = torch.add(c, (st.session_state["model"] if not st.session_state['defaults'].general.optimized else st.session_state.modelCS).get_learned_conditioning(weighted_subprompts[i][0]), alpha=weighted_subprompts[i][1]) + else: # just behave like usual + c = (st.session_state["model"] if not st.session_state['defaults'].general.optimized else st.session_state.modelCS).get_learned_conditioning(prompts) + + + shape = [opt_C, height // opt_f, width // opt_f] + + if st.session_state['defaults'].general.optimized: + mem = torch.cuda.memory_allocated()/1e6 + st.session_state.modelCS.to("cpu") + while(torch.cuda.memory_allocated()/1e6 >= mem): + time.sleep(1) + + if noise_mode == 1 or noise_mode == 3: + # TODO params for find_noise_to_image + x = torch.cat(batch_size * [find_noise_for_image( + st.session_state["model"], st.session_state["device"], + init_img.convert('RGB'), '', find_noise_steps, 0.0, normalize=True, + generation_callback=generation_callback, + )], dim=0) + else: + # we manually generate all input noises because each one should have a specific seed + x = create_random_tensors(shape, seeds=seeds) + + if variant_amount > 0.0: # we are making variants + # using variant_seed as sneaky toggle, + # when not None or '' use the variant_seed + # otherwise use seeds + if variant_seed != None and variant_seed != '': + specified_variant_seed = seed_to_int(variant_seed) + torch.manual_seed(specified_variant_seed) + seeds = [specified_variant_seed] + # finally, slerp base_x noise to target_x noise for creating a variant + x = slerp(st.session_state['defaults'].general.gpu, max(0.0, min(1.0, variant_amount)), base_x, x) + + samples_ddim = func_sample(init_data=init_data, x=x, conditioning=c, unconditional_conditioning=uc, sampler_name=sampler_name) + + if st.session_state['defaults'].general.optimized: + st.session_state.modelFS.to(st.session_state['defaults'].general.gpu) + + x_samples_ddim = (st.session_state["model"] if not st.session_state['defaults'].general.optimized else st.session_state.modelFS).decode_first_stage(samples_ddim) + x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) + + run_images = [] + for i, x_sample in enumerate(x_samples_ddim): + sanitized_prompt = slugify(prompts[i]) + + percent = i / len(x_samples_ddim) + st.session_state["progress_bar"].progress(percent if percent < 100 else 100) + + if sort_samples: + full_path = os.path.join(os.getcwd(), sample_path, sanitized_prompt) + + + sanitized_prompt = sanitized_prompt[:220-len(full_path)] + sample_path_i = os.path.join(sample_path, sanitized_prompt) + + #print(f"output folder length: {len(os.path.join(os.getcwd(), sample_path_i))}") + #print(os.path.join(os.getcwd(), sample_path_i)) + + os.makedirs(sample_path_i, exist_ok=True) + base_count = get_next_sequence_number(sample_path_i) + filename = f"{base_count:05}-{steps}_{sampler_name}_{seeds[i]}" + else: + full_path = os.path.join(os.getcwd(), sample_path) + sample_path_i = sample_path + base_count = get_next_sequence_number(sample_path_i) + filename = f"{base_count:05}-{steps}_{sampler_name}_{seeds[i]}_{sanitized_prompt}"[:220-len(full_path)] #same as before + + x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') + x_sample = x_sample.astype(np.uint8) + image = Image.fromarray(x_sample) + original_sample = x_sample + original_filename = filename + + st.session_state["preview_image"].image(image) + + if use_GFPGAN and st.session_state["GFPGAN"] is not None and not use_RealESRGAN: + st.session_state["progress_bar_text"].text("Running GFPGAN on image %d of %d..." % (i+1, len(x_samples_ddim))) + #skip_save = True # #287 >_> + torch_gc() + cropped_faces, restored_faces, restored_img = st.session_state["GFPGAN"].enhance(x_sample[:,:,::-1], has_aligned=False, only_center_face=False, paste_back=True) + gfpgan_sample = restored_img[:,:,::-1] + gfpgan_image = Image.fromarray(gfpgan_sample) + gfpgan_filename = original_filename + '-gfpgan' + + save_sample(gfpgan_image, sample_path_i, gfpgan_filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale, + normalize_prompt_weights, use_GFPGAN, write_info_files, prompt_matrix, init_img, uses_loopback, + uses_random_seed_loopback, save_grid, sort_samples, sampler_name, ddim_eta, + n_iter, batch_size, i, denoising_strength, resize_mode, False, st.session_state["loaded_model"]) + + output_images.append(gfpgan_image) #287 + run_images.append(gfpgan_image) + + if simple_templating: + grid_captions.append( captions[i] + "\ngfpgan" ) + + elif use_RealESRGAN and st.session_state["RealESRGAN"] is not None and not use_GFPGAN: + st.session_state["progress_bar_text"].text("Running RealESRGAN on image %d of %d..." % (i+1, len(x_samples_ddim))) + #skip_save = True # #287 >_> + torch_gc() + + if st.session_state["RealESRGAN"].model.name != realesrgan_model_name: + #try_loading_RealESRGAN(realesrgan_model_name) + load_models(use_GFPGAN=use_GFPGAN, use_RealESRGAN=use_RealESRGAN, RealESRGAN_model=realesrgan_model_name) + + output, img_mode = st.session_state["RealESRGAN"].enhance(x_sample[:,:,::-1]) + esrgan_filename = original_filename + '-esrgan4x' + esrgan_sample = output[:,:,::-1] + esrgan_image = Image.fromarray(esrgan_sample) + + #save_sample(image, sample_path_i, original_filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale, + #normalize_prompt_weights, use_GFPGAN, write_info_files, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, skip_save, + #save_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode) + + save_sample(esrgan_image, sample_path_i, esrgan_filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale, + normalize_prompt_weights, use_GFPGAN, write_info_files, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, + save_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, False, st.session_state["loaded_model"]) + + output_images.append(esrgan_image) #287 + run_images.append(esrgan_image) + + if simple_templating: + grid_captions.append( captions[i] + "\nesrgan" ) + + elif use_RealESRGAN and st.session_state["RealESRGAN"] is not None and use_GFPGAN and st.session_state["GFPGAN"] is not None: + st.session_state["progress_bar_text"].text("Running GFPGAN+RealESRGAN on image %d of %d..." % (i+1, len(x_samples_ddim))) + #skip_save = True # #287 >_> + torch_gc() + cropped_faces, restored_faces, restored_img = st.session_state["GFPGAN"].enhance(x_sample[:,:,::-1], has_aligned=False, only_center_face=False, paste_back=True) + gfpgan_sample = restored_img[:,:,::-1] + + if st.session_state["RealESRGAN"].model.name != realesrgan_model_name: + #try_loading_RealESRGAN(realesrgan_model_name) + load_models(use_GFPGAN=use_GFPGAN, use_RealESRGAN=use_RealESRGAN, RealESRGAN_model=realesrgan_model_name) + + output, img_mode = st.session_state["RealESRGAN"].enhance(gfpgan_sample[:,:,::-1]) + gfpgan_esrgan_filename = original_filename + '-gfpgan-esrgan4x' + gfpgan_esrgan_sample = output[:,:,::-1] + gfpgan_esrgan_image = Image.fromarray(gfpgan_esrgan_sample) + + save_sample(gfpgan_esrgan_image, sample_path_i, gfpgan_esrgan_filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale, + normalize_prompt_weights, False, write_info_files, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, + save_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, False, st.session_state["loaded_model"]) + + output_images.append(gfpgan_esrgan_image) #287 + run_images.append(gfpgan_esrgan_image) + + if simple_templating: + grid_captions.append( captions[i] + "\ngfpgan_esrgan" ) + else: + output_images.append(image) + run_images.append(image) + + if mask_restore and init_mask: + #init_mask = init_mask if keep_mask else ImageOps.invert(init_mask) + init_mask = init_mask.filter(ImageFilter.GaussianBlur(mask_blur_strength)) + init_mask = init_mask.convert('L') + init_img = init_img.convert('RGB') + image = image.convert('RGB') + + if use_RealESRGAN and st.session_state["RealESRGAN"] is not None: + if st.session_state["RealESRGAN"].model.name != realesrgan_model_name: + #try_loading_RealESRGAN(realesrgan_model_name) + load_models(use_GFPGAN=use_GFPGAN, use_RealESRGAN=use_RealESRGAN, RealESRGAN_model=realesrgan_model_name) + + output, img_mode = st.session_state["RealESRGAN"].enhance(np.array(init_img, dtype=np.uint8)) + init_img = Image.fromarray(output) + init_img = init_img.convert('RGB') + + output, img_mode = st.session_state["RealESRGAN"].enhance(np.array(init_mask, dtype=np.uint8)) + init_mask = Image.fromarray(output) + init_mask = init_mask.convert('L') + + image = Image.composite(init_img, image, init_mask) + + if save_individual_images: + save_sample(image, sample_path_i, filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale, + normalize_prompt_weights, use_GFPGAN, write_info_files, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, + save_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, save_individual_images, st.session_state["loaded_model"]) + + #if add_original_image or not simple_templating: + #output_images.append(image) + #if simple_templating: + #grid_captions.append( captions[i] ) + + if st.session_state['defaults'].general.optimized: + mem = torch.cuda.memory_allocated()/1e6 + st.session_state.modelFS.to("cpu") + while(torch.cuda.memory_allocated()/1e6 >= mem): + time.sleep(1) + + if len(run_images) > 1: + preview_image = image_grid(run_images, n_iter) + else: + preview_image = run_images[0] + + # Constrain the final preview image to 1440x900 so we're not sending huge amounts of data + # to the browser + preview_image = constrain_image(preview_image, 1440, 900) + st.session_state["progress_bar_text"].text("Finished!") + st.session_state["preview_image"].image(preview_image) + + if prompt_matrix or save_grid: + if prompt_matrix: + if simple_templating: + grid = image_grid(output_images, n_iter, force_n_rows=frows, captions=grid_captions) + else: + grid = image_grid(output_images, n_iter, force_n_rows=1 << ((len(prompt_matrix_parts)-1)//2)) + try: + grid = draw_prompt_matrix(grid, width, height, prompt_matrix_parts) + except: + import traceback + print("Error creating prompt_matrix text:", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + else: + grid = image_grid(output_images, batch_size) + + if grid and (batch_size > 1 or n_iter > 1): + output_images.insert(0, grid) + + grid_count = get_next_sequence_number(outpath, 'grid-') + grid_file = f"grid-{grid_count:05}-{seed}_{slugify(prompts[i].replace(' ', '_')[:220-len(full_path)])}.{grid_ext}" + grid.save(os.path.join(outpath, grid_file), grid_format, quality=grid_quality, lossless=grid_lossless, optimize=True) + + toc = time.time() + + mem_max_used, mem_total = mem_mon.read_and_stop() + time_diff = time.time()-start_time + + info = f""" + {prompt} + Steps: {steps}, Sampler: {sampler_name}, CFG scale: {cfg_scale}, Seed: {seed}{', Denoising strength: '+str(denoising_strength) if init_img is not None else ''}{', GFPGAN' if use_GFPGAN and st.session_state["GFPGAN"] is not None else ''}{', '+realesrgan_model_name if use_RealESRGAN and st.session_state["RealESRGAN"] is not None else ''}{', Prompt Matrix Mode.' if prompt_matrix else ''}""".strip() + stats = f''' + Took { round(time_diff, 2) }s total ({ round(time_diff/(len(all_prompts)),2) }s per image) + Peak memory usage: { -(mem_max_used // -1_048_576) } MiB / { -(mem_total // -1_048_576) } MiB / { round(mem_max_used/mem_total*100, 3) }%''' + + for comment in comments: + info += "\n\n" + comment + + #mem_mon.stop() + #del mem_mon + torch_gc() + + return output_images, seed, info, stats + + +def resize_image(resize_mode, im, width, height): + LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS) + if resize_mode == 0: + res = im.resize((width, height), resample=LANCZOS) + elif resize_mode == 1: + ratio = width / height + src_ratio = im.width / im.height + + src_w = width if ratio > src_ratio else im.width * height // im.height + src_h = height if ratio <= src_ratio else im.height * width // im.width + + resized = im.resize((src_w, src_h), resample=LANCZOS) + res = Image.new("RGBA", (width, height)) + res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2)) + else: + ratio = width / height + src_ratio = im.width / im.height + + src_w = width if ratio < src_ratio else im.width * height // im.height + src_h = height if ratio >= src_ratio else im.height * width // im.width + + resized = im.resize((src_w, src_h), resample=LANCZOS) + res = Image.new("RGBA", (width, height)) + res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2)) + + if ratio < src_ratio: + fill_height = height // 2 - src_h // 2 + res.paste(resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0)) + res.paste(resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)), box=(0, fill_height + src_h)) + elif ratio > src_ratio: + fill_width = width // 2 - src_w // 2 + res.paste(resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0)) + res.paste(resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)), box=(fill_width + src_w, 0)) + + return res + +def constrain_image(img, max_width, max_height): + ratio = max(img.width / max_width, img.height / max_height) + if ratio <= 1: + return img + resampler = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS) + resized = img.resize((int(img.width / ratio), int(img.height / ratio)), resample=resampler) + return resized diff --git a/scripts/stable_diffusion_pipeline.py b/scripts/stable_diffusion_pipeline.py new file mode 100644 index 0000000..6f4f794 --- /dev/null +++ b/scripts/stable_diffusion_pipeline.py @@ -0,0 +1,233 @@ +import inspect +import warnings +from tqdm.auto import tqdm +from typing import List, Optional, Union + +import torch +from diffusers import ModelMixin +from diffusers.models import AutoencoderKL, UNet2DConditionModel +from diffusers.pipeline_utils import DiffusionPipeline +from diffusers.pipelines.stable_diffusion.safety_checker import \ + StableDiffusionSafetyChecker +from diffusers.schedulers import (DDIMScheduler, LMSDiscreteScheduler, + PNDMScheduler) +from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer + + +class StableDiffusionPipeline(DiffusionPipeline): + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPFeatureExtractor, + ): + super().__init__() + scheduler = scheduler.set_format("pt") + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + + @torch.no_grad() + def __call__( + self, + prompt: Optional[Union[str, List[str]]] = None, + height: Optional[int] = 512, + width: Optional[int] = 512, + num_inference_steps: Optional[int] = 50, + guidance_scale: Optional[float] = 7.5, + eta: Optional[float] = 0.0, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.FloatTensor] = None, + text_embeddings: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + **kwargs, + ): + if "torch_device" in kwargs: + device = kwargs.pop("torch_device") + warnings.warn( + "`torch_device` is deprecated as an input argument to `__call__` and" + " will be removed in v0.3.0. Consider using `pipe.to(torch_device)`" + " instead." + ) + + # Set device as before (to be removed in 0.3.0) + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + self.to(device) + + if text_embeddings is None: + if isinstance(prompt, str): + batch_size = 1 + elif isinstance(prompt, list): + batch_size = len(prompt) + else: + raise ValueError( + f"`prompt` has to be of type `str` or `list` but is {type(prompt)}" + ) + + if height % 8 != 0 or width % 8 != 0: + raise ValueError( + "`height` and `width` have to be divisible by 8 but are" + f" {height} and {width}." + ) + + # get prompt text embeddings + text_input = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0] + else: + batch_size = text_embeddings.shape[0] + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance: + # max_length = text_input.input_ids.shape[-1] + max_length = 77 # self.tokenizer.model_max_length + uncond_input = self.tokenizer( + [""] * batch_size, + padding="max_length", + max_length=max_length, + return_tensors="pt", + ) + uncond_embeddings = self.text_encoder( + uncond_input.input_ids.to(self.device) + )[0] + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + + # get the initial random noise unless the user supplied it + latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8) + if latents is None: + latents = torch.randn( + latents_shape, + generator=generator, + device=self.device, + ) + else: + if latents.shape != latents_shape: + raise ValueError( + f"Unexpected latents shape, got {latents.shape}, expected" + f" {latents_shape}" + ) + latents = latents.to(self.device) + + # set timesteps + accepts_offset = "offset" in set( + inspect.signature(self.scheduler.set_timesteps).parameters.keys() + ) + extra_set_kwargs = {} + if accepts_offset: + extra_set_kwargs["offset"] = 1 + + self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs) + + # if we use LMSDiscreteScheduler, let's make sure latents are mulitplied by sigmas + if isinstance(self.scheduler, LMSDiscreteScheduler): + latents = latents * self.scheduler.sigmas[0] + + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + accepts_eta = "eta" in set( + inspect.signature(self.scheduler.step).parameters.keys() + ) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)): + # expand the latents if we are doing classifier free guidance + latent_model_input = ( + torch.cat([latents] * 2) if do_classifier_free_guidance else latents + ) + if isinstance(self.scheduler, LMSDiscreteScheduler): + sigma = self.scheduler.sigmas[i] + # the model input needs to be scaled to match the continuous ODE formulation in K-LMS + latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, t, encoder_hidden_states=text_embeddings + )["sample"] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) + + # compute the previous noisy sample x_t -> x_t-1 + if isinstance(self.scheduler, LMSDiscreteScheduler): + latents = self.scheduler.step( + noise_pred, i, latents, **extra_step_kwargs + )["prev_sample"] + else: + latents = self.scheduler.step( + noise_pred, t, latents, **extra_step_kwargs + )["prev_sample"] + + # scale and decode the image latents with vae + latents = 1 / 0.18215 * latents + image = self.vae.decode(latents).sample + + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).numpy() + + safety_cheker_input = self.feature_extractor( + self.numpy_to_pil(image), return_tensors="pt" + ).to(self.device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_cheker_input.pixel_values + ) + + if output_type == "pil": + image = self.numpy_to_pil(image) + + return {"sample": image, "nsfw_content_detected": has_nsfw_concept} + + def embed_text(self, text): + """Helper to embed some text""" + with torch.autocast("cuda"): + text_input = self.tokenizer( + text, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + with torch.no_grad(): + embed = self.text_encoder(text_input.input_ids.to(self.device))[0] + return embed + + +class NoCheck(ModelMixin): + """Can be used in place of safety checker. Use responsibly and at your own risk.""" + def __init__(self): + super().__init__() + self.register_parameter(name='asdf', param=torch.nn.Parameter(torch.randn(3))) + + def forward(self, images=None, **kwargs): + return images, [False] diff --git a/scripts/stable_diffusion_walk.py b/scripts/stable_diffusion_walk.py new file mode 100644 index 0000000..1ce175d --- /dev/null +++ b/scripts/stable_diffusion_walk.py @@ -0,0 +1,218 @@ +import json +import subprocess +from pathlib import Path + +import numpy as np +import torch +from diffusers.schedulers import (DDIMScheduler, LMSDiscreteScheduler, + PNDMScheduler) +from diffusers import ModelMixin + +from stable_diffusion_pipeline import StableDiffusionPipeline + +pipeline = StableDiffusionPipeline.from_pretrained( + "CompVis/stable-diffusion-v1-4", + use_auth_token=True, + torch_dtype=torch.float16, + revision="fp16", +).to("cuda") + +default_scheduler = PNDMScheduler( + beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" +) +ddim_scheduler = DDIMScheduler( + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + clip_sample=False, + set_alpha_to_one=False, +) +klms_scheduler = LMSDiscreteScheduler( + beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" +) +SCHEDULERS = dict(default=default_scheduler, ddim=ddim_scheduler, klms=klms_scheduler) + + +def slerp(t, v0, v1, DOT_THRESHOLD=0.9995): + """helper function to spherically interpolate two arrays v1 v2""" + + if not isinstance(v0, np.ndarray): + inputs_are_torch = True + input_device = v0.device + v0 = v0.cpu().numpy() + v1 = v1.cpu().numpy() + + dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1))) + if np.abs(dot) > DOT_THRESHOLD: + v2 = (1 - t) * v0 + t * v1 + else: + theta_0 = np.arccos(dot) + sin_theta_0 = np.sin(theta_0) + theta_t = theta_0 * t + sin_theta_t = np.sin(theta_t) + s0 = np.sin(theta_0 - theta_t) / sin_theta_0 + s1 = sin_theta_t / sin_theta_0 + v2 = s0 * v0 + s1 * v1 + + if inputs_are_torch: + v2 = torch.from_numpy(v2).to(input_device) + + return v2 + + +def make_video_ffmpeg(frame_dir, output_file_name='output.mp4', frame_filename="frame%06d.jpg", fps=30): + frame_ref_path = str(frame_dir / frame_filename) + video_path = str(frame_dir / output_file_name) + subprocess.call( + f"ffmpeg -r {fps} -i {frame_ref_path} -vcodec libx264 -crf 10 -pix_fmt yuv420p" + f" {video_path}".split() + ) + return video_path + + +def walk( + prompts=["blueberry spaghetti", "strawberry spaghetti"], + seeds=[42, 123], + num_steps=5, + output_dir="dreams", + name="berry_good_spaghetti", + height=512, + width=512, + guidance_scale=7.5, + eta=0.0, + num_inference_steps=50, + do_loop=False, + make_video=False, + use_lerp_for_text=False, + scheduler="klms", # choices: default, ddim, klms + disable_tqdm=False, + upsample=False, + fps=30, +): + """Generate video frames/a video given a list of prompts and seeds. + + Args: + prompts (List[str], optional): List of . Defaults to ["blueberry spaghetti", "strawberry spaghetti"]. + seeds (List[int], optional): List of random seeds corresponding to given prompts. + num_steps (int, optional): Number of steps to walk. Increase this value to 60-200 for good results. Defaults to 5. + output_dir (str, optional): Root dir where images will be saved. Defaults to "dreams". + name (str, optional): Sub directory of output_dir to save this run's files. Defaults to "berry_good_spaghetti". + height (int, optional): Height of image to generate. Defaults to 512. + width (int, optional): Width of image to generate. Defaults to 512. + guidance_scale (float, optional): Higher = more adherance to prompt. Lower = let model take the wheel. Defaults to 7.5. + eta (float, optional): ETA. Defaults to 0.0. + num_inference_steps (int, optional): Number of diffusion steps. Defaults to 50. + do_loop (bool, optional): Whether to loop from last prompt back to first. Defaults to False. + make_video (bool, optional): Whether to make a video or just save the images. Defaults to False. + use_lerp_for_text (bool, optional): Use LERP instead of SLERP for text embeddings when walking. Defaults to False. + scheduler (str, optional): Which scheduler to use. Defaults to "klms". Choices are "default", "ddim", "klms". + disable_tqdm (bool, optional): Whether to turn off the tqdm progress bars. Defaults to False. + upsample (bool, optional): If True, uses Real-ESRGAN to upsample images 4x. Requires it to be installed + which you can do by running: `pip install git+https://github.com/xinntao/Real-ESRGAN.git`. Defaults to False. + fps (int, optional): The frames per second (fps) that you want the video to use. Does nothing if make_video is False. Defaults to 30. + + Returns: + str: Path to video file saved if make_video=True, else None. + """ + if upsample: + from .upsampling import PipelineRealESRGAN + + upsampling_pipeline = PipelineRealESRGAN.from_pretrained('nateraw/real-esrgan') + + pipeline.set_progress_bar_config(disable=disable_tqdm) + + pipeline.scheduler = SCHEDULERS[scheduler] + + output_path = Path(output_dir) / name + output_path.mkdir(exist_ok=True, parents=True) + + # Write prompt info to file in output dir so we can keep track of what we did + prompt_config_path = output_path / 'prompt_config.json' + prompt_config_path.write_text( + json.dumps( + dict( + prompts=prompts, + seeds=seeds, + num_steps=num_steps, + name=name, + guidance_scale=guidance_scale, + eta=eta, + num_inference_steps=num_inference_steps, + do_loop=do_loop, + make_video=make_video, + use_lerp_for_text=use_lerp_for_text, + scheduler=scheduler + ), + indent=2, + sort_keys=False, + ) + ) + + assert len(prompts) == len(seeds) + + first_prompt, *prompts = prompts + embeds_a = pipeline.embed_text(first_prompt) + + first_seed, *seeds = seeds + latents_a = torch.randn( + (1, pipeline.unet.in_channels, height // 8, width // 8), + device=pipeline.device, + generator=torch.Generator(device=pipeline.device).manual_seed(first_seed), + ) + + if do_loop: + prompts.append(first_prompt) + seeds.append(first_seed) + + frame_index = 0 + for prompt, seed in zip(prompts, seeds): + # Text + embeds_b = pipeline.embed_text(prompt) + + # Latent Noise + latents_b = torch.randn( + (1, pipeline.unet.in_channels, height // 8, width // 8), + device=pipeline.device, + generator=torch.Generator(device=pipeline.device).manual_seed(seed), + ) + + for i, t in enumerate(np.linspace(0, 1, num_steps)): + do_print_progress = (i == 0) or ((frame_index + 1) % 20 == 0) + if do_print_progress: + print(f"COUNT: {frame_index+1}/{len(seeds)*num_steps}") + + if use_lerp_for_text: + embeds = torch.lerp(embeds_a, embeds_b, float(t)) + else: + embeds = slerp(float(t), embeds_a, embeds_b) + latents = slerp(float(t), latents_a, latents_b) + + with torch.autocast("cuda"): + im = pipeline( + latents=latents, + text_embeddings=embeds, + height=height, + width=width, + guidance_scale=guidance_scale, + eta=eta, + num_inference_steps=num_inference_steps, + output_type='pil' if not upsample else 'numpy' + )["sample"][0] + + if upsample: + im = upsampling_pipeline(im) + + im.save(output_path / ("frame%06d.jpg" % frame_index)) + frame_index += 1 + + embeds_a = embeds_b + latents_a = latents_b + + if make_video: + return make_video_ffmpeg(output_path, f"{name}.mp4", fps=fps) + + +if __name__ == "__main__": + import fire + + fire.Fire(walk) diff --git a/scripts/textual_inversion.py b/scripts/textual_inversion.py new file mode 100644 index 0000000..3e5cc3e --- /dev/null +++ b/scripts/textual_inversion.py @@ -0,0 +1,57 @@ +# base webui import and utils. +from webui_streamlit import st +from sd_utils import * + +# streamlit imports + + +#other imports +#from transformers import CLIPTextModel, CLIPTokenizer + +# Temp imports + + +# end of imports +#--------------------------------------------------------------------------------------------------------------- + +#def load_learned_embed_in_clip(learned_embeds_path, text_encoder, tokenizer, token=None): + + #loaded_learned_embeds = torch.load(learned_embeds_path, map_location="cpu") + + ## separate token and the embeds + #print (loaded_learned_embeds) + #trained_token = list(loaded_learned_embeds.keys())[0] + #embeds = loaded_learned_embeds[trained_token] + + ## cast to dtype of text_encoder + #dtype = text_encoder.get_input_embeddings().weight.dtype + #embeds.to(dtype) + + ## add the token in tokenizer + #token = token if token is not None else trained_token + #num_added_tokens = tokenizer.add_tokens(token) + #i = 1 + #while(num_added_tokens == 0): + #print(f"The tokenizer already contains the token {token}.") + #token = f"{token[:-1]}-{i}>" + #print(f"Attempting to add the token {token}.") + #num_added_tokens = tokenizer.add_tokens(token) + #i+=1 + + ## resize the token embeddings + #text_encoder.resize_token_embeddings(len(tokenizer)) + + ## get the id for the token and assign the embeds + #token_id = tokenizer.convert_tokens_to_ids(token) + #text_encoder.get_input_embeddings().weight.data[token_id] = embeds + #return token + +##def token_loader() +#learned_token = load_learned_embed_in_clip(f"models/custom/embeddings/Custom Ami.pt", st.session_state.pipe.text_encoder, st.session_state.pipe.tokenizer, "*") +#model_content["token"] = learned_token +#models.append(model_content) + +model_id = "./models/custom/embeddings/" + +def layout(): + st.write("Textual Inversion") \ No newline at end of file diff --git a/scripts/txt2img.py b/scripts/txt2img.py new file mode 100644 index 0000000..6f74143 --- /dev/null +++ b/scripts/txt2img.py @@ -0,0 +1,368 @@ +# base webui import and utils. +from webui_streamlit import st +from sd_utils import * + +# streamlit imports +from streamlit import StopException +from streamlit.runtime.in_memory_file_manager import in_memory_file_manager +from streamlit.elements import image as STImage + +#other imports +import os +from typing import Union +from io import BytesIO +from ldm.models.diffusion.ddim import DDIMSampler +from ldm.models.diffusion.plms import PLMSSampler + +# Temp imports + + +# end of imports +#--------------------------------------------------------------------------------------------------------------- + + +try: + # this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start. + from transformers import logging + + logging.set_verbosity_error() +except: + pass + +class plugin_info(): + plugname = "txt2img" + description = "Text to Image" + isTab = True + displayPriority = 1 + + +if os.path.exists(os.path.join(st.session_state['defaults'].general.GFPGAN_dir, "experiments", "pretrained_models", "GFPGANv1.3.pth")): + GFPGAN_available = True +else: + GFPGAN_available = False + +if os.path.exists(os.path.join(st.session_state['defaults'].general.RealESRGAN_dir, "experiments","pretrained_models", f"{st.session_state['defaults'].general.RealESRGAN_model}.pth")): + RealESRGAN_available = True +else: + RealESRGAN_available = False + +# +def txt2img(prompt: str, ddim_steps: int, sampler_name: str, realesrgan_model_name: str, + n_iter: int, batch_size: int, cfg_scale: float, seed: Union[int, str, None], + height: int, width: int, separate_prompts:bool = False, normalize_prompt_weights:bool = True, + save_individual_images: bool = True, save_grid: bool = True, group_by_prompt: bool = True, + save_as_jpg: bool = True, use_GFPGAN: bool = True, use_RealESRGAN: bool = True, + RealESRGAN_model: str = "RealESRGAN_x4plus_anime_6B", fp = None, variant_amount: float = None, + variant_seed: int = None, ddim_eta:float = 0.0, write_info_files:bool = True): + + outpath = st.session_state['defaults'].general.outdir_txt2img or st.session_state['defaults'].general.outdir or "outputs/txt2img-samples" + + seed = seed_to_int(seed) + + #prompt_matrix = 0 in toggles + #normalize_prompt_weights = 1 in toggles + #skip_save = 2 not in toggles + #save_grid = 3 not in toggles + #sort_samples = 4 in toggles + #write_info_files = 5 in toggles + #jpg_sample = 6 in toggles + #use_GFPGAN = 7 in toggles + #use_RealESRGAN = 8 in toggles + + if sampler_name == 'PLMS': + sampler = PLMSSampler(st.session_state["model"]) + elif sampler_name == 'DDIM': + sampler = DDIMSampler(st.session_state["model"]) + elif sampler_name == 'k_dpm_2_a': + sampler = KDiffusionSampler(st.session_state["model"],'dpm_2_ancestral') + elif sampler_name == 'k_dpm_2': + sampler = KDiffusionSampler(st.session_state["model"],'dpm_2') + elif sampler_name == 'k_euler_a': + sampler = KDiffusionSampler(st.session_state["model"],'euler_ancestral') + elif sampler_name == 'k_euler': + sampler = KDiffusionSampler(st.session_state["model"],'euler') + elif sampler_name == 'k_heun': + sampler = KDiffusionSampler(st.session_state["model"],'heun') + elif sampler_name == 'k_lms': + sampler = KDiffusionSampler(st.session_state["model"],'lms') + else: + raise Exception("Unknown sampler: " + sampler_name) + + def init(): + pass + + def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name): + samples_ddim, _ = sampler.sample(S=ddim_steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=cfg_scale, + unconditional_conditioning=unconditional_conditioning, eta=ddim_eta, x_T=x, img_callback=generation_callback, + log_every_t=int(st.session_state.update_preview_frequency)) + + return samples_ddim + + #try: + output_images, seed, info, stats = process_images( + outpath=outpath, + func_init=init, + func_sample=sample, + prompt=prompt, + seed=seed, + sampler_name=sampler_name, + save_grid=save_grid, + batch_size=batch_size, + n_iter=n_iter, + steps=ddim_steps, + cfg_scale=cfg_scale, + width=width, + height=height, + prompt_matrix=separate_prompts, + use_GFPGAN=st.session_state["use_GFPGAN"], + use_RealESRGAN=st.session_state["use_RealESRGAN"], + realesrgan_model_name=realesrgan_model_name, + ddim_eta=ddim_eta, + normalize_prompt_weights=normalize_prompt_weights, + save_individual_images=save_individual_images, + sort_samples=group_by_prompt, + write_info_files=write_info_files, + jpg_sample=save_as_jpg, + variant_amount=variant_amount, + variant_seed=variant_seed, + ) + + del sampler + + return output_images, seed, info, stats + + #except RuntimeError as e: + #err = e + #err_msg = f'CRASHED:


Please wait while the program restarts.' + #stats = err_msg + #return [], seed, 'err', stats + +def layout(): + with st.form("txt2img-inputs"): + st.session_state["generation_mode"] = "txt2img" + + input_col1, generate_col1 = st.columns([10,1]) + + with input_col1: + #prompt = st.text_area("Input Text","") + prompt = st.text_input("Input Text","", placeholder="A corgi wearing a top hat as an oil painting.") + + # Every form must have a submit button, the extra blank spaces is a temp way to align it with the input field. Needs to be done in CSS or some other way. + generate_col1.write("") + generate_col1.write("") + generate_button = generate_col1.form_submit_button("Generate") + + # creating the page layout using columns + col1, col2, col3 = st.columns([1,2,1], gap="large") + + with col1: + width = st.slider("Width:", min_value=64, max_value=4096, value=st.session_state['defaults'].txt2img.width, step=64) + height = st.slider("Height:", min_value=64, max_value=4096, value=st.session_state['defaults'].txt2img.height, step=64) + cfg_scale = st.slider("CFG (Classifier Free Guidance Scale):", min_value=1.0, max_value=30.0, value=st.session_state['defaults'].txt2img.cfg_scale, step=0.5, help="How strongly the image should follow the prompt.") + seed = st.text_input("Seed:", value=st.session_state['defaults'].txt2img.seed, help=" The seed to use, if left blank a random seed will be generated.") + batch_count = st.slider("Batch count.", min_value=1, max_value=100, value=st.session_state['defaults'].txt2img.batch_count, step=1, help="How many iterations or batches of images to generate in total.") + + bs_slider_max_value = 5 + if st.session_state.defaults.general.optimized: + bs_slider_max_value = 100 + + batch_size = st.slider( + "Batch size", + min_value=1, + max_value=bs_slider_max_value, + value=st.session_state.defaults.txt2img.batch_size, + step=1, + help="How many images are at once in a batch.\ + It increases the VRAM usage a lot but if you have enough VRAM it can reduce the time it takes to finish generation as more images are generated at once.\ + Default: 1") + + with st.expander("Preview Settings"): + st.session_state["update_preview"] = st.checkbox("Update Image Preview", value=st.session_state['defaults'].txt2img.update_preview, + help="If enabled the image preview will be updated during the generation instead of at the end. \ + You can use the Update Preview \Frequency option bellow to customize how frequent it's updated. \ + By default this is enabled and the frequency is set to 1 step.") + + st.session_state["update_preview_frequency"] = st.text_input("Update Image Preview Frequency", value=st.session_state['defaults'].txt2img.update_preview_frequency, + help="Frequency in steps at which the the preview image is updated. By default the frequency \ + is set to 1 step.") + + with col2: + preview_tab, gallery_tab = st.tabs(["Preview", "Gallery"]) + + with preview_tab: + #st.write("Image") + #Image for testing + #image = Image.open(requests.get("https://icon-library.com/images/image-placeholder-icon/image-placeholder-icon-13.jpg", stream=True).raw).convert('RGB') + #new_image = image.resize((175, 240)) + #preview_image = st.image(image) + + # create an empty container for the image, progress bar, etc so we can update it later and use session_state to hold them globally. + st.session_state["preview_image"] = st.empty() + + st.session_state["loading"] = st.empty() + + st.session_state["progress_bar_text"] = st.empty() + st.session_state["progress_bar"] = st.empty() + + message = st.empty() + + with col3: + # If we have custom models available on the "models/custom" + #folder then we show a menu to select which model we want to use, otherwise we use the main model for SD + if st.session_state.CustomModel_available: + st.session_state.custom_model = st.selectbox("Custom Model:", st.session_state.custom_models, + index=st.session_state["custom_models"].index(st.session_state['defaults'].general.default_model), + help="Select the model you want to use. This option is only available if you have custom models \ + on your 'models/custom' folder. The model name that will be shown here is the same as the name\ + the file for the model has on said folder, it is recommended to give the .ckpt file a name that \ + will make it easier for you to distinguish it from other models. Default: Stable Diffusion v1.4") + + st.session_state.sampling_steps = st.slider("Sampling Steps", + value=st.session_state['defaults'].txt2img.sampling_steps, + min_value=st.session_state['defaults'].txt2img.slider_bounds.sampling.lower, + max_value=st.session_state['defaults'].txt2img.slider_bounds.sampling.upper, + step=st.session_state['defaults'].txt2img.slider_steps.sampling) + + sampler_name_list = ["k_lms", "k_euler", "k_euler_a", "k_dpm_2", "k_dpm_2_a", "k_heun", "PLMS", "DDIM"] + sampler_name = st.selectbox("Sampling method", sampler_name_list, + index=sampler_name_list.index(st.session_state['defaults'].txt2img.default_sampler), help="Sampling method to use. Default: k_euler") + + + + #basic_tab, advanced_tab = st.tabs(["Basic", "Advanced"]) + + #with basic_tab: + #summit_on_enter = st.radio("Submit on enter?", ("Yes", "No"), horizontal=True, + #help="Press the Enter key to summit, when 'No' is selected you can use the Enter key to write multiple lines.") + + with st.expander("Advanced"): + separate_prompts = st.checkbox("Create Prompt Matrix.", value=st.session_state['defaults'].txt2img.separate_prompts, help="Separate multiple prompts using the `|` character, and get all combinations of them.") + normalize_prompt_weights = st.checkbox("Normalize Prompt Weights.", value=st.session_state['defaults'].txt2img.normalize_prompt_weights, help="Ensure the sum of all weights add up to 1.0") + save_individual_images = st.checkbox("Save individual images.", value=st.session_state['defaults'].txt2img.save_individual_images, help="Save each image generated before any filter or enhancement is applied.") + save_grid = st.checkbox("Save grid",value=st.session_state['defaults'].txt2img.save_grid, help="Save a grid with all the images generated into a single image.") + group_by_prompt = st.checkbox("Group results by prompt", value=st.session_state['defaults'].txt2img.group_by_prompt, + help="Saves all the images with the same prompt into the same folder. When using a prompt matrix each prompt combination will have its own folder.") + write_info_files = st.checkbox("Write Info file", value=st.session_state['defaults'].txt2img.write_info_files, help="Save a file next to the image with informartion about the generation.") + save_as_jpg = st.checkbox("Save samples as jpg", value=st.session_state['defaults'].txt2img.save_as_jpg, help="Saves the images as jpg instead of png.") + + if st.session_state["GFPGAN_available"]: + st.session_state["use_GFPGAN"] = st.checkbox("Use GFPGAN", value=st.session_state['defaults'].txt2img.use_GFPGAN, help="Uses the GFPGAN model to improve faces after the generation.\ + This greatly improve the quality and consistency of faces but uses extra VRAM. Disable if you need the extra VRAM.") + else: + st.session_state["use_GFPGAN"] = False + + if st.session_state["RealESRGAN_available"]: + st.session_state["use_RealESRGAN"] = st.checkbox("Use RealESRGAN", value=st.session_state['defaults'].txt2img.use_RealESRGAN, + help="Uses the RealESRGAN model to upscale the images after the generation.\ + This greatly improve the quality and lets you have high resolution images but uses extra VRAM. Disable if you need the extra VRAM.") + st.session_state["RealESRGAN_model"] = st.selectbox("RealESRGAN model", ["RealESRGAN_x4plus", "RealESRGAN_x4plus_anime_6B"], index=0) + else: + st.session_state["use_RealESRGAN"] = False + st.session_state["RealESRGAN_model"] = "RealESRGAN_x4plus" + + variant_amount = st.slider("Variant Amount:", value=st.session_state['defaults'].txt2img.variant_amount, min_value=0.0, max_value=1.0, step=0.01) + variant_seed = st.text_input("Variant Seed:", value=st.session_state['defaults'].txt2img.seed, help="The seed to use when generating a variant, if left blank a random seed will be generated.") + galleryCont = st.empty() + + if generate_button: + #print("Loading models") + # load the models when we hit the generate button for the first time, it wont be loaded after that so dont worry. + load_models(False, st.session_state["use_GFPGAN"], st.session_state["use_RealESRGAN"], st.session_state["RealESRGAN_model"], st.session_state["CustomModel_available"], + st.session_state["custom_model"]) + + + try: + # + output_images, seeds, info, stats = txt2img(prompt, st.session_state.sampling_steps, sampler_name, st.session_state["RealESRGAN_model"], batch_count, batch_size, + cfg_scale, seed, height, width, separate_prompts, normalize_prompt_weights, save_individual_images, + save_grid, group_by_prompt, save_as_jpg, st.session_state["use_GFPGAN"], st.session_state["use_RealESRGAN"], st.session_state["RealESRGAN_model"], + variant_amount=variant_amount, variant_seed=variant_seed, write_info_files=write_info_files) + + message.success('Render Complete: ' + info + '; Stats: ' + stats, icon="✅") + + #history_tab,col1,col2,col3,PlaceHolder,col1_cont,col2_cont,col3_cont = st.session_state['historyTab'] + + #if 'latestImages' in st.session_state: + #for i in output_images: + ##push the new image to the list of latest images and remove the oldest one + ##remove the last index from the list\ + #st.session_state['latestImages'].pop() + ##add the new image to the start of the list + #st.session_state['latestImages'].insert(0, i) + #PlaceHolder.empty() + #with PlaceHolder.container(): + #col1, col2, col3 = st.columns(3) + #col1_cont = st.container() + #col2_cont = st.container() + #col3_cont = st.container() + #images = st.session_state['latestImages'] + #with col1_cont: + #with col1: + #[st.image(images[index]) for index in [0, 3, 6] if index < len(images)] + #with col2_cont: + #with col2: + #[st.image(images[index]) for index in [1, 4, 7] if index < len(images)] + #with col3_cont: + #with col3: + #[st.image(images[index]) for index in [2, 5, 8] if index < len(images)] + #historyGallery = st.empty() + + ## check if output_images length is the same as seeds length + #with gallery_tab: + #st.markdown(createHTMLGallery(output_images,seeds), unsafe_allow_html=True) + + + #st.session_state['historyTab'] = [history_tab,col1,col2,col3,PlaceHolder,col1_cont,col2_cont,col3_cont] + + except (StopException, KeyError): + print(f"Received Streamlit StopException") + + # this will render all the images at the end of the generation but its better if its moved to a second tab inside col2 and shown as a gallery. + # use the current col2 first tab to show the preview_img and update it as its generated. + #preview_image.image(output_images) + +#on import run init +def createHTMLGallery(images,info): + html3 = """ + ' + return html3 \ No newline at end of file diff --git a/scripts/txt2vid.py b/scripts/txt2vid.py new file mode 100644 index 0000000..e1be209 --- /dev/null +++ b/scripts/txt2vid.py @@ -0,0 +1,780 @@ +# base webui import and utils. +from webui_streamlit import st +from sd_utils import * + +# streamlit imports +from streamlit import StopException +from streamlit.runtime.in_memory_file_manager import in_memory_file_manager +from streamlit.elements import image as STImage + +#other imports + +import os +from PIL import Image +import torch +import numpy as np +import time, inspect, timeit +import torch +from torch import autocast +from io import BytesIO +import imageio +from slugify import slugify + +# Temp imports + +# these are for testing txt2vid, should be removed and we should use things from our own code. +from diffusers import StableDiffusionPipeline +from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler + +# end of imports +#--------------------------------------------------------------------------------------------------------------- + +try: + # this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start. + from transformers import logging + + logging.set_verbosity_error() +except: + pass + +class plugin_info(): + plugname = "txt2img" + description = "Text to Image" + isTab = True + displayPriority = 1 + + +if os.path.exists(os.path.join(st.session_state['defaults'].general.GFPGAN_dir, "experiments", "pretrained_models", "GFPGANv1.3.pth")): + GFPGAN_available = True +else: + GFPGAN_available = False + +if os.path.exists(os.path.join(st.session_state['defaults'].general.RealESRGAN_dir, "experiments","pretrained_models", f"{st.session_state['defaults'].txt2vid.RealESRGAN_model}.pth")): + RealESRGAN_available = True +else: + RealESRGAN_available = False + +# +# ----------------------------------------------------------------------------- + +@torch.no_grad() +def diffuse( + pipe, + cond_embeddings, # text conditioning, should be (1, 77, 768) + cond_latents, # image conditioning, should be (1, 4, 64, 64) + num_inference_steps, + cfg_scale, + eta, + ): + + torch_device = cond_latents.get_device() + + # classifier guidance: add the unconditional embedding + max_length = cond_embeddings.shape[1] # 77 + uncond_input = pipe.tokenizer([""], padding="max_length", max_length=max_length, return_tensors="pt") + uncond_embeddings = pipe.text_encoder(uncond_input.input_ids.to(torch_device))[0] + text_embeddings = torch.cat([uncond_embeddings, cond_embeddings]) + + # if we use LMSDiscreteScheduler, let's make sure latents are mulitplied by sigmas + if isinstance(pipe.scheduler, LMSDiscreteScheduler): + cond_latents = cond_latents * pipe.scheduler.sigmas[0] + + # init the scheduler + accepts_offset = "offset" in set(inspect.signature(pipe.scheduler.set_timesteps).parameters.keys()) + extra_set_kwargs = {} + if accepts_offset: + extra_set_kwargs["offset"] = 1 + + pipe.scheduler.set_timesteps(num_inference_steps + st.session_state.sampling_steps, **extra_set_kwargs) + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + accepts_eta = "eta" in set(inspect.signature(pipe.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + + step_counter = 0 + inference_counter = 0 + + if "current_chunk_speed" not in st.session_state: + st.session_state["current_chunk_speed"] = 0 + + if "previous_chunk_speed_list" not in st.session_state: + st.session_state["previous_chunk_speed_list"] = [0] + st.session_state["previous_chunk_speed_list"].append(st.session_state["current_chunk_speed"]) + + if "update_preview_frequency_list" not in st.session_state: + st.session_state["update_preview_frequency_list"] = [0] + st.session_state["update_preview_frequency_list"].append(st.session_state['defaults'].txt2vid.update_preview_frequency) + + + # diffuse! + for i, t in enumerate(pipe.scheduler.timesteps): + start = timeit.default_timer() + + #status_text.text(f"Running step: {step_counter}{total_number_steps} {percent} | {duration:.2f}{speed}") + + # expand the latents for classifier free guidance + latent_model_input = torch.cat([cond_latents] * 2) + if isinstance(pipe.scheduler, LMSDiscreteScheduler): + sigma = pipe.scheduler.sigmas[i] + latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5) + + # predict the noise residual + noise_pred = pipe.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"] + + # cfg + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + cfg_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + if isinstance(pipe.scheduler, LMSDiscreteScheduler): + cond_latents = pipe.scheduler.step(noise_pred, i, cond_latents, **extra_step_kwargs)["prev_sample"] + else: + cond_latents = pipe.scheduler.step(noise_pred, t, cond_latents, **extra_step_kwargs)["prev_sample"] + + #print (st.session_state["update_preview_frequency"]) + #update the preview image if it is enabled and the frequency matches the step_counter + if st.session_state['defaults'].txt2vid.update_preview: + step_counter += 1 + + if st.session_state['defaults'].txt2vid.update_preview_frequency == step_counter or step_counter == st.session_state.sampling_steps: + if st.session_state.dynamic_preview_frequency: + st.session_state["current_chunk_speed"], st.session_state["previous_chunk_speed_list"], st.session_state['defaults'].txt2vid.update_preview_frequency, st.session_state["avg_update_preview_frequency"] = optimize_update_preview_frequency(st.session_state["current_chunk_speed"], st.session_state["previous_chunk_speed_list"], st.session_state['defaults'].txt2vid.update_preview_frequency, st.session_state["update_preview_frequency_list"]) + + #scale and decode the image latents with vae + cond_latents_2 = 1 / 0.18215 * cond_latents + image = pipe.vae.decode(cond_latents_2) + + # generate output numpy image as uint8 + image = torch.clamp((image["sample"] + 1.0) / 2.0, min=0.0, max=1.0) + image = transforms.ToPILImage()(image.squeeze_(0)) + + st.session_state["preview_image"].image(image) + + step_counter = 0 + + duration = timeit.default_timer() - start + + st.session_state["current_chunk_speed"] = duration + + if duration >= 1: + speed = "s/it" + else: + speed = "it/s" + duration = 1 / duration + + if i > st.session_state.sampling_steps: + inference_counter += 1 + inference_percent = int(100 * float(inference_counter + 1 if inference_counter < num_inference_steps else num_inference_steps)/float(num_inference_steps)) + inference_progress = f"{inference_counter + 1 if inference_counter < num_inference_steps else num_inference_steps}/{num_inference_steps} {inference_percent}% " + else: + inference_progress = "" + + percent = int(100 * float(i+1 if i+1 < st.session_state.sampling_steps else st.session_state.sampling_steps)/float(st.session_state.sampling_steps)) + frames_percent = int(100 * float(st.session_state.current_frame if st.session_state.current_frame < st.session_state.max_frames else st.session_state.max_frames)/float(st.session_state.max_frames)) + + st.session_state["progress_bar_text"].text( + f"Running step: {i+1 if i+1 < st.session_state.sampling_steps else st.session_state.sampling_steps}/{st.session_state.sampling_steps} " + f"{percent if percent < 100 else 100}% {inference_progress}{duration:.2f}{speed} | " + f"Frame: {st.session_state.current_frame + 1 if st.session_state.current_frame < st.session_state.max_frames else st.session_state.max_frames}/{st.session_state.max_frames} " + f"{frames_percent if frames_percent < 100 else 100}% {st.session_state.frame_duration:.2f}{st.session_state.frame_speed}" + ) + st.session_state["progress_bar"].progress(percent if percent < 100 else 100) + + return image + +# +def txt2vid( + # -------------------------------------- + # args you probably want to change + prompts = ["blueberry spaghetti", "strawberry spaghetti"], # prompt to dream about + gpu:int = st.session_state['defaults'].general.gpu, # id of the gpu to run on + #name:str = 'test', # name of this project, for the output directory + #rootdir:str = st.session_state['defaults'].general.outdir, + num_steps:int = 200, # number of steps between each pair of sampled points + max_frames:int = 10000, # number of frames to write and then exit the script + num_inference_steps:int = 50, # more (e.g. 100, 200 etc) can create slightly better images + cfg_scale:float = 5.0, # can depend on the prompt. usually somewhere between 3-10 is good + do_loop = False, + use_lerp_for_text = False, + seeds = None, + quality:int = 100, # for jpeg compression of the output images + eta:float = 0.0, + width:int = 256, + height:int = 256, + weights_path = "CompVis/stable-diffusion-v1-4", + scheduler="klms", # choices: default, ddim, klms + disable_tqdm = False, + #----------------------------------------------- + beta_start = 0.0001, + beta_end = 0.00012, + beta_schedule = "scaled_linear", + starting_image=None + ): + """ + prompt = ["blueberry spaghetti", "strawberry spaghetti"], # prompt to dream about + gpu:int = st.session_state['defaults'].general.gpu, # id of the gpu to run on + #name:str = 'test', # name of this project, for the output directory + #rootdir:str = st.session_state['defaults'].general.outdir, + num_steps:int = 200, # number of steps between each pair of sampled points + max_frames:int = 10000, # number of frames to write and then exit the script + num_inference_steps:int = 50, # more (e.g. 100, 200 etc) can create slightly better images + cfg_scale:float = 5.0, # can depend on the prompt. usually somewhere between 3-10 is good + do_loop = False, + use_lerp_for_text = False, + seed = None, + quality:int = 100, # for jpeg compression of the output images + eta:float = 0.0, + width:int = 256, + height:int = 256, + weights_path = "CompVis/stable-diffusion-v1-4", + scheduler="klms", # choices: default, ddim, klms + disable_tqdm = False, + beta_start = 0.0001, + beta_end = 0.00012, + beta_schedule = "scaled_linear" + """ + mem_mon = MemUsageMonitor('MemMon') + mem_mon.start() + + + seeds = seed_to_int(seeds) + + # We add an extra frame because most + # of the time the first frame is just the noise. + #max_frames +=1 + + assert torch.cuda.is_available() + assert height % 8 == 0 and width % 8 == 0 + torch.manual_seed(seeds) + torch_device = f"cuda:{gpu}" + + # init the output dir + sanitized_prompt = slugify(prompts) + + full_path = os.path.join(os.getcwd(), st.session_state['defaults'].general.outdir, "txt2vid-samples", "samples", sanitized_prompt) + + if len(full_path) > 220: + sanitized_prompt = sanitized_prompt[:220-len(full_path)] + full_path = os.path.join(os.getcwd(), st.session_state['defaults'].general.outdir, "txt2vid-samples", "samples", sanitized_prompt) + + os.makedirs(full_path, exist_ok=True) + + # Write prompt info to file in output dir so we can keep track of what we did + if st.session_state.write_info_files: + with open(os.path.join(full_path , f'{slugify(str(seeds))}_config.json' if len(prompts) > 1 else "prompts_config.json"), "w") as outfile: + outfile.write(json.dumps( + dict( + prompts = prompts, + gpu = gpu, + num_steps = num_steps, + max_frames = max_frames, + num_inference_steps = num_inference_steps, + cfg_scale = cfg_scale, + do_loop = do_loop, + use_lerp_for_text = use_lerp_for_text, + seeds = seeds, + quality = quality, + eta = eta, + width = width, + height = height, + weights_path = weights_path, + scheduler=scheduler, + disable_tqdm = disable_tqdm, + beta_start = beta_start, + beta_end = beta_end, + beta_schedule = beta_schedule + ), + indent=2, + sort_keys=False, + )) + + #print(scheduler) + default_scheduler = PNDMScheduler( + beta_start=beta_start, beta_end=beta_end, beta_schedule=beta_schedule + ) + # ------------------------------------------------------------------------------ + #Schedulers + ddim_scheduler = DDIMScheduler( + beta_start=beta_start, + beta_end=beta_end, + beta_schedule=beta_schedule, + clip_sample=False, + set_alpha_to_one=False, + ) + + klms_scheduler = LMSDiscreteScheduler( + beta_start=beta_start, beta_end=beta_end, beta_schedule=beta_schedule + ) + + SCHEDULERS = dict(default=default_scheduler, ddim=ddim_scheduler, klms=klms_scheduler) + + # ------------------------------------------------------------------------------ + st.session_state["progress_bar_text"].text("Loading models...") + + try: + if "model" in st.session_state: + del st.session_state["model"] + except: + pass + + #print (st.session_state["weights_path"] != weights_path) + + try: + if not "pipe" in st.session_state or st.session_state["weights_path"] != weights_path: + if st.session_state["weights_path"] != weights_path: + del st.session_state["weights_path"] + + st.session_state["weights_path"] = weights_path + st.session_state["pipe"] = StableDiffusionPipeline.from_pretrained( + weights_path, + use_local_file=True, + use_auth_token=True, + torch_dtype=torch.float16 if st.session_state['defaults'].general.use_float16 else None, + revision="fp16" if not st.session_state['defaults'].general.no_half else None + ) + + st.session_state["pipe"].unet.to(torch_device) + st.session_state["pipe"].vae.to(torch_device) + st.session_state["pipe"].text_encoder.to(torch_device) + + if st.session_state.defaults.general.enable_attention_slicing: + st.session_state["pipe"].enable_attention_slicing() + if st.session_state.defaults.general.enable_minimal_memory_usage: + st.session_state["pipe"].enable_minimal_memory_usage() + + print("Tx2Vid Model Loaded") + else: + print("Tx2Vid Model already Loaded") + + except: + #del st.session_state["weights_path"] + #del st.session_state["pipe"] + + st.session_state["weights_path"] = weights_path + st.session_state["pipe"] = StableDiffusionPipeline.from_pretrained( + weights_path, + use_local_file=True, + use_auth_token=True, + torch_dtype=torch.float16 if st.session_state['defaults'].general.use_float16 else None, + revision="fp16" if not st.session_state['defaults'].general.no_half else None + ) + + st.session_state["pipe"].unet.to(torch_device) + st.session_state["pipe"].vae.to(torch_device) + st.session_state["pipe"].text_encoder.to(torch_device) + + if st.session_state.defaults.general.enable_attention_slicing: + st.session_state["pipe"].enable_attention_slicing() + + + print("Tx2Vid Model Loaded") + + st.session_state["pipe"].scheduler = SCHEDULERS[scheduler] + + # get the conditional text embeddings based on the prompt + text_input = st.session_state["pipe"].tokenizer(prompts, padding="max_length", max_length=st.session_state["pipe"].tokenizer.model_max_length, truncation=True, return_tensors="pt") + cond_embeddings = st.session_state["pipe"].text_encoder(text_input.input_ids.to(torch_device))[0] # shape [1, 77, 768] + + # + if st.session_state.defaults.general.use_sd_concepts_library: + + prompt_tokens = re.findall('<([a-zA-Z0-9-]+)>', prompts) + + if prompt_tokens: + # compviz + #tokenizer = (st.session_state["model"] if not st.session_state['defaults'].general.optimized else st.session_state.modelCS).cond_stage_model.tokenizer + #text_encoder = (st.session_state["model"] if not st.session_state['defaults'].general.optimized else st.session_state.modelCS).cond_stage_model.transformer + + # diffusers + tokenizer = st.session_state.pipe.tokenizer + text_encoder = st.session_state.pipe.text_encoder + + ext = ('pt', 'bin') + #print (prompt_tokens) + + if len(prompt_tokens) > 1: + for token_name in prompt_tokens: + embedding_path = os.path.join(st.session_state['defaults'].general.sd_concepts_library_folder, token_name) + if os.path.exists(embedding_path): + for files in os.listdir(embedding_path): + if files.endswith(ext): + load_learned_embed_in_clip(f"{os.path.join(embedding_path, files)}", text_encoder, tokenizer, f"<{token_name}>") + else: + embedding_path = os.path.join(st.session_state['defaults'].general.sd_concepts_library_folder, prompt_tokens[0]) + if os.path.exists(embedding_path): + for files in os.listdir(embedding_path): + if files.endswith(ext): + load_learned_embed_in_clip(f"{os.path.join(embedding_path, files)}", text_encoder, tokenizer, f"<{prompt_tokens[0]}>") + + # sample a source + init1 = torch.randn((1, st.session_state["pipe"].unet.in_channels, height // 8, width // 8), device=torch_device) + + if do_loop: + prompts = [prompts, prompts] + seeds = [seeds, seeds] + #first_seed, *seeds = seeds + #prompts.append(prompts) + #seeds.append(first_seed) + + + # iterate the loop + frames = [] + frame_index = 0 + + st.session_state["total_frames_avg_duration"] = [] + st.session_state["total_frames_avg_speed"] = [] + + try: + while frame_index < max_frames: + st.session_state["frame_duration"] = 0 + st.session_state["frame_speed"] = 0 + st.session_state["current_frame"] = frame_index + + # sample the destination + init2 = torch.randn((1, st.session_state["pipe"].unet.in_channels, height // 8, width // 8), device=torch_device) + + for i, t in enumerate(np.linspace(0, 1, max_frames)): + start = timeit.default_timer() + print(f"COUNT: {frame_index+1}/{max_frames}") + + #if use_lerp_for_text: + #init = torch.lerp(init1, init2, float(t)) + #else: + #init = slerp(gpu, float(t), init1, init2) + + init = slerp(gpu, float(t), init1, init2) + + with autocast("cuda"): + image = diffuse(st.session_state["pipe"], cond_embeddings, init, num_inference_steps, cfg_scale, eta) + + #im = Image.fromarray(image) + outpath = os.path.join(full_path, 'frame%06d.png' % frame_index) + image.save(outpath, quality=quality) + + # send the image to the UI to update it + #st.session_state["preview_image"].image(im) + + #append the frames to the frames list so we can use them later. + frames.append(np.asarray(image)) + + #increase frame_index counter. + frame_index += 1 + + st.session_state["current_frame"] = frame_index + + duration = timeit.default_timer() - start + + if duration >= 1: + speed = "s/it" + else: + speed = "it/s" + duration = 1 / duration + + st.session_state["frame_duration"] = duration + st.session_state["frame_speed"] = speed + + init1 = init2 + + except StopException: + pass + + + if st.session_state['save_video']: + # write video to memory + #output = io.BytesIO() + #writer = imageio.get_writer(os.path.join(os.getcwd(), st.session_state['defaults'].general.outdir, "txt2vid-samples"), im, extension=".mp4", fps=30) + try: + video_path = os.path.join(os.getcwd(), st.session_state['defaults'].general.outdir, "txt2vid-samples","temp.mp4") + writer = imageio.get_writer(video_path, fps=24) + for frame in frames: + writer.append_data(frame) + writer.close() + except: + print("Can't save video, skipping.") + + # show video preview on the UI + st.session_state["preview_video"].video(open(video_path, 'rb').read()) + + mem_max_used, mem_total = mem_mon.read_and_stop() + time_diff = time.time()- start + + info = f""" + {prompts} + Sampling Steps: {num_steps}, Sampler: {scheduler}, CFG scale: {cfg_scale}, Seed: {seeds}, Max Frames: {max_frames}""".strip() + stats = f''' + Took { round(time_diff, 2) }s total ({ round(time_diff/(max_frames),2) }s per image) + Peak memory usage: { -(mem_max_used // -1_048_576) } MiB / { -(mem_total // -1_048_576) } MiB / { round(mem_max_used/mem_total*100, 3) }%''' + + return video_path, seeds, info, stats + +#on import run init +def createHTMLGallery(images,info): + html3 = """ + ' + return html3 +# +def layout(): + with st.form("txt2vid-inputs"): + st.session_state["generation_mode"] = "txt2vid" + + input_col1, generate_col1 = st.columns([10,1]) + with input_col1: + #prompt = st.text_area("Input Text","") + prompt = st.text_input("Input Text","", placeholder="A corgi wearing a top hat as an oil painting.") + + # Every form must have a submit button, the extra blank spaces is a temp way to align it with the input field. Needs to be done in CSS or some other way. + generate_col1.write("") + generate_col1.write("") + generate_button = generate_col1.form_submit_button("Generate") + + # creating the page layout using columns + col1, col2, col3 = st.columns([1,2,1], gap="large") + + with col1: + width = st.slider("Width:", min_value=64, max_value=2048, value=st.session_state['defaults'].txt2vid.width, step=64) + height = st.slider("Height:", min_value=64, max_value=2048, value=st.session_state['defaults'].txt2vid.height, step=64) + cfg_scale = st.slider("CFG (Classifier Free Guidance Scale):", min_value=1.0, max_value=30.0, value=st.session_state['defaults'].txt2vid.cfg_scale, step=0.5, help="How strongly the image should follow the prompt.") + + #uploaded_images = st.file_uploader("Upload Image", accept_multiple_files=False, type=["png", "jpg", "jpeg", "webp"], + #help="Upload an image which will be used for the image to image generation.") + seed = st.text_input("Seed:", value=st.session_state['defaults'].txt2vid.seed, help=" The seed to use, if left blank a random seed will be generated.") + #batch_count = st.slider("Batch count.", min_value=1, max_value=100, value=st.session_state['defaults'].txt2vid.batch_count, step=1, help="How many iterations or batches of images to generate in total.") + #batch_size = st.slider("Batch size", min_value=1, max_value=250, value=st.session_state['defaults'].txt2vid.batch_size, step=1, + #help="How many images are at once in a batch.\ + #It increases the VRAM usage a lot but if you have enough VRAM it can reduce the time it takes to finish generation as more images are generated at once.\ + #Default: 1") + + st.session_state["max_frames"] = int(st.text_input("Max Frames:", value=st.session_state['defaults'].txt2vid.max_frames, help="Specify the max number of frames you want to generate.")) + + with st.expander("Preview Settings"): + st.session_state["update_preview"] = st.checkbox("Update Image Preview", value=st.session_state['defaults'].txt2vid.update_preview, + help="If enabled the image preview will be updated during the generation instead of at the end. \ + You can use the Update Preview \Frequency option bellow to customize how frequent it's updated. \ + By default this is enabled and the frequency is set to 1 step.") + + st.session_state["update_preview_frequency"] = st.text_input("Update Image Preview Frequency", value=st.session_state['defaults'].txt2vid.update_preview_frequency, + help="Frequency in steps at which the the preview image is updated. By default the frequency \ + is set to 1 step.") + + # + + + + with col2: + preview_tab, gallery_tab = st.tabs(["Preview", "Gallery"]) + + with preview_tab: + #st.write("Image") + #Image for testing + #image = Image.open(requests.get("https://icon-library.com/images/image-placeholder-icon/image-placeholder-icon-13.jpg", stream=True).raw).convert('RGB') + #new_image = image.resize((175, 240)) + #preview_image = st.image(image) + + # create an empty container for the image, progress bar, etc so we can update it later and use session_state to hold them globally. + st.session_state["preview_image"] = st.empty() + + st.session_state["loading"] = st.empty() + + st.session_state["progress_bar_text"] = st.empty() + st.session_state["progress_bar"] = st.empty() + + #generate_video = st.empty() + st.session_state["preview_video"] = st.empty() + + message = st.empty() + + with gallery_tab: + st.write('Here should be the image gallery, if I could make a grid in streamlit.') + + with col3: + # If we have custom models available on the "models/custom" + #folder then we show a menu to select which model we want to use, otherwise we use the main model for SD + if st.session_state["CustomModel_available"]: + custom_model = st.selectbox("Custom Model:", st.session_state["defaults"].txt2vid.custom_models_list, + index=st.session_state["defaults"].txt2vid.custom_models_list.index(st.session_state["defaults"].txt2vid.default_model), + help="Select the model you want to use. This option is only available if you have custom models \ + on your 'models/custom' folder. The model name that will be shown here is the same as the name\ + the file for the model has on said folder, it is recommended to give the .ckpt file a name that \ + will make it easier for you to distinguish it from other models. Default: Stable Diffusion v1.4") + else: + custom_model = "CompVis/stable-diffusion-v1-4" + + #st.session_state["weights_path"] = custom_model + #else: + #custom_model = "CompVis/stable-diffusion-v1-4" + #st.session_state["weights_path"] = f"CompVis/{slugify(custom_model.lower())}" + + st.session_state.sampling_steps = st.slider("Sampling Steps", + value=st.session_state['defaults'].txt2vid.sampling_steps, + min_value=st.session_state['defaults'].txt2vid.slider_bounds.sampling.lower, + max_value=st.session_state['defaults'].txt2vid.slider_bounds.sampling.upper, + step=st.session_state['defaults'].txt2vid.slider_steps.sampling, + help="Number of steps between each pair of sampled points") + st.session_state.num_inference_steps = st.slider("Inference Steps:", value=st.session_state['defaults'].txt2vid.num_inference_steps, min_value=10,step=10, max_value=500, + help="Higher values (e.g. 100, 200 etc) can create better images.") + + #sampler_name_list = ["k_lms", "k_euler", "k_euler_a", "k_dpm_2", "k_dpm_2_a", "k_heun", "PLMS", "DDIM"] + #sampler_name = st.selectbox("Sampling method", sampler_name_list, + #index=sampler_name_list.index(st.session_state['defaults'].txt2vid.default_sampler), help="Sampling method to use. Default: k_euler") + scheduler_name_list = ["klms", "ddim"] + scheduler_name = st.selectbox("Scheduler:", scheduler_name_list, + index=scheduler_name_list.index(st.session_state['defaults'].txt2vid.scheduler_name), help="Scheduler to use. Default: klms") + + beta_scheduler_type_list = ["scaled_linear", "linear"] + beta_scheduler_type = st.selectbox("Beta Schedule Type:", beta_scheduler_type_list, + index=beta_scheduler_type_list.index(st.session_state['defaults'].txt2vid.beta_scheduler_type), help="Schedule Type to use. Default: linear") + + + #basic_tab, advanced_tab = st.tabs(["Basic", "Advanced"]) + + #with basic_tab: + #summit_on_enter = st.radio("Submit on enter?", ("Yes", "No"), horizontal=True, + #help="Press the Enter key to summit, when 'No' is selected you can use the Enter key to write multiple lines.") + + with st.expander("Advanced"): + st.session_state["separate_prompts"] = st.checkbox("Create Prompt Matrix.", value=st.session_state['defaults'].txt2vid.separate_prompts, + help="Separate multiple prompts using the `|` character, and get all combinations of them.") + st.session_state["normalize_prompt_weights"] = st.checkbox("Normalize Prompt Weights.", + value=st.session_state['defaults'].txt2vid.normalize_prompt_weights, help="Ensure the sum of all weights add up to 1.0") + st.session_state["save_individual_images"] = st.checkbox("Save individual images.", + value=st.session_state['defaults'].txt2vid.save_individual_images, help="Save each image generated before any filter or enhancement is applied.") + st.session_state["save_video"] = st.checkbox("Save video",value=st.session_state['defaults'].txt2vid.save_video, help="Save a video with all the images generated as frames at the end of the generation.") + st.session_state["group_by_prompt"] = st.checkbox("Group results by prompt", value=st.session_state['defaults'].txt2vid.group_by_prompt, + help="Saves all the images with the same prompt into the same folder. When using a prompt matrix each prompt combination will have its own folder.") + st.session_state["write_info_files"] = st.checkbox("Write Info file", value=st.session_state['defaults'].txt2vid.write_info_files, + help="Save a file next to the image with informartion about the generation.") + st.session_state["dynamic_preview_frequency"] = st.checkbox("Dynamic Preview Frequency", value=st.session_state['defaults'].txt2vid.dynamic_preview_frequency, + help="This option tries to find the best value at which we can update \ + the preview image during generation while minimizing the impact it has in performance. Default: True") + st.session_state["do_loop"] = st.checkbox("Do Loop", value=st.session_state['defaults'].txt2vid.do_loop, + help="Do loop") + st.session_state["save_as_jpg"] = st.checkbox("Save samples as jpg", value=st.session_state['defaults'].txt2vid.save_as_jpg, help="Saves the images as jpg instead of png.") + + if GFPGAN_available: + st.session_state["use_GFPGAN"] = st.checkbox("Use GFPGAN", value=st.session_state['defaults'].txt2vid.use_GFPGAN, help="Uses the GFPGAN model to improve faces after the generation. This greatly improve the quality and consistency of faces but uses extra VRAM. Disable if you need the extra VRAM.") + else: + st.session_state["use_GFPGAN"] = False + + if RealESRGAN_available: + st.session_state["use_RealESRGAN"] = st.checkbox("Use RealESRGAN", value=st.session_state['defaults'].txt2vid.use_RealESRGAN, + help="Uses the RealESRGAN model to upscale the images after the generation. This greatly improve the quality and lets you have high resolution images but uses extra VRAM. Disable if you need the extra VRAM.") + st.session_state["RealESRGAN_model"] = st.selectbox("RealESRGAN model", ["RealESRGAN_x4plus", "RealESRGAN_x4plus_anime_6B"], index=0) + else: + st.session_state["use_RealESRGAN"] = False + st.session_state["RealESRGAN_model"] = "RealESRGAN_x4plus" + + st.session_state["variant_amount"] = st.slider("Variant Amount:", value=st.session_state['defaults'].txt2vid.variant_amount, min_value=0.0, max_value=1.0, step=0.01) + st.session_state["variant_seed"] = st.text_input("Variant Seed:", value=st.session_state['defaults'].txt2vid.seed, help="The seed to use when generating a variant, if left blank a random seed will be generated.") + st.session_state["beta_start"] = st.slider("Beta Start:", value=st.session_state['defaults'].txt2vid.beta_start, min_value=0.0001, max_value=0.03, step=0.0001, format="%.4f") + st.session_state["beta_end"] = st.slider("Beta End:", value=st.session_state['defaults'].txt2vid.beta_end, min_value=0.0001, max_value=0.03, step=0.0001, format="%.4f") + + if generate_button: + #print("Loading models") + # load the models when we hit the generate button for the first time, it wont be loaded after that so dont worry. + #load_models(False, False, False, st.session_state["RealESRGAN_model"], CustomModel_available=st.session_state["CustomModel_available"], custom_model=custom_model) + + try: + # run video generation + video, seed, info, stats = txt2vid(prompts=prompt, gpu=st.session_state["defaults"].general.gpu, + num_steps=st.session_state.sampling_steps, max_frames=int(st.session_state.max_frames), + num_inference_steps=st.session_state.num_inference_steps, + cfg_scale=cfg_scale,do_loop=st.session_state["do_loop"], + seeds=seed, quality=100, eta=0.0, width=width, + height=height, weights_path=custom_model, scheduler=scheduler_name, + disable_tqdm=False, beta_start=st.session_state["beta_start"], beta_end=st.session_state["beta_end"], + beta_schedule=beta_scheduler_type, starting_image=None) + + #message.success('Done!', icon="✅") + message.success('Render Complete: ' + info + '; Stats: ' + stats, icon="✅") + + #history_tab,col1,col2,col3,PlaceHolder,col1_cont,col2_cont,col3_cont = st.session_state['historyTab'] + + #if 'latestVideos' in st.session_state: + #for i in video: + ##push the new image to the list of latest images and remove the oldest one + ##remove the last index from the list\ + #st.session_state['latestVideos'].pop() + ##add the new image to the start of the list + #st.session_state['latestVideos'].insert(0, i) + #PlaceHolder.empty() + + #with PlaceHolder.container(): + #col1, col2, col3 = st.columns(3) + #col1_cont = st.container() + #col2_cont = st.container() + #col3_cont = st.container() + + #with col1_cont: + #with col1: + #st.image(st.session_state['latestVideos'][0]) + #st.image(st.session_state['latestVideos'][3]) + #st.image(st.session_state['latestVideos'][6]) + #with col2_cont: + #with col2: + #st.image(st.session_state['latestVideos'][1]) + #st.image(st.session_state['latestVideos'][4]) + #st.image(st.session_state['latestVideos'][7]) + #with col3_cont: + #with col3: + #st.image(st.session_state['latestVideos'][2]) + #st.image(st.session_state['latestVideos'][5]) + #st.image(st.session_state['latestVideos'][8]) + #historyGallery = st.empty() + + ## check if output_images length is the same as seeds length + #with gallery_tab: + #st.markdown(createHTMLGallery(video,seed), unsafe_allow_html=True) + + + #st.session_state['historyTab'] = [history_tab,col1,col2,col3,PlaceHolder,col1_cont,col2_cont,col3_cont] + + except (StopException, KeyError): + print(f"Received Streamlit StopException") + + diff --git a/scripts/webui.py b/scripts/webui.py index dd64a4c..eb5d32f 100644 --- a/scripts/webui.py +++ b/scripts/webui.py @@ -2,8 +2,10 @@ import argparse, os, sys, glob, re import cv2 +from perlin import perlinNoise from frontend.frontend import draw_gradio_ui from frontend.job_manager import JobManager, JobInfo +from frontend.image_metadata import ImageMetadata from frontend.ui_functions import resize_image parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument("--ckpt", type=str, default="models/ldm/stable-diffusion-v1/model.ckpt", help="path to checkpoint of model",) @@ -13,7 +15,7 @@ parser.add_argument("--defaults", type=str, help="path to configuration file pro parser.add_argument("--esrgan-cpu", action='store_true', help="run ESRGAN on cpu", default=False) parser.add_argument("--esrgan-gpu", type=int, help="run ESRGAN on specific gpu (overrides --gpu)", default=0) parser.add_argument("--extra-models-cpu", action='store_true', help="run extra models (GFGPAN/ESRGAN) on cpu", default=False) -parser.add_argument("--extra-models-gpu", action='store_true', help="run extra models (GFGPAN/ESRGAN) on cpu", default=False) +parser.add_argument("--extra-models-gpu", action='store_true', help="run extra models (GFGPAN/ESRGAN) on gpu", default=False) parser.add_argument("--gfpgan-cpu", action='store_true', help="run GFPGAN on cpu", default=False) parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN')) # i disagree with where you're putting it but since all guidefags are doing it this way, there you go parser.add_argument("--gfpgan-gpu", type=int, help="run GFPGAN on specific gpu (overrides --gpu) ", default=0) @@ -31,6 +33,7 @@ parser.add_argument("--outdir_img2img", type=str, nargs="?", help="dir to write parser.add_argument("--outdir_imglab", type=str, nargs="?", help="dir to write imglab results to (overrides --outdir)", default=None) parser.add_argument("--outdir_txt2img", type=str, nargs="?", help="dir to write txt2img results to (overrides --outdir)", default=None) parser.add_argument("--outdir", type=str, nargs="?", help="dir to write results to", default=None) +parser.add_argument("--filename_format", type=str, nargs="?", help="filenames format", default=None) parser.add_argument("--port", type=int, help="choose the port for the gradio webserver to use", default=7860) parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast") parser.add_argument("--realesrgan-dir", type=str, help="RealESRGAN directory", default=('./src/realesrgan' if os.path.exists('./src/realesrgan') else './RealESRGAN')) @@ -42,6 +45,7 @@ parser.add_argument("--skip-grid", action='store_true', help="do not save a grid parser.add_argument("--skip-save", action='store_true', help="do not save indiviual samples. For speed measurements.", default=False) parser.add_argument('--no-job-manager', action='store_true', help="Don't use the experimental job manager on top of gradio", default=False) parser.add_argument("--max-jobs", type=int, help="Maximum number of concurrent 'generate' commands", default=1) +parser.add_argument("--tiling", action='store_true', help="Generate tiling images", default=False) opt = parser.parse_args() #Should not be needed anymore @@ -66,16 +70,27 @@ import torch import torch.nn as nn import yaml import glob -from typing import List, Union, Dict +import copy +from typing import List, Union, Dict, Callable, Any, Optional from pathlib import Path from collections import namedtuple +from functools import partial + +# tell the user which GPU the code is actually using +if os.getenv("SD_WEBUI_DEBUG", 'False').lower() in ('true', '1', 'y'): + gpu_in_use = opt.gpu + # prioritize --esrgan-gpu and --gfpgan-gpu over --gpu, as stated in the option info + if opt.esrgan_gpu != opt.gpu: + gpu_in_use = opt.esrgan_gpu + elif opt.gfpgan_gpu != opt.gpu: + gpu_in_use = opt.gfpgan_gpu + print("Starting on GPU {selected_gpu_name}".format(selected_gpu_name=torch.cuda.get_device_name(gpu_in_use))) from contextlib import contextmanager, nullcontext from einops import rearrange, repeat from itertools import islice from omegaconf import OmegaConf -from PIL import Image, ImageFont, ImageDraw, ImageFilter, ImageOps -from PIL.PngImagePlugin import PngInfo +from PIL import Image, ImageFont, ImageDraw, ImageFilter, ImageOps, ImageChops from io import BytesIO import base64 import re @@ -84,6 +99,18 @@ from ldm.models.diffusion.ddim import DDIMSampler from ldm.models.diffusion.plms import PLMSSampler from ldm.util import instantiate_from_config +# add global options to models +def patch_conv(**patch): + cls = torch.nn.Conv2d + init = cls.__init__ + def __init__(self, *args, **kwargs): + return init(self, *args, **kwargs, **patch) + cls.__init__ = __init__ + +if opt.tiling: + patch_conv(padding_mode='circular') + print("patched for tiling") + try: # this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start. @@ -92,6 +119,14 @@ try: except: pass +from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from transformers import AutoFeatureExtractor + +# load safety model +safety_model_id = "CompVis/stable-diffusion-safety-checker" +safety_feature_extractor = None +safety_checker = None + # this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the bowser will not show any UI mimetypes.init() mimetypes.add_type('application/javascript', '.js') @@ -203,7 +238,16 @@ class MemUsageMonitor(threading.Thread): print(f"[{self.name}] Unable to initialize NVIDIA management. No memory stats. \n") return print(f"[{self.name}] Recording max memory usage...\n") - handle = pynvml.nvmlDeviceGetHandleByIndex(opt.gpu) + # check if we're using a scoped-down GPU environment (pynvml does not listen to CUDA_VISIBLE_DEVICES) + # so that we can measure memory on the correct GPU + try: + isinstance(int(os.environ["CUDA_VISIBLE_DEVICES"]), int) + handle = pynvml.nvmlDeviceGetHandleByIndex(int(os.environ["CUDA_VISIBLE_DEVICES"])) + except (KeyError, ValueError) as pynvmlHandleError: + if os.getenv("SD_WEBUI_DEBUG", 'False').lower() in ('true', '1', 'y'): + print("[MemMon][WARNING]", pynvmlHandleError) + print("[MemMon][INFO]", "defaulting to monitoring memory on the default gpu (set via --gpu flag)") + handle = pynvml.nvmlDeviceGetHandleByIndex(opt.gpu) self.total = pynvml.nvmlDeviceGetMemoryInfo(handle).total while not self.stop_flag: m = pynvml.nvmlDeviceGetMemoryInfo(handle) @@ -264,15 +308,21 @@ class KDiffusionSampler: self.schedule = sampler def get_sampler_name(self): return self.schedule - def sample(self, S, conditioning, batch_size, shape, verbose, unconditional_guidance_scale, unconditional_conditioning, eta, x_T): + def sample(self, S, conditioning, batch_size, shape, verbose, unconditional_guidance_scale, unconditional_conditioning, eta, x_T, img_callback: Callable = None ): sigmas = self.model_wrap.get_sigmas(S) x = x_T * sigmas[0] model_wrap_cfg = CFGDenoiser(self.model_wrap) - samples_ddim = K.sampling.__dict__[f'sample_{self.schedule}'](model_wrap_cfg, x, sigmas, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': unconditional_guidance_scale}, disable=False) + samples_ddim = K.sampling.__dict__[f'sample_{self.schedule}'](model_wrap_cfg, x, sigmas, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': unconditional_guidance_scale}, disable=False, callback=partial(KDiffusionSampler.img_callback_wrapper, img_callback)) return samples_ddim, None + @classmethod + def img_callback_wrapper(cls, callback: Callable, *args): + ''' Converts a KDiffusion callback to the standard img_callback ''' + if callback: + arg_dict = args[0] + callback(image_sample=arg_dict['denoised'], iter_num=arg_dict['i']) def create_random_tensors(shape, seeds): xs = [] @@ -592,25 +642,18 @@ def check_prompt_length(prompt, comments): comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n") -def save_sample(image, sample_path_i, filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale, -normalize_prompt_weights, use_GFPGAN, write_info_files, write_sample_info_to_log_file, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, skip_save, -skip_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, skip_metadata=True): + +def save_sample(image, sample_path_i, filename, jpg_sample, write_info_files, write_sample_info_to_log_file, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, skip_save, +skip_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, skip_metadata=False): + ''' saves the image according to selected parameters. Expects to find generation parameters on image, set by ImageMetadata.set_on_image() ''' + metadata = ImageMetadata.get_from_image(image) + if not skip_metadata and metadata is None: + print("No metadata passed in to save. Set metadata on the image before calling save_sample using the ImageMetadata.set_on_image() function.") + skip_metadata = True filename_i = os.path.join(sample_path_i, filename) if not jpg_sample: if opt.save_metadata and not skip_metadata: - metadata = PngInfo() - metadata.add_text("SD:prompt", prompts[i]) - metadata.add_text("SD:seed", str(seeds[i])) - metadata.add_text("SD:width", str(width)) - metadata.add_text("SD:height", str(height)) - metadata.add_text("SD:sampler_name", str(sampler_name)) - metadata.add_text("SD:steps", str(steps)) - metadata.add_text("SD:cfg_scale", str(cfg_scale)) - metadata.add_text("SD:normalize_prompt_weights", str(normalize_prompt_weights)) - if init_img is not None: - metadata.add_text("SD:denoising_strength", str(denoising_strength)) - metadata.add_text("SD:GFPGAN", str(use_GFPGAN and GFPGAN is not None)) - image.save(f"{filename_i}.png", pnginfo=metadata) + image.save(f"{filename_i}.png", pnginfo=metadata.as_png_info()) else: image.save(f"{filename_i}.png") else: @@ -621,7 +664,7 @@ skip_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoisin toggles = [] if prompt_matrix: toggles.append(0) - if normalize_prompt_weights: + if metadata.normalize_prompt_weights: toggles.append(1) if init_img is not None: if uses_loopback: @@ -638,14 +681,14 @@ skip_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoisin toggles.append(5 + offset) if write_sample_info_to_log_file: toggles.append(6+offset) - if use_GFPGAN: + if metadata.GFPGAN: toggles.append(7 + offset) info_dict = dict( target="txt2img" if init_img is None else "img2img", - prompt=prompts[i], ddim_steps=steps, toggles=toggles, sampler_name=sampler_name, - ddim_eta=ddim_eta, n_iter=n_iter, batch_size=batch_size, cfg_scale=cfg_scale, - seed=seeds[i], width=width, height=height + prompt=metadata.prompt, ddim_steps=metadata.steps, toggles=toggles, sampler_name=sampler_name, + ddim_eta=ddim_eta, n_iter=n_iter, batch_size=batch_size, cfg_scale=metadata.cfg_scale, + seed=metadata.seed, width=metadata.width, height=metadata.height ) if init_img is not None: # Not yet any use for these, but they bloat up the files: @@ -775,16 +818,95 @@ def oxlamon_matrix(prompt, seed, n_iter, batch_size): return all_seeds, n_iter, prompt_matrix_parts, all_prompts, needrows +def perform_masked_image_restoration(image, init_img, init_mask, mask_blur_strength, mask_restore, use_RealESRGAN, RealESRGAN): + if not mask_restore: + return image + else: + init_mask = init_mask.filter(ImageFilter.GaussianBlur(mask_blur_strength)) + init_mask = init_mask.convert('L') + init_img = init_img.convert('RGB') + image = image.convert('RGB') + if use_RealESRGAN and RealESRGAN is not None: + output, img_mode = RealESRGAN.enhance(np.array(init_mask, dtype=np.uint8)) + init_mask = Image.fromarray(output) + init_mask = init_mask.convert('L') + + output, img_mode = RealESRGAN.enhance(np.array(init_img, dtype=np.uint8)) + init_img = Image.fromarray(output) + init_img = init_img.convert('RGB') + + image = Image.composite(init_img, image, init_mask) + + return image + + +def perform_color_correction(img_rgb, correction_target_lab, do_color_correction): + try: + from skimage import exposure + except: + print("Install scikit-image to perform color correction") + return img_rgb + + if not do_color_correction: return img_rgb + if correction_target_lab is None: return img_rgb + + return ( + Image.fromarray(cv2.cvtColor(exposure.match_histograms( + cv2.cvtColor( + np.asarray(img_rgb), + cv2.COLOR_RGB2LAB + ), + correction_target_lab, + channel_axis=2 + ), cv2.COLOR_LAB2RGB).astype("uint8") + ) + ) def process_images( outpath, func_init, func_sample, prompt, seed, sampler_name, skip_grid, skip_save, batch_size, - n_iter, steps, cfg_scale, width, height, prompt_matrix, use_GFPGAN, use_RealESRGAN, realesrgan_model_name, + n_iter, steps, cfg_scale, width, height, prompt_matrix, filter_nsfw, use_GFPGAN, use_RealESRGAN, realesrgan_model_name, fp, ddim_eta=0.0, do_not_save_grid=False, normalize_prompt_weights=True, init_img=None, init_mask=None, - keep_mask=False, mask_blur_strength=3, denoising_strength=0.75, resize_mode=None, uses_loopback=False, + keep_mask=False, mask_blur_strength=3, mask_restore=False, denoising_strength=0.75, resize_mode=None, uses_loopback=False, uses_random_seed_loopback=False, sort_samples=True, write_info_files=True, write_sample_info_to_log_file=False, jpg_sample=False, - variant_amount=0.0, variant_seed=None,imgProcessorTask=False, job_info: JobInfo = None): + variant_amount=0.0, variant_seed=None,imgProcessorTask=False, job_info: JobInfo = None, do_color_correction=False, correction_target=None): """this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch""" + + def numpy_to_pil(images): + """ + Convert a numpy image or a batch of images to a PIL image. + """ + if images.ndim == 3: + images = images[None, ...] + images = (images * 255).round().astype("uint8") + pil_images = [Image.fromarray(image) for image in images] + + return pil_images + + # load replacement of nsfw content + def load_replacement(x): + try: + hwc = x.shape + y = Image.open("images/nsfw.jpeg").convert("RGB").resize((hwc[1], hwc[0])) + y = (np.array(y)/255.0).astype(x.dtype) + assert y.shape == x.shape + return y + except Exception: + return x + + # check and replace nsfw content + def check_safety(x_image): + global safety_feature_extractor, safety_checker + if safety_feature_extractor is None: + safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id) + safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id) + safety_checker_input = safety_feature_extractor(numpy_to_pil(x_image), return_tensors="pt") + x_checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values) + for i in range(len(has_nsfw_concept)): + if has_nsfw_concept[i]: + x_checked_image[i] = load_replacement(x_checked_image[i]) + return x_checked_image, has_nsfw_concept + prompt = prompt or '' torch_gc() # start time after garbage collection (or before?) @@ -804,6 +926,12 @@ def process_images( if not ("|" in prompt) and prompt.startswith("@"): prompt = prompt[1:] + negprompt = '' + if '###' in prompt: + prompt, negprompt = prompt.split('###', 1) + prompt = prompt.strip() + negprompt = negprompt.strip() + comments = [] prompt_matrix_parts = [] @@ -882,12 +1010,14 @@ def process_images( if job_info: job_info.job_status = f"Processing Iteration {n+1}/{n_iter}. Batch size {batch_size}" + job_info.rec_steps_imgs.clear() for idx,(p,s) in enumerate(zip(prompts,seeds)): job_info.job_status += f"\nItem {idx}: Seed {s}\nPrompt: {p}" + print(f"Current prompt: {p}") if opt.optimized: modelCS.to(device) - uc = (model if not opt.optimized else modelCS).get_learned_conditioning(len(prompts) * [""]) + uc = (model if not opt.optimized else modelCS).get_learned_conditioning(len(prompts) * [negprompt]) if isinstance(prompts, tuple): prompts = list(prompts) @@ -912,7 +1042,7 @@ def process_images( while(torch.cuda.memory_allocated()/1e6 >= mem): time.sleep(1) - cur_variant_amount = variant_amount + cur_variant_amount = variant_amount if variant_amount == 0.0: # we manually generate all input noises because each one should have a specific seed x = create_random_tensors(shape, seeds=seeds) @@ -935,16 +1065,91 @@ def process_images( # finally, slerp base_x noise to target_x noise for creating a variant x = slerp(device, max(0.0, min(1.0, cur_variant_amount)), base_x, target_x) - samples_ddim = func_sample(init_data=init_data, x=x, conditioning=c, unconditional_conditioning=uc, sampler_name=sampler_name) + # If optimized then use first stage for preview and store it on cpu until needed + if opt.optimized: + step_preview_model = modelFS + step_preview_model.cpu() + else: + step_preview_model = model + + def sample_iteration_callback(image_sample: torch.Tensor, iter_num: int): + ''' Called from the sampler every iteration ''' + if job_info: + job_info.active_iteration_cnt = iter_num + record_periodic_image = job_info.rec_steps_enabled and (0 == iter_num % job_info.rec_steps_intrvl) + if record_periodic_image or job_info.refresh_active_image_requested.is_set(): + preview_start_time = time.time() + if opt.optimized: + step_preview_model.to(device) + + decoded_batch: List[torch.Tensor] = [] + # Break up batch to save VRAM + for sample in image_sample: + sample = sample[None, :] # expands the tensor as if it still had a batch dimension + decoded_sample = step_preview_model.decode_first_stage(sample)[0] + decoded_sample = torch.clamp((decoded_sample + 1.0) / 2.0, min=0.0, max=1.0) + decoded_sample = decoded_sample.cpu() + decoded_batch.append(decoded_sample) + + batch_size = len(decoded_batch) + + if opt.optimized: + step_preview_model.cpu() + + images: List[Image.Image] = [] + # Convert tensor to image (copied from code below) + for ddim in decoded_batch: + x_sample = 255. * rearrange(ddim.numpy(), 'c h w -> h w c') + x_sample = x_sample.astype(np.uint8) + image = Image.fromarray(x_sample) + images.append(image) + + caption = f"Iter {iter_num}" + grid = image_grid(images, len(images), force_n_rows=1, captions=[caption]*len(images)) + + # Save the images if recording steps, and append existing saved steps + if job_info.rec_steps_enabled: + gallery_img_size = tuple(int(0.25*dim) for dim in images[0].size) + job_info.rec_steps_imgs.append(grid.resize(gallery_img_size)) + + # Notify the requester that the image is updated + if job_info.refresh_active_image_requested.is_set(): + if job_info.rec_steps_enabled: + grid_rows = None if batch_size == 1 else len(job_info.rec_steps_imgs) + grid = image_grid(imgs=job_info.rec_steps_imgs[::-1], batch_size=1, force_n_rows=grid_rows) + job_info.active_image = grid + job_info.refresh_active_image_done.set() + job_info.refresh_active_image_requested.clear() + + preview_elapsed_timed = time.time() - preview_start_time + if preview_elapsed_timed / job_info.rec_steps_intrvl > 1: + print( + f"Warning: Preview generation is slowing image generation. It took {preview_elapsed_timed:.2f}s to generate progress images for batch of {batch_size} images!") + + # Interrupt current iteration? + if job_info.stop_cur_iter.is_set(): + job_info.stop_cur_iter.clear() + raise StopIteration() + + try: + samples_ddim = func_sample(init_data=init_data, x=x, conditioning=c, unconditional_conditioning=uc, sampler_name=sampler_name, img_callback=sample_iteration_callback) + except StopIteration: + print("Skipping iteration") + job_info.job_status = "Skipping iteration" + continue if opt.optimized: modelFS.to(device) + for i in range(len(samples_ddim)): + x_samples_ddim = (model if not opt.optimized else modelFS).decode_first_stage(samples_ddim[i].unsqueeze(0)) + x_sample = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) + if filter_nsfw: + x_samples_ddim_numpy = x_sample.cpu().permute(0, 2, 3, 1).numpy() + x_checked_image, has_nsfw_concept = check_safety(x_samples_ddim_numpy) + x_sample = torch.from_numpy(x_checked_image).permute(0, 3, 1, 2) - x_samples_ddim = (model if not opt.optimized else modelFS).decode_first_stage(samples_ddim) - x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) - for i, x_sample in enumerate(x_samples_ddim): sanitized_prompt = prompts[i].replace(' ', '_').translate({ord(x): '' for x in invalid_filename_chars}) if variant_seed != None and variant_seed != '': if variant_amount == 0.0: @@ -958,16 +1163,33 @@ def process_images( sample_path_i = os.path.join(sample_path, sanitized_prompt) os.makedirs(sample_path_i, exist_ok=True) base_count = get_next_sequence_number(sample_path_i) - filename = f"{base_count:05}-{steps}_{sampler_name}_{seed_used}_{cur_variant_amount:.2f}" + filename = opt.filename_format or "[STEPS]_[SAMPLER]_[SEED]_[VARIANT_AMOUNT]" else: sample_path_i = sample_path base_count = get_next_sequence_number(sample_path_i) - sanitized_prompt = sanitized_prompt - filename = f"{base_count:05}-{steps}_{sampler_name}_{seed_used}_{cur_variant_amount:.2f}_{sanitized_prompt}"[:128] #same as before + filename = opt.filename_format or "[STEPS]_[SAMPLER]_[SEED]_[VARIANT_AMOUNT]_[PROMPT]" - x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') + #Add new filenames tags here + filename = f"{base_count:05}-" + filename + filename = filename.replace("[STEPS]", str(steps)) + filename = filename.replace("[CFG]", str(cfg_scale)) + filename = filename.replace("[PROMPT]", sanitized_prompt[:128]) + filename = filename.replace("[PROMPT_SPACES]", prompts[i].translate({ord(x): '' for x in invalid_filename_chars})[:128]) + filename = filename.replace("[WIDTH]", str(width)) + filename = filename.replace("[HEIGHT]", str(height)) + filename = filename.replace("[SAMPLER]", sampler_name) + filename = filename.replace("[SEED]", seed_used) + filename = filename.replace("[VARIANT_AMOUNT]", f"{cur_variant_amount:.2f}") + + x_sample = 255. * rearrange(x_sample[0].cpu().numpy(), 'c h w -> h w c') x_sample = x_sample.astype(np.uint8) + metadata = ImageMetadata(prompt=prompts[i], seed=seeds[i], height=height, width=width, steps=steps, + cfg_scale=cfg_scale, normalize_prompt_weights=normalize_prompt_weights, denoising_strength=denoising_strength, + GFPGAN=use_GFPGAN ) image = Image.fromarray(x_sample) + image = perform_color_correction(image, correction_target, do_color_correction) + ImageMetadata.set_on_image(image, metadata) + original_sample = x_sample original_filename = filename if use_GFPGAN and GFPGAN is not None and not use_RealESRGAN: @@ -976,10 +1198,18 @@ def process_images( cropped_faces, restored_faces, restored_img = GFPGAN.enhance(original_sample[:,:,::-1], has_aligned=False, only_center_face=False, paste_back=True) gfpgan_sample = restored_img[:,:,::-1] gfpgan_image = Image.fromarray(gfpgan_sample) + gfpgan_image = perform_color_correction(gfpgan_image, correction_target, do_color_correction) + gfpgan_image = perform_masked_image_restoration( + gfpgan_image, init_img, init_mask, + mask_blur_strength, mask_restore, + use_RealESRGAN = False, RealESRGAN = None + ) + gfpgan_metadata = copy.copy(metadata) + gfpgan_metadata.GFPGAN = True + ImageMetadata.set_on_image( gfpgan_image, gfpgan_metadata ) gfpgan_filename = original_filename + '-gfpgan' - save_sample(gfpgan_image, sample_path_i, gfpgan_filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale, -normalize_prompt_weights, use_GFPGAN, write_info_files, write_sample_info_to_log_file, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, skip_save, -skip_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, skip_metadata=True) + save_sample(gfpgan_image, sample_path_i, gfpgan_filename, jpg_sample, write_info_files, write_sample_info_to_log_file, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, skip_save, +skip_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, skip_metadata=False) output_images.append(gfpgan_image) #287 #if simple_templating: # grid_captions.append( captions[i] + "\ngfpgan" ) @@ -991,9 +1221,15 @@ skip_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoisin esrgan_filename = original_filename + '-esrgan4x' esrgan_sample = output[:,:,::-1] esrgan_image = Image.fromarray(esrgan_sample) - save_sample(esrgan_image, sample_path_i, esrgan_filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale, -normalize_prompt_weights, use_GFPGAN,write_info_files, write_sample_info_to_log_file, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, skip_save, -skip_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, skip_metadata=True) + esrgan_image = perform_color_correction(esrgan_image, correction_target, do_color_correction) + esrgan_image = perform_masked_image_restoration( + esrgan_image, init_img, init_mask, + mask_blur_strength, mask_restore, + use_RealESRGAN, RealESRGAN + ) + ImageMetadata.set_on_image( esrgan_image, metadata ) + save_sample(esrgan_image, sample_path_i, esrgan_filename, jpg_sample, write_info_files, write_sample_info_to_log_file, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, skip_save, +skip_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, skip_metadata=False) output_images.append(esrgan_image) #287 #if simple_templating: # grid_captions.append( captions[i] + "\nesrgan" ) @@ -1007,9 +1243,15 @@ skip_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoisin gfpgan_esrgan_filename = original_filename + '-gfpgan-esrgan4x' gfpgan_esrgan_sample = output[:,:,::-1] gfpgan_esrgan_image = Image.fromarray(gfpgan_esrgan_sample) - save_sample(gfpgan_esrgan_image, sample_path_i, gfpgan_esrgan_filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale, -normalize_prompt_weights, use_GFPGAN, write_info_files, write_sample_info_to_log_file, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, skip_save, -skip_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, skip_metadata=True) + gfpgan_esrgan_image = perform_color_correction(gfpgan_esrgan_image, correction_target, do_color_correction) + gfpgan_esrgan_image = perform_masked_image_restoration( + gfpgan_esrgan_image, init_img, init_mask, + mask_blur_strength, mask_restore, + use_RealESRGAN, RealESRGAN + ) + ImageMetadata.set_on_image(gfpgan_esrgan_image, metadata) + save_sample(gfpgan_esrgan_image, sample_path_i, gfpgan_esrgan_filename, jpg_sample, write_info_files, write_sample_info_to_log_file, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, +skip_save, skip_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, skip_metadata=False) output_images.append(gfpgan_esrgan_image) #287 #if simple_templating: # grid_captions.append( captions[i] + "\ngfpgan_esrgan" ) @@ -1018,15 +1260,34 @@ skip_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoisin if imgProcessorTask == True: output_images.append(image) + image = perform_masked_image_restoration( + image, init_img, init_mask, + mask_blur_strength, mask_restore, + # RealESRGAN image already processed in if-case above. + use_RealESRGAN = False, RealESRGAN = None + ) + if not skip_save: - save_sample(image, sample_path_i, filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale, -normalize_prompt_weights, use_GFPGAN, write_info_files, write_sample_info_to_log_file, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, skip_save, + save_sample(image, sample_path_i, filename, jpg_sample, write_info_files, write_sample_info_to_log_file, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, skip_save, skip_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, False) if add_original_image or not simple_templating: output_images.append(image) if simple_templating: grid_captions.append( captions[i] ) + # Save the progress images? + if job_info: + if job_info.rec_steps_enabled and (job_info.rec_steps_to_file or job_info.rec_steps_to_gallery): + steps_grid = image_grid(job_info.rec_steps_imgs, 1) + if job_info.rec_steps_to_gallery: + gallery_img_size = tuple(2*dim for dim in image.size) + output_images.append( steps_grid.resize( gallery_img_size ) ) + if job_info.rec_steps_to_file: + steps_grid_filename = f"{original_filename}_step_grid" + save_sample(steps_grid, sample_path_i, steps_grid_filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale, + normalize_prompt_weights, use_GFPGAN, write_info_files, write_sample_info_to_log_file, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, skip_save, + skip_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, False) + if opt.optimized: mem = torch.cuda.memory_allocated()/1e6 modelFS.to("cpu") @@ -1046,7 +1307,7 @@ skip_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoisin import traceback print("Error creating prompt_matrix text:", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) - elif batch_size > 1 or n_iter > 1: + elif len(output_images) > 0 and (batch_size > 1 or n_iter > 1): grid = image_grid(output_images, batch_size) if grid is not None: grid_count = get_next_sequence_number(outpath, 'grid-') @@ -1101,8 +1362,13 @@ def txt2img(prompt: str, ddim_steps: int, sampler_name: str, toggles: List[int], write_info_files = 5 in toggles write_to_one_file = 6 in toggles jpg_sample = 7 in toggles - use_GFPGAN = 8 in toggles - use_RealESRGAN = 9 in toggles + filter_nsfw = 8 in toggles + use_GFPGAN = 9 in toggles + use_RealESRGAN = 10 in toggles + + do_color_correction = False + correction_target = None + ModelLoader(['model'],True,False) if use_GFPGAN and not use_RealESRGAN: ModelLoader(['GFPGAN'],True,False) @@ -1134,8 +1400,8 @@ def txt2img(prompt: str, ddim_steps: int, sampler_name: str, toggles: List[int], def init(): pass - def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name): - samples_ddim, _ = sampler.sample(S=ddim_steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=cfg_scale, unconditional_conditioning=unconditional_conditioning, eta=ddim_eta, x_T=x) + def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name, img_callback: Callable = None): + samples_ddim, _ = sampler.sample(S=ddim_steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=cfg_scale, unconditional_conditioning=unconditional_conditioning, eta=ddim_eta, x_T=x, img_callback=img_callback) return samples_ddim try: @@ -1155,6 +1421,7 @@ def txt2img(prompt: str, ddim_steps: int, sampler_name: str, toggles: List[int], width=width, height=height, prompt_matrix=prompt_matrix, + filter_nsfw=filter_nsfw, use_GFPGAN=use_GFPGAN, use_RealESRGAN=use_RealESRGAN, realesrgan_model_name=realesrgan_model_name, @@ -1168,6 +1435,8 @@ def txt2img(prompt: str, ddim_steps: int, sampler_name: str, toggles: List[int], variant_amount=variant_amount, variant_seed=variant_seed, job_info=job_info, + do_color_correction=do_color_correction, + correction_target=correction_target ) del sampler @@ -1225,7 +1494,15 @@ class Flagging(gr.FlaggingCallback): print("Logged:", filenames[0]) -def img2img(prompt: str, image_editor_mode: str, mask_mode: str, mask_blur_strength: int, ddim_steps: int, sampler_name: str, +def blurArr(a,r=8): + im1=Image.fromarray((a*255).astype(np.int8),"L") + im2 = im1.filter(ImageFilter.GaussianBlur(radius = r)) + out= np.array(im2)/255 + return out + + + +def img2img(prompt: str, image_editor_mode: str, mask_mode: str, mask_blur_strength: int, mask_restore: bool, ddim_steps: int, sampler_name: str, toggles: List[int], realesrgan_model_name: str, n_iter: int, cfg_scale: float, denoising_strength: float, seed: int, height: int, width: int, resize_mode: int, init_info: any = None, init_info_mask: any = None, fp = None, job_info: JobInfo = None): # print([prompt, image_editor_mode, init_info, init_info_mask, mask_mode, @@ -1249,8 +1526,10 @@ def img2img(prompt: str, image_editor_mode: str, mask_mode: str, mask_blur_stren write_info_files = 7 in toggles write_sample_info_to_log_file = 8 in toggles jpg_sample = 9 in toggles - use_GFPGAN = 10 in toggles - use_RealESRGAN = 11 in toggles + do_color_correction = 10 in toggles + filter_nsfw = 11 in toggles + use_GFPGAN = 12 in toggles + use_RealESRGAN = 13 in toggles ModelLoader(['model'],True,False) if use_GFPGAN and not use_RealESRGAN: ModelLoader(['GFPGAN'],True,False) @@ -1279,10 +1558,12 @@ def img2img(prompt: str, image_editor_mode: str, mask_mode: str, mask_blur_stren if image_editor_mode == 'Mask': init_img = init_info_mask["image"] + init_img_transparency = ImageOps.invert(init_img.split()[-1]).convert('L').point(lambda x: 255 if x > 0 else 0, mode='1') init_img = init_img.convert("RGB") init_img = resize_image(resize_mode, init_img, width, height) init_img = init_img.convert("RGB") init_mask = init_info_mask["mask"] + init_mask = ImageChops.lighter(init_img_transparency, init_mask.convert('L')).convert('RGBA') init_mask = init_mask.convert("RGB") init_mask = resize_image(resize_mode, init_mask, width, height) init_mask = init_mask.convert("RGB") @@ -1305,16 +1586,7 @@ def img2img(prompt: str, image_editor_mode: str, mask_mode: str, mask_blur_stren image = torch.from_numpy(image) mask_channel = None - if image_editor_mode == "Uncrop": - alpha = init_img.convert("RGBA") - alpha = resize_image(resize_mode, alpha, width // 8, height // 8) - mask_channel = alpha.split()[-1] - mask_channel = mask_channel.filter(ImageFilter.GaussianBlur(4)) - mask_channel = np.array(mask_channel) - mask_channel[mask_channel >= 255] = 255 - mask_channel[mask_channel < 255] = 0 - mask_channel = Image.fromarray(mask_channel).filter(ImageFilter.GaussianBlur(2)) - elif image_editor_mode == "Mask": + if image_editor_mode == "Mask": alpha = init_mask.convert("RGBA") alpha = resize_image(resize_mode, alpha, width // 8, height // 8) mask_channel = alpha.split()[1] @@ -1329,11 +1601,62 @@ def img2img(prompt: str, image_editor_mode: str, mask_mode: str, mask_blur_stren if opt.optimized: modelFS.to(device) - init_image = 2. * image - 1. + #let's try and find where init_image is 0's + #shape is probably (3,width,height)? + + if image_editor_mode == "Uncrop": + _image=image.numpy()[0] + _mask=np.ones((_image.shape[1],_image.shape[2])) + + #compute bounding box + cmax=np.max(_image,axis=0) + rowmax=np.max(cmax,axis=0) + colmax=np.max(cmax,axis=1) + rowwhere=np.where(rowmax>0)[0] + colwhere=np.where(colmax>0)[0] + rowstart=rowwhere[0] + rowend=rowwhere[-1]+1 + colstart=colwhere[0] + colend=colwhere[-1]+1 + print('bounding box: ',rowstart,rowend,colstart,colend) + + #this is where noise will get added + PAD_IMG=16 + boundingbox=np.zeros(shape=(height,width)) + boundingbox[colstart+PAD_IMG:colend-PAD_IMG,rowstart+PAD_IMG:rowend-PAD_IMG]=1 + boundingbox=blurArr(boundingbox,4) + + #this is the mask for outpainting + PAD_MASK=24 + boundingbox2=np.zeros(shape=(height,width)) + boundingbox2[colstart+PAD_MASK:colend-PAD_MASK,rowstart+PAD_MASK:rowend-PAD_MASK]=1 + boundingbox2=blurArr(boundingbox2,4) + + #noise=np.random.randn(*_image.shape) + noise=np.array([perlinNoise(height,width,height/64,width/64) for i in range(3)]) + _mask*=1-boundingbox2 + + #convert 0,1 to -1,1 + _image = 2. * _image - 1. + + #add noise + boundingbox=np.tile(boundingbox,(3,1,1)) + _image=_image*boundingbox+noise*(1-boundingbox) + + #resize mask + _mask = np.array(resize_image(resize_mode, Image.fromarray(_mask*255), width // 8, height // 8))/255 + + #convert back to torch tensor + init_image=torch.from_numpy(np.expand_dims(_image,axis=0).astype(np.float32)).to(device) + mask=torch.from_numpy(_mask.astype(np.float32)).to(device) + + else: + init_image = 2. * image - 1. + init_image = init_image.to(device) init_image = repeat(init_image, '1 ... -> b ...', b=batch_size) init_latent = (model if not opt.optimized else modelFS).get_first_stage_encoding((model if not opt.optimized else modelFS).encode_first_stage(init_image)) # move to latent space - + if opt.optimized: mem = torch.cuda.memory_allocated()/1e6 modelFS.to("cpu") @@ -1342,7 +1665,7 @@ def img2img(prompt: str, image_editor_mode: str, mask_mode: str, mask_blur_stren return init_latent, mask, - def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name): + def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name, img_callback: Callable = None): t_enc_steps = t_enc obliterate = False if ddim_steps == t_enc_steps: @@ -1364,7 +1687,7 @@ def img2img(prompt: str, image_editor_mode: str, mask_mode: str, mask_blur_stren sigma_sched = sigmas[ddim_steps - t_enc_steps - 1:] model_wrap_cfg = CFGMaskedDenoiser(sampler.model_wrap) - samples_ddim = K.sampling.__dict__[f'sample_{sampler.get_sampler_name()}'](model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': cfg_scale, 'mask': z_mask, 'x0': x0, 'xi': xi}, disable=False) + samples_ddim = K.sampling.__dict__[f'sample_{sampler.get_sampler_name()}'](model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': cfg_scale, 'mask': z_mask, 'x0': x0, 'xi': xi}, disable=False, callback=partial(KDiffusionSampler.img_callback_wrapper, img_callback)) else: x0, z_mask = init_data @@ -1385,18 +1708,14 @@ def img2img(prompt: str, image_editor_mode: str, mask_mode: str, mask_blur_stren return samples_ddim - + correction_target = None if loopback: output_images, info = None, None history = [] initial_seed = None - do_color_correction = False - try: - from skimage import exposure - do_color_correction = True - except: - print("Install scikit-image to perform color correction on loopback") + # turn on color correction for loopback to prevent known issue of color drift + do_color_correction = True for i in range(n_iter): if do_color_correction and i == 0: @@ -1418,6 +1737,7 @@ def img2img(prompt: str, image_editor_mode: str, mask_mode: str, mask_blur_stren width=width, height=height, prompt_matrix=prompt_matrix, + filter_nsfw=filter_nsfw, use_GFPGAN=use_GFPGAN, use_RealESRGAN=False, # Forcefully disable upscaling when using loopback realesrgan_model_name=realesrgan_model_name, @@ -1428,6 +1748,7 @@ def img2img(prompt: str, image_editor_mode: str, mask_mode: str, mask_blur_stren init_mask=init_mask, keep_mask=keep_mask, mask_blur_strength=mask_blur_strength, + mask_restore=mask_restore, denoising_strength=denoising_strength, resize_mode=resize_mode, uses_loopback=loopback, @@ -1436,7 +1757,9 @@ def img2img(prompt: str, image_editor_mode: str, mask_mode: str, mask_blur_stren write_info_files=write_info_files, write_sample_info_to_log_file=write_sample_info_to_log_file, jpg_sample=jpg_sample, - job_info=job_info + job_info=job_info, + do_color_correction=do_color_correction, + correction_target=correction_target ) if initial_seed is None: @@ -1444,16 +1767,6 @@ def img2img(prompt: str, image_editor_mode: str, mask_mode: str, mask_blur_stren init_img = output_images[0] - if do_color_correction and correction_target is not None: - init_img = Image.fromarray(cv2.cvtColor(exposure.match_histograms( - cv2.cvtColor( - np.asarray(init_img), - cv2.COLOR_RGB2LAB - ), - correction_target, - channel_axis=2 - ), cv2.COLOR_LAB2RGB).astype("uint8")) - if not random_seed_loopback: seed = seed + 1 else: @@ -1472,6 +1785,9 @@ def img2img(prompt: str, image_editor_mode: str, mask_mode: str, mask_blur_stren seed = initial_seed else: + if do_color_correction: + correction_target = cv2.cvtColor(np.asarray(init_img.copy()), cv2.COLOR_RGB2LAB) + output_images, seed, info, stats = process_images( outpath=outpath, func_init=init, @@ -1488,6 +1804,7 @@ def img2img(prompt: str, image_editor_mode: str, mask_mode: str, mask_blur_stren width=width, height=height, prompt_matrix=prompt_matrix, + filter_nsfw=filter_nsfw, use_GFPGAN=use_GFPGAN, use_RealESRGAN=use_RealESRGAN, realesrgan_model_name=realesrgan_model_name, @@ -1498,13 +1815,16 @@ def img2img(prompt: str, image_editor_mode: str, mask_mode: str, mask_blur_stren keep_mask=keep_mask, mask_blur_strength=mask_blur_strength, denoising_strength=denoising_strength, + mask_restore=mask_restore, resize_mode=resize_mode, uses_loopback=loopback, sort_samples=sort_samples, write_info_files=write_info_files, write_sample_info_to_log_file=write_sample_info_to_log_file, jpg_sample=jpg_sample, - job_info=job_info + job_info=job_info, + do_color_correction=do_color_correction, + correction_target=correction_target ) del sampler @@ -1572,8 +1892,13 @@ def imgproc(image,image_batch,imgproc_prompt,imgproc_toggles, imgproc_upscale_to images = [] def processGFPGAN(image,strength): image = image.convert("RGB") + metadata = ImageMetadata.get_from_image(image) cropped_faces, restored_faces, restored_img = GFPGAN.enhance(np.array(image, dtype=np.uint8), has_aligned=False, only_center_face=False, paste_back=True) result = Image.fromarray(restored_img) + if metadata: + metadata.GFPGAN = True + ImageMetadata.set_on_image(image, metadata) + if strength < 1.0: result = Image.blend(image, result, strength) @@ -1585,15 +1910,18 @@ def imgproc(image,image_batch,imgproc_prompt,imgproc_toggles, imgproc_upscale_to else: modelMode = imgproc_realesrgan_model_name image = image.convert("RGB") + metadata = ImageMetadata.get_from_image(image) RealESRGAN = load_RealESRGAN(modelMode) result, res = RealESRGAN.enhance(np.array(image, dtype=np.uint8)) result = Image.fromarray(result) + ImageMetadata.set_on_image(result, metadata) if 'x2' in imgproc_realesrgan_model_name: # downscale to 1/2 size result = result.resize((result.width//2, result.height//2), LANCZOS) return result def processGoBig(image): + metadata = ImageMetadata.get_from_image(image) result = processRealESRGAN(image,) if 'x4' in imgproc_realesrgan_model_name: #downscale to 1/2 size @@ -1638,6 +1966,7 @@ def imgproc(image,image_batch,imgproc_prompt,imgproc_toggles, imgproc_upscale_to init_img = result init_mask = None keep_mask = False + mask_restore = False assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]' def init(): @@ -1663,7 +1992,7 @@ def imgproc(image,image_batch,imgproc_prompt,imgproc_toggles, imgproc_upscale_to return init_latent, - def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name): + def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name, img_callback: Callable = None): if sampler_name != 'DDIM': x0, = init_data @@ -1673,7 +2002,7 @@ def imgproc(image,image_batch,imgproc_prompt,imgproc_toggles, imgproc_upscale_to xi = x0 + noise sigma_sched = sigmas[ddim_steps - t_enc - 1:] model_wrap_cfg = CFGDenoiser(sampler.model_wrap) - samples_ddim = K.sampling.__dict__[f'sample_{sampler.get_sampler_name()}'](model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': cfg_scale}, disable=False) + samples_ddim = K.sampling.__dict__[f'sample_{sampler.get_sampler_name()}'](model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': cfg_scale}, disable=False, callback=partial(KDiffusionSampler.img_callback_wrapper, img_callback)) else: x0, = init_data sampler.make_schedule(ddim_num_steps=ddim_steps, ddim_eta=0.0, verbose=False) @@ -1774,6 +2103,7 @@ def imgproc(image,image_batch,imgproc_prompt,imgproc_toggles, imgproc_upscale_to width=width, height=height, prompt_matrix=None, + filter_nsfw=False, use_GFPGAN=None, use_RealESRGAN=None, realesrgan_model_name=None, @@ -1784,6 +2114,7 @@ def imgproc(image,image_batch,imgproc_prompt,imgproc_toggles, imgproc_upscale_to keep_mask=False, mask_blur_strength=None, denoising_strength=denoising_strength, + mask_restore=mask_restore, resize_mode=resize_mode, uses_loopback=False, sort_samples=True, @@ -1808,11 +2139,14 @@ def imgproc(image,image_batch,imgproc_prompt,imgproc_toggles, imgproc_upscale_to del sampler torch.cuda.empty_cache() + ImageMetadata.set_on_image(combined_image, metadata) return combined_image def processLDSR(image): + metadata = ImageMetadata.get_from_image(image) result = LDSR.superResolution(image,int(imgproc_ldsr_steps),str(imgproc_ldsr_pre_downSample),str(imgproc_ldsr_post_downSample)) - return result - + ImageMetadata.set_on_image(result, metadata) + return result + if image_batch != None: if image != None: @@ -1839,7 +2173,7 @@ def imgproc(image,image_batch,imgproc_prompt,imgproc_toggles, imgproc_upscale_to if 1 in imgproc_toggles: if imgproc_upscale_toggles == 0: ModelLoader(['GFPGAN','LDSR'],False,True) # Unload unused models - ModelLoader(['RealESGAN'],True,False,imgproc_realesrgan_model_name) # Load used models + ModelLoader(['RealESGAN'],True,False,imgproc_realesrgan_model_name) # Load used models elif imgproc_upscale_toggles == 1: ModelLoader(['GFPGAN','LDSR'],False,True) # Unload unused models ModelLoader(['RealESGAN','model'],True,False) # Load used models @@ -1851,10 +2185,14 @@ def imgproc(image,image_batch,imgproc_prompt,imgproc_toggles, imgproc_upscale_to ModelLoader(['GFPGAN','LDSR'],False,True) # Unload unused models ModelLoader(['RealESGAN','model'],True,False,imgproc_realesrgan_model_name) # Load used models for image in images: + metadata = ImageMetadata.get_from_image(image) if 0 in imgproc_toggles: #recheck if GFPGAN is loaded since it's the only model that can be loaded in the loop as well ModelLoader(['GFPGAN'],True,False) # Load used models image = processGFPGAN(image,imgproc_gfpgan_strength) + if metadata: + metadata.GFPGAN = True + ImageMetadata.set_on_image(image, metadata) outpathDir = os.path.join(outpath,'GFPGAN') os.makedirs(outpathDir, exist_ok=True) batchNumber = get_next_sequence_number(outpathDir) @@ -1862,47 +2200,51 @@ def imgproc(image,image_batch,imgproc_prompt,imgproc_toggles, imgproc_upscale_to if 1 not in imgproc_toggles: output.append(image) - save_sample(image, outpathDir, outFilename, False, None, None, None, None, None, None, None, None, None, None, None, None, None, False, None, None, None, None, None, None, None, None, None, True) + save_sample(image, outpathDir, outFilename, False, None, None, None, None, None, False, None, None, None, None, None, None, None, None, None, False) if 1 in imgproc_toggles: if imgproc_upscale_toggles == 0: image = processRealESRGAN(image) + ImageMetadata.set_on_image(image, metadata) outpathDir = os.path.join(outpath,'RealESRGAN') os.makedirs(outpathDir, exist_ok=True) batchNumber = get_next_sequence_number(outpathDir) outFilename = str(batchNumber)+'-'+'result' output.append(image) - save_sample(image, outpathDir, outFilename, False, None, None, None, None, None, None, None, None, None, None, None, None, None, False, None, None, None, None, None, None, None, None, None, True) + save_sample(image, outpathDir, outFilename, False, None, None, None, None, None, False, None, None, None, None, None, None, None, None, None, False) elif imgproc_upscale_toggles == 1: image = processGoBig(image) + ImageMetadata.set_on_image(image, metadata) outpathDir = os.path.join(outpath,'GoBig') os.makedirs(outpathDir, exist_ok=True) batchNumber = get_next_sequence_number(outpathDir) outFilename = str(batchNumber)+'-'+'result' output.append(image) - save_sample(image, outpathDir, outFilename, False, None, None, None, None, None, None, None, None, None, None, None, None, None, False, None, None, None, None, None, None, None, None, None, True) + save_sample(image, outpathDir, outFilename, False, None, None, None, None, None, False, None, None, None, None, None, None, None, None, None, False) elif imgproc_upscale_toggles == 2: image = processLDSR(image) + ImageMetadata.set_on_image(image, metadata) outpathDir = os.path.join(outpath,'LDSR') os.makedirs(outpathDir, exist_ok=True) batchNumber = get_next_sequence_number(outpathDir) outFilename = str(batchNumber)+'-'+'result' output.append(image) - save_sample(image, outpathDir, outFilename, False, None, None, None, None, None, None, None, None, None, None, None, None, None, False, None, None, None, None, None, None, None, None, None, True) + save_sample(image, outpathDir, outFilename, False, None, None, None, None, None, False, None, None, None, None, None, None, None, None, None, False) elif imgproc_upscale_toggles == 3: image = processGoBig(image) ModelLoader(['model','GFPGAN','RealESGAN'],False,True) # Unload unused models ModelLoader(['LDSR'],True,False) # Load used models image = processLDSR(image) + ImageMetadata.set_on_image(image, metadata) outpathDir = os.path.join(outpath,'GoLatent') os.makedirs(outpathDir, exist_ok=True) batchNumber = get_next_sequence_number(outpathDir) outFilename = str(batchNumber)+'-'+'result' output.append(image) - save_sample(image, outpathDir, outFilename, None, None, None, None, None, None, None, None, None, None, None, None, None, None, False, None, None, None, None, None, None, None, None, None, True) + save_sample(image, outpathDir, outFilename, None, None, None, None, None, None, False, None, None, None, None, None, None, None, None, None, False) #LDSR is always unloaded to avoid memory issues #ModelLoader(['LDSR'],False,True) @@ -1952,10 +2294,13 @@ def ModelLoader(models,load=False,unload=False,imgproc_realesrgan_model_name='Re def run_GFPGAN(image, strength): ModelLoader(['LDSR','RealESRGAN'],False,True) ModelLoader(['GFPGAN'],True,False) + metadata = ImageMetadata.get_from_image(image) image = image.convert("RGB") cropped_faces, restored_faces, restored_img = GFPGAN.enhance(np.array(image, dtype=np.uint8), has_aligned=False, only_center_face=False, paste_back=True) res = Image.fromarray(restored_img) + metadata.GFPGAN = True + ImageMetadata.set_on_image(res, metadata) if strength < 1.0: res = Image.blend(image, res, strength) @@ -1968,10 +2313,12 @@ def run_RealESRGAN(image, model_name: str): if RealESRGAN.model.name != model_name: try_loading_RealESRGAN(model_name) + metadata = ImageMetadata.get_from_image(image) image = image.convert("RGB") output, img_mode = RealESRGAN.enhance(np.array(image, dtype=np.uint8)) res = Image.fromarray(output) + ImageMetadata.set_on_image(res, metadata) return res @@ -1997,6 +2344,7 @@ txt2img_toggles = [ 'Write sample info files', 'write sample info to log file', 'jpg samples', + 'Filter NSFW content', ] if GFPGAN is not None: @@ -2057,6 +2405,8 @@ img2img_toggles = [ 'Write sample info files', 'Write sample info to one file', 'jpg samples', + 'Color correction (always enabled on loopback mode)', + 'Filter NSFW content', ] # removed for now becuase of Image Lab implementation if GFPGAN is not None: @@ -2086,6 +2436,7 @@ img2img_defaults = { 'cfg_scale': 5.0, 'denoising_strength': 0.75, 'mask_mode': 0, + 'mask_restore': False, 'resize_mode': 0, 'seed': '', 'height': 512, @@ -2099,24 +2450,6 @@ if 'img2img' in user_defaults: img2img_toggle_defaults = [img2img_toggles[i] for i in img2img_defaults['toggles']] img2img_image_mode = 'sketch' -def change_image_editor_mode(choice, cropped_image, resize_mode, width, height): - if choice == "Mask": - return [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), gr.update(visible=True)] - return [gr.update(visible=True), gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=True), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)] - -def update_image_mask(cropped_image, resize_mode, width, height): - resized_cropped_image = resize_image(resize_mode, cropped_image, width, height) if cropped_image else None - return gr.update(value=resized_cropped_image) - - - -def copy_img_to_upscale_esrgan(img): - update = gr.update(selected='realesrgan_tab') - image_data = re.sub('^data:image/.+;base64,', '', img) - processed_image = Image.open(BytesIO(base64.b64decode(image_data))) - return {'realesrgan_source': processed_image, 'tabs': update} - - help_text = """ ## Mask/Crop * The masking/cropping is very temperamental. @@ -2178,7 +2511,7 @@ class ServerLauncher(threading.Thread): 'inbrowser': opt.inbrowser, 'server_name': '0.0.0.0', 'server_port': opt.port, - 'share': opt.share, + 'share': opt.share, 'show_error': True } if not opt.share: diff --git a/scripts/webui_streamlit.py b/scripts/webui_streamlit.py index bd680da..15a2e2f 100644 --- a/scripts/webui_streamlit.py +++ b/scripts/webui_streamlit.py @@ -1,46 +1,31 @@ -import warnings +# base webui import and utils. import streamlit as st -from streamlit import StopException, StreamlitAPIException -import base64, cv2 -import argparse, os, sys, glob, re, random, datetime -from PIL import Image, ImageFont, ImageDraw, ImageFilter, ImageOps -from PIL.PngImagePlugin import PngInfo -import requests -from scipy import integrate -import torch -from torchdiffeq import odeint -from tqdm.auto import trange, tqdm +# streamlit imports +import streamlit_nested_layout + +#streamlit components section +from st_on_hover_tabs import on_hover_tabs + +#other imports + +import warnings +import os import k_diffusion as K -import math -import mimetypes -import numpy as np -import pynvml -import threading, asyncio -import time -import torch -from torch import autocast -from torchvision import transforms -import torch.nn as nn -import yaml -from typing import List, Union -from pathlib import Path -from tqdm import tqdm -from contextlib import contextmanager, nullcontext -from einops import rearrange, repeat -from itertools import islice from omegaconf import OmegaConf -from io import BytesIO -from ldm.models.diffusion.ddim import DDIMSampler -from ldm.models.diffusion.plms import PLMSSampler -from ldm.util import instantiate_from_config -from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, \ - extract_into_tensor -from retry import retry -# we use python-slugify to make the filenames safe for windows and linux, its better than doing it manually -# install it with 'pip install python-slugify' -from slugify import slugify +from sd_utils import * +if not "defaults" in st.session_state: + st.session_state["defaults"] = {} + +st.session_state["defaults"] = OmegaConf.load("configs/webui/webui_streamlit.yaml") + +if (os.path.exists("configs/webui/userconfig_streamlit.yaml")): + user_defaults = OmegaConf.load("configs/webui/userconfig_streamlit.yaml") + st.session_state["defaults"] = OmegaConf.merge(st.session_state["defaults"], user_defaults) + +# end of imports +#--------------------------------------------------------------------------------------------------------------- try: # this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start. @@ -53,1412 +38,18 @@ except: # remove some annoying deprecation warnings that show every now and then. warnings.filterwarnings("ignore", category=DeprecationWarning) -defaults = OmegaConf.load("configs/webui/webui_streamlit.yaml") - -# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the bowser will not show any UI -mimetypes.init() -mimetypes.add_type('application/javascript', '.js') - -# some of those options should not be changed at all because they would break the model, so I removed them from options. -opt_C = 4 -opt_f = 8 - -# should and will be moved to a settings menu in the UI at some point -grid_format = [s.lower() for s in defaults.general.grid_format.split(':')] -grid_lossless = False -grid_quality = 100 -if grid_format[0] == 'png': - grid_ext = 'png' - grid_format = 'png' -elif grid_format[0] in ['jpg', 'jpeg']: - grid_quality = int(grid_format[1]) if len(grid_format) > 1 else 100 - grid_ext = 'jpg' - grid_format = 'jpeg' -elif grid_format[0] == 'webp': - grid_quality = int(grid_format[1]) if len(grid_format) > 1 else 100 - grid_ext = 'webp' - grid_format = 'webp' - if grid_quality < 0: # e.g. webp:-100 for lossless mode - grid_lossless = True - grid_quality = abs(grid_quality) - # this should force GFPGAN and RealESRGAN onto the selected gpu as well -os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" # see issue #152 -os.environ["CUDA_VISIBLE_DEVICES"] = str(defaults.general.gpu) - -@retry(tries=5) -def load_models(continue_prev_run = False, use_GFPGAN=False, use_RealESRGAN=False, RealESRGAN_model="RealESRGAN_x4plus"): - """Load the different models. We also reuse the models that are already in memory to speed things up instead of loading them again. """ - - print ("Loading models.") - - # Generate random run ID - # Used to link runs linked w/ continue_prev_run which is not yet implemented - # Use URL and filesystem safe version just in case. - st.session_state["run_id"] = base64.urlsafe_b64encode( - os.urandom(6) - ).decode("ascii") - - # check what models we want to use and if the they are already loaded. - - if use_GFPGAN: - if "GFPGAN" in st.session_state: - print("GFPGAN already loaded") - else: - # Load GFPGAN - if os.path.exists(defaults.general.GFPGAN_dir): - try: - st.session_state["GFPGAN"] = load_GFPGAN() - print("Loaded GFPGAN") - except Exception: - import traceback - print("Error loading GFPGAN:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) - else: - if "GFPGAN" in st.session_state: - del st.session_state["GFPGAN"] - - if use_RealESRGAN: - if "RealESRGAN" in st.session_state and st.session_state["RealESRGAN"].model.name == RealESRGAN_model: - print("RealESRGAN already loaded") - else: - #Load RealESRGAN - try: - # We first remove the variable in case it has something there, - # some errors can load the model incorrectly and leave things in memory. - del st.session_state["RealESRGAN"] - except KeyError: - pass - - if os.path.exists(defaults.general.RealESRGAN_dir): - # st.session_state is used for keeping the models in memory across multiple pages or runs. - st.session_state["RealESRGAN"] = load_RealESRGAN(RealESRGAN_model) - print("Loaded RealESRGAN with model "+ st.session_state["RealESRGAN"].model.name) - - else: - if "RealESRGAN" in st.session_state: - del st.session_state["RealESRGAN"] - - - if "model" in st.session_state: - print("Model already loaded") - else: - config = OmegaConf.load("configs/stable-diffusion/v1-inference.yaml") - model = load_model_from_config(config, defaults.general.ckpt) - - st.session_state["device"] = torch.device(f"cuda:{defaults.general.gpu}") if torch.cuda.is_available() else torch.device("cpu") - st.session_state["model"] = (model if defaults.general.no_half else model.half()).to(st.session_state["device"] ) - - print("Model loaded.") - - -def load_model_from_config(config, ckpt, verbose=False): - - print(f"Loading model from {ckpt}") - - pl_sd = torch.load(ckpt, map_location="cpu") - if "global_step" in pl_sd: - print(f"Global Step: {pl_sd['global_step']}") - sd = pl_sd["state_dict"] - model = instantiate_from_config(config.model) - m, u = model.load_state_dict(sd, strict=False) - if len(m) > 0 and verbose: - print("missing keys:") - print(m) - if len(u) > 0 and verbose: - print("unexpected keys:") - print(u) - - model.cuda() - model.eval() - return model - -def load_sd_from_config(ckpt, verbose=False): - print(f"Loading model from {ckpt}") - pl_sd = torch.load(ckpt, map_location="cpu") - if "global_step" in pl_sd: - print(f"Global Step: {pl_sd['global_step']}") - sd = pl_sd["state_dict"] - return sd -# -@retry(tries=5) -def generation_callback(img, i=0): - - try: - if i == 0: - if img['i']: i = img['i'] - except TypeError: - pass - - - if i % int(defaults.general.update_preview_frequency) == 0 and defaults.general.update_preview: - #print (img) - #print (type(img)) - # The following lines will convert the tensor we got on img to an actual image we can render on the UI. - # It can probably be done in a better way for someone who knows what they're doing. I don't. - #print (img,isinstance(img, torch.Tensor)) - if isinstance(img, torch.Tensor): - x_samples_ddim = (st.session_state["model"] if not defaults.general.optimized else modelFS).decode_first_stage(img) - else: - # When using the k Diffusion samplers they return a dict instead of a tensor that look like this: - # {'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised} - x_samples_ddim = (st.session_state["model"] if not defaults.general.optimized else modelFS).decode_first_stage(img["denoised"]) - - x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) - - pil_image = transforms.ToPILImage()(x_samples_ddim.squeeze_(0)) - - # update image on the UI so we can see the progress - st.session_state["preview_image"].image(pil_image) - - # Show a progress bar so we can keep track of the progress even when the image progress is not been shown, - # Dont worry, it doesnt affect the performance. - if st.session_state["generation_mode"] == "txt2img": - percent = int(100 * float(i+1 if i+1 < st.session_state.sampling_steps else st.session_state.sampling_steps)/float(st.session_state.sampling_steps)) - st.session_state["progress_bar_text"].text( - f"Running step: {i+1 if i+1 < st.session_state.sampling_steps else st.session_state.sampling_steps}/{st.session_state.sampling_steps} {percent if percent < 100 else 100}%") - else: - round_sampling_steps = round(st.session_state.sampling_steps * st.session_state["denoising_strength"]) - percent = int(100 * float(i+1 if i+1 < round_sampling_steps else round_sampling_steps)/float(round_sampling_steps)) - st.session_state["progress_bar_text"].text( - f"""Running step: {i+1 if i+1 < round_sampling_steps else round_sampling_steps}/{round_sampling_steps} {percent if percent < 100 else 100}%""") - - st.session_state["progress_bar"].progress(percent if percent < 100 else 100) - - - -class MemUsageMonitor(threading.Thread): - stop_flag = False - max_usage = 0 - total = -1 - - def __init__(self, name): - threading.Thread.__init__(self) - self.name = name - - def run(self): - try: - pynvml.nvmlInit() - except: - print(f"[{self.name}] Unable to initialize NVIDIA management. No memory stats. \n") - return - print(f"[{self.name}] Recording max memory usage...\n") - handle = pynvml.nvmlDeviceGetHandleByIndex(defaults.general.gpu) - self.total = pynvml.nvmlDeviceGetMemoryInfo(handle).total - while not self.stop_flag: - m = pynvml.nvmlDeviceGetMemoryInfo(handle) - self.max_usage = max(self.max_usage, m.used) - # print(self.max_usage) - time.sleep(0.1) - print(f"[{self.name}] Stopped recording.\n") - pynvml.nvmlShutdown() - - def read(self): - return self.max_usage, self.total - - def stop(self): - self.stop_flag = True - - def read_and_stop(self): - self.stop_flag = True - return self.max_usage, self.total - -class CFGMaskedDenoiser(nn.Module): - def __init__(self, model): - super().__init__() - self.inner_model = model - - def forward(self, x, sigma, uncond, cond, cond_scale, mask, x0, xi): - x_in = x - x_in = torch.cat([x_in] * 2) - sigma_in = torch.cat([sigma] * 2) - cond_in = torch.cat([uncond, cond]) - uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2) - denoised = uncond + (cond - uncond) * cond_scale - - if mask is not None: - assert x0 is not None - img_orig = x0 - mask_inv = 1. - mask - denoised = (img_orig * mask_inv) + (mask * denoised) - - return denoised - -class CFGDenoiser(nn.Module): - def __init__(self, model): - super().__init__() - self.inner_model = model - - def forward(self, x, sigma, uncond, cond, cond_scale): - x_in = torch.cat([x] * 2) - sigma_in = torch.cat([sigma] * 2) - cond_in = torch.cat([uncond, cond]) - uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2) - return uncond + (cond - uncond) * cond_scale -def append_zero(x): - return torch.cat([x, x.new_zeros([1])]) -def append_dims(x, target_dims): - """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" - dims_to_append = target_dims - x.ndim - if dims_to_append < 0: - raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less') - return x[(...,) + (None,) * dims_to_append] -def get_sigmas_karras(n, sigma_min, sigma_max, rho=7., device='cpu'): - """Constructs the noise schedule of Karras et al. (2022).""" - ramp = torch.linspace(0, 1, n) - min_inv_rho = sigma_min ** (1 / rho) - max_inv_rho = sigma_max ** (1 / rho) - sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho - return append_zero(sigmas).to(device) - - -def get_sigmas_exponential(n, sigma_min, sigma_max, device='cpu'): - """Constructs an exponential noise schedule.""" - sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), n, device=device).exp() - return append_zero(sigmas) - - -def get_sigmas_vp(n, beta_d=19.9, beta_min=0.1, eps_s=1e-3, device='cpu'): - """Constructs a continuous VP noise schedule.""" - t = torch.linspace(1, eps_s, n, device=device) - sigmas = torch.sqrt(torch.exp(beta_d * t ** 2 / 2 + beta_min * t) - 1) - return append_zero(sigmas) - - -def to_d(x, sigma, denoised): - """Converts a denoiser output to a Karras ODE derivative.""" - return (x - denoised) / append_dims(sigma, x.ndim) -def linear_multistep_coeff(order, t, i, j): - if order - 1 > i: - raise ValueError(f'Order {order} too high for step {i}') - def fn(tau): - prod = 1. - for k in range(order): - if j == k: - continue - prod *= (tau - t[i - k]) / (t[i - j] - t[i - k]) - return prod - return integrate.quad(fn, t[i], t[i + 1], epsrel=1e-4)[0] - -class KDiffusionSampler: - def __init__(self, m, sampler): - self.model = m - self.model_wrap = K.external.CompVisDenoiser(m) - self.schedule = sampler - def get_sampler_name(self): - return self.schedule - def sample(self, S, conditioning, batch_size, shape, verbose, unconditional_guidance_scale, unconditional_conditioning, eta, x_T, img_callback=None, log_every_t=None): - sigmas = self.model_wrap.get_sigmas(S) - x = x_T * sigmas[0] - model_wrap_cfg = CFGDenoiser(self.model_wrap) - samples_ddim = None - samples_ddim = K.sampling.__dict__[f'sample_{self.schedule}'](model_wrap_cfg, x, sigmas, - extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, - 'cond_scale': unconditional_guidance_scale}, disable=False, callback=generation_callback) - # - return samples_ddim, None - - -@torch.no_grad() -def log_likelihood(model, x, sigma_min, sigma_max, extra_args=None, atol=1e-4, rtol=1e-4): - extra_args = {} if extra_args is None else extra_args - s_in = x.new_ones([x.shape[0]]) - v = torch.randint_like(x, 2) * 2 - 1 - fevals = 0 - def ode_fn(sigma, x): - nonlocal fevals - with torch.enable_grad(): - x = x[0].detach().requires_grad_() - denoised = model(x, sigma * s_in, **extra_args) - d = to_d(x, sigma, denoised) - fevals += 1 - grad = torch.autograd.grad((d * v).sum(), x)[0] - d_ll = (v * grad).flatten(1).sum(1) - return d.detach(), d_ll - x_min = x, x.new_zeros([x.shape[0]]) - t = x.new_tensor([sigma_min, sigma_max]) - sol = odeint(ode_fn, x_min, t, atol=atol, rtol=rtol, method='dopri5') - latent, delta_ll = sol[0][-1], sol[1][-1] - ll_prior = torch.distributions.Normal(0, sigma_max).log_prob(latent).flatten(1).sum(1) - return ll_prior + delta_ll, {'fevals': fevals} - - -def create_random_tensors(shape, seeds): - xs = [] - for seed in seeds: - torch.manual_seed(seed) - - # randn results depend on device; gpu and cpu get different results for same seed; - # the way I see it, it's better to do this on CPU, so that everyone gets same result; - # but the original script had it like this so i do not dare change it for now because - # it will break everyone's seeds. - xs.append(torch.randn(shape, device=defaults.general.gpu)) - x = torch.stack(xs) - return x - -def torch_gc(): - torch.cuda.empty_cache() - torch.cuda.ipc_collect() - -def load_GFPGAN(): - model_name = 'GFPGANv1.3' - model_path = os.path.join(defaults.general.GFPGAN_dir, 'experiments/pretrained_models', model_name + '.pth') - if not os.path.isfile(model_path): - raise Exception("GFPGAN model not found at path "+model_path) - - sys.path.append(os.path.abspath(defaults.general.GFPGAN_dir)) - from gfpgan import GFPGANer - - if defaults.general.gfpgan_cpu or defaults.general.extra_models_cpu: - instance = GFPGANer(model_path=model_path, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=torch.device('cpu')) - elif defaults.general.extra_models_gpu: - instance = GFPGANer(model_path=model_path, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=torch.device(f'cuda:{defaults.general.gfpgan_gpu}')) - else: - instance = GFPGANer(model_path=model_path, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=torch.device(f'cuda:{defaults.general.gpu}')) - return instance - -def load_RealESRGAN(model_name: str): - from basicsr.archs.rrdbnet_arch import RRDBNet - RealESRGAN_models = { - 'RealESRGAN_x4plus': RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4), - 'RealESRGAN_x4plus_anime_6B': RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4) - } - - model_path = os.path.join(defaults.general.RealESRGAN_dir, 'experiments/pretrained_models', model_name + '.pth') - if not os.path.exists(os.path.join(defaults.general.RealESRGAN_dir, "experiments","pretrained_models", f"{model_name}.pth")): - raise Exception(model_name+".pth not found at path "+model_path) - - sys.path.append(os.path.abspath(defaults.general.RealESRGAN_dir)) - from realesrgan import RealESRGANer - - if defaults.general.esrgan_cpu or defaults.general.extra_models_cpu: - instance = RealESRGANer(scale=2, model_path=model_path, model=RealESRGAN_models[model_name], pre_pad=0, half=False) # cpu does not support half - instance.device = torch.device('cpu') - instance.model.to('cpu') - elif defaults.general.extra_models_gpu: - instance = RealESRGANer(scale=2, model_path=model_path, model=RealESRGAN_models[model_name], pre_pad=0, half=not defaults.general.no_half, device=torch.device(f'cuda:{defaults.general.esrgan_gpu}')) - else: - instance = RealESRGANer(scale=2, model_path=model_path, model=RealESRGAN_models[model_name], pre_pad=0, half=not defaults.general.no_half, device=torch.device(f'cuda:{defaults.general.gpu}')) - instance.model.name = model_name - - return instance - -prompt_parser = re.compile(""" - (?P # capture group for 'prompt' - [^:]+ # match one or more non ':' characters - ) # end 'prompt' - (?: # non-capture group - :+ # match one or more ':' characters - (?P # capture group for 'weight' - -?\\d+(?:\\.\\d+)? # match positive or negative decimal number - )? # end weight capture group, make optional - \\s* # strip spaces after weight - | # OR - $ # else, if no ':' then match end of line - ) # end non-capture group -""", re.VERBOSE) - -# grabs all text up to the first occurrence of ':' as sub-prompt -# takes the value following ':' as weight -# if ':' has no value defined, defaults to 1.0 -# repeats until no text remaining -def split_weighted_subprompts(input_string, normalize=True): - parsed_prompts = [(match.group("prompt"), float(match.group("weight") or 1)) for match in re.finditer(prompt_parser, input_string)] - if not normalize: - return parsed_prompts - # this probably still doesn't handle negative weights very well - weight_sum = sum(map(lambda x: x[1], parsed_prompts)) - return [(x[0], x[1] / weight_sum) for x in parsed_prompts] - -def slerp(device, t, v0:torch.Tensor, v1:torch.Tensor, DOT_THRESHOLD=0.9995): - v0 = v0.detach().cpu().numpy() - v1 = v1.detach().cpu().numpy() - - dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1))) - if np.abs(dot) > DOT_THRESHOLD: - v2 = (1 - t) * v0 + t * v1 - else: - theta_0 = np.arccos(dot) - sin_theta_0 = np.sin(theta_0) - theta_t = theta_0 * t - sin_theta_t = np.sin(theta_t) - s0 = np.sin(theta_0 - theta_t) / sin_theta_0 - s1 = sin_theta_t / sin_theta_0 - v2 = s0 * v0 + s1 * v1 - - v2 = torch.from_numpy(v2).to(device) - - return v2 - - -def ModelLoader(models,load=False,unload=False,imgproc_realesrgan_model_name='RealESRGAN_x4plus'): - #get global variables - global_vars = globals() - #check if m is in globals - if unload: - for m in models: - if m in global_vars: - #if it is, delete it - del global_vars[m] - if defaults.general.optimized: - if m == 'model': - del global_vars[m+'FS'] - del global_vars[m+'CS'] - if m =='model': - m='Stable Diffusion' - print('Unloaded ' + m) - if load: - for m in models: - if m not in global_vars or m in global_vars and type(global_vars[m]) == bool: - #if it isn't, load it - if m == 'GFPGAN': - global_vars[m] = load_GFPGAN() - elif m == 'model': - sdLoader = load_sd_from_config() - global_vars[m] = sdLoader[0] - if defaults.general.optimized: - global_vars[m+'CS'] = sdLoader[1] - global_vars[m+'FS'] = sdLoader[2] - elif m == 'RealESRGAN': - global_vars[m] = load_RealESRGAN(imgproc_realesrgan_model_name) - elif m == 'LDSR': - global_vars[m] = load_LDSR() - if m =='model': - m='Stable Diffusion' - print('Loaded ' + m) - torch_gc() - - - -def get_font(fontsize): - fonts = ["arial.ttf", "DejaVuSans.ttf"] - for font_name in fonts: - try: - return ImageFont.truetype(font_name, fontsize) - except OSError: - pass - - # ImageFont.load_default() is practically unusable as it only supports - # latin1, so raise an exception instead if no usable font was found - raise Exception(f"No usable font found (tried {', '.join(fonts)})") - -def load_embeddings(fp): - if fp is not None and hasattr(st.session_state["model"], "embedding_manager"): - st.session_state["model"].embedding_manager.load(fp['name']) - -def image_grid(imgs, batch_size, force_n_rows=None, captions=None): - #print (len(imgs)) - if force_n_rows is not None: - rows = force_n_rows - elif defaults.general.n_rows > 0: - rows = defaults.general.n_rows - elif defaults.general.n_rows == 0: - rows = batch_size - else: - rows = math.sqrt(len(imgs)) - rows = round(rows) - - cols = math.ceil(len(imgs) / rows) - - w, h = imgs[0].size - grid = Image.new('RGB', size=(cols * w, rows * h), color='black') - - fnt = get_font(30) - - for i, img in enumerate(imgs): - grid.paste(img, box=(i % cols * w, i // cols * h)) - if captions and i= 2**32: - n = n >> 32 - return n - -def check_prompt_length(prompt, comments): - """this function tests if prompt is too long, and if so, adds a message to comments""" - - tokenizer = (st.session_state["model"] if not defaults.general.optimized else modelCS).cond_stage_model.tokenizer - max_length = (st.session_state["model"] if not defaults.general.optimized else modelCS).cond_stage_model.max_length - - info = (st.session_state["model"] if not defaults.general.optimized else modelCS).cond_stage_model.tokenizer([prompt], truncation=True, max_length=max_length, - return_overflowing_tokens=True, padding="max_length", return_tensors="pt") - ovf = info['overflowing_tokens'][0] - overflowing_count = ovf.shape[0] - if overflowing_count == 0: - return - - vocab = {v: k for k, v in tokenizer.get_vocab().items()} - overflowing_words = [vocab.get(int(x), "") for x in ovf] - overflowing_text = tokenizer.convert_tokens_to_string(''.join(overflowing_words)) - - comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n") - -def save_sample(image, sample_path_i, filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale, - normalize_prompt_weights, use_GFPGAN, write_info_files, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, - save_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, save_individual_images): - - filename_i = os.path.join(sample_path_i, filename) - - if not jpg_sample: - if defaults.general.save_metadata: - metadata = PngInfo() - metadata.add_text("SD:prompt", prompts[i]) - metadata.add_text("SD:seed", str(seeds[i])) - metadata.add_text("SD:width", str(width)) - metadata.add_text("SD:height", str(height)) - metadata.add_text("SD:steps", str(steps)) - metadata.add_text("SD:cfg_scale", str(cfg_scale)) - metadata.add_text("SD:normalize_prompt_weights", str(normalize_prompt_weights)) - if init_img is not None: - metadata.add_text("SD:denoising_strength", str(denoising_strength)) - metadata.add_text("SD:GFPGAN", str(use_GFPGAN and st.session_state["GFPGAN"] is not None)) - image.save(f"{filename_i}.png", pnginfo=metadata) - else: - image.save(f"{filename_i}.png") - else: - image.save(f"{filename_i}.jpg", 'jpeg', quality=100, optimize=True) - - if write_info_files: - # toggles differ for txt2img vs. img2img: - offset = 0 if init_img is None else 2 - toggles = [] - if prompt_matrix: - toggles.append(0) - if normalize_prompt_weights: - toggles.append(1) - if init_img is not None: - if uses_loopback: - toggles.append(2) - if uses_random_seed_loopback: - toggles.append(3) - if save_individual_images: - toggles.append(2 + offset) - if save_grid: - toggles.append(3 + offset) - if sort_samples: - toggles.append(4 + offset) - if write_info_files: - toggles.append(5 + offset) - if use_GFPGAN: - toggles.append(6 + offset) - info_dict = dict( - target="txt2img" if init_img is None else "img2img", - prompt=prompts[i], ddim_steps=steps, toggles=toggles, sampler_name=sampler_name, - ddim_eta=ddim_eta, n_iter=n_iter, batch_size=batch_size, cfg_scale=cfg_scale, - seed=seeds[i], width=width, height=height - ) - if init_img is not None: - # Not yet any use for these, but they bloat up the files: - #info_dict["init_img"] = init_img - #info_dict["init_mask"] = init_mask - info_dict["denoising_strength"] = denoising_strength - info_dict["resize_mode"] = resize_mode - with open(f"{filename_i}.yaml", "w", encoding="utf8") as f: - yaml.dump(info_dict, f, allow_unicode=True, width=10000) - - # render the image on the frontend - st.session_state["preview_image"].image(image) - -def get_next_sequence_number(path, prefix=''): - """ - Determines and returns the next sequence number to use when saving an - image in the specified directory. - - If a prefix is given, only consider files whose names start with that - prefix, and strip the prefix from filenames before extracting their - sequence number. - - The sequence starts at 0. - """ - result = -1 - for p in Path(path).iterdir(): - if p.name.endswith(('.png', '.jpg')) and p.name.startswith(prefix): - tmp = p.name[len(prefix):] - try: - result = max(int(tmp.split('-')[0]), result) - except ValueError: - pass - return result + 1 - - -def oxlamon_matrix(prompt, seed, n_iter, batch_size): - pattern = re.compile(r'(,\s){2,}') - - class PromptItem: - def __init__(self, text, parts, item): - self.text = text - self.parts = parts - if item: - self.parts.append( item ) - - def clean(txt): - return re.sub(pattern, ', ', txt) - - def getrowcount( txt ): - for data in re.finditer( ".*?\\((.*?)\\).*", txt ): - if data: - return len(data.group(1).split("|")) - break - return None - - def repliter( txt ): - for data in re.finditer( ".*?\\((.*?)\\).*", txt ): - if data: - r = data.span(1) - for item in data.group(1).split("|"): - yield (clean(txt[:r[0]-1] + item.strip() + txt[r[1]+1:]), item.strip()) - break - - def iterlist( items ): - outitems = [] - for item in items: - for newitem, newpart in repliter(item.text): - outitems.append( PromptItem(newitem, item.parts.copy(), newpart) ) - - return outitems - - def getmatrix( prompt ): - dataitems = [ PromptItem( prompt[1:].strip(), [], None ) ] - while True: - newdataitems = iterlist( dataitems ) - if len( newdataitems ) == 0: - return dataitems - dataitems = newdataitems - - def classToArrays( items, seed, n_iter ): - texts = [] - parts = [] - seeds = [] - - for item in items: - itemseed = seed - for i in range(n_iter): - texts.append( item.text ) - parts.append( f"Seed: {itemseed}\n" + "\n".join(item.parts) ) - seeds.append( itemseed ) - itemseed += 1 - - return seeds, texts, parts - - all_seeds, all_prompts, prompt_matrix_parts = classToArrays(getmatrix( prompt ), seed, n_iter) - n_iter = math.ceil(len(all_prompts) / batch_size) - - needrows = getrowcount(prompt) - if needrows: - xrows = math.sqrt(len(all_prompts)) - xrows = round(xrows) - # if columns is to much - cols = math.ceil(len(all_prompts) / xrows) - if cols > needrows*4: - needrows *= 2 - - return all_seeds, n_iter, prompt_matrix_parts, all_prompts, needrows - - -def process_images( - outpath, func_init, func_sample, prompt, seed, sampler_name, save_grid, batch_size, - n_iter, steps, cfg_scale, width, height, prompt_matrix, use_GFPGAN, use_RealESRGAN, realesrgan_model_name, - fp=None, ddim_eta=0.0, normalize_prompt_weights=True, init_img=None, init_mask=None, - keep_mask=False, mask_blur_strength=3, denoising_strength=0.75, resize_mode=None, uses_loopback=False, - uses_random_seed_loopback=False, sort_samples=True, write_info_files=True, jpg_sample=False, - variant_amount=0.0, variant_seed=None, save_individual_images: bool = True): - """this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch""" - assert prompt is not None - torch_gc() - # start time after garbage collection (or before?) - start_time = time.time() - - # We will use this date here later for the folder name, need to start_time if not need - run_start_dt = datetime.datetime.now() - - mem_mon = MemUsageMonitor('MemMon') - mem_mon.start() - - if hasattr(st.session_state["model"], "embedding_manager"): - load_embeddings(fp) - - os.makedirs(outpath, exist_ok=True) - - sample_path = os.path.join(outpath, "samples") - os.makedirs(sample_path, exist_ok=True) - - if not ("|" in prompt) and prompt.startswith("@"): - prompt = prompt[1:] - - comments = [] - - prompt_matrix_parts = [] - simple_templating = False - add_original_image = not (use_RealESRGAN or use_GFPGAN) - - if prompt_matrix: - if prompt.startswith("@"): - simple_templating = True - add_original_image = not (use_RealESRGAN or use_GFPGAN) - all_seeds, n_iter, prompt_matrix_parts, all_prompts, frows = oxlamon_matrix(prompt, seed, n_iter, batch_size) - else: - all_prompts = [] - prompt_matrix_parts = prompt.split("|") - combination_count = 2 ** (len(prompt_matrix_parts) - 1) - for combination_num in range(combination_count): - current = prompt_matrix_parts[0] - - for n, text in enumerate(prompt_matrix_parts[1:]): - if combination_num & (2 ** n) > 0: - current += ("" if text.strip().startswith(",") else ", ") + text - - all_prompts.append(current) - - n_iter = math.ceil(len(all_prompts) / batch_size) - all_seeds = len(all_prompts) * [seed] - - print(f"Prompt matrix will create {len(all_prompts)} images using a total of {n_iter} batches.") - else: - - if not defaults.general.no_verify_input: - try: - check_prompt_length(prompt, comments) - except: - import traceback - print("Error verifying input:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) - - all_prompts = batch_size * n_iter * [prompt] - all_seeds = [seed + x for x in range(len(all_prompts))] - - precision_scope = autocast if defaults.general.precision == "autocast" else nullcontext - output_images = [] - grid_captions = [] - stats = [] - with torch.no_grad(), precision_scope("cuda"), (st.session_state["model"].ema_scope() if not defaults.general.optimized else nullcontext()): - init_data = func_init() - tic = time.time() - - - # if variant_amount > 0.0 create noise from base seed - base_x = None - if variant_amount > 0.0: - target_seed_randomizer = seed_to_int('') # random seed - torch.manual_seed(seed) # this has to be the single starting seed (not per-iteration) - base_x = create_random_tensors([opt_C, height // opt_f, width // opt_f], seeds=[seed]) - # we don't want all_seeds to be sequential from starting seed with variants, - # since that makes the same variants each time, - # so we add target_seed_randomizer as a random offset - for si in range(len(all_seeds)): - all_seeds[si] += target_seed_randomizer - - for n in range(n_iter): - print(f"Iteration: {n+1}/{n_iter}") - prompts = all_prompts[n * batch_size:(n + 1) * batch_size] - captions = prompt_matrix_parts[n * batch_size:(n + 1) * batch_size] - seeds = all_seeds[n * batch_size:(n + 1) * batch_size] - - print(prompt) - - if defaults.general.optimized: - modelCS.to(defaults.general.gpu) - - uc = (st.session_state["model"] if not defaults.general.optimized else modelCS).get_learned_conditioning(len(prompts) * [""]) - - if isinstance(prompts, tuple): - prompts = list(prompts) - - # split the prompt if it has : for weighting - # TODO for speed it might help to have this occur when all_prompts filled?? - weighted_subprompts = split_weighted_subprompts(prompts[0], normalize_prompt_weights) - - # sub-prompt weighting used if more than 1 - if len(weighted_subprompts) > 1: - c = torch.zeros_like(uc) # i dont know if this is correct.. but it works - for i in range(0, len(weighted_subprompts)): - # note if alpha negative, it functions same as torch.sub - c = torch.add(c, (st.session_state["model"] if not defaults.general.optimized else modelCS).get_learned_conditioning(weighted_subprompts[i][0]), alpha=weighted_subprompts[i][1]) - else: # just behave like usual - c = (st.session_state["model"] if not defaults.general.optimized else modelCS).get_learned_conditioning(prompts) - - - shape = [opt_C, height // opt_f, width // opt_f] - - if defaults.general.optimized: - mem = torch.cuda.memory_allocated()/1e6 - modelCS.to("cpu") - while(torch.cuda.memory_allocated()/1e6 >= mem): - time.sleep(1) - - if variant_amount == 0.0: - # we manually generate all input noises because each one should have a specific seed - x = create_random_tensors(shape, seeds=seeds) - - else: # we are making variants - # using variant_seed as sneaky toggle, - # when not None or '' use the variant_seed - # otherwise use seeds - if variant_seed != None and variant_seed != '': - specified_variant_seed = seed_to_int(variant_seed) - torch.manual_seed(specified_variant_seed) - seeds = [specified_variant_seed] - target_x = create_random_tensors(shape, seeds=seeds) - # finally, slerp base_x noise to target_x noise for creating a variant - x = slerp(defaults.general.gpu, max(0.0, min(1.0, variant_amount)), base_x, target_x) - - samples_ddim = func_sample(init_data=init_data, x=x, conditioning=c, unconditional_conditioning=uc, sampler_name=sampler_name) - - if defaults.general.optimized: - modelFS.to(defaults.general.gpu) - - x_samples_ddim = (st.session_state["model"] if not defaults.general.optimized else modelFS).decode_first_stage(samples_ddim) - x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) - - for i, x_sample in enumerate(x_samples_ddim): - sanitized_prompt = slugify(prompts[i]) - - if sort_samples: - full_path = os.path.join(os.getcwd(), sample_path, sanitized_prompt) - - - sanitized_prompt = sanitized_prompt[:220-len(full_path)] - sample_path_i = os.path.join(sample_path, sanitized_prompt) - - #print(f"output folder length: {len(os.path.join(os.getcwd(), sample_path_i))}") - #print(os.path.join(os.getcwd(), sample_path_i)) - - os.makedirs(sample_path_i, exist_ok=True) - base_count = get_next_sequence_number(sample_path_i) - filename = f"{base_count:05}-{steps}_{sampler_name}_{seeds[i]}" - else: - full_path = os.path.join(os.getcwd(), sample_path) - sample_path_i = sample_path - base_count = get_next_sequence_number(sample_path_i) - filename = f"{base_count:05}-{steps}_{sampler_name}_{seeds[i]}_{sanitized_prompt}"[:220-len(full_path)] #same as before - - x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') - x_sample = x_sample.astype(np.uint8) - image = Image.fromarray(x_sample) - original_sample = x_sample - original_filename = filename - - if use_GFPGAN and st.session_state["GFPGAN"] is not None and not use_RealESRGAN: - #skip_save = True # #287 >_> - torch_gc() - cropped_faces, restored_faces, restored_img = st.session_state["GFPGAN"].enhance(x_sample[:,:,::-1], has_aligned=False, only_center_face=False, paste_back=True) - gfpgan_sample = restored_img[:,:,::-1] - gfpgan_image = Image.fromarray(gfpgan_sample) - gfpgan_filename = original_filename + '-gfpgan' - - save_sample(gfpgan_image, sample_path_i, gfpgan_filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale, - normalize_prompt_weights, use_GFPGAN, write_info_files, prompt_matrix, init_img, uses_loopback, - uses_random_seed_loopback, save_grid, sort_samples, sampler_name, ddim_eta, - n_iter, batch_size, i, denoising_strength, resize_mode, save_individual_images=False) - - output_images.append(gfpgan_image) #287 - if simple_templating: - grid_captions.append( captions[i] + "\ngfpgan" ) - - if use_RealESRGAN and st.session_state["RealESRGAN"] is not None and not use_GFPGAN: - #skip_save = True # #287 >_> - torch_gc() - - if st.session_state["RealESRGAN"].model.name != realesrgan_model_name: - #try_loading_RealESRGAN(realesrgan_model_name) - load_models(use_GFPGAN=use_GFPGAN, use_RealESRGAN=use_RealESRGAN, RealESRGAN_model=realesrgan_model_name) - - output, img_mode = st.session_state["RealESRGAN"].enhance(x_sample[:,:,::-1]) - esrgan_filename = original_filename + '-esrgan4x' - esrgan_sample = output[:,:,::-1] - esrgan_image = Image.fromarray(esrgan_sample) - - #save_sample(image, sample_path_i, original_filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale, - #normalize_prompt_weights, use_GFPGAN, write_info_files, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, skip_save, - #save_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode) - - save_sample(esrgan_image, sample_path_i, esrgan_filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale, - normalize_prompt_weights, use_GFPGAN, write_info_files, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, - save_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, save_individual_images=False) - - output_images.append(esrgan_image) #287 - if simple_templating: - grid_captions.append( captions[i] + "\nesrgan" ) - - if use_RealESRGAN and st.session_state["RealESRGAN"] is not None and use_GFPGAN and st.session_state["GFPGAN"] is not None: - #skip_save = True # #287 >_> - torch_gc() - cropped_faces, restored_faces, restored_img = st.session_state["GFPGAN"].enhance(x_sample[:,:,::-1], has_aligned=False, only_center_face=False, paste_back=True) - gfpgan_sample = restored_img[:,:,::-1] - - if st.session_state["RealESRGAN"].model.name != realesrgan_model_name: - #try_loading_RealESRGAN(realesrgan_model_name) - load_models(use_GFPGAN=use_GFPGAN, use_RealESRGAN=use_RealESRGAN, RealESRGAN_model=realesrgan_model_name) - - output, img_mode = st.session_state["RealESRGAN"].enhance(gfpgan_sample[:,:,::-1]) - gfpgan_esrgan_filename = original_filename + '-gfpgan-esrgan4x' - gfpgan_esrgan_sample = output[:,:,::-1] - gfpgan_esrgan_image = Image.fromarray(gfpgan_esrgan_sample) - - save_sample(gfpgan_esrgan_image, sample_path_i, gfpgan_esrgan_filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale, - normalize_prompt_weights, False, write_info_files, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, - save_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, save_individual_images=False) - - output_images.append(gfpgan_esrgan_image) #287 - - if simple_templating: - grid_captions.append( captions[i] + "\ngfpgan_esrgan" ) - - if save_individual_images: - save_sample(image, sample_path_i, filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale, - normalize_prompt_weights, use_GFPGAN, write_info_files, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, - save_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, save_individual_images) - - if not use_GFPGAN or not use_RealESRGAN: - output_images.append(image) - - #if add_original_image or not simple_templating: - #output_images.append(image) - #if simple_templating: - #grid_captions.append( captions[i] ) - - if defaults.general.optimized: - mem = torch.cuda.memory_allocated()/1e6 - modelFS.to("cpu") - while(torch.cuda.memory_allocated()/1e6 >= mem): - time.sleep(1) - - if prompt_matrix or save_grid: - if prompt_matrix: - if simple_templating: - grid = image_grid(output_images, n_iter, force_n_rows=frows, captions=grid_captions) - else: - grid = image_grid(output_images, n_iter, force_n_rows=1 << ((len(prompt_matrix_parts)-1)//2)) - try: - grid = draw_prompt_matrix(grid, width, height, prompt_matrix_parts) - except: - import traceback - print("Error creating prompt_matrix text:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) - else: - grid = image_grid(output_images, batch_size) - - if grid and (batch_size > 1 or n_iter > 1): - output_images.insert(0, grid) - - grid_count = get_next_sequence_number(outpath, 'grid-') - grid_file = f"grid-{grid_count:05}-{seed}_{slugify(prompts[i].replace(' ', '_')[:220-len(full_path)])}.{grid_ext}" - grid.save(os.path.join(outpath, grid_file), grid_format, quality=grid_quality, lossless=grid_lossless, optimize=True) - - toc = time.time() - - mem_max_used, mem_total = mem_mon.read_and_stop() - time_diff = time.time()-start_time - - info = f""" - {prompt} - Steps: {steps}, Sampler: {sampler_name}, CFG scale: {cfg_scale}, Seed: {seed}{', Denoising strength: '+str(denoising_strength) if init_img is not None else ''}{', GFPGAN' if use_GFPGAN and st.session_state["GFPGAN"] is not None else ''}{', '+realesrgan_model_name if use_RealESRGAN and st.session_state["RealESRGAN"] is not None else ''}{', Prompt Matrix Mode.' if prompt_matrix else ''}""".strip() - stats = f''' - Took { round(time_diff, 2) }s total ({ round(time_diff/(len(all_prompts)),2) }s per image) - Peak memory usage: { -(mem_max_used // -1_048_576) } MiB / { -(mem_total // -1_048_576) } MiB / { round(mem_max_used/mem_total*100, 3) }%''' - - for comment in comments: - info += "\n\n" + comment - - #mem_mon.stop() - #del mem_mon - torch_gc() - - return output_images, seed, info, stats - - -def resize_image(resize_mode, im, width, height): - LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS) - if resize_mode == 0: - res = im.resize((width, height), resample=LANCZOS) - elif resize_mode == 1: - ratio = width / height - src_ratio = im.width / im.height - - src_w = width if ratio > src_ratio else im.width * height // im.height - src_h = height if ratio <= src_ratio else im.height * width // im.width - - resized = im.resize((src_w, src_h), resample=LANCZOS) - res = Image.new("RGBA", (width, height)) - res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2)) - else: - ratio = width / height - src_ratio = im.width / im.height - - src_w = width if ratio < src_ratio else im.width * height // im.height - src_h = height if ratio >= src_ratio else im.height * width // im.width - - resized = im.resize((src_w, src_h), resample=LANCZOS) - res = Image.new("RGBA", (width, height)) - res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2)) - - if ratio < src_ratio: - fill_height = height // 2 - src_h // 2 - res.paste(resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0)) - res.paste(resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)), box=(0, fill_height + src_h)) - elif ratio > src_ratio: - fill_width = width // 2 - src_w // 2 - res.paste(resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0)) - res.paste(resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)), box=(fill_width + src_w, 0)) - - return res - -def img2img(prompt: str = '', init_info: any = None, init_info_mask: any = None, mask_mode: int = 0, mask_blur_strength: int = 3, - ddim_steps: int = 50, sampler_name: str = 'DDIM', - n_iter: int = 1, cfg_scale: float = 7.5, denoising_strength: float = 0.8, - seed: int = -1, height: int = 512, width: int = 512, resize_mode: int = 0, fp = None, - variant_amount: float = None, variant_seed: int = None, ddim_eta:float = 0.0, - write_info_files:bool = True, RealESRGAN_model: str = "RealESRGAN_x4plus_anime_6B", - separate_prompts:bool = False, normalize_prompt_weights:bool = True, - save_individual_images: bool = True, save_grid: bool = True, group_by_prompt: bool = True, - save_as_jpg: bool = True, use_GFPGAN: bool = True, use_RealESRGAN: bool = True, loopback: bool = False, - random_seed_loopback: bool = False - ): - - outpath = defaults.general.outdir_img2img or defaults.general.outdir or "outputs/img2img-samples" - err = False - #loopback = False - #skip_save = False - seed = seed_to_int(seed) - - batch_size = 1 - - #prompt_matrix = 0 - #normalize_prompt_weights = 1 in toggles - #loopback = 2 in toggles - #random_seed_loopback = 3 in toggles - #skip_save = 4 not in toggles - #save_grid = 5 in toggles - #sort_samples = 6 in toggles - #write_info_files = 7 in toggles - #write_sample_info_to_log_file = 8 in toggles - #jpg_sample = 9 in toggles - #use_GFPGAN = 10 in toggles - #use_RealESRGAN = 11 in toggles - - if sampler_name == 'PLMS': - sampler = PLMSSampler(st.session_state["model"]) - elif sampler_name == 'DDIM': - sampler = DDIMSampler(st.session_state["model"]) - elif sampler_name == 'k_dpm_2_a': - sampler = KDiffusionSampler(st.session_state["model"],'dpm_2_ancestral') - elif sampler_name == 'k_dpm_2': - sampler = KDiffusionSampler(st.session_state["model"],'dpm_2') - elif sampler_name == 'k_euler_a': - sampler = KDiffusionSampler(st.session_state["model"],'euler_ancestral') - elif sampler_name == 'k_euler': - sampler = KDiffusionSampler(st.session_state["model"],'euler') - elif sampler_name == 'k_heun': - sampler = KDiffusionSampler(st.session_state["model"],'heun') - elif sampler_name == 'k_lms': - sampler = KDiffusionSampler(st.session_state["model"],'lms') - else: - raise Exception("Unknown sampler: " + sampler_name) - - init_img = init_info - init_mask = None - keep_mask = False - - assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]' - t_enc = int(denoising_strength * ddim_steps) - - def init(): - - image = init_img - image = np.array(image).astype(np.float32) / 255.0 - image = image[None].transpose(0, 3, 1, 2) - image = torch.from_numpy(image) - - mask = None - if defaults.general.optimized: - modelFS.to(st.session_state["device"] ) - - init_image = 2. * image - 1. - init_image = init_image.to(st.session_state["device"]) - init_latent = (st.session_state["model"] if not defaults.general.optimized else modelFS).get_first_stage_encoding((st.session_state["model"] if not defaults.general.optimized else modelFS).encode_first_stage(init_image)) # move to latent space - - if defaults.general.optimized: - mem = torch.cuda.memory_allocated()/1e6 - modelFS.to("cpu") - while(torch.cuda.memory_allocated()/1e6 >= mem): - time.sleep(1) - - return init_latent, mask, - - def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name): - t_enc_steps = t_enc - obliterate = False - if ddim_steps == t_enc_steps: - t_enc_steps = t_enc_steps - 1 - obliterate = True - - if sampler_name != 'DDIM': - x0, z_mask = init_data - - sigmas = sampler.model_wrap.get_sigmas(ddim_steps) - noise = x * sigmas[ddim_steps - t_enc_steps - 1] - - xi = x0 + noise - - # Obliterate masked image - if z_mask is not None and obliterate: - random = torch.randn(z_mask.shape, device=xi.device) - xi = (z_mask * noise) + ((1-z_mask) * xi) - - sigma_sched = sigmas[ddim_steps - t_enc_steps - 1:] - model_wrap_cfg = CFGMaskedDenoiser(sampler.model_wrap) - samples_ddim = K.sampling.__dict__[f'sample_{sampler.get_sampler_name()}'](model_wrap_cfg, xi, sigma_sched, - extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, - 'cond_scale': cfg_scale, 'mask': z_mask, 'x0': x0, 'xi': xi}, disable=False, - callback=generation_callback) - else: - - x0, z_mask = init_data - - sampler.make_schedule(ddim_num_steps=ddim_steps, ddim_eta=0.0, verbose=False) - z_enc = sampler.stochastic_encode(x0, torch.tensor([t_enc_steps]*batch_size).to(st.session_state["device"] )) - - # Obliterate masked image - if z_mask is not None and obliterate: - random = torch.randn(z_mask.shape, device=z_enc.device) - z_enc = (z_mask * random) + ((1-z_mask) * z_enc) - - # decode it - samples_ddim = sampler.decode(z_enc, conditioning, t_enc_steps, - unconditional_guidance_scale=cfg_scale, - unconditional_conditioning=unconditional_conditioning, - z_mask=z_mask, x0=x0) - return samples_ddim - - - - if loopback: - output_images, info = None, None - history = [] - initial_seed = None - - do_color_correction = False - try: - from skimage import exposure - do_color_correction = True - except: - print("Install scikit-image to perform color correction on loopback") - - for i in range(1): - if do_color_correction and i == 0: - correction_target = cv2.cvtColor(np.asarray(init_img.copy()), cv2.COLOR_RGB2LAB) - - output_images, seed, info, stats = process_images( - outpath=outpath, - func_init=init, - func_sample=sample, - prompt=prompt, - seed=seed, - sampler_name=sampler_name, - save_grid=save_grid, - batch_size=1, - n_iter=n_iter, - steps=ddim_steps, - cfg_scale=cfg_scale, - width=width, - height=height, - prompt_matrix=separate_prompts, - use_GFPGAN=use_GFPGAN, - use_RealESRGAN=use_RealESRGAN, # Forcefully disable upscaling when using loopback - realesrgan_model_name=RealESRGAN_model, - fp=fp, - normalize_prompt_weights=normalize_prompt_weights, - save_individual_images=save_individual_images, - init_img=init_img, - init_mask=init_mask, - keep_mask=keep_mask, - mask_blur_strength=mask_blur_strength, - denoising_strength=denoising_strength, - resize_mode=resize_mode, - uses_loopback=loopback, - uses_random_seed_loopback=random_seed_loopback, - sort_samples=group_by_prompt, - write_info_files=write_info_files, - jpg_sample=save_as_jpg - ) - - if initial_seed is None: - initial_seed = seed - - init_img = output_images[0] - - if do_color_correction and correction_target is not None: - init_img = Image.fromarray(cv2.cvtColor(exposure.match_histograms( - cv2.cvtColor( - np.asarray(init_img), - cv2.COLOR_RGB2LAB - ), - correction_target, - channel_axis=2 - ), cv2.COLOR_LAB2RGB).astype("uint8")) - - if not random_seed_loopback: - seed = seed + 1 - else: - seed = seed_to_int(None) - - denoising_strength = max(denoising_strength * 0.95, 0.1) - history.append(init_img) - - output_images = history - seed = initial_seed - - else: - output_images, seed, info, stats = process_images( - outpath=outpath, - func_init=init, - func_sample=sample, - prompt=prompt, - seed=seed, - sampler_name=sampler_name, - save_grid=save_grid, - batch_size=batch_size, - n_iter=n_iter, - steps=ddim_steps, - cfg_scale=cfg_scale, - width=width, - height=height, - prompt_matrix=separate_prompts, - use_GFPGAN=use_GFPGAN, - use_RealESRGAN=use_RealESRGAN, - realesrgan_model_name=RealESRGAN_model, - fp=fp, - normalize_prompt_weights=normalize_prompt_weights, - save_individual_images=save_individual_images, - init_img=init_img, - init_mask=init_mask, - keep_mask=keep_mask, - mask_blur_strength=2, - denoising_strength=denoising_strength, - resize_mode=resize_mode, - uses_loopback=loopback, - sort_samples=group_by_prompt, - write_info_files=write_info_files, - jpg_sample=save_as_jpg - ) - - del sampler - - return output_images, seed, info, stats - -#@retry(RuntimeError, tries=3) -def txt2img(prompt: str, ddim_steps: int, sampler_name: str, realesrgan_model_name: str, - n_iter: int, batch_size: int, cfg_scale: float, seed: Union[int, str, None], - height: int, width: int, separate_prompts:bool = False, normalize_prompt_weights:bool = True, - save_individual_images: bool = True, save_grid: bool = True, group_by_prompt: bool = True, - save_as_jpg: bool = True, use_GFPGAN: bool = True, use_RealESRGAN: bool = True, - RealESRGAN_model: str = "RealESRGAN_x4plus_anime_6B", fp = None, variant_amount: float = None, - variant_seed: int = None, ddim_eta:float = 0.0, write_info_files:bool = True): - - outpath = defaults.general.outdir_txt2img or defaults.general.outdir or "outputs/txt2img-samples" - - err = False - seed = seed_to_int(seed) - - #prompt_matrix = 0 in toggles - #normalize_prompt_weights = 1 in toggles - #skip_save = 2 not in toggles - #save_grid = 3 not in toggles - #sort_samples = 4 in toggles - #write_info_files = 5 in toggles - #jpg_sample = 6 in toggles - #use_GFPGAN = 7 in toggles - #use_RealESRGAN = 8 in toggles - - if sampler_name == 'PLMS': - sampler = PLMSSampler(st.session_state["model"]) - elif sampler_name == 'DDIM': - sampler = DDIMSampler(st.session_state["model"]) - elif sampler_name == 'k_dpm_2_a': - sampler = KDiffusionSampler(st.session_state["model"],'dpm_2_ancestral') - elif sampler_name == 'k_dpm_2': - sampler = KDiffusionSampler(st.session_state["model"],'dpm_2') - elif sampler_name == 'k_euler_a': - sampler = KDiffusionSampler(st.session_state["model"],'euler_ancestral') - elif sampler_name == 'k_euler': - sampler = KDiffusionSampler(st.session_state["model"],'euler') - elif sampler_name == 'k_heun': - sampler = KDiffusionSampler(st.session_state["model"],'heun') - elif sampler_name == 'k_lms': - sampler = KDiffusionSampler(st.session_state["model"],'lms') - else: - raise Exception("Unknown sampler: " + sampler_name) - - def init(): - pass - - def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name): - samples_ddim, _ = sampler.sample(S=ddim_steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=cfg_scale, - unconditional_conditioning=unconditional_conditioning, eta=ddim_eta, x_T=x, img_callback=generation_callback, - log_every_t=int(defaults.general.update_preview_frequency)) - - return samples_ddim - - #try: - output_images, seed, info, stats = process_images( - outpath=outpath, - func_init=init, - func_sample=sample, - prompt=prompt, - seed=seed, - sampler_name=sampler_name, - save_grid=save_grid, - batch_size=batch_size, - n_iter=n_iter, - steps=ddim_steps, - cfg_scale=cfg_scale, - width=width, - height=height, - prompt_matrix=separate_prompts, - use_GFPGAN=use_GFPGAN, - use_RealESRGAN=use_RealESRGAN, - realesrgan_model_name=realesrgan_model_name, - fp=fp, - ddim_eta=ddim_eta, - normalize_prompt_weights=normalize_prompt_weights, - save_individual_images=save_individual_images, - sort_samples=group_by_prompt, - write_info_files=write_info_files, - jpg_sample=save_as_jpg, - variant_amount=variant_amount, - variant_seed=variant_seed, - ) - - del sampler - - return output_images, seed, info, stats - - #except RuntimeError as e: - #err = e - #err_msg = f'CRASHED:


Please wait while the program restarts.' - #stats = err_msg - #return [], seed, 'err', stats - - - +#os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" # see issue #152 +#os.environ["CUDA_VISIBLE_DEVICES"] = str(st.session_state["defaults"].general.gpu) # functions to load css locally OR remotely starts here. Options exist for future flexibility. Called as st.markdown with unsafe_allow_html as css injection # TODO, maybe look into async loading the file especially for remote fetching def local_css(file_name): - with open(file_name) as f: - st.markdown(f'', unsafe_allow_html=True) + with open(file_name) as f: + st.markdown(f'', unsafe_allow_html=True) def remote_css(url): - st.markdown(f'', unsafe_allow_html=True) + st.markdown(f'', unsafe_allow_html=True) def load_css(isLocal, nameOrURL): if(isLocal): @@ -1466,289 +57,89 @@ def load_css(isLocal, nameOrURL): else: remote_css(nameOrURL) - -# main functions to define streamlit layout here def layout(): - - st.set_page_config(page_title="Stable Diffusion Playground", layout="wide", initial_sidebar_state="collapsed") + """Layout functions to define all the streamlit layout here.""" + st.set_page_config(page_title="Stable Diffusion Playground", layout="wide") with st.empty(): # load css as an external file, function has an option to local or remote url. Potential use when running from cloud infra that might not have access to local path. load_css(True, 'frontend/css/streamlit.main.css') - + # check if the models exist on their respective folders - if os.path.exists(os.path.join(defaults.general.GFPGAN_dir, "experiments", "pretrained_models", "GFPGANv1.3.pth")): - GFPGAN_available = True + if os.path.exists(os.path.join(st.session_state["defaults"].general.GFPGAN_dir, "experiments", "pretrained_models", "GFPGANv1.3.pth")): + st.session_state["GFPGAN_available"] = True else: - GFPGAN_available = False + st.session_state["GFPGAN_available"] = False - if os.path.exists(os.path.join(defaults.general.RealESRGAN_dir, "experiments","pretrained_models", f"{defaults.general.RealESRGAN_model}.pth")): - RealESRGAN_available = True + if os.path.exists(os.path.join(st.session_state["defaults"].general.RealESRGAN_dir, "experiments","pretrained_models", f"{st.session_state['defaults'].general.RealESRGAN_model}.pth")): + st.session_state["RealESRGAN_available"] = True else: - RealESRGAN_available = False + st.session_state["RealESRGAN_available"] = False + + # Allow for custom models to be used instead of the default one, + # an example would be Waifu-Diffusion or any other fine tune of stable diffusion + st.session_state["custom_models"]:sorted = [] + for root, dirs, files in os.walk(os.path.join("models", "custom")): + for file in files: + if os.path.splitext(file)[1] == '.ckpt': + #fullpath = os.path.join(root, file) + #print(fullpath) + st.session_state["custom_models"].append(os.path.splitext(file)[0]) + #print (os.path.splitext(file)[0]) + + if len(st.session_state["custom_models"]) > 0: + st.session_state["CustomModel_available"] = True + st.session_state["custom_models"].append("Stable Diffusion v1.4") + else: + st.session_state["CustomModel_available"] = False with st.sidebar: - # we should use an expander and group things together when more options are added so the sidebar is not too messy. + # The global settings section will be moved to the Settings page. #with st.expander("Global Settings:"): - st.write("Global Settings:") - defaults.general.update_preview = st.checkbox("Update Image Preview", value=defaults.general.update_preview, - help="If enabled the image preview will be updated during the generation instead of at the end. You can use the Update Preview \ - Frequency option bellow to customize how frequent it's updated. By default this is enabled and the frequency is set to 1 step.") - defaults.general.update_preview_frequency = st.text_input("Update Image Preview Frequency", value=defaults.general.update_preview_frequency, - help="Frequency in steps at which the the preview image is updated. By default the frequency is set to 1 step.") - - - - txt2img_tab, img2img_tab, txt2video, postprocessing_tab = st.tabs(["Text-to-Image Unified", "Image-to-Image Unified", "Text-to-Video","Post-Processing"]) - - with txt2img_tab: - with st.form("txt2img-inputs"): - st.session_state["generation_mode"] = "txt2img" - - input_col1, generate_col1 = st.columns([10,1]) - with input_col1: - #prompt = st.text_area("Input Text","") - prompt = st.text_input("Input Text","", placeholder="A corgi wearing a top hat as an oil painting.") - - # Every form must have a submit button, the extra blank spaces is a temp way to align it with the input field. Needs to be done in CSS or some other way. - generate_col1.write("") - generate_col1.write("") - generate_button = generate_col1.form_submit_button("Generate") - - # creating the page layout using columns - col1, col2, col3 = st.columns([1,2,1], gap="large") - - with col1: - width = st.slider("Width:", min_value=64, max_value=1024, value=defaults.txt2img.width, step=64) - height = st.slider("Height:", min_value=64, max_value=1024, value=defaults.txt2img.height, step=64) - cfg_scale = st.slider("CFG (Classifier Free Guidance Scale):", min_value=1.0, max_value=30.0, value=defaults.txt2img.cfg_scale, step=0.5, help="How strongly the image should follow the prompt.") - seed = st.text_input("Seed:", value=defaults.txt2img.seed, help=" The seed to use, if left blank a random seed will be generated.") - batch_count = st.slider("Batch count.", min_value=1, max_value=100, value=defaults.txt2img.batch_count, step=1, help="How many iterations or batches of images to generate in total.") - #batch_size = st.slider("Batch size", min_value=1, max_value=250, value=defaults.txt2img.batch_size, step=1, - #help="How many images are at once in a batch.\ - #It increases the VRAM usage a lot but if you have enough VRAM it can reduce the time it takes to finish generation as more images are generated at once.\ - #Default: 1") - - with col2: - preview_tab, gallery_tab = st.tabs(["Preview", "Gallery"]) - - with preview_tab: - #st.write("Image") - #Image for testing - #image = Image.open(requests.get("https://icon-library.com/images/image-placeholder-icon/image-placeholder-icon-13.jpg", stream=True).raw).convert('RGB') - #new_image = image.resize((175, 240)) - #preview_image = st.image(image) - - # create an empty container for the image, progress bar, etc so we can update it later and use session_state to hold them globally. - st.session_state["preview_image"] = st.empty() - - st.session_state["loading"] = st.empty() - - st.session_state["progress_bar_text"] = st.empty() - st.session_state["progress_bar"] = st.empty() - - message = st.empty() - - with gallery_tab: - st.write('Here should be the image gallery, if I could make a grid in streamlit.') - - with col3: - st.session_state.sampling_steps = st.slider("Sampling Steps", value=defaults.txt2img.sampling_steps, min_value=1, max_value=250) - - sampler_name_list = ["k_lms", "k_euler", "k_euler_a", "k_dpm_2", "k_dpm_2_a", "k_heun", "PLMS", "DDIM"] - sampler_name = st.selectbox("Sampling method", sampler_name_list, - index=sampler_name_list.index(defaults.txt2img.default_sampler), help="Sampling method to use. Default: k_euler") - - - - #basic_tab, advanced_tab = st.tabs(["Basic", "Advanced"]) - - #with basic_tab: - #summit_on_enter = st.radio("Submit on enter?", ("Yes", "No"), horizontal=True, - #help="Press the Enter key to summit, when 'No' is selected you can use the Enter key to write multiple lines.") - - with st.expander("Advanced"): - separate_prompts = st.checkbox("Create Prompt Matrix.", value=False, help="Separate multiple prompts using the `|` character, and get all combinations of them.") - normalize_prompt_weights = st.checkbox("Normalize Prompt Weights.", value=True, help="Ensure the sum of all weights add up to 1.0") - save_individual_images = st.checkbox("Save individual images.", value=True, help="Save each image generated before any filter or enhancement is applied.") - save_grid = st.checkbox("Save grid",value=True, help="Save a grid with all the images generated into a single image.") - group_by_prompt = st.checkbox("Group results by prompt", value=True, - help="Saves all the images with the same prompt into the same folder. When using a prompt matrix each prompt combination will have its own folder.") - write_info_files = st.checkbox("Write Info file", value=True, help="Save a file next to the image with informartion about the generation.") - save_as_jpg = st.checkbox("Save samples as jpg", value=False, help="Saves the images as jpg instead of png.") - - if GFPGAN_available: - use_GFPGAN = st.checkbox("Use GFPGAN", value=defaults.txt2img.use_GFPGAN, help="Uses the GFPGAN model to improve faces after the generation. This greatly improve the quality and consistency of faces but uses extra VRAM. Disable if you need the extra VRAM.") - else: - use_GFPGAN = False - - if RealESRGAN_available: - use_RealESRGAN = st.checkbox("Use RealESRGAN", value=defaults.txt2img.use_RealESRGAN, help="Uses the RealESRGAN model to upscale the images after the generation. This greatly improve the quality and lets you have high resolution images but uses extra VRAM. Disable if you need the extra VRAM.") - RealESRGAN_model = st.selectbox("RealESRGAN model", ["RealESRGAN_x4plus", "RealESRGAN_x4plus_anime_6B"], index=0) - else: - use_RealESRGAN = False - RealESRGAN_model = "RealESRGAN_x4plus" - - variant_amount = st.slider("Variant Amount:", value=defaults.txt2img.variant_amount, min_value=0.0, max_value=1.0, step=0.01) - variant_seed = st.text_input("Variant Seed:", value=defaults.txt2img.seed, help="The seed to use when generating a variant, if left blank a random seed will be generated.") - - - if generate_button: - #print("Loading models") - # load the models when we hit the generate button for the first time, it wont be loaded after that so dont worry. - load_models(False, use_GFPGAN, use_RealESRGAN, RealESRGAN_model) - - try: - output_images, seed, info, stats = txt2img(prompt, st.session_state.sampling_steps, sampler_name, RealESRGAN_model, batch_count, 1, - cfg_scale, seed, height, width, separate_prompts, normalize_prompt_weights, save_individual_images, - save_grid, group_by_prompt, save_as_jpg, use_GFPGAN, use_RealESRGAN, RealESRGAN_model, fp=defaults.general.fp, - variant_amount=variant_amount, variant_seed=variant_seed, write_info_files=write_info_files) + #st.write("Global Settings:") + #defaults.general.update_preview = st.checkbox("Update Image Preview", value=defaults.general.update_preview, + #help="If enabled the image preview will be updated during the generation instead of at the end. You can use the Update Preview \ + #Frequency option bellow to customize how frequent it's updated. By default this is enabled and the frequency is set to 1 step.") + #st.session_state.update_preview_frequency = st.text_input("Update Image Preview Frequency", value=defaults.general.update_preview_frequency, + #help="Frequency in steps at which the the preview image is updated. By default the frequency is set to 1 step.") + + tabs = on_hover_tabs(tabName=['Stable Diffusion', "Textual Inversion","Model Manager","Settings"], + iconName=['dashboard','model_training' ,'cloud_download', 'settings'], default_choice=0) + + if tabs =='Stable Diffusion': + # txt2img_tab, img2img_tab, txt2vid_tab, postprocessing_tab, concept_library_tab = st.tabs(["Text-to-Image Unified", "Image-to-Image Unified", + # "Text-to-Video","Post-Processing", "Concept Library"]) + txt2img_tab, img2img_tab, txt2vid_tab = st.tabs( + ["Text-to-Image Unified", "Image-to-Image Unified", "Text-to-Video"] + ) + #with home_tab: + #from home import layout + #layout() + + with txt2img_tab: + from txt2img import layout + layout() + + with img2img_tab: + from img2img import layout + layout() + + with txt2vid_tab: + from txt2vid import layout + layout() + + # with concept_library_tab: + # from sd_concept_library import layout + # layout() + + # + elif tabs == 'Model Manager': + from ModelManager import layout + layout() - message.success('Done!', icon="✅") - - except (StopException, KeyError): - print(f"Received Streamlit StopException") - - # this will render all the images at the end of the generation but its better if its moved to a second tab inside col2 and shown as a gallery. - # use the current col2 first tab to show the preview_img and update it as its generated. - #preview_image.image(output_images) - - with img2img_tab: - with st.form("img2img-inputs"): - st.session_state["generation_mode"] = "img2img" - - img2img_input_col, img2img_generate_col = st.columns([10,1]) - with img2img_input_col: - #prompt = st.text_area("Input Text","") - prompt = st.text_input("Input Text","", placeholder="A corgi wearing a top hat as an oil painting.") - - # Every form must have a submit button, the extra blank spaces is a temp way to align it with the input field. Needs to be done in CSS or some other way. - img2img_generate_col.write("") - img2img_generate_col.write("") - generate_button = img2img_generate_col.form_submit_button("Generate") - - - # creating the page layout using columns - col1_img2img_layout, col2_img2img_layout, col3_img2img_layout = st.columns([1,2,2], gap="small") - - with col1_img2img_layout: - st.session_state["sampling_steps"] = st.slider("Sampling Steps", value=defaults.img2img.sampling_steps, min_value=1, max_value=250) - st.session_state["sampler_name"] = st.selectbox("Sampling method", ["k_lms", "k_euler", "k_euler_a", "k_dpm_2", "k_dpm_2_a", "k_heun", "PLMS", "DDIM"], - index=0, help="Sampling method to use. Default: k_lms") - - uploaded_images = st.file_uploader("Upload Image", accept_multiple_files=False, type=["png", "jpg", "jpeg"], - help="Upload an image which will be used for the image to image generation." - ) - - width = st.slider("Width:", min_value=64, max_value=1024, value=defaults.img2img.width, step=64) - height = st.slider("Height:", min_value=64, max_value=1024, value=defaults.img2img.height, step=64) - seed = st.text_input("Seed:", value=defaults.img2img.seed, help=" The seed to use, if left blank a random seed will be generated.") - batch_count = st.slider("Batch count.", min_value=1, max_value=100, value=defaults.img2img.batch_count, step=1, help="How many iterations or batches of images to generate in total.") - - # - with st.expander("Advanced"): - separate_prompts = st.checkbox("Create Prompt Matrix.", value=defaults.img2img.separate_prompts, help="Separate multiple prompts using the `|` character, and get all combinations of them.") - normalize_prompt_weights = st.checkbox("Normalize Prompt Weights.", value=defaults.img2img.normalize_prompt_weights, help="Ensure the sum of all weights add up to 1.0") - loopback = st.checkbox("Loopback.", value=defaults.img2img.loopback, help="Use images from previous batch when creating next batch.") - random_seed_loopback = st.checkbox("Random loopback seed.", value=defaults.img2img.random_seed_loopback, help="Random loopback seed") - save_individual_images = st.checkbox("Save individual images.", value=True, help="Save each image generated before any filter or enhancement is applied.") - save_grid = st.checkbox("Save grid",value=defaults.img2img.save_grid, help="Save a grid with all the images generated into a single image.") - group_by_prompt = st.checkbox("Group results by prompt", value=defaults.img2img.group_by_prompt, - help="Saves all the images with the same prompt into the same folder. When using a prompt matrix each prompt combination will have its own folder.") - write_info_files = st.checkbox("Write Info file", value=True, help="Save a file next to the image with informartion about the generation.") - save_as_jpg = st.checkbox("Save samples as jpg", value=False, help="Saves the images as jpg instead of png.") - - if GFPGAN_available: - use_GFPGAN = st.checkbox("Use GFPGAN", value=defaults.img2img.use_GFPGAN, help="Uses the GFPGAN model to improve faces after the generation.\ - This greatly improve the quality and consistency of faces but uses extra VRAM. Disable if you need the extra VRAM.") - else: - use_GFPGAN = False - - if RealESRGAN_available: - use_RealESRGAN = st.checkbox("Use RealESRGAN", value=defaults.img2img.use_RealESRGAN, help="Uses the RealESRGAN model to upscale the images after the generation.\ - This greatly improve the quality and lets you have high resolution images but uses extra VRAM. Disable if you need the extra VRAM.") - RealESRGAN_model = st.selectbox("RealESRGAN model", ["RealESRGAN_x4plus", "RealESRGAN_x4plus_anime_6B"], index=0) - else: - use_RealESRGAN = False - RealESRGAN_model = "RealESRGAN_x4plus" - - variant_amount = st.slider("Variant Amount:", value=defaults.img2img.variant_amount, min_value=0.0, max_value=1.0, step=0.01) - variant_seed = st.text_input("Variant Seed:", value=defaults.img2img.variant_seed, help="The seed to use when generating a variant, if left blank a random seed will be generated.") - cfg_scale = st.slider("CFG (Classifier Free Guidance Scale):", min_value=1.0, max_value=30.0, value=defaults.img2img.cfg_scale, step=0.5, help="How strongly the image should follow the prompt.") - batch_size = st.slider("Batch size", min_value=1, max_value=100, value=defaults.img2img.batch_size, step=1, - help="How many images are at once in a batch.\ - It increases the VRAM usage a lot but if you have enough VRAM it can reduce the time it takes to finish generation as more images are generated at once.\ - Default: 1") - - st.session_state["denoising_strength"] = st.slider("Denoising Strength:", value=defaults.img2img.denoising_strength, min_value=0.01, max_value=1.0, step=0.01) - - - with col2_img2img_layout: - editor_tab = st.tabs(["Editor"]) - - editor_image = st.empty() - st.session_state["editor_image"] = editor_image - - if uploaded_images: - image = Image.open(uploaded_images).convert('RGB') - #img_array = np.array(image) # if you want to pass it to OpenCV - new_img = image.resize((width, height)) - st.image(new_img) - - - with col3_img2img_layout: - result_tab = st.tabs(["Result"]) - - # create an empty container for the image, progress bar, etc so we can update it later and use session_state to hold them globally. - preview_image = st.empty() - st.session_state["preview_image"] = preview_image - - #st.session_state["loading"] = st.empty() - - st.session_state["progress_bar_text"] = st.empty() - st.session_state["progress_bar"] = st.empty() - - - message = st.empty() - - #if uploaded_images: - #image = Image.open(uploaded_images).convert('RGB') - ##img_array = np.array(image) # if you want to pass it to OpenCV - #new_img = image.resize((width, height)) - #st.image(new_img, use_column_width=True) - - - if generate_button: - #print("Loading models") - # load the models when we hit the generate button for the first time, it wont be loaded after that so dont worry. - load_models(False, use_GFPGAN, use_RealESRGAN, RealESRGAN_model) - if uploaded_images: - image = Image.open(uploaded_images).convert('RGB') - new_img = image.resize((width, height)) - #img_array = np.array(image) # if you want to pass it to OpenCV - - try: - output_images, seed, info, stats = img2img(prompt=prompt, init_info=new_img, ddim_steps=st.session_state["sampling_steps"], - sampler_name=st.session_state["sampler_name"], n_iter=batch_count, - cfg_scale=cfg_scale, denoising_strength=st.session_state["denoising_strength"], variant_seed=variant_seed, - seed=seed, width=width, height=height, fp=defaults.general.fp, variant_amount=variant_amount, - ddim_eta=0.0, write_info_files=write_info_files, RealESRGAN_model=RealESRGAN_model, - separate_prompts=separate_prompts, normalize_prompt_weights=normalize_prompt_weights, - save_individual_images=save_individual_images, save_grid=save_grid, - group_by_prompt=group_by_prompt, save_as_jpg=save_as_jpg, use_GFPGAN=use_GFPGAN, - use_RealESRGAN=use_RealESRGAN if not loopback else False, loopback=loopback - ) + # elif tabs == 'Textual Inversion': + # from textual_inversion import layout + # layout() - #show a message when the generation is complete. - message.success('Done!', icon="✅") - - except (StopException, KeyError): - print(f"Received Streamlit StopException") - - # this will render all the images at the end of the generation but its better if its moved to a second tab inside col2 and shown as a gallery. - # use the current col2 first tab to show the preview_img and update it as its generated. - #preview_image.image(output_images, width=750) - - if __name__ == '__main__': layout() \ No newline at end of file diff --git a/scripts/webui_streamlit_old.py b/scripts/webui_streamlit_old.py new file mode 100644 index 0000000..ad1b9da --- /dev/null +++ b/scripts/webui_streamlit_old.py @@ -0,0 +1,2738 @@ +import warnings + +import piexif +import piexif.helper +import json + +import streamlit as st +from streamlit import StopException + +#streamlit components section +from st_on_hover_tabs import on_hover_tabs + +import base64, cv2 +import os, sys, re, random, datetime, timeit +from PIL import Image, ImageFont, ImageDraw, ImageFilter, ImageOps +from PIL.PngImagePlugin import PngInfo +from scipy import integrate +import pandas as pd +import torch +from torchdiffeq import odeint +import k_diffusion as K +import math +import mimetypes +import numpy as np +import pynvml +import threading +import time, inspect +import torch +from torch import autocast +from torchvision import transforms +import torch.nn as nn +import yaml +from typing import Union +from pathlib import Path +#from tqdm import tqdm +from contextlib import nullcontext +from einops import rearrange +from omegaconf import OmegaConf +from io import StringIO +from ldm.models.diffusion.ddim import DDIMSampler +from ldm.models.diffusion.plms import PLMSSampler +from ldm.util import instantiate_from_config + +from retry import retry + +# these are for testing txt2vid, should be removed and we should use things from our own code. +from diffusers import StableDiffusionPipeline +from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler + +#will be used for saving and reading a video made by the txt2vid function +import imageio, io + +# we use python-slugify to make the filenames safe for windows and linux, its better than doing it manually +# install it with 'pip install python-slugify' +from slugify import slugify + +try: + # this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start. + from transformers import logging + + logging.set_verbosity_error() +except: + pass + +# remove some annoying deprecation warnings that show every now and then. +warnings.filterwarnings("ignore", category=DeprecationWarning) + +defaults = OmegaConf.load("configs/webui/webui_streamlit.yaml") +if (os.path.exists("configs/webui/userconfig_streamlit.yaml")): + user_defaults = OmegaConf.load("configs/webui/userconfig_streamlit.yaml"); + defaults = OmegaConf.merge(defaults, user_defaults) + +# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the bowser will not show any UI +mimetypes.init() +mimetypes.add_type('application/javascript', '.js') + +# some of those options should not be changed at all because they would break the model, so I removed them from options. +opt_C = 4 +opt_f = 8 + +# should and will be moved to a settings menu in the UI at some point +grid_format = [s.lower() for s in defaults.general.grid_format.split(':')] +grid_lossless = False +grid_quality = 100 +if grid_format[0] == 'png': + grid_ext = 'png' + grid_format = 'png' +elif grid_format[0] in ['jpg', 'jpeg']: + grid_quality = int(grid_format[1]) if len(grid_format) > 1 else 100 + grid_ext = 'jpg' + grid_format = 'jpeg' +elif grid_format[0] == 'webp': + grid_quality = int(grid_format[1]) if len(grid_format) > 1 else 100 + grid_ext = 'webp' + grid_format = 'webp' + if grid_quality < 0: # e.g. webp:-100 for lossless mode + grid_lossless = True + grid_quality = abs(grid_quality) + +# should and will be moved to a settings menu in the UI at some point +save_format = [s.lower() for s in defaults.general.save_format.split(':')] +save_lossless = False +save_quality = 100 +if save_format[0] == 'png': + save_ext = 'png' + save_format = 'png' +elif save_format[0] in ['jpg', 'jpeg']: + save_quality = int(save_format[1]) if len(save_format) > 1 else 100 + save_ext = 'jpg' + save_format = 'jpeg' +elif save_format[0] == 'webp': + save_quality = int(save_format[1]) if len(save_format) > 1 else 100 + save_ext = 'webp' + save_format = 'webp' + if save_quality < 0: # e.g. webp:-100 for lossless mode + save_lossless = True + save_quality = abs(save_quality) + +# this should force GFPGAN and RealESRGAN onto the selected gpu as well +os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" # see issue #152 +os.environ["CUDA_VISIBLE_DEVICES"] = str(defaults.general.gpu) + +@retry(tries=5) +def load_models(continue_prev_run = False, use_GFPGAN=False, use_RealESRGAN=False, RealESRGAN_model="RealESRGAN_x4plus", + CustomModel_available=False, custom_model="Stable Diffusion v1.4"): + """Load the different models. We also reuse the models that are already in memory to speed things up instead of loading them again. """ + + print ("Loading models.") + + st.session_state["progress_bar_text"].text("Loading models...") + + # Generate random run ID + # Used to link runs linked w/ continue_prev_run which is not yet implemented + # Use URL and filesystem safe version just in case. + st.session_state["run_id"] = base64.urlsafe_b64encode( + os.urandom(6) + ).decode("ascii") + + # check what models we want to use and if the they are already loaded. + + if use_GFPGAN: + if "GFPGAN" in st.session_state: + print("GFPGAN already loaded") + else: + # Load GFPGAN + if os.path.exists(defaults.general.GFPGAN_dir): + try: + st.session_state["GFPGAN"] = load_GFPGAN() + print("Loaded GFPGAN") + except Exception: + import traceback + print("Error loading GFPGAN:", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + else: + if "GFPGAN" in st.session_state: + del st.session_state["GFPGAN"] + + if use_RealESRGAN: + if "RealESRGAN" in st.session_state and st.session_state["RealESRGAN"].model.name == RealESRGAN_model: + print("RealESRGAN already loaded") + else: + #Load RealESRGAN + try: + # We first remove the variable in case it has something there, + # some errors can load the model incorrectly and leave things in memory. + del st.session_state["RealESRGAN"] + except KeyError: + pass + + if os.path.exists(defaults.general.RealESRGAN_dir): + # st.session_state is used for keeping the models in memory across multiple pages or runs. + st.session_state["RealESRGAN"] = load_RealESRGAN(RealESRGAN_model) + print("Loaded RealESRGAN with model "+ st.session_state["RealESRGAN"].model.name) + + else: + if "RealESRGAN" in st.session_state: + del st.session_state["RealESRGAN"] + + + + if "model" in st.session_state: + if "model" in st.session_state and st.session_state["custom_model"] == custom_model: + print("Model already loaded") + else: + try: + del st.session_state["model"] + except KeyError: + pass + + config = OmegaConf.load(defaults.general.default_model_config) + + if custom_model == defaults.general.default_model: + model = load_model_from_config(config, defaults.general.default_model_path) + else: + model = load_model_from_config(config, os.path.join("models","custom", f"{custom_model}.ckpt")) + + st.session_state["custom_model"] = custom_model + st.session_state["device"] = torch.device(f"cuda:{defaults.general.gpu}") if torch.cuda.is_available() else torch.device("cpu") + st.session_state["model"] = (model if defaults.general.no_half else model.half()).to(st.session_state["device"] ) + else: + config = OmegaConf.load(defaults.general.default_model_config) + + if custom_model == defaults.general.default_model: + model = load_model_from_config(config, defaults.general.default_model_path) + else: + model = load_model_from_config(config, os.path.join("models","custom", f"{custom_model}.ckpt")) + + st.session_state["custom_model"] = custom_model + st.session_state["device"] = torch.device(f"cuda:{defaults.general.gpu}") if torch.cuda.is_available() else torch.device("cpu") + st.session_state["model"] = (model if defaults.general.no_half else model.half()).to(st.session_state["device"] ) + + print("Model loaded.") + + +def load_model_from_config(config, ckpt, verbose=False): + + print(f"Loading model from {ckpt}") + + pl_sd = torch.load(ckpt, map_location="cpu") + if "global_step" in pl_sd: + print(f"Global Step: {pl_sd['global_step']}") + sd = pl_sd["state_dict"] + model = instantiate_from_config(config.model) + m, u = model.load_state_dict(sd, strict=False) + if len(m) > 0 and verbose: + print("missing keys:") + print(m) + if len(u) > 0 and verbose: + print("unexpected keys:") + print(u) + + model.cuda() + model.eval() + return model + +def load_sd_from_config(ckpt, verbose=False): + print(f"Loading model from {ckpt}") + pl_sd = torch.load(ckpt, map_location="cpu") + if "global_step" in pl_sd: + print(f"Global Step: {pl_sd['global_step']}") + sd = pl_sd["state_dict"] + return sd +# +@retry(tries=5) +def generation_callback(img, i=0): + + try: + if i == 0: + if img['i']: i = img['i'] + except TypeError: + pass + + + if i % int(defaults.general.update_preview_frequency) == 0 and defaults.general.update_preview: + #print (img) + #print (type(img)) + # The following lines will convert the tensor we got on img to an actual image we can render on the UI. + # It can probably be done in a better way for someone who knows what they're doing. I don't. + #print (img,isinstance(img, torch.Tensor)) + if isinstance(img, torch.Tensor): + x_samples_ddim = (st.session_state["model"] if not defaults.general.optimized else modelFS).decode_first_stage(img) + else: + # When using the k Diffusion samplers they return a dict instead of a tensor that look like this: + # {'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised} + x_samples_ddim = (st.session_state["model"] if not defaults.general.optimized else modelFS).decode_first_stage(img["denoised"]) + + x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) + + pil_image = transforms.ToPILImage()(x_samples_ddim.squeeze_(0)) + + # update image on the UI so we can see the progress + st.session_state["preview_image"].image(pil_image) + + # Show a progress bar so we can keep track of the progress even when the image progress is not been shown, + # Dont worry, it doesnt affect the performance. + if st.session_state["generation_mode"] == "txt2img": + percent = int(100 * float(i+1 if i+1 < st.session_state.sampling_steps else st.session_state.sampling_steps)/float(st.session_state.sampling_steps)) + st.session_state["progress_bar_text"].text( + f"Running step: {i+1 if i+1 < st.session_state.sampling_steps else st.session_state.sampling_steps}/{st.session_state.sampling_steps} {percent if percent < 100 else 100}%") + else: + if st.session_state["generation_mode"] == "img2img": + round_sampling_steps = round(st.session_state.sampling_steps * st.session_state["denoising_strength"]) + percent = int(100 * float(i+1 if i+1 < round_sampling_steps else round_sampling_steps)/float(round_sampling_steps)) + st.session_state["progress_bar_text"].text( + f"""Running step: {i+1 if i+1 < round_sampling_steps else round_sampling_steps}/{round_sampling_steps} {percent if percent < 100 else 100}%""") + else: + if st.session_state["generation_mode"] == "txt2vid": + percent = int(100 * float(i+1 if i+1 < st.session_state.sampling_steps else st.session_state.sampling_steps)/float(st.session_state.sampling_steps)) + st.session_state["progress_bar_text"].text( + f"Running step: {i+1 if i+1 < st.session_state.sampling_steps else st.session_state.sampling_steps}/{st.session_state.sampling_steps}" + f"{percent if percent < 100 else 100}%") + + st.session_state["progress_bar"].progress(percent if percent < 100 else 100) + + + +class MemUsageMonitor(threading.Thread): + stop_flag = False + max_usage = 0 + total = -1 + + def __init__(self, name): + threading.Thread.__init__(self) + self.name = name + + def run(self): + try: + pynvml.nvmlInit() + except: + print(f"[{self.name}] Unable to initialize NVIDIA management. No memory stats. \n") + return + print(f"[{self.name}] Recording max memory usage...\n") + handle = pynvml.nvmlDeviceGetHandleByIndex(defaults.general.gpu) + self.total = pynvml.nvmlDeviceGetMemoryInfo(handle).total + while not self.stop_flag: + m = pynvml.nvmlDeviceGetMemoryInfo(handle) + self.max_usage = max(self.max_usage, m.used) + # print(self.max_usage) + time.sleep(0.1) + print(f"[{self.name}] Stopped recording.\n") + pynvml.nvmlShutdown() + + def read(self): + return self.max_usage, self.total + + def stop(self): + self.stop_flag = True + + def read_and_stop(self): + self.stop_flag = True + return self.max_usage, self.total + +class CFGMaskedDenoiser(nn.Module): + def __init__(self, model): + super().__init__() + self.inner_model = model + + def forward(self, x, sigma, uncond, cond, cond_scale, mask, x0, xi): + x_in = x + x_in = torch.cat([x_in] * 2) + sigma_in = torch.cat([sigma] * 2) + cond_in = torch.cat([uncond, cond]) + uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2) + denoised = uncond + (cond - uncond) * cond_scale + + if mask is not None: + assert x0 is not None + img_orig = x0 + mask_inv = 1. - mask + denoised = (img_orig * mask_inv) + (mask * denoised) + + return denoised + +class CFGDenoiser(nn.Module): + def __init__(self, model): + super().__init__() + self.inner_model = model + + def forward(self, x, sigma, uncond, cond, cond_scale): + x_in = torch.cat([x] * 2) + sigma_in = torch.cat([sigma] * 2) + cond_in = torch.cat([uncond, cond]) + uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2) + return uncond + (cond - uncond) * cond_scale +def append_zero(x): + return torch.cat([x, x.new_zeros([1])]) +def append_dims(x, target_dims): + """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" + dims_to_append = target_dims - x.ndim + if dims_to_append < 0: + raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less') + return x[(...,) + (None,) * dims_to_append] +def get_sigmas_karras(n, sigma_min, sigma_max, rho=7., device='cpu'): + """Constructs the noise schedule of Karras et al. (2022).""" + ramp = torch.linspace(0, 1, n) + min_inv_rho = sigma_min ** (1 / rho) + max_inv_rho = sigma_max ** (1 / rho) + sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho + return append_zero(sigmas).to(device) + + +def get_sigmas_exponential(n, sigma_min, sigma_max, device='cpu'): + """Constructs an exponential noise schedule.""" + sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), n, device=device).exp() + return append_zero(sigmas) + + +def get_sigmas_vp(n, beta_d=19.9, beta_min=0.1, eps_s=1e-3, device='cpu'): + """Constructs a continuous VP noise schedule.""" + t = torch.linspace(1, eps_s, n, device=device) + sigmas = torch.sqrt(torch.exp(beta_d * t ** 2 / 2 + beta_min * t) - 1) + return append_zero(sigmas) + + +def to_d(x, sigma, denoised): + """Converts a denoiser output to a Karras ODE derivative.""" + return (x - denoised) / append_dims(sigma, x.ndim) +def linear_multistep_coeff(order, t, i, j): + if order - 1 > i: + raise ValueError(f'Order {order} too high for step {i}') + def fn(tau): + prod = 1. + for k in range(order): + if j == k: + continue + prod *= (tau - t[i - k]) / (t[i - j] - t[i - k]) + return prod + return integrate.quad(fn, t[i], t[i + 1], epsrel=1e-4)[0] + +class KDiffusionSampler: + def __init__(self, m, sampler): + self.model = m + self.model_wrap = K.external.CompVisDenoiser(m) + self.schedule = sampler + def get_sampler_name(self): + return self.schedule + def sample(self, S, conditioning, batch_size, shape, verbose, unconditional_guidance_scale, unconditional_conditioning, eta, x_T, img_callback=None, log_every_t=None): + sigmas = self.model_wrap.get_sigmas(S) + x = x_T * sigmas[0] + model_wrap_cfg = CFGDenoiser(self.model_wrap) + samples_ddim = None + samples_ddim = K.sampling.__dict__[f'sample_{self.schedule}'](model_wrap_cfg, x, sigmas, + extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, + 'cond_scale': unconditional_guidance_scale}, disable=False, callback=generation_callback) + # + return samples_ddim, None + + +@torch.no_grad() +def log_likelihood(model, x, sigma_min, sigma_max, extra_args=None, atol=1e-4, rtol=1e-4): + extra_args = {} if extra_args is None else extra_args + s_in = x.new_ones([x.shape[0]]) + v = torch.randint_like(x, 2) * 2 - 1 + fevals = 0 + def ode_fn(sigma, x): + nonlocal fevals + with torch.enable_grad(): + x = x[0].detach().requires_grad_() + denoised = model(x, sigma * s_in, **extra_args) + d = to_d(x, sigma, denoised) + fevals += 1 + grad = torch.autograd.grad((d * v).sum(), x)[0] + d_ll = (v * grad).flatten(1).sum(1) + return d.detach(), d_ll + x_min = x, x.new_zeros([x.shape[0]]) + t = x.new_tensor([sigma_min, sigma_max]) + sol = odeint(ode_fn, x_min, t, atol=atol, rtol=rtol, method='dopri5') + latent, delta_ll = sol[0][-1], sol[1][-1] + ll_prior = torch.distributions.Normal(0, sigma_max).log_prob(latent).flatten(1).sum(1) + return ll_prior + delta_ll, {'fevals': fevals} + + +def create_random_tensors(shape, seeds): + xs = [] + for seed in seeds: + torch.manual_seed(seed) + + # randn results depend on device; gpu and cpu get different results for same seed; + # the way I see it, it's better to do this on CPU, so that everyone gets same result; + # but the original script had it like this so i do not dare change it for now because + # it will break everyone's seeds. + xs.append(torch.randn(shape, device=defaults.general.gpu)) + x = torch.stack(xs) + return x + +def torch_gc(): + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + +def load_GFPGAN(): + model_name = 'GFPGANv1.3' + model_path = os.path.join(defaults.general.GFPGAN_dir, 'experiments/pretrained_models', model_name + '.pth') + if not os.path.isfile(model_path): + raise Exception("GFPGAN model not found at path "+model_path) + + sys.path.append(os.path.abspath(defaults.general.GFPGAN_dir)) + from gfpgan import GFPGANer + + if defaults.general.gfpgan_cpu or defaults.general.extra_models_cpu: + instance = GFPGANer(model_path=model_path, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=torch.device('cpu')) + elif defaults.general.extra_models_gpu: + instance = GFPGANer(model_path=model_path, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=torch.device(f'cuda:{defaults.general.gfpgan_gpu}')) + else: + instance = GFPGANer(model_path=model_path, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=torch.device(f'cuda:{defaults.general.gpu}')) + return instance + +def load_RealESRGAN(model_name: str): + from basicsr.archs.rrdbnet_arch import RRDBNet + RealESRGAN_models = { + 'RealESRGAN_x4plus': RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4), + 'RealESRGAN_x4plus_anime_6B': RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4) + } + + model_path = os.path.join(defaults.general.RealESRGAN_dir, 'experiments/pretrained_models', model_name + '.pth') + if not os.path.exists(os.path.join(defaults.general.RealESRGAN_dir, "experiments","pretrained_models", f"{model_name}.pth")): + raise Exception(model_name+".pth not found at path "+model_path) + + sys.path.append(os.path.abspath(defaults.general.RealESRGAN_dir)) + from realesrgan import RealESRGANer + + if defaults.general.esrgan_cpu or defaults.general.extra_models_cpu: + instance = RealESRGANer(scale=2, model_path=model_path, model=RealESRGAN_models[model_name], pre_pad=0, half=False) # cpu does not support half + instance.device = torch.device('cpu') + instance.model.to('cpu') + elif defaults.general.extra_models_gpu: + instance = RealESRGANer(scale=2, model_path=model_path, model=RealESRGAN_models[model_name], pre_pad=0, half=not defaults.general.no_half, device=torch.device(f'cuda:{defaults.general.esrgan_gpu}')) + else: + instance = RealESRGANer(scale=2, model_path=model_path, model=RealESRGAN_models[model_name], pre_pad=0, half=not defaults.general.no_half, device=torch.device(f'cuda:{defaults.general.gpu}')) + instance.model.name = model_name + + return instance + +prompt_parser = re.compile(""" + (?P # capture group for 'prompt' + [^:]+ # match one or more non ':' characters + ) # end 'prompt' + (?: # non-capture group + :+ # match one or more ':' characters + (?P # capture group for 'weight' + -?\\d+(?:\\.\\d+)? # match positive or negative decimal number + )? # end weight capture group, make optional + \\s* # strip spaces after weight + | # OR + $ # else, if no ':' then match end of line + ) # end non-capture group +""", re.VERBOSE) + +# grabs all text up to the first occurrence of ':' as sub-prompt +# takes the value following ':' as weight +# if ':' has no value defined, defaults to 1.0 +# repeats until no text remaining +def split_weighted_subprompts(input_string, normalize=True): + parsed_prompts = [(match.group("prompt"), float(match.group("weight") or 1)) for match in re.finditer(prompt_parser, input_string)] + if not normalize: + return parsed_prompts + # this probably still doesn't handle negative weights very well + weight_sum = sum(map(lambda x: x[1], parsed_prompts)) + return [(x[0], x[1] / weight_sum) for x in parsed_prompts] + +def slerp(device, t, v0:torch.Tensor, v1:torch.Tensor, DOT_THRESHOLD=0.9995): + v0 = v0.detach().cpu().numpy() + v1 = v1.detach().cpu().numpy() + + dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1))) + if np.abs(dot) > DOT_THRESHOLD: + v2 = (1 - t) * v0 + t * v1 + else: + theta_0 = np.arccos(dot) + sin_theta_0 = np.sin(theta_0) + theta_t = theta_0 * t + sin_theta_t = np.sin(theta_t) + s0 = np.sin(theta_0 - theta_t) / sin_theta_0 + s1 = sin_theta_t / sin_theta_0 + v2 = s0 * v0 + s1 * v1 + + v2 = torch.from_numpy(v2).to(device) + + return v2 + + +def optimize_update_preview_frequency(current_chunk_speed, previous_chunk_speed, update_preview_frequency): + """Find the optimal update_preview_frequency value maximizing + performance while minimizing the time between updates.""" + if current_chunk_speed >= previous_chunk_speed: + #print(f"{current_chunk_speed} >= {previous_chunk_speed}") + update_preview_frequency +=1 + previous_chunk_speed = current_chunk_speed + else: + #print(f"{current_chunk_speed} <= {previous_chunk_speed}") + update_preview_frequency -=1 + previous_chunk_speed = current_chunk_speed + + return current_chunk_speed, previous_chunk_speed, update_preview_frequency + +# ----------------------------------------------------------------------------- + +@torch.no_grad() +def diffuse( + pipe, + cond_embeddings, # text conditioning, should be (1, 77, 768) + cond_latents, # image conditioning, should be (1, 4, 64, 64) + num_inference_steps, + cfg_scale, + eta, + ): + + torch_device = cond_latents.get_device() + + # classifier guidance: add the unconditional embedding + max_length = cond_embeddings.shape[1] # 77 + uncond_input = pipe.tokenizer([""], padding="max_length", max_length=max_length, return_tensors="pt") + uncond_embeddings = pipe.text_encoder(uncond_input.input_ids.to(torch_device))[0] + text_embeddings = torch.cat([uncond_embeddings, cond_embeddings]) + + # if we use LMSDiscreteScheduler, let's make sure latents are mulitplied by sigmas + if isinstance(pipe.scheduler, LMSDiscreteScheduler): + cond_latents = cond_latents * pipe.scheduler.sigmas[0] + + # init the scheduler + accepts_offset = "offset" in set(inspect.signature(pipe.scheduler.set_timesteps).parameters.keys()) + extra_set_kwargs = {} + if accepts_offset: + extra_set_kwargs["offset"] = 1 + + pipe.scheduler.set_timesteps(num_inference_steps + st.session_state.sampling_steps, **extra_set_kwargs) + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + accepts_eta = "eta" in set(inspect.signature(pipe.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + + step_counter = 0 + inference_counter = 0 + current_chunk_speed = 0 + previous_chunk_speed = 0 + + # diffuse! + for i, t in enumerate(pipe.scheduler.timesteps): + start = timeit.default_timer() + + #status_text.text(f"Running step: {step_counter}{total_number_steps} {percent} | {duration:.2f}{speed}") + + # expand the latents for classifier free guidance + latent_model_input = torch.cat([cond_latents] * 2) + if isinstance(pipe.scheduler, LMSDiscreteScheduler): + sigma = pipe.scheduler.sigmas[i] + latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5) + + # predict the noise residual + noise_pred = pipe.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"] + + # cfg + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + cfg_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + if isinstance(pipe.scheduler, LMSDiscreteScheduler): + cond_latents = pipe.scheduler.step(noise_pred, i, cond_latents, **extra_step_kwargs)["prev_sample"] + else: + cond_latents = pipe.scheduler.step(noise_pred, t, cond_latents, **extra_step_kwargs)["prev_sample"] + + #print (st.session_state["update_preview_frequency"]) + #update the preview image if it is enabled and the frequency matches the step_counter + if defaults.general.update_preview: + step_counter += 1 + + if st.session_state.dynamic_preview_frequency: + current_chunk_speed, previous_chunk_speed, defaults.general.update_preview_frequency = optimize_update_preview_frequency( + current_chunk_speed, previous_chunk_speed, defaults.general.update_preview_frequency) + + if defaults.general.update_preview_frequency == step_counter or step_counter == st.session_state.sampling_steps: + #scale and decode the image latents with vae + cond_latents_2 = 1 / 0.18215 * cond_latents + image_2 = pipe.vae.decode(cond_latents_2) + + # generate output numpy image as uint8 + image_2 = (image_2 / 2 + 0.5).clamp(0, 1) + image_2 = image_2.cpu().permute(0, 2, 3, 1).numpy() + image_2 = (image_2[0] * 255).astype(np.uint8) + + st.session_state["preview_image"].image(image_2) + + step_counter = 0 + + duration = timeit.default_timer() - start + + current_chunk_speed = duration + + if duration >= 1: + speed = "s/it" + else: + speed = "it/s" + duration = 1 / duration + + if i > st.session_state.sampling_steps: + inference_counter += 1 + inference_percent = int(100 * float(inference_counter if inference_counter < num_inference_steps else num_inference_steps)/float(num_inference_steps)) + inference_progress = f"{inference_counter if inference_counter < num_inference_steps else num_inference_steps}/{num_inference_steps} {inference_percent}% " + else: + inference_progress = "" + + percent = int(100 * float(i+1 if i+1 < st.session_state.sampling_steps else st.session_state.sampling_steps)/float(st.session_state.sampling_steps)) + frames_percent = int(100 * float(st.session_state.current_frame if st.session_state.current_frame < st.session_state.max_frames else st.session_state.max_frames)/float(st.session_state.max_frames)) + + st.session_state["progress_bar_text"].text( + f"Running step: {i+1 if i+1 < st.session_state.sampling_steps else st.session_state.sampling_steps}/{st.session_state.sampling_steps} " + f"{percent if percent < 100 else 100}% {inference_progress}{duration:.2f}{speed} | " + f"Frame: {st.session_state.current_frame if st.session_state.current_frame < st.session_state.max_frames else st.session_state.max_frames}/{st.session_state.max_frames} " + f"{frames_percent if frames_percent < 100 else 100}% {st.session_state.frame_duration:.2f}{st.session_state.frame_speed}" + ) + st.session_state["progress_bar"].progress(percent if percent < 100 else 100) + + # scale and decode the image latents with vae + cond_latents = 1 / 0.18215 * cond_latents + image = pipe.vae.decode(cond_latents) + + # generate output numpy image as uint8 + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).numpy() + image = (image[0] * 255).astype(np.uint8) + + return image + + +def ModelLoader(models,load=False,unload=False,imgproc_realesrgan_model_name='RealESRGAN_x4plus'): + #get global variables + global_vars = globals() + #check if m is in globals + if unload: + for m in models: + if m in global_vars: + #if it is, delete it + del global_vars[m] + if defaults.general.optimized: + if m == 'model': + del global_vars[m+'FS'] + del global_vars[m+'CS'] + if m =='model': + m='Stable Diffusion' + print('Unloaded ' + m) + if load: + for m in models: + if m not in global_vars or m in global_vars and type(global_vars[m]) == bool: + #if it isn't, load it + if m == 'GFPGAN': + global_vars[m] = load_GFPGAN() + elif m == 'model': + sdLoader = load_sd_from_config() + global_vars[m] = sdLoader[0] + if defaults.general.optimized: + global_vars[m+'CS'] = sdLoader[1] + global_vars[m+'FS'] = sdLoader[2] + elif m == 'RealESRGAN': + global_vars[m] = load_RealESRGAN(imgproc_realesrgan_model_name) + elif m == 'LDSR': + global_vars[m] = load_LDSR() + if m =='model': + m='Stable Diffusion' + print('Loaded ' + m) + torch_gc() + + + +def get_font(fontsize): + fonts = ["arial.ttf", "DejaVuSans.ttf"] + for font_name in fonts: + try: + return ImageFont.truetype(font_name, fontsize) + except OSError: + pass + + # ImageFont.load_default() is practically unusable as it only supports + # latin1, so raise an exception instead if no usable font was found + raise Exception(f"No usable font found (tried {', '.join(fonts)})") + +def load_embeddings(fp): + if fp is not None and hasattr(st.session_state["model"], "embedding_manager"): + st.session_state["model"].embedding_manager.load(fp['name']) + +def image_grid(imgs, batch_size, force_n_rows=None, captions=None): + #print (len(imgs)) + if force_n_rows is not None: + rows = force_n_rows + elif defaults.general.n_rows > 0: + rows = defaults.general.n_rows + elif defaults.general.n_rows == 0: + rows = batch_size + else: + rows = math.sqrt(len(imgs)) + rows = round(rows) + + cols = math.ceil(len(imgs) / rows) + + w, h = imgs[0].size + grid = Image.new('RGB', size=(cols * w, rows * h), color='black') + + fnt = get_font(30) + + for i, img in enumerate(imgs): + grid.paste(img, box=(i % cols * w, i // cols * h)) + if captions and i= 2**32: + n = n >> 32 + return n + +def check_prompt_length(prompt, comments): + """this function tests if prompt is too long, and if so, adds a message to comments""" + + tokenizer = (st.session_state["model"] if not defaults.general.optimized else modelCS).cond_stage_model.tokenizer + max_length = (st.session_state["model"] if not defaults.general.optimized else modelCS).cond_stage_model.max_length + + info = (st.session_state["model"] if not defaults.general.optimized else modelCS).cond_stage_model.tokenizer([prompt], truncation=True, max_length=max_length, + return_overflowing_tokens=True, padding="max_length", return_tensors="pt") + ovf = info['overflowing_tokens'][0] + overflowing_count = ovf.shape[0] + if overflowing_count == 0: + return + + vocab = {v: k for k, v in tokenizer.get_vocab().items()} + overflowing_words = [vocab.get(int(x), "") for x in ovf] + overflowing_text = tokenizer.convert_tokens_to_string(''.join(overflowing_words)) + + comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n") + +def save_sample(image, sample_path_i, filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale, + normalize_prompt_weights, use_GFPGAN, write_info_files, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, + save_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, save_individual_images): + + filename_i = os.path.join(sample_path_i, filename) + + if defaults.general.save_metadata or write_info_files: + # toggles differ for txt2img vs. img2img: + offset = 0 if init_img is None else 2 + toggles = [] + if prompt_matrix: + toggles.append(0) + if normalize_prompt_weights: + toggles.append(1) + if init_img is not None: + if uses_loopback: + toggles.append(2) + if uses_random_seed_loopback: + toggles.append(3) + if save_individual_images: + toggles.append(2 + offset) + if save_grid: + toggles.append(3 + offset) + if sort_samples: + toggles.append(4 + offset) + if write_info_files: + toggles.append(5 + offset) + if use_GFPGAN: + toggles.append(6 + offset) + metadata = \ + dict( + target="txt2img" if init_img is None else "img2img", + prompt=prompts[i], ddim_steps=steps, toggles=toggles, sampler_name=sampler_name, + ddim_eta=ddim_eta, n_iter=n_iter, batch_size=batch_size, cfg_scale=cfg_scale, + seed=seeds[i], width=width, height=height, normalize_prompt_weights=normalize_prompt_weights) + # Not yet any use for these, but they bloat up the files: + # info_dict["init_img"] = init_img + # info_dict["init_mask"] = init_mask + if init_img is not None: + metadata["denoising_strength"] = str(denoising_strength) + metadata["resize_mode"] = resize_mode + + if write_info_files: + with open(f"{filename_i}.yaml", "w", encoding="utf8") as f: + yaml.dump(metadata, f, allow_unicode=True, width=10000) + + if defaults.general.save_metadata: + # metadata = { + # "SD:prompt": prompts[i], + # "SD:seed": str(seeds[i]), + # "SD:width": str(width), + # "SD:height": str(height), + # "SD:steps": str(steps), + # "SD:cfg_scale": str(cfg_scale), + # "SD:normalize_prompt_weights": str(normalize_prompt_weights), + # } + metadata = {"SD:" + k:v for (k,v) in metadata.items()} + + if save_ext == "png": + mdata = PngInfo() + for key in metadata: + mdata.add_text(key, str(metadata[key])) + image.save(f"{filename_i}.png", pnginfo=mdata) + else: + if jpg_sample: + image.save(f"{filename_i}.jpg", quality=save_quality, + optimize=True) + elif save_ext == "webp": + image.save(f"{filename_i}.{save_ext}", f"webp", quality=save_quality, + lossless=save_lossless) + else: + # not sure what file format this is + image.save(f"{filename_i}.{save_ext}", f"{save_ext}") + try: + exif_dict = piexif.load(f"{filename_i}.{save_ext}") + except: + exif_dict = { "Exif": dict() } + exif_dict["Exif"][piexif.ExifIFD.UserComment] = piexif.helper.UserComment.dump( + json.dumps(metadata), encoding="unicode") + piexif.insert(piexif.dump(exif_dict), f"{filename_i}.{save_ext}") + + # render the image on the frontend + st.session_state["preview_image"].image(image) + +def get_next_sequence_number(path, prefix=''): + """ + Determines and returns the next sequence number to use when saving an + image in the specified directory. + + If a prefix is given, only consider files whose names start with that + prefix, and strip the prefix from filenames before extracting their + sequence number. + + The sequence starts at 0. + """ + result = -1 + for p in Path(path).iterdir(): + if p.name.endswith(('.png', '.jpg')) and p.name.startswith(prefix): + tmp = p.name[len(prefix):] + try: + result = max(int(tmp.split('-')[0]), result) + except ValueError: + pass + return result + 1 + + +def oxlamon_matrix(prompt, seed, n_iter, batch_size): + pattern = re.compile(r'(,\s){2,}') + + class PromptItem: + def __init__(self, text, parts, item): + self.text = text + self.parts = parts + if item: + self.parts.append( item ) + + def clean(txt): + return re.sub(pattern, ', ', txt) + + def getrowcount( txt ): + for data in re.finditer( ".*?\\((.*?)\\).*", txt ): + if data: + return len(data.group(1).split("|")) + break + return None + + def repliter( txt ): + for data in re.finditer( ".*?\\((.*?)\\).*", txt ): + if data: + r = data.span(1) + for item in data.group(1).split("|"): + yield (clean(txt[:r[0]-1] + item.strip() + txt[r[1]+1:]), item.strip()) + break + + def iterlist( items ): + outitems = [] + for item in items: + for newitem, newpart in repliter(item.text): + outitems.append( PromptItem(newitem, item.parts.copy(), newpart) ) + + return outitems + + def getmatrix( prompt ): + dataitems = [ PromptItem( prompt[1:].strip(), [], None ) ] + while True: + newdataitems = iterlist( dataitems ) + if len( newdataitems ) == 0: + return dataitems + dataitems = newdataitems + + def classToArrays( items, seed, n_iter ): + texts = [] + parts = [] + seeds = [] + + for item in items: + itemseed = seed + for i in range(n_iter): + texts.append( item.text ) + parts.append( f"Seed: {itemseed}\n" + "\n".join(item.parts) ) + seeds.append( itemseed ) + itemseed += 1 + + return seeds, texts, parts + + all_seeds, all_prompts, prompt_matrix_parts = classToArrays(getmatrix( prompt ), seed, n_iter) + n_iter = math.ceil(len(all_prompts) / batch_size) + + needrows = getrowcount(prompt) + if needrows: + xrows = math.sqrt(len(all_prompts)) + xrows = round(xrows) + # if columns is to much + cols = math.ceil(len(all_prompts) / xrows) + if cols > needrows*4: + needrows *= 2 + + return all_seeds, n_iter, prompt_matrix_parts, all_prompts, needrows + + +import find_noise_for_image +import matched_noise + + +def process_images( + outpath, func_init, func_sample, prompt, seed, sampler_name, save_grid, batch_size, + n_iter, steps, cfg_scale, width, height, prompt_matrix, use_GFPGAN, use_RealESRGAN, realesrgan_model_name, + fp=None, ddim_eta=0.0, normalize_prompt_weights=True, init_img=None, init_mask=None, + mask_blur_strength=3, mask_restore=False, denoising_strength=0.75, noise_mode=0, find_noise_steps=1, resize_mode=None, uses_loopback=False, + uses_random_seed_loopback=False, sort_samples=True, write_info_files=True, jpg_sample=False, + variant_amount=0.0, variant_seed=None, save_individual_images: bool = True): + """this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch""" + assert prompt is not None + torch_gc() + # start time after garbage collection (or before?) + start_time = time.time() + + # We will use this date here later for the folder name, need to start_time if not need + run_start_dt = datetime.datetime.now() + + mem_mon = MemUsageMonitor('MemMon') + mem_mon.start() + + if hasattr(st.session_state["model"], "embedding_manager"): + load_embeddings(fp) + + os.makedirs(outpath, exist_ok=True) + + sample_path = os.path.join(outpath, "samples") + os.makedirs(sample_path, exist_ok=True) + + if not ("|" in prompt) and prompt.startswith("@"): + prompt = prompt[1:] + + comments = [] + + prompt_matrix_parts = [] + simple_templating = False + add_original_image = not (use_RealESRGAN or use_GFPGAN) + + if prompt_matrix: + if prompt.startswith("@"): + simple_templating = True + add_original_image = not (use_RealESRGAN or use_GFPGAN) + all_seeds, n_iter, prompt_matrix_parts, all_prompts, frows = oxlamon_matrix(prompt, seed, n_iter, batch_size) + else: + all_prompts = [] + prompt_matrix_parts = prompt.split("|") + combination_count = 2 ** (len(prompt_matrix_parts) - 1) + for combination_num in range(combination_count): + current = prompt_matrix_parts[0] + + for n, text in enumerate(prompt_matrix_parts[1:]): + if combination_num & (2 ** n) > 0: + current += ("" if text.strip().startswith(",") else ", ") + text + + all_prompts.append(current) + + n_iter = math.ceil(len(all_prompts) / batch_size) + all_seeds = len(all_prompts) * [seed] + + print(f"Prompt matrix will create {len(all_prompts)} images using a total of {n_iter} batches.") + else: + + if not defaults.general.no_verify_input: + try: + check_prompt_length(prompt, comments) + except: + import traceback + print("Error verifying input:", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + + all_prompts = batch_size * n_iter * [prompt] + all_seeds = [seed + x for x in range(len(all_prompts))] + + precision_scope = autocast if defaults.general.precision == "autocast" else nullcontext + output_images = [] + grid_captions = [] + stats = [] + with torch.no_grad(), precision_scope("cuda"), (st.session_state["model"].ema_scope() if not defaults.general.optimized else nullcontext()): + init_data = func_init() + tic = time.time() + + + # if variant_amount > 0.0 create noise from base seed + base_x = None + if variant_amount > 0.0: + target_seed_randomizer = seed_to_int('') # random seed + torch.manual_seed(seed) # this has to be the single starting seed (not per-iteration) + base_x = create_random_tensors([opt_C, height // opt_f, width // opt_f], seeds=[seed]) + # we don't want all_seeds to be sequential from starting seed with variants, + # since that makes the same variants each time, + # so we add target_seed_randomizer as a random offset + for si in range(len(all_seeds)): + all_seeds[si] += target_seed_randomizer + + for n in range(n_iter): + print(f"Iteration: {n+1}/{n_iter}") + prompts = all_prompts[n * batch_size:(n + 1) * batch_size] + captions = prompt_matrix_parts[n * batch_size:(n + 1) * batch_size] + seeds = all_seeds[n * batch_size:(n + 1) * batch_size] + + print(prompt) + + if defaults.general.optimized: + modelCS.to(defaults.general.gpu) + + uc = (st.session_state["model"] if not defaults.general.optimized else modelCS).get_learned_conditioning(len(prompts) * [""]) + + if isinstance(prompts, tuple): + prompts = list(prompts) + + # split the prompt if it has : for weighting + # TODO for speed it might help to have this occur when all_prompts filled?? + weighted_subprompts = split_weighted_subprompts(prompts[0], normalize_prompt_weights) + + # sub-prompt weighting used if more than 1 + if len(weighted_subprompts) > 1: + c = torch.zeros_like(uc) # i dont know if this is correct.. but it works + for i in range(0, len(weighted_subprompts)): + # note if alpha negative, it functions same as torch.sub + c = torch.add(c, (st.session_state["model"] if not defaults.general.optimized else modelCS).get_learned_conditioning(weighted_subprompts[i][0]), alpha=weighted_subprompts[i][1]) + else: # just behave like usual + c = (st.session_state["model"] if not defaults.general.optimized else modelCS).get_learned_conditioning(prompts) + + + shape = [opt_C, height // opt_f, width // opt_f] + + if defaults.general.optimized: + mem = torch.cuda.memory_allocated()/1e6 + modelCS.to("cpu") + while(torch.cuda.memory_allocated()/1e6 >= mem): + time.sleep(1) + + if noise_mode == 1 or noise_mode == 3: + # TODO params for find_noise_to_image + x = torch.cat(batch_size * [find_noise_for_image.find_noise_for_image( + st.session_state["model"], st.session_state["device"], + init_img.convert('RGB'), '', find_noise_steps, 0.0, normalize=True, + generation_callback=generation_callback, + )], dim=0) + else: + # we manually generate all input noises because each one should have a specific seed + x = create_random_tensors(shape, seeds=seeds) + + if variant_amount > 0.0: # we are making variants + # using variant_seed as sneaky toggle, + # when not None or '' use the variant_seed + # otherwise use seeds + if variant_seed != None and variant_seed != '': + specified_variant_seed = seed_to_int(variant_seed) + torch.manual_seed(specified_variant_seed) + seeds = [specified_variant_seed] + # finally, slerp base_x noise to target_x noise for creating a variant + x = slerp(defaults.general.gpu, max(0.0, min(1.0, variant_amount)), base_x, x) + + samples_ddim = func_sample(init_data=init_data, x=x, conditioning=c, unconditional_conditioning=uc, sampler_name=sampler_name) + + if defaults.general.optimized: + modelFS.to(defaults.general.gpu) + + x_samples_ddim = (st.session_state["model"] if not defaults.general.optimized else modelFS).decode_first_stage(samples_ddim) + x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) + + for i, x_sample in enumerate(x_samples_ddim): + sanitized_prompt = slugify(prompts[i]) + + if sort_samples: + full_path = os.path.join(os.getcwd(), sample_path, sanitized_prompt) + + + sanitized_prompt = sanitized_prompt[:220-len(full_path)] + sample_path_i = os.path.join(sample_path, sanitized_prompt) + + #print(f"output folder length: {len(os.path.join(os.getcwd(), sample_path_i))}") + #print(os.path.join(os.getcwd(), sample_path_i)) + + os.makedirs(sample_path_i, exist_ok=True) + base_count = get_next_sequence_number(sample_path_i) + filename = f"{base_count:05}-{steps}_{sampler_name}_{seeds[i]}" + else: + full_path = os.path.join(os.getcwd(), sample_path) + sample_path_i = sample_path + base_count = get_next_sequence_number(sample_path_i) + filename = f"{base_count:05}-{steps}_{sampler_name}_{seeds[i]}_{sanitized_prompt}"[:220-len(full_path)] #same as before + + x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') + x_sample = x_sample.astype(np.uint8) + image = Image.fromarray(x_sample) + original_sample = x_sample + original_filename = filename + + if use_GFPGAN and st.session_state["GFPGAN"] is not None and not use_RealESRGAN: + #skip_save = True # #287 >_> + torch_gc() + cropped_faces, restored_faces, restored_img = st.session_state["GFPGAN"].enhance(x_sample[:,:,::-1], has_aligned=False, only_center_face=False, paste_back=True) + gfpgan_sample = restored_img[:,:,::-1] + gfpgan_image = Image.fromarray(gfpgan_sample) + gfpgan_filename = original_filename + '-gfpgan' + + save_sample(gfpgan_image, sample_path_i, gfpgan_filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale, + normalize_prompt_weights, use_GFPGAN, write_info_files, prompt_matrix, init_img, uses_loopback, + uses_random_seed_loopback, save_grid, sort_samples, sampler_name, ddim_eta, + n_iter, batch_size, i, denoising_strength, resize_mode, save_individual_images=False) + + output_images.append(gfpgan_image) #287 + if simple_templating: + grid_captions.append( captions[i] + "\ngfpgan" ) + + if use_RealESRGAN and st.session_state["RealESRGAN"] is not None and not use_GFPGAN: + #skip_save = True # #287 >_> + torch_gc() + + if st.session_state["RealESRGAN"].model.name != realesrgan_model_name: + #try_loading_RealESRGAN(realesrgan_model_name) + load_models(use_GFPGAN=use_GFPGAN, use_RealESRGAN=use_RealESRGAN, RealESRGAN_model=realesrgan_model_name) + + output, img_mode = st.session_state["RealESRGAN"].enhance(x_sample[:,:,::-1]) + esrgan_filename = original_filename + '-esrgan4x' + esrgan_sample = output[:,:,::-1] + esrgan_image = Image.fromarray(esrgan_sample) + + #save_sample(image, sample_path_i, original_filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale, + #normalize_prompt_weights, use_GFPGAN, write_info_files, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, skip_save, + #save_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode) + + save_sample(esrgan_image, sample_path_i, esrgan_filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale, + normalize_prompt_weights, use_GFPGAN, write_info_files, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, + save_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, save_individual_images=False) + + output_images.append(esrgan_image) #287 + if simple_templating: + grid_captions.append( captions[i] + "\nesrgan" ) + + if use_RealESRGAN and st.session_state["RealESRGAN"] is not None and use_GFPGAN and st.session_state["GFPGAN"] is not None: + #skip_save = True # #287 >_> + torch_gc() + cropped_faces, restored_faces, restored_img = st.session_state["GFPGAN"].enhance(x_sample[:,:,::-1], has_aligned=False, only_center_face=False, paste_back=True) + gfpgan_sample = restored_img[:,:,::-1] + + if st.session_state["RealESRGAN"].model.name != realesrgan_model_name: + #try_loading_RealESRGAN(realesrgan_model_name) + load_models(use_GFPGAN=use_GFPGAN, use_RealESRGAN=use_RealESRGAN, RealESRGAN_model=realesrgan_model_name) + + output, img_mode = st.session_state["RealESRGAN"].enhance(gfpgan_sample[:,:,::-1]) + gfpgan_esrgan_filename = original_filename + '-gfpgan-esrgan4x' + gfpgan_esrgan_sample = output[:,:,::-1] + gfpgan_esrgan_image = Image.fromarray(gfpgan_esrgan_sample) + + save_sample(gfpgan_esrgan_image, sample_path_i, gfpgan_esrgan_filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale, + normalize_prompt_weights, False, write_info_files, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, + save_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, save_individual_images=False) + + output_images.append(gfpgan_esrgan_image) #287 + + if simple_templating: + grid_captions.append( captions[i] + "\ngfpgan_esrgan" ) + + if mask_restore and init_mask: + #init_mask = init_mask if keep_mask else ImageOps.invert(init_mask) + init_mask = init_mask.filter(ImageFilter.GaussianBlur(mask_blur_strength)) + init_mask = init_mask.convert('L') + init_img = init_img.convert('RGB') + image = image.convert('RGB') + + if use_RealESRGAN and st.session_state["RealESRGAN"] is not None: + if st.session_state["RealESRGAN"].model.name != realesrgan_model_name: + #try_loading_RealESRGAN(realesrgan_model_name) + load_models(use_GFPGAN=use_GFPGAN, use_RealESRGAN=use_RealESRGAN, RealESRGAN_model=realesrgan_model_name) + + output, img_mode = st.session_state["RealESRGAN"].enhance(np.array(init_img, dtype=np.uint8)) + init_img = Image.fromarray(output) + init_img = init_img.convert('RGB') + + output, img_mode = st.session_state["RealESRGAN"].enhance(np.array(init_mask, dtype=np.uint8)) + init_mask = Image.fromarray(output) + init_mask = init_mask.convert('L') + + image = Image.composite(init_img, image, init_mask) + + if save_individual_images: + save_sample(image, sample_path_i, filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale, + normalize_prompt_weights, use_GFPGAN, write_info_files, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, + save_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, save_individual_images) + + if not use_GFPGAN or not use_RealESRGAN: + output_images.append(image) + + #if add_original_image or not simple_templating: + #output_images.append(image) + #if simple_templating: + #grid_captions.append( captions[i] ) + + if defaults.general.optimized: + mem = torch.cuda.memory_allocated()/1e6 + modelFS.to("cpu") + while(torch.cuda.memory_allocated()/1e6 >= mem): + time.sleep(1) + + if prompt_matrix or save_grid: + if prompt_matrix: + if simple_templating: + grid = image_grid(output_images, n_iter, force_n_rows=frows, captions=grid_captions) + else: + grid = image_grid(output_images, n_iter, force_n_rows=1 << ((len(prompt_matrix_parts)-1)//2)) + try: + grid = draw_prompt_matrix(grid, width, height, prompt_matrix_parts) + except: + import traceback + print("Error creating prompt_matrix text:", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + else: + grid = image_grid(output_images, batch_size) + + if grid and (batch_size > 1 or n_iter > 1): + output_images.insert(0, grid) + + grid_count = get_next_sequence_number(outpath, 'grid-') + grid_file = f"grid-{grid_count:05}-{seed}_{slugify(prompts[i].replace(' ', '_')[:220-len(full_path)])}.{grid_ext}" + grid.save(os.path.join(outpath, grid_file), grid_format, quality=grid_quality, lossless=grid_lossless, optimize=True) + + toc = time.time() + + mem_max_used, mem_total = mem_mon.read_and_stop() + time_diff = time.time()-start_time + + info = f""" + {prompt} + Steps: {steps}, Sampler: {sampler_name}, CFG scale: {cfg_scale}, Seed: {seed}{', Denoising strength: '+str(denoising_strength) if init_img is not None else ''}{', GFPGAN' if use_GFPGAN and st.session_state["GFPGAN"] is not None else ''}{', '+realesrgan_model_name if use_RealESRGAN and st.session_state["RealESRGAN"] is not None else ''}{', Prompt Matrix Mode.' if prompt_matrix else ''}""".strip() + stats = f''' + Took { round(time_diff, 2) }s total ({ round(time_diff/(len(all_prompts)),2) }s per image) + Peak memory usage: { -(mem_max_used // -1_048_576) } MiB / { -(mem_total // -1_048_576) } MiB / { round(mem_max_used/mem_total*100, 3) }%''' + + for comment in comments: + info += "\n\n" + comment + + #mem_mon.stop() + #del mem_mon + torch_gc() + + return output_images, seed, info, stats + + +def resize_image(resize_mode, im, width, height): + LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS) + if resize_mode == 0: + res = im.resize((width, height), resample=LANCZOS) + elif resize_mode == 1: + ratio = width / height + src_ratio = im.width / im.height + + src_w = width if ratio > src_ratio else im.width * height // im.height + src_h = height if ratio <= src_ratio else im.height * width // im.width + + resized = im.resize((src_w, src_h), resample=LANCZOS) + res = Image.new("RGBA", (width, height)) + res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2)) + else: + ratio = width / height + src_ratio = im.width / im.height + + src_w = width if ratio < src_ratio else im.width * height // im.height + src_h = height if ratio >= src_ratio else im.height * width // im.width + + resized = im.resize((src_w, src_h), resample=LANCZOS) + res = Image.new("RGBA", (width, height)) + res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2)) + + if ratio < src_ratio: + fill_height = height // 2 - src_h // 2 + res.paste(resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0)) + res.paste(resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)), box=(0, fill_height + src_h)) + elif ratio > src_ratio: + fill_width = width // 2 - src_w // 2 + res.paste(resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0)) + res.paste(resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)), box=(fill_width + src_w, 0)) + + return res + +import skimage + +def img2img(prompt: str = '', init_info: any = None, init_info_mask: any = None, mask_mode: int = 0, mask_blur_strength: int = 3, + mask_restore: bool = False, ddim_steps: int = 50, sampler_name: str = 'DDIM', + n_iter: int = 1, cfg_scale: float = 7.5, denoising_strength: float = 0.8, + seed: int = -1, noise_mode: int = 0, find_noise_steps: str = "", height: int = 512, width: int = 512, resize_mode: int = 0, fp = None, + variant_amount: float = None, variant_seed: int = None, ddim_eta:float = 0.0, + write_info_files:bool = True, RealESRGAN_model: str = "RealESRGAN_x4plus_anime_6B", + separate_prompts:bool = False, normalize_prompt_weights:bool = True, + save_individual_images: bool = True, save_grid: bool = True, group_by_prompt: bool = True, + save_as_jpg: bool = True, use_GFPGAN: bool = True, use_RealESRGAN: bool = True, loopback: bool = False, + random_seed_loopback: bool = False + ): + + outpath = defaults.general.outdir_img2img or defaults.general.outdir or "outputs/img2img-samples" + err = False + #loopback = False + #skip_save = False + seed = seed_to_int(seed) + + batch_size = 1 + + #prompt_matrix = 0 + #normalize_prompt_weights = 1 in toggles + #loopback = 2 in toggles + #random_seed_loopback = 3 in toggles + #skip_save = 4 not in toggles + #save_grid = 5 in toggles + #sort_samples = 6 in toggles + #write_info_files = 7 in toggles + #write_sample_info_to_log_file = 8 in toggles + #jpg_sample = 9 in toggles + #use_GFPGAN = 10 in toggles + #use_RealESRGAN = 11 in toggles + + if sampler_name == 'PLMS': + sampler = PLMSSampler(st.session_state["model"]) + elif sampler_name == 'DDIM': + sampler = DDIMSampler(st.session_state["model"]) + elif sampler_name == 'k_dpm_2_a': + sampler = KDiffusionSampler(st.session_state["model"],'dpm_2_ancestral') + elif sampler_name == 'k_dpm_2': + sampler = KDiffusionSampler(st.session_state["model"],'dpm_2') + elif sampler_name == 'k_euler_a': + sampler = KDiffusionSampler(st.session_state["model"],'euler_ancestral') + elif sampler_name == 'k_euler': + sampler = KDiffusionSampler(st.session_state["model"],'euler') + elif sampler_name == 'k_heun': + sampler = KDiffusionSampler(st.session_state["model"],'heun') + elif sampler_name == 'k_lms': + sampler = KDiffusionSampler(st.session_state["model"],'lms') + else: + raise Exception("Unknown sampler: " + sampler_name) + + def process_init_mask(init_mask: Image): + if init_mask.mode == "RGBA": + init_mask = init_mask.convert('RGBA') + background = Image.new('RGBA', init_mask.size, (0, 0, 0)) + init_mask = Image.alpha_composite(background, init_mask) + init_mask = init_mask.convert('RGB') + return init_mask + + init_img = init_info + init_mask = None + if mask_mode == 0: + if init_info_mask: + init_mask = process_init_mask(init_info_mask) + elif mask_mode == 1: + if init_info_mask: + init_mask = process_init_mask(init_info_mask) + init_mask = ImageOps.invert(init_mask) + elif mask_mode == 2: + init_img_transparency = init_img.split()[-1].convert('L')#.point(lambda x: 255 if x > 0 else 0, mode='1') + init_mask = init_img_transparency + init_mask = init_mask.convert("RGB") + init_mask = resize_image(resize_mode, init_mask, width, height) + init_mask = init_mask.convert("RGB") + + assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]' + t_enc = int(denoising_strength * ddim_steps) + + if init_mask is not None and (noise_mode == 2 or noise_mode == 3) and init_img is not None: + noise_q = 0.99 + color_variation = 0.0 + mask_blend_factor = 1.0 + + np_init = (np.asarray(init_img.convert("RGB"))/255.0).astype(np.float64) # annoyingly complex mask fixing + np_mask_rgb = 1. - (np.asarray(ImageOps.invert(init_mask).convert("RGB"))/255.0).astype(np.float64) + np_mask_rgb -= np.min(np_mask_rgb) + np_mask_rgb /= np.max(np_mask_rgb) + np_mask_rgb = 1. - np_mask_rgb + np_mask_rgb_hardened = 1. - (np_mask_rgb < 0.99).astype(np.float64) + blurred = skimage.filters.gaussian(np_mask_rgb_hardened[:], sigma=16., channel_axis=2, truncate=32.) + blurred2 = skimage.filters.gaussian(np_mask_rgb_hardened[:], sigma=16., channel_axis=2, truncate=32.) + #np_mask_rgb_dilated = np_mask_rgb + blurred # fixup mask todo: derive magic constants + #np_mask_rgb = np_mask_rgb + blurred + np_mask_rgb_dilated = np.clip((np_mask_rgb + blurred2) * 0.7071, 0., 1.) + np_mask_rgb = np.clip((np_mask_rgb + blurred) * 0.7071, 0., 1.) + + noise_rgb = matched_noise.get_matched_noise(np_init, np_mask_rgb, noise_q, color_variation) + blend_mask_rgb = np.clip(np_mask_rgb_dilated,0.,1.) ** (mask_blend_factor) + noised = noise_rgb[:] + blend_mask_rgb **= (2.) + noised = np_init[:] * (1. - blend_mask_rgb) + noised * blend_mask_rgb + + np_mask_grey = np.sum(np_mask_rgb, axis=2)/3. + ref_mask = np_mask_grey < 1e-3 + + all_mask = np.ones((height, width), dtype=bool) + noised[all_mask,:] = skimage.exposure.match_histograms(noised[all_mask,:]**1., noised[ref_mask,:], channel_axis=1) + + init_img = Image.fromarray(np.clip(noised * 255., 0., 255.).astype(np.uint8), mode="RGB") + st.session_state["editor_image"].image(init_img) # debug + + def init(): + image = init_img.convert('RGB') + image = np.array(image).astype(np.float32) / 255.0 + image = image[None].transpose(0, 3, 1, 2) + image = torch.from_numpy(image) + + mask_channel = None + if init_mask: + alpha = resize_image(resize_mode, init_mask, width // 8, height // 8) + mask_channel = alpha.split()[-1] + + mask = None + if mask_channel is not None: + mask = np.array(mask_channel).astype(np.float32) / 255.0 + mask = (1 - mask) + mask = np.tile(mask, (4, 1, 1)) + mask = mask[None].transpose(0, 1, 2, 3) + mask = torch.from_numpy(mask).to(st.session_state["device"]) + + if defaults.general.optimized: + modelFS.to(st.session_state["device"] ) + + init_image = 2. * image - 1. + init_image = init_image.to(st.session_state["device"]) + init_latent = (st.session_state["model"] if not defaults.general.optimized else modelFS).get_first_stage_encoding((st.session_state["model"] if not defaults.general.optimized else modelFS).encode_first_stage(init_image)) # move to latent space + + if defaults.general.optimized: + mem = torch.cuda.memory_allocated()/1e6 + modelFS.to("cpu") + while(torch.cuda.memory_allocated()/1e6 >= mem): + time.sleep(1) + + return init_latent, mask, + + def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name): + t_enc_steps = t_enc + obliterate = False + if ddim_steps == t_enc_steps: + t_enc_steps = t_enc_steps - 1 + obliterate = True + + if sampler_name != 'DDIM': + x0, z_mask = init_data + + sigmas = sampler.model_wrap.get_sigmas(ddim_steps) + noise = x * sigmas[ddim_steps - t_enc_steps - 1] + + xi = x0 + noise + + # Obliterate masked image + if z_mask is not None and obliterate: + random = torch.randn(z_mask.shape, device=xi.device) + xi = (z_mask * noise) + ((1-z_mask) * xi) + + sigma_sched = sigmas[ddim_steps - t_enc_steps - 1:] + model_wrap_cfg = CFGMaskedDenoiser(sampler.model_wrap) + samples_ddim = K.sampling.__dict__[f'sample_{sampler.get_sampler_name()}'](model_wrap_cfg, xi, sigma_sched, + extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, + 'cond_scale': cfg_scale, 'mask': z_mask, 'x0': x0, 'xi': xi}, disable=False, + callback=generation_callback) + else: + + x0, z_mask = init_data + + sampler.make_schedule(ddim_num_steps=ddim_steps, ddim_eta=0.0, verbose=False) + z_enc = sampler.stochastic_encode(x0, torch.tensor([t_enc_steps]*batch_size).to(st.session_state["device"] )) + + # Obliterate masked image + if z_mask is not None and obliterate: + random = torch.randn(z_mask.shape, device=z_enc.device) + z_enc = (z_mask * random) + ((1-z_mask) * z_enc) + + # decode it + samples_ddim = sampler.decode(z_enc, conditioning, t_enc_steps, + unconditional_guidance_scale=cfg_scale, + unconditional_conditioning=unconditional_conditioning, + z_mask=z_mask, x0=x0) + return samples_ddim + + + + if loopback: + output_images, info = None, None + history = [] + initial_seed = None + + do_color_correction = False + try: + from skimage import exposure + do_color_correction = True + except: + print("Install scikit-image to perform color correction on loopback") + + for i in range(n_iter): + if do_color_correction and i == 0: + correction_target = cv2.cvtColor(np.asarray(init_img.copy()), cv2.COLOR_RGB2LAB) + + output_images, seed, info, stats = process_images( + outpath=outpath, + func_init=init, + func_sample=sample, + prompt=prompt, + seed=seed, + sampler_name=sampler_name, + save_grid=save_grid, + batch_size=1, + n_iter=1, + steps=ddim_steps, + cfg_scale=cfg_scale, + width=width, + height=height, + prompt_matrix=separate_prompts, + use_GFPGAN=use_GFPGAN, + use_RealESRGAN=use_RealESRGAN, # Forcefully disable upscaling when using loopback + realesrgan_model_name=RealESRGAN_model, + fp=fp, + normalize_prompt_weights=normalize_prompt_weights, + save_individual_images=save_individual_images, + init_img=init_img, + init_mask=init_mask, + mask_blur_strength=mask_blur_strength, + mask_restore=mask_restore, + denoising_strength=denoising_strength, + noise_mode=noise_mode, + find_noise_steps=find_noise_steps, + resize_mode=resize_mode, + uses_loopback=loopback, + uses_random_seed_loopback=random_seed_loopback, + sort_samples=group_by_prompt, + write_info_files=write_info_files, + jpg_sample=save_as_jpg + ) + + if initial_seed is None: + initial_seed = seed + + init_img = output_images[0] + + if do_color_correction and correction_target is not None: + init_img = Image.fromarray(cv2.cvtColor(exposure.match_histograms( + cv2.cvtColor( + np.asarray(init_img), + cv2.COLOR_RGB2LAB + ), + correction_target, + channel_axis=2 + ), cv2.COLOR_LAB2RGB).astype("uint8")) + + if not random_seed_loopback: + seed = seed + 1 + else: + seed = seed_to_int(None) + + denoising_strength = max(denoising_strength * 0.95, 0.1) + history.append(init_img) + + output_images = history + seed = initial_seed + + else: + output_images, seed, info, stats = process_images( + outpath=outpath, + func_init=init, + func_sample=sample, + prompt=prompt, + seed=seed, + sampler_name=sampler_name, + save_grid=save_grid, + batch_size=batch_size, + n_iter=n_iter, + steps=ddim_steps, + cfg_scale=cfg_scale, + width=width, + height=height, + prompt_matrix=separate_prompts, + use_GFPGAN=use_GFPGAN, + use_RealESRGAN=use_RealESRGAN, + realesrgan_model_name=RealESRGAN_model, + fp=fp, + normalize_prompt_weights=normalize_prompt_weights, + save_individual_images=save_individual_images, + init_img=init_img, + init_mask=init_mask, + mask_blur_strength=mask_blur_strength, + denoising_strength=denoising_strength, + noise_mode=noise_mode, + find_noise_steps=find_noise_steps, + mask_restore=mask_restore, + resize_mode=resize_mode, + uses_loopback=loopback, + sort_samples=group_by_prompt, + write_info_files=write_info_files, + jpg_sample=save_as_jpg + ) + + del sampler + + return output_images, seed, info, stats + +@retry((RuntimeError, KeyError) , tries=3) +def txt2img(prompt: str, ddim_steps: int, sampler_name: str, realesrgan_model_name: str, + n_iter: int, batch_size: int, cfg_scale: float, seed: Union[int, str, None], + height: int, width: int, separate_prompts:bool = False, normalize_prompt_weights:bool = True, + save_individual_images: bool = True, save_grid: bool = True, group_by_prompt: bool = True, + save_as_jpg: bool = True, use_GFPGAN: bool = True, use_RealESRGAN: bool = True, + RealESRGAN_model: str = "RealESRGAN_x4plus_anime_6B", fp = None, variant_amount: float = None, + variant_seed: int = None, ddim_eta:float = 0.0, write_info_files:bool = True): + + outpath = defaults.general.outdir_txt2img or defaults.general.outdir or "outputs/txt2img-samples" + + err = False + seed = seed_to_int(seed) + + #prompt_matrix = 0 in toggles + #normalize_prompt_weights = 1 in toggles + #skip_save = 2 not in toggles + #save_grid = 3 not in toggles + #sort_samples = 4 in toggles + #write_info_files = 5 in toggles + #jpg_sample = 6 in toggles + #use_GFPGAN = 7 in toggles + #use_RealESRGAN = 8 in toggles + + if sampler_name == 'PLMS': + sampler = PLMSSampler(st.session_state["model"]) + elif sampler_name == 'DDIM': + sampler = DDIMSampler(st.session_state["model"]) + elif sampler_name == 'k_dpm_2_a': + sampler = KDiffusionSampler(st.session_state["model"],'dpm_2_ancestral') + elif sampler_name == 'k_dpm_2': + sampler = KDiffusionSampler(st.session_state["model"],'dpm_2') + elif sampler_name == 'k_euler_a': + sampler = KDiffusionSampler(st.session_state["model"],'euler_ancestral') + elif sampler_name == 'k_euler': + sampler = KDiffusionSampler(st.session_state["model"],'euler') + elif sampler_name == 'k_heun': + sampler = KDiffusionSampler(st.session_state["model"],'heun') + elif sampler_name == 'k_lms': + sampler = KDiffusionSampler(st.session_state["model"],'lms') + else: + raise Exception("Unknown sampler: " + sampler_name) + + def init(): + pass + + def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name): + samples_ddim, _ = sampler.sample(S=ddim_steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=cfg_scale, + unconditional_conditioning=unconditional_conditioning, eta=ddim_eta, x_T=x, img_callback=generation_callback, + log_every_t=int(defaults.general.update_preview_frequency)) + + return samples_ddim + + #try: + output_images, seed, info, stats = process_images( + outpath=outpath, + func_init=init, + func_sample=sample, + prompt=prompt, + seed=seed, + sampler_name=sampler_name, + save_grid=save_grid, + batch_size=batch_size, + n_iter=n_iter, + steps=ddim_steps, + cfg_scale=cfg_scale, + width=width, + height=height, + prompt_matrix=separate_prompts, + use_GFPGAN=use_GFPGAN, + use_RealESRGAN=use_RealESRGAN, + realesrgan_model_name=realesrgan_model_name, + fp=fp, + ddim_eta=ddim_eta, + normalize_prompt_weights=normalize_prompt_weights, + save_individual_images=save_individual_images, + sort_samples=group_by_prompt, + write_info_files=write_info_files, + jpg_sample=save_as_jpg, + variant_amount=variant_amount, + variant_seed=variant_seed, + ) + + del sampler + + return output_images, seed, info, stats + + #except RuntimeError as e: + #err = e + #err_msg = f'CRASHED:


Please wait while the program restarts.' + #stats = err_msg + #return [], seed, 'err', stats + + +# +def txt2vid( + # -------------------------------------- + # args you probably want to change + prompts = ["blueberry spaghetti", "strawberry spaghetti"], # prompt to dream about + gpu:int = defaults.general.gpu, # id of the gpu to run on + #name:str = 'test', # name of this project, for the output directory + #rootdir:str = defaults.general.outdir, + num_steps:int = 200, # number of steps between each pair of sampled points + max_frames:int = 10000, # number of frames to write and then exit the script + num_inference_steps:int = 50, # more (e.g. 100, 200 etc) can create slightly better images + cfg_scale:float = 5.0, # can depend on the prompt. usually somewhere between 3-10 is good + do_loop = False, + use_lerp_for_text = False, + seeds = None, + quality:int = 100, # for jpeg compression of the output images + eta:float = 0.0, + width:int = 256, + height:int = 256, + weights_path = "CompVis/stable-diffusion-v1-4", + scheduler="klms", # choices: default, ddim, klms + disable_tqdm = False, + #----------------------------------------------- + beta_start = 0.0001, + beta_end = 0.00012, + beta_schedule = "scaled_linear" + ): + """ + prompt = ["blueberry spaghetti", "strawberry spaghetti"], # prompt to dream about + gpu:int = defaults.general.gpu, # id of the gpu to run on + #name:str = 'test', # name of this project, for the output directory + #rootdir:str = defaults.general.outdir, + num_steps:int = 200, # number of steps between each pair of sampled points + max_frames:int = 10000, # number of frames to write and then exit the script + num_inference_steps:int = 50, # more (e.g. 100, 200 etc) can create slightly better images + cfg_scale:float = 5.0, # can depend on the prompt. usually somewhere between 3-10 is good + do_loop = False, + use_lerp_for_text = False, + seed = None, + quality:int = 100, # for jpeg compression of the output images + eta:float = 0.0, + width:int = 256, + height:int = 256, + weights_path = "CompVis/stable-diffusion-v1-4", + scheduler="klms", # choices: default, ddim, klms + disable_tqdm = False, + beta_start = 0.0001, + beta_end = 0.00012, + beta_schedule = "scaled_linear" + """ + mem_mon = MemUsageMonitor('MemMon') + mem_mon.start() + + + seeds = seed_to_int(seeds) + + # We add an extra frame because most + # of the time the first frame is just the noise. + max_frames +=1 + + assert torch.cuda.is_available() + assert height % 8 == 0 and width % 8 == 0 + torch.manual_seed(seeds) + torch_device = f"cuda:{gpu}" + + # init the output dir + sanitized_prompt = slugify(prompts) + + full_path = os.path.join(os.getcwd(), defaults.general.outdir, "txt2vid-samples", "samples", sanitized_prompt) + + if len(full_path) > 220: + sanitized_prompt = sanitized_prompt[:220-len(full_path)] + full_path = os.path.join(os.getcwd(), defaults.general.outdir, "txt2vid-samples", "samples", sanitized_prompt) + + os.makedirs(full_path, exist_ok=True) + + # Write prompt info to file in output dir so we can keep track of what we did + if st.session_state.write_info_files: + with open(os.path.join(full_path , f'{slugify(str(seeds))}_config.json' if len(prompts) > 1 else "prompts_config.json"), "w") as outfile: + outfile.write(json.dumps( + dict( + prompts = prompts, + gpu = gpu, + num_steps = num_steps, + max_frames = max_frames, + num_inference_steps = num_inference_steps, + cfg_scale = cfg_scale, + do_loop = do_loop, + use_lerp_for_text = use_lerp_for_text, + seeds = seeds, + quality = quality, + eta = eta, + width = width, + height = height, + weights_path = weights_path, + scheduler=scheduler, + disable_tqdm = disable_tqdm, + beta_start = beta_start, + beta_end = beta_end, + beta_schedule = beta_schedule + ), + indent=2, + sort_keys=False, + )) + + #print(scheduler) + default_scheduler = PNDMScheduler( + beta_start=beta_start, beta_end=beta_end, beta_schedule=beta_schedule + ) + # ------------------------------------------------------------------------------ + #Schedulers + ddim_scheduler = DDIMScheduler( + beta_start=beta_start, + beta_end=beta_end, + beta_schedule=beta_schedule, + clip_sample=False, + set_alpha_to_one=False, + ) + + klms_scheduler = LMSDiscreteScheduler( + beta_start=beta_start, beta_end=beta_end, beta_schedule=beta_schedule + ) + + SCHEDULERS = dict(default=default_scheduler, ddim=ddim_scheduler, klms=klms_scheduler) + + # ------------------------------------------------------------------------------ + + #if weights_path == "Stable Diffusion v1.4": + #weights_path = "CompVis/stable-diffusion-v1-4" + #else: + #weights_path = os.path.join("./models", "custom", f"{weights_path}.ckpt") + + try: + if "model" in st.session_state: + del st.session_state["model"] + except: + pass + + #print (st.session_state["weights_path"] != weights_path) + + try: + if not st.session_state["pipe"] or st.session_state["weights_path"] != weights_path: + if st.session_state["weights_path"] != weights_path: + del st.session_state["weights_path"] + + st.session_state["weights_path"] = weights_path + st.session_state["pipe"] = StableDiffusionPipeline.from_pretrained( + weights_path, + use_local_file=True, + use_auth_token=True, + #torch_dtype=torch.float16 if not defaults.general.no_half else None, + revision="fp16" if not defaults.general.no_half else None + ) + + st.session_state["pipe"].unet.to(torch_device) + st.session_state["pipe"].vae.to(torch_device) + st.session_state["pipe"].text_encoder.to(torch_device) + print("Tx2Vid Model Loaded") + else: + print("Tx2Vid Model already Loaded") + + except: + #del st.session_state["weights_path"] + #del st.session_state["pipe"] + + st.session_state["weights_path"] = weights_path + st.session_state["pipe"] = StableDiffusionPipeline.from_pretrained( + weights_path, + use_local_file=True, + use_auth_token=True, + #torch_dtype=torch.float16 if not defaults.general.no_half else None, + revision="fp16" if not defaults.general.no_half else None + ) + + st.session_state["pipe"].unet.to(torch_device) + st.session_state["pipe"].vae.to(torch_device) + st.session_state["pipe"].text_encoder.to(torch_device) + print("Tx2Vid Model Loaded") + + st.session_state["pipe"].scheduler = SCHEDULERS[scheduler] + + # get the conditional text embeddings based on the prompt + text_input = st.session_state["pipe"].tokenizer(prompts, padding="max_length", max_length=st.session_state["pipe"].tokenizer.model_max_length, truncation=True, return_tensors="pt") + cond_embeddings = st.session_state["pipe"].text_encoder(text_input.input_ids.to(torch_device))[0] # shape [1, 77, 768] + + # sample a source + init1 = torch.randn((1, st.session_state["pipe"].unet.in_channels, height // 8, width // 8), device=torch_device) + + if do_loop: + prompts = [prompts, prompts] + seeds = [seeds, seeds] + #first_seed, *seeds = seeds + #prompts.append(prompts) + #seeds.append(first_seed) + + + # iterate the loop + frames = [] + frame_index = 0 + + st.session_state["frame_total_duration"] = 0 + st.session_state["frame_total_speed"] = 0 + + try: + while frame_index < max_frames: + st.session_state["frame_duration"] = 0 + st.session_state["frame_speed"] = 0 + st.session_state["current_frame"] = frame_index + + # sample the destination + init2 = torch.randn((1, st.session_state["pipe"].unet.in_channels, height // 8, width // 8), device=torch_device) + + for i, t in enumerate(np.linspace(0, 1, num_steps)): + start = timeit.default_timer() + print(f"COUNT: {frame_index+1}/{num_steps}") + + #if use_lerp_for_text: + #init = torch.lerp(init1, init2, float(t)) + #else: + #init = slerp(gpu, float(t), init1, init2) + + init = slerp(gpu, float(t), init1, init2) + + with autocast("cuda"): + image = diffuse(st.session_state["pipe"], cond_embeddings, init, num_inference_steps, cfg_scale, eta) + + im = Image.fromarray(image) + outpath = os.path.join(full_path, 'frame%06d.png' % frame_index) + im.save(outpath, quality=quality) + + # send the image to the UI to update it + #st.session_state["preview_image"].image(im) + + #append the frames to the frames list so we can use them later. + frames.append(np.asarray(im)) + + #increase frame_index counter. + frame_index += 1 + + st.session_state["current_frame"] = frame_index + + duration = timeit.default_timer() - start + + if duration >= 1: + speed = "s/it" + else: + speed = "it/s" + duration = 1 / duration + + st.session_state["frame_duration"] = duration + st.session_state["frame_speed"] = speed + + init1 = init2 + + except StopException: + pass + + + if st.session_state['save_video']: + # write video to memory + #output = io.BytesIO() + #writer = imageio.get_writer(os.path.join(os.getcwd(), defaults.general.outdir, "txt2vid-samples"), im, extension=".mp4", fps=30) + try: + video_path = os.path.join(os.getcwd(), defaults.general.outdir, "txt2vid-samples","temp.mp4") + writer = imageio.get_writer(video_path, fps=24) + for frame in frames: + writer.append_data(frame) + writer.close() + except: + print("Can't save video, skipping.") + + # show video preview on the UI + st.session_state["preview_video"].video(open(video_path, 'rb').read()) + + mem_max_used, mem_total = mem_mon.read_and_stop() + time_diff = time.time()- start + + info = f""" + {prompts} + Sampling Steps: {num_steps}, Sampler: {scheduler}, CFG scale: {cfg_scale}, Seed: {seeds}, Max Frames: {max_frames}""".strip() + stats = f''' + Took { round(time_diff, 2) }s total ({ round(time_diff/(max_frames),2) }s per image) + Peak memory usage: { -(mem_max_used // -1_048_576) } MiB / { -(mem_total // -1_048_576) } MiB / { round(mem_max_used/mem_total*100, 3) }%''' + + return im, seeds, info, stats + + +# functions to load css locally OR remotely starts here. Options exist for future flexibility. Called as st.markdown with unsafe_allow_html as css injection +# TODO, maybe look into async loading the file especially for remote fetching +def local_css(file_name): + with open(file_name) as f: + st.markdown(f'', unsafe_allow_html=True) + +def remote_css(url): + st.markdown(f'', unsafe_allow_html=True) + +def load_css(isLocal, nameOrURL): + if(isLocal): + local_css(nameOrURL) + else: + remote_css(nameOrURL) + + +# main functions to define streamlit layout here +def layout(): + + st.set_page_config(page_title="Stable Diffusion Playground", layout="wide") + + with st.empty(): + # load css as an external file, function has an option to local or remote url. Potential use when running from cloud infra that might not have access to local path. + load_css(True, 'frontend/css/streamlit.main.css') + + # check if the models exist on their respective folders + if os.path.exists(os.path.join(defaults.general.GFPGAN_dir, "experiments", "pretrained_models", "GFPGANv1.3.pth")): + GFPGAN_available = True + else: + GFPGAN_available = False + + if os.path.exists(os.path.join(defaults.general.RealESRGAN_dir, "experiments","pretrained_models", f"{defaults.general.RealESRGAN_model}.pth")): + RealESRGAN_available = True + else: + RealESRGAN_available = False + + # Allow for custom models to be used instead of the default one, + # an example would be Waifu-Diffusion or any other fine tune of stable diffusion + custom_models:sorted = [] + for root, dirs, files in os.walk(os.path.join("models", "custom")): + for file in files: + if os.path.splitext(file)[1] == '.ckpt': + fullpath = os.path.join(root, file) + #print(fullpath) + custom_models.append(os.path.splitext(file)[0]) + #print (os.path.splitext(file)[0]) + + if len(custom_models) > 0: + CustomModel_available = True + custom_models.append("Stable Diffusion v1.4") + else: + CustomModel_available = False + + with st.sidebar: + # The global settings section will be moved to the Settings page. + #with st.expander("Global Settings:"): + #st.write("Global Settings:") + #defaults.general.update_preview = st.checkbox("Update Image Preview", value=defaults.general.update_preview, + #help="If enabled the image preview will be updated during the generation instead of at the end. You can use the Update Preview \ + #Frequency option bellow to customize how frequent it's updated. By default this is enabled and the frequency is set to 1 step.") + #st.session_state.update_preview_frequency = st.text_input("Update Image Preview Frequency", value=defaults.general.update_preview_frequency, + #help="Frequency in steps at which the the preview image is updated. By default the frequency is set to 1 step.") + + tabs = on_hover_tabs(tabName=['Stable Diffusion', "Textual Inversion","Model Manager","Settings"], + iconName=['dashboard','model_training' ,'cloud_download', 'settings'], default_choice=0) + + + if tabs =='Stable Diffusion': + txt2img_tab, img2img_tab, txt2vid_tab, postprocessing_tab = st.tabs(["Text-to-Image Unified", "Image-to-Image Unified", + "Text-to-Video","Post-Processing"]) + with txt2img_tab: + with st.form("txt2img-inputs"): + st.session_state["generation_mode"] = "txt2img" + + input_col1, generate_col1 = st.columns([10,1]) + + with input_col1: + #prompt = st.text_area("Input Text","") + prompt = st.text_input("Input Text","", placeholder="A corgi wearing a top hat as an oil painting.") + + # Every form must have a submit button, the extra blank spaces is a temp way to align it with the input field. Needs to be done in CSS or some other way. + generate_col1.write("") + generate_col1.write("") + generate_button = generate_col1.form_submit_button("Generate") + + # creating the page layout using columns + col1, col2, col3 = st.columns([1,2,1], gap="large") + + with col1: + width = st.slider("Width:", min_value=64, max_value=1024, value=defaults.txt2img.width, step=64) + height = st.slider("Height:", min_value=64, max_value=1024, value=defaults.txt2img.height, step=64) + cfg_scale = st.slider("CFG (Classifier Free Guidance Scale):", min_value=1.0, max_value=30.0, value=defaults.txt2img.cfg_scale, step=0.5, help="How strongly the image should follow the prompt.") + seed = st.text_input("Seed:", value=defaults.txt2img.seed, help=" The seed to use, if left blank a random seed will be generated.") + batch_count = st.slider("Batch count.", min_value=1, max_value=100, value=defaults.txt2img.batch_count, step=1, help="How many iterations or batches of images to generate in total.") + #batch_size = st.slider("Batch size", min_value=1, max_value=250, value=defaults.txt2img.batch_size, step=1, + #help="How many images are at once in a batch.\ + #It increases the VRAM usage a lot but if you have enough VRAM it can reduce the time it takes to finish generation as more images are generated at once.\ + #Default: 1") + + with st.expander("Preview Settings"): + st.session_state["update_preview"] = st.checkbox("Update Image Preview", value=defaults.txt2img.update_preview, + help="If enabled the image preview will be updated during the generation instead of at the end. \ + You can use the Update Preview \Frequency option bellow to customize how frequent it's updated. \ + By default this is enabled and the frequency is set to 1 step.") + + st.session_state["update_preview_frequency"] = st.text_input("Update Image Preview Frequency", value=defaults.txt2img.update_preview_frequency, + help="Frequency in steps at which the the preview image is updated. By default the frequency \ + is set to 1 step.") + + with col2: + preview_tab, gallery_tab = st.tabs(["Preview", "Gallery"]) + + with preview_tab: + #st.write("Image") + #Image for testing + #image = Image.open(requests.get("https://icon-library.com/images/image-placeholder-icon/image-placeholder-icon-13.jpg", stream=True).raw).convert('RGB') + #new_image = image.resize((175, 240)) + #preview_image = st.image(image) + + # create an empty container for the image, progress bar, etc so we can update it later and use session_state to hold them globally. + st.session_state["preview_image"] = st.empty() + st.session_state["preview_video"] = st.empty() + + st.session_state["loading"] = st.empty() + + st.session_state["progress_bar_text"] = st.empty() + st.session_state["progress_bar"] = st.empty() + + message = st.empty() + + with gallery_tab: + st.write('Here should be the image gallery, if I could make a grid in streamlit.') + + with col3: + # If we have custom models available on the "models/custom" + #folder then we show a menu to select which model we want to use, otherwise we use the main model for SD + if CustomModel_available: + custom_model = st.selectbox("Custom Model:", custom_models, + index=custom_models.index(defaults.general.default_model), + help="Select the model you want to use. This option is only available if you have custom models \ + on your 'models/custom' folder. The model name that will be shown here is the same as the name\ + the file for the model has on said folder, it is recommended to give the .ckpt file a name that \ + will make it easier for you to distinguish it from other models. Default: Stable Diffusion v1.4") + else: + custom_model = "Stable Diffusion v1.4" + + st.session_state.sampling_steps = st.slider("Sampling Steps", value=defaults.txt2img.sampling_steps, min_value=1, max_value=250) + + sampler_name_list = ["k_lms", "k_euler", "k_euler_a", "k_dpm_2", "k_dpm_2_a", "k_heun", "PLMS", "DDIM"] + sampler_name = st.selectbox("Sampling method", sampler_name_list, + index=sampler_name_list.index(defaults.txt2img.default_sampler), help="Sampling method to use. Default: k_euler") + + + + #basic_tab, advanced_tab = st.tabs(["Basic", "Advanced"]) + + #with basic_tab: + #summit_on_enter = st.radio("Submit on enter?", ("Yes", "No"), horizontal=True, + #help="Press the Enter key to summit, when 'No' is selected you can use the Enter key to write multiple lines.") + + with st.expander("Advanced"): + separate_prompts = st.checkbox("Create Prompt Matrix.", value=False, + help="Separate multiple prompts using the `|` character, and get all combinations of them.") + normalize_prompt_weights = st.checkbox("Normalize Prompt Weights.", + value=defaults.txt2img.normalize_prompt_weights, help="Ensure the sum of all weights add up to 1.0") + save_individual_images = st.checkbox("Save individual images.", value=defaults.txt2img.save_individual_images, + help="Save each image generated before any filter or enhancement is applied.") + save_grid = st.checkbox("Save grid",value=defaults.txt2img.save_grid, help="Save a grid with all the images generated into a single image.") + group_by_prompt = st.checkbox("Group results by prompt", value=defaults.txt2img.group_by_prompt, + help="Saves all the images with the same prompt into the same folder. \ + When using a prompt matrix each prompt combination will have its own folder.") + write_info_files = st.checkbox("Write Info file", value=defaults.txt2img.write_info_files, + help="Save a file next to the image with informartion about the generation.") + save_as_jpg = st.checkbox("Save samples as jpg", value=defaults.txt2img.save_as_jpg, help="Saves the images as jpg instead of png.") + + if GFPGAN_available: + use_GFPGAN = st.checkbox("Use GFPGAN", value=defaults.txt2img.use_GFPGAN, + help="Uses the GFPGAN model to improve faces after the generation. This greatly improve the quality and \ + consistency of faces but uses extra VRAM. Disable if you need the extra VRAM.") + else: + use_GFPGAN = False + + if RealESRGAN_available: + use_RealESRGAN = st.checkbox("Use RealESRGAN", value=defaults.txt2img.use_RealESRGAN, + help="Uses the RealESRGAN model to upscale the images after the generation. This greatly improve the \ + quality and lets you have high resolution images but uses extra VRAM. Disable if you need the extra VRAM.") + RealESRGAN_model = st.selectbox("RealESRGAN model", ["RealESRGAN_x4plus", "RealESRGAN_x4plus_anime_6B"], index=0) + else: + use_RealESRGAN = False + RealESRGAN_model = "RealESRGAN_x4plus" + + variant_amount = st.slider("Variant Amount:", value=defaults.txt2img.variant_amount, min_value=0.0, max_value=1.0, step=0.01) + variant_seed = st.text_input("Variant Seed:", value=defaults.txt2img.seed, + help="The seed to use when generating a variant, if left blank a random seed will be generated.") + + + if generate_button: + #print("Loading models") + # load the models when we hit the generate button for the first time, it wont be loaded after that so dont worry. + load_models(False, use_GFPGAN, use_RealESRGAN, RealESRGAN_model, CustomModel_available, custom_model) + + try: + output_images, seed, info, stats = txt2img(prompt, st.session_state.sampling_steps, sampler_name, RealESRGAN_model, batch_count, 1, + cfg_scale, seed, height, width, separate_prompts, normalize_prompt_weights, save_individual_images, + save_grid, group_by_prompt, save_as_jpg, use_GFPGAN, use_RealESRGAN, RealESRGAN_model, fp=defaults.general.fp, + variant_amount=variant_amount, variant_seed=variant_seed, write_info_files=write_info_files) + + message.success('Render Complete: ' + info + '; Stats: ' + stats, icon="✅") + + except KeyError: + output_images, seed, info, stats = txt2img(prompt, st.session_state.sampling_steps, sampler_name, RealESRGAN_model, batch_count, 1, + cfg_scale, seed, height, width, separate_prompts, normalize_prompt_weights, save_individual_images, + save_grid, group_by_prompt, save_as_jpg, use_GFPGAN, use_RealESRGAN, RealESRGAN_model, fp=defaults.general.fp, + variant_amount=variant_amount, variant_seed=variant_seed, write_info_files=write_info_files) + + message.success('Render Complete: ' + info + '; Stats: ' + stats, icon="✅") + + except (StopException): + print(f"Received Streamlit StopException") + + # this will render all the images at the end of the generation but its better if its moved to a second tab inside col2 and shown as a gallery. + # use the current col2 first tab to show the preview_img and update it as its generated. + #preview_image.image(output_images) + + with img2img_tab: + with st.form("img2img-inputs"): + st.session_state["generation_mode"] = "img2img" + + img2img_input_col, img2img_generate_col = st.columns([10,1]) + with img2img_input_col: + #prompt = st.text_area("Input Text","") + prompt = st.text_input("Input Text","", placeholder="A corgi wearing a top hat as an oil painting.") + + # Every form must have a submit button, the extra blank spaces is a temp way to align it with the input field. Needs to be done in CSS or some other way. + img2img_generate_col.write("") + img2img_generate_col.write("") + generate_button = img2img_generate_col.form_submit_button("Generate") + + + # creating the page layout using columns + col1_img2img_layout, col2_img2img_layout, col3_img2img_layout = st.columns([1,2,2], gap="small") + + with col1_img2img_layout: + # If we have custom models available on the "models/custom" + #folder then we show a menu to select which model we want to use, otherwise we use the main model for SD + if CustomModel_available: + custom_model = st.selectbox("Custom Model:", custom_models, + index=custom_models.index(defaults.general.default_model), + help="Select the model you want to use. This option is only available if you have custom models \ + on your 'models/custom' folder. The model name that will be shown here is the same as the name\ + the file for the model has on said folder, it is recommended to give the .ckpt file a name that \ + will make it easier for you to distinguish it from other models. Default: Stable Diffusion v1.4") + else: + custom_model = "Stable Diffusion v1.4" + + st.session_state["sampling_steps"] = st.slider("Sampling Steps", value=defaults.img2img.sampling_steps, min_value=1, max_value=500) + st.session_state["sampler_name"] = st.selectbox("Sampling method", + ["k_lms", "k_euler", "k_euler_a", "k_dpm_2", "k_dpm_2_a", "k_heun", "PLMS", "DDIM"], + index=sampler_name_list.index(defaults.img2img.sampler_name), + help="Sampling method to use.") + + mask_mode_list = ["Mask", "Inverted mask", "Image alpha"] + mask_mode = st.selectbox("Mask Mode", mask_mode_list, + help="Select how you want your image to be masked.\"Mask\" modifies the image where the mask is white.\n\ + \"Inverted mask\" modifies the image where the mask is black. \"Image alpha\" modifies the image where the image is transparent." + ) + mask_mode = mask_mode_list.index(mask_mode) + + width = st.slider("Width:", min_value=64, max_value=1024, value=defaults.img2img.width, step=64) + height = st.slider("Height:", min_value=64, max_value=1024, value=defaults.img2img.height, step=64) + seed = st.text_input("Seed:", value=defaults.img2img.seed, help=" The seed to use, if left blank a random seed will be generated.") + noise_mode_list = ["Seed", "Find Noise", "Matched Noise", "Find+Matched Noise"] + noise_mode = st.selectbox( + "Noise Mode", noise_mode_list, + help="" + ) + noise_mode = noise_mode_list.index(noise_mode) + find_noise_steps = st.slider("Find Noise Steps", value=100, min_value=1, max_value=500) + batch_count = st.slider("Batch count.", min_value=1, max_value=100, value=defaults.img2img.batch_count, step=1, + help="How many iterations or batches of images to generate in total.") + + # + with st.expander("Advanced"): + separate_prompts = st.checkbox("Create Prompt Matrix.", value=defaults.img2img.separate_prompts, + help="Separate multiple prompts using the `|` character, and get all combinations of them.") + normalize_prompt_weights = st.checkbox("Normalize Prompt Weights.", value=defaults.img2img.normalize_prompt_weights, + help="Ensure the sum of all weights add up to 1.0") + loopback = st.checkbox("Loopback.", value=defaults.img2img.loopback, help="Use images from previous batch when creating next batch.") + random_seed_loopback = st.checkbox("Random loopback seed.", value=defaults.img2img.random_seed_loopback, help="Random loopback seed") + save_individual_images = st.checkbox("Save individual images.", value=defaults.img2img.save_individual_images, + help="Save each image generated before any filter or enhancement is applied.") + save_grid = st.checkbox("Save grid",value=defaults.img2img.save_grid, help="Save a grid with all the images generated into a single image.") + group_by_prompt = st.checkbox("Group results by prompt", value=defaults.img2img.group_by_prompt, + help="Saves all the images with the same prompt into the same folder. \ + When using a prompt matrix each prompt combination will have its own folder.") + write_info_files = st.checkbox("Write Info file", value=defaults.img2img.write_info_files, + help="Save a file next to the image with informartion about the generation.") + save_as_jpg = st.checkbox("Save samples as jpg", value=defaults.img2img.save_as_jpg, help="Saves the images as jpg instead of png.") + + if GFPGAN_available: + use_GFPGAN = st.checkbox("Use GFPGAN", value=defaults.img2img.use_GFPGAN, help="Uses the GFPGAN model to improve faces after the generation.\ + This greatly improve the quality and consistency of faces but uses extra VRAM. Disable if you need the extra VRAM.") + else: + use_GFPGAN = False + + if RealESRGAN_available: + use_RealESRGAN = st.checkbox("Use RealESRGAN", value=defaults.img2img.use_RealESRGAN, + help="Uses the RealESRGAN model to upscale the images after the generation.\ + This greatly improve the quality and lets you have high resolution images but uses extra VRAM. Disable if you need the extra VRAM.") + RealESRGAN_model = st.selectbox("RealESRGAN model", ["RealESRGAN_x4plus", "RealESRGAN_x4plus_anime_6B"], index=0) + else: + use_RealESRGAN = False + RealESRGAN_model = "RealESRGAN_x4plus" + + variant_amount = st.slider("Variant Amount:", value=defaults.img2img.variant_amount, min_value=0.0, max_value=1.0, step=0.01) + variant_seed = st.text_input("Variant Seed:", value=defaults.img2img.variant_seed, + help="The seed to use when generating a variant, if left blank a random seed will be generated.") + cfg_scale = st.slider("CFG (Classifier Free Guidance Scale):", min_value=1.0, max_value=30.0, value=defaults.img2img.cfg_scale, step=0.5, + help="How strongly the image should follow the prompt.") + batch_size = st.slider("Batch size", min_value=1, max_value=100, value=defaults.img2img.batch_size, step=1, + help="How many images are at once in a batch.\ + It increases the VRAM usage a lot but if you have enough VRAM it can reduce the time it takes to finish \ + generation as more images are generated at once.\ + Default: 1") + + st.session_state["denoising_strength"] = st.slider("Denoising Strength:", value=defaults.img2img.denoising_strength, + min_value=0.01, max_value=1.0, step=0.01) + + with st.expander("Preview Settings"): + st.session_state["update_preview"] = st.checkbox("Update Image Preview", value=defaults.img2img.update_preview, + help="If enabled the image preview will be updated during the generation instead of at the end. \ + You can use the Update Preview \Frequency option bellow to customize how frequent it's updated. \ + By default this is enabled and the frequency is set to 1 step.") + + st.session_state["update_preview_frequency"] = st.text_input("Update Image Preview Frequency", value=defaults.img2img.update_preview_frequency, + help="Frequency in steps at which the the preview image is updated. By default the frequency \ + is set to 1 step.") + + with col2_img2img_layout: + editor_tab = st.tabs(["Editor"]) + + editor_image = st.empty() + st.session_state["editor_image"] = editor_image + + refresh_button = st.form_submit_button("Refresh") + + masked_image_holder = st.empty() + image_holder = st.empty() + + uploaded_images = st.file_uploader( + "Upload Image", accept_multiple_files=False, type=["png", "jpg", "jpeg", "webp"], + help="Upload an image which will be used for the image to image generation.", + ) + if uploaded_images: + image = Image.open(uploaded_images).convert('RGBA') + new_img = image.resize((width, height)) + image_holder.image(new_img) + + mask_holder = st.empty() + + uploaded_masks = st.file_uploader( + "Upload Mask", accept_multiple_files=False, type=["png", "jpg", "jpeg", "webp"], + help="Upload an mask image which will be used for masking the image to image generation.", + ) + if uploaded_masks: + mask = Image.open(uploaded_masks) + if mask.mode == "RGBA": + mask = mask.convert('RGBA') + background = Image.new('RGBA', mask.size, (0, 0, 0)) + mask = Image.alpha_composite(background, mask) + mask = mask.resize((width, height)) + mask_holder.image(mask) + + if uploaded_images and uploaded_masks: + if mask_mode != 2: + final_img = new_img.copy() + alpha_layer = mask.convert('L') + strength = st.session_state["denoising_strength"] + if mask_mode == 0: + alpha_layer = ImageOps.invert(alpha_layer) + alpha_layer = alpha_layer.point(lambda a: a * strength) + alpha_layer = ImageOps.invert(alpha_layer) + elif mask_mode == 1: + alpha_layer = alpha_layer.point(lambda a: a * strength) + alpha_layer = ImageOps.invert(alpha_layer) + + final_img.putalpha(alpha_layer) + + with masked_image_holder.container(): + st.text("Masked Image Preview") + st.image(final_img) + + + with col3_img2img_layout: + result_tab = st.tabs(["Result"]) + + # create an empty container for the image, progress bar, etc so we can update it later and use session_state to hold them globally. + preview_image = st.empty() + st.session_state["preview_image"] = preview_image + + #st.session_state["loading"] = st.empty() + + st.session_state["progress_bar_text"] = st.empty() + st.session_state["progress_bar"] = st.empty() + + + message = st.empty() + + #if uploaded_images: + #image = Image.open(uploaded_images).convert('RGB') + ##img_array = np.array(image) # if you want to pass it to OpenCV + #new_img = image.resize((width, height)) + #st.image(new_img, use_column_width=True) + + + if generate_button: + #print("Loading models") + # load the models when we hit the generate button for the first time, it wont be loaded after that so dont worry. + load_models(False, use_GFPGAN, use_RealESRGAN, RealESRGAN_model, CustomModel_available, custom_model) + if uploaded_images: + image = Image.open(uploaded_images).convert('RGBA') + new_img = image.resize((width, height)) + #img_array = np.array(image) # if you want to pass it to OpenCV + new_mask = None + if uploaded_masks: + mask = Image.open(uploaded_masks).convert('RGBA') + new_mask = mask.resize((width, height)) + + try: + output_images, seed, info, stats = img2img(prompt=prompt, init_info=new_img, init_info_mask=new_mask, mask_mode=mask_mode, ddim_steps=st.session_state["sampling_steps"], + sampler_name=st.session_state["sampler_name"], n_iter=batch_count, + cfg_scale=cfg_scale, denoising_strength=st.session_state["denoising_strength"], variant_seed=variant_seed, + seed=seed, noise_mode=noise_mode, find_noise_steps=find_noise_steps, width=width, height=height, fp=defaults.general.fp, variant_amount=variant_amount, + ddim_eta=0.0, write_info_files=write_info_files, RealESRGAN_model=RealESRGAN_model, + separate_prompts=separate_prompts, normalize_prompt_weights=normalize_prompt_weights, + save_individual_images=save_individual_images, save_grid=save_grid, + group_by_prompt=group_by_prompt, save_as_jpg=save_as_jpg, use_GFPGAN=use_GFPGAN, + use_RealESRGAN=use_RealESRGAN if not loopback else False, loopback=loopback + ) + + #show a message when the generation is complete. + message.success('Render Complete: ' + info + '; Stats: ' + stats, icon="✅") + + except (StopException, KeyError): + print(f"Received Streamlit StopException") + + # this will render all the images at the end of the generation but its better if its moved to a second tab inside col2 and shown as a gallery. + # use the current col2 first tab to show the preview_img and update it as its generated. + #preview_image.image(output_images, width=750) + + with txt2vid_tab: + with st.form("txt2vid-inputs"): + st.session_state["generation_mode"] = "txt2vid" + + input_col1, generate_col1 = st.columns([10,1]) + with input_col1: + #prompt = st.text_area("Input Text","") + prompt = st.text_input("Input Text","", placeholder="A corgi wearing a top hat as an oil painting.") + + # Every form must have a submit button, the extra blank spaces is a temp way to align it with the input field. Needs to be done in CSS or some other way. + generate_col1.write("") + generate_col1.write("") + generate_button = generate_col1.form_submit_button("Generate") + + # creating the page layout using columns + col1, col2, col3 = st.columns([1,2,1], gap="large") + + with col1: + width = st.slider("Width:", min_value=64, max_value=2048, value=defaults.txt2vid.width, step=64) + height = st.slider("Height:", min_value=64, max_value=2048, value=defaults.txt2vid.height, step=64) + cfg_scale = st.slider("CFG (Classifier Free Guidance Scale):", min_value=1.0, max_value=30.0, value=defaults.txt2vid.cfg_scale, step=0.5, help="How strongly the image should follow the prompt.") + seed = st.text_input("Seed:", value=defaults.txt2vid.seed, help=" The seed to use, if left blank a random seed will be generated.") + batch_count = st.slider("Batch count.", min_value=1, max_value=100, value=defaults.txt2vid.batch_count, step=1, help="How many iterations or batches of images to generate in total.") + #batch_size = st.slider("Batch size", min_value=1, max_value=250, value=defaults.txt2vid.batch_size, step=1, + #help="How many images are at once in a batch.\ + #It increases the VRAM usage a lot but if you have enough VRAM it can reduce the time it takes to finish generation as more images are generated at once.\ + #Default: 1") + + st.session_state["max_frames"] = int(st.text_input("Max Frames:", value=defaults.txt2vid.max_frames, help="Specify the max number of frames you want to generate.")) + + with st.expander("Preview Settings"): + st.session_state["update_preview"] = st.checkbox("Update Image Preview", value=defaults.txt2vid.update_preview, + help="If enabled the image preview will be updated during the generation instead of at the end. \ + You can use the Update Preview \Frequency option bellow to customize how frequent it's updated. \ + By default this is enabled and the frequency is set to 1 step.") + + st.session_state["update_preview_frequency"] = st.text_input("Update Image Preview Frequency", value=defaults.txt2vid.update_preview_frequency, + help="Frequency in steps at which the the preview image is updated. By default the frequency \ + is set to 1 step.") + with col2: + preview_tab, gallery_tab = st.tabs(["Preview", "Gallery"]) + + with preview_tab: + #st.write("Image") + #Image for testing + #image = Image.open(requests.get("https://icon-library.com/images/image-placeholder-icon/image-placeholder-icon-13.jpg", stream=True).raw).convert('RGB') + #new_image = image.resize((175, 240)) + #preview_image = st.image(image) + + # create an empty container for the image, progress bar, etc so we can update it later and use session_state to hold them globally. + st.session_state["preview_image"] = st.empty() + + st.session_state["loading"] = st.empty() + + st.session_state["progress_bar_text"] = st.empty() + st.session_state["progress_bar"] = st.empty() + + generate_video = st.empty() + st.session_state["preview_video"] = st.empty() + + message = st.empty() + + with gallery_tab: + st.write('Here should be the image gallery, if I could make a grid in streamlit.') + + with col3: + # If we have custom models available on the "models/custom" + #folder then we show a menu to select which model we want to use, otherwise we use the main model for SD + #if CustomModel_available: + custom_model = st.selectbox("Custom Model:", defaults.txt2vid.custom_models_list, + index=defaults.txt2vid.custom_models_list.index(defaults.txt2vid.default_model), + help="Select the model you want to use. This option is only available if you have custom models \ + on your 'models/custom' folder. The model name that will be shown here is the same as the name\ + the file for the model has on said folder, it is recommended to give the .ckpt file a name that \ + will make it easier for you to distinguish it from other models. Default: Stable Diffusion v1.4") + + #st.session_state["weights_path"] = custom_model + #else: + #custom_model = "CompVis/stable-diffusion-v1-4" + #st.session_state["weights_path"] = f"CompVis/{slugify(custom_model.lower())}" + + st.session_state.sampling_steps = st.slider("Sampling Steps", value=defaults.txt2vid.sampling_steps, min_value=10, step=10, max_value=500, + help="Number of steps between each pair of sampled points") + st.session_state.num_inference_steps = st.slider("Inference Steps:", value=defaults.txt2vid.num_inference_steps, min_value=10,step=10, max_value=500, + help="Higher values (e.g. 100, 200 etc) can create better images.") + + #sampler_name_list = ["k_lms", "k_euler", "k_euler_a", "k_dpm_2", "k_dpm_2_a", "k_heun", "PLMS", "DDIM"] + #sampler_name = st.selectbox("Sampling method", sampler_name_list, + #index=sampler_name_list.index(defaults.txt2vid.default_sampler), help="Sampling method to use. Default: k_euler") + scheduler_name_list = ["klms", "ddim"] + scheduler_name = st.selectbox("Scheduler:", scheduler_name_list, + index=scheduler_name_list.index(defaults.txt2vid.scheduler_name), help="Scheduler to use. Default: klms") + + beta_scheduler_type_list = ["scaled_linear", "linear"] + beta_scheduler_type = st.selectbox("Beta Schedule Type:", beta_scheduler_type_list, + index=beta_scheduler_type_list.index(defaults.txt2vid.beta_scheduler_type), help="Schedule Type to use. Default: linear") + + + #basic_tab, advanced_tab = st.tabs(["Basic", "Advanced"]) + + #with basic_tab: + #summit_on_enter = st.radio("Submit on enter?", ("Yes", "No"), horizontal=True, + #help="Press the Enter key to summit, when 'No' is selected you can use the Enter key to write multiple lines.") + + with st.expander("Advanced"): + st.session_state["separate_prompts"] = st.checkbox("Create Prompt Matrix.", value=defaults.txt2vid.separate_prompts, + help="Separate multiple prompts using the `|` character, and get all combinations of them.") + st.session_state["normalize_prompt_weights"] = st.checkbox("Normalize Prompt Weights.", + value=defaults.txt2vid.normalize_prompt_weights, help="Ensure the sum of all weights add up to 1.0") + st.session_state["save_individual_images"] = st.checkbox("Save individual images.", + value=defaults.txt2vid.save_individual_images, help="Save each image generated before any filter or enhancement is applied.") + st.session_state["save_video"] = st.checkbox("Save video",value=defaults.txt2vid.save_video, help="Save a video with all the images generated as frames at the end of the generation.") + st.session_state["group_by_prompt"] = st.checkbox("Group results by prompt", value=defaults.txt2vid.group_by_prompt, + help="Saves all the images with the same prompt into the same folder. When using a prompt matrix each prompt combination will have its own folder.") + st.session_state["write_info_files"] = st.checkbox("Write Info file", value=defaults.txt2vid.write_info_files, + help="Save a file next to the image with informartion about the generation.") + st.session_state["dynamic_preview_frequency"] = st.checkbox("Dynamic Preview Frequency", value=defaults.txt2vid.dynamic_preview_frequency, + help="This option tries to find the best value at which we can update \ + the preview image during generation while minimizing the impact it has in performance. Default: True") + st.session_state["do_loop"] = st.checkbox("Do Loop", value=defaults.txt2vid.do_loop, + help="Do loop") + st.session_state["save_as_jpg"] = st.checkbox("Save samples as jpg", value=defaults.txt2vid.save_as_jpg, help="Saves the images as jpg instead of png.") + + if GFPGAN_available: + st.session_state["use_GFPGAN"] = st.checkbox("Use GFPGAN", value=defaults.txt2vid.use_GFPGAN, help="Uses the GFPGAN model to improve faces after the generation. This greatly improve the quality and consistency of faces but uses extra VRAM. Disable if you need the extra VRAM.") + else: + st.session_state["use_GFPGAN"] = False + + if RealESRGAN_available: + st.session_state["use_RealESRGAN"] = st.checkbox("Use RealESRGAN", value=defaults.txt2vid.use_RealESRGAN, + help="Uses the RealESRGAN model to upscale the images after the generation. This greatly improve the quality and lets you have high resolution images but uses extra VRAM. Disable if you need the extra VRAM.") + st.session_state["RealESRGAN_model"] = st.selectbox("RealESRGAN model", ["RealESRGAN_x4plus", "RealESRGAN_x4plus_anime_6B"], index=0) + else: + st.session_state["use_RealESRGAN"] = False + st.session_state["RealESRGAN_model"] = "RealESRGAN_x4plus" + + st.session_state["variant_amount"] = st.slider("Variant Amount:", value=defaults.txt2vid.variant_amount, min_value=0.0, max_value=1.0, step=0.01) + st.session_state["variant_seed"] = st.text_input("Variant Seed:", value=defaults.txt2vid.seed, help="The seed to use when generating a variant, if left blank a random seed will be generated.") + st.session_state["beta_start"] = st.slider("Beta Start:", value=defaults.txt2vid.beta_start, min_value=0.0001, max_value=0.03, step=0.0001, format="%.4f") + st.session_state["beta_end"] = st.slider("Beta End:", value=defaults.txt2vid.beta_end, min_value=0.0001, max_value=0.03, step=0.0001, format="%.4f") + + if generate_button: + #print("Loading models") + # load the models when we hit the generate button for the first time, it wont be loaded after that so dont worry. + #load_models(False, False, False, RealESRGAN_model, CustomModel_available=CustomModel_available, custom_model=custom_model) + + # run video generation + image, seed, info, stats = txt2vid(prompts=prompt, gpu=defaults.general.gpu, + num_steps=st.session_state.sampling_steps, max_frames=int(st.session_state.max_frames), + num_inference_steps=st.session_state.num_inference_steps, + cfg_scale=cfg_scale,do_loop=st.session_state["do_loop"], + seeds=seed, quality=100, eta=0.0, width=width, + height=height, weights_path=custom_model, scheduler=scheduler_name, + disable_tqdm=False, beta_start=st.session_state["beta_start"], beta_end=st.session_state["beta_end"], + beta_schedule=beta_scheduler_type) + + #message.success('Done!', icon="✅") + message.success('Render Complete: ' + info + '; Stats: ' + stats, icon="✅") + + #except (StopException, KeyError): + #print(f"Received Streamlit StopException") + + # this will render all the images at the end of the generation but its better if its moved to a second tab inside col2 and shown as a gallery. + # use the current col2 first tab to show the preview_img and update it as its generated. + #preview_image.image(output_images) + + # + elif tabs == 'Model Manager': + #search = st.text_input(label="Search", placeholder="Type the name of the model you want to search for.", help="") + + csvString = f""" + ,Stable Diffusion v1.4 , ./models/ldm/stable-diffusion-v1 , https://www.googleapis.com/storage/v1/b/aai-blog-files/o/sd-v1-4.ckpt?alt=media + ,GFPGAN v1.3 , ./src/gfpgan/experiments/pretrained_models , https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth + ,RealESRGAN_x4plus , ./src/realesrgan/experiments/pretrained_models , https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth + ,RealESRGAN_x4plus_anime_6B , ./src/realesrgan/experiments/pretrained_models , https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth + ,Waifu Diffusion v1.2 , ./models/custom , http://wd.links.sd:8880/wd-v1-2-full-ema.ckpt + ,TrinArt Stable Diffusion v2 , ./models/custom , https://huggingface.co/naclbit/trinart_stable_diffusion_v2/resolve/main/trinart2_step115000.ckpt + """ + colms = st.columns((1, 3, 5, 5)) + columns = ["№",'Model Name','Save Location','Download Link'] + + # Convert String into StringIO + csvStringIO = StringIO(csvString) + df = pd.read_csv(csvStringIO, sep=",", header=None, names=columns) + + for col, field_name in zip(colms, columns): + # table header + col.write(field_name) + + for x, model_name in enumerate(df["Model Name"]): + col1, col2, col3, col4 = st.columns((1, 3, 4, 6)) + col1.write(x) # index + col2.write(df['Model Name'][x]) + col3.write(df['Save Location'][x]) + col4.write(df['Download Link'][x]) + + + elif tabs == 'Settings': + import Settings + + st.write("Settings") + +if __name__ == '__main__': + layout() diff --git a/setup.py b/setup.py index a24d541..0e768e1 100644 --- a/setup.py +++ b/setup.py @@ -1,7 +1,7 @@ from setuptools import setup, find_packages setup( - name='latent-diffusion', + name='sd-webui', version='0.0.1', description='', packages=find_packages(), diff --git a/webui.sh b/webui.sh index ea7028f..7b07a49 100755 --- a/webui.sh +++ b/webui.sh @@ -37,7 +37,7 @@ if ! conda env list | grep ".*${ENV_NAME}.*" >/dev/null 2>&1; then ENV_UPDATED=1 elif [[ ! -z $CONDA_FORCE_UPDATE && $CONDA_FORCE_UPDATE == "true" ]] || (( $ENV_MODIFIED > $ENV_MODIFIED_CACHED )); then echo "Updating conda env: ${ENV_NAME} ..." - conda env update --file $ENV_FILE --prune + PIP_EXISTS_ACTION=w conda env update --file $ENV_FILE --prune ENV_UPDATED=1 fi @@ -56,4 +56,4 @@ if [ ! -e "models/ldm/stable-diffusion-v1/model.ckpt" ]; then exit 1 fi -python scripts/relauncher.py \ No newline at end of file +python scripts/relauncher.py From 165893bc3610a52cb47498cf007a1864bf188b8a Mon Sep 17 00:00:00 2001 From: hlky <106811348+hlky@users.noreply.github.com> Date: Sun, 18 Sep 2022 14:08:43 +0100 Subject: [PATCH 27/27] Revert "The Merge (#1201)" (#1208) This reverts commit a797312183188867d7b4bdc7f1a1263017540a48. --- .dockerignore | 3 - .env_docker.example | 10 +- .gitignore | 9 +- Dockerfile | 9 +- README.md | 6 +- configs/webui/webui.yaml | 11 +- configs/webui/webui_streamlit.yaml | 190 +- docker-compose.yml | 5 +- docker-reset.sh | 9 +- entrypoint.sh | 79 +- environment.yaml | 60 +- frontend/css/streamlit.main.css | 116 +- frontend/frontend.py | 80 +- frontend/image_metadata.py | 57 - frontend/job_manager.py | 219 +- frontend/ui_functions.py | 4 +- images/nsfw.jpeg | Bin 25276 -> 0 bytes ldm/modules/attention.py | 105 +- ldm/modules/diffusionmodules/model.py | 145 +- ldm/modules/diffusionmodules/util.py | 5 +- scripts/DeforumStableDiffusion.py | 1312 ------------ scripts/ModelManager.py | 46 - scripts/Settings.py | 5 - scripts/home.py | 216 -- scripts/img2img.py | 592 ------ scripts/imglab.py | 161 -- scripts/perlin.py | 48 - scripts/relauncher.py | 4 - scripts/sd_utils.py | 1728 ---------------- scripts/stable_diffusion_pipeline.py | 233 --- scripts/stable_diffusion_walk.py | 218 -- scripts/textual_inversion.py | 57 - scripts/txt2img.py | 368 ---- scripts/txt2vid.py | 780 ------- scripts/webui.py | 577 ++---- scripts/webui_streamlit.py | 1809 +++++++++++++++- scripts/webui_streamlit_old.py | 2738 ------------------------- setup.py | 2 +- webui.sh | 4 +- 39 files changed, 2109 insertions(+), 9911 deletions(-) delete mode 100644 .dockerignore mode change 100755 => 100644 docker-reset.sh delete mode 100644 frontend/image_metadata.py delete mode 100644 images/nsfw.jpeg delete mode 100644 scripts/DeforumStableDiffusion.py delete mode 100644 scripts/ModelManager.py delete mode 100644 scripts/Settings.py delete mode 100644 scripts/home.py delete mode 100644 scripts/img2img.py delete mode 100644 scripts/imglab.py delete mode 100644 scripts/perlin.py delete mode 100644 scripts/sd_utils.py delete mode 100644 scripts/stable_diffusion_pipeline.py delete mode 100644 scripts/stable_diffusion_walk.py delete mode 100644 scripts/textual_inversion.py delete mode 100644 scripts/txt2img.py delete mode 100644 scripts/txt2vid.py delete mode 100644 scripts/webui_streamlit_old.py diff --git a/.dockerignore b/.dockerignore deleted file mode 100644 index 4ac12df..0000000 --- a/.dockerignore +++ /dev/null @@ -1,3 +0,0 @@ -models/ -outputs/ -src/ diff --git a/.env_docker.example b/.env_docker.example index 51eb059..5a34945 100644 --- a/.env_docker.example +++ b/.env_docker.example @@ -6,13 +6,9 @@ CONDA_FORCE_UPDATE=false # (useful to set to false after you're sure the model files are already in place) VALIDATE_MODELS=true -# Automatically relaunch the webui on crashes +#Automatically relaunch the webui on crashes WEBUI_RELAUNCH=true -# Which webui to launch -# WEBUI_SCRIPT=webui_streamlit.py -WEBUI_SCRIPT=webui.py - -# Pass cli arguments to webui.py e.g: -# WEBUI_ARGS=--optimized --extra-models-cpu --gpu=1 --esrgan-gpu=1 --gfpgan-gpu=1 +#Pass cli arguments to webui.py e.g: +#WEBUI_ARGS=--gpu=1 --esrgan-gpu=1 --gfpgan-gpu=1 WEBUI_ARGS= diff --git a/.gitignore b/.gitignore index b014154..4b30236 100644 --- a/.gitignore +++ b/.gitignore @@ -47,21 +47,16 @@ MANIFEST .env_updated condaenv.*.requirements.txt -# Visual Studio directories -.vs/ -.vscode/ # =========================================================================== # # Repo-specific # =========================================================================== # -/configs/webui/userconfig_streamlit.yaml /custom-conda-path.txt /src/* -/outputs -/model_cache +/outputs/* /log/**/*.png /log/log.csv /flagged/* /gfpgan/* /models/* -z_version_env.tmp +z_version_env.tmp \ No newline at end of file diff --git a/Dockerfile b/Dockerfile index 2b061b0..8d5ecb4 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,10 +1,6 @@ FROM nvidia/cuda:11.3.1-runtime-ubuntu20.04 -ENV DEBIAN_FRONTEND=noninteractive \ - PYTHONUNBUFFERED=1 \ - PYTHONIOENCODING=UTF-8 \ - CONDA_DIR=/opt/conda - +ENV DEBIAN_FRONTEND=noninteractive WORKDIR /sd SHELL ["/bin/bash", "-c"] @@ -15,6 +11,7 @@ RUN apt-get update && \ rm -rf /var/lib/apt/lists/* # Install miniconda +ENV CONDA_DIR /opt/conda RUN wget -O ~/miniconda.sh -q --show-progress --progress=bar:force https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && \ /bin/bash ~/miniconda.sh -b -p $CONDA_DIR && \ rm ~/miniconda.sh @@ -23,7 +20,7 @@ ENV PATH=$CONDA_DIR/bin:$PATH # Install font for prompt matrix COPY /data/DejaVuSans.ttf /usr/share/fonts/truetype/ -EXPOSE 7860 8501 +EXPOSE 7860 COPY ./entrypoint.sh /sd/ ENTRYPOINT /sd/entrypoint.sh diff --git a/README.md b/README.md index f5d96ba..36d1dc0 100644 --- a/README.md +++ b/README.md @@ -46,8 +46,8 @@ Features: * Gradio GUI: Idiot-proof, fully featured frontend for both txt2img and img2img generation * No more manually typing parameters, now all you have to do is write your prompt and adjust sliders -* GFPGAN Face Correction 🔥: [Download the model](https://github.com/sd-webui/stable-diffusion-webui/wiki/Installation#optional-additional-models) Automatically correct distorted faces with a built-in GFPGAN option, fixes them in less than half a second -* RealESRGAN Upscaling 🔥: [Download the models](https://github.com/sd-webui/stable-diffusion-webui/wiki/Installation#optional-additional-models) Boosts the resolution of images with a built-in RealESRGAN option +* GFPGAN Face Correction 🔥: [Download the model](https://github.com/sd-webui/stable-diffusion-webui#gfpgan)Automatically correct distorted faces with a built-in GFPGAN option, fixes them in less than half a second +* RealESRGAN Upscaling 🔥: [Download the models](https://github.com/sd-webui/stable-diffusion-webui#realesrgan) Boosts the resolution of images with a built-in RealESRGAN option * :computer: esrgan/gfpgan on cpu support :computer: * Textual inversion 🔥: [info](https://textual-inversion.github.io/) - requires enabling, see [here](https://github.com/hlky/sd-enable-textual-inversion), script works as usual without it enabled * Advanced img2img editor :art: :fire: :art: @@ -106,7 +106,7 @@ that are not in original script. ### GFPGAN Lets you improve faces in pictures using the GFPGAN model. There is a checkbox in every tab to use GFPGAN at 100%, and -also a separate tab that just allows you to use GFPGAN on any picture, with a slider that controls how strong the effect is. +also a separate tab that just allows you to use GFPGAN on any picture, with a slider that controls how strongthe effect is. ![](images/GFPGAN.png) diff --git a/configs/webui/webui.yaml b/configs/webui/webui.yaml index 25d222b..b7bd258 100644 --- a/configs/webui/webui.yaml +++ b/configs/webui/webui.yaml @@ -12,9 +12,8 @@ txt2img: # 5: Write sample info files # 6: write sample info to log file # 7: jpg samples - # 8: Filter NSFW content - # 9: Fix faces using GFPGAN - # 10: Upscale images using RealESRGAN + # 8: Fix faces using GFPGAN + # 9: Upscale images using RealESRGAN toggles: [1, 2, 3, 4, 5] sampler_name: k_lms ddim_eta: 0.0 # legacy name, applies to all algorithms. @@ -41,10 +40,8 @@ img2img: # 6: Sort samples by prompt # 7: Write sample info files # 8: jpg samples - # 9: Color correction - # 10: Filter NSFW content - # 11: Fix faces using GFPGAN - # 12: Upscale images using Real-ESRGAN + # 9: Fix faces using GFPGAN + # 10: Upscale images using Real-ESRGAN toggles: [1, 4, 5, 6, 7] sampler_name: k_lms ddim_eta: 0.0 diff --git a/configs/webui/webui_streamlit.yaml b/configs/webui/webui_streamlit.yaml index 394494c..84263bd 100644 --- a/configs/webui/webui_streamlit.yaml +++ b/configs/webui/webui_streamlit.yaml @@ -1,19 +1,14 @@ # UI defaults configuration file. It is automatically loaded if located at configs/webui/webui_streamlit.yaml. # Any changes made here will be available automatically on the web app without having to stop it. -# You may add overrides in a file named "userconfig_streamlit.yaml" in this folder, which can contain any subset -# of the properties below. general: gpu: 0 outdir: outputs - default_model: "Stable Diffusion v1.4" - default_model_config: "configs/stable-diffusion/v1-inference.yaml" - default_model_path: "models/ldm/stable-diffusion-v1/model.ckpt" - use_sd_concepts_library: True - sd_concepts_library_folder: "models/custom/sd-concepts-library" + ckpt: "models/ldm/stable-diffusion-v1/model.ckpt" + fp: + name: 'embeddings/alex/embeddings_gs-11000.pt' GFPGAN_dir: "./src/gfpgan" RealESRGAN_dir: "./src/realesrgan" RealESRGAN_model: "RealESRGAN_x4plus" - LDSR_dir: "./src/latent-diffusion" outdir_txt2img: outputs/txt2img-samples outdir_img2img: outputs/img2img-samples gfpgan_cpu: False @@ -21,161 +16,88 @@ general: extra_models_cpu: False extra_models_gpu: False save_metadata: True - save_format: "png" skip_grid: False skip_save: False grid_format: "jpg:95" n_rows: -1 no_verify_input: False no_half: False - use_float16: False precision: "autocast" optimized: False optimized_turbo: False - optimized_config: "optimizedSD/v1-inference.yaml" - enable_attention_slicing: False - enable_minimal_memory_usage : False update_preview: True - update_preview_frequency: 5 + update_preview_frequency: 1 txt2img: prompt: height: 512 width: 512 - cfg_scale: 7.5 + cfg_scale: 5.0 seed: "" batch_count: 1 batch_size: 1 - sampling_steps: 30 - default_sampler: "k_euler" + sampling_steps: 50 + default_sampler: "k_lms" separate_prompts: False - update_preview: True - update_preview_frequency: 5 normalize_prompt_weights: True save_individual_images: True save_grid: True group_by_prompt: True save_as_jpg: False - use_GFPGAN: False - use_RealESRGAN: False + use_GFPGAN: True + use_RealESRGAN: True RealESRGAN_model: "RealESRGAN_x4plus" variant_amount: 0.0 variant_seed: "" - write_info_files: True - slider_steps: { - sampling: 1 - } - slider_bounds: { - sampling: { - lower: 1, - upper: 150 - } - } - -txt2vid: - default_model: "CompVis/stable-diffusion-v1-4" - custom_models_list: ["CompVis/stable-diffusion-v1-4", "naclbit/trinart_stable_diffusion_v2", "hakurei/waifu-diffusion", "osanseviero/BigGAN-deep-128"] - prompt: - height: 512 - width: 512 - cfg_scale: 7.5 - seed: "" - batch_count: 1 - batch_size: 1 - sampling_steps: 30 - num_inference_steps: 200 - default_sampler: "k_euler" - scheduler_name: "klms" - separate_prompts: False - update_preview: True - update_preview_frequency: 5 - dynamic_preview_frequency: True - normalize_prompt_weights: True - save_individual_images: True - save_video: True - group_by_prompt: True - write_info_files: True - do_loop: False - save_as_jpg: False - use_GFPGAN: False - use_RealESRGAN: False - RealESRGAN_model: "RealESRGAN_x4plus" - variant_amount: 0.0 - variant_seed: "" - beta_start: 0.00085 - beta_end: 0.012 - beta_scheduler_type: "linear" - max_frames: 1000 - slider_steps: { - sampling: 1 - } - slider_bounds: { - sampling: { - lower: 1, - upper: 150 - } - } img2img: - prompt: - sampling_steps: 30 - # Adding an int to toggles enables the corresponding feature. - # 0: Create prompt matrix (separate multiple prompts using |, and get all combinations of them) - # 1: Normalize Prompt Weights (ensure sum of weights add up to 1.0) - # 2: Loopback (use images from previous batch when creating next batch) - # 3: Random loopback seed - # 4: Save individual images - # 5: Save grid - # 6: Sort samples by prompt - # 7: Write sample info files - # 8: jpg samples - # 9: Fix faces using GFPGAN - # 10: Upscale images using Real-ESRGAN - sampler_name: "k_euler" - denoising_strength: 0.75 - # 0: Keep masked area - # 1: Regenerate only masked area - mask_mode: 0 - mask_restore: False - # 0: Just resize - # 1: Crop and resize - # 2: Resize and fill - resize_mode: 0 - # Leave blank for random seed: - seed: "" - ddim_eta: 0.0 - cfg_scale: 7.5 - batch_count: 1 - batch_size: 1 - height: 512 - width: 512 - # Textual inversion embeddings file path: - fp: "" - loopback: True - random_seed_loopback: True - separate_prompts: False - update_preview: True - update_preview_frequency: 5 - normalize_prompt_weights: True - save_individual_images: True - save_grid: True - group_by_prompt: True - save_as_jpg: False - use_GFPGAN: False - use_RealESRGAN: False - RealESRGAN_model: "RealESRGAN_x4plus" - variant_amount: 0.0 - variant_seed: "" - write_info_files: True - slider_steps: { - sampling: 1 - } - slider_bounds: { - sampling: { - lower: 1, - upper: 150 - } - } + prompt: + sampling_steps: 50 + # Adding an int to toggles enables the corresponding feature. + # 0: Create prompt matrix (separate multiple prompts using |, and get all combinations of them) + # 1: Normalize Prompt Weights (ensure sum of weights add up to 1.0) + # 2: Loopback (use images from previous batch when creating next batch) + # 3: Random loopback seed + # 4: Save individual images + # 5: Save grid + # 6: Sort samples by prompt + # 7: Write sample info files + # 8: jpg samples + # 9: Fix faces using GFPGAN + # 10: Upscale images using Real-ESRGAN + sampler_name: k_lms + denoising_strength: 0.45 + # 0: Keep masked area + # 1: Regenerate only masked area + mask_mode: 0 + # 0: Just resize + # 1: Crop and resize + # 2: Resize and fill + resize_mode: 0 + # Leave blank for random seed: + seed: "" + ddim_eta: 0.0 + cfg_scale: 5.0 + batch_count: 1 + batch_size: 1 + height: 512 + width: 512 + # Textual inversion embeddings file path: + fp: "" + loopback: True + random_seed_loopback: True + separate_prompts: False + normalize_prompt_weights: True + save_individual_images: True + save_grid: True + group_by_prompt: True + save_as_jpg: False + use_GFPGAN: True + use_RealESRGAN: True + RealESRGAN_model: "RealESRGAN_x4plus" + variant_amount: 0.0 + variant_seed: "" gfpgan: strength: 100 + diff --git a/docker-compose.yml b/docker-compose.yml index f378963..968df1c 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -2,7 +2,7 @@ version: '3.3' services: stable-diffusion: - container_name: sd-webui + container_name: sd build: context: . dockerfile: Dockerfile @@ -12,7 +12,6 @@ services: volumes: - .:/sd - ./outputs:/sd/outputs - - ./model_cache:/sd/model_cache - conda_env:/opt/conda - root_profile:/root ports: @@ -22,7 +21,7 @@ services: resources: reservations: devices: - - capabilities: [ gpu ] + - capabilities: [gpu] volumes: conda_env: diff --git a/docker-reset.sh b/docker-reset.sh old mode 100755 new mode 100644 index 5042026..3ca3158 --- a/docker-reset.sh +++ b/docker-reset.sh @@ -10,13 +10,12 @@ echo $(pwd) read -p "Is the directory above correct to run reset on? (y/n) " -n 1 DIRCONFIRM if [[ $DIRCONFIRM =~ ^[Yy]$ ]]; then docker compose down - docker image rm stable-diffusion-webui_stable-diffusion:latest - docker volume rm stable-diffusion-webui_conda_env - docker volume rm stable-diffusion-webui_root_profile + docker image rm stable-diffusion_stable-diffusion:latest + docker volume rm stable-diffusion_conda_env + docker volume rm stable-diffusion_root_profile echo "Remove ./src" sudo rm -rf src - sudo rm -rf gfpgan - sudo rm -rf sd_webui.egg-info + sudo rm -rf latent_diffusion.egg-info sudo rm .env_updated else echo "Exited without resetting" diff --git a/entrypoint.sh b/entrypoint.sh index e130ea0..21ab01e 100755 --- a/entrypoint.sh +++ b/entrypoint.sh @@ -3,36 +3,26 @@ # Starts the gui inside the docker container using the conda env # -# set -x - -SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) -cd $SCRIPT_DIR -export PYTHONPATH=$SCRIPT_DIR - -MODEL_DIR="${SCRIPT_DIR}/model_cache" # Array of model files to pre-download # local filename # local path in container (no trailing slash) # download URL # sha256sum MODEL_FILES=( - 'model.ckpt models/ldm/stable-diffusion-v1 https://www.googleapis.com/storage/v1/b/aai-blog-files/o/sd-v1-4.ckpt?alt=media fe4efff1e174c627256e44ec2991ba279b3816e364b49f9be2abc0b3ff3f8556' - 'GFPGANv1.3.pth src/gfpgan/experiments/pretrained_models https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth c953a88f2727c85c3d9ae72e2bd4846bbaf59fe6972ad94130e23e7017524a70' - 'RealESRGAN_x4plus.pth src/realesrgan/experiments/pretrained_models https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth 4fa0d38905f75ac06eb49a7951b426670021be3018265fd191d2125df9d682f1' - 'RealESRGAN_x4plus_anime_6B.pth src/realesrgan/experiments/pretrained_models https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth f872d837d3c90ed2e05227bed711af5671a6fd1c9f7d7e91c911a61f155e99da' - 'project.yaml src/latent-diffusion/experiments/pretrained_models https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1 9d6ad53c5dafeb07200fb712db14b813b527edd262bc80ea136777bdb41be2ba' - 'model.ckpt src/latent-diffusion/experiments/pretrained_models https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1 c209caecac2f97b4bb8f4d726b70ac2ac9b35904b7fc99801e1f5e61f9210c13' + 'model.ckpt /sd/models/ldm/stable-diffusion-v1 https://www.googleapis.com/storage/v1/b/aai-blog-files/o/sd-v1-4.ckpt?alt=media fe4efff1e174c627256e44ec2991ba279b3816e364b49f9be2abc0b3ff3f8556' + 'GFPGANv1.3.pth /sd/src/gfpgan/experiments/pretrained_models https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth c953a88f2727c85c3d9ae72e2bd4846bbaf59fe6972ad94130e23e7017524a70' + 'RealESRGAN_x4plus.pth /sd/src/realesrgan/experiments/pretrained_models https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth 4fa0d38905f75ac06eb49a7951b426670021be3018265fd191d2125df9d682f1' + 'RealESRGAN_x4plus_anime_6B.pth /sd/src/realesrgan/experiments/pretrained_models https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth f872d837d3c90ed2e05227bed711af5671a6fd1c9f7d7e91c911a61f155e99da' ) # Conda environment installs/updates # @see https://github.com/ContinuumIO/docker-images/issues/89#issuecomment-467287039 ENV_NAME="ldm" -ENV_FILE="${SCRIPT_DIR}/environment.yaml" +ENV_FILE="/sd/environment.yaml" ENV_UPDATED=0 ENV_MODIFIED=$(date -r $ENV_FILE "+%s") -ENV_MODIFED_FILE="${SCRIPT_DIR}/.env_updated" +ENV_MODIFED_FILE="/sd/.env_updated" if [[ -f $ENV_MODIFED_FILE ]]; then ENV_MODIFIED_CACHED=$(<${ENV_MODIFED_FILE}); else ENV_MODIFIED_CACHED=0; fi -export PIP_EXISTS_ACTION=w # Create/update conda env if needed if ! conda env list | grep ".*${ENV_NAME}.*" >/dev/null 2>&1; then @@ -61,67 +51,54 @@ conda info | grep active # Function to checks for valid hash for model files and download/replaces if invalid or does not exist validateDownloadModel() { local file=$1 - local path="${SCRIPT_DIR}/${2}" + local path=$2 local url=$3 local hash=$4 echo "checking ${file}..." - sha256sum --check --status <<< "${hash} ${MODEL_DIR}/${file}.${hash}" + sha256sum --check --status <<< "${hash} ${path}/${file}" if [[ $? == "1" ]]; then echo "Downloading: ${url} please wait..." mkdir -p ${path} - wget --output-document=${MODEL_DIR}/${file}.${hash} --no-verbose --show-progress --progress=dot:giga ${url} - ln -sf ${MODEL_DIR}/${file}.${hash} ${path}/${file} - if [[ -e "${path}/${file}" ]]; then - echo "saved ${file}" - else - echo "error saving ${path}/${file}!" - exit 1 - fi + wget --output-document=${path}/${file} --no-verbose --show-progress --progress=dot:giga ${url} + echo "saved ${file}" else - if [[ ! -e ${path}/${file} || ! -L ${path}/${file} ]]; then - mkdir -p ${path} - ln -sf ${MODEL_DIR}/${file}.${hash} ${path}/${file} - echo -e "linked valid ${file}\n" - else - echo -e "${file} is valid!\n" - fi + echo -e "${file} is valid!\n" fi } # Validate model files -echo "Validating model files..." -for models in "${MODEL_FILES[@]}"; do - model=($models) - if [[ ! -e ${model[1]}/${model[0]} || ! -L ${model[1]}/${model[0]} || -z $VALIDATE_MODELS || $VALIDATE_MODELS == "true" ]]; then +if [[ -z $VALIDATE_MODELS || $VALIDATE_MODELS == "true" ]]; then + echo "Validating model files..." + for models in "${MODEL_FILES[@]}"; do + model=($models) validateDownloadModel ${model[0]} ${model[1]} ${model[2]} ${model[3]} - fi -done - -# Launch web gui -if [[ ! -z $WEBUI_SCRIPT && $WEBUI_SCRIPT == "webui_streamlit.py" ]]; then - launch_command="streamlit run scripts/${WEBUI_SCRIPT:-webui.py} $WEBUI_ARGS" -else - launch_command="python scripts/${WEBUI_SCRIPT:-webui.py} $WEBUI_ARGS" + done +fi + +# Launch web gui +cd /sd + +if [[ -z $WEBUI_ARGS ]]; then + launch_message="entrypoint.sh: Launching..." +else + launch_message="entrypoint.sh: Launching with arguments ${WEBUI_ARGS}" fi -launch_message="entrypoint.sh: Run ${launch_command}..." if [[ -z $WEBUI_RELAUNCH || $WEBUI_RELAUNCH == "true" ]]; then n=0 while true; do - echo $launch_message + echo $launch_message if (( $n > 0 )); then echo "Relaunch count: ${n}" fi - - $launch_command - + python -u scripts/webui.py $WEBUI_ARGS echo "entrypoint.sh: Process is ending. Relaunching in 0.5s..." ((n++)) sleep 0.5 done else echo $launch_message - $launch_command + python -u scripts/webui.py $WEBUI_ARGS fi diff --git a/environment.yaml b/environment.yaml index b5bd8e3..5bb3bf8 100644 --- a/environment.yaml +++ b/environment.yaml @@ -3,47 +3,39 @@ channels: - pytorch - defaults dependencies: - - cudatoolkit=11.3 - git - - numpy=1.22.3 - - pip=20.3 - python=3.8.5 + - pip=20.3 + - cudatoolkit=11.3 - pytorch=1.11.0 - - scikit-image=0.19.2 - torchvision=0.12.0 + - numpy=1.19.2 - pip: - - -e . + - albumentations==0.4.3 + - opencv-python==4.1.2.30 + - opencv-python-headless==4.1.2.30 + - pudb==2019.2 + - imageio==2.9.0 + - imageio-ffmpeg==0.4.2 + - pytorch-lightning==1.4.2 + - omegaconf==2.1.1 + - test-tube>=0.7.5 + - einops==0.3.0 + - torch-fidelity==0.3.0 + - transformers==4.19.2 + - torchmetrics==0.6.0 + - kornia==0.6 + - gradio==3.1.6 + - accelerate==0.12.0 + - pynvml==11.4.1 + - basicsr>=1.3.4.0 + - facexlib>=0.2.3 + - python-slugify>=6.1.2 + - streamlit>=1.12.2 + - retry>=0.9.2 - -e git+https://github.com/CompVis/taming-transformers#egg=taming-transformers - -e git+https://github.com/openai/CLIP#egg=clip - -e git+https://github.com/TencentARC/GFPGAN#egg=GFPGAN - -e git+https://github.com/xinntao/Real-ESRGAN#egg=realesrgan - -e git+https://github.com/hlky/k-diffusion-sd#egg=k_diffusion - - -e git+https://github.com/devilismyfriend/latent-diffusion#egg=latent-diffusion - - accelerate==0.12.0 - - albumentations==0.4.3 - - basicsr>=1.3.4.0 - - diffusers==0.3.0 - - einops==0.3.0 - - facexlib>=0.2.3 - - gradio==3.1.6 - - imageio-ffmpeg==0.4.2 - - imageio==2.9.0 - - kornia==0.6 - - omegaconf==2.1.1 - - opencv-python-headless==4.6.0.66 - - pandas==1.4.3 - - piexif==1.1.3 - - pudb==2019.2 - - pynvml==11.4.1 - - python-slugify>=6.1.2 - - pytorch-lightning==1.4.2 - - retry>=0.9.2 - - streamlit>=1.12.2 - - streamlit-on-Hover-tabs==1.0.1 - - streamlit-option-menu==0.3.2 - - streamlit_nested_layout - - test-tube>=0.7.5 - - tensorboard - - torch-fidelity==0.3.0 - - torchmetrics==0.6.0 - - transformers==4.19.2 + - -e . \ No newline at end of file diff --git a/frontend/css/streamlit.main.css b/frontend/css/streamlit.main.css index a11d21d..4e11b77 100644 --- a/frontend/css/streamlit.main.css +++ b/frontend/css/streamlit.main.css @@ -1,111 +1,15 @@ -/*********************************************************** -* Additional CSS for streamlit builtin components * -************************************************************/ - -/* Tab name (e.g. Text-to-Image) */ -button[data-baseweb="tab"] { - font-size: 25px; //improve legibility -} - -/* Image Container (only appear after run finished) */ -.css-du1fp8 { - justify-content: center; //center the image, especially better looks in wide screen -} - -/* Streamlit header */ -.css-1avcm0n { - background-color: transparent; -} - -/* Main streamlit container (below header) */ .css-18e3th9 { - padding-top: 2rem; //reduce the empty spaces + padding-top: 2rem; + padding-bottom: 10rem; + padding-left: 5rem; + padding-right: 5rem; } - -/* @media only for widescreen, to ensure enough space to see all */ -@media (min-width: 1024px) { - /* Main streamlit container (below header) */ - .css-18e3th9 { - padding-top: 0px; //reduce the empty spaces, can go fully to the top on widescreen devices - } +.css-1d391kg { + padding-top: 3.5rem; + padding-right: 1rem; + padding-bottom: 3.5rem; + padding-left: 1rem; } - -/*********************************************************** -* Additional CSS for streamlit custom/3rd party components * -************************************************************/ -/* For stream_on_hover */ -section[data-testid="stSidebar"] > div:nth-of-type(1) { - background-color: #111; -} - -button[kind="header"] { - background-color: transparent; - color: rgb(180, 167, 141); -} - -@media (hover) { - /* header element */ - header[data-testid="stHeader"] { - /* display: none;*/ /*suggested behavior by streamlit hover components*/ - pointer-events: none; /* disable interaction of the transparent background */ - } - - /* The button on the streamlit navigation menu */ - button[kind="header"] { - /* display: none;*/ /*suggested behavior by streamlit hover components*/ - pointer-events: auto; /* enable interaction of the button even if parents intereaction disabled */ - } - - /* added to avoid main sectors (all element to the right of sidebar from) moving */ - section[data-testid="stSidebar"] { - width: 3.5% !important; - min-width: 3.5% !important; - } - - /* The navigation menu specs and size */ - section[data-testid="stSidebar"] > div { - height: 100%; - width: 2% !important; - min-width: 100% !important; - position: relative; - z-index: 1; - top: 0; - left: 0; - background-color: #111; - overflow-x: hidden; - transition: 0.5s ease-in-out; - padding-top: 0px; - white-space: nowrap; - } - - /* The navigation menu open and close on hover and size */ - section[data-testid="stSidebar"] > div:hover { - width: 300px !important; - } -} - -@media (max-width: 272px) { - section[data-testid="stSidebar"] > div { - width: 15rem; - } -} - -/*********************************************************** -* Additional CSS for other elements -************************************************************/ button[data-baseweb="tab"] { - font-size: 20px; + font-size: 25px; } - -@media (min-width: 1200px){ -h1 { - font-size: 1.75rem; -} -} -#tabs-1-tabpanel-0 > div:nth-child(1) > div > div.stTabs.css-0.exp6ofz0 { - width: 50rem; - align-self: center; -} -div.gallery:hover { - border: 1px solid #777; -} \ No newline at end of file diff --git a/frontend/frontend.py b/frontend/frontend.py index 94c76c9..29d3c50 100644 --- a/frontend/frontend.py +++ b/frontend/frontend.py @@ -3,8 +3,6 @@ from frontend.css_and_js import css, js, call_JS, js_parse_prompt, js_copy_txt2i from frontend.job_manager import JobManager import frontend.ui_functions as uifn import uuid -import torch - def draw_gradio_ui(opt, img2img=lambda x: x, txt2img=lambda x: x, imgproc=lambda x: x, txt2img_defaults={}, @@ -38,11 +36,8 @@ def draw_gradio_ui(opt, img2img=lambda x: x, txt2img=lambda x: x, imgproc=lambda value=txt2img_defaults['cfg_scale'], elem_id='cfg_slider') txt2img_seed = gr.Textbox(label="Seed (blank to randomize)", lines=1, max_lines=1, value=txt2img_defaults["seed"]) - txt2img_batch_size = gr.Slider(minimum=1, maximum=50, step=1, - label='Images per batch', - value=txt2img_defaults['batch_size']) txt2img_batch_count = gr.Slider(minimum=1, maximum=50, step=1, - label='Number of batches to generate', + label='Number of images to generate', value=txt2img_defaults['n_iter']) txt2img_job_ui = job_manager.draw_gradio_ui() if job_manager else None @@ -56,15 +51,11 @@ def draw_gradio_ui(opt, img2img=lambda x: x, txt2img=lambda x: x, imgproc=lambda gr.Markdown( "Select an image from the gallery, then click one of the buttons below to perform an action.") with gr.Row(elem_id='txt2img_actions_row'): - gr.Button("Copy to clipboard").click( - fn=None, - inputs=output_txt2img_gallery, - outputs=[], - _js=call_JS( - "copyImageFromGalleryToClipboard", - fromId="txt2img_gallery_output" - ) - ) + gr.Button("Copy to clipboard").click(fn=None, + inputs=output_txt2img_gallery, + outputs=[], + # _js=js_copy_to_clipboard( 'txt2img_gallery_output') + ) output_txt2img_copy_to_input_btn = gr.Button("Push to img2img") output_txt2img_to_imglab = gr.Button("Send to Lab", visible=True) @@ -100,6 +91,9 @@ def draw_gradio_ui(opt, img2img=lambda x: x, txt2img=lambda x: x, imgproc=lambda with gr.TabItem('Advanced'): txt2img_toggles = gr.CheckboxGroup(label='', choices=txt2img_toggles, value=txt2img_toggle_defaults, type="index") + txt2img_batch_size = gr.Slider(minimum=1, maximum=8, step=1, + label='Batch size (how many images are in a batch; memory-hungry)', + value=txt2img_defaults['batch_size']) txt2img_realesrgan_model_name = gr.Dropdown(label='RealESRGAN model', choices=['RealESRGAN_x4plus', 'RealESRGAN_x4plus_anime_6B'], @@ -130,27 +124,20 @@ def draw_gradio_ui(opt, img2img=lambda x: x, txt2img=lambda x: x, imgproc=lambda inputs=txt2img_inputs, outputs=txt2img_outputs ) - use_queue = False - else: - use_queue = True txt2img_btn.click( txt2img_func, txt2img_inputs, - txt2img_outputs, - api_name='txt2img', - queue=use_queue + txt2img_outputs ) txt2img_prompt.submit( txt2img_func, txt2img_inputs, - txt2img_outputs, - queue=use_queue + txt2img_outputs ) - txt2img_width.change(fn=uifn.update_dimensions_info, inputs=[txt2img_width, txt2img_height], outputs=txt2img_dimensions_info_text_box) - txt2img_height.change(fn=uifn.update_dimensions_info, inputs=[txt2img_width, txt2img_height], outputs=txt2img_dimensions_info_text_box) - txt2img_dimensions_info_text_box.value = uifn.update_dimensions_info(txt2img_width.value, txt2img_height.value) + # txt2img_width.change(fn=uifn.update_dimensions_info, inputs=[txt2img_width, txt2img_height], outputs=txt2img_dimensions_info_text_box) + # txt2img_height.change(fn=uifn.update_dimensions_info, inputs=[txt2img_width, txt2img_height], outputs=txt2img_dimensions_info_text_box) # Temporarily disable prompt parsing until memory issues could be solved # See #676 @@ -202,9 +189,8 @@ def draw_gradio_ui(opt, img2img=lambda x: x, txt2img=lambda x: x, imgproc=lambda with gr.TabItem("Editor Options"): with gr.Row(): # disable Uncrop for now - choices=["Mask", "Crop", "Uncrop"] - #choices=["Mask", "Crop"] - img2img_image_editor_mode = gr.Radio(choices=choices, + # choices=["Mask", "Crop", "Uncrop"] + img2img_image_editor_mode = gr.Radio(choices=["Mask", "Crop"], label="Image Editor Mode", value="Mask", elem_id='edit_mode_select', visible=True) @@ -213,13 +199,9 @@ def draw_gradio_ui(opt, img2img=lambda x: x, txt2img=lambda x: x, imgproc=lambda value=img2img_mask_modes[img2img_defaults['mask_mode']], visible=True) - img2img_mask_restore = gr.Checkbox(label="Only modify regenerated parts of image", - value=img2img_defaults['mask_restore'], - visible=True) - - img2img_mask_blur_strength = gr.Slider(minimum=1, maximum=100, step=1, + img2img_mask_blur_strength = gr.Slider(minimum=1, maximum=10, step=1, label="How much blurry should the mask be? (to avoid hard edges)", - value=3, visible=True) + value=3, visible=False) img2img_resize = gr.Radio(label="Resize mode", choices=["Just resize", "Crop and resize", @@ -311,7 +293,7 @@ def draw_gradio_ui(opt, img2img=lambda x: x, txt2img=lambda x: x, imgproc=lambda img2img_height ], [img2img_image_editor, img2img_image_mask, img2img_btn_editor, img2img_btn_mask, - img2img_painterro_btn, img2img_mask, img2img_mask_blur_strength, img2img_mask_restore] + img2img_painterro_btn, img2img_mask, img2img_mask_blur_strength] ) # img2img_image_editor_mode.change( @@ -352,8 +334,8 @@ def draw_gradio_ui(opt, img2img=lambda x: x, txt2img=lambda x: x, imgproc=lambda ) img2img_func = img2img - img2img_inputs = [img2img_prompt, img2img_image_editor_mode, img2img_mask, img2img_mask_blur_strength, - img2img_mask_restore, img2img_steps, img2img_sampling, img2img_toggles, + img2img_inputs = [img2img_prompt, img2img_image_editor_mode, img2img_mask, + img2img_mask_blur_strength, img2img_steps, img2img_sampling, img2img_toggles, img2img_realesrgan_model_name, img2img_batch_count, img2img_cfg, img2img_denoising, img2img_seed, img2img_height, img2img_width, img2img_resize, img2img_image_editor, img2img_image_mask, img2img_embeddings] @@ -367,16 +349,11 @@ def draw_gradio_ui(opt, img2img=lambda x: x, txt2img=lambda x: x, imgproc=lambda inputs=img2img_inputs, outputs=img2img_outputs, ) - use_queue = False - else: - use_queue = True img2img_btn_mask.click( img2img_func, img2img_inputs, - img2img_outputs, - api_name="img2img", - queue=use_queue + img2img_outputs ) def img2img_submit_params(): @@ -406,7 +383,6 @@ def draw_gradio_ui(opt, img2img=lambda x: x, txt2img=lambda x: x, imgproc=lambda outputs=img2img_dimensions_info_text_box) img2img_height.change(fn=uifn.update_dimensions_info, inputs=[img2img_width, img2img_height], outputs=img2img_dimensions_info_text_box) - img2img_dimensions_info_text_box.value = uifn.update_dimensions_info(img2img_width.value, img2img_height.value) with gr.TabItem("Image Lab", id='imgproc_tab'): gr.Markdown("Post-process results") @@ -421,7 +397,8 @@ def draw_gradio_ui(opt, img2img=lambda x: x, txt2img=lambda x: x, imgproc=lambda # value=gfpgan_defaults['strength']) # select folder with images to process with gr.TabItem('Batch Process'): - imgproc_folder = gr.File(label="Batch Process", file_count="multiple", interactive=True, type="file") + imgproc_folder = gr.File(label="Batch Process", file_count="multiple", source="upload", + interactive=True, type="file") imgproc_pngnfo = gr.Textbox(label="PNG Metadata", placeholder="PngNfo", visible=False, max_lines=5) with gr.Row(): @@ -563,7 +540,7 @@ def draw_gradio_ui(opt, img2img=lambda x: x, txt2img=lambda x: x, imgproc=lambda imgproc_width, imgproc_cfg, imgproc_denoising, imgproc_seed, imgproc_gfpgan_strength, imgproc_ldsr_steps, imgproc_ldsr_pre_downSample, imgproc_ldsr_post_downSample], - [imgproc_output], api_name="imgproc") + [imgproc_output]) imgproc_source.change( uifn.get_png_nfo, @@ -654,12 +631,11 @@ def draw_gradio_ui(opt, img2img=lambda x: x, txt2img=lambda x: x, imgproc=lambda """ gr.HTML("""
-

For help and advanced usage guides, visit the Project Wiki

-

Stable Diffusion WebUI is an open-source project. You can find the latest stable builds on the main repository. - If you would like to contribute to development or test bleeding edge builds, you can visit the developement repository.

-

Device ID {current_device_index}: {current_device_name}
{total_device_count} total devices

+

For help and advanced usage guides, visit the Project Wiki

+

Stable Diffusion WebUI is an open-source project. + If you would like to contribute to development or test bleeding edge builds, use the dev branch.

- """.format(current_device_name=torch.cuda.get_device_name(), current_device_index=torch.cuda.current_device(), total_device_count=torch.cuda.device_count())) + """) # Hack: Detect the load event on the frontend # Won't be needed in the next version of gradio # See the relevant PR: https://github.com/gradio-app/gradio/pull/2108 diff --git a/frontend/image_metadata.py b/frontend/image_metadata.py deleted file mode 100644 index 8448088..0000000 --- a/frontend/image_metadata.py +++ /dev/null @@ -1,57 +0,0 @@ -''' Class to store image generation parameters to be stored as metadata in the image''' -from __future__ import annotations -from dataclasses import dataclass, asdict -from typing import Dict, Optional -from PIL import Image -from PIL.PngImagePlugin import PngInfo -import copy - -@dataclass -class ImageMetadata: - prompt: str = None - seed: str = None - width: str = None - height: str = None - steps: str = None - cfg_scale: str = None - normalize_prompt_weights: str = None - denoising_strength: str = None - GFPGAN: str = None - - def as_png_info(self) -> PngInfo: - info = PngInfo() - for key, value in self.as_dict().items(): - info.add_text(key, value) - return info - - def as_dict(self) -> Dict[str, str]: - return {f"SD:{key}": str(value) for key, value in asdict(self).items() if value is not None} - - @classmethod - def set_on_image(cls, image: Image, metadata: ImageMetadata) -> None: - ''' Sets metadata on image, in both text form and as an ImageMetadata object ''' - if metadata: - image.info = metadata.as_dict() - else: - metadata = ImageMetadata() - image.info["ImageMetadata"] = copy.copy(metadata) - - @classmethod - def get_from_image(cls, image: Image) -> Optional[ImageMetadata]: - ''' Gets metadata from an image, first looking for an ImageMetadata, - then if not found tries to construct one from the info ''' - metadata = image.info.get("ImageMetadata", None) - if not metadata: - found_metadata = False - metadata = ImageMetadata() - for key, value in image.info.items(): - if key.lower().startswith("sd:"): - key = key[3:] - if f"{key}" in metadata.__dict__: - metadata.__dict__[key] = value - found_metadata = True - if not found_metadata: - metadata = None - if not metadata: - print("Couldn't find metadata on image") - return metadata diff --git a/frontend/job_manager.py b/frontend/job_manager.py index 026742f..8eda8d9 100644 --- a/frontend/job_manager.py +++ b/frontend/job_manager.py @@ -1,7 +1,7 @@ ''' Provides simple job management for gradio, allowing viewing and stopping in-progress multi-batch generations ''' from __future__ import annotations import gradio as gr -from gradio.components import Component, Gallery, Slider +from gradio.components import Component, Gallery from threading import Event, Timer from typing import Callable, List, Dict, Tuple, Optional, Any from dataclasses import dataclass, field @@ -9,7 +9,6 @@ from functools import partial from PIL.Image import Image import uuid import traceback -import time @dataclass(eq=True, frozen=True) @@ -31,21 +30,9 @@ class JobInfo: session_key: str job_token: Optional[int] = None images: List[Image] = field(default_factory=list) - active_image: Image = None - rec_steps_enabled: bool = False - rec_steps_imgs: List[Image] = field(default_factory=list) - rec_steps_intrvl: int = None - rec_steps_to_gallery: bool = False - rec_steps_to_file: bool = False should_stop: Event = field(default_factory=Event) - refresh_active_image_requested: Event = field(default_factory=Event) - refresh_active_image_done: Event = field(default_factory=Event) - stop_cur_iter: Event = field(default_factory=Event) - active_iteration_cnt: int = field(default_factory=int) job_status: str = field(default_factory=str) finished: bool = False - started: bool = False - timestamp: float = None removed_output_idxs: List[int] = field(default_factory=list) @@ -89,7 +76,7 @@ class JobManagerUi: ''' return self._job_manager._wrap_func( func=func, inputs=inputs, outputs=outputs, - job_ui=self + refresh_btn=self._refresh_btn, stop_btn=self._stop_btn, status_text=self._status_text ) _refresh_btn: gr.Button @@ -97,19 +84,10 @@ class JobManagerUi: _status_text: gr.Textbox _stop_all_session_btn: gr.Button _free_done_sessions_btn: gr.Button - _active_image: gr.Image - _active_image_stop_btn: gr.Button - _active_image_refresh_btn: gr.Button - _rec_steps_intrvl_sldr: gr.Slider - _rec_steps_checkbox: gr.Checkbox - _save_rec_steps_to_gallery_chkbx: gr.Checkbox - _save_rec_steps_to_file_chkbx: gr.Checkbox _job_manager: JobManager class JobManager: - JOB_MAX_START_TIME = 5.0 # How long can a job be stuck 'starting' before assuming it isn't running - def __init__(self, max_jobs: int): self._max_jobs: int = max_jobs self._avail_job_tokens: List[Any] = list(range(max_jobs)) @@ -124,23 +102,11 @@ class JobManager: ''' assert gr.context.Context.block is not None, "draw_gradio_ui must be called within a 'gr.Blocks' 'with' context" with gr.Tabs(): - with gr.TabItem("Job Controls"): + with gr.TabItem("Current Session"): with gr.Row(): - stop_btn = gr.Button("Stop All Batches", elem_id="stop", variant="secondary") - refresh_btn = gr.Button("Refresh Finished Batches", elem_id="refresh", variant="secondary") + stop_btn = gr.Button("Stop", elem_id="stop", variant="secondary") + refresh_btn = gr.Button("Refresh", elem_id="refresh", variant="secondary") status_text = gr.Textbox(placeholder="Job Status", interactive=False, show_label=False) - with gr.Row(): - active_image_stop_btn = gr.Button("Skip Active Batch", variant="secondary") - active_image_refresh_btn = gr.Button("View Batch Progress", variant="secondary") - active_image = gr.Image(type="pil", interactive=False, visible=False, elem_id="active_iteration_image") - with gr.TabItem("Batch Progress Settings"): - with gr.Row(): - record_steps_checkbox = gr.Checkbox(value=False, label="Enable Batch Progress Grid") - record_steps_interval_slider = gr.Slider( - value=3, label="Record Interval (steps)", minimum=1, maximum=25, step=1) - with gr.Row() as record_steps_box: - steps_to_gallery_checkbox = gr.Checkbox(value=False, label="Save Progress Grid to Gallery") - steps_to_file_checkbox = gr.Checkbox(value=False, label="Save Progress Grid to File") with gr.TabItem("Maintenance"): with gr.Row(): gr.Markdown( @@ -152,15 +118,9 @@ class JobManager: free_done_sessions_btn = gr.Button( "Clear Finished Jobs", elem_id="clear_finished", variant="secondary" ) - return JobManagerUi(_refresh_btn=refresh_btn, _stop_btn=stop_btn, _status_text=status_text, _stop_all_session_btn=stop_all_sessions_btn, _free_done_sessions_btn=free_done_sessions_btn, - _active_image=active_image, _active_image_stop_btn=active_image_stop_btn, - _active_image_refresh_btn=active_image_refresh_btn, - _rec_steps_checkbox=record_steps_checkbox, - _save_rec_steps_to_gallery_chkbx=steps_to_gallery_checkbox, - _save_rec_steps_to_file_chkbx=steps_to_file_checkbox, - _rec_steps_intrvl_sldr=record_steps_interval_slider, _job_manager=self) + _job_manager=self) def clear_all_finished_jobs(self): ''' Removes all currently finished jobs, across all sessions. @@ -174,7 +134,6 @@ class JobManager: for session in self._sessions.values(): for job in session.jobs.values(): job.should_stop.set() - job.stop_cur_iter.set() def _get_job_token(self, block: bool = False) -> Optional[int]: ''' Attempts to acquire a job token, optionally blocking until available ''' @@ -216,26 +175,6 @@ class JobManager: job_info.should_stop.set() return "Stopping after current batch finishes" - def _refresh_cur_iter_func(self, func_key: FuncKey, session_key: str) -> List[Component]: - ''' Updates information from the active iteration ''' - session_info, job_info = self._get_call_info(func_key, session_key) - if job_info is None: - return [None, f"Session {session_key} was not running function {func_key}"] - - job_info.refresh_active_image_requested.set() - if job_info.refresh_active_image_done.wait(timeout=20.0): - job_info.refresh_active_image_done.clear() - return [gr.Image.update(value=job_info.active_image, visible=True), f"Sample iteration {job_info.active_iteration_cnt}"] - return [gr.Image.update(visible=False), "Timed out getting image"] - - def _stop_cur_iter_func(self, func_key: FuncKey, session_key: str) -> List[Component]: - ''' Marks that the active iteration should be stopped''' - session_info, job_info = self._get_call_info(func_key, session_key) - if job_info is None: - return [None, f"Session {session_key} was not running function {func_key}"] - job_info.stop_cur_iter.set() - return [gr.Image.update(visible=False), "Stopping current iteration"] - def _get_call_info(self, func_key: FuncKey, session_key: str) -> Tuple[SessionInfo, JobInfo]: ''' Helper to get the SessionInfo and JobInfo. ''' session_info = self._sessions.get(session_key, None) @@ -268,22 +207,19 @@ class JobManager: def _pre_call_func( self, func_key: FuncKey, output_dummy_obj: Component, refresh_btn: gr.Button, stop_btn: gr.Button, - status_text: gr.Textbox, active_image: gr.Image, active_refresh_btn: gr.Button, active_stop_btn: gr.Button, - session_key: str) -> List[Component]: + status_text: gr.Textbox, session_key: str) -> List[Component]: ''' Called when a job is about to start ''' session_info, job_info = self._get_call_info(func_key, session_key) # If we didn't already get a token then queue up for one if job_info.job_token is None: - job_info.job_token = self._get_job_token(block=True) + job_info.token = self._get_job_token(block=True) # Buttons don't seem to update unless value is set on them as well... return {output_dummy_obj: triggerChangeEvent(), refresh_btn: gr.Button.update(variant="primary", value=refresh_btn.value), stop_btn: gr.Button.update(variant="primary", value=stop_btn.value), - status_text: gr.Textbox.update(value="Generation has started. Click 'Refresh' to see finished images, 'View Batch Progress' for active images"), - active_refresh_btn: gr.Button.update(variant="primary", value=active_refresh_btn.value), - active_stop_btn: gr.Button.update(variant="primary", value=active_stop_btn.value), + status_text: gr.Textbox.update(value="Generation has started. Click 'Refresh' for updates") } def _call_func(self, func_key: FuncKey, session_key: str) -> List[Component]: @@ -292,19 +228,12 @@ class JobManager: if session_info is None or job_info is None: return [] - job_info.started = True try: - if job_info.should_stop.is_set(): - raise Exception(f"Job {job_info} requested a stop before execution began") outputs = job_info.func(*job_info.inputs, job_info=job_info) except Exception as e: job_info.job_status = f"Error: {e}" print(f"Exception processing job {job_info}: {e}\n{traceback.format_exc()}") - raise - finally: - job_info.finished = True - session_info.finished_jobs[func_key] = session_info.jobs.pop(func_key) - self._release_job_token(job_info.job_token) + outputs = [] # Filter the function output for any removed outputs filtered_output = [] @@ -312,6 +241,11 @@ class JobManager: if idx not in job_info.removed_output_idxs: filtered_output.append(output) + job_info.finished = True + session_info.finished_jobs[func_key] = session_info.jobs.pop(func_key) + + self._release_job_token(job_info.job_token) + # The wrapper added a dummy JSON output. Append a random text string # to fire the dummy objects 'change' event to notify that the job is done filtered_output.append(triggerChangeEvent()) @@ -320,16 +254,12 @@ class JobManager: def _post_call_func( self, func_key: FuncKey, output_dummy_obj: Component, refresh_btn: gr.Button, stop_btn: gr.Button, - status_text: gr.Textbox, active_image: gr.Image, active_refresh_btn: gr.Button, active_stop_btn: gr.Button, - session_key: str) -> List[Component]: + status_text: gr.Textbox, session_key: str) -> List[Component]: ''' Called when a job completes ''' return {output_dummy_obj: triggerChangeEvent(), refresh_btn: gr.Button.update(variant="secondary", value=refresh_btn.value), stop_btn: gr.Button.update(variant="secondary", value=stop_btn.value), - status_text: gr.Textbox.update(value="Generation has finished!"), - active_refresh_btn: gr.Button.update(variant="secondary", value=active_refresh_btn.value), - active_stop_btn: gr.Button.update(variant="secondary", value=active_stop_btn.value), - active_image: gr.Image.update(visible=False) + status_text: gr.Textbox.update(value="Generation has finished!") } def _update_gallery_event(self, func_key: FuncKey, session_key: str) -> List[Component]: @@ -340,17 +270,21 @@ class JobManager: if session_info is None or job_info is None: return [] + if job_info.finished: + session_info.finished_jobs.pop(func_key) + return job_info.images - def _wrap_func(self, func: Callable, inputs: List[Component], - outputs: List[Component], - job_ui: JobManagerUi) -> Tuple[Callable, List[Component]]: + def _wrap_func( + self, func: Callable, inputs: List[Component], outputs: List[Component], + refresh_btn: gr.Button = None, stop_btn: gr.Button = None, + status_text: Optional[gr.Textbox] = None) -> Tuple[Callable, List[Component]]: ''' handles JobManageUI's wrap_func''' assert gr.context.Context.block is not None, "wrap_func must be called within a 'gr.Blocks' 'with' context" # Create a unique key for this job - func_key = FuncKey(job_id=uuid.uuid4().hex, func=func) + func_key = FuncKey(job_id=uuid.uuid4(), func=func) # Create a unique session key (next gradio release can use gr.State, see https://gradio.app/state_in_blocks/) if self._session_key is None: @@ -368,59 +302,31 @@ class JobManager: del outputs[idx] break + # Add the session key to the inputs + inputs += [self._session_key] + # Create dummy objects update_gallery_obj = gr.JSON(visible=False, elem_id="JobManagerDummyObject") update_gallery_obj.change( partial(self._update_gallery_event, func_key), [self._session_key], - [gallery_comp], - queue=False + [gallery_comp] ) - if job_ui._refresh_btn: - job_ui._refresh_btn.variant = 'secondary' - job_ui._refresh_btn.click( + if refresh_btn: + refresh_btn.variant = 'secondary' + refresh_btn.click( partial(self._refresh_func, func_key), [self._session_key], - [update_gallery_obj, job_ui._status_text], - queue=False + [update_gallery_obj, status_text] ) - if job_ui._stop_btn: - job_ui._stop_btn.variant = 'secondary' - job_ui._stop_btn.click( + if stop_btn: + stop_btn.variant = 'secondary' + stop_btn.click( partial(self._stop_wrapped_func, func_key), [self._session_key], - [job_ui._status_text], - queue=False - ) - - if job_ui._active_image and job_ui._active_image_refresh_btn: - job_ui._active_image_refresh_btn.click( - partial(self._refresh_cur_iter_func, func_key), - [self._session_key], - [job_ui._active_image, job_ui._status_text], - queue=False - ) - - if job_ui._active_image_stop_btn: - job_ui._active_image_stop_btn.click( - partial(self._stop_cur_iter_func, func_key), - [self._session_key], - [job_ui._active_image, job_ui._status_text], - queue=False - ) - - if job_ui._stop_all_session_btn: - job_ui._stop_all_session_btn.click( - self.stop_all_jobs, [], [], - queue=False - ) - - if job_ui._free_done_sessions_btn: - job_ui._free_done_sessions_btn.click( - self.clear_all_finished_jobs, [], [], - queue=False + [status_text] ) # (ab)use gr.JSON to forward events. @@ -437,8 +343,7 @@ class JobManager: # Since some parameters are optional it makes sense to use the 'dict' return value type, which requires # the Component as a key... so group together the UI components that the event listeners are going to update # to make it easy to append to function calls and outputs - job_ui_params = [job_ui._refresh_btn, job_ui._stop_btn, job_ui._status_text, - job_ui._active_image, job_ui._active_image_refresh_btn, job_ui._active_image_stop_btn] + job_ui_params = [refresh_btn, stop_btn, status_text] job_ui_outputs = [comp for comp in job_ui_params if comp is not None] # Here a chain is constructed that will make a 'pre' call, a 'run' call, and a 'post' call, @@ -447,70 +352,44 @@ class JobManager: post_call_dummyobj.change( partial(self._post_call_func, func_key, update_gallery_obj, *job_ui_params), [self._session_key], - [update_gallery_obj] + job_ui_outputs, - queue=False + [update_gallery_obj] + job_ui_outputs ) call_dummyobj = gr.JSON(visible=False, elem_id="JobManagerDummyObject_runCall") call_dummyobj.change( partial(self._call_func, func_key), [self._session_key], - outputs + [post_call_dummyobj], - queue=False + outputs + [post_call_dummyobj] ) pre_call_dummyobj = gr.JSON(visible=False, elem_id="JobManagerDummyObject_preCall") pre_call_dummyobj.change( partial(self._pre_call_func, func_key, call_dummyobj, *job_ui_params), [self._session_key], - [call_dummyobj] + job_ui_outputs, - queue=False + [call_dummyobj] + job_ui_outputs ) - # Add any components that we want the runtime values for - added_inputs = [self._session_key, job_ui._rec_steps_checkbox, job_ui._save_rec_steps_to_gallery_chkbx, - job_ui._save_rec_steps_to_file_chkbx, job_ui._rec_steps_intrvl_sldr] - # Now replace the original function with one that creates a JobInfo and triggers the dummy obj - def wrapped_func(*wrapped_inputs): - # Remove the added_inputs (pop opposite order of list) - wrapped_inputs = list(wrapped_inputs) - rec_steps_interval: int = wrapped_inputs.pop() - save_rec_steps_file: bool = wrapped_inputs.pop() - save_rec_steps_grid: bool = wrapped_inputs.pop() - record_steps_enabled: bool = wrapped_inputs.pop() - session_key: str = wrapped_inputs.pop() - job_inputs = tuple(wrapped_inputs) + def wrapped_func(*inputs): + session_key = inputs[-1] + inputs = inputs[:-1] # Get or create a session for this key session_info = self._sessions.setdefault(session_key, SessionInfo()) # Is this session already running this job? if func_key in session_info.jobs: - job_info = session_info.jobs[func_key] - # If the job seems stuck in 'starting' then go ahead and toss it - if not job_info.started and time.time() > job_info.timestamp + JobManager.JOB_MAX_START_TIME: - job_info.should_stop.set() - job_info.stop_cur_iter.set() - session_info.jobs.pop(func_key) - return {job_ui._status_text: "Canceled possibly hung job. Try again"} - return {job_ui._status_text: "This session is already running that function!"} - - # Is this a new run of a previously finished job? Clear old info - if func_key in session_info.finished_jobs: - session_info.finished_jobs.pop(func_key) + return {status_text: "This session is already running that function!"} job_token = self._get_job_token(block=False) - job = JobInfo( - inputs=job_inputs, func=func, removed_output_idxs=removed_idxs, session_key=session_key, - job_token=job_token, rec_steps_enabled=record_steps_enabled, rec_steps_intrvl=rec_steps_interval, - rec_steps_to_gallery=save_rec_steps_grid, rec_steps_to_file=save_rec_steps_file, timestamp=time.time()) + job = JobInfo(inputs=inputs, func=func, removed_output_idxs=removed_idxs, session_key=session_key, + job_token=job_token) session_info.jobs[func_key] = job ret = {pre_call_dummyobj: triggerChangeEvent()} if job_token is None: - ret[job_ui._status_text] = "Job is queued" + ret[status_text] = "Job is queued" return ret - return wrapped_func, inputs + added_inputs, [pre_call_dummyobj, job_ui._status_text] + return wrapped_func, inputs, [pre_call_dummyobj, status_text] diff --git a/frontend/ui_functions.py b/frontend/ui_functions.py index ee6af8d..6557841 100644 --- a/frontend/ui_functions.py +++ b/frontend/ui_functions.py @@ -9,10 +9,10 @@ import re def change_image_editor_mode(choice, cropped_image, masked_image, resize_mode, width, height): if choice == "Mask": update_image_result = update_image_mask(cropped_image, resize_mode, width, height) - return [gr.update(visible=False), update_image_result, gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)] + return [gr.update(visible=False), update_image_result, gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=True), gr.update(visible=True)] update_image_result = update_image_mask(masked_image["image"] if masked_image is not None else None, resize_mode, width, height) - return [update_image_result, gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)] + return [update_image_result, gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)] def update_image_mask(cropped_image, resize_mode, width, height): resized_cropped_image = resize_image(resize_mode, cropped_image, width, height) if cropped_image else None diff --git a/images/nsfw.jpeg b/images/nsfw.jpeg deleted file mode 100644 index 0ecf3b68b55af111bf8a73f5436add752b1c95c3..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 25276 zcmeHv2UrtJ*Z(F|DN+OkK@6aXN+1akiXy#A2fLJzgramp5tSk;NL4^VMFD9l3L+{Z zO{7^-Kvaq#L8V!U1raIVB!F_?dtcv;?{}Z?`9J@=>}Jo*{LYzk&YYduomqDJ`SeSO z-_+2=5JDgjkTLj!rhBTq902TY>cxrpv7nDoq(2` z#UKHVMIa$waP|WS4$wm291afmC{_k<3ZV1A0jgu9=gvkZ=2j@20v4m7gae!y98LqH zu7Sm%Flrh&tOiyEpb?orbphK2QC?nAoRowVj_Gas79_w)*NFuo4Iu>(ECPt>I%pZtlMS>qy|HIE1d@f7jh%yY z0T(v_T;qp;r&y4zENpD7tN^_altZinY=Tl)9ri`m1P*CmAzZ}KEKV8S;u>L_=8@$} z&VKtBaEUAy6$HgDOwjYuN9xVpLT^rr*_ z1_kd9i9B%dP*n8cn55*CW5-j|PMkcOos*lFfA0K+D+cJp>^U4292oxAs1 zTA#E%efGS)qnFzE>h+ucfx#i#$4{TXjD8*aHcq#TZs*)KW7xmgB>?O~va+(Ua?tHU zAcMfcBEZTfg=H7ivF0H7E|SJYa0=-j%_^>0Afse6BJAwf%q6m1xpxJPZW_Zfc!ur& zEz5ow_S3Eoh=&CM8jnQ)(t^HC#mHqMfA_dF==f3e$~06IJq>-9+PJ#=*fb;+t2jQE zI1T9vp;lWP2tmJlY<#$KN)NB~g|=+!!0OXGTEbo_5vL(XfoW(!_W+xQ+}Uuh-#yB6 z!tU&>ovM(lr*~OJ&ZGif1rAu?Cp*s%ur^@qzKXSLXRKLQKfBM7pcYQ(UE6WYL zsw58Hy9YE7k1nqI)$j}LNk*;yMvLce(k2N^bwzR5U~;Z6?-Vbedxi!L%($Xc82 zcaI${kVad$=KAK@0MUA-Lip*B|0+Pr=mGN6AtBbC1?|gB`n1MIvJfCX{)-F7J>#QD z$uv|rI1OcS6>O&Xp7y#G($z>$LAk#hb#%VX6#sayTjM(zqH;HVF8ao^HSAek*f{@x zsdAzFx?z#I%h0}g?CD#r;Wyw?v0_{O-^5&bj>N7_94+*VZsY2*Tqq!xKSZ3-jc5xE zoQ;9xe)CxUCK>cWrsK- zkxZ^98xt>ghdpAHHIIr{fBTX#HF**$GO+xSSYI174XrN;>nYLZx85|HnudBxjGilM z^|#O~LskmQ)F^viqaL-)UM19no`!Dp9E0#NfsIqcp>}6iiN9~?ACN<&*d1DsLaKYY z=X~B<@Anz^!o>{~o=y!;oa^X$p>+&eAymHod;Z5BKq2 zRojZvslHIV-QJT8^rYQgcYoBq?e}WitWE~aX%cojd`)LZ@PWDmp9RkLh;d1RP^#!AdtMB$qGboPsLz@k0Mq9qX@0j3icrocabS5 zXR@oim!|ZG@@i?6J4sV|vx+&!+((b>=58F~N45>Iup@@-BC3<5wbrB71#1L*`goEl z1XQr6hnK%bu%&aBCZglLj*UNG=*S z`i8R-KuS}3R@I=OAcY_$1#drBMXb8Ix*`Urh{K@)0`0%si$VxSd-==ENzfH`vwP(}%$UJo^I$ z=pUfJRGwZnuuwzan;1Z^#6(|Hnl4F$&65Io75OjJ`>Ml0h~Flc2Zf(u&Rnc$38 zR&~LVRn=Tv2qfYhzonl$2qS{We0}J0NOU>s&gxhdJQ+=PQBp@ME0eHjXOfEwT3y+N zsG@?wC@Z;;rBNiJhM~8gCjqpByC=bwthm#M>?)0-`%6Q|%0yEdrvS8NB&|FM6c>Qi zls0qs3J9LfwR86*+foR0&tdU+oEjF7QN^ocFe*6wEMgPc&mY7KT@>R{%ma7QuXQG-CFN4KW5KOvAzlAg^Zx)How$>2H#u`&a5|BI$m z!#b-Gu>=y@#hIjnR>lICgXUI3V~OfSb!Qba38Shq$M5g$LJ1=Hk#$@_903o2C}KPc zqU0FfT{VXbawF4S2Yi6WsG)Hdc%yaofF7BSL_|BL$Llt@Zs6%`dU8HXo< z_#=VPC6JWSB+#CCq8eFEiA0#w>v!~5M`JN_ZHS+3L%RNo^sZ#4DHQ*;l^A97x}7;V z=*#HjkFIe&>Ywc&5B%eSe?0Jy2mbNEKOXr1&jUXnuE<{CPAv#LI!#lNmreBa9Ib3D z4Nc4pz!W){-p=$Odit|sAjs2;;%8%|ivqK?D2_K^W}OSn;)A(+0@2?`$I8lVZqE46 z=h3J1>2C;XMbl;dx%`(oyd>}t1?HMj0Hsaz@uL9R3DEc;iVvMm1GEsh;Uoe2BB15{ zfP#R2PS1Cqq2JPJR|ZXAfPjSj>}>RaZ9sbnT%fTT589^qHIb;Q)AZI85a)&%13WNq{FOcI8*?@B0 zf5X3o!EXh)oB@|RLKB5g7asy$t{cY(IVL??PV|0ULrx7lPQjr>8%y0LwGU z5HxvYdis0T^z`Ieu)NU)LAShq@_T1OklGiJf8=N0@(c*#I{-n~>we}rr$SI^6j&l@ z_aXQZ81>LsLXa+CDd=N91o3W#phb@%h->?d-#{6?9LjzML3Y4bW)C6g#7PK}bp^6_ z{2ROJ>q$TS_DjsH{}@Xub5r}jIwO$?1lSSCx$P%~WGt!tWc~we4*nm{)3?Er3d;fT zhXAVyb4x0c^lJ#oSbd;7gT5^QNmh0?4o)txG_ho+c&23LmxT>1MX*oLPVe)x@Ph`O zhO}PJ$NZ9(j~_jHGs@?enE8cU?!S{cz3}DytY6Z_i<^x?#I%2jnOBG)6rD{}#)Z$z z`X!01b2L~wK>j6Wej!_IRf~qz`}tYFq`8;w6Y|YU`6Xt4p@_qG<7T}Z^Rs?Q%fB8y zn!tqnx9%CpHsHzprOf$-SEdr=n)|_(4Rd2V`Ps+NQC$rt9CV=7*)hty?`p$kK)7m3 zSSL1_!o@?}dj_b{4t`d!8MZyLzArRU4W9wM`O@o{stxC0*C3Mk#IOF5;uF{m+2bWmA(R{jb7i5We3`GwxCEfX%@5dSEQaJN+GO1~$Jc>d~v3nXnlO zH+Ko*qbgwIg{dF*>$L=h!(~9$IQkfMNRGh8LtDI!>LvH~yTN9V=gX!hb~DimB7)zK zYd!{d1u!>ceUD$GsbnThI0S9!9LPS#M5k9=>T9&xL>sPr1hJ5spqvyA7mvK}Xt2bM zi7_Ydv#&vhx*S~jkizEdm?hts=u-nUC9H~==~E95P@}4t=u_>zjf7TxV!}lqdJH(_ zm~c^URSeqd9&8VB^X(G~W@cEHd@n~cf!_l1yj|AgN$`h|&BHfxc$a6w*1t5US>Ke2 zyCScmGG{a&e0|Vb+0!d+^O@;MFMi%@ZP~=ckoKg(ZOmsu&*Im{NndpYL4ByC(XQCUrtg2x& zA>RUepV(V&c&G=)$A#hdLI78C32=iuGkj-7>{Plmv<}}g6)2`lh?=M+(GaPOO z2$I)sVt^|A+Z6fSNAkP!m}r88>#q}3m}!FRUBwreX#&H}|agXAFRwrpX!f%2#l{q%jh4As4hmCW1cVxm13Fy)bnu+8BZMb;E zmObA{Upe6GvpRLJP3eC4c)NvT6%Hfrt;8KIyLLrfeK2lBc^=qQmh2T9mKwFB-$AKC zsn){`QoAovR!cY&QDXMq)l2F5Q?rulO+3@k+QHJK{-*Td0=41d8;whb8zNHG8#{U& z2?~2mRQtoUPVYS=Gvr=)dg)$lbHkxb8Ts*O{Aef90Gj8;nL=;7t>K2+`5)zn=2#HUfLCcQ~1oI`$R)6I{KRy=*D zi;q_G-f%7XsvZF5f@S%^?Fy2eK3P2HTLe3&6L{O6LqK>jo=rqnaFGz07v==z2sUJz zYEs)`jCX)rT$*ar35<6o&Z&pHPnP#>70#(UmimUbt7^5=ix9J#^3UfZgA`-ZRK1R2 zye2`Oo5oRF_BI{6gK4TfwzYZ>#N(Ic=+~4{rv#MP$HK%g-pd_QF?V?MvX)}*=#70l zq`SEB0gf;K0d5gE=;!$d`)TV|_Q_Tk-mjKsX{yw*_Q-u`>k{^_t+Kqsxevsf_wJ~v zq?APl<=ffSlnIRM*OU;1-P9wUpFPgw%@@4MU&3CuEi;cd^hz!Z#%puHOKDl&PmVy+ zy{m?1;$O|vRL57J7sq%~Y;8oGIv(}r@irtY>(|_rj6MT<$WO3e*|?KWwr z&VbMEk>s3s4vfd?g--1j8G32>N=*@`?&3sV;Nf~&3H#W#Y@k=3nRxpIsirEU!A!hO z)~zhk=l0}TS>BV0ErL$%7n(URp278p4yNI|6g!qUb?{UQI(2C|9L0D$=uJeI9nx#P zpzG9iS*~8+Ogz`sRMzQ5G1mdRnhL@(R`x0KUulUMsiQV+a8cx*v)-QeThR;Q)vSH0Y7aazLcv=AA&GpqLS zk*)YNwxmIy-MBOfT&m~^jT1S((*Cyh^m`Y|T7W%arA6dD$^iAi+@3N*W$g0j=wy_`lRwC}-u6ey&^K$#8h~Qn;S32U4 zKk>MpO}-|MKBvg$MkHJn=goY6iM`Pczed~SyON|BH+Anp%IhT_jeL1ge6mIh^JaHI z&+I>P<_Miv<8D#>vQhkb^LCJwjUas_iCD-Q@#Z1Ps-%?5;KEiE0RK+H#=^pi0Bbvp zOB;c7f>^LPUseHY!9~qV5j7)z{HUWsQWzbdEXEZMIsq+2@Z!V4g+;<`)SDF>>H~sP zXHtFGeJfpKhG_NM_(JhANi=%De)S}XBx@UW)- zl-7#Ai(eLvO{6BAr6npZHRzLc(1>kVR}ru{Zg=odrA}fJtt}*S*^0gkU)a}gm=Qup z&Pk@%1nQqvW%K6Q%Ks(nNBw`SWLC8~O=s#<6#Ynh;(&eKcdmkXpI*c;@=uZ-l)OZz zp<>_$VEe53z!^Vqb3*BL{6XBD=0K-EVb=FX#eF{{sq|iS<QHu<;_V?%1+@eHDeQ zV#CV@_I*tly52)r`r1bZC- z925KT;YSx7PaU6zezc7fB9{WfZMJLK^mexGLt}%{RI>^ejF;GqD@>CusuXoHCE_T`-2iebmBpDBG}m3SdeV2j1^lDtPq=%Uoi`-kdm@>MAnt&5dpzP{x$qqoQZXZw|-V|4w2Fxhcd}^rrIt^v`)%ihHX1pqcYP ztAc+^T77fgj{aP1>`}?#OIhImSU|bYz`ONXgZj>|S~~XW0mQ>x$0wSr;&yawoSc`m z|IaeB&B`8`5*It1_CG0$4H$f;ewMh{OE+KH%z3!A&|kknAs*Kc1kXwOJe=N)jV-b$ zd$>sS=1ihg$`-H&zFJL3S8DVF)<3E}TZQ2X|1&8+|D?lzmO6uY zcgODEc>hJx`FK|6_jm0Y_G5!c7l=BdiGwOGkX5nu=xx#hRta3SF*1hQ+;lzDxYgUDiHrn9;~kuL|^zl`t&a; zo`TaymYv^Gju(04f4_v_@paw*0*Nh`p$ls+`xGgDQ+2(M_rbNt$85{~cp^B(-~S0A zI7LoFE1P&}E1$JKxpFay;AazG2eM#F_89ff~{@DGohFrPj+?7j0EPL-g z$Cf(O#J>hIt`r4^hw5bnHgrrq`SkvIr(?&NeVyXFzS}-3m;JWWa?o=9yO^ub{og+? zFpru~zAc!%{DsbPFQl7eQeaAW>OtzYml>AMTDPm)lS8k`e9q4ZH2?S{qvg!LcGY9a z-#}HKc9yTiySI05k-w}J?rkVD!lNg$Mk;wykM+?$tL_U~s0#Py}( z8EW{y_`v`-_p_^l2iQOl4dCXVKZHQT{wCgoC0!riKHskSNa2Ck9YEfG8~-R>oZGeR zb^FD)^=Y4ZLEos$&*&Rj*b$r@3m79e8~FE77N=_h%7H1ZZAhggDV*UZmt0wPL22uV zn&u}$B3K>&fUHG$8)v_x$+uf(Z`awikm(Dk2d`$99A0N}Jw?QK<>99``}2AtZ9d_ zN!=N_?IsJZRD;=D2C z@~6WJ_d{n91*-`8f?H0I<&>$Gr86qerWReyWlJ}q2K>Rw7P^&;dNW*S43H0sIo^V3 zT9HAnUaS~Y5|O@P>zgm1ofaK-^||>h^m<_!J!^D3Eo2A5g<;|`fU*fy`>)N{eblr;fRBXAKz~#FpcCmGh#R~a-8N%os zmN{Q;*!$$>k%hi%K2nY$#&B!>8h5Fcg2%s|_e4ohO`f*XaE;%$ zxYs4fY&H{6J6gAMaQ{j{wRn4sKGjH|AdfTMWc?$-i$T{8rJHKLFQw)-jfVRlcDNt9 z#@?7s?f6@o-#A9B?z)1bAwunV>BIX2w_F7)tK{vC*wm8!GpGSYqd$!J_hAvZn-Le* zz%e%cdw{MbliO}4u=}E8Q=4x!t*{@ee>Cdjc!0Wbs$jwKVRsPU z-d`9opIm7QqELKfCcs(`ui5*6-ik7ZMQmA`mfH`hP-+$(=H6RZcv9Bap?`P8Mm5_& zw94o?z3bPi^Oii2>O48N3lrtFYw+UjrDa8jN;m`GREhm-cD|AvzM*>QgT&4y6MJx4 zyLNq>c>L~CQDceT==JI|5BWMei+#K9@fH{5ec}v!T6GHWZyTQ{Hjg28UTwYeYMaS+ zr7O{!{!@vQbR}kmRNP*A4ak0BK1 z=_}k*{Ya3dXL6Bxx%BGcgJ-+yHGQsbPs=ZqjP(}^&};E>PB<50{=mA8&!j9yBWq9< zYh`$ldkDyOS{snzFGg9aN%eQ!VZdOkdL=lnFZx<~c%evvR?)z{+1I?JxF_P(du z99K4(QMH_}Oes*`c9eFzc}8YWB9K{5yLm&ovs%0$(i|pPxwrs~NxVJrZ8dp4rqZcl z`3cK<*wUw=cgLrpyDT3cd>yIZ^nHDmnihrqL*r0X-ixtGpl0>53zEq}eP3W$1dJB} z#($wJdgew5de-=m-hooAw`$Ub!-``YI>WcH3{tsr|BBl{8~TV)Jn z|E$vJU3{qKiDlQWTK8qIiDk!CsZ{LPp!TvZW}p~#EBveah?8E>gKy`@v;wYIPF?~- z&1x{zAX(Vy(?7GbKL`OQ2wTH1rDILVLizYc94!v-9XUdqn*q|23RdWrJ^pkozb-nv z&OlOCj^^ndwmDg&P;NlXlnkMM0fR;{#qvT*sz&4jR+Ym2)&5suWj+nQSB3 zni^SmOe>{6LXeenY2CVwvdIm?!Q)P6CFSxTSlY$JpQCx`1m|T{QfR!Dl5&);V2*-D zgLNh$I`&qdSsJUd`ZyE3`?BlJbpqCBR#5cn1x{8_`OeG9^n3^|I+xwPt->mJXv2|i z>0S1H$KM3sNg|hr71X`)>qhY{J=80S(N1ox|5VgYC8nP`H$c-%-*j+``=`TGs8y;5 z9ac{*d@NEV$SSGAZ(68cH@0}$Wtyv8@X#W}Ck~^S#KmnARO9G%UM5s2y#AT>9O4Z{ zlKc033C>lKRDYJ3zBtx2=$Hdaj^?BuW?UrVE2;K6A+|x{oaA!WQA>ejEw98y0^FD7 zXiAqwUhl4p;E02IBB;C8ZBdGgN=SY>K?y!bsouA{o`9#R$gz!{qp>!XdSB{m=I|@x z&tHb&6Rh1pYooH%YlLm>`%>GUUAM!wj`X~+9(8Ncw|io3LfPcN^qco8q*rqh2DrnH z#SU8Q-oB-?YKZ{aK-bj*XXU=}&>0WMs{%DP2f4&G)(MPcdj`H^jocycvQK~`igTd_ zDhZK<9=KKDuOf!@Q;j&;zhsaxP4IgG z`_+7U9~RfL-{_U|>8*|CR)~^ZWarj=R$fZ(oVrWX6QPEu$O%U%{T3$+h*HMqDHtyS zgNaSJ8wxHEeANZ+#(3o$F8FW2Y5xr}+2G5}gbjt?PE5EIzEDosiZTI2!NxLSQ#fp) z@U5Sjdx0g`2NsGBS_m77(1s0#-%fKw84S^xFfcO^_!mAD4qGVvc7iPuLcp&c1zYO> z@r4gtPw>w=u(3R_k$;Oz|DOpcCR_^Jqkj*OiM~`Dwg&L~9oQn70dm0xGtrmI!4}HI zg&z)Es3>eC6D zw#b=py)prZ1!bb&LBSG82Qd*Q;jo3mcPae~v@kbqm^fx&Ot_R8v2d|}OScDQ-h#qS zP~G8>IBJE4i()6VI8)x_Rg$&{)C-B|>N;%_>=5qs1Yi8S!|w=YW##Cv?y%GU`5Sy| zm9ZfpCjqQ4_ztXfGl3>`G-8Awh4Ep`$jh$eSxB$5T2c7N{Lrz8*3DS^M?d_L(c^@oxa@|8yGgXt*Ho}yAyxd$`kClSjRBCE^qfOy{6+U8u$bQ| zbpM&lqlY;6Mn%@>N_}G;2~J$|RPxzH&7v)P94?Pm65fT>c6^DtHq5u{sbsSiPEE5& zw0c*AVh8G0#P@I3BLWK|rXg+CTjzGmWg|^jvz`bZ+*8vccWW{-$0XwWZFcGPpZ7ct z=S;l3vSKOte3p*B*(*xLn-_}er5l<&TOj>K-^Z*zS{t&a*e$@(TE0f+2#mTiJ+y zR2N4@BSR5=Iapo2Z;%;RR8rQfOoYJXtwqQf%<{9DCaDRj6-!ZPW&Jd?IroQ!aY>p+ zC55vZjX-HakhIo$erch-Y|9Wvx2#8m_OkLvBZBL{J77bf#W`NS*p?Ll z?~&5qN}GKt4SZ(+5@@DbAI*}&IuSxu(Zf zy4aTozU#bccitxF?cFo};~8crK)TfH!3i8a{lH;nw|rnRvs*yPWfyTicWYPjr2px3 zi)=pTYK*Y;iR*PwM=Kr7^SRn2a^iNg`O@318roc^4yZlL+__WxQ6{gsdydfQhZ`*S z)aO@HAF~;J8&4UkS<^OPD#v9QEzynmOljeEdsn-pK_zY%Vzp;+0@&>-E3ioW*L#xr(K_`~fjoI8Sf&_turA6q_B{VwG#n++nV2Z z1mzP^6W&Za(#vRg0M_E~mVoRhcXz99D~qdD^roH~Yt`s1>-pNPwDs;-`UkOGwEf#X zAH>$Zp0Egu@jG+3Uy&jcd#)Xyduqq#CO-LiP&U*sYW=#@6+WBe+Ip?Mqq(N%)$7<%el#)sbl3gIphw zGP2+;nBexEn+;7es%Un}J6ma|slMy2E^T=7Cd6BCSJs)riVtGPBfFkuP&32?P410J zd=Tc)Q|PU54^%2JHqLsqj~dUTa`@20s}|)A0q#pVEvCvFKN3GoVk=yS1h?Ur2?s?> zSxE{zPkimheURICq;+G34X60#q8jtQZ!BM@&ULRl<#|4|il!uN*mg=xV_|tcbix+J za$&f_X=5d^>|KSEjnK|me6xo4k}rEqZcLTNXS41xFD$4aBS~Kcy1G}i9kIJ(uVPOP z$+D}=x_`7=J}H~L`#r%Uh_BcBKP(!%=rZA;8Yid}YAN1;LEU{~ zxu9ZaRn8@8Hony0_94gIE2ru`3arJBhj!MpTU zKP{drck-DjU(=qtw38H*qP6uCP1|H8g2d}pbiB(iJoXk&c$p{HG<4Hjyeevs*uu@- zax~+I-zvNXO|MooE@dgCB{w;3MOXpGLBSdv02fbk?b5ZK*s9T)M8u{d{O?{K6B( zr-oIdMa?2R&;d)SW~Sz{h2=Woyu2&$JK{0xT{yF?YkWGl?L7JM*1!(4K;6TSQq4GY zhn%|JU2lt9$k+Q6n@uU;#^hX-5j*Ex$m?`{Z|L*X{KGk0r*3$Osw7`uy`ubO4VEX# zYirGZvgoI-$q}f~W@6mu^Ag;KDsKy4lf7%s4%QAx>qp;J+^wl$Q|nq2(|i4Do676q zjCGW!eoy1rD{S=?JK67C5R&*-ci!XpgL>nz(YXtYmnaVFyMS@BTOXh*fPXU1>X8z)Kk zI=HYptw6TK#42oI?ci3~p(dAq%s1H2d8$4a*3>tN5+z z*tTms6EAlt<*A~gkimx#zI)olRv@)l0-z&LN@M-z! G_WuEesGD~H diff --git a/ldm/modules/attention.py b/ldm/modules/attention.py index 4485c1e..f4eff39 100644 --- a/ldm/modules/attention.py +++ b/ldm/modules/attention.py @@ -7,8 +7,6 @@ from einops import rearrange, repeat from ldm.modules.diffusionmodules.util import checkpoint -import psutil - def exists(val): return val is not None @@ -169,98 +167,30 @@ class CrossAttention(nn.Module): nn.Dropout(dropout) ) - if torch.cuda.is_available(): - self.einsum_op = self.einsum_op_cuda - else: - self.mem_total = psutil.virtual_memory().total / (1024**3) - self.einsum_op = self.einsum_op_mps_v1 if self.mem_total >= 32 else self.einsum_op_mps_v2 - - def einsum_op_compvis(self, q, k, v, r1): - s1 = einsum('b i d, b j d -> b i j', q, k) * self.scale # faster - s2 = s1.softmax(dim=-1, dtype=q.dtype) - del s1 - r1 = einsum('b i j, b j d -> b i d', s2, v) - del s2 - return r1 - - def einsum_op_mps_v1(self, q, k, v, r1): - if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096 - r1 = self.einsum_op_compvis(q, k, v, r1) - else: - slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1])) - for i in range(0, q.shape[1], slice_size): - end = i + slice_size - s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * self.scale - s2 = s1.softmax(dim=-1, dtype=r1.dtype) - del s1 - r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v) - del s2 - return r1 - - def einsum_op_mps_v2(self, q, k, v, r1): - if self.mem_total >= 8 and q.shape[1] <= 4096: - r1 = self.einsum_op_compvis(q, k, v, r1) - else: - slice_size = 1 - for i in range(0, q.shape[0], slice_size): - end = min(q.shape[0], i + slice_size) - s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end]) - s1 *= self.scale - s2 = s1.softmax(dim=-1, dtype=r1.dtype) - del s1 - r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end]) - del s2 - return r1 - - def einsum_op_cuda(self, q, k, v, r1): - stats = torch.cuda.memory_stats(q.device) - mem_active = stats['active_bytes.all.current'] - mem_reserved = stats['reserved_bytes.all.current'] - mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device()) - mem_free_torch = mem_reserved - mem_active - mem_free_total = mem_free_cuda + mem_free_torch - - gb = 1024 ** 3 - tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * 4 - mem_required = tensor_size * 2.5 - steps = 1 - - if mem_required > mem_free_total: - steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2))) - - if steps > 64: - max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64 - raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). ' - f'Need: {mem_required/64/gb:0.1f}GB free, Have:{mem_free_total/gb:0.1f}GB free') - - slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] - for i in range(0, q.shape[1], slice_size): - end = min(q.shape[1], i + slice_size) - s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * self.scale - s2 = s1.softmax(dim=-1, dtype=r1.dtype) - del s1 - r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v) - del s2 - return r1 - def forward(self, x, context=None, mask=None): h = self.heads q = self.to_q(x) context = default(context, x) - del x k = self.to_k(context) v = self.to_v(context) - del context q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) - r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) - r1 = self.einsum_op(q, k, v, r1) - del q, k, v - r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h) - del r1 - return self.to_out(r2) + sim = einsum('b i d, b j d -> b i j', q, k) * self.scale + + if exists(mask): + mask = rearrange(mask, 'b ... -> b (...)') + max_neg_value = -torch.finfo(sim.dtype).max + mask = repeat(mask, 'b j -> (b h) () j', h=h) + sim.masked_fill_(~mask, max_neg_value) + + # attention, what we cannot get enough of + attn = sim.softmax(dim=-1) + + out = einsum('b i j, b j d -> b i d', attn, v) + out = rearrange(out, '(b h) n d -> b n (h d)', h=h) + return self.to_out(out) class BasicTransformerBlock(nn.Module): @@ -279,10 +209,9 @@ class BasicTransformerBlock(nn.Module): return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint) def _forward(self, x, context=None): - x = x.contiguous() if x.device.type == 'mps' else x - x += self.attn1(self.norm1(x)) - x += self.attn2(self.norm2(x), context=context) - x += self.ff(self.norm3(x)) + x = self.attn1(self.norm1(x)) + x + x = self.attn2(self.norm2(x), context=context) + x + x = self.ff(self.norm3(x)) + x return x diff --git a/ldm/modules/diffusionmodules/model.py b/ldm/modules/diffusionmodules/model.py index dbbb325..533e589 100644 --- a/ldm/modules/diffusionmodules/model.py +++ b/ldm/modules/diffusionmodules/model.py @@ -1,5 +1,4 @@ # pytorch_diffusion + derived encoder decoder -import gc import math import torch import torch.nn as nn @@ -120,30 +119,18 @@ class ResnetBlock(nn.Module): padding=0) def forward(self, x, temb): - h1 = x - h2 = self.norm1(h1) - del h1 - - h3 = nonlinearity(h2) - del h2 - - h4 = self.conv1(h3) - del h3 + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) if temb is not None: - h4 = h4 + self.temb_proj(nonlinearity(temb))[:,:,None,None] + h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None] - h5 = self.norm2(h4) - del h4 - - h6 = nonlinearity(h5) - del h5 - - h7 = self.dropout(h6) - del h6 - - h8 = self.conv2(h7) - del h7 + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) if self.in_channels != self.out_channels: if self.use_conv_shortcut: @@ -151,7 +138,7 @@ class ResnetBlock(nn.Module): else: x = self.nin_shortcut(x) - return x + h8 + return x+h class LinAttnBlock(LinearAttention): @@ -191,65 +178,28 @@ class AttnBlock(nn.Module): def forward(self, x): h_ = x h_ = self.norm(h_) - q1 = self.q(h_) - k1 = self.k(h_) + q = self.q(h_) + k = self.k(h_) v = self.v(h_) # compute attention - b, c, h, w = q1.shape + b,c,h,w = q.shape + q = q.reshape(b,c,h*w) + q = q.permute(0,2,1) # b,hw,c + k = k.reshape(b,c,h*w) # b,c,hw + w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c)**(-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) - q2 = q1.reshape(b, c, h*w) - del q1 + # attend to values + v = v.reshape(b,c,h*w) + w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b,c,h,w) - q = q2.permute(0, 2, 1) # b,hw,c - del q2 + h_ = self.proj_out(h_) - k = k1.reshape(b, c, h*w) # b,c,hw - del k1 - - h_ = torch.zeros_like(k, device=q.device) - - stats = torch.cuda.memory_stats(q.device) - mem_active = stats['active_bytes.all.current'] - mem_reserved = stats['reserved_bytes.all.current'] - mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device()) - mem_free_torch = mem_reserved - mem_active - mem_free_total = mem_free_cuda + mem_free_torch - - tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * 4 - mem_required = tensor_size * 2.5 - steps = 1 - - if mem_required > mem_free_total: - steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2))) - - slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] - for i in range(0, q.shape[1], slice_size): - end = i + slice_size - - w1 = torch.bmm(q[:, i:end], k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] - w2 = w1 * (int(c)**(-0.5)) - del w1 - w3 = torch.nn.functional.softmax(w2, dim=2) - del w2 - - # attend to values - v1 = v.reshape(b, c, h*w) - w4 = w3.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) - del w3 - - h_[:, :, i:end] = torch.bmm(v1, w4) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] - del v1, w4 - - h2 = h_.reshape(b, c, h, w) - del h_ - - h3 = self.proj_out(h2) - del h2 - - h3 += x - - return h3 + return x+h_ def make_attn(in_channels, attn_type="vanilla"): @@ -590,54 +540,31 @@ class Decoder(nn.Module): temb = None # z to block_in - h1 = self.conv_in(z) + h = self.conv_in(z) # middle - h2 = self.mid.block_1(h1, temb) - del h1 - - h3 = self.mid.attn_1(h2) - del h2 - - h = self.mid.block_2(h3, temb) - del h3 - - # prepare for up sampling - gc.collect() - torch.cuda.empty_cache() + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) # upsampling for i_level in reversed(range(self.num_resolutions)): for i_block in range(self.num_res_blocks+1): h = self.up[i_level].block[i_block](h, temb) if len(self.up[i_level].attn) > 0: - t = h - h = self.up[i_level].attn[i_block](t) - del t - + h = self.up[i_level].attn[i_block](h) if i_level != 0: - t = h - h = self.up[i_level].upsample(t) - del t + h = self.up[i_level].upsample(h) # end if self.give_pre_end: return h - h1 = self.norm_out(h) - del h - - h2 = nonlinearity(h1) - del h1 - - h = self.conv_out(h2) - del h2 - + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) if self.tanh_out: - t = h - h = torch.tanh(t) - del t - + h = torch.tanh(h) return h diff --git a/ldm/modules/diffusionmodules/util.py b/ldm/modules/diffusionmodules/util.py index f872ba0..a952e6c 100644 --- a/ldm/modules/diffusionmodules/util.py +++ b/ldm/modules/diffusionmodules/util.py @@ -54,8 +54,7 @@ def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timestep # assert ddim_timesteps.shape[0] == num_ddim_timesteps # add one to get the final alpha values right (the ones from first scale to data during sampling) - # steps_out = ddim_timesteps + 1 # removed due to some issues when reaching 1000 - steps_out = np.where(ddim_timesteps != 999, ddim_timesteps+1, ddim_timesteps) + steps_out = ddim_timesteps + 1 if verbose: print(f'Selected timesteps for ddim sampler: {steps_out}') return steps_out @@ -265,4 +264,4 @@ class HybridConditioner(nn.Module): def noise_like(shape, device, repeat=False): repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) noise = lambda: torch.randn(shape, device=device) - return repeat_noise() if repeat else noise() + return repeat_noise() if repeat else noise() \ No newline at end of file diff --git a/scripts/DeforumStableDiffusion.py b/scripts/DeforumStableDiffusion.py deleted file mode 100644 index cd88539..0000000 --- a/scripts/DeforumStableDiffusion.py +++ /dev/null @@ -1,1312 +0,0 @@ -#Deforum Stable Diffusion v0.4 -#Stable Diffusion by Robin Rombach, Andreas Blattmann, Dominik Lorenz, Patrick Esser, Björn Ommer and the Stability.ai Team. K Diffusion by Katherine Crowson. You need to get the ckpt file and put it on your Google Drive first to use this. It can be downloaded from HuggingFace. - -#Notebook by deforum -#Local Version by DGSpitzer 大谷的游戏创作小屋 - -import os, time -def get_output_folder(output_path, batch_folder): - out_path = os.path.join(output_path,time.strftime('%Y-%m')) - if batch_folder != "": - out_path = os.path.join(out_path, batch_folder) - os.makedirs(out_path, exist_ok=True) - return out_path - - -def main(): - - import argparse - parser = argparse.ArgumentParser() - - parser.add_argument( - "--settings", - type=str, - default="./examples/runSettings_StillImages.txt", - help="Settings file", - ) - - parser.add_argument( - "--enable_animation_mode", - default=False, - action='store_true', - help="Enable animation mode settings", - ) - - opt = parser.parse_args() - - #@markdown **Model and Output Paths** - # ask for the link - print("Local Path Variables:\n") - - models_path = "./models" #@param {type:"string"} - output_path = "./output" #@param {type:"string"} - - #@markdown **Google Drive Path Variables (Optional)** - mount_google_drive = False #@param {type:"boolean"} - force_remount = False - - - - - if mount_google_drive: - from google.colab import drive # type: ignore - try: - drive_path = "/content/drive" - drive.mount(drive_path,force_remount=force_remount) - models_path_gdrive = "/content/drive/MyDrive/AI/models" #@param {type:"string"} - output_path_gdrive = "/content/drive/MyDrive/AI/StableDiffusion" #@param {type:"string"} - models_path = models_path_gdrive - output_path = output_path_gdrive - except: - print("...error mounting drive or with drive path variables") - print("...reverting to default path variables") - - import os - os.makedirs(models_path, exist_ok=True) - os.makedirs(output_path, exist_ok=True) - - print(f"models_path: {models_path}") - print(f"output_path: {output_path}") - - - - #@markdown **Python Definitions** - import IPython - import json - from IPython import display - - import gc, math, os, pathlib, shutil, subprocess, sys, time - import cv2 - import numpy as np - import pandas as pd - import random - import requests - import torch, torchvision - import torch.nn as nn - import torchvision.transforms as T - import torchvision.transforms.functional as TF - from contextlib import contextmanager, nullcontext - from einops import rearrange, repeat - from itertools import islice - from omegaconf import OmegaConf - from PIL import Image - from pytorch_lightning import seed_everything - from skimage.exposure import match_histograms - from torchvision.utils import make_grid - from tqdm import tqdm, trange - from types import SimpleNamespace - from torch import autocast - - sys.path.extend([ - 'src/taming-transformers', - 'src/clip', - 'stable-diffusion/', - 'k-diffusion', - 'pytorch3d-lite', - 'AdaBins', - 'MiDaS', - ]) - - import py3d_tools as p3d - - from helpers import DepthModel, sampler_fn - from k_diffusion.external import CompVisDenoiser - from ldm.util import instantiate_from_config - from ldm.models.diffusion.ddim import DDIMSampler - from ldm.models.diffusion.plms import PLMSSampler - - - #Read settings files - def load_args(path): - with open(path, "r") as f: - loaded_args = json.load(f)#, ensure_ascii=False, indent=4) - return loaded_args - - master_args = load_args(opt.settings) - - - def sanitize(prompt): - whitelist = set('abcdefghijklmnopqrstuvwxyz ABCDEFGHIJKLMNOPQRSTUVWXYZ') - tmp = ''.join(filter(whitelist.__contains__, prompt)) - return tmp.replace(' ', '_') - - def anim_frame_warp_2d(prev_img_cv2, args, anim_args, keys, frame_idx): - angle = keys.angle_series[frame_idx] - zoom = keys.zoom_series[frame_idx] - translation_x = keys.translation_x_series[frame_idx] - translation_y = keys.translation_y_series[frame_idx] - - center = (args.W // 2, args.H // 2) - trans_mat = np.float32([[1, 0, translation_x], [0, 1, translation_y]]) - rot_mat = cv2.getRotationMatrix2D(center, angle, zoom) - trans_mat = np.vstack([trans_mat, [0,0,1]]) - rot_mat = np.vstack([rot_mat, [0,0,1]]) - xform = np.matmul(rot_mat, trans_mat) - - return cv2.warpPerspective( - prev_img_cv2, - xform, - (prev_img_cv2.shape[1], prev_img_cv2.shape[0]), - borderMode=cv2.BORDER_WRAP if anim_args.border == 'wrap' else cv2.BORDER_REPLICATE - ) - - def anim_frame_warp_3d(prev_img_cv2, depth, anim_args, keys, frame_idx): - TRANSLATION_SCALE = 1.0/200.0 # matches Disco - translate_xyz = [ - -keys.translation_x_series[frame_idx] * TRANSLATION_SCALE, - keys.translation_y_series[frame_idx] * TRANSLATION_SCALE, - -keys.translation_z_series[frame_idx] * TRANSLATION_SCALE - ] - rotate_xyz = [ - math.radians(keys.rotation_3d_x_series[frame_idx]), - math.radians(keys.rotation_3d_y_series[frame_idx]), - math.radians(keys.rotation_3d_z_series[frame_idx]) - ] - rot_mat = p3d.euler_angles_to_matrix(torch.tensor(rotate_xyz, device=device), "XYZ").unsqueeze(0) - result = transform_image_3d(prev_img_cv2, depth, rot_mat, translate_xyz, anim_args) - torch.cuda.empty_cache() - return result - - def add_noise(sample: torch.Tensor, noise_amt: float) -> torch.Tensor: - return sample + torch.randn(sample.shape, device=sample.device) * noise_amt - - def load_img(path, shape, use_alpha_as_mask=False): - # use_alpha_as_mask: Read the alpha channel of the image as the mask image - if path.startswith('http://') or path.startswith('https://'): - image = Image.open(requests.get(path, stream=True).raw) - else: - image = Image.open(path) - - if use_alpha_as_mask: - image = image.convert('RGBA') - else: - image = image.convert('RGB') - - image = image.resize(shape, resample=Image.LANCZOS) - - mask_image = None - if use_alpha_as_mask: - # Split alpha channel into a mask_image - red, green, blue, alpha = Image.Image.split(image) - mask_image = alpha.convert('L') - image = image.convert('RGB') - - image = np.array(image).astype(np.float16) / 255.0 - image = image[None].transpose(0, 3, 1, 2) - image = torch.from_numpy(image) - image = 2.*image - 1. - - return image, mask_image - - def load_mask_latent(mask_input, shape): - # mask_input (str or PIL Image.Image): Path to the mask image or a PIL Image object - # shape (list-like len(4)): shape of the image to match, usually latent_image.shape - - if isinstance(mask_input, str): # mask input is probably a file name - if mask_input.startswith('http://') or mask_input.startswith('https://'): - mask_image = Image.open(requests.get(mask_input, stream=True).raw).convert('RGBA') - else: - mask_image = Image.open(mask_input).convert('RGBA') - elif isinstance(mask_input, Image.Image): - mask_image = mask_input - else: - raise Exception("mask_input must be a PIL image or a file name") - - mask_w_h = (shape[-1], shape[-2]) - mask = mask_image.resize(mask_w_h, resample=Image.LANCZOS) - mask = mask.convert("L") - return mask - - def prepare_mask(mask_input, mask_shape, mask_brightness_adjust=1.0, mask_contrast_adjust=1.0): - # mask_input (str or PIL Image.Image): Path to the mask image or a PIL Image object - # shape (list-like len(4)): shape of the image to match, usually latent_image.shape - # mask_brightness_adjust (non-negative float): amount to adjust brightness of the iamge, - # 0 is black, 1 is no adjustment, >1 is brighter - # mask_contrast_adjust (non-negative float): amount to adjust contrast of the image, - # 0 is a flat grey image, 1 is no adjustment, >1 is more contrast - - mask = load_mask_latent(mask_input, mask_shape) - - # Mask brightness/contrast adjustments - if mask_brightness_adjust != 1: - mask = TF.adjust_brightness(mask, mask_brightness_adjust) - if mask_contrast_adjust != 1: - mask = TF.adjust_contrast(mask, mask_contrast_adjust) - - # Mask image to array - mask = np.array(mask).astype(np.float32) / 255.0 - mask = np.tile(mask,(4,1,1)) - mask = np.expand_dims(mask,axis=0) - mask = torch.from_numpy(mask) - - if args.invert_mask: - mask = ( (mask - 0.5) * -1) + 0.5 - - mask = np.clip(mask,0,1) - return mask - - def maintain_colors(prev_img, color_match_sample, mode): - if mode == 'Match Frame 0 RGB': - return match_histograms(prev_img, color_match_sample, multichannel=True) - elif mode == 'Match Frame 0 HSV': - prev_img_hsv = cv2.cvtColor(prev_img, cv2.COLOR_RGB2HSV) - color_match_hsv = cv2.cvtColor(color_match_sample, cv2.COLOR_RGB2HSV) - matched_hsv = match_histograms(prev_img_hsv, color_match_hsv, multichannel=True) - return cv2.cvtColor(matched_hsv, cv2.COLOR_HSV2RGB) - else: # Match Frame 0 LAB - prev_img_lab = cv2.cvtColor(prev_img, cv2.COLOR_RGB2LAB) - color_match_lab = cv2.cvtColor(color_match_sample, cv2.COLOR_RGB2LAB) - matched_lab = match_histograms(prev_img_lab, color_match_lab, multichannel=True) - return cv2.cvtColor(matched_lab, cv2.COLOR_LAB2RGB) - - - def make_callback(sampler_name, dynamic_threshold=None, static_threshold=None, mask=None, init_latent=None, sigmas=None, sampler=None, masked_noise_modifier=1.0): - # Creates the callback function to be passed into the samplers - # The callback function is applied to the image at each step - def dynamic_thresholding_(img, threshold): - # Dynamic thresholding from Imagen paper (May 2022) - s = np.percentile(np.abs(img.cpu()), threshold, axis=tuple(range(1,img.ndim))) - s = np.max(np.append(s,1.0)) - torch.clamp_(img, -1*s, s) - torch.FloatTensor.div_(img, s) - - # Callback for samplers in the k-diffusion repo, called thus: - # callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) - def k_callback_(args_dict): - if dynamic_threshold is not None: - dynamic_thresholding_(args_dict['x'], dynamic_threshold) - if static_threshold is not None: - torch.clamp_(args_dict['x'], -1*static_threshold, static_threshold) - if mask is not None: - init_noise = init_latent + noise * args_dict['sigma'] - is_masked = torch.logical_and(mask >= mask_schedule[args_dict['i']], mask != 0 ) - new_img = init_noise * torch.where(is_masked,1,0) + args_dict['x'] * torch.where(is_masked,0,1) - args_dict['x'].copy_(new_img) - - # Function that is called on the image (img) and step (i) at each step - def img_callback_(img, i): - # Thresholding functions - if dynamic_threshold is not None: - dynamic_thresholding_(img, dynamic_threshold) - if static_threshold is not None: - torch.clamp_(img, -1*static_threshold, static_threshold) - if mask is not None: - i_inv = len(sigmas) - i - 1 - init_noise = sampler.stochastic_encode(init_latent, torch.tensor([i_inv]*batch_size).to(device), noise=noise) - is_masked = torch.logical_and(mask >= mask_schedule[i], mask != 0 ) - new_img = init_noise * torch.where(is_masked,1,0) + img * torch.where(is_masked,0,1) - img.copy_(new_img) - - if init_latent is not None: - noise = torch.randn_like(init_latent, device=device) * masked_noise_modifier - if sigmas is not None and len(sigmas) > 0: - mask_schedule, _ = torch.sort(sigmas/torch.max(sigmas)) - elif len(sigmas) == 0: - mask = None # no mask needed if no steps (usually happens because strength==1.0) - if sampler_name in ["plms","ddim"]: - # Callback function formated for compvis latent diffusion samplers - if mask is not None: - assert sampler is not None, "Callback function for stable-diffusion samplers requires sampler variable" - batch_size = init_latent.shape[0] - - callback = img_callback_ - else: - # Default callback function uses k-diffusion sampler variables - callback = k_callback_ - - return callback - - def sample_from_cv2(sample: np.ndarray) -> torch.Tensor: - sample = ((sample.astype(float) / 255.0) * 2) - 1 - sample = sample[None].transpose(0, 3, 1, 2).astype(np.float16) - sample = torch.from_numpy(sample) - return sample - - def sample_to_cv2(sample: torch.Tensor, type=np.uint8) -> np.ndarray: - sample_f32 = rearrange(sample.squeeze().cpu().numpy(), "c h w -> h w c").astype(np.float32) - sample_f32 = ((sample_f32 * 0.5) + 0.5).clip(0, 1) - sample_int8 = (sample_f32 * 255) - return sample_int8.astype(type) - - def transform_image_3d(prev_img_cv2, depth_tensor, rot_mat, translate, anim_args): - # adapted and optimized version of transform_image_3d from Disco Diffusion https://github.com/alembics/disco-diffusion - w, h = prev_img_cv2.shape[1], prev_img_cv2.shape[0] - - aspect_ratio = float(w)/float(h) - near, far, fov_deg = anim_args.near_plane, anim_args.far_plane, anim_args.fov - persp_cam_old = p3d.FoVPerspectiveCameras(near, far, aspect_ratio, fov=fov_deg, degrees=True, device=device) - persp_cam_new = p3d.FoVPerspectiveCameras(near, far, aspect_ratio, fov=fov_deg, degrees=True, R=rot_mat, T=torch.tensor([translate]), device=device) - - # range of [-1,1] is important to torch grid_sample's padding handling - y,x = torch.meshgrid(torch.linspace(-1.,1.,h,dtype=torch.float32,device=device),torch.linspace(-1.,1.,w,dtype=torch.float32,device=device)) - z = torch.as_tensor(depth_tensor, dtype=torch.float32, device=device) - xyz_old_world = torch.stack((x.flatten(), y.flatten(), z.flatten()), dim=1) - - xyz_old_cam_xy = persp_cam_old.get_full_projection_transform().transform_points(xyz_old_world)[:,0:2] - xyz_new_cam_xy = persp_cam_new.get_full_projection_transform().transform_points(xyz_old_world)[:,0:2] - - offset_xy = xyz_new_cam_xy - xyz_old_cam_xy - # affine_grid theta param expects a batch of 2D mats. Each is 2x3 to do rotation+translation. - identity_2d_batch = torch.tensor([[1.,0.,0.],[0.,1.,0.]], device=device).unsqueeze(0) - # coords_2d will have shape (N,H,W,2).. which is also what grid_sample needs. - coords_2d = torch.nn.functional.affine_grid(identity_2d_batch, [1,1,h,w], align_corners=False) - offset_coords_2d = coords_2d - torch.reshape(offset_xy, (h,w,2)).unsqueeze(0) - - image_tensor = rearrange(torch.from_numpy(prev_img_cv2.astype(np.float32)), 'h w c -> c h w').to(device) - new_image = torch.nn.functional.grid_sample( - image_tensor.add(1/512 - 0.0001).unsqueeze(0), - offset_coords_2d, - mode=anim_args.sampling_mode, - padding_mode=anim_args.padding_mode, - align_corners=False - ) - - # convert back to cv2 style numpy array - result = rearrange( - new_image.squeeze().clamp(0,255), - 'c h w -> h w c' - ).cpu().numpy().astype(prev_img_cv2.dtype) - return result - - def generate(args, return_latent=False, return_sample=False, return_c=False): - seed_everything(args.seed) - os.makedirs(args.outdir, exist_ok=True) - - sampler = PLMSSampler(model) if args.sampler == 'plms' else DDIMSampler(model) - model_wrap = CompVisDenoiser(model) - batch_size = args.n_samples - prompt = args.prompt - assert prompt is not None - data = [batch_size * [prompt]] - precision_scope = autocast if args.precision == "autocast" else nullcontext - - init_latent = None - mask_image = None - init_image = None - if args.init_latent is not None: - init_latent = args.init_latent - elif args.init_sample is not None: - with precision_scope("cuda"): - init_latent = model.get_first_stage_encoding(model.encode_first_stage(args.init_sample)) - elif args.use_init and args.init_image != None and args.init_image != '': - init_image, mask_image = load_img(args.init_image, - shape=(args.W, args.H), - use_alpha_as_mask=args.use_alpha_as_mask) - init_image = init_image.to(device) - init_image = repeat(init_image, '1 ... -> b ...', b=batch_size) - with precision_scope("cuda"): - init_latent = model.get_first_stage_encoding(model.encode_first_stage(init_image)) # move to latent space - - if not args.use_init and args.strength > 0 and args.strength_0_no_init: - print("\nNo init image, but strength > 0. Strength has been auto set to 0, since use_init is False.") - print("If you want to force strength > 0 with no init, please set strength_0_no_init to False.\n") - args.strength = 0 - - # Mask functions - if args.use_mask: - assert args.mask_file is not None or mask_image is not None, "use_mask==True: An mask image is required for a mask. Please enter a mask_file or use an init image with an alpha channel" - assert args.use_init, "use_mask==True: use_init is required for a mask" - assert init_latent is not None, "use_mask==True: An latent init image is required for a mask" - - mask = prepare_mask(args.mask_file if mask_image is None else mask_image, - init_latent.shape, - args.mask_contrast_adjust, - args.mask_brightness_adjust) - - if (torch.all(mask == 0) or torch.all(mask == 1)) and args.use_alpha_as_mask: - raise Warning("use_alpha_as_mask==True: Using the alpha channel from the init image as a mask, but the alpha channel is blank.") - - mask = mask.to(device) - mask = repeat(mask, '1 ... -> b ...', b=batch_size) - else: - mask = None - - t_enc = int((1.0-args.strength) * args.steps) - - # Noise schedule for the k-diffusion samplers (used for masking) - k_sigmas = model_wrap.get_sigmas(args.steps) - k_sigmas = k_sigmas[len(k_sigmas)-t_enc-1:] - - if args.sampler in ['plms','ddim']: - sampler.make_schedule(ddim_num_steps=args.steps, ddim_eta=args.ddim_eta, ddim_discretize='fill', verbose=False) - - callback = make_callback(sampler_name=args.sampler, - dynamic_threshold=args.dynamic_threshold, - static_threshold=args.static_threshold, - mask=mask, - init_latent=init_latent, - sigmas=k_sigmas, - sampler=sampler) - - results = [] - with torch.no_grad(): - with precision_scope("cuda"): - with model.ema_scope(): - for prompts in data: - uc = None - if args.scale != 1.0: - uc = model.get_learned_conditioning(batch_size * [""]) - if isinstance(prompts, tuple): - prompts = list(prompts) - c = model.get_learned_conditioning(prompts) - - if args.init_c != None: - c = args.init_c - - if args.sampler in ["klms","dpm2","dpm2_ancestral","heun","euler","euler_ancestral"]: - samples = sampler_fn( - c=c, - uc=uc, - args=args, - model_wrap=model_wrap, - init_latent=init_latent, - t_enc=t_enc, - device=device, - cb=callback) - else: - # args.sampler == 'plms' or args.sampler == 'ddim': - if init_latent is not None and args.strength > 0: - z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(device)) - else: - z_enc = torch.randn([args.n_samples, args.C, args.H // args.f, args.W // args.f], device=device) - if args.sampler == 'ddim': - samples = sampler.decode(z_enc, - c, - t_enc, - unconditional_guidance_scale=args.scale, - unconditional_conditioning=uc, - img_callback=callback) - elif args.sampler == 'plms': # no "decode" function in plms, so use "sample" - shape = [args.C, args.H // args.f, args.W // args.f] - samples, _ = sampler.sample(S=args.steps, - conditioning=c, - batch_size=args.n_samples, - shape=shape, - verbose=False, - unconditional_guidance_scale=args.scale, - unconditional_conditioning=uc, - eta=args.ddim_eta, - x_T=z_enc, - img_callback=callback) - else: - raise Exception(f"Sampler {args.sampler} not recognised.") - - if return_latent: - results.append(samples.clone()) - - x_samples = model.decode_first_stage(samples) - if return_sample: - results.append(x_samples.clone()) - - x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0) - - if return_c: - results.append(c.clone()) - - for x_sample in x_samples: - x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') - image = Image.fromarray(x_sample.astype(np.uint8)) - results.append(image) - return results - - #@markdown **Select and Load Model** - - model_config = "v1-inference.yaml" #@param ["custom","v1-inference.yaml"] - model_checkpoint = "sd-v1-4.ckpt" #@param ["custom","sd-v1-4-full-ema.ckpt","sd-v1-4.ckpt","sd-v1-3-full-ema.ckpt","sd-v1-3.ckpt","sd-v1-2-full-ema.ckpt","sd-v1-2.ckpt","sd-v1-1-full-ema.ckpt","sd-v1-1.ckpt"] - custom_config_path = "" #@param {type:"string"} - custom_checkpoint_path = "" #@param {type:"string"} - - load_on_run_all = True #@param {type: 'boolean'} - half_precision = True # check - check_sha256 = True #@param {type:"boolean"} - - model_map = { - "sd-v1-4-full-ema.ckpt": {'sha256': '14749efc0ae8ef0329391ad4436feb781b402f4fece4883c7ad8d10556d8a36a'}, - "sd-v1-4.ckpt": {'sha256': 'fe4efff1e174c627256e44ec2991ba279b3816e364b49f9be2abc0b3ff3f8556'}, - "sd-v1-3-full-ema.ckpt": {'sha256': '54632c6e8a36eecae65e36cb0595fab314e1a1545a65209f24fde221a8d4b2ca'}, - "sd-v1-3.ckpt": {'sha256': '2cff93af4dcc07c3e03110205988ff98481e86539c51a8098d4f2236e41f7f2f'}, - "sd-v1-2-full-ema.ckpt": {'sha256': 'bc5086a904d7b9d13d2a7bccf38f089824755be7261c7399d92e555e1e9ac69a'}, - "sd-v1-2.ckpt": {'sha256': '3b87d30facd5bafca1cbed71cfb86648aad75d1c264663c0cc78c7aea8daec0d'}, - "sd-v1-1-full-ema.ckpt": {'sha256': 'efdeb5dc418a025d9a8cc0a8617e106c69044bc2925abecc8a254b2910d69829'}, - "sd-v1-1.ckpt": {'sha256': '86cd1d3ccb044d7ba8db743d717c9bac603c4043508ad2571383f954390f3cea'} - } - - # config path - ckpt_config_path = custom_config_path if model_config == "custom" else os.path.join(models_path, model_config) - if os.path.exists(ckpt_config_path): - print(f"{ckpt_config_path} exists") - else: - ckpt_config_path = "./stable-diffusion/configs/stable-diffusion/v1-inference.yaml" - print(f"Using config: {ckpt_config_path}") - - # checkpoint path or download - ckpt_path = custom_checkpoint_path if model_checkpoint == "custom" else os.path.join(models_path, model_checkpoint) - ckpt_valid = True - if os.path.exists(ckpt_path): - print(f"{ckpt_path} exists") - else: - print(f"Please download model checkpoint and place in {os.path.join(models_path, model_checkpoint)}") - ckpt_valid = False - - if check_sha256 and model_checkpoint != "custom" and ckpt_valid: - import hashlib - print("\n...checking sha256") - with open(ckpt_path, "rb") as f: - bytes = f.read() - hash = hashlib.sha256(bytes).hexdigest() - del bytes - if model_map[model_checkpoint]["sha256"] == hash: - print("hash is correct\n") - else: - print("hash in not correct\n") - ckpt_valid = False - - if ckpt_valid: - print(f"Using ckpt: {ckpt_path}") - - def load_model_from_config(config, ckpt, verbose=False, device='cuda', half_precision=True): - map_location = "cuda" #@param ["cpu", "cuda"] - print(f"Loading model from {ckpt}") - pl_sd = torch.load(ckpt, map_location=map_location) - if "global_step" in pl_sd: - print(f"Global Step: {pl_sd['global_step']}") - sd = pl_sd["state_dict"] - model = instantiate_from_config(config.model) - m, u = model.load_state_dict(sd, strict=False) - if len(m) > 0 and verbose: - print("missing keys:") - print(m) - if len(u) > 0 and verbose: - print("unexpected keys:") - print(u) - - if half_precision: - model = model.half().to(device) - else: - model = model.to(device) - model.eval() - return model - - if load_on_run_all and ckpt_valid: - local_config = OmegaConf.load(f"{ckpt_config_path}") - model = load_model_from_config(local_config, f"{ckpt_path}", half_precision=half_precision) - device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") - model = model.to(device) - - - def DeforumAnimArgs(): - - #@markdown ####**Animation:** - if opt.enable_animation_mode == True: - animation_mode = master_args["animation_mode"] #@param ['None', '2D', '3D', 'Video Input', 'Interpolation'] {type:'string'} - max_frames = master_args["max_frames"] #@param {type:"number"} - border = master_args["border"] #@param ['wrap', 'replicate'] {type:'string'} - - #@markdown ####**Motion Parameters:** - angle = master_args["angle"]#@param {type:"string"} - zoom = master_args["zoom"] #@param {type:"string"} - translation_x = master_args["translation_x"] #@param {type:"string"} - translation_y = master_args["translation_y"] #@param {type:"string"} - translation_z = master_args["translation_z"] #@param {type:"string"} - rotation_3d_x = master_args["rotation_3d_x"] #@param {type:"string"} - rotation_3d_y = master_args["rotation_3d_y"] #@param {type:"string"} - rotation_3d_z = master_args["rotation_3d_z"] #@param {type:"string"} - noise_schedule = master_args["noise_schedule"] #@param {type:"string"} - strength_schedule = master_args["strength_schedule"] #@param {type:"string"} - contrast_schedule = master_args["contrast_schedule"] #@param {type:"string"} - - #@markdown ####**Coherence:** - color_coherence = master_args["color_coherence"] #@param ['None', 'Match Frame 0 HSV', 'Match Frame 0 LAB', 'Match Frame 0 RGB'] {type:'string'} - diffusion_cadence = master_args["diffusion_cadence"] #@param ['1','2','3','4','5','6','7','8'] {type:'string'} - - #@markdown #### Depth Warping - use_depth_warping = master_args["use_depth_warping"] #@param {type:"boolean"} - midas_weight = master_args["midas_weight"] #@param {type:"number"} - near_plane = master_args["near_plane"] - far_plane = master_args["far_plane"] - fov = master_args["fov"] #@param {type:"number"} - padding_mode = master_args["padding_mode"] #@param ['border', 'reflection', 'zeros'] {type:'string'} - sampling_mode = master_args["sampling_mode"] #@param ['bicubic', 'bilinear', 'nearest'] {type:'string'} - save_depth_maps = master_args["save_depth_maps"] #@param {type:"boolean"} - - #@markdown ####**Video Input:** - video_init_path = master_args["video_init_path"] #@param {type:"string"} - extract_nth_frame = master_args["extract_nth_frame"] #@param {type:"number"} - - #@markdown ####**Interpolation:** - interpolate_key_frames = master_args["interpolate_key_frames"] #@param {type:"boolean"} - interpolate_x_frames = master_args["interpolate_x_frames"] #@param {type:"number"} - - #@markdown ####**Resume Animation:** - resume_from_timestring = master_args["resume_from_timestring"] #@param {type:"boolean"} - resume_timestring = master_args["resume_timestring"] #@param {type:"string"} - else: - #@markdown ####**Still image mode:** - animation_mode = 'None' #@param ['None', '2D', '3D', 'Video Input', 'Interpolation'] {type:'string'} - max_frames = 10 #@param {type:"number"} - border = 'wrap' #@param ['wrap', 'replicate'] {type:'string'} - - #@markdown ####**Motion Parameters:** - angle = "0:(0)"#@param {type:"string"} - zoom = "0:(1.04)"#@param {type:"string"} - translation_x = "0:(0)"#@param {type:"string"} - translation_y = "0:(2)"#@param {type:"string"} - translation_z = "0:(0.5)"#@param {type:"string"} - rotation_3d_x = "0:(0)"#@param {type:"string"} - rotation_3d_y = "0:(0)"#@param {type:"string"} - rotation_3d_z = "0:(0)"#@param {type:"string"} - noise_schedule = "0: (0.02)"#@param {type:"string"} - strength_schedule = "0: (0.6)"#@param {type:"string"} - contrast_schedule = "0: (1.0)"#@param {type:"string"} - - #@markdown ####**Coherence:** - color_coherence = 'Match Frame 0 LAB' #@param ['None', 'Match Frame 0 HSV', 'Match Frame 0 LAB', 'Match Frame 0 RGB'] {type:'string'} - diffusion_cadence = '1' #@param ['1','2','3','4','5','6','7','8'] {type:'string'} - - #@markdown #### Depth Warping - use_depth_warping = True #@param {type:"boolean"} - midas_weight = 0.3#@param {type:"number"} - near_plane = 200 - far_plane = 10000 - fov = 40#@param {type:"number"} - padding_mode = 'border'#@param ['border', 'reflection', 'zeros'] {type:'string'} - sampling_mode = 'bicubic'#@param ['bicubic', 'bilinear', 'nearest'] {type:'string'} - save_depth_maps = False #@param {type:"boolean"} - - #@markdown ####**Video Input:** - video_init_path ='./input/video_in.mp4'#@param {type:"string"} - extract_nth_frame = 1#@param {type:"number"} - - #@markdown ####**Interpolation:** - interpolate_key_frames = True #@param {type:"boolean"} - interpolate_x_frames = 100 #@param {type:"number"} - - #@markdown ####**Resume Animation:** - resume_from_timestring = False #@param {type:"boolean"} - resume_timestring = "20220829210106" #@param {type:"string"} - - return locals() - - class DeformAnimKeys(): - def __init__(self, anim_args): - self.angle_series = get_inbetweens(parse_key_frames(anim_args.angle)) - self.zoom_series = get_inbetweens(parse_key_frames(anim_args.zoom)) - self.translation_x_series = get_inbetweens(parse_key_frames(anim_args.translation_x)) - self.translation_y_series = get_inbetweens(parse_key_frames(anim_args.translation_y)) - self.translation_z_series = get_inbetweens(parse_key_frames(anim_args.translation_z)) - self.rotation_3d_x_series = get_inbetweens(parse_key_frames(anim_args.rotation_3d_x)) - self.rotation_3d_y_series = get_inbetweens(parse_key_frames(anim_args.rotation_3d_y)) - self.rotation_3d_z_series = get_inbetweens(parse_key_frames(anim_args.rotation_3d_z)) - self.noise_schedule_series = get_inbetweens(parse_key_frames(anim_args.noise_schedule)) - self.strength_schedule_series = get_inbetweens(parse_key_frames(anim_args.strength_schedule)) - self.contrast_schedule_series = get_inbetweens(parse_key_frames(anim_args.contrast_schedule)) - - - def get_inbetweens(key_frames, integer=False, interp_method='Linear'): - key_frame_series = pd.Series([np.nan for a in range(anim_args.max_frames)]) - - for i, value in key_frames.items(): - key_frame_series[i] = value - key_frame_series = key_frame_series.astype(float) - - if interp_method == 'Cubic' and len(key_frames.items()) <= 3: - interp_method = 'Quadratic' - if interp_method == 'Quadratic' and len(key_frames.items()) <= 2: - interp_method = 'Linear' - - key_frame_series[0] = key_frame_series[key_frame_series.first_valid_index()] - key_frame_series[anim_args.max_frames-1] = key_frame_series[key_frame_series.last_valid_index()] - key_frame_series = key_frame_series.interpolate(method=interp_method.lower(),limit_direction='both') - if integer: - return key_frame_series.astype(int) - return key_frame_series - - def parse_key_frames(string, prompt_parser=None): - import re - pattern = r'((?P[0-9]+):[\s]*[\(](?P[\S\s]*?)[\)])' - frames = dict() - for match_object in re.finditer(pattern, string): - frame = int(match_object.groupdict()['frame']) - param = match_object.groupdict()['param'] - if prompt_parser: - frames[frame] = prompt_parser(param) - else: - frames[frame] = param - if frames == {} and len(string) != 0: - raise RuntimeError('Key Frame string not correctly formatted') - return frames - - #Prompt will be put in here: for example: - ''' - prompts = [ - "a beaufiful young girl holding a flower, art by huang guangjian and gil elvgren and sachin teng, trending on artstation", - "a beaufiful young girl holding a flower, art by greg rutkowski and alphonse mucha, trending on artstation", - #"the third prompt I don't want it I commented it with an", - ] - - animation_prompts = { - 0: "amazing alien landscape with lush vegetation and colourful galaxy foreground, digital art, breathtaking, golden ratio, extremely detailed, hyper - detailed, establishing shot, hyperrealistic, cinematic lighting, particles, unreal engine, simon stalenhag, rendered by beeple, makoto shinkai, syd meade, kentaro miura, jean giraud, environment concept, artstation, octane render, 8k uhd image", - 50: "desolate landscape fill with giant flowers, moody :: by James Jean, Jeff Koons, Dan McPharlin Daniel Merrian :: ornate, dynamic, particulate, rich colors, intricate, elegant, highly detailed, centered, artstation, smooth, sharp focus, octane render, 3d", - } - ''' - - #Replace by text file - prompts = master_args["prompts"] - - if opt.enable_animation_mode: - animation_prompts = master_args["animation_prompts"] - else: - animation_prompts = {} - - - - def DeforumArgs(): - - #@markdown **Image Settings** - W = master_args["width"] #@param - H = master_args["height"] #@param - W, H = map(lambda x: x - x % 64, (W, H)) # resize to integer multiple of 64 - - #@markdown **Sampling Settings** - seed = master_args["seed"] #@param - sampler = master_args["sampler"] #@param ["klms","dpm2","dpm2_ancestral","heun","euler","euler_ancestral","plms", "ddim"] - steps = master_args["steps"] #@param - scale = master_args["scale"] #@param - ddim_eta = master_args["ddim_eta"] #@param - dynamic_threshold = None - static_threshold = None - - #@markdown **Save & Display Settings** - save_samples = True #@param {type:"boolean"} - save_settings = True #@param {type:"boolean"} - display_samples = True #@param {type:"boolean"} - - #@markdown **Batch Settings** - n_batch = master_args["n_batch"] #@param - batch_name = master_args["batch_name"] #@param {type:"string"} - filename_format = master_args["filename_format"] #@param ["{timestring}_{index}_{seed}.png","{timestring}_{index}_{prompt}.png"] - seed_behavior = master_args["seed_behavior"] #@param ["iter","fixed","random"] - make_grid = False #@param {type:"boolean"} - grid_rows = 2 #@param - outdir = get_output_folder(output_path, batch_name) - - #@markdown **Init Settings** - use_init = master_args["use_init"] #@param {type:"boolean"} - strength = master_args["strength"] #@param {type:"number"} - init_image = master_args["init_image"] #@param {type:"string"} - strength_0_no_init = True # Set the strength to 0 automatically when no init image is used - # Whiter areas of the mask are areas that change more - use_mask = master_args["use_mask"] #@param {type:"boolean"} - use_alpha_as_mask = master_args["use_alpha_as_mask"] # use the alpha channel of the init image as the mask - mask_file = master_args["mask_file"] #@param {type:"string"} - invert_mask = master_args["invert_mask"] #@param {type:"boolean"} - # Adjust mask image, 1.0 is no adjustment. Should be positive numbers. - mask_brightness_adjust = 1.0 #@param {type:"number"} - mask_contrast_adjust = 1.0 #@param {type:"number"} - - n_samples = 1 # doesnt do anything - precision = 'autocast' - C = 4 - f = 8 - - prompt = "" - timestring = "" - init_latent = None - init_sample = None - init_c = None - - return locals() - - - - def next_seed(args): - if args.seed_behavior == 'iter': - args.seed += 1 - elif args.seed_behavior == 'fixed': - pass # always keep seed the same - else: - args.seed = random.randint(0, 2**32) - return args.seed - - def render_image_batch(args): - args.prompts = {k: f"{v:05d}" for v, k in enumerate(prompts)} - - # create output folder for the batch - os.makedirs(args.outdir, exist_ok=True) - if args.save_settings or args.save_samples: - print(f"Saving to {os.path.join(args.outdir, args.timestring)}_*") - - # save settings for the batch - if args.save_settings: - filename = os.path.join(args.outdir, f"{args.timestring}_settings.txt") - with open(filename, "w+", encoding="utf-8") as f: - dictlist = dict(args.__dict__) - del dictlist['master_args'] - json.dump(dictlist, f, ensure_ascii=False, indent=4) - - index = 0 - - # function for init image batching - init_array = [] - if args.use_init: - if args.init_image == "": - raise FileNotFoundError("No path was given for init_image") - if args.init_image.startswith('http://') or args.init_image.startswith('https://'): - init_array.append(args.init_image) - elif not os.path.isfile(args.init_image): - if args.init_image[-1] != "/": # avoids path error by adding / to end if not there - args.init_image += "/" - for image in sorted(os.listdir(args.init_image)): # iterates dir and appends images to init_array - if image.split(".")[-1] in ("png", "jpg", "jpeg"): - init_array.append(args.init_image + image) - else: - init_array.append(args.init_image) - else: - init_array = [""] - - # when doing large batches don't flood browser with images - clear_between_batches = args.n_batch >= 32 - - for iprompt, prompt in enumerate(prompts): - args.prompt = prompt - print(f"Prompt {iprompt+1} of {len(prompts)}") - print(f"{args.prompt}") - - all_images = [] - - for batch_index in range(args.n_batch): - if clear_between_batches and batch_index % 32 == 0: - display.clear_output(wait=True) - print(f"Batch {batch_index+1} of {args.n_batch}") - - for image in init_array: # iterates the init images - args.init_image = image - results = generate(args) - for image in results: - if args.make_grid: - all_images.append(T.functional.pil_to_tensor(image)) - if args.save_samples: - if args.filename_format == "{timestring}_{index}_{prompt}.png": - filename = f"{args.timestring}_{index:05}_{sanitize(prompt)[:160]}.png" - else: - filename = f"{args.timestring}_{index:05}_{args.seed}.png" - image.save(os.path.join(args.outdir, filename)) - if args.display_samples: - display.display(image) - index += 1 - args.seed = next_seed(args) - - #print(len(all_images)) - if args.make_grid: - grid = make_grid(all_images, nrow=int(len(all_images)/args.grid_rows)) - grid = rearrange(grid, 'c h w -> h w c').cpu().numpy() - filename = f"{args.timestring}_{iprompt:05d}_grid_{args.seed}.png" - grid_image = Image.fromarray(grid.astype(np.uint8)) - grid_image.save(os.path.join(args.outdir, filename)) - display.clear_output(wait=True) - display.display(grid_image) - - - def render_animation(args, anim_args): - # animations use key framed prompts - args.prompts = animation_prompts - - # expand key frame strings to values - keys = DeformAnimKeys(anim_args) - - # resume animation - start_frame = 0 - if anim_args.resume_from_timestring: - for tmp in os.listdir(args.outdir): - if tmp.split("_")[0] == anim_args.resume_timestring: - start_frame += 1 - start_frame = start_frame - 1 - - # create output folder for the batch - os.makedirs(args.outdir, exist_ok=True) - print(f"Saving animation frames to {args.outdir}") - - # save settings for the batch - settings_filename = os.path.join(args.outdir, f"{args.timestring}_settings.txt") - with open(settings_filename, "w+", encoding="utf-8") as f: - s = {**dict(args.__dict__), **dict(anim_args.__dict__)} - del s['master_args'] - del s['opt'] - json.dump(s, f, ensure_ascii=False, indent=4) - - # resume from timestring - if anim_args.resume_from_timestring: - args.timestring = anim_args.resume_timestring - - # expand prompts out to per-frame - prompt_series = pd.Series([np.nan for a in range(anim_args.max_frames)]) - for i, prompt in animation_prompts.items(): - prompt_series[int(i)] = prompt - prompt_series = prompt_series.ffill().bfill() - - # check for video inits - using_vid_init = anim_args.animation_mode == 'Video Input' - - # load depth model for 3D - predict_depths = (anim_args.animation_mode == '3D' and anim_args.use_depth_warping) or anim_args.save_depth_maps - if predict_depths: - depth_model = DepthModel(device) - depth_model.load_midas(models_path) - if anim_args.midas_weight < 1.0: - depth_model.load_adabins() - else: - depth_model = None - anim_args.save_depth_maps = False - - # state for interpolating between diffusion steps - turbo_steps = 1 if using_vid_init else int(anim_args.diffusion_cadence) - turbo_prev_image, turbo_prev_frame_idx = None, 0 - turbo_next_image, turbo_next_frame_idx = None, 0 - - # resume animation - prev_sample = None - color_match_sample = None - if anim_args.resume_from_timestring: - last_frame = start_frame-1 - if turbo_steps > 1: - last_frame -= last_frame%turbo_steps - path = os.path.join(args.outdir,f"{args.timestring}_{last_frame:05}.png") - img = cv2.imread(path) - img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) - prev_sample = sample_from_cv2(img) - if anim_args.color_coherence != 'None': - color_match_sample = img - if turbo_steps > 1: - turbo_next_image, turbo_next_frame_idx = sample_to_cv2(prev_sample, type=np.float32), last_frame - turbo_prev_image, turbo_prev_frame_idx = turbo_next_image, turbo_next_frame_idx - start_frame = last_frame+turbo_steps - - args.n_samples = 1 - frame_idx = start_frame - while frame_idx < anim_args.max_frames: - print(f"Rendering animation frame {frame_idx} of {anim_args.max_frames}") - noise = keys.noise_schedule_series[frame_idx] - strength = keys.strength_schedule_series[frame_idx] - contrast = keys.contrast_schedule_series[frame_idx] - depth = None - - # emit in-between frames - if turbo_steps > 1: - tween_frame_start_idx = max(0, frame_idx-turbo_steps) - for tween_frame_idx in range(tween_frame_start_idx, frame_idx): - tween = float(tween_frame_idx - tween_frame_start_idx + 1) / float(frame_idx - tween_frame_start_idx) - print(f" creating in between frame {tween_frame_idx} tween:{tween:0.2f}") - - advance_prev = turbo_prev_image is not None and tween_frame_idx > turbo_prev_frame_idx - advance_next = tween_frame_idx > turbo_next_frame_idx - - if depth_model is not None: - assert(turbo_next_image is not None) - depth = depth_model.predict(turbo_next_image, anim_args) - - if anim_args.animation_mode == '2D': - if advance_prev: - turbo_prev_image = anim_frame_warp_2d(turbo_prev_image, args, anim_args, keys, tween_frame_idx) - if advance_next: - turbo_next_image = anim_frame_warp_2d(turbo_next_image, args, anim_args, keys, tween_frame_idx) - else: # '3D' - if advance_prev: - turbo_prev_image = anim_frame_warp_3d(turbo_prev_image, depth, anim_args, keys, tween_frame_idx) - if advance_next: - turbo_next_image = anim_frame_warp_3d(turbo_next_image, depth, anim_args, keys, tween_frame_idx) - turbo_prev_frame_idx = turbo_next_frame_idx = tween_frame_idx - - if turbo_prev_image is not None and tween < 1.0: - img = turbo_prev_image*(1.0-tween) + turbo_next_image*tween - else: - img = turbo_next_image - - filename = f"{args.timestring}_{tween_frame_idx:05}.png" - cv2.imwrite(os.path.join(args.outdir, filename), cv2.cvtColor(img.astype(np.uint8), cv2.COLOR_RGB2BGR)) - if anim_args.save_depth_maps: - depth_model.save(os.path.join(args.outdir, f"{args.timestring}_depth_{tween_frame_idx:05}.png"), depth) - if turbo_next_image is not None: - prev_sample = sample_from_cv2(turbo_next_image) - - # apply transforms to previous frame - if prev_sample is not None: - if anim_args.animation_mode == '2D': - prev_img = anim_frame_warp_2d(sample_to_cv2(prev_sample), args, anim_args, keys, frame_idx) - else: # '3D' - prev_img_cv2 = sample_to_cv2(prev_sample) - depth = depth_model.predict(prev_img_cv2, anim_args) if depth_model else None - prev_img = anim_frame_warp_3d(prev_img_cv2, depth, anim_args, keys, frame_idx) - - # apply color matching - if anim_args.color_coherence != 'None': - if color_match_sample is None: - color_match_sample = prev_img.copy() - else: - prev_img = maintain_colors(prev_img, color_match_sample, anim_args.color_coherence) - - # apply scaling - contrast_sample = prev_img * contrast - # apply frame noising - noised_sample = add_noise(sample_from_cv2(contrast_sample), noise) - - # use transformed previous frame as init for current - args.use_init = True - if half_precision: - args.init_sample = noised_sample.half().to(device) - else: - args.init_sample = noised_sample.to(device) - args.strength = max(0.0, min(1.0, strength)) - - # grab prompt for current frame - args.prompt = prompt_series[frame_idx] - print(f"{args.prompt} {args.seed}") - - # grab init image for current frame - if using_vid_init: - init_frame = os.path.join(args.outdir, 'inputframes', f"{frame_idx+1:04}.jpg") - print(f"Using video init frame {init_frame}") - args.init_image = init_frame - - # sample the diffusion model - sample, image = generate(args, return_latent=False, return_sample=True) - if not using_vid_init: - prev_sample = sample - - if turbo_steps > 1: - turbo_prev_image, turbo_prev_frame_idx = turbo_next_image, turbo_next_frame_idx - turbo_next_image, turbo_next_frame_idx = sample_to_cv2(sample, type=np.float32), frame_idx - frame_idx += turbo_steps - else: - filename = f"{args.timestring}_{frame_idx:05}.png" - image.save(os.path.join(args.outdir, filename)) - if anim_args.save_depth_maps: - if depth is None: - depth = depth_model.predict(sample_to_cv2(sample), anim_args) - depth_model.save(os.path.join(args.outdir, f"{args.timestring}_depth_{frame_idx:05}.png"), depth) - frame_idx += 1 - - display.clear_output(wait=True) - display.display(image) - - args.seed = next_seed(args) - - def render_input_video(args, anim_args): - # create a folder for the video input frames to live in - video_in_frame_path = os.path.join(args.outdir, 'inputframes') - os.makedirs(video_in_frame_path, exist_ok=True) - - # save the video frames from input video - print(f"Exporting Video Frames (1 every {anim_args.extract_nth_frame}) frames to {video_in_frame_path}...") - try: - for f in pathlib.Path(video_in_frame_path).glob('*.jpg'): - f.unlink() - except: - pass - vf = r'select=not(mod(n\,'+str(anim_args.extract_nth_frame)+'))' - subprocess.run([ - 'ffmpeg', '-i', f'{anim_args.video_init_path}', - '-vf', f'{vf}', '-vsync', 'vfr', '-q:v', '2', - '-loglevel', 'error', '-stats', - os.path.join(video_in_frame_path, '%04d.jpg') - ], stdout=subprocess.PIPE).stdout.decode('utf-8') - - # determine max frames from length of input frames - anim_args.max_frames = len([f for f in pathlib.Path(video_in_frame_path).glob('*.jpg')]) - - args.use_init = True - print(f"Loading {anim_args.max_frames} input frames from {video_in_frame_path} and saving video frames to {args.outdir}") - render_animation(args, anim_args) - - def render_interpolation(args, anim_args): - # animations use key framed prompts - args.prompts = animation_prompts - - # create output folder for the batch - os.makedirs(args.outdir, exist_ok=True) - print(f"Saving animation frames to {args.outdir}") - - # save settings for the batch - settings_filename = os.path.join(args.outdir, f"{args.timestring}_settings.txt") - with open(settings_filename, "w+", encoding="utf-8") as f: - s = {**dict(args.__dict__), **dict(anim_args.__dict__)} - del s['master_args'] - del s['opt'] - json.dump(s, f, ensure_ascii=False, indent=4) - - # Interpolation Settings - args.n_samples = 1 - args.seed_behavior = 'fixed' # force fix seed at the moment bc only 1 seed is available - prompts_c_s = [] # cache all the text embeddings - - print(f"Preparing for interpolation of the following...") - - for i, prompt in animation_prompts.items(): - args.prompt = prompt - - # sample the diffusion model - results = generate(args, return_c=True) - c, image = results[0], results[1] - prompts_c_s.append(c) - - # display.clear_output(wait=True) - display.display(image) - - args.seed = next_seed(args) - - display.clear_output(wait=True) - print(f"Interpolation start...") - - frame_idx = 0 - - if anim_args.interpolate_key_frames: - for i in range(len(prompts_c_s)-1): - dist_frames = list(animation_prompts.items())[i+1][0] - list(animation_prompts.items())[i][0] - if dist_frames <= 0: - print("key frames duplicated or reversed. interpolation skipped.") - return - else: - for j in range(dist_frames): - # interpolate the text embedding - prompt1_c = prompts_c_s[i] - prompt2_c = prompts_c_s[i+1] - args.init_c = prompt1_c.add(prompt2_c.sub(prompt1_c).mul(j * 1/dist_frames)) - - # sample the diffusion model - results = generate(args) - image = results[0] - - filename = f"{args.timestring}_{frame_idx:05}.png" - image.save(os.path.join(args.outdir, filename)) - frame_idx += 1 - - display.clear_output(wait=True) - display.display(image) - - args.seed = next_seed(args) - - else: - for i in range(len(prompts_c_s)-1): - for j in range(anim_args.interpolate_x_frames+1): - # interpolate the text embedding - prompt1_c = prompts_c_s[i] - prompt2_c = prompts_c_s[i+1] - args.init_c = prompt1_c.add(prompt2_c.sub(prompt1_c).mul(j * 1/(anim_args.interpolate_x_frames+1))) - - # sample the diffusion model - results = generate(args) - image = results[0] - - filename = f"{args.timestring}_{frame_idx:05}.png" - image.save(os.path.join(args.outdir, filename)) - frame_idx += 1 - - display.clear_output(wait=True) - display.display(image) - - args.seed = next_seed(args) - - # generate the last prompt - args.init_c = prompts_c_s[-1] - results = generate(args) - image = results[0] - filename = f"{args.timestring}_{frame_idx:05}.png" - image.save(os.path.join(args.outdir, filename)) - - display.clear_output(wait=True) - display.display(image) - args.seed = next_seed(args) - - #clear init_c - args.init_c = None - - - args = SimpleNamespace(**DeforumArgs()) - anim_args = SimpleNamespace(**DeforumAnimArgs()) - - args.timestring = time.strftime('%Y%m%d%H%M%S') - args.strength = max(0.0, min(1.0, args.strength)) - - if args.seed == -1: - args.seed = random.randint(0, 2**32 - 1) - if not args.use_init: - args.init_image = None - if args.sampler == 'plms' and (args.use_init or anim_args.animation_mode != 'None'): - print(f"Init images aren't supported with PLMS yet, switching to KLMS") - args.sampler = 'klms' - if args.sampler != 'ddim': - args.ddim_eta = 0 - - if anim_args.animation_mode == 'None': - anim_args.max_frames = 1 - elif anim_args.animation_mode == 'Video Input': - args.use_init = True - - # clean up unused memory - gc.collect() - torch.cuda.empty_cache() - - # dispatch to appropriate renderer - if anim_args.animation_mode == '2D' or anim_args.animation_mode == '3D': - render_animation(args, anim_args) - elif anim_args.animation_mode == 'Video Input': - render_input_video(args, anim_args) - elif anim_args.animation_mode == 'Interpolation': - render_interpolation(args, anim_args) - else: - render_image_batch(args) - - - skip_video_for_run_all = False #@param {type: 'boolean'} - fps = 12 #@param {type:"number"} - #@markdown **Manual Settings** - use_manual_settings = False #@param {type:"boolean"} - image_path = "./output/out_%05d.png" #@param {type:"string"} - mp4_path = "./output/out_%05d.mp4" #@param {type:"string"} - - - if skip_video_for_run_all == True or opt.enable_animation_mode == False: - print('Skipping video creation, uncheck skip_video_for_run_all if you want to run it') - else: - import os - import subprocess - from base64 import b64encode - - print(f"{image_path} -> {mp4_path}") - - if use_manual_settings: - max_frames = "200" #@param {type:"string"} - else: - image_path = os.path.join(args.outdir, f"{args.timestring}_%05d.png") - mp4_path = os.path.join(args.outdir, f"{args.timestring}.mp4") - max_frames = str(anim_args.max_frames) - - # make video - cmd = [ - 'ffmpeg', - '-y', - '-vcodec', 'png', - '-r', str(fps), - '-start_number', str(0), - '-i', image_path, - '-frames:v', max_frames, - '-c:v', 'libx264', - '-vf', - f'fps={fps}', - '-pix_fmt', 'yuv420p', - '-crf', '17', - '-preset', 'veryfast', - mp4_path - ] - process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) - stdout, stderr = process.communicate() - if process.returncode != 0: - print(stderr) - raise RuntimeError(stderr) - - mp4 = open(mp4_path,'rb').read() - data_url = "data:video/mp4;base64," + b64encode(mp4).decode() - display.display( display.HTML(f'') ) - - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/scripts/ModelManager.py b/scripts/ModelManager.py deleted file mode 100644 index 983f85b..0000000 --- a/scripts/ModelManager.py +++ /dev/null @@ -1,46 +0,0 @@ -# base webui import and utils. -from webui_streamlit import st -from sd_utils import * - -# streamlit imports - - -#other imports -import pandas as pd -from io import StringIO - -# Temp imports - - -# end of imports -#--------------------------------------------------------------------------------------------------------------- - -def layout(): - #search = st.text_input(label="Search", placeholder="Type the name of the model you want to search for.", help="") - - csvString = f""" - ,Stable Diffusion v1.4 , ./models/ldm/stable-diffusion-v1 , https://www.googleapis.com/storage/v1/b/aai-blog-files/o/sd-v1-4.ckpt?alt=media - ,GFPGAN v1.3 , ./src/gfpgan/experiments/pretrained_models , https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth - ,RealESRGAN_x4plus , ./src/realesrgan/experiments/pretrained_models , https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth - ,RealESRGAN_x4plus_anime_6B , ./src/realesrgan/experiments/pretrained_models , https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth - ,Waifu Diffusion v1.2 , ./models/custom , http://wd.links.sd:8880/wd-v1-2-full-ema.ckpt - ,TrinArt Stable Diffusion v2 , ./models/custom , https://huggingface.co/naclbit/trinart_stable_diffusion_v2/resolve/main/trinart2_step115000.ckpt - ,Stable Diffusion Concept Library , ./models/customsd-concepts-library , https://github.com/sd-webui/sd-concepts-library - """ - colms = st.columns((1, 3, 5, 5)) - columns = ["№",'Model Name','Save Location','Download Link'] - - # Convert String into StringIO - csvStringIO = StringIO(csvString) - df = pd.read_csv(csvStringIO, sep=",", header=None, names=columns) - - for col, field_name in zip(colms, columns): - # table header - col.write(field_name) - - for x, model_name in enumerate(df["Model Name"]): - col1, col2, col3, col4 = st.columns((1, 3, 4, 6)) - col1.write(x) # index - col2.write(df['Model Name'][x]) - col3.write(df['Save Location'][x]) - col4.write(df['Download Link'][x]) \ No newline at end of file diff --git a/scripts/Settings.py b/scripts/Settings.py deleted file mode 100644 index a1e21ec..0000000 --- a/scripts/Settings.py +++ /dev/null @@ -1,5 +0,0 @@ -from webui_streamlit import st - -# The global settings section will be moved to the Settings page. -#with st.expander("Global Settings:"): -st.write("Global Settings:") diff --git a/scripts/home.py b/scripts/home.py deleted file mode 100644 index 2702fcc..0000000 --- a/scripts/home.py +++ /dev/null @@ -1,216 +0,0 @@ -# base webui import and utils. -from webui_streamlit import st -from sd_utils import * - -# streamlit imports - - -#other imports - -# Temp imports - - -# end of imports -#--------------------------------------------------------------------------------------------------------------- - -import os -from PIL import Image - -try: - # this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start. - from transformers import logging - - logging.set_verbosity_error() -except: - pass - -class plugin_info(): - plugname = "home" - description = "Home" - isTab = True - displayPriority = 0 - -def getLatestGeneratedImagesFromPath(): - #get the latest images from the generated images folder - #get the path to the generated images folder - generatedImagesPath = os.path.join(os.getcwd(),'outputs') - #get all the files from the folders and subfolders - files = [] - #get the latest 10 images from the output folder without walking the subfolders - for r, d, f in os.walk(generatedImagesPath): - for file in f: - if '.png' in file: - files.append(os.path.join(r, file)) - #sort the files by date - files.sort(reverse=True, key=os.path.getmtime) - latest = files[:90] - latest.reverse() - - # reverse the list so the latest images are first and truncate to - # a reasonable number of images, 10 pages worth - return [Image.open(f) for f in latest] - -def get_images_from_lexica(): - #scrape images from lexica.art - #get the html from the page - #get the html with cookies and javascript - apiEndpoint = r'https://lexica.art/api/trpc/prompts.infinitePrompts?batch=1&input=%7B%220%22%3A%7B%22json%22%3A%7B%22limit%22%3A10%2C%22text%22%3A%22%22%2C%22cursor%22%3A10%7D%7D%7D' - #REST API call - # - from requests_html import HTMLSession - session = HTMLSession() - - response = session.get(apiEndpoint) - #req = requests.Session() - #req.headers['user-agent'] = 'Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/96.0.4664.45 Safari/537.36' - #response = req.get(apiEndpoint) - print(response.status_code) - print(response.text) - #get the json from the response - #json = response.json() - #get the prompts from the json - print(response) - #session = requests.Session() - #parseEndpointJson = session.get(apiEndpoint,headers=headers,verify=False) - #print(parseEndpointJson) - #print('test2') - #page = requests.get("https://lexica.art/", headers={'User-Agent': 'Mozilla/5.0'}) - #parse the html - #soup = BeautifulSoup(page.content, 'html.parser') - #find all the images - #print(soup) - #images = soup.find_all('alt-image') - #create a list to store the image urls - image_urls = [] - #loop through the images - for image in images: - #get the url - image_url = image['src'] - #add it to the list - image_urls.append('http://www.lexica.art/'+image_url) - #return the list - print(image_urls) - return image_urls - -def layout(): - #streamlit home page layout - #center the title - st.markdown("

Welcome, let's make some 🎨

", unsafe_allow_html=True) - #make a gallery of images - #st.markdown("

Gallery

", unsafe_allow_html=True) - #create a gallery of images using columns - #col1, col2, col3 = st.columns(3) - #load the images - #create 3 columns - # create a tab for the gallery - #st.markdown("

Gallery

", unsafe_allow_html=True) - #st.markdown("

Gallery

", unsafe_allow_html=True) - - history_tab, discover_tabs = st.tabs(["History","Discover"]) - - latestImages = getLatestGeneratedImagesFromPath() - st.session_state['latestImages'] = latestImages - - with history_tab: - ##--------------------------------------------------------- - ## image slideshow test - ## Number of entries per screen - #slideshow_N = 9 - #slideshow_page_number = 0 - #slideshow_last_page = len(latestImages) // slideshow_N - - ## Add a next button and a previous button - - #slideshow_prev, slideshow_image_col , slideshow_next = st.columns([1, 10, 1]) - - #with slideshow_image_col: - #slideshow_image = st.empty() - - #slideshow_image.image(st.session_state['latestImages'][0]) - - #current_image = 0 - - #if slideshow_next.button("Next", key=1): - ##print (current_image+1) - #current_image = current_image+1 - #slideshow_image.image(st.session_state['latestImages'][current_image+1]) - #if slideshow_prev.button("Previous", key=0): - ##print ([current_image-1]) - #current_image = current_image-1 - #slideshow_image.image(st.session_state['latestImages'][current_image - 1]) - - - #--------------------------------------------------------- - - # image gallery - # Number of entries per screen - gallery_N = 9 - if not "galleryPage" in st.session_state: - st.session_state["galleryPage"] = 0 - gallery_last_page = len(latestImages) // gallery_N - - # Add a next button and a previous button - - gallery_prev, gallery_refresh, gallery_pagination , gallery_next = st.columns([2, 2, 8, 1]) - - # the pagination doesnt work for now so its better to disable the buttons. - - if gallery_refresh.button("Refresh", key=4): - st.session_state["galleryPage"] = 0 - - if gallery_next.button("Next", key=3): - - if st.session_state["galleryPage"] + 1 > gallery_last_page: - st.session_state["galleryPage"] = 0 - else: - st.session_state["galleryPage"] += 1 - - if gallery_prev.button("Previous", key=2): - - if st.session_state["galleryPage"] - 1 < 0: - st.session_state["galleryPage"] = gallery_last_page - else: - st.session_state["galleryPage"] -= 1 - - print(st.session_state["galleryPage"]) - # Get start and end indices of the next page of the dataframe - gallery_start_idx = st.session_state["galleryPage"] * gallery_N - gallery_end_idx = (1 + st.session_state["galleryPage"]) * gallery_N - - #--------------------------------------------------------- - - placeholder = st.empty() - - #populate the 3 images per column - with placeholder.container(): - col1, col2, col3 = st.columns(3) - col1_cont = st.container() - col2_cont = st.container() - col3_cont = st.container() - - print (len(st.session_state['latestImages'])) - images = list(reversed(st.session_state['latestImages']))[gallery_start_idx:(gallery_start_idx+gallery_N)] - - with col1_cont: - with col1: - [st.image(images[index]) for index in [0, 3, 6] if index < len(images)] - with col2_cont: - with col2: - [st.image(images[index]) for index in [1, 4, 7] if index < len(images)] - with col3_cont: - with col3: - [st.image(images[index]) for index in [2, 5, 8] if index < len(images)] - - - st.session_state['historyTab'] = [history_tab,col1,col2,col3,placeholder,col1_cont,col2_cont,col3_cont] - - with discover_tabs: - st.markdown("

Soon :)

", unsafe_allow_html=True) - - #display the images - #add a button to the gallery - #st.markdown("

Try it out

", unsafe_allow_html=True) - #create a button to the gallery - #if st.button("Try it out"): - #if the button is clicked, go to the gallery - #st.experimental_rerun() diff --git a/scripts/img2img.py b/scripts/img2img.py deleted file mode 100644 index 142fe81..0000000 --- a/scripts/img2img.py +++ /dev/null @@ -1,592 +0,0 @@ -# base webui import and utils. -from webui_streamlit import st -from sd_utils import * - -# streamlit imports -from streamlit import StopException - -#other imports -import cv2 -from PIL import Image, ImageOps -import torch -import k_diffusion as K -import numpy as np -import time -import torch -import skimage -from ldm.models.diffusion.ddim import DDIMSampler -from ldm.models.diffusion.plms import PLMSSampler -# Temp imports - - -# end of imports -#--------------------------------------------------------------------------------------------------------------- - - -try: - # this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start. - from transformers import logging - - logging.set_verbosity_error() -except: - pass - -def img2img(prompt: str = '', init_info: any = None, init_info_mask: any = None, mask_mode: int = 0, mask_blur_strength: int = 3, - mask_restore: bool = False, ddim_steps: int = 50, sampler_name: str = 'DDIM', - n_iter: int = 1, cfg_scale: float = 7.5, denoising_strength: float = 0.8, - seed: int = -1, noise_mode: int = 0, find_noise_steps: str = "", height: int = 512, width: int = 512, resize_mode: int = 0, fp = None, - variant_amount: float = None, variant_seed: int = None, ddim_eta:float = 0.0, - write_info_files:bool = True, RealESRGAN_model: str = "RealESRGAN_x4plus_anime_6B", - separate_prompts:bool = False, normalize_prompt_weights:bool = True, - save_individual_images: bool = True, save_grid: bool = True, group_by_prompt: bool = True, - save_as_jpg: bool = True, use_GFPGAN: bool = True, use_RealESRGAN: bool = True, loopback: bool = False, - random_seed_loopback: bool = False - ): - - outpath = st.session_state['defaults'].general.outdir_img2img or st.session_state['defaults'].general.outdir or "outputs/img2img-samples" - #err = False - #loopback = False - #skip_save = False - seed = seed_to_int(seed) - - batch_size = 1 - - #prompt_matrix = 0 - #normalize_prompt_weights = 1 in toggles - #loopback = 2 in toggles - #random_seed_loopback = 3 in toggles - #skip_save = 4 not in toggles - #save_grid = 5 in toggles - #sort_samples = 6 in toggles - #write_info_files = 7 in toggles - #write_sample_info_to_log_file = 8 in toggles - #jpg_sample = 9 in toggles - #use_GFPGAN = 10 in toggles - #use_RealESRGAN = 11 in toggles - - if sampler_name == 'PLMS': - sampler = PLMSSampler(st.session_state["model"]) - elif sampler_name == 'DDIM': - sampler = DDIMSampler(st.session_state["model"]) - elif sampler_name == 'k_dpm_2_a': - sampler = KDiffusionSampler(st.session_state["model"],'dpm_2_ancestral') - elif sampler_name == 'k_dpm_2': - sampler = KDiffusionSampler(st.session_state["model"],'dpm_2') - elif sampler_name == 'k_euler_a': - sampler = KDiffusionSampler(st.session_state["model"],'euler_ancestral') - elif sampler_name == 'k_euler': - sampler = KDiffusionSampler(st.session_state["model"],'euler') - elif sampler_name == 'k_heun': - sampler = KDiffusionSampler(st.session_state["model"],'heun') - elif sampler_name == 'k_lms': - sampler = KDiffusionSampler(st.session_state["model"],'lms') - else: - raise Exception("Unknown sampler: " + sampler_name) - - def process_init_mask(init_mask: Image): - if init_mask.mode == "RGBA": - init_mask = init_mask.convert('RGBA') - background = Image.new('RGBA', init_mask.size, (0, 0, 0)) - init_mask = Image.alpha_composite(background, init_mask) - init_mask = init_mask.convert('RGB') - return init_mask - - init_img = init_info - init_mask = None - if mask_mode == 0: - if init_info_mask: - init_mask = process_init_mask(init_info_mask) - elif mask_mode == 1: - if init_info_mask: - init_mask = process_init_mask(init_info_mask) - init_mask = ImageOps.invert(init_mask) - elif mask_mode == 2: - init_img_transparency = init_img.split()[-1].convert('L')#.point(lambda x: 255 if x > 0 else 0, mode='1') - init_mask = init_img_transparency - init_mask = init_mask.convert("RGB") - init_mask = resize_image(resize_mode, init_mask, width, height) - init_mask = init_mask.convert("RGB") - - assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]' - t_enc = int(denoising_strength * ddim_steps) - - if init_mask is not None and (noise_mode == 2 or noise_mode == 3) and init_img is not None: - noise_q = 0.99 - color_variation = 0.0 - mask_blend_factor = 1.0 - - np_init = (np.asarray(init_img.convert("RGB"))/255.0).astype(np.float64) # annoyingly complex mask fixing - np_mask_rgb = 1. - (np.asarray(ImageOps.invert(init_mask).convert("RGB"))/255.0).astype(np.float64) - np_mask_rgb -= np.min(np_mask_rgb) - np_mask_rgb /= np.max(np_mask_rgb) - np_mask_rgb = 1. - np_mask_rgb - np_mask_rgb_hardened = 1. - (np_mask_rgb < 0.99).astype(np.float64) - blurred = skimage.filters.gaussian(np_mask_rgb_hardened[:], sigma=16., channel_axis=2, truncate=32.) - blurred2 = skimage.filters.gaussian(np_mask_rgb_hardened[:], sigma=16., channel_axis=2, truncate=32.) - #np_mask_rgb_dilated = np_mask_rgb + blurred # fixup mask todo: derive magic constants - #np_mask_rgb = np_mask_rgb + blurred - np_mask_rgb_dilated = np.clip((np_mask_rgb + blurred2) * 0.7071, 0., 1.) - np_mask_rgb = np.clip((np_mask_rgb + blurred) * 0.7071, 0., 1.) - - noise_rgb = get_matched_noise(np_init, np_mask_rgb, noise_q, color_variation) - blend_mask_rgb = np.clip(np_mask_rgb_dilated,0.,1.) ** (mask_blend_factor) - noised = noise_rgb[:] - blend_mask_rgb **= (2.) - noised = np_init[:] * (1. - blend_mask_rgb) + noised * blend_mask_rgb - - np_mask_grey = np.sum(np_mask_rgb, axis=2)/3. - ref_mask = np_mask_grey < 1e-3 - - all_mask = np.ones((height, width), dtype=bool) - noised[all_mask,:] = skimage.exposure.match_histograms(noised[all_mask,:]**1., noised[ref_mask,:], channel_axis=1) - - init_img = Image.fromarray(np.clip(noised * 255., 0., 255.).astype(np.uint8), mode="RGB") - st.session_state["editor_image"].image(init_img) # debug - - def init(): - image = init_img.convert('RGB') - image = np.array(image).astype(np.float32) / 255.0 - image = image[None].transpose(0, 3, 1, 2) - image = torch.from_numpy(image) - - mask_channel = None - if init_mask: - alpha = resize_image(resize_mode, init_mask, width // 8, height // 8) - mask_channel = alpha.split()[-1] - - mask = None - if mask_channel is not None: - mask = np.array(mask_channel).astype(np.float32) / 255.0 - mask = (1 - mask) - mask = np.tile(mask, (4, 1, 1)) - mask = mask[None].transpose(0, 1, 2, 3) - mask = torch.from_numpy(mask).to(st.session_state["device"]) - - if st.session_state['defaults'].general.optimized: - st.session_state.modelFS.to(st.session_state["device"] ) - - init_image = 2. * image - 1. - init_image = init_image.to(st.session_state["device"]) - init_latent = (st.session_state["model"] if not st.session_state['defaults'].general.optimized else st.session_state.modelFS).get_first_stage_encoding((st.session_state["model"] if not st.session_state['defaults'].general.optimized else modelFS).encode_first_stage(init_image)) # move to latent space - - if st.session_state['defaults'].general.optimized: - mem = torch.cuda.memory_allocated()/1e6 - st.session_state.modelFS.to("cpu") - while(torch.cuda.memory_allocated()/1e6 >= mem): - time.sleep(1) - - return init_latent, mask, - - def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name): - t_enc_steps = t_enc - obliterate = False - if ddim_steps == t_enc_steps: - t_enc_steps = t_enc_steps - 1 - obliterate = True - - if sampler_name != 'DDIM': - x0, z_mask = init_data - - sigmas = sampler.model_wrap.get_sigmas(ddim_steps) - noise = x * sigmas[ddim_steps - t_enc_steps - 1] - - xi = x0 + noise - - # Obliterate masked image - if z_mask is not None and obliterate: - random = torch.randn(z_mask.shape, device=xi.device) - xi = (z_mask * noise) + ((1-z_mask) * xi) - - sigma_sched = sigmas[ddim_steps - t_enc_steps - 1:] - model_wrap_cfg = CFGMaskedDenoiser(sampler.model_wrap) - samples_ddim = K.sampling.__dict__[f'sample_{sampler.get_sampler_name()}'](model_wrap_cfg, xi, sigma_sched, - extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, - 'cond_scale': cfg_scale, 'mask': z_mask, 'x0': x0, 'xi': xi}, disable=False, - callback=generation_callback) - else: - - x0, z_mask = init_data - - sampler.make_schedule(ddim_num_steps=ddim_steps, ddim_eta=0.0, verbose=False) - z_enc = sampler.stochastic_encode(x0, torch.tensor([t_enc_steps]*batch_size).to(st.session_state["device"] )) - - # Obliterate masked image - if z_mask is not None and obliterate: - random = torch.randn(z_mask.shape, device=z_enc.device) - z_enc = (z_mask * random) + ((1-z_mask) * z_enc) - - # decode it - samples_ddim = sampler.decode(z_enc, conditioning, t_enc_steps, - unconditional_guidance_scale=cfg_scale, - unconditional_conditioning=unconditional_conditioning, - z_mask=z_mask, x0=x0) - return samples_ddim - - - - if loopback: - output_images, info = None, None - history = [] - initial_seed = None - - do_color_correction = False - try: - from skimage import exposure - do_color_correction = True - except: - print("Install scikit-image to perform color correction on loopback") - - for i in range(n_iter): - if do_color_correction and i == 0: - correction_target = cv2.cvtColor(np.asarray(init_img.copy()), cv2.COLOR_RGB2LAB) - - output_images, seed, info, stats = process_images( - outpath=outpath, - func_init=init, - func_sample=sample, - prompt=prompt, - seed=seed, - sampler_name=sampler_name, - save_grid=save_grid, - batch_size=1, - n_iter=1, - steps=ddim_steps, - cfg_scale=cfg_scale, - width=width, - height=height, - prompt_matrix=separate_prompts, - use_GFPGAN=use_GFPGAN, - use_RealESRGAN=use_RealESRGAN, # Forcefully disable upscaling when using loopback - realesrgan_model_name=RealESRGAN_model, - normalize_prompt_weights=normalize_prompt_weights, - save_individual_images=save_individual_images, - init_img=init_img, - init_mask=init_mask, - mask_blur_strength=mask_blur_strength, - mask_restore=mask_restore, - denoising_strength=denoising_strength, - noise_mode=noise_mode, - find_noise_steps=find_noise_steps, - resize_mode=resize_mode, - uses_loopback=loopback, - uses_random_seed_loopback=random_seed_loopback, - sort_samples=group_by_prompt, - write_info_files=write_info_files, - jpg_sample=save_as_jpg - ) - - if initial_seed is None: - initial_seed = seed - - input_image = init_img - init_img = output_images[0] - - if do_color_correction and correction_target is not None: - init_img = Image.fromarray(cv2.cvtColor(exposure.match_histograms( - cv2.cvtColor( - np.asarray(init_img), - cv2.COLOR_RGB2LAB - ), - correction_target, - channel_axis=2 - ), cv2.COLOR_LAB2RGB).astype("uint8")) - if mask_restore is True and init_mask is not None: - color_mask = init_mask.filter(ImageFilter.GaussianBlur(mask_blur_strength)) - color_mask = color_mask.convert('L') - source_image = input_image.convert('RGB') - target_image = init_img.convert('RGB') - - init_img = Image.composite(source_image, target_image, color_mask) - - if not random_seed_loopback: - seed = seed + 1 - else: - seed = seed_to_int(None) - - denoising_strength = max(denoising_strength * 0.95, 0.1) - history.append(init_img) - - output_images = history - seed = initial_seed - - else: - output_images, seed, info, stats = process_images( - outpath=outpath, - func_init=init, - func_sample=sample, - prompt=prompt, - seed=seed, - sampler_name=sampler_name, - save_grid=save_grid, - batch_size=batch_size, - n_iter=n_iter, - steps=ddim_steps, - cfg_scale=cfg_scale, - width=width, - height=height, - prompt_matrix=separate_prompts, - use_GFPGAN=use_GFPGAN, - use_RealESRGAN=use_RealESRGAN, - realesrgan_model_name=RealESRGAN_model, - normalize_prompt_weights=normalize_prompt_weights, - save_individual_images=save_individual_images, - init_img=init_img, - init_mask=init_mask, - mask_blur_strength=mask_blur_strength, - denoising_strength=denoising_strength, - noise_mode=noise_mode, - find_noise_steps=find_noise_steps, - mask_restore=mask_restore, - resize_mode=resize_mode, - uses_loopback=loopback, - sort_samples=group_by_prompt, - write_info_files=write_info_files, - jpg_sample=save_as_jpg - ) - - del sampler - - return output_images, seed, info, stats - -# - - -def layout(): - with st.form("img2img-inputs"): - st.session_state["generation_mode"] = "img2img" - - img2img_input_col, img2img_generate_col = st.columns([10,1]) - with img2img_input_col: - #prompt = st.text_area("Input Text","") - prompt = st.text_input("Input Text","", placeholder="A corgi wearing a top hat as an oil painting.") - - # Every form must have a submit button, the extra blank spaces is a temp way to align it with the input field. Needs to be done in CSS or some other way. - img2img_generate_col.write("") - img2img_generate_col.write("") - generate_button = img2img_generate_col.form_submit_button("Generate") - - - # creating the page layout using columns - col1_img2img_layout, col2_img2img_layout, col3_img2img_layout = st.columns([1,2,2], gap="small") - - with col1_img2img_layout: - # If we have custom models available on the "models/custom" - #folder then we show a menu to select which model we want to use, otherwise we use the main model for SD - if st.session_state["CustomModel_available"]: - st.session_state["custom_model"] = st.selectbox("Custom Model:", st.session_state["custom_models"], - index=st.session_state["custom_models"].index(st.session_state['defaults'].general.default_model), - help="Select the model you want to use. This option is only available if you have custom models \ - on your 'models/custom' folder. The model name that will be shown here is the same as the name\ - the file for the model has on said folder, it is recommended to give the .ckpt file a name that \ - will make it easier for you to distinguish it from other models. Default: Stable Diffusion v1.4") - else: - st.session_state["custom_model"] = "Stable Diffusion v1.4" - - - st.session_state["sampling_steps"] = st.slider("Sampling Steps", - value=st.session_state['defaults'].img2img.sampling_steps, - min_value=st.session_state['defaults'].img2img.slider_bounds.sampling.lower, - max_value=st.session_state['defaults'].img2img.slider_bounds.sampling.upper, - step=st.session_state['defaults'].img2img.slider_steps.sampling) - - sampler_name_list = ["k_lms", "k_euler", "k_euler_a", "k_dpm_2", "k_dpm_2_a", "k_heun", "PLMS", "DDIM"] - st.session_state["sampler_name"] = st.selectbox("Sampling method",sampler_name_list, - index=sampler_name_list.index(st.session_state['defaults'].img2img.sampler_name), help="Sampling method to use.") - - mask_mode_list = ["Mask", "Inverted mask", "Image alpha"] - mask_mode = st.selectbox("Mask Mode", mask_mode_list, - help="Select how you want your image to be masked.\"Mask\" modifies the image where the mask is white.\n\ - \"Inverted mask\" modifies the image where the mask is black. \"Image alpha\" modifies the image where the image is transparent." - ) - mask_mode = mask_mode_list.index(mask_mode) - - width = st.slider("Width:", min_value=64, max_value=1024, value=st.session_state['defaults'].img2img.width, step=64) - height = st.slider("Height:", min_value=64, max_value=1024, value=st.session_state['defaults'].img2img.height, step=64) - seed = st.text_input("Seed:", value=st.session_state['defaults'].img2img.seed, help=" The seed to use, if left blank a random seed will be generated.") - noise_mode_list = ["Seed", "Find Noise", "Matched Noise", "Find+Matched Noise"] - noise_mode = st.selectbox( - "Noise Mode", noise_mode_list, - help="" - ) - noise_mode = noise_mode_list.index(noise_mode) - find_noise_steps = st.slider("Find Noise Steps", value=100, min_value=1, max_value=500) - batch_count = st.slider("Batch count.", min_value=1, max_value=100, value=st.session_state['defaults'].img2img.batch_count, step=1, - help="How many iterations or batches of images to generate in total.") - - # - with st.expander("Advanced"): - separate_prompts = st.checkbox("Create Prompt Matrix.", value=st.session_state['defaults'].img2img.separate_prompts, - help="Separate multiple prompts using the `|` character, and get all combinations of them.") - normalize_prompt_weights = st.checkbox("Normalize Prompt Weights.", value=st.session_state['defaults'].img2img.normalize_prompt_weights, - help="Ensure the sum of all weights add up to 1.0") - loopback = st.checkbox("Loopback.", value=st.session_state['defaults'].img2img.loopback, help="Use images from previous batch when creating next batch.") - random_seed_loopback = st.checkbox("Random loopback seed.", value=st.session_state['defaults'].img2img.random_seed_loopback, help="Random loopback seed") - img2img_mask_restore = st.checkbox("Only modify regenerated parts of image", - value=st.session_state['defaults'].img2img.mask_restore, - help="Enable to restore the unmasked parts of the image with the input, may not blend as well but preserves detail") - save_individual_images = st.checkbox("Save individual images.", value=st.session_state['defaults'].img2img.save_individual_images, - help="Save each image generated before any filter or enhancement is applied.") - save_grid = st.checkbox("Save grid",value=st.session_state['defaults'].img2img.save_grid, help="Save a grid with all the images generated into a single image.") - group_by_prompt = st.checkbox("Group results by prompt", value=st.session_state['defaults'].img2img.group_by_prompt, - help="Saves all the images with the same prompt into the same folder. \ - When using a prompt matrix each prompt combination will have its own folder.") - write_info_files = st.checkbox("Write Info file", value=st.session_state['defaults'].img2img.write_info_files, - help="Save a file next to the image with informartion about the generation.") - save_as_jpg = st.checkbox("Save samples as jpg", value=st.session_state['defaults'].img2img.save_as_jpg, help="Saves the images as jpg instead of png.") - - if st.session_state["GFPGAN_available"]: - use_GFPGAN = st.checkbox("Use GFPGAN", value=st.session_state['defaults'].img2img.use_GFPGAN, help="Uses the GFPGAN model to improve faces after the generation.\ - This greatly improve the quality and consistency of faces but uses extra VRAM. Disable if you need the extra VRAM.") - else: - use_GFPGAN = False - - if st.session_state["RealESRGAN_available"]: - st.session_state["use_RealESRGAN"] = st.checkbox("Use RealESRGAN", value=st.session_state['defaults'].img2img.use_RealESRGAN, - help="Uses the RealESRGAN model to upscale the images after the generation.\ - This greatly improve the quality and lets you have high resolution images but uses extra VRAM. Disable if you need the extra VRAM.") - st.session_state["RealESRGAN_model"] = st.selectbox("RealESRGAN model", ["RealESRGAN_x4plus", "RealESRGAN_x4plus_anime_6B"], index=0) - else: - st.session_state["use_RealESRGAN"] = False - st.session_state["RealESRGAN_model"] = "RealESRGAN_x4plus" - - variant_amount = st.slider("Variant Amount:", value=st.session_state['defaults'].img2img.variant_amount, min_value=0.0, max_value=1.0, step=0.01) - variant_seed = st.text_input("Variant Seed:", value=st.session_state['defaults'].img2img.variant_seed, - help="The seed to use when generating a variant, if left blank a random seed will be generated.") - cfg_scale = st.slider("CFG (Classifier Free Guidance Scale):", min_value=1.0, max_value=30.0, value=st.session_state['defaults'].img2img.cfg_scale, step=0.5, - help="How strongly the image should follow the prompt.") - batch_size = st.slider("Batch size", min_value=1, max_value=100, value=st.session_state['defaults'].img2img.batch_size, step=1, - help="How many images are at once in a batch.\ - It increases the VRAM usage a lot but if you have enough VRAM it can reduce the time it takes to finish \ - generation as more images are generated at once.\ - Default: 1") - - st.session_state["denoising_strength"] = st.slider("Denoising Strength:", value=st.session_state['defaults'].img2img.denoising_strength, - min_value=0.01, max_value=1.0, step=0.01) - - with st.expander("Preview Settings"): - st.session_state["update_preview"] = st.checkbox("Update Image Preview", value=st.session_state['defaults'].img2img.update_preview, - help="If enabled the image preview will be updated during the generation instead of at the end. \ - You can use the Update Preview \Frequency option bellow to customize how frequent it's updated. \ - By default this is enabled and the frequency is set to 1 step.") - - st.session_state["update_preview_frequency"] = st.text_input("Update Image Preview Frequency", value=st.session_state['defaults'].img2img.update_preview_frequency, - help="Frequency in steps at which the the preview image is updated. By default the frequency \ - is set to 1 step.") - - with col2_img2img_layout: - editor_tab = st.tabs(["Editor"]) - - editor_image = st.empty() - st.session_state["editor_image"] = editor_image - - st.form_submit_button("Refresh") - - masked_image_holder = st.empty() - image_holder = st.empty() - - uploaded_images = st.file_uploader( - "Upload Image", accept_multiple_files=False, type=["png", "jpg", "jpeg", "webp"], - help="Upload an image which will be used for the image to image generation.", - ) - if uploaded_images: - image = Image.open(uploaded_images).convert('RGBA') - new_img = image.resize((width, height)) - image_holder.image(new_img) - - mask_holder = st.empty() - - uploaded_masks = st.file_uploader( - "Upload Mask", accept_multiple_files=False, type=["png", "jpg", "jpeg", "webp"], - help="Upload an mask image which will be used for masking the image to image generation.", - ) - if uploaded_masks: - mask = Image.open(uploaded_masks) - if mask.mode == "RGBA": - mask = mask.convert('RGBA') - background = Image.new('RGBA', mask.size, (0, 0, 0)) - mask = Image.alpha_composite(background, mask) - mask = mask.resize((width, height)) - mask_holder.image(mask) - - if uploaded_images and uploaded_masks: - if mask_mode != 2: - final_img = new_img.copy() - alpha_layer = mask.convert('L') - strength = st.session_state["denoising_strength"] - if mask_mode == 0: - alpha_layer = ImageOps.invert(alpha_layer) - alpha_layer = alpha_layer.point(lambda a: a * strength) - alpha_layer = ImageOps.invert(alpha_layer) - elif mask_mode == 1: - alpha_layer = alpha_layer.point(lambda a: a * strength) - alpha_layer = ImageOps.invert(alpha_layer) - - final_img.putalpha(alpha_layer) - - with masked_image_holder.container(): - st.text("Masked Image Preview") - st.image(final_img) - - - with col3_img2img_layout: - result_tab = st.tabs(["Result"]) - - # create an empty container for the image, progress bar, etc so we can update it later and use session_state to hold them globally. - preview_image = st.empty() - st.session_state["preview_image"] = preview_image - - #st.session_state["loading"] = st.empty() - - st.session_state["progress_bar_text"] = st.empty() - st.session_state["progress_bar"] = st.empty() - - - message = st.empty() - - #if uploaded_images: - #image = Image.open(uploaded_images).convert('RGB') - ##img_array = np.array(image) # if you want to pass it to OpenCV - #new_img = image.resize((width, height)) - #st.image(new_img, use_column_width=True) - - - if generate_button: - #print("Loading models") - # load the models when we hit the generate button for the first time, it wont be loaded after that so dont worry. - load_models(False, use_GFPGAN, st.session_state["use_RealESRGAN"], st.session_state["RealESRGAN_model"], st.session_state["CustomModel_available"], - st.session_state["custom_model"]) - - if uploaded_images: - image = Image.open(uploaded_images).convert('RGBA') - new_img = image.resize((width, height)) - #img_array = np.array(image) # if you want to pass it to OpenCV - new_mask = None - if uploaded_masks: - mask = Image.open(uploaded_masks).convert('RGBA') - new_mask = mask.resize((width, height)) - - try: - output_images, seed, info, stats = img2img(prompt=prompt, init_info=new_img, init_info_mask=new_mask, mask_mode=mask_mode, - mask_restore=img2img_mask_restore, ddim_steps=st.session_state["sampling_steps"], - sampler_name=st.session_state["sampler_name"], n_iter=batch_count, - cfg_scale=cfg_scale, denoising_strength=st.session_state["denoising_strength"], variant_seed=variant_seed, - seed=seed, noise_mode=noise_mode, find_noise_steps=find_noise_steps, width=width, - height=height, variant_amount=variant_amount, - ddim_eta=0.0, write_info_files=write_info_files, RealESRGAN_model=st.session_state["RealESRGAN_model"], - separate_prompts=separate_prompts, normalize_prompt_weights=normalize_prompt_weights, - save_individual_images=save_individual_images, save_grid=save_grid, - group_by_prompt=group_by_prompt, save_as_jpg=save_as_jpg, use_GFPGAN=use_GFPGAN, - use_RealESRGAN=st.session_state["use_RealESRGAN"] if not loopback else False, loopback=loopback - ) - - #show a message when the generation is complete. - message.success('Render Complete: ' + info + '; Stats: ' + stats, icon="✅") - - except (StopException, KeyError): - print(f"Received Streamlit StopException") - - # this will render all the images at the end of the generation but its better if its moved to a second tab inside col2 and shown as a gallery. - # use the current col2 first tab to show the preview_img and update it as its generated. - #preview_image.image(output_images, width=750) - -#on import run init diff --git a/scripts/imglab.py b/scripts/imglab.py deleted file mode 100644 index eb09c6a..0000000 --- a/scripts/imglab.py +++ /dev/null @@ -1,161 +0,0 @@ -# base webui import and utils. -from webui_streamlit import st -from sd_utils import * - -#home plugin -import os -from PIL import Image -#from bs4 import BeautifulSoup -from streamlit.runtime.in_memory_file_manager import in_memory_file_manager -from streamlit.elements import image as STImage - -# Temp imports - - -# end of imports -#--------------------------------------------------------------------------------------------------------------- - -try: - # this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start. - from transformers import logging - - logging.set_verbosity_error() -except: - pass - -class plugin_info(): - plugname = "imglab" - description = "Image Lab" - isTab = True - displayPriority = 3 - -def getLatestGeneratedImagesFromPath(): - #get the latest images from the generated images folder - #get the path to the generated images folder - generatedImagesPath = os.path.join(os.getcwd(),'outputs') - #get all the files from the folders and subfolders - files = [] - #get the laest 10 images from the output folder without walking the subfolders - for r, d, f in os.walk(generatedImagesPath): - for file in f: - if '.png' in file: - files.append(os.path.join(r, file)) - #sort the files by date - files.sort(key=os.path.getmtime) - #reverse the list so the latest images are first - for f in files: - img = Image.open(f) - files[files.index(f)] = img - #get the latest 10 files - #get all the files with the .png or .jpg extension - #sort files by date - #get the latest 10 files - latestFiles = files[-10:] - #reverse the list - latestFiles.reverse() - return latestFiles - -def getImagesFromLexica(): - #scrape images from lexica.art - #get the html from the page - #get the html with cookies and javascript - apiEndpoint = r'https://lexica.art/api/trpc/prompts.infinitePrompts?batch=1&input=%7B%220%22%3A%7B%22json%22%3A%7B%22limit%22%3A10%2C%22text%22%3A%22%22%2C%22cursor%22%3A10%7D%7D%7D' - #REST API call - # - from requests_html import HTMLSession - session = HTMLSession() - - response = session.get(apiEndpoint) - #req = requests.Session() - #req.headers['user-agent'] = 'Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/96.0.4664.45 Safari/537.36' - #response = req.get(apiEndpoint) - print(response.status_code) - print(response.text) - #get the json from the response - #json = response.json() - #get the prompts from the json - print(response) - #session = requests.Session() - #parseEndpointJson = session.get(apiEndpoint,headers=headers,verify=False) - #print(parseEndpointJson) - #print('test2') - #page = requests.get("https://lexica.art/", headers={'User-Agent': 'Mozilla/5.0'}) - #parse the html - #soup = BeautifulSoup(page.content, 'html.parser') - #find all the images - #print(soup) - #images = soup.find_all('alt-image') - #create a list to store the image urls - image_urls = [] - #loop through the images - for image in images: - #get the url - image_url = image['src'] - #add it to the list - image_urls.append('http://www.lexica.art/'+image_url) - #return the list - print(image_urls) - return image_urls -def changeImage(): - #change the image in the image holder - #check if the file is not empty - if len(st.session_state['uploaded_file']) > 0: - #read the file - print('test2') - uploaded = st.session_state['uploaded_file'][0].read() - #show the image in the image holder - st.session_state['previewImg'].empty() - st.session_state['previewImg'].image(uploaded,use_column_width=True) -def createHTMLGallery(images): - html3 = """ - ' - return html3 -def layout(): - - col1, col2 = st.columns(2) - with col1: - st.session_state['uploaded_file'] = st.file_uploader("Choose an image or images", type=["png", "jpg", "jpeg", "webp"],accept_multiple_files=True,on_change=changeImage) - if 'previewImg' not in st.session_state: - st.session_state['previewImg'] = st.empty() - else: - if len(st.session_state['uploaded_file']) > 0: - st.session_state['previewImg'].empty() - st.session_state['previewImg'].image(st.session_state['uploaded_file'][0],use_column_width=True) - else: - st.session_state['previewImg'] = st.empty() - diff --git a/scripts/perlin.py b/scripts/perlin.py deleted file mode 100644 index 327a994..0000000 --- a/scripts/perlin.py +++ /dev/null @@ -1,48 +0,0 @@ -import numpy as np - -def perlin(x, y, seed=0): - # permutation table - np.random.seed(seed) - p = np.arange(256, dtype=int) - np.random.shuffle(p) - p = np.stack([p, p]).flatten() - # coordinates of the top-left - xi, yi = x.astype(int), y.astype(int) - # internal coordinates - xf, yf = x - xi, y - yi - # fade factors - u, v = fade(xf), fade(yf) - # noise components - n00 = gradient(p[p[xi] + yi], xf, yf) - n01 = gradient(p[p[xi] + yi + 1], xf, yf - 1) - n11 = gradient(p[p[xi + 1] + yi + 1], xf - 1, yf - 1) - n10 = gradient(p[p[xi + 1] + yi], xf - 1, yf) - # combine noises - x1 = lerp(n00, n10, u) - x2 = lerp(n01, n11, u) # FIX1: I was using n10 instead of n01 - return lerp(x1, x2, v) # FIX2: I also had to reverse x1 and x2 here - -def lerp(a, b, x): - "linear interpolation" - return a + x * (b - a) - -def fade(t): - "6t^5 - 15t^4 + 10t^3" - return 6 * t**5 - 15 * t**4 + 10 * t**3 - -def gradient(h, x, y): - "grad converts h to the right gradient vector and return the dot product with (x,y)" - vectors = np.array([[0, 1], [0, -1], [1, 0], [-1, 0]]) - g = vectors[h % 4] - return g[:, :, 0] * x + g[:, :, 1] * y - -lin = np.linspace(0, 5, 100, endpoint=False) -x, y = np.meshgrid(lin, lin) - - - -def perlinNoise(height,width,octavesx=5,octavesy=5,seed=None): - linx = np.linspace(0,octavesx,width,endpoint=False) - liny = np.linspace(0,octavesy,height,endpoint=False) - x,y = np.meshgrid(linx,liny) - return perlin(x,y,seed=seed) \ No newline at end of file diff --git a/scripts/relauncher.py b/scripts/relauncher.py index 457d539..7179d7f 100644 --- a/scripts/relauncher.py +++ b/scripts/relauncher.py @@ -19,8 +19,6 @@ optimized_turbo = False # Creates a public xxxxx.gradio.app share link to allow others to use your interface (requires properly forwarded ports to work correctly) share = False -# Generate tiling images -tiling = False # Enter other `--arguments` you wish to use - Must be entered as a `--argument ` syntax additional_arguments = "" @@ -39,8 +37,6 @@ if optimized_turbo == True: common_arguments += "--optimized-turbo " if optimized == True: common_arguments += "--optimized " -if tiling == True: - common_arguments += "--tiling " if share == True: common_arguments += "--share " diff --git a/scripts/sd_utils.py b/scripts/sd_utils.py deleted file mode 100644 index 6983edb..0000000 --- a/scripts/sd_utils.py +++ /dev/null @@ -1,1728 +0,0 @@ -# base webui import and utils. -from webui_streamlit import st - - -# streamlit imports -from streamlit import StopException -#other imports - -import warnings -import json - -import base64 -import os, sys, re, random, datetime, time, math, glob -from PIL import Image, ImageFont, ImageDraw, ImageFilter -from PIL.PngImagePlugin import PngInfo -from scipy import integrate -import torch -from torchdiffeq import odeint -import k_diffusion as K -import math -import mimetypes -import numpy as np -import pynvml -import threading -import torch -from torch import autocast -from torchvision import transforms -import torch.nn as nn -from omegaconf import OmegaConf -import yaml -from pathlib import Path -from contextlib import nullcontext -from einops import rearrange -from ldm.util import instantiate_from_config -from retry import retry -from slugify import slugify -import skimage -import piexif -import piexif.helper -from tqdm import trange - -# Temp imports - - -# end of imports -#--------------------------------------------------------------------------------------------------------------- - -try: - # this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start. - from transformers import logging - - logging.set_verbosity_error() -except: - pass - -# remove some annoying deprecation warnings that show every now and then. -warnings.filterwarnings("ignore", category=DeprecationWarning) - -# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the bowser will not show any UI -mimetypes.init() -mimetypes.add_type('application/javascript', '.js') - -# some of those options should not be changed at all because they would break the model, so I removed them from options. -opt_C = 4 -opt_f = 8 - -if not "defaults" in st.session_state: - st.session_state["defaults"] = {} - -st.session_state["defaults"] = OmegaConf.load("configs/webui/webui_streamlit.yaml") - -if (os.path.exists("configs/webui/userconfig_streamlit.yaml")): - user_defaults = OmegaConf.load("configs/webui/userconfig_streamlit.yaml") - st.session_state["defaults"] = OmegaConf.merge(st.session_state["defaults"], user_defaults) - - -# should and will be moved to a settings menu in the UI at some point -grid_format = [s.lower() for s in st.session_state["defaults"].general.grid_format.split(':')] -grid_lossless = False -grid_quality = 100 -if grid_format[0] == 'png': - grid_ext = 'png' - grid_format = 'png' -elif grid_format[0] in ['jpg', 'jpeg']: - grid_quality = int(grid_format[1]) if len(grid_format) > 1 else 100 - grid_ext = 'jpg' - grid_format = 'jpeg' -elif grid_format[0] == 'webp': - grid_quality = int(grid_format[1]) if len(grid_format) > 1 else 100 - grid_ext = 'webp' - grid_format = 'webp' - if grid_quality < 0: # e.g. webp:-100 for lossless mode - grid_lossless = True - grid_quality = abs(grid_quality) - -# should and will be moved to a settings menu in the UI at some point -save_format = [s.lower() for s in st.session_state["defaults"].general.save_format.split(':')] -save_lossless = False -save_quality = 100 -if save_format[0] == 'png': - save_ext = 'png' - save_format = 'png' -elif save_format[0] in ['jpg', 'jpeg']: - save_quality = int(save_format[1]) if len(save_format) > 1 else 100 - save_ext = 'jpg' - save_format = 'jpeg' -elif save_format[0] == 'webp': - save_quality = int(save_format[1]) if len(save_format) > 1 else 100 - save_ext = 'webp' - save_format = 'webp' - if save_quality < 0: # e.g. webp:-100 for lossless mode - save_lossless = True - save_quality = abs(save_quality) - -# this should force GFPGAN and RealESRGAN onto the selected gpu as well -os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" # see issue #152 -os.environ["CUDA_VISIBLE_DEVICES"] = str(st.session_state["defaults"].general.gpu) - -@retry(tries=5) -def load_models(continue_prev_run = False, use_GFPGAN=False, use_RealESRGAN=False, RealESRGAN_model="RealESRGAN_x4plus", - CustomModel_available=False, custom_model="Stable Diffusion v1.4"): - """Load the different models. We also reuse the models that are already in memory to speed things up instead of loading them again. """ - - print ("Loading models.") - - st.session_state["progress_bar_text"].text("Loading models...") - - # Generate random run ID - # Used to link runs linked w/ continue_prev_run which is not yet implemented - # Use URL and filesystem safe version just in case. - st.session_state["run_id"] = base64.urlsafe_b64encode( - os.urandom(6) - ).decode("ascii") - - # check what models we want to use and if the they are already loaded. - - if use_GFPGAN: - if "GFPGAN" in st.session_state: - print("GFPGAN already loaded") - else: - # Load GFPGAN - if os.path.exists(st.session_state["defaults"].general.GFPGAN_dir): - try: - st.session_state["GFPGAN"] = load_GFPGAN() - print("Loaded GFPGAN") - except Exception: - import traceback - print("Error loading GFPGAN:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) - else: - if "GFPGAN" in st.session_state: - del st.session_state["GFPGAN"] - - if use_RealESRGAN: - if "RealESRGAN" in st.session_state and st.session_state["RealESRGAN"].model.name == RealESRGAN_model: - print("RealESRGAN already loaded") - else: - #Load RealESRGAN - try: - # We first remove the variable in case it has something there, - # some errors can load the model incorrectly and leave things in memory. - del st.session_state["RealESRGAN"] - except KeyError: - pass - - if os.path.exists(st.session_state["defaults"].general.RealESRGAN_dir): - # st.session_state is used for keeping the models in memory across multiple pages or runs. - st.session_state["RealESRGAN"] = load_RealESRGAN(RealESRGAN_model) - print("Loaded RealESRGAN with model "+ st.session_state["RealESRGAN"].model.name) - - else: - if "RealESRGAN" in st.session_state: - del st.session_state["RealESRGAN"] - - if "model" in st.session_state: - if "model" in st.session_state and st.session_state["loaded_model"] == custom_model: - # TODO: check if the optimized mode was changed? - print("Model already loaded") - - return - else: - try: - del st.session_state.model - del st.session_state.modelCS - del st.session_state.modelFS - del st.session_state.loaded_model - except KeyError: - pass - - # At this point the model is either - # is not loaded yet or have been evicted: - # load new model into memory - st.session_state.custom_model = custom_model - - config, device, model, modelCS, modelFS = load_sd_model(custom_model) - - st.session_state.device = device - st.session_state.model = model - st.session_state.modelCS = modelCS - st.session_state.modelFS = modelFS - st.session_state.loaded_model = custom_model - - if st.session_state.defaults.general.enable_attention_slicing: - st.session_state.model.enable_attention_slicing() - - if st.session_state.defaults.general.enable_minimal_memory_usage: - st.session_state.model.enable_minimal_memory_usage() - - print("Model loaded.") - - -def load_model_from_config(config, ckpt, verbose=False): - - print(f"Loading model from {ckpt}") - - pl_sd = torch.load(ckpt, map_location="cpu") - if "global_step" in pl_sd: - print(f"Global Step: {pl_sd['global_step']}") - sd = pl_sd["state_dict"] - model = instantiate_from_config(config.model) - m, u = model.load_state_dict(sd, strict=False) - if len(m) > 0 and verbose: - print("missing keys:") - print(m) - if len(u) > 0 and verbose: - print("unexpected keys:") - print(u) - - model.cuda() - model.eval() - return model - - -def load_sd_from_config(ckpt, verbose=False): - print(f"Loading model from {ckpt}") - pl_sd = torch.load(ckpt, map_location="cpu") - if "global_step" in pl_sd: - print(f"Global Step: {pl_sd['global_step']}") - sd = pl_sd["state_dict"] - return sd - - -class MemUsageMonitor(threading.Thread): - stop_flag = False - max_usage = 0 - total = -1 - - def __init__(self, name): - threading.Thread.__init__(self) - self.name = name - - def run(self): - try: - pynvml.nvmlInit() - except: - print(f"[{self.name}] Unable to initialize NVIDIA management. No memory stats. \n") - return - print(f"[{self.name}] Recording max memory usage...\n") - # Missing context - #handle = pynvml.nvmlDeviceGetHandleByIndex(st.session_state['defaults'].general.gpu) - handle = pynvml.nvmlDeviceGetHandleByIndex(0) - self.total = pynvml.nvmlDeviceGetMemoryInfo(handle).total - while not self.stop_flag: - m = pynvml.nvmlDeviceGetMemoryInfo(handle) - self.max_usage = max(self.max_usage, m.used) - # print(self.max_usage) - time.sleep(0.1) - print(f"[{self.name}] Stopped recording.\n") - pynvml.nvmlShutdown() - - def read(self): - return self.max_usage, self.total - - def stop(self): - self.stop_flag = True - - def read_and_stop(self): - self.stop_flag = True - return self.max_usage, self.total - -class CFGMaskedDenoiser(nn.Module): - def __init__(self, model): - super().__init__() - self.inner_model = model - - def forward(self, x, sigma, uncond, cond, cond_scale, mask, x0, xi): - x_in = x - x_in = torch.cat([x_in] * 2) - sigma_in = torch.cat([sigma] * 2) - cond_in = torch.cat([uncond, cond]) - uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2) - denoised = uncond + (cond - uncond) * cond_scale - - if mask is not None: - assert x0 is not None - img_orig = x0 - mask_inv = 1. - mask - denoised = (img_orig * mask_inv) + (mask * denoised) - - return denoised - -class CFGDenoiser(nn.Module): - def __init__(self, model): - super().__init__() - self.inner_model = model - - def forward(self, x, sigma, uncond, cond, cond_scale): - x_in = torch.cat([x] * 2) - sigma_in = torch.cat([sigma] * 2) - cond_in = torch.cat([uncond, cond]) - uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2) - return uncond + (cond - uncond) * cond_scale -def append_zero(x): - return torch.cat([x, x.new_zeros([1])]) -def append_dims(x, target_dims): - """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" - dims_to_append = target_dims - x.ndim - if dims_to_append < 0: - raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less') - return x[(...,) + (None,) * dims_to_append] -def get_sigmas_karras(n, sigma_min, sigma_max, rho=7., device='cpu'): - """Constructs the noise schedule of Karras et al. (2022).""" - ramp = torch.linspace(0, 1, n) - min_inv_rho = sigma_min ** (1 / rho) - max_inv_rho = sigma_max ** (1 / rho) - sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho - return append_zero(sigmas).to(device) - -# -# helper fft routines that keep ortho normalization and auto-shift before and after fft -def _fft2(data): - if data.ndim > 2: # has channels - out_fft = np.zeros((data.shape[0], data.shape[1], data.shape[2]), dtype=np.complex128) - for c in range(data.shape[2]): - c_data = data[:,:,c] - out_fft[:,:,c] = np.fft.fft2(np.fft.fftshift(c_data),norm="ortho") - out_fft[:,:,c] = np.fft.ifftshift(out_fft[:,:,c]) - else: # one channel - out_fft = np.zeros((data.shape[0], data.shape[1]), dtype=np.complex128) - out_fft[:,:] = np.fft.fft2(np.fft.fftshift(data),norm="ortho") - out_fft[:,:] = np.fft.ifftshift(out_fft[:,:]) - - return out_fft - -def _ifft2(data): - if data.ndim > 2: # has channels - out_ifft = np.zeros((data.shape[0], data.shape[1], data.shape[2]), dtype=np.complex128) - for c in range(data.shape[2]): - c_data = data[:,:,c] - out_ifft[:,:,c] = np.fft.ifft2(np.fft.fftshift(c_data),norm="ortho") - out_ifft[:,:,c] = np.fft.ifftshift(out_ifft[:,:,c]) - else: # one channel - out_ifft = np.zeros((data.shape[0], data.shape[1]), dtype=np.complex128) - out_ifft[:,:] = np.fft.ifft2(np.fft.fftshift(data),norm="ortho") - out_ifft[:,:] = np.fft.ifftshift(out_ifft[:,:]) - - return out_ifft - -def _get_gaussian_window(width, height, std=3.14, mode=0): - - window_scale_x = float(width / min(width, height)) - window_scale_y = float(height / min(width, height)) - - window = np.zeros((width, height)) - x = (np.arange(width) / width * 2. - 1.) * window_scale_x - for y in range(height): - fy = (y / height * 2. - 1.) * window_scale_y - if mode == 0: - window[:, y] = np.exp(-(x**2+fy**2) * std) - else: - window[:, y] = (1/((x**2+1.) * (fy**2+1.))) ** (std/3.14) # hey wait a minute that's not gaussian - - return window - -def _get_masked_window_rgb(np_mask_grey, hardness=1.): - np_mask_rgb = np.zeros((np_mask_grey.shape[0], np_mask_grey.shape[1], 3)) - if hardness != 1.: - hardened = np_mask_grey[:] ** hardness - else: - hardened = np_mask_grey[:] - for c in range(3): - np_mask_rgb[:,:,c] = hardened[:] - return np_mask_rgb - -def get_matched_noise(_np_src_image, np_mask_rgb, noise_q, color_variation): - """ - Explanation: - Getting good results in/out-painting with stable diffusion can be challenging. - Although there are simpler effective solutions for in-painting, out-painting can be especially challenging because there is no color data - in the masked area to help prompt the generator. Ideally, even for in-painting we'd like work effectively without that data as well. - Provided here is my take on a potential solution to this problem. - - By taking a fourier transform of the masked src img we get a function that tells us the presence and orientation of each feature scale in the unmasked src. - Shaping the init/seed noise for in/outpainting to the same distribution of feature scales, orientations, and positions increases output coherence - by helping keep features aligned. This technique is applicable to any continuous generation task such as audio or video, each of which can - be conceptualized as a series of out-painting steps where the last half of the input "frame" is erased. For multi-channel data such as color - or stereo sound the "color tone" or histogram of the seed noise can be matched to improve quality (using scikit-image currently) - This method is quite robust and has the added benefit of being fast independently of the size of the out-painted area. - The effects of this method include things like helping the generator integrate the pre-existing view distance and camera angle. - - Carefully managing color and brightness with histogram matching is also essential to achieving good coherence. - - noise_q controls the exponent in the fall-off of the distribution can be any positive number, lower values means higher detail (range > 0, default 1.) - color_variation controls how much freedom is allowed for the colors/palette of the out-painted area (range 0..1, default 0.01) - This code is provided as is under the Unlicense (https://unlicense.org/) - Although you have no obligation to do so, if you found this code helpful please find it in your heart to credit me [parlance-zz]. - - Questions or comments can be sent to parlance@fifth-harmonic.com (https://github.com/parlance-zz/) - This code is part of a new branch of a discord bot I am working on integrating with diffusers (https://github.com/parlance-zz/g-diffuser-bot) - - """ - - global DEBUG_MODE - global TMP_ROOT_PATH - - width = _np_src_image.shape[0] - height = _np_src_image.shape[1] - num_channels = _np_src_image.shape[2] - - np_src_image = _np_src_image[:] * (1. - np_mask_rgb) - np_mask_grey = (np.sum(np_mask_rgb, axis=2)/3.) - np_src_grey = (np.sum(np_src_image, axis=2)/3.) - all_mask = np.ones((width, height), dtype=bool) - img_mask = np_mask_grey > 1e-6 - ref_mask = np_mask_grey < 1e-3 - - windowed_image = _np_src_image * (1.-_get_masked_window_rgb(np_mask_grey)) - windowed_image /= np.max(windowed_image) - windowed_image += np.average(_np_src_image) * np_mask_rgb# / (1.-np.average(np_mask_rgb)) # rather than leave the masked area black, we get better results from fft by filling the average unmasked color - #windowed_image += np.average(_np_src_image) * (np_mask_rgb * (1.- np_mask_rgb)) / (1.-np.average(np_mask_rgb)) # compensate for darkening across the mask transition area - #_save_debug_img(windowed_image, "windowed_src_img") - - src_fft = _fft2(windowed_image) # get feature statistics from masked src img - src_dist = np.absolute(src_fft) - src_phase = src_fft / src_dist - #_save_debug_img(src_dist, "windowed_src_dist") - - noise_window = _get_gaussian_window(width, height, mode=1) # start with simple gaussian noise - noise_rgb = np.random.random_sample((width, height, num_channels)) - noise_grey = (np.sum(noise_rgb, axis=2)/3.) - noise_rgb *= color_variation # the colorfulness of the starting noise is blended to greyscale with a parameter - for c in range(num_channels): - noise_rgb[:,:,c] += (1. - color_variation) * noise_grey - - noise_fft = _fft2(noise_rgb) - for c in range(num_channels): - noise_fft[:,:,c] *= noise_window - noise_rgb = np.real(_ifft2(noise_fft)) - shaped_noise_fft = _fft2(noise_rgb) - shaped_noise_fft[:,:,:] = np.absolute(shaped_noise_fft[:,:,:])**2 * (src_dist ** noise_q) * src_phase # perform the actual shaping - - brightness_variation = 0.#color_variation # todo: temporarily tieing brightness variation to color variation for now - contrast_adjusted_np_src = _np_src_image[:] * (brightness_variation + 1.) - brightness_variation * 2. - - # scikit-image is used for histogram matching, very convenient! - shaped_noise = np.real(_ifft2(shaped_noise_fft)) - shaped_noise -= np.min(shaped_noise) - shaped_noise /= np.max(shaped_noise) - shaped_noise[img_mask,:] = skimage.exposure.match_histograms(shaped_noise[img_mask,:]**1., contrast_adjusted_np_src[ref_mask,:], channel_axis=1) - shaped_noise = _np_src_image[:] * (1. - np_mask_rgb) + shaped_noise * np_mask_rgb - #_save_debug_img(shaped_noise, "shaped_noise") - - matched_noise = np.zeros((width, height, num_channels)) - matched_noise = shaped_noise[:] - #matched_noise[all_mask,:] = skimage.exposure.match_histograms(shaped_noise[all_mask,:], _np_src_image[ref_mask,:], channel_axis=1) - #matched_noise = _np_src_image[:] * (1. - np_mask_rgb) + matched_noise * np_mask_rgb - - #_save_debug_img(matched_noise, "matched_noise") - - """ - todo: - color_variation doesnt have to be a single number, the overall color tone of the out-painted area could be param controlled - """ - - return np.clip(matched_noise, 0., 1.) - - -# -def find_noise_for_image(model, device, init_image, prompt, steps=200, cond_scale=2.0, verbose=False, normalize=False, generation_callback=None): - image = np.array(init_image).astype(np.float32) / 255.0 - image = image[None].transpose(0, 3, 1, 2) - image = torch.from_numpy(image) - image = 2. * image - 1. - image = image.to(device) - x = model.get_first_stage_encoding(model.encode_first_stage(image)) - - uncond = model.get_learned_conditioning(['']) - cond = model.get_learned_conditioning([prompt]) - - s_in = x.new_ones([x.shape[0]]) - dnw = K.external.CompVisDenoiser(model) - sigmas = dnw.get_sigmas(steps).flip(0) - - if verbose: - print(sigmas) - - for i in trange(1, len(sigmas)): - x_in = torch.cat([x] * 2) - sigma_in = torch.cat([sigmas[i - 1] * s_in] * 2) - cond_in = torch.cat([uncond, cond]) - - c_out, c_in = [K.utils.append_dims(k, x_in.ndim) for k in dnw.get_scalings(sigma_in)] - - if i == 1: - t = dnw.sigma_to_t(torch.cat([sigmas[i] * s_in] * 2)) - else: - t = dnw.sigma_to_t(sigma_in) - - eps = model.apply_model(x_in * c_in, t, cond=cond_in) - denoised_uncond, denoised_cond = (x_in + eps * c_out).chunk(2) - - denoised = denoised_uncond + (denoised_cond - denoised_uncond) * cond_scale - - if i == 1: - d = (x - denoised) / (2 * sigmas[i]) - else: - d = (x - denoised) / sigmas[i - 1] - - if generation_callback is not None: - generation_callback(x, i) - - dt = sigmas[i] - sigmas[i - 1] - x = x + d * dt - - return x / sigmas[-1] - - -def get_sigmas_exponential(n, sigma_min, sigma_max, device='cpu'): - """Constructs an exponential noise schedule.""" - sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), n, device=device).exp() - return append_zero(sigmas) - - -def get_sigmas_vp(n, beta_d=19.9, beta_min=0.1, eps_s=1e-3, device='cpu'): - """Constructs a continuous VP noise schedule.""" - t = torch.linspace(1, eps_s, n, device=device) - sigmas = torch.sqrt(torch.exp(beta_d * t ** 2 / 2 + beta_min * t) - 1) - return append_zero(sigmas) - - -def to_d(x, sigma, denoised): - """Converts a denoiser output to a Karras ODE derivative.""" - return (x - denoised) / append_dims(sigma, x.ndim) -def linear_multistep_coeff(order, t, i, j): - if order - 1 > i: - raise ValueError(f'Order {order} too high for step {i}') - def fn(tau): - prod = 1. - for k in range(order): - if j == k: - continue - prod *= (tau - t[i - k]) / (t[i - j] - t[i - k]) - return prod - return integrate.quad(fn, t[i], t[i + 1], epsrel=1e-4)[0] - -class KDiffusionSampler: - def __init__(self, m, sampler): - self.model = m - self.model_wrap = K.external.CompVisDenoiser(m) - self.schedule = sampler - def get_sampler_name(self): - return self.schedule - def sample(self, S, conditioning, batch_size, shape, verbose, unconditional_guidance_scale, unconditional_conditioning, eta, x_T, img_callback=None, log_every_t=None): - sigmas = self.model_wrap.get_sigmas(S) - x = x_T * sigmas[0] - model_wrap_cfg = CFGDenoiser(self.model_wrap) - samples_ddim = None - samples_ddim = K.sampling.__dict__[f'sample_{self.schedule}'](model_wrap_cfg, x, sigmas, - extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, - 'cond_scale': unconditional_guidance_scale}, disable=False, callback=generation_callback) - # - return samples_ddim, None - - -@torch.no_grad() -def log_likelihood(model, x, sigma_min, sigma_max, extra_args=None, atol=1e-4, rtol=1e-4): - extra_args = {} if extra_args is None else extra_args - s_in = x.new_ones([x.shape[0]]) - v = torch.randint_like(x, 2) * 2 - 1 - fevals = 0 - def ode_fn(sigma, x): - nonlocal fevals - with torch.enable_grad(): - x = x[0].detach().requires_grad_() - denoised = model(x, sigma * s_in, **extra_args) - d = to_d(x, sigma, denoised) - fevals += 1 - grad = torch.autograd.grad((d * v).sum(), x)[0] - d_ll = (v * grad).flatten(1).sum(1) - return d.detach(), d_ll - x_min = x, x.new_zeros([x.shape[0]]) - t = x.new_tensor([sigma_min, sigma_max]) - sol = odeint(ode_fn, x_min, t, atol=atol, rtol=rtol, method='dopri5') - latent, delta_ll = sol[0][-1], sol[1][-1] - ll_prior = torch.distributions.Normal(0, sigma_max).log_prob(latent).flatten(1).sum(1) - return ll_prior + delta_ll, {'fevals': fevals} - - -def create_random_tensors(shape, seeds): - xs = [] - for seed in seeds: - torch.manual_seed(seed) - - # randn results depend on device; gpu and cpu get different results for same seed; - # the way I see it, it's better to do this on CPU, so that everyone gets same result; - # but the original script had it like this so i do not dare change it for now because - # it will break everyone's seeds. - xs.append(torch.randn(shape, device=st.session_state['defaults'].general.gpu)) - x = torch.stack(xs) - return x - -def torch_gc(): - torch.cuda.empty_cache() - torch.cuda.ipc_collect() - -def load_GFPGAN(): - model_name = 'GFPGANv1.3' - model_path = os.path.join(st.session_state['defaults'].general.GFPGAN_dir, 'experiments/pretrained_models', model_name + '.pth') - if not os.path.isfile(model_path): - raise Exception("GFPGAN model not found at path "+model_path) - - sys.path.append(os.path.abspath(st.session_state['defaults'].general.GFPGAN_dir)) - from gfpgan import GFPGANer - - if st.session_state['defaults'].general.gfpgan_cpu or st.session_state['defaults'].general.extra_models_cpu: - instance = GFPGANer(model_path=model_path, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=torch.device('cpu')) - elif st.session_state['defaults'].general.extra_models_gpu: - instance = GFPGANer(model_path=model_path, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=torch.device(f"cuda:{st.session_state['defaults'].general.gfpgan_gpu}")) - else: - instance = GFPGANer(model_path=model_path, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=torch.device(f"cuda:{st.session_state['defaults'].general.gpu}")) - return instance - -def load_RealESRGAN(model_name: str): - from basicsr.archs.rrdbnet_arch import RRDBNet - RealESRGAN_models = { - 'RealESRGAN_x4plus': RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4), - 'RealESRGAN_x4plus_anime_6B': RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4) - } - - model_path = os.path.join(st.session_state['defaults'].general.RealESRGAN_dir, 'experiments/pretrained_models', model_name + '.pth') - if not os.path.exists(os.path.join(st.session_state['defaults'].general.RealESRGAN_dir, "experiments","pretrained_models", f"{model_name}.pth")): - raise Exception(model_name+".pth not found at path "+model_path) - - sys.path.append(os.path.abspath(st.session_state['defaults'].general.RealESRGAN_dir)) - from realesrgan import RealESRGANer - - if st.session_state['defaults'].general.esrgan_cpu or st.session_state['defaults'].general.extra_models_cpu: - instance = RealESRGANer(scale=2, model_path=model_path, model=RealESRGAN_models[model_name], pre_pad=0, half=False) # cpu does not support half - instance.device = torch.device('cpu') - instance.model.to('cpu') - elif st.session_state['defaults'].general.extra_models_gpu: - instance = RealESRGANer(scale=2, model_path=model_path, model=RealESRGAN_models[model_name], pre_pad=0, half=not st.session_state['defaults'].general.no_half, device=torch.device(f"cuda:{st.session_state['defaults'].general.esrgan_gpu}")) - else: - instance = RealESRGANer(scale=2, model_path=model_path, model=RealESRGAN_models[model_name], pre_pad=0, half=not st.session_state['defaults'].general.no_half, device=torch.device(f"cuda:{st.session_state['defaults'].general.gpu}")) - instance.model.name = model_name - - return instance - -# -def load_LDSR(checking=False): - model_name = 'model' - yaml_name = 'project' - model_path = os.path.join(st.session_state['defaults'].general.LDSR_dir, 'experiments/pretrained_models', model_name + '.ckpt') - yaml_path = os.path.join(st.session_state['defaults'].general.LDSR_dir, 'experiments/pretrained_models', yaml_name + '.yaml') - if not os.path.isfile(model_path): - raise Exception("LDSR model not found at path "+model_path) - if not os.path.isfile(yaml_path): - raise Exception("LDSR model not found at path "+yaml_path) - if checking == True: - return True - - sys.path.append(os.path.abspath(st.session_state['defaults'].general.LDSR_dir)) - from LDSR import LDSR - LDSRObject = LDSR(model_path, yaml_path) - return LDSRObject - -# -LDSR = None -def try_loading_LDSR(model_name: str,checking=False): - global LDSR - if os.path.exists(st.session_state['defaults'].general.LDSR_dir): - try: - LDSR = load_LDSR(checking=True) # TODO: Should try to load both models before giving up - if checking == True: - print("Found LDSR") - return True - print("Latent Diffusion Super Sampling (LDSR) model loaded") - except Exception: - import traceback - print("Error loading LDSR:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) - else: - print("LDSR not found at path, please make sure you have cloned the LDSR repo to ./src/latent-diffusion/") - -#try_loading_LDSR('model',checking=True) - - -# Loads Stable Diffusion model by name -def load_sd_model(model_name: str) -> [any, any, any, any, any]: - ckpt_path = st.session_state.defaults.general.default_model_path - if model_name != st.session_state.defaults.general.default_model: - ckpt_path = os.path.join("models", "custom", f"{model_name}.ckpt") - - if st.session_state.defaults.general.optimized: - config = OmegaConf.load(st.session_state.defaults.general.optimized_config) - - sd = load_sd_from_config(ckpt_path) - li, lo = [], [] - for key, v_ in sd.items(): - sp = key.split('.') - if (sp[0]) == 'model': - if 'input_blocks' in sp: - li.append(key) - elif 'middle_block' in sp: - li.append(key) - elif 'time_embed' in sp: - li.append(key) - else: - lo.append(key) - for key in li: - sd['model1.' + key[6:]] = sd.pop(key) - for key in lo: - sd['model2.' + key[6:]] = sd.pop(key) - - device = torch.device(f"cuda:{st.session_state.defaults.general.gpu}") \ - if torch.cuda.is_available() else torch.device("cpu") - - model = instantiate_from_config(config.modelUNet) - _, _ = model.load_state_dict(sd, strict=False) - model.cuda() - model.eval() - model.turbo = st.session_state.defaults.general.optimized_turbo - - modelCS = instantiate_from_config(config.modelCondStage) - _, _ = modelCS.load_state_dict(sd, strict=False) - modelCS.cond_stage_model.device = device - modelCS.eval() - - modelFS = instantiate_from_config(config.modelFirstStage) - _, _ = modelFS.load_state_dict(sd, strict=False) - modelFS.eval() - - del sd - - if not st.session_state.defaults.general.no_half: - model = model.half() - modelCS = modelCS.half() - modelFS = modelFS.half() - - return config, device, model, modelCS, modelFS - else: - config = OmegaConf.load(st.session_state.defaults.general.default_model_config) - model = load_model_from_config(config, ckpt_path) - - device = torch.device(f"cuda:{st.session_state.defaults.general.gpu}") \ - if torch.cuda.is_available() else torch.device("cpu") - model = (model if st.session_state.defaults.general.no_half - else model.half()).to(device) - - return config, device, model, None, None - - -# @codedealer: No usages -def ModelLoader(models,load=False,unload=False,imgproc_realesrgan_model_name='RealESRGAN_x4plus'): - #get global variables - global_vars = globals() - #check if m is in globals - if unload: - for m in models: - if m in global_vars: - #if it is, delete it - del global_vars[m] - if st.session_state['defaults'].general.optimized: - if m == 'model': - del global_vars[m+'FS'] - del global_vars[m+'CS'] - if m == 'model': - m = 'Stable Diffusion' - print('Unloaded ' + m) - if load: - for m in models: - if m not in global_vars or m in global_vars and type(global_vars[m]) == bool: - #if it isn't, load it - if m == 'GFPGAN': - global_vars[m] = load_GFPGAN() - elif m == 'model': - sdLoader = load_sd_from_config() - global_vars[m] = sdLoader[0] - if st.session_state['defaults'].general.optimized: - global_vars[m+'CS'] = sdLoader[1] - global_vars[m+'FS'] = sdLoader[2] - elif m == 'RealESRGAN': - global_vars[m] = load_RealESRGAN(imgproc_realesrgan_model_name) - elif m == 'LDSR': - global_vars[m] = load_LDSR() - if m =='model': - m='Stable Diffusion' - print('Loaded ' + m) - torch_gc() - - -# -@retry(tries=5) -def generation_callback(img, i=0): - if "update_preview_frequency" not in st.session_state: - raise StopException - - try: - if i == 0: - if img['i']: i = img['i'] - except TypeError: - pass - - if i % int(st.session_state.update_preview_frequency) == 0 and st.session_state.update_preview and i > 0: - #print (img) - #print (type(img)) - # The following lines will convert the tensor we got on img to an actual image we can render on the UI. - # It can probably be done in a better way for someone who knows what they're doing. I don't. - #print (img,isinstance(img, torch.Tensor)) - if isinstance(img, torch.Tensor): - x_samples_ddim = (st.session_state["model"] if not st.session_state['defaults'].general.optimized else st.session_state.modelFS).decode_first_stage(img) - else: - # When using the k Diffusion samplers they return a dict instead of a tensor that look like this: - # {'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised} - x_samples_ddim = (st.session_state["model"] if not st.session_state['defaults'].general.optimized else st.session_state.modelFS).decode_first_stage(img["denoised"]) - - x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) - - if x_samples_ddim.ndimension() == 4: - pil_images = [transforms.ToPILImage()(x.squeeze_(0)) for x in x_samples_ddim] - pil_image = image_grid(pil_images, 1) - else: - pil_image = transforms.ToPILImage()(x_samples_ddim.squeeze_(0)) - - # update image on the UI so we can see the progress - st.session_state["preview_image"].image(pil_image) - - # Show a progress bar so we can keep track of the progress even when the image progress is not been shown, - # Dont worry, it doesnt affect the performance. - if st.session_state["generation_mode"] == "txt2img": - percent = int(100 * float(i+1 if i+1 < st.session_state.sampling_steps else st.session_state.sampling_steps)/float(st.session_state.sampling_steps)) - st.session_state["progress_bar_text"].text( - f"Running step: {i+1 if i+1 < st.session_state.sampling_steps else st.session_state.sampling_steps}/{st.session_state.sampling_steps} {percent if percent < 100 else 100}%") - else: - if st.session_state["generation_mode"] == "img2img": - round_sampling_steps = round(st.session_state.sampling_steps * st.session_state["denoising_strength"]) - percent = int(100 * float(i+1 if i+1 < round_sampling_steps else round_sampling_steps)/float(round_sampling_steps)) - st.session_state["progress_bar_text"].text( - f"""Running step: {i+1 if i+1 < round_sampling_steps else round_sampling_steps}/{round_sampling_steps} {percent if percent < 100 else 100}%""") - else: - if st.session_state["generation_mode"] == "txt2vid": - percent = int(100 * float(i+1 if i+1 < st.session_state.sampling_steps else st.session_state.sampling_steps)/float(st.session_state.sampling_steps)) - st.session_state["progress_bar_text"].text( - f"Running step: {i+1 if i+1 < st.session_state.sampling_steps else st.session_state.sampling_steps}/{st.session_state.sampling_steps}" - f"{percent if percent < 100 else 100}%") - - st.session_state["progress_bar"].progress(percent if percent < 100 else 100) - - -prompt_parser = re.compile(""" - (?P # capture group for 'prompt' - [^:]+ # match one or more non ':' characters - ) # end 'prompt' - (?: # non-capture group - :+ # match one or more ':' characters - (?P # capture group for 'weight' - -?\\d+(?:\\.\\d+)? # match positive or negative decimal number - )? # end weight capture group, make optional - \\s* # strip spaces after weight - | # OR - $ # else, if no ':' then match end of line - ) # end non-capture group -""", re.VERBOSE) - -# grabs all text up to the first occurrence of ':' as sub-prompt -# takes the value following ':' as weight -# if ':' has no value defined, defaults to 1.0 -# repeats until no text remaining -def split_weighted_subprompts(input_string, normalize=True): - parsed_prompts = [(match.group("prompt"), float(match.group("weight") or 1)) for match in re.finditer(prompt_parser, input_string)] - if not normalize: - return parsed_prompts - # this probably still doesn't handle negative weights very well - weight_sum = sum(map(lambda x: x[1], parsed_prompts)) - return [(x[0], x[1] / weight_sum) for x in parsed_prompts] - -def slerp(device, t, v0:torch.Tensor, v1:torch.Tensor, DOT_THRESHOLD=0.9995): - v0 = v0.detach().cpu().numpy() - v1 = v1.detach().cpu().numpy() - - dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1))) - if np.abs(dot) > DOT_THRESHOLD: - v2 = (1 - t) * v0 + t * v1 - else: - theta_0 = np.arccos(dot) - sin_theta_0 = np.sin(theta_0) - theta_t = theta_0 * t - sin_theta_t = np.sin(theta_t) - s0 = np.sin(theta_0 - theta_t) / sin_theta_0 - s1 = sin_theta_t / sin_theta_0 - v2 = s0 * v0 + s1 * v1 - - v2 = torch.from_numpy(v2).to(device) - - return v2 - -# -def optimize_update_preview_frequency(current_chunk_speed, previous_chunk_speed_list, update_preview_frequency, update_preview_frequency_list): - """Find the optimal update_preview_frequency value maximizing - performance while minimizing the time between updates.""" - from statistics import mean - - previous_chunk_avg_speed = mean(previous_chunk_speed_list) - - previous_chunk_speed_list.append(current_chunk_speed) - current_chunk_avg_speed = mean(previous_chunk_speed_list) - - if current_chunk_avg_speed >= previous_chunk_avg_speed: - #print(f"{current_chunk_speed} >= {previous_chunk_speed}") - update_preview_frequency_list.append(update_preview_frequency + 1) - else: - #print(f"{current_chunk_speed} <= {previous_chunk_speed}") - update_preview_frequency_list.append(update_preview_frequency - 1) - - update_preview_frequency = round(mean(update_preview_frequency_list)) - - return current_chunk_speed, previous_chunk_speed_list, update_preview_frequency, update_preview_frequency_list - - -def get_font(fontsize): - fonts = ["arial.ttf", "DejaVuSans.ttf"] - for font_name in fonts: - try: - return ImageFont.truetype(font_name, fontsize) - except OSError: - pass - - # ImageFont.load_default() is practically unusable as it only supports - # latin1, so raise an exception instead if no usable font was found - raise Exception(f"No usable font found (tried {', '.join(fonts)})") - -def load_embeddings(fp): - if fp is not None and hasattr(st.session_state["model"], "embedding_manager"): - st.session_state["model"].embedding_manager.load(fp['name']) - -def load_learned_embed_in_clip(learned_embeds_path, text_encoder, tokenizer, token=None): - loaded_learned_embeds = torch.load(learned_embeds_path, map_location="cpu") - - # separate token and the embeds - if learned_embeds_path.endswith('.pt'): - print(loaded_learned_embeds['string_to_token']) - trained_token = list(loaded_learned_embeds['string_to_token'].keys())[0] - embeds = list(loaded_learned_embeds['string_to_param'].values())[0] - - elif learned_embeds_path.endswith('.bin'): - trained_token = list(loaded_learned_embeds.keys())[0] - embeds = loaded_learned_embeds[trained_token] - - embeds = loaded_learned_embeds[trained_token] - # cast to dtype of text_encoder - dtype = text_encoder.get_input_embeddings().weight.dtype - embeds.to(dtype) - - # add the token in tokenizer - token = token if token is not None else trained_token - num_added_tokens = tokenizer.add_tokens(token) - - # resize the token embeddings - text_encoder.resize_token_embeddings(len(tokenizer)) - - # get the id for the token and assign the embeds - token_id = tokenizer.convert_tokens_to_ids(token) - text_encoder.get_input_embeddings().weight.data[token_id] = embeds - return token - -def image_grid(imgs, batch_size, force_n_rows=None, captions=None): - #print (len(imgs)) - if force_n_rows is not None: - rows = force_n_rows - elif st.session_state['defaults'].general.n_rows > 0: - rows = st.session_state['defaults'].general.n_rows - elif st.session_state['defaults'].general.n_rows == 0: - rows = batch_size - else: - rows = math.sqrt(len(imgs)) - rows = round(rows) - - cols = math.ceil(len(imgs) / rows) - - w, h = imgs[0].size - grid = Image.new('RGB', size=(cols * w, rows * h), color='black') - - fnt = get_font(30) - - for i, img in enumerate(imgs): - grid.paste(img, box=(i % cols * w, i // cols * h)) - if captions and i= 2**32: - n = n >> 32 - return n - -# -def draw_prompt_matrix(im, width, height, all_prompts): - def wrap(text, d, font, line_length): - lines = [''] - for word in text.split(): - line = f'{lines[-1]} {word}'.strip() - if d.textlength(line, font=font) <= line_length: - lines[-1] = line - else: - lines.append(word) - return '\n'.join(lines) - - def draw_texts(pos, x, y, texts, sizes): - for i, (text, size) in enumerate(zip(texts, sizes)): - active = pos & (1 << i) != 0 - - if not active: - text = '\u0336'.join(text) + '\u0336' - - d.multiline_text((x, y + size[1] / 2), text, font=fnt, fill=color_active if active else color_inactive, anchor="mm", align="center") - - y += size[1] + line_spacing - - fontsize = (width + height) // 25 - line_spacing = fontsize // 2 - fnt = get_font(fontsize) - color_active = (0, 0, 0) - color_inactive = (153, 153, 153) - - pad_top = height // 4 - pad_left = width * 3 // 4 if len(all_prompts) > 2 else 0 - - cols = im.width // width - rows = im.height // height - - prompts = all_prompts[1:] - - result = Image.new("RGB", (im.width + pad_left, im.height + pad_top), "white") - result.paste(im, (pad_left, pad_top)) - - d = ImageDraw.Draw(result) - - boundary = math.ceil(len(prompts) / 2) - prompts_horiz = [wrap(x, d, fnt, width) for x in prompts[:boundary]] - prompts_vert = [wrap(x, d, fnt, pad_left) for x in prompts[boundary:]] - - sizes_hor = [(x[2] - x[0], x[3] - x[1]) for x in [d.multiline_textbbox((0, 0), x, font=fnt) for x in prompts_horiz]] - sizes_ver = [(x[2] - x[0], x[3] - x[1]) for x in [d.multiline_textbbox((0, 0), x, font=fnt) for x in prompts_vert]] - hor_text_height = sum([x[1] + line_spacing for x in sizes_hor]) - line_spacing - ver_text_height = sum([x[1] + line_spacing for x in sizes_ver]) - line_spacing - - for col in range(cols): - x = pad_left + width * col + width / 2 - y = pad_top / 2 - hor_text_height / 2 - - draw_texts(col, x, y, prompts_horiz, sizes_hor) - - for row in range(rows): - x = pad_left / 2 - y = pad_top + height * row + height / 2 - ver_text_height / 2 - - draw_texts(row, x, y, prompts_vert, sizes_ver) - - return result - -def check_prompt_length(prompt, comments): - """this function tests if prompt is too long, and if so, adds a message to comments""" - - tokenizer = (st.session_state["model"] if not st.session_state['defaults'].general.optimized else st.session_state.modelCS).cond_stage_model.tokenizer - max_length = (st.session_state["model"] if not st.session_state['defaults'].general.optimized else st.session_state.modelCS).cond_stage_model.max_length - - info = (st.session_state["model"] if not st.session_state['defaults'].general.optimized else st.session_state.modelCS).cond_stage_model.tokenizer([prompt], truncation=True, max_length=max_length, - return_overflowing_tokens=True, padding="max_length", return_tensors="pt") - ovf = info['overflowing_tokens'][0] - overflowing_count = ovf.shape[0] - if overflowing_count == 0: - return - - vocab = {v: k for k, v in tokenizer.get_vocab().items()} - overflowing_words = [vocab.get(int(x), "") for x in ovf] - overflowing_text = tokenizer.convert_tokens_to_string(''.join(overflowing_words)) - - comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n") - -def save_sample(image, sample_path_i, filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale, - normalize_prompt_weights, use_GFPGAN, write_info_files, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, - save_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, save_individual_images, model_name): - - filename_i = os.path.join(sample_path_i, filename) - - if st.session_state['defaults'].general.save_metadata or write_info_files: - # toggles differ for txt2img vs. img2img: - offset = 0 if init_img is None else 2 - toggles = [] - if prompt_matrix: - toggles.append(0) - if normalize_prompt_weights: - toggles.append(1) - if init_img is not None: - if uses_loopback: - toggles.append(2) - if uses_random_seed_loopback: - toggles.append(3) - if save_individual_images: - toggles.append(2 + offset) - if save_grid: - toggles.append(3 + offset) - if sort_samples: - toggles.append(4 + offset) - if write_info_files: - toggles.append(5 + offset) - if use_GFPGAN: - toggles.append(6 + offset) - metadata = \ - dict( - target="txt2img" if init_img is None else "img2img", - prompt=prompts[i], ddim_steps=steps, toggles=toggles, sampler_name=sampler_name, - ddim_eta=ddim_eta, n_iter=n_iter, batch_size=batch_size, cfg_scale=cfg_scale, - seed=seeds[i], width=width, height=height, normalize_prompt_weights=normalize_prompt_weights, model_name=st.session_state["loaded_model"]) - # Not yet any use for these, but they bloat up the files: - # info_dict["init_img"] = init_img - # info_dict["init_mask"] = init_mask - if init_img is not None: - metadata["denoising_strength"] = str(denoising_strength) - metadata["resize_mode"] = resize_mode - - if write_info_files: - with open(f"{filename_i}.yaml", "w", encoding="utf8") as f: - yaml.dump(metadata, f, allow_unicode=True, width=10000) - - if st.session_state['defaults'].general.save_metadata: - # metadata = { - # "SD:prompt": prompts[i], - # "SD:seed": str(seeds[i]), - # "SD:width": str(width), - # "SD:height": str(height), - # "SD:steps": str(steps), - # "SD:cfg_scale": str(cfg_scale), - # "SD:normalize_prompt_weights": str(normalize_prompt_weights), - # } - metadata = {"SD:" + k:v for (k,v) in metadata.items()} - - if save_ext == "png": - mdata = PngInfo() - for key in metadata: - mdata.add_text(key, str(metadata[key])) - image.save(f"{filename_i}.png", pnginfo=mdata) - else: - if jpg_sample: - image.save(f"{filename_i}.jpg", quality=save_quality, - optimize=True) - elif save_ext == "webp": - image.save(f"{filename_i}.{save_ext}", f"webp", quality=save_quality, - lossless=save_lossless) - else: - # not sure what file format this is - image.save(f"{filename_i}.{save_ext}", f"{save_ext}") - try: - exif_dict = piexif.load(f"{filename_i}.{save_ext}") - except: - exif_dict = { "Exif": dict() } - exif_dict["Exif"][piexif.ExifIFD.UserComment] = piexif.helper.UserComment.dump( - json.dumps(metadata), encoding="unicode") - piexif.insert(piexif.dump(exif_dict), f"{filename_i}.{save_ext}") - - -def get_next_sequence_number(path, prefix=''): - """ - Determines and returns the next sequence number to use when saving an - image in the specified directory. - - If a prefix is given, only consider files whose names start with that - prefix, and strip the prefix from filenames before extracting their - sequence number. - - The sequence starts at 0. - """ - result = -1 - for p in Path(path).iterdir(): - if p.name.endswith(('.png', '.jpg')) and p.name.startswith(prefix): - tmp = p.name[len(prefix):] - try: - result = max(int(tmp.split('-')[0]), result) - except ValueError: - pass - return result + 1 - - -def oxlamon_matrix(prompt, seed, n_iter, batch_size): - pattern = re.compile(r'(,\s){2,}') - - class PromptItem: - def __init__(self, text, parts, item): - self.text = text - self.parts = parts - if item: - self.parts.append( item ) - - def clean(txt): - return re.sub(pattern, ', ', txt) - - def getrowcount( txt ): - for data in re.finditer( ".*?\\((.*?)\\).*", txt ): - if data: - return len(data.group(1).split("|")) - break - return None - - def repliter( txt ): - for data in re.finditer( ".*?\\((.*?)\\).*", txt ): - if data: - r = data.span(1) - for item in data.group(1).split("|"): - yield (clean(txt[:r[0]-1] + item.strip() + txt[r[1]+1:]), item.strip()) - break - - def iterlist( items ): - outitems = [] - for item in items: - for newitem, newpart in repliter(item.text): - outitems.append( PromptItem(newitem, item.parts.copy(), newpart) ) - - return outitems - - def getmatrix( prompt ): - dataitems = [ PromptItem( prompt[1:].strip(), [], None ) ] - while True: - newdataitems = iterlist( dataitems ) - if len( newdataitems ) == 0: - return dataitems - dataitems = newdataitems - - def classToArrays( items, seed, n_iter ): - texts = [] - parts = [] - seeds = [] - - for item in items: - itemseed = seed - for i in range(n_iter): - texts.append( item.text ) - parts.append( f"Seed: {itemseed}\n" + "\n".join(item.parts) ) - seeds.append( itemseed ) - itemseed += 1 - - return seeds, texts, parts - - all_seeds, all_prompts, prompt_matrix_parts = classToArrays(getmatrix( prompt ), seed, n_iter) - n_iter = math.ceil(len(all_prompts) / batch_size) - - needrows = getrowcount(prompt) - if needrows: - xrows = math.sqrt(len(all_prompts)) - xrows = round(xrows) - # if columns is to much - cols = math.ceil(len(all_prompts) / xrows) - if cols > needrows*4: - needrows *= 2 - - return all_seeds, n_iter, prompt_matrix_parts, all_prompts, needrows - -# -def process_images( - outpath, func_init, func_sample, prompt, seed, sampler_name, save_grid, batch_size, - n_iter, steps, cfg_scale, width, height, prompt_matrix, use_GFPGAN, use_RealESRGAN, realesrgan_model_name, - ddim_eta=0.0, normalize_prompt_weights=True, init_img=None, init_mask=None, - mask_blur_strength=3, mask_restore=False, denoising_strength=0.75, noise_mode=0, find_noise_steps=1, resize_mode=None, uses_loopback=False, - uses_random_seed_loopback=False, sort_samples=True, write_info_files=True, jpg_sample=False, - variant_amount=0.0, variant_seed=None, save_individual_images: bool = True): - """this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch""" - assert prompt is not None - torch_gc() - # start time after garbage collection (or before?) - start_time = time.time() - - # We will use this date here later for the folder name, need to start_time if not need - run_start_dt = datetime.datetime.now() - - mem_mon = MemUsageMonitor('MemMon') - mem_mon.start() - - if st.session_state.defaults.general.use_sd_concepts_library: - - prompt_tokens = re.findall('<([a-zA-Z0-9-]+)>', prompt) - - if prompt_tokens: - # compviz - tokenizer = (st.session_state["model"] if not st.session_state['defaults'].general.optimized else st.session_state.modelCS).cond_stage_model.tokenizer - text_encoder = (st.session_state["model"] if not st.session_state['defaults'].general.optimized else st.session_state.modelCS).cond_stage_model.transformer - - # diffusers - #tokenizer = pipe.tokenizer - #text_encoder = pipe.text_encoder - - ext = ('pt', 'bin') - - if len(prompt_tokens) > 1: - for token_name in prompt_tokens: - embedding_path = os.path.join(st.session_state['defaults'].general.sd_concepts_library_folder, token_name) - if os.path.exists(embedding_path): - for files in os.listdir(embedding_path): - if files.endswith(ext): - load_learned_embed_in_clip(f"{os.path.join(embedding_path, files)}", text_encoder, tokenizer, f"<{token_name}>") - else: - embedding_path = os.path.join(st.session_state['defaults'].general.sd_concepts_library_folder, prompt_tokens[0]) - if os.path.exists(embedding_path): - for files in os.listdir(embedding_path): - if files.endswith(ext): - load_learned_embed_in_clip(f"{os.path.join(embedding_path, files)}", text_encoder, tokenizer, f"<{prompt_tokens[0]}>") - - # - - - os.makedirs(outpath, exist_ok=True) - - sample_path = os.path.join(outpath, "samples") - os.makedirs(sample_path, exist_ok=True) - - if not ("|" in prompt) and prompt.startswith("@"): - prompt = prompt[1:] - - negprompt = '' - if '###' in prompt: - prompt, negprompt = prompt.split('###', 1) - prompt = prompt.strip() - negprompt = negprompt.strip() - - comments = [] - - prompt_matrix_parts = [] - simple_templating = False - add_original_image = not (use_RealESRGAN or use_GFPGAN) - - if prompt_matrix: - if prompt.startswith("@"): - simple_templating = True - add_original_image = not (use_RealESRGAN or use_GFPGAN) - all_seeds, n_iter, prompt_matrix_parts, all_prompts, frows = oxlamon_matrix(prompt, seed, n_iter, batch_size) - else: - all_prompts = [] - prompt_matrix_parts = prompt.split("|") - combination_count = 2 ** (len(prompt_matrix_parts) - 1) - for combination_num in range(combination_count): - current = prompt_matrix_parts[0] - - for n, text in enumerate(prompt_matrix_parts[1:]): - if combination_num & (2 ** n) > 0: - current += ("" if text.strip().startswith(",") else ", ") + text - - all_prompts.append(current) - - n_iter = math.ceil(len(all_prompts) / batch_size) - all_seeds = len(all_prompts) * [seed] - - print(f"Prompt matrix will create {len(all_prompts)} images using a total of {n_iter} batches.") - else: - - if not st.session_state['defaults'].general.no_verify_input: - try: - check_prompt_length(prompt, comments) - except: - import traceback - print("Error verifying input:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) - - all_prompts = batch_size * n_iter * [prompt] - all_seeds = [seed + x for x in range(len(all_prompts))] - - precision_scope = autocast if st.session_state['defaults'].general.precision == "autocast" else nullcontext - output_images = [] - grid_captions = [] - stats = [] - with torch.no_grad(), precision_scope("cuda"), (st.session_state["model"].ema_scope() if not st.session_state['defaults'].general.optimized else nullcontext()): - init_data = func_init() - tic = time.time() - - - # if variant_amount > 0.0 create noise from base seed - base_x = None - if variant_amount > 0.0: - target_seed_randomizer = seed_to_int('') # random seed - torch.manual_seed(seed) # this has to be the single starting seed (not per-iteration) - base_x = create_random_tensors([opt_C, height // opt_f, width // opt_f], seeds=[seed]) - # we don't want all_seeds to be sequential from starting seed with variants, - # since that makes the same variants each time, - # so we add target_seed_randomizer as a random offset - for si in range(len(all_seeds)): - all_seeds[si] += target_seed_randomizer - - for n in range(n_iter): - print(f"Iteration: {n+1}/{n_iter}") - prompts = all_prompts[n * batch_size:(n + 1) * batch_size] - captions = prompt_matrix_parts[n * batch_size:(n + 1) * batch_size] - seeds = all_seeds[n * batch_size:(n + 1) * batch_size] - - print(prompt) - - if st.session_state['defaults'].general.optimized: - st.session_state.modelCS.to(st.session_state['defaults'].general.gpu) - - uc = (st.session_state["model"] if not st.session_state['defaults'].general.optimized else st.session_state.modelCS).get_learned_conditioning(len(prompts) * [negprompt]) - - if isinstance(prompts, tuple): - prompts = list(prompts) - - # split the prompt if it has : for weighting - # TODO for speed it might help to have this occur when all_prompts filled?? - weighted_subprompts = split_weighted_subprompts(prompts[0], normalize_prompt_weights) - - # sub-prompt weighting used if more than 1 - if len(weighted_subprompts) > 1: - c = torch.zeros_like(uc) # i dont know if this is correct.. but it works - for i in range(0, len(weighted_subprompts)): - # note if alpha negative, it functions same as torch.sub - c = torch.add(c, (st.session_state["model"] if not st.session_state['defaults'].general.optimized else st.session_state.modelCS).get_learned_conditioning(weighted_subprompts[i][0]), alpha=weighted_subprompts[i][1]) - else: # just behave like usual - c = (st.session_state["model"] if not st.session_state['defaults'].general.optimized else st.session_state.modelCS).get_learned_conditioning(prompts) - - - shape = [opt_C, height // opt_f, width // opt_f] - - if st.session_state['defaults'].general.optimized: - mem = torch.cuda.memory_allocated()/1e6 - st.session_state.modelCS.to("cpu") - while(torch.cuda.memory_allocated()/1e6 >= mem): - time.sleep(1) - - if noise_mode == 1 or noise_mode == 3: - # TODO params for find_noise_to_image - x = torch.cat(batch_size * [find_noise_for_image( - st.session_state["model"], st.session_state["device"], - init_img.convert('RGB'), '', find_noise_steps, 0.0, normalize=True, - generation_callback=generation_callback, - )], dim=0) - else: - # we manually generate all input noises because each one should have a specific seed - x = create_random_tensors(shape, seeds=seeds) - - if variant_amount > 0.0: # we are making variants - # using variant_seed as sneaky toggle, - # when not None or '' use the variant_seed - # otherwise use seeds - if variant_seed != None and variant_seed != '': - specified_variant_seed = seed_to_int(variant_seed) - torch.manual_seed(specified_variant_seed) - seeds = [specified_variant_seed] - # finally, slerp base_x noise to target_x noise for creating a variant - x = slerp(st.session_state['defaults'].general.gpu, max(0.0, min(1.0, variant_amount)), base_x, x) - - samples_ddim = func_sample(init_data=init_data, x=x, conditioning=c, unconditional_conditioning=uc, sampler_name=sampler_name) - - if st.session_state['defaults'].general.optimized: - st.session_state.modelFS.to(st.session_state['defaults'].general.gpu) - - x_samples_ddim = (st.session_state["model"] if not st.session_state['defaults'].general.optimized else st.session_state.modelFS).decode_first_stage(samples_ddim) - x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) - - run_images = [] - for i, x_sample in enumerate(x_samples_ddim): - sanitized_prompt = slugify(prompts[i]) - - percent = i / len(x_samples_ddim) - st.session_state["progress_bar"].progress(percent if percent < 100 else 100) - - if sort_samples: - full_path = os.path.join(os.getcwd(), sample_path, sanitized_prompt) - - - sanitized_prompt = sanitized_prompt[:220-len(full_path)] - sample_path_i = os.path.join(sample_path, sanitized_prompt) - - #print(f"output folder length: {len(os.path.join(os.getcwd(), sample_path_i))}") - #print(os.path.join(os.getcwd(), sample_path_i)) - - os.makedirs(sample_path_i, exist_ok=True) - base_count = get_next_sequence_number(sample_path_i) - filename = f"{base_count:05}-{steps}_{sampler_name}_{seeds[i]}" - else: - full_path = os.path.join(os.getcwd(), sample_path) - sample_path_i = sample_path - base_count = get_next_sequence_number(sample_path_i) - filename = f"{base_count:05}-{steps}_{sampler_name}_{seeds[i]}_{sanitized_prompt}"[:220-len(full_path)] #same as before - - x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') - x_sample = x_sample.astype(np.uint8) - image = Image.fromarray(x_sample) - original_sample = x_sample - original_filename = filename - - st.session_state["preview_image"].image(image) - - if use_GFPGAN and st.session_state["GFPGAN"] is not None and not use_RealESRGAN: - st.session_state["progress_bar_text"].text("Running GFPGAN on image %d of %d..." % (i+1, len(x_samples_ddim))) - #skip_save = True # #287 >_> - torch_gc() - cropped_faces, restored_faces, restored_img = st.session_state["GFPGAN"].enhance(x_sample[:,:,::-1], has_aligned=False, only_center_face=False, paste_back=True) - gfpgan_sample = restored_img[:,:,::-1] - gfpgan_image = Image.fromarray(gfpgan_sample) - gfpgan_filename = original_filename + '-gfpgan' - - save_sample(gfpgan_image, sample_path_i, gfpgan_filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale, - normalize_prompt_weights, use_GFPGAN, write_info_files, prompt_matrix, init_img, uses_loopback, - uses_random_seed_loopback, save_grid, sort_samples, sampler_name, ddim_eta, - n_iter, batch_size, i, denoising_strength, resize_mode, False, st.session_state["loaded_model"]) - - output_images.append(gfpgan_image) #287 - run_images.append(gfpgan_image) - - if simple_templating: - grid_captions.append( captions[i] + "\ngfpgan" ) - - elif use_RealESRGAN and st.session_state["RealESRGAN"] is not None and not use_GFPGAN: - st.session_state["progress_bar_text"].text("Running RealESRGAN on image %d of %d..." % (i+1, len(x_samples_ddim))) - #skip_save = True # #287 >_> - torch_gc() - - if st.session_state["RealESRGAN"].model.name != realesrgan_model_name: - #try_loading_RealESRGAN(realesrgan_model_name) - load_models(use_GFPGAN=use_GFPGAN, use_RealESRGAN=use_RealESRGAN, RealESRGAN_model=realesrgan_model_name) - - output, img_mode = st.session_state["RealESRGAN"].enhance(x_sample[:,:,::-1]) - esrgan_filename = original_filename + '-esrgan4x' - esrgan_sample = output[:,:,::-1] - esrgan_image = Image.fromarray(esrgan_sample) - - #save_sample(image, sample_path_i, original_filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale, - #normalize_prompt_weights, use_GFPGAN, write_info_files, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, skip_save, - #save_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode) - - save_sample(esrgan_image, sample_path_i, esrgan_filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale, - normalize_prompt_weights, use_GFPGAN, write_info_files, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, - save_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, False, st.session_state["loaded_model"]) - - output_images.append(esrgan_image) #287 - run_images.append(esrgan_image) - - if simple_templating: - grid_captions.append( captions[i] + "\nesrgan" ) - - elif use_RealESRGAN and st.session_state["RealESRGAN"] is not None and use_GFPGAN and st.session_state["GFPGAN"] is not None: - st.session_state["progress_bar_text"].text("Running GFPGAN+RealESRGAN on image %d of %d..." % (i+1, len(x_samples_ddim))) - #skip_save = True # #287 >_> - torch_gc() - cropped_faces, restored_faces, restored_img = st.session_state["GFPGAN"].enhance(x_sample[:,:,::-1], has_aligned=False, only_center_face=False, paste_back=True) - gfpgan_sample = restored_img[:,:,::-1] - - if st.session_state["RealESRGAN"].model.name != realesrgan_model_name: - #try_loading_RealESRGAN(realesrgan_model_name) - load_models(use_GFPGAN=use_GFPGAN, use_RealESRGAN=use_RealESRGAN, RealESRGAN_model=realesrgan_model_name) - - output, img_mode = st.session_state["RealESRGAN"].enhance(gfpgan_sample[:,:,::-1]) - gfpgan_esrgan_filename = original_filename + '-gfpgan-esrgan4x' - gfpgan_esrgan_sample = output[:,:,::-1] - gfpgan_esrgan_image = Image.fromarray(gfpgan_esrgan_sample) - - save_sample(gfpgan_esrgan_image, sample_path_i, gfpgan_esrgan_filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale, - normalize_prompt_weights, False, write_info_files, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, - save_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, False, st.session_state["loaded_model"]) - - output_images.append(gfpgan_esrgan_image) #287 - run_images.append(gfpgan_esrgan_image) - - if simple_templating: - grid_captions.append( captions[i] + "\ngfpgan_esrgan" ) - else: - output_images.append(image) - run_images.append(image) - - if mask_restore and init_mask: - #init_mask = init_mask if keep_mask else ImageOps.invert(init_mask) - init_mask = init_mask.filter(ImageFilter.GaussianBlur(mask_blur_strength)) - init_mask = init_mask.convert('L') - init_img = init_img.convert('RGB') - image = image.convert('RGB') - - if use_RealESRGAN and st.session_state["RealESRGAN"] is not None: - if st.session_state["RealESRGAN"].model.name != realesrgan_model_name: - #try_loading_RealESRGAN(realesrgan_model_name) - load_models(use_GFPGAN=use_GFPGAN, use_RealESRGAN=use_RealESRGAN, RealESRGAN_model=realesrgan_model_name) - - output, img_mode = st.session_state["RealESRGAN"].enhance(np.array(init_img, dtype=np.uint8)) - init_img = Image.fromarray(output) - init_img = init_img.convert('RGB') - - output, img_mode = st.session_state["RealESRGAN"].enhance(np.array(init_mask, dtype=np.uint8)) - init_mask = Image.fromarray(output) - init_mask = init_mask.convert('L') - - image = Image.composite(init_img, image, init_mask) - - if save_individual_images: - save_sample(image, sample_path_i, filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale, - normalize_prompt_weights, use_GFPGAN, write_info_files, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, - save_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, save_individual_images, st.session_state["loaded_model"]) - - #if add_original_image or not simple_templating: - #output_images.append(image) - #if simple_templating: - #grid_captions.append( captions[i] ) - - if st.session_state['defaults'].general.optimized: - mem = torch.cuda.memory_allocated()/1e6 - st.session_state.modelFS.to("cpu") - while(torch.cuda.memory_allocated()/1e6 >= mem): - time.sleep(1) - - if len(run_images) > 1: - preview_image = image_grid(run_images, n_iter) - else: - preview_image = run_images[0] - - # Constrain the final preview image to 1440x900 so we're not sending huge amounts of data - # to the browser - preview_image = constrain_image(preview_image, 1440, 900) - st.session_state["progress_bar_text"].text("Finished!") - st.session_state["preview_image"].image(preview_image) - - if prompt_matrix or save_grid: - if prompt_matrix: - if simple_templating: - grid = image_grid(output_images, n_iter, force_n_rows=frows, captions=grid_captions) - else: - grid = image_grid(output_images, n_iter, force_n_rows=1 << ((len(prompt_matrix_parts)-1)//2)) - try: - grid = draw_prompt_matrix(grid, width, height, prompt_matrix_parts) - except: - import traceback - print("Error creating prompt_matrix text:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) - else: - grid = image_grid(output_images, batch_size) - - if grid and (batch_size > 1 or n_iter > 1): - output_images.insert(0, grid) - - grid_count = get_next_sequence_number(outpath, 'grid-') - grid_file = f"grid-{grid_count:05}-{seed}_{slugify(prompts[i].replace(' ', '_')[:220-len(full_path)])}.{grid_ext}" - grid.save(os.path.join(outpath, grid_file), grid_format, quality=grid_quality, lossless=grid_lossless, optimize=True) - - toc = time.time() - - mem_max_used, mem_total = mem_mon.read_and_stop() - time_diff = time.time()-start_time - - info = f""" - {prompt} - Steps: {steps}, Sampler: {sampler_name}, CFG scale: {cfg_scale}, Seed: {seed}{', Denoising strength: '+str(denoising_strength) if init_img is not None else ''}{', GFPGAN' if use_GFPGAN and st.session_state["GFPGAN"] is not None else ''}{', '+realesrgan_model_name if use_RealESRGAN and st.session_state["RealESRGAN"] is not None else ''}{', Prompt Matrix Mode.' if prompt_matrix else ''}""".strip() - stats = f''' - Took { round(time_diff, 2) }s total ({ round(time_diff/(len(all_prompts)),2) }s per image) - Peak memory usage: { -(mem_max_used // -1_048_576) } MiB / { -(mem_total // -1_048_576) } MiB / { round(mem_max_used/mem_total*100, 3) }%''' - - for comment in comments: - info += "\n\n" + comment - - #mem_mon.stop() - #del mem_mon - torch_gc() - - return output_images, seed, info, stats - - -def resize_image(resize_mode, im, width, height): - LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS) - if resize_mode == 0: - res = im.resize((width, height), resample=LANCZOS) - elif resize_mode == 1: - ratio = width / height - src_ratio = im.width / im.height - - src_w = width if ratio > src_ratio else im.width * height // im.height - src_h = height if ratio <= src_ratio else im.height * width // im.width - - resized = im.resize((src_w, src_h), resample=LANCZOS) - res = Image.new("RGBA", (width, height)) - res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2)) - else: - ratio = width / height - src_ratio = im.width / im.height - - src_w = width if ratio < src_ratio else im.width * height // im.height - src_h = height if ratio >= src_ratio else im.height * width // im.width - - resized = im.resize((src_w, src_h), resample=LANCZOS) - res = Image.new("RGBA", (width, height)) - res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2)) - - if ratio < src_ratio: - fill_height = height // 2 - src_h // 2 - res.paste(resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0)) - res.paste(resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)), box=(0, fill_height + src_h)) - elif ratio > src_ratio: - fill_width = width // 2 - src_w // 2 - res.paste(resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0)) - res.paste(resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)), box=(fill_width + src_w, 0)) - - return res - -def constrain_image(img, max_width, max_height): - ratio = max(img.width / max_width, img.height / max_height) - if ratio <= 1: - return img - resampler = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS) - resized = img.resize((int(img.width / ratio), int(img.height / ratio)), resample=resampler) - return resized diff --git a/scripts/stable_diffusion_pipeline.py b/scripts/stable_diffusion_pipeline.py deleted file mode 100644 index 6f4f794..0000000 --- a/scripts/stable_diffusion_pipeline.py +++ /dev/null @@ -1,233 +0,0 @@ -import inspect -import warnings -from tqdm.auto import tqdm -from typing import List, Optional, Union - -import torch -from diffusers import ModelMixin -from diffusers.models import AutoencoderKL, UNet2DConditionModel -from diffusers.pipeline_utils import DiffusionPipeline -from diffusers.pipelines.stable_diffusion.safety_checker import \ - StableDiffusionSafetyChecker -from diffusers.schedulers import (DDIMScheduler, LMSDiscreteScheduler, - PNDMScheduler) -from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer - - -class StableDiffusionPipeline(DiffusionPipeline): - def __init__( - self, - vae: AutoencoderKL, - text_encoder: CLIPTextModel, - tokenizer: CLIPTokenizer, - unet: UNet2DConditionModel, - scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], - safety_checker: StableDiffusionSafetyChecker, - feature_extractor: CLIPFeatureExtractor, - ): - super().__init__() - scheduler = scheduler.set_format("pt") - self.register_modules( - vae=vae, - text_encoder=text_encoder, - tokenizer=tokenizer, - unet=unet, - scheduler=scheduler, - safety_checker=safety_checker, - feature_extractor=feature_extractor, - ) - - @torch.no_grad() - def __call__( - self, - prompt: Optional[Union[str, List[str]]] = None, - height: Optional[int] = 512, - width: Optional[int] = 512, - num_inference_steps: Optional[int] = 50, - guidance_scale: Optional[float] = 7.5, - eta: Optional[float] = 0.0, - generator: Optional[torch.Generator] = None, - latents: Optional[torch.FloatTensor] = None, - text_embeddings: Optional[torch.FloatTensor] = None, - output_type: Optional[str] = "pil", - **kwargs, - ): - if "torch_device" in kwargs: - device = kwargs.pop("torch_device") - warnings.warn( - "`torch_device` is deprecated as an input argument to `__call__` and" - " will be removed in v0.3.0. Consider using `pipe.to(torch_device)`" - " instead." - ) - - # Set device as before (to be removed in 0.3.0) - if device is None: - device = "cuda" if torch.cuda.is_available() else "cpu" - self.to(device) - - if text_embeddings is None: - if isinstance(prompt, str): - batch_size = 1 - elif isinstance(prompt, list): - batch_size = len(prompt) - else: - raise ValueError( - f"`prompt` has to be of type `str` or `list` but is {type(prompt)}" - ) - - if height % 8 != 0 or width % 8 != 0: - raise ValueError( - "`height` and `width` have to be divisible by 8 but are" - f" {height} and {width}." - ) - - # get prompt text embeddings - text_input = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0] - else: - batch_size = text_embeddings.shape[0] - - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - do_classifier_free_guidance = guidance_scale > 1.0 - # get unconditional embeddings for classifier free guidance - if do_classifier_free_guidance: - # max_length = text_input.input_ids.shape[-1] - max_length = 77 # self.tokenizer.model_max_length - uncond_input = self.tokenizer( - [""] * batch_size, - padding="max_length", - max_length=max_length, - return_tensors="pt", - ) - uncond_embeddings = self.text_encoder( - uncond_input.input_ids.to(self.device) - )[0] - - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) - - # get the initial random noise unless the user supplied it - latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8) - if latents is None: - latents = torch.randn( - latents_shape, - generator=generator, - device=self.device, - ) - else: - if latents.shape != latents_shape: - raise ValueError( - f"Unexpected latents shape, got {latents.shape}, expected" - f" {latents_shape}" - ) - latents = latents.to(self.device) - - # set timesteps - accepts_offset = "offset" in set( - inspect.signature(self.scheduler.set_timesteps).parameters.keys() - ) - extra_set_kwargs = {} - if accepts_offset: - extra_set_kwargs["offset"] = 1 - - self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs) - - # if we use LMSDiscreteScheduler, let's make sure latents are mulitplied by sigmas - if isinstance(self.scheduler, LMSDiscreteScheduler): - latents = latents * self.scheduler.sigmas[0] - - # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature - # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. - # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 - # and should be between [0, 1] - accepts_eta = "eta" in set( - inspect.signature(self.scheduler.step).parameters.keys() - ) - extra_step_kwargs = {} - if accepts_eta: - extra_step_kwargs["eta"] = eta - - for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)): - # expand the latents if we are doing classifier free guidance - latent_model_input = ( - torch.cat([latents] * 2) if do_classifier_free_guidance else latents - ) - if isinstance(self.scheduler, LMSDiscreteScheduler): - sigma = self.scheduler.sigmas[i] - # the model input needs to be scaled to match the continuous ODE formulation in K-LMS - latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5) - - # predict the noise residual - noise_pred = self.unet( - latent_model_input, t, encoder_hidden_states=text_embeddings - )["sample"] - - # perform guidance - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * ( - noise_pred_text - noise_pred_uncond - ) - - # compute the previous noisy sample x_t -> x_t-1 - if isinstance(self.scheduler, LMSDiscreteScheduler): - latents = self.scheduler.step( - noise_pred, i, latents, **extra_step_kwargs - )["prev_sample"] - else: - latents = self.scheduler.step( - noise_pred, t, latents, **extra_step_kwargs - )["prev_sample"] - - # scale and decode the image latents with vae - latents = 1 / 0.18215 * latents - image = self.vae.decode(latents).sample - - image = (image / 2 + 0.5).clamp(0, 1) - image = image.cpu().permute(0, 2, 3, 1).numpy() - - safety_cheker_input = self.feature_extractor( - self.numpy_to_pil(image), return_tensors="pt" - ).to(self.device) - image, has_nsfw_concept = self.safety_checker( - images=image, clip_input=safety_cheker_input.pixel_values - ) - - if output_type == "pil": - image = self.numpy_to_pil(image) - - return {"sample": image, "nsfw_content_detected": has_nsfw_concept} - - def embed_text(self, text): - """Helper to embed some text""" - with torch.autocast("cuda"): - text_input = self.tokenizer( - text, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - with torch.no_grad(): - embed = self.text_encoder(text_input.input_ids.to(self.device))[0] - return embed - - -class NoCheck(ModelMixin): - """Can be used in place of safety checker. Use responsibly and at your own risk.""" - def __init__(self): - super().__init__() - self.register_parameter(name='asdf', param=torch.nn.Parameter(torch.randn(3))) - - def forward(self, images=None, **kwargs): - return images, [False] diff --git a/scripts/stable_diffusion_walk.py b/scripts/stable_diffusion_walk.py deleted file mode 100644 index 1ce175d..0000000 --- a/scripts/stable_diffusion_walk.py +++ /dev/null @@ -1,218 +0,0 @@ -import json -import subprocess -from pathlib import Path - -import numpy as np -import torch -from diffusers.schedulers import (DDIMScheduler, LMSDiscreteScheduler, - PNDMScheduler) -from diffusers import ModelMixin - -from stable_diffusion_pipeline import StableDiffusionPipeline - -pipeline = StableDiffusionPipeline.from_pretrained( - "CompVis/stable-diffusion-v1-4", - use_auth_token=True, - torch_dtype=torch.float16, - revision="fp16", -).to("cuda") - -default_scheduler = PNDMScheduler( - beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" -) -ddim_scheduler = DDIMScheduler( - beta_start=0.00085, - beta_end=0.012, - beta_schedule="scaled_linear", - clip_sample=False, - set_alpha_to_one=False, -) -klms_scheduler = LMSDiscreteScheduler( - beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" -) -SCHEDULERS = dict(default=default_scheduler, ddim=ddim_scheduler, klms=klms_scheduler) - - -def slerp(t, v0, v1, DOT_THRESHOLD=0.9995): - """helper function to spherically interpolate two arrays v1 v2""" - - if not isinstance(v0, np.ndarray): - inputs_are_torch = True - input_device = v0.device - v0 = v0.cpu().numpy() - v1 = v1.cpu().numpy() - - dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1))) - if np.abs(dot) > DOT_THRESHOLD: - v2 = (1 - t) * v0 + t * v1 - else: - theta_0 = np.arccos(dot) - sin_theta_0 = np.sin(theta_0) - theta_t = theta_0 * t - sin_theta_t = np.sin(theta_t) - s0 = np.sin(theta_0 - theta_t) / sin_theta_0 - s1 = sin_theta_t / sin_theta_0 - v2 = s0 * v0 + s1 * v1 - - if inputs_are_torch: - v2 = torch.from_numpy(v2).to(input_device) - - return v2 - - -def make_video_ffmpeg(frame_dir, output_file_name='output.mp4', frame_filename="frame%06d.jpg", fps=30): - frame_ref_path = str(frame_dir / frame_filename) - video_path = str(frame_dir / output_file_name) - subprocess.call( - f"ffmpeg -r {fps} -i {frame_ref_path} -vcodec libx264 -crf 10 -pix_fmt yuv420p" - f" {video_path}".split() - ) - return video_path - - -def walk( - prompts=["blueberry spaghetti", "strawberry spaghetti"], - seeds=[42, 123], - num_steps=5, - output_dir="dreams", - name="berry_good_spaghetti", - height=512, - width=512, - guidance_scale=7.5, - eta=0.0, - num_inference_steps=50, - do_loop=False, - make_video=False, - use_lerp_for_text=False, - scheduler="klms", # choices: default, ddim, klms - disable_tqdm=False, - upsample=False, - fps=30, -): - """Generate video frames/a video given a list of prompts and seeds. - - Args: - prompts (List[str], optional): List of . Defaults to ["blueberry spaghetti", "strawberry spaghetti"]. - seeds (List[int], optional): List of random seeds corresponding to given prompts. - num_steps (int, optional): Number of steps to walk. Increase this value to 60-200 for good results. Defaults to 5. - output_dir (str, optional): Root dir where images will be saved. Defaults to "dreams". - name (str, optional): Sub directory of output_dir to save this run's files. Defaults to "berry_good_spaghetti". - height (int, optional): Height of image to generate. Defaults to 512. - width (int, optional): Width of image to generate. Defaults to 512. - guidance_scale (float, optional): Higher = more adherance to prompt. Lower = let model take the wheel. Defaults to 7.5. - eta (float, optional): ETA. Defaults to 0.0. - num_inference_steps (int, optional): Number of diffusion steps. Defaults to 50. - do_loop (bool, optional): Whether to loop from last prompt back to first. Defaults to False. - make_video (bool, optional): Whether to make a video or just save the images. Defaults to False. - use_lerp_for_text (bool, optional): Use LERP instead of SLERP for text embeddings when walking. Defaults to False. - scheduler (str, optional): Which scheduler to use. Defaults to "klms". Choices are "default", "ddim", "klms". - disable_tqdm (bool, optional): Whether to turn off the tqdm progress bars. Defaults to False. - upsample (bool, optional): If True, uses Real-ESRGAN to upsample images 4x. Requires it to be installed - which you can do by running: `pip install git+https://github.com/xinntao/Real-ESRGAN.git`. Defaults to False. - fps (int, optional): The frames per second (fps) that you want the video to use. Does nothing if make_video is False. Defaults to 30. - - Returns: - str: Path to video file saved if make_video=True, else None. - """ - if upsample: - from .upsampling import PipelineRealESRGAN - - upsampling_pipeline = PipelineRealESRGAN.from_pretrained('nateraw/real-esrgan') - - pipeline.set_progress_bar_config(disable=disable_tqdm) - - pipeline.scheduler = SCHEDULERS[scheduler] - - output_path = Path(output_dir) / name - output_path.mkdir(exist_ok=True, parents=True) - - # Write prompt info to file in output dir so we can keep track of what we did - prompt_config_path = output_path / 'prompt_config.json' - prompt_config_path.write_text( - json.dumps( - dict( - prompts=prompts, - seeds=seeds, - num_steps=num_steps, - name=name, - guidance_scale=guidance_scale, - eta=eta, - num_inference_steps=num_inference_steps, - do_loop=do_loop, - make_video=make_video, - use_lerp_for_text=use_lerp_for_text, - scheduler=scheduler - ), - indent=2, - sort_keys=False, - ) - ) - - assert len(prompts) == len(seeds) - - first_prompt, *prompts = prompts - embeds_a = pipeline.embed_text(first_prompt) - - first_seed, *seeds = seeds - latents_a = torch.randn( - (1, pipeline.unet.in_channels, height // 8, width // 8), - device=pipeline.device, - generator=torch.Generator(device=pipeline.device).manual_seed(first_seed), - ) - - if do_loop: - prompts.append(first_prompt) - seeds.append(first_seed) - - frame_index = 0 - for prompt, seed in zip(prompts, seeds): - # Text - embeds_b = pipeline.embed_text(prompt) - - # Latent Noise - latents_b = torch.randn( - (1, pipeline.unet.in_channels, height // 8, width // 8), - device=pipeline.device, - generator=torch.Generator(device=pipeline.device).manual_seed(seed), - ) - - for i, t in enumerate(np.linspace(0, 1, num_steps)): - do_print_progress = (i == 0) or ((frame_index + 1) % 20 == 0) - if do_print_progress: - print(f"COUNT: {frame_index+1}/{len(seeds)*num_steps}") - - if use_lerp_for_text: - embeds = torch.lerp(embeds_a, embeds_b, float(t)) - else: - embeds = slerp(float(t), embeds_a, embeds_b) - latents = slerp(float(t), latents_a, latents_b) - - with torch.autocast("cuda"): - im = pipeline( - latents=latents, - text_embeddings=embeds, - height=height, - width=width, - guidance_scale=guidance_scale, - eta=eta, - num_inference_steps=num_inference_steps, - output_type='pil' if not upsample else 'numpy' - )["sample"][0] - - if upsample: - im = upsampling_pipeline(im) - - im.save(output_path / ("frame%06d.jpg" % frame_index)) - frame_index += 1 - - embeds_a = embeds_b - latents_a = latents_b - - if make_video: - return make_video_ffmpeg(output_path, f"{name}.mp4", fps=fps) - - -if __name__ == "__main__": - import fire - - fire.Fire(walk) diff --git a/scripts/textual_inversion.py b/scripts/textual_inversion.py deleted file mode 100644 index 3e5cc3e..0000000 --- a/scripts/textual_inversion.py +++ /dev/null @@ -1,57 +0,0 @@ -# base webui import and utils. -from webui_streamlit import st -from sd_utils import * - -# streamlit imports - - -#other imports -#from transformers import CLIPTextModel, CLIPTokenizer - -# Temp imports - - -# end of imports -#--------------------------------------------------------------------------------------------------------------- - -#def load_learned_embed_in_clip(learned_embeds_path, text_encoder, tokenizer, token=None): - - #loaded_learned_embeds = torch.load(learned_embeds_path, map_location="cpu") - - ## separate token and the embeds - #print (loaded_learned_embeds) - #trained_token = list(loaded_learned_embeds.keys())[0] - #embeds = loaded_learned_embeds[trained_token] - - ## cast to dtype of text_encoder - #dtype = text_encoder.get_input_embeddings().weight.dtype - #embeds.to(dtype) - - ## add the token in tokenizer - #token = token if token is not None else trained_token - #num_added_tokens = tokenizer.add_tokens(token) - #i = 1 - #while(num_added_tokens == 0): - #print(f"The tokenizer already contains the token {token}.") - #token = f"{token[:-1]}-{i}>" - #print(f"Attempting to add the token {token}.") - #num_added_tokens = tokenizer.add_tokens(token) - #i+=1 - - ## resize the token embeddings - #text_encoder.resize_token_embeddings(len(tokenizer)) - - ## get the id for the token and assign the embeds - #token_id = tokenizer.convert_tokens_to_ids(token) - #text_encoder.get_input_embeddings().weight.data[token_id] = embeds - #return token - -##def token_loader() -#learned_token = load_learned_embed_in_clip(f"models/custom/embeddings/Custom Ami.pt", st.session_state.pipe.text_encoder, st.session_state.pipe.tokenizer, "*") -#model_content["token"] = learned_token -#models.append(model_content) - -model_id = "./models/custom/embeddings/" - -def layout(): - st.write("Textual Inversion") \ No newline at end of file diff --git a/scripts/txt2img.py b/scripts/txt2img.py deleted file mode 100644 index 6f74143..0000000 --- a/scripts/txt2img.py +++ /dev/null @@ -1,368 +0,0 @@ -# base webui import and utils. -from webui_streamlit import st -from sd_utils import * - -# streamlit imports -from streamlit import StopException -from streamlit.runtime.in_memory_file_manager import in_memory_file_manager -from streamlit.elements import image as STImage - -#other imports -import os -from typing import Union -from io import BytesIO -from ldm.models.diffusion.ddim import DDIMSampler -from ldm.models.diffusion.plms import PLMSSampler - -# Temp imports - - -# end of imports -#--------------------------------------------------------------------------------------------------------------- - - -try: - # this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start. - from transformers import logging - - logging.set_verbosity_error() -except: - pass - -class plugin_info(): - plugname = "txt2img" - description = "Text to Image" - isTab = True - displayPriority = 1 - - -if os.path.exists(os.path.join(st.session_state['defaults'].general.GFPGAN_dir, "experiments", "pretrained_models", "GFPGANv1.3.pth")): - GFPGAN_available = True -else: - GFPGAN_available = False - -if os.path.exists(os.path.join(st.session_state['defaults'].general.RealESRGAN_dir, "experiments","pretrained_models", f"{st.session_state['defaults'].general.RealESRGAN_model}.pth")): - RealESRGAN_available = True -else: - RealESRGAN_available = False - -# -def txt2img(prompt: str, ddim_steps: int, sampler_name: str, realesrgan_model_name: str, - n_iter: int, batch_size: int, cfg_scale: float, seed: Union[int, str, None], - height: int, width: int, separate_prompts:bool = False, normalize_prompt_weights:bool = True, - save_individual_images: bool = True, save_grid: bool = True, group_by_prompt: bool = True, - save_as_jpg: bool = True, use_GFPGAN: bool = True, use_RealESRGAN: bool = True, - RealESRGAN_model: str = "RealESRGAN_x4plus_anime_6B", fp = None, variant_amount: float = None, - variant_seed: int = None, ddim_eta:float = 0.0, write_info_files:bool = True): - - outpath = st.session_state['defaults'].general.outdir_txt2img or st.session_state['defaults'].general.outdir or "outputs/txt2img-samples" - - seed = seed_to_int(seed) - - #prompt_matrix = 0 in toggles - #normalize_prompt_weights = 1 in toggles - #skip_save = 2 not in toggles - #save_grid = 3 not in toggles - #sort_samples = 4 in toggles - #write_info_files = 5 in toggles - #jpg_sample = 6 in toggles - #use_GFPGAN = 7 in toggles - #use_RealESRGAN = 8 in toggles - - if sampler_name == 'PLMS': - sampler = PLMSSampler(st.session_state["model"]) - elif sampler_name == 'DDIM': - sampler = DDIMSampler(st.session_state["model"]) - elif sampler_name == 'k_dpm_2_a': - sampler = KDiffusionSampler(st.session_state["model"],'dpm_2_ancestral') - elif sampler_name == 'k_dpm_2': - sampler = KDiffusionSampler(st.session_state["model"],'dpm_2') - elif sampler_name == 'k_euler_a': - sampler = KDiffusionSampler(st.session_state["model"],'euler_ancestral') - elif sampler_name == 'k_euler': - sampler = KDiffusionSampler(st.session_state["model"],'euler') - elif sampler_name == 'k_heun': - sampler = KDiffusionSampler(st.session_state["model"],'heun') - elif sampler_name == 'k_lms': - sampler = KDiffusionSampler(st.session_state["model"],'lms') - else: - raise Exception("Unknown sampler: " + sampler_name) - - def init(): - pass - - def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name): - samples_ddim, _ = sampler.sample(S=ddim_steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=cfg_scale, - unconditional_conditioning=unconditional_conditioning, eta=ddim_eta, x_T=x, img_callback=generation_callback, - log_every_t=int(st.session_state.update_preview_frequency)) - - return samples_ddim - - #try: - output_images, seed, info, stats = process_images( - outpath=outpath, - func_init=init, - func_sample=sample, - prompt=prompt, - seed=seed, - sampler_name=sampler_name, - save_grid=save_grid, - batch_size=batch_size, - n_iter=n_iter, - steps=ddim_steps, - cfg_scale=cfg_scale, - width=width, - height=height, - prompt_matrix=separate_prompts, - use_GFPGAN=st.session_state["use_GFPGAN"], - use_RealESRGAN=st.session_state["use_RealESRGAN"], - realesrgan_model_name=realesrgan_model_name, - ddim_eta=ddim_eta, - normalize_prompt_weights=normalize_prompt_weights, - save_individual_images=save_individual_images, - sort_samples=group_by_prompt, - write_info_files=write_info_files, - jpg_sample=save_as_jpg, - variant_amount=variant_amount, - variant_seed=variant_seed, - ) - - del sampler - - return output_images, seed, info, stats - - #except RuntimeError as e: - #err = e - #err_msg = f'CRASHED:


Please wait while the program restarts.' - #stats = err_msg - #return [], seed, 'err', stats - -def layout(): - with st.form("txt2img-inputs"): - st.session_state["generation_mode"] = "txt2img" - - input_col1, generate_col1 = st.columns([10,1]) - - with input_col1: - #prompt = st.text_area("Input Text","") - prompt = st.text_input("Input Text","", placeholder="A corgi wearing a top hat as an oil painting.") - - # Every form must have a submit button, the extra blank spaces is a temp way to align it with the input field. Needs to be done in CSS or some other way. - generate_col1.write("") - generate_col1.write("") - generate_button = generate_col1.form_submit_button("Generate") - - # creating the page layout using columns - col1, col2, col3 = st.columns([1,2,1], gap="large") - - with col1: - width = st.slider("Width:", min_value=64, max_value=4096, value=st.session_state['defaults'].txt2img.width, step=64) - height = st.slider("Height:", min_value=64, max_value=4096, value=st.session_state['defaults'].txt2img.height, step=64) - cfg_scale = st.slider("CFG (Classifier Free Guidance Scale):", min_value=1.0, max_value=30.0, value=st.session_state['defaults'].txt2img.cfg_scale, step=0.5, help="How strongly the image should follow the prompt.") - seed = st.text_input("Seed:", value=st.session_state['defaults'].txt2img.seed, help=" The seed to use, if left blank a random seed will be generated.") - batch_count = st.slider("Batch count.", min_value=1, max_value=100, value=st.session_state['defaults'].txt2img.batch_count, step=1, help="How many iterations or batches of images to generate in total.") - - bs_slider_max_value = 5 - if st.session_state.defaults.general.optimized: - bs_slider_max_value = 100 - - batch_size = st.slider( - "Batch size", - min_value=1, - max_value=bs_slider_max_value, - value=st.session_state.defaults.txt2img.batch_size, - step=1, - help="How many images are at once in a batch.\ - It increases the VRAM usage a lot but if you have enough VRAM it can reduce the time it takes to finish generation as more images are generated at once.\ - Default: 1") - - with st.expander("Preview Settings"): - st.session_state["update_preview"] = st.checkbox("Update Image Preview", value=st.session_state['defaults'].txt2img.update_preview, - help="If enabled the image preview will be updated during the generation instead of at the end. \ - You can use the Update Preview \Frequency option bellow to customize how frequent it's updated. \ - By default this is enabled and the frequency is set to 1 step.") - - st.session_state["update_preview_frequency"] = st.text_input("Update Image Preview Frequency", value=st.session_state['defaults'].txt2img.update_preview_frequency, - help="Frequency in steps at which the the preview image is updated. By default the frequency \ - is set to 1 step.") - - with col2: - preview_tab, gallery_tab = st.tabs(["Preview", "Gallery"]) - - with preview_tab: - #st.write("Image") - #Image for testing - #image = Image.open(requests.get("https://icon-library.com/images/image-placeholder-icon/image-placeholder-icon-13.jpg", stream=True).raw).convert('RGB') - #new_image = image.resize((175, 240)) - #preview_image = st.image(image) - - # create an empty container for the image, progress bar, etc so we can update it later and use session_state to hold them globally. - st.session_state["preview_image"] = st.empty() - - st.session_state["loading"] = st.empty() - - st.session_state["progress_bar_text"] = st.empty() - st.session_state["progress_bar"] = st.empty() - - message = st.empty() - - with col3: - # If we have custom models available on the "models/custom" - #folder then we show a menu to select which model we want to use, otherwise we use the main model for SD - if st.session_state.CustomModel_available: - st.session_state.custom_model = st.selectbox("Custom Model:", st.session_state.custom_models, - index=st.session_state["custom_models"].index(st.session_state['defaults'].general.default_model), - help="Select the model you want to use. This option is only available if you have custom models \ - on your 'models/custom' folder. The model name that will be shown here is the same as the name\ - the file for the model has on said folder, it is recommended to give the .ckpt file a name that \ - will make it easier for you to distinguish it from other models. Default: Stable Diffusion v1.4") - - st.session_state.sampling_steps = st.slider("Sampling Steps", - value=st.session_state['defaults'].txt2img.sampling_steps, - min_value=st.session_state['defaults'].txt2img.slider_bounds.sampling.lower, - max_value=st.session_state['defaults'].txt2img.slider_bounds.sampling.upper, - step=st.session_state['defaults'].txt2img.slider_steps.sampling) - - sampler_name_list = ["k_lms", "k_euler", "k_euler_a", "k_dpm_2", "k_dpm_2_a", "k_heun", "PLMS", "DDIM"] - sampler_name = st.selectbox("Sampling method", sampler_name_list, - index=sampler_name_list.index(st.session_state['defaults'].txt2img.default_sampler), help="Sampling method to use. Default: k_euler") - - - - #basic_tab, advanced_tab = st.tabs(["Basic", "Advanced"]) - - #with basic_tab: - #summit_on_enter = st.radio("Submit on enter?", ("Yes", "No"), horizontal=True, - #help="Press the Enter key to summit, when 'No' is selected you can use the Enter key to write multiple lines.") - - with st.expander("Advanced"): - separate_prompts = st.checkbox("Create Prompt Matrix.", value=st.session_state['defaults'].txt2img.separate_prompts, help="Separate multiple prompts using the `|` character, and get all combinations of them.") - normalize_prompt_weights = st.checkbox("Normalize Prompt Weights.", value=st.session_state['defaults'].txt2img.normalize_prompt_weights, help="Ensure the sum of all weights add up to 1.0") - save_individual_images = st.checkbox("Save individual images.", value=st.session_state['defaults'].txt2img.save_individual_images, help="Save each image generated before any filter or enhancement is applied.") - save_grid = st.checkbox("Save grid",value=st.session_state['defaults'].txt2img.save_grid, help="Save a grid with all the images generated into a single image.") - group_by_prompt = st.checkbox("Group results by prompt", value=st.session_state['defaults'].txt2img.group_by_prompt, - help="Saves all the images with the same prompt into the same folder. When using a prompt matrix each prompt combination will have its own folder.") - write_info_files = st.checkbox("Write Info file", value=st.session_state['defaults'].txt2img.write_info_files, help="Save a file next to the image with informartion about the generation.") - save_as_jpg = st.checkbox("Save samples as jpg", value=st.session_state['defaults'].txt2img.save_as_jpg, help="Saves the images as jpg instead of png.") - - if st.session_state["GFPGAN_available"]: - st.session_state["use_GFPGAN"] = st.checkbox("Use GFPGAN", value=st.session_state['defaults'].txt2img.use_GFPGAN, help="Uses the GFPGAN model to improve faces after the generation.\ - This greatly improve the quality and consistency of faces but uses extra VRAM. Disable if you need the extra VRAM.") - else: - st.session_state["use_GFPGAN"] = False - - if st.session_state["RealESRGAN_available"]: - st.session_state["use_RealESRGAN"] = st.checkbox("Use RealESRGAN", value=st.session_state['defaults'].txt2img.use_RealESRGAN, - help="Uses the RealESRGAN model to upscale the images after the generation.\ - This greatly improve the quality and lets you have high resolution images but uses extra VRAM. Disable if you need the extra VRAM.") - st.session_state["RealESRGAN_model"] = st.selectbox("RealESRGAN model", ["RealESRGAN_x4plus", "RealESRGAN_x4plus_anime_6B"], index=0) - else: - st.session_state["use_RealESRGAN"] = False - st.session_state["RealESRGAN_model"] = "RealESRGAN_x4plus" - - variant_amount = st.slider("Variant Amount:", value=st.session_state['defaults'].txt2img.variant_amount, min_value=0.0, max_value=1.0, step=0.01) - variant_seed = st.text_input("Variant Seed:", value=st.session_state['defaults'].txt2img.seed, help="The seed to use when generating a variant, if left blank a random seed will be generated.") - galleryCont = st.empty() - - if generate_button: - #print("Loading models") - # load the models when we hit the generate button for the first time, it wont be loaded after that so dont worry. - load_models(False, st.session_state["use_GFPGAN"], st.session_state["use_RealESRGAN"], st.session_state["RealESRGAN_model"], st.session_state["CustomModel_available"], - st.session_state["custom_model"]) - - - try: - # - output_images, seeds, info, stats = txt2img(prompt, st.session_state.sampling_steps, sampler_name, st.session_state["RealESRGAN_model"], batch_count, batch_size, - cfg_scale, seed, height, width, separate_prompts, normalize_prompt_weights, save_individual_images, - save_grid, group_by_prompt, save_as_jpg, st.session_state["use_GFPGAN"], st.session_state["use_RealESRGAN"], st.session_state["RealESRGAN_model"], - variant_amount=variant_amount, variant_seed=variant_seed, write_info_files=write_info_files) - - message.success('Render Complete: ' + info + '; Stats: ' + stats, icon="✅") - - #history_tab,col1,col2,col3,PlaceHolder,col1_cont,col2_cont,col3_cont = st.session_state['historyTab'] - - #if 'latestImages' in st.session_state: - #for i in output_images: - ##push the new image to the list of latest images and remove the oldest one - ##remove the last index from the list\ - #st.session_state['latestImages'].pop() - ##add the new image to the start of the list - #st.session_state['latestImages'].insert(0, i) - #PlaceHolder.empty() - #with PlaceHolder.container(): - #col1, col2, col3 = st.columns(3) - #col1_cont = st.container() - #col2_cont = st.container() - #col3_cont = st.container() - #images = st.session_state['latestImages'] - #with col1_cont: - #with col1: - #[st.image(images[index]) for index in [0, 3, 6] if index < len(images)] - #with col2_cont: - #with col2: - #[st.image(images[index]) for index in [1, 4, 7] if index < len(images)] - #with col3_cont: - #with col3: - #[st.image(images[index]) for index in [2, 5, 8] if index < len(images)] - #historyGallery = st.empty() - - ## check if output_images length is the same as seeds length - #with gallery_tab: - #st.markdown(createHTMLGallery(output_images,seeds), unsafe_allow_html=True) - - - #st.session_state['historyTab'] = [history_tab,col1,col2,col3,PlaceHolder,col1_cont,col2_cont,col3_cont] - - except (StopException, KeyError): - print(f"Received Streamlit StopException") - - # this will render all the images at the end of the generation but its better if its moved to a second tab inside col2 and shown as a gallery. - # use the current col2 first tab to show the preview_img and update it as its generated. - #preview_image.image(output_images) - -#on import run init -def createHTMLGallery(images,info): - html3 = """ - ' - return html3 \ No newline at end of file diff --git a/scripts/txt2vid.py b/scripts/txt2vid.py deleted file mode 100644 index e1be209..0000000 --- a/scripts/txt2vid.py +++ /dev/null @@ -1,780 +0,0 @@ -# base webui import and utils. -from webui_streamlit import st -from sd_utils import * - -# streamlit imports -from streamlit import StopException -from streamlit.runtime.in_memory_file_manager import in_memory_file_manager -from streamlit.elements import image as STImage - -#other imports - -import os -from PIL import Image -import torch -import numpy as np -import time, inspect, timeit -import torch -from torch import autocast -from io import BytesIO -import imageio -from slugify import slugify - -# Temp imports - -# these are for testing txt2vid, should be removed and we should use things from our own code. -from diffusers import StableDiffusionPipeline -from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler - -# end of imports -#--------------------------------------------------------------------------------------------------------------- - -try: - # this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start. - from transformers import logging - - logging.set_verbosity_error() -except: - pass - -class plugin_info(): - plugname = "txt2img" - description = "Text to Image" - isTab = True - displayPriority = 1 - - -if os.path.exists(os.path.join(st.session_state['defaults'].general.GFPGAN_dir, "experiments", "pretrained_models", "GFPGANv1.3.pth")): - GFPGAN_available = True -else: - GFPGAN_available = False - -if os.path.exists(os.path.join(st.session_state['defaults'].general.RealESRGAN_dir, "experiments","pretrained_models", f"{st.session_state['defaults'].txt2vid.RealESRGAN_model}.pth")): - RealESRGAN_available = True -else: - RealESRGAN_available = False - -# -# ----------------------------------------------------------------------------- - -@torch.no_grad() -def diffuse( - pipe, - cond_embeddings, # text conditioning, should be (1, 77, 768) - cond_latents, # image conditioning, should be (1, 4, 64, 64) - num_inference_steps, - cfg_scale, - eta, - ): - - torch_device = cond_latents.get_device() - - # classifier guidance: add the unconditional embedding - max_length = cond_embeddings.shape[1] # 77 - uncond_input = pipe.tokenizer([""], padding="max_length", max_length=max_length, return_tensors="pt") - uncond_embeddings = pipe.text_encoder(uncond_input.input_ids.to(torch_device))[0] - text_embeddings = torch.cat([uncond_embeddings, cond_embeddings]) - - # if we use LMSDiscreteScheduler, let's make sure latents are mulitplied by sigmas - if isinstance(pipe.scheduler, LMSDiscreteScheduler): - cond_latents = cond_latents * pipe.scheduler.sigmas[0] - - # init the scheduler - accepts_offset = "offset" in set(inspect.signature(pipe.scheduler.set_timesteps).parameters.keys()) - extra_set_kwargs = {} - if accepts_offset: - extra_set_kwargs["offset"] = 1 - - pipe.scheduler.set_timesteps(num_inference_steps + st.session_state.sampling_steps, **extra_set_kwargs) - # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature - # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. - # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 - # and should be between [0, 1] - accepts_eta = "eta" in set(inspect.signature(pipe.scheduler.step).parameters.keys()) - extra_step_kwargs = {} - if accepts_eta: - extra_step_kwargs["eta"] = eta - - - step_counter = 0 - inference_counter = 0 - - if "current_chunk_speed" not in st.session_state: - st.session_state["current_chunk_speed"] = 0 - - if "previous_chunk_speed_list" not in st.session_state: - st.session_state["previous_chunk_speed_list"] = [0] - st.session_state["previous_chunk_speed_list"].append(st.session_state["current_chunk_speed"]) - - if "update_preview_frequency_list" not in st.session_state: - st.session_state["update_preview_frequency_list"] = [0] - st.session_state["update_preview_frequency_list"].append(st.session_state['defaults'].txt2vid.update_preview_frequency) - - - # diffuse! - for i, t in enumerate(pipe.scheduler.timesteps): - start = timeit.default_timer() - - #status_text.text(f"Running step: {step_counter}{total_number_steps} {percent} | {duration:.2f}{speed}") - - # expand the latents for classifier free guidance - latent_model_input = torch.cat([cond_latents] * 2) - if isinstance(pipe.scheduler, LMSDiscreteScheduler): - sigma = pipe.scheduler.sigmas[i] - latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5) - - # predict the noise residual - noise_pred = pipe.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"] - - # cfg - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + cfg_scale * (noise_pred_text - noise_pred_uncond) - - # compute the previous noisy sample x_t -> x_t-1 - if isinstance(pipe.scheduler, LMSDiscreteScheduler): - cond_latents = pipe.scheduler.step(noise_pred, i, cond_latents, **extra_step_kwargs)["prev_sample"] - else: - cond_latents = pipe.scheduler.step(noise_pred, t, cond_latents, **extra_step_kwargs)["prev_sample"] - - #print (st.session_state["update_preview_frequency"]) - #update the preview image if it is enabled and the frequency matches the step_counter - if st.session_state['defaults'].txt2vid.update_preview: - step_counter += 1 - - if st.session_state['defaults'].txt2vid.update_preview_frequency == step_counter or step_counter == st.session_state.sampling_steps: - if st.session_state.dynamic_preview_frequency: - st.session_state["current_chunk_speed"], st.session_state["previous_chunk_speed_list"], st.session_state['defaults'].txt2vid.update_preview_frequency, st.session_state["avg_update_preview_frequency"] = optimize_update_preview_frequency(st.session_state["current_chunk_speed"], st.session_state["previous_chunk_speed_list"], st.session_state['defaults'].txt2vid.update_preview_frequency, st.session_state["update_preview_frequency_list"]) - - #scale and decode the image latents with vae - cond_latents_2 = 1 / 0.18215 * cond_latents - image = pipe.vae.decode(cond_latents_2) - - # generate output numpy image as uint8 - image = torch.clamp((image["sample"] + 1.0) / 2.0, min=0.0, max=1.0) - image = transforms.ToPILImage()(image.squeeze_(0)) - - st.session_state["preview_image"].image(image) - - step_counter = 0 - - duration = timeit.default_timer() - start - - st.session_state["current_chunk_speed"] = duration - - if duration >= 1: - speed = "s/it" - else: - speed = "it/s" - duration = 1 / duration - - if i > st.session_state.sampling_steps: - inference_counter += 1 - inference_percent = int(100 * float(inference_counter + 1 if inference_counter < num_inference_steps else num_inference_steps)/float(num_inference_steps)) - inference_progress = f"{inference_counter + 1 if inference_counter < num_inference_steps else num_inference_steps}/{num_inference_steps} {inference_percent}% " - else: - inference_progress = "" - - percent = int(100 * float(i+1 if i+1 < st.session_state.sampling_steps else st.session_state.sampling_steps)/float(st.session_state.sampling_steps)) - frames_percent = int(100 * float(st.session_state.current_frame if st.session_state.current_frame < st.session_state.max_frames else st.session_state.max_frames)/float(st.session_state.max_frames)) - - st.session_state["progress_bar_text"].text( - f"Running step: {i+1 if i+1 < st.session_state.sampling_steps else st.session_state.sampling_steps}/{st.session_state.sampling_steps} " - f"{percent if percent < 100 else 100}% {inference_progress}{duration:.2f}{speed} | " - f"Frame: {st.session_state.current_frame + 1 if st.session_state.current_frame < st.session_state.max_frames else st.session_state.max_frames}/{st.session_state.max_frames} " - f"{frames_percent if frames_percent < 100 else 100}% {st.session_state.frame_duration:.2f}{st.session_state.frame_speed}" - ) - st.session_state["progress_bar"].progress(percent if percent < 100 else 100) - - return image - -# -def txt2vid( - # -------------------------------------- - # args you probably want to change - prompts = ["blueberry spaghetti", "strawberry spaghetti"], # prompt to dream about - gpu:int = st.session_state['defaults'].general.gpu, # id of the gpu to run on - #name:str = 'test', # name of this project, for the output directory - #rootdir:str = st.session_state['defaults'].general.outdir, - num_steps:int = 200, # number of steps between each pair of sampled points - max_frames:int = 10000, # number of frames to write and then exit the script - num_inference_steps:int = 50, # more (e.g. 100, 200 etc) can create slightly better images - cfg_scale:float = 5.0, # can depend on the prompt. usually somewhere between 3-10 is good - do_loop = False, - use_lerp_for_text = False, - seeds = None, - quality:int = 100, # for jpeg compression of the output images - eta:float = 0.0, - width:int = 256, - height:int = 256, - weights_path = "CompVis/stable-diffusion-v1-4", - scheduler="klms", # choices: default, ddim, klms - disable_tqdm = False, - #----------------------------------------------- - beta_start = 0.0001, - beta_end = 0.00012, - beta_schedule = "scaled_linear", - starting_image=None - ): - """ - prompt = ["blueberry spaghetti", "strawberry spaghetti"], # prompt to dream about - gpu:int = st.session_state['defaults'].general.gpu, # id of the gpu to run on - #name:str = 'test', # name of this project, for the output directory - #rootdir:str = st.session_state['defaults'].general.outdir, - num_steps:int = 200, # number of steps between each pair of sampled points - max_frames:int = 10000, # number of frames to write and then exit the script - num_inference_steps:int = 50, # more (e.g. 100, 200 etc) can create slightly better images - cfg_scale:float = 5.0, # can depend on the prompt. usually somewhere between 3-10 is good - do_loop = False, - use_lerp_for_text = False, - seed = None, - quality:int = 100, # for jpeg compression of the output images - eta:float = 0.0, - width:int = 256, - height:int = 256, - weights_path = "CompVis/stable-diffusion-v1-4", - scheduler="klms", # choices: default, ddim, klms - disable_tqdm = False, - beta_start = 0.0001, - beta_end = 0.00012, - beta_schedule = "scaled_linear" - """ - mem_mon = MemUsageMonitor('MemMon') - mem_mon.start() - - - seeds = seed_to_int(seeds) - - # We add an extra frame because most - # of the time the first frame is just the noise. - #max_frames +=1 - - assert torch.cuda.is_available() - assert height % 8 == 0 and width % 8 == 0 - torch.manual_seed(seeds) - torch_device = f"cuda:{gpu}" - - # init the output dir - sanitized_prompt = slugify(prompts) - - full_path = os.path.join(os.getcwd(), st.session_state['defaults'].general.outdir, "txt2vid-samples", "samples", sanitized_prompt) - - if len(full_path) > 220: - sanitized_prompt = sanitized_prompt[:220-len(full_path)] - full_path = os.path.join(os.getcwd(), st.session_state['defaults'].general.outdir, "txt2vid-samples", "samples", sanitized_prompt) - - os.makedirs(full_path, exist_ok=True) - - # Write prompt info to file in output dir so we can keep track of what we did - if st.session_state.write_info_files: - with open(os.path.join(full_path , f'{slugify(str(seeds))}_config.json' if len(prompts) > 1 else "prompts_config.json"), "w") as outfile: - outfile.write(json.dumps( - dict( - prompts = prompts, - gpu = gpu, - num_steps = num_steps, - max_frames = max_frames, - num_inference_steps = num_inference_steps, - cfg_scale = cfg_scale, - do_loop = do_loop, - use_lerp_for_text = use_lerp_for_text, - seeds = seeds, - quality = quality, - eta = eta, - width = width, - height = height, - weights_path = weights_path, - scheduler=scheduler, - disable_tqdm = disable_tqdm, - beta_start = beta_start, - beta_end = beta_end, - beta_schedule = beta_schedule - ), - indent=2, - sort_keys=False, - )) - - #print(scheduler) - default_scheduler = PNDMScheduler( - beta_start=beta_start, beta_end=beta_end, beta_schedule=beta_schedule - ) - # ------------------------------------------------------------------------------ - #Schedulers - ddim_scheduler = DDIMScheduler( - beta_start=beta_start, - beta_end=beta_end, - beta_schedule=beta_schedule, - clip_sample=False, - set_alpha_to_one=False, - ) - - klms_scheduler = LMSDiscreteScheduler( - beta_start=beta_start, beta_end=beta_end, beta_schedule=beta_schedule - ) - - SCHEDULERS = dict(default=default_scheduler, ddim=ddim_scheduler, klms=klms_scheduler) - - # ------------------------------------------------------------------------------ - st.session_state["progress_bar_text"].text("Loading models...") - - try: - if "model" in st.session_state: - del st.session_state["model"] - except: - pass - - #print (st.session_state["weights_path"] != weights_path) - - try: - if not "pipe" in st.session_state or st.session_state["weights_path"] != weights_path: - if st.session_state["weights_path"] != weights_path: - del st.session_state["weights_path"] - - st.session_state["weights_path"] = weights_path - st.session_state["pipe"] = StableDiffusionPipeline.from_pretrained( - weights_path, - use_local_file=True, - use_auth_token=True, - torch_dtype=torch.float16 if st.session_state['defaults'].general.use_float16 else None, - revision="fp16" if not st.session_state['defaults'].general.no_half else None - ) - - st.session_state["pipe"].unet.to(torch_device) - st.session_state["pipe"].vae.to(torch_device) - st.session_state["pipe"].text_encoder.to(torch_device) - - if st.session_state.defaults.general.enable_attention_slicing: - st.session_state["pipe"].enable_attention_slicing() - if st.session_state.defaults.general.enable_minimal_memory_usage: - st.session_state["pipe"].enable_minimal_memory_usage() - - print("Tx2Vid Model Loaded") - else: - print("Tx2Vid Model already Loaded") - - except: - #del st.session_state["weights_path"] - #del st.session_state["pipe"] - - st.session_state["weights_path"] = weights_path - st.session_state["pipe"] = StableDiffusionPipeline.from_pretrained( - weights_path, - use_local_file=True, - use_auth_token=True, - torch_dtype=torch.float16 if st.session_state['defaults'].general.use_float16 else None, - revision="fp16" if not st.session_state['defaults'].general.no_half else None - ) - - st.session_state["pipe"].unet.to(torch_device) - st.session_state["pipe"].vae.to(torch_device) - st.session_state["pipe"].text_encoder.to(torch_device) - - if st.session_state.defaults.general.enable_attention_slicing: - st.session_state["pipe"].enable_attention_slicing() - - - print("Tx2Vid Model Loaded") - - st.session_state["pipe"].scheduler = SCHEDULERS[scheduler] - - # get the conditional text embeddings based on the prompt - text_input = st.session_state["pipe"].tokenizer(prompts, padding="max_length", max_length=st.session_state["pipe"].tokenizer.model_max_length, truncation=True, return_tensors="pt") - cond_embeddings = st.session_state["pipe"].text_encoder(text_input.input_ids.to(torch_device))[0] # shape [1, 77, 768] - - # - if st.session_state.defaults.general.use_sd_concepts_library: - - prompt_tokens = re.findall('<([a-zA-Z0-9-]+)>', prompts) - - if prompt_tokens: - # compviz - #tokenizer = (st.session_state["model"] if not st.session_state['defaults'].general.optimized else st.session_state.modelCS).cond_stage_model.tokenizer - #text_encoder = (st.session_state["model"] if not st.session_state['defaults'].general.optimized else st.session_state.modelCS).cond_stage_model.transformer - - # diffusers - tokenizer = st.session_state.pipe.tokenizer - text_encoder = st.session_state.pipe.text_encoder - - ext = ('pt', 'bin') - #print (prompt_tokens) - - if len(prompt_tokens) > 1: - for token_name in prompt_tokens: - embedding_path = os.path.join(st.session_state['defaults'].general.sd_concepts_library_folder, token_name) - if os.path.exists(embedding_path): - for files in os.listdir(embedding_path): - if files.endswith(ext): - load_learned_embed_in_clip(f"{os.path.join(embedding_path, files)}", text_encoder, tokenizer, f"<{token_name}>") - else: - embedding_path = os.path.join(st.session_state['defaults'].general.sd_concepts_library_folder, prompt_tokens[0]) - if os.path.exists(embedding_path): - for files in os.listdir(embedding_path): - if files.endswith(ext): - load_learned_embed_in_clip(f"{os.path.join(embedding_path, files)}", text_encoder, tokenizer, f"<{prompt_tokens[0]}>") - - # sample a source - init1 = torch.randn((1, st.session_state["pipe"].unet.in_channels, height // 8, width // 8), device=torch_device) - - if do_loop: - prompts = [prompts, prompts] - seeds = [seeds, seeds] - #first_seed, *seeds = seeds - #prompts.append(prompts) - #seeds.append(first_seed) - - - # iterate the loop - frames = [] - frame_index = 0 - - st.session_state["total_frames_avg_duration"] = [] - st.session_state["total_frames_avg_speed"] = [] - - try: - while frame_index < max_frames: - st.session_state["frame_duration"] = 0 - st.session_state["frame_speed"] = 0 - st.session_state["current_frame"] = frame_index - - # sample the destination - init2 = torch.randn((1, st.session_state["pipe"].unet.in_channels, height // 8, width // 8), device=torch_device) - - for i, t in enumerate(np.linspace(0, 1, max_frames)): - start = timeit.default_timer() - print(f"COUNT: {frame_index+1}/{max_frames}") - - #if use_lerp_for_text: - #init = torch.lerp(init1, init2, float(t)) - #else: - #init = slerp(gpu, float(t), init1, init2) - - init = slerp(gpu, float(t), init1, init2) - - with autocast("cuda"): - image = diffuse(st.session_state["pipe"], cond_embeddings, init, num_inference_steps, cfg_scale, eta) - - #im = Image.fromarray(image) - outpath = os.path.join(full_path, 'frame%06d.png' % frame_index) - image.save(outpath, quality=quality) - - # send the image to the UI to update it - #st.session_state["preview_image"].image(im) - - #append the frames to the frames list so we can use them later. - frames.append(np.asarray(image)) - - #increase frame_index counter. - frame_index += 1 - - st.session_state["current_frame"] = frame_index - - duration = timeit.default_timer() - start - - if duration >= 1: - speed = "s/it" - else: - speed = "it/s" - duration = 1 / duration - - st.session_state["frame_duration"] = duration - st.session_state["frame_speed"] = speed - - init1 = init2 - - except StopException: - pass - - - if st.session_state['save_video']: - # write video to memory - #output = io.BytesIO() - #writer = imageio.get_writer(os.path.join(os.getcwd(), st.session_state['defaults'].general.outdir, "txt2vid-samples"), im, extension=".mp4", fps=30) - try: - video_path = os.path.join(os.getcwd(), st.session_state['defaults'].general.outdir, "txt2vid-samples","temp.mp4") - writer = imageio.get_writer(video_path, fps=24) - for frame in frames: - writer.append_data(frame) - writer.close() - except: - print("Can't save video, skipping.") - - # show video preview on the UI - st.session_state["preview_video"].video(open(video_path, 'rb').read()) - - mem_max_used, mem_total = mem_mon.read_and_stop() - time_diff = time.time()- start - - info = f""" - {prompts} - Sampling Steps: {num_steps}, Sampler: {scheduler}, CFG scale: {cfg_scale}, Seed: {seeds}, Max Frames: {max_frames}""".strip() - stats = f''' - Took { round(time_diff, 2) }s total ({ round(time_diff/(max_frames),2) }s per image) - Peak memory usage: { -(mem_max_used // -1_048_576) } MiB / { -(mem_total // -1_048_576) } MiB / { round(mem_max_used/mem_total*100, 3) }%''' - - return video_path, seeds, info, stats - -#on import run init -def createHTMLGallery(images,info): - html3 = """ - ' - return html3 -# -def layout(): - with st.form("txt2vid-inputs"): - st.session_state["generation_mode"] = "txt2vid" - - input_col1, generate_col1 = st.columns([10,1]) - with input_col1: - #prompt = st.text_area("Input Text","") - prompt = st.text_input("Input Text","", placeholder="A corgi wearing a top hat as an oil painting.") - - # Every form must have a submit button, the extra blank spaces is a temp way to align it with the input field. Needs to be done in CSS or some other way. - generate_col1.write("") - generate_col1.write("") - generate_button = generate_col1.form_submit_button("Generate") - - # creating the page layout using columns - col1, col2, col3 = st.columns([1,2,1], gap="large") - - with col1: - width = st.slider("Width:", min_value=64, max_value=2048, value=st.session_state['defaults'].txt2vid.width, step=64) - height = st.slider("Height:", min_value=64, max_value=2048, value=st.session_state['defaults'].txt2vid.height, step=64) - cfg_scale = st.slider("CFG (Classifier Free Guidance Scale):", min_value=1.0, max_value=30.0, value=st.session_state['defaults'].txt2vid.cfg_scale, step=0.5, help="How strongly the image should follow the prompt.") - - #uploaded_images = st.file_uploader("Upload Image", accept_multiple_files=False, type=["png", "jpg", "jpeg", "webp"], - #help="Upload an image which will be used for the image to image generation.") - seed = st.text_input("Seed:", value=st.session_state['defaults'].txt2vid.seed, help=" The seed to use, if left blank a random seed will be generated.") - #batch_count = st.slider("Batch count.", min_value=1, max_value=100, value=st.session_state['defaults'].txt2vid.batch_count, step=1, help="How many iterations or batches of images to generate in total.") - #batch_size = st.slider("Batch size", min_value=1, max_value=250, value=st.session_state['defaults'].txt2vid.batch_size, step=1, - #help="How many images are at once in a batch.\ - #It increases the VRAM usage a lot but if you have enough VRAM it can reduce the time it takes to finish generation as more images are generated at once.\ - #Default: 1") - - st.session_state["max_frames"] = int(st.text_input("Max Frames:", value=st.session_state['defaults'].txt2vid.max_frames, help="Specify the max number of frames you want to generate.")) - - with st.expander("Preview Settings"): - st.session_state["update_preview"] = st.checkbox("Update Image Preview", value=st.session_state['defaults'].txt2vid.update_preview, - help="If enabled the image preview will be updated during the generation instead of at the end. \ - You can use the Update Preview \Frequency option bellow to customize how frequent it's updated. \ - By default this is enabled and the frequency is set to 1 step.") - - st.session_state["update_preview_frequency"] = st.text_input("Update Image Preview Frequency", value=st.session_state['defaults'].txt2vid.update_preview_frequency, - help="Frequency in steps at which the the preview image is updated. By default the frequency \ - is set to 1 step.") - - # - - - - with col2: - preview_tab, gallery_tab = st.tabs(["Preview", "Gallery"]) - - with preview_tab: - #st.write("Image") - #Image for testing - #image = Image.open(requests.get("https://icon-library.com/images/image-placeholder-icon/image-placeholder-icon-13.jpg", stream=True).raw).convert('RGB') - #new_image = image.resize((175, 240)) - #preview_image = st.image(image) - - # create an empty container for the image, progress bar, etc so we can update it later and use session_state to hold them globally. - st.session_state["preview_image"] = st.empty() - - st.session_state["loading"] = st.empty() - - st.session_state["progress_bar_text"] = st.empty() - st.session_state["progress_bar"] = st.empty() - - #generate_video = st.empty() - st.session_state["preview_video"] = st.empty() - - message = st.empty() - - with gallery_tab: - st.write('Here should be the image gallery, if I could make a grid in streamlit.') - - with col3: - # If we have custom models available on the "models/custom" - #folder then we show a menu to select which model we want to use, otherwise we use the main model for SD - if st.session_state["CustomModel_available"]: - custom_model = st.selectbox("Custom Model:", st.session_state["defaults"].txt2vid.custom_models_list, - index=st.session_state["defaults"].txt2vid.custom_models_list.index(st.session_state["defaults"].txt2vid.default_model), - help="Select the model you want to use. This option is only available if you have custom models \ - on your 'models/custom' folder. The model name that will be shown here is the same as the name\ - the file for the model has on said folder, it is recommended to give the .ckpt file a name that \ - will make it easier for you to distinguish it from other models. Default: Stable Diffusion v1.4") - else: - custom_model = "CompVis/stable-diffusion-v1-4" - - #st.session_state["weights_path"] = custom_model - #else: - #custom_model = "CompVis/stable-diffusion-v1-4" - #st.session_state["weights_path"] = f"CompVis/{slugify(custom_model.lower())}" - - st.session_state.sampling_steps = st.slider("Sampling Steps", - value=st.session_state['defaults'].txt2vid.sampling_steps, - min_value=st.session_state['defaults'].txt2vid.slider_bounds.sampling.lower, - max_value=st.session_state['defaults'].txt2vid.slider_bounds.sampling.upper, - step=st.session_state['defaults'].txt2vid.slider_steps.sampling, - help="Number of steps between each pair of sampled points") - st.session_state.num_inference_steps = st.slider("Inference Steps:", value=st.session_state['defaults'].txt2vid.num_inference_steps, min_value=10,step=10, max_value=500, - help="Higher values (e.g. 100, 200 etc) can create better images.") - - #sampler_name_list = ["k_lms", "k_euler", "k_euler_a", "k_dpm_2", "k_dpm_2_a", "k_heun", "PLMS", "DDIM"] - #sampler_name = st.selectbox("Sampling method", sampler_name_list, - #index=sampler_name_list.index(st.session_state['defaults'].txt2vid.default_sampler), help="Sampling method to use. Default: k_euler") - scheduler_name_list = ["klms", "ddim"] - scheduler_name = st.selectbox("Scheduler:", scheduler_name_list, - index=scheduler_name_list.index(st.session_state['defaults'].txt2vid.scheduler_name), help="Scheduler to use. Default: klms") - - beta_scheduler_type_list = ["scaled_linear", "linear"] - beta_scheduler_type = st.selectbox("Beta Schedule Type:", beta_scheduler_type_list, - index=beta_scheduler_type_list.index(st.session_state['defaults'].txt2vid.beta_scheduler_type), help="Schedule Type to use. Default: linear") - - - #basic_tab, advanced_tab = st.tabs(["Basic", "Advanced"]) - - #with basic_tab: - #summit_on_enter = st.radio("Submit on enter?", ("Yes", "No"), horizontal=True, - #help="Press the Enter key to summit, when 'No' is selected you can use the Enter key to write multiple lines.") - - with st.expander("Advanced"): - st.session_state["separate_prompts"] = st.checkbox("Create Prompt Matrix.", value=st.session_state['defaults'].txt2vid.separate_prompts, - help="Separate multiple prompts using the `|` character, and get all combinations of them.") - st.session_state["normalize_prompt_weights"] = st.checkbox("Normalize Prompt Weights.", - value=st.session_state['defaults'].txt2vid.normalize_prompt_weights, help="Ensure the sum of all weights add up to 1.0") - st.session_state["save_individual_images"] = st.checkbox("Save individual images.", - value=st.session_state['defaults'].txt2vid.save_individual_images, help="Save each image generated before any filter or enhancement is applied.") - st.session_state["save_video"] = st.checkbox("Save video",value=st.session_state['defaults'].txt2vid.save_video, help="Save a video with all the images generated as frames at the end of the generation.") - st.session_state["group_by_prompt"] = st.checkbox("Group results by prompt", value=st.session_state['defaults'].txt2vid.group_by_prompt, - help="Saves all the images with the same prompt into the same folder. When using a prompt matrix each prompt combination will have its own folder.") - st.session_state["write_info_files"] = st.checkbox("Write Info file", value=st.session_state['defaults'].txt2vid.write_info_files, - help="Save a file next to the image with informartion about the generation.") - st.session_state["dynamic_preview_frequency"] = st.checkbox("Dynamic Preview Frequency", value=st.session_state['defaults'].txt2vid.dynamic_preview_frequency, - help="This option tries to find the best value at which we can update \ - the preview image during generation while minimizing the impact it has in performance. Default: True") - st.session_state["do_loop"] = st.checkbox("Do Loop", value=st.session_state['defaults'].txt2vid.do_loop, - help="Do loop") - st.session_state["save_as_jpg"] = st.checkbox("Save samples as jpg", value=st.session_state['defaults'].txt2vid.save_as_jpg, help="Saves the images as jpg instead of png.") - - if GFPGAN_available: - st.session_state["use_GFPGAN"] = st.checkbox("Use GFPGAN", value=st.session_state['defaults'].txt2vid.use_GFPGAN, help="Uses the GFPGAN model to improve faces after the generation. This greatly improve the quality and consistency of faces but uses extra VRAM. Disable if you need the extra VRAM.") - else: - st.session_state["use_GFPGAN"] = False - - if RealESRGAN_available: - st.session_state["use_RealESRGAN"] = st.checkbox("Use RealESRGAN", value=st.session_state['defaults'].txt2vid.use_RealESRGAN, - help="Uses the RealESRGAN model to upscale the images after the generation. This greatly improve the quality and lets you have high resolution images but uses extra VRAM. Disable if you need the extra VRAM.") - st.session_state["RealESRGAN_model"] = st.selectbox("RealESRGAN model", ["RealESRGAN_x4plus", "RealESRGAN_x4plus_anime_6B"], index=0) - else: - st.session_state["use_RealESRGAN"] = False - st.session_state["RealESRGAN_model"] = "RealESRGAN_x4plus" - - st.session_state["variant_amount"] = st.slider("Variant Amount:", value=st.session_state['defaults'].txt2vid.variant_amount, min_value=0.0, max_value=1.0, step=0.01) - st.session_state["variant_seed"] = st.text_input("Variant Seed:", value=st.session_state['defaults'].txt2vid.seed, help="The seed to use when generating a variant, if left blank a random seed will be generated.") - st.session_state["beta_start"] = st.slider("Beta Start:", value=st.session_state['defaults'].txt2vid.beta_start, min_value=0.0001, max_value=0.03, step=0.0001, format="%.4f") - st.session_state["beta_end"] = st.slider("Beta End:", value=st.session_state['defaults'].txt2vid.beta_end, min_value=0.0001, max_value=0.03, step=0.0001, format="%.4f") - - if generate_button: - #print("Loading models") - # load the models when we hit the generate button for the first time, it wont be loaded after that so dont worry. - #load_models(False, False, False, st.session_state["RealESRGAN_model"], CustomModel_available=st.session_state["CustomModel_available"], custom_model=custom_model) - - try: - # run video generation - video, seed, info, stats = txt2vid(prompts=prompt, gpu=st.session_state["defaults"].general.gpu, - num_steps=st.session_state.sampling_steps, max_frames=int(st.session_state.max_frames), - num_inference_steps=st.session_state.num_inference_steps, - cfg_scale=cfg_scale,do_loop=st.session_state["do_loop"], - seeds=seed, quality=100, eta=0.0, width=width, - height=height, weights_path=custom_model, scheduler=scheduler_name, - disable_tqdm=False, beta_start=st.session_state["beta_start"], beta_end=st.session_state["beta_end"], - beta_schedule=beta_scheduler_type, starting_image=None) - - #message.success('Done!', icon="✅") - message.success('Render Complete: ' + info + '; Stats: ' + stats, icon="✅") - - #history_tab,col1,col2,col3,PlaceHolder,col1_cont,col2_cont,col3_cont = st.session_state['historyTab'] - - #if 'latestVideos' in st.session_state: - #for i in video: - ##push the new image to the list of latest images and remove the oldest one - ##remove the last index from the list\ - #st.session_state['latestVideos'].pop() - ##add the new image to the start of the list - #st.session_state['latestVideos'].insert(0, i) - #PlaceHolder.empty() - - #with PlaceHolder.container(): - #col1, col2, col3 = st.columns(3) - #col1_cont = st.container() - #col2_cont = st.container() - #col3_cont = st.container() - - #with col1_cont: - #with col1: - #st.image(st.session_state['latestVideos'][0]) - #st.image(st.session_state['latestVideos'][3]) - #st.image(st.session_state['latestVideos'][6]) - #with col2_cont: - #with col2: - #st.image(st.session_state['latestVideos'][1]) - #st.image(st.session_state['latestVideos'][4]) - #st.image(st.session_state['latestVideos'][7]) - #with col3_cont: - #with col3: - #st.image(st.session_state['latestVideos'][2]) - #st.image(st.session_state['latestVideos'][5]) - #st.image(st.session_state['latestVideos'][8]) - #historyGallery = st.empty() - - ## check if output_images length is the same as seeds length - #with gallery_tab: - #st.markdown(createHTMLGallery(video,seed), unsafe_allow_html=True) - - - #st.session_state['historyTab'] = [history_tab,col1,col2,col3,PlaceHolder,col1_cont,col2_cont,col3_cont] - - except (StopException, KeyError): - print(f"Received Streamlit StopException") - - diff --git a/scripts/webui.py b/scripts/webui.py index eb5d32f..dd64a4c 100644 --- a/scripts/webui.py +++ b/scripts/webui.py @@ -2,10 +2,8 @@ import argparse, os, sys, glob, re import cv2 -from perlin import perlinNoise from frontend.frontend import draw_gradio_ui from frontend.job_manager import JobManager, JobInfo -from frontend.image_metadata import ImageMetadata from frontend.ui_functions import resize_image parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument("--ckpt", type=str, default="models/ldm/stable-diffusion-v1/model.ckpt", help="path to checkpoint of model",) @@ -15,7 +13,7 @@ parser.add_argument("--defaults", type=str, help="path to configuration file pro parser.add_argument("--esrgan-cpu", action='store_true', help="run ESRGAN on cpu", default=False) parser.add_argument("--esrgan-gpu", type=int, help="run ESRGAN on specific gpu (overrides --gpu)", default=0) parser.add_argument("--extra-models-cpu", action='store_true', help="run extra models (GFGPAN/ESRGAN) on cpu", default=False) -parser.add_argument("--extra-models-gpu", action='store_true', help="run extra models (GFGPAN/ESRGAN) on gpu", default=False) +parser.add_argument("--extra-models-gpu", action='store_true', help="run extra models (GFGPAN/ESRGAN) on cpu", default=False) parser.add_argument("--gfpgan-cpu", action='store_true', help="run GFPGAN on cpu", default=False) parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN')) # i disagree with where you're putting it but since all guidefags are doing it this way, there you go parser.add_argument("--gfpgan-gpu", type=int, help="run GFPGAN on specific gpu (overrides --gpu) ", default=0) @@ -33,7 +31,6 @@ parser.add_argument("--outdir_img2img", type=str, nargs="?", help="dir to write parser.add_argument("--outdir_imglab", type=str, nargs="?", help="dir to write imglab results to (overrides --outdir)", default=None) parser.add_argument("--outdir_txt2img", type=str, nargs="?", help="dir to write txt2img results to (overrides --outdir)", default=None) parser.add_argument("--outdir", type=str, nargs="?", help="dir to write results to", default=None) -parser.add_argument("--filename_format", type=str, nargs="?", help="filenames format", default=None) parser.add_argument("--port", type=int, help="choose the port for the gradio webserver to use", default=7860) parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast") parser.add_argument("--realesrgan-dir", type=str, help="RealESRGAN directory", default=('./src/realesrgan' if os.path.exists('./src/realesrgan') else './RealESRGAN')) @@ -45,7 +42,6 @@ parser.add_argument("--skip-grid", action='store_true', help="do not save a grid parser.add_argument("--skip-save", action='store_true', help="do not save indiviual samples. For speed measurements.", default=False) parser.add_argument('--no-job-manager', action='store_true', help="Don't use the experimental job manager on top of gradio", default=False) parser.add_argument("--max-jobs", type=int, help="Maximum number of concurrent 'generate' commands", default=1) -parser.add_argument("--tiling", action='store_true', help="Generate tiling images", default=False) opt = parser.parse_args() #Should not be needed anymore @@ -70,27 +66,16 @@ import torch import torch.nn as nn import yaml import glob -import copy -from typing import List, Union, Dict, Callable, Any, Optional +from typing import List, Union, Dict from pathlib import Path from collections import namedtuple -from functools import partial - -# tell the user which GPU the code is actually using -if os.getenv("SD_WEBUI_DEBUG", 'False').lower() in ('true', '1', 'y'): - gpu_in_use = opt.gpu - # prioritize --esrgan-gpu and --gfpgan-gpu over --gpu, as stated in the option info - if opt.esrgan_gpu != opt.gpu: - gpu_in_use = opt.esrgan_gpu - elif opt.gfpgan_gpu != opt.gpu: - gpu_in_use = opt.gfpgan_gpu - print("Starting on GPU {selected_gpu_name}".format(selected_gpu_name=torch.cuda.get_device_name(gpu_in_use))) from contextlib import contextmanager, nullcontext from einops import rearrange, repeat from itertools import islice from omegaconf import OmegaConf -from PIL import Image, ImageFont, ImageDraw, ImageFilter, ImageOps, ImageChops +from PIL import Image, ImageFont, ImageDraw, ImageFilter, ImageOps +from PIL.PngImagePlugin import PngInfo from io import BytesIO import base64 import re @@ -99,18 +84,6 @@ from ldm.models.diffusion.ddim import DDIMSampler from ldm.models.diffusion.plms import PLMSSampler from ldm.util import instantiate_from_config -# add global options to models -def patch_conv(**patch): - cls = torch.nn.Conv2d - init = cls.__init__ - def __init__(self, *args, **kwargs): - return init(self, *args, **kwargs, **patch) - cls.__init__ = __init__ - -if opt.tiling: - patch_conv(padding_mode='circular') - print("patched for tiling") - try: # this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start. @@ -119,14 +92,6 @@ try: except: pass -from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker -from transformers import AutoFeatureExtractor - -# load safety model -safety_model_id = "CompVis/stable-diffusion-safety-checker" -safety_feature_extractor = None -safety_checker = None - # this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the bowser will not show any UI mimetypes.init() mimetypes.add_type('application/javascript', '.js') @@ -238,16 +203,7 @@ class MemUsageMonitor(threading.Thread): print(f"[{self.name}] Unable to initialize NVIDIA management. No memory stats. \n") return print(f"[{self.name}] Recording max memory usage...\n") - # check if we're using a scoped-down GPU environment (pynvml does not listen to CUDA_VISIBLE_DEVICES) - # so that we can measure memory on the correct GPU - try: - isinstance(int(os.environ["CUDA_VISIBLE_DEVICES"]), int) - handle = pynvml.nvmlDeviceGetHandleByIndex(int(os.environ["CUDA_VISIBLE_DEVICES"])) - except (KeyError, ValueError) as pynvmlHandleError: - if os.getenv("SD_WEBUI_DEBUG", 'False').lower() in ('true', '1', 'y'): - print("[MemMon][WARNING]", pynvmlHandleError) - print("[MemMon][INFO]", "defaulting to monitoring memory on the default gpu (set via --gpu flag)") - handle = pynvml.nvmlDeviceGetHandleByIndex(opt.gpu) + handle = pynvml.nvmlDeviceGetHandleByIndex(opt.gpu) self.total = pynvml.nvmlDeviceGetMemoryInfo(handle).total while not self.stop_flag: m = pynvml.nvmlDeviceGetMemoryInfo(handle) @@ -308,21 +264,15 @@ class KDiffusionSampler: self.schedule = sampler def get_sampler_name(self): return self.schedule - def sample(self, S, conditioning, batch_size, shape, verbose, unconditional_guidance_scale, unconditional_conditioning, eta, x_T, img_callback: Callable = None ): + def sample(self, S, conditioning, batch_size, shape, verbose, unconditional_guidance_scale, unconditional_conditioning, eta, x_T): sigmas = self.model_wrap.get_sigmas(S) x = x_T * sigmas[0] model_wrap_cfg = CFGDenoiser(self.model_wrap) - samples_ddim = K.sampling.__dict__[f'sample_{self.schedule}'](model_wrap_cfg, x, sigmas, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': unconditional_guidance_scale}, disable=False, callback=partial(KDiffusionSampler.img_callback_wrapper, img_callback)) + samples_ddim = K.sampling.__dict__[f'sample_{self.schedule}'](model_wrap_cfg, x, sigmas, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': unconditional_guidance_scale}, disable=False) return samples_ddim, None - @classmethod - def img_callback_wrapper(cls, callback: Callable, *args): - ''' Converts a KDiffusion callback to the standard img_callback ''' - if callback: - arg_dict = args[0] - callback(image_sample=arg_dict['denoised'], iter_num=arg_dict['i']) def create_random_tensors(shape, seeds): xs = [] @@ -642,18 +592,25 @@ def check_prompt_length(prompt, comments): comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n") - -def save_sample(image, sample_path_i, filename, jpg_sample, write_info_files, write_sample_info_to_log_file, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, skip_save, -skip_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, skip_metadata=False): - ''' saves the image according to selected parameters. Expects to find generation parameters on image, set by ImageMetadata.set_on_image() ''' - metadata = ImageMetadata.get_from_image(image) - if not skip_metadata and metadata is None: - print("No metadata passed in to save. Set metadata on the image before calling save_sample using the ImageMetadata.set_on_image() function.") - skip_metadata = True +def save_sample(image, sample_path_i, filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale, +normalize_prompt_weights, use_GFPGAN, write_info_files, write_sample_info_to_log_file, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, skip_save, +skip_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, skip_metadata=True): filename_i = os.path.join(sample_path_i, filename) if not jpg_sample: if opt.save_metadata and not skip_metadata: - image.save(f"{filename_i}.png", pnginfo=metadata.as_png_info()) + metadata = PngInfo() + metadata.add_text("SD:prompt", prompts[i]) + metadata.add_text("SD:seed", str(seeds[i])) + metadata.add_text("SD:width", str(width)) + metadata.add_text("SD:height", str(height)) + metadata.add_text("SD:sampler_name", str(sampler_name)) + metadata.add_text("SD:steps", str(steps)) + metadata.add_text("SD:cfg_scale", str(cfg_scale)) + metadata.add_text("SD:normalize_prompt_weights", str(normalize_prompt_weights)) + if init_img is not None: + metadata.add_text("SD:denoising_strength", str(denoising_strength)) + metadata.add_text("SD:GFPGAN", str(use_GFPGAN and GFPGAN is not None)) + image.save(f"{filename_i}.png", pnginfo=metadata) else: image.save(f"{filename_i}.png") else: @@ -664,7 +621,7 @@ skip_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoisin toggles = [] if prompt_matrix: toggles.append(0) - if metadata.normalize_prompt_weights: + if normalize_prompt_weights: toggles.append(1) if init_img is not None: if uses_loopback: @@ -681,14 +638,14 @@ skip_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoisin toggles.append(5 + offset) if write_sample_info_to_log_file: toggles.append(6+offset) - if metadata.GFPGAN: + if use_GFPGAN: toggles.append(7 + offset) info_dict = dict( target="txt2img" if init_img is None else "img2img", - prompt=metadata.prompt, ddim_steps=metadata.steps, toggles=toggles, sampler_name=sampler_name, - ddim_eta=ddim_eta, n_iter=n_iter, batch_size=batch_size, cfg_scale=metadata.cfg_scale, - seed=metadata.seed, width=metadata.width, height=metadata.height + prompt=prompts[i], ddim_steps=steps, toggles=toggles, sampler_name=sampler_name, + ddim_eta=ddim_eta, n_iter=n_iter, batch_size=batch_size, cfg_scale=cfg_scale, + seed=seeds[i], width=width, height=height ) if init_img is not None: # Not yet any use for these, but they bloat up the files: @@ -818,95 +775,16 @@ def oxlamon_matrix(prompt, seed, n_iter, batch_size): return all_seeds, n_iter, prompt_matrix_parts, all_prompts, needrows -def perform_masked_image_restoration(image, init_img, init_mask, mask_blur_strength, mask_restore, use_RealESRGAN, RealESRGAN): - if not mask_restore: - return image - else: - init_mask = init_mask.filter(ImageFilter.GaussianBlur(mask_blur_strength)) - init_mask = init_mask.convert('L') - init_img = init_img.convert('RGB') - image = image.convert('RGB') - if use_RealESRGAN and RealESRGAN is not None: - output, img_mode = RealESRGAN.enhance(np.array(init_mask, dtype=np.uint8)) - init_mask = Image.fromarray(output) - init_mask = init_mask.convert('L') - - output, img_mode = RealESRGAN.enhance(np.array(init_img, dtype=np.uint8)) - init_img = Image.fromarray(output) - init_img = init_img.convert('RGB') - - image = Image.composite(init_img, image, init_mask) - - return image - - -def perform_color_correction(img_rgb, correction_target_lab, do_color_correction): - try: - from skimage import exposure - except: - print("Install scikit-image to perform color correction") - return img_rgb - - if not do_color_correction: return img_rgb - if correction_target_lab is None: return img_rgb - - return ( - Image.fromarray(cv2.cvtColor(exposure.match_histograms( - cv2.cvtColor( - np.asarray(img_rgb), - cv2.COLOR_RGB2LAB - ), - correction_target_lab, - channel_axis=2 - ), cv2.COLOR_LAB2RGB).astype("uint8") - ) - ) def process_images( outpath, func_init, func_sample, prompt, seed, sampler_name, skip_grid, skip_save, batch_size, - n_iter, steps, cfg_scale, width, height, prompt_matrix, filter_nsfw, use_GFPGAN, use_RealESRGAN, realesrgan_model_name, + n_iter, steps, cfg_scale, width, height, prompt_matrix, use_GFPGAN, use_RealESRGAN, realesrgan_model_name, fp, ddim_eta=0.0, do_not_save_grid=False, normalize_prompt_weights=True, init_img=None, init_mask=None, - keep_mask=False, mask_blur_strength=3, mask_restore=False, denoising_strength=0.75, resize_mode=None, uses_loopback=False, + keep_mask=False, mask_blur_strength=3, denoising_strength=0.75, resize_mode=None, uses_loopback=False, uses_random_seed_loopback=False, sort_samples=True, write_info_files=True, write_sample_info_to_log_file=False, jpg_sample=False, - variant_amount=0.0, variant_seed=None,imgProcessorTask=False, job_info: JobInfo = None, do_color_correction=False, correction_target=None): + variant_amount=0.0, variant_seed=None,imgProcessorTask=False, job_info: JobInfo = None): """this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch""" - - def numpy_to_pil(images): - """ - Convert a numpy image or a batch of images to a PIL image. - """ - if images.ndim == 3: - images = images[None, ...] - images = (images * 255).round().astype("uint8") - pil_images = [Image.fromarray(image) for image in images] - - return pil_images - - # load replacement of nsfw content - def load_replacement(x): - try: - hwc = x.shape - y = Image.open("images/nsfw.jpeg").convert("RGB").resize((hwc[1], hwc[0])) - y = (np.array(y)/255.0).astype(x.dtype) - assert y.shape == x.shape - return y - except Exception: - return x - - # check and replace nsfw content - def check_safety(x_image): - global safety_feature_extractor, safety_checker - if safety_feature_extractor is None: - safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id) - safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id) - safety_checker_input = safety_feature_extractor(numpy_to_pil(x_image), return_tensors="pt") - x_checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values) - for i in range(len(has_nsfw_concept)): - if has_nsfw_concept[i]: - x_checked_image[i] = load_replacement(x_checked_image[i]) - return x_checked_image, has_nsfw_concept - prompt = prompt or '' torch_gc() # start time after garbage collection (or before?) @@ -926,12 +804,6 @@ def process_images( if not ("|" in prompt) and prompt.startswith("@"): prompt = prompt[1:] - negprompt = '' - if '###' in prompt: - prompt, negprompt = prompt.split('###', 1) - prompt = prompt.strip() - negprompt = negprompt.strip() - comments = [] prompt_matrix_parts = [] @@ -1010,14 +882,12 @@ def process_images( if job_info: job_info.job_status = f"Processing Iteration {n+1}/{n_iter}. Batch size {batch_size}" - job_info.rec_steps_imgs.clear() for idx,(p,s) in enumerate(zip(prompts,seeds)): job_info.job_status += f"\nItem {idx}: Seed {s}\nPrompt: {p}" - print(f"Current prompt: {p}") if opt.optimized: modelCS.to(device) - uc = (model if not opt.optimized else modelCS).get_learned_conditioning(len(prompts) * [negprompt]) + uc = (model if not opt.optimized else modelCS).get_learned_conditioning(len(prompts) * [""]) if isinstance(prompts, tuple): prompts = list(prompts) @@ -1042,7 +912,7 @@ def process_images( while(torch.cuda.memory_allocated()/1e6 >= mem): time.sleep(1) - cur_variant_amount = variant_amount + cur_variant_amount = variant_amount if variant_amount == 0.0: # we manually generate all input noises because each one should have a specific seed x = create_random_tensors(shape, seeds=seeds) @@ -1065,91 +935,16 @@ def process_images( # finally, slerp base_x noise to target_x noise for creating a variant x = slerp(device, max(0.0, min(1.0, cur_variant_amount)), base_x, target_x) - # If optimized then use first stage for preview and store it on cpu until needed - if opt.optimized: - step_preview_model = modelFS - step_preview_model.cpu() - else: - step_preview_model = model - - def sample_iteration_callback(image_sample: torch.Tensor, iter_num: int): - ''' Called from the sampler every iteration ''' - if job_info: - job_info.active_iteration_cnt = iter_num - record_periodic_image = job_info.rec_steps_enabled and (0 == iter_num % job_info.rec_steps_intrvl) - if record_periodic_image or job_info.refresh_active_image_requested.is_set(): - preview_start_time = time.time() - if opt.optimized: - step_preview_model.to(device) - - decoded_batch: List[torch.Tensor] = [] - # Break up batch to save VRAM - for sample in image_sample: - sample = sample[None, :] # expands the tensor as if it still had a batch dimension - decoded_sample = step_preview_model.decode_first_stage(sample)[0] - decoded_sample = torch.clamp((decoded_sample + 1.0) / 2.0, min=0.0, max=1.0) - decoded_sample = decoded_sample.cpu() - decoded_batch.append(decoded_sample) - - batch_size = len(decoded_batch) - - if opt.optimized: - step_preview_model.cpu() - - images: List[Image.Image] = [] - # Convert tensor to image (copied from code below) - for ddim in decoded_batch: - x_sample = 255. * rearrange(ddim.numpy(), 'c h w -> h w c') - x_sample = x_sample.astype(np.uint8) - image = Image.fromarray(x_sample) - images.append(image) - - caption = f"Iter {iter_num}" - grid = image_grid(images, len(images), force_n_rows=1, captions=[caption]*len(images)) - - # Save the images if recording steps, and append existing saved steps - if job_info.rec_steps_enabled: - gallery_img_size = tuple(int(0.25*dim) for dim in images[0].size) - job_info.rec_steps_imgs.append(grid.resize(gallery_img_size)) - - # Notify the requester that the image is updated - if job_info.refresh_active_image_requested.is_set(): - if job_info.rec_steps_enabled: - grid_rows = None if batch_size == 1 else len(job_info.rec_steps_imgs) - grid = image_grid(imgs=job_info.rec_steps_imgs[::-1], batch_size=1, force_n_rows=grid_rows) - job_info.active_image = grid - job_info.refresh_active_image_done.set() - job_info.refresh_active_image_requested.clear() - - preview_elapsed_timed = time.time() - preview_start_time - if preview_elapsed_timed / job_info.rec_steps_intrvl > 1: - print( - f"Warning: Preview generation is slowing image generation. It took {preview_elapsed_timed:.2f}s to generate progress images for batch of {batch_size} images!") - - # Interrupt current iteration? - if job_info.stop_cur_iter.is_set(): - job_info.stop_cur_iter.clear() - raise StopIteration() - - try: - samples_ddim = func_sample(init_data=init_data, x=x, conditioning=c, unconditional_conditioning=uc, sampler_name=sampler_name, img_callback=sample_iteration_callback) - except StopIteration: - print("Skipping iteration") - job_info.job_status = "Skipping iteration" - continue + samples_ddim = func_sample(init_data=init_data, x=x, conditioning=c, unconditional_conditioning=uc, sampler_name=sampler_name) if opt.optimized: modelFS.to(device) - for i in range(len(samples_ddim)): - x_samples_ddim = (model if not opt.optimized else modelFS).decode_first_stage(samples_ddim[i].unsqueeze(0)) - x_sample = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) - if filter_nsfw: - x_samples_ddim_numpy = x_sample.cpu().permute(0, 2, 3, 1).numpy() - x_checked_image, has_nsfw_concept = check_safety(x_samples_ddim_numpy) - x_sample = torch.from_numpy(x_checked_image).permute(0, 3, 1, 2) + x_samples_ddim = (model if not opt.optimized else modelFS).decode_first_stage(samples_ddim) + x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) + for i, x_sample in enumerate(x_samples_ddim): sanitized_prompt = prompts[i].replace(' ', '_').translate({ord(x): '' for x in invalid_filename_chars}) if variant_seed != None and variant_seed != '': if variant_amount == 0.0: @@ -1163,33 +958,16 @@ def process_images( sample_path_i = os.path.join(sample_path, sanitized_prompt) os.makedirs(sample_path_i, exist_ok=True) base_count = get_next_sequence_number(sample_path_i) - filename = opt.filename_format or "[STEPS]_[SAMPLER]_[SEED]_[VARIANT_AMOUNT]" + filename = f"{base_count:05}-{steps}_{sampler_name}_{seed_used}_{cur_variant_amount:.2f}" else: sample_path_i = sample_path base_count = get_next_sequence_number(sample_path_i) - filename = opt.filename_format or "[STEPS]_[SAMPLER]_[SEED]_[VARIANT_AMOUNT]_[PROMPT]" + sanitized_prompt = sanitized_prompt + filename = f"{base_count:05}-{steps}_{sampler_name}_{seed_used}_{cur_variant_amount:.2f}_{sanitized_prompt}"[:128] #same as before - #Add new filenames tags here - filename = f"{base_count:05}-" + filename - filename = filename.replace("[STEPS]", str(steps)) - filename = filename.replace("[CFG]", str(cfg_scale)) - filename = filename.replace("[PROMPT]", sanitized_prompt[:128]) - filename = filename.replace("[PROMPT_SPACES]", prompts[i].translate({ord(x): '' for x in invalid_filename_chars})[:128]) - filename = filename.replace("[WIDTH]", str(width)) - filename = filename.replace("[HEIGHT]", str(height)) - filename = filename.replace("[SAMPLER]", sampler_name) - filename = filename.replace("[SEED]", seed_used) - filename = filename.replace("[VARIANT_AMOUNT]", f"{cur_variant_amount:.2f}") - - x_sample = 255. * rearrange(x_sample[0].cpu().numpy(), 'c h w -> h w c') + x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') x_sample = x_sample.astype(np.uint8) - metadata = ImageMetadata(prompt=prompts[i], seed=seeds[i], height=height, width=width, steps=steps, - cfg_scale=cfg_scale, normalize_prompt_weights=normalize_prompt_weights, denoising_strength=denoising_strength, - GFPGAN=use_GFPGAN ) image = Image.fromarray(x_sample) - image = perform_color_correction(image, correction_target, do_color_correction) - ImageMetadata.set_on_image(image, metadata) - original_sample = x_sample original_filename = filename if use_GFPGAN and GFPGAN is not None and not use_RealESRGAN: @@ -1198,18 +976,10 @@ def process_images( cropped_faces, restored_faces, restored_img = GFPGAN.enhance(original_sample[:,:,::-1], has_aligned=False, only_center_face=False, paste_back=True) gfpgan_sample = restored_img[:,:,::-1] gfpgan_image = Image.fromarray(gfpgan_sample) - gfpgan_image = perform_color_correction(gfpgan_image, correction_target, do_color_correction) - gfpgan_image = perform_masked_image_restoration( - gfpgan_image, init_img, init_mask, - mask_blur_strength, mask_restore, - use_RealESRGAN = False, RealESRGAN = None - ) - gfpgan_metadata = copy.copy(metadata) - gfpgan_metadata.GFPGAN = True - ImageMetadata.set_on_image( gfpgan_image, gfpgan_metadata ) gfpgan_filename = original_filename + '-gfpgan' - save_sample(gfpgan_image, sample_path_i, gfpgan_filename, jpg_sample, write_info_files, write_sample_info_to_log_file, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, skip_save, -skip_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, skip_metadata=False) + save_sample(gfpgan_image, sample_path_i, gfpgan_filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale, +normalize_prompt_weights, use_GFPGAN, write_info_files, write_sample_info_to_log_file, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, skip_save, +skip_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, skip_metadata=True) output_images.append(gfpgan_image) #287 #if simple_templating: # grid_captions.append( captions[i] + "\ngfpgan" ) @@ -1221,15 +991,9 @@ skip_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoisin esrgan_filename = original_filename + '-esrgan4x' esrgan_sample = output[:,:,::-1] esrgan_image = Image.fromarray(esrgan_sample) - esrgan_image = perform_color_correction(esrgan_image, correction_target, do_color_correction) - esrgan_image = perform_masked_image_restoration( - esrgan_image, init_img, init_mask, - mask_blur_strength, mask_restore, - use_RealESRGAN, RealESRGAN - ) - ImageMetadata.set_on_image( esrgan_image, metadata ) - save_sample(esrgan_image, sample_path_i, esrgan_filename, jpg_sample, write_info_files, write_sample_info_to_log_file, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, skip_save, -skip_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, skip_metadata=False) + save_sample(esrgan_image, sample_path_i, esrgan_filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale, +normalize_prompt_weights, use_GFPGAN,write_info_files, write_sample_info_to_log_file, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, skip_save, +skip_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, skip_metadata=True) output_images.append(esrgan_image) #287 #if simple_templating: # grid_captions.append( captions[i] + "\nesrgan" ) @@ -1243,15 +1007,9 @@ skip_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoisin gfpgan_esrgan_filename = original_filename + '-gfpgan-esrgan4x' gfpgan_esrgan_sample = output[:,:,::-1] gfpgan_esrgan_image = Image.fromarray(gfpgan_esrgan_sample) - gfpgan_esrgan_image = perform_color_correction(gfpgan_esrgan_image, correction_target, do_color_correction) - gfpgan_esrgan_image = perform_masked_image_restoration( - gfpgan_esrgan_image, init_img, init_mask, - mask_blur_strength, mask_restore, - use_RealESRGAN, RealESRGAN - ) - ImageMetadata.set_on_image(gfpgan_esrgan_image, metadata) - save_sample(gfpgan_esrgan_image, sample_path_i, gfpgan_esrgan_filename, jpg_sample, write_info_files, write_sample_info_to_log_file, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, -skip_save, skip_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, skip_metadata=False) + save_sample(gfpgan_esrgan_image, sample_path_i, gfpgan_esrgan_filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale, +normalize_prompt_weights, use_GFPGAN, write_info_files, write_sample_info_to_log_file, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, skip_save, +skip_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, skip_metadata=True) output_images.append(gfpgan_esrgan_image) #287 #if simple_templating: # grid_captions.append( captions[i] + "\ngfpgan_esrgan" ) @@ -1260,34 +1018,15 @@ skip_save, skip_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, if imgProcessorTask == True: output_images.append(image) - image = perform_masked_image_restoration( - image, init_img, init_mask, - mask_blur_strength, mask_restore, - # RealESRGAN image already processed in if-case above. - use_RealESRGAN = False, RealESRGAN = None - ) - if not skip_save: - save_sample(image, sample_path_i, filename, jpg_sample, write_info_files, write_sample_info_to_log_file, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, skip_save, + save_sample(image, sample_path_i, filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale, +normalize_prompt_weights, use_GFPGAN, write_info_files, write_sample_info_to_log_file, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, skip_save, skip_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, False) if add_original_image or not simple_templating: output_images.append(image) if simple_templating: grid_captions.append( captions[i] ) - # Save the progress images? - if job_info: - if job_info.rec_steps_enabled and (job_info.rec_steps_to_file or job_info.rec_steps_to_gallery): - steps_grid = image_grid(job_info.rec_steps_imgs, 1) - if job_info.rec_steps_to_gallery: - gallery_img_size = tuple(2*dim for dim in image.size) - output_images.append( steps_grid.resize( gallery_img_size ) ) - if job_info.rec_steps_to_file: - steps_grid_filename = f"{original_filename}_step_grid" - save_sample(steps_grid, sample_path_i, steps_grid_filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale, - normalize_prompt_weights, use_GFPGAN, write_info_files, write_sample_info_to_log_file, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, skip_save, - skip_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, False) - if opt.optimized: mem = torch.cuda.memory_allocated()/1e6 modelFS.to("cpu") @@ -1307,7 +1046,7 @@ skip_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoisin import traceback print("Error creating prompt_matrix text:", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) - elif len(output_images) > 0 and (batch_size > 1 or n_iter > 1): + elif batch_size > 1 or n_iter > 1: grid = image_grid(output_images, batch_size) if grid is not None: grid_count = get_next_sequence_number(outpath, 'grid-') @@ -1362,13 +1101,8 @@ def txt2img(prompt: str, ddim_steps: int, sampler_name: str, toggles: List[int], write_info_files = 5 in toggles write_to_one_file = 6 in toggles jpg_sample = 7 in toggles - filter_nsfw = 8 in toggles - use_GFPGAN = 9 in toggles - use_RealESRGAN = 10 in toggles - - do_color_correction = False - correction_target = None - + use_GFPGAN = 8 in toggles + use_RealESRGAN = 9 in toggles ModelLoader(['model'],True,False) if use_GFPGAN and not use_RealESRGAN: ModelLoader(['GFPGAN'],True,False) @@ -1400,8 +1134,8 @@ def txt2img(prompt: str, ddim_steps: int, sampler_name: str, toggles: List[int], def init(): pass - def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name, img_callback: Callable = None): - samples_ddim, _ = sampler.sample(S=ddim_steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=cfg_scale, unconditional_conditioning=unconditional_conditioning, eta=ddim_eta, x_T=x, img_callback=img_callback) + def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name): + samples_ddim, _ = sampler.sample(S=ddim_steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=cfg_scale, unconditional_conditioning=unconditional_conditioning, eta=ddim_eta, x_T=x) return samples_ddim try: @@ -1421,7 +1155,6 @@ def txt2img(prompt: str, ddim_steps: int, sampler_name: str, toggles: List[int], width=width, height=height, prompt_matrix=prompt_matrix, - filter_nsfw=filter_nsfw, use_GFPGAN=use_GFPGAN, use_RealESRGAN=use_RealESRGAN, realesrgan_model_name=realesrgan_model_name, @@ -1435,8 +1168,6 @@ def txt2img(prompt: str, ddim_steps: int, sampler_name: str, toggles: List[int], variant_amount=variant_amount, variant_seed=variant_seed, job_info=job_info, - do_color_correction=do_color_correction, - correction_target=correction_target ) del sampler @@ -1494,15 +1225,7 @@ class Flagging(gr.FlaggingCallback): print("Logged:", filenames[0]) -def blurArr(a,r=8): - im1=Image.fromarray((a*255).astype(np.int8),"L") - im2 = im1.filter(ImageFilter.GaussianBlur(radius = r)) - out= np.array(im2)/255 - return out - - - -def img2img(prompt: str, image_editor_mode: str, mask_mode: str, mask_blur_strength: int, mask_restore: bool, ddim_steps: int, sampler_name: str, +def img2img(prompt: str, image_editor_mode: str, mask_mode: str, mask_blur_strength: int, ddim_steps: int, sampler_name: str, toggles: List[int], realesrgan_model_name: str, n_iter: int, cfg_scale: float, denoising_strength: float, seed: int, height: int, width: int, resize_mode: int, init_info: any = None, init_info_mask: any = None, fp = None, job_info: JobInfo = None): # print([prompt, image_editor_mode, init_info, init_info_mask, mask_mode, @@ -1526,10 +1249,8 @@ def img2img(prompt: str, image_editor_mode: str, mask_mode: str, mask_blur_stren write_info_files = 7 in toggles write_sample_info_to_log_file = 8 in toggles jpg_sample = 9 in toggles - do_color_correction = 10 in toggles - filter_nsfw = 11 in toggles - use_GFPGAN = 12 in toggles - use_RealESRGAN = 13 in toggles + use_GFPGAN = 10 in toggles + use_RealESRGAN = 11 in toggles ModelLoader(['model'],True,False) if use_GFPGAN and not use_RealESRGAN: ModelLoader(['GFPGAN'],True,False) @@ -1558,12 +1279,10 @@ def img2img(prompt: str, image_editor_mode: str, mask_mode: str, mask_blur_stren if image_editor_mode == 'Mask': init_img = init_info_mask["image"] - init_img_transparency = ImageOps.invert(init_img.split()[-1]).convert('L').point(lambda x: 255 if x > 0 else 0, mode='1') init_img = init_img.convert("RGB") init_img = resize_image(resize_mode, init_img, width, height) init_img = init_img.convert("RGB") init_mask = init_info_mask["mask"] - init_mask = ImageChops.lighter(init_img_transparency, init_mask.convert('L')).convert('RGBA') init_mask = init_mask.convert("RGB") init_mask = resize_image(resize_mode, init_mask, width, height) init_mask = init_mask.convert("RGB") @@ -1586,7 +1305,16 @@ def img2img(prompt: str, image_editor_mode: str, mask_mode: str, mask_blur_stren image = torch.from_numpy(image) mask_channel = None - if image_editor_mode == "Mask": + if image_editor_mode == "Uncrop": + alpha = init_img.convert("RGBA") + alpha = resize_image(resize_mode, alpha, width // 8, height // 8) + mask_channel = alpha.split()[-1] + mask_channel = mask_channel.filter(ImageFilter.GaussianBlur(4)) + mask_channel = np.array(mask_channel) + mask_channel[mask_channel >= 255] = 255 + mask_channel[mask_channel < 255] = 0 + mask_channel = Image.fromarray(mask_channel).filter(ImageFilter.GaussianBlur(2)) + elif image_editor_mode == "Mask": alpha = init_mask.convert("RGBA") alpha = resize_image(resize_mode, alpha, width // 8, height // 8) mask_channel = alpha.split()[1] @@ -1601,62 +1329,11 @@ def img2img(prompt: str, image_editor_mode: str, mask_mode: str, mask_blur_stren if opt.optimized: modelFS.to(device) - #let's try and find where init_image is 0's - #shape is probably (3,width,height)? - - if image_editor_mode == "Uncrop": - _image=image.numpy()[0] - _mask=np.ones((_image.shape[1],_image.shape[2])) - - #compute bounding box - cmax=np.max(_image,axis=0) - rowmax=np.max(cmax,axis=0) - colmax=np.max(cmax,axis=1) - rowwhere=np.where(rowmax>0)[0] - colwhere=np.where(colmax>0)[0] - rowstart=rowwhere[0] - rowend=rowwhere[-1]+1 - colstart=colwhere[0] - colend=colwhere[-1]+1 - print('bounding box: ',rowstart,rowend,colstart,colend) - - #this is where noise will get added - PAD_IMG=16 - boundingbox=np.zeros(shape=(height,width)) - boundingbox[colstart+PAD_IMG:colend-PAD_IMG,rowstart+PAD_IMG:rowend-PAD_IMG]=1 - boundingbox=blurArr(boundingbox,4) - - #this is the mask for outpainting - PAD_MASK=24 - boundingbox2=np.zeros(shape=(height,width)) - boundingbox2[colstart+PAD_MASK:colend-PAD_MASK,rowstart+PAD_MASK:rowend-PAD_MASK]=1 - boundingbox2=blurArr(boundingbox2,4) - - #noise=np.random.randn(*_image.shape) - noise=np.array([perlinNoise(height,width,height/64,width/64) for i in range(3)]) - _mask*=1-boundingbox2 - - #convert 0,1 to -1,1 - _image = 2. * _image - 1. - - #add noise - boundingbox=np.tile(boundingbox,(3,1,1)) - _image=_image*boundingbox+noise*(1-boundingbox) - - #resize mask - _mask = np.array(resize_image(resize_mode, Image.fromarray(_mask*255), width // 8, height // 8))/255 - - #convert back to torch tensor - init_image=torch.from_numpy(np.expand_dims(_image,axis=0).astype(np.float32)).to(device) - mask=torch.from_numpy(_mask.astype(np.float32)).to(device) - - else: - init_image = 2. * image - 1. - + init_image = 2. * image - 1. init_image = init_image.to(device) init_image = repeat(init_image, '1 ... -> b ...', b=batch_size) init_latent = (model if not opt.optimized else modelFS).get_first_stage_encoding((model if not opt.optimized else modelFS).encode_first_stage(init_image)) # move to latent space - + if opt.optimized: mem = torch.cuda.memory_allocated()/1e6 modelFS.to("cpu") @@ -1665,7 +1342,7 @@ def img2img(prompt: str, image_editor_mode: str, mask_mode: str, mask_blur_stren return init_latent, mask, - def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name, img_callback: Callable = None): + def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name): t_enc_steps = t_enc obliterate = False if ddim_steps == t_enc_steps: @@ -1687,7 +1364,7 @@ def img2img(prompt: str, image_editor_mode: str, mask_mode: str, mask_blur_stren sigma_sched = sigmas[ddim_steps - t_enc_steps - 1:] model_wrap_cfg = CFGMaskedDenoiser(sampler.model_wrap) - samples_ddim = K.sampling.__dict__[f'sample_{sampler.get_sampler_name()}'](model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': cfg_scale, 'mask': z_mask, 'x0': x0, 'xi': xi}, disable=False, callback=partial(KDiffusionSampler.img_callback_wrapper, img_callback)) + samples_ddim = K.sampling.__dict__[f'sample_{sampler.get_sampler_name()}'](model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': cfg_scale, 'mask': z_mask, 'x0': x0, 'xi': xi}, disable=False) else: x0, z_mask = init_data @@ -1708,14 +1385,18 @@ def img2img(prompt: str, image_editor_mode: str, mask_mode: str, mask_blur_stren return samples_ddim - correction_target = None + if loopback: output_images, info = None, None history = [] initial_seed = None - # turn on color correction for loopback to prevent known issue of color drift - do_color_correction = True + do_color_correction = False + try: + from skimage import exposure + do_color_correction = True + except: + print("Install scikit-image to perform color correction on loopback") for i in range(n_iter): if do_color_correction and i == 0: @@ -1737,7 +1418,6 @@ def img2img(prompt: str, image_editor_mode: str, mask_mode: str, mask_blur_stren width=width, height=height, prompt_matrix=prompt_matrix, - filter_nsfw=filter_nsfw, use_GFPGAN=use_GFPGAN, use_RealESRGAN=False, # Forcefully disable upscaling when using loopback realesrgan_model_name=realesrgan_model_name, @@ -1748,7 +1428,6 @@ def img2img(prompt: str, image_editor_mode: str, mask_mode: str, mask_blur_stren init_mask=init_mask, keep_mask=keep_mask, mask_blur_strength=mask_blur_strength, - mask_restore=mask_restore, denoising_strength=denoising_strength, resize_mode=resize_mode, uses_loopback=loopback, @@ -1757,9 +1436,7 @@ def img2img(prompt: str, image_editor_mode: str, mask_mode: str, mask_blur_stren write_info_files=write_info_files, write_sample_info_to_log_file=write_sample_info_to_log_file, jpg_sample=jpg_sample, - job_info=job_info, - do_color_correction=do_color_correction, - correction_target=correction_target + job_info=job_info ) if initial_seed is None: @@ -1767,6 +1444,16 @@ def img2img(prompt: str, image_editor_mode: str, mask_mode: str, mask_blur_stren init_img = output_images[0] + if do_color_correction and correction_target is not None: + init_img = Image.fromarray(cv2.cvtColor(exposure.match_histograms( + cv2.cvtColor( + np.asarray(init_img), + cv2.COLOR_RGB2LAB + ), + correction_target, + channel_axis=2 + ), cv2.COLOR_LAB2RGB).astype("uint8")) + if not random_seed_loopback: seed = seed + 1 else: @@ -1785,9 +1472,6 @@ def img2img(prompt: str, image_editor_mode: str, mask_mode: str, mask_blur_stren seed = initial_seed else: - if do_color_correction: - correction_target = cv2.cvtColor(np.asarray(init_img.copy()), cv2.COLOR_RGB2LAB) - output_images, seed, info, stats = process_images( outpath=outpath, func_init=init, @@ -1804,7 +1488,6 @@ def img2img(prompt: str, image_editor_mode: str, mask_mode: str, mask_blur_stren width=width, height=height, prompt_matrix=prompt_matrix, - filter_nsfw=filter_nsfw, use_GFPGAN=use_GFPGAN, use_RealESRGAN=use_RealESRGAN, realesrgan_model_name=realesrgan_model_name, @@ -1815,16 +1498,13 @@ def img2img(prompt: str, image_editor_mode: str, mask_mode: str, mask_blur_stren keep_mask=keep_mask, mask_blur_strength=mask_blur_strength, denoising_strength=denoising_strength, - mask_restore=mask_restore, resize_mode=resize_mode, uses_loopback=loopback, sort_samples=sort_samples, write_info_files=write_info_files, write_sample_info_to_log_file=write_sample_info_to_log_file, jpg_sample=jpg_sample, - job_info=job_info, - do_color_correction=do_color_correction, - correction_target=correction_target + job_info=job_info ) del sampler @@ -1892,13 +1572,8 @@ def imgproc(image,image_batch,imgproc_prompt,imgproc_toggles, imgproc_upscale_to images = [] def processGFPGAN(image,strength): image = image.convert("RGB") - metadata = ImageMetadata.get_from_image(image) cropped_faces, restored_faces, restored_img = GFPGAN.enhance(np.array(image, dtype=np.uint8), has_aligned=False, only_center_face=False, paste_back=True) result = Image.fromarray(restored_img) - if metadata: - metadata.GFPGAN = True - ImageMetadata.set_on_image(image, metadata) - if strength < 1.0: result = Image.blend(image, result, strength) @@ -1910,18 +1585,15 @@ def imgproc(image,image_batch,imgproc_prompt,imgproc_toggles, imgproc_upscale_to else: modelMode = imgproc_realesrgan_model_name image = image.convert("RGB") - metadata = ImageMetadata.get_from_image(image) RealESRGAN = load_RealESRGAN(modelMode) result, res = RealESRGAN.enhance(np.array(image, dtype=np.uint8)) result = Image.fromarray(result) - ImageMetadata.set_on_image(result, metadata) if 'x2' in imgproc_realesrgan_model_name: # downscale to 1/2 size result = result.resize((result.width//2, result.height//2), LANCZOS) return result def processGoBig(image): - metadata = ImageMetadata.get_from_image(image) result = processRealESRGAN(image,) if 'x4' in imgproc_realesrgan_model_name: #downscale to 1/2 size @@ -1966,7 +1638,6 @@ def imgproc(image,image_batch,imgproc_prompt,imgproc_toggles, imgproc_upscale_to init_img = result init_mask = None keep_mask = False - mask_restore = False assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]' def init(): @@ -1992,7 +1663,7 @@ def imgproc(image,image_batch,imgproc_prompt,imgproc_toggles, imgproc_upscale_to return init_latent, - def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name, img_callback: Callable = None): + def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name): if sampler_name != 'DDIM': x0, = init_data @@ -2002,7 +1673,7 @@ def imgproc(image,image_batch,imgproc_prompt,imgproc_toggles, imgproc_upscale_to xi = x0 + noise sigma_sched = sigmas[ddim_steps - t_enc - 1:] model_wrap_cfg = CFGDenoiser(sampler.model_wrap) - samples_ddim = K.sampling.__dict__[f'sample_{sampler.get_sampler_name()}'](model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': cfg_scale}, disable=False, callback=partial(KDiffusionSampler.img_callback_wrapper, img_callback)) + samples_ddim = K.sampling.__dict__[f'sample_{sampler.get_sampler_name()}'](model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': cfg_scale}, disable=False) else: x0, = init_data sampler.make_schedule(ddim_num_steps=ddim_steps, ddim_eta=0.0, verbose=False) @@ -2103,7 +1774,6 @@ def imgproc(image,image_batch,imgproc_prompt,imgproc_toggles, imgproc_upscale_to width=width, height=height, prompt_matrix=None, - filter_nsfw=False, use_GFPGAN=None, use_RealESRGAN=None, realesrgan_model_name=None, @@ -2114,7 +1784,6 @@ def imgproc(image,image_batch,imgproc_prompt,imgproc_toggles, imgproc_upscale_to keep_mask=False, mask_blur_strength=None, denoising_strength=denoising_strength, - mask_restore=mask_restore, resize_mode=resize_mode, uses_loopback=False, sort_samples=True, @@ -2139,14 +1808,11 @@ def imgproc(image,image_batch,imgproc_prompt,imgproc_toggles, imgproc_upscale_to del sampler torch.cuda.empty_cache() - ImageMetadata.set_on_image(combined_image, metadata) return combined_image def processLDSR(image): - metadata = ImageMetadata.get_from_image(image) result = LDSR.superResolution(image,int(imgproc_ldsr_steps),str(imgproc_ldsr_pre_downSample),str(imgproc_ldsr_post_downSample)) - ImageMetadata.set_on_image(result, metadata) - return result - + return result + if image_batch != None: if image != None: @@ -2173,7 +1839,7 @@ def imgproc(image,image_batch,imgproc_prompt,imgproc_toggles, imgproc_upscale_to if 1 in imgproc_toggles: if imgproc_upscale_toggles == 0: ModelLoader(['GFPGAN','LDSR'],False,True) # Unload unused models - ModelLoader(['RealESGAN'],True,False,imgproc_realesrgan_model_name) # Load used models + ModelLoader(['RealESGAN'],True,False,imgproc_realesrgan_model_name) # Load used models elif imgproc_upscale_toggles == 1: ModelLoader(['GFPGAN','LDSR'],False,True) # Unload unused models ModelLoader(['RealESGAN','model'],True,False) # Load used models @@ -2185,14 +1851,10 @@ def imgproc(image,image_batch,imgproc_prompt,imgproc_toggles, imgproc_upscale_to ModelLoader(['GFPGAN','LDSR'],False,True) # Unload unused models ModelLoader(['RealESGAN','model'],True,False,imgproc_realesrgan_model_name) # Load used models for image in images: - metadata = ImageMetadata.get_from_image(image) if 0 in imgproc_toggles: #recheck if GFPGAN is loaded since it's the only model that can be loaded in the loop as well ModelLoader(['GFPGAN'],True,False) # Load used models image = processGFPGAN(image,imgproc_gfpgan_strength) - if metadata: - metadata.GFPGAN = True - ImageMetadata.set_on_image(image, metadata) outpathDir = os.path.join(outpath,'GFPGAN') os.makedirs(outpathDir, exist_ok=True) batchNumber = get_next_sequence_number(outpathDir) @@ -2200,51 +1862,47 @@ def imgproc(image,image_batch,imgproc_prompt,imgproc_toggles, imgproc_upscale_to if 1 not in imgproc_toggles: output.append(image) - save_sample(image, outpathDir, outFilename, False, None, None, None, None, None, False, None, None, None, None, None, None, None, None, None, False) + save_sample(image, outpathDir, outFilename, False, None, None, None, None, None, None, None, None, None, None, None, None, None, False, None, None, None, None, None, None, None, None, None, True) if 1 in imgproc_toggles: if imgproc_upscale_toggles == 0: image = processRealESRGAN(image) - ImageMetadata.set_on_image(image, metadata) outpathDir = os.path.join(outpath,'RealESRGAN') os.makedirs(outpathDir, exist_ok=True) batchNumber = get_next_sequence_number(outpathDir) outFilename = str(batchNumber)+'-'+'result' output.append(image) - save_sample(image, outpathDir, outFilename, False, None, None, None, None, None, False, None, None, None, None, None, None, None, None, None, False) + save_sample(image, outpathDir, outFilename, False, None, None, None, None, None, None, None, None, None, None, None, None, None, False, None, None, None, None, None, None, None, None, None, True) elif imgproc_upscale_toggles == 1: image = processGoBig(image) - ImageMetadata.set_on_image(image, metadata) outpathDir = os.path.join(outpath,'GoBig') os.makedirs(outpathDir, exist_ok=True) batchNumber = get_next_sequence_number(outpathDir) outFilename = str(batchNumber)+'-'+'result' output.append(image) - save_sample(image, outpathDir, outFilename, False, None, None, None, None, None, False, None, None, None, None, None, None, None, None, None, False) + save_sample(image, outpathDir, outFilename, False, None, None, None, None, None, None, None, None, None, None, None, None, None, False, None, None, None, None, None, None, None, None, None, True) elif imgproc_upscale_toggles == 2: image = processLDSR(image) - ImageMetadata.set_on_image(image, metadata) outpathDir = os.path.join(outpath,'LDSR') os.makedirs(outpathDir, exist_ok=True) batchNumber = get_next_sequence_number(outpathDir) outFilename = str(batchNumber)+'-'+'result' output.append(image) - save_sample(image, outpathDir, outFilename, False, None, None, None, None, None, False, None, None, None, None, None, None, None, None, None, False) + save_sample(image, outpathDir, outFilename, False, None, None, None, None, None, None, None, None, None, None, None, None, None, False, None, None, None, None, None, None, None, None, None, True) elif imgproc_upscale_toggles == 3: image = processGoBig(image) ModelLoader(['model','GFPGAN','RealESGAN'],False,True) # Unload unused models ModelLoader(['LDSR'],True,False) # Load used models image = processLDSR(image) - ImageMetadata.set_on_image(image, metadata) outpathDir = os.path.join(outpath,'GoLatent') os.makedirs(outpathDir, exist_ok=True) batchNumber = get_next_sequence_number(outpathDir) outFilename = str(batchNumber)+'-'+'result' output.append(image) - save_sample(image, outpathDir, outFilename, None, None, None, None, None, None, False, None, None, None, None, None, None, None, None, None, False) + save_sample(image, outpathDir, outFilename, None, None, None, None, None, None, None, None, None, None, None, None, None, None, False, None, None, None, None, None, None, None, None, None, True) #LDSR is always unloaded to avoid memory issues #ModelLoader(['LDSR'],False,True) @@ -2294,13 +1952,10 @@ def ModelLoader(models,load=False,unload=False,imgproc_realesrgan_model_name='Re def run_GFPGAN(image, strength): ModelLoader(['LDSR','RealESRGAN'],False,True) ModelLoader(['GFPGAN'],True,False) - metadata = ImageMetadata.get_from_image(image) image = image.convert("RGB") cropped_faces, restored_faces, restored_img = GFPGAN.enhance(np.array(image, dtype=np.uint8), has_aligned=False, only_center_face=False, paste_back=True) res = Image.fromarray(restored_img) - metadata.GFPGAN = True - ImageMetadata.set_on_image(res, metadata) if strength < 1.0: res = Image.blend(image, res, strength) @@ -2313,12 +1968,10 @@ def run_RealESRGAN(image, model_name: str): if RealESRGAN.model.name != model_name: try_loading_RealESRGAN(model_name) - metadata = ImageMetadata.get_from_image(image) image = image.convert("RGB") output, img_mode = RealESRGAN.enhance(np.array(image, dtype=np.uint8)) res = Image.fromarray(output) - ImageMetadata.set_on_image(res, metadata) return res @@ -2344,7 +1997,6 @@ txt2img_toggles = [ 'Write sample info files', 'write sample info to log file', 'jpg samples', - 'Filter NSFW content', ] if GFPGAN is not None: @@ -2405,8 +2057,6 @@ img2img_toggles = [ 'Write sample info files', 'Write sample info to one file', 'jpg samples', - 'Color correction (always enabled on loopback mode)', - 'Filter NSFW content', ] # removed for now becuase of Image Lab implementation if GFPGAN is not None: @@ -2436,7 +2086,6 @@ img2img_defaults = { 'cfg_scale': 5.0, 'denoising_strength': 0.75, 'mask_mode': 0, - 'mask_restore': False, 'resize_mode': 0, 'seed': '', 'height': 512, @@ -2450,6 +2099,24 @@ if 'img2img' in user_defaults: img2img_toggle_defaults = [img2img_toggles[i] for i in img2img_defaults['toggles']] img2img_image_mode = 'sketch' +def change_image_editor_mode(choice, cropped_image, resize_mode, width, height): + if choice == "Mask": + return [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), gr.update(visible=True)] + return [gr.update(visible=True), gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=True), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)] + +def update_image_mask(cropped_image, resize_mode, width, height): + resized_cropped_image = resize_image(resize_mode, cropped_image, width, height) if cropped_image else None + return gr.update(value=resized_cropped_image) + + + +def copy_img_to_upscale_esrgan(img): + update = gr.update(selected='realesrgan_tab') + image_data = re.sub('^data:image/.+;base64,', '', img) + processed_image = Image.open(BytesIO(base64.b64decode(image_data))) + return {'realesrgan_source': processed_image, 'tabs': update} + + help_text = """ ## Mask/Crop * The masking/cropping is very temperamental. @@ -2511,7 +2178,7 @@ class ServerLauncher(threading.Thread): 'inbrowser': opt.inbrowser, 'server_name': '0.0.0.0', 'server_port': opt.port, - 'share': opt.share, + 'share': opt.share, 'show_error': True } if not opt.share: diff --git a/scripts/webui_streamlit.py b/scripts/webui_streamlit.py index 15a2e2f..bd680da 100644 --- a/scripts/webui_streamlit.py +++ b/scripts/webui_streamlit.py @@ -1,31 +1,46 @@ -# base webui import and utils. -import streamlit as st - -# streamlit imports -import streamlit_nested_layout - -#streamlit components section -from st_on_hover_tabs import on_hover_tabs - -#other imports - import warnings -import os +import streamlit as st +from streamlit import StopException, StreamlitAPIException + +import base64, cv2 +import argparse, os, sys, glob, re, random, datetime +from PIL import Image, ImageFont, ImageDraw, ImageFilter, ImageOps +from PIL.PngImagePlugin import PngInfo +import requests +from scipy import integrate +import torch +from torchdiffeq import odeint +from tqdm.auto import trange, tqdm import k_diffusion as K +import math +import mimetypes +import numpy as np +import pynvml +import threading, asyncio +import time +import torch +from torch import autocast +from torchvision import transforms +import torch.nn as nn +import yaml +from typing import List, Union +from pathlib import Path +from tqdm import tqdm +from contextlib import contextmanager, nullcontext +from einops import rearrange, repeat +from itertools import islice from omegaconf import OmegaConf +from io import BytesIO +from ldm.models.diffusion.ddim import DDIMSampler +from ldm.models.diffusion.plms import PLMSSampler +from ldm.util import instantiate_from_config +from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, \ + extract_into_tensor +from retry import retry -from sd_utils import * -if not "defaults" in st.session_state: - st.session_state["defaults"] = {} - -st.session_state["defaults"] = OmegaConf.load("configs/webui/webui_streamlit.yaml") - -if (os.path.exists("configs/webui/userconfig_streamlit.yaml")): - user_defaults = OmegaConf.load("configs/webui/userconfig_streamlit.yaml") - st.session_state["defaults"] = OmegaConf.merge(st.session_state["defaults"], user_defaults) - -# end of imports -#--------------------------------------------------------------------------------------------------------------- +# we use python-slugify to make the filenames safe for windows and linux, its better than doing it manually +# install it with 'pip install python-slugify' +from slugify import slugify try: # this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start. @@ -38,18 +53,1412 @@ except: # remove some annoying deprecation warnings that show every now and then. warnings.filterwarnings("ignore", category=DeprecationWarning) +defaults = OmegaConf.load("configs/webui/webui_streamlit.yaml") + +# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the bowser will not show any UI +mimetypes.init() +mimetypes.add_type('application/javascript', '.js') + +# some of those options should not be changed at all because they would break the model, so I removed them from options. +opt_C = 4 +opt_f = 8 + +# should and will be moved to a settings menu in the UI at some point +grid_format = [s.lower() for s in defaults.general.grid_format.split(':')] +grid_lossless = False +grid_quality = 100 +if grid_format[0] == 'png': + grid_ext = 'png' + grid_format = 'png' +elif grid_format[0] in ['jpg', 'jpeg']: + grid_quality = int(grid_format[1]) if len(grid_format) > 1 else 100 + grid_ext = 'jpg' + grid_format = 'jpeg' +elif grid_format[0] == 'webp': + grid_quality = int(grid_format[1]) if len(grid_format) > 1 else 100 + grid_ext = 'webp' + grid_format = 'webp' + if grid_quality < 0: # e.g. webp:-100 for lossless mode + grid_lossless = True + grid_quality = abs(grid_quality) + # this should force GFPGAN and RealESRGAN onto the selected gpu as well -#os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" # see issue #152 -#os.environ["CUDA_VISIBLE_DEVICES"] = str(st.session_state["defaults"].general.gpu) +os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" # see issue #152 +os.environ["CUDA_VISIBLE_DEVICES"] = str(defaults.general.gpu) + +@retry(tries=5) +def load_models(continue_prev_run = False, use_GFPGAN=False, use_RealESRGAN=False, RealESRGAN_model="RealESRGAN_x4plus"): + """Load the different models. We also reuse the models that are already in memory to speed things up instead of loading them again. """ + + print ("Loading models.") + + # Generate random run ID + # Used to link runs linked w/ continue_prev_run which is not yet implemented + # Use URL and filesystem safe version just in case. + st.session_state["run_id"] = base64.urlsafe_b64encode( + os.urandom(6) + ).decode("ascii") + + # check what models we want to use and if the they are already loaded. + + if use_GFPGAN: + if "GFPGAN" in st.session_state: + print("GFPGAN already loaded") + else: + # Load GFPGAN + if os.path.exists(defaults.general.GFPGAN_dir): + try: + st.session_state["GFPGAN"] = load_GFPGAN() + print("Loaded GFPGAN") + except Exception: + import traceback + print("Error loading GFPGAN:", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + else: + if "GFPGAN" in st.session_state: + del st.session_state["GFPGAN"] + + if use_RealESRGAN: + if "RealESRGAN" in st.session_state and st.session_state["RealESRGAN"].model.name == RealESRGAN_model: + print("RealESRGAN already loaded") + else: + #Load RealESRGAN + try: + # We first remove the variable in case it has something there, + # some errors can load the model incorrectly and leave things in memory. + del st.session_state["RealESRGAN"] + except KeyError: + pass + + if os.path.exists(defaults.general.RealESRGAN_dir): + # st.session_state is used for keeping the models in memory across multiple pages or runs. + st.session_state["RealESRGAN"] = load_RealESRGAN(RealESRGAN_model) + print("Loaded RealESRGAN with model "+ st.session_state["RealESRGAN"].model.name) + + else: + if "RealESRGAN" in st.session_state: + del st.session_state["RealESRGAN"] + + + if "model" in st.session_state: + print("Model already loaded") + else: + config = OmegaConf.load("configs/stable-diffusion/v1-inference.yaml") + model = load_model_from_config(config, defaults.general.ckpt) + + st.session_state["device"] = torch.device(f"cuda:{defaults.general.gpu}") if torch.cuda.is_available() else torch.device("cpu") + st.session_state["model"] = (model if defaults.general.no_half else model.half()).to(st.session_state["device"] ) + + print("Model loaded.") + + +def load_model_from_config(config, ckpt, verbose=False): + + print(f"Loading model from {ckpt}") + + pl_sd = torch.load(ckpt, map_location="cpu") + if "global_step" in pl_sd: + print(f"Global Step: {pl_sd['global_step']}") + sd = pl_sd["state_dict"] + model = instantiate_from_config(config.model) + m, u = model.load_state_dict(sd, strict=False) + if len(m) > 0 and verbose: + print("missing keys:") + print(m) + if len(u) > 0 and verbose: + print("unexpected keys:") + print(u) + + model.cuda() + model.eval() + return model + +def load_sd_from_config(ckpt, verbose=False): + print(f"Loading model from {ckpt}") + pl_sd = torch.load(ckpt, map_location="cpu") + if "global_step" in pl_sd: + print(f"Global Step: {pl_sd['global_step']}") + sd = pl_sd["state_dict"] + return sd +# +@retry(tries=5) +def generation_callback(img, i=0): + + try: + if i == 0: + if img['i']: i = img['i'] + except TypeError: + pass + + + if i % int(defaults.general.update_preview_frequency) == 0 and defaults.general.update_preview: + #print (img) + #print (type(img)) + # The following lines will convert the tensor we got on img to an actual image we can render on the UI. + # It can probably be done in a better way for someone who knows what they're doing. I don't. + #print (img,isinstance(img, torch.Tensor)) + if isinstance(img, torch.Tensor): + x_samples_ddim = (st.session_state["model"] if not defaults.general.optimized else modelFS).decode_first_stage(img) + else: + # When using the k Diffusion samplers they return a dict instead of a tensor that look like this: + # {'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised} + x_samples_ddim = (st.session_state["model"] if not defaults.general.optimized else modelFS).decode_first_stage(img["denoised"]) + + x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) + + pil_image = transforms.ToPILImage()(x_samples_ddim.squeeze_(0)) + + # update image on the UI so we can see the progress + st.session_state["preview_image"].image(pil_image) + + # Show a progress bar so we can keep track of the progress even when the image progress is not been shown, + # Dont worry, it doesnt affect the performance. + if st.session_state["generation_mode"] == "txt2img": + percent = int(100 * float(i+1 if i+1 < st.session_state.sampling_steps else st.session_state.sampling_steps)/float(st.session_state.sampling_steps)) + st.session_state["progress_bar_text"].text( + f"Running step: {i+1 if i+1 < st.session_state.sampling_steps else st.session_state.sampling_steps}/{st.session_state.sampling_steps} {percent if percent < 100 else 100}%") + else: + round_sampling_steps = round(st.session_state.sampling_steps * st.session_state["denoising_strength"]) + percent = int(100 * float(i+1 if i+1 < round_sampling_steps else round_sampling_steps)/float(round_sampling_steps)) + st.session_state["progress_bar_text"].text( + f"""Running step: {i+1 if i+1 < round_sampling_steps else round_sampling_steps}/{round_sampling_steps} {percent if percent < 100 else 100}%""") + + st.session_state["progress_bar"].progress(percent if percent < 100 else 100) + + + +class MemUsageMonitor(threading.Thread): + stop_flag = False + max_usage = 0 + total = -1 + + def __init__(self, name): + threading.Thread.__init__(self) + self.name = name + + def run(self): + try: + pynvml.nvmlInit() + except: + print(f"[{self.name}] Unable to initialize NVIDIA management. No memory stats. \n") + return + print(f"[{self.name}] Recording max memory usage...\n") + handle = pynvml.nvmlDeviceGetHandleByIndex(defaults.general.gpu) + self.total = pynvml.nvmlDeviceGetMemoryInfo(handle).total + while not self.stop_flag: + m = pynvml.nvmlDeviceGetMemoryInfo(handle) + self.max_usage = max(self.max_usage, m.used) + # print(self.max_usage) + time.sleep(0.1) + print(f"[{self.name}] Stopped recording.\n") + pynvml.nvmlShutdown() + + def read(self): + return self.max_usage, self.total + + def stop(self): + self.stop_flag = True + + def read_and_stop(self): + self.stop_flag = True + return self.max_usage, self.total + +class CFGMaskedDenoiser(nn.Module): + def __init__(self, model): + super().__init__() + self.inner_model = model + + def forward(self, x, sigma, uncond, cond, cond_scale, mask, x0, xi): + x_in = x + x_in = torch.cat([x_in] * 2) + sigma_in = torch.cat([sigma] * 2) + cond_in = torch.cat([uncond, cond]) + uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2) + denoised = uncond + (cond - uncond) * cond_scale + + if mask is not None: + assert x0 is not None + img_orig = x0 + mask_inv = 1. - mask + denoised = (img_orig * mask_inv) + (mask * denoised) + + return denoised + +class CFGDenoiser(nn.Module): + def __init__(self, model): + super().__init__() + self.inner_model = model + + def forward(self, x, sigma, uncond, cond, cond_scale): + x_in = torch.cat([x] * 2) + sigma_in = torch.cat([sigma] * 2) + cond_in = torch.cat([uncond, cond]) + uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2) + return uncond + (cond - uncond) * cond_scale +def append_zero(x): + return torch.cat([x, x.new_zeros([1])]) +def append_dims(x, target_dims): + """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" + dims_to_append = target_dims - x.ndim + if dims_to_append < 0: + raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less') + return x[(...,) + (None,) * dims_to_append] +def get_sigmas_karras(n, sigma_min, sigma_max, rho=7., device='cpu'): + """Constructs the noise schedule of Karras et al. (2022).""" + ramp = torch.linspace(0, 1, n) + min_inv_rho = sigma_min ** (1 / rho) + max_inv_rho = sigma_max ** (1 / rho) + sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho + return append_zero(sigmas).to(device) + + +def get_sigmas_exponential(n, sigma_min, sigma_max, device='cpu'): + """Constructs an exponential noise schedule.""" + sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), n, device=device).exp() + return append_zero(sigmas) + + +def get_sigmas_vp(n, beta_d=19.9, beta_min=0.1, eps_s=1e-3, device='cpu'): + """Constructs a continuous VP noise schedule.""" + t = torch.linspace(1, eps_s, n, device=device) + sigmas = torch.sqrt(torch.exp(beta_d * t ** 2 / 2 + beta_min * t) - 1) + return append_zero(sigmas) + + +def to_d(x, sigma, denoised): + """Converts a denoiser output to a Karras ODE derivative.""" + return (x - denoised) / append_dims(sigma, x.ndim) +def linear_multistep_coeff(order, t, i, j): + if order - 1 > i: + raise ValueError(f'Order {order} too high for step {i}') + def fn(tau): + prod = 1. + for k in range(order): + if j == k: + continue + prod *= (tau - t[i - k]) / (t[i - j] - t[i - k]) + return prod + return integrate.quad(fn, t[i], t[i + 1], epsrel=1e-4)[0] + +class KDiffusionSampler: + def __init__(self, m, sampler): + self.model = m + self.model_wrap = K.external.CompVisDenoiser(m) + self.schedule = sampler + def get_sampler_name(self): + return self.schedule + def sample(self, S, conditioning, batch_size, shape, verbose, unconditional_guidance_scale, unconditional_conditioning, eta, x_T, img_callback=None, log_every_t=None): + sigmas = self.model_wrap.get_sigmas(S) + x = x_T * sigmas[0] + model_wrap_cfg = CFGDenoiser(self.model_wrap) + samples_ddim = None + samples_ddim = K.sampling.__dict__[f'sample_{self.schedule}'](model_wrap_cfg, x, sigmas, + extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, + 'cond_scale': unconditional_guidance_scale}, disable=False, callback=generation_callback) + # + return samples_ddim, None + + +@torch.no_grad() +def log_likelihood(model, x, sigma_min, sigma_max, extra_args=None, atol=1e-4, rtol=1e-4): + extra_args = {} if extra_args is None else extra_args + s_in = x.new_ones([x.shape[0]]) + v = torch.randint_like(x, 2) * 2 - 1 + fevals = 0 + def ode_fn(sigma, x): + nonlocal fevals + with torch.enable_grad(): + x = x[0].detach().requires_grad_() + denoised = model(x, sigma * s_in, **extra_args) + d = to_d(x, sigma, denoised) + fevals += 1 + grad = torch.autograd.grad((d * v).sum(), x)[0] + d_ll = (v * grad).flatten(1).sum(1) + return d.detach(), d_ll + x_min = x, x.new_zeros([x.shape[0]]) + t = x.new_tensor([sigma_min, sigma_max]) + sol = odeint(ode_fn, x_min, t, atol=atol, rtol=rtol, method='dopri5') + latent, delta_ll = sol[0][-1], sol[1][-1] + ll_prior = torch.distributions.Normal(0, sigma_max).log_prob(latent).flatten(1).sum(1) + return ll_prior + delta_ll, {'fevals': fevals} + + +def create_random_tensors(shape, seeds): + xs = [] + for seed in seeds: + torch.manual_seed(seed) + + # randn results depend on device; gpu and cpu get different results for same seed; + # the way I see it, it's better to do this on CPU, so that everyone gets same result; + # but the original script had it like this so i do not dare change it for now because + # it will break everyone's seeds. + xs.append(torch.randn(shape, device=defaults.general.gpu)) + x = torch.stack(xs) + return x + +def torch_gc(): + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + +def load_GFPGAN(): + model_name = 'GFPGANv1.3' + model_path = os.path.join(defaults.general.GFPGAN_dir, 'experiments/pretrained_models', model_name + '.pth') + if not os.path.isfile(model_path): + raise Exception("GFPGAN model not found at path "+model_path) + + sys.path.append(os.path.abspath(defaults.general.GFPGAN_dir)) + from gfpgan import GFPGANer + + if defaults.general.gfpgan_cpu or defaults.general.extra_models_cpu: + instance = GFPGANer(model_path=model_path, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=torch.device('cpu')) + elif defaults.general.extra_models_gpu: + instance = GFPGANer(model_path=model_path, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=torch.device(f'cuda:{defaults.general.gfpgan_gpu}')) + else: + instance = GFPGANer(model_path=model_path, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=torch.device(f'cuda:{defaults.general.gpu}')) + return instance + +def load_RealESRGAN(model_name: str): + from basicsr.archs.rrdbnet_arch import RRDBNet + RealESRGAN_models = { + 'RealESRGAN_x4plus': RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4), + 'RealESRGAN_x4plus_anime_6B': RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4) + } + + model_path = os.path.join(defaults.general.RealESRGAN_dir, 'experiments/pretrained_models', model_name + '.pth') + if not os.path.exists(os.path.join(defaults.general.RealESRGAN_dir, "experiments","pretrained_models", f"{model_name}.pth")): + raise Exception(model_name+".pth not found at path "+model_path) + + sys.path.append(os.path.abspath(defaults.general.RealESRGAN_dir)) + from realesrgan import RealESRGANer + + if defaults.general.esrgan_cpu or defaults.general.extra_models_cpu: + instance = RealESRGANer(scale=2, model_path=model_path, model=RealESRGAN_models[model_name], pre_pad=0, half=False) # cpu does not support half + instance.device = torch.device('cpu') + instance.model.to('cpu') + elif defaults.general.extra_models_gpu: + instance = RealESRGANer(scale=2, model_path=model_path, model=RealESRGAN_models[model_name], pre_pad=0, half=not defaults.general.no_half, device=torch.device(f'cuda:{defaults.general.esrgan_gpu}')) + else: + instance = RealESRGANer(scale=2, model_path=model_path, model=RealESRGAN_models[model_name], pre_pad=0, half=not defaults.general.no_half, device=torch.device(f'cuda:{defaults.general.gpu}')) + instance.model.name = model_name + + return instance + +prompt_parser = re.compile(""" + (?P # capture group for 'prompt' + [^:]+ # match one or more non ':' characters + ) # end 'prompt' + (?: # non-capture group + :+ # match one or more ':' characters + (?P # capture group for 'weight' + -?\\d+(?:\\.\\d+)? # match positive or negative decimal number + )? # end weight capture group, make optional + \\s* # strip spaces after weight + | # OR + $ # else, if no ':' then match end of line + ) # end non-capture group +""", re.VERBOSE) + +# grabs all text up to the first occurrence of ':' as sub-prompt +# takes the value following ':' as weight +# if ':' has no value defined, defaults to 1.0 +# repeats until no text remaining +def split_weighted_subprompts(input_string, normalize=True): + parsed_prompts = [(match.group("prompt"), float(match.group("weight") or 1)) for match in re.finditer(prompt_parser, input_string)] + if not normalize: + return parsed_prompts + # this probably still doesn't handle negative weights very well + weight_sum = sum(map(lambda x: x[1], parsed_prompts)) + return [(x[0], x[1] / weight_sum) for x in parsed_prompts] + +def slerp(device, t, v0:torch.Tensor, v1:torch.Tensor, DOT_THRESHOLD=0.9995): + v0 = v0.detach().cpu().numpy() + v1 = v1.detach().cpu().numpy() + + dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1))) + if np.abs(dot) > DOT_THRESHOLD: + v2 = (1 - t) * v0 + t * v1 + else: + theta_0 = np.arccos(dot) + sin_theta_0 = np.sin(theta_0) + theta_t = theta_0 * t + sin_theta_t = np.sin(theta_t) + s0 = np.sin(theta_0 - theta_t) / sin_theta_0 + s1 = sin_theta_t / sin_theta_0 + v2 = s0 * v0 + s1 * v1 + + v2 = torch.from_numpy(v2).to(device) + + return v2 + + +def ModelLoader(models,load=False,unload=False,imgproc_realesrgan_model_name='RealESRGAN_x4plus'): + #get global variables + global_vars = globals() + #check if m is in globals + if unload: + for m in models: + if m in global_vars: + #if it is, delete it + del global_vars[m] + if defaults.general.optimized: + if m == 'model': + del global_vars[m+'FS'] + del global_vars[m+'CS'] + if m =='model': + m='Stable Diffusion' + print('Unloaded ' + m) + if load: + for m in models: + if m not in global_vars or m in global_vars and type(global_vars[m]) == bool: + #if it isn't, load it + if m == 'GFPGAN': + global_vars[m] = load_GFPGAN() + elif m == 'model': + sdLoader = load_sd_from_config() + global_vars[m] = sdLoader[0] + if defaults.general.optimized: + global_vars[m+'CS'] = sdLoader[1] + global_vars[m+'FS'] = sdLoader[2] + elif m == 'RealESRGAN': + global_vars[m] = load_RealESRGAN(imgproc_realesrgan_model_name) + elif m == 'LDSR': + global_vars[m] = load_LDSR() + if m =='model': + m='Stable Diffusion' + print('Loaded ' + m) + torch_gc() + + + +def get_font(fontsize): + fonts = ["arial.ttf", "DejaVuSans.ttf"] + for font_name in fonts: + try: + return ImageFont.truetype(font_name, fontsize) + except OSError: + pass + + # ImageFont.load_default() is practically unusable as it only supports + # latin1, so raise an exception instead if no usable font was found + raise Exception(f"No usable font found (tried {', '.join(fonts)})") + +def load_embeddings(fp): + if fp is not None and hasattr(st.session_state["model"], "embedding_manager"): + st.session_state["model"].embedding_manager.load(fp['name']) + +def image_grid(imgs, batch_size, force_n_rows=None, captions=None): + #print (len(imgs)) + if force_n_rows is not None: + rows = force_n_rows + elif defaults.general.n_rows > 0: + rows = defaults.general.n_rows + elif defaults.general.n_rows == 0: + rows = batch_size + else: + rows = math.sqrt(len(imgs)) + rows = round(rows) + + cols = math.ceil(len(imgs) / rows) + + w, h = imgs[0].size + grid = Image.new('RGB', size=(cols * w, rows * h), color='black') + + fnt = get_font(30) + + for i, img in enumerate(imgs): + grid.paste(img, box=(i % cols * w, i // cols * h)) + if captions and i= 2**32: + n = n >> 32 + return n + +def check_prompt_length(prompt, comments): + """this function tests if prompt is too long, and if so, adds a message to comments""" + + tokenizer = (st.session_state["model"] if not defaults.general.optimized else modelCS).cond_stage_model.tokenizer + max_length = (st.session_state["model"] if not defaults.general.optimized else modelCS).cond_stage_model.max_length + + info = (st.session_state["model"] if not defaults.general.optimized else modelCS).cond_stage_model.tokenizer([prompt], truncation=True, max_length=max_length, + return_overflowing_tokens=True, padding="max_length", return_tensors="pt") + ovf = info['overflowing_tokens'][0] + overflowing_count = ovf.shape[0] + if overflowing_count == 0: + return + + vocab = {v: k for k, v in tokenizer.get_vocab().items()} + overflowing_words = [vocab.get(int(x), "") for x in ovf] + overflowing_text = tokenizer.convert_tokens_to_string(''.join(overflowing_words)) + + comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n") + +def save_sample(image, sample_path_i, filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale, + normalize_prompt_weights, use_GFPGAN, write_info_files, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, + save_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, save_individual_images): + + filename_i = os.path.join(sample_path_i, filename) + + if not jpg_sample: + if defaults.general.save_metadata: + metadata = PngInfo() + metadata.add_text("SD:prompt", prompts[i]) + metadata.add_text("SD:seed", str(seeds[i])) + metadata.add_text("SD:width", str(width)) + metadata.add_text("SD:height", str(height)) + metadata.add_text("SD:steps", str(steps)) + metadata.add_text("SD:cfg_scale", str(cfg_scale)) + metadata.add_text("SD:normalize_prompt_weights", str(normalize_prompt_weights)) + if init_img is not None: + metadata.add_text("SD:denoising_strength", str(denoising_strength)) + metadata.add_text("SD:GFPGAN", str(use_GFPGAN and st.session_state["GFPGAN"] is not None)) + image.save(f"{filename_i}.png", pnginfo=metadata) + else: + image.save(f"{filename_i}.png") + else: + image.save(f"{filename_i}.jpg", 'jpeg', quality=100, optimize=True) + + if write_info_files: + # toggles differ for txt2img vs. img2img: + offset = 0 if init_img is None else 2 + toggles = [] + if prompt_matrix: + toggles.append(0) + if normalize_prompt_weights: + toggles.append(1) + if init_img is not None: + if uses_loopback: + toggles.append(2) + if uses_random_seed_loopback: + toggles.append(3) + if save_individual_images: + toggles.append(2 + offset) + if save_grid: + toggles.append(3 + offset) + if sort_samples: + toggles.append(4 + offset) + if write_info_files: + toggles.append(5 + offset) + if use_GFPGAN: + toggles.append(6 + offset) + info_dict = dict( + target="txt2img" if init_img is None else "img2img", + prompt=prompts[i], ddim_steps=steps, toggles=toggles, sampler_name=sampler_name, + ddim_eta=ddim_eta, n_iter=n_iter, batch_size=batch_size, cfg_scale=cfg_scale, + seed=seeds[i], width=width, height=height + ) + if init_img is not None: + # Not yet any use for these, but they bloat up the files: + #info_dict["init_img"] = init_img + #info_dict["init_mask"] = init_mask + info_dict["denoising_strength"] = denoising_strength + info_dict["resize_mode"] = resize_mode + with open(f"{filename_i}.yaml", "w", encoding="utf8") as f: + yaml.dump(info_dict, f, allow_unicode=True, width=10000) + + # render the image on the frontend + st.session_state["preview_image"].image(image) + +def get_next_sequence_number(path, prefix=''): + """ + Determines and returns the next sequence number to use when saving an + image in the specified directory. + + If a prefix is given, only consider files whose names start with that + prefix, and strip the prefix from filenames before extracting their + sequence number. + + The sequence starts at 0. + """ + result = -1 + for p in Path(path).iterdir(): + if p.name.endswith(('.png', '.jpg')) and p.name.startswith(prefix): + tmp = p.name[len(prefix):] + try: + result = max(int(tmp.split('-')[0]), result) + except ValueError: + pass + return result + 1 + + +def oxlamon_matrix(prompt, seed, n_iter, batch_size): + pattern = re.compile(r'(,\s){2,}') + + class PromptItem: + def __init__(self, text, parts, item): + self.text = text + self.parts = parts + if item: + self.parts.append( item ) + + def clean(txt): + return re.sub(pattern, ', ', txt) + + def getrowcount( txt ): + for data in re.finditer( ".*?\\((.*?)\\).*", txt ): + if data: + return len(data.group(1).split("|")) + break + return None + + def repliter( txt ): + for data in re.finditer( ".*?\\((.*?)\\).*", txt ): + if data: + r = data.span(1) + for item in data.group(1).split("|"): + yield (clean(txt[:r[0]-1] + item.strip() + txt[r[1]+1:]), item.strip()) + break + + def iterlist( items ): + outitems = [] + for item in items: + for newitem, newpart in repliter(item.text): + outitems.append( PromptItem(newitem, item.parts.copy(), newpart) ) + + return outitems + + def getmatrix( prompt ): + dataitems = [ PromptItem( prompt[1:].strip(), [], None ) ] + while True: + newdataitems = iterlist( dataitems ) + if len( newdataitems ) == 0: + return dataitems + dataitems = newdataitems + + def classToArrays( items, seed, n_iter ): + texts = [] + parts = [] + seeds = [] + + for item in items: + itemseed = seed + for i in range(n_iter): + texts.append( item.text ) + parts.append( f"Seed: {itemseed}\n" + "\n".join(item.parts) ) + seeds.append( itemseed ) + itemseed += 1 + + return seeds, texts, parts + + all_seeds, all_prompts, prompt_matrix_parts = classToArrays(getmatrix( prompt ), seed, n_iter) + n_iter = math.ceil(len(all_prompts) / batch_size) + + needrows = getrowcount(prompt) + if needrows: + xrows = math.sqrt(len(all_prompts)) + xrows = round(xrows) + # if columns is to much + cols = math.ceil(len(all_prompts) / xrows) + if cols > needrows*4: + needrows *= 2 + + return all_seeds, n_iter, prompt_matrix_parts, all_prompts, needrows + + +def process_images( + outpath, func_init, func_sample, prompt, seed, sampler_name, save_grid, batch_size, + n_iter, steps, cfg_scale, width, height, prompt_matrix, use_GFPGAN, use_RealESRGAN, realesrgan_model_name, + fp=None, ddim_eta=0.0, normalize_prompt_weights=True, init_img=None, init_mask=None, + keep_mask=False, mask_blur_strength=3, denoising_strength=0.75, resize_mode=None, uses_loopback=False, + uses_random_seed_loopback=False, sort_samples=True, write_info_files=True, jpg_sample=False, + variant_amount=0.0, variant_seed=None, save_individual_images: bool = True): + """this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch""" + assert prompt is not None + torch_gc() + # start time after garbage collection (or before?) + start_time = time.time() + + # We will use this date here later for the folder name, need to start_time if not need + run_start_dt = datetime.datetime.now() + + mem_mon = MemUsageMonitor('MemMon') + mem_mon.start() + + if hasattr(st.session_state["model"], "embedding_manager"): + load_embeddings(fp) + + os.makedirs(outpath, exist_ok=True) + + sample_path = os.path.join(outpath, "samples") + os.makedirs(sample_path, exist_ok=True) + + if not ("|" in prompt) and prompt.startswith("@"): + prompt = prompt[1:] + + comments = [] + + prompt_matrix_parts = [] + simple_templating = False + add_original_image = not (use_RealESRGAN or use_GFPGAN) + + if prompt_matrix: + if prompt.startswith("@"): + simple_templating = True + add_original_image = not (use_RealESRGAN or use_GFPGAN) + all_seeds, n_iter, prompt_matrix_parts, all_prompts, frows = oxlamon_matrix(prompt, seed, n_iter, batch_size) + else: + all_prompts = [] + prompt_matrix_parts = prompt.split("|") + combination_count = 2 ** (len(prompt_matrix_parts) - 1) + for combination_num in range(combination_count): + current = prompt_matrix_parts[0] + + for n, text in enumerate(prompt_matrix_parts[1:]): + if combination_num & (2 ** n) > 0: + current += ("" if text.strip().startswith(",") else ", ") + text + + all_prompts.append(current) + + n_iter = math.ceil(len(all_prompts) / batch_size) + all_seeds = len(all_prompts) * [seed] + + print(f"Prompt matrix will create {len(all_prompts)} images using a total of {n_iter} batches.") + else: + + if not defaults.general.no_verify_input: + try: + check_prompt_length(prompt, comments) + except: + import traceback + print("Error verifying input:", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + + all_prompts = batch_size * n_iter * [prompt] + all_seeds = [seed + x for x in range(len(all_prompts))] + + precision_scope = autocast if defaults.general.precision == "autocast" else nullcontext + output_images = [] + grid_captions = [] + stats = [] + with torch.no_grad(), precision_scope("cuda"), (st.session_state["model"].ema_scope() if not defaults.general.optimized else nullcontext()): + init_data = func_init() + tic = time.time() + + + # if variant_amount > 0.0 create noise from base seed + base_x = None + if variant_amount > 0.0: + target_seed_randomizer = seed_to_int('') # random seed + torch.manual_seed(seed) # this has to be the single starting seed (not per-iteration) + base_x = create_random_tensors([opt_C, height // opt_f, width // opt_f], seeds=[seed]) + # we don't want all_seeds to be sequential from starting seed with variants, + # since that makes the same variants each time, + # so we add target_seed_randomizer as a random offset + for si in range(len(all_seeds)): + all_seeds[si] += target_seed_randomizer + + for n in range(n_iter): + print(f"Iteration: {n+1}/{n_iter}") + prompts = all_prompts[n * batch_size:(n + 1) * batch_size] + captions = prompt_matrix_parts[n * batch_size:(n + 1) * batch_size] + seeds = all_seeds[n * batch_size:(n + 1) * batch_size] + + print(prompt) + + if defaults.general.optimized: + modelCS.to(defaults.general.gpu) + + uc = (st.session_state["model"] if not defaults.general.optimized else modelCS).get_learned_conditioning(len(prompts) * [""]) + + if isinstance(prompts, tuple): + prompts = list(prompts) + + # split the prompt if it has : for weighting + # TODO for speed it might help to have this occur when all_prompts filled?? + weighted_subprompts = split_weighted_subprompts(prompts[0], normalize_prompt_weights) + + # sub-prompt weighting used if more than 1 + if len(weighted_subprompts) > 1: + c = torch.zeros_like(uc) # i dont know if this is correct.. but it works + for i in range(0, len(weighted_subprompts)): + # note if alpha negative, it functions same as torch.sub + c = torch.add(c, (st.session_state["model"] if not defaults.general.optimized else modelCS).get_learned_conditioning(weighted_subprompts[i][0]), alpha=weighted_subprompts[i][1]) + else: # just behave like usual + c = (st.session_state["model"] if not defaults.general.optimized else modelCS).get_learned_conditioning(prompts) + + + shape = [opt_C, height // opt_f, width // opt_f] + + if defaults.general.optimized: + mem = torch.cuda.memory_allocated()/1e6 + modelCS.to("cpu") + while(torch.cuda.memory_allocated()/1e6 >= mem): + time.sleep(1) + + if variant_amount == 0.0: + # we manually generate all input noises because each one should have a specific seed + x = create_random_tensors(shape, seeds=seeds) + + else: # we are making variants + # using variant_seed as sneaky toggle, + # when not None or '' use the variant_seed + # otherwise use seeds + if variant_seed != None and variant_seed != '': + specified_variant_seed = seed_to_int(variant_seed) + torch.manual_seed(specified_variant_seed) + seeds = [specified_variant_seed] + target_x = create_random_tensors(shape, seeds=seeds) + # finally, slerp base_x noise to target_x noise for creating a variant + x = slerp(defaults.general.gpu, max(0.0, min(1.0, variant_amount)), base_x, target_x) + + samples_ddim = func_sample(init_data=init_data, x=x, conditioning=c, unconditional_conditioning=uc, sampler_name=sampler_name) + + if defaults.general.optimized: + modelFS.to(defaults.general.gpu) + + x_samples_ddim = (st.session_state["model"] if not defaults.general.optimized else modelFS).decode_first_stage(samples_ddim) + x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) + + for i, x_sample in enumerate(x_samples_ddim): + sanitized_prompt = slugify(prompts[i]) + + if sort_samples: + full_path = os.path.join(os.getcwd(), sample_path, sanitized_prompt) + + + sanitized_prompt = sanitized_prompt[:220-len(full_path)] + sample_path_i = os.path.join(sample_path, sanitized_prompt) + + #print(f"output folder length: {len(os.path.join(os.getcwd(), sample_path_i))}") + #print(os.path.join(os.getcwd(), sample_path_i)) + + os.makedirs(sample_path_i, exist_ok=True) + base_count = get_next_sequence_number(sample_path_i) + filename = f"{base_count:05}-{steps}_{sampler_name}_{seeds[i]}" + else: + full_path = os.path.join(os.getcwd(), sample_path) + sample_path_i = sample_path + base_count = get_next_sequence_number(sample_path_i) + filename = f"{base_count:05}-{steps}_{sampler_name}_{seeds[i]}_{sanitized_prompt}"[:220-len(full_path)] #same as before + + x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') + x_sample = x_sample.astype(np.uint8) + image = Image.fromarray(x_sample) + original_sample = x_sample + original_filename = filename + + if use_GFPGAN and st.session_state["GFPGAN"] is not None and not use_RealESRGAN: + #skip_save = True # #287 >_> + torch_gc() + cropped_faces, restored_faces, restored_img = st.session_state["GFPGAN"].enhance(x_sample[:,:,::-1], has_aligned=False, only_center_face=False, paste_back=True) + gfpgan_sample = restored_img[:,:,::-1] + gfpgan_image = Image.fromarray(gfpgan_sample) + gfpgan_filename = original_filename + '-gfpgan' + + save_sample(gfpgan_image, sample_path_i, gfpgan_filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale, + normalize_prompt_weights, use_GFPGAN, write_info_files, prompt_matrix, init_img, uses_loopback, + uses_random_seed_loopback, save_grid, sort_samples, sampler_name, ddim_eta, + n_iter, batch_size, i, denoising_strength, resize_mode, save_individual_images=False) + + output_images.append(gfpgan_image) #287 + if simple_templating: + grid_captions.append( captions[i] + "\ngfpgan" ) + + if use_RealESRGAN and st.session_state["RealESRGAN"] is not None and not use_GFPGAN: + #skip_save = True # #287 >_> + torch_gc() + + if st.session_state["RealESRGAN"].model.name != realesrgan_model_name: + #try_loading_RealESRGAN(realesrgan_model_name) + load_models(use_GFPGAN=use_GFPGAN, use_RealESRGAN=use_RealESRGAN, RealESRGAN_model=realesrgan_model_name) + + output, img_mode = st.session_state["RealESRGAN"].enhance(x_sample[:,:,::-1]) + esrgan_filename = original_filename + '-esrgan4x' + esrgan_sample = output[:,:,::-1] + esrgan_image = Image.fromarray(esrgan_sample) + + #save_sample(image, sample_path_i, original_filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale, + #normalize_prompt_weights, use_GFPGAN, write_info_files, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, skip_save, + #save_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode) + + save_sample(esrgan_image, sample_path_i, esrgan_filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale, + normalize_prompt_weights, use_GFPGAN, write_info_files, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, + save_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, save_individual_images=False) + + output_images.append(esrgan_image) #287 + if simple_templating: + grid_captions.append( captions[i] + "\nesrgan" ) + + if use_RealESRGAN and st.session_state["RealESRGAN"] is not None and use_GFPGAN and st.session_state["GFPGAN"] is not None: + #skip_save = True # #287 >_> + torch_gc() + cropped_faces, restored_faces, restored_img = st.session_state["GFPGAN"].enhance(x_sample[:,:,::-1], has_aligned=False, only_center_face=False, paste_back=True) + gfpgan_sample = restored_img[:,:,::-1] + + if st.session_state["RealESRGAN"].model.name != realesrgan_model_name: + #try_loading_RealESRGAN(realesrgan_model_name) + load_models(use_GFPGAN=use_GFPGAN, use_RealESRGAN=use_RealESRGAN, RealESRGAN_model=realesrgan_model_name) + + output, img_mode = st.session_state["RealESRGAN"].enhance(gfpgan_sample[:,:,::-1]) + gfpgan_esrgan_filename = original_filename + '-gfpgan-esrgan4x' + gfpgan_esrgan_sample = output[:,:,::-1] + gfpgan_esrgan_image = Image.fromarray(gfpgan_esrgan_sample) + + save_sample(gfpgan_esrgan_image, sample_path_i, gfpgan_esrgan_filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale, + normalize_prompt_weights, False, write_info_files, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, + save_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, save_individual_images=False) + + output_images.append(gfpgan_esrgan_image) #287 + + if simple_templating: + grid_captions.append( captions[i] + "\ngfpgan_esrgan" ) + + if save_individual_images: + save_sample(image, sample_path_i, filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale, + normalize_prompt_weights, use_GFPGAN, write_info_files, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, + save_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, save_individual_images) + + if not use_GFPGAN or not use_RealESRGAN: + output_images.append(image) + + #if add_original_image or not simple_templating: + #output_images.append(image) + #if simple_templating: + #grid_captions.append( captions[i] ) + + if defaults.general.optimized: + mem = torch.cuda.memory_allocated()/1e6 + modelFS.to("cpu") + while(torch.cuda.memory_allocated()/1e6 >= mem): + time.sleep(1) + + if prompt_matrix or save_grid: + if prompt_matrix: + if simple_templating: + grid = image_grid(output_images, n_iter, force_n_rows=frows, captions=grid_captions) + else: + grid = image_grid(output_images, n_iter, force_n_rows=1 << ((len(prompt_matrix_parts)-1)//2)) + try: + grid = draw_prompt_matrix(grid, width, height, prompt_matrix_parts) + except: + import traceback + print("Error creating prompt_matrix text:", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + else: + grid = image_grid(output_images, batch_size) + + if grid and (batch_size > 1 or n_iter > 1): + output_images.insert(0, grid) + + grid_count = get_next_sequence_number(outpath, 'grid-') + grid_file = f"grid-{grid_count:05}-{seed}_{slugify(prompts[i].replace(' ', '_')[:220-len(full_path)])}.{grid_ext}" + grid.save(os.path.join(outpath, grid_file), grid_format, quality=grid_quality, lossless=grid_lossless, optimize=True) + + toc = time.time() + + mem_max_used, mem_total = mem_mon.read_and_stop() + time_diff = time.time()-start_time + + info = f""" + {prompt} + Steps: {steps}, Sampler: {sampler_name}, CFG scale: {cfg_scale}, Seed: {seed}{', Denoising strength: '+str(denoising_strength) if init_img is not None else ''}{', GFPGAN' if use_GFPGAN and st.session_state["GFPGAN"] is not None else ''}{', '+realesrgan_model_name if use_RealESRGAN and st.session_state["RealESRGAN"] is not None else ''}{', Prompt Matrix Mode.' if prompt_matrix else ''}""".strip() + stats = f''' + Took { round(time_diff, 2) }s total ({ round(time_diff/(len(all_prompts)),2) }s per image) + Peak memory usage: { -(mem_max_used // -1_048_576) } MiB / { -(mem_total // -1_048_576) } MiB / { round(mem_max_used/mem_total*100, 3) }%''' + + for comment in comments: + info += "\n\n" + comment + + #mem_mon.stop() + #del mem_mon + torch_gc() + + return output_images, seed, info, stats + + +def resize_image(resize_mode, im, width, height): + LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS) + if resize_mode == 0: + res = im.resize((width, height), resample=LANCZOS) + elif resize_mode == 1: + ratio = width / height + src_ratio = im.width / im.height + + src_w = width if ratio > src_ratio else im.width * height // im.height + src_h = height if ratio <= src_ratio else im.height * width // im.width + + resized = im.resize((src_w, src_h), resample=LANCZOS) + res = Image.new("RGBA", (width, height)) + res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2)) + else: + ratio = width / height + src_ratio = im.width / im.height + + src_w = width if ratio < src_ratio else im.width * height // im.height + src_h = height if ratio >= src_ratio else im.height * width // im.width + + resized = im.resize((src_w, src_h), resample=LANCZOS) + res = Image.new("RGBA", (width, height)) + res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2)) + + if ratio < src_ratio: + fill_height = height // 2 - src_h // 2 + res.paste(resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0)) + res.paste(resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)), box=(0, fill_height + src_h)) + elif ratio > src_ratio: + fill_width = width // 2 - src_w // 2 + res.paste(resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0)) + res.paste(resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)), box=(fill_width + src_w, 0)) + + return res + +def img2img(prompt: str = '', init_info: any = None, init_info_mask: any = None, mask_mode: int = 0, mask_blur_strength: int = 3, + ddim_steps: int = 50, sampler_name: str = 'DDIM', + n_iter: int = 1, cfg_scale: float = 7.5, denoising_strength: float = 0.8, + seed: int = -1, height: int = 512, width: int = 512, resize_mode: int = 0, fp = None, + variant_amount: float = None, variant_seed: int = None, ddim_eta:float = 0.0, + write_info_files:bool = True, RealESRGAN_model: str = "RealESRGAN_x4plus_anime_6B", + separate_prompts:bool = False, normalize_prompt_weights:bool = True, + save_individual_images: bool = True, save_grid: bool = True, group_by_prompt: bool = True, + save_as_jpg: bool = True, use_GFPGAN: bool = True, use_RealESRGAN: bool = True, loopback: bool = False, + random_seed_loopback: bool = False + ): + + outpath = defaults.general.outdir_img2img or defaults.general.outdir or "outputs/img2img-samples" + err = False + #loopback = False + #skip_save = False + seed = seed_to_int(seed) + + batch_size = 1 + + #prompt_matrix = 0 + #normalize_prompt_weights = 1 in toggles + #loopback = 2 in toggles + #random_seed_loopback = 3 in toggles + #skip_save = 4 not in toggles + #save_grid = 5 in toggles + #sort_samples = 6 in toggles + #write_info_files = 7 in toggles + #write_sample_info_to_log_file = 8 in toggles + #jpg_sample = 9 in toggles + #use_GFPGAN = 10 in toggles + #use_RealESRGAN = 11 in toggles + + if sampler_name == 'PLMS': + sampler = PLMSSampler(st.session_state["model"]) + elif sampler_name == 'DDIM': + sampler = DDIMSampler(st.session_state["model"]) + elif sampler_name == 'k_dpm_2_a': + sampler = KDiffusionSampler(st.session_state["model"],'dpm_2_ancestral') + elif sampler_name == 'k_dpm_2': + sampler = KDiffusionSampler(st.session_state["model"],'dpm_2') + elif sampler_name == 'k_euler_a': + sampler = KDiffusionSampler(st.session_state["model"],'euler_ancestral') + elif sampler_name == 'k_euler': + sampler = KDiffusionSampler(st.session_state["model"],'euler') + elif sampler_name == 'k_heun': + sampler = KDiffusionSampler(st.session_state["model"],'heun') + elif sampler_name == 'k_lms': + sampler = KDiffusionSampler(st.session_state["model"],'lms') + else: + raise Exception("Unknown sampler: " + sampler_name) + + init_img = init_info + init_mask = None + keep_mask = False + + assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]' + t_enc = int(denoising_strength * ddim_steps) + + def init(): + + image = init_img + image = np.array(image).astype(np.float32) / 255.0 + image = image[None].transpose(0, 3, 1, 2) + image = torch.from_numpy(image) + + mask = None + if defaults.general.optimized: + modelFS.to(st.session_state["device"] ) + + init_image = 2. * image - 1. + init_image = init_image.to(st.session_state["device"]) + init_latent = (st.session_state["model"] if not defaults.general.optimized else modelFS).get_first_stage_encoding((st.session_state["model"] if not defaults.general.optimized else modelFS).encode_first_stage(init_image)) # move to latent space + + if defaults.general.optimized: + mem = torch.cuda.memory_allocated()/1e6 + modelFS.to("cpu") + while(torch.cuda.memory_allocated()/1e6 >= mem): + time.sleep(1) + + return init_latent, mask, + + def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name): + t_enc_steps = t_enc + obliterate = False + if ddim_steps == t_enc_steps: + t_enc_steps = t_enc_steps - 1 + obliterate = True + + if sampler_name != 'DDIM': + x0, z_mask = init_data + + sigmas = sampler.model_wrap.get_sigmas(ddim_steps) + noise = x * sigmas[ddim_steps - t_enc_steps - 1] + + xi = x0 + noise + + # Obliterate masked image + if z_mask is not None and obliterate: + random = torch.randn(z_mask.shape, device=xi.device) + xi = (z_mask * noise) + ((1-z_mask) * xi) + + sigma_sched = sigmas[ddim_steps - t_enc_steps - 1:] + model_wrap_cfg = CFGMaskedDenoiser(sampler.model_wrap) + samples_ddim = K.sampling.__dict__[f'sample_{sampler.get_sampler_name()}'](model_wrap_cfg, xi, sigma_sched, + extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, + 'cond_scale': cfg_scale, 'mask': z_mask, 'x0': x0, 'xi': xi}, disable=False, + callback=generation_callback) + else: + + x0, z_mask = init_data + + sampler.make_schedule(ddim_num_steps=ddim_steps, ddim_eta=0.0, verbose=False) + z_enc = sampler.stochastic_encode(x0, torch.tensor([t_enc_steps]*batch_size).to(st.session_state["device"] )) + + # Obliterate masked image + if z_mask is not None and obliterate: + random = torch.randn(z_mask.shape, device=z_enc.device) + z_enc = (z_mask * random) + ((1-z_mask) * z_enc) + + # decode it + samples_ddim = sampler.decode(z_enc, conditioning, t_enc_steps, + unconditional_guidance_scale=cfg_scale, + unconditional_conditioning=unconditional_conditioning, + z_mask=z_mask, x0=x0) + return samples_ddim + + + + if loopback: + output_images, info = None, None + history = [] + initial_seed = None + + do_color_correction = False + try: + from skimage import exposure + do_color_correction = True + except: + print("Install scikit-image to perform color correction on loopback") + + for i in range(1): + if do_color_correction and i == 0: + correction_target = cv2.cvtColor(np.asarray(init_img.copy()), cv2.COLOR_RGB2LAB) + + output_images, seed, info, stats = process_images( + outpath=outpath, + func_init=init, + func_sample=sample, + prompt=prompt, + seed=seed, + sampler_name=sampler_name, + save_grid=save_grid, + batch_size=1, + n_iter=n_iter, + steps=ddim_steps, + cfg_scale=cfg_scale, + width=width, + height=height, + prompt_matrix=separate_prompts, + use_GFPGAN=use_GFPGAN, + use_RealESRGAN=use_RealESRGAN, # Forcefully disable upscaling when using loopback + realesrgan_model_name=RealESRGAN_model, + fp=fp, + normalize_prompt_weights=normalize_prompt_weights, + save_individual_images=save_individual_images, + init_img=init_img, + init_mask=init_mask, + keep_mask=keep_mask, + mask_blur_strength=mask_blur_strength, + denoising_strength=denoising_strength, + resize_mode=resize_mode, + uses_loopback=loopback, + uses_random_seed_loopback=random_seed_loopback, + sort_samples=group_by_prompt, + write_info_files=write_info_files, + jpg_sample=save_as_jpg + ) + + if initial_seed is None: + initial_seed = seed + + init_img = output_images[0] + + if do_color_correction and correction_target is not None: + init_img = Image.fromarray(cv2.cvtColor(exposure.match_histograms( + cv2.cvtColor( + np.asarray(init_img), + cv2.COLOR_RGB2LAB + ), + correction_target, + channel_axis=2 + ), cv2.COLOR_LAB2RGB).astype("uint8")) + + if not random_seed_loopback: + seed = seed + 1 + else: + seed = seed_to_int(None) + + denoising_strength = max(denoising_strength * 0.95, 0.1) + history.append(init_img) + + output_images = history + seed = initial_seed + + else: + output_images, seed, info, stats = process_images( + outpath=outpath, + func_init=init, + func_sample=sample, + prompt=prompt, + seed=seed, + sampler_name=sampler_name, + save_grid=save_grid, + batch_size=batch_size, + n_iter=n_iter, + steps=ddim_steps, + cfg_scale=cfg_scale, + width=width, + height=height, + prompt_matrix=separate_prompts, + use_GFPGAN=use_GFPGAN, + use_RealESRGAN=use_RealESRGAN, + realesrgan_model_name=RealESRGAN_model, + fp=fp, + normalize_prompt_weights=normalize_prompt_weights, + save_individual_images=save_individual_images, + init_img=init_img, + init_mask=init_mask, + keep_mask=keep_mask, + mask_blur_strength=2, + denoising_strength=denoising_strength, + resize_mode=resize_mode, + uses_loopback=loopback, + sort_samples=group_by_prompt, + write_info_files=write_info_files, + jpg_sample=save_as_jpg + ) + + del sampler + + return output_images, seed, info, stats + +#@retry(RuntimeError, tries=3) +def txt2img(prompt: str, ddim_steps: int, sampler_name: str, realesrgan_model_name: str, + n_iter: int, batch_size: int, cfg_scale: float, seed: Union[int, str, None], + height: int, width: int, separate_prompts:bool = False, normalize_prompt_weights:bool = True, + save_individual_images: bool = True, save_grid: bool = True, group_by_prompt: bool = True, + save_as_jpg: bool = True, use_GFPGAN: bool = True, use_RealESRGAN: bool = True, + RealESRGAN_model: str = "RealESRGAN_x4plus_anime_6B", fp = None, variant_amount: float = None, + variant_seed: int = None, ddim_eta:float = 0.0, write_info_files:bool = True): + + outpath = defaults.general.outdir_txt2img or defaults.general.outdir or "outputs/txt2img-samples" + + err = False + seed = seed_to_int(seed) + + #prompt_matrix = 0 in toggles + #normalize_prompt_weights = 1 in toggles + #skip_save = 2 not in toggles + #save_grid = 3 not in toggles + #sort_samples = 4 in toggles + #write_info_files = 5 in toggles + #jpg_sample = 6 in toggles + #use_GFPGAN = 7 in toggles + #use_RealESRGAN = 8 in toggles + + if sampler_name == 'PLMS': + sampler = PLMSSampler(st.session_state["model"]) + elif sampler_name == 'DDIM': + sampler = DDIMSampler(st.session_state["model"]) + elif sampler_name == 'k_dpm_2_a': + sampler = KDiffusionSampler(st.session_state["model"],'dpm_2_ancestral') + elif sampler_name == 'k_dpm_2': + sampler = KDiffusionSampler(st.session_state["model"],'dpm_2') + elif sampler_name == 'k_euler_a': + sampler = KDiffusionSampler(st.session_state["model"],'euler_ancestral') + elif sampler_name == 'k_euler': + sampler = KDiffusionSampler(st.session_state["model"],'euler') + elif sampler_name == 'k_heun': + sampler = KDiffusionSampler(st.session_state["model"],'heun') + elif sampler_name == 'k_lms': + sampler = KDiffusionSampler(st.session_state["model"],'lms') + else: + raise Exception("Unknown sampler: " + sampler_name) + + def init(): + pass + + def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name): + samples_ddim, _ = sampler.sample(S=ddim_steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=cfg_scale, + unconditional_conditioning=unconditional_conditioning, eta=ddim_eta, x_T=x, img_callback=generation_callback, + log_every_t=int(defaults.general.update_preview_frequency)) + + return samples_ddim + + #try: + output_images, seed, info, stats = process_images( + outpath=outpath, + func_init=init, + func_sample=sample, + prompt=prompt, + seed=seed, + sampler_name=sampler_name, + save_grid=save_grid, + batch_size=batch_size, + n_iter=n_iter, + steps=ddim_steps, + cfg_scale=cfg_scale, + width=width, + height=height, + prompt_matrix=separate_prompts, + use_GFPGAN=use_GFPGAN, + use_RealESRGAN=use_RealESRGAN, + realesrgan_model_name=realesrgan_model_name, + fp=fp, + ddim_eta=ddim_eta, + normalize_prompt_weights=normalize_prompt_weights, + save_individual_images=save_individual_images, + sort_samples=group_by_prompt, + write_info_files=write_info_files, + jpg_sample=save_as_jpg, + variant_amount=variant_amount, + variant_seed=variant_seed, + ) + + del sampler + + return output_images, seed, info, stats + + #except RuntimeError as e: + #err = e + #err_msg = f'CRASHED:


Please wait while the program restarts.' + #stats = err_msg + #return [], seed, 'err', stats + + + # functions to load css locally OR remotely starts here. Options exist for future flexibility. Called as st.markdown with unsafe_allow_html as css injection # TODO, maybe look into async loading the file especially for remote fetching def local_css(file_name): - with open(file_name) as f: - st.markdown(f'', unsafe_allow_html=True) + with open(file_name) as f: + st.markdown(f'', unsafe_allow_html=True) def remote_css(url): - st.markdown(f'', unsafe_allow_html=True) + st.markdown(f'', unsafe_allow_html=True) def load_css(isLocal, nameOrURL): if(isLocal): @@ -57,89 +1466,289 @@ def load_css(isLocal, nameOrURL): else: remote_css(nameOrURL) + +# main functions to define streamlit layout here def layout(): - """Layout functions to define all the streamlit layout here.""" - st.set_page_config(page_title="Stable Diffusion Playground", layout="wide") + + st.set_page_config(page_title="Stable Diffusion Playground", layout="wide", initial_sidebar_state="collapsed") with st.empty(): # load css as an external file, function has an option to local or remote url. Potential use when running from cloud infra that might not have access to local path. load_css(True, 'frontend/css/streamlit.main.css') - - # check if the models exist on their respective folders - if os.path.exists(os.path.join(st.session_state["defaults"].general.GFPGAN_dir, "experiments", "pretrained_models", "GFPGANv1.3.pth")): - st.session_state["GFPGAN_available"] = True - else: - st.session_state["GFPGAN_available"] = False - if os.path.exists(os.path.join(st.session_state["defaults"].general.RealESRGAN_dir, "experiments","pretrained_models", f"{st.session_state['defaults'].general.RealESRGAN_model}.pth")): - st.session_state["RealESRGAN_available"] = True + # check if the models exist on their respective folders + if os.path.exists(os.path.join(defaults.general.GFPGAN_dir, "experiments", "pretrained_models", "GFPGANv1.3.pth")): + GFPGAN_available = True else: - st.session_state["RealESRGAN_available"] = False - - # Allow for custom models to be used instead of the default one, - # an example would be Waifu-Diffusion or any other fine tune of stable diffusion - st.session_state["custom_models"]:sorted = [] - for root, dirs, files in os.walk(os.path.join("models", "custom")): - for file in files: - if os.path.splitext(file)[1] == '.ckpt': - #fullpath = os.path.join(root, file) - #print(fullpath) - st.session_state["custom_models"].append(os.path.splitext(file)[0]) - #print (os.path.splitext(file)[0]) - - if len(st.session_state["custom_models"]) > 0: - st.session_state["CustomModel_available"] = True - st.session_state["custom_models"].append("Stable Diffusion v1.4") + GFPGAN_available = False + + if os.path.exists(os.path.join(defaults.general.RealESRGAN_dir, "experiments","pretrained_models", f"{defaults.general.RealESRGAN_model}.pth")): + RealESRGAN_available = True else: - st.session_state["CustomModel_available"] = False + RealESRGAN_available = False with st.sidebar: - # The global settings section will be moved to the Settings page. + # we should use an expander and group things together when more options are added so the sidebar is not too messy. #with st.expander("Global Settings:"): - #st.write("Global Settings:") - #defaults.general.update_preview = st.checkbox("Update Image Preview", value=defaults.general.update_preview, - #help="If enabled the image preview will be updated during the generation instead of at the end. You can use the Update Preview \ - #Frequency option bellow to customize how frequent it's updated. By default this is enabled and the frequency is set to 1 step.") - #st.session_state.update_preview_frequency = st.text_input("Update Image Preview Frequency", value=defaults.general.update_preview_frequency, - #help="Frequency in steps at which the the preview image is updated. By default the frequency is set to 1 step.") - - tabs = on_hover_tabs(tabName=['Stable Diffusion', "Textual Inversion","Model Manager","Settings"], - iconName=['dashboard','model_training' ,'cloud_download', 'settings'], default_choice=0) - - if tabs =='Stable Diffusion': - # txt2img_tab, img2img_tab, txt2vid_tab, postprocessing_tab, concept_library_tab = st.tabs(["Text-to-Image Unified", "Image-to-Image Unified", - # "Text-to-Video","Post-Processing", "Concept Library"]) - txt2img_tab, img2img_tab, txt2vid_tab = st.tabs( - ["Text-to-Image Unified", "Image-to-Image Unified", "Text-to-Video"] - ) - #with home_tab: - #from home import layout - #layout() - - with txt2img_tab: - from txt2img import layout - layout() - - with img2img_tab: - from img2img import layout - layout() - - with txt2vid_tab: - from txt2vid import layout - layout() - - # with concept_library_tab: - # from sd_concept_library import layout - # layout() - - # - elif tabs == 'Model Manager': - from ModelManager import layout - layout() + st.write("Global Settings:") + defaults.general.update_preview = st.checkbox("Update Image Preview", value=defaults.general.update_preview, + help="If enabled the image preview will be updated during the generation instead of at the end. You can use the Update Preview \ + Frequency option bellow to customize how frequent it's updated. By default this is enabled and the frequency is set to 1 step.") + defaults.general.update_preview_frequency = st.text_input("Update Image Preview Frequency", value=defaults.general.update_preview_frequency, + help="Frequency in steps at which the the preview image is updated. By default the frequency is set to 1 step.") + + + + txt2img_tab, img2img_tab, txt2video, postprocessing_tab = st.tabs(["Text-to-Image Unified", "Image-to-Image Unified", "Text-to-Video","Post-Processing"]) + + with txt2img_tab: + with st.form("txt2img-inputs"): + st.session_state["generation_mode"] = "txt2img" + + input_col1, generate_col1 = st.columns([10,1]) + with input_col1: + #prompt = st.text_area("Input Text","") + prompt = st.text_input("Input Text","", placeholder="A corgi wearing a top hat as an oil painting.") + + # Every form must have a submit button, the extra blank spaces is a temp way to align it with the input field. Needs to be done in CSS or some other way. + generate_col1.write("") + generate_col1.write("") + generate_button = generate_col1.form_submit_button("Generate") + + # creating the page layout using columns + col1, col2, col3 = st.columns([1,2,1], gap="large") + + with col1: + width = st.slider("Width:", min_value=64, max_value=1024, value=defaults.txt2img.width, step=64) + height = st.slider("Height:", min_value=64, max_value=1024, value=defaults.txt2img.height, step=64) + cfg_scale = st.slider("CFG (Classifier Free Guidance Scale):", min_value=1.0, max_value=30.0, value=defaults.txt2img.cfg_scale, step=0.5, help="How strongly the image should follow the prompt.") + seed = st.text_input("Seed:", value=defaults.txt2img.seed, help=" The seed to use, if left blank a random seed will be generated.") + batch_count = st.slider("Batch count.", min_value=1, max_value=100, value=defaults.txt2img.batch_count, step=1, help="How many iterations or batches of images to generate in total.") + #batch_size = st.slider("Batch size", min_value=1, max_value=250, value=defaults.txt2img.batch_size, step=1, + #help="How many images are at once in a batch.\ + #It increases the VRAM usage a lot but if you have enough VRAM it can reduce the time it takes to finish generation as more images are generated at once.\ + #Default: 1") + + with col2: + preview_tab, gallery_tab = st.tabs(["Preview", "Gallery"]) + + with preview_tab: + #st.write("Image") + #Image for testing + #image = Image.open(requests.get("https://icon-library.com/images/image-placeholder-icon/image-placeholder-icon-13.jpg", stream=True).raw).convert('RGB') + #new_image = image.resize((175, 240)) + #preview_image = st.image(image) + + # create an empty container for the image, progress bar, etc so we can update it later and use session_state to hold them globally. + st.session_state["preview_image"] = st.empty() + + st.session_state["loading"] = st.empty() + + st.session_state["progress_bar_text"] = st.empty() + st.session_state["progress_bar"] = st.empty() + + message = st.empty() + + with gallery_tab: + st.write('Here should be the image gallery, if I could make a grid in streamlit.') + + with col3: + st.session_state.sampling_steps = st.slider("Sampling Steps", value=defaults.txt2img.sampling_steps, min_value=1, max_value=250) + + sampler_name_list = ["k_lms", "k_euler", "k_euler_a", "k_dpm_2", "k_dpm_2_a", "k_heun", "PLMS", "DDIM"] + sampler_name = st.selectbox("Sampling method", sampler_name_list, + index=sampler_name_list.index(defaults.txt2img.default_sampler), help="Sampling method to use. Default: k_euler") + + + + #basic_tab, advanced_tab = st.tabs(["Basic", "Advanced"]) + + #with basic_tab: + #summit_on_enter = st.radio("Submit on enter?", ("Yes", "No"), horizontal=True, + #help="Press the Enter key to summit, when 'No' is selected you can use the Enter key to write multiple lines.") + + with st.expander("Advanced"): + separate_prompts = st.checkbox("Create Prompt Matrix.", value=False, help="Separate multiple prompts using the `|` character, and get all combinations of them.") + normalize_prompt_weights = st.checkbox("Normalize Prompt Weights.", value=True, help="Ensure the sum of all weights add up to 1.0") + save_individual_images = st.checkbox("Save individual images.", value=True, help="Save each image generated before any filter or enhancement is applied.") + save_grid = st.checkbox("Save grid",value=True, help="Save a grid with all the images generated into a single image.") + group_by_prompt = st.checkbox("Group results by prompt", value=True, + help="Saves all the images with the same prompt into the same folder. When using a prompt matrix each prompt combination will have its own folder.") + write_info_files = st.checkbox("Write Info file", value=True, help="Save a file next to the image with informartion about the generation.") + save_as_jpg = st.checkbox("Save samples as jpg", value=False, help="Saves the images as jpg instead of png.") + + if GFPGAN_available: + use_GFPGAN = st.checkbox("Use GFPGAN", value=defaults.txt2img.use_GFPGAN, help="Uses the GFPGAN model to improve faces after the generation. This greatly improve the quality and consistency of faces but uses extra VRAM. Disable if you need the extra VRAM.") + else: + use_GFPGAN = False + + if RealESRGAN_available: + use_RealESRGAN = st.checkbox("Use RealESRGAN", value=defaults.txt2img.use_RealESRGAN, help="Uses the RealESRGAN model to upscale the images after the generation. This greatly improve the quality and lets you have high resolution images but uses extra VRAM. Disable if you need the extra VRAM.") + RealESRGAN_model = st.selectbox("RealESRGAN model", ["RealESRGAN_x4plus", "RealESRGAN_x4plus_anime_6B"], index=0) + else: + use_RealESRGAN = False + RealESRGAN_model = "RealESRGAN_x4plus" + + variant_amount = st.slider("Variant Amount:", value=defaults.txt2img.variant_amount, min_value=0.0, max_value=1.0, step=0.01) + variant_seed = st.text_input("Variant Seed:", value=defaults.txt2img.seed, help="The seed to use when generating a variant, if left blank a random seed will be generated.") + + + if generate_button: + #print("Loading models") + # load the models when we hit the generate button for the first time, it wont be loaded after that so dont worry. + load_models(False, use_GFPGAN, use_RealESRGAN, RealESRGAN_model) + + try: + output_images, seed, info, stats = txt2img(prompt, st.session_state.sampling_steps, sampler_name, RealESRGAN_model, batch_count, 1, + cfg_scale, seed, height, width, separate_prompts, normalize_prompt_weights, save_individual_images, + save_grid, group_by_prompt, save_as_jpg, use_GFPGAN, use_RealESRGAN, RealESRGAN_model, fp=defaults.general.fp, + variant_amount=variant_amount, variant_seed=variant_seed, write_info_files=write_info_files) - # elif tabs == 'Textual Inversion': - # from textual_inversion import layout - # layout() + message.success('Done!', icon="✅") + + except (StopException, KeyError): + print(f"Received Streamlit StopException") + + # this will render all the images at the end of the generation but its better if its moved to a second tab inside col2 and shown as a gallery. + # use the current col2 first tab to show the preview_img and update it as its generated. + #preview_image.image(output_images) + + with img2img_tab: + with st.form("img2img-inputs"): + st.session_state["generation_mode"] = "img2img" + + img2img_input_col, img2img_generate_col = st.columns([10,1]) + with img2img_input_col: + #prompt = st.text_area("Input Text","") + prompt = st.text_input("Input Text","", placeholder="A corgi wearing a top hat as an oil painting.") + + # Every form must have a submit button, the extra blank spaces is a temp way to align it with the input field. Needs to be done in CSS or some other way. + img2img_generate_col.write("") + img2img_generate_col.write("") + generate_button = img2img_generate_col.form_submit_button("Generate") + + + # creating the page layout using columns + col1_img2img_layout, col2_img2img_layout, col3_img2img_layout = st.columns([1,2,2], gap="small") + + with col1_img2img_layout: + st.session_state["sampling_steps"] = st.slider("Sampling Steps", value=defaults.img2img.sampling_steps, min_value=1, max_value=250) + st.session_state["sampler_name"] = st.selectbox("Sampling method", ["k_lms", "k_euler", "k_euler_a", "k_dpm_2", "k_dpm_2_a", "k_heun", "PLMS", "DDIM"], + index=0, help="Sampling method to use. Default: k_lms") + + uploaded_images = st.file_uploader("Upload Image", accept_multiple_files=False, type=["png", "jpg", "jpeg"], + help="Upload an image which will be used for the image to image generation." + ) + + width = st.slider("Width:", min_value=64, max_value=1024, value=defaults.img2img.width, step=64) + height = st.slider("Height:", min_value=64, max_value=1024, value=defaults.img2img.height, step=64) + seed = st.text_input("Seed:", value=defaults.img2img.seed, help=" The seed to use, if left blank a random seed will be generated.") + batch_count = st.slider("Batch count.", min_value=1, max_value=100, value=defaults.img2img.batch_count, step=1, help="How many iterations or batches of images to generate in total.") + + # + with st.expander("Advanced"): + separate_prompts = st.checkbox("Create Prompt Matrix.", value=defaults.img2img.separate_prompts, help="Separate multiple prompts using the `|` character, and get all combinations of them.") + normalize_prompt_weights = st.checkbox("Normalize Prompt Weights.", value=defaults.img2img.normalize_prompt_weights, help="Ensure the sum of all weights add up to 1.0") + loopback = st.checkbox("Loopback.", value=defaults.img2img.loopback, help="Use images from previous batch when creating next batch.") + random_seed_loopback = st.checkbox("Random loopback seed.", value=defaults.img2img.random_seed_loopback, help="Random loopback seed") + save_individual_images = st.checkbox("Save individual images.", value=True, help="Save each image generated before any filter or enhancement is applied.") + save_grid = st.checkbox("Save grid",value=defaults.img2img.save_grid, help="Save a grid with all the images generated into a single image.") + group_by_prompt = st.checkbox("Group results by prompt", value=defaults.img2img.group_by_prompt, + help="Saves all the images with the same prompt into the same folder. When using a prompt matrix each prompt combination will have its own folder.") + write_info_files = st.checkbox("Write Info file", value=True, help="Save a file next to the image with informartion about the generation.") + save_as_jpg = st.checkbox("Save samples as jpg", value=False, help="Saves the images as jpg instead of png.") + + if GFPGAN_available: + use_GFPGAN = st.checkbox("Use GFPGAN", value=defaults.img2img.use_GFPGAN, help="Uses the GFPGAN model to improve faces after the generation.\ + This greatly improve the quality and consistency of faces but uses extra VRAM. Disable if you need the extra VRAM.") + else: + use_GFPGAN = False + + if RealESRGAN_available: + use_RealESRGAN = st.checkbox("Use RealESRGAN", value=defaults.img2img.use_RealESRGAN, help="Uses the RealESRGAN model to upscale the images after the generation.\ + This greatly improve the quality and lets you have high resolution images but uses extra VRAM. Disable if you need the extra VRAM.") + RealESRGAN_model = st.selectbox("RealESRGAN model", ["RealESRGAN_x4plus", "RealESRGAN_x4plus_anime_6B"], index=0) + else: + use_RealESRGAN = False + RealESRGAN_model = "RealESRGAN_x4plus" + + variant_amount = st.slider("Variant Amount:", value=defaults.img2img.variant_amount, min_value=0.0, max_value=1.0, step=0.01) + variant_seed = st.text_input("Variant Seed:", value=defaults.img2img.variant_seed, help="The seed to use when generating a variant, if left blank a random seed will be generated.") + cfg_scale = st.slider("CFG (Classifier Free Guidance Scale):", min_value=1.0, max_value=30.0, value=defaults.img2img.cfg_scale, step=0.5, help="How strongly the image should follow the prompt.") + batch_size = st.slider("Batch size", min_value=1, max_value=100, value=defaults.img2img.batch_size, step=1, + help="How many images are at once in a batch.\ + It increases the VRAM usage a lot but if you have enough VRAM it can reduce the time it takes to finish generation as more images are generated at once.\ + Default: 1") + + st.session_state["denoising_strength"] = st.slider("Denoising Strength:", value=defaults.img2img.denoising_strength, min_value=0.01, max_value=1.0, step=0.01) + + + with col2_img2img_layout: + editor_tab = st.tabs(["Editor"]) + + editor_image = st.empty() + st.session_state["editor_image"] = editor_image + + if uploaded_images: + image = Image.open(uploaded_images).convert('RGB') + #img_array = np.array(image) # if you want to pass it to OpenCV + new_img = image.resize((width, height)) + st.image(new_img) + + + with col3_img2img_layout: + result_tab = st.tabs(["Result"]) + + # create an empty container for the image, progress bar, etc so we can update it later and use session_state to hold them globally. + preview_image = st.empty() + st.session_state["preview_image"] = preview_image + + #st.session_state["loading"] = st.empty() + + st.session_state["progress_bar_text"] = st.empty() + st.session_state["progress_bar"] = st.empty() + + + message = st.empty() + + #if uploaded_images: + #image = Image.open(uploaded_images).convert('RGB') + ##img_array = np.array(image) # if you want to pass it to OpenCV + #new_img = image.resize((width, height)) + #st.image(new_img, use_column_width=True) + + + if generate_button: + #print("Loading models") + # load the models when we hit the generate button for the first time, it wont be loaded after that so dont worry. + load_models(False, use_GFPGAN, use_RealESRGAN, RealESRGAN_model) + if uploaded_images: + image = Image.open(uploaded_images).convert('RGB') + new_img = image.resize((width, height)) + #img_array = np.array(image) # if you want to pass it to OpenCV + + try: + output_images, seed, info, stats = img2img(prompt=prompt, init_info=new_img, ddim_steps=st.session_state["sampling_steps"], + sampler_name=st.session_state["sampler_name"], n_iter=batch_count, + cfg_scale=cfg_scale, denoising_strength=st.session_state["denoising_strength"], variant_seed=variant_seed, + seed=seed, width=width, height=height, fp=defaults.general.fp, variant_amount=variant_amount, + ddim_eta=0.0, write_info_files=write_info_files, RealESRGAN_model=RealESRGAN_model, + separate_prompts=separate_prompts, normalize_prompt_weights=normalize_prompt_weights, + save_individual_images=save_individual_images, save_grid=save_grid, + group_by_prompt=group_by_prompt, save_as_jpg=save_as_jpg, use_GFPGAN=use_GFPGAN, + use_RealESRGAN=use_RealESRGAN if not loopback else False, loopback=loopback + ) + #show a message when the generation is complete. + message.success('Done!', icon="✅") + + except (StopException, KeyError): + print(f"Received Streamlit StopException") + + # this will render all the images at the end of the generation but its better if its moved to a second tab inside col2 and shown as a gallery. + # use the current col2 first tab to show the preview_img and update it as its generated. + #preview_image.image(output_images, width=750) + + if __name__ == '__main__': layout() \ No newline at end of file diff --git a/scripts/webui_streamlit_old.py b/scripts/webui_streamlit_old.py deleted file mode 100644 index ad1b9da..0000000 --- a/scripts/webui_streamlit_old.py +++ /dev/null @@ -1,2738 +0,0 @@ -import warnings - -import piexif -import piexif.helper -import json - -import streamlit as st -from streamlit import StopException - -#streamlit components section -from st_on_hover_tabs import on_hover_tabs - -import base64, cv2 -import os, sys, re, random, datetime, timeit -from PIL import Image, ImageFont, ImageDraw, ImageFilter, ImageOps -from PIL.PngImagePlugin import PngInfo -from scipy import integrate -import pandas as pd -import torch -from torchdiffeq import odeint -import k_diffusion as K -import math -import mimetypes -import numpy as np -import pynvml -import threading -import time, inspect -import torch -from torch import autocast -from torchvision import transforms -import torch.nn as nn -import yaml -from typing import Union -from pathlib import Path -#from tqdm import tqdm -from contextlib import nullcontext -from einops import rearrange -from omegaconf import OmegaConf -from io import StringIO -from ldm.models.diffusion.ddim import DDIMSampler -from ldm.models.diffusion.plms import PLMSSampler -from ldm.util import instantiate_from_config - -from retry import retry - -# these are for testing txt2vid, should be removed and we should use things from our own code. -from diffusers import StableDiffusionPipeline -from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler - -#will be used for saving and reading a video made by the txt2vid function -import imageio, io - -# we use python-slugify to make the filenames safe for windows and linux, its better than doing it manually -# install it with 'pip install python-slugify' -from slugify import slugify - -try: - # this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start. - from transformers import logging - - logging.set_verbosity_error() -except: - pass - -# remove some annoying deprecation warnings that show every now and then. -warnings.filterwarnings("ignore", category=DeprecationWarning) - -defaults = OmegaConf.load("configs/webui/webui_streamlit.yaml") -if (os.path.exists("configs/webui/userconfig_streamlit.yaml")): - user_defaults = OmegaConf.load("configs/webui/userconfig_streamlit.yaml"); - defaults = OmegaConf.merge(defaults, user_defaults) - -# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the bowser will not show any UI -mimetypes.init() -mimetypes.add_type('application/javascript', '.js') - -# some of those options should not be changed at all because they would break the model, so I removed them from options. -opt_C = 4 -opt_f = 8 - -# should and will be moved to a settings menu in the UI at some point -grid_format = [s.lower() for s in defaults.general.grid_format.split(':')] -grid_lossless = False -grid_quality = 100 -if grid_format[0] == 'png': - grid_ext = 'png' - grid_format = 'png' -elif grid_format[0] in ['jpg', 'jpeg']: - grid_quality = int(grid_format[1]) if len(grid_format) > 1 else 100 - grid_ext = 'jpg' - grid_format = 'jpeg' -elif grid_format[0] == 'webp': - grid_quality = int(grid_format[1]) if len(grid_format) > 1 else 100 - grid_ext = 'webp' - grid_format = 'webp' - if grid_quality < 0: # e.g. webp:-100 for lossless mode - grid_lossless = True - grid_quality = abs(grid_quality) - -# should and will be moved to a settings menu in the UI at some point -save_format = [s.lower() for s in defaults.general.save_format.split(':')] -save_lossless = False -save_quality = 100 -if save_format[0] == 'png': - save_ext = 'png' - save_format = 'png' -elif save_format[0] in ['jpg', 'jpeg']: - save_quality = int(save_format[1]) if len(save_format) > 1 else 100 - save_ext = 'jpg' - save_format = 'jpeg' -elif save_format[0] == 'webp': - save_quality = int(save_format[1]) if len(save_format) > 1 else 100 - save_ext = 'webp' - save_format = 'webp' - if save_quality < 0: # e.g. webp:-100 for lossless mode - save_lossless = True - save_quality = abs(save_quality) - -# this should force GFPGAN and RealESRGAN onto the selected gpu as well -os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" # see issue #152 -os.environ["CUDA_VISIBLE_DEVICES"] = str(defaults.general.gpu) - -@retry(tries=5) -def load_models(continue_prev_run = False, use_GFPGAN=False, use_RealESRGAN=False, RealESRGAN_model="RealESRGAN_x4plus", - CustomModel_available=False, custom_model="Stable Diffusion v1.4"): - """Load the different models. We also reuse the models that are already in memory to speed things up instead of loading them again. """ - - print ("Loading models.") - - st.session_state["progress_bar_text"].text("Loading models...") - - # Generate random run ID - # Used to link runs linked w/ continue_prev_run which is not yet implemented - # Use URL and filesystem safe version just in case. - st.session_state["run_id"] = base64.urlsafe_b64encode( - os.urandom(6) - ).decode("ascii") - - # check what models we want to use and if the they are already loaded. - - if use_GFPGAN: - if "GFPGAN" in st.session_state: - print("GFPGAN already loaded") - else: - # Load GFPGAN - if os.path.exists(defaults.general.GFPGAN_dir): - try: - st.session_state["GFPGAN"] = load_GFPGAN() - print("Loaded GFPGAN") - except Exception: - import traceback - print("Error loading GFPGAN:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) - else: - if "GFPGAN" in st.session_state: - del st.session_state["GFPGAN"] - - if use_RealESRGAN: - if "RealESRGAN" in st.session_state and st.session_state["RealESRGAN"].model.name == RealESRGAN_model: - print("RealESRGAN already loaded") - else: - #Load RealESRGAN - try: - # We first remove the variable in case it has something there, - # some errors can load the model incorrectly and leave things in memory. - del st.session_state["RealESRGAN"] - except KeyError: - pass - - if os.path.exists(defaults.general.RealESRGAN_dir): - # st.session_state is used for keeping the models in memory across multiple pages or runs. - st.session_state["RealESRGAN"] = load_RealESRGAN(RealESRGAN_model) - print("Loaded RealESRGAN with model "+ st.session_state["RealESRGAN"].model.name) - - else: - if "RealESRGAN" in st.session_state: - del st.session_state["RealESRGAN"] - - - - if "model" in st.session_state: - if "model" in st.session_state and st.session_state["custom_model"] == custom_model: - print("Model already loaded") - else: - try: - del st.session_state["model"] - except KeyError: - pass - - config = OmegaConf.load(defaults.general.default_model_config) - - if custom_model == defaults.general.default_model: - model = load_model_from_config(config, defaults.general.default_model_path) - else: - model = load_model_from_config(config, os.path.join("models","custom", f"{custom_model}.ckpt")) - - st.session_state["custom_model"] = custom_model - st.session_state["device"] = torch.device(f"cuda:{defaults.general.gpu}") if torch.cuda.is_available() else torch.device("cpu") - st.session_state["model"] = (model if defaults.general.no_half else model.half()).to(st.session_state["device"] ) - else: - config = OmegaConf.load(defaults.general.default_model_config) - - if custom_model == defaults.general.default_model: - model = load_model_from_config(config, defaults.general.default_model_path) - else: - model = load_model_from_config(config, os.path.join("models","custom", f"{custom_model}.ckpt")) - - st.session_state["custom_model"] = custom_model - st.session_state["device"] = torch.device(f"cuda:{defaults.general.gpu}") if torch.cuda.is_available() else torch.device("cpu") - st.session_state["model"] = (model if defaults.general.no_half else model.half()).to(st.session_state["device"] ) - - print("Model loaded.") - - -def load_model_from_config(config, ckpt, verbose=False): - - print(f"Loading model from {ckpt}") - - pl_sd = torch.load(ckpt, map_location="cpu") - if "global_step" in pl_sd: - print(f"Global Step: {pl_sd['global_step']}") - sd = pl_sd["state_dict"] - model = instantiate_from_config(config.model) - m, u = model.load_state_dict(sd, strict=False) - if len(m) > 0 and verbose: - print("missing keys:") - print(m) - if len(u) > 0 and verbose: - print("unexpected keys:") - print(u) - - model.cuda() - model.eval() - return model - -def load_sd_from_config(ckpt, verbose=False): - print(f"Loading model from {ckpt}") - pl_sd = torch.load(ckpt, map_location="cpu") - if "global_step" in pl_sd: - print(f"Global Step: {pl_sd['global_step']}") - sd = pl_sd["state_dict"] - return sd -# -@retry(tries=5) -def generation_callback(img, i=0): - - try: - if i == 0: - if img['i']: i = img['i'] - except TypeError: - pass - - - if i % int(defaults.general.update_preview_frequency) == 0 and defaults.general.update_preview: - #print (img) - #print (type(img)) - # The following lines will convert the tensor we got on img to an actual image we can render on the UI. - # It can probably be done in a better way for someone who knows what they're doing. I don't. - #print (img,isinstance(img, torch.Tensor)) - if isinstance(img, torch.Tensor): - x_samples_ddim = (st.session_state["model"] if not defaults.general.optimized else modelFS).decode_first_stage(img) - else: - # When using the k Diffusion samplers they return a dict instead of a tensor that look like this: - # {'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised} - x_samples_ddim = (st.session_state["model"] if not defaults.general.optimized else modelFS).decode_first_stage(img["denoised"]) - - x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) - - pil_image = transforms.ToPILImage()(x_samples_ddim.squeeze_(0)) - - # update image on the UI so we can see the progress - st.session_state["preview_image"].image(pil_image) - - # Show a progress bar so we can keep track of the progress even when the image progress is not been shown, - # Dont worry, it doesnt affect the performance. - if st.session_state["generation_mode"] == "txt2img": - percent = int(100 * float(i+1 if i+1 < st.session_state.sampling_steps else st.session_state.sampling_steps)/float(st.session_state.sampling_steps)) - st.session_state["progress_bar_text"].text( - f"Running step: {i+1 if i+1 < st.session_state.sampling_steps else st.session_state.sampling_steps}/{st.session_state.sampling_steps} {percent if percent < 100 else 100}%") - else: - if st.session_state["generation_mode"] == "img2img": - round_sampling_steps = round(st.session_state.sampling_steps * st.session_state["denoising_strength"]) - percent = int(100 * float(i+1 if i+1 < round_sampling_steps else round_sampling_steps)/float(round_sampling_steps)) - st.session_state["progress_bar_text"].text( - f"""Running step: {i+1 if i+1 < round_sampling_steps else round_sampling_steps}/{round_sampling_steps} {percent if percent < 100 else 100}%""") - else: - if st.session_state["generation_mode"] == "txt2vid": - percent = int(100 * float(i+1 if i+1 < st.session_state.sampling_steps else st.session_state.sampling_steps)/float(st.session_state.sampling_steps)) - st.session_state["progress_bar_text"].text( - f"Running step: {i+1 if i+1 < st.session_state.sampling_steps else st.session_state.sampling_steps}/{st.session_state.sampling_steps}" - f"{percent if percent < 100 else 100}%") - - st.session_state["progress_bar"].progress(percent if percent < 100 else 100) - - - -class MemUsageMonitor(threading.Thread): - stop_flag = False - max_usage = 0 - total = -1 - - def __init__(self, name): - threading.Thread.__init__(self) - self.name = name - - def run(self): - try: - pynvml.nvmlInit() - except: - print(f"[{self.name}] Unable to initialize NVIDIA management. No memory stats. \n") - return - print(f"[{self.name}] Recording max memory usage...\n") - handle = pynvml.nvmlDeviceGetHandleByIndex(defaults.general.gpu) - self.total = pynvml.nvmlDeviceGetMemoryInfo(handle).total - while not self.stop_flag: - m = pynvml.nvmlDeviceGetMemoryInfo(handle) - self.max_usage = max(self.max_usage, m.used) - # print(self.max_usage) - time.sleep(0.1) - print(f"[{self.name}] Stopped recording.\n") - pynvml.nvmlShutdown() - - def read(self): - return self.max_usage, self.total - - def stop(self): - self.stop_flag = True - - def read_and_stop(self): - self.stop_flag = True - return self.max_usage, self.total - -class CFGMaskedDenoiser(nn.Module): - def __init__(self, model): - super().__init__() - self.inner_model = model - - def forward(self, x, sigma, uncond, cond, cond_scale, mask, x0, xi): - x_in = x - x_in = torch.cat([x_in] * 2) - sigma_in = torch.cat([sigma] * 2) - cond_in = torch.cat([uncond, cond]) - uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2) - denoised = uncond + (cond - uncond) * cond_scale - - if mask is not None: - assert x0 is not None - img_orig = x0 - mask_inv = 1. - mask - denoised = (img_orig * mask_inv) + (mask * denoised) - - return denoised - -class CFGDenoiser(nn.Module): - def __init__(self, model): - super().__init__() - self.inner_model = model - - def forward(self, x, sigma, uncond, cond, cond_scale): - x_in = torch.cat([x] * 2) - sigma_in = torch.cat([sigma] * 2) - cond_in = torch.cat([uncond, cond]) - uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2) - return uncond + (cond - uncond) * cond_scale -def append_zero(x): - return torch.cat([x, x.new_zeros([1])]) -def append_dims(x, target_dims): - """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" - dims_to_append = target_dims - x.ndim - if dims_to_append < 0: - raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less') - return x[(...,) + (None,) * dims_to_append] -def get_sigmas_karras(n, sigma_min, sigma_max, rho=7., device='cpu'): - """Constructs the noise schedule of Karras et al. (2022).""" - ramp = torch.linspace(0, 1, n) - min_inv_rho = sigma_min ** (1 / rho) - max_inv_rho = sigma_max ** (1 / rho) - sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho - return append_zero(sigmas).to(device) - - -def get_sigmas_exponential(n, sigma_min, sigma_max, device='cpu'): - """Constructs an exponential noise schedule.""" - sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), n, device=device).exp() - return append_zero(sigmas) - - -def get_sigmas_vp(n, beta_d=19.9, beta_min=0.1, eps_s=1e-3, device='cpu'): - """Constructs a continuous VP noise schedule.""" - t = torch.linspace(1, eps_s, n, device=device) - sigmas = torch.sqrt(torch.exp(beta_d * t ** 2 / 2 + beta_min * t) - 1) - return append_zero(sigmas) - - -def to_d(x, sigma, denoised): - """Converts a denoiser output to a Karras ODE derivative.""" - return (x - denoised) / append_dims(sigma, x.ndim) -def linear_multistep_coeff(order, t, i, j): - if order - 1 > i: - raise ValueError(f'Order {order} too high for step {i}') - def fn(tau): - prod = 1. - for k in range(order): - if j == k: - continue - prod *= (tau - t[i - k]) / (t[i - j] - t[i - k]) - return prod - return integrate.quad(fn, t[i], t[i + 1], epsrel=1e-4)[0] - -class KDiffusionSampler: - def __init__(self, m, sampler): - self.model = m - self.model_wrap = K.external.CompVisDenoiser(m) - self.schedule = sampler - def get_sampler_name(self): - return self.schedule - def sample(self, S, conditioning, batch_size, shape, verbose, unconditional_guidance_scale, unconditional_conditioning, eta, x_T, img_callback=None, log_every_t=None): - sigmas = self.model_wrap.get_sigmas(S) - x = x_T * sigmas[0] - model_wrap_cfg = CFGDenoiser(self.model_wrap) - samples_ddim = None - samples_ddim = K.sampling.__dict__[f'sample_{self.schedule}'](model_wrap_cfg, x, sigmas, - extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, - 'cond_scale': unconditional_guidance_scale}, disable=False, callback=generation_callback) - # - return samples_ddim, None - - -@torch.no_grad() -def log_likelihood(model, x, sigma_min, sigma_max, extra_args=None, atol=1e-4, rtol=1e-4): - extra_args = {} if extra_args is None else extra_args - s_in = x.new_ones([x.shape[0]]) - v = torch.randint_like(x, 2) * 2 - 1 - fevals = 0 - def ode_fn(sigma, x): - nonlocal fevals - with torch.enable_grad(): - x = x[0].detach().requires_grad_() - denoised = model(x, sigma * s_in, **extra_args) - d = to_d(x, sigma, denoised) - fevals += 1 - grad = torch.autograd.grad((d * v).sum(), x)[0] - d_ll = (v * grad).flatten(1).sum(1) - return d.detach(), d_ll - x_min = x, x.new_zeros([x.shape[0]]) - t = x.new_tensor([sigma_min, sigma_max]) - sol = odeint(ode_fn, x_min, t, atol=atol, rtol=rtol, method='dopri5') - latent, delta_ll = sol[0][-1], sol[1][-1] - ll_prior = torch.distributions.Normal(0, sigma_max).log_prob(latent).flatten(1).sum(1) - return ll_prior + delta_ll, {'fevals': fevals} - - -def create_random_tensors(shape, seeds): - xs = [] - for seed in seeds: - torch.manual_seed(seed) - - # randn results depend on device; gpu and cpu get different results for same seed; - # the way I see it, it's better to do this on CPU, so that everyone gets same result; - # but the original script had it like this so i do not dare change it for now because - # it will break everyone's seeds. - xs.append(torch.randn(shape, device=defaults.general.gpu)) - x = torch.stack(xs) - return x - -def torch_gc(): - torch.cuda.empty_cache() - torch.cuda.ipc_collect() - -def load_GFPGAN(): - model_name = 'GFPGANv1.3' - model_path = os.path.join(defaults.general.GFPGAN_dir, 'experiments/pretrained_models', model_name + '.pth') - if not os.path.isfile(model_path): - raise Exception("GFPGAN model not found at path "+model_path) - - sys.path.append(os.path.abspath(defaults.general.GFPGAN_dir)) - from gfpgan import GFPGANer - - if defaults.general.gfpgan_cpu or defaults.general.extra_models_cpu: - instance = GFPGANer(model_path=model_path, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=torch.device('cpu')) - elif defaults.general.extra_models_gpu: - instance = GFPGANer(model_path=model_path, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=torch.device(f'cuda:{defaults.general.gfpgan_gpu}')) - else: - instance = GFPGANer(model_path=model_path, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=torch.device(f'cuda:{defaults.general.gpu}')) - return instance - -def load_RealESRGAN(model_name: str): - from basicsr.archs.rrdbnet_arch import RRDBNet - RealESRGAN_models = { - 'RealESRGAN_x4plus': RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4), - 'RealESRGAN_x4plus_anime_6B': RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4) - } - - model_path = os.path.join(defaults.general.RealESRGAN_dir, 'experiments/pretrained_models', model_name + '.pth') - if not os.path.exists(os.path.join(defaults.general.RealESRGAN_dir, "experiments","pretrained_models", f"{model_name}.pth")): - raise Exception(model_name+".pth not found at path "+model_path) - - sys.path.append(os.path.abspath(defaults.general.RealESRGAN_dir)) - from realesrgan import RealESRGANer - - if defaults.general.esrgan_cpu or defaults.general.extra_models_cpu: - instance = RealESRGANer(scale=2, model_path=model_path, model=RealESRGAN_models[model_name], pre_pad=0, half=False) # cpu does not support half - instance.device = torch.device('cpu') - instance.model.to('cpu') - elif defaults.general.extra_models_gpu: - instance = RealESRGANer(scale=2, model_path=model_path, model=RealESRGAN_models[model_name], pre_pad=0, half=not defaults.general.no_half, device=torch.device(f'cuda:{defaults.general.esrgan_gpu}')) - else: - instance = RealESRGANer(scale=2, model_path=model_path, model=RealESRGAN_models[model_name], pre_pad=0, half=not defaults.general.no_half, device=torch.device(f'cuda:{defaults.general.gpu}')) - instance.model.name = model_name - - return instance - -prompt_parser = re.compile(""" - (?P # capture group for 'prompt' - [^:]+ # match one or more non ':' characters - ) # end 'prompt' - (?: # non-capture group - :+ # match one or more ':' characters - (?P # capture group for 'weight' - -?\\d+(?:\\.\\d+)? # match positive or negative decimal number - )? # end weight capture group, make optional - \\s* # strip spaces after weight - | # OR - $ # else, if no ':' then match end of line - ) # end non-capture group -""", re.VERBOSE) - -# grabs all text up to the first occurrence of ':' as sub-prompt -# takes the value following ':' as weight -# if ':' has no value defined, defaults to 1.0 -# repeats until no text remaining -def split_weighted_subprompts(input_string, normalize=True): - parsed_prompts = [(match.group("prompt"), float(match.group("weight") or 1)) for match in re.finditer(prompt_parser, input_string)] - if not normalize: - return parsed_prompts - # this probably still doesn't handle negative weights very well - weight_sum = sum(map(lambda x: x[1], parsed_prompts)) - return [(x[0], x[1] / weight_sum) for x in parsed_prompts] - -def slerp(device, t, v0:torch.Tensor, v1:torch.Tensor, DOT_THRESHOLD=0.9995): - v0 = v0.detach().cpu().numpy() - v1 = v1.detach().cpu().numpy() - - dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1))) - if np.abs(dot) > DOT_THRESHOLD: - v2 = (1 - t) * v0 + t * v1 - else: - theta_0 = np.arccos(dot) - sin_theta_0 = np.sin(theta_0) - theta_t = theta_0 * t - sin_theta_t = np.sin(theta_t) - s0 = np.sin(theta_0 - theta_t) / sin_theta_0 - s1 = sin_theta_t / sin_theta_0 - v2 = s0 * v0 + s1 * v1 - - v2 = torch.from_numpy(v2).to(device) - - return v2 - - -def optimize_update_preview_frequency(current_chunk_speed, previous_chunk_speed, update_preview_frequency): - """Find the optimal update_preview_frequency value maximizing - performance while minimizing the time between updates.""" - if current_chunk_speed >= previous_chunk_speed: - #print(f"{current_chunk_speed} >= {previous_chunk_speed}") - update_preview_frequency +=1 - previous_chunk_speed = current_chunk_speed - else: - #print(f"{current_chunk_speed} <= {previous_chunk_speed}") - update_preview_frequency -=1 - previous_chunk_speed = current_chunk_speed - - return current_chunk_speed, previous_chunk_speed, update_preview_frequency - -# ----------------------------------------------------------------------------- - -@torch.no_grad() -def diffuse( - pipe, - cond_embeddings, # text conditioning, should be (1, 77, 768) - cond_latents, # image conditioning, should be (1, 4, 64, 64) - num_inference_steps, - cfg_scale, - eta, - ): - - torch_device = cond_latents.get_device() - - # classifier guidance: add the unconditional embedding - max_length = cond_embeddings.shape[1] # 77 - uncond_input = pipe.tokenizer([""], padding="max_length", max_length=max_length, return_tensors="pt") - uncond_embeddings = pipe.text_encoder(uncond_input.input_ids.to(torch_device))[0] - text_embeddings = torch.cat([uncond_embeddings, cond_embeddings]) - - # if we use LMSDiscreteScheduler, let's make sure latents are mulitplied by sigmas - if isinstance(pipe.scheduler, LMSDiscreteScheduler): - cond_latents = cond_latents * pipe.scheduler.sigmas[0] - - # init the scheduler - accepts_offset = "offset" in set(inspect.signature(pipe.scheduler.set_timesteps).parameters.keys()) - extra_set_kwargs = {} - if accepts_offset: - extra_set_kwargs["offset"] = 1 - - pipe.scheduler.set_timesteps(num_inference_steps + st.session_state.sampling_steps, **extra_set_kwargs) - # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature - # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. - # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 - # and should be between [0, 1] - accepts_eta = "eta" in set(inspect.signature(pipe.scheduler.step).parameters.keys()) - extra_step_kwargs = {} - if accepts_eta: - extra_step_kwargs["eta"] = eta - - - step_counter = 0 - inference_counter = 0 - current_chunk_speed = 0 - previous_chunk_speed = 0 - - # diffuse! - for i, t in enumerate(pipe.scheduler.timesteps): - start = timeit.default_timer() - - #status_text.text(f"Running step: {step_counter}{total_number_steps} {percent} | {duration:.2f}{speed}") - - # expand the latents for classifier free guidance - latent_model_input = torch.cat([cond_latents] * 2) - if isinstance(pipe.scheduler, LMSDiscreteScheduler): - sigma = pipe.scheduler.sigmas[i] - latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5) - - # predict the noise residual - noise_pred = pipe.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"] - - # cfg - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + cfg_scale * (noise_pred_text - noise_pred_uncond) - - # compute the previous noisy sample x_t -> x_t-1 - if isinstance(pipe.scheduler, LMSDiscreteScheduler): - cond_latents = pipe.scheduler.step(noise_pred, i, cond_latents, **extra_step_kwargs)["prev_sample"] - else: - cond_latents = pipe.scheduler.step(noise_pred, t, cond_latents, **extra_step_kwargs)["prev_sample"] - - #print (st.session_state["update_preview_frequency"]) - #update the preview image if it is enabled and the frequency matches the step_counter - if defaults.general.update_preview: - step_counter += 1 - - if st.session_state.dynamic_preview_frequency: - current_chunk_speed, previous_chunk_speed, defaults.general.update_preview_frequency = optimize_update_preview_frequency( - current_chunk_speed, previous_chunk_speed, defaults.general.update_preview_frequency) - - if defaults.general.update_preview_frequency == step_counter or step_counter == st.session_state.sampling_steps: - #scale and decode the image latents with vae - cond_latents_2 = 1 / 0.18215 * cond_latents - image_2 = pipe.vae.decode(cond_latents_2) - - # generate output numpy image as uint8 - image_2 = (image_2 / 2 + 0.5).clamp(0, 1) - image_2 = image_2.cpu().permute(0, 2, 3, 1).numpy() - image_2 = (image_2[0] * 255).astype(np.uint8) - - st.session_state["preview_image"].image(image_2) - - step_counter = 0 - - duration = timeit.default_timer() - start - - current_chunk_speed = duration - - if duration >= 1: - speed = "s/it" - else: - speed = "it/s" - duration = 1 / duration - - if i > st.session_state.sampling_steps: - inference_counter += 1 - inference_percent = int(100 * float(inference_counter if inference_counter < num_inference_steps else num_inference_steps)/float(num_inference_steps)) - inference_progress = f"{inference_counter if inference_counter < num_inference_steps else num_inference_steps}/{num_inference_steps} {inference_percent}% " - else: - inference_progress = "" - - percent = int(100 * float(i+1 if i+1 < st.session_state.sampling_steps else st.session_state.sampling_steps)/float(st.session_state.sampling_steps)) - frames_percent = int(100 * float(st.session_state.current_frame if st.session_state.current_frame < st.session_state.max_frames else st.session_state.max_frames)/float(st.session_state.max_frames)) - - st.session_state["progress_bar_text"].text( - f"Running step: {i+1 if i+1 < st.session_state.sampling_steps else st.session_state.sampling_steps}/{st.session_state.sampling_steps} " - f"{percent if percent < 100 else 100}% {inference_progress}{duration:.2f}{speed} | " - f"Frame: {st.session_state.current_frame if st.session_state.current_frame < st.session_state.max_frames else st.session_state.max_frames}/{st.session_state.max_frames} " - f"{frames_percent if frames_percent < 100 else 100}% {st.session_state.frame_duration:.2f}{st.session_state.frame_speed}" - ) - st.session_state["progress_bar"].progress(percent if percent < 100 else 100) - - # scale and decode the image latents with vae - cond_latents = 1 / 0.18215 * cond_latents - image = pipe.vae.decode(cond_latents) - - # generate output numpy image as uint8 - image = (image / 2 + 0.5).clamp(0, 1) - image = image.cpu().permute(0, 2, 3, 1).numpy() - image = (image[0] * 255).astype(np.uint8) - - return image - - -def ModelLoader(models,load=False,unload=False,imgproc_realesrgan_model_name='RealESRGAN_x4plus'): - #get global variables - global_vars = globals() - #check if m is in globals - if unload: - for m in models: - if m in global_vars: - #if it is, delete it - del global_vars[m] - if defaults.general.optimized: - if m == 'model': - del global_vars[m+'FS'] - del global_vars[m+'CS'] - if m =='model': - m='Stable Diffusion' - print('Unloaded ' + m) - if load: - for m in models: - if m not in global_vars or m in global_vars and type(global_vars[m]) == bool: - #if it isn't, load it - if m == 'GFPGAN': - global_vars[m] = load_GFPGAN() - elif m == 'model': - sdLoader = load_sd_from_config() - global_vars[m] = sdLoader[0] - if defaults.general.optimized: - global_vars[m+'CS'] = sdLoader[1] - global_vars[m+'FS'] = sdLoader[2] - elif m == 'RealESRGAN': - global_vars[m] = load_RealESRGAN(imgproc_realesrgan_model_name) - elif m == 'LDSR': - global_vars[m] = load_LDSR() - if m =='model': - m='Stable Diffusion' - print('Loaded ' + m) - torch_gc() - - - -def get_font(fontsize): - fonts = ["arial.ttf", "DejaVuSans.ttf"] - for font_name in fonts: - try: - return ImageFont.truetype(font_name, fontsize) - except OSError: - pass - - # ImageFont.load_default() is practically unusable as it only supports - # latin1, so raise an exception instead if no usable font was found - raise Exception(f"No usable font found (tried {', '.join(fonts)})") - -def load_embeddings(fp): - if fp is not None and hasattr(st.session_state["model"], "embedding_manager"): - st.session_state["model"].embedding_manager.load(fp['name']) - -def image_grid(imgs, batch_size, force_n_rows=None, captions=None): - #print (len(imgs)) - if force_n_rows is not None: - rows = force_n_rows - elif defaults.general.n_rows > 0: - rows = defaults.general.n_rows - elif defaults.general.n_rows == 0: - rows = batch_size - else: - rows = math.sqrt(len(imgs)) - rows = round(rows) - - cols = math.ceil(len(imgs) / rows) - - w, h = imgs[0].size - grid = Image.new('RGB', size=(cols * w, rows * h), color='black') - - fnt = get_font(30) - - for i, img in enumerate(imgs): - grid.paste(img, box=(i % cols * w, i // cols * h)) - if captions and i= 2**32: - n = n >> 32 - return n - -def check_prompt_length(prompt, comments): - """this function tests if prompt is too long, and if so, adds a message to comments""" - - tokenizer = (st.session_state["model"] if not defaults.general.optimized else modelCS).cond_stage_model.tokenizer - max_length = (st.session_state["model"] if not defaults.general.optimized else modelCS).cond_stage_model.max_length - - info = (st.session_state["model"] if not defaults.general.optimized else modelCS).cond_stage_model.tokenizer([prompt], truncation=True, max_length=max_length, - return_overflowing_tokens=True, padding="max_length", return_tensors="pt") - ovf = info['overflowing_tokens'][0] - overflowing_count = ovf.shape[0] - if overflowing_count == 0: - return - - vocab = {v: k for k, v in tokenizer.get_vocab().items()} - overflowing_words = [vocab.get(int(x), "") for x in ovf] - overflowing_text = tokenizer.convert_tokens_to_string(''.join(overflowing_words)) - - comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n") - -def save_sample(image, sample_path_i, filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale, - normalize_prompt_weights, use_GFPGAN, write_info_files, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, - save_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, save_individual_images): - - filename_i = os.path.join(sample_path_i, filename) - - if defaults.general.save_metadata or write_info_files: - # toggles differ for txt2img vs. img2img: - offset = 0 if init_img is None else 2 - toggles = [] - if prompt_matrix: - toggles.append(0) - if normalize_prompt_weights: - toggles.append(1) - if init_img is not None: - if uses_loopback: - toggles.append(2) - if uses_random_seed_loopback: - toggles.append(3) - if save_individual_images: - toggles.append(2 + offset) - if save_grid: - toggles.append(3 + offset) - if sort_samples: - toggles.append(4 + offset) - if write_info_files: - toggles.append(5 + offset) - if use_GFPGAN: - toggles.append(6 + offset) - metadata = \ - dict( - target="txt2img" if init_img is None else "img2img", - prompt=prompts[i], ddim_steps=steps, toggles=toggles, sampler_name=sampler_name, - ddim_eta=ddim_eta, n_iter=n_iter, batch_size=batch_size, cfg_scale=cfg_scale, - seed=seeds[i], width=width, height=height, normalize_prompt_weights=normalize_prompt_weights) - # Not yet any use for these, but they bloat up the files: - # info_dict["init_img"] = init_img - # info_dict["init_mask"] = init_mask - if init_img is not None: - metadata["denoising_strength"] = str(denoising_strength) - metadata["resize_mode"] = resize_mode - - if write_info_files: - with open(f"{filename_i}.yaml", "w", encoding="utf8") as f: - yaml.dump(metadata, f, allow_unicode=True, width=10000) - - if defaults.general.save_metadata: - # metadata = { - # "SD:prompt": prompts[i], - # "SD:seed": str(seeds[i]), - # "SD:width": str(width), - # "SD:height": str(height), - # "SD:steps": str(steps), - # "SD:cfg_scale": str(cfg_scale), - # "SD:normalize_prompt_weights": str(normalize_prompt_weights), - # } - metadata = {"SD:" + k:v for (k,v) in metadata.items()} - - if save_ext == "png": - mdata = PngInfo() - for key in metadata: - mdata.add_text(key, str(metadata[key])) - image.save(f"{filename_i}.png", pnginfo=mdata) - else: - if jpg_sample: - image.save(f"{filename_i}.jpg", quality=save_quality, - optimize=True) - elif save_ext == "webp": - image.save(f"{filename_i}.{save_ext}", f"webp", quality=save_quality, - lossless=save_lossless) - else: - # not sure what file format this is - image.save(f"{filename_i}.{save_ext}", f"{save_ext}") - try: - exif_dict = piexif.load(f"{filename_i}.{save_ext}") - except: - exif_dict = { "Exif": dict() } - exif_dict["Exif"][piexif.ExifIFD.UserComment] = piexif.helper.UserComment.dump( - json.dumps(metadata), encoding="unicode") - piexif.insert(piexif.dump(exif_dict), f"{filename_i}.{save_ext}") - - # render the image on the frontend - st.session_state["preview_image"].image(image) - -def get_next_sequence_number(path, prefix=''): - """ - Determines and returns the next sequence number to use when saving an - image in the specified directory. - - If a prefix is given, only consider files whose names start with that - prefix, and strip the prefix from filenames before extracting their - sequence number. - - The sequence starts at 0. - """ - result = -1 - for p in Path(path).iterdir(): - if p.name.endswith(('.png', '.jpg')) and p.name.startswith(prefix): - tmp = p.name[len(prefix):] - try: - result = max(int(tmp.split('-')[0]), result) - except ValueError: - pass - return result + 1 - - -def oxlamon_matrix(prompt, seed, n_iter, batch_size): - pattern = re.compile(r'(,\s){2,}') - - class PromptItem: - def __init__(self, text, parts, item): - self.text = text - self.parts = parts - if item: - self.parts.append( item ) - - def clean(txt): - return re.sub(pattern, ', ', txt) - - def getrowcount( txt ): - for data in re.finditer( ".*?\\((.*?)\\).*", txt ): - if data: - return len(data.group(1).split("|")) - break - return None - - def repliter( txt ): - for data in re.finditer( ".*?\\((.*?)\\).*", txt ): - if data: - r = data.span(1) - for item in data.group(1).split("|"): - yield (clean(txt[:r[0]-1] + item.strip() + txt[r[1]+1:]), item.strip()) - break - - def iterlist( items ): - outitems = [] - for item in items: - for newitem, newpart in repliter(item.text): - outitems.append( PromptItem(newitem, item.parts.copy(), newpart) ) - - return outitems - - def getmatrix( prompt ): - dataitems = [ PromptItem( prompt[1:].strip(), [], None ) ] - while True: - newdataitems = iterlist( dataitems ) - if len( newdataitems ) == 0: - return dataitems - dataitems = newdataitems - - def classToArrays( items, seed, n_iter ): - texts = [] - parts = [] - seeds = [] - - for item in items: - itemseed = seed - for i in range(n_iter): - texts.append( item.text ) - parts.append( f"Seed: {itemseed}\n" + "\n".join(item.parts) ) - seeds.append( itemseed ) - itemseed += 1 - - return seeds, texts, parts - - all_seeds, all_prompts, prompt_matrix_parts = classToArrays(getmatrix( prompt ), seed, n_iter) - n_iter = math.ceil(len(all_prompts) / batch_size) - - needrows = getrowcount(prompt) - if needrows: - xrows = math.sqrt(len(all_prompts)) - xrows = round(xrows) - # if columns is to much - cols = math.ceil(len(all_prompts) / xrows) - if cols > needrows*4: - needrows *= 2 - - return all_seeds, n_iter, prompt_matrix_parts, all_prompts, needrows - - -import find_noise_for_image -import matched_noise - - -def process_images( - outpath, func_init, func_sample, prompt, seed, sampler_name, save_grid, batch_size, - n_iter, steps, cfg_scale, width, height, prompt_matrix, use_GFPGAN, use_RealESRGAN, realesrgan_model_name, - fp=None, ddim_eta=0.0, normalize_prompt_weights=True, init_img=None, init_mask=None, - mask_blur_strength=3, mask_restore=False, denoising_strength=0.75, noise_mode=0, find_noise_steps=1, resize_mode=None, uses_loopback=False, - uses_random_seed_loopback=False, sort_samples=True, write_info_files=True, jpg_sample=False, - variant_amount=0.0, variant_seed=None, save_individual_images: bool = True): - """this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch""" - assert prompt is not None - torch_gc() - # start time after garbage collection (or before?) - start_time = time.time() - - # We will use this date here later for the folder name, need to start_time if not need - run_start_dt = datetime.datetime.now() - - mem_mon = MemUsageMonitor('MemMon') - mem_mon.start() - - if hasattr(st.session_state["model"], "embedding_manager"): - load_embeddings(fp) - - os.makedirs(outpath, exist_ok=True) - - sample_path = os.path.join(outpath, "samples") - os.makedirs(sample_path, exist_ok=True) - - if not ("|" in prompt) and prompt.startswith("@"): - prompt = prompt[1:] - - comments = [] - - prompt_matrix_parts = [] - simple_templating = False - add_original_image = not (use_RealESRGAN or use_GFPGAN) - - if prompt_matrix: - if prompt.startswith("@"): - simple_templating = True - add_original_image = not (use_RealESRGAN or use_GFPGAN) - all_seeds, n_iter, prompt_matrix_parts, all_prompts, frows = oxlamon_matrix(prompt, seed, n_iter, batch_size) - else: - all_prompts = [] - prompt_matrix_parts = prompt.split("|") - combination_count = 2 ** (len(prompt_matrix_parts) - 1) - for combination_num in range(combination_count): - current = prompt_matrix_parts[0] - - for n, text in enumerate(prompt_matrix_parts[1:]): - if combination_num & (2 ** n) > 0: - current += ("" if text.strip().startswith(",") else ", ") + text - - all_prompts.append(current) - - n_iter = math.ceil(len(all_prompts) / batch_size) - all_seeds = len(all_prompts) * [seed] - - print(f"Prompt matrix will create {len(all_prompts)} images using a total of {n_iter} batches.") - else: - - if not defaults.general.no_verify_input: - try: - check_prompt_length(prompt, comments) - except: - import traceback - print("Error verifying input:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) - - all_prompts = batch_size * n_iter * [prompt] - all_seeds = [seed + x for x in range(len(all_prompts))] - - precision_scope = autocast if defaults.general.precision == "autocast" else nullcontext - output_images = [] - grid_captions = [] - stats = [] - with torch.no_grad(), precision_scope("cuda"), (st.session_state["model"].ema_scope() if not defaults.general.optimized else nullcontext()): - init_data = func_init() - tic = time.time() - - - # if variant_amount > 0.0 create noise from base seed - base_x = None - if variant_amount > 0.0: - target_seed_randomizer = seed_to_int('') # random seed - torch.manual_seed(seed) # this has to be the single starting seed (not per-iteration) - base_x = create_random_tensors([opt_C, height // opt_f, width // opt_f], seeds=[seed]) - # we don't want all_seeds to be sequential from starting seed with variants, - # since that makes the same variants each time, - # so we add target_seed_randomizer as a random offset - for si in range(len(all_seeds)): - all_seeds[si] += target_seed_randomizer - - for n in range(n_iter): - print(f"Iteration: {n+1}/{n_iter}") - prompts = all_prompts[n * batch_size:(n + 1) * batch_size] - captions = prompt_matrix_parts[n * batch_size:(n + 1) * batch_size] - seeds = all_seeds[n * batch_size:(n + 1) * batch_size] - - print(prompt) - - if defaults.general.optimized: - modelCS.to(defaults.general.gpu) - - uc = (st.session_state["model"] if not defaults.general.optimized else modelCS).get_learned_conditioning(len(prompts) * [""]) - - if isinstance(prompts, tuple): - prompts = list(prompts) - - # split the prompt if it has : for weighting - # TODO for speed it might help to have this occur when all_prompts filled?? - weighted_subprompts = split_weighted_subprompts(prompts[0], normalize_prompt_weights) - - # sub-prompt weighting used if more than 1 - if len(weighted_subprompts) > 1: - c = torch.zeros_like(uc) # i dont know if this is correct.. but it works - for i in range(0, len(weighted_subprompts)): - # note if alpha negative, it functions same as torch.sub - c = torch.add(c, (st.session_state["model"] if not defaults.general.optimized else modelCS).get_learned_conditioning(weighted_subprompts[i][0]), alpha=weighted_subprompts[i][1]) - else: # just behave like usual - c = (st.session_state["model"] if not defaults.general.optimized else modelCS).get_learned_conditioning(prompts) - - - shape = [opt_C, height // opt_f, width // opt_f] - - if defaults.general.optimized: - mem = torch.cuda.memory_allocated()/1e6 - modelCS.to("cpu") - while(torch.cuda.memory_allocated()/1e6 >= mem): - time.sleep(1) - - if noise_mode == 1 or noise_mode == 3: - # TODO params for find_noise_to_image - x = torch.cat(batch_size * [find_noise_for_image.find_noise_for_image( - st.session_state["model"], st.session_state["device"], - init_img.convert('RGB'), '', find_noise_steps, 0.0, normalize=True, - generation_callback=generation_callback, - )], dim=0) - else: - # we manually generate all input noises because each one should have a specific seed - x = create_random_tensors(shape, seeds=seeds) - - if variant_amount > 0.0: # we are making variants - # using variant_seed as sneaky toggle, - # when not None or '' use the variant_seed - # otherwise use seeds - if variant_seed != None and variant_seed != '': - specified_variant_seed = seed_to_int(variant_seed) - torch.manual_seed(specified_variant_seed) - seeds = [specified_variant_seed] - # finally, slerp base_x noise to target_x noise for creating a variant - x = slerp(defaults.general.gpu, max(0.0, min(1.0, variant_amount)), base_x, x) - - samples_ddim = func_sample(init_data=init_data, x=x, conditioning=c, unconditional_conditioning=uc, sampler_name=sampler_name) - - if defaults.general.optimized: - modelFS.to(defaults.general.gpu) - - x_samples_ddim = (st.session_state["model"] if not defaults.general.optimized else modelFS).decode_first_stage(samples_ddim) - x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) - - for i, x_sample in enumerate(x_samples_ddim): - sanitized_prompt = slugify(prompts[i]) - - if sort_samples: - full_path = os.path.join(os.getcwd(), sample_path, sanitized_prompt) - - - sanitized_prompt = sanitized_prompt[:220-len(full_path)] - sample_path_i = os.path.join(sample_path, sanitized_prompt) - - #print(f"output folder length: {len(os.path.join(os.getcwd(), sample_path_i))}") - #print(os.path.join(os.getcwd(), sample_path_i)) - - os.makedirs(sample_path_i, exist_ok=True) - base_count = get_next_sequence_number(sample_path_i) - filename = f"{base_count:05}-{steps}_{sampler_name}_{seeds[i]}" - else: - full_path = os.path.join(os.getcwd(), sample_path) - sample_path_i = sample_path - base_count = get_next_sequence_number(sample_path_i) - filename = f"{base_count:05}-{steps}_{sampler_name}_{seeds[i]}_{sanitized_prompt}"[:220-len(full_path)] #same as before - - x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') - x_sample = x_sample.astype(np.uint8) - image = Image.fromarray(x_sample) - original_sample = x_sample - original_filename = filename - - if use_GFPGAN and st.session_state["GFPGAN"] is not None and not use_RealESRGAN: - #skip_save = True # #287 >_> - torch_gc() - cropped_faces, restored_faces, restored_img = st.session_state["GFPGAN"].enhance(x_sample[:,:,::-1], has_aligned=False, only_center_face=False, paste_back=True) - gfpgan_sample = restored_img[:,:,::-1] - gfpgan_image = Image.fromarray(gfpgan_sample) - gfpgan_filename = original_filename + '-gfpgan' - - save_sample(gfpgan_image, sample_path_i, gfpgan_filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale, - normalize_prompt_weights, use_GFPGAN, write_info_files, prompt_matrix, init_img, uses_loopback, - uses_random_seed_loopback, save_grid, sort_samples, sampler_name, ddim_eta, - n_iter, batch_size, i, denoising_strength, resize_mode, save_individual_images=False) - - output_images.append(gfpgan_image) #287 - if simple_templating: - grid_captions.append( captions[i] + "\ngfpgan" ) - - if use_RealESRGAN and st.session_state["RealESRGAN"] is not None and not use_GFPGAN: - #skip_save = True # #287 >_> - torch_gc() - - if st.session_state["RealESRGAN"].model.name != realesrgan_model_name: - #try_loading_RealESRGAN(realesrgan_model_name) - load_models(use_GFPGAN=use_GFPGAN, use_RealESRGAN=use_RealESRGAN, RealESRGAN_model=realesrgan_model_name) - - output, img_mode = st.session_state["RealESRGAN"].enhance(x_sample[:,:,::-1]) - esrgan_filename = original_filename + '-esrgan4x' - esrgan_sample = output[:,:,::-1] - esrgan_image = Image.fromarray(esrgan_sample) - - #save_sample(image, sample_path_i, original_filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale, - #normalize_prompt_weights, use_GFPGAN, write_info_files, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, skip_save, - #save_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode) - - save_sample(esrgan_image, sample_path_i, esrgan_filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale, - normalize_prompt_weights, use_GFPGAN, write_info_files, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, - save_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, save_individual_images=False) - - output_images.append(esrgan_image) #287 - if simple_templating: - grid_captions.append( captions[i] + "\nesrgan" ) - - if use_RealESRGAN and st.session_state["RealESRGAN"] is not None and use_GFPGAN and st.session_state["GFPGAN"] is not None: - #skip_save = True # #287 >_> - torch_gc() - cropped_faces, restored_faces, restored_img = st.session_state["GFPGAN"].enhance(x_sample[:,:,::-1], has_aligned=False, only_center_face=False, paste_back=True) - gfpgan_sample = restored_img[:,:,::-1] - - if st.session_state["RealESRGAN"].model.name != realesrgan_model_name: - #try_loading_RealESRGAN(realesrgan_model_name) - load_models(use_GFPGAN=use_GFPGAN, use_RealESRGAN=use_RealESRGAN, RealESRGAN_model=realesrgan_model_name) - - output, img_mode = st.session_state["RealESRGAN"].enhance(gfpgan_sample[:,:,::-1]) - gfpgan_esrgan_filename = original_filename + '-gfpgan-esrgan4x' - gfpgan_esrgan_sample = output[:,:,::-1] - gfpgan_esrgan_image = Image.fromarray(gfpgan_esrgan_sample) - - save_sample(gfpgan_esrgan_image, sample_path_i, gfpgan_esrgan_filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale, - normalize_prompt_weights, False, write_info_files, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, - save_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, save_individual_images=False) - - output_images.append(gfpgan_esrgan_image) #287 - - if simple_templating: - grid_captions.append( captions[i] + "\ngfpgan_esrgan" ) - - if mask_restore and init_mask: - #init_mask = init_mask if keep_mask else ImageOps.invert(init_mask) - init_mask = init_mask.filter(ImageFilter.GaussianBlur(mask_blur_strength)) - init_mask = init_mask.convert('L') - init_img = init_img.convert('RGB') - image = image.convert('RGB') - - if use_RealESRGAN and st.session_state["RealESRGAN"] is not None: - if st.session_state["RealESRGAN"].model.name != realesrgan_model_name: - #try_loading_RealESRGAN(realesrgan_model_name) - load_models(use_GFPGAN=use_GFPGAN, use_RealESRGAN=use_RealESRGAN, RealESRGAN_model=realesrgan_model_name) - - output, img_mode = st.session_state["RealESRGAN"].enhance(np.array(init_img, dtype=np.uint8)) - init_img = Image.fromarray(output) - init_img = init_img.convert('RGB') - - output, img_mode = st.session_state["RealESRGAN"].enhance(np.array(init_mask, dtype=np.uint8)) - init_mask = Image.fromarray(output) - init_mask = init_mask.convert('L') - - image = Image.composite(init_img, image, init_mask) - - if save_individual_images: - save_sample(image, sample_path_i, filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale, - normalize_prompt_weights, use_GFPGAN, write_info_files, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, - save_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, save_individual_images) - - if not use_GFPGAN or not use_RealESRGAN: - output_images.append(image) - - #if add_original_image or not simple_templating: - #output_images.append(image) - #if simple_templating: - #grid_captions.append( captions[i] ) - - if defaults.general.optimized: - mem = torch.cuda.memory_allocated()/1e6 - modelFS.to("cpu") - while(torch.cuda.memory_allocated()/1e6 >= mem): - time.sleep(1) - - if prompt_matrix or save_grid: - if prompt_matrix: - if simple_templating: - grid = image_grid(output_images, n_iter, force_n_rows=frows, captions=grid_captions) - else: - grid = image_grid(output_images, n_iter, force_n_rows=1 << ((len(prompt_matrix_parts)-1)//2)) - try: - grid = draw_prompt_matrix(grid, width, height, prompt_matrix_parts) - except: - import traceback - print("Error creating prompt_matrix text:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) - else: - grid = image_grid(output_images, batch_size) - - if grid and (batch_size > 1 or n_iter > 1): - output_images.insert(0, grid) - - grid_count = get_next_sequence_number(outpath, 'grid-') - grid_file = f"grid-{grid_count:05}-{seed}_{slugify(prompts[i].replace(' ', '_')[:220-len(full_path)])}.{grid_ext}" - grid.save(os.path.join(outpath, grid_file), grid_format, quality=grid_quality, lossless=grid_lossless, optimize=True) - - toc = time.time() - - mem_max_used, mem_total = mem_mon.read_and_stop() - time_diff = time.time()-start_time - - info = f""" - {prompt} - Steps: {steps}, Sampler: {sampler_name}, CFG scale: {cfg_scale}, Seed: {seed}{', Denoising strength: '+str(denoising_strength) if init_img is not None else ''}{', GFPGAN' if use_GFPGAN and st.session_state["GFPGAN"] is not None else ''}{', '+realesrgan_model_name if use_RealESRGAN and st.session_state["RealESRGAN"] is not None else ''}{', Prompt Matrix Mode.' if prompt_matrix else ''}""".strip() - stats = f''' - Took { round(time_diff, 2) }s total ({ round(time_diff/(len(all_prompts)),2) }s per image) - Peak memory usage: { -(mem_max_used // -1_048_576) } MiB / { -(mem_total // -1_048_576) } MiB / { round(mem_max_used/mem_total*100, 3) }%''' - - for comment in comments: - info += "\n\n" + comment - - #mem_mon.stop() - #del mem_mon - torch_gc() - - return output_images, seed, info, stats - - -def resize_image(resize_mode, im, width, height): - LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS) - if resize_mode == 0: - res = im.resize((width, height), resample=LANCZOS) - elif resize_mode == 1: - ratio = width / height - src_ratio = im.width / im.height - - src_w = width if ratio > src_ratio else im.width * height // im.height - src_h = height if ratio <= src_ratio else im.height * width // im.width - - resized = im.resize((src_w, src_h), resample=LANCZOS) - res = Image.new("RGBA", (width, height)) - res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2)) - else: - ratio = width / height - src_ratio = im.width / im.height - - src_w = width if ratio < src_ratio else im.width * height // im.height - src_h = height if ratio >= src_ratio else im.height * width // im.width - - resized = im.resize((src_w, src_h), resample=LANCZOS) - res = Image.new("RGBA", (width, height)) - res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2)) - - if ratio < src_ratio: - fill_height = height // 2 - src_h // 2 - res.paste(resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0)) - res.paste(resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)), box=(0, fill_height + src_h)) - elif ratio > src_ratio: - fill_width = width // 2 - src_w // 2 - res.paste(resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0)) - res.paste(resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)), box=(fill_width + src_w, 0)) - - return res - -import skimage - -def img2img(prompt: str = '', init_info: any = None, init_info_mask: any = None, mask_mode: int = 0, mask_blur_strength: int = 3, - mask_restore: bool = False, ddim_steps: int = 50, sampler_name: str = 'DDIM', - n_iter: int = 1, cfg_scale: float = 7.5, denoising_strength: float = 0.8, - seed: int = -1, noise_mode: int = 0, find_noise_steps: str = "", height: int = 512, width: int = 512, resize_mode: int = 0, fp = None, - variant_amount: float = None, variant_seed: int = None, ddim_eta:float = 0.0, - write_info_files:bool = True, RealESRGAN_model: str = "RealESRGAN_x4plus_anime_6B", - separate_prompts:bool = False, normalize_prompt_weights:bool = True, - save_individual_images: bool = True, save_grid: bool = True, group_by_prompt: bool = True, - save_as_jpg: bool = True, use_GFPGAN: bool = True, use_RealESRGAN: bool = True, loopback: bool = False, - random_seed_loopback: bool = False - ): - - outpath = defaults.general.outdir_img2img or defaults.general.outdir or "outputs/img2img-samples" - err = False - #loopback = False - #skip_save = False - seed = seed_to_int(seed) - - batch_size = 1 - - #prompt_matrix = 0 - #normalize_prompt_weights = 1 in toggles - #loopback = 2 in toggles - #random_seed_loopback = 3 in toggles - #skip_save = 4 not in toggles - #save_grid = 5 in toggles - #sort_samples = 6 in toggles - #write_info_files = 7 in toggles - #write_sample_info_to_log_file = 8 in toggles - #jpg_sample = 9 in toggles - #use_GFPGAN = 10 in toggles - #use_RealESRGAN = 11 in toggles - - if sampler_name == 'PLMS': - sampler = PLMSSampler(st.session_state["model"]) - elif sampler_name == 'DDIM': - sampler = DDIMSampler(st.session_state["model"]) - elif sampler_name == 'k_dpm_2_a': - sampler = KDiffusionSampler(st.session_state["model"],'dpm_2_ancestral') - elif sampler_name == 'k_dpm_2': - sampler = KDiffusionSampler(st.session_state["model"],'dpm_2') - elif sampler_name == 'k_euler_a': - sampler = KDiffusionSampler(st.session_state["model"],'euler_ancestral') - elif sampler_name == 'k_euler': - sampler = KDiffusionSampler(st.session_state["model"],'euler') - elif sampler_name == 'k_heun': - sampler = KDiffusionSampler(st.session_state["model"],'heun') - elif sampler_name == 'k_lms': - sampler = KDiffusionSampler(st.session_state["model"],'lms') - else: - raise Exception("Unknown sampler: " + sampler_name) - - def process_init_mask(init_mask: Image): - if init_mask.mode == "RGBA": - init_mask = init_mask.convert('RGBA') - background = Image.new('RGBA', init_mask.size, (0, 0, 0)) - init_mask = Image.alpha_composite(background, init_mask) - init_mask = init_mask.convert('RGB') - return init_mask - - init_img = init_info - init_mask = None - if mask_mode == 0: - if init_info_mask: - init_mask = process_init_mask(init_info_mask) - elif mask_mode == 1: - if init_info_mask: - init_mask = process_init_mask(init_info_mask) - init_mask = ImageOps.invert(init_mask) - elif mask_mode == 2: - init_img_transparency = init_img.split()[-1].convert('L')#.point(lambda x: 255 if x > 0 else 0, mode='1') - init_mask = init_img_transparency - init_mask = init_mask.convert("RGB") - init_mask = resize_image(resize_mode, init_mask, width, height) - init_mask = init_mask.convert("RGB") - - assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]' - t_enc = int(denoising_strength * ddim_steps) - - if init_mask is not None and (noise_mode == 2 or noise_mode == 3) and init_img is not None: - noise_q = 0.99 - color_variation = 0.0 - mask_blend_factor = 1.0 - - np_init = (np.asarray(init_img.convert("RGB"))/255.0).astype(np.float64) # annoyingly complex mask fixing - np_mask_rgb = 1. - (np.asarray(ImageOps.invert(init_mask).convert("RGB"))/255.0).astype(np.float64) - np_mask_rgb -= np.min(np_mask_rgb) - np_mask_rgb /= np.max(np_mask_rgb) - np_mask_rgb = 1. - np_mask_rgb - np_mask_rgb_hardened = 1. - (np_mask_rgb < 0.99).astype(np.float64) - blurred = skimage.filters.gaussian(np_mask_rgb_hardened[:], sigma=16., channel_axis=2, truncate=32.) - blurred2 = skimage.filters.gaussian(np_mask_rgb_hardened[:], sigma=16., channel_axis=2, truncate=32.) - #np_mask_rgb_dilated = np_mask_rgb + blurred # fixup mask todo: derive magic constants - #np_mask_rgb = np_mask_rgb + blurred - np_mask_rgb_dilated = np.clip((np_mask_rgb + blurred2) * 0.7071, 0., 1.) - np_mask_rgb = np.clip((np_mask_rgb + blurred) * 0.7071, 0., 1.) - - noise_rgb = matched_noise.get_matched_noise(np_init, np_mask_rgb, noise_q, color_variation) - blend_mask_rgb = np.clip(np_mask_rgb_dilated,0.,1.) ** (mask_blend_factor) - noised = noise_rgb[:] - blend_mask_rgb **= (2.) - noised = np_init[:] * (1. - blend_mask_rgb) + noised * blend_mask_rgb - - np_mask_grey = np.sum(np_mask_rgb, axis=2)/3. - ref_mask = np_mask_grey < 1e-3 - - all_mask = np.ones((height, width), dtype=bool) - noised[all_mask,:] = skimage.exposure.match_histograms(noised[all_mask,:]**1., noised[ref_mask,:], channel_axis=1) - - init_img = Image.fromarray(np.clip(noised * 255., 0., 255.).astype(np.uint8), mode="RGB") - st.session_state["editor_image"].image(init_img) # debug - - def init(): - image = init_img.convert('RGB') - image = np.array(image).astype(np.float32) / 255.0 - image = image[None].transpose(0, 3, 1, 2) - image = torch.from_numpy(image) - - mask_channel = None - if init_mask: - alpha = resize_image(resize_mode, init_mask, width // 8, height // 8) - mask_channel = alpha.split()[-1] - - mask = None - if mask_channel is not None: - mask = np.array(mask_channel).astype(np.float32) / 255.0 - mask = (1 - mask) - mask = np.tile(mask, (4, 1, 1)) - mask = mask[None].transpose(0, 1, 2, 3) - mask = torch.from_numpy(mask).to(st.session_state["device"]) - - if defaults.general.optimized: - modelFS.to(st.session_state["device"] ) - - init_image = 2. * image - 1. - init_image = init_image.to(st.session_state["device"]) - init_latent = (st.session_state["model"] if not defaults.general.optimized else modelFS).get_first_stage_encoding((st.session_state["model"] if not defaults.general.optimized else modelFS).encode_first_stage(init_image)) # move to latent space - - if defaults.general.optimized: - mem = torch.cuda.memory_allocated()/1e6 - modelFS.to("cpu") - while(torch.cuda.memory_allocated()/1e6 >= mem): - time.sleep(1) - - return init_latent, mask, - - def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name): - t_enc_steps = t_enc - obliterate = False - if ddim_steps == t_enc_steps: - t_enc_steps = t_enc_steps - 1 - obliterate = True - - if sampler_name != 'DDIM': - x0, z_mask = init_data - - sigmas = sampler.model_wrap.get_sigmas(ddim_steps) - noise = x * sigmas[ddim_steps - t_enc_steps - 1] - - xi = x0 + noise - - # Obliterate masked image - if z_mask is not None and obliterate: - random = torch.randn(z_mask.shape, device=xi.device) - xi = (z_mask * noise) + ((1-z_mask) * xi) - - sigma_sched = sigmas[ddim_steps - t_enc_steps - 1:] - model_wrap_cfg = CFGMaskedDenoiser(sampler.model_wrap) - samples_ddim = K.sampling.__dict__[f'sample_{sampler.get_sampler_name()}'](model_wrap_cfg, xi, sigma_sched, - extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, - 'cond_scale': cfg_scale, 'mask': z_mask, 'x0': x0, 'xi': xi}, disable=False, - callback=generation_callback) - else: - - x0, z_mask = init_data - - sampler.make_schedule(ddim_num_steps=ddim_steps, ddim_eta=0.0, verbose=False) - z_enc = sampler.stochastic_encode(x0, torch.tensor([t_enc_steps]*batch_size).to(st.session_state["device"] )) - - # Obliterate masked image - if z_mask is not None and obliterate: - random = torch.randn(z_mask.shape, device=z_enc.device) - z_enc = (z_mask * random) + ((1-z_mask) * z_enc) - - # decode it - samples_ddim = sampler.decode(z_enc, conditioning, t_enc_steps, - unconditional_guidance_scale=cfg_scale, - unconditional_conditioning=unconditional_conditioning, - z_mask=z_mask, x0=x0) - return samples_ddim - - - - if loopback: - output_images, info = None, None - history = [] - initial_seed = None - - do_color_correction = False - try: - from skimage import exposure - do_color_correction = True - except: - print("Install scikit-image to perform color correction on loopback") - - for i in range(n_iter): - if do_color_correction and i == 0: - correction_target = cv2.cvtColor(np.asarray(init_img.copy()), cv2.COLOR_RGB2LAB) - - output_images, seed, info, stats = process_images( - outpath=outpath, - func_init=init, - func_sample=sample, - prompt=prompt, - seed=seed, - sampler_name=sampler_name, - save_grid=save_grid, - batch_size=1, - n_iter=1, - steps=ddim_steps, - cfg_scale=cfg_scale, - width=width, - height=height, - prompt_matrix=separate_prompts, - use_GFPGAN=use_GFPGAN, - use_RealESRGAN=use_RealESRGAN, # Forcefully disable upscaling when using loopback - realesrgan_model_name=RealESRGAN_model, - fp=fp, - normalize_prompt_weights=normalize_prompt_weights, - save_individual_images=save_individual_images, - init_img=init_img, - init_mask=init_mask, - mask_blur_strength=mask_blur_strength, - mask_restore=mask_restore, - denoising_strength=denoising_strength, - noise_mode=noise_mode, - find_noise_steps=find_noise_steps, - resize_mode=resize_mode, - uses_loopback=loopback, - uses_random_seed_loopback=random_seed_loopback, - sort_samples=group_by_prompt, - write_info_files=write_info_files, - jpg_sample=save_as_jpg - ) - - if initial_seed is None: - initial_seed = seed - - init_img = output_images[0] - - if do_color_correction and correction_target is not None: - init_img = Image.fromarray(cv2.cvtColor(exposure.match_histograms( - cv2.cvtColor( - np.asarray(init_img), - cv2.COLOR_RGB2LAB - ), - correction_target, - channel_axis=2 - ), cv2.COLOR_LAB2RGB).astype("uint8")) - - if not random_seed_loopback: - seed = seed + 1 - else: - seed = seed_to_int(None) - - denoising_strength = max(denoising_strength * 0.95, 0.1) - history.append(init_img) - - output_images = history - seed = initial_seed - - else: - output_images, seed, info, stats = process_images( - outpath=outpath, - func_init=init, - func_sample=sample, - prompt=prompt, - seed=seed, - sampler_name=sampler_name, - save_grid=save_grid, - batch_size=batch_size, - n_iter=n_iter, - steps=ddim_steps, - cfg_scale=cfg_scale, - width=width, - height=height, - prompt_matrix=separate_prompts, - use_GFPGAN=use_GFPGAN, - use_RealESRGAN=use_RealESRGAN, - realesrgan_model_name=RealESRGAN_model, - fp=fp, - normalize_prompt_weights=normalize_prompt_weights, - save_individual_images=save_individual_images, - init_img=init_img, - init_mask=init_mask, - mask_blur_strength=mask_blur_strength, - denoising_strength=denoising_strength, - noise_mode=noise_mode, - find_noise_steps=find_noise_steps, - mask_restore=mask_restore, - resize_mode=resize_mode, - uses_loopback=loopback, - sort_samples=group_by_prompt, - write_info_files=write_info_files, - jpg_sample=save_as_jpg - ) - - del sampler - - return output_images, seed, info, stats - -@retry((RuntimeError, KeyError) , tries=3) -def txt2img(prompt: str, ddim_steps: int, sampler_name: str, realesrgan_model_name: str, - n_iter: int, batch_size: int, cfg_scale: float, seed: Union[int, str, None], - height: int, width: int, separate_prompts:bool = False, normalize_prompt_weights:bool = True, - save_individual_images: bool = True, save_grid: bool = True, group_by_prompt: bool = True, - save_as_jpg: bool = True, use_GFPGAN: bool = True, use_RealESRGAN: bool = True, - RealESRGAN_model: str = "RealESRGAN_x4plus_anime_6B", fp = None, variant_amount: float = None, - variant_seed: int = None, ddim_eta:float = 0.0, write_info_files:bool = True): - - outpath = defaults.general.outdir_txt2img or defaults.general.outdir or "outputs/txt2img-samples" - - err = False - seed = seed_to_int(seed) - - #prompt_matrix = 0 in toggles - #normalize_prompt_weights = 1 in toggles - #skip_save = 2 not in toggles - #save_grid = 3 not in toggles - #sort_samples = 4 in toggles - #write_info_files = 5 in toggles - #jpg_sample = 6 in toggles - #use_GFPGAN = 7 in toggles - #use_RealESRGAN = 8 in toggles - - if sampler_name == 'PLMS': - sampler = PLMSSampler(st.session_state["model"]) - elif sampler_name == 'DDIM': - sampler = DDIMSampler(st.session_state["model"]) - elif sampler_name == 'k_dpm_2_a': - sampler = KDiffusionSampler(st.session_state["model"],'dpm_2_ancestral') - elif sampler_name == 'k_dpm_2': - sampler = KDiffusionSampler(st.session_state["model"],'dpm_2') - elif sampler_name == 'k_euler_a': - sampler = KDiffusionSampler(st.session_state["model"],'euler_ancestral') - elif sampler_name == 'k_euler': - sampler = KDiffusionSampler(st.session_state["model"],'euler') - elif sampler_name == 'k_heun': - sampler = KDiffusionSampler(st.session_state["model"],'heun') - elif sampler_name == 'k_lms': - sampler = KDiffusionSampler(st.session_state["model"],'lms') - else: - raise Exception("Unknown sampler: " + sampler_name) - - def init(): - pass - - def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name): - samples_ddim, _ = sampler.sample(S=ddim_steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=cfg_scale, - unconditional_conditioning=unconditional_conditioning, eta=ddim_eta, x_T=x, img_callback=generation_callback, - log_every_t=int(defaults.general.update_preview_frequency)) - - return samples_ddim - - #try: - output_images, seed, info, stats = process_images( - outpath=outpath, - func_init=init, - func_sample=sample, - prompt=prompt, - seed=seed, - sampler_name=sampler_name, - save_grid=save_grid, - batch_size=batch_size, - n_iter=n_iter, - steps=ddim_steps, - cfg_scale=cfg_scale, - width=width, - height=height, - prompt_matrix=separate_prompts, - use_GFPGAN=use_GFPGAN, - use_RealESRGAN=use_RealESRGAN, - realesrgan_model_name=realesrgan_model_name, - fp=fp, - ddim_eta=ddim_eta, - normalize_prompt_weights=normalize_prompt_weights, - save_individual_images=save_individual_images, - sort_samples=group_by_prompt, - write_info_files=write_info_files, - jpg_sample=save_as_jpg, - variant_amount=variant_amount, - variant_seed=variant_seed, - ) - - del sampler - - return output_images, seed, info, stats - - #except RuntimeError as e: - #err = e - #err_msg = f'CRASHED:


Please wait while the program restarts.' - #stats = err_msg - #return [], seed, 'err', stats - - -# -def txt2vid( - # -------------------------------------- - # args you probably want to change - prompts = ["blueberry spaghetti", "strawberry spaghetti"], # prompt to dream about - gpu:int = defaults.general.gpu, # id of the gpu to run on - #name:str = 'test', # name of this project, for the output directory - #rootdir:str = defaults.general.outdir, - num_steps:int = 200, # number of steps between each pair of sampled points - max_frames:int = 10000, # number of frames to write and then exit the script - num_inference_steps:int = 50, # more (e.g. 100, 200 etc) can create slightly better images - cfg_scale:float = 5.0, # can depend on the prompt. usually somewhere between 3-10 is good - do_loop = False, - use_lerp_for_text = False, - seeds = None, - quality:int = 100, # for jpeg compression of the output images - eta:float = 0.0, - width:int = 256, - height:int = 256, - weights_path = "CompVis/stable-diffusion-v1-4", - scheduler="klms", # choices: default, ddim, klms - disable_tqdm = False, - #----------------------------------------------- - beta_start = 0.0001, - beta_end = 0.00012, - beta_schedule = "scaled_linear" - ): - """ - prompt = ["blueberry spaghetti", "strawberry spaghetti"], # prompt to dream about - gpu:int = defaults.general.gpu, # id of the gpu to run on - #name:str = 'test', # name of this project, for the output directory - #rootdir:str = defaults.general.outdir, - num_steps:int = 200, # number of steps between each pair of sampled points - max_frames:int = 10000, # number of frames to write and then exit the script - num_inference_steps:int = 50, # more (e.g. 100, 200 etc) can create slightly better images - cfg_scale:float = 5.0, # can depend on the prompt. usually somewhere between 3-10 is good - do_loop = False, - use_lerp_for_text = False, - seed = None, - quality:int = 100, # for jpeg compression of the output images - eta:float = 0.0, - width:int = 256, - height:int = 256, - weights_path = "CompVis/stable-diffusion-v1-4", - scheduler="klms", # choices: default, ddim, klms - disable_tqdm = False, - beta_start = 0.0001, - beta_end = 0.00012, - beta_schedule = "scaled_linear" - """ - mem_mon = MemUsageMonitor('MemMon') - mem_mon.start() - - - seeds = seed_to_int(seeds) - - # We add an extra frame because most - # of the time the first frame is just the noise. - max_frames +=1 - - assert torch.cuda.is_available() - assert height % 8 == 0 and width % 8 == 0 - torch.manual_seed(seeds) - torch_device = f"cuda:{gpu}" - - # init the output dir - sanitized_prompt = slugify(prompts) - - full_path = os.path.join(os.getcwd(), defaults.general.outdir, "txt2vid-samples", "samples", sanitized_prompt) - - if len(full_path) > 220: - sanitized_prompt = sanitized_prompt[:220-len(full_path)] - full_path = os.path.join(os.getcwd(), defaults.general.outdir, "txt2vid-samples", "samples", sanitized_prompt) - - os.makedirs(full_path, exist_ok=True) - - # Write prompt info to file in output dir so we can keep track of what we did - if st.session_state.write_info_files: - with open(os.path.join(full_path , f'{slugify(str(seeds))}_config.json' if len(prompts) > 1 else "prompts_config.json"), "w") as outfile: - outfile.write(json.dumps( - dict( - prompts = prompts, - gpu = gpu, - num_steps = num_steps, - max_frames = max_frames, - num_inference_steps = num_inference_steps, - cfg_scale = cfg_scale, - do_loop = do_loop, - use_lerp_for_text = use_lerp_for_text, - seeds = seeds, - quality = quality, - eta = eta, - width = width, - height = height, - weights_path = weights_path, - scheduler=scheduler, - disable_tqdm = disable_tqdm, - beta_start = beta_start, - beta_end = beta_end, - beta_schedule = beta_schedule - ), - indent=2, - sort_keys=False, - )) - - #print(scheduler) - default_scheduler = PNDMScheduler( - beta_start=beta_start, beta_end=beta_end, beta_schedule=beta_schedule - ) - # ------------------------------------------------------------------------------ - #Schedulers - ddim_scheduler = DDIMScheduler( - beta_start=beta_start, - beta_end=beta_end, - beta_schedule=beta_schedule, - clip_sample=False, - set_alpha_to_one=False, - ) - - klms_scheduler = LMSDiscreteScheduler( - beta_start=beta_start, beta_end=beta_end, beta_schedule=beta_schedule - ) - - SCHEDULERS = dict(default=default_scheduler, ddim=ddim_scheduler, klms=klms_scheduler) - - # ------------------------------------------------------------------------------ - - #if weights_path == "Stable Diffusion v1.4": - #weights_path = "CompVis/stable-diffusion-v1-4" - #else: - #weights_path = os.path.join("./models", "custom", f"{weights_path}.ckpt") - - try: - if "model" in st.session_state: - del st.session_state["model"] - except: - pass - - #print (st.session_state["weights_path"] != weights_path) - - try: - if not st.session_state["pipe"] or st.session_state["weights_path"] != weights_path: - if st.session_state["weights_path"] != weights_path: - del st.session_state["weights_path"] - - st.session_state["weights_path"] = weights_path - st.session_state["pipe"] = StableDiffusionPipeline.from_pretrained( - weights_path, - use_local_file=True, - use_auth_token=True, - #torch_dtype=torch.float16 if not defaults.general.no_half else None, - revision="fp16" if not defaults.general.no_half else None - ) - - st.session_state["pipe"].unet.to(torch_device) - st.session_state["pipe"].vae.to(torch_device) - st.session_state["pipe"].text_encoder.to(torch_device) - print("Tx2Vid Model Loaded") - else: - print("Tx2Vid Model already Loaded") - - except: - #del st.session_state["weights_path"] - #del st.session_state["pipe"] - - st.session_state["weights_path"] = weights_path - st.session_state["pipe"] = StableDiffusionPipeline.from_pretrained( - weights_path, - use_local_file=True, - use_auth_token=True, - #torch_dtype=torch.float16 if not defaults.general.no_half else None, - revision="fp16" if not defaults.general.no_half else None - ) - - st.session_state["pipe"].unet.to(torch_device) - st.session_state["pipe"].vae.to(torch_device) - st.session_state["pipe"].text_encoder.to(torch_device) - print("Tx2Vid Model Loaded") - - st.session_state["pipe"].scheduler = SCHEDULERS[scheduler] - - # get the conditional text embeddings based on the prompt - text_input = st.session_state["pipe"].tokenizer(prompts, padding="max_length", max_length=st.session_state["pipe"].tokenizer.model_max_length, truncation=True, return_tensors="pt") - cond_embeddings = st.session_state["pipe"].text_encoder(text_input.input_ids.to(torch_device))[0] # shape [1, 77, 768] - - # sample a source - init1 = torch.randn((1, st.session_state["pipe"].unet.in_channels, height // 8, width // 8), device=torch_device) - - if do_loop: - prompts = [prompts, prompts] - seeds = [seeds, seeds] - #first_seed, *seeds = seeds - #prompts.append(prompts) - #seeds.append(first_seed) - - - # iterate the loop - frames = [] - frame_index = 0 - - st.session_state["frame_total_duration"] = 0 - st.session_state["frame_total_speed"] = 0 - - try: - while frame_index < max_frames: - st.session_state["frame_duration"] = 0 - st.session_state["frame_speed"] = 0 - st.session_state["current_frame"] = frame_index - - # sample the destination - init2 = torch.randn((1, st.session_state["pipe"].unet.in_channels, height // 8, width // 8), device=torch_device) - - for i, t in enumerate(np.linspace(0, 1, num_steps)): - start = timeit.default_timer() - print(f"COUNT: {frame_index+1}/{num_steps}") - - #if use_lerp_for_text: - #init = torch.lerp(init1, init2, float(t)) - #else: - #init = slerp(gpu, float(t), init1, init2) - - init = slerp(gpu, float(t), init1, init2) - - with autocast("cuda"): - image = diffuse(st.session_state["pipe"], cond_embeddings, init, num_inference_steps, cfg_scale, eta) - - im = Image.fromarray(image) - outpath = os.path.join(full_path, 'frame%06d.png' % frame_index) - im.save(outpath, quality=quality) - - # send the image to the UI to update it - #st.session_state["preview_image"].image(im) - - #append the frames to the frames list so we can use them later. - frames.append(np.asarray(im)) - - #increase frame_index counter. - frame_index += 1 - - st.session_state["current_frame"] = frame_index - - duration = timeit.default_timer() - start - - if duration >= 1: - speed = "s/it" - else: - speed = "it/s" - duration = 1 / duration - - st.session_state["frame_duration"] = duration - st.session_state["frame_speed"] = speed - - init1 = init2 - - except StopException: - pass - - - if st.session_state['save_video']: - # write video to memory - #output = io.BytesIO() - #writer = imageio.get_writer(os.path.join(os.getcwd(), defaults.general.outdir, "txt2vid-samples"), im, extension=".mp4", fps=30) - try: - video_path = os.path.join(os.getcwd(), defaults.general.outdir, "txt2vid-samples","temp.mp4") - writer = imageio.get_writer(video_path, fps=24) - for frame in frames: - writer.append_data(frame) - writer.close() - except: - print("Can't save video, skipping.") - - # show video preview on the UI - st.session_state["preview_video"].video(open(video_path, 'rb').read()) - - mem_max_used, mem_total = mem_mon.read_and_stop() - time_diff = time.time()- start - - info = f""" - {prompts} - Sampling Steps: {num_steps}, Sampler: {scheduler}, CFG scale: {cfg_scale}, Seed: {seeds}, Max Frames: {max_frames}""".strip() - stats = f''' - Took { round(time_diff, 2) }s total ({ round(time_diff/(max_frames),2) }s per image) - Peak memory usage: { -(mem_max_used // -1_048_576) } MiB / { -(mem_total // -1_048_576) } MiB / { round(mem_max_used/mem_total*100, 3) }%''' - - return im, seeds, info, stats - - -# functions to load css locally OR remotely starts here. Options exist for future flexibility. Called as st.markdown with unsafe_allow_html as css injection -# TODO, maybe look into async loading the file especially for remote fetching -def local_css(file_name): - with open(file_name) as f: - st.markdown(f'', unsafe_allow_html=True) - -def remote_css(url): - st.markdown(f'', unsafe_allow_html=True) - -def load_css(isLocal, nameOrURL): - if(isLocal): - local_css(nameOrURL) - else: - remote_css(nameOrURL) - - -# main functions to define streamlit layout here -def layout(): - - st.set_page_config(page_title="Stable Diffusion Playground", layout="wide") - - with st.empty(): - # load css as an external file, function has an option to local or remote url. Potential use when running from cloud infra that might not have access to local path. - load_css(True, 'frontend/css/streamlit.main.css') - - # check if the models exist on their respective folders - if os.path.exists(os.path.join(defaults.general.GFPGAN_dir, "experiments", "pretrained_models", "GFPGANv1.3.pth")): - GFPGAN_available = True - else: - GFPGAN_available = False - - if os.path.exists(os.path.join(defaults.general.RealESRGAN_dir, "experiments","pretrained_models", f"{defaults.general.RealESRGAN_model}.pth")): - RealESRGAN_available = True - else: - RealESRGAN_available = False - - # Allow for custom models to be used instead of the default one, - # an example would be Waifu-Diffusion or any other fine tune of stable diffusion - custom_models:sorted = [] - for root, dirs, files in os.walk(os.path.join("models", "custom")): - for file in files: - if os.path.splitext(file)[1] == '.ckpt': - fullpath = os.path.join(root, file) - #print(fullpath) - custom_models.append(os.path.splitext(file)[0]) - #print (os.path.splitext(file)[0]) - - if len(custom_models) > 0: - CustomModel_available = True - custom_models.append("Stable Diffusion v1.4") - else: - CustomModel_available = False - - with st.sidebar: - # The global settings section will be moved to the Settings page. - #with st.expander("Global Settings:"): - #st.write("Global Settings:") - #defaults.general.update_preview = st.checkbox("Update Image Preview", value=defaults.general.update_preview, - #help="If enabled the image preview will be updated during the generation instead of at the end. You can use the Update Preview \ - #Frequency option bellow to customize how frequent it's updated. By default this is enabled and the frequency is set to 1 step.") - #st.session_state.update_preview_frequency = st.text_input("Update Image Preview Frequency", value=defaults.general.update_preview_frequency, - #help="Frequency in steps at which the the preview image is updated. By default the frequency is set to 1 step.") - - tabs = on_hover_tabs(tabName=['Stable Diffusion', "Textual Inversion","Model Manager","Settings"], - iconName=['dashboard','model_training' ,'cloud_download', 'settings'], default_choice=0) - - - if tabs =='Stable Diffusion': - txt2img_tab, img2img_tab, txt2vid_tab, postprocessing_tab = st.tabs(["Text-to-Image Unified", "Image-to-Image Unified", - "Text-to-Video","Post-Processing"]) - with txt2img_tab: - with st.form("txt2img-inputs"): - st.session_state["generation_mode"] = "txt2img" - - input_col1, generate_col1 = st.columns([10,1]) - - with input_col1: - #prompt = st.text_area("Input Text","") - prompt = st.text_input("Input Text","", placeholder="A corgi wearing a top hat as an oil painting.") - - # Every form must have a submit button, the extra blank spaces is a temp way to align it with the input field. Needs to be done in CSS or some other way. - generate_col1.write("") - generate_col1.write("") - generate_button = generate_col1.form_submit_button("Generate") - - # creating the page layout using columns - col1, col2, col3 = st.columns([1,2,1], gap="large") - - with col1: - width = st.slider("Width:", min_value=64, max_value=1024, value=defaults.txt2img.width, step=64) - height = st.slider("Height:", min_value=64, max_value=1024, value=defaults.txt2img.height, step=64) - cfg_scale = st.slider("CFG (Classifier Free Guidance Scale):", min_value=1.0, max_value=30.0, value=defaults.txt2img.cfg_scale, step=0.5, help="How strongly the image should follow the prompt.") - seed = st.text_input("Seed:", value=defaults.txt2img.seed, help=" The seed to use, if left blank a random seed will be generated.") - batch_count = st.slider("Batch count.", min_value=1, max_value=100, value=defaults.txt2img.batch_count, step=1, help="How many iterations or batches of images to generate in total.") - #batch_size = st.slider("Batch size", min_value=1, max_value=250, value=defaults.txt2img.batch_size, step=1, - #help="How many images are at once in a batch.\ - #It increases the VRAM usage a lot but if you have enough VRAM it can reduce the time it takes to finish generation as more images are generated at once.\ - #Default: 1") - - with st.expander("Preview Settings"): - st.session_state["update_preview"] = st.checkbox("Update Image Preview", value=defaults.txt2img.update_preview, - help="If enabled the image preview will be updated during the generation instead of at the end. \ - You can use the Update Preview \Frequency option bellow to customize how frequent it's updated. \ - By default this is enabled and the frequency is set to 1 step.") - - st.session_state["update_preview_frequency"] = st.text_input("Update Image Preview Frequency", value=defaults.txt2img.update_preview_frequency, - help="Frequency in steps at which the the preview image is updated. By default the frequency \ - is set to 1 step.") - - with col2: - preview_tab, gallery_tab = st.tabs(["Preview", "Gallery"]) - - with preview_tab: - #st.write("Image") - #Image for testing - #image = Image.open(requests.get("https://icon-library.com/images/image-placeholder-icon/image-placeholder-icon-13.jpg", stream=True).raw).convert('RGB') - #new_image = image.resize((175, 240)) - #preview_image = st.image(image) - - # create an empty container for the image, progress bar, etc so we can update it later and use session_state to hold them globally. - st.session_state["preview_image"] = st.empty() - st.session_state["preview_video"] = st.empty() - - st.session_state["loading"] = st.empty() - - st.session_state["progress_bar_text"] = st.empty() - st.session_state["progress_bar"] = st.empty() - - message = st.empty() - - with gallery_tab: - st.write('Here should be the image gallery, if I could make a grid in streamlit.') - - with col3: - # If we have custom models available on the "models/custom" - #folder then we show a menu to select which model we want to use, otherwise we use the main model for SD - if CustomModel_available: - custom_model = st.selectbox("Custom Model:", custom_models, - index=custom_models.index(defaults.general.default_model), - help="Select the model you want to use. This option is only available if you have custom models \ - on your 'models/custom' folder. The model name that will be shown here is the same as the name\ - the file for the model has on said folder, it is recommended to give the .ckpt file a name that \ - will make it easier for you to distinguish it from other models. Default: Stable Diffusion v1.4") - else: - custom_model = "Stable Diffusion v1.4" - - st.session_state.sampling_steps = st.slider("Sampling Steps", value=defaults.txt2img.sampling_steps, min_value=1, max_value=250) - - sampler_name_list = ["k_lms", "k_euler", "k_euler_a", "k_dpm_2", "k_dpm_2_a", "k_heun", "PLMS", "DDIM"] - sampler_name = st.selectbox("Sampling method", sampler_name_list, - index=sampler_name_list.index(defaults.txt2img.default_sampler), help="Sampling method to use. Default: k_euler") - - - - #basic_tab, advanced_tab = st.tabs(["Basic", "Advanced"]) - - #with basic_tab: - #summit_on_enter = st.radio("Submit on enter?", ("Yes", "No"), horizontal=True, - #help="Press the Enter key to summit, when 'No' is selected you can use the Enter key to write multiple lines.") - - with st.expander("Advanced"): - separate_prompts = st.checkbox("Create Prompt Matrix.", value=False, - help="Separate multiple prompts using the `|` character, and get all combinations of them.") - normalize_prompt_weights = st.checkbox("Normalize Prompt Weights.", - value=defaults.txt2img.normalize_prompt_weights, help="Ensure the sum of all weights add up to 1.0") - save_individual_images = st.checkbox("Save individual images.", value=defaults.txt2img.save_individual_images, - help="Save each image generated before any filter or enhancement is applied.") - save_grid = st.checkbox("Save grid",value=defaults.txt2img.save_grid, help="Save a grid with all the images generated into a single image.") - group_by_prompt = st.checkbox("Group results by prompt", value=defaults.txt2img.group_by_prompt, - help="Saves all the images with the same prompt into the same folder. \ - When using a prompt matrix each prompt combination will have its own folder.") - write_info_files = st.checkbox("Write Info file", value=defaults.txt2img.write_info_files, - help="Save a file next to the image with informartion about the generation.") - save_as_jpg = st.checkbox("Save samples as jpg", value=defaults.txt2img.save_as_jpg, help="Saves the images as jpg instead of png.") - - if GFPGAN_available: - use_GFPGAN = st.checkbox("Use GFPGAN", value=defaults.txt2img.use_GFPGAN, - help="Uses the GFPGAN model to improve faces after the generation. This greatly improve the quality and \ - consistency of faces but uses extra VRAM. Disable if you need the extra VRAM.") - else: - use_GFPGAN = False - - if RealESRGAN_available: - use_RealESRGAN = st.checkbox("Use RealESRGAN", value=defaults.txt2img.use_RealESRGAN, - help="Uses the RealESRGAN model to upscale the images after the generation. This greatly improve the \ - quality and lets you have high resolution images but uses extra VRAM. Disable if you need the extra VRAM.") - RealESRGAN_model = st.selectbox("RealESRGAN model", ["RealESRGAN_x4plus", "RealESRGAN_x4plus_anime_6B"], index=0) - else: - use_RealESRGAN = False - RealESRGAN_model = "RealESRGAN_x4plus" - - variant_amount = st.slider("Variant Amount:", value=defaults.txt2img.variant_amount, min_value=0.0, max_value=1.0, step=0.01) - variant_seed = st.text_input("Variant Seed:", value=defaults.txt2img.seed, - help="The seed to use when generating a variant, if left blank a random seed will be generated.") - - - if generate_button: - #print("Loading models") - # load the models when we hit the generate button for the first time, it wont be loaded after that so dont worry. - load_models(False, use_GFPGAN, use_RealESRGAN, RealESRGAN_model, CustomModel_available, custom_model) - - try: - output_images, seed, info, stats = txt2img(prompt, st.session_state.sampling_steps, sampler_name, RealESRGAN_model, batch_count, 1, - cfg_scale, seed, height, width, separate_prompts, normalize_prompt_weights, save_individual_images, - save_grid, group_by_prompt, save_as_jpg, use_GFPGAN, use_RealESRGAN, RealESRGAN_model, fp=defaults.general.fp, - variant_amount=variant_amount, variant_seed=variant_seed, write_info_files=write_info_files) - - message.success('Render Complete: ' + info + '; Stats: ' + stats, icon="✅") - - except KeyError: - output_images, seed, info, stats = txt2img(prompt, st.session_state.sampling_steps, sampler_name, RealESRGAN_model, batch_count, 1, - cfg_scale, seed, height, width, separate_prompts, normalize_prompt_weights, save_individual_images, - save_grid, group_by_prompt, save_as_jpg, use_GFPGAN, use_RealESRGAN, RealESRGAN_model, fp=defaults.general.fp, - variant_amount=variant_amount, variant_seed=variant_seed, write_info_files=write_info_files) - - message.success('Render Complete: ' + info + '; Stats: ' + stats, icon="✅") - - except (StopException): - print(f"Received Streamlit StopException") - - # this will render all the images at the end of the generation but its better if its moved to a second tab inside col2 and shown as a gallery. - # use the current col2 first tab to show the preview_img and update it as its generated. - #preview_image.image(output_images) - - with img2img_tab: - with st.form("img2img-inputs"): - st.session_state["generation_mode"] = "img2img" - - img2img_input_col, img2img_generate_col = st.columns([10,1]) - with img2img_input_col: - #prompt = st.text_area("Input Text","") - prompt = st.text_input("Input Text","", placeholder="A corgi wearing a top hat as an oil painting.") - - # Every form must have a submit button, the extra blank spaces is a temp way to align it with the input field. Needs to be done in CSS or some other way. - img2img_generate_col.write("") - img2img_generate_col.write("") - generate_button = img2img_generate_col.form_submit_button("Generate") - - - # creating the page layout using columns - col1_img2img_layout, col2_img2img_layout, col3_img2img_layout = st.columns([1,2,2], gap="small") - - with col1_img2img_layout: - # If we have custom models available on the "models/custom" - #folder then we show a menu to select which model we want to use, otherwise we use the main model for SD - if CustomModel_available: - custom_model = st.selectbox("Custom Model:", custom_models, - index=custom_models.index(defaults.general.default_model), - help="Select the model you want to use. This option is only available if you have custom models \ - on your 'models/custom' folder. The model name that will be shown here is the same as the name\ - the file for the model has on said folder, it is recommended to give the .ckpt file a name that \ - will make it easier for you to distinguish it from other models. Default: Stable Diffusion v1.4") - else: - custom_model = "Stable Diffusion v1.4" - - st.session_state["sampling_steps"] = st.slider("Sampling Steps", value=defaults.img2img.sampling_steps, min_value=1, max_value=500) - st.session_state["sampler_name"] = st.selectbox("Sampling method", - ["k_lms", "k_euler", "k_euler_a", "k_dpm_2", "k_dpm_2_a", "k_heun", "PLMS", "DDIM"], - index=sampler_name_list.index(defaults.img2img.sampler_name), - help="Sampling method to use.") - - mask_mode_list = ["Mask", "Inverted mask", "Image alpha"] - mask_mode = st.selectbox("Mask Mode", mask_mode_list, - help="Select how you want your image to be masked.\"Mask\" modifies the image where the mask is white.\n\ - \"Inverted mask\" modifies the image where the mask is black. \"Image alpha\" modifies the image where the image is transparent." - ) - mask_mode = mask_mode_list.index(mask_mode) - - width = st.slider("Width:", min_value=64, max_value=1024, value=defaults.img2img.width, step=64) - height = st.slider("Height:", min_value=64, max_value=1024, value=defaults.img2img.height, step=64) - seed = st.text_input("Seed:", value=defaults.img2img.seed, help=" The seed to use, if left blank a random seed will be generated.") - noise_mode_list = ["Seed", "Find Noise", "Matched Noise", "Find+Matched Noise"] - noise_mode = st.selectbox( - "Noise Mode", noise_mode_list, - help="" - ) - noise_mode = noise_mode_list.index(noise_mode) - find_noise_steps = st.slider("Find Noise Steps", value=100, min_value=1, max_value=500) - batch_count = st.slider("Batch count.", min_value=1, max_value=100, value=defaults.img2img.batch_count, step=1, - help="How many iterations or batches of images to generate in total.") - - # - with st.expander("Advanced"): - separate_prompts = st.checkbox("Create Prompt Matrix.", value=defaults.img2img.separate_prompts, - help="Separate multiple prompts using the `|` character, and get all combinations of them.") - normalize_prompt_weights = st.checkbox("Normalize Prompt Weights.", value=defaults.img2img.normalize_prompt_weights, - help="Ensure the sum of all weights add up to 1.0") - loopback = st.checkbox("Loopback.", value=defaults.img2img.loopback, help="Use images from previous batch when creating next batch.") - random_seed_loopback = st.checkbox("Random loopback seed.", value=defaults.img2img.random_seed_loopback, help="Random loopback seed") - save_individual_images = st.checkbox("Save individual images.", value=defaults.img2img.save_individual_images, - help="Save each image generated before any filter or enhancement is applied.") - save_grid = st.checkbox("Save grid",value=defaults.img2img.save_grid, help="Save a grid with all the images generated into a single image.") - group_by_prompt = st.checkbox("Group results by prompt", value=defaults.img2img.group_by_prompt, - help="Saves all the images with the same prompt into the same folder. \ - When using a prompt matrix each prompt combination will have its own folder.") - write_info_files = st.checkbox("Write Info file", value=defaults.img2img.write_info_files, - help="Save a file next to the image with informartion about the generation.") - save_as_jpg = st.checkbox("Save samples as jpg", value=defaults.img2img.save_as_jpg, help="Saves the images as jpg instead of png.") - - if GFPGAN_available: - use_GFPGAN = st.checkbox("Use GFPGAN", value=defaults.img2img.use_GFPGAN, help="Uses the GFPGAN model to improve faces after the generation.\ - This greatly improve the quality and consistency of faces but uses extra VRAM. Disable if you need the extra VRAM.") - else: - use_GFPGAN = False - - if RealESRGAN_available: - use_RealESRGAN = st.checkbox("Use RealESRGAN", value=defaults.img2img.use_RealESRGAN, - help="Uses the RealESRGAN model to upscale the images after the generation.\ - This greatly improve the quality and lets you have high resolution images but uses extra VRAM. Disable if you need the extra VRAM.") - RealESRGAN_model = st.selectbox("RealESRGAN model", ["RealESRGAN_x4plus", "RealESRGAN_x4plus_anime_6B"], index=0) - else: - use_RealESRGAN = False - RealESRGAN_model = "RealESRGAN_x4plus" - - variant_amount = st.slider("Variant Amount:", value=defaults.img2img.variant_amount, min_value=0.0, max_value=1.0, step=0.01) - variant_seed = st.text_input("Variant Seed:", value=defaults.img2img.variant_seed, - help="The seed to use when generating a variant, if left blank a random seed will be generated.") - cfg_scale = st.slider("CFG (Classifier Free Guidance Scale):", min_value=1.0, max_value=30.0, value=defaults.img2img.cfg_scale, step=0.5, - help="How strongly the image should follow the prompt.") - batch_size = st.slider("Batch size", min_value=1, max_value=100, value=defaults.img2img.batch_size, step=1, - help="How many images are at once in a batch.\ - It increases the VRAM usage a lot but if you have enough VRAM it can reduce the time it takes to finish \ - generation as more images are generated at once.\ - Default: 1") - - st.session_state["denoising_strength"] = st.slider("Denoising Strength:", value=defaults.img2img.denoising_strength, - min_value=0.01, max_value=1.0, step=0.01) - - with st.expander("Preview Settings"): - st.session_state["update_preview"] = st.checkbox("Update Image Preview", value=defaults.img2img.update_preview, - help="If enabled the image preview will be updated during the generation instead of at the end. \ - You can use the Update Preview \Frequency option bellow to customize how frequent it's updated. \ - By default this is enabled and the frequency is set to 1 step.") - - st.session_state["update_preview_frequency"] = st.text_input("Update Image Preview Frequency", value=defaults.img2img.update_preview_frequency, - help="Frequency in steps at which the the preview image is updated. By default the frequency \ - is set to 1 step.") - - with col2_img2img_layout: - editor_tab = st.tabs(["Editor"]) - - editor_image = st.empty() - st.session_state["editor_image"] = editor_image - - refresh_button = st.form_submit_button("Refresh") - - masked_image_holder = st.empty() - image_holder = st.empty() - - uploaded_images = st.file_uploader( - "Upload Image", accept_multiple_files=False, type=["png", "jpg", "jpeg", "webp"], - help="Upload an image which will be used for the image to image generation.", - ) - if uploaded_images: - image = Image.open(uploaded_images).convert('RGBA') - new_img = image.resize((width, height)) - image_holder.image(new_img) - - mask_holder = st.empty() - - uploaded_masks = st.file_uploader( - "Upload Mask", accept_multiple_files=False, type=["png", "jpg", "jpeg", "webp"], - help="Upload an mask image which will be used for masking the image to image generation.", - ) - if uploaded_masks: - mask = Image.open(uploaded_masks) - if mask.mode == "RGBA": - mask = mask.convert('RGBA') - background = Image.new('RGBA', mask.size, (0, 0, 0)) - mask = Image.alpha_composite(background, mask) - mask = mask.resize((width, height)) - mask_holder.image(mask) - - if uploaded_images and uploaded_masks: - if mask_mode != 2: - final_img = new_img.copy() - alpha_layer = mask.convert('L') - strength = st.session_state["denoising_strength"] - if mask_mode == 0: - alpha_layer = ImageOps.invert(alpha_layer) - alpha_layer = alpha_layer.point(lambda a: a * strength) - alpha_layer = ImageOps.invert(alpha_layer) - elif mask_mode == 1: - alpha_layer = alpha_layer.point(lambda a: a * strength) - alpha_layer = ImageOps.invert(alpha_layer) - - final_img.putalpha(alpha_layer) - - with masked_image_holder.container(): - st.text("Masked Image Preview") - st.image(final_img) - - - with col3_img2img_layout: - result_tab = st.tabs(["Result"]) - - # create an empty container for the image, progress bar, etc so we can update it later and use session_state to hold them globally. - preview_image = st.empty() - st.session_state["preview_image"] = preview_image - - #st.session_state["loading"] = st.empty() - - st.session_state["progress_bar_text"] = st.empty() - st.session_state["progress_bar"] = st.empty() - - - message = st.empty() - - #if uploaded_images: - #image = Image.open(uploaded_images).convert('RGB') - ##img_array = np.array(image) # if you want to pass it to OpenCV - #new_img = image.resize((width, height)) - #st.image(new_img, use_column_width=True) - - - if generate_button: - #print("Loading models") - # load the models when we hit the generate button for the first time, it wont be loaded after that so dont worry. - load_models(False, use_GFPGAN, use_RealESRGAN, RealESRGAN_model, CustomModel_available, custom_model) - if uploaded_images: - image = Image.open(uploaded_images).convert('RGBA') - new_img = image.resize((width, height)) - #img_array = np.array(image) # if you want to pass it to OpenCV - new_mask = None - if uploaded_masks: - mask = Image.open(uploaded_masks).convert('RGBA') - new_mask = mask.resize((width, height)) - - try: - output_images, seed, info, stats = img2img(prompt=prompt, init_info=new_img, init_info_mask=new_mask, mask_mode=mask_mode, ddim_steps=st.session_state["sampling_steps"], - sampler_name=st.session_state["sampler_name"], n_iter=batch_count, - cfg_scale=cfg_scale, denoising_strength=st.session_state["denoising_strength"], variant_seed=variant_seed, - seed=seed, noise_mode=noise_mode, find_noise_steps=find_noise_steps, width=width, height=height, fp=defaults.general.fp, variant_amount=variant_amount, - ddim_eta=0.0, write_info_files=write_info_files, RealESRGAN_model=RealESRGAN_model, - separate_prompts=separate_prompts, normalize_prompt_weights=normalize_prompt_weights, - save_individual_images=save_individual_images, save_grid=save_grid, - group_by_prompt=group_by_prompt, save_as_jpg=save_as_jpg, use_GFPGAN=use_GFPGAN, - use_RealESRGAN=use_RealESRGAN if not loopback else False, loopback=loopback - ) - - #show a message when the generation is complete. - message.success('Render Complete: ' + info + '; Stats: ' + stats, icon="✅") - - except (StopException, KeyError): - print(f"Received Streamlit StopException") - - # this will render all the images at the end of the generation but its better if its moved to a second tab inside col2 and shown as a gallery. - # use the current col2 first tab to show the preview_img and update it as its generated. - #preview_image.image(output_images, width=750) - - with txt2vid_tab: - with st.form("txt2vid-inputs"): - st.session_state["generation_mode"] = "txt2vid" - - input_col1, generate_col1 = st.columns([10,1]) - with input_col1: - #prompt = st.text_area("Input Text","") - prompt = st.text_input("Input Text","", placeholder="A corgi wearing a top hat as an oil painting.") - - # Every form must have a submit button, the extra blank spaces is a temp way to align it with the input field. Needs to be done in CSS or some other way. - generate_col1.write("") - generate_col1.write("") - generate_button = generate_col1.form_submit_button("Generate") - - # creating the page layout using columns - col1, col2, col3 = st.columns([1,2,1], gap="large") - - with col1: - width = st.slider("Width:", min_value=64, max_value=2048, value=defaults.txt2vid.width, step=64) - height = st.slider("Height:", min_value=64, max_value=2048, value=defaults.txt2vid.height, step=64) - cfg_scale = st.slider("CFG (Classifier Free Guidance Scale):", min_value=1.0, max_value=30.0, value=defaults.txt2vid.cfg_scale, step=0.5, help="How strongly the image should follow the prompt.") - seed = st.text_input("Seed:", value=defaults.txt2vid.seed, help=" The seed to use, if left blank a random seed will be generated.") - batch_count = st.slider("Batch count.", min_value=1, max_value=100, value=defaults.txt2vid.batch_count, step=1, help="How many iterations or batches of images to generate in total.") - #batch_size = st.slider("Batch size", min_value=1, max_value=250, value=defaults.txt2vid.batch_size, step=1, - #help="How many images are at once in a batch.\ - #It increases the VRAM usage a lot but if you have enough VRAM it can reduce the time it takes to finish generation as more images are generated at once.\ - #Default: 1") - - st.session_state["max_frames"] = int(st.text_input("Max Frames:", value=defaults.txt2vid.max_frames, help="Specify the max number of frames you want to generate.")) - - with st.expander("Preview Settings"): - st.session_state["update_preview"] = st.checkbox("Update Image Preview", value=defaults.txt2vid.update_preview, - help="If enabled the image preview will be updated during the generation instead of at the end. \ - You can use the Update Preview \Frequency option bellow to customize how frequent it's updated. \ - By default this is enabled and the frequency is set to 1 step.") - - st.session_state["update_preview_frequency"] = st.text_input("Update Image Preview Frequency", value=defaults.txt2vid.update_preview_frequency, - help="Frequency in steps at which the the preview image is updated. By default the frequency \ - is set to 1 step.") - with col2: - preview_tab, gallery_tab = st.tabs(["Preview", "Gallery"]) - - with preview_tab: - #st.write("Image") - #Image for testing - #image = Image.open(requests.get("https://icon-library.com/images/image-placeholder-icon/image-placeholder-icon-13.jpg", stream=True).raw).convert('RGB') - #new_image = image.resize((175, 240)) - #preview_image = st.image(image) - - # create an empty container for the image, progress bar, etc so we can update it later and use session_state to hold them globally. - st.session_state["preview_image"] = st.empty() - - st.session_state["loading"] = st.empty() - - st.session_state["progress_bar_text"] = st.empty() - st.session_state["progress_bar"] = st.empty() - - generate_video = st.empty() - st.session_state["preview_video"] = st.empty() - - message = st.empty() - - with gallery_tab: - st.write('Here should be the image gallery, if I could make a grid in streamlit.') - - with col3: - # If we have custom models available on the "models/custom" - #folder then we show a menu to select which model we want to use, otherwise we use the main model for SD - #if CustomModel_available: - custom_model = st.selectbox("Custom Model:", defaults.txt2vid.custom_models_list, - index=defaults.txt2vid.custom_models_list.index(defaults.txt2vid.default_model), - help="Select the model you want to use. This option is only available if you have custom models \ - on your 'models/custom' folder. The model name that will be shown here is the same as the name\ - the file for the model has on said folder, it is recommended to give the .ckpt file a name that \ - will make it easier for you to distinguish it from other models. Default: Stable Diffusion v1.4") - - #st.session_state["weights_path"] = custom_model - #else: - #custom_model = "CompVis/stable-diffusion-v1-4" - #st.session_state["weights_path"] = f"CompVis/{slugify(custom_model.lower())}" - - st.session_state.sampling_steps = st.slider("Sampling Steps", value=defaults.txt2vid.sampling_steps, min_value=10, step=10, max_value=500, - help="Number of steps between each pair of sampled points") - st.session_state.num_inference_steps = st.slider("Inference Steps:", value=defaults.txt2vid.num_inference_steps, min_value=10,step=10, max_value=500, - help="Higher values (e.g. 100, 200 etc) can create better images.") - - #sampler_name_list = ["k_lms", "k_euler", "k_euler_a", "k_dpm_2", "k_dpm_2_a", "k_heun", "PLMS", "DDIM"] - #sampler_name = st.selectbox("Sampling method", sampler_name_list, - #index=sampler_name_list.index(defaults.txt2vid.default_sampler), help="Sampling method to use. Default: k_euler") - scheduler_name_list = ["klms", "ddim"] - scheduler_name = st.selectbox("Scheduler:", scheduler_name_list, - index=scheduler_name_list.index(defaults.txt2vid.scheduler_name), help="Scheduler to use. Default: klms") - - beta_scheduler_type_list = ["scaled_linear", "linear"] - beta_scheduler_type = st.selectbox("Beta Schedule Type:", beta_scheduler_type_list, - index=beta_scheduler_type_list.index(defaults.txt2vid.beta_scheduler_type), help="Schedule Type to use. Default: linear") - - - #basic_tab, advanced_tab = st.tabs(["Basic", "Advanced"]) - - #with basic_tab: - #summit_on_enter = st.radio("Submit on enter?", ("Yes", "No"), horizontal=True, - #help="Press the Enter key to summit, when 'No' is selected you can use the Enter key to write multiple lines.") - - with st.expander("Advanced"): - st.session_state["separate_prompts"] = st.checkbox("Create Prompt Matrix.", value=defaults.txt2vid.separate_prompts, - help="Separate multiple prompts using the `|` character, and get all combinations of them.") - st.session_state["normalize_prompt_weights"] = st.checkbox("Normalize Prompt Weights.", - value=defaults.txt2vid.normalize_prompt_weights, help="Ensure the sum of all weights add up to 1.0") - st.session_state["save_individual_images"] = st.checkbox("Save individual images.", - value=defaults.txt2vid.save_individual_images, help="Save each image generated before any filter or enhancement is applied.") - st.session_state["save_video"] = st.checkbox("Save video",value=defaults.txt2vid.save_video, help="Save a video with all the images generated as frames at the end of the generation.") - st.session_state["group_by_prompt"] = st.checkbox("Group results by prompt", value=defaults.txt2vid.group_by_prompt, - help="Saves all the images with the same prompt into the same folder. When using a prompt matrix each prompt combination will have its own folder.") - st.session_state["write_info_files"] = st.checkbox("Write Info file", value=defaults.txt2vid.write_info_files, - help="Save a file next to the image with informartion about the generation.") - st.session_state["dynamic_preview_frequency"] = st.checkbox("Dynamic Preview Frequency", value=defaults.txt2vid.dynamic_preview_frequency, - help="This option tries to find the best value at which we can update \ - the preview image during generation while minimizing the impact it has in performance. Default: True") - st.session_state["do_loop"] = st.checkbox("Do Loop", value=defaults.txt2vid.do_loop, - help="Do loop") - st.session_state["save_as_jpg"] = st.checkbox("Save samples as jpg", value=defaults.txt2vid.save_as_jpg, help="Saves the images as jpg instead of png.") - - if GFPGAN_available: - st.session_state["use_GFPGAN"] = st.checkbox("Use GFPGAN", value=defaults.txt2vid.use_GFPGAN, help="Uses the GFPGAN model to improve faces after the generation. This greatly improve the quality and consistency of faces but uses extra VRAM. Disable if you need the extra VRAM.") - else: - st.session_state["use_GFPGAN"] = False - - if RealESRGAN_available: - st.session_state["use_RealESRGAN"] = st.checkbox("Use RealESRGAN", value=defaults.txt2vid.use_RealESRGAN, - help="Uses the RealESRGAN model to upscale the images after the generation. This greatly improve the quality and lets you have high resolution images but uses extra VRAM. Disable if you need the extra VRAM.") - st.session_state["RealESRGAN_model"] = st.selectbox("RealESRGAN model", ["RealESRGAN_x4plus", "RealESRGAN_x4plus_anime_6B"], index=0) - else: - st.session_state["use_RealESRGAN"] = False - st.session_state["RealESRGAN_model"] = "RealESRGAN_x4plus" - - st.session_state["variant_amount"] = st.slider("Variant Amount:", value=defaults.txt2vid.variant_amount, min_value=0.0, max_value=1.0, step=0.01) - st.session_state["variant_seed"] = st.text_input("Variant Seed:", value=defaults.txt2vid.seed, help="The seed to use when generating a variant, if left blank a random seed will be generated.") - st.session_state["beta_start"] = st.slider("Beta Start:", value=defaults.txt2vid.beta_start, min_value=0.0001, max_value=0.03, step=0.0001, format="%.4f") - st.session_state["beta_end"] = st.slider("Beta End:", value=defaults.txt2vid.beta_end, min_value=0.0001, max_value=0.03, step=0.0001, format="%.4f") - - if generate_button: - #print("Loading models") - # load the models when we hit the generate button for the first time, it wont be loaded after that so dont worry. - #load_models(False, False, False, RealESRGAN_model, CustomModel_available=CustomModel_available, custom_model=custom_model) - - # run video generation - image, seed, info, stats = txt2vid(prompts=prompt, gpu=defaults.general.gpu, - num_steps=st.session_state.sampling_steps, max_frames=int(st.session_state.max_frames), - num_inference_steps=st.session_state.num_inference_steps, - cfg_scale=cfg_scale,do_loop=st.session_state["do_loop"], - seeds=seed, quality=100, eta=0.0, width=width, - height=height, weights_path=custom_model, scheduler=scheduler_name, - disable_tqdm=False, beta_start=st.session_state["beta_start"], beta_end=st.session_state["beta_end"], - beta_schedule=beta_scheduler_type) - - #message.success('Done!', icon="✅") - message.success('Render Complete: ' + info + '; Stats: ' + stats, icon="✅") - - #except (StopException, KeyError): - #print(f"Received Streamlit StopException") - - # this will render all the images at the end of the generation but its better if its moved to a second tab inside col2 and shown as a gallery. - # use the current col2 first tab to show the preview_img and update it as its generated. - #preview_image.image(output_images) - - # - elif tabs == 'Model Manager': - #search = st.text_input(label="Search", placeholder="Type the name of the model you want to search for.", help="") - - csvString = f""" - ,Stable Diffusion v1.4 , ./models/ldm/stable-diffusion-v1 , https://www.googleapis.com/storage/v1/b/aai-blog-files/o/sd-v1-4.ckpt?alt=media - ,GFPGAN v1.3 , ./src/gfpgan/experiments/pretrained_models , https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth - ,RealESRGAN_x4plus , ./src/realesrgan/experiments/pretrained_models , https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth - ,RealESRGAN_x4plus_anime_6B , ./src/realesrgan/experiments/pretrained_models , https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth - ,Waifu Diffusion v1.2 , ./models/custom , http://wd.links.sd:8880/wd-v1-2-full-ema.ckpt - ,TrinArt Stable Diffusion v2 , ./models/custom , https://huggingface.co/naclbit/trinart_stable_diffusion_v2/resolve/main/trinart2_step115000.ckpt - """ - colms = st.columns((1, 3, 5, 5)) - columns = ["№",'Model Name','Save Location','Download Link'] - - # Convert String into StringIO - csvStringIO = StringIO(csvString) - df = pd.read_csv(csvStringIO, sep=",", header=None, names=columns) - - for col, field_name in zip(colms, columns): - # table header - col.write(field_name) - - for x, model_name in enumerate(df["Model Name"]): - col1, col2, col3, col4 = st.columns((1, 3, 4, 6)) - col1.write(x) # index - col2.write(df['Model Name'][x]) - col3.write(df['Save Location'][x]) - col4.write(df['Download Link'][x]) - - - elif tabs == 'Settings': - import Settings - - st.write("Settings") - -if __name__ == '__main__': - layout() diff --git a/setup.py b/setup.py index 0e768e1..a24d541 100644 --- a/setup.py +++ b/setup.py @@ -1,7 +1,7 @@ from setuptools import setup, find_packages setup( - name='sd-webui', + name='latent-diffusion', version='0.0.1', description='', packages=find_packages(), diff --git a/webui.sh b/webui.sh index 7b07a49..ea7028f 100755 --- a/webui.sh +++ b/webui.sh @@ -37,7 +37,7 @@ if ! conda env list | grep ".*${ENV_NAME}.*" >/dev/null 2>&1; then ENV_UPDATED=1 elif [[ ! -z $CONDA_FORCE_UPDATE && $CONDA_FORCE_UPDATE == "true" ]] || (( $ENV_MODIFIED > $ENV_MODIFIED_CACHED )); then echo "Updating conda env: ${ENV_NAME} ..." - PIP_EXISTS_ACTION=w conda env update --file $ENV_FILE --prune + conda env update --file $ENV_FILE --prune ENV_UPDATED=1 fi @@ -56,4 +56,4 @@ if [ ! -e "models/ldm/stable-diffusion-v1/model.ckpt" ]; then exit 1 fi -python scripts/relauncher.py +python scripts/relauncher.py \ No newline at end of file