2022-10-02 15:03:39 +03:00
import math
2022-10-08 17:02:18 +03:00
import sys
import traceback
2022-12-27 16:50:55 +03:00
import psutil
2022-10-08 17:02:18 +03:00
2022-10-02 15:03:39 +03:00
import torch
from torch import einsum
2022-10-08 16:33:39 +03:00
2022-10-02 15:03:39 +03:00
from ldm . util import default
from einops import rearrange
2023-01-25 08:23:10 +03:00
from modules import shared , errors , devices
2022-10-11 15:51:22 +03:00
from modules . hypernetworks import hypernetwork
2022-10-11 11:09:51 +03:00
2022-12-27 16:50:55 +03:00
from . sub_quadratic_attention import efficient_dot_product_attention
2022-10-07 10:17:52 +03:00
2022-10-08 19:25:10 +03:00
if shared . cmd_opts . xformers or shared . cmd_opts . force_enable_xformers :
2022-10-08 17:02:18 +03:00
try :
import xformers . ops
shared . xformers_available = True
except Exception :
print ( " Cannot import xformers " , file = sys . stderr )
print ( traceback . format_exc ( ) , file = sys . stderr )
2022-10-02 15:03:39 +03:00
2022-12-27 16:50:55 +03:00
def get_available_vram ( ) :
if shared . device . type == ' cuda ' :
stats = torch . cuda . memory_stats ( shared . device )
mem_active = stats [ ' active_bytes.all.current ' ]
mem_reserved = stats [ ' reserved_bytes.all.current ' ]
mem_free_cuda , _ = torch . cuda . mem_get_info ( torch . cuda . current_device ( ) )
mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_cuda + mem_free_torch
return mem_free_total
else :
return psutil . virtual_memory ( ) . available
2022-10-02 15:03:39 +03:00
# see https://github.com/basujindal/stable-diffusion/pull/117 for discussion
def split_cross_attention_forward_v1 ( self , x , context = None , mask = None ) :
h = self . heads
2022-10-08 08:47:02 +03:00
q_in = self . to_q ( x )
2022-10-02 15:03:39 +03:00
context = default ( context , x )
2022-10-08 08:47:02 +03:00
2023-01-21 08:36:07 +03:00
context_k , context_v = hypernetwork . apply_hypernetworks ( shared . loaded_hypernetworks , context )
2022-10-11 11:09:51 +03:00
k_in = self . to_k ( context_k )
v_in = self . to_v ( context_v )
del context , context_k , context_v , x
2022-10-02 15:03:39 +03:00
2022-10-08 08:47:02 +03:00
q , k , v = map ( lambda t : rearrange ( t , ' b n (h d) -> (b h) n d ' , h = h ) , ( q_in , k_in , v_in ) )
del q_in , k_in , v_in
2022-10-02 15:03:39 +03:00
2023-01-25 08:23:10 +03:00
dtype = q . dtype
if shared . opts . upcast_attn :
q , k , v = q . float ( ) , k . float ( ) , v . float ( )
2022-10-02 15:03:39 +03:00
2023-01-25 08:23:10 +03:00
with devices . without_autocast ( disable = not shared . opts . upcast_attn ) :
r1 = torch . zeros ( q . shape [ 0 ] , q . shape [ 1 ] , v . shape [ 2 ] , device = q . device , dtype = q . dtype )
for i in range ( 0 , q . shape [ 0 ] , 2 ) :
end = i + 2
s1 = einsum ( ' b i d, b j d -> b i j ' , q [ i : end ] , k [ i : end ] )
s1 * = self . scale
s2 = s1 . softmax ( dim = - 1 )
del s1
r1 [ i : end ] = einsum ( ' b i j, b j d -> b i d ' , s2 , v [ i : end ] )
del s2
del q , k , v
2022-10-02 15:03:39 +03:00
2023-01-25 08:23:10 +03:00
r1 = r1 . to ( dtype )
2022-10-02 15:03:39 +03:00
r2 = rearrange ( r1 , ' (b h) n d -> b n (h d) ' , h = h )
del r1
return self . to_out ( r2 )
2022-10-11 11:09:51 +03:00
# taken from https://github.com/Doggettx/stable-diffusion and modified
2022-10-02 15:03:39 +03:00
def split_cross_attention_forward ( self , x , context = None , mask = None ) :
h = self . heads
q_in = self . to_q ( x )
context = default ( context , x )
2022-10-07 10:17:52 +03:00
2023-01-21 08:36:07 +03:00
context_k , context_v = hypernetwork . apply_hypernetworks ( shared . loaded_hypernetworks , context )
2022-10-11 11:09:51 +03:00
k_in = self . to_k ( context_k )
v_in = self . to_v ( context_v )
2022-10-07 10:17:52 +03:00
2023-01-25 08:23:10 +03:00
dtype = q_in . dtype
if shared . opts . upcast_attn :
q_in , k_in , v_in = q_in . float ( ) , k_in . float ( ) , v_in if v_in . device . type == ' mps ' else v_in . float ( )
2022-10-02 15:03:39 +03:00
2023-01-25 08:23:10 +03:00
with devices . without_autocast ( disable = not shared . opts . upcast_attn ) :
k_in = k_in * self . scale
del context , x
q , k , v = map ( lambda t : rearrange ( t , ' b n (h d) -> (b h) n d ' , h = h ) , ( q_in , k_in , v_in ) )
del q_in , k_in , v_in
r1 = torch . zeros ( q . shape [ 0 ] , q . shape [ 1 ] , v . shape [ 2 ] , device = q . device , dtype = q . dtype )
mem_free_total = get_available_vram ( )
gb = 1024 * * 3
tensor_size = q . shape [ 0 ] * q . shape [ 1 ] * k . shape [ 1 ] * q . element_size ( )
modifier = 3 if q . element_size ( ) == 2 else 2.5
mem_required = tensor_size * modifier
steps = 1
if mem_required > mem_free_total :
steps = 2 * * ( math . ceil ( math . log ( mem_required / mem_free_total , 2 ) ) )
# print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
# f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
if steps > 64 :
max_res = math . floor ( math . sqrt ( math . sqrt ( mem_free_total / 2.5 ) ) / 8 ) * 64
raise RuntimeError ( f ' Not enough memory, use lower resolution (max approx. { max_res } x { max_res } ). '
f ' Need: { mem_required / 64 / gb : 0.1f } GB free, Have: { mem_free_total / gb : 0.1f } GB free ' )
slice_size = q . shape [ 1 ] / / steps if ( q . shape [ 1 ] % steps ) == 0 else q . shape [ 1 ]
for i in range ( 0 , q . shape [ 1 ] , slice_size ) :
end = i + slice_size
s1 = einsum ( ' b i d, b j d -> b i j ' , q [ : , i : end ] , k )
s2 = s1 . softmax ( dim = - 1 , dtype = q . dtype )
del s1
r1 [ : , i : end ] = einsum ( ' b i j, b j d -> b i d ' , s2 , v )
del s2
del q , k , v
2022-10-02 15:03:39 +03:00
2023-01-25 08:23:10 +03:00
r1 = r1 . to ( dtype )
2022-10-02 15:03:39 +03:00
r2 = rearrange ( r1 , ' (b h) n d -> b n (h d) ' , h = h )
del r1
return self . to_out ( r2 )
2022-10-11 05:48:54 +03:00
2022-12-20 01:25:14 +03:00
# -- Taken from https://github.com/invoke-ai/InvokeAI and modified --
2022-12-27 16:50:55 +03:00
mem_total_gb = psutil . virtual_memory ( ) . total / / ( 1 << 30 )
2022-10-11 05:48:54 +03:00
def einsum_op_compvis ( q , k , v ) :
s = einsum ( ' b i d, b j d -> b i j ' , q , k )
s = s . softmax ( dim = - 1 , dtype = s . dtype )
return einsum ( ' b i j, b j d -> b i d ' , s , v )
def einsum_op_slice_0 ( q , k , v , slice_size ) :
r = torch . zeros ( q . shape [ 0 ] , q . shape [ 1 ] , v . shape [ 2 ] , device = q . device , dtype = q . dtype )
for i in range ( 0 , q . shape [ 0 ] , slice_size ) :
end = i + slice_size
r [ i : end ] = einsum_op_compvis ( q [ i : end ] , k [ i : end ] , v [ i : end ] )
return r
def einsum_op_slice_1 ( q , k , v , slice_size ) :
r = torch . zeros ( q . shape [ 0 ] , q . shape [ 1 ] , v . shape [ 2 ] , device = q . device , dtype = q . dtype )
for i in range ( 0 , q . shape [ 1 ] , slice_size ) :
end = i + slice_size
r [ : , i : end ] = einsum_op_compvis ( q [ : , i : end ] , k , v )
return r
def einsum_op_mps_v1 ( q , k , v ) :
2022-12-20 01:25:14 +03:00
if q . shape [ 0 ] * q . shape [ 1 ] < = 2 * * 16 : # (512x512) max q.shape[1]: 4096
2022-10-11 05:48:54 +03:00
return einsum_op_compvis ( q , k , v )
else :
slice_size = math . floor ( 2 * * 30 / ( q . shape [ 0 ] * q . shape [ 1 ] ) )
2022-12-20 01:25:14 +03:00
if slice_size % 4096 == 0 :
slice_size - = 1
2022-10-11 05:48:54 +03:00
return einsum_op_slice_1 ( q , k , v , slice_size )
def einsum_op_mps_v2 ( q , k , v ) :
2022-12-20 01:25:14 +03:00
if mem_total_gb > 8 and q . shape [ 0 ] * q . shape [ 1 ] < = 2 * * 16 :
2022-10-11 05:48:54 +03:00
return einsum_op_compvis ( q , k , v )
else :
return einsum_op_slice_0 ( q , k , v , 1 )
def einsum_op_tensor_mem ( q , k , v , max_tensor_mb ) :
size_mb = q . shape [ 0 ] * q . shape [ 1 ] * k . shape [ 1 ] * q . element_size ( ) / / ( 1 << 20 )
if size_mb < = max_tensor_mb :
return einsum_op_compvis ( q , k , v )
div = 1 << int ( ( size_mb - 1 ) / max_tensor_mb ) . bit_length ( )
if div < = q . shape [ 0 ] :
return einsum_op_slice_0 ( q , k , v , q . shape [ 0 ] / / div )
return einsum_op_slice_1 ( q , k , v , max ( q . shape [ 1 ] / / div , 1 ) )
2022-10-11 10:32:11 +03:00
def einsum_op_cuda ( q , k , v ) :
stats = torch . cuda . memory_stats ( q . device )
mem_active = stats [ ' active_bytes.all.current ' ]
mem_reserved = stats [ ' reserved_bytes.all.current ' ]
mem_free_cuda , _ = torch . cuda . mem_get_info ( q . device )
mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_cuda + mem_free_torch
# Divide factor of safety as there's copying and fragmentation
2022-10-19 02:28:28 +03:00
return einsum_op_tensor_mem ( q , k , v , mem_free_total / 3.3 / ( 1 << 20 ) )
2022-10-11 10:32:11 +03:00
2022-10-11 05:48:54 +03:00
def einsum_op ( q , k , v ) :
2022-10-11 10:32:11 +03:00
if q . device . type == ' cuda ' :
return einsum_op_cuda ( q , k , v )
2022-10-11 05:48:54 +03:00
if q . device . type == ' mps ' :
2022-12-20 01:25:14 +03:00
if mem_total_gb > = 32 and q . shape [ 0 ] % 32 != 0 and q . shape [ 0 ] * q . shape [ 1 ] < 2 * * 18 :
2022-10-11 05:48:54 +03:00
return einsum_op_mps_v1 ( q , k , v )
return einsum_op_mps_v2 ( q , k , v )
# Smaller slices are faster due to L2/L3/SLC caches.
# Tested on i7 with 8MB L3 cache.
return einsum_op_tensor_mem ( q , k , v , 32 )
def split_cross_attention_forward_invokeAI ( self , x , context = None , mask = None ) :
h = self . heads
q = self . to_q ( x )
context = default ( context , x )
2023-01-21 08:36:07 +03:00
context_k , context_v = hypernetwork . apply_hypernetworks ( shared . loaded_hypernetworks , context )
2023-01-25 08:23:10 +03:00
k = self . to_k ( context_k )
2022-10-11 12:13:17 +03:00
v = self . to_v ( context_v )
del context , context_k , context_v , x
2022-10-11 05:48:54 +03:00
2023-01-25 08:23:10 +03:00
dtype = q . dtype
if shared . opts . upcast_attn :
q , k , v = q . float ( ) , k . float ( ) , v if v . device . type == ' mps ' else v . float ( )
with devices . without_autocast ( disable = not shared . opts . upcast_attn ) :
k = k * self . scale
q , k , v = map ( lambda t : rearrange ( t , ' b n (h d) -> (b h) n d ' , h = h ) , ( q , k , v ) )
r = einsum_op ( q , k , v )
r = r . to ( dtype )
2022-10-11 05:48:54 +03:00
return self . to_out ( rearrange ( r , ' (b h) n d -> b n (h d) ' , h = h ) )
2022-10-11 06:55:48 +03:00
# -- End of code from https://github.com/invoke-ai/InvokeAI --
2022-10-11 05:48:54 +03:00
2022-12-27 16:50:55 +03:00
# Based on Birch-san's modified implementation of sub-quadratic attention from https://github.com/Birch-san/diffusers/pull/1
2023-01-07 00:42:47 +03:00
# The sub_quad_attention_forward function is under the MIT License listed under Memory Efficient Attention in the Licenses section of the web UI interface
2022-12-27 16:50:55 +03:00
def sub_quad_attention_forward ( self , x , context = None , mask = None ) :
assert mask is None , " attention-mask not currently implemented for SubQuadraticCrossAttnProcessor. "
h = self . heads
q = self . to_q ( x )
context = default ( context , x )
2023-01-21 08:36:07 +03:00
context_k , context_v = hypernetwork . apply_hypernetworks ( shared . loaded_hypernetworks , context )
2022-12-27 16:50:55 +03:00
k = self . to_k ( context_k )
v = self . to_v ( context_v )
del context , context_k , context_v , x
q = q . unflatten ( - 1 , ( h , - 1 ) ) . transpose ( 1 , 2 ) . flatten ( end_dim = 1 )
k = k . unflatten ( - 1 , ( h , - 1 ) ) . transpose ( 1 , 2 ) . flatten ( end_dim = 1 )
v = v . unflatten ( - 1 , ( h , - 1 ) ) . transpose ( 1 , 2 ) . flatten ( end_dim = 1 )
2023-01-25 08:23:10 +03:00
dtype = q . dtype
if shared . opts . upcast_attn :
q , k = q . float ( ) , k . float ( )
2023-01-06 09:01:51 +03:00
x = sub_quad_attention ( q , k , v , q_chunk_size = shared . cmd_opts . sub_quad_q_chunk_size , kv_chunk_size = shared . cmd_opts . sub_quad_kv_chunk_size , chunk_threshold = shared . cmd_opts . sub_quad_chunk_threshold , use_checkpoint = self . training )
2022-12-27 16:50:55 +03:00
2023-01-25 08:23:10 +03:00
x = x . to ( dtype )
2022-12-27 16:50:55 +03:00
x = x . unflatten ( 0 , ( - 1 , h ) ) . transpose ( 1 , 2 ) . flatten ( start_dim = 2 )
out_proj , dropout = self . to_out
x = out_proj ( x )
x = dropout ( x )
return x
2023-01-06 09:01:51 +03:00
def sub_quad_attention ( q , k , v , q_chunk_size = 1024 , kv_chunk_size = None , kv_chunk_size_min = None , chunk_threshold = None , use_checkpoint = True ) :
2022-12-27 16:50:55 +03:00
bytes_per_token = torch . finfo ( q . dtype ) . bits / / 8
batch_x_heads , q_tokens , _ = q . shape
_ , k_tokens , _ = k . shape
qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens
2023-01-06 09:01:51 +03:00
if chunk_threshold is None :
chunk_threshold_bytes = int ( get_available_vram ( ) * 0.9 ) if q . device . type == ' mps ' else int ( get_available_vram ( ) * 0.7 )
elif chunk_threshold == 0 :
2022-12-27 16:50:55 +03:00
chunk_threshold_bytes = None
2023-01-06 09:01:51 +03:00
else :
chunk_threshold_bytes = int ( 0.01 * chunk_threshold * get_available_vram ( ) )
2022-12-27 16:50:55 +03:00
2023-01-06 09:01:51 +03:00
if kv_chunk_size_min is None and chunk_threshold_bytes is not None :
2022-12-27 16:50:55 +03:00
kv_chunk_size_min = chunk_threshold_bytes / / ( batch_x_heads * bytes_per_token * ( k . shape [ 2 ] + v . shape [ 2 ] ) )
elif kv_chunk_size_min == 0 :
kv_chunk_size_min = None
if chunk_threshold_bytes is not None and qk_matmul_size_bytes < = chunk_threshold_bytes :
# the big matmul fits into our memory limit; do everything in 1 chunk,
# i.e. send it down the unchunked fast-path
query_chunk_size = q_tokens
kv_chunk_size = k_tokens
2023-01-25 08:23:10 +03:00
with devices . without_autocast ( disable = q . dtype == v . dtype ) :
return efficient_dot_product_attention (
q ,
k ,
v ,
query_chunk_size = q_chunk_size ,
kv_chunk_size = kv_chunk_size ,
kv_chunk_size_min = kv_chunk_size_min ,
use_checkpoint = use_checkpoint ,
)
2022-12-27 16:50:55 +03:00
2023-01-23 16:40:20 +03:00
def get_xformers_flash_attention_op ( q , k , v ) :
if not shared . cmd_opts . xformers_flash_attention :
return None
try :
flash_attention_op = xformers . ops . MemoryEfficientAttentionFlashAttentionOp
fw , bw = flash_attention_op
if fw . supports ( xformers . ops . fmha . Inputs ( query = q , key = k , value = v , attn_bias = None ) ) :
return flash_attention_op
except Exception as e :
errors . display_once ( e , " enabling flash attention " )
return None
2022-10-07 05:21:49 +03:00
def xformers_attention_forward ( self , x , context = None , mask = None ) :
h = self . heads
q_in = self . to_q ( x )
context = default ( context , x )
2022-10-11 11:09:51 +03:00
2023-01-21 08:36:07 +03:00
context_k , context_v = hypernetwork . apply_hypernetworks ( shared . loaded_hypernetworks , context )
2022-10-11 11:09:51 +03:00
k_in = self . to_k ( context_k )
v_in = self . to_v ( context_v )
2022-10-08 04:09:18 +03:00
q , k , v = map ( lambda t : rearrange ( t , ' b n (h d) -> b n h d ' , h = h ) , ( q_in , k_in , v_in ) )
2022-10-07 05:21:49 +03:00
del q_in , k_in , v_in
2023-01-21 11:42:04 +03:00
2023-01-25 08:23:10 +03:00
dtype = q . dtype
if shared . opts . upcast_attn :
q , k = q . float ( ) , k . float ( )
2023-01-23 16:40:20 +03:00
out = xformers . ops . memory_efficient_attention ( q , k , v , attn_bias = None , op = get_xformers_flash_attention_op ( q , k , v ) )
2022-10-07 05:21:49 +03:00
2023-01-25 08:23:10 +03:00
out = out . to ( dtype )
2022-10-08 04:09:18 +03:00
out = rearrange ( out , ' b n h d -> b n (h d) ' , h = h )
2022-10-07 05:21:49 +03:00
return self . to_out ( out )
2023-03-06 22:33:13 +03:00
# Based on Diffusers usage of scaled dot product attention from https://github.com/huggingface/diffusers/blob/c7da8fd23359a22d0df2741688b5b4f33c26df21/src/diffusers/models/cross_attention.py
# The scaled_dot_product_attention_forward function contains parts of code under Apache-2.0 license listed under Scaled Dot Product Attention in the Licenses section of the web UI interface
def scaled_dot_product_attention_forward ( self , x , context = None , mask = None ) :
batch_size , sequence_length , inner_dim = x . shape
if mask is not None :
mask = self . prepare_attention_mask ( mask , sequence_length , batch_size )
mask = mask . view ( batch_size , self . heads , - 1 , mask . shape [ - 1 ] )
h = self . heads
q_in = self . to_q ( x )
context = default ( context , x )
context_k , context_v = hypernetwork . apply_hypernetworks ( shared . loaded_hypernetworks , context )
k_in = self . to_k ( context_k )
v_in = self . to_v ( context_v )
head_dim = inner_dim / / h
q = q_in . view ( batch_size , - 1 , h , head_dim ) . transpose ( 1 , 2 )
k = k_in . view ( batch_size , - 1 , h , head_dim ) . transpose ( 1 , 2 )
v = v_in . view ( batch_size , - 1 , h , head_dim ) . transpose ( 1 , 2 )
del q_in , k_in , v_in
dtype = q . dtype
if shared . opts . upcast_attn :
q , k = q . float ( ) , k . float ( )
# the output of sdp = (batch, num_heads, seq_len, head_dim)
hidden_states = torch . nn . functional . scaled_dot_product_attention (
q , k , v , attn_mask = mask , dropout_p = 0.0 , is_causal = False
)
hidden_states = hidden_states . transpose ( 1 , 2 ) . reshape ( batch_size , - 1 , h * head_dim )
hidden_states = hidden_states . to ( dtype )
# linear proj
hidden_states = self . to_out [ 0 ] ( hidden_states )
# dropout
hidden_states = self . to_out [ 1 ] ( hidden_states )
return hidden_states
2022-10-02 15:03:39 +03:00
def cross_attention_attnblock_forward ( self , x ) :
h_ = x
h_ = self . norm ( h_ )
q1 = self . q ( h_ )
k1 = self . k ( h_ )
v = self . v ( h_ )
# compute attention
b , c , h , w = q1 . shape
q2 = q1 . reshape ( b , c , h * w )
del q1
q = q2 . permute ( 0 , 2 , 1 ) # b,hw,c
del q2
k = k1 . reshape ( b , c , h * w ) # b,c,hw
del k1
h_ = torch . zeros_like ( k , device = q . device )
2022-12-27 16:50:55 +03:00
mem_free_total = get_available_vram ( )
2022-10-02 15:03:39 +03:00
tensor_size = q . shape [ 0 ] * q . shape [ 1 ] * k . shape [ 2 ] * q . element_size ( )
mem_required = tensor_size * 2.5
steps = 1
if mem_required > mem_free_total :
steps = 2 * * ( math . ceil ( math . log ( mem_required / mem_free_total , 2 ) ) )
slice_size = q . shape [ 1 ] / / steps if ( q . shape [ 1 ] % steps ) == 0 else q . shape [ 1 ]
for i in range ( 0 , q . shape [ 1 ] , slice_size ) :
end = i + slice_size
w1 = torch . bmm ( q [ : , i : end ] , k ) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
w2 = w1 * ( int ( c ) * * ( - 0.5 ) )
del w1
w3 = torch . nn . functional . softmax ( w2 , dim = 2 , dtype = q . dtype )
del w2
# attend to values
v1 = v . reshape ( b , c , h * w )
w4 = w3 . permute ( 0 , 2 , 1 ) # b,hw,hw (first hw of k, second of q)
del w3
h_ [ : , : , i : end ] = torch . bmm ( v1 , w4 ) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
del v1 , w4
h2 = h_ . reshape ( b , c , h , w )
del h_
h3 = self . proj_out ( h2 )
del h2
h3 + = x
return h3
2022-10-08 11:55:02 +03:00
2022-10-17 22:18:59 +03:00
def xformers_attnblock_forward ( self , x ) :
try :
h_ = x
h_ = self . norm ( h_ )
q = self . q ( h_ )
k = self . k ( h_ )
v = self . v ( h_ )
b , c , h , w = q . shape
q , k , v = map ( lambda t : rearrange ( t , ' b c h w -> b (h w) c ' ) , ( q , k , v ) )
2023-01-25 08:23:10 +03:00
dtype = q . dtype
if shared . opts . upcast_attn :
q , k = q . float ( ) , k . float ( )
2022-10-18 00:02:50 +03:00
q = q . contiguous ( )
k = k . contiguous ( )
v = v . contiguous ( )
2023-01-23 16:40:20 +03:00
out = xformers . ops . memory_efficient_attention ( q , k , v , op = get_xformers_flash_attention_op ( q , k , v ) )
2023-01-25 08:23:10 +03:00
out = out . to ( dtype )
2022-10-17 22:18:59 +03:00
out = rearrange ( out , ' b (h w) c -> b c h w ' , h = h )
out = self . proj_out ( out )
return x + out
except NotImplementedError :
return cross_attention_attnblock_forward ( self , x )
2022-12-27 16:50:55 +03:00
def sub_quad_attnblock_forward ( self , x ) :
h_ = x
h_ = self . norm ( h_ )
q = self . q ( h_ )
k = self . k ( h_ )
v = self . v ( h_ )
b , c , h , w = q . shape
q , k , v = map ( lambda t : rearrange ( t , ' b c h w -> b (h w) c ' ) , ( q , k , v ) )
q = q . contiguous ( )
k = k . contiguous ( )
v = v . contiguous ( )
2023-01-06 09:01:51 +03:00
out = sub_quad_attention ( q , k , v , q_chunk_size = shared . cmd_opts . sub_quad_q_chunk_size , kv_chunk_size = shared . cmd_opts . sub_quad_kv_chunk_size , chunk_threshold = shared . cmd_opts . sub_quad_chunk_threshold , use_checkpoint = self . training )
2022-12-27 16:50:55 +03:00
out = rearrange ( out , ' b (h w) c -> b c h w ' , h = h )
out = self . proj_out ( out )
return x + out