img2img-mask

Great work, thanks anon!! 💯 🔥
This commit is contained in:
hlky 2022-08-25 00:03:47 +01:00
parent 8d6c046a08
commit 9927f7c38a
No known key found for this signature in database
GPG Key ID: 55A99F1E80D907D5

View File

@ -15,7 +15,7 @@ from contextlib import contextmanager, nullcontext
from einops import rearrange, repeat
from itertools import islice
from omegaconf import OmegaConf
from PIL import Image, ImageFont, ImageDraw
from PIL import Image, ImageFont, ImageDraw, ImageFilter, ImageOps
from torch import autocast
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler
@ -389,7 +389,7 @@ def check_prompt_length(prompt, comments):
comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name, skip_grid, skip_save, batch_size, n_iter, steps, cfg_scale, width, height, prompt_matrix, use_GFPGAN, fp, do_not_save_grid=False, normalize_prompt_weights=True):
def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name, skip_grid, skip_save, batch_size, n_iter, steps, cfg_scale, width, height, prompt_matrix, use_GFPGAN, fp, do_not_save_grid=False, normalize_prompt_weights=True, init_img=None, init_mask=None, keep_mask=True):
"""this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
assert prompt is not None
torch_gc()
@ -497,6 +497,14 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name,
image = Image.fromarray(x_sample)
if init_mask:
init_mask = init_mask if keep_mask else ImageOps.invert(init_mask)
init_mask = init_mask.filter(ImageFilter.GaussianBlur(3))
init_mask = init_mask.convert('L')
init_img = init_img.convert('RGB')
image = image.convert('RGB')
image = Image.composite(init_img, image, init_mask)
filename = f"{base_count:05}-{seeds[i]}_{prompts[i].replace(' ', '_').translate({ord(x): '' for x in invalid_filename_chars})[:128]}.png"
if not skip_save:
image.save(os.path.join(sample_path, filename))
@ -673,7 +681,7 @@ txt2img_interface = gr.Interface(
)
def img2img(prompt: str, init_img, ddim_steps: int, sampler_name: str, use_GFPGAN: bool, prompt_matrix, loopback: bool, skip_grid: bool, skip_save: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, height: int, width: int, resize_mode: int, normalize_prompt_weights: bool, fp):
def img2img(prompt: str, init_info, mask_mode, ddim_steps: int, sampler_name: str, use_GFPGAN: bool, prompt_matrix, loopback: bool, skip_grid: bool, skip_save: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, height: int, width: int, resize_mode: int, normalize_prompt_weights: bool, fp):
outpath = opt.outdir or "outputs/img2img-samples"
err = False
seed = seed_to_int(seed)
@ -685,6 +693,14 @@ def img2img(prompt: str, init_img, ddim_steps: int, sampler_name: str, use_GFPGA
else:
raise Exception("Unknown sampler: " + sampler_name)
init_img = init_info["image"]
init_img = init_img.convert("RGB")
init_img = resize_image(resize_mode, init_img, width, height)
init_mask = init_info["mask"]
init_mask = init_mask.convert("RGB")
init_mask = resize_image(resize_mode, init_mask, width, height)
keep_mask = mask_mode == "Keep masked area"
assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]'
t_enc = int(denoising_strength * ddim_steps)
@ -789,7 +805,10 @@ def img2img(prompt: str, init_img, ddim_steps: int, sampler_name: str, use_GFPGA
prompt_matrix=prompt_matrix,
use_GFPGAN=use_GFPGAN,
fp=fp,
normalize_prompt_weights=normalize_prompt_weights
normalize_prompt_weights=normalize_prompt_weights,
init_img=init_img,
init_mask=init_mask,
keep_mask=keep_mask
)
del sampler
@ -813,7 +832,8 @@ img2img_interface = gr.Interface(
img2img,
inputs=[
gr.Textbox(placeholder="A fantasy landscape, trending on artstation.", lines=1),
gr.Image(value=sample_img2img, source="upload", interactive=True, type="pil"),
gr.Image(value=sample_img2img, source="upload", interactive=True, type="pil", tool="sketch"),
gr.Radio(choices=["Keep masked area", "Regenerate only masked area"], label="Mask Mode", value="Keep masked area"),
gr.Slider(minimum=1, maximum=150, step=1, label="Sampling Steps", value=50),
gr.Radio(label='Sampling method', choices=["DDIM", "k-diffusion"], value="k-diffusion"),
gr.Checkbox(label='Fix faces using GFPGAN', value=False, visible=GFPGAN is not None),