mirror of
https://github.com/sd-webui/stable-diffusion-webui.git
synced 2024-12-15 07:12:58 +03:00
img2img-mask
Great work, thanks anon!! 💯 🔥
This commit is contained in:
parent
8d6c046a08
commit
9927f7c38a
66
webui.py
66
webui.py
@ -15,7 +15,7 @@ from contextlib import contextmanager, nullcontext
|
||||
from einops import rearrange, repeat
|
||||
from itertools import islice
|
||||
from omegaconf import OmegaConf
|
||||
from PIL import Image, ImageFont, ImageDraw
|
||||
from PIL import Image, ImageFont, ImageDraw, ImageFilter, ImageOps
|
||||
from torch import autocast
|
||||
from ldm.models.diffusion.ddim import DDIMSampler
|
||||
from ldm.models.diffusion.plms import PLMSSampler
|
||||
@ -104,11 +104,11 @@ class MemUsageMonitor(threading.Thread):
|
||||
stop_flag = False
|
||||
max_usage = 0
|
||||
total = 0
|
||||
|
||||
|
||||
def __init__(self, name):
|
||||
threading.Thread.__init__(self)
|
||||
self.name = name
|
||||
|
||||
|
||||
def run(self):
|
||||
print(f"[{self.name}] Recording max memory usage...\n")
|
||||
pynvml.nvmlInit()
|
||||
@ -121,13 +121,13 @@ class MemUsageMonitor(threading.Thread):
|
||||
time.sleep(0.1)
|
||||
print(f"[{self.name}] Stopped recording.\n")
|
||||
pynvml.nvmlShutdown()
|
||||
|
||||
|
||||
def read(self):
|
||||
return self.max_usage, self.total
|
||||
|
||||
|
||||
def stop(self):
|
||||
self.stop_flag = True
|
||||
|
||||
|
||||
def read_and_stop(self):
|
||||
self.stop_flag = True
|
||||
return self.max_usage, self.total
|
||||
@ -163,11 +163,11 @@ class MemUsageMonitor(threading.Thread):
|
||||
stop_flag = False
|
||||
max_usage = 0
|
||||
total = 0
|
||||
|
||||
|
||||
def __init__(self, name):
|
||||
threading.Thread.__init__(self)
|
||||
self.name = name
|
||||
|
||||
|
||||
def run(self):
|
||||
print(f"[{self.name}] Recording max memory usage...\n")
|
||||
pynvml.nvmlInit()
|
||||
@ -180,13 +180,13 @@ class MemUsageMonitor(threading.Thread):
|
||||
time.sleep(0.1)
|
||||
print(f"[{self.name}] Stopped recording.\n")
|
||||
pynvml.nvmlShutdown()
|
||||
|
||||
|
||||
def read(self):
|
||||
return self.max_usage, self.total
|
||||
|
||||
|
||||
def stop(self):
|
||||
self.stop_flag = True
|
||||
|
||||
|
||||
def read_and_stop(self):
|
||||
self.stop_flag = True
|
||||
return self.max_usage, self.total
|
||||
@ -389,7 +389,7 @@ def check_prompt_length(prompt, comments):
|
||||
comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
|
||||
|
||||
|
||||
def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name, skip_grid, skip_save, batch_size, n_iter, steps, cfg_scale, width, height, prompt_matrix, use_GFPGAN, fp, do_not_save_grid=False, normalize_prompt_weights=True):
|
||||
def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name, skip_grid, skip_save, batch_size, n_iter, steps, cfg_scale, width, height, prompt_matrix, use_GFPGAN, fp, do_not_save_grid=False, normalize_prompt_weights=True, init_img=None, init_mask=None, keep_mask=True):
|
||||
"""this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
|
||||
assert prompt is not None
|
||||
torch_gc()
|
||||
@ -458,7 +458,7 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name,
|
||||
uc = model.get_learned_conditioning(len(prompts) * [""])
|
||||
if isinstance(prompts, tuple):
|
||||
prompts = list(prompts)
|
||||
|
||||
|
||||
# split the prompt if it has : for weighting
|
||||
# TODO for speed it might help to have this occur when all_prompts filled??
|
||||
subprompts,weights = split_weighted_subprompts(prompts[0])
|
||||
@ -474,16 +474,16 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name,
|
||||
weight = weight / totalPromptWeight
|
||||
#print(f"{subprompts[i]} {weight*100.0}%")
|
||||
# note if alpha negative, it functions same as torch.sub
|
||||
c = torch.add(c,model.get_learned_conditioning(subprompts[i]), alpha=weight)
|
||||
c = torch.add(c,model.get_learned_conditioning(subprompts[i]), alpha=weight)
|
||||
else: # just behave like usual
|
||||
c = model.get_learned_conditioning(prompts)
|
||||
|
||||
|
||||
shape = [opt_C, height // opt_f, width // opt_f]
|
||||
|
||||
# we manually generate all input noises because each one should have a specific seed
|
||||
x = create_random_tensors([opt_C, height // opt_f, width // opt_f], seeds=seeds)
|
||||
samples_ddim = func_sample(init_data=init_data, x=x, conditioning=c, unconditional_conditioning=uc, sampler_name=sampler_name)
|
||||
|
||||
|
||||
|
||||
x_samples_ddim = model.decode_first_stage(samples_ddim)
|
||||
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
@ -497,6 +497,14 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name,
|
||||
|
||||
|
||||
image = Image.fromarray(x_sample)
|
||||
if init_mask:
|
||||
init_mask = init_mask if keep_mask else ImageOps.invert(init_mask)
|
||||
init_mask = init_mask.filter(ImageFilter.GaussianBlur(3))
|
||||
init_mask = init_mask.convert('L')
|
||||
init_img = init_img.convert('RGB')
|
||||
image = image.convert('RGB')
|
||||
image = Image.composite(init_img, image, init_mask)
|
||||
|
||||
filename = f"{base_count:05}-{seeds[i]}_{prompts[i].replace(' ', '_').translate({ord(x): '' for x in invalid_filename_chars})[:128]}.png"
|
||||
if not skip_save:
|
||||
image.save(os.path.join(sample_path, filename))
|
||||
@ -532,10 +540,10 @@ Steps: {steps}, Sampler: {sampler_name}, CFG scale: {cfg_scale}, Seed: {seed}{',
|
||||
stats = f'''
|
||||
Took { round(time_diff, 2) }s total ({ round(time_diff/(len(all_prompts)),2) }s per image)
|
||||
Peak memory usage: { -(mem_max_used // -1_048_576) } MiB / { -(mem_total // -1_048_576) } MiB / { round(mem_max_used/mem_total*100, 3) }%'''
|
||||
|
||||
|
||||
for comment in comments:
|
||||
info += "\n\n" + comment
|
||||
|
||||
|
||||
#mem_mon.stop()
|
||||
#del mem_mon
|
||||
torch_gc()
|
||||
@ -673,7 +681,7 @@ txt2img_interface = gr.Interface(
|
||||
)
|
||||
|
||||
|
||||
def img2img(prompt: str, init_img, ddim_steps: int, sampler_name: str, use_GFPGAN: bool, prompt_matrix, loopback: bool, skip_grid: bool, skip_save: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, height: int, width: int, resize_mode: int, normalize_prompt_weights: bool, fp):
|
||||
def img2img(prompt: str, init_info, mask_mode, ddim_steps: int, sampler_name: str, use_GFPGAN: bool, prompt_matrix, loopback: bool, skip_grid: bool, skip_save: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, height: int, width: int, resize_mode: int, normalize_prompt_weights: bool, fp):
|
||||
outpath = opt.outdir or "outputs/img2img-samples"
|
||||
err = False
|
||||
seed = seed_to_int(seed)
|
||||
@ -685,6 +693,14 @@ def img2img(prompt: str, init_img, ddim_steps: int, sampler_name: str, use_GFPGA
|
||||
else:
|
||||
raise Exception("Unknown sampler: " + sampler_name)
|
||||
|
||||
init_img = init_info["image"]
|
||||
init_img = init_img.convert("RGB")
|
||||
init_img = resize_image(resize_mode, init_img, width, height)
|
||||
init_mask = init_info["mask"]
|
||||
init_mask = init_mask.convert("RGB")
|
||||
init_mask = resize_image(resize_mode, init_mask, width, height)
|
||||
keep_mask = mask_mode == "Keep masked area"
|
||||
|
||||
assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]'
|
||||
t_enc = int(denoising_strength * ddim_steps)
|
||||
|
||||
@ -789,9 +805,12 @@ def img2img(prompt: str, init_img, ddim_steps: int, sampler_name: str, use_GFPGA
|
||||
prompt_matrix=prompt_matrix,
|
||||
use_GFPGAN=use_GFPGAN,
|
||||
fp=fp,
|
||||
normalize_prompt_weights=normalize_prompt_weights
|
||||
normalize_prompt_weights=normalize_prompt_weights,
|
||||
init_img=init_img,
|
||||
init_mask=init_mask,
|
||||
keep_mask=keep_mask
|
||||
)
|
||||
|
||||
|
||||
del sampler
|
||||
|
||||
return output_images, seed, info, stats
|
||||
@ -813,7 +832,8 @@ img2img_interface = gr.Interface(
|
||||
img2img,
|
||||
inputs=[
|
||||
gr.Textbox(placeholder="A fantasy landscape, trending on artstation.", lines=1),
|
||||
gr.Image(value=sample_img2img, source="upload", interactive=True, type="pil"),
|
||||
gr.Image(value=sample_img2img, source="upload", interactive=True, type="pil", tool="sketch"),
|
||||
gr.Radio(choices=["Keep masked area", "Regenerate only masked area"], label="Mask Mode", value="Keep masked area"),
|
||||
gr.Slider(minimum=1, maximum=150, step=1, label="Sampling Steps", value=50),
|
||||
gr.Radio(label='Sampling method', choices=["DDIM", "k-diffusion"], value="k-diffusion"),
|
||||
gr.Checkbox(label='Fix faces using GFPGAN', value=False, visible=GFPGAN is not None),
|
||||
@ -930,4 +950,4 @@ demo = gr.TabbedInterface(
|
||||
theme="default",
|
||||
)
|
||||
demo.queue(concurrency_count=1)
|
||||
demo.launch(show_error=True, server_name='0.0.0.0')
|
||||
demo.launch(show_error=True, server_name='0.0.0.0')
|
Loading…
Reference in New Issue
Block a user