img2img-mask

Great work, thanks anon!! 💯 🔥
This commit is contained in:
hlky 2022-08-25 00:03:47 +01:00
parent 8d6c046a08
commit 9927f7c38a
No known key found for this signature in database
GPG Key ID: 55A99F1E80D907D5

View File

@ -15,7 +15,7 @@ from contextlib import contextmanager, nullcontext
from einops import rearrange, repeat from einops import rearrange, repeat
from itertools import islice from itertools import islice
from omegaconf import OmegaConf from omegaconf import OmegaConf
from PIL import Image, ImageFont, ImageDraw from PIL import Image, ImageFont, ImageDraw, ImageFilter, ImageOps
from torch import autocast from torch import autocast
from ldm.models.diffusion.ddim import DDIMSampler from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler from ldm.models.diffusion.plms import PLMSSampler
@ -104,11 +104,11 @@ class MemUsageMonitor(threading.Thread):
stop_flag = False stop_flag = False
max_usage = 0 max_usage = 0
total = 0 total = 0
def __init__(self, name): def __init__(self, name):
threading.Thread.__init__(self) threading.Thread.__init__(self)
self.name = name self.name = name
def run(self): def run(self):
print(f"[{self.name}] Recording max memory usage...\n") print(f"[{self.name}] Recording max memory usage...\n")
pynvml.nvmlInit() pynvml.nvmlInit()
@ -121,13 +121,13 @@ class MemUsageMonitor(threading.Thread):
time.sleep(0.1) time.sleep(0.1)
print(f"[{self.name}] Stopped recording.\n") print(f"[{self.name}] Stopped recording.\n")
pynvml.nvmlShutdown() pynvml.nvmlShutdown()
def read(self): def read(self):
return self.max_usage, self.total return self.max_usage, self.total
def stop(self): def stop(self):
self.stop_flag = True self.stop_flag = True
def read_and_stop(self): def read_and_stop(self):
self.stop_flag = True self.stop_flag = True
return self.max_usage, self.total return self.max_usage, self.total
@ -163,11 +163,11 @@ class MemUsageMonitor(threading.Thread):
stop_flag = False stop_flag = False
max_usage = 0 max_usage = 0
total = 0 total = 0
def __init__(self, name): def __init__(self, name):
threading.Thread.__init__(self) threading.Thread.__init__(self)
self.name = name self.name = name
def run(self): def run(self):
print(f"[{self.name}] Recording max memory usage...\n") print(f"[{self.name}] Recording max memory usage...\n")
pynvml.nvmlInit() pynvml.nvmlInit()
@ -180,13 +180,13 @@ class MemUsageMonitor(threading.Thread):
time.sleep(0.1) time.sleep(0.1)
print(f"[{self.name}] Stopped recording.\n") print(f"[{self.name}] Stopped recording.\n")
pynvml.nvmlShutdown() pynvml.nvmlShutdown()
def read(self): def read(self):
return self.max_usage, self.total return self.max_usage, self.total
def stop(self): def stop(self):
self.stop_flag = True self.stop_flag = True
def read_and_stop(self): def read_and_stop(self):
self.stop_flag = True self.stop_flag = True
return self.max_usage, self.total return self.max_usage, self.total
@ -389,7 +389,7 @@ def check_prompt_length(prompt, comments):
comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n") comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name, skip_grid, skip_save, batch_size, n_iter, steps, cfg_scale, width, height, prompt_matrix, use_GFPGAN, fp, do_not_save_grid=False, normalize_prompt_weights=True): def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name, skip_grid, skip_save, batch_size, n_iter, steps, cfg_scale, width, height, prompt_matrix, use_GFPGAN, fp, do_not_save_grid=False, normalize_prompt_weights=True, init_img=None, init_mask=None, keep_mask=True):
"""this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch""" """this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
assert prompt is not None assert prompt is not None
torch_gc() torch_gc()
@ -458,7 +458,7 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name,
uc = model.get_learned_conditioning(len(prompts) * [""]) uc = model.get_learned_conditioning(len(prompts) * [""])
if isinstance(prompts, tuple): if isinstance(prompts, tuple):
prompts = list(prompts) prompts = list(prompts)
# split the prompt if it has : for weighting # split the prompt if it has : for weighting
# TODO for speed it might help to have this occur when all_prompts filled?? # TODO for speed it might help to have this occur when all_prompts filled??
subprompts,weights = split_weighted_subprompts(prompts[0]) subprompts,weights = split_weighted_subprompts(prompts[0])
@ -474,16 +474,16 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name,
weight = weight / totalPromptWeight weight = weight / totalPromptWeight
#print(f"{subprompts[i]} {weight*100.0}%") #print(f"{subprompts[i]} {weight*100.0}%")
# note if alpha negative, it functions same as torch.sub # note if alpha negative, it functions same as torch.sub
c = torch.add(c,model.get_learned_conditioning(subprompts[i]), alpha=weight) c = torch.add(c,model.get_learned_conditioning(subprompts[i]), alpha=weight)
else: # just behave like usual else: # just behave like usual
c = model.get_learned_conditioning(prompts) c = model.get_learned_conditioning(prompts)
shape = [opt_C, height // opt_f, width // opt_f] shape = [opt_C, height // opt_f, width // opt_f]
# we manually generate all input noises because each one should have a specific seed # we manually generate all input noises because each one should have a specific seed
x = create_random_tensors([opt_C, height // opt_f, width // opt_f], seeds=seeds) x = create_random_tensors([opt_C, height // opt_f, width // opt_f], seeds=seeds)
samples_ddim = func_sample(init_data=init_data, x=x, conditioning=c, unconditional_conditioning=uc, sampler_name=sampler_name) samples_ddim = func_sample(init_data=init_data, x=x, conditioning=c, unconditional_conditioning=uc, sampler_name=sampler_name)
x_samples_ddim = model.decode_first_stage(samples_ddim) x_samples_ddim = model.decode_first_stage(samples_ddim)
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
@ -497,6 +497,14 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name,
image = Image.fromarray(x_sample) image = Image.fromarray(x_sample)
if init_mask:
init_mask = init_mask if keep_mask else ImageOps.invert(init_mask)
init_mask = init_mask.filter(ImageFilter.GaussianBlur(3))
init_mask = init_mask.convert('L')
init_img = init_img.convert('RGB')
image = image.convert('RGB')
image = Image.composite(init_img, image, init_mask)
filename = f"{base_count:05}-{seeds[i]}_{prompts[i].replace(' ', '_').translate({ord(x): '' for x in invalid_filename_chars})[:128]}.png" filename = f"{base_count:05}-{seeds[i]}_{prompts[i].replace(' ', '_').translate({ord(x): '' for x in invalid_filename_chars})[:128]}.png"
if not skip_save: if not skip_save:
image.save(os.path.join(sample_path, filename)) image.save(os.path.join(sample_path, filename))
@ -532,10 +540,10 @@ Steps: {steps}, Sampler: {sampler_name}, CFG scale: {cfg_scale}, Seed: {seed}{',
stats = f''' stats = f'''
Took { round(time_diff, 2) }s total ({ round(time_diff/(len(all_prompts)),2) }s per image) Took { round(time_diff, 2) }s total ({ round(time_diff/(len(all_prompts)),2) }s per image)
Peak memory usage: { -(mem_max_used // -1_048_576) } MiB / { -(mem_total // -1_048_576) } MiB / { round(mem_max_used/mem_total*100, 3) }%''' Peak memory usage: { -(mem_max_used // -1_048_576) } MiB / { -(mem_total // -1_048_576) } MiB / { round(mem_max_used/mem_total*100, 3) }%'''
for comment in comments: for comment in comments:
info += "\n\n" + comment info += "\n\n" + comment
#mem_mon.stop() #mem_mon.stop()
#del mem_mon #del mem_mon
torch_gc() torch_gc()
@ -673,7 +681,7 @@ txt2img_interface = gr.Interface(
) )
def img2img(prompt: str, init_img, ddim_steps: int, sampler_name: str, use_GFPGAN: bool, prompt_matrix, loopback: bool, skip_grid: bool, skip_save: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, height: int, width: int, resize_mode: int, normalize_prompt_weights: bool, fp): def img2img(prompt: str, init_info, mask_mode, ddim_steps: int, sampler_name: str, use_GFPGAN: bool, prompt_matrix, loopback: bool, skip_grid: bool, skip_save: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, height: int, width: int, resize_mode: int, normalize_prompt_weights: bool, fp):
outpath = opt.outdir or "outputs/img2img-samples" outpath = opt.outdir or "outputs/img2img-samples"
err = False err = False
seed = seed_to_int(seed) seed = seed_to_int(seed)
@ -685,6 +693,14 @@ def img2img(prompt: str, init_img, ddim_steps: int, sampler_name: str, use_GFPGA
else: else:
raise Exception("Unknown sampler: " + sampler_name) raise Exception("Unknown sampler: " + sampler_name)
init_img = init_info["image"]
init_img = init_img.convert("RGB")
init_img = resize_image(resize_mode, init_img, width, height)
init_mask = init_info["mask"]
init_mask = init_mask.convert("RGB")
init_mask = resize_image(resize_mode, init_mask, width, height)
keep_mask = mask_mode == "Keep masked area"
assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]' assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]'
t_enc = int(denoising_strength * ddim_steps) t_enc = int(denoising_strength * ddim_steps)
@ -789,9 +805,12 @@ def img2img(prompt: str, init_img, ddim_steps: int, sampler_name: str, use_GFPGA
prompt_matrix=prompt_matrix, prompt_matrix=prompt_matrix,
use_GFPGAN=use_GFPGAN, use_GFPGAN=use_GFPGAN,
fp=fp, fp=fp,
normalize_prompt_weights=normalize_prompt_weights normalize_prompt_weights=normalize_prompt_weights,
init_img=init_img,
init_mask=init_mask,
keep_mask=keep_mask
) )
del sampler del sampler
return output_images, seed, info, stats return output_images, seed, info, stats
@ -813,7 +832,8 @@ img2img_interface = gr.Interface(
img2img, img2img,
inputs=[ inputs=[
gr.Textbox(placeholder="A fantasy landscape, trending on artstation.", lines=1), gr.Textbox(placeholder="A fantasy landscape, trending on artstation.", lines=1),
gr.Image(value=sample_img2img, source="upload", interactive=True, type="pil"), gr.Image(value=sample_img2img, source="upload", interactive=True, type="pil", tool="sketch"),
gr.Radio(choices=["Keep masked area", "Regenerate only masked area"], label="Mask Mode", value="Keep masked area"),
gr.Slider(minimum=1, maximum=150, step=1, label="Sampling Steps", value=50), gr.Slider(minimum=1, maximum=150, step=1, label="Sampling Steps", value=50),
gr.Radio(label='Sampling method', choices=["DDIM", "k-diffusion"], value="k-diffusion"), gr.Radio(label='Sampling method', choices=["DDIM", "k-diffusion"], value="k-diffusion"),
gr.Checkbox(label='Fix faces using GFPGAN', value=False, visible=GFPGAN is not None), gr.Checkbox(label='Fix faces using GFPGAN', value=False, visible=GFPGAN is not None),
@ -930,4 +950,4 @@ demo = gr.TabbedInterface(
theme="default", theme="default",
) )
demo.queue(concurrency_count=1) demo.queue(concurrency_count=1)
demo.launch(show_error=True, server_name='0.0.0.0') demo.launch(show_error=True, server_name='0.0.0.0')