Implement masking during sampling to improve blending (#308)

* Implement masking during sampling to improve blending

 - Almost inpainting-like behaviour
 - Works best with small erased areas
 - Denoiser strength to 1 results to completely in-paint with no previous image info
 - Uncrop kind of sometimes works, but not really
 - DDIM changes depend on a PR in stable-diffusion core

* Remove old init_masking code on save, as the masked image is applied during sampling

Co-authored-by: anon-hlhl <>
Co-authored-by: hlky <106811348+hlky@users.noreply.github.com>
This commit is contained in:
anon-hlhl 2022-08-30 11:32:15 +01:00 committed by GitHub
parent 96aba4b36d
commit c73fdd78d9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 81 additions and 35 deletions

View File

@ -128,13 +128,14 @@ def draw_gradio_ui(opt, img2img=lambda x: x, txt2img=lambda x: x, txt2img_defaul
with gr.Column(): with gr.Column():
gr.Markdown('#### Img2Img input') gr.Markdown('#### Img2Img input')
img2img_image_editor = gr.Image(value=sample_img2img, source="upload", interactive=True, img2img_image_editor = gr.Image(value=sample_img2img, source="upload", interactive=True,
type="pil", tool="select", elem_id="img2img_editor") type="pil", tool="select", elem_id="img2img_editor",
image_mode="RGBA")
img2img_image_mask = gr.Image(value=sample_img2img, source="upload", interactive=True, img2img_image_mask = gr.Image(value=sample_img2img, source="upload", interactive=True,
type="pil", tool="sketch", visible=False, type="pil", tool="sketch", visible=False,
elem_id="img2img_mask") elem_id="img2img_mask")
with gr.Row(): with gr.Row():
img2img_image_editor_mode = gr.Radio(choices=["Mask", "Crop"], label="Image Editor Mode", img2img_image_editor_mode = gr.Radio(choices=["Mask", "Crop", "Uncrop"], label="Image Editor Mode",
value="Crop", elem_id='edit_mode_select') value="Crop", elem_id='edit_mode_select')
img2img_painterro_btn = gr.Button("Advanced Editor") img2img_painterro_btn = gr.Button("Advanced Editor")

View File

@ -89,7 +89,7 @@ def resize_image(resize_mode, im, width, height):
src_h = height if ratio <= src_ratio else im.height * width // im.width src_h = height if ratio <= src_ratio else im.height * width // im.width
resized = im.resize((src_w, src_h), resample=LANCZOS) resized = im.resize((src_w, src_h), resample=LANCZOS)
res = Image.new("RGB", (width, height)) res = Image.new("RGBA", (width, height))
res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2)) res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
else: else:
ratio = width / height ratio = width / height
@ -99,7 +99,7 @@ def resize_image(resize_mode, im, width, height):
src_h = height if ratio >= src_ratio else im.height * width // im.width src_h = height if ratio >= src_ratio else im.height * width // im.width
resized = im.resize((src_w, src_h), resample=LANCZOS) resized = im.resize((src_w, src_h), resample=LANCZOS)
res = Image.new("RGB", (width, height)) res = Image.new("RGBA", (width, height))
res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2)) res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
if ratio < src_ratio: if ratio < src_ratio:

107
webui.py
View File

