Started to move all the streamlit code and dependencies to its own folder in webui/streamlit and moving the backend to use nataili, this should help make it self-contained and reduce the amount of code we have in other places which should also make it so the code is easier to understand and read.

This commit is contained in:
ZeroCool940711 2023-01-31 14:30:59 -07:00
parent 0781ced89a
commit ffd7883cb0
No known key found for this signature in database
GPG Key ID: 4E4072992B5BC640
22 changed files with 9281 additions and 0 deletions

View File

@ -0,0 +1,178 @@
/*
This file is part of sygil-webui (https://github.com/Sygil-Dev/sygil-webui/).
Copyright 2022 Sygil-Dev team.
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
/***********************************************************
* Additional CSS for streamlit builtin components *
************************************************************/
/* Tab name (e.g. Text-to-Image) //improve legibility*/
button[data-baseweb="tab"] {
font-size: 25px;
}
/* Image Container (only appear after run finished)//center the image, especially better looks in wide screen */
.css-1kyxreq{
justify-content: center;
}
/* Streamlit header */
.css-1avcm0n {
background-color: transparent;
}
/* Main streamlit container (below header) //reduce the empty spaces*/
.css-18e3th9 {
padding-top: 1rem;
}
/***********************************************************
* Additional CSS for streamlit custom/3rd party components *
************************************************************/
/* For stream_on_hover */
section[data-testid="stSidebar"] > div:nth-of-type(1) {
background-color: #111;
}
button[kind="header"] {
background-color: transparent;
color: rgb(180, 167, 141);
}
@media (hover) {
/* header element */
header[data-testid="stHeader"] {
/* display: none;*/ /*suggested behavior by streamlit hover components*/
pointer-events: none; /* disable interaction of the transparent background */
}
/* The button on the streamlit navigation menu */
button[kind="header"] {
/* display: none;*/ /*suggested behavior by streamlit hover components*/
pointer-events: auto; /* enable interaction of the button even if parents intereaction disabled */
}
/* added to avoid main sectors (all element to the right of sidebar from) moving */
section[data-testid="stSidebar"] {
width: 3.5% !important;
min-width: 3.5% !important;
}
/* The navigation menu specs and size */
section[data-testid="stSidebar"] > div {
height: 100%;
width: 2% !important;
min-width: 100% !important;
position: relative;
z-index: 1;
top: 0;
left: 0;
background-color: #111;
overflow-x: hidden;
transition: 0.5s ease-in-out;
padding-top: 0px;
white-space: nowrap;
}
/* The navigation menu open and close on hover and size */
section[data-testid="stSidebar"] > div:hover {
width: 300px !important;
}
}
@media (max-width: 272px) {
section[data-testid="stSidebar"] > div {
width: 15rem;
}
}
/***********************************************************
* Additional CSS for other elements
************************************************************/
button[data-baseweb="tab"] {
font-size: 20px;
}
@media (min-width: 1200px){
h1 {
font-size: 1.75rem;
}
}
#tabs-1-tabpanel-0 > div:nth-child(1) > div > div.stTabs.css-0.exp6ofz0 {
width: 50rem;
align-self: center;
}
div.gallery:hover {
border: 1px solid #777;
}
.css-dg4u6x p {
font-size: 0.8rem;
text-align: center;
position: relative;
top: 6px;
}
.row-widget.stButton {
text-align: center;
}
/********************************************************************
Hide anchor links on titles
*********************************************************************/
/*
.css-15zrgzn {
display: none
}
.css-eczf16 {
display: none
}
.css-jn99sy {
display: none
}
/* Make the text area widget have a similar height as the text input field */
.st-dy{
height: 54px;
min-height: 25px;
}
.css-17useex{
gap: 3px;
}
/* Remove some empty spaces to make the UI more compact. */
.css-18e3th9{
padding-left: 10px;
padding-right: 30px;
position: unset !important; /* Fixes the layout/page going up when an expander or another item is expanded and then collapsed */
}
.css-k1vhr4{
padding-top: initial;
}
.css-ret2ud{
padding-left: 10px;
padding-right: 30px;
gap: initial;
display: initial;
}
.css-w5z5an{
gap: 1px;
}

View File

@ -0,0 +1,34 @@
# This file is part of sygil-webui (https://github.com/Sygil-Dev/sygil-webui/).
# Copyright 2022 Sygil-Dev team.
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
# base webui import and utils.
#from sd_utils import *
from sd_utils import st
# streamlit imports
#streamlit components section
#other imports
#from fastapi import FastAPI
#import uvicorn
# Temp imports
# end of imports
#---------------------------------------------------------------------------------------------------------------
def layout():
st.info("Under Construction. :construction_worker:")

View File

@ -0,0 +1,121 @@
# This file is part of sygil-webui (https://github.com/Sygil-Dev/sygil-webui/).
# Copyright 2022 Sygil-Dev team.
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
# base webui import and utils.
from sd_utils import st, logger
# streamlit imports
#other imports
import os, requests
from requests.auth import HTTPBasicAuth
from requests import HTTPError
from stqdm import stqdm
# Temp imports
# end of imports
#---------------------------------------------------------------------------------------------------------------
def download_file(file_name, file_path, file_url):
if not os.path.exists(file_path):
os.makedirs(file_path)
if not os.path.exists(os.path.join(file_path , file_name)):
print('Downloading ' + file_name + '...')
# TODO - add progress bar in streamlit
# download file with `requests``
if file_name == "Stable Diffusion v1.5":
if "huggingface_token" not in st.session_state or st.session_state["defaults"].general.huggingface_token == "None":
if "progress_bar_text" in st.session_state:
st.session_state["progress_bar_text"].error(
"You need a huggingface token in order to use the Text to Video tab. Use the Settings page from the sidebar on the left to add your token."
)
raise OSError("You need a huggingface token in order to use the Text to Video tab. Use the Settings page from the sidebar on the left to add your token.")
try:
with requests.get(file_url, auth = HTTPBasicAuth('token', st.session_state.defaults.general.huggingface_token) if "huggingface.co" in file_url else None, stream=True) as r:
r.raise_for_status()
with open(os.path.join(file_path, file_name), 'wb') as f:
for chunk in stqdm(r.iter_content(chunk_size=8192), backend=True, unit="kb"):
f.write(chunk)
except HTTPError as e:
if "huggingface.co" in file_url:
if "resolve"in file_url:
repo_url = file_url.split("resolve")[0]
st.session_state["progress_bar_text"].error(
f"You need to accept the license for the model in order to be able to download it. "
f"Please visit {repo_url} and accept the lincense there, then try again to download the model.")
logger.error(e)
else:
print(file_name + ' already exists.')
def download_model(models, model_name):
""" Download all files from model_list[model_name] """
for file in models[model_name]:
download_file(file['file_name'], file['file_path'], file['file_url'])
return
def layout():
#search = st.text_input(label="Search", placeholder="Type the name of the model you want to search for.", help="")
colms = st.columns((1, 3, 3, 5, 5))
columns = ["", 'Model Name', 'Save Location', "Download", 'Download Link']
models = st.session_state["defaults"].model_manager.models
for col, field_name in zip(colms, columns):
# table header
col.write(field_name)
for x, model_name in enumerate(models):
col1, col2, col3, col4, col5 = st.columns((1, 3, 3, 3, 6))
col1.write(x) # index
col2.write(models[model_name]['model_name'])
col3.write(models[model_name]['save_location'])
with col4:
files_exist = 0
for file in models[model_name]['files']:
if "save_location" in models[model_name]['files'][file]:
os.path.exists(os.path.join(models[model_name]['files'][file]['save_location'] , models[model_name]['files'][file]['file_name']))
files_exist += 1
elif os.path.exists(os.path.join(models[model_name]['save_location'] , models[model_name]['files'][file]['file_name'])):
files_exist += 1
files_needed = []
for file in models[model_name]['files']:
if "save_location" in models[model_name]['files'][file]:
if not os.path.exists(os.path.join(models[model_name]['files'][file]['save_location'] , models[model_name]['files'][file]['file_name'])):
files_needed.append(file)
elif not os.path.exists(os.path.join(models[model_name]['save_location'] , models[model_name]['files'][file]['file_name'])):
files_needed.append(file)
if len(files_needed) > 0:
if st.button('Download', key=models[model_name]['model_name'], help='Download ' + models[model_name]['model_name']):
for file in files_needed:
if "save_location" in models[model_name]['files'][file]:
download_file(models[model_name]['files'][file]['file_name'], models[model_name]['files'][file]['save_location'], models[model_name]['files'][file]['download_link'])
else:
download_file(models[model_name]['files'][file]['file_name'], models[model_name]['save_location'], models[model_name]['files'][file]['download_link'])
st.experimental_rerun()
else:
st.empty()
else:
st.write('')
#

View File

