Merge pull request #1079 from unnamedplugins/dev

Add masked image restoration to streamlit, make color correction respect the setting as well
This commit is contained in:
ZeroCool 2022-09-14 06:02:16 -07:00 committed by GitHub
commit 2f55766576
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 13 additions and 1 deletions

View File

@ -114,6 +114,7 @@ img2img:
# 0: Keep masked area # 0: Keep masked area
# 1: Regenerate only masked area # 1: Regenerate only masked area
mask_mode: 0 mask_mode: 0
mask_restore: False
# 0: Just resize # 0: Just resize
# 1: Crop and resize # 1: Crop and resize
# 2: Resize and fill # 2: Resize and fill

View File

@ -279,6 +279,7 @@ def img2img(prompt: str = '', init_info: any = None, init_info_mask: any = None,
if initial_seed is None: if initial_seed is None:
initial_seed = seed initial_seed = seed
input_image = init_img
init_img = output_images[0] init_img = output_images[0]
if do_color_correction and correction_target is not None: if do_color_correction and correction_target is not None:
@ -290,6 +291,13 @@ def img2img(prompt: str = '', init_info: any = None, init_info_mask: any = None,
correction_target, correction_target,
channel_axis=2 channel_axis=2
), cv2.COLOR_LAB2RGB).astype("uint8")) ), cv2.COLOR_LAB2RGB).astype("uint8"))
if mask_restore is True:
color_mask = init_mask.filter(ImageFilter.GaussianBlur(mask_blur_strength))
color_mask = color_mask.convert('L')
source_image = input_image.convert('RGB')
target_image = init_img.convert('RGB')
init_img = Image.composite(source_image, target_image, color_mask)
if not random_seed_loopback: if not random_seed_loopback:
seed = seed + 1 seed = seed + 1
@ -411,6 +419,9 @@ def layout():
help="Ensure the sum of all weights add up to 1.0") help="Ensure the sum of all weights add up to 1.0")
loopback = st.checkbox("Loopback.", value=defaults.img2img.loopback, help="Use images from previous batch when creating next batch.") loopback = st.checkbox("Loopback.", value=defaults.img2img.loopback, help="Use images from previous batch when creating next batch.")
random_seed_loopback = st.checkbox("Random loopback seed.", value=defaults.img2img.random_seed_loopback, help="Random loopback seed") random_seed_loopback = st.checkbox("Random loopback seed.", value=defaults.img2img.random_seed_loopback, help="Random loopback seed")
img2img_mask_restore = st.checkbox("Only modify regenerated parts of image",
value=defaults.img2img.mask_restore,
help="Enable to restore the unmasked parts of the image with the input, may not blend as well but preserves detail")
save_individual_images = st.checkbox("Save individual images.", value=defaults.img2img.save_individual_images, save_individual_images = st.checkbox("Save individual images.", value=defaults.img2img.save_individual_images,
help="Save each image generated before any filter or enhancement is applied.") help="Save each image generated before any filter or enhancement is applied.")
save_grid = st.checkbox("Save grid",value=defaults.img2img.save_grid, help="Save a grid with all the images generated into a single image.") save_grid = st.checkbox("Save grid",value=defaults.img2img.save_grid, help="Save a grid with all the images generated into a single image.")
@ -554,7 +565,7 @@ def layout():
try: try:
output_images, seed, info, stats = img2img(prompt=prompt, init_info=new_img, init_info_mask=new_mask, mask_mode=mask_mode, output_images, seed, info, stats = img2img(prompt=prompt, init_info=new_img, init_info_mask=new_mask, mask_mode=mask_mode,
ddim_steps=st.session_state["sampling_steps"], mask_restore=img2img_mask_restore, ddim_steps=st.session_state["sampling_steps"],
sampler_name=st.session_state["sampler_name"], n_iter=batch_count, sampler_name=st.session_state["sampler_name"], n_iter=batch_count,
cfg_scale=cfg_scale, denoising_strength=st.session_state["denoising_strength"], variant_seed=variant_seed, cfg_scale=cfg_scale, denoising_strength=st.session_state["denoising_strength"], variant_seed=variant_seed,
seed=seed, noise_mode=noise_mode, find_noise_steps=find_noise_steps, width=width, seed=seed, noise_mode=noise_mode, find_noise_steps=find_noise_steps, width=width,