Fixed bug in txt2vid not allowing you to use custom models.

Fixed the script that convert ckpt models to diffusers not working with some models if they had invalid or broken keys, these keys are now ignored so we can still convert the model.
This commit is contained in:
ZeroCool940711 2022-11-23 19:56:40 -07:00
parent 54e7da7721
commit 0d734219a8
No known key found for this signature in database
GPG Key ID: 4E4072992B5BC640
2 changed files with 48 additions and 18 deletions

View File

@ -30,7 +30,7 @@ except ImportError:
from diffusers import (
AutoencoderKL,
DDIMScheduler,
DPMSolverMultistepScheduler,
#DPMSolverMultistepScheduler,
EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler,
LDMTextToImagePipeline,
@ -627,8 +627,11 @@ def convert_ldm_clip_checkpoint(checkpoint):
for key in keys:
if key.startswith("cond_stage_model.transformer"):
text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
text_model.load_state_dict(text_model_dict)
try:
text_model.load_state_dict(text_model_dict)
except RuntimeError:
pass
return text_model
@ -748,4 +751,4 @@ if __name__ == "__main__":
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet, scheduler=scheduler)
pipe.save_pretrained(args.dump_path)
pipe.save_pretrained(args.dump_path)

View File

@ -54,7 +54,8 @@ from diffusers import StableDiffusionPipeline, DiffusionPipeline
#from stable_diffusion_videos import StableDiffusionWalkPipeline
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, \
PNDMScheduler
PNDMScheduler, DDPMScheduler, FlaxPNDMScheduler, FlaxDDIMScheduler, \
FlaxDDPMScheduler, FlaxKarrasVeScheduler, IPNDMScheduler, KarrasVeScheduler
from diffusers.configuration_utils import FrozenDict
from diffusers.models import AutoencoderKL, UNet2DConditionModel
@ -1125,16 +1126,18 @@ def load_diffusers_model(weights_path,torch_device):
if weights_path == "runwayml/stable-diffusion-v1-5":
model_path = os.path.join("models", "diffusers", "stable-diffusion-v1-5")
else:
model_path = weights_path
if not os.path.exists(model_path + "/model_index.json"):
server_state["pipe"] = StableDiffusionPipeline.from_pretrained(
weights_path,
use_local_file=True,
use_auth_token=st.session_state["defaults"].general.huggingface_token,
torch_dtype=torch.float16 if st.session_state['defaults'].general.use_float16 else None,
revision="fp16" if not st.session_state['defaults'].general.no_half else None,
safety_checker=None, # Very important for videos...lots of false positives while interpolating
#custom_pipeline="interpolate_stable_diffusion",
use_local_file=True,
use_auth_token=st.session_state["defaults"].general.huggingface_token,
torch_dtype=torch.float16 if st.session_state['defaults'].general.use_float16 else None,
revision="fp16" if not st.session_state['defaults'].general.no_half else None,
safety_checker=None, # Very important for videos...lots of false positives while interpolating
#custom_pipeline="interpolate_stable_diffusion",
)
@ -1142,11 +1145,11 @@ def load_diffusers_model(weights_path,torch_device):
else:
server_state["pipe"] = StableDiffusionPipeline.from_pretrained(
model_path,
use_local_file=True,
torch_dtype=torch.float16 if st.session_state['defaults'].general.use_float16 else None,
revision="fp16" if not st.session_state['defaults'].general.no_half else None,
safety_checker=None, # Very important for videos...lots of false positives while interpolating
#custom_pipeline="interpolate_stable_diffusion",
use_local_file=True,
torch_dtype=torch.float16 if st.session_state['defaults'].general.use_float16 else None,
revision="fp16" if not st.session_state['defaults'].general.no_half else None,
safety_checker=None, # Very important for videos...lots of false positives while interpolating
#custom_pipeline="interpolate_stable_diffusion",
)
server_state["pipe"].unet.to(torch_device)
@ -1358,8 +1361,30 @@ def txt2vid(
klms_scheduler = LMSDiscreteScheduler(
beta_start=beta_start, beta_end=beta_end, beta_schedule=beta_schedule
)
#flaxddims_scheduler = FlaxDDIMScheduler(
#beta_start=beta_start, beta_end=beta_end, beta_schedule=beta_schedule
#)
#flaxddpms_scheduler = FlaxDDPMScheduler(
#beta_start=beta_start, beta_end=beta_end, beta_schedule=beta_schedule
#)
#flaxpndms_scheduler = FlaxPNDMScheduler(
#beta_start=beta_start, beta_end=beta_end, beta_schedule=beta_schedule
#)
ddpms_scheduler = DDPMScheduler(
beta_start=beta_start, beta_end=beta_end, beta_schedule=beta_schedule
)
SCHEDULERS = dict(default=default_scheduler, ddim=ddim_scheduler, klms=klms_scheduler)
SCHEDULERS = dict(default=default_scheduler, ddim=ddim_scheduler,
klms=klms_scheduler,
ddpms=ddpms_scheduler,
#flaxddims=flaxddims_scheduler,
#flaxddpms=flaxddpms_scheduler,
#flaxpndms=flaxpndms_scheduler,
)
with st.session_state["progress_bar_text"].container():
with hc.HyLoader('Loading Models...', hc.Loaders.standard_loaders,index=[0]):
@ -1721,7 +1746,9 @@ def layout():
#sampler_name_list = ["k_lms", "k_euler", "k_euler_a", "k_dpm_2", "k_dpm_2_a", "k_heun", "PLMS", "DDIM"]
#sampler_name = st.selectbox("Sampling method", sampler_name_list,
#index=sampler_name_list.index(st.session_state['defaults'].txt2vid.default_sampler), help="Sampling method to use. Default: k_euler")
scheduler_name_list = ["klms", "ddim"]
scheduler_name_list = ["klms", "ddim", "ddpms",
#"flaxddims", "flaxddpms", "flaxpndms"
]
scheduler_name = st.selectbox("Scheduler:", scheduler_name_list,
index=scheduler_name_list.index(st.session_state['defaults'].txt2vid.scheduler_name), help="Scheduler to use. Default: klms")