@ -0,0 +1,899 @@
# This file is part of sygil-webui (https://github.com/Sygil-Dev/sygil-webui/).
# Copyright 2022 Sygil-Dev team.
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
# base webui import and utils.
from sd_utils import st, custom_models_available, logger, human_readable_size
# streamlit imports
# streamlit components section
import streamlit_nested_layout
from streamlit_server_state import server_state
# other imports
from omegaconf import OmegaConf
import torch
import os, toml
# end of imports
# ---------------------------------------------------------------------------------------------------------------
@logger.catch(reraise=True)
def layout():
#st.header("Settings")
with st.form("Settings"):
general_tab, txt2img_tab, img2img_tab, img2txt_tab, txt2vid_tab, image_processing, textual_inversion_tab, concepts_library_tab = st.tabs(
['General', "Text-To-Image", "Image-To-Image", "Image-To-Text", "Text-To-Video", "Image processing", "Textual Inversion", "Concepts Library"])
with general_tab:
col1, col2, col3, col4, col5 = st.columns(5, gap='large')
device_list = []
device_properties = [(i, torch.cuda.get_device_properties(i)) for i in range(torch.cuda.device_count())]
for device in device_properties:
id = device[0]
name = device[1].name
total_memory = device[1].total_memory
device_list.append(f"{id}: {name} ({human_readable_size(total_memory, decimal_places=0)})")
with col1:
st.title("General")
st.session_state['defaults'].general.gpu = int(st.selectbox("GPU", device_list, index=st.session_state['defaults'].general.gpu,
help=f"Select which GPU to use. Default: {device_list[0]}").split(":")[0])
st.session_state['defaults'].general.outdir = str(st.text_input("Output directory", value=st.session_state['defaults'].general.outdir,
help="Relative directory on which the output images after a generation will be placed. Default: 'outputs'"))
# If we have custom models available on the "models/custom"
# folder then we show a menu to select which model we want to use, otherwise we use the main model for SD
custom_models_available()
if server_state["CustomModel_available"]:
st.session_state.defaults.general.default_model = st.selectbox("Default Model:", server_state["custom_models"],
index=server_state["custom_models"].index(st.session_state['defaults'].general.default_model),
help="Select the model you want to use. If you have placed custom models \
on your 'models/custom' folder they will be shown here as well. The model name that will be shown here \
is the same as the name the file for the model has on said folder, \
it is recommended to give the .ckpt file a name that \
will make it easier for you to distinguish it from other models. Default: Stable Diffusion v1.4")
else:
st.session_state.defaults.general.default_model = st.selectbox("Default Model:", [st.session_state['defaults'].general.default_model],
help="Select the model you want to use. If you have placed custom models \
on your 'models/custom' folder they will be shown here as well. \
The model name that will be shown here is the same as the name\
the file for the model has on said folder, it is recommended to give the .ckpt file a name that \
will make it easier for you to distinguish it from other models. Default: Stable Diffusion v1.4")
st.session_state['defaults'].general.default_model_config = st.text_input("Default Model Config", value=st.session_state['defaults'].general.default_model_config,
help="Default model config file for inference. Default: 'configs/stable-diffusion/v1-inference.yaml'")
st.session_state['defaults'].general.default_model_path = st.text_input("Default Model Config", value=st.session_state['defaults'].general.default_model_path,
help="Default model path. Default: 'models/ldm/stable-diffusion-v1/model.ckpt'")
st.session_state['defaults'].general.GFPGAN_dir = st.text_input("Default GFPGAN directory", value=st.session_state['defaults'].general.GFPGAN_dir,
help="Default GFPGAN directory. Default: './models/gfpgan'")
st.session_state['defaults'].general.RealESRGAN_dir = st.text_input("Default RealESRGAN directory", value=st.session_state['defaults'].general.RealESRGAN_dir,
help="Default GFPGAN directory. Default: './models/realesrgan'")
RealESRGAN_model_list = ["RealESRGAN_x4plus", "RealESRGAN_x4plus_anime_6B"]
st.session_state['defaults'].general.RealESRGAN_model = st.selectbox("RealESRGAN model", RealESRGAN_model_list,
index=RealESRGAN_model_list.index(st.session_state['defaults'].general.RealESRGAN_model),
help="Default RealESRGAN model. Default: 'RealESRGAN_x4plus'")
Upscaler_list = ["RealESRGAN", "LDSR"]
st.session_state['defaults'].general.upscaling_method = st.selectbox("Upscaler", Upscaler_list, index=Upscaler_list.index(
st.session_state['defaults'].general.upscaling_method), help="Default upscaling method. Default: 'RealESRGAN'")
with col2:
st.title("Performance")
st.session_state["defaults"].general.gfpgan_cpu = st.checkbox("GFPGAN - CPU", value=st.session_state['defaults'].general.gfpgan_cpu,
help="Run GFPGAN on the cpu. Default: False")
st.session_state["defaults"].general.esrgan_cpu = st.checkbox("ESRGAN - CPU", value=st.session_state['defaults'].general.esrgan_cpu,
help="Run ESRGAN on the cpu. Default: False")
st.session_state["defaults"].general.extra_models_cpu = st.checkbox("Extra Models - CPU", value=st.session_state['defaults'].general.extra_models_cpu,
help="Run extra models (GFGPAN/ESRGAN) on cpu. Default: False")
st.session_state["defaults"].general.extra_models_gpu = st.checkbox("Extra Models - GPU", value=st.session_state['defaults'].general.extra_models_gpu,
help="Run extra models (GFGPAN/ESRGAN) on gpu. \
Check and save in order to be able to select the GPU that each model will use. Default: False")
if st.session_state["defaults"].general.extra_models_gpu:
st.session_state['defaults'].general.gfpgan_gpu = int(st.selectbox("GFGPAN GPU", device_list, index=st.session_state['defaults'].general.gfpgan_gpu,
help=f"Select which GPU to use. Default: {device_list[st.session_state['defaults'].general.gfpgan_gpu]}",
key="gfpgan_gpu").split(":")[0])
st.session_state["defaults"].general.esrgan_gpu = int(st.selectbox("ESRGAN - GPU", device_list, index=st.session_state['defaults'].general.esrgan_gpu,
help=f"Select which GPU to use. Default: {device_list[st.session_state['defaults'].general.esrgan_gpu]}",
key="esrgan_gpu").split(":")[0])
st.session_state["defaults"].general.no_half = st.checkbox("No Half", value=st.session_state['defaults'].general.no_half,
help="DO NOT switch the model to 16-bit floats. Default: False")
st.session_state["defaults"].general.use_cudnn = st.checkbox("Use cudnn", value=st.session_state['defaults'].general.use_cudnn,
help="Switch the pytorch backend to use cudnn, this should help with fixing Nvidia 16xx cards getting"
"a black or green image. Default: False")
st.session_state["defaults"].general.use_float16 = st.checkbox("Use float16", value=st.session_state['defaults'].general.use_float16,
help="Switch the model to 16-bit floats. Default: False")
precision_list = ['full', 'autocast']
st.session_state["defaults"].general.precision = st.selectbox("Precision", precision_list, index=precision_list.index(st.session_state['defaults'].general.precision),
help="Evaluates at this precision. Default: autocast")
st.session_state["defaults"].general.optimized = st.checkbox("Optimized Mode", value=st.session_state['defaults'].general.optimized,
help="Loads the model onto the device piecemeal instead of all at once to reduce VRAM usage\
at the cost of performance. Default: False")
st.session_state["defaults"].general.optimized_turbo = st.checkbox("Optimized Turbo Mode", value=st.session_state['defaults'].general.optimized_turbo,
help="Alternative optimization mode that does not save as much VRAM but \
runs siginificantly faster. Default: False")
st.session_state["defaults"].general.optimized_config = st.text_input("Optimized Config", value=st.session_state['defaults'].general.optimized_config,
help=f"Loads alternative optimized configuration for inference. \
Default: optimizedSD/v1-inference.yaml")
st.session_state["defaults"].general.enable_attention_slicing = st.checkbox("Enable Attention Slicing", value=st.session_state['defaults'].general.enable_attention_slicing,
help="Enable sliced attention computation. When this option is enabled, the attention module will \
split the input tensor in slices, to compute attention in several steps. This is useful to save some \
memory in exchange for a small speed decrease. Only works the txt2vid tab right now. Default: False")
st.session_state["defaults"].general.enable_minimal_memory_usage = st.checkbox("Enable Minimal Memory Usage", value=st.session_state['defaults'].general.enable_minimal_memory_usage,
help="Moves only unet to fp16 and to CUDA, while keepping lighter models on CPUs \
(Not properly implemented and currently not working, check this \
link 'https://github.com/huggingface/diffusers/pull/537' for more information on it ). Default: False")
# st.session_state["defaults"].general.update_preview = st.checkbox("Update Preview Image", value=st.session_state['defaults'].general.update_preview,
# help="Enables the preview image to be updated and shown to the user on the UI during the generation.\
# If checked, once you save the settings an option to specify the frequency at which the image is updated\
# in steps will be shown, this is helpful to reduce the negative effect this option has on performance. \
# Default: True")
st.session_state["defaults"].general.update_preview = True
st.session_state["defaults"].general.update_preview_frequency = st.number_input("Update Preview Frequency",
min_value=0,
value=st.session_state['defaults'].general.update_preview_frequency,
help="Specify the frequency at which the image is updated in steps, this is helpful to reduce the \
negative effect updating the preview image has on performance. Default: 10")
with col3:
st.title("Others")
st.session_state["defaults"].general.use_sd_concepts_library = st.checkbox("Use the Concepts Library", value=st.session_state['defaults'].general.use_sd_concepts_library,
help="Use the embeds Concepts Library, if checked, once the settings are saved an option will\
appear to specify the directory where the concepts are stored. Default: True)")
if st.session_state["defaults"].general.use_sd_concepts_library:
st.session_state['defaults'].general.sd_concepts_library_folder = st.text_input("Concepts Library Folder",
value=st.session_state['defaults'].general.sd_concepts_library_folder,
help="Relative folder on which the concepts library embeds are stored. \
Default: 'models/custom/sd-concepts-library'")
st.session_state['defaults'].general.LDSR_dir = st.text_input("LDSR Folder", value=st.session_state['defaults'].general.LDSR_dir,
help="Folder where LDSR is located. Default: './models/ldsr'")
st.session_state["defaults"].general.save_metadata = st.checkbox("Save Metadata", value=st.session_state['defaults'].general.save_metadata,
help="Save metadata on the output image. Default: True")
save_format_list = ["png","jpg", "jpeg","webp"]
st.session_state["defaults"].general.save_format = st.selectbox("Save Format", save_format_list, index=save_format_list.index(st.session_state['defaults'].general.save_format),
help="Format that will be used whens saving the output images. Default: 'png'")
st.session_state["defaults"].general.skip_grid = st.checkbox("Skip Grid", value=st.session_state['defaults'].general.skip_grid,
help="Skip saving the grid output image. Default: False")
if not st.session_state["defaults"].general.skip_grid:
st.session_state["defaults"].general.grid_quality = st.number_input("Grid Quality", value=st.session_state['defaults'].general.grid_quality,
help="Format for saving the grid output image. Default: 95")
st.session_state["defaults"].general.skip_save = st.checkbox("Skip Save", value=st.session_state['defaults'].general.skip_save,
help="Skip saving the output image. Default: False")
st.session_state["defaults"].general.n_rows = st.number_input("Number of Grid Rows", value=st.session_state['defaults'].general.n_rows,
help="Number of rows the grid wil have when saving the grid output image. Default: '-1'")
st.session_state["defaults"].general.no_verify_input = st.checkbox("Do not Verify Input", value=st.session_state['defaults'].general.no_verify_input,
help="Do not verify input to check if it's too long. Default: False")
st.session_state["defaults"].general.show_percent_in_tab_title = st.checkbox("Show Percent in tab title", value=st.session_state['defaults'].general.show_percent_in_tab_title,
help="Add the progress percent value to the page title on the tab on your browser. "
"This is useful in case you need to know how the generation is going while doign something else"
"in another tab on your browser. Default: True")
st.session_state["defaults"].general.enable_suggestions = st.checkbox("Enable Suggestions Box", value=st.session_state['defaults'].general.enable_suggestions,
help="Adds a suggestion box under the prompt when clicked. Default: True")
st.session_state["defaults"].daisi_app.running_on_daisi_io = st.checkbox("Running on Daisi.io?", value=st.session_state['defaults'].daisi_app.running_on_daisi_io,
help="Specify if we are running on app.Daisi.io . Default: False")
with col4:
st.title("Streamlit Config")
default_theme_list = ["light", "dark"]
st.session_state["defaults"].general.default_theme = st.selectbox("Default Theme", default_theme_list, index=default_theme_list.index(st.session_state['defaults'].general.default_theme),
help="Defaut theme to use as base for streamlit. Default: dark")
st.session_state["streamlit_config"]["theme"]["base"] = st.session_state["defaults"].general.default_theme
if not st.session_state['defaults'].admin.hide_server_setting:
with st.expander("Server", True):
st.session_state["streamlit_config"]['server']['headless'] = st.checkbox("Run Headless", help="If false, will attempt to open a browser window on start. \
Default: false unless (1) we are on a Linux box where DISPLAY is unset, \
or (2) we are running in the Streamlit Atom plugin.")
st.session_state["streamlit_config"]['server']['port'] = st.number_input("Port", value=st.session_state["streamlit_config"]['server']['port'],
help="The port where the server will listen for browser connections. Default: 8501")
st.session_state["streamlit_config"]['server']['baseUrlPath'] = st.text_input("Base Url Path", value=st.session_state["streamlit_config"]['server']['baseUrlPath'],
help="The base path for the URL where Streamlit should be served from. Default: '' ")
st.session_state["streamlit_config"]['server']['enableCORS'] = st.checkbox("Enable CORS", value=st.session_state['streamlit_config']['server']['enableCORS'],
help="Enables support for Cross-Origin Request Sharing (CORS) protection, for added security. \
Due to conflicts between CORS and XSRF, if `server.enableXsrfProtection` is on and `server.enableCORS` \
is off at the same time, we will prioritize `server.enableXsrfProtection`. Default: true")
st.session_state["streamlit_config"]['server']['enableXsrfProtection'] = st.checkbox("Enable Xsrf Protection",
value=st.session_state['streamlit_config']['server']['enableXsrfProtection'],
help="Enables support for Cross-Site Request Forgery (XSRF) protection, \
for added security. Due to conflicts between CORS and XSRF, \
if `server.enableXsrfProtection` is on and `server.enableCORS` is off at \
the same time, we will prioritize `server.enableXsrfProtection`. Default: true")
st.session_state["streamlit_config"]['server']['maxUploadSize'] = st.number_input("Max Upload Size", value=st.session_state["streamlit_config"]['server']['maxUploadSize'],
help="Max size, in megabytes, for files uploaded with the file_uploader. Default: 200")
st.session_state["streamlit_config"]['server']['maxMessageSize'] = st.number_input("Max Message Size", value=st.session_state["streamlit_config"]['server']['maxUploadSize'],
help="Max size, in megabytes, of messages that can be sent via the WebSocket connection. Default: 200")
st.session_state["streamlit_config"]['server']['enableWebsocketCompression'] = st.checkbox("Enable Websocket Compression",
value=st.session_state["streamlit_config"]['server']['enableWebsocketCompression'],
help=" Enables support for websocket compression. Default: false")
if not st.session_state['defaults'].admin.hide_browser_setting:
with st.expander("Browser", expanded=True):
st.session_state["streamlit_config"]['browser']['serverAddress'] = st.text_input("Server Address",
value=st.session_state["streamlit_config"]['browser']['serverAddress'] if "serverAddress" in st.session_state["streamlit_config"] else "localhost",
help="Internet address where users should point their browsers in order \
to connect to the app. Can be IP address or DNS name and path.\
This is used to: - Set the correct URL for CORS and XSRF protection purposes. \
- Show the URL on the terminal - Open the browser. Default: 'localhost'")
st.session_state["defaults"].general.streamlit_telemetry = st.checkbox("Enable Telemetry", value=st.session_state['defaults'].general.streamlit_telemetry,
help="Enables or Disables streamlit telemetry. Default: False")
st.session_state["streamlit_config"]["browser"]["gatherUsageStats"] = st.session_state["defaults"].general.streamlit_telemetry
st.session_state["streamlit_config"]['browser']['serverPort'] = st.number_input("Server Port", value=st.session_state["streamlit_config"]['browser']['serverPort'],
help="Port where users should point their browsers in order to connect to the app. \
This is used to: - Set the correct URL for CORS and XSRF protection purposes. \
- Show the URL on the terminal - Open the browser \
Default: whatever value is set in server.port.")
with col5:
st.title("Huggingface")
st.session_state["defaults"].general.huggingface_token = st.text_input("Huggingface Token", value=st.session_state['defaults'].general.huggingface_token, type="password",
help="Your Huggingface Token, it's used to download the model for the diffusers library which \
is used on the Text To Video tab. This token will be saved to your user config file\
and WILL NOT be share with us or anyone. You can get your access token \
at https://huggingface.co/settings/tokens. Default: None")
st.title("Stable Horde")
st.session_state["defaults"].general.stable_horde_api = st.text_input("Stable Horde Api", value=st.session_state["defaults"].general.stable_horde_api, type="password",
help="First Register an account at https://stablehorde.net/register which will generate for you \
an API key. Store that key somewhere safe. \n \
If you do not want to register, you can use `0000000000` as api_key to connect anonymously.\
However anonymous accounts have the lowest priority when there's too many concurrent requests! \
To increase your priority you will need a unique API key and then to increase your Kudos \
read more about them at https://dbzer0.com/blog/the-kudos-based-economy-for-the-koboldai-horde/.")
with txt2img_tab:
col1, col2, col3, col4, col5 = st.columns(5, gap='medium')
with col1:
st.title("Slider Parameters")
# Width
st.session_state["defaults"].txt2img.width.value = st.number_input("Default Image Width", value=st.session_state['defaults'].txt2img.width.value,
help="Set the default width for the generated image. Default is: 512")
st.session_state["defaults"].txt2img.width.min_value = st.number_input("Minimum Image Width", value=st.session_state['defaults'].txt2img.width.min_value,
help="Set the default minimum value for the width slider. Default is: 64")
st.session_state["defaults"].txt2img.width.max_value = st.number_input("Maximum Image Width", value=st.session_state['defaults'].txt2img.width.max_value,
help="Set the default maximum value for the width slider. Default is: 2048")
# Height
st.session_state["defaults"].txt2img.height.value = st.number_input("Default Image Height", value=st.session_state['defaults'].txt2img.height.value,
help="Set the default height for the generated image. Default is: 512")
st.session_state["defaults"].txt2img.height.min_value = st.number_input("Minimum Image Height", value=st.session_state['defaults'].txt2img.height.min_value,
help="Set the default minimum value for the height slider. Default is: 64")
st.session_state["defaults"].txt2img.height.max_value = st.number_input("Maximum Image Height", value=st.session_state['defaults'].txt2img.height.max_value,
help="Set the default maximum value for the height slider. Default is: 2048")
with col2:
# CFG
st.session_state["defaults"].txt2img.cfg_scale.value = st.number_input("Default CFG Scale", value=st.session_state['defaults'].txt2img.cfg_scale.value,
help="Set the default value for the CFG Scale. Default is: 7.5")
st.session_state["defaults"].txt2img.cfg_scale.min_value = st.number_input("Minimum CFG Scale Value", value=st.session_state['defaults'].txt2img.cfg_scale.min_value,
help="Set the default minimum value for the CFG scale slider. Default is: 1")
st.session_state["defaults"].txt2img.cfg_scale.step = st.number_input("CFG Slider Steps", value=st.session_state['defaults'].txt2img.cfg_scale.step,
help="Set the default value for the number of steps on the CFG scale slider. Default is: 0.5")
# Sampling Steps
st.session_state["defaults"].txt2img.sampling_steps.value = st.number_input("Default Sampling Steps", value=st.session_state['defaults'].txt2img.sampling_steps.value,
help="Set the default number of sampling steps to use. Default is: 30 (with k_euler)")
st.session_state["defaults"].txt2img.sampling_steps.min_value = st.number_input("Minimum Sampling Steps",
value=st.session_state['defaults'].txt2img.sampling_steps.min_value,
help="Set the default minimum value for the sampling steps slider. Default is: 1")
st.session_state["defaults"].txt2img.sampling_steps.step = st.number_input("Sampling Slider Steps",
value=st.session_state['defaults'].txt2img.sampling_steps.step,
help="Set the default value for the number of steps on the sampling steps slider. Default is: 10")
with col3:
st.title("General Parameters")
# Batch Count
st.session_state["defaults"].txt2img.batch_count.value = st.number_input("Batch count", value=st.session_state['defaults'].txt2img.batch_count.value,
help="How many iterations or batches of images to generate in total.")
st.session_state["defaults"].txt2img.batch_size.value = st.number_input("Batch size", value=st.session_state.defaults.txt2img.batch_size.value,
help="How many images are at once in a batch.\
It increases the VRAM usage a lot but if you have enough VRAM it can reduce the time it \
takes to finish generation as more images are generated at once.\
Default: 1")
default_sampler_list = ["k_lms", "k_euler", "k_euler_a", "k_dpm_2", "k_dpm_2_a", "k_heun", "PLMS", "DDIM"]
st.session_state["defaults"].txt2img.default_sampler = st.selectbox("Default Sampler",
default_sampler_list, index=default_sampler_list.index(
st.session_state['defaults'].txt2img.default_sampler),
help="Defaut sampler to use for txt2img. Default: k_euler")
st.session_state['defaults'].txt2img.seed = st.text_input("Default Seed", value=st.session_state['defaults'].txt2img.seed, help="Default seed.")
with col4:
st.session_state["defaults"].txt2img.separate_prompts = st.checkbox("Separate Prompts",
value=st.session_state['defaults'].txt2img.separate_prompts, help="Separate Prompts. Default: False")
st.session_state["defaults"].txt2img.normalize_prompt_weights = st.checkbox("Normalize Prompt Weights",
value=st.session_state['defaults'].txt2img.normalize_prompt_weights,
help="Choose to normalize prompt weights. Default: True")
st.session_state["defaults"].txt2img.save_individual_images = st.checkbox("Save Individual Images",
value=st.session_state['defaults'].txt2img.save_individual_images,
help="Choose to save individual images. Default: True")
st.session_state["defaults"].txt2img.save_grid = st.checkbox("Save Grid Images", value=st.session_state['defaults'].txt2img.save_grid,
help="Choose to save the grid images. Default: True")
st.session_state["defaults"].txt2img.group_by_prompt = st.checkbox("Group By Prompt", value=st.session_state['defaults'].txt2img.group_by_prompt,
help="Choose to save images grouped by their prompt. Default: False")
st.session_state["defaults"].txt2img.save_as_jpg = st.checkbox("Save As JPG", value=st.session_state['defaults'].txt2img.save_as_jpg,
help="Choose to save images as jpegs. Default: False")
st.session_state["defaults"].txt2img.write_info_files = st.checkbox("Write Info Files For Images", value=st.session_state['defaults'].txt2img.write_info_files,
help="Choose to write the info files along with the generated images. Default: True")
st.session_state["defaults"].txt2img.use_GFPGAN = st.checkbox(
"Use GFPGAN", value=st.session_state['defaults'].txt2img.use_GFPGAN, help="Choose to use GFPGAN. Default: False")
st.session_state["defaults"].txt2img.use_upscaling = st.checkbox("Use Upscaling", value=st.session_state['defaults'].txt2img.use_upscaling,
help="Choose to turn on upscaling by default. Default: False")
st.session_state["defaults"].txt2img.update_preview = True
st.session_state["defaults"].txt2img.update_preview_frequency = st.number_input("Preview Image Update Frequency",
min_value=0,
value=st.session_state['defaults'].txt2img.update_preview_frequency,
help="Set the default value for the frrquency of the preview image updates. Default is: 10")
with col5:
st.title("Variation Parameters")
st.session_state["defaults"].txt2img.variant_amount.value = st.number_input("Default Variation Amount",
value=st.session_state['defaults'].txt2img.variant_amount.value,
help="Set the default variation to use. Default is: 0.0")
st.session_state["defaults"].txt2img.variant_amount.min_value = st.number_input("Minimum Variation Amount",
value=st.session_state['defaults'].txt2img.variant_amount.min_value,
help="Set the default minimum value for the variation slider. Default is: 0.0")
st.session_state["defaults"].txt2img.variant_amount.max_value = st.number_input("Maximum Variation Amount",
value=st.session_state['defaults'].txt2img.variant_amount.max_value,
help="Set the default maximum value for the variation slider. Default is: 1.0")
st.session_state["defaults"].txt2img.variant_amount.step = st.number_input("Variation Slider Steps",
value=st.session_state['defaults'].txt2img.variant_amount.step,
help="Set the default value for the number of steps on the variation slider. Default is: 1")
st.session_state['defaults'].txt2img.variant_seed = st.text_input("Default Variation Seed", value=st.session_state['defaults'].txt2img.variant_seed,
help="Default variation seed.")
with img2img_tab:
col1, col2, col3, col4, col5 = st.columns(5, gap='medium')
with col1:
st.title("Image Editing")
# Denoising
st.session_state["defaults"].img2img.denoising_strength.value = st.number_input("Default Denoising Amount",
value=st.session_state['defaults'].img2img.denoising_strength.value,
help="Set the default denoising to use. Default is: 0.75")
st.session_state["defaults"].img2img.denoising_strength.min_value = st.number_input("Minimum Denoising Amount",
value=st.session_state['defaults'].img2img.denoising_strength.min_value,
help="Set the default minimum value for the denoising slider. Default is: 0.0")
st.session_state["defaults"].img2img.denoising_strength.max_value = st.number_input("Maximum Denoising Amount",
value=st.session_state['defaults'].img2img.denoising_strength.max_value,
help="Set the default maximum value for the denoising slider. Default is: 1.0")
st.session_state["defaults"].img2img.denoising_strength.step = st.number_input("Denoising Slider Steps",
value=st.session_state['defaults'].img2img.denoising_strength.step,
help="Set the default value for the number of steps on the denoising slider. Default is: 0.01")
# Masking
st.session_state["defaults"].img2img.mask_mode = st.number_input("Default Mask Mode", value=st.session_state['defaults'].img2img.mask_mode,
help="Set the default mask mode to use. 0 = Keep Masked Area, 1 = Regenerate Masked Area. Default is: 0")
st.session_state["defaults"].img2img.mask_restore = st.checkbox("Default Mask Restore", value=st.session_state['defaults'].img2img.mask_restore,
help="Mask Restore. Default: False")
st.session_state["defaults"].img2img.resize_mode = st.number_input("Default Resize Mode", value=st.session_state['defaults'].img2img.resize_mode,
help="Set the default resizing mode. 0 = Just Resize, 1 = Crop and Resize, 3 = Resize and Fill. Default is: 0")
with col2:
st.title("Slider Parameters")
# Width
st.session_state["defaults"].img2img.width.value = st.number_input("Default Outputted Image Width", value=st.session_state['defaults'].img2img.width.value,
help="Set the default width for the generated image. Default is: 512")
st.session_state["defaults"].img2img.width.min_value = st.number_input("Minimum Outputted Image Width", value=st.session_state['defaults'].img2img.width.min_value,
help="Set the default minimum value for the width slider. Default is: 64")
st.session_state["defaults"].img2img.width.max_value = st.number_input("Maximum Outputted Image Width", value=st.session_state['defaults'].img2img.width.max_value,
help="Set the default maximum value for the width slider. Default is: 2048")
# Height
st.session_state["defaults"].img2img.height.value = st.number_input("Default Outputted Image Height", value=st.session_state['defaults'].img2img.height.value,
help="Set the default height for the generated image. Default is: 512")
st.session_state["defaults"].img2img.height.min_value = st.number_input("Minimum Outputted Image Height", value=st.session_state['defaults'].img2img.height.min_value,
help="Set the default minimum value for the height slider. Default is: 64")
st.session_state["defaults"].img2img.height.max_value = st.number_input("Maximum Outputted Image Height", value=st.session_state['defaults'].img2img.height.max_value,
help="Set the default maximum value for the height slider. Default is: 2048")
# CFG
st.session_state["defaults"].img2img.cfg_scale.value = st.number_input("Default Img2Img CFG Scale", value=st.session_state['defaults'].img2img.cfg_scale.value,
help="Set the default value for the CFG Scale. Default is: 7.5")
st.session_state["defaults"].img2img.cfg_scale.min_value = st.number_input("Minimum Img2Img CFG Scale Value",
value=st.session_state['defaults'].img2img.cfg_scale.min_value,
help="Set the default minimum value for the CFG scale slider. Default is: 1")
with col3:
st.session_state["defaults"].img2img.cfg_scale.step = st.number_input("Img2Img CFG Slider Steps",
value=st.session_state['defaults'].img2img.cfg_scale.step,
help="Set the default value for the number of steps on the CFG scale slider. Default is: 0.5")
# Sampling Steps
st.session_state["defaults"].img2img.sampling_steps.value = st.number_input("Default Img2Img Sampling Steps",
value=st.session_state['defaults'].img2img.sampling_steps.value,
help="Set the default number of sampling steps to use. Default is: 30 (with k_euler)")
st.session_state["defaults"].img2img.sampling_steps.min_value = st.number_input("Minimum Img2Img Sampling Steps",
value=st.session_state['defaults'].img2img.sampling_steps.min_value,
help="Set the default minimum value for the sampling steps slider. Default is: 1")
st.session_state["defaults"].img2img.sampling_steps.step = st.number_input("Img2Img Sampling Slider Steps",
value=st.session_state['defaults'].img2img.sampling_steps.step,
help="Set the default value for the number of steps on the sampling steps slider. Default is: 10")
# Batch Count
st.session_state["defaults"].img2img.batch_count.value = st.number_input("Img2img Batch count", value=st.session_state["defaults"].img2img.batch_count.value,
help="How many iterations or batches of images to generate in total.")
st.session_state["defaults"].img2img.batch_size.value = st.number_input("Img2img Batch size", value=st.session_state["defaults"].img2img.batch_size.value,
help="How many images are at once in a batch.\
It increases the VRAM usage a lot but if you have enough VRAM it can reduce the time it \
takes to finish generation as more images are generated at once.\
Default: 1")
with col4:
# Inference Steps
st.session_state["defaults"].img2img.num_inference_steps.value = st.number_input("Default Inference Steps",
value=st.session_state['defaults'].img2img.num_inference_steps.value,
help="Set the default number of inference steps to use. Default is: 200")
st.session_state["defaults"].img2img.num_inference_steps.min_value = st.number_input("Minimum Sampling Steps",
value=st.session_state['defaults'].img2img.num_inference_steps.min_value,
help="Set the default minimum value for the inference steps slider. Default is: 10")
st.session_state["defaults"].img2img.num_inference_steps.max_value = st.number_input("Maximum Sampling Steps",
value=st.session_state['defaults'].img2img.num_inference_steps.max_value,
help="Set the default maximum value for the inference steps slider. Default is: 500")
st.session_state["defaults"].img2img.num_inference_steps.step = st.number_input("Inference Slider Steps",
value=st.session_state['defaults'].img2img.num_inference_steps.step,
help="Set the default value for the number of steps on the inference steps slider.\
Default is: 10")
# Find Noise Steps
st.session_state["defaults"].img2img.find_noise_steps.value = st.number_input("Default Find Noise Steps",
value=st.session_state['defaults'].img2img.find_noise_steps.value,
help="Set the default number of find noise steps to use. Default is: 100")
st.session_state["defaults"].img2img.find_noise_steps.min_value = st.number_input("Minimum Find Noise Steps",
value=st.session_state['defaults'].img2img.find_noise_steps.min_value,
help="Set the default minimum value for the find noise steps slider. Default is: 0")
st.session_state["defaults"].img2img.find_noise_steps.step = st.number_input("Find Noise Slider Steps",
value=st.session_state['defaults'].img2img.find_noise_steps.step,
help="Set the default value for the number of steps on the find noise steps slider. \
Default is: 100")
with col5:
st.title("General Parameters")
default_sampler_list = ["k_lms", "k_euler", "k_euler_a", "k_dpm_2", "k_dpm_2_a", "k_heun", "PLMS", "DDIM"]
st.session_state["defaults"].img2img.sampler_name = st.selectbox("Default Img2Img Sampler", default_sampler_list,
index=default_sampler_list.index(st.session_state['defaults'].img2img.sampler_name),
help="Defaut sampler to use for img2img. Default: k_euler")
st.session_state['defaults'].img2img.seed = st.text_input("Default Img2Img Seed", value=st.session_state['defaults'].img2img.seed, help="Default seed.")
st.session_state["defaults"].img2img.separate_prompts = st.checkbox("Separate Img2Img Prompts", value=st.session_state['defaults'].img2img.separate_prompts,
help="Separate Prompts. Default: False")
st.session_state["defaults"].img2img.normalize_prompt_weights = st.checkbox("Normalize Img2Img Prompt Weights",
value=st.session_state['defaults'].img2img.normalize_prompt_weights,
help="Choose to normalize prompt weights. Default: True")
st.session_state["defaults"].img2img.save_individual_images = st.checkbox("Save Individual Img2Img Images",
value=st.session_state['defaults'].img2img.save_individual_images,
help="Choose to save individual images. Default: True")
st.session_state["defaults"].img2img.save_grid = st.checkbox("Save Img2Img Grid Images",
value=st.session_state['defaults'].img2img.save_grid, help="Choose to save the grid images. Default: True")
st.session_state["defaults"].img2img.group_by_prompt = st.checkbox("Group By Img2Img Prompt",
value=st.session_state['defaults'].img2img.group_by_prompt,
help="Choose to save images grouped by their prompt. Default: False")
st.session_state["defaults"].img2img.save_as_jpg = st.checkbox("Save Img2Img As JPG", value=st.session_state['defaults'].img2img.save_as_jpg,
help="Choose to save images as jpegs. Default: False")
st.session_state["defaults"].img2img.write_info_files = st.checkbox("Write Info Files For Img2Img Images",
value=st.session_state['defaults'].img2img.write_info_files,
help="Choose to write the info files along with the generated images. Default: True")
st.session_state["defaults"].img2img.use_GFPGAN = st.checkbox(
"Img2Img Use GFPGAN", value=st.session_state['defaults'].img2img.use_GFPGAN, help="Choose to use GFPGAN. Default: False")
st.session_state["defaults"].img2img.use_RealESRGAN = st.checkbox("Img2Img Use RealESRGAN", value=st.session_state['defaults'].img2img.use_RealESRGAN,
help="Choose to use RealESRGAN. Default: False")
st.session_state["defaults"].img2img.update_preview = True
st.session_state["defaults"].img2img.update_preview_frequency = st.number_input("Img2Img Preview Image Update Frequency",
min_value=0,
value=st.session_state['defaults'].img2img.update_preview_frequency,
help="Set the default value for the frrquency of the preview image updates. Default is: 10")
st.title("Variation Parameters")
st.session_state["defaults"].img2img.variant_amount = st.number_input("Default Img2Img Variation Amount",
value=st.session_state['defaults'].img2img.variant_amount,
help="Set the default variation to use. Default is: 0.0")
# I THINK THESE ARE MISSING FROM THE CONFIG FILE
# st.session_state["defaults"].img2img.variant_amount.min_value = st.number_input("Minimum Img2Img Variation Amount",
# value=st.session_state['defaults'].img2img.variant_amount.min_value, help="Set the default minimum value for the variation slider. Default is: 0.0"))
# st.session_state["defaults"].img2img.variant_amount.max_value = st.number_input("Maximum Img2Img Variation Amount",
# value=st.session_state['defaults'].img2img.variant_amount.max_value, help="Set the default maximum value for the variation slider. Default is: 1.0"))
# st.session_state["defaults"].img2img.variant_amount.step = st.number_input("Img2Img Variation Slider Steps",
# value=st.session_state['defaults'].img2img.variant_amount.step, help="Set the default value for the number of steps on the variation slider. Default is: 1"))
st.session_state['defaults'].img2img.variant_seed = st.text_input("Default Img2Img Variation Seed",
value=st.session_state['defaults'].img2img.variant_seed, help="Default variation seed.")
with img2txt_tab:
col1 = st.columns(1, gap="large")
st.title("Image-To-Text")
st.session_state["defaults"].img2txt.batch_size = st.number_input("Default Img2Txt Batch Size", value=st.session_state['defaults'].img2txt.batch_size,
help="Set the default batch size for Img2Txt. Default is: 420?")
st.session_state["defaults"].img2txt.blip_image_eval_size = st.number_input("Default Blip Image Size Evaluation",
value=st.session_state['defaults'].img2txt.blip_image_eval_size,
help="Set the default value for the blip image evaluation size. Default is: 512")
with txt2vid_tab:
col1, col2, col3, col4, col5 = st.columns(5, gap="medium")
with col1:
st.title("Slider Parameters")
# Width
st.session_state["defaults"].txt2vid.width.value = st.number_input("Default txt2vid Image Width",
value=st.session_state['defaults'].txt2vid.width.value,
help="Set the default width for the generated image. Default is: 512")
st.session_state["defaults"].txt2vid.width.min_value = st.number_input("Minimum txt2vid Image Width",
value=st.session_state['defaults'].txt2vid.width.min_value,
help="Set the default minimum value for the width slider. Default is: 64")
st.session_state["defaults"].txt2vid.width.max_value = st.number_input("Maximum txt2vid Image Width",
value=st.session_state['defaults'].txt2vid.width.max_value,
help="Set the default maximum value for the width slider. Default is: 2048")
# Height
st.session_state["defaults"].txt2vid.height.value = st.number_input("Default txt2vid Image Height",
value=st.session_state['defaults'].txt2vid.height.value,
help="Set the default height for the generated image. Default is: 512")
st.session_state["defaults"].txt2vid.height.min_value = st.number_input("Minimum txt2vid Image Height",
value=st.session_state['defaults'].txt2vid.height.min_value,
help="Set the default minimum value for the height slider. Default is: 64")
st.session_state["defaults"].txt2vid.height.max_value = st.number_input("Maximum txt2vid Image Height",
value=st.session_state['defaults'].txt2vid.height.max_value,
help="Set the default maximum value for the height slider. Default is: 2048")
# CFG
st.session_state["defaults"].txt2vid.cfg_scale.value = st.number_input("Default txt2vid CFG Scale",
value=st.session_state['defaults'].txt2vid.cfg_scale.value,
help="Set the default value for the CFG Scale. Default is: 7.5")
st.session_state["defaults"].txt2vid.cfg_scale.min_value = st.number_input("Minimum txt2vid CFG Scale Value",
value=st.session_state['defaults'].txt2vid.cfg_scale.min_value,
help="Set the default minimum value for the CFG scale slider. Default is: 1")
st.session_state["defaults"].txt2vid.cfg_scale.step = st.number_input("txt2vid CFG Slider Steps",
value=st.session_state['defaults'].txt2vid.cfg_scale.step,
help="Set the default value for the number of steps on the CFG scale slider. Default is: 0.5")
with col2:
# Sampling Steps
st.session_state["defaults"].txt2vid.sampling_steps.value = st.number_input("Default txt2vid Sampling Steps",
value=st.session_state['defaults'].txt2vid.sampling_steps.value,
help="Set the default number of sampling steps to use. Default is: 30 (with k_euler)")
st.session_state["defaults"].txt2vid.sampling_steps.min_value = st.number_input("Minimum txt2vid Sampling Steps",
value=st.session_state['defaults'].txt2vid.sampling_steps.min_value,
help="Set the default minimum value for the sampling steps slider. Default is: 1")
st.session_state["defaults"].txt2vid.sampling_steps.step = st.number_input("txt2vid Sampling Slider Steps",
value=st.session_state['defaults'].txt2vid.sampling_steps.step,
help="Set the default value for the number of steps on the sampling steps slider. Default is: 10")
# Batch Count
st.session_state["defaults"].txt2vid.batch_count.value = st.number_input("txt2vid Batch count", value=st.session_state['defaults'].txt2vid.batch_count.value,
help="How many iterations or batches of images to generate in total.")
st.session_state["defaults"].txt2vid.batch_size.value = st.number_input("txt2vid Batch size", value=st.session_state.defaults.txt2vid.batch_size.value,
help="How many images are at once in a batch.\
It increases the VRAM usage a lot but if you have enough VRAM it can reduce the time it \
takes to finish generation as more images are generated at once.\
Default: 1")
# Inference Steps
st.session_state["defaults"].txt2vid.num_inference_steps.value = st.number_input("Default Txt2Vid Inference Steps",
value=st.session_state['defaults'].txt2vid.num_inference_steps.value,
help="Set the default number of inference steps to use. Default is: 200")
st.session_state["defaults"].txt2vid.num_inference_steps.min_value = st.number_input("Minimum Txt2Vid Sampling Steps",
value=st.session_state['defaults'].txt2vid.num_inference_steps.min_value,
help="Set the default minimum value for the inference steps slider. Default is: 10")
st.session_state["defaults"].txt2vid.num_inference_steps.max_value = st.number_input("Maximum Txt2Vid Sampling Steps",
value=st.session_state['defaults'].txt2vid.num_inference_steps.max_value,
help="Set the default maximum value for the inference steps slider. Default is: 500")
st.session_state["defaults"].txt2vid.num_inference_steps.step = st.number_input("Txt2Vid Inference Slider Steps",
value=st.session_state['defaults'].txt2vid.num_inference_steps.step,
help="Set the default value for the number of steps on the inference steps slider. Default is: 10")
with col3:
st.title("General Parameters")
st.session_state['defaults'].txt2vid.default_model = st.text_input("Default Txt2Vid Model", value=st.session_state['defaults'].txt2vid.default_model,
help="Default: CompVis/stable-diffusion-v1-4")
# INSERT CUSTOM_MODELS_LIST HERE
default_sampler_list = ["k_lms", "k_euler", "k_euler_a", "k_dpm_2", "k_dpm_2_a", "k_heun", "PLMS", "DDIM"]
st.session_state["defaults"].txt2vid.default_sampler = st.selectbox("Default txt2vid Sampler", default_sampler_list,
index=default_sampler_list.index(st.session_state['defaults'].txt2vid.default_sampler),
help="Defaut sampler to use for txt2vid. Default: k_euler")
st.session_state['defaults'].txt2vid.seed = st.text_input("Default txt2vid Seed", value=st.session_state['defaults'].txt2vid.seed, help="Default seed.")
st.session_state['defaults'].txt2vid.scheduler_name = st.text_input("Default Txt2Vid Scheduler",
value=st.session_state['defaults'].txt2vid.scheduler_name, help="Default scheduler.")
st.session_state["defaults"].txt2vid.separate_prompts = st.checkbox("Separate txt2vid Prompts",
value=st.session_state['defaults'].txt2vid.separate_prompts, help="Separate Prompts. Default: False")
st.session_state["defaults"].txt2vid.normalize_prompt_weights = st.checkbox("Normalize txt2vid Prompt Weights",
value=st.session_state['defaults'].txt2vid.normalize_prompt_weights,
help="Choose to normalize prompt weights. Default: True")
st.session_state["defaults"].txt2vid.save_individual_images = st.checkbox("Save Individual txt2vid Images",
value=st.session_state['defaults'].txt2vid.save_individual_images,
help="Choose to save individual images. Default: True")
st.session_state["defaults"].txt2vid.save_video = st.checkbox("Save Txt2Vid Video", value=st.session_state['defaults'].txt2vid.save_video,
help="Choose to save the Txt2Vid video. Default: True")
st.session_state["defaults"].txt2vid.save_video_on_stop = st.checkbox("Save video on Stop", value=st.session_state['defaults'].txt2vid.save_video_on_stop,
help="Save a video with all the images generated as frames when we hit the stop button \
during a generation.")
st.session_state["defaults"].txt2vid.group_by_prompt = st.checkbox("Group By txt2vid Prompt", value=st.session_state['defaults'].txt2vid.group_by_prompt,
help="Choose to save images grouped by their prompt. Default: False")
st.session_state["defaults"].txt2vid.save_as_jpg = st.checkbox("Save txt2vid As JPG", value=st.session_state['defaults'].txt2vid.save_as_jpg,
help="Choose to save images as jpegs. Default: False")
# Need more info for the Help dialog...
st.session_state["defaults"].txt2vid.do_loop = st.checkbox("Loop Generations", value=st.session_state['defaults'].txt2vid.do_loop,
help="Choose to loop or something, IDK.... Default: False")
st.session_state["defaults"].txt2vid.max_duration_in_seconds = st.number_input("Txt2Vid Max Duration in Seconds", value=st.session_state['defaults'].txt2vid.max_duration_in_seconds,
help="Set the default value for the max duration in seconds for the video generated. Default is: 30")
st.session_state["defaults"].txt2vid.write_info_files = st.checkbox("Write Info Files For txt2vid Images", value=st.session_state['defaults'].txt2vid.write_info_files,
help="Choose to write the info files along with the generated images. Default: True")
st.session_state["defaults"].txt2vid.use_GFPGAN = st.checkbox("txt2vid Use GFPGAN", value=st.session_state['defaults'].txt2vid.use_GFPGAN,
help="Choose to use GFPGAN. Default: False")
st.session_state["defaults"].txt2vid.use_RealESRGAN = st.checkbox("txt2vid Use RealESRGAN", value=st.session_state['defaults'].txt2vid.use_RealESRGAN,
help="Choose to use RealESRGAN. Default: False")
st.session_state["defaults"].txt2vid.update_preview = True
st.session_state["defaults"].txt2vid.update_preview_frequency = st.number_input("txt2vid Preview Image Update Frequency",
value=st.session_state['defaults'].txt2vid.update_preview_frequency,
help="Set the default value for the frrquency of the preview image updates. Default is: 10")
with col4:
st.title("Variation Parameters")
st.session_state["defaults"].txt2vid.variant_amount.value = st.number_input("Default txt2vid Variation Amount",
value=st.session_state['defaults'].txt2vid.variant_amount.value,
help="Set the default variation to use. Default is: 0.0")
st.session_state["defaults"].txt2vid.variant_amount.min_value = st.number_input("Minimum txt2vid Variation Amount",
value=st.session_state['defaults'].txt2vid.variant_amount.min_value,
help="Set the default minimum value for the variation slider. Default is: 0.0")
st.session_state["defaults"].txt2vid.variant_amount.max_value = st.number_input("Maximum txt2vid Variation Amount",
value=st.session_state['defaults'].txt2vid.variant_amount.max_value,
help="Set the default maximum value for the variation slider. Default is: 1.0")
st.session_state["defaults"].txt2vid.variant_amount.step = st.number_input("txt2vid Variation Slider Steps",
value=st.session_state['defaults'].txt2vid.variant_amount.step,
help="Set the default value for the number of steps on the variation slider. Default is: 1")
st.session_state['defaults'].txt2vid.variant_seed = st.text_input("Default txt2vid Variation Seed",
value=st.session_state['defaults'].txt2vid.variant_seed, help="Default variation seed.")
with col5:
st.title("Beta Parameters")
# Beta Start
st.session_state["defaults"].txt2vid.beta_start.value = st.number_input("Default txt2vid Beta Start Value",
value=st.session_state['defaults'].txt2vid.beta_start.value,
help="Set the default variation to use. Default is: 0.0")
st.session_state["defaults"].txt2vid.beta_start.min_value = st.number_input("Minimum txt2vid Beta Start Amount",
value=st.session_state['defaults'].txt2vid.beta_start.min_value,
help="Set the default minimum value for the variation slider. Default is: 0.0")
st.session_state["defaults"].txt2vid.beta_start.max_value = st.number_input("Maximum txt2vid Beta Start Amount",
value=st.session_state['defaults'].txt2vid.beta_start.max_value,
help="Set the default maximum value for the variation slider. Default is: 1.0")
st.session_state["defaults"].txt2vid.beta_start.step = st.number_input("txt2vid Beta Start Slider Steps", value=st.session_state['defaults'].txt2vid.beta_start.step,
help="Set the default value for the number of steps on the variation slider. Default is: 1")
st.session_state["defaults"].txt2vid.beta_start.format = st.text_input("Default txt2vid Beta Start Format", value=st.session_state['defaults'].txt2vid.beta_start.format,
help="Set the default Beta Start Format. Default is: %.5\f")
# Beta End
st.session_state["defaults"].txt2vid.beta_end.value = st.number_input("Default txt2vid Beta End Value", value=st.session_state['defaults'].txt2vid.beta_end.value,
help="Set the default variation to use. Default is: 0.0")
st.session_state["defaults"].txt2vid.beta_end.min_value = st.number_input("Minimum txt2vid Beta End Amount", value=st.session_state['defaults'].txt2vid.beta_end.min_value,
help="Set the default minimum value for the variation slider. Default is: 0.0")
st.session_state["defaults"].txt2vid.beta_end.max_value = st.number_input("Maximum txt2vid Beta End Amount", value=st.session_state['defaults'].txt2vid.beta_end.max_value,
help="Set the default maximum value for the variation slider. Default is: 1.0")
st.session_state["defaults"].txt2vid.beta_end.step = st.number_input("txt2vid Beta End Slider Steps", value=st.session_state['defaults'].txt2vid.beta_end.step,
help="Set the default value for the number of steps on the variation slider. Default is: 1")
st.session_state["defaults"].txt2vid.beta_end.format = st.text_input("Default txt2vid Beta End Format", value=st.session_state['defaults'].txt2vid.beta_start.format,
help="Set the default Beta Start Format. Default is: %.5\f")
with image_processing:
col1, col2, col3, col4, col5 = st.columns(5, gap="large")
with col1:
st.title("GFPGAN")
st.session_state["defaults"].gfpgan.strength = st.number_input("Default Img2Txt Batch Size", value=st.session_state['defaults'].gfpgan.strength,
help="Set the default global strength for GFPGAN. Default is: 100")
with col2:
st.title("GoBig")
with col3:
st.title("RealESRGAN")
with col4:
st.title("LDSR")
with col5:
st.title("GoLatent")
with textual_inversion_tab:
st.title("Textual Inversion")
st.session_state['defaults'].textual_inversion.pretrained_model_name_or_path = st.text_input("Default Textual Inversion Model Path",
value=st.session_state['defaults'].textual_inversion.pretrained_model_name_or_path,
help="Default: models/ldm/stable-diffusion-v1-4")
st.session_state['defaults'].textual_inversion.tokenizer_name = st.text_input("Default Img2Img Variation Seed", value=st.session_state['defaults'].textual_inversion.tokenizer_name,
help="Default tokenizer seed.")
with concepts_library_tab:
st.title("Concepts Library")
#st.info("Under Construction. :construction_worker:")
col1, col2, col3, col4, col5 = st.columns(5, gap='large')
with col1:
st.session_state["defaults"].concepts_library.concepts_per_page = st.number_input("Concepts Per Page", value=st.session_state['defaults'].concepts_library.concepts_per_page,
help="Number of concepts per page to show on the Concepts Library. Default: '12'")
# add space for the buttons at the bottom
st.markdown("---")
# We need a submit button to save the Settings
# as well as one to reset them to the defaults, just in case.
_, _, save_button_col, reset_button_col, _, _ = st.columns([1, 1, 1, 1, 1, 1], gap="large")
with save_button_col:
save_button = st.form_submit_button("Save")
with reset_button_col:
reset_button = st.form_submit_button("Reset")
if save_button:
OmegaConf.save(config=st.session_state.defaults, f="configs/webui/userconfig_streamlit.yaml")
loaded = OmegaConf.load("configs/webui/userconfig_streamlit.yaml")
assert st.session_state.defaults == loaded
#
if (os.path.exists(".streamlit/config.toml")):
with open(".streamlit/config.toml", "w") as toml_file:
toml.dump(st.session_state["streamlit_config"], toml_file)
if reset_button:
st.session_state["defaults"] = OmegaConf.load("configs/webui/webui_streamlit.yaml")
st.experimental_rerun()

View File

