Improved optimized mode speed and VRAM usage.

This commit is contained in:
ZeroCool940711 2022-12-03 09:33:43 -07:00
parent a3d5e5e548
commit 7ccfc86397
No known key found for this signature in database
GPG Key ID: 4E4072992B5BC640
7 changed files with 1430 additions and 217 deletions

View File

@ -6,7 +6,8 @@ https://github.com/CompVis/taming-transformers
-- merci -- merci
""" """
import time import time, math
from tqdm.auto import trange, tqdm
import torch import torch
from einops import rearrange from einops import rearrange
from tqdm import tqdm from tqdm import tqdm
@ -21,7 +22,7 @@ from ldm.util import exists, default, instantiate_from_config
from ldm.modules.diffusionmodules.util import make_beta_schedule from ldm.modules.diffusionmodules.util import make_beta_schedule
from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like
from .samplers import CompVisDenoiser, get_ancestral_step, to_d, append_dims,linear_multistep_coeff
def disabled_train(self): def disabled_train(self):
"""Overwrite model.train with this function to make sure train/eval mode """Overwrite model.train with this function to make sure train/eval mode
@ -92,7 +93,6 @@ class DDPM(pl.LightningModule):
cosine_s=cosine_s) cosine_s=cosine_s)
alphas = 1. - betas alphas = 1. - betas
alphas_cumprod = np.cumprod(alphas, axis=0) alphas_cumprod = np.cumprod(alphas, axis=0)
alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
timesteps, = betas.shape timesteps, = betas.shape
self.num_timesteps = int(timesteps) self.num_timesteps = int(timesteps)
@ -104,7 +104,6 @@ class DDPM(pl.LightningModule):
self.register_buffer('betas', to_torch(betas)) self.register_buffer('betas', to_torch(betas))
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
class FirstStage(DDPM): class FirstStage(DDPM):
@ -403,7 +402,7 @@ class UNet(DDPM):
h,emb,hs = self.model1(x_noisy[0:step], t[:step], cond[:step]) h,emb,hs = self.model1(x_noisy[0:step], t[:step], cond[:step])
bs = cond.shape[0] bs = cond.shape[0]
assert bs%2 == 0 # assert bs%2 == 0
lenhs = len(hs) lenhs = len(hs)
for i in range(step,bs,step): for i in range(step,bs,step):
@ -446,15 +445,14 @@ class UNet(DDPM):
self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
num_ddpm_timesteps=self.num_timesteps,verbose=verbose) num_ddpm_timesteps=self.num_timesteps,verbose=verbose)
alphas_cumprod = self.alphas_cumprod
assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
assert self.alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
to_torch = lambda x: x.to(self.cdevice) to_torch = lambda x: x.to(self.cdevice)
self.register_buffer1('betas', to_torch(self.betas)) self.register_buffer1('betas', to_torch(self.betas))
self.register_buffer1('alphas_cumprod', to_torch(alphas_cumprod)) self.register_buffer1('alphas_cumprod', to_torch(self.alphas_cumprod))
self.register_buffer1('alphas_cumprod_prev', to_torch(self.alphas_cumprod_prev))
self.register_buffer1('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
# ddim sampling parameters # ddim sampling parameters
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=self.alphas_cumprod.cpu(), ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=self.alphas_cumprod.cpu(),
ddim_timesteps=self.ddim_timesteps, ddim_timesteps=self.ddim_timesteps,
@ -463,25 +461,21 @@ class UNet(DDPM):
self.register_buffer1('ddim_alphas', ddim_alphas) self.register_buffer1('ddim_alphas', ddim_alphas)
self.register_buffer1('ddim_alphas_prev', ddim_alphas_prev) self.register_buffer1('ddim_alphas_prev', ddim_alphas_prev)
self.register_buffer1('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) self.register_buffer1('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
self.ddim_sqrt_one_minus_alphas = np.sqrt(1. - ddim_alphas)
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
1 - self.alphas_cumprod / self.alphas_cumprod_prev))
self.register_buffer1('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
@torch.no_grad() @torch.no_grad()
def sample(self, def sample(self,
S, S,
batch_size, conditioning,
shape, x0=None,
seed, shape = None,
conditioning=None, seed=1234,
callback=None, callback=None,
img_callback=None, img_callback=None,
quantize_x0=False, quantize_x0=False,
eta=0., eta=0.,
mask=None, mask=None,
x0=None, sampler = "plms",
temperature=1., temperature=1.,
noise_dropout=0., noise_dropout=0.,
score_corrector=None, score_corrector=None,
@ -492,41 +486,74 @@ class UNet(DDPM):
unconditional_guidance_scale=1., unconditional_guidance_scale=1.,
unconditional_conditioning=None, unconditional_conditioning=None,
): ):
if conditioning is not None:
if isinstance(conditioning, dict):
cbs = conditioning[list(conditioning.keys())[0]].shape[0]
if cbs != batch_size:
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
else:
if conditioning.shape[0] != batch_size:
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=False)
# sampling
C, H, W = shape
size = (batch_size, C, H, W)
print(f'Data shape for PLMS sampling is {size}')
if(self.turbo): if(self.turbo):
self.model1.to(self.cdevice) self.model1.to(self.cdevice)
self.model2.to(self.cdevice) self.model2.to(self.cdevice)
samples = self.plms_sampling(conditioning, size, seed, if x0 is None:
callback=callback, batch_size, b1, b2, b3 = shape
img_callback=img_callback, img_shape = (1, b1, b2, b3)
quantize_denoised=quantize_x0, tens = []
mask=mask, x0=x0, print("seeds used = ", [seed+s for s in range(batch_size)])
ddim_use_original_steps=False, for _ in range(batch_size):
noise_dropout=noise_dropout, torch.manual_seed(seed)
temperature=temperature, tens.append(torch.randn(img_shape, device=self.cdevice))
score_corrector=score_corrector, seed+=1
corrector_kwargs=corrector_kwargs, noise = torch.cat(tens)
x_T=x_T, del tens
log_every_t=log_every_t,
unconditional_guidance_scale=unconditional_guidance_scale, x_latent = noise if x0 is None else x0
unconditional_conditioning=unconditional_conditioning, # sampling
)
if sampler == "plms":
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=False)
print(f'Data shape for PLMS sampling is {shape}')
samples = self.plms_sampling(conditioning, batch_size, x_latent,
callback=callback,
img_callback=img_callback,
quantize_denoised=quantize_x0,
mask=mask, x0=x0,
ddim_use_original_steps=False,
noise_dropout=noise_dropout,
temperature=temperature,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
log_every_t=log_every_t,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
)
elif sampler == "ddim":
samples = self.ddim_sampling(x_latent, conditioning, S, unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
mask = mask,init_latent=x_T,use_original_steps=False)
elif sampler == "euler":
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=False)
samples = self.euler_sampling(self.alphas_cumprod,x_latent, S, conditioning, unconditional_conditioning=unconditional_conditioning,
unconditional_guidance_scale=unconditional_guidance_scale)
elif sampler == "euler_a":
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=False)
samples = self.euler_ancestral_sampling(self.alphas_cumprod,x_latent, S, conditioning, unconditional_conditioning=unconditional_conditioning,
unconditional_guidance_scale=unconditional_guidance_scale)
elif sampler == "dpm2":
samples = self.dpm_2_sampling(self.alphas_cumprod,x_latent, S, conditioning, unconditional_conditioning=unconditional_conditioning,
unconditional_guidance_scale=unconditional_guidance_scale)
elif sampler == "heun":
samples = self.heun_sampling(self.alphas_cumprod,x_latent, S, conditioning, unconditional_conditioning=unconditional_conditioning,
unconditional_guidance_scale=unconditional_guidance_scale)
elif sampler == "dpm2_a":
samples = self.dpm_2_ancestral_sampling(self.alphas_cumprod,x_latent, S, conditioning, unconditional_conditioning=unconditional_conditioning,
unconditional_guidance_scale=unconditional_guidance_scale)
elif sampler == "lms":
samples = self.lms_sampling(self.alphas_cumprod,x_latent, S, conditioning, unconditional_conditioning=unconditional_conditioning,
unconditional_guidance_scale=unconditional_guidance_scale)
if(self.turbo): if(self.turbo):
self.model1.to("cpu") self.model1.to("cpu")
@ -535,36 +562,17 @@ class UNet(DDPM):
return samples return samples
@torch.no_grad() @torch.no_grad()
def plms_sampling(self, cond, shape, seed, def plms_sampling(self, cond,b, img,
x_T=None, ddim_use_original_steps=False, ddim_use_original_steps=False,
callback=None, timesteps=None, quantize_denoised=False, callback=None, quantize_denoised=False,
mask=None, x0=None, img_callback=None, log_every_t=100, mask=None, x0=None, img_callback=None, log_every_t=100,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None,): unconditional_guidance_scale=1., unconditional_conditioning=None,):
device = self.betas.device device = self.betas.device
b = shape[0] timesteps = self.ddim_timesteps
if x_T is None: time_range = np.flip(timesteps)
_, b1, b2, b3 = shape total_steps = timesteps.shape[0]
img_shape = (1, b1, b2, b3)
tens = []
print("seeds used = ", [seed+s for s in range(b)])
for _ in range(b):
torch.manual_seed(seed)
tens.append(torch.randn(img_shape, device=device))
seed+=1
img = torch.cat(tens)
del tens
else:
img = x_T
if timesteps is None:
timesteps = self.num_timesteps if ddim_use_original_steps else self.ddim_timesteps
elif timesteps is not None and not ddim_use_original_steps:
subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
timesteps = self.ddim_timesteps[:subset_end]
time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps)
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
print(f"Running PLMS Sampling with {total_steps} timesteps") print(f"Running PLMS Sampling with {total_steps} timesteps")
iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps) iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps)
@ -618,10 +626,10 @@ class UNet(DDPM):
return e_t return e_t
alphas = self.alphas_cumprod if use_original_steps else self.ddim_alphas alphas = self.ddim_alphas
alphas_prev = self.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev alphas_prev = self.ddim_alphas_prev
sqrt_one_minus_alphas = self.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas sqrt_one_minus_alphas = self.ddim_sqrt_one_minus_alphas
sigmas = self.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas sigmas = self.ddim_sigmas
def get_x_prev_and_pred_x0(e_t, index): def get_x_prev_and_pred_x0(e_t, index):
# select parameters corresponding to the currently considered timestep # select parameters corresponding to the currently considered timestep
@ -664,17 +672,11 @@ class UNet(DDPM):
@torch.no_grad() @torch.no_grad()
def stochastic_encode(self, x0, t, seed, ddim_eta,ddim_steps,use_original_steps=False, noise=None, mask=None): def stochastic_encode(self, x0, t, seed, ddim_eta,ddim_steps,use_original_steps=False, noise=None):
# fast, but does not allow for exact reconstruction # fast, but does not allow for exact reconstruction
# t serves as an index to gather the correct alphas # t serves as an index to gather the correct alphas
self.make_schedule(ddim_num_steps=ddim_steps, ddim_eta=ddim_eta, verbose=False) self.make_schedule(ddim_num_steps=ddim_steps, ddim_eta=ddim_eta, verbose=False)
sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
if use_original_steps:
sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
else:
sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
if noise is None: if noise is None:
b0, b1, b2, b3 = x0.shape b0, b1, b2, b3 = x0.shape
@ -687,50 +689,53 @@ class UNet(DDPM):
seed+=1 seed+=1
noise = torch.cat(tens) noise = torch.cat(tens)
del tens del tens
if mask is not None:
noise = noise*mask
return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 + return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 +
extract_into_tensor(sqrt_one_minus_alphas_cumprod.to(self.cdevice), t, x0.shape) * noise) extract_into_tensor(self.ddim_sqrt_one_minus_alphas, t, x0.shape) * noise)
@torch.no_grad() @torch.no_grad()
def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None, def add_noise(self, x0, t):
mask = None,use_original_steps=False):
sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
if(self.turbo): noise = torch.randn(x0.shape, device=x0.device)
self.model1.to(self.cdevice)
self.model2.to(self.cdevice)
timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps # print(extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape),
# extract_into_tensor(self.ddim_sqrt_one_minus_alphas, t, x0.shape))
return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 +
extract_into_tensor(self.ddim_sqrt_one_minus_alphas, t, x0.shape) * noise)
@torch.no_grad()
def ddim_sampling(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
mask = None,init_latent=None,use_original_steps=False):
timesteps = self.ddim_timesteps
timesteps = timesteps[:t_start] timesteps = timesteps[:t_start]
time_range = np.flip(timesteps) time_range = np.flip(timesteps)
total_steps = timesteps.shape[0] total_steps = timesteps.shape[0]
print(f"Running DDIM Sampling with {total_steps} timesteps") print(f"Running DDIM Sampling with {total_steps} timesteps")
iterator = tqdm(time_range, desc='Decoding image', total=total_steps) iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
x_dec = x_latent x_dec = x_latent
# x0 = x_latent x0 = init_latent
for i, step in enumerate(iterator): for i, step in enumerate(iterator):
index = total_steps - i - 1 index = total_steps - i - 1
ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long) ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
# if mask is not None: if mask is not None:
# x_dec = x0 * mask + (1. - mask) * x_dec # x0_noisy = self.add_noise(mask, torch.tensor([index] * x0.shape[0]).to(self.cdevice))
x0_noisy = x0
x_dec = x0_noisy* mask + (1. - mask) * x_dec
x_dec = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps, x_dec = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
unconditional_guidance_scale=unconditional_guidance_scale, unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning) unconditional_conditioning=unconditional_conditioning)
# if mask is not None:
# return x0 * mask + (1. - mask) * x_dec
if(self.turbo): if mask is not None:
self.model1.to("cpu") return x0 * mask + (1. - mask) * x_dec
self.model2.to("cpu")
return x_dec return x_dec
@torch.no_grad() @torch.no_grad()
def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
@ -743,7 +748,6 @@ class UNet(DDPM):
x_in = torch.cat([x] * 2) x_in = torch.cat([x] * 2)
t_in = torch.cat([t] * 2) t_in = torch.cat([t] * 2)
c_in = torch.cat([unconditional_conditioning, c]) c_in = torch.cat([unconditional_conditioning, c])
# print("xin shape = ", x_in.shape)
e_t_uncond, e_t = self.apply_model(x_in, t_in, c_in).chunk(2) e_t_uncond, e_t = self.apply_model(x_in, t_in, c_in).chunk(2)
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
@ -751,10 +755,10 @@ class UNet(DDPM):
assert self.model.parameterization == "eps" assert self.model.parameterization == "eps"
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas alphas = self.ddim_alphas
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev alphas_prev = self.ddim_alphas_prev
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas sqrt_one_minus_alphas = self.ddim_sqrt_one_minus_alphas
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas sigmas = self.ddim_sigmas
# select parameters corresponding to the currently considered timestep # select parameters corresponding to the currently considered timestep
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
@ -771,4 +775,256 @@ class UNet(DDPM):
if noise_dropout > 0.: if noise_dropout > 0.:
noise = torch.nn.functional.dropout(noise, p=noise_dropout) noise = torch.nn.functional.dropout(noise, p=noise_dropout)
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
return x_prev return x_prev
@torch.no_grad()
def euler_sampling(self, ac, x, S, cond, unconditional_conditioning = None, unconditional_guidance_scale = 1,extra_args=None,callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
"""Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
extra_args = {} if extra_args is None else extra_args
cvd = CompVisDenoiser(ac)
sigmas = cvd.get_sigmas(S)
x = x*sigmas[0]
s_in = x.new_ones([x.shape[0]]).half()
for i in trange(len(sigmas) - 1, disable=disable):
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
eps = torch.randn_like(x) * s_noise
sigma_hat = (sigmas[i] * (gamma + 1)).half()
if gamma > 0:
x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
s_i = sigma_hat * s_in
x_in = torch.cat([x] * 2)
t_in = torch.cat([s_i] * 2)
cond_in = torch.cat([unconditional_conditioning, cond])
c_out, c_in = [append_dims(tmp, x_in.ndim) for tmp in cvd.get_scalings(t_in)]
eps = self.apply_model(x_in * c_in, cvd.sigma_to_t(t_in), cond_in)
e_t_uncond, e_t = (x_in + eps * c_out).chunk(2)
denoised = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
d = to_d(x, sigma_hat, denoised)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
dt = sigmas[i + 1] - sigma_hat
# Euler method
x = x + d * dt
return x
@torch.no_grad()
def euler_ancestral_sampling(self,ac,x, S, cond, unconditional_conditioning = None, unconditional_guidance_scale = 1,extra_args=None, callback=None, disable=None):
"""Ancestral sampling with Euler method steps."""
extra_args = {} if extra_args is None else extra_args
cvd = CompVisDenoiser(ac)
sigmas = cvd.get_sigmas(S)
x = x*sigmas[0]
s_in = x.new_ones([x.shape[0]]).half()
for i in trange(len(sigmas) - 1, disable=disable):
s_i = sigmas[i] * s_in
x_in = torch.cat([x] * 2)
t_in = torch.cat([s_i] * 2)
cond_in = torch.cat([unconditional_conditioning, cond])
c_out, c_in = [append_dims(tmp, x_in.ndim) for tmp in cvd.get_scalings(t_in)]
eps = self.apply_model(x_in * c_in, cvd.sigma_to_t(t_in), cond_in)
e_t_uncond, e_t = (x_in + eps * c_out).chunk(2)
denoised = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1])
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
d = to_d(x, sigmas[i], denoised)
# Euler method
dt = sigma_down - sigmas[i]
x = x + d * dt
x = x + torch.randn_like(x) * sigma_up
return x
@torch.no_grad()
def heun_sampling(self, ac, x, S, cond, unconditional_conditioning = None, unconditional_guidance_scale = 1, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
"""Implements Algorithm 2 (Heun steps) from Karras et al. (2022)."""
extra_args = {} if extra_args is None else extra_args
cvd = CompVisDenoiser(alphas_cumprod=ac)
sigmas = cvd.get_sigmas(S)
x = x*sigmas[0]
s_in = x.new_ones([x.shape[0]]).half()
for i in trange(len(sigmas) - 1, disable=disable):
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
eps = torch.randn_like(x) * s_noise
sigma_hat = (sigmas[i] * (gamma + 1)).half()
if gamma > 0:
x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
s_i = sigma_hat * s_in
x_in = torch.cat([x] * 2)
t_in = torch.cat([s_i] * 2)
cond_in = torch.cat([unconditional_conditioning, cond])
c_out, c_in = [append_dims(tmp, x_in.ndim) for tmp in cvd.get_scalings(t_in)]
eps = self.apply_model(x_in * c_in, cvd.sigma_to_t(t_in), cond_in)
e_t_uncond, e_t = (x_in + eps * c_out).chunk(2)
denoised = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
d = to_d(x, sigma_hat, denoised)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
dt = sigmas[i + 1] - sigma_hat
if sigmas[i + 1] == 0:
# Euler method
x = x + d * dt
else:
# Heun's method
x_2 = x + d * dt
s_i = sigmas[i + 1] * s_in
x_in = torch.cat([x_2] * 2)
t_in = torch.cat([s_i] * 2)
cond_in = torch.cat([unconditional_conditioning, cond])
c_out, c_in = [append_dims(tmp, x_in.ndim) for tmp in cvd.get_scalings(t_in)]
eps = self.apply_model(x_in * c_in, cvd.sigma_to_t(t_in), cond_in)
e_t_uncond, e_t = (x_in + eps * c_out).chunk(2)
denoised_2 = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
d_2 = to_d(x_2, sigmas[i + 1], denoised_2)
d_prime = (d + d_2) / 2
x = x + d_prime * dt
return x
@torch.no_grad()
def dpm_2_sampling(self,ac,x, S, cond, unconditional_conditioning = None, unconditional_guidance_scale = 1,extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
"""A sampler inspired by DPM-Solver-2 and Algorithm 2 from Karras et al. (2022)."""
extra_args = {} if extra_args is None else extra_args
cvd = CompVisDenoiser(ac)
sigmas = cvd.get_sigmas(S)
x = x*sigmas[0]
s_in = x.new_ones([x.shape[0]]).half()
for i in trange(len(sigmas) - 1, disable=disable):
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
eps = torch.randn_like(x) * s_noise
sigma_hat = sigmas[i] * (gamma + 1)
if gamma > 0:
x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
s_i = sigma_hat * s_in
x_in = torch.cat([x] * 2)
t_in = torch.cat([s_i] * 2)
cond_in = torch.cat([unconditional_conditioning, cond])
c_out, c_in = [append_dims(tmp, x_in.ndim) for tmp in cvd.get_scalings(t_in)]
eps = self.apply_model(x_in * c_in, cvd.sigma_to_t(t_in), cond_in)
e_t_uncond, e_t = (x_in + eps * c_out).chunk(2)
denoised = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
d = to_d(x, sigma_hat, denoised)
# Midpoint method, where the midpoint is chosen according to a rho=3 Karras schedule
sigma_mid = ((sigma_hat ** (1 / 3) + sigmas[i + 1] ** (1 / 3)) / 2) ** 3
dt_1 = sigma_mid - sigma_hat
dt_2 = sigmas[i + 1] - sigma_hat
x_2 = x + d * dt_1
s_i = sigma_mid * s_in
x_in = torch.cat([x_2] * 2)
t_in = torch.cat([s_i] * 2)
cond_in = torch.cat([unconditional_conditioning, cond])
c_out, c_in = [append_dims(tmp, x_in.ndim) for tmp in cvd.get_scalings(t_in)]
eps = self.apply_model(x_in * c_in, cvd.sigma_to_t(t_in), cond_in)
e_t_uncond, e_t = (x_in + eps * c_out).chunk(2)
denoised_2 = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
d_2 = to_d(x_2, sigma_mid, denoised_2)
x = x + d_2 * dt_2
return x
@torch.no_grad()
def dpm_2_ancestral_sampling(self,ac,x, S, cond, unconditional_conditioning = None, unconditional_guidance_scale = 1, extra_args=None, callback=None, disable=None):
"""Ancestral sampling with DPM-Solver inspired second-order steps."""
extra_args = {} if extra_args is None else extra_args
cvd = CompVisDenoiser(ac)
sigmas = cvd.get_sigmas(S)
x = x*sigmas[0]
s_in = x.new_ones([x.shape[0]]).half()
for i in trange(len(sigmas) - 1, disable=disable):
s_i = sigmas[i] * s_in
x_in = torch.cat([x] * 2)
t_in = torch.cat([s_i] * 2)
cond_in = torch.cat([unconditional_conditioning, cond])
c_out, c_in = [append_dims(tmp, x_in.ndim) for tmp in cvd.get_scalings(t_in)]
eps = self.apply_model(x_in * c_in, cvd.sigma_to_t(t_in), cond_in)
e_t_uncond, e_t = (x_in + eps * c_out).chunk(2)
denoised = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1])
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
d = to_d(x, sigmas[i], denoised)
# Midpoint method, where the midpoint is chosen according to a rho=3 Karras schedule
sigma_mid = ((sigmas[i] ** (1 / 3) + sigma_down ** (1 / 3)) / 2) ** 3
dt_1 = sigma_mid - sigmas[i]
dt_2 = sigma_down - sigmas[i]
x_2 = x + d * dt_1
s_i = sigma_mid * s_in
x_in = torch.cat([x_2] * 2)
t_in = torch.cat([s_i] * 2)
cond_in = torch.cat([unconditional_conditioning, cond])
c_out, c_in = [append_dims(tmp, x_in.ndim) for tmp in cvd.get_scalings(t_in)]
eps = self.apply_model(x_in * c_in, cvd.sigma_to_t(t_in), cond_in)
e_t_uncond, e_t = (x_in + eps * c_out).chunk(2)
denoised_2 = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
d_2 = to_d(x_2, sigma_mid, denoised_2)
x = x + d_2 * dt_2
x = x + torch.randn_like(x) * sigma_up
return x
@torch.no_grad()
def lms_sampling(self,ac,x, S, cond, unconditional_conditioning = None, unconditional_guidance_scale = 1, extra_args=None, callback=None, disable=None, order=4):
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
cvd = CompVisDenoiser(ac)
sigmas = cvd.get_sigmas(S)
x = x*sigmas[0]
ds = []
for i in trange(len(sigmas) - 1, disable=disable):
s_i = sigmas[i] * s_in
x_in = torch.cat([x] * 2)
t_in = torch.cat([s_i] * 2)
cond_in = torch.cat([unconditional_conditioning, cond])
c_out, c_in = [append_dims(tmp, x_in.ndim) for tmp in cvd.get_scalings(t_in)]
eps = self.apply_model(x_in * c_in, cvd.sigma_to_t(t_in), cond_in)
e_t_uncond, e_t = (x_in + eps * c_out).chunk(2)
denoised = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
d = to_d(x, sigmas[i], denoised)
ds.append(d)
if len(ds) > order:
ds.pop(0)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
cur_order = min(i + 1, order)
coeffs = [linear_multistep_coeff(cur_order, sigmas.cpu(), i, j) for j in range(cur_order)]
x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds)))
return x

View File

@ -0,0 +1,13 @@
import torch
from diffusers import LDMTextToImagePipeline
pipe = LDMTextToImagePipeline.from_pretrained("CompVis/stable-diffusion-v1-3-diffusers", use_auth_token=True)
prompt = "19th Century wooden engraving of Elon musk"
seed = torch.manual_seed(1024)
images = pipe([prompt], batch_size=1, num_inference_steps=50, guidance_scale=7, generator=seed,torch_device="cpu" )["sample"]
# save images
for idx, image in enumerate(images):
image.save(f"image-{idx}.png")

View File

@ -13,7 +13,7 @@ from ldm.modules.diffusionmodules.util import (
normalization, normalization,
timestep_embedding, timestep_embedding,
) )
from ldm.modules.attention import SpatialTransformer from .splitAttention import SpatialTransformer
class AttentionPool2d(nn.Module): class AttentionPool2d(nn.Module):

View File

@ -0,0 +1,362 @@
import argparse, os, re
import torch
import numpy as np
from random import randint
from omegaconf import OmegaConf
from PIL import Image
from tqdm import tqdm, trange
from itertools import islice
from einops import rearrange
from torchvision.utils import make_grid
import time
from pytorch_lightning import seed_everything
from torch import autocast
from contextlib import contextmanager, nullcontext
from einops import rearrange, repeat
from ldm.util import instantiate_from_config
from optimUtils import split_weighted_subprompts, logger
from transformers import logging
import pandas as pd
logging.set_verbosity_error()
def chunk(it, size):
it = iter(it)
return iter(lambda: tuple(islice(it, size)), ())
def load_model_from_config(ckpt, verbose=False):
print(f"Loading model from {ckpt}")
pl_sd = torch.load(ckpt, map_location="cpu")
if "global_step" in pl_sd:
print(f"Global Step: {pl_sd['global_step']}")
sd = pl_sd["state_dict"]
return sd
def load_img(path, h0, w0):
image = Image.open(path).convert("RGB")
w, h = image.size
print(f"loaded input image of size ({w}, {h}) from {path}")
if h0 is not None and w0 is not None:
h, w = h0, w0
w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 32
print(f"New image size ({w}, {h})")
image = image.resize((w, h), resample=Image.LANCZOS)
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
return 2.0 * image - 1.0
config = "optimizedSD/v1-inference.yaml"
ckpt = "models/ldm/stable-diffusion-v1/model.ckpt"
parser = argparse.ArgumentParser()
parser.add_argument(
"--prompt", type=str, nargs="?", default="a painting of a virus monster playing guitar", help="the prompt to render"
)
parser.add_argument("--outdir", type=str, nargs="?", help="dir to write results to", default="outputs/img2img-samples")
parser.add_argument("--init-img", type=str, nargs="?", help="path to the input image")
parser.add_argument(
"--skip_grid",
action="store_true",
help="do not save a grid, only individual samples. Helpful when evaluating lots of samples",
)
parser.add_argument(
"--skip_save",
action="store_true",
help="do not save individual samples. For speed measurements.",
)
parser.add_argument(
"--ddim_steps",
type=int,
default=50,
help="number of ddim sampling steps",
)
parser.add_argument(
"--ddim_eta",
type=float,
default=0.0,
help="ddim eta (eta=0.0 corresponds to deterministic sampling",
)
parser.add_argument(
"--n_iter",
type=int,
default=1,
help="sample this often",
)
parser.add_argument(
"--H",
type=int,
default=None,
help="image height, in pixel space",
)
parser.add_argument(
"--W",
type=int,
default=None,
help="image width, in pixel space",
)
parser.add_argument(
"--strength",
type=float,
default=0.75,
help="strength for noising/unnoising. 1.0 corresponds to full destruction of information in init image",
)
parser.add_argument(
"--n_samples",
type=int,
default=5,
help="how many samples to produce for each given prompt. A.k.a. batch size",
)
parser.add_argument(
"--n_rows",
type=int,
default=0,
help="rows in the grid (default: n_samples)",
)
parser.add_argument(
"--scale",
type=float,
default=7.5,
help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))",
)
parser.add_argument(
"--from-file",
type=str,
help="if specified, load prompts from this file",
)
parser.add_argument(
"--seed",
type=int,
default=None,
help="the seed (for reproducible sampling)",
)
parser.add_argument(
"--device",
type=str,
default="cuda",
help="CPU or GPU (cuda/cuda:0/cuda:1/...)",
)
parser.add_argument(
"--unet_bs",
type=int,
default=1,
help="Slightly reduces inference time at the expense of high VRAM (value > 1 not recommended )",
)
parser.add_argument(
"--turbo",
action="store_true",
help="Reduces inference time on the expense of 1GB VRAM",
)
parser.add_argument(
"--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast"
)
parser.add_argument(
"--format",
type=str,
help="output image format",
choices=["jpg", "png"],
default="png",
)
parser.add_argument(
"--sampler",
type=str,
help="sampler",
choices=["ddim"],
default="ddim",
)
opt = parser.parse_args()
tic = time.time()
os.makedirs(opt.outdir, exist_ok=True)
outpath = opt.outdir
grid_count = len(os.listdir(outpath)) - 1
if opt.seed == None:
opt.seed = randint(0, 1000000)
seed_everything(opt.seed)
# Logging
logger(vars(opt), log_csv = "logs/img2img_logs.csv")
sd = load_model_from_config(f"{ckpt}")
li, lo = [], []
for key, value in sd.items():
sp = key.split(".")
if (sp[0]) == "model":
if "input_blocks" in sp:
li.append(key)
elif "middle_block" in sp:
li.append(key)
elif "time_embed" in sp:
li.append(key)
else:
lo.append(key)
for key in li:
sd["model1." + key[6:]] = sd.pop(key)
for key in lo:
sd["model2." + key[6:]] = sd.pop(key)
config = OmegaConf.load(f"{config}")
assert os.path.isfile(opt.init_img)
init_image = load_img(opt.init_img, opt.H, opt.W).to(opt.device)
model = instantiate_from_config(config.modelUNet)
_, _ = model.load_state_dict(sd, strict=False)
model.eval()
model.cdevice = opt.device
model.unet_bs = opt.unet_bs
model.turbo = opt.turbo
modelCS = instantiate_from_config(config.modelCondStage)
_, _ = modelCS.load_state_dict(sd, strict=False)
modelCS.eval()
modelCS.cond_stage_model.device = opt.device
modelFS = instantiate_from_config(config.modelFirstStage)
_, _ = modelFS.load_state_dict(sd, strict=False)
modelFS.eval()
del sd
if opt.device != "cpu" and opt.precision == "autocast":
model.half()
modelCS.half()
modelFS.half()
init_image = init_image.half()
batch_size = opt.n_samples
n_rows = opt.n_rows if opt.n_rows > 0 else batch_size
if not opt.from_file:
assert opt.prompt is not None
prompt = opt.prompt
data = [batch_size * [prompt]]
else:
print(f"reading prompts from {opt.from_file}")
with open(opt.from_file, "r") as f:
data = f.read().splitlines()
data = batch_size * list(data)
data = list(chunk(sorted(data), batch_size))
modelFS.to(opt.device)
init_image = repeat(init_image, "1 ... -> b ...", b=batch_size)
init_latent = modelFS.get_first_stage_encoding(modelFS.encode_first_stage(init_image)) # move to latent space
if opt.device != "cpu":
mem = torch.cuda.memory_allocated(device=opt.device) / 1e6
modelFS.to("cpu")
while torch.cuda.memory_allocated(device=opt.device) / 1e6 >= mem:
time.sleep(1)
assert 0.0 <= opt.strength <= 1.0, "can only work with strength in [0.0, 1.0]"
t_enc = int(opt.strength * opt.ddim_steps)
print(f"target t_enc is {t_enc} steps")
if opt.precision == "autocast" and opt.device != "cpu":
precision_scope = autocast
else:
precision_scope = nullcontext
seeds = ""
with torch.no_grad():
all_samples = list()
for n in trange(opt.n_iter, desc="Sampling"):
for prompts in tqdm(data, desc="data"):
sample_path = os.path.join(outpath, "_".join(re.split(":| ", prompts[0])))[:150]
os.makedirs(sample_path, exist_ok=True)
base_count = len(os.listdir(sample_path))
with precision_scope("cuda"):
modelCS.to(opt.device)
uc = None
if opt.scale != 1.0:
uc = modelCS.get_learned_conditioning(batch_size * [""])
if isinstance(prompts, tuple):
prompts = list(prompts)
subprompts, weights = split_weighted_subprompts(prompts[0])
if len(subprompts) > 1:
c = torch.zeros_like(uc)
totalWeight = sum(weights)
# normalize each "sub prompt" and add it
for i in range(len(subprompts)):
weight = weights[i]
# if not skip_normalize:
weight = weight / totalWeight
c = torch.add(c, modelCS.get_learned_conditioning(subprompts[i]), alpha=weight)
else:
c = modelCS.get_learned_conditioning(prompts)
if opt.device != "cpu":
mem = torch.cuda.memory_allocated(device=opt.device) / 1e6
modelCS.to("cpu")
while torch.cuda.memory_allocated(device=opt.device) / 1e6 >= mem:
time.sleep(1)
# encode (scaled latent)
z_enc = model.stochastic_encode(
init_latent,
torch.tensor([t_enc] * batch_size).to(opt.device),
opt.seed,
opt.ddim_eta,
opt.ddim_steps,
)
# decode it
samples_ddim = model.sample(
t_enc,
c,
z_enc,
unconditional_guidance_scale=opt.scale,
unconditional_conditioning=uc,
sampler = opt.sampler
)
modelFS.to(opt.device)
print("saving images")
for i in range(batch_size):
x_samples_ddim = modelFS.decode_first_stage(samples_ddim[i].unsqueeze(0))
x_sample = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
x_sample = 255.0 * rearrange(x_sample[0].cpu().numpy(), "c h w -> h w c")
Image.fromarray(x_sample.astype(np.uint8)).save(
os.path.join(sample_path, "seed_" + str(opt.seed) + "_" + f"{base_count:05}.{opt.format}")
)
seeds += str(opt.seed) + ","
opt.seed += 1
base_count += 1
if opt.device != "cpu":
mem = torch.cuda.memory_allocated(device=opt.device) / 1e6
modelFS.to("cpu")
while torch.cuda.memory_allocated(device=opt.device) / 1e6 >= mem:
time.sleep(1)
del samples_ddim
print("memory_final = ", torch.cuda.memory_allocated(device=opt.device) / 1e6)
toc = time.time()
time_taken = (toc - tic) / 60.0
print(
(
"Samples finished in {0:.2f} minutes and exported to "
+ sample_path
+ "\n Seeds used = "
+ seeds[:-1]
).format(time_taken)
)

View File

@ -1,7 +1,7 @@
import argparse, os, sys, glob, random import argparse, os, re
import torch import torch
import numpy as np import numpy as np
import copy from random import randint
from omegaconf import OmegaConf from omegaconf import OmegaConf
from PIL import Image from PIL import Image
from tqdm import tqdm, trange from tqdm import tqdm, trange
@ -13,6 +13,10 @@ from pytorch_lightning import seed_everything
from torch import autocast from torch import autocast
from contextlib import contextmanager, nullcontext from contextlib import contextmanager, nullcontext
from ldm.util import instantiate_from_config from ldm.util import instantiate_from_config
from optimUtils import split_weighted_subprompts, logger
from transformers import logging
# from samplers import CompVisDenoiser
logging.set_verbosity_error()
def chunk(it, size): def chunk(it, size):
@ -30,33 +34,22 @@ def load_model_from_config(ckpt, verbose=False):
config = "optimizedSD/v1-inference.yaml" config = "optimizedSD/v1-inference.yaml"
ckpt = "models/ldm/stable-diffusion-v1/model.ckpt" DEFAULT_CKPT = "models/ldm/stable-diffusion-v1/model.ckpt"
device = "cuda"
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument(
"--prompt", "--prompt", type=str, nargs="?", default="a painting of a virus monster playing guitar", help="the prompt to render"
type=str,
nargs="?",
default="a painting of a virus monster playing guitar",
help="the prompt to render"
)
parser.add_argument(
"--outdir",
type=str,
nargs="?",
help="dir to write results to",
default="outputs/txt2img-samples"
) )
parser.add_argument("--outdir", type=str, nargs="?", help="dir to write results to", default="outputs/txt2img-samples")
parser.add_argument( parser.add_argument(
"--skip_grid", "--skip_grid",
action='store_true', action="store_true",
help="do not save a grid, only individual samples. Helpful when evaluating lots of samples", help="do not save a grid, only individual samples. Helpful when evaluating lots of samples",
) )
parser.add_argument( parser.add_argument(
"--skip_save", "--skip_save",
action='store_true', action="store_true",
help="do not save individual samples. For speed measurements.", help="do not save individual samples. For speed measurements.",
) )
parser.add_argument( parser.add_argument(
@ -68,7 +61,7 @@ parser.add_argument(
parser.add_argument( parser.add_argument(
"--fixed_code", "--fixed_code",
action='store_true', action="store_true",
help="if enabled, uses the same starting code across samples ", help="if enabled, uses the same starting code across samples ",
) )
parser.add_argument( parser.add_argument(
@ -125,6 +118,12 @@ parser.add_argument(
default=7.5, default=7.5,
help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))", help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))",
) )
parser.add_argument(
"--device",
type=str,
default="cuda",
help="specify GPU (cuda/cuda:0/cuda:1/...)",
)
parser.add_argument( parser.add_argument(
"--from-file", "--from-file",
type=str, type=str,
@ -133,165 +132,216 @@ parser.add_argument(
parser.add_argument( parser.add_argument(
"--seed", "--seed",
type=int, type=int,
default=42, default=None,
help="the seed (for reproducible sampling)", help="the seed (for reproducible sampling)",
) )
parser.add_argument( parser.add_argument(
"--small_batch", "--unet_bs",
action='store_true', type=int,
help="Reduce inference time when generate a smaller batch of images", default=1,
help="Slightly reduces inference time at the expense of high VRAM (value > 1 not recommended )",
) )
parser.add_argument( parser.add_argument(
"--precision", "--turbo",
action="store_true",
help="Reduces inference time on the expense of 1GB VRAM",
)
parser.add_argument(
"--precision",
type=str, type=str,
help="evaluate at this precision", help="evaluate at this precision",
choices=["full", "autocast"], choices=["full", "autocast"],
default="autocast" default="autocast"
) )
parser.add_argument(
"--format",
type=str,
help="output image format",
choices=["jpg", "png"],
default="png",
)
parser.add_argument(
"--sampler",
type=str,
help="sampler",
choices=["ddim", "plms","heun", "euler", "euler_a", "dpm2", "dpm2_a", "lms"],
default="plms",
)
parser.add_argument(
"--ckpt",
type=str,
help="path to checkpoint of model",
default=DEFAULT_CKPT,
)
opt = parser.parse_args() opt = parser.parse_args()
tic = time.time() tic = time.time()
os.makedirs(opt.outdir, exist_ok=True) os.makedirs(opt.outdir, exist_ok=True)
outpath = opt.outdir outpath = opt.outdir
sample_path = os.path.join(outpath, "samples", "_".join(opt.prompt.split())[:255])
os.makedirs(sample_path, exist_ok=True)
base_count = len(os.listdir(sample_path))
grid_count = len(os.listdir(outpath)) - 1 grid_count = len(os.listdir(outpath)) - 1
if opt.seed == None:
opt.seed = randint(0, 1000000)
seed_everything(opt.seed) seed_everything(opt.seed)
sd = load_model_from_config(f"{ckpt}") # Logging
li = [] logger(vars(opt), log_csv = "logs/txt2img_logs.csv")
lo = []
sd = load_model_from_config(f"{opt.ckpt}")
li, lo = [], []
for key, value in sd.items(): for key, value in sd.items():
sp = key.split('.') sp = key.split(".")
if(sp[0]) == 'model': if (sp[0]) == "model":
if('input_blocks' in sp): if "input_blocks" in sp:
li.append(key) li.append(key)
elif('middle_block' in sp): elif "middle_block" in sp:
li.append(key) li.append(key)
elif('time_embed' in sp): elif "time_embed" in sp:
li.append(key) li.append(key)
else: else:
lo.append(key) lo.append(key)
for key in li: for key in li:
sd['model1.' + key[6:]] = sd.pop(key) sd["model1." + key[6:]] = sd.pop(key)
for key in lo: for key in lo:
sd['model2.' + key[6:]] = sd.pop(key) sd["model2." + key[6:]] = sd.pop(key)
config = OmegaConf.load(f"{config}") config = OmegaConf.load(f"{config}")
config.modelUNet.params.ddim_steps = opt.ddim_steps
if opt.small_batch:
config.modelUNet.params.small_batch = True
else:
config.modelUNet.params.small_batch = False
model = instantiate_from_config(config.modelUNet) model = instantiate_from_config(config.modelUNet)
_, _ = model.load_state_dict(sd, strict=False) _, _ = model.load_state_dict(sd, strict=False)
model.eval() model.eval()
model.unet_bs = opt.unet_bs
model.cdevice = opt.device
model.turbo = opt.turbo
modelCS = instantiate_from_config(config.modelCondStage) modelCS = instantiate_from_config(config.modelCondStage)
_, _ = modelCS.load_state_dict(sd, strict=False) _, _ = modelCS.load_state_dict(sd, strict=False)
modelCS.eval() modelCS.eval()
modelCS.cond_stage_model.device = opt.device
modelFS = instantiate_from_config(config.modelFirstStage) modelFS = instantiate_from_config(config.modelFirstStage)
_, _ = modelFS.load_state_dict(sd, strict=False) _, _ = modelFS.load_state_dict(sd, strict=False)
modelFS.eval() modelFS.eval()
del sd
if opt.precision == "autocast": if opt.device != "cpu" and opt.precision == "autocast":
model.half() model.half()
modelCS.half() modelCS.half()
start_code = None start_code = None
if opt.fixed_code: if opt.fixed_code:
start_code = torch.randn([opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f], device=device) start_code = torch.randn([opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f], device=opt.device)
batch_size = opt.n_samples batch_size = opt.n_samples
n_rows = opt.n_rows if opt.n_rows > 0 else batch_size n_rows = opt.n_rows if opt.n_rows > 0 else batch_size
if not opt.from_file: if not opt.from_file:
assert opt.prompt is not None
prompt = opt.prompt prompt = opt.prompt
assert prompt is not None print(f"Using prompt: {prompt}")
data = [batch_size * [prompt]] data = [batch_size * [prompt]]
else: else:
print(f"reading prompts from {opt.from_file}") print(f"reading prompts from {opt.from_file}")
with open(opt.from_file, "r") as f: with open(opt.from_file, "r") as f:
data = f.read().splitlines() text = f.read()
data = list(chunk(data, batch_size)) print(f"Using prompt: {text.strip()}")
data = text.splitlines()
data = batch_size * list(data)
data = list(chunk(sorted(data), batch_size))
precision_scope = autocast if opt.precision=="autocast" else nullcontext if opt.precision == "autocast" and opt.device != "cpu":
precision_scope = autocast
else:
precision_scope = nullcontext
seeds = ""
with torch.no_grad(): with torch.no_grad():
all_samples = list() all_samples = list()
for n in trange(opt.n_iter, desc="Sampling"): for n in trange(opt.n_iter, desc="Sampling"):
for prompts in tqdm(data, desc="data"): for prompts in tqdm(data, desc="data"):
with precision_scope("cuda"):
modelCS.to(device) sample_path = os.path.join(outpath, "_".join(re.split(":| ", prompts[0])))[:150]
os.makedirs(sample_path, exist_ok=True)
base_count = len(os.listdir(sample_path))
with precision_scope("cuda"):
modelCS.to(opt.device)
uc = None uc = None
if opt.scale != 1.0: if opt.scale != 1.0:
uc = modelCS.get_learned_conditioning(batch_size * [""]) uc = modelCS.get_learned_conditioning(batch_size * [""])
if isinstance(prompts, tuple): if isinstance(prompts, tuple):
prompts = list(prompts) prompts = list(prompts)
c = modelCS.get_learned_conditioning(prompts)
shape = [opt.C, opt.H // opt.f, opt.W // opt.f]
mem = torch.cuda.memory_allocated()/1e6
modelCS.to("cpu")
while(torch.cuda.memory_allocated()/1e6 >= mem):
time.sleep(1)
subprompts, weights = split_weighted_subprompts(prompts[0])
if len(subprompts) > 1:
c = torch.zeros_like(uc)
totalWeight = sum(weights)
# normalize each "sub prompt" and add it
for i in range(len(subprompts)):
weight = weights[i]
# if not skip_normalize:
weight = weight / totalWeight
c = torch.add(c, modelCS.get_learned_conditioning(subprompts[i]), alpha=weight)
else:
c = modelCS.get_learned_conditioning(prompts)
samples_ddim = model.sample(S=opt.ddim_steps, shape = [opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f]
conditioning=c,
batch_size=opt.n_samples,
shape=shape,
verbose=False,
unconditional_guidance_scale=opt.scale,
unconditional_conditioning=uc,
eta=opt.ddim_eta,
x_T=start_code)
modelFS.to(device) if opt.device != "cpu":
mem = torch.cuda.memory_allocated() / 1e6
modelCS.to("cpu")
while torch.cuda.memory_allocated() / 1e6 >= mem:
time.sleep(1)
samples_ddim = model.sample(
S=opt.ddim_steps,
conditioning=c,
seed=opt.seed,
shape=shape,
verbose=False,
unconditional_guidance_scale=opt.scale,
unconditional_conditioning=uc,
eta=opt.ddim_eta,
x_T=start_code,
sampler = opt.sampler,
)
modelFS.to(opt.device)
print(samples_ddim.shape)
print("saving images") print("saving images")
for i in range(batch_size): for i in range(batch_size):
x_samples_ddim = modelFS.decode_first_stage(samples_ddim[i].unsqueeze(0)) x_samples_ddim = modelFS.decode_first_stage(samples_ddim[i].unsqueeze(0))
x_sample = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) x_sample = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
# for x_sample in x_samples_ddim: x_sample = 255.0 * rearrange(x_sample[0].cpu().numpy(), "c h w -> h w c")
x_sample = 255. * rearrange(x_sample[0].cpu().numpy(), 'c h w -> h w c')
Image.fromarray(x_sample.astype(np.uint8)).save( Image.fromarray(x_sample.astype(np.uint8)).save(
os.path.join(sample_path, f"{base_count:05}.png")) os.path.join(sample_path, "seed_" + str(opt.seed) + "_" + f"{base_count:05}.{opt.format}")
)
seeds += str(opt.seed) + ","
opt.seed += 1
base_count += 1 base_count += 1
if opt.device != "cpu":
mem = torch.cuda.memory_allocated()/1e6 mem = torch.cuda.memory_allocated() / 1e6
modelFS.to("cpu") modelFS.to("cpu")
while(torch.cuda.memory_allocated()/1e6 >= mem): while torch.cuda.memory_allocated() / 1e6 >= mem:
time.sleep(1) time.sleep(1)
# if not opt.skip_grid:
# all_samples.append(x_samples_ddim)
del samples_ddim del samples_ddim
print("memory_final = ", torch.cuda.memory_allocated()/1e6) print("memory_final = ", torch.cuda.memory_allocated() / 1e6)
# if not skip_grid:
# # additionally, save as grid
# grid = torch.stack(all_samples, 0)
# grid = rearrange(grid, 'n b c h w -> (n b) c h w')
# grid = make_grid(grid, nrow=n_rows)
# # to image
# grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
# Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'grid-{grid_count:04}.png'))
# grid_count += 1
toc = time.time() toc = time.time()
time_taken = (toc-tic)/60.0 time_taken = (toc - tic) / 60.0
print(("Your samples are ready in {0:.2f} minutes and waiting for you here \n" + sample_path).format(time_taken)) print(
(
"Samples finished in {0:.2f} minutes and exported to "
+ sample_path
+ "\n Seeds used = "
+ seeds[:-1]
).format(time_taken)
)

252
optimizedSD/samplers.py Normal file
View File

@ -0,0 +1,252 @@
from scipy import integrate
import torch
from tqdm.auto import trange, tqdm
import torch.nn as nn
def append_zero(x):
return torch.cat([x, x.new_zeros([1])])
def append_dims(x, target_dims):
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
dims_to_append = target_dims - x.ndim
if dims_to_append < 0:
raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
return x[(...,) + (None,) * dims_to_append]
def get_ancestral_step(sigma_from, sigma_to):
"""Calculates the noise level (sigma_down) to step down to and the amount
of noise to add (sigma_up) when doing an ancestral sampling step."""
sigma_up = (sigma_to ** 2 * (sigma_from ** 2 - sigma_to ** 2) / sigma_from ** 2) ** 0.5
sigma_down = (sigma_to ** 2 - sigma_up ** 2) ** 0.5
return sigma_down, sigma_up
class DiscreteSchedule(nn.Module):
"""A mapping between continuous noise levels (sigmas) and a list of discrete noise
levels."""
def __init__(self, sigmas, quantize):
super().__init__()
self.register_buffer('sigmas', sigmas)
self.quantize = quantize
def get_sigmas(self, n=None):
if n is None:
return append_zero(self.sigmas.flip(0))
t_max = len(self.sigmas) - 1
t = torch.linspace(t_max, 0, n, device=self.sigmas.device)
return append_zero(self.t_to_sigma(t))
def sigma_to_t(self, sigma, quantize=None):
quantize = self.quantize if quantize is None else quantize
dists = torch.abs(sigma - self.sigmas[:, None])
if quantize:
return torch.argmin(dists, dim=0).view(sigma.shape)
low_idx, high_idx = torch.sort(torch.topk(dists, dim=0, k=2, largest=False).indices, dim=0)[0]
low, high = self.sigmas[low_idx], self.sigmas[high_idx]
w = (low - sigma) / (low - high)
w = w.clamp(0, 1)
t = (1 - w) * low_idx + w * high_idx
return t.view(sigma.shape)
def t_to_sigma(self, t):
t = t.float()
low_idx, high_idx, w = t.floor().long(), t.ceil().long(), t.frac()
# print(low_idx, high_idx, w )
return (1 - w) * self.sigmas[low_idx] + w * self.sigmas[high_idx]
class DiscreteEpsDDPMDenoiser(DiscreteSchedule):
"""A wrapper for discrete schedule DDPM models that output eps (the predicted
noise)."""
def __init__(self, alphas_cumprod, quantize):
super().__init__(((1 - alphas_cumprod) / alphas_cumprod) ** 0.5, quantize)
self.sigma_data = 1.
def get_scalings(self, sigma):
c_out = -sigma
c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
return c_out, c_in
def get_eps(self, *args, **kwargs):
return self.inner_model(*args, **kwargs)
def forward(self, input, sigma, **kwargs):
c_out, c_in = [append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
eps = self.get_eps(input * c_in, self.sigma_to_t(sigma), **kwargs)
return input + eps * c_out
class CompVisDenoiser(DiscreteEpsDDPMDenoiser):
"""A wrapper for CompVis diffusion models."""
def __init__(self, alphas_cumprod, quantize=False, device='cpu'):
super().__init__(alphas_cumprod, quantize=quantize)
def get_eps(self, *args, **kwargs):
return self.inner_model.apply_model(*args, **kwargs)
def to_d(x, sigma, denoised):
"""Converts a denoiser output to a Karras ODE derivative."""
return (x - denoised) / append_dims(sigma, x.ndim)
def get_ancestral_step(sigma_from, sigma_to):
"""Calculates the noise level (sigma_down) to step down to and the amount
of noise to add (sigma_up) when doing an ancestral sampling step."""
sigma_up = (sigma_to ** 2 * (sigma_from ** 2 - sigma_to ** 2) / sigma_from ** 2) ** 0.5
sigma_down = (sigma_to ** 2 - sigma_up ** 2) ** 0.5
return sigma_down, sigma_up
@torch.no_grad()
def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
"""Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
eps = torch.randn_like(x) * s_noise
sigma_hat = sigmas[i] * (gamma + 1)
if gamma > 0:
x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
denoised = model(x, sigma_hat * s_in, **extra_args)
d = to_d(x, sigma_hat, denoised)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
dt = sigmas[i + 1] - sigma_hat
# Euler method
x = x + d * dt
return x
@torch.no_grad()
def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None):
"""Ancestral sampling with Euler method steps."""
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1])
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
d = to_d(x, sigmas[i], denoised)
# Euler method
dt = sigma_down - sigmas[i]
x = x + d * dt
x = x + torch.randn_like(x) * sigma_up
return x
@torch.no_grad()
def sample_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
"""Implements Algorithm 2 (Heun steps) from Karras et al. (2022)."""
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
eps = torch.randn_like(x) * s_noise
sigma_hat = sigmas[i] * (gamma + 1)
if gamma > 0:
x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
denoised = model(x, sigma_hat * s_in, **extra_args)
d = to_d(x, sigma_hat, denoised)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
dt = sigmas[i + 1] - sigma_hat
if sigmas[i + 1] == 0:
# Euler method
x = x + d * dt
else:
# Heun's method
x_2 = x + d * dt
denoised_2 = model(x_2, sigmas[i + 1] * s_in, **extra_args)
d_2 = to_d(x_2, sigmas[i + 1], denoised_2)
d_prime = (d + d_2) / 2
x = x + d_prime * dt
return x
@torch.no_grad()
def sample_dpm_2(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
"""A sampler inspired by DPM-Solver-2 and Algorithm 2 from Karras et al. (2022)."""
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
eps = torch.randn_like(x) * s_noise
sigma_hat = sigmas[i] * (gamma + 1)
if gamma > 0:
x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
denoised = model(x, sigma_hat * s_in, **extra_args)
d = to_d(x, sigma_hat, denoised)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
# Midpoint method, where the midpoint is chosen according to a rho=3 Karras schedule
sigma_mid = ((sigma_hat ** (1 / 3) + sigmas[i + 1] ** (1 / 3)) / 2) ** 3
dt_1 = sigma_mid - sigma_hat
dt_2 = sigmas[i + 1] - sigma_hat
x_2 = x + d * dt_1
denoised_2 = model(x_2, sigma_mid * s_in, **extra_args)
d_2 = to_d(x_2, sigma_mid, denoised_2)
x = x + d_2 * dt_2
return x
@torch.no_grad()
def sample_dpm_2_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None):
"""Ancestral sampling with DPM-Solver inspired second-order steps."""
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1])
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
d = to_d(x, sigmas[i], denoised)
# Midpoint method, where the midpoint is chosen according to a rho=3 Karras schedule
sigma_mid = ((sigmas[i] ** (1 / 3) + sigma_down ** (1 / 3)) / 2) ** 3
dt_1 = sigma_mid - sigmas[i]
dt_2 = sigma_down - sigmas[i]
x_2 = x + d * dt_1
denoised_2 = model(x_2, sigma_mid * s_in, **extra_args)
d_2 = to_d(x_2, sigma_mid, denoised_2)
x = x + d_2 * dt_2
x = x + torch.randn_like(x) * sigma_up
return x
def linear_multistep_coeff(order, t, i, j):
if order - 1 > i:
raise ValueError(f'Order {order} too high for step {i}')
def fn(tau):
prod = 1.
for k in range(order):
if j == k:
continue
prod *= (tau - t[i - k]) / (t[i - j] - t[i - k])
return prod
return integrate.quad(fn, t[i], t[i + 1], epsrel=1e-4)[0]
@torch.no_grad()
def sample_lms(model, x, sigmas, extra_args=None, callback=None, disable=None, order=4):
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
ds = []
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
d = to_d(x, sigmas[i], denoised)
ds.append(d)
if len(ds) > order:
ds.pop(0)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
cur_order = min(i + 1, order)
coeffs = [linear_multistep_coeff(cur_order, sigmas.cpu(), i, j) for j in range(cur_order)]
x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds)))
return x

