fix img2img alt for SD v2.x

This commit is contained in:
Alex "mcmonkey" Goodwin 2023-03-20 15:42:36 -07:00
parent a9fed7c364
commit 05ec128ca9

View File

@ -22,7 +22,12 @@ def find_noise_for_image(p, cond, uncond, cfg_scale, steps):
x = p.init_latent x = p.init_latent
s_in = x.new_ones([x.shape[0]]) s_in = x.new_ones([x.shape[0]])
dnw = K.external.CompVisDenoiser(shared.sd_model) if shared.sd_model.parameterization == "v":
dnw = K.external.CompVisVDenoiser(shared.sd_model)
skip = 1
else:
dnw = K.external.CompVisDenoiser(shared.sd_model)
skip = 0
sigmas = dnw.get_sigmas(steps).flip(0) sigmas = dnw.get_sigmas(steps).flip(0)
shared.state.sampling_steps = steps shared.state.sampling_steps = steps
@ -37,7 +42,7 @@ def find_noise_for_image(p, cond, uncond, cfg_scale, steps):
image_conditioning = torch.cat([p.image_conditioning] * 2) image_conditioning = torch.cat([p.image_conditioning] * 2)
cond_in = {"c_concat": [image_conditioning], "c_crossattn": [cond_in]} cond_in = {"c_concat": [image_conditioning], "c_crossattn": [cond_in]}
c_out, c_in = [K.utils.append_dims(k, x_in.ndim) for k in dnw.get_scalings(sigma_in)] c_out, c_in = [K.utils.append_dims(k, x_in.ndim) for k in dnw.get_scalings(sigma_in)[skip:]]
t = dnw.sigma_to_t(sigma_in) t = dnw.sigma_to_t(sigma_in)
eps = shared.sd_model.apply_model(x_in * c_in, t, cond=cond_in) eps = shared.sd_model.apply_model(x_in * c_in, t, cond=cond_in)
@ -69,7 +74,12 @@ def find_noise_for_image_sigma_adjustment(p, cond, uncond, cfg_scale, steps):
x = p.init_latent x = p.init_latent
s_in = x.new_ones([x.shape[0]]) s_in = x.new_ones([x.shape[0]])
dnw = K.external.CompVisDenoiser(shared.sd_model) if shared.sd_model.parameterization == "v":
dnw = K.external.CompVisVDenoiser(shared.sd_model)
skip = 1
else:
dnw = K.external.CompVisDenoiser(shared.sd_model)
skip = 0
sigmas = dnw.get_sigmas(steps).flip(0) sigmas = dnw.get_sigmas(steps).flip(0)
shared.state.sampling_steps = steps shared.state.sampling_steps = steps
@ -84,7 +94,7 @@ def find_noise_for_image_sigma_adjustment(p, cond, uncond, cfg_scale, steps):
image_conditioning = torch.cat([p.image_conditioning] * 2) image_conditioning = torch.cat([p.image_conditioning] * 2)
cond_in = {"c_concat": [image_conditioning], "c_crossattn": [cond_in]} cond_in = {"c_concat": [image_conditioning], "c_crossattn": [cond_in]}
c_out, c_in = [K.utils.append_dims(k, x_in.ndim) for k in dnw.get_scalings(sigma_in)] c_out, c_in = [K.utils.append_dims(k, x_in.ndim) for k in dnw.get_scalings(sigma_in)[skip:]]
if i == 1: if i == 1:
t = dnw.sigma_to_t(torch.cat([sigmas[i] * s_in] * 2)) t = dnw.sigma_to_t(torch.cat([sigmas[i] * s_in] * 2))