@ -0,0 +1,91 @@
# This file is part of sygil-webui (https://github.com/Sygil-Dev/sandbox-webui/).
# Copyright 2022 Sygil-Dev team.
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
# base webui import and utils.
#from sd_utils import *
from sd_utils import st
# streamlit imports
#streamlit components section
#other imports
from barfi import st_barfi, barfi_schemas, Block
# Temp imports
# end of imports
#---------------------------------------------------------------------------------------------------------------
def layout():
#st.info("Under Construction. :construction_worker:")
#from barfi import st_barfi, Block
#add = Block(name='Addition')
#sub = Block(name='Subtraction')
#mul = Block(name='Multiplication')
#div = Block(name='Division')
#barfi_result = st_barfi(base_blocks= [add, sub, mul, div])
# or if you want to use a category to organise them in the frontend sub-menu
#barfi_result = st_barfi(base_blocks= {'Op 1': [add, sub], 'Op 2': [mul, div]})
col1, col2, col3 = st.columns([1, 8, 1])
with col2:
feed = Block(name='Feed')
feed.add_output()
def feed_func(self):
self.set_interface(name='Output 1', value=4)
feed.add_compute(feed_func)
splitter = Block(name='Splitter')
splitter.add_input()
splitter.add_output()
splitter.add_output()
def splitter_func(self):
in_1 = self.get_interface(name='Input 1')
value = (in_1/2)
self.set_interface(name='Output 1', value=value)
self.set_interface(name='Output 2', value=value)
splitter.add_compute(splitter_func)
mixer = Block(name='Mixer')
mixer.add_input()
mixer.add_input()
mixer.add_output()
def mixer_func(self):
in_1 = self.get_interface(name='Input 1')
in_2 = self.get_interface(name='Input 2')
value = (in_1 + in_2)
self.set_interface(name='Output 1', value=value)
mixer.add_compute(mixer_func)
result = Block(name='Result')
result.add_input()
def result_func(self):
in_1 = self.get_interface(name='Input 1')
result.add_compute(result_func)
load_schema = st.selectbox('Select a saved schema:', barfi_schemas())
compute_engine = st.checkbox('Activate barfi compute engine', value=False)
barfi_result = st_barfi(base_blocks=[feed, result, mixer, splitter],
compute_engine=compute_engine, load_schema=load_schema)
if barfi_result:
st.write(barfi_result)

View File

@ -0,0 +1,134 @@
<html>
<head>
<style>
/* our style while dragging */
.value-dragging
{
background-color: lightblue;
}
</style>
</head>
<body>
<!-- our fields -->
<input type="number" value="0" min=0 max=200 step=10>
<input type="number" value="0" min=-50 max=200 step=10>
<script>
// iframe parent
// var parentDoc = window.parent.document
// check for mouse pointer locking support, not a requirement but improves the overall experience
var havePointerLock = 'pointerLockElement' in document ||
'mozPointerLockElement' in document ||
'webkitPointerLockElement' in document;
// the pointer locking exit function
document.exitPointerLock = document.exitPointerLock || document.mozExitPointerLock || document.webkitExitPointerLock;
// how far should the mouse travel for a step 50 pixel
var pixelPerStep = 50;
// how many steps did the mouse move in as float
var movementDelta = 0.0;
// value when drag started
var lockedValue = 0;
// minimum value from field
var lockedMin = 0;
// maximum value from field
var lockedMax = 0;
// how big should the field steps be
var lockedStep = 0;
// the currently locked in field
var lockedField = null;
// register events and pointer locking on field
RegisterField = (field) => {
if(havePointerLock)
field.requestPointerLock = field.requestPointerLock || field.mozRequestPointerLock || field.webkitRequestPointerLock;
field.title = "Click and hold middle mouse button\nmove mouse left to decrease\nmove right to increase";
field.addEventListener('mousedown', (e) => {
onDragStart(e)
});
}
onDragStart = (e) => {
// if middle mouse is down
if(e.button === 1)
{
// save current field
lockedField = e.target;
// add class for styling
lockedField.classList.add("value-dragging");
// reset movement delta
movementDelta = 0.0;
// set to 0 if field is empty
if(lockedField.value === '')
lockedField.value = '0';
// save current field value
lockedValue = parseInt(lockedField.value);
if(lockedField.min === '')
lockedField.min = '-99999999';
if(lockedField.max === '')
lockedField.max = '99999999';
if(lockedField.step === '')
lockedField.step = '10';
lockedMin = parseInt(lockedField.min);
lockedMax = parseInt(lockedField.max);
lockedStep = parseInt(lockedField.step);
// lock pointer if available
if(havePointerLock)
lockedField.requestPointerLock();
// add drag event
document.addEventListener("mousemove", onDrag, false);
// prevent event propagation
e.preventDefault();
}
};
onDrag = (e) => {
if(lockedField !== null)
{
// add movement to delta
movementDelta += e.movementX / pixelPerStep;
// set new value
let value = lockedValue + Math.floor(Math.abs(movementDelta)) * lockedStep * Math.sign(movementDelta);
lockedField.value = Math.min(Math.max(value, lockedMin), lockedMax);
}
};
document.addEventListener('mouseup', (e) => {
// if middle mouse is up
if(e.button === 1)
{
// release pointer lock if available
if(havePointerLock)
document.exitPointerLock();
if(lockedField !== null)
{
// stop drag event
document.removeEventListener("mousemove", onDrag, false);
// remove class for styling
lockedField.classList.remove("value-dragging");
// remove reference
lockedField = null;
}
}
});
// find and register all input fields of type=number
var list = document.querySelectorAll('input[type="number"]');
list.forEach(RegisterField);
</script>
</body>
</html>

View File

@ -0,0 +1,11 @@
import os
import streamlit.components.v1 as components
def load(pixel_per_step = 50):
parent_dir = os.path.dirname(os.path.abspath(__file__))
file = os.path.join(parent_dir, "main.js")
with open(file) as f:
javascript_main = f.read()
javascript_main = javascript_main.replace("%%pixelPerStep%%",str(pixel_per_step))
components.html(f"<script>{javascript_main}</script>")

View File

@ -0,0 +1,192 @@
// iframe parent
var parentDoc = window.parent.document
// check for mouse pointer locking support, not a requirement but improves the overall experience
var havePointerLock = 'pointerLockElement' in parentDoc ||
'mozPointerLockElement' in parentDoc ||
'webkitPointerLockElement' in parentDoc;
// the pointer locking exit function
parentDoc.exitPointerLock = parentDoc.exitPointerLock || parentDoc.mozExitPointerLock || parentDoc.webkitExitPointerLock;
// how far should the mouse travel for a step in pixel
var pixelPerStep = %%pixelPerStep%%;
// how many steps did the mouse move in as float
var movementDelta = 0.0;
// value when drag started
var lockedValue = 0.0;
// minimum value from field
var lockedMin = 0.0;
// maximum value from field
var lockedMax = 0.0;
// how big should the field steps be
var lockedStep = 0.0;
// the currently locked in field
var lockedField = null;
// lock box to just request pointer lock for one element
var lockBox = document.createElement("div");
lockBox.classList.add("lockbox");
parentDoc.body.appendChild(lockBox);
lockBox.requestPointerLock = lockBox.requestPointerLock || lockBox.mozRequestPointerLock || lockBox.webkitRequestPointerLock;
function Lock(field)
{
var rect = field.getBoundingClientRect();
lockBox.style.left = (rect.left-2.5)+"px";
lockBox.style.top = (rect.top-2.5)+"px";
lockBox.style.width = (rect.width+2.5)+"px";
lockBox.style.height = (rect.height+5)+"px";
lockBox.requestPointerLock();
}
function Unlock()
{
parentDoc.exitPointerLock();
lockBox.style.left = "0px";
lockBox.style.top = "0px";
lockBox.style.width = "0px";
lockBox.style.height = "0px";
lockedField.focus();
}
parentDoc.addEventListener('mousedown', (e) => {
// if middle is down
if(e.button === 1)
{
if(e.target.tagName === 'INPUT' && e.target.type === 'number')
{
e.preventDefault();
var field = e.target;
if(havePointerLock)
Lock(field);
// save current field
lockedField = e.target;
// add class for styling
lockedField.classList.add("value-dragging");
// reset movement delta
movementDelta = 0.0;
// set to 0 if field is empty
if(lockedField.value === '')
lockedField.value = 0.0;
// save current field value
lockedValue = parseFloat(lockedField.value);
if(lockedField.min === '' || lockedField.min === '-Infinity')
lockedMin = -99999999.0;
else
lockedMin = parseFloat(lockedField.min);
if(lockedField.max === '' || lockedField.max === 'Infinity')
lockedMax = 99999999.0;
else
lockedMax = parseFloat(lockedField.max);
if(lockedField.step === '' || lockedField.step === 'Infinity')
lockedStep = 1.0;
else
lockedStep = parseFloat(lockedField.step);
// lock pointer if available
if(havePointerLock)
Lock(lockedField);
// add drag event
parentDoc.addEventListener("mousemove", onDrag, false);
}
}
});
function onDrag(e)
{
if(lockedField !== null)
{
// add movement to delta
movementDelta += e.movementX / pixelPerStep;
if(lockedField === NaN)
return;
// set new value
let value = lockedValue + Math.floor(Math.abs(movementDelta)) * lockedStep * Math.sign(movementDelta);
lockedField.focus();
lockedField.select();
parentDoc.execCommand('insertText', false /*no UI*/, Math.min(Math.max(value, lockedMin), lockedMax));
}
}
parentDoc.addEventListener('mouseup', (e) => {
// if mouse is up
if(e.button === 1)
{
// release pointer lock if available
if(havePointerLock)
Unlock();
if(lockedField !== null && lockedField !== NaN)
{
// stop drag event
parentDoc.removeEventListener("mousemove", onDrag, false);
// remove class for styling
lockedField.classList.remove("value-dragging");
// remove reference
lockedField = null;
}
}
});
// only execute once (even though multiple iframes exist)
if(!parentDoc.hasOwnProperty("dragableInitialized"))
{
var parentCSS =
`
/* Make input-instruction not block mouse events */
.input-instructions,.input-instructions > *{
pointer-events: none;
user-select: none;
-moz-user-select: none;
-khtml-user-select: none;
-webkit-user-select: none;
-o-user-select: none;
}
.lockbox {
background-color: transparent;
position: absolute;
pointer-events: none;
user-select: none;
-moz-user-select: none;
-khtml-user-select: none;
-webkit-user-select: none;
-o-user-select: none;
border-left: dotted 2px rgb(255,75,75);
border-top: dotted 2px rgb(255,75,75);
border-bottom: dotted 2px rgb(255,75,75);
border-right: dotted 1px rgba(255,75,75,0.2);
border-top-left-radius: 0.25rem;
border-bottom-left-radius: 0.25rem;
z-index: 1000;
}
`;
// get parent document head
var head = parentDoc.getElementsByTagName('head')[0];
// add style tag
var s = document.createElement('style');
// set type attribute
s.setAttribute('type', 'text/css');
// add css forwarded from python
if (s.styleSheet) { // IE
s.styleSheet.cssText = parentCSS;
} else { // the world
s.appendChild(document.createTextNode(parentCSS));
}
// add style to head
head.appendChild(s);
// set flag so this only runs once
parentDoc["dragableInitialized"] = true;
}

View File

@ -0,0 +1,46 @@
import os
from collections import defaultdict
import streamlit.components.v1 as components
# where to save the downloaded key_phrases
key_phrases_file = "data/tags/key_phrases.json"
# the loaded key phrase json as text
key_phrases_json = ""
# where to save the downloaded key_phrases
thumbnails_file = "data/tags/thumbnails.json"
# the loaded key phrase json as text
thumbnails_json = ""
def init():
global key_phrases_json, thumbnails_json
with open(key_phrases_file) as f:
key_phrases_json = f.read()
with open(thumbnails_file) as f:
thumbnails_json = f.read()
def suggestion_area(placeholder):
# get component path
parent_dir = os.path.dirname(os.path.abspath(__file__))
# get file paths
javascript_file = os.path.join(parent_dir, "main.js")
stylesheet_file = os.path.join(parent_dir, "main.css")
parent_stylesheet_file = os.path.join(parent_dir, "parent.css")
# load file texts
with open(javascript_file) as f:
javascript_main = f.read()
with open(stylesheet_file) as f:
stylesheet_main = f.read()
with open(parent_stylesheet_file) as f:
parent_stylesheet = f.read()
# add suggestion area div box
html = "<div id='scroll_area' class='st-bg'><div id='suggestion_area'>javascript failed</div></div>"
# add loaded style
html += f"<style>{stylesheet_main}</style>"
# set default variables
html += f"<script>var thumbnails = {thumbnails_json};\nvar keyPhrases = {key_phrases_json};\nvar parentCSS = `{parent_stylesheet}`;\nvar placeholder='{placeholder}';</script>"
# add main java script
html += f"\n<script>{javascript_main}</script>"
# add component to site
components.html(html, width=None, height=None, scrolling=True)

View File

@ -0,0 +1,81 @@
*
{
padding: 0px;
margin: 0px;
user-select: none;
-moz-user-select: none;
-khtml-user-select: none;
-webkit-user-select: none;
-o-user-select: none;
}
body
{
width: 100%;
height: 100%;
padding-left: calc( 1em - 1px );
padding-top: calc( 1em - 1px );
overflow: hidden;
}
/* width */
::-webkit-scrollbar {
width: 7px;
}
/* Track */
::-webkit-scrollbar-track {
background: rgb(10, 13, 19);
}
/* Handle */
::-webkit-scrollbar-thumb {
background: #6c6e72;
border-radius: 3px;
}
/* Handle on hover */
::-webkit-scrollbar-thumb:hover {
background: #6c6e72;
}
#scroll_area
{
display: flex;
overflow-x: hidden;
overflow-y: auto;
}
#suggestion_area
{
overflow-x: hidden;
width: calc( 100% - 2em - 2px );
margin-bottom: calc( 1em + 13px );
min-height: 50px;
}
span
{
border: 1px solid rgba(250, 250, 250, 0.2);
border-radius: 0.25rem;
font-size: 1rem;
font-family: "Source Sans Pro", sans-serif;
background-color: rgb(38, 39, 48);
color: white;
display: inline-block;
padding: 0.5rem;
margin-right: 3px;
cursor: pointer;
user-select: none;
-moz-user-select: none;
-khtml-user-select: none;
-webkit-user-select: none;
-o-user-select: none;
}
span:hover
{
color: rgb(255,75,75);
border-color: rgb(255,75,75);
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,84 @@
.suggestion-frame
{
position: absolute;
/* make as small as possible */
margin: 0px;
padding: 0px;
min-height: 0px;
line-height: 0;
/* animate transitions of the height property */
-webkit-transition: height 1s;
-moz-transition: height 1s;
-ms-transition: height 1s;
-o-transition: height 1s;
transition: height 1s, border-bottom-width 1s;
/* block selection */
user-select: none;
-moz-user-select: none;
-khtml-user-select: none;
-webkit-user-select: none;
-o-user-select: none;
z-index: 700;
outline: 1px solid rgba(250, 250, 250, 0.2);
outline-offset: 0px;
border-radius: 0.25rem;
background: rgb(14, 17, 23);
box-sizing: border-box;
-moz-box-sizing: border-box;
-webkit-box-sizing: border-box;
border-bottom: solid 13px rgb(14, 17, 23) !important;
border-left: solid 13px rgb(14, 17, 23) !important;
}
#phrase-tooltip
{
display: none;
pointer-events: none;
position: absolute;
border-bottom-left-radius: 0.5rem;
border-top-right-radius: 0.5rem;
border-bottom-right-radius: 0.5rem;
border: solid rgb(255,75,75) 2px;
background-color: rgb(38, 39, 48);
color: rgb(255,75,75);
font-size: 1rem;
font-family: "Source Sans Pro", sans-serif;
padding: 0.5rem;
cursor: default;
user-select: none;
-moz-user-select: none;
-khtml-user-select: none;
-webkit-user-select: none;
-o-user-select: none;
z-index: 1000;
}
#phrase-tooltip:has(img)
{
transform: scale(1.25, 1.25);
-ms-transform: scale(1.25, 1.25);
-webkit-transform: scale(1.25, 1.25);
}
#phrase-tooltip>img
{
pointer-events: none;
border-bottom-left-radius: 0.5rem;
border-top-right-radius: 0.5rem;
border-bottom-right-radius: 0.5rem;
cursor: default;
user-select: none;
-moz-user-select: none;
-khtml-user-select: none;
-webkit-user-select: none;
-o-user-select: none;
z-index: 1500;
}

View File

@ -0,0 +1,752 @@
# This file is part of sygil-webui (https://github.com/Sygil-Dev/sygil-webui/).
# Copyright 2022 Sygil-Dev team.
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
# base webui import and utils.
from sd_utils import st, server_state, no_rerun, \
custom_models_available, RealESRGAN_available, GFPGAN_available, LDSR_available
#generation_callback, process_images, KDiffusionSampler, \
#load_models, hc, seed_to_int, logger, \
#resize_image, get_matched_noise, CFGMaskedDenoiser, ImageFilter, set_page_title
# streamlit imports
from streamlit.runtime.scriptrunner import StopException
#other imports
import cv2
from PIL import Image, ImageOps
import torch
import k_diffusion as K
import numpy as np
import time
import torch
import skimage
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler
# streamlit components
from custom_components import sygil_suggestions
from streamlit_drawable_canvas import st_canvas
# Temp imports
# end of imports
#---------------------------------------------------------------------------------------------------------------
sygil_suggestions.init()
try:
# this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
from transformers import logging
logging.set_verbosity_error()
except:
pass
def img2img(prompt: str = '', init_info: any = None, init_info_mask: any = None, mask_mode: int = 0, mask_blur_strength: int = 3,
mask_restore: bool = False, ddim_steps: int = 50, sampler_name: str = 'DDIM',
n_iter: int = 1, cfg_scale: float = 7.5, denoising_strength: float = 0.8,
seed: int = -1, noise_mode: int = 0, find_noise_steps: str = "", height: int = 512, width: int = 512, resize_mode: int = 0, fp = None,
variant_amount: float = 0.0, variant_seed: int = None, ddim_eta:float = 0.0,
write_info_files:bool = True, separate_prompts:bool = False, normalize_prompt_weights:bool = True,
save_individual_images: bool = True, save_grid: bool = True, group_by_prompt: bool = True,
save_as_jpg: bool = True, use_GFPGAN: bool = True, GFPGAN_model: str = 'GFPGANv1.4',
use_RealESRGAN: bool = True, RealESRGAN_model: str = "RealESRGAN_x4plus_anime_6B",
use_LDSR: bool = True, LDSR_model: str = "model",
loopback: bool = False,
random_seed_loopback: bool = False
):
outpath = st.session_state['defaults'].general.outdir_img2img
seed = seed_to_int(seed)
batch_size = 1
if sampler_name == 'PLMS':
sampler = PLMSSampler(server_state["model"])
elif sampler_name == 'DDIM':
sampler = DDIMSampler(server_state["model"])
elif sampler_name == 'k_dpm_2_a':
sampler = KDiffusionSampler(server_state["model"],'dpm_2_ancestral')
elif sampler_name == 'k_dpm_2':
sampler = KDiffusionSampler(server_state["model"],'dpm_2')
elif sampler_name == 'k_dpmpp_2m':
sampler = KDiffusionSampler(server_state["model"],'dpmpp_2m')
elif sampler_name == 'k_euler_a':
sampler = KDiffusionSampler(server_state["model"],'euler_ancestral')
elif sampler_name == 'k_euler':
sampler = KDiffusionSampler(server_state["model"],'euler')
elif sampler_name == 'k_heun':
sampler = KDiffusionSampler(server_state["model"],'heun')
elif sampler_name == 'k_lms':
sampler = KDiffusionSampler(server_state["model"],'lms')
else:
raise Exception("Unknown sampler: " + sampler_name)
def process_init_mask(init_mask: Image):
if init_mask.mode == "RGBA":
init_mask = init_mask.convert('RGBA')
background = Image.new('RGBA', init_mask.size, (0, 0, 0))
init_mask = Image.alpha_composite(background, init_mask)
init_mask = init_mask.convert('RGB')
return init_mask
init_img = init_info
init_mask = None
if mask_mode == 0:
if init_info_mask:
init_mask = process_init_mask(init_info_mask)
elif mask_mode == 1:
if init_info_mask:
init_mask = process_init_mask(init_info_mask)
init_mask = ImageOps.invert(init_mask)
elif mask_mode == 2:
init_img_transparency = init_img.split()[-1].convert('L')#.point(lambda x: 255 if x > 0 else 0, mode='1')
init_mask = init_img_transparency
init_mask = init_mask.convert("RGB")
init_mask = resize_image(resize_mode, init_mask, width, height)
init_mask = init_mask.convert("RGB")
assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]'
t_enc = int(denoising_strength * ddim_steps)
if init_mask is not None and (noise_mode == 2 or noise_mode == 3) and init_img is not None:
noise_q = 0.99
color_variation = 0.0
mask_blend_factor = 1.0
np_init = (np.asarray(init_img.convert("RGB"))/255.0).astype(np.float64) # annoyingly complex mask fixing
np_mask_rgb = 1. - (np.asarray(ImageOps.invert(init_mask).convert("RGB"))/255.0).astype(np.float64)
np_mask_rgb -= np.min(np_mask_rgb)
np_mask_rgb /= np.max(np_mask_rgb)
np_mask_rgb = 1. - np_mask_rgb
np_mask_rgb_hardened = 1. - (np_mask_rgb < 0.99).astype(np.float64)
blurred = skimage.filters.gaussian(np_mask_rgb_hardened[:], sigma=16., channel_axis=2, truncate=32.)
blurred2 = skimage.filters.gaussian(np_mask_rgb_hardened[:], sigma=16., channel_axis=2, truncate=32.)
#np_mask_rgb_dilated = np_mask_rgb + blurred # fixup mask todo: derive magic constants
#np_mask_rgb = np_mask_rgb + blurred
np_mask_rgb_dilated = np.clip((np_mask_rgb + blurred2) * 0.7071, 0., 1.)
np_mask_rgb = np.clip((np_mask_rgb + blurred) * 0.7071, 0., 1.)
noise_rgb = get_matched_noise(np_init, np_mask_rgb, noise_q, color_variation)
blend_mask_rgb = np.clip(np_mask_rgb_dilated,0.,1.) ** (mask_blend_factor)
noised = noise_rgb[:]
blend_mask_rgb **= (2.)
noised = np_init[:] * (1. - blend_mask_rgb) + noised * blend_mask_rgb
np_mask_grey = np.sum(np_mask_rgb, axis=2)/3.
ref_mask = np_mask_grey < 1e-3
all_mask = np.ones((height, width), dtype=bool)
noised[all_mask,:] = skimage.exposure.match_histograms(noised[all_mask,:]**1., noised[ref_mask,:], channel_axis=1)
init_img = Image.fromarray(np.clip(noised * 255., 0., 255.).astype(np.uint8), mode="RGB")
st.session_state["editor_image"].image(init_img) # debug
def init():
image = init_img.convert('RGB')
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
mask_channel = None
if init_mask:
alpha = resize_image(resize_mode, init_mask, width // 8, height // 8)
mask_channel = alpha.split()[-1]
mask = None
if mask_channel is not None:
mask = np.array(mask_channel).astype(np.float32) / 255.0
mask = (1 - mask)
mask = np.tile(mask, (4, 1, 1))
mask = mask[None].transpose(0, 1, 2, 3)
mask = torch.from_numpy(mask).to(server_state["device"])
if st.session_state['defaults'].general.optimized:
server_state["modelFS"].to(server_state["device"] )
init_image = 2. * image - 1.
init_image = init_image.to(server_state["device"])
init_latent = (server_state["model"] if not st.session_state['defaults'].general.optimized else server_state["modelFS"]).get_first_stage_encoding((server_state["model"] if not st.session_state['defaults'].general.optimized else server_state["modelFS"]).encode_first_stage(init_image)) # move to latent space
if st.session_state['defaults'].general.optimized:
mem = torch.cuda.memory_allocated()/1e6
server_state["modelFS"].to("cpu")
while(torch.cuda.memory_allocated()/1e6 >= mem):
time.sleep(1)
return init_latent, mask,
def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name):
t_enc_steps = t_enc
obliterate = False
if ddim_steps == t_enc_steps:
t_enc_steps = t_enc_steps - 1
obliterate = True
if sampler_name != 'DDIM':
x0, z_mask = init_data
sigmas = sampler.model_wrap.get_sigmas(ddim_steps)
noise = x * sigmas[ddim_steps - t_enc_steps - 1]
xi = x0 + noise
# Obliterate masked image
if z_mask is not None and obliterate:
random = torch.randn(z_mask.shape, device=xi.device)
xi = (z_mask * noise) + ((1-z_mask) * xi)
sigma_sched = sigmas[ddim_steps - t_enc_steps - 1:]
model_wrap_cfg = CFGMaskedDenoiser(sampler.model_wrap)
samples_ddim = K.sampling.__dict__[f'sample_{sampler.get_sampler_name()}'](model_wrap_cfg, xi, sigma_sched,
extra_args={'cond': conditioning, 'uncond': unconditional_conditioning,
'cond_scale': cfg_scale, 'mask': z_mask, 'x0': x0, 'xi': xi}, disable=False,
callback=generation_callback if not server_state["bridge"] else None)
else:
x0, z_mask = init_data
sampler.make_schedule(ddim_num_steps=ddim_steps, ddim_eta=0.0, verbose=False)
z_enc = sampler.stochastic_encode(x0, torch.tensor([t_enc_steps]*batch_size).to(server_state["device"] ))
# Obliterate masked image
if z_mask is not None and obliterate:
random = torch.randn(z_mask.shape, device=z_enc.device)
z_enc = (z_mask * random) + ((1-z_mask) * z_enc)
# decode it
samples_ddim = sampler.decode(z_enc, conditioning, t_enc_steps,
unconditional_guidance_scale=cfg_scale,
unconditional_conditioning=unconditional_conditioning,
z_mask=z_mask, x0=x0)
return samples_ddim
if loopback:
output_images, info = None, None
history = []
initial_seed = None
do_color_correction = False
try:
from skimage import exposure
do_color_correction = True
except:
logger.error("Install scikit-image to perform color correction on loopback")
for i in range(n_iter):
if do_color_correction and i == 0:
correction_target = cv2.cvtColor(np.asarray(init_img.copy()), cv2.COLOR_RGB2LAB)
# RealESRGAN can only run on the final iteration
is_final_iteration = i == n_iter - 1
output_images, seed, info, stats = process_images(
outpath=outpath,
func_init=init,
func_sample=sample,
prompt=prompt,
seed=seed,
sampler_name=sampler_name,
save_grid=save_grid,
batch_size=1,
n_iter=1,
steps=ddim_steps,
cfg_scale=cfg_scale,
width=width,
height=height,
prompt_matrix=separate_prompts,
use_GFPGAN=use_GFPGAN,
GFPGAN_model=GFPGAN_model,
use_RealESRGAN=use_RealESRGAN and is_final_iteration, # Forcefully disable upscaling when using loopback
realesrgan_model_name=RealESRGAN_model,
use_LDSR=use_LDSR,
LDSR_model_name=LDSR_model,
normalize_prompt_weights=normalize_prompt_weights,
save_individual_images=save_individual_images,
init_img=init_img,
init_mask=init_mask,
mask_blur_strength=mask_blur_strength,
mask_restore=mask_restore,
denoising_strength=denoising_strength,
noise_mode=noise_mode,
find_noise_steps=find_noise_steps,
resize_mode=resize_mode,
uses_loopback=loopback,
uses_random_seed_loopback=random_seed_loopback,
sort_samples=group_by_prompt,
write_info_files=write_info_files,
jpg_sample=save_as_jpg
)
if initial_seed is None:
initial_seed = seed
input_image = init_img
init_img = output_images[0]
if do_color_correction and correction_target is not None:
init_img = Image.fromarray(cv2.cvtColor(exposure.match_histograms(
cv2.cvtColor(
np.asarray(init_img),
cv2.COLOR_RGB2LAB
),
correction_target,
channel_axis=2
), cv2.COLOR_LAB2RGB).astype("uint8"))
if mask_restore is True and init_mask is not None:
color_mask = init_mask.filter(ImageFilter.GaussianBlur(mask_blur_strength))
color_mask = color_mask.convert('L')
source_image = input_image.convert('RGB')
target_image = init_img.convert('RGB')
init_img = Image.composite(source_image, target_image, color_mask)
if not random_seed_loopback:
seed = seed + 1
else:
seed = seed_to_int(None)
denoising_strength = max(denoising_strength * 0.95, 0.1)
history.append(init_img)
output_images = history
seed = initial_seed
else:
output_images, seed, info, stats = process_images(
outpath=outpath,
func_init=init,
func_sample=sample,
prompt=prompt,
seed=seed,
sampler_name=sampler_name,
save_grid=save_grid,
batch_size=batch_size,
n_iter=n_iter,
steps=ddim_steps,
cfg_scale=cfg_scale,
width=width,
height=height,
prompt_matrix=separate_prompts,
use_GFPGAN=use_GFPGAN,
GFPGAN_model=GFPGAN_model,
use_RealESRGAN=use_RealESRGAN,
realesrgan_model_name=RealESRGAN_model,
use_LDSR=use_LDSR,
LDSR_model_name=LDSR_model,
normalize_prompt_weights=normalize_prompt_weights,
save_individual_images=save_individual_images,
init_img=init_img,
init_mask=init_mask,
mask_blur_strength=mask_blur_strength,
denoising_strength=denoising_strength,
noise_mode=noise_mode,
find_noise_steps=find_noise_steps,
mask_restore=mask_restore,
resize_mode=resize_mode,
uses_loopback=loopback,
sort_samples=group_by_prompt,
write_info_files=write_info_files,
jpg_sample=save_as_jpg
)
del sampler
return output_images, seed, info, stats
#
def layout():
with st.form("img2img-inputs"):
st.session_state["generation_mode"] = "img2img"
img2img_input_col, img2img_generate_col = st.columns([10,1])
with img2img_input_col:
#prompt = st.text_area("Input Text","")
placeholder = "A corgi wearing a top hat as an oil painting."
prompt = st.text_area("Input Text","", placeholder=placeholder, height=54)
if "defaults" in st.session_state:
if st.session_state["defaults"].general.enable_suggestions:
sygil_suggestions.suggestion_area(placeholder)
if "defaults" in st.session_state:
if st.session_state['defaults'].admin.global_negative_prompt:
prompt += f"### {st.session_state['defaults'].admin.global_negative_prompt}"
# Every form must have a submit button, the extra blank spaces is a temp way to align it with the input field. Needs to be done in CSS or some other way.
img2img_generate_col.write("")
img2img_generate_col.write("")
generate_button = img2img_generate_col.form_submit_button("Generate")
# creating the page layout using columns
col1_img2img_layout, col2_img2img_layout, col3_img2img_layout = st.columns([2,4,4], gap="medium")
with col1_img2img_layout:
# If we have custom models available on the "models/custom"
#folder then we show a menu to select which model we want to use, otherwise we use the main model for SD
custom_models_available()
if server_state["CustomModel_available"]:
st.session_state["custom_model"] = st.selectbox("Custom Model:", server_state["custom_models"],
index=server_state["custom_models"].index(st.session_state['defaults'].general.default_model),
help="Select the model you want to use. This option is only available if you have custom models \
on your 'models/custom' folder. The model name that will be shown here is the same as the name\
the file for the model has on said folder, it is recommended to give the .ckpt file a name that \
will make it easier for you to distinguish it from other models. Default: Stable Diffusion v1.5")
else:
st.session_state["custom_model"] = "Stable Diffusion v1.5"
st.session_state["sampling_steps"] = st.number_input("Sampling Steps", value=st.session_state['defaults'].img2img.sampling_steps.value,
min_value=st.session_state['defaults'].img2img.sampling_steps.min_value,
step=st.session_state['defaults'].img2img.sampling_steps.step)
sampler_name_list = ["k_lms", "k_euler", "k_euler_a", "k_dpm_2", "k_dpm_2_a", "k_dpmpp_2m", "k_heun", "PLMS", "DDIM"]
st.session_state["sampler_name"] = st.selectbox("Sampling method",sampler_name_list,
index=sampler_name_list.index(st.session_state['defaults'].img2img.sampler_name), help="Sampling method to use.")
width = st.slider("Width:", min_value=st.session_state['defaults'].img2img.width.min_value, max_value=st.session_state['defaults'].img2img.width.max_value,
value=st.session_state['defaults'].img2img.width.value, step=st.session_state['defaults'].img2img.width.step)
height = st.slider("Height:", min_value=st.session_state['defaults'].img2img.height.min_value, max_value=st.session_state['defaults'].img2img.height.max_value,
value=st.session_state['defaults'].img2img.height.value, step=st.session_state['defaults'].img2img.height.step)
seed = st.text_input("Seed:", value=st.session_state['defaults'].img2img.seed, help=" The seed to use, if left blank a random seed will be generated.")
cfg_scale = st.number_input("CFG (Classifier Free Guidance Scale):", min_value=st.session_state['defaults'].img2img.cfg_scale.min_value,
value=st.session_state['defaults'].img2img.cfg_scale.value,
step=st.session_state['defaults'].img2img.cfg_scale.step,
help="How strongly the image should follow the prompt.")
st.session_state["denoising_strength"] = st.slider("Denoising Strength:", value=st.session_state['defaults'].img2img.denoising_strength.value,
min_value=st.session_state['defaults'].img2img.denoising_strength.min_value,
max_value=st.session_state['defaults'].img2img.denoising_strength.max_value,
step=st.session_state['defaults'].img2img.denoising_strength.step)
mask_expander = st.empty()
with mask_expander.expander("Inpainting/Outpainting"):
mask_mode_list = ["Outpainting", "Inpainting", "Image alpha"]
mask_mode = st.selectbox("Painting Mode", mask_mode_list, index=st.session_state["defaults"].img2img.mask_mode,
help="Select how you want your image to be masked/painted.\"Inpainting\" modifies the image where the mask is white.\n\
\"Inverted mask\" modifies the image where the mask is black. \"Image alpha\" modifies the image where the image is transparent."
)
mask_mode = mask_mode_list.index(mask_mode)
noise_mode_list = ["Seed", "Find Noise", "Matched Noise", "Find+Matched Noise"]
noise_mode = st.selectbox("Noise Mode", noise_mode_list, index=noise_mode_list.index(st.session_state['defaults'].img2img.noise_mode), help="")
#noise_mode = noise_mode_list.index(noise_mode)
find_noise_steps = st.number_input("Find Noise Steps", value=st.session_state['defaults'].img2img.find_noise_steps.value,
min_value=st.session_state['defaults'].img2img.find_noise_steps.min_value,
step=st.session_state['defaults'].img2img.find_noise_steps.step)
# Specify canvas parameters in application
drawing_mode = st.selectbox(
"Drawing tool:",
(
"freedraw",
"transform",
#"line",
"rect",
"circle",
#"polygon",
),
)
stroke_width = st.slider("Stroke width: ", 1, 100, 50)
stroke_color = st.color_picker("Stroke color hex: ", value="#EEEEEE")
bg_color = st.color_picker("Background color hex: ", "#7B6E6E")
display_toolbar = st.checkbox("Display toolbar", True)
#realtime_update = st.checkbox("Update in realtime", True)
with st.expander("Batch Options"):
st.session_state["batch_count"] = st.number_input("Batch count.", value=st.session_state['defaults'].img2img.batch_count.value,
help="How many iterations or batches of images to generate in total.")
st.session_state["batch_size"] = st.number_input("Batch size", value=st.session_state.defaults.img2img.batch_size.value,
help="How many images are at once in a batch.\
It increases the VRAM usage a lot but if you have enough VRAM it can reduce the time it takes to finish generation as more images are generated at once.\
Default: 1")
with st.expander("Preview Settings"):
st.session_state["update_preview"] = st.session_state["defaults"].general.update_preview
st.session_state["update_preview_frequency"] = st.number_input("Update Image Preview Frequency",
min_value=0,
value=st.session_state['defaults'].img2img.update_preview_frequency,
help="Frequency in steps at which the the preview image is updated. By default the frequency \
is set to 1 step.")
#
with st.expander("Advanced"):
with st.expander("Output Settings"):
separate_prompts = st.checkbox("Create Prompt Matrix.", value=st.session_state['defaults'].img2img.separate_prompts,
help="Separate multiple prompts using the `|` character, and get all combinations of them.")
normalize_prompt_weights = st.checkbox("Normalize Prompt Weights.", value=st.session_state['defaults'].img2img.normalize_prompt_weights,
help="Ensure the sum of all weights add up to 1.0")
loopback = st.checkbox("Loopback.", value=st.session_state['defaults'].img2img.loopback, help="Use images from previous batch when creating next batch.")
random_seed_loopback = st.checkbox("Random loopback seed.", value=st.session_state['defaults'].img2img.random_seed_loopback, help="Random loopback seed")
img2img_mask_restore = st.checkbox("Only modify regenerated parts of image",
value=st.session_state['defaults'].img2img.mask_restore,
help="Enable to restore the unmasked parts of the image with the input, may not blend as well but preserves detail")
save_individual_images = st.checkbox("Save individual images.", value=st.session_state['defaults'].img2img.save_individual_images,
help="Save each image generated before any filter or enhancement is applied.")
save_grid = st.checkbox("Save grid",value=st.session_state['defaults'].img2img.save_grid, help="Save a grid with all the images generated into a single image.")
group_by_prompt = st.checkbox("Group results by prompt", value=st.session_state['defaults'].img2img.group_by_prompt,
help="Saves all the images with the same prompt into the same folder. \
When using a prompt matrix each prompt combination will have its own folder.")
write_info_files = st.checkbox("Write Info file", value=st.session_state['defaults'].img2img.write_info_files,
help="Save a file next to the image with informartion about the generation.")
save_as_jpg = st.checkbox("Save samples as jpg", value=st.session_state['defaults'].img2img.save_as_jpg, help="Saves the images as jpg instead of png.")
#
# check if GFPGAN, RealESRGAN and LDSR are available.
if "GFPGAN_available" not in st.session_state:
GFPGAN_available()
if "RealESRGAN_available" not in st.session_state:
RealESRGAN_available()
if "LDSR_available" not in st.session_state:
LDSR_available()
if st.session_state["GFPGAN_available"] or st.session_state["RealESRGAN_available"] or st.session_state["LDSR_available"]:
with st.expander("Post-Processing"):
face_restoration_tab, upscaling_tab = st.tabs(["Face Restoration", "Upscaling"])
with face_restoration_tab:
# GFPGAN used for face restoration
if st.session_state["GFPGAN_available"]:
#with st.expander("Face Restoration"):
#if st.session_state["GFPGAN_available"]:
#with st.expander("GFPGAN"):
st.session_state["use_GFPGAN"] = st.checkbox("Use GFPGAN", value=st.session_state['defaults'].img2img.use_GFPGAN,
help="Uses the GFPGAN model to improve faces after the generation.\
This greatly improve the quality and consistency of faces but uses\
extra VRAM. Disable if you need the extra VRAM.")
st.session_state["GFPGAN_model"] = st.selectbox("GFPGAN model", st.session_state["GFPGAN_models"],
index=st.session_state["GFPGAN_models"].index(st.session_state['defaults'].general.GFPGAN_model))
#st.session_state["GFPGAN_strenght"] = st.slider("Effect Strenght", min_value=1, max_value=100, value=1, step=1, help='')
else:
st.session_state["use_GFPGAN"] = False
with upscaling_tab:
st.session_state['us_upscaling'] = st.checkbox("Use Upscaling", value=st.session_state['defaults'].img2img.use_upscaling)
# RealESRGAN and LDSR used for upscaling.
if st.session_state["RealESRGAN_available"] or st.session_state["LDSR_available"]:
upscaling_method_list = []
if st.session_state["RealESRGAN_available"]:
upscaling_method_list.append("RealESRGAN")
if st.session_state["LDSR_available"]:
upscaling_method_list.append("LDSR")
st.session_state["upscaling_method"] = st.selectbox("Upscaling Method", upscaling_method_list,
index=upscaling_method_list.index(st.session_state['defaults'].general.upscaling_method)
if st.session_state['defaults'].general.upscaling_method in upscaling_method_list
else 0)
if st.session_state["RealESRGAN_available"]:
with st.expander("RealESRGAN"):
if st.session_state["upscaling_method"] == "RealESRGAN" and st.session_state['us_upscaling']:
st.session_state["use_RealESRGAN"] = True
else:
st.session_state["use_RealESRGAN"] = False
st.session_state["RealESRGAN_model"] = st.selectbox("RealESRGAN model", st.session_state["RealESRGAN_models"],
index=st.session_state["RealESRGAN_models"].index(st.session_state['defaults'].general.RealESRGAN_model))
else:
st.session_state["use_RealESRGAN"] = False
st.session_state["RealESRGAN_model"] = "RealESRGAN_x4plus"
#
if st.session_state["LDSR_available"]:
with st.expander("LDSR"):
if st.session_state["upscaling_method"] == "LDSR" and st.session_state['us_upscaling']:
st.session_state["use_LDSR"] = True
else:
st.session_state["use_LDSR"] = False
st.session_state["LDSR_model"] = st.selectbox("LDSR model", st.session_state["LDSR_models"],
index=st.session_state["LDSR_models"].index(st.session_state['defaults'].general.LDSR_model))
st.session_state["ldsr_sampling_steps"] = st.number_input("Sampling Steps", value=st.session_state['defaults'].img2img.LDSR_config.sampling_steps,
help="")
st.session_state["preDownScale"] = st.number_input("PreDownScale", value=st.session_state['defaults'].img2img.LDSR_config.preDownScale,
help="")
st.session_state["postDownScale"] = st.number_input("postDownScale", value=st.session_state['defaults'].img2img.LDSR_config.postDownScale,
help="")
downsample_method_list = ['Nearest', 'Lanczos']
st.session_state["downsample_method"] = st.selectbox("Downsample Method", downsample_method_list,
index=downsample_method_list.index(st.session_state['defaults'].img2img.LDSR_config.downsample_method))
else:
st.session_state["use_LDSR"] = False
st.session_state["LDSR_model"] = "model"
with st.expander("Variant"):
variant_amount = st.slider("Variant Amount:", value=st.session_state['defaults'].img2img.variant_amount, min_value=0.0, max_value=1.0, step=0.01)
variant_seed = st.text_input("Variant Seed:", value=st.session_state['defaults'].img2img.variant_seed,
help="The seed to use when generating a variant, if left blank a random seed will be generated.")
with col2_img2img_layout:
editor_tab = st.tabs(["Editor"])
editor_image = st.empty()
st.session_state["editor_image"] = editor_image
st.form_submit_button("Refresh")
#if "canvas" not in st.session_state:
st.session_state["canvas"] = st.empty()
masked_image_holder = st.empty()
image_holder = st.empty()
uploaded_images = st.file_uploader(
"Upload Image", accept_multiple_files=False, type=["png", "jpg", "jpeg", "webp", 'jfif'],
help="Upload an image which will be used for the image to image generation.",
)
if uploaded_images:
image = Image.open(uploaded_images).convert('RGB')
new_img = image.resize((width, height))
#image_holder.image(new_img)
#mask_holder = st.empty()
#uploaded_masks = st.file_uploader(
#"Upload Mask", accept_multiple_files=False, type=["png", "jpg", "jpeg", "webp", 'jfif'],
#help="Upload an mask image which will be used for masking the image to image generation.",
#)
#
# Create a canvas component
with st.session_state["canvas"]:
st.session_state["uploaded_masks"] = st_canvas(
fill_color="rgba(255, 165, 0, 0.3)", # Fixed fill color with some opacity
stroke_width=stroke_width,
stroke_color=stroke_color,
background_color=bg_color,
background_image=image if uploaded_images else None,
update_streamlit=True,
width=width,
height=height,
drawing_mode=drawing_mode,
initial_drawing=st.session_state["uploaded_masks"].json_data if "uploaded_masks" in st.session_state else None,
display_toolbar= display_toolbar,
key="full_app",
)
#try:
##print (type(st.session_state["uploaded_masks"]))
#if st.session_state["uploaded_masks"] != None:
#mask_expander.expander("Mask", expanded=True)
#mask = Image.fromarray(st.session_state["uploaded_masks"].image_data)
#st.image(mask)
#if mask.mode == "RGBA":
#mask = mask.convert('RGBA')
#background = Image.new('RGBA', mask.size, (0, 0, 0))
#mask = Image.alpha_composite(background, mask)
#mask = mask.resize((width, height))
#except AttributeError:
#pass
with col3_img2img_layout:
result_tab = st.tabs(["Result"])
# create an empty container for the image, progress bar, etc so we can update it later and use session_state to hold them globally.
preview_image = st.empty()
st.session_state["preview_image"] = preview_image
#st.session_state["loading"] = st.empty()
st.session_state["progress_bar_text"] = st.empty()
st.session_state["progress_bar"] = st.empty()
message = st.empty()
#if uploaded_images:
#image = Image.open(uploaded_images).convert('RGB')
##img_array = np.array(image) # if you want to pass it to OpenCV
#new_img = image.resize((width, height))
#st.image(new_img, use_column_width=True)
if generate_button:
#print("Loading models")
# load the models when we hit the generate button for the first time, it wont be loaded after that so dont worry.
with col3_img2img_layout:
with no_rerun:
with hc.HyLoader('Loading Models...', hc.Loaders.standard_loaders,index=[0]):
load_models(use_LDSR=st.session_state["use_LDSR"], LDSR_model=st.session_state["LDSR_model"],
use_GFPGAN=st.session_state["use_GFPGAN"], GFPGAN_model=st.session_state["GFPGAN_model"] ,
use_RealESRGAN=st.session_state["use_RealESRGAN"], RealESRGAN_model=st.session_state["RealESRGAN_model"],
CustomModel_available=server_state["CustomModel_available"], custom_model=st.session_state["custom_model"])
if uploaded_images:
#image = Image.fromarray(image).convert('RGBA')
#new_img = image.resize((width, height))
###img_array = np.array(image) # if you want to pass it to OpenCV
#image_holder.image(new_img)
new_mask = None
if st.session_state["uploaded_masks"]:
mask = Image.fromarray(st.session_state["uploaded_masks"].image_data)
new_mask = mask.resize((width, height))
#masked_image_holder.image(new_mask)
try:
output_images, seed, info, stats = img2img(prompt=prompt, init_info=new_img, init_info_mask=new_mask, mask_mode=mask_mode,
mask_restore=img2img_mask_restore, ddim_steps=st.session_state["sampling_steps"],
sampler_name=st.session_state["sampler_name"], n_iter=st.session_state["batch_count"],
cfg_scale=cfg_scale, denoising_strength=st.session_state["denoising_strength"], variant_seed=variant_seed,
seed=seed, noise_mode=noise_mode, find_noise_steps=find_noise_steps, width=width,
height=height, variant_amount=variant_amount,
ddim_eta=st.session_state.defaults.img2img.ddim_eta, write_info_files=write_info_files,
separate_prompts=separate_prompts, normalize_prompt_weights=normalize_prompt_weights,
save_individual_images=save_individual_images, save_grid=save_grid,
group_by_prompt=group_by_prompt, save_as_jpg=save_as_jpg, use_GFPGAN=st.session_state["use_GFPGAN"],
GFPGAN_model=st.session_state["GFPGAN_model"],
use_RealESRGAN=st.session_state["use_RealESRGAN"], RealESRGAN_model=st.session_state["RealESRGAN_model"],
use_LDSR=st.session_state["use_LDSR"], LDSR_model=st.session_state["LDSR_model"],
loopback=loopback
)
#show a message when the generation is complete.
message.success('Render Complete: ' + info + '; Stats: ' + stats, icon="")
except (StopException,
#KeyError
):
logger.info(f"Received Streamlit StopException")
# reset the page title so the percent doesnt stay on it confusing the user.
set_page_title(f"Stable Diffusion Playground")
# this will render all the images at the end of the generation but its better if its moved to a second tab inside col2 and shown as a gallery.
# use the current col2 first tab to show the preview_img and update it as its generated.
#preview_image.image(output_images, width=750)
#on import run init

