mirror of
https://github.com/sd-webui/stable-diffusion-webui.git
synced 2024-12-13 18:02:31 +03:00
Update attention.py
This commit is contained in:
parent
90a922c320
commit
c459bfaacc
@ -150,13 +150,14 @@ class SpatialSelfAttention(nn.Module):
|
||||
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
|
||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., att_step=1):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
context_dim = default(context_dim, query_dim)
|
||||
|
||||
self.scale = dim_head ** -0.5
|
||||
self.heads = heads
|
||||
self.att_step = att_step
|
||||
|
||||
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
||||
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
|
||||
@ -174,23 +175,48 @@ class CrossAttention(nn.Module):
|
||||
context = default(context, x)
|
||||
k = self.to_k(context)
|
||||
v = self.to_v(context)
|
||||
del context, x
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
||||
|
||||
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
|
||||
|
||||
if exists(mask):
|
||||
mask = rearrange(mask, 'b ... -> b (...)')
|
||||
max_neg_value = -torch.finfo(sim.dtype).max
|
||||
mask = repeat(mask, 'b j -> (b h) () j', h=h)
|
||||
sim.masked_fill_(~mask, max_neg_value)
|
||||
limit = k.shape[0]
|
||||
att_step = self.att_step
|
||||
q_chunks = list(torch.tensor_split(q, limit//att_step, dim=0))
|
||||
k_chunks = list(torch.tensor_split(k, limit//att_step, dim=0))
|
||||
v_chunks = list(torch.tensor_split(v, limit//att_step, dim=0))
|
||||
|
||||
# attention, what we cannot get enough of
|
||||
attn = sim.softmax(dim=-1)
|
||||
q_chunks.reverse()
|
||||
k_chunks.reverse()
|
||||
v_chunks.reverse()
|
||||
sim = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
|
||||
del k, q, v
|
||||
for i in range (0, limit, att_step):
|
||||
|
||||
out = einsum('b i j, b j d -> b i d', attn, v)
|
||||
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
|
||||
return self.to_out(out)
|
||||
q_buffer = q_chunks.pop()
|
||||
k_buffer = k_chunks.pop()
|
||||
v_buffer = v_chunks.pop()
|
||||
sim_buffer = einsum('b i d, b j d -> b i j', q_buffer, k_buffer) * self.scale
|
||||
|
||||
del k_buffer, q_buffer
|
||||
'''
|
||||
if exists(mask):
|
||||
mask_buffer = rearrange(mask[i:i+att_step,:,:,:], 'b ... -> b (...)')
|
||||
max_neg_value = -torch.finfo(sim_buffer.dtype).max
|
||||
mask_buffer = repeat(mask_buffer, 'b j -> (b h) () j', h=h)
|
||||
sim_buffer.masked_fill_(~mask_buffer, max_neg_value)
|
||||
'''
|
||||
# attention, what we cannot get enough of, by chunks
|
||||
|
||||
sim_buffer = sim_buffer.softmax(dim=-1)
|
||||
|
||||
sim_buffer = einsum('b i j, b j d -> b i d', sim_buffer, v_buffer)
|
||||
del v_buffer
|
||||
sim[i:i+att_step,:,:] = sim_buffer
|
||||
|
||||
del sim_buffer
|
||||
sim = rearrange(sim, '(b h) n d -> b n (h d)', h=h)
|
||||
return self.to_out(sim)
|
||||
|
||||
|
||||
class BasicTransformerBlock(nn.Module):
|
||||
@ -258,4 +284,4 @@ class SpatialTransformer(nn.Module):
|
||||
x = block(x, context=context)
|
||||
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
|
||||
x = self.proj_out(x)
|
||||
return x + x_in
|
||||
return x + x_in
|
||||
|
Loading…
Reference in New Issue
Block a user