model manager

model manager
This commit is contained in:
hlky 2022-10-04 15:25:47 +01:00
parent 1e7bdfe3f3
commit 3851fcc537
No known key found for this signature in database
GPG Key ID: 55A99F1E80D907D5
9 changed files with 151 additions and 88 deletions

View File

@ -322,55 +322,82 @@ model_manager:
stable_diffusion: stable_diffusion:
model_name: "Stable Diffusion v1.4" model_name: "Stable Diffusion v1.4"
save_location: "./models/ldm/stable-diffusion-v1" save_location: "./models/ldm/stable-diffusion-v1"
download_link: "https://huggingface.co/CompVis/stable-diffusion-v-1-4-original" files:
model_ckpt:
file_name: "model.ckpt"
download_link: "https://www.googleapis.com/storage/v1/b/aai-blog-files/o/sd-v1-4.ckpt?alt=media"
gfpgan: gfpgan:
model_name: "GFPGAN v1.4" model_name: "GFPGAN"
save_location: "./models/gfpgan" save_location: "./models/gfpgan"
download_link: "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/GFPGANv1.4.pth" files:
gfpgan_1_4:
file_name: "GFPGANv1.4.pth"
download_link: "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/GFPGANv1.4.pth"
resnet_50:
file_name: "detection_Resnet50_Final.pth"
save_location: "./gfpgan/weights"
download_link: "https://github.com/xinntao/facexlib/releases/download/v0.1.0/detection_Resnet50_Final.pth"
parsing_parsenet:
file_name: "parsing_parsenet.pth"
save_location: "./gfpgan/weights"
download_link: "https://github.com/xinntao/facexlib/releases/download/v0.2.2/parsing_parsenet.pth"
realesrgan_x4plus: realesrgan:
model_name: "RealESRGAN_x4plus" model_name: "RealESRGAN"
save_location: "./models/realesrgan" save_location: "./models/realesrgan"
download_link: "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth" files:
x4plus:
realesrgan_x4plus_anime_6b: file_name: "RealESRGAN_x4plus.pth"
model_name: "RealESRGAN_x4plus_anime_6B" download_link: "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth"
save_location: "./models/realesrgan" x4plus_anime_6b:
download_link: "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth" file_name: "RealESRGAN_x4plus_anime_6B.pth"
download_link: "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth"
waifu_diffusion: waifu_diffusion:
model_name: "Waifu Diffusion v1.2" model_name: "Waifu Diffusion v1.2"
save_location: "./models/custom" save_location: "./models/custom"
download_link: "https://huggingface.co/hakurei/waifu-diffusion" files:
waifu_diffusion:
waifu-diffusion_pruned: file_name: "waifu-diffusion.ckpt"
model_name: "Waifu Diffusion v1.2 Pruned" download_link: "https://huggingface.co/crumb/pruned-waifu-diffusion/resolve/main/model-pruned.ckpt"
save_location: "./models/custom"
download_link: "https://huggingface.co/crumb/pruned-waifu-diffusion"
trinart_stable_diffusion: trinart_stable_diffusion:
model_name: "TrinArt Stable Diffusion v2" model_name: "TrinArt Stable Diffusion v2"
save_location: "./models/custom" save_location: "./models/custom"
download_link: "https://huggingface.co/naclbit/trinart_stable_diffusion_v2" files:
trinart:
file_name: "trinart.ckpt"
download_link: "https://huggingface.co/naclbit/trinart_stable_diffusion_v2/resolve/main/trinart2_step95000.ckpt"
stable_diffusion_concept_library: stable_diffusion_concept_library:
model_name: "Stable Diffusion Concept Library" model_name: "Stable Diffusion Concept Library"
save_location: "./models/custom/sd-concepts-library" save_location: "./models/custom/sd-concepts-library/"
download_link: "https://github.com/sd-webui/sd-concepts-library" files:
concept_library:
file_name: ""
download_link: "https://github.com/sd-webui/sd-concepts-library"
blip_model: blip_model:
model_name: "Blip Model" model_name: "Blip Model"
save_location: "./models/blip" save_location: "./models/blip"
download_link: "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model*_base_caption.pth" files:
blip:
file_name: "model__base_caption.pth"
download_link: "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model*_base_caption.pth"
lds_project_file: ldsr:
model_name: "LDSR `project.yaml`" model_name: "Latent Diffusion Super Resolution (LDSR)"
save_location: "./models/ldsr" save_location: "./models/ldsr"
download_link: "https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1" files:
project_yaml:
ldsr_model: file_name: "project.yaml"
model_name: "LDSR `model.cpkt`" download_link: "https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1"
save_location: "./models/ldsr"
download_link: "https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1" ldsr_model:
file_name: "model.ckpt"
download_link: "https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1"

