mirror of
https://github.com/Sygil-Dev/sygil-webui.git
synced 2024-12-14 22:13:41 +03:00
move scale multiplication (#1232)
This commit is contained in:
parent
3705f296c6
commit
c6b13a418f
@ -236,7 +236,7 @@ class CrossAttention(nn.Module):
|
||||
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
|
||||
for i in range(0, q.shape[1], slice_size):
|
||||
end = min(q.shape[1], i + slice_size)
|
||||
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * self.scale
|
||||
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k)# * self.scale
|
||||
s2 = s1.softmax(dim=-1, dtype=r1.dtype)
|
||||
del s1
|
||||
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
|
||||
@ -249,7 +249,7 @@ class CrossAttention(nn.Module):
|
||||
q = self.to_q(x)
|
||||
context = default(context, x)
|
||||
del x
|
||||
k = self.to_k(context)
|
||||
k = self.to_k(context) * self.scale
|
||||
v = self.to_v(context)
|
||||
del context
|
||||
|
||||
@ -329,4 +329,4 @@ class SpatialTransformer(nn.Module):
|
||||
x = block(x, context=context)
|
||||
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
|
||||
x = self.proj_out(x)
|
||||
return x + x_in
|
||||
return x + x_in
|
||||
|
Loading…
Reference in New Issue
Block a user