View File

@ -0,0 +1,460 @@
# This file is part of sygil-webui (https://github.com/Sygil-Dev/sygil-webui/).
# Copyright 2022 Sygil-Dev team.
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
# ---------------------------------------------------------------------------------------------------------------------------------------------------
"""
CLIP Interrogator made by @pharmapsychotic modified to work with our WebUI.
# CLIP Interrogator by @pharmapsychotic
Twitter: https://twitter.com/pharmapsychotic
Github: https://github.com/pharmapsychotic/clip-interrogator
Description:
What do the different OpenAI CLIP models see in an image? What might be a good text prompt to create similar images using CLIP guided diffusion
or another text to image model? The CLIP Interrogator is here to get you answers!
Please consider buying him a coffee via [ko-fi](https://ko-fi.com/pharmapsychotic) or following him on [twitter](https://twitter.com/pharmapsychotic).
And if you're looking for more Ai art tools check out my [Ai generative art tools list](https://pharmapsychotic.com/tools.html).
"""
# ---------------------------------------------------------------------------------------------------------------------------------------------------
# base webui import and utils.
from sd_utils import st, logger, server_state, server_state_lock, random
# streamlit imports
# streamlit components section
import streamlit_nested_layout
# other imports
import clip
import open_clip
import gc
import os
import pandas as pd
#import requests
import torch
from PIL import Image
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
from ldm.models.blip import blip_decoder
#import hashlib
# end of imports
# ---------------------------------------------------------------------------------------------------------------
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
blip_image_eval_size = 512
st.session_state["log"] = []
def load_blip_model():
logger.info("Loading BLIP Model")
if "log" not in st.session_state:
st.session_state["log"] = []
st.session_state["log"].append("Loading BLIP Model")
st.session_state["log_message"].code('\n'.join(st.session_state["log"]), language='')
if "blip_model" not in server_state:
with server_state_lock['blip_model']:
server_state["blip_model"] = blip_decoder(pretrained="models/blip/model__base_caption.pth",
image_size=blip_image_eval_size, vit='base', med_config="configs/blip/med_config.json")
server_state["blip_model"] = server_state["blip_model"].eval()
server_state["blip_model"] = server_state["blip_model"].to(device).half()
logger.info("BLIP Model Loaded")
st.session_state["log"].append("BLIP Model Loaded")
st.session_state["log_message"].code('\n'.join(st.session_state["log"]), language='')
else:
logger.info("BLIP Model already loaded")
st.session_state["log"].append("BLIP Model already loaded")
st.session_state["log_message"].code('\n'.join(st.session_state["log"]), language='')
def generate_caption(pil_image):
load_blip_model()
gpu_image = transforms.Compose([ # type: ignore
transforms.Resize((blip_image_eval_size, blip_image_eval_size), interpolation=InterpolationMode.BICUBIC), # type: ignore
transforms.ToTensor(), # type: ignore
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) # type: ignore
])(pil_image).unsqueeze(0).to(device).half()
with torch.no_grad():
caption = server_state["blip_model"].generate(gpu_image, sample=False, num_beams=3, max_length=20, min_length=5)
return caption[0]
def load_list(filename):
with open(filename, 'r', encoding='utf-8', errors='replace') as f:
items = [line.strip() for line in f.readlines()]
return items
def rank(model, image_features, text_array, top_count=1):
top_count = min(top_count, len(text_array))
text_tokens = clip.tokenize([text for text in text_array]).cuda()
with torch.no_grad():
text_features = model.encode_text(text_tokens).float()
text_features /= text_features.norm(dim=-1, keepdim=True)
similarity = torch.zeros((1, len(text_array))).to(device)
for i in range(image_features.shape[0]):
similarity += (100.0 * image_features[i].unsqueeze(0) @ text_features.T).softmax(dim=-1)
similarity /= image_features.shape[0]
top_probs, top_labels = similarity.cpu().topk(top_count, dim=-1)
return [(text_array[top_labels[0][i].numpy()], (top_probs[0][i].numpy()*100)) for i in range(top_count)]
def clear_cuda():
torch.cuda.empty_cache()
gc.collect()
def batch_rank(model, image_features, text_array, batch_size=st.session_state["defaults"].img2txt.batch_size):
batch_size = min(batch_size, len(text_array))
batch_count = int(len(text_array) / batch_size)
batches = [text_array[i*batch_size:(i+1)*batch_size] for i in range(batch_count)]
ranks = []
for batch in batches:
ranks += rank(model, image_features, batch)
return ranks
def interrogate(image, models):
load_blip_model()
logger.info("Generating Caption")
st.session_state["log"].append("Generating Caption")
st.session_state["log_message"].code('\n'.join(st.session_state["log"]), language='')
caption = generate_caption(image)
if st.session_state["defaults"].general.optimized:
del server_state["blip_model"]
clear_cuda()
logger.info("Caption Generated")
st.session_state["log"].append("Caption Generated")
st.session_state["log_message"].code('\n'.join(st.session_state["log"]), language='')
if len(models) == 0:
logger.info(f"\n\n{caption}")
return
table = []
bests = [[('', 0)]]*7
logger.info("Ranking Text")
st.session_state["log"].append("Ranking Text")
st.session_state["log_message"].code('\n'.join(st.session_state["log"]), language='')
for model_name in models:
with torch.no_grad(), torch.autocast('cuda', dtype=torch.float16):
logger.info(f"Interrogating with {model_name}...")
st.session_state["log"].append(f"Interrogating with {model_name}...")
st.session_state["log_message"].code('\n'.join(st.session_state["log"]), language='')
if model_name not in server_state["clip_models"]:
if not st.session_state["defaults"].img2txt.keep_all_models_loaded:
model_to_delete = []
for model in server_state["clip_models"]:
if model != model_name:
model_to_delete.append(model)
for model in model_to_delete:
del server_state["clip_models"][model]
del server_state["preprocesses"][model]
clear_cuda()
if model_name == 'ViT-H-14':
server_state["clip_models"][model_name], _, server_state["preprocesses"][model_name] = \
open_clip.create_model_and_transforms(model_name, pretrained='laion2b_s32b_b79k', cache_dir='models/clip')
elif model_name == 'ViT-g-14':
server_state["clip_models"][model_name], _, server_state["preprocesses"][model_name] = \
open_clip.create_model_and_transforms(model_name, pretrained='laion2b_s12b_b42k', cache_dir='models/clip')
else:
server_state["clip_models"][model_name], server_state["preprocesses"][model_name] = \
clip.load(model_name, device=device, download_root='models/clip')
server_state["clip_models"][model_name] = server_state["clip_models"][model_name].cuda().eval()
images = server_state["preprocesses"][model_name](image).unsqueeze(0).cuda()
image_features = server_state["clip_models"][model_name].encode_image(images).float()
image_features /= image_features.norm(dim=-1, keepdim=True)
if st.session_state["defaults"].general.optimized:
clear_cuda()
ranks = []
ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["mediums"]))
ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, ["by "+artist for artist in server_state["artists"]]))
ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["trending_list"]))
ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["movements"]))
ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["flavors"]))
#ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["domains"]))
#ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["subreddits"]))
ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["techniques"]))
ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["tags"]))
# ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["genres"]))
# ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["styles"]))
# ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["subjects"]))
# ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["colors"]))
# ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["moods"]))
# ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["themes"]))
# ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["keywords"]))
#print (bests)
#print (ranks)
for i in range(len(ranks)):
confidence_sum = 0
for ci in range(len(ranks[i])):
confidence_sum += ranks[i][ci][1]
if confidence_sum > sum(bests[i][t][1] for t in range(len(bests[i]))):
bests[i] = ranks[i]
for best in bests:
best.sort(key=lambda x: x[1], reverse=True)
# prune to 3
best = best[:3]
row = [model_name]
for r in ranks:
row.append(', '.join([f"{x[0]} ({x[1]:0.1f}%)" for x in r]))
#for rank in ranks:
# rank.sort(key=lambda x: x[1], reverse=True)
# row.append(f'{rank[0][0]} {rank[0][1]:.2f}%')
table.append(row)
if st.session_state["defaults"].general.optimized:
del server_state["clip_models"][model_name]
gc.collect()
st.session_state["prediction_table"][st.session_state["processed_image_count"]].dataframe(pd.DataFrame(
table, columns=["Model", "Medium", "Artist", "Trending", "Movement", "Flavors", "Techniques", "Tags"]))
medium = bests[0][0][0]
artist = bests[1][0][0]
trending = bests[2][0][0]
movement = bests[3][0][0]
flavors = bests[4][0][0]
#domains = bests[5][0][0]
#subreddits = bests[6][0][0]
techniques = bests[5][0][0]
tags = bests[6][0][0]
if caption.startswith(medium):
st.session_state["text_result"][st.session_state["processed_image_count"]].code(
f"\n\n{caption} {artist}, {trending}, {movement}, {techniques}, {flavors}, {tags}", language="")
else:
st.session_state["text_result"][st.session_state["processed_image_count"]].code(
f"\n\n{caption}, {medium} {artist}, {trending}, {movement}, {techniques}, {flavors}, {tags}", language="")
logger.info("Finished Interrogating.")
st.session_state["log"].append("Finished Interrogating.")
st.session_state["log_message"].code('\n'.join(st.session_state["log"]), language='')
def img2txt():
models = []
if st.session_state["ViT-L/14"]:
models.append('ViT-L/14')
if st.session_state["ViT-H-14"]:
models.append('ViT-H-14')
if st.session_state["ViT-g-14"]:
models.append('ViT-g-14')
if st.session_state["ViTB32"]:
models.append('ViT-B/32')
if st.session_state['ViTB16']:
models.append('ViT-B/16')
if st.session_state["ViTL14_336px"]:
models.append('ViT-L/14@336px')
if st.session_state["RN101"]:
models.append('RN101')
if st.session_state["RN50"]:
models.append('RN50')
if st.session_state["RN50x4"]:
models.append('RN50x4')
if st.session_state["RN50x16"]:
models.append('RN50x16')
if st.session_state["RN50x64"]:
models.append('RN50x64')
# if str(image_path_or_url).startswith('http://') or str(image_path_or_url).startswith('https://'):
#image = Image.open(requests.get(image_path_or_url, stream=True).raw).convert('RGB')
# else:
#image = Image.open(image_path_or_url).convert('RGB')
#thumb = st.session_state["uploaded_image"].image.copy()
#thumb.thumbnail([blip_image_eval_size, blip_image_eval_size])
# display(thumb)
st.session_state["processed_image_count"] = 0
for i in range(len(st.session_state["uploaded_image"])):
interrogate(st.session_state["uploaded_image"][i].pil_image, models=models)
# increase counter.
st.session_state["processed_image_count"] += 1
#
def layout():
#set_page_title("Image-to-Text - Stable Diffusion WebUI")
#st.info("Under Construction. :construction_worker:")
#
if "clip_models" not in server_state:
server_state["clip_models"] = {}
if "preprocesses" not in server_state:
server_state["preprocesses"] = {}
data_path = "data/"
if "artists" not in server_state:
server_state["artists"] = load_list(os.path.join(data_path, 'img2txt', 'artists.txt'))
if "flavors" not in server_state:
server_state["flavors"] = random.choices(load_list(os.path.join(data_path, 'img2txt', 'flavors.txt')), k=2000)
if "mediums" not in server_state:
server_state["mediums"] = load_list(os.path.join(data_path, 'img2txt', 'mediums.txt'))
if "movements" not in server_state:
server_state["movements"] = load_list(os.path.join(data_path, 'img2txt', 'movements.txt'))
if "sites" not in server_state:
server_state["sites"] = load_list(os.path.join(data_path, 'img2txt', 'sites.txt'))
#server_state["domains"] = load_list(os.path.join(data_path, 'img2txt', 'domains.txt'))
#server_state["subreddits"] = load_list(os.path.join(data_path, 'img2txt', 'subreddits.txt'))
if "techniques" not in server_state:
server_state["techniques"] = load_list(os.path.join(data_path, 'img2txt', 'techniques.txt'))
if "tags" not in server_state:
server_state["tags"] = load_list(os.path.join(data_path, 'img2txt', 'tags.txt'))
#server_state["genres"] = load_list(os.path.join(data_path, 'img2txt', 'genres.txt'))
# server_state["styles"] = load_list(os.path.join(data_path, 'img2txt', 'styles.txt'))
# server_state["subjects"] = load_list(os.path.join(data_path, 'img2txt', 'subjects.txt'))
if "trending_list" not in server_state:
server_state["trending_list"] = [site for site in server_state["sites"]]
server_state["trending_list"].extend(["trending on "+site for site in server_state["sites"]])
server_state["trending_list"].extend(["featured on "+site for site in server_state["sites"]])
server_state["trending_list"].extend([site+" contest winner" for site in server_state["sites"]])
with st.form("img2txt-inputs"):
st.session_state["generation_mode"] = "img2txt"
# st.write("---")
# creating the page layout using columns
col1, col2 = st.columns([1, 4], gap="large")
with col1:
st.session_state["uploaded_image"] = st.file_uploader('Input Image', type=['png', 'jpg', 'jpeg', 'jfif', 'webp'], accept_multiple_files=True)
with st.expander("CLIP models", expanded=True):
st.session_state["ViT-L/14"] = st.checkbox("ViT-L/14", value=True, help="ViT-L/14 model.")
st.session_state["ViT-H-14"] = st.checkbox("ViT-H-14", value=False, help="ViT-H-14 model.")
st.session_state["ViT-g-14"] = st.checkbox("ViT-g-14", value=False, help="ViT-g-14 model.")
with st.expander("Others"):
st.info("For DiscoDiffusion and JAX enable all the same models here as you intend to use when generating your images.")
st.session_state["ViTL14_336px"] = st.checkbox("ViTL14_336px", value=False, help="ViTL14_336px model.")
st.session_state["ViTB16"] = st.checkbox("ViTB16", value=False, help="ViTB16 model.")
st.session_state["ViTB32"] = st.checkbox("ViTB32", value=False, help="ViTB32 model.")
st.session_state["RN50"] = st.checkbox("RN50", value=False, help="RN50 model.")
st.session_state["RN50x4"] = st.checkbox("RN50x4", value=False, help="RN50x4 model.")
st.session_state["RN50x16"] = st.checkbox("RN50x16", value=False, help="RN50x16 model.")
st.session_state["RN50x64"] = st.checkbox("RN50x64", value=False, help="RN50x64 model.")
st.session_state["RN101"] = st.checkbox("RN101", value=False, help="RN101 model.")
#
# st.subheader("Logs:")
st.session_state["log_message"] = st.empty()
st.session_state["log_message"].code('', language="")
with col2:
st.subheader("Image")
image_col1, image_col2 = st.columns([10,25])
with image_col1:
refresh = st.form_submit_button("Update Preview Image", help='Refresh the image preview to show your uploaded image instead of the default placeholder.')
if st.session_state["uploaded_image"]:
#print (type(st.session_state["uploaded_image"]))
# if len(st.session_state["uploaded_image"]) == 1:
st.session_state["input_image_preview"] = []
st.session_state["input_image_preview_container"] = []
st.session_state["prediction_table"] = []
st.session_state["text_result"] = []
for i in range(len(st.session_state["uploaded_image"])):
st.session_state["input_image_preview_container"].append(i)
st.session_state["input_image_preview_container"][i] = st.empty()
with st.session_state["input_image_preview_container"][i].container():
col1_output, col2_output = st.columns([2, 10], gap="medium")
with col1_output:
st.session_state["input_image_preview"].append(i)
st.session_state["input_image_preview"][i] = st.empty()
st.session_state["uploaded_image"][i].pil_image = Image.open(st.session_state["uploaded_image"][i]).convert('RGB')
st.session_state["input_image_preview"][i].image(st.session_state["uploaded_image"][i].pil_image, use_column_width=True, clamp=True)
with st.session_state["input_image_preview_container"][i].container():
with col2_output:
st.session_state["prediction_table"].append(i)
st.session_state["prediction_table"][i] = st.empty()
st.session_state["prediction_table"][i].table()
st.session_state["text_result"].append(i)
st.session_state["text_result"][i] = st.empty()
st.session_state["text_result"][i].code("", language="")
else:
#st.session_state["input_image_preview"].code('', language="")
st.image("images/streamlit/img2txt_placeholder.png", clamp=True)
with image_col2:
#
# Every form must have a submit button, the extra blank spaces is a temp way to align it with the input field. Needs to be done in CSS or some other way.
# generate_col1.title("")
# generate_col1.title("")
generate_button = st.form_submit_button("Generate!", help="Start interrogating the images to generate a prompt from each of the selected images")
if generate_button:
# if model, pipe, RealESRGAN or GFPGAN is in st.session_state remove the model and pipe form session_state so that they are reloaded.
if "model" in server_state and st.session_state["defaults"].general.optimized:
del server_state["model"]
if "pipe" in server_state and st.session_state["defaults"].general.optimized:
del server_state["pipe"]
if "RealESRGAN" in server_state and st.session_state["defaults"].general.optimized:
del server_state["RealESRGAN"]
if "GFPGAN" in server_state and st.session_state["defaults"].general.optimized:
del server_state["GFPGAN"]
# run clip interrogator
img2txt()

