mirror of
https://github.com/sd-webui/stable-diffusion-webui.git
synced 2024-12-14 06:35:14 +03:00
Fixed bug in txt2vid not allowing you to use custom models.
Fixed the script that convert ckpt models to diffusers not working with some models if they had invalid or broken keys, these keys are now ignored so we can still convert the model.
This commit is contained in:
parent
54e7da7721
commit
0d734219a8
@ -30,7 +30,7 @@ except ImportError:
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
DDIMScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
#DPMSolverMultistepScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
LDMTextToImagePipeline,
|
||||
@ -627,8 +627,11 @@ def convert_ldm_clip_checkpoint(checkpoint):
|
||||
for key in keys:
|
||||
if key.startswith("cond_stage_model.transformer"):
|
||||
text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
|
||||
|
||||
text_model.load_state_dict(text_model_dict)
|
||||
|
||||
try:
|
||||
text_model.load_state_dict(text_model_dict)
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
return text_model
|
||||
|
||||
@ -748,4 +751,4 @@ if __name__ == "__main__":
|
||||
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
|
||||
pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet, scheduler=scheduler)
|
||||
|
||||
pipe.save_pretrained(args.dump_path)
|
||||
pipe.save_pretrained(args.dump_path)
|
@ -54,7 +54,8 @@ from diffusers import StableDiffusionPipeline, DiffusionPipeline
|
||||
#from stable_diffusion_videos import StableDiffusionWalkPipeline
|
||||
|
||||
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, \
|
||||
PNDMScheduler
|
||||
PNDMScheduler, DDPMScheduler, FlaxPNDMScheduler, FlaxDDIMScheduler, \
|
||||
FlaxDDPMScheduler, FlaxKarrasVeScheduler, IPNDMScheduler, KarrasVeScheduler
|
||||
|
||||
from diffusers.configuration_utils import FrozenDict
|
||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||
@ -1125,16 +1126,18 @@ def load_diffusers_model(weights_path,torch_device):
|
||||
|
||||
if weights_path == "runwayml/stable-diffusion-v1-5":
|
||||
model_path = os.path.join("models", "diffusers", "stable-diffusion-v1-5")
|
||||
else:
|
||||
model_path = weights_path
|
||||
|
||||
if not os.path.exists(model_path + "/model_index.json"):
|
||||
server_state["pipe"] = StableDiffusionPipeline.from_pretrained(
|
||||
weights_path,
|
||||
use_local_file=True,
|
||||
use_auth_token=st.session_state["defaults"].general.huggingface_token,
|
||||
torch_dtype=torch.float16 if st.session_state['defaults'].general.use_float16 else None,
|
||||
revision="fp16" if not st.session_state['defaults'].general.no_half else None,
|
||||
safety_checker=None, # Very important for videos...lots of false positives while interpolating
|
||||
#custom_pipeline="interpolate_stable_diffusion",
|
||||
use_local_file=True,
|
||||
use_auth_token=st.session_state["defaults"].general.huggingface_token,
|
||||
torch_dtype=torch.float16 if st.session_state['defaults'].general.use_float16 else None,
|
||||
revision="fp16" if not st.session_state['defaults'].general.no_half else None,
|
||||
safety_checker=None, # Very important for videos...lots of false positives while interpolating
|
||||
#custom_pipeline="interpolate_stable_diffusion",
|
||||
|
||||
)
|
||||
|
||||
@ -1142,11 +1145,11 @@ def load_diffusers_model(weights_path,torch_device):
|
||||
else:
|
||||
server_state["pipe"] = StableDiffusionPipeline.from_pretrained(
|
||||
model_path,
|
||||
use_local_file=True,
|
||||
torch_dtype=torch.float16 if st.session_state['defaults'].general.use_float16 else None,
|
||||
revision="fp16" if not st.session_state['defaults'].general.no_half else None,
|
||||
safety_checker=None, # Very important for videos...lots of false positives while interpolating
|
||||
#custom_pipeline="interpolate_stable_diffusion",
|
||||
use_local_file=True,
|
||||
torch_dtype=torch.float16 if st.session_state['defaults'].general.use_float16 else None,
|
||||
revision="fp16" if not st.session_state['defaults'].general.no_half else None,
|
||||
safety_checker=None, # Very important for videos...lots of false positives while interpolating
|
||||
#custom_pipeline="interpolate_stable_diffusion",
|
||||
)
|
||||
|
||||
server_state["pipe"].unet.to(torch_device)
|
||||
@ -1358,8 +1361,30 @@ def txt2vid(
|
||||
klms_scheduler = LMSDiscreteScheduler(
|
||||
beta_start=beta_start, beta_end=beta_end, beta_schedule=beta_schedule
|
||||
)
|
||||
|
||||
#flaxddims_scheduler = FlaxDDIMScheduler(
|
||||
#beta_start=beta_start, beta_end=beta_end, beta_schedule=beta_schedule
|
||||
#)
|
||||
|
||||
#flaxddpms_scheduler = FlaxDDPMScheduler(
|
||||
#beta_start=beta_start, beta_end=beta_end, beta_schedule=beta_schedule
|
||||
#)
|
||||
|
||||
#flaxpndms_scheduler = FlaxPNDMScheduler(
|
||||
#beta_start=beta_start, beta_end=beta_end, beta_schedule=beta_schedule
|
||||
#)
|
||||
|
||||
ddpms_scheduler = DDPMScheduler(
|
||||
beta_start=beta_start, beta_end=beta_end, beta_schedule=beta_schedule
|
||||
)
|
||||
|
||||
SCHEDULERS = dict(default=default_scheduler, ddim=ddim_scheduler, klms=klms_scheduler)
|
||||
SCHEDULERS = dict(default=default_scheduler, ddim=ddim_scheduler,
|
||||
klms=klms_scheduler,
|
||||
ddpms=ddpms_scheduler,
|
||||
#flaxddims=flaxddims_scheduler,
|
||||
#flaxddpms=flaxddpms_scheduler,
|
||||
#flaxpndms=flaxpndms_scheduler,
|
||||
)
|
||||
|
||||
with st.session_state["progress_bar_text"].container():
|
||||
with hc.HyLoader('Loading Models...', hc.Loaders.standard_loaders,index=[0]):
|
||||
@ -1721,7 +1746,9 @@ def layout():
|
||||
#sampler_name_list = ["k_lms", "k_euler", "k_euler_a", "k_dpm_2", "k_dpm_2_a", "k_heun", "PLMS", "DDIM"]
|
||||
#sampler_name = st.selectbox("Sampling method", sampler_name_list,
|
||||
#index=sampler_name_list.index(st.session_state['defaults'].txt2vid.default_sampler), help="Sampling method to use. Default: k_euler")
|
||||
scheduler_name_list = ["klms", "ddim"]
|
||||
scheduler_name_list = ["klms", "ddim", "ddpms",
|
||||
#"flaxddims", "flaxddpms", "flaxpndms"
|
||||
]
|
||||
scheduler_name = st.selectbox("Scheduler:", scheduler_name_list,
|
||||
index=scheduler_name_list.index(st.session_state['defaults'].txt2vid.scheduler_name), help="Scheduler to use. Default: klms")
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user