From 9927f7c38ab2bbf71cd509847b43a87f5bf62321 Mon Sep 17 00:00:00 2001 From: hlky <106811348+hlky@users.noreply.github.com> Date: Thu, 25 Aug 2022 00:03:47 +0100 Subject: [PATCH] img2img-mask Great work, thanks anon!! :100: :fire: --- webui.py | 66 ++++++++++++++++++++++++++++++++++++-------------------- 1 file changed, 43 insertions(+), 23 deletions(-) diff --git a/webui.py b/webui.py index 495bb4c..de1d219 100644 --- a/webui.py +++ b/webui.py @@ -15,7 +15,7 @@ from contextlib import contextmanager, nullcontext from einops import rearrange, repeat from itertools import islice from omegaconf import OmegaConf -from PIL import Image, ImageFont, ImageDraw +from PIL import Image, ImageFont, ImageDraw, ImageFilter, ImageOps from torch import autocast from ldm.models.diffusion.ddim import DDIMSampler from ldm.models.diffusion.plms import PLMSSampler @@ -104,11 +104,11 @@ class MemUsageMonitor(threading.Thread): stop_flag = False max_usage = 0 total = 0 - + def __init__(self, name): threading.Thread.__init__(self) self.name = name - + def run(self): print(f"[{self.name}] Recording max memory usage...\n") pynvml.nvmlInit() @@ -121,13 +121,13 @@ class MemUsageMonitor(threading.Thread): time.sleep(0.1) print(f"[{self.name}] Stopped recording.\n") pynvml.nvmlShutdown() - + def read(self): return self.max_usage, self.total - + def stop(self): self.stop_flag = True - + def read_and_stop(self): self.stop_flag = True return self.max_usage, self.total @@ -163,11 +163,11 @@ class MemUsageMonitor(threading.Thread): stop_flag = False max_usage = 0 total = 0 - + def __init__(self, name): threading.Thread.__init__(self) self.name = name - + def run(self): print(f"[{self.name}] Recording max memory usage...\n") pynvml.nvmlInit() @@ -180,13 +180,13 @@ class MemUsageMonitor(threading.Thread): time.sleep(0.1) print(f"[{self.name}] Stopped recording.\n") pynvml.nvmlShutdown() - + def read(self): return self.max_usage, self.total - + def stop(self): self.stop_flag = True - + def read_and_stop(self): self.stop_flag = True return self.max_usage, self.total @@ -389,7 +389,7 @@ def check_prompt_length(prompt, comments): comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n") -def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name, skip_grid, skip_save, batch_size, n_iter, steps, cfg_scale, width, height, prompt_matrix, use_GFPGAN, fp, do_not_save_grid=False, normalize_prompt_weights=True): +def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name, skip_grid, skip_save, batch_size, n_iter, steps, cfg_scale, width, height, prompt_matrix, use_GFPGAN, fp, do_not_save_grid=False, normalize_prompt_weights=True, init_img=None, init_mask=None, keep_mask=True): """this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch""" assert prompt is not None torch_gc() @@ -458,7 +458,7 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name, uc = model.get_learned_conditioning(len(prompts) * [""]) if isinstance(prompts, tuple): prompts = list(prompts) - + # split the prompt if it has : for weighting # TODO for speed it might help to have this occur when all_prompts filled?? subprompts,weights = split_weighted_subprompts(prompts[0]) @@ -474,16 +474,16 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name, weight = weight / totalPromptWeight #print(f"{subprompts[i]} {weight*100.0}%") # note if alpha negative, it functions same as torch.sub - c = torch.add(c,model.get_learned_conditioning(subprompts[i]), alpha=weight) + c = torch.add(c,model.get_learned_conditioning(subprompts[i]), alpha=weight) else: # just behave like usual c = model.get_learned_conditioning(prompts) - + shape = [opt_C, height // opt_f, width // opt_f] # we manually generate all input noises because each one should have a specific seed x = create_random_tensors([opt_C, height // opt_f, width // opt_f], seeds=seeds) samples_ddim = func_sample(init_data=init_data, x=x, conditioning=c, unconditional_conditioning=uc, sampler_name=sampler_name) - + x_samples_ddim = model.decode_first_stage(samples_ddim) x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) @@ -497,6 +497,14 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name, image = Image.fromarray(x_sample) + if init_mask: + init_mask = init_mask if keep_mask else ImageOps.invert(init_mask) + init_mask = init_mask.filter(ImageFilter.GaussianBlur(3)) + init_mask = init_mask.convert('L') + init_img = init_img.convert('RGB') + image = image.convert('RGB') + image = Image.composite(init_img, image, init_mask) + filename = f"{base_count:05}-{seeds[i]}_{prompts[i].replace(' ', '_').translate({ord(x): '' for x in invalid_filename_chars})[:128]}.png" if not skip_save: image.save(os.path.join(sample_path, filename)) @@ -532,10 +540,10 @@ Steps: {steps}, Sampler: {sampler_name}, CFG scale: {cfg_scale}, Seed: {seed}{', stats = f''' Took { round(time_diff, 2) }s total ({ round(time_diff/(len(all_prompts)),2) }s per image) Peak memory usage: { -(mem_max_used // -1_048_576) } MiB / { -(mem_total // -1_048_576) } MiB / { round(mem_max_used/mem_total*100, 3) }%''' - + for comment in comments: info += "\n\n" + comment - + #mem_mon.stop() #del mem_mon torch_gc() @@ -673,7 +681,7 @@ txt2img_interface = gr.Interface( ) -def img2img(prompt: str, init_img, ddim_steps: int, sampler_name: str, use_GFPGAN: bool, prompt_matrix, loopback: bool, skip_grid: bool, skip_save: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, height: int, width: int, resize_mode: int, normalize_prompt_weights: bool, fp): +def img2img(prompt: str, init_info, mask_mode, ddim_steps: int, sampler_name: str, use_GFPGAN: bool, prompt_matrix, loopback: bool, skip_grid: bool, skip_save: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, height: int, width: int, resize_mode: int, normalize_prompt_weights: bool, fp): outpath = opt.outdir or "outputs/img2img-samples" err = False seed = seed_to_int(seed) @@ -685,6 +693,14 @@ def img2img(prompt: str, init_img, ddim_steps: int, sampler_name: str, use_GFPGA else: raise Exception("Unknown sampler: " + sampler_name) + init_img = init_info["image"] + init_img = init_img.convert("RGB") + init_img = resize_image(resize_mode, init_img, width, height) + init_mask = init_info["mask"] + init_mask = init_mask.convert("RGB") + init_mask = resize_image(resize_mode, init_mask, width, height) + keep_mask = mask_mode == "Keep masked area" + assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]' t_enc = int(denoising_strength * ddim_steps) @@ -789,9 +805,12 @@ def img2img(prompt: str, init_img, ddim_steps: int, sampler_name: str, use_GFPGA prompt_matrix=prompt_matrix, use_GFPGAN=use_GFPGAN, fp=fp, - normalize_prompt_weights=normalize_prompt_weights + normalize_prompt_weights=normalize_prompt_weights, + init_img=init_img, + init_mask=init_mask, + keep_mask=keep_mask ) - + del sampler return output_images, seed, info, stats @@ -813,7 +832,8 @@ img2img_interface = gr.Interface( img2img, inputs=[ gr.Textbox(placeholder="A fantasy landscape, trending on artstation.", lines=1), - gr.Image(value=sample_img2img, source="upload", interactive=True, type="pil"), + gr.Image(value=sample_img2img, source="upload", interactive=True, type="pil", tool="sketch"), + gr.Radio(choices=["Keep masked area", "Regenerate only masked area"], label="Mask Mode", value="Keep masked area"), gr.Slider(minimum=1, maximum=150, step=1, label="Sampling Steps", value=50), gr.Radio(label='Sampling method', choices=["DDIM", "k-diffusion"], value="k-diffusion"), gr.Checkbox(label='Fix faces using GFPGAN', value=False, visible=GFPGAN is not None), @@ -930,4 +950,4 @@ demo = gr.TabbedInterface( theme="default", ) demo.queue(concurrency_count=1) -demo.launch(show_error=True, server_name='0.0.0.0') +demo.launch(show_error=True, server_name='0.0.0.0') \ No newline at end of file