diff --git a/scripts/convert_original_stable_diffusion_to_diffusers.py b/scripts/convert_original_stable_diffusion_to_diffusers.py index 375b12b..9572c7f 100644 --- a/scripts/convert_original_stable_diffusion_to_diffusers.py +++ b/scripts/convert_original_stable_diffusion_to_diffusers.py @@ -30,7 +30,7 @@ except ImportError: from diffusers import ( AutoencoderKL, DDIMScheduler, - DPMSolverMultistepScheduler, + #DPMSolverMultistepScheduler, EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, LDMTextToImagePipeline, @@ -627,8 +627,11 @@ def convert_ldm_clip_checkpoint(checkpoint): for key in keys: if key.startswith("cond_stage_model.transformer"): text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key] - - text_model.load_state_dict(text_model_dict) + + try: + text_model.load_state_dict(text_model_dict) + except RuntimeError: + pass return text_model @@ -748,4 +751,4 @@ if __name__ == "__main__": tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased") pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet, scheduler=scheduler) - pipe.save_pretrained(args.dump_path) + pipe.save_pretrained(args.dump_path) \ No newline at end of file diff --git a/scripts/txt2vid.py b/scripts/txt2vid.py index b3afcc6..324cd94 100644 --- a/scripts/txt2vid.py +++ b/scripts/txt2vid.py @@ -54,7 +54,8 @@ from diffusers import StableDiffusionPipeline, DiffusionPipeline #from stable_diffusion_videos import StableDiffusionWalkPipeline from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, \ - PNDMScheduler + PNDMScheduler, DDPMScheduler, FlaxPNDMScheduler, FlaxDDIMScheduler, \ + FlaxDDPMScheduler, FlaxKarrasVeScheduler, IPNDMScheduler, KarrasVeScheduler from diffusers.configuration_utils import FrozenDict from diffusers.models import AutoencoderKL, UNet2DConditionModel @@ -1125,16 +1126,18 @@ def load_diffusers_model(weights_path,torch_device): if weights_path == "runwayml/stable-diffusion-v1-5": model_path = os.path.join("models", "diffusers", "stable-diffusion-v1-5") + else: + model_path = weights_path if not os.path.exists(model_path + "/model_index.json"): server_state["pipe"] = StableDiffusionPipeline.from_pretrained( weights_path, - use_local_file=True, - use_auth_token=st.session_state["defaults"].general.huggingface_token, - torch_dtype=torch.float16 if st.session_state['defaults'].general.use_float16 else None, - revision="fp16" if not st.session_state['defaults'].general.no_half else None, - safety_checker=None, # Very important for videos...lots of false positives while interpolating - #custom_pipeline="interpolate_stable_diffusion", + use_local_file=True, + use_auth_token=st.session_state["defaults"].general.huggingface_token, + torch_dtype=torch.float16 if st.session_state['defaults'].general.use_float16 else None, + revision="fp16" if not st.session_state['defaults'].general.no_half else None, + safety_checker=None, # Very important for videos...lots of false positives while interpolating + #custom_pipeline="interpolate_stable_diffusion", ) @@ -1142,11 +1145,11 @@ def load_diffusers_model(weights_path,torch_device): else: server_state["pipe"] = StableDiffusionPipeline.from_pretrained( model_path, - use_local_file=True, - torch_dtype=torch.float16 if st.session_state['defaults'].general.use_float16 else None, - revision="fp16" if not st.session_state['defaults'].general.no_half else None, - safety_checker=None, # Very important for videos...lots of false positives while interpolating - #custom_pipeline="interpolate_stable_diffusion", + use_local_file=True, + torch_dtype=torch.float16 if st.session_state['defaults'].general.use_float16 else None, + revision="fp16" if not st.session_state['defaults'].general.no_half else None, + safety_checker=None, # Very important for videos...lots of false positives while interpolating + #custom_pipeline="interpolate_stable_diffusion", ) server_state["pipe"].unet.to(torch_device) @@ -1358,8 +1361,30 @@ def txt2vid( klms_scheduler = LMSDiscreteScheduler( beta_start=beta_start, beta_end=beta_end, beta_schedule=beta_schedule ) + + #flaxddims_scheduler = FlaxDDIMScheduler( + #beta_start=beta_start, beta_end=beta_end, beta_schedule=beta_schedule + #) + + #flaxddpms_scheduler = FlaxDDPMScheduler( + #beta_start=beta_start, beta_end=beta_end, beta_schedule=beta_schedule + #) + + #flaxpndms_scheduler = FlaxPNDMScheduler( + #beta_start=beta_start, beta_end=beta_end, beta_schedule=beta_schedule + #) + + ddpms_scheduler = DDPMScheduler( + beta_start=beta_start, beta_end=beta_end, beta_schedule=beta_schedule + ) - SCHEDULERS = dict(default=default_scheduler, ddim=ddim_scheduler, klms=klms_scheduler) + SCHEDULERS = dict(default=default_scheduler, ddim=ddim_scheduler, + klms=klms_scheduler, + ddpms=ddpms_scheduler, + #flaxddims=flaxddims_scheduler, + #flaxddpms=flaxddpms_scheduler, + #flaxpndms=flaxpndms_scheduler, + ) with st.session_state["progress_bar_text"].container(): with hc.HyLoader('Loading Models...', hc.Loaders.standard_loaders,index=[0]): @@ -1721,7 +1746,9 @@ def layout(): #sampler_name_list = ["k_lms", "k_euler", "k_euler_a", "k_dpm_2", "k_dpm_2_a", "k_heun", "PLMS", "DDIM"] #sampler_name = st.selectbox("Sampling method", sampler_name_list, #index=sampler_name_list.index(st.session_state['defaults'].txt2vid.default_sampler), help="Sampling method to use. Default: k_euler") - scheduler_name_list = ["klms", "ddim"] + scheduler_name_list = ["klms", "ddim", "ddpms", + #"flaxddims", "flaxddpms", "flaxpndms" + ] scheduler_name = st.selectbox("Scheduler:", scheduler_name_list, index=scheduler_name_list.index(st.session_state['defaults'].txt2vid.scheduler_name), help="Scheduler to use. Default: klms")