@ -200,6 +200,27 @@ class MemUsageMonitor(threading.Thread):
self.stop_flag = True self.stop_flag = True
return self.max_usage, self.total return self.max_usage, self.total
class CFGMaskedDenoiser(nn.Module):
def __init__(self, model):
super().__init__()
self.inner_model = model
def forward(self, x, sigma, uncond, cond, cond_scale, mask, x0, xi):
x_in = x
x_in = torch.cat([x_in] * 2)
sigma_in = torch.cat([sigma] * 2)
cond_in = torch.cat([uncond, cond])
uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
denoised = uncond + (cond - uncond) * cond_scale
if mask is not None:
assert x0 is not None
img_orig = x0
mask_inv = 1. - mask
denoised = (img_orig * mask_inv) + (mask * denoised)
return denoised
class CFGDenoiser(nn.Module): class CFGDenoiser(nn.Module):
def __init__(self, model): def __init__(self, model):
super().__init__() super().__init__()
@ -874,27 +895,8 @@ normalize_prompt_weights, use_GFPGAN, write_info_files, prompt_matrix, init_img,
skip_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode) skip_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode)
output_images.append(gfpgan_esrgan_image) #287 output_images.append(gfpgan_esrgan_image) #287
if init_mask:
#init_mask = init_mask if keep_mask else ImageOps.invert(init_mask)
init_mask = init_mask.filter(ImageFilter.GaussianBlur(mask_blur_strength))
init_mask = init_mask.convert('L')
init_img = init_img.convert('RGB')
image = image.convert('RGB')
if use_RealESRGAN and RealESRGAN is not None:
if RealESRGAN.model.name != realesrgan_model_name:
try_loading_RealESRGAN(realesrgan_model_name)
output, img_mode = RealESRGAN.enhance(np.array(init_img, dtype=np.uint8))
init_img = Image.fromarray(output)
init_img = init_img.convert('RGB')
output, img_mode = RealESRGAN.enhance(np.array(init_mask, dtype=np.uint8))
init_mask = Image.fromarray(output)
init_mask = init_mask.convert('L')
image = Image.composite(init_img, image, init_mask)
if not skip_save or (not use_GFPGAN or not use_RealESRGAN): if not skip_save or (not use_GFPGAN or not use_RealESRGAN):
save_sample(image, sample_path_i, filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale, save_sample(image, sample_path_i, filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale,
normalize_prompt_weights, use_GFPGAN, write_info_files, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, skip_save, normalize_prompt_weights, use_GFPGAN, write_info_files, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, skip_save,
skip_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode) skip_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode)
@ -1116,7 +1118,7 @@ def img2img(prompt: str, image_editor_mode: str, init_info, mask_mode: str, mask
if image_editor_mode == 'Mask': if image_editor_mode == 'Mask':
init_img = init_info["image"] init_img = init_info["image"]
init_img = init_img.convert("RGB") init_img = init_img.convert("RGBA")
init_img = resize_image(resize_mode, init_img, width, height) init_img = resize_image(resize_mode, init_img, width, height)
init_mask = init_info["mask"] init_mask = init_info["mask"]
init_mask = init_mask.convert("RGB") init_mask = init_mask.convert("RGB")
@ -1138,6 +1140,28 @@ def img2img(prompt: str, image_editor_mode: str, init_info, mask_mode: str, mask
image = image[None].transpose(0, 3, 1, 2) image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image) image = torch.from_numpy(image)
mask_channel = None
if image_editor_mode == "Uncrop":
alpha = init_img.convert("RGBA")
alpha = resize_image(resize_mode, alpha, width // 8, height // 8)
mask_channel = alpha.split()[-1]
mask_channel = mask_channel.filter(ImageFilter.GaussianBlur(4))
mask_channel = np.array(mask_channel)
mask_channel[mask_channel >= 255] = 255
mask_channel[mask_channel < 255] = 0
mask_channel = Image.fromarray(mask_channel).filter(ImageFilter.GaussianBlur(2))
elif init_mask is not None:
alpha = init_mask.convert("RGBA")
alpha = resize_image(resize_mode, alpha, width // 8, height // 8)
mask_channel = alpha.split()[1]
mask = None
if mask_channel is not None:
mask = np.array(mask_channel).astype(np.float32) / 255.0
mask = (1 - mask)
mask = np.tile(mask, (4, 1, 1))
mask = mask[None].transpose(0, 1, 2, 3)
mask = torch.from_numpy(mask).to(device)
if opt.optimized: if opt.optimized:
modelFS.to(device) modelFS.to(device)
@ -1152,27 +1176,48 @@ def img2img(prompt: str, image_editor_mode: str, init_info, mask_mode: str, mask
while(torch.cuda.memory_allocated()/1e6 >= mem): while(torch.cuda.memory_allocated()/1e6 >= mem):
time.sleep(1) time.sleep(1)
return init_latent, return init_latent, mask,
def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name): def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name):
t_enc_steps = t_enc
obliterate = False
if ddim_steps == t_enc_steps:
t_enc_steps = t_enc_steps - 1
obliterate = True
if sampler_name != 'DDIM': if sampler_name != 'DDIM':
x0, = init_data x0, z_mask = init_data
sigmas = sampler.model_wrap.get_sigmas(ddim_steps) sigmas = sampler.model_wrap.get_sigmas(ddim_steps)
noise = x * sigmas[ddim_steps - t_enc - 1] noise = x * sigmas[ddim_steps - t_enc_steps - 1]
xi = x0 + noise xi = x0 + noise
sigma_sched = sigmas[ddim_steps - t_enc - 1:]
model_wrap_cfg = CFGDenoiser(sampler.model_wrap) # Obliterate masked image
samples_ddim = K.sampling.__dict__[f'sample_{sampler.get_sampler_name()}'](model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': cfg_scale}, disable=False) if z_mask is not None and obliterate:
random = torch.randn(z_mask.shape, device=xi.device)
xi = (z_mask * noise) + ((1-z_mask) * xi)
sigma_sched = sigmas[ddim_steps - t_enc_steps - 1:]
model_wrap_cfg = CFGMaskedDenoiser(sampler.model_wrap)
samples_ddim = K.sampling.__dict__[f'sample_{sampler.get_sampler_name()}'](model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': cfg_scale, 'mask': z_mask, 'x0': x0, 'xi': xi}, disable=False)
else: else:
x0, = init_data
x0, z_mask = init_data
sampler.make_schedule(ddim_num_steps=ddim_steps, ddim_eta=0.0, verbose=False) sampler.make_schedule(ddim_num_steps=ddim_steps, ddim_eta=0.0, verbose=False)
z_enc = sampler.stochastic_encode(x0, torch.tensor([t_enc]*batch_size).to(device)) z_enc = sampler.stochastic_encode(x0, torch.tensor([t_enc_steps]*batch_size).to(device))
# Obliterate masked image
if z_mask is not None and obliterate:
random = torch.randn(z_mask.shape, device=z_enc.device)
z_enc = (z_mask * random) + ((1-z_mask) * z_enc)
# decode it # decode it
samples_ddim = sampler.decode(z_enc, conditioning, t_enc, samples_ddim = sampler.decode(z_enc, conditioning, t_enc_steps,
unconditional_guidance_scale=cfg_scale, unconditional_guidance_scale=cfg_scale,
unconditional_conditioning=unconditional_conditioning,) unconditional_conditioning=unconditional_conditioning,
z_mask=z_mask, x0=x0)
return samples_ddim return samples_ddim