Revert "Made sure GFPGAN and RealESRGAN are on server_state. (#1319)"

This reverts commit 1fd28eed1e.
This commit is contained in:
ZeroCool 2022-09-25 21:26:17 -07:00 committed by GitHub
parent 1fd28eed1e
commit c8a8d6cea9
2 changed files with 87 additions and 92 deletions

View File

@ -686,15 +686,13 @@ def load_GFPGAN():
sys.path.append(os.path.abspath(st.session_state['defaults'].general.GFPGAN_dir))
from gfpgan import GFPGANer
with server_state_lock['GFPGAN']:
if st.session_state['defaults'].general.gfpgan_cpu or st.session_state['defaults'].general.extra_models_cpu:
server_state['GFPGAN'] = GFPGANer(model_path=model_path, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=torch.device('cpu'))
elif st.session_state['defaults'].general.extra_models_gpu:
server_state['GFPGAN'] = GFPGANer(model_path=model_path, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=torch.device(f"cuda:{st.session_state['defaults'].general.gfpgan_gpu}"))
else:
server_state['GFPGAN'] = GFPGANer(model_path=model_path, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=torch.device(f"cuda:{st.session_state['defaults'].general.gpu}"))
return server_state['GFPGAN']
if st.session_state['defaults'].general.gfpgan_cpu or st.session_state['defaults'].general.extra_models_cpu:
instance = GFPGANer(model_path=model_path, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=torch.device('cpu'))
elif st.session_state['defaults'].general.extra_models_gpu:
instance = GFPGANer(model_path=model_path, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=torch.device(f"cuda:{st.session_state['defaults'].general.gfpgan_gpu}"))
else:
instance = GFPGANer(model_path=model_path, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=torch.device(f"cuda:{st.session_state['defaults'].general.gpu}"))
return instance
@retry(tries=5)
def load_RealESRGAN(model_name: str):
@ -711,18 +709,17 @@ def load_RealESRGAN(model_name: str):
sys.path.append(os.path.abspath(st.session_state['defaults'].general.RealESRGAN_dir))
from realesrgan import RealESRGANer
with server_state_lock['RealESRGAN']:
if st.session_state['defaults'].general.esrgan_cpu or st.session_state['defaults'].general.extra_models_cpu:
server_state['RealESRGAN'] = RealESRGANer(scale=2, model_path=model_path, model=RealESRGAN_models[model_name], pre_pad=0, half=False) # cpu does not support half
server_state['RealESRGAN'].device = torch.device('cpu')
server_state['RealESRGAN'].model.to('cpu')
elif st.session_state['defaults'].general.extra_models_gpu:
server_state['RealESRGAN'] = RealESRGANer(scale=2, model_path=model_path, model=RealESRGAN_models[model_name], pre_pad=0, half=not st.session_state['defaults'].general.no_half, device=torch.device(f"cuda:{st.session_state['defaults'].general.esrgan_gpu}"))
else:
server_state['RealESRGAN'] = RealESRGANer(scale=2, model_path=model_path, model=RealESRGAN_models[model_name], pre_pad=0, half=not st.session_state['defaults'].general.no_half, device=torch.device(f"cuda:{st.session_state['defaults'].general.gpu}"))
server_state['RealESRGAN'].model.name = model_name
if st.session_state['defaults'].general.esrgan_cpu or st.session_state['defaults'].general.extra_models_cpu:
instance = RealESRGANer(scale=2, model_path=model_path, model=RealESRGAN_models[model_name], pre_pad=0, half=False) # cpu does not support half
instance.device = torch.device('cpu')
instance.model.to('cpu')
elif st.session_state['defaults'].general.extra_models_gpu:
instance = RealESRGANer(scale=2, model_path=model_path, model=RealESRGAN_models[model_name], pre_pad=0, half=not st.session_state['defaults'].general.no_half, device=torch.device(f"cuda:{st.session_state['defaults'].general.esrgan_gpu}"))
else:
instance = RealESRGANer(scale=2, model_path=model_path, model=RealESRGAN_models[model_name], pre_pad=0, half=not st.session_state['defaults'].general.no_half, device=torch.device(f"cuda:{st.session_state['defaults'].general.gpu}"))
instance.model.name = model_name
return server_state['RealESRGAN']
return instance
#
@retry(tries=5)
@ -731,7 +728,6 @@ def load_LDSR(checking=False):
yaml_name = 'project'
model_path = os.path.join(st.session_state['defaults'].general.LDSR_dir, 'experiments/pretrained_models', model_name + '.ckpt')
yaml_path = os.path.join(st.session_state['defaults'].general.LDSR_dir, 'experiments/pretrained_models', yaml_name + '.yaml')
if not os.path.isfile(model_path):
raise Exception("LDSR model not found at path "+model_path)
if not os.path.isfile(yaml_path):
@ -742,7 +738,6 @@ def load_LDSR(checking=False):
sys.path.append(os.path.abspath(st.session_state['defaults'].general.LDSR_dir))
from LDSR import LDSR
LDSRObject = LDSR(model_path, yaml_path)
return LDSRObject
#

