img2img-mask

Great work, thanks anon!! 💯 🔥
This commit is contained in:
hlky 2022-08-25 00:03:47 +01:00
parent 8d6c046a08
commit 9927f7c38a
No known key found for this signature in database
GPG Key ID: 55A99F1E80D907D5

View File

@ -15,7 +15,7 @@ from contextlib import contextmanager, nullcontext
from einops import rearrange, repeat
from itertools import islice
from omegaconf import OmegaConf
from PIL import Image, ImageFont, ImageDraw
from PIL import Image, ImageFont, ImageDraw, ImageFilter, ImageOps
from torch import autocast
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler
@ -104,11 +104,11 @@ class MemUsageMonitor(threading.Thread):
stop_flag = False
max_usage = 0
total = 0
def __init__(self, name):
threading.Thread.__init__(self)
self.name = name
def run(self):
print(f"[{self.name}] Recording max memory usage...\n")
pynvml.nvmlInit()
@ -121,13 +121,13 @@ class MemUsageMonitor(threading.Thread):
time.sleep(0.1)
print(f"[{self.name}] Stopped recording.\n")
pynvml.nvmlShutdown()
def read(self):
return self.max_usage, self.total
def stop(self):
self.stop_flag = True
def read_and_stop(self):
self.stop_flag = True
return self.max_usage, self.total
@ -163,11 +163,11 @@ class MemUsageMonitor(threading.Thread):
stop_flag = False
max_usage = 0
total = 0
def __init__(self, name):
threading.Thread.__init__(self)
self.name = name
def run(self):
print(f"[{self.name}] Recording max memory usage...\n")
pynvml.nvmlInit()
@ -180,13 +180,13 @@ class MemUsageMonitor(threading.Thread):
time.sleep(0.1)
print(f"[{self.name}] Stopped recording.\n")
pynvml.nvmlShutdown()
def read(self):
return self.max_usage, self.total
def stop(self):
self.stop_flag = True
def read_and_stop(self):
self.stop_flag = True
return self.max_usage, self.total
@ -389,7 +389,7 @@ def check_prompt_length(prompt, comments):
comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name, skip_grid, skip_save, batch_size, n_iter, steps, cfg_scale, width, height, prompt_matrix, use_GFPGAN, fp, do_not_save_grid=False, normalize_prompt_weights=True):
def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name, skip_grid, skip_save, batch_size, n_iter, steps, cfg_scale, width, height, prompt_matrix, use_GFPGAN, fp, do_not_save_grid=False, normalize_prompt_weights=True, init_img=None, init_mask=None, keep_mask=True):
"""this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
assert prompt is not None
torch_gc()
@ -458,7 +458,7 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name,
uc = model.get_learned_conditioning(len(prompts) * [""])
if isinstance(prompts, tuple):
prompts = list(prompts)
# split the prompt if it has : for weighting
# TODO for speed it might help to have this occur when all_prompts filled??
subprompts,weights = split_weighted_subprompts(prompts[0])
@ -474,16 +474,16 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name,
weight = weight / totalPromptWeight
#print(f"{subprompts[i]} {weight*100.0}%")
# note if alpha negative, it functions same as torch.sub
c = torch.add(c,model.get_learned_conditioning(subprompts[i]), alpha=weight)
c = torch.add(c,model.get_learned_conditioning(subprompts[i]), alpha=weight)
else: # just behave like usual
c = model.get_learned_conditioning(prompts)
shape = [opt_C, height // opt_f, width // opt_f]
# we manually generate all input noises because each one should have a specific seed
x = create_random_tensors([opt_C, height // opt_f, width // opt_f], seeds=seeds)
samples_ddim = func_sample(init_data=init_data, x=x, conditioning=c, unconditional_conditioning=uc, sampler_name=sampler_name)
x_samples_ddim = model.decode_first_stage(samples_ddim)
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
@ -497,6 +497,14 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name,
image = Image.fromarray(x_sample)
if init_mask:
init_mask = init_mask if keep_mask else ImageOps.invert(init_mask)
init_mask = init_mask.filter(ImageFilter.GaussianBlur(3))
init_mask = init_mask.convert('L')
init_img = init_img.convert('RGB')
image = image.convert('RGB')
image = Image.composite(init_img, image, init_mask)
filename = f"{base_count:05}-{seeds[i]}_{prompts[i].replace(' ', '_').translate({ord(x): '' for x in invalid_filename_chars})[:128]}.png"
if not skip_save:
image.save(os.path.join(sample_path, filename))
@ -532,10 +540,10 @@ Steps: {steps}, Sampler: {sampler_name}, CFG scale: {cfg_scale}, Seed: {seed}{',
stats = f'''
Took { round(time_diff, 2) }s total ({ round(time_diff/(len(all_prompts)),2) }s per image)
Peak memory usage: { -(mem_max_used // -1_048_576) } MiB / { -(mem_total // -1_048_576) } MiB / { round(mem_max_used/mem_total*100, 3) }%'''
for comment in comments:
info += "\n\n" + comment
#mem_mon.stop()
#del mem_mon
torch_gc()
@ -673,7 +681,7 @@ txt2img_interface = gr.Interface(
)
def img2img(prompt: str, init_img, ddim_steps: int, sampler_name: str, use_GFPGAN: bool, prompt_matrix, loopback: bool, skip_grid: bool, skip_save: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, height: int, width: int, resize_mode: int, normalize_prompt_weights: bool, fp):
def img2img(prompt: str, init_info, mask_mode, ddim_steps: int, sampler_name: str, use_GFPGAN: bool, prompt_matrix, loopback: bool, skip_grid: bool, skip_save: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, height: int, width: int, resize_mode: int, normalize_prompt_weights: bool, fp):
outpath = opt.outdir or "outputs/img2img-samples"
err = False
seed = seed_to_int(seed)
@ -685,6 +693,14 @@ def img2img(prompt: str, init_img, ddim_steps: int, sampler_name: str, use_GFPGA
else:
raise Exception("Unknown sampler: " + sampler_name)
init_img = init_info["image"]
init_img = init_img.convert("RGB")
init_img = resize_image(resize_mode, init_img, width, height)
init_mask = init_info["mask"]
init_mask = init_mask.convert("RGB")
init_mask = resize_image(resize_mode, init_mask, width, height)
keep_mask = mask_mode == "Keep masked area"
assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]'
t_enc = int(denoising_strength * ddim_steps)
@ -789,9 +805,12 @@ def img2img(prompt: str, init_img, ddim_steps: int, sampler_name: str, use_GFPGA
prompt_matrix=prompt_matrix,
use_GFPGAN=use_GFPGAN,
fp=fp,
normalize_prompt_weights=normalize_prompt_weights
normalize_prompt_weights=normalize_prompt_weights,
init_img=init_img,
init_mask=init_mask,
keep_mask=keep_mask
)
del sampler
return output_images, seed, info, stats
@ -813,7 +832,8 @@ img2img_interface = gr.Interface(
img2img,
inputs=[
gr.Textbox(placeholder="A fantasy landscape, trending on artstation.", lines=1),
gr.Image(value=sample_img2img, source="upload", interactive=True, type="pil"),
gr.Image(value=sample_img2img, source="upload", interactive=True, type="pil", tool="sketch"),
gr.Radio(choices=["Keep masked area", "Regenerate only masked area"], label="Mask Mode", value="Keep masked area"),
gr.Slider(minimum=1, maximum=150, step=1, label="Sampling Steps", value=50),
gr.Radio(label='Sampling method', choices=["DDIM", "k-diffusion"], value="k-diffusion"),
gr.Checkbox(label='Fix faces using GFPGAN', value=False, visible=GFPGAN is not None),
@ -930,4 +950,4 @@ demo = gr.TabbedInterface(
theme="default",
)
demo.queue(concurrency_count=1)
demo.launch(show_error=True, server_name='0.0.0.0')
demo.launch(show_error=True, server_name='0.0.0.0')