mirror of
https://github.com/Sygil-Dev/sygil-webui.git
synced 2024-12-16 07:41:38 +03:00
Fixed bug in txt2vid not allowing you to use custom models.
Fixed the script that convert ckpt models to diffusers not working with some models if they had invalid or broken keys, these keys are now ignored so we can still convert the model.
This commit is contained in:
parent
54e7da7721
commit
0d734219a8
@ -30,7 +30,7 @@ except ImportError:
|
|||||||
from diffusers import (
|
from diffusers import (
|
||||||
AutoencoderKL,
|
AutoencoderKL,
|
||||||
DDIMScheduler,
|
DDIMScheduler,
|
||||||
DPMSolverMultistepScheduler,
|
#DPMSolverMultistepScheduler,
|
||||||
EulerAncestralDiscreteScheduler,
|
EulerAncestralDiscreteScheduler,
|
||||||
EulerDiscreteScheduler,
|
EulerDiscreteScheduler,
|
||||||
LDMTextToImagePipeline,
|
LDMTextToImagePipeline,
|
||||||
@ -627,8 +627,11 @@ def convert_ldm_clip_checkpoint(checkpoint):
|
|||||||
for key in keys:
|
for key in keys:
|
||||||
if key.startswith("cond_stage_model.transformer"):
|
if key.startswith("cond_stage_model.transformer"):
|
||||||
text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
|
text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
|
||||||
|
|
||||||
text_model.load_state_dict(text_model_dict)
|
try:
|
||||||
|
text_model.load_state_dict(text_model_dict)
|
||||||
|
except RuntimeError:
|
||||||
|
pass
|
||||||
|
|
||||||
return text_model
|
return text_model
|
||||||
|
|
||||||
@ -748,4 +751,4 @@ if __name__ == "__main__":
|
|||||||
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
|
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
|
||||||
pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet, scheduler=scheduler)
|
pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet, scheduler=scheduler)
|
||||||
|
|
||||||
pipe.save_pretrained(args.dump_path)
|
pipe.save_pretrained(args.dump_path)
|
@ -54,7 +54,8 @@ from diffusers import StableDiffusionPipeline, DiffusionPipeline
|
|||||||
#from stable_diffusion_videos import StableDiffusionWalkPipeline
|
#from stable_diffusion_videos import StableDiffusionWalkPipeline
|
||||||
|
|
||||||
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, \
|
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, \
|
||||||
PNDMScheduler
|
PNDMScheduler, DDPMScheduler, FlaxPNDMScheduler, FlaxDDIMScheduler, \
|
||||||
|
FlaxDDPMScheduler, FlaxKarrasVeScheduler, IPNDMScheduler, KarrasVeScheduler
|
||||||
|
|
||||||
from diffusers.configuration_utils import FrozenDict
|
from diffusers.configuration_utils import FrozenDict
|
||||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||||
@ -1125,16 +1126,18 @@ def load_diffusers_model(weights_path,torch_device):
|
|||||||
|
|
||||||
if weights_path == "runwayml/stable-diffusion-v1-5":
|
if weights_path == "runwayml/stable-diffusion-v1-5":
|
||||||
model_path = os.path.join("models", "diffusers", "stable-diffusion-v1-5")
|
model_path = os.path.join("models", "diffusers", "stable-diffusion-v1-5")
|
||||||
|
else:
|
||||||
|
model_path = weights_path
|
||||||
|
|
||||||
if not os.path.exists(model_path + "/model_index.json"):
|
if not os.path.exists(model_path + "/model_index.json"):
|
||||||
server_state["pipe"] = StableDiffusionPipeline.from_pretrained(
|
server_state["pipe"] = StableDiffusionPipeline.from_pretrained(
|
||||||
weights_path,
|
weights_path,
|
||||||
use_local_file=True,
|
use_local_file=True,
|
||||||
use_auth_token=st.session_state["defaults"].general.huggingface_token,
|
use_auth_token=st.session_state["defaults"].general.huggingface_token,
|
||||||
torch_dtype=torch.float16 if st.session_state['defaults'].general.use_float16 else None,
|
torch_dtype=torch.float16 if st.session_state['defaults'].general.use_float16 else None,
|
||||||
revision="fp16" if not st.session_state['defaults'].general.no_half else None,
|
revision="fp16" if not st.session_state['defaults'].general.no_half else None,
|
||||||
safety_checker=None, # Very important for videos...lots of false positives while interpolating
|
safety_checker=None, # Very important for videos...lots of false positives while interpolating
|
||||||
#custom_pipeline="interpolate_stable_diffusion",
|
#custom_pipeline="interpolate_stable_diffusion",
|
||||||
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1142,11 +1145,11 @@ def load_diffusers_model(weights_path,torch_device):
|
|||||||
else:
|
else:
|
||||||
server_state["pipe"] = StableDiffusionPipeline.from_pretrained(
|
server_state["pipe"] = StableDiffusionPipeline.from_pretrained(
|
||||||
model_path,
|
model_path,
|
||||||
use_local_file=True,
|
use_local_file=True,
|
||||||
torch_dtype=torch.float16 if st.session_state['defaults'].general.use_float16 else None,
|
torch_dtype=torch.float16 if st.session_state['defaults'].general.use_float16 else None,
|
||||||
revision="fp16" if not st.session_state['defaults'].general.no_half else None,
|
revision="fp16" if not st.session_state['defaults'].general.no_half else None,
|
||||||
safety_checker=None, # Very important for videos...lots of false positives while interpolating
|
safety_checker=None, # Very important for videos...lots of false positives while interpolating
|
||||||
#custom_pipeline="interpolate_stable_diffusion",
|
#custom_pipeline="interpolate_stable_diffusion",
|
||||||
)
|
)
|
||||||
|
|
||||||
server_state["pipe"].unet.to(torch_device)
|
server_state["pipe"].unet.to(torch_device)
|
||||||
@ -1358,8 +1361,30 @@ def txt2vid(
|
|||||||
klms_scheduler = LMSDiscreteScheduler(
|
klms_scheduler = LMSDiscreteScheduler(
|
||||||
beta_start=beta_start, beta_end=beta_end, beta_schedule=beta_schedule
|
beta_start=beta_start, beta_end=beta_end, beta_schedule=beta_schedule
|
||||||
)
|
)
|
||||||
|
|
||||||
|
#flaxddims_scheduler = FlaxDDIMScheduler(
|
||||||
|
#beta_start=beta_start, beta_end=beta_end, beta_schedule=beta_schedule
|
||||||
|
#)
|
||||||
|
|
||||||
|
#flaxddpms_scheduler = FlaxDDPMScheduler(
|
||||||
|
#beta_start=beta_start, beta_end=beta_end, beta_schedule=beta_schedule
|
||||||
|
#)
|
||||||
|
|
||||||
|
#flaxpndms_scheduler = FlaxPNDMScheduler(
|
||||||
|
#beta_start=beta_start, beta_end=beta_end, beta_schedule=beta_schedule
|
||||||
|
#)
|
||||||
|
|
||||||
|
ddpms_scheduler = DDPMScheduler(
|
||||||
|
beta_start=beta_start, beta_end=beta_end, beta_schedule=beta_schedule
|
||||||
|
)
|
||||||
|
|
||||||
SCHEDULERS = dict(default=default_scheduler, ddim=ddim_scheduler, klms=klms_scheduler)
|
SCHEDULERS = dict(default=default_scheduler, ddim=ddim_scheduler,
|
||||||
|
klms=klms_scheduler,
|
||||||
|
ddpms=ddpms_scheduler,
|
||||||
|
#flaxddims=flaxddims_scheduler,
|
||||||
|
#flaxddpms=flaxddpms_scheduler,
|
||||||
|
#flaxpndms=flaxpndms_scheduler,
|
||||||
|
)
|
||||||
|
|
||||||
with st.session_state["progress_bar_text"].container():
|
with st.session_state["progress_bar_text"].container():
|
||||||
with hc.HyLoader('Loading Models...', hc.Loaders.standard_loaders,index=[0]):
|
with hc.HyLoader('Loading Models...', hc.Loaders.standard_loaders,index=[0]):
|
||||||
@ -1721,7 +1746,9 @@ def layout():
|
|||||||
#sampler_name_list = ["k_lms", "k_euler", "k_euler_a", "k_dpm_2", "k_dpm_2_a", "k_heun", "PLMS", "DDIM"]
|
#sampler_name_list = ["k_lms", "k_euler", "k_euler_a", "k_dpm_2", "k_dpm_2_a", "k_heun", "PLMS", "DDIM"]
|
||||||
#sampler_name = st.selectbox("Sampling method", sampler_name_list,
|
#sampler_name = st.selectbox("Sampling method", sampler_name_list,
|
||||||
#index=sampler_name_list.index(st.session_state['defaults'].txt2vid.default_sampler), help="Sampling method to use. Default: k_euler")
|
#index=sampler_name_list.index(st.session_state['defaults'].txt2vid.default_sampler), help="Sampling method to use. Default: k_euler")
|
||||||
scheduler_name_list = ["klms", "ddim"]
|
scheduler_name_list = ["klms", "ddim", "ddpms",
|
||||||
|
#"flaxddims", "flaxddpms", "flaxpndms"
|
||||||
|
]
|
||||||
scheduler_name = st.selectbox("Scheduler:", scheduler_name_list,
|
scheduler_name = st.selectbox("Scheduler:", scheduler_name_list,
|
||||||
index=scheduler_name_list.index(st.session_state['defaults'].txt2vid.scheduler_name), help="Scheduler to use. Default: klms")
|
index=scheduler_name_list.index(st.session_state['defaults'].txt2vid.scheduler_name), help="Scheduler to use. Default: klms")
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user