diff --git a/webui/streamlit/frontend/css/streamlit.main.css b/webui/streamlit/frontend/css/streamlit.main.css new file mode 100644 index 0000000..9162d05 --- /dev/null +++ b/webui/streamlit/frontend/css/streamlit.main.css @@ -0,0 +1,178 @@ +/* +This file is part of sygil-webui (https://github.com/Sygil-Dev/sygil-webui/). + +Copyright 2022 Sygil-Dev team. +This program is free software: you can redistribute it and/or modify +it under the terms of the GNU Affero General Public License as published by +the Free Software Foundation, either version 3 of the License, or +(at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU Affero General Public License for more details. + +You should have received a copy of the GNU Affero General Public License +along with this program. If not, see . +*/ + +/*********************************************************** +* Additional CSS for streamlit builtin components * +************************************************************/ + +/* Tab name (e.g. Text-to-Image) //improve legibility*/ +button[data-baseweb="tab"] { + font-size: 25px; +} + +/* Image Container (only appear after run finished)//center the image, especially better looks in wide screen */ +.css-1kyxreq{ + justify-content: center; +} + + +/* Streamlit header */ +.css-1avcm0n { + background-color: transparent; +} + +/* Main streamlit container (below header) //reduce the empty spaces*/ +.css-18e3th9 { + padding-top: 1rem; +} + + + +/*********************************************************** +* Additional CSS for streamlit custom/3rd party components * +************************************************************/ +/* For stream_on_hover */ +section[data-testid="stSidebar"] > div:nth-of-type(1) { + background-color: #111; +} + +button[kind="header"] { + background-color: transparent; + color: rgb(180, 167, 141); +} + +@media (hover) { + /* header element */ + header[data-testid="stHeader"] { + /* display: none;*/ /*suggested behavior by streamlit hover components*/ + pointer-events: none; /* disable interaction of the transparent background */ + } + + /* The button on the streamlit navigation menu */ + button[kind="header"] { + /* display: none;*/ /*suggested behavior by streamlit hover components*/ + pointer-events: auto; /* enable interaction of the button even if parents intereaction disabled */ + } + + /* added to avoid main sectors (all element to the right of sidebar from) moving */ + section[data-testid="stSidebar"] { + width: 3.5% !important; + min-width: 3.5% !important; + } + + /* The navigation menu specs and size */ + section[data-testid="stSidebar"] > div { + height: 100%; + width: 2% !important; + min-width: 100% !important; + position: relative; + z-index: 1; + top: 0; + left: 0; + background-color: #111; + overflow-x: hidden; + transition: 0.5s ease-in-out; + padding-top: 0px; + white-space: nowrap; + } + + /* The navigation menu open and close on hover and size */ + section[data-testid="stSidebar"] > div:hover { + width: 300px !important; + } +} + +@media (max-width: 272px) { + section[data-testid="stSidebar"] > div { + width: 15rem; + } +} + +/*********************************************************** +* Additional CSS for other elements +************************************************************/ +button[data-baseweb="tab"] { + font-size: 20px; +} + +@media (min-width: 1200px){ +h1 { + font-size: 1.75rem; +} +} +#tabs-1-tabpanel-0 > div:nth-child(1) > div > div.stTabs.css-0.exp6ofz0 { + width: 50rem; + align-self: center; +} +div.gallery:hover { + border: 1px solid #777; +} +.css-dg4u6x p { + font-size: 0.8rem; + text-align: center; + position: relative; + top: 6px; +} + +.row-widget.stButton { + text-align: center; +} + +/******************************************************************** + Hide anchor links on titles +*********************************************************************/ +/* +.css-15zrgzn { + display: none + } +.css-eczf16 { + display: none + } +.css-jn99sy { + display: none + } + +/* Make the text area widget have a similar height as the text input field */ +.st-dy{ + height: 54px; + min-height: 25px; +} +.css-17useex{ + gap: 3px; + +} + +/* Remove some empty spaces to make the UI more compact. */ +.css-18e3th9{ + padding-left: 10px; + padding-right: 30px; + position: unset !important; /* Fixes the layout/page going up when an expander or another item is expanded and then collapsed */ +} +.css-k1vhr4{ + padding-top: initial; +} +.css-ret2ud{ + padding-left: 10px; + padding-right: 30px; + gap: initial; + display: initial; +} + +.css-w5z5an{ + gap: 1px; +} diff --git a/webui/streamlit/scripts/APIServer.py b/webui/streamlit/scripts/APIServer.py new file mode 100644 index 0000000..a4da6c5 --- /dev/null +++ b/webui/streamlit/scripts/APIServer.py @@ -0,0 +1,34 @@ +# This file is part of sygil-webui (https://github.com/Sygil-Dev/sygil-webui/). + +# Copyright 2022 Sygil-Dev team. +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. + +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. + +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . +# base webui import and utils. +#from sd_utils import * +from sd_utils import st +# streamlit imports + +#streamlit components section + +#other imports +#from fastapi import FastAPI +#import uvicorn + +# Temp imports + +# end of imports +#--------------------------------------------------------------------------------------------------------------- + + +def layout(): + st.info("Under Construction. :construction_worker:") \ No newline at end of file diff --git a/webui/streamlit/scripts/ModelManager.py b/webui/streamlit/scripts/ModelManager.py new file mode 100644 index 0000000..3db0105 --- /dev/null +++ b/webui/streamlit/scripts/ModelManager.py @@ -0,0 +1,121 @@ +# This file is part of sygil-webui (https://github.com/Sygil-Dev/sygil-webui/). + +# Copyright 2022 Sygil-Dev team. +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. + +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. + +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . +# base webui import and utils. +from sd_utils import st, logger +# streamlit imports + + +#other imports +import os, requests +from requests.auth import HTTPBasicAuth +from requests import HTTPError +from stqdm import stqdm + +# Temp imports + + +# end of imports +#--------------------------------------------------------------------------------------------------------------- +def download_file(file_name, file_path, file_url): + if not os.path.exists(file_path): + os.makedirs(file_path) + + if not os.path.exists(os.path.join(file_path , file_name)): + print('Downloading ' + file_name + '...') + # TODO - add progress bar in streamlit + # download file with `requests`` + if file_name == "Stable Diffusion v1.5": + if "huggingface_token" not in st.session_state or st.session_state["defaults"].general.huggingface_token == "None": + if "progress_bar_text" in st.session_state: + st.session_state["progress_bar_text"].error( + "You need a huggingface token in order to use the Text to Video tab. Use the Settings page from the sidebar on the left to add your token." + ) + raise OSError("You need a huggingface token in order to use the Text to Video tab. Use the Settings page from the sidebar on the left to add your token.") + + try: + with requests.get(file_url, auth = HTTPBasicAuth('token', st.session_state.defaults.general.huggingface_token) if "huggingface.co" in file_url else None, stream=True) as r: + r.raise_for_status() + with open(os.path.join(file_path, file_name), 'wb') as f: + for chunk in stqdm(r.iter_content(chunk_size=8192), backend=True, unit="kb"): + f.write(chunk) + except HTTPError as e: + if "huggingface.co" in file_url: + if "resolve"in file_url: + repo_url = file_url.split("resolve")[0] + + st.session_state["progress_bar_text"].error( + f"You need to accept the license for the model in order to be able to download it. " + f"Please visit {repo_url} and accept the lincense there, then try again to download the model.") + + logger.error(e) + + else: + print(file_name + ' already exists.') + + +def download_model(models, model_name): + """ Download all files from model_list[model_name] """ + for file in models[model_name]: + download_file(file['file_name'], file['file_path'], file['file_url']) + return + + +def layout(): + #search = st.text_input(label="Search", placeholder="Type the name of the model you want to search for.", help="") + + colms = st.columns((1, 3, 3, 5, 5)) + columns = ["№", 'Model Name', 'Save Location', "Download", 'Download Link'] + + models = st.session_state["defaults"].model_manager.models + + for col, field_name in zip(colms, columns): + # table header + col.write(field_name) + + for x, model_name in enumerate(models): + col1, col2, col3, col4, col5 = st.columns((1, 3, 3, 3, 6)) + col1.write(x) # index + col2.write(models[model_name]['model_name']) + col3.write(models[model_name]['save_location']) + with col4: + files_exist = 0 + for file in models[model_name]['files']: + if "save_location" in models[model_name]['files'][file]: + os.path.exists(os.path.join(models[model_name]['files'][file]['save_location'] , models[model_name]['files'][file]['file_name'])) + files_exist += 1 + elif os.path.exists(os.path.join(models[model_name]['save_location'] , models[model_name]['files'][file]['file_name'])): + files_exist += 1 + files_needed = [] + for file in models[model_name]['files']: + if "save_location" in models[model_name]['files'][file]: + if not os.path.exists(os.path.join(models[model_name]['files'][file]['save_location'] , models[model_name]['files'][file]['file_name'])): + files_needed.append(file) + elif not os.path.exists(os.path.join(models[model_name]['save_location'] , models[model_name]['files'][file]['file_name'])): + files_needed.append(file) + if len(files_needed) > 0: + if st.button('Download', key=models[model_name]['model_name'], help='Download ' + models[model_name]['model_name']): + for file in files_needed: + if "save_location" in models[model_name]['files'][file]: + download_file(models[model_name]['files'][file]['file_name'], models[model_name]['files'][file]['save_location'], models[model_name]['files'][file]['download_link']) + else: + download_file(models[model_name]['files'][file]['file_name'], models[model_name]['save_location'], models[model_name]['files'][file]['download_link']) + st.experimental_rerun() + else: + st.empty() + else: + st.write('✅') + + # diff --git a/webui/streamlit/scripts/Settings.py b/webui/streamlit/scripts/Settings.py new file mode 100644 index 0000000..0b228cb --- /dev/null +++ b/webui/streamlit/scripts/Settings.py @@ -0,0 +1,899 @@ +# This file is part of sygil-webui (https://github.com/Sygil-Dev/sygil-webui/). + +# Copyright 2022 Sygil-Dev team. +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. + +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. + +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . +# base webui import and utils. +from sd_utils import st, custom_models_available, logger, human_readable_size + +# streamlit imports + +# streamlit components section +import streamlit_nested_layout +from streamlit_server_state import server_state + +# other imports +from omegaconf import OmegaConf +import torch +import os, toml + +# end of imports +# --------------------------------------------------------------------------------------------------------------- + +@logger.catch(reraise=True) +def layout(): + #st.header("Settings") + + with st.form("Settings"): + general_tab, txt2img_tab, img2img_tab, img2txt_tab, txt2vid_tab, image_processing, textual_inversion_tab, concepts_library_tab = st.tabs( + ['General', "Text-To-Image", "Image-To-Image", "Image-To-Text", "Text-To-Video", "Image processing", "Textual Inversion", "Concepts Library"]) + + with general_tab: + col1, col2, col3, col4, col5 = st.columns(5, gap='large') + + device_list = [] + device_properties = [(i, torch.cuda.get_device_properties(i)) for i in range(torch.cuda.device_count())] + for device in device_properties: + id = device[0] + name = device[1].name + total_memory = device[1].total_memory + + device_list.append(f"{id}: {name} ({human_readable_size(total_memory, decimal_places=0)})") + + with col1: + st.title("General") + st.session_state['defaults'].general.gpu = int(st.selectbox("GPU", device_list, index=st.session_state['defaults'].general.gpu, + help=f"Select which GPU to use. Default: {device_list[0]}").split(":")[0]) + + st.session_state['defaults'].general.outdir = str(st.text_input("Output directory", value=st.session_state['defaults'].general.outdir, + help="Relative directory on which the output images after a generation will be placed. Default: 'outputs'")) + + # If we have custom models available on the "models/custom" + # folder then we show a menu to select which model we want to use, otherwise we use the main model for SD + custom_models_available() + + if server_state["CustomModel_available"]: + st.session_state.defaults.general.default_model = st.selectbox("Default Model:", server_state["custom_models"], + index=server_state["custom_models"].index(st.session_state['defaults'].general.default_model), + help="Select the model you want to use. If you have placed custom models \ + on your 'models/custom' folder they will be shown here as well. The model name that will be shown here \ + is the same as the name the file for the model has on said folder, \ + it is recommended to give the .ckpt file a name that \ + will make it easier for you to distinguish it from other models. Default: Stable Diffusion v1.4") + else: + st.session_state.defaults.general.default_model = st.selectbox("Default Model:", [st.session_state['defaults'].general.default_model], + help="Select the model you want to use. If you have placed custom models \ + on your 'models/custom' folder they will be shown here as well. \ + The model name that will be shown here is the same as the name\ + the file for the model has on said folder, it is recommended to give the .ckpt file a name that \ + will make it easier for you to distinguish it from other models. Default: Stable Diffusion v1.4") + + st.session_state['defaults'].general.default_model_config = st.text_input("Default Model Config", value=st.session_state['defaults'].general.default_model_config, + help="Default model config file for inference. Default: 'configs/stable-diffusion/v1-inference.yaml'") + + st.session_state['defaults'].general.default_model_path = st.text_input("Default Model Config", value=st.session_state['defaults'].general.default_model_path, + help="Default model path. Default: 'models/ldm/stable-diffusion-v1/model.ckpt'") + + st.session_state['defaults'].general.GFPGAN_dir = st.text_input("Default GFPGAN directory", value=st.session_state['defaults'].general.GFPGAN_dir, + help="Default GFPGAN directory. Default: './models/gfpgan'") + + st.session_state['defaults'].general.RealESRGAN_dir = st.text_input("Default RealESRGAN directory", value=st.session_state['defaults'].general.RealESRGAN_dir, + help="Default GFPGAN directory. Default: './models/realesrgan'") + + RealESRGAN_model_list = ["RealESRGAN_x4plus", "RealESRGAN_x4plus_anime_6B"] + st.session_state['defaults'].general.RealESRGAN_model = st.selectbox("RealESRGAN model", RealESRGAN_model_list, + index=RealESRGAN_model_list.index(st.session_state['defaults'].general.RealESRGAN_model), + help="Default RealESRGAN model. Default: 'RealESRGAN_x4plus'") + Upscaler_list = ["RealESRGAN", "LDSR"] + st.session_state['defaults'].general.upscaling_method = st.selectbox("Upscaler", Upscaler_list, index=Upscaler_list.index( + st.session_state['defaults'].general.upscaling_method), help="Default upscaling method. Default: 'RealESRGAN'") + + with col2: + st.title("Performance") + + st.session_state["defaults"].general.gfpgan_cpu = st.checkbox("GFPGAN - CPU", value=st.session_state['defaults'].general.gfpgan_cpu, + help="Run GFPGAN on the cpu. Default: False") + + st.session_state["defaults"].general.esrgan_cpu = st.checkbox("ESRGAN - CPU", value=st.session_state['defaults'].general.esrgan_cpu, + help="Run ESRGAN on the cpu. Default: False") + + st.session_state["defaults"].general.extra_models_cpu = st.checkbox("Extra Models - CPU", value=st.session_state['defaults'].general.extra_models_cpu, + help="Run extra models (GFGPAN/ESRGAN) on cpu. Default: False") + + st.session_state["defaults"].general.extra_models_gpu = st.checkbox("Extra Models - GPU", value=st.session_state['defaults'].general.extra_models_gpu, + help="Run extra models (GFGPAN/ESRGAN) on gpu. \ + Check and save in order to be able to select the GPU that each model will use. Default: False") + if st.session_state["defaults"].general.extra_models_gpu: + st.session_state['defaults'].general.gfpgan_gpu = int(st.selectbox("GFGPAN GPU", device_list, index=st.session_state['defaults'].general.gfpgan_gpu, + help=f"Select which GPU to use. Default: {device_list[st.session_state['defaults'].general.gfpgan_gpu]}", + key="gfpgan_gpu").split(":")[0]) + + st.session_state["defaults"].general.esrgan_gpu = int(st.selectbox("ESRGAN - GPU", device_list, index=st.session_state['defaults'].general.esrgan_gpu, + help=f"Select which GPU to use. Default: {device_list[st.session_state['defaults'].general.esrgan_gpu]}", + key="esrgan_gpu").split(":")[0]) + + st.session_state["defaults"].general.no_half = st.checkbox("No Half", value=st.session_state['defaults'].general.no_half, + help="DO NOT switch the model to 16-bit floats. Default: False") + + st.session_state["defaults"].general.use_cudnn = st.checkbox("Use cudnn", value=st.session_state['defaults'].general.use_cudnn, + help="Switch the pytorch backend to use cudnn, this should help with fixing Nvidia 16xx cards getting" + "a black or green image. Default: False") + + st.session_state["defaults"].general.use_float16 = st.checkbox("Use float16", value=st.session_state['defaults'].general.use_float16, + help="Switch the model to 16-bit floats. Default: False") + + + precision_list = ['full', 'autocast'] + st.session_state["defaults"].general.precision = st.selectbox("Precision", precision_list, index=precision_list.index(st.session_state['defaults'].general.precision), + help="Evaluates at this precision. Default: autocast") + + st.session_state["defaults"].general.optimized = st.checkbox("Optimized Mode", value=st.session_state['defaults'].general.optimized, + help="Loads the model onto the device piecemeal instead of all at once to reduce VRAM usage\ + at the cost of performance. Default: False") + + st.session_state["defaults"].general.optimized_turbo = st.checkbox("Optimized Turbo Mode", value=st.session_state['defaults'].general.optimized_turbo, + help="Alternative optimization mode that does not save as much VRAM but \ + runs siginificantly faster. Default: False") + + st.session_state["defaults"].general.optimized_config = st.text_input("Optimized Config", value=st.session_state['defaults'].general.optimized_config, + help=f"Loads alternative optimized configuration for inference. \ + Default: optimizedSD/v1-inference.yaml") + + st.session_state["defaults"].general.enable_attention_slicing = st.checkbox("Enable Attention Slicing", value=st.session_state['defaults'].general.enable_attention_slicing, + help="Enable sliced attention computation. When this option is enabled, the attention module will \ + split the input tensor in slices, to compute attention in several steps. This is useful to save some \ + memory in exchange for a small speed decrease. Only works the txt2vid tab right now. Default: False") + + st.session_state["defaults"].general.enable_minimal_memory_usage = st.checkbox("Enable Minimal Memory Usage", value=st.session_state['defaults'].general.enable_minimal_memory_usage, + help="Moves only unet to fp16 and to CUDA, while keepping lighter models on CPUs \ + (Not properly implemented and currently not working, check this \ + link 'https://github.com/huggingface/diffusers/pull/537' for more information on it ). Default: False") + + # st.session_state["defaults"].general.update_preview = st.checkbox("Update Preview Image", value=st.session_state['defaults'].general.update_preview, + # help="Enables the preview image to be updated and shown to the user on the UI during the generation.\ + # If checked, once you save the settings an option to specify the frequency at which the image is updated\ + # in steps will be shown, this is helpful to reduce the negative effect this option has on performance. \ + # Default: True") + st.session_state["defaults"].general.update_preview = True + st.session_state["defaults"].general.update_preview_frequency = st.number_input("Update Preview Frequency", + min_value=0, + value=st.session_state['defaults'].general.update_preview_frequency, + help="Specify the frequency at which the image is updated in steps, this is helpful to reduce the \ + negative effect updating the preview image has on performance. Default: 10") + + with col3: + st.title("Others") + st.session_state["defaults"].general.use_sd_concepts_library = st.checkbox("Use the Concepts Library", value=st.session_state['defaults'].general.use_sd_concepts_library, + help="Use the embeds Concepts Library, if checked, once the settings are saved an option will\ + appear to specify the directory where the concepts are stored. Default: True)") + + if st.session_state["defaults"].general.use_sd_concepts_library: + st.session_state['defaults'].general.sd_concepts_library_folder = st.text_input("Concepts Library Folder", + value=st.session_state['defaults'].general.sd_concepts_library_folder, + help="Relative folder on which the concepts library embeds are stored. \ + Default: 'models/custom/sd-concepts-library'") + + st.session_state['defaults'].general.LDSR_dir = st.text_input("LDSR Folder", value=st.session_state['defaults'].general.LDSR_dir, + help="Folder where LDSR is located. Default: './models/ldsr'") + + st.session_state["defaults"].general.save_metadata = st.checkbox("Save Metadata", value=st.session_state['defaults'].general.save_metadata, + help="Save metadata on the output image. Default: True") + save_format_list = ["png","jpg", "jpeg","webp"] + st.session_state["defaults"].general.save_format = st.selectbox("Save Format", save_format_list, index=save_format_list.index(st.session_state['defaults'].general.save_format), + help="Format that will be used whens saving the output images. Default: 'png'") + + st.session_state["defaults"].general.skip_grid = st.checkbox("Skip Grid", value=st.session_state['defaults'].general.skip_grid, + help="Skip saving the grid output image. Default: False") + if not st.session_state["defaults"].general.skip_grid: + + + st.session_state["defaults"].general.grid_quality = st.number_input("Grid Quality", value=st.session_state['defaults'].general.grid_quality, + help="Format for saving the grid output image. Default: 95") + + st.session_state["defaults"].general.skip_save = st.checkbox("Skip Save", value=st.session_state['defaults'].general.skip_save, + help="Skip saving the output image. Default: False") + + st.session_state["defaults"].general.n_rows = st.number_input("Number of Grid Rows", value=st.session_state['defaults'].general.n_rows, + help="Number of rows the grid wil have when saving the grid output image. Default: '-1'") + + st.session_state["defaults"].general.no_verify_input = st.checkbox("Do not Verify Input", value=st.session_state['defaults'].general.no_verify_input, + help="Do not verify input to check if it's too long. Default: False") + + st.session_state["defaults"].general.show_percent_in_tab_title = st.checkbox("Show Percent in tab title", value=st.session_state['defaults'].general.show_percent_in_tab_title, + help="Add the progress percent value to the page title on the tab on your browser. " + "This is useful in case you need to know how the generation is going while doign something else" + "in another tab on your browser. Default: True") + + st.session_state["defaults"].general.enable_suggestions = st.checkbox("Enable Suggestions Box", value=st.session_state['defaults'].general.enable_suggestions, + help="Adds a suggestion box under the prompt when clicked. Default: True") + + st.session_state["defaults"].daisi_app.running_on_daisi_io = st.checkbox("Running on Daisi.io?", value=st.session_state['defaults'].daisi_app.running_on_daisi_io, + help="Specify if we are running on app.Daisi.io . Default: False") + + with col4: + st.title("Streamlit Config") + + default_theme_list = ["light", "dark"] + st.session_state["defaults"].general.default_theme = st.selectbox("Default Theme", default_theme_list, index=default_theme_list.index(st.session_state['defaults'].general.default_theme), + help="Defaut theme to use as base for streamlit. Default: dark") + st.session_state["streamlit_config"]["theme"]["base"] = st.session_state["defaults"].general.default_theme + + + if not st.session_state['defaults'].admin.hide_server_setting: + with st.expander("Server", True): + + st.session_state["streamlit_config"]['server']['headless'] = st.checkbox("Run Headless", help="If false, will attempt to open a browser window on start. \ + Default: false unless (1) we are on a Linux box where DISPLAY is unset, \ + or (2) we are running in the Streamlit Atom plugin.") + + st.session_state["streamlit_config"]['server']['port'] = st.number_input("Port", value=st.session_state["streamlit_config"]['server']['port'], + help="The port where the server will listen for browser connections. Default: 8501") + + st.session_state["streamlit_config"]['server']['baseUrlPath'] = st.text_input("Base Url Path", value=st.session_state["streamlit_config"]['server']['baseUrlPath'], + help="The base path for the URL where Streamlit should be served from. Default: '' ") + + st.session_state["streamlit_config"]['server']['enableCORS'] = st.checkbox("Enable CORS", value=st.session_state['streamlit_config']['server']['enableCORS'], + help="Enables support for Cross-Origin Request Sharing (CORS) protection, for added security. \ + Due to conflicts between CORS and XSRF, if `server.enableXsrfProtection` is on and `server.enableCORS` \ + is off at the same time, we will prioritize `server.enableXsrfProtection`. Default: true") + + st.session_state["streamlit_config"]['server']['enableXsrfProtection'] = st.checkbox("Enable Xsrf Protection", + value=st.session_state['streamlit_config']['server']['enableXsrfProtection'], + help="Enables support for Cross-Site Request Forgery (XSRF) protection, \ + for added security. Due to conflicts between CORS and XSRF, \ + if `server.enableXsrfProtection` is on and `server.enableCORS` is off at \ + the same time, we will prioritize `server.enableXsrfProtection`. Default: true") + + st.session_state["streamlit_config"]['server']['maxUploadSize'] = st.number_input("Max Upload Size", value=st.session_state["streamlit_config"]['server']['maxUploadSize'], + help="Max size, in megabytes, for files uploaded with the file_uploader. Default: 200") + + st.session_state["streamlit_config"]['server']['maxMessageSize'] = st.number_input("Max Message Size", value=st.session_state["streamlit_config"]['server']['maxUploadSize'], + help="Max size, in megabytes, of messages that can be sent via the WebSocket connection. Default: 200") + + st.session_state["streamlit_config"]['server']['enableWebsocketCompression'] = st.checkbox("Enable Websocket Compression", + value=st.session_state["streamlit_config"]['server']['enableWebsocketCompression'], + help=" Enables support for websocket compression. Default: false") + if not st.session_state['defaults'].admin.hide_browser_setting: + with st.expander("Browser", expanded=True): + st.session_state["streamlit_config"]['browser']['serverAddress'] = st.text_input("Server Address", + value=st.session_state["streamlit_config"]['browser']['serverAddress'] if "serverAddress" in st.session_state["streamlit_config"] else "localhost", + help="Internet address where users should point their browsers in order \ + to connect to the app. Can be IP address or DNS name and path.\ + This is used to: - Set the correct URL for CORS and XSRF protection purposes. \ + - Show the URL on the terminal - Open the browser. Default: 'localhost'") + + st.session_state["defaults"].general.streamlit_telemetry = st.checkbox("Enable Telemetry", value=st.session_state['defaults'].general.streamlit_telemetry, + help="Enables or Disables streamlit telemetry. Default: False") + st.session_state["streamlit_config"]["browser"]["gatherUsageStats"] = st.session_state["defaults"].general.streamlit_telemetry + + st.session_state["streamlit_config"]['browser']['serverPort'] = st.number_input("Server Port", value=st.session_state["streamlit_config"]['browser']['serverPort'], + help="Port where users should point their browsers in order to connect to the app. \ + This is used to: - Set the correct URL for CORS and XSRF protection purposes. \ + - Show the URL on the terminal - Open the browser \ + Default: whatever value is set in server.port.") + + with col5: + st.title("Huggingface") + st.session_state["defaults"].general.huggingface_token = st.text_input("Huggingface Token", value=st.session_state['defaults'].general.huggingface_token, type="password", + help="Your Huggingface Token, it's used to download the model for the diffusers library which \ + is used on the Text To Video tab. This token will be saved to your user config file\ + and WILL NOT be share with us or anyone. You can get your access token \ + at https://huggingface.co/settings/tokens. Default: None") + + st.title("Stable Horde") + st.session_state["defaults"].general.stable_horde_api = st.text_input("Stable Horde Api", value=st.session_state["defaults"].general.stable_horde_api, type="password", + help="First Register an account at https://stablehorde.net/register which will generate for you \ + an API key. Store that key somewhere safe. \n \ + If you do not want to register, you can use `0000000000` as api_key to connect anonymously.\ + However anonymous accounts have the lowest priority when there's too many concurrent requests! \ + To increase your priority you will need a unique API key and then to increase your Kudos \ + read more about them at https://dbzer0.com/blog/the-kudos-based-economy-for-the-koboldai-horde/.") + + with txt2img_tab: + col1, col2, col3, col4, col5 = st.columns(5, gap='medium') + + with col1: + st.title("Slider Parameters") + + # Width + st.session_state["defaults"].txt2img.width.value = st.number_input("Default Image Width", value=st.session_state['defaults'].txt2img.width.value, + help="Set the default width for the generated image. Default is: 512") + + st.session_state["defaults"].txt2img.width.min_value = st.number_input("Minimum Image Width", value=st.session_state['defaults'].txt2img.width.min_value, + help="Set the default minimum value for the width slider. Default is: 64") + + st.session_state["defaults"].txt2img.width.max_value = st.number_input("Maximum Image Width", value=st.session_state['defaults'].txt2img.width.max_value, + help="Set the default maximum value for the width slider. Default is: 2048") + + # Height + st.session_state["defaults"].txt2img.height.value = st.number_input("Default Image Height", value=st.session_state['defaults'].txt2img.height.value, + help="Set the default height for the generated image. Default is: 512") + + st.session_state["defaults"].txt2img.height.min_value = st.number_input("Minimum Image Height", value=st.session_state['defaults'].txt2img.height.min_value, + help="Set the default minimum value for the height slider. Default is: 64") + + st.session_state["defaults"].txt2img.height.max_value = st.number_input("Maximum Image Height", value=st.session_state['defaults'].txt2img.height.max_value, + help="Set the default maximum value for the height slider. Default is: 2048") + + with col2: + # CFG + st.session_state["defaults"].txt2img.cfg_scale.value = st.number_input("Default CFG Scale", value=st.session_state['defaults'].txt2img.cfg_scale.value, + help="Set the default value for the CFG Scale. Default is: 7.5") + + st.session_state["defaults"].txt2img.cfg_scale.min_value = st.number_input("Minimum CFG Scale Value", value=st.session_state['defaults'].txt2img.cfg_scale.min_value, + help="Set the default minimum value for the CFG scale slider. Default is: 1") + + st.session_state["defaults"].txt2img.cfg_scale.step = st.number_input("CFG Slider Steps", value=st.session_state['defaults'].txt2img.cfg_scale.step, + help="Set the default value for the number of steps on the CFG scale slider. Default is: 0.5") + # Sampling Steps + st.session_state["defaults"].txt2img.sampling_steps.value = st.number_input("Default Sampling Steps", value=st.session_state['defaults'].txt2img.sampling_steps.value, + help="Set the default number of sampling steps to use. Default is: 30 (with k_euler)") + + st.session_state["defaults"].txt2img.sampling_steps.min_value = st.number_input("Minimum Sampling Steps", + value=st.session_state['defaults'].txt2img.sampling_steps.min_value, + help="Set the default minimum value for the sampling steps slider. Default is: 1") + + st.session_state["defaults"].txt2img.sampling_steps.step = st.number_input("Sampling Slider Steps", + value=st.session_state['defaults'].txt2img.sampling_steps.step, + help="Set the default value for the number of steps on the sampling steps slider. Default is: 10") + + with col3: + st.title("General Parameters") + + # Batch Count + st.session_state["defaults"].txt2img.batch_count.value = st.number_input("Batch count", value=st.session_state['defaults'].txt2img.batch_count.value, + help="How many iterations or batches of images to generate in total.") + + st.session_state["defaults"].txt2img.batch_size.value = st.number_input("Batch size", value=st.session_state.defaults.txt2img.batch_size.value, + help="How many images are at once in a batch.\ + It increases the VRAM usage a lot but if you have enough VRAM it can reduce the time it \ + takes to finish generation as more images are generated at once.\ + Default: 1") + + default_sampler_list = ["k_lms", "k_euler", "k_euler_a", "k_dpm_2", "k_dpm_2_a", "k_heun", "PLMS", "DDIM"] + st.session_state["defaults"].txt2img.default_sampler = st.selectbox("Default Sampler", + default_sampler_list, index=default_sampler_list.index( + st.session_state['defaults'].txt2img.default_sampler), + help="Defaut sampler to use for txt2img. Default: k_euler") + + st.session_state['defaults'].txt2img.seed = st.text_input("Default Seed", value=st.session_state['defaults'].txt2img.seed, help="Default seed.") + + with col4: + + st.session_state["defaults"].txt2img.separate_prompts = st.checkbox("Separate Prompts", + value=st.session_state['defaults'].txt2img.separate_prompts, help="Separate Prompts. Default: False") + + st.session_state["defaults"].txt2img.normalize_prompt_weights = st.checkbox("Normalize Prompt Weights", + value=st.session_state['defaults'].txt2img.normalize_prompt_weights, + help="Choose to normalize prompt weights. Default: True") + + st.session_state["defaults"].txt2img.save_individual_images = st.checkbox("Save Individual Images", + value=st.session_state['defaults'].txt2img.save_individual_images, + help="Choose to save individual images. Default: True") + + st.session_state["defaults"].txt2img.save_grid = st.checkbox("Save Grid Images", value=st.session_state['defaults'].txt2img.save_grid, + help="Choose to save the grid images. Default: True") + + st.session_state["defaults"].txt2img.group_by_prompt = st.checkbox("Group By Prompt", value=st.session_state['defaults'].txt2img.group_by_prompt, + help="Choose to save images grouped by their prompt. Default: False") + + st.session_state["defaults"].txt2img.save_as_jpg = st.checkbox("Save As JPG", value=st.session_state['defaults'].txt2img.save_as_jpg, + help="Choose to save images as jpegs. Default: False") + + st.session_state["defaults"].txt2img.write_info_files = st.checkbox("Write Info Files For Images", value=st.session_state['defaults'].txt2img.write_info_files, + help="Choose to write the info files along with the generated images. Default: True") + + st.session_state["defaults"].txt2img.use_GFPGAN = st.checkbox( + "Use GFPGAN", value=st.session_state['defaults'].txt2img.use_GFPGAN, help="Choose to use GFPGAN. Default: False") + + st.session_state["defaults"].txt2img.use_upscaling = st.checkbox("Use Upscaling", value=st.session_state['defaults'].txt2img.use_upscaling, + help="Choose to turn on upscaling by default. Default: False") + + st.session_state["defaults"].txt2img.update_preview = True + st.session_state["defaults"].txt2img.update_preview_frequency = st.number_input("Preview Image Update Frequency", + min_value=0, + value=st.session_state['defaults'].txt2img.update_preview_frequency, + help="Set the default value for the frrquency of the preview image updates. Default is: 10") + + with col5: + st.title("Variation Parameters") + + st.session_state["defaults"].txt2img.variant_amount.value = st.number_input("Default Variation Amount", + value=st.session_state['defaults'].txt2img.variant_amount.value, + help="Set the default variation to use. Default is: 0.0") + + st.session_state["defaults"].txt2img.variant_amount.min_value = st.number_input("Minimum Variation Amount", + value=st.session_state['defaults'].txt2img.variant_amount.min_value, + help="Set the default minimum value for the variation slider. Default is: 0.0") + + st.session_state["defaults"].txt2img.variant_amount.max_value = st.number_input("Maximum Variation Amount", + value=st.session_state['defaults'].txt2img.variant_amount.max_value, + help="Set the default maximum value for the variation slider. Default is: 1.0") + + st.session_state["defaults"].txt2img.variant_amount.step = st.number_input("Variation Slider Steps", + value=st.session_state['defaults'].txt2img.variant_amount.step, + help="Set the default value for the number of steps on the variation slider. Default is: 1") + + st.session_state['defaults'].txt2img.variant_seed = st.text_input("Default Variation Seed", value=st.session_state['defaults'].txt2img.variant_seed, + help="Default variation seed.") + + with img2img_tab: + col1, col2, col3, col4, col5 = st.columns(5, gap='medium') + + with col1: + st.title("Image Editing") + + # Denoising + st.session_state["defaults"].img2img.denoising_strength.value = st.number_input("Default Denoising Amount", + value=st.session_state['defaults'].img2img.denoising_strength.value, + help="Set the default denoising to use. Default is: 0.75") + + st.session_state["defaults"].img2img.denoising_strength.min_value = st.number_input("Minimum Denoising Amount", + value=st.session_state['defaults'].img2img.denoising_strength.min_value, + help="Set the default minimum value for the denoising slider. Default is: 0.0") + + st.session_state["defaults"].img2img.denoising_strength.max_value = st.number_input("Maximum Denoising Amount", + value=st.session_state['defaults'].img2img.denoising_strength.max_value, + help="Set the default maximum value for the denoising slider. Default is: 1.0") + + st.session_state["defaults"].img2img.denoising_strength.step = st.number_input("Denoising Slider Steps", + value=st.session_state['defaults'].img2img.denoising_strength.step, + help="Set the default value for the number of steps on the denoising slider. Default is: 0.01") + + # Masking + st.session_state["defaults"].img2img.mask_mode = st.number_input("Default Mask Mode", value=st.session_state['defaults'].img2img.mask_mode, + help="Set the default mask mode to use. 0 = Keep Masked Area, 1 = Regenerate Masked Area. Default is: 0") + + st.session_state["defaults"].img2img.mask_restore = st.checkbox("Default Mask Restore", value=st.session_state['defaults'].img2img.mask_restore, + help="Mask Restore. Default: False") + + st.session_state["defaults"].img2img.resize_mode = st.number_input("Default Resize Mode", value=st.session_state['defaults'].img2img.resize_mode, + help="Set the default resizing mode. 0 = Just Resize, 1 = Crop and Resize, 3 = Resize and Fill. Default is: 0") + + with col2: + st.title("Slider Parameters") + + # Width + st.session_state["defaults"].img2img.width.value = st.number_input("Default Outputted Image Width", value=st.session_state['defaults'].img2img.width.value, + help="Set the default width for the generated image. Default is: 512") + + st.session_state["defaults"].img2img.width.min_value = st.number_input("Minimum Outputted Image Width", value=st.session_state['defaults'].img2img.width.min_value, + help="Set the default minimum value for the width slider. Default is: 64") + + st.session_state["defaults"].img2img.width.max_value = st.number_input("Maximum Outputted Image Width", value=st.session_state['defaults'].img2img.width.max_value, + help="Set the default maximum value for the width slider. Default is: 2048") + + # Height + st.session_state["defaults"].img2img.height.value = st.number_input("Default Outputted Image Height", value=st.session_state['defaults'].img2img.height.value, + help="Set the default height for the generated image. Default is: 512") + + st.session_state["defaults"].img2img.height.min_value = st.number_input("Minimum Outputted Image Height", value=st.session_state['defaults'].img2img.height.min_value, + help="Set the default minimum value for the height slider. Default is: 64") + + st.session_state["defaults"].img2img.height.max_value = st.number_input("Maximum Outputted Image Height", value=st.session_state['defaults'].img2img.height.max_value, + help="Set the default maximum value for the height slider. Default is: 2048") + + # CFG + st.session_state["defaults"].img2img.cfg_scale.value = st.number_input("Default Img2Img CFG Scale", value=st.session_state['defaults'].img2img.cfg_scale.value, + help="Set the default value for the CFG Scale. Default is: 7.5") + + st.session_state["defaults"].img2img.cfg_scale.min_value = st.number_input("Minimum Img2Img CFG Scale Value", + value=st.session_state['defaults'].img2img.cfg_scale.min_value, + help="Set the default minimum value for the CFG scale slider. Default is: 1") + + with col3: + st.session_state["defaults"].img2img.cfg_scale.step = st.number_input("Img2Img CFG Slider Steps", + value=st.session_state['defaults'].img2img.cfg_scale.step, + help="Set the default value for the number of steps on the CFG scale slider. Default is: 0.5") + + # Sampling Steps + st.session_state["defaults"].img2img.sampling_steps.value = st.number_input("Default Img2Img Sampling Steps", + value=st.session_state['defaults'].img2img.sampling_steps.value, + help="Set the default number of sampling steps to use. Default is: 30 (with k_euler)") + + st.session_state["defaults"].img2img.sampling_steps.min_value = st.number_input("Minimum Img2Img Sampling Steps", + value=st.session_state['defaults'].img2img.sampling_steps.min_value, + help="Set the default minimum value for the sampling steps slider. Default is: 1") + + st.session_state["defaults"].img2img.sampling_steps.step = st.number_input("Img2Img Sampling Slider Steps", + value=st.session_state['defaults'].img2img.sampling_steps.step, + help="Set the default value for the number of steps on the sampling steps slider. Default is: 10") + + # Batch Count + st.session_state["defaults"].img2img.batch_count.value = st.number_input("Img2img Batch count", value=st.session_state["defaults"].img2img.batch_count.value, + help="How many iterations or batches of images to generate in total.") + + st.session_state["defaults"].img2img.batch_size.value = st.number_input("Img2img Batch size", value=st.session_state["defaults"].img2img.batch_size.value, + help="How many images are at once in a batch.\ + It increases the VRAM usage a lot but if you have enough VRAM it can reduce the time it \ + takes to finish generation as more images are generated at once.\ + Default: 1") + with col4: + # Inference Steps + st.session_state["defaults"].img2img.num_inference_steps.value = st.number_input("Default Inference Steps", + value=st.session_state['defaults'].img2img.num_inference_steps.value, + help="Set the default number of inference steps to use. Default is: 200") + + st.session_state["defaults"].img2img.num_inference_steps.min_value = st.number_input("Minimum Sampling Steps", + value=st.session_state['defaults'].img2img.num_inference_steps.min_value, + help="Set the default minimum value for the inference steps slider. Default is: 10") + + st.session_state["defaults"].img2img.num_inference_steps.max_value = st.number_input("Maximum Sampling Steps", + value=st.session_state['defaults'].img2img.num_inference_steps.max_value, + help="Set the default maximum value for the inference steps slider. Default is: 500") + + st.session_state["defaults"].img2img.num_inference_steps.step = st.number_input("Inference Slider Steps", + value=st.session_state['defaults'].img2img.num_inference_steps.step, + help="Set the default value for the number of steps on the inference steps slider.\ + Default is: 10") + + # Find Noise Steps + st.session_state["defaults"].img2img.find_noise_steps.value = st.number_input("Default Find Noise Steps", + value=st.session_state['defaults'].img2img.find_noise_steps.value, + help="Set the default number of find noise steps to use. Default is: 100") + + st.session_state["defaults"].img2img.find_noise_steps.min_value = st.number_input("Minimum Find Noise Steps", + value=st.session_state['defaults'].img2img.find_noise_steps.min_value, + help="Set the default minimum value for the find noise steps slider. Default is: 0") + + st.session_state["defaults"].img2img.find_noise_steps.step = st.number_input("Find Noise Slider Steps", + value=st.session_state['defaults'].img2img.find_noise_steps.step, + help="Set the default value for the number of steps on the find noise steps slider. \ + Default is: 100") + + with col5: + st.title("General Parameters") + + default_sampler_list = ["k_lms", "k_euler", "k_euler_a", "k_dpm_2", "k_dpm_2_a", "k_heun", "PLMS", "DDIM"] + st.session_state["defaults"].img2img.sampler_name = st.selectbox("Default Img2Img Sampler", default_sampler_list, + index=default_sampler_list.index(st.session_state['defaults'].img2img.sampler_name), + help="Defaut sampler to use for img2img. Default: k_euler") + + st.session_state['defaults'].img2img.seed = st.text_input("Default Img2Img Seed", value=st.session_state['defaults'].img2img.seed, help="Default seed.") + + st.session_state["defaults"].img2img.separate_prompts = st.checkbox("Separate Img2Img Prompts", value=st.session_state['defaults'].img2img.separate_prompts, + help="Separate Prompts. Default: False") + + st.session_state["defaults"].img2img.normalize_prompt_weights = st.checkbox("Normalize Img2Img Prompt Weights", + value=st.session_state['defaults'].img2img.normalize_prompt_weights, + help="Choose to normalize prompt weights. Default: True") + + st.session_state["defaults"].img2img.save_individual_images = st.checkbox("Save Individual Img2Img Images", + value=st.session_state['defaults'].img2img.save_individual_images, + help="Choose to save individual images. Default: True") + + st.session_state["defaults"].img2img.save_grid = st.checkbox("Save Img2Img Grid Images", + value=st.session_state['defaults'].img2img.save_grid, help="Choose to save the grid images. Default: True") + + st.session_state["defaults"].img2img.group_by_prompt = st.checkbox("Group By Img2Img Prompt", + value=st.session_state['defaults'].img2img.group_by_prompt, + help="Choose to save images grouped by their prompt. Default: False") + + st.session_state["defaults"].img2img.save_as_jpg = st.checkbox("Save Img2Img As JPG", value=st.session_state['defaults'].img2img.save_as_jpg, + help="Choose to save images as jpegs. Default: False") + + st.session_state["defaults"].img2img.write_info_files = st.checkbox("Write Info Files For Img2Img Images", + value=st.session_state['defaults'].img2img.write_info_files, + help="Choose to write the info files along with the generated images. Default: True") + + st.session_state["defaults"].img2img.use_GFPGAN = st.checkbox( + "Img2Img Use GFPGAN", value=st.session_state['defaults'].img2img.use_GFPGAN, help="Choose to use GFPGAN. Default: False") + + st.session_state["defaults"].img2img.use_RealESRGAN = st.checkbox("Img2Img Use RealESRGAN", value=st.session_state['defaults'].img2img.use_RealESRGAN, + help="Choose to use RealESRGAN. Default: False") + + st.session_state["defaults"].img2img.update_preview = True + st.session_state["defaults"].img2img.update_preview_frequency = st.number_input("Img2Img Preview Image Update Frequency", + min_value=0, + value=st.session_state['defaults'].img2img.update_preview_frequency, + help="Set the default value for the frrquency of the preview image updates. Default is: 10") + + st.title("Variation Parameters") + + st.session_state["defaults"].img2img.variant_amount = st.number_input("Default Img2Img Variation Amount", + value=st.session_state['defaults'].img2img.variant_amount, + help="Set the default variation to use. Default is: 0.0") + + # I THINK THESE ARE MISSING FROM THE CONFIG FILE + # st.session_state["defaults"].img2img.variant_amount.min_value = st.number_input("Minimum Img2Img Variation Amount", + # value=st.session_state['defaults'].img2img.variant_amount.min_value, help="Set the default minimum value for the variation slider. Default is: 0.0")) + + # st.session_state["defaults"].img2img.variant_amount.max_value = st.number_input("Maximum Img2Img Variation Amount", + # value=st.session_state['defaults'].img2img.variant_amount.max_value, help="Set the default maximum value for the variation slider. Default is: 1.0")) + + # st.session_state["defaults"].img2img.variant_amount.step = st.number_input("Img2Img Variation Slider Steps", + # value=st.session_state['defaults'].img2img.variant_amount.step, help="Set the default value for the number of steps on the variation slider. Default is: 1")) + + st.session_state['defaults'].img2img.variant_seed = st.text_input("Default Img2Img Variation Seed", + value=st.session_state['defaults'].img2img.variant_seed, help="Default variation seed.") + + with img2txt_tab: + col1 = st.columns(1, gap="large") + + st.title("Image-To-Text") + + st.session_state["defaults"].img2txt.batch_size = st.number_input("Default Img2Txt Batch Size", value=st.session_state['defaults'].img2txt.batch_size, + help="Set the default batch size for Img2Txt. Default is: 420?") + + st.session_state["defaults"].img2txt.blip_image_eval_size = st.number_input("Default Blip Image Size Evaluation", + value=st.session_state['defaults'].img2txt.blip_image_eval_size, + help="Set the default value for the blip image evaluation size. Default is: 512") + + with txt2vid_tab: + col1, col2, col3, col4, col5 = st.columns(5, gap="medium") + + with col1: + st.title("Slider Parameters") + + # Width + st.session_state["defaults"].txt2vid.width.value = st.number_input("Default txt2vid Image Width", + value=st.session_state['defaults'].txt2vid.width.value, + help="Set the default width for the generated image. Default is: 512") + + st.session_state["defaults"].txt2vid.width.min_value = st.number_input("Minimum txt2vid Image Width", + value=st.session_state['defaults'].txt2vid.width.min_value, + help="Set the default minimum value for the width slider. Default is: 64") + + st.session_state["defaults"].txt2vid.width.max_value = st.number_input("Maximum txt2vid Image Width", + value=st.session_state['defaults'].txt2vid.width.max_value, + help="Set the default maximum value for the width slider. Default is: 2048") + + # Height + st.session_state["defaults"].txt2vid.height.value = st.number_input("Default txt2vid Image Height", + value=st.session_state['defaults'].txt2vid.height.value, + help="Set the default height for the generated image. Default is: 512") + + st.session_state["defaults"].txt2vid.height.min_value = st.number_input("Minimum txt2vid Image Height", + value=st.session_state['defaults'].txt2vid.height.min_value, + help="Set the default minimum value for the height slider. Default is: 64") + + st.session_state["defaults"].txt2vid.height.max_value = st.number_input("Maximum txt2vid Image Height", + value=st.session_state['defaults'].txt2vid.height.max_value, + help="Set the default maximum value for the height slider. Default is: 2048") + + # CFG + st.session_state["defaults"].txt2vid.cfg_scale.value = st.number_input("Default txt2vid CFG Scale", + value=st.session_state['defaults'].txt2vid.cfg_scale.value, + help="Set the default value for the CFG Scale. Default is: 7.5") + + st.session_state["defaults"].txt2vid.cfg_scale.min_value = st.number_input("Minimum txt2vid CFG Scale Value", + value=st.session_state['defaults'].txt2vid.cfg_scale.min_value, + help="Set the default minimum value for the CFG scale slider. Default is: 1") + + st.session_state["defaults"].txt2vid.cfg_scale.step = st.number_input("txt2vid CFG Slider Steps", + value=st.session_state['defaults'].txt2vid.cfg_scale.step, + help="Set the default value for the number of steps on the CFG scale slider. Default is: 0.5") + + with col2: + # Sampling Steps + st.session_state["defaults"].txt2vid.sampling_steps.value = st.number_input("Default txt2vid Sampling Steps", + value=st.session_state['defaults'].txt2vid.sampling_steps.value, + help="Set the default number of sampling steps to use. Default is: 30 (with k_euler)") + + st.session_state["defaults"].txt2vid.sampling_steps.min_value = st.number_input("Minimum txt2vid Sampling Steps", + value=st.session_state['defaults'].txt2vid.sampling_steps.min_value, + help="Set the default minimum value for the sampling steps slider. Default is: 1") + + st.session_state["defaults"].txt2vid.sampling_steps.step = st.number_input("txt2vid Sampling Slider Steps", + value=st.session_state['defaults'].txt2vid.sampling_steps.step, + help="Set the default value for the number of steps on the sampling steps slider. Default is: 10") + + # Batch Count + st.session_state["defaults"].txt2vid.batch_count.value = st.number_input("txt2vid Batch count", value=st.session_state['defaults'].txt2vid.batch_count.value, + help="How many iterations or batches of images to generate in total.") + + st.session_state["defaults"].txt2vid.batch_size.value = st.number_input("txt2vid Batch size", value=st.session_state.defaults.txt2vid.batch_size.value, + help="How many images are at once in a batch.\ + It increases the VRAM usage a lot but if you have enough VRAM it can reduce the time it \ + takes to finish generation as more images are generated at once.\ + Default: 1") + + # Inference Steps + st.session_state["defaults"].txt2vid.num_inference_steps.value = st.number_input("Default Txt2Vid Inference Steps", + value=st.session_state['defaults'].txt2vid.num_inference_steps.value, + help="Set the default number of inference steps to use. Default is: 200") + + st.session_state["defaults"].txt2vid.num_inference_steps.min_value = st.number_input("Minimum Txt2Vid Sampling Steps", + value=st.session_state['defaults'].txt2vid.num_inference_steps.min_value, + help="Set the default minimum value for the inference steps slider. Default is: 10") + + st.session_state["defaults"].txt2vid.num_inference_steps.max_value = st.number_input("Maximum Txt2Vid Sampling Steps", + value=st.session_state['defaults'].txt2vid.num_inference_steps.max_value, + help="Set the default maximum value for the inference steps slider. Default is: 500") + st.session_state["defaults"].txt2vid.num_inference_steps.step = st.number_input("Txt2Vid Inference Slider Steps", + value=st.session_state['defaults'].txt2vid.num_inference_steps.step, + help="Set the default value for the number of steps on the inference steps slider. Default is: 10") + + with col3: + st.title("General Parameters") + + st.session_state['defaults'].txt2vid.default_model = st.text_input("Default Txt2Vid Model", value=st.session_state['defaults'].txt2vid.default_model, + help="Default: CompVis/stable-diffusion-v1-4") + + # INSERT CUSTOM_MODELS_LIST HERE + + default_sampler_list = ["k_lms", "k_euler", "k_euler_a", "k_dpm_2", "k_dpm_2_a", "k_heun", "PLMS", "DDIM"] + st.session_state["defaults"].txt2vid.default_sampler = st.selectbox("Default txt2vid Sampler", default_sampler_list, + index=default_sampler_list.index(st.session_state['defaults'].txt2vid.default_sampler), + help="Defaut sampler to use for txt2vid. Default: k_euler") + + st.session_state['defaults'].txt2vid.seed = st.text_input("Default txt2vid Seed", value=st.session_state['defaults'].txt2vid.seed, help="Default seed.") + + st.session_state['defaults'].txt2vid.scheduler_name = st.text_input("Default Txt2Vid Scheduler", + value=st.session_state['defaults'].txt2vid.scheduler_name, help="Default scheduler.") + + st.session_state["defaults"].txt2vid.separate_prompts = st.checkbox("Separate txt2vid Prompts", + value=st.session_state['defaults'].txt2vid.separate_prompts, help="Separate Prompts. Default: False") + + st.session_state["defaults"].txt2vid.normalize_prompt_weights = st.checkbox("Normalize txt2vid Prompt Weights", + value=st.session_state['defaults'].txt2vid.normalize_prompt_weights, + help="Choose to normalize prompt weights. Default: True") + + st.session_state["defaults"].txt2vid.save_individual_images = st.checkbox("Save Individual txt2vid Images", + value=st.session_state['defaults'].txt2vid.save_individual_images, + help="Choose to save individual images. Default: True") + + st.session_state["defaults"].txt2vid.save_video = st.checkbox("Save Txt2Vid Video", value=st.session_state['defaults'].txt2vid.save_video, + help="Choose to save the Txt2Vid video. Default: True") + + st.session_state["defaults"].txt2vid.save_video_on_stop = st.checkbox("Save video on Stop", value=st.session_state['defaults'].txt2vid.save_video_on_stop, + help="Save a video with all the images generated as frames when we hit the stop button \ + during a generation.") + + st.session_state["defaults"].txt2vid.group_by_prompt = st.checkbox("Group By txt2vid Prompt", value=st.session_state['defaults'].txt2vid.group_by_prompt, + help="Choose to save images grouped by their prompt. Default: False") + + st.session_state["defaults"].txt2vid.save_as_jpg = st.checkbox("Save txt2vid As JPG", value=st.session_state['defaults'].txt2vid.save_as_jpg, + help="Choose to save images as jpegs. Default: False") + + # Need more info for the Help dialog... + st.session_state["defaults"].txt2vid.do_loop = st.checkbox("Loop Generations", value=st.session_state['defaults'].txt2vid.do_loop, + help="Choose to loop or something, IDK.... Default: False") + + st.session_state["defaults"].txt2vid.max_duration_in_seconds = st.number_input("Txt2Vid Max Duration in Seconds", value=st.session_state['defaults'].txt2vid.max_duration_in_seconds, + help="Set the default value for the max duration in seconds for the video generated. Default is: 30") + + st.session_state["defaults"].txt2vid.write_info_files = st.checkbox("Write Info Files For txt2vid Images", value=st.session_state['defaults'].txt2vid.write_info_files, + help="Choose to write the info files along with the generated images. Default: True") + + st.session_state["defaults"].txt2vid.use_GFPGAN = st.checkbox("txt2vid Use GFPGAN", value=st.session_state['defaults'].txt2vid.use_GFPGAN, + help="Choose to use GFPGAN. Default: False") + + st.session_state["defaults"].txt2vid.use_RealESRGAN = st.checkbox("txt2vid Use RealESRGAN", value=st.session_state['defaults'].txt2vid.use_RealESRGAN, + help="Choose to use RealESRGAN. Default: False") + + st.session_state["defaults"].txt2vid.update_preview = True + st.session_state["defaults"].txt2vid.update_preview_frequency = st.number_input("txt2vid Preview Image Update Frequency", + value=st.session_state['defaults'].txt2vid.update_preview_frequency, + help="Set the default value for the frrquency of the preview image updates. Default is: 10") + + with col4: + st.title("Variation Parameters") + + st.session_state["defaults"].txt2vid.variant_amount.value = st.number_input("Default txt2vid Variation Amount", + value=st.session_state['defaults'].txt2vid.variant_amount.value, + help="Set the default variation to use. Default is: 0.0") + + st.session_state["defaults"].txt2vid.variant_amount.min_value = st.number_input("Minimum txt2vid Variation Amount", + value=st.session_state['defaults'].txt2vid.variant_amount.min_value, + help="Set the default minimum value for the variation slider. Default is: 0.0") + + st.session_state["defaults"].txt2vid.variant_amount.max_value = st.number_input("Maximum txt2vid Variation Amount", + value=st.session_state['defaults'].txt2vid.variant_amount.max_value, + help="Set the default maximum value for the variation slider. Default is: 1.0") + + st.session_state["defaults"].txt2vid.variant_amount.step = st.number_input("txt2vid Variation Slider Steps", + value=st.session_state['defaults'].txt2vid.variant_amount.step, + help="Set the default value for the number of steps on the variation slider. Default is: 1") + + st.session_state['defaults'].txt2vid.variant_seed = st.text_input("Default txt2vid Variation Seed", + value=st.session_state['defaults'].txt2vid.variant_seed, help="Default variation seed.") + + with col5: + st.title("Beta Parameters") + + # Beta Start + st.session_state["defaults"].txt2vid.beta_start.value = st.number_input("Default txt2vid Beta Start Value", + value=st.session_state['defaults'].txt2vid.beta_start.value, + help="Set the default variation to use. Default is: 0.0") + + st.session_state["defaults"].txt2vid.beta_start.min_value = st.number_input("Minimum txt2vid Beta Start Amount", + value=st.session_state['defaults'].txt2vid.beta_start.min_value, + help="Set the default minimum value for the variation slider. Default is: 0.0") + + st.session_state["defaults"].txt2vid.beta_start.max_value = st.number_input("Maximum txt2vid Beta Start Amount", + value=st.session_state['defaults'].txt2vid.beta_start.max_value, + help="Set the default maximum value for the variation slider. Default is: 1.0") + + st.session_state["defaults"].txt2vid.beta_start.step = st.number_input("txt2vid Beta Start Slider Steps", value=st.session_state['defaults'].txt2vid.beta_start.step, + help="Set the default value for the number of steps on the variation slider. Default is: 1") + + st.session_state["defaults"].txt2vid.beta_start.format = st.text_input("Default txt2vid Beta Start Format", value=st.session_state['defaults'].txt2vid.beta_start.format, + help="Set the default Beta Start Format. Default is: %.5\f") + + # Beta End + st.session_state["defaults"].txt2vid.beta_end.value = st.number_input("Default txt2vid Beta End Value", value=st.session_state['defaults'].txt2vid.beta_end.value, + help="Set the default variation to use. Default is: 0.0") + + st.session_state["defaults"].txt2vid.beta_end.min_value = st.number_input("Minimum txt2vid Beta End Amount", value=st.session_state['defaults'].txt2vid.beta_end.min_value, + help="Set the default minimum value for the variation slider. Default is: 0.0") + + st.session_state["defaults"].txt2vid.beta_end.max_value = st.number_input("Maximum txt2vid Beta End Amount", value=st.session_state['defaults'].txt2vid.beta_end.max_value, + help="Set the default maximum value for the variation slider. Default is: 1.0") + + st.session_state["defaults"].txt2vid.beta_end.step = st.number_input("txt2vid Beta End Slider Steps", value=st.session_state['defaults'].txt2vid.beta_end.step, + help="Set the default value for the number of steps on the variation slider. Default is: 1") + + st.session_state["defaults"].txt2vid.beta_end.format = st.text_input("Default txt2vid Beta End Format", value=st.session_state['defaults'].txt2vid.beta_start.format, + help="Set the default Beta Start Format. Default is: %.5\f") + + with image_processing: + col1, col2, col3, col4, col5 = st.columns(5, gap="large") + + with col1: + st.title("GFPGAN") + + st.session_state["defaults"].gfpgan.strength = st.number_input("Default Img2Txt Batch Size", value=st.session_state['defaults'].gfpgan.strength, + help="Set the default global strength for GFPGAN. Default is: 100") + with col2: + st.title("GoBig") + with col3: + st.title("RealESRGAN") + with col4: + st.title("LDSR") + with col5: + st.title("GoLatent") + + with textual_inversion_tab: + st.title("Textual Inversion") + + st.session_state['defaults'].textual_inversion.pretrained_model_name_or_path = st.text_input("Default Textual Inversion Model Path", + value=st.session_state['defaults'].textual_inversion.pretrained_model_name_or_path, + help="Default: models/ldm/stable-diffusion-v1-4") + + st.session_state['defaults'].textual_inversion.tokenizer_name = st.text_input("Default Img2Img Variation Seed", value=st.session_state['defaults'].textual_inversion.tokenizer_name, + help="Default tokenizer seed.") + + with concepts_library_tab: + st.title("Concepts Library") + #st.info("Under Construction. :construction_worker:") + col1, col2, col3, col4, col5 = st.columns(5, gap='large') + with col1: + st.session_state["defaults"].concepts_library.concepts_per_page = st.number_input("Concepts Per Page", value=st.session_state['defaults'].concepts_library.concepts_per_page, + help="Number of concepts per page to show on the Concepts Library. Default: '12'") + + # add space for the buttons at the bottom + st.markdown("---") + + # We need a submit button to save the Settings + # as well as one to reset them to the defaults, just in case. + _, _, save_button_col, reset_button_col, _, _ = st.columns([1, 1, 1, 1, 1, 1], gap="large") + with save_button_col: + save_button = st.form_submit_button("Save") + + with reset_button_col: + reset_button = st.form_submit_button("Reset") + + if save_button: + OmegaConf.save(config=st.session_state.defaults, f="configs/webui/userconfig_streamlit.yaml") + loaded = OmegaConf.load("configs/webui/userconfig_streamlit.yaml") + assert st.session_state.defaults == loaded + + # + if (os.path.exists(".streamlit/config.toml")): + with open(".streamlit/config.toml", "w") as toml_file: + toml.dump(st.session_state["streamlit_config"], toml_file) + + if reset_button: + st.session_state["defaults"] = OmegaConf.load("configs/webui/webui_streamlit.yaml") + st.experimental_rerun() diff --git a/webui/streamlit/scripts/barfi_baklavajs.py b/webui/streamlit/scripts/barfi_baklavajs.py new file mode 100644 index 0000000..9d3f802 --- /dev/null +++ b/webui/streamlit/scripts/barfi_baklavajs.py @@ -0,0 +1,91 @@ +# This file is part of sygil-webui (https://github.com/Sygil-Dev/sandbox-webui/). + +# Copyright 2022 Sygil-Dev team. +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. + +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. + +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . +# base webui import and utils. +#from sd_utils import * +from sd_utils import st +# streamlit imports + +#streamlit components section + +#other imports +from barfi import st_barfi, barfi_schemas, Block + +# Temp imports + +# end of imports +#--------------------------------------------------------------------------------------------------------------- + + +def layout(): + #st.info("Under Construction. :construction_worker:") + + #from barfi import st_barfi, Block + + #add = Block(name='Addition') + #sub = Block(name='Subtraction') + #mul = Block(name='Multiplication') + #div = Block(name='Division') + + #barfi_result = st_barfi(base_blocks= [add, sub, mul, div]) + # or if you want to use a category to organise them in the frontend sub-menu + #barfi_result = st_barfi(base_blocks= {'Op 1': [add, sub], 'Op 2': [mul, div]}) + + col1, col2, col3 = st.columns([1, 8, 1]) + + with col2: + feed = Block(name='Feed') + feed.add_output() + def feed_func(self): + self.set_interface(name='Output 1', value=4) + feed.add_compute(feed_func) + + splitter = Block(name='Splitter') + splitter.add_input() + splitter.add_output() + splitter.add_output() + def splitter_func(self): + in_1 = self.get_interface(name='Input 1') + value = (in_1/2) + self.set_interface(name='Output 1', value=value) + self.set_interface(name='Output 2', value=value) + splitter.add_compute(splitter_func) + + mixer = Block(name='Mixer') + mixer.add_input() + mixer.add_input() + mixer.add_output() + def mixer_func(self): + in_1 = self.get_interface(name='Input 1') + in_2 = self.get_interface(name='Input 2') + value = (in_1 + in_2) + self.set_interface(name='Output 1', value=value) + mixer.add_compute(mixer_func) + + result = Block(name='Result') + result.add_input() + def result_func(self): + in_1 = self.get_interface(name='Input 1') + result.add_compute(result_func) + + load_schema = st.selectbox('Select a saved schema:', barfi_schemas()) + + compute_engine = st.checkbox('Activate barfi compute engine', value=False) + + barfi_result = st_barfi(base_blocks=[feed, result, mixer, splitter], + compute_engine=compute_engine, load_schema=load_schema) + + if barfi_result: + st.write(barfi_result) diff --git a/webui/streamlit/scripts/custom_components/dragable_number_input/index.html b/webui/streamlit/scripts/custom_components/dragable_number_input/index.html new file mode 100644 index 0000000..c26e3fd --- /dev/null +++ b/webui/streamlit/scripts/custom_components/dragable_number_input/index.html @@ -0,0 +1,134 @@ + + + + + + + + + + + + \ No newline at end of file diff --git a/webui/streamlit/scripts/custom_components/draggable_number_input/__init__.py b/webui/streamlit/scripts/custom_components/draggable_number_input/__init__.py new file mode 100644 index 0000000..5e79185 --- /dev/null +++ b/webui/streamlit/scripts/custom_components/draggable_number_input/__init__.py @@ -0,0 +1,11 @@ +import os +import streamlit.components.v1 as components + +def load(pixel_per_step = 50): + parent_dir = os.path.dirname(os.path.abspath(__file__)) + file = os.path.join(parent_dir, "main.js") + + with open(file) as f: + javascript_main = f.read() + javascript_main = javascript_main.replace("%%pixelPerStep%%",str(pixel_per_step)) + components.html(f"") \ No newline at end of file diff --git a/webui/streamlit/scripts/custom_components/draggable_number_input/main.js b/webui/streamlit/scripts/custom_components/draggable_number_input/main.js new file mode 100644 index 0000000..574d133 --- /dev/null +++ b/webui/streamlit/scripts/custom_components/draggable_number_input/main.js @@ -0,0 +1,192 @@ +// iframe parent +var parentDoc = window.parent.document + +// check for mouse pointer locking support, not a requirement but improves the overall experience +var havePointerLock = 'pointerLockElement' in parentDoc || + 'mozPointerLockElement' in parentDoc || + 'webkitPointerLockElement' in parentDoc; + +// the pointer locking exit function +parentDoc.exitPointerLock = parentDoc.exitPointerLock || parentDoc.mozExitPointerLock || parentDoc.webkitExitPointerLock; + +// how far should the mouse travel for a step in pixel +var pixelPerStep = %%pixelPerStep%%; +// how many steps did the mouse move in as float +var movementDelta = 0.0; +// value when drag started +var lockedValue = 0.0; +// minimum value from field +var lockedMin = 0.0; +// maximum value from field +var lockedMax = 0.0; +// how big should the field steps be +var lockedStep = 0.0; +// the currently locked in field +var lockedField = null; + +// lock box to just request pointer lock for one element +var lockBox = document.createElement("div"); +lockBox.classList.add("lockbox"); +parentDoc.body.appendChild(lockBox); +lockBox.requestPointerLock = lockBox.requestPointerLock || lockBox.mozRequestPointerLock || lockBox.webkitRequestPointerLock; + +function Lock(field) +{ + var rect = field.getBoundingClientRect(); + lockBox.style.left = (rect.left-2.5)+"px"; + lockBox.style.top = (rect.top-2.5)+"px"; + + lockBox.style.width = (rect.width+2.5)+"px"; + lockBox.style.height = (rect.height+5)+"px"; + + lockBox.requestPointerLock(); +} + +function Unlock() +{ + parentDoc.exitPointerLock(); + lockBox.style.left = "0px"; + lockBox.style.top = "0px"; + + lockBox.style.width = "0px"; + lockBox.style.height = "0px"; + lockedField.focus(); +} + +parentDoc.addEventListener('mousedown', (e) => { + // if middle is down + if(e.button === 1) + { + if(e.target.tagName === 'INPUT' && e.target.type === 'number') + { + e.preventDefault(); + var field = e.target; + if(havePointerLock) + Lock(field); + + // save current field + lockedField = e.target; + // add class for styling + lockedField.classList.add("value-dragging"); + // reset movement delta + movementDelta = 0.0; + // set to 0 if field is empty + if(lockedField.value === '') + lockedField.value = 0.0; + + // save current field value + lockedValue = parseFloat(lockedField.value); + + if(lockedField.min === '' || lockedField.min === '-Infinity') + lockedMin = -99999999.0; + else + lockedMin = parseFloat(lockedField.min); + + if(lockedField.max === '' || lockedField.max === 'Infinity') + lockedMax = 99999999.0; + else + lockedMax = parseFloat(lockedField.max); + + if(lockedField.step === '' || lockedField.step === 'Infinity') + lockedStep = 1.0; + else + lockedStep = parseFloat(lockedField.step); + + // lock pointer if available + if(havePointerLock) + Lock(lockedField); + + // add drag event + parentDoc.addEventListener("mousemove", onDrag, false); + } + } +}); + +function onDrag(e) +{ + if(lockedField !== null) + { + // add movement to delta + movementDelta += e.movementX / pixelPerStep; + if(lockedField === NaN) + return; + // set new value + let value = lockedValue + Math.floor(Math.abs(movementDelta)) * lockedStep * Math.sign(movementDelta); + lockedField.focus(); + lockedField.select(); + parentDoc.execCommand('insertText', false /*no UI*/, Math.min(Math.max(value, lockedMin), lockedMax)); + } +} + +parentDoc.addEventListener('mouseup', (e) => { + // if mouse is up + if(e.button === 1) + { + // release pointer lock if available + if(havePointerLock) + Unlock(); + + if(lockedField !== null && lockedField !== NaN) + { + // stop drag event + parentDoc.removeEventListener("mousemove", onDrag, false); + // remove class for styling + lockedField.classList.remove("value-dragging"); + // remove reference + lockedField = null; + } + } +}); + +// only execute once (even though multiple iframes exist) +if(!parentDoc.hasOwnProperty("dragableInitialized")) +{ + var parentCSS = +` +/* Make input-instruction not block mouse events */ +.input-instructions,.input-instructions > *{ + pointer-events: none; + user-select: none; + -moz-user-select: none; + -khtml-user-select: none; + -webkit-user-select: none; + -o-user-select: none; +} + +.lockbox { + background-color: transparent; + position: absolute; + pointer-events: none; + user-select: none; + -moz-user-select: none; + -khtml-user-select: none; + -webkit-user-select: none; + -o-user-select: none; + border-left: dotted 2px rgb(255,75,75); + border-top: dotted 2px rgb(255,75,75); + border-bottom: dotted 2px rgb(255,75,75); + border-right: dotted 1px rgba(255,75,75,0.2); + border-top-left-radius: 0.25rem; + border-bottom-left-radius: 0.25rem; + z-index: 1000; +} +`; + + // get parent document head + var head = parentDoc.getElementsByTagName('head')[0]; + // add style tag + var s = document.createElement('style'); + // set type attribute + s.setAttribute('type', 'text/css'); + // add css forwarded from python + if (s.styleSheet) { // IE + s.styleSheet.cssText = parentCSS; + } else { // the world + s.appendChild(document.createTextNode(parentCSS)); + } + // add style to head + head.appendChild(s); + // set flag so this only runs once + parentDoc["dragableInitialized"] = true; +} + diff --git a/webui/streamlit/scripts/custom_components/sygil_suggestions/__init__.py b/webui/streamlit/scripts/custom_components/sygil_suggestions/__init__.py new file mode 100644 index 0000000..e25e1ac --- /dev/null +++ b/webui/streamlit/scripts/custom_components/sygil_suggestions/__init__.py @@ -0,0 +1,46 @@ +import os +from collections import defaultdict +import streamlit.components.v1 as components + +# where to save the downloaded key_phrases +key_phrases_file = "data/tags/key_phrases.json" +# the loaded key phrase json as text +key_phrases_json = "" +# where to save the downloaded key_phrases +thumbnails_file = "data/tags/thumbnails.json" +# the loaded key phrase json as text +thumbnails_json = "" + +def init(): + global key_phrases_json, thumbnails_json + with open(key_phrases_file) as f: + key_phrases_json = f.read() + with open(thumbnails_file) as f: + thumbnails_json = f.read() + +def suggestion_area(placeholder): + # get component path + parent_dir = os.path.dirname(os.path.abspath(__file__)) + # get file paths + javascript_file = os.path.join(parent_dir, "main.js") + stylesheet_file = os.path.join(parent_dir, "main.css") + parent_stylesheet_file = os.path.join(parent_dir, "parent.css") + + # load file texts + with open(javascript_file) as f: + javascript_main = f.read() + with open(stylesheet_file) as f: + stylesheet_main = f.read() + with open(parent_stylesheet_file) as f: + parent_stylesheet = f.read() + + # add suggestion area div box + html = "
javascript failed
" + # add loaded style + html += f"" + # set default variables + html += f"" + # add main java script + html += f"\n" + # add component to site + components.html(html, width=None, height=None, scrolling=True) \ No newline at end of file diff --git a/webui/streamlit/scripts/custom_components/sygil_suggestions/main.css b/webui/streamlit/scripts/custom_components/sygil_suggestions/main.css new file mode 100644 index 0000000..c8729b4 --- /dev/null +++ b/webui/streamlit/scripts/custom_components/sygil_suggestions/main.css @@ -0,0 +1,81 @@ +* +{ + padding: 0px; + margin: 0px; + user-select: none; + -moz-user-select: none; + -khtml-user-select: none; + -webkit-user-select: none; + -o-user-select: none; +} + +body +{ + width: 100%; + height: 100%; + padding-left: calc( 1em - 1px ); + padding-top: calc( 1em - 1px ); + overflow: hidden; +} + +/* width */ +::-webkit-scrollbar { + width: 7px; +} + +/* Track */ +::-webkit-scrollbar-track { + background: rgb(10, 13, 19); +} + +/* Handle */ +::-webkit-scrollbar-thumb { + background: #6c6e72; + border-radius: 3px; +} + +/* Handle on hover */ +::-webkit-scrollbar-thumb:hover { + background: #6c6e72; +} + +#scroll_area +{ + display: flex; + overflow-x: hidden; + overflow-y: auto; +} + +#suggestion_area +{ + overflow-x: hidden; + width: calc( 100% - 2em - 2px ); + margin-bottom: calc( 1em + 13px ); + min-height: 50px; +} + +span +{ + border: 1px solid rgba(250, 250, 250, 0.2); + border-radius: 0.25rem; + font-size: 1rem; + font-family: "Source Sans Pro", sans-serif; + + background-color: rgb(38, 39, 48); + color: white; + display: inline-block; + padding: 0.5rem; + margin-right: 3px; + cursor: pointer; + user-select: none; + -moz-user-select: none; + -khtml-user-select: none; + -webkit-user-select: none; + -o-user-select: none; +} + +span:hover +{ + color: rgb(255,75,75); + border-color: rgb(255,75,75); +} \ No newline at end of file diff --git a/webui/streamlit/scripts/custom_components/sygil_suggestions/main.js b/webui/streamlit/scripts/custom_components/sygil_suggestions/main.js new file mode 100644 index 0000000..9826702 --- /dev/null +++ b/webui/streamlit/scripts/custom_components/sygil_suggestions/main.js @@ -0,0 +1,1048 @@ + +// parent document +var parentDoc = window.parent.document; +// iframe element in parent document +var frame = window.frameElement; +// the area to put the suggestions in +var suggestionArea = document.getElementById('suggestion_area'); +var scrollArea = document.getElementById('scroll_area'); +// button height is read when the first button gets created +var buttonHeight = -1; +// the maximum size of the iframe in buttons (3 x buttons height) +var maxHeightInButtons = 3; +// the prompt field connected to this iframe +var promptField = null; +// the category of suggestions +var activeCategory = contextCategory; + +var conditionalButtons = null; + +var contextCategory = "[context]"; + +var frameHeight = "calc( 3em - 3px + {} )"; + +var filterGroups = {nsfw_mild: "nsfw_mild", nsfw_basic: "nsfw_basic", nsfw_strict: "nsfw_strict", gore_mild: "gore_mild", gore_basic: "gore_basic", gore_strict: "gore_strict"}; +var activeFilters = [filterGroups.nsfw_mild, filterGroups.nsfw_basic, filterGroups.gore_mild]; + +var triggers = {empty: "empty", nsfw: "nsfw", nude: "nude"}; +var activeContext = []; + +var triggerIndex = {}; + +var wordMap = {}; +var tagMap = {}; + +// could pass in an array of specific stylesheets for optimization +function getAllCSSVariableNames(styleSheets = parentDoc.styleSheets){ + var cssVars = []; + // loop each stylesheet + for(var i = 0; i < styleSheets.length; i++){ + // loop stylesheet's cssRules + try{ // try/catch used because 'hasOwnProperty' doesn't work + for( var j = 0; j < styleSheets[i].cssRules.length; j++){ + try{ + //console.log(styleSheets[i].cssRules[j].selectorText); + // loop stylesheet's cssRules' style (property names) + for(var k = 0; k < styleSheets[i].cssRules[j].style.length; k++){ + let name = styleSheets[i].cssRules[j].style[k]; + // test name for css variable signiture and uniqueness + if(name.startsWith('--') && cssVars.indexOf(name) == -1){ + cssVars.push(name); + } + } + } catch (error) {} + } + } catch (error) {} + } + return cssVars; +} + +function currentFrameAbsolutePosition() { + let currentWindow = window; + let currentParentWindow; + let positions = []; + let rect; + + while (currentWindow !== window.top) { + currentParentWindow = currentWindow.parent; + for (let idx = 0; idx < currentParentWindow.frames.length; idx++) + if (currentParentWindow.frames[idx] === currentWindow) { + for (let frameElement of currentParentWindow.document.getElementsByTagName('iframe')) { + if (frameElement.contentWindow === currentWindow) { + rect = frameElement.getBoundingClientRect(); + positions.push({x: rect.x, y: rect.y}); + } + } + currentWindow = currentParentWindow; + break; + } + } + return positions.reduce((accumulator, currentValue) => { + return { + x: accumulator.x + currentValue.x, + y: accumulator.y + currentValue.y + }; + }, { x: 0, y: 0 }); +} + +// check if element is visible +function isVisible(e) { + return !!( e.offsetWidth || e.offsetHeight || e.getClientRects().length ); +} + +// remove everything from the suggestion area +function ClearSuggestionArea(text = "") +{ + suggestionArea.innerHTML = text; + conditionalButtons = []; +} + +// update iframe size depending on button rows +function UpdateSize() +{ + // calculate maximum height + var maxHeight = buttonHeight * maxHeightInButtons; + + var height = suggestionArea.lastChild.offsetTop + buttonHeight; + // apply height to iframe + frame.style.height = frameHeight.replace("{}", Math.min(height,maxHeight)+"px"); + scrollArea.style.height = frame.style.height; +} + +// add a button to the suggestion area +function AddButton(label, action, dataTooltip = null, tooltipImage = null, pattern = null, data = null) +{ + // create span + var button = document.createElement("span"); + // label it + button.innerHTML = label; + if(data != null) + { + // add category attribute to button, will be read on click + button.setAttribute("data",data); + } + if(pattern != null) + { + // add category attribute to button, will be read on click + button.setAttribute("pattern",pattern); + } + if(dataTooltip != null) + { + // add category attribute to button, will be read on click + button.setAttribute("tooltip-text",dataTooltip); + } + if(tooltipImage != null) + { + // add category attribute to button, will be read on click + button.setAttribute("tooltip-image",tooltipImage); + } + // add button function + button.addEventListener('click', action, false); + button.addEventListener('mouseover', ButtonHoverEnter); + button.addEventListener('mouseout', ButtonHoverExit); + // add button to suggestion area + suggestionArea.appendChild(button); + // get buttonHeight if not set + if(buttonHeight < 0) + buttonHeight = button.offsetHeight; + return button; +} + +// find visible prompt field to connect to this iframe +function GetPromptField() +{ + // get all prompt fields, the %% placeholder %% is set in python + var all = parentDoc.querySelectorAll('textarea[placeholder="'+placeholder+'"]'); + // filter visible + for(var i = 0; i < all.length; i++) + { + if(isVisible(all[i])) + { + promptField = all[i]; + promptField.addEventListener('input', OnChange, false); + promptField.addEventListener('click', OnClick, false); + promptField.addEventListener('keyup', OnKey, false); + break; + } + } +} + +function OnChange(e) +{ + ButtonConditions(); + ButtonUpdateContext(true); +} + +function OnClick(e) +{ + ButtonUpdateContext(true); +} + +function OnKey(e) +{ + if (e.keyCode == '37' || e.keyCode == '38' || e.keyCode == '39' || e.keyCode == '40') { + ButtonUpdateContext(false); + } +} + +function getCaretPosition(ctrl) { + // IE < 9 Support + if (document.selection) { + ctrl.focus(); + var range = document.selection.createRange(); + var rangelen = range.text.length; + range.moveStart('character', -ctrl.value.length); + var start = range.text.length - rangelen; + return { + 'start': start, + 'end': start + rangelen + }; + } // IE >=9 and other browsers + else if (ctrl.selectionStart || ctrl.selectionStart == '0') { + return { + 'start': ctrl.selectionStart, + 'end': ctrl.selectionEnd + }; + } else { + return { + 'start': 0, + 'end': 0 + }; + } +} + +function setCaretPosition(ctrl, start, end) { + // IE >= 9 and other browsers + if (ctrl.setSelectionRange) { + ctrl.focus(); + ctrl.setSelectionRange(start, end); + } + // IE < 9 + else if (ctrl.createTextRange) { + var range = ctrl.createTextRange(); + range.collapse(true); + range.moveEnd('character', end); + range.moveStart('character', start); + range.select(); + } +} + +function isEmptyOrSpaces(str){ + return str === null || str.match(/^ *$/) !== null; +} + +function ButtonUpdateContext(changeCategory) +{ + let targetCategory = contextCategory; + let text = promptField.value; + if(document.activeElement === promptField) + { + var pos = getCaretPosition(promptField).end; + text = promptField.value.slice(0, pos); + } + + activeContext = []; + + var parts = text.split(/[\.?!,]/); + if(activeCategory == "Artists" && !isEmptyOrSpaces(parts[parts.length-1])) + { + return; + } + if(text == "") + { + activeContext.push(triggers.empty); + } + if(text.endsWith("by")) + { + changeCategory = true; + targetCategory = "Artists"; + activeContext.push("Artists"); + } + else + { + var parts = text.split(/[\.,!?;]/); + parts = parts.reverse(); + + parts.forEach( part => + { + var words = part.split(" "); + words = words.reverse(); + words.forEach( word => + { + word = word.replace(/[^a-zA-Z0-9 \._\-]/g, '').trim().toLowerCase(); + word = WordToKey(word); + if(wordMap.hasOwnProperty(word)) + { + activeContext = activeContext.concat(wordMap[word]).unique(); + } + }); + }); + } + + if(activeContext.length == 0) + { + if(activeCategory == contextCategory) + { + activeCategory = ""; + ShowMenu(); + } + } + else if(changeCategory) + { + activeCategory = targetCategory; + ShowMenu(); + } + else if(activeCategory == contextCategory) + ShowMenu(); +} + +// when pressing a button, give the focus back to the prompt field +function KeepFocus(e) +{ + e.preventDefault(); + promptField.focus(); +} + +function selectCategory(e) +{ + KeepFocus(e); + // set category from attribute + activeCategory = e.target.getAttribute("data"); + // rebuild menu + ShowMenu(); +} + +function leaveCategory(e) +{ + KeepFocus(e); + activeCategory = ""; + // rebuild menu + ShowMenu(); +} + +// [...]=block "..."=requirement ...=add {|}=cursor {}=insert .,!?;=start +// [{} {|}] +// [,by {}{|}]["by "* and by {}{|}] +// [, {}{|}] + +function PatternWalk(text, pattern) +{ + var parts = text.split(/[\,!?;]/); + var part = parts[parts.length - 1]; + + var indent = 0; + var outPattern = ""; + var requirement = "" + var mode = ""; + var patternFailed = false; + var partIndex = 0; + for( let i = 0; i < pattern.length; i++) + { + if(mode == "") + { + if(pattern[i] == "[") + { + indent++; + mode = "pattern"; + console.log("pattern start:"); + } + } + else if(indent > 0) + { + if(pattern[i] == "[") + { + indent++; + } + else if(mode == "pattern") + { + if(patternFailed) + { + if(pattern[i] == "]") + { + indent--; + if(indent == 0) + { + mode = ""; + outPattern = ""; + partIndex = 0; + patternFailed = false; + part = parts[parts.length - 1]; + } + } + else + { + } + } + else + { + if(pattern[i] == "\"") + { + mode = "requirement"; + } + else if(pattern[i] == "]") + { + indent--; + if(indent == 0) + { + mode = ""; + return outPattern; + } + } + else if(pattern[i] == "," || pattern[i] == "!" || pattern[i] == "?" || pattern[i] == ";" ) + { + let textToCheck = (text+outPattern).trim(); + + if(textToCheck.endsWith("and")) + { + outPattern += "{_}"; + part = ""; + partIndex = 0; + } + else if(textToCheck.endsWith("with")) + { + outPattern += "{_}"; + part = ""; + partIndex = 0; + } + else if(textToCheck.endsWith("of")) + { + outPattern += "{_}"; + part = ""; + partIndex = 0; + } + else if(textToCheck.endsWith("at")) + { + outPattern += "{_}"; + part = ""; + partIndex = 0; + } + else if(textToCheck.endsWith("and a")) + { + part = ""; + partIndex = 0; + } + else if(textToCheck.endsWith("with a")) + { + part = ""; + partIndex = 0; + } + else if(textToCheck.endsWith("of a")) + { + part = ""; + partIndex = 0; + } + else if(textToCheck.endsWith("at a")) + { + part = ""; + partIndex = 0; + } + else if(textToCheck.endsWith("and an")) + { + part = ""; + partIndex = 0; + } + else if(textToCheck.endsWith("with an")) + { + part = ""; + partIndex = 0; + } + else if(textToCheck.endsWith("of an")) + { + part = ""; + partIndex = 0; + } + else if(textToCheck.endsWith("at an")) + { + part = ""; + partIndex = 0; + } + else if(!textToCheck.endsWith(pattern[i])) + { + outPattern += pattern[i]; + part = ""; + partIndex = 0; + } + } + else if(pattern[i] == "{") + { + outPattern += pattern[i]; + mode = "write"; + } + else if(pattern[i] == "." && pattern[i+1] == "*" || pattern[i] == "*") + { + let minLength = false; + if(pattern[i] == "." && pattern[i+1] == "*") + { + minLength = true; + i++; + } + var o = pattern.slice(i+1).search(/[^\w\s]/); + var subpattern = pattern.slice(i+1,i+1+o); + + var index = part.lastIndexOf(subpattern); + var subPatternIndex = subpattern.length; + while(index == -1) + { + if(subPatternIndex <= 1) + { + patternFailed = true; + break; + } + + subPatternIndex--; + var slice = subpattern.slice(0,subPatternIndex); + index = part.lastIndexOf(slice); + } + if(!patternFailed) + { + if(minLength && index == 0) + { + patternFailed = true; + } + partIndex += index; + } + else + { + } + } + else + { + if(partIndex >= part.length) + { + outPattern += pattern[i]; + } + else if(part[partIndex] == pattern[i]) + { + partIndex++; + } + else + { + patternFailed = true; + } + } + } + } + else if(mode == "requirement") + { + if(pattern[i] == "\"") + { + if(!part.includes(requirement)) + { + patternFailed = true; + } + else + { + partIndex = part.indexOf(requirement)+requirement.length; + } + mode = "pattern"; + requirement = ""; + } + else + { + requirement += pattern[i]; + } + } + else if(mode == "write") + { + if(pattern[i] == "}") + { + outPattern += pattern[i]; + mode = "pattern"; + } + else + { + outPattern += pattern[i]; + } + } + } + else if(pattern[i] == "[") + indent++; + } + // fallback + return ", {}"; +} + +function InsertPhrase(phrase, pattern) +{ + var text = promptField.value ?? ""; + if(document.activeElement === promptField) + { + var pos = getCaretPosition(promptField).end; + text = promptField.value.slice(0, pos); + } + var insert = PatternWalk(text,pattern); + insert = insert.replace('{}',phrase); + + let firstLetter = phrase.trim()[0]; + + if(firstLetter == "a" || firstLetter == "e" || firstLetter == "i" || firstLetter == "o" || firstLetter == "u") + insert = insert.replace('{_}',"an"); + else + insert = insert.replace('{_}',"a"); + + insert = insert.replace(/{[^|]/,""); + insert = insert.replace(/[^|]}/,""); + + var caret = (text+insert).indexOf("{|}"); + insert = insert.replace('{|}',""); + // inserting via execCommand is required, this triggers all native browser functionality as if the user wrote into the prompt field. + parentDoc.execCommand('insertText', false, insert); + setCaretPosition(promptField, caret, caret); +} + +function SelectPhrase(e) +{ + KeepFocus(e); + var pattern = e.target.getAttribute("pattern"); + var phrase = e.target.getAttribute("data"); + + InsertPhrase(phrase,pattern); +} + +function CheckButtonCondition(condition) +{ + var pos = getCaretPosition(promptField).end; + var text = promptField.value.slice(0, pos); + if(condition === "empty") + { + return text == ""; + } +} + +function ButtonConditions() +{ + conditionalButtons.forEach(entry => + { + let filtered = !CheckButtonCondition(entry.condition); + + if(entry.filterGroup != null) + { + entry.filterGroup.split(",").forEach( (group) => + { + + if(activeFilters.includes(group.trim().toLowerCase())) + { + filtered = filtered || true; + return; + } + }); + } + if(filtered) + entry.element.style.display = "none"; + else + entry.element.style.display = "inline-block"; + }); +} + +function ButtonHoverEnter(e) +{ + var text = e.target.getAttribute("tooltip-text"); + var image = e.target.getAttribute("tooltip-image"); + ShowTooltip(text, e.target, image) +} + +function ButtonHoverExit(e) +{ + HideTooltip(); +} + +function ShowTooltip(text, target, image = "") +{ + var cleanedName = image == null? null : image.replace(/[^a-zA-Z0-9 \._\-]/g, ''); + if((text == "" || text == null) && (image == "" || image == null || thumbnails[cleanedName] === undefined)) + return; + + var currentFramePosition = currentFrameAbsolutePosition(); + var rect = target.getBoundingClientRect(); + var element = parentDoc["phraseTooltip"]; + element.innerText = text; + if(image != "" && image != null && thumbnails[cleanedName] !== undefined) + { + + var img = parentDoc.createElement('img'); + img.src = GetThumbnailURL(cleanedName); + element.appendChild(img) + } + element.style.display = "flex"; + element.style.top = (rect.bottom+currentFramePosition.y)+"px"; + element.style.left = (rect.right+currentFramePosition.x)+"px"; + element.style.width = "inherit"; + element.style.height = "inherit"; +} + +function base64toBlob(base64Data, contentType) { + contentType = contentType || ''; + var sliceSize = 1024; + var byteCharacters = atob(base64Data); + var bytesLength = byteCharacters.length; + var slicesCount = Math.ceil(bytesLength / sliceSize); + var byteArrays = new Array(slicesCount); + + for (var sliceIndex = 0; sliceIndex < slicesCount; ++sliceIndex) { + var begin = sliceIndex * sliceSize; + var end = Math.min(begin + sliceSize, bytesLength); + + var bytes = new Array(end - begin); + for (var offset = begin, i = 0; offset < end; ++i, ++offset) { + bytes[i] = byteCharacters[offset].charCodeAt(0); + } + byteArrays[sliceIndex] = new Uint8Array(bytes); + } + return new Blob(byteArrays, { type: contentType }); +} + +function GetThumbnailURL(image) +{ + if(parentDoc["keyPhraseSuggestionsLoadedBlobs"].hasOwnProperty(image)) + { + return parentDoc["keyPhraseSuggestionsLoadedBlobs"][image]; + } + else + { + let url = URL.createObjectURL(GetThumbnail(image)); + parentDoc["keyPhraseSuggestionsLoadedBlobs"][image] = url; + return url; + } +} + +function GetThumbnail(image) +{ + return base64toBlob(thumbnails[image], 'image/webp'); +} + +function HideTooltip() +{ + var element = parentDoc["phraseTooltip"]; + element.style.display= "none"; + element.innerHTML = ""; + element.style.top = "0px"; + element.style.left = "0px"; + element.style.width = "0px"; + element.style.height = "0px"; +} + +function RemoveDouble(str, symbol) +{ + let doubleSymbole = symbol+symbol; + while(str.includes(doubleSymbole)) + { + str = str.replace(doubleSymbole, symbol); + } + return str; +} + +function ReplaceAll(str, toReplace, seperator, symbol) +{ + toReplace.split(seperator).forEach( (replaceSymbol) => + { + str = str.replace(replaceSymbol, symbol); + }); + return str; +} + +function WordToKey(word) +{ + if(word.endsWith("s")) + word = word.slice(0, -1); + word = word.replace("'", ""); + if(word.endsWith("s")) + word = word.slice(0, -1); + word = ReplaceAll(word, "sch;sh;ch;ll;gg;r;l;j;g", ';', 'h'); + word = ReplaceAll(word, "sw;ss;zz;qu;kk;k;z;q;s;x", ';','c'); + word = ReplaceAll(word, "pp;bb;tt;th;ff;p;t;b;f;v", ';','d'); + word = ReplaceAll(word, "yu;yo;oo;u;y;w", ';','o'); + word = ReplaceAll(word, "ee;ie;a;i", ';','e'); + word = ReplaceAll(word, "mm;nn;n", ';','n'); + word = RemoveDouble(word, "l"); + word = RemoveDouble(word, "c"); + word = RemoveDouble(word, "e"); + word = RemoveDouble(word, "m"); + word = RemoveDouble(word, "j"); + word = RemoveDouble(word, "o"); + word = RemoveDouble(word, "d"); + word = RemoveDouble(word, "f"); + return word; +} + +Array.prototype.unique = function() { + var a = this.concat(); + for(var i=0; i + { + trigger = trigger.replace(/[^a-zA-Z0-9 \._\-]/g, '').trim().toLowerCase(); + if(!triggers.hasOwnProperty(trigger)) + { + trigger = WordToKey(trigger); + } + if(triggerIndex.hasOwnProperty(trigger)) + { + triggerIndex[trigger].push( { category: category, index: i }); + } + else + { + triggerIndex[trigger] = []; + triggerIndex[trigger].push( { category: category, index: i }); + } + }); + } + + /*let words = entry["phrase"].split(" "); + let wordCount = words.length; + for(let e = 0; e < wordCount; e++) + { + let wordKey = WordToKey(words[e].replace(/[^a-zA-Z0-9 \._\-]/g, '').trim().toLowerCase()); + + if(wordKey.length < 2) + continue; + + if(!wordMap.hasOwnProperty(wordKey)) + { + wordMap[wordKey] = []; + } + + let entrySearchTags = entry["search_tags"].split(","); + entrySearchTags.push(category); + entrySearchTags.forEach( search_tag => + { + if(search_tag != null && search_tag != "") + { + if(search_tag.endsWith("'s")) + search_tag = search_tag.slice(0, -2); + if(search_tag.endsWith("s")) + search_tag = search_tag.slice(0, -1); + search_tag = search_tag.replace(/[^a-zA-Z0-9 \._\-]/g, '').trim().toLowerCase(); + wordMap[wordKey].push(search_tag); + if(!tagMap.hasOwnProperty(search_tag)) + { + tagMap[search_tag] = []; + } + tagMap[search_tag].push({ category: category, index: i }); + tagMap[search_tag] = tagMap[search_tag].unique(); + } + }); + wordMap[wordKey] = wordMap[wordKey].unique(); + }*/ + } + } +} + +function ConditionalButton(entry, button) +{ + if(entry["show_if"] != "" || entry["filter_group"] != "") + conditionalButtons.push({element:button,condition:entry["show_if"], filterGroup:entry["filter_group"]}); +} + +// generate menu in suggestion area +function ShowMenu() +{ + // clear all buttons from menu + ClearSuggestionArea(); + HideTooltip(); + + // if no chategory is selected + if(activeCategory == "") + { + if(activeContext.length != 0) + { + AddButton("Context", selectCategory, "A dynamicly updating category based on the current prompt.", null, null, contextCategory); + } + for (var category in keyPhrases) + { + AddButton(category, selectCategory, keyPhrases[category]["description"], null, null, category); + } + // change iframe size after buttons have been added + UpdateSize(); + } + else if(activeCategory == contextCategory) + { + // add a button to leave the chategory + var backbutton = AddButton("↑ back", leaveCategory); + activeContext.forEach( context => + { + if(tagMap.hasOwnProperty(context)) + { + var words = tagMap[context].unique(); + words.forEach( word => + { + var entry = keyPhrases[word.category]["entries"][word.index]; + var tempPattern = keyPhrases[word.category]["pattern"]; + + if(entry["pattern_override"] != "") + { + tempPattern = entry["pattern_override"]; + } + + var button = AddButton(entry["phrase"], SelectPhrase, entry["description"], entry["phrase"],tempPattern, entry["phrase"]); + + ConditionalButton(entry, button); + }); + } + if(triggerIndex.hasOwnProperty(context)) + { + var triggered = triggerIndex[context]; + triggered.forEach( trigger => + { + var entry = keyPhrases[trigger.category]["entries"][trigger.index]; + var tempPattern = keyPhrases[trigger.category]["pattern"]; + + if(entry["pattern_override"] != "") + { + tempPattern = entry["pattern_override"]; + } + + var button = AddButton(entry["phrase"], SelectPhrase, entry["description"], entry["phrase"],tempPattern, entry["phrase"]); + + ConditionalButton(entry, button); + }); + } + }); + + ButtonConditions(); + // change iframe size after buttons have been added + UpdateSize(); + } + // if a chategory is selected + else + { + // add a button to leave the chategory + var backbutton = AddButton("↑ back", leaveCategory); + var pattern = keyPhrases[activeCategory]["pattern"]; + keyPhrases[activeCategory]["entries"].forEach(entry => + { + var tempPattern = pattern; + if(entry["pattern_override"] != "") + { + tempPattern = entry["pattern_override"]; + } + + var button = AddButton(entry["phrase"], SelectPhrase, entry["description"], entry["phrase"],tempPattern, entry["phrase"]); + + ConditionalButton(entry, button); + }); + ButtonConditions(); + // change iframe size after buttons have been added + UpdateSize(); + } +} + +// listen for clicks on the prompt field +parentDoc.addEventListener("click", (e) => +{ + // skip if this frame is not visible + if(!isVisible(frame)) + return; + + // if the iframes prompt field is not set, get it and set it + if(promptField === null) + { + GetPromptField(); + ButtonUpdateContext(true); + } + + // get the field with focus + var target = parentDoc.activeElement; + + // if the field with focus is a prompt field, the %% placeholder %% is set in python + if( target.placeholder === placeholder) + { + // generate menu + ShowMenu(); + frame.style.borderBottomWidth = '13px'; + } + else + { + // else hide the iframe + frame.style.height = "0px"; + frame.style.borderBottomWidth = '0px'; + } +}); + +function AppendStyle(targetDoc, id, content) +{ + // get parent document head + var head = targetDoc.getElementsByTagName('head')[0]; + + // add style tag + var style = targetDoc.createElement('style'); + // set type attribute + style.setAttribute('type', 'text/css'); + style.id = id; + // add css forwarded from python + if (style.styleSheet) { // IE + style.styleSheet.cssText = content; + } else { // the world + style.appendChild(parentDoc.createTextNode(content)); + } + // add style to head + head.appendChild(style); +} + +// Transfer all styles +var head = document.getElementsByTagName("head")[0]; +var parentStyle = parentDoc.getElementsByTagName("style"); +for (var i = 0; i < parentStyle.length; i++) + head.appendChild(parentStyle[i].cloneNode(true)); +var parentLinks = parentDoc.querySelectorAll('link[rel="stylesheet"]'); +for (var i = 0; i < parentLinks.length; i++) + head.appendChild(parentLinks[i].cloneNode(true)); + +// add custom style to iframe +frame.classList.add("suggestion-frame"); +// clear suggestion area to remove the "javascript failed" message +ClearSuggestionArea(); +// collapse the iframe by default +frame.style.height = "0px"; +frame.style.borderBottomWidth = '0px'; + +BuildTriggerIndex(); + +// only execute once (even though multiple iframes exist) +if(!parentDoc.hasOwnProperty('keyPhraseSuggestionsInitialized')) +{ + AppendStyle(parentDoc, "key-phrase-suggestions", parentCSS); + + var tooltip = parentDoc.createElement('div'); + tooltip.id = "phrase-tooltip"; + parentDoc.body.appendChild(tooltip); + parentDoc["phraseTooltip"] = tooltip; + // set flag so this only runs once + parentDoc["keyPhraseSuggestionsLoadedBlobs"] = {}; + parentDoc["keyPhraseSuggestionsInitialized"] = true; + + var cssVars = getAllCSSVariableNames(); + computedStyle = getComputedStyle(parentDoc.documentElement); + + parentDoc["keyPhraseSuggestionsCSSvariables"] = ":root{"; + + cssVars.forEach( (rule) => + { + parentDoc["keyPhraseSuggestionsCSSvariables"] += rule+": "+computedStyle.getPropertyValue(rule)+";"; + }); + parentDoc["keyPhraseSuggestionsCSSvariables"] += "}"; +} + +AppendStyle(document, "variables", parentDoc["keyPhraseSuggestionsCSSvariables"]); \ No newline at end of file diff --git a/webui/streamlit/scripts/custom_components/sygil_suggestions/parent.css b/webui/streamlit/scripts/custom_components/sygil_suggestions/parent.css new file mode 100644 index 0000000..6e2b285 --- /dev/null +++ b/webui/streamlit/scripts/custom_components/sygil_suggestions/parent.css @@ -0,0 +1,84 @@ +.suggestion-frame +{ + position: absolute; + + /* make as small as possible */ + margin: 0px; + padding: 0px; + min-height: 0px; + line-height: 0; + + /* animate transitions of the height property */ + -webkit-transition: height 1s; + -moz-transition: height 1s; + -ms-transition: height 1s; + -o-transition: height 1s; + transition: height 1s, border-bottom-width 1s; + + /* block selection */ + user-select: none; + -moz-user-select: none; + -khtml-user-select: none; + -webkit-user-select: none; + -o-user-select: none; + + z-index: 700; + + outline: 1px solid rgba(250, 250, 250, 0.2); + outline-offset: 0px; + border-radius: 0.25rem; + background: rgb(14, 17, 23); + + box-sizing: border-box; + -moz-box-sizing: border-box; + -webkit-box-sizing: border-box; + border-bottom: solid 13px rgb(14, 17, 23) !important; + border-left: solid 13px rgb(14, 17, 23) !important; +} + +#phrase-tooltip +{ + display: none; + pointer-events: none; + position: absolute; + border-bottom-left-radius: 0.5rem; + border-top-right-radius: 0.5rem; + border-bottom-right-radius: 0.5rem; + border: solid rgb(255,75,75) 2px; + background-color: rgb(38, 39, 48); + color: rgb(255,75,75); + font-size: 1rem; + font-family: "Source Sans Pro", sans-serif; + padding: 0.5rem; + + cursor: default; + user-select: none; + -moz-user-select: none; + -khtml-user-select: none; + -webkit-user-select: none; + -o-user-select: none; + z-index: 1000; +} + +#phrase-tooltip:has(img) +{ + transform: scale(1.25, 1.25); + -ms-transform: scale(1.25, 1.25); + -webkit-transform: scale(1.25, 1.25); +} + +#phrase-tooltip>img +{ + pointer-events: none; + border-bottom-left-radius: 0.5rem; + border-top-right-radius: 0.5rem; + border-bottom-right-radius: 0.5rem; + + cursor: default; + user-select: none; + -moz-user-select: none; + -khtml-user-select: none; + -webkit-user-select: none; + -o-user-select: none; + z-index: 1500; +} \ No newline at end of file diff --git a/webui/streamlit/scripts/img2img.py b/webui/streamlit/scripts/img2img.py new file mode 100644 index 0000000..4cee56a --- /dev/null +++ b/webui/streamlit/scripts/img2img.py @@ -0,0 +1,752 @@ +# This file is part of sygil-webui (https://github.com/Sygil-Dev/sygil-webui/). + +# Copyright 2022 Sygil-Dev team. +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. + +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. + +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . +# base webui import and utils. +from sd_utils import st, server_state, no_rerun, \ + custom_models_available, RealESRGAN_available, GFPGAN_available, LDSR_available + #generation_callback, process_images, KDiffusionSampler, \ + #load_models, hc, seed_to_int, logger, \ + #resize_image, get_matched_noise, CFGMaskedDenoiser, ImageFilter, set_page_title + +# streamlit imports +from streamlit.runtime.scriptrunner import StopException + +#other imports +import cv2 +from PIL import Image, ImageOps +import torch +import k_diffusion as K +import numpy as np +import time +import torch +import skimage +from ldm.models.diffusion.ddim import DDIMSampler +from ldm.models.diffusion.plms import PLMSSampler + +# streamlit components +from custom_components import sygil_suggestions +from streamlit_drawable_canvas import st_canvas + +# Temp imports + + +# end of imports +#--------------------------------------------------------------------------------------------------------------- + +sygil_suggestions.init() + +try: + # this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start. + from transformers import logging + + logging.set_verbosity_error() +except: + pass + +def img2img(prompt: str = '', init_info: any = None, init_info_mask: any = None, mask_mode: int = 0, mask_blur_strength: int = 3, + mask_restore: bool = False, ddim_steps: int = 50, sampler_name: str = 'DDIM', + n_iter: int = 1, cfg_scale: float = 7.5, denoising_strength: float = 0.8, + seed: int = -1, noise_mode: int = 0, find_noise_steps: str = "", height: int = 512, width: int = 512, resize_mode: int = 0, fp = None, + variant_amount: float = 0.0, variant_seed: int = None, ddim_eta:float = 0.0, + write_info_files:bool = True, separate_prompts:bool = False, normalize_prompt_weights:bool = True, + save_individual_images: bool = True, save_grid: bool = True, group_by_prompt: bool = True, + save_as_jpg: bool = True, use_GFPGAN: bool = True, GFPGAN_model: str = 'GFPGANv1.4', + use_RealESRGAN: bool = True, RealESRGAN_model: str = "RealESRGAN_x4plus_anime_6B", + use_LDSR: bool = True, LDSR_model: str = "model", + loopback: bool = False, + random_seed_loopback: bool = False + ): + + outpath = st.session_state['defaults'].general.outdir_img2img + seed = seed_to_int(seed) + + batch_size = 1 + + if sampler_name == 'PLMS': + sampler = PLMSSampler(server_state["model"]) + elif sampler_name == 'DDIM': + sampler = DDIMSampler(server_state["model"]) + elif sampler_name == 'k_dpm_2_a': + sampler = KDiffusionSampler(server_state["model"],'dpm_2_ancestral') + elif sampler_name == 'k_dpm_2': + sampler = KDiffusionSampler(server_state["model"],'dpm_2') + elif sampler_name == 'k_dpmpp_2m': + sampler = KDiffusionSampler(server_state["model"],'dpmpp_2m') + elif sampler_name == 'k_euler_a': + sampler = KDiffusionSampler(server_state["model"],'euler_ancestral') + elif sampler_name == 'k_euler': + sampler = KDiffusionSampler(server_state["model"],'euler') + elif sampler_name == 'k_heun': + sampler = KDiffusionSampler(server_state["model"],'heun') + elif sampler_name == 'k_lms': + sampler = KDiffusionSampler(server_state["model"],'lms') + else: + raise Exception("Unknown sampler: " + sampler_name) + + def process_init_mask(init_mask: Image): + if init_mask.mode == "RGBA": + init_mask = init_mask.convert('RGBA') + background = Image.new('RGBA', init_mask.size, (0, 0, 0)) + init_mask = Image.alpha_composite(background, init_mask) + init_mask = init_mask.convert('RGB') + return init_mask + + init_img = init_info + init_mask = None + if mask_mode == 0: + if init_info_mask: + init_mask = process_init_mask(init_info_mask) + elif mask_mode == 1: + if init_info_mask: + init_mask = process_init_mask(init_info_mask) + init_mask = ImageOps.invert(init_mask) + elif mask_mode == 2: + init_img_transparency = init_img.split()[-1].convert('L')#.point(lambda x: 255 if x > 0 else 0, mode='1') + init_mask = init_img_transparency + init_mask = init_mask.convert("RGB") + init_mask = resize_image(resize_mode, init_mask, width, height) + init_mask = init_mask.convert("RGB") + + assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]' + t_enc = int(denoising_strength * ddim_steps) + + if init_mask is not None and (noise_mode == 2 or noise_mode == 3) and init_img is not None: + noise_q = 0.99 + color_variation = 0.0 + mask_blend_factor = 1.0 + + np_init = (np.asarray(init_img.convert("RGB"))/255.0).astype(np.float64) # annoyingly complex mask fixing + np_mask_rgb = 1. - (np.asarray(ImageOps.invert(init_mask).convert("RGB"))/255.0).astype(np.float64) + np_mask_rgb -= np.min(np_mask_rgb) + np_mask_rgb /= np.max(np_mask_rgb) + np_mask_rgb = 1. - np_mask_rgb + np_mask_rgb_hardened = 1. - (np_mask_rgb < 0.99).astype(np.float64) + blurred = skimage.filters.gaussian(np_mask_rgb_hardened[:], sigma=16., channel_axis=2, truncate=32.) + blurred2 = skimage.filters.gaussian(np_mask_rgb_hardened[:], sigma=16., channel_axis=2, truncate=32.) + #np_mask_rgb_dilated = np_mask_rgb + blurred # fixup mask todo: derive magic constants + #np_mask_rgb = np_mask_rgb + blurred + np_mask_rgb_dilated = np.clip((np_mask_rgb + blurred2) * 0.7071, 0., 1.) + np_mask_rgb = np.clip((np_mask_rgb + blurred) * 0.7071, 0., 1.) + + noise_rgb = get_matched_noise(np_init, np_mask_rgb, noise_q, color_variation) + blend_mask_rgb = np.clip(np_mask_rgb_dilated,0.,1.) ** (mask_blend_factor) + noised = noise_rgb[:] + blend_mask_rgb **= (2.) + noised = np_init[:] * (1. - blend_mask_rgb) + noised * blend_mask_rgb + + np_mask_grey = np.sum(np_mask_rgb, axis=2)/3. + ref_mask = np_mask_grey < 1e-3 + + all_mask = np.ones((height, width), dtype=bool) + noised[all_mask,:] = skimage.exposure.match_histograms(noised[all_mask,:]**1., noised[ref_mask,:], channel_axis=1) + + init_img = Image.fromarray(np.clip(noised * 255., 0., 255.).astype(np.uint8), mode="RGB") + st.session_state["editor_image"].image(init_img) # debug + + def init(): + image = init_img.convert('RGB') + image = np.array(image).astype(np.float32) / 255.0 + image = image[None].transpose(0, 3, 1, 2) + image = torch.from_numpy(image) + + mask_channel = None + if init_mask: + alpha = resize_image(resize_mode, init_mask, width // 8, height // 8) + mask_channel = alpha.split()[-1] + + mask = None + if mask_channel is not None: + mask = np.array(mask_channel).astype(np.float32) / 255.0 + mask = (1 - mask) + mask = np.tile(mask, (4, 1, 1)) + mask = mask[None].transpose(0, 1, 2, 3) + mask = torch.from_numpy(mask).to(server_state["device"]) + + if st.session_state['defaults'].general.optimized: + server_state["modelFS"].to(server_state["device"] ) + + init_image = 2. * image - 1. + init_image = init_image.to(server_state["device"]) + init_latent = (server_state["model"] if not st.session_state['defaults'].general.optimized else server_state["modelFS"]).get_first_stage_encoding((server_state["model"] if not st.session_state['defaults'].general.optimized else server_state["modelFS"]).encode_first_stage(init_image)) # move to latent space + + if st.session_state['defaults'].general.optimized: + mem = torch.cuda.memory_allocated()/1e6 + server_state["modelFS"].to("cpu") + while(torch.cuda.memory_allocated()/1e6 >= mem): + time.sleep(1) + + return init_latent, mask, + + def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name): + t_enc_steps = t_enc + obliterate = False + if ddim_steps == t_enc_steps: + t_enc_steps = t_enc_steps - 1 + obliterate = True + + if sampler_name != 'DDIM': + x0, z_mask = init_data + + sigmas = sampler.model_wrap.get_sigmas(ddim_steps) + noise = x * sigmas[ddim_steps - t_enc_steps - 1] + + xi = x0 + noise + + # Obliterate masked image + if z_mask is not None and obliterate: + random = torch.randn(z_mask.shape, device=xi.device) + xi = (z_mask * noise) + ((1-z_mask) * xi) + + sigma_sched = sigmas[ddim_steps - t_enc_steps - 1:] + model_wrap_cfg = CFGMaskedDenoiser(sampler.model_wrap) + samples_ddim = K.sampling.__dict__[f'sample_{sampler.get_sampler_name()}'](model_wrap_cfg, xi, sigma_sched, + extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, + 'cond_scale': cfg_scale, 'mask': z_mask, 'x0': x0, 'xi': xi}, disable=False, + callback=generation_callback if not server_state["bridge"] else None) + else: + + x0, z_mask = init_data + + sampler.make_schedule(ddim_num_steps=ddim_steps, ddim_eta=0.0, verbose=False) + z_enc = sampler.stochastic_encode(x0, torch.tensor([t_enc_steps]*batch_size).to(server_state["device"] )) + + # Obliterate masked image + if z_mask is not None and obliterate: + random = torch.randn(z_mask.shape, device=z_enc.device) + z_enc = (z_mask * random) + ((1-z_mask) * z_enc) + + # decode it + samples_ddim = sampler.decode(z_enc, conditioning, t_enc_steps, + unconditional_guidance_scale=cfg_scale, + unconditional_conditioning=unconditional_conditioning, + z_mask=z_mask, x0=x0) + return samples_ddim + + + + if loopback: + output_images, info = None, None + history = [] + initial_seed = None + + do_color_correction = False + try: + from skimage import exposure + do_color_correction = True + except: + logger.error("Install scikit-image to perform color correction on loopback") + + for i in range(n_iter): + if do_color_correction and i == 0: + correction_target = cv2.cvtColor(np.asarray(init_img.copy()), cv2.COLOR_RGB2LAB) + + # RealESRGAN can only run on the final iteration + is_final_iteration = i == n_iter - 1 + + output_images, seed, info, stats = process_images( + outpath=outpath, + func_init=init, + func_sample=sample, + prompt=prompt, + seed=seed, + sampler_name=sampler_name, + save_grid=save_grid, + batch_size=1, + n_iter=1, + steps=ddim_steps, + cfg_scale=cfg_scale, + width=width, + height=height, + prompt_matrix=separate_prompts, + use_GFPGAN=use_GFPGAN, + GFPGAN_model=GFPGAN_model, + use_RealESRGAN=use_RealESRGAN and is_final_iteration, # Forcefully disable upscaling when using loopback + realesrgan_model_name=RealESRGAN_model, + use_LDSR=use_LDSR, + LDSR_model_name=LDSR_model, + normalize_prompt_weights=normalize_prompt_weights, + save_individual_images=save_individual_images, + init_img=init_img, + init_mask=init_mask, + mask_blur_strength=mask_blur_strength, + mask_restore=mask_restore, + denoising_strength=denoising_strength, + noise_mode=noise_mode, + find_noise_steps=find_noise_steps, + resize_mode=resize_mode, + uses_loopback=loopback, + uses_random_seed_loopback=random_seed_loopback, + sort_samples=group_by_prompt, + write_info_files=write_info_files, + jpg_sample=save_as_jpg + ) + + if initial_seed is None: + initial_seed = seed + + input_image = init_img + init_img = output_images[0] + + if do_color_correction and correction_target is not None: + init_img = Image.fromarray(cv2.cvtColor(exposure.match_histograms( + cv2.cvtColor( + np.asarray(init_img), + cv2.COLOR_RGB2LAB + ), + correction_target, + channel_axis=2 + ), cv2.COLOR_LAB2RGB).astype("uint8")) + if mask_restore is True and init_mask is not None: + color_mask = init_mask.filter(ImageFilter.GaussianBlur(mask_blur_strength)) + color_mask = color_mask.convert('L') + source_image = input_image.convert('RGB') + target_image = init_img.convert('RGB') + + init_img = Image.composite(source_image, target_image, color_mask) + + if not random_seed_loopback: + seed = seed + 1 + else: + seed = seed_to_int(None) + + denoising_strength = max(denoising_strength * 0.95, 0.1) + history.append(init_img) + + output_images = history + seed = initial_seed + + else: + output_images, seed, info, stats = process_images( + outpath=outpath, + func_init=init, + func_sample=sample, + prompt=prompt, + seed=seed, + sampler_name=sampler_name, + save_grid=save_grid, + batch_size=batch_size, + n_iter=n_iter, + steps=ddim_steps, + cfg_scale=cfg_scale, + width=width, + height=height, + prompt_matrix=separate_prompts, + use_GFPGAN=use_GFPGAN, + GFPGAN_model=GFPGAN_model, + use_RealESRGAN=use_RealESRGAN, + realesrgan_model_name=RealESRGAN_model, + use_LDSR=use_LDSR, + LDSR_model_name=LDSR_model, + normalize_prompt_weights=normalize_prompt_weights, + save_individual_images=save_individual_images, + init_img=init_img, + init_mask=init_mask, + mask_blur_strength=mask_blur_strength, + denoising_strength=denoising_strength, + noise_mode=noise_mode, + find_noise_steps=find_noise_steps, + mask_restore=mask_restore, + resize_mode=resize_mode, + uses_loopback=loopback, + sort_samples=group_by_prompt, + write_info_files=write_info_files, + jpg_sample=save_as_jpg + ) + + del sampler + + return output_images, seed, info, stats + +# +def layout(): + with st.form("img2img-inputs"): + st.session_state["generation_mode"] = "img2img" + + img2img_input_col, img2img_generate_col = st.columns([10,1]) + with img2img_input_col: + #prompt = st.text_area("Input Text","") + placeholder = "A corgi wearing a top hat as an oil painting." + prompt = st.text_area("Input Text","", placeholder=placeholder, height=54) + + if "defaults" in st.session_state: + if st.session_state["defaults"].general.enable_suggestions: + sygil_suggestions.suggestion_area(placeholder) + + if "defaults" in st.session_state: + if st.session_state['defaults'].admin.global_negative_prompt: + prompt += f"### {st.session_state['defaults'].admin.global_negative_prompt}" + + # Every form must have a submit button, the extra blank spaces is a temp way to align it with the input field. Needs to be done in CSS or some other way. + img2img_generate_col.write("") + img2img_generate_col.write("") + generate_button = img2img_generate_col.form_submit_button("Generate") + + + # creating the page layout using columns + col1_img2img_layout, col2_img2img_layout, col3_img2img_layout = st.columns([2,4,4], gap="medium") + + with col1_img2img_layout: + # If we have custom models available on the "models/custom" + #folder then we show a menu to select which model we want to use, otherwise we use the main model for SD + custom_models_available() + if server_state["CustomModel_available"]: + st.session_state["custom_model"] = st.selectbox("Custom Model:", server_state["custom_models"], + index=server_state["custom_models"].index(st.session_state['defaults'].general.default_model), + help="Select the model you want to use. This option is only available if you have custom models \ + on your 'models/custom' folder. The model name that will be shown here is the same as the name\ + the file for the model has on said folder, it is recommended to give the .ckpt file a name that \ + will make it easier for you to distinguish it from other models. Default: Stable Diffusion v1.5") + else: + st.session_state["custom_model"] = "Stable Diffusion v1.5" + + + st.session_state["sampling_steps"] = st.number_input("Sampling Steps", value=st.session_state['defaults'].img2img.sampling_steps.value, + min_value=st.session_state['defaults'].img2img.sampling_steps.min_value, + step=st.session_state['defaults'].img2img.sampling_steps.step) + + sampler_name_list = ["k_lms", "k_euler", "k_euler_a", "k_dpm_2", "k_dpm_2_a", "k_dpmpp_2m", "k_heun", "PLMS", "DDIM"] + st.session_state["sampler_name"] = st.selectbox("Sampling method",sampler_name_list, + index=sampler_name_list.index(st.session_state['defaults'].img2img.sampler_name), help="Sampling method to use.") + + width = st.slider("Width:", min_value=st.session_state['defaults'].img2img.width.min_value, max_value=st.session_state['defaults'].img2img.width.max_value, + value=st.session_state['defaults'].img2img.width.value, step=st.session_state['defaults'].img2img.width.step) + height = st.slider("Height:", min_value=st.session_state['defaults'].img2img.height.min_value, max_value=st.session_state['defaults'].img2img.height.max_value, + value=st.session_state['defaults'].img2img.height.value, step=st.session_state['defaults'].img2img.height.step) + seed = st.text_input("Seed:", value=st.session_state['defaults'].img2img.seed, help=" The seed to use, if left blank a random seed will be generated.") + + cfg_scale = st.number_input("CFG (Classifier Free Guidance Scale):", min_value=st.session_state['defaults'].img2img.cfg_scale.min_value, + value=st.session_state['defaults'].img2img.cfg_scale.value, + step=st.session_state['defaults'].img2img.cfg_scale.step, + help="How strongly the image should follow the prompt.") + + st.session_state["denoising_strength"] = st.slider("Denoising Strength:", value=st.session_state['defaults'].img2img.denoising_strength.value, + min_value=st.session_state['defaults'].img2img.denoising_strength.min_value, + max_value=st.session_state['defaults'].img2img.denoising_strength.max_value, + step=st.session_state['defaults'].img2img.denoising_strength.step) + + + mask_expander = st.empty() + with mask_expander.expander("Inpainting/Outpainting"): + mask_mode_list = ["Outpainting", "Inpainting", "Image alpha"] + mask_mode = st.selectbox("Painting Mode", mask_mode_list, index=st.session_state["defaults"].img2img.mask_mode, + help="Select how you want your image to be masked/painted.\"Inpainting\" modifies the image where the mask is white.\n\ + \"Inverted mask\" modifies the image where the mask is black. \"Image alpha\" modifies the image where the image is transparent." + ) + mask_mode = mask_mode_list.index(mask_mode) + + + noise_mode_list = ["Seed", "Find Noise", "Matched Noise", "Find+Matched Noise"] + noise_mode = st.selectbox("Noise Mode", noise_mode_list, index=noise_mode_list.index(st.session_state['defaults'].img2img.noise_mode), help="") + #noise_mode = noise_mode_list.index(noise_mode) + find_noise_steps = st.number_input("Find Noise Steps", value=st.session_state['defaults'].img2img.find_noise_steps.value, + min_value=st.session_state['defaults'].img2img.find_noise_steps.min_value, + step=st.session_state['defaults'].img2img.find_noise_steps.step) + + # Specify canvas parameters in application + drawing_mode = st.selectbox( + "Drawing tool:", + ( + "freedraw", + "transform", + #"line", + "rect", + "circle", + #"polygon", + ), + ) + + stroke_width = st.slider("Stroke width: ", 1, 100, 50) + stroke_color = st.color_picker("Stroke color hex: ", value="#EEEEEE") + bg_color = st.color_picker("Background color hex: ", "#7B6E6E") + + display_toolbar = st.checkbox("Display toolbar", True) + #realtime_update = st.checkbox("Update in realtime", True) + + with st.expander("Batch Options"): + st.session_state["batch_count"] = st.number_input("Batch count.", value=st.session_state['defaults'].img2img.batch_count.value, + help="How many iterations or batches of images to generate in total.") + + st.session_state["batch_size"] = st.number_input("Batch size", value=st.session_state.defaults.img2img.batch_size.value, + help="How many images are at once in a batch.\ + It increases the VRAM usage a lot but if you have enough VRAM it can reduce the time it takes to finish generation as more images are generated at once.\ + Default: 1") + + with st.expander("Preview Settings"): + st.session_state["update_preview"] = st.session_state["defaults"].general.update_preview + st.session_state["update_preview_frequency"] = st.number_input("Update Image Preview Frequency", + min_value=0, + value=st.session_state['defaults'].img2img.update_preview_frequency, + help="Frequency in steps at which the the preview image is updated. By default the frequency \ + is set to 1 step.") + # + with st.expander("Advanced"): + with st.expander("Output Settings"): + separate_prompts = st.checkbox("Create Prompt Matrix.", value=st.session_state['defaults'].img2img.separate_prompts, + help="Separate multiple prompts using the `|` character, and get all combinations of them.") + normalize_prompt_weights = st.checkbox("Normalize Prompt Weights.", value=st.session_state['defaults'].img2img.normalize_prompt_weights, + help="Ensure the sum of all weights add up to 1.0") + loopback = st.checkbox("Loopback.", value=st.session_state['defaults'].img2img.loopback, help="Use images from previous batch when creating next batch.") + random_seed_loopback = st.checkbox("Random loopback seed.", value=st.session_state['defaults'].img2img.random_seed_loopback, help="Random loopback seed") + img2img_mask_restore = st.checkbox("Only modify regenerated parts of image", + value=st.session_state['defaults'].img2img.mask_restore, + help="Enable to restore the unmasked parts of the image with the input, may not blend as well but preserves detail") + save_individual_images = st.checkbox("Save individual images.", value=st.session_state['defaults'].img2img.save_individual_images, + help="Save each image generated before any filter or enhancement is applied.") + save_grid = st.checkbox("Save grid",value=st.session_state['defaults'].img2img.save_grid, help="Save a grid with all the images generated into a single image.") + group_by_prompt = st.checkbox("Group results by prompt", value=st.session_state['defaults'].img2img.group_by_prompt, + help="Saves all the images with the same prompt into the same folder. \ + When using a prompt matrix each prompt combination will have its own folder.") + write_info_files = st.checkbox("Write Info file", value=st.session_state['defaults'].img2img.write_info_files, + help="Save a file next to the image with informartion about the generation.") + save_as_jpg = st.checkbox("Save samples as jpg", value=st.session_state['defaults'].img2img.save_as_jpg, help="Saves the images as jpg instead of png.") + + # + # check if GFPGAN, RealESRGAN and LDSR are available. + if "GFPGAN_available" not in st.session_state: + GFPGAN_available() + + if "RealESRGAN_available" not in st.session_state: + RealESRGAN_available() + + if "LDSR_available" not in st.session_state: + LDSR_available() + + if st.session_state["GFPGAN_available"] or st.session_state["RealESRGAN_available"] or st.session_state["LDSR_available"]: + with st.expander("Post-Processing"): + face_restoration_tab, upscaling_tab = st.tabs(["Face Restoration", "Upscaling"]) + with face_restoration_tab: + # GFPGAN used for face restoration + if st.session_state["GFPGAN_available"]: + #with st.expander("Face Restoration"): + #if st.session_state["GFPGAN_available"]: + #with st.expander("GFPGAN"): + st.session_state["use_GFPGAN"] = st.checkbox("Use GFPGAN", value=st.session_state['defaults'].img2img.use_GFPGAN, + help="Uses the GFPGAN model to improve faces after the generation.\ + This greatly improve the quality and consistency of faces but uses\ + extra VRAM. Disable if you need the extra VRAM.") + + st.session_state["GFPGAN_model"] = st.selectbox("GFPGAN model", st.session_state["GFPGAN_models"], + index=st.session_state["GFPGAN_models"].index(st.session_state['defaults'].general.GFPGAN_model)) + + #st.session_state["GFPGAN_strenght"] = st.slider("Effect Strenght", min_value=1, max_value=100, value=1, step=1, help='') + + else: + st.session_state["use_GFPGAN"] = False + + with upscaling_tab: + st.session_state['us_upscaling'] = st.checkbox("Use Upscaling", value=st.session_state['defaults'].img2img.use_upscaling) + + # RealESRGAN and LDSR used for upscaling. + if st.session_state["RealESRGAN_available"] or st.session_state["LDSR_available"]: + + upscaling_method_list = [] + if st.session_state["RealESRGAN_available"]: + upscaling_method_list.append("RealESRGAN") + if st.session_state["LDSR_available"]: + upscaling_method_list.append("LDSR") + + st.session_state["upscaling_method"] = st.selectbox("Upscaling Method", upscaling_method_list, + index=upscaling_method_list.index(st.session_state['defaults'].general.upscaling_method) + if st.session_state['defaults'].general.upscaling_method in upscaling_method_list + else 0) + + if st.session_state["RealESRGAN_available"]: + with st.expander("RealESRGAN"): + if st.session_state["upscaling_method"] == "RealESRGAN" and st.session_state['us_upscaling']: + st.session_state["use_RealESRGAN"] = True + else: + st.session_state["use_RealESRGAN"] = False + + st.session_state["RealESRGAN_model"] = st.selectbox("RealESRGAN model", st.session_state["RealESRGAN_models"], + index=st.session_state["RealESRGAN_models"].index(st.session_state['defaults'].general.RealESRGAN_model)) + else: + st.session_state["use_RealESRGAN"] = False + st.session_state["RealESRGAN_model"] = "RealESRGAN_x4plus" + + + # + if st.session_state["LDSR_available"]: + with st.expander("LDSR"): + if st.session_state["upscaling_method"] == "LDSR" and st.session_state['us_upscaling']: + st.session_state["use_LDSR"] = True + else: + st.session_state["use_LDSR"] = False + + st.session_state["LDSR_model"] = st.selectbox("LDSR model", st.session_state["LDSR_models"], + index=st.session_state["LDSR_models"].index(st.session_state['defaults'].general.LDSR_model)) + + st.session_state["ldsr_sampling_steps"] = st.number_input("Sampling Steps", value=st.session_state['defaults'].img2img.LDSR_config.sampling_steps, + help="") + + st.session_state["preDownScale"] = st.number_input("PreDownScale", value=st.session_state['defaults'].img2img.LDSR_config.preDownScale, + help="") + + st.session_state["postDownScale"] = st.number_input("postDownScale", value=st.session_state['defaults'].img2img.LDSR_config.postDownScale, + help="") + + downsample_method_list = ['Nearest', 'Lanczos'] + st.session_state["downsample_method"] = st.selectbox("Downsample Method", downsample_method_list, + index=downsample_method_list.index(st.session_state['defaults'].img2img.LDSR_config.downsample_method)) + + else: + st.session_state["use_LDSR"] = False + st.session_state["LDSR_model"] = "model" + + with st.expander("Variant"): + variant_amount = st.slider("Variant Amount:", value=st.session_state['defaults'].img2img.variant_amount, min_value=0.0, max_value=1.0, step=0.01) + variant_seed = st.text_input("Variant Seed:", value=st.session_state['defaults'].img2img.variant_seed, + help="The seed to use when generating a variant, if left blank a random seed will be generated.") + + + with col2_img2img_layout: + editor_tab = st.tabs(["Editor"]) + + editor_image = st.empty() + st.session_state["editor_image"] = editor_image + + st.form_submit_button("Refresh") + + #if "canvas" not in st.session_state: + st.session_state["canvas"] = st.empty() + + masked_image_holder = st.empty() + image_holder = st.empty() + + uploaded_images = st.file_uploader( + "Upload Image", accept_multiple_files=False, type=["png", "jpg", "jpeg", "webp", 'jfif'], + help="Upload an image which will be used for the image to image generation.", + ) + if uploaded_images: + image = Image.open(uploaded_images).convert('RGB') + new_img = image.resize((width, height)) + #image_holder.image(new_img) + + #mask_holder = st.empty() + + #uploaded_masks = st.file_uploader( + #"Upload Mask", accept_multiple_files=False, type=["png", "jpg", "jpeg", "webp", 'jfif'], + #help="Upload an mask image which will be used for masking the image to image generation.", + #) + + # + # Create a canvas component + with st.session_state["canvas"]: + st.session_state["uploaded_masks"] = st_canvas( + fill_color="rgba(255, 165, 0, 0.3)", # Fixed fill color with some opacity + stroke_width=stroke_width, + stroke_color=stroke_color, + background_color=bg_color, + background_image=image if uploaded_images else None, + update_streamlit=True, + width=width, + height=height, + drawing_mode=drawing_mode, + initial_drawing=st.session_state["uploaded_masks"].json_data if "uploaded_masks" in st.session_state else None, + display_toolbar= display_toolbar, + key="full_app", + ) + + #try: + ##print (type(st.session_state["uploaded_masks"])) + #if st.session_state["uploaded_masks"] != None: + #mask_expander.expander("Mask", expanded=True) + #mask = Image.fromarray(st.session_state["uploaded_masks"].image_data) + + #st.image(mask) + + #if mask.mode == "RGBA": + #mask = mask.convert('RGBA') + #background = Image.new('RGBA', mask.size, (0, 0, 0)) + #mask = Image.alpha_composite(background, mask) + #mask = mask.resize((width, height)) + #except AttributeError: + #pass + + with col3_img2img_layout: + result_tab = st.tabs(["Result"]) + + # create an empty container for the image, progress bar, etc so we can update it later and use session_state to hold them globally. + preview_image = st.empty() + st.session_state["preview_image"] = preview_image + + #st.session_state["loading"] = st.empty() + + st.session_state["progress_bar_text"] = st.empty() + st.session_state["progress_bar"] = st.empty() + + message = st.empty() + + #if uploaded_images: + #image = Image.open(uploaded_images).convert('RGB') + ##img_array = np.array(image) # if you want to pass it to OpenCV + #new_img = image.resize((width, height)) + #st.image(new_img, use_column_width=True) + + + if generate_button: + #print("Loading models") + # load the models when we hit the generate button for the first time, it wont be loaded after that so dont worry. + with col3_img2img_layout: + with no_rerun: + with hc.HyLoader('Loading Models...', hc.Loaders.standard_loaders,index=[0]): + load_models(use_LDSR=st.session_state["use_LDSR"], LDSR_model=st.session_state["LDSR_model"], + use_GFPGAN=st.session_state["use_GFPGAN"], GFPGAN_model=st.session_state["GFPGAN_model"] , + use_RealESRGAN=st.session_state["use_RealESRGAN"], RealESRGAN_model=st.session_state["RealESRGAN_model"], + CustomModel_available=server_state["CustomModel_available"], custom_model=st.session_state["custom_model"]) + + if uploaded_images: + #image = Image.fromarray(image).convert('RGBA') + #new_img = image.resize((width, height)) + ###img_array = np.array(image) # if you want to pass it to OpenCV + #image_holder.image(new_img) + new_mask = None + + if st.session_state["uploaded_masks"]: + mask = Image.fromarray(st.session_state["uploaded_masks"].image_data) + new_mask = mask.resize((width, height)) + + #masked_image_holder.image(new_mask) + try: + output_images, seed, info, stats = img2img(prompt=prompt, init_info=new_img, init_info_mask=new_mask, mask_mode=mask_mode, + mask_restore=img2img_mask_restore, ddim_steps=st.session_state["sampling_steps"], + sampler_name=st.session_state["sampler_name"], n_iter=st.session_state["batch_count"], + cfg_scale=cfg_scale, denoising_strength=st.session_state["denoising_strength"], variant_seed=variant_seed, + seed=seed, noise_mode=noise_mode, find_noise_steps=find_noise_steps, width=width, + height=height, variant_amount=variant_amount, + ddim_eta=st.session_state.defaults.img2img.ddim_eta, write_info_files=write_info_files, + separate_prompts=separate_prompts, normalize_prompt_weights=normalize_prompt_weights, + save_individual_images=save_individual_images, save_grid=save_grid, + group_by_prompt=group_by_prompt, save_as_jpg=save_as_jpg, use_GFPGAN=st.session_state["use_GFPGAN"], + GFPGAN_model=st.session_state["GFPGAN_model"], + use_RealESRGAN=st.session_state["use_RealESRGAN"], RealESRGAN_model=st.session_state["RealESRGAN_model"], + use_LDSR=st.session_state["use_LDSR"], LDSR_model=st.session_state["LDSR_model"], + loopback=loopback + ) + + #show a message when the generation is complete. + message.success('Render Complete: ' + info + '; Stats: ' + stats, icon="✅") + + except (StopException, + #KeyError + ): + logger.info(f"Received Streamlit StopException") + # reset the page title so the percent doesnt stay on it confusing the user. + set_page_title(f"Stable Diffusion Playground") + + # this will render all the images at the end of the generation but its better if its moved to a second tab inside col2 and shown as a gallery. + # use the current col2 first tab to show the preview_img and update it as its generated. + #preview_image.image(output_images, width=750) + +#on import run init diff --git a/webui/streamlit/scripts/img2txt.py b/webui/streamlit/scripts/img2txt.py new file mode 100644 index 0000000..1268ebf --- /dev/null +++ b/webui/streamlit/scripts/img2txt.py @@ -0,0 +1,460 @@ +# This file is part of sygil-webui (https://github.com/Sygil-Dev/sygil-webui/). + +# Copyright 2022 Sygil-Dev team. +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. + +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. + +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +# --------------------------------------------------------------------------------------------------------------------------------------------------- +""" +CLIP Interrogator made by @pharmapsychotic modified to work with our WebUI. + +# CLIP Interrogator by @pharmapsychotic +Twitter: https://twitter.com/pharmapsychotic +Github: https://github.com/pharmapsychotic/clip-interrogator + +Description: +What do the different OpenAI CLIP models see in an image? What might be a good text prompt to create similar images using CLIP guided diffusion +or another text to image model? The CLIP Interrogator is here to get you answers! + +Please consider buying him a coffee via [ko-fi](https://ko-fi.com/pharmapsychotic) or following him on [twitter](https://twitter.com/pharmapsychotic). + +And if you're looking for more Ai art tools check out my [Ai generative art tools list](https://pharmapsychotic.com/tools.html). + +""" +# --------------------------------------------------------------------------------------------------------------------------------------------------- + +# base webui import and utils. +from sd_utils import st, logger, server_state, server_state_lock, random + +# streamlit imports + +# streamlit components section +import streamlit_nested_layout + +# other imports + +import clip +import open_clip +import gc +import os +import pandas as pd +#import requests +import torch +from PIL import Image +from torchvision import transforms +from torchvision.transforms.functional import InterpolationMode +from ldm.models.blip import blip_decoder +#import hashlib + +# end of imports +# --------------------------------------------------------------------------------------------------------------- + +device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') +blip_image_eval_size = 512 + +st.session_state["log"] = [] + +def load_blip_model(): + logger.info("Loading BLIP Model") + if "log" not in st.session_state: + st.session_state["log"] = [] + + st.session_state["log"].append("Loading BLIP Model") + st.session_state["log_message"].code('\n'.join(st.session_state["log"]), language='') + + if "blip_model" not in server_state: + with server_state_lock['blip_model']: + server_state["blip_model"] = blip_decoder(pretrained="models/blip/model__base_caption.pth", + image_size=blip_image_eval_size, vit='base', med_config="configs/blip/med_config.json") + + server_state["blip_model"] = server_state["blip_model"].eval() + + server_state["blip_model"] = server_state["blip_model"].to(device).half() + + logger.info("BLIP Model Loaded") + st.session_state["log"].append("BLIP Model Loaded") + st.session_state["log_message"].code('\n'.join(st.session_state["log"]), language='') + else: + logger.info("BLIP Model already loaded") + st.session_state["log"].append("BLIP Model already loaded") + st.session_state["log_message"].code('\n'.join(st.session_state["log"]), language='') + + +def generate_caption(pil_image): + + load_blip_model() + + gpu_image = transforms.Compose([ # type: ignore + transforms.Resize((blip_image_eval_size, blip_image_eval_size), interpolation=InterpolationMode.BICUBIC), # type: ignore + transforms.ToTensor(), # type: ignore + transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) # type: ignore + ])(pil_image).unsqueeze(0).to(device).half() + + with torch.no_grad(): + caption = server_state["blip_model"].generate(gpu_image, sample=False, num_beams=3, max_length=20, min_length=5) + + return caption[0] + +def load_list(filename): + with open(filename, 'r', encoding='utf-8', errors='replace') as f: + items = [line.strip() for line in f.readlines()] + return items + +def rank(model, image_features, text_array, top_count=1): + top_count = min(top_count, len(text_array)) + text_tokens = clip.tokenize([text for text in text_array]).cuda() + with torch.no_grad(): + text_features = model.encode_text(text_tokens).float() + text_features /= text_features.norm(dim=-1, keepdim=True) + + similarity = torch.zeros((1, len(text_array))).to(device) + for i in range(image_features.shape[0]): + similarity += (100.0 * image_features[i].unsqueeze(0) @ text_features.T).softmax(dim=-1) + similarity /= image_features.shape[0] + + top_probs, top_labels = similarity.cpu().topk(top_count, dim=-1) + return [(text_array[top_labels[0][i].numpy()], (top_probs[0][i].numpy()*100)) for i in range(top_count)] + + +def clear_cuda(): + torch.cuda.empty_cache() + gc.collect() + + +def batch_rank(model, image_features, text_array, batch_size=st.session_state["defaults"].img2txt.batch_size): + batch_size = min(batch_size, len(text_array)) + batch_count = int(len(text_array) / batch_size) + batches = [text_array[i*batch_size:(i+1)*batch_size] for i in range(batch_count)] + ranks = [] + for batch in batches: + ranks += rank(model, image_features, batch) + return ranks + +def interrogate(image, models): + load_blip_model() + + logger.info("Generating Caption") + st.session_state["log"].append("Generating Caption") + st.session_state["log_message"].code('\n'.join(st.session_state["log"]), language='') + caption = generate_caption(image) + + if st.session_state["defaults"].general.optimized: + del server_state["blip_model"] + clear_cuda() + + logger.info("Caption Generated") + st.session_state["log"].append("Caption Generated") + st.session_state["log_message"].code('\n'.join(st.session_state["log"]), language='') + + if len(models) == 0: + logger.info(f"\n\n{caption}") + return + + table = [] + bests = [[('', 0)]]*7 + + logger.info("Ranking Text") + st.session_state["log"].append("Ranking Text") + st.session_state["log_message"].code('\n'.join(st.session_state["log"]), language='') + + for model_name in models: + with torch.no_grad(), torch.autocast('cuda', dtype=torch.float16): + logger.info(f"Interrogating with {model_name}...") + st.session_state["log"].append(f"Interrogating with {model_name}...") + st.session_state["log_message"].code('\n'.join(st.session_state["log"]), language='') + + if model_name not in server_state["clip_models"]: + if not st.session_state["defaults"].img2txt.keep_all_models_loaded: + model_to_delete = [] + for model in server_state["clip_models"]: + if model != model_name: + model_to_delete.append(model) + for model in model_to_delete: + del server_state["clip_models"][model] + del server_state["preprocesses"][model] + clear_cuda() + if model_name == 'ViT-H-14': + server_state["clip_models"][model_name], _, server_state["preprocesses"][model_name] = \ + open_clip.create_model_and_transforms(model_name, pretrained='laion2b_s32b_b79k', cache_dir='models/clip') + elif model_name == 'ViT-g-14': + server_state["clip_models"][model_name], _, server_state["preprocesses"][model_name] = \ + open_clip.create_model_and_transforms(model_name, pretrained='laion2b_s12b_b42k', cache_dir='models/clip') + else: + server_state["clip_models"][model_name], server_state["preprocesses"][model_name] = \ + clip.load(model_name, device=device, download_root='models/clip') + server_state["clip_models"][model_name] = server_state["clip_models"][model_name].cuda().eval() + + images = server_state["preprocesses"][model_name](image).unsqueeze(0).cuda() + + + image_features = server_state["clip_models"][model_name].encode_image(images).float() + + image_features /= image_features.norm(dim=-1, keepdim=True) + + if st.session_state["defaults"].general.optimized: + clear_cuda() + + ranks = [] + ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["mediums"])) + ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, ["by "+artist for artist in server_state["artists"]])) + ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["trending_list"])) + ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["movements"])) + ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["flavors"])) + #ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["domains"])) + #ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["subreddits"])) + ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["techniques"])) + ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["tags"])) + + # ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["genres"])) + # ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["styles"])) + # ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["subjects"])) + # ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["colors"])) + # ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["moods"])) + # ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["themes"])) + # ranks.append(batch_rank(server_state["clip_models"][model_name], image_features, server_state["keywords"])) + + #print (bests) + #print (ranks) + + for i in range(len(ranks)): + confidence_sum = 0 + for ci in range(len(ranks[i])): + confidence_sum += ranks[i][ci][1] + if confidence_sum > sum(bests[i][t][1] for t in range(len(bests[i]))): + bests[i] = ranks[i] + + for best in bests: + best.sort(key=lambda x: x[1], reverse=True) + # prune to 3 + best = best[:3] + + row = [model_name] + + for r in ranks: + row.append(', '.join([f"{x[0]} ({x[1]:0.1f}%)" for x in r])) + + #for rank in ranks: + # rank.sort(key=lambda x: x[1], reverse=True) + # row.append(f'{rank[0][0]} {rank[0][1]:.2f}%') + + table.append(row) + + if st.session_state["defaults"].general.optimized: + del server_state["clip_models"][model_name] + gc.collect() + + st.session_state["prediction_table"][st.session_state["processed_image_count"]].dataframe(pd.DataFrame( + table, columns=["Model", "Medium", "Artist", "Trending", "Movement", "Flavors", "Techniques", "Tags"])) + + medium = bests[0][0][0] + artist = bests[1][0][0] + trending = bests[2][0][0] + movement = bests[3][0][0] + flavors = bests[4][0][0] + #domains = bests[5][0][0] + #subreddits = bests[6][0][0] + techniques = bests[5][0][0] + tags = bests[6][0][0] + + + if caption.startswith(medium): + st.session_state["text_result"][st.session_state["processed_image_count"]].code( + f"\n\n{caption} {artist}, {trending}, {movement}, {techniques}, {flavors}, {tags}", language="") + else: + st.session_state["text_result"][st.session_state["processed_image_count"]].code( + f"\n\n{caption}, {medium} {artist}, {trending}, {movement}, {techniques}, {flavors}, {tags}", language="") + + logger.info("Finished Interrogating.") + st.session_state["log"].append("Finished Interrogating.") + st.session_state["log_message"].code('\n'.join(st.session_state["log"]), language='') + + +def img2txt(): + models = [] + + if st.session_state["ViT-L/14"]: + models.append('ViT-L/14') + if st.session_state["ViT-H-14"]: + models.append('ViT-H-14') + if st.session_state["ViT-g-14"]: + models.append('ViT-g-14') + + if st.session_state["ViTB32"]: + models.append('ViT-B/32') + if st.session_state['ViTB16']: + models.append('ViT-B/16') + + if st.session_state["ViTL14_336px"]: + models.append('ViT-L/14@336px') + if st.session_state["RN101"]: + models.append('RN101') + if st.session_state["RN50"]: + models.append('RN50') + if st.session_state["RN50x4"]: + models.append('RN50x4') + if st.session_state["RN50x16"]: + models.append('RN50x16') + if st.session_state["RN50x64"]: + models.append('RN50x64') + + # if str(image_path_or_url).startswith('http://') or str(image_path_or_url).startswith('https://'): + #image = Image.open(requests.get(image_path_or_url, stream=True).raw).convert('RGB') + # else: + #image = Image.open(image_path_or_url).convert('RGB') + + #thumb = st.session_state["uploaded_image"].image.copy() + #thumb.thumbnail([blip_image_eval_size, blip_image_eval_size]) + # display(thumb) + + st.session_state["processed_image_count"] = 0 + + for i in range(len(st.session_state["uploaded_image"])): + + interrogate(st.session_state["uploaded_image"][i].pil_image, models=models) + # increase counter. + st.session_state["processed_image_count"] += 1 +# + + +def layout(): + #set_page_title("Image-to-Text - Stable Diffusion WebUI") + #st.info("Under Construction. :construction_worker:") + # + if "clip_models" not in server_state: + server_state["clip_models"] = {} + if "preprocesses" not in server_state: + server_state["preprocesses"] = {} + data_path = "data/" + if "artists" not in server_state: + server_state["artists"] = load_list(os.path.join(data_path, 'img2txt', 'artists.txt')) + if "flavors" not in server_state: + server_state["flavors"] = random.choices(load_list(os.path.join(data_path, 'img2txt', 'flavors.txt')), k=2000) + if "mediums" not in server_state: + server_state["mediums"] = load_list(os.path.join(data_path, 'img2txt', 'mediums.txt')) + if "movements" not in server_state: + server_state["movements"] = load_list(os.path.join(data_path, 'img2txt', 'movements.txt')) + if "sites" not in server_state: + server_state["sites"] = load_list(os.path.join(data_path, 'img2txt', 'sites.txt')) + #server_state["domains"] = load_list(os.path.join(data_path, 'img2txt', 'domains.txt')) + #server_state["subreddits"] = load_list(os.path.join(data_path, 'img2txt', 'subreddits.txt')) + if "techniques" not in server_state: + server_state["techniques"] = load_list(os.path.join(data_path, 'img2txt', 'techniques.txt')) + if "tags" not in server_state: + server_state["tags"] = load_list(os.path.join(data_path, 'img2txt', 'tags.txt')) + #server_state["genres"] = load_list(os.path.join(data_path, 'img2txt', 'genres.txt')) + # server_state["styles"] = load_list(os.path.join(data_path, 'img2txt', 'styles.txt')) + # server_state["subjects"] = load_list(os.path.join(data_path, 'img2txt', 'subjects.txt')) + if "trending_list" not in server_state: + server_state["trending_list"] = [site for site in server_state["sites"]] + server_state["trending_list"].extend(["trending on "+site for site in server_state["sites"]]) + server_state["trending_list"].extend(["featured on "+site for site in server_state["sites"]]) + server_state["trending_list"].extend([site+" contest winner" for site in server_state["sites"]]) + with st.form("img2txt-inputs"): + st.session_state["generation_mode"] = "img2txt" + + # st.write("---") + # creating the page layout using columns + col1, col2 = st.columns([1, 4], gap="large") + + with col1: + st.session_state["uploaded_image"] = st.file_uploader('Input Image', type=['png', 'jpg', 'jpeg', 'jfif', 'webp'], accept_multiple_files=True) + + with st.expander("CLIP models", expanded=True): + st.session_state["ViT-L/14"] = st.checkbox("ViT-L/14", value=True, help="ViT-L/14 model.") + st.session_state["ViT-H-14"] = st.checkbox("ViT-H-14", value=False, help="ViT-H-14 model.") + st.session_state["ViT-g-14"] = st.checkbox("ViT-g-14", value=False, help="ViT-g-14 model.") + + + + with st.expander("Others"): + st.info("For DiscoDiffusion and JAX enable all the same models here as you intend to use when generating your images.") + + st.session_state["ViTL14_336px"] = st.checkbox("ViTL14_336px", value=False, help="ViTL14_336px model.") + st.session_state["ViTB16"] = st.checkbox("ViTB16", value=False, help="ViTB16 model.") + st.session_state["ViTB32"] = st.checkbox("ViTB32", value=False, help="ViTB32 model.") + st.session_state["RN50"] = st.checkbox("RN50", value=False, help="RN50 model.") + st.session_state["RN50x4"] = st.checkbox("RN50x4", value=False, help="RN50x4 model.") + st.session_state["RN50x16"] = st.checkbox("RN50x16", value=False, help="RN50x16 model.") + st.session_state["RN50x64"] = st.checkbox("RN50x64", value=False, help="RN50x64 model.") + st.session_state["RN101"] = st.checkbox("RN101", value=False, help="RN101 model.") + + # + # st.subheader("Logs:") + + st.session_state["log_message"] = st.empty() + st.session_state["log_message"].code('', language="") + + with col2: + st.subheader("Image") + + image_col1, image_col2 = st.columns([10,25]) + with image_col1: + refresh = st.form_submit_button("Update Preview Image", help='Refresh the image preview to show your uploaded image instead of the default placeholder.') + + if st.session_state["uploaded_image"]: + #print (type(st.session_state["uploaded_image"])) + # if len(st.session_state["uploaded_image"]) == 1: + st.session_state["input_image_preview"] = [] + st.session_state["input_image_preview_container"] = [] + st.session_state["prediction_table"] = [] + st.session_state["text_result"] = [] + + for i in range(len(st.session_state["uploaded_image"])): + st.session_state["input_image_preview_container"].append(i) + st.session_state["input_image_preview_container"][i] = st.empty() + + with st.session_state["input_image_preview_container"][i].container(): + col1_output, col2_output = st.columns([2, 10], gap="medium") + with col1_output: + st.session_state["input_image_preview"].append(i) + st.session_state["input_image_preview"][i] = st.empty() + st.session_state["uploaded_image"][i].pil_image = Image.open(st.session_state["uploaded_image"][i]).convert('RGB') + + st.session_state["input_image_preview"][i].image(st.session_state["uploaded_image"][i].pil_image, use_column_width=True, clamp=True) + + with st.session_state["input_image_preview_container"][i].container(): + + with col2_output: + + st.session_state["prediction_table"].append(i) + st.session_state["prediction_table"][i] = st.empty() + st.session_state["prediction_table"][i].table() + + st.session_state["text_result"].append(i) + st.session_state["text_result"][i] = st.empty() + st.session_state["text_result"][i].code("", language="") + + else: + #st.session_state["input_image_preview"].code('', language="") + st.image("images/streamlit/img2txt_placeholder.png", clamp=True) + + with image_col2: + # + # Every form must have a submit button, the extra blank spaces is a temp way to align it with the input field. Needs to be done in CSS or some other way. + # generate_col1.title("") + # generate_col1.title("") + generate_button = st.form_submit_button("Generate!", help="Start interrogating the images to generate a prompt from each of the selected images") + + if generate_button: + # if model, pipe, RealESRGAN or GFPGAN is in st.session_state remove the model and pipe form session_state so that they are reloaded. + if "model" in server_state and st.session_state["defaults"].general.optimized: + del server_state["model"] + if "pipe" in server_state and st.session_state["defaults"].general.optimized: + del server_state["pipe"] + if "RealESRGAN" in server_state and st.session_state["defaults"].general.optimized: + del server_state["RealESRGAN"] + if "GFPGAN" in server_state and st.session_state["defaults"].general.optimized: + del server_state["GFPGAN"] + + # run clip interrogator + img2txt() diff --git a/webui/streamlit/scripts/post_processing.py b/webui/streamlit/scripts/post_processing.py new file mode 100644 index 0000000..61de288 --- /dev/null +++ b/webui/streamlit/scripts/post_processing.py @@ -0,0 +1,368 @@ +# This file is part of sygil-webui (https://github.com/Sygil-Dev/sandbox-webui/). + +# Copyright 2022 Sygil-Dev team. +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. + +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. + +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . +# base webui import and utils. +#from sd_utils import * +from sd_utils import st, server_state, \ + RealESRGAN_available, GFPGAN_available, LDSR_available + +# streamlit imports + +#streamlit components section +import hydralit_components as hc + +#other imports +import os +from PIL import Image +import torch + +# Temp imports + +# end of imports +#--------------------------------------------------------------------------------------------------------------- + +def post_process(use_GFPGAN=True, GFPGAN_model='', use_RealESRGAN=False, realesrgan_model_name="", use_LDSR=False, LDSR_model_name=""): + + for i in range(len(st.session_state["uploaded_image"])): + #st.session_state["uploaded_image"][i].pil_image + + if use_GFPGAN and server_state["GFPGAN"] is not None and not use_RealESRGAN and not use_LDSR: + if "progress_bar_text" in st.session_state: + st.session_state["progress_bar_text"].text("Running GFPGAN on image %d of %d..." % (i+1, len(st.session_state["uploaded_image"]))) + + if "progress_bar" in st.session_state: + st.session_state["progress_bar"].progress( + int(100 * float(i+1 if i+1 < len(st.session_state["uploaded_image"]) else len(st.session_state["uploaded_image"]))/float(len(st.session_state["uploaded_image"])))) + + if server_state["GFPGAN"].name != GFPGAN_model: + load_models(use_LDSR=use_LDSR, LDSR_model=LDSR_model_name, use_GFPGAN=use_GFPGAN, use_RealESRGAN=use_RealESRGAN, RealESRGAN_model=realesrgan_model_name) + + torch_gc() + + with torch.autocast('cuda'): + cropped_faces, restored_faces, restored_img = server_state["GFPGAN"].enhance(st.session_state["uploaded_image"][i].pil_image, has_aligned=False, only_center_face=False, paste_back=True) + + gfpgan_sample = restored_img[:,:,::-1] + gfpgan_image = Image.fromarray(gfpgan_sample) + + #if st.session_state["GFPGAN_strenght"]: + #gfpgan_sample = Image.blend(image, gfpgan_image, st.session_state["GFPGAN_strenght"]) + + gfpgan_filename = st.session_state["uploaded_image"][i].name.split('.')[0] + '-gfpgan' + + gfpgan_image.save(os.path.join(st.session_state["defaults"].post_processing.outdir_post_processing, f"{gfpgan_filename}.png")) + + # + elif use_RealESRGAN and server_state["RealESRGAN"] is not None and not use_GFPGAN: + if "progress_bar_text" in st.session_state: + st.session_state["progress_bar_text"].text("Running RealESRGAN on image %d of %d..." % (i+1, len(st.session_state["uploaded_image"]))) + + if "progress_bar" in st.session_state: + st.session_state["progress_bar"].progress( + int(100 * float(i+1 if i+1 < len(st.session_state["uploaded_image"]) else len(st.session_state["uploaded_image"]))/float(len(st.session_state["uploaded_image"])))) + + torch_gc() + + if server_state["RealESRGAN"].model.name != realesrgan_model_name: + #try_loading_RealESRGAN(realesrgan_model_name) + load_models(use_GFPGAN=use_GFPGAN, use_RealESRGAN=use_RealESRGAN, RealESRGAN_model=realesrgan_model_name) + + output, img_mode = server_state["RealESRGAN"].enhance(st.session_state["uploaded_image"][i].pil_image) + esrgan_filename = st.session_state["uploaded_image"][i].name.split('.')[0] + '-esrgan4x' + esrgan_sample = output[:,:,::-1] + esrgan_image = Image.fromarray(esrgan_sample) + + esrgan_image.save(os.path.join(st.session_state["defaults"].post_processing.outdir_post_processing, f"{esrgan_filename}.png")) + + # + elif use_LDSR and "LDSR" in server_state and not use_GFPGAN: + logger.info ("Running LDSR on image %d of %d..." % (i+1, len(st.session_state["uploaded_image"]))) + if "progress_bar_text" in st.session_state: + st.session_state["progress_bar_text"].text("Running LDSR on image %d of %d..." % (i+1, len(st.session_state["uploaded_image"]))) + if "progress_bar" in st.session_state: + st.session_state["progress_bar"].progress( + int(100 * float(i+1 if i+1 < len(st.session_state["uploaded_image"]) else len(st.session_state["uploaded_image"]))/float(len(st.session_state["uploaded_image"])))) + + torch_gc() + + if server_state["LDSR"].name != LDSR_model_name: + #try_loading_RealESRGAN(realesrgan_model_name) + load_models(use_LDSR=use_LDSR, LDSR_model=LDSR_model_name, use_GFPGAN=use_GFPGAN, use_RealESRGAN=use_RealESRGAN, RealESRGAN_model=realesrgan_model_name) + + result = server_state["LDSR"].superResolution(st.session_state["uploaded_image"][i].pil_image, ddimSteps = st.session_state["ldsr_sampling_steps"], + preDownScale = st.session_state["preDownScale"], postDownScale = st.session_state["postDownScale"], + downsample_method=st.session_state["downsample_method"]) + + ldsr_filename = st.session_state["uploaded_image"][i].name.split('.')[0] + '-ldsr4x' + + result.save(os.path.join(st.session_state["defaults"].post_processing.outdir_post_processing, f"{ldsr_filename}.png")) + + # + elif use_LDSR and "LDSR" in server_state and use_GFPGAN and "GFPGAN" in server_state: + logger.info ("Running GFPGAN+LDSR on image %d of %d..." % (i+1, len(st.session_state["uploaded_image"]))) + if "progress_bar_text" in st.session_state: + st.session_state["progress_bar_text"].text("Running GFPGAN+LDSR on image %d of %d..." % (i+1, len(st.session_state["uploaded_image"]))) + + if "progress_bar" in st.session_state: + st.session_state["progress_bar"].progress( + int(100 * float(i+1 if i+1 < len(st.session_state["uploaded_image"]) else len(st.session_state["uploaded_image"]))/float(len(st.session_state["uploaded_image"])))) + + if server_state["GFPGAN"].name != GFPGAN_model: + load_models(use_LDSR=use_LDSR, LDSR_model=LDSR_model_name, use_GFPGAN=use_GFPGAN, use_RealESRGAN=use_RealESRGAN, RealESRGAN_model=realesrgan_model_name) + + torch_gc() + cropped_faces, restored_faces, restored_img = server_state["GFPGAN"].enhance(st.session_state["uploaded_image"][i].pil_image, has_aligned=False, only_center_face=False, paste_back=True) + + gfpgan_sample = restored_img[:,:,::-1] + gfpgan_image = Image.fromarray(gfpgan_sample) + + if server_state["LDSR"].name != LDSR_model_name: + #try_loading_RealESRGAN(realesrgan_model_name) + load_models(use_LDSR=use_LDSR, LDSR_model=LDSR_model_name, use_GFPGAN=use_GFPGAN, use_RealESRGAN=use_RealESRGAN, RealESRGAN_model=realesrgan_model_name) + + #LDSR.superResolution(gfpgan_image, ddimSteps=100, preDownScale='None', postDownScale='None', downsample_method="Lanczos") + result = server_state["LDSR"].superResolution(gfpgan_image, ddimSteps = st.session_state["ldsr_sampling_steps"], + preDownScale = st.session_state["preDownScale"], postDownScale = st.session_state["postDownScale"], + downsample_method=st.session_state["downsample_method"]) + + ldsr_filename = st.session_state["uploaded_image"][i].name.split('.')[0] + '-gfpgan-ldsr2x' + + result.save(os.path.join(st.session_state["defaults"].post_processing.outdir_post_processing, f"{ldsr_filename}.png")) + + elif use_RealESRGAN and server_state["RealESRGAN"] is not None and use_GFPGAN and server_state["GFPGAN"] is not None: + if "progress_bar_text" in st.session_state: + st.session_state["progress_bar_text"].text("Running GFPGAN+RealESRGAN on image %d of %d..." % (i+1, len(st.session_state["uploaded_image"]))) + + if "progress_bar" in st.session_state: + st.session_state["progress_bar"].progress( + int(100 * float(i+1 if i+1 < len(st.session_state["uploaded_image"]) else len(st.session_state["uploaded_image"]))/float(len(st.session_state["uploaded_image"])))) + + torch_gc() + cropped_faces, restored_faces, restored_img = server_state["GFPGAN"].enhance(st.session_state["uploaded_image"][i].pil_image, has_aligned=False, only_center_face=False, paste_back=True) + gfpgan_sample = restored_img[:,:,::-1] + + if server_state["RealESRGAN"].model.name != realesrgan_model_name: + #try_loading_RealESRGAN(realesrgan_model_name) + load_models(use_GFPGAN=use_GFPGAN, use_RealESRGAN=use_RealESRGAN, RealESRGAN_model=realesrgan_model_name) + + output, img_mode = server_state["RealESRGAN"].enhance(gfpgan_sample[:,:,::-1]) + gfpgan_esrgan_filename = st.session_state["uploaded_image"][i].name.split('.')[0] + '-gfpgan-esrgan4x' + gfpgan_esrgan_sample = output[:,:,::-1] + gfpgan_esrgan_image = Image.fromarray(gfpgan_esrgan_sample) + + gfpgan_esrgan_image.save(os.path.join(st.session_state["defaults"].post_processing.outdir_post_processing, f"{gfpgan_esrgan_filename}.png")) + + + +def layout(): + #st.info("Under Construction. :construction_worker:") + st.session_state["progress_bar_text"] = st.empty() + #st.session_state["progress_bar_text"].info("Nothing but crickets here, try generating something first.") + + st.session_state["progress_bar"] = st.empty() + + with st.form("post-processing-inputs"): + # creating the page layout using columns + col1, col2 = st.columns([1, 4], gap="medium") + + with col1: + st.session_state["uploaded_image"] = st.file_uploader('Input Image', type=['png', 'jpg', 'jpeg', 'jfif', 'webp'], accept_multiple_files=True) + + + # check if GFPGAN, RealESRGAN and LDSR are available. + #if "GFPGAN_available" not in st.session_state: + GFPGAN_available() + + #if "RealESRGAN_available" not in st.session_state: + RealESRGAN_available() + + #if "LDSR_available" not in st.session_state: + LDSR_available() + + if st.session_state["GFPGAN_available"] or st.session_state["RealESRGAN_available"] or st.session_state["LDSR_available"]: + face_restoration_tab, upscaling_tab = st.tabs(["Face Restoration", "Upscaling"]) + with face_restoration_tab: + # GFPGAN used for face restoration + if st.session_state["GFPGAN_available"]: + #with st.expander("Face Restoration"): + #if st.session_state["GFPGAN_available"]: + #with st.expander("GFPGAN"): + st.session_state["use_GFPGAN"] = st.checkbox("Use GFPGAN", value=st.session_state['defaults'].txt2img.use_GFPGAN, + help="Uses the GFPGAN model to improve faces after the generation.\ + This greatly improve the quality and consistency of faces but uses\ + extra VRAM. Disable if you need the extra VRAM.") + + st.session_state["GFPGAN_model"] = st.selectbox("GFPGAN model", st.session_state["GFPGAN_models"], + index=st.session_state["GFPGAN_models"].index(st.session_state['defaults'].general.GFPGAN_model)) + + #st.session_state["GFPGAN_strenght"] = st.slider("Effect Strenght", min_value=1, max_value=100, value=1, step=1, help='') + + else: + st.session_state["use_GFPGAN"] = False + + with upscaling_tab: + st.session_state['use_upscaling'] = st.checkbox("Use Upscaling", value=st.session_state['defaults'].txt2img.use_upscaling) + + # RealESRGAN and LDSR used for upscaling. + if st.session_state["RealESRGAN_available"] or st.session_state["LDSR_available"]: + + upscaling_method_list = [] + if st.session_state["RealESRGAN_available"]: + upscaling_method_list.append("RealESRGAN") + if st.session_state["LDSR_available"]: + upscaling_method_list.append("LDSR") + + #print (st.session_state["RealESRGAN_available"]) + st.session_state["upscaling_method"] = st.selectbox("Upscaling Method", upscaling_method_list, + index=upscaling_method_list.index(st.session_state['defaults'].general.upscaling_method) + if st.session_state['defaults'].general.upscaling_method in upscaling_method_list + else 0) + + if st.session_state["RealESRGAN_available"]: + with st.expander("RealESRGAN"): + if st.session_state["upscaling_method"] == "RealESRGAN" and st.session_state['use_upscaling']: + st.session_state["use_RealESRGAN"] = True + else: + st.session_state["use_RealESRGAN"] = False + + st.session_state["RealESRGAN_model"] = st.selectbox("RealESRGAN model", st.session_state["RealESRGAN_models"], + index=st.session_state["RealESRGAN_models"].index(st.session_state['defaults'].general.RealESRGAN_model)) + else: + st.session_state["use_RealESRGAN"] = False + st.session_state["RealESRGAN_model"] = "RealESRGAN_x4plus" + + + # + if st.session_state["LDSR_available"]: + with st.expander("LDSR"): + if st.session_state["upscaling_method"] == "LDSR" and st.session_state['use_upscaling']: + st.session_state["use_LDSR"] = True + else: + st.session_state["use_LDSR"] = False + + st.session_state["LDSR_model"] = st.selectbox("LDSR model", st.session_state["LDSR_models"], + index=st.session_state["LDSR_models"].index(st.session_state['defaults'].general.LDSR_model)) + + st.session_state["ldsr_sampling_steps"] = st.number_input("Sampling Steps", value=st.session_state['defaults'].txt2img.LDSR_config.sampling_steps, + help="") + + st.session_state["preDownScale"] = st.number_input("PreDownScale", value=st.session_state['defaults'].txt2img.LDSR_config.preDownScale, + help="") + + st.session_state["postDownScale"] = st.number_input("postDownScale", value=st.session_state['defaults'].txt2img.LDSR_config.postDownScale, + help="") + + downsample_method_list = ['Nearest', 'Lanczos'] + st.session_state["downsample_method"] = st.selectbox("Downsample Method", downsample_method_list, + index=downsample_method_list.index(st.session_state['defaults'].txt2img.LDSR_config.downsample_method)) + + else: + st.session_state["use_LDSR"] = False + st.session_state["LDSR_model"] = "model" + + #process = st.form_submit_button("Process Images", help="") + + # + with st.expander("Output Settings", True): + #st.session_state['defaults'].post_processing.save_original_images = st.checkbox("Save input images.", value=st.session_state['defaults'].post_processing.save_original_images, + #help="Save each original/input image next to the Post Processed image. " + #"This might be helpful for comparing the before and after images.") + + st.session_state['defaults'].post_processing.outdir_post_processing = st.text_input("Output Dir",value=st.session_state['defaults'].post_processing.outdir_post_processing, + help="Folder where the images will be saved after post processing.") + + with col2: + st.subheader("Image") + + image_col1, image_col2, image_col3 = st.columns([2, 2, 2], gap="small") + with image_col1: + refresh = st.form_submit_button("Refresh", help='Refresh the image preview to show your uploaded image.') + + if st.session_state["uploaded_image"]: + #print (type(st.session_state["uploaded_image"])) + # if len(st.session_state["uploaded_image"]) == 1: + st.session_state["input_image_preview"] = [] + st.session_state["input_image_caption"] = [] + st.session_state["output_image_preview"] = [] + st.session_state["output_image_caption"] = [] + st.session_state["input_image_preview_container"] = [] + st.session_state["prediction_table"] = [] + st.session_state["text_result"] = [] + + for i in range(len(st.session_state["uploaded_image"])): + st.session_state["input_image_preview_container"].append(i) + st.session_state["input_image_preview_container"][i] = st.empty() + + with st.session_state["input_image_preview_container"][i].container(): + col1_output, col2_output, col3_output = st.columns([2, 2, 2], gap="medium") + with col1_output: + st.session_state["output_image_caption"].append(i) + st.session_state["output_image_caption"][i] = st.empty() + #st.session_state["output_image_caption"][i] = st.session_state["uploaded_image"][i].name + + st.session_state["input_image_caption"].append(i) + st.session_state["input_image_caption"][i] = st.empty() + #st.session_state["input_image_caption"][i].caption(") + + st.session_state["input_image_preview"].append(i) + st.session_state["input_image_preview"][i] = st.empty() + st.session_state["uploaded_image"][i].pil_image = Image.open(st.session_state["uploaded_image"][i]).convert('RGB') + + st.session_state["input_image_preview"][i].image(st.session_state["uploaded_image"][i].pil_image, use_column_width=True, clamp=True) + + with col2_output: + st.session_state["output_image_preview"].append(i) + st.session_state["output_image_preview"][i] = st.empty() + + st.session_state["output_image_preview"][i].image(st.session_state["uploaded_image"][i].pil_image, use_column_width=True, clamp=True) + + with st.session_state["input_image_preview_container"][i].container(): + + with col3_output: + + #st.session_state["prediction_table"].append(i) + #st.session_state["prediction_table"][i] = st.empty() + #st.session_state["prediction_table"][i].table(pd.DataFrame(columns=["Model", "Filename", "Progress"])) + + st.session_state["text_result"].append(i) + st.session_state["text_result"][i] = st.empty() + st.session_state["text_result"][i].code("", language="") + + #else: + ##st.session_state["input_image_preview"].code('', language="") + #st.image("images/streamlit/img2txt_placeholder.png", clamp=True) + + with image_col3: + # Every form must have a submit button, the extra blank spaces is a temp way to align it with the input field. Needs to be done in CSS or some other way. + process = st.form_submit_button("Process Images!") + + if process: + with hc.HyLoader('Loading Models...', hc.Loaders.standard_loaders,index=[0]): + #load_models(use_LDSR=st.session_state["use_LDSR"], LDSR_model=st.session_state["LDSR_model"], + #use_GFPGAN=st.session_state["use_GFPGAN"], GFPGAN_model=st.session_state["GFPGAN_model"] , + #use_RealESRGAN=st.session_state["use_RealESRGAN"], RealESRGAN_model=st.session_state["RealESRGAN_model"]) + + if st.session_state["use_GFPGAN"]: + load_GFPGAN(model_name=st.session_state["GFPGAN_model"]) + + if st.session_state["use_RealESRGAN"]: + load_RealESRGAN(st.session_state["RealESRGAN_model"]) + + if st.session_state["use_LDSR"]: + load_LDSR(st.session_state["LDSR_model"]) + + post_process(use_GFPGAN=st.session_state["use_GFPGAN"], GFPGAN_model=st.session_state["GFPGAN_model"], + use_RealESRGAN=st.session_state["use_RealESRGAN"], realesrgan_model_name=st.session_state["RealESRGAN_model"], + use_LDSR=st.session_state["use_LDSR"], LDSR_model_name=st.session_state["LDSR_model"]) \ No newline at end of file diff --git a/webui/streamlit/scripts/sd_concept_library.py b/webui/streamlit/scripts/sd_concept_library.py new file mode 100644 index 0000000..f8c1217 --- /dev/null +++ b/webui/streamlit/scripts/sd_concept_library.py @@ -0,0 +1,260 @@ +# This file is part of sygil-webui (https://github.com/Sygil-Dev/sygil-webui/). + +# Copyright 2022 Sygil-Dev team. +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. + +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. + +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . +# base webui import and utils. +from sd_utils import st + +# streamlit imports +import streamlit.components.v1 as components +#other imports + +import os, math +from PIL import Image + +# Temp imports +#from basicsr.utils.registry import ARCH_REGISTRY + + +# end of imports +#--------------------------------------------------------------------------------------------------------------- + +# Init Vuejs component +_component_func = components.declare_component( + "sd-concepts-browser", "./frontend/dists/concept-browser/dist") + + +def sdConceptsBrowser(concepts, key=None): + component_value = _component_func(concepts=concepts, key=key, default="") + return component_value + + +@st.experimental_memo(persist="disk", show_spinner=False, suppress_st_warning=True) +def getConceptsFromPath(page, conceptPerPage, searchText=""): + #print("getConceptsFromPath", "page:", page, "conceptPerPage:", conceptPerPage, "searchText:", searchText) + # get the path where the concepts are stored + path = os.path.join( + os.getcwd(), st.session_state['defaults'].general.sd_concepts_library_folder) + acceptedExtensions = ('jpeg', 'jpg', "png") + concepts = [] + + if os.path.exists(path): + # List all folders (concepts) in the path + folders = [f for f in os.listdir( + path) if os.path.isdir(os.path.join(path, f))] + filteredFolders = folders + + # Filter the folders by the search text + if searchText != "": + filteredFolders = [ + f for f in folders if searchText.lower() in f.lower()] + else: + filteredFolders = [] + + conceptIndex = 1 + for folder in filteredFolders: + # handle pagination + if conceptIndex > (page * conceptPerPage): + continue + if conceptIndex <= ((page - 1) * conceptPerPage): + conceptIndex += 1 + continue + + concept = { + "name": folder, + "token": "<" + folder + ">", + "images": [], + "type": "" + } + + # type of concept is inside type_of_concept.txt + typePath = os.path.join(path, folder, "type_of_concept.txt") + binFile = os.path.join(path, folder, "learned_embeds.bin") + + # Continue if the concept is not valid or the download has failed (no type_of_concept.txt or no binFile) + if not os.path.exists(typePath) or not os.path.exists(binFile): + continue + + with open(typePath, "r") as f: + concept["type"] = f.read() + + # List all files in the concept/concept_images folder + files = [f for f in os.listdir(os.path.join(path, folder, "concept_images")) if os.path.isfile( + os.path.join(path, folder, "concept_images", f))] + # Retrieve only the 4 first images + for file in files: + + # Skip if we already have 4 images + if len(concept["images"]) >= 4: + break + + if file.endswith(acceptedExtensions): + try: + # Add a copy of the image to avoid file locking + originalImage = Image.open(os.path.join( + path, folder, "concept_images", file)) + + # Maintain the aspect ratio (max 200x200) + resizedImage = originalImage.copy() + resizedImage.thumbnail((200, 200), Image.Resampling.LANCZOS) + + # concept["images"].append(resizedImage) + + concept["images"].append(imageToBase64(resizedImage)) + # Close original image + originalImage.close() + except: + print("Error while loading image", file, "in concept", folder, "(The file may be corrupted). Skipping it.") + + concepts.append(concept) + conceptIndex += 1 + # print all concepts name + #print("Results:", [c["name"] for c in concepts]) + return concepts + +@st.cache(persist=True, allow_output_mutation=True, show_spinner=False, suppress_st_warning=True) +def imageToBase64(image): + import io + import base64 + buffered = io.BytesIO() + image.save(buffered, format="PNG") + img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") + return img_str + + +@st.experimental_memo(persist="disk", show_spinner=False, suppress_st_warning=True) +def getTotalNumberOfConcepts(searchText=""): + # get the path where the concepts are stored + path = os.path.join( + os.getcwd(), st.session_state['defaults'].general.sd_concepts_library_folder) + concepts = [] + + if os.path.exists(path): + # List all folders (concepts) in the path + folders = [f for f in os.listdir( + path) if os.path.isdir(os.path.join(path, f))] + filteredFolders = folders + + # Filter the folders by the search text + if searchText != "": + filteredFolders = [ + f for f in folders if searchText.lower() in f.lower()] + else: + filteredFolders = [] + return len(filteredFolders) + + + +def layout(): + # 2 tabs, one for Concept Library and one for the Download Manager + tab_library, tab_downloader = st.tabs(["Library", "Download Manager"]) + + # Concept Library + with tab_library: + downloaded_concepts_count = getTotalNumberOfConcepts() + concepts_per_page = st.session_state["defaults"].concepts_library.concepts_per_page + + if not "results" in st.session_state: + st.session_state["results"] = getConceptsFromPath(1, concepts_per_page, "") + + # Pagination controls + if not "cl_current_page" in st.session_state: + st.session_state["cl_current_page"] = 1 + + # Search + if not 'cl_search_text' in st.session_state: + st.session_state["cl_search_text"] = "" + + if not 'cl_search_results_count' in st.session_state: + st.session_state["cl_search_results_count"] = downloaded_concepts_count + + # Search bar + _search_col, _refresh_col = st.columns([10, 2]) + with _search_col: + search_text_input = st.text_input("Search", "", placeholder=f'Search for a concept ({downloaded_concepts_count} available)', label_visibility="hidden") + if search_text_input != st.session_state["cl_search_text"]: + # Search text has changed + st.session_state["cl_search_text"] = search_text_input + st.session_state["cl_current_page"] = 1 + st.session_state["cl_search_results_count"] = getTotalNumberOfConcepts(st.session_state["cl_search_text"]) + st.session_state["results"] = getConceptsFromPath(1, concepts_per_page, st.session_state["cl_search_text"]) + + with _refresh_col: + # Super weird fix to align the refresh button with the search bar ( Please streamlit, add css support.. ) + _refresh_col.write("") + _refresh_col.write("") + if st.button("Refresh concepts", key="refresh_concepts", help="Refresh the concepts folders. Use this if you have added new concepts manually or deleted some."): + getTotalNumberOfConcepts.clear() + getConceptsFromPath.clear() + st.experimental_rerun() + + + # Show results + results_empty = st.empty() + + # Pagination + pagination_empty = st.empty() + + # Layouts + with pagination_empty: + with st.container(): + if len(st.session_state["results"]) > 0: + last_page = math.ceil(st.session_state["cl_search_results_count"] / concepts_per_page) + _1, _2, _3, _4, _previous_page, _current_page, _next_page, _9, _10, _11, _12 = st.columns([1,1,1,1,1,2,1,1,1,1,1]) + + # Previous page + with _previous_page: + if st.button("Previous", key="cl_previous_page"): + st.session_state["cl_current_page"] -= 1 + if st.session_state["cl_current_page"] <= 0: + st.session_state["cl_current_page"] = last_page + st.session_state["results"] = getConceptsFromPath(st.session_state["cl_current_page"], concepts_per_page, st.session_state["cl_search_text"]) + + # Current page + with _current_page: + _current_page_container = st.empty() + + # Next page + with _next_page: + if st.button("Next", key="cl_next_page"): + st.session_state["cl_current_page"] += 1 + if st.session_state["cl_current_page"] > last_page: + st.session_state["cl_current_page"] = 1 + st.session_state["results"] = getConceptsFromPath(st.session_state["cl_current_page"], concepts_per_page, st.session_state["cl_search_text"]) + + # Current page + with _current_page_container: + st.markdown(f'