View File

@ -15,7 +15,7 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
# base webui import and utils. # base webui import and utils.
from sd_utils import * from sd_utils import *
import wget
# streamlit imports # streamlit imports
@ -26,6 +26,22 @@ from sd_utils import *
# end of imports # end of imports
#--------------------------------------------------------------------------------------------------------------- #---------------------------------------------------------------------------------------------------------------
def download_file(file_name, file_path, file_url):
if not os.path.exists(file_path):
os.makedirs(file_path)
if not os.path.exists(file_path + '/' + file_name):
print('Downloading ' + file_name + '...')
# TODO - add progress bar in streamlit
wget.download(url=file_url, out=file_path + '/' + file_name)
else:
print(file_name + ' already exists.')
def download_model(models, model_name):
""" Download all files from model_list[model_name] """
for file in models[model_name]:
download_file(file['file_name'], file['file_path'], file['file_url'])
return
def layout(): def layout():
#search = st.text_input(label="Search", placeholder="Type the name of the model you want to search for.", help="") #search = st.text_input(label="Search", placeholder="Type the name of the model you want to search for.", help="")
@ -44,4 +60,29 @@ def layout():
col1.write(x) # index col1.write(x) # index
col2.write(models[model_name]['model_name']) col2.write(models[model_name]['model_name'])
col3.write(models[model_name]['save_location']) col3.write(models[model_name]['save_location'])
col4.write(models[model_name]['download_link']) with col4:
files_exist = 0
for file in models[model_name]['files']:
if "save_location" in models[model_name]['files'][file]:
os.path.exists(models[model_name]['files'][file]['save_location'] + '/' + models[model_name]['files'][file]['file_name'])
files_exist += 1
elif os.path.exists(models[model_name]['save_location'] + '/' + models[model_name]['files'][file]['file_name']):
files_exist += 1
files_needed = []
for file in models[model_name]['files']:
if "save_location" in models[model_name]['files'][file]:
if not os.path.exists(models[model_name]['files'][file]['save_location'] + '/' + models[model_name]['files'][file]['file_name']):
files_needed.append(file)
elif not os.path.exists(models[model_name]['save_location'] + '/' + models[model_name]['files'][file]['file_name']):
files_needed.append(file)
if len(files_needed) > 0:
if st.button('Download', key=models[model_name]['model_name'], help='Download ' + models[model_name]['model_name']):
for file in files_needed:
if "save_location" in models[model_name]['files'][file]:
download_file(models[model_name]['files'][file]['file_name'], models[model_name]['files'][file]['save_location'], models[model_name]['files'][file]['download_link'])
else:
download_file(models[model_name]['files'][file]['file_name'], models[model_name]['save_location'], models[model_name]['files'][file]['download_link'])
else:
st.empty()
else:
st.write('')

View File

