mirror of
https://github.com/sd-webui/stable-diffusion-webui.git
synced 2024-12-14 06:35:14 +03:00
Create kdiffusion.py
This commit is contained in:
parent
a70a69c7fb
commit
e71c4225ca
55
ldm/models/diffusion/kdiffusion.py
Normal file
55
ldm/models/diffusion/kdiffusion.py
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
import k_diffusion as K
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
class KDiffusionSampler:
|
||||||
|
def __init__(self, m, sampler, callback=None):
|
||||||
|
self.model = m
|
||||||
|
self.model_wrap = K.external.CompVisDenoiser(m)
|
||||||
|
self.schedule = sampler
|
||||||
|
self.generation_callback = callback
|
||||||
|
def get_sampler_name(self):
|
||||||
|
return self.schedule
|
||||||
|
def sample(self, S, conditioning, unconditional_guidance_scale, unconditional_conditioning, x_T):
|
||||||
|
sigmas = self.model_wrap.get_sigmas(S)
|
||||||
|
x = x_T * sigmas[0]
|
||||||
|
model_wrap_cfg = CFGDenoiser(self.model_wrap)
|
||||||
|
samples_ddim = None
|
||||||
|
samples_ddim = K.sampling.__dict__[f'sample_{self.schedule}'](
|
||||||
|
model_wrap_cfg, x, sigmas,
|
||||||
|
extra_args={'cond': conditioning, 'uncond': unconditional_conditioning,'cond_scale': unconditional_guidance_scale},
|
||||||
|
disable=False, callback=self.generation_callback)
|
||||||
|
#
|
||||||
|
return samples_ddim, None
|
||||||
|
class CFGMaskedDenoiser(nn.Module):
|
||||||
|
def __init__(self, model):
|
||||||
|
super().__init__()
|
||||||
|
self.inner_model = model
|
||||||
|
|
||||||
|
def forward(self, x, sigma, uncond, cond, cond_scale, mask, x0, xi):
|
||||||
|
x_in = x
|
||||||
|
x_in = torch.cat([x_in] * 2)
|
||||||
|
sigma_in = torch.cat([sigma] * 2)
|
||||||
|
cond_in = torch.cat([uncond, cond])
|
||||||
|
uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
|
||||||
|
denoised = uncond + (cond - uncond) * cond_scale
|
||||||
|
|
||||||
|
if mask is not None:
|
||||||
|
assert x0 is not None
|
||||||
|
img_orig = x0
|
||||||
|
mask_inv = 1. - mask
|
||||||
|
denoised = (img_orig * mask_inv) + (mask * denoised)
|
||||||
|
|
||||||
|
return denoised
|
||||||
|
|
||||||
|
class CFGDenoiser(nn.Module):
|
||||||
|
def __init__(self, model):
|
||||||
|
super().__init__()
|
||||||
|
self.inner_model = model
|
||||||
|
|
||||||
|
def forward(self, x, sigma, uncond, cond, cond_scale):
|
||||||
|
x_in = torch.cat([x] * 2)
|
||||||
|
sigma_in = torch.cat([sigma] * 2)
|
||||||
|
cond_in = torch.cat([uncond, cond])
|
||||||
|
uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
|
||||||
|
return uncond + (cond - uncond) * cond_scale
|
Loading…
Reference in New Issue
Block a user