stable-diffusion-webui/webui.py

975 lines
40 KiB
Python
Raw Normal View History

2022-08-22 17:15:46 +03:00
import argparse, os, sys, glob
import gradio as gr
import k_diffusion as K
import math
import mimetypes
import numpy as np
import pynvml
import random
import threading
import time
2022-08-22 17:15:46 +03:00
import torch
import torch.nn as nn
from contextlib import contextmanager, nullcontext
from einops import rearrange, repeat
from itertools import islice
2022-08-22 17:15:46 +03:00
from omegaconf import OmegaConf
from PIL import Image, ImageFont, ImageDraw, ImageFilter, ImageOps
2022-08-22 17:15:46 +03:00
from torch import autocast
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler
from ldm.util import instantiate_from_config
2022-08-22 17:15:46 +03:00
try:
# this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
from transformers import logging
logging.set_verbosity_error()
except:
pass
2022-08-22 17:15:46 +03:00
# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the bowser will not show any UI
mimetypes.init()
mimetypes.add_type('application/javascript', '.js')
# some of those options should not be changed at all because they would break the model, so I removed them from options.
opt_C = 4
opt_f = 8
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
invalid_filename_chars = '<>:"/\|?*\n'
2022-08-22 17:15:46 +03:00
parser = argparse.ArgumentParser()
parser.add_argument("--outdir", type=str, nargs="?", help="dir to write results to", default=None)
parser.add_argument("--skip_grid", action='store_true', help="do not save a grid, only individual samples. Helpful when evaluating lots of samples",)
parser.add_argument("--skip_save", action='store_true', help="do not save indiviual samples. For speed measurements.",)
parser.add_argument("--n_rows", type=int, default=-1, help="rows in the grid; use -1 for autodetect and 0 for n_rows to be same as batch_size (default: -1)",)
2022-08-22 17:15:46 +03:00
parser.add_argument("--config", type=str, default="configs/stable-diffusion/v1-inference.yaml", help="path to config which constructs model",)
parser.add_argument("--ckpt", type=str, default="models/ldm/stable-diffusion-v1/model.ckpt", help="path to checkpoint of model",)
parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast")
parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN')) # i disagree with where you're putting it but since all guidefags are doing it this way, there you go
parser.add_argument("--no-verify-input", action='store_true', help="do not verify input to check if it's too long")
parser.add_argument("--no-half", action='store_true', help="do not switch the model to 16-bit floats")
parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware accleration in browser)")
2022-08-22 17:15:46 +03:00
opt = parser.parse_args()
GFPGAN_dir = opt.gfpgan_dir
css_hide_progressbar = """
.wrap .m-12 svg { display:none!important; }
.wrap .m-12::before { content:"Loading..." }
.progress-bar { display:none!important; }
.meta-text { display:none!important; }
"""
2022-08-22 17:15:46 +03:00
def chunk(it, size):
it = iter(it)
return iter(lambda: tuple(islice(it, size)), ())
def load_model_from_config(config, ckpt, verbose=False):
print(f"Loading model from {ckpt}")
pl_sd = torch.load(ckpt, map_location="cpu")
if "global_step" in pl_sd:
print(f"Global Step: {pl_sd['global_step']}")
sd = pl_sd["state_dict"]
model = instantiate_from_config(config.model)
m, u = model.load_state_dict(sd, strict=False)
if len(m) > 0 and verbose:
print("missing keys:")
print(m)
if len(u) > 0 and verbose:
print("unexpected keys:")
print(u)
model.cuda()
model.eval()
return model
def crash(e, s):
global model
global device
print(s, '\n', e)
del model
del device
print('exiting...calling os._exit(0)')
t = threading.Timer(0.25, os._exit, args=[0])
t.start()
class MemUsageMonitor(threading.Thread):
stop_flag = False
max_usage = 0
total = 0
def __init__(self, name):
threading.Thread.__init__(self)
self.name = name
def run(self):
print(f"[{self.name}] Recording max memory usage...\n")
pynvml.nvmlInit()
handle = pynvml.nvmlDeviceGetHandleByIndex(0)
self.total = pynvml.nvmlDeviceGetMemoryInfo(handle).total
while not self.stop_flag:
m = pynvml.nvmlDeviceGetMemoryInfo(handle)
self.max_usage = max(self.max_usage, m.used)
# print(self.max_usage)
time.sleep(0.1)
print(f"[{self.name}] Stopped recording.\n")
pynvml.nvmlShutdown()
def read(self):
return self.max_usage, self.total
def stop(self):
self.stop_flag = True
def read_and_stop(self):
self.stop_flag = True
return self.max_usage, self.total
2022-08-22 17:15:46 +03:00
class CFGDenoiser(nn.Module):
def __init__(self, model):
super().__init__()
self.inner_model = model
def forward(self, x, sigma, uncond, cond, cond_scale):
x_in = torch.cat([x] * 2)
sigma_in = torch.cat([sigma] * 2)
cond_in = torch.cat([uncond, cond])
uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
return uncond + (cond - uncond) * cond_scale
class KDiffusionSampler:
2022-08-25 05:55:15 +03:00
def __init__(self, m, sampler):
self.model = m
self.model_wrap = K.external.CompVisDenoiser(m)
2022-08-25 05:55:15 +03:00
self.schedule = sampler
def sample(self, S, conditioning, batch_size, shape, verbose, unconditional_guidance_scale, unconditional_conditioning, eta, x_T):
sigmas = self.model_wrap.get_sigmas(S)
x = x_T * sigmas[0]
model_wrap_cfg = CFGDenoiser(self.model_wrap)
samples_ddim = K.sampling.__dict__[f'sample_{self.schedule}'](model_wrap_cfg, x, sigmas, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': unconditional_guidance_scale}, disable=False)
return samples_ddim, None
2022-08-24 18:45:55 +03:00
class MemUsageMonitor(threading.Thread):
stop_flag = False
max_usage = 0
total = 0
2022-08-24 18:45:55 +03:00
def __init__(self, name):
threading.Thread.__init__(self)
self.name = name
2022-08-24 18:45:55 +03:00
def run(self):
print(f"[{self.name}] Recording max memory usage...\n")
pynvml.nvmlInit()
handle = pynvml.nvmlDeviceGetHandleByIndex(0)
self.total = pynvml.nvmlDeviceGetMemoryInfo(handle).total
while not self.stop_flag:
m = pynvml.nvmlDeviceGetMemoryInfo(handle)
self.max_usage = max(self.max_usage, m.used)
# print(self.max_usage)
time.sleep(0.1)
print(f"[{self.name}] Stopped recording.\n")
pynvml.nvmlShutdown()
2022-08-24 18:45:55 +03:00
def read(self):
return self.max_usage, self.total
2022-08-24 18:45:55 +03:00
def stop(self):
self.stop_flag = True
2022-08-24 18:45:55 +03:00
def read_and_stop(self):
self.stop_flag = True
return self.max_usage, self.total
def create_random_tensors(shape, seeds):
xs = []
for seed in seeds:
torch.manual_seed(seed)
# randn results depend on device; gpu and cpu get different results for same seed;
# the way I see it, it's better to do this on CPU, so that everyone gets same result;
# but the original script had it like this so i do not dare change it for now because
# it will break everyone's seeds.
xs.append(torch.randn(shape, device=device))
x = torch.stack(xs)
return x
def torch_gc():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
2022-08-22 17:15:46 +03:00
def load_GFPGAN():
model_name = 'GFPGANv1.3'
model_path = os.path.join(GFPGAN_dir, 'experiments/pretrained_models', model_name + '.pth')
if not os.path.isfile(model_path):
raise Exception("GFPGAN model not found at path "+model_path)
sys.path.append(os.path.abspath(GFPGAN_dir))
from gfpgan import GFPGANer
return GFPGANer(model_path=model_path, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None)
GFPGAN = None
if os.path.exists(GFPGAN_dir):
try:
GFPGAN = load_GFPGAN()
print("Loaded GFPGAN")
except Exception:
import traceback
print("Error loading GFPGAN:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
config = OmegaConf.load("configs/stable-diffusion/v1-inference.yaml")
model = load_model_from_config(config, "models/ldm/stable-diffusion-v1/model.ckpt")
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = (model if opt.no_half else model.half()).to(device)
2022-08-22 17:15:46 +03:00
def load_embeddings(fp):
if fp is not None and hasattr(model, "embedding_manager"):
model.embedding_manager.load(fp.name)
2022-08-22 17:15:46 +03:00
2022-08-24 18:53:35 +03:00
def image_grid(imgs, batch_size, round_down=False, force_n_rows=None):
if force_n_rows is not None:
rows = force_n_rows
elif opt.n_rows > 0:
rows = opt.n_rows
elif opt.n_rows == 0:
rows = batch_size
else:
rows = math.sqrt(len(imgs))
rows = int(rows) if round_down else round(rows)
cols = math.ceil(len(imgs) / rows)
2022-08-22 17:15:46 +03:00
w, h = imgs[0].size
grid = Image.new('RGB', size=(cols * w, rows * h), color='black')
2022-08-22 17:15:46 +03:00
for i, img in enumerate(imgs):
grid.paste(img, box=(i % cols * w, i // cols * h))
return grid
2022-08-24 20:51:28 +03:00
def seed_to_int(s):
if s == 'random':
return random.randint(0,2**32)
n = abs(int(s) if s.isdigit() else hash(s))
while n > 2**32:
n = n >> 32
return n
2022-08-24 18:45:55 +03:00
def draw_prompt_matrix(im, width, height, all_prompts):
def wrap(text, d, font, line_length):
lines = ['']
for word in text.split():
line = f'{lines[-1]} {word}'.strip()
if d.textlength(line, font=font) <= line_length:
lines[-1] = line
else:
lines.append(word)
return '\n'.join(lines)
def draw_texts(pos, x, y, texts, sizes):
for i, (text, size) in enumerate(zip(texts, sizes)):
active = pos & (1 << i) != 0
if not active:
text = '\u0336'.join(text) + '\u0336'
d.multiline_text((x, y + size[1] / 2), text, font=fnt, fill=color_active if active else color_inactive, anchor="mm", align="center")
y += size[1] + line_spacing
fontsize = (width + height) // 25
line_spacing = fontsize // 2
fnt = ImageFont.truetype("arial.ttf", fontsize)
color_active = (0, 0, 0)
color_inactive = (153, 153, 153)
pad_top = height // 4
pad_left = width * 3 // 4 if len(all_prompts) > 2 else 0
cols = im.width // width
rows = im.height // height
prompts = all_prompts[1:]
result = Image.new("RGB", (im.width + pad_left, im.height + pad_top), "white")
result.paste(im, (pad_left, pad_top))
d = ImageDraw.Draw(result)
boundary = math.ceil(len(prompts) / 2)
prompts_horiz = [wrap(x, d, fnt, width) for x in prompts[:boundary]]
prompts_vert = [wrap(x, d, fnt, pad_left) for x in prompts[boundary:]]
sizes_hor = [(x[2] - x[0], x[3] - x[1]) for x in [d.multiline_textbbox((0, 0), x, font=fnt) for x in prompts_horiz]]
sizes_ver = [(x[2] - x[0], x[3] - x[1]) for x in [d.multiline_textbbox((0, 0), x, font=fnt) for x in prompts_vert]]
hor_text_height = sum([x[1] + line_spacing for x in sizes_hor]) - line_spacing
ver_text_height = sum([x[1] + line_spacing for x in sizes_ver]) - line_spacing
for col in range(cols):
x = pad_left + width * col + width / 2
y = pad_top / 2 - hor_text_height / 2
draw_texts(col, x, y, prompts_horiz, sizes_hor)
for row in range(rows):
x = pad_left / 2
y = pad_top + height * row + height / 2 - ver_text_height / 2
draw_texts(row, x, y, prompts_vert, sizes_ver)
return result
def resize_image(resize_mode, im, width, height):
if resize_mode == 0:
res = im.resize((width, height), resample=LANCZOS)
elif resize_mode == 1:
ratio = width / height
src_ratio = im.width / im.height
src_w = width if ratio > src_ratio else im.width * height // im.height
src_h = height if ratio <= src_ratio else im.height * width // im.width
resized = im.resize((src_w, src_h), resample=LANCZOS)
res = Image.new("RGB", (width, height))
res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
else:
ratio = width / height
src_ratio = im.width / im.height
src_w = width if ratio < src_ratio else im.width * height // im.height
src_h = height if ratio >= src_ratio else im.height * width // im.width
resized = im.resize((src_w, src_h), resample=LANCZOS)
res = Image.new("RGB", (width, height))
res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
if ratio < src_ratio:
fill_height = height // 2 - src_h // 2
res.paste(resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0))
res.paste(resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)), box=(0, fill_height + src_h))
elif ratio > src_ratio:
fill_width = width // 2 - src_w // 2
res.paste(resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0))
res.paste(resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)), box=(fill_width + src_w, 0))
return res
def check_prompt_length(prompt, comments):
"""this function tests if prompt is too long, and if so, adds a message to comments"""
tokenizer = model.cond_stage_model.tokenizer
max_length = model.cond_stage_model.max_length
info = model.cond_stage_model.tokenizer([prompt], truncation=True, max_length=max_length, return_overflowing_tokens=True, padding="max_length", return_tensors="pt")
ovf = info['overflowing_tokens'][0]
overflowing_count = ovf.shape[0]
if overflowing_count == 0:
return
vocab = {v: k for k, v in tokenizer.get_vocab().items()}
overflowing_words = [vocab.get(int(x), "") for x in ovf]
overflowing_text = tokenizer.convert_tokens_to_string(''.join(overflowing_words))
comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name, skip_grid, skip_save, batch_size, n_iter, steps, cfg_scale, width, height, prompt_matrix,
, fp, do_not_save_grid=False, normalize_prompt_weights=True, init_img=None, init_mask=None, keep_mask=True):
"""this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
2022-08-24 18:45:55 +03:00
assert prompt is not None
torch_gc()
# start time after garbage collection (or before?)
start_time = time.time()
2022-08-22 17:15:46 +03:00
2022-08-24 18:45:55 +03:00
mem_mon = MemUsageMonitor('MemMon')
mem_mon.start()
2022-08-22 17:15:46 +03:00
if hasattr(model, "embedding_manager"):
load_embeddings(fp)
2022-08-22 17:15:46 +03:00
os.makedirs(outpath, exist_ok=True)
sample_path = os.path.join(outpath, "samples")
os.makedirs(sample_path, exist_ok=True)
base_count = len(os.listdir(sample_path))
grid_count = len(os.listdir(outpath)) - 1
comments = []
prompt_matrix_parts = []
if prompt_matrix:
all_prompts = []
prompt_matrix_parts = prompt.split("|")
combination_count = 2 ** (len(prompt_matrix_parts) - 1)
for combination_num in range(combination_count):
current = prompt_matrix_parts[0]
for n, text in enumerate(prompt_matrix_parts[1:]):
if combination_num & (2 ** n) > 0:
current += ("" if text.strip().startswith(",") else ", ") + text
all_prompts.append(current)
n_iter = math.ceil(len(all_prompts) / batch_size)
all_seeds = len(all_prompts) * [seed]
print(f"Prompt matrix will create {len(all_prompts)} images using a total of {n_iter} batches.")
else:
if not opt.no_verify_input:
try:
check_prompt_length(prompt, comments)
except:
import traceback
print("Error verifying input:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
all_prompts = batch_size * n_iter * [prompt]
all_seeds = [seed + x for x in range(len(all_prompts))]
2022-08-22 17:15:46 +03:00
precision_scope = autocast if opt.precision == "autocast" else nullcontext
output_images = []
2022-08-24 16:46:39 +03:00
stats = []
2022-08-22 17:15:46 +03:00
with torch.no_grad(), precision_scope("cuda"), model.ema_scope():
init_data = func_init()
tic = time.time()
2022-08-22 17:15:46 +03:00
for n in range(n_iter):
prompts = all_prompts[n * batch_size:(n + 1) * batch_size]
seeds = all_seeds[n * batch_size:(n + 1) * batch_size]
uc = None
if cfg_scale != 1.0:
uc = model.get_learned_conditioning(len(prompts) * [""])
if isinstance(prompts, tuple):
prompts = list(prompts)
# split the prompt if it has : for weighting
# TODO for speed it might help to have this occur when all_prompts filled??
subprompts,weights = split_weighted_subprompts(prompts[0])
# get total weight for normalizing, this gets weird if large negative values used
totalPromptWeight = sum(weights)
# sub-prompt weighting used if more than 1
if len(subprompts) > 1:
c = torch.zeros_like(uc) # i dont know if this is correct.. but it works
for i in range(0,len(subprompts)): # normalize each prompt and add it
weight = weights[i]
if normalize_prompt_weights:
weight = weight / totalPromptWeight
#print(f"{subprompts[i]} {weight*100.0}%")
# note if alpha negative, it functions same as torch.sub
c = torch.add(c,model.get_learned_conditioning(subprompts[i]), alpha=weight)
else: # just behave like usual
c = model.get_learned_conditioning(prompts)
shape = [opt_C, height // opt_f, width // opt_f]
# we manually generate all input noises because each one should have a specific seed
x = create_random_tensors([opt_C, height // opt_f, width // opt_f], seeds=seeds)
samples_ddim = func_sample(init_data=init_data, x=x, conditioning=c, unconditional_conditioning=uc, sampler_name=sampler_name)
2022-08-22 17:15:46 +03:00
x_samples_ddim = model.decode_first_stage(samples_ddim)
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
2022-08-24 20:51:28 +03:00
for i, x_sample in enumerate(x_samples_ddim):
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
x_sample = x_sample.astype(np.uint8)
2022-08-24 20:51:28 +03:00
if use_GFPGAN and GFPGAN is not None:
cropped_faces, restored_faces, restored_img = GFPGAN.enhance(x_sample[:,:,::-1], has_aligned=False, only_center_face=False, paste_back=True)
x_sample = restored_img[:,:,::-1]
2022-08-24 20:51:28 +03:00
image = Image.fromarray(x_sample)
if init_mask:
init_mask = init_mask if keep_mask else ImageOps.invert(init_mask)
init_mask = init_mask.filter(ImageFilter.GaussianBlur(3))
init_mask = init_mask.convert('L')
init_img = init_img.convert('RGB')
image = image.convert('RGB')
image = Image.composite(init_img, image, init_mask)
2022-08-24 20:51:28 +03:00
filename = f"{base_count:05}-{seeds[i]}_{prompts[i].replace(' ', '_').translate({ord(x): '' for x in invalid_filename_chars})[:128]}.png"
if not skip_save:
image.save(os.path.join(sample_path, filename))
2022-08-24 20:51:28 +03:00
output_images.append(image)
base_count += 1
2022-08-22 17:15:46 +03:00
if (prompt_matrix or not skip_grid) and not do_not_save_grid:
grid = image_grid(output_images, batch_size, round_down=prompt_matrix)
if prompt_matrix:
try:
grid = draw_prompt_matrix(grid, width, height, prompt_matrix_parts)
except Exception:
import traceback
print("Error creating prompt_matrix text:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
output_images.insert(0, grid)
2022-08-24 18:45:55 +03:00
grid_file = f"grid-{grid_count:05}-{seed}_{prompts[i].replace(' ', '_').translate({ord(x): '' for x in invalid_filename_chars})[:128]}.jpg"
2022-08-24 20:51:28 +03:00
grid.save(os.path.join(outpath, grid_file), 'jpeg', quality=100, optimize=True)
2022-08-22 17:15:46 +03:00
grid_count += 1
toc = time.time()
mem_max_used, mem_total = mem_mon.read_and_stop()
time_diff = time.time()-start_time
2022-08-22 17:15:46 +03:00
info = f"""
{prompt}
2022-08-24 16:46:39 +03:00
Steps: {steps}, Sampler: {sampler_name}, CFG scale: {cfg_scale}, Seed: {seed}{', GFPGAN' if use_GFPGAN and GFPGAN is not None else ''}{', Prompt Matrix Mode.' if prompt_matrix else ''}""".strip()
stats = f'''
Took { round(time_diff, 2) }s total ({ round(time_diff/(len(all_prompts)),2) }s per image)
Peak memory usage: { -(mem_max_used // -1_048_576) } MiB / { -(mem_total // -1_048_576) } MiB / { round(mem_max_used/mem_total*100, 3) }%'''
for comment in comments:
info += "\n\n" + comment
2022-08-24 16:46:39 +03:00
#mem_mon.stop()
#del mem_mon
torch_gc()
2022-08-24 18:45:55 +03:00
2022-08-24 16:46:39 +03:00
return output_images, seed, info, stats
2022-08-22 17:15:46 +03:00
2022-08-24 18:45:55 +03:00
def txt2img(prompt: str, ddim_steps: int, sampler_name: str, use_GFPGAN: bool, prompt_matrix: bool, skip_grid: bool, skip_save: bool, ddim_eta: float, n_iter: int, batch_size: int, cfg_scale: float, seed: int, height: int, width: int, normalize_prompt_weights: bool, fp):
outpath = opt.outdir or "outputs/txt2img-samples"
err = False
2022-08-24 20:51:28 +03:00
seed = seed_to_int(seed)
if sampler_name == 'PLMS':
sampler = PLMSSampler(model)
elif sampler_name == 'DDIM':
sampler = DDIMSampler(model)
2022-08-25 05:31:40 +03:00
elif sampler_name == 'k_dpm_2_a':
sampler = KDiffusionSampler(model,'dpm_2_ancestral')
elif sampler_name == 'k_dpm_2':
sampler = KDiffusionSampler(model,'dpm_2')
elif sampler_name == 'k_euler_a':
sampler = KDiffusionSampler(model,'euler_ancestral')
elif sampler_name == 'k_euler':
sampler = KDiffusionSampler(model,'euler')
elif sampler_name == 'k_heun':
sampler = KDiffusionSampler(model,'heun')
elif sampler_name == 'k_lms':
sampler = KDiffusionSampler(model,'lms')
else:
raise Exception("Unknown sampler: " + sampler_name)
def init():
pass
def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name):
samples_ddim, _ = sampler.sample(S=ddim_steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=cfg_scale, unconditional_conditioning=unconditional_conditioning, eta=ddim_eta, x_T=x)
return samples_ddim
try:
2022-08-24 16:46:39 +03:00
output_images, seed, info, stats = process_images(
outpath=outpath,
func_init=init,
func_sample=sample,
prompt=prompt,
seed=seed,
sampler_name=sampler_name,
skip_save=skip_save,
skip_grid=skip_grid,
batch_size=batch_size,
n_iter=n_iter,
steps=ddim_steps,
cfg_scale=cfg_scale,
width=width,
height=height,
prompt_matrix=prompt_matrix,
use_GFPGAN=use_GFPGAN,
fp=fp,
normalize_prompt_weights=normalize_prompt_weights
)
del sampler
2022-08-24 16:46:39 +03:00
return output_images, seed, info, stats
except RuntimeError as e:
err = e
err_msg = f'CRASHED:<br><textarea rows="5" style="color:white;background: black;width: -webkit-fill-available;font-family: monospace;font-size: small;font-weight: bold;">{str(e)}</textarea><br><br>Please wait while the program restarts.'
stats = err_msg
return [], seed, 'err', stats
finally:
if err:
crash(err, '!!Runtime error (txt2img)!!')
class Flagging(gr.FlaggingCallback):
def setup(self, components, flagging_dir: str):
pass
def flag(self, flag_data, flag_option=None, flag_index=None, username=None):
import csv
os.makedirs("log/images", exist_ok=True)
# those must match the "txt2img" function !! + images, seed, comment, stats !! NOTE: changes to UI output must be reflected here too
prompt, ddim_steps, sampler_name, use_GFPGAN, skip_grid, skip_save, prompt_matrix, ddim_eta, n_iter, n_samples, cfg_scale, input_seed, height, width, normalize_prompt_weights, fp, images, seed, comment, stats = flag_data
filenames = []
with open("log/log.csv", "a", encoding="utf8", newline='') as file:
import time
import base64
at_start = file.tell() == 0
writer = csv.writer(file)
if at_start:
writer.writerow(["prompt", "seed", "width", "height", "sampler", "use_GFPGAN", "prompt_matrix", "n_iter", "n_samples", "cfg_scale", "steps", "filename"])
filename_base = str(int(time.time() * 1000))
for i, filedata in enumerate(images):
filename = "log/images/"+filename_base + ("" if len(images) == 1 else "-"+str(i+1)) + ".png"
if filedata.startswith("data:image/png;base64,"):
filedata = filedata[len("data:image/png;base64,"):]
with open(filename, "wb") as imgfile:
imgfile.write(base64.decodebytes(filedata.encode('utf-8')))
filenames.append(filename)
writer.writerow([prompt, seed, width, height, sampler_name, use_GFPGAN, prompt_matrix, n_iter, n_samples, cfg_scale, ddim_steps, filenames[0]])
print("Logged:", filenames[0])
2022-08-22 17:15:46 +03:00
txt2img_interface = gr.Interface(
txt2img,
2022-08-22 17:15:46 +03:00
inputs=[
gr.Textbox(label="Prompt", placeholder="A corgi wearing a top hat as an oil painting.", lines=1),
gr.Slider(minimum=1, maximum=150, step=1, label="Sampling Steps", value=50),
gr.Radio(label='Sampling method (k_lms is default k-diffusion sampler)', choices=["DDIM", "PLMS", 'k_dpm_2_a', 'k_dpm_2', 'k_euler_a', 'k_euler', 'k_heun', 'k_lms'], value="k_lms"),
2022-08-22 17:15:46 +03:00
gr.Checkbox(label='Fix faces using GFPGAN', value=False, visible=GFPGAN is not None),
gr.Checkbox(label='Create prompt matrix (separate multiple prompts using |, and get all combinations of them)', value=False),
gr.Checkbox(label='Skip grid', value=False),
gr.Checkbox(label='Skip save individual images', value=False),
2022-08-22 17:15:46 +03:00
gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="DDIM ETA", value=0.0, visible=False),
gr.Slider(minimum=1, maximum=250, step=1, label='Batch count (how many batches of images to generate)', value=1),
gr.Slider(minimum=1, maximum=8, step=1, label='Batch size (how many images are in a batch; memory-hungry)', value=1),
gr.Slider(minimum=1.0, maximum=15.0, step=0.5, label='Classifier Free Guidance Scale (how strongly the image should follow the prompt)', value=7.0),
2022-08-24 20:51:28 +03:00
gr.Textbox(label="Seed ('random' to randomize)", lines=1, value="random"),
2022-08-22 17:15:46 +03:00
gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512),
gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512),
gr.Checkbox(label="Normalize Prompt Weights (ensure sum of weights add up to 1.0)", value=True),
gr.File(label = "Embeddings file for textual inversion", visible=hasattr(model, "embedding_manager")),
2022-08-22 17:15:46 +03:00
],
outputs=[
gr.Gallery(label="Images"),
gr.Number(label='Seed'),
gr.Textbox(label="Copy-paste generation parameters"),
gr.HTML(label='Stats'),
2022-08-22 17:15:46 +03:00
],
title="Stable Diffusion Text-to-Image Unified",
description="Generate images from text with Stable Diffusion",
2022-08-24 18:45:55 +03:00
flagging_callback=Flagging(),
theme="default",
2022-08-22 17:15:46 +03:00
)
def img2img(prompt: str, init_info, mask_mode, ddim_steps: int, sampler_name: str, use_GFPGAN: bool, prompt_matrix, loopback: bool, skip_grid: bool, skip_save: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, height: int, width: int, resize_mode: int, normalize_prompt_weights: bool, fp):
2022-08-22 17:15:46 +03:00
outpath = opt.outdir or "outputs/img2img-samples"
2022-08-24 16:59:44 +03:00
err = False
2022-08-24 20:51:28 +03:00
seed = seed_to_int(seed)
2022-08-22 17:15:46 +03:00
if sampler_name == 'DDIM':
sampler = DDIMSampler(model)
2022-08-25 05:31:40 +03:00
elif sampler_name == 'k_dpm_2_a':
sampler = KDiffusionSampler(model,'dpm_2_ancestral')
elif sampler_name == 'k_dpm_2':
sampler = KDiffusionSampler(model,'dpm_2')
elif sampler_name == 'k_euler_a':
sampler = KDiffusionSampler(model,'euler_ancestral')
elif sampler_name == 'k_euler':
sampler = KDiffusionSampler(model,'euler')
elif sampler_name == 'k_heun':
sampler = KDiffusionSampler(model,'heun')
elif sampler_name == 'k_lms':
sampler = KDiffusionSampler(model,'lms')
else:
raise Exception("Unknown sampler: " + sampler_name)
2022-08-22 17:15:46 +03:00
init_img = init_info["image"]
init_img = init_img.convert("RGB")
init_img = resize_image(resize_mode, init_img, width, height)
init_mask = init_info["mask"]
init_mask = init_mask.convert("RGB")
init_mask = resize_image(resize_mode, init_mask, width, height)
keep_mask = mask_mode == "Keep masked area"
assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]'
2022-08-24 18:45:55 +03:00
t_enc = int(denoising_strength * ddim_steps)
2022-08-22 17:15:46 +03:00
def init():
image = init_img.convert("RGB")
image = resize_image(resize_mode, image, width, height)
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
2022-08-22 17:15:46 +03:00
init_image = 2. * image - 1.
init_image = init_image.to(device)
init_image = repeat(init_image, '1 ... -> b ...', b=batch_size)
init_latent = model.get_first_stage_encoding(model.encode_first_stage(init_image)) # move to latent space
return init_latent,
def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name):
if sampler_name == 'k-diffusion':
x0, = init_data
sigmas = sampler.model_wrap.get_sigmas(ddim_steps)
noise = x * sigmas[ddim_steps - t_enc - 1]
xi = x0 + noise
sigma_sched = sigmas[ddim_steps - t_enc - 1:]
model_wrap_cfg = CFGDenoiser(sampler.model_wrap)
samples_ddim = K.sampling.sample_lms(model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': cfg_scale}, disable=False)
else:
x0, = init_data
sampler.make_schedule(ddim_num_steps=ddim_steps, ddim_eta=0.0, verbose=False)
z_enc = sampler.stochastic_encode(x0, torch.tensor([t_enc]*batch_size).to(device))
# decode it
samples_ddim = sampler.decode(z_enc, conditioning, t_enc,
unconditional_guidance_scale=cfg_scale,
unconditional_conditioning=unconditional_conditioning,)
return samples_ddim
2022-08-24 18:45:55 +03:00
2022-08-24 16:59:44 +03:00
try:
2022-08-24 17:04:17 +03:00
if loopback:
output_images, info = None, None
history = []
initial_seed = None
for i in range(n_iter):
output_images, seed, info, stats = process_images(
2022-08-24 17:04:17 +03:00
outpath=outpath,
func_init=init,
func_sample=sample,
prompt=prompt,
seed=seed,
sampler_name=sampler_name,
skip_save=skip_save,
skip_grid=skip_grid,
2022-08-24 17:04:17 +03:00
batch_size=1,
n_iter=1,
steps=ddim_steps,
cfg_scale=cfg_scale,
width=width,
height=height,
prompt_matrix=prompt_matrix,
use_GFPGAN=use_GFPGAN,
fp=fp,
2022-08-24 17:04:17 +03:00
do_not_save_grid=True
)
if initial_seed is None:
initial_seed = seed
init_img = output_images[0]
seed = seed + 1
denoising_strength = max(denoising_strength * 0.95, 0.1)
history.append(init_img)
2022-08-24 20:51:28 +03:00
if not skip_grid:
grid_count = len(os.listdir(outpath)) - 1
grid = image_grid(history, batch_size, force_n_rows=1)
grid_file = f"grid-{grid_count:05}-{seed}_{prompt.replace(' ', '_').translate({ord(x): '' for x in invalid_filename_chars})[:128]}.jpg"
grid.save(os.path.join(outpath, grid_file), 'jpeg', quality=100, optimize=True)
2022-08-24 17:04:17 +03:00
output_images = history
seed = initial_seed
else:
output_images, seed, info, stats = process_images(
outpath=outpath,
func_init=init,
func_sample=sample,
prompt=prompt,
seed=seed,
sampler_name=sampler_name,
skip_save=skip_save,
skip_grid=skip_grid,
2022-08-24 17:04:17 +03:00
batch_size=batch_size,
n_iter=n_iter,
steps=ddim_steps,
cfg_scale=cfg_scale,
width=width,
height=height,
prompt_matrix=prompt_matrix,
use_GFPGAN=use_GFPGAN,
fp=fp,
normalize_prompt_weights=normalize_prompt_weights,
init_img=init_img,
init_mask=init_mask,
keep_mask=keep_mask
)
2022-08-24 16:59:44 +03:00
del sampler
return output_images, seed, info, stats
except RuntimeError as e:
err = e
err_msg = f'CRASHED:<br><textarea rows="5" style="color:white;background: black;width: -webkit-fill-available;font-family: monospace;font-size: small;font-weight: bold;">{str(e)}</textarea><br><br>Please wait while the program restarts.'
stats = err_msg
return [], seed, 'err', stats
2022-08-24 16:59:44 +03:00
finally:
if err:
crash(err, '!!Runtime error (img2img)!!')
2022-08-22 17:15:46 +03:00
2022-08-24 18:45:55 +03:00
sample_img2img = "assets/stable-samples/img2img/sketch-mountains-input.jpg"
sample_img2img = sample_img2img if os.path.exists(sample_img2img) else None
2022-08-22 17:15:46 +03:00
img2img_interface = gr.Interface(
img2img,
2022-08-22 17:15:46 +03:00
inputs=[
gr.Textbox(placeholder="A fantasy landscape, trending on artstation.", lines=1),
gr.Image(value=sample_img2img, source="upload", interactive=True, type="pil", tool="sketch"),
gr.Radio(choices=["Keep masked area", "Regenerate only masked area"], label="Mask Mode", value="Keep masked area"),
2022-08-22 17:15:46 +03:00
gr.Slider(minimum=1, maximum=150, step=1, label="Sampling Steps", value=50),
gr.Radio(label='Sampling method (k_lms is default k-diffusion sampler)', choices=["DDIM", 'k_dpm_2_a', 'k_dpm_2', 'k_euler_a', 'k_euler', 'k_heun', 'k_lms'], value="k_lms"),
gr.Checkbox(label='Fix faces using GFPGAN', value=False, visible=GFPGAN is not None),
gr.Checkbox(label='Create prompt matrix (separate multiple prompts using |, and get all combinations of them)', value=False),
gr.Checkbox(label='Loopback (use images from previous batch when creating next batch)', value=False),
gr.Checkbox(label='Skip grid', value=False),
gr.Checkbox(label='Skip save individual images', value=False),
2022-08-24 18:45:55 +03:00
gr.Slider(minimum=1, maximum=16, step=1, label='Batch count (how many batches of images to generate)', value=1),
gr.Slider(minimum=1, maximum=250, step=1, label='Batch size (how many images are in a batch; memory-hungry)', value=1),
gr.Slider(minimum=1.0, maximum=15.0, step=0.5, label='Classifier Free Guidance Scale (how strongly the image should follow the prompt)', value=7.0),
2022-08-22 17:15:46 +03:00
gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising Strength', value=0.75),
2022-08-24 20:51:28 +03:00
gr.Textbox(label="Seed ('random' to randomize)", lines=1, value="random"),
gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512),
gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512),
gr.Radio(label="Resize mode", choices=["Just resize", "Crop and resize", "Resize and fill"], type="index", value="Just resize"),
gr.Checkbox(label="Normalize Prompt Weights (ensure sum of weights add up to 1.0)", value=True),
gr.File(label = "Embeddings file for textual inversion", visible=hasattr(model, "embedding_manager")),
2022-08-22 17:15:46 +03:00
],
outputs=[
gr.Gallery(),
gr.Number(label='Seed'),
gr.Textbox(label="Copy-paste generation parameters"),
2022-08-24 16:46:39 +03:00
gr.HTML(label='Stats'),
2022-08-22 17:15:46 +03:00
],
title="Stable Diffusion Image-to-Image Unified",
2022-08-22 17:15:46 +03:00
description="Generate images from images with Stable Diffusion",
allow_flagging="never",
2022-08-24 18:45:55 +03:00
theme="default",
2022-08-22 17:15:46 +03:00
)
interfaces = [
(txt2img_interface, "txt2img"),
(img2img_interface, "img2img")
]
# grabs all text up to the first occurrence of ':' as sub-prompt
# takes the value following ':' as weight
# if ':' has no value defined, defaults to 1.0
# repeats until no text remaining
# TODO this could probably be done with less code
def split_weighted_subprompts(text):
print(text)
remaining = len(text)
prompts = []
weights = []
while remaining > 0:
if ":" in text:
idx = text.index(":") # first occurrence from start
# grab up to index as sub-prompt
prompt = text[:idx]
remaining -= idx
# remove from main text
text = text[idx+1:]
# find value for weight, assume it is followed by a space or comma
idx = len(text) # default is read to end of text
if " " in text:
idx = min(idx,text.index(" ")) # want the closer idx
if "," in text:
idx = min(idx,text.index(",")) # want the closer idx
if idx != 0:
try:
weight = float(text[:idx])
except: # couldn't treat as float
print(f"Warning: '{text[:idx]}' is not a value, are you missing a space or comma after a value?")
weight = 1.0
else: # no value found
weight = 1.0
# remove from main text
remaining -= idx
text = text[idx+1:]
# append the sub-prompt and its weight
prompts.append(prompt)
weights.append(weight)
else: # no : found
if len(text) > 0: # there is still text though
# take remainder as weight 1
prompts.append(text)
weights.append(1.0)
remaining = 0
return prompts, weights
def run_GFPGAN(image, strength):
image = image.convert("RGB")
cropped_faces, restored_faces, restored_img = GFPGAN.enhance(np.array(image, dtype=np.uint8), has_aligned=False, only_center_face=False, paste_back=True)
res = Image.fromarray(restored_img)
if strength < 1.0:
res = Image.blend(image, res, strength)
return res
if GFPGAN is not None:
interfaces.append((gr.Interface(
run_GFPGAN,
inputs=[
gr.Image(label="Source", source="upload", interactive=True, type="pil"),
gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Effect strength", value=100),
],
outputs=[
gr.Image(label="Result"),
],
title="GFPGAN",
description="Fix faces on images",
allow_flagging="never",
2022-08-24 18:45:55 +03:00
theme="default",
), "GFPGAN"))
demo = gr.TabbedInterface(
interface_list=[x[0] for x in interfaces],
tab_names=[x[1] for x in interfaces],
2022-08-24 18:45:55 +03:00
css=("" if opt.no_progressbar_hiding else css_hide_progressbar),
theme="default",
)
2022-08-24 23:38:29 +03:00
demo.queue(concurrency_count=1)
demo.launch(show_error=True, server_name='0.0.0.0')