Fixed bug in txt2vid not allowing you to use custom models.

Fixed the script that convert ckpt models to diffusers not working with some models if they had invalid or broken keys, these keys are now ignored so we can still convert the model.
This commit is contained in:
ZeroCool940711 2022-11-23 19:56:40 -07:00
parent 54e7da7721
commit 0d734219a8
No known key found for this signature in database
GPG Key ID: 4E4072992B5BC640
2 changed files with 48 additions and 18 deletions

View File

@ -30,7 +30,7 @@ except ImportError:
from diffusers import (
AutoencoderKL,
DDIMScheduler,
DPMSolverMultistepScheduler,
#DPMSolverMultistepScheduler,
EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler,
LDMTextToImagePipeline,
@ -628,7 +628,10 @@ def convert_ldm_clip_checkpoint(checkpoint):
if key.startswith("cond_stage_model.transformer"):
text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
try:
text_model.load_state_dict(text_model_dict)
except RuntimeError:
pass
return text_model

View File

@ -54,7 +54,8 @@ from diffusers import StableDiffusionPipeline, DiffusionPipeline
#from stable_diffusion_videos import StableDiffusionWalkPipeline
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, \
PNDMScheduler
PNDMScheduler, DDPMScheduler, FlaxPNDMScheduler, FlaxDDIMScheduler, \
FlaxDDPMScheduler, FlaxKarrasVeScheduler, IPNDMScheduler, KarrasVeScheduler
from diffusers.configuration_utils import FrozenDict
from diffusers.models import AutoencoderKL, UNet2DConditionModel
@ -1125,6 +1126,8 @@ def load_diffusers_model(weights_path,torch_device):
if weights_path == "runwayml/stable-diffusion-v1-5":
model_path = os.path.join("models", "diffusers", "stable-diffusion-v1-5")
else:
model_path = weights_path
if not os.path.exists(model_path + "/model_index.json"):
server_state["pipe"] = StableDiffusionPipeline.from_pretrained(
@ -1359,7 +1362,29 @@ def txt2vid(
beta_start=beta_start, beta_end=beta_end, beta_schedule=beta_schedule
)
SCHEDULERS = dict(default=default_scheduler, ddim=ddim_scheduler, klms=klms_scheduler)
#flaxddims_scheduler = FlaxDDIMScheduler(
#beta_start=beta_start, beta_end=beta_end, beta_schedule=beta_schedule
#)
#flaxddpms_scheduler = FlaxDDPMScheduler(
#beta_start=beta_start, beta_end=beta_end, beta_schedule=beta_schedule
#)
#flaxpndms_scheduler = FlaxPNDMScheduler(
#beta_start=beta_start, beta_end=beta_end, beta_schedule=beta_schedule
#)
ddpms_scheduler = DDPMScheduler(
beta_start=beta_start, beta_end=beta_end, beta_schedule=beta_schedule
)
SCHEDULERS = dict(default=default_scheduler, ddim=ddim_scheduler,
klms=klms_scheduler,
ddpms=ddpms_scheduler,
#flaxddims=flaxddims_scheduler,
#flaxddpms=flaxddpms_scheduler,
#flaxpndms=flaxpndms_scheduler,
)
with st.session_state["progress_bar_text"].container():
with hc.HyLoader('Loading Models...', hc.Loaders.standard_loaders,index=[0]):
@ -1721,7 +1746,9 @@ def layout():
#sampler_name_list = ["k_lms", "k_euler", "k_euler_a", "k_dpm_2", "k_dpm_2_a", "k_heun", "PLMS", "DDIM"]
#sampler_name = st.selectbox("Sampling method", sampler_name_list,
#index=sampler_name_list.index(st.session_state['defaults'].txt2vid.default_sampler), help="Sampling method to use. Default: k_euler")
scheduler_name_list = ["klms", "ddim"]
scheduler_name_list = ["klms", "ddim", "ddpms",
#"flaxddims", "flaxddpms", "flaxpndms"
]
scheduler_name = st.selectbox("Scheduler:", scheduler_name_list,
index=scheduler_name_list.index(st.session_state['defaults'].txt2vid.scheduler_name), help="Scheduler to use. Default: klms")