diff --git a/webui/streamlit/frontend/css/streamlit.main.css b/webui/streamlit/frontend/css/streamlit.main.css
new file mode 100644
index 0000000..9162d05
--- /dev/null
+++ b/webui/streamlit/frontend/css/streamlit.main.css
@@ -0,0 +1,178 @@
+/*
+This file is part of sygil-webui (https://github.com/Sygil-Dev/sygil-webui/).
+
+Copyright 2022 Sygil-Dev team.
+This program is free software: you can redistribute it and/or modify
+it under the terms of the GNU Affero General Public License as published by
+the Free Software Foundation, either version 3 of the License, or
+(at your option) any later version.
+
+This program is distributed in the hope that it will be useful,
+but WITHOUT ANY WARRANTY; without even the implied warranty of
+MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+GNU Affero General Public License for more details.
+
+You should have received a copy of the GNU Affero General Public License
+along with this program. If not, see .
+*/
+
+/***********************************************************
+* Additional CSS for streamlit builtin components *
+************************************************************/
+
+/* Tab name (e.g. Text-to-Image) //improve legibility*/
+button[data-baseweb="tab"] {
+ font-size: 25px;
+}
+
+/* Image Container (only appear after run finished)//center the image, especially better looks in wide screen */
+.css-1kyxreq{
+ justify-content: center;
+}
+
+
+/* Streamlit header */
+.css-1avcm0n {
+ background-color: transparent;
+}
+
+/* Main streamlit container (below header) //reduce the empty spaces*/
+.css-18e3th9 {
+ padding-top: 1rem;
+}
+
+
+
+/***********************************************************
+* Additional CSS for streamlit custom/3rd party components *
+************************************************************/
+/* For stream_on_hover */
+section[data-testid="stSidebar"] > div:nth-of-type(1) {
+ background-color: #111;
+}
+
+button[kind="header"] {
+ background-color: transparent;
+ color: rgb(180, 167, 141);
+}
+
+@media (hover) {
+ /* header element */
+ header[data-testid="stHeader"] {
+ /* display: none;*/ /*suggested behavior by streamlit hover components*/
+ pointer-events: none; /* disable interaction of the transparent background */
+ }
+
+ /* The button on the streamlit navigation menu */
+ button[kind="header"] {
+ /* display: none;*/ /*suggested behavior by streamlit hover components*/
+ pointer-events: auto; /* enable interaction of the button even if parents intereaction disabled */
+ }
+
+ /* added to avoid main sectors (all element to the right of sidebar from) moving */
+ section[data-testid="stSidebar"] {
+ width: 3.5% !important;
+ min-width: 3.5% !important;
+ }
+
+ /* The navigation menu specs and size */
+ section[data-testid="stSidebar"] > div {
+ height: 100%;
+ width: 2% !important;
+ min-width: 100% !important;
+ position: relative;
+ z-index: 1;
+ top: 0;
+ left: 0;
+ background-color: #111;
+ overflow-x: hidden;
+ transition: 0.5s ease-in-out;
+ padding-top: 0px;
+ white-space: nowrap;
+ }
+
+ /* The navigation menu open and close on hover and size */
+ section[data-testid="stSidebar"] > div:hover {
+ width: 300px !important;
+ }
+}
+
+@media (max-width: 272px) {
+ section[data-testid="stSidebar"] > div {
+ width: 15rem;
+ }
+}
+
+/***********************************************************
+* Additional CSS for other elements
+************************************************************/
+button[data-baseweb="tab"] {
+ font-size: 20px;
+}
+
+@media (min-width: 1200px){
+h1 {
+ font-size: 1.75rem;
+}
+}
+#tabs-1-tabpanel-0 > div:nth-child(1) > div > div.stTabs.css-0.exp6ofz0 {
+ width: 50rem;
+ align-self: center;
+}
+div.gallery:hover {
+ border: 1px solid #777;
+}
+.css-dg4u6x p {
+ font-size: 0.8rem;
+ text-align: center;
+ position: relative;
+ top: 6px;
+}
+
+.row-widget.stButton {
+ text-align: center;
+}
+
+/********************************************************************
+ Hide anchor links on titles
+*********************************************************************/
+/*
+.css-15zrgzn {
+ display: none
+ }
+.css-eczf16 {
+ display: none
+ }
+.css-jn99sy {
+ display: none
+ }
+
+/* Make the text area widget have a similar height as the text input field */
+.st-dy{
+ height: 54px;
+ min-height: 25px;
+}
+.css-17useex{
+ gap: 3px;
+
+}
+
+/* Remove some empty spaces to make the UI more compact. */
+.css-18e3th9{
+ padding-left: 10px;
+ padding-right: 30px;
+ position: unset !important; /* Fixes the layout/page going up when an expander or another item is expanded and then collapsed */
+}
+.css-k1vhr4{
+ padding-top: initial;
+}
+.css-ret2ud{
+ padding-left: 10px;
+ padding-right: 30px;
+ gap: initial;
+ display: initial;
+}
+
+.css-w5z5an{
+ gap: 1px;
+}
diff --git a/webui/streamlit/scripts/APIServer.py b/webui/streamlit/scripts/APIServer.py
new file mode 100644
index 0000000..a4da6c5
--- /dev/null
+++ b/webui/streamlit/scripts/APIServer.py
@@ -0,0 +1,34 @@
+# This file is part of sygil-webui (https://github.com/Sygil-Dev/sygil-webui/).
+
+# Copyright 2022 Sygil-Dev team.
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+# base webui import and utils.
+#from sd_utils import *
+from sd_utils import st
+# streamlit imports
+
+#streamlit components section
+
+#other imports
+#from fastapi import FastAPI
+#import uvicorn
+
+# Temp imports
+
+# end of imports
+#---------------------------------------------------------------------------------------------------------------
+
+
+def layout():
+ st.info("Under Construction. :construction_worker:")
\ No newline at end of file
diff --git a/webui/streamlit/scripts/ModelManager.py b/webui/streamlit/scripts/ModelManager.py
new file mode 100644
index 0000000..3db0105
--- /dev/null
+++ b/webui/streamlit/scripts/ModelManager.py
@@ -0,0 +1,121 @@
+# This file is part of sygil-webui (https://github.com/Sygil-Dev/sygil-webui/).
+
+# Copyright 2022 Sygil-Dev team.
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+# base webui import and utils.
+from sd_utils import st, logger
+# streamlit imports
+
+
+#other imports
+import os, requests
+from requests.auth import HTTPBasicAuth
+from requests import HTTPError
+from stqdm import stqdm
+
+# Temp imports
+
+
+# end of imports
+#---------------------------------------------------------------------------------------------------------------
+def download_file(file_name, file_path, file_url):
+ if not os.path.exists(file_path):
+ os.makedirs(file_path)
+
+ if not os.path.exists(os.path.join(file_path , file_name)):
+ print('Downloading ' + file_name + '...')
+ # TODO - add progress bar in streamlit
+ # download file with `requests``
+ if file_name == "Stable Diffusion v1.5":
+ if "huggingface_token" not in st.session_state or st.session_state["defaults"].general.huggingface_token == "None":
+ if "progress_bar_text" in st.session_state:
+ st.session_state["progress_bar_text"].error(
+ "You need a huggingface token in order to use the Text to Video tab. Use the Settings page from the sidebar on the left to add your token."
+ )
+ raise OSError("You need a huggingface token in order to use the Text to Video tab. Use the Settings page from the sidebar on the left to add your token.")
+
+ try:
+ with requests.get(file_url, auth = HTTPBasicAuth('token', st.session_state.defaults.general.huggingface_token) if "huggingface.co" in file_url else None, stream=True) as r:
+ r.raise_for_status()
+ with open(os.path.join(file_path, file_name), 'wb') as f:
+ for chunk in stqdm(r.iter_content(chunk_size=8192), backend=True, unit="kb"):
+ f.write(chunk)
+ except HTTPError as e:
+ if "huggingface.co" in file_url:
+ if "resolve"in file_url:
+ repo_url = file_url.split("resolve")[0]
+
+ st.session_state["progress_bar_text"].error(
+ f"You need to accept the license for the model in order to be able to download it. "
+ f"Please visit {repo_url} and accept the lincense there, then try again to download the model.")
+
+ logger.error(e)
+
+ else:
+ print(file_name + ' already exists.')
+
+
+def download_model(models, model_name):
+ """ Download all files from model_list[model_name] """
+ for file in models[model_name]:
+ download_file(file['file_name'], file['file_path'], file['file_url'])
+ return
+
+
+def layout():
+ #search = st.text_input(label="Search", placeholder="Type the name of the model you want to search for.", help="")
+
+ colms = st.columns((1, 3, 3, 5, 5))
+ columns = ["№", 'Model Name', 'Save Location', "Download", 'Download Link']
+
+ models = st.session_state["defaults"].model_manager.models
+
+ for col, field_name in zip(colms, columns):
+ # table header
+ col.write(field_name)
+
+ for x, model_name in enumerate(models):
+ col1, col2, col3, col4, col5 = st.columns((1, 3, 3, 3, 6))
+ col1.write(x) # index
+ col2.write(models[model_name]['model_name'])
+ col3.write(models[model_name]['save_location'])
+ with col4:
+ files_exist = 0
+ for file in models[model_name]['files']:
+ if "save_location" in models[model_name]['files'][file]:
+ os.path.exists(os.path.join(models[model_name]['files'][file]['save_location'] , models[model_name]['files'][file]['file_name']))
+ files_exist += 1
+ elif os.path.exists(os.path.join(models[model_name]['save_location'] , models[model_name]['files'][file]['file_name'])):
+ files_exist += 1
+ files_needed = []
+ for file in models[model_name]['files']:
+ if "save_location" in models[model_name]['files'][file]:
+ if not os.path.exists(os.path.join(models[model_name]['files'][file]['save_location'] , models[model_name]['files'][file]['file_name'])):
+ files_needed.append(file)
+ elif not os.path.exists(os.path.join(models[model_name]['save_location'] , models[model_name]['files'][file]['file_name'])):
+ files_needed.append(file)
+ if len(files_needed) > 0:
+ if st.button('Download', key=models[model_name]['model_name'], help='Download ' + models[model_name]['model_name']):
+ for file in files_needed:
+ if "save_location" in models[model_name]['files'][file]:
+ download_file(models[model_name]['files'][file]['file_name'], models[model_name]['files'][file]['save_location'], models[model_name]['files'][file]['download_link'])
+ else:
+ download_file(models[model_name]['files'][file]['file_name'], models[model_name]['save_location'], models[model_name]['files'][file]['download_link'])
+ st.experimental_rerun()
+ else:
+ st.empty()
+ else:
+ st.write('✅')
+
+ #
diff --git a/webui/streamlit/scripts/Settings.py b/webui/streamlit/scripts/Settings.py
new file mode 100644
index 0000000..0b228cb
--- /dev/null
+++ b/webui/streamlit/scripts/Settings.py
@@ -0,0 +1,899 @@
+# This file is part of sygil-webui (https://github.com/Sygil-Dev/sygil-webui/).
+
+# Copyright 2022 Sygil-Dev team.
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+# base webui import and utils.
+from sd_utils import st, custom_models_available, logger, human_readable_size
+
+# streamlit imports
+
+# streamlit components section
+import streamlit_nested_layout
+from streamlit_server_state import server_state
+
+# other imports
+from omegaconf import OmegaConf
+import torch
+import os, toml
+
+# end of imports
+# ---------------------------------------------------------------------------------------------------------------
+
+@logger.catch(reraise=True)
+def layout():
+ #st.header("Settings")
+
+ with st.form("Settings"):
+ general_tab, txt2img_tab, img2img_tab, img2txt_tab, txt2vid_tab, image_processing, textual_inversion_tab, concepts_library_tab = st.tabs(
+ ['General', "Text-To-Image", "Image-To-Image", "Image-To-Text", "Text-To-Video", "Image processing", "Textual Inversion", "Concepts Library"])
+
+ with general_tab:
+ col1, col2, col3, col4, col5 = st.columns(5, gap='large')
+
+ device_list = []
+ device_properties = [(i, torch.cuda.get_device_properties(i)) for i in range(torch.cuda.device_count())]
+ for device in device_properties:
+ id = device[0]
+ name = device[1].name
+ total_memory = device[1].total_memory
+
+ device_list.append(f"{id}: {name} ({human_readable_size(total_memory, decimal_places=0)})")
+
+ with col1:
+ st.title("General")
+ st.session_state['defaults'].general.gpu = int(st.selectbox("GPU", device_list, index=st.session_state['defaults'].general.gpu,
+ help=f"Select which GPU to use. Default: {device_list[0]}").split(":")[0])
+
+ st.session_state['defaults'].general.outdir = str(st.text_input("Output directory", value=st.session_state['defaults'].general.outdir,
+ help="Relative directory on which the output images after a generation will be placed. Default: 'outputs'"))
+
+ # If we have custom models available on the "models/custom"
+ # folder then we show a menu to select which model we want to use, otherwise we use the main model for SD
+ custom_models_available()
+
+ if server_state["CustomModel_available"]:
+ st.session_state.defaults.general.default_model = st.selectbox("Default Model:", server_state["custom_models"],
+ index=server_state["custom_models"].index(st.session_state['defaults'].general.default_model),
+ help="Select the model you want to use. If you have placed custom models \
+ on your 'models/custom' folder they will be shown here as well. The model name that will be shown here \
+ is the same as the name the file for the model has on said folder, \
+ it is recommended to give the .ckpt file a name that \
+ will make it easier for you to distinguish it from other models. Default: Stable Diffusion v1.4")
+ else:
+ st.session_state.defaults.general.default_model = st.selectbox("Default Model:", [st.session_state['defaults'].general.default_model],
+ help="Select the model you want to use. If you have placed custom models \
+ on your 'models/custom' folder they will be shown here as well. \
+ The model name that will be shown here is the same as the name\
+ the file for the model has on said folder, it is recommended to give the .ckpt file a name that \
+ will make it easier for you to distinguish it from other models. Default: Stable Diffusion v1.4")
+
+ st.session_state['defaults'].general.default_model_config = st.text_input("Default Model Config", value=st.session_state['defaults'].general.default_model_config,
+ help="Default model config file for inference. Default: 'configs/stable-diffusion/v1-inference.yaml'")
+
+ st.session_state['defaults'].general.default_model_path = st.text_input("Default Model Config", value=st.session_state['defaults'].general.default_model_path,
+ help="Default model path. Default: 'models/ldm/stable-diffusion-v1/model.ckpt'")
+
+ st.session_state['defaults'].general.GFPGAN_dir = st.text_input("Default GFPGAN directory", value=st.session_state['defaults'].general.GFPGAN_dir,
+ help="Default GFPGAN directory. Default: './models/gfpgan'")
+
+ st.session_state['defaults'].general.RealESRGAN_dir = st.text_input("Default RealESRGAN directory", value=st.session_state['defaults'].general.RealESRGAN_dir,
+ help="Default GFPGAN directory. Default: './models/realesrgan'")
+
+ RealESRGAN_model_list = ["RealESRGAN_x4plus", "RealESRGAN_x4plus_anime_6B"]
+ st.session_state['defaults'].general.RealESRGAN_model = st.selectbox("RealESRGAN model", RealESRGAN_model_list,
+ index=RealESRGAN_model_list.index(st.session_state['defaults'].general.RealESRGAN_model),
+ help="Default RealESRGAN model. Default: 'RealESRGAN_x4plus'")
+ Upscaler_list = ["RealESRGAN", "LDSR"]
+ st.session_state['defaults'].general.upscaling_method = st.selectbox("Upscaler", Upscaler_list, index=Upscaler_list.index(
+ st.session_state['defaults'].general.upscaling_method), help="Default upscaling method. Default: 'RealESRGAN'")
+
+ with col2:
+ st.title("Performance")
+
+ st.session_state["defaults"].general.gfpgan_cpu = st.checkbox("GFPGAN - CPU", value=st.session_state['defaults'].general.gfpgan_cpu,
+ help="Run GFPGAN on the cpu. Default: False")
+
+ st.session_state["defaults"].general.esrgan_cpu = st.checkbox("ESRGAN - CPU", value=st.session_state['defaults'].general.esrgan_cpu,
+ help="Run ESRGAN on the cpu. Default: False")
+
+ st.session_state["defaults"].general.extra_models_cpu = st.checkbox("Extra Models - CPU", value=st.session_state['defaults'].general.extra_models_cpu,
+ help="Run extra models (GFGPAN/ESRGAN) on cpu. Default: False")
+
+ st.session_state["defaults"].general.extra_models_gpu = st.checkbox("Extra Models - GPU", value=st.session_state['defaults'].general.extra_models_gpu,
+ help="Run extra models (GFGPAN/ESRGAN) on gpu. \
+ Check and save in order to be able to select the GPU that each model will use. Default: False")
+ if st.session_state["defaults"].general.extra_models_gpu:
+ st.session_state['defaults'].general.gfpgan_gpu = int(st.selectbox("GFGPAN GPU", device_list, index=st.session_state['defaults'].general.gfpgan_gpu,
+ help=f"Select which GPU to use. Default: {device_list[st.session_state['defaults'].general.gfpgan_gpu]}",
+ key="gfpgan_gpu").split(":")[0])
+
+ st.session_state["defaults"].general.esrgan_gpu = int(st.selectbox("ESRGAN - GPU", device_list, index=st.session_state['defaults'].general.esrgan_gpu,
+ help=f"Select which GPU to use. Default: {device_list[st.session_state['defaults'].general.esrgan_gpu]}",
+ key="esrgan_gpu").split(":")[0])
+
+ st.session_state["defaults"].general.no_half = st.checkbox("No Half", value=st.session_state['defaults'].general.no_half,
+ help="DO NOT switch the model to 16-bit floats. Default: False")
+
+ st.session_state["defaults"].general.use_cudnn = st.checkbox("Use cudnn", value=st.session_state['defaults'].general.use_cudnn,
+ help="Switch the pytorch backend to use cudnn, this should help with fixing Nvidia 16xx cards getting"
+ "a black or green image. Default: False")
+
+ st.session_state["defaults"].general.use_float16 = st.checkbox("Use float16", value=st.session_state['defaults'].general.use_float16,
+ help="Switch the model to 16-bit floats. Default: False")
+
+
+ precision_list = ['full', 'autocast']
+ st.session_state["defaults"].general.precision = st.selectbox("Precision", precision_list, index=precision_list.index(st.session_state['defaults'].general.precision),
+ help="Evaluates at this precision. Default: autocast")
+
+ st.session_state["defaults"].general.optimized = st.checkbox("Optimized Mode", value=st.session_state['defaults'].general.optimized,
+ help="Loads the model onto the device piecemeal instead of all at once to reduce VRAM usage\
+ at the cost of performance. Default: False")
+
+ st.session_state["defaults"].general.optimized_turbo = st.checkbox("Optimized Turbo Mode", value=st.session_state['defaults'].general.optimized_turbo,
+ help="Alternative optimization mode that does not save as much VRAM but \
+ runs siginificantly faster. Default: False")
+
+ st.session_state["defaults"].general.optimized_config = st.text_input("Optimized Config", value=st.session_state['defaults'].general.optimized_config,
+ help=f"Loads alternative optimized configuration for inference. \
+ Default: optimizedSD/v1-inference.yaml")
+
+ st.session_state["defaults"].general.enable_attention_slicing = st.checkbox("Enable Attention Slicing", value=st.session_state['defaults'].general.enable_attention_slicing,
+ help="Enable sliced attention computation. When this option is enabled, the attention module will \
+ split the input tensor in slices, to compute attention in several steps. This is useful to save some \
+ memory in exchange for a small speed decrease. Only works the txt2vid tab right now. Default: False")
+
+ st.session_state["defaults"].general.enable_minimal_memory_usage = st.checkbox("Enable Minimal Memory Usage", value=st.session_state['defaults'].general.enable_minimal_memory_usage,
+ help="Moves only unet to fp16 and to CUDA, while keepping lighter models on CPUs \
+ (Not properly implemented and currently not working, check this \
+ link 'https://github.com/huggingface/diffusers/pull/537' for more information on it ). Default: False")
+
+ # st.session_state["defaults"].general.update_preview = st.checkbox("Update Preview Image", value=st.session_state['defaults'].general.update_preview,
+ # help="Enables the preview image to be updated and shown to the user on the UI during the generation.\
+ # If checked, once you save the settings an option to specify the frequency at which the image is updated\
+ # in steps will be shown, this is helpful to reduce the negative effect this option has on performance. \
+ # Default: True")
+ st.session_state["defaults"].general.update_preview = True
+ st.session_state["defaults"].general.update_preview_frequency = st.number_input("Update Preview Frequency",
+ min_value=0,
+ value=st.session_state['defaults'].general.update_preview_frequency,
+ help="Specify the frequency at which the image is updated in steps, this is helpful to reduce the \
+ negative effect updating the preview image has on performance. Default: 10")
+
+ with col3:
+ st.title("Others")
+ st.session_state["defaults"].general.use_sd_concepts_library = st.checkbox("Use the Concepts Library", value=st.session_state['defaults'].general.use_sd_concepts_library,
+ help="Use the embeds Concepts Library, if checked, once the settings are saved an option will\
+ appear to specify the directory where the concepts are stored. Default: True)")
+
+ if st.session_state["defaults"].general.use_sd_concepts_library:
+ st.session_state['defaults'].general.sd_concepts_library_folder = st.text_input("Concepts Library Folder",
+ value=st.session_state['defaults'].general.sd_concepts_library_folder,
+ help="Relative folder on which the concepts library embeds are stored. \
+ Default: 'models/custom/sd-concepts-library'")
+
+ st.session_state['defaults'].general.LDSR_dir = st.text_input("LDSR Folder", value=st.session_state['defaults'].general.LDSR_dir,
+ help="Folder where LDSR is located. Default: './models/ldsr'")
+
+ st.session_state["defaults"].general.save_metadata = st.checkbox("Save Metadata", value=st.session_state['defaults'].general.save_metadata,
+ help="Save metadata on the output image. Default: True")
+ save_format_list = ["png","jpg", "jpeg","webp"]
+ st.session_state["defaults"].general.save_format = st.selectbox("Save Format", save_format_list, index=save_format_list.index(st.session_state['defaults'].general.save_format),
+ help="Format that will be used whens saving the output images. Default: 'png'")
+
+ st.session_state["defaults"].general.skip_grid = st.checkbox("Skip Grid", value=st.session_state['defaults'].general.skip_grid,
+ help="Skip saving the grid output image. Default: False")
+ if not st.session_state["defaults"].general.skip_grid:
+
+
+ st.session_state["defaults"].general.grid_quality = st.number_input("Grid Quality", value=st.session_state['defaults'].general.grid_quality,
+ help="Format for saving the grid output image. Default: 95")
+
+ st.session_state["defaults"].general.skip_save = st.checkbox("Skip Save", value=st.session_state['defaults'].general.skip_save,
+ help="Skip saving the output image. Default: False")
+
+ st.session_state["defaults"].general.n_rows = st.number_input("Number of Grid Rows", value=st.session_state['defaults'].general.n_rows,
+ help="Number of rows the grid wil have when saving the grid output image. Default: '-1'")
+
+ st.session_state["defaults"].general.no_verify_input = st.checkbox("Do not Verify Input", value=st.session_state['defaults'].general.no_verify_input,
+ help="Do not verify input to check if it's too long. Default: False")
+
+ st.session_state["defaults"].general.show_percent_in_tab_title = st.checkbox("Show Percent in tab title", value=st.session_state['defaults'].general.show_percent_in_tab_title,
+ help="Add the progress percent value to the page title on the tab on your browser. "
+ "This is useful in case you need to know how the generation is going while doign something else"
+ "in another tab on your browser. Default: True")
+
+ st.session_state["defaults"].general.enable_suggestions = st.checkbox("Enable Suggestions Box", value=st.session_state['defaults'].general.enable_suggestions,
+ help="Adds a suggestion box under the prompt when clicked. Default: True")
+
+ st.session_state["defaults"].daisi_app.running_on_daisi_io = st.checkbox("Running on Daisi.io?", value=st.session_state['defaults'].daisi_app.running_on_daisi_io,
+ help="Specify if we are running on app.Daisi.io . Default: False")
+
+ with col4:
+ st.title("Streamlit Config")
+
+ default_theme_list = ["light", "dark"]
+ st.session_state["defaults"].general.default_theme = st.selectbox("Default Theme", default_theme_list, index=default_theme_list.index(st.session_state['defaults'].general.default_theme),
+ help="Defaut theme to use as base for streamlit. Default: dark")
+ st.session_state["streamlit_config"]["theme"]["base"] = st.session_state["defaults"].general.default_theme
+
+
+ if not st.session_state['defaults'].admin.hide_server_setting:
+ with st.expander("Server", True):
+
+ st.session_state["streamlit_config"]['server']['headless'] = st.checkbox("Run Headless", help="If false, will attempt to open a browser window on start. \
+ Default: false unless (1) we are on a Linux box where DISPLAY is unset, \
+ or (2) we are running in the Streamlit Atom plugin.")
+
+ st.session_state["streamlit_config"]['server']['port'] = st.number_input("Port", value=st.session_state["streamlit_config"]['server']['port'],
+ help="The port where the server will listen for browser connections. Default: 8501")
+
+ st.session_state["streamlit_config"]['server']['baseUrlPath'] = st.text_input("Base Url Path", value=st.session_state["streamlit_config"]['server']['baseUrlPath'],
+ help="The base path for the URL where Streamlit should be served from. Default: '' ")
+
+ st.session_state["streamlit_config"]['server']['enableCORS'] = st.checkbox("Enable CORS", value=st.session_state['streamlit_config']['server']['enableCORS'],
+ help="Enables support for Cross-Origin Request Sharing (CORS) protection, for added security. \
+ Due to conflicts between CORS and XSRF, if `server.enableXsrfProtection` is on and `server.enableCORS` \
+ is off at the same time, we will prioritize `server.enableXsrfProtection`. Default: true")
+
+ st.session_state["streamlit_config"]['server']['enableXsrfProtection'] = st.checkbox("Enable Xsrf Protection",
+ value=st.session_state['streamlit_config']['server']['enableXsrfProtection'],
+ help="Enables support for Cross-Site Request Forgery (XSRF) protection, \
+ for added security. Due to conflicts between CORS and XSRF, \
+ if `server.enableXsrfProtection` is on and `server.enableCORS` is off at \
+ the same time, we will prioritize `server.enableXsrfProtection`. Default: true")
+
+ st.session_state["streamlit_config"]['server']['maxUploadSize'] = st.number_input("Max Upload Size", value=st.session_state["streamlit_config"]['server']['maxUploadSize'],
+ help="Max size, in megabytes, for files uploaded with the file_uploader. Default: 200")
+
+ st.session_state["streamlit_config"]['server']['maxMessageSize'] = st.number_input("Max Message Size", value=st.session_state["streamlit_config"]['server']['maxUploadSize'],
+ help="Max size, in megabytes, of messages that can be sent via the WebSocket connection. Default: 200")
+
+ st.session_state["streamlit_config"]['server']['enableWebsocketCompression'] = st.checkbox("Enable Websocket Compression",
+ value=st.session_state["streamlit_config"]['server']['enableWebsocketCompression'],
+ help=" Enables support for websocket compression. Default: false")
+ if not st.session_state['defaults'].admin.hide_browser_setting:
+ with st.expander("Browser", expanded=True):
+ st.session_state["streamlit_config"]['browser']['serverAddress'] = st.text_input("Server Address",
+ value=st.session_state["streamlit_config"]['browser']['serverAddress'] if "serverAddress" in st.session_state["streamlit_config"] else "localhost",
+ help="Internet address where users should point their browsers in order \
+ to connect to the app. Can be IP address or DNS name and path.\
+ This is used to: - Set the correct URL for CORS and XSRF protection purposes. \
+ - Show the URL on the terminal - Open the browser. Default: 'localhost'")
+
+ st.session_state["defaults"].general.streamlit_telemetry = st.checkbox("Enable Telemetry", value=st.session_state['defaults'].general.streamlit_telemetry,
+ help="Enables or Disables streamlit telemetry. Default: False")
+ st.session_state["streamlit_config"]["browser"]["gatherUsageStats"] = st.session_state["defaults"].general.streamlit_telemetry
+
+ st.session_state["streamlit_config"]['browser']['serverPort'] = st.number_input("Server Port", value=st.session_state["streamlit_config"]['browser']['serverPort'],
+ help="Port where users should point their browsers in order to connect to the app. \
+ This is used to: - Set the correct URL for CORS and XSRF protection purposes. \
+ - Show the URL on the terminal - Open the browser \
+ Default: whatever value is set in server.port.")
+
+ with col5:
+ st.title("Huggingface")
+ st.session_state["defaults"].general.huggingface_token = st.text_input("Huggingface Token", value=st.session_state['defaults'].general.huggingface_token, type="password",
+ help="Your Huggingface Token, it's used to download the model for the diffusers library which \
+ is used on the Text To Video tab. This token will be saved to your user config file\
+ and WILL NOT be share with us or anyone. You can get your access token \
+ at https://huggingface.co/settings/tokens. Default: None")
+
+ st.title("Stable Horde")
+ st.session_state["defaults"].general.stable_horde_api = st.text_input("Stable Horde Api", value=st.session_state["defaults"].general.stable_horde_api, type="password",
+ help="First Register an account at https://stablehorde.net/register which will generate for you \
+ an API key. Store that key somewhere safe. \n \
+ If you do not want to register, you can use `0000000000` as api_key to connect anonymously.\
+ However anonymous accounts have the lowest priority when there's too many concurrent requests! \
+ To increase your priority you will need a unique API key and then to increase your Kudos \
+ read more about them at https://dbzer0.com/blog/the-kudos-based-economy-for-the-koboldai-horde/.")
+
+ with txt2img_tab:
+ col1, col2, col3, col4, col5 = st.columns(5, gap='medium')
+
+ with col1:
+ st.title("Slider Parameters")
+
+ # Width
+ st.session_state["defaults"].txt2img.width.value = st.number_input("Default Image Width", value=st.session_state['defaults'].txt2img.width.value,
+ help="Set the default width for the generated image. Default is: 512")
+
+ st.session_state["defaults"].txt2img.width.min_value = st.number_input("Minimum Image Width", value=st.session_state['defaults'].txt2img.width.min_value,
+ help="Set the default minimum value for the width slider. Default is: 64")
+
+ st.session_state["defaults"].txt2img.width.max_value = st.number_input("Maximum Image Width", value=st.session_state['defaults'].txt2img.width.max_value,
+ help="Set the default maximum value for the width slider. Default is: 2048")
+
+ # Height
+ st.session_state["defaults"].txt2img.height.value = st.number_input("Default Image Height", value=st.session_state['defaults'].txt2img.height.value,
+ help="Set the default height for the generated image. Default is: 512")
+
+ st.session_state["defaults"].txt2img.height.min_value = st.number_input("Minimum Image Height", value=st.session_state['defaults'].txt2img.height.min_value,
+ help="Set the default minimum value for the height slider. Default is: 64")
+
+ st.session_state["defaults"].txt2img.height.max_value = st.number_input("Maximum Image Height", value=st.session_state['defaults'].txt2img.height.max_value,
+ help="Set the default maximum value for the height slider. Default is: 2048")
+
+ with col2:
+ # CFG
+ st.session_state["defaults"].txt2img.cfg_scale.value = st.number_input("Default CFG Scale", value=st.session_state['defaults'].txt2img.cfg_scale.value,
+ help="Set the default value for the CFG Scale. Default is: 7.5")
+
+ st.session_state["defaults"].txt2img.cfg_scale.min_value = st.number_input("Minimum CFG Scale Value", value=st.session_state['defaults'].txt2img.cfg_scale.min_value,
+ help="Set the default minimum value for the CFG scale slider. Default is: 1")
+
+ st.session_state["defaults"].txt2img.cfg_scale.step = st.number_input("CFG Slider Steps", value=st.session_state['defaults'].txt2img.cfg_scale.step,
+ help="Set the default value for the number of steps on the CFG scale slider. Default is: 0.5")
+ # Sampling Steps
+ st.session_state["defaults"].txt2img.sampling_steps.value = st.number_input("Default Sampling Steps", value=st.session_state['defaults'].txt2img.sampling_steps.value,
+ help="Set the default number of sampling steps to use. Default is: 30 (with k_euler)")
+
+ st.session_state["defaults"].txt2img.sampling_steps.min_value = st.number_input("Minimum Sampling Steps",
+ value=st.session_state['defaults'].txt2img.sampling_steps.min_value,
+ help="Set the default minimum value for the sampling steps slider. Default is: 1")
+
+ st.session_state["defaults"].txt2img.sampling_steps.step = st.number_input("Sampling Slider Steps",
+ value=st.session_state['defaults'].txt2img.sampling_steps.step,
+ help="Set the default value for the number of steps on the sampling steps slider. Default is: 10")
+
+ with col3:
+ st.title("General Parameters")
+
+ # Batch Count
+ st.session_state["defaults"].txt2img.batch_count.value = st.number_input("Batch count", value=st.session_state['defaults'].txt2img.batch_count.value,
+ help="How many iterations or batches of images to generate in total.")
+
+ st.session_state["defaults"].txt2img.batch_size.value = st.number_input("Batch size", value=st.session_state.defaults.txt2img.batch_size.value,
+ help="How many images are at once in a batch.\
+ It increases the VRAM usage a lot but if you have enough VRAM it can reduce the time it \
+ takes to finish generation as more images are generated at once.\
+ Default: 1")
+
+ default_sampler_list = ["k_lms", "k_euler", "k_euler_a", "k_dpm_2", "k_dpm_2_a", "k_heun", "PLMS", "DDIM"]
+ st.session_state["defaults"].txt2img.default_sampler = st.selectbox("Default Sampler",
+ default_sampler_list, index=default_sampler_list.index(
+ st.session_state['defaults'].txt2img.default_sampler),
+ help="Defaut sampler to use for txt2img. Default: k_euler")
+
+ st.session_state['defaults'].txt2img.seed = st.text_input("Default Seed", value=st.session_state['defaults'].txt2img.seed, help="Default seed.")
+
+ with col4:
+
+ st.session_state["defaults"].txt2img.separate_prompts = st.checkbox("Separate Prompts",
+ value=st.session_state['defaults'].txt2img.separate_prompts, help="Separate Prompts. Default: False")
+
+ st.session_state["defaults"].txt2img.normalize_prompt_weights = st.checkbox("Normalize Prompt Weights",
+ value=st.session_state['defaults'].txt2img.normalize_prompt_weights,
+ help="Choose to normalize prompt weights. Default: True")
+
+ st.session_state["defaults"].txt2img.save_individual_images = st.checkbox("Save Individual Images",
+ value=st.session_state['defaults'].txt2img.save_individual_images,
+ help="Choose to save individual images. Default: True")
+
+ st.session_state["defaults"].txt2img.save_grid = st.checkbox("Save Grid Images", value=st.session_state['defaults'].txt2img.save_grid,
+ help="Choose to save the grid images. Default: True")
+
+ st.session_state["defaults"].txt2img.group_by_prompt = st.checkbox("Group By Prompt", value=st.session_state['defaults'].txt2img.group_by_prompt,
+ help="Choose to save images grouped by their prompt. Default: False")
+
+ st.session_state["defaults"].txt2img.save_as_jpg = st.checkbox("Save As JPG", value=st.session_state['defaults'].txt2img.save_as_jpg,
+ help="Choose to save images as jpegs. Default: False")
+
+ st.session_state["defaults"].txt2img.write_info_files = st.checkbox("Write Info Files For Images", value=st.session_state['defaults'].txt2img.write_info_files,
+ help="Choose to write the info files along with the generated images. Default: True")
+
+ st.session_state["defaults"].txt2img.use_GFPGAN = st.checkbox(
+ "Use GFPGAN", value=st.session_state['defaults'].txt2img.use_GFPGAN, help="Choose to use GFPGAN. Default: False")
+
+ st.session_state["defaults"].txt2img.use_upscaling = st.checkbox("Use Upscaling", value=st.session_state['defaults'].txt2img.use_upscaling,
+ help="Choose to turn on upscaling by default. Default: False")
+
+ st.session_state["defaults"].txt2img.update_preview = True
+ st.session_state["defaults"].txt2img.update_preview_frequency = st.number_input("Preview Image Update Frequency",
+ min_value=0,
+ value=st.session_state['defaults'].txt2img.update_preview_frequency,
+ help="Set the default value for the frrquency of the preview image updates. Default is: 10")
+
+ with col5:
+ st.title("Variation Parameters")
+
+ st.session_state["defaults"].txt2img.variant_amount.value = st.number_input("Default Variation Amount",
+ value=st.session_state['defaults'].txt2img.variant_amount.value,
+ help="Set the default variation to use. Default is: 0.0")
+
+ st.session_state["defaults"].txt2img.variant_amount.min_value = st.number_input("Minimum Variation Amount",
+ value=st.session_state['defaults'].txt2img.variant_amount.min_value,
+ help="Set the default minimum value for the variation slider. Default is: 0.0")
+
+ st.session_state["defaults"].txt2img.variant_amount.max_value = st.number_input("Maximum Variation Amount",
+ value=st.session_state['defaults'].txt2img.variant_amount.max_value,
+ help="Set the default maximum value for the variation slider. Default is: 1.0")
+
+ st.session_state["defaults"].txt2img.variant_amount.step = st.number_input("Variation Slider Steps",
+ value=st.session_state['defaults'].txt2img.variant_amount.step,
+ help="Set the default value for the number of steps on the variation slider. Default is: 1")
+
+ st.session_state['defaults'].txt2img.variant_seed = st.text_input("Default Variation Seed", value=st.session_state['defaults'].txt2img.variant_seed,
+ help="Default variation seed.")
+
+ with img2img_tab:
+ col1, col2, col3, col4, col5 = st.columns(5, gap='medium')
+
+ with col1:
+ st.title("Image Editing")
+
+ # Denoising
+ st.session_state["defaults"].img2img.denoising_strength.value = st.number_input("Default Denoising Amount",
+ value=st.session_state['defaults'].img2img.denoising_strength.value,
+ help="Set the default denoising to use. Default is: 0.75")
+
+ st.session_state["defaults"].img2img.denoising_strength.min_value = st.number_input("Minimum Denoising Amount",
+ value=st.session_state['defaults'].img2img.denoising_strength.min_value,
+ help="Set the default minimum value for the denoising slider. Default is: 0.0")
+
+ st.session_state["defaults"].img2img.denoising_strength.max_value = st.number_input("Maximum Denoising Amount",
+ value=st.session_state['defaults'].img2img.denoising_strength.max_value,
+ help="Set the default maximum value for the denoising slider. Default is: 1.0")
+
+ st.session_state["defaults"].img2img.denoising_strength.step = st.number_input("Denoising Slider Steps",
+ value=st.session_state['defaults'].img2img.denoising_strength.step,
+ help="Set the default value for the number of steps on the denoising slider. Default is: 0.01")
+
+ # Masking
+ st.session_state["defaults"].img2img.mask_mode = st.number_input("Default Mask Mode", value=st.session_state['defaults'].img2img.mask_mode,
+ help="Set the default mask mode to use. 0 = Keep Masked Area, 1 = Regenerate Masked Area. Default is: 0")
+
+ st.session_state["defaults"].img2img.mask_restore = st.checkbox("Default Mask Restore", value=st.session_state['defaults'].img2img.mask_restore,
+ help="Mask Restore. Default: False")
+
+ st.session_state["defaults"].img2img.resize_mode = st.number_input("Default Resize Mode", value=st.session_state['defaults'].img2img.resize_mode,
+ help="Set the default resizing mode. 0 = Just Resize, 1 = Crop and Resize, 3 = Resize and Fill. Default is: 0")
+
+ with col2:
+ st.title("Slider Parameters")
+
+ # Width
+ st.session_state["defaults"].img2img.width.value = st.number_input("Default Outputted Image Width", value=st.session_state['defaults'].img2img.width.value,
+ help="Set the default width for the generated image. Default is: 512")
+
+ st.session_state["defaults"].img2img.width.min_value = st.number_input("Minimum Outputted Image Width", value=st.session_state['defaults'].img2img.width.min_value,
+ help="Set the default minimum value for the width slider. Default is: 64")
+
+ st.session_state["defaults"].img2img.width.max_value = st.number_input("Maximum Outputted Image Width", value=st.session_state['defaults'].img2img.width.max_value,
+ help="Set the default maximum value for the width slider. Default is: 2048")
+
+ # Height
+ st.session_state["defaults"].img2img.height.value = st.number_input("Default Outputted Image Height", value=st.session_state['defaults'].img2img.height.value,
+ help="Set the default height for the generated image. Default is: 512")
+
+ st.session_state["defaults"].img2img.height.min_value = st.number_input("Minimum Outputted Image Height", value=st.session_state['defaults'].img2img.height.min_value,
+ help="Set the default minimum value for the height slider. Default is: 64")
+
+ st.session_state["defaults"].img2img.height.max_value = st.number_input("Maximum Outputted Image Height", value=st.session_state['defaults'].img2img.height.max_value,
+ help="Set the default maximum value for the height slider. Default is: 2048")
+
+ # CFG
+ st.session_state["defaults"].img2img.cfg_scale.value = st.number_input("Default Img2Img CFG Scale", value=st.session_state['defaults'].img2img.cfg_scale.value,
+ help="Set the default value for the CFG Scale. Default is: 7.5")
+
+ st.session_state["defaults"].img2img.cfg_scale.min_value = st.number_input("Minimum Img2Img CFG Scale Value",
+ value=st.session_state['defaults'].img2img.cfg_scale.min_value,
+ help="Set the default minimum value for the CFG scale slider. Default is: 1")
+
+ with col3:
+ st.session_state["defaults"].img2img.cfg_scale.step = st.number_input("Img2Img CFG Slider Steps",
+ value=st.session_state['defaults'].img2img.cfg_scale.step,
+ help="Set the default value for the number of steps on the CFG scale slider. Default is: 0.5")
+
+ # Sampling Steps
+ st.session_state["defaults"].img2img.sampling_steps.value = st.number_input("Default Img2Img Sampling Steps",
+ value=st.session_state['defaults'].img2img.sampling_steps.value,
+ help="Set the default number of sampling steps to use. Default is: 30 (with k_euler)")
+
+ st.session_state["defaults"].img2img.sampling_steps.min_value = st.number_input("Minimum Img2Img Sampling Steps",
+ value=st.session_state['defaults'].img2img.sampling_steps.min_value,
+ help="Set the default minimum value for the sampling steps slider. Default is: 1")
+
+ st.session_state["defaults"].img2img.sampling_steps.step = st.number_input("Img2Img Sampling Slider Steps",
+ value=st.session_state['defaults'].img2img.sampling_steps.step,
+ help="Set the default value for the number of steps on the sampling steps slider. Default is: 10")
+
+ # Batch Count
+ st.session_state["defaults"].img2img.batch_count.value = st.number_input("Img2img Batch count", value=st.session_state["defaults"].img2img.batch_count.value,
+ help="How many iterations or batches of images to generate in total.")
+
+ st.session_state["defaults"].img2img.batch_size.value = st.number_input("Img2img Batch size", value=st.session_state["defaults"].img2img.batch_size.value,
+ help="How many images are at once in a batch.\
+ It increases the VRAM usage a lot but if you have enough VRAM it can reduce the time it \
+ takes to finish generation as more images are generated at once.\
+ Default: 1")
+ with col4:
+ # Inference Steps
+ st.session_state["defaults"].img2img.num_inference_steps.value = st.number_input("Default Inference Steps",
+ value=st.session_state['defaults'].img2img.num_inference_steps.value,
+ help="Set the default number of inference steps to use. Default is: 200")
+
+ st.session_state["defaults"].img2img.num_inference_steps.min_value = st.number_input("Minimum Sampling Steps",
+ value=st.session_state['defaults'].img2img.num_inference_steps.min_value,
+ help="Set the default minimum value for the inference steps slider. Default is: 10")
+
+ st.session_state["defaults"].img2img.num_inference_steps.max_value = st.number_input("Maximum Sampling Steps",
+ value=st.session_state['defaults'].img2img.num_inference_steps.max_value,
+ help="Set the default maximum value for the inference steps slider. Default is: 500")
+
+ st.session_state["defaults"].img2img.num_inference_steps.step = st.number_input("Inference Slider Steps",
+ value=st.session_state['defaults'].img2img.num_inference_steps.step,
+ help="Set the default value for the number of steps on the inference steps slider.\
+ Default is: 10")
+
+ # Find Noise Steps
+ st.session_state["defaults"].img2img.find_noise_steps.value = st.number_input("Default Find Noise Steps",
+ value=st.session_state['defaults'].img2img.find_noise_steps.value,
+ help="Set the default number of find noise steps to use. Default is: 100")
+
+ st.session_state["defaults"].img2img.find_noise_steps.min_value = st.number_input("Minimum Find Noise Steps",
+ value=st.session_state['defaults'].img2img.find_noise_steps.min_value,
+ help="Set the default minimum value for the find noise steps slider. Default is: 0")
+
+ st.session_state["defaults"].img2img.find_noise_steps.step = st.number_input("Find Noise Slider Steps",
+ value=st.session_state['defaults'].img2img.find_noise_steps.step,
+ help="Set the default value for the number of steps on the find noise steps slider. \
+ Default is: 100")
+
+ with col5:
+ st.title("General Parameters")
+
+ default_sampler_list = ["k_lms", "k_euler", "k_euler_a", "k_dpm_2", "k_dpm_2_a", "k_heun", "PLMS", "DDIM"]
+ st.session_state["defaults"].img2img.sampler_name = st.selectbox("Default Img2Img Sampler", default_sampler_list,
+ index=default_sampler_list.index(st.session_state['defaults'].img2img.sampler_name),
+ help="Defaut sampler to use for img2img. Default: k_euler")
+
+ st.session_state['defaults'].img2img.seed = st.text_input("Default Img2Img Seed", value=st.session_state['defaults'].img2img.seed, help="Default seed.")
+
+ st.session_state["defaults"].img2img.separate_prompts = st.checkbox("Separate Img2Img Prompts", value=st.session_state['defaults'].img2img.separate_prompts,
+ help="Separate Prompts. Default: False")
+
+ st.session_state["defaults"].img2img.normalize_prompt_weights = st.checkbox("Normalize Img2Img Prompt Weights",
+ value=st.session_state['defaults'].img2img.normalize_prompt_weights,
+ help="Choose to normalize prompt weights. Default: True")
+
+ st.session_state["defaults"].img2img.save_individual_images = st.checkbox("Save Individual Img2Img Images",
+ value=st.session_state['defaults'].img2img.save_individual_images,
+ help="Choose to save individual images. Default: True")
+
+ st.session_state["defaults"].img2img.save_grid = st.checkbox("Save Img2Img Grid Images",
+ value=st.session_state['defaults'].img2img.save_grid, help="Choose to save the grid images. Default: True")
+
+ st.session_state["defaults"].img2img.group_by_prompt = st.checkbox("Group By Img2Img Prompt",
+ value=st.session_state['defaults'].img2img.group_by_prompt,
+ help="Choose to save images grouped by their prompt. Default: False")
+
+ st.session_state["defaults"].img2img.save_as_jpg = st.checkbox("Save Img2Img As JPG", value=st.session_state['defaults'].img2img.save_as_jpg,
+ help="Choose to save images as jpegs. Default: False")
+
+ st.session_state["defaults"].img2img.write_info_files = st.checkbox("Write Info Files For Img2Img Images",
+ value=st.session_state['defaults'].img2img.write_info_files,
+ help="Choose to write the info files along with the generated images. Default: True")
+
+ st.session_state["defaults"].img2img.use_GFPGAN = st.checkbox(
+ "Img2Img Use GFPGAN", value=st.session_state['defaults'].img2img.use_GFPGAN, help="Choose to use GFPGAN. Default: False")
+
+ st.session_state["defaults"].img2img.use_RealESRGAN = st.checkbox("Img2Img Use RealESRGAN", value=st.session_state['defaults'].img2img.use_RealESRGAN,
+ help="Choose to use RealESRGAN. Default: False")
+
+ st.session_state["defaults"].img2img.update_preview = True
+ st.session_state["defaults"].img2img.update_preview_frequency = st.number_input("Img2Img Preview Image Update Frequency",
+ min_value=0,
+ value=st.session_state['defaults'].img2img.update_preview_frequency,
+ help="Set the default value for the frrquency of the preview image updates. Default is: 10")
+
+ st.title("Variation Parameters")
+
+ st.session_state["defaults"].img2img.variant_amount = st.number_input("Default Img2Img Variation Amount",
+ value=st.session_state['defaults'].img2img.variant_amount,
+ help="Set the default variation to use. Default is: 0.0")
+
+ # I THINK THESE ARE MISSING FROM THE CONFIG FILE
+ # st.session_state["defaults"].img2img.variant_amount.min_value = st.number_input("Minimum Img2Img Variation Amount",
+ # value=st.session_state['defaults'].img2img.variant_amount.min_value, help="Set the default minimum value for the variation slider. Default is: 0.0"))
+
+ # st.session_state["defaults"].img2img.variant_amount.max_value = st.number_input("Maximum Img2Img Variation Amount",
+ # value=st.session_state['defaults'].img2img.variant_amount.max_value, help="Set the default maximum value for the variation slider. Default is: 1.0"))
+
+ # st.session_state["defaults"].img2img.variant_amount.step = st.number_input("Img2Img Variation Slider Steps",
+ # value=st.session_state['defaults'].img2img.variant_amount.step, help="Set the default value for the number of steps on the variation slider. Default is: 1"))
+
+ st.session_state['defaults'].img2img.variant_seed = st.text_input("Default Img2Img Variation Seed",
+ value=st.session_state['defaults'].img2img.variant_seed, help="Default variation seed.")
+
+ with img2txt_tab:
+ col1 = st.columns(1, gap="large")
+
+ st.title("Image-To-Text")
+
+ st.session_state["defaults"].img2txt.batch_size = st.number_input("Default Img2Txt Batch Size", value=st.session_state['defaults'].img2txt.batch_size,
+ help="Set the default batch size for Img2Txt. Default is: 420?")
+
+ st.session_state["defaults"].img2txt.blip_image_eval_size = st.number_input("Default Blip Image Size Evaluation",
+ value=st.session_state['defaults'].img2txt.blip_image_eval_size,
+ help="Set the default value for the blip image evaluation size. Default is: 512")
+
+ with txt2vid_tab:
+ col1, col2, col3, col4, col5 = st.columns(5, gap="medium")
+
+ with col1:
+ st.title("Slider Parameters")
+
+ # Width
+ st.session_state["defaults"].txt2vid.width.value = st.number_input("Default txt2vid Image Width",
+ value=st.session_state['defaults'].txt2vid.width.value,
+ help="Set the default width for the generated image. Default is: 512")
+
+ st.session_state["defaults"].txt2vid.width.min_value = st.number_input("Minimum txt2vid Image Width",
+ value=st.session_state['defaults'].txt2vid.width.min_value,
+ help="Set the default minimum value for the width slider. Default is: 64")
+
+ st.session_state["defaults"].txt2vid.width.max_value = st.number_input("Maximum txt2vid Image Width",
+ value=st.session_state['defaults'].txt2vid.width.max_value,
+ help="Set the default maximum value for the width slider. Default is: 2048")
+
+ # Height
+ st.session_state["defaults"].txt2vid.height.value = st.number_input("Default txt2vid Image Height",
+ value=st.session_state['defaults'].txt2vid.height.value,
+ help="Set the default height for the generated image. Default is: 512")
+
+ st.session_state["defaults"].txt2vid.height.min_value = st.number_input("Minimum txt2vid Image Height",
+ value=st.session_state['defaults'].txt2vid.height.min_value,
+ help="Set the default minimum value for the height slider. Default is: 64")
+
+ st.session_state["defaults"].txt2vid.height.max_value = st.number_input("Maximum txt2vid Image Height",
+ value=st.session_state['defaults'].txt2vid.height.max_value,
+ help="Set the default maximum value for the height slider. Default is: 2048")
+
+ # CFG
+ st.session_state["defaults"].txt2vid.cfg_scale.value = st.number_input("Default txt2vid CFG Scale",
+ value=st.session_state['defaults'].txt2vid.cfg_scale.value,
+ help="Set the default value for the CFG Scale. Default is: 7.5")
+
+ st.session_state["defaults"].txt2vid.cfg_scale.min_value = st.number_input("Minimum txt2vid CFG Scale Value",
+ value=st.session_state['defaults'].txt2vid.cfg_scale.min_value,
+ help="Set the default minimum value for the CFG scale slider. Default is: 1")
+
+ st.session_state["defaults"].txt2vid.cfg_scale.step = st.number_input("txt2vid CFG Slider Steps",
+ value=st.session_state['defaults'].txt2vid.cfg_scale.step,
+ help="Set the default value for the number of steps on the CFG scale slider. Default is: 0.5")
+
+ with col2:
+ # Sampling Steps
+ st.session_state["defaults"].txt2vid.sampling_steps.value = st.number_input("Default txt2vid Sampling Steps",
+ value=st.session_state['defaults'].txt2vid.sampling_steps.value,
+ help="Set the default number of sampling steps to use. Default is: 30 (with k_euler)")
+
+ st.session_state["defaults"].txt2vid.sampling_steps.min_value = st.number_input("Minimum txt2vid Sampling Steps",
+ value=st.session_state['defaults'].txt2vid.sampling_steps.min_value,
+ help="Set the default minimum value for the sampling steps slider. Default is: 1")
+
+ st.session_state["defaults"].txt2vid.sampling_steps.step = st.number_input("txt2vid Sampling Slider Steps",
+ value=st.session_state['defaults'].txt2vid.sampling_steps.step,
+ help="Set the default value for the number of steps on the sampling steps slider. Default is: 10")
+
+ # Batch Count
+ st.session_state["defaults"].txt2vid.batch_count.value = st.number_input("txt2vid Batch count", value=st.session_state['defaults'].txt2vid.batch_count.value,
+ help="How many iterations or batches of images to generate in total.")
+
+ st.session_state["defaults"].txt2vid.batch_size.value = st.number_input("txt2vid Batch size", value=st.session_state.defaults.txt2vid.batch_size.value,
+ help="How many images are at once in a batch.\
+ It increases the VRAM usage a lot but if you have enough VRAM it can reduce the time it \
+ takes to finish generation as more images are generated at once.\
+ Default: 1")
+
+ # Inference Steps
+ st.session_state["defaults"].txt2vid.num_inference_steps.value = st.number_input("Default Txt2Vid Inference Steps",
+ value=st.session_state['defaults'].txt2vid.num_inference_steps.value,
+ help="Set the default number of inference steps to use. Default is: 200")
+
+ st.session_state["defaults"].txt2vid.num_inference_steps.min_value = st.number_input("Minimum Txt2Vid Sampling Steps",
+ value=st.session_state['defaults'].txt2vid.num_inference_steps.min_value,
+ help="Set the default minimum value for the inference steps slider. Default is: 10")
+
+ st.session_state["defaults"].txt2vid.num_inference_steps.max_value = st.number_input("Maximum Txt2Vid Sampling Steps",
+ value=st.session_state['defaults'].txt2vid.num_inference_steps.max_value,
+ help="Set the default maximum value for the inference steps slider. Default is: 500")
+ st.session_state["defaults"].txt2vid.num_inference_steps.step = st.number_input("Txt2Vid Inference Slider Steps",
+ value=st.session_state['defaults'].txt2vid.num_inference_steps.step,
+ help="Set the default value for the number of steps on the inference steps slider. Default is: 10")
+
+ with col3:
+ st.title("General Parameters")
+
+ st.session_state['defaults'].txt2vid.default_model = st.text_input("Default Txt2Vid Model", value=st.session_state['defaults'].txt2vid.default_model,
+ help="Default: CompVis/stable-diffusion-v1-4")
+
+ # INSERT CUSTOM_MODELS_LIST HERE
+
+ default_sampler_list = ["k_lms", "k_euler", "k_euler_a", "k_dpm_2", "k_dpm_2_a", "k_heun", "PLMS", "DDIM"]
+ st.session_state["defaults"].txt2vid.default_sampler = st.selectbox("Default txt2vid Sampler", default_sampler_list,
+ index=default_sampler_list.index(st.session_state['defaults'].txt2vid.default_sampler),
+ help="Defaut sampler to use for txt2vid. Default: k_euler")
+
+ st.session_state['defaults'].txt2vid.seed = st.text_input("Default txt2vid Seed", value=st.session_state['defaults'].txt2vid.seed, help="Default seed.")
+
+ st.session_state['defaults'].txt2vid.scheduler_name = st.text_input("Default Txt2Vid Scheduler",
+ value=st.session_state['defaults'].txt2vid.scheduler_name, help="Default scheduler.")
+
+ st.session_state["defaults"].txt2vid.separate_prompts = st.checkbox("Separate txt2vid Prompts",
+ value=st.session_state['defaults'].txt2vid.separate_prompts, help="Separate Prompts. Default: False")
+
+ st.session_state["defaults"].txt2vid.normalize_prompt_weights = st.checkbox("Normalize txt2vid Prompt Weights",
+ value=st.session_state['defaults'].txt2vid.normalize_prompt_weights,
+ help="Choose to normalize prompt weights. Default: True")
+
+ st.session_state["defaults"].txt2vid.save_individual_images = st.checkbox("Save Individual txt2vid Images",
+ value=st.session_state['defaults'].txt2vid.save_individual_images,
+ help="Choose to save individual images. Default: True")
+
+ st.session_state["defaults"].txt2vid.save_video = st.checkbox("Save Txt2Vid Video", value=st.session_state['defaults'].txt2vid.save_video,
+ help="Choose to save the Txt2Vid video. Default: True")
+
+ st.session_state["defaults"].txt2vid.save_video_on_stop = st.checkbox("Save video on Stop", value=st.session_state['defaults'].txt2vid.save_video_on_stop,
+ help="Save a video with all the images generated as frames when we hit the stop button \
+ during a generation.")
+
+ st.session_state["defaults"].txt2vid.group_by_prompt = st.checkbox("Group By txt2vid Prompt", value=st.session_state['defaults'].txt2vid.group_by_prompt,
+ help="Choose to save images grouped by their prompt. Default: False")
+
+ st.session_state["defaults"].txt2vid.save_as_jpg = st.checkbox("Save txt2vid As JPG", value=st.session_state['defaults'].txt2vid.save_as_jpg,
+ help="Choose to save images as jpegs. Default: False")
+
+ # Need more info for the Help dialog...
+ st.session_state["defaults"].txt2vid.do_loop = st.checkbox("Loop Generations", value=st.session_state['defaults'].txt2vid.do_loop,
+ help="Choose to loop or something, IDK.... Default: False")
+
+ st.session_state["defaults"].txt2vid.max_duration_in_seconds = st.number_input("Txt2Vid Max Duration in Seconds", value=st.session_state['defaults'].txt2vid.max_duration_in_seconds,
+ help="Set the default value for the max duration in seconds for the video generated. Default is: 30")
+
+ st.session_state["defaults"].txt2vid.write_info_files = st.checkbox("Write Info Files For txt2vid Images", value=st.session_state['defaults'].txt2vid.write_info_files,
+ help="Choose to write the info files along with the generated images. Default: True")
+
+ st.session_state["defaults"].txt2vid.use_GFPGAN = st.checkbox("txt2vid Use GFPGAN", value=st.session_state['defaults'].txt2vid.use_GFPGAN,
+ help="Choose to use GFPGAN. Default: False")
+
+ st.session_state["defaults"].txt2vid.use_RealESRGAN = st.checkbox("txt2vid Use RealESRGAN", value=st.session_state['defaults'].txt2vid.use_RealESRGAN,
+ help="Choose to use RealESRGAN. Default: False")
+
+ st.session_state["defaults"].txt2vid.update_preview = True
+ st.session_state["defaults"].txt2vid.update_preview_frequency = st.number_input("txt2vid Preview Image Update Frequency",
+ value=st.session_state['defaults'].txt2vid.update_preview_frequency,
+ help="Set the default value for the frrquency of the preview image updates. Default is: 10")
+
+ with col4:
+ st.title("Variation Parameters")
+
+ st.session_state["defaults"].txt2vid.variant_amount.value = st.number_input("Default txt2vid Variation Amount",
+ value=st.session_state['defaults'].txt2vid.variant_amount.value,
+ help="Set the default variation to use. Default is: 0.0")
+
+ st.session_state["defaults"].txt2vid.variant_amount.min_value = st.number_input("Minimum txt2vid Variation Amount",
+ value=st.session_state['defaults'].txt2vid.variant_amount.min_value,
+ help="Set the default minimum value for the variation slider. Default is: 0.0")
+
+ st.session_state["defaults"].txt2vid.variant_amount.max_value = st.number_input("Maximum txt2vid Variation Amount",
+ value=st.session_state['defaults'].txt2vid.variant_amount.max_value,
+ help="Set the default maximum value for the variation slider. Default is: 1.0")
+
+ st.session_state["defaults"].txt2vid.variant_amount.step = st.number_input("txt2vid Variation Slider Steps",
+ value=st.session_state['defaults'].txt2vid.variant_amount.step,
+ help="Set the default value for the number of steps on the variation slider. Default is: 1")
+
+ st.session_state['defaults'].txt2vid.variant_seed = st.text_input("Default txt2vid Variation Seed",
+ value=st.session_state['defaults'].txt2vid.variant_seed, help="Default variation seed.")
+
+ with col5:
+ st.title("Beta Parameters")
+
+ # Beta Start
+ st.session_state["defaults"].txt2vid.beta_start.value = st.number_input("Default txt2vid Beta Start Value",
+ value=st.session_state['defaults'].txt2vid.beta_start.value,
+ help="Set the default variation to use. Default is: 0.0")
+
+ st.session_state["defaults"].txt2vid.beta_start.min_value = st.number_input("Minimum txt2vid Beta Start Amount",
+ value=st.session_state['defaults'].txt2vid.beta_start.min_value,
+ help="Set the default minimum value for the variation slider. Default is: 0.0")
+
+ st.session_state["defaults"].txt2vid.beta_start.max_value = st.number_input("Maximum txt2vid Beta Start Amount",
+ value=st.session_state['defaults'].txt2vid.beta_start.max_value,
+ help="Set the default maximum value for the variation slider. Default is: 1.0")
+
+ st.session_state["defaults"].txt2vid.beta_start.step = st.number_input("txt2vid Beta Start Slider Steps", value=st.session_state['defaults'].txt2vid.beta_start.step,
+ help="Set the default value for the number of steps on the variation slider. Default is: 1")
+
+ st.session_state["defaults"].txt2vid.beta_start.format = st.text_input("Default txt2vid Beta Start Format", value=st.session_state['defaults'].txt2vid.beta_start.format,
+ help="Set the default Beta Start Format. Default is: %.5\f")
+
+ # Beta End
+ st.session_state["defaults"].txt2vid.beta_end.value = st.number_input("Default txt2vid Beta End Value", value=st.session_state['defaults'].txt2vid.beta_end.value,
+ help="Set the default variation to use. Default is: 0.0")
+
+ st.session_state["defaults"].txt2vid.beta_end.min_value = st.number_input("Minimum txt2vid Beta End Amount", value=st.session_state['defaults'].txt2vid.beta_end.min_value,
+ help="Set the default minimum value for the variation slider. Default is: 0.0")
+
+ st.session_state["defaults"].txt2vid.beta_end.max_value = st.number_input("Maximum txt2vid Beta End Amount", value=st.session_state['defaults'].txt2vid.beta_end.max_value,
+ help="Set the default maximum value for the variation slider. Default is: 1.0")
+
+ st.session_state["defaults"].txt2vid.beta_end.step = st.number_input("txt2vid Beta End Slider Steps", value=st.session_state['defaults'].txt2vid.beta_end.step,
+ help="Set the default value for the number of steps on the variation slider. Default is: 1")
+
+ st.session_state["defaults"].txt2vid.beta_end.format = st.text_input("Default txt2vid Beta End Format", value=st.session_state['defaults'].txt2vid.beta_start.format,
+ help="Set the default Beta Start Format. Default is: %.5\f")
+
+ with image_processing:
+ col1, col2, col3, col4, col5 = st.columns(5, gap="large")
+
+ with col1:
+ st.title("GFPGAN")
+
+ st.session_state["defaults"].gfpgan.strength = st.number_input("Default Img2Txt Batch Size", value=st.session_state['defaults'].gfpgan.strength,
+ help="Set the default global strength for GFPGAN. Default is: 100")
+ with col2:
+ st.title("GoBig")
+ with col3:
+ st.title("RealESRGAN")
+ with col4:
+ st.title("LDSR")
+ with col5:
+ st.title("GoLatent")
+
+ with textual_inversion_tab:
+ st.title("Textual Inversion")
+
+ st.session_state['defaults'].textual_inversion.pretrained_model_name_or_path = st.text_input("Default Textual Inversion Model Path",
+ value=st.session_state['defaults'].textual_inversion.pretrained_model_name_or_path,
+ help="Default: models/ldm/stable-diffusion-v1-4")
+
+ st.session_state['defaults'].textual_inversion.tokenizer_name = st.text_input("Default Img2Img Variation Seed", value=st.session_state['defaults'].textual_inversion.tokenizer_name,
+ help="Default tokenizer seed.")
+
+ with concepts_library_tab:
+ st.title("Concepts Library")
+ #st.info("Under Construction. :construction_worker:")
+ col1, col2, col3, col4, col5 = st.columns(5, gap='large')
+ with col1:
+ st.session_state["defaults"].concepts_library.concepts_per_page = st.number_input("Concepts Per Page", value=st.session_state['defaults'].concepts_library.concepts_per_page,
+ help="Number of concepts per page to show on the Concepts Library. Default: '12'")
+
+ # add space for the buttons at the bottom
+ st.markdown("---")
+
+ # We need a submit button to save the Settings
+ # as well as one to reset them to the defaults, just in case.
+ _, _, save_button_col, reset_button_col, _, _ = st.columns([1, 1, 1, 1, 1, 1], gap="large")
+ with save_button_col:
+ save_button = st.form_submit_button("Save")
+
+ with reset_button_col:
+ reset_button = st.form_submit_button("Reset")
+
+ if save_button:
+ OmegaConf.save(config=st.session_state.defaults, f="configs/webui/userconfig_streamlit.yaml")
+ loaded = OmegaConf.load("configs/webui/userconfig_streamlit.yaml")
+ assert st.session_state.defaults == loaded
+
+ #
+ if (os.path.exists(".streamlit/config.toml")):
+ with open(".streamlit/config.toml", "w") as toml_file:
+ toml.dump(st.session_state["streamlit_config"], toml_file)
+
+ if reset_button:
+ st.session_state["defaults"] = OmegaConf.load("configs/webui/webui_streamlit.yaml")
+ st.experimental_rerun()
diff --git a/webui/streamlit/scripts/barfi_baklavajs.py b/webui/streamlit/scripts/barfi_baklavajs.py
new file mode 100644
index 0000000..9d3f802
--- /dev/null
+++ b/webui/streamlit/scripts/barfi_baklavajs.py
@@ -0,0 +1,91 @@
+# This file is part of sygil-webui (https://github.com/Sygil-Dev/sandbox-webui/).
+
+# Copyright 2022 Sygil-Dev team.
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+# base webui import and utils.
+#from sd_utils import *
+from sd_utils import st
+# streamlit imports
+
+#streamlit components section
+
+#other imports
+from barfi import st_barfi, barfi_schemas, Block
+
+# Temp imports
+
+# end of imports
+#---------------------------------------------------------------------------------------------------------------
+
+
+def layout():
+ #st.info("Under Construction. :construction_worker:")
+
+ #from barfi import st_barfi, Block
+
+ #add = Block(name='Addition')
+ #sub = Block(name='Subtraction')
+ #mul = Block(name='Multiplication')
+ #div = Block(name='Division')
+
+ #barfi_result = st_barfi(base_blocks= [add, sub, mul, div])
+ # or if you want to use a category to organise them in the frontend sub-menu
+ #barfi_result = st_barfi(base_blocks= {'Op 1': [add, sub], 'Op 2': [mul, div]})
+
+ col1, col2, col3 = st.columns([1, 8, 1])
+
+ with col2:
+ feed = Block(name='Feed')
+ feed.add_output()
+ def feed_func(self):
+ self.set_interface(name='Output 1', value=4)
+ feed.add_compute(feed_func)
+
+ splitter = Block(name='Splitter')
+ splitter.add_input()
+ splitter.add_output()
+ splitter.add_output()
+ def splitter_func(self):
+ in_1 = self.get_interface(name='Input 1')
+ value = (in_1/2)
+ self.set_interface(name='Output 1', value=value)
+ self.set_interface(name='Output 2', value=value)
+ splitter.add_compute(splitter_func)
+
+ mixer = Block(name='Mixer')
+ mixer.add_input()
+ mixer.add_input()
+ mixer.add_output()
+ def mixer_func(self):
+ in_1 = self.get_interface(name='Input 1')
+ in_2 = self.get_interface(name='Input 2')
+ value = (in_1 + in_2)
+ self.set_interface(name='Output 1', value=value)
+ mixer.add_compute(mixer_func)
+
+ result = Block(name='Result')
+ result.add_input()
+ def result_func(self):
+ in_1 = self.get_interface(name='Input 1')
+ result.add_compute(result_func)
+
+ load_schema = st.selectbox('Select a saved schema:', barfi_schemas())
+
+ compute_engine = st.checkbox('Activate barfi compute engine', value=False)
+
+ barfi_result = st_barfi(base_blocks=[feed, result, mixer, splitter],
+ compute_engine=compute_engine, load_schema=load_schema)
+
+ if barfi_result:
+ st.write(barfi_result)
diff --git a/webui/streamlit/scripts/custom_components/dragable_number_input/index.html b/webui/streamlit/scripts/custom_components/dragable_number_input/index.html
new file mode 100644
index 0000000..c26e3fd
--- /dev/null
+++ b/webui/streamlit/scripts/custom_components/dragable_number_input/index.html
@@ -0,0 +1,134 @@
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/webui/streamlit/scripts/custom_components/draggable_number_input/__init__.py b/webui/streamlit/scripts/custom_components/draggable_number_input/__init__.py
new file mode 100644
index 0000000..5e79185
--- /dev/null
+++ b/webui/streamlit/scripts/custom_components/draggable_number_input/__init__.py
@@ -0,0 +1,11 @@
+import os
+import streamlit.components.v1 as components
+
+def load(pixel_per_step = 50):
+ parent_dir = os.path.dirname(os.path.abspath(__file__))
+ file = os.path.join(parent_dir, "main.js")
+
+ with open(file) as f:
+ javascript_main = f.read()
+ javascript_main = javascript_main.replace("%%pixelPerStep%%",str(pixel_per_step))
+ components.html(f"")
\ No newline at end of file
diff --git a/webui/streamlit/scripts/custom_components/draggable_number_input/main.js b/webui/streamlit/scripts/custom_components/draggable_number_input/main.js
new file mode 100644
index 0000000..574d133
--- /dev/null
+++ b/webui/streamlit/scripts/custom_components/draggable_number_input/main.js
@@ -0,0 +1,192 @@
+// iframe parent
+var parentDoc = window.parent.document
+
+// check for mouse pointer locking support, not a requirement but improves the overall experience
+var havePointerLock = 'pointerLockElement' in parentDoc ||
+ 'mozPointerLockElement' in parentDoc ||
+ 'webkitPointerLockElement' in parentDoc;
+
+// the pointer locking exit function
+parentDoc.exitPointerLock = parentDoc.exitPointerLock || parentDoc.mozExitPointerLock || parentDoc.webkitExitPointerLock;
+
+// how far should the mouse travel for a step in pixel
+var pixelPerStep = %%pixelPerStep%%;
+// how many steps did the mouse move in as float
+var movementDelta = 0.0;
+// value when drag started
+var lockedValue = 0.0;
+// minimum value from field
+var lockedMin = 0.0;
+// maximum value from field
+var lockedMax = 0.0;
+// how big should the field steps be
+var lockedStep = 0.0;
+// the currently locked in field
+var lockedField = null;
+
+// lock box to just request pointer lock for one element
+var lockBox = document.createElement("div");
+lockBox.classList.add("lockbox");
+parentDoc.body.appendChild(lockBox);
+lockBox.requestPointerLock = lockBox.requestPointerLock || lockBox.mozRequestPointerLock || lockBox.webkitRequestPointerLock;
+
+function Lock(field)
+{
+ var rect = field.getBoundingClientRect();
+ lockBox.style.left = (rect.left-2.5)+"px";
+ lockBox.style.top = (rect.top-2.5)+"px";
+
+ lockBox.style.width = (rect.width+2.5)+"px";
+ lockBox.style.height = (rect.height+5)+"px";
+
+ lockBox.requestPointerLock();
+}
+
+function Unlock()
+{
+ parentDoc.exitPointerLock();
+ lockBox.style.left = "0px";
+ lockBox.style.top = "0px";
+
+ lockBox.style.width = "0px";
+ lockBox.style.height = "0px";
+ lockedField.focus();
+}
+
+parentDoc.addEventListener('mousedown', (e) => {
+ // if middle is down
+ if(e.button === 1)
+ {
+ if(e.target.tagName === 'INPUT' && e.target.type === 'number')
+ {
+ e.preventDefault();
+ var field = e.target;
+ if(havePointerLock)
+ Lock(field);
+
+ // save current field
+ lockedField = e.target;
+ // add class for styling
+ lockedField.classList.add("value-dragging");
+ // reset movement delta
+ movementDelta = 0.0;
+ // set to 0 if field is empty
+ if(lockedField.value === '')
+ lockedField.value = 0.0;
+
+ // save current field value
+ lockedValue = parseFloat(lockedField.value);
+
+ if(lockedField.min === '' || lockedField.min === '-Infinity')
+ lockedMin = -99999999.0;
+ else
+ lockedMin = parseFloat(lockedField.min);
+
+ if(lockedField.max === '' || lockedField.max === 'Infinity')
+ lockedMax = 99999999.0;
+ else
+ lockedMax = parseFloat(lockedField.max);
+
+ if(lockedField.step === '' || lockedField.step === 'Infinity')
+ lockedStep = 1.0;
+ else
+ lockedStep = parseFloat(lockedField.step);
+
+ // lock pointer if available
+ if(havePointerLock)
+ Lock(lockedField);
+
+ // add drag event
+ parentDoc.addEventListener("mousemove", onDrag, false);
+ }
+ }
+});
+
+function onDrag(e)
+{
+ if(lockedField !== null)
+ {
+ // add movement to delta
+ movementDelta += e.movementX / pixelPerStep;
+ if(lockedField === NaN)
+ return;
+ // set new value
+ let value = lockedValue + Math.floor(Math.abs(movementDelta)) * lockedStep * Math.sign(movementDelta);
+ lockedField.focus();
+ lockedField.select();
+ parentDoc.execCommand('insertText', false /*no UI*/, Math.min(Math.max(value, lockedMin), lockedMax));
+ }
+}
+
+parentDoc.addEventListener('mouseup', (e) => {
+ // if mouse is up
+ if(e.button === 1)
+ {
+ // release pointer lock if available
+ if(havePointerLock)
+ Unlock();
+
+ if(lockedField !== null && lockedField !== NaN)
+ {
+ // stop drag event
+ parentDoc.removeEventListener("mousemove", onDrag, false);
+ // remove class for styling
+ lockedField.classList.remove("value-dragging");
+ // remove reference
+ lockedField = null;
+ }
+ }
+});
+
+// only execute once (even though multiple iframes exist)
+if(!parentDoc.hasOwnProperty("dragableInitialized"))
+{
+ var parentCSS =
+`
+/* Make input-instruction not block mouse events */
+.input-instructions,.input-instructions > *{
+ pointer-events: none;
+ user-select: none;
+ -moz-user-select: none;
+ -khtml-user-select: none;
+ -webkit-user-select: none;
+ -o-user-select: none;
+}
+
+.lockbox {
+ background-color: transparent;
+ position: absolute;
+ pointer-events: none;
+ user-select: none;
+ -moz-user-select: none;
+ -khtml-user-select: none;
+ -webkit-user-select: none;
+ -o-user-select: none;
+ border-left: dotted 2px rgb(255,75,75);
+ border-top: dotted 2px rgb(255,75,75);
+ border-bottom: dotted 2px rgb(255,75,75);
+ border-right: dotted 1px rgba(255,75,75,0.2);
+ border-top-left-radius: 0.25rem;
+ border-bottom-left-radius: 0.25rem;
+ z-index: 1000;
+}
+`;
+
+ // get parent document head
+ var head = parentDoc.getElementsByTagName('head')[0];
+ // add style tag
+ var s = document.createElement('style');
+ // set type attribute
+ s.setAttribute('type', 'text/css');
+ // add css forwarded from python
+ if (s.styleSheet) { // IE
+ s.styleSheet.cssText = parentCSS;
+ } else { // the world
+ s.appendChild(document.createTextNode(parentCSS));
+ }
+ // add style to head
+ head.appendChild(s);
+ // set flag so this only runs once
+ parentDoc["dragableInitialized"] = true;
+}
+
diff --git a/webui/streamlit/scripts/custom_components/sygil_suggestions/__init__.py b/webui/streamlit/scripts/custom_components/sygil_suggestions/__init__.py
new file mode 100644
index 0000000..e25e1ac
--- /dev/null
+++ b/webui/streamlit/scripts/custom_components/sygil_suggestions/__init__.py
@@ -0,0 +1,46 @@
+import os
+from collections import defaultdict
+import streamlit.components.v1 as components
+
+# where to save the downloaded key_phrases
+key_phrases_file = "data/tags/key_phrases.json"
+# the loaded key phrase json as text
+key_phrases_json = ""
+# where to save the downloaded key_phrases
+thumbnails_file = "data/tags/thumbnails.json"
+# the loaded key phrase json as text
+thumbnails_json = ""
+
+def init():
+ global key_phrases_json, thumbnails_json
+ with open(key_phrases_file) as f:
+ key_phrases_json = f.read()
+ with open(thumbnails_file) as f:
+ thumbnails_json = f.read()
+
+def suggestion_area(placeholder):
+ # get component path
+ parent_dir = os.path.dirname(os.path.abspath(__file__))
+ # get file paths
+ javascript_file = os.path.join(parent_dir, "main.js")
+ stylesheet_file = os.path.join(parent_dir, "main.css")
+ parent_stylesheet_file = os.path.join(parent_dir, "parent.css")
+
+ # load file texts
+ with open(javascript_file) as f:
+ javascript_main = f.read()
+ with open(stylesheet_file) as f:
+ stylesheet_main = f.read()
+ with open(parent_stylesheet_file) as f:
+ parent_stylesheet = f.read()
+
+ # add suggestion area div box
+ html = "
javascript failed
"
+ # add loaded style
+ html += f""
+ # set default variables
+ html += f""
+ # add main java script
+ html += f"\n"
+ # add component to site
+ components.html(html, width=None, height=None, scrolling=True)
\ No newline at end of file
diff --git a/webui/streamlit/scripts/custom_components/sygil_suggestions/main.css b/webui/streamlit/scripts/custom_components/sygil_suggestions/main.css
new file mode 100644
index 0000000..c8729b4
--- /dev/null
+++ b/webui/streamlit/scripts/custom_components/sygil_suggestions/main.css
@@ -0,0 +1,81 @@
+*
+{
+ padding: 0px;
+ margin: 0px;
+ user-select: none;
+ -moz-user-select: none;
+ -khtml-user-select: none;
+ -webkit-user-select: none;
+ -o-user-select: none;
+}
+
+body
+{
+ width: 100%;
+ height: 100%;
+ padding-left: calc( 1em - 1px );
+ padding-top: calc( 1em - 1px );
+ overflow: hidden;
+}
+
+/* width */
+::-webkit-scrollbar {
+ width: 7px;
+}
+
+/* Track */
+::-webkit-scrollbar-track {
+ background: rgb(10, 13, 19);
+}
+
+/* Handle */
+::-webkit-scrollbar-thumb {
+ background: #6c6e72;
+ border-radius: 3px;
+}
+
+/* Handle on hover */
+::-webkit-scrollbar-thumb:hover {
+ background: #6c6e72;
+}
+
+#scroll_area
+{
+ display: flex;
+ overflow-x: hidden;
+ overflow-y: auto;
+}
+
+#suggestion_area
+{
+ overflow-x: hidden;
+ width: calc( 100% - 2em - 2px );
+ margin-bottom: calc( 1em + 13px );
+ min-height: 50px;
+}
+
+span
+{
+ border: 1px solid rgba(250, 250, 250, 0.2);
+ border-radius: 0.25rem;
+ font-size: 1rem;
+ font-family: "Source Sans Pro", sans-serif;
+
+ background-color: rgb(38, 39, 48);
+ color: white;
+ display: inline-block;
+ padding: 0.5rem;
+ margin-right: 3px;
+ cursor: pointer;
+ user-select: none;
+ -moz-user-select: none;
+ -khtml-user-select: none;
+ -webkit-user-select: none;
+ -o-user-select: none;
+}
+
+span:hover
+{
+ color: rgb(255,75,75);
+ border-color: rgb(255,75,75);
+}
\ No newline at end of file
diff --git a/webui/streamlit/scripts/custom_components/sygil_suggestions/main.js b/webui/streamlit/scripts/custom_components/sygil_suggestions/main.js
new file mode 100644
index 0000000..9826702
--- /dev/null
+++ b/webui/streamlit/scripts/custom_components/sygil_suggestions/main.js
@@ -0,0 +1,1048 @@
+
+// parent document
+var parentDoc = window.parent.document;
+// iframe element in parent document
+var frame = window.frameElement;
+// the area to put the suggestions in
+var suggestionArea = document.getElementById('suggestion_area');
+var scrollArea = document.getElementById('scroll_area');
+// button height is read when the first button gets created
+var buttonHeight = -1;
+// the maximum size of the iframe in buttons (3 x buttons height)
+var maxHeightInButtons = 3;
+// the prompt field connected to this iframe
+var promptField = null;
+// the category of suggestions
+var activeCategory = contextCategory;
+
+var conditionalButtons = null;
+
+var contextCategory = "[context]";
+
+var frameHeight = "calc( 3em - 3px + {} )";
+
+var filterGroups = {nsfw_mild: "nsfw_mild", nsfw_basic: "nsfw_basic", nsfw_strict: "nsfw_strict", gore_mild: "gore_mild", gore_basic: "gore_basic", gore_strict: "gore_strict"};
+var activeFilters = [filterGroups.nsfw_mild, filterGroups.nsfw_basic, filterGroups.gore_mild];
+
+var triggers = {empty: "empty", nsfw: "nsfw", nude: "nude"};
+var activeContext = [];
+
+var triggerIndex = {};
+
+var wordMap = {};
+var tagMap = {};
+
+// could pass in an array of specific stylesheets for optimization
+function getAllCSSVariableNames(styleSheets = parentDoc.styleSheets){
+ var cssVars = [];
+ // loop each stylesheet
+ for(var i = 0; i < styleSheets.length; i++){
+ // loop stylesheet's cssRules
+ try{ // try/catch used because 'hasOwnProperty' doesn't work
+ for( var j = 0; j < styleSheets[i].cssRules.length; j++){
+ try{
+ //console.log(styleSheets[i].cssRules[j].selectorText);
+ // loop stylesheet's cssRules' style (property names)
+ for(var k = 0; k < styleSheets[i].cssRules[j].style.length; k++){
+ let name = styleSheets[i].cssRules[j].style[k];
+ // test name for css variable signiture and uniqueness
+ if(name.startsWith('--') && cssVars.indexOf(name) == -1){
+ cssVars.push(name);
+ }
+ }
+ } catch (error) {}
+ }
+ } catch (error) {}
+ }
+ return cssVars;
+}
+
+function currentFrameAbsolutePosition() {
+ let currentWindow = window;
+ let currentParentWindow;
+ let positions = [];
+ let rect;
+
+ while (currentWindow !== window.top) {
+ currentParentWindow = currentWindow.parent;
+ for (let idx = 0; idx < currentParentWindow.frames.length; idx++)
+ if (currentParentWindow.frames[idx] === currentWindow) {
+ for (let frameElement of currentParentWindow.document.getElementsByTagName('iframe')) {
+ if (frameElement.contentWindow === currentWindow) {
+ rect = frameElement.getBoundingClientRect();
+ positions.push({x: rect.x, y: rect.y});
+ }
+ }
+ currentWindow = currentParentWindow;
+ break;
+ }
+ }
+ return positions.reduce((accumulator, currentValue) => {
+ return {
+ x: accumulator.x + currentValue.x,
+ y: accumulator.y + currentValue.y
+ };
+ }, { x: 0, y: 0 });
+}
+
+// check if element is visible
+function isVisible(e) {
+ return !!( e.offsetWidth || e.offsetHeight || e.getClientRects().length );
+}
+
+// remove everything from the suggestion area
+function ClearSuggestionArea(text = "")
+{
+ suggestionArea.innerHTML = text;
+ conditionalButtons = [];
+}
+
+// update iframe size depending on button rows
+function UpdateSize()
+{
+ // calculate maximum height
+ var maxHeight = buttonHeight * maxHeightInButtons;
+
+ var height = suggestionArea.lastChild.offsetTop + buttonHeight;
+ // apply height to iframe
+ frame.style.height = frameHeight.replace("{}", Math.min(height,maxHeight)+"px");
+ scrollArea.style.height = frame.style.height;
+}
+
+// add a button to the suggestion area
+function AddButton(label, action, dataTooltip = null, tooltipImage = null, pattern = null, data = null)
+{
+ // create span
+ var button = document.createElement("span");
+ // label it
+ button.innerHTML = label;
+ if(data != null)
+ {
+ // add category attribute to button, will be read on click
+ button.setAttribute("data",data);
+ }
+ if(pattern != null)
+ {
+ // add category attribute to button, will be read on click
+ button.setAttribute("pattern",pattern);
+ }
+ if(dataTooltip != null)
+ {
+ // add category attribute to button, will be read on click
+ button.setAttribute("tooltip-text",dataTooltip);
+ }
+ if(tooltipImage != null)
+ {
+ // add category attribute to button, will be read on click
+ button.setAttribute("tooltip-image",tooltipImage);
+ }
+ // add button function
+ button.addEventListener('click', action, false);
+ button.addEventListener('mouseover', ButtonHoverEnter);
+ button.addEventListener('mouseout', ButtonHoverExit);
+ // add button to suggestion area
+ suggestionArea.appendChild(button);
+ // get buttonHeight if not set
+ if(buttonHeight < 0)
+ buttonHeight = button.offsetHeight;
+ return button;
+}
+
+// find visible prompt field to connect to this iframe
+function GetPromptField()
+{
+ // get all prompt fields, the %% placeholder %% is set in python
+ var all = parentDoc.querySelectorAll('textarea[placeholder="'+placeholder+'"]');
+ // filter visible
+ for(var i = 0; i < all.length; i++)
+ {
+ if(isVisible(all[i]))
+ {
+ promptField = all[i];
+ promptField.addEventListener('input', OnChange, false);
+ promptField.addEventListener('click', OnClick, false);
+ promptField.addEventListener('keyup', OnKey, false);
+ break;
+ }
+ }
+}
+
+function OnChange(e)
+{
+ ButtonConditions();
+ ButtonUpdateContext(true);
+}
+
+function OnClick(e)
+{
+ ButtonUpdateContext(true);
+}
+
+function OnKey(e)
+{
+ if (e.keyCode == '37' || e.keyCode == '38' || e.keyCode == '39' || e.keyCode == '40') {
+ ButtonUpdateContext(false);
+ }
+}
+
+function getCaretPosition(ctrl) {
+ // IE < 9 Support
+ if (document.selection) {
+ ctrl.focus();
+ var range = document.selection.createRange();
+ var rangelen = range.text.length;
+ range.moveStart('character', -ctrl.value.length);
+ var start = range.text.length - rangelen;
+ return {
+ 'start': start,
+ 'end': start + rangelen
+ };
+ } // IE >=9 and other browsers
+ else if (ctrl.selectionStart || ctrl.selectionStart == '0') {
+ return {
+ 'start': ctrl.selectionStart,
+ 'end': ctrl.selectionEnd
+ };
+ } else {
+ return {
+ 'start': 0,
+ 'end': 0
+ };
+ }
+}
+
+function setCaretPosition(ctrl, start, end) {
+ // IE >= 9 and other browsers
+ if (ctrl.setSelectionRange) {
+ ctrl.focus();
+ ctrl.setSelectionRange(start, end);
+ }
+ // IE < 9
+ else if (ctrl.createTextRange) {
+ var range = ctrl.createTextRange();
+ range.collapse(true);
+ range.moveEnd('character', end);
+ range.moveStart('character', start);
+ range.select();
+ }
+}
+
+function isEmptyOrSpaces(str){
+ return str === null || str.match(/^ *$/) !== null;
+}
+
+function ButtonUpdateContext(changeCategory)
+{
+ let targetCategory = contextCategory;
+ let text = promptField.value;
+ if(document.activeElement === promptField)
+ {
+ var pos = getCaretPosition(promptField).end;
+ text = promptField.value.slice(0, pos);
+ }
+
+ activeContext = [];
+
+ var parts = text.split(/[\.?!,]/);
+ if(activeCategory == "Artists" && !isEmptyOrSpaces(parts[parts.length-1]))
+ {
+ return;
+ }
+ if(text == "")
+ {
+ activeContext.push(triggers.empty);
+ }
+ if(text.endsWith("by"))
+ {
+ changeCategory = true;
+ targetCategory = "Artists";
+ activeContext.push("Artists");
+ }
+ else
+ {
+ var parts = text.split(/[\.,!?;]/);
+ parts = parts.reverse();
+
+ parts.forEach( part =>
+ {
+ var words = part.split(" ");
+ words = words.reverse();
+ words.forEach( word =>
+ {
+ word = word.replace(/[^a-zA-Z0-9 \._\-]/g, '').trim().toLowerCase();
+ word = WordToKey(word);
+ if(wordMap.hasOwnProperty(word))
+ {
+ activeContext = activeContext.concat(wordMap[word]).unique();
+ }
+ });
+ });
+ }
+
+ if(activeContext.length == 0)
+ {
+ if(activeCategory == contextCategory)
+ {
+ activeCategory = "";
+ ShowMenu();
+ }
+ }
+ else if(changeCategory)
+ {
+ activeCategory = targetCategory;
+ ShowMenu();
+ }
+ else if(activeCategory == contextCategory)
+ ShowMenu();
+}
+
+// when pressing a button, give the focus back to the prompt field
+function KeepFocus(e)
+{
+ e.preventDefault();
+ promptField.focus();
+}
+
+function selectCategory(e)
+{
+ KeepFocus(e);
+ // set category from attribute
+ activeCategory = e.target.getAttribute("data");
+ // rebuild menu
+ ShowMenu();
+}
+
+function leaveCategory(e)
+{
+ KeepFocus(e);
+ activeCategory = "";
+ // rebuild menu
+ ShowMenu();
+}
+
+// [...]=block "..."=requirement ...=add {|}=cursor {}=insert .,!?;=start
+// [{} {|}]
+// [,by {}{|}]["by "* and by {}{|}]
+// [, {}{|}]
+
+function PatternWalk(text, pattern)
+{
+ var parts = text.split(/[\,!?;]/);
+ var part = parts[parts.length - 1];
+
+ var indent = 0;
+ var outPattern = "";
+ var requirement = ""
+ var mode = "";
+ var patternFailed = false;
+ var partIndex = 0;
+ for( let i = 0; i < pattern.length; i++)
+ {
+ if(mode == "")
+ {
+ if(pattern[i] == "[")
+ {
+ indent++;
+ mode = "pattern";
+ console.log("pattern start:");
+ }
+ }
+ else if(indent > 0)
+ {
+ if(pattern[i] == "[")
+ {
+ indent++;
+ }
+ else if(mode == "pattern")
+ {
+ if(patternFailed)
+ {
+ if(pattern[i] == "]")
+ {
+ indent--;
+ if(indent == 0)
+ {
+ mode = "";
+ outPattern = "";
+ partIndex = 0;
+ patternFailed = false;
+ part = parts[parts.length - 1];
+ }
+ }
+ else
+ {
+ }
+ }
+ else
+ {
+ if(pattern[i] == "\"")
+ {
+ mode = "requirement";
+ }
+ else if(pattern[i] == "]")
+ {
+ indent--;
+ if(indent == 0)
+ {
+ mode = "";
+ return outPattern;
+ }
+ }
+ else if(pattern[i] == "," || pattern[i] == "!" || pattern[i] == "?" || pattern[i] == ";" )
+ {
+ let textToCheck = (text+outPattern).trim();
+
+ if(textToCheck.endsWith("and"))
+ {
+ outPattern += "{_}";
+ part = "";
+ partIndex = 0;
+ }
+ else if(textToCheck.endsWith("with"))
+ {
+ outPattern += "{_}";
+ part = "";
+ partIndex = 0;
+ }
+ else if(textToCheck.endsWith("of"))
+ {
+ outPattern += "{_}";
+ part = "";
+ partIndex = 0;
+ }
+ else if(textToCheck.endsWith("at"))
+ {
+ outPattern += "{_}";
+ part = "";
+ partIndex = 0;
+ }
+ else if(textToCheck.endsWith("and a"))
+ {
+ part = "";
+ partIndex = 0;
+ }
+ else if(textToCheck.endsWith("with a"))
+ {
+ part = "";
+ partIndex = 0;
+ }
+ else if(textToCheck.endsWith("of a"))
+ {
+ part = "";
+ partIndex = 0;
+ }
+ else if(textToCheck.endsWith("at a"))
+ {
+ part = "";
+ partIndex = 0;
+ }
+ else if(textToCheck.endsWith("and an"))
+ {
+ part = "";
+ partIndex = 0;
+ }
+ else if(textToCheck.endsWith("with an"))
+ {
+ part = "";
+ partIndex = 0;
+ }
+ else if(textToCheck.endsWith("of an"))
+ {
+ part = "";
+ partIndex = 0;
+ }
+ else if(textToCheck.endsWith("at an"))
+ {
+ part = "";
+ partIndex = 0;
+ }
+ else if(!textToCheck.endsWith(pattern[i]))
+ {
+ outPattern += pattern[i];
+ part = "";
+ partIndex = 0;
+ }
+ }
+ else if(pattern[i] == "{")
+ {
+ outPattern += pattern[i];
+ mode = "write";
+ }
+ else if(pattern[i] == "." && pattern[i+1] == "*" || pattern[i] == "*")
+ {
+ let minLength = false;
+ if(pattern[i] == "." && pattern[i+1] == "*")
+ {
+ minLength = true;
+ i++;
+ }
+ var o = pattern.slice(i+1).search(/[^\w\s]/);
+ var subpattern = pattern.slice(i+1,i+1+o);
+
+ var index = part.lastIndexOf(subpattern);
+ var subPatternIndex = subpattern.length;
+ while(index == -1)
+ {
+ if(subPatternIndex <= 1)
+ {
+ patternFailed = true;
+ break;
+ }
+
+ subPatternIndex--;
+ var slice = subpattern.slice(0,subPatternIndex);
+ index = part.lastIndexOf(slice);
+ }
+ if(!patternFailed)
+ {
+ if(minLength && index == 0)
+ {
+ patternFailed = true;
+ }
+ partIndex += index;
+ }
+ else
+ {
+ }
+ }
+ else
+ {
+ if(partIndex >= part.length)
+ {
+ outPattern += pattern[i];
+ }
+ else if(part[partIndex] == pattern[i])
+ {
+ partIndex++;
+ }
+ else
+ {
+ patternFailed = true;
+ }
+ }
+ }
+ }
+ else if(mode == "requirement")
+ {
+ if(pattern[i] == "\"")
+ {
+ if(!part.includes(requirement))
+ {
+ patternFailed = true;
+ }
+ else
+ {
+ partIndex = part.indexOf(requirement)+requirement.length;
+ }
+ mode = "pattern";
+ requirement = "";
+ }
+ else
+ {
+ requirement += pattern[i];
+ }
+ }
+ else if(mode == "write")
+ {
+ if(pattern[i] == "}")
+ {
+ outPattern += pattern[i];
+ mode = "pattern";
+ }
+ else
+ {
+ outPattern += pattern[i];
+ }
+ }
+ }
+ else if(pattern[i] == "[")
+ indent++;
+ }
+ // fallback
+ return ", {}";
+}
+
+function InsertPhrase(phrase, pattern)
+{
+ var text = promptField.value ?? "";
+ if(document.activeElement === promptField)
+ {
+ var pos = getCaretPosition(promptField).end;
+ text = promptField.value.slice(0, pos);
+ }
+ var insert = PatternWalk(text,pattern);
+ insert = insert.replace('{}',phrase);
+
+ let firstLetter = phrase.trim()[0];
+
+ if(firstLetter == "a" || firstLetter == "e" || firstLetter == "i" || firstLetter == "o" || firstLetter == "u")
+ insert = insert.replace('{_}',"an");
+ else
+ insert = insert.replace('{_}',"a");
+
+ insert = insert.replace(/{[^|]/,"");
+ insert = insert.replace(/[^|]}/,"");
+
+ var caret = (text+insert).indexOf("{|}");
+ insert = insert.replace('{|}',"");
+ // inserting via execCommand is required, this triggers all native browser functionality as if the user wrote into the prompt field.
+ parentDoc.execCommand('insertText', false, insert);
+ setCaretPosition(promptField, caret, caret);
+}
+
+function SelectPhrase(e)
+{
+ KeepFocus(e);
+ var pattern = e.target.getAttribute("pattern");
+ var phrase = e.target.getAttribute("data");
+
+ InsertPhrase(phrase,pattern);
+}
+
+function CheckButtonCondition(condition)
+{
+ var pos = getCaretPosition(promptField).end;
+ var text = promptField.value.slice(0, pos);
+ if(condition === "empty")
+ {
+ return text == "";
+ }
+}
+
+function ButtonConditions()
+{
+ conditionalButtons.forEach(entry =>
+ {
+ let filtered = !CheckButtonCondition(entry.condition);
+
+ if(entry.filterGroup != null)
+ {
+ entry.filterGroup.split(",").forEach( (group) =>
+ {
+
+ if(activeFilters.includes(group.trim().toLowerCase()))
+ {
+ filtered = filtered || true;
+ return;
+ }
+ });
+ }
+ if(filtered)
+ entry.element.style.display = "none";
+ else
+ entry.element.style.display = "inline-block";
+ });
+}
+
+function ButtonHoverEnter(e)
+{
+ var text = e.target.getAttribute("tooltip-text");
+ var image = e.target.getAttribute("tooltip-image");
+ ShowTooltip(text, e.target, image)
+}
+
+function ButtonHoverExit(e)
+{
+ HideTooltip();
+}
+
+function ShowTooltip(text, target, image = "")
+{
+ var cleanedName = image == null? null : image.replace(/[^a-zA-Z0-9 \._\-]/g, '');
+ if((text == "" || text == null) && (image == "" || image == null || thumbnails[cleanedName] === undefined))
+ return;
+
+ var currentFramePosition = currentFrameAbsolutePosition();
+ var rect = target.getBoundingClientRect();
+ var element = parentDoc["phraseTooltip"];
+ element.innerText = text;
+ if(image != "" && image != null && thumbnails[cleanedName] !== undefined)
+ {
+
+ var img = parentDoc.createElement('img');
+ img.src = GetThumbnailURL(cleanedName);
+ element.appendChild(img)
+ }
+ element.style.display = "flex";
+ element.style.top = (rect.bottom+currentFramePosition.y)+"px";
+ element.style.left = (rect.right+currentFramePosition.x)+"px";
+ element.style.width = "inherit";
+ element.style.height = "inherit";
+}
+
+function base64toBlob(base64Data, contentType) {
+ contentType = contentType || '';
+ var sliceSize = 1024;
+ var byteCharacters = atob(base64Data);
+ var bytesLength = byteCharacters.length;
+ var slicesCount = Math.ceil(bytesLength / sliceSize);
+ var byteArrays = new Array(slicesCount);
+
+ for (var sliceIndex = 0; sliceIndex < slicesCount; ++sliceIndex) {
+ var begin = sliceIndex * sliceSize;
+ var end = Math.min(begin + sliceSize, bytesLength);
+
+ var bytes = new Array(end - begin);
+ for (var offset = begin, i = 0; offset < end; ++i, ++offset) {
+ bytes[i] = byteCharacters[offset].charCodeAt(0);
+ }
+ byteArrays[sliceIndex] = new Uint8Array(bytes);
+ }
+ return new Blob(byteArrays, { type: contentType });
+}
+
+function GetThumbnailURL(image)
+{
+ if(parentDoc["keyPhraseSuggestionsLoadedBlobs"].hasOwnProperty(image))
+ {
+ return parentDoc["keyPhraseSuggestionsLoadedBlobs"][image];
+ }
+ else
+ {
+ let url = URL.createObjectURL(GetThumbnail(image));
+ parentDoc["keyPhraseSuggestionsLoadedBlobs"][image] = url;
+ return url;
+ }
+}
+
+function GetThumbnail(image)
+{
+ return base64toBlob(thumbnails[image], 'image/webp');
+}
+
+function HideTooltip()
+{
+ var element = parentDoc["phraseTooltip"];
+ element.style.display= "none";
+ element.innerHTML = "";
+ element.style.top = "0px";
+ element.style.left = "0px";
+ element.style.width = "0px";
+ element.style.height = "0px";
+}
+
+function RemoveDouble(str, symbol)
+{
+ let doubleSymbole = symbol+symbol;
+ while(str.includes(doubleSymbole))
+ {
+ str = str.replace(doubleSymbole, symbol);
+ }
+ return str;
+}
+
+function ReplaceAll(str, toReplace, seperator, symbol)
+{
+ toReplace.split(seperator).forEach( (replaceSymbol) =>
+ {
+ str = str.replace(replaceSymbol, symbol);
+ });
+ return str;
+}
+
+function WordToKey(word)
+{
+ if(word.endsWith("s"))
+ word = word.slice(0, -1);
+ word = word.replace("'", "");
+ if(word.endsWith("s"))
+ word = word.slice(0, -1);
+ word = ReplaceAll(word, "sch;sh;ch;ll;gg;r;l;j;g", ';', 'h');
+ word = ReplaceAll(word, "sw;ss;zz;qu;kk;k;z;q;s;x", ';','c');
+ word = ReplaceAll(word, "pp;bb;tt;th;ff;p;t;b;f;v", ';','d');
+ word = ReplaceAll(word, "yu;yo;oo;u;y;w", ';','o');
+ word = ReplaceAll(word, "ee;ie;a;i", ';','e');
+ word = ReplaceAll(word, "mm;nn;n", ';','n');
+ word = RemoveDouble(word, "l");
+ word = RemoveDouble(word, "c");
+ word = RemoveDouble(word, "e");
+ word = RemoveDouble(word, "m");
+ word = RemoveDouble(word, "j");
+ word = RemoveDouble(word, "o");
+ word = RemoveDouble(word, "d");
+ word = RemoveDouble(word, "f");
+ return word;
+}
+
+Array.prototype.unique = function() {
+ var a = this.concat();
+ for(var i=0; i
+ {
+ trigger = trigger.replace(/[^a-zA-Z0-9 \._\-]/g, '').trim().toLowerCase();
+ if(!triggers.hasOwnProperty(trigger))
+ {
+ trigger = WordToKey(trigger);
+ }
+ if(triggerIndex.hasOwnProperty(trigger))
+ {
+ triggerIndex[trigger].push( { category: category, index: i });
+ }
+ else
+ {
+ triggerIndex[trigger] = [];
+ triggerIndex[trigger].push( { category: category, index: i });
+ }
+ });
+ }
+
+ /*let words = entry["phrase"].split(" ");
+ let wordCount = words.length;
+ for(let e = 0; e < wordCount; e++)
+ {
+ let wordKey = WordToKey(words[e].replace(/[^a-zA-Z0-9 \._\-]/g, '').trim().toLowerCase());
+
+ if(wordKey.length < 2)
+ continue;
+
+ if(!wordMap.hasOwnProperty(wordKey))
+ {
+ wordMap[wordKey] = [];
+ }
+
+ let entrySearchTags = entry["search_tags"].split(",");
+ entrySearchTags.push(category);
+ entrySearchTags.forEach( search_tag =>
+ {
+ if(search_tag != null && search_tag != "")
+ {
+ if(search_tag.endsWith("'s"))
+ search_tag = search_tag.slice(0, -2);
+ if(search_tag.endsWith("s"))
+ search_tag = search_tag.slice(0, -1);
+ search_tag = search_tag.replace(/[^a-zA-Z0-9 \._\-]/g, '').trim().toLowerCase();
+ wordMap[wordKey].push(search_tag);
+ if(!tagMap.hasOwnProperty(search_tag))
+ {
+ tagMap[search_tag] = [];
+ }
+ tagMap[search_tag].push({ category: category, index: i });
+ tagMap[search_tag] = tagMap[search_tag].unique();
+ }
+ });
+ wordMap[wordKey] = wordMap[wordKey].unique();
+ }*/
+ }
+ }
+}
+
+function ConditionalButton(entry, button)
+{
+ if(entry["show_if"] != "" || entry["filter_group"] != "")
+ conditionalButtons.push({element:button,condition:entry["show_if"], filterGroup:entry["filter_group"]});
+}
+
+// generate menu in suggestion area
+function ShowMenu()
+{
+ // clear all buttons from menu
+ ClearSuggestionArea();
+ HideTooltip();
+
+ // if no chategory is selected
+ if(activeCategory == "")
+ {
+ if(activeContext.length != 0)
+ {
+ AddButton("Context", selectCategory, "A dynamicly updating category based on the current prompt.", null, null, contextCategory);
+ }
+ for (var category in keyPhrases)
+ {
+ AddButton(category, selectCategory, keyPhrases[category]["description"], null, null, category);
+ }
+ // change iframe size after buttons have been added
+ UpdateSize();
+ }
+ else if(activeCategory == contextCategory)
+ {
+ // add a button to leave the chategory
+ var backbutton = AddButton("↑ back", leaveCategory);
+ activeContext.forEach( context =>
+ {
+ if(tagMap.hasOwnProperty(context))
+ {
+ var words = tagMap[context].unique();
+ words.forEach( word =>
+ {
+ var entry = keyPhrases[word.category]["entries"][word.index];
+ var tempPattern = keyPhrases[word.category]["pattern"];
+
+ if(entry["pattern_override"] != "")
+ {
+ tempPattern = entry["pattern_override"];
+ }
+
+ var button = AddButton(entry["phrase"], SelectPhrase, entry["description"], entry["phrase"],tempPattern, entry["phrase"]);
+
+ ConditionalButton(entry, button);
+ });
+ }
+ if(triggerIndex.hasOwnProperty(context))
+ {
+ var triggered = triggerIndex[context];
+ triggered.forEach( trigger =>
+ {
+ var entry = keyPhrases[trigger.category]["entries"][trigger.index];
+ var tempPattern = keyPhrases[trigger.category]["pattern"];
+
+ if(entry["pattern_override"] != "")
+ {
+ tempPattern = entry["pattern_override"];
+ }
+
+ var button = AddButton(entry["phrase"], SelectPhrase, entry["description"], entry["phrase"],tempPattern, entry["phrase"]);
+
+ ConditionalButton(entry, button);
+ });
+ }
+ });
+
+ ButtonConditions();
+ // change iframe size after buttons have been added
+ UpdateSize();
+ }
+ // if a chategory is selected
+ else
+ {
+ // add a button to leave the chategory
+ var backbutton = AddButton("↑ back", leaveCategory);
+ var pattern = keyPhrases[activeCategory]["pattern"];
+ keyPhrases[activeCategory]["entries"].forEach(entry =>
+ {
+ var tempPattern = pattern;
+ if(entry["pattern_override"] != "")
+ {
+ tempPattern = entry["pattern_override"];
+ }
+
+ var button = AddButton(entry["phrase"], SelectPhrase, entry["description"], entry["phrase"],tempPattern, entry["phrase"]);
+
+ ConditionalButton(entry, button);
+ });
+ ButtonConditions();
+ // change iframe size after buttons have been added
+ UpdateSize();
+ }
+}
+
+// listen for clicks on the prompt field
+parentDoc.addEventListener("click", (e) =>
+{
+ // skip if this frame is not visible
+ if(!isVisible(frame))
+ return;
+
+ // if the iframes prompt field is not set, get it and set it
+ if(promptField === null)
+ {
+ GetPromptField();
+ ButtonUpdateContext(true);
+ }
+
+ // get the field with focus
+ var target = parentDoc.activeElement;
+
+ // if the field with focus is a prompt field, the %% placeholder %% is set in python
+ if( target.placeholder === placeholder)
+ {
+ // generate menu
+ ShowMenu();
+ frame.style.borderBottomWidth = '13px';
+ }
+ else
+ {
+ // else hide the iframe
+ frame.style.height = "0px";
+ frame.style.borderBottomWidth = '0px';
+ }
+});
+
+function AppendStyle(targetDoc, id, content)
+{
+ // get parent document head
+ var head = targetDoc.getElementsByTagName('head')[0];
+
+ // add style tag
+ var style = targetDoc.createElement('style');
+ // set type attribute
+ style.setAttribute('type', 'text/css');
+ style.id = id;
+ // add css forwarded from python
+ if (style.styleSheet) { // IE
+ style.styleSheet.cssText = content;
+ } else { // the world
+ style.appendChild(parentDoc.createTextNode(content));
+ }
+ // add style to head
+ head.appendChild(style);
+}
+
+// Transfer all styles
+var head = document.getElementsByTagName("head")[0];
+var parentStyle = parentDoc.getElementsByTagName("style");
+for (var i = 0; i < parentStyle.length; i++)
+ head.appendChild(parentStyle[i].cloneNode(true));
+var parentLinks = parentDoc.querySelectorAll('link[rel="stylesheet"]');
+for (var i = 0; i < parentLinks.length; i++)
+ head.appendChild(parentLinks[i].cloneNode(true));
+
+// add custom style to iframe
+frame.classList.add("suggestion-frame");
+// clear suggestion area to remove the "javascript failed" message
+ClearSuggestionArea();
+// collapse the iframe by default
+frame.style.height = "0px";
+frame.style.borderBottomWidth = '0px';
+
+BuildTriggerIndex();
+
+// only execute once (even though multiple iframes exist)
+if(!parentDoc.hasOwnProperty('keyPhraseSuggestionsInitialized'))
+{
+ AppendStyle(parentDoc, "key-phrase-suggestions", parentCSS);
+
+ var tooltip = parentDoc.createElement('div');
+ tooltip.id = "phrase-tooltip";
+ parentDoc.body.appendChild(tooltip);
+ parentDoc["phraseTooltip"] = tooltip;
+ // set flag so this only runs once
+ parentDoc["keyPhraseSuggestionsLoadedBlobs"] = {};
+ parentDoc["keyPhraseSuggestionsInitialized"] = true;
+
+ var cssVars = getAllCSSVariableNames();
+ computedStyle = getComputedStyle(parentDoc.documentElement);
+
+ parentDoc["keyPhraseSuggestionsCSSvariables"] = ":root{";
+
+ cssVars.forEach( (rule) =>
+ {
+ parentDoc["keyPhraseSuggestionsCSSvariables"] += rule+": "+computedStyle.getPropertyValue(rule)+";";
+ });
+ parentDoc["keyPhraseSuggestionsCSSvariables"] += "}";
+}
+
+AppendStyle(document, "variables", parentDoc["keyPhraseSuggestionsCSSvariables"]);
\ No newline at end of file
diff --git a/webui/streamlit/scripts/custom_components/sygil_suggestions/parent.css b/webui/streamlit/scripts/custom_components/sygil_suggestions/parent.css
new file mode 100644
index 0000000..6e2b285
--- /dev/null
+++ b/webui/streamlit/scripts/custom_components/sygil_suggestions/parent.css
@@ -0,0 +1,84 @@
+.suggestion-frame
+{
+ position: absolute;
+
+ /* make as small as possible */
+ margin: 0px;
+ padding: 0px;
+ min-height: 0px;
+ line-height: 0;
+
+ /* animate transitions of the height property */
+ -webkit-transition: height 1s;
+ -moz-transition: height 1s;
+ -ms-transition: height 1s;
+ -o-transition: height 1s;
+ transition: height 1s, border-bottom-width 1s;
+
+ /* block selection */
+ user-select: none;
+ -moz-user-select: none;
+ -khtml-user-select: none;
+ -webkit-user-select: none;
+ -o-user-select: none;
+
+ z-index: 700;
+
+ outline: 1px solid rgba(250, 250, 250, 0.2);
+ outline-offset: 0px;
+ border-radius: 0.25rem;
+ background: rgb(14, 17, 23);
+
+ box-sizing: border-box;
+ -moz-box-sizing: border-box;
+ -webkit-box-sizing: border-box;
+ border-bottom: solid 13px rgb(14, 17, 23) !important;
+ border-left: solid 13px rgb(14, 17, 23) !important;
+}
+
+#phrase-tooltip
+{
+ display: none;
+ pointer-events: none;
+ position: absolute;
+ border-bottom-left-radius: 0.5rem;
+ border-top-right-radius: 0.5rem;
+ border-bottom-right-radius: 0.5rem;
+ border: solid rgb(255,75,75) 2px;
+ background-color: rgb(38, 39, 48);
+ color: rgb(255,75,75);
+ font-size: 1rem;
+ font-family: "Source Sans Pro", sans-serif;
+ padding: 0.5rem;
+
+ cursor: default;
+ user-select: none;
+ -moz-user-select: none;
+ -khtml-user-select: none;
+ -webkit-user-select: none;
+ -o-user-select: none;
+ z-index: 1000;
+}
+
+#phrase-tooltip:has(img)
+{
+ transform: scale(1.25, 1.25);
+ -ms-transform: scale(1.25, 1.25);
+ -webkit-transform: scale(1.25, 1.25);
+}
+
+#phrase-tooltip>img
+{
+ pointer-events: none;
+ border-bottom-left-radius: 0.5rem;
+ border-top-right-radius: 0.5rem;
+ border-bottom-right-radius: 0.5rem;
+
+ cursor: default;
+ user-select: none;
+ -moz-user-select: none;
+ -khtml-user-select: none;
+ -webkit-user-select: none;
+ -o-user-select: none;
+ z-index: 1500;
+}
\ No newline at end of file
diff --git a/webui/streamlit/scripts/img2img.py b/webui/streamlit/scripts/img2img.py
new file mode 100644
index 0000000..4cee56a
--- /dev/null
+++ b/webui/streamlit/scripts/img2img.py
@@ -0,0 +1,752 @@
+# This file is part of sygil-webui (https://github.com/Sygil-Dev/sygil-webui/).
+
+# Copyright 2022 Sygil-Dev team.
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+# base webui import and utils.
+from sd_utils import st, server_state, no_rerun, \
+ custom_models_available, RealESRGAN_available, GFPGAN_available, LDSR_available
+ #generation_callback, process_images, KDiffusionSampler, \
+ #load_models, hc, seed_to_int, logger, \
+ #resize_image, get_matched_noise, CFGMaskedDenoiser, ImageFilter, set_page_title
+
+# streamlit imports
+from streamlit.runtime.scriptrunner import StopException
+
+#other imports
+import cv2
+from PIL import Image, ImageOps
+import torch
+import k_diffusion as K
+import numpy as np
+import time
+import torch
+import skimage
+from ldm.models.diffusion.ddim import DDIMSampler
+from ldm.models.diffusion.plms import PLMSSampler
+
+# streamlit components
+from custom_components import sygil_suggestions
+from streamlit_drawable_canvas import st_canvas
+
+# Temp imports
+
+
+# end of imports
+#---------------------------------------------------------------------------------------------------------------
+
+sygil_suggestions.init()
+
+try:
+ # this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
+ from transformers import logging
+
+ logging.set_verbosity_error()
+except:
+ pass
+
+def img2img(prompt: str = '', init_info: any = None, init_info_mask: any = None, mask_mode: int = 0, mask_blur_strength: int = 3,
+ mask_restore: bool = False, ddim_steps: int = 50, sampler_name: str = 'DDIM',
+ n_iter: int = 1, cfg_scale: float = 7.5, denoising_strength: float = 0.8,
+ seed: int = -1, noise_mode: int = 0, find_noise_steps: str = "", height: int = 512, width: int = 512, resize_mode: int = 0, fp = None,
+ variant_amount: float = 0.0, variant_seed: int = None, ddim_eta:float = 0.0,
+ write_info_files:bool = True, separate_prompts:bool = False, normalize_prompt_weights:bool = True,
+ save_individual_images: bool = True, save_grid: bool = True, group_by_prompt: bool = True,
+ save_as_jpg: bool = True, use_GFPGAN: bool = True, GFPGAN_model: str = 'GFPGANv1.4',
+ use_RealESRGAN: bool = True, RealESRGAN_model: str = "RealESRGAN_x4plus_anime_6B",
+ use_LDSR: bool = True, LDSR_model: str = "model",
+ loopback: bool = False,
+ random_seed_loopback: bool = False
+ ):
+
+ outpath = st.session_state['defaults'].general.outdir_img2img
+ seed = seed_to_int(seed)
+
+ batch_size = 1
+
+ if sampler_name == 'PLMS':
+ sampler = PLMSSampler(server_state["model"])
+ elif sampler_name == 'DDIM':
+ sampler = DDIMSampler(server_state["model"])
+ elif sampler_name == 'k_dpm_2_a':
+ sampler = KDiffusionSampler(server_state["model"],'dpm_2_ancestral')
+ elif sampler_name == 'k_dpm_2':
+ sampler = KDiffusionSampler(server_state["model"],'dpm_2')
+ elif sampler_name == 'k_dpmpp_2m':
+ sampler = KDiffusionSampler(server_state["model"],'dpmpp_2m')
+ elif sampler_name == 'k_euler_a':
+ sampler = KDiffusionSampler(server_state["model"],'euler_ancestral')
+ elif sampler_name == 'k_euler':
+ sampler = KDiffusionSampler(server_state["model"],'euler')
+ elif sampler_name == 'k_heun':
+ sampler = KDiffusionSampler(server_state["model"],'heun')
+ elif sampler_name == 'k_lms':
+ sampler = KDiffusionSampler(server_state["model"],'lms')
+ else:
+ raise Exception("Unknown sampler: " + sampler_name)
+
+ def process_init_mask(init_mask: Image):
+ if init_mask.mode == "RGBA":
+ init_mask = init_mask.convert('RGBA')
+ background = Image.new('RGBA', init_mask.size, (0, 0, 0))
+ init_mask = Image.alpha_composite(background, init_mask)
+ init_mask = init_mask.convert('RGB')
+ return init_mask
+
+ init_img = init_info
+ init_mask = None
+ if mask_mode == 0:
+ if init_info_mask:
+ init_mask = process_init_mask(init_info_mask)
+ elif mask_mode == 1:
+ if init_info_mask:
+ init_mask = process_init_mask(init_info_mask)
+ init_mask = ImageOps.invert(init_mask)
+ elif mask_mode == 2:
+ init_img_transparency = init_img.split()[-1].convert('L')#.point(lambda x: 255 if x > 0 else 0, mode='1')
+ init_mask = init_img_transparency
+ init_mask = init_mask.convert("RGB")
+ init_mask = resize_image(resize_mode, init_mask, width, height)
+ init_mask = init_mask.convert("RGB")
+
+ assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]'
+ t_enc = int(denoising_strength * ddim_steps)
+
+ if init_mask is not None and (noise_mode == 2 or noise_mode == 3) and init_img is not None:
+ noise_q = 0.99
+ color_variation = 0.0
+ mask_blend_factor = 1.0
+
+ np_init = (np.asarray(init_img.convert("RGB"))/255.0).astype(np.float64) # annoyingly complex mask fixing
+ np_mask_rgb = 1. - (np.asarray(ImageOps.invert(init_mask).convert("RGB"))/255.0).astype(np.float64)
+ np_mask_rgb -= np.min(np_mask_rgb)
+ np_mask_rgb /= np.max(np_mask_rgb)
+ np_mask_rgb = 1. - np_mask_rgb
+ np_mask_rgb_hardened = 1. - (np_mask_rgb < 0.99).astype(np.float64)
+ blurred = skimage.filters.gaussian(np_mask_rgb_hardened[:], sigma=16., channel_axis=2, truncate=32.)
+ blurred2 = skimage.filters.gaussian(np_mask_rgb_hardened[:], sigma=16., channel_axis=2, truncate=32.)
+ #np_mask_rgb_dilated = np_mask_rgb + blurred # fixup mask todo: derive magic constants
+ #np_mask_rgb = np_mask_rgb + blurred
+ np_mask_rgb_dilated = np.clip((np_mask_rgb + blurred2) * 0.7071, 0., 1.)
+ np_mask_rgb = np.clip((np_mask_rgb + blurred) * 0.7071, 0., 1.)
+
+ noise_rgb = get_matched_noise(np_init, np_mask_rgb, noise_q, color_variation)
+ blend_mask_rgb = np.clip(np_mask_rgb_dilated,0.,1.) ** (mask_blend_factor)
+ noised = noise_rgb[:]
+ blend_mask_rgb **= (2.)
+ noised = np_init[:] * (1. - blend_mask_rgb) + noised * blend_mask_rgb
+
+ np_mask_grey = np.sum(np_mask_rgb, axis=2)/3.
+ ref_mask = np_mask_grey < 1e-3
+
+ all_mask = np.ones((height, width), dtype=bool)
+ noised[all_mask,:] = skimage.exposure.match_histograms(noised[all_mask,:]**1., noised[ref_mask,:], channel_axis=1)
+
+ init_img = Image.fromarray(np.clip(noised * 255., 0., 255.).astype(np.uint8), mode="RGB")
+ st.session_state["editor_image"].image(init_img) # debug
+
+ def init():
+ image = init_img.convert('RGB')
+ image = np.array(image).astype(np.float32) / 255.0
+ image = image[None].transpose(0, 3, 1, 2)
+ image = torch.from_numpy(image)
+
+ mask_channel = None
+ if init_mask:
+ alpha = resize_image(resize_mode, init_mask, width // 8, height // 8)
+ mask_channel = alpha.split()[-1]
+
+ mask = None
+ if mask_channel is not None:
+ mask = np.array(mask_channel).astype(np.float32) / 255.0
+ mask = (1 - mask)
+ mask = np.tile(mask, (4, 1, 1))
+ mask = mask[None].transpose(0, 1, 2, 3)
+ mask = torch.from_numpy(mask).to(server_state["device"])
+
+ if st.session_state['defaults'].general.optimized:
+ server_state["modelFS"].to(server_state["device"] )
+
+ init_image = 2. * image - 1.
+ init_image = init_image.to(server_state["device"])
+ init_latent = (server_state["model"] if not st.session_state['defaults'].general.optimized else server_state["modelFS"]).get_first_stage_encoding((server_state["model"] if not st.session_state['defaults'].general.optimized else server_state["modelFS"]).encode_first_stage(init_image)) # move to latent space
+
+ if st.session_state['defaults'].general.optimized:
+ mem = torch.cuda.memory_allocated()/1e6
+ server_state["modelFS"].to("cpu")
+ while(torch.cuda.memory_allocated()/1e6 >= mem):
+ time.sleep(1)
+
+ return init_latent, mask,
+
+ def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name):
+ t_enc_steps = t_enc
+ obliterate = False
+ if ddim_steps == t_enc_steps:
+ t_enc_steps = t_enc_steps - 1
+ obliterate = True
+
+ if sampler_name != 'DDIM':
+ x0, z_mask = init_data
+
+ sigmas = sampler.model_wrap.get_sigmas(ddim_steps)
+ noise = x * sigmas[ddim_steps - t_enc_steps - 1]
+
+ xi = x0 + noise
+
+ # Obliterate masked image
+ if z_mask is not None and obliterate:
+ random = torch.randn(z_mask.shape, device=xi.device)
+ xi = (z_mask * noise) + ((1-z_mask) * xi)
+
+ sigma_sched = sigmas[ddim_steps - t_enc_steps - 1:]
+ model_wrap_cfg = CFGMaskedDenoiser(sampler.model_wrap)
+ samples_ddim = K.sampling.__dict__[f'sample_{sampler.get_sampler_name()}'](model_wrap_cfg, xi, sigma_sched,
+ extra_args={'cond': conditioning, 'uncond': unconditional_conditioning,
+ 'cond_scale': cfg_scale, 'mask': z_mask, 'x0': x0, 'xi': xi}, disable=False,
+ callback=generation_callback if not server_state["bridge"] else None)
+ else:
+
+ x0, z_mask = init_data
+
+ sampler.make_schedule(ddim_num_steps=ddim_steps, ddim_eta=0.0, verbose=False)
+ z_enc = sampler.stochastic_encode(x0, torch.tensor([t_enc_steps]*batch_size).to(server_state["device"] ))
+
+ # Obliterate masked image
+ if z_mask is not None and obliterate:
+ random = torch.randn(z_mask.shape, device=z_enc.device)
+ z_enc = (z_mask * random) + ((1-z_mask) * z_enc)
+
+ # decode it
+ samples_ddim = sampler.decode(z_enc, conditioning, t_enc_steps,
+ unconditional_guidance_scale=cfg_scale,
+ unconditional_conditioning=unconditional_conditioning,
+ z_mask=z_mask, x0=x0)
+ return samples_ddim
+
+
+
+ if loopback:
+ output_images, info = None, None
+ history = []
+ initial_seed = None
+
+ do_color_correction = False
+ try:
+ from skimage import exposure
+ do_color_correction = True
+ except:
+ logger.error("Install scikit-image to perform color correction on loopback")
+
+ for i in range(n_iter):
+ if do_color_correction and i == 0:
+ correction_target = cv2.cvtColor(np.asarray(init_img.copy()), cv2.COLOR_RGB2LAB)
+
+ # RealESRGAN can only run on the final iteration
+ is_final_iteration = i == n_iter - 1
+
+ output_images, seed, info, stats = process_images(
+ outpath=outpath,
+ func_init=init,
+ func_sample=sample,
+ prompt=prompt,
+ seed=seed,
+ sampler_name=sampler_name,
+ save_grid=save_grid,
+ batch_size=1,
+ n_iter=1,
+ steps=ddim_steps,
+ cfg_scale=cfg_scale,
+ width=width,
+ height=height,
+ prompt_matrix=separate_prompts,
+ use_GFPGAN=use_GFPGAN,
+ GFPGAN_model=GFPGAN_model,
+ use_RealESRGAN=use_RealESRGAN and is_final_iteration, # Forcefully disable upscaling when using loopback
+ realesrgan_model_name=RealESRGAN_model,
+ use_LDSR=use_LDSR,
+ LDSR_model_name=LDSR_model,
+ normalize_prompt_weights=normalize_prompt_weights,
+ save_individual_images=save_individual_images,
+ init_img=init_img,
+ init_mask=init_mask,
+ mask_blur_strength=mask_blur_strength,
+ mask_restore=mask_restore,
+ denoising_strength=denoising_strength,
+ noise_mode=noise_mode,
+ find_noise_steps=find_noise_steps,
+ resize_mode=resize_mode,
+ uses_loopback=loopback,
+ uses_random_seed_loopback=random_seed_loopback,
+ sort_samples=group_by_prompt,
+ write_info_files=write_info_files,
+ jpg_sample=save_as_jpg
+ )
+
+ if initial_seed is None:
+ initial_seed = seed
+
+ input_image = init_img
+ init_img = output_images[0]
+
+ if do_color_correction and correction_target is not None:
+ init_img = Image.fromarray(cv2.cvtColor(exposure.match_histograms(
+ cv2.cvtColor(
+ np.asarray(init_img),
+ cv2.COLOR_RGB2LAB
+ ),
+ correction_target,
+ channel_axis=2
+ ), cv2.COLOR_LAB2RGB).astype("uint8"))
+ if mask_restore is True and init_mask is not None:
+ color_mask = init_mask.filter(ImageFilter.GaussianBlur(mask_blur_strength))
+ color_mask = color_mask.convert('L')
+ source_image = input_image.convert('RGB')
+ target_image = init_img.convert('RGB')
+
+ init_img = Image.composite(source_image, target_image, color_mask)
+
+ if not random_seed_loopback:
+ seed = seed + 1
+ else:
+ seed = seed_to_int(None)
+
+ denoising_strength = max(denoising_strength * 0.95, 0.1)
+ history.append(init_img)
+
+ output_images = history
+ seed = initial_seed
+
+ else:
+ output_images, seed, info, stats = process_images(
+ outpath=outpath,
+ func_init=init,
+ func_sample=sample,
+ prompt=prompt,
+ seed=seed,
+ sampler_name=sampler_name,
+ save_grid=save_grid,
+ batch_size=batch_size,
+ n_iter=n_iter,
+ steps=ddim_steps,
+ cfg_scale=cfg_scale,
+ width=width,
+ height=height,
+ prompt_matrix=separate_prompts,
+ use_GFPGAN=use_GFPGAN,
+ GFPGAN_model=GFPGAN_model,
+ use_RealESRGAN=use_RealESRGAN,
+ realesrgan_model_name=RealESRGAN_model,
+ use_LDSR=use_LDSR,
+ LDSR_model_name=LDSR_model,
+ normalize_prompt_weights=normalize_prompt_weights,
+ save_individual_images=save_individual_images,
+ init_img=init_img,
+ init_mask=init_mask,
+ mask_blur_strength=mask_blur_strength,
+ denoising_strength=denoising_strength,
+ noise_mode=noise_mode,
+ find_noise_steps=find_noise_steps,
+ mask_restore=mask_restore,
+ resize_mode=resize_mode,
+ uses_loopback=loopback,
+ sort_samples=group_by_prompt,
+ write_info_files=write_info_files,
+ jpg_sample=save_as_jpg
+ )
+
+ del sampler
+
+ return output_images, seed, info, stats
+
+#
+def layout():
+ with st.form("img2img-inputs"):
+ st.session_state["generation_mode"] = "img2img"
+
+ img2img_input_col, img2img_generate_col = st.columns([10,1])
+ with img2img_input_col:
+ #prompt = st.text_area("Input Text","")
+ placeholder = "A corgi wearing a top hat as an oil painting."
+ prompt = st.text_area("Input Text","", placeholder=placeholder, height=54)
+
+ if "defaults" in st.session_state:
+ if st.session_state["defaults"].general.enable_suggestions:
+ sygil_suggestions.suggestion_area(placeholder)
+
+ if "defaults" in st.session_state:
+ if st.session_state['defaults'].admin.global_negative_prompt:
+ prompt += f"### {st.session_state['defaults'].admin.global_negative_prompt}"
+
+ # Every form must have a submit button, the extra blank spaces is a temp way to align it with the input field. Needs to be done in CSS or some other way.
+ img2img_generate_col.write("")
+ img2img_generate_col.write("")
+ generate_button = img2img_generate_col.form_submit_button("Generate")
+
+
+ # creating the page layout using columns
+ col1_img2img_layout, col2_img2img_layout, col3_img2img_layout = st.columns([2,4,4], gap="medium")
+
+ with col1_img2img_layout:
+ # If we have custom models available on the "models/custom"
+ #folder then we show a menu to select which model we want to use, otherwise we use the main model for SD
+ custom_models_available()
+ if server_state["CustomModel_available"]:
+ st.session_state["custom_model"] = st.selectbox("Custom Model:", server_state["custom_models"],
+ index=server_state["custom_models"].index(st.session_state['defaults'].general.default_model),
+ help="Select the model you want to use. This option is only available if you have custom models \
+ on your 'models/custom' folder. The model name that will be shown here is the same as the name\
+ the file for the model has on said folder, it is recommended to give the .ckpt file a name that \
+ will make it easier for you to distinguish it from other models. Default: Stable Diffusion v1.5")
+ else:
+ st.session_state["custom_model"] = "Stable Diffusion v1.5"
+
+
+ st.session_state["sampling_steps"] = st.number_input("Sampling Steps", value=st.session_state['defaults'].img2img.sampling_steps.value,
+ min_value=st.session_state['defaults'].img2img.sampling_steps.min_value,
+ step=st.session_state['defaults'].img2img.sampling_steps.step)
+
+ sampler_name_list = ["k_lms", "k_euler", "k_euler_a", "k_dpm_2", "k_dpm_2_a", "k_dpmpp_2m", "k_heun", "PLMS", "DDIM"]
+ st.session_state["sampler_name"] = st.selectbox("Sampling method",sampler_name_list,
+ index=sampler_name_list.index(st.session_state['defaults'].img2img.sampler_name), help="Sampling method to use.")
+
+ width = st.slider("Width:", min_value=st.session_state['defaults'].img2img.width.min_value, max_value=st.session_state['defaults'].img2img.width.max_value,
+ value=st.session_state['defaults'].img2img.width.value, step=st.session_state['defaults'].img2img.width.step)
+ height = st.slider("Height:", min_value=st.session_state['defaults'].img2img.height.min_value, max_value=st.session_state['defaults'].img2img.height.max_value,
+ value=st.session_state['defaults'].img2img.height.value, step=st.session_state['defaults'].img2img.height.step)
+ seed = st.text_input("Seed:", value=st.session_state['defaults'].img2img.seed, help=" The seed to use, if left blank a random seed will be generated.")
+
+ cfg_scale = st.number_input("CFG (Classifier Free Guidance Scale):", min_value=st.session_state['defaults'].img2img.cfg_scale.min_value,
+ value=st.session_state['defaults'].img2img.cfg_scale.value,
+ step=st.session_state['defaults'].img2img.cfg_scale.step,
+ help="How strongly the image should follow the prompt.")
+
+ st.session_state["denoising_strength"] = st.slider("Denoising Strength:", value=st.session_state['defaults'].img2img.denoising_strength.value,
+ min_value=st.session_state['defaults'].img2img.denoising_strength.min_value,
+ max_value=st.session_state['defaults'].img2img.denoising_strength.max_value,
+ step=st.session_state['defaults'].img2img.denoising_strength.step)
+
+
+ mask_expander = st.empty()
+ with mask_expander.expander("Inpainting/Outpainting"):
+ mask_mode_list = ["Outpainting", "Inpainting", "Image alpha"]
+ mask_mode = st.selectbox("Painting Mode", mask_mode_list, index=st.session_state["defaults"].img2img.mask_mode,
+ help="Select how you want your image to be masked/painted.\"Inpainting\" modifies the image where the mask is white.\n\
+ \"Inverted mask\" modifies the image where the mask is black. \"Image alpha\" modifies the image where the image is transparent."
+ )
+ mask_mode = mask_mode_list.index(mask_mode)
+
+
+ noise_mode_list = ["Seed", "Find Noise", "Matched Noise", "Find+Matched Noise"]
+ noise_mode = st.selectbox("Noise Mode", noise_mode_list, index=noise_mode_list.index(st.session_state['defaults'].img2img.noise_mode), help="")
+ #noise_mode = noise_mode_list.index(noise_mode)
+ find_noise_steps = st.number_input("Find Noise Steps", value=st.session_state['defaults'].img2img.find_noise_steps.value,
+ min_value=st.session_state['defaults'].img2img.find_noise_steps.min_value,
+ step=st.session_state['defaults'].img2img.find_noise_steps.step)
+
+ # Specify canvas parameters in application
+ drawing_mode = st.selectbox(
+ "Drawing tool:",
+ (
+ "freedraw",
+ "transform",
+ #"line",
+ "rect",
+ "circle",
+ #"polygon",
+ ),
+ )
+
+ stroke_width = st.slider("Stroke width: ", 1, 100, 50)
+ stroke_color = st.color_picker("Stroke color hex: ", value="#EEEEEE")
+ bg_color = st.color_picker("Background color hex: ", "#7B6E6E")
+
+ display_toolbar = st.checkbox("Display toolbar", True)
+ #realtime_update = st.checkbox("Update in realtime", True)
+
+ with st.expander("Batch Options"):
+ st.session_state["batch_count"] = st.number_input("Batch count.", value=st.session_state['defaults'].img2img.batch_count.value,
+ help="How many iterations or batches of images to generate in total.")
+
+ st.session_state["batch_size"] = st.number_input("Batch size", value=st.session_state.defaults.img2img.batch_size.value,
+ help="How many images are at once in a batch.\
+ It increases the VRAM usage a lot but if you have enough VRAM it can reduce the time it takes to finish generation as more images are generated at once.\
+ Default: 1")
+
+ with st.expander("Preview Settings"):
+ st.session_state["update_preview"] = st.session_state["defaults"].general.update_preview
+ st.session_state["update_preview_frequency"] = st.number_input("Update Image Preview Frequency",
+ min_value=0,
+ value=st.session_state['defaults'].img2img.update_preview_frequency,
+ help="Frequency in steps at which the the preview image is updated. By default the frequency \
+ is set to 1 step.")
+ #
+ with st.expander("Advanced"):
+ with st.expander("Output Settings"):
+ separate_prompts = st.checkbox("Create Prompt Matrix.", value=st.session_state['defaults'].img2img.separate_prompts,
+ help="Separate multiple prompts using the `|` character, and get all combinations of them.")
+ normalize_prompt_weights = st.checkbox("Normalize Prompt Weights.", value=st.session_state['defaults'].img2img.normalize_prompt_weights,
+ help="Ensure the sum of all weights add up to 1.0")
+ loopback = st.checkbox("Loopback.", value=st.session_state['defaults'].img2img.loopback, help="Use images from previous batch when creating next batch.")
+ random_seed_loopback = st.checkbox("Random loopback seed.", value=st.session_state['defaults'].img2img.random_seed_loopback, help="Random loopback seed")
+ img2img_mask_restore = st.checkbox("Only modify regenerated parts of image",
+ value=st.session_state['defaults'].img2img.mask_restore,
+ help="Enable to restore the unmasked parts of the image with the input, may not blend as well but preserves detail")
+ save_individual_images = st.checkbox("Save individual images.", value=st.session_state['defaults'].img2img.save_individual_images,
+ help="Save each image generated before any filter or enhancement is applied.")
+ save_grid = st.checkbox("Save grid",value=st.session_state['defaults'].img2img.save_grid, help="Save a grid with all the images generated into a single image.")
+ group_by_prompt = st.checkbox("Group results by prompt", value=st.session_state['defaults'].img2img.group_by_prompt,
+ help="Saves all the images with the same prompt into the same folder. \
+ When using a prompt matrix each prompt combination will have its own folder.")
+ write_info_files = st.checkbox("Write Info file", value=st.session_state['defaults'].img2img.write_info_files,
+ help="Save a file next to the image with informartion about the generation.")
+ save_as_jpg = st.checkbox("Save samples as jpg", value=st.session_state['defaults'].img2img.save_as_jpg, help="Saves the images as jpg instead of png.")
+
+ #
+ # check if GFPGAN, RealESRGAN and LDSR are available.
+ if "GFPGAN_available" not in st.session_state:
+ GFPGAN_available()
+
+ if "RealESRGAN_available" not in st.session_state:
+ RealESRGAN_available()
+
+ if "LDSR_available" not in st.session_state:
+ LDSR_available()
+
+ if st.session_state["GFPGAN_available"] or st.session_state["RealESRGAN_available"] or st.session_state["LDSR_available"]:
+ with st.expander("Post-Processing"):
+ face_restoration_tab, upscaling_tab = st.tabs(["Face Restoration", "Upscaling"])
+ with face_restoration_tab:
+ # GFPGAN used for face restoration
+ if st.session_state["GFPGAN_available"]:
+ #with st.expander("Face Restoration"):
+ #if st.session_state["GFPGAN_available"]:
+ #with st.expander("GFPGAN"):
+ st.session_state["use_GFPGAN"] = st.checkbox("Use GFPGAN", value=st.session_state['defaults'].img2img.use_GFPGAN,
+ help="Uses the GFPGAN model to improve faces after the generation.\
+ This greatly improve the quality and consistency of faces but uses\
+ extra VRAM. Disable if you need the extra VRAM.")
+
+ st.session_state["GFPGAN_model"] = st.selectbox("GFPGAN model", st.session_state["GFPGAN_models"],
+ index=st.session_state["GFPGAN_models"].index(st.session_state['defaults'].general.GFPGAN_model))
+
+ #st.session_state["GFPGAN_strenght"] = st.slider("Effect Strenght", min_value=1, max_value=100, value=1, step=1, help='')
+
+ else:
+ st.session_state["use_GFPGAN"] = False
+
+ with upscaling_tab:
+ st.session_state['us_upscaling'] = st.checkbox("Use Upscaling", value=st.session_state['defaults'].img2img.use_upscaling)
+
+ # RealESRGAN and LDSR used for upscaling.
+ if st.session_state["RealESRGAN_available"] or st.session_state["LDSR_available"]:
+
+ upscaling_method_list = []
+ if st.session_state["RealESRGAN_available"]:
+ upscaling_method_list.append("RealESRGAN")
+ if st.session_state["LDSR_available"]:
+ upscaling_method_list.append("LDSR")
+
+ st.session_state["upscaling_method"] = st.selectbox("Upscaling Method", upscaling_method_list,
+ index=upscaling_method_list.index(st.session_state['defaults'].general.upscaling_method)
+ if st.session_state['defaults'].general.upscaling_method in upscaling_method_list
+ else 0)
+
+ if st.session_state["RealESRGAN_available"]:
+ with st.expander("RealESRGAN"):
+ if st.session_state["upscaling_method"] == "RealESRGAN" and st.session_state['us_upscaling']:
+ st.session_state["use_RealESRGAN"] = True
+ else:
+ st.session_state["use_RealESRGAN"] = False
+
+ st.session_state["RealESRGAN_model"] = st.selectbox("RealESRGAN model", st.session_state["RealESRGAN_models"],
+ index=st.session_state["RealESRGAN_models"].index(st.session_state['defaults'].general.RealESRGAN_model))
+ else:
+ st.session_state["use_RealESRGAN"] = False
+ st.session_state["RealESRGAN_model"] = "RealESRGAN_x4plus"
+
+
+ #
+ if st.session_state["LDSR_available"]:
+ with st.expander("LDSR"):
+ if st.session_state["upscaling_method"] == "LDSR" and st.session_state['us_upscaling']:
+ st.session_state["use_LDSR"] = True
+ else:
+ st.session_state["use_LDSR"] = False
+
+ st.session_state["LDSR_model"] = st.selectbox("LDSR model", st.session_state["LDSR_models"],
+ index=st.session_state["LDSR_models"].index(st.session_state['defaults'].general.LDSR_model))
+
+ st.session_state["ldsr_sampling_steps"] = st.number_input("Sampling Steps", value=st.session_state['defaults'].img2img.LDSR_config.sampling_steps,
+ help="")
+
+ st.session_state["preDownScale"] = st.number_input("PreDownScale", value=st.session_state['defaults'].img2img.LDSR_config.preDownScale,
+ help="")
+
+ st.session_state["postDownScale"] = st.number_input("postDownScale", value=st.session_state['defaults'].img2img.LDSR_config.postDownScale,
+ help="")
+
+ downsample_method_list = ['Nearest', 'Lanczos']
+ st.session_state["downsample_method"] = st.selectbox("Downsample Method", downsample_method_list,
+ index=downsample_method_list.index(st.session_state['defaults'].img2img.LDSR_config.downsample_method))
+
+ else:
+ st.session_state["use_LDSR"] = False
+ st.session_state["LDSR_model"] = "model"
+
+ with st.expander("Variant"):
+ variant_amount = st.slider("Variant Amount:", value=st.session_state['defaults'].img2img.variant_amount, min_value=0.0, max_value=1.0, step=0.01)
+ variant_seed = st.text_input("Variant Seed:", value=st.session_state['defaults'].img2img.variant_seed,
+ help="The seed to use when generating a variant, if left blank a random seed will be generated.")
+
+
+ with col2_img2img_layout:
+ editor_tab = st.tabs(["Editor"])
+
+ editor_image = st.empty()
+ st.session_state["editor_image"] = editor_image
+
+ st.form_submit_button("Refresh")
+
+ #if "canvas" not in st.session_state:
+ st.session_state["canvas"] = st.empty()
+
+ masked_image_holder = st.empty()
+ image_holder = st.empty()
+
+ uploaded_images = st.file_uploader(
+ "Upload Image", accept_multiple_files=False, type=["png", "jpg", "jpeg", "webp", 'jfif'],
+ help="Upload an image which will be used for the image to image generation.",
+ )
+ if uploaded_images:
+ image = Image.open(uploaded_images).convert('RGB')
+ new_img = image.resize((width, height))
+ #image_holder.image(new_img)
+
+ #mask_holder = st.empty()
+
+ #uploaded_masks = st.file_uploader(
+ #"Upload Mask", accept_multiple_files=False, type=["png", "jpg", "jpeg", "webp", 'jfif'],
+ #help="Upload an mask image which will be used for masking the image to image generation.",
+ #)
+
+ #
+ # Create a canvas component
+ with st.session_state["canvas"]:
+ st.session_state["uploaded_masks"] = st_canvas(
+ fill_color="rgba(255, 165, 0, 0.3)", # Fixed fill color with some opacity
+ stroke_width=stroke_width,
+ stroke_color=stroke_color,
+ background_color=bg_color,
+ background_image=image if uploaded_images else None,
+ update_streamlit=True,
+ width=width,
+ height=height,
+ drawing_mode=drawing_mode,
+ initial_drawing=st.session_state["uploaded_masks"].json_data if "uploaded_masks" in st.session_state else None,
+ display_toolbar= display_toolbar,
+ key="full_app",
+ )
+
+ #try:
+ ##print (type(st.session_state["uploaded_masks"]))
+ #if st.session_state["uploaded_masks"] != None:
+ #mask_expander.expander("Mask", expanded=True)
+ #mask = Image.fromarray(st.session_state["uploaded_masks"].image_data)
+
+ #st.image(mask)
+
+ #if mask.mode == "RGBA":
+ #mask = mask.convert('RGBA')
+ #background = Image.new('RGBA', mask.size, (0, 0, 0))
+ #mask = Image.alpha_composite(background, mask)
+ #mask = mask.resize((width, height))
+ #except AttributeError:
+ #pass
+
+ with col3_img2img_layout:
+ result_tab = st.tabs(["Result"])
+
+ # create an empty container for the image, progress bar, etc so we can update it later and use session_state to hold them globally.
+ preview_image = st.empty()
+ st.session_state["preview_image"] = preview_image
+
+ #st.session_state["loading"] = st.empty()
+
+ st.session_state["progress_bar_text"] = st.empty()
+ st.session_state["progress_bar"] = st.empty()
+
+ message = st.empty()
+
+ #if uploaded_images:
+ #image = Image.open(uploaded_images).convert('RGB')
+ ##img_array = np.array(image) # if you want to pass it to OpenCV
+ #new_img = image.resize((width, height))
+ #st.image(new_img, use_column_width=True)
+
+
+ if generate_button:
+ #print("Loading models")
+ # load the models when we hit the generate button for the first time, it wont be loaded after that so dont worry.
+ with col3_img2img_layout:
+ with no_rerun:
+ with hc.HyLoader('Loading Models...', hc.Loaders.standard_loaders,index=[0]):
+ load_models(use_LDSR=st.session_state["use_LDSR"], LDSR_model=st.session_state["LDSR_model"],
+ use_GFPGAN=st.session_state["use_GFPGAN"], GFPGAN_model=st.session_state["GFPGAN_model"] ,
+ use_RealESRGAN=st.session_state["use_RealESRGAN"], RealESRGAN_model=st.session_state["RealESRGAN_model"],
+ CustomModel_available=server_state["CustomModel_available"], custom_model=st.session_state["custom_model"])
+
+ if uploaded_images:
+ #image = Image.fromarray(image).convert('RGBA')
+ #new_img = image.resize((width, height))
+ ###img_array = np.array(image) # if you want to pass it to OpenCV
+ #image_holder.image(new_img)
+ new_mask = None
+
+ if st.session_state["uploaded_masks"]:
+ mask = Image.fromarray(st.session_state["uploaded_masks"].image_data)
+ new_mask = mask.resize((width, height))
+
+ #masked_image_holder.image(new_mask)
+ try:
+ output_images, seed, info, stats = img2img(prompt=prompt, init_info=new_img, init_info_mask=new_mask, mask_mode=mask_mode,
+ mask_restore=img2img_mask_restore, ddim_steps=st.session_state["sampling_steps"],
+ sampler_name=st.session_state["sampler_name"], n_iter=st.session_state["batch_count"],
+ cfg_scale=cfg_scale, denoising_strength=st.session_state["denoising_strength"], variant_seed=variant_seed,
+ seed=seed, noise_mode=noise_mode, find_noise_steps=find_noise_steps, width=width,
+ height=height, variant_amount=variant_amount,
+ ddim_eta=st.session_state.defaults.img2img.ddim_eta, write_info_files=write_info_files,
+ separate_prompts=separate_prompts, normalize_prompt_weights=normalize_prompt_weights,
+ save_individual_images=save_individual_images, save_grid=save_grid,
+ group_by_prompt=group_by_prompt, save_as_jpg=save_as_jpg, use_GFPGAN=st.session_state["use_GFPGAN"],
+ GFPGAN_model=st.session_state["GFPGAN_model"],
+ use_RealESRGAN=st.session_state["use_RealESRGAN"], RealESRGAN_model=st.session_state["RealESRGAN_model"],
+ use_LDSR=st.session_state["use_LDSR"], LDSR_model=st.session_state["LDSR_model"],
+ loopback=loopback
+ )
+
+ #show a message when the generation is complete.
+ message.success('Render Complete: ' + info + '; Stats: ' + stats, icon="✅")
+
+ except (StopException,
+ #KeyError
+ ):
+ logger.info(f"Received Streamlit StopException")
+ # reset the page title so the percent doesnt stay on it confusing the user.
+ set_page_title(f"Stable Diffusion Playground")
+
+ # this will render all the images at the end of the generation but its better if its moved to a second tab inside col2 and shown as a gallery.
+ # use the current col2 first tab to show the preview_img and update it as its generated.
+ #preview_image.image(output_images, width=750)
+
+#on import run init
diff --git a/webui/streamlit/scripts/img2txt.py b/webui/streamlit/scripts/img2txt.py
new file mode 100644
index 0000000..1268ebf
--- /dev/null
+++ b/webui/streamlit/scripts/img2txt.py
@@ -0,0 +1,460 @@
+# This file is part of sygil-webui (https://github.com/Sygil-Dev/sygil-webui/).
+
+# Copyright 2022 Sygil-Dev team.
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+# ---------------------------------------------------------------------------------------------------------------------------------------------------
+"""
+CLIP Interrogator made by @pharmapsychotic modified to work with our WebUI.
+
+# CLIP Interrogator by @pharmapsychotic
+Twitter: https://twitter.com/pharmapsychotic
+Github: https://github.com/pharmapsychotic/clip-interrogator
+
+Description:
+What do the different OpenAI CLIP models see in an image? What might be a good text prompt to create similar images using CLIP guided diffusion
+or another text to image model? The CLIP Interrogator is here to get you answers!
+
+Please consider buying him a coffee via [ko-fi](https://ko-fi.com/pharmapsychotic) or following him on [twitter](https://twitter.com/pharmapsychotic).
+
+And if you're looking for more Ai art tools check out my [Ai generative art tools list](https://pharmapsychotic.com/tools.html).
+
+"""
+# ---------------------------------------------------------------------------------------------------------------------------------------------------
+
+# base webui import and utils.
+from sd_utils import st, logger, server_state, server_state_lock, random
+
+# streamlit imports
+
+# streamlit components section
+import streamlit_nested_layout
+
+# other imports
+
+import clip
+import open_clip
+import gc
+import os
+import pandas as pd
+#import requests
+import torch
+from PIL import Image
+from torchvision import transforms
+from torchvision.transforms.functional import InterpolationMode
+from ldm.models.blip import blip_decoder
+#import hashlib
+
+# end of imports
+# ---------------------------------------------------------------------------------------------------------------
+
+device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
+blip_image_eval_size = 512
+
+st.session_state["log"] = []
+
+def load_blip_model():
+ logger.info("Loading BLIP Model")
+ if "log" not in st.session_state:
+ st.session_state["log"] = []
+
+ st.session_state["log"].append("Loading BLIP Model")
+ st.session_state["log_message"].code('\n'.join(st.session_state["log"]), language='')
+
+ if "blip_model" not in server_state:
+ with server_state_lock['blip_model']:
+ server_state["blip_model"] = blip_decoder(pretrained="models/blip/model__base_caption.pth",
+ image_size=blip_image_eval_size, vit='base', med_config="configs/blip/med_config.json")
+
+ server_state["blip_model"] = server_state["blip_model"].eval()
+
+ server_state["blip_model"] = server_state["blip_model"].to(device).half()
+
+ logger.info("BLIP Model Loaded")
+ st.session_state["log"].append("BLIP Model Loaded")
+ st.session_state["log_message"].code('\n'.join(st.session_state["log"]), language='')
+ else:
+ logger.info("BLIP Model already loaded")
+ st.session_state["log"].append("BLIP Model already loaded")
+ st.session_state["log_message"].code('\n'.join(st.session_state["log"]), language='')
+
+
+def generate_caption(pil_image):
+
+ load_blip_model()
+
+ gpu_image = transforms.Compose([ # type: ignore
+ transforms.Resize((blip_image_eval_size, blip_image_eval_size), interpolation=InterpolationMode.BICUBIC), # type: ignore
+ transforms.ToTensor(), # type: ignore
+ transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) # type: ignore
+ ])(pil_image).unsqueeze(0).to(device).half()
+
+ with torch.no_grad():
+ caption = server_state["blip_model"].generate(gpu_image, sample=False, num_beams=3, max_length=20, min_length=5)
+
+ return caption[0]
+
+def load_list(filename):
+ with open(filename, 'r', encoding='utf-8', errors='replace') as f:
+ items = [line.strip() for line in f.readlines()]
+ return items
+
+def rank(model, image_features, text_array, top_count=1):
+ top_count = min(top_count, len(text_array))
+ text_tokens = clip.tokenize([text for text in text_array]).cuda()
+ with torch.no_grad():
+ text_features = model.encode_text(text_tokens).float()
+ text_features /= text_features.norm(dim=-1, keepdim=True)
+
+ similarity = torch.zeros((1, len(text_array))).to(device)
+ for i in range(image_features.shape[0]):
+ similarity += (100.0 * image_features[i].unsqueeze(0) @ text_features.T).softmax(dim=-1)
+ similarity /= image_features.shape[0]
+
+ top_probs, top_labels = similarity.cpu().topk(top_count, dim=-1)
+ return [(text_array[top_labels[0][i].numpy()], (top_probs[0][i].numpy()*100)) for i in range(top_count)]
+
+
+def clear_cuda():
+ torch.cuda.empty_cache()
+ gc.collect()
+
+
+def batch_rank(model, image_features, text_array, batch_size=st.session_state["defaults"].img2txt.batch_size):
+ batch_size = min(batch_size, len(text_array))
+ batch_count = int(len(text_array) / batch_size)
+ batches = [text_array[i*batch_size:(i+1)*batch_size] for i in range(batch_count)]
+ ranks = []
+ for batch in batches:
+ ranks += rank(model, image_features, batch)
+ return ranks
+
+def interrogate(image, models):
+ load_blip_model()
+
+ logger.info("Generating Caption")
+ st.session_state["log"].append("Generating Caption")
+ st.session_state["log_message"].code('\n'.join(st.session_state["log"]), language='')
+ caption = generate_caption(image)
+
+ if st.session_state["defaults"].general.optimized:
+ del server_state["blip_model"]
+ clear_cuda()
+
+ logger.info("Caption Generated")
+ st.session_state["log"].append("Caption Generated")
+ st.session_state["log_message"].code('\n'.join(st.session_state["log"]), language='')
+
+ if len(models) == 0:
+ logger.info(f"\n\n{caption}")
+ return
+
+ table = []
+ bests = [[('', 0)]]*7
+
+ logger.info("Ranking Text")
+ st.session_state["log"].append("Ranking Text")
+ st.session_state["log_message"].code('\n'.join(st.session_state["log"]), language='')
+
+ for model_name in models:
+ with torch.no_grad(), torch.autocast('cuda', dtype=torch.float16):
+ logger.info(f"Interrogating with {model_name}...")
+ st.session_state["log"].append(f"Interrogating with {model_name}...")
+ st.session_state["log_message"].code('\n'.join(st.session_state["log"]), language='')
+
+ if model_name not in server_state["clip_models"]:
+ if not st.session_state["defaults"].img2txt.keep_all_models_loaded:
+ model_to_delete = []
+ for model in server_state["clip_models"]:
+ if model != model_name:
+ model_to_delete.append(model)
+ for model in model_to_delete:
+ del server_state["clip_models"][model]
+ del server_state["preprocesses"][model]
+ clear_cuda()
+ if model_name == 'ViT-H-14':
+ server_state["clip_models"][model_name], _, server_state["preprocesses"][model_name] = \
+ open_clip.create_model_and_transforms(model_name, pretrained='laion2b_s32b_b79k', cache_dir='models/clip')
+ elif model_name == 'ViT-g-14':
+ server_state["clip_models"][model_name], _, server_state["preprocesses"][model_name] = \
+ open_clip.create_model_and_transforms(model_name, pretrained='laion2b_s12b_b42k', cache_dir='models/clip')
+ else:
+ server_state["clip_models"][model_name], server_state["preprocesses"][model_name] = \
+ clip.load(model_name, device=device, download_root='models/clip')
+ server_state["clip_models"][model_name] = server_state["clip_models"][model_name].cuda().eval()
+
+ images = server_state["preprocesses"][model_name](image).unsqueeze(0).cuda()
+
+
+ image_features = server_state["clip_models"][model_name].encode_image(images).float()
+
+ image_features /= image_features.norm(dim=-1, keepdim=True)
+
+ if st.session_state["defaults"].general.optimized:
+ clear_cuda()
+
+ ranks = []
+ ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["mediums"]))
+ ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, ["by "+artist for artist in server_state["artists"]]))
+ ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["trending_list"]))
+ ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["movements"]))
+ ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["flavors"]))
+ #ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["domains"]))
+ #ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["subreddits"]))
+ ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["techniques"]))
+ ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["tags"]))
+
+ # ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["genres"]))
+ # ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["styles"]))
+ # ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["subjects"]))
+ # ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["colors"]))
+ # ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["moods"]))
+ # ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["themes"]))
+ # ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["keywords"]))
+
+ #print (bests)
+ #print (ranks)
+
+ for i in range(len(ranks)):
+ confidence_sum = 0
+ for ci in range(len(ranks[i])):
+ confidence_sum += ranks[i][ci][1]
+ if confidence_sum > sum(bests[i][t][1] for t in range(len(bests[i]))):
+ bests[i] = ranks[i]
+
+ for best in bests:
+ best.sort(key=lambda x: x[1], reverse=True)
+ # prune to 3
+ best = best[:3]
+
+ row = [model_name]
+
+ for r in ranks:
+ row.append(', '.join([f"{x[0]} ({x[1]:0.1f}%)" for x in r]))
+
+ #for rank in ranks:
+ # rank.sort(key=lambda x: x[1], reverse=True)
+ # row.append(f'{rank[0][0]} {rank[0][1]:.2f}%')
+
+ table.append(row)
+
+ if st.session_state["defaults"].general.optimized:
+ del server_state["clip_models"][model_name]
+ gc.collect()
+
+ st.session_state["prediction_table"][st.session_state["processed_image_count"]].dataframe(pd.DataFrame(
+ table, columns=["Model", "Medium", "Artist", "Trending", "Movement", "Flavors", "Techniques", "Tags"]))
+
+ medium = bests[0][0][0]
+ artist = bests[1][0][0]
+ trending = bests[2][0][0]
+ movement = bests[3][0][0]
+ flavors = bests[4][0][0]
+ #domains = bests[5][0][0]
+ #subreddits = bests[6][0][0]
+ techniques = bests[5][0][0]
+ tags = bests[6][0][0]
+
+
+ if caption.startswith(medium):
+ st.session_state["text_result"][st.session_state["processed_image_count"]].code(
+ f"\n\n{caption} {artist}, {trending}, {movement}, {techniques}, {flavors}, {tags}", language="")
+ else:
+ st.session_state["text_result"][st.session_state["processed_image_count"]].code(
+ f"\n\n{caption}, {medium} {artist}, {trending}, {movement}, {techniques}, {flavors}, {tags}", language="")
+
+ logger.info("Finished Interrogating.")
+ st.session_state["log"].append("Finished Interrogating.")
+ st.session_state["log_message"].code('\n'.join(st.session_state["log"]), language='')
+
+
+def img2txt():
+ models = []
+
+ if st.session_state["ViT-L/14"]:
+ models.append('ViT-L/14')
+ if st.session_state["ViT-H-14"]:
+ models.append('ViT-H-14')
+ if st.session_state["ViT-g-14"]:
+ models.append('ViT-g-14')
+
+ if st.session_state["ViTB32"]:
+ models.append('ViT-B/32')
+ if st.session_state['ViTB16']:
+ models.append('ViT-B/16')
+
+ if st.session_state["ViTL14_336px"]:
+ models.append('ViT-L/14@336px')
+ if st.session_state["RN101"]:
+ models.append('RN101')
+ if st.session_state["RN50"]:
+ models.append('RN50')
+ if st.session_state["RN50x4"]:
+ models.append('RN50x4')
+ if st.session_state["RN50x16"]:
+ models.append('RN50x16')
+ if st.session_state["RN50x64"]:
+ models.append('RN50x64')
+
+ # if str(image_path_or_url).startswith('http://') or str(image_path_or_url).startswith('https://'):
+ #image = Image.open(requests.get(image_path_or_url, stream=True).raw).convert('RGB')
+ # else:
+ #image = Image.open(image_path_or_url).convert('RGB')
+
+ #thumb = st.session_state["uploaded_image"].image.copy()
+ #thumb.thumbnail([blip_image_eval_size, blip_image_eval_size])
+ # display(thumb)
+
+ st.session_state["processed_image_count"] = 0
+
+ for i in range(len(st.session_state["uploaded_image"])):
+
+ interrogate(st.session_state["uploaded_image"][i].pil_image, models=models)
+ # increase counter.
+ st.session_state["processed_image_count"] += 1
+#
+
+
+def layout():
+ #set_page_title("Image-to-Text - Stable Diffusion WebUI")
+ #st.info("Under Construction. :construction_worker:")
+ #
+ if "clip_models" not in server_state:
+ server_state["clip_models"] = {}
+ if "preprocesses" not in server_state:
+ server_state["preprocesses"] = {}
+ data_path = "data/"
+ if "artists" not in server_state:
+ server_state["artists"] = load_list(os.path.join(data_path, 'img2txt', 'artists.txt'))
+ if "flavors" not in server_state:
+ server_state["flavors"] = random.choices(load_list(os.path.join(data_path, 'img2txt', 'flavors.txt')), k=2000)
+ if "mediums" not in server_state:
+ server_state["mediums"] = load_list(os.path.join(data_path, 'img2txt', 'mediums.txt'))
+ if "movements" not in server_state:
+ server_state["movements"] = load_list(os.path.join(data_path, 'img2txt', 'movements.txt'))
+ if "sites" not in server_state:
+ server_state["sites"] = load_list(os.path.join(data_path, 'img2txt', 'sites.txt'))
+ #server_state["domains"] = load_list(os.path.join(data_path, 'img2txt', 'domains.txt'))
+ #server_state["subreddits"] = load_list(os.path.join(data_path, 'img2txt', 'subreddits.txt'))
+ if "techniques" not in server_state:
+ server_state["techniques"] = load_list(os.path.join(data_path, 'img2txt', 'techniques.txt'))
+ if "tags" not in server_state:
+ server_state["tags"] = load_list(os.path.join(data_path, 'img2txt', 'tags.txt'))
+ #server_state["genres"] = load_list(os.path.join(data_path, 'img2txt', 'genres.txt'))
+ # server_state["styles"] = load_list(os.path.join(data_path, 'img2txt', 'styles.txt'))
+ # server_state["subjects"] = load_list(os.path.join(data_path, 'img2txt', 'subjects.txt'))
+ if "trending_list" not in server_state:
+ server_state["trending_list"] = [site for site in server_state["sites"]]
+ server_state["trending_list"].extend(["trending on "+site for site in server_state["sites"]])
+ server_state["trending_list"].extend(["featured on "+site for site in server_state["sites"]])
+ server_state["trending_list"].extend([site+" contest winner" for site in server_state["sites"]])
+ with st.form("img2txt-inputs"):
+ st.session_state["generation_mode"] = "img2txt"
+
+ # st.write("---")
+ # creating the page layout using columns
+ col1, col2 = st.columns([1, 4], gap="large")
+
+ with col1:
+ st.session_state["uploaded_image"] = st.file_uploader('Input Image', type=['png', 'jpg', 'jpeg', 'jfif', 'webp'], accept_multiple_files=True)
+
+ with st.expander("CLIP models", expanded=True):
+ st.session_state["ViT-L/14"] = st.checkbox("ViT-L/14", value=True, help="ViT-L/14 model.")
+ st.session_state["ViT-H-14"] = st.checkbox("ViT-H-14", value=False, help="ViT-H-14 model.")
+ st.session_state["ViT-g-14"] = st.checkbox("ViT-g-14", value=False, help="ViT-g-14 model.")
+
+
+
+ with st.expander("Others"):
+ st.info("For DiscoDiffusion and JAX enable all the same models here as you intend to use when generating your images.")
+
+ st.session_state["ViTL14_336px"] = st.checkbox("ViTL14_336px", value=False, help="ViTL14_336px model.")
+ st.session_state["ViTB16"] = st.checkbox("ViTB16", value=False, help="ViTB16 model.")
+ st.session_state["ViTB32"] = st.checkbox("ViTB32", value=False, help="ViTB32 model.")
+ st.session_state["RN50"] = st.checkbox("RN50", value=False, help="RN50 model.")
+ st.session_state["RN50x4"] = st.checkbox("RN50x4", value=False, help="RN50x4 model.")
+ st.session_state["RN50x16"] = st.checkbox("RN50x16", value=False, help="RN50x16 model.")
+ st.session_state["RN50x64"] = st.checkbox("RN50x64", value=False, help="RN50x64 model.")
+ st.session_state["RN101"] = st.checkbox("RN101", value=False, help="RN101 model.")
+
+ #
+ # st.subheader("Logs:")
+
+ st.session_state["log_message"] = st.empty()
+ st.session_state["log_message"].code('', language="")
+
+ with col2:
+ st.subheader("Image")
+
+ image_col1, image_col2 = st.columns([10,25])
+ with image_col1:
+ refresh = st.form_submit_button("Update Preview Image", help='Refresh the image preview to show your uploaded image instead of the default placeholder.')
+
+ if st.session_state["uploaded_image"]:
+ #print (type(st.session_state["uploaded_image"]))
+ # if len(st.session_state["uploaded_image"]) == 1:
+ st.session_state["input_image_preview"] = []
+ st.session_state["input_image_preview_container"] = []
+ st.session_state["prediction_table"] = []
+ st.session_state["text_result"] = []
+
+ for i in range(len(st.session_state["uploaded_image"])):
+ st.session_state["input_image_preview_container"].append(i)
+ st.session_state["input_image_preview_container"][i] = st.empty()
+
+ with st.session_state["input_image_preview_container"][i].container():
+ col1_output, col2_output = st.columns([2, 10], gap="medium")
+ with col1_output:
+ st.session_state["input_image_preview"].append(i)
+ st.session_state["input_image_preview"][i] = st.empty()
+ st.session_state["uploaded_image"][i].pil_image = Image.open(st.session_state["uploaded_image"][i]).convert('RGB')
+
+ st.session_state["input_image_preview"][i].image(st.session_state["uploaded_image"][i].pil_image, use_column_width=True, clamp=True)
+
+ with st.session_state["input_image_preview_container"][i].container():
+
+ with col2_output:
+
+ st.session_state["prediction_table"].append(i)
+ st.session_state["prediction_table"][i] = st.empty()
+ st.session_state["prediction_table"][i].table()
+
+ st.session_state["text_result"].append(i)
+ st.session_state["text_result"][i] = st.empty()
+ st.session_state["text_result"][i].code("", language="")
+
+ else:
+ #st.session_state["input_image_preview"].code('', language="")
+ st.image("images/streamlit/img2txt_placeholder.png", clamp=True)
+
+ with image_col2:
+ #
+ # Every form must have a submit button, the extra blank spaces is a temp way to align it with the input field. Needs to be done in CSS or some other way.
+ # generate_col1.title("")
+ # generate_col1.title("")
+ generate_button = st.form_submit_button("Generate!", help="Start interrogating the images to generate a prompt from each of the selected images")
+
+ if generate_button:
+ # if model, pipe, RealESRGAN or GFPGAN is in st.session_state remove the model and pipe form session_state so that they are reloaded.
+ if "model" in server_state and st.session_state["defaults"].general.optimized:
+ del server_state["model"]
+ if "pipe" in server_state and st.session_state["defaults"].general.optimized:
+ del server_state["pipe"]
+ if "RealESRGAN" in server_state and st.session_state["defaults"].general.optimized:
+ del server_state["RealESRGAN"]
+ if "GFPGAN" in server_state and st.session_state["defaults"].general.optimized:
+ del server_state["GFPGAN"]
+
+ # run clip interrogator
+ img2txt()
diff --git a/webui/streamlit/scripts/post_processing.py b/webui/streamlit/scripts/post_processing.py
new file mode 100644
index 0000000..61de288
--- /dev/null
+++ b/webui/streamlit/scripts/post_processing.py
@@ -0,0 +1,368 @@
+# This file is part of sygil-webui (https://github.com/Sygil-Dev/sandbox-webui/).
+
+# Copyright 2022 Sygil-Dev team.
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+# base webui import and utils.
+#from sd_utils import *
+from sd_utils import st, server_state, \
+ RealESRGAN_available, GFPGAN_available, LDSR_available
+
+# streamlit imports
+
+#streamlit components section
+import hydralit_components as hc
+
+#other imports
+import os
+from PIL import Image
+import torch
+
+# Temp imports
+
+# end of imports
+#---------------------------------------------------------------------------------------------------------------
+
+def post_process(use_GFPGAN=True, GFPGAN_model='', use_RealESRGAN=False, realesrgan_model_name="", use_LDSR=False, LDSR_model_name=""):
+
+ for i in range(len(st.session_state["uploaded_image"])):
+ #st.session_state["uploaded_image"][i].pil_image
+
+ if use_GFPGAN and server_state["GFPGAN"] is not None and not use_RealESRGAN and not use_LDSR:
+ if "progress_bar_text" in st.session_state:
+ st.session_state["progress_bar_text"].text("Running GFPGAN on image %d of %d..." % (i+1, len(st.session_state["uploaded_image"])))
+
+ if "progress_bar" in st.session_state:
+ st.session_state["progress_bar"].progress(
+ int(100 * float(i+1 if i+1 < len(st.session_state["uploaded_image"]) else len(st.session_state["uploaded_image"]))/float(len(st.session_state["uploaded_image"]))))
+
+ if server_state["GFPGAN"].name != GFPGAN_model:
+ load_models(use_LDSR=use_LDSR, LDSR_model=LDSR_model_name, use_GFPGAN=use_GFPGAN, use_RealESRGAN=use_RealESRGAN, RealESRGAN_model=realesrgan_model_name)
+
+ torch_gc()
+
+ with torch.autocast('cuda'):
+ cropped_faces, restored_faces, restored_img = server_state["GFPGAN"].enhance(st.session_state["uploaded_image"][i].pil_image, has_aligned=False, only_center_face=False, paste_back=True)
+
+ gfpgan_sample = restored_img[:,:,::-1]
+ gfpgan_image = Image.fromarray(gfpgan_sample)
+
+ #if st.session_state["GFPGAN_strenght"]:
+ #gfpgan_sample = Image.blend(image, gfpgan_image, st.session_state["GFPGAN_strenght"])
+
+ gfpgan_filename = st.session_state["uploaded_image"][i].name.split('.')[0] + '-gfpgan'
+
+ gfpgan_image.save(os.path.join(st.session_state["defaults"].post_processing.outdir_post_processing, f"{gfpgan_filename}.png"))
+
+ #
+ elif use_RealESRGAN and server_state["RealESRGAN"] is not None and not use_GFPGAN:
+ if "progress_bar_text" in st.session_state:
+ st.session_state["progress_bar_text"].text("Running RealESRGAN on image %d of %d..." % (i+1, len(st.session_state["uploaded_image"])))
+
+ if "progress_bar" in st.session_state:
+ st.session_state["progress_bar"].progress(
+ int(100 * float(i+1 if i+1 < len(st.session_state["uploaded_image"]) else len(st.session_state["uploaded_image"]))/float(len(st.session_state["uploaded_image"]))))
+
+ torch_gc()
+
+ if server_state["RealESRGAN"].model.name != realesrgan_model_name:
+ #try_loading_RealESRGAN(realesrgan_model_name)
+ load_models(use_GFPGAN=use_GFPGAN, use_RealESRGAN=use_RealESRGAN, RealESRGAN_model=realesrgan_model_name)
+
+ output, img_mode = server_state["RealESRGAN"].enhance(st.session_state["uploaded_image"][i].pil_image)
+ esrgan_filename = st.session_state["uploaded_image"][i].name.split('.')[0] + '-esrgan4x'
+ esrgan_sample = output[:,:,::-1]
+ esrgan_image = Image.fromarray(esrgan_sample)
+
+ esrgan_image.save(os.path.join(st.session_state["defaults"].post_processing.outdir_post_processing, f"{esrgan_filename}.png"))
+
+ #
+ elif use_LDSR and "LDSR" in server_state and not use_GFPGAN:
+ logger.info ("Running LDSR on image %d of %d..." % (i+1, len(st.session_state["uploaded_image"])))
+ if "progress_bar_text" in st.session_state:
+ st.session_state["progress_bar_text"].text("Running LDSR on image %d of %d..." % (i+1, len(st.session_state["uploaded_image"])))
+ if "progress_bar" in st.session_state:
+ st.session_state["progress_bar"].progress(
+ int(100 * float(i+1 if i+1 < len(st.session_state["uploaded_image"]) else len(st.session_state["uploaded_image"]))/float(len(st.session_state["uploaded_image"]))))
+
+ torch_gc()
+
+ if server_state["LDSR"].name != LDSR_model_name:
+ #try_loading_RealESRGAN(realesrgan_model_name)
+ load_models(use_LDSR=use_LDSR, LDSR_model=LDSR_model_name, use_GFPGAN=use_GFPGAN, use_RealESRGAN=use_RealESRGAN, RealESRGAN_model=realesrgan_model_name)
+
+ result = server_state["LDSR"].superResolution(st.session_state["uploaded_image"][i].pil_image, ddimSteps = st.session_state["ldsr_sampling_steps"],
+ preDownScale = st.session_state["preDownScale"], postDownScale = st.session_state["postDownScale"],
+ downsample_method=st.session_state["downsample_method"])
+
+ ldsr_filename = st.session_state["uploaded_image"][i].name.split('.')[0] + '-ldsr4x'
+
+ result.save(os.path.join(st.session_state["defaults"].post_processing.outdir_post_processing, f"{ldsr_filename}.png"))
+
+ #
+ elif use_LDSR and "LDSR" in server_state and use_GFPGAN and "GFPGAN" in server_state:
+ logger.info ("Running GFPGAN+LDSR on image %d of %d..." % (i+1, len(st.session_state["uploaded_image"])))
+ if "progress_bar_text" in st.session_state:
+ st.session_state["progress_bar_text"].text("Running GFPGAN+LDSR on image %d of %d..." % (i+1, len(st.session_state["uploaded_image"])))
+
+ if "progress_bar" in st.session_state:
+ st.session_state["progress_bar"].progress(
+ int(100 * float(i+1 if i+1 < len(st.session_state["uploaded_image"]) else len(st.session_state["uploaded_image"]))/float(len(st.session_state["uploaded_image"]))))
+
+ if server_state["GFPGAN"].name != GFPGAN_model:
+ load_models(use_LDSR=use_LDSR, LDSR_model=LDSR_model_name, use_GFPGAN=use_GFPGAN, use_RealESRGAN=use_RealESRGAN, RealESRGAN_model=realesrgan_model_name)
+
+ torch_gc()
+ cropped_faces, restored_faces, restored_img = server_state["GFPGAN"].enhance(st.session_state["uploaded_image"][i].pil_image, has_aligned=False, only_center_face=False, paste_back=True)
+
+ gfpgan_sample = restored_img[:,:,::-1]
+ gfpgan_image = Image.fromarray(gfpgan_sample)
+
+ if server_state["LDSR"].name != LDSR_model_name:
+ #try_loading_RealESRGAN(realesrgan_model_name)
+ load_models(use_LDSR=use_LDSR, LDSR_model=LDSR_model_name, use_GFPGAN=use_GFPGAN, use_RealESRGAN=use_RealESRGAN, RealESRGAN_model=realesrgan_model_name)
+
+ #LDSR.superResolution(gfpgan_image, ddimSteps=100, preDownScale='None', postDownScale='None', downsample_method="Lanczos")
+ result = server_state["LDSR"].superResolution(gfpgan_image, ddimSteps = st.session_state["ldsr_sampling_steps"],
+ preDownScale = st.session_state["preDownScale"], postDownScale = st.session_state["postDownScale"],
+ downsample_method=st.session_state["downsample_method"])
+
+ ldsr_filename = st.session_state["uploaded_image"][i].name.split('.')[0] + '-gfpgan-ldsr2x'
+
+ result.save(os.path.join(st.session_state["defaults"].post_processing.outdir_post_processing, f"{ldsr_filename}.png"))
+
+ elif use_RealESRGAN and server_state["RealESRGAN"] is not None and use_GFPGAN and server_state["GFPGAN"] is not None:
+ if "progress_bar_text" in st.session_state:
+ st.session_state["progress_bar_text"].text("Running GFPGAN+RealESRGAN on image %d of %d..." % (i+1, len(st.session_state["uploaded_image"])))
+
+ if "progress_bar" in st.session_state:
+ st.session_state["progress_bar"].progress(
+ int(100 * float(i+1 if i+1 < len(st.session_state["uploaded_image"]) else len(st.session_state["uploaded_image"]))/float(len(st.session_state["uploaded_image"]))))
+
+ torch_gc()
+ cropped_faces, restored_faces, restored_img = server_state["GFPGAN"].enhance(st.session_state["uploaded_image"][i].pil_image, has_aligned=False, only_center_face=False, paste_back=True)
+ gfpgan_sample = restored_img[:,:,::-1]
+
+ if server_state["RealESRGAN"].model.name != realesrgan_model_name:
+ #try_loading_RealESRGAN(realesrgan_model_name)
+ load_models(use_GFPGAN=use_GFPGAN, use_RealESRGAN=use_RealESRGAN, RealESRGAN_model=realesrgan_model_name)
+
+ output, img_mode = server_state["RealESRGAN"].enhance(gfpgan_sample[:,:,::-1])
+ gfpgan_esrgan_filename = st.session_state["uploaded_image"][i].name.split('.')[0] + '-gfpgan-esrgan4x'
+ gfpgan_esrgan_sample = output[:,:,::-1]
+ gfpgan_esrgan_image = Image.fromarray(gfpgan_esrgan_sample)
+
+ gfpgan_esrgan_image.save(os.path.join(st.session_state["defaults"].post_processing.outdir_post_processing, f"{gfpgan_esrgan_filename}.png"))
+
+
+
+def layout():
+ #st.info("Under Construction. :construction_worker:")
+ st.session_state["progress_bar_text"] = st.empty()
+ #st.session_state["progress_bar_text"].info("Nothing but crickets here, try generating something first.")
+
+ st.session_state["progress_bar"] = st.empty()
+
+ with st.form("post-processing-inputs"):
+ # creating the page layout using columns
+ col1, col2 = st.columns([1, 4], gap="medium")
+
+ with col1:
+ st.session_state["uploaded_image"] = st.file_uploader('Input Image', type=['png', 'jpg', 'jpeg', 'jfif', 'webp'], accept_multiple_files=True)
+
+
+ # check if GFPGAN, RealESRGAN and LDSR are available.
+ #if "GFPGAN_available" not in st.session_state:
+ GFPGAN_available()
+
+ #if "RealESRGAN_available" not in st.session_state:
+ RealESRGAN_available()
+
+ #if "LDSR_available" not in st.session_state:
+ LDSR_available()
+
+ if st.session_state["GFPGAN_available"] or st.session_state["RealESRGAN_available"] or st.session_state["LDSR_available"]:
+ face_restoration_tab, upscaling_tab = st.tabs(["Face Restoration", "Upscaling"])
+ with face_restoration_tab:
+ # GFPGAN used for face restoration
+ if st.session_state["GFPGAN_available"]:
+ #with st.expander("Face Restoration"):
+ #if st.session_state["GFPGAN_available"]:
+ #with st.expander("GFPGAN"):
+ st.session_state["use_GFPGAN"] = st.checkbox("Use GFPGAN", value=st.session_state['defaults'].txt2img.use_GFPGAN,
+ help="Uses the GFPGAN model to improve faces after the generation.\
+ This greatly improve the quality and consistency of faces but uses\
+ extra VRAM. Disable if you need the extra VRAM.")
+
+ st.session_state["GFPGAN_model"] = st.selectbox("GFPGAN model", st.session_state["GFPGAN_models"],
+ index=st.session_state["GFPGAN_models"].index(st.session_state['defaults'].general.GFPGAN_model))
+
+ #st.session_state["GFPGAN_strenght"] = st.slider("Effect Strenght", min_value=1, max_value=100, value=1, step=1, help='')
+
+ else:
+ st.session_state["use_GFPGAN"] = False
+
+ with upscaling_tab:
+ st.session_state['use_upscaling'] = st.checkbox("Use Upscaling", value=st.session_state['defaults'].txt2img.use_upscaling)
+
+ # RealESRGAN and LDSR used for upscaling.
+ if st.session_state["RealESRGAN_available"] or st.session_state["LDSR_available"]:
+
+ upscaling_method_list = []
+ if st.session_state["RealESRGAN_available"]:
+ upscaling_method_list.append("RealESRGAN")
+ if st.session_state["LDSR_available"]:
+ upscaling_method_list.append("LDSR")
+
+ #print (st.session_state["RealESRGAN_available"])
+ st.session_state["upscaling_method"] = st.selectbox("Upscaling Method", upscaling_method_list,
+ index=upscaling_method_list.index(st.session_state['defaults'].general.upscaling_method)
+ if st.session_state['defaults'].general.upscaling_method in upscaling_method_list
+ else 0)
+
+ if st.session_state["RealESRGAN_available"]:
+ with st.expander("RealESRGAN"):
+ if st.session_state["upscaling_method"] == "RealESRGAN" and st.session_state['use_upscaling']:
+ st.session_state["use_RealESRGAN"] = True
+ else:
+ st.session_state["use_RealESRGAN"] = False
+
+ st.session_state["RealESRGAN_model"] = st.selectbox("RealESRGAN model", st.session_state["RealESRGAN_models"],
+ index=st.session_state["RealESRGAN_models"].index(st.session_state['defaults'].general.RealESRGAN_model))
+ else:
+ st.session_state["use_RealESRGAN"] = False
+ st.session_state["RealESRGAN_model"] = "RealESRGAN_x4plus"
+
+
+ #
+ if st.session_state["LDSR_available"]:
+ with st.expander("LDSR"):
+ if st.session_state["upscaling_method"] == "LDSR" and st.session_state['use_upscaling']:
+ st.session_state["use_LDSR"] = True
+ else:
+ st.session_state["use_LDSR"] = False
+
+ st.session_state["LDSR_model"] = st.selectbox("LDSR model", st.session_state["LDSR_models"],
+ index=st.session_state["LDSR_models"].index(st.session_state['defaults'].general.LDSR_model))
+
+ st.session_state["ldsr_sampling_steps"] = st.number_input("Sampling Steps", value=st.session_state['defaults'].txt2img.LDSR_config.sampling_steps,
+ help="")
+
+ st.session_state["preDownScale"] = st.number_input("PreDownScale", value=st.session_state['defaults'].txt2img.LDSR_config.preDownScale,
+ help="")
+
+ st.session_state["postDownScale"] = st.number_input("postDownScale", value=st.session_state['defaults'].txt2img.LDSR_config.postDownScale,
+ help="")
+
+ downsample_method_list = ['Nearest', 'Lanczos']
+ st.session_state["downsample_method"] = st.selectbox("Downsample Method", downsample_method_list,
+ index=downsample_method_list.index(st.session_state['defaults'].txt2img.LDSR_config.downsample_method))
+
+ else:
+ st.session_state["use_LDSR"] = False
+ st.session_state["LDSR_model"] = "model"
+
+ #process = st.form_submit_button("Process Images", help="")
+
+ #
+ with st.expander("Output Settings", True):
+ #st.session_state['defaults'].post_processing.save_original_images = st.checkbox("Save input images.", value=st.session_state['defaults'].post_processing.save_original_images,
+ #help="Save each original/input image next to the Post Processed image. "
+ #"This might be helpful for comparing the before and after images.")
+
+ st.session_state['defaults'].post_processing.outdir_post_processing = st.text_input("Output Dir",value=st.session_state['defaults'].post_processing.outdir_post_processing,
+ help="Folder where the images will be saved after post processing.")
+
+ with col2:
+ st.subheader("Image")
+
+ image_col1, image_col2, image_col3 = st.columns([2, 2, 2], gap="small")
+ with image_col1:
+ refresh = st.form_submit_button("Refresh", help='Refresh the image preview to show your uploaded image.')
+
+ if st.session_state["uploaded_image"]:
+ #print (type(st.session_state["uploaded_image"]))
+ # if len(st.session_state["uploaded_image"]) == 1:
+ st.session_state["input_image_preview"] = []
+ st.session_state["input_image_caption"] = []
+ st.session_state["output_image_preview"] = []
+ st.session_state["output_image_caption"] = []
+ st.session_state["input_image_preview_container"] = []
+ st.session_state["prediction_table"] = []
+ st.session_state["text_result"] = []
+
+ for i in range(len(st.session_state["uploaded_image"])):
+ st.session_state["input_image_preview_container"].append(i)
+ st.session_state["input_image_preview_container"][i] = st.empty()
+
+ with st.session_state["input_image_preview_container"][i].container():
+ col1_output, col2_output, col3_output = st.columns([2, 2, 2], gap="medium")
+ with col1_output:
+ st.session_state["output_image_caption"].append(i)
+ st.session_state["output_image_caption"][i] = st.empty()
+ #st.session_state["output_image_caption"][i] = st.session_state["uploaded_image"][i].name
+
+ st.session_state["input_image_caption"].append(i)
+ st.session_state["input_image_caption"][i] = st.empty()
+ #st.session_state["input_image_caption"][i].caption(")
+
+ st.session_state["input_image_preview"].append(i)
+ st.session_state["input_image_preview"][i] = st.empty()
+ st.session_state["uploaded_image"][i].pil_image = Image.open(st.session_state["uploaded_image"][i]).convert('RGB')
+
+ st.session_state["input_image_preview"][i].image(st.session_state["uploaded_image"][i].pil_image, use_column_width=True, clamp=True)
+
+ with col2_output:
+ st.session_state["output_image_preview"].append(i)
+ st.session_state["output_image_preview"][i] = st.empty()
+
+ st.session_state["output_image_preview"][i].image(st.session_state["uploaded_image"][i].pil_image, use_column_width=True, clamp=True)
+
+ with st.session_state["input_image_preview_container"][i].container():
+
+ with col3_output:
+
+ #st.session_state["prediction_table"].append(i)
+ #st.session_state["prediction_table"][i] = st.empty()
+ #st.session_state["prediction_table"][i].table(pd.DataFrame(columns=["Model", "Filename", "Progress"]))
+
+ st.session_state["text_result"].append(i)
+ st.session_state["text_result"][i] = st.empty()
+ st.session_state["text_result"][i].code("", language="")
+
+ #else:
+ ##st.session_state["input_image_preview"].code('', language="")
+ #st.image("images/streamlit/img2txt_placeholder.png", clamp=True)
+
+ with image_col3:
+ # Every form must have a submit button, the extra blank spaces is a temp way to align it with the input field. Needs to be done in CSS or some other way.
+ process = st.form_submit_button("Process Images!")
+
+ if process:
+ with hc.HyLoader('Loading Models...', hc.Loaders.standard_loaders,index=[0]):
+ #load_models(use_LDSR=st.session_state["use_LDSR"], LDSR_model=st.session_state["LDSR_model"],
+ #use_GFPGAN=st.session_state["use_GFPGAN"], GFPGAN_model=st.session_state["GFPGAN_model"] ,
+ #use_RealESRGAN=st.session_state["use_RealESRGAN"], RealESRGAN_model=st.session_state["RealESRGAN_model"])
+
+ if st.session_state["use_GFPGAN"]:
+ load_GFPGAN(model_name=st.session_state["GFPGAN_model"])
+
+ if st.session_state["use_RealESRGAN"]:
+ load_RealESRGAN(st.session_state["RealESRGAN_model"])
+
+ if st.session_state["use_LDSR"]:
+ load_LDSR(st.session_state["LDSR_model"])
+
+ post_process(use_GFPGAN=st.session_state["use_GFPGAN"], GFPGAN_model=st.session_state["GFPGAN_model"],
+ use_RealESRGAN=st.session_state["use_RealESRGAN"], realesrgan_model_name=st.session_state["RealESRGAN_model"],
+ use_LDSR=st.session_state["use_LDSR"], LDSR_model_name=st.session_state["LDSR_model"])
\ No newline at end of file
diff --git a/webui/streamlit/scripts/sd_concept_library.py b/webui/streamlit/scripts/sd_concept_library.py
new file mode 100644
index 0000000..f8c1217
--- /dev/null
+++ b/webui/streamlit/scripts/sd_concept_library.py
@@ -0,0 +1,260 @@
+# This file is part of sygil-webui (https://github.com/Sygil-Dev/sygil-webui/).
+
+# Copyright 2022 Sygil-Dev team.
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+# base webui import and utils.
+from sd_utils import st
+
+# streamlit imports
+import streamlit.components.v1 as components
+#other imports
+
+import os, math
+from PIL import Image
+
+# Temp imports
+#from basicsr.utils.registry import ARCH_REGISTRY
+
+
+# end of imports
+#---------------------------------------------------------------------------------------------------------------
+
+# Init Vuejs component
+_component_func = components.declare_component(
+ "sd-concepts-browser", "./frontend/dists/concept-browser/dist")
+
+
+def sdConceptsBrowser(concepts, key=None):
+ component_value = _component_func(concepts=concepts, key=key, default="")
+ return component_value
+
+
+@st.experimental_memo(persist="disk", show_spinner=False, suppress_st_warning=True)
+def getConceptsFromPath(page, conceptPerPage, searchText=""):
+ #print("getConceptsFromPath", "page:", page, "conceptPerPage:", conceptPerPage, "searchText:", searchText)
+ # get the path where the concepts are stored
+ path = os.path.join(
+ os.getcwd(), st.session_state['defaults'].general.sd_concepts_library_folder)
+ acceptedExtensions = ('jpeg', 'jpg', "png")
+ concepts = []
+
+ if os.path.exists(path):
+ # List all folders (concepts) in the path
+ folders = [f for f in os.listdir(
+ path) if os.path.isdir(os.path.join(path, f))]
+ filteredFolders = folders
+
+ # Filter the folders by the search text
+ if searchText != "":
+ filteredFolders = [
+ f for f in folders if searchText.lower() in f.lower()]
+ else:
+ filteredFolders = []
+
+ conceptIndex = 1
+ for folder in filteredFolders:
+ # handle pagination
+ if conceptIndex > (page * conceptPerPage):
+ continue
+ if conceptIndex <= ((page - 1) * conceptPerPage):
+ conceptIndex += 1
+ continue
+
+ concept = {
+ "name": folder,
+ "token": "<" + folder + ">",
+ "images": [],
+ "type": ""
+ }
+
+ # type of concept is inside type_of_concept.txt
+ typePath = os.path.join(path, folder, "type_of_concept.txt")
+ binFile = os.path.join(path, folder, "learned_embeds.bin")
+
+ # Continue if the concept is not valid or the download has failed (no type_of_concept.txt or no binFile)
+ if not os.path.exists(typePath) or not os.path.exists(binFile):
+ continue
+
+ with open(typePath, "r") as f:
+ concept["type"] = f.read()
+
+ # List all files in the concept/concept_images folder
+ files = [f for f in os.listdir(os.path.join(path, folder, "concept_images")) if os.path.isfile(
+ os.path.join(path, folder, "concept_images", f))]
+ # Retrieve only the 4 first images
+ for file in files:
+
+ # Skip if we already have 4 images
+ if len(concept["images"]) >= 4:
+ break
+
+ if file.endswith(acceptedExtensions):
+ try:
+ # Add a copy of the image to avoid file locking
+ originalImage = Image.open(os.path.join(
+ path, folder, "concept_images", file))
+
+ # Maintain the aspect ratio (max 200x200)
+ resizedImage = originalImage.copy()
+ resizedImage.thumbnail((200, 200), Image.Resampling.LANCZOS)
+
+ # concept["images"].append(resizedImage)
+
+ concept["images"].append(imageToBase64(resizedImage))
+ # Close original image
+ originalImage.close()
+ except:
+ print("Error while loading image", file, "in concept", folder, "(The file may be corrupted). Skipping it.")
+
+ concepts.append(concept)
+ conceptIndex += 1
+ # print all concepts name
+ #print("Results:", [c["name"] for c in concepts])
+ return concepts
+
+@st.cache(persist=True, allow_output_mutation=True, show_spinner=False, suppress_st_warning=True)
+def imageToBase64(image):
+ import io
+ import base64
+ buffered = io.BytesIO()
+ image.save(buffered, format="PNG")
+ img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
+ return img_str
+
+
+@st.experimental_memo(persist="disk", show_spinner=False, suppress_st_warning=True)
+def getTotalNumberOfConcepts(searchText=""):
+ # get the path where the concepts are stored
+ path = os.path.join(
+ os.getcwd(), st.session_state['defaults'].general.sd_concepts_library_folder)
+ concepts = []
+
+ if os.path.exists(path):
+ # List all folders (concepts) in the path
+ folders = [f for f in os.listdir(
+ path) if os.path.isdir(os.path.join(path, f))]
+ filteredFolders = folders
+
+ # Filter the folders by the search text
+ if searchText != "":
+ filteredFolders = [
+ f for f in folders if searchText.lower() in f.lower()]
+ else:
+ filteredFolders = []
+ return len(filteredFolders)
+
+
+
+def layout():
+ # 2 tabs, one for Concept Library and one for the Download Manager
+ tab_library, tab_downloader = st.tabs(["Library", "Download Manager"])
+
+ # Concept Library
+ with tab_library:
+ downloaded_concepts_count = getTotalNumberOfConcepts()
+ concepts_per_page = st.session_state["defaults"].concepts_library.concepts_per_page
+
+ if not "results" in st.session_state:
+ st.session_state["results"] = getConceptsFromPath(1, concepts_per_page, "")
+
+ # Pagination controls
+ if not "cl_current_page" in st.session_state:
+ st.session_state["cl_current_page"] = 1
+
+ # Search
+ if not 'cl_search_text' in st.session_state:
+ st.session_state["cl_search_text"] = ""
+
+ if not 'cl_search_results_count' in st.session_state:
+ st.session_state["cl_search_results_count"] = downloaded_concepts_count
+
+ # Search bar
+ _search_col, _refresh_col = st.columns([10, 2])
+ with _search_col:
+ search_text_input = st.text_input("Search", "", placeholder=f'Search for a concept ({downloaded_concepts_count} available)', label_visibility="hidden")
+ if search_text_input != st.session_state["cl_search_text"]:
+ # Search text has changed
+ st.session_state["cl_search_text"] = search_text_input
+ st.session_state["cl_current_page"] = 1
+ st.session_state["cl_search_results_count"] = getTotalNumberOfConcepts(st.session_state["cl_search_text"])
+ st.session_state["results"] = getConceptsFromPath(1, concepts_per_page, st.session_state["cl_search_text"])
+
+ with _refresh_col:
+ # Super weird fix to align the refresh button with the search bar ( Please streamlit, add css support.. )
+ _refresh_col.write("")
+ _refresh_col.write("")
+ if st.button("Refresh concepts", key="refresh_concepts", help="Refresh the concepts folders. Use this if you have added new concepts manually or deleted some."):
+ getTotalNumberOfConcepts.clear()
+ getConceptsFromPath.clear()
+ st.experimental_rerun()
+
+
+ # Show results
+ results_empty = st.empty()
+
+ # Pagination
+ pagination_empty = st.empty()
+
+ # Layouts
+ with pagination_empty:
+ with st.container():
+ if len(st.session_state["results"]) > 0:
+ last_page = math.ceil(st.session_state["cl_search_results_count"] / concepts_per_page)
+ _1, _2, _3, _4, _previous_page, _current_page, _next_page, _9, _10, _11, _12 = st.columns([1,1,1,1,1,2,1,1,1,1,1])
+
+ # Previous page
+ with _previous_page:
+ if st.button("Previous", key="cl_previous_page"):
+ st.session_state["cl_current_page"] -= 1
+ if st.session_state["cl_current_page"] <= 0:
+ st.session_state["cl_current_page"] = last_page
+ st.session_state["results"] = getConceptsFromPath(st.session_state["cl_current_page"], concepts_per_page, st.session_state["cl_search_text"])
+
+ # Current page
+ with _current_page:
+ _current_page_container = st.empty()
+
+ # Next page
+ with _next_page:
+ if st.button("Next", key="cl_next_page"):
+ st.session_state["cl_current_page"] += 1
+ if st.session_state["cl_current_page"] > last_page:
+ st.session_state["cl_current_page"] = 1
+ st.session_state["results"] = getConceptsFromPath(st.session_state["cl_current_page"], concepts_per_page, st.session_state["cl_search_text"])
+
+ # Current page
+ with _current_page_container:
+ st.markdown(f'
Page {st.session_state["cl_current_page"]} of {last_page}
', unsafe_allow_html=True)
+ # st.write(f"Page {st.session_state['cl_current_page']} of {last_page}", key="cl_current_page")
+
+ with results_empty:
+ with st.container():
+ if downloaded_concepts_count == 0:
+ st.write("You don't have any concepts in your library ")
+ st.markdown("To add concepts to your library, download some from the [sd-concepts-library](https://github.com/Sygil-Dev/sd-concepts-library) \
+ repository and save the content of `sd-concepts-library` into ```./models/custom/sd-concepts-library``` or just create your own concepts :wink:.", unsafe_allow_html=False)
+ else:
+ if len(st.session_state["results"]) == 0:
+ st.write("No concept found in the library matching your search: " + st.session_state["cl_search_text"])
+ else:
+ # display number of results
+ if st.session_state["cl_search_text"]:
+ st.write(f"Found {st.session_state['cl_search_results_count']} {'concepts' if st.session_state['cl_search_results_count'] > 1 else 'concept' } matching your search")
+ sdConceptsBrowser(st.session_state['results'], key="results")
+
+
+ with tab_downloader:
+ st.write("Not implemented yet")
+
+ return False
diff --git a/webui/streamlit/scripts/sd_utils/__init__.py b/webui/streamlit/scripts/sd_utils/__init__.py
new file mode 100644
index 0000000..f38124b
--- /dev/null
+++ b/webui/streamlit/scripts/sd_utils/__init__.py
@@ -0,0 +1,405 @@
+# This file is part of sygil-webui (https://github.com/Sygil-Dev/sygil-webui/).
+
+# Copyright 2022 Sygil-Dev team.
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+# base webui import and utils.
+#from webui_streamlit import st
+import hydralit as st
+
+# streamlit imports
+from streamlit.runtime.scriptrunner import StopException
+#from streamlit.runtime.scriptrunner import script_run_context
+
+#streamlit components section
+from streamlit_server_state import server_state, server_state_lock, no_rerun
+import hydralit_components as hc
+from hydralit import HydraHeadApp
+import streamlit_nested_layout
+#from streamlitextras.threader import lock, trigger_rerun, \
+ #streamlit_thread, get_thread, \
+ #last_trigger_time
+
+#other imports
+
+import warnings
+import json
+
+import base64, cv2
+import os, sys, re, random, datetime, time, math, toml
+import gc
+from PIL import Image, ImageFont, ImageDraw, ImageFilter
+from PIL.PngImagePlugin import PngInfo
+from scipy import integrate
+import torch
+from torchdiffeq import odeint
+import k_diffusion as K
+import math, requests
+import mimetypes
+import numpy as np
+import pynvml
+import threading
+import torch, torchvision
+from torch import autocast
+from torchvision import transforms
+import torch.nn as nn
+from omegaconf import OmegaConf
+import yaml
+from pathlib import Path
+from contextlib import nullcontext
+from einops import rearrange, repeat
+from ldm.util import instantiate_from_config
+from retry import retry
+from slugify import slugify
+import skimage
+import piexif
+import piexif.helper
+from tqdm import trange
+from ldm.models.diffusion.ddim import DDIMSampler
+from ldm.util import ismap
+#from abc import ABC, abstractmethod
+from io import BytesIO
+from packaging import version
+from pathlib import Path
+from huggingface_hub import hf_hub_download
+import shutup
+
+#import librosa
+from nataili.util.logger import logger, set_logger_verbosity, quiesce_logger
+from nataili.esrgan import esrgan
+
+
+#try:
+ #from realesrgan import RealESRGANer
+ #from basicsr.archs.rrdbnet_arch import RRDBNet
+#except ImportError as e:
+ #logger.error("You tried to import realesrgan without having it installed properly. To install Real-ESRGAN, run:\n\n"
+ #"pip install realesrgan")
+
+# Temp imports
+#from basicsr.utils.registry import ARCH_REGISTRY
+
+
+# end of imports
+#---------------------------------------------------------------------------------------------------------------
+
+# remove all the annoying python warnings.
+shutup.please()
+
+# the following lines should help fixing an issue with nvidia 16xx cards.
+if "defaults" in st.session_state:
+ if st.session_state["defaults"].general.use_cudnn:
+ torch.backends.cudnn.benchmark = True
+ torch.backends.cudnn.enabled = True
+
+try:
+ # this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
+ from transformers import logging
+
+ logging.set_verbosity_error()
+except:
+ pass
+
+# disable diffusers telemetry
+os.environ["DISABLE_TELEMETRY"] = "YES"
+
+# remove some annoying deprecation warnings that show every now and then.
+warnings.filterwarnings("ignore", category=DeprecationWarning)
+warnings.filterwarnings("ignore", category=UserWarning)
+
+# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the bowser will not show any UI
+mimetypes.init()
+mimetypes.add_type('application/javascript', '.js')
+
+# some of those options should not be changed at all because they would break the model, so I removed them from options.
+opt_C = 4
+opt_f = 8
+
+# The model manager loads and unloads the SD models and has features to download them or find their location
+#model_manager = ModelManager()
+
+def load_configs():
+ if not "defaults" in st.session_state:
+ st.session_state["defaults"] = {}
+
+ st.session_state["defaults"] = OmegaConf.load("configs/webui/webui_streamlit.yaml")
+
+ if (os.path.exists("configs/webui/userconfig_streamlit.yaml")):
+ user_defaults = OmegaConf.load("configs/webui/userconfig_streamlit.yaml")
+
+ if "version" in user_defaults.general:
+ if version.parse(user_defaults.general.version) < version.parse(st.session_state["defaults"].general.version):
+ logger.error("The version of the user config file is older than the version on the defaults config file. "
+ "This means there were big changes we made on the config."
+ "We are removing this file and recreating it from the defaults in order to make sure things work properly.")
+ os.remove("configs/webui/userconfig_streamlit.yaml")
+ st.experimental_rerun()
+ else:
+ logger.error("The version of the user config file is older than the version on the defaults config file. "
+ "This means there were big changes we made on the config."
+ "We are removing this file and recreating it from the defaults in order to make sure things work properly.")
+ os.remove("configs/webui/userconfig_streamlit.yaml")
+ st.experimental_rerun()
+
+ try:
+ st.session_state["defaults"] = OmegaConf.merge(st.session_state["defaults"], user_defaults)
+ except KeyError:
+ st.experimental_rerun()
+ else:
+ OmegaConf.save(config=st.session_state.defaults, f="configs/webui/userconfig_streamlit.yaml")
+ loaded = OmegaConf.load("configs/webui/userconfig_streamlit.yaml")
+ assert st.session_state.defaults == loaded
+
+ if (os.path.exists(".streamlit/config.toml")):
+ st.session_state["streamlit_config"] = toml.load(".streamlit/config.toml")
+
+ #if st.session_state["defaults"].daisi_app.running_on_daisi_io:
+ #if os.path.exists("scripts/modeldownload.py"):
+ #import modeldownload
+ #modeldownload.updateModels()
+
+ if "keep_all_models_loaded" in st.session_state.defaults.general:
+ with server_state_lock["keep_all_models_loaded"]:
+ server_state["keep_all_models_loaded"] = st.session_state["defaults"].general.keep_all_models_loaded
+ else:
+ st.session_state["defaults"].general.keep_all_models_loaded = False
+ with server_state_lock["keep_all_models_loaded"]:
+ server_state["keep_all_models_loaded"] = st.session_state["defaults"].general.keep_all_models_loaded
+
+load_configs()
+
+#
+#if st.session_state["defaults"].debug.enable_hydralit:
+ #navbar_theme = {'txc_inactive': '#FFFFFF','menu_background':'#0e1117','txc_active':'black','option_active':'red'}
+ #app = st.HydraApp(title='Stable Diffusion WebUI', favicon="", use_cookie_cache=False, sidebar_state="expanded", layout="wide", navbar_theme=navbar_theme,
+ #hide_streamlit_markers=False, allow_url_nav=True , clear_cross_app_sessions=False, use_loader=False)
+#else:
+ #app = None
+
+#
+grid_format = st.session_state["defaults"].general.save_format
+grid_lossless = False
+grid_quality = st.session_state["defaults"].general.grid_quality
+if grid_format == 'png':
+ grid_ext = 'png'
+ grid_format = 'png'
+elif grid_format in ['jpg', 'jpeg']:
+ grid_quality = int(grid_format) if len(grid_format) > 1 else 100
+ grid_ext = 'jpg'
+ grid_format = 'jpeg'
+elif grid_format[0] == 'webp':
+ grid_quality = int(grid_format) if len(grid_format) > 1 else 100
+ grid_ext = 'webp'
+ grid_format = 'webp'
+ if grid_quality < 0: # e.g. webp:-100 for lossless mode
+ grid_lossless = True
+ grid_quality = abs(grid_quality)
+
+#
+save_format = st.session_state["defaults"].general.save_format
+save_lossless = False
+save_quality = 100
+if save_format == 'png':
+ save_ext = 'png'
+ save_format = 'png'
+elif save_format in ['jpg', 'jpeg']:
+ save_quality = int(save_format) if len(save_format) > 1 else 100
+ save_ext = 'jpg'
+ save_format = 'jpeg'
+elif save_format == 'webp':
+ save_quality = int(save_format) if len(save_format) > 1 else 100
+ save_ext = 'webp'
+ save_format = 'webp'
+ if save_quality < 0: # e.g. webp:-100 for lossless mode
+ save_lossless = True
+ save_quality = abs(save_quality)
+
+# this should force GFPGAN and RealESRGAN onto the selected gpu as well
+os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" # see issue #152
+os.environ["CUDA_VISIBLE_DEVICES"] = str(st.session_state["defaults"].general.gpu)
+
+
+# functions to load css locally OR remotely starts here. Options exist for future flexibility. Called as st.markdown with unsafe_allow_html as css injection
+# TODO, maybe look into async loading the file especially for remote fetching
+def local_css(file_name):
+ with open(file_name) as f:
+ st.markdown(f'', unsafe_allow_html=True)
+
+def remote_css(url):
+ st.markdown(f'', unsafe_allow_html=True)
+
+def load_css(isLocal, nameOrURL):
+ if(isLocal):
+ local_css(nameOrURL)
+ else:
+ remote_css(nameOrURL)
+
+def set_page_title(title):
+ """
+ Simple function to allows us to change the title dynamically.
+ Normally you can use `st.set_page_config` to change the title but it can only be used once per app.
+ """
+
+ st.sidebar.markdown(unsafe_allow_html=True, body=f"""
+
+ """)
+
+class MemUsageMonitor(threading.Thread):
+ stop_flag = False
+ max_usage = 0
+ total = -1
+
+ def __init__(self, name):
+ threading.Thread.__init__(self)
+ self.name = name
+
+ def run(self):
+ try:
+ pynvml.nvmlInit()
+ except:
+ logger.debug(f"[{self.name}] Unable to initialize NVIDIA management. No memory stats. \n")
+ return
+ logger.info(f"[{self.name}] Recording memory usage...\n")
+ # Missing context
+ #handle = pynvml.nvmlDeviceGetHandleByIndex(st.session_state['defaults'].general.gpu)
+ handle = pynvml.nvmlDeviceGetHandleByIndex(0)
+ self.total = pynvml.nvmlDeviceGetMemoryInfo(handle).total
+ while not self.stop_flag:
+ m = pynvml.nvmlDeviceGetMemoryInfo(handle)
+ self.max_usage = max(self.max_usage, m.used)
+ # logger.info(self.max_usage)
+ time.sleep(0.1)
+ logger.info(f"[{self.name}] Stopped recording.\n")
+ pynvml.nvmlShutdown()
+
+ def read(self):
+ return self.max_usage, self.total
+
+ def stop(self):
+ self.stop_flag = True
+
+ def read_and_stop(self):
+ self.stop_flag = True
+ return self.max_usage, self.total
+
+#
+def custom_models_available():
+ with server_state_lock["custom_models"]:
+ #
+ # Allow for custom models to be used instead of the default one,
+ # an example would be Waifu-Diffusion or any other fine tune of stable diffusion
+ server_state["custom_models"]:sorted = []
+
+ for root, dirs, files in os.walk(os.path.join("models", "custom")):
+ for file in files:
+ if os.path.splitext(file)[1] == '.ckpt':
+ server_state["custom_models"].append(os.path.splitext(file)[0])
+
+ with server_state_lock["CustomModel_available"]:
+ if len(server_state["custom_models"]) > 0:
+ server_state["CustomModel_available"] = True
+ server_state["custom_models"].append("Stable Diffusion v1.5")
+ else:
+ server_state["CustomModel_available"] = False
+
+#
+def GFPGAN_available():
+ #with server_state_lock["GFPGAN_models"]:
+ #
+
+ st.session_state["GFPGAN_models"]:sorted = []
+ model = st.session_state["defaults"].model_manager.models.gfpgan
+
+ files_available = 0
+
+ for file in model['files']:
+ if "save_location" in model['files'][file]:
+ if os.path.exists(os.path.join(model['files'][file]['save_location'], model['files'][file]['file_name'] )):
+ files_available += 1
+
+ elif os.path.exists(os.path.join(model['save_location'], model['files'][file]['file_name'] )):
+ base_name = os.path.splitext(model['files'][file]['file_name'])[0]
+ if "GFPGANv" in base_name:
+ st.session_state["GFPGAN_models"].append(base_name)
+ files_available += 1
+
+ # we need to show the other models from previous verions that we have on the
+ # same directory in case we want to see how they perform vs each other.
+ for root, dirs, files in os.walk(st.session_state['defaults'].general.GFPGAN_dir):
+ for file in files:
+ if os.path.splitext(file)[1] == '.pth':
+ if os.path.splitext(file)[0] not in st.session_state["GFPGAN_models"]:
+ st.session_state["GFPGAN_models"].append(os.path.splitext(file)[0])
+
+
+ if len(st.session_state["GFPGAN_models"]) > 0 and files_available == len(model['files']):
+ st.session_state["GFPGAN_available"] = True
+ else:
+ st.session_state["GFPGAN_available"] = False
+ st.session_state["use_GFPGAN"] = False
+ st.session_state["GFPGAN_model"] = "GFPGANv1.4"
+
+#
+def RealESRGAN_available():
+ #with server_state_lock["RealESRGAN_models"]:
+
+ st.session_state["RealESRGAN_models"]:sorted = []
+ model = st.session_state["defaults"].model_manager.models.realesrgan
+ for file in model['files']:
+ if os.path.exists(os.path.join(model['save_location'], model['files'][file]['file_name'] )):
+ base_name = os.path.splitext(model['files'][file]['file_name'])[0]
+ st.session_state["RealESRGAN_models"].append(base_name)
+
+ if len(st.session_state["RealESRGAN_models"]) > 0:
+ st.session_state["RealESRGAN_available"] = True
+ else:
+ st.session_state["RealESRGAN_available"] = False
+ st.session_state["use_RealESRGAN"] = False
+ st.session_state["RealESRGAN_model"] = "RealESRGAN_x4plus"
+#
+def LDSR_available():
+ st.session_state["LDSR_models"]:sorted = []
+ files_available = 0
+ model = st.session_state["defaults"].model_manager.models.ldsr
+ for file in model['files']:
+ if os.path.exists(os.path.join(model['save_location'], model['files'][file]['file_name'] )):
+ base_name = os.path.splitext(model['files'][file]['file_name'])[0]
+ extension = os.path.splitext(model['files'][file]['file_name'])[1]
+ if extension == ".ckpt":
+ st.session_state["LDSR_models"].append(base_name)
+ files_available += 1
+ if files_available == len(model['files']):
+ st.session_state["LDSR_available"] = True
+ else:
+ st.session_state["LDSR_available"] = False
+ st.session_state["use_LDSR"] = False
+ st.session_state["LDSR_model"] = "model"
diff --git a/webui/streamlit/scripts/sd_utils/bridge.py b/webui/streamlit/scripts/sd_utils/bridge.py
new file mode 100644
index 0000000..8b677a4
--- /dev/null
+++ b/webui/streamlit/scripts/sd_utils/bridge.py
@@ -0,0 +1,182 @@
+# This file is part of sygil-webui (https://github.com/Sygil-Dev/sygil-webui/).
+
+# Copyright 2022 Sygil-Dev team.
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+# base webui import and utils.
+#import streamlit as st
+
+# We import hydralit like this to replace the previous stuff
+# we had with native streamlit as it lets ur replace things 1:1
+from nataili.util import logger
+
+# streamlit imports
+
+#streamlit components section
+
+#other imports
+import requests, time, json, base64
+from io import BytesIO
+
+# import custom components
+
+
+# end of imports
+#---------------------------------------------------------------------------------------------------------------
+
+
+
+@logger.catch(reraise=True)
+def run_bridge(interval, api_key, horde_name, horde_url, priority_usernames, horde_max_pixels, horde_nsfw, horde_censor_nsfw, horde_blacklist, horde_censorlist):
+ current_id = None
+ current_payload = None
+ loop_retry = 0
+ # load the model for stable horde if its not in memory already
+ # we should load it after we get the request from the API in
+ # case the model is different from the loaded in memory but
+ # for now we can load it here so its read right away.
+ load_models(use_GFPGAN=True)
+ while True:
+
+ if loop_retry > 10 and current_id:
+ logger.info(f"Exceeded retry count {loop_retry} for generation id {current_id}. Aborting generation!")
+ current_id = None
+ current_payload = None
+ current_generation = None
+ loop_retry = 0
+ elif current_id:
+ logger.info(f"Retrying ({loop_retry}/10) for generation id {current_id}...")
+ gen_dict = {
+ "name": horde_name,
+ "max_pixels": horde_max_pixels,
+ "priority_usernames": priority_usernames,
+ "nsfw": horde_nsfw,
+ "blacklist": horde_blacklist,
+ "models": ["stable_diffusion"],
+ }
+ headers = {"apikey": api_key}
+ if current_id:
+ loop_retry += 1
+ else:
+ try:
+ pop_req = requests.post(horde_url + '/api/v2/generate/pop', json = gen_dict, headers = headers)
+ except requests.exceptions.ConnectionError:
+ logger.warning(f"Server {horde_url} unavailable during pop. Waiting 10 seconds...")
+ time.sleep(10)
+ continue
+ except requests.exceptions.JSONDecodeError():
+ logger.warning(f"Server {horde_url} unavailable during pop. Waiting 10 seconds...")
+ time.sleep(10)
+ continue
+ try:
+ pop = pop_req.json()
+ except json.decoder.JSONDecodeError:
+ logger.warning(f"Could not decode response from {horde_url} as json. Please inform its administrator!")
+ time.sleep(interval)
+ continue
+ if pop == None:
+ logger.warning(f"Something has gone wrong with {horde_url}. Please inform its administrator!")
+ time.sleep(interval)
+ continue
+ if not pop_req.ok:
+ message = pop['message']
+ logger.warning(f"During gen pop, server {horde_url} responded with status code {pop_req.status_code}: {pop['message']}. Waiting for 10 seconds...")
+ if 'errors' in pop:
+ logger.debug(f"Detailed Request Errors: {pop['errors']}")
+ time.sleep(10)
+ continue
+ if not pop.get("id"):
+ skipped_info = pop.get('skipped')
+ if skipped_info and len(skipped_info):
+ skipped_info = f" Skipped Info: {skipped_info}."
+ else:
+ skipped_info = ''
+ logger.info(f"Server {horde_url} has no valid generations to do for us.{skipped_info}")
+ time.sleep(interval)
+ continue
+ current_id = pop['id']
+ logger.info(f"Request with id {current_id} picked up. Initiating work...")
+ current_payload = pop['payload']
+ if 'toggles' in current_payload and current_payload['toggles'] == None:
+ logger.error(f"Received Bad payload: {pop}")
+ current_id = None
+ current_payload = None
+ current_generation = None
+ loop_retry = 0
+ time.sleep(10)
+ continue
+
+ logger.debug(current_payload)
+ current_payload['toggles'] = current_payload.get('toggles', [1,4])
+ # In bridge-mode, matrix is prepared on the horde and split in multiple nodes
+ if 0 in current_payload['toggles']:
+ current_payload['toggles'].remove(0)
+ if 8 not in current_payload['toggles']:
+ if horde_censor_nsfw and not horde_nsfw:
+ current_payload['toggles'].append(8)
+ elif any(word in current_payload['prompt'] for word in horde_censorlist):
+ current_payload['toggles'].append(8)
+
+ from txt2img import txt2img
+
+
+ """{'prompt': 'Centred Husky, inside spiral with circular patterns, trending on dribbble, knotwork, spirals, key patterns,
+ zoomorphics, ', 'ddim_steps': 30, 'n_iter': 1, 'sampler_name': 'DDIM', 'cfg_scale': 16.0, 'seed': '3405278433', 'height': 512, 'width': 512}"""
+
+ #images, seed, info, stats = txt2img(**current_payload)
+ images, seed, info, stats = txt2img(str(current_payload['prompt']), int(current_payload['ddim_steps']), str(current_payload['sampler_name']),
+ int(current_payload['n_iter']), 1, float(current_payload["cfg_scale"]), str(current_payload["seed"]),
+ int(current_payload["height"]), int(current_payload["width"]), save_grid=False, group_by_prompt=False,
+ save_individual_images=False,write_info_files=False)
+
+ buffer = BytesIO()
+ # We send as WebP to avoid using all the horde bandwidth
+ images[0].save(buffer, format="WebP", quality=90)
+ # logger.info(info)
+ submit_dict = {
+ "id": current_id,
+ "generation": base64.b64encode(buffer.getvalue()).decode("utf8"),
+ "api_key": api_key,
+ "seed": seed,
+ "max_pixels": horde_max_pixels,
+ }
+ current_generation = seed
+ while current_id and current_generation != None:
+ try:
+ submit_req = requests.post(horde_url + '/api/v2/generate/submit', json = submit_dict, headers = headers)
+ try:
+ submit = submit_req.json()
+ except json.decoder.JSONDecodeError:
+ logger.error(f"Something has gone wrong with {horde_url} during submit. Please inform its administrator! (Retry {loop_retry}/10)")
+ time.sleep(interval)
+ continue
+ if submit_req.status_code == 404:
+ logger.info(f"The generation we were working on got stale. Aborting!")
+ elif not submit_req.ok:
+ logger.error(f"During gen submit, server {horde_url} responded with status code {submit_req.status_code}: {submit['message']}. Waiting for 10 seconds... (Retry {loop_retry}/10)")
+ if 'errors' in submit:
+ logger.debug(f"Detailed Request Errors: {submit['errors']}")
+ time.sleep(10)
+ continue
+ else:
+ logger.info(f'Submitted generation with id {current_id} and contributed for {submit_req.json()["reward"]}')
+ current_id = None
+ current_payload = None
+ current_generation = None
+ loop_retry = 0
+ except requests.exceptions.ConnectionError:
+ logger.warning(f"Server {horde_url} unavailable during submit. Waiting 10 seconds... (Retry {loop_retry}/10)")
+ time.sleep(10)
+ continue
+ time.sleep(interval)
diff --git a/webui/streamlit/scripts/textual_inversion.py b/webui/streamlit/scripts/textual_inversion.py
new file mode 100644
index 0000000..317fb85
--- /dev/null
+++ b/webui/streamlit/scripts/textual_inversion.py
@@ -0,0 +1,938 @@
+# This file is part of sygil-webui (https://github.com/Sygil-Dev/sygil-webui/).
+
+# Copyright 2022 Sygil-Dev team.
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+# base webui import and utils.
+from sd_utils import st, set_page_title, seed_to_int
+
+# streamlit imports
+from streamlit.runtime.scriptrunner import StopException
+from streamlit_tensorboard import st_tensorboard
+
+#streamlit components section
+from streamlit_server_state import server_state
+
+#other imports
+from transformers import CLIPTextModel, CLIPTokenizer
+
+# Temp imports
+
+import itertools
+import math
+import os
+import random
+#import datetime
+#from pathlib import Path
+#from typing import Optional
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+from torch.utils.data import Dataset
+
+import PIL
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import set_seed
+from diffusers import AutoencoderKL, DDPMScheduler, LMSDiscreteScheduler, StableDiffusionPipeline, UNet2DConditionModel#, PNDMScheduler
+from diffusers.optimization import get_scheduler
+#from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
+from pipelines.stable_diffusion.no_check import NoCheck
+from PIL import Image
+from torchvision import transforms
+from tqdm.auto import tqdm
+from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
+from slugify import slugify
+import json
+import os#, subprocess
+#from io import StringIO
+
+
+# end of imports
+#---------------------------------------------------------------------------------------------------------------
+
+logger = get_logger(__name__)
+
+imagenet_templates_small = [
+ "a photo of a {}",
+ "a rendering of a {}",
+ "a cropped photo of the {}",
+ "the photo of a {}",
+ "a photo of a clean {}",
+ "a photo of a dirty {}",
+ "a dark photo of the {}",
+ "a photo of my {}",
+ "a photo of the cool {}",
+ "a close-up photo of a {}",
+ "a bright photo of the {}",
+ "a cropped photo of a {}",
+ "a photo of the {}",
+ "a good photo of the {}",
+ "a photo of one {}",
+ "a close-up photo of the {}",
+ "a rendition of the {}",
+ "a photo of the clean {}",
+ "a rendition of a {}",
+ "a photo of a nice {}",
+ "a good photo of a {}",
+ "a photo of the nice {}",
+ "a photo of the small {}",
+ "a photo of the weird {}",
+ "a photo of the large {}",
+ "a photo of a cool {}",
+ "a photo of a small {}",
+]
+
+imagenet_style_templates_small = [
+ "a painting in the style of {}",
+ "a rendering in the style of {}",
+ "a cropped painting in the style of {}",
+ "the painting in the style of {}",
+ "a clean painting in the style of {}",
+ "a dirty painting in the style of {}",
+ "a dark painting in the style of {}",
+ "a picture in the style of {}",
+ "a cool painting in the style of {}",
+ "a close-up painting in the style of {}",
+ "a bright painting in the style of {}",
+ "a cropped painting in the style of {}",
+ "a good painting in the style of {}",
+ "a close-up painting in the style of {}",
+ "a rendition in the style of {}",
+ "a nice painting in the style of {}",
+ "a small painting in the style of {}",
+ "a weird painting in the style of {}",
+ "a large painting in the style of {}",
+]
+
+class TextualInversionDataset(Dataset):
+ def __init__(
+ self,
+ data_root,
+ tokenizer,
+ learnable_property="object", # [object, style]
+ size=512,
+ repeats=100,
+ interpolation="bicubic",
+ set="train",
+ placeholder_token="*",
+ center_crop=False,
+ templates=None
+ ):
+
+ self.data_root = data_root
+ self.tokenizer = tokenizer
+ self.learnable_property = learnable_property
+ self.size = size
+ self.placeholder_token = placeholder_token
+ self.center_crop = center_crop
+
+ self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root) if file_path.lower().endswith(('.png', '.jpg', '.jpeg'))]
+
+ self.num_images = len(self.image_paths)
+ self._length = self.num_images
+
+ if set == "train":
+ self._length = self.num_images * repeats
+
+ self.interpolation = {
+ "linear": PIL.Image.LINEAR,
+ "bilinear": PIL.Image.Resampling.BILINEAR,
+ "bicubic": PIL.Image.Resampling.BICUBIC,
+ "lanczos": PIL.Image.Resampling.LANCZOS,
+ }[interpolation]
+
+ self.templates = templates
+ self.cache = {}
+ self.tokenized_templates = [self.tokenizer(
+ text.format(self.placeholder_token),
+ padding="max_length",
+ truncation=True,
+ max_length=self.tokenizer.model_max_length,
+ return_tensors="pt",
+ ).input_ids[0] for text in self.templates]
+
+ def __len__(self):
+ return self._length
+
+ def get_example(self, image_path, flipped):
+ if image_path in self.cache:
+ return self.cache[image_path]
+
+ example = {}
+ image = Image.open(image_path)
+
+ if not image.mode == "RGB":
+ image = image.convert("RGB")
+
+ # default to score-sde preprocessing
+ img = np.array(image).astype(np.uint8)
+ if self.center_crop:
+ crop = min(img.shape[0], img.shape[1])
+ h, w, = (
+ img.shape[0],
+ img.shape[1],
+ )
+ img = img[(h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2]
+ image = Image.fromarray(img)
+ image = image.resize((self.size, self.size), resample=self.interpolation)
+ image = transforms.RandomHorizontalFlip(p=1 if flipped else 0)(image)
+ image = np.array(image).astype(np.uint8)
+ image = (image / 127.5 - 1.0).astype(np.float32)
+ example["key"] = "-".join([image_path, "-", str(flipped)])
+ example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1)
+
+ self.cache[image_path] = example
+ return example
+
+ def __getitem__(self, i):
+ flipped = random.choice([False, True])
+ example = self.get_example(self.image_paths[i % self.num_images], flipped)
+ example["input_ids"] = random.choice(self.tokenized_templates)
+ return example
+
+
+def freeze_params(params):
+ for param in params:
+ param.requires_grad = False
+
+
+def save_resume_file(basepath, extra = {}, config=''):
+ info = {"args": config["args"]}
+ info["args"].update(extra)
+
+ with open(f"{os.path.join(basepath, 'resume.json')}", "w") as f:
+ #print (info)
+ json.dump(info, f, indent=4)
+
+ with open(f"{basepath}/token_identifier.txt", "w") as f:
+ f.write(f"{config['args']['placeholder_token']}")
+
+ with open(f"{basepath}/type_of_concept.txt", "w") as f:
+ f.write(f"{config['args']['learnable_property']}")
+
+ config['args'] = info["args"]
+
+ return config['args']
+
+class Checkpointer:
+ def __init__(
+ self,
+ accelerator,
+ vae,
+ unet,
+ tokenizer,
+ placeholder_token,
+ placeholder_token_id,
+ templates,
+ output_dir,
+ random_sample_batches,
+ sample_batch_size,
+ stable_sample_batches,
+ seed
+ ):
+ self.accelerator = accelerator
+ self.vae = vae
+ self.unet = unet
+ self.tokenizer = tokenizer
+ self.placeholder_token = placeholder_token
+ self.placeholder_token_id = placeholder_token_id
+ self.templates = templates
+ self.output_dir = output_dir
+ self.seed = seed
+ self.random_sample_batches = random_sample_batches
+ self.sample_batch_size = sample_batch_size
+ self.stable_sample_batches = stable_sample_batches
+
+ @torch.no_grad()
+ def checkpoint(self, step, text_encoder, save_samples=True, path=None):
+ print("Saving checkpoint for step %d..." % step)
+ with torch.autocast("cuda"):
+ if path is None:
+ checkpoints_path = f"{self.output_dir}/checkpoints"
+ os.makedirs(checkpoints_path, exist_ok=True)
+
+ unwrapped = self.accelerator.unwrap_model(text_encoder)
+
+ # Save a checkpoint
+ learned_embeds = unwrapped.get_input_embeddings().weight[self.placeholder_token_id]
+ learned_embeds_dict = {self.placeholder_token: learned_embeds.detach().cpu()}
+
+ filename = f"%s_%d.bin" % (slugify(self.placeholder_token), step)
+ if path is not None:
+ torch.save(learned_embeds_dict, path)
+ else:
+ torch.save(learned_embeds_dict, f"{checkpoints_path}/{filename}")
+ torch.save(learned_embeds_dict, f"{checkpoints_path}/last.bin")
+
+ del unwrapped
+ del learned_embeds
+
+
+ @torch.no_grad()
+ def save_samples(self, step, text_encoder, height, width, guidance_scale, eta, num_inference_steps):
+ samples_path = f"{self.output_dir}/concept_images"
+ os.makedirs(samples_path, exist_ok=True)
+
+ #if "checker" not in server_state['textual_inversion']:
+ #with server_state_lock['textual_inversion']["checker"]:
+ server_state['textual_inversion']["checker"] = NoCheck()
+
+ #if "unwrapped" not in server_state['textual_inversion']:
+ # with server_state_lock['textual_inversion']["unwrapped"]:
+ server_state['textual_inversion']["unwrapped"] = self.accelerator.unwrap_model(text_encoder)
+
+ #if "pipeline" not in server_state['textual_inversion']:
+ # with server_state_lock['textual_inversion']["pipeline"]:
+ # Save a sample image
+ server_state['textual_inversion']["pipeline"] = StableDiffusionPipeline(
+ text_encoder=server_state['textual_inversion']["unwrapped"],
+ vae=self.vae,
+ unet=self.unet,
+ tokenizer=self.tokenizer,
+ scheduler=LMSDiscreteScheduler(
+ beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
+ ),
+ safety_checker=NoCheck(),
+ feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
+ ).to("cuda")
+
+ server_state['textual_inversion']["pipeline"].enable_attention_slicing()
+
+ if self.stable_sample_batches > 0:
+ stable_latents = torch.randn(
+ (self.sample_batch_size, server_state['textual_inversion']["pipeline"].unet.in_channels, height // 8, width // 8),
+ device=server_state['textual_inversion']["pipeline"].device,
+ generator=torch.Generator(device=server_state['textual_inversion']["pipeline"].device).manual_seed(self.seed),
+ )
+
+ stable_prompts = [choice.format(self.placeholder_token) for choice in (self.templates * self.sample_batch_size)[:self.sample_batch_size]]
+
+ # Generate and save stable samples
+ for i in range(0, self.stable_sample_batches):
+ samples = server_state['textual_inversion']["pipeline"](
+ prompt=stable_prompts,
+ height=384,
+ latents=stable_latents,
+ width=384,
+ guidance_scale=guidance_scale,
+ eta=eta,
+ num_inference_steps=num_inference_steps,
+ output_type='pil'
+ )["sample"]
+
+ for idx, im in enumerate(samples):
+ filename = f"stable_sample_%d_%d_step_%d.png" % (i+1, idx+1, step)
+ im.save(f"{samples_path}/{filename}")
+ del samples
+ del stable_latents
+
+ prompts = [choice.format(self.placeholder_token) for choice in random.choices(self.templates, k=self.sample_batch_size)]
+ # Generate and save random samples
+ for i in range(0, self.random_sample_batches):
+ samples = server_state['textual_inversion']["pipeline"](
+ prompt=prompts,
+ height=384,
+ width=384,
+ guidance_scale=guidance_scale,
+ eta=eta,
+ num_inference_steps=num_inference_steps,
+ output_type='pil'
+ )["sample"]
+ for idx, im in enumerate(samples):
+ filename = f"step_%d_sample_%d_%d.png" % (step, i+1, idx+1)
+ im.save(f"{samples_path}/{filename}")
+ del samples
+
+ del server_state['textual_inversion']["checker"]
+ del server_state['textual_inversion']["unwrapped"]
+ del server_state['textual_inversion']["pipeline"]
+ torch.cuda.empty_cache()
+
+#@retry(RuntimeError, tries=5)
+def textual_inversion(config):
+ print ("Running textual inversion.")
+
+ #if "pipeline" in server_state["textual_inversion"]:
+ #del server_state['textual_inversion']["checker"]
+ #del server_state['textual_inversion']["unwrapped"]
+ #del server_state['textual_inversion']["pipeline"]
+ #torch.cuda.empty_cache()
+
+ global_step_offset = 0
+
+ #print(config['args']['resume_from'])
+ if config['args']['resume_from']:
+ try:
+ basepath = f"{config['args']['resume_from']}"
+
+ with open(f"{basepath}/resume.json", 'r') as f:
+ state = json.load(f)
+
+ global_step_offset = state["args"].get("global_step", 0)
+
+ print("Resuming state from %s" % config['args']['resume_from'])
+ print("We've trained %d steps so far" % global_step_offset)
+
+ except json.decoder.JSONDecodeError:
+ pass
+ else:
+ basepath = f"{config['args']['output_dir']}/{slugify(config['args']['placeholder_token'])}"
+ os.makedirs(basepath, exist_ok=True)
+
+
+ accelerator = Accelerator(
+ gradient_accumulation_steps=config['args']['gradient_accumulation_steps'],
+ mixed_precision=config['args']['mixed_precision']
+ )
+
+ # If passed along, set the training seed.
+ if config['args']['seed']:
+ set_seed(config['args']['seed'])
+
+ #if "tokenizer" not in server_state["textual_inversion"]:
+ # Load the tokenizer and add the placeholder token as a additional special token
+ #with server_state_lock['textual_inversion']["tokenizer"]:
+ if config['args']['tokenizer_name']:
+ server_state['textual_inversion']["tokenizer"] = CLIPTokenizer.from_pretrained(config['args']['tokenizer_name'])
+ elif config['args']['pretrained_model_name_or_path']:
+ server_state['textual_inversion']["tokenizer"] = CLIPTokenizer.from_pretrained(
+ config['args']['pretrained_model_name_or_path'] + '/tokenizer'
+ )
+
+ # Add the placeholder token in tokenizer
+ num_added_tokens = server_state['textual_inversion']["tokenizer"].add_tokens(config['args']['placeholder_token'])
+ if num_added_tokens == 0:
+ st.error(
+ f"The tokenizer already contains the token {config['args']['placeholder_token']}. Please pass a different"
+ " `placeholder_token` that is not already in the tokenizer."
+ )
+
+ # Convert the initializer_token, placeholder_token to ids
+ token_ids = server_state['textual_inversion']["tokenizer"].encode(config['args']['initializer_token'], add_special_tokens=False)
+ # Check if initializer_token is a single token or a sequence of tokens
+ if len(token_ids) > 1:
+ st.error("The initializer token must be a single token.")
+
+ initializer_token_id = token_ids[0]
+ placeholder_token_id = server_state['textual_inversion']["tokenizer"].convert_tokens_to_ids(config['args']['placeholder_token'])
+
+ #if "text_encoder" not in server_state['textual_inversion']:
+ # Load models and create wrapper for stable diffusion
+ #with server_state_lock['textual_inversion']["text_encoder"]:
+ server_state['textual_inversion']["text_encoder"] = CLIPTextModel.from_pretrained(
+ config['args']['pretrained_model_name_or_path'] + '/text_encoder',
+ )
+
+ #if "vae" not in server_state['textual_inversion']:
+ #with server_state_lock['textual_inversion']["vae"]:
+ server_state['textual_inversion']["vae"] = AutoencoderKL.from_pretrained(
+ config['args']['pretrained_model_name_or_path'] + '/vae',
+ )
+
+ #if "unet" not in server_state['textual_inversion']:
+ #with server_state_lock['textual_inversion']["unet"]:
+ server_state['textual_inversion']["unet"] = UNet2DConditionModel.from_pretrained(
+ config['args']['pretrained_model_name_or_path'] + '/unet',
+ )
+
+ base_templates = imagenet_style_templates_small if config['args']['learnable_property'] == "style" else imagenet_templates_small
+ if config['args']['custom_templates']:
+ templates = config['args']['custom_templates'].split(";")
+ else:
+ templates = base_templates
+
+ slice_size = server_state['textual_inversion']["unet"].config.attention_head_dim // 2
+ server_state['textual_inversion']["unet"].set_attention_slice(slice_size)
+
+ # Resize the token embeddings as we are adding new special tokens to the tokenizer
+ server_state['textual_inversion']["text_encoder"].resize_token_embeddings(len(server_state['textual_inversion']["tokenizer"]))
+
+ # Initialise the newly added placeholder token with the embeddings of the initializer token
+ token_embeds = server_state['textual_inversion']["text_encoder"].get_input_embeddings().weight.data
+
+ if "resume_checkpoint" in config['args']:
+ if config['args']['resume_checkpoint'] is not None:
+ token_embeds[placeholder_token_id] = torch.load(config['args']['resume_checkpoint'])[config['args']['placeholder_token']]
+ else:
+ token_embeds[placeholder_token_id] = token_embeds[initializer_token_id]
+
+ # Freeze vae and unet
+ freeze_params(server_state['textual_inversion']["vae"].parameters())
+ freeze_params(server_state['textual_inversion']["unet"].parameters())
+ # Freeze all parameters except for the token embeddings in text encoder
+ params_to_freeze = itertools.chain(
+ server_state['textual_inversion']["text_encoder"].text_model.encoder.parameters(),
+ server_state['textual_inversion']["text_encoder"].text_model.final_layer_norm.parameters(),
+ server_state['textual_inversion']["text_encoder"].text_model.embeddings.position_embedding.parameters(),
+ )
+ freeze_params(params_to_freeze)
+
+ checkpointer = Checkpointer(
+ accelerator=accelerator,
+ vae=server_state['textual_inversion']["vae"],
+ unet=server_state['textual_inversion']["unet"],
+ tokenizer=server_state['textual_inversion']["tokenizer"],
+ placeholder_token=config['args']['placeholder_token'],
+ placeholder_token_id=placeholder_token_id,
+ templates=templates,
+ output_dir=basepath,
+ sample_batch_size=config['args']['sample_batch_size'],
+ random_sample_batches=config['args']['random_sample_batches'],
+ stable_sample_batches=config['args']['stable_sample_batches'],
+ seed=config['args']['seed']
+ )
+
+ if config['args']['scale_lr']:
+ config['args']['learning_rate'] = (
+ config['args']['learning_rate'] * config[
+ 'args']['gradient_accumulation_steps'] * config['args']['train_batch_size'] * accelerator.num_processes
+ )
+
+ # Initialize the optimizer
+ optimizer = torch.optim.AdamW(
+ server_state['textual_inversion']["text_encoder"].get_input_embeddings().parameters(), # only optimize the embeddings
+ lr=config['args']['learning_rate'],
+ betas=(config['args']['adam_beta1'], config['args']['adam_beta2']),
+ weight_decay=config['args']['adam_weight_decay'],
+ eps=config['args']['adam_epsilon'],
+ )
+
+ # TODO (patil-suraj): load scheduler using args
+ noise_scheduler = DDPMScheduler(
+ beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, tensor_format="pt"
+ )
+
+ train_dataset = TextualInversionDataset(
+ data_root=config['args']['train_data_dir'],
+ tokenizer=server_state['textual_inversion']["tokenizer"],
+ size=config['args']['resolution'],
+ placeholder_token=config['args']['placeholder_token'],
+ repeats=config['args']['repeats'],
+ learnable_property=config['args']['learnable_property'],
+ center_crop=config['args']['center_crop'],
+ set="train",
+ templates=templates
+ )
+ train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=config['args']['train_batch_size'], shuffle=True)
+
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / config['args']['gradient_accumulation_steps'])
+ if config['args']['max_train_steps'] is None:
+ config['args']['max_train_steps'] = config['args']['num_train_epochs'] * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ lr_scheduler = get_scheduler(
+ config['args']['lr_scheduler'],
+ optimizer=optimizer,
+ num_warmup_steps=config['args']['lr_warmup_steps'] * config['args']['gradient_accumulation_steps'],
+ num_training_steps=config['args']['max_train_steps'] * config['args']['gradient_accumulation_steps'],
+ )
+
+ server_state['textual_inversion']["text_encoder"], optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ server_state['textual_inversion']["text_encoder"], optimizer, train_dataloader, lr_scheduler
+ )
+
+ # Move vae and unet to device
+ server_state['textual_inversion']["vae"].to(accelerator.device)
+ server_state['textual_inversion']["unet"].to(accelerator.device)
+
+ # Keep vae and unet in eval mode as we don't train these
+ server_state['textual_inversion']["vae"].eval()
+ server_state['textual_inversion']["unet"].eval()
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / config['args']['gradient_accumulation_steps'])
+ if overrode_max_train_steps:
+ config['args']['max_train_steps'] = config['args']['num_train_epochs'] * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ config['args']['num_train_epochs'] = math.ceil(config['args']['max_train_steps'] / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ accelerator.init_trackers("textual_inversion", config=config['args'])
+
+ # Train!
+ total_batch_size = config['args']['train_batch_size'] * accelerator.num_processes * st.session_state[
+ 'textual_inversion']['args']['gradient_accumulation_steps']
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num Epochs = {config['args']['num_train_epochs']}")
+ logger.info(f" Instantaneous batch size per device = {config['args']['train_batch_size']}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {config['args']['gradient_accumulation_steps']}")
+ logger.info(f" Total optimization steps = {config['args']['max_train_steps']}")
+ # Only show the progress bar once on each machine.
+ progress_bar = tqdm(range(config['args']['max_train_steps']), disable=not accelerator.is_local_main_process)
+ progress_bar.set_description("Steps")
+ global_step = 0
+ encoded_pixel_values_cache = {}
+
+ try:
+ for epoch in range(config['args']['num_train_epochs']):
+ server_state['textual_inversion']["text_encoder"].train()
+ for step, batch in enumerate(train_dataloader):
+ with accelerator.accumulate(server_state['textual_inversion']["text_encoder"]):
+ # Convert images to latent space
+ key = "|".join(batch["key"])
+ if encoded_pixel_values_cache.get(key, None) is None:
+ encoded_pixel_values_cache[key] = server_state['textual_inversion']["vae"].encode(batch["pixel_values"]).latent_dist
+ latents = encoded_pixel_values_cache[key].sample().detach().half() * 0.18215
+
+ # Sample noise that we'll add to the latents
+ noise = torch.randn(latents.shape).to(latents.device)
+ bsz = latents.shape[0]
+ # Sample a random timestep for each image
+ timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device).long()
+
+ # Add noise to the latents according to the noise magnitude at each timestep
+ # (this is the forward diffusion process)
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
+
+ # Get the text embedding for conditioning
+ encoder_hidden_states = server_state['textual_inversion']["text_encoder"](batch["input_ids"])[0]
+
+ # Predict the noise residual
+ noise_pred = server_state['textual_inversion']["unet"](noisy_latents, timesteps, encoder_hidden_states).sample
+
+ loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
+ accelerator.backward(loss)
+
+ # Zero out the gradients for all token embeddings except the newly added
+ # embeddings for the concept, as we only want to optimize the concept embeddings
+ if accelerator.num_processes > 1:
+ grads = server_state['textual_inversion']["text_encoder"].module.get_input_embeddings().weight.grad
+ else:
+ grads = server_state['textual_inversion']["text_encoder"].get_input_embeddings().weight.grad
+ # Get the index for tokens that we want to zero the grads for
+ index_grads_to_zero = torch.arange(len(server_state['textual_inversion']["tokenizer"])) != placeholder_token_id
+ grads.data[index_grads_to_zero, :] = grads.data[index_grads_to_zero, :].fill_(0)
+
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ #try:
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ progress_bar.update(1)
+ global_step += 1
+
+ if global_step % config['args']['checkpoint_frequency'] == 0 and global_step > 0 and accelerator.is_main_process:
+ checkpointer.checkpoint(global_step + global_step_offset, server_state['textual_inversion']["text_encoder"])
+ save_resume_file(basepath, {
+ "global_step": global_step + global_step_offset,
+ "resume_checkpoint": f"{basepath}/checkpoints/last.bin"
+ }, config)
+
+ checkpointer.save_samples(
+ global_step + global_step_offset,
+ server_state['textual_inversion']["text_encoder"],
+ config['args']['resolution'], config['args'][
+ 'resolution'], 7.5, 0.0, config['args']['sample_steps'])
+
+ checkpointer.checkpoint(
+ global_step + global_step_offset,
+ server_state['textual_inversion']["text_encoder"],
+ path=f"{basepath}/learned_embeds.bin"
+ )
+ #except KeyError:
+ #raise StopException
+
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+
+ #accelerator.log(logs, step=global_step)
+
+ #try:
+ if global_step >= config['args']['max_train_steps']:
+ break
+ #except:
+ #pass
+
+ accelerator.wait_for_everyone()
+
+ # Create the pipeline using the trained modules and save it.
+ if accelerator.is_main_process:
+ print("Finished! Saving final checkpoint and resume state.")
+ checkpointer.checkpoint(
+ global_step + global_step_offset,
+ server_state['textual_inversion']["text_encoder"],
+ path=f"{basepath}/learned_embeds.bin"
+ )
+
+ save_resume_file(basepath, {
+ "global_step": global_step + global_step_offset,
+ "resume_checkpoint": f"{basepath}/checkpoints/last.bin"
+ }, config)
+
+ accelerator.end_training()
+
+ except (KeyboardInterrupt, StopException) as e:
+ print(f"Received Streamlit StopException or KeyboardInterrupt")
+
+ if accelerator.is_main_process:
+ print("Interrupted, saving checkpoint and resume state...")
+ checkpointer.checkpoint(global_step + global_step_offset, server_state['textual_inversion']["text_encoder"])
+
+ config['args'] = save_resume_file(basepath, {
+ "global_step": global_step + global_step_offset,
+ "resume_checkpoint": f"{basepath}/checkpoints/last.bin"
+ }, config)
+
+
+ checkpointer.checkpoint(
+ global_step + global_step_offset,
+ server_state['textual_inversion']["text_encoder"],
+ path=f"{basepath}/learned_embeds.bin"
+ )
+
+ quit()
+
+
+def layout():
+
+ with st.form("textual-inversion"):
+ #st.info("Under Construction. :construction_worker:")
+ #parser = argparse.ArgumentParser(description="Simple example of a training script.")
+
+ set_page_title("Textual Inversion - Stable Diffusion Playground")
+
+ config_tab, output_tab, tensorboard_tab = st.tabs(["Textual Inversion Config", "Ouput", "TensorBoard"])
+
+ with config_tab:
+ col1, col2, col3, col4, col5 = st.columns(5, gap='large')
+
+ if "textual_inversion" not in st.session_state:
+ st.session_state["textual_inversion"] = {}
+
+ if "textual_inversion" not in server_state:
+ server_state["textual_inversion"] = {}
+
+ if "args" not in st.session_state["textual_inversion"]:
+ st.session_state["textual_inversion"]["args"] = {}
+
+
+ with col1:
+ st.session_state["textual_inversion"]["args"]["pretrained_model_name_or_path"] = st.text_input("Pretrained Model Path",
+ value=st.session_state["defaults"].textual_inversion.pretrained_model_name_or_path,
+ help="Path to pretrained model or model identifier from huggingface.co/models.")
+
+ st.session_state["textual_inversion"]["args"]["tokenizer_name"] = st.text_input("Tokenizer Name",
+ value=st.session_state["defaults"].textual_inversion.tokenizer_name,
+ help="Pretrained tokenizer name or path if not the same as model_name")
+
+ st.session_state["textual_inversion"]["args"]["train_data_dir"] = st.text_input("train_data_dir", value="", help="A folder containing the training data.")
+
+ st.session_state["textual_inversion"]["args"]["placeholder_token"] = st.text_input("Placeholder Token", value="", help="A token to use as a placeholder for the concept.")
+
+ st.session_state["textual_inversion"]["args"]["initializer_token"] = st.text_input("Initializer Token", value="", help="A token to use as initializer word.")
+
+ st.session_state["textual_inversion"]["args"]["learnable_property"] = st.selectbox("Learnable Property", ["object", "style"], index=0, help="Choose between 'object' and 'style'")
+
+ st.session_state["textual_inversion"]["args"]["repeats"] = int(st.text_input("Number of times to Repeat", value=100, help="How many times to repeat the training data."))
+
+ with col2:
+ st.session_state["textual_inversion"]["args"]["output_dir"] = st.text_input("Output Directory",
+ value=str(os.path.join("outputs", "textual_inversion")),
+ help="The output directory where the model predictions and checkpoints will be written.")
+
+ st.session_state["textual_inversion"]["args"]["seed"] = seed_to_int(st.text_input("Seed", value=0,
+ help="A seed for reproducible training, if left empty a random one will be generated. Default: 0"))
+
+ st.session_state["textual_inversion"]["args"]["resolution"] = int(st.text_input("Resolution", value=512,
+ help="The resolution for input images, all the images in the train/validation dataset will be resized to this resolution"))
+
+ st.session_state["textual_inversion"]["args"]["center_crop"] = st.checkbox("Center Image", value=True, help="Whether to center crop images before resizing to resolution")
+
+ st.session_state["textual_inversion"]["args"]["train_batch_size"] = int(st.text_input("Train Batch Size", value=1, help="Batch size (per device) for the training dataloader."))
+
+ st.session_state["textual_inversion"]["args"]["num_train_epochs"] = int(st.text_input("Number of Steps to Train", value=100, help="Number of steps to train."))
+
+ st.session_state["textual_inversion"]["args"]["max_train_steps"] = int(st.text_input("Max Number of Steps to Train", value=5000,
+ help="Total number of training steps to perform. If provided, overrides 'Number of Steps to Train'."))
+
+ with col3:
+ st.session_state["textual_inversion"]["args"]["gradient_accumulation_steps"] = int(st.text_input("Gradient Accumulation Steps", value=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass."))
+
+ st.session_state["textual_inversion"]["args"]["learning_rate"] = float(st.text_input("Learning Rate", value=5.0e-04,
+ help="Initial learning rate (after the potential warmup period) to use."))
+
+ st.session_state["textual_inversion"]["args"]["scale_lr"] = st.checkbox("Scale Learning Rate", value=True,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.")
+
+ st.session_state["textual_inversion"]["args"]["lr_scheduler"] = st.text_input("Learning Rate Scheduler", value="constant",
+ help=("The scheduler type to use. Choose between ['linear', 'cosine', 'cosine_with_restarts', 'polynomial',"
+ " 'constant', 'constant_with_warmup']" ))
+
+ st.session_state["textual_inversion"]["args"]["lr_warmup_steps"] = int(st.text_input("Learning Rate Warmup Steps", value=500, help="Number of steps for the warmup in the lr scheduler."))
+
+ st.session_state["textual_inversion"]["args"]["adam_beta1"] = float(st.text_input("Adam Beta 1", value=0.9, help="The beta1 parameter for the Adam optimizer."))
+
+ st.session_state["textual_inversion"]["args"]["adam_beta2"] = float(st.text_input("Adam Beta 2", value=0.999, help="The beta2 parameter for the Adam optimizer."))
+
+ st.session_state["textual_inversion"]["args"]["adam_weight_decay"] = float(st.text_input("Adam Weight Decay", value=1e-2, help="Weight decay to use."))
+
+ st.session_state["textual_inversion"]["args"]["adam_epsilon"] = float(st.text_input("Adam Epsilon", value=1e-08, help="Epsilon value for the Adam optimizer"))
+
+ with col4:
+ st.session_state["textual_inversion"]["args"]["mixed_precision"] = st.selectbox("Mixed Precision", ["no", "fp16", "bf16"], index=1,
+ help="Whether to use mixed precision. Choose" "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
+ "and an Nvidia Ampere GPU.")
+
+ st.session_state["textual_inversion"]["args"]["local_rank"] = int(st.text_input("Local Rank", value=1, help="For distributed training: local_rank"))
+
+ st.session_state["textual_inversion"]["args"]["checkpoint_frequency"] = int(st.text_input("Checkpoint Frequency", value=500, help="How often to save a checkpoint and sample image"))
+
+ # stable_sample_batches is crashing when saving the samples so for now I will disable it util its fixed.
+ #st.session_state["textual_inversion"]["args"]["stable_sample_batches"] = int(st.text_input("Stable Sample Batches", value=0,
+ #help="Number of fixed seed sample batches to generate per checkpoint"))
+
+ st.session_state["textual_inversion"]["args"]["stable_sample_batches"] = 0
+
+ st.session_state["textual_inversion"]["args"]["random_sample_batches"] = int(st.text_input("Random Sample Batches", value=2,
+ help="Number of random seed sample batches to generate per checkpoint"))
+
+ st.session_state["textual_inversion"]["args"]["sample_batch_size"] = int(st.text_input("Sample Batch Size", value=1, help="Number of samples to generate per batch"))
+
+ st.session_state["textual_inversion"]["args"]["sample_steps"] = int(st.text_input("Sample Steps", value=100,
+ help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes."))
+
+ st.session_state["textual_inversion"]["args"]["custom_templates"] = st.text_input("Custom Templates", value="",
+ help="A semicolon-delimited list of custom template to use for samples, using {} as a placeholder for the concept.")
+ with col5:
+ st.session_state["textual_inversion"]["args"]["resume"] = st.checkbox(label="Resume Previous Run?", value=False,
+ help="Resume previous run, if a valid resume.json file is on the output dir \
+ it will be used, otherwise if the 'Resume From' field bellow contains a valid resume.json file \
+ that one will be used.")
+
+ st.session_state["textual_inversion"]["args"]["resume_from"] = st.text_input(label="Resume From", help="Path to a directory to resume training from (ie, logs/token_name)")
+
+ #st.session_state["textual_inversion"]["args"]["resume_checkpoint"] = st.file_uploader("Resume Checkpoint", type=["bin"],
+ #help="Path to a specific checkpoint to resume training from (ie, logs/token_name/checkpoints/something.bin).")
+
+ #st.session_state["textual_inversion"]["args"]["st.session_state["textual_inversion"]"] = st.file_uploader("st.session_state["textual_inversion"] File", type=["json"],
+ #help="Path to a JSON st.session_state["textual_inversion"]uration file containing arguments for invoking this script."
+ #"If resume_from is given, its resume.json takes priority over this.")
+ #
+ #print (os.path.join(st.session_state["textual_inversion"]["args"]["output_dir"],st.session_state["textual_inversion"]["args"]["placeholder_token"].strip("<>"),"resume.json"))
+ #print (os.path.exists(os.path.join(st.session_state["textual_inversion"]["args"]["output_dir"],st.session_state["textual_inversion"]["args"]["placeholder_token"].strip("<>"),"resume.json")))
+ if os.path.exists(os.path.join(st.session_state["textual_inversion"]["args"]["output_dir"],st.session_state["textual_inversion"]["args"]["placeholder_token"].strip("<>"),"resume.json")):
+ st.session_state["textual_inversion"]["args"]["resume_from"] = os.path.join(
+ st.session_state["textual_inversion"]["args"]["output_dir"], st.session_state["textual_inversion"]["args"]["placeholder_token"].strip("<>"))
+ #print (st.session_state["textual_inversion"]["args"]["resume_from"])
+
+ if os.path.exists(os.path.join(st.session_state["textual_inversion"]["args"]["output_dir"],st.session_state["textual_inversion"]["args"]["placeholder_token"].strip("<>"), "checkpoints","last.bin")):
+ st.session_state["textual_inversion"]["args"]["resume_checkpoint"] = os.path.join(
+ st.session_state["textual_inversion"]["args"]["output_dir"], st.session_state["textual_inversion"]["args"]["placeholder_token"].strip("<>"), "checkpoints","last.bin")
+
+ #if "resume_from" in st.session_state["textual_inversion"]["args"]:
+ #if st.session_state["textual_inversion"]["args"]["resume_from"]:
+ #if os.path.exists(os.path.join(st.session_state["textual_inversion"]['args']['resume_from'], "resume.json")):
+ #with open(os.path.join(st.session_state["textual_inversion"]['args']['resume_from'], "resume.json"), 'rt') as f:
+ #try:
+ #resume_json = json.load(f)["args"]
+ #st.session_state["textual_inversion"]["args"] = OmegaConf.merge(st.session_state["textual_inversion"]["args"], resume_json)
+ #st.session_state["textual_inversion"]["args"]["resume_from"] = os.path.join(
+ #st.session_state["textual_inversion"]["args"]["output_dir"], st.session_state["textual_inversion"]["args"]["placeholder_token"].strip("<>"))
+ #except json.decoder.JSONDecodeError:
+ #pass
+
+ #print(st.session_state["textual_inversion"]["args"])
+ #print(st.session_state["textual_inversion"]["args"]['resume_from'])
+
+ #elif st.session_state["textual_inversion"]["args"]["st.session_state["textual_inversion"]"] is not None:
+ #with open(st.session_state["textual_inversion"]["args"]["st.session_state["textual_inversion"]"], 'rt') as f:
+ #args = parser.parse_args(namespace=argparse.Namespace(**json.load(f)["args"]))
+
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != st.session_state["textual_inversion"]["args"]["local_rank"]:
+ st.session_state["textual_inversion"]["args"]["local_rank"] = env_local_rank
+
+ if st.session_state["textual_inversion"]["args"]["train_data_dir"] is None:
+ st.error("You must specify --train_data_dir")
+
+ if st.session_state["textual_inversion"]["args"]["pretrained_model_name_or_path"] is None:
+ st.error("You must specify --pretrained_model_name_or_path")
+
+ if st.session_state["textual_inversion"]["args"]["placeholder_token"] is None:
+ st.error("You must specify --placeholder_token")
+
+ if st.session_state["textual_inversion"]["args"]["initializer_token"] is None:
+ st.error("You must specify --initializer_token")
+
+ if st.session_state["textual_inversion"]["args"]["output_dir"] is None:
+ st.error("You must specify --output_dir")
+
+ # add a spacer and the submit button for the form.
+
+ st.session_state["textual_inversion"]["message"] = st.empty()
+ st.session_state["textual_inversion"]["progress_bar"] = st.empty()
+
+ st.write("---")
+
+ submit = st.form_submit_button("Run",help="")
+ if submit:
+ if "pipe" in st.session_state:
+ del st.session_state["pipe"]
+ if "model" in st.session_state:
+ del st.session_state["model"]
+
+ set_page_title("Running Textual Inversion - Stable Diffusion WebUI")
+ #st.session_state["textual_inversion"]["message"].info("Textual Inversion Running. For more info check the progress on your console or the Ouput Tab.")
+
+ try:
+ #try:
+ # run textual inversion.
+ config = st.session_state['textual_inversion']
+ textual_inversion(config)
+ #except RuntimeError:
+ #if "pipeline" in server_state["textual_inversion"]:
+ #del server_state['textual_inversion']["checker"]
+ #del server_state['textual_inversion']["unwrapped"]
+ #del server_state['textual_inversion']["pipeline"]
+
+ # run textual inversion.
+ #config = st.session_state['textual_inversion']
+ #textual_inversion(config)
+
+ set_page_title("Textual Inversion - Stable Diffusion WebUI")
+
+ except StopException:
+ set_page_title("Textual Inversion - Stable Diffusion WebUI")
+ print(f"Received Streamlit StopException")
+
+ st.session_state["textual_inversion"]["message"].empty()
+
+ #
+ with output_tab:
+ st.info("Under Construction. :construction_worker:")
+
+ #st.info("Nothing to show yet. Maybe try running some training first.")
+
+ #st.session_state["textual_inversion"]["preview_image"] = st.empty()
+ #st.session_state["textual_inversion"]["progress_bar"] = st.empty()
+
+
+ with tensorboard_tab:
+ #st.info("Under Construction. :construction_worker:")
+
+ # Start TensorBoard
+ st_tensorboard(logdir=os.path.join("outputs", "textual_inversion"), port=8888)
+
diff --git a/webui/streamlit/scripts/txt2img.py b/webui/streamlit/scripts/txt2img.py
new file mode 100644
index 0000000..db3a731
--- /dev/null
+++ b/webui/streamlit/scripts/txt2img.py
@@ -0,0 +1,708 @@
+# This file is part of sygil-webui (https://github.com/Sygil-Dev/sygil-webui/).
+
+# Copyright 2022 Sygil-Dev team.
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+# base webui import and utils.
+from sd_utils import st, MemUsageMonitor, server_state, no_rerun, logger, set_page_title, \
+ custom_models_available, RealESRGAN_available, GFPGAN_available, \
+ LDSR_available
+ #load_models, hc, seed_to_int, \
+ #get_next_sequence_number, check_prompt_length, torch_gc, \
+ #save_sample, generation_callback, process_images, \
+ #KDiffusionSampler, \
+
+# streamlit imports
+from streamlit.runtime.scriptrunner import StopException
+
+#streamlit components section
+import streamlit_nested_layout #used to allow nested columns, just importing it is enought
+
+#from streamlit.elements import image as STImage
+import streamlit.components.v1 as components
+#from streamlit.runtime.media_file_manager import media_file_manager
+from streamlit.elements.image import image_to_url
+
+#other imports
+
+import base64, uuid
+import os, sys, datetime, time
+from PIL import Image
+import requests
+from slugify import slugify
+from ldm.models.diffusion.ddim import DDIMSampler
+from typing import Union
+from io import BytesIO
+from ldm.models.diffusion.ddim import DDIMSampler
+from ldm.models.diffusion.plms import PLMSSampler
+
+
+# streamlit components
+from custom_components import sygil_suggestions
+
+# Temp imports
+
+
+# end of imports
+#---------------------------------------------------------------------------------------------------------------
+
+sygil_suggestions.init()
+
+try:
+ # this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
+ from transformers import logging
+
+ logging.set_verbosity_error()
+except:
+ pass
+
+#
+# Dev mode (server)
+# _component_func = components.declare_component(
+# "sd-gallery",
+# url="http://localhost:3001",
+# )
+
+# Init Vuejs component
+_component_func = components.declare_component(
+ "sd-gallery", "./frontend/dists/sd-gallery/dist")
+
+def sdGallery(images=[], key=None):
+ component_value = _component_func(images=imgsToGallery(images), key=key, default="")
+ return component_value
+
+def imgsToGallery(images):
+ urls = []
+ for i in images:
+ # random string for id
+ random_id = str(uuid.uuid4())
+ url = image_to_url(
+ image=i,
+ image_id= random_id,
+ width=i.width,
+ clamp=False,
+ channels="RGB",
+ output_format="PNG"
+ )
+ # image_io = BytesIO()
+ # i.save(image_io, 'PNG')
+ # width, height = i.size
+ # image_id = "%s" % (str(images.index(i)))
+ # (data, mimetype) = STImage._normalize_to_bytes(image_io.getvalue(), width, 'auto')
+ # this_file = media_file_manager.add(data, mimetype, image_id)
+ # img_str = this_file.url
+ urls.append(url)
+
+ return urls
+
+
+class plugin_info():
+ plugname = "txt2img"
+ description = "Text to Image"
+ isTab = True
+ displayPriority = 1
+
+@logger.catch(reraise=True)
+def stable_horde(outpath, prompt, seed, sampler_name, save_grid, batch_size,
+ n_iter, steps, cfg_scale, width, height, prompt_matrix, use_GFPGAN, GFPGAN_model,
+ use_RealESRGAN, realesrgan_model_name, use_LDSR,
+ LDSR_model_name, ddim_eta, normalize_prompt_weights,
+ save_individual_images, sort_samples, write_info_files,
+ jpg_sample, variant_amount, variant_seed, api_key,
+ nsfw=True, censor_nsfw=False):
+
+ log = []
+
+ log.append("Generating image with Stable Horde.")
+
+ st.session_state["progress_bar_text"].code('\n'.join(log), language='')
+
+ # start time after garbage collection (or before?)
+ start_time = time.time()
+
+ # We will use this date here later for the folder name, need to start_time if not need
+ run_start_dt = datetime.datetime.now()
+
+ mem_mon = MemUsageMonitor('MemMon')
+ mem_mon.start()
+
+ os.makedirs(outpath, exist_ok=True)
+
+ sample_path = os.path.join(outpath, "samples")
+ os.makedirs(sample_path, exist_ok=True)
+
+ params = {
+ "sampler_name": "k_euler",
+ "toggles": [1,4],
+ "cfg_scale": cfg_scale,
+ "seed": str(seed),
+ "width": width,
+ "height": height,
+ "seed_variation": variant_seed if variant_seed else 1,
+ "steps": int(steps),
+ "n": int(n_iter)
+ # You can put extra params here if you wish
+ }
+
+ final_submit_dict = {
+ "prompt": prompt,
+ "params": params,
+ "nsfw": nsfw,
+ "censor_nsfw": censor_nsfw,
+ "trusted_workers": True,
+ "workers": []
+ }
+ log.append(final_submit_dict)
+
+ headers = {"apikey": api_key}
+ logger.debug(final_submit_dict)
+ st.session_state["progress_bar_text"].code('\n'.join(str(log)), language='')
+
+ horde_url = "https://stablehorde.net"
+
+ submit_req = requests.post(f'{horde_url}/api/v2/generate/async', json = final_submit_dict, headers = headers)
+ if submit_req.ok:
+ submit_results = submit_req.json()
+ logger.debug(submit_results)
+
+ log.append(submit_results)
+ st.session_state["progress_bar_text"].code(''.join(str(log)), language='')
+
+ req_id = submit_results['id']
+ is_done = False
+ while not is_done:
+ chk_req = requests.get(f'{horde_url}/api/v2/generate/check/{req_id}')
+ if not chk_req.ok:
+ logger.error(chk_req.text)
+ return
+ chk_results = chk_req.json()
+ logger.info(chk_results)
+ is_done = chk_results['done']
+ time.sleep(1)
+ retrieve_req = requests.get(f'{horde_url}/api/v2/generate/status/{req_id}')
+ if not retrieve_req.ok:
+ logger.error(retrieve_req.text)
+ return
+ results_json = retrieve_req.json()
+ # logger.debug(results_json)
+ results = results_json['generations']
+
+ output_images = []
+ comments = []
+ prompt_matrix_parts = []
+
+ if not st.session_state['defaults'].general.no_verify_input:
+ try:
+ check_prompt_length(prompt, comments)
+ except:
+ import traceback
+ logger.info("Error verifying input:", file=sys.stderr)
+ logger.info(traceback.format_exc(), file=sys.stderr)
+
+ all_prompts = batch_size * n_iter * [prompt]
+ all_seeds = [seed + x for x in range(len(all_prompts))]
+
+ for iter in range(len(results)):
+ b64img = results[iter]["img"]
+ base64_bytes = b64img.encode('utf-8')
+ img_bytes = base64.b64decode(base64_bytes)
+ img = Image.open(BytesIO(img_bytes))
+
+ sanitized_prompt = slugify(prompt)
+
+ prompts = all_prompts[iter * batch_size:(iter + 1) * batch_size]
+ #captions = prompt_matrix_parts[n * batch_size:(n + 1) * batch_size]
+ seeds = all_seeds[iter * batch_size:(iter + 1) * batch_size]
+
+ if sort_samples:
+ full_path = os.path.join(os.getcwd(), sample_path, sanitized_prompt)
+
+
+ sanitized_prompt = sanitized_prompt[:200-len(full_path)]
+ sample_path_i = os.path.join(sample_path, sanitized_prompt)
+
+ #print(f"output folder length: {len(os.path.join(os.getcwd(), sample_path_i))}")
+ #print(os.path.join(os.getcwd(), sample_path_i))
+
+ os.makedirs(sample_path_i, exist_ok=True)
+ base_count = get_next_sequence_number(sample_path_i)
+ filename = f"{base_count:05}-{steps}_{sampler_name}_{seeds[iter]}"
+ else:
+ full_path = os.path.join(os.getcwd(), sample_path)
+ sample_path_i = sample_path
+ base_count = get_next_sequence_number(sample_path_i)
+ filename = f"{base_count:05}-{steps}_{sampler_name}_{seed}_{sanitized_prompt}"[:200-len(full_path)] #same as before
+
+ save_sample(img, sample_path_i, filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale,
+ normalize_prompt_weights, use_GFPGAN, write_info_files, prompt_matrix, init_img=None,
+ denoising_strength=0.75, resize_mode=None, uses_loopback=False, uses_random_seed_loopback=False,
+ save_grid=save_grid,
+ sort_samples=sampler_name, sampler_name=sampler_name, ddim_eta=ddim_eta, n_iter=n_iter,
+ batch_size=batch_size, i=iter, save_individual_images=save_individual_images,
+ model_name="Stable Diffusion v1.5")
+
+ output_images.append(img)
+
+ # update image on the UI so we can see the progress
+ if "preview_image" in st.session_state:
+ st.session_state["preview_image"].image(img)
+
+ if "progress_bar_text" in st.session_state:
+ st.session_state["progress_bar_text"].empty()
+
+ #if len(results) > 1:
+ #final_filename = f"{iter}_{filename}"
+ #img.save(final_filename)
+ #logger.info(f"Saved {final_filename}")
+ else:
+ if "progress_bar_text" in st.session_state:
+ st.session_state["progress_bar_text"].error(submit_req.text)
+
+ logger.error(submit_req.text)
+
+ mem_max_used, mem_total = mem_mon.read_and_stop()
+ time_diff = time.time()-start_time
+
+ info = f"""
+ {prompt}
+ Steps: {steps}, Sampler: {sampler_name}, CFG scale: {cfg_scale}, Seed: {seed}{', GFPGAN' if use_GFPGAN else ''}{', '+realesrgan_model_name if use_RealESRGAN else ''}
+ {', Prompt Matrix Mode.' if prompt_matrix else ''}""".strip()
+
+ stats = f'''
+ Took { round(time_diff, 2) }s total ({ round(time_diff/(len(all_prompts)),2) }s per image)
+ Peak memory usage: { -(mem_max_used // -1_048_576) } MiB / { -(mem_total // -1_048_576) } MiB / { round(mem_max_used/mem_total*100, 3) }%'''
+
+ for comment in comments:
+ info += "\n\n" + comment
+
+ #mem_mon.stop()
+ #del mem_mon
+ torch_gc()
+
+ return output_images, seed, info, stats
+
+
+#
+@logger.catch(reraise=True)
+def txt2img(prompt: str, ddim_steps: int, sampler_name: str, n_iter: int, batch_size: int, cfg_scale: float, seed: Union[int, str, None],
+ height: int, width: int, separate_prompts:bool = False, normalize_prompt_weights:bool = True,
+ save_individual_images: bool = True, save_grid: bool = True, group_by_prompt: bool = True,
+ save_as_jpg: bool = True, use_GFPGAN: bool = True, GFPGAN_model: str = 'GFPGANv1.3', use_RealESRGAN: bool = False,
+ RealESRGAN_model: str = "RealESRGAN_x4plus_anime_6B", use_LDSR: bool = True, LDSR_model: str = "model",
+ fp = None, variant_amount: float = 0.0,
+ variant_seed: int = None, ddim_eta:float = 0.0, write_info_files:bool = True,
+ use_stable_horde: bool = False, stable_horde_key:str = "0000000000"):
+
+ outpath = st.session_state['defaults'].general.outdir_txt2img
+
+ seed = seed_to_int(seed)
+
+ if not use_stable_horde:
+
+ if sampler_name == 'PLMS':
+ sampler = PLMSSampler(server_state["model"])
+ elif sampler_name == 'DDIM':
+ sampler = DDIMSampler(server_state["model"])
+ elif sampler_name == 'k_dpm_2_a':
+ sampler = KDiffusionSampler(server_state["model"],'dpm_2_ancestral')
+ elif sampler_name == 'k_dpm_2':
+ sampler = KDiffusionSampler(server_state["model"],'dpm_2')
+ elif sampler_name == 'k_dpmpp_2m':
+ sampler = KDiffusionSampler(server_state["model"],'dpmpp_2m')
+ elif sampler_name == 'k_euler_a':
+ sampler = KDiffusionSampler(server_state["model"],'euler_ancestral')
+ elif sampler_name == 'k_euler':
+ sampler = KDiffusionSampler(server_state["model"],'euler')
+ elif sampler_name == 'k_heun':
+ sampler = KDiffusionSampler(server_state["model"],'heun')
+ elif sampler_name == 'k_lms':
+ sampler = KDiffusionSampler(server_state["model"],'lms')
+ else:
+ raise Exception("Unknown sampler: " + sampler_name)
+
+ def init():
+ pass
+
+ def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name):
+ samples_ddim, _ = sampler.sample(S=ddim_steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=cfg_scale,
+ unconditional_conditioning=unconditional_conditioning, eta=ddim_eta, x_T=x,
+ img_callback=generation_callback if not server_state["bridge"] else None,
+ log_every_t=int(st.session_state.update_preview_frequency if not server_state["bridge"] else 100))
+
+ return samples_ddim
+
+
+ if use_stable_horde:
+ output_images, seed, info, stats = stable_horde(
+ prompt=prompt,
+ seed=seed,
+ outpath=outpath,
+ sampler_name=sampler_name,
+ save_grid=save_grid,
+ batch_size=batch_size,
+ n_iter=n_iter,
+ steps=ddim_steps,
+ cfg_scale=cfg_scale,
+ width=width,
+ height=height,
+ prompt_matrix=separate_prompts,
+ use_GFPGAN=use_GFPGAN,
+ GFPGAN_model=GFPGAN_model,
+ use_RealESRGAN=use_RealESRGAN,
+ realesrgan_model_name=RealESRGAN_model,
+ use_LDSR=use_LDSR,
+ LDSR_model_name=LDSR_model,
+ ddim_eta=ddim_eta,
+ normalize_prompt_weights=normalize_prompt_weights,
+ save_individual_images=save_individual_images,
+ sort_samples=group_by_prompt,
+ write_info_files=write_info_files,
+ jpg_sample=save_as_jpg,
+ variant_amount=variant_amount,
+ variant_seed=variant_seed,
+ api_key=stable_horde_key
+ )
+ else:
+
+ #try:
+ output_images, seed, info, stats = process_images(
+ outpath=outpath,
+ func_init=init,
+ func_sample=sample,
+ prompt=prompt,
+ seed=seed,
+ sampler_name=sampler_name,
+ save_grid=save_grid,
+ batch_size=batch_size,
+ n_iter=n_iter,
+ steps=ddim_steps,
+ cfg_scale=cfg_scale,
+ width=width,
+ height=height,
+ prompt_matrix=separate_prompts,
+ use_GFPGAN=use_GFPGAN,
+ GFPGAN_model=GFPGAN_model,
+ use_RealESRGAN=use_RealESRGAN,
+ realesrgan_model_name=RealESRGAN_model,
+ use_LDSR=use_LDSR,
+ LDSR_model_name=LDSR_model,
+ ddim_eta=ddim_eta,
+ normalize_prompt_weights=normalize_prompt_weights,
+ save_individual_images=save_individual_images,
+ sort_samples=group_by_prompt,
+ write_info_files=write_info_files,
+ jpg_sample=save_as_jpg,
+ variant_amount=variant_amount,
+ variant_seed=variant_seed,
+ )
+
+ del sampler
+
+ return output_images, seed, info, stats
+
+ #except RuntimeError as e:
+ #err = e
+ #err_msg = f'CRASHED:
Please wait while the program restarts.'
+ #stats = err_msg
+ #return [], seed, 'err', stats
+
+#
+@logger.catch(reraise=True)
+def layout():
+ with st.form("txt2img-inputs"):
+ st.session_state["generation_mode"] = "txt2img"
+
+ input_col1, generate_col1 = st.columns([10,1])
+
+ with input_col1:
+ #prompt = st.text_area("Input Text","")
+ placeholder = "A corgi wearing a top hat as an oil painting."
+ prompt = st.text_area("Input Text","", placeholder=placeholder, height=54)
+
+ if "defaults" in st.session_state:
+ if st.session_state["defaults"].general.enable_suggestions:
+ sygil_suggestions.suggestion_area(placeholder)
+
+ if "defaults" in st.session_state:
+ if st.session_state['defaults'].admin.global_negative_prompt:
+ prompt += f"### {st.session_state['defaults'].admin.global_negative_prompt}"
+
+ #print(prompt)
+
+ # creating the page layout using columns
+ col1, col2, col3 = st.columns([2,5,2], gap="large")
+
+ with col1:
+ width = st.slider("Width:", min_value=st.session_state['defaults'].txt2img.width.min_value, max_value=st.session_state['defaults'].txt2img.width.max_value,
+ value=st.session_state['defaults'].txt2img.width.value, step=st.session_state['defaults'].txt2img.width.step)
+ height = st.slider("Height:", min_value=st.session_state['defaults'].txt2img.height.min_value, max_value=st.session_state['defaults'].txt2img.height.max_value,
+ value=st.session_state['defaults'].txt2img.height.value, step=st.session_state['defaults'].txt2img.height.step)
+ cfg_scale = st.number_input("CFG (Classifier Free Guidance Scale):", min_value=st.session_state['defaults'].txt2img.cfg_scale.min_value,
+ value=st.session_state['defaults'].txt2img.cfg_scale.value, step=st.session_state['defaults'].txt2img.cfg_scale.step,
+ help="How strongly the image should follow the prompt.")
+
+ seed = st.text_input("Seed:", value=st.session_state['defaults'].txt2img.seed, help=" The seed to use, if left blank a random seed will be generated.")
+
+ with st.expander("Batch Options"):
+ #batch_count = st.slider("Batch count.", min_value=st.session_state['defaults'].txt2img.batch_count.min_value, max_value=st.session_state['defaults'].txt2img.batch_count.max_value,
+ #value=st.session_state['defaults'].txt2img.batch_count.value, step=st.session_state['defaults'].txt2img.batch_count.step,
+ #help="How many iterations or batches of images to generate in total.")
+
+ #batch_size = st.slider("Batch size", min_value=st.session_state['defaults'].txt2img.batch_size.min_value, max_value=st.session_state['defaults'].txt2img.batch_size.max_value,
+ #value=st.session_state.defaults.txt2img.batch_size.value, step=st.session_state.defaults.txt2img.batch_size.step,
+ #help="How many images are at once in a batch.\
+ #It increases the VRAM usage a lot but if you have enough VRAM it can reduce the time it takes to finish generation as more images are generated at once.\
+ #Default: 1")
+
+ st.session_state["batch_count"] = st.number_input("Batch count.", value=st.session_state['defaults'].txt2img.batch_count.value,
+ help="How many iterations or batches of images to generate in total.")
+
+ st.session_state["batch_size"] = st.number_input("Batch size", value=st.session_state.defaults.txt2img.batch_size.value,
+ help="How many images are at once in a batch.\
+ It increases the VRAM usage a lot but if you have enough VRAM it can reduce the time it takes \
+ to finish generation as more images are generated at once.\
+ Default: 1")
+
+ with st.expander("Preview Settings"):
+
+ st.session_state["update_preview"] = st.session_state["defaults"].general.update_preview
+ st.session_state["update_preview_frequency"] = st.number_input("Update Image Preview Frequency",
+ min_value=0,
+ value=st.session_state['defaults'].txt2img.update_preview_frequency,
+ help="Frequency in steps at which the the preview image is updated. By default the frequency \
+ is set to 10 step.")
+
+ with col2:
+ preview_tab, gallery_tab = st.tabs(["Preview", "Gallery"])
+
+ with preview_tab:
+ #st.write("Image")
+ #Image for testing
+ #image = Image.open(requests.get("https://icon-library.com/images/image-placeholder-icon/image-placeholder-icon-13.jpg", stream=True).raw).convert('RGB')
+ #new_image = image.resize((175, 240))
+ #preview_image = st.image(image)
+
+ # create an empty container for the image, progress bar, etc so we can update it later and use session_state to hold them globally.
+ st.session_state["preview_image"] = st.empty()
+
+
+ st.session_state["progress_bar_text"] = st.empty()
+ st.session_state["progress_bar_text"].info("Nothing but crickets here, try generating something first.")
+
+ st.session_state["progress_bar"] = st.empty()
+
+ message = st.empty()
+
+ with gallery_tab:
+ st.session_state["gallery"] = st.empty()
+ #st.session_state["gallery"].info("Nothing but crickets here, try generating something first.")
+
+ with col3:
+ # If we have custom models available on the "models/custom"
+ #folder then we show a menu to select which model we want to use, otherwise we use the main model for SD
+ custom_models_available()
+
+ if server_state["CustomModel_available"]:
+ st.session_state["custom_model"] = st.selectbox("Custom Model:", server_state["custom_models"],
+ index=server_state["custom_models"].index(st.session_state['defaults'].general.default_model),
+ help="Select the model you want to use. This option is only available if you have custom models \
+ on your 'models/custom' folder. The model name that will be shown here is the same as the name\
+ the file for the model has on said folder, it is recommended to give the .ckpt file a name that \
+ will make it easier for you to distinguish it from other models. Default: Stable Diffusion v1.5")
+
+ st.session_state.sampling_steps = st.number_input("Sampling Steps", value=st.session_state.defaults.txt2img.sampling_steps.value,
+ min_value=st.session_state.defaults.txt2img.sampling_steps.min_value,
+ step=st.session_state['defaults'].txt2img.sampling_steps.step,
+ help="Set the default number of sampling steps to use. Default is: 30 (with k_euler)")
+
+ sampler_name_list = ["k_lms", "k_euler", "k_euler_a", "k_dpm_2", "k_dpm_2_a", "k_dpmpp_2m", "k_heun", "PLMS", "DDIM"]
+ sampler_name = st.selectbox("Sampling method", sampler_name_list,
+ index=sampler_name_list.index(st.session_state['defaults'].txt2img.default_sampler), help="Sampling method to use. Default: k_euler")
+
+ with st.expander("Advanced"):
+ with st.expander("Stable Horde"):
+ use_stable_horde = st.checkbox("Use Stable Horde", value=False, help="Use the Stable Horde to generate images. More info can be found at https://stablehorde.net/")
+ stable_horde_key = st.text_input("Stable Horde Api Key", value=st.session_state['defaults'].general.stable_horde_api, type="password",
+ help="Optional Api Key used for the Stable Horde Bridge, if no api key is added the horde will be used anonymously.")
+
+ with st.expander("Output Settings"):
+ separate_prompts = st.checkbox("Create Prompt Matrix.", value=st.session_state['defaults'].txt2img.separate_prompts,
+ help="Separate multiple prompts using the `|` character, and get all combinations of them.")
+
+ normalize_prompt_weights = st.checkbox("Normalize Prompt Weights.", value=st.session_state['defaults'].txt2img.normalize_prompt_weights,
+ help="Ensure the sum of all weights add up to 1.0")
+
+ save_individual_images = st.checkbox("Save individual images.", value=st.session_state['defaults'].txt2img.save_individual_images,
+ help="Save each image generated before any filter or enhancement is applied.")
+
+ save_grid = st.checkbox("Save grid",value=st.session_state['defaults'].txt2img.save_grid, help="Save a grid with all the images generated into a single image.")
+ group_by_prompt = st.checkbox("Group results by prompt", value=st.session_state['defaults'].txt2img.group_by_prompt,
+ help="Saves all the images with the same prompt into the same folder. When using a prompt matrix each prompt combination will have its own folder.")
+
+ write_info_files = st.checkbox("Write Info file", value=st.session_state['defaults'].txt2img.write_info_files,
+ help="Save a file next to the image with informartion about the generation.")
+
+ save_as_jpg = st.checkbox("Save samples as jpg", value=st.session_state['defaults'].txt2img.save_as_jpg, help="Saves the images as jpg instead of png.")
+
+ # check if GFPGAN, RealESRGAN and LDSR are available.
+ #if "GFPGAN_available" not in st.session_state:
+ GFPGAN_available()
+
+ #if "RealESRGAN_available" not in st.session_state:
+ RealESRGAN_available()
+
+ #if "LDSR_available" not in st.session_state:
+ LDSR_available()
+
+ if st.session_state["GFPGAN_available"] or st.session_state["RealESRGAN_available"] or st.session_state["LDSR_available"]:
+ with st.expander("Post-Processing"):
+ face_restoration_tab, upscaling_tab = st.tabs(["Face Restoration", "Upscaling"])
+ with face_restoration_tab:
+ # GFPGAN used for face restoration
+ if st.session_state["GFPGAN_available"]:
+ #with st.expander("Face Restoration"):
+ #if st.session_state["GFPGAN_available"]:
+ #with st.expander("GFPGAN"):
+ st.session_state["use_GFPGAN"] = st.checkbox("Use GFPGAN", value=st.session_state['defaults'].txt2img.use_GFPGAN,
+ help="Uses the GFPGAN model to improve faces after the generation.\
+ This greatly improve the quality and consistency of faces but uses\
+ extra VRAM. Disable if you need the extra VRAM.")
+
+ st.session_state["GFPGAN_model"] = st.selectbox("GFPGAN model", st.session_state["GFPGAN_models"],
+ index=st.session_state["GFPGAN_models"].index(st.session_state['defaults'].general.GFPGAN_model))
+
+ #st.session_state["GFPGAN_strenght"] = st.slider("Effect Strenght", min_value=1, max_value=100, value=1, step=1, help='')
+
+ else:
+ st.session_state["use_GFPGAN"] = False
+
+ with upscaling_tab:
+ st.session_state['use_upscaling'] = st.checkbox("Use Upscaling", value=st.session_state['defaults'].txt2img.use_upscaling)
+
+ # RealESRGAN and LDSR used for upscaling.
+ if st.session_state["RealESRGAN_available"] or st.session_state["LDSR_available"]:
+
+ upscaling_method_list = []
+ if st.session_state["RealESRGAN_available"]:
+ upscaling_method_list.append("RealESRGAN")
+ if st.session_state["LDSR_available"]:
+ upscaling_method_list.append("LDSR")
+
+ #print (st.session_state["RealESRGAN_available"])
+ st.session_state["upscaling_method"] = st.selectbox("Upscaling Method", upscaling_method_list,
+ index=upscaling_method_list.index(st.session_state['defaults'].general.upscaling_method)
+ if st.session_state['defaults'].general.upscaling_method in upscaling_method_list
+ else 0)
+
+ if st.session_state["RealESRGAN_available"]:
+ with st.expander("RealESRGAN"):
+ if st.session_state["upscaling_method"] == "RealESRGAN" and st.session_state['use_upscaling']:
+ st.session_state["use_RealESRGAN"] = True
+ else:
+ st.session_state["use_RealESRGAN"] = False
+
+ st.session_state["RealESRGAN_model"] = st.selectbox("RealESRGAN model", st.session_state["RealESRGAN_models"],
+ index=st.session_state["RealESRGAN_models"].index(st.session_state['defaults'].general.RealESRGAN_model))
+ else:
+ st.session_state["use_RealESRGAN"] = False
+ st.session_state["RealESRGAN_model"] = "RealESRGAN_x4plus"
+
+
+ #
+ if st.session_state["LDSR_available"]:
+ with st.expander("LDSR"):
+ if st.session_state["upscaling_method"] == "LDSR" and st.session_state['use_upscaling']:
+ st.session_state["use_LDSR"] = True
+ else:
+ st.session_state["use_LDSR"] = False
+
+ st.session_state["LDSR_model"] = st.selectbox("LDSR model", st.session_state["LDSR_models"],
+ index=st.session_state["LDSR_models"].index(st.session_state['defaults'].general.LDSR_model))
+
+ st.session_state["ldsr_sampling_steps"] = st.number_input("Sampling Steps", value=st.session_state['defaults'].txt2img.LDSR_config.sampling_steps,
+ help="")
+
+ st.session_state["preDownScale"] = st.number_input("PreDownScale", value=st.session_state['defaults'].txt2img.LDSR_config.preDownScale,
+ help="")
+
+ st.session_state["postDownScale"] = st.number_input("postDownScale", value=st.session_state['defaults'].txt2img.LDSR_config.postDownScale,
+ help="")
+
+ downsample_method_list = ['Nearest', 'Lanczos']
+ st.session_state["downsample_method"] = st.selectbox("Downsample Method", downsample_method_list,
+ index=downsample_method_list.index(st.session_state['defaults'].txt2img.LDSR_config.downsample_method))
+
+ else:
+ st.session_state["use_LDSR"] = False
+ st.session_state["LDSR_model"] = "model"
+
+ with st.expander("Variant"):
+ variant_amount = st.slider("Variant Amount:", value=st.session_state['defaults'].txt2img.variant_amount.value,
+ min_value=st.session_state['defaults'].txt2img.variant_amount.min_value, max_value=st.session_state['defaults'].txt2img.variant_amount.max_value,
+ step=st.session_state['defaults'].txt2img.variant_amount.step)
+ variant_seed = st.text_input("Variant Seed:", value=st.session_state['defaults'].txt2img.seed,
+ help="The seed to use when generating a variant, if left blank a random seed will be generated.")
+
+ #galleryCont = st.empty()
+
+ # Every form must have a submit button, the extra blank spaces is a temp way to align it with the input field. Needs to be done in CSS or some other way.
+ generate_col1.write("")
+ generate_col1.write("")
+ generate_button = generate_col1.form_submit_button("Generate")
+
+ #
+ if generate_button:
+
+ with col2:
+ with no_rerun:
+ if not use_stable_horde:
+ with hc.HyLoader('Loading Models...', hc.Loaders.standard_loaders,index=[0]):
+ load_models(use_LDSR=st.session_state["use_LDSR"], LDSR_model=st.session_state["LDSR_model"],
+ use_GFPGAN=st.session_state["use_GFPGAN"], GFPGAN_model=st.session_state["GFPGAN_model"] ,
+ use_RealESRGAN=st.session_state["use_RealESRGAN"], RealESRGAN_model=st.session_state["RealESRGAN_model"],
+ CustomModel_available=server_state["CustomModel_available"], custom_model=st.session_state["custom_model"])
+
+ #print(st.session_state['use_RealESRGAN'])
+ #print(st.session_state['use_LDSR'])
+ try:
+
+
+ output_images, seeds, info, stats = txt2img(prompt, st.session_state.sampling_steps, sampler_name, st.session_state["batch_count"], st.session_state["batch_size"],
+ cfg_scale, seed, height, width, separate_prompts, normalize_prompt_weights, save_individual_images,
+ save_grid, group_by_prompt, save_as_jpg, st.session_state["use_GFPGAN"], st.session_state['GFPGAN_model'],
+ use_RealESRGAN=st.session_state["use_RealESRGAN"], RealESRGAN_model=st.session_state["RealESRGAN_model"],
+ use_LDSR=st.session_state["use_LDSR"], LDSR_model=st.session_state["LDSR_model"],
+ variant_amount=variant_amount, variant_seed=variant_seed, write_info_files=write_info_files,
+ use_stable_horde=use_stable_horde, stable_horde_key=stable_horde_key)
+
+ message.success('Render Complete: ' + info + '; Stats: ' + stats, icon="✅")
+
+ with gallery_tab:
+ logger.info(seeds)
+ st.session_state["gallery"].text = ""
+ sdGallery(output_images)
+
+
+ except (StopException,
+ #KeyError
+ ):
+ print(f"Received Streamlit StopException")
+
+ # reset the page title so the percent doesnt stay on it confusing the user.
+ set_page_title(f"Stable Diffusion Playground")
+
+ # this will render all the images at the end of the generation but its better if its moved to a second tab inside col2 and shown as a gallery.
+ # use the current col2 first tab to show the preview_img and update it as its generated.
+ #preview_image.image(output_images)
+
+
diff --git a/webui/streamlit/scripts/txt2vid.py b/webui/streamlit/scripts/txt2vid.py
new file mode 100644
index 0000000..678d508
--- /dev/null
+++ b/webui/streamlit/scripts/txt2vid.py
@@ -0,0 +1,2012 @@
+# This file is part of sygil-webui (https://github.com/Sygil-Dev/sygil-webui/).
+
+# Copyright 2022 Sygil-Dev team.
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+# base webui import and utils.
+
+"""
+Implementation of Text to Video based on the
+https://github.com/nateraw/stable-diffusion-videos
+repo and the original gist script from
+https://gist.github.com/karpathy/00103b0037c5aaea32fe1da1af553355
+"""
+from sd_utils import st, MemUsageMonitor, server_state, no_rerun, \
+ custom_models_available, RealESRGAN_available, GFPGAN_available, \
+ LDSR_available, hc, logger
+ #seed_to_int, logger, slerp, optimize_update_preview_frequency, \
+ #load_learned_embed_in_clip, load_GFPGAN, RealESRGANModel, set_page_title
+
+
+# streamlit imports
+from streamlit.runtime.scriptrunner import StopException
+#from streamlit.elements import image as STImage
+
+#streamlit components section
+from streamlit_server_state import server_state, server_state_lock
+#from streamlitextras.threader import lock, trigger_rerun, \
+ #streamlit_thread, get_thread, \
+ #last_trigger_time
+
+#other imports
+
+import os, sys, json, re, random, datetime, time, warnings, mimetypes
+from PIL import Image
+import torch
+import numpy as np
+import time, inspect, timeit
+import torch
+from torch import autocast
+#from io import BytesIO
+import imageio
+from slugify import slugify
+
+from diffusers import StableDiffusionPipeline, DiffusionPipeline
+#from stable_diffusion_videos import StableDiffusionWalkPipeline
+
+from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, \
+ PNDMScheduler, DDPMScheduler
+
+from diffusers.configuration_utils import FrozenDict
+from diffusers.models import AutoencoderKL, UNet2DConditionModel
+from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
+from diffusers.utils import deprecate
+from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
+from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
+from typing import Callable, List, Optional, Union
+from pathlib import Path
+from torchvision.transforms.functional import pil_to_tensor
+from torchvision import transforms
+import librosa
+from PIL import Image
+from torchvision.io import write_video
+from torchvision import transforms
+import torch.nn as nn
+from uuid import uuid4
+
+
+# streamlit components
+from custom_components import sygil_suggestions
+
+# Temp imports
+
+# end of imports
+#---------------------------------------------------------------------------------------------------------------
+
+sygil_suggestions.init()
+
+try:
+ # this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
+ from transformers import logging
+
+ logging.set_verbosity_error()
+except:
+ pass
+
+# remove some annoying deprecation warnings that show every now and then.
+warnings.filterwarnings("ignore", category=DeprecationWarning)
+warnings.filterwarnings("ignore", category=UserWarning)
+
+# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the bowser will not show any UI
+mimetypes.init()
+mimetypes.add_type('application/javascript', '.js')
+
+class plugin_info():
+ plugname = "txt2vid"
+ description = "Text to Image"
+ isTab = True
+ displayPriority = 1
+
+#
+# -----------------------------------------------------------------------------
+
+def txt2vid_generation_callback(step: int, timestep: int, latents: torch.FloatTensor):
+ #print ("test")
+ #scale and decode the image latents with vae
+ cond_latents_2 = 1 / 0.18215 * latents
+ image = server_state["pipe"].vae.decode(cond_latents_2)
+
+ # generate output numpy image as uint8
+ image = torch.clamp((image["sample"] + 1.0) / 2.0, min=0.0, max=1.0)
+ image2 = transforms.ToPILImage()(image.squeeze_(0))
+
+ st.session_state["preview_image"].image(image2)
+
+def get_timesteps_arr(audio_filepath, offset, duration, fps=30, margin=1.0, smooth=0.0):
+ y, sr = librosa.load(audio_filepath, offset=offset, duration=duration)
+
+ # librosa.stft hardcoded defaults...
+ # n_fft defaults to 2048
+ # hop length is win_length // 4
+ # win_length defaults to n_fft
+ D = librosa.stft(y, n_fft=2048, hop_length=2048 // 4, win_length=2048)
+
+ # Extract percussive elements
+ D_harmonic, D_percussive = librosa.decompose.hpss(D, margin=margin)
+ y_percussive = librosa.istft(D_percussive, length=len(y))
+
+ # Get normalized melspectrogram
+ spec_raw = librosa.feature.melspectrogram(y=y_percussive, sr=sr)
+ spec_max = np.amax(spec_raw, axis=0)
+ spec_norm = (spec_max - np.min(spec_max)) / np.ptp(spec_max)
+
+ # Resize cumsum of spec norm to our desired number of interpolation frames
+ x_norm = np.linspace(0, spec_norm.shape[-1], spec_norm.shape[-1])
+ y_norm = np.cumsum(spec_norm)
+ y_norm /= y_norm[-1]
+ x_resize = np.linspace(0, y_norm.shape[-1], int(duration*fps))
+
+ T = np.interp(x_resize, x_norm, y_norm)
+
+ # Apply smoothing
+ return T * (1 - smooth) + np.linspace(0.0, 1.0, T.shape[0]) * smooth
+
+#
+def make_video_pyav(
+ frames_or_frame_dir: Union[str, Path, torch.Tensor],
+ audio_filepath: Union[str, Path] = None,
+ fps: int = 30,
+ audio_offset: int = 0,
+ audio_duration: int = 2,
+ sr: int = 22050,
+ output_filepath: Union[str, Path] = "output.mp4",
+ glob_pattern: str = "*.png",
+ ):
+ """
+ TODO - docstring here
+
+ frames_or_frame_dir: (Union[str, Path, torch.Tensor]):
+ Either a directory of images, or a tensor of shape (T, C, H, W) in range [0, 255].
+ """
+
+ # Torchvision write_video doesn't support pathlib paths
+ output_filepath = str(output_filepath)
+
+ if isinstance(frames_or_frame_dir, (str, Path)):
+ frames = None
+ for img in sorted(Path(frames_or_frame_dir).glob(glob_pattern)):
+ frame = pil_to_tensor(Image.open(img)).unsqueeze(0)
+ frames = frame if frames is None else torch.cat([frames, frame])
+ else:
+ frames = frames_or_frame_dir
+
+ # TCHW -> THWC
+ frames = frames.permute(0, 2, 3, 1)
+
+ if audio_filepath:
+ # Read audio, convert to tensor
+ audio, sr = librosa.load(audio_filepath, sr=sr, mono=True, offset=audio_offset, duration=audio_duration)
+ audio_tensor = torch.tensor(audio).unsqueeze(0)
+
+ write_video(
+ output_filepath,
+ frames,
+ fps=fps,
+ audio_array=audio_tensor,
+ audio_fps=sr,
+ audio_codec="aac",
+ options={"crf": "10", "pix_fmt": "yuv420p"},
+ )
+ else:
+ write_video(output_filepath, frames, fps=fps, options={"crf": "10", "pix_fmt": "yuv420p"})
+
+ return output_filepath
+
+
+class StableDiffusionWalkPipeline(DiffusionPipeline):
+ r"""
+ Pipeline for generating videos by interpolating Stable Diffusion's latent space.
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
+ feature_extractor ([`CLIPFeatureExtractor`]):
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
+ """
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPFeatureExtractor,
+ ):
+ super().__init__()
+
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
+ " file"
+ )
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["steps_offset"] = 1
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+
+ def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
+ r"""
+ Enable sliced attention computation.
+ When this option is enabled, the attention module will split the input tensor in slices, to compute attention
+ in several steps. This is useful to save some memory in exchange for a small speed decrease.
+ Args:
+ slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
+ When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
+ a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
+ `attention_head_dim` must be a multiple of `slice_size`.
+ """
+ if slice_size == "auto":
+ # half the attention head size is usually a good trade-off between
+ # speed and memory
+ slice_size = self.unet.config.attention_head_dim // 2
+ self.unet.set_attention_slice(slice_size)
+
+ def disable_attention_slicing(self):
+ r"""
+ Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
+ back to computing attention in one step.
+ """
+ # set slice_size = `None` to disable `attention slicing`
+ self.enable_attention_slicing(None)
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Optional[Union[str, List[str]]] = None,
+ height: int = 512,
+ width: int = 512,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[torch.Generator] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: Optional[int] = 1,
+ text_embeddings: Optional[torch.FloatTensor] = None,
+ **kwargs,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+ Args:
+ prompt (`str` or `List[str]`, *optional*, defaults to `None`):
+ The prompt or prompts to guide the image generation. If not provided, `text_embeddings` is required.
+ height (`int`, *optional*, defaults to 512):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to 512):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `guidance_scale` is less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator`, *optional*):
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
+ deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ text_embeddings (`torch.FloatTensor`, *optional*, defaults to `None`):
+ Pre-generated text embeddings to be used as inputs for image generation. Can be used in place of
+ `prompt` to avoid re-computing the embeddings. If not provided, the embeddings will be generated from
+ the supplied `prompt`.
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if text_embeddings is None:
+ if isinstance(prompt, str):
+ batch_size = 1
+ elif isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ # get prompt text embeddings
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+
+ if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
+ removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
+ print("The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+ text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
+ text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]
+ else:
+ batch_size = text_embeddings.shape[0]
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ bs_embed, seq_len, _ = text_embeddings.shape
+ text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
+ text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""]
+ elif type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ max_length = self.tokenizer.model_max_length
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
+
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = uncond_embeddings.shape[1]
+ uncond_embeddings = uncond_embeddings.repeat(batch_size, num_images_per_prompt, 1)
+ uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
+
+ # get the initial random noise unless the user supplied it
+
+ # Unlike in other pipelines, latents need to be generated in the target device
+ # for 1-to-1 results reproducibility with the CompVis implementation.
+ # However this currently doesn't work in `mps`.
+ latents_shape = (batch_size * num_images_per_prompt, self.unet.in_channels, height // 8, width // 8)
+ latents_dtype = text_embeddings.dtype
+ if latents is None:
+ if self.device.type == "mps":
+ # randn does not exist on mps
+ latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to(
+ self.device
+ )
+ else:
+ latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype)
+ else:
+ if latents.shape != latents_shape:
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
+ latents = latents.to(self.device)
+
+ # set timesteps
+ self.scheduler.set_timesteps(num_inference_steps)
+
+ # Some schedulers like PNDM have timesteps as arrays
+ # It's more optimized to move all timesteps to correct device beforehand
+ timesteps_tensor = self.scheduler.timesteps.to(self.device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ for i, t in enumerate(self.progress_bar(timesteps_tensor)):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+
+ # call the callback, if provided
+ if callback is not None and i % callback_steps == 0:
+ callback(i, t, latents)
+ print ("test")
+
+ latents = 1 / 0.18215 * latents
+ image = self.vae.decode(latents).sample
+
+ image = (image / 2 + 0.5).clamp(0, 1)
+
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+
+ if self.safety_checker is not None:
+ safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(
+ self.device
+ )
+ image, has_nsfw_concept = self.safety_checker(
+ images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)
+ )
+ else:
+ has_nsfw_concept = None
+
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
+
+ def generate_inputs(self, prompt_a, prompt_b, seed_a, seed_b, noise_shape, T, batch_size):
+ embeds_a = self.embed_text(prompt_a)
+ embeds_b = self.embed_text(prompt_b)
+
+ latents_a = self.init_noise(seed_a, noise_shape)
+ latents_b = self.init_noise(seed_b, noise_shape)
+
+ batch_idx = 0
+ embeds_batch, noise_batch = None, None
+ for i, t in enumerate(T):
+ embeds = torch.lerp(embeds_a, embeds_b, t)
+ noise = slerp(device="cuda", t=float(t), v0=latents_a, v1=latents_b, DOT_THRESHOLD=0.9995)
+
+ embeds_batch = embeds if embeds_batch is None else torch.cat([embeds_batch, embeds])
+ noise_batch = noise if noise_batch is None else torch.cat([noise_batch, noise])
+ batch_is_ready = embeds_batch.shape[0] == batch_size or i + 1 == T.shape[0]
+ if not batch_is_ready:
+ continue
+ yield batch_idx, embeds_batch, noise_batch
+ batch_idx += 1
+ del embeds_batch, noise_batch
+ torch.cuda.empty_cache()
+ embeds_batch, noise_batch = None, None
+
+ def make_clip_frames(
+ self,
+ prompt_a: str,
+ prompt_b: str,
+ seed_a: int,
+ seed_b: int,
+ num_interpolation_steps: int = 5,
+ save_path: Union[str, Path] = "outputs/",
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ eta: float = 0.0,
+ height: int = 512,
+ width: int = 512,
+ upsample: bool = False,
+ batch_size: int = 1,
+ image_file_ext: str = ".png",
+ T: np.ndarray = None,
+ skip: int = 0,
+ callback = None,
+ callback_steps:int = 1,
+ ):
+ save_path = Path(save_path)
+ save_path.mkdir(parents=True, exist_ok=True)
+
+ T = T if T is not None else np.linspace(0.0, 1.0, num_interpolation_steps)
+ if T.shape[0] != num_interpolation_steps:
+ raise ValueError(f"Unexpected T shape, got {T.shape}, expected dim 0 to be {num_interpolation_steps}")
+
+ if upsample:
+ if getattr(self, "upsampler", None) is None:
+ self.upsampler = RealESRGANModel.from_pretrained("nateraw/real-esrgan")
+ self.upsampler.to(self.device)
+
+ batch_generator = self.generate_inputs(
+ prompt_a,
+ prompt_b,
+ seed_a,
+ seed_b,
+ (1, self.unet.in_channels, height // 8, width // 8),
+ T[skip:],
+ batch_size,
+ )
+
+ frame_index = skip
+ for _, embeds_batch, noise_batch in batch_generator:
+ with torch.autocast("cuda"):
+ outputs = self(
+ latents=noise_batch,
+ text_embeddings=embeds_batch,
+ height=height,
+ width=width,
+ guidance_scale=guidance_scale,
+ eta=eta,
+ num_inference_steps=num_inference_steps,
+ output_type="pil" if not upsample else "numpy",
+ callback=callback,
+ callback_steps=callback_steps,
+ )["images"]
+
+ for image in outputs:
+ frame_filepath = save_path / (f"frame%06d{image_file_ext}" % frame_index)
+ image = image if not upsample else self.upsampler(image)
+ image.save(frame_filepath)
+ frame_index += 1
+
+ def walk(
+ self,
+ prompt: Optional[List[str]] = None,
+ seeds: Optional[List[int]] = None,
+ num_interpolation_steps: Optional[Union[int, List[int]]] = 5, # int or list of int
+ output_dir: Optional[str] = "./dreams",
+ name: Optional[str] = None,
+ image_file_ext: Optional[str] = ".png",
+ fps: Optional[int] = 30,
+ num_inference_steps: Optional[int] = 50,
+ guidance_scale: Optional[float] = 7.5,
+ eta: Optional[float] = 0.0,
+ height: Optional[int] = 512,
+ width: Optional[int] = 512,
+ upsample: Optional[bool] = False,
+ batch_size: Optional[int] = 1,
+ resume: Optional[bool] = False,
+ audio_filepath: str = None,
+ audio_start_sec: Optional[Union[int, float]] = None,
+ margin: Optional[float] = 1.0,
+ smooth: Optional[float] = 0.0,
+ callback=None,
+ callback_steps=1,
+ ):
+ """Generate a video from a sequence of prompts and seeds. Optionally, add audio to the
+ video to interpolate to the intensity of the audio.
+
+ Args:
+ prompts (Optional[List[str]], optional):
+ list of text prompts. Defaults to None.
+ seeds (Optional[List[int]], optional):
+ list of random seeds corresponding to prompts. Defaults to None.
+ num_interpolation_steps (Union[int, List[int]], *optional*):
+ How many interpolation steps between each prompt. Defaults to None.
+ output_dir (Optional[str], optional):
+ Where to save the video. Defaults to './dreams'.
+ name (Optional[str], optional):
+ Name of the subdirectory of output_dir. Defaults to None.
+ image_file_ext (Optional[str], *optional*, defaults to '.png'):
+ The extension to use when writing video frames.
+ fps (Optional[int], *optional*, defaults to 30):
+ The frames per second in the resulting output videos.
+ num_inference_steps (Optional[int], *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (Optional[float], *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ eta (Optional[float], *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ height (Optional[int], *optional*, defaults to 512):
+ height of the images to generate.
+ width (Optional[int], *optional*, defaults to 512):
+ width of the images to generate.
+ upsample (Optional[bool], *optional*, defaults to False):
+ When True, upsamples images with realesrgan.
+ batch_size (Optional[int], *optional*, defaults to 1):
+ Number of images to generate at once.
+ resume (Optional[bool], *optional*, defaults to False):
+ When True, resumes from the last frame in the output directory based
+ on available prompt config. Requires you to provide the `name` argument.
+ audio_filepath (str, *optional*, defaults to None):
+ Optional path to an audio file to influence the interpolation rate.
+ audio_start_sec (Optional[Union[int, float]], *optional*, defaults to 0):
+ Global start time of the provided audio_filepath.
+ margin (Optional[float], *optional*, defaults to 1.0):
+ Margin from librosa hpss to use for audio interpolation.
+ smooth (Optional[float], *optional*, defaults to 0.0):
+ Smoothness of the audio interpolation. 1.0 means linear interpolation.
+
+ This function will create sub directories for each prompt and seed pair.
+
+ For example, if you provide the following prompts and seeds:
+
+ ```
+ prompts = ['a dog', 'a cat', 'a bird']
+ seeds = [1, 2, 3]
+ num_interpolation_steps = 5
+ output_dir = 'output_dir'
+ name = 'name'
+ fps = 5
+ ```
+
+ Then the following directories will be created:
+
+ ```
+ output_dir
+ ├── name
+ │ ├── name_000000
+ │ │ ├── frame000000.png
+ │ │ ├── ...
+ │ │ ├── frame000004.png
+ │ │ ├── name_000000.mp4
+ │ ├── name_000001
+ │ │ ├── frame000000.png
+ │ │ ├── ...
+ │ │ ├── frame000004.png
+ │ │ ├── name_000001.mp4
+ │ ├── ...
+ │ ├── name.mp4
+ | |── prompt_config.json
+ ```
+
+ Returns:
+ str: The resulting video filepath. This video includes all sub directories' video clips.
+ """
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ # init the output dir
+ if type(prompts) == str:
+ sanitized_prompt = slugify(prompts)
+ else:
+ sanitized_prompt = slugify(prompts[0])
+
+ full_path = os.path.join(str(output_dir), str(sanitized_prompt))
+
+ if len(full_path) > 220:
+ sanitized_prompt = sanitized_prompt[:220-len(full_path)]
+ full_path = os.path.join(output_dir, sanitized_prompt)
+
+ os.makedirs(full_path, exist_ok=True)
+
+ # Where the final video of all the clips combined will be saved
+ output_filepath = os.path.join(full_path, f"{sanitized_prompt}.mp4")
+
+ # If using same number of interpolation steps between, we turn into list
+ if not resume and isinstance(num_interpolation_steps, int):
+ num_interpolation_steps = [num_interpolation_steps] * (len(prompts) - 1)
+
+ if not resume:
+ audio_start_sec = audio_start_sec or 0
+
+ # Save/reload prompt config
+ prompt_config_path = Path(os.path.join(full_path, "prompt_config.json"))
+ if not resume:
+ prompt_config_path.write_text(
+ json.dumps(
+ dict(
+ prompts=prompts,
+ seeds=seeds,
+ num_interpolation_steps=num_interpolation_steps,
+ fps=fps,
+ num_inference_steps=num_inference_steps,
+ guidance_scale=guidance_scale,
+ eta=eta,
+ upsample=upsample,
+ height=height,
+ width=width,
+ audio_filepath=audio_filepath,
+ audio_start_sec=audio_start_sec,
+ ),
+
+ indent=2,
+ sort_keys=False,
+ )
+ )
+ else:
+ data = json.load(open(prompt_config_path))
+ prompts = data["prompts"]
+ seeds = data["seeds"]
+ num_interpolation_steps = data["num_interpolation_steps"]
+ fps = data["fps"]
+ num_inference_steps = data["num_inference_steps"]
+ guidance_scale = data["guidance_scale"]
+ eta = data["eta"]
+ upsample = data["upsample"]
+ height = data["height"]
+ width = data["width"]
+ audio_filepath = data["audio_filepath"]
+ audio_start_sec = data["audio_start_sec"]
+
+ for i, (prompt_a, prompt_b, seed_a, seed_b, num_step) in enumerate(
+ zip(prompts, prompts[1:], seeds, seeds[1:], num_interpolation_steps)
+ ):
+ # {name}_000000 / {name}_000001 / ...
+ save_path = Path(f"{full_path}/{name}_{i:06d}")
+
+ # Where the individual clips will be saved
+ step_output_filepath = Path(f"{save_path}/{name}_{i:06d}.mp4")
+
+ # Determine if we need to resume from a previous run
+ skip = 0
+ if resume:
+ if step_output_filepath.exists():
+ print(f"Skipping {save_path} because frames already exist")
+ continue
+
+ existing_frames = sorted(save_path.glob(f"*{image_file_ext}"))
+ if existing_frames:
+ skip = int(existing_frames[-1].stem[-6:]) + 1
+ if skip + 1 >= num_step:
+ print(f"Skipping {save_path} because frames already exist")
+ continue
+ print(f"Resuming {save_path.name} from frame {skip}")
+
+ audio_offset = audio_start_sec + sum(num_interpolation_steps[:i]) / fps
+ audio_duration = num_step / fps
+
+ self.make_clip_frames(
+ prompt_a,
+ prompt_b,
+ seed_a,
+ seed_b,
+ num_interpolation_steps=num_step,
+ save_path=save_path,
+ num_inference_steps=num_inference_steps,
+ guidance_scale=guidance_scale,
+ eta=eta,
+ height=height,
+ width=width,
+ upsample=upsample,
+ batch_size=batch_size,
+ skip=skip,
+ T=get_timesteps_arr(
+ audio_filepath,
+ offset=audio_offset,
+ duration=audio_duration,
+ fps=fps,
+ margin=margin,
+ smooth=smooth,
+ callback=callback,
+ callback_steps=callback_steps,
+ )
+ if audio_filepath
+ else None,
+ )
+ make_video_pyav(
+ save_path,
+ audio_filepath=audio_filepath,
+ fps=fps,
+ output_filepath=step_output_filepath,
+ glob_pattern=f"*{image_file_ext}",
+ audio_offset=audio_offset,
+ audio_duration=audio_duration,
+ sr=44100,
+ )
+
+ return make_video_pyav(
+ full_path,
+ audio_filepath=audio_filepath,
+ fps=fps,
+ audio_offset=audio_start_sec,
+ audio_duration=sum(num_interpolation_steps) / fps,
+ output_filepath=output_filepath,
+ glob_pattern=f"**/*{image_file_ext}",
+ sr=44100,
+ )
+
+ def embed_text(self, text):
+ """Helper to embed some text"""
+ with torch.autocast("cuda"):
+ text_input = self.tokenizer(
+ text,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ with torch.no_grad():
+ embed = self.text_encoder(text_input.input_ids.to(self.device))[0]
+ return embed
+
+ def init_noise(self, seed, noise_shape):
+ """Helper to initialize noise"""
+ # randn does not exist on mps, so we create noise on CPU here and move it to the device after initialization
+ if self.device.type == "mps":
+ noise = torch.randn(
+ noise_shape,
+ device='cpu',
+ generator=torch.Generator(device='cpu').manual_seed(seed),
+ ).to(self.device)
+ else:
+ noise = torch.randn(
+ noise_shape,
+ device=self.device,
+ generator=torch.Generator(device=self.device).manual_seed(seed),
+ )
+ return noise
+
+ @classmethod
+ def from_pretrained(cls, *args, tiled=False, **kwargs):
+ """Same as diffusers `from_pretrained` but with tiled option, which makes images tilable"""
+ if tiled:
+
+ def patch_conv(**patch):
+ cls = nn.Conv2d
+ init = cls.__init__
+
+ def __init__(self, *args, **kwargs):
+ return init(self, *args, **kwargs, **patch)
+
+ cls.__init__ = __init__
+
+ patch_conv(padding_mode="circular")
+
+ pipeline = super().from_pretrained(*args, **kwargs)
+ pipeline.tiled = tiled
+ return pipeline
+
+@torch.no_grad()
+def diffuse(
+ pipe,
+ cond_embeddings, # text conditioning, should be (1, 77, 768)
+ cond_latents, # image conditioning, should be (1, 4, 64, 64)
+ num_inference_steps,
+ cfg_scale,
+ eta,
+ fps=30
+ ):
+
+ torch_device = cond_latents.get_device()
+
+ # classifier guidance: add the unconditional embedding
+ max_length = cond_embeddings.shape[1] # 77
+ uncond_input = pipe.tokenizer([""], padding="max_length", max_length=max_length, return_tensors="pt")
+ uncond_embeddings = pipe.text_encoder(uncond_input.input_ids.to(torch_device))[0]
+ text_embeddings = torch.cat([uncond_embeddings, cond_embeddings])
+
+ # if we use LMSDiscreteScheduler, let's make sure latents are mulitplied by sigmas
+ if isinstance(pipe.scheduler, LMSDiscreteScheduler):
+ cond_latents = cond_latents * pipe.scheduler.sigmas[0]
+
+ # init the scheduler
+ accepts_offset = "offset" in set(inspect.signature(pipe.scheduler.set_timesteps).parameters.keys())
+ extra_set_kwargs = {}
+ if accepts_offset:
+ extra_set_kwargs["offset"] = 1
+
+ pipe.scheduler.set_timesteps(num_inference_steps + st.session_state.sampling_steps, **extra_set_kwargs)
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+ accepts_eta = "eta" in set(inspect.signature(pipe.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+
+ step_counter = 0
+ inference_counter = 0
+
+ if "current_chunk_speed" not in st.session_state:
+ st.session_state["current_chunk_speed"] = 0
+
+ if "previous_chunk_speed_list" not in st.session_state:
+ st.session_state["previous_chunk_speed_list"] = [0]
+ st.session_state["previous_chunk_speed_list"].append(st.session_state["current_chunk_speed"])
+
+ if "update_preview_frequency_list" not in st.session_state:
+ st.session_state["update_preview_frequency_list"] = [0]
+ st.session_state["update_preview_frequency_list"].append(st.session_state["update_preview_frequency"])
+
+
+ try:
+ # diffuse!
+ for i, t in enumerate(pipe.scheduler.timesteps):
+ start = timeit.default_timer()
+
+ #status_text.text(f"Running step: {step_counter}{total_number_steps} {percent} | {duration:.2f}{speed}")
+
+ # expand the latents for classifier free guidance
+ latent_model_input = torch.cat([cond_latents] * 2)
+ if isinstance(pipe.scheduler, LMSDiscreteScheduler):
+ sigma = pipe.scheduler.sigmas[i]
+ latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
+
+ # predict the noise residual
+ noise_pred = pipe.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"]
+
+ # cfg
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + cfg_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ if isinstance(pipe.scheduler, LMSDiscreteScheduler):
+ cond_latents = pipe.scheduler.step(noise_pred, i, cond_latents, **extra_step_kwargs)["prev_sample"]
+ else:
+ cond_latents = pipe.scheduler.step(noise_pred, t, cond_latents, **extra_step_kwargs)["prev_sample"]
+
+
+ #update the preview image if it is enabled and the frequency matches the step_counter
+ if st.session_state["update_preview"]:
+ step_counter += 1
+
+ if step_counter == st.session_state["update_preview_frequency"]:
+ if st.session_state.dynamic_preview_frequency:
+ st.session_state["current_chunk_speed"],
+ st.session_state["previous_chunk_speed_list"],
+ st.session_state["update_preview_frequency"],
+ st.session_state["avg_update_preview_frequency"] = optimize_update_preview_frequency(st.session_state["current_chunk_speed"],
+ st.session_state["previous_chunk_speed_list"],
+ st.session_state["update_preview_frequency"],
+ st.session_state["update_preview_frequency_list"])
+
+ #scale and decode the image latents with vae
+ cond_latents_2 = 1 / 0.18215 * cond_latents
+ image = pipe.vae.decode(cond_latents_2)
+
+ # generate output numpy image as uint8
+ image = torch.clamp((image["sample"] + 1.0) / 2.0, min=0.0, max=1.0)
+ image2 = transforms.ToPILImage()(image.squeeze_(0))
+
+ st.session_state["preview_image"].image(image2)
+
+ step_counter = 0
+
+ duration = timeit.default_timer() - start
+
+ st.session_state["current_chunk_speed"] = duration
+
+ if duration >= 1:
+ speed = "s/it"
+ else:
+ speed = "it/s"
+ duration = 1 / duration
+
+ total_frames = st.session_state.max_duration_in_seconds * fps
+ total_steps = st.session_state.sampling_steps + st.session_state.num_inference_steps
+
+ if i > st.session_state.sampling_steps:
+ inference_counter += 1
+ inference_percent = int(100 * float(inference_counter + 1 if inference_counter < num_inference_steps else num_inference_steps)/float(num_inference_steps))
+ inference_progress = f"{inference_counter + 1 if inference_counter < num_inference_steps else num_inference_steps}/{num_inference_steps} {inference_percent}% "
+ else:
+ inference_progress = ""
+
+ total_percent = int(100 * float(i+1 if i+1 < (num_inference_steps + st.session_state.sampling_steps)
+ else (num_inference_steps + st.session_state.sampling_steps))/float((num_inference_steps + st.session_state.sampling_steps)))
+
+ percent = int(100 * float(i+1 if i+1 < num_inference_steps else st.session_state.sampling_steps)/float(st.session_state.sampling_steps))
+ frames_percent = int(100 * float(st.session_state.current_frame if st.session_state.current_frame < total_frames else total_frames)/float(total_frames))
+
+ if "progress_bar_text" in st.session_state:
+ st.session_state["progress_bar_text"].text(
+ f"Running step: {i+1 if i+1 < st.session_state.sampling_steps else st.session_state.sampling_steps}/{st.session_state.sampling_steps} "
+ f"{percent if percent < 100 else 100}% {inference_progress}{duration:.2f}{speed} | "
+ f"Frame: {st.session_state.current_frame + 1 if st.session_state.current_frame < total_frames else total_frames}/{total_frames} "
+ f"{frames_percent if frames_percent < 100 else 100}% {st.session_state.frame_duration:.2f}{st.session_state.frame_speed}"
+ )
+
+ if "progress_bar" in st.session_state:
+ st.session_state["progress_bar"].progress(total_percent if total_percent < 100 else 100)
+
+ if st.session_state["defaults"].general.show_percent_in_tab_title:
+ set_page_title(f"({percent if percent < 100 else 100}%) Stable Diffusion Playground")
+
+ except KeyError:
+ raise StopException
+
+ #scale and decode the image latents with vae
+ cond_latents_2 = 1 / 0.18215 * cond_latents
+ image = pipe.vae.decode(cond_latents_2)
+
+ # generate output numpy image as uint8
+ image = torch.clamp((image["sample"] + 1.0) / 2.0, min=0.0, max=1.0)
+ image2 = transforms.ToPILImage()(image.squeeze_(0))
+
+
+ return image2
+
+#
+def load_diffusers_model(weights_path,torch_device):
+
+ with server_state_lock["model"]:
+ if "model" in server_state:
+ del server_state["model"]
+
+ if "textual_inversion" in st.session_state:
+ del st.session_state['textual_inversion']
+
+ try:
+ with server_state_lock["pipe"]:
+ if "pipe" not in server_state:
+ if "weights_path" in st.session_state and st.session_state["weights_path"] != weights_path:
+ del st.session_state["weights_path"]
+
+ st.session_state["weights_path"] = weights_path
+ server_state['float16'] = st.session_state['defaults'].general.use_float16
+ server_state['no_half'] = st.session_state['defaults'].general.no_half
+ server_state['optimized'] = st.session_state['defaults'].general.optimized
+
+ #if folder "models/diffusers/stable-diffusion-v1-4" exists, load the model from there
+ if weights_path == "CompVis/stable-diffusion-v1-4":
+ model_path = os.path.join("models", "diffusers", "stable-diffusion-v1-4")
+
+ if weights_path == "runwayml/stable-diffusion-v1-5":
+ model_path = os.path.join("models", "diffusers", "stable-diffusion-v1-5")
+ else:
+ model_path = weights_path
+
+ if not os.path.exists(model_path + "/model_index.json"):
+ server_state["pipe"] = StableDiffusionPipeline.from_pretrained(
+ weights_path,
+ #use_local_file=True,
+ use_auth_token=st.session_state["defaults"].general.huggingface_token,
+ torch_dtype=torch.float16 if st.session_state['defaults'].general.use_float16 else None,
+ revision="fp16" if not st.session_state['defaults'].general.no_half else None,
+ safety_checker=None, # Very important for videos...lots of false positives while interpolating
+ #custom_pipeline="interpolate_stable_diffusion",
+
+ )
+
+ StableDiffusionPipeline.save_pretrained(server_state["pipe"], model_path)
+ else:
+ server_state["pipe"] = StableDiffusionPipeline.from_pretrained(
+ model_path,
+ #use_local_file=True,
+ torch_dtype=torch.float16 if st.session_state['defaults'].general.use_float16 else None,
+ revision="fp16" if not st.session_state['defaults'].general.no_half else None,
+ safety_checker=None, # Very important for videos...lots of false positives while interpolating
+ #custom_pipeline="interpolate_stable_diffusion",
+ )
+
+ server_state["pipe"].unet.to(torch_device)
+ server_state["pipe"].vae.to(torch_device)
+ server_state["pipe"].text_encoder.to(torch_device)
+
+ #if st.session_state.defaults.general.enable_attention_slicing:
+ server_state["pipe"].enable_attention_slicing()
+
+ if st.session_state.defaults.general.enable_minimal_memory_usage:
+ server_state["pipe"].enable_minimal_memory_usage()
+
+ logger.info("Tx2Vid Model Loaded")
+ else:
+ # if the float16 or no_half options have changed since the last time the model was loaded then we need to reload the model.
+ if ("float16" in server_state and server_state['float16'] != st.session_state['defaults'].general.use_float16) \
+ or ("no_half" in server_state and server_state['no_half'] != st.session_state['defaults'].general.no_half) \
+ or ("optimized" in server_state and server_state['optimized'] != st.session_state['defaults'].general.optimized):
+
+ del server_state['float16']
+ del server_state['no_half']
+ with server_state_lock["pipe"]:
+ del server_state["pipe"]
+ torch_gc()
+
+ del server_state['optimized']
+
+ server_state['float16'] = st.session_state['defaults'].general.use_float16
+ server_state['no_half'] = st.session_state['defaults'].general.no_half
+ server_state['optimized'] = st.session_state['defaults'].general.optimized
+
+ #with no_rerun:
+ load_diffusers_model(weights_path, torch_device)
+ else:
+ logger.info("Tx2Vid Model already Loaded")
+
+ except (EnvironmentError, OSError) as e:
+ if "huggingface_token" not in st.session_state or st.session_state["defaults"].general.huggingface_token == "None":
+ if "progress_bar_text" in st.session_state:
+ st.session_state["progress_bar_text"].error(
+ "You need a huggingface token in order to use the Text to Video tab. Use the Settings page to add your token under the Huggingface section. "
+ "Make sure you save your settings after adding it."
+ )
+ raise OSError("You need a huggingface token in order to use the Text to Video tab. Use the Settings page to add your token under the Huggingface section. "
+ "Make sure you save your settings after adding it.")
+ else:
+ if "progress_bar_text" in st.session_state:
+ st.session_state["progress_bar_text"].error(e)
+
+#
+def save_video_to_disk(frames, seeds, sanitized_prompt, fps=30,save_video=True, outdir='outputs'):
+ if save_video:
+ # write video to memory
+ #output = io.BytesIO()
+ #writer = imageio.get_writer(os.path.join(os.getcwd(), st.session_state['defaults'].general.outdir, "txt2vid"), im, extension=".mp4", fps=30)
+ #try:
+ video_path = os.path.join(os.getcwd(), outdir, "txt2vid",f"{seeds}_{sanitized_prompt}{datetime.datetime.now().strftime('%Y%m-%d%H-%M%S-') + str(uuid4())[:8]}.mp4")
+ writer = imageio.get_writer(video_path, fps=fps)
+ for frame in frames:
+ writer.append_data(frame)
+
+ writer.close()
+ #except:
+ # print("Can't save video, skipping.")
+
+ return video_path
+#
+def txt2vid(
+ # --------------------------------------
+ # args you probably want to change
+ prompts = ["blueberry spaghetti", "strawberry spaghetti"], # prompt to dream about
+ gpu:int = st.session_state['defaults'].general.gpu, # id of the gpu to run on
+ #name:str = 'test', # name of this project, for the output directory
+ #rootdir:str = st.session_state['defaults'].general.outdir,
+ num_steps:int = 200, # number of steps between each pair of sampled points
+ max_duration_in_seconds:int = 30, # number of frames to write and then exit the script
+ num_inference_steps:int = 50, # more (e.g. 100, 200 etc) can create slightly better images
+ cfg_scale:float = 5.0, # can depend on the prompt. usually somewhere between 3-10 is good
+ save_video = True,
+ save_video_on_stop = False,
+ outdir='outputs',
+ do_loop = False,
+ use_lerp_for_text = False,
+ seeds = None,
+ quality:int = 100, # for jpeg compression of the output images
+ eta:float = 0.0,
+ width:int = 256,
+ height:int = 256,
+ weights_path = "runwayml/stable-diffusion-v1-5",
+ scheduler="klms", # choices: default, ddim, klms
+ disable_tqdm = False,
+ #-----------------------------------------------
+ beta_start = 0.0001,
+ beta_end = 0.00012,
+ beta_schedule = "scaled_linear",
+ starting_image=None,
+ #-----------------------------------------------
+ # from new version
+ image_file_ext: Optional[str] = ".png",
+ fps: Optional[int] = 30,
+ upsample: Optional[bool] = False,
+ batch_size: Optional[int] = 1,
+ resume: Optional[bool] = False,
+ audio_filepath: str = None,
+ audio_start_sec: Optional[Union[int, float]] = None,
+ margin: Optional[float] = 1.0,
+ smooth: Optional[float] = 0.0,
+ ):
+ """
+ prompt = ["blueberry spaghetti", "strawberry spaghetti"], # prompt to dream about
+ gpu:int = st.session_state['defaults'].general.gpu, # id of the gpu to run on
+ #name:str = 'test', # name of this project, for the output directory
+ #rootdir:str = st.session_state['defaults'].general.outdir,
+ num_steps:int = 200, # number of steps between each pair of sampled points
+ max_duration_in_seconds:int = 10000, # number of frames to write and then exit the script
+ num_inference_steps:int = 50, # more (e.g. 100, 200 etc) can create slightly better images
+ cfg_scale:float = 5.0, # can depend on the prompt. usually somewhere between 3-10 is good
+ do_loop = False,
+ use_lerp_for_text = False,
+ seed = None,
+ quality:int = 100, # for jpeg compression of the output images
+ eta:float = 0.0,
+ width:int = 256,
+ height:int = 256,
+ weights_path = "runwayml/stable-diffusion-v1-5",
+ scheduler="klms", # choices: default, ddim, klms
+ disable_tqdm = False,
+ beta_start = 0.0001,
+ beta_end = 0.00012,
+ beta_schedule = "scaled_linear"
+ """
+ mem_mon = MemUsageMonitor('MemMon')
+ mem_mon.start()
+
+
+ seeds = seed_to_int(seeds)
+
+ # We add an extra frame because most
+ # of the time the first frame is just the noise.
+ #max_duration_in_seconds +=1
+
+ assert torch.cuda.is_available()
+ assert height % 8 == 0 and width % 8 == 0
+ torch.manual_seed(seeds)
+ torch_device = f"cuda:{gpu}"
+
+ if type(seeds) == list:
+ prompts = [prompts] * len(seeds)
+ else:
+ seeds = [seeds, random.randint(0, 2**32 - 1)]
+
+ if type(prompts) == list:
+ # init the output dir
+ sanitized_prompt = slugify(prompts[0])
+ else:
+ # init the output dir
+ sanitized_prompt = slugify(prompts)
+
+ full_path = os.path.join(os.getcwd(), st.session_state['defaults'].general.outdir, "txt2vid", "samples", sanitized_prompt)
+
+ if len(full_path) > 220:
+ sanitized_prompt = sanitized_prompt[:220-len(full_path)]
+ full_path = os.path.join(os.getcwd(), st.session_state['defaults'].general.outdir, "txt2vid", "samples", sanitized_prompt)
+
+ os.makedirs(full_path, exist_ok=True)
+
+ # Write prompt info to file in output dir so we can keep track of what we did
+ if st.session_state.write_info_files:
+ with open(os.path.join(full_path , f'{slugify(str(seeds))}_config.json' if len(prompts) > 1 else "prompts_config.json"), "w") as outfile:
+ outfile.write(json.dumps(
+ dict(
+ prompts = prompts,
+ gpu = gpu,
+ num_steps = num_steps,
+ max_duration_in_seconds = max_duration_in_seconds,
+ num_inference_steps = num_inference_steps,
+ cfg_scale = cfg_scale,
+ do_loop = do_loop,
+ use_lerp_for_text = use_lerp_for_text,
+ seeds = seeds,
+ quality = quality,
+ eta = eta,
+ width = width,
+ height = height,
+ weights_path = weights_path,
+ scheduler=scheduler,
+ disable_tqdm = disable_tqdm,
+ beta_start = beta_start,
+ beta_end = beta_end,
+ beta_schedule = beta_schedule
+ ),
+ indent=2,
+ sort_keys=False,
+ ))
+
+ #print(scheduler)
+ default_scheduler = PNDMScheduler(
+ beta_start=beta_start, beta_end=beta_end, beta_schedule=beta_schedule
+ )
+ # ------------------------------------------------------------------------------
+ #Schedulers
+ ddim_scheduler = DDIMScheduler(
+ beta_start=beta_start,
+ beta_end=beta_end,
+ beta_schedule=beta_schedule,
+ clip_sample=False,
+ set_alpha_to_one=False,
+ )
+
+ klms_scheduler = LMSDiscreteScheduler(
+ beta_start=beta_start, beta_end=beta_end, beta_schedule=beta_schedule
+ )
+
+ #flaxddims_scheduler = FlaxDDIMScheduler(
+ #beta_start=beta_start, beta_end=beta_end, beta_schedule=beta_schedule
+ #)
+
+ #flaxddpms_scheduler = FlaxDDPMScheduler(
+ #beta_start=beta_start, beta_end=beta_end, beta_schedule=beta_schedule
+ #)
+
+ #flaxpndms_scheduler = FlaxPNDMScheduler(
+ #beta_start=beta_start, beta_end=beta_end, beta_schedule=beta_schedule
+ #)
+
+ ddpms_scheduler = DDPMScheduler(
+ beta_start=beta_start, beta_end=beta_end, beta_schedule=beta_schedule
+ )
+
+ SCHEDULERS = dict(default=default_scheduler, ddim=ddim_scheduler,
+ klms=klms_scheduler,
+ ddpms=ddpms_scheduler,
+ #flaxddims=flaxddims_scheduler,
+ #flaxddpms=flaxddpms_scheduler,
+ #flaxpndms=flaxpndms_scheduler,
+ )
+
+ with no_rerun:
+ with st.session_state["progress_bar_text"].container():
+ with hc.HyLoader('Loading Models...', hc.Loaders.standard_loaders,index=[0]):
+ load_diffusers_model(weights_path, torch_device)
+
+ if "pipe" not in server_state:
+ logger.error('wtf')
+
+ server_state["pipe"].scheduler = SCHEDULERS[scheduler]
+
+ server_state["pipe"].use_multiprocessing_for_evaluation = False
+ server_state["pipe"].use_multiprocessed_decoding = False
+
+ #if do_loop:
+ ##Makes the last prompt loop back to first prompt
+ #prompts = [prompts, prompts]
+ #seeds = [seeds, seeds]
+ #first_seed, *seeds = seeds
+ #prompts.append(prompts)
+ #seeds.append(first_seed)
+
+ with torch.autocast('cuda'):
+ # get the conditional text embeddings based on the prompt
+ text_input = server_state["pipe"].tokenizer(prompts, padding="max_length", max_length=server_state["pipe"].tokenizer.model_max_length, truncation=True, return_tensors="pt")
+ cond_embeddings = server_state["pipe"].text_encoder(text_input.input_ids.to(torch_device) )[0]
+
+ #
+ if st.session_state.defaults.general.use_sd_concepts_library:
+
+ prompt_tokens = re.findall('<([a-zA-Z0-9-]+)>', str(prompts))
+
+ if prompt_tokens:
+ # compviz
+ #tokenizer = (st.session_state["model"] if not st.session_state['defaults'].general.optimized else st.session_state.modelCS).cond_stage_model.tokenizer
+ #text_encoder = (st.session_state["model"] if not st.session_state['defaults'].general.optimized else st.session_state.modelCS).cond_stage_model.transformer
+
+ # diffusers
+ tokenizer = st.session_state.pipe.tokenizer
+ text_encoder = st.session_state.pipe.text_encoder
+
+ ext = ('pt', 'bin')
+ #print (prompt_tokens)
+
+ if len(prompt_tokens) > 1:
+ for token_name in prompt_tokens:
+ embedding_path = os.path.join(st.session_state['defaults'].general.sd_concepts_library_folder, token_name)
+ if os.path.exists(embedding_path):
+ for files in os.listdir(embedding_path):
+ if files.endswith(ext):
+ load_learned_embed_in_clip(f"{os.path.join(embedding_path, files)}", text_encoder, tokenizer, f"<{token_name}>")
+ else:
+ embedding_path = os.path.join(st.session_state['defaults'].general.sd_concepts_library_folder, prompt_tokens[0])
+ if os.path.exists(embedding_path):
+ for files in os.listdir(embedding_path):
+ if files.endswith(ext):
+ load_learned_embed_in_clip(f"{os.path.join(embedding_path, files)}", text_encoder, tokenizer, f"<{prompt_tokens[0]}>")
+
+ # sample a source
+ init1 = torch.randn((1, server_state["pipe"].unet.in_channels, height // 8, width // 8), device=torch_device)
+
+
+ # iterate the loop
+ frames = []
+ frame_index = 0
+
+ second_count = 1
+
+ st.session_state["total_frames_avg_duration"] = []
+ st.session_state["total_frames_avg_speed"] = []
+
+ try:
+ # code for the new StableDiffusionWalkPipeline implementation.
+ start = timeit.default_timer()
+
+ # preview image works but its not the right way to use this, this also do not work properly as it only makes one image and then exits.
+ #with torch.autocast("cuda"):
+ #StableDiffusionWalkPipeline.__call__(self=server_state["pipe"],
+ #prompt=prompts, height=height, width=width, num_inference_steps=num_inference_steps, guidance_scale=cfg_scale,
+ #negative_prompt="", num_images_per_prompt=1, eta=0.0,
+ #callback=txt2vid_generation_callback, callback_steps=1,
+ #num_interpolation_steps=num_steps,
+ #fps=30,
+ #image_file_ext = ".png",
+ #output_dir=full_path, # Where images/videos will be saved
+ ##name='animals_test', # Subdirectory of output_dir where images/videos will be saved
+ #upsample = False,
+ ##do_loop=do_loop, # Change to True if you want last prompt to loop back to first prompt
+ #resume = False,
+ #audio_filepath = None,
+ #audio_start_sec = None,
+ #margin = 1.0,
+ #smooth = 0.0, )
+
+ # works correctly generating all frames but do not show the preview image
+ # we also do not have control over the generation and cant stop it until the end of it.
+ #with torch.autocast("cuda"):
+ #print (prompts)
+ #video_path = server_state["pipe"].walk(
+ #prompt=prompts,
+ #seeds=seeds,
+ #num_interpolation_steps=num_steps,
+ #height=height, # use multiples of 64 if > 512. Multiples of 8 if < 512.
+ #width=width, # use multiples of 64 if > 512. Multiples of 8 if < 512.
+ #batch_size=4,
+ #fps=30,
+ #image_file_ext = ".png",
+ #eta = 0.0,
+ #output_dir=full_path, # Where images/videos will be saved
+ ##name='test', # Subdirectory of output_dir where images/videos will be saved
+ #guidance_scale=cfg_scale, # Higher adheres to prompt more, lower lets model take the wheel
+ #num_inference_steps=num_inference_steps, # Number of diffusion steps per image generated. 50 is good default
+ #upsample = False,
+ ##do_loop=do_loop, # Change to True if you want last prompt to loop back to first prompt
+ #resume = False,
+ #audio_filepath = None,
+ #audio_start_sec = None,
+ #margin = 1.0,
+ #smooth = 0.0,
+ #callback=txt2vid_generation_callback, # our callback function will be called with the arguments callback(step, timestep, latents)
+ #callback_steps=1 # our callback function will be called once this many steps are processed in a single frame
+ #)
+
+ # old code
+ total_frames = st.session_state.max_duration_in_seconds * fps
+
+ while frame_index+1 <= total_frames:
+ st.session_state["frame_duration"] = 0
+ st.session_state["frame_speed"] = 0
+ st.session_state["current_frame"] = frame_index
+
+ #print(f"Second: {second_count+1}/{max_duration_in_seconds}")
+
+ # sample the destination
+ init2 = torch.randn((1, server_state["pipe"].unet.in_channels, height // 8, width // 8), device=torch_device)
+
+ for i, t in enumerate(np.linspace(0, 1, num_steps)):
+ start = timeit.default_timer()
+ logger.info(f"COUNT: {frame_index+1}/{total_frames}")
+
+ if use_lerp_for_text:
+ init = torch.lerp(init1, init2, float(t))
+ else:
+ init = slerp(gpu, float(t), init1, init2)
+
+ #init = slerp(gpu, float(t), init1, init2)
+
+ with autocast("cuda"):
+ image = diffuse(server_state["pipe"], cond_embeddings, init, num_inference_steps, cfg_scale, eta, fps=fps)
+
+ if st.session_state["save_individual_images"] and not st.session_state["use_GFPGAN"] and not st.session_state["use_RealESRGAN"]:
+ #im = Image.fromarray(image)
+ outpath = os.path.join(full_path, 'frame%06d.png' % frame_index)
+ image.save(outpath, quality=quality)
+
+ # send the image to the UI to update it
+ #st.session_state["preview_image"].image(im)
+
+ #append the frames to the frames list so we can use them later.
+ frames.append(np.asarray(image))
+
+
+ #
+ #try:
+ #if st.session_state["use_GFPGAN"] and server_state["GFPGAN"] is not None and not st.session_state["use_RealESRGAN"]:
+ if st.session_state["use_GFPGAN"] and server_state["GFPGAN"] is not None:
+ #print("Running GFPGAN on image ...")
+ if "progress_bar_text" in st.session_state:
+ st.session_state["progress_bar_text"].text("Running GFPGAN on image ...")
+ #skip_save = True # #287 >_>
+ torch_gc()
+ cropped_faces, restored_faces, restored_img = server_state["GFPGAN"].enhance(np.array(image)[:,:,::-1], has_aligned=False, only_center_face=False, paste_back=True)
+ gfpgan_sample = restored_img[:,:,::-1]
+ gfpgan_image = Image.fromarray(gfpgan_sample)
+
+ outpath = os.path.join(full_path, 'frame%06d.png' % frame_index)
+ gfpgan_image.save(outpath, quality=quality)
+
+ #append the frames to the frames list so we can use them later.
+ frames.append(np.asarray(gfpgan_image))
+ try:
+ st.session_state["preview_image"].image(gfpgan_image)
+ except KeyError:
+ logger.error ("Cant get session_state, skipping image preview.")
+ #except (AttributeError, KeyError):
+ #print("Cant perform GFPGAN, skipping.")
+
+ #increase frame_index counter.
+ frame_index += 1
+
+ st.session_state["current_frame"] = frame_index
+
+ duration = timeit.default_timer() - start
+
+ if duration >= 1:
+ speed = "s/it"
+ else:
+ speed = "it/s"
+ duration = 1 / duration
+
+ st.session_state["frame_duration"] = duration
+ st.session_state["frame_speed"] = speed
+ if frame_index+1 > total_frames:
+ break
+
+ init1 = init2
+
+ # save the video after the generation is done.
+ video_path = save_video_to_disk(frames, seeds, sanitized_prompt, save_video=save_video, outdir=outdir)
+
+ except StopException:
+ # reset the page title so the percent doesnt stay on it confusing the user.
+ set_page_title(f"Stable Diffusion Playground")
+
+ if save_video_on_stop:
+ logger.info("Streamlit Stop Exception Received. Saving video")
+ video_path = save_video_to_disk(frames, seeds, sanitized_prompt, save_video=save_video, outdir=outdir)
+ else:
+ video_path = None
+
+
+ #if video_path and "preview_video" in st.session_state:
+ ## show video preview on the UI
+ #st.session_state["preview_video"].video(open(video_path, 'rb').read())
+
+ mem_max_used, mem_total = mem_mon.read_and_stop()
+ time_diff = time.time()- start
+
+ info = f"""
+ {prompts}
+ Sampling Steps: {num_steps}, Sampler: {scheduler}, CFG scale: {cfg_scale}, Seed: {seeds}, Max Duration In Seconds: {max_duration_in_seconds}""".strip()
+ stats = f'''
+ Took { round(time_diff, 2) }s total ({ round(time_diff/(max_duration_in_seconds),2) }s per image)
+ Peak memory usage: { -(mem_max_used // -1_048_576) } MiB / { -(mem_total // -1_048_576) } MiB / { round(mem_max_used/mem_total*100, 3) }%'''
+
+ return video_path, seeds, info, stats
+
+#
+def layout():
+ with st.form("txt2vid-inputs"):
+ st.session_state["generation_mode"] = "txt2vid"
+
+ input_col1, generate_col1 = st.columns([10,1])
+ with input_col1:
+ #prompt = st.text_area("Input Text","")
+ placeholder = "A corgi wearing a top hat as an oil painting."
+ prompt = st.text_area("Input Text","", placeholder=placeholder, height=54)
+
+ if "defaults" in st.session_state:
+ if st.session_state["defaults"].general.enable_suggestions:
+ sygil_suggestions.suggestion_area(placeholder)
+
+ if "defaults" in st.session_state:
+ if st.session_state['defaults'].admin.global_negative_prompt:
+ prompt += f"### {st.session_state['defaults'].admin.global_negative_prompt}"
+
+ # Every form must have a submit button, the extra blank spaces is a temp way to align it with the input field. Needs to be done in CSS or some other way.
+ generate_col1.write("")
+ generate_col1.write("")
+ generate_button = generate_col1.form_submit_button("Generate")
+
+ # creating the page layout using columns
+ col1, col2, col3 = st.columns([2,5,2], gap="large")
+
+ with col1:
+ width = st.slider("Width:", min_value=st.session_state['defaults'].txt2vid.width.min_value, max_value=st.session_state['defaults'].txt2vid.width.max_value,
+ value=st.session_state['defaults'].txt2vid.width.value, step=st.session_state['defaults'].txt2vid.width.step)
+ height = st.slider("Height:", min_value=st.session_state['defaults'].txt2vid.height.min_value, max_value=st.session_state['defaults'].txt2vid.height.max_value,
+ value=st.session_state['defaults'].txt2vid.height.value, step=st.session_state['defaults'].txt2vid.height.step)
+ cfg_scale = st.number_input("CFG (Classifier Free Guidance Scale):", min_value=st.session_state['defaults'].txt2vid.cfg_scale.min_value,
+ value=st.session_state['defaults'].txt2vid.cfg_scale.value,
+ step=st.session_state['defaults'].txt2vid.cfg_scale.step,
+ help="How strongly the image should follow the prompt.")
+
+ #uploaded_images = st.file_uploader("Upload Image", accept_multiple_files=False, type=["png", "jpg", "jpeg", "webp"],
+ #help="Upload an image which will be used for the image to image generation.")
+ seed = st.text_input("Seed:", value=st.session_state['defaults'].txt2vid.seed, help=" The seed to use, if left blank a random seed will be generated.")
+ #batch_count = st.slider("Batch count.", min_value=1, max_value=100, value=st.session_state['defaults'].txt2vid.batch_count,
+ # step=1, help="How many iterations or batches of images to generate in total.")
+ #batch_size = st.slider("Batch size", min_value=1, max_value=250, value=st.session_state['defaults'].txt2vid.batch_size, step=1,
+ #help="How many images are at once in a batch.\
+ #It increases the VRAM usage a lot but if you have enough VRAM it can reduce the time it takes to finish generation as more images are generated at once.\
+ #Default: 1")
+
+ st.session_state["max_duration_in_seconds"] = st.number_input("Max Duration In Seconds:", value=st.session_state['defaults'].txt2vid.max_duration_in_seconds,
+ help="Specify the max duration in seconds you want your video to be.")
+
+ st.session_state["fps"] = st.number_input("Frames per Second (FPS):", value=st.session_state['defaults'].txt2vid.fps,
+ help="Specify the frame rate of the video.")
+
+ with st.expander("Preview Settings"):
+ #st.session_state["update_preview"] = st.checkbox("Update Image Preview", value=st.session_state['defaults'].txt2vid.update_preview,
+ #help="If enabled the image preview will be updated during the generation instead of at the end. \
+ #You can use the Update Preview \Frequency option bellow to customize how frequent it's updated. \
+ #By default this is enabled and the frequency is set to 1 step.")
+
+ st.session_state["update_preview"] = st.session_state["defaults"].general.update_preview
+ st.session_state["update_preview_frequency"] = st.number_input("Update Image Preview Frequency",
+ min_value=0,
+ value=st.session_state['defaults'].txt2vid.update_preview_frequency,
+ help="Frequency in steps at which the the preview image is updated. By default the frequency \
+ is set to 1 step.")
+
+ st.session_state["dynamic_preview_frequency"] = st.checkbox("Dynamic Preview Frequency", value=st.session_state['defaults'].txt2vid.dynamic_preview_frequency,
+ help="This option tries to find the best value at which we can update \
+ the preview image during generation while minimizing the impact it has in performance. Default: True")
+
+
+ #
+
+
+
+ with col2:
+ preview_tab, gallery_tab = st.tabs(["Preview", "Gallery"])
+
+ with preview_tab:
+ #st.write("Image")
+ #Image for testing
+ #image = Image.open(requests.get("https://icon-library.com/images/image-placeholder-icon/image-placeholder-icon-13.jpg", stream=True).raw).convert('RGB')
+ #new_image = image.resize((175, 240))
+ #preview_image = st.image(image)
+
+ # create an empty container for the image, progress bar, etc so we can update it later and use session_state to hold them globally.
+ st.session_state["preview_image"] = st.empty()
+
+ st.session_state["loading"] = st.empty()
+
+ st.session_state["progress_bar_text"] = st.empty()
+ st.session_state["progress_bar"] = st.empty()
+
+ #generate_video = st.empty()
+ st.session_state["preview_video"] = st.empty()
+ preview_video = st.session_state["preview_video"]
+
+ message = st.empty()
+
+ with gallery_tab:
+ st.write('Here should be the image gallery, if I could make a grid in streamlit.')
+
+ with col3:
+ # If we have custom models available on the "models/custom"
+ #folder then we show a menu to select which model we want to use, otherwise we use the main model for SD
+ custom_models_available()
+ if server_state["CustomModel_available"]:
+ custom_model = st.selectbox("Custom Model:", st.session_state["defaults"].txt2vid.custom_models_list,
+ index=st.session_state["defaults"].txt2vid.custom_models_list.index(st.session_state["defaults"].txt2vid.default_model),
+ help="Select the model you want to use. This option is only available if you have custom models \
+ on your 'models/custom' folder. The model name that will be shown here is the same as the name\
+ the file for the model has on said folder, it is recommended to give the .ckpt file a name that \
+ will make it easier for you to distinguish it from other models. Default: Stable Diffusion v1.5")
+ else:
+ custom_model = "runwayml/stable-diffusion-v1-5"
+
+ #st.session_state["weights_path"] = custom_model
+ #else:
+ #custom_model = "runwayml/stable-diffusion-v1-5"
+ #st.session_state["weights_path"] = f"CompVis/{slugify(custom_model.lower())}"
+
+ st.session_state.sampling_steps = st.number_input("Sampling Steps", value=st.session_state['defaults'].txt2vid.sampling_steps.value,
+ min_value=st.session_state['defaults'].txt2vid.sampling_steps.min_value,
+ step=st.session_state['defaults'].txt2vid.sampling_steps.step, help="Number of steps between each pair of sampled points")
+
+ st.session_state.num_inference_steps = st.number_input("Inference Steps:", value=st.session_state['defaults'].txt2vid.num_inference_steps.value,
+ min_value=st.session_state['defaults'].txt2vid.num_inference_steps.min_value,
+ step=st.session_state['defaults'].txt2vid.num_inference_steps.step,
+ help="Higher values (e.g. 100, 200 etc) can create better images.")
+
+ #sampler_name_list = ["k_lms", "k_euler", "k_euler_a", "k_dpm_2", "k_dpm_2_a", "k_heun", "PLMS", "DDIM"]
+ #sampler_name = st.selectbox("Sampling method", sampler_name_list,
+ #index=sampler_name_list.index(st.session_state['defaults'].txt2vid.default_sampler), help="Sampling method to use. Default: k_euler")
+ scheduler_name_list = ["klms", "ddim", "ddpms",
+ #"flaxddims", "flaxddpms", "flaxpndms"
+ ]
+ scheduler_name = st.selectbox("Scheduler:", scheduler_name_list,
+ index=scheduler_name_list.index(st.session_state['defaults'].txt2vid.scheduler_name), help="Scheduler to use. Default: klms")
+
+ beta_scheduler_type_list = ["scaled_linear", "linear"]
+ beta_scheduler_type = st.selectbox("Beta Schedule Type:", beta_scheduler_type_list,
+ index=beta_scheduler_type_list.index(st.session_state['defaults'].txt2vid.beta_scheduler_type), help="Schedule Type to use. Default: linear")
+
+
+ #basic_tab, advanced_tab = st.tabs(["Basic", "Advanced"])
+
+ #with basic_tab:
+ #summit_on_enter = st.radio("Submit on enter?", ("Yes", "No"), horizontal=True,
+ #help="Press the Enter key to summit, when 'No' is selected you can use the Enter key to write multiple lines.")
+
+ with st.expander("Advanced"):
+ with st.expander("Output Settings"):
+ st.session_state["separate_prompts"] = st.checkbox("Create Prompt Matrix.", value=st.session_state['defaults'].txt2vid.separate_prompts,
+ help="Separate multiple prompts using the `|` character, and get all combinations of them.")
+ st.session_state["normalize_prompt_weights"] = st.checkbox("Normalize Prompt Weights.",
+ value=st.session_state['defaults'].txt2vid.normalize_prompt_weights, help="Ensure the sum of all weights add up to 1.0")
+
+ st.session_state["save_individual_images"] = st.checkbox("Save individual images.",
+ value=st.session_state['defaults'].txt2vid.save_individual_images,
+ help="Save each image generated before any filter or enhancement is applied.")
+
+ st.session_state["save_video"] = st.checkbox("Save video",value=st.session_state['defaults'].txt2vid.save_video,
+ help="Save a video with all the images generated as frames at the end of the generation.")
+
+ save_video_on_stop = st.checkbox("Save video on Stop",value=st.session_state['defaults'].txt2vid.save_video_on_stop,
+ help="Save a video with all the images generated as frames when we hit the stop button during a generation.")
+
+ st.session_state["group_by_prompt"] = st.checkbox("Group results by prompt", value=st.session_state['defaults'].txt2vid.group_by_prompt,
+ help="Saves all the images with the same prompt into the same folder. When using a prompt \
+ matrix each prompt combination will have its own folder.")
+
+ st.session_state["write_info_files"] = st.checkbox("Write Info file", value=st.session_state['defaults'].txt2vid.write_info_files,
+ help="Save a file next to the image with informartion about the generation.")
+
+ st.session_state["do_loop"] = st.checkbox("Do Loop", value=st.session_state['defaults'].txt2vid.do_loop,
+ help="Loop the prompt making two prompts from a single one.")
+
+ st.session_state["use_lerp_for_text"] = st.checkbox("Use Lerp Instead of Slerp", value=st.session_state['defaults'].txt2vid.use_lerp_for_text,
+ help="Uses torch.lerp() instead of slerp. When interpolating between related prompts. \
+ e.g. 'a lion in a grassy meadow' -> 'a bear in a grassy meadow' tends to keep the meadow \
+ the whole way through when lerped, but slerping will often find a path where the meadow \
+ disappears in the middle")
+
+ st.session_state["save_as_jpg"] = st.checkbox("Save samples as jpg", value=st.session_state['defaults'].txt2vid.save_as_jpg, help="Saves the images as jpg instead of png.")
+
+ #
+ if "GFPGAN_available" not in st.session_state:
+ GFPGAN_available()
+
+ if "RealESRGAN_available" not in st.session_state:
+ RealESRGAN_available()
+
+ if "LDSR_available" not in st.session_state:
+ LDSR_available()
+
+ if st.session_state["GFPGAN_available"] or st.session_state["RealESRGAN_available"] or st.session_state["LDSR_available"]:
+ with st.expander("Post-Processing"):
+ face_restoration_tab, upscaling_tab = st.tabs(["Face Restoration", "Upscaling"])
+ with face_restoration_tab:
+ # GFPGAN used for face restoration
+ if st.session_state["GFPGAN_available"]:
+ #with st.expander("Face Restoration"):
+ #if st.session_state["GFPGAN_available"]:
+ #with st.expander("GFPGAN"):
+ st.session_state["use_GFPGAN"] = st.checkbox("Use GFPGAN", value=st.session_state['defaults'].txt2vid.use_GFPGAN,
+ help="Uses the GFPGAN model to improve faces after the generation.\
+ This greatly improve the quality and consistency of faces but uses\
+ extra VRAM. Disable if you need the extra VRAM.")
+
+ st.session_state["GFPGAN_model"] = st.selectbox("GFPGAN model", st.session_state["GFPGAN_models"],
+ index=st.session_state["GFPGAN_models"].index(st.session_state['defaults'].general.GFPGAN_model))
+
+ #st.session_state["GFPGAN_strenght"] = st.slider("Effect Strenght", min_value=1, max_value=100, value=1, step=1, help='')
+
+ else:
+ st.session_state["use_GFPGAN"] = False
+
+ with upscaling_tab:
+ st.session_state['us_upscaling'] = st.checkbox("Use Upscaling", value=st.session_state['defaults'].txt2vid.use_upscaling)
+ # RealESRGAN and LDSR used for upscaling.
+ if st.session_state["RealESRGAN_available"] or st.session_state["LDSR_available"]:
+
+ upscaling_method_list = []
+ if st.session_state["RealESRGAN_available"]:
+ upscaling_method_list.append("RealESRGAN")
+ if st.session_state["LDSR_available"]:
+ upscaling_method_list.append("LDSR")
+
+ st.session_state["upscaling_method"] = st.selectbox("Upscaling Method", upscaling_method_list,
+ index=upscaling_method_list.index(st.session_state['defaults'].general.upscaling_method)
+ if st.session_state['defaults'].general.upscaling_method in upscaling_method_list
+ else 0)
+
+ if st.session_state["RealESRGAN_available"]:
+ with st.expander("RealESRGAN"):
+ if st.session_state["upscaling_method"] == "RealESRGAN" and st.session_state['us_upscaling']:
+ st.session_state["use_RealESRGAN"] = True
+ else:
+ st.session_state["use_RealESRGAN"] = False
+
+ st.session_state["RealESRGAN_model"] = st.selectbox("RealESRGAN model", st.session_state["RealESRGAN_models"],
+ index=st.session_state["RealESRGAN_models"].index(st.session_state['defaults'].general.RealESRGAN_model))
+ else:
+ st.session_state["use_RealESRGAN"] = False
+ st.session_state["RealESRGAN_model"] = "RealESRGAN_x4plus"
+
+
+ #
+ if st.session_state["LDSR_available"]:
+ with st.expander("LDSR"):
+ if st.session_state["upscaling_method"] == "LDSR" and st.session_state['us_upscaling']:
+ st.session_state["use_LDSR"] = True
+ else:
+ st.session_state["use_LDSR"] = False
+
+ st.session_state["LDSR_model"] = st.selectbox("LDSR model", st.session_state["LDSR_models"],
+ index=st.session_state["LDSR_models"].index(st.session_state['defaults'].general.LDSR_model))
+
+ st.session_state["ldsr_sampling_steps"] = st.number_input("Sampling Steps", value=st.session_state['defaults'].txt2vid.LDSR_config.sampling_steps,
+ help="")
+
+ st.session_state["preDownScale"] = st.number_input("PreDownScale", value=st.session_state['defaults'].txt2vid.LDSR_config.preDownScale,
+ help="")
+
+ st.session_state["postDownScale"] = st.number_input("postDownScale", value=st.session_state['defaults'].txt2vid.LDSR_config.postDownScale,
+ help="")
+
+ downsample_method_list = ['Nearest', 'Lanczos']
+ st.session_state["downsample_method"] = st.selectbox("Downsample Method", downsample_method_list,
+ index=downsample_method_list.index(st.session_state['defaults'].txt2vid.LDSR_config.downsample_method))
+
+ else:
+ st.session_state["use_LDSR"] = False
+ st.session_state["LDSR_model"] = "model"
+
+ with st.expander("Variant"):
+ st.session_state["variant_amount"] = st.number_input("Variant Amount:", value=st.session_state['defaults'].txt2vid.variant_amount.value,
+ min_value=st.session_state['defaults'].txt2vid.variant_amount.min_value,
+ max_value=st.session_state['defaults'].txt2vid.variant_amount.max_value,
+ step=st.session_state['defaults'].txt2vid.variant_amount.step)
+
+ st.session_state["variant_seed"] = st.text_input("Variant Seed:", value=st.session_state['defaults'].txt2vid.seed,
+ help="The seed to use when generating a variant, if left blank a random seed will be generated.")
+
+ #st.session_state["beta_start"] = st.slider("Beta Start:", value=st.session_state['defaults'].txt2vid.beta_start.value,
+ #min_value=st.session_state['defaults'].txt2vid.beta_start.min_value,
+ #max_value=st.session_state['defaults'].txt2vid.beta_start.max_value,
+ #step=st.session_state['defaults'].txt2vid.beta_start.step, format=st.session_state['defaults'].txt2vid.beta_start.format)
+ #st.session_state["beta_end"] = st.slider("Beta End:", value=st.session_state['defaults'].txt2vid.beta_end.value,
+ #min_value=st.session_state['defaults'].txt2vid.beta_end.min_value, max_value=st.session_state['defaults'].txt2vid.beta_end.max_value,
+ #step=st.session_state['defaults'].txt2vid.beta_end.step, format=st.session_state['defaults'].txt2vid.beta_end.format)
+
+ if generate_button:
+ #print("Loading models")
+ # load the models when we hit the generate button for the first time, it wont be loaded after that so dont worry.
+ #load_models(False, st.session_state["use_GFPGAN"], True, st.session_state["RealESRGAN_model"])
+ #with no_rerun:
+ if st.session_state["use_GFPGAN"]:
+ if "GFPGAN" in server_state:
+ logger.info("GFPGAN already loaded")
+ else:
+ with col2:
+ with hc.HyLoader('Loading Models...', hc.Loaders.standard_loaders,index=[0]):
+ # Load GFPGAN
+ if os.path.exists(st.session_state["defaults"].general.GFPGAN_dir):
+ try:
+ load_GFPGAN()
+ logger.info("Loaded GFPGAN")
+ except Exception:
+ import traceback
+ logger.error("Error loading GFPGAN:", file=sys.stderr)
+ logger.error(traceback.format_exc(), file=sys.stderr)
+ else:
+ if "GFPGAN" in server_state:
+ del server_state["GFPGAN"]
+
+ #try:
+ # run video generation
+ video, seed, info, stats = txt2vid(prompts=prompt, gpu=st.session_state["defaults"].general.gpu,
+ num_steps=st.session_state.sampling_steps, max_duration_in_seconds=st.session_state.max_duration_in_seconds,
+ num_inference_steps=st.session_state.num_inference_steps,
+ cfg_scale=cfg_scale, save_video_on_stop=save_video_on_stop,
+ outdir=st.session_state["defaults"].general.outdir,
+ do_loop=st.session_state["do_loop"],
+ use_lerp_for_text=st.session_state["use_lerp_for_text"],
+ seeds=seed, quality=100, eta=0.0, width=width,
+ height=height, weights_path=custom_model, scheduler=scheduler_name,
+ disable_tqdm=False, beta_start=st.session_state['defaults'].txt2vid.beta_start.value,
+ beta_end=st.session_state['defaults'].txt2vid.beta_end.value,
+ beta_schedule=beta_scheduler_type, starting_image=None, fps=st.session_state.fps)
+
+ if video and save_video_on_stop:
+ if os.path.exists(video): # temporary solution to bypass exception
+ # show video preview on the UI after we hit the stop button
+ # currently not working as session_state is cleared on StopException
+ preview_video.video(open(video, 'rb').read())
+
+ #message.success('Done!', icon="✅")
+ message.success('Render Complete: ' + info + '; Stats: ' + stats, icon="✅")
+
+ #history_tab,col1,col2,col3,PlaceHolder,col1_cont,col2_cont,col3_cont = st.session_state['historyTab']
+
+ #if 'latestVideos' in st.session_state:
+ #for i in video:
+ ##push the new image to the list of latest images and remove the oldest one
+ ##remove the last index from the list\
+ #st.session_state['latestVideos'].pop()
+ ##add the new image to the start of the list
+ #st.session_state['latestVideos'].insert(0, i)
+ #PlaceHolder.empty()
+
+ #with PlaceHolder.container():
+ #col1, col2, col3 = st.columns(3)
+ #col1_cont = st.container()
+ #col2_cont = st.container()
+ #col3_cont = st.container()
+
+ #with col1_cont:
+ #with col1:
+ #st.image(st.session_state['latestVideos'][0])
+ #st.image(st.session_state['latestVideos'][3])
+ #st.image(st.session_state['latestVideos'][6])
+ #with col2_cont:
+ #with col2:
+ #st.image(st.session_state['latestVideos'][1])
+ #st.image(st.session_state['latestVideos'][4])
+ #st.image(st.session_state['latestVideos'][7])
+ #with col3_cont:
+ #with col3:
+ #st.image(st.session_state['latestVideos'][2])
+ #st.image(st.session_state['latestVideos'][5])
+ #st.image(st.session_state['latestVideos'][8])
+ #historyGallery = st.empty()
+
+ ## check if output_images length is the same as seeds length
+ #with gallery_tab:
+ #st.markdown(createHTMLGallery(video,seed), unsafe_allow_html=True)
+
+
+ #st.session_state['historyTab'] = [history_tab,col1,col2,col3,PlaceHolder,col1_cont,col2_cont,col3_cont]
+
+ #except (StopException, KeyError):
+ #print(f"Received Streamlit StopException")
+
+
diff --git a/webui/streamlit/scripts/webui_streamlit.py b/webui/streamlit/scripts/webui_streamlit.py
new file mode 100644
index 0000000..95e04fb
--- /dev/null
+++ b/webui/streamlit/scripts/webui_streamlit.py
@@ -0,0 +1,277 @@
+# This file is part of sygil-webui (https://github.com/Sygil-Dev/sygil-webui/).
+
+# Copyright 2022 Sygil-Dev team.
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+# base webui import and utils.
+#import streamlit as st
+
+# We import hydralit like this to replace the previous stuff
+# we had with native streamlit as it lets ur replace things 1:1
+from sd_utils import st, hc, load_configs, load_css, set_logger_verbosity,\
+ logger, quiesce_logger, set_page_title, random
+
+# streamlit imports
+import streamlit_nested_layout
+
+#streamlit components section
+#from st_on_hover_tabs import on_hover_tabs
+#from streamlit_server_state import server_state, server_state_lock
+
+#other imports
+import argparse
+#from sd_utils.bridge import run_bridge
+
+# import custom components
+from custom_components import draggable_number_input
+
+# end of imports
+#---------------------------------------------------------------------------------------------------------------
+
+load_configs()
+
+help = """
+A double dash (`--`) is used to separate streamlit arguments from app arguments.
+As a result using "streamlit run webui_streamlit.py --headless"
+will show the help for streamlit itself and not pass any argument to our app,
+we need to use "streamlit run webui_streamlit.py -- --headless"
+in order to pass a command argument to this app."""
+parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+
+parser.add_argument("--headless", action='store_true', help="Don't launch web server, util if you just want to run the stable horde bridge.", default=False)
+
+parser.add_argument("--bridge", action='store_true', help="don't launch web server, but make this instance into a Horde bridge.", default=False)
+parser.add_argument('--horde_api_key', action="store", required=False, type=str, help="The API key corresponding to the owner of this Horde instance")
+parser.add_argument('--horde_name', action="store", required=False, type=str, help="The server name for the Horde. It will be shown to the world and there can be only one.")
+parser.add_argument('--horde_url', action="store", required=False, type=str, help="The SH Horde URL. Where the bridge will pickup prompts and send the finished generations.")
+parser.add_argument('--horde_priority_usernames',type=str, action='append', required=False, help="Usernames which get priority use in this horde instance. The owner's username is always in this list.")
+parser.add_argument('--horde_max_power',type=int, required=False, help="How much power this instance has to generate pictures. Min: 2")
+parser.add_argument('--horde_sfw', action='store_true', required=False, help="Set to true if you do not want this worker generating NSFW images.")
+parser.add_argument('--horde_blacklist', nargs='+', required=False, help="List the words that you want to blacklist.")
+parser.add_argument('--horde_censorlist', nargs='+', required=False, help="List the words that you want to censor.")
+parser.add_argument('--horde_censor_nsfw', action='store_true', required=False, help="Set to true if you want this bridge worker to censor NSFW images.")
+parser.add_argument('--horde_model', action='store', required=False, help="Which model to run on this horde.")
+parser.add_argument('-v', '--verbosity', action='count', default=0, help="The default logging level is ERROR or higher. This value increases the amount of logging seen in your screen")
+parser.add_argument('-q', '--quiet', action='count', default=0, help="The default logging level is ERROR or higher. This value decreases the amount of logging seen in your screen")
+opt = parser.parse_args()
+
+#with server_state_lock["bridge"]:
+ #server_state["bridge"] = opt.bridge
+
+@logger.catch(reraise=True)
+def layout():
+ """Layout functions to define all the streamlit layout here."""
+ if not st.session_state["defaults"].debug.enable_hydralit:
+ st.set_page_config(page_title="Stable Diffusion Playground", layout="wide", initial_sidebar_state="collapsed")
+
+ #app = st.HydraApp(title='Stable Diffusion WebUI', favicon="", sidebar_state="expanded", layout="wide",
+ #hide_streamlit_markers=False, allow_url_nav=True , clear_cross_app_sessions=False)
+
+
+ # load css as an external file, function has an option to local or remote url. Potential use when running from cloud infra that might not have access to local path.
+ load_css(True, 'frontend/css/streamlit.main.css')
+
+ #
+ # specify the primary menu definition
+ menu_data = [
+ {'id': 'Stable Diffusion', 'label': 'Stable Diffusion', 'icon': 'bi bi-grid-1x2-fill'},
+ {'id': 'Train','label':"Train", 'icon': "bi bi-lightbulb-fill", 'submenu':[
+ {'id': 'Textual Inversion', 'label': 'Textual Inversion', 'icon': 'bi bi-lightbulb-fill'},
+ {'id': 'Fine Tunning', 'label': 'Fine Tunning', 'icon': 'bi bi-lightbulb-fill'},
+ ]},
+ {'id': 'Model Manager', 'label': 'Model Manager', 'icon': 'bi bi-cloud-arrow-down-fill'},
+ {'id': 'Tools','label':"Tools", 'icon': "bi bi-tools", 'submenu':[
+ {'id': 'API Server', 'label': 'API Server', 'icon': 'bi bi-server'},
+ {'id': 'Barfi/BaklavaJS', 'label': 'Barfi/BaklavaJS', 'icon': 'bi bi-diagram-3-fill'},
+ #{'id': 'API Server', 'label': 'API Server', 'icon': 'bi bi-server'},
+ ]},
+ {'id': 'Settings', 'label': 'Settings', 'icon': 'bi bi-gear-fill'},
+ ]
+
+ over_theme = {'txc_inactive': '#FFFFFF', "menu_background":'#000000'}
+
+ menu_id = hc.nav_bar(
+ menu_definition=menu_data,
+ #home_name='Home',
+ #login_name='Logout',
+ hide_streamlit_markers=False,
+ override_theme=over_theme,
+ sticky_nav=True,
+ sticky_mode='pinned',
+ )
+
+ #
+ #if menu_id == "Home":
+ #st.info("Under Construction. :construction_worker:")
+
+ if menu_id == "Stable Diffusion":
+ # set the page url and title
+ #st.experimental_set_query_params(page='stable-diffusion')
+ try:
+ set_page_title("Stable Diffusion Playground")
+ except NameError:
+ st.experimental_rerun()
+
+ txt2img_tab, img2img_tab, txt2vid_tab, img2txt_tab, post_processing_tab, concept_library_tab = st.tabs(["Text-to-Image", "Image-to-Image",
+ #"Inpainting",
+ "Text-to-Video", "Image-To-Text",
+ "Post-Processing","Concept Library"])
+ #with home_tab:
+ #from home import layout
+ #layout()
+
+ with txt2img_tab:
+ from txt2img import layout
+ layout()
+
+ with img2img_tab:
+ from img2img import layout
+ layout()
+
+ #with inpainting_tab:
+ #from inpainting import layout
+ #layout()
+
+ with txt2vid_tab:
+ from txt2vid import layout
+ layout()
+
+ with img2txt_tab:
+ from img2txt import layout
+ layout()
+
+ with post_processing_tab:
+ from post_processing import layout
+ layout()
+
+ with concept_library_tab:
+ from sd_concept_library import layout
+ layout()
+
+ #
+ elif menu_id == 'Model Manager':
+ set_page_title("Model Manager - Stable Diffusion Playground")
+
+ from ModelManager import layout
+ layout()
+
+ elif menu_id == 'Textual Inversion':
+ from textual_inversion import layout
+ layout()
+
+ elif menu_id == 'Fine Tunning':
+ #from textual_inversion import layout
+ #layout()
+ st.info("Under Construction. :construction_worker:")
+
+ elif menu_id == 'API Server':
+ set_page_title("API Server - Stable Diffusion Playground")
+ from APIServer import layout
+ layout()
+
+ elif menu_id == 'Barfi/BaklavaJS':
+ set_page_title("Barfi/BaklavaJS - Stable Diffusion Playground")
+ from barfi_baklavajs import layout
+ layout()
+
+ elif menu_id == 'Settings':
+ set_page_title("Settings - Stable Diffusion Playground")
+
+ from Settings import layout
+ layout()
+
+ # calling dragable input component module at the end, so it works on all pages
+ draggable_number_input.load()
+
+
+if __name__ == '__main__':
+ set_logger_verbosity(opt.verbosity)
+ quiesce_logger(opt.quiet)
+
+ if not opt.headless:
+ layout()
+
+ #with server_state_lock["bridge"]:
+ #if server_state["bridge"]:
+ #try:
+ #import bridgeData as cd
+ #except ModuleNotFoundError as e:
+ #logger.warning("No bridgeData found. Falling back to default where no CLI args are set.")
+ #logger.debug(str(e))
+ #except SyntaxError as e:
+ #logger.warning("bridgeData found, but is malformed. Falling back to default where no CLI args are set.")
+ #logger.debug(str(e))
+ #except Exception as e:
+ #logger.warning("No bridgeData found, use default where no CLI args are set")
+ #logger.debug(str(e))
+ #finally:
+ #try: # check if cd exists (i.e. bridgeData loaded properly)
+ #cd
+ #except: # if not, create defaults
+ #class temp(object):
+ #def __init__(self):
+ #random.seed()
+ #self.horde_url = "https://stablehorde.net"
+ ## Give a cool name to your instance
+ #self.horde_name = f"Automated Instance #{random.randint(-100000000, 100000000)}"
+ ## The api_key identifies a unique user in the horde
+ #self.horde_api_key = "0000000000"
+ ## Put other users whose prompts you want to prioritize.
+ ## The owner's username is always included so you don't need to add it here, unless you want it to have lower priority than another user
+ #self.horde_priority_usernames = []
+ #self.horde_max_power = 8
+ #self.nsfw = True
+ #self.censor_nsfw = False
+ #self.blacklist = []
+ #self.censorlist = []
+ #self.models_to_load = ["stable_diffusion"]
+ #cd = temp()
+ #horde_api_key = opt.horde_api_key if opt.horde_api_key else cd.horde_api_key
+ #horde_name = opt.horde_name if opt.horde_name else cd.horde_name
+ #horde_url = opt.horde_url if opt.horde_url else cd.horde_url
+ #horde_priority_usernames = opt.horde_priority_usernames if opt.horde_priority_usernames else cd.horde_priority_usernames
+ #horde_max_power = opt.horde_max_power if opt.horde_max_power else cd.horde_max_power
+ ## Not used yet
+ #horde_models = [opt.horde_model] if opt.horde_model else cd.models_to_load
+ #try:
+ #horde_nsfw = not opt.horde_sfw if opt.horde_sfw else cd.horde_nsfw
+ #except AttributeError:
+ #horde_nsfw = True
+ #try:
+ #horde_censor_nsfw = opt.horde_censor_nsfw if opt.horde_censor_nsfw else cd.horde_censor_nsfw
+ #except AttributeError:
+ #horde_censor_nsfw = False
+ #try:
+ #horde_blacklist = opt.horde_blacklist if opt.horde_blacklist else cd.horde_blacklist
+ #except AttributeError:
+ #horde_blacklist = []
+ #try:
+ #horde_censorlist = opt.horde_censorlist if opt.horde_censorlist else cd.horde_censorlist
+ #except AttributeError:
+ #horde_censorlist = []
+ #if horde_max_power < 2:
+ #horde_max_power = 2
+ #horde_max_pixels = 64*64*8*horde_max_power
+ #logger.info(f"Joining Horde with parameters: Server Name '{horde_name}'. Horde URL '{horde_url}'. Max Pixels {horde_max_pixels}")
+
+ #try:
+ #thread = threading.Thread(target=run_bridge(1, horde_api_key, horde_name, horde_url,
+ #horde_priority_usernames, horde_max_pixels,
+ #horde_nsfw, horde_censor_nsfw, horde_blacklist,
+ #horde_censorlist), args=())
+ #thread.daemon = True
+ #thread.start()
+ ##run_bridge(1, horde_api_key, horde_name, horde_url, horde_priority_usernames, horde_max_pixels, horde_nsfw, horde_censor_nsfw, horde_blacklist, horde_censorlist)
+ #except KeyboardInterrupt:
+ #print(f"Keyboard Interrupt Received. Ending Bridge")
\ No newline at end of file