Add mask_restore option to give users the option to restore images based on mask, fixing #665.

Before commit c73fdd78  (Implement masking during sampling to improve blending, #308)
image mask was applied after sampling, resulting in masked parts that are not regenerated
to actually stay the same.
Since c73fdd78 the masked img2img will change the whole image, even in masked areas.
It gives better looking results at first glance, but will result in image degredation when
applied a few times. See issue #665.

In the workflow of using repeated masked img2img, users may want to use this options to keep the parts
of image they actually want to keep without image degradation. A final masked img2img or whole image img2img with mask_restore disabled
will give the better blending of "Implement masking during sampling".
This commit is contained in:
xaedes 2022-09-09 09:58:46 +02:00
parent 90a922c320
commit a7be43ba92
3 changed files with 102 additions and 14 deletions

View File

@ -197,9 +197,13 @@ def draw_gradio_ui(opt, img2img=lambda x: x, txt2img=lambda x: x, imgproc=lambda
value=img2img_mask_modes[img2img_defaults['mask_mode']], value=img2img_mask_modes[img2img_defaults['mask_mode']],
visible=True) visible=True)
img2img_mask_blur_strength = gr.Slider(minimum=1, maximum=10, step=1, img2img_mask_restore = gr.Checkbox(label="Restore image by mask",
value=img2img_defaults['mask_restore'],
visible=True)
img2img_mask_blur_strength = gr.Slider(minimum=1, maximum=100, step=1,
label="How much blurry should the mask be? (to avoid hard edges)", label="How much blurry should the mask be? (to avoid hard edges)",
value=3, visible=False) value=3, visible=True)
img2img_resize = gr.Radio(label="Resize mode", img2img_resize = gr.Radio(label="Resize mode",
choices=["Just resize", "Crop and resize", choices=["Just resize", "Crop and resize",
@ -290,8 +294,14 @@ def draw_gradio_ui(opt, img2img=lambda x: x, txt2img=lambda x: x, imgproc=lambda
img2img_width, img2img_width,
img2img_height img2img_height
], ],
[img2img_image_editor, img2img_image_mask, img2img_btn_editor, img2img_btn_mask, [img2img_image_editor,
img2img_painterro_btn, img2img_mask, img2img_mask_blur_strength] img2img_image_mask,
img2img_btn_editor,
img2img_btn_mask,
img2img_painterro_btn,
img2img_mask,
img2img_mask_blur_strength,
img2img_mask_restore]
) )
# img2img_image_editor_mode.change( # img2img_image_editor_mode.change(
@ -332,8 +342,8 @@ def draw_gradio_ui(opt, img2img=lambda x: x, txt2img=lambda x: x, imgproc=lambda
) )
img2img_func = img2img img2img_func = img2img
img2img_inputs = [img2img_prompt, img2img_image_editor_mode, img2img_mask, img2img_inputs = [img2img_prompt, img2img_image_editor_mode, img2img_mask, img2img_mask_blur_strength,
img2img_mask_blur_strength, img2img_steps, img2img_sampling, img2img_toggles, img2img_mask_restore, img2img_steps, img2img_sampling, img2img_toggles,
img2img_realesrgan_model_name, img2img_batch_count, img2img_cfg, img2img_realesrgan_model_name, img2img_batch_count, img2img_cfg,
img2img_denoising, img2img_seed, img2img_height, img2img_width, img2img_resize, img2img_denoising, img2img_seed, img2img_height, img2img_width, img2img_resize,
img2img_image_editor, img2img_image_mask, img2img_embeddings] img2img_image_editor, img2img_image_mask, img2img_embeddings]

View File

