From c73fdd78d959018679b741ba523ee5f9ea2d0b93 Mon Sep 17 00:00:00 2001 From: anon-hlhl <112447525+anon-hlhl@users.noreply.github.com> Date: Tue, 30 Aug 2022 11:32:15 +0100 Subject: [PATCH] Implement masking during sampling to improve blending (#308) * Implement masking during sampling to improve blending - Almost inpainting-like behaviour - Works best with small erased areas - Denoiser strength to 1 results to completely in-paint with no previous image info - Uncrop kind of sometimes works, but not really - DDIM changes depend on a PR in stable-diffusion core * Remove old init_masking code on save, as the masked image is applied during sampling Co-authored-by: anon-hlhl <> Co-authored-by: hlky <106811348+hlky@users.noreply.github.com> --- frontend/frontend.py | 5 +- frontend/ui_functions.py | 4 +- webui.py | 107 +++++++++++++++++++++++++++------------ 3 files changed, 81 insertions(+), 35 deletions(-) diff --git a/frontend/frontend.py b/frontend/frontend.py index 131992c..a0c7939 100644 --- a/frontend/frontend.py +++ b/frontend/frontend.py @@ -128,13 +128,14 @@ def draw_gradio_ui(opt, img2img=lambda x: x, txt2img=lambda x: x, txt2img_defaul with gr.Column(): gr.Markdown('#### Img2Img input') img2img_image_editor = gr.Image(value=sample_img2img, source="upload", interactive=True, - type="pil", tool="select", elem_id="img2img_editor") + type="pil", tool="select", elem_id="img2img_editor", + image_mode="RGBA") img2img_image_mask = gr.Image(value=sample_img2img, source="upload", interactive=True, type="pil", tool="sketch", visible=False, elem_id="img2img_mask") with gr.Row(): - img2img_image_editor_mode = gr.Radio(choices=["Mask", "Crop"], label="Image Editor Mode", + img2img_image_editor_mode = gr.Radio(choices=["Mask", "Crop", "Uncrop"], label="Image Editor Mode", value="Crop", elem_id='edit_mode_select') img2img_painterro_btn = gr.Button("Advanced Editor") diff --git a/frontend/ui_functions.py b/frontend/ui_functions.py index 84032b1..013397b 100644 --- a/frontend/ui_functions.py +++ b/frontend/ui_functions.py @@ -89,7 +89,7 @@ def resize_image(resize_mode, im, width, height): src_h = height if ratio <= src_ratio else im.height * width // im.width resized = im.resize((src_w, src_h), resample=LANCZOS) - res = Image.new("RGB", (width, height)) + res = Image.new("RGBA", (width, height)) res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2)) else: ratio = width / height @@ -99,7 +99,7 @@ def resize_image(resize_mode, im, width, height): src_h = height if ratio >= src_ratio else im.height * width // im.width resized = im.resize((src_w, src_h), resample=LANCZOS) - res = Image.new("RGB", (width, height)) + res = Image.new("RGBA", (width, height)) res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2)) if ratio < src_ratio: diff --git a/webui.py b/webui.py index 6546f7e..aaedd3d 100644 --- a/webui.py +++ b/webui.py @@ -200,6 +200,27 @@ class MemUsageMonitor(threading.Thread): self.stop_flag = True return self.max_usage, self.total +class CFGMaskedDenoiser(nn.Module): + def __init__(self, model): + super().__init__() + self.inner_model = model + + def forward(self, x, sigma, uncond, cond, cond_scale, mask, x0, xi): + x_in = x + x_in = torch.cat([x_in] * 2) + sigma_in = torch.cat([sigma] * 2) + cond_in = torch.cat([uncond, cond]) + uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2) + denoised = uncond + (cond - uncond) * cond_scale + + if mask is not None: + assert x0 is not None + img_orig = x0 + mask_inv = 1. - mask + denoised = (img_orig * mask_inv) + (mask * denoised) + + return denoised + class CFGDenoiser(nn.Module): def __init__(self, model): super().__init__() @@ -874,27 +895,8 @@ normalize_prompt_weights, use_GFPGAN, write_info_files, prompt_matrix, init_img, skip_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode) output_images.append(gfpgan_esrgan_image) #287 - - if init_mask: - #init_mask = init_mask if keep_mask else ImageOps.invert(init_mask) - init_mask = init_mask.filter(ImageFilter.GaussianBlur(mask_blur_strength)) - init_mask = init_mask.convert('L') - init_img = init_img.convert('RGB') - image = image.convert('RGB') - - if use_RealESRGAN and RealESRGAN is not None: - if RealESRGAN.model.name != realesrgan_model_name: - try_loading_RealESRGAN(realesrgan_model_name) - output, img_mode = RealESRGAN.enhance(np.array(init_img, dtype=np.uint8)) - init_img = Image.fromarray(output) - init_img = init_img.convert('RGB') - - output, img_mode = RealESRGAN.enhance(np.array(init_mask, dtype=np.uint8)) - init_mask = Image.fromarray(output) - init_mask = init_mask.convert('L') - - image = Image.composite(init_img, image, init_mask) if not skip_save or (not use_GFPGAN or not use_RealESRGAN): + save_sample(image, sample_path_i, filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale, normalize_prompt_weights, use_GFPGAN, write_info_files, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, skip_save, skip_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode) @@ -1116,7 +1118,7 @@ def img2img(prompt: str, image_editor_mode: str, init_info, mask_mode: str, mask if image_editor_mode == 'Mask': init_img = init_info["image"] - init_img = init_img.convert("RGB") + init_img = init_img.convert("RGBA") init_img = resize_image(resize_mode, init_img, width, height) init_mask = init_info["mask"] init_mask = init_mask.convert("RGB") @@ -1138,6 +1140,28 @@ def img2img(prompt: str, image_editor_mode: str, init_info, mask_mode: str, mask image = image[None].transpose(0, 3, 1, 2) image = torch.from_numpy(image) + mask_channel = None + if image_editor_mode == "Uncrop": + alpha = init_img.convert("RGBA") + alpha = resize_image(resize_mode, alpha, width // 8, height // 8) + mask_channel = alpha.split()[-1] + mask_channel = mask_channel.filter(ImageFilter.GaussianBlur(4)) + mask_channel = np.array(mask_channel) + mask_channel[mask_channel >= 255] = 255 + mask_channel[mask_channel < 255] = 0 + mask_channel = Image.fromarray(mask_channel).filter(ImageFilter.GaussianBlur(2)) + elif init_mask is not None: + alpha = init_mask.convert("RGBA") + alpha = resize_image(resize_mode, alpha, width // 8, height // 8) + mask_channel = alpha.split()[1] + + mask = None + if mask_channel is not None: + mask = np.array(mask_channel).astype(np.float32) / 255.0 + mask = (1 - mask) + mask = np.tile(mask, (4, 1, 1)) + mask = mask[None].transpose(0, 1, 2, 3) + mask = torch.from_numpy(mask).to(device) if opt.optimized: modelFS.to(device) @@ -1152,27 +1176,48 @@ def img2img(prompt: str, image_editor_mode: str, init_info, mask_mode: str, mask while(torch.cuda.memory_allocated()/1e6 >= mem): time.sleep(1) - return init_latent, + return init_latent, mask, def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name): + t_enc_steps = t_enc + obliterate = False + if ddim_steps == t_enc_steps: + t_enc_steps = t_enc_steps - 1 + obliterate = True + if sampler_name != 'DDIM': - x0, = init_data + x0, z_mask = init_data sigmas = sampler.model_wrap.get_sigmas(ddim_steps) - noise = x * sigmas[ddim_steps - t_enc - 1] + noise = x * sigmas[ddim_steps - t_enc_steps - 1] xi = x0 + noise - sigma_sched = sigmas[ddim_steps - t_enc - 1:] - model_wrap_cfg = CFGDenoiser(sampler.model_wrap) - samples_ddim = K.sampling.__dict__[f'sample_{sampler.get_sampler_name()}'](model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': cfg_scale}, disable=False) + + # Obliterate masked image + if z_mask is not None and obliterate: + random = torch.randn(z_mask.shape, device=xi.device) + xi = (z_mask * noise) + ((1-z_mask) * xi) + + sigma_sched = sigmas[ddim_steps - t_enc_steps - 1:] + model_wrap_cfg = CFGMaskedDenoiser(sampler.model_wrap) + samples_ddim = K.sampling.__dict__[f'sample_{sampler.get_sampler_name()}'](model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': cfg_scale, 'mask': z_mask, 'x0': x0, 'xi': xi}, disable=False) else: - x0, = init_data + + x0, z_mask = init_data + sampler.make_schedule(ddim_num_steps=ddim_steps, ddim_eta=0.0, verbose=False) - z_enc = sampler.stochastic_encode(x0, torch.tensor([t_enc]*batch_size).to(device)) + z_enc = sampler.stochastic_encode(x0, torch.tensor([t_enc_steps]*batch_size).to(device)) + + # Obliterate masked image + if z_mask is not None and obliterate: + random = torch.randn(z_mask.shape, device=z_enc.device) + z_enc = (z_mask * random) + ((1-z_mask) * z_enc) + # decode it - samples_ddim = sampler.decode(z_enc, conditioning, t_enc, + samples_ddim = sampler.decode(z_enc, conditioning, t_enc_steps, unconditional_guidance_scale=cfg_scale, - unconditional_conditioning=unconditional_conditioning,) + unconditional_conditioning=unconditional_conditioning, + z_mask=z_mask, x0=x0) return samples_ddim