fixed all lines PyCharm was nagging me about

fixed input verification not working properly with long textual inversion tokens in some cases (plus it will prevent incorrect outputs for forks that use the :::: prompt weighing method)
changed process_images to object class with same fields as args it was previously accepting
changed options system to make it possible to explicitly specify gradio objects with args
This commit is contained in:
AUTOMATIC 2022-08-27 21:32:28 +03:00
parent 4e0fdca2f4
commit c30aee2f4b

563
webui.py
View File

@ -1,14 +1,13 @@
import argparse, os, sys, glob
import argparse
import os
import sys
from collections import namedtuple
import torch
import torch.nn as nn
import numpy as np
import gradio as gr
from omegaconf import OmegaConf
from PIL import Image, ImageFont, ImageDraw, PngImagePlugin
from itertools import islice
from einops import rearrange, repeat
from torch import autocast
import mimetypes
import random
@ -22,14 +21,13 @@ import k_diffusion.sampling
from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler
import ldm.modules.encoders.modules
try:
# this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
from transformers import logging
logging.set_verbosity_error()
except:
except Exception:
pass
# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the bowser will not show any UI
@ -41,13 +39,13 @@ opt_C = 4
opt_f = 8
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
invalid_filename_chars = '<>:"/\|?*\n'
invalid_filename_chars = '<>:"/\\|?*\n'
config_filename = "config.json"
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, default="configs/stable-diffusion/v1-inference.yaml", help="path to config which constructs model",)
parser.add_argument("--ckpt", type=str, default="models/ldm/stable-diffusion-v1/model.ckpt", help="path to checkpoint of model",)
parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN')) # i disagree with where you're putting it but since all guidefags are doing it this way, there you go
parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN'))
parser.add_argument("--no-half", action='store_true', help="do not switch the model to 16-bit floats")
parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware accleration in browser)")
parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI")
@ -64,7 +62,7 @@ css_hide_progressbar = """
SamplerData = namedtuple('SamplerData', ['name', 'constructor'])
samplers = [
*[SamplerData(x[0], lambda m, funcname=x[1]: KDiffusionSampler(m, funcname)) for x in [
*[SamplerData(x[0], lambda funcname=x[1]: KDiffusionSampler(funcname)) for x in [
('LMS', 'sample_lms'),
('Heun', 'sample_heun'),
('Euler', 'sample_euler'),
@ -72,8 +70,8 @@ samplers = [
('DPM 2', 'sample_dpm_2'),
('DPM 2 Ancestral', 'sample_dpm_2_ancestral'),
] if hasattr(k_diffusion.sampling, x[1])],
SamplerData('DDIM', lambda m: DDIMSampler(model)),
SamplerData('PLMS', lambda m: PLMSSampler(model)),
SamplerData('DDIM', lambda: VanillaStableDiffusionSampler(DDIMSampler)),
SamplerData('PLMS', lambda: VanillaStableDiffusionSampler(PLMSSampler)),
]
samplers_for_img2img = [x for x in samplers if x.name != 'DDIM' and x.name != 'PLMS']
@ -102,7 +100,7 @@ try:
),
]
have_realesrgan = True
except:
except Exception:
print("Error loading Real-ESRGAN:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
@ -111,24 +109,30 @@ except:
class Options:
class OptionInfo:
def __init__(self, default=None, label="", component=None, component_args=None):
self.default = default
self.label = label
self.component = component
self.component_args = component_args
data = None
data_labels = {
"outdir": ("", "Output dictectory; if empty, defaults to 'outputs/*'"),
"samples_save": (True, "Save indiviual samples"),
"samples_format": ('png', 'File format for indiviual samples'),
"grid_save": (True, "Save image grids"),
"grid_format": ('png', 'File format for grids'),
"grid_extended_filename": (False, "Add extended info (seed, prompt) to filename when saving grid"),
"n_rows": (-1, "Grid row count; use -1 for autodetect and 0 for it to be same as batch size", -1, 16),
"jpeg_quality": (80, "Quality for saved jpeg images", 1, 100),
"verify_input": (True, "Check input, and produce warning if it's too long"),
"enable_pnginfo": (True, "Save text information about generation parameters as chunks to png files"),
"prompt_matrix_add_to_start": (True, "In prompt matrix, add the variable combination of text to the start of the prompt, rather than the end"),
"sd_upscale_overlap": (64, "Overlap for tiles for SD upscale. The smaller it is, the less smooth transition from one tile to another", 0, 256, 16),
"outdir": OptionInfo("", "Output dictectory; if empty, defaults to 'outputs/*'"),
"samples_save": OptionInfo(True, "Save indiviual samples"),
"samples_format": OptionInfo('png', 'File format for indiviual samples'),
"grid_save": OptionInfo(True, "Save image grids"),
"grid_format": OptionInfo('png', 'File format for grids'),
"grid_extended_filename": OptionInfo(False, "Add extended info (seed, prompt) to filename when saving grid"),
"n_rows": OptionInfo(-1, "Grid row count; use -1 for autodetect and 0 for it to be same as batch size", gr.Slider, {"minimum": -1, "maximum": 16, "step": 1}),
"jpeg_quality": OptionInfo(80, "Quality for saved jpeg images", gr.Slider, {"minimum": 1, "maximum": 100, "step": 1}),
"enable_pnginfo": OptionInfo(True, "Save text information about generation parameters as chunks to png files"),
"prompt_matrix_add_to_start": OptionInfo(True, "In prompt matrix, add the variable combination of text to the start of the prompt, rather than the end"),
"sd_upscale_overlap": OptionInfo(64, "Overlap for tiles for SD upscale. The smaller it is, the less smooth transition from one tile to another", gr.Slider, {"minimum": 0, "maximum": 256, "step": 16}),
}
def __init__(self):
self.data = {k: v[0] for k, v in self.data_labels.items()}
self.data = {k: v.default for k, v in self.data_labels.items()}
def __setattr__(self, key, value):
if self.data is not None:
@ -143,7 +147,7 @@ class Options:
return self.data[item]
if item in self.data_labels:
return self.data_labels[item][0]
return self.data_labels[item].default
return super(Options, self).__getattribute__(item)
@ -156,11 +160,6 @@ class Options:
self.data = json.load(file)
def chunk(it, size):
it = iter(it)
return iter(lambda: tuple(islice(it, size)), ())
def load_model_from_config(config, ckpt, verbose=False):
print(f"Loading model from {ckpt}")
pl_sd = torch.load(ckpt, map_location="cpu")
@ -181,36 +180,6 @@ def load_model_from_config(config, ckpt, verbose=False):
return model
class CFGDenoiser(nn.Module):
def __init__(self, model):
super().__init__()
self.inner_model = model
def forward(self, x, sigma, uncond, cond, cond_scale):
x_in = torch.cat([x] * 2)
sigma_in = torch.cat([sigma] * 2)
cond_in = torch.cat([uncond, cond])
uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
return uncond + (cond - uncond) * cond_scale
class KDiffusionSampler:
def __init__(self, m, funcname):
self.model = m
self.model_wrap = k_diffusion.external.CompVisDenoiser(m)
self.funcname = funcname
self.func = getattr(k_diffusion.sampling, self.funcname)
def sample(self, S, conditioning, batch_size, shape, verbose, unconditional_guidance_scale, unconditional_conditioning, eta, x_T):
sigmas = self.model_wrap.get_sigmas(S)
x = x_T * sigmas[0]
model_wrap_cfg = CFGDenoiser(self.model_wrap)
samples_ddim = self.func(model_wrap_cfg, x, sigmas, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': unconditional_guidance_scale}, disable=False)
return samples_ddim, None
def create_random_tensors(shape, seeds):
xs = []
for seed in seeds:
@ -256,7 +225,7 @@ def plaintext_to_html(text):
return text
def load_GFPGAN():
def load_gfpgan():
model_name = 'GFPGANv1.3'
model_path = os.path.join(cmd_opts.gfpgan_dir, 'experiments/pretrained_models', model_name + '.pth')
if not os.path.isfile(model_path):
@ -358,7 +327,7 @@ def combine_grid(grid):
def draw_prompt_matrix(im, width, height, all_prompts):
def wrap(text, d, font, line_length):
def wrap(text, font, line_length):
lines = ['']
for word in text.split():
line = f'{lines[-1]} {word}'.strip()
@ -368,16 +337,16 @@ def draw_prompt_matrix(im, width, height, all_prompts):
lines.append(word)
return '\n'.join(lines)
def draw_texts(pos, x, y, texts, sizes):
def draw_texts(pos, draw_x, draw_y, texts, sizes):
for i, (text, size) in enumerate(zip(texts, sizes)):
active = pos & (1 << i) != 0
if not active:
text = '\u0336'.join(text) + '\u0336'
d.multiline_text((x, y + size[1] / 2), text, font=fnt, fill=color_active if active else color_inactive, anchor="mm", align="center")
d.multiline_text((draw_x, draw_y + size[1] / 2), text, font=fnt, fill=color_active if active else color_inactive, anchor="mm", align="center")
y += size[1] + line_spacing
draw_y += size[1] + line_spacing
fontsize = (width + height) // 25
line_spacing = fontsize // 2
@ -399,8 +368,8 @@ def draw_prompt_matrix(im, width, height, all_prompts):
d = ImageDraw.Draw(result)
boundary = math.ceil(len(prompts) / 2)
prompts_horiz = [wrap(x, d, fnt, width) for x in prompts[:boundary]]
prompts_vert = [wrap(x, d, fnt, pad_left) for x in prompts[boundary:]]
prompts_horiz = [wrap(x, fnt, width) for x in prompts[:boundary]]
prompts_vert = [wrap(x, fnt, pad_left) for x in prompts[boundary:]]
sizes_hor = [(x[2] - x[0], x[3] - x[1]) for x in [d.multiline_textbbox((0, 0), x, font=fnt) for x in prompts_horiz]]
sizes_ver = [(x[2] - x[0], x[3] - x[1]) for x in [d.multiline_textbbox((0, 0), x, font=fnt) for x in prompts_vert]]
@ -458,25 +427,6 @@ def resize_image(resize_mode, im, width, height):
return res
def check_prompt_length(prompt, comments):
"""this function tests if prompt is too long, and if so, adds a message to comments"""
tokenizer = model.cond_stage_model.tokenizer
max_length = model.cond_stage_model.max_length
info = model.cond_stage_model.tokenizer([prompt], truncation=True, max_length=max_length, return_overflowing_tokens=True, padding="max_length", return_tensors="pt")
ovf = info['overflowing_tokens'][0]
overflowing_count = ovf.shape[0]
if overflowing_count == 0:
return
vocab = {v: k for k, v in tokenizer.get_vocab().items()}
overflowing_words = [vocab.get(int(x), "") for x in ovf]
overflowing_text = tokenizer.convert_tokens_to_string(''.join(overflowing_words))
comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
def wrap_gradio_call(func):
def f(*p1, **p2):
t = time.perf_counter()
@ -494,7 +444,7 @@ def wrap_gradio_call(func):
GFPGAN = None
if os.path.exists(cmd_opts.gfpgan_dir):
try:
GFPGAN = load_GFPGAN()
GFPGAN = load_gfpgan()
print("Loaded GFPGAN")
except Exception:
print("Error loading GFPGAN:", file=sys.stderr)
@ -506,11 +456,11 @@ class StableDiffuionModelHijack:
word_embeddings = {}
word_embeddings_checksums = {}
fixes = None
used_custom_terms = []
comments = None
dir_mtime = None
def load_textual_inversion_embeddings(self, dir, model):
mt = os.path.getmtime(dir)
def load_textual_inversion_embeddings(self, dirname, model):
mt = os.path.getmtime(dirname)
if self.dir_mtime is not None and mt <= self.dir_mtime:
return
@ -543,10 +493,10 @@ class StableDiffuionModelHijack:
self.ids_lookup[first_id] = []
self.ids_lookup[first_id].append((ids, name))
for fn in os.listdir(dir):
for fn in os.listdir(dirname):
try:
process_file(os.path.join(dir, fn), fn)
except:
process_file(os.path.join(dirname, fn), fn)
except Exception:
print(f"Error loading emedding {fn}:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
continue
@ -561,10 +511,10 @@ class StableDiffuionModelHijack:
class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
def __init__(self, wrapped, embeddings):
def __init__(self, wrapped, hijack):
super().__init__()
self.wrapped = wrapped
self.embeddings = embeddings
self.hijack = hijack
self.tokenizer = wrapped.tokenizer
self.max_length = wrapped.max_length
self.token_mults = {}
@ -586,12 +536,13 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
self.token_mults[ident] = mult
def forward(self, text):
self.embeddings.fixes = []
self.embeddings.used_custom_terms = []
self.hijack.fixes = []
self.hijack.comments = []
remade_batch_tokens = []
id_start = self.wrapped.tokenizer.bos_token_id
id_end = self.wrapped.tokenizer.eos_token_id
maxlen = self.wrapped.max_length - 2
used_custom_terms = []
cache = {}
batch_tokens = self.wrapped.tokenizer(text, truncation=False, add_special_tokens=False)["input_ids"]
@ -611,7 +562,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
while i < len(tokens):
token = tokens[i]
possible_matches = self.embeddings.ids_lookup.get(token, None)
possible_matches = self.hijack.ids_lookup.get(token, None)
mult_change = self.token_mults.get(token)
if mult_change is not None:
@ -628,7 +579,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
multipliers.append(mult)
i += len(ids) - 1
found = True
self.embeddings.used_custom_terms.append((word, self.embeddings.word_embeddings_checksums[word]))
used_custom_terms.append((word, self.hijack.word_embeddings_checksums[word]))
break
if not found:
@ -637,6 +588,14 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
i += 1
if len(remade_tokens) > maxlen - 2:
vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
ovf = remade_tokens[maxlen - 2:]
overflowing_words = [vocab.get(int(x), "") for x in ovf]
overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words))
self.hijack.comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
remade_tokens = [id_start] + remade_tokens[0:maxlen-2] + [id_end]
cache[tuple_tokens] = (remade_tokens, fixes, multipliers)
@ -645,9 +604,12 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0]
remade_batch_tokens.append(remade_tokens)
self.embeddings.fixes.append(fixes)
self.hijack.fixes.append(fixes)
batch_multipliers.append(multipliers)
if len(used_custom_terms) > 0:
self.hijack.comments.append("Used custom terms: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
tokens = torch.asarray(remade_batch_tokens).to(self.wrapped.device)
outputs = self.wrapped.transformer(input_ids=tokens)
z = outputs.last_hidden_state
@ -679,71 +641,123 @@ class EmbeddingsWithFixes(nn.Module):
for offset, word in fixes:
tensor[offset] = self.embeddings.word_embeddings[word]
return inputs_embeds
def process_images(outpath, func_init, func_sample, prompt, seed, sampler_index, batch_size, n_iter, steps, cfg_scale, width, height, prompt_matrix, use_GFPGAN, do_not_save_grid=False, extra_generation_params=None):
class StableDiffusionProcessing:
def __init__(self, outpath=None, prompt="", seed=-1, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, prompt_matrix=False, use_GFPGAN=False, do_not_save_grid=False, extra_generation_params=None):
self.outpath: str = outpath
self.prompt: str = prompt
self.seed: int = seed
self.sampler_index: int = sampler_index
self.batch_size: int = batch_size
self.n_iter: int = n_iter
self.steps: int = steps
self.cfg_scale: float = cfg_scale
self.width: int = width
self.height: int = height
self.prompt_matrix: bool = prompt_matrix
self.use_GFPGAN: bool = use_GFPGAN
self.do_not_save_grid: bool = do_not_save_grid
self.extra_generation_params: dict = extra_generation_params
def init(self):
pass
def sample(self, x, conditioning, unconditional_conditioning):
raise NotImplementedError()
class VanillaStableDiffusionSampler:
def __init__(self, constructor):
self.sampler = constructor(sd_model)
def sample(self, p: StableDiffusionProcessing, x, conditioning, unconditional_conditioning):
samples_ddim, _ = self.sampler.sample(S=p.steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x)
return samples_ddim
class CFGDenoiser(nn.Module):
def __init__(self, model):
super().__init__()
self.inner_model = model
def forward(self, x, sigma, uncond, cond, cond_scale):
x_in = torch.cat([x] * 2)
sigma_in = torch.cat([sigma] * 2)
cond_in = torch.cat([uncond, cond])
uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
return uncond + (cond - uncond) * cond_scale
class KDiffusionSampler:
def __init__(self, funcname):
self.model_wrap = k_diffusion.external.CompVisDenoiser(sd_model)
self.funcname = funcname
self.func = getattr(k_diffusion.sampling, self.funcname)
self.model_wrap_cfg = CFGDenoiser(self.model_wrap)
def sample(self, p: StableDiffusionProcessing, x, conditioning, unconditional_conditioning):
sigmas = self.model_wrap.get_sigmas(p.steps)
x = x * sigmas[0]
samples_ddim = self.func(self.model_wrap_cfg, x, sigmas, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False)
return samples_ddim
def process_images(p: StableDiffusionProcessing):
"""this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
assert prompt is not None
prompt = p.prompt
model = sd_model
assert p.prompt is not None
torch_gc()
if seed == -1:
seed = random.randrange(4294967294)
seed = int(seed)
seed = int(random.randrange(4294967294) if p.seed == -1 else p.seed)
os.makedirs(outpath, exist_ok=True)
os.makedirs(p.outpath, exist_ok=True)
sample_path = os.path.join(outpath, "samples")
sample_path = os.path.join(p.outpath, "samples")
os.makedirs(sample_path, exist_ok=True)
base_count = len(os.listdir(sample_path))
grid_count = len(os.listdir(outpath)) - 1
grid_count = len(os.listdir(p.outpath)) - 1
comments = []
prompt_matrix_parts = []
if prompt_matrix:
if p.prompt_matrix:
all_prompts = []
prompt_matrix_parts = prompt.split("|")
combination_count = 2 ** (len(prompt_matrix_parts) - 1)
for combination_num in range(combination_count):
selected_prompts = [text.strip().strip(',') for n, text in enumerate(prompt_matrix_parts[1:]) if combination_num & (1<<n)]
selected_prompts = [text.strip().strip(',') for n, text in enumerate(prompt_matrix_parts[1:]) if combination_num & (1 << n)]
if opts.prompt_matrix_add_to_start:
selected_prompts = selected_prompts + [prompt_matrix_parts[0]]
else:
selected_prompts = [prompt_matrix_parts[0]] + selected_prompts
all_prompts.append( ", ".join(selected_prompts))
all_prompts.append(", ".join(selected_prompts))
n_iter = math.ceil(len(all_prompts) / batch_size)
p.n_iter = math.ceil(len(all_prompts) / p.batch_size)
all_seeds = len(all_prompts) * [seed]
print(f"Prompt matrix will create {len(all_prompts)} images using a total of {n_iter} batches.")
print(f"Prompt matrix will create {len(all_prompts)} images using a total of {p.n_iter} batches.")
else:
if opts.verify_input:
try:
check_prompt_length(prompt, comments)
except:
import traceback
print("Error verifying input:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
all_prompts = batch_size * n_iter * [prompt]
all_prompts = p.batch_size * p.n_iter * [prompt]
all_seeds = [seed + x for x in range(len(all_prompts))]
generation_params = {
"Steps": steps,
"Sampler": samplers[sampler_index].name,
"CFG scale": cfg_scale,
"Steps": p.steps,
"Sampler": samplers[p.sampler_index].name,
"CFG scale": p.cfg_scale,
"Seed": seed,
"GFPGAN": ("GFPGAN" if use_GFPGAN and GFPGAN is not None else None)
"GFPGAN": ("GFPGAN" if p.use_GFPGAN and GFPGAN is not None else None)
}
if extra_generation_params is not None:
generation_params.update(extra_generation_params)
if p.extra_generation_params is not None:
generation_params.update(p.extra_generation_params)
generation_params_text = ", ".join([k if k == v else f'{k}: {v}' for k, v in generation_params.items() if v is not None])
@ -755,32 +769,32 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_index,
output_images = []
with torch.no_grad(), autocast("cuda"), model.ema_scope():
init_data = func_init()
p.init()
for n in range(n_iter):
prompts = all_prompts[n * batch_size:(n + 1) * batch_size]
seeds = all_seeds[n * batch_size:(n + 1) * batch_size]
for n in range(p.n_iter):
prompts = all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
seeds = all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
uc = model.get_learned_conditioning(len(prompts) * [""])
c = model.get_learned_conditioning(prompts)
if len(model_hijack.used_custom_terms) > 0:
comments.append("Used custom terms: " + ", ".join([f'{word} [{checksum}]' for word, checksum in model_hijack.used_custom_terms]))
if len(model_hijack.comments) > 0:
comments += model_hijack.comments
# we manually generate all input noises because each one should have a specific seed
x = create_random_tensors([opt_C, height // opt_f, width // opt_f], seeds=seeds)
x = create_random_tensors([opt_C, p.height // opt_f, p.width // opt_f], seeds=seeds)
samples_ddim = func_sample(init_data=init_data, x=x, conditioning=c, unconditional_conditioning=uc)
samples_ddim = p.sample(x=x, conditioning=c, unconditional_conditioning=uc)
x_samples_ddim = model.decode_first_stage(samples_ddim)
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
if prompt_matrix or opts.samples_save or opts.grid_save:
if p.prompt_matrix or opts.samples_save or opts.grid_save:
for i, x_sample in enumerate(x_samples_ddim):
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
x_sample = x_sample.astype(np.uint8)
if use_GFPGAN and GFPGAN is not None:
if p.use_GFPGAN and GFPGAN is not None:
torch_gc()
cropped_faces, restored_faces, restored_img = GFPGAN.enhance(x_sample, has_aligned=False, only_center_face=False, paste_back=True)
x_sample = restored_img
@ -791,44 +805,44 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_index,
output_images.append(image)
base_count += 1
if (prompt_matrix or opts.grid_save) and not do_not_save_grid:
if prompt_matrix:
grid = image_grid(output_images, batch_size, force_n_rows=1 << ((len(prompt_matrix_parts)-1)//2))
if (p.prompt_matrix or opts.grid_save) and not p.do_not_save_grid:
if p.prompt_matrix:
grid = image_grid(output_images, p.batch_size, force_n_rows=1 << ((len(prompt_matrix_parts)-1)//2))
try:
grid = draw_prompt_matrix(grid, width, height, prompt_matrix_parts)
except:
grid = draw_prompt_matrix(grid, p.width, p.height, prompt_matrix_parts)
except Exception:
import traceback
print("Error creating prompt_matrix text:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
output_images.insert(0, grid)
else:
grid = image_grid(output_images, batch_size)
grid = image_grid(output_images, p.batch_size)
save_image(grid, outpath, f"grid-{grid_count:04}", seed, prompt, opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename)
save_image(grid, p.outpath, f"grid-{grid_count:04}", seed, prompt, opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename)
grid_count += 1
torch_gc()
return output_images, seed, infotext()
def txt2img(prompt: str, ddim_steps: int, sampler_index: int, use_GFPGAN: bool, prompt_matrix: bool, ddim_eta: float, n_iter: int, batch_size: int, cfg_scale: float, seed: int, height: int, width: int):
outpath = opts.outdir or "outputs/txt2img-samples"
class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
sampler = None
sampler = samplers[sampler_index].constructor(model)
def init(self):
self.sampler = samplers[self.sampler_index].constructor()
def init():
pass
def sample(init_data, x, conditioning, unconditional_conditioning):
samples_ddim, _ = sampler.sample(S=ddim_steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=cfg_scale, unconditional_conditioning=unconditional_conditioning, eta=ddim_eta, x_T=x)
def sample(self, x, conditioning, unconditional_conditioning):
samples_ddim = self.sampler.sample(self, x, conditioning, unconditional_conditioning)
return samples_ddim
output_images, seed, info = process_images(
def txt2img(prompt: str, ddim_steps: int, sampler_index: int, use_GFPGAN: bool, prompt_matrix: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, height: int, width: int):
outpath = opts.outdir or "outputs/txt2img-samples"
p = StableDiffusionProcessingTxt2Img(
outpath=outpath,
func_init=init,
func_sample=sample,
prompt=prompt,
seed=seed,
sampler_index=sampler_index,
@ -842,7 +856,7 @@ def txt2img(prompt: str, ddim_steps: int, sampler_index: int, use_GFPGAN: bool,
use_GFPGAN=use_GFPGAN
)
del sampler
output_images, seed, info = process_images(p)
return output_images, seed, plaintext_to_html(info)
@ -858,7 +872,7 @@ class Flagging(gr.FlaggingCallback):
os.makedirs("log/images", exist_ok=True)
# those must match the "txt2img" function
prompt, ddim_steps, sampler_name, use_GFPGAN, prompt_matrix, ddim_eta, n_iter, n_samples, cfg_scale, request_seed, height, width, images, seed, comment = flag_data
prompt, ddim_steps, sampler_name, use_gfpgan, prompt_matrix, ddim_eta, n_iter, n_samples, cfg_scale, request_seed, height, width, images, seed, comment = flag_data
filenames = []
@ -896,7 +910,6 @@ txt2img_interface = gr.Interface(
gr.Radio(label='Sampling method', choices=[x.name for x in samplers], value=samplers[0].name, type="index"),
gr.Checkbox(label='Fix faces using GFPGAN', value=False, visible=GFPGAN is not None),
gr.Checkbox(label='Create prompt matrix (separate multiple prompts using |, and get all combinations of them)', value=False),
gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="DDIM ETA", value=0.0, visible=False),
gr.Slider(minimum=1, maximum=cmd_opts.max_batch_count, step=1, label='Batch count (how many batches of images to generate)', value=1),
gr.Slider(minimum=1, maximum=8, step=1, label='Batch size (how many images are in a batch; memory-hungry)', value=1),
gr.Slider(minimum=1.0, maximum=15.0, step=0.5, label='Classifier Free Guidance Scale (how strongly the image should follow the prompt)', value=7.0),
@ -914,40 +927,77 @@ txt2img_interface = gr.Interface(
)
class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
sampler = None
def __init__(self, init_images=None, resize_mode=0, denoising_strength=0.75, **kwargs):
super().__init__(**kwargs)
self.init_images = init_images
self.resize_mode: int = resize_mode
self.denoising_strength: float = denoising_strength
self.init_latent = None
def init(self):
self.sampler = samplers_for_img2img[self.sampler_index].constructor()
imgs = []
for img in self.init_images:
image = img.convert("RGB")
image = resize_image(self.resize_mode, image, self.width, self.height)
image = np.array(image).astype(np.float32) / 255.0
image = np.moveaxis(image, 2, 0)
imgs.append(image)
if len(imgs) == 1:
batch_images = np.expand_dims(imgs[0], axis=0).repeat(self.batch_size, axis=0)
elif len(imgs) <= self.batch_size:
self.batch_size = len(imgs)
batch_images = np.array(imgs)
else:
raise RuntimeError(f"bad number of images passed: {len(imgs)}; expecting {self.batch_size} or less")
image = torch.from_numpy(batch_images)
image = 2. * image - 1.
image = image.to(device)
self.init_latent = sd_model.get_first_stage_encoding(sd_model.encode_first_stage(image))
def sample(self, x, conditioning, unconditional_conditioning):
t_enc = int(self.denoising_strength * self.steps)
sigmas = self.sampler.model_wrap.get_sigmas(self.steps)
noise = x * sigmas[self.steps - t_enc - 1]
xi = self.init_latent + noise
sigma_sched = sigmas[self.steps - t_enc - 1:]
samples_ddim = self.sampler.func(self.sampler.model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': self.cfg_scale}, disable=False)
return samples_ddim
def img2img(prompt: str, init_img, ddim_steps: int, sampler_index: int, use_GFPGAN: bool, prompt_matrix, loopback: bool, sd_upscale: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, height: int, width: int, resize_mode: int):
outpath = opts.outdir or "outputs/img2img-samples"
sampler = samplers_for_img2img[sampler_index].constructor(model)
assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]'
def init():
image = init_img.convert("RGB")
image = resize_image(resize_mode, image, width, height)
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
init_image = 2. * image - 1.
init_image = init_image.to(device)
init_image = repeat(init_image, '1 ... -> b ...', b=batch_size)
init_latent = model.get_first_stage_encoding(model.encode_first_stage(init_image)) # move to latent space
return init_latent,
def sample(init_data, x, conditioning, unconditional_conditioning):
t_enc = int(denoising_strength * ddim_steps)
x0, = init_data
sigmas = sampler.model_wrap.get_sigmas(ddim_steps)
noise = x * sigmas[ddim_steps - t_enc - 1]
xi = x0 + noise
sigma_sched = sigmas[ddim_steps - t_enc - 1:]
model_wrap_cfg = CFGDenoiser(sampler.model_wrap)
samples_ddim = sampler.func(model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': cfg_scale}, disable=False)
return samples_ddim
p = StableDiffusionProcessingImg2Img(
outpath=outpath,
prompt=prompt,
seed=seed,
sampler_index=sampler_index,
batch_size=batch_size,
n_iter=n_iter,
steps=ddim_steps,
cfg_scale=cfg_scale,
width=width,
height=height,
prompt_matrix=prompt_matrix,
use_GFPGAN=use_GFPGAN,
init_images=[init_img],
resize_mode=resize_mode,
denoising_strength=denoising_strength,
extra_generation_params={"Denoising Strength": denoising_strength}
)
if loopback:
output_images, info = None, None
@ -955,32 +1005,19 @@ def img2img(prompt: str, init_img, ddim_steps: int, sampler_index: int, use_GFPG
initial_seed = None
for i in range(n_iter):
output_images, seed, info = process_images(
outpath=outpath,
func_init=init,
func_sample=sample,
prompt=prompt,
seed=seed,
sampler_index=sampler_index,
batch_size=1,
n_iter=1,
steps=ddim_steps,
cfg_scale=cfg_scale,
width=width,
height=height,
prompt_matrix=prompt_matrix,
use_GFPGAN=use_GFPGAN,
do_not_save_grid=True,
extra_generation_params={"Denoising Strength": denoising_strength},
)
p.n_iter = 1
p.batch_size = 1
p.do_not_save_grid = True
output_images, seed, info = process_images(p)
if initial_seed is None:
initial_seed = seed
init_img = output_images[0]
seed = seed + 1
denoising_strength = max(denoising_strength * 0.95, 0.1)
history.append(init_img)
p.init_img = output_images[0]
p.seed = seed + 1
p.denoising_strength = max(p.denoising_strength * 0.95, 0.1)
history.append(output_images[0])
grid_count = len(os.listdir(outpath)) - 1
grid = image_grid(history, batch_size, force_n_rows=1)
@ -1000,39 +1037,36 @@ def img2img(prompt: str, init_img, ddim_steps: int, sampler_index: int, use_GFPG
grid = split_grid(img, tile_w=width, tile_h=height, overlap=opts.sd_upscale_overlap)
p.n_iter = 1
p.do_not_save_grid = True
print(f"SD upscaling will process a total of {len(grid.tiles[0][2])}x{len(grid.tiles)} images.")
work = []
work_results = []
for y, h, row in grid.tiles:
for tiledata in row:
init_img = tiledata[2]
work.append(tiledata[2])
output_images, seed, info = process_images(
outpath=outpath,
func_init=init,
func_sample=sample,
prompt=prompt,
seed=seed,
sampler_index=sampler_index,
batch_size=1, # since process_images can't work with multiple different images we have to do this for now
n_iter=1,
steps=ddim_steps,
cfg_scale=cfg_scale,
width=width,
height=height,
prompt_matrix=prompt_matrix,
use_GFPGAN=use_GFPGAN,
do_not_save_grid=True,
extra_generation_params={"Denoising Strength": denoising_strength},
)
batch_count = math.ceil(len(work) / p.batch_size)
print(f"SD upscaling will process a total of {len(work)} images tiled as {len(grid.tiles[0][2])}x{len(grid.tiles)} in a total of {batch_count} batches.")
if initial_seed is None:
initial_seed = seed
initial_info = info
for i in range(batch_count):
p.init_images = work[i*p.batch_size:(i+1)*p.batch_size]
seed += 1
output_images, seed, info = process_images(p)
tiledata[2] = output_images[0]
if initial_seed is None:
initial_seed = seed
initial_info = info
p.seed = seed + 1
work_results += output_images
image_index = 0
for y, h, row in grid.tiles:
for tiledata in row:
tiledata[2] = work_results[image_index]
image_index += 1
combined_image = combine_grid(grid)
@ -1044,25 +1078,7 @@ def img2img(prompt: str, init_img, ddim_steps: int, sampler_index: int, use_GFPG
info = initial_info
else:
output_images, seed, info = process_images(
outpath=outpath,
func_init=init,
func_sample=sample,
prompt=prompt,
seed=seed,
sampler_index=sampler_index,
batch_size=batch_size,
n_iter=n_iter,
steps=ddim_steps,
cfg_scale=cfg_scale,
width=width,
height=height,
prompt_matrix=prompt_matrix,
use_GFPGAN=use_GFPGAN,
extra_generation_params={"Denoising Strength": denoising_strength},
)
del sampler
output_images, seed, info = process_images(p)
return output_images, seed, plaintext_to_html(info)
@ -1178,22 +1194,19 @@ def run_settings(*args):
def create_setting_component(key):
def fun():
return opts.data[key] if key in opts.data else opts.data_labels[key][0]
return opts.data[key] if key in opts.data else opts.data_labels[key].default
labelinfo = opts.data_labels[key]
t = type(labelinfo[0])
label = labelinfo[1]
if t == str:
item = gr.Textbox(label=label, value=fun, lines=1)
info = opts.data_labels[key]
t = type(info.default)
if info.component is not None:
item = info.component(label=info.label, value=fun, **(info.component_args or {}))
elif t == str:
item = gr.Textbox(label=info.label, value=fun, lines=1)
elif t == int:
if len(labelinfo) == 5:
item = gr.Slider(minimum=labelinfo[2], maximum=labelinfo[3], step=labelinfo[4], label=label, value=fun)
elif len(labelinfo) == 4:
item = gr.Slider(minimum=labelinfo[2], maximum=labelinfo[3], step=1, label=label, value=fun)
else:
item = gr.Number(label=label, value=fun)
item = gr.Number(label=info.label, value=fun)
elif t == bool:
item = gr.Checkbox(label=label, value=fun)
item = gr.Checkbox(label=info.label, value=fun)
else:
raise Exception(f'bad options item type: {str(t)} for key {key}')
@ -1219,14 +1232,14 @@ interfaces = [
(settings_interface, "Settings"),
]
config = OmegaConf.load(cmd_opts.config)
model = load_model_from_config(config, cmd_opts.ckpt)
sd_config = OmegaConf.load(cmd_opts.config)
sd_model = load_model_from_config(sd_config, cmd_opts.ckpt)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = (model if cmd_opts.no_half else model.half()).to(device)
sd_model = (sd_model if cmd_opts.no_half else sd_model.half()).to(device)
model_hijack = StableDiffuionModelHijack()
model_hijack.hijack(model)
model_hijack.hijack(sd_model)
demo = gr.TabbedInterface(
interface_list=[x[0] for x in interfaces],