Add mask_restore to restore images based on mask, fixing #665 (#898)

* Add mask_restore option to give users the option to restore images based on mask, fixing #665.

Before commit c73fdd78  (Implement masking during sampling to improve blending, #308)
image mask was applied after sampling, resulting in masked parts that are not regenerated
to actually stay the same.
Since c73fdd78 the masked img2img will change the whole image, even in masked areas.
It gives better looking results at first glance, but will result in image degredation when
applied a few times. See issue #665.

In the workflow of using repeated masked img2img, users may want to use this options to keep the parts
of image they actually want to keep without image degradation. A final masked img2img or whole image img2img with mask_restore disabled
will give the better blending of "Implement masking during sampling".

* revert changes of a7be43ba in change_image_editor_mode

* fix ui_functions.change_image_editor_mode by adding gr.update to the end of the list it returns

* revert inserted newlines and whitespaces to match format of previous code

* improve caption of new option mask_restore

"Only modify regenerated parts of image"

* fix ui_functions.change_image_editor_mode by adding gr.update to the end of the list it returns

an old copy of the function exists in webui.py, this superflous function mistakenly was changed by the earlier commit b6a9e16b

* remove unused functions that are near duplicates of functions in ui_functions.py
This commit is contained in:
xaedes 2022-09-09 23:07:14 +02:00 committed by GitHub
parent c7d6855bae
commit f6aa2c64eb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 65 additions and 30 deletions

View File

@ -200,9 +200,13 @@ def draw_gradio_ui(opt, img2img=lambda x: x, txt2img=lambda x: x, imgproc=lambda
value=img2img_mask_modes[img2img_defaults['mask_mode']], value=img2img_mask_modes[img2img_defaults['mask_mode']],
visible=True) visible=True)
img2img_mask_blur_strength = gr.Slider(minimum=1, maximum=10, step=1, img2img_mask_restore = gr.Checkbox(label="Only modify regenerated parts of image",
value=img2img_defaults['mask_restore'],
visible=True)
img2img_mask_blur_strength = gr.Slider(minimum=1, maximum=100, step=1,
label="How much blurry should the mask be? (to avoid hard edges)", label="How much blurry should the mask be? (to avoid hard edges)",
value=3, visible=False) value=3, visible=True)
img2img_resize = gr.Radio(label="Resize mode", img2img_resize = gr.Radio(label="Resize mode",
choices=["Just resize", "Crop and resize", choices=["Just resize", "Crop and resize",
@ -294,7 +298,7 @@ def draw_gradio_ui(opt, img2img=lambda x: x, txt2img=lambda x: x, imgproc=lambda
img2img_height img2img_height
], ],
[img2img_image_editor, img2img_image_mask, img2img_btn_editor, img2img_btn_mask, [img2img_image_editor, img2img_image_mask, img2img_btn_editor, img2img_btn_mask,
img2img_painterro_btn, img2img_mask, img2img_mask_blur_strength] img2img_painterro_btn, img2img_mask, img2img_mask_blur_strength, img2img_mask_restore]
) )
# img2img_image_editor_mode.change( # img2img_image_editor_mode.change(
@ -335,8 +339,8 @@ def draw_gradio_ui(opt, img2img=lambda x: x, txt2img=lambda x: x, imgproc=lambda
) )
img2img_func = img2img img2img_func = img2img
img2img_inputs = [img2img_prompt, img2img_image_editor_mode, img2img_mask, img2img_inputs = [img2img_prompt, img2img_image_editor_mode, img2img_mask, img2img_mask_blur_strength,
img2img_mask_blur_strength, img2img_steps, img2img_sampling, img2img_toggles, img2img_mask_restore, img2img_steps, img2img_sampling, img2img_toggles,
img2img_realesrgan_model_name, img2img_batch_count, img2img_cfg, img2img_realesrgan_model_name, img2img_batch_count, img2img_cfg,
img2img_denoising, img2img_seed, img2img_height, img2img_width, img2img_resize, img2img_denoising, img2img_seed, img2img_height, img2img_width, img2img_resize,
img2img_image_editor, img2img_image_mask, img2img_embeddings] img2img_image_editor, img2img_image_mask, img2img_embeddings]

View File