View File

@ -47,14 +47,14 @@ class plugin_info():
if os.path.exists(os.path.join(st.session_state['defaults'].general.GFPGAN_dir, "experiments", "pretrained_models", "GFPGANv1.3.pth")):
server_state["GFPGAN_available"] = True
GFPGAN_available = True
else:
server_state["GFPGAN_available"] = False
GFPGAN_available = False
if os.path.exists(os.path.join(st.session_state['defaults'].general.RealESRGAN_dir, "experiments","pretrained_models", f"{st.session_state['defaults'].txt2vid.RealESRGAN_model}.pth")):
server_state["RealESRGAN_available"] = True
RealESRGAN_available = True
else:
server_state["RealESRGAN_available"] = False
RealESRGAN_available = False
#
# -----------------------------------------------------------------------------
@ -484,7 +484,7 @@ def txt2vid(
with autocast("cuda"):
image = diffuse(server_state["pipe"], cond_embeddings, init, num_inference_steps, cfg_scale, eta)
if st.session_state["save_individual_images"] and not st.session_state["use_GFPGAN"] and not st.session_state["use_RealESRGAN"]:
if st.session_state["save_individual_images"] and not server_state["use_GFPGAN"] and not st.session_state["use_RealESRGAN"]:
#im = Image.fromarray(image)
outpath = os.path.join(full_path, 'frame%06d.png' % frame_index)
image.save(outpath, quality=quality)
@ -498,8 +498,8 @@ def txt2vid(
#
#try:
#if st.session_state["use_GFPGAN"] and server_state["GFPGAN"] is not None and not st.session_state["use_RealESRGAN"]:
if st.session_state["use_GFPGAN"] and server_state["GFPGAN"] is not None:
#if server_state["use_GFPGAN"] and server_state["GFPGAN"] is not None and not st.session_state["use_RealESRGAN"]:
if server_state["use_GFPGAN"] and server_state["GFPGAN"] is not None:
#print("Running GFPGAN on image ...")
st.session_state["progress_bar_text"].text("Running GFPGAN on image ...")
#skip_save = True # #287 >_>
@ -714,12 +714,12 @@ def layout():
help="Do loop")
st.session_state["save_as_jpg"] = st.checkbox("Save samples as jpg", value=st.session_state['defaults'].txt2vid.save_as_jpg, help="Saves the images as jpg instead of png.")
if server_state["GFPGAN_available"]:
st.session_state["use_GFPGAN"] = st.checkbox("Use GFPGAN", value=st.session_state['defaults'].txt2vid.use_GFPGAN, help="Uses the GFPGAN model to improve faces after the generation. This greatly improve the quality and consistency of faces but uses extra VRAM. Disable if you need the extra VRAM.")
if GFPGAN_available:
server_state["use_GFPGAN"] = st.checkbox("Use GFPGAN", value=st.session_state['defaults'].txt2vid.use_GFPGAN, help="Uses the GFPGAN model to improve faces after the generation. This greatly improve the quality and consistency of faces but uses extra VRAM. Disable if you need the extra VRAM.")
else:
st.session_state["use_GFPGAN"] = False
server_state["use_GFPGAN"] = False
if server_state["RealESRGAN_available"]:
if RealESRGAN_available:
st.session_state["use_RealESRGAN"] = st.checkbox("Use RealESRGAN", value=st.session_state['defaults'].txt2vid.use_RealESRGAN,
help="Uses the RealESRGAN model to upscale the images after the generation. This greatly improve the quality and lets you have high resolution images but uses extra VRAM. Disable if you need the extra VRAM.")
st.session_state["RealESRGAN_model"] = st.selectbox("RealESRGAN model", ["RealESRGAN_x4plus", "RealESRGAN_x4plus_anime_6B"], index=0)
@ -743,9 +743,9 @@ def layout():
if generate_button:
#print("Loading models")
# load the models when we hit the generate button for the first time, it wont be loaded after that so dont worry.
#load_models(False, st.session_state["use_GFPGAN"], True, st.session_state["RealESRGAN_model"])
#load_models(False, server_state["use_GFPGAN"], True, st.session_state["RealESRGAN_model"])
if st.session_state["use_GFPGAN"]:
if server_state["use_GFPGAN"]:
if "GFPGAN" in st.session_state:
print("GFPGAN already loaded")
else:
@ -762,63 +762,63 @@ def layout():
if "GFPGAN" in st.session_state:
del server_state["GFPGAN"]
#try:
# run video generation
video, seed, info, stats = txt2vid(prompts=prompt, gpu=st.session_state["defaults"].general.gpu,
num_steps=st.session_state.sampling_steps, max_frames=int(st.session_state.max_frames),
num_inference_steps=st.session_state.num_inference_steps,
cfg_scale=cfg_scale,do_loop=st.session_state["do_loop"],
seeds=seed, quality=100, eta=0.0, width=width,
height=height, weights_path=custom_model, scheduler=scheduler_name,
disable_tqdm=False, beta_start=st.session_state['defaults'].txt2vid.beta_start.value,
beta_end=st.session_state['defaults'].txt2vid.beta_end.value,
beta_schedule=beta_scheduler_type, starting_image=None)
try:
# run video generation
video, seed, info, stats = txt2vid(prompts=prompt, gpu=st.session_state["defaults"].general.gpu,
num_steps=st.session_state.sampling_steps, max_frames=int(st.session_state.max_frames),
num_inference_steps=st.session_state.num_inference_steps,
cfg_scale=cfg_scale,do_loop=st.session_state["do_loop"],
seeds=seed, quality=100, eta=0.0, width=width,
height=height, weights_path=custom_model, scheduler=scheduler_name,
disable_tqdm=False, beta_start=st.session_state['defaults'].txt2vid.beta_start.value,
beta_end=st.session_state['defaults'].txt2vid.beta_end.value,
beta_schedule=beta_scheduler_type, starting_image=None)
#message.success('Done!', icon="✅")
message.success('Render Complete: ' + info + '; Stats: ' + stats, icon="")
#history_tab,col1,col2,col3,PlaceHolder,col1_cont,col2_cont,col3_cont = st.session_state['historyTab']
#if 'latestVideos' in st.session_state:
#for i in video:
##push the new image to the list of latest images and remove the oldest one
##remove the last index from the list\
#st.session_state['latestVideos'].pop()
##add the new image to the start of the list
#st.session_state['latestVideos'].insert(0, i)
#PlaceHolder.empty()
#with PlaceHolder.container():
#col1, col2, col3 = st.columns(3)
#col1_cont = st.container()
#col2_cont = st.container()
#col3_cont = st.container()
#with col1_cont:
#with col1:
#st.image(st.session_state['latestVideos'][0])
#st.image(st.session_state['latestVideos'][3])
#st.image(st.session_state['latestVideos'][6])
#with col2_cont:
#with col2:
#st.image(st.session_state['latestVideos'][1])
#st.image(st.session_state['latestVideos'][4])
#st.image(st.session_state['latestVideos'][7])
#with col3_cont:
#with col3:
#st.image(st.session_state['latestVideos'][2])
#st.image(st.session_state['latestVideos'][5])
#st.image(st.session_state['latestVideos'][8])
#historyGallery = st.empty()
## check if output_images length is the same as seeds length
#with gallery_tab:
#st.markdown(createHTMLGallery(video,seed), unsafe_allow_html=True)
#st.session_state['historyTab'] = [history_tab,col1,col2,col3,PlaceHolder,col1_cont,col2_cont,col3_cont]
#message.success('Done!', icon="✅")
message.success('Render Complete: ' + info + '; Stats: ' + stats, icon="")
#history_tab,col1,col2,col3,PlaceHolder,col1_cont,col2_cont,col3_cont = st.session_state['historyTab']
#if 'latestVideos' in st.session_state:
#for i in video:
##push the new image to the list of latest images and remove the oldest one
##remove the last index from the list\
#st.session_state['latestVideos'].pop()
##add the new image to the start of the list
#st.session_state['latestVideos'].insert(0, i)
#PlaceHolder.empty()
#with PlaceHolder.container():
#col1, col2, col3 = st.columns(3)
#col1_cont = st.container()
#col2_cont = st.container()
#col3_cont = st.container()
#with col1_cont:
#with col1:
#st.image(st.session_state['latestVideos'][0])
#st.image(st.session_state['latestVideos'][3])
#st.image(st.session_state['latestVideos'][6])
#with col2_cont:
#with col2:
#st.image(st.session_state['latestVideos'][1])
#st.image(st.session_state['latestVideos'][4])
#st.image(st.session_state['latestVideos'][7])
#with col3_cont:
#with col3:
#st.image(st.session_state['latestVideos'][2])
#st.image(st.session_state['latestVideos'][5])
#st.image(st.session_state['latestVideos'][8])
#historyGallery = st.empty()
## check if output_images length is the same as seeds length
#with gallery_tab:
#st.markdown(createHTMLGallery(video,seed), unsafe_allow_html=True)
#st.session_state['historyTab'] = [history_tab,col1,col2,col3,PlaceHolder,col1_cont,col2_cont,col3_cont]
#except (StopException, KeyError):
#print(f"Received Streamlit StopException")
except (StopException, KeyError):
print(f"Received Streamlit StopException")