mirror of
https://github.com/sd-webui/stable-diffusion-webui.git
synced 2024-12-15 15:22:55 +03:00
Implement masking during sampling to improve blending (#308)
* Implement masking during sampling to improve blending - Almost inpainting-like behaviour - Works best with small erased areas - Denoiser strength to 1 results to completely in-paint with no previous image info - Uncrop kind of sometimes works, but not really - DDIM changes depend on a PR in stable-diffusion core * Remove old init_masking code on save, as the masked image is applied during sampling Co-authored-by: anon-hlhl <> Co-authored-by: hlky <106811348+hlky@users.noreply.github.com>
This commit is contained in:
parent
96aba4b36d
commit
c73fdd78d9
@ -128,13 +128,14 @@ def draw_gradio_ui(opt, img2img=lambda x: x, txt2img=lambda x: x, txt2img_defaul
|
||||
with gr.Column():
|
||||
gr.Markdown('#### Img2Img input')
|
||||
img2img_image_editor = gr.Image(value=sample_img2img, source="upload", interactive=True,
|
||||
type="pil", tool="select", elem_id="img2img_editor")
|
||||
type="pil", tool="select", elem_id="img2img_editor",
|
||||
image_mode="RGBA")
|
||||
img2img_image_mask = gr.Image(value=sample_img2img, source="upload", interactive=True,
|
||||
type="pil", tool="sketch", visible=False,
|
||||
elem_id="img2img_mask")
|
||||
|
||||
with gr.Row():
|
||||
img2img_image_editor_mode = gr.Radio(choices=["Mask", "Crop"], label="Image Editor Mode",
|
||||
img2img_image_editor_mode = gr.Radio(choices=["Mask", "Crop", "Uncrop"], label="Image Editor Mode",
|
||||
value="Crop", elem_id='edit_mode_select')
|
||||
|
||||
img2img_painterro_btn = gr.Button("Advanced Editor")
|
||||
|
@ -89,7 +89,7 @@ def resize_image(resize_mode, im, width, height):
|
||||
src_h = height if ratio <= src_ratio else im.height * width // im.width
|
||||
|
||||
resized = im.resize((src_w, src_h), resample=LANCZOS)
|
||||
res = Image.new("RGB", (width, height))
|
||||
res = Image.new("RGBA", (width, height))
|
||||
res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
|
||||
else:
|
||||
ratio = width / height
|
||||
@ -99,7 +99,7 @@ def resize_image(resize_mode, im, width, height):
|
||||
src_h = height if ratio >= src_ratio else im.height * width // im.width
|
||||
|
||||
resized = im.resize((src_w, src_h), resample=LANCZOS)
|
||||
res = Image.new("RGB", (width, height))
|
||||
res = Image.new("RGBA", (width, height))
|
||||
res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
|
||||
|
||||
if ratio < src_ratio:
|
||||
|
107
webui.py
107
webui.py
@ -200,6 +200,27 @@ class MemUsageMonitor(threading.Thread):
|
||||
self.stop_flag = True
|
||||
return self.max_usage, self.total
|
||||
|
||||
class CFGMaskedDenoiser(nn.Module):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.inner_model = model
|
||||
|
||||
def forward(self, x, sigma, uncond, cond, cond_scale, mask, x0, xi):
|
||||
x_in = x
|
||||
x_in = torch.cat([x_in] * 2)
|
||||
sigma_in = torch.cat([sigma] * 2)
|
||||
cond_in = torch.cat([uncond, cond])
|
||||
uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
|
||||
denoised = uncond + (cond - uncond) * cond_scale
|
||||
|
||||
if mask is not None:
|
||||
assert x0 is not None
|
||||
img_orig = x0
|
||||
mask_inv = 1. - mask
|
||||
denoised = (img_orig * mask_inv) + (mask * denoised)
|
||||
|
||||
return denoised
|
||||
|
||||
class CFGDenoiser(nn.Module):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
@ -874,27 +895,8 @@ normalize_prompt_weights, use_GFPGAN, write_info_files, prompt_matrix, init_img,
|
||||
skip_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode)
|
||||
output_images.append(gfpgan_esrgan_image) #287
|
||||
|
||||
|
||||
if init_mask:
|
||||
#init_mask = init_mask if keep_mask else ImageOps.invert(init_mask)
|
||||
init_mask = init_mask.filter(ImageFilter.GaussianBlur(mask_blur_strength))
|
||||
init_mask = init_mask.convert('L')
|
||||
init_img = init_img.convert('RGB')
|
||||
image = image.convert('RGB')
|
||||
|
||||
if use_RealESRGAN and RealESRGAN is not None:
|
||||
if RealESRGAN.model.name != realesrgan_model_name:
|
||||
try_loading_RealESRGAN(realesrgan_model_name)
|
||||
output, img_mode = RealESRGAN.enhance(np.array(init_img, dtype=np.uint8))
|
||||
init_img = Image.fromarray(output)
|
||||
init_img = init_img.convert('RGB')
|
||||
|
||||
output, img_mode = RealESRGAN.enhance(np.array(init_mask, dtype=np.uint8))
|
||||
init_mask = Image.fromarray(output)
|
||||
init_mask = init_mask.convert('L')
|
||||
|
||||
image = Image.composite(init_img, image, init_mask)
|
||||
if not skip_save or (not use_GFPGAN or not use_RealESRGAN):
|
||||
|
||||
save_sample(image, sample_path_i, filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale,
|
||||
normalize_prompt_weights, use_GFPGAN, write_info_files, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, skip_save,
|
||||
skip_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode)
|
||||
@ -1116,7 +1118,7 @@ def img2img(prompt: str, image_editor_mode: str, init_info, mask_mode: str, mask
|
||||
|
||||
if image_editor_mode == 'Mask':
|
||||
init_img = init_info["image"]
|
||||
init_img = init_img.convert("RGB")
|
||||
init_img = init_img.convert("RGBA")
|
||||
init_img = resize_image(resize_mode, init_img, width, height)
|
||||
init_mask = init_info["mask"]
|
||||
init_mask = init_mask.convert("RGB")
|
||||
@ -1138,6 +1140,28 @@ def img2img(prompt: str, image_editor_mode: str, init_info, mask_mode: str, mask
|
||||
image = image[None].transpose(0, 3, 1, 2)
|
||||
image = torch.from_numpy(image)
|
||||
|
||||
mask_channel = None
|
||||
if image_editor_mode == "Uncrop":
|
||||
alpha = init_img.convert("RGBA")
|
||||
alpha = resize_image(resize_mode, alpha, width // 8, height // 8)
|
||||
mask_channel = alpha.split()[-1]
|
||||
mask_channel = mask_channel.filter(ImageFilter.GaussianBlur(4))
|
||||
mask_channel = np.array(mask_channel)
|
||||
mask_channel[mask_channel >= 255] = 255
|
||||
mask_channel[mask_channel < 255] = 0
|
||||
mask_channel = Image.fromarray(mask_channel).filter(ImageFilter.GaussianBlur(2))
|
||||
elif init_mask is not None:
|
||||
alpha = init_mask.convert("RGBA")
|
||||
alpha = resize_image(resize_mode, alpha, width // 8, height // 8)
|
||||
mask_channel = alpha.split()[1]
|
||||
|
||||
mask = None
|
||||
if mask_channel is not None:
|
||||
mask = np.array(mask_channel).astype(np.float32) / 255.0
|
||||
mask = (1 - mask)
|
||||
mask = np.tile(mask, (4, 1, 1))
|
||||
mask = mask[None].transpose(0, 1, 2, 3)
|
||||
mask = torch.from_numpy(mask).to(device)
|
||||
if opt.optimized:
|
||||
modelFS.to(device)
|
||||
|
||||
@ -1152,27 +1176,48 @@ def img2img(prompt: str, image_editor_mode: str, init_info, mask_mode: str, mask
|
||||
while(torch.cuda.memory_allocated()/1e6 >= mem):
|
||||
time.sleep(1)
|
||||
|
||||
return init_latent,
|
||||
return init_latent, mask,
|
||||
|
||||
def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name):
|
||||
t_enc_steps = t_enc
|
||||
obliterate = False
|
||||
if ddim_steps == t_enc_steps:
|
||||
t_enc_steps = t_enc_steps - 1
|
||||
obliterate = True
|
||||
|
||||
if sampler_name != 'DDIM':
|
||||
x0, = init_data
|
||||
x0, z_mask = init_data
|
||||
|
||||
sigmas = sampler.model_wrap.get_sigmas(ddim_steps)
|
||||
noise = x * sigmas[ddim_steps - t_enc - 1]
|
||||
noise = x * sigmas[ddim_steps - t_enc_steps - 1]
|
||||
|
||||
xi = x0 + noise
|
||||
sigma_sched = sigmas[ddim_steps - t_enc - 1:]
|
||||
model_wrap_cfg = CFGDenoiser(sampler.model_wrap)
|
||||
samples_ddim = K.sampling.__dict__[f'sample_{sampler.get_sampler_name()}'](model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': cfg_scale}, disable=False)
|
||||
|
||||
# Obliterate masked image
|
||||
if z_mask is not None and obliterate:
|
||||
random = torch.randn(z_mask.shape, device=xi.device)
|
||||
xi = (z_mask * noise) + ((1-z_mask) * xi)
|
||||
|
||||
sigma_sched = sigmas[ddim_steps - t_enc_steps - 1:]
|
||||
model_wrap_cfg = CFGMaskedDenoiser(sampler.model_wrap)
|
||||
samples_ddim = K.sampling.__dict__[f'sample_{sampler.get_sampler_name()}'](model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': cfg_scale, 'mask': z_mask, 'x0': x0, 'xi': xi}, disable=False)
|
||||
else:
|
||||
x0, = init_data
|
||||
|
||||
x0, z_mask = init_data
|
||||
|
||||
sampler.make_schedule(ddim_num_steps=ddim_steps, ddim_eta=0.0, verbose=False)
|
||||
z_enc = sampler.stochastic_encode(x0, torch.tensor([t_enc]*batch_size).to(device))
|
||||
z_enc = sampler.stochastic_encode(x0, torch.tensor([t_enc_steps]*batch_size).to(device))
|
||||
|
||||
# Obliterate masked image
|
||||
if z_mask is not None and obliterate:
|
||||
random = torch.randn(z_mask.shape, device=z_enc.device)
|
||||
z_enc = (z_mask * random) + ((1-z_mask) * z_enc)
|
||||
|
||||
# decode it
|
||||
samples_ddim = sampler.decode(z_enc, conditioning, t_enc,
|
||||
samples_ddim = sampler.decode(z_enc, conditioning, t_enc_steps,
|
||||
unconditional_guidance_scale=cfg_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,)
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
z_mask=z_mask, x0=x0)
|
||||
return samples_ddim
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user