@ -9,10 +9,10 @@ import re
def change_image_editor_mode(choice, cropped_image, masked_image, resize_mode, width, height): def change_image_editor_mode(choice, cropped_image, masked_image, resize_mode, width, height):
if choice == "Mask": if choice == "Mask":
update_image_result = update_image_mask(cropped_image, resize_mode, width, height) update_image_result = update_image_mask(cropped_image, resize_mode, width, height)
return [gr.update(visible=False), update_image_result, gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=True), gr.update(visible=True)] return [gr.update(visible=False), update_image_result, gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)]
update_image_result = update_image_mask(masked_image["image"] if masked_image is not None else None, resize_mode, width, height) update_image_result = update_image_mask(masked_image["image"] if masked_image is not None else None, resize_mode, width, height)
return [update_image_result, gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)] return [update_image_result, gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)]
def update_image_mask(cropped_image, resize_mode, width, height): def update_image_mask(cropped_image, resize_mode, width, height):
resized_cropped_image = resize_image(resize_mode, cropped_image, width, height) if cropped_image else None resized_cropped_image = resize_image(resize_mode, cropped_image, width, height) if cropped_image else None

View File

@ -790,7 +790,7 @@ def process_images(
outpath, func_init, func_sample, prompt, seed, sampler_name, skip_grid, skip_save, batch_size, outpath, func_init, func_sample, prompt, seed, sampler_name, skip_grid, skip_save, batch_size,
n_iter, steps, cfg_scale, width, height, prompt_matrix, use_GFPGAN, use_RealESRGAN, realesrgan_model_name, n_iter, steps, cfg_scale, width, height, prompt_matrix, use_GFPGAN, use_RealESRGAN, realesrgan_model_name,
fp, ddim_eta=0.0, do_not_save_grid=False, normalize_prompt_weights=True, init_img=None, init_mask=None, fp, ddim_eta=0.0, do_not_save_grid=False, normalize_prompt_weights=True, init_img=None, init_mask=None,
keep_mask=False, mask_blur_strength=3, denoising_strength=0.75, resize_mode=None, uses_loopback=False, keep_mask=False, mask_blur_strength=3, mask_restore=False, denoising_strength=0.75, resize_mode=None, uses_loopback=False,
uses_random_seed_loopback=False, sort_samples=True, write_info_files=True, write_sample_info_to_log_file=False, jpg_sample=False, uses_random_seed_loopback=False, sort_samples=True, write_info_files=True, write_sample_info_to_log_file=False, jpg_sample=False,
variant_amount=0.0, variant_seed=None,imgProcessorTask=False, job_info: JobInfo = None): variant_amount=0.0, variant_seed=None,imgProcessorTask=False, job_info: JobInfo = None):
"""this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch""" """this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
@ -1045,6 +1045,26 @@ skip_save, skip_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size,
if imgProcessorTask == True: if imgProcessorTask == True:
output_images.append(image) output_images.append(image)
if mask_restore and init_mask:
#init_mask = init_mask if keep_mask else ImageOps.invert(init_mask)
init_mask = init_mask.filter(ImageFilter.GaussianBlur(mask_blur_strength))
init_mask = init_mask.convert('L')
init_img = init_img.convert('RGB')
image = image.convert('RGB')
if use_RealESRGAN and RealESRGAN is not None:
if RealESRGAN.model.name != realesrgan_model_name:
try_loading_RealESRGAN(realesrgan_model_name)
output, img_mode = RealESRGAN.enhance(np.array(init_img, dtype=np.uint8))
init_img = Image.fromarray(output)
init_img = init_img.convert('RGB')
output, img_mode = RealESRGAN.enhance(np.array(init_mask, dtype=np.uint8))
init_mask = Image.fromarray(output)
init_mask = init_mask.convert('L')
image = Image.composite(init_img, image, init_mask)
if not skip_save: if not skip_save:
save_sample(image, sample_path_i, filename, jpg_sample, write_info_files, write_sample_info_to_log_file, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, skip_save, save_sample(image, sample_path_i, filename, jpg_sample, write_info_files, write_sample_info_to_log_file, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, skip_save,
skip_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, False) skip_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, False)
@ -1259,7 +1279,7 @@ def blurArr(a,r=8):
def img2img(prompt: str, image_editor_mode: str, mask_mode: str, mask_blur_strength: int, ddim_steps: int, sampler_name: str, def img2img(prompt: str, image_editor_mode: str, mask_mode: str, mask_blur_strength: int, mask_restore: bool, ddim_steps: int, sampler_name: str,
toggles: List[int], realesrgan_model_name: str, n_iter: int, cfg_scale: float, denoising_strength: float, toggles: List[int], realesrgan_model_name: str, n_iter: int, cfg_scale: float, denoising_strength: float,
seed: int, height: int, width: int, resize_mode: int, init_info: any = None, init_info_mask: any = None, fp = None, job_info: JobInfo = None): seed: int, height: int, width: int, resize_mode: int, init_info: any = None, init_info_mask: any = None, fp = None, job_info: JobInfo = None):
# print([prompt, image_editor_mode, init_info, init_info_mask, mask_mode, # print([prompt, image_editor_mode, init_info, init_info_mask, mask_mode,
@ -1504,6 +1524,7 @@ def img2img(prompt: str, image_editor_mode: str, mask_mode: str, mask_blur_stren
init_mask=init_mask, init_mask=init_mask,
keep_mask=keep_mask, keep_mask=keep_mask,
mask_blur_strength=mask_blur_strength, mask_blur_strength=mask_blur_strength,
mask_restore=mask_restore,
denoising_strength=denoising_strength, denoising_strength=denoising_strength,
resize_mode=resize_mode, resize_mode=resize_mode,
uses_loopback=loopback, uses_loopback=loopback,
@ -1574,6 +1595,7 @@ def img2img(prompt: str, image_editor_mode: str, mask_mode: str, mask_blur_stren
keep_mask=keep_mask, keep_mask=keep_mask,
mask_blur_strength=mask_blur_strength, mask_blur_strength=mask_blur_strength,
denoising_strength=denoising_strength, denoising_strength=denoising_strength,
mask_restore=mask_restore,
resize_mode=resize_mode, resize_mode=resize_mode,
uses_loopback=loopback, uses_loopback=loopback,
sort_samples=sort_samples, sort_samples=sort_samples,
@ -1722,6 +1744,7 @@ def imgproc(image,image_batch,imgproc_prompt,imgproc_toggles, imgproc_upscale_to
init_img = result init_img = result
init_mask = None init_mask = None
keep_mask = False keep_mask = False
mask_restore = False
assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]' assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]'
def init(): def init():
@ -1868,6 +1891,7 @@ def imgproc(image,image_batch,imgproc_prompt,imgproc_toggles, imgproc_upscale_to
keep_mask=False, keep_mask=False,
mask_blur_strength=None, mask_blur_strength=None,
denoising_strength=denoising_strength, denoising_strength=denoising_strength,
mask_restore=mask_restore,
resize_mode=resize_mode, resize_mode=resize_mode,
uses_loopback=False, uses_loopback=False,
sort_samples=True, sort_samples=True,
@ -2186,6 +2210,7 @@ img2img_defaults = {
'cfg_scale': 5.0, 'cfg_scale': 5.0,
'denoising_strength': 0.75, 'denoising_strength': 0.75,
'mask_mode': 0, 'mask_mode': 0,
'mask_restore': False,
'resize_mode': 0, 'resize_mode': 0,
'seed': '', 'seed': '',
'height': 512, 'height': 512,
@ -2199,24 +2224,6 @@ if 'img2img' in user_defaults:
img2img_toggle_defaults = [img2img_toggles[i] for i in img2img_defaults['toggles']] img2img_toggle_defaults = [img2img_toggles[i] for i in img2img_defaults['toggles']]
img2img_image_mode = 'sketch' img2img_image_mode = 'sketch'
def change_image_editor_mode(choice, cropped_image, resize_mode, width, height):
if choice == "Mask":
return [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), gr.update(visible=True)]
return [gr.update(visible=True), gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=True), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)]
def update_image_mask(cropped_image, resize_mode, width, height):
resized_cropped_image = resize_image(resize_mode, cropped_image, width, height) if cropped_image else None
return gr.update(value=resized_cropped_image)
def copy_img_to_upscale_esrgan(img):
update = gr.update(selected='realesrgan_tab')
image_data = re.sub('^data:image/.+;base64,', '', img)
processed_image = Image.open(BytesIO(base64.b64decode(image_data)))
return {'realesrgan_source': processed_image, 'tabs': update}
help_text = """ help_text = """
## Mask/Crop ## Mask/Crop
* The masking/cropping is very temperamental. * The masking/cropping is very temperamental.

View File

@ -913,7 +913,7 @@ def process_images(
outpath, func_init, func_sample, prompt, seed, sampler_name, save_grid, batch_size, outpath, func_init, func_sample, prompt, seed, sampler_name, save_grid, batch_size,
n_iter, steps, cfg_scale, width, height, prompt_matrix, use_GFPGAN, use_RealESRGAN, realesrgan_model_name, n_iter, steps, cfg_scale, width, height, prompt_matrix, use_GFPGAN, use_RealESRGAN, realesrgan_model_name,
fp=None, ddim_eta=0.0, normalize_prompt_weights=True, init_img=None, init_mask=None, fp=None, ddim_eta=0.0, normalize_prompt_weights=True, init_img=None, init_mask=None,
keep_mask=False, mask_blur_strength=3, denoising_strength=0.75, resize_mode=None, uses_loopback=False, keep_mask=False, mask_blur_strength=3, mask_restore=False, denoising_strength=0.75, resize_mode=None, uses_loopback=False,
uses_random_seed_loopback=False, sort_samples=True, write_info_files=True, jpg_sample=False, uses_random_seed_loopback=False, sort_samples=True, write_info_files=True, jpg_sample=False,
variant_amount=0.0, variant_seed=None, save_individual_images: bool = True): variant_amount=0.0, variant_seed=None, save_individual_images: bool = True):
"""this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch""" """this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
@ -1156,7 +1156,29 @@ def process_images(
if simple_templating: if simple_templating:
grid_captions.append( captions[i] + "\ngfpgan_esrgan" ) grid_captions.append( captions[i] + "\ngfpgan_esrgan" )
if mask_restore and init_mask:
#init_mask = init_mask if keep_mask else ImageOps.invert(init_mask)
init_mask = init_mask.filter(ImageFilter.GaussianBlur(mask_blur_strength))
init_mask = init_mask.convert('L')
init_img = init_img.convert('RGB')
image = image.convert('RGB')
if use_RealESRGAN and st.session_state["RealESRGAN"] is not None:
if st.session_state["RealESRGAN"].model.name != realesrgan_model_name:
#try_loading_RealESRGAN(realesrgan_model_name)
load_models(use_GFPGAN=use_GFPGAN, use_RealESRGAN=use_RealESRGAN, RealESRGAN_model=realesrgan_model_name)
output, img_mode = st.session_state["RealESRGAN"].enhance(np.array(init_img, dtype=np.uint8))
init_img = Image.fromarray(output)
init_img = init_img.convert('RGB')
output, img_mode = st.session_state["RealESRGAN"].enhance(np.array(init_mask, dtype=np.uint8))
init_mask = Image.fromarray(output)
init_mask = init_mask.convert('L')
image = Image.composite(init_img, image, init_mask)
if save_individual_images: if save_individual_images:
save_sample(image, sample_path_i, filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale, save_sample(image, sample_path_i, filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale,
normalize_prompt_weights, use_GFPGAN, write_info_files, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, normalize_prompt_weights, use_GFPGAN, write_info_files, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback,
@ -1257,7 +1279,7 @@ def resize_image(resize_mode, im, width, height):
return res return res
def img2img(prompt: str = '', init_info: any = None, init_info_mask: any = None, mask_mode: int = 0, mask_blur_strength: int = 3, def img2img(prompt: str = '', init_info: any = None, init_info_mask: any = None, mask_mode: int = 0, mask_blur_strength: int = 3,
ddim_steps: int = 50, sampler_name: str = 'DDIM', mask_restore: bool = False, ddim_steps: int = 50, sampler_name: str = 'DDIM',
n_iter: int = 1, cfg_scale: float = 7.5, denoising_strength: float = 0.8, n_iter: int = 1, cfg_scale: float = 7.5, denoising_strength: float = 0.8,
seed: int = -1, height: int = 512, width: int = 512, resize_mode: int = 0, fp = None, seed: int = -1, height: int = 512, width: int = 512, resize_mode: int = 0, fp = None,
variant_amount: float = None, variant_seed: int = None, ddim_eta:float = 0.0, variant_amount: float = None, variant_seed: int = None, ddim_eta:float = 0.0,
@ -1426,6 +1448,7 @@ def img2img(prompt: str = '', init_info: any = None, init_info_mask: any = None,
init_mask=init_mask, init_mask=init_mask,
keep_mask=keep_mask, keep_mask=keep_mask,
mask_blur_strength=mask_blur_strength, mask_blur_strength=mask_blur_strength,
mask_restore=mask_restore,
denoising_strength=denoising_strength, denoising_strength=denoising_strength,
resize_mode=resize_mode, resize_mode=resize_mode,
uses_loopback=loopback, uses_loopback=loopback,
@ -1486,8 +1509,9 @@ def img2img(prompt: str = '', init_info: any = None, init_info_mask: any = None,
init_img=init_img, init_img=init_img,
init_mask=init_mask, init_mask=init_mask,
keep_mask=keep_mask, keep_mask=keep_mask,
mask_blur_strength=2, mask_blur_strength=mask_blur_strength,
denoising_strength=denoising_strength, denoising_strength=denoising_strength,
mask_restore=mask_restore,
resize_mode=resize_mode, resize_mode=resize_mode,
uses_loopback=loopback, uses_loopback=loopback,
sort_samples=group_by_prompt, sort_samples=group_by_prompt,