@ -83,10 +83,10 @@ def layout():
help="Default model path. Default: 'models/ldm/stable-diffusion-v1/model.ckpt'") help="Default model path. Default: 'models/ldm/stable-diffusion-v1/model.ckpt'")
st.session_state['defaults'].general.GFPGAN_dir = st.text_input("Default GFPGAN directory", value=st.session_state['defaults'].general.GFPGAN_dir, st.session_state['defaults'].general.GFPGAN_dir = st.text_input("Default GFPGAN directory", value=st.session_state['defaults'].general.GFPGAN_dir,
help="Default GFPGAN directory. Default: './src/gfpgan'") help="Default GFPGAN directory. Default: './models/gfpgan'")
st.session_state['defaults'].general.RealESRGAN_dir = st.text_input("Default RealESRGAN directory", value=st.session_state['defaults'].general.RealESRGAN_dir, st.session_state['defaults'].general.RealESRGAN_dir = st.text_input("Default RealESRGAN directory", value=st.session_state['defaults'].general.RealESRGAN_dir,
help="Default GFPGAN directory. Default: './src/realesrgan'") help="Default GFPGAN directory. Default: './models/realesrgan'")
RealESRGAN_model_list = ["RealESRGAN_x4plus", "RealESRGAN_x4plus_anime_6B"] RealESRGAN_model_list = ["RealESRGAN_x4plus", "RealESRGAN_x4plus_anime_6B"]
st.session_state['defaults'].general.RealESRGAN_model = st.selectbox("RealESRGAN model", RealESRGAN_model_list, st.session_state['defaults'].general.RealESRGAN_model = st.selectbox("RealESRGAN model", RealESRGAN_model_list,

View File

@ -52,7 +52,7 @@ def img2img(prompt: str = '', init_info: any = None, init_info_mask: any = None,
variant_amount: float = None, variant_seed: int = None, ddim_eta:float = 0.0, variant_amount: float = None, variant_seed: int = None, ddim_eta:float = 0.0,
write_info_files:bool = True, separate_prompts:bool = False, normalize_prompt_weights:bool = True, write_info_files:bool = True, separate_prompts:bool = False, normalize_prompt_weights:bool = True,
save_individual_images: bool = True, save_grid: bool = True, group_by_prompt: bool = True, save_individual_images: bool = True, save_grid: bool = True, group_by_prompt: bool = True,
save_as_jpg: bool = True, use_GFPGAN: bool = True, GFPGAN_model: str = 'GFPGANv1.3', save_as_jpg: bool = True, use_GFPGAN: bool = True, GFPGAN_model: str = 'GFPGANv1.4',
use_RealESRGAN: bool = True, RealESRGAN_model: str = "RealESRGAN_x4plus_anime_6B", use_RealESRGAN: bool = True, RealESRGAN_model: str = "RealESRGAN_x4plus_anime_6B",
use_LDSR: bool = True, LDSR_model: str = "model", use_LDSR: bool = True, LDSR_model: str = "model",
loopback: bool = False, loopback: bool = False,
@ -167,7 +167,7 @@ def img2img(prompt: str = '', init_info: any = None, init_info_mask: any = None,
init_image = 2. * image - 1. init_image = 2. * image - 1.
init_image = init_image.to(server_state["device"]) init_image = init_image.to(server_state["device"])
init_latent = (server_state["model"] if not st.session_state['defaults'].general.optimized else server_state["modelFS"]).get_first_stage_encoding((server_state["model"] if not st.session_state['defaults'].general.optimized else modelFS).encode_first_stage(init_image)) # move to latent space init_latent = (server_state["model"] if not st.session_state['defaults'].general.optimized else server_state["modelFS"]).get_first_stage_encoding((server_state["model"] if not st.session_state['defaults'].general.optimized else server_state["modelFS"]).encode_first_stage(init_image)) # move to latent space
if st.session_state['defaults'].general.optimized: if st.session_state['defaults'].general.optimized:
mem = torch.cuda.memory_allocated()/1e6 mem = torch.cuda.memory_allocated()/1e6

View File

@ -226,11 +226,11 @@ def interrogate(image, models):
if model_name not in server_state["clip_models"]: if model_name not in server_state["clip_models"]:
if model_name == 'ViT-H-14': if model_name == 'ViT-H-14':
server_state["clip_models"][model_name], _, server_state["preprocesses"][model_name] = open_clip.create_model_and_transforms(model_name, pretrained='laion2b_s32b_b79k', cache_dir='user_data/model_cache/clip') server_state["clip_models"][model_name], _, server_state["preprocesses"][model_name] = open_clip.create_model_and_transforms(model_name, pretrained='laion2b_s32b_b79k', cache_dir='models/clip')
elif model_name == 'ViT-g-14': elif model_name == 'ViT-g-14':
server_state["clip_models"][model_name], _, server_state["preprocesses"][model_name] = open_clip.create_model_and_transforms(model_name, pretrained='laion2b_s12b_b42k', cache_dir='user_data/model_cache/clip') server_state["clip_models"][model_name], _, server_state["preprocesses"][model_name] = open_clip.create_model_and_transforms(model_name, pretrained='laion2b_s12b_b42k', cache_dir='models/clip')
else: else:
server_state["clip_models"][model_name], server_state["preprocesses"][model_name] = clip.load(model_name, device=device, download_root='user_data/model_cache/clip') server_state["clip_models"][model_name], server_state["preprocesses"][model_name] = clip.load(model_name, device=device, download_root='models/clip')
server_state["clip_models"][model_name] = server_state["clip_models"][model_name].cuda().eval() server_state["clip_models"][model_name] = server_state["clip_models"][model_name].cuda().eval()
images = server_state["preprocesses"][model_name](image).unsqueeze(0).cuda() images = server_state["preprocesses"][model_name](image).unsqueeze(0).cuda()

View File

@ -15,6 +15,7 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
# base webui import and utils. # base webui import and utils.
#from webui_streamlit import st #from webui_streamlit import st
import gfpgan
import hydralit as st import hydralit as st
@ -219,7 +220,7 @@ def human_readable_size(size, decimal_places=3):
return f"{size:.{decimal_places}f}{unit}" return f"{size:.{decimal_places}f}{unit}"
def load_models(use_LDSR = False, LDSR_model='model', use_GFPGAN=False, GFPGAN_model='GFPGANv1.3', use_RealESRGAN=False, RealESRGAN_model="RealESRGAN_x4plus", def load_models(use_LDSR = False, LDSR_model='model', use_GFPGAN=False, GFPGAN_model='GFPGANv1.4', use_RealESRGAN=False, RealESRGAN_model="RealESRGAN_x4plus",
CustomModel_available=False, custom_model="Stable Diffusion v1.4"): CustomModel_available=False, custom_model="Stable Diffusion v1.4"):
"""Load the different models. We also reuse the models that are already in memory to speed things up instead of loading them again. """ """Load the different models. We also reuse the models that are already in memory to speed things up instead of loading them again. """
@ -1193,7 +1194,6 @@ def load_GFPGAN(model_name='GFPGANv1.4'):
sys.path.append(os.path.abspath(st.session_state['defaults'].general.GFPGAN_dir)) sys.path.append(os.path.abspath(st.session_state['defaults'].general.GFPGAN_dir))
from gfpgan import GFPGANer from gfpgan import GFPGANer
with server_state_lock['GFPGAN']: with server_state_lock['GFPGAN']:
if st.session_state['defaults'].general.gfpgan_cpu or st.session_state['defaults'].general.extra_models_cpu: if st.session_state['defaults'].general.gfpgan_cpu or st.session_state['defaults'].general.extra_models_cpu:
server_state['GFPGAN'] = GFPGANer(model_path=model_path, upscale=1, arch='clean', server_state['GFPGAN'] = GFPGANer(model_path=model_path, upscale=1, arch='clean',
@ -1221,7 +1221,7 @@ def load_RealESRGAN(model_name: str):
'RealESRGAN_x4plus_anime_6B': RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4) 'RealESRGAN_x4plus_anime_6B': RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
} }
model_path = os.path.join(st.session_state['defaults'].general.RealESRGAN_dir, 'experiments/pretrained_models', model_name + '.pth') model_path = os.path.join(st.session_state['defaults'].general.RealESRGAN_dir, model_name + '.pth')
if not os.path.isfile(model_path): if not os.path.isfile(model_path):
model_path = os.path.join(st.session_state['defaults'].general.RealESRGAN_dir, model_name + '.pth') model_path = os.path.join(st.session_state['defaults'].general.RealESRGAN_dir, model_name + '.pth')
@ -1756,15 +1756,19 @@ def GFPGAN_available():
# Allow for custom models to be used instead of the default one, # Allow for custom models to be used instead of the default one,
# an example would be Waifu-Diffusion or any other fine tune of stable diffusion # an example would be Waifu-Diffusion or any other fine tune of stable diffusion
st.session_state["GFPGAN_models"]:sorted = [] st.session_state["GFPGAN_models"]:sorted = []
model = st.session_state["defaults"].model_manager.models.gfpgan
files_available = 0
for file in model['files']:
if "save_location" in model['files'][file]:
if os.path.exists(os.path.join(model['files'][file]['save_location'], model['files'][file]['file_name'] )):
files_available += 1
elif os.path.exists(os.path.join(model['save_location'], model['files'][file]['file_name'] )):
base_name = os.path.splitext(model['files'][file]['file_name'])[0]
if "GFPGANv" in base_name:
st.session_state["GFPGAN_models"].append(base_name)
files_available += 1
for root, dirs, files in os.walk(st.session_state['defaults'].general.GFPGAN_dir): if len(st.session_state["GFPGAN_models"]) > 0 and files_available == len(model['files']):
for file in files:
if os.path.splitext(file)[1] == '.pth':
st.session_state["GFPGAN_models"].append(os.path.splitext(file)[0])
#print (len(st.session_state["GFPGAN_models"]))
#with server_state_lock["GFPGAN_available"]:
if len(st.session_state["GFPGAN_models"]) > 0:
st.session_state["GFPGAN_available"] = True st.session_state["GFPGAN_available"] = True
else: else:
st.session_state["GFPGAN_available"] = False st.session_state["GFPGAN_available"] = False
@ -1776,14 +1780,13 @@ def RealESRGAN_available():
# Allow for custom models to be used instead of the default one, # Allow for custom models to be used instead of the default one,
# an example would be Waifu-Diffusion or any other fine tune of stable diffusion # an example would be Waifu-Diffusion or any other fine tune of stable diffusion
st.session_state["RealESRGAN_models"]:sorted = [] st.session_state["RealESRGAN_models"]:sorted = []
model = st.session_state["defaults"].model_manager.models.realesrgan
for file in model['files']:
if os.path.exists(os.path.join(model['save_location'], model['files'][file]['file_name'] )):
base_name = os.path.splitext(model['files'][file]['file_name'])[0]
st.session_state["RealESRGAN_models"].append(base_name)
for root, dirs, files in os.walk(st.session_state['defaults'].general.RealESRGAN_dir): if len(st.session_state["RealESRGAN_models"]) > 0:
for file in files:
if os.path.splitext(file)[1] == '.pth':
st.session_state["RealESRGAN_models"].append(os.path.splitext(file)[0])
#with server_state_lock["RealESRGAN_available"]:
if len(st.session_state["RealESRGAN_models"]) > 0:
st.session_state["RealESRGAN_available"] = True st.session_state["RealESRGAN_available"] = True
else: else:
st.session_state["RealESRGAN_available"] = False st.session_state["RealESRGAN_available"] = False
@ -1794,19 +1797,22 @@ def LDSR_available():
# Allow for custom models to be used instead of the default one, # Allow for custom models to be used instead of the default one,
# an example would be Waifu-Diffusion or any other fine tune of stable diffusion # an example would be Waifu-Diffusion or any other fine tune of stable diffusion
st.session_state["LDSR_models"]:sorted = [] st.session_state["LDSR_models"]:sorted = []
files_available = 0
for root, dirs, files in os.walk(st.session_state['defaults'].general.LDSR_dir): model = st.session_state["defaults"].model_manager.models.ldsr
for file in files: for file in model['files']:
if os.path.splitext(file)[1] == '.ckpt': if os.path.exists(os.path.join(model['save_location'], model['files'][file]['file_name'] )):
st.session_state["LDSR_models"].append(os.path.splitext(file)[0]) base_name = os.path.splitext(model['files'][file]['file_name'])[0]
extension = os.path.splitext(model['files'][file]['file_name'])[1]
#print (st.session_state['defaults'].general.LDSR_dir) if extension == ".ckpt":
#print (st.session_state["LDSR_models"]) st.session_state["LDSR_models"].append(base_name)
#with server_state_lock["LDSR_available"]: files_available += 1
if len(st.session_state["LDSR_models"]) > 0: if files_available == len(model['files']):
st.session_state["LDSR_available"] = True st.session_state["LDSR_available"] = True
else: else:
st.session_state["LDSR_available"] = False st.session_state["LDSR_available"] = False
def save_sample(image, sample_path_i, filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale, def save_sample(image, sample_path_i, filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale,

View File

@ -59,17 +59,6 @@ class plugin_info():
isTab = True isTab = True
displayPriority = 1 displayPriority = 1
if os.path.exists(os.path.join(st.session_state['defaults'].general.GFPGAN_dir, "experiments", "pretrained_models", "GFPGANv1.3.pth")):
server_state["GFPGAN_available"] = True
else:
server_state["GFPGAN_available"] = False
if os.path.exists(os.path.join(st.session_state['defaults'].general.RealESRGAN_dir, "experiments","pretrained_models", f"{st.session_state['defaults'].txt2vid.RealESRGAN_model}.pth")):
server_state["RealESRGAN_available"] = True
else:
server_state["RealESRGAN_available"] = False
# #
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@ -235,11 +224,11 @@ def load_diffusers_model(weights_path,torch_device):
del st.session_state["weights_path"] del st.session_state["weights_path"]
st.session_state["weights_path"] = weights_path st.session_state["weights_path"] = weights_path
# if folder "user_data/model_cache/stable-diffusion-v1-4" exists, load the model from there # if folder "models/diffusers/stable-diffusion-v1-4" exists, load the model from there
if weights_path == "CompVis/stable-diffusion-v1-4": if weights_path == "CompVis/stable-diffusion-v1-4":
model_path = os.path.join("user_data", "model_cache", "stable-diffusion-v1-4") model_path = os.path.join("models", "diffusers", "stable-diffusion-v1-4")
elif weights_path == "hakurei/waifu-diffusion": elif weights_path == "hakurei/waifu-diffusion":
model_path = os.path.join("user_data", "model_cache", "waifu-diffusion") model_path = os.path.join("models", "diffusers", "waifu-diffusion")
if not os.path.exists(model_path + "/model_index.json"): if not os.path.exists(model_path + "/model_index.json"):
server_state["pipe"] = StableDiffusionPipeline.from_pretrained( server_state["pipe"] = StableDiffusionPipeline.from_pretrained(

View File

@ -369,8 +369,8 @@ def torch_gc():
def load_LDSR(checking=False): def load_LDSR(checking=False):
model_name = 'model' model_name = 'model'
yaml_name = 'project' yaml_name = 'project'
model_path = os.path.join(LDSR_dir, 'experiments/pretrained_models', model_name + '.ckpt') model_path = os.path.join(LDSR_dir, model_name + '.ckpt')
yaml_path = os.path.join(LDSR_dir, 'experiments/pretrained_models', yaml_name + '.yaml') yaml_path = os.path.join(LDSR_dir, yaml_name + '.yaml')
if not os.path.isfile(model_path): if not os.path.isfile(model_path):
raise Exception("LDSR model not found at path "+model_path) raise Exception("LDSR model not found at path "+model_path)
if not os.path.isfile(yaml_path): if not os.path.isfile(yaml_path):
@ -384,7 +384,7 @@ def load_LDSR(checking=False):
return LDSRObject return LDSRObject
def load_GFPGAN(checking=False): def load_GFPGAN(checking=False):
model_name = 'GFPGANv1.3' model_name = 'GFPGANv1.3'
model_path = os.path.join(GFPGAN_dir, 'experiments/pretrained_models', model_name + '.pth') model_path = os.path.join(GFPGAN_dir, model_name + '.pth')
if not os.path.isfile(model_path): if not os.path.isfile(model_path):
raise Exception("GFPGAN model not found at path "+model_path) raise Exception("GFPGAN model not found at path "+model_path)
if checking == True: if checking == True:
@ -407,7 +407,7 @@ def load_RealESRGAN(model_name: str, checking = False):
'RealESRGAN_x4plus_anime_6B': RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4) 'RealESRGAN_x4plus_anime_6B': RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
} }
model_path = os.path.join(RealESRGAN_dir, 'experiments/pretrained_models', model_name + '.pth') model_path = os.path.join(RealESRGAN_dir, model_name + '.pth')
if not os.path.isfile(model_path): if not os.path.isfile(model_path):
raise Exception(model_name+".pth not found at path "+model_path) raise Exception(model_name+".pth not found at path "+model_path)
if checking == True: if checking == True:

View File

@ -100,13 +100,13 @@ def layout():
# check if the models exist on their respective folders # check if the models exist on their respective folders
with server_state_lock["GFPGAN_available"]: with server_state_lock["GFPGAN_available"]:
if os.path.exists(os.path.join(st.session_state["defaults"].general.GFPGAN_dir, "experiments", "pretrained_models", "GFPGANv1.3.pth")): if os.path.exists(os.path.join(st.session_state["defaults"].general.GFPGAN_dir, f"{st.session_state['defaults'].general.GFPGAN_model}.pth")):
server_state["GFPGAN_available"] = True server_state["GFPGAN_available"] = True
else: else:
server_state["GFPGAN_available"] = False server_state["GFPGAN_available"] = False
with server_state_lock["RealESRGAN_available"]: with server_state_lock["RealESRGAN_available"]:
if os.path.exists(os.path.join(st.session_state["defaults"].general.RealESRGAN_dir, "experiments","pretrained_models", f"{st.session_state['defaults'].general.RealESRGAN_model}.pth")): if os.path.exists(os.path.join(st.session_state["defaults"].general.RealESRGAN_dir, f"{st.session_state['defaults'].general.RealESRGAN_model}.pth")):
server_state["RealESRGAN_available"] = True server_state["RealESRGAN_available"] = True
else: else:
server_state["RealESRGAN_available"] = False server_state["RealESRGAN_available"] = False