From c30aee2f4b4c3a46f6ae878b880d0b837609f63d Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 27 Aug 2022 21:32:28 +0300 Subject: [PATCH] fixed all lines PyCharm was nagging me about fixed input verification not working properly with long textual inversion tokens in some cases (plus it will prevent incorrect outputs for forks that use the :::: prompt weighing method) changed process_images to object class with same fields as args it was previously accepting changed options system to make it possible to explicitly specify gradio objects with args --- webui.py | 563 ++++++++++++++++++++++++++++--------------------------- 1 file changed, 288 insertions(+), 275 deletions(-) diff --git a/webui.py b/webui.py index 13e5112a..8de1bcf2 100644 --- a/webui.py +++ b/webui.py @@ -1,14 +1,13 @@ -import argparse, os, sys, glob +import argparse +import os +import sys from collections import namedtuple - import torch import torch.nn as nn import numpy as np import gradio as gr from omegaconf import OmegaConf from PIL import Image, ImageFont, ImageDraw, PngImagePlugin -from itertools import islice -from einops import rearrange, repeat from torch import autocast import mimetypes import random @@ -22,14 +21,13 @@ import k_diffusion.sampling from ldm.util import instantiate_from_config from ldm.models.diffusion.ddim import DDIMSampler from ldm.models.diffusion.plms import PLMSSampler -import ldm.modules.encoders.modules try: # this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start. from transformers import logging logging.set_verbosity_error() -except: +except Exception: pass # this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the bowser will not show any UI @@ -41,13 +39,13 @@ opt_C = 4 opt_f = 8 LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS) -invalid_filename_chars = '<>:"/\|?*\n' +invalid_filename_chars = '<>:"/\\|?*\n' config_filename = "config.json" parser = argparse.ArgumentParser() parser.add_argument("--config", type=str, default="configs/stable-diffusion/v1-inference.yaml", help="path to config which constructs model",) parser.add_argument("--ckpt", type=str, default="models/ldm/stable-diffusion-v1/model.ckpt", help="path to checkpoint of model",) -parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN')) # i disagree with where you're putting it but since all guidefags are doing it this way, there you go +parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN')) parser.add_argument("--no-half", action='store_true', help="do not switch the model to 16-bit floats") parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware accleration in browser)") parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI") @@ -64,7 +62,7 @@ css_hide_progressbar = """ SamplerData = namedtuple('SamplerData', ['name', 'constructor']) samplers = [ - *[SamplerData(x[0], lambda m, funcname=x[1]: KDiffusionSampler(m, funcname)) for x in [ + *[SamplerData(x[0], lambda funcname=x[1]: KDiffusionSampler(funcname)) for x in [ ('LMS', 'sample_lms'), ('Heun', 'sample_heun'), ('Euler', 'sample_euler'), @@ -72,8 +70,8 @@ samplers = [ ('DPM 2', 'sample_dpm_2'), ('DPM 2 Ancestral', 'sample_dpm_2_ancestral'), ] if hasattr(k_diffusion.sampling, x[1])], - SamplerData('DDIM', lambda m: DDIMSampler(model)), - SamplerData('PLMS', lambda m: PLMSSampler(model)), + SamplerData('DDIM', lambda: VanillaStableDiffusionSampler(DDIMSampler)), + SamplerData('PLMS', lambda: VanillaStableDiffusionSampler(PLMSSampler)), ] samplers_for_img2img = [x for x in samplers if x.name != 'DDIM' and x.name != 'PLMS'] @@ -102,7 +100,7 @@ try: ), ] have_realesrgan = True -except: +except Exception: print("Error loading Real-ESRGAN:", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) @@ -111,24 +109,30 @@ except: class Options: + class OptionInfo: + def __init__(self, default=None, label="", component=None, component_args=None): + self.default = default + self.label = label + self.component = component + self.component_args = component_args + data = None data_labels = { - "outdir": ("", "Output dictectory; if empty, defaults to 'outputs/*'"), - "samples_save": (True, "Save indiviual samples"), - "samples_format": ('png', 'File format for indiviual samples'), - "grid_save": (True, "Save image grids"), - "grid_format": ('png', 'File format for grids'), - "grid_extended_filename": (False, "Add extended info (seed, prompt) to filename when saving grid"), - "n_rows": (-1, "Grid row count; use -1 for autodetect and 0 for it to be same as batch size", -1, 16), - "jpeg_quality": (80, "Quality for saved jpeg images", 1, 100), - "verify_input": (True, "Check input, and produce warning if it's too long"), - "enable_pnginfo": (True, "Save text information about generation parameters as chunks to png files"), - "prompt_matrix_add_to_start": (True, "In prompt matrix, add the variable combination of text to the start of the prompt, rather than the end"), - "sd_upscale_overlap": (64, "Overlap for tiles for SD upscale. The smaller it is, the less smooth transition from one tile to another", 0, 256, 16), + "outdir": OptionInfo("", "Output dictectory; if empty, defaults to 'outputs/*'"), + "samples_save": OptionInfo(True, "Save indiviual samples"), + "samples_format": OptionInfo('png', 'File format for indiviual samples'), + "grid_save": OptionInfo(True, "Save image grids"), + "grid_format": OptionInfo('png', 'File format for grids'), + "grid_extended_filename": OptionInfo(False, "Add extended info (seed, prompt) to filename when saving grid"), + "n_rows": OptionInfo(-1, "Grid row count; use -1 for autodetect and 0 for it to be same as batch size", gr.Slider, {"minimum": -1, "maximum": 16, "step": 1}), + "jpeg_quality": OptionInfo(80, "Quality for saved jpeg images", gr.Slider, {"minimum": 1, "maximum": 100, "step": 1}), + "enable_pnginfo": OptionInfo(True, "Save text information about generation parameters as chunks to png files"), + "prompt_matrix_add_to_start": OptionInfo(True, "In prompt matrix, add the variable combination of text to the start of the prompt, rather than the end"), + "sd_upscale_overlap": OptionInfo(64, "Overlap for tiles for SD upscale. The smaller it is, the less smooth transition from one tile to another", gr.Slider, {"minimum": 0, "maximum": 256, "step": 16}), } def __init__(self): - self.data = {k: v[0] for k, v in self.data_labels.items()} + self.data = {k: v.default for k, v in self.data_labels.items()} def __setattr__(self, key, value): if self.data is not None: @@ -143,7 +147,7 @@ class Options: return self.data[item] if item in self.data_labels: - return self.data_labels[item][0] + return self.data_labels[item].default return super(Options, self).__getattribute__(item) @@ -156,11 +160,6 @@ class Options: self.data = json.load(file) -def chunk(it, size): - it = iter(it) - return iter(lambda: tuple(islice(it, size)), ()) - - def load_model_from_config(config, ckpt, verbose=False): print(f"Loading model from {ckpt}") pl_sd = torch.load(ckpt, map_location="cpu") @@ -181,36 +180,6 @@ def load_model_from_config(config, ckpt, verbose=False): return model -class CFGDenoiser(nn.Module): - def __init__(self, model): - super().__init__() - self.inner_model = model - - def forward(self, x, sigma, uncond, cond, cond_scale): - x_in = torch.cat([x] * 2) - sigma_in = torch.cat([sigma] * 2) - cond_in = torch.cat([uncond, cond]) - uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2) - return uncond + (cond - uncond) * cond_scale - - -class KDiffusionSampler: - def __init__(self, m, funcname): - self.model = m - self.model_wrap = k_diffusion.external.CompVisDenoiser(m) - self.funcname = funcname - self.func = getattr(k_diffusion.sampling, self.funcname) - - def sample(self, S, conditioning, batch_size, shape, verbose, unconditional_guidance_scale, unconditional_conditioning, eta, x_T): - sigmas = self.model_wrap.get_sigmas(S) - x = x_T * sigmas[0] - model_wrap_cfg = CFGDenoiser(self.model_wrap) - - samples_ddim = self.func(model_wrap_cfg, x, sigmas, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': unconditional_guidance_scale}, disable=False) - - return samples_ddim, None - - def create_random_tensors(shape, seeds): xs = [] for seed in seeds: @@ -256,7 +225,7 @@ def plaintext_to_html(text): return text -def load_GFPGAN(): +def load_gfpgan(): model_name = 'GFPGANv1.3' model_path = os.path.join(cmd_opts.gfpgan_dir, 'experiments/pretrained_models', model_name + '.pth') if not os.path.isfile(model_path): @@ -358,7 +327,7 @@ def combine_grid(grid): def draw_prompt_matrix(im, width, height, all_prompts): - def wrap(text, d, font, line_length): + def wrap(text, font, line_length): lines = [''] for word in text.split(): line = f'{lines[-1]} {word}'.strip() @@ -368,16 +337,16 @@ def draw_prompt_matrix(im, width, height, all_prompts): lines.append(word) return '\n'.join(lines) - def draw_texts(pos, x, y, texts, sizes): + def draw_texts(pos, draw_x, draw_y, texts, sizes): for i, (text, size) in enumerate(zip(texts, sizes)): active = pos & (1 << i) != 0 if not active: text = '\u0336'.join(text) + '\u0336' - d.multiline_text((x, y + size[1] / 2), text, font=fnt, fill=color_active if active else color_inactive, anchor="mm", align="center") + d.multiline_text((draw_x, draw_y + size[1] / 2), text, font=fnt, fill=color_active if active else color_inactive, anchor="mm", align="center") - y += size[1] + line_spacing + draw_y += size[1] + line_spacing fontsize = (width + height) // 25 line_spacing = fontsize // 2 @@ -399,8 +368,8 @@ def draw_prompt_matrix(im, width, height, all_prompts): d = ImageDraw.Draw(result) boundary = math.ceil(len(prompts) / 2) - prompts_horiz = [wrap(x, d, fnt, width) for x in prompts[:boundary]] - prompts_vert = [wrap(x, d, fnt, pad_left) for x in prompts[boundary:]] + prompts_horiz = [wrap(x, fnt, width) for x in prompts[:boundary]] + prompts_vert = [wrap(x, fnt, pad_left) for x in prompts[boundary:]] sizes_hor = [(x[2] - x[0], x[3] - x[1]) for x in [d.multiline_textbbox((0, 0), x, font=fnt) for x in prompts_horiz]] sizes_ver = [(x[2] - x[0], x[3] - x[1]) for x in [d.multiline_textbbox((0, 0), x, font=fnt) for x in prompts_vert]] @@ -458,25 +427,6 @@ def resize_image(resize_mode, im, width, height): return res -def check_prompt_length(prompt, comments): - """this function tests if prompt is too long, and if so, adds a message to comments""" - - tokenizer = model.cond_stage_model.tokenizer - max_length = model.cond_stage_model.max_length - - info = model.cond_stage_model.tokenizer([prompt], truncation=True, max_length=max_length, return_overflowing_tokens=True, padding="max_length", return_tensors="pt") - ovf = info['overflowing_tokens'][0] - overflowing_count = ovf.shape[0] - if overflowing_count == 0: - return - - vocab = {v: k for k, v in tokenizer.get_vocab().items()} - overflowing_words = [vocab.get(int(x), "") for x in ovf] - overflowing_text = tokenizer.convert_tokens_to_string(''.join(overflowing_words)) - - comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n") - - def wrap_gradio_call(func): def f(*p1, **p2): t = time.perf_counter() @@ -494,7 +444,7 @@ def wrap_gradio_call(func): GFPGAN = None if os.path.exists(cmd_opts.gfpgan_dir): try: - GFPGAN = load_GFPGAN() + GFPGAN = load_gfpgan() print("Loaded GFPGAN") except Exception: print("Error loading GFPGAN:", file=sys.stderr) @@ -506,11 +456,11 @@ class StableDiffuionModelHijack: word_embeddings = {} word_embeddings_checksums = {} fixes = None - used_custom_terms = [] + comments = None dir_mtime = None - def load_textual_inversion_embeddings(self, dir, model): - mt = os.path.getmtime(dir) + def load_textual_inversion_embeddings(self, dirname, model): + mt = os.path.getmtime(dirname) if self.dir_mtime is not None and mt <= self.dir_mtime: return @@ -543,10 +493,10 @@ class StableDiffuionModelHijack: self.ids_lookup[first_id] = [] self.ids_lookup[first_id].append((ids, name)) - for fn in os.listdir(dir): + for fn in os.listdir(dirname): try: - process_file(os.path.join(dir, fn), fn) - except: + process_file(os.path.join(dirname, fn), fn) + except Exception: print(f"Error loading emedding {fn}:", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) continue @@ -561,10 +511,10 @@ class StableDiffuionModelHijack: class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): - def __init__(self, wrapped, embeddings): + def __init__(self, wrapped, hijack): super().__init__() self.wrapped = wrapped - self.embeddings = embeddings + self.hijack = hijack self.tokenizer = wrapped.tokenizer self.max_length = wrapped.max_length self.token_mults = {} @@ -586,12 +536,13 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): self.token_mults[ident] = mult def forward(self, text): - self.embeddings.fixes = [] - self.embeddings.used_custom_terms = [] + self.hijack.fixes = [] + self.hijack.comments = [] remade_batch_tokens = [] id_start = self.wrapped.tokenizer.bos_token_id id_end = self.wrapped.tokenizer.eos_token_id maxlen = self.wrapped.max_length - 2 + used_custom_terms = [] cache = {} batch_tokens = self.wrapped.tokenizer(text, truncation=False, add_special_tokens=False)["input_ids"] @@ -611,7 +562,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): while i < len(tokens): token = tokens[i] - possible_matches = self.embeddings.ids_lookup.get(token, None) + possible_matches = self.hijack.ids_lookup.get(token, None) mult_change = self.token_mults.get(token) if mult_change is not None: @@ -628,7 +579,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): multipliers.append(mult) i += len(ids) - 1 found = True - self.embeddings.used_custom_terms.append((word, self.embeddings.word_embeddings_checksums[word])) + used_custom_terms.append((word, self.hijack.word_embeddings_checksums[word])) break if not found: @@ -637,6 +588,14 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): i += 1 + if len(remade_tokens) > maxlen - 2: + vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()} + ovf = remade_tokens[maxlen - 2:] + overflowing_words = [vocab.get(int(x), "") for x in ovf] + overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words)) + + self.hijack.comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n") + remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens)) remade_tokens = [id_start] + remade_tokens[0:maxlen-2] + [id_end] cache[tuple_tokens] = (remade_tokens, fixes, multipliers) @@ -645,9 +604,12 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0] remade_batch_tokens.append(remade_tokens) - self.embeddings.fixes.append(fixes) + self.hijack.fixes.append(fixes) batch_multipliers.append(multipliers) + if len(used_custom_terms) > 0: + self.hijack.comments.append("Used custom terms: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms])) + tokens = torch.asarray(remade_batch_tokens).to(self.wrapped.device) outputs = self.wrapped.transformer(input_ids=tokens) z = outputs.last_hidden_state @@ -679,71 +641,123 @@ class EmbeddingsWithFixes(nn.Module): for offset, word in fixes: tensor[offset] = self.embeddings.word_embeddings[word] - return inputs_embeds -def process_images(outpath, func_init, func_sample, prompt, seed, sampler_index, batch_size, n_iter, steps, cfg_scale, width, height, prompt_matrix, use_GFPGAN, do_not_save_grid=False, extra_generation_params=None): +class StableDiffusionProcessing: + def __init__(self, outpath=None, prompt="", seed=-1, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, prompt_matrix=False, use_GFPGAN=False, do_not_save_grid=False, extra_generation_params=None): + self.outpath: str = outpath + self.prompt: str = prompt + self.seed: int = seed + self.sampler_index: int = sampler_index + self.batch_size: int = batch_size + self.n_iter: int = n_iter + self.steps: int = steps + self.cfg_scale: float = cfg_scale + self.width: int = width + self.height: int = height + self.prompt_matrix: bool = prompt_matrix + self.use_GFPGAN: bool = use_GFPGAN + self.do_not_save_grid: bool = do_not_save_grid + self.extra_generation_params: dict = extra_generation_params + + def init(self): + pass + + def sample(self, x, conditioning, unconditional_conditioning): + raise NotImplementedError() + + +class VanillaStableDiffusionSampler: + def __init__(self, constructor): + self.sampler = constructor(sd_model) + + def sample(self, p: StableDiffusionProcessing, x, conditioning, unconditional_conditioning): + samples_ddim, _ = self.sampler.sample(S=p.steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x) + return samples_ddim + + +class CFGDenoiser(nn.Module): + def __init__(self, model): + super().__init__() + self.inner_model = model + + def forward(self, x, sigma, uncond, cond, cond_scale): + x_in = torch.cat([x] * 2) + sigma_in = torch.cat([sigma] * 2) + cond_in = torch.cat([uncond, cond]) + uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2) + return uncond + (cond - uncond) * cond_scale + + +class KDiffusionSampler: + def __init__(self, funcname): + self.model_wrap = k_diffusion.external.CompVisDenoiser(sd_model) + self.funcname = funcname + self.func = getattr(k_diffusion.sampling, self.funcname) + self.model_wrap_cfg = CFGDenoiser(self.model_wrap) + + def sample(self, p: StableDiffusionProcessing, x, conditioning, unconditional_conditioning): + sigmas = self.model_wrap.get_sigmas(p.steps) + x = x * sigmas[0] + + samples_ddim = self.func(self.model_wrap_cfg, x, sigmas, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False) + return samples_ddim + + +def process_images(p: StableDiffusionProcessing): """this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch""" - assert prompt is not None + prompt = p.prompt + model = sd_model + + assert p.prompt is not None torch_gc() - if seed == -1: - seed = random.randrange(4294967294) - seed = int(seed) + seed = int(random.randrange(4294967294) if p.seed == -1 else p.seed) - os.makedirs(outpath, exist_ok=True) + os.makedirs(p.outpath, exist_ok=True) - sample_path = os.path.join(outpath, "samples") + sample_path = os.path.join(p.outpath, "samples") os.makedirs(sample_path, exist_ok=True) base_count = len(os.listdir(sample_path)) - grid_count = len(os.listdir(outpath)) - 1 + grid_count = len(os.listdir(p.outpath)) - 1 comments = [] prompt_matrix_parts = [] - if prompt_matrix: + if p.prompt_matrix: all_prompts = [] prompt_matrix_parts = prompt.split("|") combination_count = 2 ** (len(prompt_matrix_parts) - 1) for combination_num in range(combination_count): - selected_prompts = [text.strip().strip(',') for n, text in enumerate(prompt_matrix_parts[1:]) if combination_num & (1< 0: - comments.append("Used custom terms: " + ", ".join([f'{word} [{checksum}]' for word, checksum in model_hijack.used_custom_terms])) + if len(model_hijack.comments) > 0: + comments += model_hijack.comments # we manually generate all input noises because each one should have a specific seed - x = create_random_tensors([opt_C, height // opt_f, width // opt_f], seeds=seeds) + x = create_random_tensors([opt_C, p.height // opt_f, p.width // opt_f], seeds=seeds) - samples_ddim = func_sample(init_data=init_data, x=x, conditioning=c, unconditional_conditioning=uc) + samples_ddim = p.sample(x=x, conditioning=c, unconditional_conditioning=uc) x_samples_ddim = model.decode_first_stage(samples_ddim) x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) - if prompt_matrix or opts.samples_save or opts.grid_save: + if p.prompt_matrix or opts.samples_save or opts.grid_save: for i, x_sample in enumerate(x_samples_ddim): - x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') + x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2) x_sample = x_sample.astype(np.uint8) - if use_GFPGAN and GFPGAN is not None: + if p.use_GFPGAN and GFPGAN is not None: torch_gc() cropped_faces, restored_faces, restored_img = GFPGAN.enhance(x_sample, has_aligned=False, only_center_face=False, paste_back=True) x_sample = restored_img @@ -791,44 +805,44 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_index, output_images.append(image) base_count += 1 - if (prompt_matrix or opts.grid_save) and not do_not_save_grid: - if prompt_matrix: - grid = image_grid(output_images, batch_size, force_n_rows=1 << ((len(prompt_matrix_parts)-1)//2)) + if (p.prompt_matrix or opts.grid_save) and not p.do_not_save_grid: + if p.prompt_matrix: + grid = image_grid(output_images, p.batch_size, force_n_rows=1 << ((len(prompt_matrix_parts)-1)//2)) try: - grid = draw_prompt_matrix(grid, width, height, prompt_matrix_parts) - except: + grid = draw_prompt_matrix(grid, p.width, p.height, prompt_matrix_parts) + except Exception: import traceback print("Error creating prompt_matrix text:", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) output_images.insert(0, grid) else: - grid = image_grid(output_images, batch_size) + grid = image_grid(output_images, p.batch_size) - save_image(grid, outpath, f"grid-{grid_count:04}", seed, prompt, opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename) + save_image(grid, p.outpath, f"grid-{grid_count:04}", seed, prompt, opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename) grid_count += 1 torch_gc() return output_images, seed, infotext() -def txt2img(prompt: str, ddim_steps: int, sampler_index: int, use_GFPGAN: bool, prompt_matrix: bool, ddim_eta: float, n_iter: int, batch_size: int, cfg_scale: float, seed: int, height: int, width: int): - outpath = opts.outdir or "outputs/txt2img-samples" +class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): + sampler = None - sampler = samplers[sampler_index].constructor(model) + def init(self): + self.sampler = samplers[self.sampler_index].constructor() - def init(): - pass - - def sample(init_data, x, conditioning, unconditional_conditioning): - samples_ddim, _ = sampler.sample(S=ddim_steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=cfg_scale, unconditional_conditioning=unconditional_conditioning, eta=ddim_eta, x_T=x) + def sample(self, x, conditioning, unconditional_conditioning): + samples_ddim = self.sampler.sample(self, x, conditioning, unconditional_conditioning) return samples_ddim - output_images, seed, info = process_images( + +def txt2img(prompt: str, ddim_steps: int, sampler_index: int, use_GFPGAN: bool, prompt_matrix: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, height: int, width: int): + outpath = opts.outdir or "outputs/txt2img-samples" + + p = StableDiffusionProcessingTxt2Img( outpath=outpath, - func_init=init, - func_sample=sample, prompt=prompt, seed=seed, sampler_index=sampler_index, @@ -842,7 +856,7 @@ def txt2img(prompt: str, ddim_steps: int, sampler_index: int, use_GFPGAN: bool, use_GFPGAN=use_GFPGAN ) - del sampler + output_images, seed, info = process_images(p) return output_images, seed, plaintext_to_html(info) @@ -858,7 +872,7 @@ class Flagging(gr.FlaggingCallback): os.makedirs("log/images", exist_ok=True) # those must match the "txt2img" function - prompt, ddim_steps, sampler_name, use_GFPGAN, prompt_matrix, ddim_eta, n_iter, n_samples, cfg_scale, request_seed, height, width, images, seed, comment = flag_data + prompt, ddim_steps, sampler_name, use_gfpgan, prompt_matrix, ddim_eta, n_iter, n_samples, cfg_scale, request_seed, height, width, images, seed, comment = flag_data filenames = [] @@ -896,7 +910,6 @@ txt2img_interface = gr.Interface( gr.Radio(label='Sampling method', choices=[x.name for x in samplers], value=samplers[0].name, type="index"), gr.Checkbox(label='Fix faces using GFPGAN', value=False, visible=GFPGAN is not None), gr.Checkbox(label='Create prompt matrix (separate multiple prompts using |, and get all combinations of them)', value=False), - gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="DDIM ETA", value=0.0, visible=False), gr.Slider(minimum=1, maximum=cmd_opts.max_batch_count, step=1, label='Batch count (how many batches of images to generate)', value=1), gr.Slider(minimum=1, maximum=8, step=1, label='Batch size (how many images are in a batch; memory-hungry)', value=1), gr.Slider(minimum=1.0, maximum=15.0, step=0.5, label='Classifier Free Guidance Scale (how strongly the image should follow the prompt)', value=7.0), @@ -914,40 +927,77 @@ txt2img_interface = gr.Interface( ) +class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): + sampler = None + + def __init__(self, init_images=None, resize_mode=0, denoising_strength=0.75, **kwargs): + super().__init__(**kwargs) + + self.init_images = init_images + self.resize_mode: int = resize_mode + self.denoising_strength: float = denoising_strength + self.init_latent = None + + def init(self): + self.sampler = samplers_for_img2img[self.sampler_index].constructor() + + imgs = [] + for img in self.init_images: + image = img.convert("RGB") + image = resize_image(self.resize_mode, image, self.width, self.height) + image = np.array(image).astype(np.float32) / 255.0 + image = np.moveaxis(image, 2, 0) + imgs.append(image) + + if len(imgs) == 1: + batch_images = np.expand_dims(imgs[0], axis=0).repeat(self.batch_size, axis=0) + elif len(imgs) <= self.batch_size: + self.batch_size = len(imgs) + batch_images = np.array(imgs) + else: + raise RuntimeError(f"bad number of images passed: {len(imgs)}; expecting {self.batch_size} or less") + + image = torch.from_numpy(batch_images) + image = 2. * image - 1. + image = image.to(device) + + self.init_latent = sd_model.get_first_stage_encoding(sd_model.encode_first_stage(image)) + + def sample(self, x, conditioning, unconditional_conditioning): + t_enc = int(self.denoising_strength * self.steps) + + sigmas = self.sampler.model_wrap.get_sigmas(self.steps) + noise = x * sigmas[self.steps - t_enc - 1] + + xi = self.init_latent + noise + sigma_sched = sigmas[self.steps - t_enc - 1:] + samples_ddim = self.sampler.func(self.sampler.model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': self.cfg_scale}, disable=False) + return samples_ddim + + def img2img(prompt: str, init_img, ddim_steps: int, sampler_index: int, use_GFPGAN: bool, prompt_matrix, loopback: bool, sd_upscale: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, height: int, width: int, resize_mode: int): outpath = opts.outdir or "outputs/img2img-samples" - sampler = samplers_for_img2img[sampler_index].constructor(model) - assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]' - def init(): - image = init_img.convert("RGB") - image = resize_image(resize_mode, image, width, height) - image = np.array(image).astype(np.float32) / 255.0 - image = image[None].transpose(0, 3, 1, 2) - image = torch.from_numpy(image) - - init_image = 2. * image - 1. - init_image = init_image.to(device) - init_image = repeat(init_image, '1 ... -> b ...', b=batch_size) - init_latent = model.get_first_stage_encoding(model.encode_first_stage(init_image)) # move to latent space - - return init_latent, - - def sample(init_data, x, conditioning, unconditional_conditioning): - t_enc = int(denoising_strength * ddim_steps) - - x0, = init_data - - sigmas = sampler.model_wrap.get_sigmas(ddim_steps) - noise = x * sigmas[ddim_steps - t_enc - 1] - - xi = x0 + noise - sigma_sched = sigmas[ddim_steps - t_enc - 1:] - model_wrap_cfg = CFGDenoiser(sampler.model_wrap) - samples_ddim = sampler.func(model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': cfg_scale}, disable=False) - return samples_ddim + p = StableDiffusionProcessingImg2Img( + outpath=outpath, + prompt=prompt, + seed=seed, + sampler_index=sampler_index, + batch_size=batch_size, + n_iter=n_iter, + steps=ddim_steps, + cfg_scale=cfg_scale, + width=width, + height=height, + prompt_matrix=prompt_matrix, + use_GFPGAN=use_GFPGAN, + init_images=[init_img], + resize_mode=resize_mode, + denoising_strength=denoising_strength, + extra_generation_params={"Denoising Strength": denoising_strength} + ) if loopback: output_images, info = None, None @@ -955,32 +1005,19 @@ def img2img(prompt: str, init_img, ddim_steps: int, sampler_index: int, use_GFPG initial_seed = None for i in range(n_iter): - output_images, seed, info = process_images( - outpath=outpath, - func_init=init, - func_sample=sample, - prompt=prompt, - seed=seed, - sampler_index=sampler_index, - batch_size=1, - n_iter=1, - steps=ddim_steps, - cfg_scale=cfg_scale, - width=width, - height=height, - prompt_matrix=prompt_matrix, - use_GFPGAN=use_GFPGAN, - do_not_save_grid=True, - extra_generation_params={"Denoising Strength": denoising_strength}, - ) + p.n_iter = 1 + p.batch_size = 1 + p.do_not_save_grid = True + + output_images, seed, info = process_images(p) if initial_seed is None: initial_seed = seed - init_img = output_images[0] - seed = seed + 1 - denoising_strength = max(denoising_strength * 0.95, 0.1) - history.append(init_img) + p.init_img = output_images[0] + p.seed = seed + 1 + p.denoising_strength = max(p.denoising_strength * 0.95, 0.1) + history.append(output_images[0]) grid_count = len(os.listdir(outpath)) - 1 grid = image_grid(history, batch_size, force_n_rows=1) @@ -1000,39 +1037,36 @@ def img2img(prompt: str, init_img, ddim_steps: int, sampler_index: int, use_GFPG grid = split_grid(img, tile_w=width, tile_h=height, overlap=opts.sd_upscale_overlap) + p.n_iter = 1 + p.do_not_save_grid = True - print(f"SD upscaling will process a total of {len(grid.tiles[0][2])}x{len(grid.tiles)} images.") + work = [] + work_results = [] for y, h, row in grid.tiles: for tiledata in row: - init_img = tiledata[2] + work.append(tiledata[2]) - output_images, seed, info = process_images( - outpath=outpath, - func_init=init, - func_sample=sample, - prompt=prompt, - seed=seed, - sampler_index=sampler_index, - batch_size=1, # since process_images can't work with multiple different images we have to do this for now - n_iter=1, - steps=ddim_steps, - cfg_scale=cfg_scale, - width=width, - height=height, - prompt_matrix=prompt_matrix, - use_GFPGAN=use_GFPGAN, - do_not_save_grid=True, - extra_generation_params={"Denoising Strength": denoising_strength}, - ) + batch_count = math.ceil(len(work) / p.batch_size) + print(f"SD upscaling will process a total of {len(work)} images tiled as {len(grid.tiles[0][2])}x{len(grid.tiles)} in a total of {batch_count} batches.") - if initial_seed is None: - initial_seed = seed - initial_info = info + for i in range(batch_count): + p.init_images = work[i*p.batch_size:(i+1)*p.batch_size] - seed += 1 + output_images, seed, info = process_images(p) - tiledata[2] = output_images[0] + if initial_seed is None: + initial_seed = seed + initial_info = info + + p.seed = seed + 1 + work_results += output_images + + image_index = 0 + for y, h, row in grid.tiles: + for tiledata in row: + tiledata[2] = work_results[image_index] + image_index += 1 combined_image = combine_grid(grid) @@ -1044,25 +1078,7 @@ def img2img(prompt: str, init_img, ddim_steps: int, sampler_index: int, use_GFPG info = initial_info else: - output_images, seed, info = process_images( - outpath=outpath, - func_init=init, - func_sample=sample, - prompt=prompt, - seed=seed, - sampler_index=sampler_index, - batch_size=batch_size, - n_iter=n_iter, - steps=ddim_steps, - cfg_scale=cfg_scale, - width=width, - height=height, - prompt_matrix=prompt_matrix, - use_GFPGAN=use_GFPGAN, - extra_generation_params={"Denoising Strength": denoising_strength}, - ) - - del sampler + output_images, seed, info = process_images(p) return output_images, seed, plaintext_to_html(info) @@ -1178,22 +1194,19 @@ def run_settings(*args): def create_setting_component(key): def fun(): - return opts.data[key] if key in opts.data else opts.data_labels[key][0] + return opts.data[key] if key in opts.data else opts.data_labels[key].default - labelinfo = opts.data_labels[key] - t = type(labelinfo[0]) - label = labelinfo[1] - if t == str: - item = gr.Textbox(label=label, value=fun, lines=1) + info = opts.data_labels[key] + t = type(info.default) + + if info.component is not None: + item = info.component(label=info.label, value=fun, **(info.component_args or {})) + elif t == str: + item = gr.Textbox(label=info.label, value=fun, lines=1) elif t == int: - if len(labelinfo) == 5: - item = gr.Slider(minimum=labelinfo[2], maximum=labelinfo[3], step=labelinfo[4], label=label, value=fun) - elif len(labelinfo) == 4: - item = gr.Slider(minimum=labelinfo[2], maximum=labelinfo[3], step=1, label=label, value=fun) - else: - item = gr.Number(label=label, value=fun) + item = gr.Number(label=info.label, value=fun) elif t == bool: - item = gr.Checkbox(label=label, value=fun) + item = gr.Checkbox(label=info.label, value=fun) else: raise Exception(f'bad options item type: {str(t)} for key {key}') @@ -1219,14 +1232,14 @@ interfaces = [ (settings_interface, "Settings"), ] -config = OmegaConf.load(cmd_opts.config) -model = load_model_from_config(config, cmd_opts.ckpt) +sd_config = OmegaConf.load(cmd_opts.config) +sd_model = load_model_from_config(sd_config, cmd_opts.ckpt) device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") -model = (model if cmd_opts.no_half else model.half()).to(device) +sd_model = (sd_model if cmd_opts.no_half else sd_model.half()).to(device) model_hijack = StableDiffuionModelHijack() -model_hijack.hijack(model) +model_hijack.hijack(sd_model) demo = gr.TabbedInterface( interface_list=[x[0] for x in interfaces],