mirror of
https://github.com/sd-webui/stable-diffusion-webui.git
synced 2024-12-16 08:22:33 +03:00
Implement masking during sampling to improve blending (#308)
* Implement masking during sampling to improve blending - Almost inpainting-like behaviour - Works best with small erased areas - Denoiser strength to 1 results to completely in-paint with no previous image info - Uncrop kind of sometimes works, but not really - DDIM changes depend on a PR in stable-diffusion core * Remove old init_masking code on save, as the masked image is applied during sampling Co-authored-by: anon-hlhl <> Co-authored-by: hlky <106811348+hlky@users.noreply.github.com>
This commit is contained in:
parent
96aba4b36d
commit
c73fdd78d9
@ -128,13 +128,14 @@ def draw_gradio_ui(opt, img2img=lambda x: x, txt2img=lambda x: x, txt2img_defaul
|
|||||||
with gr.Column():
|
with gr.Column():
|
||||||
gr.Markdown('#### Img2Img input')
|
gr.Markdown('#### Img2Img input')
|
||||||
img2img_image_editor = gr.Image(value=sample_img2img, source="upload", interactive=True,
|
img2img_image_editor = gr.Image(value=sample_img2img, source="upload", interactive=True,
|
||||||
type="pil", tool="select", elem_id="img2img_editor")
|
type="pil", tool="select", elem_id="img2img_editor",
|
||||||
|
image_mode="RGBA")
|
||||||
img2img_image_mask = gr.Image(value=sample_img2img, source="upload", interactive=True,
|
img2img_image_mask = gr.Image(value=sample_img2img, source="upload", interactive=True,
|
||||||
type="pil", tool="sketch", visible=False,
|
type="pil", tool="sketch", visible=False,
|
||||||
elem_id="img2img_mask")
|
elem_id="img2img_mask")
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
img2img_image_editor_mode = gr.Radio(choices=["Mask", "Crop"], label="Image Editor Mode",
|
img2img_image_editor_mode = gr.Radio(choices=["Mask", "Crop", "Uncrop"], label="Image Editor Mode",
|
||||||
value="Crop", elem_id='edit_mode_select')
|
value="Crop", elem_id='edit_mode_select')
|
||||||
|
|
||||||
img2img_painterro_btn = gr.Button("Advanced Editor")
|
img2img_painterro_btn = gr.Button("Advanced Editor")
|
||||||
|
@ -89,7 +89,7 @@ def resize_image(resize_mode, im, width, height):
|
|||||||
src_h = height if ratio <= src_ratio else im.height * width // im.width
|
src_h = height if ratio <= src_ratio else im.height * width // im.width
|
||||||
|
|
||||||
resized = im.resize((src_w, src_h), resample=LANCZOS)
|
resized = im.resize((src_w, src_h), resample=LANCZOS)
|
||||||
res = Image.new("RGB", (width, height))
|
res = Image.new("RGBA", (width, height))
|
||||||
res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
|
res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
|
||||||
else:
|
else:
|
||||||
ratio = width / height
|
ratio = width / height
|
||||||
@ -99,7 +99,7 @@ def resize_image(resize_mode, im, width, height):
|
|||||||
src_h = height if ratio >= src_ratio else im.height * width // im.width
|
src_h = height if ratio >= src_ratio else im.height * width // im.width
|
||||||
|
|
||||||
resized = im.resize((src_w, src_h), resample=LANCZOS)
|
resized = im.resize((src_w, src_h), resample=LANCZOS)
|
||||||
res = Image.new("RGB", (width, height))
|
res = Image.new("RGBA", (width, height))
|
||||||
res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
|
res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
|
||||||
|
|
||||||
if ratio < src_ratio:
|
if ratio < src_ratio:
|
||||||
|
107
webui.py
107
webui.py
@ -200,6 +200,27 @@ class MemUsageMonitor(threading.Thread):
|
|||||||
self.stop_flag = True
|
self.stop_flag = True
|
||||||
return self.max_usage, self.total
|
return self.max_usage, self.total
|
||||||
|
|
||||||
|
class CFGMaskedDenoiser(nn.Module):
|
||||||
|
def __init__(self, model):
|
||||||
|
super().__init__()
|
||||||
|
self.inner_model = model
|
||||||
|
|
||||||
|
def forward(self, x, sigma, uncond, cond, cond_scale, mask, x0, xi):
|
||||||
|
x_in = x
|
||||||
|
x_in = torch.cat([x_in] * 2)
|
||||||
|
sigma_in = torch.cat([sigma] * 2)
|
||||||
|
cond_in = torch.cat([uncond, cond])
|
||||||
|
uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
|
||||||
|
denoised = uncond + (cond - uncond) * cond_scale
|
||||||
|
|
||||||
|
if mask is not None:
|
||||||
|
assert x0 is not None
|
||||||
|
img_orig = x0
|
||||||
|
mask_inv = 1. - mask
|
||||||
|
denoised = (img_orig * mask_inv) + (mask * denoised)
|
||||||
|
|
||||||
|
return denoised
|
||||||
|
|
||||||
class CFGDenoiser(nn.Module):
|
class CFGDenoiser(nn.Module):
|
||||||
def __init__(self, model):
|
def __init__(self, model):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -874,27 +895,8 @@ normalize_prompt_weights, use_GFPGAN, write_info_files, prompt_matrix, init_img,
|
|||||||
skip_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode)
|
skip_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode)
|
||||||
output_images.append(gfpgan_esrgan_image) #287
|
output_images.append(gfpgan_esrgan_image) #287
|
||||||
|
|
||||||
|
|
||||||
if init_mask:
|
|
||||||
#init_mask = init_mask if keep_mask else ImageOps.invert(init_mask)
|
|
||||||
init_mask = init_mask.filter(ImageFilter.GaussianBlur(mask_blur_strength))
|
|
||||||
init_mask = init_mask.convert('L')
|
|
||||||
init_img = init_img.convert('RGB')
|
|
||||||
image = image.convert('RGB')
|
|
||||||
|
|
||||||
if use_RealESRGAN and RealESRGAN is not None:
|
|
||||||
if RealESRGAN.model.name != realesrgan_model_name:
|
|
||||||
try_loading_RealESRGAN(realesrgan_model_name)
|
|
||||||
output, img_mode = RealESRGAN.enhance(np.array(init_img, dtype=np.uint8))
|
|
||||||
init_img = Image.fromarray(output)
|
|
||||||
init_img = init_img.convert('RGB')
|
|
||||||
|
|
||||||
output, img_mode = RealESRGAN.enhance(np.array(init_mask, dtype=np.uint8))
|
|
||||||
init_mask = Image.fromarray(output)
|
|
||||||
init_mask = init_mask.convert('L')
|
|
||||||
|
|
||||||
image = Image.composite(init_img, image, init_mask)
|
|
||||||
if not skip_save or (not use_GFPGAN or not use_RealESRGAN):
|
if not skip_save or (not use_GFPGAN or not use_RealESRGAN):
|
||||||
|
|
||||||
save_sample(image, sample_path_i, filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale,
|
save_sample(image, sample_path_i, filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale,
|
||||||
normalize_prompt_weights, use_GFPGAN, write_info_files, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, skip_save,
|
normalize_prompt_weights, use_GFPGAN, write_info_files, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, skip_save,
|
||||||
skip_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode)
|
skip_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode)
|
||||||
@ -1116,7 +1118,7 @@ def img2img(prompt: str, image_editor_mode: str, init_info, mask_mode: str, mask
|
|||||||
|
|
||||||
if image_editor_mode == 'Mask':
|
if image_editor_mode == 'Mask':
|
||||||
init_img = init_info["image"]
|
init_img = init_info["image"]
|
||||||
init_img = init_img.convert("RGB")
|
init_img = init_img.convert("RGBA")
|
||||||
init_img = resize_image(resize_mode, init_img, width, height)
|
init_img = resize_image(resize_mode, init_img, width, height)
|
||||||
init_mask = init_info["mask"]
|
init_mask = init_info["mask"]
|
||||||
init_mask = init_mask.convert("RGB")
|
init_mask = init_mask.convert("RGB")
|
||||||
@ -1138,6 +1140,28 @@ def img2img(prompt: str, image_editor_mode: str, init_info, mask_mode: str, mask
|
|||||||
image = image[None].transpose(0, 3, 1, 2)
|
image = image[None].transpose(0, 3, 1, 2)
|
||||||
image = torch.from_numpy(image)
|
image = torch.from_numpy(image)
|
||||||
|
|
||||||
|
mask_channel = None
|
||||||
|
if image_editor_mode == "Uncrop":
|
||||||
|
alpha = init_img.convert("RGBA")
|
||||||
|
alpha = resize_image(resize_mode, alpha, width // 8, height // 8)
|
||||||
|
mask_channel = alpha.split()[-1]
|
||||||
|
mask_channel = mask_channel.filter(ImageFilter.GaussianBlur(4))
|
||||||
|
mask_channel = np.array(mask_channel)
|
||||||
|
mask_channel[mask_channel >= 255] = 255
|
||||||
|
mask_channel[mask_channel < 255] = 0
|
||||||
|
mask_channel = Image.fromarray(mask_channel).filter(ImageFilter.GaussianBlur(2))
|
||||||
|
elif init_mask is not None:
|
||||||
|
alpha = init_mask.convert("RGBA")
|
||||||
|
alpha = resize_image(resize_mode, alpha, width // 8, height // 8)
|
||||||
|
mask_channel = alpha.split()[1]
|
||||||
|
|
||||||
|
mask = None
|
||||||
|
if mask_channel is not None:
|
||||||
|
mask = np.array(mask_channel).astype(np.float32) / 255.0
|
||||||
|
mask = (1 - mask)
|
||||||
|
mask = np.tile(mask, (4, 1, 1))
|
||||||
|
mask = mask[None].transpose(0, 1, 2, 3)
|
||||||
|
mask = torch.from_numpy(mask).to(device)
|
||||||
if opt.optimized:
|
if opt.optimized:
|
||||||
modelFS.to(device)
|
modelFS.to(device)
|
||||||
|
|
||||||
@ -1152,27 +1176,48 @@ def img2img(prompt: str, image_editor_mode: str, init_info, mask_mode: str, mask
|
|||||||
while(torch.cuda.memory_allocated()/1e6 >= mem):
|
while(torch.cuda.memory_allocated()/1e6 >= mem):
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
|
|
||||||
return init_latent,
|
return init_latent, mask,
|
||||||
|
|
||||||
def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name):
|
def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name):
|
||||||
|
t_enc_steps = t_enc
|
||||||
|
obliterate = False
|
||||||
|
if ddim_steps == t_enc_steps:
|
||||||
|
t_enc_steps = t_enc_steps - 1
|
||||||
|
obliterate = True
|
||||||
|
|
||||||
if sampler_name != 'DDIM':
|
if sampler_name != 'DDIM':
|
||||||
x0, = init_data
|
x0, z_mask = init_data
|
||||||
|
|
||||||
sigmas = sampler.model_wrap.get_sigmas(ddim_steps)
|
sigmas = sampler.model_wrap.get_sigmas(ddim_steps)
|
||||||
noise = x * sigmas[ddim_steps - t_enc - 1]
|
noise = x * sigmas[ddim_steps - t_enc_steps - 1]
|
||||||
|
|
||||||
xi = x0 + noise
|
xi = x0 + noise
|
||||||
sigma_sched = sigmas[ddim_steps - t_enc - 1:]
|
|
||||||
model_wrap_cfg = CFGDenoiser(sampler.model_wrap)
|
# Obliterate masked image
|
||||||
samples_ddim = K.sampling.__dict__[f'sample_{sampler.get_sampler_name()}'](model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': cfg_scale}, disable=False)
|
if z_mask is not None and obliterate:
|
||||||
|
random = torch.randn(z_mask.shape, device=xi.device)
|
||||||
|
xi = (z_mask * noise) + ((1-z_mask) * xi)
|
||||||
|
|
||||||
|
sigma_sched = sigmas[ddim_steps - t_enc_steps - 1:]
|
||||||
|
model_wrap_cfg = CFGMaskedDenoiser(sampler.model_wrap)
|
||||||
|
samples_ddim = K.sampling.__dict__[f'sample_{sampler.get_sampler_name()}'](model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': cfg_scale, 'mask': z_mask, 'x0': x0, 'xi': xi}, disable=False)
|
||||||
else:
|
else:
|
||||||
x0, = init_data
|
|
||||||
|
x0, z_mask = init_data
|
||||||
|
|
||||||
sampler.make_schedule(ddim_num_steps=ddim_steps, ddim_eta=0.0, verbose=False)
|
sampler.make_schedule(ddim_num_steps=ddim_steps, ddim_eta=0.0, verbose=False)
|
||||||
z_enc = sampler.stochastic_encode(x0, torch.tensor([t_enc]*batch_size).to(device))
|
z_enc = sampler.stochastic_encode(x0, torch.tensor([t_enc_steps]*batch_size).to(device))
|
||||||
|
|
||||||
|
# Obliterate masked image
|
||||||
|
if z_mask is not None and obliterate:
|
||||||
|
random = torch.randn(z_mask.shape, device=z_enc.device)
|
||||||
|
z_enc = (z_mask * random) + ((1-z_mask) * z_enc)
|
||||||
|
|
||||||
# decode it
|
# decode it
|
||||||
samples_ddim = sampler.decode(z_enc, conditioning, t_enc,
|
samples_ddim = sampler.decode(z_enc, conditioning, t_enc_steps,
|
||||||
unconditional_guidance_scale=cfg_scale,
|
unconditional_guidance_scale=cfg_scale,
|
||||||
unconditional_conditioning=unconditional_conditioning,)
|
unconditional_conditioning=unconditional_conditioning,
|
||||||
|
z_mask=z_mask, x0=x0)
|
||||||
return samples_ddim
|
return samples_ddim
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user