2022-09-03 12:08:45 +03:00
|
|
|
import torch
|
2022-10-03 00:31:19 +03:00
|
|
|
from torch.nn.functional import silu
|
2022-09-03 12:08:45 +03:00
|
|
|
|
2022-10-02 15:03:39 +03:00
|
|
|
import modules.textual_inversion.textual_inversion
|
2022-12-10 09:17:39 +03:00
|
|
|
from modules import devices, sd_hijack_optimizations, shared, sd_hijack_checkpoint
|
2022-11-26 16:45:57 +03:00
|
|
|
from modules.hypernetworks import hypernetwork
|
2022-12-10 09:17:39 +03:00
|
|
|
from modules.shared import cmd_opts
|
2022-12-31 18:06:35 +03:00
|
|
|
from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr
|
2022-11-26 16:10:46 +03:00
|
|
|
|
2022-09-05 01:41:20 +03:00
|
|
|
import ldm.modules.attention
|
2022-09-13 14:29:56 +03:00
|
|
|
import ldm.modules.diffusionmodules.model
|
2022-12-02 15:47:02 +03:00
|
|
|
import ldm.modules.diffusionmodules.openaimodel
|
2022-11-11 18:20:18 +03:00
|
|
|
import ldm.models.diffusion.ddim
|
|
|
|
import ldm.models.diffusion.plms
|
2022-11-26 16:10:46 +03:00
|
|
|
import ldm.modules.encoders.modules
|
2022-09-13 14:29:56 +03:00
|
|
|
|
2022-10-02 15:03:39 +03:00
|
|
|
attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward
|
|
|
|
diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity
|
|
|
|
diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
|
2022-09-13 14:29:56 +03:00
|
|
|
|
2022-11-26 16:10:46 +03:00
|
|
|
# new memory efficient cross attention blocks do not support hypernets and we already
|
|
|
|
# have memory efficient cross attention anyway, so this disables SD2.0's memory efficient cross attention
|
|
|
|
ldm.modules.attention.MemoryEfficientCrossAttention = ldm.modules.attention.CrossAttention
|
|
|
|
ldm.modules.attention.BasicTransformerBlock.ATTENTION_MODES["softmax-xformers"] = ldm.modules.attention.CrossAttention
|
|
|
|
|
|
|
|
# silence new console spam from SD2
|
|
|
|
ldm.modules.attention.print = lambda *args: None
|
|
|
|
ldm.modules.diffusionmodules.model.print = lambda *args: None
|
2022-10-15 16:59:37 +03:00
|
|
|
|
2022-12-10 09:14:30 +03:00
|
|
|
|
2022-10-02 15:03:39 +03:00
|
|
|
def apply_optimizations():
|
2022-10-07 16:39:51 +03:00
|
|
|
undo_optimizations()
|
|
|
|
|
2022-10-03 00:31:19 +03:00
|
|
|
ldm.modules.diffusionmodules.model.nonlinearity = silu
|
2022-12-10 09:14:30 +03:00
|
|
|
ldm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th
|
2023-01-04 16:04:38 +03:00
|
|
|
|
|
|
|
optimization_method = None
|
2022-09-13 14:29:56 +03:00
|
|
|
|
2022-10-15 19:19:54 +03:00
|
|
|
if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)):
|
2022-10-08 19:22:15 +03:00
|
|
|
print("Applying xformers cross attention optimization.")
|
2022-10-08 17:44:53 +03:00
|
|
|
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward
|
2022-10-17 22:19:18 +03:00
|
|
|
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward
|
2023-01-04 16:04:38 +03:00
|
|
|
optimization_method = 'xformers'
|
2022-12-27 16:50:55 +03:00
|
|
|
elif cmd_opts.opt_sub_quad_attention:
|
|
|
|
print("Applying sub-quadratic cross attention optimization.")
|
|
|
|
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.sub_quad_attention_forward
|
|
|
|
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.sub_quad_attnblock_forward
|
2023-01-05 07:10:31 +03:00
|
|
|
optimization_method = 'sub-quadratic'
|
2022-10-08 04:10:35 +03:00
|
|
|
elif cmd_opts.opt_split_attention_v1:
|
2022-10-08 19:22:15 +03:00
|
|
|
print("Applying v1 cross attention optimization.")
|
2022-10-02 15:03:39 +03:00
|
|
|
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1
|
2023-01-04 16:04:38 +03:00
|
|
|
optimization_method = 'V1'
|
2023-01-06 09:33:15 +03:00
|
|
|
elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention_invokeai or not cmd_opts.opt_split_attention and not torch.cuda.is_available()):
|
2022-12-27 16:50:55 +03:00
|
|
|
print("Applying cross attention optimization (InvokeAI).")
|
|
|
|
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_invokeAI
|
2023-01-05 07:10:31 +03:00
|
|
|
optimization_method = 'InvokeAI'
|
2022-10-02 15:03:39 +03:00
|
|
|
elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()):
|
2022-10-11 05:48:54 +03:00
|
|
|
print("Applying cross attention optimization (Doggettx).")
|
2022-10-02 15:03:39 +03:00
|
|
|
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward
|
|
|
|
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward
|
2023-01-04 16:04:38 +03:00
|
|
|
optimization_method = 'Doggettx'
|
|
|
|
|
|
|
|
return optimization_method
|
2022-09-13 14:29:56 +03:00
|
|
|
|
|
|
|
|
2022-10-02 15:03:39 +03:00
|
|
|
def undo_optimizations():
|
2022-11-26 16:45:57 +03:00
|
|
|
ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
|
2022-10-02 15:03:39 +03:00
|
|
|
ldm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity
|
|
|
|
ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
|
2022-09-13 14:29:56 +03:00
|
|
|
|
2022-09-03 12:08:45 +03:00
|
|
|
|
2022-11-20 06:35:26 +03:00
|
|
|
def fix_checkpoint():
|
|
|
|
ldm.modules.attention.BasicTransformerBlock.forward = sd_hijack_checkpoint.BasicTransformerBlock_forward
|
|
|
|
ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = sd_hijack_checkpoint.ResBlock_forward
|
|
|
|
ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = sd_hijack_checkpoint.AttentionBlock_forward
|
2022-10-08 14:25:47 +03:00
|
|
|
|
2022-12-31 18:06:35 +03:00
|
|
|
|
2022-09-03 12:08:45 +03:00
|
|
|
class StableDiffusionModelHijack:
|
|
|
|
fixes = None
|
|
|
|
comments = []
|
2022-09-05 03:25:37 +03:00
|
|
|
layers = None
|
|
|
|
circular_enabled = False
|
2022-09-27 22:56:18 +03:00
|
|
|
clip = None
|
2023-01-04 16:04:38 +03:00
|
|
|
optimization_method = None
|
2022-09-03 12:08:45 +03:00
|
|
|
|
2022-10-02 15:03:39 +03:00
|
|
|
embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase(cmd_opts.embeddings_dir)
|
2022-09-03 12:08:45 +03:00
|
|
|
|
|
|
|
def hijack(self, m):
|
2022-11-30 05:13:17 +03:00
|
|
|
|
2022-12-31 18:06:35 +03:00
|
|
|
if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
|
2022-11-30 09:56:12 +03:00
|
|
|
model_embeddings = m.cond_stage_model.roberta.embeddings
|
|
|
|
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.word_embeddings, self)
|
2022-12-31 18:06:35 +03:00
|
|
|
m.cond_stage_model = sd_hijack_xlmr.FrozenXLMREmbedderWithCustomWords(m.cond_stage_model, self)
|
|
|
|
|
2022-11-30 09:56:12 +03:00
|
|
|
elif type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenCLIPEmbedder:
|
2022-11-26 16:10:46 +03:00
|
|
|
model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
|
|
|
|
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
|
|
|
|
m.cond_stage_model = sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
|
2022-12-31 18:06:35 +03:00
|
|
|
|
2022-11-26 16:10:46 +03:00
|
|
|
elif type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder:
|
|
|
|
m.cond_stage_model.model.token_embedding = EmbeddingsWithFixes(m.cond_stage_model.model.token_embedding, self)
|
|
|
|
m.cond_stage_model = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
|
2022-12-31 18:06:35 +03:00
|
|
|
|
2023-01-04 16:04:38 +03:00
|
|
|
self.optimization_method = apply_optimizations()
|
2022-12-31 18:06:35 +03:00
|
|
|
|
2022-09-27 22:56:18 +03:00
|
|
|
self.clip = m.cond_stage_model
|
2022-11-30 05:13:17 +03:00
|
|
|
|
2022-11-20 06:35:26 +03:00
|
|
|
fix_checkpoint()
|
2022-09-05 01:41:20 +03:00
|
|
|
|
2022-09-05 03:25:37 +03:00
|
|
|
def flatten(el):
|
|
|
|
flattened = [flatten(children) for children in el.children()]
|
|
|
|
res = [el]
|
|
|
|
for c in flattened:
|
|
|
|
res += c
|
|
|
|
return res
|
|
|
|
|
|
|
|
self.layers = flatten(m)
|
|
|
|
|
2022-09-29 15:40:28 +03:00
|
|
|
def undo_hijack(self, m):
|
2022-12-06 11:04:50 +03:00
|
|
|
|
2022-12-31 18:06:35 +03:00
|
|
|
if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
|
2022-12-06 11:04:50 +03:00
|
|
|
m.cond_stage_model = m.cond_stage_model.wrapped
|
|
|
|
|
|
|
|
elif type(m.cond_stage_model) == sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords:
|
2022-09-29 15:40:28 +03:00
|
|
|
m.cond_stage_model = m.cond_stage_model.wrapped
|
|
|
|
|
2022-11-26 16:10:46 +03:00
|
|
|
model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
|
|
|
|
if type(model_embeddings.token_embedding) == EmbeddingsWithFixes:
|
|
|
|
model_embeddings.token_embedding = model_embeddings.token_embedding.wrapped
|
|
|
|
elif type(m.cond_stage_model) == sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords:
|
|
|
|
m.cond_stage_model.wrapped.model.token_embedding = m.cond_stage_model.wrapped.model.token_embedding.wrapped
|
|
|
|
m.cond_stage_model = m.cond_stage_model.wrapped
|
2022-09-29 15:40:28 +03:00
|
|
|
|
2022-11-18 13:22:55 +03:00
|
|
|
self.apply_circular(False)
|
2022-11-01 10:01:49 +03:00
|
|
|
self.layers = None
|
|
|
|
self.clip = None
|
|
|
|
|
2022-09-05 03:25:37 +03:00
|
|
|
def apply_circular(self, enable):
|
|
|
|
if self.circular_enabled == enable:
|
|
|
|
return
|
|
|
|
|
|
|
|
self.circular_enabled = enable
|
|
|
|
|
|
|
|
for layer in [layer for layer in self.layers if type(layer) == torch.nn.Conv2d]:
|
|
|
|
layer.padding_mode = 'circular' if enable else 'zeros'
|
|
|
|
|
2022-10-08 00:48:34 +03:00
|
|
|
def clear_comments(self):
|
|
|
|
self.comments = []
|
|
|
|
|
2023-01-07 01:45:28 +03:00
|
|
|
def get_prompt_lengths(self, text):
|
|
|
|
_, token_count = self.clip.process_texts([text])
|
2022-09-03 12:08:45 +03:00
|
|
|
|
2023-01-07 01:45:28 +03:00
|
|
|
return token_count, self.clip.get_target_prompt_token_count(token_count)
|
2022-09-03 12:08:45 +03:00
|
|
|
|
|
|
|
|
|
|
|
class EmbeddingsWithFixes(torch.nn.Module):
|
|
|
|
def __init__(self, wrapped, embeddings):
|
|
|
|
super().__init__()
|
|
|
|
self.wrapped = wrapped
|
|
|
|
self.embeddings = embeddings
|
|
|
|
|
|
|
|
def forward(self, input_ids):
|
|
|
|
batch_fixes = self.embeddings.fixes
|
|
|
|
self.embeddings.fixes = None
|
|
|
|
|
|
|
|
inputs_embeds = self.wrapped(input_ids)
|
|
|
|
|
2022-10-02 15:03:39 +03:00
|
|
|
if batch_fixes is None or len(batch_fixes) == 0 or max([len(x) for x in batch_fixes]) == 0:
|
|
|
|
return inputs_embeds
|
|
|
|
|
|
|
|
vecs = []
|
|
|
|
for fixes, tensor in zip(batch_fixes, inputs_embeds):
|
|
|
|
for offset, embedding in fixes:
|
|
|
|
emb = embedding.vec
|
2022-10-15 16:59:37 +03:00
|
|
|
emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0])
|
|
|
|
tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]])
|
2022-10-02 15:03:39 +03:00
|
|
|
|
|
|
|
vecs.append(tensor)
|
2022-09-03 12:08:45 +03:00
|
|
|
|
2022-10-02 15:03:39 +03:00
|
|
|
return torch.stack(vecs)
|
2022-09-03 12:08:45 +03:00
|
|
|
|
|
|
|
|
2022-09-05 02:16:36 +03:00
|
|
|
def add_circular_option_to_conv_2d():
|
|
|
|
conv2d_constructor = torch.nn.Conv2d.__init__
|
2022-09-05 01:41:20 +03:00
|
|
|
|
2022-09-05 02:16:36 +03:00
|
|
|
def conv2d_constructor_circular(self, *args, **kwargs):
|
|
|
|
return conv2d_constructor(self, *args, padding_mode='circular', **kwargs)
|
2022-09-05 01:41:20 +03:00
|
|
|
|
2022-09-05 02:16:36 +03:00
|
|
|
torch.nn.Conv2d.__init__ = conv2d_constructor_circular
|
2022-09-05 01:41:20 +03:00
|
|
|
|
|
|
|
|
2022-09-03 12:08:45 +03:00
|
|
|
model_hijack = StableDiffusionModelHijack()
|
2022-11-11 18:20:18 +03:00
|
|
|
|
|
|
|
|
|
|
|
def register_buffer(self, name, attr):
|
|
|
|
"""
|
|
|
|
Fix register buffer bug for Mac OS.
|
|
|
|
"""
|
|
|
|
|
|
|
|
if type(attr) == torch.Tensor:
|
|
|
|
if attr.device != devices.device:
|
2022-11-12 10:17:55 +03:00
|
|
|
attr = attr.to(device=devices.device, dtype=(torch.float32 if devices.device.type == 'mps' else None))
|
2022-11-11 18:20:18 +03:00
|
|
|
|
|
|
|
setattr(self, name, attr)
|
|
|
|
|
|
|
|
|
|
|
|
ldm.models.diffusion.ddim.DDIMSampler.register_buffer = register_buffer
|
|
|
|
ldm.models.diffusion.plms.PLMSSampler.register_buffer = register_buffer
|