Changed the loading of the model on the txt2vid tab so the half models are only loaded if the no_half option on the config file is set to False.

This commit is contained in:
ZeroCool940711 2022-09-13 01:49:03 -07:00
parent 299cef698d
commit cbbf33d735

View File

@ -1910,6 +1910,8 @@ def txt2vid(
default_scheduler = PNDMScheduler( default_scheduler = PNDMScheduler(
beta_start=beta_start, beta_end=beta_end, beta_schedule=beta_schedule beta_start=beta_start, beta_end=beta_end, beta_schedule=beta_schedule
) )
# ------------------------------------------------------------------------------
#Schedulers
ddim_scheduler = DDIMScheduler( ddim_scheduler = DDIMScheduler(
beta_start=beta_start, beta_start=beta_start,
beta_end=beta_end, beta_end=beta_end,
@ -1917,14 +1919,15 @@ def txt2vid(
clip_sample=False, clip_sample=False,
set_alpha_to_one=False, set_alpha_to_one=False,
) )
klms_scheduler = LMSDiscreteScheduler( klms_scheduler = LMSDiscreteScheduler(
beta_start=beta_start, beta_end=beta_end, beta_schedule=beta_schedule beta_start=beta_start, beta_end=beta_end, beta_schedule=beta_schedule
#beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
) )
#lms = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear")
SCHEDULERS = dict(default=default_scheduler, ddim=ddim_scheduler, klms=klms_scheduler) SCHEDULERS = dict(default=default_scheduler, ddim=ddim_scheduler, klms=klms_scheduler)
# ------------------------------------------------------------------------------
#if weights_path == "Stable Diffusion v1.4": #if weights_path == "Stable Diffusion v1.4":
#weights_path = "CompVis/stable-diffusion-v1-4" #weights_path = "CompVis/stable-diffusion-v1-4"
#else: #else:
@ -1948,8 +1951,8 @@ def txt2vid(
weights_path, weights_path,
use_local_file=True, use_local_file=True,
use_auth_token=True, use_auth_token=True,
#torch_dtype=torch.float16, torch_dtype=torch.float16 if not defaults.general.no_half else None,
#revision="fp16" revision="fp16" if not defaults.general.no_half else None
) )
st.session_state["pipe"].unet.to(torch_device) st.session_state["pipe"].unet.to(torch_device)
@ -1958,24 +1961,6 @@ def txt2vid(
print("Tx2Vid Model Loaded") print("Tx2Vid Model Loaded")
else: else:
print("Tx2Vid Model already Loaded") print("Tx2Vid Model already Loaded")
except RuntimeError:
#del st.session_state["weights_path"]
#del st.session_state["pipe"]
st.session_state["weights_path"] = weights_path
st.session_state["pipe"] = StableDiffusionPipeline.from_pretrained(
weights_path,
use_local_file=True,
use_auth_token=True,
#torch_dtype=torch.float16,
revision="fp16"
)
st.session_state["pipe"].unet.to(torch_device)
st.session_state["pipe"].vae.to(torch_device)
st.session_state["pipe"].text_encoder.to(torch_device)
print("Tx2Vid Model Loaded")
except: except:
#del st.session_state["weights_path"] #del st.session_state["weights_path"]
@ -1986,8 +1971,8 @@ def txt2vid(
weights_path, weights_path,
use_local_file=True, use_local_file=True,
use_auth_token=True, use_auth_token=True,
torch_dtype=torch.float16, torch_dtype=torch.float16 if not defaults.general.no_half else None,
#revision="fp16" revision="fp16" if not defaults.general.no_half else None
) )
st.session_state["pipe"].unet.to(torch_device) st.session_state["pipe"].unet.to(torch_device)