mirror of
https://github.com/Sygil-Dev/sygil-webui.git
synced 2024-12-14 14:05:36 +03:00
Improved optimized mode speed and VRAM usage.
This commit is contained in:
parent
a3d5e5e548
commit
7ccfc86397
@ -6,7 +6,8 @@ https://github.com/CompVis/taming-transformers
|
||||
-- merci
|
||||
"""
|
||||
|
||||
import time
|
||||
import time, math
|
||||
from tqdm.auto import trange, tqdm
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from tqdm import tqdm
|
||||
@ -21,7 +22,7 @@ from ldm.util import exists, default, instantiate_from_config
|
||||
from ldm.modules.diffusionmodules.util import make_beta_schedule
|
||||
from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
|
||||
from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like
|
||||
|
||||
from .samplers import CompVisDenoiser, get_ancestral_step, to_d, append_dims,linear_multistep_coeff
|
||||
|
||||
def disabled_train(self):
|
||||
"""Overwrite model.train with this function to make sure train/eval mode
|
||||
@ -92,7 +93,6 @@ class DDPM(pl.LightningModule):
|
||||
cosine_s=cosine_s)
|
||||
alphas = 1. - betas
|
||||
alphas_cumprod = np.cumprod(alphas, axis=0)
|
||||
alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
|
||||
|
||||
timesteps, = betas.shape
|
||||
self.num_timesteps = int(timesteps)
|
||||
@ -104,7 +104,6 @@ class DDPM(pl.LightningModule):
|
||||
|
||||
self.register_buffer('betas', to_torch(betas))
|
||||
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
|
||||
self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
|
||||
|
||||
|
||||
class FirstStage(DDPM):
|
||||
@ -403,7 +402,7 @@ class UNet(DDPM):
|
||||
h,emb,hs = self.model1(x_noisy[0:step], t[:step], cond[:step])
|
||||
bs = cond.shape[0]
|
||||
|
||||
assert bs%2 == 0
|
||||
# assert bs%2 == 0
|
||||
lenhs = len(hs)
|
||||
|
||||
for i in range(step,bs,step):
|
||||
@ -446,15 +445,14 @@ class UNet(DDPM):
|
||||
|
||||
self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
|
||||
num_ddpm_timesteps=self.num_timesteps,verbose=verbose)
|
||||
alphas_cumprod = self.alphas_cumprod
|
||||
assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
|
||||
|
||||
|
||||
assert self.alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
|
||||
|
||||
|
||||
to_torch = lambda x: x.to(self.cdevice)
|
||||
|
||||
self.register_buffer1('betas', to_torch(self.betas))
|
||||
self.register_buffer1('alphas_cumprod', to_torch(alphas_cumprod))
|
||||
self.register_buffer1('alphas_cumprod_prev', to_torch(self.alphas_cumprod_prev))
|
||||
self.register_buffer1('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
|
||||
|
||||
self.register_buffer1('alphas_cumprod', to_torch(self.alphas_cumprod))
|
||||
# ddim sampling parameters
|
||||
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=self.alphas_cumprod.cpu(),
|
||||
ddim_timesteps=self.ddim_timesteps,
|
||||
@ -463,25 +461,21 @@ class UNet(DDPM):
|
||||
self.register_buffer1('ddim_alphas', ddim_alphas)
|
||||
self.register_buffer1('ddim_alphas_prev', ddim_alphas_prev)
|
||||
self.register_buffer1('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
|
||||
self.ddim_sqrt_one_minus_alphas = np.sqrt(1. - ddim_alphas)
|
||||
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
|
||||
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
|
||||
1 - self.alphas_cumprod / self.alphas_cumprod_prev))
|
||||
self.register_buffer1('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample(self,
|
||||
S,
|
||||
batch_size,
|
||||
shape,
|
||||
seed,
|
||||
conditioning=None,
|
||||
conditioning,
|
||||
x0=None,
|
||||
shape = None,
|
||||
seed=1234,
|
||||
callback=None,
|
||||
img_callback=None,
|
||||
quantize_x0=False,
|
||||
eta=0.,
|
||||
mask=None,
|
||||
x0=None,
|
||||
sampler = "plms",
|
||||
temperature=1.,
|
||||
noise_dropout=0.,
|
||||
score_corrector=None,
|
||||
@ -492,41 +486,74 @@ class UNet(DDPM):
|
||||
unconditional_guidance_scale=1.,
|
||||
unconditional_conditioning=None,
|
||||
):
|
||||
if conditioning is not None:
|
||||
if isinstance(conditioning, dict):
|
||||
cbs = conditioning[list(conditioning.keys())[0]].shape[0]
|
||||
if cbs != batch_size:
|
||||
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
||||
else:
|
||||
if conditioning.shape[0] != batch_size:
|
||||
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
|
||||
|
||||
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=False)
|
||||
|
||||
# sampling
|
||||
C, H, W = shape
|
||||
size = (batch_size, C, H, W)
|
||||
print(f'Data shape for PLMS sampling is {size}')
|
||||
|
||||
|
||||
if(self.turbo):
|
||||
self.model1.to(self.cdevice)
|
||||
self.model2.to(self.cdevice)
|
||||
|
||||
samples = self.plms_sampling(conditioning, size, seed,
|
||||
callback=callback,
|
||||
img_callback=img_callback,
|
||||
quantize_denoised=quantize_x0,
|
||||
mask=mask, x0=x0,
|
||||
ddim_use_original_steps=False,
|
||||
noise_dropout=noise_dropout,
|
||||
temperature=temperature,
|
||||
score_corrector=score_corrector,
|
||||
corrector_kwargs=corrector_kwargs,
|
||||
x_T=x_T,
|
||||
log_every_t=log_every_t,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
)
|
||||
if x0 is None:
|
||||
batch_size, b1, b2, b3 = shape
|
||||
img_shape = (1, b1, b2, b3)
|
||||
tens = []
|
||||
print("seeds used = ", [seed+s for s in range(batch_size)])
|
||||
for _ in range(batch_size):
|
||||
torch.manual_seed(seed)
|
||||
tens.append(torch.randn(img_shape, device=self.cdevice))
|
||||
seed+=1
|
||||
noise = torch.cat(tens)
|
||||
del tens
|
||||
|
||||
x_latent = noise if x0 is None else x0
|
||||
# sampling
|
||||
|
||||
if sampler == "plms":
|
||||
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=False)
|
||||
print(f'Data shape for PLMS sampling is {shape}')
|
||||
samples = self.plms_sampling(conditioning, batch_size, x_latent,
|
||||
callback=callback,
|
||||
img_callback=img_callback,
|
||||
quantize_denoised=quantize_x0,
|
||||
mask=mask, x0=x0,
|
||||
ddim_use_original_steps=False,
|
||||
noise_dropout=noise_dropout,
|
||||
temperature=temperature,
|
||||
score_corrector=score_corrector,
|
||||
corrector_kwargs=corrector_kwargs,
|
||||
log_every_t=log_every_t,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
)
|
||||
|
||||
elif sampler == "ddim":
|
||||
samples = self.ddim_sampling(x_latent, conditioning, S, unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
mask = mask,init_latent=x_T,use_original_steps=False)
|
||||
|
||||
elif sampler == "euler":
|
||||
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=False)
|
||||
samples = self.euler_sampling(self.alphas_cumprod,x_latent, S, conditioning, unconditional_conditioning=unconditional_conditioning,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale)
|
||||
elif sampler == "euler_a":
|
||||
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=False)
|
||||
samples = self.euler_ancestral_sampling(self.alphas_cumprod,x_latent, S, conditioning, unconditional_conditioning=unconditional_conditioning,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale)
|
||||
|
||||
elif sampler == "dpm2":
|
||||
samples = self.dpm_2_sampling(self.alphas_cumprod,x_latent, S, conditioning, unconditional_conditioning=unconditional_conditioning,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale)
|
||||
elif sampler == "heun":
|
||||
samples = self.heun_sampling(self.alphas_cumprod,x_latent, S, conditioning, unconditional_conditioning=unconditional_conditioning,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale)
|
||||
|
||||
elif sampler == "dpm2_a":
|
||||
samples = self.dpm_2_ancestral_sampling(self.alphas_cumprod,x_latent, S, conditioning, unconditional_conditioning=unconditional_conditioning,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale)
|
||||
|
||||
|
||||
elif sampler == "lms":
|
||||
samples = self.lms_sampling(self.alphas_cumprod,x_latent, S, conditioning, unconditional_conditioning=unconditional_conditioning,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale)
|
||||
|
||||
if(self.turbo):
|
||||
self.model1.to("cpu")
|
||||
@ -535,36 +562,17 @@ class UNet(DDPM):
|
||||
return samples
|
||||
|
||||
@torch.no_grad()
|
||||
def plms_sampling(self, cond, shape, seed,
|
||||
x_T=None, ddim_use_original_steps=False,
|
||||
callback=None, timesteps=None, quantize_denoised=False,
|
||||
def plms_sampling(self, cond,b, img,
|
||||
ddim_use_original_steps=False,
|
||||
callback=None, quantize_denoised=False,
|
||||
mask=None, x0=None, img_callback=None, log_every_t=100,
|
||||
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
||||
unconditional_guidance_scale=1., unconditional_conditioning=None,):
|
||||
|
||||
device = self.betas.device
|
||||
b = shape[0]
|
||||
if x_T is None:
|
||||
_, b1, b2, b3 = shape
|
||||
img_shape = (1, b1, b2, b3)
|
||||
tens = []
|
||||
print("seeds used = ", [seed+s for s in range(b)])
|
||||
for _ in range(b):
|
||||
torch.manual_seed(seed)
|
||||
tens.append(torch.randn(img_shape, device=device))
|
||||
seed+=1
|
||||
img = torch.cat(tens)
|
||||
del tens
|
||||
else:
|
||||
img = x_T
|
||||
|
||||
if timesteps is None:
|
||||
timesteps = self.num_timesteps if ddim_use_original_steps else self.ddim_timesteps
|
||||
elif timesteps is not None and not ddim_use_original_steps:
|
||||
subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
|
||||
timesteps = self.ddim_timesteps[:subset_end]
|
||||
|
||||
time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps)
|
||||
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
|
||||
timesteps = self.ddim_timesteps
|
||||
time_range = np.flip(timesteps)
|
||||
total_steps = timesteps.shape[0]
|
||||
print(f"Running PLMS Sampling with {total_steps} timesteps")
|
||||
|
||||
iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps)
|
||||
@ -618,10 +626,10 @@ class UNet(DDPM):
|
||||
|
||||
return e_t
|
||||
|
||||
alphas = self.alphas_cumprod if use_original_steps else self.ddim_alphas
|
||||
alphas_prev = self.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
|
||||
sqrt_one_minus_alphas = self.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
|
||||
sigmas = self.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
|
||||
alphas = self.ddim_alphas
|
||||
alphas_prev = self.ddim_alphas_prev
|
||||
sqrt_one_minus_alphas = self.ddim_sqrt_one_minus_alphas
|
||||
sigmas = self.ddim_sigmas
|
||||
|
||||
def get_x_prev_and_pred_x0(e_t, index):
|
||||
# select parameters corresponding to the currently considered timestep
|
||||
@ -664,17 +672,11 @@ class UNet(DDPM):
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def stochastic_encode(self, x0, t, seed, ddim_eta,ddim_steps,use_original_steps=False, noise=None, mask=None):
|
||||
def stochastic_encode(self, x0, t, seed, ddim_eta,ddim_steps,use_original_steps=False, noise=None):
|
||||
# fast, but does not allow for exact reconstruction
|
||||
# t serves as an index to gather the correct alphas
|
||||
self.make_schedule(ddim_num_steps=ddim_steps, ddim_eta=ddim_eta, verbose=False)
|
||||
|
||||
if use_original_steps:
|
||||
sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
|
||||
sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
|
||||
else:
|
||||
sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
|
||||
sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
|
||||
sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
|
||||
|
||||
if noise is None:
|
||||
b0, b1, b2, b3 = x0.shape
|
||||
@ -687,50 +689,53 @@ class UNet(DDPM):
|
||||
seed+=1
|
||||
noise = torch.cat(tens)
|
||||
del tens
|
||||
if mask is not None:
|
||||
noise = noise*mask
|
||||
return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 +
|
||||
extract_into_tensor(sqrt_one_minus_alphas_cumprod.to(self.cdevice), t, x0.shape) * noise)
|
||||
extract_into_tensor(self.ddim_sqrt_one_minus_alphas, t, x0.shape) * noise)
|
||||
|
||||
@torch.no_grad()
|
||||
def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
|
||||
mask = None,use_original_steps=False):
|
||||
def add_noise(self, x0, t):
|
||||
|
||||
|
||||
if(self.turbo):
|
||||
self.model1.to(self.cdevice)
|
||||
self.model2.to(self.cdevice)
|
||||
sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
|
||||
noise = torch.randn(x0.shape, device=x0.device)
|
||||
|
||||
timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
|
||||
# print(extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape),
|
||||
# extract_into_tensor(self.ddim_sqrt_one_minus_alphas, t, x0.shape))
|
||||
return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 +
|
||||
extract_into_tensor(self.ddim_sqrt_one_minus_alphas, t, x0.shape) * noise)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def ddim_sampling(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
|
||||
mask = None,init_latent=None,use_original_steps=False):
|
||||
|
||||
timesteps = self.ddim_timesteps
|
||||
timesteps = timesteps[:t_start]
|
||||
|
||||
time_range = np.flip(timesteps)
|
||||
total_steps = timesteps.shape[0]
|
||||
print(f"Running DDIM Sampling with {total_steps} timesteps")
|
||||
|
||||
iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
|
||||
x_dec = x_latent
|
||||
# x0 = x_latent
|
||||
x0 = init_latent
|
||||
for i, step in enumerate(iterator):
|
||||
index = total_steps - i - 1
|
||||
ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
|
||||
|
||||
ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
|
||||
|
||||
# if mask is not None:
|
||||
# x_dec = x0 * mask + (1. - mask) * x_dec
|
||||
if mask is not None:
|
||||
# x0_noisy = self.add_noise(mask, torch.tensor([index] * x0.shape[0]).to(self.cdevice))
|
||||
x0_noisy = x0
|
||||
x_dec = x0_noisy* mask + (1. - mask) * x_dec
|
||||
|
||||
x_dec = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning)
|
||||
# if mask is not None:
|
||||
# return x0 * mask + (1. - mask) * x_dec
|
||||
|
||||
if(self.turbo):
|
||||
self.model1.to("cpu")
|
||||
self.model2.to("cpu")
|
||||
if mask is not None:
|
||||
return x0 * mask + (1. - mask) * x_dec
|
||||
|
||||
return x_dec
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
|
||||
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
||||
@ -743,7 +748,6 @@ class UNet(DDPM):
|
||||
x_in = torch.cat([x] * 2)
|
||||
t_in = torch.cat([t] * 2)
|
||||
c_in = torch.cat([unconditional_conditioning, c])
|
||||
# print("xin shape = ", x_in.shape)
|
||||
e_t_uncond, e_t = self.apply_model(x_in, t_in, c_in).chunk(2)
|
||||
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
|
||||
|
||||
@ -751,10 +755,10 @@ class UNet(DDPM):
|
||||
assert self.model.parameterization == "eps"
|
||||
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
|
||||
|
||||
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
|
||||
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
|
||||
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
|
||||
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
|
||||
alphas = self.ddim_alphas
|
||||
alphas_prev = self.ddim_alphas_prev
|
||||
sqrt_one_minus_alphas = self.ddim_sqrt_one_minus_alphas
|
||||
sigmas = self.ddim_sigmas
|
||||
# select parameters corresponding to the currently considered timestep
|
||||
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
|
||||
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
|
||||
@ -771,4 +775,256 @@ class UNet(DDPM):
|
||||
if noise_dropout > 0.:
|
||||
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
||||
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
||||
return x_prev
|
||||
return x_prev
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def euler_sampling(self, ac, x, S, cond, unconditional_conditioning = None, unconditional_guidance_scale = 1,extra_args=None,callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
|
||||
"""Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
cvd = CompVisDenoiser(ac)
|
||||
sigmas = cvd.get_sigmas(S)
|
||||
x = x*sigmas[0]
|
||||
|
||||
s_in = x.new_ones([x.shape[0]]).half()
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
|
||||
eps = torch.randn_like(x) * s_noise
|
||||
sigma_hat = (sigmas[i] * (gamma + 1)).half()
|
||||
if gamma > 0:
|
||||
x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
|
||||
|
||||
s_i = sigma_hat * s_in
|
||||
x_in = torch.cat([x] * 2)
|
||||
t_in = torch.cat([s_i] * 2)
|
||||
cond_in = torch.cat([unconditional_conditioning, cond])
|
||||
c_out, c_in = [append_dims(tmp, x_in.ndim) for tmp in cvd.get_scalings(t_in)]
|
||||
eps = self.apply_model(x_in * c_in, cvd.sigma_to_t(t_in), cond_in)
|
||||
e_t_uncond, e_t = (x_in + eps * c_out).chunk(2)
|
||||
denoised = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
|
||||
|
||||
|
||||
d = to_d(x, sigma_hat, denoised)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
|
||||
dt = sigmas[i + 1] - sigma_hat
|
||||
# Euler method
|
||||
x = x + d * dt
|
||||
return x
|
||||
|
||||
@torch.no_grad()
|
||||
def euler_ancestral_sampling(self,ac,x, S, cond, unconditional_conditioning = None, unconditional_guidance_scale = 1,extra_args=None, callback=None, disable=None):
|
||||
"""Ancestral sampling with Euler method steps."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
|
||||
|
||||
cvd = CompVisDenoiser(ac)
|
||||
sigmas = cvd.get_sigmas(S)
|
||||
x = x*sigmas[0]
|
||||
|
||||
s_in = x.new_ones([x.shape[0]]).half()
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
|
||||
s_i = sigmas[i] * s_in
|
||||
x_in = torch.cat([x] * 2)
|
||||
t_in = torch.cat([s_i] * 2)
|
||||
cond_in = torch.cat([unconditional_conditioning, cond])
|
||||
c_out, c_in = [append_dims(tmp, x_in.ndim) for tmp in cvd.get_scalings(t_in)]
|
||||
eps = self.apply_model(x_in * c_in, cvd.sigma_to_t(t_in), cond_in)
|
||||
e_t_uncond, e_t = (x_in + eps * c_out).chunk(2)
|
||||
denoised = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
|
||||
|
||||
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1])
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
d = to_d(x, sigmas[i], denoised)
|
||||
# Euler method
|
||||
dt = sigma_down - sigmas[i]
|
||||
x = x + d * dt
|
||||
x = x + torch.randn_like(x) * sigma_up
|
||||
return x
|
||||
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def heun_sampling(self, ac, x, S, cond, unconditional_conditioning = None, unconditional_guidance_scale = 1, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
|
||||
"""Implements Algorithm 2 (Heun steps) from Karras et al. (2022)."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
|
||||
cvd = CompVisDenoiser(alphas_cumprod=ac)
|
||||
sigmas = cvd.get_sigmas(S)
|
||||
x = x*sigmas[0]
|
||||
|
||||
|
||||
s_in = x.new_ones([x.shape[0]]).half()
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
|
||||
eps = torch.randn_like(x) * s_noise
|
||||
sigma_hat = (sigmas[i] * (gamma + 1)).half()
|
||||
if gamma > 0:
|
||||
x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
|
||||
|
||||
s_i = sigma_hat * s_in
|
||||
x_in = torch.cat([x] * 2)
|
||||
t_in = torch.cat([s_i] * 2)
|
||||
cond_in = torch.cat([unconditional_conditioning, cond])
|
||||
c_out, c_in = [append_dims(tmp, x_in.ndim) for tmp in cvd.get_scalings(t_in)]
|
||||
eps = self.apply_model(x_in * c_in, cvd.sigma_to_t(t_in), cond_in)
|
||||
e_t_uncond, e_t = (x_in + eps * c_out).chunk(2)
|
||||
denoised = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
|
||||
|
||||
d = to_d(x, sigma_hat, denoised)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
|
||||
dt = sigmas[i + 1] - sigma_hat
|
||||
if sigmas[i + 1] == 0:
|
||||
# Euler method
|
||||
x = x + d * dt
|
||||
else:
|
||||
# Heun's method
|
||||
x_2 = x + d * dt
|
||||
s_i = sigmas[i + 1] * s_in
|
||||
x_in = torch.cat([x_2] * 2)
|
||||
t_in = torch.cat([s_i] * 2)
|
||||
cond_in = torch.cat([unconditional_conditioning, cond])
|
||||
c_out, c_in = [append_dims(tmp, x_in.ndim) for tmp in cvd.get_scalings(t_in)]
|
||||
eps = self.apply_model(x_in * c_in, cvd.sigma_to_t(t_in), cond_in)
|
||||
e_t_uncond, e_t = (x_in + eps * c_out).chunk(2)
|
||||
denoised_2 = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
|
||||
|
||||
d_2 = to_d(x_2, sigmas[i + 1], denoised_2)
|
||||
d_prime = (d + d_2) / 2
|
||||
x = x + d_prime * dt
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def dpm_2_sampling(self,ac,x, S, cond, unconditional_conditioning = None, unconditional_guidance_scale = 1,extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
|
||||
"""A sampler inspired by DPM-Solver-2 and Algorithm 2 from Karras et al. (2022)."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
|
||||
cvd = CompVisDenoiser(ac)
|
||||
sigmas = cvd.get_sigmas(S)
|
||||
x = x*sigmas[0]
|
||||
|
||||
s_in = x.new_ones([x.shape[0]]).half()
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
|
||||
eps = torch.randn_like(x) * s_noise
|
||||
sigma_hat = sigmas[i] * (gamma + 1)
|
||||
if gamma > 0:
|
||||
x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
|
||||
|
||||
s_i = sigma_hat * s_in
|
||||
x_in = torch.cat([x] * 2)
|
||||
t_in = torch.cat([s_i] * 2)
|
||||
cond_in = torch.cat([unconditional_conditioning, cond])
|
||||
c_out, c_in = [append_dims(tmp, x_in.ndim) for tmp in cvd.get_scalings(t_in)]
|
||||
eps = self.apply_model(x_in * c_in, cvd.sigma_to_t(t_in), cond_in)
|
||||
e_t_uncond, e_t = (x_in + eps * c_out).chunk(2)
|
||||
denoised = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
|
||||
|
||||
|
||||
|
||||
d = to_d(x, sigma_hat, denoised)
|
||||
# Midpoint method, where the midpoint is chosen according to a rho=3 Karras schedule
|
||||
sigma_mid = ((sigma_hat ** (1 / 3) + sigmas[i + 1] ** (1 / 3)) / 2) ** 3
|
||||
dt_1 = sigma_mid - sigma_hat
|
||||
dt_2 = sigmas[i + 1] - sigma_hat
|
||||
x_2 = x + d * dt_1
|
||||
|
||||
s_i = sigma_mid * s_in
|
||||
x_in = torch.cat([x_2] * 2)
|
||||
t_in = torch.cat([s_i] * 2)
|
||||
cond_in = torch.cat([unconditional_conditioning, cond])
|
||||
c_out, c_in = [append_dims(tmp, x_in.ndim) for tmp in cvd.get_scalings(t_in)]
|
||||
eps = self.apply_model(x_in * c_in, cvd.sigma_to_t(t_in), cond_in)
|
||||
e_t_uncond, e_t = (x_in + eps * c_out).chunk(2)
|
||||
denoised_2 = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
|
||||
|
||||
|
||||
d_2 = to_d(x_2, sigma_mid, denoised_2)
|
||||
x = x + d_2 * dt_2
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def dpm_2_ancestral_sampling(self,ac,x, S, cond, unconditional_conditioning = None, unconditional_guidance_scale = 1, extra_args=None, callback=None, disable=None):
|
||||
"""Ancestral sampling with DPM-Solver inspired second-order steps."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
|
||||
cvd = CompVisDenoiser(ac)
|
||||
sigmas = cvd.get_sigmas(S)
|
||||
x = x*sigmas[0]
|
||||
|
||||
s_in = x.new_ones([x.shape[0]]).half()
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
|
||||
s_i = sigmas[i] * s_in
|
||||
x_in = torch.cat([x] * 2)
|
||||
t_in = torch.cat([s_i] * 2)
|
||||
cond_in = torch.cat([unconditional_conditioning, cond])
|
||||
c_out, c_in = [append_dims(tmp, x_in.ndim) for tmp in cvd.get_scalings(t_in)]
|
||||
eps = self.apply_model(x_in * c_in, cvd.sigma_to_t(t_in), cond_in)
|
||||
e_t_uncond, e_t = (x_in + eps * c_out).chunk(2)
|
||||
denoised = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
|
||||
|
||||
|
||||
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1])
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
d = to_d(x, sigmas[i], denoised)
|
||||
# Midpoint method, where the midpoint is chosen according to a rho=3 Karras schedule
|
||||
sigma_mid = ((sigmas[i] ** (1 / 3) + sigma_down ** (1 / 3)) / 2) ** 3
|
||||
dt_1 = sigma_mid - sigmas[i]
|
||||
dt_2 = sigma_down - sigmas[i]
|
||||
x_2 = x + d * dt_1
|
||||
|
||||
s_i = sigma_mid * s_in
|
||||
x_in = torch.cat([x_2] * 2)
|
||||
t_in = torch.cat([s_i] * 2)
|
||||
cond_in = torch.cat([unconditional_conditioning, cond])
|
||||
c_out, c_in = [append_dims(tmp, x_in.ndim) for tmp in cvd.get_scalings(t_in)]
|
||||
eps = self.apply_model(x_in * c_in, cvd.sigma_to_t(t_in), cond_in)
|
||||
e_t_uncond, e_t = (x_in + eps * c_out).chunk(2)
|
||||
denoised_2 = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
|
||||
|
||||
|
||||
d_2 = to_d(x_2, sigma_mid, denoised_2)
|
||||
x = x + d_2 * dt_2
|
||||
x = x + torch.randn_like(x) * sigma_up
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def lms_sampling(self,ac,x, S, cond, unconditional_conditioning = None, unconditional_guidance_scale = 1, extra_args=None, callback=None, disable=None, order=4):
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
|
||||
cvd = CompVisDenoiser(ac)
|
||||
sigmas = cvd.get_sigmas(S)
|
||||
x = x*sigmas[0]
|
||||
|
||||
ds = []
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
|
||||
s_i = sigmas[i] * s_in
|
||||
x_in = torch.cat([x] * 2)
|
||||
t_in = torch.cat([s_i] * 2)
|
||||
cond_in = torch.cat([unconditional_conditioning, cond])
|
||||
c_out, c_in = [append_dims(tmp, x_in.ndim) for tmp in cvd.get_scalings(t_in)]
|
||||
eps = self.apply_model(x_in * c_in, cvd.sigma_to_t(t_in), cond_in)
|
||||
e_t_uncond, e_t = (x_in + eps * c_out).chunk(2)
|
||||
denoised = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
|
||||
|
||||
|
||||
d = to_d(x, sigmas[i], denoised)
|
||||
ds.append(d)
|
||||
if len(ds) > order:
|
||||
ds.pop(0)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
cur_order = min(i + 1, order)
|
||||
coeffs = [linear_multistep_coeff(cur_order, sigmas.cpu(), i, j) for j in range(cur_order)]
|
||||
x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds)))
|
||||
return x
|
||||
|
13
optimizedSD/diffusers_txt2img.py
Normal file
13
optimizedSD/diffusers_txt2img.py
Normal file
@ -0,0 +1,13 @@
|
||||
import torch
|
||||
from diffusers import LDMTextToImagePipeline
|
||||
|
||||
pipe = LDMTextToImagePipeline.from_pretrained("CompVis/stable-diffusion-v1-3-diffusers", use_auth_token=True)
|
||||
|
||||
prompt = "19th Century wooden engraving of Elon musk"
|
||||
|
||||
seed = torch.manual_seed(1024)
|
||||
images = pipe([prompt], batch_size=1, num_inference_steps=50, guidance_scale=7, generator=seed,torch_device="cpu" )["sample"]
|
||||
|
||||
# save images
|
||||
for idx, image in enumerate(images):
|
||||
image.save(f"image-{idx}.png")
|
@ -13,7 +13,7 @@ from ldm.modules.diffusionmodules.util import (
|
||||
normalization,
|
||||
timestep_embedding,
|
||||
)
|
||||
from ldm.modules.attention import SpatialTransformer
|
||||
from .splitAttention import SpatialTransformer
|
||||
|
||||
|
||||
class AttentionPool2d(nn.Module):
|
||||
|
362
optimizedSD/optimized_img2img.py
Normal file
362
optimizedSD/optimized_img2img.py
Normal file
@ -0,0 +1,362 @@
|
||||
import argparse, os, re
|
||||
import torch
|
||||
import numpy as np
|
||||
from random import randint
|
||||
from omegaconf import OmegaConf
|
||||
from PIL import Image
|
||||
from tqdm import tqdm, trange
|
||||
from itertools import islice
|
||||
from einops import rearrange
|
||||
from torchvision.utils import make_grid
|
||||
import time
|
||||
from pytorch_lightning import seed_everything
|
||||
from torch import autocast
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from einops import rearrange, repeat
|
||||
from ldm.util import instantiate_from_config
|
||||
from optimUtils import split_weighted_subprompts, logger
|
||||
from transformers import logging
|
||||
import pandas as pd
|
||||
logging.set_verbosity_error()
|
||||
|
||||
|
||||
def chunk(it, size):
|
||||
it = iter(it)
|
||||
return iter(lambda: tuple(islice(it, size)), ())
|
||||
|
||||
|
||||
def load_model_from_config(ckpt, verbose=False):
|
||||
print(f"Loading model from {ckpt}")
|
||||
pl_sd = torch.load(ckpt, map_location="cpu")
|
||||
if "global_step" in pl_sd:
|
||||
print(f"Global Step: {pl_sd['global_step']}")
|
||||
sd = pl_sd["state_dict"]
|
||||
return sd
|
||||
|
||||
|
||||
def load_img(path, h0, w0):
|
||||
|
||||
image = Image.open(path).convert("RGB")
|
||||
w, h = image.size
|
||||
|
||||
print(f"loaded input image of size ({w}, {h}) from {path}")
|
||||
if h0 is not None and w0 is not None:
|
||||
h, w = h0, w0
|
||||
|
||||
w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 32
|
||||
|
||||
print(f"New image size ({w}, {h})")
|
||||
image = image.resize((w, h), resample=Image.LANCZOS)
|
||||
image = np.array(image).astype(np.float32) / 255.0
|
||||
image = image[None].transpose(0, 3, 1, 2)
|
||||
image = torch.from_numpy(image)
|
||||
return 2.0 * image - 1.0
|
||||
|
||||
|
||||
config = "optimizedSD/v1-inference.yaml"
|
||||
ckpt = "models/ldm/stable-diffusion-v1/model.ckpt"
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--prompt", type=str, nargs="?", default="a painting of a virus monster playing guitar", help="the prompt to render"
|
||||
)
|
||||
parser.add_argument("--outdir", type=str, nargs="?", help="dir to write results to", default="outputs/img2img-samples")
|
||||
parser.add_argument("--init-img", type=str, nargs="?", help="path to the input image")
|
||||
|
||||
parser.add_argument(
|
||||
"--skip_grid",
|
||||
action="store_true",
|
||||
help="do not save a grid, only individual samples. Helpful when evaluating lots of samples",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip_save",
|
||||
action="store_true",
|
||||
help="do not save individual samples. For speed measurements.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ddim_steps",
|
||||
type=int,
|
||||
default=50,
|
||||
help="number of ddim sampling steps",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--ddim_eta",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="ddim eta (eta=0.0 corresponds to deterministic sampling",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--n_iter",
|
||||
type=int,
|
||||
default=1,
|
||||
help="sample this often",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--H",
|
||||
type=int,
|
||||
default=None,
|
||||
help="image height, in pixel space",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--W",
|
||||
type=int,
|
||||
default=None,
|
||||
help="image width, in pixel space",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--strength",
|
||||
type=float,
|
||||
default=0.75,
|
||||
help="strength for noising/unnoising. 1.0 corresponds to full destruction of information in init image",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--n_samples",
|
||||
type=int,
|
||||
default=5,
|
||||
help="how many samples to produce for each given prompt. A.k.a. batch size",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--n_rows",
|
||||
type=int,
|
||||
default=0,
|
||||
help="rows in the grid (default: n_samples)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--scale",
|
||||
type=float,
|
||||
default=7.5,
|
||||
help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--from-file",
|
||||
type=str,
|
||||
help="if specified, load prompts from this file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
default=None,
|
||||
help="the seed (for reproducible sampling)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
default="cuda",
|
||||
help="CPU or GPU (cuda/cuda:0/cuda:1/...)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--unet_bs",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Slightly reduces inference time at the expense of high VRAM (value > 1 not recommended )",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--turbo",
|
||||
action="store_true",
|
||||
help="Reduces inference time on the expense of 1GB VRAM",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--format",
|
||||
type=str,
|
||||
help="output image format",
|
||||
choices=["jpg", "png"],
|
||||
default="png",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sampler",
|
||||
type=str,
|
||||
help="sampler",
|
||||
choices=["ddim"],
|
||||
default="ddim",
|
||||
)
|
||||
opt = parser.parse_args()
|
||||
|
||||
tic = time.time()
|
||||
os.makedirs(opt.outdir, exist_ok=True)
|
||||
outpath = opt.outdir
|
||||
grid_count = len(os.listdir(outpath)) - 1
|
||||
|
||||
if opt.seed == None:
|
||||
opt.seed = randint(0, 1000000)
|
||||
seed_everything(opt.seed)
|
||||
|
||||
# Logging
|
||||
logger(vars(opt), log_csv = "logs/img2img_logs.csv")
|
||||
|
||||
sd = load_model_from_config(f"{ckpt}")
|
||||
li, lo = [], []
|
||||
for key, value in sd.items():
|
||||
sp = key.split(".")
|
||||
if (sp[0]) == "model":
|
||||
if "input_blocks" in sp:
|
||||
li.append(key)
|
||||
elif "middle_block" in sp:
|
||||
li.append(key)
|
||||
elif "time_embed" in sp:
|
||||
li.append(key)
|
||||
else:
|
||||
lo.append(key)
|
||||
for key in li:
|
||||
sd["model1." + key[6:]] = sd.pop(key)
|
||||
for key in lo:
|
||||
sd["model2." + key[6:]] = sd.pop(key)
|
||||
|
||||
config = OmegaConf.load(f"{config}")
|
||||
|
||||
assert os.path.isfile(opt.init_img)
|
||||
init_image = load_img(opt.init_img, opt.H, opt.W).to(opt.device)
|
||||
|
||||
model = instantiate_from_config(config.modelUNet)
|
||||
_, _ = model.load_state_dict(sd, strict=False)
|
||||
model.eval()
|
||||
model.cdevice = opt.device
|
||||
model.unet_bs = opt.unet_bs
|
||||
model.turbo = opt.turbo
|
||||
|
||||
modelCS = instantiate_from_config(config.modelCondStage)
|
||||
_, _ = modelCS.load_state_dict(sd, strict=False)
|
||||
modelCS.eval()
|
||||
modelCS.cond_stage_model.device = opt.device
|
||||
|
||||
modelFS = instantiate_from_config(config.modelFirstStage)
|
||||
_, _ = modelFS.load_state_dict(sd, strict=False)
|
||||
modelFS.eval()
|
||||
del sd
|
||||
if opt.device != "cpu" and opt.precision == "autocast":
|
||||
model.half()
|
||||
modelCS.half()
|
||||
modelFS.half()
|
||||
init_image = init_image.half()
|
||||
|
||||
batch_size = opt.n_samples
|
||||
n_rows = opt.n_rows if opt.n_rows > 0 else batch_size
|
||||
if not opt.from_file:
|
||||
assert opt.prompt is not None
|
||||
prompt = opt.prompt
|
||||
data = [batch_size * [prompt]]
|
||||
|
||||
else:
|
||||
print(f"reading prompts from {opt.from_file}")
|
||||
with open(opt.from_file, "r") as f:
|
||||
data = f.read().splitlines()
|
||||
data = batch_size * list(data)
|
||||
data = list(chunk(sorted(data), batch_size))
|
||||
|
||||
modelFS.to(opt.device)
|
||||
|
||||
init_image = repeat(init_image, "1 ... -> b ...", b=batch_size)
|
||||
init_latent = modelFS.get_first_stage_encoding(modelFS.encode_first_stage(init_image)) # move to latent space
|
||||
|
||||
if opt.device != "cpu":
|
||||
mem = torch.cuda.memory_allocated(device=opt.device) / 1e6
|
||||
modelFS.to("cpu")
|
||||
while torch.cuda.memory_allocated(device=opt.device) / 1e6 >= mem:
|
||||
time.sleep(1)
|
||||
|
||||
|
||||
assert 0.0 <= opt.strength <= 1.0, "can only work with strength in [0.0, 1.0]"
|
||||
t_enc = int(opt.strength * opt.ddim_steps)
|
||||
print(f"target t_enc is {t_enc} steps")
|
||||
|
||||
|
||||
if opt.precision == "autocast" and opt.device != "cpu":
|
||||
precision_scope = autocast
|
||||
else:
|
||||
precision_scope = nullcontext
|
||||
|
||||
seeds = ""
|
||||
with torch.no_grad():
|
||||
|
||||
all_samples = list()
|
||||
for n in trange(opt.n_iter, desc="Sampling"):
|
||||
for prompts in tqdm(data, desc="data"):
|
||||
|
||||
sample_path = os.path.join(outpath, "_".join(re.split(":| ", prompts[0])))[:150]
|
||||
os.makedirs(sample_path, exist_ok=True)
|
||||
base_count = len(os.listdir(sample_path))
|
||||
|
||||
with precision_scope("cuda"):
|
||||
modelCS.to(opt.device)
|
||||
uc = None
|
||||
if opt.scale != 1.0:
|
||||
uc = modelCS.get_learned_conditioning(batch_size * [""])
|
||||
if isinstance(prompts, tuple):
|
||||
prompts = list(prompts)
|
||||
|
||||
subprompts, weights = split_weighted_subprompts(prompts[0])
|
||||
if len(subprompts) > 1:
|
||||
c = torch.zeros_like(uc)
|
||||
totalWeight = sum(weights)
|
||||
# normalize each "sub prompt" and add it
|
||||
for i in range(len(subprompts)):
|
||||
weight = weights[i]
|
||||
# if not skip_normalize:
|
||||
weight = weight / totalWeight
|
||||
c = torch.add(c, modelCS.get_learned_conditioning(subprompts[i]), alpha=weight)
|
||||
else:
|
||||
c = modelCS.get_learned_conditioning(prompts)
|
||||
|
||||
if opt.device != "cpu":
|
||||
mem = torch.cuda.memory_allocated(device=opt.device) / 1e6
|
||||
modelCS.to("cpu")
|
||||
while torch.cuda.memory_allocated(device=opt.device) / 1e6 >= mem:
|
||||
time.sleep(1)
|
||||
|
||||
# encode (scaled latent)
|
||||
z_enc = model.stochastic_encode(
|
||||
init_latent,
|
||||
torch.tensor([t_enc] * batch_size).to(opt.device),
|
||||
opt.seed,
|
||||
opt.ddim_eta,
|
||||
opt.ddim_steps,
|
||||
)
|
||||
# decode it
|
||||
samples_ddim = model.sample(
|
||||
t_enc,
|
||||
c,
|
||||
z_enc,
|
||||
unconditional_guidance_scale=opt.scale,
|
||||
unconditional_conditioning=uc,
|
||||
sampler = opt.sampler
|
||||
)
|
||||
|
||||
modelFS.to(opt.device)
|
||||
print("saving images")
|
||||
for i in range(batch_size):
|
||||
|
||||
x_samples_ddim = modelFS.decode_first_stage(samples_ddim[i].unsqueeze(0))
|
||||
x_sample = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
x_sample = 255.0 * rearrange(x_sample[0].cpu().numpy(), "c h w -> h w c")
|
||||
Image.fromarray(x_sample.astype(np.uint8)).save(
|
||||
os.path.join(sample_path, "seed_" + str(opt.seed) + "_" + f"{base_count:05}.{opt.format}")
|
||||
)
|
||||
seeds += str(opt.seed) + ","
|
||||
opt.seed += 1
|
||||
base_count += 1
|
||||
|
||||
if opt.device != "cpu":
|
||||
mem = torch.cuda.memory_allocated(device=opt.device) / 1e6
|
||||
modelFS.to("cpu")
|
||||
while torch.cuda.memory_allocated(device=opt.device) / 1e6 >= mem:
|
||||
time.sleep(1)
|
||||
|
||||
del samples_ddim
|
||||
print("memory_final = ", torch.cuda.memory_allocated(device=opt.device) / 1e6)
|
||||
|
||||
toc = time.time()
|
||||
|
||||
time_taken = (toc - tic) / 60.0
|
||||
|
||||
print(
|
||||
(
|
||||
"Samples finished in {0:.2f} minutes and exported to "
|
||||
+ sample_path
|
||||
+ "\n Seeds used = "
|
||||
+ seeds[:-1]
|
||||
).format(time_taken)
|
||||
)
|
@ -1,7 +1,7 @@
|
||||
import argparse, os, sys, glob, random
|
||||
import argparse, os, re
|
||||
import torch
|
||||
import numpy as np
|
||||
import copy
|
||||
from random import randint
|
||||
from omegaconf import OmegaConf
|
||||
from PIL import Image
|
||||
from tqdm import tqdm, trange
|
||||
@ -13,6 +13,10 @@ from pytorch_lightning import seed_everything
|
||||
from torch import autocast
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from ldm.util import instantiate_from_config
|
||||
from optimUtils import split_weighted_subprompts, logger
|
||||
from transformers import logging
|
||||
# from samplers import CompVisDenoiser
|
||||
logging.set_verbosity_error()
|
||||
|
||||
|
||||
def chunk(it, size):
|
||||
@ -30,33 +34,22 @@ def load_model_from_config(ckpt, verbose=False):
|
||||
|
||||
|
||||
config = "optimizedSD/v1-inference.yaml"
|
||||
ckpt = "models/ldm/stable-diffusion-v1/model.ckpt"
|
||||
device = "cuda"
|
||||
DEFAULT_CKPT = "models/ldm/stable-diffusion-v1/model.ckpt"
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--prompt",
|
||||
type=str,
|
||||
nargs="?",
|
||||
default="a painting of a virus monster playing guitar",
|
||||
help="the prompt to render"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--outdir",
|
||||
type=str,
|
||||
nargs="?",
|
||||
help="dir to write results to",
|
||||
default="outputs/txt2img-samples"
|
||||
"--prompt", type=str, nargs="?", default="a painting of a virus monster playing guitar", help="the prompt to render"
|
||||
)
|
||||
parser.add_argument("--outdir", type=str, nargs="?", help="dir to write results to", default="outputs/txt2img-samples")
|
||||
parser.add_argument(
|
||||
"--skip_grid",
|
||||
action='store_true',
|
||||
action="store_true",
|
||||
help="do not save a grid, only individual samples. Helpful when evaluating lots of samples",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip_save",
|
||||
action='store_true',
|
||||
action="store_true",
|
||||
help="do not save individual samples. For speed measurements.",
|
||||
)
|
||||
parser.add_argument(
|
||||
@ -68,7 +61,7 @@ parser.add_argument(
|
||||
|
||||
parser.add_argument(
|
||||
"--fixed_code",
|
||||
action='store_true',
|
||||
action="store_true",
|
||||
help="if enabled, uses the same starting code across samples ",
|
||||
)
|
||||
parser.add_argument(
|
||||
@ -125,6 +118,12 @@ parser.add_argument(
|
||||
default=7.5,
|
||||
help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
default="cuda",
|
||||
help="specify GPU (cuda/cuda:0/cuda:1/...)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--from-file",
|
||||
type=str,
|
||||
@ -133,165 +132,216 @@ parser.add_argument(
|
||||
parser.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
default=42,
|
||||
default=None,
|
||||
help="the seed (for reproducible sampling)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--small_batch",
|
||||
action='store_true',
|
||||
help="Reduce inference time when generate a smaller batch of images",
|
||||
"--unet_bs",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Slightly reduces inference time at the expense of high VRAM (value > 1 not recommended )",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--precision",
|
||||
"--turbo",
|
||||
action="store_true",
|
||||
help="Reduces inference time on the expense of 1GB VRAM",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--precision",
|
||||
type=str,
|
||||
help="evaluate at this precision",
|
||||
choices=["full", "autocast"],
|
||||
default="autocast"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--format",
|
||||
type=str,
|
||||
help="output image format",
|
||||
choices=["jpg", "png"],
|
||||
default="png",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sampler",
|
||||
type=str,
|
||||
help="sampler",
|
||||
choices=["ddim", "plms","heun", "euler", "euler_a", "dpm2", "dpm2_a", "lms"],
|
||||
default="plms",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ckpt",
|
||||
type=str,
|
||||
help="path to checkpoint of model",
|
||||
default=DEFAULT_CKPT,
|
||||
)
|
||||
opt = parser.parse_args()
|
||||
|
||||
tic = time.time()
|
||||
os.makedirs(opt.outdir, exist_ok=True)
|
||||
outpath = opt.outdir
|
||||
|
||||
sample_path = os.path.join(outpath, "samples", "_".join(opt.prompt.split())[:255])
|
||||
os.makedirs(sample_path, exist_ok=True)
|
||||
base_count = len(os.listdir(sample_path))
|
||||
grid_count = len(os.listdir(outpath)) - 1
|
||||
|
||||
if opt.seed == None:
|
||||
opt.seed = randint(0, 1000000)
|
||||
seed_everything(opt.seed)
|
||||
|
||||
sd = load_model_from_config(f"{ckpt}")
|
||||
li = []
|
||||
lo = []
|
||||
# Logging
|
||||
logger(vars(opt), log_csv = "logs/txt2img_logs.csv")
|
||||
|
||||
sd = load_model_from_config(f"{opt.ckpt}")
|
||||
li, lo = [], []
|
||||
for key, value in sd.items():
|
||||
sp = key.split('.')
|
||||
if(sp[0]) == 'model':
|
||||
if('input_blocks' in sp):
|
||||
sp = key.split(".")
|
||||
if (sp[0]) == "model":
|
||||
if "input_blocks" in sp:
|
||||
li.append(key)
|
||||
elif('middle_block' in sp):
|
||||
elif "middle_block" in sp:
|
||||
li.append(key)
|
||||
elif('time_embed' in sp):
|
||||
elif "time_embed" in sp:
|
||||
li.append(key)
|
||||
else:
|
||||
lo.append(key)
|
||||
for key in li:
|
||||
sd['model1.' + key[6:]] = sd.pop(key)
|
||||
sd["model1." + key[6:]] = sd.pop(key)
|
||||
for key in lo:
|
||||
sd['model2.' + key[6:]] = sd.pop(key)
|
||||
sd["model2." + key[6:]] = sd.pop(key)
|
||||
|
||||
config = OmegaConf.load(f"{config}")
|
||||
config.modelUNet.params.ddim_steps = opt.ddim_steps
|
||||
|
||||
if opt.small_batch:
|
||||
config.modelUNet.params.small_batch = True
|
||||
else:
|
||||
config.modelUNet.params.small_batch = False
|
||||
|
||||
|
||||
|
||||
model = instantiate_from_config(config.modelUNet)
|
||||
_, _ = model.load_state_dict(sd, strict=False)
|
||||
model.eval()
|
||||
|
||||
model.unet_bs = opt.unet_bs
|
||||
model.cdevice = opt.device
|
||||
model.turbo = opt.turbo
|
||||
|
||||
modelCS = instantiate_from_config(config.modelCondStage)
|
||||
_, _ = modelCS.load_state_dict(sd, strict=False)
|
||||
modelCS.eval()
|
||||
|
||||
modelCS.cond_stage_model.device = opt.device
|
||||
|
||||
modelFS = instantiate_from_config(config.modelFirstStage)
|
||||
_, _ = modelFS.load_state_dict(sd, strict=False)
|
||||
modelFS.eval()
|
||||
del sd
|
||||
|
||||
if opt.precision == "autocast":
|
||||
if opt.device != "cpu" and opt.precision == "autocast":
|
||||
model.half()
|
||||
modelCS.half()
|
||||
|
||||
start_code = None
|
||||
if opt.fixed_code:
|
||||
start_code = torch.randn([opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f], device=device)
|
||||
start_code = torch.randn([opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f], device=opt.device)
|
||||
|
||||
|
||||
batch_size = opt.n_samples
|
||||
n_rows = opt.n_rows if opt.n_rows > 0 else batch_size
|
||||
if not opt.from_file:
|
||||
assert opt.prompt is not None
|
||||
prompt = opt.prompt
|
||||
assert prompt is not None
|
||||
print(f"Using prompt: {prompt}")
|
||||
data = [batch_size * [prompt]]
|
||||
|
||||
else:
|
||||
print(f"reading prompts from {opt.from_file}")
|
||||
with open(opt.from_file, "r") as f:
|
||||
data = f.read().splitlines()
|
||||
data = list(chunk(data, batch_size))
|
||||
text = f.read()
|
||||
print(f"Using prompt: {text.strip()}")
|
||||
data = text.splitlines()
|
||||
data = batch_size * list(data)
|
||||
data = list(chunk(sorted(data), batch_size))
|
||||
|
||||
|
||||
precision_scope = autocast if opt.precision=="autocast" else nullcontext
|
||||
if opt.precision == "autocast" and opt.device != "cpu":
|
||||
precision_scope = autocast
|
||||
else:
|
||||
precision_scope = nullcontext
|
||||
|
||||
seeds = ""
|
||||
with torch.no_grad():
|
||||
|
||||
all_samples = list()
|
||||
for n in trange(opt.n_iter, desc="Sampling"):
|
||||
for prompts in tqdm(data, desc="data"):
|
||||
with precision_scope("cuda"):
|
||||
modelCS.to(device)
|
||||
|
||||
sample_path = os.path.join(outpath, "_".join(re.split(":| ", prompts[0])))[:150]
|
||||
os.makedirs(sample_path, exist_ok=True)
|
||||
base_count = len(os.listdir(sample_path))
|
||||
|
||||
with precision_scope("cuda"):
|
||||
modelCS.to(opt.device)
|
||||
uc = None
|
||||
if opt.scale != 1.0:
|
||||
uc = modelCS.get_learned_conditioning(batch_size * [""])
|
||||
if isinstance(prompts, tuple):
|
||||
prompts = list(prompts)
|
||||
|
||||
c = modelCS.get_learned_conditioning(prompts)
|
||||
shape = [opt.C, opt.H // opt.f, opt.W // opt.f]
|
||||
mem = torch.cuda.memory_allocated()/1e6
|
||||
modelCS.to("cpu")
|
||||
while(torch.cuda.memory_allocated()/1e6 >= mem):
|
||||
time.sleep(1)
|
||||
|
||||
subprompts, weights = split_weighted_subprompts(prompts[0])
|
||||
if len(subprompts) > 1:
|
||||
c = torch.zeros_like(uc)
|
||||
totalWeight = sum(weights)
|
||||
# normalize each "sub prompt" and add it
|
||||
for i in range(len(subprompts)):
|
||||
weight = weights[i]
|
||||
# if not skip_normalize:
|
||||
weight = weight / totalWeight
|
||||
c = torch.add(c, modelCS.get_learned_conditioning(subprompts[i]), alpha=weight)
|
||||
else:
|
||||
c = modelCS.get_learned_conditioning(prompts)
|
||||
|
||||
samples_ddim = model.sample(S=opt.ddim_steps,
|
||||
conditioning=c,
|
||||
batch_size=opt.n_samples,
|
||||
shape=shape,
|
||||
verbose=False,
|
||||
unconditional_guidance_scale=opt.scale,
|
||||
unconditional_conditioning=uc,
|
||||
eta=opt.ddim_eta,
|
||||
x_T=start_code)
|
||||
shape = [opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f]
|
||||
|
||||
modelFS.to(device)
|
||||
if opt.device != "cpu":
|
||||
mem = torch.cuda.memory_allocated() / 1e6
|
||||
modelCS.to("cpu")
|
||||
while torch.cuda.memory_allocated() / 1e6 >= mem:
|
||||
time.sleep(1)
|
||||
|
||||
samples_ddim = model.sample(
|
||||
S=opt.ddim_steps,
|
||||
conditioning=c,
|
||||
seed=opt.seed,
|
||||
shape=shape,
|
||||
verbose=False,
|
||||
unconditional_guidance_scale=opt.scale,
|
||||
unconditional_conditioning=uc,
|
||||
eta=opt.ddim_eta,
|
||||
x_T=start_code,
|
||||
sampler = opt.sampler,
|
||||
)
|
||||
|
||||
modelFS.to(opt.device)
|
||||
|
||||
print(samples_ddim.shape)
|
||||
print("saving images")
|
||||
for i in range(batch_size):
|
||||
|
||||
|
||||
x_samples_ddim = modelFS.decode_first_stage(samples_ddim[i].unsqueeze(0))
|
||||
x_sample = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
# for x_sample in x_samples_ddim:
|
||||
x_sample = 255. * rearrange(x_sample[0].cpu().numpy(), 'c h w -> h w c')
|
||||
x_sample = 255.0 * rearrange(x_sample[0].cpu().numpy(), "c h w -> h w c")
|
||||
Image.fromarray(x_sample.astype(np.uint8)).save(
|
||||
os.path.join(sample_path, f"{base_count:05}.png"))
|
||||
os.path.join(sample_path, "seed_" + str(opt.seed) + "_" + f"{base_count:05}.{opt.format}")
|
||||
)
|
||||
seeds += str(opt.seed) + ","
|
||||
opt.seed += 1
|
||||
base_count += 1
|
||||
|
||||
|
||||
mem = torch.cuda.memory_allocated()/1e6
|
||||
modelFS.to("cpu")
|
||||
while(torch.cuda.memory_allocated()/1e6 >= mem):
|
||||
time.sleep(1)
|
||||
|
||||
# if not opt.skip_grid:
|
||||
# all_samples.append(x_samples_ddim)
|
||||
if opt.device != "cpu":
|
||||
mem = torch.cuda.memory_allocated() / 1e6
|
||||
modelFS.to("cpu")
|
||||
while torch.cuda.memory_allocated() / 1e6 >= mem:
|
||||
time.sleep(1)
|
||||
del samples_ddim
|
||||
print("memory_final = ", torch.cuda.memory_allocated()/1e6)
|
||||
|
||||
# if not skip_grid:
|
||||
# # additionally, save as grid
|
||||
# grid = torch.stack(all_samples, 0)
|
||||
# grid = rearrange(grid, 'n b c h w -> (n b) c h w')
|
||||
# grid = make_grid(grid, nrow=n_rows)
|
||||
|
||||
# # to image
|
||||
# grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
|
||||
# Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'grid-{grid_count:04}.png'))
|
||||
# grid_count += 1
|
||||
print("memory_final = ", torch.cuda.memory_allocated() / 1e6)
|
||||
|
||||
toc = time.time()
|
||||
|
||||
time_taken = (toc-tic)/60.0
|
||||
time_taken = (toc - tic) / 60.0
|
||||
|
||||
print(("Your samples are ready in {0:.2f} minutes and waiting for you here \n" + sample_path).format(time_taken))
|
||||
print(
|
||||
(
|
||||
"Samples finished in {0:.2f} minutes and exported to "
|
||||
+ sample_path
|
||||
+ "\n Seeds used = "
|
||||
+ seeds[:-1]
|
||||
).format(time_taken)
|
||||
)
|
||||
|
252
optimizedSD/samplers.py
Normal file
252
optimizedSD/samplers.py
Normal file
@ -0,0 +1,252 @@
|
||||
from scipy import integrate
|
||||
import torch
|
||||
from tqdm.auto import trange, tqdm
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def append_zero(x):
|
||||
return torch.cat([x, x.new_zeros([1])])
|
||||
|
||||
|
||||
def append_dims(x, target_dims):
|
||||
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
|
||||
dims_to_append = target_dims - x.ndim
|
||||
if dims_to_append < 0:
|
||||
raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
|
||||
return x[(...,) + (None,) * dims_to_append]
|
||||
|
||||
def get_ancestral_step(sigma_from, sigma_to):
|
||||
"""Calculates the noise level (sigma_down) to step down to and the amount
|
||||
of noise to add (sigma_up) when doing an ancestral sampling step."""
|
||||
sigma_up = (sigma_to ** 2 * (sigma_from ** 2 - sigma_to ** 2) / sigma_from ** 2) ** 0.5
|
||||
sigma_down = (sigma_to ** 2 - sigma_up ** 2) ** 0.5
|
||||
return sigma_down, sigma_up
|
||||
|
||||
|
||||
class DiscreteSchedule(nn.Module):
|
||||
"""A mapping between continuous noise levels (sigmas) and a list of discrete noise
|
||||
levels."""
|
||||
|
||||
def __init__(self, sigmas, quantize):
|
||||
super().__init__()
|
||||
self.register_buffer('sigmas', sigmas)
|
||||
self.quantize = quantize
|
||||
|
||||
def get_sigmas(self, n=None):
|
||||
if n is None:
|
||||
return append_zero(self.sigmas.flip(0))
|
||||
t_max = len(self.sigmas) - 1
|
||||
t = torch.linspace(t_max, 0, n, device=self.sigmas.device)
|
||||
return append_zero(self.t_to_sigma(t))
|
||||
|
||||
def sigma_to_t(self, sigma, quantize=None):
|
||||
quantize = self.quantize if quantize is None else quantize
|
||||
dists = torch.abs(sigma - self.sigmas[:, None])
|
||||
if quantize:
|
||||
return torch.argmin(dists, dim=0).view(sigma.shape)
|
||||
low_idx, high_idx = torch.sort(torch.topk(dists, dim=0, k=2, largest=False).indices, dim=0)[0]
|
||||
low, high = self.sigmas[low_idx], self.sigmas[high_idx]
|
||||
w = (low - sigma) / (low - high)
|
||||
w = w.clamp(0, 1)
|
||||
t = (1 - w) * low_idx + w * high_idx
|
||||
return t.view(sigma.shape)
|
||||
|
||||
def t_to_sigma(self, t):
|
||||
t = t.float()
|
||||
low_idx, high_idx, w = t.floor().long(), t.ceil().long(), t.frac()
|
||||
# print(low_idx, high_idx, w )
|
||||
return (1 - w) * self.sigmas[low_idx] + w * self.sigmas[high_idx]
|
||||
|
||||
|
||||
class DiscreteEpsDDPMDenoiser(DiscreteSchedule):
|
||||
"""A wrapper for discrete schedule DDPM models that output eps (the predicted
|
||||
noise)."""
|
||||
|
||||
def __init__(self, alphas_cumprod, quantize):
|
||||
super().__init__(((1 - alphas_cumprod) / alphas_cumprod) ** 0.5, quantize)
|
||||
self.sigma_data = 1.
|
||||
|
||||
def get_scalings(self, sigma):
|
||||
c_out = -sigma
|
||||
c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
|
||||
return c_out, c_in
|
||||
|
||||
def get_eps(self, *args, **kwargs):
|
||||
return self.inner_model(*args, **kwargs)
|
||||
|
||||
def forward(self, input, sigma, **kwargs):
|
||||
c_out, c_in = [append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
|
||||
eps = self.get_eps(input * c_in, self.sigma_to_t(sigma), **kwargs)
|
||||
return input + eps * c_out
|
||||
|
||||
class CompVisDenoiser(DiscreteEpsDDPMDenoiser):
|
||||
"""A wrapper for CompVis diffusion models."""
|
||||
|
||||
def __init__(self, alphas_cumprod, quantize=False, device='cpu'):
|
||||
super().__init__(alphas_cumprod, quantize=quantize)
|
||||
|
||||
def get_eps(self, *args, **kwargs):
|
||||
return self.inner_model.apply_model(*args, **kwargs)
|
||||
|
||||
|
||||
def to_d(x, sigma, denoised):
|
||||
"""Converts a denoiser output to a Karras ODE derivative."""
|
||||
return (x - denoised) / append_dims(sigma, x.ndim)
|
||||
|
||||
|
||||
def get_ancestral_step(sigma_from, sigma_to):
|
||||
"""Calculates the noise level (sigma_down) to step down to and the amount
|
||||
of noise to add (sigma_up) when doing an ancestral sampling step."""
|
||||
sigma_up = (sigma_to ** 2 * (sigma_from ** 2 - sigma_to ** 2) / sigma_from ** 2) ** 0.5
|
||||
sigma_down = (sigma_to ** 2 - sigma_up ** 2) ** 0.5
|
||||
return sigma_down, sigma_up
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
|
||||
"""Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
|
||||
eps = torch.randn_like(x) * s_noise
|
||||
sigma_hat = sigmas[i] * (gamma + 1)
|
||||
if gamma > 0:
|
||||
x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
|
||||
denoised = model(x, sigma_hat * s_in, **extra_args)
|
||||
d = to_d(x, sigma_hat, denoised)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
|
||||
dt = sigmas[i + 1] - sigma_hat
|
||||
# Euler method
|
||||
x = x + d * dt
|
||||
return x
|
||||
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None):
|
||||
"""Ancestral sampling with Euler method steps."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1])
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
d = to_d(x, sigmas[i], denoised)
|
||||
# Euler method
|
||||
dt = sigma_down - sigmas[i]
|
||||
x = x + d * dt
|
||||
x = x + torch.randn_like(x) * sigma_up
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
|
||||
"""Implements Algorithm 2 (Heun steps) from Karras et al. (2022)."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
|
||||
eps = torch.randn_like(x) * s_noise
|
||||
sigma_hat = sigmas[i] * (gamma + 1)
|
||||
if gamma > 0:
|
||||
x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
|
||||
denoised = model(x, sigma_hat * s_in, **extra_args)
|
||||
d = to_d(x, sigma_hat, denoised)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
|
||||
dt = sigmas[i + 1] - sigma_hat
|
||||
if sigmas[i + 1] == 0:
|
||||
# Euler method
|
||||
x = x + d * dt
|
||||
else:
|
||||
# Heun's method
|
||||
x_2 = x + d * dt
|
||||
denoised_2 = model(x_2, sigmas[i + 1] * s_in, **extra_args)
|
||||
d_2 = to_d(x_2, sigmas[i + 1], denoised_2)
|
||||
d_prime = (d + d_2) / 2
|
||||
x = x + d_prime * dt
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpm_2(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
|
||||
"""A sampler inspired by DPM-Solver-2 and Algorithm 2 from Karras et al. (2022)."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
|
||||
eps = torch.randn_like(x) * s_noise
|
||||
sigma_hat = sigmas[i] * (gamma + 1)
|
||||
if gamma > 0:
|
||||
x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
|
||||
denoised = model(x, sigma_hat * s_in, **extra_args)
|
||||
d = to_d(x, sigma_hat, denoised)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
|
||||
# Midpoint method, where the midpoint is chosen according to a rho=3 Karras schedule
|
||||
sigma_mid = ((sigma_hat ** (1 / 3) + sigmas[i + 1] ** (1 / 3)) / 2) ** 3
|
||||
dt_1 = sigma_mid - sigma_hat
|
||||
dt_2 = sigmas[i + 1] - sigma_hat
|
||||
x_2 = x + d * dt_1
|
||||
denoised_2 = model(x_2, sigma_mid * s_in, **extra_args)
|
||||
d_2 = to_d(x_2, sigma_mid, denoised_2)
|
||||
x = x + d_2 * dt_2
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpm_2_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None):
|
||||
"""Ancestral sampling with DPM-Solver inspired second-order steps."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1])
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
d = to_d(x, sigmas[i], denoised)
|
||||
# Midpoint method, where the midpoint is chosen according to a rho=3 Karras schedule
|
||||
sigma_mid = ((sigmas[i] ** (1 / 3) + sigma_down ** (1 / 3)) / 2) ** 3
|
||||
dt_1 = sigma_mid - sigmas[i]
|
||||
dt_2 = sigma_down - sigmas[i]
|
||||
x_2 = x + d * dt_1
|
||||
denoised_2 = model(x_2, sigma_mid * s_in, **extra_args)
|
||||
d_2 = to_d(x_2, sigma_mid, denoised_2)
|
||||
x = x + d_2 * dt_2
|
||||
x = x + torch.randn_like(x) * sigma_up
|
||||
return x
|
||||
|
||||
|
||||
def linear_multistep_coeff(order, t, i, j):
|
||||
if order - 1 > i:
|
||||
raise ValueError(f'Order {order} too high for step {i}')
|
||||
def fn(tau):
|
||||
prod = 1.
|
||||
for k in range(order):
|
||||
if j == k:
|
||||
continue
|
||||
prod *= (tau - t[i - k]) / (t[i - j] - t[i - k])
|
||||
return prod
|
||||
return integrate.quad(fn, t[i], t[i + 1], epsrel=1e-4)[0]
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_lms(model, x, sigmas, extra_args=None, callback=None, disable=None, order=4):
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
ds = []
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
d = to_d(x, sigmas[i], denoised)
|
||||
ds.append(d)
|
||||
if len(ds) > order:
|
||||
ds.pop(0)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
cur_order = min(i + 1, order)
|
||||
coeffs = [linear_multistep_coeff(cur_order, sigmas.cpu(), i, j) for j in range(cur_order)]
|
||||
x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds)))
|
||||
return x
|
280
optimizedSD/splitAttention.py
Normal file
280
optimizedSD/splitAttention.py
Normal file
@ -0,0 +1,280 @@
|
||||
from inspect import isfunction
|
||||
import math
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn, einsum
|
||||
from einops import rearrange, repeat
|
||||
|
||||
from ldm.modules.diffusionmodules.util import checkpoint
|
||||
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
|
||||
def uniq(arr):
|
||||
return{el: True for el in arr}.keys()
|
||||
|
||||
|
||||
def default(val, d):
|
||||
if exists(val):
|
||||
return val
|
||||
return d() if isfunction(d) else d
|
||||
|
||||
|
||||
def max_neg_value(t):
|
||||
return -torch.finfo(t.dtype).max
|
||||
|
||||
|
||||
def init_(tensor):
|
||||
dim = tensor.shape[-1]
|
||||
std = 1 / math.sqrt(dim)
|
||||
tensor.uniform_(-std, std)
|
||||
return tensor
|
||||
|
||||
|
||||
# feedforward
|
||||
class GEGLU(nn.Module):
|
||||
def __init__(self, dim_in, dim_out):
|
||||
super().__init__()
|
||||
self.proj = nn.Linear(dim_in, dim_out * 2)
|
||||
|
||||
def forward(self, x):
|
||||
x, gate = self.proj(x).chunk(2, dim=-1)
|
||||
return x * F.gelu(gate)
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
|
||||
super().__init__()
|
||||
inner_dim = int(dim * mult)
|
||||
dim_out = default(dim_out, dim)
|
||||
project_in = nn.Sequential(
|
||||
nn.Linear(dim, inner_dim),
|
||||
nn.GELU()
|
||||
) if not glu else GEGLU(dim, inner_dim)
|
||||
|
||||
self.net = nn.Sequential(
|
||||
project_in,
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(inner_dim, dim_out)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
def zero_module(module):
|
||||
"""
|
||||
Zero out the parameters of a module and return it.
|
||||
"""
|
||||
for p in module.parameters():
|
||||
p.detach().zero_()
|
||||
return module
|
||||
|
||||
|
||||
def Normalize(in_channels):
|
||||
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
|
||||
|
||||
class LinearAttention(nn.Module):
|
||||
def __init__(self, dim, heads=4, dim_head=32):
|
||||
super().__init__()
|
||||
self.heads = heads
|
||||
hidden_dim = dim_head * heads
|
||||
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
|
||||
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
|
||||
|
||||
def forward(self, x):
|
||||
b, c, h, w = x.shape
|
||||
qkv = self.to_qkv(x)
|
||||
q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
|
||||
k = k.softmax(dim=-1)
|
||||
context = torch.einsum('bhdn,bhen->bhde', k, v)
|
||||
out = torch.einsum('bhde,bhdn->bhen', context, q)
|
||||
out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
class SpatialSelfAttention(nn.Module):
|
||||
def __init__(self, in_channels):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = Normalize(in_channels)
|
||||
self.q = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.k = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.v = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.proj_out = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
|
||||
def forward(self, x):
|
||||
h_ = x
|
||||
h_ = self.norm(h_)
|
||||
q = self.q(h_)
|
||||
k = self.k(h_)
|
||||
v = self.v(h_)
|
||||
|
||||
# compute attention
|
||||
b,c,h,w = q.shape
|
||||
q = rearrange(q, 'b c h w -> b (h w) c')
|
||||
k = rearrange(k, 'b c h w -> b c (h w)')
|
||||
w_ = torch.einsum('bij,bjk->bik', q, k)
|
||||
|
||||
w_ = w_ * (int(c)**(-0.5))
|
||||
w_ = torch.nn.functional.softmax(w_, dim=2)
|
||||
|
||||
# attend to values
|
||||
v = rearrange(v, 'b c h w -> b c (h w)')
|
||||
w_ = rearrange(w_, 'b i j -> b j i')
|
||||
h_ = torch.einsum('bij,bjk->bik', v, w_)
|
||||
h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
|
||||
h_ = self.proj_out(h_)
|
||||
|
||||
return x+h_
|
||||
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., att_step=1):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
context_dim = default(context_dim, query_dim)
|
||||
|
||||
self.scale = dim_head ** -0.5
|
||||
self.heads = heads
|
||||
self.att_step = att_step
|
||||
|
||||
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
||||
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
|
||||
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, query_dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x, context=None, mask=None):
|
||||
h = self.heads
|
||||
|
||||
q = self.to_q(x)
|
||||
context = default(context, x)
|
||||
k = self.to_k(context)
|
||||
v = self.to_v(context)
|
||||
del context, x
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
||||
|
||||
|
||||
limit = k.shape[0]
|
||||
att_step = self.att_step
|
||||
q_chunks = list(torch.tensor_split(q, limit//att_step, dim=0))
|
||||
k_chunks = list(torch.tensor_split(k, limit//att_step, dim=0))
|
||||
v_chunks = list(torch.tensor_split(v, limit//att_step, dim=0))
|
||||
|
||||
q_chunks.reverse()
|
||||
k_chunks.reverse()
|
||||
v_chunks.reverse()
|
||||
sim = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
|
||||
del k, q, v
|
||||
for i in range (0, limit, att_step):
|
||||
|
||||
q_buffer = q_chunks.pop()
|
||||
k_buffer = k_chunks.pop()
|
||||
v_buffer = v_chunks.pop()
|
||||
sim_buffer = einsum('b i d, b j d -> b i j', q_buffer, k_buffer) * self.scale
|
||||
|
||||
del k_buffer, q_buffer
|
||||
# attention, what we cannot get enough of, by chunks
|
||||
|
||||
sim_buffer = sim_buffer.softmax(dim=-1)
|
||||
|
||||
sim_buffer = einsum('b i j, b j d -> b i d', sim_buffer, v_buffer)
|
||||
del v_buffer
|
||||
sim[i:i+att_step,:,:] = sim_buffer
|
||||
|
||||
del sim_buffer
|
||||
sim = rearrange(sim, '(b h) n d -> b n (h d)', h=h)
|
||||
return self.to_out(sim)
|
||||
|
||||
|
||||
class BasicTransformerBlock(nn.Module):
|
||||
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True):
|
||||
super().__init__()
|
||||
self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention
|
||||
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
|
||||
self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
|
||||
heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
|
||||
self.norm1 = nn.LayerNorm(dim)
|
||||
self.norm2 = nn.LayerNorm(dim)
|
||||
self.norm3 = nn.LayerNorm(dim)
|
||||
self.checkpoint = checkpoint
|
||||
|
||||
def forward(self, x, context=None):
|
||||
return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
|
||||
|
||||
def _forward(self, x, context=None):
|
||||
x = self.attn1(self.norm1(x)) + x
|
||||
x = self.attn2(self.norm2(x), context=context) + x
|
||||
x = self.ff(self.norm3(x)) + x
|
||||
return x
|
||||
|
||||
|
||||
class SpatialTransformer(nn.Module):
|
||||
"""
|
||||
Transformer block for image-like data.
|
||||
First, project the input (aka embedding)
|
||||
and reshape to b, t, d.
|
||||
Then apply standard transformer action.
|
||||
Finally, reshape to image
|
||||
"""
|
||||
def __init__(self, in_channels, n_heads, d_head,
|
||||
depth=1, dropout=0., context_dim=None):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
inner_dim = n_heads * d_head
|
||||
self.norm = Normalize(in_channels)
|
||||
|
||||
self.proj_in = nn.Conv2d(in_channels,
|
||||
inner_dim,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim)
|
||||
for d in range(depth)]
|
||||
)
|
||||
|
||||
self.proj_out = zero_module(nn.Conv2d(inner_dim,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0))
|
||||
|
||||
def forward(self, x, context=None):
|
||||
# note: if no context is given, cross-attention defaults to self-attention
|
||||
b, c, h, w = x.shape
|
||||
x_in = x
|
||||
x = self.norm(x)
|
||||
x = self.proj_in(x)
|
||||
x = rearrange(x, 'b c h w -> b (h w) c')
|
||||
for block in self.transformer_blocks:
|
||||
x = block(x, context=context)
|
||||
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
|
||||
x = self.proj_out(x)
|
||||
return x + x_in
|
Loading…
Reference in New Issue
Block a user