@ -781,7 +781,7 @@ def process_images(
outpath, func_init, func_sample, prompt, seed, sampler_name, skip_grid, skip_save, batch_size, outpath, func_init, func_sample, prompt, seed, sampler_name, skip_grid, skip_save, batch_size,
n_iter, steps, cfg_scale, width, height, prompt_matrix, use_GFPGAN, use_RealESRGAN, realesrgan_model_name, n_iter, steps, cfg_scale, width, height, prompt_matrix, use_GFPGAN, use_RealESRGAN, realesrgan_model_name,
fp, ddim_eta=0.0, do_not_save_grid=False, normalize_prompt_weights=True, init_img=None, init_mask=None, fp, ddim_eta=0.0, do_not_save_grid=False, normalize_prompt_weights=True, init_img=None, init_mask=None,
keep_mask=False, mask_blur_strength=3, denoising_strength=0.75, resize_mode=None, uses_loopback=False, keep_mask=False, mask_blur_strength=3, mask_restore=False, denoising_strength=0.75, resize_mode=None, uses_loopback=False,
uses_random_seed_loopback=False, sort_samples=True, write_info_files=True, write_sample_info_to_log_file=False, jpg_sample=False, uses_random_seed_loopback=False, sort_samples=True, write_info_files=True, write_sample_info_to_log_file=False, jpg_sample=False,
variant_amount=0.0, variant_seed=None,imgProcessorTask=False, job_info: JobInfo = None): variant_amount=0.0, variant_seed=None,imgProcessorTask=False, job_info: JobInfo = None):
"""this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch""" """this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
@ -1018,6 +1018,26 @@ skip_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoisin
if imgProcessorTask == True: if imgProcessorTask == True:
output_images.append(image) output_images.append(image)
if mask_restore and init_mask:
#init_mask = init_mask if keep_mask else ImageOps.invert(init_mask)
init_mask = init_mask.filter(ImageFilter.GaussianBlur(mask_blur_strength))
init_mask = init_mask.convert('L')
init_img = init_img.convert('RGB')
image = image.convert('RGB')
if use_RealESRGAN and RealESRGAN is not None:
if RealESRGAN.model.name != realesrgan_model_name:
try_loading_RealESRGAN(realesrgan_model_name)
output, img_mode = RealESRGAN.enhance(np.array(init_img, dtype=np.uint8))
init_img = Image.fromarray(output)
init_img = init_img.convert('RGB')
output, img_mode = RealESRGAN.enhance(np.array(init_mask, dtype=np.uint8))
init_mask = Image.fromarray(output)
init_mask = init_mask.convert('L')
image = Image.composite(init_img, image, init_mask)
if not skip_save: if not skip_save:
save_sample(image, sample_path_i, filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale, save_sample(image, sample_path_i, filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale,
normalize_prompt_weights, use_GFPGAN, write_info_files, write_sample_info_to_log_file, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, skip_save, normalize_prompt_weights, use_GFPGAN, write_info_files, write_sample_info_to_log_file, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, skip_save,
@ -1225,7 +1245,7 @@ class Flagging(gr.FlaggingCallback):
print("Logged:", filenames[0]) print("Logged:", filenames[0])
def img2img(prompt: str, image_editor_mode: str, mask_mode: str, mask_blur_strength: int, ddim_steps: int, sampler_name: str, def img2img(prompt: str, image_editor_mode: str, mask_mode: str, mask_blur_strength: int, mask_restore: bool, ddim_steps: int, sampler_name: str,
toggles: List[int], realesrgan_model_name: str, n_iter: int, cfg_scale: float, denoising_strength: float, toggles: List[int], realesrgan_model_name: str, n_iter: int, cfg_scale: float, denoising_strength: float,
seed: int, height: int, width: int, resize_mode: int, init_info: any = None, init_info_mask: any = None, fp = None, job_info: JobInfo = None): seed: int, height: int, width: int, resize_mode: int, init_info: any = None, init_info_mask: any = None, fp = None, job_info: JobInfo = None):
# print([prompt, image_editor_mode, init_info, init_info_mask, mask_mode, # print([prompt, image_editor_mode, init_info, init_info_mask, mask_mode,
@ -1428,6 +1448,7 @@ def img2img(prompt: str, image_editor_mode: str, mask_mode: str, mask_blur_stren
init_mask=init_mask, init_mask=init_mask,
keep_mask=keep_mask, keep_mask=keep_mask,
mask_blur_strength=mask_blur_strength, mask_blur_strength=mask_blur_strength,
mask_restore=mask_restore,
denoising_strength=denoising_strength, denoising_strength=denoising_strength,
resize_mode=resize_mode, resize_mode=resize_mode,
uses_loopback=loopback, uses_loopback=loopback,
@ -1498,6 +1519,7 @@ def img2img(prompt: str, image_editor_mode: str, mask_mode: str, mask_blur_stren
keep_mask=keep_mask, keep_mask=keep_mask,
mask_blur_strength=mask_blur_strength, mask_blur_strength=mask_blur_strength,
denoising_strength=denoising_strength, denoising_strength=denoising_strength,
mask_restore=mask_restore,
resize_mode=resize_mode, resize_mode=resize_mode,
uses_loopback=loopback, uses_loopback=loopback,
sort_samples=sort_samples, sort_samples=sort_samples,
@ -1638,6 +1660,7 @@ def imgproc(image,image_batch,imgproc_prompt,imgproc_toggles, imgproc_upscale_to
init_img = result init_img = result
init_mask = None init_mask = None
keep_mask = False keep_mask = False
mask_restore = False
assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]' assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]'
def init(): def init():
@ -1784,6 +1807,7 @@ def imgproc(image,image_batch,imgproc_prompt,imgproc_toggles, imgproc_upscale_to
keep_mask=False, keep_mask=False,
mask_blur_strength=None, mask_blur_strength=None,
denoising_strength=denoising_strength, denoising_strength=denoising_strength,
mask_restore=mask_restore,
resize_mode=resize_mode, resize_mode=resize_mode,
uses_loopback=False, uses_loopback=False,
sort_samples=True, sort_samples=True,
@ -2086,6 +2110,7 @@ img2img_defaults = {
'cfg_scale': 5.0, 'cfg_scale': 5.0,
'denoising_strength': 0.75, 'denoising_strength': 0.75,
'mask_mode': 0, 'mask_mode': 0,
'mask_restore': False,
'resize_mode': 0, 'resize_mode': 0,
'seed': '', 'seed': '',
'height': 512, 'height': 512,
@ -2099,10 +2124,39 @@ if 'img2img' in user_defaults:
img2img_toggle_defaults = [img2img_toggles[i] for i in img2img_defaults['toggles']] img2img_toggle_defaults = [img2img_toggles[i] for i in img2img_defaults['toggles']]
img2img_image_mode = 'sketch' img2img_image_mode = 'sketch'
def change_image_editor_mode(choice, cropped_image, resize_mode, width, height): def change_image_editor_mode(choice, cropped_image, mask, resize_mode, width, height):
if choice == "Mask": if choice == "Mask":
return [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), gr.update(visible=True)] update_image_editor = gr.update(visible=False)
return [gr.update(visible=True), gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=True), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)] update_image_mask = gr.update(visible=True)
update_btn_editor = gr.update(visible=False)
update_btn_mask = gr.update(visible=True)
update_painterro_btn = gr.update(visible=False)
update_mask = gr.update(visible=False)
update_mask_blur_strength = gr.update(visible=True)
update_mask_restore = gr.update(visible=True)
# unknown = gr.update(visible=True)
else:
update_image_editor = gr.update(visible=True)
update_image_mask = gr.update(visible=False)
update_btn_editor = gr.update(visible=True)
update_btn_mask = gr.update(visible=False)
update_painterro_btn = gr.update(visible=True)
update_mask = gr.update(visible=True)
update_mask_blur_strength = gr.update(visible=False)
update_mask_restore = gr.update(visible=False)
# unknown = gr.update(visible=False)
return [
update_image_editor,
update_image_mask,
update_btn_editor,
update_btn_mask,
update_painterro_btn,
update_mask,
update_mask_blur_strength,
update_mask_restore,
# unknown,
]
def update_image_mask(cropped_image, resize_mode, width, height): def update_image_mask(cropped_image, resize_mode, width, height):
resized_cropped_image = resize_image(resize_mode, cropped_image, width, height) if cropped_image else None resized_cropped_image = resize_image(resize_mode, cropped_image, width, height) if cropped_image else None

View File

@ -913,7 +913,7 @@ def process_images(
outpath, func_init, func_sample, prompt, seed, sampler_name, save_grid, batch_size, outpath, func_init, func_sample, prompt, seed, sampler_name, save_grid, batch_size,
n_iter, steps, cfg_scale, width, height, prompt_matrix, use_GFPGAN, use_RealESRGAN, realesrgan_model_name, n_iter, steps, cfg_scale, width, height, prompt_matrix, use_GFPGAN, use_RealESRGAN, realesrgan_model_name,
fp=None, ddim_eta=0.0, normalize_prompt_weights=True, init_img=None, init_mask=None, fp=None, ddim_eta=0.0, normalize_prompt_weights=True, init_img=None, init_mask=None,
keep_mask=False, mask_blur_strength=3, denoising_strength=0.75, resize_mode=None, uses_loopback=False, keep_mask=False, mask_blur_strength=3, mask_restore=False, denoising_strength=0.75, resize_mode=None, uses_loopback=False,
uses_random_seed_loopback=False, sort_samples=True, write_info_files=True, jpg_sample=False, uses_random_seed_loopback=False, sort_samples=True, write_info_files=True, jpg_sample=False,
variant_amount=0.0, variant_seed=None, save_individual_images: bool = True): variant_amount=0.0, variant_seed=None, save_individual_images: bool = True):
"""this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch""" """this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
@ -1157,6 +1157,28 @@ def process_images(
if simple_templating: if simple_templating:
grid_captions.append( captions[i] + "\ngfpgan_esrgan" ) grid_captions.append( captions[i] + "\ngfpgan_esrgan" )
if mask_restore and init_mask:
#init_mask = init_mask if keep_mask else ImageOps.invert(init_mask)
init_mask = init_mask.filter(ImageFilter.GaussianBlur(mask_blur_strength))
init_mask = init_mask.convert('L')
init_img = init_img.convert('RGB')
image = image.convert('RGB')
if use_RealESRGAN and st.session_state["RealESRGAN"] is not None:
if st.session_state["RealESRGAN"].model.name != realesrgan_model_name:
#try_loading_RealESRGAN(realesrgan_model_name)
load_models(use_GFPGAN=use_GFPGAN, use_RealESRGAN=use_RealESRGAN, RealESRGAN_model=realesrgan_model_name)
output, img_mode = st.session_state["RealESRGAN"].enhance(np.array(init_img, dtype=np.uint8))
init_img = Image.fromarray(output)
init_img = init_img.convert('RGB')
output, img_mode = st.session_state["RealESRGAN"].enhance(np.array(init_mask, dtype=np.uint8))
init_mask = Image.fromarray(output)
init_mask = init_mask.convert('L')
image = Image.composite(init_img, image, init_mask)
if save_individual_images: if save_individual_images:
save_sample(image, sample_path_i, filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale, save_sample(image, sample_path_i, filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale,
normalize_prompt_weights, use_GFPGAN, write_info_files, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, normalize_prompt_weights, use_GFPGAN, write_info_files, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback,
@ -1257,7 +1279,7 @@ def resize_image(resize_mode, im, width, height):
return res return res
def img2img(prompt: str = '', init_info: any = None, init_info_mask: any = None, mask_mode: int = 0, mask_blur_strength: int = 3, def img2img(prompt: str = '', init_info: any = None, init_info_mask: any = None, mask_mode: int = 0, mask_blur_strength: int = 3,
ddim_steps: int = 50, sampler_name: str = 'DDIM', mask_restore: bool = False, ddim_steps: int = 50, sampler_name: str = 'DDIM',
n_iter: int = 1, cfg_scale: float = 7.5, denoising_strength: float = 0.8, n_iter: int = 1, cfg_scale: float = 7.5, denoising_strength: float = 0.8,
seed: int = -1, height: int = 512, width: int = 512, resize_mode: int = 0, fp = None, seed: int = -1, height: int = 512, width: int = 512, resize_mode: int = 0, fp = None,
variant_amount: float = None, variant_seed: int = None, ddim_eta:float = 0.0, variant_amount: float = None, variant_seed: int = None, ddim_eta:float = 0.0,
@ -1426,6 +1448,7 @@ def img2img(prompt: str = '', init_info: any = None, init_info_mask: any = None,
init_mask=init_mask, init_mask=init_mask,
keep_mask=keep_mask, keep_mask=keep_mask,
mask_blur_strength=mask_blur_strength, mask_blur_strength=mask_blur_strength,
mask_restore=mask_restore,
denoising_strength=denoising_strength, denoising_strength=denoising_strength,
resize_mode=resize_mode, resize_mode=resize_mode,
uses_loopback=loopback, uses_loopback=loopback,
@ -1486,8 +1509,9 @@ def img2img(prompt: str = '', init_info: any = None, init_info_mask: any = None,
init_img=init_img, init_img=init_img,
init_mask=init_mask, init_mask=init_mask,
keep_mask=keep_mask, keep_mask=keep_mask,
mask_blur_strength=2, mask_blur_strength=mask_blur_strength,
denoising_strength=denoising_strength, denoising_strength=denoising_strength,
mask_restore=mask_restore,
resize_mode=resize_mode, resize_mode=resize_mode,
uses_loopback=loopback, uses_loopback=loopback,
sort_samples=group_by_prompt, sort_samples=group_by_prompt,