From 67f4c27ea26b94978f6c6f006f14a3871233c177 Mon Sep 17 00:00:00 2001 From: ZeroCool940711 Date: Tue, 11 Oct 2022 21:52:28 -0700 Subject: [PATCH] Improved hot reloading for some model options like optimized and float16. --- requirements.txt | 6 ++++++ scripts/sd_utils.py | 48 +++++++++++++++++++++++++++++++++------------ 2 files changed, 42 insertions(+), 12 deletions(-) diff --git a/requirements.txt b/requirements.txt index 7ab31cf..d7fa462 100644 --- a/requirements.txt +++ b/requirements.txt @@ -34,8 +34,14 @@ streamlit-tensorboard==0.0.2 hydralit==1.0.14 hydralit_components==1.0.10 stqdm==0.0.4 + +# txt2vid +stable-diffusion-videos==0.5.3 diffusers==0.4.1 +# img2img inpainting +streamlit-drawable-canvas==0.9.2 + # Img2text ftfy==6.1.1 fairscale==0.4.4 diff --git a/scripts/sd_utils.py b/scripts/sd_utils.py index fa3f6cd..9ef0ca0 100644 --- a/scripts/sd_utils.py +++ b/scripts/sd_utils.py @@ -25,6 +25,7 @@ from streamlit import StopException, StreamlitAPIException #streamlit components section from streamlit_server_state import server_state, server_state_lock import hydralit_components as hc +import streamlit_nested_layout #other imports @@ -64,6 +65,7 @@ import piexif.helper from tqdm import trange from ldm.models.diffusion.ddim import DDIMSampler from ldm.util import ismap +import librosa # Temp imports @@ -229,14 +231,6 @@ def load_models(use_LDSR = False, LDSR_model='model', use_GFPGAN=False, GFPGAN_m if "progress_bar_text" in st.session_state: st.session_state["progress_bar_text"].text("") - - # Generate random run ID - # Used to link runs linked w/ continue_prev_run which is not yet implemented - # Use URL and filesystem safe version just in case. - st.session_state["run_id"] = base64.urlsafe_b64encode( - os.urandom(6) - ).decode("ascii") - # check what models we want to use and if the they are already loaded. with server_state_lock["LDSR"]: if use_LDSR: @@ -306,8 +300,32 @@ def load_models(use_LDSR = False, LDSR_model='model', use_GFPGAN=False, GFPGAN_m if "model" in server_state: if "model" in server_state and server_state["loaded_model"] == custom_model: - # TODO: check if the optimized mode was changed? - print("Model already loaded") + # if the float16 or no_half options have changed since the last time the model was loaded then we need to reload the model. + if ("float16" in server_state and server_state['float16'] != st.session_state['defaults'].general.use_float16) \ + or ("no_half" in server_state and server_state['no_half'] != st.session_state['defaults'].general.no_half) \ + or ("optimized" in server_state and server_state['optimized'] != st.session_state['defaults'].general.optimized): + + print ("Model options changed, deleting the model from memory.") + del server_state['float16'] + del server_state['no_half'] + + del server_state["model"] + del server_state["modelCS"] + del server_state["modelFS"] + del server_state["loaded_model"] + + del server_state['optimized'] + + server_state['float16'] = st.session_state['defaults'].general.use_float16 + server_state['no_half'] = st.session_state['defaults'].general.no_half + server_state['optimized'] = st.session_state['defaults'].general.optimized + + load_models(use_LDSR=st.session_state["use_LDSR"], LDSR_model=st.session_state["LDSR_model"], + use_GFPGAN=st.session_state["use_GFPGAN"], GFPGAN_model=st.session_state["GFPGAN_model"] , + use_RealESRGAN=st.session_state["use_RealESRGAN"], RealESRGAN_model=st.session_state["RealESRGAN_model"], + CustomModel_available=server_state["CustomModel_available"], custom_model=st.session_state["custom_model"]) + else: + print("Model already loaded") return else: @@ -329,7 +347,7 @@ def load_models(use_LDSR = False, LDSR_model='model', use_GFPGAN=False, GFPGAN_m del st.session_state['textual_inversion'] # At this point the model is either - # not loaded yet or have been evicted: + # not loaded yet or have been deleted from memory: # load new model into memory server_state["custom_model"] = custom_model @@ -342,6 +360,10 @@ def load_models(use_LDSR = False, LDSR_model='model', use_GFPGAN=False, GFPGAN_m server_state["modelFS"] = modelFS server_state["loaded_model"] = custom_model + server_state['float16'] = st.session_state['defaults'].general.use_float16 + server_state['no_half'] = st.session_state['defaults'].general.no_half + server_state['optimized'] = st.session_state['defaults'].general.optimized + #trying to disable multiprocessing as it makes it so streamlit cant stop when the # model is loaded in memory and you need to kill the process sometimes. @@ -2253,7 +2275,9 @@ def process_images( load_models(use_LDSR=use_LDSR, LDSR_model=LDSR_model_name, use_GFPGAN=use_GFPGAN, use_RealESRGAN=use_RealESRGAN, RealESRGAN_model=realesrgan_model_name) torch_gc() - cropped_faces, restored_faces, restored_img = server_state["GFPGAN"].enhance(x_sample[:,:,::-1], has_aligned=False, only_center_face=False, paste_back=True) + + with torch.autocast('cuda'): + cropped_faces, restored_faces, restored_img = server_state["GFPGAN"].enhance(x_sample[:,:,::-1], has_aligned=False, only_center_face=False, paste_back=True) gfpgan_sample = restored_img[:,:,::-1] gfpgan_image = Image.fromarray(gfpgan_sample)