View File

@ -0,0 +1,368 @@
# This file is part of sygil-webui (https://github.com/Sygil-Dev/sandbox-webui/).
# Copyright 2022 Sygil-Dev team.
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
# base webui import and utils.
#from sd_utils import *
from sd_utils import st, server_state, \
RealESRGAN_available, GFPGAN_available, LDSR_available
# streamlit imports
#streamlit components section
import hydralit_components as hc
#other imports
import os
from PIL import Image
import torch
# Temp imports
# end of imports
#---------------------------------------------------------------------------------------------------------------
def post_process(use_GFPGAN=True, GFPGAN_model='', use_RealESRGAN=False, realesrgan_model_name="", use_LDSR=False, LDSR_model_name=""):
for i in range(len(st.session_state["uploaded_image"])):
#st.session_state["uploaded_image"][i].pil_image
if use_GFPGAN and server_state["GFPGAN"] is not None and not use_RealESRGAN and not use_LDSR:
if "progress_bar_text" in st.session_state:
st.session_state["progress_bar_text"].text("Running GFPGAN on image %d of %d..." % (i+1, len(st.session_state["uploaded_image"])))
if "progress_bar" in st.session_state:
st.session_state["progress_bar"].progress(
int(100 * float(i+1 if i+1 < len(st.session_state["uploaded_image"]) else len(st.session_state["uploaded_image"]))/float(len(st.session_state["uploaded_image"]))))
if server_state["GFPGAN"].name != GFPGAN_model:
load_models(use_LDSR=use_LDSR, LDSR_model=LDSR_model_name, use_GFPGAN=use_GFPGAN, use_RealESRGAN=use_RealESRGAN, RealESRGAN_model=realesrgan_model_name)
torch_gc()
with torch.autocast('cuda'):
cropped_faces, restored_faces, restored_img = server_state["GFPGAN"].enhance(st.session_state["uploaded_image"][i].pil_image, has_aligned=False, only_center_face=False, paste_back=True)
gfpgan_sample = restored_img[:,:,::-1]
gfpgan_image = Image.fromarray(gfpgan_sample)
#if st.session_state["GFPGAN_strenght"]:
#gfpgan_sample = Image.blend(image, gfpgan_image, st.session_state["GFPGAN_strenght"])
gfpgan_filename = st.session_state["uploaded_image"][i].name.split('.')[0] + '-gfpgan'
gfpgan_image.save(os.path.join(st.session_state["defaults"].post_processing.outdir_post_processing, f"{gfpgan_filename}.png"))
#
elif use_RealESRGAN and server_state["RealESRGAN"] is not None and not use_GFPGAN:
if "progress_bar_text" in st.session_state:
st.session_state["progress_bar_text"].text("Running RealESRGAN on image %d of %d..." % (i+1, len(st.session_state["uploaded_image"])))
if "progress_bar" in st.session_state:
st.session_state["progress_bar"].progress(
int(100 * float(i+1 if i+1 < len(st.session_state["uploaded_image"]) else len(st.session_state["uploaded_image"]))/float(len(st.session_state["uploaded_image"]))))
torch_gc()
if server_state["RealESRGAN"].model.name != realesrgan_model_name:
#try_loading_RealESRGAN(realesrgan_model_name)
load_models(use_GFPGAN=use_GFPGAN, use_RealESRGAN=use_RealESRGAN, RealESRGAN_model=realesrgan_model_name)
output, img_mode = server_state["RealESRGAN"].enhance(st.session_state["uploaded_image"][i].pil_image)
esrgan_filename = st.session_state["uploaded_image"][i].name.split('.')[0] + '-esrgan4x'
esrgan_sample = output[:,:,::-1]
esrgan_image = Image.fromarray(esrgan_sample)
esrgan_image.save(os.path.join(st.session_state["defaults"].post_processing.outdir_post_processing, f"{esrgan_filename}.png"))
#
elif use_LDSR and "LDSR" in server_state and not use_GFPGAN:
logger.info ("Running LDSR on image %d of %d..." % (i+1, len(st.session_state["uploaded_image"])))
if "progress_bar_text" in st.session_state:
st.session_state["progress_bar_text"].text("Running LDSR on image %d of %d..." % (i+1, len(st.session_state["uploaded_image"])))
if "progress_bar" in st.session_state:
st.session_state["progress_bar"].progress(
int(100 * float(i+1 if i+1 < len(st.session_state["uploaded_image"]) else len(st.session_state["uploaded_image"]))/float(len(st.session_state["uploaded_image"]))))
torch_gc()
if server_state["LDSR"].name != LDSR_model_name:
#try_loading_RealESRGAN(realesrgan_model_name)
load_models(use_LDSR=use_LDSR, LDSR_model=LDSR_model_name, use_GFPGAN=use_GFPGAN, use_RealESRGAN=use_RealESRGAN, RealESRGAN_model=realesrgan_model_name)
result = server_state["LDSR"].superResolution(st.session_state["uploaded_image"][i].pil_image, ddimSteps = st.session_state["ldsr_sampling_steps"],
preDownScale = st.session_state["preDownScale"], postDownScale = st.session_state["postDownScale"],
downsample_method=st.session_state["downsample_method"])
ldsr_filename = st.session_state["uploaded_image"][i].name.split('.')[0] + '-ldsr4x'
result.save(os.path.join(st.session_state["defaults"].post_processing.outdir_post_processing, f"{ldsr_filename}.png"))
#
elif use_LDSR and "LDSR" in server_state and use_GFPGAN and "GFPGAN" in server_state:
logger.info ("Running GFPGAN+LDSR on image %d of %d..." % (i+1, len(st.session_state["uploaded_image"])))
if "progress_bar_text" in st.session_state:
st.session_state["progress_bar_text"].text("Running GFPGAN+LDSR on image %d of %d..." % (i+1, len(st.session_state["uploaded_image"])))
if "progress_bar" in st.session_state:
st.session_state["progress_bar"].progress(
int(100 * float(i+1 if i+1 < len(st.session_state["uploaded_image"]) else len(st.session_state["uploaded_image"]))/float(len(st.session_state["uploaded_image"]))))
if server_state["GFPGAN"].name != GFPGAN_model:
load_models(use_LDSR=use_LDSR, LDSR_model=LDSR_model_name, use_GFPGAN=use_GFPGAN, use_RealESRGAN=use_RealESRGAN, RealESRGAN_model=realesrgan_model_name)
torch_gc()
cropped_faces, restored_faces, restored_img = server_state["GFPGAN"].enhance(st.session_state["uploaded_image"][i].pil_image, has_aligned=False, only_center_face=False, paste_back=True)
gfpgan_sample = restored_img[:,:,::-1]
gfpgan_image = Image.fromarray(gfpgan_sample)
if server_state["LDSR"].name != LDSR_model_name:
#try_loading_RealESRGAN(realesrgan_model_name)
load_models(use_LDSR=use_LDSR, LDSR_model=LDSR_model_name, use_GFPGAN=use_GFPGAN, use_RealESRGAN=use_RealESRGAN, RealESRGAN_model=realesrgan_model_name)
#LDSR.superResolution(gfpgan_image, ddimSteps=100, preDownScale='None', postDownScale='None', downsample_method="Lanczos")
result = server_state["LDSR"].superResolution(gfpgan_image, ddimSteps = st.session_state["ldsr_sampling_steps"],
preDownScale = st.session_state["preDownScale"], postDownScale = st.session_state["postDownScale"],
downsample_method=st.session_state["downsample_method"])
ldsr_filename = st.session_state["uploaded_image"][i].name.split('.')[0] + '-gfpgan-ldsr2x'
result.save(os.path.join(st.session_state["defaults"].post_processing.outdir_post_processing, f"{ldsr_filename}.png"))
elif use_RealESRGAN and server_state["RealESRGAN"] is not None and use_GFPGAN and server_state["GFPGAN"] is not None:
if "progress_bar_text" in st.session_state:
st.session_state["progress_bar_text"].text("Running GFPGAN+RealESRGAN on image %d of %d..." % (i+1, len(st.session_state["uploaded_image"])))
if "progress_bar" in st.session_state:
st.session_state["progress_bar"].progress(
int(100 * float(i+1 if i+1 < len(st.session_state["uploaded_image"]) else len(st.session_state["uploaded_image"]))/float(len(st.session_state["uploaded_image"]))))
torch_gc()
cropped_faces, restored_faces, restored_img = server_state["GFPGAN"].enhance(st.session_state["uploaded_image"][i].pil_image, has_aligned=False, only_center_face=False, paste_back=True)
gfpgan_sample = restored_img[:,:,::-1]
if server_state["RealESRGAN"].model.name != realesrgan_model_name:
#try_loading_RealESRGAN(realesrgan_model_name)
load_models(use_GFPGAN=use_GFPGAN, use_RealESRGAN=use_RealESRGAN, RealESRGAN_model=realesrgan_model_name)
output, img_mode = server_state["RealESRGAN"].enhance(gfpgan_sample[:,:,::-1])
gfpgan_esrgan_filename = st.session_state["uploaded_image"][i].name.split('.')[0] + '-gfpgan-esrgan4x'
gfpgan_esrgan_sample = output[:,:,::-1]
gfpgan_esrgan_image = Image.fromarray(gfpgan_esrgan_sample)
gfpgan_esrgan_image.save(os.path.join(st.session_state["defaults"].post_processing.outdir_post_processing, f"{gfpgan_esrgan_filename}.png"))
def layout():
#st.info("Under Construction. :construction_worker:")
st.session_state["progress_bar_text"] = st.empty()
#st.session_state["progress_bar_text"].info("Nothing but crickets here, try generating something first.")
st.session_state["progress_bar"] = st.empty()
with st.form("post-processing-inputs"):
# creating the page layout using columns
col1, col2 = st.columns([1, 4], gap="medium")
with col1:
st.session_state["uploaded_image"] = st.file_uploader('Input Image', type=['png', 'jpg', 'jpeg', 'jfif', 'webp'], accept_multiple_files=True)
# check if GFPGAN, RealESRGAN and LDSR are available.
#if "GFPGAN_available" not in st.session_state:
GFPGAN_available()
#if "RealESRGAN_available" not in st.session_state:
RealESRGAN_available()
#if "LDSR_available" not in st.session_state:
LDSR_available()
if st.session_state["GFPGAN_available"] or st.session_state["RealESRGAN_available"] or st.session_state["LDSR_available"]:
face_restoration_tab, upscaling_tab = st.tabs(["Face Restoration", "Upscaling"])
with face_restoration_tab:
# GFPGAN used for face restoration
if st.session_state["GFPGAN_available"]:
#with st.expander("Face Restoration"):
#if st.session_state["GFPGAN_available"]:
#with st.expander("GFPGAN"):
st.session_state["use_GFPGAN"] = st.checkbox("Use GFPGAN", value=st.session_state['defaults'].txt2img.use_GFPGAN,
help="Uses the GFPGAN model to improve faces after the generation.\
This greatly improve the quality and consistency of faces but uses\
extra VRAM. Disable if you need the extra VRAM.")
st.session_state["GFPGAN_model"] = st.selectbox("GFPGAN model", st.session_state["GFPGAN_models"],
index=st.session_state["GFPGAN_models"].index(st.session_state['defaults'].general.GFPGAN_model))
#st.session_state["GFPGAN_strenght"] = st.slider("Effect Strenght", min_value=1, max_value=100, value=1, step=1, help='')
else:
st.session_state["use_GFPGAN"] = False
with upscaling_tab:
st.session_state['use_upscaling'] = st.checkbox("Use Upscaling", value=st.session_state['defaults'].txt2img.use_upscaling)
# RealESRGAN and LDSR used for upscaling.
if st.session_state["RealESRGAN_available"] or st.session_state["LDSR_available"]:
upscaling_method_list = []
if st.session_state["RealESRGAN_available"]:
upscaling_method_list.append("RealESRGAN")
if st.session_state["LDSR_available"]:
upscaling_method_list.append("LDSR")
#print (st.session_state["RealESRGAN_available"])
st.session_state["upscaling_method"] = st.selectbox("Upscaling Method", upscaling_method_list,
index=upscaling_method_list.index(st.session_state['defaults'].general.upscaling_method)
if st.session_state['defaults'].general.upscaling_method in upscaling_method_list
else 0)
if st.session_state["RealESRGAN_available"]:
with st.expander("RealESRGAN"):
if st.session_state["upscaling_method"] == "RealESRGAN" and st.session_state['use_upscaling']:
st.session_state["use_RealESRGAN"] = True
else:
st.session_state["use_RealESRGAN"] = False
st.session_state["RealESRGAN_model"] = st.selectbox("RealESRGAN model", st.session_state["RealESRGAN_models"],
index=st.session_state["RealESRGAN_models"].index(st.session_state['defaults'].general.RealESRGAN_model))
else:
st.session_state["use_RealESRGAN"] = False
st.session_state["RealESRGAN_model"] = "RealESRGAN_x4plus"
#
if st.session_state["LDSR_available"]:
with st.expander("LDSR"):
if st.session_state["upscaling_method"] == "LDSR" and st.session_state['use_upscaling']:
st.session_state["use_LDSR"] = True
else:
st.session_state["use_LDSR"] = False
st.session_state["LDSR_model"] = st.selectbox("LDSR model", st.session_state["LDSR_models"],
index=st.session_state["LDSR_models"].index(st.session_state['defaults'].general.LDSR_model))
st.session_state["ldsr_sampling_steps"] = st.number_input("Sampling Steps", value=st.session_state['defaults'].txt2img.LDSR_config.sampling_steps,
help="")
st.session_state["preDownScale"] = st.number_input("PreDownScale", value=st.session_state['defaults'].txt2img.LDSR_config.preDownScale,
help="")
st.session_state["postDownScale"] = st.number_input("postDownScale", value=st.session_state['defaults'].txt2img.LDSR_config.postDownScale,
help="")
downsample_method_list = ['Nearest', 'Lanczos']
st.session_state["downsample_method"] = st.selectbox("Downsample Method", downsample_method_list,
index=downsample_method_list.index(st.session_state['defaults'].txt2img.LDSR_config.downsample_method))
else:
st.session_state["use_LDSR"] = False
st.session_state["LDSR_model"] = "model"
#process = st.form_submit_button("Process Images", help="")
#
with st.expander("Output Settings", True):
#st.session_state['defaults'].post_processing.save_original_images = st.checkbox("Save input images.", value=st.session_state['defaults'].post_processing.save_original_images,
#help="Save each original/input image next to the Post Processed image. "
#"This might be helpful for comparing the before and after images.")
st.session_state['defaults'].post_processing.outdir_post_processing = st.text_input("Output Dir",value=st.session_state['defaults'].post_processing.outdir_post_processing,
help="Folder where the images will be saved after post processing.")
with col2:
st.subheader("Image")
image_col1, image_col2, image_col3 = st.columns([2, 2, 2], gap="small")
with image_col1:
refresh = st.form_submit_button("Refresh", help='Refresh the image preview to show your uploaded image.')
if st.session_state["uploaded_image"]:
#print (type(st.session_state["uploaded_image"]))
# if len(st.session_state["uploaded_image"]) == 1:
st.session_state["input_image_preview"] = []
st.session_state["input_image_caption"] = []
st.session_state["output_image_preview"] = []
st.session_state["output_image_caption"] = []
st.session_state["input_image_preview_container"] = []
st.session_state["prediction_table"] = []
st.session_state["text_result"] = []
for i in range(len(st.session_state["uploaded_image"])):
st.session_state["input_image_preview_container"].append(i)
st.session_state["input_image_preview_container"][i] = st.empty()
with st.session_state["input_image_preview_container"][i].container():
col1_output, col2_output, col3_output = st.columns([2, 2, 2], gap="medium")
with col1_output:
st.session_state["output_image_caption"].append(i)
st.session_state["output_image_caption"][i] = st.empty()
#st.session_state["output_image_caption"][i] = st.session_state["uploaded_image"][i].name
st.session_state["input_image_caption"].append(i)
st.session_state["input_image_caption"][i] = st.empty()
#st.session_state["input_image_caption"][i].caption(")
st.session_state["input_image_preview"].append(i)
st.session_state["input_image_preview"][i] = st.empty()
st.session_state["uploaded_image"][i].pil_image = Image.open(st.session_state["uploaded_image"][i]).convert('RGB')
st.session_state["input_image_preview"][i].image(st.session_state["uploaded_image"][i].pil_image, use_column_width=True, clamp=True)
with col2_output:
st.session_state["output_image_preview"].append(i)
st.session_state["output_image_preview"][i] = st.empty()
st.session_state["output_image_preview"][i].image(st.session_state["uploaded_image"][i].pil_image, use_column_width=True, clamp=True)
with st.session_state["input_image_preview_container"][i].container():
with col3_output:
#st.session_state["prediction_table"].append(i)
#st.session_state["prediction_table"][i] = st.empty()
#st.session_state["prediction_table"][i].table(pd.DataFrame(columns=["Model", "Filename", "Progress"]))
st.session_state["text_result"].append(i)
st.session_state["text_result"][i] = st.empty()
st.session_state["text_result"][i].code("", language="")
#else:
##st.session_state["input_image_preview"].code('', language="")
#st.image("images/streamlit/img2txt_placeholder.png", clamp=True)
with image_col3:
# Every form must have a submit button, the extra blank spaces is a temp way to align it with the input field. Needs to be done in CSS or some other way.
process = st.form_submit_button("Process Images!")
if process:
with hc.HyLoader('Loading Models...', hc.Loaders.standard_loaders,index=[0]):
#load_models(use_LDSR=st.session_state["use_LDSR"], LDSR_model=st.session_state["LDSR_model"],
#use_GFPGAN=st.session_state["use_GFPGAN"], GFPGAN_model=st.session_state["GFPGAN_model"] ,
#use_RealESRGAN=st.session_state["use_RealESRGAN"], RealESRGAN_model=st.session_state["RealESRGAN_model"])
if st.session_state["use_GFPGAN"]:
load_GFPGAN(model_name=st.session_state["GFPGAN_model"])
if st.session_state["use_RealESRGAN"]:
load_RealESRGAN(st.session_state["RealESRGAN_model"])
if st.session_state["use_LDSR"]:
load_LDSR(st.session_state["LDSR_model"])
post_process(use_GFPGAN=st.session_state["use_GFPGAN"], GFPGAN_model=st.session_state["GFPGAN_model"],
use_RealESRGAN=st.session_state["use_RealESRGAN"], realesrgan_model_name=st.session_state["RealESRGAN_model"],
use_LDSR=st.session_state["use_LDSR"], LDSR_model_name=st.session_state["LDSR_model"])

View File

@ -0,0 +1,260 @@
# This file is part of sygil-webui (https://github.com/Sygil-Dev/sygil-webui/).
# Copyright 2022 Sygil-Dev team.
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
# base webui import and utils.
from sd_utils import st
# streamlit imports
import streamlit.components.v1 as components
#other imports
import os, math
from PIL import Image
# Temp imports
#from basicsr.utils.registry import ARCH_REGISTRY
# end of imports
#---------------------------------------------------------------------------------------------------------------
# Init Vuejs component
_component_func = components.declare_component(
"sd-concepts-browser", "./frontend/dists/concept-browser/dist")
def sdConceptsBrowser(concepts, key=None):
component_value = _component_func(concepts=concepts, key=key, default="")
return component_value
@st.experimental_memo(persist="disk", show_spinner=False, suppress_st_warning=True)
def getConceptsFromPath(page, conceptPerPage, searchText=""):
#print("getConceptsFromPath", "page:", page, "conceptPerPage:", conceptPerPage, "searchText:", searchText)
# get the path where the concepts are stored
path = os.path.join(
os.getcwd(), st.session_state['defaults'].general.sd_concepts_library_folder)
acceptedExtensions = ('jpeg', 'jpg', "png")
concepts = []
if os.path.exists(path):
# List all folders (concepts) in the path
folders = [f for f in os.listdir(
path) if os.path.isdir(os.path.join(path, f))]
filteredFolders = folders
# Filter the folders by the search text
if searchText != "":
filteredFolders = [
f for f in folders if searchText.lower() in f.lower()]
else:
filteredFolders = []
conceptIndex = 1
for folder in filteredFolders:
# handle pagination
if conceptIndex > (page * conceptPerPage):
continue
if conceptIndex <= ((page - 1) * conceptPerPage):
conceptIndex += 1
continue
concept = {
"name": folder,
"token": "<" + folder + ">",
"images": [],
"type": ""
}
# type of concept is inside type_of_concept.txt
typePath = os.path.join(path, folder, "type_of_concept.txt")
binFile = os.path.join(path, folder, "learned_embeds.bin")
# Continue if the concept is not valid or the download has failed (no type_of_concept.txt or no binFile)
if not os.path.exists(typePath) or not os.path.exists(binFile):
continue
with open(typePath, "r") as f:
concept["type"] = f.read()
# List all files in the concept/concept_images folder
files = [f for f in os.listdir(os.path.join(path, folder, "concept_images")) if os.path.isfile(
os.path.join(path, folder, "concept_images", f))]
# Retrieve only the 4 first images
for file in files:
# Skip if we already have 4 images
if len(concept["images"]) >= 4:
break
if file.endswith(acceptedExtensions):
try:
# Add a copy of the image to avoid file locking
originalImage = Image.open(os.path.join(
path, folder, "concept_images", file))
# Maintain the aspect ratio (max 200x200)
resizedImage = originalImage.copy()
resizedImage.thumbnail((200, 200), Image.Resampling.LANCZOS)
# concept["images"].append(resizedImage)
concept["images"].append(imageToBase64(resizedImage))
# Close original image
originalImage.close()
except:
print("Error while loading image", file, "in concept", folder, "(The file may be corrupted). Skipping it.")
concepts.append(concept)
conceptIndex += 1
# print all concepts name
#print("Results:", [c["name"] for c in concepts])
return concepts
@st.cache(persist=True, allow_output_mutation=True, show_spinner=False, suppress_st_warning=True)
def imageToBase64(image):
import io
import base64
buffered = io.BytesIO()
image.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
return img_str
@st.experimental_memo(persist="disk", show_spinner=False, suppress_st_warning=True)
def getTotalNumberOfConcepts(searchText=""):
# get the path where the concepts are stored
path = os.path.join(
os.getcwd(), st.session_state['defaults'].general.sd_concepts_library_folder)
concepts = []
if os.path.exists(path):
# List all folders (concepts) in the path
folders = [f for f in os.listdir(
path) if os.path.isdir(os.path.join(path, f))]
filteredFolders = folders
# Filter the folders by the search text
if searchText != "":
filteredFolders = [
f for f in folders if searchText.lower() in f.lower()]
else:
filteredFolders = []
return len(filteredFolders)
def layout():
# 2 tabs, one for Concept Library and one for the Download Manager
tab_library, tab_downloader = st.tabs(["Library", "Download Manager"])
# Concept Library
with tab_library:
downloaded_concepts_count = getTotalNumberOfConcepts()
concepts_per_page = st.session_state["defaults"].concepts_library.concepts_per_page
if not "results" in st.session_state:
st.session_state["results"] = getConceptsFromPath(1, concepts_per_page, "")
# Pagination controls
if not "cl_current_page" in st.session_state:
st.session_state["cl_current_page"] = 1
# Search
if not 'cl_search_text' in st.session_state:
st.session_state["cl_search_text"] = ""
if not 'cl_search_results_count' in st.session_state:
st.session_state["cl_search_results_count"] = downloaded_concepts_count
# Search bar
_search_col, _refresh_col = st.columns([10, 2])
with _search_col:
search_text_input = st.text_input("Search", "", placeholder=f'Search for a concept ({downloaded_concepts_count} available)', label_visibility="hidden")
if search_text_input != st.session_state["cl_search_text"]:
# Search text has changed
st.session_state["cl_search_text"] = search_text_input
st.session_state["cl_current_page"] = 1
st.session_state["cl_search_results_count"] = getTotalNumberOfConcepts(st.session_state["cl_search_text"])
st.session_state["results"] = getConceptsFromPath(1, concepts_per_page, st.session_state["cl_search_text"])
with _refresh_col:
# Super weird fix to align the refresh button with the search bar ( Please streamlit, add css support.. )
_refresh_col.write("")
_refresh_col.write("")
if st.button("Refresh concepts", key="refresh_concepts", help="Refresh the concepts folders. Use this if you have added new concepts manually or deleted some."):
getTotalNumberOfConcepts.clear()
getConceptsFromPath.clear()
st.experimental_rerun()
# Show results
results_empty = st.empty()
# Pagination
pagination_empty = st.empty()
# Layouts
with pagination_empty:
with st.container():
if len(st.session_state["results"]) > 0:
last_page = math.ceil(st.session_state["cl_search_results_count"] / concepts_per_page)
_1, _2, _3, _4, _previous_page, _current_page, _next_page, _9, _10, _11, _12 = st.columns([1,1,1,1,1,2,1,1,1,1,1])
# Previous page
with _previous_page:
if st.button("Previous", key="cl_previous_page"):
st.session_state["cl_current_page"] -= 1
if st.session_state["cl_current_page"] <= 0:
st.session_state["cl_current_page"] = last_page
st.session_state["results"] = getConceptsFromPath(st.session_state["cl_current_page"], concepts_per_page, st.session_state["cl_search_text"])
# Current page
with _current_page:
_current_page_container = st.empty()
# Next page
with _next_page:
if st.button("Next", key="cl_next_page"):
st.session_state["cl_current_page"] += 1
if st.session_state["cl_current_page"] > last_page:
st.session_state["cl_current_page"] = 1
st.session_state["results"] = getConceptsFromPath(st.session_state["cl_current_page"], concepts_per_page, st.session_state["cl_search_text"])
# Current page
with _current_page_container:
st.markdown(f'<p style="text-align: center">Page {st.session_state["cl_current_page"]} of {last_page}</p>', unsafe_allow_html=True)
# st.write(f"Page {st.session_state['cl_current_page']} of {last_page}", key="cl_current_page")
with results_empty:
with st.container():
if downloaded_concepts_count == 0:
st.write("You don't have any concepts in your library ")
st.markdown("To add concepts to your library, download some from the [sd-concepts-library](https://github.com/Sygil-Dev/sd-concepts-library) \
repository and save the content of `sd-concepts-library` into ```./models/custom/sd-concepts-library``` or just create your own concepts :wink:.", unsafe_allow_html=False)
else:
if len(st.session_state["results"]) == 0:
st.write("No concept found in the library matching your search: " + st.session_state["cl_search_text"])
else:
# display number of results
if st.session_state["cl_search_text"]:
st.write(f"Found {st.session_state['cl_search_results_count']} {'concepts' if st.session_state['cl_search_results_count'] > 1 else 'concept' } matching your search")
sdConceptsBrowser(st.session_state['results'], key="results")
with tab_downloader:
st.write("Not implemented yet")
return False

View File

