diff --git a/ldm/modules/attention.py b/ldm/modules/attention.py index 4485c1e..59108f0 100644 --- a/ldm/modules/attention.py +++ b/ldm/modules/attention.py @@ -236,7 +236,7 @@ class CrossAttention(nn.Module): slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] for i in range(0, q.shape[1], slice_size): end = min(q.shape[1], i + slice_size) - s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * self.scale + s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k)# * self.scale s2 = s1.softmax(dim=-1, dtype=r1.dtype) del s1 r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v) @@ -249,7 +249,7 @@ class CrossAttention(nn.Module): q = self.to_q(x) context = default(context, x) del x - k = self.to_k(context) + k = self.to_k(context) * self.scale v = self.to_v(context) del context @@ -329,4 +329,4 @@ class SpatialTransformer(nn.Module): x = block(x, context=context) x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w) x = self.proj_out(x) - return x + x_in \ No newline at end of file + return x + x_in