Added cell to prefetch models. (#1619)

# Description

Added a config cell for the Colab instance. Now it can pre-fetch models.
Folder on Google drive can be specified. If models are found there, they
will be symlinked instead of downloaded. Any models found in folder, but
not in download list, will be symlinked to models/custom. Also added
comments to code and titles to cells.
Should just be able to enter settings in first cell, then hit 'run all'.
Well... twice.

# Checklist:

- [x] I have changed the base branch to `dev`
- [x] I have performed a self-review of my own code
- [x] I have commented my code in hard-to-understand areas
- [ ] I have made corresponding changes to the documentation
This commit is contained in:
Alejandro Gil 2022-10-29 10:02:12 -07:00 committed by GitHub
commit 5d3feebfd0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -129,7 +129,7 @@
"\n",
"## Streamlit\n",
"\n",
"![](images/streamlit/streamlit-t2i.png)\n",
"![](https://github.com/aedhcarrick/sygil-webui/blob/patch-2/images/streamlit/streamlit-t2i.png?raw=1)\n",
"\n",
"**Features:**\n",
"\n",
@ -148,7 +148,7 @@
"\n",
"## Gradio\n",
"\n",
"![](images/gradio/gradio-t2i.png)\n",
"![](https://github.com/aedhcarrick/sygil-webui/blob/patch-2/images/gradio/gradio-t2i.png?raw=1)\n",
"\n",
"**Features:**\n",
"\n",
@ -166,7 +166,7 @@
"\n",
"### GFPGAN\n",
"\n",
"![](images/GFPGAN.png)\n",
"![](https://github.com/aedhcarrick/sygil-webui/blob/patch-2/images/GFPGAN.png?raw=1)\n",
"\n",
"Lets you improve faces in pictures using the GFPGAN model. There is a checkbox in every tab to use GFPGAN at 100%, and also a separate tab that just allows you to use GFPGAN on any picture, with a slider that controls how strong the effect is.\n",
"\n",
@ -176,7 +176,7 @@
"\n",
"### RealESRGAN\n",
"\n",
"![](images/RealESRGAN.png)\n",
"![](https://github.com/aedhcarrick/sygil-webui/blob/patch-2/images/RealESRGAN.png?raw=1)\n",
"\n",
"Lets you double the resolution of generated images. There is a checkbox in every tab to use RealESRGAN, and you can choose between the regular upscaler and the anime version.\n",
"There is also a separate tab for using RealESRGAN on any picture.\n",
@ -265,7 +265,54 @@
{
"cell_type": "markdown",
"source": [
"# Setup"
"# Config options for Colab instance"
],
"metadata": {
"id": "iegma7yteERV"
}
},
{
"cell_type": "code",
"source": [
"#@markdown WebUI repo (and branch)\n",
"repo_name = \"Sygil-Dev/sygil-webui\" #@param {type:\"string\"}\n",
"repo_branch = \"dev\" #@param {type:\"string\"}\n",
"\n",
"#@markdown Mount Google Drive\n",
"mount_google_drive = True #@param {type:\"boolean\"}\n",
"save_outputs_to_drive = True #@param {type:\"boolean\"}\n",
"#@markdown Folder in Google Drive to search for custom models\n",
"MODEL_DIR = \"\" #@param {type:\"string\"}\n",
"\n",
"#@markdown Enter auth token from Huggingface.co\n",
"#@markdown >(required for downloading stable diffusion model.)\n",
"HF_TOKEN = \"\" #@param {type:\"string\"}\n",
"\n",
"#@markdown Select which models to prefetch\n",
"STABLE_DIFFUSION = True #@param {type:\"boolean\"}\n",
"WAIFU_DIFFUSION = False #@param {type:\"boolean\"}\n",
"TRINART_SD = False #@param {type:\"boolean\"}\n",
"SD_WD_LD_TRINART_MERGED = False #@param {type:\"boolean\"}\n",
"GFPGAN = True #@param {type:\"boolean\"}\n",
"REALESRGAN = True #@param {type:\"boolean\"}\n",
"LDSR = True #@param {type:\"boolean\"}\n",
"BLIP_MODEL = False #@param {type:\"boolean\"}\n",
"\n"
],
"metadata": {
"id": "OXn96M9deVtF"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# Setup\n",
"\n",
">Runtime will crash when installing conda. This is normal as we are forcing a restart of the runtime from code.\n",
"\n",
">Just hit \"Run All\" again. 😑"
],
"metadata": {
"id": "IZjJSr-WPNxB"
@ -277,6 +324,7 @@
"id": "eq0-E5mjSpmP"
},
"source": [
"#@title Make sure we have access to GPU backend\n",
"!nvidia-smi -L"
],
"execution_count": null,
@ -285,14 +333,14 @@
{
"cell_type": "code",
"source": [
"#@title Install miniConda (mamba)\n",
"!pip install condacolab\n",
"import condacolab\n",
"condacolab.install_from_url(\"https://github.com/conda-forge/miniforge/releases/download/4.14.0-0/Mambaforge-4.14.0-0-Linux-x86_64.sh\")\n",
"\n",
"import condacolab\n",
"condacolab.check()\n",
"\n",
"# The runtime will crash after this, its normal as we are forcing a restart of the runtime from code. Just hit \"Run All\" again."
"# The runtime will crash here!!! Don't panic! We planned for this remember?"
],
"metadata": {
"id": "cDu33xkdJ5mD"
@ -303,9 +351,13 @@
{
"cell_type": "code",
"source": [
"!git clone https://github.com/Sygil-Dev/sygil-webui.git\n",
"%cd /content/sygil-webui/\n",
"!git checkout dev\n",
"#@title Clone webUI repo and download font\n",
"import os\n",
"REPO_URL = os.path.join('https://github.com', repo_name)\n",
"PATH_TO_REPO = os.path.join('/content', repo_name.split('/')[1])\n",
"!git clone {REPO_URL}\n",
"%cd {PATH_TO_REPO}\n",
"!git checkout {repo_branch}\n",
"!git pull\n",
"!wget -O arial.ttf https://github.com/matomo-org/travis-scripts/blob/master/fonts/Arial.ttf?raw=true"
],
@ -318,7 +370,10 @@
{
"cell_type": "code",
"source": [
"!mamba install cudatoolkit=11.3 git numpy=1.22.3 pip=20.3 python=3.8.5 pytorch=1.11.0 scikit-image=0.19.2 torchvision=0.12.0 -y"
"#@title Install dependencies\n",
"!mamba install cudatoolkit=11.3 git numpy=1.22.3 pip=20.3 python=3.8.5 pytorch=1.11.0 scikit-image=0.19.2 torchvision=0.12.0 -y\n",
"!python --version\n",
"!pip install -r requirements.txt"
],
"metadata": {
"id": "dmN2igp5Yk3z"
@ -329,52 +384,29 @@
{
"cell_type": "code",
"source": [
"#@title Install dependencies.\n",
"!python --version\n",
"!pip install -r requirements.txt"
],
"metadata": {
"id": "vXX0OaR8KyLQ"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"#@title Install localtunnel to openGoogle's ports\n",
"!npm install localtunnel"
],
"metadata": {
"id": "FHyVuT5aSM2G"
"id": "Nxaxfgo_F8Am"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"#Launch the WebUI"
],
"metadata": {
"id": "csi6cj6gQZmC"
}
},
{
"cell_type": "code",
"source": [
"#@title Mount Google Drive\n",
"import os\n",
"mount_google_drive = True #@param {type:\"boolean\"}\n",
"save_outputs_to_drive = True #@param {type:\"boolean\"}\n",
"\n",
"#@title Mount Google Drive (if selected)\n",
"if mount_google_drive:\n",
" # Mount google drive to store your outputs.\n",
" # Mount google drive to store outputs.\n",
" from google.colab import drive\n",
" drive.mount('/content/drive/', force_remount=True)\n",
"\n",
"if save_outputs_to_drive:\n",
" os.makedirs(\"/content/drive/MyDrive/sygil-webui/outputs\", exist_ok=True)\n",
" os.symlink(\"/content/drive/MyDrive/sygil-webui/outputs\", \"/content/sygil-webui/outputs\", target_is_directory=True)\n"
" # Make symlink to redirect downloads\n",
" OUTPUT_PATH = os.path.join('/content/drive/MyDrive', repo_name.split('/')[1], 'outputs')\n",
" os.makedirs(OUTPUT_PATH, exist_ok=True)\n",
" os.symlink(OUTPUT_PATH, os.path.join(PATH_TO_REPO, 'outputs'), target_is_directory=True)\n"
],
"metadata": {
"id": "pcSWo9Zkzbsf"
@ -385,20 +417,114 @@
{
"cell_type": "code",
"source": [
"#@title Enter Huggingface token\n",
"!git config --global credential.helper store\n",
"!huggingface-cli login"
"#@title Pre-fetch models\n",
"%cd {PATH_TO_REPO}\n",
"# make list of models we want to download\n",
"model_list = {\n",
" 'stable_diffusion': f'{STABLE_DIFFUSION}',\n",
" 'waifu_diffusion': f'{WAIFU_DIFFUSION}',\n",
" 'trinart_stable_diffusion': f'{TRINART_SD}',\n",
" 'sd_wd_ld_trinart_merged': f'{SD_WD_LD_TRINART_MERGED}',\n",
" 'gfpgan': f'{GFPGAN}',\n",
" 'realesrgan': f'{REALESRGAN}',\n",
" 'ldsr': f'{LDSR}',\n",
" 'blip_model': f'{BLIP_MODEL}'}\n",
"download_list = {k for (k,v) in model_list.items() if v == 'True'}\n",
"\n",
"# get model info (file name, download link, save location)\n",
"import yaml\n",
"from pprint import pprint\n",
"with open('configs/webui/webui_streamlit.yaml') as f:\n",
" dataMap = yaml.safe_load(f)\n",
"models = dataMap['model_manager']['models']\n",
"\n",
"# copy script from model manager\n",
"import requests, time\n",
"from requests.auth import HTTPBasicAuth\n",
"\n",
"def download_file(file_name, file_path, file_url):\n",
" os.makedirs(file_path, exist_ok=True)\n",
" if os.path.exists(os.path.join(MODEL_DIR , file_name)):\n",
" print( file_name + \"found in Google Drive\")\n",
" print( \"Creating symlink...\")\n",
" os.symlink(os.path.join(MODEL_DIR , file_name), os.path.join(file_path, file_name))\n",
" elif not os.path.exists(os.path.join(file_path , file_name)):\n",
" print( \"Downloading \" + file_name + \"...\", end=\"\" )\n",
" token = None\n",
" if \"huggingface.co\" in file_url:\n",
" token = HTTPBasicAuth('token', HF_TOKEN)\n",
" try:\n",
" with requests.get(file_url, auth = token, stream=True) as r:\n",
" starttime = time.time()\n",
" r.raise_for_status()\n",
" with open(os.path.join(file_path, file_name), 'wb') as f:\n",
" for chunk in r.iter_content(chunk_size=8192):\n",
" f.write(chunk)\n",
" if ((time.time() - starttime) % 60.0) > 2 :\n",
" starttime = time.time()\n",
" print( \".\", end=\"\" )\n",
" print( \"done\" )\n",
" print( \" \" + file_name + \" downloaded to \\'\" + file_path + \"\\'\" )\n",
" except:\n",
" print( \"Failed to download \" + file_name + \".\" )\n",
" else:\n",
" print( file_name + \" already exists.\" )\n",
"\n",
"# download models in list\n",
"for model in download_list:\n",
" model_name = models[model]['model_name']\n",
" file_info = models[model]['files']\n",
" for file in file_info:\n",
" file_name = file_info[file]['file_name']\n",
" file_url = file_info[file]['download_link']\n",
" if 'save_location' in file_info[file]:\n",
" file_path = file_info[file]['save_location']\n",
" else: \n",
" file_path = models[model]['save_location']\n",
" download_file(file_name, file_path, file_url)\n",
"\n",
"# add custom models not in list\n",
"CUSTOM_MODEL_DIR = os.path.join(PATH_TO_REPO, 'models/custom')\n",
"if MODEL_DIR != \"\":\n",
" MODEL_DIR = os.path.join('/content/drive/MyDrive', MODEL_DIR)\n",
" if os.path.exists(MODEL_DIR):\n",
" custom_models = os.listdir(MODEL_DIR)\n",
" custom_models = [m for m in custom_models if os.path.isfile(MODEL_DIR + '/' + m)]\n",
" os.makedirs(CUSTOM_MODEL_DIR, exist_ok=True)\n",
" print( \"Custom model(s) found: \" )\n",
" for m in custom_models:\n",
" print( \" \" + m )\n",
" os.symlink(os.path.join(MODEL_DIR , m), os.path.join(CUSTOM_MODEL_DIR, m))\n",
"\n"
],
"metadata": {
"id": "IsbG7fvIrKwg"
"id": "vMdmh81J70yA"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# Launch the web ui server\n",
"### (optional) JS to prevent idle timeout:\n",
"Press 'F12' OR ('CTRL' + 'SHIFT' + 'I') OR right click on this website -> inspect. Then click on the console tab and paste in the following code.\n",
"```js,\n",
"function ClickConnect(){\n",
"console.log(\"Working\");\n",
"document.querySelector(\"colab-toolbar-button#connect\").click()\n",
"}\n",
"setInterval(ClickConnect,60000)\n",
"```"
],
"metadata": {
"id": "pjIjiCuJysJI"
}
},
{
"cell_type": "code",
"source": [
"#@title <-- Press play on the music player to keep the tab alive (Uses only 13MB of data)\n",
"#@title Press play on the music player to keep the tab alive (Uses only 13MB of data)\n",
"%%html\n",
"<b>Press play on the music player to keep the tab alive, then start your generation below (Uses only 13MB of data)</b><br/>\n",
"<audio src=\"https://henk.tech/colabkobold/silence.m4a\" controls>"
@ -409,27 +535,10 @@
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"JS to prevent idle timeout:\n",
"\n",
"Press F12 OR CTRL + SHIFT + I OR right click on this website -> inspect. Then click on the console tab and paste in the following code.\n",
"\n",
"function ClickConnect(){\n",
"console.log(\"Working\");\n",
"document.querySelector(\"colab-toolbar-button#connect\").click()\n",
"}\n",
"setInterval(ClickConnect,60000)"
],
"metadata": {
"id": "pjIjiCuJysJI"
}
},
{
"cell_type": "code",
"source": [
"#@title Open port 8501 and start Streamlit server. Open link in 'link.txt' file in file pane on left.\n",
"#@title Run localtunnel and start Streamlit server. ('Ctrl' + 'left click') on link in the 'link.txt' file. (/content/link.txt)\n",
"!npx localtunnel --port 8501 &>/content/link.txt &\n",
"!streamlit run scripts/webui_streamlit.py --theme.base dark --server.headless true 2>&1 | tee -a /content/log.txt"
],