@ -0,0 +1,405 @@
# This file is part of sygil-webui (https://github.com/Sygil-Dev/sygil-webui/).
# Copyright 2022 Sygil-Dev team.
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
# base webui import and utils.
#from webui_streamlit import st
import hydralit as st
# streamlit imports
from streamlit.runtime.scriptrunner import StopException
#from streamlit.runtime.scriptrunner import script_run_context
#streamlit components section
from streamlit_server_state import server_state, server_state_lock, no_rerun
import hydralit_components as hc
from hydralit import HydraHeadApp
import streamlit_nested_layout
#from streamlitextras.threader import lock, trigger_rerun, \
#streamlit_thread, get_thread, \
#last_trigger_time
#other imports
import warnings
import json
import base64, cv2
import os, sys, re, random, datetime, time, math, toml
import gc
from PIL import Image, ImageFont, ImageDraw, ImageFilter
from PIL.PngImagePlugin import PngInfo
from scipy import integrate
import torch
from torchdiffeq import odeint
import k_diffusion as K
import math, requests
import mimetypes
import numpy as np
import pynvml
import threading
import torch, torchvision
from torch import autocast
from torchvision import transforms
import torch.nn as nn
from omegaconf import OmegaConf
import yaml
from pathlib import Path
from contextlib import nullcontext
from einops import rearrange, repeat
from ldm.util import instantiate_from_config
from retry import retry
from slugify import slugify
import skimage
import piexif
import piexif.helper
from tqdm import trange
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.util import ismap
#from abc import ABC, abstractmethod
from io import BytesIO
from packaging import version
from pathlib import Path
from huggingface_hub import hf_hub_download
import shutup
#import librosa
from nataili.util.logger import logger, set_logger_verbosity, quiesce_logger
from nataili.esrgan import esrgan
#try:
#from realesrgan import RealESRGANer
#from basicsr.archs.rrdbnet_arch import RRDBNet
#except ImportError as e:
#logger.error("You tried to import realesrgan without having it installed properly. To install Real-ESRGAN, run:\n\n"
#"pip install realesrgan")
# Temp imports
#from basicsr.utils.registry import ARCH_REGISTRY
# end of imports
#---------------------------------------------------------------------------------------------------------------
# remove all the annoying python warnings.
shutup.please()
# the following lines should help fixing an issue with nvidia 16xx cards.
if "defaults" in st.session_state:
if st.session_state["defaults"].general.use_cudnn:
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.enabled = True
try:
# this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
from transformers import logging
logging.set_verbosity_error()
except:
pass
# disable diffusers telemetry
os.environ["DISABLE_TELEMETRY"] = "YES"
# remove some annoying deprecation warnings that show every now and then.
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=UserWarning)
# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the bowser will not show any UI
mimetypes.init()
mimetypes.add_type('application/javascript', '.js')
# some of those options should not be changed at all because they would break the model, so I removed them from options.
opt_C = 4
opt_f = 8
# The model manager loads and unloads the SD models and has features to download them or find their location
#model_manager = ModelManager()
def load_configs():
if not "defaults" in st.session_state:
st.session_state["defaults"] = {}
st.session_state["defaults"] = OmegaConf.load("configs/webui/webui_streamlit.yaml")
if (os.path.exists("configs/webui/userconfig_streamlit.yaml")):
user_defaults = OmegaConf.load("configs/webui/userconfig_streamlit.yaml")
if "version" in user_defaults.general:
if version.parse(user_defaults.general.version) < version.parse(st.session_state["defaults"].general.version):
logger.error("The version of the user config file is older than the version on the defaults config file. "
"This means there were big changes we made on the config."
"We are removing this file and recreating it from the defaults in order to make sure things work properly.")
os.remove("configs/webui/userconfig_streamlit.yaml")
st.experimental_rerun()
else:
logger.error("The version of the user config file is older than the version on the defaults config file. "
"This means there were big changes we made on the config."
"We are removing this file and recreating it from the defaults in order to make sure things work properly.")
os.remove("configs/webui/userconfig_streamlit.yaml")
st.experimental_rerun()
try:
st.session_state["defaults"] = OmegaConf.merge(st.session_state["defaults"], user_defaults)
except KeyError:
st.experimental_rerun()
else:
OmegaConf.save(config=st.session_state.defaults, f="configs/webui/userconfig_streamlit.yaml")
loaded = OmegaConf.load("configs/webui/userconfig_streamlit.yaml")
assert st.session_state.defaults == loaded
if (os.path.exists(".streamlit/config.toml")):
st.session_state["streamlit_config"] = toml.load(".streamlit/config.toml")
#if st.session_state["defaults"].daisi_app.running_on_daisi_io:
#if os.path.exists("scripts/modeldownload.py"):
#import modeldownload
#modeldownload.updateModels()
if "keep_all_models_loaded" in st.session_state.defaults.general:
with server_state_lock["keep_all_models_loaded"]:
server_state["keep_all_models_loaded"] = st.session_state["defaults"].general.keep_all_models_loaded
else:
st.session_state["defaults"].general.keep_all_models_loaded = False
with server_state_lock["keep_all_models_loaded"]:
server_state["keep_all_models_loaded"] = st.session_state["defaults"].general.keep_all_models_loaded
load_configs()
#
#if st.session_state["defaults"].debug.enable_hydralit:
#navbar_theme = {'txc_inactive': '#FFFFFF','menu_background':'#0e1117','txc_active':'black','option_active':'red'}
#app = st.HydraApp(title='Stable Diffusion WebUI', favicon="", use_cookie_cache=False, sidebar_state="expanded", layout="wide", navbar_theme=navbar_theme,
#hide_streamlit_markers=False, allow_url_nav=True , clear_cross_app_sessions=False, use_loader=False)
#else:
#app = None
#
grid_format = st.session_state["defaults"].general.save_format
grid_lossless = False
grid_quality = st.session_state["defaults"].general.grid_quality
if grid_format == 'png':
grid_ext = 'png'
grid_format = 'png'
elif grid_format in ['jpg', 'jpeg']:
grid_quality = int(grid_format) if len(grid_format) > 1 else 100
grid_ext = 'jpg'
grid_format = 'jpeg'
elif grid_format[0] == 'webp':
grid_quality = int(grid_format) if len(grid_format) > 1 else 100
grid_ext = 'webp'
grid_format = 'webp'
if grid_quality < 0: # e.g. webp:-100 for lossless mode
grid_lossless = True
grid_quality = abs(grid_quality)
#
save_format = st.session_state["defaults"].general.save_format
save_lossless = False
save_quality = 100
if save_format == 'png':
save_ext = 'png'
save_format = 'png'
elif save_format in ['jpg', 'jpeg']:
save_quality = int(save_format) if len(save_format) > 1 else 100
save_ext = 'jpg'
save_format = 'jpeg'
elif save_format == 'webp':
save_quality = int(save_format) if len(save_format) > 1 else 100
save_ext = 'webp'
save_format = 'webp'
if save_quality < 0: # e.g. webp:-100 for lossless mode
save_lossless = True
save_quality = abs(save_quality)
# this should force GFPGAN and RealESRGAN onto the selected gpu as well
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"] = str(st.session_state["defaults"].general.gpu)
# functions to load css locally OR remotely starts here. Options exist for future flexibility. Called as st.markdown with unsafe_allow_html as css injection
# TODO, maybe look into async loading the file especially for remote fetching
def local_css(file_name):
with open(file_name) as f:
st.markdown(f'<style>{f.read()}</style>', unsafe_allow_html=True)
def remote_css(url):
st.markdown(f'<link href="{url}" rel="stylesheet">', unsafe_allow_html=True)
def load_css(isLocal, nameOrURL):
if(isLocal):
local_css(nameOrURL)
else:
remote_css(nameOrURL)
def set_page_title(title):
"""
Simple function to allows us to change the title dynamically.
Normally you can use `st.set_page_config` to change the title but it can only be used once per app.
"""
st.sidebar.markdown(unsafe_allow_html=True, body=f"""
<iframe height=0 srcdoc="<script>
const title = window.parent.document.querySelector('title') \
const oldObserver = window.parent.titleObserver
if (oldObserver) {{
oldObserver.disconnect()
}} \
const newObserver = new MutationObserver(function(mutations) {{
const target = mutations[0].target
if (target.text !== '{title}') {{
target.text = '{title}'
}}
}}) \
newObserver.observe(title, {{ childList: true }})
window.parent.titleObserver = newObserver \
title.text = '{title}'
</script>" />
""")
class MemUsageMonitor(threading.Thread):
stop_flag = False
max_usage = 0
total = -1
def __init__(self, name):
threading.Thread.__init__(self)
self.name = name
def run(self):
try:
pynvml.nvmlInit()
except:
logger.debug(f"[{self.name}] Unable to initialize NVIDIA management. No memory stats. \n")
return
logger.info(f"[{self.name}] Recording memory usage...\n")
# Missing context
#handle = pynvml.nvmlDeviceGetHandleByIndex(st.session_state['defaults'].general.gpu)
handle = pynvml.nvmlDeviceGetHandleByIndex(0)
self.total = pynvml.nvmlDeviceGetMemoryInfo(handle).total
while not self.stop_flag:
m = pynvml.nvmlDeviceGetMemoryInfo(handle)
self.max_usage = max(self.max_usage, m.used)
# logger.info(self.max_usage)
time.sleep(0.1)
logger.info(f"[{self.name}] Stopped recording.\n")
pynvml.nvmlShutdown()
def read(self):
return self.max_usage, self.total
def stop(self):
self.stop_flag = True
def read_and_stop(self):
self.stop_flag = True
return self.max_usage, self.total
#
def custom_models_available():
with server_state_lock["custom_models"]:
#
# Allow for custom models to be used instead of the default one,
# an example would be Waifu-Diffusion or any other fine tune of stable diffusion
server_state["custom_models"]:sorted = []
for root, dirs, files in os.walk(os.path.join("models", "custom")):
for file in files:
if os.path.splitext(file)[1] == '.ckpt':
server_state["custom_models"].append(os.path.splitext(file)[0])
with server_state_lock["CustomModel_available"]:
if len(server_state["custom_models"]) > 0:
server_state["CustomModel_available"] = True
server_state["custom_models"].append("Stable Diffusion v1.5")
else:
server_state["CustomModel_available"] = False
#
def GFPGAN_available():
#with server_state_lock["GFPGAN_models"]:
#
st.session_state["GFPGAN_models"]:sorted = []
model = st.session_state["defaults"].model_manager.models.gfpgan
files_available = 0
for file in model['files']:
if "save_location" in model['files'][file]:
if os.path.exists(os.path.join(model['files'][file]['save_location'], model['files'][file]['file_name'] )):
files_available += 1
elif os.path.exists(os.path.join(model['save_location'], model['files'][file]['file_name'] )):
base_name = os.path.splitext(model['files'][file]['file_name'])[0]
if "GFPGANv" in base_name:
st.session_state["GFPGAN_models"].append(base_name)
files_available += 1
# we need to show the other models from previous verions that we have on the
# same directory in case we want to see how they perform vs each other.
for root, dirs, files in os.walk(st.session_state['defaults'].general.GFPGAN_dir):
for file in files:
if os.path.splitext(file)[1] == '.pth':
if os.path.splitext(file)[0] not in st.session_state["GFPGAN_models"]:
st.session_state["GFPGAN_models"].append(os.path.splitext(file)[0])
if len(st.session_state["GFPGAN_models"]) > 0 and files_available == len(model['files']):
st.session_state["GFPGAN_available"] = True
else:
st.session_state["GFPGAN_available"] = False
st.session_state["use_GFPGAN"] = False
st.session_state["GFPGAN_model"] = "GFPGANv1.4"
#
def RealESRGAN_available():
#with server_state_lock["RealESRGAN_models"]:
st.session_state["RealESRGAN_models"]:sorted = []
model = st.session_state["defaults"].model_manager.models.realesrgan
for file in model['files']:
if os.path.exists(os.path.join(model['save_location'], model['files'][file]['file_name'] )):
base_name = os.path.splitext(model['files'][file]['file_name'])[0]
st.session_state["RealESRGAN_models"].append(base_name)
if len(st.session_state["RealESRGAN_models"]) > 0:
st.session_state["RealESRGAN_available"] = True
else:
st.session_state["RealESRGAN_available"] = False
st.session_state["use_RealESRGAN"] = False
st.session_state["RealESRGAN_model"] = "RealESRGAN_x4plus"
#
def LDSR_available():
st.session_state["LDSR_models"]:sorted = []
files_available = 0
model = st.session_state["defaults"].model_manager.models.ldsr
for file in model['files']:
if os.path.exists(os.path.join(model['save_location'], model['files'][file]['file_name'] )):
base_name = os.path.splitext(model['files'][file]['file_name'])[0]
extension = os.path.splitext(model['files'][file]['file_name'])[1]
if extension == ".ckpt":
st.session_state["LDSR_models"].append(base_name)
files_available += 1
if files_available == len(model['files']):
st.session_state["LDSR_available"] = True
else:
st.session_state["LDSR_available"] = False
st.session_state["use_LDSR"] = False
st.session_state["LDSR_model"] = "model"

View File

@ -0,0 +1,182 @@
# This file is part of sygil-webui (https://github.com/Sygil-Dev/sygil-webui/).
# Copyright 2022 Sygil-Dev team.
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
# base webui import and utils.
#import streamlit as st
# We import hydralit like this to replace the previous stuff
# we had with native streamlit as it lets ur replace things 1:1
from nataili.util import logger
# streamlit imports
#streamlit components section
#other imports
import requests, time, json, base64
from io import BytesIO
# import custom components
# end of imports
#---------------------------------------------------------------------------------------------------------------
@logger.catch(reraise=True)
def run_bridge(interval, api_key, horde_name, horde_url, priority_usernames, horde_max_pixels, horde_nsfw, horde_censor_nsfw, horde_blacklist, horde_censorlist):
current_id = None
current_payload = None
loop_retry = 0
# load the model for stable horde if its not in memory already
# we should load it after we get the request from the API in
# case the model is different from the loaded in memory but
# for now we can load it here so its read right away.
load_models(use_GFPGAN=True)
while True:
if loop_retry > 10 and current_id:
logger.info(f"Exceeded retry count {loop_retry} for generation id {current_id}. Aborting generation!")
current_id = None
current_payload = None
current_generation = None
loop_retry = 0
elif current_id:
logger.info(f"Retrying ({loop_retry}/10) for generation id {current_id}...")
gen_dict = {
"name": horde_name,
"max_pixels": horde_max_pixels,
"priority_usernames": priority_usernames,
"nsfw": horde_nsfw,
"blacklist": horde_blacklist,
"models": ["stable_diffusion"],
}
headers = {"apikey": api_key}
if current_id:
loop_retry += 1
else:
try:
pop_req = requests.post(horde_url + '/api/v2/generate/pop', json = gen_dict, headers = headers)
except requests.exceptions.ConnectionError:
logger.warning(f"Server {horde_url} unavailable during pop. Waiting 10 seconds...")
time.sleep(10)
continue
except requests.exceptions.JSONDecodeError():
logger.warning(f"Server {horde_url} unavailable during pop. Waiting 10 seconds...")
time.sleep(10)
continue
try:
pop = pop_req.json()
except json.decoder.JSONDecodeError:
logger.warning(f"Could not decode response from {horde_url} as json. Please inform its administrator!")
time.sleep(interval)
continue
if pop == None:
logger.warning(f"Something has gone wrong with {horde_url}. Please inform its administrator!")
time.sleep(interval)
continue
if not pop_req.ok:
message = pop['message']
logger.warning(f"During gen pop, server {horde_url} responded with status code {pop_req.status_code}: {pop['message']}. Waiting for 10 seconds...")
if 'errors' in pop:
logger.debug(f"Detailed Request Errors: {pop['errors']}")
time.sleep(10)
continue
if not pop.get("id"):
skipped_info = pop.get('skipped')
if skipped_info and len(skipped_info):
skipped_info = f" Skipped Info: {skipped_info}."
else:
skipped_info = ''
logger.info(f"Server {horde_url} has no valid generations to do for us.{skipped_info}")
time.sleep(interval)
continue
current_id = pop['id']
logger.info(f"Request with id {current_id} picked up. Initiating work...")
current_payload = pop['payload']
if 'toggles' in current_payload and current_payload['toggles'] == None:
logger.error(f"Received Bad payload: {pop}")
current_id = None
current_payload = None
current_generation = None
loop_retry = 0
time.sleep(10)
continue
logger.debug(current_payload)
current_payload['toggles'] = current_payload.get('toggles', [1,4])
# In bridge-mode, matrix is prepared on the horde and split in multiple nodes
if 0 in current_payload['toggles']:
current_payload['toggles'].remove(0)
if 8 not in current_payload['toggles']:
if horde_censor_nsfw and not horde_nsfw:
current_payload['toggles'].append(8)
elif any(word in current_payload['prompt'] for word in horde_censorlist):
current_payload['toggles'].append(8)
from txt2img import txt2img
"""{'prompt': 'Centred Husky, inside spiral with circular patterns, trending on dribbble, knotwork, spirals, key patterns,
zoomorphics, ', 'ddim_steps': 30, 'n_iter': 1, 'sampler_name': 'DDIM', 'cfg_scale': 16.0, 'seed': '3405278433', 'height': 512, 'width': 512}"""
#images, seed, info, stats = txt2img(**current_payload)
images, seed, info, stats = txt2img(str(current_payload['prompt']), int(current_payload['ddim_steps']), str(current_payload['sampler_name']),
int(current_payload['n_iter']), 1, float(current_payload["cfg_scale"]), str(current_payload["seed"]),
int(current_payload["height"]), int(current_payload["width"]), save_grid=False, group_by_prompt=False,
save_individual_images=False,write_info_files=False)
buffer = BytesIO()
# We send as WebP to avoid using all the horde bandwidth
images[0].save(buffer, format="WebP", quality=90)
# logger.info(info)
submit_dict = {
"id": current_id,
"generation": base64.b64encode(buffer.getvalue()).decode("utf8"),
"api_key": api_key,
"seed": seed,
"max_pixels": horde_max_pixels,
}
current_generation = seed
while current_id and current_generation != None:
try:
submit_req = requests.post(horde_url + '/api/v2/generate/submit', json = submit_dict, headers = headers)
try:
submit = submit_req.json()
except json.decoder.JSONDecodeError:
logger.error(f"Something has gone wrong with {horde_url} during submit. Please inform its administrator! (Retry {loop_retry}/10)")
time.sleep(interval)
continue
if submit_req.status_code == 404:
logger.info(f"The generation we were working on got stale. Aborting!")
elif not submit_req.ok:
logger.error(f"During gen submit, server {horde_url} responded with status code {submit_req.status_code}: {submit['message']}. Waiting for 10 seconds... (Retry {loop_retry}/10)")
if 'errors' in submit:
logger.debug(f"Detailed Request Errors: {submit['errors']}")
time.sleep(10)
continue
else:
logger.info(f'Submitted generation with id {current_id} and contributed for {submit_req.json()["reward"]}')
current_id = None
current_payload = None
current_generation = None
loop_retry = 0
except requests.exceptions.ConnectionError:
logger.warning(f"Server {horde_url} unavailable during submit. Waiting 10 seconds... (Retry {loop_retry}/10)")
time.sleep(10)
continue
time.sleep(interval)

View File

