stable-diffusion-webui/modules/sd_samplers_common.py
2023-08-08 22:17:25 +03:00

301 lines
10 KiB
Python

import inspect
from collections import namedtuple, deque
import numpy as np
import torch
from PIL import Image
from modules import devices, images, sd_vae_approx, sd_samplers, sd_vae_taesd, shared, sd_models
from modules.shared import opts, state
import k_diffusion.sampling
SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options'])
def setup_img2img_steps(p, steps=None):
if opts.img2img_fix_steps or steps is not None:
requested_steps = (steps or p.steps)
steps = int(requested_steps / min(p.denoising_strength, 0.999)) if p.denoising_strength > 0 else 0
t_enc = requested_steps - 1
else:
steps = p.steps
t_enc = int(min(p.denoising_strength, 0.999) * steps)
return steps, t_enc
approximation_indexes = {"Full": 0, "Approx NN": 1, "Approx cheap": 2, "TAESD": 3}
def samples_to_images_tensor(sample, approximation=None, model=None):
'''latents -> images [-1, 1]'''
if approximation is None:
approximation = approximation_indexes.get(opts.show_progress_type, 0)
if approximation == 2:
x_sample = sd_vae_approx.cheap_approximation(sample)
elif approximation == 1:
x_sample = sd_vae_approx.model()(sample.to(devices.device, devices.dtype)).detach()
elif approximation == 3:
x_sample = sample * 1.5
x_sample = sd_vae_taesd.decoder_model()(x_sample.to(devices.device, devices.dtype)).detach()
x_sample = x_sample * 2 - 1
else:
if model is None:
model = shared.sd_model
x_sample = model.decode_first_stage(sample.to(model.first_stage_model.dtype))
return x_sample
def single_sample_to_image(sample, approximation=None):
x_sample = samples_to_images_tensor(sample.unsqueeze(0), approximation)[0] * 0.5 + 0.5
x_sample = torch.clamp(x_sample, min=0.0, max=1.0)
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
x_sample = x_sample.astype(np.uint8)
return Image.fromarray(x_sample)
def decode_first_stage(model, x):
x = x.to(devices.dtype_vae)
approx_index = approximation_indexes.get(opts.sd_vae_decode_method, 0)
return samples_to_images_tensor(x, approx_index, model)
def sample_to_image(samples, index=0, approximation=None):
return single_sample_to_image(samples[index], approximation)
def samples_to_image_grid(samples, approximation=None):
return images.image_grid([single_sample_to_image(sample, approximation) for sample in samples])
def images_tensor_to_samples(image, approximation=None, model=None):
'''image[0, 1] -> latent'''
if approximation is None:
approximation = approximation_indexes.get(opts.sd_vae_encode_method, 0)
if approximation == 3:
image = image.to(devices.device, devices.dtype)
x_latent = sd_vae_taesd.encoder_model()(image)
else:
if model is None:
model = shared.sd_model
image = image.to(shared.device, dtype=devices.dtype_vae)
image = image * 2 - 1
x_latent = model.get_first_stage_encoding(model.encode_first_stage(image))
return x_latent
def store_latent(decoded):
state.current_latent = decoded
if opts.live_previews_enable and opts.show_progress_every_n_steps > 0 and shared.state.sampling_step % opts.show_progress_every_n_steps == 0:
if not shared.parallel_processing_allowed:
shared.state.assign_current_image(sample_to_image(decoded))
def is_sampler_using_eta_noise_seed_delta(p):
"""returns whether sampler from config will use eta noise seed delta for image creation"""
sampler_config = sd_samplers.find_sampler_config(p.sampler_name)
eta = p.eta
if eta is None and p.sampler is not None:
eta = p.sampler.eta
if eta is None and sampler_config is not None:
eta = 0 if sampler_config.options.get("default_eta_is_0", False) else 1.0
if eta == 0:
return False
return sampler_config.options.get("uses_ensd", False)
class InterruptedException(BaseException):
pass
def replace_torchsde_browinan():
import torchsde._brownian.brownian_interval
def torchsde_randn(size, dtype, device, seed):
return devices.randn_local(seed, size).to(device=device, dtype=dtype)
torchsde._brownian.brownian_interval._randn = torchsde_randn
replace_torchsde_browinan()
def apply_refiner(sampler):
completed_ratio = sampler.step / sampler.steps
if completed_ratio <= shared.opts.sd_refiner_switch_at:
return False
if shared.opts.sd_refiner_checkpoint == "None":
return False
if shared.sd_model.sd_checkpoint_info.title == shared.opts.sd_refiner_checkpoint:
return False
refiner_checkpoint_info = sd_models.get_closet_checkpoint_match(shared.opts.sd_refiner_checkpoint)
if refiner_checkpoint_info is None:
raise Exception(f'Could not find checkpoint with name {shared.opts.sd_refiner_checkpoint}')
sampler.p.extra_generation_params['Refiner'] = refiner_checkpoint_info.short_title
sampler.p.extra_generation_params['Refiner switch at'] = shared.opts.sd_refiner_switch_at
with sd_models.SkipWritingToConfig():
sd_models.reload_model_weights(info=refiner_checkpoint_info)
devices.torch_gc()
sampler.p.setup_conds()
sampler.update_inner_model()
return True
class TorchHijack:
def __init__(self, sampler_noises):
# Using a deque to efficiently receive the sampler_noises in the same order as the previous index-based
# implementation.
self.sampler_noises = deque(sampler_noises)
def __getattr__(self, item):
if item == 'randn_like':
return self.randn_like
if hasattr(torch, item):
return getattr(torch, item)
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'")
def randn_like(self, x):
if self.sampler_noises:
noise = self.sampler_noises.popleft()
if noise.shape == x.shape:
return noise
return devices.randn_like(x)
class Sampler:
def __init__(self, funcname):
self.funcname = funcname
self.func = funcname
self.extra_params = []
self.sampler_noises = None
self.stop_at = None
self.eta = None
self.config = None # set by the function calling the constructor
self.last_latent = None
self.s_min_uncond = None
self.s_churn = 0.0
self.s_tmin = 0.0
self.s_tmax = float('inf')
self.s_noise = 1.0
self.eta_option_field = 'eta_ancestral'
self.eta_infotext_field = 'Eta'
self.conditioning_key = shared.sd_model.model.conditioning_key
self.p = None
self.model_wrap_cfg = None
self.sampler_extra_args = None
def callback_state(self, d):
step = d['i']
if self.stop_at is not None and step > self.stop_at:
raise InterruptedException
state.sampling_step = step
shared.total_tqdm.update()
def launch_sampling(self, steps, func):
self.model_wrap_cfg.steps = steps
state.sampling_steps = steps
state.sampling_step = 0
try:
return func()
except RecursionError:
print(
'Encountered RecursionError during sampling, returning last latent. '
'rho >5 with a polyexponential scheduler may cause this error. '
'You should try to use a smaller rho value instead.'
)
return self.last_latent
except InterruptedException:
return self.last_latent
def number_of_needed_noises(self, p):
return p.steps
def initialize(self, p) -> dict:
self.p = p
self.model_wrap_cfg.p = p
self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None
self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None
self.model_wrap_cfg.step = 0
self.model_wrap_cfg.image_cfg_scale = getattr(p, 'image_cfg_scale', None)
self.eta = p.eta if p.eta is not None else getattr(opts, self.eta_option_field, 0.0)
self.s_min_uncond = getattr(p, 's_min_uncond', 0.0)
k_diffusion.sampling.torch = TorchHijack(self.sampler_noises if self.sampler_noises is not None else [])
extra_params_kwargs = {}
for param_name in self.extra_params:
if hasattr(p, param_name) and param_name in inspect.signature(self.func).parameters:
extra_params_kwargs[param_name] = getattr(p, param_name)
if 'eta' in inspect.signature(self.func).parameters:
if self.eta != 1.0:
p.extra_generation_params[self.eta_infotext_field] = self.eta
extra_params_kwargs['eta'] = self.eta
if len(self.extra_params) > 0:
s_churn = getattr(opts, 's_churn', p.s_churn)
s_tmin = getattr(opts, 's_tmin', p.s_tmin)
s_tmax = getattr(opts, 's_tmax', p.s_tmax) or self.s_tmax # 0 = inf
s_noise = getattr(opts, 's_noise', p.s_noise)
if s_churn != self.s_churn:
extra_params_kwargs['s_churn'] = s_churn
p.s_churn = s_churn
p.extra_generation_params['Sigma churn'] = s_churn
if s_tmin != self.s_tmin:
extra_params_kwargs['s_tmin'] = s_tmin
p.s_tmin = s_tmin
p.extra_generation_params['Sigma tmin'] = s_tmin
if s_tmax != self.s_tmax:
extra_params_kwargs['s_tmax'] = s_tmax
p.s_tmax = s_tmax
p.extra_generation_params['Sigma tmax'] = s_tmax
if s_noise != self.s_noise:
extra_params_kwargs['s_noise'] = s_noise
p.s_noise = s_noise
p.extra_generation_params['Sigma noise'] = s_noise
return extra_params_kwargs
def create_noise_sampler(self, x, sigmas, p):
"""For DPM++ SDE: manually create noise sampler to enable deterministic results across different batch sizes"""
if shared.opts.no_dpmpp_sde_batch_determinism:
return None
from k_diffusion.sampling import BrownianTreeNoiseSampler
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
current_iter_seeds = p.all_seeds[p.iteration * p.batch_size:(p.iteration + 1) * p.batch_size]
return BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=current_iter_seeds)