cleaned code, added option to save models to drive for faster loading, untested!!!

This commit is contained in:
aedh carrick 2022-11-06 18:11:52 -06:00
parent d508b14984
commit 4be125a7d5

View File

@ -276,6 +276,7 @@
{
"cell_type": "code",
"source": [
"#@title { display-mode: \"form\" }\n",
"#@markdown WebUI repo (and branch)\n",
"repo_name = \"Sygil-Dev/sygil-webui\" #@param {type:\"string\"}\n",
"repo_branch = \"dev\" #@param {type:\"string\"}\n",
@ -302,7 +303,9 @@
"REALESRGAN = True #@param {type:\"boolean\"}\n",
"LDSR = True #@param {type:\"boolean\"}\n",
"BLIP_MODEL = False #@param {type:\"boolean\"}\n",
"\n"
"\n",
"#@markdown Save models to Google Drive for faster loading in future (Be warned! Make sure you have enough space!)\n",
"SAVE_MODELS = False #@param {type:\"boolean\"}"
],
"metadata": {
"id": "OXn96M9deVtF"
@ -363,8 +366,7 @@
"!git clone {REPO_URL}\n",
"%cd {PATH_TO_REPO}\n",
"!git checkout {repo_branch}\n",
"!git pull\n",
"!wget -O arial.ttf https://github.com/matomo-org/travis-scripts/blob/master/fonts/Arial.ttf?raw=true"
"!git pull"
],
"metadata": {
"id": "pZHGf03Vp305"
@ -442,19 +444,27 @@
"with open('configs/webui/webui_streamlit.yaml') as f:\n",
" dataMap = yaml.safe_load(f)\n",
"models = dataMap['model_manager']['models']\n",
"existing_models = []\n",
"\n",
"# copy script from model manager\n",
"import requests, time\n",
"import requests, time, shutil\n",
"from requests.auth import HTTPBasicAuth\n",
"\n",
"if MODEL_DIR != \"\":\n",
" MODEL_DIR = os.path.join('/content/drive/MyDrive', MODEL_DIR)\n",
"else:\n",
" MODEL_DIR = '/content/drive/MyDrive'\n",
"\n",
"def download_file(file_name, file_path, file_url):\n",
" os.makedirs(file_path, exist_ok=True)\n",
" link_path = os.path.join(MODEL_DIR, file_name)\n",
" full_path = os.path.join(file_path, file_name)\n",
" if os.path.exists(os.path.join(MODEL_DIR , file_name)):\n",
" if os.path.exists(link_path):\n",
" print( file_name + \" found in Google Drive\")\n",
" if not os.path.exists(full_path):\n",
" print( file_name + \"found in Google Drive\")\n",
" print( \"Creating symlink...\")\n",
" os.symlink(os.path.join(MODEL_DIR , file_name), full_path)\n",
" print( \" creating symlink...\")\n",
" os.symlink(link_path, full_path)\n",
" print( \" symlink already exists\")\n",
" elif not os.path.exists(full_path):\n",
" print( \"Downloading \" + file_name + \"...\", end=\"\" )\n",
" token = None\n",
@ -472,10 +482,15 @@
" print( \".\", end=\"\" )\n",
" print( \"done\" )\n",
" print( \" \" + file_name + \" downloaded to \\'\" + file_path + \"\\'\" )\n",
" if SAVE_MODELS and os.path.exists(MODEL_DIR):\n",
" shutil.copy2(full_path,MODEL_DIR)\n",
" print( \" Saved \" + file_name + \" to \" + MODEL_DIR)\n",
" except:\n",
" print( \"Failed to download \" + file_name + \".\" )\n",
" return\n",
" else:\n",
" print( full_path + \" already exists.\" )\n",
" existing_models.append(file_name)\n",
"\n",
"# download models in list\n",
"for model in download_list:\n",
@ -492,18 +507,18 @@
"\n",
"# add custom models not in list\n",
"CUSTOM_MODEL_DIR = os.path.join(PATH_TO_REPO, 'models/custom')\n",
"if MODEL_DIR != \"\":\n",
" MODEL_DIR = os.path.join('/content/drive/MyDrive', MODEL_DIR)\n",
" if os.path.exists(MODEL_DIR):\n",
" custom_models = os.listdir(MODEL_DIR)\n",
" custom_models = [m for m in custom_models if os.path.isfile(MODEL_DIR + '/' + m)]\n",
" os.makedirs(CUSTOM_MODEL_DIR, exist_ok=True)\n",
" print( \"Custom model(s) found: \" )\n",
" for m in custom_models:\n",
" full_model_path = os.path.join(CUSTOM_MODEL_DIR, m)\n",
" if not os.path.exists(full_model_path):\n",
" print( \" \" + m )\n",
" os.symlink(os.path.join(MODEL_DIR , m), full_model_path)\n",
"if os.path.exists(MODEL_DIR):\n",
" custom_models = os.listdir(MODEL_DIR)\n",
" custom_models = [m for m in custom_models if os.path.isfile(MODEL_DIR + '/' + m)]\n",
" os.makedirs(CUSTOM_MODEL_DIR, exist_ok=True)\n",
" print( \"Custom model(s) found: \" )\n",
" for m in custom_models:\n",
" if m in existing_models:\n",
" continue\n",
" full_path = os.path.join(CUSTOM_MODEL_DIR, m)\n",
" if not os.path.exists(full_model_path):\n",
" print( \" \" + m )\n",
" os.symlink(os.path.join(MODEL_DIR , m), full_path)\n",
"\n",
"# get custom config file if it exists\n",
"if CONFIG_DIR != \"\":\n",