mirror of
https://github.com/sd-webui/stable-diffusion-webui.git
synced 2024-12-15 23:31:59 +03:00
The Merge (#1201)
* resolve conflict with master * - Added option to select custom models instead of just using the default one, if you want to use a custom model just place your .ckpt file in "models/custom" and the UI will detect it and let you switch between stable diffusion and your custom model, make sure to give the filename a proper name that is easy to distinguish from other models because that name will be used on the UI. - Implemented basic Text To Video tab, will continue to improve it as it is really basic right now. - Improved the model loading, you now should see less frequently errors about it not been loaded correctly. * fix: advanced editor (#827), close #811 refactor js_Call hook to take all gradio arguments * Added num_inference_steps to config file and fixed incorrectly calls to the config file from the txt2vid tab calling txt2img instead. * update readme as per installation step & format * proposed streamlit code organization changes I want people of all skill levels to be able to contribute This is one way the code could be split up with the aim of making it easy to understand and contribute especially for people on the lower end of the skill spectrum All i've done is split things, I think renaming and reorganising is still needed * Fixed missing diffusers dependency for Streamlit * Streamlit: Allow user defaults to be specified in a userconfig_streamlit.yaml file. * Changed Streamit yaml default configs Changed `update_preview_frequency` from every 1 step to every 5 steps. This results in a massive gain in performance (roughly going from 2-3 times slower to only 10-15% slower) while still showing good image generation output. Changed default GFPGAN and realESRGAN settings to be off by default. That way, users can decide if they want to use them on, and what images they wish to do so. * Made sure img2txt and img2img checkboxes respect YAML defaults * Move location of user file to configs/webui folder * Fixed the path in webui_streamlit.py * Display Info and Stats when render is complete, similar to what Gradio shows. * Add info and stats to img2img * chore: update maintenance scripts and docs (#891) * automate conda_env_name as per name in yaml * Embed installation links directly in README.md Include links to Windows, Linux, and Google Colab installations. * Fix conda update in webui.sh for pip bug * Add info about new PRs Co-authored-by: Hafiidz <3688500+Hafiidz@users.noreply.github.com> Co-authored-by: Tom Pham <54967380+TomPham97@users.noreply.github.com> Co-authored-by: GRMrGecko <grmrgecko@gmail.com> * Improvements to the txt2vid tab. * Urgent Fix to PR:860 * Update attention.py * Update FUNDING.yml * when in outcrop mode, mask added regions and fill in with voroni noise for better outpainting * frontend: display current device info (#889) Displays the current device info at the bottom of the page. For users who run multiple instances of `sd-webui` on the same system (for multiple GPUs), it helps to know which of the active `CUDA_VISIBLE_DEVICES` is being used. * Fixed aspect ratio box not being updated on txt2img tab, for issue 219 from old repo (#812) * Metadata cleanup - Maintain metadata within UI (#845) * Metadata cleanup - Maintain metadata within UI This commit, when combined with Gradio 3.2.1b1+, maintains image metadata as an image is passed throughout the UI. For example, if you generate an image, send it to Image Lab, upscale it, fix faces, and then drag the resulting image back in to Image Lab, it will still remember the image generation parameters. When the image is saved, the metadata will be stripped from it if save-metadata is not enabled. If the image is saved by *dragging* out of the UI on to the filesystem it may maintain its metadata. Note: I have ran into UI responsiveness issues with upgrading Gradio. Seems there may be some Gradio queue management issues. *Without* the gradio update this commit will maintain current functionality, but will not keep meetadata when dragging an image between UI components. * Move ImageMetadata into its own file Cleans up webui, enables webui_streamlit et al to use it as well. * Fix typo * Add filename formatting argument (#908) * Update webui.py Filename formatting argument * Update scripts/webui.py Co-authored-by: Thomas Mello <work.mello@gmail.com> * Tiling parameter (#911) * tiling * default to False * fix: filename format parameter (#923) * For issue :884, ensure webui.cmd before init src * Remove embeddings file path * Add mask_restore to restore images based on mask, fixing #665 (#898) * Add mask_restore option to give users the option to restore images based on mask, fixing #665. Before commitc73fdd78
(Implement masking during sampling to improve blending, #308) image mask was applied after sampling, resulting in masked parts that are not regenerated to actually stay the same. Sincec73fdd78
the masked img2img will change the whole image, even in masked areas. It gives better looking results at first glance, but will result in image degredation when applied a few times. See issue #665. In the workflow of using repeated masked img2img, users may want to use this options to keep the parts of image they actually want to keep without image degradation. A final masked img2img or whole image img2img with mask_restore disabled will give the better blending of "Implement masking during sampling". * revert changes ofa7be43ba
in change_image_editor_mode * fix ui_functions.change_image_editor_mode by adding gr.update to the end of the list it returns * revert inserted newlines and whitespaces to match format of previous code * improve caption of new option mask_restore "Only modify regenerated parts of image" * fix ui_functions.change_image_editor_mode by adding gr.update to the end of the list it returns an old copy of the function exists in webui.py, this superflous function mistakenly was changed by the earlier commitb6a9e16b
* remove unused functions that are near duplicates of functions in ui_functions.py * Added CSS to center the image in the txt2img interface * add img2img option for color correction. (#936) color correction is already used for loopback to prevent color drift with the first image as correction target. the option allows to use the color correction even without loopback mode. it helps keeping the colors similar to the input image. * Image transparency is used as mask for inpainting * fix: lost imports from #921 * Changed StreamIt to `k_euler` 30 steps as default * Fixed an issue with the txt2vid model. * Removed old files from a split test we deed that are not needed anymore, we plan to do the split differently. * Changed the scheduler for the txt2vid tab back to LMS, for now we can only use that. * Better support for large batches in optimized mode * Removed some unused lines from the css file for the streamlit version. * Changed the diffusers version to be 0.2.4 or lower as a new version breaks the txt2vid generation. * Added the models/custom folder to gitignore to ignore custom models. * Added two new scripts that will be used for the new implementation of the txt2vid tab which uses the latest version of the diffusers library. * - Improved the progress bar for the txt2vid tab, it now shows more information during generation. - Changed the guidance_scale variable to be cfg_scale. * Perform masked image restoration for GFPGAN, RealESRGAN, fixing #947 * Perform masked image restoration when using GFPGAN or RealESRGAN, fixing #947. Also fixes bug in image display when using masked image restoration with RealESRGAN. When the image is upscaled using RealESRGAN the image restoration can not use the original image because it has wrong resolution. In this case the image restoration will restore the non-regenerated parts of the image with an RealESRGAN upscaled version of the original input image. Modifications from GFPGAN or color correction in (un)masked parts are also restored to the original image by mask blending. * Update scripts/webui.py Co-authored-by: Thomas Mello <work.mello@gmail.com> * fix: sampler name in GoBig #988 * add sampler_name defaults to img2img * add metadata to other file output file types * remove deprecated kwargs/parameter * refactor: sort out dependencies Co-Authored-By: oc013 <101832295+oc013@users.noreply.github.com> Co-Authored-By: Aarni Koskela <akx@iki.fi> Co-Authored-By: oc013 <101832295+oc013@users.noreply.github.com> Co-Authored-By: Aarni Koskela <akx@iki.fi> * webui: detect scoped-down GPU environment (#993) * webui: detect scoped-down GPU environment check if we're using a scoped-down GPU environment (pynvml does not listen to CUDA_VISIBLE_DEVICES) so that we can measure memory on the correct GPU * remove unnecessary import * Added piexif dependency. * Changed the minimum value for the Sampling Steps and Inference Steps to 10 and added step with a value of 10 to make it easier to move the slider as it will require a higher maximum value than in other tabs for good results on the text2vid tab. * Commented an import that is not used for now but will be used soon. * write same metadata to file and yaml * include piexif in environment needed for exif labelling of non-png files * fix individual image file format saves * introduces a general config setting save_format similar to grid_format for individual file saves * Add NSFW filter to avoid unexpected (#955) * Add NSFW filter to avoid unexpected * Fix img2img configuration numbering * Added some basic layout for the Model Manager tab and added there the models that most people use to make it easy to download instead of having to go do the wiki or searching through discord for links, it also shows the path where you are supposed to put those models for them to work. * webui: display the GPU in use during startup (#994) * webui: display the GPU in use during startup tell the user which GPU the code is actually going to use before spending lots of time loading everything onto the GPU * typo * add some info messages * evaluate current GPU properly * add debug flag gating not everyone wants or needs to see debug messages :) * add in stray debug msg * Docker updates - Add LDSR, streamlit, other updates for new repository * Update util.py * Docker - Set PYTHONPATH to parent directory to avoid `No module named frontend` error * Add missing comma for nsfw toggle in img2img (#1028) * Multiple improvements to the txt2vid tab. - Improved txt2vid speed by 2 times. - Added DDIM scheduler. - Added sliders for beta_start and beta_end to have more control over these parameters on the scheduler. - Added option to select the scheduler type from scaled_linear or linear. - Added option to save info files for the txt2vid tab and improved the information saved to include most of the parameters used to run the generation. - You can now download any model from the huggingface website to use on the txt2vid tab, just add the name to the custom_models_list on the config file. * webui: add prompt output to console (#1031) * webui: add prompt output to console show the user what prompt is currently being rendered * fix prompt print location * support negative prompts separated by ### e.g. "shopping mall ### people" will try to generate an image of a mall without people in it. * Docker validate model files if not a symlink in case user has VALIDATE_MODELS=false set (#1038) * - Added changes made by @Hafiidz on the ui-improvements branch to the css for the streamli-on-hover-tabs component. * Added streamlit-on-Hover-tabs and streamlit-option-menu dependencies to the environment.yaml file. * Changed some values to be dynamic instead of a fixed value so they are more responsive. * Changed the cmd script to use the dark theme by default when launching the streamlit UI. * Removed the padding at the top of the sidebar so we can have more free space. * - Added code for @Hafiidz's changes made on the css. * Fixed an error with the metadata not able to be saved because of the seed was not converted to a string before so it had no attribute encode on it. * add masking to streamlit img2img, find_noise_for_image, matched_noise * Use the webui script directories as PWD (#946) * add Gradio API endpoint settings (#1055) * add Gradio API endpoint settings * Add comments crediting code authors. (probably not enough, but better than none) * Renamed the save_grid option for txt2vid on the config file to be save_video, this will be used to determine if the user wants to save a video at the end of the generation or not, similar to the save_grid that is used on txt2img and img2img but for video. * -Added the Update Image Preview option to be part of the current tab options under Preview Settings. - Added Dynamic Preview Frequency option for the txt2vid tab which tries to find the lowest value for update_preview_frequency at which we can update the preview image during generation while at the same time minimizing the impact it has in performance. - Added option to save a video file on the outputs/txt2vid-samples folder after the generation is complete similar to how the save_grid option works on other tabs. - Added a video preview which shows a video on the txt2vid tab when the generation is completed. - Formated some lines of code to make it use less space and fit on the a single screen. - Added a script called Settings.py to the script folder in which Settings for the Setting page will be placed. Empty for now. * Commented some print statements that were used for debugging and forgot to remove previously. * fix: disable live prompt parsing * Fix issue where loopback was using batch mode * Fix indentation error that prevents mask_restore from working unless ESRGAN is turned on * Fixed Sidebar CSS for 4K displays * img2img mask fixes and fix image2noise normalization * Made it so the sampling_steps is added to num_inference_steps, otherwise it would not match the value you set for it on the slider. * Changed the loading of the model on the txt2vid tab so the half models are only loaded if the no_half option on the config file is set to False. * fix: launcher batch file fix #920, fix #605 - Allow reading environment.yaml file in either LF or CRLF - Only update environment if environment.yaml changes - Remove custom_conda_path to discourage changing source file - Fix unable to launch webui due to frontend module missing (#605) * Update README.md (#1075) fix typo * half precision streamlit txt2vid `RuntimeError: expected scalar type Half but found Float` with both `torch_dtype=torch.float16` and `revision="fp16"` * Add mask restore feature to streamlit, prevent color correction from modifying initial image when mask_restore is turned on * Add mask_restore to streamlit config * JobManager: Fix typo breaking jobs close #858 close #1041 * JobManager: Buttons skip queue (#1092) Have JobManager buttons skip Gradio's queue, since otherwise they aren't sending JobManager button presses. * The webui_streamlit.py file has been split into multiple modules containing their own code making it easier to work with than a single big file. The list of modules is as follow: - webuit_streamlit.py: contains the main layout as well as the functions that load the css which is needed by the layout. - webui_streamlit_old.py: contains the code for the previous version of the WebUI. Will be removed once the new UI code starts to get used and if everything works as it should. - txt2img.py: contains the code for the txt2img tab. - img2img.py: contains the code for the img2img tab. - txt2vid.py: contains the code for the txt2vid tab. - sd_utils.py: contains utility functions used by more than one module, any function that meets such condition should be placed here. - ModelManager.py: contains the code for the Model Manager page on the sidebar menu. - Settings.py: contains the code for the Settings page on the sidebar menu. - home.py: contains the code for the Home tab, history and gallery implemented by @devilismyfriend. - imglab.py: contains the code for the Image Lab tab implemented by @devilismyfriend * fix: patch docker conda install pip requirements (#1094) (cherry picked from commitfab5765fe4
) Co-authored-by: Sérgio <smaisidoro@gmail.com> * Using the Optimization from Dogettx (#974) * Update attention.py * change to dogettx Co-authored-by: hlky <106811348+hlky@users.noreply.github.com> * Update Dockerfile (#1101) add expose for streamlit port * Publish Streamlit ports (#1102) (cherry picked from commit833a91047d
) Co-authored-by: Charlie <outlookhazy@users.noreply.github.com> * Forgot to call the layout() for the Model Manager tab after the import so it was not been used and the tab was shown as empty. * Removed the "find_noise_for_image.py" and "matched_noise.py" scripts as their content is now part of "sd_utils.py" * - Added the functions to load the optimized models, this "should" make it so optimized and turbo mode work now but needs to be tested more. - Added function to load LDSR. * Fixed some imports. * Fixed the info message on the txt2img tab not showing the info but just showing the text "Done" * Made the defaults settings from the config file be stored inside st.session_state to avoid loading it multiple times when calling the "sd_utils.py" file from other modules. * Moved defaults to the webui_streamlit.py file and fixed some imports. * Removed condition to check if the defaults are in the st.session_state dictionary, this is not needed and would cause issues with it not being reloaded when the user changes something on it. * Modified the way the defaults settings are loaded from the config file so we only load them on the webui_strealit.py file and use st.session_state to access them from anywhere else, this makes it so the config can be modified externally like before the code split and the changes will be updated on next rerun of the UI. * fix: [streamlit] optimization mode * temp disable nvml support for multiple gpus * Fixed defaults not being loaded correctly or missing in some places. * Add a separate update script instead of git pull on startup (#1106) * - Fixed max_frame not being properly used and instead sampling_steps was the variable being use. - Fixed several issues with wrong variable being used on multiple places. - Addd option to toggle some extra option from the config file for when the model is loading on the txt2vid tab. * Re-merge #611 - View/Cancel in-progress diffusions (#796) * JobManager: Re-merge #611 PR #611 seems to have got lost in the shuffle after the transition to 'dev'. This commit re-merges the feature branch. This adds support for viewing preview images as the image generates, as well as cancelling in-progress images and a couple fixes and clean-ups. * JobManager: Clear jobs that fail to start Sometimes if a job fails to start it will get stuck in the active job list. This commit ensures that jobs that raise exceptions are cleared, and also adds a start timer to clear out jobs that fail to start within a reasonable amount of time. * chore: add breaks to cmds for readability (#1134) * Added custom models list to the txt2img tab. * Small fix to the custom model list. * Corrected breaking issues introduced in #1136 to txt2img and made state variables consistent with img2img. Fixed a bug where switching models after running would not reload the used model. * Formatted tabs as spaces * Fixed update_preview_frequency and update_preview using defaults from webui_streamlit.yaml instead of state variables from UI. * Prompt user if they want to restore changes (#1137) - After stashing any changes and pulling updates, ask user if they wish to pop changes - If user declines the restore, drop the stash to prevent the case of an ever growing stash pile * Added streamlit_nested_layout component as dependency and imported on the webui_streamli.py file to allow us to use nested columns and expanders. * - Added the Home tab made by @devilismyfriend - Added gallery tab on txt2img. * Added case insensitivity to restore prompt (#1152) * Calculate aspect ratio and pixel count on start (#1157) * Fix errors rendering galleries when there are not enough images to render * Fix the gallery back/next buttons and add a refresh button * Fix invalid invocation of find_noise_for_image * Removed the Home tab until the gallery is fixed. * Fixed a missing import on the ModelManager script. * Added discord server link to the Readme.md * - Increased the max value for the width and height sliders on the txt2img tab. - Fixed a leftover line from removing the home tab. * Update conda environment on startup always (#1171) * Update environment on startup always * Message to explicitly state no environment.yaml update required Co-authored-by: hlky <106811348+hlky@users.noreply.github.com> * environment update from .cmd * Update .gitignore * Enable negative prompts on streamlit * - Bumped the version of diffusers used on the txt2vid tab to be now v0.3.0. - Added initial file for the textual inversion tab. * add missing argument to GoBig sample function, fixes #1183 (#1184) * cherry-pick @Any-Winter-4079's https://github.com/lstein/stable-diffusion/pull/540. this is a collaboration incorporating a lot of people's contributions -- including for example @Doggettx and the original code from @neonsecret on which the Doggetx optimizations were based (see https://github.com/lstein/stable-diffusion/issues/431, https://github.com/sd-webui/stable-diffusion-webui/pull/771\#issuecomment-1239716055). Takes exactly the same amount of time to run 8 steps as original CompVis code does (10.4 secs, ~1.25s/it). (#1177) Co-authored-by: Alex Birch <birch-san@users.noreply.github.com> * allow webp uploads to img2img tab #991 * Don't attempt mask restoration when there is no mask given (#1186) * When running a batch with preview turned on, produce a grid of preview images * When early terminating, generation_callback gets invoked but st.session_state is empty. When this happens, just bail. * Collect images for final display This is a collection of several changes to enhance image display: * When using GFPGAN or RealESRGAN, only the final output will be displayed. * In batch>1 mode, each final image will be collected into an image grid for display * The image is constrained to a reasonable size to ensure that batch grids of RealESRGAN'd images don't end up spitting out a massive image that the browser then has to handle. * Additionally, the progress bar indicator is updated as each image is post-processed. * Display the final image before running postprocessing, and don't preview when i=0 * Added a config option to use embeddings from the huggingface stable diffusion concept library. * Added option to enable enable_attention_slicing and enable_minimal_memory_usage, this for now only works on txt2vid which uses diffusers. * Basic implementation for the Concept Library tab made by cloning the Home tab. * Temporarily hide sd_concept_library due to missing layout() * st.session_state["defaults"] fix * Used loaded_model state variable in .yaml generation (#1196) Used loaded_model state variable in .yaml generation * Streamlit txt2img page settings now follow defaults (#1195) * Some options on the Streamlit txt2img page now follow the defaults from the relevant config files. * Fixed a copy-paste gone wrong in my previous commit. * st.session_state["defaults"] fix Co-authored-by: hlky <106811348+hlky@users.noreply.github.com> * default img2img denoising strength increased * slider_steps and slider_bounds in defaults config slider_steps and slider_bounds in defaults config * fix: copy to clipboard button Co-authored-by: ZeroCool940711 <alejandrogilelias940711@gmail.com> Co-authored-by: ZeroCool <ZeroCool940711@users.noreply.github.com> Co-authored-by: Hafiidz <3688500+Hafiidz@users.noreply.github.com> Co-authored-by: hlky <106811348+hlky@users.noreply.github.com> Co-authored-by: Joshua Kimsey <jkimsey95@gmail.com> Co-authored-by: Tony Beeman <beeman@gmail.com> Co-authored-by: Tom Pham <54967380+TomPham97@users.noreply.github.com> Co-authored-by: GRMrGecko <grmrgecko@gmail.com> Co-authored-by: TingTingin <36141041+TingTingin@users.noreply.github.com> Co-authored-by: Logan zoellner <nagolinc@gmail.com> Co-authored-by: M <mchaker@users.noreply.github.com> Co-authored-by: James Pound <jamespoundiv@gmail.com> Co-authored-by: cobryan05 <13701027+cobryan05@users.noreply.github.com> Co-authored-by: Michoko <michoko@hotmail.com> Co-authored-by: VulumeCode <2590984+VulumeCode@users.noreply.github.com> Co-authored-by: xaedes <xaedes@googlemail.com> Co-authored-by: Michael Hearn <git@mikehearn.net> Co-authored-by: Soul-Burn <sugoibaka@gmail.com> Co-authored-by: JJ <jjisnow@gmail.com> Co-authored-by: oc013 <101832295+oc013@users.noreply.github.com> Co-authored-by: Aarni Koskela <akx@iki.fi> Co-authored-by: osi1880vr <87379616+osi1880vr@users.noreply.github.com> Co-authored-by: Rae Fu <rraefu@gmail.com> Co-authored-by: Brian Semrau <brian.semrau@gmail.com> Co-authored-by: Matt Soucy <git@msoucy.me> Co-authored-by: endomorphosis <endomorphosis@users.noreply.github.com> Co-authored-by: unnamedplugins <79282950+unnamedplugins@users.noreply.github.com> Co-authored-by: Syahmi Azhar <prsyahmi@gmail.com> Co-authored-by: Ahmad Abdullah <83442967+ahmad1284@users.noreply.github.com> Co-authored-by: Sérgio <smaisidoro@gmail.com> Co-authored-by: Charlie <outlookhazy@users.noreply.github.com> Co-authored-by: protoplm <protoplmz@gmail.com> Co-authored-by: Ascended <dspradau@gmail.com> Co-authored-by: JuanLagu <32816584+JuanLagu@users.noreply.github.com> Co-authored-by: Chris Heald <cheald@gmail.com> Co-authored-by: Charles Galant <cgalant@gmail.com> Co-authored-by: Alex Birch <birch-san@users.noreply.github.com> Co-authored-by: protoplm <57930981+protoplm@users.noreply.github.com> Co-authored-by: Dekker3D <dekker3d@gmail.com>
This commit is contained in:
parent
ea6b422bff
commit
a797312183
3
.dockerignore
Normal file
3
.dockerignore
Normal file
@ -0,0 +1,3 @@
|
||||
models/
|
||||
outputs/
|
||||
src/
|
@ -6,9 +6,13 @@ CONDA_FORCE_UPDATE=false
|
||||
# (useful to set to false after you're sure the model files are already in place)
|
||||
VALIDATE_MODELS=true
|
||||
|
||||
#Automatically relaunch the webui on crashes
|
||||
# Automatically relaunch the webui on crashes
|
||||
WEBUI_RELAUNCH=true
|
||||
|
||||
#Pass cli arguments to webui.py e.g:
|
||||
#WEBUI_ARGS=--gpu=1 --esrgan-gpu=1 --gfpgan-gpu=1
|
||||
# Which webui to launch
|
||||
# WEBUI_SCRIPT=webui_streamlit.py
|
||||
WEBUI_SCRIPT=webui.py
|
||||
|
||||
# Pass cli arguments to webui.py e.g:
|
||||
# WEBUI_ARGS=--optimized --extra-models-cpu --gpu=1 --esrgan-gpu=1 --gfpgan-gpu=1
|
||||
WEBUI_ARGS=
|
||||
|
7
.gitignore
vendored
7
.gitignore
vendored
@ -47,13 +47,18 @@ MANIFEST
|
||||
.env_updated
|
||||
condaenv.*.requirements.txt
|
||||
|
||||
# Visual Studio directories
|
||||
.vs/
|
||||
.vscode/
|
||||
|
||||
# =========================================================================== #
|
||||
# Repo-specific
|
||||
# =========================================================================== #
|
||||
/configs/webui/userconfig_streamlit.yaml
|
||||
/custom-conda-path.txt
|
||||
/src/*
|
||||
/outputs/*
|
||||
/outputs
|
||||
/model_cache
|
||||
/log/**/*.png
|
||||
/log/log.csv
|
||||
/flagged/*
|
||||
|
@ -1,6 +1,10 @@
|
||||
FROM nvidia/cuda:11.3.1-runtime-ubuntu20.04
|
||||
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
ENV DEBIAN_FRONTEND=noninteractive \
|
||||
PYTHONUNBUFFERED=1 \
|
||||
PYTHONIOENCODING=UTF-8 \
|
||||
CONDA_DIR=/opt/conda
|
||||
|
||||
WORKDIR /sd
|
||||
|
||||
SHELL ["/bin/bash", "-c"]
|
||||
@ -11,7 +15,6 @@ RUN apt-get update && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install miniconda
|
||||
ENV CONDA_DIR /opt/conda
|
||||
RUN wget -O ~/miniconda.sh -q --show-progress --progress=bar:force https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && \
|
||||
/bin/bash ~/miniconda.sh -b -p $CONDA_DIR && \
|
||||
rm ~/miniconda.sh
|
||||
@ -20,7 +23,7 @@ ENV PATH=$CONDA_DIR/bin:$PATH
|
||||
# Install font for prompt matrix
|
||||
COPY /data/DejaVuSans.ttf /usr/share/fonts/truetype/
|
||||
|
||||
EXPOSE 7860
|
||||
EXPOSE 7860 8501
|
||||
|
||||
COPY ./entrypoint.sh /sd/
|
||||
ENTRYPOINT /sd/entrypoint.sh
|
||||
|
@ -46,8 +46,8 @@ Features:
|
||||
|
||||
* Gradio GUI: Idiot-proof, fully featured frontend for both txt2img and img2img generation
|
||||
* No more manually typing parameters, now all you have to do is write your prompt and adjust sliders
|
||||
* GFPGAN Face Correction 🔥: [Download the model](https://github.com/sd-webui/stable-diffusion-webui#gfpgan)Automatically correct distorted faces with a built-in GFPGAN option, fixes them in less than half a second
|
||||
* RealESRGAN Upscaling 🔥: [Download the models](https://github.com/sd-webui/stable-diffusion-webui#realesrgan) Boosts the resolution of images with a built-in RealESRGAN option
|
||||
* GFPGAN Face Correction 🔥: [Download the model](https://github.com/sd-webui/stable-diffusion-webui/wiki/Installation#optional-additional-models) Automatically correct distorted faces with a built-in GFPGAN option, fixes them in less than half a second
|
||||
* RealESRGAN Upscaling 🔥: [Download the models](https://github.com/sd-webui/stable-diffusion-webui/wiki/Installation#optional-additional-models) Boosts the resolution of images with a built-in RealESRGAN option
|
||||
* :computer: esrgan/gfpgan on cpu support :computer:
|
||||
* Textual inversion 🔥: [info](https://textual-inversion.github.io/) - requires enabling, see [here](https://github.com/hlky/sd-enable-textual-inversion), script works as usual without it enabled
|
||||
* Advanced img2img editor :art: :fire: :art:
|
||||
@ -106,7 +106,7 @@ that are not in original script.
|
||||
|
||||
### GFPGAN
|
||||
Lets you improve faces in pictures using the GFPGAN model. There is a checkbox in every tab to use GFPGAN at 100%, and
|
||||
also a separate tab that just allows you to use GFPGAN on any picture, with a slider that controls how strongthe effect is.
|
||||
also a separate tab that just allows you to use GFPGAN on any picture, with a slider that controls how strong the effect is.
|
||||
|
||||
![](images/GFPGAN.png)
|
||||
|
||||
|
@ -12,8 +12,9 @@ txt2img:
|
||||
# 5: Write sample info files
|
||||
# 6: write sample info to log file
|
||||
# 7: jpg samples
|
||||
# 8: Fix faces using GFPGAN
|
||||
# 9: Upscale images using RealESRGAN
|
||||
# 8: Filter NSFW content
|
||||
# 9: Fix faces using GFPGAN
|
||||
# 10: Upscale images using RealESRGAN
|
||||
toggles: [1, 2, 3, 4, 5]
|
||||
sampler_name: k_lms
|
||||
ddim_eta: 0.0 # legacy name, applies to all algorithms.
|
||||
@ -40,8 +41,10 @@ img2img:
|
||||
# 6: Sort samples by prompt
|
||||
# 7: Write sample info files
|
||||
# 8: jpg samples
|
||||
# 9: Fix faces using GFPGAN
|
||||
# 10: Upscale images using Real-ESRGAN
|
||||
# 9: Color correction
|
||||
# 10: Filter NSFW content
|
||||
# 11: Fix faces using GFPGAN
|
||||
# 12: Upscale images using Real-ESRGAN
|
||||
toggles: [1, 4, 5, 6, 7]
|
||||
sampler_name: k_lms
|
||||
ddim_eta: 0.0
|
||||
|
@ -1,14 +1,19 @@
|
||||
# UI defaults configuration file. It is automatically loaded if located at configs/webui/webui_streamlit.yaml.
|
||||
# Any changes made here will be available automatically on the web app without having to stop it.
|
||||
# You may add overrides in a file named "userconfig_streamlit.yaml" in this folder, which can contain any subset
|
||||
# of the properties below.
|
||||
general:
|
||||
gpu: 0
|
||||
outdir: outputs
|
||||
ckpt: "models/ldm/stable-diffusion-v1/model.ckpt"
|
||||
fp:
|
||||
name: 'embeddings/alex/embeddings_gs-11000.pt'
|
||||
default_model: "Stable Diffusion v1.4"
|
||||
default_model_config: "configs/stable-diffusion/v1-inference.yaml"
|
||||
default_model_path: "models/ldm/stable-diffusion-v1/model.ckpt"
|
||||
use_sd_concepts_library: True
|
||||
sd_concepts_library_folder: "models/custom/sd-concepts-library"
|
||||
GFPGAN_dir: "./src/gfpgan"
|
||||
RealESRGAN_dir: "./src/realesrgan"
|
||||
RealESRGAN_model: "RealESRGAN_x4plus"
|
||||
LDSR_dir: "./src/latent-diffusion"
|
||||
outdir_txt2img: outputs/txt2img-samples
|
||||
outdir_img2img: outputs/img2img-samples
|
||||
gfpgan_cpu: False
|
||||
@ -16,43 +21,104 @@ general:
|
||||
extra_models_cpu: False
|
||||
extra_models_gpu: False
|
||||
save_metadata: True
|
||||
save_format: "png"
|
||||
skip_grid: False
|
||||
skip_save: False
|
||||
grid_format: "jpg:95"
|
||||
n_rows: -1
|
||||
no_verify_input: False
|
||||
no_half: False
|
||||
use_float16: False
|
||||
precision: "autocast"
|
||||
optimized: False
|
||||
optimized_turbo: False
|
||||
optimized_config: "optimizedSD/v1-inference.yaml"
|
||||
enable_attention_slicing: False
|
||||
enable_minimal_memory_usage : False
|
||||
update_preview: True
|
||||
update_preview_frequency: 1
|
||||
update_preview_frequency: 5
|
||||
|
||||
txt2img:
|
||||
prompt:
|
||||
height: 512
|
||||
width: 512
|
||||
cfg_scale: 5.0
|
||||
cfg_scale: 7.5
|
||||
seed: ""
|
||||
batch_count: 1
|
||||
batch_size: 1
|
||||
sampling_steps: 50
|
||||
default_sampler: "k_lms"
|
||||
sampling_steps: 30
|
||||
default_sampler: "k_euler"
|
||||
separate_prompts: False
|
||||
update_preview: True
|
||||
update_preview_frequency: 5
|
||||
normalize_prompt_weights: True
|
||||
save_individual_images: True
|
||||
save_grid: True
|
||||
group_by_prompt: True
|
||||
save_as_jpg: False
|
||||
use_GFPGAN: True
|
||||
use_RealESRGAN: True
|
||||
use_GFPGAN: False
|
||||
use_RealESRGAN: False
|
||||
RealESRGAN_model: "RealESRGAN_x4plus"
|
||||
variant_amount: 0.0
|
||||
variant_seed: ""
|
||||
write_info_files: True
|
||||
slider_steps: {
|
||||
sampling: 1
|
||||
}
|
||||
slider_bounds: {
|
||||
sampling: {
|
||||
lower: 1,
|
||||
upper: 150
|
||||
}
|
||||
}
|
||||
|
||||
txt2vid:
|
||||
default_model: "CompVis/stable-diffusion-v1-4"
|
||||
custom_models_list: ["CompVis/stable-diffusion-v1-4", "naclbit/trinart_stable_diffusion_v2", "hakurei/waifu-diffusion", "osanseviero/BigGAN-deep-128"]
|
||||
prompt:
|
||||
height: 512
|
||||
width: 512
|
||||
cfg_scale: 7.5
|
||||
seed: ""
|
||||
batch_count: 1
|
||||
batch_size: 1
|
||||
sampling_steps: 30
|
||||
num_inference_steps: 200
|
||||
default_sampler: "k_euler"
|
||||
scheduler_name: "klms"
|
||||
separate_prompts: False
|
||||
update_preview: True
|
||||
update_preview_frequency: 5
|
||||
dynamic_preview_frequency: True
|
||||
normalize_prompt_weights: True
|
||||
save_individual_images: True
|
||||
save_video: True
|
||||
group_by_prompt: True
|
||||
write_info_files: True
|
||||
do_loop: False
|
||||
save_as_jpg: False
|
||||
use_GFPGAN: False
|
||||
use_RealESRGAN: False
|
||||
RealESRGAN_model: "RealESRGAN_x4plus"
|
||||
variant_amount: 0.0
|
||||
variant_seed: ""
|
||||
beta_start: 0.00085
|
||||
beta_end: 0.012
|
||||
beta_scheduler_type: "linear"
|
||||
max_frames: 1000
|
||||
slider_steps: {
|
||||
sampling: 1
|
||||
}
|
||||
slider_bounds: {
|
||||
sampling: {
|
||||
lower: 1,
|
||||
upper: 150
|
||||
}
|
||||
}
|
||||
|
||||
img2img:
|
||||
prompt:
|
||||
sampling_steps: 50
|
||||
sampling_steps: 30
|
||||
# Adding an int to toggles enables the corresponding feature.
|
||||
# 0: Create prompt matrix (separate multiple prompts using |, and get all combinations of them)
|
||||
# 1: Normalize Prompt Weights (ensure sum of weights add up to 1.0)
|
||||
@ -65,11 +131,12 @@ img2img:
|
||||
# 8: jpg samples
|
||||
# 9: Fix faces using GFPGAN
|
||||
# 10: Upscale images using Real-ESRGAN
|
||||
sampler_name: k_lms
|
||||
denoising_strength: 0.45
|
||||
sampler_name: "k_euler"
|
||||
denoising_strength: 0.75
|
||||
# 0: Keep masked area
|
||||
# 1: Regenerate only masked area
|
||||
mask_mode: 0
|
||||
mask_restore: False
|
||||
# 0: Just resize
|
||||
# 1: Crop and resize
|
||||
# 2: Resize and fill
|
||||
@ -77,7 +144,7 @@ img2img:
|
||||
# Leave blank for random seed:
|
||||
seed: ""
|
||||
ddim_eta: 0.0
|
||||
cfg_scale: 5.0
|
||||
cfg_scale: 7.5
|
||||
batch_count: 1
|
||||
batch_size: 1
|
||||
height: 512
|
||||
@ -87,17 +154,28 @@ img2img:
|
||||
loopback: True
|
||||
random_seed_loopback: True
|
||||
separate_prompts: False
|
||||
update_preview: True
|
||||
update_preview_frequency: 5
|
||||
normalize_prompt_weights: True
|
||||
save_individual_images: True
|
||||
save_grid: True
|
||||
group_by_prompt: True
|
||||
save_as_jpg: False
|
||||
use_GFPGAN: True
|
||||
use_RealESRGAN: True
|
||||
use_GFPGAN: False
|
||||
use_RealESRGAN: False
|
||||
RealESRGAN_model: "RealESRGAN_x4plus"
|
||||
variant_amount: 0.0
|
||||
variant_seed: ""
|
||||
write_info_files: True
|
||||
slider_steps: {
|
||||
sampling: 1
|
||||
}
|
||||
slider_bounds: {
|
||||
sampling: {
|
||||
lower: 1,
|
||||
upper: 150
|
||||
}
|
||||
}
|
||||
|
||||
gfpgan:
|
||||
strength: 100
|
||||
|
||||
|
@ -2,7 +2,7 @@ version: '3.3'
|
||||
|
||||
services:
|
||||
stable-diffusion:
|
||||
container_name: sd
|
||||
container_name: sd-webui
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile
|
||||
@ -12,6 +12,7 @@ services:
|
||||
volumes:
|
||||
- .:/sd
|
||||
- ./outputs:/sd/outputs
|
||||
- ./model_cache:/sd/model_cache
|
||||
- conda_env:/opt/conda
|
||||
- root_profile:/root
|
||||
ports:
|
||||
@ -21,7 +22,7 @@ services:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- capabilities: [gpu]
|
||||
- capabilities: [ gpu ]
|
||||
|
||||
volumes:
|
||||
conda_env:
|
||||
|
9
docker-reset.sh
Normal file → Executable file
9
docker-reset.sh
Normal file → Executable file
@ -10,12 +10,13 @@ echo $(pwd)
|
||||
read -p "Is the directory above correct to run reset on? (y/n) " -n 1 DIRCONFIRM
|
||||
if [[ $DIRCONFIRM =~ ^[Yy]$ ]]; then
|
||||
docker compose down
|
||||
docker image rm stable-diffusion_stable-diffusion:latest
|
||||
docker volume rm stable-diffusion_conda_env
|
||||
docker volume rm stable-diffusion_root_profile
|
||||
docker image rm stable-diffusion-webui_stable-diffusion:latest
|
||||
docker volume rm stable-diffusion-webui_conda_env
|
||||
docker volume rm stable-diffusion-webui_root_profile
|
||||
echo "Remove ./src"
|
||||
sudo rm -rf src
|
||||
sudo rm -rf latent_diffusion.egg-info
|
||||
sudo rm -rf gfpgan
|
||||
sudo rm -rf sd_webui.egg-info
|
||||
sudo rm .env_updated
|
||||
else
|
||||
echo "Exited without resetting"
|
||||
|
@ -3,26 +3,36 @@
|
||||
# Starts the gui inside the docker container using the conda env
|
||||
#
|
||||
|
||||
# set -x
|
||||
|
||||
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
|
||||
cd $SCRIPT_DIR
|
||||
export PYTHONPATH=$SCRIPT_DIR
|
||||
|
||||
MODEL_DIR="${SCRIPT_DIR}/model_cache"
|
||||
# Array of model files to pre-download
|
||||
# local filename
|
||||
# local path in container (no trailing slash)
|
||||
# download URL
|
||||
# sha256sum
|
||||
MODEL_FILES=(
|
||||
'model.ckpt /sd/models/ldm/stable-diffusion-v1 https://www.googleapis.com/storage/v1/b/aai-blog-files/o/sd-v1-4.ckpt?alt=media fe4efff1e174c627256e44ec2991ba279b3816e364b49f9be2abc0b3ff3f8556'
|
||||
'GFPGANv1.3.pth /sd/src/gfpgan/experiments/pretrained_models https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth c953a88f2727c85c3d9ae72e2bd4846bbaf59fe6972ad94130e23e7017524a70'
|
||||
'RealESRGAN_x4plus.pth /sd/src/realesrgan/experiments/pretrained_models https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth 4fa0d38905f75ac06eb49a7951b426670021be3018265fd191d2125df9d682f1'
|
||||
'RealESRGAN_x4plus_anime_6B.pth /sd/src/realesrgan/experiments/pretrained_models https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth f872d837d3c90ed2e05227bed711af5671a6fd1c9f7d7e91c911a61f155e99da'
|
||||
'model.ckpt models/ldm/stable-diffusion-v1 https://www.googleapis.com/storage/v1/b/aai-blog-files/o/sd-v1-4.ckpt?alt=media fe4efff1e174c627256e44ec2991ba279b3816e364b49f9be2abc0b3ff3f8556'
|
||||
'GFPGANv1.3.pth src/gfpgan/experiments/pretrained_models https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth c953a88f2727c85c3d9ae72e2bd4846bbaf59fe6972ad94130e23e7017524a70'
|
||||
'RealESRGAN_x4plus.pth src/realesrgan/experiments/pretrained_models https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth 4fa0d38905f75ac06eb49a7951b426670021be3018265fd191d2125df9d682f1'
|
||||
'RealESRGAN_x4plus_anime_6B.pth src/realesrgan/experiments/pretrained_models https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth f872d837d3c90ed2e05227bed711af5671a6fd1c9f7d7e91c911a61f155e99da'
|
||||
'project.yaml src/latent-diffusion/experiments/pretrained_models https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1 9d6ad53c5dafeb07200fb712db14b813b527edd262bc80ea136777bdb41be2ba'
|
||||
'model.ckpt src/latent-diffusion/experiments/pretrained_models https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1 c209caecac2f97b4bb8f4d726b70ac2ac9b35904b7fc99801e1f5e61f9210c13'
|
||||
)
|
||||
|
||||
# Conda environment installs/updates
|
||||
# @see https://github.com/ContinuumIO/docker-images/issues/89#issuecomment-467287039
|
||||
ENV_NAME="ldm"
|
||||
ENV_FILE="/sd/environment.yaml"
|
||||
ENV_FILE="${SCRIPT_DIR}/environment.yaml"
|
||||
ENV_UPDATED=0
|
||||
ENV_MODIFIED=$(date -r $ENV_FILE "+%s")
|
||||
ENV_MODIFED_FILE="/sd/.env_updated"
|
||||
ENV_MODIFED_FILE="${SCRIPT_DIR}/.env_updated"
|
||||
if [[ -f $ENV_MODIFED_FILE ]]; then ENV_MODIFIED_CACHED=$(<${ENV_MODIFED_FILE}); else ENV_MODIFIED_CACHED=0; fi
|
||||
export PIP_EXISTS_ACTION=w
|
||||
|
||||
# Create/update conda env if needed
|
||||
if ! conda env list | grep ".*${ENV_NAME}.*" >/dev/null 2>&1; then
|
||||
@ -51,54 +61,67 @@ conda info | grep active
|
||||
# Function to checks for valid hash for model files and download/replaces if invalid or does not exist
|
||||
validateDownloadModel() {
|
||||
local file=$1
|
||||
local path=$2
|
||||
local path="${SCRIPT_DIR}/${2}"
|
||||
local url=$3
|
||||
local hash=$4
|
||||
|
||||
echo "checking ${file}..."
|
||||
sha256sum --check --status <<< "${hash} ${path}/${file}"
|
||||
sha256sum --check --status <<< "${hash} ${MODEL_DIR}/${file}.${hash}"
|
||||
if [[ $? == "1" ]]; then
|
||||
echo "Downloading: ${url} please wait..."
|
||||
mkdir -p ${path}
|
||||
wget --output-document=${path}/${file} --no-verbose --show-progress --progress=dot:giga ${url}
|
||||
wget --output-document=${MODEL_DIR}/${file}.${hash} --no-verbose --show-progress --progress=dot:giga ${url}
|
||||
ln -sf ${MODEL_DIR}/${file}.${hash} ${path}/${file}
|
||||
if [[ -e "${path}/${file}" ]]; then
|
||||
echo "saved ${file}"
|
||||
else
|
||||
echo "error saving ${path}/${file}!"
|
||||
exit 1
|
||||
fi
|
||||
else
|
||||
if [[ ! -e ${path}/${file} || ! -L ${path}/${file} ]]; then
|
||||
mkdir -p ${path}
|
||||
ln -sf ${MODEL_DIR}/${file}.${hash} ${path}/${file}
|
||||
echo -e "linked valid ${file}\n"
|
||||
else
|
||||
echo -e "${file} is valid!\n"
|
||||
fi
|
||||
fi
|
||||
}
|
||||
|
||||
# Validate model files
|
||||
if [[ -z $VALIDATE_MODELS || $VALIDATE_MODELS == "true" ]]; then
|
||||
echo "Validating model files..."
|
||||
for models in "${MODEL_FILES[@]}"; do
|
||||
echo "Validating model files..."
|
||||
for models in "${MODEL_FILES[@]}"; do
|
||||
model=($models)
|
||||
if [[ ! -e ${model[1]}/${model[0]} || ! -L ${model[1]}/${model[0]} || -z $VALIDATE_MODELS || $VALIDATE_MODELS == "true" ]]; then
|
||||
validateDownloadModel ${model[0]} ${model[1]} ${model[2]} ${model[3]}
|
||||
done
|
||||
fi
|
||||
fi
|
||||
done
|
||||
|
||||
# Launch web gui
|
||||
cd /sd
|
||||
|
||||
if [[ -z $WEBUI_ARGS ]]; then
|
||||
launch_message="entrypoint.sh: Launching..."
|
||||
if [[ ! -z $WEBUI_SCRIPT && $WEBUI_SCRIPT == "webui_streamlit.py" ]]; then
|
||||
launch_command="streamlit run scripts/${WEBUI_SCRIPT:-webui.py} $WEBUI_ARGS"
|
||||
else
|
||||
launch_message="entrypoint.sh: Launching with arguments ${WEBUI_ARGS}"
|
||||
launch_command="python scripts/${WEBUI_SCRIPT:-webui.py} $WEBUI_ARGS"
|
||||
fi
|
||||
|
||||
launch_message="entrypoint.sh: Run ${launch_command}..."
|
||||
if [[ -z $WEBUI_RELAUNCH || $WEBUI_RELAUNCH == "true" ]]; then
|
||||
n=0
|
||||
while true; do
|
||||
|
||||
echo $launch_message
|
||||
|
||||
if (( $n > 0 )); then
|
||||
echo "Relaunch count: ${n}"
|
||||
fi
|
||||
python -u scripts/webui.py $WEBUI_ARGS
|
||||
|
||||
$launch_command
|
||||
|
||||
echo "entrypoint.sh: Process is ending. Relaunching in 0.5s..."
|
||||
((n++))
|
||||
sleep 0.5
|
||||
done
|
||||
else
|
||||
echo $launch_message
|
||||
python -u scripts/webui.py $WEBUI_ARGS
|
||||
$launch_command
|
||||
fi
|
||||
|
@ -3,39 +3,47 @@ channels:
|
||||
- pytorch
|
||||
- defaults
|
||||
dependencies:
|
||||
- git
|
||||
- python=3.8.5
|
||||
- pip=20.3
|
||||
- cudatoolkit=11.3
|
||||
- git
|
||||
- numpy=1.22.3
|
||||
- pip=20.3
|
||||
- python=3.8.5
|
||||
- pytorch=1.11.0
|
||||
- scikit-image=0.19.2
|
||||
- torchvision=0.12.0
|
||||
- numpy=1.19.2
|
||||
- pip:
|
||||
- albumentations==0.4.3
|
||||
- opencv-python==4.1.2.30
|
||||
- opencv-python-headless==4.1.2.30
|
||||
- pudb==2019.2
|
||||
- imageio==2.9.0
|
||||
- imageio-ffmpeg==0.4.2
|
||||
- pytorch-lightning==1.4.2
|
||||
- omegaconf==2.1.1
|
||||
- test-tube>=0.7.5
|
||||
- einops==0.3.0
|
||||
- torch-fidelity==0.3.0
|
||||
- transformers==4.19.2
|
||||
- torchmetrics==0.6.0
|
||||
- kornia==0.6
|
||||
- gradio==3.1.6
|
||||
- accelerate==0.12.0
|
||||
- pynvml==11.4.1
|
||||
- basicsr>=1.3.4.0
|
||||
- facexlib>=0.2.3
|
||||
- python-slugify>=6.1.2
|
||||
- streamlit>=1.12.2
|
||||
- retry>=0.9.2
|
||||
- -e .
|
||||
- -e git+https://github.com/CompVis/taming-transformers#egg=taming-transformers
|
||||
- -e git+https://github.com/openai/CLIP#egg=clip
|
||||
- -e git+https://github.com/TencentARC/GFPGAN#egg=GFPGAN
|
||||
- -e git+https://github.com/xinntao/Real-ESRGAN#egg=realesrgan
|
||||
- -e git+https://github.com/hlky/k-diffusion-sd#egg=k_diffusion
|
||||
- -e .
|
||||
- -e git+https://github.com/devilismyfriend/latent-diffusion#egg=latent-diffusion
|
||||
- accelerate==0.12.0
|
||||
- albumentations==0.4.3
|
||||
- basicsr>=1.3.4.0
|
||||
- diffusers==0.3.0
|
||||
- einops==0.3.0
|
||||
- facexlib>=0.2.3
|
||||
- gradio==3.1.6
|
||||
- imageio-ffmpeg==0.4.2
|
||||
- imageio==2.9.0
|
||||
- kornia==0.6
|
||||
- omegaconf==2.1.1
|
||||
- opencv-python-headless==4.6.0.66
|
||||
- pandas==1.4.3
|
||||
- piexif==1.1.3
|
||||
- pudb==2019.2
|
||||
- pynvml==11.4.1
|
||||
- python-slugify>=6.1.2
|
||||
- pytorch-lightning==1.4.2
|
||||
- retry>=0.9.2
|
||||
- streamlit>=1.12.2
|
||||
- streamlit-on-Hover-tabs==1.0.1
|
||||
- streamlit-option-menu==0.3.2
|
||||
- streamlit_nested_layout
|
||||
- test-tube>=0.7.5
|
||||
- tensorboard
|
||||
- torch-fidelity==0.3.0
|
||||
- torchmetrics==0.6.0
|
||||
- transformers==4.19.2
|
||||
|
@ -1,15 +1,111 @@
|
||||
.css-18e3th9 {
|
||||
padding-top: 2rem;
|
||||
padding-bottom: 10rem;
|
||||
padding-left: 5rem;
|
||||
padding-right: 5rem;
|
||||
}
|
||||
.css-1d391kg {
|
||||
padding-top: 3.5rem;
|
||||
padding-right: 1rem;
|
||||
padding-bottom: 3.5rem;
|
||||
padding-left: 1rem;
|
||||
}
|
||||
/***********************************************************
|
||||
* Additional CSS for streamlit builtin components *
|
||||
************************************************************/
|
||||
|
||||
/* Tab name (e.g. Text-to-Image) */
|
||||
button[data-baseweb="tab"] {
|
||||
font-size: 25px;
|
||||
font-size: 25px; //improve legibility
|
||||
}
|
||||
|
||||
/* Image Container (only appear after run finished) */
|
||||
.css-du1fp8 {
|
||||
justify-content: center; //center the image, especially better looks in wide screen
|
||||
}
|
||||
|
||||
/* Streamlit header */
|
||||
.css-1avcm0n {
|
||||
background-color: transparent;
|
||||
}
|
||||
|
||||
/* Main streamlit container (below header) */
|
||||
.css-18e3th9 {
|
||||
padding-top: 2rem; //reduce the empty spaces
|
||||
}
|
||||
|
||||
/* @media only for widescreen, to ensure enough space to see all */
|
||||
@media (min-width: 1024px) {
|
||||
/* Main streamlit container (below header) */
|
||||
.css-18e3th9 {
|
||||
padding-top: 0px; //reduce the empty spaces, can go fully to the top on widescreen devices
|
||||
}
|
||||
}
|
||||
|
||||
/***********************************************************
|
||||
* Additional CSS for streamlit custom/3rd party components *
|
||||
************************************************************/
|
||||
/* For stream_on_hover */
|
||||
section[data-testid="stSidebar"] > div:nth-of-type(1) {
|
||||
background-color: #111;
|
||||
}
|
||||
|
||||
button[kind="header"] {
|
||||
background-color: transparent;
|
||||
color: rgb(180, 167, 141);
|
||||
}
|
||||
|
||||
@media (hover) {
|
||||
/* header element */
|
||||
header[data-testid="stHeader"] {
|
||||
/* display: none;*/ /*suggested behavior by streamlit hover components*/
|
||||
pointer-events: none; /* disable interaction of the transparent background */
|
||||
}
|
||||
|
||||
/* The button on the streamlit navigation menu */
|
||||
button[kind="header"] {
|
||||
/* display: none;*/ /*suggested behavior by streamlit hover components*/
|
||||
pointer-events: auto; /* enable interaction of the button even if parents intereaction disabled */
|
||||
}
|
||||
|
||||
/* added to avoid main sectors (all element to the right of sidebar from) moving */
|
||||
section[data-testid="stSidebar"] {
|
||||
width: 3.5% !important;
|
||||
min-width: 3.5% !important;
|
||||
}
|
||||
|
||||
/* The navigation menu specs and size */
|
||||
section[data-testid="stSidebar"] > div {
|
||||
height: 100%;
|
||||
width: 2% !important;
|
||||
min-width: 100% !important;
|
||||
position: relative;
|
||||
z-index: 1;
|
||||
top: 0;
|
||||
left: 0;
|
||||
background-color: #111;
|
||||
overflow-x: hidden;
|
||||
transition: 0.5s ease-in-out;
|
||||
padding-top: 0px;
|
||||
white-space: nowrap;
|
||||
}
|
||||
|
||||
/* The navigation menu open and close on hover and size */
|
||||
section[data-testid="stSidebar"] > div:hover {
|
||||
width: 300px !important;
|
||||
}
|
||||
}
|
||||
|
||||
@media (max-width: 272px) {
|
||||
section[data-testid="stSidebar"] > div {
|
||||
width: 15rem;
|
||||
}
|
||||
}
|
||||
|
||||
/***********************************************************
|
||||
* Additional CSS for other elements
|
||||
************************************************************/
|
||||
button[data-baseweb="tab"] {
|
||||
font-size: 20px;
|
||||
}
|
||||
|
||||
@media (min-width: 1200px){
|
||||
h1 {
|
||||
font-size: 1.75rem;
|
||||
}
|
||||
}
|
||||
#tabs-1-tabpanel-0 > div:nth-child(1) > div > div.stTabs.css-0.exp6ofz0 {
|
||||
width: 50rem;
|
||||
align-self: center;
|
||||
}
|
||||
div.gallery:hover {
|
||||
border: 1px solid #777;
|
||||
}
|
@ -3,6 +3,8 @@ from frontend.css_and_js import css, js, call_JS, js_parse_prompt, js_copy_txt2i
|
||||
from frontend.job_manager import JobManager
|
||||
import frontend.ui_functions as uifn
|
||||
import uuid
|
||||
import torch
|
||||
|
||||
|
||||
|
||||
def draw_gradio_ui(opt, img2img=lambda x: x, txt2img=lambda x: x, imgproc=lambda x: x, txt2img_defaults={},
|
||||
@ -36,8 +38,11 @@ def draw_gradio_ui(opt, img2img=lambda x: x, txt2img=lambda x: x, imgproc=lambda
|
||||
value=txt2img_defaults['cfg_scale'], elem_id='cfg_slider')
|
||||
txt2img_seed = gr.Textbox(label="Seed (blank to randomize)", lines=1, max_lines=1,
|
||||
value=txt2img_defaults["seed"])
|
||||
txt2img_batch_size = gr.Slider(minimum=1, maximum=50, step=1,
|
||||
label='Images per batch',
|
||||
value=txt2img_defaults['batch_size'])
|
||||
txt2img_batch_count = gr.Slider(minimum=1, maximum=50, step=1,
|
||||
label='Number of images to generate',
|
||||
label='Number of batches to generate',
|
||||
value=txt2img_defaults['n_iter'])
|
||||
|
||||
txt2img_job_ui = job_manager.draw_gradio_ui() if job_manager else None
|
||||
@ -51,10 +56,14 @@ def draw_gradio_ui(opt, img2img=lambda x: x, txt2img=lambda x: x, imgproc=lambda
|
||||
gr.Markdown(
|
||||
"Select an image from the gallery, then click one of the buttons below to perform an action.")
|
||||
with gr.Row(elem_id='txt2img_actions_row'):
|
||||
gr.Button("Copy to clipboard").click(fn=None,
|
||||
gr.Button("Copy to clipboard").click(
|
||||
fn=None,
|
||||
inputs=output_txt2img_gallery,
|
||||
outputs=[],
|
||||
# _js=js_copy_to_clipboard( 'txt2img_gallery_output')
|
||||
_js=call_JS(
|
||||
"copyImageFromGalleryToClipboard",
|
||||
fromId="txt2img_gallery_output"
|
||||
)
|
||||
)
|
||||
output_txt2img_copy_to_input_btn = gr.Button("Push to img2img")
|
||||
output_txt2img_to_imglab = gr.Button("Send to Lab", visible=True)
|
||||
@ -91,9 +100,6 @@ def draw_gradio_ui(opt, img2img=lambda x: x, txt2img=lambda x: x, imgproc=lambda
|
||||
with gr.TabItem('Advanced'):
|
||||
txt2img_toggles = gr.CheckboxGroup(label='', choices=txt2img_toggles,
|
||||
value=txt2img_toggle_defaults, type="index")
|
||||
txt2img_batch_size = gr.Slider(minimum=1, maximum=8, step=1,
|
||||
label='Batch size (how many images are in a batch; memory-hungry)',
|
||||
value=txt2img_defaults['batch_size'])
|
||||
txt2img_realesrgan_model_name = gr.Dropdown(label='RealESRGAN model',
|
||||
choices=['RealESRGAN_x4plus',
|
||||
'RealESRGAN_x4plus_anime_6B'],
|
||||
@ -124,20 +130,27 @@ def draw_gradio_ui(opt, img2img=lambda x: x, txt2img=lambda x: x, imgproc=lambda
|
||||
inputs=txt2img_inputs,
|
||||
outputs=txt2img_outputs
|
||||
)
|
||||
use_queue = False
|
||||
else:
|
||||
use_queue = True
|
||||
|
||||
txt2img_btn.click(
|
||||
txt2img_func,
|
||||
txt2img_inputs,
|
||||
txt2img_outputs
|
||||
txt2img_outputs,
|
||||
api_name='txt2img',
|
||||
queue=use_queue
|
||||
)
|
||||
txt2img_prompt.submit(
|
||||
txt2img_func,
|
||||
txt2img_inputs,
|
||||
txt2img_outputs
|
||||
txt2img_outputs,
|
||||
queue=use_queue
|
||||
)
|
||||
|
||||
# txt2img_width.change(fn=uifn.update_dimensions_info, inputs=[txt2img_width, txt2img_height], outputs=txt2img_dimensions_info_text_box)
|
||||
# txt2img_height.change(fn=uifn.update_dimensions_info, inputs=[txt2img_width, txt2img_height], outputs=txt2img_dimensions_info_text_box)
|
||||
txt2img_width.change(fn=uifn.update_dimensions_info, inputs=[txt2img_width, txt2img_height], outputs=txt2img_dimensions_info_text_box)
|
||||
txt2img_height.change(fn=uifn.update_dimensions_info, inputs=[txt2img_width, txt2img_height], outputs=txt2img_dimensions_info_text_box)
|
||||
txt2img_dimensions_info_text_box.value = uifn.update_dimensions_info(txt2img_width.value, txt2img_height.value)
|
||||
|
||||
# Temporarily disable prompt parsing until memory issues could be solved
|
||||
# See #676
|
||||
@ -189,8 +202,9 @@ def draw_gradio_ui(opt, img2img=lambda x: x, txt2img=lambda x: x, imgproc=lambda
|
||||
with gr.TabItem("Editor Options"):
|
||||
with gr.Row():
|
||||
# disable Uncrop for now
|
||||
# choices=["Mask", "Crop", "Uncrop"]
|
||||
img2img_image_editor_mode = gr.Radio(choices=["Mask", "Crop"],
|
||||
choices=["Mask", "Crop", "Uncrop"]
|
||||
#choices=["Mask", "Crop"]
|
||||
img2img_image_editor_mode = gr.Radio(choices=choices,
|
||||
label="Image Editor Mode",
|
||||
value="Mask", elem_id='edit_mode_select',
|
||||
visible=True)
|
||||
@ -199,9 +213,13 @@ def draw_gradio_ui(opt, img2img=lambda x: x, txt2img=lambda x: x, imgproc=lambda
|
||||
value=img2img_mask_modes[img2img_defaults['mask_mode']],
|
||||
visible=True)
|
||||
|
||||
img2img_mask_blur_strength = gr.Slider(minimum=1, maximum=10, step=1,
|
||||
img2img_mask_restore = gr.Checkbox(label="Only modify regenerated parts of image",
|
||||
value=img2img_defaults['mask_restore'],
|
||||
visible=True)
|
||||
|
||||
img2img_mask_blur_strength = gr.Slider(minimum=1, maximum=100, step=1,
|
||||
label="How much blurry should the mask be? (to avoid hard edges)",
|
||||
value=3, visible=False)
|
||||
value=3, visible=True)
|
||||
|
||||
img2img_resize = gr.Radio(label="Resize mode",
|
||||
choices=["Just resize", "Crop and resize",
|
||||
@ -293,7 +311,7 @@ def draw_gradio_ui(opt, img2img=lambda x: x, txt2img=lambda x: x, imgproc=lambda
|
||||
img2img_height
|
||||
],
|
||||
[img2img_image_editor, img2img_image_mask, img2img_btn_editor, img2img_btn_mask,
|
||||
img2img_painterro_btn, img2img_mask, img2img_mask_blur_strength]
|
||||
img2img_painterro_btn, img2img_mask, img2img_mask_blur_strength, img2img_mask_restore]
|
||||
)
|
||||
|
||||
# img2img_image_editor_mode.change(
|
||||
@ -334,8 +352,8 @@ def draw_gradio_ui(opt, img2img=lambda x: x, txt2img=lambda x: x, imgproc=lambda
|
||||
)
|
||||
|
||||
img2img_func = img2img
|
||||
img2img_inputs = [img2img_prompt, img2img_image_editor_mode, img2img_mask,
|
||||
img2img_mask_blur_strength, img2img_steps, img2img_sampling, img2img_toggles,
|
||||
img2img_inputs = [img2img_prompt, img2img_image_editor_mode, img2img_mask, img2img_mask_blur_strength,
|
||||
img2img_mask_restore, img2img_steps, img2img_sampling, img2img_toggles,
|
||||
img2img_realesrgan_model_name, img2img_batch_count, img2img_cfg,
|
||||
img2img_denoising, img2img_seed, img2img_height, img2img_width, img2img_resize,
|
||||
img2img_image_editor, img2img_image_mask, img2img_embeddings]
|
||||
@ -349,11 +367,16 @@ def draw_gradio_ui(opt, img2img=lambda x: x, txt2img=lambda x: x, imgproc=lambda
|
||||
inputs=img2img_inputs,
|
||||
outputs=img2img_outputs,
|
||||
)
|
||||
use_queue = False
|
||||
else:
|
||||
use_queue = True
|
||||
|
||||
img2img_btn_mask.click(
|
||||
img2img_func,
|
||||
img2img_inputs,
|
||||
img2img_outputs
|
||||
img2img_outputs,
|
||||
api_name="img2img",
|
||||
queue=use_queue
|
||||
)
|
||||
|
||||
def img2img_submit_params():
|
||||
@ -383,6 +406,7 @@ def draw_gradio_ui(opt, img2img=lambda x: x, txt2img=lambda x: x, imgproc=lambda
|
||||
outputs=img2img_dimensions_info_text_box)
|
||||
img2img_height.change(fn=uifn.update_dimensions_info, inputs=[img2img_width, img2img_height],
|
||||
outputs=img2img_dimensions_info_text_box)
|
||||
img2img_dimensions_info_text_box.value = uifn.update_dimensions_info(img2img_width.value, img2img_height.value)
|
||||
|
||||
with gr.TabItem("Image Lab", id='imgproc_tab'):
|
||||
gr.Markdown("Post-process results")
|
||||
@ -397,8 +421,7 @@ def draw_gradio_ui(opt, img2img=lambda x: x, txt2img=lambda x: x, imgproc=lambda
|
||||
# value=gfpgan_defaults['strength'])
|
||||
# select folder with images to process
|
||||
with gr.TabItem('Batch Process'):
|
||||
imgproc_folder = gr.File(label="Batch Process", file_count="multiple", source="upload",
|
||||
interactive=True, type="file")
|
||||
imgproc_folder = gr.File(label="Batch Process", file_count="multiple", interactive=True, type="file")
|
||||
imgproc_pngnfo = gr.Textbox(label="PNG Metadata", placeholder="PngNfo", visible=False,
|
||||
max_lines=5)
|
||||
with gr.Row():
|
||||
@ -540,7 +563,7 @@ def draw_gradio_ui(opt, img2img=lambda x: x, txt2img=lambda x: x, imgproc=lambda
|
||||
imgproc_width, imgproc_cfg, imgproc_denoising, imgproc_seed,
|
||||
imgproc_gfpgan_strength, imgproc_ldsr_steps, imgproc_ldsr_pre_downSample,
|
||||
imgproc_ldsr_post_downSample],
|
||||
[imgproc_output])
|
||||
[imgproc_output], api_name="imgproc")
|
||||
|
||||
imgproc_source.change(
|
||||
uifn.get_png_nfo,
|
||||
@ -631,11 +654,12 @@ def draw_gradio_ui(opt, img2img=lambda x: x, txt2img=lambda x: x, imgproc=lambda
|
||||
"""
|
||||
gr.HTML("""
|
||||
<div id="90" style="max-width: 100%; font-size: 14px; text-align: center;" class="output-markdown gr-prose border-solid border border-gray-200 rounded gr-panel">
|
||||
<p>For help and advanced usage guides, visit the <a href="https://github.com/sd-webui/stable-diffusion-webui/wiki" target="_blank">Project Wiki</a></p>
|
||||
<p><a href="https://github.com/sd-webui/stable-diffusion-webui">Stable Diffusion WebUI</a> is an open-source project.
|
||||
If you would like to contribute to development or test bleeding edge builds, use the <a href="https://github.com/sd-webui/stable-diffusion-webui/tree/dev" target="_blank">dev branch</a>.</p>
|
||||
<p>For help and advanced usage guides, visit the <a href="https://github.com/hlky/stable-diffusion-webui/wiki" target="_blank">Project Wiki</a></p>
|
||||
<p>Stable Diffusion WebUI is an open-source project. You can find the latest stable builds on the <a href="https://github.com/hlky/stable-diffusion" target="_blank">main repository</a>.
|
||||
If you would like to contribute to development or test bleeding edge builds, you can visit the <a href="https://github.com/hlky/stable-diffusion-webui" target="_blank">developement repository</a>.</p>
|
||||
<p>Device ID {current_device_index}: {current_device_name}<br/>{total_device_count} total devices</p>
|
||||
</div>
|
||||
""")
|
||||
""".format(current_device_name=torch.cuda.get_device_name(), current_device_index=torch.cuda.current_device(), total_device_count=torch.cuda.device_count()))
|
||||
# Hack: Detect the load event on the frontend
|
||||
# Won't be needed in the next version of gradio
|
||||
# See the relevant PR: https://github.com/gradio-app/gradio/pull/2108
|
||||
|
57
frontend/image_metadata.py
Normal file
57
frontend/image_metadata.py
Normal file
@ -0,0 +1,57 @@
|
||||
''' Class to store image generation parameters to be stored as metadata in the image'''
|
||||
from __future__ import annotations
|
||||
from dataclasses import dataclass, asdict
|
||||
from typing import Dict, Optional
|
||||
from PIL import Image
|
||||
from PIL.PngImagePlugin import PngInfo
|
||||
import copy
|
||||
|
||||
@dataclass
|
||||
class ImageMetadata:
|
||||
prompt: str = None
|
||||
seed: str = None
|
||||
width: str = None
|
||||
height: str = None
|
||||
steps: str = None
|
||||
cfg_scale: str = None
|
||||
normalize_prompt_weights: str = None
|
||||
denoising_strength: str = None
|
||||
GFPGAN: str = None
|
||||
|
||||
def as_png_info(self) -> PngInfo:
|
||||
info = PngInfo()
|
||||
for key, value in self.as_dict().items():
|
||||
info.add_text(key, value)
|
||||
return info
|
||||
|
||||
def as_dict(self) -> Dict[str, str]:
|
||||
return {f"SD:{key}": str(value) for key, value in asdict(self).items() if value is not None}
|
||||
|
||||
@classmethod
|
||||
def set_on_image(cls, image: Image, metadata: ImageMetadata) -> None:
|
||||
''' Sets metadata on image, in both text form and as an ImageMetadata object '''
|
||||
if metadata:
|
||||
image.info = metadata.as_dict()
|
||||
else:
|
||||
metadata = ImageMetadata()
|
||||
image.info["ImageMetadata"] = copy.copy(metadata)
|
||||
|
||||
@classmethod
|
||||
def get_from_image(cls, image: Image) -> Optional[ImageMetadata]:
|
||||
''' Gets metadata from an image, first looking for an ImageMetadata,
|
||||
then if not found tries to construct one from the info '''
|
||||
metadata = image.info.get("ImageMetadata", None)
|
||||
if not metadata:
|
||||
found_metadata = False
|
||||
metadata = ImageMetadata()
|
||||
for key, value in image.info.items():
|
||||
if key.lower().startswith("sd:"):
|
||||
key = key[3:]
|
||||
if f"{key}" in metadata.__dict__:
|
||||
metadata.__dict__[key] = value
|
||||
found_metadata = True
|
||||
if not found_metadata:
|
||||
metadata = None
|
||||
if not metadata:
|
||||
print("Couldn't find metadata on image")
|
||||
return metadata
|
@ -1,7 +1,7 @@
|
||||
''' Provides simple job management for gradio, allowing viewing and stopping in-progress multi-batch generations '''
|
||||
from __future__ import annotations
|
||||
import gradio as gr
|
||||
from gradio.components import Component, Gallery
|
||||
from gradio.components import Component, Gallery, Slider
|
||||
from threading import Event, Timer
|
||||
from typing import Callable, List, Dict, Tuple, Optional, Any
|
||||
from dataclasses import dataclass, field
|
||||
@ -9,6 +9,7 @@ from functools import partial
|
||||
from PIL.Image import Image
|
||||
import uuid
|
||||
import traceback
|
||||
import time
|
||||
|
||||
|
||||
@dataclass(eq=True, frozen=True)
|
||||
@ -30,9 +31,21 @@ class JobInfo:
|
||||
session_key: str
|
||||
job_token: Optional[int] = None
|
||||
images: List[Image] = field(default_factory=list)
|
||||
active_image: Image = None
|
||||
rec_steps_enabled: bool = False
|
||||
rec_steps_imgs: List[Image] = field(default_factory=list)
|
||||
rec_steps_intrvl: int = None
|
||||
rec_steps_to_gallery: bool = False
|
||||
rec_steps_to_file: bool = False
|
||||
should_stop: Event = field(default_factory=Event)
|
||||
refresh_active_image_requested: Event = field(default_factory=Event)
|
||||
refresh_active_image_done: Event = field(default_factory=Event)
|
||||
stop_cur_iter: Event = field(default_factory=Event)
|
||||
active_iteration_cnt: int = field(default_factory=int)
|
||||
job_status: str = field(default_factory=str)
|
||||
finished: bool = False
|
||||
started: bool = False
|
||||
timestamp: float = None
|
||||
removed_output_idxs: List[int] = field(default_factory=list)
|
||||
|
||||
|
||||
@ -76,7 +89,7 @@ class JobManagerUi:
|
||||
'''
|
||||
return self._job_manager._wrap_func(
|
||||
func=func, inputs=inputs, outputs=outputs,
|
||||
refresh_btn=self._refresh_btn, stop_btn=self._stop_btn, status_text=self._status_text
|
||||
job_ui=self
|
||||
)
|
||||
|
||||
_refresh_btn: gr.Button
|
||||
@ -84,10 +97,19 @@ class JobManagerUi:
|
||||
_status_text: gr.Textbox
|
||||
_stop_all_session_btn: gr.Button
|
||||
_free_done_sessions_btn: gr.Button
|
||||
_active_image: gr.Image
|
||||
_active_image_stop_btn: gr.Button
|
||||
_active_image_refresh_btn: gr.Button
|
||||
_rec_steps_intrvl_sldr: gr.Slider
|
||||
_rec_steps_checkbox: gr.Checkbox
|
||||
_save_rec_steps_to_gallery_chkbx: gr.Checkbox
|
||||
_save_rec_steps_to_file_chkbx: gr.Checkbox
|
||||
_job_manager: JobManager
|
||||
|
||||
|
||||
class JobManager:
|
||||
JOB_MAX_START_TIME = 5.0 # How long can a job be stuck 'starting' before assuming it isn't running
|
||||
|
||||
def __init__(self, max_jobs: int):
|
||||
self._max_jobs: int = max_jobs
|
||||
self._avail_job_tokens: List[Any] = list(range(max_jobs))
|
||||
@ -102,11 +124,23 @@ class JobManager:
|
||||
'''
|
||||
assert gr.context.Context.block is not None, "draw_gradio_ui must be called within a 'gr.Blocks' 'with' context"
|
||||
with gr.Tabs():
|
||||
with gr.TabItem("Current Session"):
|
||||
with gr.TabItem("Job Controls"):
|
||||
with gr.Row():
|
||||
stop_btn = gr.Button("Stop", elem_id="stop", variant="secondary")
|
||||
refresh_btn = gr.Button("Refresh", elem_id="refresh", variant="secondary")
|
||||
stop_btn = gr.Button("Stop All Batches", elem_id="stop", variant="secondary")
|
||||
refresh_btn = gr.Button("Refresh Finished Batches", elem_id="refresh", variant="secondary")
|
||||
status_text = gr.Textbox(placeholder="Job Status", interactive=False, show_label=False)
|
||||
with gr.Row():
|
||||
active_image_stop_btn = gr.Button("Skip Active Batch", variant="secondary")
|
||||
active_image_refresh_btn = gr.Button("View Batch Progress", variant="secondary")
|
||||
active_image = gr.Image(type="pil", interactive=False, visible=False, elem_id="active_iteration_image")
|
||||
with gr.TabItem("Batch Progress Settings"):
|
||||
with gr.Row():
|
||||
record_steps_checkbox = gr.Checkbox(value=False, label="Enable Batch Progress Grid")
|
||||
record_steps_interval_slider = gr.Slider(
|
||||
value=3, label="Record Interval (steps)", minimum=1, maximum=25, step=1)
|
||||
with gr.Row() as record_steps_box:
|
||||
steps_to_gallery_checkbox = gr.Checkbox(value=False, label="Save Progress Grid to Gallery")
|
||||
steps_to_file_checkbox = gr.Checkbox(value=False, label="Save Progress Grid to File")
|
||||
with gr.TabItem("Maintenance"):
|
||||
with gr.Row():
|
||||
gr.Markdown(
|
||||
@ -118,9 +152,15 @@ class JobManager:
|
||||
free_done_sessions_btn = gr.Button(
|
||||
"Clear Finished Jobs", elem_id="clear_finished", variant="secondary"
|
||||
)
|
||||
|
||||
return JobManagerUi(_refresh_btn=refresh_btn, _stop_btn=stop_btn, _status_text=status_text,
|
||||
_stop_all_session_btn=stop_all_sessions_btn, _free_done_sessions_btn=free_done_sessions_btn,
|
||||
_job_manager=self)
|
||||
_active_image=active_image, _active_image_stop_btn=active_image_stop_btn,
|
||||
_active_image_refresh_btn=active_image_refresh_btn,
|
||||
_rec_steps_checkbox=record_steps_checkbox,
|
||||
_save_rec_steps_to_gallery_chkbx=steps_to_gallery_checkbox,
|
||||
_save_rec_steps_to_file_chkbx=steps_to_file_checkbox,
|
||||
_rec_steps_intrvl_sldr=record_steps_interval_slider, _job_manager=self)
|
||||
|
||||
def clear_all_finished_jobs(self):
|
||||
''' Removes all currently finished jobs, across all sessions.
|
||||
@ -134,6 +174,7 @@ class JobManager:
|
||||
for session in self._sessions.values():
|
||||
for job in session.jobs.values():
|
||||
job.should_stop.set()
|
||||
job.stop_cur_iter.set()
|
||||
|
||||
def _get_job_token(self, block: bool = False) -> Optional[int]:
|
||||
''' Attempts to acquire a job token, optionally blocking until available '''
|
||||
@ -175,6 +216,26 @@ class JobManager:
|
||||
job_info.should_stop.set()
|
||||
return "Stopping after current batch finishes"
|
||||
|
||||
def _refresh_cur_iter_func(self, func_key: FuncKey, session_key: str) -> List[Component]:
|
||||
''' Updates information from the active iteration '''
|
||||
session_info, job_info = self._get_call_info(func_key, session_key)
|
||||
if job_info is None:
|
||||
return [None, f"Session {session_key} was not running function {func_key}"]
|
||||
|
||||
job_info.refresh_active_image_requested.set()
|
||||
if job_info.refresh_active_image_done.wait(timeout=20.0):
|
||||
job_info.refresh_active_image_done.clear()
|
||||
return [gr.Image.update(value=job_info.active_image, visible=True), f"Sample iteration {job_info.active_iteration_cnt}"]
|
||||
return [gr.Image.update(visible=False), "Timed out getting image"]
|
||||
|
||||
def _stop_cur_iter_func(self, func_key: FuncKey, session_key: str) -> List[Component]:
|
||||
''' Marks that the active iteration should be stopped'''
|
||||
session_info, job_info = self._get_call_info(func_key, session_key)
|
||||
if job_info is None:
|
||||
return [None, f"Session {session_key} was not running function {func_key}"]
|
||||
job_info.stop_cur_iter.set()
|
||||
return [gr.Image.update(visible=False), "Stopping current iteration"]
|
||||
|
||||
def _get_call_info(self, func_key: FuncKey, session_key: str) -> Tuple[SessionInfo, JobInfo]:
|
||||
''' Helper to get the SessionInfo and JobInfo. '''
|
||||
session_info = self._sessions.get(session_key, None)
|
||||
@ -207,19 +268,22 @@ class JobManager:
|
||||
|
||||
def _pre_call_func(
|
||||
self, func_key: FuncKey, output_dummy_obj: Component, refresh_btn: gr.Button, stop_btn: gr.Button,
|
||||
status_text: gr.Textbox, session_key: str) -> List[Component]:
|
||||
status_text: gr.Textbox, active_image: gr.Image, active_refresh_btn: gr.Button, active_stop_btn: gr.Button,
|
||||
session_key: str) -> List[Component]:
|
||||
''' Called when a job is about to start '''
|
||||
session_info, job_info = self._get_call_info(func_key, session_key)
|
||||
|
||||
# If we didn't already get a token then queue up for one
|
||||
if job_info.job_token is None:
|
||||
job_info.token = self._get_job_token(block=True)
|
||||
job_info.job_token = self._get_job_token(block=True)
|
||||
|
||||
# Buttons don't seem to update unless value is set on them as well...
|
||||
return {output_dummy_obj: triggerChangeEvent(),
|
||||
refresh_btn: gr.Button.update(variant="primary", value=refresh_btn.value),
|
||||
stop_btn: gr.Button.update(variant="primary", value=stop_btn.value),
|
||||
status_text: gr.Textbox.update(value="Generation has started. Click 'Refresh' for updates")
|
||||
status_text: gr.Textbox.update(value="Generation has started. Click 'Refresh' to see finished images, 'View Batch Progress' for active images"),
|
||||
active_refresh_btn: gr.Button.update(variant="primary", value=active_refresh_btn.value),
|
||||
active_stop_btn: gr.Button.update(variant="primary", value=active_stop_btn.value),
|
||||
}
|
||||
|
||||
def _call_func(self, func_key: FuncKey, session_key: str) -> List[Component]:
|
||||
@ -228,12 +292,19 @@ class JobManager:
|
||||
if session_info is None or job_info is None:
|
||||
return []
|
||||
|
||||
job_info.started = True
|
||||
try:
|
||||
if job_info.should_stop.is_set():
|
||||
raise Exception(f"Job {job_info} requested a stop before execution began")
|
||||
outputs = job_info.func(*job_info.inputs, job_info=job_info)
|
||||
except Exception as e:
|
||||
job_info.job_status = f"Error: {e}"
|
||||
print(f"Exception processing job {job_info}: {e}\n{traceback.format_exc()}")
|
||||
outputs = []
|
||||
raise
|
||||
finally:
|
||||
job_info.finished = True
|
||||
session_info.finished_jobs[func_key] = session_info.jobs.pop(func_key)
|
||||
self._release_job_token(job_info.job_token)
|
||||
|
||||
# Filter the function output for any removed outputs
|
||||
filtered_output = []
|
||||
@ -241,11 +312,6 @@ class JobManager:
|
||||
if idx not in job_info.removed_output_idxs:
|
||||
filtered_output.append(output)
|
||||
|
||||
job_info.finished = True
|
||||
session_info.finished_jobs[func_key] = session_info.jobs.pop(func_key)
|
||||
|
||||
self._release_job_token(job_info.job_token)
|
||||
|
||||
# The wrapper added a dummy JSON output. Append a random text string
|
||||
# to fire the dummy objects 'change' event to notify that the job is done
|
||||
filtered_output.append(triggerChangeEvent())
|
||||
@ -254,12 +320,16 @@ class JobManager:
|
||||
|
||||
def _post_call_func(
|
||||
self, func_key: FuncKey, output_dummy_obj: Component, refresh_btn: gr.Button, stop_btn: gr.Button,
|
||||
status_text: gr.Textbox, session_key: str) -> List[Component]:
|
||||
status_text: gr.Textbox, active_image: gr.Image, active_refresh_btn: gr.Button, active_stop_btn: gr.Button,
|
||||
session_key: str) -> List[Component]:
|
||||
''' Called when a job completes '''
|
||||
return {output_dummy_obj: triggerChangeEvent(),
|
||||
refresh_btn: gr.Button.update(variant="secondary", value=refresh_btn.value),
|
||||
stop_btn: gr.Button.update(variant="secondary", value=stop_btn.value),
|
||||
status_text: gr.Textbox.update(value="Generation has finished!")
|
||||
status_text: gr.Textbox.update(value="Generation has finished!"),
|
||||
active_refresh_btn: gr.Button.update(variant="secondary", value=active_refresh_btn.value),
|
||||
active_stop_btn: gr.Button.update(variant="secondary", value=active_stop_btn.value),
|
||||
active_image: gr.Image.update(visible=False)
|
||||
}
|
||||
|
||||
def _update_gallery_event(self, func_key: FuncKey, session_key: str) -> List[Component]:
|
||||
@ -270,21 +340,17 @@ class JobManager:
|
||||
if session_info is None or job_info is None:
|
||||
return []
|
||||
|
||||
if job_info.finished:
|
||||
session_info.finished_jobs.pop(func_key)
|
||||
|
||||
return job_info.images
|
||||
|
||||
def _wrap_func(
|
||||
self, func: Callable, inputs: List[Component], outputs: List[Component],
|
||||
refresh_btn: gr.Button = None, stop_btn: gr.Button = None,
|
||||
status_text: Optional[gr.Textbox] = None) -> Tuple[Callable, List[Component]]:
|
||||
def _wrap_func(self, func: Callable, inputs: List[Component],
|
||||
outputs: List[Component],
|
||||
job_ui: JobManagerUi) -> Tuple[Callable, List[Component]]:
|
||||
''' handles JobManageUI's wrap_func'''
|
||||
|
||||
assert gr.context.Context.block is not None, "wrap_func must be called within a 'gr.Blocks' 'with' context"
|
||||
|
||||
# Create a unique key for this job
|
||||
func_key = FuncKey(job_id=uuid.uuid4(), func=func)
|
||||
func_key = FuncKey(job_id=uuid.uuid4().hex, func=func)
|
||||
|
||||
# Create a unique session key (next gradio release can use gr.State, see https://gradio.app/state_in_blocks/)
|
||||
if self._session_key is None:
|
||||
@ -302,31 +368,59 @@ class JobManager:
|
||||
del outputs[idx]
|
||||
break
|
||||
|
||||
# Add the session key to the inputs
|
||||
inputs += [self._session_key]
|
||||
|
||||
# Create dummy objects
|
||||
update_gallery_obj = gr.JSON(visible=False, elem_id="JobManagerDummyObject")
|
||||
update_gallery_obj.change(
|
||||
partial(self._update_gallery_event, func_key),
|
||||
[self._session_key],
|
||||
[gallery_comp]
|
||||
[gallery_comp],
|
||||
queue=False
|
||||
)
|
||||
|
||||
if refresh_btn:
|
||||
refresh_btn.variant = 'secondary'
|
||||
refresh_btn.click(
|
||||
if job_ui._refresh_btn:
|
||||
job_ui._refresh_btn.variant = 'secondary'
|
||||
job_ui._refresh_btn.click(
|
||||
partial(self._refresh_func, func_key),
|
||||
[self._session_key],
|
||||
[update_gallery_obj, status_text]
|
||||
[update_gallery_obj, job_ui._status_text],
|
||||
queue=False
|
||||
)
|
||||
|
||||
if stop_btn:
|
||||
stop_btn.variant = 'secondary'
|
||||
stop_btn.click(
|
||||
if job_ui._stop_btn:
|
||||
job_ui._stop_btn.variant = 'secondary'
|
||||
job_ui._stop_btn.click(
|
||||
partial(self._stop_wrapped_func, func_key),
|
||||
[self._session_key],
|
||||
[status_text]
|
||||
[job_ui._status_text],
|
||||
queue=False
|
||||
)
|
||||
|
||||
if job_ui._active_image and job_ui._active_image_refresh_btn:
|
||||
job_ui._active_image_refresh_btn.click(
|
||||
partial(self._refresh_cur_iter_func, func_key),
|
||||
[self._session_key],
|
||||
[job_ui._active_image, job_ui._status_text],
|
||||
queue=False
|
||||
)
|
||||
|
||||
if job_ui._active_image_stop_btn:
|
||||
job_ui._active_image_stop_btn.click(
|
||||
partial(self._stop_cur_iter_func, func_key),
|
||||
[self._session_key],
|
||||
[job_ui._active_image, job_ui._status_text],
|
||||
queue=False
|
||||
)
|
||||
|
||||
if job_ui._stop_all_session_btn:
|
||||
job_ui._stop_all_session_btn.click(
|
||||
self.stop_all_jobs, [], [],
|
||||
queue=False
|
||||
)
|
||||
|
||||
if job_ui._free_done_sessions_btn:
|
||||
job_ui._free_done_sessions_btn.click(
|
||||
self.clear_all_finished_jobs, [], [],
|
||||
queue=False
|
||||
)
|
||||
|
||||
# (ab)use gr.JSON to forward events.
|
||||
@ -343,7 +437,8 @@ class JobManager:
|
||||
# Since some parameters are optional it makes sense to use the 'dict' return value type, which requires
|
||||
# the Component as a key... so group together the UI components that the event listeners are going to update
|
||||
# to make it easy to append to function calls and outputs
|
||||
job_ui_params = [refresh_btn, stop_btn, status_text]
|
||||
job_ui_params = [job_ui._refresh_btn, job_ui._stop_btn, job_ui._status_text,
|
||||
job_ui._active_image, job_ui._active_image_refresh_btn, job_ui._active_image_stop_btn]
|
||||
job_ui_outputs = [comp for comp in job_ui_params if comp is not None]
|
||||
|
||||
# Here a chain is constructed that will make a 'pre' call, a 'run' call, and a 'post' call,
|
||||
@ -352,44 +447,70 @@ class JobManager:
|
||||
post_call_dummyobj.change(
|
||||
partial(self._post_call_func, func_key, update_gallery_obj, *job_ui_params),
|
||||
[self._session_key],
|
||||
[update_gallery_obj] + job_ui_outputs
|
||||
[update_gallery_obj] + job_ui_outputs,
|
||||
queue=False
|
||||
)
|
||||
|
||||
call_dummyobj = gr.JSON(visible=False, elem_id="JobManagerDummyObject_runCall")
|
||||
call_dummyobj.change(
|
||||
partial(self._call_func, func_key),
|
||||
[self._session_key],
|
||||
outputs + [post_call_dummyobj]
|
||||
outputs + [post_call_dummyobj],
|
||||
queue=False
|
||||
)
|
||||
|
||||
pre_call_dummyobj = gr.JSON(visible=False, elem_id="JobManagerDummyObject_preCall")
|
||||
pre_call_dummyobj.change(
|
||||
partial(self._pre_call_func, func_key, call_dummyobj, *job_ui_params),
|
||||
[self._session_key],
|
||||
[call_dummyobj] + job_ui_outputs
|
||||
[call_dummyobj] + job_ui_outputs,
|
||||
queue=False
|
||||
)
|
||||
|
||||
# Now replace the original function with one that creates a JobInfo and triggers the dummy obj
|
||||
# Add any components that we want the runtime values for
|
||||
added_inputs = [self._session_key, job_ui._rec_steps_checkbox, job_ui._save_rec_steps_to_gallery_chkbx,
|
||||
job_ui._save_rec_steps_to_file_chkbx, job_ui._rec_steps_intrvl_sldr]
|
||||
|
||||
def wrapped_func(*inputs):
|
||||
session_key = inputs[-1]
|
||||
inputs = inputs[:-1]
|
||||
# Now replace the original function with one that creates a JobInfo and triggers the dummy obj
|
||||
def wrapped_func(*wrapped_inputs):
|
||||
# Remove the added_inputs (pop opposite order of list)
|
||||
|
||||
wrapped_inputs = list(wrapped_inputs)
|
||||
rec_steps_interval: int = wrapped_inputs.pop()
|
||||
save_rec_steps_file: bool = wrapped_inputs.pop()
|
||||
save_rec_steps_grid: bool = wrapped_inputs.pop()
|
||||
record_steps_enabled: bool = wrapped_inputs.pop()
|
||||
session_key: str = wrapped_inputs.pop()
|
||||
job_inputs = tuple(wrapped_inputs)
|
||||
|
||||
# Get or create a session for this key
|
||||
session_info = self._sessions.setdefault(session_key, SessionInfo())
|
||||
|
||||
# Is this session already running this job?
|
||||
if func_key in session_info.jobs:
|
||||
return {status_text: "This session is already running that function!"}
|
||||
job_info = session_info.jobs[func_key]
|
||||
# If the job seems stuck in 'starting' then go ahead and toss it
|
||||
if not job_info.started and time.time() > job_info.timestamp + JobManager.JOB_MAX_START_TIME:
|
||||
job_info.should_stop.set()
|
||||
job_info.stop_cur_iter.set()
|
||||
session_info.jobs.pop(func_key)
|
||||
return {job_ui._status_text: "Canceled possibly hung job. Try again"}
|
||||
return {job_ui._status_text: "This session is already running that function!"}
|
||||
|
||||
# Is this a new run of a previously finished job? Clear old info
|
||||
if func_key in session_info.finished_jobs:
|
||||
session_info.finished_jobs.pop(func_key)
|
||||
|
||||
job_token = self._get_job_token(block=False)
|
||||
job = JobInfo(inputs=inputs, func=func, removed_output_idxs=removed_idxs, session_key=session_key,
|
||||
job_token=job_token)
|
||||
job = JobInfo(
|
||||
inputs=job_inputs, func=func, removed_output_idxs=removed_idxs, session_key=session_key,
|
||||
job_token=job_token, rec_steps_enabled=record_steps_enabled, rec_steps_intrvl=rec_steps_interval,
|
||||
rec_steps_to_gallery=save_rec_steps_grid, rec_steps_to_file=save_rec_steps_file, timestamp=time.time())
|
||||
session_info.jobs[func_key] = job
|
||||
|
||||
ret = {pre_call_dummyobj: triggerChangeEvent()}
|
||||
if job_token is None:
|
||||
ret[status_text] = "Job is queued"
|
||||
ret[job_ui._status_text] = "Job is queued"
|
||||
return ret
|
||||
|
||||
return wrapped_func, inputs, [pre_call_dummyobj, status_text]
|
||||
return wrapped_func, inputs + added_inputs, [pre_call_dummyobj, job_ui._status_text]
|
||||
|
@ -9,10 +9,10 @@ import re
|
||||
def change_image_editor_mode(choice, cropped_image, masked_image, resize_mode, width, height):
|
||||
if choice == "Mask":
|
||||
update_image_result = update_image_mask(cropped_image, resize_mode, width, height)
|
||||
return [gr.update(visible=False), update_image_result, gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=True), gr.update(visible=True)]
|
||||
return [gr.update(visible=False), update_image_result, gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)]
|
||||
|
||||
update_image_result = update_image_mask(masked_image["image"] if masked_image is not None else None, resize_mode, width, height)
|
||||
return [update_image_result, gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)]
|
||||
return [update_image_result, gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)]
|
||||
|
||||
def update_image_mask(cropped_image, resize_mode, width, height):
|
||||
resized_cropped_image = resize_image(resize_mode, cropped_image, width, height) if cropped_image else None
|
||||
|
BIN
images/nsfw.jpeg
Normal file
BIN
images/nsfw.jpeg
Normal file
Binary file not shown.
After Width: | Height: | Size: 25 KiB |
@ -7,6 +7,8 @@ from einops import rearrange, repeat
|
||||
|
||||
from ldm.modules.diffusionmodules.util import checkpoint
|
||||
|
||||
import psutil
|
||||
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
@ -167,30 +169,98 @@ class CrossAttention(nn.Module):
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
self.einsum_op = self.einsum_op_cuda
|
||||
else:
|
||||
self.mem_total = psutil.virtual_memory().total / (1024**3)
|
||||
self.einsum_op = self.einsum_op_mps_v1 if self.mem_total >= 32 else self.einsum_op_mps_v2
|
||||
|
||||
def einsum_op_compvis(self, q, k, v, r1):
|
||||
s1 = einsum('b i d, b j d -> b i j', q, k) * self.scale # faster
|
||||
s2 = s1.softmax(dim=-1, dtype=q.dtype)
|
||||
del s1
|
||||
r1 = einsum('b i j, b j d -> b i d', s2, v)
|
||||
del s2
|
||||
return r1
|
||||
|
||||
def einsum_op_mps_v1(self, q, k, v, r1):
|
||||
if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096
|
||||
r1 = self.einsum_op_compvis(q, k, v, r1)
|
||||
else:
|
||||
slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
|
||||
for i in range(0, q.shape[1], slice_size):
|
||||
end = i + slice_size
|
||||
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * self.scale
|
||||
s2 = s1.softmax(dim=-1, dtype=r1.dtype)
|
||||
del s1
|
||||
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
|
||||
del s2
|
||||
return r1
|
||||
|
||||
def einsum_op_mps_v2(self, q, k, v, r1):
|
||||
if self.mem_total >= 8 and q.shape[1] <= 4096:
|
||||
r1 = self.einsum_op_compvis(q, k, v, r1)
|
||||
else:
|
||||
slice_size = 1
|
||||
for i in range(0, q.shape[0], slice_size):
|
||||
end = min(q.shape[0], i + slice_size)
|
||||
s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end])
|
||||
s1 *= self.scale
|
||||
s2 = s1.softmax(dim=-1, dtype=r1.dtype)
|
||||
del s1
|
||||
r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end])
|
||||
del s2
|
||||
return r1
|
||||
|
||||
def einsum_op_cuda(self, q, k, v, r1):
|
||||
stats = torch.cuda.memory_stats(q.device)
|
||||
mem_active = stats['active_bytes.all.current']
|
||||
mem_reserved = stats['reserved_bytes.all.current']
|
||||
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
|
||||
mem_free_torch = mem_reserved - mem_active
|
||||
mem_free_total = mem_free_cuda + mem_free_torch
|
||||
|
||||
gb = 1024 ** 3
|
||||
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * 4
|
||||
mem_required = tensor_size * 2.5
|
||||
steps = 1
|
||||
|
||||
if mem_required > mem_free_total:
|
||||
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
|
||||
|
||||
if steps > 64:
|
||||
max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
|
||||
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
|
||||
f'Need: {mem_required/64/gb:0.1f}GB free, Have:{mem_free_total/gb:0.1f}GB free')
|
||||
|
||||
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
|
||||
for i in range(0, q.shape[1], slice_size):
|
||||
end = min(q.shape[1], i + slice_size)
|
||||
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * self.scale
|
||||
s2 = s1.softmax(dim=-1, dtype=r1.dtype)
|
||||
del s1
|
||||
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
|
||||
del s2
|
||||
return r1
|
||||
|
||||
def forward(self, x, context=None, mask=None):
|
||||
h = self.heads
|
||||
|
||||
q = self.to_q(x)
|
||||
context = default(context, x)
|
||||
del x
|
||||
k = self.to_k(context)
|
||||
v = self.to_v(context)
|
||||
del context
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
||||
|
||||
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
|
||||
|
||||
if exists(mask):
|
||||
mask = rearrange(mask, 'b ... -> b (...)')
|
||||
max_neg_value = -torch.finfo(sim.dtype).max
|
||||
mask = repeat(mask, 'b j -> (b h) () j', h=h)
|
||||
sim.masked_fill_(~mask, max_neg_value)
|
||||
|
||||
# attention, what we cannot get enough of
|
||||
attn = sim.softmax(dim=-1)
|
||||
|
||||
out = einsum('b i j, b j d -> b i d', attn, v)
|
||||
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
|
||||
return self.to_out(out)
|
||||
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||
r1 = self.einsum_op(q, k, v, r1)
|
||||
del q, k, v
|
||||
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
|
||||
del r1
|
||||
return self.to_out(r2)
|
||||
|
||||
|
||||
class BasicTransformerBlock(nn.Module):
|
||||
@ -209,9 +279,10 @@ class BasicTransformerBlock(nn.Module):
|
||||
return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
|
||||
|
||||
def _forward(self, x, context=None):
|
||||
x = self.attn1(self.norm1(x)) + x
|
||||
x = self.attn2(self.norm2(x), context=context) + x
|
||||
x = self.ff(self.norm3(x)) + x
|
||||
x = x.contiguous() if x.device.type == 'mps' else x
|
||||
x += self.attn1(self.norm1(x))
|
||||
x += self.attn2(self.norm2(x), context=context)
|
||||
x += self.ff(self.norm3(x))
|
||||
return x
|
||||
|
||||
|
||||
|
@ -1,4 +1,5 @@
|
||||
# pytorch_diffusion + derived encoder decoder
|
||||
import gc
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -119,18 +120,30 @@ class ResnetBlock(nn.Module):
|
||||
padding=0)
|
||||
|
||||
def forward(self, x, temb):
|
||||
h = x
|
||||
h = self.norm1(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv1(h)
|
||||
h1 = x
|
||||
h2 = self.norm1(h1)
|
||||
del h1
|
||||
|
||||
h3 = nonlinearity(h2)
|
||||
del h2
|
||||
|
||||
h4 = self.conv1(h3)
|
||||
del h3
|
||||
|
||||
if temb is not None:
|
||||
h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
|
||||
h4 = h4 + self.temb_proj(nonlinearity(temb))[:,:,None,None]
|
||||
|
||||
h = self.norm2(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.dropout(h)
|
||||
h = self.conv2(h)
|
||||
h5 = self.norm2(h4)
|
||||
del h4
|
||||
|
||||
h6 = nonlinearity(h5)
|
||||
del h5
|
||||
|
||||
h7 = self.dropout(h6)
|
||||
del h6
|
||||
|
||||
h8 = self.conv2(h7)
|
||||
del h7
|
||||
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
@ -138,7 +151,7 @@ class ResnetBlock(nn.Module):
|
||||
else:
|
||||
x = self.nin_shortcut(x)
|
||||
|
||||
return x+h
|
||||
return x + h8
|
||||
|
||||
|
||||
class LinAttnBlock(LinearAttention):
|
||||
@ -178,28 +191,65 @@ class AttnBlock(nn.Module):
|
||||
def forward(self, x):
|
||||
h_ = x
|
||||
h_ = self.norm(h_)
|
||||
q = self.q(h_)
|
||||
k = self.k(h_)
|
||||
q1 = self.q(h_)
|
||||
k1 = self.k(h_)
|
||||
v = self.v(h_)
|
||||
|
||||
# compute attention
|
||||
b,c,h,w = q.shape
|
||||
q = q.reshape(b,c,h*w)
|
||||
q = q.permute(0,2,1) # b,hw,c
|
||||
k = k.reshape(b,c,h*w) # b,c,hw
|
||||
w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
||||
w_ = w_ * (int(c)**(-0.5))
|
||||
w_ = torch.nn.functional.softmax(w_, dim=2)
|
||||
b, c, h, w = q1.shape
|
||||
|
||||
q2 = q1.reshape(b, c, h*w)
|
||||
del q1
|
||||
|
||||
q = q2.permute(0, 2, 1) # b,hw,c
|
||||
del q2
|
||||
|
||||
k = k1.reshape(b, c, h*w) # b,c,hw
|
||||
del k1
|
||||
|
||||
h_ = torch.zeros_like(k, device=q.device)
|
||||
|
||||
stats = torch.cuda.memory_stats(q.device)
|
||||
mem_active = stats['active_bytes.all.current']
|
||||
mem_reserved = stats['reserved_bytes.all.current']
|
||||
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
|
||||
mem_free_torch = mem_reserved - mem_active
|
||||
mem_free_total = mem_free_cuda + mem_free_torch
|
||||
|
||||
tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * 4
|
||||
mem_required = tensor_size * 2.5
|
||||
steps = 1
|
||||
|
||||
if mem_required > mem_free_total:
|
||||
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
|
||||
|
||||
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
|
||||
for i in range(0, q.shape[1], slice_size):
|
||||
end = i + slice_size
|
||||
|
||||
w1 = torch.bmm(q[:, i:end], k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
||||
w2 = w1 * (int(c)**(-0.5))
|
||||
del w1
|
||||
w3 = torch.nn.functional.softmax(w2, dim=2)
|
||||
del w2
|
||||
|
||||
# attend to values
|
||||
v = v.reshape(b,c,h*w)
|
||||
w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
|
||||
h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
||||
h_ = h_.reshape(b,c,h,w)
|
||||
v1 = v.reshape(b, c, h*w)
|
||||
w4 = w3.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
|
||||
del w3
|
||||
|
||||
h_ = self.proj_out(h_)
|
||||
h_[:, :, i:end] = torch.bmm(v1, w4) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
||||
del v1, w4
|
||||
|
||||
return x+h_
|
||||
h2 = h_.reshape(b, c, h, w)
|
||||
del h_
|
||||
|
||||
h3 = self.proj_out(h2)
|
||||
del h2
|
||||
|
||||
h3 += x
|
||||
|
||||
return h3
|
||||
|
||||
|
||||
def make_attn(in_channels, attn_type="vanilla"):
|
||||
@ -540,31 +590,54 @@ class Decoder(nn.Module):
|
||||
temb = None
|
||||
|
||||
# z to block_in
|
||||
h = self.conv_in(z)
|
||||
h1 = self.conv_in(z)
|
||||
|
||||
# middle
|
||||
h = self.mid.block_1(h, temb)
|
||||
h = self.mid.attn_1(h)
|
||||
h = self.mid.block_2(h, temb)
|
||||
h2 = self.mid.block_1(h1, temb)
|
||||
del h1
|
||||
|
||||
h3 = self.mid.attn_1(h2)
|
||||
del h2
|
||||
|
||||
h = self.mid.block_2(h3, temb)
|
||||
del h3
|
||||
|
||||
# prepare for up sampling
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# upsampling
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
for i_block in range(self.num_res_blocks+1):
|
||||
h = self.up[i_level].block[i_block](h, temb)
|
||||
if len(self.up[i_level].attn) > 0:
|
||||
h = self.up[i_level].attn[i_block](h)
|
||||
t = h
|
||||
h = self.up[i_level].attn[i_block](t)
|
||||
del t
|
||||
|
||||
if i_level != 0:
|
||||
h = self.up[i_level].upsample(h)
|
||||
t = h
|
||||
h = self.up[i_level].upsample(t)
|
||||
del t
|
||||
|
||||
# end
|
||||
if self.give_pre_end:
|
||||
return h
|
||||
|
||||
h = self.norm_out(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h)
|
||||
h1 = self.norm_out(h)
|
||||
del h
|
||||
|
||||
h2 = nonlinearity(h1)
|
||||
del h1
|
||||
|
||||
h = self.conv_out(h2)
|
||||
del h2
|
||||
|
||||
if self.tanh_out:
|
||||
h = torch.tanh(h)
|
||||
t = h
|
||||
h = torch.tanh(t)
|
||||
del t
|
||||
|
||||
return h
|
||||
|
||||
|
||||
|
@ -54,7 +54,8 @@ def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timestep
|
||||
|
||||
# assert ddim_timesteps.shape[0] == num_ddim_timesteps
|
||||
# add one to get the final alpha values right (the ones from first scale to data during sampling)
|
||||
steps_out = ddim_timesteps + 1
|
||||
# steps_out = ddim_timesteps + 1 # removed due to some issues when reaching 1000
|
||||
steps_out = np.where(ddim_timesteps != 999, ddim_timesteps+1, ddim_timesteps)
|
||||
if verbose:
|
||||
print(f'Selected timesteps for ddim sampler: {steps_out}')
|
||||
return steps_out
|
||||
|
1312
scripts/DeforumStableDiffusion.py
Normal file
1312
scripts/DeforumStableDiffusion.py
Normal file
File diff suppressed because it is too large
Load Diff
46
scripts/ModelManager.py
Normal file
46
scripts/ModelManager.py
Normal file
@ -0,0 +1,46 @@
|
||||
# base webui import and utils.
|
||||
from webui_streamlit import st
|
||||
from sd_utils import *
|
||||
|
||||
# streamlit imports
|
||||
|
||||
|
||||
#other imports
|
||||
import pandas as pd
|
||||
from io import StringIO
|
||||
|
||||
# Temp imports
|
||||
|
||||
|
||||
# end of imports
|
||||
#---------------------------------------------------------------------------------------------------------------
|
||||
|
||||
def layout():
|
||||
#search = st.text_input(label="Search", placeholder="Type the name of the model you want to search for.", help="")
|
||||
|
||||
csvString = f"""
|
||||
,Stable Diffusion v1.4 , ./models/ldm/stable-diffusion-v1 , https://www.googleapis.com/storage/v1/b/aai-blog-files/o/sd-v1-4.ckpt?alt=media
|
||||
,GFPGAN v1.3 , ./src/gfpgan/experiments/pretrained_models , https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth
|
||||
,RealESRGAN_x4plus , ./src/realesrgan/experiments/pretrained_models , https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth
|
||||
,RealESRGAN_x4plus_anime_6B , ./src/realesrgan/experiments/pretrained_models , https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth
|
||||
,Waifu Diffusion v1.2 , ./models/custom , http://wd.links.sd:8880/wd-v1-2-full-ema.ckpt
|
||||
,TrinArt Stable Diffusion v2 , ./models/custom , https://huggingface.co/naclbit/trinart_stable_diffusion_v2/resolve/main/trinart2_step115000.ckpt
|
||||
,Stable Diffusion Concept Library , ./models/customsd-concepts-library , https://github.com/sd-webui/sd-concepts-library
|
||||
"""
|
||||
colms = st.columns((1, 3, 5, 5))
|
||||
columns = ["№",'Model Name','Save Location','Download Link']
|
||||
|
||||
# Convert String into StringIO
|
||||
csvStringIO = StringIO(csvString)
|
||||
df = pd.read_csv(csvStringIO, sep=",", header=None, names=columns)
|
||||
|
||||
for col, field_name in zip(colms, columns):
|
||||
# table header
|
||||
col.write(field_name)
|
||||
|
||||
for x, model_name in enumerate(df["Model Name"]):
|
||||
col1, col2, col3, col4 = st.columns((1, 3, 4, 6))
|
||||
col1.write(x) # index
|
||||
col2.write(df['Model Name'][x])
|
||||
col3.write(df['Save Location'][x])
|
||||
col4.write(df['Download Link'][x])
|
5
scripts/Settings.py
Normal file
5
scripts/Settings.py
Normal file
@ -0,0 +1,5 @@
|
||||
from webui_streamlit import st
|
||||
|
||||
# The global settings section will be moved to the Settings page.
|
||||
#with st.expander("Global Settings:"):
|
||||
st.write("Global Settings:")
|
216
scripts/home.py
Normal file
216
scripts/home.py
Normal file
@ -0,0 +1,216 @@
|
||||
# base webui import and utils.
|
||||
from webui_streamlit import st
|
||||
from sd_utils import *
|
||||
|
||||
# streamlit imports
|
||||
|
||||
|
||||
#other imports
|
||||
|
||||
# Temp imports
|
||||
|
||||
|
||||
# end of imports
|
||||
#---------------------------------------------------------------------------------------------------------------
|
||||
|
||||
import os
|
||||
from PIL import Image
|
||||
|
||||
try:
|
||||
# this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
|
||||
from transformers import logging
|
||||
|
||||
logging.set_verbosity_error()
|
||||
except:
|
||||
pass
|
||||
|
||||
class plugin_info():
|
||||
plugname = "home"
|
||||
description = "Home"
|
||||
isTab = True
|
||||
displayPriority = 0
|
||||
|
||||
def getLatestGeneratedImagesFromPath():
|
||||
#get the latest images from the generated images folder
|
||||
#get the path to the generated images folder
|
||||
generatedImagesPath = os.path.join(os.getcwd(),'outputs')
|
||||
#get all the files from the folders and subfolders
|
||||
files = []
|
||||
#get the latest 10 images from the output folder without walking the subfolders
|
||||
for r, d, f in os.walk(generatedImagesPath):
|
||||
for file in f:
|
||||
if '.png' in file:
|
||||
files.append(os.path.join(r, file))
|
||||
#sort the files by date
|
||||
files.sort(reverse=True, key=os.path.getmtime)
|
||||
latest = files[:90]
|
||||
latest.reverse()
|
||||
|
||||
# reverse the list so the latest images are first and truncate to
|
||||
# a reasonable number of images, 10 pages worth
|
||||
return [Image.open(f) for f in latest]
|
||||
|
||||
def get_images_from_lexica():
|
||||
#scrape images from lexica.art
|
||||
#get the html from the page
|
||||
#get the html with cookies and javascript
|
||||
apiEndpoint = r'https://lexica.art/api/trpc/prompts.infinitePrompts?batch=1&input=%7B%220%22%3A%7B%22json%22%3A%7B%22limit%22%3A10%2C%22text%22%3A%22%22%2C%22cursor%22%3A10%7D%7D%7D'
|
||||
#REST API call
|
||||
#
|
||||
from requests_html import HTMLSession
|
||||
session = HTMLSession()
|
||||
|
||||
response = session.get(apiEndpoint)
|
||||
#req = requests.Session()
|
||||
#req.headers['user-agent'] = 'Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/96.0.4664.45 Safari/537.36'
|
||||
#response = req.get(apiEndpoint)
|
||||
print(response.status_code)
|
||||
print(response.text)
|
||||
#get the json from the response
|
||||
#json = response.json()
|
||||
#get the prompts from the json
|
||||
print(response)
|
||||
#session = requests.Session()
|
||||
#parseEndpointJson = session.get(apiEndpoint,headers=headers,verify=False)
|
||||
#print(parseEndpointJson)
|
||||
#print('test2')
|
||||
#page = requests.get("https://lexica.art/", headers={'User-Agent': 'Mozilla/5.0'})
|
||||
#parse the html
|
||||
#soup = BeautifulSoup(page.content, 'html.parser')
|
||||
#find all the images
|
||||
#print(soup)
|
||||
#images = soup.find_all('alt-image')
|
||||
#create a list to store the image urls
|
||||
image_urls = []
|
||||
#loop through the images
|
||||
for image in images:
|
||||
#get the url
|
||||
image_url = image['src']
|
||||
#add it to the list
|
||||
image_urls.append('http://www.lexica.art/'+image_url)
|
||||
#return the list
|
||||
print(image_urls)
|
||||
return image_urls
|
||||
|
||||
def layout():
|
||||
#streamlit home page layout
|
||||
#center the title
|
||||
st.markdown("<h1 style='text-align: center; color: white;'>Welcome, let's make some 🎨</h1>", unsafe_allow_html=True)
|
||||
#make a gallery of images
|
||||
#st.markdown("<h2 style='text-align: center; color: white;'>Gallery</h2>", unsafe_allow_html=True)
|
||||
#create a gallery of images using columns
|
||||
#col1, col2, col3 = st.columns(3)
|
||||
#load the images
|
||||
#create 3 columns
|
||||
# create a tab for the gallery
|
||||
#st.markdown("<h2 style='text-align: center; color: white;'>Gallery</h2>", unsafe_allow_html=True)
|
||||
#st.markdown("<h2 style='text-align: center; color: white;'>Gallery</h2>", unsafe_allow_html=True)
|
||||
|
||||
history_tab, discover_tabs = st.tabs(["History","Discover"])
|
||||
|
||||
latestImages = getLatestGeneratedImagesFromPath()
|
||||
st.session_state['latestImages'] = latestImages
|
||||
|
||||
with history_tab:
|
||||
##---------------------------------------------------------
|
||||
## image slideshow test
|
||||
## Number of entries per screen
|
||||
#slideshow_N = 9
|
||||
#slideshow_page_number = 0
|
||||
#slideshow_last_page = len(latestImages) // slideshow_N
|
||||
|
||||
## Add a next button and a previous button
|
||||
|
||||
#slideshow_prev, slideshow_image_col , slideshow_next = st.columns([1, 10, 1])
|
||||
|
||||
#with slideshow_image_col:
|
||||
#slideshow_image = st.empty()
|
||||
|
||||
#slideshow_image.image(st.session_state['latestImages'][0])
|
||||
|
||||
#current_image = 0
|
||||
|
||||
#if slideshow_next.button("Next", key=1):
|
||||
##print (current_image+1)
|
||||
#current_image = current_image+1
|
||||
#slideshow_image.image(st.session_state['latestImages'][current_image+1])
|
||||
#if slideshow_prev.button("Previous", key=0):
|
||||
##print ([current_image-1])
|
||||
#current_image = current_image-1
|
||||
#slideshow_image.image(st.session_state['latestImages'][current_image - 1])
|
||||
|
||||
|
||||
#---------------------------------------------------------
|
||||
|
||||
# image gallery
|
||||
# Number of entries per screen
|
||||
gallery_N = 9
|
||||
if not "galleryPage" in st.session_state:
|
||||
st.session_state["galleryPage"] = 0
|
||||
gallery_last_page = len(latestImages) // gallery_N
|
||||
|
||||
# Add a next button and a previous button
|
||||
|
||||
gallery_prev, gallery_refresh, gallery_pagination , gallery_next = st.columns([2, 2, 8, 1])
|
||||
|
||||
# the pagination doesnt work for now so its better to disable the buttons.
|
||||
|
||||
if gallery_refresh.button("Refresh", key=4):
|
||||
st.session_state["galleryPage"] = 0
|
||||
|
||||
if gallery_next.button("Next", key=3):
|
||||
|
||||
if st.session_state["galleryPage"] + 1 > gallery_last_page:
|
||||
st.session_state["galleryPage"] = 0
|
||||
else:
|
||||
st.session_state["galleryPage"] += 1
|
||||
|
||||
if gallery_prev.button("Previous", key=2):
|
||||
|
||||
if st.session_state["galleryPage"] - 1 < 0:
|
||||
st.session_state["galleryPage"] = gallery_last_page
|
||||
else:
|
||||
st.session_state["galleryPage"] -= 1
|
||||
|
||||
print(st.session_state["galleryPage"])
|
||||
# Get start and end indices of the next page of the dataframe
|
||||
gallery_start_idx = st.session_state["galleryPage"] * gallery_N
|
||||
gallery_end_idx = (1 + st.session_state["galleryPage"]) * gallery_N
|
||||
|
||||
#---------------------------------------------------------
|
||||
|
||||
placeholder = st.empty()
|
||||
|
||||
#populate the 3 images per column
|
||||
with placeholder.container():
|
||||
col1, col2, col3 = st.columns(3)
|
||||
col1_cont = st.container()
|
||||
col2_cont = st.container()
|
||||
col3_cont = st.container()
|
||||
|
||||
print (len(st.session_state['latestImages']))
|
||||
images = list(reversed(st.session_state['latestImages']))[gallery_start_idx:(gallery_start_idx+gallery_N)]
|
||||
|
||||
with col1_cont:
|
||||
with col1:
|
||||
[st.image(images[index]) for index in [0, 3, 6] if index < len(images)]
|
||||
with col2_cont:
|
||||
with col2:
|
||||
[st.image(images[index]) for index in [1, 4, 7] if index < len(images)]
|
||||
with col3_cont:
|
||||
with col3:
|
||||
[st.image(images[index]) for index in [2, 5, 8] if index < len(images)]
|
||||
|
||||
|
||||
st.session_state['historyTab'] = [history_tab,col1,col2,col3,placeholder,col1_cont,col2_cont,col3_cont]
|
||||
|
||||
with discover_tabs:
|
||||
st.markdown("<h1 style='text-align: center; color: white;'>Soon :)</h1>", unsafe_allow_html=True)
|
||||
|
||||
#display the images
|
||||
#add a button to the gallery
|
||||
#st.markdown("<h2 style='text-align: center; color: white;'>Try it out</h2>", unsafe_allow_html=True)
|
||||
#create a button to the gallery
|
||||
#if st.button("Try it out"):
|
||||
#if the button is clicked, go to the gallery
|
||||
#st.experimental_rerun()
|
592
scripts/img2img.py
Normal file
592
scripts/img2img.py
Normal file
@ -0,0 +1,592 @@
|
||||
# base webui import and utils.
|
||||
from webui_streamlit import st
|
||||
from sd_utils import *
|
||||
|
||||
# streamlit imports
|
||||
from streamlit import StopException
|
||||
|
||||
#other imports
|
||||
import cv2
|
||||
from PIL import Image, ImageOps
|
||||
import torch
|
||||
import k_diffusion as K
|
||||
import numpy as np
|
||||
import time
|
||||
import torch
|
||||
import skimage
|
||||
from ldm.models.diffusion.ddim import DDIMSampler
|
||||
from ldm.models.diffusion.plms import PLMSSampler
|
||||
# Temp imports
|
||||
|
||||
|
||||
# end of imports
|
||||
#---------------------------------------------------------------------------------------------------------------
|
||||
|
||||
|
||||
try:
|
||||
# this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
|
||||
from transformers import logging
|
||||
|
||||
logging.set_verbosity_error()
|
||||
except:
|
||||
pass
|
||||
|
||||
def img2img(prompt: str = '', init_info: any = None, init_info_mask: any = None, mask_mode: int = 0, mask_blur_strength: int = 3,
|
||||
mask_restore: bool = False, ddim_steps: int = 50, sampler_name: str = 'DDIM',
|
||||
n_iter: int = 1, cfg_scale: float = 7.5, denoising_strength: float = 0.8,
|
||||
seed: int = -1, noise_mode: int = 0, find_noise_steps: str = "", height: int = 512, width: int = 512, resize_mode: int = 0, fp = None,
|
||||
variant_amount: float = None, variant_seed: int = None, ddim_eta:float = 0.0,
|
||||
write_info_files:bool = True, RealESRGAN_model: str = "RealESRGAN_x4plus_anime_6B",
|
||||
separate_prompts:bool = False, normalize_prompt_weights:bool = True,
|
||||
save_individual_images: bool = True, save_grid: bool = True, group_by_prompt: bool = True,
|
||||
save_as_jpg: bool = True, use_GFPGAN: bool = True, use_RealESRGAN: bool = True, loopback: bool = False,
|
||||
random_seed_loopback: bool = False
|
||||
):
|
||||
|
||||
outpath = st.session_state['defaults'].general.outdir_img2img or st.session_state['defaults'].general.outdir or "outputs/img2img-samples"
|
||||
#err = False
|
||||
#loopback = False
|
||||
#skip_save = False
|
||||
seed = seed_to_int(seed)
|
||||
|
||||
batch_size = 1
|
||||
|
||||
#prompt_matrix = 0
|
||||
#normalize_prompt_weights = 1 in toggles
|
||||
#loopback = 2 in toggles
|
||||
#random_seed_loopback = 3 in toggles
|
||||
#skip_save = 4 not in toggles
|
||||
#save_grid = 5 in toggles
|
||||
#sort_samples = 6 in toggles
|
||||
#write_info_files = 7 in toggles
|
||||
#write_sample_info_to_log_file = 8 in toggles
|
||||
#jpg_sample = 9 in toggles
|
||||
#use_GFPGAN = 10 in toggles
|
||||
#use_RealESRGAN = 11 in toggles
|
||||
|
||||
if sampler_name == 'PLMS':
|
||||
sampler = PLMSSampler(st.session_state["model"])
|
||||
elif sampler_name == 'DDIM':
|
||||
sampler = DDIMSampler(st.session_state["model"])
|
||||
elif sampler_name == 'k_dpm_2_a':
|
||||
sampler = KDiffusionSampler(st.session_state["model"],'dpm_2_ancestral')
|
||||
elif sampler_name == 'k_dpm_2':
|
||||
sampler = KDiffusionSampler(st.session_state["model"],'dpm_2')
|
||||
elif sampler_name == 'k_euler_a':
|
||||
sampler = KDiffusionSampler(st.session_state["model"],'euler_ancestral')
|
||||
elif sampler_name == 'k_euler':
|
||||
sampler = KDiffusionSampler(st.session_state["model"],'euler')
|
||||
elif sampler_name == 'k_heun':
|
||||
sampler = KDiffusionSampler(st.session_state["model"],'heun')
|
||||
elif sampler_name == 'k_lms':
|
||||
sampler = KDiffusionSampler(st.session_state["model"],'lms')
|
||||
else:
|
||||
raise Exception("Unknown sampler: " + sampler_name)
|
||||
|
||||
def process_init_mask(init_mask: Image):
|
||||
if init_mask.mode == "RGBA":
|
||||
init_mask = init_mask.convert('RGBA')
|
||||
background = Image.new('RGBA', init_mask.size, (0, 0, 0))
|
||||
init_mask = Image.alpha_composite(background, init_mask)
|
||||
init_mask = init_mask.convert('RGB')
|
||||
return init_mask
|
||||
|
||||
init_img = init_info
|
||||
init_mask = None
|
||||
if mask_mode == 0:
|
||||
if init_info_mask:
|
||||
init_mask = process_init_mask(init_info_mask)
|
||||
elif mask_mode == 1:
|
||||
if init_info_mask:
|
||||
init_mask = process_init_mask(init_info_mask)
|
||||
init_mask = ImageOps.invert(init_mask)
|
||||
elif mask_mode == 2:
|
||||
init_img_transparency = init_img.split()[-1].convert('L')#.point(lambda x: 255 if x > 0 else 0, mode='1')
|
||||
init_mask = init_img_transparency
|
||||
init_mask = init_mask.convert("RGB")
|
||||
init_mask = resize_image(resize_mode, init_mask, width, height)
|
||||
init_mask = init_mask.convert("RGB")
|
||||
|
||||
assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]'
|
||||
t_enc = int(denoising_strength * ddim_steps)
|
||||
|
||||
if init_mask is not None and (noise_mode == 2 or noise_mode == 3) and init_img is not None:
|
||||
noise_q = 0.99
|
||||
color_variation = 0.0
|
||||
mask_blend_factor = 1.0
|
||||
|
||||
np_init = (np.asarray(init_img.convert("RGB"))/255.0).astype(np.float64) # annoyingly complex mask fixing
|
||||
np_mask_rgb = 1. - (np.asarray(ImageOps.invert(init_mask).convert("RGB"))/255.0).astype(np.float64)
|
||||
np_mask_rgb -= np.min(np_mask_rgb)
|
||||
np_mask_rgb /= np.max(np_mask_rgb)
|
||||
np_mask_rgb = 1. - np_mask_rgb
|
||||
np_mask_rgb_hardened = 1. - (np_mask_rgb < 0.99).astype(np.float64)
|
||||
blurred = skimage.filters.gaussian(np_mask_rgb_hardened[:], sigma=16., channel_axis=2, truncate=32.)
|
||||
blurred2 = skimage.filters.gaussian(np_mask_rgb_hardened[:], sigma=16., channel_axis=2, truncate=32.)
|
||||
#np_mask_rgb_dilated = np_mask_rgb + blurred # fixup mask todo: derive magic constants
|
||||
#np_mask_rgb = np_mask_rgb + blurred
|
||||
np_mask_rgb_dilated = np.clip((np_mask_rgb + blurred2) * 0.7071, 0., 1.)
|
||||
np_mask_rgb = np.clip((np_mask_rgb + blurred) * 0.7071, 0., 1.)
|
||||
|
||||
noise_rgb = get_matched_noise(np_init, np_mask_rgb, noise_q, color_variation)
|
||||
blend_mask_rgb = np.clip(np_mask_rgb_dilated,0.,1.) ** (mask_blend_factor)
|
||||
noised = noise_rgb[:]
|
||||
blend_mask_rgb **= (2.)
|
||||
noised = np_init[:] * (1. - blend_mask_rgb) + noised * blend_mask_rgb
|
||||
|
||||
np_mask_grey = np.sum(np_mask_rgb, axis=2)/3.
|
||||
ref_mask = np_mask_grey < 1e-3
|
||||
|
||||
all_mask = np.ones((height, width), dtype=bool)
|
||||
noised[all_mask,:] = skimage.exposure.match_histograms(noised[all_mask,:]**1., noised[ref_mask,:], channel_axis=1)
|
||||
|
||||
init_img = Image.fromarray(np.clip(noised * 255., 0., 255.).astype(np.uint8), mode="RGB")
|
||||
st.session_state["editor_image"].image(init_img) # debug
|
||||
|
||||
def init():
|
||||
image = init_img.convert('RGB')
|
||||
image = np.array(image).astype(np.float32) / 255.0
|
||||
image = image[None].transpose(0, 3, 1, 2)
|
||||
image = torch.from_numpy(image)
|
||||
|
||||
mask_channel = None
|
||||
if init_mask:
|
||||
alpha = resize_image(resize_mode, init_mask, width // 8, height // 8)
|
||||
mask_channel = alpha.split()[-1]
|
||||
|
||||
mask = None
|
||||
if mask_channel is not None:
|
||||
mask = np.array(mask_channel).astype(np.float32) / 255.0
|
||||
mask = (1 - mask)
|
||||
mask = np.tile(mask, (4, 1, 1))
|
||||
mask = mask[None].transpose(0, 1, 2, 3)
|
||||
mask = torch.from_numpy(mask).to(st.session_state["device"])
|
||||
|
||||
if st.session_state['defaults'].general.optimized:
|
||||
st.session_state.modelFS.to(st.session_state["device"] )
|
||||
|
||||
init_image = 2. * image - 1.
|
||||
init_image = init_image.to(st.session_state["device"])
|
||||
init_latent = (st.session_state["model"] if not st.session_state['defaults'].general.optimized else st.session_state.modelFS).get_first_stage_encoding((st.session_state["model"] if not st.session_state['defaults'].general.optimized else modelFS).encode_first_stage(init_image)) # move to latent space
|
||||
|
||||
if st.session_state['defaults'].general.optimized:
|
||||
mem = torch.cuda.memory_allocated()/1e6
|
||||
st.session_state.modelFS.to("cpu")
|
||||
while(torch.cuda.memory_allocated()/1e6 >= mem):
|
||||
time.sleep(1)
|
||||
|
||||
return init_latent, mask,
|
||||
|
||||
def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name):
|
||||
t_enc_steps = t_enc
|
||||
obliterate = False
|
||||
if ddim_steps == t_enc_steps:
|
||||
t_enc_steps = t_enc_steps - 1
|
||||
obliterate = True
|
||||
|
||||
if sampler_name != 'DDIM':
|
||||
x0, z_mask = init_data
|
||||
|
||||
sigmas = sampler.model_wrap.get_sigmas(ddim_steps)
|
||||
noise = x * sigmas[ddim_steps - t_enc_steps - 1]
|
||||
|
||||
xi = x0 + noise
|
||||
|
||||
# Obliterate masked image
|
||||
if z_mask is not None and obliterate:
|
||||
random = torch.randn(z_mask.shape, device=xi.device)
|
||||
xi = (z_mask * noise) + ((1-z_mask) * xi)
|
||||
|
||||
sigma_sched = sigmas[ddim_steps - t_enc_steps - 1:]
|
||||
model_wrap_cfg = CFGMaskedDenoiser(sampler.model_wrap)
|
||||
samples_ddim = K.sampling.__dict__[f'sample_{sampler.get_sampler_name()}'](model_wrap_cfg, xi, sigma_sched,
|
||||
extra_args={'cond': conditioning, 'uncond': unconditional_conditioning,
|
||||
'cond_scale': cfg_scale, 'mask': z_mask, 'x0': x0, 'xi': xi}, disable=False,
|
||||
callback=generation_callback)
|
||||
else:
|
||||
|
||||
x0, z_mask = init_data
|
||||
|
||||
sampler.make_schedule(ddim_num_steps=ddim_steps, ddim_eta=0.0, verbose=False)
|
||||
z_enc = sampler.stochastic_encode(x0, torch.tensor([t_enc_steps]*batch_size).to(st.session_state["device"] ))
|
||||
|
||||
# Obliterate masked image
|
||||
if z_mask is not None and obliterate:
|
||||
random = torch.randn(z_mask.shape, device=z_enc.device)
|
||||
z_enc = (z_mask * random) + ((1-z_mask) * z_enc)
|
||||
|
||||
# decode it
|
||||
samples_ddim = sampler.decode(z_enc, conditioning, t_enc_steps,
|
||||
unconditional_guidance_scale=cfg_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
z_mask=z_mask, x0=x0)
|
||||
return samples_ddim
|
||||
|
||||
|
||||
|
||||
if loopback:
|
||||
output_images, info = None, None
|
||||
history = []
|
||||
initial_seed = None
|
||||
|
||||
do_color_correction = False
|
||||
try:
|
||||
from skimage import exposure
|
||||
do_color_correction = True
|
||||
except:
|
||||
print("Install scikit-image to perform color correction on loopback")
|
||||
|
||||
for i in range(n_iter):
|
||||
if do_color_correction and i == 0:
|
||||
correction_target = cv2.cvtColor(np.asarray(init_img.copy()), cv2.COLOR_RGB2LAB)
|
||||
|
||||
output_images, seed, info, stats = process_images(
|
||||
outpath=outpath,
|
||||
func_init=init,
|
||||
func_sample=sample,
|
||||
prompt=prompt,
|
||||
seed=seed,
|
||||
sampler_name=sampler_name,
|
||||
save_grid=save_grid,
|
||||
batch_size=1,
|
||||
n_iter=1,
|
||||
steps=ddim_steps,
|
||||
cfg_scale=cfg_scale,
|
||||
width=width,
|
||||
height=height,
|
||||
prompt_matrix=separate_prompts,
|
||||
use_GFPGAN=use_GFPGAN,
|
||||
use_RealESRGAN=use_RealESRGAN, # Forcefully disable upscaling when using loopback
|
||||
realesrgan_model_name=RealESRGAN_model,
|
||||
normalize_prompt_weights=normalize_prompt_weights,
|
||||
save_individual_images=save_individual_images,
|
||||
init_img=init_img,
|
||||
init_mask=init_mask,
|
||||
mask_blur_strength=mask_blur_strength,
|
||||
mask_restore=mask_restore,
|
||||
denoising_strength=denoising_strength,
|
||||
noise_mode=noise_mode,
|
||||
find_noise_steps=find_noise_steps,
|
||||
resize_mode=resize_mode,
|
||||
uses_loopback=loopback,
|
||||
uses_random_seed_loopback=random_seed_loopback,
|
||||
sort_samples=group_by_prompt,
|
||||
write_info_files=write_info_files,
|
||||
jpg_sample=save_as_jpg
|
||||
)
|
||||
|
||||
if initial_seed is None:
|
||||
initial_seed = seed
|
||||
|
||||
input_image = init_img
|
||||
init_img = output_images[0]
|
||||
|
||||
if do_color_correction and correction_target is not None:
|
||||
init_img = Image.fromarray(cv2.cvtColor(exposure.match_histograms(
|
||||
cv2.cvtColor(
|
||||
np.asarray(init_img),
|
||||
cv2.COLOR_RGB2LAB
|
||||
),
|
||||
correction_target,
|
||||
channel_axis=2
|
||||
), cv2.COLOR_LAB2RGB).astype("uint8"))
|
||||
if mask_restore is True and init_mask is not None:
|
||||
color_mask = init_mask.filter(ImageFilter.GaussianBlur(mask_blur_strength))
|
||||
color_mask = color_mask.convert('L')
|
||||
source_image = input_image.convert('RGB')
|
||||
target_image = init_img.convert('RGB')
|
||||
|
||||
init_img = Image.composite(source_image, target_image, color_mask)
|
||||
|
||||
if not random_seed_loopback:
|
||||
seed = seed + 1
|
||||
else:
|
||||
seed = seed_to_int(None)
|
||||
|
||||
denoising_strength = max(denoising_strength * 0.95, 0.1)
|
||||
history.append(init_img)
|
||||
|
||||
output_images = history
|
||||
seed = initial_seed
|
||||
|
||||
else:
|
||||
output_images, seed, info, stats = process_images(
|
||||
outpath=outpath,
|
||||
func_init=init,
|
||||
func_sample=sample,
|
||||
prompt=prompt,
|
||||
seed=seed,
|
||||
sampler_name=sampler_name,
|
||||
save_grid=save_grid,
|
||||
batch_size=batch_size,
|
||||
n_iter=n_iter,
|
||||
steps=ddim_steps,
|
||||
cfg_scale=cfg_scale,
|
||||
width=width,
|
||||
height=height,
|
||||
prompt_matrix=separate_prompts,
|
||||
use_GFPGAN=use_GFPGAN,
|
||||
use_RealESRGAN=use_RealESRGAN,
|
||||
realesrgan_model_name=RealESRGAN_model,
|
||||
normalize_prompt_weights=normalize_prompt_weights,
|
||||
save_individual_images=save_individual_images,
|
||||
init_img=init_img,
|
||||
init_mask=init_mask,
|
||||
mask_blur_strength=mask_blur_strength,
|
||||
denoising_strength=denoising_strength,
|
||||
noise_mode=noise_mode,
|
||||
find_noise_steps=find_noise_steps,
|
||||
mask_restore=mask_restore,
|
||||
resize_mode=resize_mode,
|
||||
uses_loopback=loopback,
|
||||
sort_samples=group_by_prompt,
|
||||
write_info_files=write_info_files,
|
||||
jpg_sample=save_as_jpg
|
||||
)
|
||||
|
||||
del sampler
|
||||
|
||||
return output_images, seed, info, stats
|
||||
|
||||
#
|
||||
|
||||
|
||||
def layout():
|
||||
with st.form("img2img-inputs"):
|
||||
st.session_state["generation_mode"] = "img2img"
|
||||
|
||||
img2img_input_col, img2img_generate_col = st.columns([10,1])
|
||||
with img2img_input_col:
|
||||
#prompt = st.text_area("Input Text","")
|
||||
prompt = st.text_input("Input Text","", placeholder="A corgi wearing a top hat as an oil painting.")
|
||||
|
||||
# Every form must have a submit button, the extra blank spaces is a temp way to align it with the input field. Needs to be done in CSS or some other way.
|
||||
img2img_generate_col.write("")
|
||||
img2img_generate_col.write("")
|
||||
generate_button = img2img_generate_col.form_submit_button("Generate")
|
||||
|
||||
|
||||
# creating the page layout using columns
|
||||
col1_img2img_layout, col2_img2img_layout, col3_img2img_layout = st.columns([1,2,2], gap="small")
|
||||
|
||||
with col1_img2img_layout:
|
||||
# If we have custom models available on the "models/custom"
|
||||
#folder then we show a menu to select which model we want to use, otherwise we use the main model for SD
|
||||
if st.session_state["CustomModel_available"]:
|
||||
st.session_state["custom_model"] = st.selectbox("Custom Model:", st.session_state["custom_models"],
|
||||
index=st.session_state["custom_models"].index(st.session_state['defaults'].general.default_model),
|
||||
help="Select the model you want to use. This option is only available if you have custom models \
|
||||
on your 'models/custom' folder. The model name that will be shown here is the same as the name\
|
||||
the file for the model has on said folder, it is recommended to give the .ckpt file a name that \
|
||||
will make it easier for you to distinguish it from other models. Default: Stable Diffusion v1.4")
|
||||
else:
|
||||
st.session_state["custom_model"] = "Stable Diffusion v1.4"
|
||||
|
||||
|
||||
st.session_state["sampling_steps"] = st.slider("Sampling Steps",
|
||||
value=st.session_state['defaults'].img2img.sampling_steps,
|
||||
min_value=st.session_state['defaults'].img2img.slider_bounds.sampling.lower,
|
||||
max_value=st.session_state['defaults'].img2img.slider_bounds.sampling.upper,
|
||||
step=st.session_state['defaults'].img2img.slider_steps.sampling)
|
||||
|
||||
sampler_name_list = ["k_lms", "k_euler", "k_euler_a", "k_dpm_2", "k_dpm_2_a", "k_heun", "PLMS", "DDIM"]
|
||||
st.session_state["sampler_name"] = st.selectbox("Sampling method",sampler_name_list,
|
||||
index=sampler_name_list.index(st.session_state['defaults'].img2img.sampler_name), help="Sampling method to use.")
|
||||
|
||||
mask_mode_list = ["Mask", "Inverted mask", "Image alpha"]
|
||||
mask_mode = st.selectbox("Mask Mode", mask_mode_list,
|
||||
help="Select how you want your image to be masked.\"Mask\" modifies the image where the mask is white.\n\
|
||||
\"Inverted mask\" modifies the image where the mask is black. \"Image alpha\" modifies the image where the image is transparent."
|
||||
)
|
||||
mask_mode = mask_mode_list.index(mask_mode)
|
||||
|
||||
width = st.slider("Width:", min_value=64, max_value=1024, value=st.session_state['defaults'].img2img.width, step=64)
|
||||
height = st.slider("Height:", min_value=64, max_value=1024, value=st.session_state['defaults'].img2img.height, step=64)
|
||||
seed = st.text_input("Seed:", value=st.session_state['defaults'].img2img.seed, help=" The seed to use, if left blank a random seed will be generated.")
|
||||
noise_mode_list = ["Seed", "Find Noise", "Matched Noise", "Find+Matched Noise"]
|
||||
noise_mode = st.selectbox(
|
||||
"Noise Mode", noise_mode_list,
|
||||
help=""
|
||||
)
|
||||
noise_mode = noise_mode_list.index(noise_mode)
|
||||
find_noise_steps = st.slider("Find Noise Steps", value=100, min_value=1, max_value=500)
|
||||
batch_count = st.slider("Batch count.", min_value=1, max_value=100, value=st.session_state['defaults'].img2img.batch_count, step=1,
|
||||
help="How many iterations or batches of images to generate in total.")
|
||||
|
||||
#
|
||||
with st.expander("Advanced"):
|
||||
separate_prompts = st.checkbox("Create Prompt Matrix.", value=st.session_state['defaults'].img2img.separate_prompts,
|
||||
help="Separate multiple prompts using the `|` character, and get all combinations of them.")
|
||||
normalize_prompt_weights = st.checkbox("Normalize Prompt Weights.", value=st.session_state['defaults'].img2img.normalize_prompt_weights,
|
||||
help="Ensure the sum of all weights add up to 1.0")
|
||||
loopback = st.checkbox("Loopback.", value=st.session_state['defaults'].img2img.loopback, help="Use images from previous batch when creating next batch.")
|
||||
random_seed_loopback = st.checkbox("Random loopback seed.", value=st.session_state['defaults'].img2img.random_seed_loopback, help="Random loopback seed")
|
||||
img2img_mask_restore = st.checkbox("Only modify regenerated parts of image",
|
||||
value=st.session_state['defaults'].img2img.mask_restore,
|
||||
help="Enable to restore the unmasked parts of the image with the input, may not blend as well but preserves detail")
|
||||
save_individual_images = st.checkbox("Save individual images.", value=st.session_state['defaults'].img2img.save_individual_images,
|
||||
help="Save each image generated before any filter or enhancement is applied.")
|
||||
save_grid = st.checkbox("Save grid",value=st.session_state['defaults'].img2img.save_grid, help="Save a grid with all the images generated into a single image.")
|
||||
group_by_prompt = st.checkbox("Group results by prompt", value=st.session_state['defaults'].img2img.group_by_prompt,
|
||||
help="Saves all the images with the same prompt into the same folder. \
|
||||
When using a prompt matrix each prompt combination will have its own folder.")
|
||||
write_info_files = st.checkbox("Write Info file", value=st.session_state['defaults'].img2img.write_info_files,
|
||||
help="Save a file next to the image with informartion about the generation.")
|
||||
save_as_jpg = st.checkbox("Save samples as jpg", value=st.session_state['defaults'].img2img.save_as_jpg, help="Saves the images as jpg instead of png.")
|
||||
|
||||
if st.session_state["GFPGAN_available"]:
|
||||
use_GFPGAN = st.checkbox("Use GFPGAN", value=st.session_state['defaults'].img2img.use_GFPGAN, help="Uses the GFPGAN model to improve faces after the generation.\
|
||||
This greatly improve the quality and consistency of faces but uses extra VRAM. Disable if you need the extra VRAM.")
|
||||
else:
|
||||
use_GFPGAN = False
|
||||
|
||||
if st.session_state["RealESRGAN_available"]:
|
||||
st.session_state["use_RealESRGAN"] = st.checkbox("Use RealESRGAN", value=st.session_state['defaults'].img2img.use_RealESRGAN,
|
||||
help="Uses the RealESRGAN model to upscale the images after the generation.\
|
||||
This greatly improve the quality and lets you have high resolution images but uses extra VRAM. Disable if you need the extra VRAM.")
|
||||
st.session_state["RealESRGAN_model"] = st.selectbox("RealESRGAN model", ["RealESRGAN_x4plus", "RealESRGAN_x4plus_anime_6B"], index=0)
|
||||
else:
|
||||
st.session_state["use_RealESRGAN"] = False
|
||||
st.session_state["RealESRGAN_model"] = "RealESRGAN_x4plus"
|
||||
|
||||
variant_amount = st.slider("Variant Amount:", value=st.session_state['defaults'].img2img.variant_amount, min_value=0.0, max_value=1.0, step=0.01)
|
||||
variant_seed = st.text_input("Variant Seed:", value=st.session_state['defaults'].img2img.variant_seed,
|
||||
help="The seed to use when generating a variant, if left blank a random seed will be generated.")
|
||||
cfg_scale = st.slider("CFG (Classifier Free Guidance Scale):", min_value=1.0, max_value=30.0, value=st.session_state['defaults'].img2img.cfg_scale, step=0.5,
|
||||
help="How strongly the image should follow the prompt.")
|
||||
batch_size = st.slider("Batch size", min_value=1, max_value=100, value=st.session_state['defaults'].img2img.batch_size, step=1,
|
||||
help="How many images are at once in a batch.\
|
||||
It increases the VRAM usage a lot but if you have enough VRAM it can reduce the time it takes to finish \
|
||||
generation as more images are generated at once.\
|
||||
Default: 1")
|
||||
|
||||
st.session_state["denoising_strength"] = st.slider("Denoising Strength:", value=st.session_state['defaults'].img2img.denoising_strength,
|
||||
min_value=0.01, max_value=1.0, step=0.01)
|
||||
|
||||
with st.expander("Preview Settings"):
|
||||
st.session_state["update_preview"] = st.checkbox("Update Image Preview", value=st.session_state['defaults'].img2img.update_preview,
|
||||
help="If enabled the image preview will be updated during the generation instead of at the end. \
|
||||
You can use the Update Preview \Frequency option bellow to customize how frequent it's updated. \
|
||||
By default this is enabled and the frequency is set to 1 step.")
|
||||
|
||||
st.session_state["update_preview_frequency"] = st.text_input("Update Image Preview Frequency", value=st.session_state['defaults'].img2img.update_preview_frequency,
|
||||
help="Frequency in steps at which the the preview image is updated. By default the frequency \
|
||||
is set to 1 step.")
|
||||
|
||||
with col2_img2img_layout:
|
||||
editor_tab = st.tabs(["Editor"])
|
||||
|
||||
editor_image = st.empty()
|
||||
st.session_state["editor_image"] = editor_image
|
||||
|
||||
st.form_submit_button("Refresh")
|
||||
|
||||
masked_image_holder = st.empty()
|
||||
image_holder = st.empty()
|
||||
|
||||
uploaded_images = st.file_uploader(
|
||||
"Upload Image", accept_multiple_files=False, type=["png", "jpg", "jpeg", "webp"],
|
||||
help="Upload an image which will be used for the image to image generation.",
|
||||
)
|
||||
if uploaded_images:
|
||||
image = Image.open(uploaded_images).convert('RGBA')
|
||||
new_img = image.resize((width, height))
|
||||
image_holder.image(new_img)
|
||||
|
||||
mask_holder = st.empty()
|
||||
|
||||
uploaded_masks = st.file_uploader(
|
||||
"Upload Mask", accept_multiple_files=False, type=["png", "jpg", "jpeg", "webp"],
|
||||
help="Upload an mask image which will be used for masking the image to image generation.",
|
||||
)
|
||||
if uploaded_masks:
|
||||
mask = Image.open(uploaded_masks)
|
||||
if mask.mode == "RGBA":
|
||||
mask = mask.convert('RGBA')
|
||||
background = Image.new('RGBA', mask.size, (0, 0, 0))
|
||||
mask = Image.alpha_composite(background, mask)
|
||||
mask = mask.resize((width, height))
|
||||
mask_holder.image(mask)
|
||||
|
||||
if uploaded_images and uploaded_masks:
|
||||
if mask_mode != 2:
|
||||
final_img = new_img.copy()
|
||||
alpha_layer = mask.convert('L')
|
||||
strength = st.session_state["denoising_strength"]
|
||||
if mask_mode == 0:
|
||||
alpha_layer = ImageOps.invert(alpha_layer)
|
||||
alpha_layer = alpha_layer.point(lambda a: a * strength)
|
||||
alpha_layer = ImageOps.invert(alpha_layer)
|
||||
elif mask_mode == 1:
|
||||
alpha_layer = alpha_layer.point(lambda a: a * strength)
|
||||
alpha_layer = ImageOps.invert(alpha_layer)
|
||||
|
||||
final_img.putalpha(alpha_layer)
|
||||
|
||||
with masked_image_holder.container():
|
||||
st.text("Masked Image Preview")
|
||||
st.image(final_img)
|
||||
|
||||
|
||||
with col3_img2img_layout:
|
||||
result_tab = st.tabs(["Result"])
|
||||
|
||||
# create an empty container for the image, progress bar, etc so we can update it later and use session_state to hold them globally.
|
||||
preview_image = st.empty()
|
||||
st.session_state["preview_image"] = preview_image
|
||||
|
||||
#st.session_state["loading"] = st.empty()
|
||||
|
||||
st.session_state["progress_bar_text"] = st.empty()
|
||||
st.session_state["progress_bar"] = st.empty()
|
||||
|
||||
|
||||
message = st.empty()
|
||||
|
||||
#if uploaded_images:
|
||||
#image = Image.open(uploaded_images).convert('RGB')
|
||||
##img_array = np.array(image) # if you want to pass it to OpenCV
|
||||
#new_img = image.resize((width, height))
|
||||
#st.image(new_img, use_column_width=True)
|
||||
|
||||
|
||||
if generate_button:
|
||||
#print("Loading models")
|
||||
# load the models when we hit the generate button for the first time, it wont be loaded after that so dont worry.
|
||||
load_models(False, use_GFPGAN, st.session_state["use_RealESRGAN"], st.session_state["RealESRGAN_model"], st.session_state["CustomModel_available"],
|
||||
st.session_state["custom_model"])
|
||||
|
||||
if uploaded_images:
|
||||
image = Image.open(uploaded_images).convert('RGBA')
|
||||
new_img = image.resize((width, height))
|
||||
#img_array = np.array(image) # if you want to pass it to OpenCV
|
||||
new_mask = None
|
||||
if uploaded_masks:
|
||||
mask = Image.open(uploaded_masks).convert('RGBA')
|
||||
new_mask = mask.resize((width, height))
|
||||
|
||||
try:
|
||||
output_images, seed, info, stats = img2img(prompt=prompt, init_info=new_img, init_info_mask=new_mask, mask_mode=mask_mode,
|
||||
mask_restore=img2img_mask_restore, ddim_steps=st.session_state["sampling_steps"],
|
||||
sampler_name=st.session_state["sampler_name"], n_iter=batch_count,
|
||||
cfg_scale=cfg_scale, denoising_strength=st.session_state["denoising_strength"], variant_seed=variant_seed,
|
||||
seed=seed, noise_mode=noise_mode, find_noise_steps=find_noise_steps, width=width,
|
||||
height=height, variant_amount=variant_amount,
|
||||
ddim_eta=0.0, write_info_files=write_info_files, RealESRGAN_model=st.session_state["RealESRGAN_model"],
|
||||
separate_prompts=separate_prompts, normalize_prompt_weights=normalize_prompt_weights,
|
||||
save_individual_images=save_individual_images, save_grid=save_grid,
|
||||
group_by_prompt=group_by_prompt, save_as_jpg=save_as_jpg, use_GFPGAN=use_GFPGAN,
|
||||
use_RealESRGAN=st.session_state["use_RealESRGAN"] if not loopback else False, loopback=loopback
|
||||
)
|
||||
|
||||
#show a message when the generation is complete.
|
||||
message.success('Render Complete: ' + info + '; Stats: ' + stats, icon="✅")
|
||||
|
||||
except (StopException, KeyError):
|
||||
print(f"Received Streamlit StopException")
|
||||
|
||||
# this will render all the images at the end of the generation but its better if its moved to a second tab inside col2 and shown as a gallery.
|
||||
# use the current col2 first tab to show the preview_img and update it as its generated.
|
||||
#preview_image.image(output_images, width=750)
|
||||
|
||||
#on import run init
|
161
scripts/imglab.py
Normal file
161
scripts/imglab.py
Normal file
@ -0,0 +1,161 @@
|
||||
# base webui import and utils.
|
||||
from webui_streamlit import st
|
||||
from sd_utils import *
|
||||
|
||||
#home plugin
|
||||
import os
|
||||
from PIL import Image
|
||||
#from bs4 import BeautifulSoup
|
||||
from streamlit.runtime.in_memory_file_manager import in_memory_file_manager
|
||||
from streamlit.elements import image as STImage
|
||||
|
||||
# Temp imports
|
||||
|
||||
|
||||
# end of imports
|
||||
#---------------------------------------------------------------------------------------------------------------
|
||||
|
||||
try:
|
||||
# this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
|
||||
from transformers import logging
|
||||
|
||||
logging.set_verbosity_error()
|
||||
except:
|
||||
pass
|
||||
|
||||
class plugin_info():
|
||||
plugname = "imglab"
|
||||
description = "Image Lab"
|
||||
isTab = True
|
||||
displayPriority = 3
|
||||
|
||||
def getLatestGeneratedImagesFromPath():
|
||||
#get the latest images from the generated images folder
|
||||
#get the path to the generated images folder
|
||||
generatedImagesPath = os.path.join(os.getcwd(),'outputs')
|
||||
#get all the files from the folders and subfolders
|
||||
files = []
|
||||
#get the laest 10 images from the output folder without walking the subfolders
|
||||
for r, d, f in os.walk(generatedImagesPath):
|
||||
for file in f:
|
||||
if '.png' in file:
|
||||
files.append(os.path.join(r, file))
|
||||
#sort the files by date
|
||||
files.sort(key=os.path.getmtime)
|
||||
#reverse the list so the latest images are first
|
||||
for f in files:
|
||||
img = Image.open(f)
|
||||
files[files.index(f)] = img
|
||||
#get the latest 10 files
|
||||
#get all the files with the .png or .jpg extension
|
||||
#sort files by date
|
||||
#get the latest 10 files
|
||||
latestFiles = files[-10:]
|
||||
#reverse the list
|
||||
latestFiles.reverse()
|
||||
return latestFiles
|
||||
|
||||
def getImagesFromLexica():
|
||||
#scrape images from lexica.art
|
||||
#get the html from the page
|
||||
#get the html with cookies and javascript
|
||||
apiEndpoint = r'https://lexica.art/api/trpc/prompts.infinitePrompts?batch=1&input=%7B%220%22%3A%7B%22json%22%3A%7B%22limit%22%3A10%2C%22text%22%3A%22%22%2C%22cursor%22%3A10%7D%7D%7D'
|
||||
#REST API call
|
||||
#
|
||||
from requests_html import HTMLSession
|
||||
session = HTMLSession()
|
||||
|
||||
response = session.get(apiEndpoint)
|
||||
#req = requests.Session()
|
||||
#req.headers['user-agent'] = 'Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/96.0.4664.45 Safari/537.36'
|
||||
#response = req.get(apiEndpoint)
|
||||
print(response.status_code)
|
||||
print(response.text)
|
||||
#get the json from the response
|
||||
#json = response.json()
|
||||
#get the prompts from the json
|
||||
print(response)
|
||||
#session = requests.Session()
|
||||
#parseEndpointJson = session.get(apiEndpoint,headers=headers,verify=False)
|
||||
#print(parseEndpointJson)
|
||||
#print('test2')
|
||||
#page = requests.get("https://lexica.art/", headers={'User-Agent': 'Mozilla/5.0'})
|
||||
#parse the html
|
||||
#soup = BeautifulSoup(page.content, 'html.parser')
|
||||
#find all the images
|
||||
#print(soup)
|
||||
#images = soup.find_all('alt-image')
|
||||
#create a list to store the image urls
|
||||
image_urls = []
|
||||
#loop through the images
|
||||
for image in images:
|
||||
#get the url
|
||||
image_url = image['src']
|
||||
#add it to the list
|
||||
image_urls.append('http://www.lexica.art/'+image_url)
|
||||
#return the list
|
||||
print(image_urls)
|
||||
return image_urls
|
||||
def changeImage():
|
||||
#change the image in the image holder
|
||||
#check if the file is not empty
|
||||
if len(st.session_state['uploaded_file']) > 0:
|
||||
#read the file
|
||||
print('test2')
|
||||
uploaded = st.session_state['uploaded_file'][0].read()
|
||||
#show the image in the image holder
|
||||
st.session_state['previewImg'].empty()
|
||||
st.session_state['previewImg'].image(uploaded,use_column_width=True)
|
||||
def createHTMLGallery(images):
|
||||
html3 = """
|
||||
<div class="gallery-history" style="
|
||||
display: flex;
|
||||
flex-wrap: wrap;
|
||||
align-items: flex-start;">
|
||||
"""
|
||||
mkdwn_array = []
|
||||
for i in images:
|
||||
bImg = i.read()
|
||||
i = Image.save(bImg, 'PNG')
|
||||
width, height = i.size
|
||||
#get random number for the id
|
||||
image_id = "%s" % (str(images.index(i)))
|
||||
(data, mimetype) = STImage._normalize_to_bytes(bImg.getvalue(), width, 'auto')
|
||||
this_file = in_memory_file_manager.add(data, mimetype, image_id)
|
||||
img_str = this_file.url
|
||||
#img_str = 'data:image/png;base64,' + b64encode(image_io.getvalue()).decode('ascii')
|
||||
#get image size
|
||||
|
||||
#make sure the image is not bigger then 150px but keep the aspect ratio
|
||||
if width > 150:
|
||||
height = int(height * (150/width))
|
||||
width = 150
|
||||
if height > 150:
|
||||
width = int(width * (150/height))
|
||||
height = 150
|
||||
|
||||
#mkdwn = f"""<img src="{img_str}" alt="Image" with="200" height="200" />"""
|
||||
mkdwn = f'''<div class="gallery" style="margin: 3px;" >
|
||||
<a href="{img_str}">
|
||||
<img src="{img_str}" alt="Image" width="{width}" height="{height}">
|
||||
</a>
|
||||
</div>
|
||||
'''
|
||||
mkdwn_array.append(mkdwn)
|
||||
html3 += "".join(mkdwn_array)
|
||||
html3 += '</div>'
|
||||
return html3
|
||||
def layout():
|
||||
|
||||
col1, col2 = st.columns(2)
|
||||
with col1:
|
||||
st.session_state['uploaded_file'] = st.file_uploader("Choose an image or images", type=["png", "jpg", "jpeg", "webp"],accept_multiple_files=True,on_change=changeImage)
|
||||
if 'previewImg' not in st.session_state:
|
||||
st.session_state['previewImg'] = st.empty()
|
||||
else:
|
||||
if len(st.session_state['uploaded_file']) > 0:
|
||||
st.session_state['previewImg'].empty()
|
||||
st.session_state['previewImg'].image(st.session_state['uploaded_file'][0],use_column_width=True)
|
||||
else:
|
||||
st.session_state['previewImg'] = st.empty()
|
||||
|
48
scripts/perlin.py
Normal file
48
scripts/perlin.py
Normal file
@ -0,0 +1,48 @@
|
||||
import numpy as np
|
||||
|
||||
def perlin(x, y, seed=0):
|
||||
# permutation table
|
||||
np.random.seed(seed)
|
||||
p = np.arange(256, dtype=int)
|
||||
np.random.shuffle(p)
|
||||
p = np.stack([p, p]).flatten()
|
||||
# coordinates of the top-left
|
||||
xi, yi = x.astype(int), y.astype(int)
|
||||
# internal coordinates
|
||||
xf, yf = x - xi, y - yi
|
||||
# fade factors
|
||||
u, v = fade(xf), fade(yf)
|
||||
# noise components
|
||||
n00 = gradient(p[p[xi] + yi], xf, yf)
|
||||
n01 = gradient(p[p[xi] + yi + 1], xf, yf - 1)
|
||||
n11 = gradient(p[p[xi + 1] + yi + 1], xf - 1, yf - 1)
|
||||
n10 = gradient(p[p[xi + 1] + yi], xf - 1, yf)
|
||||
# combine noises
|
||||
x1 = lerp(n00, n10, u)
|
||||
x2 = lerp(n01, n11, u) # FIX1: I was using n10 instead of n01
|
||||
return lerp(x1, x2, v) # FIX2: I also had to reverse x1 and x2 here
|
||||
|
||||
def lerp(a, b, x):
|
||||
"linear interpolation"
|
||||
return a + x * (b - a)
|
||||
|
||||
def fade(t):
|
||||
"6t^5 - 15t^4 + 10t^3"
|
||||
return 6 * t**5 - 15 * t**4 + 10 * t**3
|
||||
|
||||
def gradient(h, x, y):
|
||||
"grad converts h to the right gradient vector and return the dot product with (x,y)"
|
||||
vectors = np.array([[0, 1], [0, -1], [1, 0], [-1, 0]])
|
||||
g = vectors[h % 4]
|
||||
return g[:, :, 0] * x + g[:, :, 1] * y
|
||||
|
||||
lin = np.linspace(0, 5, 100, endpoint=False)
|
||||
x, y = np.meshgrid(lin, lin)
|
||||
|
||||
|
||||
|
||||
def perlinNoise(height,width,octavesx=5,octavesy=5,seed=None):
|
||||
linx = np.linspace(0,octavesx,width,endpoint=False)
|
||||
liny = np.linspace(0,octavesy,height,endpoint=False)
|
||||
x,y = np.meshgrid(linx,liny)
|
||||
return perlin(x,y,seed=seed)
|
@ -19,6 +19,8 @@ optimized_turbo = False
|
||||
# Creates a public xxxxx.gradio.app share link to allow others to use your interface (requires properly forwarded ports to work correctly)
|
||||
share = False
|
||||
|
||||
# Generate tiling images
|
||||
tiling = False
|
||||
|
||||
# Enter other `--arguments` you wish to use - Must be entered as a `--argument ` syntax
|
||||
additional_arguments = ""
|
||||
@ -37,6 +39,8 @@ if optimized_turbo == True:
|
||||
common_arguments += "--optimized-turbo "
|
||||
if optimized == True:
|
||||
common_arguments += "--optimized "
|
||||
if tiling == True:
|
||||
common_arguments += "--tiling "
|
||||
if share == True:
|
||||
common_arguments += "--share "
|
||||
|
||||
|
1728
scripts/sd_utils.py
Normal file
1728
scripts/sd_utils.py
Normal file
File diff suppressed because it is too large
Load Diff
233
scripts/stable_diffusion_pipeline.py
Normal file
233
scripts/stable_diffusion_pipeline.py
Normal file
@ -0,0 +1,233 @@
|
||||
import inspect
|
||||
import warnings
|
||||
from tqdm.auto import tqdm
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
from diffusers import ModelMixin
|
||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||
from diffusers.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.pipelines.stable_diffusion.safety_checker import \
|
||||
StableDiffusionSafetyChecker
|
||||
from diffusers.schedulers import (DDIMScheduler, LMSDiscreteScheduler,
|
||||
PNDMScheduler)
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
|
||||
class StableDiffusionPipeline(DiffusionPipeline):
|
||||
def __init__(
|
||||
self,
|
||||
vae: AutoencoderKL,
|
||||
text_encoder: CLIPTextModel,
|
||||
tokenizer: CLIPTokenizer,
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
|
||||
safety_checker: StableDiffusionSafetyChecker,
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
):
|
||||
super().__init__()
|
||||
scheduler = scheduler.set_format("pt")
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Optional[Union[str, List[str]]] = None,
|
||||
height: Optional[int] = 512,
|
||||
width: Optional[int] = 512,
|
||||
num_inference_steps: Optional[int] = 50,
|
||||
guidance_scale: Optional[float] = 7.5,
|
||||
eta: Optional[float] = 0.0,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
text_embeddings: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
**kwargs,
|
||||
):
|
||||
if "torch_device" in kwargs:
|
||||
device = kwargs.pop("torch_device")
|
||||
warnings.warn(
|
||||
"`torch_device` is deprecated as an input argument to `__call__` and"
|
||||
" will be removed in v0.3.0. Consider using `pipe.to(torch_device)`"
|
||||
" instead."
|
||||
)
|
||||
|
||||
# Set device as before (to be removed in 0.3.0)
|
||||
if device is None:
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
self.to(device)
|
||||
|
||||
if text_embeddings is None:
|
||||
if isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"`prompt` has to be of type `str` or `list` but is {type(prompt)}"
|
||||
)
|
||||
|
||||
if height % 8 != 0 or width % 8 != 0:
|
||||
raise ValueError(
|
||||
"`height` and `width` have to be divisible by 8 but are"
|
||||
f" {height} and {width}."
|
||||
)
|
||||
|
||||
# get prompt text embeddings
|
||||
text_input = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]
|
||||
else:
|
||||
batch_size = text_embeddings.shape[0]
|
||||
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
# get unconditional embeddings for classifier free guidance
|
||||
if do_classifier_free_guidance:
|
||||
# max_length = text_input.input_ids.shape[-1]
|
||||
max_length = 77 # self.tokenizer.model_max_length
|
||||
uncond_input = self.tokenizer(
|
||||
[""] * batch_size,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
return_tensors="pt",
|
||||
)
|
||||
uncond_embeddings = self.text_encoder(
|
||||
uncond_input.input_ids.to(self.device)
|
||||
)[0]
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
||||
|
||||
# get the initial random noise unless the user supplied it
|
||||
latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8)
|
||||
if latents is None:
|
||||
latents = torch.randn(
|
||||
latents_shape,
|
||||
generator=generator,
|
||||
device=self.device,
|
||||
)
|
||||
else:
|
||||
if latents.shape != latents_shape:
|
||||
raise ValueError(
|
||||
f"Unexpected latents shape, got {latents.shape}, expected"
|
||||
f" {latents_shape}"
|
||||
)
|
||||
latents = latents.to(self.device)
|
||||
|
||||
# set timesteps
|
||||
accepts_offset = "offset" in set(
|
||||
inspect.signature(self.scheduler.set_timesteps).parameters.keys()
|
||||
)
|
||||
extra_set_kwargs = {}
|
||||
if accepts_offset:
|
||||
extra_set_kwargs["offset"] = 1
|
||||
|
||||
self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
|
||||
|
||||
# if we use LMSDiscreteScheduler, let's make sure latents are mulitplied by sigmas
|
||||
if isinstance(self.scheduler, LMSDiscreteScheduler):
|
||||
latents = latents * self.scheduler.sigmas[0]
|
||||
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
||||
# and should be between [0, 1]
|
||||
accepts_eta = "eta" in set(
|
||||
inspect.signature(self.scheduler.step).parameters.keys()
|
||||
)
|
||||
extra_step_kwargs = {}
|
||||
if accepts_eta:
|
||||
extra_step_kwargs["eta"] = eta
|
||||
|
||||
for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = (
|
||||
torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
)
|
||||
if isinstance(self.scheduler, LMSDiscreteScheduler):
|
||||
sigma = self.scheduler.sigmas[i]
|
||||
# the model input needs to be scaled to match the continuous ODE formulation in K-LMS
|
||||
latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(
|
||||
latent_model_input, t, encoder_hidden_states=text_embeddings
|
||||
)["sample"]
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (
|
||||
noise_pred_text - noise_pred_uncond
|
||||
)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
if isinstance(self.scheduler, LMSDiscreteScheduler):
|
||||
latents = self.scheduler.step(
|
||||
noise_pred, i, latents, **extra_step_kwargs
|
||||
)["prev_sample"]
|
||||
else:
|
||||
latents = self.scheduler.step(
|
||||
noise_pred, t, latents, **extra_step_kwargs
|
||||
)["prev_sample"]
|
||||
|
||||
# scale and decode the image latents with vae
|
||||
latents = 1 / 0.18215 * latents
|
||||
image = self.vae.decode(latents).sample
|
||||
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
image = image.cpu().permute(0, 2, 3, 1).numpy()
|
||||
|
||||
safety_cheker_input = self.feature_extractor(
|
||||
self.numpy_to_pil(image), return_tensors="pt"
|
||||
).to(self.device)
|
||||
image, has_nsfw_concept = self.safety_checker(
|
||||
images=image, clip_input=safety_cheker_input.pixel_values
|
||||
)
|
||||
|
||||
if output_type == "pil":
|
||||
image = self.numpy_to_pil(image)
|
||||
|
||||
return {"sample": image, "nsfw_content_detected": has_nsfw_concept}
|
||||
|
||||
def embed_text(self, text):
|
||||
"""Helper to embed some text"""
|
||||
with torch.autocast("cuda"):
|
||||
text_input = self.tokenizer(
|
||||
text,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
with torch.no_grad():
|
||||
embed = self.text_encoder(text_input.input_ids.to(self.device))[0]
|
||||
return embed
|
||||
|
||||
|
||||
class NoCheck(ModelMixin):
|
||||
"""Can be used in place of safety checker. Use responsibly and at your own risk."""
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.register_parameter(name='asdf', param=torch.nn.Parameter(torch.randn(3)))
|
||||
|
||||
def forward(self, images=None, **kwargs):
|
||||
return images, [False]
|
218
scripts/stable_diffusion_walk.py
Normal file
218
scripts/stable_diffusion_walk.py
Normal file
@ -0,0 +1,218 @@
|
||||
import json
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from diffusers.schedulers import (DDIMScheduler, LMSDiscreteScheduler,
|
||||
PNDMScheduler)
|
||||
from diffusers import ModelMixin
|
||||
|
||||
from stable_diffusion_pipeline import StableDiffusionPipeline
|
||||
|
||||
pipeline = StableDiffusionPipeline.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4",
|
||||
use_auth_token=True,
|
||||
torch_dtype=torch.float16,
|
||||
revision="fp16",
|
||||
).to("cuda")
|
||||
|
||||
default_scheduler = PNDMScheduler(
|
||||
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
|
||||
)
|
||||
ddim_scheduler = DDIMScheduler(
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
beta_schedule="scaled_linear",
|
||||
clip_sample=False,
|
||||
set_alpha_to_one=False,
|
||||
)
|
||||
klms_scheduler = LMSDiscreteScheduler(
|
||||
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
|
||||
)
|
||||
SCHEDULERS = dict(default=default_scheduler, ddim=ddim_scheduler, klms=klms_scheduler)
|
||||
|
||||
|
||||
def slerp(t, v0, v1, DOT_THRESHOLD=0.9995):
|
||||
"""helper function to spherically interpolate two arrays v1 v2"""
|
||||
|
||||
if not isinstance(v0, np.ndarray):
|
||||
inputs_are_torch = True
|
||||
input_device = v0.device
|
||||
v0 = v0.cpu().numpy()
|
||||
v1 = v1.cpu().numpy()
|
||||
|
||||
dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1)))
|
||||
if np.abs(dot) > DOT_THRESHOLD:
|
||||
v2 = (1 - t) * v0 + t * v1
|
||||
else:
|
||||
theta_0 = np.arccos(dot)
|
||||
sin_theta_0 = np.sin(theta_0)
|
||||
theta_t = theta_0 * t
|
||||
sin_theta_t = np.sin(theta_t)
|
||||
s0 = np.sin(theta_0 - theta_t) / sin_theta_0
|
||||
s1 = sin_theta_t / sin_theta_0
|
||||
v2 = s0 * v0 + s1 * v1
|
||||
|
||||
if inputs_are_torch:
|
||||
v2 = torch.from_numpy(v2).to(input_device)
|
||||
|
||||
return v2
|
||||
|
||||
|
||||
def make_video_ffmpeg(frame_dir, output_file_name='output.mp4', frame_filename="frame%06d.jpg", fps=30):
|
||||
frame_ref_path = str(frame_dir / frame_filename)
|
||||
video_path = str(frame_dir / output_file_name)
|
||||
subprocess.call(
|
||||
f"ffmpeg -r {fps} -i {frame_ref_path} -vcodec libx264 -crf 10 -pix_fmt yuv420p"
|
||||
f" {video_path}".split()
|
||||
)
|
||||
return video_path
|
||||
|
||||
|
||||
def walk(
|
||||
prompts=["blueberry spaghetti", "strawberry spaghetti"],
|
||||
seeds=[42, 123],
|
||||
num_steps=5,
|
||||
output_dir="dreams",
|
||||
name="berry_good_spaghetti",
|
||||
height=512,
|
||||
width=512,
|
||||
guidance_scale=7.5,
|
||||
eta=0.0,
|
||||
num_inference_steps=50,
|
||||
do_loop=False,
|
||||
make_video=False,
|
||||
use_lerp_for_text=False,
|
||||
scheduler="klms", # choices: default, ddim, klms
|
||||
disable_tqdm=False,
|
||||
upsample=False,
|
||||
fps=30,
|
||||
):
|
||||
"""Generate video frames/a video given a list of prompts and seeds.
|
||||
|
||||
Args:
|
||||
prompts (List[str], optional): List of . Defaults to ["blueberry spaghetti", "strawberry spaghetti"].
|
||||
seeds (List[int], optional): List of random seeds corresponding to given prompts.
|
||||
num_steps (int, optional): Number of steps to walk. Increase this value to 60-200 for good results. Defaults to 5.
|
||||
output_dir (str, optional): Root dir where images will be saved. Defaults to "dreams".
|
||||
name (str, optional): Sub directory of output_dir to save this run's files. Defaults to "berry_good_spaghetti".
|
||||
height (int, optional): Height of image to generate. Defaults to 512.
|
||||
width (int, optional): Width of image to generate. Defaults to 512.
|
||||
guidance_scale (float, optional): Higher = more adherance to prompt. Lower = let model take the wheel. Defaults to 7.5.
|
||||
eta (float, optional): ETA. Defaults to 0.0.
|
||||
num_inference_steps (int, optional): Number of diffusion steps. Defaults to 50.
|
||||
do_loop (bool, optional): Whether to loop from last prompt back to first. Defaults to False.
|
||||
make_video (bool, optional): Whether to make a video or just save the images. Defaults to False.
|
||||
use_lerp_for_text (bool, optional): Use LERP instead of SLERP for text embeddings when walking. Defaults to False.
|
||||
scheduler (str, optional): Which scheduler to use. Defaults to "klms". Choices are "default", "ddim", "klms".
|
||||
disable_tqdm (bool, optional): Whether to turn off the tqdm progress bars. Defaults to False.
|
||||
upsample (bool, optional): If True, uses Real-ESRGAN to upsample images 4x. Requires it to be installed
|
||||
which you can do by running: `pip install git+https://github.com/xinntao/Real-ESRGAN.git`. Defaults to False.
|
||||
fps (int, optional): The frames per second (fps) that you want the video to use. Does nothing if make_video is False. Defaults to 30.
|
||||
|
||||
Returns:
|
||||
str: Path to video file saved if make_video=True, else None.
|
||||
"""
|
||||
if upsample:
|
||||
from .upsampling import PipelineRealESRGAN
|
||||
|
||||
upsampling_pipeline = PipelineRealESRGAN.from_pretrained('nateraw/real-esrgan')
|
||||
|
||||
pipeline.set_progress_bar_config(disable=disable_tqdm)
|
||||
|
||||
pipeline.scheduler = SCHEDULERS[scheduler]
|
||||
|
||||
output_path = Path(output_dir) / name
|
||||
output_path.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
# Write prompt info to file in output dir so we can keep track of what we did
|
||||
prompt_config_path = output_path / 'prompt_config.json'
|
||||
prompt_config_path.write_text(
|
||||
json.dumps(
|
||||
dict(
|
||||
prompts=prompts,
|
||||
seeds=seeds,
|
||||
num_steps=num_steps,
|
||||
name=name,
|
||||
guidance_scale=guidance_scale,
|
||||
eta=eta,
|
||||
num_inference_steps=num_inference_steps,
|
||||
do_loop=do_loop,
|
||||
make_video=make_video,
|
||||
use_lerp_for_text=use_lerp_for_text,
|
||||
scheduler=scheduler
|
||||
),
|
||||
indent=2,
|
||||
sort_keys=False,
|
||||
)
|
||||
)
|
||||
|
||||
assert len(prompts) == len(seeds)
|
||||
|
||||
first_prompt, *prompts = prompts
|
||||
embeds_a = pipeline.embed_text(first_prompt)
|
||||
|
||||
first_seed, *seeds = seeds
|
||||
latents_a = torch.randn(
|
||||
(1, pipeline.unet.in_channels, height // 8, width // 8),
|
||||
device=pipeline.device,
|
||||
generator=torch.Generator(device=pipeline.device).manual_seed(first_seed),
|
||||
)
|
||||
|
||||
if do_loop:
|
||||
prompts.append(first_prompt)
|
||||
seeds.append(first_seed)
|
||||
|
||||
frame_index = 0
|
||||
for prompt, seed in zip(prompts, seeds):
|
||||
# Text
|
||||
embeds_b = pipeline.embed_text(prompt)
|
||||
|
||||
# Latent Noise
|
||||
latents_b = torch.randn(
|
||||
(1, pipeline.unet.in_channels, height // 8, width // 8),
|
||||
device=pipeline.device,
|
||||
generator=torch.Generator(device=pipeline.device).manual_seed(seed),
|
||||
)
|
||||
|
||||
for i, t in enumerate(np.linspace(0, 1, num_steps)):
|
||||
do_print_progress = (i == 0) or ((frame_index + 1) % 20 == 0)
|
||||
if do_print_progress:
|
||||
print(f"COUNT: {frame_index+1}/{len(seeds)*num_steps}")
|
||||
|
||||
if use_lerp_for_text:
|
||||
embeds = torch.lerp(embeds_a, embeds_b, float(t))
|
||||
else:
|
||||
embeds = slerp(float(t), embeds_a, embeds_b)
|
||||
latents = slerp(float(t), latents_a, latents_b)
|
||||
|
||||
with torch.autocast("cuda"):
|
||||
im = pipeline(
|
||||
latents=latents,
|
||||
text_embeddings=embeds,
|
||||
height=height,
|
||||
width=width,
|
||||
guidance_scale=guidance_scale,
|
||||
eta=eta,
|
||||
num_inference_steps=num_inference_steps,
|
||||
output_type='pil' if not upsample else 'numpy'
|
||||
)["sample"][0]
|
||||
|
||||
if upsample:
|
||||
im = upsampling_pipeline(im)
|
||||
|
||||
im.save(output_path / ("frame%06d.jpg" % frame_index))
|
||||
frame_index += 1
|
||||
|
||||
embeds_a = embeds_b
|
||||
latents_a = latents_b
|
||||
|
||||
if make_video:
|
||||
return make_video_ffmpeg(output_path, f"{name}.mp4", fps=fps)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import fire
|
||||
|
||||
fire.Fire(walk)
|
57
scripts/textual_inversion.py
Normal file
57
scripts/textual_inversion.py
Normal file
@ -0,0 +1,57 @@
|
||||
# base webui import and utils.
|
||||
from webui_streamlit import st
|
||||
from sd_utils import *
|
||||
|
||||
# streamlit imports
|
||||
|
||||
|
||||
#other imports
|
||||
#from transformers import CLIPTextModel, CLIPTokenizer
|
||||
|
||||
# Temp imports
|
||||
|
||||
|
||||
# end of imports
|
||||
#---------------------------------------------------------------------------------------------------------------
|
||||
|
||||
#def load_learned_embed_in_clip(learned_embeds_path, text_encoder, tokenizer, token=None):
|
||||
|
||||
#loaded_learned_embeds = torch.load(learned_embeds_path, map_location="cpu")
|
||||
|
||||
## separate token and the embeds
|
||||
#print (loaded_learned_embeds)
|
||||
#trained_token = list(loaded_learned_embeds.keys())[0]
|
||||
#embeds = loaded_learned_embeds[trained_token]
|
||||
|
||||
## cast to dtype of text_encoder
|
||||
#dtype = text_encoder.get_input_embeddings().weight.dtype
|
||||
#embeds.to(dtype)
|
||||
|
||||
## add the token in tokenizer
|
||||
#token = token if token is not None else trained_token
|
||||
#num_added_tokens = tokenizer.add_tokens(token)
|
||||
#i = 1
|
||||
#while(num_added_tokens == 0):
|
||||
#print(f"The tokenizer already contains the token {token}.")
|
||||
#token = f"{token[:-1]}-{i}>"
|
||||
#print(f"Attempting to add the token {token}.")
|
||||
#num_added_tokens = tokenizer.add_tokens(token)
|
||||
#i+=1
|
||||
|
||||
## resize the token embeddings
|
||||
#text_encoder.resize_token_embeddings(len(tokenizer))
|
||||
|
||||
## get the id for the token and assign the embeds
|
||||
#token_id = tokenizer.convert_tokens_to_ids(token)
|
||||
#text_encoder.get_input_embeddings().weight.data[token_id] = embeds
|
||||
#return token
|
||||
|
||||
##def token_loader()
|
||||
#learned_token = load_learned_embed_in_clip(f"models/custom/embeddings/Custom Ami.pt", st.session_state.pipe.text_encoder, st.session_state.pipe.tokenizer, "*")
|
||||
#model_content["token"] = learned_token
|
||||
#models.append(model_content)
|
||||
|
||||
model_id = "./models/custom/embeddings/"
|
||||
|
||||
def layout():
|
||||
st.write("Textual Inversion")
|
368
scripts/txt2img.py
Normal file
368
scripts/txt2img.py
Normal file
@ -0,0 +1,368 @@
|
||||
# base webui import and utils.
|
||||
from webui_streamlit import st
|
||||
from sd_utils import *
|
||||
|
||||
# streamlit imports
|
||||
from streamlit import StopException
|
||||
from streamlit.runtime.in_memory_file_manager import in_memory_file_manager
|
||||
from streamlit.elements import image as STImage
|
||||
|
||||
#other imports
|
||||
import os
|
||||
from typing import Union
|
||||
from io import BytesIO
|
||||
from ldm.models.diffusion.ddim import DDIMSampler
|
||||
from ldm.models.diffusion.plms import PLMSSampler
|
||||
|
||||
# Temp imports
|
||||
|
||||
|
||||
# end of imports
|
||||
#---------------------------------------------------------------------------------------------------------------
|
||||
|
||||
|
||||
try:
|
||||
# this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
|
||||
from transformers import logging
|
||||
|
||||
logging.set_verbosity_error()
|
||||
except:
|
||||
pass
|
||||
|
||||
class plugin_info():
|
||||
plugname = "txt2img"
|
||||
description = "Text to Image"
|
||||
isTab = True
|
||||
displayPriority = 1
|
||||
|
||||
|
||||
if os.path.exists(os.path.join(st.session_state['defaults'].general.GFPGAN_dir, "experiments", "pretrained_models", "GFPGANv1.3.pth")):
|
||||
GFPGAN_available = True
|
||||
else:
|
||||
GFPGAN_available = False
|
||||
|
||||
if os.path.exists(os.path.join(st.session_state['defaults'].general.RealESRGAN_dir, "experiments","pretrained_models", f"{st.session_state['defaults'].general.RealESRGAN_model}.pth")):
|
||||
RealESRGAN_available = True
|
||||
else:
|
||||
RealESRGAN_available = False
|
||||
|
||||
#
|
||||
def txt2img(prompt: str, ddim_steps: int, sampler_name: str, realesrgan_model_name: str,
|
||||
n_iter: int, batch_size: int, cfg_scale: float, seed: Union[int, str, None],
|
||||
height: int, width: int, separate_prompts:bool = False, normalize_prompt_weights:bool = True,
|
||||
save_individual_images: bool = True, save_grid: bool = True, group_by_prompt: bool = True,
|
||||
save_as_jpg: bool = True, use_GFPGAN: bool = True, use_RealESRGAN: bool = True,
|
||||
RealESRGAN_model: str = "RealESRGAN_x4plus_anime_6B", fp = None, variant_amount: float = None,
|
||||
variant_seed: int = None, ddim_eta:float = 0.0, write_info_files:bool = True):
|
||||
|
||||
outpath = st.session_state['defaults'].general.outdir_txt2img or st.session_state['defaults'].general.outdir or "outputs/txt2img-samples"
|
||||
|
||||
seed = seed_to_int(seed)
|
||||
|
||||
#prompt_matrix = 0 in toggles
|
||||
#normalize_prompt_weights = 1 in toggles
|
||||
#skip_save = 2 not in toggles
|
||||
#save_grid = 3 not in toggles
|
||||
#sort_samples = 4 in toggles
|
||||
#write_info_files = 5 in toggles
|
||||
#jpg_sample = 6 in toggles
|
||||
#use_GFPGAN = 7 in toggles
|
||||
#use_RealESRGAN = 8 in toggles
|
||||
|
||||
if sampler_name == 'PLMS':
|
||||
sampler = PLMSSampler(st.session_state["model"])
|
||||
elif sampler_name == 'DDIM':
|
||||
sampler = DDIMSampler(st.session_state["model"])
|
||||
elif sampler_name == 'k_dpm_2_a':
|
||||
sampler = KDiffusionSampler(st.session_state["model"],'dpm_2_ancestral')
|
||||
elif sampler_name == 'k_dpm_2':
|
||||
sampler = KDiffusionSampler(st.session_state["model"],'dpm_2')
|
||||
elif sampler_name == 'k_euler_a':
|
||||
sampler = KDiffusionSampler(st.session_state["model"],'euler_ancestral')
|
||||
elif sampler_name == 'k_euler':
|
||||
sampler = KDiffusionSampler(st.session_state["model"],'euler')
|
||||
elif sampler_name == 'k_heun':
|
||||
sampler = KDiffusionSampler(st.session_state["model"],'heun')
|
||||
elif sampler_name == 'k_lms':
|
||||
sampler = KDiffusionSampler(st.session_state["model"],'lms')
|
||||
else:
|
||||
raise Exception("Unknown sampler: " + sampler_name)
|
||||
|
||||
def init():
|
||||
pass
|
||||
|
||||
def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name):
|
||||
samples_ddim, _ = sampler.sample(S=ddim_steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=cfg_scale,
|
||||
unconditional_conditioning=unconditional_conditioning, eta=ddim_eta, x_T=x, img_callback=generation_callback,
|
||||
log_every_t=int(st.session_state.update_preview_frequency))
|
||||
|
||||
return samples_ddim
|
||||
|
||||
#try:
|
||||
output_images, seed, info, stats = process_images(
|
||||
outpath=outpath,
|
||||
func_init=init,
|
||||
func_sample=sample,
|
||||
prompt=prompt,
|
||||
seed=seed,
|
||||
sampler_name=sampler_name,
|
||||
save_grid=save_grid,
|
||||
batch_size=batch_size,
|
||||
n_iter=n_iter,
|
||||
steps=ddim_steps,
|
||||
cfg_scale=cfg_scale,
|
||||
width=width,
|
||||
height=height,
|
||||
prompt_matrix=separate_prompts,
|
||||
use_GFPGAN=st.session_state["use_GFPGAN"],
|
||||
use_RealESRGAN=st.session_state["use_RealESRGAN"],
|
||||
realesrgan_model_name=realesrgan_model_name,
|
||||
ddim_eta=ddim_eta,
|
||||
normalize_prompt_weights=normalize_prompt_weights,
|
||||
save_individual_images=save_individual_images,
|
||||
sort_samples=group_by_prompt,
|
||||
write_info_files=write_info_files,
|
||||
jpg_sample=save_as_jpg,
|
||||
variant_amount=variant_amount,
|
||||
variant_seed=variant_seed,
|
||||
)
|
||||
|
||||
del sampler
|
||||
|
||||
return output_images, seed, info, stats
|
||||
|
||||
#except RuntimeError as e:
|
||||
#err = e
|
||||
#err_msg = f'CRASHED:<br><textarea rows="5" style="color:white;background: black;width: -webkit-fill-available;font-family: monospace;font-size: small;font-weight: bold;">{str(e)}</textarea><br><br>Please wait while the program restarts.'
|
||||
#stats = err_msg
|
||||
#return [], seed, 'err', stats
|
||||
|
||||
def layout():
|
||||
with st.form("txt2img-inputs"):
|
||||
st.session_state["generation_mode"] = "txt2img"
|
||||
|
||||
input_col1, generate_col1 = st.columns([10,1])
|
||||
|
||||
with input_col1:
|
||||
#prompt = st.text_area("Input Text","")
|
||||
prompt = st.text_input("Input Text","", placeholder="A corgi wearing a top hat as an oil painting.")
|
||||
|
||||
# Every form must have a submit button, the extra blank spaces is a temp way to align it with the input field. Needs to be done in CSS or some other way.
|
||||
generate_col1.write("")
|
||||
generate_col1.write("")
|
||||
generate_button = generate_col1.form_submit_button("Generate")
|
||||
|
||||
# creating the page layout using columns
|
||||
col1, col2, col3 = st.columns([1,2,1], gap="large")
|
||||
|
||||
with col1:
|
||||
width = st.slider("Width:", min_value=64, max_value=4096, value=st.session_state['defaults'].txt2img.width, step=64)
|
||||
height = st.slider("Height:", min_value=64, max_value=4096, value=st.session_state['defaults'].txt2img.height, step=64)
|
||||
cfg_scale = st.slider("CFG (Classifier Free Guidance Scale):", min_value=1.0, max_value=30.0, value=st.session_state['defaults'].txt2img.cfg_scale, step=0.5, help="How strongly the image should follow the prompt.")
|
||||
seed = st.text_input("Seed:", value=st.session_state['defaults'].txt2img.seed, help=" The seed to use, if left blank a random seed will be generated.")
|
||||
batch_count = st.slider("Batch count.", min_value=1, max_value=100, value=st.session_state['defaults'].txt2img.batch_count, step=1, help="How many iterations or batches of images to generate in total.")
|
||||
|
||||
bs_slider_max_value = 5
|
||||
if st.session_state.defaults.general.optimized:
|
||||
bs_slider_max_value = 100
|
||||
|
||||
batch_size = st.slider(
|
||||
"Batch size",
|
||||
min_value=1,
|
||||
max_value=bs_slider_max_value,
|
||||
value=st.session_state.defaults.txt2img.batch_size,
|
||||
step=1,
|
||||
help="How many images are at once in a batch.\
|
||||
It increases the VRAM usage a lot but if you have enough VRAM it can reduce the time it takes to finish generation as more images are generated at once.\
|
||||
Default: 1")
|
||||
|
||||
with st.expander("Preview Settings"):
|
||||
st.session_state["update_preview"] = st.checkbox("Update Image Preview", value=st.session_state['defaults'].txt2img.update_preview,
|
||||
help="If enabled the image preview will be updated during the generation instead of at the end. \
|
||||
You can use the Update Preview \Frequency option bellow to customize how frequent it's updated. \
|
||||
By default this is enabled and the frequency is set to 1 step.")
|
||||
|
||||
st.session_state["update_preview_frequency"] = st.text_input("Update Image Preview Frequency", value=st.session_state['defaults'].txt2img.update_preview_frequency,
|
||||
help="Frequency in steps at which the the preview image is updated. By default the frequency \
|
||||
is set to 1 step.")
|
||||
|
||||
with col2:
|
||||
preview_tab, gallery_tab = st.tabs(["Preview", "Gallery"])
|
||||
|
||||
with preview_tab:
|
||||
#st.write("Image")
|
||||
#Image for testing
|
||||
#image = Image.open(requests.get("https://icon-library.com/images/image-placeholder-icon/image-placeholder-icon-13.jpg", stream=True).raw).convert('RGB')
|
||||
#new_image = image.resize((175, 240))
|
||||
#preview_image = st.image(image)
|
||||
|
||||
# create an empty container for the image, progress bar, etc so we can update it later and use session_state to hold them globally.
|
||||
st.session_state["preview_image"] = st.empty()
|
||||
|
||||
st.session_state["loading"] = st.empty()
|
||||
|
||||
st.session_state["progress_bar_text"] = st.empty()
|
||||
st.session_state["progress_bar"] = st.empty()
|
||||
|
||||
message = st.empty()
|
||||
|
||||
with col3:
|
||||
# If we have custom models available on the "models/custom"
|
||||
#folder then we show a menu to select which model we want to use, otherwise we use the main model for SD
|
||||
if st.session_state.CustomModel_available:
|
||||
st.session_state.custom_model = st.selectbox("Custom Model:", st.session_state.custom_models,
|
||||
index=st.session_state["custom_models"].index(st.session_state['defaults'].general.default_model),
|
||||
help="Select the model you want to use. This option is only available if you have custom models \
|
||||
on your 'models/custom' folder. The model name that will be shown here is the same as the name\
|
||||
the file for the model has on said folder, it is recommended to give the .ckpt file a name that \
|
||||
will make it easier for you to distinguish it from other models. Default: Stable Diffusion v1.4")
|
||||
|
||||
st.session_state.sampling_steps = st.slider("Sampling Steps",
|
||||
value=st.session_state['defaults'].txt2img.sampling_steps,
|
||||
min_value=st.session_state['defaults'].txt2img.slider_bounds.sampling.lower,
|
||||
max_value=st.session_state['defaults'].txt2img.slider_bounds.sampling.upper,
|
||||
step=st.session_state['defaults'].txt2img.slider_steps.sampling)
|
||||
|
||||
sampler_name_list = ["k_lms", "k_euler", "k_euler_a", "k_dpm_2", "k_dpm_2_a", "k_heun", "PLMS", "DDIM"]
|
||||
sampler_name = st.selectbox("Sampling method", sampler_name_list,
|
||||
index=sampler_name_list.index(st.session_state['defaults'].txt2img.default_sampler), help="Sampling method to use. Default: k_euler")
|
||||
|
||||
|
||||
|
||||
#basic_tab, advanced_tab = st.tabs(["Basic", "Advanced"])
|
||||
|
||||
#with basic_tab:
|
||||
#summit_on_enter = st.radio("Submit on enter?", ("Yes", "No"), horizontal=True,
|
||||
#help="Press the Enter key to summit, when 'No' is selected you can use the Enter key to write multiple lines.")
|
||||
|
||||
with st.expander("Advanced"):
|
||||
separate_prompts = st.checkbox("Create Prompt Matrix.", value=st.session_state['defaults'].txt2img.separate_prompts, help="Separate multiple prompts using the `|` character, and get all combinations of them.")
|
||||
normalize_prompt_weights = st.checkbox("Normalize Prompt Weights.", value=st.session_state['defaults'].txt2img.normalize_prompt_weights, help="Ensure the sum of all weights add up to 1.0")
|
||||
save_individual_images = st.checkbox("Save individual images.", value=st.session_state['defaults'].txt2img.save_individual_images, help="Save each image generated before any filter or enhancement is applied.")
|
||||
save_grid = st.checkbox("Save grid",value=st.session_state['defaults'].txt2img.save_grid, help="Save a grid with all the images generated into a single image.")
|
||||
group_by_prompt = st.checkbox("Group results by prompt", value=st.session_state['defaults'].txt2img.group_by_prompt,
|
||||
help="Saves all the images with the same prompt into the same folder. When using a prompt matrix each prompt combination will have its own folder.")
|
||||
write_info_files = st.checkbox("Write Info file", value=st.session_state['defaults'].txt2img.write_info_files, help="Save a file next to the image with informartion about the generation.")
|
||||
save_as_jpg = st.checkbox("Save samples as jpg", value=st.session_state['defaults'].txt2img.save_as_jpg, help="Saves the images as jpg instead of png.")
|
||||
|
||||
if st.session_state["GFPGAN_available"]:
|
||||
st.session_state["use_GFPGAN"] = st.checkbox("Use GFPGAN", value=st.session_state['defaults'].txt2img.use_GFPGAN, help="Uses the GFPGAN model to improve faces after the generation.\
|
||||
This greatly improve the quality and consistency of faces but uses extra VRAM. Disable if you need the extra VRAM.")
|
||||
else:
|
||||
st.session_state["use_GFPGAN"] = False
|
||||
|
||||
if st.session_state["RealESRGAN_available"]:
|
||||
st.session_state["use_RealESRGAN"] = st.checkbox("Use RealESRGAN", value=st.session_state['defaults'].txt2img.use_RealESRGAN,
|
||||
help="Uses the RealESRGAN model to upscale the images after the generation.\
|
||||
This greatly improve the quality and lets you have high resolution images but uses extra VRAM. Disable if you need the extra VRAM.")
|
||||
st.session_state["RealESRGAN_model"] = st.selectbox("RealESRGAN model", ["RealESRGAN_x4plus", "RealESRGAN_x4plus_anime_6B"], index=0)
|
||||
else:
|
||||
st.session_state["use_RealESRGAN"] = False
|
||||
st.session_state["RealESRGAN_model"] = "RealESRGAN_x4plus"
|
||||
|
||||
variant_amount = st.slider("Variant Amount:", value=st.session_state['defaults'].txt2img.variant_amount, min_value=0.0, max_value=1.0, step=0.01)
|
||||
variant_seed = st.text_input("Variant Seed:", value=st.session_state['defaults'].txt2img.seed, help="The seed to use when generating a variant, if left blank a random seed will be generated.")
|
||||
galleryCont = st.empty()
|
||||
|
||||
if generate_button:
|
||||
#print("Loading models")
|
||||
# load the models when we hit the generate button for the first time, it wont be loaded after that so dont worry.
|
||||
load_models(False, st.session_state["use_GFPGAN"], st.session_state["use_RealESRGAN"], st.session_state["RealESRGAN_model"], st.session_state["CustomModel_available"],
|
||||
st.session_state["custom_model"])
|
||||
|
||||
|
||||
try:
|
||||
#
|
||||
output_images, seeds, info, stats = txt2img(prompt, st.session_state.sampling_steps, sampler_name, st.session_state["RealESRGAN_model"], batch_count, batch_size,
|
||||
cfg_scale, seed, height, width, separate_prompts, normalize_prompt_weights, save_individual_images,
|
||||
save_grid, group_by_prompt, save_as_jpg, st.session_state["use_GFPGAN"], st.session_state["use_RealESRGAN"], st.session_state["RealESRGAN_model"],
|
||||
variant_amount=variant_amount, variant_seed=variant_seed, write_info_files=write_info_files)
|
||||
|
||||
message.success('Render Complete: ' + info + '; Stats: ' + stats, icon="✅")
|
||||
|
||||
#history_tab,col1,col2,col3,PlaceHolder,col1_cont,col2_cont,col3_cont = st.session_state['historyTab']
|
||||
|
||||
#if 'latestImages' in st.session_state:
|
||||
#for i in output_images:
|
||||
##push the new image to the list of latest images and remove the oldest one
|
||||
##remove the last index from the list\
|
||||
#st.session_state['latestImages'].pop()
|
||||
##add the new image to the start of the list
|
||||
#st.session_state['latestImages'].insert(0, i)
|
||||
#PlaceHolder.empty()
|
||||
#with PlaceHolder.container():
|
||||
#col1, col2, col3 = st.columns(3)
|
||||
#col1_cont = st.container()
|
||||
#col2_cont = st.container()
|
||||
#col3_cont = st.container()
|
||||
#images = st.session_state['latestImages']
|
||||
#with col1_cont:
|
||||
#with col1:
|
||||
#[st.image(images[index]) for index in [0, 3, 6] if index < len(images)]
|
||||
#with col2_cont:
|
||||
#with col2:
|
||||
#[st.image(images[index]) for index in [1, 4, 7] if index < len(images)]
|
||||
#with col3_cont:
|
||||
#with col3:
|
||||
#[st.image(images[index]) for index in [2, 5, 8] if index < len(images)]
|
||||
#historyGallery = st.empty()
|
||||
|
||||
## check if output_images length is the same as seeds length
|
||||
#with gallery_tab:
|
||||
#st.markdown(createHTMLGallery(output_images,seeds), unsafe_allow_html=True)
|
||||
|
||||
|
||||
#st.session_state['historyTab'] = [history_tab,col1,col2,col3,PlaceHolder,col1_cont,col2_cont,col3_cont]
|
||||
|
||||
except (StopException, KeyError):
|
||||
print(f"Received Streamlit StopException")
|
||||
|
||||
# this will render all the images at the end of the generation but its better if its moved to a second tab inside col2 and shown as a gallery.
|
||||
# use the current col2 first tab to show the preview_img and update it as its generated.
|
||||
#preview_image.image(output_images)
|
||||
|
||||
#on import run init
|
||||
def createHTMLGallery(images,info):
|
||||
html3 = """
|
||||
<div class="gallery-history" style="
|
||||
display: flex;
|
||||
flex-wrap: wrap;
|
||||
align-items: flex-start;">
|
||||
"""
|
||||
mkdwn_array = []
|
||||
for i in images:
|
||||
try:
|
||||
seed = info[images.index(i)]
|
||||
except:
|
||||
seed = ' '
|
||||
image_io = BytesIO()
|
||||
i.save(image_io, 'PNG')
|
||||
width, height = i.size
|
||||
#get random number for the id
|
||||
image_id = "%s" % (str(images.index(i)))
|
||||
(data, mimetype) = STImage._normalize_to_bytes(image_io.getvalue(), width, 'auto')
|
||||
this_file = in_memory_file_manager.add(data, mimetype, image_id)
|
||||
img_str = this_file.url
|
||||
#img_str = 'data:image/png;base64,' + b64encode(image_io.getvalue()).decode('ascii')
|
||||
#get image size
|
||||
|
||||
#make sure the image is not bigger then 150px but keep the aspect ratio
|
||||
if width > 150:
|
||||
height = int(height * (150/width))
|
||||
width = 150
|
||||
if height > 150:
|
||||
width = int(width * (150/height))
|
||||
height = 150
|
||||
|
||||
#mkdwn = f"""<img src="{img_str}" alt="Image" with="200" height="200" />"""
|
||||
mkdwn = f'''<div class="gallery" style="margin: 3px;" >
|
||||
<a href="{img_str}">
|
||||
<img src="{img_str}" alt="Image" width="{width}" height="{height}">
|
||||
</a>
|
||||
<div class="desc" style="text-align: center; opacity: 40%;">{seed}</div>
|
||||
</div>
|
||||
'''
|
||||
mkdwn_array.append(mkdwn)
|
||||
html3 += "".join(mkdwn_array)
|
||||
html3 += '</div>'
|
||||
return html3
|
780
scripts/txt2vid.py
Normal file
780
scripts/txt2vid.py
Normal file
@ -0,0 +1,780 @@
|
||||
# base webui import and utils.
|
||||
from webui_streamlit import st
|
||||
from sd_utils import *
|
||||
|
||||
# streamlit imports
|
||||
from streamlit import StopException
|
||||
from streamlit.runtime.in_memory_file_manager import in_memory_file_manager
|
||||
from streamlit.elements import image as STImage
|
||||
|
||||
#other imports
|
||||
|
||||
import os
|
||||
from PIL import Image
|
||||
import torch
|
||||
import numpy as np
|
||||
import time, inspect, timeit
|
||||
import torch
|
||||
from torch import autocast
|
||||
from io import BytesIO
|
||||
import imageio
|
||||
from slugify import slugify
|
||||
|
||||
# Temp imports
|
||||
|
||||
# these are for testing txt2vid, should be removed and we should use things from our own code.
|
||||
from diffusers import StableDiffusionPipeline
|
||||
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
||||
|
||||
# end of imports
|
||||
#---------------------------------------------------------------------------------------------------------------
|
||||
|
||||
try:
|
||||
# this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
|
||||
from transformers import logging
|
||||
|
||||
logging.set_verbosity_error()
|
||||
except:
|
||||
pass
|
||||
|
||||
class plugin_info():
|
||||
plugname = "txt2img"
|
||||
description = "Text to Image"
|
||||
isTab = True
|
||||
displayPriority = 1
|
||||
|
||||
|
||||
if os.path.exists(os.path.join(st.session_state['defaults'].general.GFPGAN_dir, "experiments", "pretrained_models", "GFPGANv1.3.pth")):
|
||||
GFPGAN_available = True
|
||||
else:
|
||||
GFPGAN_available = False
|
||||
|
||||
if os.path.exists(os.path.join(st.session_state['defaults'].general.RealESRGAN_dir, "experiments","pretrained_models", f"{st.session_state['defaults'].txt2vid.RealESRGAN_model}.pth")):
|
||||
RealESRGAN_available = True
|
||||
else:
|
||||
RealESRGAN_available = False
|
||||
|
||||
#
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
@torch.no_grad()
|
||||
def diffuse(
|
||||
pipe,
|
||||
cond_embeddings, # text conditioning, should be (1, 77, 768)
|
||||
cond_latents, # image conditioning, should be (1, 4, 64, 64)
|
||||
num_inference_steps,
|
||||
cfg_scale,
|
||||
eta,
|
||||
):
|
||||
|
||||
torch_device = cond_latents.get_device()
|
||||
|
||||
# classifier guidance: add the unconditional embedding
|
||||
max_length = cond_embeddings.shape[1] # 77
|
||||
uncond_input = pipe.tokenizer([""], padding="max_length", max_length=max_length, return_tensors="pt")
|
||||
uncond_embeddings = pipe.text_encoder(uncond_input.input_ids.to(torch_device))[0]
|
||||
text_embeddings = torch.cat([uncond_embeddings, cond_embeddings])
|
||||
|
||||
# if we use LMSDiscreteScheduler, let's make sure latents are mulitplied by sigmas
|
||||
if isinstance(pipe.scheduler, LMSDiscreteScheduler):
|
||||
cond_latents = cond_latents * pipe.scheduler.sigmas[0]
|
||||
|
||||
# init the scheduler
|
||||
accepts_offset = "offset" in set(inspect.signature(pipe.scheduler.set_timesteps).parameters.keys())
|
||||
extra_set_kwargs = {}
|
||||
if accepts_offset:
|
||||
extra_set_kwargs["offset"] = 1
|
||||
|
||||
pipe.scheduler.set_timesteps(num_inference_steps + st.session_state.sampling_steps, **extra_set_kwargs)
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
||||
# and should be between [0, 1]
|
||||
accepts_eta = "eta" in set(inspect.signature(pipe.scheduler.step).parameters.keys())
|
||||
extra_step_kwargs = {}
|
||||
if accepts_eta:
|
||||
extra_step_kwargs["eta"] = eta
|
||||
|
||||
|
||||
step_counter = 0
|
||||
inference_counter = 0
|
||||
|
||||
if "current_chunk_speed" not in st.session_state:
|
||||
st.session_state["current_chunk_speed"] = 0
|
||||
|
||||
if "previous_chunk_speed_list" not in st.session_state:
|
||||
st.session_state["previous_chunk_speed_list"] = [0]
|
||||
st.session_state["previous_chunk_speed_list"].append(st.session_state["current_chunk_speed"])
|
||||
|
||||
if "update_preview_frequency_list" not in st.session_state:
|
||||
st.session_state["update_preview_frequency_list"] = [0]
|
||||
st.session_state["update_preview_frequency_list"].append(st.session_state['defaults'].txt2vid.update_preview_frequency)
|
||||
|
||||
|
||||
# diffuse!
|
||||
for i, t in enumerate(pipe.scheduler.timesteps):
|
||||
start = timeit.default_timer()
|
||||
|
||||
#status_text.text(f"Running step: {step_counter}{total_number_steps} {percent} | {duration:.2f}{speed}")
|
||||
|
||||
# expand the latents for classifier free guidance
|
||||
latent_model_input = torch.cat([cond_latents] * 2)
|
||||
if isinstance(pipe.scheduler, LMSDiscreteScheduler):
|
||||
sigma = pipe.scheduler.sigmas[i]
|
||||
latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = pipe.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"]
|
||||
|
||||
# cfg
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + cfg_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
if isinstance(pipe.scheduler, LMSDiscreteScheduler):
|
||||
cond_latents = pipe.scheduler.step(noise_pred, i, cond_latents, **extra_step_kwargs)["prev_sample"]
|
||||
else:
|
||||
cond_latents = pipe.scheduler.step(noise_pred, t, cond_latents, **extra_step_kwargs)["prev_sample"]
|
||||
|
||||
#print (st.session_state["update_preview_frequency"])
|
||||
#update the preview image if it is enabled and the frequency matches the step_counter
|
||||
if st.session_state['defaults'].txt2vid.update_preview:
|
||||
step_counter += 1
|
||||
|
||||
if st.session_state['defaults'].txt2vid.update_preview_frequency == step_counter or step_counter == st.session_state.sampling_steps:
|
||||
if st.session_state.dynamic_preview_frequency:
|
||||
st.session_state["current_chunk_speed"], st.session_state["previous_chunk_speed_list"], st.session_state['defaults'].txt2vid.update_preview_frequency, st.session_state["avg_update_preview_frequency"] = optimize_update_preview_frequency(st.session_state["current_chunk_speed"], st.session_state["previous_chunk_speed_list"], st.session_state['defaults'].txt2vid.update_preview_frequency, st.session_state["update_preview_frequency_list"])
|
||||
|
||||
#scale and decode the image latents with vae
|
||||
cond_latents_2 = 1 / 0.18215 * cond_latents
|
||||
image = pipe.vae.decode(cond_latents_2)
|
||||
|
||||
# generate output numpy image as uint8
|
||||
image = torch.clamp((image["sample"] + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
image = transforms.ToPILImage()(image.squeeze_(0))
|
||||
|
||||
st.session_state["preview_image"].image(image)
|
||||
|
||||
step_counter = 0
|
||||
|
||||
duration = timeit.default_timer() - start
|
||||
|
||||
st.session_state["current_chunk_speed"] = duration
|
||||
|
||||
if duration >= 1:
|
||||
speed = "s/it"
|
||||
else:
|
||||
speed = "it/s"
|
||||
duration = 1 / duration
|
||||
|
||||
if i > st.session_state.sampling_steps:
|
||||
inference_counter += 1
|
||||
inference_percent = int(100 * float(inference_counter + 1 if inference_counter < num_inference_steps else num_inference_steps)/float(num_inference_steps))
|
||||
inference_progress = f"{inference_counter + 1 if inference_counter < num_inference_steps else num_inference_steps}/{num_inference_steps} {inference_percent}% "
|
||||
else:
|
||||
inference_progress = ""
|
||||
|
||||
percent = int(100 * float(i+1 if i+1 < st.session_state.sampling_steps else st.session_state.sampling_steps)/float(st.session_state.sampling_steps))
|
||||
frames_percent = int(100 * float(st.session_state.current_frame if st.session_state.current_frame < st.session_state.max_frames else st.session_state.max_frames)/float(st.session_state.max_frames))
|
||||
|
||||
st.session_state["progress_bar_text"].text(
|
||||
f"Running step: {i+1 if i+1 < st.session_state.sampling_steps else st.session_state.sampling_steps}/{st.session_state.sampling_steps} "
|
||||
f"{percent if percent < 100 else 100}% {inference_progress}{duration:.2f}{speed} | "
|
||||
f"Frame: {st.session_state.current_frame + 1 if st.session_state.current_frame < st.session_state.max_frames else st.session_state.max_frames}/{st.session_state.max_frames} "
|
||||
f"{frames_percent if frames_percent < 100 else 100}% {st.session_state.frame_duration:.2f}{st.session_state.frame_speed}"
|
||||
)
|
||||
st.session_state["progress_bar"].progress(percent if percent < 100 else 100)
|
||||
|
||||
return image
|
||||
|
||||
#
|
||||
def txt2vid(
|
||||
# --------------------------------------
|
||||
# args you probably want to change
|
||||
prompts = ["blueberry spaghetti", "strawberry spaghetti"], # prompt to dream about
|
||||
gpu:int = st.session_state['defaults'].general.gpu, # id of the gpu to run on
|
||||
#name:str = 'test', # name of this project, for the output directory
|
||||
#rootdir:str = st.session_state['defaults'].general.outdir,
|
||||
num_steps:int = 200, # number of steps between each pair of sampled points
|
||||
max_frames:int = 10000, # number of frames to write and then exit the script
|
||||
num_inference_steps:int = 50, # more (e.g. 100, 200 etc) can create slightly better images
|
||||
cfg_scale:float = 5.0, # can depend on the prompt. usually somewhere between 3-10 is good
|
||||
do_loop = False,
|
||||
use_lerp_for_text = False,
|
||||
seeds = None,
|
||||
quality:int = 100, # for jpeg compression of the output images
|
||||
eta:float = 0.0,
|
||||
width:int = 256,
|
||||
height:int = 256,
|
||||
weights_path = "CompVis/stable-diffusion-v1-4",
|
||||
scheduler="klms", # choices: default, ddim, klms
|
||||
disable_tqdm = False,
|
||||
#-----------------------------------------------
|
||||
beta_start = 0.0001,
|
||||
beta_end = 0.00012,
|
||||
beta_schedule = "scaled_linear",
|
||||
starting_image=None
|
||||
):
|
||||
"""
|
||||
prompt = ["blueberry spaghetti", "strawberry spaghetti"], # prompt to dream about
|
||||
gpu:int = st.session_state['defaults'].general.gpu, # id of the gpu to run on
|
||||
#name:str = 'test', # name of this project, for the output directory
|
||||
#rootdir:str = st.session_state['defaults'].general.outdir,
|
||||
num_steps:int = 200, # number of steps between each pair of sampled points
|
||||
max_frames:int = 10000, # number of frames to write and then exit the script
|
||||
num_inference_steps:int = 50, # more (e.g. 100, 200 etc) can create slightly better images
|
||||
cfg_scale:float = 5.0, # can depend on the prompt. usually somewhere between 3-10 is good
|
||||
do_loop = False,
|
||||
use_lerp_for_text = False,
|
||||
seed = None,
|
||||
quality:int = 100, # for jpeg compression of the output images
|
||||
eta:float = 0.0,
|
||||
width:int = 256,
|
||||
height:int = 256,
|
||||
weights_path = "CompVis/stable-diffusion-v1-4",
|
||||
scheduler="klms", # choices: default, ddim, klms
|
||||
disable_tqdm = False,
|
||||
beta_start = 0.0001,
|
||||
beta_end = 0.00012,
|
||||
beta_schedule = "scaled_linear"
|
||||
"""
|
||||
mem_mon = MemUsageMonitor('MemMon')
|
||||
mem_mon.start()
|
||||
|
||||
|
||||
seeds = seed_to_int(seeds)
|
||||
|
||||
# We add an extra frame because most
|
||||
# of the time the first frame is just the noise.
|
||||
#max_frames +=1
|
||||
|
||||
assert torch.cuda.is_available()
|
||||
assert height % 8 == 0 and width % 8 == 0
|
||||
torch.manual_seed(seeds)
|
||||
torch_device = f"cuda:{gpu}"
|
||||
|
||||
# init the output dir
|
||||
sanitized_prompt = slugify(prompts)
|
||||
|
||||
full_path = os.path.join(os.getcwd(), st.session_state['defaults'].general.outdir, "txt2vid-samples", "samples", sanitized_prompt)
|
||||
|
||||
if len(full_path) > 220:
|
||||
sanitized_prompt = sanitized_prompt[:220-len(full_path)]
|
||||
full_path = os.path.join(os.getcwd(), st.session_state['defaults'].general.outdir, "txt2vid-samples", "samples", sanitized_prompt)
|
||||
|
||||
os.makedirs(full_path, exist_ok=True)
|
||||
|
||||
# Write prompt info to file in output dir so we can keep track of what we did
|
||||
if st.session_state.write_info_files:
|
||||
with open(os.path.join(full_path , f'{slugify(str(seeds))}_config.json' if len(prompts) > 1 else "prompts_config.json"), "w") as outfile:
|
||||
outfile.write(json.dumps(
|
||||
dict(
|
||||
prompts = prompts,
|
||||
gpu = gpu,
|
||||
num_steps = num_steps,
|
||||
max_frames = max_frames,
|
||||
num_inference_steps = num_inference_steps,
|
||||
cfg_scale = cfg_scale,
|
||||
do_loop = do_loop,
|
||||
use_lerp_for_text = use_lerp_for_text,
|
||||
seeds = seeds,
|
||||
quality = quality,
|
||||
eta = eta,
|
||||
width = width,
|
||||
height = height,
|
||||
weights_path = weights_path,
|
||||
scheduler=scheduler,
|
||||
disable_tqdm = disable_tqdm,
|
||||
beta_start = beta_start,
|
||||
beta_end = beta_end,
|
||||
beta_schedule = beta_schedule
|
||||
),
|
||||
indent=2,
|
||||
sort_keys=False,
|
||||
))
|
||||
|
||||
#print(scheduler)
|
||||
default_scheduler = PNDMScheduler(
|
||||
beta_start=beta_start, beta_end=beta_end, beta_schedule=beta_schedule
|
||||
)
|
||||
# ------------------------------------------------------------------------------
|
||||
#Schedulers
|
||||
ddim_scheduler = DDIMScheduler(
|
||||
beta_start=beta_start,
|
||||
beta_end=beta_end,
|
||||
beta_schedule=beta_schedule,
|
||||
clip_sample=False,
|
||||
set_alpha_to_one=False,
|
||||
)
|
||||
|
||||
klms_scheduler = LMSDiscreteScheduler(
|
||||
beta_start=beta_start, beta_end=beta_end, beta_schedule=beta_schedule
|
||||
)
|
||||
|
||||
SCHEDULERS = dict(default=default_scheduler, ddim=ddim_scheduler, klms=klms_scheduler)
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
st.session_state["progress_bar_text"].text("Loading models...")
|
||||
|
||||
try:
|
||||
if "model" in st.session_state:
|
||||
del st.session_state["model"]
|
||||
except:
|
||||
pass
|
||||
|
||||
#print (st.session_state["weights_path"] != weights_path)
|
||||
|
||||
try:
|
||||
if not "pipe" in st.session_state or st.session_state["weights_path"] != weights_path:
|
||||
if st.session_state["weights_path"] != weights_path:
|
||||
del st.session_state["weights_path"]
|
||||
|
||||
st.session_state["weights_path"] = weights_path
|
||||
st.session_state["pipe"] = StableDiffusionPipeline.from_pretrained(
|
||||
weights_path,
|
||||
use_local_file=True,
|
||||
use_auth_token=True,
|
||||
torch_dtype=torch.float16 if st.session_state['defaults'].general.use_float16 else None,
|
||||
revision="fp16" if not st.session_state['defaults'].general.no_half else None
|
||||
)
|
||||
|
||||
st.session_state["pipe"].unet.to(torch_device)
|
||||
st.session_state["pipe"].vae.to(torch_device)
|
||||
st.session_state["pipe"].text_encoder.to(torch_device)
|
||||
|
||||
if st.session_state.defaults.general.enable_attention_slicing:
|
||||
st.session_state["pipe"].enable_attention_slicing()
|
||||
if st.session_state.defaults.general.enable_minimal_memory_usage:
|
||||
st.session_state["pipe"].enable_minimal_memory_usage()
|
||||
|
||||
print("Tx2Vid Model Loaded")
|
||||
else:
|
||||
print("Tx2Vid Model already Loaded")
|
||||
|
||||
except:
|
||||
#del st.session_state["weights_path"]
|
||||
#del st.session_state["pipe"]
|
||||
|
||||
st.session_state["weights_path"] = weights_path
|
||||
st.session_state["pipe"] = StableDiffusionPipeline.from_pretrained(
|
||||
weights_path,
|
||||
use_local_file=True,
|
||||
use_auth_token=True,
|
||||
torch_dtype=torch.float16 if st.session_state['defaults'].general.use_float16 else None,
|
||||
revision="fp16" if not st.session_state['defaults'].general.no_half else None
|
||||
)
|
||||
|
||||
st.session_state["pipe"].unet.to(torch_device)
|
||||
st.session_state["pipe"].vae.to(torch_device)
|
||||
st.session_state["pipe"].text_encoder.to(torch_device)
|
||||
|
||||
if st.session_state.defaults.general.enable_attention_slicing:
|
||||
st.session_state["pipe"].enable_attention_slicing()
|
||||
|
||||
|
||||
print("Tx2Vid Model Loaded")
|
||||
|
||||
st.session_state["pipe"].scheduler = SCHEDULERS[scheduler]
|
||||
|
||||
# get the conditional text embeddings based on the prompt
|
||||
text_input = st.session_state["pipe"].tokenizer(prompts, padding="max_length", max_length=st.session_state["pipe"].tokenizer.model_max_length, truncation=True, return_tensors="pt")
|
||||
cond_embeddings = st.session_state["pipe"].text_encoder(text_input.input_ids.to(torch_device))[0] # shape [1, 77, 768]
|
||||
|
||||
#
|
||||
if st.session_state.defaults.general.use_sd_concepts_library:
|
||||
|
||||
prompt_tokens = re.findall('<([a-zA-Z0-9-]+)>', prompts)
|
||||
|
||||
if prompt_tokens:
|
||||
# compviz
|
||||
#tokenizer = (st.session_state["model"] if not st.session_state['defaults'].general.optimized else st.session_state.modelCS).cond_stage_model.tokenizer
|
||||
#text_encoder = (st.session_state["model"] if not st.session_state['defaults'].general.optimized else st.session_state.modelCS).cond_stage_model.transformer
|
||||
|
||||
# diffusers
|
||||
tokenizer = st.session_state.pipe.tokenizer
|
||||
text_encoder = st.session_state.pipe.text_encoder
|
||||
|
||||
ext = ('pt', 'bin')
|
||||
#print (prompt_tokens)
|
||||
|
||||
if len(prompt_tokens) > 1:
|
||||
for token_name in prompt_tokens:
|
||||
embedding_path = os.path.join(st.session_state['defaults'].general.sd_concepts_library_folder, token_name)
|
||||
if os.path.exists(embedding_path):
|
||||
for files in os.listdir(embedding_path):
|
||||
if files.endswith(ext):
|
||||
load_learned_embed_in_clip(f"{os.path.join(embedding_path, files)}", text_encoder, tokenizer, f"<{token_name}>")
|
||||
else:
|
||||
embedding_path = os.path.join(st.session_state['defaults'].general.sd_concepts_library_folder, prompt_tokens[0])
|
||||
if os.path.exists(embedding_path):
|
||||
for files in os.listdir(embedding_path):
|
||||
if files.endswith(ext):
|
||||
load_learned_embed_in_clip(f"{os.path.join(embedding_path, files)}", text_encoder, tokenizer, f"<{prompt_tokens[0]}>")
|
||||
|
||||
# sample a source
|
||||
init1 = torch.randn((1, st.session_state["pipe"].unet.in_channels, height // 8, width // 8), device=torch_device)
|
||||
|
||||
if do_loop:
|
||||
prompts = [prompts, prompts]
|
||||
seeds = [seeds, seeds]
|
||||
#first_seed, *seeds = seeds
|
||||
#prompts.append(prompts)
|
||||
#seeds.append(first_seed)
|
||||
|
||||
|
||||
# iterate the loop
|
||||
frames = []
|
||||
frame_index = 0
|
||||
|
||||
st.session_state["total_frames_avg_duration"] = []
|
||||
st.session_state["total_frames_avg_speed"] = []
|
||||
|
||||
try:
|
||||
while frame_index < max_frames:
|
||||
st.session_state["frame_duration"] = 0
|
||||
st.session_state["frame_speed"] = 0
|
||||
st.session_state["current_frame"] = frame_index
|
||||
|
||||
# sample the destination
|
||||
init2 = torch.randn((1, st.session_state["pipe"].unet.in_channels, height // 8, width // 8), device=torch_device)
|
||||
|
||||
for i, t in enumerate(np.linspace(0, 1, max_frames)):
|
||||
start = timeit.default_timer()
|
||||
print(f"COUNT: {frame_index+1}/{max_frames}")
|
||||
|
||||
#if use_lerp_for_text:
|
||||
#init = torch.lerp(init1, init2, float(t))
|
||||
#else:
|
||||
#init = slerp(gpu, float(t), init1, init2)
|
||||
|
||||
init = slerp(gpu, float(t), init1, init2)
|
||||
|
||||
with autocast("cuda"):
|
||||
image = diffuse(st.session_state["pipe"], cond_embeddings, init, num_inference_steps, cfg_scale, eta)
|
||||
|
||||
#im = Image.fromarray(image)
|
||||
outpath = os.path.join(full_path, 'frame%06d.png' % frame_index)
|
||||
image.save(outpath, quality=quality)
|
||||
|
||||
# send the image to the UI to update it
|
||||
#st.session_state["preview_image"].image(im)
|
||||
|
||||
#append the frames to the frames list so we can use them later.
|
||||
frames.append(np.asarray(image))
|
||||
|
||||
#increase frame_index counter.
|
||||
frame_index += 1
|
||||
|
||||
st.session_state["current_frame"] = frame_index
|
||||
|
||||
duration = timeit.default_timer() - start
|
||||
|
||||
if duration >= 1:
|
||||
speed = "s/it"
|
||||
else:
|
||||
speed = "it/s"
|
||||
duration = 1 / duration
|
||||
|
||||
st.session_state["frame_duration"] = duration
|
||||
st.session_state["frame_speed"] = speed
|
||||
|
||||
init1 = init2
|
||||
|
||||
except StopException:
|
||||
pass
|
||||
|
||||
|
||||
if st.session_state['save_video']:
|
||||
# write video to memory
|
||||
#output = io.BytesIO()
|
||||
#writer = imageio.get_writer(os.path.join(os.getcwd(), st.session_state['defaults'].general.outdir, "txt2vid-samples"), im, extension=".mp4", fps=30)
|
||||
try:
|
||||
video_path = os.path.join(os.getcwd(), st.session_state['defaults'].general.outdir, "txt2vid-samples","temp.mp4")
|
||||
writer = imageio.get_writer(video_path, fps=24)
|
||||
for frame in frames:
|
||||
writer.append_data(frame)
|
||||
writer.close()
|
||||
except:
|
||||
print("Can't save video, skipping.")
|
||||
|
||||
# show video preview on the UI
|
||||
st.session_state["preview_video"].video(open(video_path, 'rb').read())
|
||||
|
||||
mem_max_used, mem_total = mem_mon.read_and_stop()
|
||||
time_diff = time.time()- start
|
||||
|
||||
info = f"""
|
||||
{prompts}
|
||||
Sampling Steps: {num_steps}, Sampler: {scheduler}, CFG scale: {cfg_scale}, Seed: {seeds}, Max Frames: {max_frames}""".strip()
|
||||
stats = f'''
|
||||
Took { round(time_diff, 2) }s total ({ round(time_diff/(max_frames),2) }s per image)
|
||||
Peak memory usage: { -(mem_max_used // -1_048_576) } MiB / { -(mem_total // -1_048_576) } MiB / { round(mem_max_used/mem_total*100, 3) }%'''
|
||||
|
||||
return video_path, seeds, info, stats
|
||||
|
||||
#on import run init
|
||||
def createHTMLGallery(images,info):
|
||||
html3 = """
|
||||
<div class="gallery-history" style="
|
||||
display: flex;
|
||||
flex-wrap: wrap;
|
||||
align-items: flex-start;">
|
||||
"""
|
||||
mkdwn_array = []
|
||||
for i in images:
|
||||
try:
|
||||
seed = info[images.index(i)]
|
||||
except:
|
||||
seed = ' '
|
||||
image_io = BytesIO()
|
||||
i.save(image_io, 'PNG')
|
||||
width, height = i.size
|
||||
#get random number for the id
|
||||
image_id = "%s" % (str(images.index(i)))
|
||||
(data, mimetype) = STImage._normalize_to_bytes(image_io.getvalue(), width, 'auto')
|
||||
this_file = in_memory_file_manager.add(data, mimetype, image_id)
|
||||
img_str = this_file.url
|
||||
#img_str = 'data:image/png;base64,' + b64encode(image_io.getvalue()).decode('ascii')
|
||||
#get image size
|
||||
|
||||
#make sure the image is not bigger then 150px but keep the aspect ratio
|
||||
if width > 150:
|
||||
height = int(height * (150/width))
|
||||
width = 150
|
||||
if height > 150:
|
||||
width = int(width * (150/height))
|
||||
height = 150
|
||||
|
||||
#mkdwn = f"""<img src="{img_str}" alt="Image" with="200" height="200" />"""
|
||||
mkdwn = f'''<div class="gallery" style="margin: 3px;" >
|
||||
<a href="{img_str}">
|
||||
<img src="{img_str}" alt="Image" width="{width}" height="{height}">
|
||||
</a>
|
||||
<div class="desc" style="text-align: center; opacity: 40%;">{seed}</div>
|
||||
</div>
|
||||
'''
|
||||
mkdwn_array.append(mkdwn)
|
||||
|
||||
html3 += "".join(mkdwn_array)
|
||||
html3 += '</div>'
|
||||
return html3
|
||||
#
|
||||
def layout():
|
||||
with st.form("txt2vid-inputs"):
|
||||
st.session_state["generation_mode"] = "txt2vid"
|
||||
|
||||
input_col1, generate_col1 = st.columns([10,1])
|
||||
with input_col1:
|
||||
#prompt = st.text_area("Input Text","")
|
||||
prompt = st.text_input("Input Text","", placeholder="A corgi wearing a top hat as an oil painting.")
|
||||
|
||||
# Every form must have a submit button, the extra blank spaces is a temp way to align it with the input field. Needs to be done in CSS or some other way.
|
||||
generate_col1.write("")
|
||||
generate_col1.write("")
|
||||
generate_button = generate_col1.form_submit_button("Generate")
|
||||
|
||||
# creating the page layout using columns
|
||||
col1, col2, col3 = st.columns([1,2,1], gap="large")
|
||||
|
||||
with col1:
|
||||
width = st.slider("Width:", min_value=64, max_value=2048, value=st.session_state['defaults'].txt2vid.width, step=64)
|
||||
height = st.slider("Height:", min_value=64, max_value=2048, value=st.session_state['defaults'].txt2vid.height, step=64)
|
||||
cfg_scale = st.slider("CFG (Classifier Free Guidance Scale):", min_value=1.0, max_value=30.0, value=st.session_state['defaults'].txt2vid.cfg_scale, step=0.5, help="How strongly the image should follow the prompt.")
|
||||
|
||||
#uploaded_images = st.file_uploader("Upload Image", accept_multiple_files=False, type=["png", "jpg", "jpeg", "webp"],
|
||||
#help="Upload an image which will be used for the image to image generation.")
|
||||
seed = st.text_input("Seed:", value=st.session_state['defaults'].txt2vid.seed, help=" The seed to use, if left blank a random seed will be generated.")
|
||||
#batch_count = st.slider("Batch count.", min_value=1, max_value=100, value=st.session_state['defaults'].txt2vid.batch_count, step=1, help="How many iterations or batches of images to generate in total.")
|
||||
#batch_size = st.slider("Batch size", min_value=1, max_value=250, value=st.session_state['defaults'].txt2vid.batch_size, step=1,
|
||||
#help="How many images are at once in a batch.\
|
||||
#It increases the VRAM usage a lot but if you have enough VRAM it can reduce the time it takes to finish generation as more images are generated at once.\
|
||||
#Default: 1")
|
||||
|
||||
st.session_state["max_frames"] = int(st.text_input("Max Frames:", value=st.session_state['defaults'].txt2vid.max_frames, help="Specify the max number of frames you want to generate."))
|
||||
|
||||
with st.expander("Preview Settings"):
|
||||
st.session_state["update_preview"] = st.checkbox("Update Image Preview", value=st.session_state['defaults'].txt2vid.update_preview,
|
||||
help="If enabled the image preview will be updated during the generation instead of at the end. \
|
||||
You can use the Update Preview \Frequency option bellow to customize how frequent it's updated. \
|
||||
By default this is enabled and the frequency is set to 1 step.")
|
||||
|
||||
st.session_state["update_preview_frequency"] = st.text_input("Update Image Preview Frequency", value=st.session_state['defaults'].txt2vid.update_preview_frequency,
|
||||
help="Frequency in steps at which the the preview image is updated. By default the frequency \
|
||||
is set to 1 step.")
|
||||
|
||||
#
|
||||
|
||||
|
||||
|
||||
with col2:
|
||||
preview_tab, gallery_tab = st.tabs(["Preview", "Gallery"])
|
||||
|
||||
with preview_tab:
|
||||
#st.write("Image")
|
||||
#Image for testing
|
||||
#image = Image.open(requests.get("https://icon-library.com/images/image-placeholder-icon/image-placeholder-icon-13.jpg", stream=True).raw).convert('RGB')
|
||||
#new_image = image.resize((175, 240))
|
||||
#preview_image = st.image(image)
|
||||
|
||||
# create an empty container for the image, progress bar, etc so we can update it later and use session_state to hold them globally.
|
||||
st.session_state["preview_image"] = st.empty()
|
||||
|
||||
st.session_state["loading"] = st.empty()
|
||||
|
||||
st.session_state["progress_bar_text"] = st.empty()
|
||||
st.session_state["progress_bar"] = st.empty()
|
||||
|
||||
#generate_video = st.empty()
|
||||
st.session_state["preview_video"] = st.empty()
|
||||
|
||||
message = st.empty()
|
||||
|
||||
with gallery_tab:
|
||||
st.write('Here should be the image gallery, if I could make a grid in streamlit.')
|
||||
|
||||
with col3:
|
||||
# If we have custom models available on the "models/custom"
|
||||
#folder then we show a menu to select which model we want to use, otherwise we use the main model for SD
|
||||
if st.session_state["CustomModel_available"]:
|
||||
custom_model = st.selectbox("Custom Model:", st.session_state["defaults"].txt2vid.custom_models_list,
|
||||
index=st.session_state["defaults"].txt2vid.custom_models_list.index(st.session_state["defaults"].txt2vid.default_model),
|
||||
help="Select the model you want to use. This option is only available if you have custom models \
|
||||
on your 'models/custom' folder. The model name that will be shown here is the same as the name\
|
||||
the file for the model has on said folder, it is recommended to give the .ckpt file a name that \
|
||||
will make it easier for you to distinguish it from other models. Default: Stable Diffusion v1.4")
|
||||
else:
|
||||
custom_model = "CompVis/stable-diffusion-v1-4"
|
||||
|
||||
#st.session_state["weights_path"] = custom_model
|
||||
#else:
|
||||
#custom_model = "CompVis/stable-diffusion-v1-4"
|
||||
#st.session_state["weights_path"] = f"CompVis/{slugify(custom_model.lower())}"
|
||||
|
||||
st.session_state.sampling_steps = st.slider("Sampling Steps",
|
||||
value=st.session_state['defaults'].txt2vid.sampling_steps,
|
||||
min_value=st.session_state['defaults'].txt2vid.slider_bounds.sampling.lower,
|
||||
max_value=st.session_state['defaults'].txt2vid.slider_bounds.sampling.upper,
|
||||
step=st.session_state['defaults'].txt2vid.slider_steps.sampling,
|
||||
help="Number of steps between each pair of sampled points")
|
||||
st.session_state.num_inference_steps = st.slider("Inference Steps:", value=st.session_state['defaults'].txt2vid.num_inference_steps, min_value=10,step=10, max_value=500,
|
||||
help="Higher values (e.g. 100, 200 etc) can create better images.")
|
||||
|
||||
#sampler_name_list = ["k_lms", "k_euler", "k_euler_a", "k_dpm_2", "k_dpm_2_a", "k_heun", "PLMS", "DDIM"]
|
||||
#sampler_name = st.selectbox("Sampling method", sampler_name_list,
|
||||
#index=sampler_name_list.index(st.session_state['defaults'].txt2vid.default_sampler), help="Sampling method to use. Default: k_euler")
|
||||
scheduler_name_list = ["klms", "ddim"]
|
||||
scheduler_name = st.selectbox("Scheduler:", scheduler_name_list,
|
||||
index=scheduler_name_list.index(st.session_state['defaults'].txt2vid.scheduler_name), help="Scheduler to use. Default: klms")
|
||||
|
||||
beta_scheduler_type_list = ["scaled_linear", "linear"]
|
||||
beta_scheduler_type = st.selectbox("Beta Schedule Type:", beta_scheduler_type_list,
|
||||
index=beta_scheduler_type_list.index(st.session_state['defaults'].txt2vid.beta_scheduler_type), help="Schedule Type to use. Default: linear")
|
||||
|
||||
|
||||
#basic_tab, advanced_tab = st.tabs(["Basic", "Advanced"])
|
||||
|
||||
#with basic_tab:
|
||||
#summit_on_enter = st.radio("Submit on enter?", ("Yes", "No"), horizontal=True,
|
||||
#help="Press the Enter key to summit, when 'No' is selected you can use the Enter key to write multiple lines.")
|
||||
|
||||
with st.expander("Advanced"):
|
||||
st.session_state["separate_prompts"] = st.checkbox("Create Prompt Matrix.", value=st.session_state['defaults'].txt2vid.separate_prompts,
|
||||
help="Separate multiple prompts using the `|` character, and get all combinations of them.")
|
||||
st.session_state["normalize_prompt_weights"] = st.checkbox("Normalize Prompt Weights.",
|
||||
value=st.session_state['defaults'].txt2vid.normalize_prompt_weights, help="Ensure the sum of all weights add up to 1.0")
|
||||
st.session_state["save_individual_images"] = st.checkbox("Save individual images.",
|
||||
value=st.session_state['defaults'].txt2vid.save_individual_images, help="Save each image generated before any filter or enhancement is applied.")
|
||||
st.session_state["save_video"] = st.checkbox("Save video",value=st.session_state['defaults'].txt2vid.save_video, help="Save a video with all the images generated as frames at the end of the generation.")
|
||||
st.session_state["group_by_prompt"] = st.checkbox("Group results by prompt", value=st.session_state['defaults'].txt2vid.group_by_prompt,
|
||||
help="Saves all the images with the same prompt into the same folder. When using a prompt matrix each prompt combination will have its own folder.")
|
||||
st.session_state["write_info_files"] = st.checkbox("Write Info file", value=st.session_state['defaults'].txt2vid.write_info_files,
|
||||
help="Save a file next to the image with informartion about the generation.")
|
||||
st.session_state["dynamic_preview_frequency"] = st.checkbox("Dynamic Preview Frequency", value=st.session_state['defaults'].txt2vid.dynamic_preview_frequency,
|
||||
help="This option tries to find the best value at which we can update \
|
||||
the preview image during generation while minimizing the impact it has in performance. Default: True")
|
||||
st.session_state["do_loop"] = st.checkbox("Do Loop", value=st.session_state['defaults'].txt2vid.do_loop,
|
||||
help="Do loop")
|
||||
st.session_state["save_as_jpg"] = st.checkbox("Save samples as jpg", value=st.session_state['defaults'].txt2vid.save_as_jpg, help="Saves the images as jpg instead of png.")
|
||||
|
||||
if GFPGAN_available:
|
||||
st.session_state["use_GFPGAN"] = st.checkbox("Use GFPGAN", value=st.session_state['defaults'].txt2vid.use_GFPGAN, help="Uses the GFPGAN model to improve faces after the generation. This greatly improve the quality and consistency of faces but uses extra VRAM. Disable if you need the extra VRAM.")
|
||||
else:
|
||||
st.session_state["use_GFPGAN"] = False
|
||||
|
||||
if RealESRGAN_available:
|
||||
st.session_state["use_RealESRGAN"] = st.checkbox("Use RealESRGAN", value=st.session_state['defaults'].txt2vid.use_RealESRGAN,
|
||||
help="Uses the RealESRGAN model to upscale the images after the generation. This greatly improve the quality and lets you have high resolution images but uses extra VRAM. Disable if you need the extra VRAM.")
|
||||
st.session_state["RealESRGAN_model"] = st.selectbox("RealESRGAN model", ["RealESRGAN_x4plus", "RealESRGAN_x4plus_anime_6B"], index=0)
|
||||
else:
|
||||
st.session_state["use_RealESRGAN"] = False
|
||||
st.session_state["RealESRGAN_model"] = "RealESRGAN_x4plus"
|
||||
|
||||
st.session_state["variant_amount"] = st.slider("Variant Amount:", value=st.session_state['defaults'].txt2vid.variant_amount, min_value=0.0, max_value=1.0, step=0.01)
|
||||
st.session_state["variant_seed"] = st.text_input("Variant Seed:", value=st.session_state['defaults'].txt2vid.seed, help="The seed to use when generating a variant, if left blank a random seed will be generated.")
|
||||
st.session_state["beta_start"] = st.slider("Beta Start:", value=st.session_state['defaults'].txt2vid.beta_start, min_value=0.0001, max_value=0.03, step=0.0001, format="%.4f")
|
||||
st.session_state["beta_end"] = st.slider("Beta End:", value=st.session_state['defaults'].txt2vid.beta_end, min_value=0.0001, max_value=0.03, step=0.0001, format="%.4f")
|
||||
|
||||
if generate_button:
|
||||
#print("Loading models")
|
||||
# load the models when we hit the generate button for the first time, it wont be loaded after that so dont worry.
|
||||
#load_models(False, False, False, st.session_state["RealESRGAN_model"], CustomModel_available=st.session_state["CustomModel_available"], custom_model=custom_model)
|
||||
|
||||
try:
|
||||
# run video generation
|
||||
video, seed, info, stats = txt2vid(prompts=prompt, gpu=st.session_state["defaults"].general.gpu,
|
||||
num_steps=st.session_state.sampling_steps, max_frames=int(st.session_state.max_frames),
|
||||
num_inference_steps=st.session_state.num_inference_steps,
|
||||
cfg_scale=cfg_scale,do_loop=st.session_state["do_loop"],
|
||||
seeds=seed, quality=100, eta=0.0, width=width,
|
||||
height=height, weights_path=custom_model, scheduler=scheduler_name,
|
||||
disable_tqdm=False, beta_start=st.session_state["beta_start"], beta_end=st.session_state["beta_end"],
|
||||
beta_schedule=beta_scheduler_type, starting_image=None)
|
||||
|
||||
#message.success('Done!', icon="✅")
|
||||
message.success('Render Complete: ' + info + '; Stats: ' + stats, icon="✅")
|
||||
|
||||
#history_tab,col1,col2,col3,PlaceHolder,col1_cont,col2_cont,col3_cont = st.session_state['historyTab']
|
||||
|
||||
#if 'latestVideos' in st.session_state:
|
||||
#for i in video:
|
||||
##push the new image to the list of latest images and remove the oldest one
|
||||
##remove the last index from the list\
|
||||
#st.session_state['latestVideos'].pop()
|
||||
##add the new image to the start of the list
|
||||
#st.session_state['latestVideos'].insert(0, i)
|
||||
#PlaceHolder.empty()
|
||||
|
||||
#with PlaceHolder.container():
|
||||
#col1, col2, col3 = st.columns(3)
|
||||
#col1_cont = st.container()
|
||||
#col2_cont = st.container()
|
||||
#col3_cont = st.container()
|
||||
|
||||
#with col1_cont:
|
||||
#with col1:
|
||||
#st.image(st.session_state['latestVideos'][0])
|
||||
#st.image(st.session_state['latestVideos'][3])
|
||||
#st.image(st.session_state['latestVideos'][6])
|
||||
#with col2_cont:
|
||||
#with col2:
|
||||
#st.image(st.session_state['latestVideos'][1])
|
||||
#st.image(st.session_state['latestVideos'][4])
|
||||
#st.image(st.session_state['latestVideos'][7])
|
||||
#with col3_cont:
|
||||
#with col3:
|
||||
#st.image(st.session_state['latestVideos'][2])
|
||||
#st.image(st.session_state['latestVideos'][5])
|
||||
#st.image(st.session_state['latestVideos'][8])
|
||||
#historyGallery = st.empty()
|
||||
|
||||
## check if output_images length is the same as seeds length
|
||||
#with gallery_tab:
|
||||
#st.markdown(createHTMLGallery(video,seed), unsafe_allow_html=True)
|
||||
|
||||
|
||||
#st.session_state['historyTab'] = [history_tab,col1,col2,col3,PlaceHolder,col1_cont,col2_cont,col3_cont]
|
||||
|
||||
except (StopException, KeyError):
|
||||
print(f"Received Streamlit StopException")
|
||||
|
||||
|
559
scripts/webui.py
559
scripts/webui.py
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
2738
scripts/webui_streamlit_old.py
Normal file
2738
scripts/webui_streamlit_old.py
Normal file
File diff suppressed because it is too large
Load Diff
2
setup.py
2
setup.py
@ -1,7 +1,7 @@
|
||||
from setuptools import setup, find_packages
|
||||
|
||||
setup(
|
||||
name='latent-diffusion',
|
||||
name='sd-webui',
|
||||
version='0.0.1',
|
||||
description='',
|
||||
packages=find_packages(),
|
||||
|
2
webui.sh
2
webui.sh
@ -37,7 +37,7 @@ if ! conda env list | grep ".*${ENV_NAME}.*" >/dev/null 2>&1; then
|
||||
ENV_UPDATED=1
|
||||
elif [[ ! -z $CONDA_FORCE_UPDATE && $CONDA_FORCE_UPDATE == "true" ]] || (( $ENV_MODIFIED > $ENV_MODIFIED_CACHED )); then
|
||||
echo "Updating conda env: ${ENV_NAME} ..."
|
||||
conda env update --file $ENV_FILE --prune
|
||||
PIP_EXISTS_ACTION=w conda env update --file $ENV_FILE --prune
|
||||
ENV_UPDATED=1
|
||||
fi
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user