Improved hot reloading for some model options like optimized and float16.

This commit is contained in:
ZeroCool940711 2022-10-11 21:52:28 -07:00 committed by Alejandro Gil
parent 1cc22b0984
commit 67f4c27ea2
2 changed files with 42 additions and 12 deletions

View File

@ -34,8 +34,14 @@ streamlit-tensorboard==0.0.2
hydralit==1.0.14
hydralit_components==1.0.10
stqdm==0.0.4
# txt2vid
stable-diffusion-videos==0.5.3
diffusers==0.4.1
# img2img inpainting
streamlit-drawable-canvas==0.9.2
# Img2text
ftfy==6.1.1
fairscale==0.4.4

View File

@ -25,6 +25,7 @@ from streamlit import StopException, StreamlitAPIException
#streamlit components section
from streamlit_server_state import server_state, server_state_lock
import hydralit_components as hc
import streamlit_nested_layout
#other imports
@ -64,6 +65,7 @@ import piexif.helper
from tqdm import trange
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.util import ismap
import librosa
# Temp imports
@ -229,14 +231,6 @@ def load_models(use_LDSR = False, LDSR_model='model', use_GFPGAN=False, GFPGAN_m
if "progress_bar_text" in st.session_state:
st.session_state["progress_bar_text"].text("")
# Generate random run ID
# Used to link runs linked w/ continue_prev_run which is not yet implemented
# Use URL and filesystem safe version just in case.
st.session_state["run_id"] = base64.urlsafe_b64encode(
os.urandom(6)
).decode("ascii")
# check what models we want to use and if the they are already loaded.
with server_state_lock["LDSR"]:
if use_LDSR:
@ -306,8 +300,32 @@ def load_models(use_LDSR = False, LDSR_model='model', use_GFPGAN=False, GFPGAN_m
if "model" in server_state:
if "model" in server_state and server_state["loaded_model"] == custom_model:
# TODO: check if the optimized mode was changed?
print("Model already loaded")
# if the float16 or no_half options have changed since the last time the model was loaded then we need to reload the model.
if ("float16" in server_state and server_state['float16'] != st.session_state['defaults'].general.use_float16) \
or ("no_half" in server_state and server_state['no_half'] != st.session_state['defaults'].general.no_half) \
or ("optimized" in server_state and server_state['optimized'] != st.session_state['defaults'].general.optimized):
print ("Model options changed, deleting the model from memory.")
del server_state['float16']
del server_state['no_half']
del server_state["model"]
del server_state["modelCS"]
del server_state["modelFS"]
del server_state["loaded_model"]
del server_state['optimized']
server_state['float16'] = st.session_state['defaults'].general.use_float16
server_state['no_half'] = st.session_state['defaults'].general.no_half
server_state['optimized'] = st.session_state['defaults'].general.optimized
load_models(use_LDSR=st.session_state["use_LDSR"], LDSR_model=st.session_state["LDSR_model"],
use_GFPGAN=st.session_state["use_GFPGAN"], GFPGAN_model=st.session_state["GFPGAN_model"] ,
use_RealESRGAN=st.session_state["use_RealESRGAN"], RealESRGAN_model=st.session_state["RealESRGAN_model"],
CustomModel_available=server_state["CustomModel_available"], custom_model=st.session_state["custom_model"])
else:
print("Model already loaded")
return
else:
@ -329,7 +347,7 @@ def load_models(use_LDSR = False, LDSR_model='model', use_GFPGAN=False, GFPGAN_m
del st.session_state['textual_inversion']
# At this point the model is either
# not loaded yet or have been evicted:
# not loaded yet or have been deleted from memory:
# load new model into memory
server_state["custom_model"] = custom_model
@ -342,6 +360,10 @@ def load_models(use_LDSR = False, LDSR_model='model', use_GFPGAN=False, GFPGAN_m
server_state["modelFS"] = modelFS
server_state["loaded_model"] = custom_model
server_state['float16'] = st.session_state['defaults'].general.use_float16
server_state['no_half'] = st.session_state['defaults'].general.no_half
server_state['optimized'] = st.session_state['defaults'].general.optimized
#trying to disable multiprocessing as it makes it so streamlit cant stop when the
# model is loaded in memory and you need to kill the process sometimes.
@ -2253,7 +2275,9 @@ def process_images(
load_models(use_LDSR=use_LDSR, LDSR_model=LDSR_model_name, use_GFPGAN=use_GFPGAN, use_RealESRGAN=use_RealESRGAN, RealESRGAN_model=realesrgan_model_name)
torch_gc()
cropped_faces, restored_faces, restored_img = server_state["GFPGAN"].enhance(x_sample[:,:,::-1], has_aligned=False, only_center_face=False, paste_back=True)
with torch.autocast('cuda'):
cropped_faces, restored_faces, restored_img = server_state["GFPGAN"].enhance(x_sample[:,:,::-1], has_aligned=False, only_center_face=False, paste_back=True)
gfpgan_sample = restored_img[:,:,::-1]
gfpgan_image = Image.fromarray(gfpgan_sample)