Fixed bug in txt2vid not allowing you to use custom models.

Fixed the script that convert ckpt models to diffusers not working with some models if they had invalid or broken keys, these keys are now ignored so we can still convert the model.
This commit is contained in:
ZeroCool940711 2022-11-23 19:56:40 -07:00
parent 54e7da7721
commit 0d734219a8
No known key found for this signature in database
GPG Key ID: 4E4072992B5BC640
2 changed files with 48 additions and 18 deletions

View File

@ -30,7 +30,7 @@ except ImportError:
from diffusers import ( from diffusers import (
AutoencoderKL, AutoencoderKL,
DDIMScheduler, DDIMScheduler,
DPMSolverMultistepScheduler, #DPMSolverMultistepScheduler,
EulerAncestralDiscreteScheduler, EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler, EulerDiscreteScheduler,
LDMTextToImagePipeline, LDMTextToImagePipeline,
@ -628,7 +628,10 @@ def convert_ldm_clip_checkpoint(checkpoint):
if key.startswith("cond_stage_model.transformer"): if key.startswith("cond_stage_model.transformer"):
text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key] text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
text_model.load_state_dict(text_model_dict) try:
text_model.load_state_dict(text_model_dict)
except RuntimeError:
pass
return text_model return text_model

View File

@ -54,7 +54,8 @@ from diffusers import StableDiffusionPipeline, DiffusionPipeline
#from stable_diffusion_videos import StableDiffusionWalkPipeline #from stable_diffusion_videos import StableDiffusionWalkPipeline
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, \ from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, \
PNDMScheduler PNDMScheduler, DDPMScheduler, FlaxPNDMScheduler, FlaxDDIMScheduler, \
FlaxDDPMScheduler, FlaxKarrasVeScheduler, IPNDMScheduler, KarrasVeScheduler
from diffusers.configuration_utils import FrozenDict from diffusers.configuration_utils import FrozenDict
from diffusers.models import AutoencoderKL, UNet2DConditionModel from diffusers.models import AutoencoderKL, UNet2DConditionModel
@ -1125,16 +1126,18 @@ def load_diffusers_model(weights_path,torch_device):
if weights_path == "runwayml/stable-diffusion-v1-5": if weights_path == "runwayml/stable-diffusion-v1-5":
model_path = os.path.join("models", "diffusers", "stable-diffusion-v1-5") model_path = os.path.join("models", "diffusers", "stable-diffusion-v1-5")
else:
model_path = weights_path
if not os.path.exists(model_path + "/model_index.json"): if not os.path.exists(model_path + "/model_index.json"):
server_state["pipe"] = StableDiffusionPipeline.from_pretrained( server_state["pipe"] = StableDiffusionPipeline.from_pretrained(
weights_path, weights_path,
use_local_file=True, use_local_file=True,
use_auth_token=st.session_state["defaults"].general.huggingface_token, use_auth_token=st.session_state["defaults"].general.huggingface_token,
torch_dtype=torch.float16 if st.session_state['defaults'].general.use_float16 else None, torch_dtype=torch.float16 if st.session_state['defaults'].general.use_float16 else None,
revision="fp16" if not st.session_state['defaults'].general.no_half else None, revision="fp16" if not st.session_state['defaults'].general.no_half else None,
safety_checker=None, # Very important for videos...lots of false positives while interpolating safety_checker=None, # Very important for videos...lots of false positives while interpolating
#custom_pipeline="interpolate_stable_diffusion", #custom_pipeline="interpolate_stable_diffusion",
) )
@ -1142,11 +1145,11 @@ def load_diffusers_model(weights_path,torch_device):
else: else:
server_state["pipe"] = StableDiffusionPipeline.from_pretrained( server_state["pipe"] = StableDiffusionPipeline.from_pretrained(
model_path, model_path,
use_local_file=True, use_local_file=True,
torch_dtype=torch.float16 if st.session_state['defaults'].general.use_float16 else None, torch_dtype=torch.float16 if st.session_state['defaults'].general.use_float16 else None,
revision="fp16" if not st.session_state['defaults'].general.no_half else None, revision="fp16" if not st.session_state['defaults'].general.no_half else None,
safety_checker=None, # Very important for videos...lots of false positives while interpolating safety_checker=None, # Very important for videos...lots of false positives while interpolating
#custom_pipeline="interpolate_stable_diffusion", #custom_pipeline="interpolate_stable_diffusion",
) )
server_state["pipe"].unet.to(torch_device) server_state["pipe"].unet.to(torch_device)
@ -1359,7 +1362,29 @@ def txt2vid(
beta_start=beta_start, beta_end=beta_end, beta_schedule=beta_schedule beta_start=beta_start, beta_end=beta_end, beta_schedule=beta_schedule
) )
SCHEDULERS = dict(default=default_scheduler, ddim=ddim_scheduler, klms=klms_scheduler) #flaxddims_scheduler = FlaxDDIMScheduler(
#beta_start=beta_start, beta_end=beta_end, beta_schedule=beta_schedule
#)
#flaxddpms_scheduler = FlaxDDPMScheduler(
#beta_start=beta_start, beta_end=beta_end, beta_schedule=beta_schedule
#)
#flaxpndms_scheduler = FlaxPNDMScheduler(
#beta_start=beta_start, beta_end=beta_end, beta_schedule=beta_schedule
#)
ddpms_scheduler = DDPMScheduler(
beta_start=beta_start, beta_end=beta_end, beta_schedule=beta_schedule
)
SCHEDULERS = dict(default=default_scheduler, ddim=ddim_scheduler,
klms=klms_scheduler,
ddpms=ddpms_scheduler,
#flaxddims=flaxddims_scheduler,
#flaxddpms=flaxddpms_scheduler,
#flaxpndms=flaxpndms_scheduler,
)
with st.session_state["progress_bar_text"].container(): with st.session_state["progress_bar_text"].container():
with hc.HyLoader('Loading Models...', hc.Loaders.standard_loaders,index=[0]): with hc.HyLoader('Loading Models...', hc.Loaders.standard_loaders,index=[0]):
@ -1721,7 +1746,9 @@ def layout():
#sampler_name_list = ["k_lms", "k_euler", "k_euler_a", "k_dpm_2", "k_dpm_2_a", "k_heun", "PLMS", "DDIM"] #sampler_name_list = ["k_lms", "k_euler", "k_euler_a", "k_dpm_2", "k_dpm_2_a", "k_heun", "PLMS", "DDIM"]
#sampler_name = st.selectbox("Sampling method", sampler_name_list, #sampler_name = st.selectbox("Sampling method", sampler_name_list,
#index=sampler_name_list.index(st.session_state['defaults'].txt2vid.default_sampler), help="Sampling method to use. Default: k_euler") #index=sampler_name_list.index(st.session_state['defaults'].txt2vid.default_sampler), help="Sampling method to use. Default: k_euler")
scheduler_name_list = ["klms", "ddim"] scheduler_name_list = ["klms", "ddim", "ddpms",
#"flaxddims", "flaxddpms", "flaxpndms"
]
scheduler_name = st.selectbox("Scheduler:", scheduler_name_list, scheduler_name = st.selectbox("Scheduler:", scheduler_name_list,
index=scheduler_name_list.index(st.session_state['defaults'].txt2vid.scheduler_name), help="Scheduler to use. Default: klms") index=scheduler_name_list.index(st.session_state['defaults'].txt2vid.scheduler_name), help="Scheduler to use. Default: klms")