move scale multiplication (#1232)

This commit is contained in:
Thomas Mello 2022-09-20 02:10:46 +03:00 committed by GitHub
parent 3705f296c6
commit c6b13a418f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -236,7 +236,7 @@ class CrossAttention(nn.Module):
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
for i in range(0, q.shape[1], slice_size):
end = min(q.shape[1], i + slice_size)
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * self.scale
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k)# * self.scale
s2 = s1.softmax(dim=-1, dtype=r1.dtype)
del s1
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
@ -249,7 +249,7 @@ class CrossAttention(nn.Module):
q = self.to_q(x)
context = default(context, x)
del x
k = self.to_k(context)
k = self.to_k(context) * self.scale
v = self.to_v(context)
del context
@ -329,4 +329,4 @@ class SpatialTransformer(nn.Module):
x = block(x, context=context)
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
x = self.proj_out(x)
return x + x_in
return x + x_in