View File

@ -0,0 +1,280 @@
from inspect import isfunction
import math
import torch
import torch.nn.functional as F
from torch import nn, einsum
from einops import rearrange, repeat
from ldm.modules.diffusionmodules.util import checkpoint
def exists(val):
return val is not None
def uniq(arr):
return{el: True for el in arr}.keys()
def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d
def max_neg_value(t):
return -torch.finfo(t.dtype).max
def init_(tensor):
dim = tensor.shape[-1]
std = 1 / math.sqrt(dim)
tensor.uniform_(-std, std)
return tensor
# feedforward
class GEGLU(nn.Module):
def __init__(self, dim_in, dim_out):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out * 2)
def forward(self, x):
x, gate = self.proj(x).chunk(2, dim=-1)
return x * F.gelu(gate)
class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
super().__init__()
inner_dim = int(dim * mult)
dim_out = default(dim_out, dim)
project_in = nn.Sequential(
nn.Linear(dim, inner_dim),
nn.GELU()
) if not glu else GEGLU(dim, inner_dim)
self.net = nn.Sequential(
project_in,
nn.Dropout(dropout),
nn.Linear(inner_dim, dim_out)
)
def forward(self, x):
return self.net(x)
def zero_module(module):
"""
Zero out the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().zero_()
return module
def Normalize(in_channels):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
class LinearAttention(nn.Module):
def __init__(self, dim, heads=4, dim_head=32):
super().__init__()
self.heads = heads
hidden_dim = dim_head * heads
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
def forward(self, x):
b, c, h, w = x.shape
qkv = self.to_qkv(x)
q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
k = k.softmax(dim=-1)
context = torch.einsum('bhdn,bhen->bhde', k, v)
out = torch.einsum('bhde,bhdn->bhen', context, q)
out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
return self.to_out(out)
class SpatialSelfAttention(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.in_channels = in_channels
self.norm = Normalize(in_channels)
self.q = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.k = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.v = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.proj_out = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
def forward(self, x):
h_ = x
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
# compute attention
b,c,h,w = q.shape
q = rearrange(q, 'b c h w -> b (h w) c')
k = rearrange(k, 'b c h w -> b c (h w)')
w_ = torch.einsum('bij,bjk->bik', q, k)
w_ = w_ * (int(c)**(-0.5))
w_ = torch.nn.functional.softmax(w_, dim=2)
# attend to values
v = rearrange(v, 'b c h w -> b c (h w)')
w_ = rearrange(w_, 'b i j -> b j i')
h_ = torch.einsum('bij,bjk->bik', v, w_)
h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
h_ = self.proj_out(h_)
return x+h_
class CrossAttention(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., att_step=1):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
self.scale = dim_head ** -0.5
self.heads = heads
self.att_step = att_step
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, query_dim),
nn.Dropout(dropout)
)
def forward(self, x, context=None, mask=None):
h = self.heads
q = self.to_q(x)
context = default(context, x)
k = self.to_k(context)
v = self.to_v(context)
del context, x
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
limit = k.shape[0]
att_step = self.att_step
q_chunks = list(torch.tensor_split(q, limit//att_step, dim=0))
k_chunks = list(torch.tensor_split(k, limit//att_step, dim=0))
v_chunks = list(torch.tensor_split(v, limit//att_step, dim=0))
q_chunks.reverse()
k_chunks.reverse()
v_chunks.reverse()
sim = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
del k, q, v
for i in range (0, limit, att_step):
q_buffer = q_chunks.pop()
k_buffer = k_chunks.pop()
v_buffer = v_chunks.pop()
sim_buffer = einsum('b i d, b j d -> b i j', q_buffer, k_buffer) * self.scale
del k_buffer, q_buffer
# attention, what we cannot get enough of, by chunks
sim_buffer = sim_buffer.softmax(dim=-1)
sim_buffer = einsum('b i j, b j d -> b i d', sim_buffer, v_buffer)
del v_buffer
sim[i:i+att_step,:,:] = sim_buffer
del sim_buffer
sim = rearrange(sim, '(b h) n d -> b n (h d)', h=h)
return self.to_out(sim)
class BasicTransformerBlock(nn.Module):
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True):
super().__init__()
self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.norm3 = nn.LayerNorm(dim)
self.checkpoint = checkpoint
def forward(self, x, context=None):
return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
def _forward(self, x, context=None):
x = self.attn1(self.norm1(x)) + x
x = self.attn2(self.norm2(x), context=context) + x
x = self.ff(self.norm3(x)) + x
return x
class SpatialTransformer(nn.Module):
"""
Transformer block for image-like data.
First, project the input (aka embedding)
and reshape to b, t, d.
Then apply standard transformer action.
Finally, reshape to image
"""
def __init__(self, in_channels, n_heads, d_head,
depth=1, dropout=0., context_dim=None):
super().__init__()
self.in_channels = in_channels
inner_dim = n_heads * d_head
self.norm = Normalize(in_channels)
self.proj_in = nn.Conv2d(in_channels,
inner_dim,
kernel_size=1,
stride=1,
padding=0)
self.transformer_blocks = nn.ModuleList(
[BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim)
for d in range(depth)]
)
self.proj_out = zero_module(nn.Conv2d(inner_dim,
in_channels,
kernel_size=1,
stride=1,
padding=0))
def forward(self, x, context=None):
# note: if no context is given, cross-attention defaults to self-attention
b, c, h, w = x.shape
x_in = x
x = self.norm(x)
x = self.proj_in(x)
x = rearrange(x, 'b c h w -> b (h w) c')
for block in self.transformer_blocks:
x = block(x, context=context)
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
x = self.proj_out(x)
return x + x_in