Create kdiffusion.py

This commit is contained in:
hlky 2022-10-26 08:14:25 +01:00
parent a70a69c7fb
commit e71c4225ca
No known key found for this signature in database
GPG Key ID: 55A99F1E80D907D5

View File

@ -0,0 +1,55 @@
import k_diffusion as K
import torch
import torch.nn as nn
class KDiffusionSampler:
def __init__(self, m, sampler, callback=None):
self.model = m
self.model_wrap = K.external.CompVisDenoiser(m)
self.schedule = sampler
self.generation_callback = callback
def get_sampler_name(self):
return self.schedule
def sample(self, S, conditioning, unconditional_guidance_scale, unconditional_conditioning, x_T):
sigmas = self.model_wrap.get_sigmas(S)
x = x_T * sigmas[0]
model_wrap_cfg = CFGDenoiser(self.model_wrap)
samples_ddim = None
samples_ddim = K.sampling.__dict__[f'sample_{self.schedule}'](
model_wrap_cfg, x, sigmas,
extra_args={'cond': conditioning, 'uncond': unconditional_conditioning,'cond_scale': unconditional_guidance_scale},
disable=False, callback=self.generation_callback)
#
return samples_ddim, None
class CFGMaskedDenoiser(nn.Module):
def __init__(self, model):
super().__init__()
self.inner_model = model
def forward(self, x, sigma, uncond, cond, cond_scale, mask, x0, xi):
x_in = x
x_in = torch.cat([x_in] * 2)
sigma_in = torch.cat([sigma] * 2)
cond_in = torch.cat([uncond, cond])
uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
denoised = uncond + (cond - uncond) * cond_scale
if mask is not None:
assert x0 is not None
img_orig = x0
mask_inv = 1. - mask
denoised = (img_orig * mask_inv) + (mask * denoised)
return denoised
class CFGDenoiser(nn.Module):
def __init__(self, model):
super().__init__()
self.inner_model = model
def forward(self, x, sigma, uncond, cond, cond_scale):
x_in = torch.cat([x] * 2)
sigma_in = torch.cat([sigma] * 2)
cond_in = torch.cat([uncond, cond])
uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
return uncond + (cond - uncond) * cond_scale