Changed the loading of the model on the txt2vid tab so the half models are only loaded if the no_half option on the config file is set to False.

This commit is contained in:
ZeroCool940711 2022-09-13 01:49:03 -07:00
parent 299cef698d
commit cbbf33d735

View File

@ -1910,6 +1910,8 @@ def txt2vid(
default_scheduler = PNDMScheduler(
beta_start=beta_start, beta_end=beta_end, beta_schedule=beta_schedule
)
# ------------------------------------------------------------------------------
#Schedulers
ddim_scheduler = DDIMScheduler(
beta_start=beta_start,
beta_end=beta_end,
@ -1917,14 +1919,15 @@ def txt2vid(
clip_sample=False,
set_alpha_to_one=False,
)
klms_scheduler = LMSDiscreteScheduler(
beta_start=beta_start, beta_end=beta_end, beta_schedule=beta_schedule
#beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
)
#lms = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear")
SCHEDULERS = dict(default=default_scheduler, ddim=ddim_scheduler, klms=klms_scheduler)
# ------------------------------------------------------------------------------
#if weights_path == "Stable Diffusion v1.4":
#weights_path = "CompVis/stable-diffusion-v1-4"
#else:
@ -1948,8 +1951,8 @@ def txt2vid(
weights_path,
use_local_file=True,
use_auth_token=True,
#torch_dtype=torch.float16,
#revision="fp16"
torch_dtype=torch.float16 if not defaults.general.no_half else None,
revision="fp16" if not defaults.general.no_half else None
)
st.session_state["pipe"].unet.to(torch_device)
@ -1959,24 +1962,6 @@ def txt2vid(
else:
print("Tx2Vid Model already Loaded")
except RuntimeError:
#del st.session_state["weights_path"]
#del st.session_state["pipe"]
st.session_state["weights_path"] = weights_path
st.session_state["pipe"] = StableDiffusionPipeline.from_pretrained(
weights_path,
use_local_file=True,
use_auth_token=True,
#torch_dtype=torch.float16,
revision="fp16"
)
st.session_state["pipe"].unet.to(torch_device)
st.session_state["pipe"].vae.to(torch_device)
st.session_state["pipe"].text_encoder.to(torch_device)
print("Tx2Vid Model Loaded")
except:
#del st.session_state["weights_path"]
#del st.session_state["pipe"]
@ -1986,8 +1971,8 @@ def txt2vid(
weights_path,
use_local_file=True,
use_auth_token=True,
torch_dtype=torch.float16,
#revision="fp16"
torch_dtype=torch.float16 if not defaults.general.no_half else None,
revision="fp16" if not defaults.general.no_half else None
)
st.session_state["pipe"].unet.to(torch_device)