Page {st.session_state["cl_current_page"]} of {last_page}

', unsafe_allow_html=True) + # st.write(f"Page {st.session_state['cl_current_page']} of {last_page}", key="cl_current_page") + + with results_empty: + with st.container(): + if downloaded_concepts_count == 0: + st.write("You don't have any concepts in your library ") + st.markdown("To add concepts to your library, download some from the [sd-concepts-library](https://github.com/Sygil-Dev/sd-concepts-library) \ + repository and save the content of `sd-concepts-library` into ```./models/custom/sd-concepts-library``` or just create your own concepts :wink:.", unsafe_allow_html=False) + else: + if len(st.session_state["results"]) == 0: + st.write("No concept found in the library matching your search: " + st.session_state["cl_search_text"]) + else: + # display number of results + if st.session_state["cl_search_text"]: + st.write(f"Found {st.session_state['cl_search_results_count']} {'concepts' if st.session_state['cl_search_results_count'] > 1 else 'concept' } matching your search") + sdConceptsBrowser(st.session_state['results'], key="results") + + + with tab_downloader: + st.write("Not implemented yet") + + return False diff --git a/webui/streamlit/scripts/sd_utils/__init__.py b/webui/streamlit/scripts/sd_utils/__init__.py new file mode 100644 index 0000000..f38124b --- /dev/null +++ b/webui/streamlit/scripts/sd_utils/__init__.py @@ -0,0 +1,405 @@ +# This file is part of sygil-webui (https://github.com/Sygil-Dev/sygil-webui/). + +# Copyright 2022 Sygil-Dev team. +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. + +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. + +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . +# base webui import and utils. +#from webui_streamlit import st +import hydralit as st + +# streamlit imports +from streamlit.runtime.scriptrunner import StopException +#from streamlit.runtime.scriptrunner import script_run_context + +#streamlit components section +from streamlit_server_state import server_state, server_state_lock, no_rerun +import hydralit_components as hc +from hydralit import HydraHeadApp +import streamlit_nested_layout +#from streamlitextras.threader import lock, trigger_rerun, \ + #streamlit_thread, get_thread, \ + #last_trigger_time + +#other imports + +import warnings +import json + +import base64, cv2 +import os, sys, re, random, datetime, time, math, toml +import gc +from PIL import Image, ImageFont, ImageDraw, ImageFilter +from PIL.PngImagePlugin import PngInfo +from scipy import integrate +import torch +from torchdiffeq import odeint +import k_diffusion as K +import math, requests +import mimetypes +import numpy as np +import pynvml +import threading +import torch, torchvision +from torch import autocast +from torchvision import transforms +import torch.nn as nn +from omegaconf import OmegaConf +import yaml +from pathlib import Path +from contextlib import nullcontext +from einops import rearrange, repeat +from ldm.util import instantiate_from_config +from retry import retry +from slugify import slugify +import skimage +import piexif +import piexif.helper +from tqdm import trange +from ldm.models.diffusion.ddim import DDIMSampler +from ldm.util import ismap +#from abc import ABC, abstractmethod +from io import BytesIO +from packaging import version +from pathlib import Path +from huggingface_hub import hf_hub_download +import shutup + +#import librosa +from nataili.util.logger import logger, set_logger_verbosity, quiesce_logger +from nataili.esrgan import esrgan + + +#try: + #from realesrgan import RealESRGANer + #from basicsr.archs.rrdbnet_arch import RRDBNet +#except ImportError as e: + #logger.error("You tried to import realesrgan without having it installed properly. To install Real-ESRGAN, run:\n\n" + #"pip install realesrgan") + +# Temp imports +#from basicsr.utils.registry import ARCH_REGISTRY + + +# end of imports +#--------------------------------------------------------------------------------------------------------------- + +# remove all the annoying python warnings. +shutup.please() + +# the following lines should help fixing an issue with nvidia 16xx cards. +if "defaults" in st.session_state: + if st.session_state["defaults"].general.use_cudnn: + torch.backends.cudnn.benchmark = True + torch.backends.cudnn.enabled = True + +try: + # this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start. + from transformers import logging + + logging.set_verbosity_error() +except: + pass + +# disable diffusers telemetry +os.environ["DISABLE_TELEMETRY"] = "YES" + +# remove some annoying deprecation warnings that show every now and then. +warnings.filterwarnings("ignore", category=DeprecationWarning) +warnings.filterwarnings("ignore", category=UserWarning) + +# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the bowser will not show any UI +mimetypes.init() +mimetypes.add_type('application/javascript', '.js') + +# some of those options should not be changed at all because they would break the model, so I removed them from options. +opt_C = 4 +opt_f = 8 + +# The model manager loads and unloads the SD models and has features to download them or find their location +#model_manager = ModelManager() + +def load_configs(): + if not "defaults" in st.session_state: + st.session_state["defaults"] = {} + + st.session_state["defaults"] = OmegaConf.load("configs/webui/webui_streamlit.yaml") + + if (os.path.exists("configs/webui/userconfig_streamlit.yaml")): + user_defaults = OmegaConf.load("configs/webui/userconfig_streamlit.yaml") + + if "version" in user_defaults.general: + if version.parse(user_defaults.general.version) < version.parse(st.session_state["defaults"].general.version): + logger.error("The version of the user config file is older than the version on the defaults config file. " + "This means there were big changes we made on the config." + "We are removing this file and recreating it from the defaults in order to make sure things work properly.") + os.remove("configs/webui/userconfig_streamlit.yaml") + st.experimental_rerun() + else: + logger.error("The version of the user config file is older than the version on the defaults config file. " + "This means there were big changes we made on the config." + "We are removing this file and recreating it from the defaults in order to make sure things work properly.") + os.remove("configs/webui/userconfig_streamlit.yaml") + st.experimental_rerun() + + try: + st.session_state["defaults"] = OmegaConf.merge(st.session_state["defaults"], user_defaults) + except KeyError: + st.experimental_rerun() + else: + OmegaConf.save(config=st.session_state.defaults, f="configs/webui/userconfig_streamlit.yaml") + loaded = OmegaConf.load("configs/webui/userconfig_streamlit.yaml") + assert st.session_state.defaults == loaded + + if (os.path.exists(".streamlit/config.toml")): + st.session_state["streamlit_config"] = toml.load(".streamlit/config.toml") + + #if st.session_state["defaults"].daisi_app.running_on_daisi_io: + #if os.path.exists("scripts/modeldownload.py"): + #import modeldownload + #modeldownload.updateModels() + + if "keep_all_models_loaded" in st.session_state.defaults.general: + with server_state_lock["keep_all_models_loaded"]: + server_state["keep_all_models_loaded"] = st.session_state["defaults"].general.keep_all_models_loaded + else: + st.session_state["defaults"].general.keep_all_models_loaded = False + with server_state_lock["keep_all_models_loaded"]: + server_state["keep_all_models_loaded"] = st.session_state["defaults"].general.keep_all_models_loaded + +load_configs() + +# +#if st.session_state["defaults"].debug.enable_hydralit: + #navbar_theme = {'txc_inactive': '#FFFFFF','menu_background':'#0e1117','txc_active':'black','option_active':'red'} + #app = st.HydraApp(title='Stable Diffusion WebUI', favicon="", use_cookie_cache=False, sidebar_state="expanded", layout="wide", navbar_theme=navbar_theme, + #hide_streamlit_markers=False, allow_url_nav=True , clear_cross_app_sessions=False, use_loader=False) +#else: + #app = None + +# +grid_format = st.session_state["defaults"].general.save_format +grid_lossless = False +grid_quality = st.session_state["defaults"].general.grid_quality +if grid_format == 'png': + grid_ext = 'png' + grid_format = 'png' +elif grid_format in ['jpg', 'jpeg']: + grid_quality = int(grid_format) if len(grid_format) > 1 else 100 + grid_ext = 'jpg' + grid_format = 'jpeg' +elif grid_format[0] == 'webp': + grid_quality = int(grid_format) if len(grid_format) > 1 else 100 + grid_ext = 'webp' + grid_format = 'webp' + if grid_quality < 0: # e.g. webp:-100 for lossless mode + grid_lossless = True + grid_quality = abs(grid_quality) + +# +save_format = st.session_state["defaults"].general.save_format +save_lossless = False +save_quality = 100 +if save_format == 'png': + save_ext = 'png' + save_format = 'png' +elif save_format in ['jpg', 'jpeg']: + save_quality = int(save_format) if len(save_format) > 1 else 100 + save_ext = 'jpg' + save_format = 'jpeg' +elif save_format == 'webp': + save_quality = int(save_format) if len(save_format) > 1 else 100 + save_ext = 'webp' + save_format = 'webp' + if save_quality < 0: # e.g. webp:-100 for lossless mode + save_lossless = True + save_quality = abs(save_quality) + +# this should force GFPGAN and RealESRGAN onto the selected gpu as well +os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" # see issue #152 +os.environ["CUDA_VISIBLE_DEVICES"] = str(st.session_state["defaults"].general.gpu) + + +# functions to load css locally OR remotely starts here. Options exist for future flexibility. Called as st.markdown with unsafe_allow_html as css injection +# TODO, maybe look into async loading the file especially for remote fetching +def local_css(file_name): + with open(file_name) as f: + st.markdown(f'', unsafe_allow_html=True) + +def remote_css(url): + st.markdown(f'', unsafe_allow_html=True) + +def load_css(isLocal, nameOrURL): + if(isLocal): + local_css(nameOrURL) + else: + remote_css(nameOrURL) + +def set_page_title(title): + """ + Simple function to allows us to change the title dynamically. + Normally you can use `st.set_page_config` to change the title but it can only be used once per app. + """ + + st.sidebar.markdown(unsafe_allow_html=True, body=f""" +