mirror of
https://github.com/Sygil-Dev/sygil-webui.git
synced 2024-12-14 22:13:41 +03:00
Started to move all the streamlit code and dependencies to its own folder in webui/streamlit
and moving the backend to use nataili, this should help make it self-contained and reduce the amount of code we have in other places which should also make it so the code is easier to understand and read.
This commit is contained in:
parent
0781ced89a
commit
ffd7883cb0
178
webui/streamlit/frontend/css/streamlit.main.css
Normal file
178
webui/streamlit/frontend/css/streamlit.main.css
Normal file
@ -0,0 +1,178 @@
|
||||
/*
|
||||
This file is part of sygil-webui (https://github.com/Sygil-Dev/sygil-webui/).
|
||||
|
||||
Copyright 2022 Sygil-Dev team.
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
/***********************************************************
|
||||
* Additional CSS for streamlit builtin components *
|
||||
************************************************************/
|
||||
|
||||
/* Tab name (e.g. Text-to-Image) //improve legibility*/
|
||||
button[data-baseweb="tab"] {
|
||||
font-size: 25px;
|
||||
}
|
||||
|
||||
/* Image Container (only appear after run finished)//center the image, especially better looks in wide screen */
|
||||
.css-1kyxreq{
|
||||
justify-content: center;
|
||||
}
|
||||
|
||||
|
||||
/* Streamlit header */
|
||||
.css-1avcm0n {
|
||||
background-color: transparent;
|
||||
}
|
||||
|
||||
/* Main streamlit container (below header) //reduce the empty spaces*/
|
||||
.css-18e3th9 {
|
||||
padding-top: 1rem;
|
||||
}
|
||||
|
||||
|
||||
|
||||
/***********************************************************
|
||||
* Additional CSS for streamlit custom/3rd party components *
|
||||
************************************************************/
|
||||
/* For stream_on_hover */
|
||||
section[data-testid="stSidebar"] > div:nth-of-type(1) {
|
||||
background-color: #111;
|
||||
}
|
||||
|
||||
button[kind="header"] {
|
||||
background-color: transparent;
|
||||
color: rgb(180, 167, 141);
|
||||
}
|
||||
|
||||
@media (hover) {
|
||||
/* header element */
|
||||
header[data-testid="stHeader"] {
|
||||
/* display: none;*/ /*suggested behavior by streamlit hover components*/
|
||||
pointer-events: none; /* disable interaction of the transparent background */
|
||||
}
|
||||
|
||||
/* The button on the streamlit navigation menu */
|
||||
button[kind="header"] {
|
||||
/* display: none;*/ /*suggested behavior by streamlit hover components*/
|
||||
pointer-events: auto; /* enable interaction of the button even if parents intereaction disabled */
|
||||
}
|
||||
|
||||
/* added to avoid main sectors (all element to the right of sidebar from) moving */
|
||||
section[data-testid="stSidebar"] {
|
||||
width: 3.5% !important;
|
||||
min-width: 3.5% !important;
|
||||
}
|
||||
|
||||
/* The navigation menu specs and size */
|
||||
section[data-testid="stSidebar"] > div {
|
||||
height: 100%;
|
||||
width: 2% !important;
|
||||
min-width: 100% !important;
|
||||
position: relative;
|
||||
z-index: 1;
|
||||
top: 0;
|
||||
left: 0;
|
||||
background-color: #111;
|
||||
overflow-x: hidden;
|
||||
transition: 0.5s ease-in-out;
|
||||
padding-top: 0px;
|
||||
white-space: nowrap;
|
||||
}
|
||||
|
||||
/* The navigation menu open and close on hover and size */
|
||||
section[data-testid="stSidebar"] > div:hover {
|
||||
width: 300px !important;
|
||||
}
|
||||
}
|
||||
|
||||
@media (max-width: 272px) {
|
||||
section[data-testid="stSidebar"] > div {
|
||||
width: 15rem;
|
||||
}
|
||||
}
|
||||
|
||||
/***********************************************************
|
||||
* Additional CSS for other elements
|
||||
************************************************************/
|
||||
button[data-baseweb="tab"] {
|
||||
font-size: 20px;
|
||||
}
|
||||
|
||||
@media (min-width: 1200px){
|
||||
h1 {
|
||||
font-size: 1.75rem;
|
||||
}
|
||||
}
|
||||
#tabs-1-tabpanel-0 > div:nth-child(1) > div > div.stTabs.css-0.exp6ofz0 {
|
||||
width: 50rem;
|
||||
align-self: center;
|
||||
}
|
||||
div.gallery:hover {
|
||||
border: 1px solid #777;
|
||||
}
|
||||
.css-dg4u6x p {
|
||||
font-size: 0.8rem;
|
||||
text-align: center;
|
||||
position: relative;
|
||||
top: 6px;
|
||||
}
|
||||
|
||||
.row-widget.stButton {
|
||||
text-align: center;
|
||||
}
|
||||
|
||||
/********************************************************************
|
||||
Hide anchor links on titles
|
||||
*********************************************************************/
|
||||
/*
|
||||
.css-15zrgzn {
|
||||
display: none
|
||||
}
|
||||
.css-eczf16 {
|
||||
display: none
|
||||
}
|
||||
.css-jn99sy {
|
||||
display: none
|
||||
}
|
||||
|
||||
/* Make the text area widget have a similar height as the text input field */
|
||||
.st-dy{
|
||||
height: 54px;
|
||||
min-height: 25px;
|
||||
}
|
||||
.css-17useex{
|
||||
gap: 3px;
|
||||
|
||||
}
|
||||
|
||||
/* Remove some empty spaces to make the UI more compact. */
|
||||
.css-18e3th9{
|
||||
padding-left: 10px;
|
||||
padding-right: 30px;
|
||||
position: unset !important; /* Fixes the layout/page going up when an expander or another item is expanded and then collapsed */
|
||||
}
|
||||
.css-k1vhr4{
|
||||
padding-top: initial;
|
||||
}
|
||||
.css-ret2ud{
|
||||
padding-left: 10px;
|
||||
padding-right: 30px;
|
||||
gap: initial;
|
||||
display: initial;
|
||||
}
|
||||
|
||||
.css-w5z5an{
|
||||
gap: 1px;
|
||||
}
|
34
webui/streamlit/scripts/APIServer.py
Normal file
34
webui/streamlit/scripts/APIServer.py
Normal file
@ -0,0 +1,34 @@
|
||||
# This file is part of sygil-webui (https://github.com/Sygil-Dev/sygil-webui/).
|
||||
|
||||
# Copyright 2022 Sygil-Dev team.
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as published by
|
||||
# the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
|
||||
# This program is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU Affero General Public License for more details.
|
||||
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
# base webui import and utils.
|
||||
#from sd_utils import *
|
||||
from sd_utils import st
|
||||
# streamlit imports
|
||||
|
||||
#streamlit components section
|
||||
|
||||
#other imports
|
||||
#from fastapi import FastAPI
|
||||
#import uvicorn
|
||||
|
||||
# Temp imports
|
||||
|
||||
# end of imports
|
||||
#---------------------------------------------------------------------------------------------------------------
|
||||
|
||||
|
||||
def layout():
|
||||
st.info("Under Construction. :construction_worker:")
|
121
webui/streamlit/scripts/ModelManager.py
Normal file
121
webui/streamlit/scripts/ModelManager.py
Normal file
@ -0,0 +1,121 @@
|
||||
# This file is part of sygil-webui (https://github.com/Sygil-Dev/sygil-webui/).
|
||||
|
||||
# Copyright 2022 Sygil-Dev team.
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as published by
|
||||
# the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
|
||||
# This program is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU Affero General Public License for more details.
|
||||
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
# base webui import and utils.
|
||||
from sd_utils import st, logger
|
||||
# streamlit imports
|
||||
|
||||
|
||||
#other imports
|
||||
import os, requests
|
||||
from requests.auth import HTTPBasicAuth
|
||||
from requests import HTTPError
|
||||
from stqdm import stqdm
|
||||
|
||||
# Temp imports
|
||||
|
||||
|
||||
# end of imports
|
||||
#---------------------------------------------------------------------------------------------------------------
|
||||
def download_file(file_name, file_path, file_url):
|
||||
if not os.path.exists(file_path):
|
||||
os.makedirs(file_path)
|
||||
|
||||
if not os.path.exists(os.path.join(file_path , file_name)):
|
||||
print('Downloading ' + file_name + '...')
|
||||
# TODO - add progress bar in streamlit
|
||||
# download file with `requests``
|
||||
if file_name == "Stable Diffusion v1.5":
|
||||
if "huggingface_token" not in st.session_state or st.session_state["defaults"].general.huggingface_token == "None":
|
||||
if "progress_bar_text" in st.session_state:
|
||||
st.session_state["progress_bar_text"].error(
|
||||
"You need a huggingface token in order to use the Text to Video tab. Use the Settings page from the sidebar on the left to add your token."
|
||||
)
|
||||
raise OSError("You need a huggingface token in order to use the Text to Video tab. Use the Settings page from the sidebar on the left to add your token.")
|
||||
|
||||
try:
|
||||
with requests.get(file_url, auth = HTTPBasicAuth('token', st.session_state.defaults.general.huggingface_token) if "huggingface.co" in file_url else None, stream=True) as r:
|
||||
r.raise_for_status()
|
||||
with open(os.path.join(file_path, file_name), 'wb') as f:
|
||||
for chunk in stqdm(r.iter_content(chunk_size=8192), backend=True, unit="kb"):
|
||||
f.write(chunk)
|
||||
except HTTPError as e:
|
||||
if "huggingface.co" in file_url:
|
||||
if "resolve"in file_url:
|
||||
repo_url = file_url.split("resolve")[0]
|
||||
|
||||
st.session_state["progress_bar_text"].error(
|
||||
f"You need to accept the license for the model in order to be able to download it. "
|
||||
f"Please visit {repo_url} and accept the lincense there, then try again to download the model.")
|
||||
|
||||
logger.error(e)
|
||||
|
||||
else:
|
||||
print(file_name + ' already exists.')
|
||||
|
||||
|
||||
def download_model(models, model_name):
|
||||
""" Download all files from model_list[model_name] """
|
||||
for file in models[model_name]:
|
||||
download_file(file['file_name'], file['file_path'], file['file_url'])
|
||||
return
|
||||
|
||||
|
||||
def layout():
|
||||
#search = st.text_input(label="Search", placeholder="Type the name of the model you want to search for.", help="")
|
||||
|
||||
colms = st.columns((1, 3, 3, 5, 5))
|
||||
columns = ["№", 'Model Name', 'Save Location', "Download", 'Download Link']
|
||||
|
||||
models = st.session_state["defaults"].model_manager.models
|
||||
|
||||
for col, field_name in zip(colms, columns):
|
||||
# table header
|
||||
col.write(field_name)
|
||||
|
||||
for x, model_name in enumerate(models):
|
||||
col1, col2, col3, col4, col5 = st.columns((1, 3, 3, 3, 6))
|
||||
col1.write(x) # index
|
||||
col2.write(models[model_name]['model_name'])
|
||||
col3.write(models[model_name]['save_location'])
|
||||
with col4:
|
||||
files_exist = 0
|
||||
for file in models[model_name]['files']:
|
||||
if "save_location" in models[model_name]['files'][file]:
|
||||
os.path.exists(os.path.join(models[model_name]['files'][file]['save_location'] , models[model_name]['files'][file]['file_name']))
|
||||
files_exist += 1
|
||||
elif os.path.exists(os.path.join(models[model_name]['save_location'] , models[model_name]['files'][file]['file_name'])):
|
||||
files_exist += 1
|
||||
files_needed = []
|
||||
for file in models[model_name]['files']:
|
||||
if "save_location" in models[model_name]['files'][file]:
|
||||
if not os.path.exists(os.path.join(models[model_name]['files'][file]['save_location'] , models[model_name]['files'][file]['file_name'])):
|
||||
files_needed.append(file)
|
||||
elif not os.path.exists(os.path.join(models[model_name]['save_location'] , models[model_name]['files'][file]['file_name'])):
|
||||
files_needed.append(file)
|
||||
if len(files_needed) > 0:
|
||||
if st.button('Download', key=models[model_name]['model_name'], help='Download ' + models[model_name]['model_name']):
|
||||
for file in files_needed:
|
||||
if "save_location" in models[model_name]['files'][file]:
|
||||
download_file(models[model_name]['files'][file]['file_name'], models[model_name]['files'][file]['save_location'], models[model_name]['files'][file]['download_link'])
|
||||
else:
|
||||
download_file(models[model_name]['files'][file]['file_name'], models[model_name]['save_location'], models[model_name]['files'][file]['download_link'])
|
||||
st.experimental_rerun()
|
||||
else:
|
||||
st.empty()
|
||||
else:
|
||||
st.write('✅')
|
||||
|
||||
#
|
899
webui/streamlit/scripts/Settings.py
Normal file
899
webui/streamlit/scripts/Settings.py
Normal file
@ -0,0 +1,899 @@
|
||||
# This file is part of sygil-webui (https://github.com/Sygil-Dev/sygil-webui/).
|
||||
|
||||
# Copyright 2022 Sygil-Dev team.
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as published by
|
||||
# the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
|
||||
# This program is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU Affero General Public License for more details.
|
||||
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
# base webui import and utils.
|
||||
from sd_utils import st, custom_models_available, logger, human_readable_size
|
||||
|
||||
# streamlit imports
|
||||
|
||||
# streamlit components section
|
||||
import streamlit_nested_layout
|
||||
from streamlit_server_state import server_state
|
||||
|
||||
# other imports
|
||||
from omegaconf import OmegaConf
|
||||
import torch
|
||||
import os, toml
|
||||
|
||||
# end of imports
|
||||
# ---------------------------------------------------------------------------------------------------------------
|
||||
|
||||
@logger.catch(reraise=True)
|
||||
def layout():
|
||||
#st.header("Settings")
|
||||
|
||||
with st.form("Settings"):
|
||||
general_tab, txt2img_tab, img2img_tab, img2txt_tab, txt2vid_tab, image_processing, textual_inversion_tab, concepts_library_tab = st.tabs(
|
||||
['General', "Text-To-Image", "Image-To-Image", "Image-To-Text", "Text-To-Video", "Image processing", "Textual Inversion", "Concepts Library"])
|
||||
|
||||
with general_tab:
|
||||
col1, col2, col3, col4, col5 = st.columns(5, gap='large')
|
||||
|
||||
device_list = []
|
||||
device_properties = [(i, torch.cuda.get_device_properties(i)) for i in range(torch.cuda.device_count())]
|
||||
for device in device_properties:
|
||||
id = device[0]
|
||||
name = device[1].name
|
||||
total_memory = device[1].total_memory
|
||||
|
||||
device_list.append(f"{id}: {name} ({human_readable_size(total_memory, decimal_places=0)})")
|
||||
|
||||
with col1:
|
||||
st.title("General")
|
||||
st.session_state['defaults'].general.gpu = int(st.selectbox("GPU", device_list, index=st.session_state['defaults'].general.gpu,
|
||||
help=f"Select which GPU to use. Default: {device_list[0]}").split(":")[0])
|
||||
|
||||
st.session_state['defaults'].general.outdir = str(st.text_input("Output directory", value=st.session_state['defaults'].general.outdir,
|
||||
help="Relative directory on which the output images after a generation will be placed. Default: 'outputs'"))
|
||||
|
||||
# If we have custom models available on the "models/custom"
|
||||
# folder then we show a menu to select which model we want to use, otherwise we use the main model for SD
|
||||
custom_models_available()
|
||||
|
||||
if server_state["CustomModel_available"]:
|
||||
st.session_state.defaults.general.default_model = st.selectbox("Default Model:", server_state["custom_models"],
|
||||
index=server_state["custom_models"].index(st.session_state['defaults'].general.default_model),
|
||||
help="Select the model you want to use. If you have placed custom models \
|
||||
on your 'models/custom' folder they will be shown here as well. The model name that will be shown here \
|
||||
is the same as the name the file for the model has on said folder, \
|
||||
it is recommended to give the .ckpt file a name that \
|
||||
will make it easier for you to distinguish it from other models. Default: Stable Diffusion v1.4")
|
||||
else:
|
||||
st.session_state.defaults.general.default_model = st.selectbox("Default Model:", [st.session_state['defaults'].general.default_model],
|
||||
help="Select the model you want to use. If you have placed custom models \
|
||||
on your 'models/custom' folder they will be shown here as well. \
|
||||
The model name that will be shown here is the same as the name\
|
||||
the file for the model has on said folder, it is recommended to give the .ckpt file a name that \
|
||||
will make it easier for you to distinguish it from other models. Default: Stable Diffusion v1.4")
|
||||
|
||||
st.session_state['defaults'].general.default_model_config = st.text_input("Default Model Config", value=st.session_state['defaults'].general.default_model_config,
|
||||
help="Default model config file for inference. Default: 'configs/stable-diffusion/v1-inference.yaml'")
|
||||
|
||||
st.session_state['defaults'].general.default_model_path = st.text_input("Default Model Config", value=st.session_state['defaults'].general.default_model_path,
|
||||
help="Default model path. Default: 'models/ldm/stable-diffusion-v1/model.ckpt'")
|
||||
|
||||
st.session_state['defaults'].general.GFPGAN_dir = st.text_input("Default GFPGAN directory", value=st.session_state['defaults'].general.GFPGAN_dir,
|
||||
help="Default GFPGAN directory. Default: './models/gfpgan'")
|
||||
|
||||
st.session_state['defaults'].general.RealESRGAN_dir = st.text_input("Default RealESRGAN directory", value=st.session_state['defaults'].general.RealESRGAN_dir,
|
||||
help="Default GFPGAN directory. Default: './models/realesrgan'")
|
||||
|
||||
RealESRGAN_model_list = ["RealESRGAN_x4plus", "RealESRGAN_x4plus_anime_6B"]
|
||||
st.session_state['defaults'].general.RealESRGAN_model = st.selectbox("RealESRGAN model", RealESRGAN_model_list,
|
||||
index=RealESRGAN_model_list.index(st.session_state['defaults'].general.RealESRGAN_model),
|
||||
help="Default RealESRGAN model. Default: 'RealESRGAN_x4plus'")
|
||||
Upscaler_list = ["RealESRGAN", "LDSR"]
|
||||
st.session_state['defaults'].general.upscaling_method = st.selectbox("Upscaler", Upscaler_list, index=Upscaler_list.index(
|
||||
st.session_state['defaults'].general.upscaling_method), help="Default upscaling method. Default: 'RealESRGAN'")
|
||||
|
||||
with col2:
|
||||
st.title("Performance")
|
||||
|
||||
st.session_state["defaults"].general.gfpgan_cpu = st.checkbox("GFPGAN - CPU", value=st.session_state['defaults'].general.gfpgan_cpu,
|
||||
help="Run GFPGAN on the cpu. Default: False")
|
||||
|
||||
st.session_state["defaults"].general.esrgan_cpu = st.checkbox("ESRGAN - CPU", value=st.session_state['defaults'].general.esrgan_cpu,
|
||||
help="Run ESRGAN on the cpu. Default: False")
|
||||
|
||||
st.session_state["defaults"].general.extra_models_cpu = st.checkbox("Extra Models - CPU", value=st.session_state['defaults'].general.extra_models_cpu,
|
||||
help="Run extra models (GFGPAN/ESRGAN) on cpu. Default: False")
|
||||
|
||||
st.session_state["defaults"].general.extra_models_gpu = st.checkbox("Extra Models - GPU", value=st.session_state['defaults'].general.extra_models_gpu,
|
||||
help="Run extra models (GFGPAN/ESRGAN) on gpu. \
|
||||
Check and save in order to be able to select the GPU that each model will use. Default: False")
|
||||
if st.session_state["defaults"].general.extra_models_gpu:
|
||||
st.session_state['defaults'].general.gfpgan_gpu = int(st.selectbox("GFGPAN GPU", device_list, index=st.session_state['defaults'].general.gfpgan_gpu,
|
||||
help=f"Select which GPU to use. Default: {device_list[st.session_state['defaults'].general.gfpgan_gpu]}",
|
||||
key="gfpgan_gpu").split(":")[0])
|
||||
|
||||
st.session_state["defaults"].general.esrgan_gpu = int(st.selectbox("ESRGAN - GPU", device_list, index=st.session_state['defaults'].general.esrgan_gpu,
|
||||
help=f"Select which GPU to use. Default: {device_list[st.session_state['defaults'].general.esrgan_gpu]}",
|
||||
key="esrgan_gpu").split(":")[0])
|
||||
|
||||
st.session_state["defaults"].general.no_half = st.checkbox("No Half", value=st.session_state['defaults'].general.no_half,
|
||||
help="DO NOT switch the model to 16-bit floats. Default: False")
|
||||
|
||||
st.session_state["defaults"].general.use_cudnn = st.checkbox("Use cudnn", value=st.session_state['defaults'].general.use_cudnn,
|
||||
help="Switch the pytorch backend to use cudnn, this should help with fixing Nvidia 16xx cards getting"
|
||||
"a black or green image. Default: False")
|
||||
|
||||
st.session_state["defaults"].general.use_float16 = st.checkbox("Use float16", value=st.session_state['defaults'].general.use_float16,
|
||||
help="Switch the model to 16-bit floats. Default: False")
|
||||
|
||||
|
||||
precision_list = ['full', 'autocast']
|
||||
st.session_state["defaults"].general.precision = st.selectbox("Precision", precision_list, index=precision_list.index(st.session_state['defaults'].general.precision),
|
||||
help="Evaluates at this precision. Default: autocast")
|
||||
|
||||
st.session_state["defaults"].general.optimized = st.checkbox("Optimized Mode", value=st.session_state['defaults'].general.optimized,
|
||||
help="Loads the model onto the device piecemeal instead of all at once to reduce VRAM usage\
|
||||
at the cost of performance. Default: False")
|
||||
|
||||
st.session_state["defaults"].general.optimized_turbo = st.checkbox("Optimized Turbo Mode", value=st.session_state['defaults'].general.optimized_turbo,
|
||||
help="Alternative optimization mode that does not save as much VRAM but \
|
||||
runs siginificantly faster. Default: False")
|
||||
|
||||
st.session_state["defaults"].general.optimized_config = st.text_input("Optimized Config", value=st.session_state['defaults'].general.optimized_config,
|
||||
help=f"Loads alternative optimized configuration for inference. \
|
||||
Default: optimizedSD/v1-inference.yaml")
|
||||
|
||||
st.session_state["defaults"].general.enable_attention_slicing = st.checkbox("Enable Attention Slicing", value=st.session_state['defaults'].general.enable_attention_slicing,
|
||||
help="Enable sliced attention computation. When this option is enabled, the attention module will \
|
||||
split the input tensor in slices, to compute attention in several steps. This is useful to save some \
|
||||
memory in exchange for a small speed decrease. Only works the txt2vid tab right now. Default: False")
|
||||
|
||||
st.session_state["defaults"].general.enable_minimal_memory_usage = st.checkbox("Enable Minimal Memory Usage", value=st.session_state['defaults'].general.enable_minimal_memory_usage,
|
||||
help="Moves only unet to fp16 and to CUDA, while keepping lighter models on CPUs \
|
||||
(Not properly implemented and currently not working, check this \
|
||||
link 'https://github.com/huggingface/diffusers/pull/537' for more information on it ). Default: False")
|
||||
|
||||
# st.session_state["defaults"].general.update_preview = st.checkbox("Update Preview Image", value=st.session_state['defaults'].general.update_preview,
|
||||
# help="Enables the preview image to be updated and shown to the user on the UI during the generation.\
|
||||
# If checked, once you save the settings an option to specify the frequency at which the image is updated\
|
||||
# in steps will be shown, this is helpful to reduce the negative effect this option has on performance. \
|
||||
# Default: True")
|
||||
st.session_state["defaults"].general.update_preview = True
|
||||
st.session_state["defaults"].general.update_preview_frequency = st.number_input("Update Preview Frequency",
|
||||
min_value=0,
|
||||
value=st.session_state['defaults'].general.update_preview_frequency,
|
||||
help="Specify the frequency at which the image is updated in steps, this is helpful to reduce the \
|
||||
negative effect updating the preview image has on performance. Default: 10")
|
||||
|
||||
with col3:
|
||||
st.title("Others")
|
||||
st.session_state["defaults"].general.use_sd_concepts_library = st.checkbox("Use the Concepts Library", value=st.session_state['defaults'].general.use_sd_concepts_library,
|
||||
help="Use the embeds Concepts Library, if checked, once the settings are saved an option will\
|
||||
appear to specify the directory where the concepts are stored. Default: True)")
|
||||
|
||||
if st.session_state["defaults"].general.use_sd_concepts_library:
|
||||
st.session_state['defaults'].general.sd_concepts_library_folder = st.text_input("Concepts Library Folder",
|
||||
value=st.session_state['defaults'].general.sd_concepts_library_folder,
|
||||
help="Relative folder on which the concepts library embeds are stored. \
|
||||
Default: 'models/custom/sd-concepts-library'")
|
||||
|
||||
st.session_state['defaults'].general.LDSR_dir = st.text_input("LDSR Folder", value=st.session_state['defaults'].general.LDSR_dir,
|
||||
help="Folder where LDSR is located. Default: './models/ldsr'")
|
||||
|
||||
st.session_state["defaults"].general.save_metadata = st.checkbox("Save Metadata", value=st.session_state['defaults'].general.save_metadata,
|
||||
help="Save metadata on the output image. Default: True")
|
||||
save_format_list = ["png","jpg", "jpeg","webp"]
|
||||
st.session_state["defaults"].general.save_format = st.selectbox("Save Format", save_format_list, index=save_format_list.index(st.session_state['defaults'].general.save_format),
|
||||
help="Format that will be used whens saving the output images. Default: 'png'")
|
||||
|
||||
st.session_state["defaults"].general.skip_grid = st.checkbox("Skip Grid", value=st.session_state['defaults'].general.skip_grid,
|
||||
help="Skip saving the grid output image. Default: False")
|
||||
if not st.session_state["defaults"].general.skip_grid:
|
||||
|
||||
|
||||
st.session_state["defaults"].general.grid_quality = st.number_input("Grid Quality", value=st.session_state['defaults'].general.grid_quality,
|
||||
help="Format for saving the grid output image. Default: 95")
|
||||
|
||||
st.session_state["defaults"].general.skip_save = st.checkbox("Skip Save", value=st.session_state['defaults'].general.skip_save,
|
||||
help="Skip saving the output image. Default: False")
|
||||
|
||||
st.session_state["defaults"].general.n_rows = st.number_input("Number of Grid Rows", value=st.session_state['defaults'].general.n_rows,
|
||||
help="Number of rows the grid wil have when saving the grid output image. Default: '-1'")
|
||||
|
||||
st.session_state["defaults"].general.no_verify_input = st.checkbox("Do not Verify Input", value=st.session_state['defaults'].general.no_verify_input,
|
||||
help="Do not verify input to check if it's too long. Default: False")
|
||||
|
||||
st.session_state["defaults"].general.show_percent_in_tab_title = st.checkbox("Show Percent in tab title", value=st.session_state['defaults'].general.show_percent_in_tab_title,
|
||||
help="Add the progress percent value to the page title on the tab on your browser. "
|
||||
"This is useful in case you need to know how the generation is going while doign something else"
|
||||
"in another tab on your browser. Default: True")
|
||||
|
||||
st.session_state["defaults"].general.enable_suggestions = st.checkbox("Enable Suggestions Box", value=st.session_state['defaults'].general.enable_suggestions,
|
||||
help="Adds a suggestion box under the prompt when clicked. Default: True")
|
||||
|
||||
st.session_state["defaults"].daisi_app.running_on_daisi_io = st.checkbox("Running on Daisi.io?", value=st.session_state['defaults'].daisi_app.running_on_daisi_io,
|
||||
help="Specify if we are running on app.Daisi.io . Default: False")
|
||||
|
||||
with col4:
|
||||
st.title("Streamlit Config")
|
||||
|
||||
default_theme_list = ["light", "dark"]
|
||||
st.session_state["defaults"].general.default_theme = st.selectbox("Default Theme", default_theme_list, index=default_theme_list.index(st.session_state['defaults'].general.default_theme),
|
||||
help="Defaut theme to use as base for streamlit. Default: dark")
|
||||
st.session_state["streamlit_config"]["theme"]["base"] = st.session_state["defaults"].general.default_theme
|
||||
|
||||
|
||||
if not st.session_state['defaults'].admin.hide_server_setting:
|
||||
with st.expander("Server", True):
|
||||
|
||||
st.session_state["streamlit_config"]['server']['headless'] = st.checkbox("Run Headless", help="If false, will attempt to open a browser window on start. \
|
||||
Default: false unless (1) we are on a Linux box where DISPLAY is unset, \
|
||||
or (2) we are running in the Streamlit Atom plugin.")
|
||||
|
||||
st.session_state["streamlit_config"]['server']['port'] = st.number_input("Port", value=st.session_state["streamlit_config"]['server']['port'],
|
||||
help="The port where the server will listen for browser connections. Default: 8501")
|
||||
|
||||
st.session_state["streamlit_config"]['server']['baseUrlPath'] = st.text_input("Base Url Path", value=st.session_state["streamlit_config"]['server']['baseUrlPath'],
|
||||
help="The base path for the URL where Streamlit should be served from. Default: '' ")
|
||||
|
||||
st.session_state["streamlit_config"]['server']['enableCORS'] = st.checkbox("Enable CORS", value=st.session_state['streamlit_config']['server']['enableCORS'],
|
||||
help="Enables support for Cross-Origin Request Sharing (CORS) protection, for added security. \
|
||||
Due to conflicts between CORS and XSRF, if `server.enableXsrfProtection` is on and `server.enableCORS` \
|
||||
is off at the same time, we will prioritize `server.enableXsrfProtection`. Default: true")
|
||||
|
||||
st.session_state["streamlit_config"]['server']['enableXsrfProtection'] = st.checkbox("Enable Xsrf Protection",
|
||||
value=st.session_state['streamlit_config']['server']['enableXsrfProtection'],
|
||||
help="Enables support for Cross-Site Request Forgery (XSRF) protection, \
|
||||
for added security. Due to conflicts between CORS and XSRF, \
|
||||
if `server.enableXsrfProtection` is on and `server.enableCORS` is off at \
|
||||
the same time, we will prioritize `server.enableXsrfProtection`. Default: true")
|
||||
|
||||
st.session_state["streamlit_config"]['server']['maxUploadSize'] = st.number_input("Max Upload Size", value=st.session_state["streamlit_config"]['server']['maxUploadSize'],
|
||||
help="Max size, in megabytes, for files uploaded with the file_uploader. Default: 200")
|
||||
|
||||
st.session_state["streamlit_config"]['server']['maxMessageSize'] = st.number_input("Max Message Size", value=st.session_state["streamlit_config"]['server']['maxUploadSize'],
|
||||
help="Max size, in megabytes, of messages that can be sent via the WebSocket connection. Default: 200")
|
||||
|
||||
st.session_state["streamlit_config"]['server']['enableWebsocketCompression'] = st.checkbox("Enable Websocket Compression",
|
||||
value=st.session_state["streamlit_config"]['server']['enableWebsocketCompression'],
|
||||
help=" Enables support for websocket compression. Default: false")
|
||||
if not st.session_state['defaults'].admin.hide_browser_setting:
|
||||
with st.expander("Browser", expanded=True):
|
||||
st.session_state["streamlit_config"]['browser']['serverAddress'] = st.text_input("Server Address",
|
||||
value=st.session_state["streamlit_config"]['browser']['serverAddress'] if "serverAddress" in st.session_state["streamlit_config"] else "localhost",
|
||||
help="Internet address where users should point their browsers in order \
|
||||
to connect to the app. Can be IP address or DNS name and path.\
|
||||
This is used to: - Set the correct URL for CORS and XSRF protection purposes. \
|
||||
- Show the URL on the terminal - Open the browser. Default: 'localhost'")
|
||||
|
||||
st.session_state["defaults"].general.streamlit_telemetry = st.checkbox("Enable Telemetry", value=st.session_state['defaults'].general.streamlit_telemetry,
|
||||
help="Enables or Disables streamlit telemetry. Default: False")
|
||||
st.session_state["streamlit_config"]["browser"]["gatherUsageStats"] = st.session_state["defaults"].general.streamlit_telemetry
|
||||
|
||||
st.session_state["streamlit_config"]['browser']['serverPort'] = st.number_input("Server Port", value=st.session_state["streamlit_config"]['browser']['serverPort'],
|
||||
help="Port where users should point their browsers in order to connect to the app. \
|
||||
This is used to: - Set the correct URL for CORS and XSRF protection purposes. \
|
||||
- Show the URL on the terminal - Open the browser \
|
||||
Default: whatever value is set in server.port.")
|
||||
|
||||
with col5:
|
||||
st.title("Huggingface")
|
||||
st.session_state["defaults"].general.huggingface_token = st.text_input("Huggingface Token", value=st.session_state['defaults'].general.huggingface_token, type="password",
|
||||
help="Your Huggingface Token, it's used to download the model for the diffusers library which \
|
||||
is used on the Text To Video tab. This token will be saved to your user config file\
|
||||
and WILL NOT be share with us or anyone. You can get your access token \
|
||||
at https://huggingface.co/settings/tokens. Default: None")
|
||||
|
||||
st.title("Stable Horde")
|
||||
st.session_state["defaults"].general.stable_horde_api = st.text_input("Stable Horde Api", value=st.session_state["defaults"].general.stable_horde_api, type="password",
|
||||
help="First Register an account at https://stablehorde.net/register which will generate for you \
|
||||
an API key. Store that key somewhere safe. \n \
|
||||
If you do not want to register, you can use `0000000000` as api_key to connect anonymously.\
|
||||
However anonymous accounts have the lowest priority when there's too many concurrent requests! \
|
||||
To increase your priority you will need a unique API key and then to increase your Kudos \
|
||||
read more about them at https://dbzer0.com/blog/the-kudos-based-economy-for-the-koboldai-horde/.")
|
||||
|
||||
with txt2img_tab:
|
||||
col1, col2, col3, col4, col5 = st.columns(5, gap='medium')
|
||||
|
||||
with col1:
|
||||
st.title("Slider Parameters")
|
||||
|
||||
# Width
|
||||
st.session_state["defaults"].txt2img.width.value = st.number_input("Default Image Width", value=st.session_state['defaults'].txt2img.width.value,
|
||||
help="Set the default width for the generated image. Default is: 512")
|
||||
|
||||
st.session_state["defaults"].txt2img.width.min_value = st.number_input("Minimum Image Width", value=st.session_state['defaults'].txt2img.width.min_value,
|
||||
help="Set the default minimum value for the width slider. Default is: 64")
|
||||
|
||||
st.session_state["defaults"].txt2img.width.max_value = st.number_input("Maximum Image Width", value=st.session_state['defaults'].txt2img.width.max_value,
|
||||
help="Set the default maximum value for the width slider. Default is: 2048")
|
||||
|
||||
# Height
|
||||
st.session_state["defaults"].txt2img.height.value = st.number_input("Default Image Height", value=st.session_state['defaults'].txt2img.height.value,
|
||||
help="Set the default height for the generated image. Default is: 512")
|
||||
|
||||
st.session_state["defaults"].txt2img.height.min_value = st.number_input("Minimum Image Height", value=st.session_state['defaults'].txt2img.height.min_value,
|
||||
help="Set the default minimum value for the height slider. Default is: 64")
|
||||
|
||||
st.session_state["defaults"].txt2img.height.max_value = st.number_input("Maximum Image Height", value=st.session_state['defaults'].txt2img.height.max_value,
|
||||
help="Set the default maximum value for the height slider. Default is: 2048")
|
||||
|
||||
with col2:
|
||||
# CFG
|
||||
st.session_state["defaults"].txt2img.cfg_scale.value = st.number_input("Default CFG Scale", value=st.session_state['defaults'].txt2img.cfg_scale.value,
|
||||
help="Set the default value for the CFG Scale. Default is: 7.5")
|
||||
|
||||
st.session_state["defaults"].txt2img.cfg_scale.min_value = st.number_input("Minimum CFG Scale Value", value=st.session_state['defaults'].txt2img.cfg_scale.min_value,
|
||||
help="Set the default minimum value for the CFG scale slider. Default is: 1")
|
||||
|
||||
st.session_state["defaults"].txt2img.cfg_scale.step = st.number_input("CFG Slider Steps", value=st.session_state['defaults'].txt2img.cfg_scale.step,
|
||||
help="Set the default value for the number of steps on the CFG scale slider. Default is: 0.5")
|
||||
# Sampling Steps
|
||||
st.session_state["defaults"].txt2img.sampling_steps.value = st.number_input("Default Sampling Steps", value=st.session_state['defaults'].txt2img.sampling_steps.value,
|
||||
help="Set the default number of sampling steps to use. Default is: 30 (with k_euler)")
|
||||
|
||||
st.session_state["defaults"].txt2img.sampling_steps.min_value = st.number_input("Minimum Sampling Steps",
|
||||
value=st.session_state['defaults'].txt2img.sampling_steps.min_value,
|
||||
help="Set the default minimum value for the sampling steps slider. Default is: 1")
|
||||
|
||||
st.session_state["defaults"].txt2img.sampling_steps.step = st.number_input("Sampling Slider Steps",
|
||||
value=st.session_state['defaults'].txt2img.sampling_steps.step,
|
||||
help="Set the default value for the number of steps on the sampling steps slider. Default is: 10")
|
||||
|
||||
with col3:
|
||||
st.title("General Parameters")
|
||||
|
||||
# Batch Count
|
||||
st.session_state["defaults"].txt2img.batch_count.value = st.number_input("Batch count", value=st.session_state['defaults'].txt2img.batch_count.value,
|
||||
help="How many iterations or batches of images to generate in total.")
|
||||
|
||||
st.session_state["defaults"].txt2img.batch_size.value = st.number_input("Batch size", value=st.session_state.defaults.txt2img.batch_size.value,
|
||||
help="How many images are at once in a batch.\
|
||||
It increases the VRAM usage a lot but if you have enough VRAM it can reduce the time it \
|
||||
takes to finish generation as more images are generated at once.\
|
||||
Default: 1")
|
||||
|
||||
default_sampler_list = ["k_lms", "k_euler", "k_euler_a", "k_dpm_2", "k_dpm_2_a", "k_heun", "PLMS", "DDIM"]
|
||||
st.session_state["defaults"].txt2img.default_sampler = st.selectbox("Default Sampler",
|
||||
default_sampler_list, index=default_sampler_list.index(
|
||||
st.session_state['defaults'].txt2img.default_sampler),
|
||||
help="Defaut sampler to use for txt2img. Default: k_euler")
|
||||
|
||||
st.session_state['defaults'].txt2img.seed = st.text_input("Default Seed", value=st.session_state['defaults'].txt2img.seed, help="Default seed.")
|
||||
|
||||
with col4:
|
||||
|
||||
st.session_state["defaults"].txt2img.separate_prompts = st.checkbox("Separate Prompts",
|
||||
value=st.session_state['defaults'].txt2img.separate_prompts, help="Separate Prompts. Default: False")
|
||||
|
||||
st.session_state["defaults"].txt2img.normalize_prompt_weights = st.checkbox("Normalize Prompt Weights",
|
||||
value=st.session_state['defaults'].txt2img.normalize_prompt_weights,
|
||||
help="Choose to normalize prompt weights. Default: True")
|
||||
|
||||
st.session_state["defaults"].txt2img.save_individual_images = st.checkbox("Save Individual Images",
|
||||
value=st.session_state['defaults'].txt2img.save_individual_images,
|
||||
help="Choose to save individual images. Default: True")
|
||||
|
||||
st.session_state["defaults"].txt2img.save_grid = st.checkbox("Save Grid Images", value=st.session_state['defaults'].txt2img.save_grid,
|
||||
help="Choose to save the grid images. Default: True")
|
||||
|
||||
st.session_state["defaults"].txt2img.group_by_prompt = st.checkbox("Group By Prompt", value=st.session_state['defaults'].txt2img.group_by_prompt,
|
||||
help="Choose to save images grouped by their prompt. Default: False")
|
||||
|
||||
st.session_state["defaults"].txt2img.save_as_jpg = st.checkbox("Save As JPG", value=st.session_state['defaults'].txt2img.save_as_jpg,
|
||||
help="Choose to save images as jpegs. Default: False")
|
||||
|
||||
st.session_state["defaults"].txt2img.write_info_files = st.checkbox("Write Info Files For Images", value=st.session_state['defaults'].txt2img.write_info_files,
|
||||
help="Choose to write the info files along with the generated images. Default: True")
|
||||
|
||||
st.session_state["defaults"].txt2img.use_GFPGAN = st.checkbox(
|
||||
"Use GFPGAN", value=st.session_state['defaults'].txt2img.use_GFPGAN, help="Choose to use GFPGAN. Default: False")
|
||||
|
||||
st.session_state["defaults"].txt2img.use_upscaling = st.checkbox("Use Upscaling", value=st.session_state['defaults'].txt2img.use_upscaling,
|
||||
help="Choose to turn on upscaling by default. Default: False")
|
||||
|
||||
st.session_state["defaults"].txt2img.update_preview = True
|
||||
st.session_state["defaults"].txt2img.update_preview_frequency = st.number_input("Preview Image Update Frequency",
|
||||
min_value=0,
|
||||
value=st.session_state['defaults'].txt2img.update_preview_frequency,
|
||||
help="Set the default value for the frrquency of the preview image updates. Default is: 10")
|
||||
|
||||
with col5:
|
||||
st.title("Variation Parameters")
|
||||
|
||||
st.session_state["defaults"].txt2img.variant_amount.value = st.number_input("Default Variation Amount",
|
||||
value=st.session_state['defaults'].txt2img.variant_amount.value,
|
||||
help="Set the default variation to use. Default is: 0.0")
|
||||
|
||||
st.session_state["defaults"].txt2img.variant_amount.min_value = st.number_input("Minimum Variation Amount",
|
||||
value=st.session_state['defaults'].txt2img.variant_amount.min_value,
|
||||
help="Set the default minimum value for the variation slider. Default is: 0.0")
|
||||
|
||||
st.session_state["defaults"].txt2img.variant_amount.max_value = st.number_input("Maximum Variation Amount",
|
||||
value=st.session_state['defaults'].txt2img.variant_amount.max_value,
|
||||
help="Set the default maximum value for the variation slider. Default is: 1.0")
|
||||
|
||||
st.session_state["defaults"].txt2img.variant_amount.step = st.number_input("Variation Slider Steps",
|
||||
value=st.session_state['defaults'].txt2img.variant_amount.step,
|
||||
help="Set the default value for the number of steps on the variation slider. Default is: 1")
|
||||
|
||||
st.session_state['defaults'].txt2img.variant_seed = st.text_input("Default Variation Seed", value=st.session_state['defaults'].txt2img.variant_seed,
|
||||
help="Default variation seed.")
|
||||
|
||||
with img2img_tab:
|
||||
col1, col2, col3, col4, col5 = st.columns(5, gap='medium')
|
||||
|
||||
with col1:
|
||||
st.title("Image Editing")
|
||||
|
||||
# Denoising
|
||||
st.session_state["defaults"].img2img.denoising_strength.value = st.number_input("Default Denoising Amount",
|
||||
value=st.session_state['defaults'].img2img.denoising_strength.value,
|
||||
help="Set the default denoising to use. Default is: 0.75")
|
||||
|
||||
st.session_state["defaults"].img2img.denoising_strength.min_value = st.number_input("Minimum Denoising Amount",
|
||||
value=st.session_state['defaults'].img2img.denoising_strength.min_value,
|
||||
help="Set the default minimum value for the denoising slider. Default is: 0.0")
|
||||
|
||||
st.session_state["defaults"].img2img.denoising_strength.max_value = st.number_input("Maximum Denoising Amount",
|
||||
value=st.session_state['defaults'].img2img.denoising_strength.max_value,
|
||||
help="Set the default maximum value for the denoising slider. Default is: 1.0")
|
||||
|
||||
st.session_state["defaults"].img2img.denoising_strength.step = st.number_input("Denoising Slider Steps",
|
||||
value=st.session_state['defaults'].img2img.denoising_strength.step,
|
||||
help="Set the default value for the number of steps on the denoising slider. Default is: 0.01")
|
||||
|
||||
# Masking
|
||||
st.session_state["defaults"].img2img.mask_mode = st.number_input("Default Mask Mode", value=st.session_state['defaults'].img2img.mask_mode,
|
||||
help="Set the default mask mode to use. 0 = Keep Masked Area, 1 = Regenerate Masked Area. Default is: 0")
|
||||
|
||||
st.session_state["defaults"].img2img.mask_restore = st.checkbox("Default Mask Restore", value=st.session_state['defaults'].img2img.mask_restore,
|
||||
help="Mask Restore. Default: False")
|
||||
|
||||
st.session_state["defaults"].img2img.resize_mode = st.number_input("Default Resize Mode", value=st.session_state['defaults'].img2img.resize_mode,
|
||||
help="Set the default resizing mode. 0 = Just Resize, 1 = Crop and Resize, 3 = Resize and Fill. Default is: 0")
|
||||
|
||||
with col2:
|
||||
st.title("Slider Parameters")
|
||||
|
||||
# Width
|
||||
st.session_state["defaults"].img2img.width.value = st.number_input("Default Outputted Image Width", value=st.session_state['defaults'].img2img.width.value,
|
||||
help="Set the default width for the generated image. Default is: 512")
|
||||
|
||||
st.session_state["defaults"].img2img.width.min_value = st.number_input("Minimum Outputted Image Width", value=st.session_state['defaults'].img2img.width.min_value,
|
||||
help="Set the default minimum value for the width slider. Default is: 64")
|
||||
|
||||
st.session_state["defaults"].img2img.width.max_value = st.number_input("Maximum Outputted Image Width", value=st.session_state['defaults'].img2img.width.max_value,
|
||||
help="Set the default maximum value for the width slider. Default is: 2048")
|
||||
|
||||
# Height
|
||||
st.session_state["defaults"].img2img.height.value = st.number_input("Default Outputted Image Height", value=st.session_state['defaults'].img2img.height.value,
|
||||
help="Set the default height for the generated image. Default is: 512")
|
||||
|
||||
st.session_state["defaults"].img2img.height.min_value = st.number_input("Minimum Outputted Image Height", value=st.session_state['defaults'].img2img.height.min_value,
|
||||
help="Set the default minimum value for the height slider. Default is: 64")
|
||||
|
||||
st.session_state["defaults"].img2img.height.max_value = st.number_input("Maximum Outputted Image Height", value=st.session_state['defaults'].img2img.height.max_value,
|
||||
help="Set the default maximum value for the height slider. Default is: 2048")
|
||||
|
||||
# CFG
|
||||
st.session_state["defaults"].img2img.cfg_scale.value = st.number_input("Default Img2Img CFG Scale", value=st.session_state['defaults'].img2img.cfg_scale.value,
|
||||
help="Set the default value for the CFG Scale. Default is: 7.5")
|
||||
|
||||
st.session_state["defaults"].img2img.cfg_scale.min_value = st.number_input("Minimum Img2Img CFG Scale Value",
|
||||
value=st.session_state['defaults'].img2img.cfg_scale.min_value,
|
||||
help="Set the default minimum value for the CFG scale slider. Default is: 1")
|
||||
|
||||
with col3:
|
||||
st.session_state["defaults"].img2img.cfg_scale.step = st.number_input("Img2Img CFG Slider Steps",
|
||||
value=st.session_state['defaults'].img2img.cfg_scale.step,
|
||||
help="Set the default value for the number of steps on the CFG scale slider. Default is: 0.5")
|
||||
|
||||
# Sampling Steps
|
||||
st.session_state["defaults"].img2img.sampling_steps.value = st.number_input("Default Img2Img Sampling Steps",
|
||||
value=st.session_state['defaults'].img2img.sampling_steps.value,
|
||||
help="Set the default number of sampling steps to use. Default is: 30 (with k_euler)")
|
||||
|
||||
st.session_state["defaults"].img2img.sampling_steps.min_value = st.number_input("Minimum Img2Img Sampling Steps",
|
||||
value=st.session_state['defaults'].img2img.sampling_steps.min_value,
|
||||
help="Set the default minimum value for the sampling steps slider. Default is: 1")
|
||||
|
||||
st.session_state["defaults"].img2img.sampling_steps.step = st.number_input("Img2Img Sampling Slider Steps",
|
||||
value=st.session_state['defaults'].img2img.sampling_steps.step,
|
||||
help="Set the default value for the number of steps on the sampling steps slider. Default is: 10")
|
||||
|
||||
# Batch Count
|
||||
st.session_state["defaults"].img2img.batch_count.value = st.number_input("Img2img Batch count", value=st.session_state["defaults"].img2img.batch_count.value,
|
||||
help="How many iterations or batches of images to generate in total.")
|
||||
|
||||
st.session_state["defaults"].img2img.batch_size.value = st.number_input("Img2img Batch size", value=st.session_state["defaults"].img2img.batch_size.value,
|
||||
help="How many images are at once in a batch.\
|
||||
It increases the VRAM usage a lot but if you have enough VRAM it can reduce the time it \
|
||||
takes to finish generation as more images are generated at once.\
|
||||
Default: 1")
|
||||
with col4:
|
||||
# Inference Steps
|
||||
st.session_state["defaults"].img2img.num_inference_steps.value = st.number_input("Default Inference Steps",
|
||||
value=st.session_state['defaults'].img2img.num_inference_steps.value,
|
||||
help="Set the default number of inference steps to use. Default is: 200")
|
||||
|
||||
st.session_state["defaults"].img2img.num_inference_steps.min_value = st.number_input("Minimum Sampling Steps",
|
||||
value=st.session_state['defaults'].img2img.num_inference_steps.min_value,
|
||||
help="Set the default minimum value for the inference steps slider. Default is: 10")
|
||||
|
||||
st.session_state["defaults"].img2img.num_inference_steps.max_value = st.number_input("Maximum Sampling Steps",
|
||||
value=st.session_state['defaults'].img2img.num_inference_steps.max_value,
|
||||
help="Set the default maximum value for the inference steps slider. Default is: 500")
|
||||
|
||||
st.session_state["defaults"].img2img.num_inference_steps.step = st.number_input("Inference Slider Steps",
|
||||
value=st.session_state['defaults'].img2img.num_inference_steps.step,
|
||||
help="Set the default value for the number of steps on the inference steps slider.\
|
||||
Default is: 10")
|
||||
|
||||
# Find Noise Steps
|
||||
st.session_state["defaults"].img2img.find_noise_steps.value = st.number_input("Default Find Noise Steps",
|
||||
value=st.session_state['defaults'].img2img.find_noise_steps.value,
|
||||
help="Set the default number of find noise steps to use. Default is: 100")
|
||||
|
||||
st.session_state["defaults"].img2img.find_noise_steps.min_value = st.number_input("Minimum Find Noise Steps",
|
||||
value=st.session_state['defaults'].img2img.find_noise_steps.min_value,
|
||||
help="Set the default minimum value for the find noise steps slider. Default is: 0")
|
||||
|
||||
st.session_state["defaults"].img2img.find_noise_steps.step = st.number_input("Find Noise Slider Steps",
|
||||
value=st.session_state['defaults'].img2img.find_noise_steps.step,
|
||||
help="Set the default value for the number of steps on the find noise steps slider. \
|
||||
Default is: 100")
|
||||
|
||||
with col5:
|
||||
st.title("General Parameters")
|
||||
|
||||
default_sampler_list = ["k_lms", "k_euler", "k_euler_a", "k_dpm_2", "k_dpm_2_a", "k_heun", "PLMS", "DDIM"]
|
||||
st.session_state["defaults"].img2img.sampler_name = st.selectbox("Default Img2Img Sampler", default_sampler_list,
|
||||
index=default_sampler_list.index(st.session_state['defaults'].img2img.sampler_name),
|
||||
help="Defaut sampler to use for img2img. Default: k_euler")
|
||||
|
||||
st.session_state['defaults'].img2img.seed = st.text_input("Default Img2Img Seed", value=st.session_state['defaults'].img2img.seed, help="Default seed.")
|
||||
|
||||
st.session_state["defaults"].img2img.separate_prompts = st.checkbox("Separate Img2Img Prompts", value=st.session_state['defaults'].img2img.separate_prompts,
|
||||
help="Separate Prompts. Default: False")
|
||||
|
||||
st.session_state["defaults"].img2img.normalize_prompt_weights = st.checkbox("Normalize Img2Img Prompt Weights",
|
||||
value=st.session_state['defaults'].img2img.normalize_prompt_weights,
|
||||
help="Choose to normalize prompt weights. Default: True")
|
||||
|
||||
st.session_state["defaults"].img2img.save_individual_images = st.checkbox("Save Individual Img2Img Images",
|
||||
value=st.session_state['defaults'].img2img.save_individual_images,
|
||||
help="Choose to save individual images. Default: True")
|
||||
|
||||
st.session_state["defaults"].img2img.save_grid = st.checkbox("Save Img2Img Grid Images",
|
||||
value=st.session_state['defaults'].img2img.save_grid, help="Choose to save the grid images. Default: True")
|
||||
|
||||
st.session_state["defaults"].img2img.group_by_prompt = st.checkbox("Group By Img2Img Prompt",
|
||||
value=st.session_state['defaults'].img2img.group_by_prompt,
|
||||
help="Choose to save images grouped by their prompt. Default: False")
|
||||
|
||||
st.session_state["defaults"].img2img.save_as_jpg = st.checkbox("Save Img2Img As JPG", value=st.session_state['defaults'].img2img.save_as_jpg,
|
||||
help="Choose to save images as jpegs. Default: False")
|
||||
|
||||
st.session_state["defaults"].img2img.write_info_files = st.checkbox("Write Info Files For Img2Img Images",
|
||||
value=st.session_state['defaults'].img2img.write_info_files,
|
||||
help="Choose to write the info files along with the generated images. Default: True")
|
||||
|
||||
st.session_state["defaults"].img2img.use_GFPGAN = st.checkbox(
|
||||
"Img2Img Use GFPGAN", value=st.session_state['defaults'].img2img.use_GFPGAN, help="Choose to use GFPGAN. Default: False")
|
||||
|
||||
st.session_state["defaults"].img2img.use_RealESRGAN = st.checkbox("Img2Img Use RealESRGAN", value=st.session_state['defaults'].img2img.use_RealESRGAN,
|
||||
help="Choose to use RealESRGAN. Default: False")
|
||||
|
||||
st.session_state["defaults"].img2img.update_preview = True
|
||||
st.session_state["defaults"].img2img.update_preview_frequency = st.number_input("Img2Img Preview Image Update Frequency",
|
||||
min_value=0,
|
||||
value=st.session_state['defaults'].img2img.update_preview_frequency,
|
||||
help="Set the default value for the frrquency of the preview image updates. Default is: 10")
|
||||
|
||||
st.title("Variation Parameters")
|
||||
|
||||
st.session_state["defaults"].img2img.variant_amount = st.number_input("Default Img2Img Variation Amount",
|
||||
value=st.session_state['defaults'].img2img.variant_amount,
|
||||
help="Set the default variation to use. Default is: 0.0")
|
||||
|
||||
# I THINK THESE ARE MISSING FROM THE CONFIG FILE
|
||||
# st.session_state["defaults"].img2img.variant_amount.min_value = st.number_input("Minimum Img2Img Variation Amount",
|
||||
# value=st.session_state['defaults'].img2img.variant_amount.min_value, help="Set the default minimum value for the variation slider. Default is: 0.0"))
|
||||
|
||||
# st.session_state["defaults"].img2img.variant_amount.max_value = st.number_input("Maximum Img2Img Variation Amount",
|
||||
# value=st.session_state['defaults'].img2img.variant_amount.max_value, help="Set the default maximum value for the variation slider. Default is: 1.0"))
|
||||
|
||||
# st.session_state["defaults"].img2img.variant_amount.step = st.number_input("Img2Img Variation Slider Steps",
|
||||
# value=st.session_state['defaults'].img2img.variant_amount.step, help="Set the default value for the number of steps on the variation slider. Default is: 1"))
|
||||
|
||||
st.session_state['defaults'].img2img.variant_seed = st.text_input("Default Img2Img Variation Seed",
|
||||
value=st.session_state['defaults'].img2img.variant_seed, help="Default variation seed.")
|
||||
|
||||
with img2txt_tab:
|
||||
col1 = st.columns(1, gap="large")
|
||||
|
||||
st.title("Image-To-Text")
|
||||
|
||||
st.session_state["defaults"].img2txt.batch_size = st.number_input("Default Img2Txt Batch Size", value=st.session_state['defaults'].img2txt.batch_size,
|
||||
help="Set the default batch size for Img2Txt. Default is: 420?")
|
||||
|
||||
st.session_state["defaults"].img2txt.blip_image_eval_size = st.number_input("Default Blip Image Size Evaluation",
|
||||
value=st.session_state['defaults'].img2txt.blip_image_eval_size,
|
||||
help="Set the default value for the blip image evaluation size. Default is: 512")
|
||||
|
||||
with txt2vid_tab:
|
||||
col1, col2, col3, col4, col5 = st.columns(5, gap="medium")
|
||||
|
||||
with col1:
|
||||
st.title("Slider Parameters")
|
||||
|
||||
# Width
|
||||
st.session_state["defaults"].txt2vid.width.value = st.number_input("Default txt2vid Image Width",
|
||||
value=st.session_state['defaults'].txt2vid.width.value,
|
||||
help="Set the default width for the generated image. Default is: 512")
|
||||
|
||||
st.session_state["defaults"].txt2vid.width.min_value = st.number_input("Minimum txt2vid Image Width",
|
||||
value=st.session_state['defaults'].txt2vid.width.min_value,
|
||||
help="Set the default minimum value for the width slider. Default is: 64")
|
||||
|
||||
st.session_state["defaults"].txt2vid.width.max_value = st.number_input("Maximum txt2vid Image Width",
|
||||
value=st.session_state['defaults'].txt2vid.width.max_value,
|
||||
help="Set the default maximum value for the width slider. Default is: 2048")
|
||||
|
||||
# Height
|
||||
st.session_state["defaults"].txt2vid.height.value = st.number_input("Default txt2vid Image Height",
|
||||
value=st.session_state['defaults'].txt2vid.height.value,
|
||||
help="Set the default height for the generated image. Default is: 512")
|
||||
|
||||
st.session_state["defaults"].txt2vid.height.min_value = st.number_input("Minimum txt2vid Image Height",
|
||||
value=st.session_state['defaults'].txt2vid.height.min_value,
|
||||
help="Set the default minimum value for the height slider. Default is: 64")
|
||||
|
||||
st.session_state["defaults"].txt2vid.height.max_value = st.number_input("Maximum txt2vid Image Height",
|
||||
value=st.session_state['defaults'].txt2vid.height.max_value,
|
||||
help="Set the default maximum value for the height slider. Default is: 2048")
|
||||
|
||||
# CFG
|
||||
st.session_state["defaults"].txt2vid.cfg_scale.value = st.number_input("Default txt2vid CFG Scale",
|
||||
value=st.session_state['defaults'].txt2vid.cfg_scale.value,
|
||||
help="Set the default value for the CFG Scale. Default is: 7.5")
|
||||
|
||||
st.session_state["defaults"].txt2vid.cfg_scale.min_value = st.number_input("Minimum txt2vid CFG Scale Value",
|
||||
value=st.session_state['defaults'].txt2vid.cfg_scale.min_value,
|
||||
help="Set the default minimum value for the CFG scale slider. Default is: 1")
|
||||
|
||||
st.session_state["defaults"].txt2vid.cfg_scale.step = st.number_input("txt2vid CFG Slider Steps",
|
||||
value=st.session_state['defaults'].txt2vid.cfg_scale.step,
|
||||
help="Set the default value for the number of steps on the CFG scale slider. Default is: 0.5")
|
||||
|
||||
with col2:
|
||||
# Sampling Steps
|
||||
st.session_state["defaults"].txt2vid.sampling_steps.value = st.number_input("Default txt2vid Sampling Steps",
|
||||
value=st.session_state['defaults'].txt2vid.sampling_steps.value,
|
||||
help="Set the default number of sampling steps to use. Default is: 30 (with k_euler)")
|
||||
|
||||
st.session_state["defaults"].txt2vid.sampling_steps.min_value = st.number_input("Minimum txt2vid Sampling Steps",
|
||||
value=st.session_state['defaults'].txt2vid.sampling_steps.min_value,
|
||||
help="Set the default minimum value for the sampling steps slider. Default is: 1")
|
||||
|
||||
st.session_state["defaults"].txt2vid.sampling_steps.step = st.number_input("txt2vid Sampling Slider Steps",
|
||||
value=st.session_state['defaults'].txt2vid.sampling_steps.step,
|
||||
help="Set the default value for the number of steps on the sampling steps slider. Default is: 10")
|
||||
|
||||
# Batch Count
|
||||
st.session_state["defaults"].txt2vid.batch_count.value = st.number_input("txt2vid Batch count", value=st.session_state['defaults'].txt2vid.batch_count.value,
|
||||
help="How many iterations or batches of images to generate in total.")
|
||||
|
||||
st.session_state["defaults"].txt2vid.batch_size.value = st.number_input("txt2vid Batch size", value=st.session_state.defaults.txt2vid.batch_size.value,
|
||||
help="How many images are at once in a batch.\
|
||||
It increases the VRAM usage a lot but if you have enough VRAM it can reduce the time it \
|
||||
takes to finish generation as more images are generated at once.\
|
||||
Default: 1")
|
||||
|
||||
# Inference Steps
|
||||
st.session_state["defaults"].txt2vid.num_inference_steps.value = st.number_input("Default Txt2Vid Inference Steps",
|
||||
value=st.session_state['defaults'].txt2vid.num_inference_steps.value,
|
||||
help="Set the default number of inference steps to use. Default is: 200")
|
||||
|
||||
st.session_state["defaults"].txt2vid.num_inference_steps.min_value = st.number_input("Minimum Txt2Vid Sampling Steps",
|
||||
value=st.session_state['defaults'].txt2vid.num_inference_steps.min_value,
|
||||
help="Set the default minimum value for the inference steps slider. Default is: 10")
|
||||
|
||||
st.session_state["defaults"].txt2vid.num_inference_steps.max_value = st.number_input("Maximum Txt2Vid Sampling Steps",
|
||||
value=st.session_state['defaults'].txt2vid.num_inference_steps.max_value,
|
||||
help="Set the default maximum value for the inference steps slider. Default is: 500")
|
||||
st.session_state["defaults"].txt2vid.num_inference_steps.step = st.number_input("Txt2Vid Inference Slider Steps",
|
||||
value=st.session_state['defaults'].txt2vid.num_inference_steps.step,
|
||||
help="Set the default value for the number of steps on the inference steps slider. Default is: 10")
|
||||
|
||||
with col3:
|
||||
st.title("General Parameters")
|
||||
|
||||
st.session_state['defaults'].txt2vid.default_model = st.text_input("Default Txt2Vid Model", value=st.session_state['defaults'].txt2vid.default_model,
|
||||
help="Default: CompVis/stable-diffusion-v1-4")
|
||||
|
||||
# INSERT CUSTOM_MODELS_LIST HERE
|
||||
|
||||
default_sampler_list = ["k_lms", "k_euler", "k_euler_a", "k_dpm_2", "k_dpm_2_a", "k_heun", "PLMS", "DDIM"]
|
||||
st.session_state["defaults"].txt2vid.default_sampler = st.selectbox("Default txt2vid Sampler", default_sampler_list,
|
||||
index=default_sampler_list.index(st.session_state['defaults'].txt2vid.default_sampler),
|
||||
help="Defaut sampler to use for txt2vid. Default: k_euler")
|
||||
|
||||
st.session_state['defaults'].txt2vid.seed = st.text_input("Default txt2vid Seed", value=st.session_state['defaults'].txt2vid.seed, help="Default seed.")
|
||||
|
||||
st.session_state['defaults'].txt2vid.scheduler_name = st.text_input("Default Txt2Vid Scheduler",
|
||||
value=st.session_state['defaults'].txt2vid.scheduler_name, help="Default scheduler.")
|
||||
|
||||
st.session_state["defaults"].txt2vid.separate_prompts = st.checkbox("Separate txt2vid Prompts",
|
||||
value=st.session_state['defaults'].txt2vid.separate_prompts, help="Separate Prompts. Default: False")
|
||||
|
||||
st.session_state["defaults"].txt2vid.normalize_prompt_weights = st.checkbox("Normalize txt2vid Prompt Weights",
|
||||
value=st.session_state['defaults'].txt2vid.normalize_prompt_weights,
|
||||
help="Choose to normalize prompt weights. Default: True")
|
||||
|
||||
st.session_state["defaults"].txt2vid.save_individual_images = st.checkbox("Save Individual txt2vid Images",
|
||||
value=st.session_state['defaults'].txt2vid.save_individual_images,
|
||||
help="Choose to save individual images. Default: True")
|
||||
|
||||
st.session_state["defaults"].txt2vid.save_video = st.checkbox("Save Txt2Vid Video", value=st.session_state['defaults'].txt2vid.save_video,
|
||||
help="Choose to save the Txt2Vid video. Default: True")
|
||||
|
||||
st.session_state["defaults"].txt2vid.save_video_on_stop = st.checkbox("Save video on Stop", value=st.session_state['defaults'].txt2vid.save_video_on_stop,
|
||||
help="Save a video with all the images generated as frames when we hit the stop button \
|
||||
during a generation.")
|
||||
|
||||
st.session_state["defaults"].txt2vid.group_by_prompt = st.checkbox("Group By txt2vid Prompt", value=st.session_state['defaults'].txt2vid.group_by_prompt,
|
||||
help="Choose to save images grouped by their prompt. Default: False")
|
||||
|
||||
st.session_state["defaults"].txt2vid.save_as_jpg = st.checkbox("Save txt2vid As JPG", value=st.session_state['defaults'].txt2vid.save_as_jpg,
|
||||
help="Choose to save images as jpegs. Default: False")
|
||||
|
||||
# Need more info for the Help dialog...
|
||||
st.session_state["defaults"].txt2vid.do_loop = st.checkbox("Loop Generations", value=st.session_state['defaults'].txt2vid.do_loop,
|
||||
help="Choose to loop or something, IDK.... Default: False")
|
||||
|
||||
st.session_state["defaults"].txt2vid.max_duration_in_seconds = st.number_input("Txt2Vid Max Duration in Seconds", value=st.session_state['defaults'].txt2vid.max_duration_in_seconds,
|
||||
help="Set the default value for the max duration in seconds for the video generated. Default is: 30")
|
||||
|
||||
st.session_state["defaults"].txt2vid.write_info_files = st.checkbox("Write Info Files For txt2vid Images", value=st.session_state['defaults'].txt2vid.write_info_files,
|
||||
help="Choose to write the info files along with the generated images. Default: True")
|
||||
|
||||
st.session_state["defaults"].txt2vid.use_GFPGAN = st.checkbox("txt2vid Use GFPGAN", value=st.session_state['defaults'].txt2vid.use_GFPGAN,
|
||||
help="Choose to use GFPGAN. Default: False")
|
||||
|
||||
st.session_state["defaults"].txt2vid.use_RealESRGAN = st.checkbox("txt2vid Use RealESRGAN", value=st.session_state['defaults'].txt2vid.use_RealESRGAN,
|
||||
help="Choose to use RealESRGAN. Default: False")
|
||||
|
||||
st.session_state["defaults"].txt2vid.update_preview = True
|
||||
st.session_state["defaults"].txt2vid.update_preview_frequency = st.number_input("txt2vid Preview Image Update Frequency",
|
||||
value=st.session_state['defaults'].txt2vid.update_preview_frequency,
|
||||
help="Set the default value for the frrquency of the preview image updates. Default is: 10")
|
||||
|
||||
with col4:
|
||||
st.title("Variation Parameters")
|
||||
|
||||
st.session_state["defaults"].txt2vid.variant_amount.value = st.number_input("Default txt2vid Variation Amount",
|
||||
value=st.session_state['defaults'].txt2vid.variant_amount.value,
|
||||
help="Set the default variation to use. Default is: 0.0")
|
||||
|
||||
st.session_state["defaults"].txt2vid.variant_amount.min_value = st.number_input("Minimum txt2vid Variation Amount",
|
||||
value=st.session_state['defaults'].txt2vid.variant_amount.min_value,
|
||||
help="Set the default minimum value for the variation slider. Default is: 0.0")
|
||||
|
||||
st.session_state["defaults"].txt2vid.variant_amount.max_value = st.number_input("Maximum txt2vid Variation Amount",
|
||||
value=st.session_state['defaults'].txt2vid.variant_amount.max_value,
|
||||
help="Set the default maximum value for the variation slider. Default is: 1.0")
|
||||
|
||||
st.session_state["defaults"].txt2vid.variant_amount.step = st.number_input("txt2vid Variation Slider Steps",
|
||||
value=st.session_state['defaults'].txt2vid.variant_amount.step,
|
||||
help="Set the default value for the number of steps on the variation slider. Default is: 1")
|
||||
|
||||
st.session_state['defaults'].txt2vid.variant_seed = st.text_input("Default txt2vid Variation Seed",
|
||||
value=st.session_state['defaults'].txt2vid.variant_seed, help="Default variation seed.")
|
||||
|
||||
with col5:
|
||||
st.title("Beta Parameters")
|
||||
|
||||
# Beta Start
|
||||
st.session_state["defaults"].txt2vid.beta_start.value = st.number_input("Default txt2vid Beta Start Value",
|
||||
value=st.session_state['defaults'].txt2vid.beta_start.value,
|
||||
help="Set the default variation to use. Default is: 0.0")
|
||||
|
||||
st.session_state["defaults"].txt2vid.beta_start.min_value = st.number_input("Minimum txt2vid Beta Start Amount",
|
||||
value=st.session_state['defaults'].txt2vid.beta_start.min_value,
|
||||
help="Set the default minimum value for the variation slider. Default is: 0.0")
|
||||
|
||||
st.session_state["defaults"].txt2vid.beta_start.max_value = st.number_input("Maximum txt2vid Beta Start Amount",
|
||||
value=st.session_state['defaults'].txt2vid.beta_start.max_value,
|
||||
help="Set the default maximum value for the variation slider. Default is: 1.0")
|
||||
|
||||
st.session_state["defaults"].txt2vid.beta_start.step = st.number_input("txt2vid Beta Start Slider Steps", value=st.session_state['defaults'].txt2vid.beta_start.step,
|
||||
help="Set the default value for the number of steps on the variation slider. Default is: 1")
|
||||
|
||||
st.session_state["defaults"].txt2vid.beta_start.format = st.text_input("Default txt2vid Beta Start Format", value=st.session_state['defaults'].txt2vid.beta_start.format,
|
||||
help="Set the default Beta Start Format. Default is: %.5\f")
|
||||
|
||||
# Beta End
|
||||
st.session_state["defaults"].txt2vid.beta_end.value = st.number_input("Default txt2vid Beta End Value", value=st.session_state['defaults'].txt2vid.beta_end.value,
|
||||
help="Set the default variation to use. Default is: 0.0")
|
||||
|
||||
st.session_state["defaults"].txt2vid.beta_end.min_value = st.number_input("Minimum txt2vid Beta End Amount", value=st.session_state['defaults'].txt2vid.beta_end.min_value,
|
||||
help="Set the default minimum value for the variation slider. Default is: 0.0")
|
||||
|
||||
st.session_state["defaults"].txt2vid.beta_end.max_value = st.number_input("Maximum txt2vid Beta End Amount", value=st.session_state['defaults'].txt2vid.beta_end.max_value,
|
||||
help="Set the default maximum value for the variation slider. Default is: 1.0")
|
||||
|
||||
st.session_state["defaults"].txt2vid.beta_end.step = st.number_input("txt2vid Beta End Slider Steps", value=st.session_state['defaults'].txt2vid.beta_end.step,
|
||||
help="Set the default value for the number of steps on the variation slider. Default is: 1")
|
||||
|
||||
st.session_state["defaults"].txt2vid.beta_end.format = st.text_input("Default txt2vid Beta End Format", value=st.session_state['defaults'].txt2vid.beta_start.format,
|
||||
help="Set the default Beta Start Format. Default is: %.5\f")
|
||||
|
||||
with image_processing:
|
||||
col1, col2, col3, col4, col5 = st.columns(5, gap="large")
|
||||
|
||||
with col1:
|
||||
st.title("GFPGAN")
|
||||
|
||||
st.session_state["defaults"].gfpgan.strength = st.number_input("Default Img2Txt Batch Size", value=st.session_state['defaults'].gfpgan.strength,
|
||||
help="Set the default global strength for GFPGAN. Default is: 100")
|
||||
with col2:
|
||||
st.title("GoBig")
|
||||
with col3:
|
||||
st.title("RealESRGAN")
|
||||
with col4:
|
||||
st.title("LDSR")
|
||||
with col5:
|
||||
st.title("GoLatent")
|
||||
|
||||
with textual_inversion_tab:
|
||||
st.title("Textual Inversion")
|
||||
|
||||
st.session_state['defaults'].textual_inversion.pretrained_model_name_or_path = st.text_input("Default Textual Inversion Model Path",
|
||||
value=st.session_state['defaults'].textual_inversion.pretrained_model_name_or_path,
|
||||
help="Default: models/ldm/stable-diffusion-v1-4")
|
||||
|
||||
st.session_state['defaults'].textual_inversion.tokenizer_name = st.text_input("Default Img2Img Variation Seed", value=st.session_state['defaults'].textual_inversion.tokenizer_name,
|
||||
help="Default tokenizer seed.")
|
||||
|
||||
with concepts_library_tab:
|
||||
st.title("Concepts Library")
|
||||
#st.info("Under Construction. :construction_worker:")
|
||||
col1, col2, col3, col4, col5 = st.columns(5, gap='large')
|
||||
with col1:
|
||||
st.session_state["defaults"].concepts_library.concepts_per_page = st.number_input("Concepts Per Page", value=st.session_state['defaults'].concepts_library.concepts_per_page,
|
||||
help="Number of concepts per page to show on the Concepts Library. Default: '12'")
|
||||
|
||||
# add space for the buttons at the bottom
|
||||
st.markdown("---")
|
||||
|
||||
# We need a submit button to save the Settings
|
||||
# as well as one to reset them to the defaults, just in case.
|
||||
_, _, save_button_col, reset_button_col, _, _ = st.columns([1, 1, 1, 1, 1, 1], gap="large")
|
||||
with save_button_col:
|
||||
save_button = st.form_submit_button("Save")
|
||||
|
||||
with reset_button_col:
|
||||
reset_button = st.form_submit_button("Reset")
|
||||
|
||||
if save_button:
|
||||
OmegaConf.save(config=st.session_state.defaults, f="configs/webui/userconfig_streamlit.yaml")
|
||||
loaded = OmegaConf.load("configs/webui/userconfig_streamlit.yaml")
|
||||
assert st.session_state.defaults == loaded
|
||||
|
||||
#
|
||||
if (os.path.exists(".streamlit/config.toml")):
|
||||
with open(".streamlit/config.toml", "w") as toml_file:
|
||||
toml.dump(st.session_state["streamlit_config"], toml_file)
|
||||
|
||||
if reset_button:
|
||||
st.session_state["defaults"] = OmegaConf.load("configs/webui/webui_streamlit.yaml")
|
||||
st.experimental_rerun()
|
91
webui/streamlit/scripts/barfi_baklavajs.py
Normal file
91
webui/streamlit/scripts/barfi_baklavajs.py
Normal file
@ -0,0 +1,91 @@
|
||||
# This file is part of sygil-webui (https://github.com/Sygil-Dev/sandbox-webui/).
|
||||
|
||||
# Copyright 2022 Sygil-Dev team.
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as published by
|
||||
# the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
|
||||
# This program is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU Affero General Public License for more details.
|
||||
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
# base webui import and utils.
|
||||
#from sd_utils import *
|
||||
from sd_utils import st
|
||||
# streamlit imports
|
||||
|
||||
#streamlit components section
|
||||
|
||||
#other imports
|
||||
from barfi import st_barfi, barfi_schemas, Block
|
||||
|
||||
# Temp imports
|
||||
|
||||
# end of imports
|
||||
#---------------------------------------------------------------------------------------------------------------
|
||||
|
||||
|
||||
def layout():
|
||||
#st.info("Under Construction. :construction_worker:")
|
||||
|
||||
#from barfi import st_barfi, Block
|
||||
|
||||
#add = Block(name='Addition')
|
||||
#sub = Block(name='Subtraction')
|
||||
#mul = Block(name='Multiplication')
|
||||
#div = Block(name='Division')
|
||||
|
||||
#barfi_result = st_barfi(base_blocks= [add, sub, mul, div])
|
||||
# or if you want to use a category to organise them in the frontend sub-menu
|
||||
#barfi_result = st_barfi(base_blocks= {'Op 1': [add, sub], 'Op 2': [mul, div]})
|
||||
|
||||
col1, col2, col3 = st.columns([1, 8, 1])
|
||||
|
||||
with col2:
|
||||
feed = Block(name='Feed')
|
||||
feed.add_output()
|
||||
def feed_func(self):
|
||||
self.set_interface(name='Output 1', value=4)
|
||||
feed.add_compute(feed_func)
|
||||
|
||||
splitter = Block(name='Splitter')
|
||||
splitter.add_input()
|
||||
splitter.add_output()
|
||||
splitter.add_output()
|
||||
def splitter_func(self):
|
||||
in_1 = self.get_interface(name='Input 1')
|
||||
value = (in_1/2)
|
||||
self.set_interface(name='Output 1', value=value)
|
||||
self.set_interface(name='Output 2', value=value)
|
||||
splitter.add_compute(splitter_func)
|
||||
|
||||
mixer = Block(name='Mixer')
|
||||
mixer.add_input()
|
||||
mixer.add_input()
|
||||
mixer.add_output()
|
||||
def mixer_func(self):
|
||||
in_1 = self.get_interface(name='Input 1')
|
||||
in_2 = self.get_interface(name='Input 2')
|
||||
value = (in_1 + in_2)
|
||||
self.set_interface(name='Output 1', value=value)
|
||||
mixer.add_compute(mixer_func)
|
||||
|
||||
result = Block(name='Result')
|
||||
result.add_input()
|
||||
def result_func(self):
|
||||
in_1 = self.get_interface(name='Input 1')
|
||||
result.add_compute(result_func)
|
||||
|
||||
load_schema = st.selectbox('Select a saved schema:', barfi_schemas())
|
||||
|
||||
compute_engine = st.checkbox('Activate barfi compute engine', value=False)
|
||||
|
||||
barfi_result = st_barfi(base_blocks=[feed, result, mixer, splitter],
|
||||
compute_engine=compute_engine, load_schema=load_schema)
|
||||
|
||||
if barfi_result:
|
||||
st.write(barfi_result)
|
@ -0,0 +1,134 @@
|
||||
<html>
|
||||
<head>
|
||||
<style>
|
||||
/* our style while dragging */
|
||||
.value-dragging
|
||||
{
|
||||
background-color: lightblue;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<!-- our fields -->
|
||||
<input type="number" value="0" min=0 max=200 step=10>
|
||||
<input type="number" value="0" min=-50 max=200 step=10>
|
||||
|
||||
<script>
|
||||
// iframe parent
|
||||
// var parentDoc = window.parent.document
|
||||
|
||||
// check for mouse pointer locking support, not a requirement but improves the overall experience
|
||||
var havePointerLock = 'pointerLockElement' in document ||
|
||||
'mozPointerLockElement' in document ||
|
||||
'webkitPointerLockElement' in document;
|
||||
|
||||
// the pointer locking exit function
|
||||
document.exitPointerLock = document.exitPointerLock || document.mozExitPointerLock || document.webkitExitPointerLock;
|
||||
|
||||
// how far should the mouse travel for a step 50 pixel
|
||||
var pixelPerStep = 50;
|
||||
|
||||
// how many steps did the mouse move in as float
|
||||
var movementDelta = 0.0;
|
||||
// value when drag started
|
||||
var lockedValue = 0;
|
||||
// minimum value from field
|
||||
var lockedMin = 0;
|
||||
// maximum value from field
|
||||
var lockedMax = 0;
|
||||
// how big should the field steps be
|
||||
var lockedStep = 0;
|
||||
// the currently locked in field
|
||||
var lockedField = null;
|
||||
|
||||
// register events and pointer locking on field
|
||||
RegisterField = (field) => {
|
||||
if(havePointerLock)
|
||||
field.requestPointerLock = field.requestPointerLock || field.mozRequestPointerLock || field.webkitRequestPointerLock;
|
||||
|
||||
field.title = "Click and hold middle mouse button\nmove mouse left to decrease\nmove right to increase";
|
||||
|
||||
field.addEventListener('mousedown', (e) => {
|
||||
onDragStart(e)
|
||||
});
|
||||
}
|
||||
|
||||
onDragStart = (e) => {
|
||||
// if middle mouse is down
|
||||
if(e.button === 1)
|
||||
{
|
||||
// save current field
|
||||
lockedField = e.target;
|
||||
// add class for styling
|
||||
lockedField.classList.add("value-dragging");
|
||||
// reset movement delta
|
||||
movementDelta = 0.0;
|
||||
// set to 0 if field is empty
|
||||
if(lockedField.value === '')
|
||||
lockedField.value = '0';
|
||||
|
||||
// save current field value
|
||||
lockedValue = parseInt(lockedField.value);
|
||||
|
||||
if(lockedField.min === '')
|
||||
lockedField.min = '-99999999';
|
||||
|
||||
if(lockedField.max === '')
|
||||
lockedField.max = '99999999';
|
||||
|
||||
if(lockedField.step === '')
|
||||
lockedField.step = '10';
|
||||
|
||||
lockedMin = parseInt(lockedField.min);
|
||||
lockedMax = parseInt(lockedField.max);
|
||||
|
||||
lockedStep = parseInt(lockedField.step);
|
||||
|
||||
// lock pointer if available
|
||||
if(havePointerLock)
|
||||
lockedField.requestPointerLock();
|
||||
|
||||
// add drag event
|
||||
document.addEventListener("mousemove", onDrag, false);
|
||||
// prevent event propagation
|
||||
e.preventDefault();
|
||||
}
|
||||
};
|
||||
|
||||
onDrag = (e) => {
|
||||
if(lockedField !== null)
|
||||
{
|
||||
// add movement to delta
|
||||
movementDelta += e.movementX / pixelPerStep;
|
||||
// set new value
|
||||
let value = lockedValue + Math.floor(Math.abs(movementDelta)) * lockedStep * Math.sign(movementDelta);
|
||||
lockedField.value = Math.min(Math.max(value, lockedMin), lockedMax);
|
||||
}
|
||||
};
|
||||
|
||||
document.addEventListener('mouseup', (e) => {
|
||||
// if middle mouse is up
|
||||
if(e.button === 1)
|
||||
{
|
||||
// release pointer lock if available
|
||||
if(havePointerLock)
|
||||
document.exitPointerLock();
|
||||
|
||||
if(lockedField !== null)
|
||||
{
|
||||
// stop drag event
|
||||
document.removeEventListener("mousemove", onDrag, false);
|
||||
// remove class for styling
|
||||
lockedField.classList.remove("value-dragging");
|
||||
// remove reference
|
||||
lockedField = null;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// find and register all input fields of type=number
|
||||
var list = document.querySelectorAll('input[type="number"]');
|
||||
list.forEach(RegisterField);
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
@ -0,0 +1,11 @@
|
||||
import os
|
||||
import streamlit.components.v1 as components
|
||||
|
||||
def load(pixel_per_step = 50):
|
||||
parent_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
file = os.path.join(parent_dir, "main.js")
|
||||
|
||||
with open(file) as f:
|
||||
javascript_main = f.read()
|
||||
javascript_main = javascript_main.replace("%%pixelPerStep%%",str(pixel_per_step))
|
||||
components.html(f"<script>{javascript_main}</script>")
|
@ -0,0 +1,192 @@
|
||||
// iframe parent
|
||||
var parentDoc = window.parent.document
|
||||
|
||||
// check for mouse pointer locking support, not a requirement but improves the overall experience
|
||||
var havePointerLock = 'pointerLockElement' in parentDoc ||
|
||||
'mozPointerLockElement' in parentDoc ||
|
||||
'webkitPointerLockElement' in parentDoc;
|
||||
|
||||
// the pointer locking exit function
|
||||
parentDoc.exitPointerLock = parentDoc.exitPointerLock || parentDoc.mozExitPointerLock || parentDoc.webkitExitPointerLock;
|
||||
|
||||
// how far should the mouse travel for a step in pixel
|
||||
var pixelPerStep = %%pixelPerStep%%;
|
||||
// how many steps did the mouse move in as float
|
||||
var movementDelta = 0.0;
|
||||
// value when drag started
|
||||
var lockedValue = 0.0;
|
||||
// minimum value from field
|
||||
var lockedMin = 0.0;
|
||||
// maximum value from field
|
||||
var lockedMax = 0.0;
|
||||
// how big should the field steps be
|
||||
var lockedStep = 0.0;
|
||||
// the currently locked in field
|
||||
var lockedField = null;
|
||||
|
||||
// lock box to just request pointer lock for one element
|
||||
var lockBox = document.createElement("div");
|
||||
lockBox.classList.add("lockbox");
|
||||
parentDoc.body.appendChild(lockBox);
|
||||
lockBox.requestPointerLock = lockBox.requestPointerLock || lockBox.mozRequestPointerLock || lockBox.webkitRequestPointerLock;
|
||||
|
||||
function Lock(field)
|
||||
{
|
||||
var rect = field.getBoundingClientRect();
|
||||
lockBox.style.left = (rect.left-2.5)+"px";
|
||||
lockBox.style.top = (rect.top-2.5)+"px";
|
||||
|
||||
lockBox.style.width = (rect.width+2.5)+"px";
|
||||
lockBox.style.height = (rect.height+5)+"px";
|
||||
|
||||
lockBox.requestPointerLock();
|
||||
}
|
||||
|
||||
function Unlock()
|
||||
{
|
||||
parentDoc.exitPointerLock();
|
||||
lockBox.style.left = "0px";
|
||||
lockBox.style.top = "0px";
|
||||
|
||||
lockBox.style.width = "0px";
|
||||
lockBox.style.height = "0px";
|
||||
lockedField.focus();
|
||||
}
|
||||
|
||||
parentDoc.addEventListener('mousedown', (e) => {
|
||||
// if middle is down
|
||||
if(e.button === 1)
|
||||
{
|
||||
if(e.target.tagName === 'INPUT' && e.target.type === 'number')
|
||||
{
|
||||
e.preventDefault();
|
||||
var field = e.target;
|
||||
if(havePointerLock)
|
||||
Lock(field);
|
||||
|
||||
// save current field
|
||||
lockedField = e.target;
|
||||
// add class for styling
|
||||
lockedField.classList.add("value-dragging");
|
||||
// reset movement delta
|
||||
movementDelta = 0.0;
|
||||
// set to 0 if field is empty
|
||||
if(lockedField.value === '')
|
||||
lockedField.value = 0.0;
|
||||
|
||||
// save current field value
|
||||
lockedValue = parseFloat(lockedField.value);
|
||||
|
||||
if(lockedField.min === '' || lockedField.min === '-Infinity')
|
||||
lockedMin = -99999999.0;
|
||||
else
|
||||
lockedMin = parseFloat(lockedField.min);
|
||||
|
||||
if(lockedField.max === '' || lockedField.max === 'Infinity')
|
||||
lockedMax = 99999999.0;
|
||||
else
|
||||
lockedMax = parseFloat(lockedField.max);
|
||||
|
||||
if(lockedField.step === '' || lockedField.step === 'Infinity')
|
||||
lockedStep = 1.0;
|
||||
else
|
||||
lockedStep = parseFloat(lockedField.step);
|
||||
|
||||
// lock pointer if available
|
||||
if(havePointerLock)
|
||||
Lock(lockedField);
|
||||
|
||||
// add drag event
|
||||
parentDoc.addEventListener("mousemove", onDrag, false);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
function onDrag(e)
|
||||
{
|
||||
if(lockedField !== null)
|
||||
{
|
||||
// add movement to delta
|
||||
movementDelta += e.movementX / pixelPerStep;
|
||||
if(lockedField === NaN)
|
||||
return;
|
||||
// set new value
|
||||
let value = lockedValue + Math.floor(Math.abs(movementDelta)) * lockedStep * Math.sign(movementDelta);
|
||||
lockedField.focus();
|
||||
lockedField.select();
|
||||
parentDoc.execCommand('insertText', false /*no UI*/, Math.min(Math.max(value, lockedMin), lockedMax));
|
||||
}
|
||||
}
|
||||
|
||||
parentDoc.addEventListener('mouseup', (e) => {
|
||||
// if mouse is up
|
||||
if(e.button === 1)
|
||||
{
|
||||
// release pointer lock if available
|
||||
if(havePointerLock)
|
||||
Unlock();
|
||||
|
||||
if(lockedField !== null && lockedField !== NaN)
|
||||
{
|
||||
// stop drag event
|
||||
parentDoc.removeEventListener("mousemove", onDrag, false);
|
||||
// remove class for styling
|
||||
lockedField.classList.remove("value-dragging");
|
||||
// remove reference
|
||||
lockedField = null;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// only execute once (even though multiple iframes exist)
|
||||
if(!parentDoc.hasOwnProperty("dragableInitialized"))
|
||||
{
|
||||
var parentCSS =
|
||||
`
|
||||
/* Make input-instruction not block mouse events */
|
||||
.input-instructions,.input-instructions > *{
|
||||
pointer-events: none;
|
||||
user-select: none;
|
||||
-moz-user-select: none;
|
||||
-khtml-user-select: none;
|
||||
-webkit-user-select: none;
|
||||
-o-user-select: none;
|
||||
}
|
||||
|
||||
.lockbox {
|
||||
background-color: transparent;
|
||||
position: absolute;
|
||||
pointer-events: none;
|
||||
user-select: none;
|
||||
-moz-user-select: none;
|
||||
-khtml-user-select: none;
|
||||
-webkit-user-select: none;
|
||||
-o-user-select: none;
|
||||
border-left: dotted 2px rgb(255,75,75);
|
||||
border-top: dotted 2px rgb(255,75,75);
|
||||
border-bottom: dotted 2px rgb(255,75,75);
|
||||
border-right: dotted 1px rgba(255,75,75,0.2);
|
||||
border-top-left-radius: 0.25rem;
|
||||
border-bottom-left-radius: 0.25rem;
|
||||
z-index: 1000;
|
||||
}
|
||||
`;
|
||||
|
||||
// get parent document head
|
||||
var head = parentDoc.getElementsByTagName('head')[0];
|
||||
// add style tag
|
||||
var s = document.createElement('style');
|
||||
// set type attribute
|
||||
s.setAttribute('type', 'text/css');
|
||||
// add css forwarded from python
|
||||
if (s.styleSheet) { // IE
|
||||
s.styleSheet.cssText = parentCSS;
|
||||
} else { // the world
|
||||
s.appendChild(document.createTextNode(parentCSS));
|
||||
}
|
||||
// add style to head
|
||||
head.appendChild(s);
|
||||
// set flag so this only runs once
|
||||
parentDoc["dragableInitialized"] = true;
|
||||
}
|
||||
|
@ -0,0 +1,46 @@
|
||||
import os
|
||||
from collections import defaultdict
|
||||
import streamlit.components.v1 as components
|
||||
|
||||
# where to save the downloaded key_phrases
|
||||
key_phrases_file = "data/tags/key_phrases.json"
|
||||
# the loaded key phrase json as text
|
||||
key_phrases_json = ""
|
||||
# where to save the downloaded key_phrases
|
||||
thumbnails_file = "data/tags/thumbnails.json"
|
||||
# the loaded key phrase json as text
|
||||
thumbnails_json = ""
|
||||
|
||||
def init():
|
||||
global key_phrases_json, thumbnails_json
|
||||
with open(key_phrases_file) as f:
|
||||
key_phrases_json = f.read()
|
||||
with open(thumbnails_file) as f:
|
||||
thumbnails_json = f.read()
|
||||
|
||||
def suggestion_area(placeholder):
|
||||
# get component path
|
||||
parent_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
# get file paths
|
||||
javascript_file = os.path.join(parent_dir, "main.js")
|
||||
stylesheet_file = os.path.join(parent_dir, "main.css")
|
||||
parent_stylesheet_file = os.path.join(parent_dir, "parent.css")
|
||||
|
||||
# load file texts
|
||||
with open(javascript_file) as f:
|
||||
javascript_main = f.read()
|
||||
with open(stylesheet_file) as f:
|
||||
stylesheet_main = f.read()
|
||||
with open(parent_stylesheet_file) as f:
|
||||
parent_stylesheet = f.read()
|
||||
|
||||
# add suggestion area div box
|
||||
html = "<div id='scroll_area' class='st-bg'><div id='suggestion_area'>javascript failed</div></div>"
|
||||
# add loaded style
|
||||
html += f"<style>{stylesheet_main}</style>"
|
||||
# set default variables
|
||||
html += f"<script>var thumbnails = {thumbnails_json};\nvar keyPhrases = {key_phrases_json};\nvar parentCSS = `{parent_stylesheet}`;\nvar placeholder='{placeholder}';</script>"
|
||||
# add main java script
|
||||
html += f"\n<script>{javascript_main}</script>"
|
||||
# add component to site
|
||||
components.html(html, width=None, height=None, scrolling=True)
|
@ -0,0 +1,81 @@
|
||||
*
|
||||
{
|
||||
padding: 0px;
|
||||
margin: 0px;
|
||||
user-select: none;
|
||||
-moz-user-select: none;
|
||||
-khtml-user-select: none;
|
||||
-webkit-user-select: none;
|
||||
-o-user-select: none;
|
||||
}
|
||||
|
||||
body
|
||||
{
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
padding-left: calc( 1em - 1px );
|
||||
padding-top: calc( 1em - 1px );
|
||||
overflow: hidden;
|
||||
}
|
||||
|
||||
/* width */
|
||||
::-webkit-scrollbar {
|
||||
width: 7px;
|
||||
}
|
||||
|
||||
/* Track */
|
||||
::-webkit-scrollbar-track {
|
||||
background: rgb(10, 13, 19);
|
||||
}
|
||||
|
||||
/* Handle */
|
||||
::-webkit-scrollbar-thumb {
|
||||
background: #6c6e72;
|
||||
border-radius: 3px;
|
||||
}
|
||||
|
||||
/* Handle on hover */
|
||||
::-webkit-scrollbar-thumb:hover {
|
||||
background: #6c6e72;
|
||||
}
|
||||
|
||||
#scroll_area
|
||||
{
|
||||
display: flex;
|
||||
overflow-x: hidden;
|
||||
overflow-y: auto;
|
||||
}
|
||||
|
||||
#suggestion_area
|
||||
{
|
||||
overflow-x: hidden;
|
||||
width: calc( 100% - 2em - 2px );
|
||||
margin-bottom: calc( 1em + 13px );
|
||||
min-height: 50px;
|
||||
}
|
||||
|
||||
span
|
||||
{
|
||||
border: 1px solid rgba(250, 250, 250, 0.2);
|
||||
border-radius: 0.25rem;
|
||||
font-size: 1rem;
|
||||
font-family: "Source Sans Pro", sans-serif;
|
||||
|
||||
background-color: rgb(38, 39, 48);
|
||||
color: white;
|
||||
display: inline-block;
|
||||
padding: 0.5rem;
|
||||
margin-right: 3px;
|
||||
cursor: pointer;
|
||||
user-select: none;
|
||||
-moz-user-select: none;
|
||||
-khtml-user-select: none;
|
||||
-webkit-user-select: none;
|
||||
-o-user-select: none;
|
||||
}
|
||||
|
||||
span:hover
|
||||
{
|
||||
color: rgb(255,75,75);
|
||||
border-color: rgb(255,75,75);
|
||||
}
|
1048
webui/streamlit/scripts/custom_components/sygil_suggestions/main.js
Normal file
1048
webui/streamlit/scripts/custom_components/sygil_suggestions/main.js
Normal file
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,84 @@
|
||||
.suggestion-frame
|
||||
{
|
||||
position: absolute;
|
||||
|
||||
/* make as small as possible */
|
||||
margin: 0px;
|
||||
padding: 0px;
|
||||
min-height: 0px;
|
||||
line-height: 0;
|
||||
|
||||
/* animate transitions of the height property */
|
||||
-webkit-transition: height 1s;
|
||||
-moz-transition: height 1s;
|
||||
-ms-transition: height 1s;
|
||||
-o-transition: height 1s;
|
||||
transition: height 1s, border-bottom-width 1s;
|
||||
|
||||
/* block selection */
|
||||
user-select: none;
|
||||
-moz-user-select: none;
|
||||
-khtml-user-select: none;
|
||||
-webkit-user-select: none;
|
||||
-o-user-select: none;
|
||||
|
||||
z-index: 700;
|
||||
|
||||
outline: 1px solid rgba(250, 250, 250, 0.2);
|
||||
outline-offset: 0px;
|
||||
border-radius: 0.25rem;
|
||||
background: rgb(14, 17, 23);
|
||||
|
||||
box-sizing: border-box;
|
||||
-moz-box-sizing: border-box;
|
||||
-webkit-box-sizing: border-box;
|
||||
border-bottom: solid 13px rgb(14, 17, 23) !important;
|
||||
border-left: solid 13px rgb(14, 17, 23) !important;
|
||||
}
|
||||
|
||||
#phrase-tooltip
|
||||
{
|
||||
display: none;
|
||||
pointer-events: none;
|
||||
position: absolute;
|
||||
border-bottom-left-radius: 0.5rem;
|
||||
border-top-right-radius: 0.5rem;
|
||||
border-bottom-right-radius: 0.5rem;
|
||||
border: solid rgb(255,75,75) 2px;
|
||||
background-color: rgb(38, 39, 48);
|
||||
color: rgb(255,75,75);
|
||||
font-size: 1rem;
|
||||
font-family: "Source Sans Pro", sans-serif;
|
||||
padding: 0.5rem;
|
||||
|
||||
cursor: default;
|
||||
user-select: none;
|
||||
-moz-user-select: none;
|
||||
-khtml-user-select: none;
|
||||
-webkit-user-select: none;
|
||||
-o-user-select: none;
|
||||
z-index: 1000;
|
||||
}
|
||||
|
||||
#phrase-tooltip:has(img)
|
||||
{
|
||||
transform: scale(1.25, 1.25);
|
||||
-ms-transform: scale(1.25, 1.25);
|
||||
-webkit-transform: scale(1.25, 1.25);
|
||||
}
|
||||
|
||||
#phrase-tooltip>img
|
||||
{
|
||||
pointer-events: none;
|
||||
border-bottom-left-radius: 0.5rem;
|
||||
border-top-right-radius: 0.5rem;
|
||||
border-bottom-right-radius: 0.5rem;
|
||||
|
||||
cursor: default;
|
||||
user-select: none;
|
||||
-moz-user-select: none;
|
||||
-khtml-user-select: none;
|
||||
-webkit-user-select: none;
|
||||
-o-user-select: none;
|
||||
z-index: 1500;
|
||||
}
|
752
webui/streamlit/scripts/img2img.py
Normal file
752
webui/streamlit/scripts/img2img.py
Normal file
@ -0,0 +1,752 @@
|
||||
# This file is part of sygil-webui (https://github.com/Sygil-Dev/sygil-webui/).
|
||||
|
||||
# Copyright 2022 Sygil-Dev team.
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as published by
|
||||
# the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
|
||||
# This program is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU Affero General Public License for more details.
|
||||
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
# base webui import and utils.
|
||||
from sd_utils import st, server_state, no_rerun, \
|
||||
custom_models_available, RealESRGAN_available, GFPGAN_available, LDSR_available
|
||||
#generation_callback, process_images, KDiffusionSampler, \
|
||||
#load_models, hc, seed_to_int, logger, \
|
||||
#resize_image, get_matched_noise, CFGMaskedDenoiser, ImageFilter, set_page_title
|
||||
|
||||
# streamlit imports
|
||||
from streamlit.runtime.scriptrunner import StopException
|
||||
|
||||
#other imports
|
||||
import cv2
|
||||
from PIL import Image, ImageOps
|
||||
import torch
|
||||
import k_diffusion as K
|
||||
import numpy as np
|
||||
import time
|
||||
import torch
|
||||
import skimage
|
||||
from ldm.models.diffusion.ddim import DDIMSampler
|
||||
from ldm.models.diffusion.plms import PLMSSampler
|
||||
|
||||
# streamlit components
|
||||
from custom_components import sygil_suggestions
|
||||
from streamlit_drawable_canvas import st_canvas
|
||||
|
||||
# Temp imports
|
||||
|
||||
|
||||
# end of imports
|
||||
#---------------------------------------------------------------------------------------------------------------
|
||||
|
||||
sygil_suggestions.init()
|
||||
|
||||
try:
|
||||
# this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
|
||||
from transformers import logging
|
||||
|
||||
logging.set_verbosity_error()
|
||||
except:
|
||||
pass
|
||||
|
||||
def img2img(prompt: str = '', init_info: any = None, init_info_mask: any = None, mask_mode: int = 0, mask_blur_strength: int = 3,
|
||||
mask_restore: bool = False, ddim_steps: int = 50, sampler_name: str = 'DDIM',
|
||||
n_iter: int = 1, cfg_scale: float = 7.5, denoising_strength: float = 0.8,
|
||||
seed: int = -1, noise_mode: int = 0, find_noise_steps: str = "", height: int = 512, width: int = 512, resize_mode: int = 0, fp = None,
|
||||
variant_amount: float = 0.0, variant_seed: int = None, ddim_eta:float = 0.0,
|
||||
write_info_files:bool = True, separate_prompts:bool = False, normalize_prompt_weights:bool = True,
|
||||
save_individual_images: bool = True, save_grid: bool = True, group_by_prompt: bool = True,
|
||||
save_as_jpg: bool = True, use_GFPGAN: bool = True, GFPGAN_model: str = 'GFPGANv1.4',
|
||||
use_RealESRGAN: bool = True, RealESRGAN_model: str = "RealESRGAN_x4plus_anime_6B",
|
||||
use_LDSR: bool = True, LDSR_model: str = "model",
|
||||
loopback: bool = False,
|
||||
random_seed_loopback: bool = False
|
||||
):
|
||||
|
||||
outpath = st.session_state['defaults'].general.outdir_img2img
|
||||
seed = seed_to_int(seed)
|
||||
|
||||
batch_size = 1
|
||||
|
||||
if sampler_name == 'PLMS':
|
||||
sampler = PLMSSampler(server_state["model"])
|
||||
elif sampler_name == 'DDIM':
|
||||
sampler = DDIMSampler(server_state["model"])
|
||||
elif sampler_name == 'k_dpm_2_a':
|
||||
sampler = KDiffusionSampler(server_state["model"],'dpm_2_ancestral')
|
||||
elif sampler_name == 'k_dpm_2':
|
||||
sampler = KDiffusionSampler(server_state["model"],'dpm_2')
|
||||
elif sampler_name == 'k_dpmpp_2m':
|
||||
sampler = KDiffusionSampler(server_state["model"],'dpmpp_2m')
|
||||
elif sampler_name == 'k_euler_a':
|
||||
sampler = KDiffusionSampler(server_state["model"],'euler_ancestral')
|
||||
elif sampler_name == 'k_euler':
|
||||
sampler = KDiffusionSampler(server_state["model"],'euler')
|
||||
elif sampler_name == 'k_heun':
|
||||
sampler = KDiffusionSampler(server_state["model"],'heun')
|
||||
elif sampler_name == 'k_lms':
|
||||
sampler = KDiffusionSampler(server_state["model"],'lms')
|
||||
else:
|
||||
raise Exception("Unknown sampler: " + sampler_name)
|
||||
|
||||
def process_init_mask(init_mask: Image):
|
||||
if init_mask.mode == "RGBA":
|
||||
init_mask = init_mask.convert('RGBA')
|
||||
background = Image.new('RGBA', init_mask.size, (0, 0, 0))
|
||||
init_mask = Image.alpha_composite(background, init_mask)
|
||||
init_mask = init_mask.convert('RGB')
|
||||
return init_mask
|
||||
|
||||
init_img = init_info
|
||||
init_mask = None
|
||||
if mask_mode == 0:
|
||||
if init_info_mask:
|
||||
init_mask = process_init_mask(init_info_mask)
|
||||
elif mask_mode == 1:
|
||||
if init_info_mask:
|
||||
init_mask = process_init_mask(init_info_mask)
|
||||
init_mask = ImageOps.invert(init_mask)
|
||||
elif mask_mode == 2:
|
||||
init_img_transparency = init_img.split()[-1].convert('L')#.point(lambda x: 255 if x > 0 else 0, mode='1')
|
||||
init_mask = init_img_transparency
|
||||
init_mask = init_mask.convert("RGB")
|
||||
init_mask = resize_image(resize_mode, init_mask, width, height)
|
||||
init_mask = init_mask.convert("RGB")
|
||||
|
||||
assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]'
|
||||
t_enc = int(denoising_strength * ddim_steps)
|
||||
|
||||
if init_mask is not None and (noise_mode == 2 or noise_mode == 3) and init_img is not None:
|
||||
noise_q = 0.99
|
||||
color_variation = 0.0
|
||||
mask_blend_factor = 1.0
|
||||
|
||||
np_init = (np.asarray(init_img.convert("RGB"))/255.0).astype(np.float64) # annoyingly complex mask fixing
|
||||
np_mask_rgb = 1. - (np.asarray(ImageOps.invert(init_mask).convert("RGB"))/255.0).astype(np.float64)
|
||||
np_mask_rgb -= np.min(np_mask_rgb)
|
||||
np_mask_rgb /= np.max(np_mask_rgb)
|
||||
np_mask_rgb = 1. - np_mask_rgb
|
||||
np_mask_rgb_hardened = 1. - (np_mask_rgb < 0.99).astype(np.float64)
|
||||
blurred = skimage.filters.gaussian(np_mask_rgb_hardened[:], sigma=16., channel_axis=2, truncate=32.)
|
||||
blurred2 = skimage.filters.gaussian(np_mask_rgb_hardened[:], sigma=16., channel_axis=2, truncate=32.)
|
||||
#np_mask_rgb_dilated = np_mask_rgb + blurred # fixup mask todo: derive magic constants
|
||||
#np_mask_rgb = np_mask_rgb + blurred
|
||||
np_mask_rgb_dilated = np.clip((np_mask_rgb + blurred2) * 0.7071, 0., 1.)
|
||||
np_mask_rgb = np.clip((np_mask_rgb + blurred) * 0.7071, 0., 1.)
|
||||
|
||||
noise_rgb = get_matched_noise(np_init, np_mask_rgb, noise_q, color_variation)
|
||||
blend_mask_rgb = np.clip(np_mask_rgb_dilated,0.,1.) ** (mask_blend_factor)
|
||||
noised = noise_rgb[:]
|
||||
blend_mask_rgb **= (2.)
|
||||
noised = np_init[:] * (1. - blend_mask_rgb) + noised * blend_mask_rgb
|
||||
|
||||
np_mask_grey = np.sum(np_mask_rgb, axis=2)/3.
|
||||
ref_mask = np_mask_grey < 1e-3
|
||||
|
||||
all_mask = np.ones((height, width), dtype=bool)
|
||||
noised[all_mask,:] = skimage.exposure.match_histograms(noised[all_mask,:]**1., noised[ref_mask,:], channel_axis=1)
|
||||
|
||||
init_img = Image.fromarray(np.clip(noised * 255., 0., 255.).astype(np.uint8), mode="RGB")
|
||||
st.session_state["editor_image"].image(init_img) # debug
|
||||
|
||||
def init():
|
||||
image = init_img.convert('RGB')
|
||||
image = np.array(image).astype(np.float32) / 255.0
|
||||
image = image[None].transpose(0, 3, 1, 2)
|
||||
image = torch.from_numpy(image)
|
||||
|
||||
mask_channel = None
|
||||
if init_mask:
|
||||
alpha = resize_image(resize_mode, init_mask, width // 8, height // 8)
|
||||
mask_channel = alpha.split()[-1]
|
||||
|
||||
mask = None
|
||||
if mask_channel is not None:
|
||||
mask = np.array(mask_channel).astype(np.float32) / 255.0
|
||||
mask = (1 - mask)
|
||||
mask = np.tile(mask, (4, 1, 1))
|
||||
mask = mask[None].transpose(0, 1, 2, 3)
|
||||
mask = torch.from_numpy(mask).to(server_state["device"])
|
||||
|
||||
if st.session_state['defaults'].general.optimized:
|
||||
server_state["modelFS"].to(server_state["device"] )
|
||||
|
||||
init_image = 2. * image - 1.
|
||||
init_image = init_image.to(server_state["device"])
|
||||
init_latent = (server_state["model"] if not st.session_state['defaults'].general.optimized else server_state["modelFS"]).get_first_stage_encoding((server_state["model"] if not st.session_state['defaults'].general.optimized else server_state["modelFS"]).encode_first_stage(init_image)) # move to latent space
|
||||
|
||||
if st.session_state['defaults'].general.optimized:
|
||||
mem = torch.cuda.memory_allocated()/1e6
|
||||
server_state["modelFS"].to("cpu")
|
||||
while(torch.cuda.memory_allocated()/1e6 >= mem):
|
||||
time.sleep(1)
|
||||
|
||||
return init_latent, mask,
|
||||
|
||||
def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name):
|
||||
t_enc_steps = t_enc
|
||||
obliterate = False
|
||||
if ddim_steps == t_enc_steps:
|
||||
t_enc_steps = t_enc_steps - 1
|
||||
obliterate = True
|
||||
|
||||
if sampler_name != 'DDIM':
|
||||
x0, z_mask = init_data
|
||||
|
||||
sigmas = sampler.model_wrap.get_sigmas(ddim_steps)
|
||||
noise = x * sigmas[ddim_steps - t_enc_steps - 1]
|
||||
|
||||
xi = x0 + noise
|
||||
|
||||
# Obliterate masked image
|
||||
if z_mask is not None and obliterate:
|
||||
random = torch.randn(z_mask.shape, device=xi.device)
|
||||
xi = (z_mask * noise) + ((1-z_mask) * xi)
|
||||
|
||||
sigma_sched = sigmas[ddim_steps - t_enc_steps - 1:]
|
||||
model_wrap_cfg = CFGMaskedDenoiser(sampler.model_wrap)
|
||||
samples_ddim = K.sampling.__dict__[f'sample_{sampler.get_sampler_name()}'](model_wrap_cfg, xi, sigma_sched,
|
||||
extra_args={'cond': conditioning, 'uncond': unconditional_conditioning,
|
||||
'cond_scale': cfg_scale, 'mask': z_mask, 'x0': x0, 'xi': xi}, disable=False,
|
||||
callback=generation_callback if not server_state["bridge"] else None)
|
||||
else:
|
||||
|
||||
x0, z_mask = init_data
|
||||
|
||||
sampler.make_schedule(ddim_num_steps=ddim_steps, ddim_eta=0.0, verbose=False)
|
||||
z_enc = sampler.stochastic_encode(x0, torch.tensor([t_enc_steps]*batch_size).to(server_state["device"] ))
|
||||
|
||||
# Obliterate masked image
|
||||
if z_mask is not None and obliterate:
|
||||
random = torch.randn(z_mask.shape, device=z_enc.device)
|
||||
z_enc = (z_mask * random) + ((1-z_mask) * z_enc)
|
||||
|
||||
# decode it
|
||||
samples_ddim = sampler.decode(z_enc, conditioning, t_enc_steps,
|
||||
unconditional_guidance_scale=cfg_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
z_mask=z_mask, x0=x0)
|
||||
return samples_ddim
|
||||
|
||||
|
||||
|
||||
if loopback:
|
||||
output_images, info = None, None
|
||||
history = []
|
||||
initial_seed = None
|
||||
|
||||
do_color_correction = False
|
||||
try:
|
||||
from skimage import exposure
|
||||
do_color_correction = True
|
||||
except:
|
||||
logger.error("Install scikit-image to perform color correction on loopback")
|
||||
|
||||
for i in range(n_iter):
|
||||
if do_color_correction and i == 0:
|
||||
correction_target = cv2.cvtColor(np.asarray(init_img.copy()), cv2.COLOR_RGB2LAB)
|
||||
|
||||
# RealESRGAN can only run on the final iteration
|
||||
is_final_iteration = i == n_iter - 1
|
||||
|
||||
output_images, seed, info, stats = process_images(
|
||||
outpath=outpath,
|
||||
func_init=init,
|
||||
func_sample=sample,
|
||||
prompt=prompt,
|
||||
seed=seed,
|
||||
sampler_name=sampler_name,
|
||||
save_grid=save_grid,
|
||||
batch_size=1,
|
||||
n_iter=1,
|
||||
steps=ddim_steps,
|
||||
cfg_scale=cfg_scale,
|
||||
width=width,
|
||||
height=height,
|
||||
prompt_matrix=separate_prompts,
|
||||
use_GFPGAN=use_GFPGAN,
|
||||
GFPGAN_model=GFPGAN_model,
|
||||
use_RealESRGAN=use_RealESRGAN and is_final_iteration, # Forcefully disable upscaling when using loopback
|
||||
realesrgan_model_name=RealESRGAN_model,
|
||||
use_LDSR=use_LDSR,
|
||||
LDSR_model_name=LDSR_model,
|
||||
normalize_prompt_weights=normalize_prompt_weights,
|
||||
save_individual_images=save_individual_images,
|
||||
init_img=init_img,
|
||||
init_mask=init_mask,
|
||||
mask_blur_strength=mask_blur_strength,
|
||||
mask_restore=mask_restore,
|
||||
denoising_strength=denoising_strength,
|
||||
noise_mode=noise_mode,
|
||||
find_noise_steps=find_noise_steps,
|
||||
resize_mode=resize_mode,
|
||||
uses_loopback=loopback,
|
||||
uses_random_seed_loopback=random_seed_loopback,
|
||||
sort_samples=group_by_prompt,
|
||||
write_info_files=write_info_files,
|
||||
jpg_sample=save_as_jpg
|
||||
)
|
||||
|
||||
if initial_seed is None:
|
||||
initial_seed = seed
|
||||
|
||||
input_image = init_img
|
||||
init_img = output_images[0]
|
||||
|
||||
if do_color_correction and correction_target is not None:
|
||||
init_img = Image.fromarray(cv2.cvtColor(exposure.match_histograms(
|
||||
cv2.cvtColor(
|
||||
np.asarray(init_img),
|
||||
cv2.COLOR_RGB2LAB
|
||||
),
|
||||
correction_target,
|
||||
channel_axis=2
|
||||
), cv2.COLOR_LAB2RGB).astype("uint8"))
|
||||
if mask_restore is True and init_mask is not None:
|
||||
color_mask = init_mask.filter(ImageFilter.GaussianBlur(mask_blur_strength))
|
||||
color_mask = color_mask.convert('L')
|
||||
source_image = input_image.convert('RGB')
|
||||
target_image = init_img.convert('RGB')
|
||||
|
||||
init_img = Image.composite(source_image, target_image, color_mask)
|
||||
|
||||
if not random_seed_loopback:
|
||||
seed = seed + 1
|
||||
else:
|
||||
seed = seed_to_int(None)
|
||||
|
||||
denoising_strength = max(denoising_strength * 0.95, 0.1)
|
||||
history.append(init_img)
|
||||
|
||||
output_images = history
|
||||
seed = initial_seed
|
||||
|
||||
else:
|
||||
output_images, seed, info, stats = process_images(
|
||||
outpath=outpath,
|
||||
func_init=init,
|
||||
func_sample=sample,
|
||||
prompt=prompt,
|
||||
seed=seed,
|
||||
sampler_name=sampler_name,
|
||||
save_grid=save_grid,
|
||||
batch_size=batch_size,
|
||||
n_iter=n_iter,
|
||||
steps=ddim_steps,
|
||||
cfg_scale=cfg_scale,
|
||||
width=width,
|
||||
height=height,
|
||||
prompt_matrix=separate_prompts,
|
||||
use_GFPGAN=use_GFPGAN,
|
||||
GFPGAN_model=GFPGAN_model,
|
||||
use_RealESRGAN=use_RealESRGAN,
|
||||
realesrgan_model_name=RealESRGAN_model,
|
||||
use_LDSR=use_LDSR,
|
||||
LDSR_model_name=LDSR_model,
|
||||
normalize_prompt_weights=normalize_prompt_weights,
|
||||
save_individual_images=save_individual_images,
|
||||
init_img=init_img,
|
||||
init_mask=init_mask,
|
||||
mask_blur_strength=mask_blur_strength,
|
||||
denoising_strength=denoising_strength,
|
||||
noise_mode=noise_mode,
|
||||
find_noise_steps=find_noise_steps,
|
||||
mask_restore=mask_restore,
|
||||
resize_mode=resize_mode,
|
||||
uses_loopback=loopback,
|
||||
sort_samples=group_by_prompt,
|
||||
write_info_files=write_info_files,
|
||||
jpg_sample=save_as_jpg
|
||||
)
|
||||
|
||||
del sampler
|
||||
|
||||
return output_images, seed, info, stats
|
||||
|
||||
#
|
||||
def layout():
|
||||
with st.form("img2img-inputs"):
|
||||
st.session_state["generation_mode"] = "img2img"
|
||||
|
||||
img2img_input_col, img2img_generate_col = st.columns([10,1])
|
||||
with img2img_input_col:
|
||||
#prompt = st.text_area("Input Text","")
|
||||
placeholder = "A corgi wearing a top hat as an oil painting."
|
||||
prompt = st.text_area("Input Text","", placeholder=placeholder, height=54)
|
||||
|
||||
if "defaults" in st.session_state:
|
||||
if st.session_state["defaults"].general.enable_suggestions:
|
||||
sygil_suggestions.suggestion_area(placeholder)
|
||||
|
||||
if "defaults" in st.session_state:
|
||||
if st.session_state['defaults'].admin.global_negative_prompt:
|
||||
prompt += f"### {st.session_state['defaults'].admin.global_negative_prompt}"
|
||||
|
||||
# Every form must have a submit button, the extra blank spaces is a temp way to align it with the input field. Needs to be done in CSS or some other way.
|
||||
img2img_generate_col.write("")
|
||||
img2img_generate_col.write("")
|
||||
generate_button = img2img_generate_col.form_submit_button("Generate")
|
||||
|
||||
|
||||
# creating the page layout using columns
|
||||
col1_img2img_layout, col2_img2img_layout, col3_img2img_layout = st.columns([2,4,4], gap="medium")
|
||||
|
||||
with col1_img2img_layout:
|
||||
# If we have custom models available on the "models/custom"
|
||||
#folder then we show a menu to select which model we want to use, otherwise we use the main model for SD
|
||||
custom_models_available()
|
||||
if server_state["CustomModel_available"]:
|
||||
st.session_state["custom_model"] = st.selectbox("Custom Model:", server_state["custom_models"],
|
||||
index=server_state["custom_models"].index(st.session_state['defaults'].general.default_model),
|
||||
help="Select the model you want to use. This option is only available if you have custom models \
|
||||
on your 'models/custom' folder. The model name that will be shown here is the same as the name\
|
||||
the file for the model has on said folder, it is recommended to give the .ckpt file a name that \
|
||||
will make it easier for you to distinguish it from other models. Default: Stable Diffusion v1.5")
|
||||
else:
|
||||
st.session_state["custom_model"] = "Stable Diffusion v1.5"
|
||||
|
||||
|
||||
st.session_state["sampling_steps"] = st.number_input("Sampling Steps", value=st.session_state['defaults'].img2img.sampling_steps.value,
|
||||
min_value=st.session_state['defaults'].img2img.sampling_steps.min_value,
|
||||
step=st.session_state['defaults'].img2img.sampling_steps.step)
|
||||
|
||||
sampler_name_list = ["k_lms", "k_euler", "k_euler_a", "k_dpm_2", "k_dpm_2_a", "k_dpmpp_2m", "k_heun", "PLMS", "DDIM"]
|
||||
st.session_state["sampler_name"] = st.selectbox("Sampling method",sampler_name_list,
|
||||
index=sampler_name_list.index(st.session_state['defaults'].img2img.sampler_name), help="Sampling method to use.")
|
||||
|
||||
width = st.slider("Width:", min_value=st.session_state['defaults'].img2img.width.min_value, max_value=st.session_state['defaults'].img2img.width.max_value,
|
||||
value=st.session_state['defaults'].img2img.width.value, step=st.session_state['defaults'].img2img.width.step)
|
||||
height = st.slider("Height:", min_value=st.session_state['defaults'].img2img.height.min_value, max_value=st.session_state['defaults'].img2img.height.max_value,
|
||||
value=st.session_state['defaults'].img2img.height.value, step=st.session_state['defaults'].img2img.height.step)
|
||||
seed = st.text_input("Seed:", value=st.session_state['defaults'].img2img.seed, help=" The seed to use, if left blank a random seed will be generated.")
|
||||
|
||||
cfg_scale = st.number_input("CFG (Classifier Free Guidance Scale):", min_value=st.session_state['defaults'].img2img.cfg_scale.min_value,
|
||||
value=st.session_state['defaults'].img2img.cfg_scale.value,
|
||||
step=st.session_state['defaults'].img2img.cfg_scale.step,
|
||||
help="How strongly the image should follow the prompt.")
|
||||
|
||||
st.session_state["denoising_strength"] = st.slider("Denoising Strength:", value=st.session_state['defaults'].img2img.denoising_strength.value,
|
||||
min_value=st.session_state['defaults'].img2img.denoising_strength.min_value,
|
||||
max_value=st.session_state['defaults'].img2img.denoising_strength.max_value,
|
||||
step=st.session_state['defaults'].img2img.denoising_strength.step)
|
||||
|
||||
|
||||
mask_expander = st.empty()
|
||||
with mask_expander.expander("Inpainting/Outpainting"):
|
||||
mask_mode_list = ["Outpainting", "Inpainting", "Image alpha"]
|
||||
mask_mode = st.selectbox("Painting Mode", mask_mode_list, index=st.session_state["defaults"].img2img.mask_mode,
|
||||
help="Select how you want your image to be masked/painted.\"Inpainting\" modifies the image where the mask is white.\n\
|
||||
\"Inverted mask\" modifies the image where the mask is black. \"Image alpha\" modifies the image where the image is transparent."
|
||||
)
|
||||
mask_mode = mask_mode_list.index(mask_mode)
|
||||
|
||||
|
||||
noise_mode_list = ["Seed", "Find Noise", "Matched Noise", "Find+Matched Noise"]
|
||||
noise_mode = st.selectbox("Noise Mode", noise_mode_list, index=noise_mode_list.index(st.session_state['defaults'].img2img.noise_mode), help="")
|
||||
#noise_mode = noise_mode_list.index(noise_mode)
|
||||
find_noise_steps = st.number_input("Find Noise Steps", value=st.session_state['defaults'].img2img.find_noise_steps.value,
|
||||
min_value=st.session_state['defaults'].img2img.find_noise_steps.min_value,
|
||||
step=st.session_state['defaults'].img2img.find_noise_steps.step)
|
||||
|
||||
# Specify canvas parameters in application
|
||||
drawing_mode = st.selectbox(
|
||||
"Drawing tool:",
|
||||
(
|
||||
"freedraw",
|
||||
"transform",
|
||||
#"line",
|
||||
"rect",
|
||||
"circle",
|
||||
#"polygon",
|
||||
),
|
||||
)
|
||||
|
||||
stroke_width = st.slider("Stroke width: ", 1, 100, 50)
|
||||
stroke_color = st.color_picker("Stroke color hex: ", value="#EEEEEE")
|
||||
bg_color = st.color_picker("Background color hex: ", "#7B6E6E")
|
||||
|
||||
display_toolbar = st.checkbox("Display toolbar", True)
|
||||
#realtime_update = st.checkbox("Update in realtime", True)
|
||||
|
||||
with st.expander("Batch Options"):
|
||||
st.session_state["batch_count"] = st.number_input("Batch count.", value=st.session_state['defaults'].img2img.batch_count.value,
|
||||
help="How many iterations or batches of images to generate in total.")
|
||||
|
||||
st.session_state["batch_size"] = st.number_input("Batch size", value=st.session_state.defaults.img2img.batch_size.value,
|
||||
help="How many images are at once in a batch.\
|
||||
It increases the VRAM usage a lot but if you have enough VRAM it can reduce the time it takes to finish generation as more images are generated at once.\
|
||||
Default: 1")
|
||||
|
||||
with st.expander("Preview Settings"):
|
||||
st.session_state["update_preview"] = st.session_state["defaults"].general.update_preview
|
||||
st.session_state["update_preview_frequency"] = st.number_input("Update Image Preview Frequency",
|
||||
min_value=0,
|
||||
value=st.session_state['defaults'].img2img.update_preview_frequency,
|
||||
help="Frequency in steps at which the the preview image is updated. By default the frequency \
|
||||
is set to 1 step.")
|
||||
#
|
||||
with st.expander("Advanced"):
|
||||
with st.expander("Output Settings"):
|
||||
separate_prompts = st.checkbox("Create Prompt Matrix.", value=st.session_state['defaults'].img2img.separate_prompts,
|
||||
help="Separate multiple prompts using the `|` character, and get all combinations of them.")
|
||||
normalize_prompt_weights = st.checkbox("Normalize Prompt Weights.", value=st.session_state['defaults'].img2img.normalize_prompt_weights,
|
||||
help="Ensure the sum of all weights add up to 1.0")
|
||||
loopback = st.checkbox("Loopback.", value=st.session_state['defaults'].img2img.loopback, help="Use images from previous batch when creating next batch.")
|
||||
random_seed_loopback = st.checkbox("Random loopback seed.", value=st.session_state['defaults'].img2img.random_seed_loopback, help="Random loopback seed")
|
||||
img2img_mask_restore = st.checkbox("Only modify regenerated parts of image",
|
||||
value=st.session_state['defaults'].img2img.mask_restore,
|
||||
help="Enable to restore the unmasked parts of the image with the input, may not blend as well but preserves detail")
|
||||
save_individual_images = st.checkbox("Save individual images.", value=st.session_state['defaults'].img2img.save_individual_images,
|
||||
help="Save each image generated before any filter or enhancement is applied.")
|
||||
save_grid = st.checkbox("Save grid",value=st.session_state['defaults'].img2img.save_grid, help="Save a grid with all the images generated into a single image.")
|
||||
group_by_prompt = st.checkbox("Group results by prompt", value=st.session_state['defaults'].img2img.group_by_prompt,
|
||||
help="Saves all the images with the same prompt into the same folder. \
|
||||
When using a prompt matrix each prompt combination will have its own folder.")
|
||||
write_info_files = st.checkbox("Write Info file", value=st.session_state['defaults'].img2img.write_info_files,
|
||||
help="Save a file next to the image with informartion about the generation.")
|
||||
save_as_jpg = st.checkbox("Save samples as jpg", value=st.session_state['defaults'].img2img.save_as_jpg, help="Saves the images as jpg instead of png.")
|
||||
|
||||
#
|
||||
# check if GFPGAN, RealESRGAN and LDSR are available.
|
||||
if "GFPGAN_available" not in st.session_state:
|
||||
GFPGAN_available()
|
||||
|
||||
if "RealESRGAN_available" not in st.session_state:
|
||||
RealESRGAN_available()
|
||||
|
||||
if "LDSR_available" not in st.session_state:
|
||||
LDSR_available()
|
||||
|
||||
if st.session_state["GFPGAN_available"] or st.session_state["RealESRGAN_available"] or st.session_state["LDSR_available"]:
|
||||
with st.expander("Post-Processing"):
|
||||
face_restoration_tab, upscaling_tab = st.tabs(["Face Restoration", "Upscaling"])
|
||||
with face_restoration_tab:
|
||||
# GFPGAN used for face restoration
|
||||
if st.session_state["GFPGAN_available"]:
|
||||
#with st.expander("Face Restoration"):
|
||||
#if st.session_state["GFPGAN_available"]:
|
||||
#with st.expander("GFPGAN"):
|
||||
st.session_state["use_GFPGAN"] = st.checkbox("Use GFPGAN", value=st.session_state['defaults'].img2img.use_GFPGAN,
|
||||
help="Uses the GFPGAN model to improve faces after the generation.\
|
||||
This greatly improve the quality and consistency of faces but uses\
|
||||
extra VRAM. Disable if you need the extra VRAM.")
|
||||
|
||||
st.session_state["GFPGAN_model"] = st.selectbox("GFPGAN model", st.session_state["GFPGAN_models"],
|
||||
index=st.session_state["GFPGAN_models"].index(st.session_state['defaults'].general.GFPGAN_model))
|
||||
|
||||
#st.session_state["GFPGAN_strenght"] = st.slider("Effect Strenght", min_value=1, max_value=100, value=1, step=1, help='')
|
||||
|
||||
else:
|
||||
st.session_state["use_GFPGAN"] = False
|
||||
|
||||
with upscaling_tab:
|
||||
st.session_state['us_upscaling'] = st.checkbox("Use Upscaling", value=st.session_state['defaults'].img2img.use_upscaling)
|
||||
|
||||
# RealESRGAN and LDSR used for upscaling.
|
||||
if st.session_state["RealESRGAN_available"] or st.session_state["LDSR_available"]:
|
||||
|
||||
upscaling_method_list = []
|
||||
if st.session_state["RealESRGAN_available"]:
|
||||
upscaling_method_list.append("RealESRGAN")
|
||||
if st.session_state["LDSR_available"]:
|
||||
upscaling_method_list.append("LDSR")
|
||||
|
||||
st.session_state["upscaling_method"] = st.selectbox("Upscaling Method", upscaling_method_list,
|
||||
index=upscaling_method_list.index(st.session_state['defaults'].general.upscaling_method)
|
||||
if st.session_state['defaults'].general.upscaling_method in upscaling_method_list
|
||||
else 0)
|
||||
|
||||
if st.session_state["RealESRGAN_available"]:
|
||||
with st.expander("RealESRGAN"):
|
||||
if st.session_state["upscaling_method"] == "RealESRGAN" and st.session_state['us_upscaling']:
|
||||
st.session_state["use_RealESRGAN"] = True
|
||||
else:
|
||||
st.session_state["use_RealESRGAN"] = False
|
||||
|
||||
st.session_state["RealESRGAN_model"] = st.selectbox("RealESRGAN model", st.session_state["RealESRGAN_models"],
|
||||
index=st.session_state["RealESRGAN_models"].index(st.session_state['defaults'].general.RealESRGAN_model))
|
||||
else:
|
||||
st.session_state["use_RealESRGAN"] = False
|
||||
st.session_state["RealESRGAN_model"] = "RealESRGAN_x4plus"
|
||||
|
||||
|
||||
#
|
||||
if st.session_state["LDSR_available"]:
|
||||
with st.expander("LDSR"):
|
||||
if st.session_state["upscaling_method"] == "LDSR" and st.session_state['us_upscaling']:
|
||||
st.session_state["use_LDSR"] = True
|
||||
else:
|
||||
st.session_state["use_LDSR"] = False
|
||||
|
||||
st.session_state["LDSR_model"] = st.selectbox("LDSR model", st.session_state["LDSR_models"],
|
||||
index=st.session_state["LDSR_models"].index(st.session_state['defaults'].general.LDSR_model))
|
||||
|
||||
st.session_state["ldsr_sampling_steps"] = st.number_input("Sampling Steps", value=st.session_state['defaults'].img2img.LDSR_config.sampling_steps,
|
||||
help="")
|
||||
|
||||
st.session_state["preDownScale"] = st.number_input("PreDownScale", value=st.session_state['defaults'].img2img.LDSR_config.preDownScale,
|
||||
help="")
|
||||
|
||||
st.session_state["postDownScale"] = st.number_input("postDownScale", value=st.session_state['defaults'].img2img.LDSR_config.postDownScale,
|
||||
help="")
|
||||
|
||||
downsample_method_list = ['Nearest', 'Lanczos']
|
||||
st.session_state["downsample_method"] = st.selectbox("Downsample Method", downsample_method_list,
|
||||
index=downsample_method_list.index(st.session_state['defaults'].img2img.LDSR_config.downsample_method))
|
||||
|
||||
else:
|
||||
st.session_state["use_LDSR"] = False
|
||||
st.session_state["LDSR_model"] = "model"
|
||||
|
||||
with st.expander("Variant"):
|
||||
variant_amount = st.slider("Variant Amount:", value=st.session_state['defaults'].img2img.variant_amount, min_value=0.0, max_value=1.0, step=0.01)
|
||||
variant_seed = st.text_input("Variant Seed:", value=st.session_state['defaults'].img2img.variant_seed,
|
||||
help="The seed to use when generating a variant, if left blank a random seed will be generated.")
|
||||
|
||||
|
||||
with col2_img2img_layout:
|
||||
editor_tab = st.tabs(["Editor"])
|
||||
|
||||
editor_image = st.empty()
|
||||
st.session_state["editor_image"] = editor_image
|
||||
|
||||
st.form_submit_button("Refresh")
|
||||
|
||||
#if "canvas" not in st.session_state:
|
||||
st.session_state["canvas"] = st.empty()
|
||||
|
||||
masked_image_holder = st.empty()
|
||||
image_holder = st.empty()
|
||||
|
||||
uploaded_images = st.file_uploader(
|
||||
"Upload Image", accept_multiple_files=False, type=["png", "jpg", "jpeg", "webp", 'jfif'],
|
||||
help="Upload an image which will be used for the image to image generation.",
|
||||
)
|
||||
if uploaded_images:
|
||||
image = Image.open(uploaded_images).convert('RGB')
|
||||
new_img = image.resize((width, height))
|
||||
#image_holder.image(new_img)
|
||||
|
||||
#mask_holder = st.empty()
|
||||
|
||||
#uploaded_masks = st.file_uploader(
|
||||
#"Upload Mask", accept_multiple_files=False, type=["png", "jpg", "jpeg", "webp", 'jfif'],
|
||||
#help="Upload an mask image which will be used for masking the image to image generation.",
|
||||
#)
|
||||
|
||||
#
|
||||
# Create a canvas component
|
||||
with st.session_state["canvas"]:
|
||||
st.session_state["uploaded_masks"] = st_canvas(
|
||||
fill_color="rgba(255, 165, 0, 0.3)", # Fixed fill color with some opacity
|
||||
stroke_width=stroke_width,
|
||||
stroke_color=stroke_color,
|
||||
background_color=bg_color,
|
||||
background_image=image if uploaded_images else None,
|
||||
update_streamlit=True,
|
||||
width=width,
|
||||
height=height,
|
||||
drawing_mode=drawing_mode,
|
||||
initial_drawing=st.session_state["uploaded_masks"].json_data if "uploaded_masks" in st.session_state else None,
|
||||
display_toolbar= display_toolbar,
|
||||
key="full_app",
|
||||
)
|
||||
|
||||
#try:
|
||||
##print (type(st.session_state["uploaded_masks"]))
|
||||
#if st.session_state["uploaded_masks"] != None:
|
||||
#mask_expander.expander("Mask", expanded=True)
|
||||
#mask = Image.fromarray(st.session_state["uploaded_masks"].image_data)
|
||||
|
||||
#st.image(mask)
|
||||
|
||||
#if mask.mode == "RGBA":
|
||||
#mask = mask.convert('RGBA')
|
||||
#background = Image.new('RGBA', mask.size, (0, 0, 0))
|
||||
#mask = Image.alpha_composite(background, mask)
|
||||
#mask = mask.resize((width, height))
|
||||
#except AttributeError:
|
||||
#pass
|
||||
|
||||
with col3_img2img_layout:
|
||||
result_tab = st.tabs(["Result"])
|
||||
|
||||
# create an empty container for the image, progress bar, etc so we can update it later and use session_state to hold them globally.
|
||||
preview_image = st.empty()
|
||||
st.session_state["preview_image"] = preview_image
|
||||
|
||||
#st.session_state["loading"] = st.empty()
|
||||
|
||||
st.session_state["progress_bar_text"] = st.empty()
|
||||
st.session_state["progress_bar"] = st.empty()
|
||||
|
||||
message = st.empty()
|
||||
|
||||
#if uploaded_images:
|
||||
#image = Image.open(uploaded_images).convert('RGB')
|
||||
##img_array = np.array(image) # if you want to pass it to OpenCV
|
||||
#new_img = image.resize((width, height))
|
||||
#st.image(new_img, use_column_width=True)
|
||||
|
||||
|
||||
if generate_button:
|
||||
#print("Loading models")
|
||||
# load the models when we hit the generate button for the first time, it wont be loaded after that so dont worry.
|
||||
with col3_img2img_layout:
|
||||
with no_rerun:
|
||||
with hc.HyLoader('Loading Models...', hc.Loaders.standard_loaders,index=[0]):
|
||||
load_models(use_LDSR=st.session_state["use_LDSR"], LDSR_model=st.session_state["LDSR_model"],
|
||||
use_GFPGAN=st.session_state["use_GFPGAN"], GFPGAN_model=st.session_state["GFPGAN_model"] ,
|
||||
use_RealESRGAN=st.session_state["use_RealESRGAN"], RealESRGAN_model=st.session_state["RealESRGAN_model"],
|
||||
CustomModel_available=server_state["CustomModel_available"], custom_model=st.session_state["custom_model"])
|
||||
|
||||
if uploaded_images:
|
||||
#image = Image.fromarray(image).convert('RGBA')
|
||||
#new_img = image.resize((width, height))
|
||||
###img_array = np.array(image) # if you want to pass it to OpenCV
|
||||
#image_holder.image(new_img)
|
||||
new_mask = None
|
||||
|
||||
if st.session_state["uploaded_masks"]:
|
||||
mask = Image.fromarray(st.session_state["uploaded_masks"].image_data)
|
||||
new_mask = mask.resize((width, height))
|
||||
|
||||
#masked_image_holder.image(new_mask)
|
||||
try:
|
||||
output_images, seed, info, stats = img2img(prompt=prompt, init_info=new_img, init_info_mask=new_mask, mask_mode=mask_mode,
|
||||
mask_restore=img2img_mask_restore, ddim_steps=st.session_state["sampling_steps"],
|
||||
sampler_name=st.session_state["sampler_name"], n_iter=st.session_state["batch_count"],
|
||||
cfg_scale=cfg_scale, denoising_strength=st.session_state["denoising_strength"], variant_seed=variant_seed,
|
||||
seed=seed, noise_mode=noise_mode, find_noise_steps=find_noise_steps, width=width,
|
||||
height=height, variant_amount=variant_amount,
|
||||
ddim_eta=st.session_state.defaults.img2img.ddim_eta, write_info_files=write_info_files,
|
||||
separate_prompts=separate_prompts, normalize_prompt_weights=normalize_prompt_weights,
|
||||
save_individual_images=save_individual_images, save_grid=save_grid,
|
||||
group_by_prompt=group_by_prompt, save_as_jpg=save_as_jpg, use_GFPGAN=st.session_state["use_GFPGAN"],
|
||||
GFPGAN_model=st.session_state["GFPGAN_model"],
|
||||
use_RealESRGAN=st.session_state["use_RealESRGAN"], RealESRGAN_model=st.session_state["RealESRGAN_model"],
|
||||
use_LDSR=st.session_state["use_LDSR"], LDSR_model=st.session_state["LDSR_model"],
|
||||
loopback=loopback
|
||||
)
|
||||
|
||||
#show a message when the generation is complete.
|
||||
message.success('Render Complete: ' + info + '; Stats: ' + stats, icon="✅")
|
||||
|
||||
except (StopException,
|
||||
#KeyError
|
||||
):
|
||||
logger.info(f"Received Streamlit StopException")
|
||||
# reset the page title so the percent doesnt stay on it confusing the user.
|
||||
set_page_title(f"Stable Diffusion Playground")
|
||||
|
||||
# this will render all the images at the end of the generation but its better if its moved to a second tab inside col2 and shown as a gallery.
|
||||
# use the current col2 first tab to show the preview_img and update it as its generated.
|
||||
#preview_image.image(output_images, width=750)
|
||||
|
||||
#on import run init
|
460
webui/streamlit/scripts/img2txt.py
Normal file
460
webui/streamlit/scripts/img2txt.py
Normal file
@ -0,0 +1,460 @@
|
||||
# This file is part of sygil-webui (https://github.com/Sygil-Dev/sygil-webui/).
|
||||
|
||||
# Copyright 2022 Sygil-Dev team.
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as published by
|
||||
# the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
|
||||
# This program is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU Affero General Public License for more details.
|
||||
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
# ---------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
"""
|
||||
CLIP Interrogator made by @pharmapsychotic modified to work with our WebUI.
|
||||
|
||||
# CLIP Interrogator by @pharmapsychotic
|
||||
Twitter: https://twitter.com/pharmapsychotic
|
||||
Github: https://github.com/pharmapsychotic/clip-interrogator
|
||||
|
||||
Description:
|
||||
What do the different OpenAI CLIP models see in an image? What might be a good text prompt to create similar images using CLIP guided diffusion
|
||||
or another text to image model? The CLIP Interrogator is here to get you answers!
|
||||
|
||||
Please consider buying him a coffee via [ko-fi](https://ko-fi.com/pharmapsychotic) or following him on [twitter](https://twitter.com/pharmapsychotic).
|
||||
|
||||
And if you're looking for more Ai art tools check out my [Ai generative art tools list](https://pharmapsychotic.com/tools.html).
|
||||
|
||||
"""
|
||||
# ---------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
# base webui import and utils.
|
||||
from sd_utils import st, logger, server_state, server_state_lock, random
|
||||
|
||||
# streamlit imports
|
||||
|
||||
# streamlit components section
|
||||
import streamlit_nested_layout
|
||||
|
||||
# other imports
|
||||
|
||||
import clip
|
||||
import open_clip
|
||||
import gc
|
||||
import os
|
||||
import pandas as pd
|
||||
#import requests
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torchvision import transforms
|
||||
from torchvision.transforms.functional import InterpolationMode
|
||||
from ldm.models.blip import blip_decoder
|
||||
#import hashlib
|
||||
|
||||
# end of imports
|
||||
# ---------------------------------------------------------------------------------------------------------------
|
||||
|
||||
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
||||
blip_image_eval_size = 512
|
||||
|
||||
st.session_state["log"] = []
|
||||
|
||||
def load_blip_model():
|
||||
logger.info("Loading BLIP Model")
|
||||
if "log" not in st.session_state:
|
||||
st.session_state["log"] = []
|
||||
|
||||
st.session_state["log"].append("Loading BLIP Model")
|
||||
st.session_state["log_message"].code('\n'.join(st.session_state["log"]), language='')
|
||||
|
||||
if "blip_model" not in server_state:
|
||||
with server_state_lock['blip_model']:
|
||||
server_state["blip_model"] = blip_decoder(pretrained="models/blip/model__base_caption.pth",
|
||||
image_size=blip_image_eval_size, vit='base', med_config="configs/blip/med_config.json")
|
||||
|
||||
server_state["blip_model"] = server_state["blip_model"].eval()
|
||||
|
||||
server_state["blip_model"] = server_state["blip_model"].to(device).half()
|
||||
|
||||
logger.info("BLIP Model Loaded")
|
||||
st.session_state["log"].append("BLIP Model Loaded")
|
||||
st.session_state["log_message"].code('\n'.join(st.session_state["log"]), language='')
|
||||
else:
|
||||
logger.info("BLIP Model already loaded")
|
||||
st.session_state["log"].append("BLIP Model already loaded")
|
||||
st.session_state["log_message"].code('\n'.join(st.session_state["log"]), language='')
|
||||
|
||||
|
||||
def generate_caption(pil_image):
|
||||
|
||||
load_blip_model()
|
||||
|
||||
gpu_image = transforms.Compose([ # type: ignore
|
||||
transforms.Resize((blip_image_eval_size, blip_image_eval_size), interpolation=InterpolationMode.BICUBIC), # type: ignore
|
||||
transforms.ToTensor(), # type: ignore
|
||||
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) # type: ignore
|
||||
])(pil_image).unsqueeze(0).to(device).half()
|
||||
|
||||
with torch.no_grad():
|
||||
caption = server_state["blip_model"].generate(gpu_image, sample=False, num_beams=3, max_length=20, min_length=5)
|
||||
|
||||
return caption[0]
|
||||
|
||||
def load_list(filename):
|
||||
with open(filename, 'r', encoding='utf-8', errors='replace') as f:
|
||||
items = [line.strip() for line in f.readlines()]
|
||||
return items
|
||||
|
||||
def rank(model, image_features, text_array, top_count=1):
|
||||
top_count = min(top_count, len(text_array))
|
||||
text_tokens = clip.tokenize([text for text in text_array]).cuda()
|
||||
with torch.no_grad():
|
||||
text_features = model.encode_text(text_tokens).float()
|
||||
text_features /= text_features.norm(dim=-1, keepdim=True)
|
||||
|
||||
similarity = torch.zeros((1, len(text_array))).to(device)
|
||||
for i in range(image_features.shape[0]):
|
||||
similarity += (100.0 * image_features[i].unsqueeze(0) @ text_features.T).softmax(dim=-1)
|
||||
similarity /= image_features.shape[0]
|
||||
|
||||
top_probs, top_labels = similarity.cpu().topk(top_count, dim=-1)
|
||||
return [(text_array[top_labels[0][i].numpy()], (top_probs[0][i].numpy()*100)) for i in range(top_count)]
|
||||
|
||||
|
||||
def clear_cuda():
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
|
||||
def batch_rank(model, image_features, text_array, batch_size=st.session_state["defaults"].img2txt.batch_size):
|
||||
batch_size = min(batch_size, len(text_array))
|
||||
batch_count = int(len(text_array) / batch_size)
|
||||
batches = [text_array[i*batch_size:(i+1)*batch_size] for i in range(batch_count)]
|
||||
ranks = []
|
||||
for batch in batches:
|
||||
ranks += rank(model, image_features, batch)
|
||||
return ranks
|
||||
|
||||
def interrogate(image, models):
|
||||
load_blip_model()
|
||||
|
||||
logger.info("Generating Caption")
|
||||
st.session_state["log"].append("Generating Caption")
|
||||
st.session_state["log_message"].code('\n'.join(st.session_state["log"]), language='')
|
||||
caption = generate_caption(image)
|
||||
|
||||
if st.session_state["defaults"].general.optimized:
|
||||
del server_state["blip_model"]
|
||||
clear_cuda()
|
||||
|
||||
logger.info("Caption Generated")
|
||||
st.session_state["log"].append("Caption Generated")
|
||||
st.session_state["log_message"].code('\n'.join(st.session_state["log"]), language='')
|
||||
|
||||
if len(models) == 0:
|
||||
logger.info(f"\n\n{caption}")
|
||||
return
|
||||
|
||||
table = []
|
||||
bests = [[('', 0)]]*7
|
||||
|
||||
logger.info("Ranking Text")
|
||||
st.session_state["log"].append("Ranking Text")
|
||||
st.session_state["log_message"].code('\n'.join(st.session_state["log"]), language='')
|
||||
|
||||
for model_name in models:
|
||||
with torch.no_grad(), torch.autocast('cuda', dtype=torch.float16):
|
||||
logger.info(f"Interrogating with {model_name}...")
|
||||
st.session_state["log"].append(f"Interrogating with {model_name}...")
|
||||
st.session_state["log_message"].code('\n'.join(st.session_state["log"]), language='')
|
||||
|
||||
if model_name not in server_state["clip_models"]:
|
||||
if not st.session_state["defaults"].img2txt.keep_all_models_loaded:
|
||||
model_to_delete = []
|
||||
for model in server_state["clip_models"]:
|
||||
if model != model_name:
|
||||
model_to_delete.append(model)
|
||||
for model in model_to_delete:
|
||||
del server_state["clip_models"][model]
|
||||
del server_state["preprocesses"][model]
|
||||
clear_cuda()
|
||||
if model_name == 'ViT-H-14':
|
||||
server_state["clip_models"][model_name], _, server_state["preprocesses"][model_name] = \
|
||||
open_clip.create_model_and_transforms(model_name, pretrained='laion2b_s32b_b79k', cache_dir='models/clip')
|
||||
elif model_name == 'ViT-g-14':
|
||||
server_state["clip_models"][model_name], _, server_state["preprocesses"][model_name] = \
|
||||
open_clip.create_model_and_transforms(model_name, pretrained='laion2b_s12b_b42k', cache_dir='models/clip')
|
||||
else:
|
||||
server_state["clip_models"][model_name], server_state["preprocesses"][model_name] = \
|
||||
clip.load(model_name, device=device, download_root='models/clip')
|
||||
server_state["clip_models"][model_name] = server_state["clip_models"][model_name].cuda().eval()
|
||||
|
||||
images = server_state["preprocesses"][model_name](image).unsqueeze(0).cuda()
|
||||
|
||||
|
||||
image_features = server_state["clip_models"][model_name].encode_image(images).float()
|
||||
|
||||
image_features /= image_features.norm(dim=-1, keepdim=True)
|
||||
|
||||
if st.session_state["defaults"].general.optimized:
|
||||
clear_cuda()
|
||||
|
||||
ranks = []
|
||||
ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["mediums"]))
|
||||
ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, ["by "+artist for artist in server_state["artists"]]))
|
||||
ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["trending_list"]))
|
||||
ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["movements"]))
|
||||
ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["flavors"]))
|
||||
#ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["domains"]))
|
||||
#ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["subreddits"]))
|
||||
ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["techniques"]))
|
||||
ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["tags"]))
|
||||
|
||||
# ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["genres"]))
|
||||
# ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["styles"]))
|
||||
# ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["subjects"]))
|
||||
# ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["colors"]))
|
||||
# ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["moods"]))
|
||||
# ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["themes"]))
|
||||
# ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["keywords"]))
|
||||
|
||||
#print (bests)
|
||||
#print (ranks)
|
||||
|
||||
for i in range(len(ranks)):
|
||||
confidence_sum = 0
|
||||
for ci in range(len(ranks[i])):
|
||||
confidence_sum += ranks[i][ci][1]
|
||||
if confidence_sum > sum(bests[i][t][1] for t in range(len(bests[i]))):
|
||||
bests[i] = ranks[i]
|
||||
|
||||
for best in bests:
|
||||
best.sort(key=lambda x: x[1], reverse=True)
|
||||
# prune to 3
|
||||
best = best[:3]
|
||||
|
||||
row = [model_name]
|
||||
|
||||
for r in ranks:
|
||||
row.append(', '.join([f"{x[0]} ({x[1]:0.1f}%)" for x in r]))
|
||||
|
||||
#for rank in ranks:
|
||||
# rank.sort(key=lambda x: x[1], reverse=True)
|
||||
# row.append(f'{rank[0][0]} {rank[0][1]:.2f}%')
|
||||
|
||||
table.append(row)
|
||||
|
||||
if st.session_state["defaults"].general.optimized:
|
||||
del server_state["clip_models"][model_name]
|
||||
gc.collect()
|
||||
|
||||
st.session_state["prediction_table"][st.session_state["processed_image_count"]].dataframe(pd.DataFrame(
|
||||
table, columns=["Model", "Medium", "Artist", "Trending", "Movement", "Flavors", "Techniques", "Tags"]))
|
||||
|
||||
medium = bests[0][0][0]
|
||||
artist = bests[1][0][0]
|
||||
trending = bests[2][0][0]
|
||||
movement = bests[3][0][0]
|
||||
flavors = bests[4][0][0]
|
||||
#domains = bests[5][0][0]
|
||||
#subreddits = bests[6][0][0]
|
||||
techniques = bests[5][0][0]
|
||||
tags = bests[6][0][0]
|
||||
|
||||
|
||||
if caption.startswith(medium):
|
||||
st.session_state["text_result"][st.session_state["processed_image_count"]].code(
|
||||
f"\n\n{caption} {artist}, {trending}, {movement}, {techniques}, {flavors}, {tags}", language="")
|
||||
else:
|
||||
st.session_state["text_result"][st.session_state["processed_image_count"]].code(
|
||||
f"\n\n{caption}, {medium} {artist}, {trending}, {movement}, {techniques}, {flavors}, {tags}", language="")
|
||||
|
||||
logger.info("Finished Interrogating.")
|
||||
st.session_state["log"].append("Finished Interrogating.")
|
||||
st.session_state["log_message"].code('\n'.join(st.session_state["log"]), language='')
|
||||
|
||||
|
||||
def img2txt():
|
||||
models = []
|
||||
|
||||
if st.session_state["ViT-L/14"]:
|
||||
models.append('ViT-L/14')
|
||||
if st.session_state["ViT-H-14"]:
|
||||
models.append('ViT-H-14')
|
||||
if st.session_state["ViT-g-14"]:
|
||||
models.append('ViT-g-14')
|
||||
|
||||
if st.session_state["ViTB32"]:
|
||||
models.append('ViT-B/32')
|
||||
if st.session_state['ViTB16']:
|
||||
models.append('ViT-B/16')
|
||||
|
||||
if st.session_state["ViTL14_336px"]:
|
||||
models.append('ViT-L/14@336px')
|
||||
if st.session_state["RN101"]:
|
||||
models.append('RN101')
|
||||
if st.session_state["RN50"]:
|
||||
models.append('RN50')
|
||||
if st.session_state["RN50x4"]:
|
||||
models.append('RN50x4')
|
||||
if st.session_state["RN50x16"]:
|
||||
models.append('RN50x16')
|
||||
if st.session_state["RN50x64"]:
|
||||
models.append('RN50x64')
|
||||
|
||||
# if str(image_path_or_url).startswith('http://') or str(image_path_or_url).startswith('https://'):
|
||||
#image = Image.open(requests.get(image_path_or_url, stream=True).raw).convert('RGB')
|
||||
# else:
|
||||
#image = Image.open(image_path_or_url).convert('RGB')
|
||||
|
||||
#thumb = st.session_state["uploaded_image"].image.copy()
|
||||
#thumb.thumbnail([blip_image_eval_size, blip_image_eval_size])
|
||||
# display(thumb)
|
||||
|
||||
st.session_state["processed_image_count"] = 0
|
||||
|
||||
for i in range(len(st.session_state["uploaded_image"])):
|
||||
|
||||
interrogate(st.session_state["uploaded_image"][i].pil_image, models=models)
|
||||
# increase counter.
|
||||
st.session_state["processed_image_count"] += 1
|
||||
#
|
||||
|
||||
|
||||
def layout():
|
||||
#set_page_title("Image-to-Text - Stable Diffusion WebUI")
|
||||
#st.info("Under Construction. :construction_worker:")
|
||||
#
|
||||
if "clip_models" not in server_state:
|
||||
server_state["clip_models"] = {}
|
||||
if "preprocesses" not in server_state:
|
||||
server_state["preprocesses"] = {}
|
||||
data_path = "data/"
|
||||
if "artists" not in server_state:
|
||||
server_state["artists"] = load_list(os.path.join(data_path, 'img2txt', 'artists.txt'))
|
||||
if "flavors" not in server_state:
|
||||
server_state["flavors"] = random.choices(load_list(os.path.join(data_path, 'img2txt', 'flavors.txt')), k=2000)
|
||||
if "mediums" not in server_state:
|
||||
server_state["mediums"] = load_list(os.path.join(data_path, 'img2txt', 'mediums.txt'))
|
||||
if "movements" not in server_state:
|
||||
server_state["movements"] = load_list(os.path.join(data_path, 'img2txt', 'movements.txt'))
|
||||
if "sites" not in server_state:
|
||||
server_state["sites"] = load_list(os.path.join(data_path, 'img2txt', 'sites.txt'))
|
||||
#server_state["domains"] = load_list(os.path.join(data_path, 'img2txt', 'domains.txt'))
|
||||
#server_state["subreddits"] = load_list(os.path.join(data_path, 'img2txt', 'subreddits.txt'))
|
||||
if "techniques" not in server_state:
|
||||
server_state["techniques"] = load_list(os.path.join(data_path, 'img2txt', 'techniques.txt'))
|
||||
if "tags" not in server_state:
|
||||
server_state["tags"] = load_list(os.path.join(data_path, 'img2txt', 'tags.txt'))
|
||||
#server_state["genres"] = load_list(os.path.join(data_path, 'img2txt', 'genres.txt'))
|
||||
# server_state["styles"] = load_list(os.path.join(data_path, 'img2txt', 'styles.txt'))
|
||||
# server_state["subjects"] = load_list(os.path.join(data_path, 'img2txt', 'subjects.txt'))
|
||||
if "trending_list" not in server_state:
|
||||
server_state["trending_list"] = [site for site in server_state["sites"]]
|
||||
server_state["trending_list"].extend(["trending on "+site for site in server_state["sites"]])
|
||||
server_state["trending_list"].extend(["featured on "+site for site in server_state["sites"]])
|
||||
server_state["trending_list"].extend([site+" contest winner" for site in server_state["sites"]])
|
||||
with st.form("img2txt-inputs"):
|
||||
st.session_state["generation_mode"] = "img2txt"
|
||||
|
||||
# st.write("---")
|
||||
# creating the page layout using columns
|
||||
col1, col2 = st.columns([1, 4], gap="large")
|
||||
|
||||
with col1:
|
||||
st.session_state["uploaded_image"] = st.file_uploader('Input Image', type=['png', 'jpg', 'jpeg', 'jfif', 'webp'], accept_multiple_files=True)
|
||||
|
||||
with st.expander("CLIP models", expanded=True):
|
||||
st.session_state["ViT-L/14"] = st.checkbox("ViT-L/14", value=True, help="ViT-L/14 model.")
|
||||
st.session_state["ViT-H-14"] = st.checkbox("ViT-H-14", value=False, help="ViT-H-14 model.")
|
||||
st.session_state["ViT-g-14"] = st.checkbox("ViT-g-14", value=False, help="ViT-g-14 model.")
|
||||
|
||||
|
||||
|
||||
with st.expander("Others"):
|
||||
st.info("For DiscoDiffusion and JAX enable all the same models here as you intend to use when generating your images.")
|
||||
|
||||
st.session_state["ViTL14_336px"] = st.checkbox("ViTL14_336px", value=False, help="ViTL14_336px model.")
|
||||
st.session_state["ViTB16"] = st.checkbox("ViTB16", value=False, help="ViTB16 model.")
|
||||
st.session_state["ViTB32"] = st.checkbox("ViTB32", value=False, help="ViTB32 model.")
|
||||
st.session_state["RN50"] = st.checkbox("RN50", value=False, help="RN50 model.")
|
||||
st.session_state["RN50x4"] = st.checkbox("RN50x4", value=False, help="RN50x4 model.")
|
||||
st.session_state["RN50x16"] = st.checkbox("RN50x16", value=False, help="RN50x16 model.")
|
||||
st.session_state["RN50x64"] = st.checkbox("RN50x64", value=False, help="RN50x64 model.")
|
||||
st.session_state["RN101"] = st.checkbox("RN101", value=False, help="RN101 model.")
|
||||
|
||||
#
|
||||
# st.subheader("Logs:")
|
||||
|
||||
st.session_state["log_message"] = st.empty()
|
||||
st.session_state["log_message"].code('', language="")
|
||||
|
||||
with col2:
|
||||
st.subheader("Image")
|
||||
|
||||
image_col1, image_col2 = st.columns([10,25])
|
||||
with image_col1:
|
||||
refresh = st.form_submit_button("Update Preview Image", help='Refresh the image preview to show your uploaded image instead of the default placeholder.')
|
||||
|
||||
if st.session_state["uploaded_image"]:
|
||||
#print (type(st.session_state["uploaded_image"]))
|
||||
# if len(st.session_state["uploaded_image"]) == 1:
|
||||
st.session_state["input_image_preview"] = []
|
||||
st.session_state["input_image_preview_container"] = []
|
||||
st.session_state["prediction_table"] = []
|
||||
st.session_state["text_result"] = []
|
||||
|
||||
for i in range(len(st.session_state["uploaded_image"])):
|
||||
st.session_state["input_image_preview_container"].append(i)
|
||||
st.session_state["input_image_preview_container"][i] = st.empty()
|
||||
|
||||
with st.session_state["input_image_preview_container"][i].container():
|
||||
col1_output, col2_output = st.columns([2, 10], gap="medium")
|
||||
with col1_output:
|
||||
st.session_state["input_image_preview"].append(i)
|
||||
st.session_state["input_image_preview"][i] = st.empty()
|
||||
st.session_state["uploaded_image"][i].pil_image = Image.open(st.session_state["uploaded_image"][i]).convert('RGB')
|
||||
|
||||
st.session_state["input_image_preview"][i].image(st.session_state["uploaded_image"][i].pil_image, use_column_width=True, clamp=True)
|
||||
|
||||
with st.session_state["input_image_preview_container"][i].container():
|
||||
|
||||
with col2_output:
|
||||
|
||||
st.session_state["prediction_table"].append(i)
|
||||
st.session_state["prediction_table"][i] = st.empty()
|
||||
st.session_state["prediction_table"][i].table()
|
||||
|
||||
st.session_state["text_result"].append(i)
|
||||
st.session_state["text_result"][i] = st.empty()
|
||||
st.session_state["text_result"][i].code("", language="")
|
||||
|
||||
else:
|
||||
#st.session_state["input_image_preview"].code('', language="")
|
||||
st.image("images/streamlit/img2txt_placeholder.png", clamp=True)
|
||||
|
||||
with image_col2:
|
||||
#
|
||||
# Every form must have a submit button, the extra blank spaces is a temp way to align it with the input field. Needs to be done in CSS or some other way.
|
||||
# generate_col1.title("")
|
||||
# generate_col1.title("")
|
||||
generate_button = st.form_submit_button("Generate!", help="Start interrogating the images to generate a prompt from each of the selected images")
|
||||
|
||||
if generate_button:
|
||||
# if model, pipe, RealESRGAN or GFPGAN is in st.session_state remove the model and pipe form session_state so that they are reloaded.
|
||||
if "model" in server_state and st.session_state["defaults"].general.optimized:
|
||||
del server_state["model"]
|
||||
if "pipe" in server_state and st.session_state["defaults"].general.optimized:
|
||||
del server_state["pipe"]
|
||||
if "RealESRGAN" in server_state and st.session_state["defaults"].general.optimized:
|
||||
del server_state["RealESRGAN"]
|
||||
if "GFPGAN" in server_state and st.session_state["defaults"].general.optimized:
|
||||
del server_state["GFPGAN"]
|
||||
|
||||
# run clip interrogator
|
||||
img2txt()
|
368
webui/streamlit/scripts/post_processing.py
Normal file
368
webui/streamlit/scripts/post_processing.py
Normal file
@ -0,0 +1,368 @@
|
||||
# This file is part of sygil-webui (https://github.com/Sygil-Dev/sandbox-webui/).
|
||||
|
||||
# Copyright 2022 Sygil-Dev team.
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as published by
|
||||
# the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
|
||||
# This program is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU Affero General Public License for more details.
|
||||
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
# base webui import and utils.
|
||||
#from sd_utils import *
|
||||
from sd_utils import st, server_state, \
|
||||
RealESRGAN_available, GFPGAN_available, LDSR_available
|
||||
|
||||
# streamlit imports
|
||||
|
||||
#streamlit components section
|
||||
import hydralit_components as hc
|
||||
|
||||
#other imports
|
||||
import os
|
||||
from PIL import Image
|
||||
import torch
|
||||
|
||||
# Temp imports
|
||||
|
||||
# end of imports
|
||||
#---------------------------------------------------------------------------------------------------------------
|
||||
|
||||
def post_process(use_GFPGAN=True, GFPGAN_model='', use_RealESRGAN=False, realesrgan_model_name="", use_LDSR=False, LDSR_model_name=""):
|
||||
|
||||
for i in range(len(st.session_state["uploaded_image"])):
|
||||
#st.session_state["uploaded_image"][i].pil_image
|
||||
|
||||
if use_GFPGAN and server_state["GFPGAN"] is not None and not use_RealESRGAN and not use_LDSR:
|
||||
if "progress_bar_text" in st.session_state:
|
||||
st.session_state["progress_bar_text"].text("Running GFPGAN on image %d of %d..." % (i+1, len(st.session_state["uploaded_image"])))
|
||||
|
||||
if "progress_bar" in st.session_state:
|
||||
st.session_state["progress_bar"].progress(
|
||||
int(100 * float(i+1 if i+1 < len(st.session_state["uploaded_image"]) else len(st.session_state["uploaded_image"]))/float(len(st.session_state["uploaded_image"]))))
|
||||
|
||||
if server_state["GFPGAN"].name != GFPGAN_model:
|
||||
load_models(use_LDSR=use_LDSR, LDSR_model=LDSR_model_name, use_GFPGAN=use_GFPGAN, use_RealESRGAN=use_RealESRGAN, RealESRGAN_model=realesrgan_model_name)
|
||||
|
||||
torch_gc()
|
||||
|
||||
with torch.autocast('cuda'):
|
||||
cropped_faces, restored_faces, restored_img = server_state["GFPGAN"].enhance(st.session_state["uploaded_image"][i].pil_image, has_aligned=False, only_center_face=False, paste_back=True)
|
||||
|
||||
gfpgan_sample = restored_img[:,:,::-1]
|
||||
gfpgan_image = Image.fromarray(gfpgan_sample)
|
||||
|
||||
#if st.session_state["GFPGAN_strenght"]:
|
||||
#gfpgan_sample = Image.blend(image, gfpgan_image, st.session_state["GFPGAN_strenght"])
|
||||
|
||||
gfpgan_filename = st.session_state["uploaded_image"][i].name.split('.')[0] + '-gfpgan'
|
||||
|
||||
gfpgan_image.save(os.path.join(st.session_state["defaults"].post_processing.outdir_post_processing, f"{gfpgan_filename}.png"))
|
||||
|
||||
#
|
||||
elif use_RealESRGAN and server_state["RealESRGAN"] is not None and not use_GFPGAN:
|
||||
if "progress_bar_text" in st.session_state:
|
||||
st.session_state["progress_bar_text"].text("Running RealESRGAN on image %d of %d..." % (i+1, len(st.session_state["uploaded_image"])))
|
||||
|
||||
if "progress_bar" in st.session_state:
|
||||
st.session_state["progress_bar"].progress(
|
||||
int(100 * float(i+1 if i+1 < len(st.session_state["uploaded_image"]) else len(st.session_state["uploaded_image"]))/float(len(st.session_state["uploaded_image"]))))
|
||||
|
||||
torch_gc()
|
||||
|
||||
if server_state["RealESRGAN"].model.name != realesrgan_model_name:
|
||||
#try_loading_RealESRGAN(realesrgan_model_name)
|
||||
load_models(use_GFPGAN=use_GFPGAN, use_RealESRGAN=use_RealESRGAN, RealESRGAN_model=realesrgan_model_name)
|
||||
|
||||
output, img_mode = server_state["RealESRGAN"].enhance(st.session_state["uploaded_image"][i].pil_image)
|
||||
esrgan_filename = st.session_state["uploaded_image"][i].name.split('.')[0] + '-esrgan4x'
|
||||
esrgan_sample = output[:,:,::-1]
|
||||
esrgan_image = Image.fromarray(esrgan_sample)
|
||||
|
||||
esrgan_image.save(os.path.join(st.session_state["defaults"].post_processing.outdir_post_processing, f"{esrgan_filename}.png"))
|
||||
|
||||
#
|
||||
elif use_LDSR and "LDSR" in server_state and not use_GFPGAN:
|
||||
logger.info ("Running LDSR on image %d of %d..." % (i+1, len(st.session_state["uploaded_image"])))
|
||||
if "progress_bar_text" in st.session_state:
|
||||
st.session_state["progress_bar_text"].text("Running LDSR on image %d of %d..." % (i+1, len(st.session_state["uploaded_image"])))
|
||||
if "progress_bar" in st.session_state:
|
||||
st.session_state["progress_bar"].progress(
|
||||
int(100 * float(i+1 if i+1 < len(st.session_state["uploaded_image"]) else len(st.session_state["uploaded_image"]))/float(len(st.session_state["uploaded_image"]))))
|
||||
|
||||
torch_gc()
|
||||
|
||||
if server_state["LDSR"].name != LDSR_model_name:
|
||||
#try_loading_RealESRGAN(realesrgan_model_name)
|
||||
load_models(use_LDSR=use_LDSR, LDSR_model=LDSR_model_name, use_GFPGAN=use_GFPGAN, use_RealESRGAN=use_RealESRGAN, RealESRGAN_model=realesrgan_model_name)
|
||||
|
||||
result = server_state["LDSR"].superResolution(st.session_state["uploaded_image"][i].pil_image, ddimSteps = st.session_state["ldsr_sampling_steps"],
|
||||
preDownScale = st.session_state["preDownScale"], postDownScale = st.session_state["postDownScale"],
|
||||
downsample_method=st.session_state["downsample_method"])
|
||||
|
||||
ldsr_filename = st.session_state["uploaded_image"][i].name.split('.')[0] + '-ldsr4x'
|
||||
|
||||
result.save(os.path.join(st.session_state["defaults"].post_processing.outdir_post_processing, f"{ldsr_filename}.png"))
|
||||
|
||||
#
|
||||
elif use_LDSR and "LDSR" in server_state and use_GFPGAN and "GFPGAN" in server_state:
|
||||
logger.info ("Running GFPGAN+LDSR on image %d of %d..." % (i+1, len(st.session_state["uploaded_image"])))
|
||||
if "progress_bar_text" in st.session_state:
|
||||
st.session_state["progress_bar_text"].text("Running GFPGAN+LDSR on image %d of %d..." % (i+1, len(st.session_state["uploaded_image"])))
|
||||
|
||||
if "progress_bar" in st.session_state:
|
||||
st.session_state["progress_bar"].progress(
|
||||
int(100 * float(i+1 if i+1 < len(st.session_state["uploaded_image"]) else len(st.session_state["uploaded_image"]))/float(len(st.session_state["uploaded_image"]))))
|
||||
|
||||
if server_state["GFPGAN"].name != GFPGAN_model:
|
||||
load_models(use_LDSR=use_LDSR, LDSR_model=LDSR_model_name, use_GFPGAN=use_GFPGAN, use_RealESRGAN=use_RealESRGAN, RealESRGAN_model=realesrgan_model_name)
|
||||
|
||||
torch_gc()
|
||||
cropped_faces, restored_faces, restored_img = server_state["GFPGAN"].enhance(st.session_state["uploaded_image"][i].pil_image, has_aligned=False, only_center_face=False, paste_back=True)
|
||||
|
||||
gfpgan_sample = restored_img[:,:,::-1]
|
||||
gfpgan_image = Image.fromarray(gfpgan_sample)
|
||||
|
||||
if server_state["LDSR"].name != LDSR_model_name:
|
||||
#try_loading_RealESRGAN(realesrgan_model_name)
|
||||
load_models(use_LDSR=use_LDSR, LDSR_model=LDSR_model_name, use_GFPGAN=use_GFPGAN, use_RealESRGAN=use_RealESRGAN, RealESRGAN_model=realesrgan_model_name)
|
||||
|
||||
#LDSR.superResolution(gfpgan_image, ddimSteps=100, preDownScale='None', postDownScale='None', downsample_method="Lanczos")
|
||||
result = server_state["LDSR"].superResolution(gfpgan_image, ddimSteps = st.session_state["ldsr_sampling_steps"],
|
||||
preDownScale = st.session_state["preDownScale"], postDownScale = st.session_state["postDownScale"],
|
||||
downsample_method=st.session_state["downsample_method"])
|
||||
|
||||
ldsr_filename = st.session_state["uploaded_image"][i].name.split('.')[0] + '-gfpgan-ldsr2x'
|
||||
|
||||
result.save(os.path.join(st.session_state["defaults"].post_processing.outdir_post_processing, f"{ldsr_filename}.png"))
|
||||
|
||||
elif use_RealESRGAN and server_state["RealESRGAN"] is not None and use_GFPGAN and server_state["GFPGAN"] is not None:
|
||||
if "progress_bar_text" in st.session_state:
|
||||
st.session_state["progress_bar_text"].text("Running GFPGAN+RealESRGAN on image %d of %d..." % (i+1, len(st.session_state["uploaded_image"])))
|
||||
|
||||
if "progress_bar" in st.session_state:
|
||||
st.session_state["progress_bar"].progress(
|
||||
int(100 * float(i+1 if i+1 < len(st.session_state["uploaded_image"]) else len(st.session_state["uploaded_image"]))/float(len(st.session_state["uploaded_image"]))))
|
||||
|
||||
torch_gc()
|
||||
cropped_faces, restored_faces, restored_img = server_state["GFPGAN"].enhance(st.session_state["uploaded_image"][i].pil_image, has_aligned=False, only_center_face=False, paste_back=True)
|
||||
gfpgan_sample = restored_img[:,:,::-1]
|
||||
|
||||
if server_state["RealESRGAN"].model.name != realesrgan_model_name:
|
||||
#try_loading_RealESRGAN(realesrgan_model_name)
|
||||
load_models(use_GFPGAN=use_GFPGAN, use_RealESRGAN=use_RealESRGAN, RealESRGAN_model=realesrgan_model_name)
|
||||
|
||||
output, img_mode = server_state["RealESRGAN"].enhance(gfpgan_sample[:,:,::-1])
|
||||
gfpgan_esrgan_filename = st.session_state["uploaded_image"][i].name.split('.')[0] + '-gfpgan-esrgan4x'
|
||||
gfpgan_esrgan_sample = output[:,:,::-1]
|
||||
gfpgan_esrgan_image = Image.fromarray(gfpgan_esrgan_sample)
|
||||
|
||||
gfpgan_esrgan_image.save(os.path.join(st.session_state["defaults"].post_processing.outdir_post_processing, f"{gfpgan_esrgan_filename}.png"))
|
||||
|
||||
|
||||
|
||||
def layout():
|
||||
#st.info("Under Construction. :construction_worker:")
|
||||
st.session_state["progress_bar_text"] = st.empty()
|
||||
#st.session_state["progress_bar_text"].info("Nothing but crickets here, try generating something first.")
|
||||
|
||||
st.session_state["progress_bar"] = st.empty()
|
||||
|
||||
with st.form("post-processing-inputs"):
|
||||
# creating the page layout using columns
|
||||
col1, col2 = st.columns([1, 4], gap="medium")
|
||||
|
||||
with col1:
|
||||
st.session_state["uploaded_image"] = st.file_uploader('Input Image', type=['png', 'jpg', 'jpeg', 'jfif', 'webp'], accept_multiple_files=True)
|
||||
|
||||
|
||||
# check if GFPGAN, RealESRGAN and LDSR are available.
|
||||
#if "GFPGAN_available" not in st.session_state:
|
||||
GFPGAN_available()
|
||||
|
||||
#if "RealESRGAN_available" not in st.session_state:
|
||||
RealESRGAN_available()
|
||||
|
||||
#if "LDSR_available" not in st.session_state:
|
||||
LDSR_available()
|
||||
|
||||
if st.session_state["GFPGAN_available"] or st.session_state["RealESRGAN_available"] or st.session_state["LDSR_available"]:
|
||||
face_restoration_tab, upscaling_tab = st.tabs(["Face Restoration", "Upscaling"])
|
||||
with face_restoration_tab:
|
||||
# GFPGAN used for face restoration
|
||||
if st.session_state["GFPGAN_available"]:
|
||||
#with st.expander("Face Restoration"):
|
||||
#if st.session_state["GFPGAN_available"]:
|
||||
#with st.expander("GFPGAN"):
|
||||
st.session_state["use_GFPGAN"] = st.checkbox("Use GFPGAN", value=st.session_state['defaults'].txt2img.use_GFPGAN,
|
||||
help="Uses the GFPGAN model to improve faces after the generation.\
|
||||
This greatly improve the quality and consistency of faces but uses\
|
||||
extra VRAM. Disable if you need the extra VRAM.")
|
||||
|
||||
st.session_state["GFPGAN_model"] = st.selectbox("GFPGAN model", st.session_state["GFPGAN_models"],
|
||||
index=st.session_state["GFPGAN_models"].index(st.session_state['defaults'].general.GFPGAN_model))
|
||||
|
||||
#st.session_state["GFPGAN_strenght"] = st.slider("Effect Strenght", min_value=1, max_value=100, value=1, step=1, help='')
|
||||
|
||||
else:
|
||||
st.session_state["use_GFPGAN"] = False
|
||||
|
||||
with upscaling_tab:
|
||||
st.session_state['use_upscaling'] = st.checkbox("Use Upscaling", value=st.session_state['defaults'].txt2img.use_upscaling)
|
||||
|
||||
# RealESRGAN and LDSR used for upscaling.
|
||||
if st.session_state["RealESRGAN_available"] or st.session_state["LDSR_available"]:
|
||||
|
||||
upscaling_method_list = []
|
||||
if st.session_state["RealESRGAN_available"]:
|
||||
upscaling_method_list.append("RealESRGAN")
|
||||
if st.session_state["LDSR_available"]:
|
||||
upscaling_method_list.append("LDSR")
|
||||
|
||||
#print (st.session_state["RealESRGAN_available"])
|
||||
st.session_state["upscaling_method"] = st.selectbox("Upscaling Method", upscaling_method_list,
|
||||
index=upscaling_method_list.index(st.session_state['defaults'].general.upscaling_method)
|
||||
if st.session_state['defaults'].general.upscaling_method in upscaling_method_list
|
||||
else 0)
|
||||
|
||||
if st.session_state["RealESRGAN_available"]:
|
||||
with st.expander("RealESRGAN"):
|
||||
if st.session_state["upscaling_method"] == "RealESRGAN" and st.session_state['use_upscaling']:
|
||||
st.session_state["use_RealESRGAN"] = True
|
||||
else:
|
||||
st.session_state["use_RealESRGAN"] = False
|
||||
|
||||
st.session_state["RealESRGAN_model"] = st.selectbox("RealESRGAN model", st.session_state["RealESRGAN_models"],
|
||||
index=st.session_state["RealESRGAN_models"].index(st.session_state['defaults'].general.RealESRGAN_model))
|
||||
else:
|
||||
st.session_state["use_RealESRGAN"] = False
|
||||
st.session_state["RealESRGAN_model"] = "RealESRGAN_x4plus"
|
||||
|
||||
|
||||
#
|
||||
if st.session_state["LDSR_available"]:
|
||||
with st.expander("LDSR"):
|
||||
if st.session_state["upscaling_method"] == "LDSR" and st.session_state['use_upscaling']:
|
||||
st.session_state["use_LDSR"] = True
|
||||
else:
|
||||
st.session_state["use_LDSR"] = False
|
||||
|
||||
st.session_state["LDSR_model"] = st.selectbox("LDSR model", st.session_state["LDSR_models"],
|
||||
index=st.session_state["LDSR_models"].index(st.session_state['defaults'].general.LDSR_model))
|
||||
|
||||
st.session_state["ldsr_sampling_steps"] = st.number_input("Sampling Steps", value=st.session_state['defaults'].txt2img.LDSR_config.sampling_steps,
|
||||
help="")
|
||||
|
||||
st.session_state["preDownScale"] = st.number_input("PreDownScale", value=st.session_state['defaults'].txt2img.LDSR_config.preDownScale,
|
||||
help="")
|
||||
|
||||
st.session_state["postDownScale"] = st.number_input("postDownScale", value=st.session_state['defaults'].txt2img.LDSR_config.postDownScale,
|
||||
help="")
|
||||
|
||||
downsample_method_list = ['Nearest', 'Lanczos']
|
||||
st.session_state["downsample_method"] = st.selectbox("Downsample Method", downsample_method_list,
|
||||
index=downsample_method_list.index(st.session_state['defaults'].txt2img.LDSR_config.downsample_method))
|
||||
|
||||
else:
|
||||
st.session_state["use_LDSR"] = False
|
||||
st.session_state["LDSR_model"] = "model"
|
||||
|
||||
#process = st.form_submit_button("Process Images", help="")
|
||||
|
||||
#
|
||||
with st.expander("Output Settings", True):
|
||||
#st.session_state['defaults'].post_processing.save_original_images = st.checkbox("Save input images.", value=st.session_state['defaults'].post_processing.save_original_images,
|
||||
#help="Save each original/input image next to the Post Processed image. "
|
||||
#"This might be helpful for comparing the before and after images.")
|
||||
|
||||
st.session_state['defaults'].post_processing.outdir_post_processing = st.text_input("Output Dir",value=st.session_state['defaults'].post_processing.outdir_post_processing,
|
||||
help="Folder where the images will be saved after post processing.")
|
||||
|
||||
with col2:
|
||||
st.subheader("Image")
|
||||
|
||||
image_col1, image_col2, image_col3 = st.columns([2, 2, 2], gap="small")
|
||||
with image_col1:
|
||||
refresh = st.form_submit_button("Refresh", help='Refresh the image preview to show your uploaded image.')
|
||||
|
||||
if st.session_state["uploaded_image"]:
|
||||
#print (type(st.session_state["uploaded_image"]))
|
||||
# if len(st.session_state["uploaded_image"]) == 1:
|
||||
st.session_state["input_image_preview"] = []
|
||||
st.session_state["input_image_caption"] = []
|
||||
st.session_state["output_image_preview"] = []
|
||||
st.session_state["output_image_caption"] = []
|
||||
st.session_state["input_image_preview_container"] = []
|
||||
st.session_state["prediction_table"] = []
|
||||
st.session_state["text_result"] = []
|
||||
|
||||
for i in range(len(st.session_state["uploaded_image"])):
|
||||
st.session_state["input_image_preview_container"].append(i)
|
||||
st.session_state["input_image_preview_container"][i] = st.empty()
|
||||
|
||||
with st.session_state["input_image_preview_container"][i].container():
|
||||
col1_output, col2_output, col3_output = st.columns([2, 2, 2], gap="medium")
|
||||
with col1_output:
|
||||
st.session_state["output_image_caption"].append(i)
|
||||
st.session_state["output_image_caption"][i] = st.empty()
|
||||
#st.session_state["output_image_caption"][i] = st.session_state["uploaded_image"][i].name
|
||||
|
||||
st.session_state["input_image_caption"].append(i)
|
||||
st.session_state["input_image_caption"][i] = st.empty()
|
||||
#st.session_state["input_image_caption"][i].caption(")
|
||||
|
||||
st.session_state["input_image_preview"].append(i)
|
||||
st.session_state["input_image_preview"][i] = st.empty()
|
||||
st.session_state["uploaded_image"][i].pil_image = Image.open(st.session_state["uploaded_image"][i]).convert('RGB')
|
||||
|
||||
st.session_state["input_image_preview"][i].image(st.session_state["uploaded_image"][i].pil_image, use_column_width=True, clamp=True)
|
||||
|
||||
with col2_output:
|
||||
st.session_state["output_image_preview"].append(i)
|
||||
st.session_state["output_image_preview"][i] = st.empty()
|
||||
|
||||
st.session_state["output_image_preview"][i].image(st.session_state["uploaded_image"][i].pil_image, use_column_width=True, clamp=True)
|
||||
|
||||
with st.session_state["input_image_preview_container"][i].container():
|
||||
|
||||
with col3_output:
|
||||
|
||||
#st.session_state["prediction_table"].append(i)
|
||||
#st.session_state["prediction_table"][i] = st.empty()
|
||||
#st.session_state["prediction_table"][i].table(pd.DataFrame(columns=["Model", "Filename", "Progress"]))
|
||||
|
||||
st.session_state["text_result"].append(i)
|
||||
st.session_state["text_result"][i] = st.empty()
|
||||
st.session_state["text_result"][i].code("", language="")
|
||||
|
||||
#else:
|
||||
##st.session_state["input_image_preview"].code('', language="")
|
||||
#st.image("images/streamlit/img2txt_placeholder.png", clamp=True)
|
||||
|
||||
with image_col3:
|
||||
# Every form must have a submit button, the extra blank spaces is a temp way to align it with the input field. Needs to be done in CSS or some other way.
|
||||
process = st.form_submit_button("Process Images!")
|
||||
|
||||
if process:
|
||||
with hc.HyLoader('Loading Models...', hc.Loaders.standard_loaders,index=[0]):
|
||||
#load_models(use_LDSR=st.session_state["use_LDSR"], LDSR_model=st.session_state["LDSR_model"],
|
||||
#use_GFPGAN=st.session_state["use_GFPGAN"], GFPGAN_model=st.session_state["GFPGAN_model"] ,
|
||||
#use_RealESRGAN=st.session_state["use_RealESRGAN"], RealESRGAN_model=st.session_state["RealESRGAN_model"])
|
||||
|
||||
if st.session_state["use_GFPGAN"]:
|
||||
load_GFPGAN(model_name=st.session_state["GFPGAN_model"])
|
||||
|
||||
if st.session_state["use_RealESRGAN"]:
|
||||
load_RealESRGAN(st.session_state["RealESRGAN_model"])
|
||||
|
||||
if st.session_state["use_LDSR"]:
|
||||
load_LDSR(st.session_state["LDSR_model"])
|
||||
|
||||
post_process(use_GFPGAN=st.session_state["use_GFPGAN"], GFPGAN_model=st.session_state["GFPGAN_model"],
|
||||
use_RealESRGAN=st.session_state["use_RealESRGAN"], realesrgan_model_name=st.session_state["RealESRGAN_model"],
|
||||
use_LDSR=st.session_state["use_LDSR"], LDSR_model_name=st.session_state["LDSR_model"])
|
260
webui/streamlit/scripts/sd_concept_library.py
Normal file
260
webui/streamlit/scripts/sd_concept_library.py
Normal file
@ -0,0 +1,260 @@
|
||||
# This file is part of sygil-webui (https://github.com/Sygil-Dev/sygil-webui/).
|
||||
|
||||
# Copyright 2022 Sygil-Dev team.
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as published by
|
||||
# the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
|
||||
# This program is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU Affero General Public License for more details.
|
||||
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
# base webui import and utils.
|
||||
from sd_utils import st
|
||||
|
||||
# streamlit imports
|
||||
import streamlit.components.v1 as components
|
||||
#other imports
|
||||
|
||||
import os, math
|
||||
from PIL import Image
|
||||
|
||||
# Temp imports
|
||||
#from basicsr.utils.registry import ARCH_REGISTRY
|
||||
|
||||
|
||||
# end of imports
|
||||
#---------------------------------------------------------------------------------------------------------------
|
||||
|
||||
# Init Vuejs component
|
||||
_component_func = components.declare_component(
|
||||
"sd-concepts-browser", "./frontend/dists/concept-browser/dist")
|
||||
|
||||
|
||||
def sdConceptsBrowser(concepts, key=None):
|
||||
component_value = _component_func(concepts=concepts, key=key, default="")
|
||||
return component_value
|
||||
|
||||
|
||||
@st.experimental_memo(persist="disk", show_spinner=False, suppress_st_warning=True)
|
||||
def getConceptsFromPath(page, conceptPerPage, searchText=""):
|
||||
#print("getConceptsFromPath", "page:", page, "conceptPerPage:", conceptPerPage, "searchText:", searchText)
|
||||
# get the path where the concepts are stored
|
||||
path = os.path.join(
|
||||
os.getcwd(), st.session_state['defaults'].general.sd_concepts_library_folder)
|
||||
acceptedExtensions = ('jpeg', 'jpg', "png")
|
||||
concepts = []
|
||||
|
||||
if os.path.exists(path):
|
||||
# List all folders (concepts) in the path
|
||||
folders = [f for f in os.listdir(
|
||||
path) if os.path.isdir(os.path.join(path, f))]
|
||||
filteredFolders = folders
|
||||
|
||||
# Filter the folders by the search text
|
||||
if searchText != "":
|
||||
filteredFolders = [
|
||||
f for f in folders if searchText.lower() in f.lower()]
|
||||
else:
|
||||
filteredFolders = []
|
||||
|
||||
conceptIndex = 1
|
||||
for folder in filteredFolders:
|
||||
# handle pagination
|
||||
if conceptIndex > (page * conceptPerPage):
|
||||
continue
|
||||
if conceptIndex <= ((page - 1) * conceptPerPage):
|
||||
conceptIndex += 1
|
||||
continue
|
||||
|
||||
concept = {
|
||||
"name": folder,
|
||||
"token": "<" + folder + ">",
|
||||
"images": [],
|
||||
"type": ""
|
||||
}
|
||||
|
||||
# type of concept is inside type_of_concept.txt
|
||||
typePath = os.path.join(path, folder, "type_of_concept.txt")
|
||||
binFile = os.path.join(path, folder, "learned_embeds.bin")
|
||||
|
||||
# Continue if the concept is not valid or the download has failed (no type_of_concept.txt or no binFile)
|
||||
if not os.path.exists(typePath) or not os.path.exists(binFile):
|
||||
continue
|
||||
|
||||
with open(typePath, "r") as f:
|
||||
concept["type"] = f.read()
|
||||
|
||||
# List all files in the concept/concept_images folder
|
||||
files = [f for f in os.listdir(os.path.join(path, folder, "concept_images")) if os.path.isfile(
|
||||
os.path.join(path, folder, "concept_images", f))]
|
||||
# Retrieve only the 4 first images
|
||||
for file in files:
|
||||
|
||||
# Skip if we already have 4 images
|
||||
if len(concept["images"]) >= 4:
|
||||
break
|
||||
|
||||
if file.endswith(acceptedExtensions):
|
||||
try:
|
||||
# Add a copy of the image to avoid file locking
|
||||
originalImage = Image.open(os.path.join(
|
||||
path, folder, "concept_images", file))
|
||||
|
||||
# Maintain the aspect ratio (max 200x200)
|
||||
resizedImage = originalImage.copy()
|
||||
resizedImage.thumbnail((200, 200), Image.Resampling.LANCZOS)
|
||||
|
||||
# concept["images"].append(resizedImage)
|
||||
|
||||
concept["images"].append(imageToBase64(resizedImage))
|
||||
# Close original image
|
||||
originalImage.close()
|
||||
except:
|
||||
print("Error while loading image", file, "in concept", folder, "(The file may be corrupted). Skipping it.")
|
||||
|
||||
concepts.append(concept)
|
||||
conceptIndex += 1
|
||||
# print all concepts name
|
||||
#print("Results:", [c["name"] for c in concepts])
|
||||
return concepts
|
||||
|
||||
@st.cache(persist=True, allow_output_mutation=True, show_spinner=False, suppress_st_warning=True)
|
||||
def imageToBase64(image):
|
||||
import io
|
||||
import base64
|
||||
buffered = io.BytesIO()
|
||||
image.save(buffered, format="PNG")
|
||||
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
||||
return img_str
|
||||
|
||||
|
||||
@st.experimental_memo(persist="disk", show_spinner=False, suppress_st_warning=True)
|
||||
def getTotalNumberOfConcepts(searchText=""):
|
||||
# get the path where the concepts are stored
|
||||
path = os.path.join(
|
||||
os.getcwd(), st.session_state['defaults'].general.sd_concepts_library_folder)
|
||||
concepts = []
|
||||
|
||||
if os.path.exists(path):
|
||||
# List all folders (concepts) in the path
|
||||
folders = [f for f in os.listdir(
|
||||
path) if os.path.isdir(os.path.join(path, f))]
|
||||
filteredFolders = folders
|
||||
|
||||
# Filter the folders by the search text
|
||||
if searchText != "":
|
||||
filteredFolders = [
|
||||
f for f in folders if searchText.lower() in f.lower()]
|
||||
else:
|
||||
filteredFolders = []
|
||||
return len(filteredFolders)
|
||||
|
||||
|
||||
|
||||
def layout():
|
||||
# 2 tabs, one for Concept Library and one for the Download Manager
|
||||
tab_library, tab_downloader = st.tabs(["Library", "Download Manager"])
|
||||
|
||||
# Concept Library
|
||||
with tab_library:
|
||||
downloaded_concepts_count = getTotalNumberOfConcepts()
|
||||
concepts_per_page = st.session_state["defaults"].concepts_library.concepts_per_page
|
||||
|
||||
if not "results" in st.session_state:
|
||||
st.session_state["results"] = getConceptsFromPath(1, concepts_per_page, "")
|
||||
|
||||
# Pagination controls
|
||||
if not "cl_current_page" in st.session_state:
|
||||
st.session_state["cl_current_page"] = 1
|
||||
|
||||
# Search
|
||||
if not 'cl_search_text' in st.session_state:
|
||||
st.session_state["cl_search_text"] = ""
|
||||
|
||||
if not 'cl_search_results_count' in st.session_state:
|
||||
st.session_state["cl_search_results_count"] = downloaded_concepts_count
|
||||
|
||||
# Search bar
|
||||
_search_col, _refresh_col = st.columns([10, 2])
|
||||
with _search_col:
|
||||
search_text_input = st.text_input("Search", "", placeholder=f'Search for a concept ({downloaded_concepts_count} available)', label_visibility="hidden")
|
||||
if search_text_input != st.session_state["cl_search_text"]:
|
||||
# Search text has changed
|
||||
st.session_state["cl_search_text"] = search_text_input
|
||||
st.session_state["cl_current_page"] = 1
|
||||
st.session_state["cl_search_results_count"] = getTotalNumberOfConcepts(st.session_state["cl_search_text"])
|
||||
st.session_state["results"] = getConceptsFromPath(1, concepts_per_page, st.session_state["cl_search_text"])
|
||||
|
||||
with _refresh_col:
|
||||
# Super weird fix to align the refresh button with the search bar ( Please streamlit, add css support.. )
|
||||
_refresh_col.write("")
|
||||
_refresh_col.write("")
|
||||
if st.button("Refresh concepts", key="refresh_concepts", help="Refresh the concepts folders. Use this if you have added new concepts manually or deleted some."):
|
||||
getTotalNumberOfConcepts.clear()
|
||||
getConceptsFromPath.clear()
|
||||
st.experimental_rerun()
|
||||
|
||||
|
||||
# Show results
|
||||
results_empty = st.empty()
|
||||
|
||||
# Pagination
|
||||
pagination_empty = st.empty()
|
||||
|
||||
# Layouts
|
||||
with pagination_empty:
|
||||
with st.container():
|
||||
if len(st.session_state["results"]) > 0:
|
||||
last_page = math.ceil(st.session_state["cl_search_results_count"] / concepts_per_page)
|
||||
_1, _2, _3, _4, _previous_page, _current_page, _next_page, _9, _10, _11, _12 = st.columns([1,1,1,1,1,2,1,1,1,1,1])
|
||||
|
||||
# Previous page
|
||||
with _previous_page:
|
||||
if st.button("Previous", key="cl_previous_page"):
|
||||
st.session_state["cl_current_page"] -= 1
|
||||
if st.session_state["cl_current_page"] <= 0:
|
||||
st.session_state["cl_current_page"] = last_page
|
||||
st.session_state["results"] = getConceptsFromPath(st.session_state["cl_current_page"], concepts_per_page, st.session_state["cl_search_text"])
|
||||
|
||||
# Current page
|
||||
with _current_page:
|
||||
_current_page_container = st.empty()
|
||||
|
||||
# Next page
|
||||
with _next_page:
|
||||
if st.button("Next", key="cl_next_page"):
|
||||
st.session_state["cl_current_page"] += 1
|
||||
if st.session_state["cl_current_page"] > last_page:
|
||||
st.session_state["cl_current_page"] = 1
|
||||
st.session_state["results"] = getConceptsFromPath(st.session_state["cl_current_page"], concepts_per_page, st.session_state["cl_search_text"])
|
||||
|
||||
# Current page
|
||||
with _current_page_container:
|
||||
st.markdown(f'<p style="text-align: center">Page {st.session_state["cl_current_page"]} of {last_page}</p>', unsafe_allow_html=True)
|
||||
# st.write(f"Page {st.session_state['cl_current_page']} of {last_page}", key="cl_current_page")
|
||||
|
||||
with results_empty:
|
||||
with st.container():
|
||||
if downloaded_concepts_count == 0:
|
||||
st.write("You don't have any concepts in your library ")
|
||||
st.markdown("To add concepts to your library, download some from the [sd-concepts-library](https://github.com/Sygil-Dev/sd-concepts-library) \
|
||||
repository and save the content of `sd-concepts-library` into ```./models/custom/sd-concepts-library``` or just create your own concepts :wink:.", unsafe_allow_html=False)
|
||||
else:
|
||||
if len(st.session_state["results"]) == 0:
|
||||
st.write("No concept found in the library matching your search: " + st.session_state["cl_search_text"])
|
||||
else:
|
||||
# display number of results
|
||||
if st.session_state["cl_search_text"]:
|
||||
st.write(f"Found {st.session_state['cl_search_results_count']} {'concepts' if st.session_state['cl_search_results_count'] > 1 else 'concept' } matching your search")
|
||||
sdConceptsBrowser(st.session_state['results'], key="results")
|
||||
|
||||
|
||||
with tab_downloader:
|
||||
st.write("Not implemented yet")
|
||||
|
||||
return False
|
405
webui/streamlit/scripts/sd_utils/__init__.py
Normal file
405
webui/streamlit/scripts/sd_utils/__init__.py
Normal file
@ -0,0 +1,405 @@
|
||||
# This file is part of sygil-webui (https://github.com/Sygil-Dev/sygil-webui/).
|
||||
|
||||
# Copyright 2022 Sygil-Dev team.
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as published by
|
||||
# the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
|
||||
# This program is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU Affero General Public License for more details.
|
||||
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
# base webui import and utils.
|
||||
#from webui_streamlit import st
|
||||
import hydralit as st
|
||||
|
||||
# streamlit imports
|
||||
from streamlit.runtime.scriptrunner import StopException
|
||||
#from streamlit.runtime.scriptrunner import script_run_context
|
||||
|
||||
#streamlit components section
|
||||
from streamlit_server_state import server_state, server_state_lock, no_rerun
|
||||
import hydralit_components as hc
|
||||
from hydralit import HydraHeadApp
|
||||
import streamlit_nested_layout
|
||||
#from streamlitextras.threader import lock, trigger_rerun, \
|
||||
#streamlit_thread, get_thread, \
|
||||
#last_trigger_time
|
||||
|
||||
#other imports
|
||||
|
||||
import warnings
|
||||
import json
|
||||
|
||||
import base64, cv2
|
||||
import os, sys, re, random, datetime, time, math, toml
|
||||
import gc
|
||||
from PIL import Image, ImageFont, ImageDraw, ImageFilter
|
||||
from PIL.PngImagePlugin import PngInfo
|
||||
from scipy import integrate
|
||||
import torch
|
||||
from torchdiffeq import odeint
|
||||
import k_diffusion as K
|
||||
import math, requests
|
||||
import mimetypes
|
||||
import numpy as np
|
||||
import pynvml
|
||||
import threading
|
||||
import torch, torchvision
|
||||
from torch import autocast
|
||||
from torchvision import transforms
|
||||
import torch.nn as nn
|
||||
from omegaconf import OmegaConf
|
||||
import yaml
|
||||
from pathlib import Path
|
||||
from contextlib import nullcontext
|
||||
from einops import rearrange, repeat
|
||||
from ldm.util import instantiate_from_config
|
||||
from retry import retry
|
||||
from slugify import slugify
|
||||
import skimage
|
||||
import piexif
|
||||
import piexif.helper
|
||||
from tqdm import trange
|
||||
from ldm.models.diffusion.ddim import DDIMSampler
|
||||
from ldm.util import ismap
|
||||
#from abc import ABC, abstractmethod
|
||||
from io import BytesIO
|
||||
from packaging import version
|
||||
from pathlib import Path
|
||||
from huggingface_hub import hf_hub_download
|
||||
import shutup
|
||||
|
||||
#import librosa
|
||||
from nataili.util.logger import logger, set_logger_verbosity, quiesce_logger
|
||||
from nataili.esrgan import esrgan
|
||||
|
||||
|
||||
#try:
|
||||
#from realesrgan import RealESRGANer
|
||||
#from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||
#except ImportError as e:
|
||||
#logger.error("You tried to import realesrgan without having it installed properly. To install Real-ESRGAN, run:\n\n"
|
||||
#"pip install realesrgan")
|
||||
|
||||
# Temp imports
|
||||
#from basicsr.utils.registry import ARCH_REGISTRY
|
||||
|
||||
|
||||
# end of imports
|
||||
#---------------------------------------------------------------------------------------------------------------
|
||||
|
||||
# remove all the annoying python warnings.
|
||||
shutup.please()
|
||||
|
||||
# the following lines should help fixing an issue with nvidia 16xx cards.
|
||||
if "defaults" in st.session_state:
|
||||
if st.session_state["defaults"].general.use_cudnn:
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cudnn.enabled = True
|
||||
|
||||
try:
|
||||
# this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
|
||||
from transformers import logging
|
||||
|
||||
logging.set_verbosity_error()
|
||||
except:
|
||||
pass
|
||||
|
||||
# disable diffusers telemetry
|
||||
os.environ["DISABLE_TELEMETRY"] = "YES"
|
||||
|
||||
# remove some annoying deprecation warnings that show every now and then.
|
||||
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
||||
warnings.filterwarnings("ignore", category=UserWarning)
|
||||
|
||||
# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the bowser will not show any UI
|
||||
mimetypes.init()
|
||||
mimetypes.add_type('application/javascript', '.js')
|
||||
|
||||
# some of those options should not be changed at all because they would break the model, so I removed them from options.
|
||||
opt_C = 4
|
||||
opt_f = 8
|
||||
|
||||
# The model manager loads and unloads the SD models and has features to download them or find their location
|
||||
#model_manager = ModelManager()
|
||||
|
||||
def load_configs():
|
||||
if not "defaults" in st.session_state:
|
||||
st.session_state["defaults"] = {}
|
||||
|
||||
st.session_state["defaults"] = OmegaConf.load("configs/webui/webui_streamlit.yaml")
|
||||
|
||||
if (os.path.exists("configs/webui/userconfig_streamlit.yaml")):
|
||||
user_defaults = OmegaConf.load("configs/webui/userconfig_streamlit.yaml")
|
||||
|
||||
if "version" in user_defaults.general:
|
||||
if version.parse(user_defaults.general.version) < version.parse(st.session_state["defaults"].general.version):
|
||||
logger.error("The version of the user config file is older than the version on the defaults config file. "
|
||||
"This means there were big changes we made on the config."
|
||||
"We are removing this file and recreating it from the defaults in order to make sure things work properly.")
|
||||
os.remove("configs/webui/userconfig_streamlit.yaml")
|
||||
st.experimental_rerun()
|
||||
else:
|
||||
logger.error("The version of the user config file is older than the version on the defaults config file. "
|
||||
"This means there were big changes we made on the config."
|
||||
"We are removing this file and recreating it from the defaults in order to make sure things work properly.")
|
||||
os.remove("configs/webui/userconfig_streamlit.yaml")
|
||||
st.experimental_rerun()
|
||||
|
||||
try:
|
||||
st.session_state["defaults"] = OmegaConf.merge(st.session_state["defaults"], user_defaults)
|
||||
except KeyError:
|
||||
st.experimental_rerun()
|
||||
else:
|
||||
OmegaConf.save(config=st.session_state.defaults, f="configs/webui/userconfig_streamlit.yaml")
|
||||
loaded = OmegaConf.load("configs/webui/userconfig_streamlit.yaml")
|
||||
assert st.session_state.defaults == loaded
|
||||
|
||||
if (os.path.exists(".streamlit/config.toml")):
|
||||
st.session_state["streamlit_config"] = toml.load(".streamlit/config.toml")
|
||||
|
||||
#if st.session_state["defaults"].daisi_app.running_on_daisi_io:
|
||||
#if os.path.exists("scripts/modeldownload.py"):
|
||||
#import modeldownload
|
||||
#modeldownload.updateModels()
|
||||
|
||||
if "keep_all_models_loaded" in st.session_state.defaults.general:
|
||||
with server_state_lock["keep_all_models_loaded"]:
|
||||
server_state["keep_all_models_loaded"] = st.session_state["defaults"].general.keep_all_models_loaded
|
||||
else:
|
||||
st.session_state["defaults"].general.keep_all_models_loaded = False
|
||||
with server_state_lock["keep_all_models_loaded"]:
|
||||
server_state["keep_all_models_loaded"] = st.session_state["defaults"].general.keep_all_models_loaded
|
||||
|
||||
load_configs()
|
||||
|
||||
#
|
||||
#if st.session_state["defaults"].debug.enable_hydralit:
|
||||
#navbar_theme = {'txc_inactive': '#FFFFFF','menu_background':'#0e1117','txc_active':'black','option_active':'red'}
|
||||
#app = st.HydraApp(title='Stable Diffusion WebUI', favicon="", use_cookie_cache=False, sidebar_state="expanded", layout="wide", navbar_theme=navbar_theme,
|
||||
#hide_streamlit_markers=False, allow_url_nav=True , clear_cross_app_sessions=False, use_loader=False)
|
||||
#else:
|
||||
#app = None
|
||||
|
||||
#
|
||||
grid_format = st.session_state["defaults"].general.save_format
|
||||
grid_lossless = False
|
||||
grid_quality = st.session_state["defaults"].general.grid_quality
|
||||
if grid_format == 'png':
|
||||
grid_ext = 'png'
|
||||
grid_format = 'png'
|
||||
elif grid_format in ['jpg', 'jpeg']:
|
||||
grid_quality = int(grid_format) if len(grid_format) > 1 else 100
|
||||
grid_ext = 'jpg'
|
||||
grid_format = 'jpeg'
|
||||
elif grid_format[0] == 'webp':
|
||||
grid_quality = int(grid_format) if len(grid_format) > 1 else 100
|
||||
grid_ext = 'webp'
|
||||
grid_format = 'webp'
|
||||
if grid_quality < 0: # e.g. webp:-100 for lossless mode
|
||||
grid_lossless = True
|
||||
grid_quality = abs(grid_quality)
|
||||
|
||||
#
|
||||
save_format = st.session_state["defaults"].general.save_format
|
||||
save_lossless = False
|
||||
save_quality = 100
|
||||
if save_format == 'png':
|
||||
save_ext = 'png'
|
||||
save_format = 'png'
|
||||
elif save_format in ['jpg', 'jpeg']:
|
||||
save_quality = int(save_format) if len(save_format) > 1 else 100
|
||||
save_ext = 'jpg'
|
||||
save_format = 'jpeg'
|
||||
elif save_format == 'webp':
|
||||
save_quality = int(save_format) if len(save_format) > 1 else 100
|
||||
save_ext = 'webp'
|
||||
save_format = 'webp'
|
||||
if save_quality < 0: # e.g. webp:-100 for lossless mode
|
||||
save_lossless = True
|
||||
save_quality = abs(save_quality)
|
||||
|
||||
# this should force GFPGAN and RealESRGAN onto the selected gpu as well
|
||||
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" # see issue #152
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = str(st.session_state["defaults"].general.gpu)
|
||||
|
||||
|
||||
# functions to load css locally OR remotely starts here. Options exist for future flexibility. Called as st.markdown with unsafe_allow_html as css injection
|
||||
# TODO, maybe look into async loading the file especially for remote fetching
|
||||
def local_css(file_name):
|
||||
with open(file_name) as f:
|
||||
st.markdown(f'<style>{f.read()}</style>', unsafe_allow_html=True)
|
||||
|
||||
def remote_css(url):
|
||||
st.markdown(f'<link href="{url}" rel="stylesheet">', unsafe_allow_html=True)
|
||||
|
||||
def load_css(isLocal, nameOrURL):
|
||||
if(isLocal):
|
||||
local_css(nameOrURL)
|
||||
else:
|
||||
remote_css(nameOrURL)
|
||||
|
||||
def set_page_title(title):
|
||||
"""
|
||||
Simple function to allows us to change the title dynamically.
|
||||
Normally you can use `st.set_page_config` to change the title but it can only be used once per app.
|
||||
"""
|
||||
|
||||
st.sidebar.markdown(unsafe_allow_html=True, body=f"""
|
||||
<iframe height=0 srcdoc="<script>
|
||||
const title = window.parent.document.querySelector('title') \
|
||||
|
||||
const oldObserver = window.parent.titleObserver
|
||||
if (oldObserver) {{
|
||||
oldObserver.disconnect()
|
||||
}} \
|
||||
|
||||
const newObserver = new MutationObserver(function(mutations) {{
|
||||
const target = mutations[0].target
|
||||
if (target.text !== '{title}') {{
|
||||
target.text = '{title}'
|
||||
}}
|
||||
}}) \
|
||||
|
||||
newObserver.observe(title, {{ childList: true }})
|
||||
window.parent.titleObserver = newObserver \
|
||||
|
||||
title.text = '{title}'
|
||||
</script>" />
|
||||
""")
|
||||
|
||||
class MemUsageMonitor(threading.Thread):
|
||||
stop_flag = False
|
||||
max_usage = 0
|
||||
total = -1
|
||||
|
||||
def __init__(self, name):
|
||||
threading.Thread.__init__(self)
|
||||
self.name = name
|
||||
|
||||
def run(self):
|
||||
try:
|
||||
pynvml.nvmlInit()
|
||||
except:
|
||||
logger.debug(f"[{self.name}] Unable to initialize NVIDIA management. No memory stats. \n")
|
||||
return
|
||||
logger.info(f"[{self.name}] Recording memory usage...\n")
|
||||
# Missing context
|
||||
#handle = pynvml.nvmlDeviceGetHandleByIndex(st.session_state['defaults'].general.gpu)
|
||||
handle = pynvml.nvmlDeviceGetHandleByIndex(0)
|
||||
self.total = pynvml.nvmlDeviceGetMemoryInfo(handle).total
|
||||
while not self.stop_flag:
|
||||
m = pynvml.nvmlDeviceGetMemoryInfo(handle)
|
||||
self.max_usage = max(self.max_usage, m.used)
|
||||
# logger.info(self.max_usage)
|
||||
time.sleep(0.1)
|
||||
logger.info(f"[{self.name}] Stopped recording.\n")
|
||||
pynvml.nvmlShutdown()
|
||||
|
||||
def read(self):
|
||||
return self.max_usage, self.total
|
||||
|
||||
def stop(self):
|
||||
self.stop_flag = True
|
||||
|
||||
def read_and_stop(self):
|
||||
self.stop_flag = True
|
||||
return self.max_usage, self.total
|
||||
|
||||
#
|
||||
def custom_models_available():
|
||||
with server_state_lock["custom_models"]:
|
||||
#
|
||||
# Allow for custom models to be used instead of the default one,
|
||||
# an example would be Waifu-Diffusion or any other fine tune of stable diffusion
|
||||
server_state["custom_models"]:sorted = []
|
||||
|
||||
for root, dirs, files in os.walk(os.path.join("models", "custom")):
|
||||
for file in files:
|
||||
if os.path.splitext(file)[1] == '.ckpt':
|
||||
server_state["custom_models"].append(os.path.splitext(file)[0])
|
||||
|
||||
with server_state_lock["CustomModel_available"]:
|
||||
if len(server_state["custom_models"]) > 0:
|
||||
server_state["CustomModel_available"] = True
|
||||
server_state["custom_models"].append("Stable Diffusion v1.5")
|
||||
else:
|
||||
server_state["CustomModel_available"] = False
|
||||
|
||||
#
|
||||
def GFPGAN_available():
|
||||
#with server_state_lock["GFPGAN_models"]:
|
||||
#
|
||||
|
||||
st.session_state["GFPGAN_models"]:sorted = []
|
||||
model = st.session_state["defaults"].model_manager.models.gfpgan
|
||||
|
||||
files_available = 0
|
||||
|
||||
for file in model['files']:
|
||||
if "save_location" in model['files'][file]:
|
||||
if os.path.exists(os.path.join(model['files'][file]['save_location'], model['files'][file]['file_name'] )):
|
||||
files_available += 1
|
||||
|
||||
elif os.path.exists(os.path.join(model['save_location'], model['files'][file]['file_name'] )):
|
||||
base_name = os.path.splitext(model['files'][file]['file_name'])[0]
|
||||
if "GFPGANv" in base_name:
|
||||
st.session_state["GFPGAN_models"].append(base_name)
|
||||
files_available += 1
|
||||
|
||||
# we need to show the other models from previous verions that we have on the
|
||||
# same directory in case we want to see how they perform vs each other.
|
||||
for root, dirs, files in os.walk(st.session_state['defaults'].general.GFPGAN_dir):
|
||||
for file in files:
|
||||
if os.path.splitext(file)[1] == '.pth':
|
||||
if os.path.splitext(file)[0] not in st.session_state["GFPGAN_models"]:
|
||||
st.session_state["GFPGAN_models"].append(os.path.splitext(file)[0])
|
||||
|
||||
|
||||
if len(st.session_state["GFPGAN_models"]) > 0 and files_available == len(model['files']):
|
||||
st.session_state["GFPGAN_available"] = True
|
||||
else:
|
||||
st.session_state["GFPGAN_available"] = False
|
||||
st.session_state["use_GFPGAN"] = False
|
||||
st.session_state["GFPGAN_model"] = "GFPGANv1.4"
|
||||
|
||||
#
|
||||
def RealESRGAN_available():
|
||||
#with server_state_lock["RealESRGAN_models"]:
|
||||
|
||||
st.session_state["RealESRGAN_models"]:sorted = []
|
||||
model = st.session_state["defaults"].model_manager.models.realesrgan
|
||||
for file in model['files']:
|
||||
if os.path.exists(os.path.join(model['save_location'], model['files'][file]['file_name'] )):
|
||||
base_name = os.path.splitext(model['files'][file]['file_name'])[0]
|
||||
st.session_state["RealESRGAN_models"].append(base_name)
|
||||
|
||||
if len(st.session_state["RealESRGAN_models"]) > 0:
|
||||
st.session_state["RealESRGAN_available"] = True
|
||||
else:
|
||||
st.session_state["RealESRGAN_available"] = False
|
||||
st.session_state["use_RealESRGAN"] = False
|
||||
st.session_state["RealESRGAN_model"] = "RealESRGAN_x4plus"
|
||||
#
|
||||
def LDSR_available():
|
||||
st.session_state["LDSR_models"]:sorted = []
|
||||
files_available = 0
|
||||
model = st.session_state["defaults"].model_manager.models.ldsr
|
||||
for file in model['files']:
|
||||
if os.path.exists(os.path.join(model['save_location'], model['files'][file]['file_name'] )):
|
||||
base_name = os.path.splitext(model['files'][file]['file_name'])[0]
|
||||
extension = os.path.splitext(model['files'][file]['file_name'])[1]
|
||||
if extension == ".ckpt":
|
||||
st.session_state["LDSR_models"].append(base_name)
|
||||
files_available += 1
|
||||
if files_available == len(model['files']):
|
||||
st.session_state["LDSR_available"] = True
|
||||
else:
|
||||
st.session_state["LDSR_available"] = False
|
||||
st.session_state["use_LDSR"] = False
|
||||
st.session_state["LDSR_model"] = "model"
|
182
webui/streamlit/scripts/sd_utils/bridge.py
Normal file
182
webui/streamlit/scripts/sd_utils/bridge.py
Normal file
@ -0,0 +1,182 @@
|
||||
# This file is part of sygil-webui (https://github.com/Sygil-Dev/sygil-webui/).
|
||||
|
||||
# Copyright 2022 Sygil-Dev team.
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as published by
|
||||
# the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
|
||||
# This program is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU Affero General Public License for more details.
|
||||
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
# base webui import and utils.
|
||||
#import streamlit as st
|
||||
|
||||
# We import hydralit like this to replace the previous stuff
|
||||
# we had with native streamlit as it lets ur replace things 1:1
|
||||
from nataili.util import logger
|
||||
|
||||
# streamlit imports
|
||||
|
||||
#streamlit components section
|
||||
|
||||
#other imports
|
||||
import requests, time, json, base64
|
||||
from io import BytesIO
|
||||
|
||||
# import custom components
|
||||
|
||||
|
||||
# end of imports
|
||||
#---------------------------------------------------------------------------------------------------------------
|
||||
|
||||
|
||||
|
||||
@logger.catch(reraise=True)
|
||||
def run_bridge(interval, api_key, horde_name, horde_url, priority_usernames, horde_max_pixels, horde_nsfw, horde_censor_nsfw, horde_blacklist, horde_censorlist):
|
||||
current_id = None
|
||||
current_payload = None
|
||||
loop_retry = 0
|
||||
# load the model for stable horde if its not in memory already
|
||||
# we should load it after we get the request from the API in
|
||||
# case the model is different from the loaded in memory but
|
||||
# for now we can load it here so its read right away.
|
||||
load_models(use_GFPGAN=True)
|
||||
while True:
|
||||
|
||||
if loop_retry > 10 and current_id:
|
||||
logger.info(f"Exceeded retry count {loop_retry} for generation id {current_id}. Aborting generation!")
|
||||
current_id = None
|
||||
current_payload = None
|
||||
current_generation = None
|
||||
loop_retry = 0
|
||||
elif current_id:
|
||||
logger.info(f"Retrying ({loop_retry}/10) for generation id {current_id}...")
|
||||
gen_dict = {
|
||||
"name": horde_name,
|
||||
"max_pixels": horde_max_pixels,
|
||||
"priority_usernames": priority_usernames,
|
||||
"nsfw": horde_nsfw,
|
||||
"blacklist": horde_blacklist,
|
||||
"models": ["stable_diffusion"],
|
||||
}
|
||||
headers = {"apikey": api_key}
|
||||
if current_id:
|
||||
loop_retry += 1
|
||||
else:
|
||||
try:
|
||||
pop_req = requests.post(horde_url + '/api/v2/generate/pop', json = gen_dict, headers = headers)
|
||||
except requests.exceptions.ConnectionError:
|
||||
logger.warning(f"Server {horde_url} unavailable during pop. Waiting 10 seconds...")
|
||||
time.sleep(10)
|
||||
continue
|
||||
except requests.exceptions.JSONDecodeError():
|
||||
logger.warning(f"Server {horde_url} unavailable during pop. Waiting 10 seconds...")
|
||||
time.sleep(10)
|
||||
continue
|
||||
try:
|
||||
pop = pop_req.json()
|
||||
except json.decoder.JSONDecodeError:
|
||||
logger.warning(f"Could not decode response from {horde_url} as json. Please inform its administrator!")
|
||||
time.sleep(interval)
|
||||
continue
|
||||
if pop == None:
|
||||
logger.warning(f"Something has gone wrong with {horde_url}. Please inform its administrator!")
|
||||
time.sleep(interval)
|
||||
continue
|
||||
if not pop_req.ok:
|
||||
message = pop['message']
|
||||
logger.warning(f"During gen pop, server {horde_url} responded with status code {pop_req.status_code}: {pop['message']}. Waiting for 10 seconds...")
|
||||
if 'errors' in pop:
|
||||
logger.debug(f"Detailed Request Errors: {pop['errors']}")
|
||||
time.sleep(10)
|
||||
continue
|
||||
if not pop.get("id"):
|
||||
skipped_info = pop.get('skipped')
|
||||
if skipped_info and len(skipped_info):
|
||||
skipped_info = f" Skipped Info: {skipped_info}."
|
||||
else:
|
||||
skipped_info = ''
|
||||
logger.info(f"Server {horde_url} has no valid generations to do for us.{skipped_info}")
|
||||
time.sleep(interval)
|
||||
continue
|
||||
current_id = pop['id']
|
||||
logger.info(f"Request with id {current_id} picked up. Initiating work...")
|
||||
current_payload = pop['payload']
|
||||
if 'toggles' in current_payload and current_payload['toggles'] == None:
|
||||
logger.error(f"Received Bad payload: {pop}")
|
||||
current_id = None
|
||||
current_payload = None
|
||||
current_generation = None
|
||||
loop_retry = 0
|
||||
time.sleep(10)
|
||||
continue
|
||||
|
||||
logger.debug(current_payload)
|
||||
current_payload['toggles'] = current_payload.get('toggles', [1,4])
|
||||
# In bridge-mode, matrix is prepared on the horde and split in multiple nodes
|
||||
if 0 in current_payload['toggles']:
|
||||
current_payload['toggles'].remove(0)
|
||||
if 8 not in current_payload['toggles']:
|
||||
if horde_censor_nsfw and not horde_nsfw:
|
||||
current_payload['toggles'].append(8)
|
||||
elif any(word in current_payload['prompt'] for word in horde_censorlist):
|
||||
current_payload['toggles'].append(8)
|
||||
|
||||
from txt2img import txt2img
|
||||
|
||||
|
||||
"""{'prompt': 'Centred Husky, inside spiral with circular patterns, trending on dribbble, knotwork, spirals, key patterns,
|
||||
zoomorphics, ', 'ddim_steps': 30, 'n_iter': 1, 'sampler_name': 'DDIM', 'cfg_scale': 16.0, 'seed': '3405278433', 'height': 512, 'width': 512}"""
|
||||
|
||||
#images, seed, info, stats = txt2img(**current_payload)
|
||||
images, seed, info, stats = txt2img(str(current_payload['prompt']), int(current_payload['ddim_steps']), str(current_payload['sampler_name']),
|
||||
int(current_payload['n_iter']), 1, float(current_payload["cfg_scale"]), str(current_payload["seed"]),
|
||||
int(current_payload["height"]), int(current_payload["width"]), save_grid=False, group_by_prompt=False,
|
||||
save_individual_images=False,write_info_files=False)
|
||||
|
||||
buffer = BytesIO()
|
||||
# We send as WebP to avoid using all the horde bandwidth
|
||||
images[0].save(buffer, format="WebP", quality=90)
|
||||
# logger.info(info)
|
||||
submit_dict = {
|
||||
"id": current_id,
|
||||
"generation": base64.b64encode(buffer.getvalue()).decode("utf8"),
|
||||
"api_key": api_key,
|
||||
"seed": seed,
|
||||
"max_pixels": horde_max_pixels,
|
||||
}
|
||||
current_generation = seed
|
||||
while current_id and current_generation != None:
|
||||
try:
|
||||
submit_req = requests.post(horde_url + '/api/v2/generate/submit', json = submit_dict, headers = headers)
|
||||
try:
|
||||
submit = submit_req.json()
|
||||
except json.decoder.JSONDecodeError:
|
||||
logger.error(f"Something has gone wrong with {horde_url} during submit. Please inform its administrator! (Retry {loop_retry}/10)")
|
||||
time.sleep(interval)
|
||||
continue
|
||||
if submit_req.status_code == 404:
|
||||
logger.info(f"The generation we were working on got stale. Aborting!")
|
||||
elif not submit_req.ok:
|
||||
logger.error(f"During gen submit, server {horde_url} responded with status code {submit_req.status_code}: {submit['message']}. Waiting for 10 seconds... (Retry {loop_retry}/10)")
|
||||
if 'errors' in submit:
|
||||
logger.debug(f"Detailed Request Errors: {submit['errors']}")
|
||||
time.sleep(10)
|
||||
continue
|
||||
else:
|
||||
logger.info(f'Submitted generation with id {current_id} and contributed for {submit_req.json()["reward"]}')
|
||||
current_id = None
|
||||
current_payload = None
|
||||
current_generation = None
|
||||
loop_retry = 0
|
||||
except requests.exceptions.ConnectionError:
|
||||
logger.warning(f"Server {horde_url} unavailable during submit. Waiting 10 seconds... (Retry {loop_retry}/10)")
|
||||
time.sleep(10)
|
||||
continue
|
||||
time.sleep(interval)
|
938
webui/streamlit/scripts/textual_inversion.py
Normal file
938
webui/streamlit/scripts/textual_inversion.py
Normal file
@ -0,0 +1,938 @@
|
||||
# This file is part of sygil-webui (https://github.com/Sygil-Dev/sygil-webui/).
|
||||
|
||||
# Copyright 2022 Sygil-Dev team.
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as published by
|
||||
# the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
|
||||
# This program is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU Affero General Public License for more details.
|
||||
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
# base webui import and utils.
|
||||
from sd_utils import st, set_page_title, seed_to_int
|
||||
|
||||
# streamlit imports
|
||||
from streamlit.runtime.scriptrunner import StopException
|
||||
from streamlit_tensorboard import st_tensorboard
|
||||
|
||||
#streamlit components section
|
||||
from streamlit_server_state import server_state
|
||||
|
||||
#other imports
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
|
||||
# Temp imports
|
||||
|
||||
import itertools
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
#import datetime
|
||||
#from pathlib import Path
|
||||
#from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
import PIL
|
||||
from accelerate import Accelerator
|
||||
from accelerate.logging import get_logger
|
||||
from accelerate.utils import set_seed
|
||||
from diffusers import AutoencoderKL, DDPMScheduler, LMSDiscreteScheduler, StableDiffusionPipeline, UNet2DConditionModel#, PNDMScheduler
|
||||
from diffusers.optimization import get_scheduler
|
||||
#from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
|
||||
from pipelines.stable_diffusion.no_check import NoCheck
|
||||
from PIL import Image
|
||||
from torchvision import transforms
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
from slugify import slugify
|
||||
import json
|
||||
import os#, subprocess
|
||||
#from io import StringIO
|
||||
|
||||
|
||||
# end of imports
|
||||
#---------------------------------------------------------------------------------------------------------------
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
imagenet_templates_small = [
|
||||
"a photo of a {}",
|
||||
"a rendering of a {}",
|
||||
"a cropped photo of the {}",
|
||||
"the photo of a {}",
|
||||
"a photo of a clean {}",
|
||||
"a photo of a dirty {}",
|
||||
"a dark photo of the {}",
|
||||
"a photo of my {}",
|
||||
"a photo of the cool {}",
|
||||
"a close-up photo of a {}",
|
||||
"a bright photo of the {}",
|
||||
"a cropped photo of a {}",
|
||||
"a photo of the {}",
|
||||
"a good photo of the {}",
|
||||
"a photo of one {}",
|
||||
"a close-up photo of the {}",
|
||||
"a rendition of the {}",
|
||||
"a photo of the clean {}",
|
||||
"a rendition of a {}",
|
||||
"a photo of a nice {}",
|
||||
"a good photo of a {}",
|
||||
"a photo of the nice {}",
|
||||
"a photo of the small {}",
|
||||
"a photo of the weird {}",
|
||||
"a photo of the large {}",
|
||||
"a photo of a cool {}",
|
||||
"a photo of a small {}",
|
||||
]
|
||||
|
||||
imagenet_style_templates_small = [
|
||||
"a painting in the style of {}",
|
||||
"a rendering in the style of {}",
|
||||
"a cropped painting in the style of {}",
|
||||
"the painting in the style of {}",
|
||||
"a clean painting in the style of {}",
|
||||
"a dirty painting in the style of {}",
|
||||
"a dark painting in the style of {}",
|
||||
"a picture in the style of {}",
|
||||
"a cool painting in the style of {}",
|
||||
"a close-up painting in the style of {}",
|
||||
"a bright painting in the style of {}",
|
||||
"a cropped painting in the style of {}",
|
||||
"a good painting in the style of {}",
|
||||
"a close-up painting in the style of {}",
|
||||
"a rendition in the style of {}",
|
||||
"a nice painting in the style of {}",
|
||||
"a small painting in the style of {}",
|
||||
"a weird painting in the style of {}",
|
||||
"a large painting in the style of {}",
|
||||
]
|
||||
|
||||
class TextualInversionDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
data_root,
|
||||
tokenizer,
|
||||
learnable_property="object", # [object, style]
|
||||
size=512,
|
||||
repeats=100,
|
||||
interpolation="bicubic",
|
||||
set="train",
|
||||
placeholder_token="*",
|
||||
center_crop=False,
|
||||
templates=None
|
||||
):
|
||||
|
||||
self.data_root = data_root
|
||||
self.tokenizer = tokenizer
|
||||
self.learnable_property = learnable_property
|
||||
self.size = size
|
||||
self.placeholder_token = placeholder_token
|
||||
self.center_crop = center_crop
|
||||
|
||||
self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root) if file_path.lower().endswith(('.png', '.jpg', '.jpeg'))]
|
||||
|
||||
self.num_images = len(self.image_paths)
|
||||
self._length = self.num_images
|
||||
|
||||
if set == "train":
|
||||
self._length = self.num_images * repeats
|
||||
|
||||
self.interpolation = {
|
||||
"linear": PIL.Image.LINEAR,
|
||||
"bilinear": PIL.Image.Resampling.BILINEAR,
|
||||
"bicubic": PIL.Image.Resampling.BICUBIC,
|
||||
"lanczos": PIL.Image.Resampling.LANCZOS,
|
||||
}[interpolation]
|
||||
|
||||
self.templates = templates
|
||||
self.cache = {}
|
||||
self.tokenized_templates = [self.tokenizer(
|
||||
text.format(self.placeholder_token),
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
return_tensors="pt",
|
||||
).input_ids[0] for text in self.templates]
|
||||
|
||||
def __len__(self):
|
||||
return self._length
|
||||
|
||||
def get_example(self, image_path, flipped):
|
||||
if image_path in self.cache:
|
||||
return self.cache[image_path]
|
||||
|
||||
example = {}
|
||||
image = Image.open(image_path)
|
||||
|
||||
if not image.mode == "RGB":
|
||||
image = image.convert("RGB")
|
||||
|
||||
# default to score-sde preprocessing
|
||||
img = np.array(image).astype(np.uint8)
|
||||
if self.center_crop:
|
||||
crop = min(img.shape[0], img.shape[1])
|
||||
h, w, = (
|
||||
img.shape[0],
|
||||
img.shape[1],
|
||||
)
|
||||
img = img[(h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2]
|
||||
image = Image.fromarray(img)
|
||||
image = image.resize((self.size, self.size), resample=self.interpolation)
|
||||
image = transforms.RandomHorizontalFlip(p=1 if flipped else 0)(image)
|
||||
image = np.array(image).astype(np.uint8)
|
||||
image = (image / 127.5 - 1.0).astype(np.float32)
|
||||
example["key"] = "-".join([image_path, "-", str(flipped)])
|
||||
example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1)
|
||||
|
||||
self.cache[image_path] = example
|
||||
return example
|
||||
|
||||
def __getitem__(self, i):
|
||||
flipped = random.choice([False, True])
|
||||
example = self.get_example(self.image_paths[i % self.num_images], flipped)
|
||||
example["input_ids"] = random.choice(self.tokenized_templates)
|
||||
return example
|
||||
|
||||
|
||||
def freeze_params(params):
|
||||
for param in params:
|
||||
param.requires_grad = False
|
||||
|
||||
|
||||
def save_resume_file(basepath, extra = {}, config=''):
|
||||
info = {"args": config["args"]}
|
||||
info["args"].update(extra)
|
||||
|
||||
with open(f"{os.path.join(basepath, 'resume.json')}", "w") as f:
|
||||
#print (info)
|
||||
json.dump(info, f, indent=4)
|
||||
|
||||
with open(f"{basepath}/token_identifier.txt", "w") as f:
|
||||
f.write(f"{config['args']['placeholder_token']}")
|
||||
|
||||
with open(f"{basepath}/type_of_concept.txt", "w") as f:
|
||||
f.write(f"{config['args']['learnable_property']}")
|
||||
|
||||
config['args'] = info["args"]
|
||||
|
||||
return config['args']
|
||||
|
||||
class Checkpointer:
|
||||
def __init__(
|
||||
self,
|
||||
accelerator,
|
||||
vae,
|
||||
unet,
|
||||
tokenizer,
|
||||
placeholder_token,
|
||||
placeholder_token_id,
|
||||
templates,
|
||||
output_dir,
|
||||
random_sample_batches,
|
||||
sample_batch_size,
|
||||
stable_sample_batches,
|
||||
seed
|
||||
):
|
||||
self.accelerator = accelerator
|
||||
self.vae = vae
|
||||
self.unet = unet
|
||||
self.tokenizer = tokenizer
|
||||
self.placeholder_token = placeholder_token
|
||||
self.placeholder_token_id = placeholder_token_id
|
||||
self.templates = templates
|
||||
self.output_dir = output_dir
|
||||
self.seed = seed
|
||||
self.random_sample_batches = random_sample_batches
|
||||
self.sample_batch_size = sample_batch_size
|
||||
self.stable_sample_batches = stable_sample_batches
|
||||
|
||||
@torch.no_grad()
|
||||
def checkpoint(self, step, text_encoder, save_samples=True, path=None):
|
||||
print("Saving checkpoint for step %d..." % step)
|
||||
with torch.autocast("cuda"):
|
||||
if path is None:
|
||||
checkpoints_path = f"{self.output_dir}/checkpoints"
|
||||
os.makedirs(checkpoints_path, exist_ok=True)
|
||||
|
||||
unwrapped = self.accelerator.unwrap_model(text_encoder)
|
||||
|
||||
# Save a checkpoint
|
||||
learned_embeds = unwrapped.get_input_embeddings().weight[self.placeholder_token_id]
|
||||
learned_embeds_dict = {self.placeholder_token: learned_embeds.detach().cpu()}
|
||||
|
||||
filename = f"%s_%d.bin" % (slugify(self.placeholder_token), step)
|
||||
if path is not None:
|
||||
torch.save(learned_embeds_dict, path)
|
||||
else:
|
||||
torch.save(learned_embeds_dict, f"{checkpoints_path}/{filename}")
|
||||
torch.save(learned_embeds_dict, f"{checkpoints_path}/last.bin")
|
||||
|
||||
del unwrapped
|
||||
del learned_embeds
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def save_samples(self, step, text_encoder, height, width, guidance_scale, eta, num_inference_steps):
|
||||
samples_path = f"{self.output_dir}/concept_images"
|
||||
os.makedirs(samples_path, exist_ok=True)
|
||||
|
||||
#if "checker" not in server_state['textual_inversion']:
|
||||
#with server_state_lock['textual_inversion']["checker"]:
|
||||
server_state['textual_inversion']["checker"] = NoCheck()
|
||||
|
||||
#if "unwrapped" not in server_state['textual_inversion']:
|
||||
# with server_state_lock['textual_inversion']["unwrapped"]:
|
||||
server_state['textual_inversion']["unwrapped"] = self.accelerator.unwrap_model(text_encoder)
|
||||
|
||||
#if "pipeline" not in server_state['textual_inversion']:
|
||||
# with server_state_lock['textual_inversion']["pipeline"]:
|
||||
# Save a sample image
|
||||
server_state['textual_inversion']["pipeline"] = StableDiffusionPipeline(
|
||||
text_encoder=server_state['textual_inversion']["unwrapped"],
|
||||
vae=self.vae,
|
||||
unet=self.unet,
|
||||
tokenizer=self.tokenizer,
|
||||
scheduler=LMSDiscreteScheduler(
|
||||
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
|
||||
),
|
||||
safety_checker=NoCheck(),
|
||||
feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
|
||||
).to("cuda")
|
||||
|
||||
server_state['textual_inversion']["pipeline"].enable_attention_slicing()
|
||||
|
||||
if self.stable_sample_batches > 0:
|
||||
stable_latents = torch.randn(
|
||||
(self.sample_batch_size, server_state['textual_inversion']["pipeline"].unet.in_channels, height // 8, width // 8),
|
||||
device=server_state['textual_inversion']["pipeline"].device,
|
||||
generator=torch.Generator(device=server_state['textual_inversion']["pipeline"].device).manual_seed(self.seed),
|
||||
)
|
||||
|
||||
stable_prompts = [choice.format(self.placeholder_token) for choice in (self.templates * self.sample_batch_size)[:self.sample_batch_size]]
|
||||
|
||||
# Generate and save stable samples
|
||||
for i in range(0, self.stable_sample_batches):
|
||||
samples = server_state['textual_inversion']["pipeline"](
|
||||
prompt=stable_prompts,
|
||||
height=384,
|
||||
latents=stable_latents,
|
||||
width=384,
|
||||
guidance_scale=guidance_scale,
|
||||
eta=eta,
|
||||
num_inference_steps=num_inference_steps,
|
||||
output_type='pil'
|
||||
)["sample"]
|
||||
|
||||
for idx, im in enumerate(samples):
|
||||
filename = f"stable_sample_%d_%d_step_%d.png" % (i+1, idx+1, step)
|
||||
im.save(f"{samples_path}/{filename}")
|
||||
del samples
|
||||
del stable_latents
|
||||
|
||||
prompts = [choice.format(self.placeholder_token) for choice in random.choices(self.templates, k=self.sample_batch_size)]
|
||||
# Generate and save random samples
|
||||
for i in range(0, self.random_sample_batches):
|
||||
samples = server_state['textual_inversion']["pipeline"](
|
||||
prompt=prompts,
|
||||
height=384,
|
||||
width=384,
|
||||
guidance_scale=guidance_scale,
|
||||
eta=eta,
|
||||
num_inference_steps=num_inference_steps,
|
||||
output_type='pil'
|
||||
)["sample"]
|
||||
for idx, im in enumerate(samples):
|
||||
filename = f"step_%d_sample_%d_%d.png" % (step, i+1, idx+1)
|
||||
im.save(f"{samples_path}/{filename}")
|
||||
del samples
|
||||
|
||||
del server_state['textual_inversion']["checker"]
|
||||
del server_state['textual_inversion']["unwrapped"]
|
||||
del server_state['textual_inversion']["pipeline"]
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
#@retry(RuntimeError, tries=5)
|
||||
def textual_inversion(config):
|
||||
print ("Running textual inversion.")
|
||||
|
||||
#if "pipeline" in server_state["textual_inversion"]:
|
||||
#del server_state['textual_inversion']["checker"]
|
||||
#del server_state['textual_inversion']["unwrapped"]
|
||||
#del server_state['textual_inversion']["pipeline"]
|
||||
#torch.cuda.empty_cache()
|
||||
|
||||
global_step_offset = 0
|
||||
|
||||
#print(config['args']['resume_from'])
|
||||
if config['args']['resume_from']:
|
||||
try:
|
||||
basepath = f"{config['args']['resume_from']}"
|
||||
|
||||
with open(f"{basepath}/resume.json", 'r') as f:
|
||||
state = json.load(f)
|
||||
|
||||
global_step_offset = state["args"].get("global_step", 0)
|
||||
|
||||
print("Resuming state from %s" % config['args']['resume_from'])
|
||||
print("We've trained %d steps so far" % global_step_offset)
|
||||
|
||||
except json.decoder.JSONDecodeError:
|
||||
pass
|
||||
else:
|
||||
basepath = f"{config['args']['output_dir']}/{slugify(config['args']['placeholder_token'])}"
|
||||
os.makedirs(basepath, exist_ok=True)
|
||||
|
||||
|
||||
accelerator = Accelerator(
|
||||
gradient_accumulation_steps=config['args']['gradient_accumulation_steps'],
|
||||
mixed_precision=config['args']['mixed_precision']
|
||||
)
|
||||
|
||||
# If passed along, set the training seed.
|
||||
if config['args']['seed']:
|
||||
set_seed(config['args']['seed'])
|
||||
|
||||
#if "tokenizer" not in server_state["textual_inversion"]:
|
||||
# Load the tokenizer and add the placeholder token as a additional special token
|
||||
#with server_state_lock['textual_inversion']["tokenizer"]:
|
||||
if config['args']['tokenizer_name']:
|
||||
server_state['textual_inversion']["tokenizer"] = CLIPTokenizer.from_pretrained(config['args']['tokenizer_name'])
|
||||
elif config['args']['pretrained_model_name_or_path']:
|
||||
server_state['textual_inversion']["tokenizer"] = CLIPTokenizer.from_pretrained(
|
||||
config['args']['pretrained_model_name_or_path'] + '/tokenizer'
|
||||
)
|
||||
|
||||
# Add the placeholder token in tokenizer
|
||||
num_added_tokens = server_state['textual_inversion']["tokenizer"].add_tokens(config['args']['placeholder_token'])
|
||||
if num_added_tokens == 0:
|
||||
st.error(
|
||||
f"The tokenizer already contains the token {config['args']['placeholder_token']}. Please pass a different"
|
||||
" `placeholder_token` that is not already in the tokenizer."
|
||||
)
|
||||
|
||||
# Convert the initializer_token, placeholder_token to ids
|
||||
token_ids = server_state['textual_inversion']["tokenizer"].encode(config['args']['initializer_token'], add_special_tokens=False)
|
||||
# Check if initializer_token is a single token or a sequence of tokens
|
||||
if len(token_ids) > 1:
|
||||
st.error("The initializer token must be a single token.")
|
||||
|
||||
initializer_token_id = token_ids[0]
|
||||
placeholder_token_id = server_state['textual_inversion']["tokenizer"].convert_tokens_to_ids(config['args']['placeholder_token'])
|
||||
|
||||
#if "text_encoder" not in server_state['textual_inversion']:
|
||||
# Load models and create wrapper for stable diffusion
|
||||
#with server_state_lock['textual_inversion']["text_encoder"]:
|
||||
server_state['textual_inversion']["text_encoder"] = CLIPTextModel.from_pretrained(
|
||||
config['args']['pretrained_model_name_or_path'] + '/text_encoder',
|
||||
)
|
||||
|
||||
#if "vae" not in server_state['textual_inversion']:
|
||||
#with server_state_lock['textual_inversion']["vae"]:
|
||||
server_state['textual_inversion']["vae"] = AutoencoderKL.from_pretrained(
|
||||
config['args']['pretrained_model_name_or_path'] + '/vae',
|
||||
)
|
||||
|
||||
#if "unet" not in server_state['textual_inversion']:
|
||||
#with server_state_lock['textual_inversion']["unet"]:
|
||||
server_state['textual_inversion']["unet"] = UNet2DConditionModel.from_pretrained(
|
||||
config['args']['pretrained_model_name_or_path'] + '/unet',
|
||||
)
|
||||
|
||||
base_templates = imagenet_style_templates_small if config['args']['learnable_property'] == "style" else imagenet_templates_small
|
||||
if config['args']['custom_templates']:
|
||||
templates = config['args']['custom_templates'].split(";")
|
||||
else:
|
||||
templates = base_templates
|
||||
|
||||
slice_size = server_state['textual_inversion']["unet"].config.attention_head_dim // 2
|
||||
server_state['textual_inversion']["unet"].set_attention_slice(slice_size)
|
||||
|
||||
# Resize the token embeddings as we are adding new special tokens to the tokenizer
|
||||
server_state['textual_inversion']["text_encoder"].resize_token_embeddings(len(server_state['textual_inversion']["tokenizer"]))
|
||||
|
||||
# Initialise the newly added placeholder token with the embeddings of the initializer token
|
||||
token_embeds = server_state['textual_inversion']["text_encoder"].get_input_embeddings().weight.data
|
||||
|
||||
if "resume_checkpoint" in config['args']:
|
||||
if config['args']['resume_checkpoint'] is not None:
|
||||
token_embeds[placeholder_token_id] = torch.load(config['args']['resume_checkpoint'])[config['args']['placeholder_token']]
|
||||
else:
|
||||
token_embeds[placeholder_token_id] = token_embeds[initializer_token_id]
|
||||
|
||||
# Freeze vae and unet
|
||||
freeze_params(server_state['textual_inversion']["vae"].parameters())
|
||||
freeze_params(server_state['textual_inversion']["unet"].parameters())
|
||||
# Freeze all parameters except for the token embeddings in text encoder
|
||||
params_to_freeze = itertools.chain(
|
||||
server_state['textual_inversion']["text_encoder"].text_model.encoder.parameters(),
|
||||
server_state['textual_inversion']["text_encoder"].text_model.final_layer_norm.parameters(),
|
||||
server_state['textual_inversion']["text_encoder"].text_model.embeddings.position_embedding.parameters(),
|
||||
)
|
||||
freeze_params(params_to_freeze)
|
||||
|
||||
checkpointer = Checkpointer(
|
||||
accelerator=accelerator,
|
||||
vae=server_state['textual_inversion']["vae"],
|
||||
unet=server_state['textual_inversion']["unet"],
|
||||
tokenizer=server_state['textual_inversion']["tokenizer"],
|
||||
placeholder_token=config['args']['placeholder_token'],
|
||||
placeholder_token_id=placeholder_token_id,
|
||||
templates=templates,
|
||||
output_dir=basepath,
|
||||
sample_batch_size=config['args']['sample_batch_size'],
|
||||
random_sample_batches=config['args']['random_sample_batches'],
|
||||
stable_sample_batches=config['args']['stable_sample_batches'],
|
||||
seed=config['args']['seed']
|
||||
)
|
||||
|
||||
if config['args']['scale_lr']:
|
||||
config['args']['learning_rate'] = (
|
||||
config['args']['learning_rate'] * config[
|
||||
'args']['gradient_accumulation_steps'] * config['args']['train_batch_size'] * accelerator.num_processes
|
||||
)
|
||||
|
||||
# Initialize the optimizer
|
||||
optimizer = torch.optim.AdamW(
|
||||
server_state['textual_inversion']["text_encoder"].get_input_embeddings().parameters(), # only optimize the embeddings
|
||||
lr=config['args']['learning_rate'],
|
||||
betas=(config['args']['adam_beta1'], config['args']['adam_beta2']),
|
||||
weight_decay=config['args']['adam_weight_decay'],
|
||||
eps=config['args']['adam_epsilon'],
|
||||
)
|
||||
|
||||
# TODO (patil-suraj): load scheduler using args
|
||||
noise_scheduler = DDPMScheduler(
|
||||
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, tensor_format="pt"
|
||||
)
|
||||
|
||||
train_dataset = TextualInversionDataset(
|
||||
data_root=config['args']['train_data_dir'],
|
||||
tokenizer=server_state['textual_inversion']["tokenizer"],
|
||||
size=config['args']['resolution'],
|
||||
placeholder_token=config['args']['placeholder_token'],
|
||||
repeats=config['args']['repeats'],
|
||||
learnable_property=config['args']['learnable_property'],
|
||||
center_crop=config['args']['center_crop'],
|
||||
set="train",
|
||||
templates=templates
|
||||
)
|
||||
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=config['args']['train_batch_size'], shuffle=True)
|
||||
|
||||
# Scheduler and math around the number of training steps.
|
||||
overrode_max_train_steps = False
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / config['args']['gradient_accumulation_steps'])
|
||||
if config['args']['max_train_steps'] is None:
|
||||
config['args']['max_train_steps'] = config['args']['num_train_epochs'] * num_update_steps_per_epoch
|
||||
overrode_max_train_steps = True
|
||||
|
||||
lr_scheduler = get_scheduler(
|
||||
config['args']['lr_scheduler'],
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=config['args']['lr_warmup_steps'] * config['args']['gradient_accumulation_steps'],
|
||||
num_training_steps=config['args']['max_train_steps'] * config['args']['gradient_accumulation_steps'],
|
||||
)
|
||||
|
||||
server_state['textual_inversion']["text_encoder"], optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
server_state['textual_inversion']["text_encoder"], optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
|
||||
# Move vae and unet to device
|
||||
server_state['textual_inversion']["vae"].to(accelerator.device)
|
||||
server_state['textual_inversion']["unet"].to(accelerator.device)
|
||||
|
||||
# Keep vae and unet in eval mode as we don't train these
|
||||
server_state['textual_inversion']["vae"].eval()
|
||||
server_state['textual_inversion']["unet"].eval()
|
||||
|
||||
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / config['args']['gradient_accumulation_steps'])
|
||||
if overrode_max_train_steps:
|
||||
config['args']['max_train_steps'] = config['args']['num_train_epochs'] * num_update_steps_per_epoch
|
||||
# Afterwards we recalculate our number of training epochs
|
||||
config['args']['num_train_epochs'] = math.ceil(config['args']['max_train_steps'] / num_update_steps_per_epoch)
|
||||
|
||||
# We need to initialize the trackers we use, and also store our configuration.
|
||||
# The trackers initializes automatically on the main process.
|
||||
if accelerator.is_main_process:
|
||||
accelerator.init_trackers("textual_inversion", config=config['args'])
|
||||
|
||||
# Train!
|
||||
total_batch_size = config['args']['train_batch_size'] * accelerator.num_processes * st.session_state[
|
||||
'textual_inversion']['args']['gradient_accumulation_steps']
|
||||
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(f" Num examples = {len(train_dataset)}")
|
||||
logger.info(f" Num Epochs = {config['args']['num_train_epochs']}")
|
||||
logger.info(f" Instantaneous batch size per device = {config['args']['train_batch_size']}")
|
||||
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
||||
logger.info(f" Gradient Accumulation steps = {config['args']['gradient_accumulation_steps']}")
|
||||
logger.info(f" Total optimization steps = {config['args']['max_train_steps']}")
|
||||
# Only show the progress bar once on each machine.
|
||||
progress_bar = tqdm(range(config['args']['max_train_steps']), disable=not accelerator.is_local_main_process)
|
||||
progress_bar.set_description("Steps")
|
||||
global_step = 0
|
||||
encoded_pixel_values_cache = {}
|
||||
|
||||
try:
|
||||
for epoch in range(config['args']['num_train_epochs']):
|
||||
server_state['textual_inversion']["text_encoder"].train()
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
with accelerator.accumulate(server_state['textual_inversion']["text_encoder"]):
|
||||
# Convert images to latent space
|
||||
key = "|".join(batch["key"])
|
||||
if encoded_pixel_values_cache.get(key, None) is None:
|
||||
encoded_pixel_values_cache[key] = server_state['textual_inversion']["vae"].encode(batch["pixel_values"]).latent_dist
|
||||
latents = encoded_pixel_values_cache[key].sample().detach().half() * 0.18215
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn(latents.shape).to(latents.device)
|
||||
bsz = latents.shape[0]
|
||||
# Sample a random timestep for each image
|
||||
timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device).long()
|
||||
|
||||
# Add noise to the latents according to the noise magnitude at each timestep
|
||||
# (this is the forward diffusion process)
|
||||
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
||||
|
||||
# Get the text embedding for conditioning
|
||||
encoder_hidden_states = server_state['textual_inversion']["text_encoder"](batch["input_ids"])[0]
|
||||
|
||||
# Predict the noise residual
|
||||
noise_pred = server_state['textual_inversion']["unet"](noisy_latents, timesteps, encoder_hidden_states).sample
|
||||
|
||||
loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
|
||||
accelerator.backward(loss)
|
||||
|
||||
# Zero out the gradients for all token embeddings except the newly added
|
||||
# embeddings for the concept, as we only want to optimize the concept embeddings
|
||||
if accelerator.num_processes > 1:
|
||||
grads = server_state['textual_inversion']["text_encoder"].module.get_input_embeddings().weight.grad
|
||||
else:
|
||||
grads = server_state['textual_inversion']["text_encoder"].get_input_embeddings().weight.grad
|
||||
# Get the index for tokens that we want to zero the grads for
|
||||
index_grads_to_zero = torch.arange(len(server_state['textual_inversion']["tokenizer"])) != placeholder_token_id
|
||||
grads.data[index_grads_to_zero, :] = grads.data[index_grads_to_zero, :].fill_(0)
|
||||
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
#try:
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
|
||||
if global_step % config['args']['checkpoint_frequency'] == 0 and global_step > 0 and accelerator.is_main_process:
|
||||
checkpointer.checkpoint(global_step + global_step_offset, server_state['textual_inversion']["text_encoder"])
|
||||
save_resume_file(basepath, {
|
||||
"global_step": global_step + global_step_offset,
|
||||
"resume_checkpoint": f"{basepath}/checkpoints/last.bin"
|
||||
}, config)
|
||||
|
||||
checkpointer.save_samples(
|
||||
global_step + global_step_offset,
|
||||
server_state['textual_inversion']["text_encoder"],
|
||||
config['args']['resolution'], config['args'][
|
||||
'resolution'], 7.5, 0.0, config['args']['sample_steps'])
|
||||
|
||||
checkpointer.checkpoint(
|
||||
global_step + global_step_offset,
|
||||
server_state['textual_inversion']["text_encoder"],
|
||||
path=f"{basepath}/learned_embeds.bin"
|
||||
)
|
||||
#except KeyError:
|
||||
#raise StopException
|
||||
|
||||
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
|
||||
progress_bar.set_postfix(**logs)
|
||||
|
||||
#accelerator.log(logs, step=global_step)
|
||||
|
||||
#try:
|
||||
if global_step >= config['args']['max_train_steps']:
|
||||
break
|
||||
#except:
|
||||
#pass
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# Create the pipeline using the trained modules and save it.
|
||||
if accelerator.is_main_process:
|
||||
print("Finished! Saving final checkpoint and resume state.")
|
||||
checkpointer.checkpoint(
|
||||
global_step + global_step_offset,
|
||||
server_state['textual_inversion']["text_encoder"],
|
||||
path=f"{basepath}/learned_embeds.bin"
|
||||
)
|
||||
|
||||
save_resume_file(basepath, {
|
||||
"global_step": global_step + global_step_offset,
|
||||
"resume_checkpoint": f"{basepath}/checkpoints/last.bin"
|
||||
}, config)
|
||||
|
||||
accelerator.end_training()
|
||||
|
||||
except (KeyboardInterrupt, StopException) as e:
|
||||
print(f"Received Streamlit StopException or KeyboardInterrupt")
|
||||
|
||||
if accelerator.is_main_process:
|
||||
print("Interrupted, saving checkpoint and resume state...")
|
||||
checkpointer.checkpoint(global_step + global_step_offset, server_state['textual_inversion']["text_encoder"])
|
||||
|
||||
config['args'] = save_resume_file(basepath, {
|
||||
"global_step": global_step + global_step_offset,
|
||||
"resume_checkpoint": f"{basepath}/checkpoints/last.bin"
|
||||
}, config)
|
||||
|
||||
|
||||
checkpointer.checkpoint(
|
||||
global_step + global_step_offset,
|
||||
server_state['textual_inversion']["text_encoder"],
|
||||
path=f"{basepath}/learned_embeds.bin"
|
||||
)
|
||||
|
||||
quit()
|
||||
|
||||
|
||||
def layout():
|
||||
|
||||
with st.form("textual-inversion"):
|
||||
#st.info("Under Construction. :construction_worker:")
|
||||
#parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
||||
|
||||
set_page_title("Textual Inversion - Stable Diffusion Playground")
|
||||
|
||||
config_tab, output_tab, tensorboard_tab = st.tabs(["Textual Inversion Config", "Ouput", "TensorBoard"])
|
||||
|
||||
with config_tab:
|
||||
col1, col2, col3, col4, col5 = st.columns(5, gap='large')
|
||||
|
||||
if "textual_inversion" not in st.session_state:
|
||||
st.session_state["textual_inversion"] = {}
|
||||
|
||||
if "textual_inversion" not in server_state:
|
||||
server_state["textual_inversion"] = {}
|
||||
|
||||
if "args" not in st.session_state["textual_inversion"]:
|
||||
st.session_state["textual_inversion"]["args"] = {}
|
||||
|
||||
|
||||
with col1:
|
||||
st.session_state["textual_inversion"]["args"]["pretrained_model_name_or_path"] = st.text_input("Pretrained Model Path",
|
||||
value=st.session_state["defaults"].textual_inversion.pretrained_model_name_or_path,
|
||||
help="Path to pretrained model or model identifier from huggingface.co/models.")
|
||||
|
||||
st.session_state["textual_inversion"]["args"]["tokenizer_name"] = st.text_input("Tokenizer Name",
|
||||
value=st.session_state["defaults"].textual_inversion.tokenizer_name,
|
||||
help="Pretrained tokenizer name or path if not the same as model_name")
|
||||
|
||||
st.session_state["textual_inversion"]["args"]["train_data_dir"] = st.text_input("train_data_dir", value="", help="A folder containing the training data.")
|
||||
|
||||
st.session_state["textual_inversion"]["args"]["placeholder_token"] = st.text_input("Placeholder Token", value="", help="A token to use as a placeholder for the concept.")
|
||||
|
||||
st.session_state["textual_inversion"]["args"]["initializer_token"] = st.text_input("Initializer Token", value="", help="A token to use as initializer word.")
|
||||
|
||||
st.session_state["textual_inversion"]["args"]["learnable_property"] = st.selectbox("Learnable Property", ["object", "style"], index=0, help="Choose between 'object' and 'style'")
|
||||
|
||||
st.session_state["textual_inversion"]["args"]["repeats"] = int(st.text_input("Number of times to Repeat", value=100, help="How many times to repeat the training data."))
|
||||
|
||||
with col2:
|
||||
st.session_state["textual_inversion"]["args"]["output_dir"] = st.text_input("Output Directory",
|
||||
value=str(os.path.join("outputs", "textual_inversion")),
|
||||
help="The output directory where the model predictions and checkpoints will be written.")
|
||||
|
||||
st.session_state["textual_inversion"]["args"]["seed"] = seed_to_int(st.text_input("Seed", value=0,
|
||||
help="A seed for reproducible training, if left empty a random one will be generated. Default: 0"))
|
||||
|
||||
st.session_state["textual_inversion"]["args"]["resolution"] = int(st.text_input("Resolution", value=512,
|
||||
help="The resolution for input images, all the images in the train/validation dataset will be resized to this resolution"))
|
||||
|
||||
st.session_state["textual_inversion"]["args"]["center_crop"] = st.checkbox("Center Image", value=True, help="Whether to center crop images before resizing to resolution")
|
||||
|
||||
st.session_state["textual_inversion"]["args"]["train_batch_size"] = int(st.text_input("Train Batch Size", value=1, help="Batch size (per device) for the training dataloader."))
|
||||
|
||||
st.session_state["textual_inversion"]["args"]["num_train_epochs"] = int(st.text_input("Number of Steps to Train", value=100, help="Number of steps to train."))
|
||||
|
||||
st.session_state["textual_inversion"]["args"]["max_train_steps"] = int(st.text_input("Max Number of Steps to Train", value=5000,
|
||||
help="Total number of training steps to perform. If provided, overrides 'Number of Steps to Train'."))
|
||||
|
||||
with col3:
|
||||
st.session_state["textual_inversion"]["args"]["gradient_accumulation_steps"] = int(st.text_input("Gradient Accumulation Steps", value=1,
|
||||
help="Number of updates steps to accumulate before performing a backward/update pass."))
|
||||
|
||||
st.session_state["textual_inversion"]["args"]["learning_rate"] = float(st.text_input("Learning Rate", value=5.0e-04,
|
||||
help="Initial learning rate (after the potential warmup period) to use."))
|
||||
|
||||
st.session_state["textual_inversion"]["args"]["scale_lr"] = st.checkbox("Scale Learning Rate", value=True,
|
||||
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.")
|
||||
|
||||
st.session_state["textual_inversion"]["args"]["lr_scheduler"] = st.text_input("Learning Rate Scheduler", value="constant",
|
||||
help=("The scheduler type to use. Choose between ['linear', 'cosine', 'cosine_with_restarts', 'polynomial',"
|
||||
" 'constant', 'constant_with_warmup']" ))
|
||||
|
||||
st.session_state["textual_inversion"]["args"]["lr_warmup_steps"] = int(st.text_input("Learning Rate Warmup Steps", value=500, help="Number of steps for the warmup in the lr scheduler."))
|
||||
|
||||
st.session_state["textual_inversion"]["args"]["adam_beta1"] = float(st.text_input("Adam Beta 1", value=0.9, help="The beta1 parameter for the Adam optimizer."))
|
||||
|
||||
st.session_state["textual_inversion"]["args"]["adam_beta2"] = float(st.text_input("Adam Beta 2", value=0.999, help="The beta2 parameter for the Adam optimizer."))
|
||||
|
||||
st.session_state["textual_inversion"]["args"]["adam_weight_decay"] = float(st.text_input("Adam Weight Decay", value=1e-2, help="Weight decay to use."))
|
||||
|
||||
st.session_state["textual_inversion"]["args"]["adam_epsilon"] = float(st.text_input("Adam Epsilon", value=1e-08, help="Epsilon value for the Adam optimizer"))
|
||||
|
||||
with col4:
|
||||
st.session_state["textual_inversion"]["args"]["mixed_precision"] = st.selectbox("Mixed Precision", ["no", "fp16", "bf16"], index=1,
|
||||
help="Whether to use mixed precision. Choose" "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
|
||||
"and an Nvidia Ampere GPU.")
|
||||
|
||||
st.session_state["textual_inversion"]["args"]["local_rank"] = int(st.text_input("Local Rank", value=1, help="For distributed training: local_rank"))
|
||||
|
||||
st.session_state["textual_inversion"]["args"]["checkpoint_frequency"] = int(st.text_input("Checkpoint Frequency", value=500, help="How often to save a checkpoint and sample image"))
|
||||
|
||||
# stable_sample_batches is crashing when saving the samples so for now I will disable it util its fixed.
|
||||
#st.session_state["textual_inversion"]["args"]["stable_sample_batches"] = int(st.text_input("Stable Sample Batches", value=0,
|
||||
#help="Number of fixed seed sample batches to generate per checkpoint"))
|
||||
|
||||
st.session_state["textual_inversion"]["args"]["stable_sample_batches"] = 0
|
||||
|
||||
st.session_state["textual_inversion"]["args"]["random_sample_batches"] = int(st.text_input("Random Sample Batches", value=2,
|
||||
help="Number of random seed sample batches to generate per checkpoint"))
|
||||
|
||||
st.session_state["textual_inversion"]["args"]["sample_batch_size"] = int(st.text_input("Sample Batch Size", value=1, help="Number of samples to generate per batch"))
|
||||
|
||||
st.session_state["textual_inversion"]["args"]["sample_steps"] = int(st.text_input("Sample Steps", value=100,
|
||||
help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes."))
|
||||
|
||||
st.session_state["textual_inversion"]["args"]["custom_templates"] = st.text_input("Custom Templates", value="",
|
||||
help="A semicolon-delimited list of custom template to use for samples, using {} as a placeholder for the concept.")
|
||||
with col5:
|
||||
st.session_state["textual_inversion"]["args"]["resume"] = st.checkbox(label="Resume Previous Run?", value=False,
|
||||
help="Resume previous run, if a valid resume.json file is on the output dir \
|
||||
it will be used, otherwise if the 'Resume From' field bellow contains a valid resume.json file \
|
||||
that one will be used.")
|
||||
|
||||
st.session_state["textual_inversion"]["args"]["resume_from"] = st.text_input(label="Resume From", help="Path to a directory to resume training from (ie, logs/token_name)")
|
||||
|
||||
#st.session_state["textual_inversion"]["args"]["resume_checkpoint"] = st.file_uploader("Resume Checkpoint", type=["bin"],
|
||||
#help="Path to a specific checkpoint to resume training from (ie, logs/token_name/checkpoints/something.bin).")
|
||||
|
||||
#st.session_state["textual_inversion"]["args"]["st.session_state["textual_inversion"]"] = st.file_uploader("st.session_state["textual_inversion"] File", type=["json"],
|
||||
#help="Path to a JSON st.session_state["textual_inversion"]uration file containing arguments for invoking this script."
|
||||
#"If resume_from is given, its resume.json takes priority over this.")
|
||||
#
|
||||
#print (os.path.join(st.session_state["textual_inversion"]["args"]["output_dir"],st.session_state["textual_inversion"]["args"]["placeholder_token"].strip("<>"),"resume.json"))
|
||||
#print (os.path.exists(os.path.join(st.session_state["textual_inversion"]["args"]["output_dir"],st.session_state["textual_inversion"]["args"]["placeholder_token"].strip("<>"),"resume.json")))
|
||||
if os.path.exists(os.path.join(st.session_state["textual_inversion"]["args"]["output_dir"],st.session_state["textual_inversion"]["args"]["placeholder_token"].strip("<>"),"resume.json")):
|
||||
st.session_state["textual_inversion"]["args"]["resume_from"] = os.path.join(
|
||||
st.session_state["textual_inversion"]["args"]["output_dir"], st.session_state["textual_inversion"]["args"]["placeholder_token"].strip("<>"))
|
||||
#print (st.session_state["textual_inversion"]["args"]["resume_from"])
|
||||
|
||||
if os.path.exists(os.path.join(st.session_state["textual_inversion"]["args"]["output_dir"],st.session_state["textual_inversion"]["args"]["placeholder_token"].strip("<>"), "checkpoints","last.bin")):
|
||||
st.session_state["textual_inversion"]["args"]["resume_checkpoint"] = os.path.join(
|
||||
st.session_state["textual_inversion"]["args"]["output_dir"], st.session_state["textual_inversion"]["args"]["placeholder_token"].strip("<>"), "checkpoints","last.bin")
|
||||
|
||||
#if "resume_from" in st.session_state["textual_inversion"]["args"]:
|
||||
#if st.session_state["textual_inversion"]["args"]["resume_from"]:
|
||||
#if os.path.exists(os.path.join(st.session_state["textual_inversion"]['args']['resume_from'], "resume.json")):
|
||||
#with open(os.path.join(st.session_state["textual_inversion"]['args']['resume_from'], "resume.json"), 'rt') as f:
|
||||
#try:
|
||||
#resume_json = json.load(f)["args"]
|
||||
#st.session_state["textual_inversion"]["args"] = OmegaConf.merge(st.session_state["textual_inversion"]["args"], resume_json)
|
||||
#st.session_state["textual_inversion"]["args"]["resume_from"] = os.path.join(
|
||||
#st.session_state["textual_inversion"]["args"]["output_dir"], st.session_state["textual_inversion"]["args"]["placeholder_token"].strip("<>"))
|
||||
#except json.decoder.JSONDecodeError:
|
||||
#pass
|
||||
|
||||
#print(st.session_state["textual_inversion"]["args"])
|
||||
#print(st.session_state["textual_inversion"]["args"]['resume_from'])
|
||||
|
||||
#elif st.session_state["textual_inversion"]["args"]["st.session_state["textual_inversion"]"] is not None:
|
||||
#with open(st.session_state["textual_inversion"]["args"]["st.session_state["textual_inversion"]"], 'rt') as f:
|
||||
#args = parser.parse_args(namespace=argparse.Namespace(**json.load(f)["args"]))
|
||||
|
||||
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
||||
if env_local_rank != -1 and env_local_rank != st.session_state["textual_inversion"]["args"]["local_rank"]:
|
||||
st.session_state["textual_inversion"]["args"]["local_rank"] = env_local_rank
|
||||
|
||||
if st.session_state["textual_inversion"]["args"]["train_data_dir"] is None:
|
||||
st.error("You must specify --train_data_dir")
|
||||
|
||||
if st.session_state["textual_inversion"]["args"]["pretrained_model_name_or_path"] is None:
|
||||
st.error("You must specify --pretrained_model_name_or_path")
|
||||
|
||||
if st.session_state["textual_inversion"]["args"]["placeholder_token"] is None:
|
||||
st.error("You must specify --placeholder_token")
|
||||
|
||||
if st.session_state["textual_inversion"]["args"]["initializer_token"] is None:
|
||||
st.error("You must specify --initializer_token")
|
||||
|
||||
if st.session_state["textual_inversion"]["args"]["output_dir"] is None:
|
||||
st.error("You must specify --output_dir")
|
||||
|
||||
# add a spacer and the submit button for the form.
|
||||
|
||||
st.session_state["textual_inversion"]["message"] = st.empty()
|
||||
st.session_state["textual_inversion"]["progress_bar"] = st.empty()
|
||||
|
||||
st.write("---")
|
||||
|
||||
submit = st.form_submit_button("Run",help="")
|
||||
if submit:
|
||||
if "pipe" in st.session_state:
|
||||
del st.session_state["pipe"]
|
||||
if "model" in st.session_state:
|
||||
del st.session_state["model"]
|
||||
|
||||
set_page_title("Running Textual Inversion - Stable Diffusion WebUI")
|
||||
#st.session_state["textual_inversion"]["message"].info("Textual Inversion Running. For more info check the progress on your console or the Ouput Tab.")
|
||||
|
||||
try:
|
||||
#try:
|
||||
# run textual inversion.
|
||||
config = st.session_state['textual_inversion']
|
||||
textual_inversion(config)
|
||||
#except RuntimeError:
|
||||
#if "pipeline" in server_state["textual_inversion"]:
|
||||
#del server_state['textual_inversion']["checker"]
|
||||
#del server_state['textual_inversion']["unwrapped"]
|
||||
#del server_state['textual_inversion']["pipeline"]
|
||||
|
||||
# run textual inversion.
|
||||
#config = st.session_state['textual_inversion']
|
||||
#textual_inversion(config)
|
||||
|
||||
set_page_title("Textual Inversion - Stable Diffusion WebUI")
|
||||
|
||||
except StopException:
|
||||
set_page_title("Textual Inversion - Stable Diffusion WebUI")
|
||||
print(f"Received Streamlit StopException")
|
||||
|
||||
st.session_state["textual_inversion"]["message"].empty()
|
||||
|
||||
#
|
||||
with output_tab:
|
||||
st.info("Under Construction. :construction_worker:")
|
||||
|
||||
#st.info("Nothing to show yet. Maybe try running some training first.")
|
||||
|
||||
#st.session_state["textual_inversion"]["preview_image"] = st.empty()
|
||||
#st.session_state["textual_inversion"]["progress_bar"] = st.empty()
|
||||
|
||||
|
||||
with tensorboard_tab:
|
||||
#st.info("Under Construction. :construction_worker:")
|
||||
|
||||
# Start TensorBoard
|
||||
st_tensorboard(logdir=os.path.join("outputs", "textual_inversion"), port=8888)
|
||||
|
708
webui/streamlit/scripts/txt2img.py
Normal file
708
webui/streamlit/scripts/txt2img.py
Normal file
@ -0,0 +1,708 @@
|
||||
# This file is part of sygil-webui (https://github.com/Sygil-Dev/sygil-webui/).
|
||||
|
||||
# Copyright 2022 Sygil-Dev team.
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as published by
|
||||
# the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
|
||||
# This program is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU Affero General Public License for more details.
|
||||
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
# base webui import and utils.
|
||||
from sd_utils import st, MemUsageMonitor, server_state, no_rerun, logger, set_page_title, \
|
||||
custom_models_available, RealESRGAN_available, GFPGAN_available, \
|
||||
LDSR_available
|
||||
#load_models, hc, seed_to_int, \
|
||||
#get_next_sequence_number, check_prompt_length, torch_gc, \
|
||||
#save_sample, generation_callback, process_images, \
|
||||
#KDiffusionSampler, \
|
||||
|
||||
# streamlit imports
|
||||
from streamlit.runtime.scriptrunner import StopException
|
||||
|
||||
#streamlit components section
|
||||
import streamlit_nested_layout #used to allow nested columns, just importing it is enought
|
||||
|
||||
#from streamlit.elements import image as STImage
|
||||
import streamlit.components.v1 as components
|
||||
#from streamlit.runtime.media_file_manager import media_file_manager
|
||||
from streamlit.elements.image import image_to_url
|
||||
|
||||
#other imports
|
||||
|
||||
import base64, uuid
|
||||
import os, sys, datetime, time
|
||||
from PIL import Image
|
||||
import requests
|
||||
from slugify import slugify
|
||||
from ldm.models.diffusion.ddim import DDIMSampler
|
||||
from typing import Union
|
||||
from io import BytesIO
|
||||
from ldm.models.diffusion.ddim import DDIMSampler
|
||||
from ldm.models.diffusion.plms import PLMSSampler
|
||||
|
||||
|
||||
# streamlit components
|
||||
from custom_components import sygil_suggestions
|
||||
|
||||
# Temp imports
|
||||
|
||||
|
||||
# end of imports
|
||||
#---------------------------------------------------------------------------------------------------------------
|
||||
|
||||
sygil_suggestions.init()
|
||||
|
||||
try:
|
||||
# this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
|
||||
from transformers import logging
|
||||
|
||||
logging.set_verbosity_error()
|
||||
except:
|
||||
pass
|
||||
|
||||
#
|
||||
# Dev mode (server)
|
||||
# _component_func = components.declare_component(
|
||||
# "sd-gallery",
|
||||
# url="http://localhost:3001",
|
||||
# )
|
||||
|
||||
# Init Vuejs component
|
||||
_component_func = components.declare_component(
|
||||
"sd-gallery", "./frontend/dists/sd-gallery/dist")
|
||||
|
||||
def sdGallery(images=[], key=None):
|
||||
component_value = _component_func(images=imgsToGallery(images), key=key, default="")
|
||||
return component_value
|
||||
|
||||
def imgsToGallery(images):
|
||||
urls = []
|
||||
for i in images:
|
||||
# random string for id
|
||||
random_id = str(uuid.uuid4())
|
||||
url = image_to_url(
|
||||
image=i,
|
||||
image_id= random_id,
|
||||
width=i.width,
|
||||
clamp=False,
|
||||
channels="RGB",
|
||||
output_format="PNG"
|
||||
)
|
||||
# image_io = BytesIO()
|
||||
# i.save(image_io, 'PNG')
|
||||
# width, height = i.size
|
||||
# image_id = "%s" % (str(images.index(i)))
|
||||
# (data, mimetype) = STImage._normalize_to_bytes(image_io.getvalue(), width, 'auto')
|
||||
# this_file = media_file_manager.add(data, mimetype, image_id)
|
||||
# img_str = this_file.url
|
||||
urls.append(url)
|
||||
|
||||
return urls
|
||||
|
||||
|
||||
class plugin_info():
|
||||
plugname = "txt2img"
|
||||
description = "Text to Image"
|
||||
isTab = True
|
||||
displayPriority = 1
|
||||
|
||||
@logger.catch(reraise=True)
|
||||
def stable_horde(outpath, prompt, seed, sampler_name, save_grid, batch_size,
|
||||
n_iter, steps, cfg_scale, width, height, prompt_matrix, use_GFPGAN, GFPGAN_model,
|
||||
use_RealESRGAN, realesrgan_model_name, use_LDSR,
|
||||
LDSR_model_name, ddim_eta, normalize_prompt_weights,
|
||||
save_individual_images, sort_samples, write_info_files,
|
||||
jpg_sample, variant_amount, variant_seed, api_key,
|
||||
nsfw=True, censor_nsfw=False):
|
||||
|
||||
log = []
|
||||
|
||||
log.append("Generating image with Stable Horde.")
|
||||
|
||||
st.session_state["progress_bar_text"].code('\n'.join(log), language='')
|
||||
|
||||
# start time after garbage collection (or before?)
|
||||
start_time = time.time()
|
||||
|
||||
# We will use this date here later for the folder name, need to start_time if not need
|
||||
run_start_dt = datetime.datetime.now()
|
||||
|
||||
mem_mon = MemUsageMonitor('MemMon')
|
||||
mem_mon.start()
|
||||
|
||||
os.makedirs(outpath, exist_ok=True)
|
||||
|
||||
sample_path = os.path.join(outpath, "samples")
|
||||
os.makedirs(sample_path, exist_ok=True)
|
||||
|
||||
params = {
|
||||
"sampler_name": "k_euler",
|
||||
"toggles": [1,4],
|
||||
"cfg_scale": cfg_scale,
|
||||
"seed": str(seed),
|
||||
"width": width,
|
||||
"height": height,
|
||||
"seed_variation": variant_seed if variant_seed else 1,
|
||||
"steps": int(steps),
|
||||
"n": int(n_iter)
|
||||
# You can put extra params here if you wish
|
||||
}
|
||||
|
||||
final_submit_dict = {
|
||||
"prompt": prompt,
|
||||
"params": params,
|
||||
"nsfw": nsfw,
|
||||
"censor_nsfw": censor_nsfw,
|
||||
"trusted_workers": True,
|
||||
"workers": []
|
||||
}
|
||||
log.append(final_submit_dict)
|
||||
|
||||
headers = {"apikey": api_key}
|
||||
logger.debug(final_submit_dict)
|
||||
st.session_state["progress_bar_text"].code('\n'.join(str(log)), language='')
|
||||
|
||||
horde_url = "https://stablehorde.net"
|
||||
|
||||
submit_req = requests.post(f'{horde_url}/api/v2/generate/async', json = final_submit_dict, headers = headers)
|
||||
if submit_req.ok:
|
||||
submit_results = submit_req.json()
|
||||
logger.debug(submit_results)
|
||||
|
||||
log.append(submit_results)
|
||||
st.session_state["progress_bar_text"].code(''.join(str(log)), language='')
|
||||
|
||||
req_id = submit_results['id']
|
||||
is_done = False
|
||||
while not is_done:
|
||||
chk_req = requests.get(f'{horde_url}/api/v2/generate/check/{req_id}')
|
||||
if not chk_req.ok:
|
||||
logger.error(chk_req.text)
|
||||
return
|
||||
chk_results = chk_req.json()
|
||||
logger.info(chk_results)
|
||||
is_done = chk_results['done']
|
||||
time.sleep(1)
|
||||
retrieve_req = requests.get(f'{horde_url}/api/v2/generate/status/{req_id}')
|
||||
if not retrieve_req.ok:
|
||||
logger.error(retrieve_req.text)
|
||||
return
|
||||
results_json = retrieve_req.json()
|
||||
# logger.debug(results_json)
|
||||
results = results_json['generations']
|
||||
|
||||
output_images = []
|
||||
comments = []
|
||||
prompt_matrix_parts = []
|
||||
|
||||
if not st.session_state['defaults'].general.no_verify_input:
|
||||
try:
|
||||
check_prompt_length(prompt, comments)
|
||||
except:
|
||||
import traceback
|
||||
logger.info("Error verifying input:", file=sys.stderr)
|
||||
logger.info(traceback.format_exc(), file=sys.stderr)
|
||||
|
||||
all_prompts = batch_size * n_iter * [prompt]
|
||||
all_seeds = [seed + x for x in range(len(all_prompts))]
|
||||
|
||||
for iter in range(len(results)):
|
||||
b64img = results[iter]["img"]
|
||||
base64_bytes = b64img.encode('utf-8')
|
||||
img_bytes = base64.b64decode(base64_bytes)
|
||||
img = Image.open(BytesIO(img_bytes))
|
||||
|
||||
sanitized_prompt = slugify(prompt)
|
||||
|
||||
prompts = all_prompts[iter * batch_size:(iter + 1) * batch_size]
|
||||
#captions = prompt_matrix_parts[n * batch_size:(n + 1) * batch_size]
|
||||
seeds = all_seeds[iter * batch_size:(iter + 1) * batch_size]
|
||||
|
||||
if sort_samples:
|
||||
full_path = os.path.join(os.getcwd(), sample_path, sanitized_prompt)
|
||||
|
||||
|
||||
sanitized_prompt = sanitized_prompt[:200-len(full_path)]
|
||||
sample_path_i = os.path.join(sample_path, sanitized_prompt)
|
||||
|
||||
#print(f"output folder length: {len(os.path.join(os.getcwd(), sample_path_i))}")
|
||||
#print(os.path.join(os.getcwd(), sample_path_i))
|
||||
|
||||
os.makedirs(sample_path_i, exist_ok=True)
|
||||
base_count = get_next_sequence_number(sample_path_i)
|
||||
filename = f"{base_count:05}-{steps}_{sampler_name}_{seeds[iter]}"
|
||||
else:
|
||||
full_path = os.path.join(os.getcwd(), sample_path)
|
||||
sample_path_i = sample_path
|
||||
base_count = get_next_sequence_number(sample_path_i)
|
||||
filename = f"{base_count:05}-{steps}_{sampler_name}_{seed}_{sanitized_prompt}"[:200-len(full_path)] #same as before
|
||||
|
||||
save_sample(img, sample_path_i, filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale,
|
||||
normalize_prompt_weights, use_GFPGAN, write_info_files, prompt_matrix, init_img=None,
|
||||
denoising_strength=0.75, resize_mode=None, uses_loopback=False, uses_random_seed_loopback=False,
|
||||
save_grid=save_grid,
|
||||
sort_samples=sampler_name, sampler_name=sampler_name, ddim_eta=ddim_eta, n_iter=n_iter,
|
||||
batch_size=batch_size, i=iter, save_individual_images=save_individual_images,
|
||||
model_name="Stable Diffusion v1.5")
|
||||
|
||||
output_images.append(img)
|
||||
|
||||
# update image on the UI so we can see the progress
|
||||
if "preview_image" in st.session_state:
|
||||
st.session_state["preview_image"].image(img)
|
||||
|
||||
if "progress_bar_text" in st.session_state:
|
||||
st.session_state["progress_bar_text"].empty()
|
||||
|
||||
#if len(results) > 1:
|
||||
#final_filename = f"{iter}_{filename}"
|
||||
#img.save(final_filename)
|
||||
#logger.info(f"Saved {final_filename}")
|
||||
else:
|
||||
if "progress_bar_text" in st.session_state:
|
||||
st.session_state["progress_bar_text"].error(submit_req.text)
|
||||
|
||||
logger.error(submit_req.text)
|
||||
|
||||
mem_max_used, mem_total = mem_mon.read_and_stop()
|
||||
time_diff = time.time()-start_time
|
||||
|
||||
info = f"""
|
||||
{prompt}
|
||||
Steps: {steps}, Sampler: {sampler_name}, CFG scale: {cfg_scale}, Seed: {seed}{', GFPGAN' if use_GFPGAN else ''}{', '+realesrgan_model_name if use_RealESRGAN else ''}
|
||||
{', Prompt Matrix Mode.' if prompt_matrix else ''}""".strip()
|
||||
|
||||
stats = f'''
|
||||
Took { round(time_diff, 2) }s total ({ round(time_diff/(len(all_prompts)),2) }s per image)
|
||||
Peak memory usage: { -(mem_max_used // -1_048_576) } MiB / { -(mem_total // -1_048_576) } MiB / { round(mem_max_used/mem_total*100, 3) }%'''
|
||||
|
||||
for comment in comments:
|
||||
info += "\n\n" + comment
|
||||
|
||||
#mem_mon.stop()
|
||||
#del mem_mon
|
||||
torch_gc()
|
||||
|
||||
return output_images, seed, info, stats
|
||||
|
||||
|
||||
#
|
||||
@logger.catch(reraise=True)
|
||||
def txt2img(prompt: str, ddim_steps: int, sampler_name: str, n_iter: int, batch_size: int, cfg_scale: float, seed: Union[int, str, None],
|
||||
height: int, width: int, separate_prompts:bool = False, normalize_prompt_weights:bool = True,
|
||||
save_individual_images: bool = True, save_grid: bool = True, group_by_prompt: bool = True,
|
||||
save_as_jpg: bool = True, use_GFPGAN: bool = True, GFPGAN_model: str = 'GFPGANv1.3', use_RealESRGAN: bool = False,
|
||||
RealESRGAN_model: str = "RealESRGAN_x4plus_anime_6B", use_LDSR: bool = True, LDSR_model: str = "model",
|
||||
fp = None, variant_amount: float = 0.0,
|
||||
variant_seed: int = None, ddim_eta:float = 0.0, write_info_files:bool = True,
|
||||
use_stable_horde: bool = False, stable_horde_key:str = "0000000000"):
|
||||
|
||||
outpath = st.session_state['defaults'].general.outdir_txt2img
|
||||
|
||||
seed = seed_to_int(seed)
|
||||
|
||||
if not use_stable_horde:
|
||||
|
||||
if sampler_name == 'PLMS':
|
||||
sampler = PLMSSampler(server_state["model"])
|
||||
elif sampler_name == 'DDIM':
|
||||
sampler = DDIMSampler(server_state["model"])
|
||||
elif sampler_name == 'k_dpm_2_a':
|
||||
sampler = KDiffusionSampler(server_state["model"],'dpm_2_ancestral')
|
||||
elif sampler_name == 'k_dpm_2':
|
||||
sampler = KDiffusionSampler(server_state["model"],'dpm_2')
|
||||
elif sampler_name == 'k_dpmpp_2m':
|
||||
sampler = KDiffusionSampler(server_state["model"],'dpmpp_2m')
|
||||
elif sampler_name == 'k_euler_a':
|
||||
sampler = KDiffusionSampler(server_state["model"],'euler_ancestral')
|
||||
elif sampler_name == 'k_euler':
|
||||
sampler = KDiffusionSampler(server_state["model"],'euler')
|
||||
elif sampler_name == 'k_heun':
|
||||
sampler = KDiffusionSampler(server_state["model"],'heun')
|
||||
elif sampler_name == 'k_lms':
|
||||
sampler = KDiffusionSampler(server_state["model"],'lms')
|
||||
else:
|
||||
raise Exception("Unknown sampler: " + sampler_name)
|
||||
|
||||
def init():
|
||||
pass
|
||||
|
||||
def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name):
|
||||
samples_ddim, _ = sampler.sample(S=ddim_steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=cfg_scale,
|
||||
unconditional_conditioning=unconditional_conditioning, eta=ddim_eta, x_T=x,
|
||||
img_callback=generation_callback if not server_state["bridge"] else None,
|
||||
log_every_t=int(st.session_state.update_preview_frequency if not server_state["bridge"] else 100))
|
||||
|
||||
return samples_ddim
|
||||
|
||||
|
||||
if use_stable_horde:
|
||||
output_images, seed, info, stats = stable_horde(
|
||||
prompt=prompt,
|
||||
seed=seed,
|
||||
outpath=outpath,
|
||||
sampler_name=sampler_name,
|
||||
save_grid=save_grid,
|
||||
batch_size=batch_size,
|
||||
n_iter=n_iter,
|
||||
steps=ddim_steps,
|
||||
cfg_scale=cfg_scale,
|
||||
width=width,
|
||||
height=height,
|
||||
prompt_matrix=separate_prompts,
|
||||
use_GFPGAN=use_GFPGAN,
|
||||
GFPGAN_model=GFPGAN_model,
|
||||
use_RealESRGAN=use_RealESRGAN,
|
||||
realesrgan_model_name=RealESRGAN_model,
|
||||
use_LDSR=use_LDSR,
|
||||
LDSR_model_name=LDSR_model,
|
||||
ddim_eta=ddim_eta,
|
||||
normalize_prompt_weights=normalize_prompt_weights,
|
||||
save_individual_images=save_individual_images,
|
||||
sort_samples=group_by_prompt,
|
||||
write_info_files=write_info_files,
|
||||
jpg_sample=save_as_jpg,
|
||||
variant_amount=variant_amount,
|
||||
variant_seed=variant_seed,
|
||||
api_key=stable_horde_key
|
||||
)
|
||||
else:
|
||||
|
||||
#try:
|
||||
output_images, seed, info, stats = process_images(
|
||||
outpath=outpath,
|
||||
func_init=init,
|
||||
func_sample=sample,
|
||||
prompt=prompt,
|
||||
seed=seed,
|
||||
sampler_name=sampler_name,
|
||||
save_grid=save_grid,
|
||||
batch_size=batch_size,
|
||||
n_iter=n_iter,
|
||||
steps=ddim_steps,
|
||||
cfg_scale=cfg_scale,
|
||||
width=width,
|
||||
height=height,
|
||||
prompt_matrix=separate_prompts,
|
||||
use_GFPGAN=use_GFPGAN,
|
||||
GFPGAN_model=GFPGAN_model,
|
||||
use_RealESRGAN=use_RealESRGAN,
|
||||
realesrgan_model_name=RealESRGAN_model,
|
||||
use_LDSR=use_LDSR,
|
||||
LDSR_model_name=LDSR_model,
|
||||
ddim_eta=ddim_eta,
|
||||
normalize_prompt_weights=normalize_prompt_weights,
|
||||
save_individual_images=save_individual_images,
|
||||
sort_samples=group_by_prompt,
|
||||
write_info_files=write_info_files,
|
||||
jpg_sample=save_as_jpg,
|
||||
variant_amount=variant_amount,
|
||||
variant_seed=variant_seed,
|
||||
)
|
||||
|
||||
del sampler
|
||||
|
||||
return output_images, seed, info, stats
|
||||
|
||||
#except RuntimeError as e:
|
||||
#err = e
|
||||
#err_msg = f'CRASHED:<br><textarea rows="5" style="color:white;background: black;width: -webkit-fill-available;font-family: monospace;font-size: small;font-weight: bold;">{str(e)}</textarea><br><br>Please wait while the program restarts.'
|
||||
#stats = err_msg
|
||||
#return [], seed, 'err', stats
|
||||
|
||||
#
|
||||
@logger.catch(reraise=True)
|
||||
def layout():
|
||||
with st.form("txt2img-inputs"):
|
||||
st.session_state["generation_mode"] = "txt2img"
|
||||
|
||||
input_col1, generate_col1 = st.columns([10,1])
|
||||
|
||||
with input_col1:
|
||||
#prompt = st.text_area("Input Text","")
|
||||
placeholder = "A corgi wearing a top hat as an oil painting."
|
||||
prompt = st.text_area("Input Text","", placeholder=placeholder, height=54)
|
||||
|
||||
if "defaults" in st.session_state:
|
||||
if st.session_state["defaults"].general.enable_suggestions:
|
||||
sygil_suggestions.suggestion_area(placeholder)
|
||||
|
||||
if "defaults" in st.session_state:
|
||||
if st.session_state['defaults'].admin.global_negative_prompt:
|
||||
prompt += f"### {st.session_state['defaults'].admin.global_negative_prompt}"
|
||||
|
||||
#print(prompt)
|
||||
|
||||
# creating the page layout using columns
|
||||
col1, col2, col3 = st.columns([2,5,2], gap="large")
|
||||
|
||||
with col1:
|
||||
width = st.slider("Width:", min_value=st.session_state['defaults'].txt2img.width.min_value, max_value=st.session_state['defaults'].txt2img.width.max_value,
|
||||
value=st.session_state['defaults'].txt2img.width.value, step=st.session_state['defaults'].txt2img.width.step)
|
||||
height = st.slider("Height:", min_value=st.session_state['defaults'].txt2img.height.min_value, max_value=st.session_state['defaults'].txt2img.height.max_value,
|
||||
value=st.session_state['defaults'].txt2img.height.value, step=st.session_state['defaults'].txt2img.height.step)
|
||||
cfg_scale = st.number_input("CFG (Classifier Free Guidance Scale):", min_value=st.session_state['defaults'].txt2img.cfg_scale.min_value,
|
||||
value=st.session_state['defaults'].txt2img.cfg_scale.value, step=st.session_state['defaults'].txt2img.cfg_scale.step,
|
||||
help="How strongly the image should follow the prompt.")
|
||||
|
||||
seed = st.text_input("Seed:", value=st.session_state['defaults'].txt2img.seed, help=" The seed to use, if left blank a random seed will be generated.")
|
||||
|
||||
with st.expander("Batch Options"):
|
||||
#batch_count = st.slider("Batch count.", min_value=st.session_state['defaults'].txt2img.batch_count.min_value, max_value=st.session_state['defaults'].txt2img.batch_count.max_value,
|
||||
#value=st.session_state['defaults'].txt2img.batch_count.value, step=st.session_state['defaults'].txt2img.batch_count.step,
|
||||
#help="How many iterations or batches of images to generate in total.")
|
||||
|
||||
#batch_size = st.slider("Batch size", min_value=st.session_state['defaults'].txt2img.batch_size.min_value, max_value=st.session_state['defaults'].txt2img.batch_size.max_value,
|
||||
#value=st.session_state.defaults.txt2img.batch_size.value, step=st.session_state.defaults.txt2img.batch_size.step,
|
||||
#help="How many images are at once in a batch.\
|
||||
#It increases the VRAM usage a lot but if you have enough VRAM it can reduce the time it takes to finish generation as more images are generated at once.\
|
||||
#Default: 1")
|
||||
|
||||
st.session_state["batch_count"] = st.number_input("Batch count.", value=st.session_state['defaults'].txt2img.batch_count.value,
|
||||
help="How many iterations or batches of images to generate in total.")
|
||||
|
||||
st.session_state["batch_size"] = st.number_input("Batch size", value=st.session_state.defaults.txt2img.batch_size.value,
|
||||
help="How many images are at once in a batch.\
|
||||
It increases the VRAM usage a lot but if you have enough VRAM it can reduce the time it takes \
|
||||
to finish generation as more images are generated at once.\
|
||||
Default: 1")
|
||||
|
||||
with st.expander("Preview Settings"):
|
||||
|
||||
st.session_state["update_preview"] = st.session_state["defaults"].general.update_preview
|
||||
st.session_state["update_preview_frequency"] = st.number_input("Update Image Preview Frequency",
|
||||
min_value=0,
|
||||
value=st.session_state['defaults'].txt2img.update_preview_frequency,
|
||||
help="Frequency in steps at which the the preview image is updated. By default the frequency \
|
||||
is set to 10 step.")
|
||||
|
||||
with col2:
|
||||
preview_tab, gallery_tab = st.tabs(["Preview", "Gallery"])
|
||||
|
||||
with preview_tab:
|
||||
#st.write("Image")
|
||||
#Image for testing
|
||||
#image = Image.open(requests.get("https://icon-library.com/images/image-placeholder-icon/image-placeholder-icon-13.jpg", stream=True).raw).convert('RGB')
|
||||
#new_image = image.resize((175, 240))
|
||||
#preview_image = st.image(image)
|
||||
|
||||
# create an empty container for the image, progress bar, etc so we can update it later and use session_state to hold them globally.
|
||||
st.session_state["preview_image"] = st.empty()
|
||||
|
||||
|
||||
st.session_state["progress_bar_text"] = st.empty()
|
||||
st.session_state["progress_bar_text"].info("Nothing but crickets here, try generating something first.")
|
||||
|
||||
st.session_state["progress_bar"] = st.empty()
|
||||
|
||||
message = st.empty()
|
||||
|
||||
with gallery_tab:
|
||||
st.session_state["gallery"] = st.empty()
|
||||
#st.session_state["gallery"].info("Nothing but crickets here, try generating something first.")
|
||||
|
||||
with col3:
|
||||
# If we have custom models available on the "models/custom"
|
||||
#folder then we show a menu to select which model we want to use, otherwise we use the main model for SD
|
||||
custom_models_available()
|
||||
|
||||
if server_state["CustomModel_available"]:
|
||||
st.session_state["custom_model"] = st.selectbox("Custom Model:", server_state["custom_models"],
|
||||
index=server_state["custom_models"].index(st.session_state['defaults'].general.default_model),
|
||||
help="Select the model you want to use. This option is only available if you have custom models \
|
||||
on your 'models/custom' folder. The model name that will be shown here is the same as the name\
|
||||
the file for the model has on said folder, it is recommended to give the .ckpt file a name that \
|
||||
will make it easier for you to distinguish it from other models. Default: Stable Diffusion v1.5")
|
||||
|
||||
st.session_state.sampling_steps = st.number_input("Sampling Steps", value=st.session_state.defaults.txt2img.sampling_steps.value,
|
||||
min_value=st.session_state.defaults.txt2img.sampling_steps.min_value,
|
||||
step=st.session_state['defaults'].txt2img.sampling_steps.step,
|
||||
help="Set the default number of sampling steps to use. Default is: 30 (with k_euler)")
|
||||
|
||||
sampler_name_list = ["k_lms", "k_euler", "k_euler_a", "k_dpm_2", "k_dpm_2_a", "k_dpmpp_2m", "k_heun", "PLMS", "DDIM"]
|
||||
sampler_name = st.selectbox("Sampling method", sampler_name_list,
|
||||
index=sampler_name_list.index(st.session_state['defaults'].txt2img.default_sampler), help="Sampling method to use. Default: k_euler")
|
||||
|
||||
with st.expander("Advanced"):
|
||||
with st.expander("Stable Horde"):
|
||||
use_stable_horde = st.checkbox("Use Stable Horde", value=False, help="Use the Stable Horde to generate images. More info can be found at https://stablehorde.net/")
|
||||
stable_horde_key = st.text_input("Stable Horde Api Key", value=st.session_state['defaults'].general.stable_horde_api, type="password",
|
||||
help="Optional Api Key used for the Stable Horde Bridge, if no api key is added the horde will be used anonymously.")
|
||||
|
||||
with st.expander("Output Settings"):
|
||||
separate_prompts = st.checkbox("Create Prompt Matrix.", value=st.session_state['defaults'].txt2img.separate_prompts,
|
||||
help="Separate multiple prompts using the `|` character, and get all combinations of them.")
|
||||
|
||||
normalize_prompt_weights = st.checkbox("Normalize Prompt Weights.", value=st.session_state['defaults'].txt2img.normalize_prompt_weights,
|
||||
help="Ensure the sum of all weights add up to 1.0")
|
||||
|
||||
save_individual_images = st.checkbox("Save individual images.", value=st.session_state['defaults'].txt2img.save_individual_images,
|
||||
help="Save each image generated before any filter or enhancement is applied.")
|
||||
|
||||
save_grid = st.checkbox("Save grid",value=st.session_state['defaults'].txt2img.save_grid, help="Save a grid with all the images generated into a single image.")
|
||||
group_by_prompt = st.checkbox("Group results by prompt", value=st.session_state['defaults'].txt2img.group_by_prompt,
|
||||
help="Saves all the images with the same prompt into the same folder. When using a prompt matrix each prompt combination will have its own folder.")
|
||||
|
||||
write_info_files = st.checkbox("Write Info file", value=st.session_state['defaults'].txt2img.write_info_files,
|
||||
help="Save a file next to the image with informartion about the generation.")
|
||||
|
||||
save_as_jpg = st.checkbox("Save samples as jpg", value=st.session_state['defaults'].txt2img.save_as_jpg, help="Saves the images as jpg instead of png.")
|
||||
|
||||
# check if GFPGAN, RealESRGAN and LDSR are available.
|
||||
#if "GFPGAN_available" not in st.session_state:
|
||||
GFPGAN_available()
|
||||
|
||||
#if "RealESRGAN_available" not in st.session_state:
|
||||
RealESRGAN_available()
|
||||
|
||||
#if "LDSR_available" not in st.session_state:
|
||||
LDSR_available()
|
||||
|
||||
if st.session_state["GFPGAN_available"] or st.session_state["RealESRGAN_available"] or st.session_state["LDSR_available"]:
|
||||
with st.expander("Post-Processing"):
|
||||
face_restoration_tab, upscaling_tab = st.tabs(["Face Restoration", "Upscaling"])
|
||||
with face_restoration_tab:
|
||||
# GFPGAN used for face restoration
|
||||
if st.session_state["GFPGAN_available"]:
|
||||
#with st.expander("Face Restoration"):
|
||||
#if st.session_state["GFPGAN_available"]:
|
||||
#with st.expander("GFPGAN"):
|
||||
st.session_state["use_GFPGAN"] = st.checkbox("Use GFPGAN", value=st.session_state['defaults'].txt2img.use_GFPGAN,
|
||||
help="Uses the GFPGAN model to improve faces after the generation.\
|
||||
This greatly improve the quality and consistency of faces but uses\
|
||||
extra VRAM. Disable if you need the extra VRAM.")
|
||||
|
||||
st.session_state["GFPGAN_model"] = st.selectbox("GFPGAN model", st.session_state["GFPGAN_models"],
|
||||
index=st.session_state["GFPGAN_models"].index(st.session_state['defaults'].general.GFPGAN_model))
|
||||
|
||||
#st.session_state["GFPGAN_strenght"] = st.slider("Effect Strenght", min_value=1, max_value=100, value=1, step=1, help='')
|
||||
|
||||
else:
|
||||
st.session_state["use_GFPGAN"] = False
|
||||
|
||||
with upscaling_tab:
|
||||
st.session_state['use_upscaling'] = st.checkbox("Use Upscaling", value=st.session_state['defaults'].txt2img.use_upscaling)
|
||||
|
||||
# RealESRGAN and LDSR used for upscaling.
|
||||
if st.session_state["RealESRGAN_available"] or st.session_state["LDSR_available"]:
|
||||
|
||||
upscaling_method_list = []
|
||||
if st.session_state["RealESRGAN_available"]:
|
||||
upscaling_method_list.append("RealESRGAN")
|
||||
if st.session_state["LDSR_available"]:
|
||||
upscaling_method_list.append("LDSR")
|
||||
|
||||
#print (st.session_state["RealESRGAN_available"])
|
||||
st.session_state["upscaling_method"] = st.selectbox("Upscaling Method", upscaling_method_list,
|
||||
index=upscaling_method_list.index(st.session_state['defaults'].general.upscaling_method)
|
||||
if st.session_state['defaults'].general.upscaling_method in upscaling_method_list
|
||||
else 0)
|
||||
|
||||
if st.session_state["RealESRGAN_available"]:
|
||||
with st.expander("RealESRGAN"):
|
||||
if st.session_state["upscaling_method"] == "RealESRGAN" and st.session_state['use_upscaling']:
|
||||
st.session_state["use_RealESRGAN"] = True
|
||||
else:
|
||||
st.session_state["use_RealESRGAN"] = False
|
||||
|
||||
st.session_state["RealESRGAN_model"] = st.selectbox("RealESRGAN model", st.session_state["RealESRGAN_models"],
|
||||
index=st.session_state["RealESRGAN_models"].index(st.session_state['defaults'].general.RealESRGAN_model))
|
||||
else:
|
||||
st.session_state["use_RealESRGAN"] = False
|
||||
st.session_state["RealESRGAN_model"] = "RealESRGAN_x4plus"
|
||||
|
||||
|
||||
#
|
||||
if st.session_state["LDSR_available"]:
|
||||
with st.expander("LDSR"):
|
||||
if st.session_state["upscaling_method"] == "LDSR" and st.session_state['use_upscaling']:
|
||||
st.session_state["use_LDSR"] = True
|
||||
else:
|
||||
st.session_state["use_LDSR"] = False
|
||||
|
||||
st.session_state["LDSR_model"] = st.selectbox("LDSR model", st.session_state["LDSR_models"],
|
||||
index=st.session_state["LDSR_models"].index(st.session_state['defaults'].general.LDSR_model))
|
||||
|
||||
st.session_state["ldsr_sampling_steps"] = st.number_input("Sampling Steps", value=st.session_state['defaults'].txt2img.LDSR_config.sampling_steps,
|
||||
help="")
|
||||
|
||||
st.session_state["preDownScale"] = st.number_input("PreDownScale", value=st.session_state['defaults'].txt2img.LDSR_config.preDownScale,
|
||||
help="")
|
||||
|
||||
st.session_state["postDownScale"] = st.number_input("postDownScale", value=st.session_state['defaults'].txt2img.LDSR_config.postDownScale,
|
||||
help="")
|
||||
|
||||
downsample_method_list = ['Nearest', 'Lanczos']
|
||||
st.session_state["downsample_method"] = st.selectbox("Downsample Method", downsample_method_list,
|
||||
index=downsample_method_list.index(st.session_state['defaults'].txt2img.LDSR_config.downsample_method))
|
||||
|
||||
else:
|
||||
st.session_state["use_LDSR"] = False
|
||||
st.session_state["LDSR_model"] = "model"
|
||||
|
||||
with st.expander("Variant"):
|
||||
variant_amount = st.slider("Variant Amount:", value=st.session_state['defaults'].txt2img.variant_amount.value,
|
||||
min_value=st.session_state['defaults'].txt2img.variant_amount.min_value, max_value=st.session_state['defaults'].txt2img.variant_amount.max_value,
|
||||
step=st.session_state['defaults'].txt2img.variant_amount.step)
|
||||
variant_seed = st.text_input("Variant Seed:", value=st.session_state['defaults'].txt2img.seed,
|
||||
help="The seed to use when generating a variant, if left blank a random seed will be generated.")
|
||||
|
||||
#galleryCont = st.empty()
|
||||
|
||||
# Every form must have a submit button, the extra blank spaces is a temp way to align it with the input field. Needs to be done in CSS or some other way.
|
||||
generate_col1.write("")
|
||||
generate_col1.write("")
|
||||
generate_button = generate_col1.form_submit_button("Generate")
|
||||
|
||||
#
|
||||
if generate_button:
|
||||
|
||||
with col2:
|
||||
with no_rerun:
|
||||
if not use_stable_horde:
|
||||
with hc.HyLoader('Loading Models...', hc.Loaders.standard_loaders,index=[0]):
|
||||
load_models(use_LDSR=st.session_state["use_LDSR"], LDSR_model=st.session_state["LDSR_model"],
|
||||
use_GFPGAN=st.session_state["use_GFPGAN"], GFPGAN_model=st.session_state["GFPGAN_model"] ,
|
||||
use_RealESRGAN=st.session_state["use_RealESRGAN"], RealESRGAN_model=st.session_state["RealESRGAN_model"],
|
||||
CustomModel_available=server_state["CustomModel_available"], custom_model=st.session_state["custom_model"])
|
||||
|
||||
#print(st.session_state['use_RealESRGAN'])
|
||||
#print(st.session_state['use_LDSR'])
|
||||
try:
|
||||
|
||||
|
||||
output_images, seeds, info, stats = txt2img(prompt, st.session_state.sampling_steps, sampler_name, st.session_state["batch_count"], st.session_state["batch_size"],
|
||||
cfg_scale, seed, height, width, separate_prompts, normalize_prompt_weights, save_individual_images,
|
||||
save_grid, group_by_prompt, save_as_jpg, st.session_state["use_GFPGAN"], st.session_state['GFPGAN_model'],
|
||||
use_RealESRGAN=st.session_state["use_RealESRGAN"], RealESRGAN_model=st.session_state["RealESRGAN_model"],
|
||||
use_LDSR=st.session_state["use_LDSR"], LDSR_model=st.session_state["LDSR_model"],
|
||||
variant_amount=variant_amount, variant_seed=variant_seed, write_info_files=write_info_files,
|
||||
use_stable_horde=use_stable_horde, stable_horde_key=stable_horde_key)
|
||||
|
||||
message.success('Render Complete: ' + info + '; Stats: ' + stats, icon="✅")
|
||||
|
||||
with gallery_tab:
|
||||
logger.info(seeds)
|
||||
st.session_state["gallery"].text = ""
|
||||
sdGallery(output_images)
|
||||
|
||||
|
||||
except (StopException,
|
||||
#KeyError
|
||||
):
|
||||
print(f"Received Streamlit StopException")
|
||||
|
||||
# reset the page title so the percent doesnt stay on it confusing the user.
|
||||
set_page_title(f"Stable Diffusion Playground")
|
||||
|
||||
# this will render all the images at the end of the generation but its better if its moved to a second tab inside col2 and shown as a gallery.
|
||||
# use the current col2 first tab to show the preview_img and update it as its generated.
|
||||
#preview_image.image(output_images)
|
||||
|
||||
|
2012
webui/streamlit/scripts/txt2vid.py
Normal file
2012
webui/streamlit/scripts/txt2vid.py
Normal file
File diff suppressed because it is too large
Load Diff
277
webui/streamlit/scripts/webui_streamlit.py
Normal file
277
webui/streamlit/scripts/webui_streamlit.py
Normal file
@ -0,0 +1,277 @@
|
||||
# This file is part of sygil-webui (https://github.com/Sygil-Dev/sygil-webui/).
|
||||
|
||||
# Copyright 2022 Sygil-Dev team.
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as published by
|
||||
# the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
|
||||
# This program is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU Affero General Public License for more details.
|
||||
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
# base webui import and utils.
|
||||
#import streamlit as st
|
||||
|
||||
# We import hydralit like this to replace the previous stuff
|
||||
# we had with native streamlit as it lets ur replace things 1:1
|
||||
from sd_utils import st, hc, load_configs, load_css, set_logger_verbosity,\
|
||||
logger, quiesce_logger, set_page_title, random
|
||||
|
||||
# streamlit imports
|
||||
import streamlit_nested_layout
|
||||
|
||||
#streamlit components section
|
||||
#from st_on_hover_tabs import on_hover_tabs
|
||||
#from streamlit_server_state import server_state, server_state_lock
|
||||
|
||||
#other imports
|
||||
import argparse
|
||||
#from sd_utils.bridge import run_bridge
|
||||
|
||||
# import custom components
|
||||
from custom_components import draggable_number_input
|
||||
|
||||
# end of imports
|
||||
#---------------------------------------------------------------------------------------------------------------
|
||||
|
||||
load_configs()
|
||||
|
||||
help = """
|
||||
A double dash (`--`) is used to separate streamlit arguments from app arguments.
|
||||
As a result using "streamlit run webui_streamlit.py --headless"
|
||||
will show the help for streamlit itself and not pass any argument to our app,
|
||||
we need to use "streamlit run webui_streamlit.py -- --headless"
|
||||
in order to pass a command argument to this app."""
|
||||
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
|
||||
parser.add_argument("--headless", action='store_true', help="Don't launch web server, util if you just want to run the stable horde bridge.", default=False)
|
||||
|
||||
parser.add_argument("--bridge", action='store_true', help="don't launch web server, but make this instance into a Horde bridge.", default=False)
|
||||
parser.add_argument('--horde_api_key', action="store", required=False, type=str, help="The API key corresponding to the owner of this Horde instance")
|
||||
parser.add_argument('--horde_name', action="store", required=False, type=str, help="The server name for the Horde. It will be shown to the world and there can be only one.")
|
||||
parser.add_argument('--horde_url', action="store", required=False, type=str, help="The SH Horde URL. Where the bridge will pickup prompts and send the finished generations.")
|
||||
parser.add_argument('--horde_priority_usernames',type=str, action='append', required=False, help="Usernames which get priority use in this horde instance. The owner's username is always in this list.")
|
||||
parser.add_argument('--horde_max_power',type=int, required=False, help="How much power this instance has to generate pictures. Min: 2")
|
||||
parser.add_argument('--horde_sfw', action='store_true', required=False, help="Set to true if you do not want this worker generating NSFW images.")
|
||||
parser.add_argument('--horde_blacklist', nargs='+', required=False, help="List the words that you want to blacklist.")
|
||||
parser.add_argument('--horde_censorlist', nargs='+', required=False, help="List the words that you want to censor.")
|
||||
parser.add_argument('--horde_censor_nsfw', action='store_true', required=False, help="Set to true if you want this bridge worker to censor NSFW images.")
|
||||
parser.add_argument('--horde_model', action='store', required=False, help="Which model to run on this horde.")
|
||||
parser.add_argument('-v', '--verbosity', action='count', default=0, help="The default logging level is ERROR or higher. This value increases the amount of logging seen in your screen")
|
||||
parser.add_argument('-q', '--quiet', action='count', default=0, help="The default logging level is ERROR or higher. This value decreases the amount of logging seen in your screen")
|
||||
opt = parser.parse_args()
|
||||
|
||||
#with server_state_lock["bridge"]:
|
||||
#server_state["bridge"] = opt.bridge
|
||||
|
||||
@logger.catch(reraise=True)
|
||||
def layout():
|
||||
"""Layout functions to define all the streamlit layout here."""
|
||||
if not st.session_state["defaults"].debug.enable_hydralit:
|
||||
st.set_page_config(page_title="Stable Diffusion Playground", layout="wide", initial_sidebar_state="collapsed")
|
||||
|
||||
#app = st.HydraApp(title='Stable Diffusion WebUI', favicon="", sidebar_state="expanded", layout="wide",
|
||||
#hide_streamlit_markers=False, allow_url_nav=True , clear_cross_app_sessions=False)
|
||||
|
||||
|
||||
# load css as an external file, function has an option to local or remote url. Potential use when running from cloud infra that might not have access to local path.
|
||||
load_css(True, 'frontend/css/streamlit.main.css')
|
||||
|
||||
#
|
||||
# specify the primary menu definition
|
||||
menu_data = [
|
||||
{'id': 'Stable Diffusion', 'label': 'Stable Diffusion', 'icon': 'bi bi-grid-1x2-fill'},
|
||||
{'id': 'Train','label':"Train", 'icon': "bi bi-lightbulb-fill", 'submenu':[
|
||||
{'id': 'Textual Inversion', 'label': 'Textual Inversion', 'icon': 'bi bi-lightbulb-fill'},
|
||||
{'id': 'Fine Tunning', 'label': 'Fine Tunning', 'icon': 'bi bi-lightbulb-fill'},
|
||||
]},
|
||||
{'id': 'Model Manager', 'label': 'Model Manager', 'icon': 'bi bi-cloud-arrow-down-fill'},
|
||||
{'id': 'Tools','label':"Tools", 'icon': "bi bi-tools", 'submenu':[
|
||||
{'id': 'API Server', 'label': 'API Server', 'icon': 'bi bi-server'},
|
||||
{'id': 'Barfi/BaklavaJS', 'label': 'Barfi/BaklavaJS', 'icon': 'bi bi-diagram-3-fill'},
|
||||
#{'id': 'API Server', 'label': 'API Server', 'icon': 'bi bi-server'},
|
||||
]},
|
||||
{'id': 'Settings', 'label': 'Settings', 'icon': 'bi bi-gear-fill'},
|
||||
]
|
||||
|
||||
over_theme = {'txc_inactive': '#FFFFFF', "menu_background":'#000000'}
|
||||
|
||||
menu_id = hc.nav_bar(
|
||||
menu_definition=menu_data,
|
||||
#home_name='Home',
|
||||
#login_name='Logout',
|
||||
hide_streamlit_markers=False,
|
||||
override_theme=over_theme,
|
||||
sticky_nav=True,
|
||||
sticky_mode='pinned',
|
||||
)
|
||||
|
||||
#
|
||||
#if menu_id == "Home":
|
||||
#st.info("Under Construction. :construction_worker:")
|
||||
|
||||
if menu_id == "Stable Diffusion":
|
||||
# set the page url and title
|
||||
#st.experimental_set_query_params(page='stable-diffusion')
|
||||
try:
|
||||
set_page_title("Stable Diffusion Playground")
|
||||
except NameError:
|
||||
st.experimental_rerun()
|
||||
|
||||
txt2img_tab, img2img_tab, txt2vid_tab, img2txt_tab, post_processing_tab, concept_library_tab = st.tabs(["Text-to-Image", "Image-to-Image",
|
||||
#"Inpainting",
|
||||
"Text-to-Video", "Image-To-Text",
|
||||
"Post-Processing","Concept Library"])
|
||||
#with home_tab:
|
||||
#from home import layout
|
||||
#layout()
|
||||
|
||||
with txt2img_tab:
|
||||
from txt2img import layout
|
||||
layout()
|
||||
|
||||
with img2img_tab:
|
||||
from img2img import layout
|
||||
layout()
|
||||
|
||||
#with inpainting_tab:
|
||||
#from inpainting import layout
|
||||
#layout()
|
||||
|
||||
with txt2vid_tab:
|
||||
from txt2vid import layout
|
||||
layout()
|
||||
|
||||
with img2txt_tab:
|
||||
from img2txt import layout
|
||||
layout()
|
||||
|
||||
with post_processing_tab:
|
||||
from post_processing import layout
|
||||
layout()
|
||||
|
||||
with concept_library_tab:
|
||||
from sd_concept_library import layout
|
||||
layout()
|
||||
|
||||
#
|
||||
elif menu_id == 'Model Manager':
|
||||
set_page_title("Model Manager - Stable Diffusion Playground")
|
||||
|
||||
from ModelManager import layout
|
||||
layout()
|
||||
|
||||
elif menu_id == 'Textual Inversion':
|
||||
from textual_inversion import layout
|
||||
layout()
|
||||
|
||||
elif menu_id == 'Fine Tunning':
|
||||
#from textual_inversion import layout
|
||||
#layout()
|
||||
st.info("Under Construction. :construction_worker:")
|
||||
|
||||
elif menu_id == 'API Server':
|
||||
set_page_title("API Server - Stable Diffusion Playground")
|
||||
from APIServer import layout
|
||||
layout()
|
||||
|
||||
elif menu_id == 'Barfi/BaklavaJS':
|
||||
set_page_title("Barfi/BaklavaJS - Stable Diffusion Playground")
|
||||
from barfi_baklavajs import layout
|
||||
layout()
|
||||
|
||||
elif menu_id == 'Settings':
|
||||
set_page_title("Settings - Stable Diffusion Playground")
|
||||
|
||||
from Settings import layout
|
||||
layout()
|
||||
|
||||
# calling dragable input component module at the end, so it works on all pages
|
||||
draggable_number_input.load()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
set_logger_verbosity(opt.verbosity)
|
||||
quiesce_logger(opt.quiet)
|
||||
|
||||
if not opt.headless:
|
||||
layout()
|
||||
|
||||
#with server_state_lock["bridge"]:
|
||||
#if server_state["bridge"]:
|
||||
#try:
|
||||
#import bridgeData as cd
|
||||
#except ModuleNotFoundError as e:
|
||||
#logger.warning("No bridgeData found. Falling back to default where no CLI args are set.")
|
||||
#logger.debug(str(e))
|
||||
#except SyntaxError as e:
|
||||
#logger.warning("bridgeData found, but is malformed. Falling back to default where no CLI args are set.")
|
||||
#logger.debug(str(e))
|
||||
#except Exception as e:
|
||||
#logger.warning("No bridgeData found, use default where no CLI args are set")
|
||||
#logger.debug(str(e))
|
||||
#finally:
|
||||
#try: # check if cd exists (i.e. bridgeData loaded properly)
|
||||
#cd
|
||||
#except: # if not, create defaults
|
||||
#class temp(object):
|
||||
#def __init__(self):
|
||||
#random.seed()
|
||||
#self.horde_url = "https://stablehorde.net"
|
||||
## Give a cool name to your instance
|
||||
#self.horde_name = f"Automated Instance #{random.randint(-100000000, 100000000)}"
|
||||
## The api_key identifies a unique user in the horde
|
||||
#self.horde_api_key = "0000000000"
|
||||
## Put other users whose prompts you want to prioritize.
|
||||
## The owner's username is always included so you don't need to add it here, unless you want it to have lower priority than another user
|
||||
#self.horde_priority_usernames = []
|
||||
#self.horde_max_power = 8
|
||||
#self.nsfw = True
|
||||
#self.censor_nsfw = False
|
||||
#self.blacklist = []
|
||||
#self.censorlist = []
|
||||
#self.models_to_load = ["stable_diffusion"]
|
||||
#cd = temp()
|
||||
#horde_api_key = opt.horde_api_key if opt.horde_api_key else cd.horde_api_key
|
||||
#horde_name = opt.horde_name if opt.horde_name else cd.horde_name
|
||||
#horde_url = opt.horde_url if opt.horde_url else cd.horde_url
|
||||
#horde_priority_usernames = opt.horde_priority_usernames if opt.horde_priority_usernames else cd.horde_priority_usernames
|
||||
#horde_max_power = opt.horde_max_power if opt.horde_max_power else cd.horde_max_power
|
||||
## Not used yet
|
||||
#horde_models = [opt.horde_model] if opt.horde_model else cd.models_to_load
|
||||
#try:
|
||||
#horde_nsfw = not opt.horde_sfw if opt.horde_sfw else cd.horde_nsfw
|
||||
#except AttributeError:
|
||||
#horde_nsfw = True
|
||||
#try:
|
||||
#horde_censor_nsfw = opt.horde_censor_nsfw if opt.horde_censor_nsfw else cd.horde_censor_nsfw
|
||||
#except AttributeError:
|
||||
#horde_censor_nsfw = False
|
||||
#try:
|
||||
#horde_blacklist = opt.horde_blacklist if opt.horde_blacklist else cd.horde_blacklist
|
||||
#except AttributeError:
|
||||
#horde_blacklist = []
|
||||
#try:
|
||||
#horde_censorlist = opt.horde_censorlist if opt.horde_censorlist else cd.horde_censorlist
|
||||
#except AttributeError:
|
||||
#horde_censorlist = []
|
||||
#if horde_max_power < 2:
|
||||
#horde_max_power = 2
|
||||
#horde_max_pixels = 64*64*8*horde_max_power
|
||||
#logger.info(f"Joining Horde with parameters: Server Name '{horde_name}'. Horde URL '{horde_url}'. Max Pixels {horde_max_pixels}")
|
||||
|
||||
#try:
|
||||
#thread = threading.Thread(target=run_bridge(1, horde_api_key, horde_name, horde_url,
|
||||
#horde_priority_usernames, horde_max_pixels,
|
||||
#horde_nsfw, horde_censor_nsfw, horde_blacklist,
|
||||
#horde_censorlist), args=())
|
||||
#thread.daemon = True
|
||||
#thread.start()
|
||||
##run_bridge(1, horde_api_key, horde_name, horde_url, horde_priority_usernames, horde_max_pixels, horde_nsfw, horde_censor_nsfw, horde_blacklist, horde_censorlist)
|
||||
#except KeyboardInterrupt:
|
||||
#print(f"Keyboard Interrupt Received. Ending Bridge")
|
Loading…
Reference in New Issue
Block a user