Optimized support force push

This commit is contained in:
hlky 2022-08-27 02:40:37 +01:00
parent 8c1ea1311a
commit b0b9c8a1d9

213
webui.py
View File

@ -11,6 +11,7 @@ parser.add_argument("--n_rows", type=int, default=-1, help="rows in the grid; us
parser.add_argument("--config", type=str, default="configs/stable-diffusion/v1-inference.yaml", help="path to config which constructs model",)
parser.add_argument("--ckpt", type=str, default="models/ldm/stable-diffusion-v1/model.ckpt", help="path to checkpoint of model",)
parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast")
parser.add_argument("--optimized", action='store_true', help="load the model onto the device piecemeal instead of all at once to reduce VRAM usage at the cost of performance")
parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN')) # i disagree with where you're putting it but since all guidefags are doing it this way, there you go
parser.add_argument("--realesrgan-dir", type=str, help="RealESRGAN directory", default=('./src/realesrgan' if os.path.exists('./src/realesrgan') else './RealESRGAN'))
parser.add_argument("--realesrgan-model", type=str, help="Upscaling model for RealESRGAN", default=('RealESRGAN_x4plus'))
@ -110,6 +111,14 @@ def load_model_from_config(config, ckpt, verbose=False):
model.eval()
return model
def load_sd_from_config(ckpt, verbose=False):
print(f"Loading model from {ckpt}")
pl_sd = torch.load(ckpt, map_location="cpu")
if "global_step" in pl_sd:
print(f"Global Step: {pl_sd['global_step']}")
sd = pl_sd["state_dict"]
return sd
def crash(e, s):
global model
global device
@ -270,11 +279,49 @@ def try_loading_RealESRGAN(model_name: str):
print(traceback.format_exc(), file=sys.stderr)
try_loading_RealESRGAN('RealESRGAN_x4plus')
config = OmegaConf.load("configs/stable-diffusion/v1-inference.yaml")
model = load_model_from_config(config, "models/ldm/stable-diffusion-v1/model.ckpt")
if opt.optimized:
config = OmegaConf.load("optimizedSD/v1-inference.yaml")
sd = load_sd_from_config("models/ldm/stable-diffusion-v1/model.ckpt")
li, lo = [], []
for key, v_ in sd.items():
sp = key.split('.')
if(sp[0]) == 'model':
if('input_blocks' in sp):
li.append(key)
elif('middle_block' in sp):
li.append(key)
elif('time_embed' in sp):
li.append(key)
else:
lo.append(key)
for key in li:
sd['model1.' + key[6:]] = sd.pop(key)
for key in lo:
sd['model2.' + key[6:]] = sd.pop(key)
device = torch.device(f"cuda") if torch.cuda.is_available() else torch.device("cpu")
model = (model if opt.no_half else model.half()).to(device)
config.modelUNet.params.small_batch = False
model = instantiate_from_config(config.modelUNet)
_, _ = model.load_state_dict(sd, strict=False)
model.cuda()
model.eval()
modelCS = instantiate_from_config(config.modelCondStage)
_, _ = modelCS.load_state_dict(sd, strict=False)
modelCS.eval()
modelFS = instantiate_from_config(config.modelFirstStage)
_, _ = modelFS.load_state_dict(sd, strict=False)
modelFS.eval()
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = model if opt.no_half else model.half()
modelCS = modelCS if opt.no_half else modelCS.half()
else:
config = OmegaConf.load("configs/stable-diffusion/v1-inference.yaml")
model = load_model_from_config(config, "models/ldm/stable-diffusion-v1/model.ckpt")
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = (model if opt.no_half else model.half()).to(device)
def load_embeddings(fp):
if fp is not None and hasattr(model, "embedding_manager"):
@ -305,9 +352,9 @@ def seed_to_int(s):
if type(s) is int:
return s
if s is None or s == '':
return random.randint(0, 2**32 - 1)
n = abs(int(s) if s.isdigit() else random.Random(s).randint(0, 2**32 - 1))
while n >= 2**32:
return random.randint(0,2**32)
n = abs(int(s) if s.isdigit() else hash(s))
while n > 2**32:
n = n >> 32
return n
@ -425,10 +472,10 @@ def resize_image(resize_mode, im, width, height):
def check_prompt_length(prompt, comments):
"""this function tests if prompt is too long, and if so, adds a message to comments"""
tokenizer = model.cond_stage_model.tokenizer
max_length = model.cond_stage_model.max_length
tokenizer = (model if not opt.optimized else modelCS).cond_stage_model.tokenizer
max_length = (model if not opt.optimized else modelCS).cond_stage_model.max_length
info = model.cond_stage_model.tokenizer([prompt], truncation=True, max_length=max_length, return_overflowing_tokens=True, padding="max_length", return_tensors="pt")
info = (model if not opt.optimized else modelCS).cond_stage_model.tokenizer([prompt], truncation=True, max_length=max_length, return_overflowing_tokens=True, padding="max_length", return_tensors="pt")
ovf = info['overflowing_tokens'][0]
overflowing_count = ovf.shape[0]
if overflowing_count == 0:
@ -445,7 +492,7 @@ def process_images(
outpath, func_init, func_sample, prompt, seed, sampler_name, skip_grid, skip_save, batch_size,
n_iter, steps, cfg_scale, width, height, prompt_matrix, use_GFPGAN, use_RealESRGAN, realesrgan_model_name,
fp, ddim_eta=0.0, do_not_save_grid=False, normalize_prompt_weights=True, init_img=None, init_mask=None,
keep_mask=False, mask_blur_strength=3, denoising_strength=0.75, resize_mode=None, uses_loopback=False,
keep_mask=False, denoising_strength=0.75, resize_mode=None, uses_loopback=False,
uses_random_seed_loopback=False, sort_samples=True, write_info_files=True, jpg_sample=False):
"""this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
assert prompt is not None
@ -463,7 +510,7 @@ def process_images(
sample_path = os.path.join(outpath, "samples")
os.makedirs(sample_path, exist_ok=True)
grid_count = len([_ for x in os.listdir() if x.endswith(('.png', '.jpg'))]) - 1 # start at 0
grid_count = len(os.listdir(outpath)) - 1
comments = []
@ -501,7 +548,7 @@ def process_images(
precision_scope = autocast if opt.precision == "autocast" else nullcontext
output_images = []
stats = []
with torch.no_grad(), precision_scope("cuda"), model.ema_scope():
with torch.no_grad(), precision_scope("cuda"), (model.ema_scope() if not opt.optimized else nullcontext()):
init_data = func_init()
tic = time.time()
@ -509,7 +556,9 @@ def process_images(
prompts = all_prompts[n * batch_size:(n + 1) * batch_size]
seeds = all_seeds[n * batch_size:(n + 1) * batch_size]
uc = model.get_learned_conditioning(len(prompts) * [""])
if opt.optimized:
modelCS.to(device)
uc = (model if not opt.optimized else modelCS).get_learned_conditioning(len(prompts) * [""])
if isinstance(prompts, tuple):
prompts = list(prompts)
@ -528,18 +577,26 @@ def process_images(
weight = weight / totalPromptWeight
#print(f"{subprompts[i]} {weight*100.0}%")
# note if alpha negative, it functions same as torch.sub
c = torch.add(c,model.get_learned_conditioning(subprompts[i]), alpha=weight)
c = torch.add(c, (model if not opt.optimized else modelCS).get_learned_conditioning(subprompts[i]), alpha=weight)
else: # just behave like usual
c = model.get_learned_conditioning(prompts)
c = (model if not opt.optimized else modelCS).get_learned_conditioning(prompts)
shape = [opt_C, height // opt_f, width // opt_f]
if opt.optimized:
mem = torch.cuda.memory_allocated()/1e6
modelCS.to("cpu")
while(torch.cuda.memory_allocated()/1e6 >= mem):
time.sleep(1)
# we manually generate all input noises because each one should have a specific seed
x = create_random_tensors([opt_C, height // opt_f, width // opt_f], seeds=seeds)
samples_ddim = func_sample(init_data=init_data, x=x, conditioning=c, unconditional_conditioning=uc, sampler_name=sampler_name)
if opt.optimized:
modelFS.to(device)
x_samples_ddim = model.decode_first_stage(samples_ddim)
x_samples_ddim = (model if not opt.optimized else modelFS).decode_first_stage(samples_ddim)
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
for i, x_sample in enumerate(x_samples_ddim):
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
@ -561,7 +618,7 @@ def process_images(
image = Image.fromarray(x_sample)
if init_mask:
#init_mask = init_mask if keep_mask else ImageOps.invert(init_mask)
init_mask = init_mask.filter(ImageFilter.GaussianBlur(mask_blur_strength))
init_mask = init_mask.filter(ImageFilter.GaussianBlur(3))
init_mask = init_mask.convert('L')
init_img = init_img.convert('RGB')
image = image.convert('RGB')
@ -584,11 +641,11 @@ def process_images(
sanitized_prompt = sanitized_prompt[:128] #200 is too long
sample_path_i = os.path.join(sample_path, sanitized_prompt)
os.makedirs(sample_path_i, exist_ok=True)
base_count = len([_ for x in os.listdir() if x.endswith(('.png', '.jpg'))]) - 1 # start at 0
base_count = len(os.listdir(sample_path_i))
filename = f"{base_count:05}-{seeds[i]}"
else:
sample_path_i = sample_path
base_count = len([_ for x in os.listdir() if x.endswith(('.png', '.jpg'))]) - 1 # start at 0
base_count = len(os.listdir(sample_path_i))
sanitized_prompt = sanitized_prompt
filename = f"{base_count:05}-{seeds[i]}_{sanitized_prompt}"[:128] #same as before
if not skip_save:
@ -667,6 +724,13 @@ def process_images(
grid_file = f"grid-{grid_count:05}-{seed}_{prompts[i].replace(' ', '_').translate({ord(x): '' for x in invalid_filename_chars})[:128]}.jpg"
grid.save(os.path.join(outpath, grid_file), 'jpeg', quality=100, optimize=True)
grid_count += 1
if opt.optimized:
mem = torch.cuda.memory_allocated()/1e6
modelFS.to("cpu")
while(torch.cuda.memory_allocated()/1e6 >= mem):
time.sleep(1)
toc = time.time()
mem_max_used, mem_total = mem_mon.read_and_stop()
@ -815,7 +879,7 @@ class Flagging(gr.FlaggingCallback):
print("Logged:", filenames[0])
def img2img(prompt: str, image_editor_mode: str, init_info, mask_mode: str, mask_blur_strength: int, ddim_steps: int, sampler_name: str,
def img2img(prompt: str, image_editor_mode: str, init_info, mask_mode: str, ddim_steps: int, sampler_name: str,
toggles: List[int], realesrgan_model_name: str, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float,
seed: int, height: int, width: int, resize_mode: int, fp):
outpath = opt.outdir_img2img or opt.outdir or "outputs/img2img-samples"
@ -875,10 +939,19 @@ def img2img(prompt: str, image_editor_mode: str, init_info, mask_mode: str, mask
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
if opt.optimized:
modelFS.to(device)
init_image = 2. * image - 1.
init_image = init_image.to(device)
init_image = repeat(init_image, '1 ... -> b ...', b=batch_size)
init_latent = model.get_first_stage_encoding(model.encode_first_stage(init_image)) # move to latent space
init_latent = (model if not opt.optimized else modelFS).get_first_stage_encoding((model if not opt.optimized else modelFS).encode_first_stage(init_image)) # move to latent space
if opt.optimized:
mem = torch.cuda.memory_allocated()/1e6
modelFS.to("cpu")
while(torch.cuda.memory_allocated()/1e6 >= mem):
time.sleep(1)
return init_latent,
@ -936,7 +1009,6 @@ def img2img(prompt: str, image_editor_mode: str, init_info, mask_mode: str, mask
init_img=init_img,
init_mask=init_mask,
keep_mask=keep_mask,
mask_blur_strength=mask_blur_strength,
denoising_strength=denoising_strength,
resize_mode=resize_mode,
uses_loopback=loopback,
@ -958,7 +1030,7 @@ def img2img(prompt: str, image_editor_mode: str, init_info, mask_mode: str, mask
history.append(init_img)
if not skip_grid:
grid_count = len([_ for x in os.listdir() if x.endswith(('.png', '.jpg'))]) - 1 # start at 0
grid_count = len(os.listdir(outpath)) - 1
grid = image_grid(history, batch_size, force_n_rows=1)
grid_file = f"grid-{grid_count:05}-{seed}_{prompt.replace(' ', '_').translate({ord(x): '' for x in invalid_filename_chars})[:128]}.jpg"
grid.save(os.path.join(outpath, grid_file), 'jpeg', quality=100, optimize=True)
@ -992,7 +1064,6 @@ def img2img(prompt: str, image_editor_mode: str, init_info, mask_mode: str, mask
init_img=init_img,
init_mask=init_mask,
keep_mask=keep_mask,
mask_blur_strength=mask_blur_strength,
denoising_strength=denoising_strength,
resize_mode=resize_mode,
uses_loopback=loopback,
@ -1187,8 +1258,8 @@ img2img_image_mode = 'sketch'
def change_image_editor_mode(choice, cropped_image, resize_mode, width, height):
if choice == "Mask":
return [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), gr.update(visible=True)]
return [gr.update(visible=True), gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=True), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)]
return [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)]
return [gr.update(visible=True), gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)]
def update_image_mask(cropped_image, resize_mode, width, height):
resized_cropped_image = resize_image(resize_mode, cropped_image, width, height) if cropped_image else None
@ -1203,28 +1274,6 @@ def copy_img_to_input(selected=1, imgs = []):
except IndexError:
return [None, None]
help_text = """
## Mask/Crop
* The masking/cropping is very temperamental.
* It may take some time for the image to show when switching from Crop to Mask.
* If the image doesn't appear after switching to Mask, switch back to Crop and then back again to Mask
* If the mask appears distorted (the brush is weirdly shaped instead of round), switch back to Crop and then back again to Mask.
## Advanced Editor
* For now the button needs to be clicked twice the first time.
* Once you have edited your image, you _need_ to click the save button for the next step to work.
* Clear the image from the crop editor (click the x)
* Click "Get Image from Advanced Editor" to get the image you saved. If it doesn't work, try opening the editor and saving again.
If it keeps not working, try switching modes again, switch tabs, clear the image or reload.
"""
def show_help():
return [gr.update(visible=False), gr.update(visible=True), gr.update(value=help_text)]
def hide_help():
return [gr.update(visible=True), gr.update(visible=False), gr.update(value="")]
with gr.Blocks(css=css, analytics_enabled=False, title="Stable Diffusion WebUI") as demo:
with gr.Tabs():
with gr.TabItem("Stable Diffusion Text-to-Image Unified"):
@ -1265,16 +1314,19 @@ with gr.Blocks(css=css, analytics_enabled=False, title="Stable Diffusion WebUI")
gr.Markdown("Generate images from images with Stable Diffusion")
img2img_prompt = gr.Textbox(label="Prompt", placeholder="A fantasy landscape, trending on artstation.", lines=1, value=img2img_defaults['prompt'])
img2img_image_editor_mode = gr.Radio(choices=["Mask", "Crop"], label="Image Editor Mode", value="Crop")
img2img_show_help_btn = gr.Button("Show Hints")
img2img_hide_help_btn = gr.Button("Hide Hints", visible=False)
img2img_help = gr.Markdown(visible=False, value="")
with gr.Row():
img2img_painterro_btn = gr.Button("Advanced Editor")
img2img_copy_from_painterro_btn = gr.Button(value="Get Image from Advanced Editor")
gr.Markdown(
"""
The masking/cropping is very temperamental.
It may take some time for the image to show when switching from Crop to Mask.
* If the image doesn't appear after switching to Mask, switch back to Crop and then back again to Mask
* If the mask appears distorted (the brush is weirdly shaped instead of round), switch back to Crop and then back again to Mask.
If it keeps not working, try switching modes again, switch tabs, clear the image or reload.
"""
)
img2img_image_editor = gr.Image(value=sample_img2img, source="upload", interactive=True, type="pil", tool="select")
img2img_image_mask = gr.Image(value=sample_img2img, source="upload", interactive=True, type="pil", tool="sketch", visible=False)
img2img_mask = gr.Radio(choices=["Keep masked area", "Regenerate only masked area"], label="Mask Mode", type="index", value=img2img_mask_modes[img2img_defaults['mask_mode']], visible=False)
img2img_mask_blur_strength = gr.Slider(minimum=1, maximum=10, step=1, label="How much blurry should the mask be? (to avoid hard edges)", value=3, visible=False)
img2img_mask = gr.Radio(choices=["Keep masked area", "Regenerate only masked area"], label="Mask Mode", type="index", value=img2img_mask_modes[img2img_defaults['mask_mode']])
img2img_steps = gr.Slider(minimum=1, maximum=250, step=1, label="Sampling Steps", value=img2img_defaults['ddim_steps'])
img2img_sampling = gr.Radio(label='Sampling method (k_lms is default k-diffusion sampler)', choices=["DDIM", 'k_dpm_2_a', 'k_dpm_2', 'k_euler_a', 'k_euler', 'k_heun', 'k_lms'], value=img2img_defaults['sampler_name'])
img2img_toggles = gr.CheckboxGroup(label='', choices=img2img_toggles, value=img2img_toggle_defaults, type="index")
@ -1288,8 +1340,8 @@ with gr.Blocks(css=css, analytics_enabled=False, title="Stable Diffusion WebUI")
img2img_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=img2img_defaults["width"])
img2img_resize = gr.Radio(label="Resize mode", choices=["Just resize", "Crop and resize", "Resize and fill"], type="index", value=img2img_resize_modes[img2img_defaults['resize_mode']])
img2img_embeddings = gr.File(label = "Embeddings file for textual inversion", visible=hasattr(model, "embedding_manager"))
img2img_btn_mask = gr.Button("Generate", visible=False).style(full_width=True)
img2img_btn_editor = gr.Button("Generate").style(full_width=True)
img2img_btn_mask = gr.Button("Generate", visible=False)
img2img_btn_editor = gr.Button("Generate")
with gr.Column():
output_img2img_gallery = gr.Gallery(label="Images")
output_img2img_select_image = gr.Number(label='Select image number from results for copying', value=1, precision=None)
@ -1302,7 +1354,7 @@ with gr.Blocks(css=css, analytics_enabled=False, title="Stable Diffusion WebUI")
img2img_image_editor_mode.change(
change_image_editor_mode,
[img2img_image_editor_mode, img2img_image_editor, img2img_resize, img2img_width, img2img_height],
[img2img_image_editor, img2img_image_mask, img2img_btn_editor, img2img_btn_mask, img2img_painterro_btn, img2img_copy_from_painterro_btn, img2img_mask, img2img_mask_blur_strength]
[img2img_image_editor, img2img_image_mask, img2img_btn_editor, img2img_btn_mask]
)
img2img_image_editor.edit(
@ -1311,18 +1363,6 @@ with gr.Blocks(css=css, analytics_enabled=False, title="Stable Diffusion WebUI")
img2img_image_mask
)
img2img_show_help_btn.click(
show_help,
None,
[img2img_show_help_btn, img2img_hide_help_btn, img2img_help]
)
img2img_hide_help_btn.click(
hide_help,
None,
[img2img_show_help_btn, img2img_hide_help_btn, img2img_help]
)
output_img2img_copy_to_input_btn.click(
copy_img_to_input,
[output_img2img_select_image, output_img2img_gallery],
@ -1337,41 +1377,16 @@ with gr.Blocks(css=css, analytics_enabled=False, title="Stable Diffusion WebUI")
img2img_btn_mask.click(
img2img,
[img2img_prompt, img2img_image_editor_mode, img2img_image_mask, img2img_mask, img2img_mask_blur_strength, img2img_steps, img2img_sampling, img2img_toggles, img2img_realesrgan_model_name, img2img_batch_count, img2img_batch_size, img2img_cfg, img2img_denoising, img2img_seed, img2img_height, img2img_width, img2img_resize, img2img_embeddings],
[img2img_prompt, img2img_image_editor_mode, img2img_image_mask, img2img_mask, img2img_steps, img2img_sampling, img2img_toggles, img2img_realesrgan_model_name, img2img_batch_count, img2img_batch_size, img2img_cfg, img2img_denoising, img2img_seed, img2img_height, img2img_width, img2img_resize, img2img_embeddings],
[output_img2img_gallery, output_img2img_seed, output_img2img_params, output_img2img_stats]
)
img2img_btn_editor.click(
img2img,
[img2img_prompt, img2img_image_editor_mode, img2img_image_editor, img2img_mask, img2img_mask_blur_strength, img2img_steps, img2img_sampling, img2img_toggles, img2img_realesrgan_model_name, img2img_batch_count, img2img_batch_size, img2img_cfg, img2img_denoising, img2img_seed, img2img_height, img2img_width, img2img_resize, img2img_embeddings],
[img2img_prompt, img2img_image_editor_mode, img2img_image_editor, img2img_mask, img2img_steps, img2img_sampling, img2img_toggles, img2img_realesrgan_model_name, img2img_batch_count, img2img_batch_size, img2img_cfg, img2img_denoising, img2img_seed, img2img_height, img2img_width, img2img_resize, img2img_embeddings],
[output_img2img_gallery, output_img2img_seed, output_img2img_params, output_img2img_stats]
)
img2img_painterro_btn.click(None, [img2img_image_editor], None, _js="""(img) => {
try {
Painterro({
hiddenTools: ['arrow'],
saveHandler: function (image, done) {
localStorage.setItem('painterro-image', image.asDataURL());
done(true);
},
}).show(Array.isArray(img) ? img[0] : img);
} catch(e) {
const script = document.createElement('script');
script.src = 'https://unpkg.com/painterro@1.2.78/build/painterro.min.js';
document.head.appendChild(script);
const style = document.createElement('style');
style.appendChild(document.createTextNode('.ptro-holder-wrapper { z-index: 9999 !important; }'));
document.head.appendChild(style);
}
return [];
}""")
img2img_copy_from_painterro_btn.click(None, None, [img2img_image_editor, img2img_image_mask], _js="""() => {
const image = localStorage.getItem('painterro-image')
return [image, image];
}""")
if GFPGAN is not None:
gfpgan_defaults = {
'strength': 100,