From 0d734219a863b8cf9b40d8e5006a187e34ff88d6 Mon Sep 17 00:00:00 2001 From: ZeroCool940711 Date: Wed, 23 Nov 2022 19:56:40 -0700 Subject: [PATCH] Fixed bug in txt2vid not allowing you to use custom models. Fixed the script that convert ckpt models to diffusers not working with some models if they had invalid or broken keys, these keys are now ignored so we can still convert the model. --- ..._original_stable_diffusion_to_diffusers.py | 11 ++-- scripts/txt2vid.py | 55 ++++++++++++++----- 2 files changed, 48 insertions(+), 18 deletions(-) diff --git a/scripts/convert_original_stable_diffusion_to_diffusers.py b/scripts/convert_original_stable_diffusion_to_diffusers.py index 375b12b..9572c7f 100644 --- a/scripts/convert_original_stable_diffusion_to_diffusers.py +++ b/scripts/convert_original_stable_diffusion_to_diffusers.py @@ -30,7 +30,7 @@ except ImportError: from diffusers import ( AutoencoderKL, DDIMScheduler, - DPMSolverMultistepScheduler, + #DPMSolverMultistepScheduler, EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, LDMTextToImagePipeline, @@ -627,8 +627,11 @@ def convert_ldm_clip_checkpoint(checkpoint): for key in keys: if key.startswith("cond_stage_model.transformer"): text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key] - - text_model.load_state_dict(text_model_dict) + + try: + text_model.load_state_dict(text_model_dict) + except RuntimeError: + pass return text_model @@ -748,4 +751,4 @@ if __name__ == "__main__": tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased") pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet, scheduler=scheduler) - pipe.save_pretrained(args.dump_path) + pipe.save_pretrained(args.dump_path) \ No newline at end of file diff --git a/scripts/txt2vid.py b/scripts/txt2vid.py index b3afcc6..324cd94 100644 --- a/scripts/txt2vid.py +++ b/scripts/txt2vid.py @@ -54,7 +54,8 @@ from diffusers import StableDiffusionPipeline, DiffusionPipeline #from stable_diffusion_videos import StableDiffusionWalkPipeline from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, \ - PNDMScheduler + PNDMScheduler, DDPMScheduler, FlaxPNDMScheduler, FlaxDDIMScheduler, \ + FlaxDDPMScheduler, FlaxKarrasVeScheduler, IPNDMScheduler, KarrasVeScheduler from diffusers.configuration_utils import FrozenDict from diffusers.models import AutoencoderKL, UNet2DConditionModel @@ -1125,16 +1126,18 @@ def load_diffusers_model(weights_path,torch_device): if weights_path == "runwayml/stable-diffusion-v1-5": model_path = os.path.join("models", "diffusers", "stable-diffusion-v1-5") + else: + model_path = weights_path if not os.path.exists(model_path + "/model_index.json"): server_state["pipe"] = StableDiffusionPipeline.from_pretrained( weights_path, - use_local_file=True, - use_auth_token=st.session_state["defaults"].general.huggingface_token, - torch_dtype=torch.float16 if st.session_state['defaults'].general.use_float16 else None, - revision="fp16" if not st.session_state['defaults'].general.no_half else None, - safety_checker=None, # Very important for videos...lots of false positives while interpolating - #custom_pipeline="interpolate_stable_diffusion", + use_local_file=True, + use_auth_token=st.session_state["defaults"].general.huggingface_token, + torch_dtype=torch.float16 if st.session_state['defaults'].general.use_float16 else None, + revision="fp16" if not st.session_state['defaults'].general.no_half else None, + safety_checker=None, # Very important for videos...lots of false positives while interpolating + #custom_pipeline="interpolate_stable_diffusion", ) @@ -1142,11 +1145,11 @@ def load_diffusers_model(weights_path,torch_device): else: server_state["pipe"] = StableDiffusionPipeline.from_pretrained( model_path, - use_local_file=True, - torch_dtype=torch.float16 if st.session_state['defaults'].general.use_float16 else None, - revision="fp16" if not st.session_state['defaults'].general.no_half else None, - safety_checker=None, # Very important for videos...lots of false positives while interpolating - #custom_pipeline="interpolate_stable_diffusion", + use_local_file=True, + torch_dtype=torch.float16 if st.session_state['defaults'].general.use_float16 else None, + revision="fp16" if not st.session_state['defaults'].general.no_half else None, + safety_checker=None, # Very important for videos...lots of false positives while interpolating + #custom_pipeline="interpolate_stable_diffusion", ) server_state["pipe"].unet.to(torch_device) @@ -1358,8 +1361,30 @@ def txt2vid( klms_scheduler = LMSDiscreteScheduler( beta_start=beta_start, beta_end=beta_end, beta_schedule=beta_schedule ) + + #flaxddims_scheduler = FlaxDDIMScheduler( + #beta_start=beta_start, beta_end=beta_end, beta_schedule=beta_schedule + #) + + #flaxddpms_scheduler = FlaxDDPMScheduler( + #beta_start=beta_start, beta_end=beta_end, beta_schedule=beta_schedule + #) + + #flaxpndms_scheduler = FlaxPNDMScheduler( + #beta_start=beta_start, beta_end=beta_end, beta_schedule=beta_schedule + #) + + ddpms_scheduler = DDPMScheduler( + beta_start=beta_start, beta_end=beta_end, beta_schedule=beta_schedule + ) - SCHEDULERS = dict(default=default_scheduler, ddim=ddim_scheduler, klms=klms_scheduler) + SCHEDULERS = dict(default=default_scheduler, ddim=ddim_scheduler, + klms=klms_scheduler, + ddpms=ddpms_scheduler, + #flaxddims=flaxddims_scheduler, + #flaxddpms=flaxddpms_scheduler, + #flaxpndms=flaxpndms_scheduler, + ) with st.session_state["progress_bar_text"].container(): with hc.HyLoader('Loading Models...', hc.Loaders.standard_loaders,index=[0]): @@ -1721,7 +1746,9 @@ def layout(): #sampler_name_list = ["k_lms", "k_euler", "k_euler_a", "k_dpm_2", "k_dpm_2_a", "k_heun", "PLMS", "DDIM"] #sampler_name = st.selectbox("Sampling method", sampler_name_list, #index=sampler_name_list.index(st.session_state['defaults'].txt2vid.default_sampler), help="Sampling method to use. Default: k_euler") - scheduler_name_list = ["klms", "ddim"] + scheduler_name_list = ["klms", "ddim", "ddpms", + #"flaxddims", "flaxddpms", "flaxpndms" + ] scheduler_name = st.selectbox("Scheduler:", scheduler_name_list, index=scheduler_name_list.index(st.session_state['defaults'].txt2vid.scheduler_name), help="Scheduler to use. Default: klms")