Merge branch 'Sygil-Dev:dev' into dev

This commit is contained in:
aedh carrick 2022-11-24 12:13:44 -06:00 committed by GitHub
commit 20a7b6387a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 48 additions and 19 deletions

View File

@ -48,7 +48,7 @@ barfi==0.7.0
# Environment Dependencies for WebUI (flet)
# txt2vid
diffusers==0.6.0
diffusers==0.7.2
librosa==0.9.2
# img2img inpainting

View File

@ -30,7 +30,7 @@ except ImportError:
from diffusers import (
AutoencoderKL,
DDIMScheduler,
DPMSolverMultistepScheduler,
#DPMSolverMultistepScheduler,
EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler,
LDMTextToImagePipeline,
@ -627,8 +627,11 @@ def convert_ldm_clip_checkpoint(checkpoint):
for key in keys:
if key.startswith("cond_stage_model.transformer"):
text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
text_model.load_state_dict(text_model_dict)
try:
text_model.load_state_dict(text_model_dict)
except RuntimeError:
pass
return text_model
@ -748,4 +751,4 @@ if __name__ == "__main__":
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet, scheduler=scheduler)
pipe.save_pretrained(args.dump_path)
pipe.save_pretrained(args.dump_path)

View File

@ -54,7 +54,7 @@ from diffusers import StableDiffusionPipeline, DiffusionPipeline
#from stable_diffusion_videos import StableDiffusionWalkPipeline
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, \
PNDMScheduler
PNDMScheduler, DDPMScheduler
from diffusers.configuration_utils import FrozenDict
from diffusers.models import AutoencoderKL, UNet2DConditionModel
@ -1125,16 +1125,18 @@ def load_diffusers_model(weights_path,torch_device):
if weights_path == "runwayml/stable-diffusion-v1-5":
model_path = os.path.join("models", "diffusers", "stable-diffusion-v1-5")
else:
model_path = weights_path
if not os.path.exists(model_path + "/model_index.json"):
server_state["pipe"] = StableDiffusionPipeline.from_pretrained(
weights_path,
use_local_file=True,
use_auth_token=st.session_state["defaults"].general.huggingface_token,
torch_dtype=torch.float16 if st.session_state['defaults'].general.use_float16 else None,
revision="fp16" if not st.session_state['defaults'].general.no_half else None,
safety_checker=None, # Very important for videos...lots of false positives while interpolating
#custom_pipeline="interpolate_stable_diffusion",
use_local_file=True,
use_auth_token=st.session_state["defaults"].general.huggingface_token,
torch_dtype=torch.float16 if st.session_state['defaults'].general.use_float16 else None,
revision="fp16" if not st.session_state['defaults'].general.no_half else None,
safety_checker=None, # Very important for videos...lots of false positives while interpolating
#custom_pipeline="interpolate_stable_diffusion",
)
@ -1142,11 +1144,11 @@ def load_diffusers_model(weights_path,torch_device):
else:
server_state["pipe"] = StableDiffusionPipeline.from_pretrained(
model_path,
use_local_file=True,
torch_dtype=torch.float16 if st.session_state['defaults'].general.use_float16 else None,
revision="fp16" if not st.session_state['defaults'].general.no_half else None,
safety_checker=None, # Very important for videos...lots of false positives while interpolating
#custom_pipeline="interpolate_stable_diffusion",
use_local_file=True,
torch_dtype=torch.float16 if st.session_state['defaults'].general.use_float16 else None,
revision="fp16" if not st.session_state['defaults'].general.no_half else None,
safety_checker=None, # Very important for videos...lots of false positives while interpolating
#custom_pipeline="interpolate_stable_diffusion",
)
server_state["pipe"].unet.to(torch_device)
@ -1358,8 +1360,30 @@ def txt2vid(
klms_scheduler = LMSDiscreteScheduler(
beta_start=beta_start, beta_end=beta_end, beta_schedule=beta_schedule
)
#flaxddims_scheduler = FlaxDDIMScheduler(
#beta_start=beta_start, beta_end=beta_end, beta_schedule=beta_schedule
#)
#flaxddpms_scheduler = FlaxDDPMScheduler(
#beta_start=beta_start, beta_end=beta_end, beta_schedule=beta_schedule
#)
#flaxpndms_scheduler = FlaxPNDMScheduler(
#beta_start=beta_start, beta_end=beta_end, beta_schedule=beta_schedule
#)
ddpms_scheduler = DDPMScheduler(
beta_start=beta_start, beta_end=beta_end, beta_schedule=beta_schedule
)
SCHEDULERS = dict(default=default_scheduler, ddim=ddim_scheduler, klms=klms_scheduler)
SCHEDULERS = dict(default=default_scheduler, ddim=ddim_scheduler,
klms=klms_scheduler,
ddpms=ddpms_scheduler,
#flaxddims=flaxddims_scheduler,
#flaxddpms=flaxddpms_scheduler,
#flaxpndms=flaxpndms_scheduler,
)
with st.session_state["progress_bar_text"].container():
with hc.HyLoader('Loading Models...', hc.Loaders.standard_loaders,index=[0]):
@ -1721,7 +1745,9 @@ def layout():
#sampler_name_list = ["k_lms", "k_euler", "k_euler_a", "k_dpm_2", "k_dpm_2_a", "k_heun", "PLMS", "DDIM"]
#sampler_name = st.selectbox("Sampling method", sampler_name_list,
#index=sampler_name_list.index(st.session_state['defaults'].txt2vid.default_sampler), help="Sampling method to use. Default: k_euler")
scheduler_name_list = ["klms", "ddim"]
scheduler_name_list = ["klms", "ddim", "ddpms",
#"flaxddims", "flaxddpms", "flaxpndms"
]
scheduler_name = st.selectbox("Scheduler:", scheduler_name_list,
index=scheduler_name_list.index(st.session_state['defaults'].txt2vid.scheduler_name), help="Scheduler to use. Default: klms")