2022-10-02 15:03:39 +03:00
import math
2022-10-08 17:02:18 +03:00
import sys
import traceback
2022-12-27 16:50:55 +03:00
import psutil
2022-10-08 17:02:18 +03:00
2022-10-02 15:03:39 +03:00
import torch
from torch import einsum
2022-10-08 16:33:39 +03:00
2022-10-02 15:03:39 +03:00
from ldm . util import default
from einops import rearrange
2023-05-19 00:03:27 +03:00
from modules import shared , errors , devices , sub_quadratic_attention
2022-10-11 15:51:22 +03:00
from modules . hypernetworks import hypernetwork
2022-10-11 11:09:51 +03:00
2023-05-18 22:48:28 +03:00
import ldm . modules . attention
import ldm . modules . diffusionmodules . model
diffusionmodules_model_AttnBlock_forward = ldm . modules . diffusionmodules . model . AttnBlock . forward
class SdOptimization :
2023-05-19 09:17:36 +03:00
name : str = None
label : str | None = None
cmd_opt : str | None = None
priority : int = 0
2023-05-18 22:48:28 +03:00
def title ( self ) :
if self . label is None :
return self . name
return f " { self . name } - { self . label } "
def is_available ( self ) :
return True
def apply ( self ) :
pass
def undo ( self ) :
ldm . modules . attention . CrossAttention . forward = hypernetwork . attention_CrossAttention_forward
ldm . modules . diffusionmodules . model . AttnBlock . forward = diffusionmodules_model_AttnBlock_forward
class SdOptimizationXformers ( SdOptimization ) :
2023-05-19 09:17:36 +03:00
name = " xformers "
cmd_opt = " xformers "
priority = 100
2023-05-18 22:48:28 +03:00
def is_available ( self ) :
return shared . cmd_opts . force_enable_xformers or ( shared . xformers_available and torch . version . cuda and ( 6 , 0 ) < = torch . cuda . get_device_capability ( shared . device ) < = ( 9 , 0 ) )
def apply ( self ) :
ldm . modules . attention . CrossAttention . forward = xformers_attention_forward
ldm . modules . diffusionmodules . model . AttnBlock . forward = xformers_attnblock_forward
class SdOptimizationSdpNoMem ( SdOptimization ) :
2023-05-19 09:17:36 +03:00
name = " sdp-no-mem "
label = " scaled dot product without memory efficient attention "
cmd_opt = " opt_sdp_no_mem_attention "
priority = 90
2023-05-18 22:48:28 +03:00
def is_available ( self ) :
return hasattr ( torch . nn . functional , " scaled_dot_product_attention " ) and callable ( torch . nn . functional . scaled_dot_product_attention )
def apply ( self ) :
ldm . modules . attention . CrossAttention . forward = scaled_dot_product_no_mem_attention_forward
ldm . modules . diffusionmodules . model . AttnBlock . forward = sdp_no_mem_attnblock_forward
class SdOptimizationSdp ( SdOptimizationSdpNoMem ) :
2023-05-19 09:17:36 +03:00
name = " sdp "
label = " scaled dot product "
cmd_opt = " opt_sdp_attention "
priority = 80
2023-05-18 22:48:28 +03:00
def apply ( self ) :
ldm . modules . attention . CrossAttention . forward = scaled_dot_product_attention_forward
ldm . modules . diffusionmodules . model . AttnBlock . forward = sdp_attnblock_forward
class SdOptimizationSubQuad ( SdOptimization ) :
2023-05-19 09:17:36 +03:00
name = " sub-quadratic "
cmd_opt = " opt_sub_quad_attention "
priority = 10
2023-05-18 22:48:28 +03:00
def apply ( self ) :
ldm . modules . attention . CrossAttention . forward = sub_quad_attention_forward
ldm . modules . diffusionmodules . model . AttnBlock . forward = sub_quad_attnblock_forward
class SdOptimizationV1 ( SdOptimization ) :
2023-05-19 09:17:36 +03:00
name = " V1 "
label = " original v1 "
cmd_opt = " opt_split_attention_v1 "
priority = 10
2023-05-18 22:48:28 +03:00
def apply ( self ) :
ldm . modules . attention . CrossAttention . forward = split_cross_attention_forward_v1
class SdOptimizationInvokeAI ( SdOptimization ) :
2023-05-19 09:17:36 +03:00
name = " InvokeAI "
cmd_opt = " opt_split_attention_invokeai "
2023-05-18 22:48:28 +03:00
2023-05-19 09:17:36 +03:00
@property
2023-05-18 22:48:28 +03:00
def priority ( self ) :
return 1000 if not torch . cuda . is_available ( ) else 10
def apply ( self ) :
ldm . modules . attention . CrossAttention . forward = split_cross_attention_forward_invokeAI
class SdOptimizationDoggettx ( SdOptimization ) :
2023-05-19 09:17:36 +03:00
name = " Doggettx "
cmd_opt = " opt_split_attention "
priority = 20
2023-05-18 22:48:28 +03:00
def apply ( self ) :
ldm . modules . attention . CrossAttention . forward = split_cross_attention_forward
ldm . modules . diffusionmodules . model . AttnBlock . forward = cross_attention_attnblock_forward
def list_optimizers ( res ) :
res . extend ( [
SdOptimizationXformers ( ) ,
SdOptimizationSdpNoMem ( ) ,
SdOptimizationSdp ( ) ,
SdOptimizationSubQuad ( ) ,
SdOptimizationV1 ( ) ,
SdOptimizationInvokeAI ( ) ,
SdOptimizationDoggettx ( ) ,
] )
2022-12-27 16:50:55 +03:00
2022-10-07 10:17:52 +03:00
2022-10-08 19:25:10 +03:00
if shared . cmd_opts . xformers or shared . cmd_opts . force_enable_xformers :
2022-10-08 17:02:18 +03:00
try :
import xformers . ops
shared . xformers_available = True
except Exception :
print ( " Cannot import xformers " , file = sys . stderr )
print ( traceback . format_exc ( ) , file = sys . stderr )
2022-10-02 15:03:39 +03:00
2022-12-27 16:50:55 +03:00
def get_available_vram ( ) :
if shared . device . type == ' cuda ' :
stats = torch . cuda . memory_stats ( shared . device )
mem_active = stats [ ' active_bytes.all.current ' ]
mem_reserved = stats [ ' reserved_bytes.all.current ' ]
mem_free_cuda , _ = torch . cuda . mem_get_info ( torch . cuda . current_device ( ) )
mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_cuda + mem_free_torch
return mem_free_total
else :
return psutil . virtual_memory ( ) . available
2022-10-02 15:03:39 +03:00
# see https://github.com/basujindal/stable-diffusion/pull/117 for discussion
def split_cross_attention_forward_v1 ( self , x , context = None , mask = None ) :
h = self . heads
2022-10-08 08:47:02 +03:00
q_in = self . to_q ( x )
2022-10-02 15:03:39 +03:00
context = default ( context , x )
2022-10-08 08:47:02 +03:00
2023-01-21 08:36:07 +03:00
context_k , context_v = hypernetwork . apply_hypernetworks ( shared . loaded_hypernetworks , context )
2022-10-11 11:09:51 +03:00
k_in = self . to_k ( context_k )
v_in = self . to_v ( context_v )
del context , context_k , context_v , x
2022-10-02 15:03:39 +03:00
2023-05-10 11:05:02 +03:00
q , k , v = ( rearrange ( t , ' b n (h d) -> (b h) n d ' , h = h ) for t in ( q_in , k_in , v_in ) )
2022-10-08 08:47:02 +03:00
del q_in , k_in , v_in
2022-10-02 15:03:39 +03:00
2023-01-25 08:23:10 +03:00
dtype = q . dtype
if shared . opts . upcast_attn :
q , k , v = q . float ( ) , k . float ( ) , v . float ( )
2022-10-02 15:03:39 +03:00
2023-01-25 08:23:10 +03:00
with devices . without_autocast ( disable = not shared . opts . upcast_attn ) :
r1 = torch . zeros ( q . shape [ 0 ] , q . shape [ 1 ] , v . shape [ 2 ] , device = q . device , dtype = q . dtype )
for i in range ( 0 , q . shape [ 0 ] , 2 ) :
end = i + 2
s1 = einsum ( ' b i d, b j d -> b i j ' , q [ i : end ] , k [ i : end ] )
s1 * = self . scale
2023-05-11 18:28:15 +03:00
2023-01-25 08:23:10 +03:00
s2 = s1 . softmax ( dim = - 1 )
del s1
2023-05-11 18:28:15 +03:00
2023-01-25 08:23:10 +03:00
r1 [ i : end ] = einsum ( ' b i j, b j d -> b i d ' , s2 , v [ i : end ] )
del s2
del q , k , v
2022-10-02 15:03:39 +03:00
2023-01-25 08:23:10 +03:00
r1 = r1 . to ( dtype )
2022-10-02 15:03:39 +03:00
r2 = rearrange ( r1 , ' (b h) n d -> b n (h d) ' , h = h )
del r1
return self . to_out ( r2 )
2022-10-11 11:09:51 +03:00
# taken from https://github.com/Doggettx/stable-diffusion and modified
2022-10-02 15:03:39 +03:00
def split_cross_attention_forward ( self , x , context = None , mask = None ) :
h = self . heads
q_in = self . to_q ( x )
context = default ( context , x )
2022-10-07 10:17:52 +03:00
2023-01-21 08:36:07 +03:00
context_k , context_v = hypernetwork . apply_hypernetworks ( shared . loaded_hypernetworks , context )
2022-10-11 11:09:51 +03:00
k_in = self . to_k ( context_k )
v_in = self . to_v ( context_v )
2022-10-07 10:17:52 +03:00
2023-01-25 08:23:10 +03:00
dtype = q_in . dtype
if shared . opts . upcast_attn :
q_in , k_in , v_in = q_in . float ( ) , k_in . float ( ) , v_in if v_in . device . type == ' mps ' else v_in . float ( )
2022-10-02 15:03:39 +03:00
2023-01-25 08:23:10 +03:00
with devices . without_autocast ( disable = not shared . opts . upcast_attn ) :
k_in = k_in * self . scale
2023-05-11 18:28:15 +03:00
2023-01-25 08:23:10 +03:00
del context , x
2023-05-11 18:28:15 +03:00
2023-05-10 11:05:02 +03:00
q , k , v = ( rearrange ( t , ' b n (h d) -> (b h) n d ' , h = h ) for t in ( q_in , k_in , v_in ) )
2023-01-25 08:23:10 +03:00
del q_in , k_in , v_in
2023-05-11 18:28:15 +03:00
2023-01-25 08:23:10 +03:00
r1 = torch . zeros ( q . shape [ 0 ] , q . shape [ 1 ] , v . shape [ 2 ] , device = q . device , dtype = q . dtype )
2023-05-11 18:28:15 +03:00
2023-01-25 08:23:10 +03:00
mem_free_total = get_available_vram ( )
2023-05-11 18:28:15 +03:00
2023-01-25 08:23:10 +03:00
gb = 1024 * * 3
tensor_size = q . shape [ 0 ] * q . shape [ 1 ] * k . shape [ 1 ] * q . element_size ( )
modifier = 3 if q . element_size ( ) == 2 else 2.5
mem_required = tensor_size * modifier
steps = 1
2023-05-11 18:28:15 +03:00
2023-01-25 08:23:10 +03:00
if mem_required > mem_free_total :
steps = 2 * * ( math . ceil ( math . log ( mem_required / mem_free_total , 2 ) ) )
# print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
# f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
2023-05-11 18:28:15 +03:00
2023-01-25 08:23:10 +03:00
if steps > 64 :
max_res = math . floor ( math . sqrt ( math . sqrt ( mem_free_total / 2.5 ) ) / 8 ) * 64
raise RuntimeError ( f ' Not enough memory, use lower resolution (max approx. { max_res } x { max_res } ). '
f ' Need: { mem_required / 64 / gb : 0.1f } GB free, Have: { mem_free_total / gb : 0.1f } GB free ' )
2023-05-11 18:28:15 +03:00
2023-01-25 08:23:10 +03:00
slice_size = q . shape [ 1 ] / / steps if ( q . shape [ 1 ] % steps ) == 0 else q . shape [ 1 ]
for i in range ( 0 , q . shape [ 1 ] , slice_size ) :
end = i + slice_size
s1 = einsum ( ' b i d, b j d -> b i j ' , q [ : , i : end ] , k )
2023-05-11 18:28:15 +03:00
2023-01-25 08:23:10 +03:00
s2 = s1 . softmax ( dim = - 1 , dtype = q . dtype )
del s1
2023-05-11 18:28:15 +03:00
2023-01-25 08:23:10 +03:00
r1 [ : , i : end ] = einsum ( ' b i j, b j d -> b i d ' , s2 , v )
del s2
2023-05-11 18:28:15 +03:00
2023-01-25 08:23:10 +03:00
del q , k , v
2022-10-02 15:03:39 +03:00
2023-01-25 08:23:10 +03:00
r1 = r1 . to ( dtype )
2022-10-02 15:03:39 +03:00
r2 = rearrange ( r1 , ' (b h) n d -> b n (h d) ' , h = h )
del r1
return self . to_out ( r2 )
2022-10-11 05:48:54 +03:00
2022-12-20 01:25:14 +03:00
# -- Taken from https://github.com/invoke-ai/InvokeAI and modified --
2022-12-27 16:50:55 +03:00
mem_total_gb = psutil . virtual_memory ( ) . total / / ( 1 << 30 )
2022-10-11 05:48:54 +03:00
def einsum_op_compvis ( q , k , v ) :
s = einsum ( ' b i d, b j d -> b i j ' , q , k )
s = s . softmax ( dim = - 1 , dtype = s . dtype )
return einsum ( ' b i j, b j d -> b i d ' , s , v )
def einsum_op_slice_0 ( q , k , v , slice_size ) :
r = torch . zeros ( q . shape [ 0 ] , q . shape [ 1 ] , v . shape [ 2 ] , device = q . device , dtype = q . dtype )
for i in range ( 0 , q . shape [ 0 ] , slice_size ) :
end = i + slice_size
r [ i : end ] = einsum_op_compvis ( q [ i : end ] , k [ i : end ] , v [ i : end ] )
return r
def einsum_op_slice_1 ( q , k , v , slice_size ) :
r = torch . zeros ( q . shape [ 0 ] , q . shape [ 1 ] , v . shape [ 2 ] , device = q . device , dtype = q . dtype )
for i in range ( 0 , q . shape [ 1 ] , slice_size ) :
end = i + slice_size
r [ : , i : end ] = einsum_op_compvis ( q [ : , i : end ] , k , v )
return r
def einsum_op_mps_v1 ( q , k , v ) :
2022-12-20 01:25:14 +03:00
if q . shape [ 0 ] * q . shape [ 1 ] < = 2 * * 16 : # (512x512) max q.shape[1]: 4096
2022-10-11 05:48:54 +03:00
return einsum_op_compvis ( q , k , v )
else :
slice_size = math . floor ( 2 * * 30 / ( q . shape [ 0 ] * q . shape [ 1 ] ) )
2022-12-20 01:25:14 +03:00
if slice_size % 4096 == 0 :
slice_size - = 1
2022-10-11 05:48:54 +03:00
return einsum_op_slice_1 ( q , k , v , slice_size )
def einsum_op_mps_v2 ( q , k , v ) :
2022-12-20 01:25:14 +03:00
if mem_total_gb > 8 and q . shape [ 0 ] * q . shape [ 1 ] < = 2 * * 16 :
2022-10-11 05:48:54 +03:00
return einsum_op_compvis ( q , k , v )
else :
return einsum_op_slice_0 ( q , k , v , 1 )
def einsum_op_tensor_mem ( q , k , v , max_tensor_mb ) :
size_mb = q . shape [ 0 ] * q . shape [ 1 ] * k . shape [ 1 ] * q . element_size ( ) / / ( 1 << 20 )
if size_mb < = max_tensor_mb :
return einsum_op_compvis ( q , k , v )
div = 1 << int ( ( size_mb - 1 ) / max_tensor_mb ) . bit_length ( )
if div < = q . shape [ 0 ] :
return einsum_op_slice_0 ( q , k , v , q . shape [ 0 ] / / div )
return einsum_op_slice_1 ( q , k , v , max ( q . shape [ 1 ] / / div , 1 ) )
2022-10-11 10:32:11 +03:00
def einsum_op_cuda ( q , k , v ) :
stats = torch . cuda . memory_stats ( q . device )
mem_active = stats [ ' active_bytes.all.current ' ]
mem_reserved = stats [ ' reserved_bytes.all.current ' ]
mem_free_cuda , _ = torch . cuda . mem_get_info ( q . device )
mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_cuda + mem_free_torch
# Divide factor of safety as there's copying and fragmentation
2022-10-19 02:28:28 +03:00
return einsum_op_tensor_mem ( q , k , v , mem_free_total / 3.3 / ( 1 << 20 ) )
2022-10-11 10:32:11 +03:00
2022-10-11 05:48:54 +03:00
def einsum_op ( q , k , v ) :
2022-10-11 10:32:11 +03:00
if q . device . type == ' cuda ' :
return einsum_op_cuda ( q , k , v )
2022-10-11 05:48:54 +03:00
if q . device . type == ' mps ' :
2022-12-20 01:25:14 +03:00
if mem_total_gb > = 32 and q . shape [ 0 ] % 32 != 0 and q . shape [ 0 ] * q . shape [ 1 ] < 2 * * 18 :
2022-10-11 05:48:54 +03:00
return einsum_op_mps_v1 ( q , k , v )
return einsum_op_mps_v2 ( q , k , v )
# Smaller slices are faster due to L2/L3/SLC caches.
# Tested on i7 with 8MB L3 cache.
return einsum_op_tensor_mem ( q , k , v , 32 )
def split_cross_attention_forward_invokeAI ( self , x , context = None , mask = None ) :
h = self . heads
q = self . to_q ( x )
context = default ( context , x )
2023-01-21 08:36:07 +03:00
context_k , context_v = hypernetwork . apply_hypernetworks ( shared . loaded_hypernetworks , context )
2023-01-25 08:23:10 +03:00
k = self . to_k ( context_k )
2022-10-11 12:13:17 +03:00
v = self . to_v ( context_v )
del context , context_k , context_v , x
2022-10-11 05:48:54 +03:00
2023-01-25 08:23:10 +03:00
dtype = q . dtype
if shared . opts . upcast_attn :
q , k , v = q . float ( ) , k . float ( ) , v if v . device . type == ' mps ' else v . float ( )
with devices . without_autocast ( disable = not shared . opts . upcast_attn ) :
k = k * self . scale
2023-05-11 18:28:15 +03:00
2023-05-10 11:05:02 +03:00
q , k , v = ( rearrange ( t , ' b n (h d) -> (b h) n d ' , h = h ) for t in ( q , k , v ) )
2023-01-25 08:23:10 +03:00
r = einsum_op ( q , k , v )
r = r . to ( dtype )
2022-10-11 05:48:54 +03:00
return self . to_out ( rearrange ( r , ' (b h) n d -> b n (h d) ' , h = h ) )
2022-10-11 06:55:48 +03:00
# -- End of code from https://github.com/invoke-ai/InvokeAI --
2022-10-11 05:48:54 +03:00
2022-12-27 16:50:55 +03:00
# Based on Birch-san's modified implementation of sub-quadratic attention from https://github.com/Birch-san/diffusers/pull/1
2023-01-07 00:42:47 +03:00
# The sub_quad_attention_forward function is under the MIT License listed under Memory Efficient Attention in the Licenses section of the web UI interface
2022-12-27 16:50:55 +03:00
def sub_quad_attention_forward ( self , x , context = None , mask = None ) :
assert mask is None , " attention-mask not currently implemented for SubQuadraticCrossAttnProcessor. "
h = self . heads
q = self . to_q ( x )
context = default ( context , x )
2023-01-21 08:36:07 +03:00
context_k , context_v = hypernetwork . apply_hypernetworks ( shared . loaded_hypernetworks , context )
2022-12-27 16:50:55 +03:00
k = self . to_k ( context_k )
v = self . to_v ( context_v )
del context , context_k , context_v , x
q = q . unflatten ( - 1 , ( h , - 1 ) ) . transpose ( 1 , 2 ) . flatten ( end_dim = 1 )
k = k . unflatten ( - 1 , ( h , - 1 ) ) . transpose ( 1 , 2 ) . flatten ( end_dim = 1 )
v = v . unflatten ( - 1 , ( h , - 1 ) ) . transpose ( 1 , 2 ) . flatten ( end_dim = 1 )
2023-04-14 09:22:48 +03:00
if q . device . type == ' mps ' :
q , k , v = q . contiguous ( ) , k . contiguous ( ) , v . contiguous ( )
2023-01-25 08:23:10 +03:00
dtype = q . dtype
if shared . opts . upcast_attn :
q , k = q . float ( ) , k . float ( )
2023-01-06 09:01:51 +03:00
x = sub_quad_attention ( q , k , v , q_chunk_size = shared . cmd_opts . sub_quad_q_chunk_size , kv_chunk_size = shared . cmd_opts . sub_quad_kv_chunk_size , chunk_threshold = shared . cmd_opts . sub_quad_chunk_threshold , use_checkpoint = self . training )
2022-12-27 16:50:55 +03:00
2023-01-25 08:23:10 +03:00
x = x . to ( dtype )
2022-12-27 16:50:55 +03:00
x = x . unflatten ( 0 , ( - 1 , h ) ) . transpose ( 1 , 2 ) . flatten ( start_dim = 2 )
out_proj , dropout = self . to_out
x = out_proj ( x )
x = dropout ( x )
return x
2023-01-06 09:01:51 +03:00
def sub_quad_attention ( q , k , v , q_chunk_size = 1024 , kv_chunk_size = None , kv_chunk_size_min = None , chunk_threshold = None , use_checkpoint = True ) :
2022-12-27 16:50:55 +03:00
bytes_per_token = torch . finfo ( q . dtype ) . bits / / 8
batch_x_heads , q_tokens , _ = q . shape
_ , k_tokens , _ = k . shape
qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens
2023-01-06 09:01:51 +03:00
if chunk_threshold is None :
chunk_threshold_bytes = int ( get_available_vram ( ) * 0.9 ) if q . device . type == ' mps ' else int ( get_available_vram ( ) * 0.7 )
elif chunk_threshold == 0 :
2022-12-27 16:50:55 +03:00
chunk_threshold_bytes = None
2023-01-06 09:01:51 +03:00
else :
chunk_threshold_bytes = int ( 0.01 * chunk_threshold * get_available_vram ( ) )
2022-12-27 16:50:55 +03:00
2023-01-06 09:01:51 +03:00
if kv_chunk_size_min is None and chunk_threshold_bytes is not None :
2022-12-27 16:50:55 +03:00
kv_chunk_size_min = chunk_threshold_bytes / / ( batch_x_heads * bytes_per_token * ( k . shape [ 2 ] + v . shape [ 2 ] ) )
elif kv_chunk_size_min == 0 :
kv_chunk_size_min = None
if chunk_threshold_bytes is not None and qk_matmul_size_bytes < = chunk_threshold_bytes :
# the big matmul fits into our memory limit; do everything in 1 chunk,
# i.e. send it down the unchunked fast-path
kv_chunk_size = k_tokens
2023-01-25 08:23:10 +03:00
with devices . without_autocast ( disable = q . dtype == v . dtype ) :
2023-05-18 22:48:28 +03:00
return sub_quadratic_attention . efficient_dot_product_attention (
2023-01-25 08:23:10 +03:00
q ,
k ,
v ,
query_chunk_size = q_chunk_size ,
kv_chunk_size = kv_chunk_size ,
kv_chunk_size_min = kv_chunk_size_min ,
use_checkpoint = use_checkpoint ,
)
2022-12-27 16:50:55 +03:00
2023-01-23 16:40:20 +03:00
def get_xformers_flash_attention_op ( q , k , v ) :
if not shared . cmd_opts . xformers_flash_attention :
return None
try :
flash_attention_op = xformers . ops . MemoryEfficientAttentionFlashAttentionOp
fw , bw = flash_attention_op
if fw . supports ( xformers . ops . fmha . Inputs ( query = q , key = k , value = v , attn_bias = None ) ) :
return flash_attention_op
except Exception as e :
errors . display_once ( e , " enabling flash attention " )
return None
2022-10-07 05:21:49 +03:00
def xformers_attention_forward ( self , x , context = None , mask = None ) :
h = self . heads
q_in = self . to_q ( x )
context = default ( context , x )
2022-10-11 11:09:51 +03:00
2023-01-21 08:36:07 +03:00
context_k , context_v = hypernetwork . apply_hypernetworks ( shared . loaded_hypernetworks , context )
2022-10-11 11:09:51 +03:00
k_in = self . to_k ( context_k )
v_in = self . to_v ( context_v )
2023-05-10 11:05:02 +03:00
q , k , v = ( rearrange ( t , ' b n (h d) -> b n h d ' , h = h ) for t in ( q_in , k_in , v_in ) )
2022-10-07 05:21:49 +03:00
del q_in , k_in , v_in
2023-01-21 11:42:04 +03:00
2023-01-25 08:23:10 +03:00
dtype = q . dtype
if shared . opts . upcast_attn :
2023-03-21 13:50:22 +03:00
q , k , v = q . float ( ) , k . float ( ) , v . float ( )
2023-01-25 08:23:10 +03:00
2023-01-23 16:40:20 +03:00
out = xformers . ops . memory_efficient_attention ( q , k , v , attn_bias = None , op = get_xformers_flash_attention_op ( q , k , v ) )
2022-10-07 05:21:49 +03:00
2023-01-25 08:23:10 +03:00
out = out . to ( dtype )
2022-10-08 04:09:18 +03:00
out = rearrange ( out , ' b n h d -> b n (h d) ' , h = h )
2022-10-07 05:21:49 +03:00
return self . to_out ( out )
2023-03-06 22:33:13 +03:00
# Based on Diffusers usage of scaled dot product attention from https://github.com/huggingface/diffusers/blob/c7da8fd23359a22d0df2741688b5b4f33c26df21/src/diffusers/models/cross_attention.py
# The scaled_dot_product_attention_forward function contains parts of code under Apache-2.0 license listed under Scaled Dot Product Attention in the Licenses section of the web UI interface
def scaled_dot_product_attention_forward ( self , x , context = None , mask = None ) :
batch_size , sequence_length , inner_dim = x . shape
if mask is not None :
mask = self . prepare_attention_mask ( mask , sequence_length , batch_size )
mask = mask . view ( batch_size , self . heads , - 1 , mask . shape [ - 1 ] )
h = self . heads
q_in = self . to_q ( x )
context = default ( context , x )
context_k , context_v = hypernetwork . apply_hypernetworks ( shared . loaded_hypernetworks , context )
k_in = self . to_k ( context_k )
v_in = self . to_v ( context_v )
head_dim = inner_dim / / h
q = q_in . view ( batch_size , - 1 , h , head_dim ) . transpose ( 1 , 2 )
k = k_in . view ( batch_size , - 1 , h , head_dim ) . transpose ( 1 , 2 )
v = v_in . view ( batch_size , - 1 , h , head_dim ) . transpose ( 1 , 2 )
2023-05-11 18:28:15 +03:00
2023-03-06 22:33:13 +03:00
del q_in , k_in , v_in
dtype = q . dtype
if shared . opts . upcast_attn :
2023-03-24 15:29:16 +03:00
q , k , v = q . float ( ) , k . float ( ) , v . float ( )
2023-03-06 22:33:13 +03:00
# the output of sdp = (batch, num_heads, seq_len, head_dim)
hidden_states = torch . nn . functional . scaled_dot_product_attention (
q , k , v , attn_mask = mask , dropout_p = 0.0 , is_causal = False
)
hidden_states = hidden_states . transpose ( 1 , 2 ) . reshape ( batch_size , - 1 , h * head_dim )
hidden_states = hidden_states . to ( dtype )
# linear proj
hidden_states = self . to_out [ 0 ] ( hidden_states )
# dropout
hidden_states = self . to_out [ 1 ] ( hidden_states )
return hidden_states
2023-03-10 10:19:36 +03:00
def scaled_dot_product_no_mem_attention_forward ( self , x , context = None , mask = None ) :
with torch . backends . cuda . sdp_kernel ( enable_flash = True , enable_math = True , enable_mem_efficient = False ) :
return scaled_dot_product_attention_forward ( self , x , context , mask )
2022-10-02 15:03:39 +03:00
def cross_attention_attnblock_forward ( self , x ) :
h_ = x
h_ = self . norm ( h_ )
q1 = self . q ( h_ )
k1 = self . k ( h_ )
v = self . v ( h_ )
# compute attention
b , c , h , w = q1 . shape
q2 = q1 . reshape ( b , c , h * w )
del q1
q = q2 . permute ( 0 , 2 , 1 ) # b,hw,c
del q2
k = k1 . reshape ( b , c , h * w ) # b,c,hw
del k1
h_ = torch . zeros_like ( k , device = q . device )
2022-12-27 16:50:55 +03:00
mem_free_total = get_available_vram ( )
2022-10-02 15:03:39 +03:00
tensor_size = q . shape [ 0 ] * q . shape [ 1 ] * k . shape [ 2 ] * q . element_size ( )
mem_required = tensor_size * 2.5
steps = 1
if mem_required > mem_free_total :
steps = 2 * * ( math . ceil ( math . log ( mem_required / mem_free_total , 2 ) ) )
slice_size = q . shape [ 1 ] / / steps if ( q . shape [ 1 ] % steps ) == 0 else q . shape [ 1 ]
for i in range ( 0 , q . shape [ 1 ] , slice_size ) :
end = i + slice_size
w1 = torch . bmm ( q [ : , i : end ] , k ) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
w2 = w1 * ( int ( c ) * * ( - 0.5 ) )
del w1
w3 = torch . nn . functional . softmax ( w2 , dim = 2 , dtype = q . dtype )
del w2
# attend to values
v1 = v . reshape ( b , c , h * w )
w4 = w3 . permute ( 0 , 2 , 1 ) # b,hw,hw (first hw of k, second of q)
del w3
h_ [ : , : , i : end ] = torch . bmm ( v1 , w4 ) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
del v1 , w4
h2 = h_ . reshape ( b , c , h , w )
del h_
h3 = self . proj_out ( h2 )
del h2
h3 + = x
return h3
2023-05-11 18:28:15 +03:00
2022-10-17 22:18:59 +03:00
def xformers_attnblock_forward ( self , x ) :
try :
h_ = x
h_ = self . norm ( h_ )
q = self . q ( h_ )
k = self . k ( h_ )
v = self . v ( h_ )
b , c , h , w = q . shape
2023-05-10 11:05:02 +03:00
q , k , v = ( rearrange ( t , ' b c h w -> b (h w) c ' ) for t in ( q , k , v ) )
2023-01-25 08:23:10 +03:00
dtype = q . dtype
if shared . opts . upcast_attn :
q , k = q . float ( ) , k . float ( )
2022-10-18 00:02:50 +03:00
q = q . contiguous ( )
k = k . contiguous ( )
v = v . contiguous ( )
2023-01-23 16:40:20 +03:00
out = xformers . ops . memory_efficient_attention ( q , k , v , op = get_xformers_flash_attention_op ( q , k , v ) )
2023-01-25 08:23:10 +03:00
out = out . to ( dtype )
2022-10-17 22:18:59 +03:00
out = rearrange ( out , ' b (h w) c -> b c h w ' , h = h )
out = self . proj_out ( out )
return x + out
except NotImplementedError :
return cross_attention_attnblock_forward ( self , x )
2022-12-27 16:50:55 +03:00
2023-03-10 20:48:41 +03:00
def sdp_attnblock_forward ( self , x ) :
h_ = x
h_ = self . norm ( h_ )
q = self . q ( h_ )
k = self . k ( h_ )
v = self . v ( h_ )
b , c , h , w = q . shape
2023-05-10 11:05:02 +03:00
q , k , v = ( rearrange ( t , ' b c h w -> b (h w) c ' ) for t in ( q , k , v ) )
2023-03-10 20:48:41 +03:00
dtype = q . dtype
if shared . opts . upcast_attn :
q , k = q . float ( ) , k . float ( )
q = q . contiguous ( )
k = k . contiguous ( )
v = v . contiguous ( )
out = torch . nn . functional . scaled_dot_product_attention ( q , k , v , dropout_p = 0.0 , is_causal = False )
out = out . to ( dtype )
out = rearrange ( out , ' b (h w) c -> b c h w ' , h = h )
out = self . proj_out ( out )
return x + out
def sdp_no_mem_attnblock_forward ( self , x ) :
with torch . backends . cuda . sdp_kernel ( enable_flash = True , enable_math = True , enable_mem_efficient = False ) :
return sdp_attnblock_forward ( self , x )
2022-12-27 16:50:55 +03:00
def sub_quad_attnblock_forward ( self , x ) :
h_ = x
h_ = self . norm ( h_ )
q = self . q ( h_ )
k = self . k ( h_ )
v = self . v ( h_ )
b , c , h , w = q . shape
2023-05-10 11:05:02 +03:00
q , k , v = ( rearrange ( t , ' b c h w -> b (h w) c ' ) for t in ( q , k , v ) )
2022-12-27 16:50:55 +03:00
q = q . contiguous ( )
k = k . contiguous ( )
v = v . contiguous ( )
2023-01-06 09:01:51 +03:00
out = sub_quad_attention ( q , k , v , q_chunk_size = shared . cmd_opts . sub_quad_q_chunk_size , kv_chunk_size = shared . cmd_opts . sub_quad_kv_chunk_size , chunk_threshold = shared . cmd_opts . sub_quad_chunk_threshold , use_checkpoint = self . training )
2022-12-27 16:50:55 +03:00
out = rearrange ( out , ' b (h w) c -> b c h w ' , h = h )
out = self . proj_out ( out )
return x + out