Removed the suppress_st_warning argument from st.experimental_memo as it was causing some issues on the latest version. (#1788)

This commit is contained in:
Alejandro Gil 2023-06-11 23:19:09 -07:00 committed by GitHub
commit 2ca54a6934
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -99,7 +99,7 @@ shutup.please()
if "defaults" in st.session_state:
if st.session_state["defaults"].general.use_cudnn:
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.enabled = True
torch.backends.cudnn.enabled = True
try:
# this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
@ -294,7 +294,7 @@ def merge(file1, file2, out, weight):
file2 += ".ckpt"
if not(out.endswith(".ckpt")):
out += ".ckpt"
try:
try:
#Load Models
model_0 = torch.load(file1)
model_1 = torch.load(file2)
@ -310,7 +310,7 @@ def merge(file1, file2, out, weight):
theta_0[key] = theta_1[key]
torch.save(model_0, out)
except:
logger.error("Error in merging")
logger.error("Error in merging")
def human_readable_size(size, decimal_places=3):
@ -1325,7 +1325,7 @@ def torch_gc():
torch.cuda.ipc_collect()
@retry(tries=5)
#@st.experimental_memo(persist="disk", show_spinner=False, suppress_st_warning=True)
#@st.experimental_memo(persist="disk", show_spinner=False)
def load_GFPGAN(model_name='GFPGANv1.4'):
#model_name = 'GFPGANv1.3'
@ -1629,10 +1629,10 @@ def ModelLoader(models,load=False,unload=False,imgproc_realesrgan_model_name='Re
#
@retry(tries=5)
def generation_callback(img, i=0):
# try to do garbage collection before decoding the image
torch_gc()
if "update_preview_frequency" not in st.session_state:
raise StopException
@ -1757,7 +1757,7 @@ def slerp(device, t, v0:torch.Tensor, v1:torch.Tensor, DOT_THRESHOLD=0.9995):
return v2
#
@st.experimental_memo(persist="disk", show_spinner=False, suppress_st_warning=True)
@st.experimental_memo(persist="disk", show_spinner=False)
def optimize_update_preview_frequency(current_chunk_speed, previous_chunk_speed_list, update_preview_frequency, update_preview_frequency_list):
"""Find the optimal update_preview_frequency value maximizing
performance while minimizing the time between updates."""
@ -2420,7 +2420,7 @@ def process_images(
else: # just behave like usual
c = (server_state["model"] if not st.session_state['defaults'].general.optimized else server_state["modelCS"]).get_learned_conditioning(prompts)
shape = [opt_C, height // opt_f, width // opt_f]
if st.session_state['defaults'].general.optimized: