mirror of
https://github.com/sd-webui/stable-diffusion-webui.git
synced 2025-01-05 20:28:01 +03:00
img2img-mask
Great work, thanks anon!! 💯 🔥
This commit is contained in:
parent
8d6c046a08
commit
9927f7c38a
66
webui.py
66
webui.py
@ -15,7 +15,7 @@ from contextlib import contextmanager, nullcontext
|
|||||||
from einops import rearrange, repeat
|
from einops import rearrange, repeat
|
||||||
from itertools import islice
|
from itertools import islice
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from PIL import Image, ImageFont, ImageDraw
|
from PIL import Image, ImageFont, ImageDraw, ImageFilter, ImageOps
|
||||||
from torch import autocast
|
from torch import autocast
|
||||||
from ldm.models.diffusion.ddim import DDIMSampler
|
from ldm.models.diffusion.ddim import DDIMSampler
|
||||||
from ldm.models.diffusion.plms import PLMSSampler
|
from ldm.models.diffusion.plms import PLMSSampler
|
||||||
@ -104,11 +104,11 @@ class MemUsageMonitor(threading.Thread):
|
|||||||
stop_flag = False
|
stop_flag = False
|
||||||
max_usage = 0
|
max_usage = 0
|
||||||
total = 0
|
total = 0
|
||||||
|
|
||||||
def __init__(self, name):
|
def __init__(self, name):
|
||||||
threading.Thread.__init__(self)
|
threading.Thread.__init__(self)
|
||||||
self.name = name
|
self.name = name
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
print(f"[{self.name}] Recording max memory usage...\n")
|
print(f"[{self.name}] Recording max memory usage...\n")
|
||||||
pynvml.nvmlInit()
|
pynvml.nvmlInit()
|
||||||
@ -121,13 +121,13 @@ class MemUsageMonitor(threading.Thread):
|
|||||||
time.sleep(0.1)
|
time.sleep(0.1)
|
||||||
print(f"[{self.name}] Stopped recording.\n")
|
print(f"[{self.name}] Stopped recording.\n")
|
||||||
pynvml.nvmlShutdown()
|
pynvml.nvmlShutdown()
|
||||||
|
|
||||||
def read(self):
|
def read(self):
|
||||||
return self.max_usage, self.total
|
return self.max_usage, self.total
|
||||||
|
|
||||||
def stop(self):
|
def stop(self):
|
||||||
self.stop_flag = True
|
self.stop_flag = True
|
||||||
|
|
||||||
def read_and_stop(self):
|
def read_and_stop(self):
|
||||||
self.stop_flag = True
|
self.stop_flag = True
|
||||||
return self.max_usage, self.total
|
return self.max_usage, self.total
|
||||||
@ -163,11 +163,11 @@ class MemUsageMonitor(threading.Thread):
|
|||||||
stop_flag = False
|
stop_flag = False
|
||||||
max_usage = 0
|
max_usage = 0
|
||||||
total = 0
|
total = 0
|
||||||
|
|
||||||
def __init__(self, name):
|
def __init__(self, name):
|
||||||
threading.Thread.__init__(self)
|
threading.Thread.__init__(self)
|
||||||
self.name = name
|
self.name = name
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
print(f"[{self.name}] Recording max memory usage...\n")
|
print(f"[{self.name}] Recording max memory usage...\n")
|
||||||
pynvml.nvmlInit()
|
pynvml.nvmlInit()
|
||||||
@ -180,13 +180,13 @@ class MemUsageMonitor(threading.Thread):
|
|||||||
time.sleep(0.1)
|
time.sleep(0.1)
|
||||||
print(f"[{self.name}] Stopped recording.\n")
|
print(f"[{self.name}] Stopped recording.\n")
|
||||||
pynvml.nvmlShutdown()
|
pynvml.nvmlShutdown()
|
||||||
|
|
||||||
def read(self):
|
def read(self):
|
||||||
return self.max_usage, self.total
|
return self.max_usage, self.total
|
||||||
|
|
||||||
def stop(self):
|
def stop(self):
|
||||||
self.stop_flag = True
|
self.stop_flag = True
|
||||||
|
|
||||||
def read_and_stop(self):
|
def read_and_stop(self):
|
||||||
self.stop_flag = True
|
self.stop_flag = True
|
||||||
return self.max_usage, self.total
|
return self.max_usage, self.total
|
||||||
@ -389,7 +389,7 @@ def check_prompt_length(prompt, comments):
|
|||||||
comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
|
comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
|
||||||
|
|
||||||
|
|
||||||
def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name, skip_grid, skip_save, batch_size, n_iter, steps, cfg_scale, width, height, prompt_matrix, use_GFPGAN, fp, do_not_save_grid=False, normalize_prompt_weights=True):
|
def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name, skip_grid, skip_save, batch_size, n_iter, steps, cfg_scale, width, height, prompt_matrix, use_GFPGAN, fp, do_not_save_grid=False, normalize_prompt_weights=True, init_img=None, init_mask=None, keep_mask=True):
|
||||||
"""this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
|
"""this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
|
||||||
assert prompt is not None
|
assert prompt is not None
|
||||||
torch_gc()
|
torch_gc()
|
||||||
@ -458,7 +458,7 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name,
|
|||||||
uc = model.get_learned_conditioning(len(prompts) * [""])
|
uc = model.get_learned_conditioning(len(prompts) * [""])
|
||||||
if isinstance(prompts, tuple):
|
if isinstance(prompts, tuple):
|
||||||
prompts = list(prompts)
|
prompts = list(prompts)
|
||||||
|
|
||||||
# split the prompt if it has : for weighting
|
# split the prompt if it has : for weighting
|
||||||
# TODO for speed it might help to have this occur when all_prompts filled??
|
# TODO for speed it might help to have this occur when all_prompts filled??
|
||||||
subprompts,weights = split_weighted_subprompts(prompts[0])
|
subprompts,weights = split_weighted_subprompts(prompts[0])
|
||||||
@ -474,16 +474,16 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name,
|
|||||||
weight = weight / totalPromptWeight
|
weight = weight / totalPromptWeight
|
||||||
#print(f"{subprompts[i]} {weight*100.0}%")
|
#print(f"{subprompts[i]} {weight*100.0}%")
|
||||||
# note if alpha negative, it functions same as torch.sub
|
# note if alpha negative, it functions same as torch.sub
|
||||||
c = torch.add(c,model.get_learned_conditioning(subprompts[i]), alpha=weight)
|
c = torch.add(c,model.get_learned_conditioning(subprompts[i]), alpha=weight)
|
||||||
else: # just behave like usual
|
else: # just behave like usual
|
||||||
c = model.get_learned_conditioning(prompts)
|
c = model.get_learned_conditioning(prompts)
|
||||||
|
|
||||||
shape = [opt_C, height // opt_f, width // opt_f]
|
shape = [opt_C, height // opt_f, width // opt_f]
|
||||||
|
|
||||||
# we manually generate all input noises because each one should have a specific seed
|
# we manually generate all input noises because each one should have a specific seed
|
||||||
x = create_random_tensors([opt_C, height // opt_f, width // opt_f], seeds=seeds)
|
x = create_random_tensors([opt_C, height // opt_f, width // opt_f], seeds=seeds)
|
||||||
samples_ddim = func_sample(init_data=init_data, x=x, conditioning=c, unconditional_conditioning=uc, sampler_name=sampler_name)
|
samples_ddim = func_sample(init_data=init_data, x=x, conditioning=c, unconditional_conditioning=uc, sampler_name=sampler_name)
|
||||||
|
|
||||||
|
|
||||||
x_samples_ddim = model.decode_first_stage(samples_ddim)
|
x_samples_ddim = model.decode_first_stage(samples_ddim)
|
||||||
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
||||||
@ -497,6 +497,14 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name,
|
|||||||
|
|
||||||
|
|
||||||
image = Image.fromarray(x_sample)
|
image = Image.fromarray(x_sample)
|
||||||
|
if init_mask:
|
||||||
|
init_mask = init_mask if keep_mask else ImageOps.invert(init_mask)
|
||||||
|
init_mask = init_mask.filter(ImageFilter.GaussianBlur(3))
|
||||||
|
init_mask = init_mask.convert('L')
|
||||||
|
init_img = init_img.convert('RGB')
|
||||||
|
image = image.convert('RGB')
|
||||||
|
image = Image.composite(init_img, image, init_mask)
|
||||||
|
|
||||||
filename = f"{base_count:05}-{seeds[i]}_{prompts[i].replace(' ', '_').translate({ord(x): '' for x in invalid_filename_chars})[:128]}.png"
|
filename = f"{base_count:05}-{seeds[i]}_{prompts[i].replace(' ', '_').translate({ord(x): '' for x in invalid_filename_chars})[:128]}.png"
|
||||||
if not skip_save:
|
if not skip_save:
|
||||||
image.save(os.path.join(sample_path, filename))
|
image.save(os.path.join(sample_path, filename))
|
||||||
@ -532,10 +540,10 @@ Steps: {steps}, Sampler: {sampler_name}, CFG scale: {cfg_scale}, Seed: {seed}{',
|
|||||||
stats = f'''
|
stats = f'''
|
||||||
Took { round(time_diff, 2) }s total ({ round(time_diff/(len(all_prompts)),2) }s per image)
|
Took { round(time_diff, 2) }s total ({ round(time_diff/(len(all_prompts)),2) }s per image)
|
||||||
Peak memory usage: { -(mem_max_used // -1_048_576) } MiB / { -(mem_total // -1_048_576) } MiB / { round(mem_max_used/mem_total*100, 3) }%'''
|
Peak memory usage: { -(mem_max_used // -1_048_576) } MiB / { -(mem_total // -1_048_576) } MiB / { round(mem_max_used/mem_total*100, 3) }%'''
|
||||||
|
|
||||||
for comment in comments:
|
for comment in comments:
|
||||||
info += "\n\n" + comment
|
info += "\n\n" + comment
|
||||||
|
|
||||||
#mem_mon.stop()
|
#mem_mon.stop()
|
||||||
#del mem_mon
|
#del mem_mon
|
||||||
torch_gc()
|
torch_gc()
|
||||||
@ -673,7 +681,7 @@ txt2img_interface = gr.Interface(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def img2img(prompt: str, init_img, ddim_steps: int, sampler_name: str, use_GFPGAN: bool, prompt_matrix, loopback: bool, skip_grid: bool, skip_save: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, height: int, width: int, resize_mode: int, normalize_prompt_weights: bool, fp):
|
def img2img(prompt: str, init_info, mask_mode, ddim_steps: int, sampler_name: str, use_GFPGAN: bool, prompt_matrix, loopback: bool, skip_grid: bool, skip_save: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, height: int, width: int, resize_mode: int, normalize_prompt_weights: bool, fp):
|
||||||
outpath = opt.outdir or "outputs/img2img-samples"
|
outpath = opt.outdir or "outputs/img2img-samples"
|
||||||
err = False
|
err = False
|
||||||
seed = seed_to_int(seed)
|
seed = seed_to_int(seed)
|
||||||
@ -685,6 +693,14 @@ def img2img(prompt: str, init_img, ddim_steps: int, sampler_name: str, use_GFPGA
|
|||||||
else:
|
else:
|
||||||
raise Exception("Unknown sampler: " + sampler_name)
|
raise Exception("Unknown sampler: " + sampler_name)
|
||||||
|
|
||||||
|
init_img = init_info["image"]
|
||||||
|
init_img = init_img.convert("RGB")
|
||||||
|
init_img = resize_image(resize_mode, init_img, width, height)
|
||||||
|
init_mask = init_info["mask"]
|
||||||
|
init_mask = init_mask.convert("RGB")
|
||||||
|
init_mask = resize_image(resize_mode, init_mask, width, height)
|
||||||
|
keep_mask = mask_mode == "Keep masked area"
|
||||||
|
|
||||||
assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]'
|
assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]'
|
||||||
t_enc = int(denoising_strength * ddim_steps)
|
t_enc = int(denoising_strength * ddim_steps)
|
||||||
|
|
||||||
@ -789,9 +805,12 @@ def img2img(prompt: str, init_img, ddim_steps: int, sampler_name: str, use_GFPGA
|
|||||||
prompt_matrix=prompt_matrix,
|
prompt_matrix=prompt_matrix,
|
||||||
use_GFPGAN=use_GFPGAN,
|
use_GFPGAN=use_GFPGAN,
|
||||||
fp=fp,
|
fp=fp,
|
||||||
normalize_prompt_weights=normalize_prompt_weights
|
normalize_prompt_weights=normalize_prompt_weights,
|
||||||
|
init_img=init_img,
|
||||||
|
init_mask=init_mask,
|
||||||
|
keep_mask=keep_mask
|
||||||
)
|
)
|
||||||
|
|
||||||
del sampler
|
del sampler
|
||||||
|
|
||||||
return output_images, seed, info, stats
|
return output_images, seed, info, stats
|
||||||
@ -813,7 +832,8 @@ img2img_interface = gr.Interface(
|
|||||||
img2img,
|
img2img,
|
||||||
inputs=[
|
inputs=[
|
||||||
gr.Textbox(placeholder="A fantasy landscape, trending on artstation.", lines=1),
|
gr.Textbox(placeholder="A fantasy landscape, trending on artstation.", lines=1),
|
||||||
gr.Image(value=sample_img2img, source="upload", interactive=True, type="pil"),
|
gr.Image(value=sample_img2img, source="upload", interactive=True, type="pil", tool="sketch"),
|
||||||
|
gr.Radio(choices=["Keep masked area", "Regenerate only masked area"], label="Mask Mode", value="Keep masked area"),
|
||||||
gr.Slider(minimum=1, maximum=150, step=1, label="Sampling Steps", value=50),
|
gr.Slider(minimum=1, maximum=150, step=1, label="Sampling Steps", value=50),
|
||||||
gr.Radio(label='Sampling method', choices=["DDIM", "k-diffusion"], value="k-diffusion"),
|
gr.Radio(label='Sampling method', choices=["DDIM", "k-diffusion"], value="k-diffusion"),
|
||||||
gr.Checkbox(label='Fix faces using GFPGAN', value=False, visible=GFPGAN is not None),
|
gr.Checkbox(label='Fix faces using GFPGAN', value=False, visible=GFPGAN is not None),
|
||||||
@ -930,4 +950,4 @@ demo = gr.TabbedInterface(
|
|||||||
theme="default",
|
theme="default",
|
||||||
)
|
)
|
||||||
demo.queue(concurrency_count=1)
|
demo.queue(concurrency_count=1)
|
||||||
demo.launch(show_error=True, server_name='0.0.0.0')
|
demo.launch(show_error=True, server_name='0.0.0.0')
|
Loading…
Reference in New Issue
Block a user