You can now change the optimized mode, float16 and no_half options on the Settings page without having to restart the app. (#1503)

This commit is contained in:
Alejandro Gil 2022-10-10 20:46:46 -07:00 committed by GitHub
commit 994341efba
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -224,11 +224,15 @@ def load_diffusers_model(weights_path,torch_device):
try:
with server_state_lock["pipe"]:
if "pipe" not in server_state:
if ("weights_path" in st.session_state) and st.session_state["weights_path"] != weights_path:
if "weights_path" in st.session_state and st.session_state["weights_path"] != weights_path:
del st.session_state["weights_path"]
st.session_state["weights_path"] = weights_path
# if folder "models/diffusers/stable-diffusion-v1-4" exists, load the model from there
st.session_state['float16'] = st.session_state['defaults'].general.use_float16
st.session_state['no_half'] = st.session_state['defaults'].general.no_half
st.session_state['optimized'] = st.session_state['defaults'].general.optimized
#if folder "models/diffusers/stable-diffusion-v1-4" exists, load the model from there
if weights_path == "CompVis/stable-diffusion-v1-4":
model_path = os.path.join("models", "diffusers", "stable-diffusion-v1-4")
@ -261,12 +265,36 @@ def load_diffusers_model(weights_path,torch_device):
print("Tx2Vid Model Loaded")
else:
print("Tx2Vid Model already Loaded")
except (EnvironmentError, OSError):
st.session_state["progress_bar_text"].error(
"You need a huggingface token in order to use the Text to Video tab. Use the Settings page from the sidebar on the left to add your token."
)
raise OSError("You need a huggingface token in order to use the Text to Video tab. Use the Settings page from the sidebar on the left to add your token.")
# if the float16 or no_half options have changed since the last time the model was loaded then we need to reload the model.
if ("float16" in st.session_state and st.session_state['float16'] != st.session_state['defaults'].general.use_float16) \
or ("no_half" in st.session_state and st.session_state['no_half'] != st.session_state['defaults'].general.no_half) \
or ("optimized" in st.session_state and st.session_state['optimized'] != st.session_state['defaults'].general.optimized):
del st.session_state['float16']
del st.session_state['no_half']
with server_state_lock["pipe"]:
del server_state["pipe"]
torch_gc()
del st.session_state['optimized']
st.session_state['float16'] = st.session_state['defaults'].general.use_float16
st.session_state['no_half'] = st.session_state['defaults'].general.no_half
st.session_state['optimized'] = st.session_state['defaults'].general.optimized
load_diffusers_model(weights_path, torch_device)
else:
print("Tx2Vid Model already Loaded")
except (EnvironmentError, OSError) as e:
if "huggingface_token" not in st.session_state or st.session_state["defaults"].general.huggingface_token == "None":
st.session_state["progress_bar_text"].error(
"You need a huggingface token in order to use the Text to Video tab. Use the Settings page from the sidebar on the left to add your token."
)
raise OSError("You need a huggingface token in order to use the Text to Video tab. Use the Settings page from the sidebar on the left to add your token.")
else:
st.session_state["progress_bar_text"].error(e)
#
def save_video_to_disk(frames, seeds, sanitized_prompt, fps=6,save_video=True, outdir='outputs'):
if save_video:
@ -413,14 +441,14 @@ def txt2vid(
SCHEDULERS = dict(default=default_scheduler, ddim=ddim_scheduler, klms=klms_scheduler)
if "pipe" not in server_state:
with st.session_state["progress_bar_text"].container():
with hc.HyLoader('Loading Models...', hc.Loaders.standard_loaders,index=[0]):
if "model" in st.session_state:
del st.session_state["model"]
load_diffusers_model(weights_path, torch_device)
else:
print("Model already loaded")
#if "pipe" not in server_state:
with st.session_state["progress_bar_text"].container():
with hc.HyLoader('Loading Models...', hc.Loaders.standard_loaders,index=[0]):
if "model" in st.session_state:
del st.session_state["model"]
load_diffusers_model(weights_path, torch_device)
#else:
#print("Model already loaded")
if "pipe" not in server_state:
print('wtf')