Merge pull request #1 from ShinkoNet/master

Memory Patch
This commit is contained in:
hlky 2022-08-24 15:22:09 +01:00 committed by GitHub
commit 0ed0dd7e90
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 188 additions and 65 deletions

11
relauncher.py Normal file
View File

@ -0,0 +1,11 @@
import os, time
n = 0
while True:
print('Relauncher: Launching...')
if n > 0:
print(f'\tRelaunch count: {n}')
os.system("python scripts/webui.py")
print('Relauncher: Process ending. Relaunching in 0.5s...')
n += 1
time.sleep(0.5)

242
webui.py
View File

@ -1,25 +1,25 @@
import argparse, os, sys, glob
import torch
import torch.nn as nn
import numpy as np
import gradio as gr
from omegaconf import OmegaConf
from PIL import Image, ImageFont, ImageDraw
from itertools import islice
from einops import rearrange, repeat
from torch import autocast
from contextlib import contextmanager, nullcontext
import mimetypes
import random
import k_diffusion as K
import math
import mimetypes
import numpy as np
import pynvml
import random
import threading
import time
import torch
import torch.nn as nn
import k_diffusion as K
from ldm.util import instantiate_from_config
from contextlib import contextmanager, nullcontext
from einops import rearrange, repeat
from itertools import islice
from omegaconf import OmegaConf
from PIL import Image, ImageFont, ImageDraw
from torch import autocast
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler
from ldm.util import instantiate_from_config
try:
# this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
@ -87,6 +87,50 @@ def load_model_from_config(config, ckpt, verbose=False):
model.eval()
return model
def crash(e, s):
global model
global device
print(s, '\n', e)
del model
del device
print('exiting...calling os._exit(0)')
t = threading.Timer(0.25, os._exit, args=[0])
t.start()
class MemUsageMonitor(threading.Thread):
stop_flag = False
max_usage = 0
total = 0
def __init__(self, name):
threading.Thread.__init__(self)
self.name = name
def run(self):
print(f"[{self.name}] Recording max memory usage...\n")
pynvml.nvmlInit()
handle = pynvml.nvmlDeviceGetHandleByIndex(0)
self.total = pynvml.nvmlDeviceGetMemoryInfo(handle).total
while not self.stop_flag:
m = pynvml.nvmlDeviceGetMemoryInfo(handle)
self.max_usage = max(self.max_usage, m.used)
# print(self.max_usage)
time.sleep(0.1)
print(f"[{self.name}] Stopped recording.\n")
pynvml.nvmlShutdown()
def read(self):
return self.max_usage, self.total
def stop(self):
self.stop_flag = True
def read_and_stop(self):
self.stop_flag = True
return self.max_usage, self.total
class CFGDenoiser(nn.Module):
def __init__(self, model):
@ -389,8 +433,10 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name,
precision_scope = autocast if opt.precision == "autocast" else nullcontext
output_images = []
stats = []
with torch.no_grad(), precision_scope("cuda"), model.ema_scope():
init_data = func_init()
tic = time.time()
for n in range(n_iter):
prompts = all_prompts[n * batch_size:(n + 1) * batch_size]
@ -432,7 +478,6 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name,
grid = image_grid(output_images, batch_size, round_down=prompt_matrix)
if prompt_matrix:
try:
grid = draw_prompt_matrix(grid, width, height, prompt_matrix_parts)
except Exception:
@ -442,31 +487,38 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name,
output_images.insert(0, grid)
grid_file = f"grid-{grid_count:05}-{seed}_{prompts[i].replace(' ', '_').translate({ord(x): '' for x in invalid_filename_chars})[:128]}.jpg"
grid.save(os.path.join(outpath, grid_file), 'jpeg', quality=80, optimize=True)
grid_count += 1
toc = time.time()
mem_max_used, mem_total = mem_mon.read_and_stop()
time_diff = time.time()-start_time
notes = f'''
Took { round(time_diff, 2) }s total ({ round(time_diff/(batch_size*n_iter),2) }s per image)<br>
Peak memory usage: { -(mem_max_used // -1_048_576) } MiB / { -(mem_total // -1_048_576) } MiB / { round(mem_max_used/mem_total*100, 3) }%<br>
'''
mem_max_used, mem_total = mem_mon.read_and_stop()
time_diff = time.time()-start_time
info = f"""
{prompt}
Steps: {steps}, Sampler: {sampler_name}, CFG scale: {cfg_scale}, Seed: {seed}{', GFPGAN' if use_GFPGAN and GFPGAN is not None else ''}
""".strip()
Steps: {steps}, Sampler: {sampler_name}, CFG scale: {cfg_scale}, Seed: {seed}{', GFPGAN' if use_GFPGAN and GFPGAN is not None else ''}{', Prompt Matrix Mode.' if prompt_matrix else ''}""".strip()
stats = f'''
Took { round(time_diff, 2) }s total ({ round(time_diff/(len(all_prompts)),2) }s per image)
Peak memory usage: { -(mem_max_used // -1_048_576) } MiB / { -(mem_total // -1_048_576) } MiB / { round(mem_max_used/mem_total*100, 3) }%'''
for comment in comments:
info += "\n\n" + comment
#mem_mon.stop()
#del mem_mon
torch_gc()
return output_images, seed, info, notes
return output_images, seed, info, stats
def txt2img(prompt: str, ddim_steps: int, sampler_name: str, use_GFPGAN: bool, prompt_matrix: bool, ddim_eta: float, n_iter: int, batch_size: int, cfg_scale: float, seed: int, height: int, width: int):
outpath = opt.outdir or "outputs/txt2img-samples"
err = False
if sampler_name == 'PLMS':
sampler = PLMSSampler(model)
@ -483,27 +535,35 @@ def txt2img(prompt: str, ddim_steps: int, sampler_name: str, use_GFPGAN: bool, p
def sample(init_data, x, conditioning, unconditional_conditioning):
samples_ddim, _ = sampler.sample(S=ddim_steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=cfg_scale, unconditional_conditioning=unconditional_conditioning, eta=ddim_eta, x_T=x)
return samples_ddim
try:
output_images, seed, info, stats = process_images(
outpath=outpath,
func_init=init,
func_sample=sample,
prompt=prompt,
seed=seed,
sampler_name=sampler_name,
batch_size=batch_size,
n_iter=n_iter,
steps=ddim_steps,
cfg_scale=cfg_scale,
width=width,
height=height,
prompt_matrix=prompt_matrix,
use_GFPGAN=use_GFPGAN
)
output_images, seed, info, notes = process_images(
outpath=outpath,
func_init=init,
func_sample=sample,
prompt=prompt,
seed=seed,
sampler_name=sampler_name,
batch_size=batch_size,
n_iter=n_iter,
steps=ddim_steps,
cfg_scale=cfg_scale,
width=width,
height=height,
prompt_matrix=prompt_matrix,
use_GFPGAN=use_GFPGAN
)
del sampler
del sampler
return output_images, seed, info, notes
return output_images, seed, info, stats
except RuntimeError as e:
err = e
err_msg = f'CRASHED:<br><textarea rows="5" style="background: black;width: -webkit-fill-available;font-family: monospace;font-size: small;font-weight: bold;">{str(e)}</textarea><br><br>Please wait while the program restarts.'
stats = err_msg
return [], 1
finally:
if err:
crash(err, '!!Runtime error (txt2img)!!')
class Flagging(gr.FlaggingCallback):
@ -567,16 +627,17 @@ txt2img_interface = gr.Interface(
gr.Gallery(label="Images"),
gr.Number(label='Seed'),
gr.Textbox(label="Copy-paste generation parameters"),
gr.HTML(label='Notes'),
gr.HTML(label='Stats'),
],
title="Stable Diffusion Text-to-Image K",
description="Generate images from text with Stable Diffusion (using K-LMS)",
title="Stable Diffusion Text-to-Image Unified",
description="Generate images from text with Stable Diffusion",
flagging_callback=Flagging()
)
def img2img(prompt: str, init_img, ddim_steps: int, use_GFPGAN: bool, prompt_matrix, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, height: int, width: int, resize_mode: int):
outpath = opt.outdir or "outputs/img2img-samples"
err = False
sampler = KDiffusionSampler(model)
@ -609,26 +670,77 @@ def img2img(prompt: str, init_img, ddim_steps: int, use_GFPGAN: bool, prompt_mat
samples_ddim = K.sampling.sample_lms(model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': cfg_scale}, disable=False)
return samples_ddim
output_images, seed, info, notes = process_images(
outpath=outpath,
func_init=init,
func_sample=sample,
prompt=prompt,
seed=seed,
sampler_name='k-diffusion',
batch_size=batch_size,
n_iter=n_iter,
steps=ddim_steps,
cfg_scale=cfg_scale,
width=width,
height=height,
prompt_matrix=prompt_matrix,
use_GFPGAN=use_GFPGAN
)
del sampler
try:
if loopback:
output_images, info = None, None
history = []
initial_seed = None
for i in range(n_iter):
output_images, seed, info, stats = process_images(
outpath=outpath,
func_init=init,
func_sample=sample,
prompt=prompt,
seed=seed,
sampler_name='k-diffusion',
batch_size=1,
n_iter=1,
steps=ddim_steps,
cfg_scale=cfg_scale,
width=width,
height=height,
prompt_matrix=prompt_matrix,
use_GFPGAN=use_GFPGAN,
do_not_save_grid=True
)
if initial_seed is None:
initial_seed = seed
init_img = output_images[0]
seed = seed + 1
denoising_strength = max(denoising_strength * 0.95, 0.1)
history.append(init_img)
grid_count = len(os.listdir(outpath)) - 1
grid = image_grid(history, batch_size, force_n_rows=1)
grid.save(os.path.join(outpath, f'grid-{grid_count:04}.{opt.grid_format}'))
output_images = history
seed = initial_seed
else:
output_images, seed, info, stats = process_images(
outpath=outpath,
func_init=init,
func_sample=sample,
prompt=prompt,
seed=seed,
sampler_name='k-diffusion',
batch_size=batch_size,
n_iter=n_iter,
steps=ddim_steps,
cfg_scale=cfg_scale,
width=width,
height=height,
prompt_matrix=prompt_matrix,
use_GFPGAN=use_GFPGAN
)
del sampler
return output_images, seed, info, stats
except RuntimeError as e:
err = e
err_msg = f'CRASHED:<br><textarea rows="5" style="background: black;width: -webkit-fill-available;font-family: monospace;font-size: small;font-weight: bold;">{str(e)}</textarea><br><br>Please wait while the program restarts.'
stats = err_msg
return [], 1
finally:
if err:
crash(err, '!!Runtime error (img2img)!!')
return output_images, seed, info, notes
sample_img2img = "assets/stable-samples/img2img/sketch-mountains-input.jpg"
@ -655,9 +767,9 @@ img2img_interface = gr.Interface(
gr.Gallery(),
gr.Number(label='Seed'),
gr.Textbox(label="Copy-paste generation parameters"),
gr.HTML(label='Notes'),
gr.HTML(label='Stats'),
],
title="Stable Diffusion Image-to-Image",
title="Stable Diffusion Image-to-Image Unified",
description="Generate images from images with Stable Diffusion",
allow_flagging="never",
)
@ -700,4 +812,4 @@ demo = gr.TabbedInterface(
css=("" if opt.no_progressbar_hiding else css_hide_progressbar)
)
demo.launch()
demo.launch()