stable-diffusion-webui/modules/sd_samplers_compvis.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

225 lines
11 KiB
Python
Raw Normal View History

import math
import ldm.models.diffusion.ddim
import ldm.models.diffusion.plms
import numpy as np
import torch
from modules.shared import state
from modules import sd_samplers_common, prompt_parser, shared
2023-02-10 14:30:20 +03:00
import modules.models.diffusion.uni_pc
samplers_data_compvis = [
sd_samplers_common.SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {"default_eta_is_0": True, "uses_ensd": True, "no_sdxl": True}),
sd_samplers_common.SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {"no_sdxl": True}),
sd_samplers_common.SamplerData('UniPC', lambda model: VanillaStableDiffusionSampler(modules.models.diffusion.uni_pc.UniPCSampler, model), [], {"no_sdxl": True}),
]
class VanillaStableDiffusionSampler:
def __init__(self, constructor, sd_model):
self.sampler = constructor(sd_model)
2023-02-10 14:30:20 +03:00
self.is_ddim = hasattr(self.sampler, 'p_sample_ddim')
2022-11-26 16:10:46 +03:00
self.is_plms = hasattr(self.sampler, 'p_sample_plms')
2023-02-10 14:30:20 +03:00
self.is_unipc = isinstance(self.sampler, modules.models.diffusion.uni_pc.UniPCSampler)
self.orig_p_sample_ddim = None
if self.is_plms:
self.orig_p_sample_ddim = self.sampler.p_sample_plms
elif self.is_ddim:
self.orig_p_sample_ddim = self.sampler.p_sample_ddim
self.mask = None
self.nmask = None
self.init_latent = None
self.sampler_noises = None
2022-09-15 13:10:16 +03:00
self.step = 0
self.stop_at = None
self.eta = None
2022-10-06 14:12:52 +03:00
self.config = None
self.last_latent = None
2022-09-15 13:10:16 +03:00
self.conditioning_key = sd_model.model.conditioning_key
def number_of_needed_noises(self, p):
return 0
def launch_sampling(self, steps, func):
state.sampling_steps = steps
state.sampling_step = 0
try:
return func()
except sd_samplers_common.InterruptedException:
return self.last_latent
2022-09-15 13:10:16 +03:00
def p_sample_ddim_hook(self, x_dec, cond, ts, unconditional_conditioning, *args, **kwargs):
2023-02-10 14:30:20 +03:00
x_dec, ts, cond, unconditional_conditioning = self.before_sample(x_dec, ts, cond, unconditional_conditioning)
2023-05-10 11:19:16 +03:00
res = self.orig_p_sample_ddim(x_dec, cond, ts, *args, unconditional_conditioning=unconditional_conditioning, **kwargs)
2023-02-10 14:30:20 +03:00
x_dec, ts, cond, unconditional_conditioning, res = self.after_sample(x_dec, ts, cond, unconditional_conditioning, res)
return res
def before_sample(self, x, ts, cond, unconditional_conditioning):
if state.interrupted or state.skipped:
raise sd_samplers_common.InterruptedException
if self.stop_at is not None and self.step > self.stop_at:
raise sd_samplers_common.InterruptedException
# Have to unwrap the inpainting conditioning here to perform pre-processing
image_conditioning = None
uc_image_conditioning = None
if isinstance(cond, dict):
if self.conditioning_key == "crossattn-adm":
image_conditioning = cond["c_adm"]
uc_image_conditioning = unconditional_conditioning["c_adm"]
else:
image_conditioning = cond["c_concat"][0]
cond = cond["c_crossattn"][0]
unconditional_conditioning = unconditional_conditioning["c_crossattn"][0]
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
2022-09-15 13:10:16 +03:00
unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step)
2023-05-10 11:05:02 +03:00
assert all(len(conds) == 1 for conds in conds_list), 'composition via AND is not supported for DDIM/PLMS samplers'
cond = tensor
# for DDIM, shapes must match, we can't just process cond and uncond independently;
# filling unconditional_conditioning with repeats of the last vector to match length is
# not 100% correct but should work well enough
if unconditional_conditioning.shape[1] < cond.shape[1]:
last_vector = unconditional_conditioning[:, -1:]
last_vector_repeated = last_vector.repeat([1, cond.shape[1] - unconditional_conditioning.shape[1], 1])
unconditional_conditioning = torch.hstack([unconditional_conditioning, last_vector_repeated])
elif unconditional_conditioning.shape[1] > cond.shape[1]:
unconditional_conditioning = unconditional_conditioning[:, :cond.shape[1]]
2022-09-15 13:10:16 +03:00
if self.mask is not None:
img_orig = self.sampler.model.q_sample(self.init_latent, ts)
2023-02-10 14:30:20 +03:00
x = img_orig * self.mask + self.nmask * x
2022-09-15 13:10:16 +03:00
# Wrap the image conditioning back up since the DDIM code can accept the dict directly.
# Note that they need to be lists because it just concatenates them later.
if image_conditioning is not None:
if self.conditioning_key == "crossattn-adm":
cond = {"c_adm": image_conditioning, "c_crossattn": [cond]}
unconditional_conditioning = {"c_adm": uc_image_conditioning, "c_crossattn": [unconditional_conditioning]}
else:
cond = {"c_concat": [image_conditioning], "c_crossattn": [cond]}
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
2023-02-10 14:30:20 +03:00
return x, ts, cond, unconditional_conditioning
2023-02-10 15:47:08 +03:00
def update_step(self, last_latent):
2022-09-15 13:10:16 +03:00
if self.mask is not None:
2023-02-10 15:47:08 +03:00
self.last_latent = self.init_latent * self.mask + self.nmask * last_latent
2022-09-15 13:10:16 +03:00
else:
2023-02-10 15:47:08 +03:00
self.last_latent = last_latent
sd_samplers_common.store_latent(self.last_latent)
2022-09-15 13:10:16 +03:00
self.step += 1
state.sampling_step = self.step
shared.total_tqdm.update()
2023-02-10 15:47:08 +03:00
def after_sample(self, x, ts, cond, uncond, res):
if not self.is_unipc:
self.update_step(res[1])
2023-02-10 14:30:20 +03:00
return x, ts, cond, uncond, res
2023-02-10 15:47:08 +03:00
def unipc_after_update(self, x, model_x):
self.update_step(x)
def initialize(self, p):
if self.is_ddim:
self.eta = p.eta if p.eta is not None else shared.opts.eta_ddim
else:
self.eta = 0.0
if self.eta != 0.0:
p.extra_generation_params["Eta DDIM"] = self.eta
2023-03-11 12:09:36 +03:00
if self.is_unipc:
keys = [
('UniPC variant', 'uni_pc_variant'),
('UniPC skip type', 'uni_pc_skip_type'),
('UniPC order', 'uni_pc_order'),
('UniPC lower order final', 'uni_pc_lower_order_final'),
]
for name, key in keys:
v = getattr(shared.opts, key)
if v != shared.opts.get_default(key):
p.extra_generation_params[name] = v
for fieldname in ['p_sample_ddim', 'p_sample_plms']:
if hasattr(self.sampler, fieldname):
setattr(self.sampler, fieldname, self.p_sample_ddim_hook)
2023-02-10 14:30:20 +03:00
if self.is_unipc:
2023-02-10 15:47:08 +03:00
self.sampler.set_hooks(lambda x, t, c, u: self.before_sample(x, t, c, u), lambda x, t, c, u, r: self.after_sample(x, t, c, u, r), lambda x, mx: self.unipc_after_update(x, mx))
self.mask = p.mask if hasattr(p, 'mask') else None
self.nmask = p.nmask if hasattr(p, 'nmask') else None
2023-03-11 12:09:36 +03:00
def adjust_steps_if_invalid(self, p, num_steps):
if ((self.config.name == 'DDIM') and p.ddim_discretize == 'uniform') or (self.config.name == 'PLMS') or (self.config.name == 'UniPC'):
2023-03-11 03:56:14 +03:00
if self.config.name == 'UniPC' and num_steps < shared.opts.uni_pc_order:
num_steps = shared.opts.uni_pc_order
valid_step = 999 / (1000 // num_steps)
if valid_step == math.floor(valid_step):
return int(valid_step) + 1
2023-03-11 03:56:14 +03:00
return num_steps
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps)
steps = self.adjust_steps_if_invalid(p, steps)
2022-09-28 22:30:52 +03:00
self.initialize(p)
self.sampler.make_schedule(ddim_num_steps=steps, ddim_eta=self.eta, ddim_discretize=p.ddim_discretize, verbose=False)
x1 = self.sampler.stochastic_encode(x, torch.tensor([t_enc] * int(x.shape[0])).to(shared.device), noise=noise)
2022-09-19 16:42:56 +03:00
self.init_latent = x
self.last_latent = x
self.step = 0
2022-10-20 00:14:24 +03:00
# Wrap the conditioning models with additional image conditioning for inpainting model
if image_conditioning is not None:
if self.conditioning_key == "crossattn-adm":
conditioning = {"c_adm": image_conditioning, "c_crossattn": [conditioning]}
unconditional_conditioning = {"c_adm": torch.zeros_like(image_conditioning), "c_crossattn": [unconditional_conditioning]}
else:
conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]}
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
2022-10-21 17:46:32 +03:00
samples = self.launch_sampling(t_enc + 1, lambda: self.sampler.decode(x1, conditioning, t_enc, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning))
return samples
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
self.initialize(p)
2022-09-08 19:20:41 +03:00
self.init_latent = None
self.last_latent = x
self.step = 0
2022-09-08 19:20:41 +03:00
steps = self.adjust_steps_if_invalid(p, steps or p.steps)
2022-09-19 16:42:56 +03:00
# Wrap the conditioning models with additional image conditioning for inpainting model
2022-11-26 16:10:46 +03:00
# dummy_for_plms is needed because PLMS code checks the first item in the dict to have the right shape
if image_conditioning is not None:
if self.conditioning_key == "crossattn-adm":
conditioning = {"dummy_for_plms": np.zeros((conditioning.shape[0],)), "c_crossattn": [conditioning], "c_adm": image_conditioning}
unconditional_conditioning = {"c_crossattn": [unconditional_conditioning], "c_adm": torch.zeros_like(image_conditioning)}
else:
conditioning = {"dummy_for_plms": np.zeros((conditioning.shape[0],)), "c_crossattn": [conditioning], "c_concat": [image_conditioning]}
unconditional_conditioning = {"c_crossattn": [unconditional_conditioning], "c_concat": [image_conditioning]}
samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0])
2022-09-13 20:12:24 +03:00
return samples_ddim