From 7ccfc86397b4d3b46ce298cac76028bb889ed96b Mon Sep 17 00:00:00 2001 From: ZeroCool940711 Date: Sat, 3 Dec 2022 09:33:43 -0700 Subject: [PATCH] Improved optimized mode speed and VRAM usage. --- optimizedSD/ddpm.py | 490 +++++++++++++++++++++++-------- optimizedSD/diffusers_txt2img.py | 13 + optimizedSD/openaimodelSplit.py | 2 +- optimizedSD/optimized_img2img.py | 362 +++++++++++++++++++++++ optimizedSD/optimized_txt2img.py | 248 +++++++++------- optimizedSD/samplers.py | 252 ++++++++++++++++ optimizedSD/splitAttention.py | 280 ++++++++++++++++++ 7 files changed, 1430 insertions(+), 217 deletions(-) create mode 100644 optimizedSD/diffusers_txt2img.py create mode 100644 optimizedSD/optimized_img2img.py create mode 100644 optimizedSD/samplers.py create mode 100644 optimizedSD/splitAttention.py diff --git a/optimizedSD/ddpm.py b/optimizedSD/ddpm.py index 776fdf1..7ddea6a 100644 --- a/optimizedSD/ddpm.py +++ b/optimizedSD/ddpm.py @@ -6,7 +6,8 @@ https://github.com/CompVis/taming-transformers -- merci """ -import time +import time, math +from tqdm.auto import trange, tqdm import torch from einops import rearrange from tqdm import tqdm @@ -21,7 +22,7 @@ from ldm.util import exists, default, instantiate_from_config from ldm.modules.diffusionmodules.util import make_beta_schedule from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like - +from .samplers import CompVisDenoiser, get_ancestral_step, to_d, append_dims,linear_multistep_coeff def disabled_train(self): """Overwrite model.train with this function to make sure train/eval mode @@ -92,7 +93,6 @@ class DDPM(pl.LightningModule): cosine_s=cosine_s) alphas = 1. - betas alphas_cumprod = np.cumprod(alphas, axis=0) - alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) timesteps, = betas.shape self.num_timesteps = int(timesteps) @@ -104,7 +104,6 @@ class DDPM(pl.LightningModule): self.register_buffer('betas', to_torch(betas)) self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) - self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) class FirstStage(DDPM): @@ -403,7 +402,7 @@ class UNet(DDPM): h,emb,hs = self.model1(x_noisy[0:step], t[:step], cond[:step]) bs = cond.shape[0] - assert bs%2 == 0 + # assert bs%2 == 0 lenhs = len(hs) for i in range(step,bs,step): @@ -446,15 +445,14 @@ class UNet(DDPM): self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, num_ddpm_timesteps=self.num_timesteps,verbose=verbose) - alphas_cumprod = self.alphas_cumprod - assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep' + + + assert self.alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep' + + to_torch = lambda x: x.to(self.cdevice) - self.register_buffer1('betas', to_torch(self.betas)) - self.register_buffer1('alphas_cumprod', to_torch(alphas_cumprod)) - self.register_buffer1('alphas_cumprod_prev', to_torch(self.alphas_cumprod_prev)) - self.register_buffer1('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) - + self.register_buffer1('alphas_cumprod', to_torch(self.alphas_cumprod)) # ddim sampling parameters ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=self.alphas_cumprod.cpu(), ddim_timesteps=self.ddim_timesteps, @@ -463,25 +461,21 @@ class UNet(DDPM): self.register_buffer1('ddim_alphas', ddim_alphas) self.register_buffer1('ddim_alphas_prev', ddim_alphas_prev) self.register_buffer1('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) - self.ddim_sqrt_one_minus_alphas = np.sqrt(1. - ddim_alphas) - sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( - (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( - 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) - self.register_buffer1('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) + @torch.no_grad() def sample(self, S, - batch_size, - shape, - seed, - conditioning=None, + conditioning, + x0=None, + shape = None, + seed=1234, callback=None, img_callback=None, quantize_x0=False, eta=0., mask=None, - x0=None, + sampler = "plms", temperature=1., noise_dropout=0., score_corrector=None, @@ -492,41 +486,74 @@ class UNet(DDPM): unconditional_guidance_scale=1., unconditional_conditioning=None, ): - if conditioning is not None: - if isinstance(conditioning, dict): - cbs = conditioning[list(conditioning.keys())[0]].shape[0] - if cbs != batch_size: - print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") - else: - if conditioning.shape[0] != batch_size: - print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") - - self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=False) - - # sampling - C, H, W = shape - size = (batch_size, C, H, W) - print(f'Data shape for PLMS sampling is {size}') + if(self.turbo): self.model1.to(self.cdevice) self.model2.to(self.cdevice) - samples = self.plms_sampling(conditioning, size, seed, - callback=callback, - img_callback=img_callback, - quantize_denoised=quantize_x0, - mask=mask, x0=x0, - ddim_use_original_steps=False, - noise_dropout=noise_dropout, - temperature=temperature, - score_corrector=score_corrector, - corrector_kwargs=corrector_kwargs, - x_T=x_T, - log_every_t=log_every_t, - unconditional_guidance_scale=unconditional_guidance_scale, - unconditional_conditioning=unconditional_conditioning, - ) + if x0 is None: + batch_size, b1, b2, b3 = shape + img_shape = (1, b1, b2, b3) + tens = [] + print("seeds used = ", [seed+s for s in range(batch_size)]) + for _ in range(batch_size): + torch.manual_seed(seed) + tens.append(torch.randn(img_shape, device=self.cdevice)) + seed+=1 + noise = torch.cat(tens) + del tens + + x_latent = noise if x0 is None else x0 + # sampling + + if sampler == "plms": + self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=False) + print(f'Data shape for PLMS sampling is {shape}') + samples = self.plms_sampling(conditioning, batch_size, x_latent, + callback=callback, + img_callback=img_callback, + quantize_denoised=quantize_x0, + mask=mask, x0=x0, + ddim_use_original_steps=False, + noise_dropout=noise_dropout, + temperature=temperature, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + log_every_t=log_every_t, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + ) + + elif sampler == "ddim": + samples = self.ddim_sampling(x_latent, conditioning, S, unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + mask = mask,init_latent=x_T,use_original_steps=False) + + elif sampler == "euler": + self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=False) + samples = self.euler_sampling(self.alphas_cumprod,x_latent, S, conditioning, unconditional_conditioning=unconditional_conditioning, + unconditional_guidance_scale=unconditional_guidance_scale) + elif sampler == "euler_a": + self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=False) + samples = self.euler_ancestral_sampling(self.alphas_cumprod,x_latent, S, conditioning, unconditional_conditioning=unconditional_conditioning, + unconditional_guidance_scale=unconditional_guidance_scale) + + elif sampler == "dpm2": + samples = self.dpm_2_sampling(self.alphas_cumprod,x_latent, S, conditioning, unconditional_conditioning=unconditional_conditioning, + unconditional_guidance_scale=unconditional_guidance_scale) + elif sampler == "heun": + samples = self.heun_sampling(self.alphas_cumprod,x_latent, S, conditioning, unconditional_conditioning=unconditional_conditioning, + unconditional_guidance_scale=unconditional_guidance_scale) + + elif sampler == "dpm2_a": + samples = self.dpm_2_ancestral_sampling(self.alphas_cumprod,x_latent, S, conditioning, unconditional_conditioning=unconditional_conditioning, + unconditional_guidance_scale=unconditional_guidance_scale) + + + elif sampler == "lms": + samples = self.lms_sampling(self.alphas_cumprod,x_latent, S, conditioning, unconditional_conditioning=unconditional_conditioning, + unconditional_guidance_scale=unconditional_guidance_scale) if(self.turbo): self.model1.to("cpu") @@ -535,36 +562,17 @@ class UNet(DDPM): return samples @torch.no_grad() - def plms_sampling(self, cond, shape, seed, - x_T=None, ddim_use_original_steps=False, - callback=None, timesteps=None, quantize_denoised=False, + def plms_sampling(self, cond,b, img, + ddim_use_original_steps=False, + callback=None, quantize_denoised=False, mask=None, x0=None, img_callback=None, log_every_t=100, temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, unconditional_guidance_scale=1., unconditional_conditioning=None,): + device = self.betas.device - b = shape[0] - if x_T is None: - _, b1, b2, b3 = shape - img_shape = (1, b1, b2, b3) - tens = [] - print("seeds used = ", [seed+s for s in range(b)]) - for _ in range(b): - torch.manual_seed(seed) - tens.append(torch.randn(img_shape, device=device)) - seed+=1 - img = torch.cat(tens) - del tens - else: - img = x_T - - if timesteps is None: - timesteps = self.num_timesteps if ddim_use_original_steps else self.ddim_timesteps - elif timesteps is not None and not ddim_use_original_steps: - subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 - timesteps = self.ddim_timesteps[:subset_end] - - time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps) - total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] + timesteps = self.ddim_timesteps + time_range = np.flip(timesteps) + total_steps = timesteps.shape[0] print(f"Running PLMS Sampling with {total_steps} timesteps") iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps) @@ -618,10 +626,10 @@ class UNet(DDPM): return e_t - alphas = self.alphas_cumprod if use_original_steps else self.ddim_alphas - alphas_prev = self.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev - sqrt_one_minus_alphas = self.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas - sigmas = self.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas + alphas = self.ddim_alphas + alphas_prev = self.ddim_alphas_prev + sqrt_one_minus_alphas = self.ddim_sqrt_one_minus_alphas + sigmas = self.ddim_sigmas def get_x_prev_and_pred_x0(e_t, index): # select parameters corresponding to the currently considered timestep @@ -664,17 +672,11 @@ class UNet(DDPM): @torch.no_grad() - def stochastic_encode(self, x0, t, seed, ddim_eta,ddim_steps,use_original_steps=False, noise=None, mask=None): + def stochastic_encode(self, x0, t, seed, ddim_eta,ddim_steps,use_original_steps=False, noise=None): # fast, but does not allow for exact reconstruction # t serves as an index to gather the correct alphas self.make_schedule(ddim_num_steps=ddim_steps, ddim_eta=ddim_eta, verbose=False) - - if use_original_steps: - sqrt_alphas_cumprod = self.sqrt_alphas_cumprod - sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod - else: - sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas) - sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas + sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas) if noise is None: b0, b1, b2, b3 = x0.shape @@ -687,50 +689,53 @@ class UNet(DDPM): seed+=1 noise = torch.cat(tens) del tens - if mask is not None: - noise = noise*mask return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 + - extract_into_tensor(sqrt_one_minus_alphas_cumprod.to(self.cdevice), t, x0.shape) * noise) + extract_into_tensor(self.ddim_sqrt_one_minus_alphas, t, x0.shape) * noise) @torch.no_grad() - def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None, - mask = None,use_original_steps=False): + def add_noise(self, x0, t): - - if(self.turbo): - self.model1.to(self.cdevice) - self.model2.to(self.cdevice) + sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas) + noise = torch.randn(x0.shape, device=x0.device) - timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps + # print(extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape), + # extract_into_tensor(self.ddim_sqrt_one_minus_alphas, t, x0.shape)) + return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 + + extract_into_tensor(self.ddim_sqrt_one_minus_alphas, t, x0.shape) * noise) + + + @torch.no_grad() + def ddim_sampling(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None, + mask = None,init_latent=None,use_original_steps=False): + + timesteps = self.ddim_timesteps timesteps = timesteps[:t_start] - time_range = np.flip(timesteps) total_steps = timesteps.shape[0] print(f"Running DDIM Sampling with {total_steps} timesteps") iterator = tqdm(time_range, desc='Decoding image', total=total_steps) x_dec = x_latent - # x0 = x_latent + x0 = init_latent for i, step in enumerate(iterator): index = total_steps - i - 1 - ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long) - + ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long) - # if mask is not None: - # x_dec = x0 * mask + (1. - mask) * x_dec + if mask is not None: + # x0_noisy = self.add_noise(mask, torch.tensor([index] * x0.shape[0]).to(self.cdevice)) + x0_noisy = x0 + x_dec = x0_noisy* mask + (1. - mask) * x_dec x_dec = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps, unconditional_guidance_scale=unconditional_guidance_scale, unconditional_conditioning=unconditional_conditioning) - # if mask is not None: - # return x0 * mask + (1. - mask) * x_dec - if(self.turbo): - self.model1.to("cpu") - self.model2.to("cpu") + if mask is not None: + return x0 * mask + (1. - mask) * x_dec return x_dec + @torch.no_grad() def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, @@ -743,7 +748,6 @@ class UNet(DDPM): x_in = torch.cat([x] * 2) t_in = torch.cat([t] * 2) c_in = torch.cat([unconditional_conditioning, c]) - # print("xin shape = ", x_in.shape) e_t_uncond, e_t = self.apply_model(x_in, t_in, c_in).chunk(2) e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) @@ -751,10 +755,10 @@ class UNet(DDPM): assert self.model.parameterization == "eps" e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) - alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas - alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev - sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas - sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas + alphas = self.ddim_alphas + alphas_prev = self.ddim_alphas_prev + sqrt_one_minus_alphas = self.ddim_sqrt_one_minus_alphas + sigmas = self.ddim_sigmas # select parameters corresponding to the currently considered timestep a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) @@ -771,4 +775,256 @@ class UNet(DDPM): if noise_dropout > 0.: noise = torch.nn.functional.dropout(noise, p=noise_dropout) x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise - return x_prev \ No newline at end of file + return x_prev + + + @torch.no_grad() + def euler_sampling(self, ac, x, S, cond, unconditional_conditioning = None, unconditional_guidance_scale = 1,extra_args=None,callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.): + """Implements Algorithm 2 (Euler steps) from Karras et al. (2022).""" + extra_args = {} if extra_args is None else extra_args + cvd = CompVisDenoiser(ac) + sigmas = cvd.get_sigmas(S) + x = x*sigmas[0] + + s_in = x.new_ones([x.shape[0]]).half() + for i in trange(len(sigmas) - 1, disable=disable): + gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0. + eps = torch.randn_like(x) * s_noise + sigma_hat = (sigmas[i] * (gamma + 1)).half() + if gamma > 0: + x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5 + + s_i = sigma_hat * s_in + x_in = torch.cat([x] * 2) + t_in = torch.cat([s_i] * 2) + cond_in = torch.cat([unconditional_conditioning, cond]) + c_out, c_in = [append_dims(tmp, x_in.ndim) for tmp in cvd.get_scalings(t_in)] + eps = self.apply_model(x_in * c_in, cvd.sigma_to_t(t_in), cond_in) + e_t_uncond, e_t = (x_in + eps * c_out).chunk(2) + denoised = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) + + + d = to_d(x, sigma_hat, denoised) + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised}) + dt = sigmas[i + 1] - sigma_hat + # Euler method + x = x + d * dt + return x + + @torch.no_grad() + def euler_ancestral_sampling(self,ac,x, S, cond, unconditional_conditioning = None, unconditional_guidance_scale = 1,extra_args=None, callback=None, disable=None): + """Ancestral sampling with Euler method steps.""" + extra_args = {} if extra_args is None else extra_args + + + cvd = CompVisDenoiser(ac) + sigmas = cvd.get_sigmas(S) + x = x*sigmas[0] + + s_in = x.new_ones([x.shape[0]]).half() + for i in trange(len(sigmas) - 1, disable=disable): + + s_i = sigmas[i] * s_in + x_in = torch.cat([x] * 2) + t_in = torch.cat([s_i] * 2) + cond_in = torch.cat([unconditional_conditioning, cond]) + c_out, c_in = [append_dims(tmp, x_in.ndim) for tmp in cvd.get_scalings(t_in)] + eps = self.apply_model(x_in * c_in, cvd.sigma_to_t(t_in), cond_in) + e_t_uncond, e_t = (x_in + eps * c_out).chunk(2) + denoised = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) + + sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1]) + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + d = to_d(x, sigmas[i], denoised) + # Euler method + dt = sigma_down - sigmas[i] + x = x + d * dt + x = x + torch.randn_like(x) * sigma_up + return x + + + + @torch.no_grad() + def heun_sampling(self, ac, x, S, cond, unconditional_conditioning = None, unconditional_guidance_scale = 1, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.): + """Implements Algorithm 2 (Heun steps) from Karras et al. (2022).""" + extra_args = {} if extra_args is None else extra_args + + cvd = CompVisDenoiser(alphas_cumprod=ac) + sigmas = cvd.get_sigmas(S) + x = x*sigmas[0] + + + s_in = x.new_ones([x.shape[0]]).half() + for i in trange(len(sigmas) - 1, disable=disable): + gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0. + eps = torch.randn_like(x) * s_noise + sigma_hat = (sigmas[i] * (gamma + 1)).half() + if gamma > 0: + x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5 + + s_i = sigma_hat * s_in + x_in = torch.cat([x] * 2) + t_in = torch.cat([s_i] * 2) + cond_in = torch.cat([unconditional_conditioning, cond]) + c_out, c_in = [append_dims(tmp, x_in.ndim) for tmp in cvd.get_scalings(t_in)] + eps = self.apply_model(x_in * c_in, cvd.sigma_to_t(t_in), cond_in) + e_t_uncond, e_t = (x_in + eps * c_out).chunk(2) + denoised = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) + + d = to_d(x, sigma_hat, denoised) + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised}) + dt = sigmas[i + 1] - sigma_hat + if sigmas[i + 1] == 0: + # Euler method + x = x + d * dt + else: + # Heun's method + x_2 = x + d * dt + s_i = sigmas[i + 1] * s_in + x_in = torch.cat([x_2] * 2) + t_in = torch.cat([s_i] * 2) + cond_in = torch.cat([unconditional_conditioning, cond]) + c_out, c_in = [append_dims(tmp, x_in.ndim) for tmp in cvd.get_scalings(t_in)] + eps = self.apply_model(x_in * c_in, cvd.sigma_to_t(t_in), cond_in) + e_t_uncond, e_t = (x_in + eps * c_out).chunk(2) + denoised_2 = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) + + d_2 = to_d(x_2, sigmas[i + 1], denoised_2) + d_prime = (d + d_2) / 2 + x = x + d_prime * dt + return x + + + @torch.no_grad() + def dpm_2_sampling(self,ac,x, S, cond, unconditional_conditioning = None, unconditional_guidance_scale = 1,extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.): + """A sampler inspired by DPM-Solver-2 and Algorithm 2 from Karras et al. (2022).""" + extra_args = {} if extra_args is None else extra_args + + cvd = CompVisDenoiser(ac) + sigmas = cvd.get_sigmas(S) + x = x*sigmas[0] + + s_in = x.new_ones([x.shape[0]]).half() + for i in trange(len(sigmas) - 1, disable=disable): + gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0. + eps = torch.randn_like(x) * s_noise + sigma_hat = sigmas[i] * (gamma + 1) + if gamma > 0: + x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5 + + s_i = sigma_hat * s_in + x_in = torch.cat([x] * 2) + t_in = torch.cat([s_i] * 2) + cond_in = torch.cat([unconditional_conditioning, cond]) + c_out, c_in = [append_dims(tmp, x_in.ndim) for tmp in cvd.get_scalings(t_in)] + eps = self.apply_model(x_in * c_in, cvd.sigma_to_t(t_in), cond_in) + e_t_uncond, e_t = (x_in + eps * c_out).chunk(2) + denoised = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) + + + + d = to_d(x, sigma_hat, denoised) + # Midpoint method, where the midpoint is chosen according to a rho=3 Karras schedule + sigma_mid = ((sigma_hat ** (1 / 3) + sigmas[i + 1] ** (1 / 3)) / 2) ** 3 + dt_1 = sigma_mid - sigma_hat + dt_2 = sigmas[i + 1] - sigma_hat + x_2 = x + d * dt_1 + + s_i = sigma_mid * s_in + x_in = torch.cat([x_2] * 2) + t_in = torch.cat([s_i] * 2) + cond_in = torch.cat([unconditional_conditioning, cond]) + c_out, c_in = [append_dims(tmp, x_in.ndim) for tmp in cvd.get_scalings(t_in)] + eps = self.apply_model(x_in * c_in, cvd.sigma_to_t(t_in), cond_in) + e_t_uncond, e_t = (x_in + eps * c_out).chunk(2) + denoised_2 = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) + + + d_2 = to_d(x_2, sigma_mid, denoised_2) + x = x + d_2 * dt_2 + return x + + + @torch.no_grad() + def dpm_2_ancestral_sampling(self,ac,x, S, cond, unconditional_conditioning = None, unconditional_guidance_scale = 1, extra_args=None, callback=None, disable=None): + """Ancestral sampling with DPM-Solver inspired second-order steps.""" + extra_args = {} if extra_args is None else extra_args + + cvd = CompVisDenoiser(ac) + sigmas = cvd.get_sigmas(S) + x = x*sigmas[0] + + s_in = x.new_ones([x.shape[0]]).half() + for i in trange(len(sigmas) - 1, disable=disable): + + s_i = sigmas[i] * s_in + x_in = torch.cat([x] * 2) + t_in = torch.cat([s_i] * 2) + cond_in = torch.cat([unconditional_conditioning, cond]) + c_out, c_in = [append_dims(tmp, x_in.ndim) for tmp in cvd.get_scalings(t_in)] + eps = self.apply_model(x_in * c_in, cvd.sigma_to_t(t_in), cond_in) + e_t_uncond, e_t = (x_in + eps * c_out).chunk(2) + denoised = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) + + + sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1]) + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + d = to_d(x, sigmas[i], denoised) + # Midpoint method, where the midpoint is chosen according to a rho=3 Karras schedule + sigma_mid = ((sigmas[i] ** (1 / 3) + sigma_down ** (1 / 3)) / 2) ** 3 + dt_1 = sigma_mid - sigmas[i] + dt_2 = sigma_down - sigmas[i] + x_2 = x + d * dt_1 + + s_i = sigma_mid * s_in + x_in = torch.cat([x_2] * 2) + t_in = torch.cat([s_i] * 2) + cond_in = torch.cat([unconditional_conditioning, cond]) + c_out, c_in = [append_dims(tmp, x_in.ndim) for tmp in cvd.get_scalings(t_in)] + eps = self.apply_model(x_in * c_in, cvd.sigma_to_t(t_in), cond_in) + e_t_uncond, e_t = (x_in + eps * c_out).chunk(2) + denoised_2 = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) + + + d_2 = to_d(x_2, sigma_mid, denoised_2) + x = x + d_2 * dt_2 + x = x + torch.randn_like(x) * sigma_up + return x + + + @torch.no_grad() + def lms_sampling(self,ac,x, S, cond, unconditional_conditioning = None, unconditional_guidance_scale = 1, extra_args=None, callback=None, disable=None, order=4): + extra_args = {} if extra_args is None else extra_args + s_in = x.new_ones([x.shape[0]]) + + cvd = CompVisDenoiser(ac) + sigmas = cvd.get_sigmas(S) + x = x*sigmas[0] + + ds = [] + for i in trange(len(sigmas) - 1, disable=disable): + + s_i = sigmas[i] * s_in + x_in = torch.cat([x] * 2) + t_in = torch.cat([s_i] * 2) + cond_in = torch.cat([unconditional_conditioning, cond]) + c_out, c_in = [append_dims(tmp, x_in.ndim) for tmp in cvd.get_scalings(t_in)] + eps = self.apply_model(x_in * c_in, cvd.sigma_to_t(t_in), cond_in) + e_t_uncond, e_t = (x_in + eps * c_out).chunk(2) + denoised = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) + + + d = to_d(x, sigmas[i], denoised) + ds.append(d) + if len(ds) > order: + ds.pop(0) + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + cur_order = min(i + 1, order) + coeffs = [linear_multistep_coeff(cur_order, sigmas.cpu(), i, j) for j in range(cur_order)] + x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds))) + return x diff --git a/optimizedSD/diffusers_txt2img.py b/optimizedSD/diffusers_txt2img.py new file mode 100644 index 0000000..80fbb97 --- /dev/null +++ b/optimizedSD/diffusers_txt2img.py @@ -0,0 +1,13 @@ +import torch +from diffusers import LDMTextToImagePipeline + +pipe = LDMTextToImagePipeline.from_pretrained("CompVis/stable-diffusion-v1-3-diffusers", use_auth_token=True) + +prompt = "19th Century wooden engraving of Elon musk" + +seed = torch.manual_seed(1024) +images = pipe([prompt], batch_size=1, num_inference_steps=50, guidance_scale=7, generator=seed,torch_device="cpu" )["sample"] + +# save images +for idx, image in enumerate(images): + image.save(f"image-{idx}.png") diff --git a/optimizedSD/openaimodelSplit.py b/optimizedSD/openaimodelSplit.py index 2136ada..7a32ffe 100644 --- a/optimizedSD/openaimodelSplit.py +++ b/optimizedSD/openaimodelSplit.py @@ -13,7 +13,7 @@ from ldm.modules.diffusionmodules.util import ( normalization, timestep_embedding, ) -from ldm.modules.attention import SpatialTransformer +from .splitAttention import SpatialTransformer class AttentionPool2d(nn.Module): diff --git a/optimizedSD/optimized_img2img.py b/optimizedSD/optimized_img2img.py new file mode 100644 index 0000000..24f3338 --- /dev/null +++ b/optimizedSD/optimized_img2img.py @@ -0,0 +1,362 @@ +import argparse, os, re +import torch +import numpy as np +from random import randint +from omegaconf import OmegaConf +from PIL import Image +from tqdm import tqdm, trange +from itertools import islice +from einops import rearrange +from torchvision.utils import make_grid +import time +from pytorch_lightning import seed_everything +from torch import autocast +from contextlib import contextmanager, nullcontext +from einops import rearrange, repeat +from ldm.util import instantiate_from_config +from optimUtils import split_weighted_subprompts, logger +from transformers import logging +import pandas as pd +logging.set_verbosity_error() + + +def chunk(it, size): + it = iter(it) + return iter(lambda: tuple(islice(it, size)), ()) + + +def load_model_from_config(ckpt, verbose=False): + print(f"Loading model from {ckpt}") + pl_sd = torch.load(ckpt, map_location="cpu") + if "global_step" in pl_sd: + print(f"Global Step: {pl_sd['global_step']}") + sd = pl_sd["state_dict"] + return sd + + +def load_img(path, h0, w0): + + image = Image.open(path).convert("RGB") + w, h = image.size + + print(f"loaded input image of size ({w}, {h}) from {path}") + if h0 is not None and w0 is not None: + h, w = h0, w0 + + w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 32 + + print(f"New image size ({w}, {h})") + image = image.resize((w, h), resample=Image.LANCZOS) + image = np.array(image).astype(np.float32) / 255.0 + image = image[None].transpose(0, 3, 1, 2) + image = torch.from_numpy(image) + return 2.0 * image - 1.0 + + +config = "optimizedSD/v1-inference.yaml" +ckpt = "models/ldm/stable-diffusion-v1/model.ckpt" + +parser = argparse.ArgumentParser() + +parser.add_argument( + "--prompt", type=str, nargs="?", default="a painting of a virus monster playing guitar", help="the prompt to render" +) +parser.add_argument("--outdir", type=str, nargs="?", help="dir to write results to", default="outputs/img2img-samples") +parser.add_argument("--init-img", type=str, nargs="?", help="path to the input image") + +parser.add_argument( + "--skip_grid", + action="store_true", + help="do not save a grid, only individual samples. Helpful when evaluating lots of samples", +) +parser.add_argument( + "--skip_save", + action="store_true", + help="do not save individual samples. For speed measurements.", +) +parser.add_argument( + "--ddim_steps", + type=int, + default=50, + help="number of ddim sampling steps", +) + +parser.add_argument( + "--ddim_eta", + type=float, + default=0.0, + help="ddim eta (eta=0.0 corresponds to deterministic sampling", +) +parser.add_argument( + "--n_iter", + type=int, + default=1, + help="sample this often", +) +parser.add_argument( + "--H", + type=int, + default=None, + help="image height, in pixel space", +) +parser.add_argument( + "--W", + type=int, + default=None, + help="image width, in pixel space", +) +parser.add_argument( + "--strength", + type=float, + default=0.75, + help="strength for noising/unnoising. 1.0 corresponds to full destruction of information in init image", +) +parser.add_argument( + "--n_samples", + type=int, + default=5, + help="how many samples to produce for each given prompt. A.k.a. batch size", +) +parser.add_argument( + "--n_rows", + type=int, + default=0, + help="rows in the grid (default: n_samples)", +) +parser.add_argument( + "--scale", + type=float, + default=7.5, + help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))", +) +parser.add_argument( + "--from-file", + type=str, + help="if specified, load prompts from this file", +) +parser.add_argument( + "--seed", + type=int, + default=None, + help="the seed (for reproducible sampling)", +) +parser.add_argument( + "--device", + type=str, + default="cuda", + help="CPU or GPU (cuda/cuda:0/cuda:1/...)", +) +parser.add_argument( + "--unet_bs", + type=int, + default=1, + help="Slightly reduces inference time at the expense of high VRAM (value > 1 not recommended )", +) +parser.add_argument( + "--turbo", + action="store_true", + help="Reduces inference time on the expense of 1GB VRAM", +) +parser.add_argument( + "--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast" +) +parser.add_argument( + "--format", + type=str, + help="output image format", + choices=["jpg", "png"], + default="png", +) +parser.add_argument( + "--sampler", + type=str, + help="sampler", + choices=["ddim"], + default="ddim", +) +opt = parser.parse_args() + +tic = time.time() +os.makedirs(opt.outdir, exist_ok=True) +outpath = opt.outdir +grid_count = len(os.listdir(outpath)) - 1 + +if opt.seed == None: + opt.seed = randint(0, 1000000) +seed_everything(opt.seed) + +# Logging +logger(vars(opt), log_csv = "logs/img2img_logs.csv") + +sd = load_model_from_config(f"{ckpt}") +li, lo = [], [] +for key, value in sd.items(): + sp = key.split(".") + if (sp[0]) == "model": + if "input_blocks" in sp: + li.append(key) + elif "middle_block" in sp: + li.append(key) + elif "time_embed" in sp: + li.append(key) + else: + lo.append(key) +for key in li: + sd["model1." + key[6:]] = sd.pop(key) +for key in lo: + sd["model2." + key[6:]] = sd.pop(key) + +config = OmegaConf.load(f"{config}") + +assert os.path.isfile(opt.init_img) +init_image = load_img(opt.init_img, opt.H, opt.W).to(opt.device) + +model = instantiate_from_config(config.modelUNet) +_, _ = model.load_state_dict(sd, strict=False) +model.eval() +model.cdevice = opt.device +model.unet_bs = opt.unet_bs +model.turbo = opt.turbo + +modelCS = instantiate_from_config(config.modelCondStage) +_, _ = modelCS.load_state_dict(sd, strict=False) +modelCS.eval() +modelCS.cond_stage_model.device = opt.device + +modelFS = instantiate_from_config(config.modelFirstStage) +_, _ = modelFS.load_state_dict(sd, strict=False) +modelFS.eval() +del sd +if opt.device != "cpu" and opt.precision == "autocast": + model.half() + modelCS.half() + modelFS.half() + init_image = init_image.half() + +batch_size = opt.n_samples +n_rows = opt.n_rows if opt.n_rows > 0 else batch_size +if not opt.from_file: + assert opt.prompt is not None + prompt = opt.prompt + data = [batch_size * [prompt]] + +else: + print(f"reading prompts from {opt.from_file}") + with open(opt.from_file, "r") as f: + data = f.read().splitlines() + data = batch_size * list(data) + data = list(chunk(sorted(data), batch_size)) + +modelFS.to(opt.device) + +init_image = repeat(init_image, "1 ... -> b ...", b=batch_size) +init_latent = modelFS.get_first_stage_encoding(modelFS.encode_first_stage(init_image)) # move to latent space + +if opt.device != "cpu": + mem = torch.cuda.memory_allocated(device=opt.device) / 1e6 + modelFS.to("cpu") + while torch.cuda.memory_allocated(device=opt.device) / 1e6 >= mem: + time.sleep(1) + + +assert 0.0 <= opt.strength <= 1.0, "can only work with strength in [0.0, 1.0]" +t_enc = int(opt.strength * opt.ddim_steps) +print(f"target t_enc is {t_enc} steps") + + +if opt.precision == "autocast" and opt.device != "cpu": + precision_scope = autocast +else: + precision_scope = nullcontext + +seeds = "" +with torch.no_grad(): + + all_samples = list() + for n in trange(opt.n_iter, desc="Sampling"): + for prompts in tqdm(data, desc="data"): + + sample_path = os.path.join(outpath, "_".join(re.split(":| ", prompts[0])))[:150] + os.makedirs(sample_path, exist_ok=True) + base_count = len(os.listdir(sample_path)) + + with precision_scope("cuda"): + modelCS.to(opt.device) + uc = None + if opt.scale != 1.0: + uc = modelCS.get_learned_conditioning(batch_size * [""]) + if isinstance(prompts, tuple): + prompts = list(prompts) + + subprompts, weights = split_weighted_subprompts(prompts[0]) + if len(subprompts) > 1: + c = torch.zeros_like(uc) + totalWeight = sum(weights) + # normalize each "sub prompt" and add it + for i in range(len(subprompts)): + weight = weights[i] + # if not skip_normalize: + weight = weight / totalWeight + c = torch.add(c, modelCS.get_learned_conditioning(subprompts[i]), alpha=weight) + else: + c = modelCS.get_learned_conditioning(prompts) + + if opt.device != "cpu": + mem = torch.cuda.memory_allocated(device=opt.device) / 1e6 + modelCS.to("cpu") + while torch.cuda.memory_allocated(device=opt.device) / 1e6 >= mem: + time.sleep(1) + + # encode (scaled latent) + z_enc = model.stochastic_encode( + init_latent, + torch.tensor([t_enc] * batch_size).to(opt.device), + opt.seed, + opt.ddim_eta, + opt.ddim_steps, + ) + # decode it + samples_ddim = model.sample( + t_enc, + c, + z_enc, + unconditional_guidance_scale=opt.scale, + unconditional_conditioning=uc, + sampler = opt.sampler + ) + + modelFS.to(opt.device) + print("saving images") + for i in range(batch_size): + + x_samples_ddim = modelFS.decode_first_stage(samples_ddim[i].unsqueeze(0)) + x_sample = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) + x_sample = 255.0 * rearrange(x_sample[0].cpu().numpy(), "c h w -> h w c") + Image.fromarray(x_sample.astype(np.uint8)).save( + os.path.join(sample_path, "seed_" + str(opt.seed) + "_" + f"{base_count:05}.{opt.format}") + ) + seeds += str(opt.seed) + "," + opt.seed += 1 + base_count += 1 + + if opt.device != "cpu": + mem = torch.cuda.memory_allocated(device=opt.device) / 1e6 + modelFS.to("cpu") + while torch.cuda.memory_allocated(device=opt.device) / 1e6 >= mem: + time.sleep(1) + + del samples_ddim + print("memory_final = ", torch.cuda.memory_allocated(device=opt.device) / 1e6) + +toc = time.time() + +time_taken = (toc - tic) / 60.0 + +print( + ( + "Samples finished in {0:.2f} minutes and exported to " + + sample_path + + "\n Seeds used = " + + seeds[:-1] + ).format(time_taken) +) diff --git a/optimizedSD/optimized_txt2img.py b/optimizedSD/optimized_txt2img.py index 8e07294..c829182 100644 --- a/optimizedSD/optimized_txt2img.py +++ b/optimizedSD/optimized_txt2img.py @@ -1,7 +1,7 @@ -import argparse, os, sys, glob, random +import argparse, os, re import torch import numpy as np -import copy +from random import randint from omegaconf import OmegaConf from PIL import Image from tqdm import tqdm, trange @@ -13,6 +13,10 @@ from pytorch_lightning import seed_everything from torch import autocast from contextlib import contextmanager, nullcontext from ldm.util import instantiate_from_config +from optimUtils import split_weighted_subprompts, logger +from transformers import logging +# from samplers import CompVisDenoiser +logging.set_verbosity_error() def chunk(it, size): @@ -30,33 +34,22 @@ def load_model_from_config(ckpt, verbose=False): config = "optimizedSD/v1-inference.yaml" -ckpt = "models/ldm/stable-diffusion-v1/model.ckpt" -device = "cuda" +DEFAULT_CKPT = "models/ldm/stable-diffusion-v1/model.ckpt" parser = argparse.ArgumentParser() parser.add_argument( - "--prompt", - type=str, - nargs="?", - default="a painting of a virus monster playing guitar", - help="the prompt to render" -) -parser.add_argument( - "--outdir", - type=str, - nargs="?", - help="dir to write results to", - default="outputs/txt2img-samples" + "--prompt", type=str, nargs="?", default="a painting of a virus monster playing guitar", help="the prompt to render" ) +parser.add_argument("--outdir", type=str, nargs="?", help="dir to write results to", default="outputs/txt2img-samples") parser.add_argument( "--skip_grid", - action='store_true', + action="store_true", help="do not save a grid, only individual samples. Helpful when evaluating lots of samples", ) parser.add_argument( "--skip_save", - action='store_true', + action="store_true", help="do not save individual samples. For speed measurements.", ) parser.add_argument( @@ -68,7 +61,7 @@ parser.add_argument( parser.add_argument( "--fixed_code", - action='store_true', + action="store_true", help="if enabled, uses the same starting code across samples ", ) parser.add_argument( @@ -125,6 +118,12 @@ parser.add_argument( default=7.5, help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))", ) +parser.add_argument( + "--device", + type=str, + default="cuda", + help="specify GPU (cuda/cuda:0/cuda:1/...)", +) parser.add_argument( "--from-file", type=str, @@ -133,165 +132,216 @@ parser.add_argument( parser.add_argument( "--seed", type=int, - default=42, + default=None, help="the seed (for reproducible sampling)", ) parser.add_argument( - "--small_batch", - action='store_true', - help="Reduce inference time when generate a smaller batch of images", + "--unet_bs", + type=int, + default=1, + help="Slightly reduces inference time at the expense of high VRAM (value > 1 not recommended )", ) parser.add_argument( - "--precision", + "--turbo", + action="store_true", + help="Reduces inference time on the expense of 1GB VRAM", +) +parser.add_argument( + "--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast" ) +parser.add_argument( + "--format", + type=str, + help="output image format", + choices=["jpg", "png"], + default="png", +) +parser.add_argument( + "--sampler", + type=str, + help="sampler", + choices=["ddim", "plms","heun", "euler", "euler_a", "dpm2", "dpm2_a", "lms"], + default="plms", +) +parser.add_argument( + "--ckpt", + type=str, + help="path to checkpoint of model", + default=DEFAULT_CKPT, +) opt = parser.parse_args() tic = time.time() os.makedirs(opt.outdir, exist_ok=True) outpath = opt.outdir - -sample_path = os.path.join(outpath, "samples", "_".join(opt.prompt.split())[:255]) -os.makedirs(sample_path, exist_ok=True) -base_count = len(os.listdir(sample_path)) grid_count = len(os.listdir(outpath)) - 1 + +if opt.seed == None: + opt.seed = randint(0, 1000000) seed_everything(opt.seed) -sd = load_model_from_config(f"{ckpt}") -li = [] -lo = [] +# Logging +logger(vars(opt), log_csv = "logs/txt2img_logs.csv") + +sd = load_model_from_config(f"{opt.ckpt}") +li, lo = [], [] for key, value in sd.items(): - sp = key.split('.') - if(sp[0]) == 'model': - if('input_blocks' in sp): + sp = key.split(".") + if (sp[0]) == "model": + if "input_blocks" in sp: li.append(key) - elif('middle_block' in sp): + elif "middle_block" in sp: li.append(key) - elif('time_embed' in sp): + elif "time_embed" in sp: li.append(key) else: lo.append(key) for key in li: - sd['model1.' + key[6:]] = sd.pop(key) + sd["model1." + key[6:]] = sd.pop(key) for key in lo: - sd['model2.' + key[6:]] = sd.pop(key) + sd["model2." + key[6:]] = sd.pop(key) config = OmegaConf.load(f"{config}") -config.modelUNet.params.ddim_steps = opt.ddim_steps - -if opt.small_batch: - config.modelUNet.params.small_batch = True -else: - config.modelUNet.params.small_batch = False - - model = instantiate_from_config(config.modelUNet) _, _ = model.load_state_dict(sd, strict=False) model.eval() - +model.unet_bs = opt.unet_bs +model.cdevice = opt.device +model.turbo = opt.turbo + modelCS = instantiate_from_config(config.modelCondStage) _, _ = modelCS.load_state_dict(sd, strict=False) modelCS.eval() - +modelCS.cond_stage_model.device = opt.device + modelFS = instantiate_from_config(config.modelFirstStage) _, _ = modelFS.load_state_dict(sd, strict=False) modelFS.eval() +del sd -if opt.precision == "autocast": +if opt.device != "cpu" and opt.precision == "autocast": model.half() modelCS.half() start_code = None if opt.fixed_code: - start_code = torch.randn([opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f], device=device) + start_code = torch.randn([opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f], device=opt.device) batch_size = opt.n_samples n_rows = opt.n_rows if opt.n_rows > 0 else batch_size if not opt.from_file: + assert opt.prompt is not None prompt = opt.prompt - assert prompt is not None + print(f"Using prompt: {prompt}") data = [batch_size * [prompt]] else: print(f"reading prompts from {opt.from_file}") with open(opt.from_file, "r") as f: - data = f.read().splitlines() - data = list(chunk(data, batch_size)) + text = f.read() + print(f"Using prompt: {text.strip()}") + data = text.splitlines() + data = batch_size * list(data) + data = list(chunk(sorted(data), batch_size)) -precision_scope = autocast if opt.precision=="autocast" else nullcontext +if opt.precision == "autocast" and opt.device != "cpu": + precision_scope = autocast +else: + precision_scope = nullcontext + +seeds = "" with torch.no_grad(): all_samples = list() for n in trange(opt.n_iter, desc="Sampling"): for prompts in tqdm(data, desc="data"): - with precision_scope("cuda"): - modelCS.to(device) + + sample_path = os.path.join(outpath, "_".join(re.split(":| ", prompts[0])))[:150] + os.makedirs(sample_path, exist_ok=True) + base_count = len(os.listdir(sample_path)) + + with precision_scope("cuda"): + modelCS.to(opt.device) uc = None if opt.scale != 1.0: uc = modelCS.get_learned_conditioning(batch_size * [""]) if isinstance(prompts, tuple): prompts = list(prompts) - - c = modelCS.get_learned_conditioning(prompts) - shape = [opt.C, opt.H // opt.f, opt.W // opt.f] - mem = torch.cuda.memory_allocated()/1e6 - modelCS.to("cpu") - while(torch.cuda.memory_allocated()/1e6 >= mem): - time.sleep(1) + subprompts, weights = split_weighted_subprompts(prompts[0]) + if len(subprompts) > 1: + c = torch.zeros_like(uc) + totalWeight = sum(weights) + # normalize each "sub prompt" and add it + for i in range(len(subprompts)): + weight = weights[i] + # if not skip_normalize: + weight = weight / totalWeight + c = torch.add(c, modelCS.get_learned_conditioning(subprompts[i]), alpha=weight) + else: + c = modelCS.get_learned_conditioning(prompts) - samples_ddim = model.sample(S=opt.ddim_steps, - conditioning=c, - batch_size=opt.n_samples, - shape=shape, - verbose=False, - unconditional_guidance_scale=opt.scale, - unconditional_conditioning=uc, - eta=opt.ddim_eta, - x_T=start_code) + shape = [opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f] - modelFS.to(device) + if opt.device != "cpu": + mem = torch.cuda.memory_allocated() / 1e6 + modelCS.to("cpu") + while torch.cuda.memory_allocated() / 1e6 >= mem: + time.sleep(1) + + samples_ddim = model.sample( + S=opt.ddim_steps, + conditioning=c, + seed=opt.seed, + shape=shape, + verbose=False, + unconditional_guidance_scale=opt.scale, + unconditional_conditioning=uc, + eta=opt.ddim_eta, + x_T=start_code, + sampler = opt.sampler, + ) + + modelFS.to(opt.device) + + print(samples_ddim.shape) print("saving images") for i in range(batch_size): - + x_samples_ddim = modelFS.decode_first_stage(samples_ddim[i].unsqueeze(0)) x_sample = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) - # for x_sample in x_samples_ddim: - x_sample = 255. * rearrange(x_sample[0].cpu().numpy(), 'c h w -> h w c') + x_sample = 255.0 * rearrange(x_sample[0].cpu().numpy(), "c h w -> h w c") Image.fromarray(x_sample.astype(np.uint8)).save( - os.path.join(sample_path, f"{base_count:05}.png")) + os.path.join(sample_path, "seed_" + str(opt.seed) + "_" + f"{base_count:05}.{opt.format}") + ) + seeds += str(opt.seed) + "," + opt.seed += 1 base_count += 1 - - mem = torch.cuda.memory_allocated()/1e6 - modelFS.to("cpu") - while(torch.cuda.memory_allocated()/1e6 >= mem): - time.sleep(1) - - # if not opt.skip_grid: - # all_samples.append(x_samples_ddim) + if opt.device != "cpu": + mem = torch.cuda.memory_allocated() / 1e6 + modelFS.to("cpu") + while torch.cuda.memory_allocated() / 1e6 >= mem: + time.sleep(1) del samples_ddim - print("memory_final = ", torch.cuda.memory_allocated()/1e6) - - # if not skip_grid: - # # additionally, save as grid - # grid = torch.stack(all_samples, 0) - # grid = rearrange(grid, 'n b c h w -> (n b) c h w') - # grid = make_grid(grid, nrow=n_rows) - - # # to image - # grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy() - # Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'grid-{grid_count:04}.png')) - # grid_count += 1 + print("memory_final = ", torch.cuda.memory_allocated() / 1e6) toc = time.time() -time_taken = (toc-tic)/60.0 +time_taken = (toc - tic) / 60.0 -print(("Your samples are ready in {0:.2f} minutes and waiting for you here \n" + sample_path).format(time_taken)) \ No newline at end of file +print( + ( + "Samples finished in {0:.2f} minutes and exported to " + + sample_path + + "\n Seeds used = " + + seeds[:-1] + ).format(time_taken) +) diff --git a/optimizedSD/samplers.py b/optimizedSD/samplers.py new file mode 100644 index 0000000..6a68e8e --- /dev/null +++ b/optimizedSD/samplers.py @@ -0,0 +1,252 @@ +from scipy import integrate +import torch +from tqdm.auto import trange, tqdm +import torch.nn as nn + + +def append_zero(x): + return torch.cat([x, x.new_zeros([1])]) + + +def append_dims(x, target_dims): + """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" + dims_to_append = target_dims - x.ndim + if dims_to_append < 0: + raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less') + return x[(...,) + (None,) * dims_to_append] + +def get_ancestral_step(sigma_from, sigma_to): + """Calculates the noise level (sigma_down) to step down to and the amount + of noise to add (sigma_up) when doing an ancestral sampling step.""" + sigma_up = (sigma_to ** 2 * (sigma_from ** 2 - sigma_to ** 2) / sigma_from ** 2) ** 0.5 + sigma_down = (sigma_to ** 2 - sigma_up ** 2) ** 0.5 + return sigma_down, sigma_up + + +class DiscreteSchedule(nn.Module): + """A mapping between continuous noise levels (sigmas) and a list of discrete noise + levels.""" + + def __init__(self, sigmas, quantize): + super().__init__() + self.register_buffer('sigmas', sigmas) + self.quantize = quantize + + def get_sigmas(self, n=None): + if n is None: + return append_zero(self.sigmas.flip(0)) + t_max = len(self.sigmas) - 1 + t = torch.linspace(t_max, 0, n, device=self.sigmas.device) + return append_zero(self.t_to_sigma(t)) + + def sigma_to_t(self, sigma, quantize=None): + quantize = self.quantize if quantize is None else quantize + dists = torch.abs(sigma - self.sigmas[:, None]) + if quantize: + return torch.argmin(dists, dim=0).view(sigma.shape) + low_idx, high_idx = torch.sort(torch.topk(dists, dim=0, k=2, largest=False).indices, dim=0)[0] + low, high = self.sigmas[low_idx], self.sigmas[high_idx] + w = (low - sigma) / (low - high) + w = w.clamp(0, 1) + t = (1 - w) * low_idx + w * high_idx + return t.view(sigma.shape) + + def t_to_sigma(self, t): + t = t.float() + low_idx, high_idx, w = t.floor().long(), t.ceil().long(), t.frac() + # print(low_idx, high_idx, w ) + return (1 - w) * self.sigmas[low_idx] + w * self.sigmas[high_idx] + + +class DiscreteEpsDDPMDenoiser(DiscreteSchedule): + """A wrapper for discrete schedule DDPM models that output eps (the predicted + noise).""" + + def __init__(self, alphas_cumprod, quantize): + super().__init__(((1 - alphas_cumprod) / alphas_cumprod) ** 0.5, quantize) + self.sigma_data = 1. + + def get_scalings(self, sigma): + c_out = -sigma + c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 + return c_out, c_in + + def get_eps(self, *args, **kwargs): + return self.inner_model(*args, **kwargs) + + def forward(self, input, sigma, **kwargs): + c_out, c_in = [append_dims(x, input.ndim) for x in self.get_scalings(sigma)] + eps = self.get_eps(input * c_in, self.sigma_to_t(sigma), **kwargs) + return input + eps * c_out + +class CompVisDenoiser(DiscreteEpsDDPMDenoiser): + """A wrapper for CompVis diffusion models.""" + + def __init__(self, alphas_cumprod, quantize=False, device='cpu'): + super().__init__(alphas_cumprod, quantize=quantize) + + def get_eps(self, *args, **kwargs): + return self.inner_model.apply_model(*args, **kwargs) + + +def to_d(x, sigma, denoised): + """Converts a denoiser output to a Karras ODE derivative.""" + return (x - denoised) / append_dims(sigma, x.ndim) + + +def get_ancestral_step(sigma_from, sigma_to): + """Calculates the noise level (sigma_down) to step down to and the amount + of noise to add (sigma_up) when doing an ancestral sampling step.""" + sigma_up = (sigma_to ** 2 * (sigma_from ** 2 - sigma_to ** 2) / sigma_from ** 2) ** 0.5 + sigma_down = (sigma_to ** 2 - sigma_up ** 2) ** 0.5 + return sigma_down, sigma_up + + +@torch.no_grad() +def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.): + """Implements Algorithm 2 (Euler steps) from Karras et al. (2022).""" + extra_args = {} if extra_args is None else extra_args + s_in = x.new_ones([x.shape[0]]) + for i in trange(len(sigmas) - 1, disable=disable): + gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0. + eps = torch.randn_like(x) * s_noise + sigma_hat = sigmas[i] * (gamma + 1) + if gamma > 0: + x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5 + denoised = model(x, sigma_hat * s_in, **extra_args) + d = to_d(x, sigma_hat, denoised) + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised}) + dt = sigmas[i + 1] - sigma_hat + # Euler method + x = x + d * dt + return x + + + +@torch.no_grad() +def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None): + """Ancestral sampling with Euler method steps.""" + extra_args = {} if extra_args is None else extra_args + s_in = x.new_ones([x.shape[0]]) + for i in trange(len(sigmas) - 1, disable=disable): + denoised = model(x, sigmas[i] * s_in, **extra_args) + sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1]) + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + d = to_d(x, sigmas[i], denoised) + # Euler method + dt = sigma_down - sigmas[i] + x = x + d * dt + x = x + torch.randn_like(x) * sigma_up + return x + + +@torch.no_grad() +def sample_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.): + """Implements Algorithm 2 (Heun steps) from Karras et al. (2022).""" + extra_args = {} if extra_args is None else extra_args + s_in = x.new_ones([x.shape[0]]) + for i in trange(len(sigmas) - 1, disable=disable): + gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0. + eps = torch.randn_like(x) * s_noise + sigma_hat = sigmas[i] * (gamma + 1) + if gamma > 0: + x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5 + denoised = model(x, sigma_hat * s_in, **extra_args) + d = to_d(x, sigma_hat, denoised) + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised}) + dt = sigmas[i + 1] - sigma_hat + if sigmas[i + 1] == 0: + # Euler method + x = x + d * dt + else: + # Heun's method + x_2 = x + d * dt + denoised_2 = model(x_2, sigmas[i + 1] * s_in, **extra_args) + d_2 = to_d(x_2, sigmas[i + 1], denoised_2) + d_prime = (d + d_2) / 2 + x = x + d_prime * dt + return x + + +@torch.no_grad() +def sample_dpm_2(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.): + """A sampler inspired by DPM-Solver-2 and Algorithm 2 from Karras et al. (2022).""" + extra_args = {} if extra_args is None else extra_args + s_in = x.new_ones([x.shape[0]]) + for i in trange(len(sigmas) - 1, disable=disable): + gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0. + eps = torch.randn_like(x) * s_noise + sigma_hat = sigmas[i] * (gamma + 1) + if gamma > 0: + x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5 + denoised = model(x, sigma_hat * s_in, **extra_args) + d = to_d(x, sigma_hat, denoised) + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised}) + # Midpoint method, where the midpoint is chosen according to a rho=3 Karras schedule + sigma_mid = ((sigma_hat ** (1 / 3) + sigmas[i + 1] ** (1 / 3)) / 2) ** 3 + dt_1 = sigma_mid - sigma_hat + dt_2 = sigmas[i + 1] - sigma_hat + x_2 = x + d * dt_1 + denoised_2 = model(x_2, sigma_mid * s_in, **extra_args) + d_2 = to_d(x_2, sigma_mid, denoised_2) + x = x + d_2 * dt_2 + return x + + +@torch.no_grad() +def sample_dpm_2_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None): + """Ancestral sampling with DPM-Solver inspired second-order steps.""" + extra_args = {} if extra_args is None else extra_args + s_in = x.new_ones([x.shape[0]]) + for i in trange(len(sigmas) - 1, disable=disable): + denoised = model(x, sigmas[i] * s_in, **extra_args) + sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1]) + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + d = to_d(x, sigmas[i], denoised) + # Midpoint method, where the midpoint is chosen according to a rho=3 Karras schedule + sigma_mid = ((sigmas[i] ** (1 / 3) + sigma_down ** (1 / 3)) / 2) ** 3 + dt_1 = sigma_mid - sigmas[i] + dt_2 = sigma_down - sigmas[i] + x_2 = x + d * dt_1 + denoised_2 = model(x_2, sigma_mid * s_in, **extra_args) + d_2 = to_d(x_2, sigma_mid, denoised_2) + x = x + d_2 * dt_2 + x = x + torch.randn_like(x) * sigma_up + return x + + +def linear_multistep_coeff(order, t, i, j): + if order - 1 > i: + raise ValueError(f'Order {order} too high for step {i}') + def fn(tau): + prod = 1. + for k in range(order): + if j == k: + continue + prod *= (tau - t[i - k]) / (t[i - j] - t[i - k]) + return prod + return integrate.quad(fn, t[i], t[i + 1], epsrel=1e-4)[0] + + +@torch.no_grad() +def sample_lms(model, x, sigmas, extra_args=None, callback=None, disable=None, order=4): + extra_args = {} if extra_args is None else extra_args + s_in = x.new_ones([x.shape[0]]) + ds = [] + for i in trange(len(sigmas) - 1, disable=disable): + denoised = model(x, sigmas[i] * s_in, **extra_args) + d = to_d(x, sigmas[i], denoised) + ds.append(d) + if len(ds) > order: + ds.pop(0) + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + cur_order = min(i + 1, order) + coeffs = [linear_multistep_coeff(cur_order, sigmas.cpu(), i, j) for j in range(cur_order)] + x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds))) + return x diff --git a/optimizedSD/splitAttention.py b/optimizedSD/splitAttention.py new file mode 100644 index 0000000..dbfd459 --- /dev/null +++ b/optimizedSD/splitAttention.py @@ -0,0 +1,280 @@ +from inspect import isfunction +import math +import torch +import torch.nn.functional as F +from torch import nn, einsum +from einops import rearrange, repeat + +from ldm.modules.diffusionmodules.util import checkpoint + + +def exists(val): + return val is not None + + +def uniq(arr): + return{el: True for el in arr}.keys() + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def max_neg_value(t): + return -torch.finfo(t.dtype).max + + +def init_(tensor): + dim = tensor.shape[-1] + std = 1 / math.sqrt(dim) + tensor.uniform_(-std, std) + return tensor + + +# feedforward +class GEGLU(nn.Module): + def __init__(self, dim_in, dim_out): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * F.gelu(gate) + + +class FeedForward(nn.Module): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): + super().__init__() + inner_dim = int(dim * mult) + dim_out = default(dim_out, dim) + project_in = nn.Sequential( + nn.Linear(dim, inner_dim), + nn.GELU() + ) if not glu else GEGLU(dim, inner_dim) + + self.net = nn.Sequential( + project_in, + nn.Dropout(dropout), + nn.Linear(inner_dim, dim_out) + ) + + def forward(self, x): + return self.net(x) + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def Normalize(in_channels): + return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + + +class LinearAttention(nn.Module): + def __init__(self, dim, heads=4, dim_head=32): + super().__init__() + self.heads = heads + hidden_dim = dim_head * heads + self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) + self.to_out = nn.Conv2d(hidden_dim, dim, 1) + + def forward(self, x): + b, c, h, w = x.shape + qkv = self.to_qkv(x) + q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3) + k = k.softmax(dim=-1) + context = torch.einsum('bhdn,bhen->bhde', k, v) + out = torch.einsum('bhde,bhdn->bhen', context, q) + out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w) + return self.to_out(out) + + +class SpatialSelfAttention(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b,c,h,w = q.shape + q = rearrange(q, 'b c h w -> b (h w) c') + k = rearrange(k, 'b c h w -> b c (h w)') + w_ = torch.einsum('bij,bjk->bik', q, k) + + w_ = w_ * (int(c)**(-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = rearrange(v, 'b c h w -> b c (h w)') + w_ = rearrange(w_, 'b i j -> b j i') + h_ = torch.einsum('bij,bjk->bik', v, w_) + h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h) + h_ = self.proj_out(h_) + + return x+h_ + + +class CrossAttention(nn.Module): + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., att_step=1): + super().__init__() + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.scale = dim_head ** -0.5 + self.heads = heads + self.att_step = att_step + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, query_dim), + nn.Dropout(dropout) + ) + + def forward(self, x, context=None, mask=None): + h = self.heads + + q = self.to_q(x) + context = default(context, x) + k = self.to_k(context) + v = self.to_v(context) + del context, x + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) + + + limit = k.shape[0] + att_step = self.att_step + q_chunks = list(torch.tensor_split(q, limit//att_step, dim=0)) + k_chunks = list(torch.tensor_split(k, limit//att_step, dim=0)) + v_chunks = list(torch.tensor_split(v, limit//att_step, dim=0)) + + q_chunks.reverse() + k_chunks.reverse() + v_chunks.reverse() + sim = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device) + del k, q, v + for i in range (0, limit, att_step): + + q_buffer = q_chunks.pop() + k_buffer = k_chunks.pop() + v_buffer = v_chunks.pop() + sim_buffer = einsum('b i d, b j d -> b i j', q_buffer, k_buffer) * self.scale + + del k_buffer, q_buffer + # attention, what we cannot get enough of, by chunks + + sim_buffer = sim_buffer.softmax(dim=-1) + + sim_buffer = einsum('b i j, b j d -> b i d', sim_buffer, v_buffer) + del v_buffer + sim[i:i+att_step,:,:] = sim_buffer + + del sim_buffer + sim = rearrange(sim, '(b h) n d -> b n (h d)', h=h) + return self.to_out(sim) + + +class BasicTransformerBlock(nn.Module): + def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True): + super().__init__() + self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention + self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) + self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, + heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + self.norm3 = nn.LayerNorm(dim) + self.checkpoint = checkpoint + + def forward(self, x, context=None): + return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint) + + def _forward(self, x, context=None): + x = self.attn1(self.norm1(x)) + x + x = self.attn2(self.norm2(x), context=context) + x + x = self.ff(self.norm3(x)) + x + return x + + +class SpatialTransformer(nn.Module): + """ + Transformer block for image-like data. + First, project the input (aka embedding) + and reshape to b, t, d. + Then apply standard transformer action. + Finally, reshape to image + """ + def __init__(self, in_channels, n_heads, d_head, + depth=1, dropout=0., context_dim=None): + super().__init__() + self.in_channels = in_channels + inner_dim = n_heads * d_head + self.norm = Normalize(in_channels) + + self.proj_in = nn.Conv2d(in_channels, + inner_dim, + kernel_size=1, + stride=1, + padding=0) + + self.transformer_blocks = nn.ModuleList( + [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim) + for d in range(depth)] + ) + + self.proj_out = zero_module(nn.Conv2d(inner_dim, + in_channels, + kernel_size=1, + stride=1, + padding=0)) + + def forward(self, x, context=None): + # note: if no context is given, cross-attention defaults to self-attention + b, c, h, w = x.shape + x_in = x + x = self.norm(x) + x = self.proj_in(x) + x = rearrange(x, 'b c h w -> b (h w) c') + for block in self.transformer_blocks: + x = block(x, context=context) + x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w) + x = self.proj_out(x) + return x + x_in