@ -0,0 +1,938 @@
# This file is part of sygil-webui (https://github.com/Sygil-Dev/sygil-webui/).
# Copyright 2022 Sygil-Dev team.
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
# base webui import and utils.
from sd_utils import st, set_page_title, seed_to_int
# streamlit imports
from streamlit.runtime.scriptrunner import StopException
from streamlit_tensorboard import st_tensorboard
#streamlit components section
from streamlit_server_state import server_state
#other imports
from transformers import CLIPTextModel, CLIPTokenizer
# Temp imports
import itertools
import math
import os
import random
#import datetime
#from pathlib import Path
#from typing import Optional
import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch.utils.data import Dataset
import PIL
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import set_seed
from diffusers import AutoencoderKL, DDPMScheduler, LMSDiscreteScheduler, StableDiffusionPipeline, UNet2DConditionModel#, PNDMScheduler
from diffusers.optimization import get_scheduler
#from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
from pipelines.stable_diffusion.no_check import NoCheck
from PIL import Image
from torchvision import transforms
from tqdm.auto import tqdm
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from slugify import slugify
import json
import os#, subprocess
#from io import StringIO
# end of imports
#---------------------------------------------------------------------------------------------------------------
logger = get_logger(__name__)
imagenet_templates_small = [
"a photo of a {}",
"a rendering of a {}",
"a cropped photo of the {}",
"the photo of a {}",
"a photo of a clean {}",
"a photo of a dirty {}",
"a dark photo of the {}",
"a photo of my {}",
"a photo of the cool {}",
"a close-up photo of a {}",
"a bright photo of the {}",
"a cropped photo of a {}",
"a photo of the {}",
"a good photo of the {}",
"a photo of one {}",
"a close-up photo of the {}",
"a rendition of the {}",
"a photo of the clean {}",
"a rendition of a {}",
"a photo of a nice {}",
"a good photo of a {}",
"a photo of the nice {}",
"a photo of the small {}",
"a photo of the weird {}",
"a photo of the large {}",
"a photo of a cool {}",
"a photo of a small {}",
]
imagenet_style_templates_small = [
"a painting in the style of {}",
"a rendering in the style of {}",
"a cropped painting in the style of {}",
"the painting in the style of {}",
"a clean painting in the style of {}",
"a dirty painting in the style of {}",
"a dark painting in the style of {}",
"a picture in the style of {}",
"a cool painting in the style of {}",
"a close-up painting in the style of {}",
"a bright painting in the style of {}",
"a cropped painting in the style of {}",
"a good painting in the style of {}",
"a close-up painting in the style of {}",
"a rendition in the style of {}",
"a nice painting in the style of {}",
"a small painting in the style of {}",
"a weird painting in the style of {}",
"a large painting in the style of {}",
]
class TextualInversionDataset(Dataset):
def __init__(
self,
data_root,
tokenizer,
learnable_property="object", # [object, style]
size=512,
repeats=100,
interpolation="bicubic",
set="train",
placeholder_token="*",
center_crop=False,
templates=None
):
self.data_root = data_root
self.tokenizer = tokenizer
self.learnable_property = learnable_property
self.size = size
self.placeholder_token = placeholder_token
self.center_crop = center_crop
self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root) if file_path.lower().endswith(('.png', '.jpg', '.jpeg'))]
self.num_images = len(self.image_paths)
self._length = self.num_images
if set == "train":
self._length = self.num_images * repeats
self.interpolation = {
"linear": PIL.Image.LINEAR,
"bilinear": PIL.Image.Resampling.BILINEAR,
"bicubic": PIL.Image.Resampling.BICUBIC,
"lanczos": PIL.Image.Resampling.LANCZOS,
}[interpolation]
self.templates = templates
self.cache = {}
self.tokenized_templates = [self.tokenizer(
text.format(self.placeholder_token),
padding="max_length",
truncation=True,
max_length=self.tokenizer.model_max_length,
return_tensors="pt",
).input_ids[0] for text in self.templates]
def __len__(self):
return self._length
def get_example(self, image_path, flipped):
if image_path in self.cache:
return self.cache[image_path]
example = {}
image = Image.open(image_path)
if not image.mode == "RGB":
image = image.convert("RGB")
# default to score-sde preprocessing
img = np.array(image).astype(np.uint8)
if self.center_crop:
crop = min(img.shape[0], img.shape[1])
h, w, = (
img.shape[0],
img.shape[1],
)
img = img[(h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2]
image = Image.fromarray(img)
image = image.resize((self.size, self.size), resample=self.interpolation)
image = transforms.RandomHorizontalFlip(p=1 if flipped else 0)(image)
image = np.array(image).astype(np.uint8)
image = (image / 127.5 - 1.0).astype(np.float32)
example["key"] = "-".join([image_path, "-", str(flipped)])
example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1)
self.cache[image_path] = example
return example
def __getitem__(self, i):
flipped = random.choice([False, True])
example = self.get_example(self.image_paths[i % self.num_images], flipped)
example["input_ids"] = random.choice(self.tokenized_templates)
return example
def freeze_params(params):
for param in params:
param.requires_grad = False
def save_resume_file(basepath, extra = {}, config=''):
info = {"args": config["args"]}
info["args"].update(extra)
with open(f"{os.path.join(basepath, 'resume.json')}", "w") as f:
#print (info)
json.dump(info, f, indent=4)
with open(f"{basepath}/token_identifier.txt", "w") as f:
f.write(f"{config['args']['placeholder_token']}")
with open(f"{basepath}/type_of_concept.txt", "w") as f:
f.write(f"{config['args']['learnable_property']}")
config['args'] = info["args"]
return config['args']
class Checkpointer:
def __init__(
self,
accelerator,
vae,
unet,
tokenizer,
placeholder_token,
placeholder_token_id,
templates,
output_dir,
random_sample_batches,
sample_batch_size,
stable_sample_batches,
seed
):
self.accelerator = accelerator
self.vae = vae
self.unet = unet
self.tokenizer = tokenizer
self.placeholder_token = placeholder_token
self.placeholder_token_id = placeholder_token_id
self.templates = templates
self.output_dir = output_dir
self.seed = seed
self.random_sample_batches = random_sample_batches
self.sample_batch_size = sample_batch_size
self.stable_sample_batches = stable_sample_batches
@torch.no_grad()
def checkpoint(self, step, text_encoder, save_samples=True, path=None):
print("Saving checkpoint for step %d..." % step)
with torch.autocast("cuda"):
if path is None:
checkpoints_path = f"{self.output_dir}/checkpoints"
os.makedirs(checkpoints_path, exist_ok=True)
unwrapped = self.accelerator.unwrap_model(text_encoder)
# Save a checkpoint
learned_embeds = unwrapped.get_input_embeddings().weight[self.placeholder_token_id]
learned_embeds_dict = {self.placeholder_token: learned_embeds.detach().cpu()}
filename = f"%s_%d.bin" % (slugify(self.placeholder_token), step)
if path is not None:
torch.save(learned_embeds_dict, path)
else:
torch.save(learned_embeds_dict, f"{checkpoints_path}/{filename}")
torch.save(learned_embeds_dict, f"{checkpoints_path}/last.bin")
del unwrapped
del learned_embeds
@torch.no_grad()
def save_samples(self, step, text_encoder, height, width, guidance_scale, eta, num_inference_steps):
samples_path = f"{self.output_dir}/concept_images"
os.makedirs(samples_path, exist_ok=True)
#if "checker" not in server_state['textual_inversion']:
#with server_state_lock['textual_inversion']["checker"]:
server_state['textual_inversion']["checker"] = NoCheck()
#if "unwrapped" not in server_state['textual_inversion']:
# with server_state_lock['textual_inversion']["unwrapped"]:
server_state['textual_inversion']["unwrapped"] = self.accelerator.unwrap_model(text_encoder)
#if "pipeline" not in server_state['textual_inversion']:
# with server_state_lock['textual_inversion']["pipeline"]:
# Save a sample image
server_state['textual_inversion']["pipeline"] = StableDiffusionPipeline(
text_encoder=server_state['textual_inversion']["unwrapped"],
vae=self.vae,
unet=self.unet,
tokenizer=self.tokenizer,
scheduler=LMSDiscreteScheduler(
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
),
safety_checker=NoCheck(),
feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
).to("cuda")
server_state['textual_inversion']["pipeline"].enable_attention_slicing()
if self.stable_sample_batches > 0:
stable_latents = torch.randn(
(self.sample_batch_size, server_state['textual_inversion']["pipeline"].unet.in_channels, height // 8, width // 8),
device=server_state['textual_inversion']["pipeline"].device,
generator=torch.Generator(device=server_state['textual_inversion']["pipeline"].device).manual_seed(self.seed),
)
stable_prompts = [choice.format(self.placeholder_token) for choice in (self.templates * self.sample_batch_size)[:self.sample_batch_size]]
# Generate and save stable samples
for i in range(0, self.stable_sample_batches):
samples = server_state['textual_inversion']["pipeline"](
prompt=stable_prompts,
height=384,
latents=stable_latents,
width=384,
guidance_scale=guidance_scale,
eta=eta,
num_inference_steps=num_inference_steps,
output_type='pil'
)["sample"]
for idx, im in enumerate(samples):
filename = f"stable_sample_%d_%d_step_%d.png" % (i+1, idx+1, step)
im.save(f"{samples_path}/{filename}")
del samples
del stable_latents
prompts = [choice.format(self.placeholder_token) for choice in random.choices(self.templates, k=self.sample_batch_size)]
# Generate and save random samples
for i in range(0, self.random_sample_batches):
samples = server_state['textual_inversion']["pipeline"](
prompt=prompts,
height=384,
width=384,
guidance_scale=guidance_scale,
eta=eta,
num_inference_steps=num_inference_steps,
output_type='pil'
)["sample"]
for idx, im in enumerate(samples):
filename = f"step_%d_sample_%d_%d.png" % (step, i+1, idx+1)
im.save(f"{samples_path}/{filename}")
del samples
del server_state['textual_inversion']["checker"]
del server_state['textual_inversion']["unwrapped"]
del server_state['textual_inversion']["pipeline"]
torch.cuda.empty_cache()
#@retry(RuntimeError, tries=5)
def textual_inversion(config):
print ("Running textual inversion.")
#if "pipeline" in server_state["textual_inversion"]:
#del server_state['textual_inversion']["checker"]
#del server_state['textual_inversion']["unwrapped"]
#del server_state['textual_inversion']["pipeline"]
#torch.cuda.empty_cache()
global_step_offset = 0
#print(config['args']['resume_from'])
if config['args']['resume_from']:
try:
basepath = f"{config['args']['resume_from']}"
with open(f"{basepath}/resume.json", 'r') as f:
state = json.load(f)
global_step_offset = state["args"].get("global_step", 0)
print("Resuming state from %s" % config['args']['resume_from'])
print("We've trained %d steps so far" % global_step_offset)
except json.decoder.JSONDecodeError:
pass
else:
basepath = f"{config['args']['output_dir']}/{slugify(config['args']['placeholder_token'])}"
os.makedirs(basepath, exist_ok=True)
accelerator = Accelerator(
gradient_accumulation_steps=config['args']['gradient_accumulation_steps'],
mixed_precision=config['args']['mixed_precision']
)
# If passed along, set the training seed.
if config['args']['seed']:
set_seed(config['args']['seed'])
#if "tokenizer" not in server_state["textual_inversion"]:
# Load the tokenizer and add the placeholder token as a additional special token
#with server_state_lock['textual_inversion']["tokenizer"]:
if config['args']['tokenizer_name']:
server_state['textual_inversion']["tokenizer"] = CLIPTokenizer.from_pretrained(config['args']['tokenizer_name'])
elif config['args']['pretrained_model_name_or_path']:
server_state['textual_inversion']["tokenizer"] = CLIPTokenizer.from_pretrained(
config['args']['pretrained_model_name_or_path'] + '/tokenizer'
)
# Add the placeholder token in tokenizer
num_added_tokens = server_state['textual_inversion']["tokenizer"].add_tokens(config['args']['placeholder_token'])
if num_added_tokens == 0:
st.error(
f"The tokenizer already contains the token {config['args']['placeholder_token']}. Please pass a different"
" `placeholder_token` that is not already in the tokenizer."
)
# Convert the initializer_token, placeholder_token to ids
token_ids = server_state['textual_inversion']["tokenizer"].encode(config['args']['initializer_token'], add_special_tokens=False)
# Check if initializer_token is a single token or a sequence of tokens
if len(token_ids) > 1:
st.error("The initializer token must be a single token.")
initializer_token_id = token_ids[0]
placeholder_token_id = server_state['textual_inversion']["tokenizer"].convert_tokens_to_ids(config['args']['placeholder_token'])
#if "text_encoder" not in server_state['textual_inversion']:
# Load models and create wrapper for stable diffusion
#with server_state_lock['textual_inversion']["text_encoder"]:
server_state['textual_inversion']["text_encoder"] = CLIPTextModel.from_pretrained(
config['args']['pretrained_model_name_or_path'] + '/text_encoder',
)
#if "vae" not in server_state['textual_inversion']:
#with server_state_lock['textual_inversion']["vae"]:
server_state['textual_inversion']["vae"] = AutoencoderKL.from_pretrained(
config['args']['pretrained_model_name_or_path'] + '/vae',
)
#if "unet" not in server_state['textual_inversion']:
#with server_state_lock['textual_inversion']["unet"]:
server_state['textual_inversion']["unet"] = UNet2DConditionModel.from_pretrained(
config['args']['pretrained_model_name_or_path'] + '/unet',
)
base_templates = imagenet_style_templates_small if config['args']['learnable_property'] == "style" else imagenet_templates_small
if config['args']['custom_templates']:
templates = config['args']['custom_templates'].split(";")
else:
templates = base_templates
slice_size = server_state['textual_inversion']["unet"].config.attention_head_dim // 2
server_state['textual_inversion']["unet"].set_attention_slice(slice_size)
# Resize the token embeddings as we are adding new special tokens to the tokenizer
server_state['textual_inversion']["text_encoder"].resize_token_embeddings(len(server_state['textual_inversion']["tokenizer"]))
# Initialise the newly added placeholder token with the embeddings of the initializer token
token_embeds = server_state['textual_inversion']["text_encoder"].get_input_embeddings().weight.data
if "resume_checkpoint" in config['args']:
if config['args']['resume_checkpoint'] is not None:
token_embeds[placeholder_token_id] = torch.load(config['args']['resume_checkpoint'])[config['args']['placeholder_token']]
else:
token_embeds[placeholder_token_id] = token_embeds[initializer_token_id]
# Freeze vae and unet
freeze_params(server_state['textual_inversion']["vae"].parameters())
freeze_params(server_state['textual_inversion']["unet"].parameters())
# Freeze all parameters except for the token embeddings in text encoder
params_to_freeze = itertools.chain(
server_state['textual_inversion']["text_encoder"].text_model.encoder.parameters(),
server_state['textual_inversion']["text_encoder"].text_model.final_layer_norm.parameters(),
server_state['textual_inversion']["text_encoder"].text_model.embeddings.position_embedding.parameters(),
)
freeze_params(params_to_freeze)
checkpointer = Checkpointer(
accelerator=accelerator,
vae=server_state['textual_inversion']["vae"],
unet=server_state['textual_inversion']["unet"],
tokenizer=server_state['textual_inversion']["tokenizer"],
placeholder_token=config['args']['placeholder_token'],
placeholder_token_id=placeholder_token_id,
templates=templates,
output_dir=basepath,
sample_batch_size=config['args']['sample_batch_size'],
random_sample_batches=config['args']['random_sample_batches'],
stable_sample_batches=config['args']['stable_sample_batches'],
seed=config['args']['seed']
)
if config['args']['scale_lr']:
config['args']['learning_rate'] = (
config['args']['learning_rate'] * config[
'args']['gradient_accumulation_steps'] * config['args']['train_batch_size'] * accelerator.num_processes
)
# Initialize the optimizer
optimizer = torch.optim.AdamW(
server_state['textual_inversion']["text_encoder"].get_input_embeddings().parameters(), # only optimize the embeddings
lr=config['args']['learning_rate'],
betas=(config['args']['adam_beta1'], config['args']['adam_beta2']),
weight_decay=config['args']['adam_weight_decay'],
eps=config['args']['adam_epsilon'],
)
# TODO (patil-suraj): load scheduler using args
noise_scheduler = DDPMScheduler(
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, tensor_format="pt"
)
train_dataset = TextualInversionDataset(
data_root=config['args']['train_data_dir'],
tokenizer=server_state['textual_inversion']["tokenizer"],
size=config['args']['resolution'],
placeholder_token=config['args']['placeholder_token'],
repeats=config['args']['repeats'],
learnable_property=config['args']['learnable_property'],
center_crop=config['args']['center_crop'],
set="train",
templates=templates
)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=config['args']['train_batch_size'], shuffle=True)
# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / config['args']['gradient_accumulation_steps'])
if config['args']['max_train_steps'] is None:
config['args']['max_train_steps'] = config['args']['num_train_epochs'] * num_update_steps_per_epoch
overrode_max_train_steps = True
lr_scheduler = get_scheduler(
config['args']['lr_scheduler'],
optimizer=optimizer,
num_warmup_steps=config['args']['lr_warmup_steps'] * config['args']['gradient_accumulation_steps'],
num_training_steps=config['args']['max_train_steps'] * config['args']['gradient_accumulation_steps'],
)
server_state['textual_inversion']["text_encoder"], optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
server_state['textual_inversion']["text_encoder"], optimizer, train_dataloader, lr_scheduler
)
# Move vae and unet to device
server_state['textual_inversion']["vae"].to(accelerator.device)
server_state['textual_inversion']["unet"].to(accelerator.device)
# Keep vae and unet in eval mode as we don't train these
server_state['textual_inversion']["vae"].eval()
server_state['textual_inversion']["unet"].eval()
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / config['args']['gradient_accumulation_steps'])
if overrode_max_train_steps:
config['args']['max_train_steps'] = config['args']['num_train_epochs'] * num_update_steps_per_epoch
# Afterwards we recalculate our number of training epochs
config['args']['num_train_epochs'] = math.ceil(config['args']['max_train_steps'] / num_update_steps_per_epoch)
# We need to initialize the trackers we use, and also store our configuration.
# The trackers initializes automatically on the main process.
if accelerator.is_main_process:
accelerator.init_trackers("textual_inversion", config=config['args'])
# Train!
total_batch_size = config['args']['train_batch_size'] * accelerator.num_processes * st.session_state[
'textual_inversion']['args']['gradient_accumulation_steps']
logger.info("***** Running training *****")
logger.info(f" Num examples = {len(train_dataset)}")
logger.info(f" Num Epochs = {config['args']['num_train_epochs']}")
logger.info(f" Instantaneous batch size per device = {config['args']['train_batch_size']}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
logger.info(f" Gradient Accumulation steps = {config['args']['gradient_accumulation_steps']}")
logger.info(f" Total optimization steps = {config['args']['max_train_steps']}")
# Only show the progress bar once on each machine.
progress_bar = tqdm(range(config['args']['max_train_steps']), disable=not accelerator.is_local_main_process)
progress_bar.set_description("Steps")
global_step = 0
encoded_pixel_values_cache = {}
try:
for epoch in range(config['args']['num_train_epochs']):
server_state['textual_inversion']["text_encoder"].train()
for step, batch in enumerate(train_dataloader):
with accelerator.accumulate(server_state['textual_inversion']["text_encoder"]):
# Convert images to latent space
key = "|".join(batch["key"])
if encoded_pixel_values_cache.get(key, None) is None:
encoded_pixel_values_cache[key] = server_state['textual_inversion']["vae"].encode(batch["pixel_values"]).latent_dist
latents = encoded_pixel_values_cache[key].sample().detach().half() * 0.18215
# Sample noise that we'll add to the latents
noise = torch.randn(latents.shape).to(latents.device)
bsz = latents.shape[0]
# Sample a random timestep for each image
timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device).long()
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# Get the text embedding for conditioning
encoder_hidden_states = server_state['textual_inversion']["text_encoder"](batch["input_ids"])[0]
# Predict the noise residual
noise_pred = server_state['textual_inversion']["unet"](noisy_latents, timesteps, encoder_hidden_states).sample
loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
accelerator.backward(loss)
# Zero out the gradients for all token embeddings except the newly added
# embeddings for the concept, as we only want to optimize the concept embeddings
if accelerator.num_processes > 1:
grads = server_state['textual_inversion']["text_encoder"].module.get_input_embeddings().weight.grad
else:
grads = server_state['textual_inversion']["text_encoder"].get_input_embeddings().weight.grad
# Get the index for tokens that we want to zero the grads for
index_grads_to_zero = torch.arange(len(server_state['textual_inversion']["tokenizer"])) != placeholder_token_id
grads.data[index_grads_to_zero, :] = grads.data[index_grads_to_zero, :].fill_(0)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
#try:
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
if global_step % config['args']['checkpoint_frequency'] == 0 and global_step > 0 and accelerator.is_main_process:
checkpointer.checkpoint(global_step + global_step_offset, server_state['textual_inversion']["text_encoder"])
save_resume_file(basepath, {
"global_step": global_step + global_step_offset,
"resume_checkpoint": f"{basepath}/checkpoints/last.bin"
}, config)
checkpointer.save_samples(
global_step + global_step_offset,
server_state['textual_inversion']["text_encoder"],
config['args']['resolution'], config['args'][
'resolution'], 7.5, 0.0, config['args']['sample_steps'])
checkpointer.checkpoint(
global_step + global_step_offset,
server_state['textual_inversion']["text_encoder"],
path=f"{basepath}/learned_embeds.bin"
)
#except KeyError:
#raise StopException
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
#accelerator.log(logs, step=global_step)
#try:
if global_step >= config['args']['max_train_steps']:
break
#except:
#pass
accelerator.wait_for_everyone()
# Create the pipeline using the trained modules and save it.
if accelerator.is_main_process:
print("Finished! Saving final checkpoint and resume state.")
checkpointer.checkpoint(
global_step + global_step_offset,
server_state['textual_inversion']["text_encoder"],
path=f"{basepath}/learned_embeds.bin"
)
save_resume_file(basepath, {
"global_step": global_step + global_step_offset,
"resume_checkpoint": f"{basepath}/checkpoints/last.bin"
}, config)
accelerator.end_training()
except (KeyboardInterrupt, StopException) as e:
print(f"Received Streamlit StopException or KeyboardInterrupt")
if accelerator.is_main_process:
print("Interrupted, saving checkpoint and resume state...")
checkpointer.checkpoint(global_step + global_step_offset, server_state['textual_inversion']["text_encoder"])
config['args'] = save_resume_file(basepath, {
"global_step": global_step + global_step_offset,
"resume_checkpoint": f"{basepath}/checkpoints/last.bin"
}, config)
checkpointer.checkpoint(
global_step + global_step_offset,
server_state['textual_inversion']["text_encoder"],
path=f"{basepath}/learned_embeds.bin"
)
quit()
def layout():
with st.form("textual-inversion"):
#st.info("Under Construction. :construction_worker:")
#parser = argparse.ArgumentParser(description="Simple example of a training script.")
set_page_title("Textual Inversion - Stable Diffusion Playground")
config_tab, output_tab, tensorboard_tab = st.tabs(["Textual Inversion Config", "Ouput", "TensorBoard"])
with config_tab:
col1, col2, col3, col4, col5 = st.columns(5, gap='large')
if "textual_inversion" not in st.session_state:
st.session_state["textual_inversion"] = {}
if "textual_inversion" not in server_state:
server_state["textual_inversion"] = {}
if "args" not in st.session_state["textual_inversion"]:
st.session_state["textual_inversion"]["args"] = {}
with col1:
st.session_state["textual_inversion"]["args"]["pretrained_model_name_or_path"] = st.text_input("Pretrained Model Path",
value=st.session_state["defaults"].textual_inversion.pretrained_model_name_or_path,
help="Path to pretrained model or model identifier from huggingface.co/models.")
st.session_state["textual_inversion"]["args"]["tokenizer_name"] = st.text_input("Tokenizer Name",
value=st.session_state["defaults"].textual_inversion.tokenizer_name,
help="Pretrained tokenizer name or path if not the same as model_name")
st.session_state["textual_inversion"]["args"]["train_data_dir"] = st.text_input("train_data_dir", value="", help="A folder containing the training data.")
st.session_state["textual_inversion"]["args"]["placeholder_token"] = st.text_input("Placeholder Token", value="", help="A token to use as a placeholder for the concept.")
st.session_state["textual_inversion"]["args"]["initializer_token"] = st.text_input("Initializer Token", value="", help="A token to use as initializer word.")
st.session_state["textual_inversion"]["args"]["learnable_property"] = st.selectbox("Learnable Property", ["object", "style"], index=0, help="Choose between 'object' and 'style'")
st.session_state["textual_inversion"]["args"]["repeats"] = int(st.text_input("Number of times to Repeat", value=100, help="How many times to repeat the training data."))
with col2:
st.session_state["textual_inversion"]["args"]["output_dir"] = st.text_input("Output Directory",
value=str(os.path.join("outputs", "textual_inversion")),
help="The output directory where the model predictions and checkpoints will be written.")
st.session_state["textual_inversion"]["args"]["seed"] = seed_to_int(st.text_input("Seed", value=0,
help="A seed for reproducible training, if left empty a random one will be generated. Default: 0"))
st.session_state["textual_inversion"]["args"]["resolution"] = int(st.text_input("Resolution", value=512,
help="The resolution for input images, all the images in the train/validation dataset will be resized to this resolution"))
st.session_state["textual_inversion"]["args"]["center_crop"] = st.checkbox("Center Image", value=True, help="Whether to center crop images before resizing to resolution")
st.session_state["textual_inversion"]["args"]["train_batch_size"] = int(st.text_input("Train Batch Size", value=1, help="Batch size (per device) for the training dataloader."))
st.session_state["textual_inversion"]["args"]["num_train_epochs"] = int(st.text_input("Number of Steps to Train", value=100, help="Number of steps to train."))
st.session_state["textual_inversion"]["args"]["max_train_steps"] = int(st.text_input("Max Number of Steps to Train", value=5000,
help="Total number of training steps to perform. If provided, overrides 'Number of Steps to Train'."))
with col3:
st.session_state["textual_inversion"]["args"]["gradient_accumulation_steps"] = int(st.text_input("Gradient Accumulation Steps", value=1,
help="Number of updates steps to accumulate before performing a backward/update pass."))
st.session_state["textual_inversion"]["args"]["learning_rate"] = float(st.text_input("Learning Rate", value=5.0e-04,
help="Initial learning rate (after the potential warmup period) to use."))
st.session_state["textual_inversion"]["args"]["scale_lr"] = st.checkbox("Scale Learning Rate", value=True,
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.")
st.session_state["textual_inversion"]["args"]["lr_scheduler"] = st.text_input("Learning Rate Scheduler", value="constant",
help=("The scheduler type to use. Choose between ['linear', 'cosine', 'cosine_with_restarts', 'polynomial',"
" 'constant', 'constant_with_warmup']" ))
st.session_state["textual_inversion"]["args"]["lr_warmup_steps"] = int(st.text_input("Learning Rate Warmup Steps", value=500, help="Number of steps for the warmup in the lr scheduler."))
st.session_state["textual_inversion"]["args"]["adam_beta1"] = float(st.text_input("Adam Beta 1", value=0.9, help="The beta1 parameter for the Adam optimizer."))
st.session_state["textual_inversion"]["args"]["adam_beta2"] = float(st.text_input("Adam Beta 2", value=0.999, help="The beta2 parameter for the Adam optimizer."))
st.session_state["textual_inversion"]["args"]["adam_weight_decay"] = float(st.text_input("Adam Weight Decay", value=1e-2, help="Weight decay to use."))
st.session_state["textual_inversion"]["args"]["adam_epsilon"] = float(st.text_input("Adam Epsilon", value=1e-08, help="Epsilon value for the Adam optimizer"))
with col4:
st.session_state["textual_inversion"]["args"]["mixed_precision"] = st.selectbox("Mixed Precision", ["no", "fp16", "bf16"], index=1,
help="Whether to use mixed precision. Choose" "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
"and an Nvidia Ampere GPU.")
st.session_state["textual_inversion"]["args"]["local_rank"] = int(st.text_input("Local Rank", value=1, help="For distributed training: local_rank"))
st.session_state["textual_inversion"]["args"]["checkpoint_frequency"] = int(st.text_input("Checkpoint Frequency", value=500, help="How often to save a checkpoint and sample image"))
# stable_sample_batches is crashing when saving the samples so for now I will disable it util its fixed.
#st.session_state["textual_inversion"]["args"]["stable_sample_batches"] = int(st.text_input("Stable Sample Batches", value=0,
#help="Number of fixed seed sample batches to generate per checkpoint"))
st.session_state["textual_inversion"]["args"]["stable_sample_batches"] = 0
st.session_state["textual_inversion"]["args"]["random_sample_batches"] = int(st.text_input("Random Sample Batches", value=2,
help="Number of random seed sample batches to generate per checkpoint"))
st.session_state["textual_inversion"]["args"]["sample_batch_size"] = int(st.text_input("Sample Batch Size", value=1, help="Number of samples to generate per batch"))
st.session_state["textual_inversion"]["args"]["sample_steps"] = int(st.text_input("Sample Steps", value=100,
help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes."))
st.session_state["textual_inversion"]["args"]["custom_templates"] = st.text_input("Custom Templates", value="",
help="A semicolon-delimited list of custom template to use for samples, using {} as a placeholder for the concept.")
with col5:
st.session_state["textual_inversion"]["args"]["resume"] = st.checkbox(label="Resume Previous Run?", value=False,
help="Resume previous run, if a valid resume.json file is on the output dir \
it will be used, otherwise if the 'Resume From' field bellow contains a valid resume.json file \
that one will be used.")
st.session_state["textual_inversion"]["args"]["resume_from"] = st.text_input(label="Resume From", help="Path to a directory to resume training from (ie, logs/token_name)")
#st.session_state["textual_inversion"]["args"]["resume_checkpoint"] = st.file_uploader("Resume Checkpoint", type=["bin"],
#help="Path to a specific checkpoint to resume training from (ie, logs/token_name/checkpoints/something.bin).")
#st.session_state["textual_inversion"]["args"]["st.session_state["textual_inversion"]"] = st.file_uploader("st.session_state["textual_inversion"] File", type=["json"],
#help="Path to a JSON st.session_state["textual_inversion"]uration file containing arguments for invoking this script."
#"If resume_from is given, its resume.json takes priority over this.")
#
#print (os.path.join(st.session_state["textual_inversion"]["args"]["output_dir"],st.session_state["textual_inversion"]["args"]["placeholder_token"].strip("<>"),"resume.json"))
#print (os.path.exists(os.path.join(st.session_state["textual_inversion"]["args"]["output_dir"],st.session_state["textual_inversion"]["args"]["placeholder_token"].strip("<>"),"resume.json")))
if os.path.exists(os.path.join(st.session_state["textual_inversion"]["args"]["output_dir"],st.session_state["textual_inversion"]["args"]["placeholder_token"].strip("<>"),"resume.json")):
st.session_state["textual_inversion"]["args"]["resume_from"] = os.path.join(
st.session_state["textual_inversion"]["args"]["output_dir"], st.session_state["textual_inversion"]["args"]["placeholder_token"].strip("<>"))
#print (st.session_state["textual_inversion"]["args"]["resume_from"])
if os.path.exists(os.path.join(st.session_state["textual_inversion"]["args"]["output_dir"],st.session_state["textual_inversion"]["args"]["placeholder_token"].strip("<>"), "checkpoints","last.bin")):
st.session_state["textual_inversion"]["args"]["resume_checkpoint"] = os.path.join(
st.session_state["textual_inversion"]["args"]["output_dir"], st.session_state["textual_inversion"]["args"]["placeholder_token"].strip("<>"), "checkpoints","last.bin")
#if "resume_from" in st.session_state["textual_inversion"]["args"]:
#if st.session_state["textual_inversion"]["args"]["resume_from"]:
#if os.path.exists(os.path.join(st.session_state["textual_inversion"]['args']['resume_from'], "resume.json")):
#with open(os.path.join(st.session_state["textual_inversion"]['args']['resume_from'], "resume.json"), 'rt') as f:
#try:
#resume_json = json.load(f)["args"]
#st.session_state["textual_inversion"]["args"] = OmegaConf.merge(st.session_state["textual_inversion"]["args"], resume_json)
#st.session_state["textual_inversion"]["args"]["resume_from"] = os.path.join(
#st.session_state["textual_inversion"]["args"]["output_dir"], st.session_state["textual_inversion"]["args"]["placeholder_token"].strip("<>"))
#except json.decoder.JSONDecodeError:
#pass
#print(st.session_state["textual_inversion"]["args"])
#print(st.session_state["textual_inversion"]["args"]['resume_from'])
#elif st.session_state["textual_inversion"]["args"]["st.session_state["textual_inversion"]"] is not None:
#with open(st.session_state["textual_inversion"]["args"]["st.session_state["textual_inversion"]"], 'rt') as f:
#args = parser.parse_args(namespace=argparse.Namespace(**json.load(f)["args"]))
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
if env_local_rank != -1 and env_local_rank != st.session_state["textual_inversion"]["args"]["local_rank"]:
st.session_state["textual_inversion"]["args"]["local_rank"] = env_local_rank
if st.session_state["textual_inversion"]["args"]["train_data_dir"] is None:
st.error("You must specify --train_data_dir")
if st.session_state["textual_inversion"]["args"]["pretrained_model_name_or_path"] is None:
st.error("You must specify --pretrained_model_name_or_path")
if st.session_state["textual_inversion"]["args"]["placeholder_token"] is None:
st.error("You must specify --placeholder_token")
if st.session_state["textual_inversion"]["args"]["initializer_token"] is None:
st.error("You must specify --initializer_token")
if st.session_state["textual_inversion"]["args"]["output_dir"] is None:
st.error("You must specify --output_dir")
# add a spacer and the submit button for the form.
st.session_state["textual_inversion"]["message"] = st.empty()
st.session_state["textual_inversion"]["progress_bar"] = st.empty()
st.write("---")
submit = st.form_submit_button("Run",help="")
if submit:
if "pipe" in st.session_state:
del st.session_state["pipe"]
if "model" in st.session_state:
del st.session_state["model"]
set_page_title("Running Textual Inversion - Stable Diffusion WebUI")
#st.session_state["textual_inversion"]["message"].info("Textual Inversion Running. For more info check the progress on your console or the Ouput Tab.")
try:
#try:
# run textual inversion.
config = st.session_state['textual_inversion']
textual_inversion(config)
#except RuntimeError:
#if "pipeline" in server_state["textual_inversion"]:
#del server_state['textual_inversion']["checker"]
#del server_state['textual_inversion']["unwrapped"]
#del server_state['textual_inversion']["pipeline"]
# run textual inversion.
#config = st.session_state['textual_inversion']
#textual_inversion(config)
set_page_title("Textual Inversion - Stable Diffusion WebUI")
except StopException:
set_page_title("Textual Inversion - Stable Diffusion WebUI")
print(f"Received Streamlit StopException")
st.session_state["textual_inversion"]["message"].empty()
#
with output_tab:
st.info("Under Construction. :construction_worker:")
#st.info("Nothing to show yet. Maybe try running some training first.")
#st.session_state["textual_inversion"]["preview_image"] = st.empty()
#st.session_state["textual_inversion"]["progress_bar"] = st.empty()
with tensorboard_tab:
#st.info("Under Construction. :construction_worker:")
# Start TensorBoard
st_tensorboard(logdir=os.path.join("outputs", "textual_inversion"), port=8888)

View File

@ -0,0 +1,708 @@
# This file is part of sygil-webui (https://github.com/Sygil-Dev/sygil-webui/).
# Copyright 2022 Sygil-Dev team.
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
# base webui import and utils.
from sd_utils import st, MemUsageMonitor, server_state, no_rerun, logger, set_page_title, \
custom_models_available, RealESRGAN_available, GFPGAN_available, \
LDSR_available
#load_models, hc, seed_to_int, \
#get_next_sequence_number, check_prompt_length, torch_gc, \
#save_sample, generation_callback, process_images, \
#KDiffusionSampler, \
# streamlit imports
from streamlit.runtime.scriptrunner import StopException
#streamlit components section
import streamlit_nested_layout #used to allow nested columns, just importing it is enought
#from streamlit.elements import image as STImage
import streamlit.components.v1 as components
#from streamlit.runtime.media_file_manager import media_file_manager
from streamlit.elements.image import image_to_url
#other imports
import base64, uuid
import os, sys, datetime, time
from PIL import Image
import requests
from slugify import slugify
from ldm.models.diffusion.ddim import DDIMSampler
from typing import Union
from io import BytesIO
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler
# streamlit components
from custom_components import sygil_suggestions
# Temp imports
# end of imports
#---------------------------------------------------------------------------------------------------------------
sygil_suggestions.init()
try:
# this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
from transformers import logging
logging.set_verbosity_error()
except:
pass
#
# Dev mode (server)
# _component_func = components.declare_component(
# "sd-gallery",
# url="http://localhost:3001",
# )
# Init Vuejs component
_component_func = components.declare_component(
"sd-gallery", "./frontend/dists/sd-gallery/dist")
def sdGallery(images=[], key=None):
component_value = _component_func(images=imgsToGallery(images), key=key, default="")
return component_value
def imgsToGallery(images):
urls = []
for i in images:
# random string for id
random_id = str(uuid.uuid4())
url = image_to_url(
image=i,
image_id= random_id,
width=i.width,
clamp=False,
channels="RGB",
output_format="PNG"
)
# image_io = BytesIO()
# i.save(image_io, 'PNG')
# width, height = i.size
# image_id = "%s" % (str(images.index(i)))
# (data, mimetype) = STImage._normalize_to_bytes(image_io.getvalue(), width, 'auto')
# this_file = media_file_manager.add(data, mimetype, image_id)
# img_str = this_file.url
urls.append(url)
return urls
class plugin_info():
plugname = "txt2img"
description = "Text to Image"
isTab = True
displayPriority = 1
@logger.catch(reraise=True)
def stable_horde(outpath, prompt, seed, sampler_name, save_grid, batch_size,
n_iter, steps, cfg_scale, width, height, prompt_matrix, use_GFPGAN, GFPGAN_model,
use_RealESRGAN, realesrgan_model_name, use_LDSR,
LDSR_model_name, ddim_eta, normalize_prompt_weights,
save_individual_images, sort_samples, write_info_files,
jpg_sample, variant_amount, variant_seed, api_key,
nsfw=True, censor_nsfw=False):
log = []
log.append("Generating image with Stable Horde.")
st.session_state["progress_bar_text"].code('\n'.join(log), language='')
# start time after garbage collection (or before?)
start_time = time.time()
# We will use this date here later for the folder name, need to start_time if not need
run_start_dt = datetime.datetime.now()
mem_mon = MemUsageMonitor('MemMon')
mem_mon.start()
os.makedirs(outpath, exist_ok=True)
sample_path = os.path.join(outpath, "samples")
os.makedirs(sample_path, exist_ok=True)
params = {
"sampler_name": "k_euler",
"toggles": [1,4],
"cfg_scale": cfg_scale,
"seed": str(seed),
"width": width,
"height": height,
"seed_variation": variant_seed if variant_seed else 1,
"steps": int(steps),
"n": int(n_iter)
# You can put extra params here if you wish
}
final_submit_dict = {
"prompt": prompt,
"params": params,
"nsfw": nsfw,
"censor_nsfw": censor_nsfw,
"trusted_workers": True,
"workers": []
}
log.append(final_submit_dict)
headers = {"apikey": api_key}
logger.debug(final_submit_dict)
st.session_state["progress_bar_text"].code('\n'.join(str(log)), language='')
horde_url = "https://stablehorde.net"
submit_req = requests.post(f'{horde_url}/api/v2/generate/async', json = final_submit_dict, headers = headers)
if submit_req.ok:
submit_results = submit_req.json()
logger.debug(submit_results)
log.append(submit_results)
st.session_state["progress_bar_text"].code(''.join(str(log)), language='')
req_id = submit_results['id']
is_done = False
while not is_done:
chk_req = requests.get(f'{horde_url}/api/v2/generate/check/{req_id}')
if not chk_req.ok:
logger.error(chk_req.text)
return
chk_results = chk_req.json()
logger.info(chk_results)
is_done = chk_results['done']
time.sleep(1)
retrieve_req = requests.get(f'{horde_url}/api/v2/generate/status/{req_id}')
if not retrieve_req.ok:
logger.error(retrieve_req.text)
return
results_json = retrieve_req.json()
# logger.debug(results_json)
results = results_json['generations']
output_images = []
comments = []
prompt_matrix_parts = []
if not st.session_state['defaults'].general.no_verify_input:
try:
check_prompt_length(prompt, comments)
except:
import traceback
logger.info("Error verifying input:", file=sys.stderr)
logger.info(traceback.format_exc(), file=sys.stderr)
all_prompts = batch_size * n_iter * [prompt]
all_seeds = [seed + x for x in range(len(all_prompts))]
for iter in range(len(results)):
b64img = results[iter]["img"]
base64_bytes = b64img.encode('utf-8')
img_bytes = base64.b64decode(base64_bytes)
img = Image.open(BytesIO(img_bytes))
sanitized_prompt = slugify(prompt)
prompts = all_prompts[iter * batch_size:(iter + 1) * batch_size]
#captions = prompt_matrix_parts[n * batch_size:(n + 1) * batch_size]
seeds = all_seeds[iter * batch_size:(iter + 1) * batch_size]
if sort_samples:
full_path = os.path.join(os.getcwd(), sample_path, sanitized_prompt)
sanitized_prompt = sanitized_prompt[:200-len(full_path)]
sample_path_i = os.path.join(sample_path, sanitized_prompt)
#print(f"output folder length: {len(os.path.join(os.getcwd(), sample_path_i))}")
#print(os.path.join(os.getcwd(), sample_path_i))
os.makedirs(sample_path_i, exist_ok=True)
base_count = get_next_sequence_number(sample_path_i)
filename = f"{base_count:05}-{steps}_{sampler_name}_{seeds[iter]}"
else:
full_path = os.path.join(os.getcwd(), sample_path)
sample_path_i = sample_path
base_count = get_next_sequence_number(sample_path_i)
filename = f"{base_count:05}-{steps}_{sampler_name}_{seed}_{sanitized_prompt}"[:200-len(full_path)] #same as before
save_sample(img, sample_path_i, filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale,
normalize_prompt_weights, use_GFPGAN, write_info_files, prompt_matrix, init_img=None,
denoising_strength=0.75, resize_mode=None, uses_loopback=False, uses_random_seed_loopback=False,
save_grid=save_grid,
sort_samples=sampler_name, sampler_name=sampler_name, ddim_eta=ddim_eta, n_iter=n_iter,
batch_size=batch_size, i=iter, save_individual_images=save_individual_images,
model_name="Stable Diffusion v1.5")
output_images.append(img)
# update image on the UI so we can see the progress
if "preview_image" in st.session_state:
st.session_state["preview_image"].image(img)
if "progress_bar_text" in st.session_state:
st.session_state["progress_bar_text"].empty()
#if len(results) > 1:
#final_filename = f"{iter}_{filename}"
#img.save(final_filename)
#logger.info(f"Saved {final_filename}")
else:
if "progress_bar_text" in st.session_state:
st.session_state["progress_bar_text"].error(submit_req.text)
logger.error(submit_req.text)
mem_max_used, mem_total = mem_mon.read_and_stop()
time_diff = time.time()-start_time
info = f"""
{prompt}
Steps: {steps}, Sampler: {sampler_name}, CFG scale: {cfg_scale}, Seed: {seed}{', GFPGAN' if use_GFPGAN else ''}{', '+realesrgan_model_name if use_RealESRGAN else ''}
{', Prompt Matrix Mode.' if prompt_matrix else ''}""".strip()
stats = f'''
Took { round(time_diff, 2) }s total ({ round(time_diff/(len(all_prompts)),2) }s per image)
Peak memory usage: { -(mem_max_used // -1_048_576) } MiB / { -(mem_total // -1_048_576) } MiB / { round(mem_max_used/mem_total*100, 3) }%'''
for comment in comments:
info += "\n\n" + comment
#mem_mon.stop()
#del mem_mon
torch_gc()
return output_images, seed, info, stats
#
@logger.catch(reraise=True)
def txt2img(prompt: str, ddim_steps: int, sampler_name: str, n_iter: int, batch_size: int, cfg_scale: float, seed: Union[int, str, None],
height: int, width: int, separate_prompts:bool = False, normalize_prompt_weights:bool = True,
save_individual_images: bool = True, save_grid: bool = True, group_by_prompt: bool = True,
save_as_jpg: bool = True, use_GFPGAN: bool = True, GFPGAN_model: str = 'GFPGANv1.3', use_RealESRGAN: bool = False,
RealESRGAN_model: str = "RealESRGAN_x4plus_anime_6B", use_LDSR: bool = True, LDSR_model: str = "model",
fp = None, variant_amount: float = 0.0,
variant_seed: int = None, ddim_eta:float = 0.0, write_info_files:bool = True,
use_stable_horde: bool = False, stable_horde_key:str = "0000000000"):
outpath = st.session_state['defaults'].general.outdir_txt2img
seed = seed_to_int(seed)
if not use_stable_horde:
if sampler_name == 'PLMS':
sampler = PLMSSampler(server_state["model"])
elif sampler_name == 'DDIM':
sampler = DDIMSampler(server_state["model"])
elif sampler_name == 'k_dpm_2_a':
sampler = KDiffusionSampler(server_state["model"],'dpm_2_ancestral')
elif sampler_name == 'k_dpm_2':
sampler = KDiffusionSampler(server_state["model"],'dpm_2')
elif sampler_name == 'k_dpmpp_2m':
sampler = KDiffusionSampler(server_state["model"],'dpmpp_2m')
elif sampler_name == 'k_euler_a':
sampler = KDiffusionSampler(server_state["model"],'euler_ancestral')
elif sampler_name == 'k_euler':
sampler = KDiffusionSampler(server_state["model"],'euler')
elif sampler_name == 'k_heun':
sampler = KDiffusionSampler(server_state["model"],'heun')
elif sampler_name == 'k_lms':
sampler = KDiffusionSampler(server_state["model"],'lms')
else:
raise Exception("Unknown sampler: " + sampler_name)
def init():
pass
def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name):
samples_ddim, _ = sampler.sample(S=ddim_steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=cfg_scale,
unconditional_conditioning=unconditional_conditioning, eta=ddim_eta, x_T=x,
img_callback=generation_callback if not server_state["bridge"] else None,
log_every_t=int(st.session_state.update_preview_frequency if not server_state["bridge"] else 100))
return samples_ddim
if use_stable_horde:
output_images, seed, info, stats = stable_horde(
prompt=prompt,
seed=seed,
outpath=outpath,
sampler_name=sampler_name,
save_grid=save_grid,
batch_size=batch_size,
n_iter=n_iter,
steps=ddim_steps,
cfg_scale=cfg_scale,
width=width,
height=height,
prompt_matrix=separate_prompts,
use_GFPGAN=use_GFPGAN,
GFPGAN_model=GFPGAN_model,
use_RealESRGAN=use_RealESRGAN,
realesrgan_model_name=RealESRGAN_model,
use_LDSR=use_LDSR,
LDSR_model_name=LDSR_model,
ddim_eta=ddim_eta,
normalize_prompt_weights=normalize_prompt_weights,
save_individual_images=save_individual_images,
sort_samples=group_by_prompt,
write_info_files=write_info_files,
jpg_sample=save_as_jpg,
variant_amount=variant_amount,
variant_seed=variant_seed,
api_key=stable_horde_key
)
else:
#try:
output_images, seed, info, stats = process_images(
outpath=outpath,
func_init=init,
func_sample=sample,
prompt=prompt,
seed=seed,
sampler_name=sampler_name,
save_grid=save_grid,
batch_size=batch_size,
n_iter=n_iter,
steps=ddim_steps,
cfg_scale=cfg_scale,
width=width,
height=height,
prompt_matrix=separate_prompts,
use_GFPGAN=use_GFPGAN,
GFPGAN_model=GFPGAN_model,
use_RealESRGAN=use_RealESRGAN,
realesrgan_model_name=RealESRGAN_model,
use_LDSR=use_LDSR,
LDSR_model_name=LDSR_model,
ddim_eta=ddim_eta,
normalize_prompt_weights=normalize_prompt_weights,
save_individual_images=save_individual_images,
sort_samples=group_by_prompt,
write_info_files=write_info_files,
jpg_sample=save_as_jpg,
variant_amount=variant_amount,
variant_seed=variant_seed,
)
del sampler
return output_images, seed, info, stats
#except RuntimeError as e:
#err = e
#err_msg = f'CRASHED:<br><textarea rows="5" style="color:white;background: black;width: -webkit-fill-available;font-family: monospace;font-size: small;font-weight: bold;">{str(e)}</textarea><br><br>Please wait while the program restarts.'
#stats = err_msg
#return [], seed, 'err', stats
#
@logger.catch(reraise=True)
def layout():
with st.form("txt2img-inputs"):
st.session_state["generation_mode"] = "txt2img"
input_col1, generate_col1 = st.columns([10,1])
with input_col1:
#prompt = st.text_area("Input Text","")
placeholder = "A corgi wearing a top hat as an oil painting."
prompt = st.text_area("Input Text","", placeholder=placeholder, height=54)
if "defaults" in st.session_state:
if st.session_state["defaults"].general.enable_suggestions:
sygil_suggestions.suggestion_area(placeholder)
if "defaults" in st.session_state:
if st.session_state['defaults'].admin.global_negative_prompt:
prompt += f"### {st.session_state['defaults'].admin.global_negative_prompt}"
#print(prompt)
# creating the page layout using columns
col1, col2, col3 = st.columns([2,5,2], gap="large")
with col1:
width = st.slider("Width:", min_value=st.session_state['defaults'].txt2img.width.min_value, max_value=st.session_state['defaults'].txt2img.width.max_value,
value=st.session_state['defaults'].txt2img.width.value, step=st.session_state['defaults'].txt2img.width.step)
height = st.slider("Height:", min_value=st.session_state['defaults'].txt2img.height.min_value, max_value=st.session_state['defaults'].txt2img.height.max_value,
value=st.session_state['defaults'].txt2img.height.value, step=st.session_state['defaults'].txt2img.height.step)
cfg_scale = st.number_input("CFG (Classifier Free Guidance Scale):", min_value=st.session_state['defaults'].txt2img.cfg_scale.min_value,
value=st.session_state['defaults'].txt2img.cfg_scale.value, step=st.session_state['defaults'].txt2img.cfg_scale.step,
help="How strongly the image should follow the prompt.")
seed = st.text_input("Seed:", value=st.session_state['defaults'].txt2img.seed, help=" The seed to use, if left blank a random seed will be generated.")
with st.expander("Batch Options"):
#batch_count = st.slider("Batch count.", min_value=st.session_state['defaults'].txt2img.batch_count.min_value, max_value=st.session_state['defaults'].txt2img.batch_count.max_value,
#value=st.session_state['defaults'].txt2img.batch_count.value, step=st.session_state['defaults'].txt2img.batch_count.step,
#help="How many iterations or batches of images to generate in total.")
#batch_size = st.slider("Batch size", min_value=st.session_state['defaults'].txt2img.batch_size.min_value, max_value=st.session_state['defaults'].txt2img.batch_size.max_value,
#value=st.session_state.defaults.txt2img.batch_size.value, step=st.session_state.defaults.txt2img.batch_size.step,
#help="How many images are at once in a batch.\
#It increases the VRAM usage a lot but if you have enough VRAM it can reduce the time it takes to finish generation as more images are generated at once.\
#Default: 1")
st.session_state["batch_count"] = st.number_input("Batch count.", value=st.session_state['defaults'].txt2img.batch_count.value,
help="How many iterations or batches of images to generate in total.")
st.session_state["batch_size"] = st.number_input("Batch size", value=st.session_state.defaults.txt2img.batch_size.value,
help="How many images are at once in a batch.\
It increases the VRAM usage a lot but if you have enough VRAM it can reduce the time it takes \
to finish generation as more images are generated at once.\
Default: 1")
with st.expander("Preview Settings"):
st.session_state["update_preview"] = st.session_state["defaults"].general.update_preview
st.session_state["update_preview_frequency"] = st.number_input("Update Image Preview Frequency",
min_value=0,
value=st.session_state['defaults'].txt2img.update_preview_frequency,
help="Frequency in steps at which the the preview image is updated. By default the frequency \
is set to 10 step.")
with col2:
preview_tab, gallery_tab = st.tabs(["Preview", "Gallery"])
with preview_tab:
#st.write("Image")
#Image for testing
#image = Image.open(requests.get("https://icon-library.com/images/image-placeholder-icon/image-placeholder-icon-13.jpg", stream=True).raw).convert('RGB')
#new_image = image.resize((175, 240))
#preview_image = st.image(image)
# create an empty container for the image, progress bar, etc so we can update it later and use session_state to hold them globally.
st.session_state["preview_image"] = st.empty()
st.session_state["progress_bar_text"] = st.empty()
st.session_state["progress_bar_text"].info("Nothing but crickets here, try generating something first.")
st.session_state["progress_bar"] = st.empty()
message = st.empty()
with gallery_tab:
st.session_state["gallery"] = st.empty()
#st.session_state["gallery"].info("Nothing but crickets here, try generating something first.")
with col3:
# If we have custom models available on the "models/custom"
#folder then we show a menu to select which model we want to use, otherwise we use the main model for SD
custom_models_available()
if server_state["CustomModel_available"]:
st.session_state["custom_model"] = st.selectbox("Custom Model:", server_state["custom_models"],
index=server_state["custom_models"].index(st.session_state['defaults'].general.default_model),
help="Select the model you want to use. This option is only available if you have custom models \
on your 'models/custom' folder. The model name that will be shown here is the same as the name\
the file for the model has on said folder, it is recommended to give the .ckpt file a name that \
will make it easier for you to distinguish it from other models. Default: Stable Diffusion v1.5")
st.session_state.sampling_steps = st.number_input("Sampling Steps", value=st.session_state.defaults.txt2img.sampling_steps.value,
min_value=st.session_state.defaults.txt2img.sampling_steps.min_value,
step=st.session_state['defaults'].txt2img.sampling_steps.step,
help="Set the default number of sampling steps to use. Default is: 30 (with k_euler)")
sampler_name_list = ["k_lms", "k_euler", "k_euler_a", "k_dpm_2", "k_dpm_2_a", "k_dpmpp_2m", "k_heun", "PLMS", "DDIM"]
sampler_name = st.selectbox("Sampling method", sampler_name_list,
index=sampler_name_list.index(st.session_state['defaults'].txt2img.default_sampler), help="Sampling method to use. Default: k_euler")
with st.expander("Advanced"):
with st.expander("Stable Horde"):
use_stable_horde = st.checkbox("Use Stable Horde", value=False, help="Use the Stable Horde to generate images. More info can be found at https://stablehorde.net/")
stable_horde_key = st.text_input("Stable Horde Api Key", value=st.session_state['defaults'].general.stable_horde_api, type="password",
help="Optional Api Key used for the Stable Horde Bridge, if no api key is added the horde will be used anonymously.")
with st.expander("Output Settings"):
separate_prompts = st.checkbox("Create Prompt Matrix.", value=st.session_state['defaults'].txt2img.separate_prompts,
help="Separate multiple prompts using the `|` character, and get all combinations of them.")
normalize_prompt_weights = st.checkbox("Normalize Prompt Weights.", value=st.session_state['defaults'].txt2img.normalize_prompt_weights,
help="Ensure the sum of all weights add up to 1.0")
save_individual_images = st.checkbox("Save individual images.", value=st.session_state['defaults'].txt2img.save_individual_images,
help="Save each image generated before any filter or enhancement is applied.")
save_grid = st.checkbox("Save grid",value=st.session_state['defaults'].txt2img.save_grid, help="Save a grid with all the images generated into a single image.")
group_by_prompt = st.checkbox("Group results by prompt", value=st.session_state['defaults'].txt2img.group_by_prompt,
help="Saves all the images with the same prompt into the same folder. When using a prompt matrix each prompt combination will have its own folder.")
write_info_files = st.checkbox("Write Info file", value=st.session_state['defaults'].txt2img.write_info_files,
help="Save a file next to the image with informartion about the generation.")
save_as_jpg = st.checkbox("Save samples as jpg", value=st.session_state['defaults'].txt2img.save_as_jpg, help="Saves the images as jpg instead of png.")
# check if GFPGAN, RealESRGAN and LDSR are available.
#if "GFPGAN_available" not in st.session_state:
GFPGAN_available()
#if "RealESRGAN_available" not in st.session_state:
RealESRGAN_available()
#if "LDSR_available" not in st.session_state:
LDSR_available()
if st.session_state["GFPGAN_available"] or st.session_state["RealESRGAN_available"] or st.session_state["LDSR_available"]:
with st.expander("Post-Processing"):
face_restoration_tab, upscaling_tab = st.tabs(["Face Restoration", "Upscaling"])
with face_restoration_tab:
# GFPGAN used for face restoration
if st.session_state["GFPGAN_available"]:
#with st.expander("Face Restoration"):
#if st.session_state["GFPGAN_available"]:
#with st.expander("GFPGAN"):
st.session_state["use_GFPGAN"] = st.checkbox("Use GFPGAN", value=st.session_state['defaults'].txt2img.use_GFPGAN,
help="Uses the GFPGAN model to improve faces after the generation.\
This greatly improve the quality and consistency of faces but uses\
extra VRAM. Disable if you need the extra VRAM.")
st.session_state["GFPGAN_model"] = st.selectbox("GFPGAN model", st.session_state["GFPGAN_models"],
index=st.session_state["GFPGAN_models"].index(st.session_state['defaults'].general.GFPGAN_model))
#st.session_state["GFPGAN_strenght"] = st.slider("Effect Strenght", min_value=1, max_value=100, value=1, step=1, help='')
else:
st.session_state["use_GFPGAN"] = False
with upscaling_tab:
st.session_state['use_upscaling'] = st.checkbox("Use Upscaling", value=st.session_state['defaults'].txt2img.use_upscaling)
# RealESRGAN and LDSR used for upscaling.
if st.session_state["RealESRGAN_available"] or st.session_state["LDSR_available"]:
upscaling_method_list = []
if st.session_state["RealESRGAN_available"]:
upscaling_method_list.append("RealESRGAN")
if st.session_state["LDSR_available"]:
upscaling_method_list.append("LDSR")
#print (st.session_state["RealESRGAN_available"])
st.session_state["upscaling_method"] = st.selectbox("Upscaling Method", upscaling_method_list,
index=upscaling_method_list.index(st.session_state['defaults'].general.upscaling_method)
if st.session_state['defaults'].general.upscaling_method in upscaling_method_list
else 0)
if st.session_state["RealESRGAN_available"]:
with st.expander("RealESRGAN"):
if st.session_state["upscaling_method"] == "RealESRGAN" and st.session_state['use_upscaling']:
st.session_state["use_RealESRGAN"] = True
else:
st.session_state["use_RealESRGAN"] = False
st.session_state["RealESRGAN_model"] = st.selectbox("RealESRGAN model", st.session_state["RealESRGAN_models"],
index=st.session_state["RealESRGAN_models"].index(st.session_state['defaults'].general.RealESRGAN_model))
else:
st.session_state["use_RealESRGAN"] = False
st.session_state["RealESRGAN_model"] = "RealESRGAN_x4plus"
#
if st.session_state["LDSR_available"]:
with st.expander("LDSR"):
if st.session_state["upscaling_method"] == "LDSR" and st.session_state['use_upscaling']:
st.session_state["use_LDSR"] = True
else:
st.session_state["use_LDSR"] = False
st.session_state["LDSR_model"] = st.selectbox("LDSR model", st.session_state["LDSR_models"],
index=st.session_state["LDSR_models"].index(st.session_state['defaults'].general.LDSR_model))
st.session_state["ldsr_sampling_steps"] = st.number_input("Sampling Steps", value=st.session_state['defaults'].txt2img.LDSR_config.sampling_steps,
help="")
st.session_state["preDownScale"] = st.number_input("PreDownScale", value=st.session_state['defaults'].txt2img.LDSR_config.preDownScale,
help="")
st.session_state["postDownScale"] = st.number_input("postDownScale", value=st.session_state['defaults'].txt2img.LDSR_config.postDownScale,
help="")
downsample_method_list = ['Nearest', 'Lanczos']
st.session_state["downsample_method"] = st.selectbox("Downsample Method", downsample_method_list,
index=downsample_method_list.index(st.session_state['defaults'].txt2img.LDSR_config.downsample_method))
else:
st.session_state["use_LDSR"] = False
st.session_state["LDSR_model"] = "model"
with st.expander("Variant"):
variant_amount = st.slider("Variant Amount:", value=st.session_state['defaults'].txt2img.variant_amount.value,
min_value=st.session_state['defaults'].txt2img.variant_amount.min_value, max_value=st.session_state['defaults'].txt2img.variant_amount.max_value,
step=st.session_state['defaults'].txt2img.variant_amount.step)
variant_seed = st.text_input("Variant Seed:", value=st.session_state['defaults'].txt2img.seed,
help="The seed to use when generating a variant, if left blank a random seed will be generated.")
#galleryCont = st.empty()
# Every form must have a submit button, the extra blank spaces is a temp way to align it with the input field. Needs to be done in CSS or some other way.
generate_col1.write("")
generate_col1.write("")
generate_button = generate_col1.form_submit_button("Generate")
#
if generate_button:
with col2:
with no_rerun:
if not use_stable_horde:
with hc.HyLoader('Loading Models...', hc.Loaders.standard_loaders,index=[0]):
load_models(use_LDSR=st.session_state["use_LDSR"], LDSR_model=st.session_state["LDSR_model"],
use_GFPGAN=st.session_state["use_GFPGAN"], GFPGAN_model=st.session_state["GFPGAN_model"] ,
use_RealESRGAN=st.session_state["use_RealESRGAN"], RealESRGAN_model=st.session_state["RealESRGAN_model"],
CustomModel_available=server_state["CustomModel_available"], custom_model=st.session_state["custom_model"])
#print(st.session_state['use_RealESRGAN'])
#print(st.session_state['use_LDSR'])
try:
output_images, seeds, info, stats = txt2img(prompt, st.session_state.sampling_steps, sampler_name, st.session_state["batch_count"], st.session_state["batch_size"],
cfg_scale, seed, height, width, separate_prompts, normalize_prompt_weights, save_individual_images,
save_grid, group_by_prompt, save_as_jpg, st.session_state["use_GFPGAN"], st.session_state['GFPGAN_model'],
use_RealESRGAN=st.session_state["use_RealESRGAN"], RealESRGAN_model=st.session_state["RealESRGAN_model"],
use_LDSR=st.session_state["use_LDSR"], LDSR_model=st.session_state["LDSR_model"],
variant_amount=variant_amount, variant_seed=variant_seed, write_info_files=write_info_files,
use_stable_horde=use_stable_horde, stable_horde_key=stable_horde_key)
message.success('Render Complete: ' + info + '; Stats: ' + stats, icon="")
with gallery_tab:
logger.info(seeds)
st.session_state["gallery"].text = ""
sdGallery(output_images)
except (StopException,
#KeyError
):
print(f"Received Streamlit StopException")
# reset the page title so the percent doesnt stay on it confusing the user.
set_page_title(f"Stable Diffusion Playground")
# this will render all the images at the end of the generation but its better if its moved to a second tab inside col2 and shown as a gallery.
# use the current col2 first tab to show the preview_img and update it as its generated.
#preview_image.image(output_images)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,277 @@
# This file is part of sygil-webui (https://github.com/Sygil-Dev/sygil-webui/).
# Copyright 2022 Sygil-Dev team.
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
# base webui import and utils.
#import streamlit as st
# We import hydralit like this to replace the previous stuff
# we had with native streamlit as it lets ur replace things 1:1
from sd_utils import st, hc, load_configs, load_css, set_logger_verbosity,\
logger, quiesce_logger, set_page_title, random
# streamlit imports
import streamlit_nested_layout
#streamlit components section
#from st_on_hover_tabs import on_hover_tabs
#from streamlit_server_state import server_state, server_state_lock
#other imports
import argparse
#from sd_utils.bridge import run_bridge
# import custom components
from custom_components import draggable_number_input
# end of imports
#---------------------------------------------------------------------------------------------------------------
load_configs()
help = """
A double dash (`--`) is used to separate streamlit arguments from app arguments.
As a result using "streamlit run webui_streamlit.py --headless"
will show the help for streamlit itself and not pass any argument to our app,
we need to use "streamlit run webui_streamlit.py -- --headless"
in order to pass a command argument to this app."""
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--headless", action='store_true', help="Don't launch web server, util if you just want to run the stable horde bridge.", default=False)
parser.add_argument("--bridge", action='store_true', help="don't launch web server, but make this instance into a Horde bridge.", default=False)
parser.add_argument('--horde_api_key', action="store", required=False, type=str, help="The API key corresponding to the owner of this Horde instance")
parser.add_argument('--horde_name', action="store", required=False, type=str, help="The server name for the Horde. It will be shown to the world and there can be only one.")
parser.add_argument('--horde_url', action="store", required=False, type=str, help="The SH Horde URL. Where the bridge will pickup prompts and send the finished generations.")
parser.add_argument('--horde_priority_usernames',type=str, action='append', required=False, help="Usernames which get priority use in this horde instance. The owner's username is always in this list.")
parser.add_argument('--horde_max_power',type=int, required=False, help="How much power this instance has to generate pictures. Min: 2")
parser.add_argument('--horde_sfw', action='store_true', required=False, help="Set to true if you do not want this worker generating NSFW images.")
parser.add_argument('--horde_blacklist', nargs='+', required=False, help="List the words that you want to blacklist.")
parser.add_argument('--horde_censorlist', nargs='+', required=False, help="List the words that you want to censor.")
parser.add_argument('--horde_censor_nsfw', action='store_true', required=False, help="Set to true if you want this bridge worker to censor NSFW images.")
parser.add_argument('--horde_model', action='store', required=False, help="Which model to run on this horde.")
parser.add_argument('-v', '--verbosity', action='count', default=0, help="The default logging level is ERROR or higher. This value increases the amount of logging seen in your screen")
parser.add_argument('-q', '--quiet', action='count', default=0, help="The default logging level is ERROR or higher. This value decreases the amount of logging seen in your screen")
opt = parser.parse_args()
#with server_state_lock["bridge"]:
#server_state["bridge"] = opt.bridge
@logger.catch(reraise=True)
def layout():
"""Layout functions to define all the streamlit layout here."""
if not st.session_state["defaults"].debug.enable_hydralit:
st.set_page_config(page_title="Stable Diffusion Playground", layout="wide", initial_sidebar_state="collapsed")
#app = st.HydraApp(title='Stable Diffusion WebUI', favicon="", sidebar_state="expanded", layout="wide",
#hide_streamlit_markers=False, allow_url_nav=True , clear_cross_app_sessions=False)
# load css as an external file, function has an option to local or remote url. Potential use when running from cloud infra that might not have access to local path.
load_css(True, 'frontend/css/streamlit.main.css')
#
# specify the primary menu definition
menu_data = [
{'id': 'Stable Diffusion', 'label': 'Stable Diffusion', 'icon': 'bi bi-grid-1x2-fill'},
{'id': 'Train','label':"Train", 'icon': "bi bi-lightbulb-fill", 'submenu':[
{'id': 'Textual Inversion', 'label': 'Textual Inversion', 'icon': 'bi bi-lightbulb-fill'},
{'id': 'Fine Tunning', 'label': 'Fine Tunning', 'icon': 'bi bi-lightbulb-fill'},
]},
{'id': 'Model Manager', 'label': 'Model Manager', 'icon': 'bi bi-cloud-arrow-down-fill'},
{'id': 'Tools','label':"Tools", 'icon': "bi bi-tools", 'submenu':[
{'id': 'API Server', 'label': 'API Server', 'icon': 'bi bi-server'},
{'id': 'Barfi/BaklavaJS', 'label': 'Barfi/BaklavaJS', 'icon': 'bi bi-diagram-3-fill'},
#{'id': 'API Server', 'label': 'API Server', 'icon': 'bi bi-server'},
]},
{'id': 'Settings', 'label': 'Settings', 'icon': 'bi bi-gear-fill'},
]
over_theme = {'txc_inactive': '#FFFFFF', "menu_background":'#000000'}
menu_id = hc.nav_bar(
menu_definition=menu_data,
#home_name='Home',
#login_name='Logout',
hide_streamlit_markers=False,
override_theme=over_theme,
sticky_nav=True,
sticky_mode='pinned',
)
#
#if menu_id == "Home":
#st.info("Under Construction. :construction_worker:")
if menu_id == "Stable Diffusion":
# set the page url and title
#st.experimental_set_query_params(page='stable-diffusion')
try:
set_page_title("Stable Diffusion Playground")
except NameError:
st.experimental_rerun()
txt2img_tab, img2img_tab, txt2vid_tab, img2txt_tab, post_processing_tab, concept_library_tab = st.tabs(["Text-to-Image", "Image-to-Image",
#"Inpainting",
"Text-to-Video", "Image-To-Text",
"Post-Processing","Concept Library"])
#with home_tab:
#from home import layout
#layout()
with txt2img_tab:
from txt2img import layout
layout()
with img2img_tab:
from img2img import layout
layout()
#with inpainting_tab:
#from inpainting import layout
#layout()
with txt2vid_tab:
from txt2vid import layout
layout()
with img2txt_tab:
from img2txt import layout
layout()
with post_processing_tab:
from post_processing import layout
layout()
with concept_library_tab:
from sd_concept_library import layout
layout()
#
elif menu_id == 'Model Manager':
set_page_title("Model Manager - Stable Diffusion Playground")
from ModelManager import layout
layout()
elif menu_id == 'Textual Inversion':
from textual_inversion import layout
layout()
elif menu_id == 'Fine Tunning':
#from textual_inversion import layout
#layout()
st.info("Under Construction. :construction_worker:")
elif menu_id == 'API Server':
set_page_title("API Server - Stable Diffusion Playground")
from APIServer import layout
layout()
elif menu_id == 'Barfi/BaklavaJS':
set_page_title("Barfi/BaklavaJS - Stable Diffusion Playground")
from barfi_baklavajs import layout
layout()
elif menu_id == 'Settings':
set_page_title("Settings - Stable Diffusion Playground")
from Settings import layout
layout()
# calling dragable input component module at the end, so it works on all pages
draggable_number_input.load()
if __name__ == '__main__':
set_logger_verbosity(opt.verbosity)
quiesce_logger(opt.quiet)
if not opt.headless:
layout()
#with server_state_lock["bridge"]:
#if server_state["bridge"]:
#try:
#import bridgeData as cd
#except ModuleNotFoundError as e:
#logger.warning("No bridgeData found. Falling back to default where no CLI args are set.")
#logger.debug(str(e))
#except SyntaxError as e:
#logger.warning("bridgeData found, but is malformed. Falling back to default where no CLI args are set.")
#logger.debug(str(e))
#except Exception as e:
#logger.warning("No bridgeData found, use default where no CLI args are set")
#logger.debug(str(e))
#finally:
#try: # check if cd exists (i.e. bridgeData loaded properly)
#cd
#except: # if not, create defaults
#class temp(object):
#def __init__(self):
#random.seed()
#self.horde_url = "https://stablehorde.net"
## Give a cool name to your instance
#self.horde_name = f"Automated Instance #{random.randint(-100000000, 100000000)}"
## The api_key identifies a unique user in the horde
#self.horde_api_key = "0000000000"
## Put other users whose prompts you want to prioritize.
## The owner's username is always included so you don't need to add it here, unless you want it to have lower priority than another user
#self.horde_priority_usernames = []
#self.horde_max_power = 8
#self.nsfw = True
#self.censor_nsfw = False
#self.blacklist = []
#self.censorlist = []
#self.models_to_load = ["stable_diffusion"]
#cd = temp()
#horde_api_key = opt.horde_api_key if opt.horde_api_key else cd.horde_api_key
#horde_name = opt.horde_name if opt.horde_name else cd.horde_name
#horde_url = opt.horde_url if opt.horde_url else cd.horde_url
#horde_priority_usernames = opt.horde_priority_usernames if opt.horde_priority_usernames else cd.horde_priority_usernames
#horde_max_power = opt.horde_max_power if opt.horde_max_power else cd.horde_max_power
## Not used yet
#horde_models = [opt.horde_model] if opt.horde_model else cd.models_to_load
#try:
#horde_nsfw = not opt.horde_sfw if opt.horde_sfw else cd.horde_nsfw
#except AttributeError:
#horde_nsfw = True
#try:
#horde_censor_nsfw = opt.horde_censor_nsfw if opt.horde_censor_nsfw else cd.horde_censor_nsfw
#except AttributeError:
#horde_censor_nsfw = False
#try:
#horde_blacklist = opt.horde_blacklist if opt.horde_blacklist else cd.horde_blacklist
#except AttributeError:
#horde_blacklist = []
#try:
#horde_censorlist = opt.horde_censorlist if opt.horde_censorlist else cd.horde_censorlist
#except AttributeError:
#horde_censorlist = []
#if horde_max_power < 2:
#horde_max_power = 2
#horde_max_pixels = 64*64*8*horde_max_power
#logger.info(f"Joining Horde with parameters: Server Name '{horde_name}'. Horde URL '{horde_url}'. Max Pixels {horde_max_pixels}")
#try:
#thread = threading.Thread(target=run_bridge(1, horde_api_key, horde_name, horde_url,
#horde_priority_usernames, horde_max_pixels,
#horde_nsfw, horde_censor_nsfw, horde_blacklist,
#horde_censorlist), args=())
#thread.daemon = True
#thread.start()
##run_bridge(1, horde_api_key, horde_name, horde_url, horde_priority_usernames, horde_max_pixels, horde_nsfw, horde_censor_nsfw, horde_blacklist, horde_censorlist)
#except KeyboardInterrupt:
#print(f"Keyboard Interrupt Received. Ending Bridge")