diff --git a/.dockerignore b/.dockerignore
new file mode 100644
index 0000000..4ac12df
--- /dev/null
+++ b/.dockerignore
@@ -0,0 +1,3 @@
+models/
+outputs/
+src/
diff --git a/.env_docker.example b/.env_docker.example
index 5a34945..51eb059 100644
--- a/.env_docker.example
+++ b/.env_docker.example
@@ -6,9 +6,13 @@ CONDA_FORCE_UPDATE=false
# (useful to set to false after you're sure the model files are already in place)
VALIDATE_MODELS=true
-#Automatically relaunch the webui on crashes
+# Automatically relaunch the webui on crashes
WEBUI_RELAUNCH=true
-#Pass cli arguments to webui.py e.g:
-#WEBUI_ARGS=--gpu=1 --esrgan-gpu=1 --gfpgan-gpu=1
+# Which webui to launch
+# WEBUI_SCRIPT=webui_streamlit.py
+WEBUI_SCRIPT=webui.py
+
+# Pass cli arguments to webui.py e.g:
+# WEBUI_ARGS=--optimized --extra-models-cpu --gpu=1 --esrgan-gpu=1 --gfpgan-gpu=1
WEBUI_ARGS=
diff --git a/.gitignore b/.gitignore
index 4b30236..b014154 100644
--- a/.gitignore
+++ b/.gitignore
@@ -47,16 +47,21 @@ MANIFEST
.env_updated
condaenv.*.requirements.txt
+# Visual Studio directories
+.vs/
+.vscode/
# =========================================================================== #
# Repo-specific
# =========================================================================== #
+/configs/webui/userconfig_streamlit.yaml
/custom-conda-path.txt
/src/*
-/outputs/*
+/outputs
+/model_cache
/log/**/*.png
/log/log.csv
/flagged/*
/gfpgan/*
/models/*
-z_version_env.tmp
\ No newline at end of file
+z_version_env.tmp
diff --git a/Dockerfile b/Dockerfile
index 8d5ecb4..2b061b0 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -1,6 +1,10 @@
FROM nvidia/cuda:11.3.1-runtime-ubuntu20.04
-ENV DEBIAN_FRONTEND=noninteractive
+ENV DEBIAN_FRONTEND=noninteractive \
+ PYTHONUNBUFFERED=1 \
+ PYTHONIOENCODING=UTF-8 \
+ CONDA_DIR=/opt/conda
+
WORKDIR /sd
SHELL ["/bin/bash", "-c"]
@@ -11,7 +15,6 @@ RUN apt-get update && \
rm -rf /var/lib/apt/lists/*
# Install miniconda
-ENV CONDA_DIR /opt/conda
RUN wget -O ~/miniconda.sh -q --show-progress --progress=bar:force https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && \
/bin/bash ~/miniconda.sh -b -p $CONDA_DIR && \
rm ~/miniconda.sh
@@ -20,7 +23,7 @@ ENV PATH=$CONDA_DIR/bin:$PATH
# Install font for prompt matrix
COPY /data/DejaVuSans.ttf /usr/share/fonts/truetype/
-EXPOSE 7860
+EXPOSE 7860 8501
COPY ./entrypoint.sh /sd/
ENTRYPOINT /sd/entrypoint.sh
diff --git a/README.md b/README.md
index 36d1dc0..f5d96ba 100644
--- a/README.md
+++ b/README.md
@@ -46,8 +46,8 @@ Features:
* Gradio GUI: Idiot-proof, fully featured frontend for both txt2img and img2img generation
* No more manually typing parameters, now all you have to do is write your prompt and adjust sliders
-* GFPGAN Face Correction 🔥: [Download the model](https://github.com/sd-webui/stable-diffusion-webui#gfpgan)Automatically correct distorted faces with a built-in GFPGAN option, fixes them in less than half a second
-* RealESRGAN Upscaling 🔥: [Download the models](https://github.com/sd-webui/stable-diffusion-webui#realesrgan) Boosts the resolution of images with a built-in RealESRGAN option
+* GFPGAN Face Correction 🔥: [Download the model](https://github.com/sd-webui/stable-diffusion-webui/wiki/Installation#optional-additional-models) Automatically correct distorted faces with a built-in GFPGAN option, fixes them in less than half a second
+* RealESRGAN Upscaling 🔥: [Download the models](https://github.com/sd-webui/stable-diffusion-webui/wiki/Installation#optional-additional-models) Boosts the resolution of images with a built-in RealESRGAN option
* :computer: esrgan/gfpgan on cpu support :computer:
* Textual inversion 🔥: [info](https://textual-inversion.github.io/) - requires enabling, see [here](https://github.com/hlky/sd-enable-textual-inversion), script works as usual without it enabled
* Advanced img2img editor :art: :fire: :art:
@@ -106,7 +106,7 @@ that are not in original script.
### GFPGAN
Lets you improve faces in pictures using the GFPGAN model. There is a checkbox in every tab to use GFPGAN at 100%, and
-also a separate tab that just allows you to use GFPGAN on any picture, with a slider that controls how strongthe effect is.
+also a separate tab that just allows you to use GFPGAN on any picture, with a slider that controls how strong the effect is.
![](images/GFPGAN.png)
diff --git a/configs/webui/webui.yaml b/configs/webui/webui.yaml
index b7bd258..25d222b 100644
--- a/configs/webui/webui.yaml
+++ b/configs/webui/webui.yaml
@@ -12,8 +12,9 @@ txt2img:
# 5: Write sample info files
# 6: write sample info to log file
# 7: jpg samples
- # 8: Fix faces using GFPGAN
- # 9: Upscale images using RealESRGAN
+ # 8: Filter NSFW content
+ # 9: Fix faces using GFPGAN
+ # 10: Upscale images using RealESRGAN
toggles: [1, 2, 3, 4, 5]
sampler_name: k_lms
ddim_eta: 0.0 # legacy name, applies to all algorithms.
@@ -40,8 +41,10 @@ img2img:
# 6: Sort samples by prompt
# 7: Write sample info files
# 8: jpg samples
- # 9: Fix faces using GFPGAN
- # 10: Upscale images using Real-ESRGAN
+ # 9: Color correction
+ # 10: Filter NSFW content
+ # 11: Fix faces using GFPGAN
+ # 12: Upscale images using Real-ESRGAN
toggles: [1, 4, 5, 6, 7]
sampler_name: k_lms
ddim_eta: 0.0
diff --git a/configs/webui/webui_streamlit.yaml b/configs/webui/webui_streamlit.yaml
index 84263bd..394494c 100644
--- a/configs/webui/webui_streamlit.yaml
+++ b/configs/webui/webui_streamlit.yaml
@@ -1,14 +1,19 @@
# UI defaults configuration file. It is automatically loaded if located at configs/webui/webui_streamlit.yaml.
# Any changes made here will be available automatically on the web app without having to stop it.
+# You may add overrides in a file named "userconfig_streamlit.yaml" in this folder, which can contain any subset
+# of the properties below.
general:
gpu: 0
outdir: outputs
- ckpt: "models/ldm/stable-diffusion-v1/model.ckpt"
- fp:
- name: 'embeddings/alex/embeddings_gs-11000.pt'
+ default_model: "Stable Diffusion v1.4"
+ default_model_config: "configs/stable-diffusion/v1-inference.yaml"
+ default_model_path: "models/ldm/stable-diffusion-v1/model.ckpt"
+ use_sd_concepts_library: True
+ sd_concepts_library_folder: "models/custom/sd-concepts-library"
GFPGAN_dir: "./src/gfpgan"
RealESRGAN_dir: "./src/realesrgan"
RealESRGAN_model: "RealESRGAN_x4plus"
+ LDSR_dir: "./src/latent-diffusion"
outdir_txt2img: outputs/txt2img-samples
outdir_img2img: outputs/img2img-samples
gfpgan_cpu: False
@@ -16,88 +21,161 @@ general:
extra_models_cpu: False
extra_models_gpu: False
save_metadata: True
+ save_format: "png"
skip_grid: False
skip_save: False
grid_format: "jpg:95"
n_rows: -1
no_verify_input: False
no_half: False
+ use_float16: False
precision: "autocast"
optimized: False
optimized_turbo: False
+ optimized_config: "optimizedSD/v1-inference.yaml"
+ enable_attention_slicing: False
+ enable_minimal_memory_usage : False
update_preview: True
- update_preview_frequency: 1
+ update_preview_frequency: 5
txt2img:
prompt:
height: 512
width: 512
- cfg_scale: 5.0
+ cfg_scale: 7.5
seed: ""
batch_count: 1
batch_size: 1
- sampling_steps: 50
- default_sampler: "k_lms"
+ sampling_steps: 30
+ default_sampler: "k_euler"
separate_prompts: False
+ update_preview: True
+ update_preview_frequency: 5
normalize_prompt_weights: True
save_individual_images: True
save_grid: True
group_by_prompt: True
save_as_jpg: False
- use_GFPGAN: True
- use_RealESRGAN: True
+ use_GFPGAN: False
+ use_RealESRGAN: False
RealESRGAN_model: "RealESRGAN_x4plus"
variant_amount: 0.0
variant_seed: ""
+ write_info_files: True
+ slider_steps: {
+ sampling: 1
+ }
+ slider_bounds: {
+ sampling: {
+ lower: 1,
+ upper: 150
+ }
+ }
+
+txt2vid:
+ default_model: "CompVis/stable-diffusion-v1-4"
+ custom_models_list: ["CompVis/stable-diffusion-v1-4", "naclbit/trinart_stable_diffusion_v2", "hakurei/waifu-diffusion", "osanseviero/BigGAN-deep-128"]
+ prompt:
+ height: 512
+ width: 512
+ cfg_scale: 7.5
+ seed: ""
+ batch_count: 1
+ batch_size: 1
+ sampling_steps: 30
+ num_inference_steps: 200
+ default_sampler: "k_euler"
+ scheduler_name: "klms"
+ separate_prompts: False
+ update_preview: True
+ update_preview_frequency: 5
+ dynamic_preview_frequency: True
+ normalize_prompt_weights: True
+ save_individual_images: True
+ save_video: True
+ group_by_prompt: True
+ write_info_files: True
+ do_loop: False
+ save_as_jpg: False
+ use_GFPGAN: False
+ use_RealESRGAN: False
+ RealESRGAN_model: "RealESRGAN_x4plus"
+ variant_amount: 0.0
+ variant_seed: ""
+ beta_start: 0.00085
+ beta_end: 0.012
+ beta_scheduler_type: "linear"
+ max_frames: 1000
+ slider_steps: {
+ sampling: 1
+ }
+ slider_bounds: {
+ sampling: {
+ lower: 1,
+ upper: 150
+ }
+ }
img2img:
- prompt:
- sampling_steps: 50
- # Adding an int to toggles enables the corresponding feature.
- # 0: Create prompt matrix (separate multiple prompts using |, and get all combinations of them)
- # 1: Normalize Prompt Weights (ensure sum of weights add up to 1.0)
- # 2: Loopback (use images from previous batch when creating next batch)
- # 3: Random loopback seed
- # 4: Save individual images
- # 5: Save grid
- # 6: Sort samples by prompt
- # 7: Write sample info files
- # 8: jpg samples
- # 9: Fix faces using GFPGAN
- # 10: Upscale images using Real-ESRGAN
- sampler_name: k_lms
- denoising_strength: 0.45
- # 0: Keep masked area
- # 1: Regenerate only masked area
- mask_mode: 0
- # 0: Just resize
- # 1: Crop and resize
- # 2: Resize and fill
- resize_mode: 0
- # Leave blank for random seed:
- seed: ""
- ddim_eta: 0.0
- cfg_scale: 5.0
- batch_count: 1
- batch_size: 1
- height: 512
- width: 512
- # Textual inversion embeddings file path:
- fp: ""
- loopback: True
- random_seed_loopback: True
- separate_prompts: False
- normalize_prompt_weights: True
- save_individual_images: True
- save_grid: True
- group_by_prompt: True
- save_as_jpg: False
- use_GFPGAN: True
- use_RealESRGAN: True
- RealESRGAN_model: "RealESRGAN_x4plus"
- variant_amount: 0.0
- variant_seed: ""
+ prompt:
+ sampling_steps: 30
+ # Adding an int to toggles enables the corresponding feature.
+ # 0: Create prompt matrix (separate multiple prompts using |, and get all combinations of them)
+ # 1: Normalize Prompt Weights (ensure sum of weights add up to 1.0)
+ # 2: Loopback (use images from previous batch when creating next batch)
+ # 3: Random loopback seed
+ # 4: Save individual images
+ # 5: Save grid
+ # 6: Sort samples by prompt
+ # 7: Write sample info files
+ # 8: jpg samples
+ # 9: Fix faces using GFPGAN
+ # 10: Upscale images using Real-ESRGAN
+ sampler_name: "k_euler"
+ denoising_strength: 0.75
+ # 0: Keep masked area
+ # 1: Regenerate only masked area
+ mask_mode: 0
+ mask_restore: False
+ # 0: Just resize
+ # 1: Crop and resize
+ # 2: Resize and fill
+ resize_mode: 0
+ # Leave blank for random seed:
+ seed: ""
+ ddim_eta: 0.0
+ cfg_scale: 7.5
+ batch_count: 1
+ batch_size: 1
+ height: 512
+ width: 512
+ # Textual inversion embeddings file path:
+ fp: ""
+ loopback: True
+ random_seed_loopback: True
+ separate_prompts: False
+ update_preview: True
+ update_preview_frequency: 5
+ normalize_prompt_weights: True
+ save_individual_images: True
+ save_grid: True
+ group_by_prompt: True
+ save_as_jpg: False
+ use_GFPGAN: False
+ use_RealESRGAN: False
+ RealESRGAN_model: "RealESRGAN_x4plus"
+ variant_amount: 0.0
+ variant_seed: ""
+ write_info_files: True
+ slider_steps: {
+ sampling: 1
+ }
+ slider_bounds: {
+ sampling: {
+ lower: 1,
+ upper: 150
+ }
+ }
gfpgan:
strength: 100
-
diff --git a/docker-compose.yml b/docker-compose.yml
index 968df1c..f378963 100644
--- a/docker-compose.yml
+++ b/docker-compose.yml
@@ -2,7 +2,7 @@ version: '3.3'
services:
stable-diffusion:
- container_name: sd
+ container_name: sd-webui
build:
context: .
dockerfile: Dockerfile
@@ -12,6 +12,7 @@ services:
volumes:
- .:/sd
- ./outputs:/sd/outputs
+ - ./model_cache:/sd/model_cache
- conda_env:/opt/conda
- root_profile:/root
ports:
@@ -21,7 +22,7 @@ services:
resources:
reservations:
devices:
- - capabilities: [gpu]
+ - capabilities: [ gpu ]
volumes:
conda_env:
diff --git a/docker-reset.sh b/docker-reset.sh
old mode 100644
new mode 100755
index 3ca3158..5042026
--- a/docker-reset.sh
+++ b/docker-reset.sh
@@ -10,12 +10,13 @@ echo $(pwd)
read -p "Is the directory above correct to run reset on? (y/n) " -n 1 DIRCONFIRM
if [[ $DIRCONFIRM =~ ^[Yy]$ ]]; then
docker compose down
- docker image rm stable-diffusion_stable-diffusion:latest
- docker volume rm stable-diffusion_conda_env
- docker volume rm stable-diffusion_root_profile
+ docker image rm stable-diffusion-webui_stable-diffusion:latest
+ docker volume rm stable-diffusion-webui_conda_env
+ docker volume rm stable-diffusion-webui_root_profile
echo "Remove ./src"
sudo rm -rf src
- sudo rm -rf latent_diffusion.egg-info
+ sudo rm -rf gfpgan
+ sudo rm -rf sd_webui.egg-info
sudo rm .env_updated
else
echo "Exited without resetting"
diff --git a/entrypoint.sh b/entrypoint.sh
index 21ab01e..e130ea0 100755
--- a/entrypoint.sh
+++ b/entrypoint.sh
@@ -3,26 +3,36 @@
# Starts the gui inside the docker container using the conda env
#
+# set -x
+
+SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
+cd $SCRIPT_DIR
+export PYTHONPATH=$SCRIPT_DIR
+
+MODEL_DIR="${SCRIPT_DIR}/model_cache"
# Array of model files to pre-download
# local filename
# local path in container (no trailing slash)
# download URL
# sha256sum
MODEL_FILES=(
- 'model.ckpt /sd/models/ldm/stable-diffusion-v1 https://www.googleapis.com/storage/v1/b/aai-blog-files/o/sd-v1-4.ckpt?alt=media fe4efff1e174c627256e44ec2991ba279b3816e364b49f9be2abc0b3ff3f8556'
- 'GFPGANv1.3.pth /sd/src/gfpgan/experiments/pretrained_models https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth c953a88f2727c85c3d9ae72e2bd4846bbaf59fe6972ad94130e23e7017524a70'
- 'RealESRGAN_x4plus.pth /sd/src/realesrgan/experiments/pretrained_models https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth 4fa0d38905f75ac06eb49a7951b426670021be3018265fd191d2125df9d682f1'
- 'RealESRGAN_x4plus_anime_6B.pth /sd/src/realesrgan/experiments/pretrained_models https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth f872d837d3c90ed2e05227bed711af5671a6fd1c9f7d7e91c911a61f155e99da'
+ 'model.ckpt models/ldm/stable-diffusion-v1 https://www.googleapis.com/storage/v1/b/aai-blog-files/o/sd-v1-4.ckpt?alt=media fe4efff1e174c627256e44ec2991ba279b3816e364b49f9be2abc0b3ff3f8556'
+ 'GFPGANv1.3.pth src/gfpgan/experiments/pretrained_models https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth c953a88f2727c85c3d9ae72e2bd4846bbaf59fe6972ad94130e23e7017524a70'
+ 'RealESRGAN_x4plus.pth src/realesrgan/experiments/pretrained_models https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth 4fa0d38905f75ac06eb49a7951b426670021be3018265fd191d2125df9d682f1'
+ 'RealESRGAN_x4plus_anime_6B.pth src/realesrgan/experiments/pretrained_models https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth f872d837d3c90ed2e05227bed711af5671a6fd1c9f7d7e91c911a61f155e99da'
+ 'project.yaml src/latent-diffusion/experiments/pretrained_models https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1 9d6ad53c5dafeb07200fb712db14b813b527edd262bc80ea136777bdb41be2ba'
+ 'model.ckpt src/latent-diffusion/experiments/pretrained_models https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1 c209caecac2f97b4bb8f4d726b70ac2ac9b35904b7fc99801e1f5e61f9210c13'
)
# Conda environment installs/updates
# @see https://github.com/ContinuumIO/docker-images/issues/89#issuecomment-467287039
ENV_NAME="ldm"
-ENV_FILE="/sd/environment.yaml"
+ENV_FILE="${SCRIPT_DIR}/environment.yaml"
ENV_UPDATED=0
ENV_MODIFIED=$(date -r $ENV_FILE "+%s")
-ENV_MODIFED_FILE="/sd/.env_updated"
+ENV_MODIFED_FILE="${SCRIPT_DIR}/.env_updated"
if [[ -f $ENV_MODIFED_FILE ]]; then ENV_MODIFIED_CACHED=$(<${ENV_MODIFED_FILE}); else ENV_MODIFIED_CACHED=0; fi
+export PIP_EXISTS_ACTION=w
# Create/update conda env if needed
if ! conda env list | grep ".*${ENV_NAME}.*" >/dev/null 2>&1; then
@@ -51,54 +61,67 @@ conda info | grep active
# Function to checks for valid hash for model files and download/replaces if invalid or does not exist
validateDownloadModel() {
local file=$1
- local path=$2
+ local path="${SCRIPT_DIR}/${2}"
local url=$3
local hash=$4
echo "checking ${file}..."
- sha256sum --check --status <<< "${hash} ${path}/${file}"
+ sha256sum --check --status <<< "${hash} ${MODEL_DIR}/${file}.${hash}"
if [[ $? == "1" ]]; then
echo "Downloading: ${url} please wait..."
mkdir -p ${path}
- wget --output-document=${path}/${file} --no-verbose --show-progress --progress=dot:giga ${url}
- echo "saved ${file}"
+ wget --output-document=${MODEL_DIR}/${file}.${hash} --no-verbose --show-progress --progress=dot:giga ${url}
+ ln -sf ${MODEL_DIR}/${file}.${hash} ${path}/${file}
+ if [[ -e "${path}/${file}" ]]; then
+ echo "saved ${file}"
+ else
+ echo "error saving ${path}/${file}!"
+ exit 1
+ fi
else
- echo -e "${file} is valid!\n"
+ if [[ ! -e ${path}/${file} || ! -L ${path}/${file} ]]; then
+ mkdir -p ${path}
+ ln -sf ${MODEL_DIR}/${file}.${hash} ${path}/${file}
+ echo -e "linked valid ${file}\n"
+ else
+ echo -e "${file} is valid!\n"
+ fi
fi
}
# Validate model files
-if [[ -z $VALIDATE_MODELS || $VALIDATE_MODELS == "true" ]]; then
- echo "Validating model files..."
- for models in "${MODEL_FILES[@]}"; do
- model=($models)
+echo "Validating model files..."
+for models in "${MODEL_FILES[@]}"; do
+ model=($models)
+ if [[ ! -e ${model[1]}/${model[0]} || ! -L ${model[1]}/${model[0]} || -z $VALIDATE_MODELS || $VALIDATE_MODELS == "true" ]]; then
validateDownloadModel ${model[0]} ${model[1]} ${model[2]} ${model[3]}
- done
-fi
+ fi
+done
# Launch web gui
-cd /sd
-
-if [[ -z $WEBUI_ARGS ]]; then
- launch_message="entrypoint.sh: Launching..."
+if [[ ! -z $WEBUI_SCRIPT && $WEBUI_SCRIPT == "webui_streamlit.py" ]]; then
+ launch_command="streamlit run scripts/${WEBUI_SCRIPT:-webui.py} $WEBUI_ARGS"
else
- launch_message="entrypoint.sh: Launching with arguments ${WEBUI_ARGS}"
+ launch_command="python scripts/${WEBUI_SCRIPT:-webui.py} $WEBUI_ARGS"
fi
+launch_message="entrypoint.sh: Run ${launch_command}..."
if [[ -z $WEBUI_RELAUNCH || $WEBUI_RELAUNCH == "true" ]]; then
n=0
while true; do
-
echo $launch_message
+
if (( $n > 0 )); then
echo "Relaunch count: ${n}"
fi
- python -u scripts/webui.py $WEBUI_ARGS
+
+ $launch_command
+
echo "entrypoint.sh: Process is ending. Relaunching in 0.5s..."
((n++))
sleep 0.5
done
else
echo $launch_message
- python -u scripts/webui.py $WEBUI_ARGS
+ $launch_command
fi
diff --git a/environment.yaml b/environment.yaml
index 5bb3bf8..b5bd8e3 100644
--- a/environment.yaml
+++ b/environment.yaml
@@ -3,39 +3,47 @@ channels:
- pytorch
- defaults
dependencies:
- - git
- - python=3.8.5
- - pip=20.3
- cudatoolkit=11.3
+ - git
+ - numpy=1.22.3
+ - pip=20.3
+ - python=3.8.5
- pytorch=1.11.0
+ - scikit-image=0.19.2
- torchvision=0.12.0
- - numpy=1.19.2
- pip:
- - albumentations==0.4.3
- - opencv-python==4.1.2.30
- - opencv-python-headless==4.1.2.30
- - pudb==2019.2
- - imageio==2.9.0
- - imageio-ffmpeg==0.4.2
- - pytorch-lightning==1.4.2
- - omegaconf==2.1.1
- - test-tube>=0.7.5
- - einops==0.3.0
- - torch-fidelity==0.3.0
- - transformers==4.19.2
- - torchmetrics==0.6.0
- - kornia==0.6
- - gradio==3.1.6
- - accelerate==0.12.0
- - pynvml==11.4.1
- - basicsr>=1.3.4.0
- - facexlib>=0.2.3
- - python-slugify>=6.1.2
- - streamlit>=1.12.2
- - retry>=0.9.2
+ - -e .
- -e git+https://github.com/CompVis/taming-transformers#egg=taming-transformers
- -e git+https://github.com/openai/CLIP#egg=clip
- -e git+https://github.com/TencentARC/GFPGAN#egg=GFPGAN
- -e git+https://github.com/xinntao/Real-ESRGAN#egg=realesrgan
- -e git+https://github.com/hlky/k-diffusion-sd#egg=k_diffusion
- - -e .
\ No newline at end of file
+ - -e git+https://github.com/devilismyfriend/latent-diffusion#egg=latent-diffusion
+ - accelerate==0.12.0
+ - albumentations==0.4.3
+ - basicsr>=1.3.4.0
+ - diffusers==0.3.0
+ - einops==0.3.0
+ - facexlib>=0.2.3
+ - gradio==3.1.6
+ - imageio-ffmpeg==0.4.2
+ - imageio==2.9.0
+ - kornia==0.6
+ - omegaconf==2.1.1
+ - opencv-python-headless==4.6.0.66
+ - pandas==1.4.3
+ - piexif==1.1.3
+ - pudb==2019.2
+ - pynvml==11.4.1
+ - python-slugify>=6.1.2
+ - pytorch-lightning==1.4.2
+ - retry>=0.9.2
+ - streamlit>=1.12.2
+ - streamlit-on-Hover-tabs==1.0.1
+ - streamlit-option-menu==0.3.2
+ - streamlit_nested_layout
+ - test-tube>=0.7.5
+ - tensorboard
+ - torch-fidelity==0.3.0
+ - torchmetrics==0.6.0
+ - transformers==4.19.2
diff --git a/frontend/css/streamlit.main.css b/frontend/css/streamlit.main.css
index 4e11b77..a11d21d 100644
--- a/frontend/css/streamlit.main.css
+++ b/frontend/css/streamlit.main.css
@@ -1,15 +1,111 @@
-.css-18e3th9 {
- padding-top: 2rem;
- padding-bottom: 10rem;
- padding-left: 5rem;
- padding-right: 5rem;
-}
-.css-1d391kg {
- padding-top: 3.5rem;
- padding-right: 1rem;
- padding-bottom: 3.5rem;
- padding-left: 1rem;
-}
+/***********************************************************
+* Additional CSS for streamlit builtin components *
+************************************************************/
+
+/* Tab name (e.g. Text-to-Image) */
button[data-baseweb="tab"] {
- font-size: 25px;
+ font-size: 25px; //improve legibility
}
+
+/* Image Container (only appear after run finished) */
+.css-du1fp8 {
+ justify-content: center; //center the image, especially better looks in wide screen
+}
+
+/* Streamlit header */
+.css-1avcm0n {
+ background-color: transparent;
+}
+
+/* Main streamlit container (below header) */
+.css-18e3th9 {
+ padding-top: 2rem; //reduce the empty spaces
+}
+
+/* @media only for widescreen, to ensure enough space to see all */
+@media (min-width: 1024px) {
+ /* Main streamlit container (below header) */
+ .css-18e3th9 {
+ padding-top: 0px; //reduce the empty spaces, can go fully to the top on widescreen devices
+ }
+}
+
+/***********************************************************
+* Additional CSS for streamlit custom/3rd party components *
+************************************************************/
+/* For stream_on_hover */
+section[data-testid="stSidebar"] > div:nth-of-type(1) {
+ background-color: #111;
+}
+
+button[kind="header"] {
+ background-color: transparent;
+ color: rgb(180, 167, 141);
+}
+
+@media (hover) {
+ /* header element */
+ header[data-testid="stHeader"] {
+ /* display: none;*/ /*suggested behavior by streamlit hover components*/
+ pointer-events: none; /* disable interaction of the transparent background */
+ }
+
+ /* The button on the streamlit navigation menu */
+ button[kind="header"] {
+ /* display: none;*/ /*suggested behavior by streamlit hover components*/
+ pointer-events: auto; /* enable interaction of the button even if parents intereaction disabled */
+ }
+
+ /* added to avoid main sectors (all element to the right of sidebar from) moving */
+ section[data-testid="stSidebar"] {
+ width: 3.5% !important;
+ min-width: 3.5% !important;
+ }
+
+ /* The navigation menu specs and size */
+ section[data-testid="stSidebar"] > div {
+ height: 100%;
+ width: 2% !important;
+ min-width: 100% !important;
+ position: relative;
+ z-index: 1;
+ top: 0;
+ left: 0;
+ background-color: #111;
+ overflow-x: hidden;
+ transition: 0.5s ease-in-out;
+ padding-top: 0px;
+ white-space: nowrap;
+ }
+
+ /* The navigation menu open and close on hover and size */
+ section[data-testid="stSidebar"] > div:hover {
+ width: 300px !important;
+ }
+}
+
+@media (max-width: 272px) {
+ section[data-testid="stSidebar"] > div {
+ width: 15rem;
+ }
+}
+
+/***********************************************************
+* Additional CSS for other elements
+************************************************************/
+button[data-baseweb="tab"] {
+ font-size: 20px;
+}
+
+@media (min-width: 1200px){
+h1 {
+ font-size: 1.75rem;
+}
+}
+#tabs-1-tabpanel-0 > div:nth-child(1) > div > div.stTabs.css-0.exp6ofz0 {
+ width: 50rem;
+ align-self: center;
+}
+div.gallery:hover {
+ border: 1px solid #777;
+}
\ No newline at end of file
diff --git a/frontend/frontend.py b/frontend/frontend.py
index 29d3c50..94c76c9 100644
--- a/frontend/frontend.py
+++ b/frontend/frontend.py
@@ -3,6 +3,8 @@ from frontend.css_and_js import css, js, call_JS, js_parse_prompt, js_copy_txt2i
from frontend.job_manager import JobManager
import frontend.ui_functions as uifn
import uuid
+import torch
+
def draw_gradio_ui(opt, img2img=lambda x: x, txt2img=lambda x: x, imgproc=lambda x: x, txt2img_defaults={},
@@ -36,8 +38,11 @@ def draw_gradio_ui(opt, img2img=lambda x: x, txt2img=lambda x: x, imgproc=lambda
value=txt2img_defaults['cfg_scale'], elem_id='cfg_slider')
txt2img_seed = gr.Textbox(label="Seed (blank to randomize)", lines=1, max_lines=1,
value=txt2img_defaults["seed"])
+ txt2img_batch_size = gr.Slider(minimum=1, maximum=50, step=1,
+ label='Images per batch',
+ value=txt2img_defaults['batch_size'])
txt2img_batch_count = gr.Slider(minimum=1, maximum=50, step=1,
- label='Number of images to generate',
+ label='Number of batches to generate',
value=txt2img_defaults['n_iter'])
txt2img_job_ui = job_manager.draw_gradio_ui() if job_manager else None
@@ -51,11 +56,15 @@ def draw_gradio_ui(opt, img2img=lambda x: x, txt2img=lambda x: x, imgproc=lambda
gr.Markdown(
"Select an image from the gallery, then click one of the buttons below to perform an action.")
with gr.Row(elem_id='txt2img_actions_row'):
- gr.Button("Copy to clipboard").click(fn=None,
- inputs=output_txt2img_gallery,
- outputs=[],
- # _js=js_copy_to_clipboard( 'txt2img_gallery_output')
- )
+ gr.Button("Copy to clipboard").click(
+ fn=None,
+ inputs=output_txt2img_gallery,
+ outputs=[],
+ _js=call_JS(
+ "copyImageFromGalleryToClipboard",
+ fromId="txt2img_gallery_output"
+ )
+ )
output_txt2img_copy_to_input_btn = gr.Button("Push to img2img")
output_txt2img_to_imglab = gr.Button("Send to Lab", visible=True)
@@ -91,9 +100,6 @@ def draw_gradio_ui(opt, img2img=lambda x: x, txt2img=lambda x: x, imgproc=lambda
with gr.TabItem('Advanced'):
txt2img_toggles = gr.CheckboxGroup(label='', choices=txt2img_toggles,
value=txt2img_toggle_defaults, type="index")
- txt2img_batch_size = gr.Slider(minimum=1, maximum=8, step=1,
- label='Batch size (how many images are in a batch; memory-hungry)',
- value=txt2img_defaults['batch_size'])
txt2img_realesrgan_model_name = gr.Dropdown(label='RealESRGAN model',
choices=['RealESRGAN_x4plus',
'RealESRGAN_x4plus_anime_6B'],
@@ -124,20 +130,27 @@ def draw_gradio_ui(opt, img2img=lambda x: x, txt2img=lambda x: x, imgproc=lambda
inputs=txt2img_inputs,
outputs=txt2img_outputs
)
+ use_queue = False
+ else:
+ use_queue = True
txt2img_btn.click(
txt2img_func,
txt2img_inputs,
- txt2img_outputs
+ txt2img_outputs,
+ api_name='txt2img',
+ queue=use_queue
)
txt2img_prompt.submit(
txt2img_func,
txt2img_inputs,
- txt2img_outputs
+ txt2img_outputs,
+ queue=use_queue
)
- # txt2img_width.change(fn=uifn.update_dimensions_info, inputs=[txt2img_width, txt2img_height], outputs=txt2img_dimensions_info_text_box)
- # txt2img_height.change(fn=uifn.update_dimensions_info, inputs=[txt2img_width, txt2img_height], outputs=txt2img_dimensions_info_text_box)
+ txt2img_width.change(fn=uifn.update_dimensions_info, inputs=[txt2img_width, txt2img_height], outputs=txt2img_dimensions_info_text_box)
+ txt2img_height.change(fn=uifn.update_dimensions_info, inputs=[txt2img_width, txt2img_height], outputs=txt2img_dimensions_info_text_box)
+ txt2img_dimensions_info_text_box.value = uifn.update_dimensions_info(txt2img_width.value, txt2img_height.value)
# Temporarily disable prompt parsing until memory issues could be solved
# See #676
@@ -189,8 +202,9 @@ def draw_gradio_ui(opt, img2img=lambda x: x, txt2img=lambda x: x, imgproc=lambda
with gr.TabItem("Editor Options"):
with gr.Row():
# disable Uncrop for now
- # choices=["Mask", "Crop", "Uncrop"]
- img2img_image_editor_mode = gr.Radio(choices=["Mask", "Crop"],
+ choices=["Mask", "Crop", "Uncrop"]
+ #choices=["Mask", "Crop"]
+ img2img_image_editor_mode = gr.Radio(choices=choices,
label="Image Editor Mode",
value="Mask", elem_id='edit_mode_select',
visible=True)
@@ -199,9 +213,13 @@ def draw_gradio_ui(opt, img2img=lambda x: x, txt2img=lambda x: x, imgproc=lambda
value=img2img_mask_modes[img2img_defaults['mask_mode']],
visible=True)
- img2img_mask_blur_strength = gr.Slider(minimum=1, maximum=10, step=1,
+ img2img_mask_restore = gr.Checkbox(label="Only modify regenerated parts of image",
+ value=img2img_defaults['mask_restore'],
+ visible=True)
+
+ img2img_mask_blur_strength = gr.Slider(minimum=1, maximum=100, step=1,
label="How much blurry should the mask be? (to avoid hard edges)",
- value=3, visible=False)
+ value=3, visible=True)
img2img_resize = gr.Radio(label="Resize mode",
choices=["Just resize", "Crop and resize",
@@ -293,7 +311,7 @@ def draw_gradio_ui(opt, img2img=lambda x: x, txt2img=lambda x: x, imgproc=lambda
img2img_height
],
[img2img_image_editor, img2img_image_mask, img2img_btn_editor, img2img_btn_mask,
- img2img_painterro_btn, img2img_mask, img2img_mask_blur_strength]
+ img2img_painterro_btn, img2img_mask, img2img_mask_blur_strength, img2img_mask_restore]
)
# img2img_image_editor_mode.change(
@@ -334,8 +352,8 @@ def draw_gradio_ui(opt, img2img=lambda x: x, txt2img=lambda x: x, imgproc=lambda
)
img2img_func = img2img
- img2img_inputs = [img2img_prompt, img2img_image_editor_mode, img2img_mask,
- img2img_mask_blur_strength, img2img_steps, img2img_sampling, img2img_toggles,
+ img2img_inputs = [img2img_prompt, img2img_image_editor_mode, img2img_mask, img2img_mask_blur_strength,
+ img2img_mask_restore, img2img_steps, img2img_sampling, img2img_toggles,
img2img_realesrgan_model_name, img2img_batch_count, img2img_cfg,
img2img_denoising, img2img_seed, img2img_height, img2img_width, img2img_resize,
img2img_image_editor, img2img_image_mask, img2img_embeddings]
@@ -349,11 +367,16 @@ def draw_gradio_ui(opt, img2img=lambda x: x, txt2img=lambda x: x, imgproc=lambda
inputs=img2img_inputs,
outputs=img2img_outputs,
)
+ use_queue = False
+ else:
+ use_queue = True
img2img_btn_mask.click(
img2img_func,
img2img_inputs,
- img2img_outputs
+ img2img_outputs,
+ api_name="img2img",
+ queue=use_queue
)
def img2img_submit_params():
@@ -383,6 +406,7 @@ def draw_gradio_ui(opt, img2img=lambda x: x, txt2img=lambda x: x, imgproc=lambda
outputs=img2img_dimensions_info_text_box)
img2img_height.change(fn=uifn.update_dimensions_info, inputs=[img2img_width, img2img_height],
outputs=img2img_dimensions_info_text_box)
+ img2img_dimensions_info_text_box.value = uifn.update_dimensions_info(img2img_width.value, img2img_height.value)
with gr.TabItem("Image Lab", id='imgproc_tab'):
gr.Markdown("Post-process results")
@@ -397,8 +421,7 @@ def draw_gradio_ui(opt, img2img=lambda x: x, txt2img=lambda x: x, imgproc=lambda
# value=gfpgan_defaults['strength'])
# select folder with images to process
with gr.TabItem('Batch Process'):
- imgproc_folder = gr.File(label="Batch Process", file_count="multiple", source="upload",
- interactive=True, type="file")
+ imgproc_folder = gr.File(label="Batch Process", file_count="multiple", interactive=True, type="file")
imgproc_pngnfo = gr.Textbox(label="PNG Metadata", placeholder="PngNfo", visible=False,
max_lines=5)
with gr.Row():
@@ -540,7 +563,7 @@ def draw_gradio_ui(opt, img2img=lambda x: x, txt2img=lambda x: x, imgproc=lambda
imgproc_width, imgproc_cfg, imgproc_denoising, imgproc_seed,
imgproc_gfpgan_strength, imgproc_ldsr_steps, imgproc_ldsr_pre_downSample,
imgproc_ldsr_post_downSample],
- [imgproc_output])
+ [imgproc_output], api_name="imgproc")
imgproc_source.change(
uifn.get_png_nfo,
@@ -631,11 +654,12 @@ def draw_gradio_ui(opt, img2img=lambda x: x, txt2img=lambda x: x, imgproc=lambda
"""
gr.HTML("""
-
For help and advanced usage guides, visit the Project Wiki
-
Stable Diffusion WebUI is an open-source project.
- If you would like to contribute to development or test bleeding edge builds, use the dev branch.
+
For help and advanced usage guides, visit the Project Wiki
+
Stable Diffusion WebUI is an open-source project. You can find the latest stable builds on the main repository.
+ If you would like to contribute to development or test bleeding edge builds, you can visit the developement repository.
+
Device ID {current_device_index}: {current_device_name} {total_device_count} total devices
- """)
+ """.format(current_device_name=torch.cuda.get_device_name(), current_device_index=torch.cuda.current_device(), total_device_count=torch.cuda.device_count()))
# Hack: Detect the load event on the frontend
# Won't be needed in the next version of gradio
# See the relevant PR: https://github.com/gradio-app/gradio/pull/2108
diff --git a/frontend/image_metadata.py b/frontend/image_metadata.py
new file mode 100644
index 0000000..8448088
--- /dev/null
+++ b/frontend/image_metadata.py
@@ -0,0 +1,57 @@
+''' Class to store image generation parameters to be stored as metadata in the image'''
+from __future__ import annotations
+from dataclasses import dataclass, asdict
+from typing import Dict, Optional
+from PIL import Image
+from PIL.PngImagePlugin import PngInfo
+import copy
+
+@dataclass
+class ImageMetadata:
+ prompt: str = None
+ seed: str = None
+ width: str = None
+ height: str = None
+ steps: str = None
+ cfg_scale: str = None
+ normalize_prompt_weights: str = None
+ denoising_strength: str = None
+ GFPGAN: str = None
+
+ def as_png_info(self) -> PngInfo:
+ info = PngInfo()
+ for key, value in self.as_dict().items():
+ info.add_text(key, value)
+ return info
+
+ def as_dict(self) -> Dict[str, str]:
+ return {f"SD:{key}": str(value) for key, value in asdict(self).items() if value is not None}
+
+ @classmethod
+ def set_on_image(cls, image: Image, metadata: ImageMetadata) -> None:
+ ''' Sets metadata on image, in both text form and as an ImageMetadata object '''
+ if metadata:
+ image.info = metadata.as_dict()
+ else:
+ metadata = ImageMetadata()
+ image.info["ImageMetadata"] = copy.copy(metadata)
+
+ @classmethod
+ def get_from_image(cls, image: Image) -> Optional[ImageMetadata]:
+ ''' Gets metadata from an image, first looking for an ImageMetadata,
+ then if not found tries to construct one from the info '''
+ metadata = image.info.get("ImageMetadata", None)
+ if not metadata:
+ found_metadata = False
+ metadata = ImageMetadata()
+ for key, value in image.info.items():
+ if key.lower().startswith("sd:"):
+ key = key[3:]
+ if f"{key}" in metadata.__dict__:
+ metadata.__dict__[key] = value
+ found_metadata = True
+ if not found_metadata:
+ metadata = None
+ if not metadata:
+ print("Couldn't find metadata on image")
+ return metadata
diff --git a/frontend/job_manager.py b/frontend/job_manager.py
index 8eda8d9..026742f 100644
--- a/frontend/job_manager.py
+++ b/frontend/job_manager.py
@@ -1,7 +1,7 @@
''' Provides simple job management for gradio, allowing viewing and stopping in-progress multi-batch generations '''
from __future__ import annotations
import gradio as gr
-from gradio.components import Component, Gallery
+from gradio.components import Component, Gallery, Slider
from threading import Event, Timer
from typing import Callable, List, Dict, Tuple, Optional, Any
from dataclasses import dataclass, field
@@ -9,6 +9,7 @@ from functools import partial
from PIL.Image import Image
import uuid
import traceback
+import time
@dataclass(eq=True, frozen=True)
@@ -30,9 +31,21 @@ class JobInfo:
session_key: str
job_token: Optional[int] = None
images: List[Image] = field(default_factory=list)
+ active_image: Image = None
+ rec_steps_enabled: bool = False
+ rec_steps_imgs: List[Image] = field(default_factory=list)
+ rec_steps_intrvl: int = None
+ rec_steps_to_gallery: bool = False
+ rec_steps_to_file: bool = False
should_stop: Event = field(default_factory=Event)
+ refresh_active_image_requested: Event = field(default_factory=Event)
+ refresh_active_image_done: Event = field(default_factory=Event)
+ stop_cur_iter: Event = field(default_factory=Event)
+ active_iteration_cnt: int = field(default_factory=int)
job_status: str = field(default_factory=str)
finished: bool = False
+ started: bool = False
+ timestamp: float = None
removed_output_idxs: List[int] = field(default_factory=list)
@@ -76,7 +89,7 @@ class JobManagerUi:
'''
return self._job_manager._wrap_func(
func=func, inputs=inputs, outputs=outputs,
- refresh_btn=self._refresh_btn, stop_btn=self._stop_btn, status_text=self._status_text
+ job_ui=self
)
_refresh_btn: gr.Button
@@ -84,10 +97,19 @@ class JobManagerUi:
_status_text: gr.Textbox
_stop_all_session_btn: gr.Button
_free_done_sessions_btn: gr.Button
+ _active_image: gr.Image
+ _active_image_stop_btn: gr.Button
+ _active_image_refresh_btn: gr.Button
+ _rec_steps_intrvl_sldr: gr.Slider
+ _rec_steps_checkbox: gr.Checkbox
+ _save_rec_steps_to_gallery_chkbx: gr.Checkbox
+ _save_rec_steps_to_file_chkbx: gr.Checkbox
_job_manager: JobManager
class JobManager:
+ JOB_MAX_START_TIME = 5.0 # How long can a job be stuck 'starting' before assuming it isn't running
+
def __init__(self, max_jobs: int):
self._max_jobs: int = max_jobs
self._avail_job_tokens: List[Any] = list(range(max_jobs))
@@ -102,11 +124,23 @@ class JobManager:
'''
assert gr.context.Context.block is not None, "draw_gradio_ui must be called within a 'gr.Blocks' 'with' context"
with gr.Tabs():
- with gr.TabItem("Current Session"):
+ with gr.TabItem("Job Controls"):
with gr.Row():
- stop_btn = gr.Button("Stop", elem_id="stop", variant="secondary")
- refresh_btn = gr.Button("Refresh", elem_id="refresh", variant="secondary")
+ stop_btn = gr.Button("Stop All Batches", elem_id="stop", variant="secondary")
+ refresh_btn = gr.Button("Refresh Finished Batches", elem_id="refresh", variant="secondary")
status_text = gr.Textbox(placeholder="Job Status", interactive=False, show_label=False)
+ with gr.Row():
+ active_image_stop_btn = gr.Button("Skip Active Batch", variant="secondary")
+ active_image_refresh_btn = gr.Button("View Batch Progress", variant="secondary")
+ active_image = gr.Image(type="pil", interactive=False, visible=False, elem_id="active_iteration_image")
+ with gr.TabItem("Batch Progress Settings"):
+ with gr.Row():
+ record_steps_checkbox = gr.Checkbox(value=False, label="Enable Batch Progress Grid")
+ record_steps_interval_slider = gr.Slider(
+ value=3, label="Record Interval (steps)", minimum=1, maximum=25, step=1)
+ with gr.Row() as record_steps_box:
+ steps_to_gallery_checkbox = gr.Checkbox(value=False, label="Save Progress Grid to Gallery")
+ steps_to_file_checkbox = gr.Checkbox(value=False, label="Save Progress Grid to File")
with gr.TabItem("Maintenance"):
with gr.Row():
gr.Markdown(
@@ -118,9 +152,15 @@ class JobManager:
free_done_sessions_btn = gr.Button(
"Clear Finished Jobs", elem_id="clear_finished", variant="secondary"
)
+
return JobManagerUi(_refresh_btn=refresh_btn, _stop_btn=stop_btn, _status_text=status_text,
_stop_all_session_btn=stop_all_sessions_btn, _free_done_sessions_btn=free_done_sessions_btn,
- _job_manager=self)
+ _active_image=active_image, _active_image_stop_btn=active_image_stop_btn,
+ _active_image_refresh_btn=active_image_refresh_btn,
+ _rec_steps_checkbox=record_steps_checkbox,
+ _save_rec_steps_to_gallery_chkbx=steps_to_gallery_checkbox,
+ _save_rec_steps_to_file_chkbx=steps_to_file_checkbox,
+ _rec_steps_intrvl_sldr=record_steps_interval_slider, _job_manager=self)
def clear_all_finished_jobs(self):
''' Removes all currently finished jobs, across all sessions.
@@ -134,6 +174,7 @@ class JobManager:
for session in self._sessions.values():
for job in session.jobs.values():
job.should_stop.set()
+ job.stop_cur_iter.set()
def _get_job_token(self, block: bool = False) -> Optional[int]:
''' Attempts to acquire a job token, optionally blocking until available '''
@@ -175,6 +216,26 @@ class JobManager:
job_info.should_stop.set()
return "Stopping after current batch finishes"
+ def _refresh_cur_iter_func(self, func_key: FuncKey, session_key: str) -> List[Component]:
+ ''' Updates information from the active iteration '''
+ session_info, job_info = self._get_call_info(func_key, session_key)
+ if job_info is None:
+ return [None, f"Session {session_key} was not running function {func_key}"]
+
+ job_info.refresh_active_image_requested.set()
+ if job_info.refresh_active_image_done.wait(timeout=20.0):
+ job_info.refresh_active_image_done.clear()
+ return [gr.Image.update(value=job_info.active_image, visible=True), f"Sample iteration {job_info.active_iteration_cnt}"]
+ return [gr.Image.update(visible=False), "Timed out getting image"]
+
+ def _stop_cur_iter_func(self, func_key: FuncKey, session_key: str) -> List[Component]:
+ ''' Marks that the active iteration should be stopped'''
+ session_info, job_info = self._get_call_info(func_key, session_key)
+ if job_info is None:
+ return [None, f"Session {session_key} was not running function {func_key}"]
+ job_info.stop_cur_iter.set()
+ return [gr.Image.update(visible=False), "Stopping current iteration"]
+
def _get_call_info(self, func_key: FuncKey, session_key: str) -> Tuple[SessionInfo, JobInfo]:
''' Helper to get the SessionInfo and JobInfo. '''
session_info = self._sessions.get(session_key, None)
@@ -207,19 +268,22 @@ class JobManager:
def _pre_call_func(
self, func_key: FuncKey, output_dummy_obj: Component, refresh_btn: gr.Button, stop_btn: gr.Button,
- status_text: gr.Textbox, session_key: str) -> List[Component]:
+ status_text: gr.Textbox, active_image: gr.Image, active_refresh_btn: gr.Button, active_stop_btn: gr.Button,
+ session_key: str) -> List[Component]:
''' Called when a job is about to start '''
session_info, job_info = self._get_call_info(func_key, session_key)
# If we didn't already get a token then queue up for one
if job_info.job_token is None:
- job_info.token = self._get_job_token(block=True)
+ job_info.job_token = self._get_job_token(block=True)
# Buttons don't seem to update unless value is set on them as well...
return {output_dummy_obj: triggerChangeEvent(),
refresh_btn: gr.Button.update(variant="primary", value=refresh_btn.value),
stop_btn: gr.Button.update(variant="primary", value=stop_btn.value),
- status_text: gr.Textbox.update(value="Generation has started. Click 'Refresh' for updates")
+ status_text: gr.Textbox.update(value="Generation has started. Click 'Refresh' to see finished images, 'View Batch Progress' for active images"),
+ active_refresh_btn: gr.Button.update(variant="primary", value=active_refresh_btn.value),
+ active_stop_btn: gr.Button.update(variant="primary", value=active_stop_btn.value),
}
def _call_func(self, func_key: FuncKey, session_key: str) -> List[Component]:
@@ -228,12 +292,19 @@ class JobManager:
if session_info is None or job_info is None:
return []
+ job_info.started = True
try:
+ if job_info.should_stop.is_set():
+ raise Exception(f"Job {job_info} requested a stop before execution began")
outputs = job_info.func(*job_info.inputs, job_info=job_info)
except Exception as e:
job_info.job_status = f"Error: {e}"
print(f"Exception processing job {job_info}: {e}\n{traceback.format_exc()}")
- outputs = []
+ raise
+ finally:
+ job_info.finished = True
+ session_info.finished_jobs[func_key] = session_info.jobs.pop(func_key)
+ self._release_job_token(job_info.job_token)
# Filter the function output for any removed outputs
filtered_output = []
@@ -241,11 +312,6 @@ class JobManager:
if idx not in job_info.removed_output_idxs:
filtered_output.append(output)
- job_info.finished = True
- session_info.finished_jobs[func_key] = session_info.jobs.pop(func_key)
-
- self._release_job_token(job_info.job_token)
-
# The wrapper added a dummy JSON output. Append a random text string
# to fire the dummy objects 'change' event to notify that the job is done
filtered_output.append(triggerChangeEvent())
@@ -254,12 +320,16 @@ class JobManager:
def _post_call_func(
self, func_key: FuncKey, output_dummy_obj: Component, refresh_btn: gr.Button, stop_btn: gr.Button,
- status_text: gr.Textbox, session_key: str) -> List[Component]:
+ status_text: gr.Textbox, active_image: gr.Image, active_refresh_btn: gr.Button, active_stop_btn: gr.Button,
+ session_key: str) -> List[Component]:
''' Called when a job completes '''
return {output_dummy_obj: triggerChangeEvent(),
refresh_btn: gr.Button.update(variant="secondary", value=refresh_btn.value),
stop_btn: gr.Button.update(variant="secondary", value=stop_btn.value),
- status_text: gr.Textbox.update(value="Generation has finished!")
+ status_text: gr.Textbox.update(value="Generation has finished!"),
+ active_refresh_btn: gr.Button.update(variant="secondary", value=active_refresh_btn.value),
+ active_stop_btn: gr.Button.update(variant="secondary", value=active_stop_btn.value),
+ active_image: gr.Image.update(visible=False)
}
def _update_gallery_event(self, func_key: FuncKey, session_key: str) -> List[Component]:
@@ -270,21 +340,17 @@ class JobManager:
if session_info is None or job_info is None:
return []
- if job_info.finished:
- session_info.finished_jobs.pop(func_key)
-
return job_info.images
- def _wrap_func(
- self, func: Callable, inputs: List[Component], outputs: List[Component],
- refresh_btn: gr.Button = None, stop_btn: gr.Button = None,
- status_text: Optional[gr.Textbox] = None) -> Tuple[Callable, List[Component]]:
+ def _wrap_func(self, func: Callable, inputs: List[Component],
+ outputs: List[Component],
+ job_ui: JobManagerUi) -> Tuple[Callable, List[Component]]:
''' handles JobManageUI's wrap_func'''
assert gr.context.Context.block is not None, "wrap_func must be called within a 'gr.Blocks' 'with' context"
# Create a unique key for this job
- func_key = FuncKey(job_id=uuid.uuid4(), func=func)
+ func_key = FuncKey(job_id=uuid.uuid4().hex, func=func)
# Create a unique session key (next gradio release can use gr.State, see https://gradio.app/state_in_blocks/)
if self._session_key is None:
@@ -302,31 +368,59 @@ class JobManager:
del outputs[idx]
break
- # Add the session key to the inputs
- inputs += [self._session_key]
-
# Create dummy objects
update_gallery_obj = gr.JSON(visible=False, elem_id="JobManagerDummyObject")
update_gallery_obj.change(
partial(self._update_gallery_event, func_key),
[self._session_key],
- [gallery_comp]
+ [gallery_comp],
+ queue=False
)
- if refresh_btn:
- refresh_btn.variant = 'secondary'
- refresh_btn.click(
+ if job_ui._refresh_btn:
+ job_ui._refresh_btn.variant = 'secondary'
+ job_ui._refresh_btn.click(
partial(self._refresh_func, func_key),
[self._session_key],
- [update_gallery_obj, status_text]
+ [update_gallery_obj, job_ui._status_text],
+ queue=False
)
- if stop_btn:
- stop_btn.variant = 'secondary'
- stop_btn.click(
+ if job_ui._stop_btn:
+ job_ui._stop_btn.variant = 'secondary'
+ job_ui._stop_btn.click(
partial(self._stop_wrapped_func, func_key),
[self._session_key],
- [status_text]
+ [job_ui._status_text],
+ queue=False
+ )
+
+ if job_ui._active_image and job_ui._active_image_refresh_btn:
+ job_ui._active_image_refresh_btn.click(
+ partial(self._refresh_cur_iter_func, func_key),
+ [self._session_key],
+ [job_ui._active_image, job_ui._status_text],
+ queue=False
+ )
+
+ if job_ui._active_image_stop_btn:
+ job_ui._active_image_stop_btn.click(
+ partial(self._stop_cur_iter_func, func_key),
+ [self._session_key],
+ [job_ui._active_image, job_ui._status_text],
+ queue=False
+ )
+
+ if job_ui._stop_all_session_btn:
+ job_ui._stop_all_session_btn.click(
+ self.stop_all_jobs, [], [],
+ queue=False
+ )
+
+ if job_ui._free_done_sessions_btn:
+ job_ui._free_done_sessions_btn.click(
+ self.clear_all_finished_jobs, [], [],
+ queue=False
)
# (ab)use gr.JSON to forward events.
@@ -343,7 +437,8 @@ class JobManager:
# Since some parameters are optional it makes sense to use the 'dict' return value type, which requires
# the Component as a key... so group together the UI components that the event listeners are going to update
# to make it easy to append to function calls and outputs
- job_ui_params = [refresh_btn, stop_btn, status_text]
+ job_ui_params = [job_ui._refresh_btn, job_ui._stop_btn, job_ui._status_text,
+ job_ui._active_image, job_ui._active_image_refresh_btn, job_ui._active_image_stop_btn]
job_ui_outputs = [comp for comp in job_ui_params if comp is not None]
# Here a chain is constructed that will make a 'pre' call, a 'run' call, and a 'post' call,
@@ -352,44 +447,70 @@ class JobManager:
post_call_dummyobj.change(
partial(self._post_call_func, func_key, update_gallery_obj, *job_ui_params),
[self._session_key],
- [update_gallery_obj] + job_ui_outputs
+ [update_gallery_obj] + job_ui_outputs,
+ queue=False
)
call_dummyobj = gr.JSON(visible=False, elem_id="JobManagerDummyObject_runCall")
call_dummyobj.change(
partial(self._call_func, func_key),
[self._session_key],
- outputs + [post_call_dummyobj]
+ outputs + [post_call_dummyobj],
+ queue=False
)
pre_call_dummyobj = gr.JSON(visible=False, elem_id="JobManagerDummyObject_preCall")
pre_call_dummyobj.change(
partial(self._pre_call_func, func_key, call_dummyobj, *job_ui_params),
[self._session_key],
- [call_dummyobj] + job_ui_outputs
+ [call_dummyobj] + job_ui_outputs,
+ queue=False
)
- # Now replace the original function with one that creates a JobInfo and triggers the dummy obj
+ # Add any components that we want the runtime values for
+ added_inputs = [self._session_key, job_ui._rec_steps_checkbox, job_ui._save_rec_steps_to_gallery_chkbx,
+ job_ui._save_rec_steps_to_file_chkbx, job_ui._rec_steps_intrvl_sldr]
- def wrapped_func(*inputs):
- session_key = inputs[-1]
- inputs = inputs[:-1]
+ # Now replace the original function with one that creates a JobInfo and triggers the dummy obj
+ def wrapped_func(*wrapped_inputs):
+ # Remove the added_inputs (pop opposite order of list)
+
+ wrapped_inputs = list(wrapped_inputs)
+ rec_steps_interval: int = wrapped_inputs.pop()
+ save_rec_steps_file: bool = wrapped_inputs.pop()
+ save_rec_steps_grid: bool = wrapped_inputs.pop()
+ record_steps_enabled: bool = wrapped_inputs.pop()
+ session_key: str = wrapped_inputs.pop()
+ job_inputs = tuple(wrapped_inputs)
# Get or create a session for this key
session_info = self._sessions.setdefault(session_key, SessionInfo())
# Is this session already running this job?
if func_key in session_info.jobs:
- return {status_text: "This session is already running that function!"}
+ job_info = session_info.jobs[func_key]
+ # If the job seems stuck in 'starting' then go ahead and toss it
+ if not job_info.started and time.time() > job_info.timestamp + JobManager.JOB_MAX_START_TIME:
+ job_info.should_stop.set()
+ job_info.stop_cur_iter.set()
+ session_info.jobs.pop(func_key)
+ return {job_ui._status_text: "Canceled possibly hung job. Try again"}
+ return {job_ui._status_text: "This session is already running that function!"}
+
+ # Is this a new run of a previously finished job? Clear old info
+ if func_key in session_info.finished_jobs:
+ session_info.finished_jobs.pop(func_key)
job_token = self._get_job_token(block=False)
- job = JobInfo(inputs=inputs, func=func, removed_output_idxs=removed_idxs, session_key=session_key,
- job_token=job_token)
+ job = JobInfo(
+ inputs=job_inputs, func=func, removed_output_idxs=removed_idxs, session_key=session_key,
+ job_token=job_token, rec_steps_enabled=record_steps_enabled, rec_steps_intrvl=rec_steps_interval,
+ rec_steps_to_gallery=save_rec_steps_grid, rec_steps_to_file=save_rec_steps_file, timestamp=time.time())
session_info.jobs[func_key] = job
ret = {pre_call_dummyobj: triggerChangeEvent()}
if job_token is None:
- ret[status_text] = "Job is queued"
+ ret[job_ui._status_text] = "Job is queued"
return ret
- return wrapped_func, inputs, [pre_call_dummyobj, status_text]
+ return wrapped_func, inputs + added_inputs, [pre_call_dummyobj, job_ui._status_text]
diff --git a/frontend/ui_functions.py b/frontend/ui_functions.py
index 6557841..ee6af8d 100644
--- a/frontend/ui_functions.py
+++ b/frontend/ui_functions.py
@@ -9,10 +9,10 @@ import re
def change_image_editor_mode(choice, cropped_image, masked_image, resize_mode, width, height):
if choice == "Mask":
update_image_result = update_image_mask(cropped_image, resize_mode, width, height)
- return [gr.update(visible=False), update_image_result, gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=True), gr.update(visible=True)]
+ return [gr.update(visible=False), update_image_result, gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)]
update_image_result = update_image_mask(masked_image["image"] if masked_image is not None else None, resize_mode, width, height)
- return [update_image_result, gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)]
+ return [update_image_result, gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)]
def update_image_mask(cropped_image, resize_mode, width, height):
resized_cropped_image = resize_image(resize_mode, cropped_image, width, height) if cropped_image else None
diff --git a/images/nsfw.jpeg b/images/nsfw.jpeg
new file mode 100644
index 0000000..0ecf3b6
Binary files /dev/null and b/images/nsfw.jpeg differ
diff --git a/ldm/modules/attention.py b/ldm/modules/attention.py
index f4eff39..4485c1e 100644
--- a/ldm/modules/attention.py
+++ b/ldm/modules/attention.py
@@ -7,6 +7,8 @@ from einops import rearrange, repeat
from ldm.modules.diffusionmodules.util import checkpoint
+import psutil
+
def exists(val):
return val is not None
@@ -167,30 +169,98 @@ class CrossAttention(nn.Module):
nn.Dropout(dropout)
)
+ if torch.cuda.is_available():
+ self.einsum_op = self.einsum_op_cuda
+ else:
+ self.mem_total = psutil.virtual_memory().total / (1024**3)
+ self.einsum_op = self.einsum_op_mps_v1 if self.mem_total >= 32 else self.einsum_op_mps_v2
+
+ def einsum_op_compvis(self, q, k, v, r1):
+ s1 = einsum('b i d, b j d -> b i j', q, k) * self.scale # faster
+ s2 = s1.softmax(dim=-1, dtype=q.dtype)
+ del s1
+ r1 = einsum('b i j, b j d -> b i d', s2, v)
+ del s2
+ return r1
+
+ def einsum_op_mps_v1(self, q, k, v, r1):
+ if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096
+ r1 = self.einsum_op_compvis(q, k, v, r1)
+ else:
+ slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
+ for i in range(0, q.shape[1], slice_size):
+ end = i + slice_size
+ s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * self.scale
+ s2 = s1.softmax(dim=-1, dtype=r1.dtype)
+ del s1
+ r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
+ del s2
+ return r1
+
+ def einsum_op_mps_v2(self, q, k, v, r1):
+ if self.mem_total >= 8 and q.shape[1] <= 4096:
+ r1 = self.einsum_op_compvis(q, k, v, r1)
+ else:
+ slice_size = 1
+ for i in range(0, q.shape[0], slice_size):
+ end = min(q.shape[0], i + slice_size)
+ s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end])
+ s1 *= self.scale
+ s2 = s1.softmax(dim=-1, dtype=r1.dtype)
+ del s1
+ r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end])
+ del s2
+ return r1
+
+ def einsum_op_cuda(self, q, k, v, r1):
+ stats = torch.cuda.memory_stats(q.device)
+ mem_active = stats['active_bytes.all.current']
+ mem_reserved = stats['reserved_bytes.all.current']
+ mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
+ mem_free_torch = mem_reserved - mem_active
+ mem_free_total = mem_free_cuda + mem_free_torch
+
+ gb = 1024 ** 3
+ tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * 4
+ mem_required = tensor_size * 2.5
+ steps = 1
+
+ if mem_required > mem_free_total:
+ steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
+
+ if steps > 64:
+ max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
+ raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
+ f'Need: {mem_required/64/gb:0.1f}GB free, Have:{mem_free_total/gb:0.1f}GB free')
+
+ slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
+ for i in range(0, q.shape[1], slice_size):
+ end = min(q.shape[1], i + slice_size)
+ s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * self.scale
+ s2 = s1.softmax(dim=-1, dtype=r1.dtype)
+ del s1
+ r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
+ del s2
+ return r1
+
def forward(self, x, context=None, mask=None):
h = self.heads
q = self.to_q(x)
context = default(context, x)
+ del x
k = self.to_k(context)
v = self.to_v(context)
+ del context
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
- sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
-
- if exists(mask):
- mask = rearrange(mask, 'b ... -> b (...)')
- max_neg_value = -torch.finfo(sim.dtype).max
- mask = repeat(mask, 'b j -> (b h) () j', h=h)
- sim.masked_fill_(~mask, max_neg_value)
-
- # attention, what we cannot get enough of
- attn = sim.softmax(dim=-1)
-
- out = einsum('b i j, b j d -> b i d', attn, v)
- out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
- return self.to_out(out)
+ r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
+ r1 = self.einsum_op(q, k, v, r1)
+ del q, k, v
+ r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
+ del r1
+ return self.to_out(r2)
class BasicTransformerBlock(nn.Module):
@@ -209,9 +279,10 @@ class BasicTransformerBlock(nn.Module):
return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
def _forward(self, x, context=None):
- x = self.attn1(self.norm1(x)) + x
- x = self.attn2(self.norm2(x), context=context) + x
- x = self.ff(self.norm3(x)) + x
+ x = x.contiguous() if x.device.type == 'mps' else x
+ x += self.attn1(self.norm1(x))
+ x += self.attn2(self.norm2(x), context=context)
+ x += self.ff(self.norm3(x))
return x
diff --git a/ldm/modules/diffusionmodules/model.py b/ldm/modules/diffusionmodules/model.py
index 533e589..dbbb325 100644
--- a/ldm/modules/diffusionmodules/model.py
+++ b/ldm/modules/diffusionmodules/model.py
@@ -1,4 +1,5 @@
# pytorch_diffusion + derived encoder decoder
+import gc
import math
import torch
import torch.nn as nn
@@ -119,18 +120,30 @@ class ResnetBlock(nn.Module):
padding=0)
def forward(self, x, temb):
- h = x
- h = self.norm1(h)
- h = nonlinearity(h)
- h = self.conv1(h)
+ h1 = x
+ h2 = self.norm1(h1)
+ del h1
+
+ h3 = nonlinearity(h2)
+ del h2
+
+ h4 = self.conv1(h3)
+ del h3
if temb is not None:
- h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
+ h4 = h4 + self.temb_proj(nonlinearity(temb))[:,:,None,None]
- h = self.norm2(h)
- h = nonlinearity(h)
- h = self.dropout(h)
- h = self.conv2(h)
+ h5 = self.norm2(h4)
+ del h4
+
+ h6 = nonlinearity(h5)
+ del h5
+
+ h7 = self.dropout(h6)
+ del h6
+
+ h8 = self.conv2(h7)
+ del h7
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
@@ -138,7 +151,7 @@ class ResnetBlock(nn.Module):
else:
x = self.nin_shortcut(x)
- return x+h
+ return x + h8
class LinAttnBlock(LinearAttention):
@@ -178,28 +191,65 @@ class AttnBlock(nn.Module):
def forward(self, x):
h_ = x
h_ = self.norm(h_)
- q = self.q(h_)
- k = self.k(h_)
+ q1 = self.q(h_)
+ k1 = self.k(h_)
v = self.v(h_)
# compute attention
- b,c,h,w = q.shape
- q = q.reshape(b,c,h*w)
- q = q.permute(0,2,1) # b,hw,c
- k = k.reshape(b,c,h*w) # b,c,hw
- w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
- w_ = w_ * (int(c)**(-0.5))
- w_ = torch.nn.functional.softmax(w_, dim=2)
+ b, c, h, w = q1.shape
- # attend to values
- v = v.reshape(b,c,h*w)
- w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
- h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
- h_ = h_.reshape(b,c,h,w)
+ q2 = q1.reshape(b, c, h*w)
+ del q1
- h_ = self.proj_out(h_)
+ q = q2.permute(0, 2, 1) # b,hw,c
+ del q2
- return x+h_
+ k = k1.reshape(b, c, h*w) # b,c,hw
+ del k1
+
+ h_ = torch.zeros_like(k, device=q.device)
+
+ stats = torch.cuda.memory_stats(q.device)
+ mem_active = stats['active_bytes.all.current']
+ mem_reserved = stats['reserved_bytes.all.current']
+ mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
+ mem_free_torch = mem_reserved - mem_active
+ mem_free_total = mem_free_cuda + mem_free_torch
+
+ tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * 4
+ mem_required = tensor_size * 2.5
+ steps = 1
+
+ if mem_required > mem_free_total:
+ steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
+
+ slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
+ for i in range(0, q.shape[1], slice_size):
+ end = i + slice_size
+
+ w1 = torch.bmm(q[:, i:end], k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
+ w2 = w1 * (int(c)**(-0.5))
+ del w1
+ w3 = torch.nn.functional.softmax(w2, dim=2)
+ del w2
+
+ # attend to values
+ v1 = v.reshape(b, c, h*w)
+ w4 = w3.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
+ del w3
+
+ h_[:, :, i:end] = torch.bmm(v1, w4) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
+ del v1, w4
+
+ h2 = h_.reshape(b, c, h, w)
+ del h_
+
+ h3 = self.proj_out(h2)
+ del h2
+
+ h3 += x
+
+ return h3
def make_attn(in_channels, attn_type="vanilla"):
@@ -540,31 +590,54 @@ class Decoder(nn.Module):
temb = None
# z to block_in
- h = self.conv_in(z)
+ h1 = self.conv_in(z)
# middle
- h = self.mid.block_1(h, temb)
- h = self.mid.attn_1(h)
- h = self.mid.block_2(h, temb)
+ h2 = self.mid.block_1(h1, temb)
+ del h1
+
+ h3 = self.mid.attn_1(h2)
+ del h2
+
+ h = self.mid.block_2(h3, temb)
+ del h3
+
+ # prepare for up sampling
+ gc.collect()
+ torch.cuda.empty_cache()
# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks+1):
h = self.up[i_level].block[i_block](h, temb)
if len(self.up[i_level].attn) > 0:
- h = self.up[i_level].attn[i_block](h)
+ t = h
+ h = self.up[i_level].attn[i_block](t)
+ del t
+
if i_level != 0:
- h = self.up[i_level].upsample(h)
+ t = h
+ h = self.up[i_level].upsample(t)
+ del t
# end
if self.give_pre_end:
return h
- h = self.norm_out(h)
- h = nonlinearity(h)
- h = self.conv_out(h)
+ h1 = self.norm_out(h)
+ del h
+
+ h2 = nonlinearity(h1)
+ del h1
+
+ h = self.conv_out(h2)
+ del h2
+
if self.tanh_out:
- h = torch.tanh(h)
+ t = h
+ h = torch.tanh(t)
+ del t
+
return h
diff --git a/ldm/modules/diffusionmodules/util.py b/ldm/modules/diffusionmodules/util.py
index a952e6c..f872ba0 100644
--- a/ldm/modules/diffusionmodules/util.py
+++ b/ldm/modules/diffusionmodules/util.py
@@ -54,7 +54,8 @@ def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timestep
# assert ddim_timesteps.shape[0] == num_ddim_timesteps
# add one to get the final alpha values right (the ones from first scale to data during sampling)
- steps_out = ddim_timesteps + 1
+ # steps_out = ddim_timesteps + 1 # removed due to some issues when reaching 1000
+ steps_out = np.where(ddim_timesteps != 999, ddim_timesteps+1, ddim_timesteps)
if verbose:
print(f'Selected timesteps for ddim sampler: {steps_out}')
return steps_out
@@ -264,4 +265,4 @@ class HybridConditioner(nn.Module):
def noise_like(shape, device, repeat=False):
repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
noise = lambda: torch.randn(shape, device=device)
- return repeat_noise() if repeat else noise()
\ No newline at end of file
+ return repeat_noise() if repeat else noise()
diff --git a/scripts/DeforumStableDiffusion.py b/scripts/DeforumStableDiffusion.py
new file mode 100644
index 0000000..cd88539
--- /dev/null
+++ b/scripts/DeforumStableDiffusion.py
@@ -0,0 +1,1312 @@
+#Deforum Stable Diffusion v0.4
+#Stable Diffusion by Robin Rombach, Andreas Blattmann, Dominik Lorenz, Patrick Esser, Björn Ommer and the Stability.ai Team. K Diffusion by Katherine Crowson. You need to get the ckpt file and put it on your Google Drive first to use this. It can be downloaded from HuggingFace.
+
+#Notebook by deforum
+#Local Version by DGSpitzer 大谷的游戏创作小屋
+
+import os, time
+def get_output_folder(output_path, batch_folder):
+ out_path = os.path.join(output_path,time.strftime('%Y-%m'))
+ if batch_folder != "":
+ out_path = os.path.join(out_path, batch_folder)
+ os.makedirs(out_path, exist_ok=True)
+ return out_path
+
+
+def main():
+
+ import argparse
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument(
+ "--settings",
+ type=str,
+ default="./examples/runSettings_StillImages.txt",
+ help="Settings file",
+ )
+
+ parser.add_argument(
+ "--enable_animation_mode",
+ default=False,
+ action='store_true',
+ help="Enable animation mode settings",
+ )
+
+ opt = parser.parse_args()
+
+ #@markdown **Model and Output Paths**
+ # ask for the link
+ print("Local Path Variables:\n")
+
+ models_path = "./models" #@param {type:"string"}
+ output_path = "./output" #@param {type:"string"}
+
+ #@markdown **Google Drive Path Variables (Optional)**
+ mount_google_drive = False #@param {type:"boolean"}
+ force_remount = False
+
+
+
+
+ if mount_google_drive:
+ from google.colab import drive # type: ignore
+ try:
+ drive_path = "/content/drive"
+ drive.mount(drive_path,force_remount=force_remount)
+ models_path_gdrive = "/content/drive/MyDrive/AI/models" #@param {type:"string"}
+ output_path_gdrive = "/content/drive/MyDrive/AI/StableDiffusion" #@param {type:"string"}
+ models_path = models_path_gdrive
+ output_path = output_path_gdrive
+ except:
+ print("...error mounting drive or with drive path variables")
+ print("...reverting to default path variables")
+
+ import os
+ os.makedirs(models_path, exist_ok=True)
+ os.makedirs(output_path, exist_ok=True)
+
+ print(f"models_path: {models_path}")
+ print(f"output_path: {output_path}")
+
+
+
+ #@markdown **Python Definitions**
+ import IPython
+ import json
+ from IPython import display
+
+ import gc, math, os, pathlib, shutil, subprocess, sys, time
+ import cv2
+ import numpy as np
+ import pandas as pd
+ import random
+ import requests
+ import torch, torchvision
+ import torch.nn as nn
+ import torchvision.transforms as T
+ import torchvision.transforms.functional as TF
+ from contextlib import contextmanager, nullcontext
+ from einops import rearrange, repeat
+ from itertools import islice
+ from omegaconf import OmegaConf
+ from PIL import Image
+ from pytorch_lightning import seed_everything
+ from skimage.exposure import match_histograms
+ from torchvision.utils import make_grid
+ from tqdm import tqdm, trange
+ from types import SimpleNamespace
+ from torch import autocast
+
+ sys.path.extend([
+ 'src/taming-transformers',
+ 'src/clip',
+ 'stable-diffusion/',
+ 'k-diffusion',
+ 'pytorch3d-lite',
+ 'AdaBins',
+ 'MiDaS',
+ ])
+
+ import py3d_tools as p3d
+
+ from helpers import DepthModel, sampler_fn
+ from k_diffusion.external import CompVisDenoiser
+ from ldm.util import instantiate_from_config
+ from ldm.models.diffusion.ddim import DDIMSampler
+ from ldm.models.diffusion.plms import PLMSSampler
+
+
+ #Read settings files
+ def load_args(path):
+ with open(path, "r") as f:
+ loaded_args = json.load(f)#, ensure_ascii=False, indent=4)
+ return loaded_args
+
+ master_args = load_args(opt.settings)
+
+
+ def sanitize(prompt):
+ whitelist = set('abcdefghijklmnopqrstuvwxyz ABCDEFGHIJKLMNOPQRSTUVWXYZ')
+ tmp = ''.join(filter(whitelist.__contains__, prompt))
+ return tmp.replace(' ', '_')
+
+ def anim_frame_warp_2d(prev_img_cv2, args, anim_args, keys, frame_idx):
+ angle = keys.angle_series[frame_idx]
+ zoom = keys.zoom_series[frame_idx]
+ translation_x = keys.translation_x_series[frame_idx]
+ translation_y = keys.translation_y_series[frame_idx]
+
+ center = (args.W // 2, args.H // 2)
+ trans_mat = np.float32([[1, 0, translation_x], [0, 1, translation_y]])
+ rot_mat = cv2.getRotationMatrix2D(center, angle, zoom)
+ trans_mat = np.vstack([trans_mat, [0,0,1]])
+ rot_mat = np.vstack([rot_mat, [0,0,1]])
+ xform = np.matmul(rot_mat, trans_mat)
+
+ return cv2.warpPerspective(
+ prev_img_cv2,
+ xform,
+ (prev_img_cv2.shape[1], prev_img_cv2.shape[0]),
+ borderMode=cv2.BORDER_WRAP if anim_args.border == 'wrap' else cv2.BORDER_REPLICATE
+ )
+
+ def anim_frame_warp_3d(prev_img_cv2, depth, anim_args, keys, frame_idx):
+ TRANSLATION_SCALE = 1.0/200.0 # matches Disco
+ translate_xyz = [
+ -keys.translation_x_series[frame_idx] * TRANSLATION_SCALE,
+ keys.translation_y_series[frame_idx] * TRANSLATION_SCALE,
+ -keys.translation_z_series[frame_idx] * TRANSLATION_SCALE
+ ]
+ rotate_xyz = [
+ math.radians(keys.rotation_3d_x_series[frame_idx]),
+ math.radians(keys.rotation_3d_y_series[frame_idx]),
+ math.radians(keys.rotation_3d_z_series[frame_idx])
+ ]
+ rot_mat = p3d.euler_angles_to_matrix(torch.tensor(rotate_xyz, device=device), "XYZ").unsqueeze(0)
+ result = transform_image_3d(prev_img_cv2, depth, rot_mat, translate_xyz, anim_args)
+ torch.cuda.empty_cache()
+ return result
+
+ def add_noise(sample: torch.Tensor, noise_amt: float) -> torch.Tensor:
+ return sample + torch.randn(sample.shape, device=sample.device) * noise_amt
+
+ def load_img(path, shape, use_alpha_as_mask=False):
+ # use_alpha_as_mask: Read the alpha channel of the image as the mask image
+ if path.startswith('http://') or path.startswith('https://'):
+ image = Image.open(requests.get(path, stream=True).raw)
+ else:
+ image = Image.open(path)
+
+ if use_alpha_as_mask:
+ image = image.convert('RGBA')
+ else:
+ image = image.convert('RGB')
+
+ image = image.resize(shape, resample=Image.LANCZOS)
+
+ mask_image = None
+ if use_alpha_as_mask:
+ # Split alpha channel into a mask_image
+ red, green, blue, alpha = Image.Image.split(image)
+ mask_image = alpha.convert('L')
+ image = image.convert('RGB')
+
+ image = np.array(image).astype(np.float16) / 255.0
+ image = image[None].transpose(0, 3, 1, 2)
+ image = torch.from_numpy(image)
+ image = 2.*image - 1.
+
+ return image, mask_image
+
+ def load_mask_latent(mask_input, shape):
+ # mask_input (str or PIL Image.Image): Path to the mask image or a PIL Image object
+ # shape (list-like len(4)): shape of the image to match, usually latent_image.shape
+
+ if isinstance(mask_input, str): # mask input is probably a file name
+ if mask_input.startswith('http://') or mask_input.startswith('https://'):
+ mask_image = Image.open(requests.get(mask_input, stream=True).raw).convert('RGBA')
+ else:
+ mask_image = Image.open(mask_input).convert('RGBA')
+ elif isinstance(mask_input, Image.Image):
+ mask_image = mask_input
+ else:
+ raise Exception("mask_input must be a PIL image or a file name")
+
+ mask_w_h = (shape[-1], shape[-2])
+ mask = mask_image.resize(mask_w_h, resample=Image.LANCZOS)
+ mask = mask.convert("L")
+ return mask
+
+ def prepare_mask(mask_input, mask_shape, mask_brightness_adjust=1.0, mask_contrast_adjust=1.0):
+ # mask_input (str or PIL Image.Image): Path to the mask image or a PIL Image object
+ # shape (list-like len(4)): shape of the image to match, usually latent_image.shape
+ # mask_brightness_adjust (non-negative float): amount to adjust brightness of the iamge,
+ # 0 is black, 1 is no adjustment, >1 is brighter
+ # mask_contrast_adjust (non-negative float): amount to adjust contrast of the image,
+ # 0 is a flat grey image, 1 is no adjustment, >1 is more contrast
+
+ mask = load_mask_latent(mask_input, mask_shape)
+
+ # Mask brightness/contrast adjustments
+ if mask_brightness_adjust != 1:
+ mask = TF.adjust_brightness(mask, mask_brightness_adjust)
+ if mask_contrast_adjust != 1:
+ mask = TF.adjust_contrast(mask, mask_contrast_adjust)
+
+ # Mask image to array
+ mask = np.array(mask).astype(np.float32) / 255.0
+ mask = np.tile(mask,(4,1,1))
+ mask = np.expand_dims(mask,axis=0)
+ mask = torch.from_numpy(mask)
+
+ if args.invert_mask:
+ mask = ( (mask - 0.5) * -1) + 0.5
+
+ mask = np.clip(mask,0,1)
+ return mask
+
+ def maintain_colors(prev_img, color_match_sample, mode):
+ if mode == 'Match Frame 0 RGB':
+ return match_histograms(prev_img, color_match_sample, multichannel=True)
+ elif mode == 'Match Frame 0 HSV':
+ prev_img_hsv = cv2.cvtColor(prev_img, cv2.COLOR_RGB2HSV)
+ color_match_hsv = cv2.cvtColor(color_match_sample, cv2.COLOR_RGB2HSV)
+ matched_hsv = match_histograms(prev_img_hsv, color_match_hsv, multichannel=True)
+ return cv2.cvtColor(matched_hsv, cv2.COLOR_HSV2RGB)
+ else: # Match Frame 0 LAB
+ prev_img_lab = cv2.cvtColor(prev_img, cv2.COLOR_RGB2LAB)
+ color_match_lab = cv2.cvtColor(color_match_sample, cv2.COLOR_RGB2LAB)
+ matched_lab = match_histograms(prev_img_lab, color_match_lab, multichannel=True)
+ return cv2.cvtColor(matched_lab, cv2.COLOR_LAB2RGB)
+
+
+ def make_callback(sampler_name, dynamic_threshold=None, static_threshold=None, mask=None, init_latent=None, sigmas=None, sampler=None, masked_noise_modifier=1.0):
+ # Creates the callback function to be passed into the samplers
+ # The callback function is applied to the image at each step
+ def dynamic_thresholding_(img, threshold):
+ # Dynamic thresholding from Imagen paper (May 2022)
+ s = np.percentile(np.abs(img.cpu()), threshold, axis=tuple(range(1,img.ndim)))
+ s = np.max(np.append(s,1.0))
+ torch.clamp_(img, -1*s, s)
+ torch.FloatTensor.div_(img, s)
+
+ # Callback for samplers in the k-diffusion repo, called thus:
+ # callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
+ def k_callback_(args_dict):
+ if dynamic_threshold is not None:
+ dynamic_thresholding_(args_dict['x'], dynamic_threshold)
+ if static_threshold is not None:
+ torch.clamp_(args_dict['x'], -1*static_threshold, static_threshold)
+ if mask is not None:
+ init_noise = init_latent + noise * args_dict['sigma']
+ is_masked = torch.logical_and(mask >= mask_schedule[args_dict['i']], mask != 0 )
+ new_img = init_noise * torch.where(is_masked,1,0) + args_dict['x'] * torch.where(is_masked,0,1)
+ args_dict['x'].copy_(new_img)
+
+ # Function that is called on the image (img) and step (i) at each step
+ def img_callback_(img, i):
+ # Thresholding functions
+ if dynamic_threshold is not None:
+ dynamic_thresholding_(img, dynamic_threshold)
+ if static_threshold is not None:
+ torch.clamp_(img, -1*static_threshold, static_threshold)
+ if mask is not None:
+ i_inv = len(sigmas) - i - 1
+ init_noise = sampler.stochastic_encode(init_latent, torch.tensor([i_inv]*batch_size).to(device), noise=noise)
+ is_masked = torch.logical_and(mask >= mask_schedule[i], mask != 0 )
+ new_img = init_noise * torch.where(is_masked,1,0) + img * torch.where(is_masked,0,1)
+ img.copy_(new_img)
+
+ if init_latent is not None:
+ noise = torch.randn_like(init_latent, device=device) * masked_noise_modifier
+ if sigmas is not None and len(sigmas) > 0:
+ mask_schedule, _ = torch.sort(sigmas/torch.max(sigmas))
+ elif len(sigmas) == 0:
+ mask = None # no mask needed if no steps (usually happens because strength==1.0)
+ if sampler_name in ["plms","ddim"]:
+ # Callback function formated for compvis latent diffusion samplers
+ if mask is not None:
+ assert sampler is not None, "Callback function for stable-diffusion samplers requires sampler variable"
+ batch_size = init_latent.shape[0]
+
+ callback = img_callback_
+ else:
+ # Default callback function uses k-diffusion sampler variables
+ callback = k_callback_
+
+ return callback
+
+ def sample_from_cv2(sample: np.ndarray) -> torch.Tensor:
+ sample = ((sample.astype(float) / 255.0) * 2) - 1
+ sample = sample[None].transpose(0, 3, 1, 2).astype(np.float16)
+ sample = torch.from_numpy(sample)
+ return sample
+
+ def sample_to_cv2(sample: torch.Tensor, type=np.uint8) -> np.ndarray:
+ sample_f32 = rearrange(sample.squeeze().cpu().numpy(), "c h w -> h w c").astype(np.float32)
+ sample_f32 = ((sample_f32 * 0.5) + 0.5).clip(0, 1)
+ sample_int8 = (sample_f32 * 255)
+ return sample_int8.astype(type)
+
+ def transform_image_3d(prev_img_cv2, depth_tensor, rot_mat, translate, anim_args):
+ # adapted and optimized version of transform_image_3d from Disco Diffusion https://github.com/alembics/disco-diffusion
+ w, h = prev_img_cv2.shape[1], prev_img_cv2.shape[0]
+
+ aspect_ratio = float(w)/float(h)
+ near, far, fov_deg = anim_args.near_plane, anim_args.far_plane, anim_args.fov
+ persp_cam_old = p3d.FoVPerspectiveCameras(near, far, aspect_ratio, fov=fov_deg, degrees=True, device=device)
+ persp_cam_new = p3d.FoVPerspectiveCameras(near, far, aspect_ratio, fov=fov_deg, degrees=True, R=rot_mat, T=torch.tensor([translate]), device=device)
+
+ # range of [-1,1] is important to torch grid_sample's padding handling
+ y,x = torch.meshgrid(torch.linspace(-1.,1.,h,dtype=torch.float32,device=device),torch.linspace(-1.,1.,w,dtype=torch.float32,device=device))
+ z = torch.as_tensor(depth_tensor, dtype=torch.float32, device=device)
+ xyz_old_world = torch.stack((x.flatten(), y.flatten(), z.flatten()), dim=1)
+
+ xyz_old_cam_xy = persp_cam_old.get_full_projection_transform().transform_points(xyz_old_world)[:,0:2]
+ xyz_new_cam_xy = persp_cam_new.get_full_projection_transform().transform_points(xyz_old_world)[:,0:2]
+
+ offset_xy = xyz_new_cam_xy - xyz_old_cam_xy
+ # affine_grid theta param expects a batch of 2D mats. Each is 2x3 to do rotation+translation.
+ identity_2d_batch = torch.tensor([[1.,0.,0.],[0.,1.,0.]], device=device).unsqueeze(0)
+ # coords_2d will have shape (N,H,W,2).. which is also what grid_sample needs.
+ coords_2d = torch.nn.functional.affine_grid(identity_2d_batch, [1,1,h,w], align_corners=False)
+ offset_coords_2d = coords_2d - torch.reshape(offset_xy, (h,w,2)).unsqueeze(0)
+
+ image_tensor = rearrange(torch.from_numpy(prev_img_cv2.astype(np.float32)), 'h w c -> c h w').to(device)
+ new_image = torch.nn.functional.grid_sample(
+ image_tensor.add(1/512 - 0.0001).unsqueeze(0),
+ offset_coords_2d,
+ mode=anim_args.sampling_mode,
+ padding_mode=anim_args.padding_mode,
+ align_corners=False
+ )
+
+ # convert back to cv2 style numpy array
+ result = rearrange(
+ new_image.squeeze().clamp(0,255),
+ 'c h w -> h w c'
+ ).cpu().numpy().astype(prev_img_cv2.dtype)
+ return result
+
+ def generate(args, return_latent=False, return_sample=False, return_c=False):
+ seed_everything(args.seed)
+ os.makedirs(args.outdir, exist_ok=True)
+
+ sampler = PLMSSampler(model) if args.sampler == 'plms' else DDIMSampler(model)
+ model_wrap = CompVisDenoiser(model)
+ batch_size = args.n_samples
+ prompt = args.prompt
+ assert prompt is not None
+ data = [batch_size * [prompt]]
+ precision_scope = autocast if args.precision == "autocast" else nullcontext
+
+ init_latent = None
+ mask_image = None
+ init_image = None
+ if args.init_latent is not None:
+ init_latent = args.init_latent
+ elif args.init_sample is not None:
+ with precision_scope("cuda"):
+ init_latent = model.get_first_stage_encoding(model.encode_first_stage(args.init_sample))
+ elif args.use_init and args.init_image != None and args.init_image != '':
+ init_image, mask_image = load_img(args.init_image,
+ shape=(args.W, args.H),
+ use_alpha_as_mask=args.use_alpha_as_mask)
+ init_image = init_image.to(device)
+ init_image = repeat(init_image, '1 ... -> b ...', b=batch_size)
+ with precision_scope("cuda"):
+ init_latent = model.get_first_stage_encoding(model.encode_first_stage(init_image)) # move to latent space
+
+ if not args.use_init and args.strength > 0 and args.strength_0_no_init:
+ print("\nNo init image, but strength > 0. Strength has been auto set to 0, since use_init is False.")
+ print("If you want to force strength > 0 with no init, please set strength_0_no_init to False.\n")
+ args.strength = 0
+
+ # Mask functions
+ if args.use_mask:
+ assert args.mask_file is not None or mask_image is not None, "use_mask==True: An mask image is required for a mask. Please enter a mask_file or use an init image with an alpha channel"
+ assert args.use_init, "use_mask==True: use_init is required for a mask"
+ assert init_latent is not None, "use_mask==True: An latent init image is required for a mask"
+
+ mask = prepare_mask(args.mask_file if mask_image is None else mask_image,
+ init_latent.shape,
+ args.mask_contrast_adjust,
+ args.mask_brightness_adjust)
+
+ if (torch.all(mask == 0) or torch.all(mask == 1)) and args.use_alpha_as_mask:
+ raise Warning("use_alpha_as_mask==True: Using the alpha channel from the init image as a mask, but the alpha channel is blank.")
+
+ mask = mask.to(device)
+ mask = repeat(mask, '1 ... -> b ...', b=batch_size)
+ else:
+ mask = None
+
+ t_enc = int((1.0-args.strength) * args.steps)
+
+ # Noise schedule for the k-diffusion samplers (used for masking)
+ k_sigmas = model_wrap.get_sigmas(args.steps)
+ k_sigmas = k_sigmas[len(k_sigmas)-t_enc-1:]
+
+ if args.sampler in ['plms','ddim']:
+ sampler.make_schedule(ddim_num_steps=args.steps, ddim_eta=args.ddim_eta, ddim_discretize='fill', verbose=False)
+
+ callback = make_callback(sampler_name=args.sampler,
+ dynamic_threshold=args.dynamic_threshold,
+ static_threshold=args.static_threshold,
+ mask=mask,
+ init_latent=init_latent,
+ sigmas=k_sigmas,
+ sampler=sampler)
+
+ results = []
+ with torch.no_grad():
+ with precision_scope("cuda"):
+ with model.ema_scope():
+ for prompts in data:
+ uc = None
+ if args.scale != 1.0:
+ uc = model.get_learned_conditioning(batch_size * [""])
+ if isinstance(prompts, tuple):
+ prompts = list(prompts)
+ c = model.get_learned_conditioning(prompts)
+
+ if args.init_c != None:
+ c = args.init_c
+
+ if args.sampler in ["klms","dpm2","dpm2_ancestral","heun","euler","euler_ancestral"]:
+ samples = sampler_fn(
+ c=c,
+ uc=uc,
+ args=args,
+ model_wrap=model_wrap,
+ init_latent=init_latent,
+ t_enc=t_enc,
+ device=device,
+ cb=callback)
+ else:
+ # args.sampler == 'plms' or args.sampler == 'ddim':
+ if init_latent is not None and args.strength > 0:
+ z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(device))
+ else:
+ z_enc = torch.randn([args.n_samples, args.C, args.H // args.f, args.W // args.f], device=device)
+ if args.sampler == 'ddim':
+ samples = sampler.decode(z_enc,
+ c,
+ t_enc,
+ unconditional_guidance_scale=args.scale,
+ unconditional_conditioning=uc,
+ img_callback=callback)
+ elif args.sampler == 'plms': # no "decode" function in plms, so use "sample"
+ shape = [args.C, args.H // args.f, args.W // args.f]
+ samples, _ = sampler.sample(S=args.steps,
+ conditioning=c,
+ batch_size=args.n_samples,
+ shape=shape,
+ verbose=False,
+ unconditional_guidance_scale=args.scale,
+ unconditional_conditioning=uc,
+ eta=args.ddim_eta,
+ x_T=z_enc,
+ img_callback=callback)
+ else:
+ raise Exception(f"Sampler {args.sampler} not recognised.")
+
+ if return_latent:
+ results.append(samples.clone())
+
+ x_samples = model.decode_first_stage(samples)
+ if return_sample:
+ results.append(x_samples.clone())
+
+ x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
+
+ if return_c:
+ results.append(c.clone())
+
+ for x_sample in x_samples:
+ x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
+ image = Image.fromarray(x_sample.astype(np.uint8))
+ results.append(image)
+ return results
+
+ #@markdown **Select and Load Model**
+
+ model_config = "v1-inference.yaml" #@param ["custom","v1-inference.yaml"]
+ model_checkpoint = "sd-v1-4.ckpt" #@param ["custom","sd-v1-4-full-ema.ckpt","sd-v1-4.ckpt","sd-v1-3-full-ema.ckpt","sd-v1-3.ckpt","sd-v1-2-full-ema.ckpt","sd-v1-2.ckpt","sd-v1-1-full-ema.ckpt","sd-v1-1.ckpt"]
+ custom_config_path = "" #@param {type:"string"}
+ custom_checkpoint_path = "" #@param {type:"string"}
+
+ load_on_run_all = True #@param {type: 'boolean'}
+ half_precision = True # check
+ check_sha256 = True #@param {type:"boolean"}
+
+ model_map = {
+ "sd-v1-4-full-ema.ckpt": {'sha256': '14749efc0ae8ef0329391ad4436feb781b402f4fece4883c7ad8d10556d8a36a'},
+ "sd-v1-4.ckpt": {'sha256': 'fe4efff1e174c627256e44ec2991ba279b3816e364b49f9be2abc0b3ff3f8556'},
+ "sd-v1-3-full-ema.ckpt": {'sha256': '54632c6e8a36eecae65e36cb0595fab314e1a1545a65209f24fde221a8d4b2ca'},
+ "sd-v1-3.ckpt": {'sha256': '2cff93af4dcc07c3e03110205988ff98481e86539c51a8098d4f2236e41f7f2f'},
+ "sd-v1-2-full-ema.ckpt": {'sha256': 'bc5086a904d7b9d13d2a7bccf38f089824755be7261c7399d92e555e1e9ac69a'},
+ "sd-v1-2.ckpt": {'sha256': '3b87d30facd5bafca1cbed71cfb86648aad75d1c264663c0cc78c7aea8daec0d'},
+ "sd-v1-1-full-ema.ckpt": {'sha256': 'efdeb5dc418a025d9a8cc0a8617e106c69044bc2925abecc8a254b2910d69829'},
+ "sd-v1-1.ckpt": {'sha256': '86cd1d3ccb044d7ba8db743d717c9bac603c4043508ad2571383f954390f3cea'}
+ }
+
+ # config path
+ ckpt_config_path = custom_config_path if model_config == "custom" else os.path.join(models_path, model_config)
+ if os.path.exists(ckpt_config_path):
+ print(f"{ckpt_config_path} exists")
+ else:
+ ckpt_config_path = "./stable-diffusion/configs/stable-diffusion/v1-inference.yaml"
+ print(f"Using config: {ckpt_config_path}")
+
+ # checkpoint path or download
+ ckpt_path = custom_checkpoint_path if model_checkpoint == "custom" else os.path.join(models_path, model_checkpoint)
+ ckpt_valid = True
+ if os.path.exists(ckpt_path):
+ print(f"{ckpt_path} exists")
+ else:
+ print(f"Please download model checkpoint and place in {os.path.join(models_path, model_checkpoint)}")
+ ckpt_valid = False
+
+ if check_sha256 and model_checkpoint != "custom" and ckpt_valid:
+ import hashlib
+ print("\n...checking sha256")
+ with open(ckpt_path, "rb") as f:
+ bytes = f.read()
+ hash = hashlib.sha256(bytes).hexdigest()
+ del bytes
+ if model_map[model_checkpoint]["sha256"] == hash:
+ print("hash is correct\n")
+ else:
+ print("hash in not correct\n")
+ ckpt_valid = False
+
+ if ckpt_valid:
+ print(f"Using ckpt: {ckpt_path}")
+
+ def load_model_from_config(config, ckpt, verbose=False, device='cuda', half_precision=True):
+ map_location = "cuda" #@param ["cpu", "cuda"]
+ print(f"Loading model from {ckpt}")
+ pl_sd = torch.load(ckpt, map_location=map_location)
+ if "global_step" in pl_sd:
+ print(f"Global Step: {pl_sd['global_step']}")
+ sd = pl_sd["state_dict"]
+ model = instantiate_from_config(config.model)
+ m, u = model.load_state_dict(sd, strict=False)
+ if len(m) > 0 and verbose:
+ print("missing keys:")
+ print(m)
+ if len(u) > 0 and verbose:
+ print("unexpected keys:")
+ print(u)
+
+ if half_precision:
+ model = model.half().to(device)
+ else:
+ model = model.to(device)
+ model.eval()
+ return model
+
+ if load_on_run_all and ckpt_valid:
+ local_config = OmegaConf.load(f"{ckpt_config_path}")
+ model = load_model_from_config(local_config, f"{ckpt_path}", half_precision=half_precision)
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
+ model = model.to(device)
+
+
+ def DeforumAnimArgs():
+
+ #@markdown ####**Animation:**
+ if opt.enable_animation_mode == True:
+ animation_mode = master_args["animation_mode"] #@param ['None', '2D', '3D', 'Video Input', 'Interpolation'] {type:'string'}
+ max_frames = master_args["max_frames"] #@param {type:"number"}
+ border = master_args["border"] #@param ['wrap', 'replicate'] {type:'string'}
+
+ #@markdown ####**Motion Parameters:**
+ angle = master_args["angle"]#@param {type:"string"}
+ zoom = master_args["zoom"] #@param {type:"string"}
+ translation_x = master_args["translation_x"] #@param {type:"string"}
+ translation_y = master_args["translation_y"] #@param {type:"string"}
+ translation_z = master_args["translation_z"] #@param {type:"string"}
+ rotation_3d_x = master_args["rotation_3d_x"] #@param {type:"string"}
+ rotation_3d_y = master_args["rotation_3d_y"] #@param {type:"string"}
+ rotation_3d_z = master_args["rotation_3d_z"] #@param {type:"string"}
+ noise_schedule = master_args["noise_schedule"] #@param {type:"string"}
+ strength_schedule = master_args["strength_schedule"] #@param {type:"string"}
+ contrast_schedule = master_args["contrast_schedule"] #@param {type:"string"}
+
+ #@markdown ####**Coherence:**
+ color_coherence = master_args["color_coherence"] #@param ['None', 'Match Frame 0 HSV', 'Match Frame 0 LAB', 'Match Frame 0 RGB'] {type:'string'}
+ diffusion_cadence = master_args["diffusion_cadence"] #@param ['1','2','3','4','5','6','7','8'] {type:'string'}
+
+ #@markdown #### Depth Warping
+ use_depth_warping = master_args["use_depth_warping"] #@param {type:"boolean"}
+ midas_weight = master_args["midas_weight"] #@param {type:"number"}
+ near_plane = master_args["near_plane"]
+ far_plane = master_args["far_plane"]
+ fov = master_args["fov"] #@param {type:"number"}
+ padding_mode = master_args["padding_mode"] #@param ['border', 'reflection', 'zeros'] {type:'string'}
+ sampling_mode = master_args["sampling_mode"] #@param ['bicubic', 'bilinear', 'nearest'] {type:'string'}
+ save_depth_maps = master_args["save_depth_maps"] #@param {type:"boolean"}
+
+ #@markdown ####**Video Input:**
+ video_init_path = master_args["video_init_path"] #@param {type:"string"}
+ extract_nth_frame = master_args["extract_nth_frame"] #@param {type:"number"}
+
+ #@markdown ####**Interpolation:**
+ interpolate_key_frames = master_args["interpolate_key_frames"] #@param {type:"boolean"}
+ interpolate_x_frames = master_args["interpolate_x_frames"] #@param {type:"number"}
+
+ #@markdown ####**Resume Animation:**
+ resume_from_timestring = master_args["resume_from_timestring"] #@param {type:"boolean"}
+ resume_timestring = master_args["resume_timestring"] #@param {type:"string"}
+ else:
+ #@markdown ####**Still image mode:**
+ animation_mode = 'None' #@param ['None', '2D', '3D', 'Video Input', 'Interpolation'] {type:'string'}
+ max_frames = 10 #@param {type:"number"}
+ border = 'wrap' #@param ['wrap', 'replicate'] {type:'string'}
+
+ #@markdown ####**Motion Parameters:**
+ angle = "0:(0)"#@param {type:"string"}
+ zoom = "0:(1.04)"#@param {type:"string"}
+ translation_x = "0:(0)"#@param {type:"string"}
+ translation_y = "0:(2)"#@param {type:"string"}
+ translation_z = "0:(0.5)"#@param {type:"string"}
+ rotation_3d_x = "0:(0)"#@param {type:"string"}
+ rotation_3d_y = "0:(0)"#@param {type:"string"}
+ rotation_3d_z = "0:(0)"#@param {type:"string"}
+ noise_schedule = "0: (0.02)"#@param {type:"string"}
+ strength_schedule = "0: (0.6)"#@param {type:"string"}
+ contrast_schedule = "0: (1.0)"#@param {type:"string"}
+
+ #@markdown ####**Coherence:**
+ color_coherence = 'Match Frame 0 LAB' #@param ['None', 'Match Frame 0 HSV', 'Match Frame 0 LAB', 'Match Frame 0 RGB'] {type:'string'}
+ diffusion_cadence = '1' #@param ['1','2','3','4','5','6','7','8'] {type:'string'}
+
+ #@markdown #### Depth Warping
+ use_depth_warping = True #@param {type:"boolean"}
+ midas_weight = 0.3#@param {type:"number"}
+ near_plane = 200
+ far_plane = 10000
+ fov = 40#@param {type:"number"}
+ padding_mode = 'border'#@param ['border', 'reflection', 'zeros'] {type:'string'}
+ sampling_mode = 'bicubic'#@param ['bicubic', 'bilinear', 'nearest'] {type:'string'}
+ save_depth_maps = False #@param {type:"boolean"}
+
+ #@markdown ####**Video Input:**
+ video_init_path ='./input/video_in.mp4'#@param {type:"string"}
+ extract_nth_frame = 1#@param {type:"number"}
+
+ #@markdown ####**Interpolation:**
+ interpolate_key_frames = True #@param {type:"boolean"}
+ interpolate_x_frames = 100 #@param {type:"number"}
+
+ #@markdown ####**Resume Animation:**
+ resume_from_timestring = False #@param {type:"boolean"}
+ resume_timestring = "20220829210106" #@param {type:"string"}
+
+ return locals()
+
+ class DeformAnimKeys():
+ def __init__(self, anim_args):
+ self.angle_series = get_inbetweens(parse_key_frames(anim_args.angle))
+ self.zoom_series = get_inbetweens(parse_key_frames(anim_args.zoom))
+ self.translation_x_series = get_inbetweens(parse_key_frames(anim_args.translation_x))
+ self.translation_y_series = get_inbetweens(parse_key_frames(anim_args.translation_y))
+ self.translation_z_series = get_inbetweens(parse_key_frames(anim_args.translation_z))
+ self.rotation_3d_x_series = get_inbetweens(parse_key_frames(anim_args.rotation_3d_x))
+ self.rotation_3d_y_series = get_inbetweens(parse_key_frames(anim_args.rotation_3d_y))
+ self.rotation_3d_z_series = get_inbetweens(parse_key_frames(anim_args.rotation_3d_z))
+ self.noise_schedule_series = get_inbetweens(parse_key_frames(anim_args.noise_schedule))
+ self.strength_schedule_series = get_inbetweens(parse_key_frames(anim_args.strength_schedule))
+ self.contrast_schedule_series = get_inbetweens(parse_key_frames(anim_args.contrast_schedule))
+
+
+ def get_inbetweens(key_frames, integer=False, interp_method='Linear'):
+ key_frame_series = pd.Series([np.nan for a in range(anim_args.max_frames)])
+
+ for i, value in key_frames.items():
+ key_frame_series[i] = value
+ key_frame_series = key_frame_series.astype(float)
+
+ if interp_method == 'Cubic' and len(key_frames.items()) <= 3:
+ interp_method = 'Quadratic'
+ if interp_method == 'Quadratic' and len(key_frames.items()) <= 2:
+ interp_method = 'Linear'
+
+ key_frame_series[0] = key_frame_series[key_frame_series.first_valid_index()]
+ key_frame_series[anim_args.max_frames-1] = key_frame_series[key_frame_series.last_valid_index()]
+ key_frame_series = key_frame_series.interpolate(method=interp_method.lower(),limit_direction='both')
+ if integer:
+ return key_frame_series.astype(int)
+ return key_frame_series
+
+ def parse_key_frames(string, prompt_parser=None):
+ import re
+ pattern = r'((?P[0-9]+):[\s]*[\(](?P[\S\s]*?)[\)])'
+ frames = dict()
+ for match_object in re.finditer(pattern, string):
+ frame = int(match_object.groupdict()['frame'])
+ param = match_object.groupdict()['param']
+ if prompt_parser:
+ frames[frame] = prompt_parser(param)
+ else:
+ frames[frame] = param
+ if frames == {} and len(string) != 0:
+ raise RuntimeError('Key Frame string not correctly formatted')
+ return frames
+
+ #Prompt will be put in here: for example:
+ '''
+ prompts = [
+ "a beaufiful young girl holding a flower, art by huang guangjian and gil elvgren and sachin teng, trending on artstation",
+ "a beaufiful young girl holding a flower, art by greg rutkowski and alphonse mucha, trending on artstation",
+ #"the third prompt I don't want it I commented it with an",
+ ]
+
+ animation_prompts = {
+ 0: "amazing alien landscape with lush vegetation and colourful galaxy foreground, digital art, breathtaking, golden ratio, extremely detailed, hyper - detailed, establishing shot, hyperrealistic, cinematic lighting, particles, unreal engine, simon stalenhag, rendered by beeple, makoto shinkai, syd meade, kentaro miura, jean giraud, environment concept, artstation, octane render, 8k uhd image",
+ 50: "desolate landscape fill with giant flowers, moody :: by James Jean, Jeff Koons, Dan McPharlin Daniel Merrian :: ornate, dynamic, particulate, rich colors, intricate, elegant, highly detailed, centered, artstation, smooth, sharp focus, octane render, 3d",
+ }
+ '''
+
+ #Replace by text file
+ prompts = master_args["prompts"]
+
+ if opt.enable_animation_mode:
+ animation_prompts = master_args["animation_prompts"]
+ else:
+ animation_prompts = {}
+
+
+
+ def DeforumArgs():
+
+ #@markdown **Image Settings**
+ W = master_args["width"] #@param
+ H = master_args["height"] #@param
+ W, H = map(lambda x: x - x % 64, (W, H)) # resize to integer multiple of 64
+
+ #@markdown **Sampling Settings**
+ seed = master_args["seed"] #@param
+ sampler = master_args["sampler"] #@param ["klms","dpm2","dpm2_ancestral","heun","euler","euler_ancestral","plms", "ddim"]
+ steps = master_args["steps"] #@param
+ scale = master_args["scale"] #@param
+ ddim_eta = master_args["ddim_eta"] #@param
+ dynamic_threshold = None
+ static_threshold = None
+
+ #@markdown **Save & Display Settings**
+ save_samples = True #@param {type:"boolean"}
+ save_settings = True #@param {type:"boolean"}
+ display_samples = True #@param {type:"boolean"}
+
+ #@markdown **Batch Settings**
+ n_batch = master_args["n_batch"] #@param
+ batch_name = master_args["batch_name"] #@param {type:"string"}
+ filename_format = master_args["filename_format"] #@param ["{timestring}_{index}_{seed}.png","{timestring}_{index}_{prompt}.png"]
+ seed_behavior = master_args["seed_behavior"] #@param ["iter","fixed","random"]
+ make_grid = False #@param {type:"boolean"}
+ grid_rows = 2 #@param
+ outdir = get_output_folder(output_path, batch_name)
+
+ #@markdown **Init Settings**
+ use_init = master_args["use_init"] #@param {type:"boolean"}
+ strength = master_args["strength"] #@param {type:"number"}
+ init_image = master_args["init_image"] #@param {type:"string"}
+ strength_0_no_init = True # Set the strength to 0 automatically when no init image is used
+ # Whiter areas of the mask are areas that change more
+ use_mask = master_args["use_mask"] #@param {type:"boolean"}
+ use_alpha_as_mask = master_args["use_alpha_as_mask"] # use the alpha channel of the init image as the mask
+ mask_file = master_args["mask_file"] #@param {type:"string"}
+ invert_mask = master_args["invert_mask"] #@param {type:"boolean"}
+ # Adjust mask image, 1.0 is no adjustment. Should be positive numbers.
+ mask_brightness_adjust = 1.0 #@param {type:"number"}
+ mask_contrast_adjust = 1.0 #@param {type:"number"}
+
+ n_samples = 1 # doesnt do anything
+ precision = 'autocast'
+ C = 4
+ f = 8
+
+ prompt = ""
+ timestring = ""
+ init_latent = None
+ init_sample = None
+ init_c = None
+
+ return locals()
+
+
+
+ def next_seed(args):
+ if args.seed_behavior == 'iter':
+ args.seed += 1
+ elif args.seed_behavior == 'fixed':
+ pass # always keep seed the same
+ else:
+ args.seed = random.randint(0, 2**32)
+ return args.seed
+
+ def render_image_batch(args):
+ args.prompts = {k: f"{v:05d}" for v, k in enumerate(prompts)}
+
+ # create output folder for the batch
+ os.makedirs(args.outdir, exist_ok=True)
+ if args.save_settings or args.save_samples:
+ print(f"Saving to {os.path.join(args.outdir, args.timestring)}_*")
+
+ # save settings for the batch
+ if args.save_settings:
+ filename = os.path.join(args.outdir, f"{args.timestring}_settings.txt")
+ with open(filename, "w+", encoding="utf-8") as f:
+ dictlist = dict(args.__dict__)
+ del dictlist['master_args']
+ json.dump(dictlist, f, ensure_ascii=False, indent=4)
+
+ index = 0
+
+ # function for init image batching
+ init_array = []
+ if args.use_init:
+ if args.init_image == "":
+ raise FileNotFoundError("No path was given for init_image")
+ if args.init_image.startswith('http://') or args.init_image.startswith('https://'):
+ init_array.append(args.init_image)
+ elif not os.path.isfile(args.init_image):
+ if args.init_image[-1] != "/": # avoids path error by adding / to end if not there
+ args.init_image += "/"
+ for image in sorted(os.listdir(args.init_image)): # iterates dir and appends images to init_array
+ if image.split(".")[-1] in ("png", "jpg", "jpeg"):
+ init_array.append(args.init_image + image)
+ else:
+ init_array.append(args.init_image)
+ else:
+ init_array = [""]
+
+ # when doing large batches don't flood browser with images
+ clear_between_batches = args.n_batch >= 32
+
+ for iprompt, prompt in enumerate(prompts):
+ args.prompt = prompt
+ print(f"Prompt {iprompt+1} of {len(prompts)}")
+ print(f"{args.prompt}")
+
+ all_images = []
+
+ for batch_index in range(args.n_batch):
+ if clear_between_batches and batch_index % 32 == 0:
+ display.clear_output(wait=True)
+ print(f"Batch {batch_index+1} of {args.n_batch}")
+
+ for image in init_array: # iterates the init images
+ args.init_image = image
+ results = generate(args)
+ for image in results:
+ if args.make_grid:
+ all_images.append(T.functional.pil_to_tensor(image))
+ if args.save_samples:
+ if args.filename_format == "{timestring}_{index}_{prompt}.png":
+ filename = f"{args.timestring}_{index:05}_{sanitize(prompt)[:160]}.png"
+ else:
+ filename = f"{args.timestring}_{index:05}_{args.seed}.png"
+ image.save(os.path.join(args.outdir, filename))
+ if args.display_samples:
+ display.display(image)
+ index += 1
+ args.seed = next_seed(args)
+
+ #print(len(all_images))
+ if args.make_grid:
+ grid = make_grid(all_images, nrow=int(len(all_images)/args.grid_rows))
+ grid = rearrange(grid, 'c h w -> h w c').cpu().numpy()
+ filename = f"{args.timestring}_{iprompt:05d}_grid_{args.seed}.png"
+ grid_image = Image.fromarray(grid.astype(np.uint8))
+ grid_image.save(os.path.join(args.outdir, filename))
+ display.clear_output(wait=True)
+ display.display(grid_image)
+
+
+ def render_animation(args, anim_args):
+ # animations use key framed prompts
+ args.prompts = animation_prompts
+
+ # expand key frame strings to values
+ keys = DeformAnimKeys(anim_args)
+
+ # resume animation
+ start_frame = 0
+ if anim_args.resume_from_timestring:
+ for tmp in os.listdir(args.outdir):
+ if tmp.split("_")[0] == anim_args.resume_timestring:
+ start_frame += 1
+ start_frame = start_frame - 1
+
+ # create output folder for the batch
+ os.makedirs(args.outdir, exist_ok=True)
+ print(f"Saving animation frames to {args.outdir}")
+
+ # save settings for the batch
+ settings_filename = os.path.join(args.outdir, f"{args.timestring}_settings.txt")
+ with open(settings_filename, "w+", encoding="utf-8") as f:
+ s = {**dict(args.__dict__), **dict(anim_args.__dict__)}
+ del s['master_args']
+ del s['opt']
+ json.dump(s, f, ensure_ascii=False, indent=4)
+
+ # resume from timestring
+ if anim_args.resume_from_timestring:
+ args.timestring = anim_args.resume_timestring
+
+ # expand prompts out to per-frame
+ prompt_series = pd.Series([np.nan for a in range(anim_args.max_frames)])
+ for i, prompt in animation_prompts.items():
+ prompt_series[int(i)] = prompt
+ prompt_series = prompt_series.ffill().bfill()
+
+ # check for video inits
+ using_vid_init = anim_args.animation_mode == 'Video Input'
+
+ # load depth model for 3D
+ predict_depths = (anim_args.animation_mode == '3D' and anim_args.use_depth_warping) or anim_args.save_depth_maps
+ if predict_depths:
+ depth_model = DepthModel(device)
+ depth_model.load_midas(models_path)
+ if anim_args.midas_weight < 1.0:
+ depth_model.load_adabins()
+ else:
+ depth_model = None
+ anim_args.save_depth_maps = False
+
+ # state for interpolating between diffusion steps
+ turbo_steps = 1 if using_vid_init else int(anim_args.diffusion_cadence)
+ turbo_prev_image, turbo_prev_frame_idx = None, 0
+ turbo_next_image, turbo_next_frame_idx = None, 0
+
+ # resume animation
+ prev_sample = None
+ color_match_sample = None
+ if anim_args.resume_from_timestring:
+ last_frame = start_frame-1
+ if turbo_steps > 1:
+ last_frame -= last_frame%turbo_steps
+ path = os.path.join(args.outdir,f"{args.timestring}_{last_frame:05}.png")
+ img = cv2.imread(path)
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+ prev_sample = sample_from_cv2(img)
+ if anim_args.color_coherence != 'None':
+ color_match_sample = img
+ if turbo_steps > 1:
+ turbo_next_image, turbo_next_frame_idx = sample_to_cv2(prev_sample, type=np.float32), last_frame
+ turbo_prev_image, turbo_prev_frame_idx = turbo_next_image, turbo_next_frame_idx
+ start_frame = last_frame+turbo_steps
+
+ args.n_samples = 1
+ frame_idx = start_frame
+ while frame_idx < anim_args.max_frames:
+ print(f"Rendering animation frame {frame_idx} of {anim_args.max_frames}")
+ noise = keys.noise_schedule_series[frame_idx]
+ strength = keys.strength_schedule_series[frame_idx]
+ contrast = keys.contrast_schedule_series[frame_idx]
+ depth = None
+
+ # emit in-between frames
+ if turbo_steps > 1:
+ tween_frame_start_idx = max(0, frame_idx-turbo_steps)
+ for tween_frame_idx in range(tween_frame_start_idx, frame_idx):
+ tween = float(tween_frame_idx - tween_frame_start_idx + 1) / float(frame_idx - tween_frame_start_idx)
+ print(f" creating in between frame {tween_frame_idx} tween:{tween:0.2f}")
+
+ advance_prev = turbo_prev_image is not None and tween_frame_idx > turbo_prev_frame_idx
+ advance_next = tween_frame_idx > turbo_next_frame_idx
+
+ if depth_model is not None:
+ assert(turbo_next_image is not None)
+ depth = depth_model.predict(turbo_next_image, anim_args)
+
+ if anim_args.animation_mode == '2D':
+ if advance_prev:
+ turbo_prev_image = anim_frame_warp_2d(turbo_prev_image, args, anim_args, keys, tween_frame_idx)
+ if advance_next:
+ turbo_next_image = anim_frame_warp_2d(turbo_next_image, args, anim_args, keys, tween_frame_idx)
+ else: # '3D'
+ if advance_prev:
+ turbo_prev_image = anim_frame_warp_3d(turbo_prev_image, depth, anim_args, keys, tween_frame_idx)
+ if advance_next:
+ turbo_next_image = anim_frame_warp_3d(turbo_next_image, depth, anim_args, keys, tween_frame_idx)
+ turbo_prev_frame_idx = turbo_next_frame_idx = tween_frame_idx
+
+ if turbo_prev_image is not None and tween < 1.0:
+ img = turbo_prev_image*(1.0-tween) + turbo_next_image*tween
+ else:
+ img = turbo_next_image
+
+ filename = f"{args.timestring}_{tween_frame_idx:05}.png"
+ cv2.imwrite(os.path.join(args.outdir, filename), cv2.cvtColor(img.astype(np.uint8), cv2.COLOR_RGB2BGR))
+ if anim_args.save_depth_maps:
+ depth_model.save(os.path.join(args.outdir, f"{args.timestring}_depth_{tween_frame_idx:05}.png"), depth)
+ if turbo_next_image is not None:
+ prev_sample = sample_from_cv2(turbo_next_image)
+
+ # apply transforms to previous frame
+ if prev_sample is not None:
+ if anim_args.animation_mode == '2D':
+ prev_img = anim_frame_warp_2d(sample_to_cv2(prev_sample), args, anim_args, keys, frame_idx)
+ else: # '3D'
+ prev_img_cv2 = sample_to_cv2(prev_sample)
+ depth = depth_model.predict(prev_img_cv2, anim_args) if depth_model else None
+ prev_img = anim_frame_warp_3d(prev_img_cv2, depth, anim_args, keys, frame_idx)
+
+ # apply color matching
+ if anim_args.color_coherence != 'None':
+ if color_match_sample is None:
+ color_match_sample = prev_img.copy()
+ else:
+ prev_img = maintain_colors(prev_img, color_match_sample, anim_args.color_coherence)
+
+ # apply scaling
+ contrast_sample = prev_img * contrast
+ # apply frame noising
+ noised_sample = add_noise(sample_from_cv2(contrast_sample), noise)
+
+ # use transformed previous frame as init for current
+ args.use_init = True
+ if half_precision:
+ args.init_sample = noised_sample.half().to(device)
+ else:
+ args.init_sample = noised_sample.to(device)
+ args.strength = max(0.0, min(1.0, strength))
+
+ # grab prompt for current frame
+ args.prompt = prompt_series[frame_idx]
+ print(f"{args.prompt} {args.seed}")
+
+ # grab init image for current frame
+ if using_vid_init:
+ init_frame = os.path.join(args.outdir, 'inputframes', f"{frame_idx+1:04}.jpg")
+ print(f"Using video init frame {init_frame}")
+ args.init_image = init_frame
+
+ # sample the diffusion model
+ sample, image = generate(args, return_latent=False, return_sample=True)
+ if not using_vid_init:
+ prev_sample = sample
+
+ if turbo_steps > 1:
+ turbo_prev_image, turbo_prev_frame_idx = turbo_next_image, turbo_next_frame_idx
+ turbo_next_image, turbo_next_frame_idx = sample_to_cv2(sample, type=np.float32), frame_idx
+ frame_idx += turbo_steps
+ else:
+ filename = f"{args.timestring}_{frame_idx:05}.png"
+ image.save(os.path.join(args.outdir, filename))
+ if anim_args.save_depth_maps:
+ if depth is None:
+ depth = depth_model.predict(sample_to_cv2(sample), anim_args)
+ depth_model.save(os.path.join(args.outdir, f"{args.timestring}_depth_{frame_idx:05}.png"), depth)
+ frame_idx += 1
+
+ display.clear_output(wait=True)
+ display.display(image)
+
+ args.seed = next_seed(args)
+
+ def render_input_video(args, anim_args):
+ # create a folder for the video input frames to live in
+ video_in_frame_path = os.path.join(args.outdir, 'inputframes')
+ os.makedirs(video_in_frame_path, exist_ok=True)
+
+ # save the video frames from input video
+ print(f"Exporting Video Frames (1 every {anim_args.extract_nth_frame}) frames to {video_in_frame_path}...")
+ try:
+ for f in pathlib.Path(video_in_frame_path).glob('*.jpg'):
+ f.unlink()
+ except:
+ pass
+ vf = r'select=not(mod(n\,'+str(anim_args.extract_nth_frame)+'))'
+ subprocess.run([
+ 'ffmpeg', '-i', f'{anim_args.video_init_path}',
+ '-vf', f'{vf}', '-vsync', 'vfr', '-q:v', '2',
+ '-loglevel', 'error', '-stats',
+ os.path.join(video_in_frame_path, '%04d.jpg')
+ ], stdout=subprocess.PIPE).stdout.decode('utf-8')
+
+ # determine max frames from length of input frames
+ anim_args.max_frames = len([f for f in pathlib.Path(video_in_frame_path).glob('*.jpg')])
+
+ args.use_init = True
+ print(f"Loading {anim_args.max_frames} input frames from {video_in_frame_path} and saving video frames to {args.outdir}")
+ render_animation(args, anim_args)
+
+ def render_interpolation(args, anim_args):
+ # animations use key framed prompts
+ args.prompts = animation_prompts
+
+ # create output folder for the batch
+ os.makedirs(args.outdir, exist_ok=True)
+ print(f"Saving animation frames to {args.outdir}")
+
+ # save settings for the batch
+ settings_filename = os.path.join(args.outdir, f"{args.timestring}_settings.txt")
+ with open(settings_filename, "w+", encoding="utf-8") as f:
+ s = {**dict(args.__dict__), **dict(anim_args.__dict__)}
+ del s['master_args']
+ del s['opt']
+ json.dump(s, f, ensure_ascii=False, indent=4)
+
+ # Interpolation Settings
+ args.n_samples = 1
+ args.seed_behavior = 'fixed' # force fix seed at the moment bc only 1 seed is available
+ prompts_c_s = [] # cache all the text embeddings
+
+ print(f"Preparing for interpolation of the following...")
+
+ for i, prompt in animation_prompts.items():
+ args.prompt = prompt
+
+ # sample the diffusion model
+ results = generate(args, return_c=True)
+ c, image = results[0], results[1]
+ prompts_c_s.append(c)
+
+ # display.clear_output(wait=True)
+ display.display(image)
+
+ args.seed = next_seed(args)
+
+ display.clear_output(wait=True)
+ print(f"Interpolation start...")
+
+ frame_idx = 0
+
+ if anim_args.interpolate_key_frames:
+ for i in range(len(prompts_c_s)-1):
+ dist_frames = list(animation_prompts.items())[i+1][0] - list(animation_prompts.items())[i][0]
+ if dist_frames <= 0:
+ print("key frames duplicated or reversed. interpolation skipped.")
+ return
+ else:
+ for j in range(dist_frames):
+ # interpolate the text embedding
+ prompt1_c = prompts_c_s[i]
+ prompt2_c = prompts_c_s[i+1]
+ args.init_c = prompt1_c.add(prompt2_c.sub(prompt1_c).mul(j * 1/dist_frames))
+
+ # sample the diffusion model
+ results = generate(args)
+ image = results[0]
+
+ filename = f"{args.timestring}_{frame_idx:05}.png"
+ image.save(os.path.join(args.outdir, filename))
+ frame_idx += 1
+
+ display.clear_output(wait=True)
+ display.display(image)
+
+ args.seed = next_seed(args)
+
+ else:
+ for i in range(len(prompts_c_s)-1):
+ for j in range(anim_args.interpolate_x_frames+1):
+ # interpolate the text embedding
+ prompt1_c = prompts_c_s[i]
+ prompt2_c = prompts_c_s[i+1]
+ args.init_c = prompt1_c.add(prompt2_c.sub(prompt1_c).mul(j * 1/(anim_args.interpolate_x_frames+1)))
+
+ # sample the diffusion model
+ results = generate(args)
+ image = results[0]
+
+ filename = f"{args.timestring}_{frame_idx:05}.png"
+ image.save(os.path.join(args.outdir, filename))
+ frame_idx += 1
+
+ display.clear_output(wait=True)
+ display.display(image)
+
+ args.seed = next_seed(args)
+
+ # generate the last prompt
+ args.init_c = prompts_c_s[-1]
+ results = generate(args)
+ image = results[0]
+ filename = f"{args.timestring}_{frame_idx:05}.png"
+ image.save(os.path.join(args.outdir, filename))
+
+ display.clear_output(wait=True)
+ display.display(image)
+ args.seed = next_seed(args)
+
+ #clear init_c
+ args.init_c = None
+
+
+ args = SimpleNamespace(**DeforumArgs())
+ anim_args = SimpleNamespace(**DeforumAnimArgs())
+
+ args.timestring = time.strftime('%Y%m%d%H%M%S')
+ args.strength = max(0.0, min(1.0, args.strength))
+
+ if args.seed == -1:
+ args.seed = random.randint(0, 2**32 - 1)
+ if not args.use_init:
+ args.init_image = None
+ if args.sampler == 'plms' and (args.use_init or anim_args.animation_mode != 'None'):
+ print(f"Init images aren't supported with PLMS yet, switching to KLMS")
+ args.sampler = 'klms'
+ if args.sampler != 'ddim':
+ args.ddim_eta = 0
+
+ if anim_args.animation_mode == 'None':
+ anim_args.max_frames = 1
+ elif anim_args.animation_mode == 'Video Input':
+ args.use_init = True
+
+ # clean up unused memory
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ # dispatch to appropriate renderer
+ if anim_args.animation_mode == '2D' or anim_args.animation_mode == '3D':
+ render_animation(args, anim_args)
+ elif anim_args.animation_mode == 'Video Input':
+ render_input_video(args, anim_args)
+ elif anim_args.animation_mode == 'Interpolation':
+ render_interpolation(args, anim_args)
+ else:
+ render_image_batch(args)
+
+
+ skip_video_for_run_all = False #@param {type: 'boolean'}
+ fps = 12 #@param {type:"number"}
+ #@markdown **Manual Settings**
+ use_manual_settings = False #@param {type:"boolean"}
+ image_path = "./output/out_%05d.png" #@param {type:"string"}
+ mp4_path = "./output/out_%05d.mp4" #@param {type:"string"}
+
+
+ if skip_video_for_run_all == True or opt.enable_animation_mode == False:
+ print('Skipping video creation, uncheck skip_video_for_run_all if you want to run it')
+ else:
+ import os
+ import subprocess
+ from base64 import b64encode
+
+ print(f"{image_path} -> {mp4_path}")
+
+ if use_manual_settings:
+ max_frames = "200" #@param {type:"string"}
+ else:
+ image_path = os.path.join(args.outdir, f"{args.timestring}_%05d.png")
+ mp4_path = os.path.join(args.outdir, f"{args.timestring}.mp4")
+ max_frames = str(anim_args.max_frames)
+
+ # make video
+ cmd = [
+ 'ffmpeg',
+ '-y',
+ '-vcodec', 'png',
+ '-r', str(fps),
+ '-start_number', str(0),
+ '-i', image_path,
+ '-frames:v', max_frames,
+ '-c:v', 'libx264',
+ '-vf',
+ f'fps={fps}',
+ '-pix_fmt', 'yuv420p',
+ '-crf', '17',
+ '-preset', 'veryfast',
+ mp4_path
+ ]
+ process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
+ stdout, stderr = process.communicate()
+ if process.returncode != 0:
+ print(stderr)
+ raise RuntimeError(stderr)
+
+ mp4 = open(mp4_path,'rb').read()
+ data_url = "data:video/mp4;base64," + b64encode(mp4).decode()
+ display.display( display.HTML(f'') )
+
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/scripts/ModelManager.py b/scripts/ModelManager.py
new file mode 100644
index 0000000..983f85b
--- /dev/null
+++ b/scripts/ModelManager.py
@@ -0,0 +1,46 @@
+# base webui import and utils.
+from webui_streamlit import st
+from sd_utils import *
+
+# streamlit imports
+
+
+#other imports
+import pandas as pd
+from io import StringIO
+
+# Temp imports
+
+
+# end of imports
+#---------------------------------------------------------------------------------------------------------------
+
+def layout():
+ #search = st.text_input(label="Search", placeholder="Type the name of the model you want to search for.", help="")
+
+ csvString = f"""
+ ,Stable Diffusion v1.4 , ./models/ldm/stable-diffusion-v1 , https://www.googleapis.com/storage/v1/b/aai-blog-files/o/sd-v1-4.ckpt?alt=media
+ ,GFPGAN v1.3 , ./src/gfpgan/experiments/pretrained_models , https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth
+ ,RealESRGAN_x4plus , ./src/realesrgan/experiments/pretrained_models , https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth
+ ,RealESRGAN_x4plus_anime_6B , ./src/realesrgan/experiments/pretrained_models , https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth
+ ,Waifu Diffusion v1.2 , ./models/custom , http://wd.links.sd:8880/wd-v1-2-full-ema.ckpt
+ ,TrinArt Stable Diffusion v2 , ./models/custom , https://huggingface.co/naclbit/trinart_stable_diffusion_v2/resolve/main/trinart2_step115000.ckpt
+ ,Stable Diffusion Concept Library , ./models/customsd-concepts-library , https://github.com/sd-webui/sd-concepts-library
+ """
+ colms = st.columns((1, 3, 5, 5))
+ columns = ["№",'Model Name','Save Location','Download Link']
+
+ # Convert String into StringIO
+ csvStringIO = StringIO(csvString)
+ df = pd.read_csv(csvStringIO, sep=",", header=None, names=columns)
+
+ for col, field_name in zip(colms, columns):
+ # table header
+ col.write(field_name)
+
+ for x, model_name in enumerate(df["Model Name"]):
+ col1, col2, col3, col4 = st.columns((1, 3, 4, 6))
+ col1.write(x) # index
+ col2.write(df['Model Name'][x])
+ col3.write(df['Save Location'][x])
+ col4.write(df['Download Link'][x])
\ No newline at end of file
diff --git a/scripts/Settings.py b/scripts/Settings.py
new file mode 100644
index 0000000..a1e21ec
--- /dev/null
+++ b/scripts/Settings.py
@@ -0,0 +1,5 @@
+from webui_streamlit import st
+
+# The global settings section will be moved to the Settings page.
+#with st.expander("Global Settings:"):
+st.write("Global Settings:")
diff --git a/scripts/home.py b/scripts/home.py
new file mode 100644
index 0000000..2702fcc
--- /dev/null
+++ b/scripts/home.py
@@ -0,0 +1,216 @@
+# base webui import and utils.
+from webui_streamlit import st
+from sd_utils import *
+
+# streamlit imports
+
+
+#other imports
+
+# Temp imports
+
+
+# end of imports
+#---------------------------------------------------------------------------------------------------------------
+
+import os
+from PIL import Image
+
+try:
+ # this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
+ from transformers import logging
+
+ logging.set_verbosity_error()
+except:
+ pass
+
+class plugin_info():
+ plugname = "home"
+ description = "Home"
+ isTab = True
+ displayPriority = 0
+
+def getLatestGeneratedImagesFromPath():
+ #get the latest images from the generated images folder
+ #get the path to the generated images folder
+ generatedImagesPath = os.path.join(os.getcwd(),'outputs')
+ #get all the files from the folders and subfolders
+ files = []
+ #get the latest 10 images from the output folder without walking the subfolders
+ for r, d, f in os.walk(generatedImagesPath):
+ for file in f:
+ if '.png' in file:
+ files.append(os.path.join(r, file))
+ #sort the files by date
+ files.sort(reverse=True, key=os.path.getmtime)
+ latest = files[:90]
+ latest.reverse()
+
+ # reverse the list so the latest images are first and truncate to
+ # a reasonable number of images, 10 pages worth
+ return [Image.open(f) for f in latest]
+
+def get_images_from_lexica():
+ #scrape images from lexica.art
+ #get the html from the page
+ #get the html with cookies and javascript
+ apiEndpoint = r'https://lexica.art/api/trpc/prompts.infinitePrompts?batch=1&input=%7B%220%22%3A%7B%22json%22%3A%7B%22limit%22%3A10%2C%22text%22%3A%22%22%2C%22cursor%22%3A10%7D%7D%7D'
+ #REST API call
+ #
+ from requests_html import HTMLSession
+ session = HTMLSession()
+
+ response = session.get(apiEndpoint)
+ #req = requests.Session()
+ #req.headers['user-agent'] = 'Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/96.0.4664.45 Safari/537.36'
+ #response = req.get(apiEndpoint)
+ print(response.status_code)
+ print(response.text)
+ #get the json from the response
+ #json = response.json()
+ #get the prompts from the json
+ print(response)
+ #session = requests.Session()
+ #parseEndpointJson = session.get(apiEndpoint,headers=headers,verify=False)
+ #print(parseEndpointJson)
+ #print('test2')
+ #page = requests.get("https://lexica.art/", headers={'User-Agent': 'Mozilla/5.0'})
+ #parse the html
+ #soup = BeautifulSoup(page.content, 'html.parser')
+ #find all the images
+ #print(soup)
+ #images = soup.find_all('alt-image')
+ #create a list to store the image urls
+ image_urls = []
+ #loop through the images
+ for image in images:
+ #get the url
+ image_url = image['src']
+ #add it to the list
+ image_urls.append('http://www.lexica.art/'+image_url)
+ #return the list
+ print(image_urls)
+ return image_urls
+
+def layout():
+ #streamlit home page layout
+ #center the title
+ st.markdown("
Welcome, let's make some 🎨
", unsafe_allow_html=True)
+ #make a gallery of images
+ #st.markdown("
Gallery
", unsafe_allow_html=True)
+ #create a gallery of images using columns
+ #col1, col2, col3 = st.columns(3)
+ #load the images
+ #create 3 columns
+ # create a tab for the gallery
+ #st.markdown("
Gallery
", unsafe_allow_html=True)
+ #st.markdown("
Gallery
", unsafe_allow_html=True)
+
+ history_tab, discover_tabs = st.tabs(["History","Discover"])
+
+ latestImages = getLatestGeneratedImagesFromPath()
+ st.session_state['latestImages'] = latestImages
+
+ with history_tab:
+ ##---------------------------------------------------------
+ ## image slideshow test
+ ## Number of entries per screen
+ #slideshow_N = 9
+ #slideshow_page_number = 0
+ #slideshow_last_page = len(latestImages) // slideshow_N
+
+ ## Add a next button and a previous button
+
+ #slideshow_prev, slideshow_image_col , slideshow_next = st.columns([1, 10, 1])
+
+ #with slideshow_image_col:
+ #slideshow_image = st.empty()
+
+ #slideshow_image.image(st.session_state['latestImages'][0])
+
+ #current_image = 0
+
+ #if slideshow_next.button("Next", key=1):
+ ##print (current_image+1)
+ #current_image = current_image+1
+ #slideshow_image.image(st.session_state['latestImages'][current_image+1])
+ #if slideshow_prev.button("Previous", key=0):
+ ##print ([current_image-1])
+ #current_image = current_image-1
+ #slideshow_image.image(st.session_state['latestImages'][current_image - 1])
+
+
+ #---------------------------------------------------------
+
+ # image gallery
+ # Number of entries per screen
+ gallery_N = 9
+ if not "galleryPage" in st.session_state:
+ st.session_state["galleryPage"] = 0
+ gallery_last_page = len(latestImages) // gallery_N
+
+ # Add a next button and a previous button
+
+ gallery_prev, gallery_refresh, gallery_pagination , gallery_next = st.columns([2, 2, 8, 1])
+
+ # the pagination doesnt work for now so its better to disable the buttons.
+
+ if gallery_refresh.button("Refresh", key=4):
+ st.session_state["galleryPage"] = 0
+
+ if gallery_next.button("Next", key=3):
+
+ if st.session_state["galleryPage"] + 1 > gallery_last_page:
+ st.session_state["galleryPage"] = 0
+ else:
+ st.session_state["galleryPage"] += 1
+
+ if gallery_prev.button("Previous", key=2):
+
+ if st.session_state["galleryPage"] - 1 < 0:
+ st.session_state["galleryPage"] = gallery_last_page
+ else:
+ st.session_state["galleryPage"] -= 1
+
+ print(st.session_state["galleryPage"])
+ # Get start and end indices of the next page of the dataframe
+ gallery_start_idx = st.session_state["galleryPage"] * gallery_N
+ gallery_end_idx = (1 + st.session_state["galleryPage"]) * gallery_N
+
+ #---------------------------------------------------------
+
+ placeholder = st.empty()
+
+ #populate the 3 images per column
+ with placeholder.container():
+ col1, col2, col3 = st.columns(3)
+ col1_cont = st.container()
+ col2_cont = st.container()
+ col3_cont = st.container()
+
+ print (len(st.session_state['latestImages']))
+ images = list(reversed(st.session_state['latestImages']))[gallery_start_idx:(gallery_start_idx+gallery_N)]
+
+ with col1_cont:
+ with col1:
+ [st.image(images[index]) for index in [0, 3, 6] if index < len(images)]
+ with col2_cont:
+ with col2:
+ [st.image(images[index]) for index in [1, 4, 7] if index < len(images)]
+ with col3_cont:
+ with col3:
+ [st.image(images[index]) for index in [2, 5, 8] if index < len(images)]
+
+
+ st.session_state['historyTab'] = [history_tab,col1,col2,col3,placeholder,col1_cont,col2_cont,col3_cont]
+
+ with discover_tabs:
+ st.markdown("
Soon :)
", unsafe_allow_html=True)
+
+ #display the images
+ #add a button to the gallery
+ #st.markdown("
Try it out
", unsafe_allow_html=True)
+ #create a button to the gallery
+ #if st.button("Try it out"):
+ #if the button is clicked, go to the gallery
+ #st.experimental_rerun()
diff --git a/scripts/img2img.py b/scripts/img2img.py
new file mode 100644
index 0000000..142fe81
--- /dev/null
+++ b/scripts/img2img.py
@@ -0,0 +1,592 @@
+# base webui import and utils.
+from webui_streamlit import st
+from sd_utils import *
+
+# streamlit imports
+from streamlit import StopException
+
+#other imports
+import cv2
+from PIL import Image, ImageOps
+import torch
+import k_diffusion as K
+import numpy as np
+import time
+import torch
+import skimage
+from ldm.models.diffusion.ddim import DDIMSampler
+from ldm.models.diffusion.plms import PLMSSampler
+# Temp imports
+
+
+# end of imports
+#---------------------------------------------------------------------------------------------------------------
+
+
+try:
+ # this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
+ from transformers import logging
+
+ logging.set_verbosity_error()
+except:
+ pass
+
+def img2img(prompt: str = '', init_info: any = None, init_info_mask: any = None, mask_mode: int = 0, mask_blur_strength: int = 3,
+ mask_restore: bool = False, ddim_steps: int = 50, sampler_name: str = 'DDIM',
+ n_iter: int = 1, cfg_scale: float = 7.5, denoising_strength: float = 0.8,
+ seed: int = -1, noise_mode: int = 0, find_noise_steps: str = "", height: int = 512, width: int = 512, resize_mode: int = 0, fp = None,
+ variant_amount: float = None, variant_seed: int = None, ddim_eta:float = 0.0,
+ write_info_files:bool = True, RealESRGAN_model: str = "RealESRGAN_x4plus_anime_6B",
+ separate_prompts:bool = False, normalize_prompt_weights:bool = True,
+ save_individual_images: bool = True, save_grid: bool = True, group_by_prompt: bool = True,
+ save_as_jpg: bool = True, use_GFPGAN: bool = True, use_RealESRGAN: bool = True, loopback: bool = False,
+ random_seed_loopback: bool = False
+ ):
+
+ outpath = st.session_state['defaults'].general.outdir_img2img or st.session_state['defaults'].general.outdir or "outputs/img2img-samples"
+ #err = False
+ #loopback = False
+ #skip_save = False
+ seed = seed_to_int(seed)
+
+ batch_size = 1
+
+ #prompt_matrix = 0
+ #normalize_prompt_weights = 1 in toggles
+ #loopback = 2 in toggles
+ #random_seed_loopback = 3 in toggles
+ #skip_save = 4 not in toggles
+ #save_grid = 5 in toggles
+ #sort_samples = 6 in toggles
+ #write_info_files = 7 in toggles
+ #write_sample_info_to_log_file = 8 in toggles
+ #jpg_sample = 9 in toggles
+ #use_GFPGAN = 10 in toggles
+ #use_RealESRGAN = 11 in toggles
+
+ if sampler_name == 'PLMS':
+ sampler = PLMSSampler(st.session_state["model"])
+ elif sampler_name == 'DDIM':
+ sampler = DDIMSampler(st.session_state["model"])
+ elif sampler_name == 'k_dpm_2_a':
+ sampler = KDiffusionSampler(st.session_state["model"],'dpm_2_ancestral')
+ elif sampler_name == 'k_dpm_2':
+ sampler = KDiffusionSampler(st.session_state["model"],'dpm_2')
+ elif sampler_name == 'k_euler_a':
+ sampler = KDiffusionSampler(st.session_state["model"],'euler_ancestral')
+ elif sampler_name == 'k_euler':
+ sampler = KDiffusionSampler(st.session_state["model"],'euler')
+ elif sampler_name == 'k_heun':
+ sampler = KDiffusionSampler(st.session_state["model"],'heun')
+ elif sampler_name == 'k_lms':
+ sampler = KDiffusionSampler(st.session_state["model"],'lms')
+ else:
+ raise Exception("Unknown sampler: " + sampler_name)
+
+ def process_init_mask(init_mask: Image):
+ if init_mask.mode == "RGBA":
+ init_mask = init_mask.convert('RGBA')
+ background = Image.new('RGBA', init_mask.size, (0, 0, 0))
+ init_mask = Image.alpha_composite(background, init_mask)
+ init_mask = init_mask.convert('RGB')
+ return init_mask
+
+ init_img = init_info
+ init_mask = None
+ if mask_mode == 0:
+ if init_info_mask:
+ init_mask = process_init_mask(init_info_mask)
+ elif mask_mode == 1:
+ if init_info_mask:
+ init_mask = process_init_mask(init_info_mask)
+ init_mask = ImageOps.invert(init_mask)
+ elif mask_mode == 2:
+ init_img_transparency = init_img.split()[-1].convert('L')#.point(lambda x: 255 if x > 0 else 0, mode='1')
+ init_mask = init_img_transparency
+ init_mask = init_mask.convert("RGB")
+ init_mask = resize_image(resize_mode, init_mask, width, height)
+ init_mask = init_mask.convert("RGB")
+
+ assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]'
+ t_enc = int(denoising_strength * ddim_steps)
+
+ if init_mask is not None and (noise_mode == 2 or noise_mode == 3) and init_img is not None:
+ noise_q = 0.99
+ color_variation = 0.0
+ mask_blend_factor = 1.0
+
+ np_init = (np.asarray(init_img.convert("RGB"))/255.0).astype(np.float64) # annoyingly complex mask fixing
+ np_mask_rgb = 1. - (np.asarray(ImageOps.invert(init_mask).convert("RGB"))/255.0).astype(np.float64)
+ np_mask_rgb -= np.min(np_mask_rgb)
+ np_mask_rgb /= np.max(np_mask_rgb)
+ np_mask_rgb = 1. - np_mask_rgb
+ np_mask_rgb_hardened = 1. - (np_mask_rgb < 0.99).astype(np.float64)
+ blurred = skimage.filters.gaussian(np_mask_rgb_hardened[:], sigma=16., channel_axis=2, truncate=32.)
+ blurred2 = skimage.filters.gaussian(np_mask_rgb_hardened[:], sigma=16., channel_axis=2, truncate=32.)
+ #np_mask_rgb_dilated = np_mask_rgb + blurred # fixup mask todo: derive magic constants
+ #np_mask_rgb = np_mask_rgb + blurred
+ np_mask_rgb_dilated = np.clip((np_mask_rgb + blurred2) * 0.7071, 0., 1.)
+ np_mask_rgb = np.clip((np_mask_rgb + blurred) * 0.7071, 0., 1.)
+
+ noise_rgb = get_matched_noise(np_init, np_mask_rgb, noise_q, color_variation)
+ blend_mask_rgb = np.clip(np_mask_rgb_dilated,0.,1.) ** (mask_blend_factor)
+ noised = noise_rgb[:]
+ blend_mask_rgb **= (2.)
+ noised = np_init[:] * (1. - blend_mask_rgb) + noised * blend_mask_rgb
+
+ np_mask_grey = np.sum(np_mask_rgb, axis=2)/3.
+ ref_mask = np_mask_grey < 1e-3
+
+ all_mask = np.ones((height, width), dtype=bool)
+ noised[all_mask,:] = skimage.exposure.match_histograms(noised[all_mask,:]**1., noised[ref_mask,:], channel_axis=1)
+
+ init_img = Image.fromarray(np.clip(noised * 255., 0., 255.).astype(np.uint8), mode="RGB")
+ st.session_state["editor_image"].image(init_img) # debug
+
+ def init():
+ image = init_img.convert('RGB')
+ image = np.array(image).astype(np.float32) / 255.0
+ image = image[None].transpose(0, 3, 1, 2)
+ image = torch.from_numpy(image)
+
+ mask_channel = None
+ if init_mask:
+ alpha = resize_image(resize_mode, init_mask, width // 8, height // 8)
+ mask_channel = alpha.split()[-1]
+
+ mask = None
+ if mask_channel is not None:
+ mask = np.array(mask_channel).astype(np.float32) / 255.0
+ mask = (1 - mask)
+ mask = np.tile(mask, (4, 1, 1))
+ mask = mask[None].transpose(0, 1, 2, 3)
+ mask = torch.from_numpy(mask).to(st.session_state["device"])
+
+ if st.session_state['defaults'].general.optimized:
+ st.session_state.modelFS.to(st.session_state["device"] )
+
+ init_image = 2. * image - 1.
+ init_image = init_image.to(st.session_state["device"])
+ init_latent = (st.session_state["model"] if not st.session_state['defaults'].general.optimized else st.session_state.modelFS).get_first_stage_encoding((st.session_state["model"] if not st.session_state['defaults'].general.optimized else modelFS).encode_first_stage(init_image)) # move to latent space
+
+ if st.session_state['defaults'].general.optimized:
+ mem = torch.cuda.memory_allocated()/1e6
+ st.session_state.modelFS.to("cpu")
+ while(torch.cuda.memory_allocated()/1e6 >= mem):
+ time.sleep(1)
+
+ return init_latent, mask,
+
+ def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name):
+ t_enc_steps = t_enc
+ obliterate = False
+ if ddim_steps == t_enc_steps:
+ t_enc_steps = t_enc_steps - 1
+ obliterate = True
+
+ if sampler_name != 'DDIM':
+ x0, z_mask = init_data
+
+ sigmas = sampler.model_wrap.get_sigmas(ddim_steps)
+ noise = x * sigmas[ddim_steps - t_enc_steps - 1]
+
+ xi = x0 + noise
+
+ # Obliterate masked image
+ if z_mask is not None and obliterate:
+ random = torch.randn(z_mask.shape, device=xi.device)
+ xi = (z_mask * noise) + ((1-z_mask) * xi)
+
+ sigma_sched = sigmas[ddim_steps - t_enc_steps - 1:]
+ model_wrap_cfg = CFGMaskedDenoiser(sampler.model_wrap)
+ samples_ddim = K.sampling.__dict__[f'sample_{sampler.get_sampler_name()}'](model_wrap_cfg, xi, sigma_sched,
+ extra_args={'cond': conditioning, 'uncond': unconditional_conditioning,
+ 'cond_scale': cfg_scale, 'mask': z_mask, 'x0': x0, 'xi': xi}, disable=False,
+ callback=generation_callback)
+ else:
+
+ x0, z_mask = init_data
+
+ sampler.make_schedule(ddim_num_steps=ddim_steps, ddim_eta=0.0, verbose=False)
+ z_enc = sampler.stochastic_encode(x0, torch.tensor([t_enc_steps]*batch_size).to(st.session_state["device"] ))
+
+ # Obliterate masked image
+ if z_mask is not None and obliterate:
+ random = torch.randn(z_mask.shape, device=z_enc.device)
+ z_enc = (z_mask * random) + ((1-z_mask) * z_enc)
+
+ # decode it
+ samples_ddim = sampler.decode(z_enc, conditioning, t_enc_steps,
+ unconditional_guidance_scale=cfg_scale,
+ unconditional_conditioning=unconditional_conditioning,
+ z_mask=z_mask, x0=x0)
+ return samples_ddim
+
+
+
+ if loopback:
+ output_images, info = None, None
+ history = []
+ initial_seed = None
+
+ do_color_correction = False
+ try:
+ from skimage import exposure
+ do_color_correction = True
+ except:
+ print("Install scikit-image to perform color correction on loopback")
+
+ for i in range(n_iter):
+ if do_color_correction and i == 0:
+ correction_target = cv2.cvtColor(np.asarray(init_img.copy()), cv2.COLOR_RGB2LAB)
+
+ output_images, seed, info, stats = process_images(
+ outpath=outpath,
+ func_init=init,
+ func_sample=sample,
+ prompt=prompt,
+ seed=seed,
+ sampler_name=sampler_name,
+ save_grid=save_grid,
+ batch_size=1,
+ n_iter=1,
+ steps=ddim_steps,
+ cfg_scale=cfg_scale,
+ width=width,
+ height=height,
+ prompt_matrix=separate_prompts,
+ use_GFPGAN=use_GFPGAN,
+ use_RealESRGAN=use_RealESRGAN, # Forcefully disable upscaling when using loopback
+ realesrgan_model_name=RealESRGAN_model,
+ normalize_prompt_weights=normalize_prompt_weights,
+ save_individual_images=save_individual_images,
+ init_img=init_img,
+ init_mask=init_mask,
+ mask_blur_strength=mask_blur_strength,
+ mask_restore=mask_restore,
+ denoising_strength=denoising_strength,
+ noise_mode=noise_mode,
+ find_noise_steps=find_noise_steps,
+ resize_mode=resize_mode,
+ uses_loopback=loopback,
+ uses_random_seed_loopback=random_seed_loopback,
+ sort_samples=group_by_prompt,
+ write_info_files=write_info_files,
+ jpg_sample=save_as_jpg
+ )
+
+ if initial_seed is None:
+ initial_seed = seed
+
+ input_image = init_img
+ init_img = output_images[0]
+
+ if do_color_correction and correction_target is not None:
+ init_img = Image.fromarray(cv2.cvtColor(exposure.match_histograms(
+ cv2.cvtColor(
+ np.asarray(init_img),
+ cv2.COLOR_RGB2LAB
+ ),
+ correction_target,
+ channel_axis=2
+ ), cv2.COLOR_LAB2RGB).astype("uint8"))
+ if mask_restore is True and init_mask is not None:
+ color_mask = init_mask.filter(ImageFilter.GaussianBlur(mask_blur_strength))
+ color_mask = color_mask.convert('L')
+ source_image = input_image.convert('RGB')
+ target_image = init_img.convert('RGB')
+
+ init_img = Image.composite(source_image, target_image, color_mask)
+
+ if not random_seed_loopback:
+ seed = seed + 1
+ else:
+ seed = seed_to_int(None)
+
+ denoising_strength = max(denoising_strength * 0.95, 0.1)
+ history.append(init_img)
+
+ output_images = history
+ seed = initial_seed
+
+ else:
+ output_images, seed, info, stats = process_images(
+ outpath=outpath,
+ func_init=init,
+ func_sample=sample,
+ prompt=prompt,
+ seed=seed,
+ sampler_name=sampler_name,
+ save_grid=save_grid,
+ batch_size=batch_size,
+ n_iter=n_iter,
+ steps=ddim_steps,
+ cfg_scale=cfg_scale,
+ width=width,
+ height=height,
+ prompt_matrix=separate_prompts,
+ use_GFPGAN=use_GFPGAN,
+ use_RealESRGAN=use_RealESRGAN,
+ realesrgan_model_name=RealESRGAN_model,
+ normalize_prompt_weights=normalize_prompt_weights,
+ save_individual_images=save_individual_images,
+ init_img=init_img,
+ init_mask=init_mask,
+ mask_blur_strength=mask_blur_strength,
+ denoising_strength=denoising_strength,
+ noise_mode=noise_mode,
+ find_noise_steps=find_noise_steps,
+ mask_restore=mask_restore,
+ resize_mode=resize_mode,
+ uses_loopback=loopback,
+ sort_samples=group_by_prompt,
+ write_info_files=write_info_files,
+ jpg_sample=save_as_jpg
+ )
+
+ del sampler
+
+ return output_images, seed, info, stats
+
+#
+
+
+def layout():
+ with st.form("img2img-inputs"):
+ st.session_state["generation_mode"] = "img2img"
+
+ img2img_input_col, img2img_generate_col = st.columns([10,1])
+ with img2img_input_col:
+ #prompt = st.text_area("Input Text","")
+ prompt = st.text_input("Input Text","", placeholder="A corgi wearing a top hat as an oil painting.")
+
+ # Every form must have a submit button, the extra blank spaces is a temp way to align it with the input field. Needs to be done in CSS or some other way.
+ img2img_generate_col.write("")
+ img2img_generate_col.write("")
+ generate_button = img2img_generate_col.form_submit_button("Generate")
+
+
+ # creating the page layout using columns
+ col1_img2img_layout, col2_img2img_layout, col3_img2img_layout = st.columns([1,2,2], gap="small")
+
+ with col1_img2img_layout:
+ # If we have custom models available on the "models/custom"
+ #folder then we show a menu to select which model we want to use, otherwise we use the main model for SD
+ if st.session_state["CustomModel_available"]:
+ st.session_state["custom_model"] = st.selectbox("Custom Model:", st.session_state["custom_models"],
+ index=st.session_state["custom_models"].index(st.session_state['defaults'].general.default_model),
+ help="Select the model you want to use. This option is only available if you have custom models \
+ on your 'models/custom' folder. The model name that will be shown here is the same as the name\
+ the file for the model has on said folder, it is recommended to give the .ckpt file a name that \
+ will make it easier for you to distinguish it from other models. Default: Stable Diffusion v1.4")
+ else:
+ st.session_state["custom_model"] = "Stable Diffusion v1.4"
+
+
+ st.session_state["sampling_steps"] = st.slider("Sampling Steps",
+ value=st.session_state['defaults'].img2img.sampling_steps,
+ min_value=st.session_state['defaults'].img2img.slider_bounds.sampling.lower,
+ max_value=st.session_state['defaults'].img2img.slider_bounds.sampling.upper,
+ step=st.session_state['defaults'].img2img.slider_steps.sampling)
+
+ sampler_name_list = ["k_lms", "k_euler", "k_euler_a", "k_dpm_2", "k_dpm_2_a", "k_heun", "PLMS", "DDIM"]
+ st.session_state["sampler_name"] = st.selectbox("Sampling method",sampler_name_list,
+ index=sampler_name_list.index(st.session_state['defaults'].img2img.sampler_name), help="Sampling method to use.")
+
+ mask_mode_list = ["Mask", "Inverted mask", "Image alpha"]
+ mask_mode = st.selectbox("Mask Mode", mask_mode_list,
+ help="Select how you want your image to be masked.\"Mask\" modifies the image where the mask is white.\n\
+ \"Inverted mask\" modifies the image where the mask is black. \"Image alpha\" modifies the image where the image is transparent."
+ )
+ mask_mode = mask_mode_list.index(mask_mode)
+
+ width = st.slider("Width:", min_value=64, max_value=1024, value=st.session_state['defaults'].img2img.width, step=64)
+ height = st.slider("Height:", min_value=64, max_value=1024, value=st.session_state['defaults'].img2img.height, step=64)
+ seed = st.text_input("Seed:", value=st.session_state['defaults'].img2img.seed, help=" The seed to use, if left blank a random seed will be generated.")
+ noise_mode_list = ["Seed", "Find Noise", "Matched Noise", "Find+Matched Noise"]
+ noise_mode = st.selectbox(
+ "Noise Mode", noise_mode_list,
+ help=""
+ )
+ noise_mode = noise_mode_list.index(noise_mode)
+ find_noise_steps = st.slider("Find Noise Steps", value=100, min_value=1, max_value=500)
+ batch_count = st.slider("Batch count.", min_value=1, max_value=100, value=st.session_state['defaults'].img2img.batch_count, step=1,
+ help="How many iterations or batches of images to generate in total.")
+
+ #
+ with st.expander("Advanced"):
+ separate_prompts = st.checkbox("Create Prompt Matrix.", value=st.session_state['defaults'].img2img.separate_prompts,
+ help="Separate multiple prompts using the `|` character, and get all combinations of them.")
+ normalize_prompt_weights = st.checkbox("Normalize Prompt Weights.", value=st.session_state['defaults'].img2img.normalize_prompt_weights,
+ help="Ensure the sum of all weights add up to 1.0")
+ loopback = st.checkbox("Loopback.", value=st.session_state['defaults'].img2img.loopback, help="Use images from previous batch when creating next batch.")
+ random_seed_loopback = st.checkbox("Random loopback seed.", value=st.session_state['defaults'].img2img.random_seed_loopback, help="Random loopback seed")
+ img2img_mask_restore = st.checkbox("Only modify regenerated parts of image",
+ value=st.session_state['defaults'].img2img.mask_restore,
+ help="Enable to restore the unmasked parts of the image with the input, may not blend as well but preserves detail")
+ save_individual_images = st.checkbox("Save individual images.", value=st.session_state['defaults'].img2img.save_individual_images,
+ help="Save each image generated before any filter or enhancement is applied.")
+ save_grid = st.checkbox("Save grid",value=st.session_state['defaults'].img2img.save_grid, help="Save a grid with all the images generated into a single image.")
+ group_by_prompt = st.checkbox("Group results by prompt", value=st.session_state['defaults'].img2img.group_by_prompt,
+ help="Saves all the images with the same prompt into the same folder. \
+ When using a prompt matrix each prompt combination will have its own folder.")
+ write_info_files = st.checkbox("Write Info file", value=st.session_state['defaults'].img2img.write_info_files,
+ help="Save a file next to the image with informartion about the generation.")
+ save_as_jpg = st.checkbox("Save samples as jpg", value=st.session_state['defaults'].img2img.save_as_jpg, help="Saves the images as jpg instead of png.")
+
+ if st.session_state["GFPGAN_available"]:
+ use_GFPGAN = st.checkbox("Use GFPGAN", value=st.session_state['defaults'].img2img.use_GFPGAN, help="Uses the GFPGAN model to improve faces after the generation.\
+ This greatly improve the quality and consistency of faces but uses extra VRAM. Disable if you need the extra VRAM.")
+ else:
+ use_GFPGAN = False
+
+ if st.session_state["RealESRGAN_available"]:
+ st.session_state["use_RealESRGAN"] = st.checkbox("Use RealESRGAN", value=st.session_state['defaults'].img2img.use_RealESRGAN,
+ help="Uses the RealESRGAN model to upscale the images after the generation.\
+ This greatly improve the quality and lets you have high resolution images but uses extra VRAM. Disable if you need the extra VRAM.")
+ st.session_state["RealESRGAN_model"] = st.selectbox("RealESRGAN model", ["RealESRGAN_x4plus", "RealESRGAN_x4plus_anime_6B"], index=0)
+ else:
+ st.session_state["use_RealESRGAN"] = False
+ st.session_state["RealESRGAN_model"] = "RealESRGAN_x4plus"
+
+ variant_amount = st.slider("Variant Amount:", value=st.session_state['defaults'].img2img.variant_amount, min_value=0.0, max_value=1.0, step=0.01)
+ variant_seed = st.text_input("Variant Seed:", value=st.session_state['defaults'].img2img.variant_seed,
+ help="The seed to use when generating a variant, if left blank a random seed will be generated.")
+ cfg_scale = st.slider("CFG (Classifier Free Guidance Scale):", min_value=1.0, max_value=30.0, value=st.session_state['defaults'].img2img.cfg_scale, step=0.5,
+ help="How strongly the image should follow the prompt.")
+ batch_size = st.slider("Batch size", min_value=1, max_value=100, value=st.session_state['defaults'].img2img.batch_size, step=1,
+ help="How many images are at once in a batch.\
+ It increases the VRAM usage a lot but if you have enough VRAM it can reduce the time it takes to finish \
+ generation as more images are generated at once.\
+ Default: 1")
+
+ st.session_state["denoising_strength"] = st.slider("Denoising Strength:", value=st.session_state['defaults'].img2img.denoising_strength,
+ min_value=0.01, max_value=1.0, step=0.01)
+
+ with st.expander("Preview Settings"):
+ st.session_state["update_preview"] = st.checkbox("Update Image Preview", value=st.session_state['defaults'].img2img.update_preview,
+ help="If enabled the image preview will be updated during the generation instead of at the end. \
+ You can use the Update Preview \Frequency option bellow to customize how frequent it's updated. \
+ By default this is enabled and the frequency is set to 1 step.")
+
+ st.session_state["update_preview_frequency"] = st.text_input("Update Image Preview Frequency", value=st.session_state['defaults'].img2img.update_preview_frequency,
+ help="Frequency in steps at which the the preview image is updated. By default the frequency \
+ is set to 1 step.")
+
+ with col2_img2img_layout:
+ editor_tab = st.tabs(["Editor"])
+
+ editor_image = st.empty()
+ st.session_state["editor_image"] = editor_image
+
+ st.form_submit_button("Refresh")
+
+ masked_image_holder = st.empty()
+ image_holder = st.empty()
+
+ uploaded_images = st.file_uploader(
+ "Upload Image", accept_multiple_files=False, type=["png", "jpg", "jpeg", "webp"],
+ help="Upload an image which will be used for the image to image generation.",
+ )
+ if uploaded_images:
+ image = Image.open(uploaded_images).convert('RGBA')
+ new_img = image.resize((width, height))
+ image_holder.image(new_img)
+
+ mask_holder = st.empty()
+
+ uploaded_masks = st.file_uploader(
+ "Upload Mask", accept_multiple_files=False, type=["png", "jpg", "jpeg", "webp"],
+ help="Upload an mask image which will be used for masking the image to image generation.",
+ )
+ if uploaded_masks:
+ mask = Image.open(uploaded_masks)
+ if mask.mode == "RGBA":
+ mask = mask.convert('RGBA')
+ background = Image.new('RGBA', mask.size, (0, 0, 0))
+ mask = Image.alpha_composite(background, mask)
+ mask = mask.resize((width, height))
+ mask_holder.image(mask)
+
+ if uploaded_images and uploaded_masks:
+ if mask_mode != 2:
+ final_img = new_img.copy()
+ alpha_layer = mask.convert('L')
+ strength = st.session_state["denoising_strength"]
+ if mask_mode == 0:
+ alpha_layer = ImageOps.invert(alpha_layer)
+ alpha_layer = alpha_layer.point(lambda a: a * strength)
+ alpha_layer = ImageOps.invert(alpha_layer)
+ elif mask_mode == 1:
+ alpha_layer = alpha_layer.point(lambda a: a * strength)
+ alpha_layer = ImageOps.invert(alpha_layer)
+
+ final_img.putalpha(alpha_layer)
+
+ with masked_image_holder.container():
+ st.text("Masked Image Preview")
+ st.image(final_img)
+
+
+ with col3_img2img_layout:
+ result_tab = st.tabs(["Result"])
+
+ # create an empty container for the image, progress bar, etc so we can update it later and use session_state to hold them globally.
+ preview_image = st.empty()
+ st.session_state["preview_image"] = preview_image
+
+ #st.session_state["loading"] = st.empty()
+
+ st.session_state["progress_bar_text"] = st.empty()
+ st.session_state["progress_bar"] = st.empty()
+
+
+ message = st.empty()
+
+ #if uploaded_images:
+ #image = Image.open(uploaded_images).convert('RGB')
+ ##img_array = np.array(image) # if you want to pass it to OpenCV
+ #new_img = image.resize((width, height))
+ #st.image(new_img, use_column_width=True)
+
+
+ if generate_button:
+ #print("Loading models")
+ # load the models when we hit the generate button for the first time, it wont be loaded after that so dont worry.
+ load_models(False, use_GFPGAN, st.session_state["use_RealESRGAN"], st.session_state["RealESRGAN_model"], st.session_state["CustomModel_available"],
+ st.session_state["custom_model"])
+
+ if uploaded_images:
+ image = Image.open(uploaded_images).convert('RGBA')
+ new_img = image.resize((width, height))
+ #img_array = np.array(image) # if you want to pass it to OpenCV
+ new_mask = None
+ if uploaded_masks:
+ mask = Image.open(uploaded_masks).convert('RGBA')
+ new_mask = mask.resize((width, height))
+
+ try:
+ output_images, seed, info, stats = img2img(prompt=prompt, init_info=new_img, init_info_mask=new_mask, mask_mode=mask_mode,
+ mask_restore=img2img_mask_restore, ddim_steps=st.session_state["sampling_steps"],
+ sampler_name=st.session_state["sampler_name"], n_iter=batch_count,
+ cfg_scale=cfg_scale, denoising_strength=st.session_state["denoising_strength"], variant_seed=variant_seed,
+ seed=seed, noise_mode=noise_mode, find_noise_steps=find_noise_steps, width=width,
+ height=height, variant_amount=variant_amount,
+ ddim_eta=0.0, write_info_files=write_info_files, RealESRGAN_model=st.session_state["RealESRGAN_model"],
+ separate_prompts=separate_prompts, normalize_prompt_weights=normalize_prompt_weights,
+ save_individual_images=save_individual_images, save_grid=save_grid,
+ group_by_prompt=group_by_prompt, save_as_jpg=save_as_jpg, use_GFPGAN=use_GFPGAN,
+ use_RealESRGAN=st.session_state["use_RealESRGAN"] if not loopback else False, loopback=loopback
+ )
+
+ #show a message when the generation is complete.
+ message.success('Render Complete: ' + info + '; Stats: ' + stats, icon="✅")
+
+ except (StopException, KeyError):
+ print(f"Received Streamlit StopException")
+
+ # this will render all the images at the end of the generation but its better if its moved to a second tab inside col2 and shown as a gallery.
+ # use the current col2 first tab to show the preview_img and update it as its generated.
+ #preview_image.image(output_images, width=750)
+
+#on import run init
diff --git a/scripts/imglab.py b/scripts/imglab.py
new file mode 100644
index 0000000..eb09c6a
--- /dev/null
+++ b/scripts/imglab.py
@@ -0,0 +1,161 @@
+# base webui import and utils.
+from webui_streamlit import st
+from sd_utils import *
+
+#home plugin
+import os
+from PIL import Image
+#from bs4 import BeautifulSoup
+from streamlit.runtime.in_memory_file_manager import in_memory_file_manager
+from streamlit.elements import image as STImage
+
+# Temp imports
+
+
+# end of imports
+#---------------------------------------------------------------------------------------------------------------
+
+try:
+ # this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
+ from transformers import logging
+
+ logging.set_verbosity_error()
+except:
+ pass
+
+class plugin_info():
+ plugname = "imglab"
+ description = "Image Lab"
+ isTab = True
+ displayPriority = 3
+
+def getLatestGeneratedImagesFromPath():
+ #get the latest images from the generated images folder
+ #get the path to the generated images folder
+ generatedImagesPath = os.path.join(os.getcwd(),'outputs')
+ #get all the files from the folders and subfolders
+ files = []
+ #get the laest 10 images from the output folder without walking the subfolders
+ for r, d, f in os.walk(generatedImagesPath):
+ for file in f:
+ if '.png' in file:
+ files.append(os.path.join(r, file))
+ #sort the files by date
+ files.sort(key=os.path.getmtime)
+ #reverse the list so the latest images are first
+ for f in files:
+ img = Image.open(f)
+ files[files.index(f)] = img
+ #get the latest 10 files
+ #get all the files with the .png or .jpg extension
+ #sort files by date
+ #get the latest 10 files
+ latestFiles = files[-10:]
+ #reverse the list
+ latestFiles.reverse()
+ return latestFiles
+
+def getImagesFromLexica():
+ #scrape images from lexica.art
+ #get the html from the page
+ #get the html with cookies and javascript
+ apiEndpoint = r'https://lexica.art/api/trpc/prompts.infinitePrompts?batch=1&input=%7B%220%22%3A%7B%22json%22%3A%7B%22limit%22%3A10%2C%22text%22%3A%22%22%2C%22cursor%22%3A10%7D%7D%7D'
+ #REST API call
+ #
+ from requests_html import HTMLSession
+ session = HTMLSession()
+
+ response = session.get(apiEndpoint)
+ #req = requests.Session()
+ #req.headers['user-agent'] = 'Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/96.0.4664.45 Safari/537.36'
+ #response = req.get(apiEndpoint)
+ print(response.status_code)
+ print(response.text)
+ #get the json from the response
+ #json = response.json()
+ #get the prompts from the json
+ print(response)
+ #session = requests.Session()
+ #parseEndpointJson = session.get(apiEndpoint,headers=headers,verify=False)
+ #print(parseEndpointJson)
+ #print('test2')
+ #page = requests.get("https://lexica.art/", headers={'User-Agent': 'Mozilla/5.0'})
+ #parse the html
+ #soup = BeautifulSoup(page.content, 'html.parser')
+ #find all the images
+ #print(soup)
+ #images = soup.find_all('alt-image')
+ #create a list to store the image urls
+ image_urls = []
+ #loop through the images
+ for image in images:
+ #get the url
+ image_url = image['src']
+ #add it to the list
+ image_urls.append('http://www.lexica.art/'+image_url)
+ #return the list
+ print(image_urls)
+ return image_urls
+def changeImage():
+ #change the image in the image holder
+ #check if the file is not empty
+ if len(st.session_state['uploaded_file']) > 0:
+ #read the file
+ print('test2')
+ uploaded = st.session_state['uploaded_file'][0].read()
+ #show the image in the image holder
+ st.session_state['previewImg'].empty()
+ st.session_state['previewImg'].image(uploaded,use_column_width=True)
+def createHTMLGallery(images):
+ html3 = """
+
+ """
+ mkdwn_array = []
+ for i in images:
+ bImg = i.read()
+ i = Image.save(bImg, 'PNG')
+ width, height = i.size
+ #get random number for the id
+ image_id = "%s" % (str(images.index(i)))
+ (data, mimetype) = STImage._normalize_to_bytes(bImg.getvalue(), width, 'auto')
+ this_file = in_memory_file_manager.add(data, mimetype, image_id)
+ img_str = this_file.url
+ #img_str = 'data:image/png;base64,' + b64encode(image_io.getvalue()).decode('ascii')
+ #get image size
+
+ #make sure the image is not bigger then 150px but keep the aspect ratio
+ if width > 150:
+ height = int(height * (150/width))
+ width = 150
+ if height > 150:
+ width = int(width * (150/height))
+ height = 150
+
+ #mkdwn = f""""""
+ mkdwn = f'''
'
+ return html3
+def layout():
+
+ col1, col2 = st.columns(2)
+ with col1:
+ st.session_state['uploaded_file'] = st.file_uploader("Choose an image or images", type=["png", "jpg", "jpeg", "webp"],accept_multiple_files=True,on_change=changeImage)
+ if 'previewImg' not in st.session_state:
+ st.session_state['previewImg'] = st.empty()
+ else:
+ if len(st.session_state['uploaded_file']) > 0:
+ st.session_state['previewImg'].empty()
+ st.session_state['previewImg'].image(st.session_state['uploaded_file'][0],use_column_width=True)
+ else:
+ st.session_state['previewImg'] = st.empty()
+
diff --git a/scripts/perlin.py b/scripts/perlin.py
new file mode 100644
index 0000000..327a994
--- /dev/null
+++ b/scripts/perlin.py
@@ -0,0 +1,48 @@
+import numpy as np
+
+def perlin(x, y, seed=0):
+ # permutation table
+ np.random.seed(seed)
+ p = np.arange(256, dtype=int)
+ np.random.shuffle(p)
+ p = np.stack([p, p]).flatten()
+ # coordinates of the top-left
+ xi, yi = x.astype(int), y.astype(int)
+ # internal coordinates
+ xf, yf = x - xi, y - yi
+ # fade factors
+ u, v = fade(xf), fade(yf)
+ # noise components
+ n00 = gradient(p[p[xi] + yi], xf, yf)
+ n01 = gradient(p[p[xi] + yi + 1], xf, yf - 1)
+ n11 = gradient(p[p[xi + 1] + yi + 1], xf - 1, yf - 1)
+ n10 = gradient(p[p[xi + 1] + yi], xf - 1, yf)
+ # combine noises
+ x1 = lerp(n00, n10, u)
+ x2 = lerp(n01, n11, u) # FIX1: I was using n10 instead of n01
+ return lerp(x1, x2, v) # FIX2: I also had to reverse x1 and x2 here
+
+def lerp(a, b, x):
+ "linear interpolation"
+ return a + x * (b - a)
+
+def fade(t):
+ "6t^5 - 15t^4 + 10t^3"
+ return 6 * t**5 - 15 * t**4 + 10 * t**3
+
+def gradient(h, x, y):
+ "grad converts h to the right gradient vector and return the dot product with (x,y)"
+ vectors = np.array([[0, 1], [0, -1], [1, 0], [-1, 0]])
+ g = vectors[h % 4]
+ return g[:, :, 0] * x + g[:, :, 1] * y
+
+lin = np.linspace(0, 5, 100, endpoint=False)
+x, y = np.meshgrid(lin, lin)
+
+
+
+def perlinNoise(height,width,octavesx=5,octavesy=5,seed=None):
+ linx = np.linspace(0,octavesx,width,endpoint=False)
+ liny = np.linspace(0,octavesy,height,endpoint=False)
+ x,y = np.meshgrid(linx,liny)
+ return perlin(x,y,seed=seed)
\ No newline at end of file
diff --git a/scripts/relauncher.py b/scripts/relauncher.py
index 7179d7f..457d539 100644
--- a/scripts/relauncher.py
+++ b/scripts/relauncher.py
@@ -19,6 +19,8 @@ optimized_turbo = False
# Creates a public xxxxx.gradio.app share link to allow others to use your interface (requires properly forwarded ports to work correctly)
share = False
+# Generate tiling images
+tiling = False
# Enter other `--arguments` you wish to use - Must be entered as a `--argument ` syntax
additional_arguments = ""
@@ -37,6 +39,8 @@ if optimized_turbo == True:
common_arguments += "--optimized-turbo "
if optimized == True:
common_arguments += "--optimized "
+if tiling == True:
+ common_arguments += "--tiling "
if share == True:
common_arguments += "--share "
diff --git a/scripts/sd_utils.py b/scripts/sd_utils.py
new file mode 100644
index 0000000..6983edb
--- /dev/null
+++ b/scripts/sd_utils.py
@@ -0,0 +1,1728 @@
+# base webui import and utils.
+from webui_streamlit import st
+
+
+# streamlit imports
+from streamlit import StopException
+#other imports
+
+import warnings
+import json
+
+import base64
+import os, sys, re, random, datetime, time, math, glob
+from PIL import Image, ImageFont, ImageDraw, ImageFilter
+from PIL.PngImagePlugin import PngInfo
+from scipy import integrate
+import torch
+from torchdiffeq import odeint
+import k_diffusion as K
+import math
+import mimetypes
+import numpy as np
+import pynvml
+import threading
+import torch
+from torch import autocast
+from torchvision import transforms
+import torch.nn as nn
+from omegaconf import OmegaConf
+import yaml
+from pathlib import Path
+from contextlib import nullcontext
+from einops import rearrange
+from ldm.util import instantiate_from_config
+from retry import retry
+from slugify import slugify
+import skimage
+import piexif
+import piexif.helper
+from tqdm import trange
+
+# Temp imports
+
+
+# end of imports
+#---------------------------------------------------------------------------------------------------------------
+
+try:
+ # this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
+ from transformers import logging
+
+ logging.set_verbosity_error()
+except:
+ pass
+
+# remove some annoying deprecation warnings that show every now and then.
+warnings.filterwarnings("ignore", category=DeprecationWarning)
+
+# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the bowser will not show any UI
+mimetypes.init()
+mimetypes.add_type('application/javascript', '.js')
+
+# some of those options should not be changed at all because they would break the model, so I removed them from options.
+opt_C = 4
+opt_f = 8
+
+if not "defaults" in st.session_state:
+ st.session_state["defaults"] = {}
+
+st.session_state["defaults"] = OmegaConf.load("configs/webui/webui_streamlit.yaml")
+
+if (os.path.exists("configs/webui/userconfig_streamlit.yaml")):
+ user_defaults = OmegaConf.load("configs/webui/userconfig_streamlit.yaml")
+ st.session_state["defaults"] = OmegaConf.merge(st.session_state["defaults"], user_defaults)
+
+
+# should and will be moved to a settings menu in the UI at some point
+grid_format = [s.lower() for s in st.session_state["defaults"].general.grid_format.split(':')]
+grid_lossless = False
+grid_quality = 100
+if grid_format[0] == 'png':
+ grid_ext = 'png'
+ grid_format = 'png'
+elif grid_format[0] in ['jpg', 'jpeg']:
+ grid_quality = int(grid_format[1]) if len(grid_format) > 1 else 100
+ grid_ext = 'jpg'
+ grid_format = 'jpeg'
+elif grid_format[0] == 'webp':
+ grid_quality = int(grid_format[1]) if len(grid_format) > 1 else 100
+ grid_ext = 'webp'
+ grid_format = 'webp'
+ if grid_quality < 0: # e.g. webp:-100 for lossless mode
+ grid_lossless = True
+ grid_quality = abs(grid_quality)
+
+# should and will be moved to a settings menu in the UI at some point
+save_format = [s.lower() for s in st.session_state["defaults"].general.save_format.split(':')]
+save_lossless = False
+save_quality = 100
+if save_format[0] == 'png':
+ save_ext = 'png'
+ save_format = 'png'
+elif save_format[0] in ['jpg', 'jpeg']:
+ save_quality = int(save_format[1]) if len(save_format) > 1 else 100
+ save_ext = 'jpg'
+ save_format = 'jpeg'
+elif save_format[0] == 'webp':
+ save_quality = int(save_format[1]) if len(save_format) > 1 else 100
+ save_ext = 'webp'
+ save_format = 'webp'
+ if save_quality < 0: # e.g. webp:-100 for lossless mode
+ save_lossless = True
+ save_quality = abs(save_quality)
+
+# this should force GFPGAN and RealESRGAN onto the selected gpu as well
+os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" # see issue #152
+os.environ["CUDA_VISIBLE_DEVICES"] = str(st.session_state["defaults"].general.gpu)
+
+@retry(tries=5)
+def load_models(continue_prev_run = False, use_GFPGAN=False, use_RealESRGAN=False, RealESRGAN_model="RealESRGAN_x4plus",
+ CustomModel_available=False, custom_model="Stable Diffusion v1.4"):
+ """Load the different models. We also reuse the models that are already in memory to speed things up instead of loading them again. """
+
+ print ("Loading models.")
+
+ st.session_state["progress_bar_text"].text("Loading models...")
+
+ # Generate random run ID
+ # Used to link runs linked w/ continue_prev_run which is not yet implemented
+ # Use URL and filesystem safe version just in case.
+ st.session_state["run_id"] = base64.urlsafe_b64encode(
+ os.urandom(6)
+ ).decode("ascii")
+
+ # check what models we want to use and if the they are already loaded.
+
+ if use_GFPGAN:
+ if "GFPGAN" in st.session_state:
+ print("GFPGAN already loaded")
+ else:
+ # Load GFPGAN
+ if os.path.exists(st.session_state["defaults"].general.GFPGAN_dir):
+ try:
+ st.session_state["GFPGAN"] = load_GFPGAN()
+ print("Loaded GFPGAN")
+ except Exception:
+ import traceback
+ print("Error loading GFPGAN:", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+ else:
+ if "GFPGAN" in st.session_state:
+ del st.session_state["GFPGAN"]
+
+ if use_RealESRGAN:
+ if "RealESRGAN" in st.session_state and st.session_state["RealESRGAN"].model.name == RealESRGAN_model:
+ print("RealESRGAN already loaded")
+ else:
+ #Load RealESRGAN
+ try:
+ # We first remove the variable in case it has something there,
+ # some errors can load the model incorrectly and leave things in memory.
+ del st.session_state["RealESRGAN"]
+ except KeyError:
+ pass
+
+ if os.path.exists(st.session_state["defaults"].general.RealESRGAN_dir):
+ # st.session_state is used for keeping the models in memory across multiple pages or runs.
+ st.session_state["RealESRGAN"] = load_RealESRGAN(RealESRGAN_model)
+ print("Loaded RealESRGAN with model "+ st.session_state["RealESRGAN"].model.name)
+
+ else:
+ if "RealESRGAN" in st.session_state:
+ del st.session_state["RealESRGAN"]
+
+ if "model" in st.session_state:
+ if "model" in st.session_state and st.session_state["loaded_model"] == custom_model:
+ # TODO: check if the optimized mode was changed?
+ print("Model already loaded")
+
+ return
+ else:
+ try:
+ del st.session_state.model
+ del st.session_state.modelCS
+ del st.session_state.modelFS
+ del st.session_state.loaded_model
+ except KeyError:
+ pass
+
+ # At this point the model is either
+ # is not loaded yet or have been evicted:
+ # load new model into memory
+ st.session_state.custom_model = custom_model
+
+ config, device, model, modelCS, modelFS = load_sd_model(custom_model)
+
+ st.session_state.device = device
+ st.session_state.model = model
+ st.session_state.modelCS = modelCS
+ st.session_state.modelFS = modelFS
+ st.session_state.loaded_model = custom_model
+
+ if st.session_state.defaults.general.enable_attention_slicing:
+ st.session_state.model.enable_attention_slicing()
+
+ if st.session_state.defaults.general.enable_minimal_memory_usage:
+ st.session_state.model.enable_minimal_memory_usage()
+
+ print("Model loaded.")
+
+
+def load_model_from_config(config, ckpt, verbose=False):
+
+ print(f"Loading model from {ckpt}")
+
+ pl_sd = torch.load(ckpt, map_location="cpu")
+ if "global_step" in pl_sd:
+ print(f"Global Step: {pl_sd['global_step']}")
+ sd = pl_sd["state_dict"]
+ model = instantiate_from_config(config.model)
+ m, u = model.load_state_dict(sd, strict=False)
+ if len(m) > 0 and verbose:
+ print("missing keys:")
+ print(m)
+ if len(u) > 0 and verbose:
+ print("unexpected keys:")
+ print(u)
+
+ model.cuda()
+ model.eval()
+ return model
+
+
+def load_sd_from_config(ckpt, verbose=False):
+ print(f"Loading model from {ckpt}")
+ pl_sd = torch.load(ckpt, map_location="cpu")
+ if "global_step" in pl_sd:
+ print(f"Global Step: {pl_sd['global_step']}")
+ sd = pl_sd["state_dict"]
+ return sd
+
+
+class MemUsageMonitor(threading.Thread):
+ stop_flag = False
+ max_usage = 0
+ total = -1
+
+ def __init__(self, name):
+ threading.Thread.__init__(self)
+ self.name = name
+
+ def run(self):
+ try:
+ pynvml.nvmlInit()
+ except:
+ print(f"[{self.name}] Unable to initialize NVIDIA management. No memory stats. \n")
+ return
+ print(f"[{self.name}] Recording max memory usage...\n")
+ # Missing context
+ #handle = pynvml.nvmlDeviceGetHandleByIndex(st.session_state['defaults'].general.gpu)
+ handle = pynvml.nvmlDeviceGetHandleByIndex(0)
+ self.total = pynvml.nvmlDeviceGetMemoryInfo(handle).total
+ while not self.stop_flag:
+ m = pynvml.nvmlDeviceGetMemoryInfo(handle)
+ self.max_usage = max(self.max_usage, m.used)
+ # print(self.max_usage)
+ time.sleep(0.1)
+ print(f"[{self.name}] Stopped recording.\n")
+ pynvml.nvmlShutdown()
+
+ def read(self):
+ return self.max_usage, self.total
+
+ def stop(self):
+ self.stop_flag = True
+
+ def read_and_stop(self):
+ self.stop_flag = True
+ return self.max_usage, self.total
+
+class CFGMaskedDenoiser(nn.Module):
+ def __init__(self, model):
+ super().__init__()
+ self.inner_model = model
+
+ def forward(self, x, sigma, uncond, cond, cond_scale, mask, x0, xi):
+ x_in = x
+ x_in = torch.cat([x_in] * 2)
+ sigma_in = torch.cat([sigma] * 2)
+ cond_in = torch.cat([uncond, cond])
+ uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
+ denoised = uncond + (cond - uncond) * cond_scale
+
+ if mask is not None:
+ assert x0 is not None
+ img_orig = x0
+ mask_inv = 1. - mask
+ denoised = (img_orig * mask_inv) + (mask * denoised)
+
+ return denoised
+
+class CFGDenoiser(nn.Module):
+ def __init__(self, model):
+ super().__init__()
+ self.inner_model = model
+
+ def forward(self, x, sigma, uncond, cond, cond_scale):
+ x_in = torch.cat([x] * 2)
+ sigma_in = torch.cat([sigma] * 2)
+ cond_in = torch.cat([uncond, cond])
+ uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
+ return uncond + (cond - uncond) * cond_scale
+def append_zero(x):
+ return torch.cat([x, x.new_zeros([1])])
+def append_dims(x, target_dims):
+ """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
+ dims_to_append = target_dims - x.ndim
+ if dims_to_append < 0:
+ raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
+ return x[(...,) + (None,) * dims_to_append]
+def get_sigmas_karras(n, sigma_min, sigma_max, rho=7., device='cpu'):
+ """Constructs the noise schedule of Karras et al. (2022)."""
+ ramp = torch.linspace(0, 1, n)
+ min_inv_rho = sigma_min ** (1 / rho)
+ max_inv_rho = sigma_max ** (1 / rho)
+ sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
+ return append_zero(sigmas).to(device)
+
+#
+# helper fft routines that keep ortho normalization and auto-shift before and after fft
+def _fft2(data):
+ if data.ndim > 2: # has channels
+ out_fft = np.zeros((data.shape[0], data.shape[1], data.shape[2]), dtype=np.complex128)
+ for c in range(data.shape[2]):
+ c_data = data[:,:,c]
+ out_fft[:,:,c] = np.fft.fft2(np.fft.fftshift(c_data),norm="ortho")
+ out_fft[:,:,c] = np.fft.ifftshift(out_fft[:,:,c])
+ else: # one channel
+ out_fft = np.zeros((data.shape[0], data.shape[1]), dtype=np.complex128)
+ out_fft[:,:] = np.fft.fft2(np.fft.fftshift(data),norm="ortho")
+ out_fft[:,:] = np.fft.ifftshift(out_fft[:,:])
+
+ return out_fft
+
+def _ifft2(data):
+ if data.ndim > 2: # has channels
+ out_ifft = np.zeros((data.shape[0], data.shape[1], data.shape[2]), dtype=np.complex128)
+ for c in range(data.shape[2]):
+ c_data = data[:,:,c]
+ out_ifft[:,:,c] = np.fft.ifft2(np.fft.fftshift(c_data),norm="ortho")
+ out_ifft[:,:,c] = np.fft.ifftshift(out_ifft[:,:,c])
+ else: # one channel
+ out_ifft = np.zeros((data.shape[0], data.shape[1]), dtype=np.complex128)
+ out_ifft[:,:] = np.fft.ifft2(np.fft.fftshift(data),norm="ortho")
+ out_ifft[:,:] = np.fft.ifftshift(out_ifft[:,:])
+
+ return out_ifft
+
+def _get_gaussian_window(width, height, std=3.14, mode=0):
+
+ window_scale_x = float(width / min(width, height))
+ window_scale_y = float(height / min(width, height))
+
+ window = np.zeros((width, height))
+ x = (np.arange(width) / width * 2. - 1.) * window_scale_x
+ for y in range(height):
+ fy = (y / height * 2. - 1.) * window_scale_y
+ if mode == 0:
+ window[:, y] = np.exp(-(x**2+fy**2) * std)
+ else:
+ window[:, y] = (1/((x**2+1.) * (fy**2+1.))) ** (std/3.14) # hey wait a minute that's not gaussian
+
+ return window
+
+def _get_masked_window_rgb(np_mask_grey, hardness=1.):
+ np_mask_rgb = np.zeros((np_mask_grey.shape[0], np_mask_grey.shape[1], 3))
+ if hardness != 1.:
+ hardened = np_mask_grey[:] ** hardness
+ else:
+ hardened = np_mask_grey[:]
+ for c in range(3):
+ np_mask_rgb[:,:,c] = hardened[:]
+ return np_mask_rgb
+
+def get_matched_noise(_np_src_image, np_mask_rgb, noise_q, color_variation):
+ """
+ Explanation:
+ Getting good results in/out-painting with stable diffusion can be challenging.
+ Although there are simpler effective solutions for in-painting, out-painting can be especially challenging because there is no color data
+ in the masked area to help prompt the generator. Ideally, even for in-painting we'd like work effectively without that data as well.
+ Provided here is my take on a potential solution to this problem.
+
+ By taking a fourier transform of the masked src img we get a function that tells us the presence and orientation of each feature scale in the unmasked src.
+ Shaping the init/seed noise for in/outpainting to the same distribution of feature scales, orientations, and positions increases output coherence
+ by helping keep features aligned. This technique is applicable to any continuous generation task such as audio or video, each of which can
+ be conceptualized as a series of out-painting steps where the last half of the input "frame" is erased. For multi-channel data such as color
+ or stereo sound the "color tone" or histogram of the seed noise can be matched to improve quality (using scikit-image currently)
+ This method is quite robust and has the added benefit of being fast independently of the size of the out-painted area.
+ The effects of this method include things like helping the generator integrate the pre-existing view distance and camera angle.
+
+ Carefully managing color and brightness with histogram matching is also essential to achieving good coherence.
+
+ noise_q controls the exponent in the fall-off of the distribution can be any positive number, lower values means higher detail (range > 0, default 1.)
+ color_variation controls how much freedom is allowed for the colors/palette of the out-painted area (range 0..1, default 0.01)
+ This code is provided as is under the Unlicense (https://unlicense.org/)
+ Although you have no obligation to do so, if you found this code helpful please find it in your heart to credit me [parlance-zz].
+
+ Questions or comments can be sent to parlance@fifth-harmonic.com (https://github.com/parlance-zz/)
+ This code is part of a new branch of a discord bot I am working on integrating with diffusers (https://github.com/parlance-zz/g-diffuser-bot)
+
+ """
+
+ global DEBUG_MODE
+ global TMP_ROOT_PATH
+
+ width = _np_src_image.shape[0]
+ height = _np_src_image.shape[1]
+ num_channels = _np_src_image.shape[2]
+
+ np_src_image = _np_src_image[:] * (1. - np_mask_rgb)
+ np_mask_grey = (np.sum(np_mask_rgb, axis=2)/3.)
+ np_src_grey = (np.sum(np_src_image, axis=2)/3.)
+ all_mask = np.ones((width, height), dtype=bool)
+ img_mask = np_mask_grey > 1e-6
+ ref_mask = np_mask_grey < 1e-3
+
+ windowed_image = _np_src_image * (1.-_get_masked_window_rgb(np_mask_grey))
+ windowed_image /= np.max(windowed_image)
+ windowed_image += np.average(_np_src_image) * np_mask_rgb# / (1.-np.average(np_mask_rgb)) # rather than leave the masked area black, we get better results from fft by filling the average unmasked color
+ #windowed_image += np.average(_np_src_image) * (np_mask_rgb * (1.- np_mask_rgb)) / (1.-np.average(np_mask_rgb)) # compensate for darkening across the mask transition area
+ #_save_debug_img(windowed_image, "windowed_src_img")
+
+ src_fft = _fft2(windowed_image) # get feature statistics from masked src img
+ src_dist = np.absolute(src_fft)
+ src_phase = src_fft / src_dist
+ #_save_debug_img(src_dist, "windowed_src_dist")
+
+ noise_window = _get_gaussian_window(width, height, mode=1) # start with simple gaussian noise
+ noise_rgb = np.random.random_sample((width, height, num_channels))
+ noise_grey = (np.sum(noise_rgb, axis=2)/3.)
+ noise_rgb *= color_variation # the colorfulness of the starting noise is blended to greyscale with a parameter
+ for c in range(num_channels):
+ noise_rgb[:,:,c] += (1. - color_variation) * noise_grey
+
+ noise_fft = _fft2(noise_rgb)
+ for c in range(num_channels):
+ noise_fft[:,:,c] *= noise_window
+ noise_rgb = np.real(_ifft2(noise_fft))
+ shaped_noise_fft = _fft2(noise_rgb)
+ shaped_noise_fft[:,:,:] = np.absolute(shaped_noise_fft[:,:,:])**2 * (src_dist ** noise_q) * src_phase # perform the actual shaping
+
+ brightness_variation = 0.#color_variation # todo: temporarily tieing brightness variation to color variation for now
+ contrast_adjusted_np_src = _np_src_image[:] * (brightness_variation + 1.) - brightness_variation * 2.
+
+ # scikit-image is used for histogram matching, very convenient!
+ shaped_noise = np.real(_ifft2(shaped_noise_fft))
+ shaped_noise -= np.min(shaped_noise)
+ shaped_noise /= np.max(shaped_noise)
+ shaped_noise[img_mask,:] = skimage.exposure.match_histograms(shaped_noise[img_mask,:]**1., contrast_adjusted_np_src[ref_mask,:], channel_axis=1)
+ shaped_noise = _np_src_image[:] * (1. - np_mask_rgb) + shaped_noise * np_mask_rgb
+ #_save_debug_img(shaped_noise, "shaped_noise")
+
+ matched_noise = np.zeros((width, height, num_channels))
+ matched_noise = shaped_noise[:]
+ #matched_noise[all_mask,:] = skimage.exposure.match_histograms(shaped_noise[all_mask,:], _np_src_image[ref_mask,:], channel_axis=1)
+ #matched_noise = _np_src_image[:] * (1. - np_mask_rgb) + matched_noise * np_mask_rgb
+
+ #_save_debug_img(matched_noise, "matched_noise")
+
+ """
+ todo:
+ color_variation doesnt have to be a single number, the overall color tone of the out-painted area could be param controlled
+ """
+
+ return np.clip(matched_noise, 0., 1.)
+
+
+#
+def find_noise_for_image(model, device, init_image, prompt, steps=200, cond_scale=2.0, verbose=False, normalize=False, generation_callback=None):
+ image = np.array(init_image).astype(np.float32) / 255.0
+ image = image[None].transpose(0, 3, 1, 2)
+ image = torch.from_numpy(image)
+ image = 2. * image - 1.
+ image = image.to(device)
+ x = model.get_first_stage_encoding(model.encode_first_stage(image))
+
+ uncond = model.get_learned_conditioning([''])
+ cond = model.get_learned_conditioning([prompt])
+
+ s_in = x.new_ones([x.shape[0]])
+ dnw = K.external.CompVisDenoiser(model)
+ sigmas = dnw.get_sigmas(steps).flip(0)
+
+ if verbose:
+ print(sigmas)
+
+ for i in trange(1, len(sigmas)):
+ x_in = torch.cat([x] * 2)
+ sigma_in = torch.cat([sigmas[i - 1] * s_in] * 2)
+ cond_in = torch.cat([uncond, cond])
+
+ c_out, c_in = [K.utils.append_dims(k, x_in.ndim) for k in dnw.get_scalings(sigma_in)]
+
+ if i == 1:
+ t = dnw.sigma_to_t(torch.cat([sigmas[i] * s_in] * 2))
+ else:
+ t = dnw.sigma_to_t(sigma_in)
+
+ eps = model.apply_model(x_in * c_in, t, cond=cond_in)
+ denoised_uncond, denoised_cond = (x_in + eps * c_out).chunk(2)
+
+ denoised = denoised_uncond + (denoised_cond - denoised_uncond) * cond_scale
+
+ if i == 1:
+ d = (x - denoised) / (2 * sigmas[i])
+ else:
+ d = (x - denoised) / sigmas[i - 1]
+
+ if generation_callback is not None:
+ generation_callback(x, i)
+
+ dt = sigmas[i] - sigmas[i - 1]
+ x = x + d * dt
+
+ return x / sigmas[-1]
+
+
+def get_sigmas_exponential(n, sigma_min, sigma_max, device='cpu'):
+ """Constructs an exponential noise schedule."""
+ sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), n, device=device).exp()
+ return append_zero(sigmas)
+
+
+def get_sigmas_vp(n, beta_d=19.9, beta_min=0.1, eps_s=1e-3, device='cpu'):
+ """Constructs a continuous VP noise schedule."""
+ t = torch.linspace(1, eps_s, n, device=device)
+ sigmas = torch.sqrt(torch.exp(beta_d * t ** 2 / 2 + beta_min * t) - 1)
+ return append_zero(sigmas)
+
+
+def to_d(x, sigma, denoised):
+ """Converts a denoiser output to a Karras ODE derivative."""
+ return (x - denoised) / append_dims(sigma, x.ndim)
+def linear_multistep_coeff(order, t, i, j):
+ if order - 1 > i:
+ raise ValueError(f'Order {order} too high for step {i}')
+ def fn(tau):
+ prod = 1.
+ for k in range(order):
+ if j == k:
+ continue
+ prod *= (tau - t[i - k]) / (t[i - j] - t[i - k])
+ return prod
+ return integrate.quad(fn, t[i], t[i + 1], epsrel=1e-4)[0]
+
+class KDiffusionSampler:
+ def __init__(self, m, sampler):
+ self.model = m
+ self.model_wrap = K.external.CompVisDenoiser(m)
+ self.schedule = sampler
+ def get_sampler_name(self):
+ return self.schedule
+ def sample(self, S, conditioning, batch_size, shape, verbose, unconditional_guidance_scale, unconditional_conditioning, eta, x_T, img_callback=None, log_every_t=None):
+ sigmas = self.model_wrap.get_sigmas(S)
+ x = x_T * sigmas[0]
+ model_wrap_cfg = CFGDenoiser(self.model_wrap)
+ samples_ddim = None
+ samples_ddim = K.sampling.__dict__[f'sample_{self.schedule}'](model_wrap_cfg, x, sigmas,
+ extra_args={'cond': conditioning, 'uncond': unconditional_conditioning,
+ 'cond_scale': unconditional_guidance_scale}, disable=False, callback=generation_callback)
+ #
+ return samples_ddim, None
+
+
+@torch.no_grad()
+def log_likelihood(model, x, sigma_min, sigma_max, extra_args=None, atol=1e-4, rtol=1e-4):
+ extra_args = {} if extra_args is None else extra_args
+ s_in = x.new_ones([x.shape[0]])
+ v = torch.randint_like(x, 2) * 2 - 1
+ fevals = 0
+ def ode_fn(sigma, x):
+ nonlocal fevals
+ with torch.enable_grad():
+ x = x[0].detach().requires_grad_()
+ denoised = model(x, sigma * s_in, **extra_args)
+ d = to_d(x, sigma, denoised)
+ fevals += 1
+ grad = torch.autograd.grad((d * v).sum(), x)[0]
+ d_ll = (v * grad).flatten(1).sum(1)
+ return d.detach(), d_ll
+ x_min = x, x.new_zeros([x.shape[0]])
+ t = x.new_tensor([sigma_min, sigma_max])
+ sol = odeint(ode_fn, x_min, t, atol=atol, rtol=rtol, method='dopri5')
+ latent, delta_ll = sol[0][-1], sol[1][-1]
+ ll_prior = torch.distributions.Normal(0, sigma_max).log_prob(latent).flatten(1).sum(1)
+ return ll_prior + delta_ll, {'fevals': fevals}
+
+
+def create_random_tensors(shape, seeds):
+ xs = []
+ for seed in seeds:
+ torch.manual_seed(seed)
+
+ # randn results depend on device; gpu and cpu get different results for same seed;
+ # the way I see it, it's better to do this on CPU, so that everyone gets same result;
+ # but the original script had it like this so i do not dare change it for now because
+ # it will break everyone's seeds.
+ xs.append(torch.randn(shape, device=st.session_state['defaults'].general.gpu))
+ x = torch.stack(xs)
+ return x
+
+def torch_gc():
+ torch.cuda.empty_cache()
+ torch.cuda.ipc_collect()
+
+def load_GFPGAN():
+ model_name = 'GFPGANv1.3'
+ model_path = os.path.join(st.session_state['defaults'].general.GFPGAN_dir, 'experiments/pretrained_models', model_name + '.pth')
+ if not os.path.isfile(model_path):
+ raise Exception("GFPGAN model not found at path "+model_path)
+
+ sys.path.append(os.path.abspath(st.session_state['defaults'].general.GFPGAN_dir))
+ from gfpgan import GFPGANer
+
+ if st.session_state['defaults'].general.gfpgan_cpu or st.session_state['defaults'].general.extra_models_cpu:
+ instance = GFPGANer(model_path=model_path, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=torch.device('cpu'))
+ elif st.session_state['defaults'].general.extra_models_gpu:
+ instance = GFPGANer(model_path=model_path, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=torch.device(f"cuda:{st.session_state['defaults'].general.gfpgan_gpu}"))
+ else:
+ instance = GFPGANer(model_path=model_path, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=torch.device(f"cuda:{st.session_state['defaults'].general.gpu}"))
+ return instance
+
+def load_RealESRGAN(model_name: str):
+ from basicsr.archs.rrdbnet_arch import RRDBNet
+ RealESRGAN_models = {
+ 'RealESRGAN_x4plus': RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4),
+ 'RealESRGAN_x4plus_anime_6B': RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
+ }
+
+ model_path = os.path.join(st.session_state['defaults'].general.RealESRGAN_dir, 'experiments/pretrained_models', model_name + '.pth')
+ if not os.path.exists(os.path.join(st.session_state['defaults'].general.RealESRGAN_dir, "experiments","pretrained_models", f"{model_name}.pth")):
+ raise Exception(model_name+".pth not found at path "+model_path)
+
+ sys.path.append(os.path.abspath(st.session_state['defaults'].general.RealESRGAN_dir))
+ from realesrgan import RealESRGANer
+
+ if st.session_state['defaults'].general.esrgan_cpu or st.session_state['defaults'].general.extra_models_cpu:
+ instance = RealESRGANer(scale=2, model_path=model_path, model=RealESRGAN_models[model_name], pre_pad=0, half=False) # cpu does not support half
+ instance.device = torch.device('cpu')
+ instance.model.to('cpu')
+ elif st.session_state['defaults'].general.extra_models_gpu:
+ instance = RealESRGANer(scale=2, model_path=model_path, model=RealESRGAN_models[model_name], pre_pad=0, half=not st.session_state['defaults'].general.no_half, device=torch.device(f"cuda:{st.session_state['defaults'].general.esrgan_gpu}"))
+ else:
+ instance = RealESRGANer(scale=2, model_path=model_path, model=RealESRGAN_models[model_name], pre_pad=0, half=not st.session_state['defaults'].general.no_half, device=torch.device(f"cuda:{st.session_state['defaults'].general.gpu}"))
+ instance.model.name = model_name
+
+ return instance
+
+#
+def load_LDSR(checking=False):
+ model_name = 'model'
+ yaml_name = 'project'
+ model_path = os.path.join(st.session_state['defaults'].general.LDSR_dir, 'experiments/pretrained_models', model_name + '.ckpt')
+ yaml_path = os.path.join(st.session_state['defaults'].general.LDSR_dir, 'experiments/pretrained_models', yaml_name + '.yaml')
+ if not os.path.isfile(model_path):
+ raise Exception("LDSR model not found at path "+model_path)
+ if not os.path.isfile(yaml_path):
+ raise Exception("LDSR model not found at path "+yaml_path)
+ if checking == True:
+ return True
+
+ sys.path.append(os.path.abspath(st.session_state['defaults'].general.LDSR_dir))
+ from LDSR import LDSR
+ LDSRObject = LDSR(model_path, yaml_path)
+ return LDSRObject
+
+#
+LDSR = None
+def try_loading_LDSR(model_name: str,checking=False):
+ global LDSR
+ if os.path.exists(st.session_state['defaults'].general.LDSR_dir):
+ try:
+ LDSR = load_LDSR(checking=True) # TODO: Should try to load both models before giving up
+ if checking == True:
+ print("Found LDSR")
+ return True
+ print("Latent Diffusion Super Sampling (LDSR) model loaded")
+ except Exception:
+ import traceback
+ print("Error loading LDSR:", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+ else:
+ print("LDSR not found at path, please make sure you have cloned the LDSR repo to ./src/latent-diffusion/")
+
+#try_loading_LDSR('model',checking=True)
+
+
+# Loads Stable Diffusion model by name
+def load_sd_model(model_name: str) -> [any, any, any, any, any]:
+ ckpt_path = st.session_state.defaults.general.default_model_path
+ if model_name != st.session_state.defaults.general.default_model:
+ ckpt_path = os.path.join("models", "custom", f"{model_name}.ckpt")
+
+ if st.session_state.defaults.general.optimized:
+ config = OmegaConf.load(st.session_state.defaults.general.optimized_config)
+
+ sd = load_sd_from_config(ckpt_path)
+ li, lo = [], []
+ for key, v_ in sd.items():
+ sp = key.split('.')
+ if (sp[0]) == 'model':
+ if 'input_blocks' in sp:
+ li.append(key)
+ elif 'middle_block' in sp:
+ li.append(key)
+ elif 'time_embed' in sp:
+ li.append(key)
+ else:
+ lo.append(key)
+ for key in li:
+ sd['model1.' + key[6:]] = sd.pop(key)
+ for key in lo:
+ sd['model2.' + key[6:]] = sd.pop(key)
+
+ device = torch.device(f"cuda:{st.session_state.defaults.general.gpu}") \
+ if torch.cuda.is_available() else torch.device("cpu")
+
+ model = instantiate_from_config(config.modelUNet)
+ _, _ = model.load_state_dict(sd, strict=False)
+ model.cuda()
+ model.eval()
+ model.turbo = st.session_state.defaults.general.optimized_turbo
+
+ modelCS = instantiate_from_config(config.modelCondStage)
+ _, _ = modelCS.load_state_dict(sd, strict=False)
+ modelCS.cond_stage_model.device = device
+ modelCS.eval()
+
+ modelFS = instantiate_from_config(config.modelFirstStage)
+ _, _ = modelFS.load_state_dict(sd, strict=False)
+ modelFS.eval()
+
+ del sd
+
+ if not st.session_state.defaults.general.no_half:
+ model = model.half()
+ modelCS = modelCS.half()
+ modelFS = modelFS.half()
+
+ return config, device, model, modelCS, modelFS
+ else:
+ config = OmegaConf.load(st.session_state.defaults.general.default_model_config)
+ model = load_model_from_config(config, ckpt_path)
+
+ device = torch.device(f"cuda:{st.session_state.defaults.general.gpu}") \
+ if torch.cuda.is_available() else torch.device("cpu")
+ model = (model if st.session_state.defaults.general.no_half
+ else model.half()).to(device)
+
+ return config, device, model, None, None
+
+
+# @codedealer: No usages
+def ModelLoader(models,load=False,unload=False,imgproc_realesrgan_model_name='RealESRGAN_x4plus'):
+ #get global variables
+ global_vars = globals()
+ #check if m is in globals
+ if unload:
+ for m in models:
+ if m in global_vars:
+ #if it is, delete it
+ del global_vars[m]
+ if st.session_state['defaults'].general.optimized:
+ if m == 'model':
+ del global_vars[m+'FS']
+ del global_vars[m+'CS']
+ if m == 'model':
+ m = 'Stable Diffusion'
+ print('Unloaded ' + m)
+ if load:
+ for m in models:
+ if m not in global_vars or m in global_vars and type(global_vars[m]) == bool:
+ #if it isn't, load it
+ if m == 'GFPGAN':
+ global_vars[m] = load_GFPGAN()
+ elif m == 'model':
+ sdLoader = load_sd_from_config()
+ global_vars[m] = sdLoader[0]
+ if st.session_state['defaults'].general.optimized:
+ global_vars[m+'CS'] = sdLoader[1]
+ global_vars[m+'FS'] = sdLoader[2]
+ elif m == 'RealESRGAN':
+ global_vars[m] = load_RealESRGAN(imgproc_realesrgan_model_name)
+ elif m == 'LDSR':
+ global_vars[m] = load_LDSR()
+ if m =='model':
+ m='Stable Diffusion'
+ print('Loaded ' + m)
+ torch_gc()
+
+
+#
+@retry(tries=5)
+def generation_callback(img, i=0):
+ if "update_preview_frequency" not in st.session_state:
+ raise StopException
+
+ try:
+ if i == 0:
+ if img['i']: i = img['i']
+ except TypeError:
+ pass
+
+ if i % int(st.session_state.update_preview_frequency) == 0 and st.session_state.update_preview and i > 0:
+ #print (img)
+ #print (type(img))
+ # The following lines will convert the tensor we got on img to an actual image we can render on the UI.
+ # It can probably be done in a better way for someone who knows what they're doing. I don't.
+ #print (img,isinstance(img, torch.Tensor))
+ if isinstance(img, torch.Tensor):
+ x_samples_ddim = (st.session_state["model"] if not st.session_state['defaults'].general.optimized else st.session_state.modelFS).decode_first_stage(img)
+ else:
+ # When using the k Diffusion samplers they return a dict instead of a tensor that look like this:
+ # {'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}
+ x_samples_ddim = (st.session_state["model"] if not st.session_state['defaults'].general.optimized else st.session_state.modelFS).decode_first_stage(img["denoised"])
+
+ x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
+
+ if x_samples_ddim.ndimension() == 4:
+ pil_images = [transforms.ToPILImage()(x.squeeze_(0)) for x in x_samples_ddim]
+ pil_image = image_grid(pil_images, 1)
+ else:
+ pil_image = transforms.ToPILImage()(x_samples_ddim.squeeze_(0))
+
+ # update image on the UI so we can see the progress
+ st.session_state["preview_image"].image(pil_image)
+
+ # Show a progress bar so we can keep track of the progress even when the image progress is not been shown,
+ # Dont worry, it doesnt affect the performance.
+ if st.session_state["generation_mode"] == "txt2img":
+ percent = int(100 * float(i+1 if i+1 < st.session_state.sampling_steps else st.session_state.sampling_steps)/float(st.session_state.sampling_steps))
+ st.session_state["progress_bar_text"].text(
+ f"Running step: {i+1 if i+1 < st.session_state.sampling_steps else st.session_state.sampling_steps}/{st.session_state.sampling_steps} {percent if percent < 100 else 100}%")
+ else:
+ if st.session_state["generation_mode"] == "img2img":
+ round_sampling_steps = round(st.session_state.sampling_steps * st.session_state["denoising_strength"])
+ percent = int(100 * float(i+1 if i+1 < round_sampling_steps else round_sampling_steps)/float(round_sampling_steps))
+ st.session_state["progress_bar_text"].text(
+ f"""Running step: {i+1 if i+1 < round_sampling_steps else round_sampling_steps}/{round_sampling_steps} {percent if percent < 100 else 100}%""")
+ else:
+ if st.session_state["generation_mode"] == "txt2vid":
+ percent = int(100 * float(i+1 if i+1 < st.session_state.sampling_steps else st.session_state.sampling_steps)/float(st.session_state.sampling_steps))
+ st.session_state["progress_bar_text"].text(
+ f"Running step: {i+1 if i+1 < st.session_state.sampling_steps else st.session_state.sampling_steps}/{st.session_state.sampling_steps}"
+ f"{percent if percent < 100 else 100}%")
+
+ st.session_state["progress_bar"].progress(percent if percent < 100 else 100)
+
+
+prompt_parser = re.compile("""
+ (?P # capture group for 'prompt'
+ [^:]+ # match one or more non ':' characters
+ ) # end 'prompt'
+ (?: # non-capture group
+ :+ # match one or more ':' characters
+ (?P # capture group for 'weight'
+ -?\\d+(?:\\.\\d+)? # match positive or negative decimal number
+ )? # end weight capture group, make optional
+ \\s* # strip spaces after weight
+ | # OR
+ $ # else, if no ':' then match end of line
+ ) # end non-capture group
+""", re.VERBOSE)
+
+# grabs all text up to the first occurrence of ':' as sub-prompt
+# takes the value following ':' as weight
+# if ':' has no value defined, defaults to 1.0
+# repeats until no text remaining
+def split_weighted_subprompts(input_string, normalize=True):
+ parsed_prompts = [(match.group("prompt"), float(match.group("weight") or 1)) for match in re.finditer(prompt_parser, input_string)]
+ if not normalize:
+ return parsed_prompts
+ # this probably still doesn't handle negative weights very well
+ weight_sum = sum(map(lambda x: x[1], parsed_prompts))
+ return [(x[0], x[1] / weight_sum) for x in parsed_prompts]
+
+def slerp(device, t, v0:torch.Tensor, v1:torch.Tensor, DOT_THRESHOLD=0.9995):
+ v0 = v0.detach().cpu().numpy()
+ v1 = v1.detach().cpu().numpy()
+
+ dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1)))
+ if np.abs(dot) > DOT_THRESHOLD:
+ v2 = (1 - t) * v0 + t * v1
+ else:
+ theta_0 = np.arccos(dot)
+ sin_theta_0 = np.sin(theta_0)
+ theta_t = theta_0 * t
+ sin_theta_t = np.sin(theta_t)
+ s0 = np.sin(theta_0 - theta_t) / sin_theta_0
+ s1 = sin_theta_t / sin_theta_0
+ v2 = s0 * v0 + s1 * v1
+
+ v2 = torch.from_numpy(v2).to(device)
+
+ return v2
+
+#
+def optimize_update_preview_frequency(current_chunk_speed, previous_chunk_speed_list, update_preview_frequency, update_preview_frequency_list):
+ """Find the optimal update_preview_frequency value maximizing
+ performance while minimizing the time between updates."""
+ from statistics import mean
+
+ previous_chunk_avg_speed = mean(previous_chunk_speed_list)
+
+ previous_chunk_speed_list.append(current_chunk_speed)
+ current_chunk_avg_speed = mean(previous_chunk_speed_list)
+
+ if current_chunk_avg_speed >= previous_chunk_avg_speed:
+ #print(f"{current_chunk_speed} >= {previous_chunk_speed}")
+ update_preview_frequency_list.append(update_preview_frequency + 1)
+ else:
+ #print(f"{current_chunk_speed} <= {previous_chunk_speed}")
+ update_preview_frequency_list.append(update_preview_frequency - 1)
+
+ update_preview_frequency = round(mean(update_preview_frequency_list))
+
+ return current_chunk_speed, previous_chunk_speed_list, update_preview_frequency, update_preview_frequency_list
+
+
+def get_font(fontsize):
+ fonts = ["arial.ttf", "DejaVuSans.ttf"]
+ for font_name in fonts:
+ try:
+ return ImageFont.truetype(font_name, fontsize)
+ except OSError:
+ pass
+
+ # ImageFont.load_default() is practically unusable as it only supports
+ # latin1, so raise an exception instead if no usable font was found
+ raise Exception(f"No usable font found (tried {', '.join(fonts)})")
+
+def load_embeddings(fp):
+ if fp is not None and hasattr(st.session_state["model"], "embedding_manager"):
+ st.session_state["model"].embedding_manager.load(fp['name'])
+
+def load_learned_embed_in_clip(learned_embeds_path, text_encoder, tokenizer, token=None):
+ loaded_learned_embeds = torch.load(learned_embeds_path, map_location="cpu")
+
+ # separate token and the embeds
+ if learned_embeds_path.endswith('.pt'):
+ print(loaded_learned_embeds['string_to_token'])
+ trained_token = list(loaded_learned_embeds['string_to_token'].keys())[0]
+ embeds = list(loaded_learned_embeds['string_to_param'].values())[0]
+
+ elif learned_embeds_path.endswith('.bin'):
+ trained_token = list(loaded_learned_embeds.keys())[0]
+ embeds = loaded_learned_embeds[trained_token]
+
+ embeds = loaded_learned_embeds[trained_token]
+ # cast to dtype of text_encoder
+ dtype = text_encoder.get_input_embeddings().weight.dtype
+ embeds.to(dtype)
+
+ # add the token in tokenizer
+ token = token if token is not None else trained_token
+ num_added_tokens = tokenizer.add_tokens(token)
+
+ # resize the token embeddings
+ text_encoder.resize_token_embeddings(len(tokenizer))
+
+ # get the id for the token and assign the embeds
+ token_id = tokenizer.convert_tokens_to_ids(token)
+ text_encoder.get_input_embeddings().weight.data[token_id] = embeds
+ return token
+
+def image_grid(imgs, batch_size, force_n_rows=None, captions=None):
+ #print (len(imgs))
+ if force_n_rows is not None:
+ rows = force_n_rows
+ elif st.session_state['defaults'].general.n_rows > 0:
+ rows = st.session_state['defaults'].general.n_rows
+ elif st.session_state['defaults'].general.n_rows == 0:
+ rows = batch_size
+ else:
+ rows = math.sqrt(len(imgs))
+ rows = round(rows)
+
+ cols = math.ceil(len(imgs) / rows)
+
+ w, h = imgs[0].size
+ grid = Image.new('RGB', size=(cols * w, rows * h), color='black')
+
+ fnt = get_font(30)
+
+ for i, img in enumerate(imgs):
+ grid.paste(img, box=(i % cols * w, i // cols * h))
+ if captions and i= 2**32:
+ n = n >> 32
+ return n
+
+#
+def draw_prompt_matrix(im, width, height, all_prompts):
+ def wrap(text, d, font, line_length):
+ lines = ['']
+ for word in text.split():
+ line = f'{lines[-1]} {word}'.strip()
+ if d.textlength(line, font=font) <= line_length:
+ lines[-1] = line
+ else:
+ lines.append(word)
+ return '\n'.join(lines)
+
+ def draw_texts(pos, x, y, texts, sizes):
+ for i, (text, size) in enumerate(zip(texts, sizes)):
+ active = pos & (1 << i) != 0
+
+ if not active:
+ text = '\u0336'.join(text) + '\u0336'
+
+ d.multiline_text((x, y + size[1] / 2), text, font=fnt, fill=color_active if active else color_inactive, anchor="mm", align="center")
+
+ y += size[1] + line_spacing
+
+ fontsize = (width + height) // 25
+ line_spacing = fontsize // 2
+ fnt = get_font(fontsize)
+ color_active = (0, 0, 0)
+ color_inactive = (153, 153, 153)
+
+ pad_top = height // 4
+ pad_left = width * 3 // 4 if len(all_prompts) > 2 else 0
+
+ cols = im.width // width
+ rows = im.height // height
+
+ prompts = all_prompts[1:]
+
+ result = Image.new("RGB", (im.width + pad_left, im.height + pad_top), "white")
+ result.paste(im, (pad_left, pad_top))
+
+ d = ImageDraw.Draw(result)
+
+ boundary = math.ceil(len(prompts) / 2)
+ prompts_horiz = [wrap(x, d, fnt, width) for x in prompts[:boundary]]
+ prompts_vert = [wrap(x, d, fnt, pad_left) for x in prompts[boundary:]]
+
+ sizes_hor = [(x[2] - x[0], x[3] - x[1]) for x in [d.multiline_textbbox((0, 0), x, font=fnt) for x in prompts_horiz]]
+ sizes_ver = [(x[2] - x[0], x[3] - x[1]) for x in [d.multiline_textbbox((0, 0), x, font=fnt) for x in prompts_vert]]
+ hor_text_height = sum([x[1] + line_spacing for x in sizes_hor]) - line_spacing
+ ver_text_height = sum([x[1] + line_spacing for x in sizes_ver]) - line_spacing
+
+ for col in range(cols):
+ x = pad_left + width * col + width / 2
+ y = pad_top / 2 - hor_text_height / 2
+
+ draw_texts(col, x, y, prompts_horiz, sizes_hor)
+
+ for row in range(rows):
+ x = pad_left / 2
+ y = pad_top + height * row + height / 2 - ver_text_height / 2
+
+ draw_texts(row, x, y, prompts_vert, sizes_ver)
+
+ return result
+
+def check_prompt_length(prompt, comments):
+ """this function tests if prompt is too long, and if so, adds a message to comments"""
+
+ tokenizer = (st.session_state["model"] if not st.session_state['defaults'].general.optimized else st.session_state.modelCS).cond_stage_model.tokenizer
+ max_length = (st.session_state["model"] if not st.session_state['defaults'].general.optimized else st.session_state.modelCS).cond_stage_model.max_length
+
+ info = (st.session_state["model"] if not st.session_state['defaults'].general.optimized else st.session_state.modelCS).cond_stage_model.tokenizer([prompt], truncation=True, max_length=max_length,
+ return_overflowing_tokens=True, padding="max_length", return_tensors="pt")
+ ovf = info['overflowing_tokens'][0]
+ overflowing_count = ovf.shape[0]
+ if overflowing_count == 0:
+ return
+
+ vocab = {v: k for k, v in tokenizer.get_vocab().items()}
+ overflowing_words = [vocab.get(int(x), "") for x in ovf]
+ overflowing_text = tokenizer.convert_tokens_to_string(''.join(overflowing_words))
+
+ comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
+
+def save_sample(image, sample_path_i, filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale,
+ normalize_prompt_weights, use_GFPGAN, write_info_files, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback,
+ save_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, save_individual_images, model_name):
+
+ filename_i = os.path.join(sample_path_i, filename)
+
+ if st.session_state['defaults'].general.save_metadata or write_info_files:
+ # toggles differ for txt2img vs. img2img:
+ offset = 0 if init_img is None else 2
+ toggles = []
+ if prompt_matrix:
+ toggles.append(0)
+ if normalize_prompt_weights:
+ toggles.append(1)
+ if init_img is not None:
+ if uses_loopback:
+ toggles.append(2)
+ if uses_random_seed_loopback:
+ toggles.append(3)
+ if save_individual_images:
+ toggles.append(2 + offset)
+ if save_grid:
+ toggles.append(3 + offset)
+ if sort_samples:
+ toggles.append(4 + offset)
+ if write_info_files:
+ toggles.append(5 + offset)
+ if use_GFPGAN:
+ toggles.append(6 + offset)
+ metadata = \
+ dict(
+ target="txt2img" if init_img is None else "img2img",
+ prompt=prompts[i], ddim_steps=steps, toggles=toggles, sampler_name=sampler_name,
+ ddim_eta=ddim_eta, n_iter=n_iter, batch_size=batch_size, cfg_scale=cfg_scale,
+ seed=seeds[i], width=width, height=height, normalize_prompt_weights=normalize_prompt_weights, model_name=st.session_state["loaded_model"])
+ # Not yet any use for these, but they bloat up the files:
+ # info_dict["init_img"] = init_img
+ # info_dict["init_mask"] = init_mask
+ if init_img is not None:
+ metadata["denoising_strength"] = str(denoising_strength)
+ metadata["resize_mode"] = resize_mode
+
+ if write_info_files:
+ with open(f"{filename_i}.yaml", "w", encoding="utf8") as f:
+ yaml.dump(metadata, f, allow_unicode=True, width=10000)
+
+ if st.session_state['defaults'].general.save_metadata:
+ # metadata = {
+ # "SD:prompt": prompts[i],
+ # "SD:seed": str(seeds[i]),
+ # "SD:width": str(width),
+ # "SD:height": str(height),
+ # "SD:steps": str(steps),
+ # "SD:cfg_scale": str(cfg_scale),
+ # "SD:normalize_prompt_weights": str(normalize_prompt_weights),
+ # }
+ metadata = {"SD:" + k:v for (k,v) in metadata.items()}
+
+ if save_ext == "png":
+ mdata = PngInfo()
+ for key in metadata:
+ mdata.add_text(key, str(metadata[key]))
+ image.save(f"{filename_i}.png", pnginfo=mdata)
+ else:
+ if jpg_sample:
+ image.save(f"{filename_i}.jpg", quality=save_quality,
+ optimize=True)
+ elif save_ext == "webp":
+ image.save(f"{filename_i}.{save_ext}", f"webp", quality=save_quality,
+ lossless=save_lossless)
+ else:
+ # not sure what file format this is
+ image.save(f"{filename_i}.{save_ext}", f"{save_ext}")
+ try:
+ exif_dict = piexif.load(f"{filename_i}.{save_ext}")
+ except:
+ exif_dict = { "Exif": dict() }
+ exif_dict["Exif"][piexif.ExifIFD.UserComment] = piexif.helper.UserComment.dump(
+ json.dumps(metadata), encoding="unicode")
+ piexif.insert(piexif.dump(exif_dict), f"{filename_i}.{save_ext}")
+
+
+def get_next_sequence_number(path, prefix=''):
+ """
+ Determines and returns the next sequence number to use when saving an
+ image in the specified directory.
+
+ If a prefix is given, only consider files whose names start with that
+ prefix, and strip the prefix from filenames before extracting their
+ sequence number.
+
+ The sequence starts at 0.
+ """
+ result = -1
+ for p in Path(path).iterdir():
+ if p.name.endswith(('.png', '.jpg')) and p.name.startswith(prefix):
+ tmp = p.name[len(prefix):]
+ try:
+ result = max(int(tmp.split('-')[0]), result)
+ except ValueError:
+ pass
+ return result + 1
+
+
+def oxlamon_matrix(prompt, seed, n_iter, batch_size):
+ pattern = re.compile(r'(,\s){2,}')
+
+ class PromptItem:
+ def __init__(self, text, parts, item):
+ self.text = text
+ self.parts = parts
+ if item:
+ self.parts.append( item )
+
+ def clean(txt):
+ return re.sub(pattern, ', ', txt)
+
+ def getrowcount( txt ):
+ for data in re.finditer( ".*?\\((.*?)\\).*", txt ):
+ if data:
+ return len(data.group(1).split("|"))
+ break
+ return None
+
+ def repliter( txt ):
+ for data in re.finditer( ".*?\\((.*?)\\).*", txt ):
+ if data:
+ r = data.span(1)
+ for item in data.group(1).split("|"):
+ yield (clean(txt[:r[0]-1] + item.strip() + txt[r[1]+1:]), item.strip())
+ break
+
+ def iterlist( items ):
+ outitems = []
+ for item in items:
+ for newitem, newpart in repliter(item.text):
+ outitems.append( PromptItem(newitem, item.parts.copy(), newpart) )
+
+ return outitems
+
+ def getmatrix( prompt ):
+ dataitems = [ PromptItem( prompt[1:].strip(), [], None ) ]
+ while True:
+ newdataitems = iterlist( dataitems )
+ if len( newdataitems ) == 0:
+ return dataitems
+ dataitems = newdataitems
+
+ def classToArrays( items, seed, n_iter ):
+ texts = []
+ parts = []
+ seeds = []
+
+ for item in items:
+ itemseed = seed
+ for i in range(n_iter):
+ texts.append( item.text )
+ parts.append( f"Seed: {itemseed}\n" + "\n".join(item.parts) )
+ seeds.append( itemseed )
+ itemseed += 1
+
+ return seeds, texts, parts
+
+ all_seeds, all_prompts, prompt_matrix_parts = classToArrays(getmatrix( prompt ), seed, n_iter)
+ n_iter = math.ceil(len(all_prompts) / batch_size)
+
+ needrows = getrowcount(prompt)
+ if needrows:
+ xrows = math.sqrt(len(all_prompts))
+ xrows = round(xrows)
+ # if columns is to much
+ cols = math.ceil(len(all_prompts) / xrows)
+ if cols > needrows*4:
+ needrows *= 2
+
+ return all_seeds, n_iter, prompt_matrix_parts, all_prompts, needrows
+
+#
+def process_images(
+ outpath, func_init, func_sample, prompt, seed, sampler_name, save_grid, batch_size,
+ n_iter, steps, cfg_scale, width, height, prompt_matrix, use_GFPGAN, use_RealESRGAN, realesrgan_model_name,
+ ddim_eta=0.0, normalize_prompt_weights=True, init_img=None, init_mask=None,
+ mask_blur_strength=3, mask_restore=False, denoising_strength=0.75, noise_mode=0, find_noise_steps=1, resize_mode=None, uses_loopback=False,
+ uses_random_seed_loopback=False, sort_samples=True, write_info_files=True, jpg_sample=False,
+ variant_amount=0.0, variant_seed=None, save_individual_images: bool = True):
+ """this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
+ assert prompt is not None
+ torch_gc()
+ # start time after garbage collection (or before?)
+ start_time = time.time()
+
+ # We will use this date here later for the folder name, need to start_time if not need
+ run_start_dt = datetime.datetime.now()
+
+ mem_mon = MemUsageMonitor('MemMon')
+ mem_mon.start()
+
+ if st.session_state.defaults.general.use_sd_concepts_library:
+
+ prompt_tokens = re.findall('<([a-zA-Z0-9-]+)>', prompt)
+
+ if prompt_tokens:
+ # compviz
+ tokenizer = (st.session_state["model"] if not st.session_state['defaults'].general.optimized else st.session_state.modelCS).cond_stage_model.tokenizer
+ text_encoder = (st.session_state["model"] if not st.session_state['defaults'].general.optimized else st.session_state.modelCS).cond_stage_model.transformer
+
+ # diffusers
+ #tokenizer = pipe.tokenizer
+ #text_encoder = pipe.text_encoder
+
+ ext = ('pt', 'bin')
+
+ if len(prompt_tokens) > 1:
+ for token_name in prompt_tokens:
+ embedding_path = os.path.join(st.session_state['defaults'].general.sd_concepts_library_folder, token_name)
+ if os.path.exists(embedding_path):
+ for files in os.listdir(embedding_path):
+ if files.endswith(ext):
+ load_learned_embed_in_clip(f"{os.path.join(embedding_path, files)}", text_encoder, tokenizer, f"<{token_name}>")
+ else:
+ embedding_path = os.path.join(st.session_state['defaults'].general.sd_concepts_library_folder, prompt_tokens[0])
+ if os.path.exists(embedding_path):
+ for files in os.listdir(embedding_path):
+ if files.endswith(ext):
+ load_learned_embed_in_clip(f"{os.path.join(embedding_path, files)}", text_encoder, tokenizer, f"<{prompt_tokens[0]}>")
+
+ #
+
+
+ os.makedirs(outpath, exist_ok=True)
+
+ sample_path = os.path.join(outpath, "samples")
+ os.makedirs(sample_path, exist_ok=True)
+
+ if not ("|" in prompt) and prompt.startswith("@"):
+ prompt = prompt[1:]
+
+ negprompt = ''
+ if '###' in prompt:
+ prompt, negprompt = prompt.split('###', 1)
+ prompt = prompt.strip()
+ negprompt = negprompt.strip()
+
+ comments = []
+
+ prompt_matrix_parts = []
+ simple_templating = False
+ add_original_image = not (use_RealESRGAN or use_GFPGAN)
+
+ if prompt_matrix:
+ if prompt.startswith("@"):
+ simple_templating = True
+ add_original_image = not (use_RealESRGAN or use_GFPGAN)
+ all_seeds, n_iter, prompt_matrix_parts, all_prompts, frows = oxlamon_matrix(prompt, seed, n_iter, batch_size)
+ else:
+ all_prompts = []
+ prompt_matrix_parts = prompt.split("|")
+ combination_count = 2 ** (len(prompt_matrix_parts) - 1)
+ for combination_num in range(combination_count):
+ current = prompt_matrix_parts[0]
+
+ for n, text in enumerate(prompt_matrix_parts[1:]):
+ if combination_num & (2 ** n) > 0:
+ current += ("" if text.strip().startswith(",") else ", ") + text
+
+ all_prompts.append(current)
+
+ n_iter = math.ceil(len(all_prompts) / batch_size)
+ all_seeds = len(all_prompts) * [seed]
+
+ print(f"Prompt matrix will create {len(all_prompts)} images using a total of {n_iter} batches.")
+ else:
+
+ if not st.session_state['defaults'].general.no_verify_input:
+ try:
+ check_prompt_length(prompt, comments)
+ except:
+ import traceback
+ print("Error verifying input:", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+
+ all_prompts = batch_size * n_iter * [prompt]
+ all_seeds = [seed + x for x in range(len(all_prompts))]
+
+ precision_scope = autocast if st.session_state['defaults'].general.precision == "autocast" else nullcontext
+ output_images = []
+ grid_captions = []
+ stats = []
+ with torch.no_grad(), precision_scope("cuda"), (st.session_state["model"].ema_scope() if not st.session_state['defaults'].general.optimized else nullcontext()):
+ init_data = func_init()
+ tic = time.time()
+
+
+ # if variant_amount > 0.0 create noise from base seed
+ base_x = None
+ if variant_amount > 0.0:
+ target_seed_randomizer = seed_to_int('') # random seed
+ torch.manual_seed(seed) # this has to be the single starting seed (not per-iteration)
+ base_x = create_random_tensors([opt_C, height // opt_f, width // opt_f], seeds=[seed])
+ # we don't want all_seeds to be sequential from starting seed with variants,
+ # since that makes the same variants each time,
+ # so we add target_seed_randomizer as a random offset
+ for si in range(len(all_seeds)):
+ all_seeds[si] += target_seed_randomizer
+
+ for n in range(n_iter):
+ print(f"Iteration: {n+1}/{n_iter}")
+ prompts = all_prompts[n * batch_size:(n + 1) * batch_size]
+ captions = prompt_matrix_parts[n * batch_size:(n + 1) * batch_size]
+ seeds = all_seeds[n * batch_size:(n + 1) * batch_size]
+
+ print(prompt)
+
+ if st.session_state['defaults'].general.optimized:
+ st.session_state.modelCS.to(st.session_state['defaults'].general.gpu)
+
+ uc = (st.session_state["model"] if not st.session_state['defaults'].general.optimized else st.session_state.modelCS).get_learned_conditioning(len(prompts) * [negprompt])
+
+ if isinstance(prompts, tuple):
+ prompts = list(prompts)
+
+ # split the prompt if it has : for weighting
+ # TODO for speed it might help to have this occur when all_prompts filled??
+ weighted_subprompts = split_weighted_subprompts(prompts[0], normalize_prompt_weights)
+
+ # sub-prompt weighting used if more than 1
+ if len(weighted_subprompts) > 1:
+ c = torch.zeros_like(uc) # i dont know if this is correct.. but it works
+ for i in range(0, len(weighted_subprompts)):
+ # note if alpha negative, it functions same as torch.sub
+ c = torch.add(c, (st.session_state["model"] if not st.session_state['defaults'].general.optimized else st.session_state.modelCS).get_learned_conditioning(weighted_subprompts[i][0]), alpha=weighted_subprompts[i][1])
+ else: # just behave like usual
+ c = (st.session_state["model"] if not st.session_state['defaults'].general.optimized else st.session_state.modelCS).get_learned_conditioning(prompts)
+
+
+ shape = [opt_C, height // opt_f, width // opt_f]
+
+ if st.session_state['defaults'].general.optimized:
+ mem = torch.cuda.memory_allocated()/1e6
+ st.session_state.modelCS.to("cpu")
+ while(torch.cuda.memory_allocated()/1e6 >= mem):
+ time.sleep(1)
+
+ if noise_mode == 1 or noise_mode == 3:
+ # TODO params for find_noise_to_image
+ x = torch.cat(batch_size * [find_noise_for_image(
+ st.session_state["model"], st.session_state["device"],
+ init_img.convert('RGB'), '', find_noise_steps, 0.0, normalize=True,
+ generation_callback=generation_callback,
+ )], dim=0)
+ else:
+ # we manually generate all input noises because each one should have a specific seed
+ x = create_random_tensors(shape, seeds=seeds)
+
+ if variant_amount > 0.0: # we are making variants
+ # using variant_seed as sneaky toggle,
+ # when not None or '' use the variant_seed
+ # otherwise use seeds
+ if variant_seed != None and variant_seed != '':
+ specified_variant_seed = seed_to_int(variant_seed)
+ torch.manual_seed(specified_variant_seed)
+ seeds = [specified_variant_seed]
+ # finally, slerp base_x noise to target_x noise for creating a variant
+ x = slerp(st.session_state['defaults'].general.gpu, max(0.0, min(1.0, variant_amount)), base_x, x)
+
+ samples_ddim = func_sample(init_data=init_data, x=x, conditioning=c, unconditional_conditioning=uc, sampler_name=sampler_name)
+
+ if st.session_state['defaults'].general.optimized:
+ st.session_state.modelFS.to(st.session_state['defaults'].general.gpu)
+
+ x_samples_ddim = (st.session_state["model"] if not st.session_state['defaults'].general.optimized else st.session_state.modelFS).decode_first_stage(samples_ddim)
+ x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
+
+ run_images = []
+ for i, x_sample in enumerate(x_samples_ddim):
+ sanitized_prompt = slugify(prompts[i])
+
+ percent = i / len(x_samples_ddim)
+ st.session_state["progress_bar"].progress(percent if percent < 100 else 100)
+
+ if sort_samples:
+ full_path = os.path.join(os.getcwd(), sample_path, sanitized_prompt)
+
+
+ sanitized_prompt = sanitized_prompt[:220-len(full_path)]
+ sample_path_i = os.path.join(sample_path, sanitized_prompt)
+
+ #print(f"output folder length: {len(os.path.join(os.getcwd(), sample_path_i))}")
+ #print(os.path.join(os.getcwd(), sample_path_i))
+
+ os.makedirs(sample_path_i, exist_ok=True)
+ base_count = get_next_sequence_number(sample_path_i)
+ filename = f"{base_count:05}-{steps}_{sampler_name}_{seeds[i]}"
+ else:
+ full_path = os.path.join(os.getcwd(), sample_path)
+ sample_path_i = sample_path
+ base_count = get_next_sequence_number(sample_path_i)
+ filename = f"{base_count:05}-{steps}_{sampler_name}_{seeds[i]}_{sanitized_prompt}"[:220-len(full_path)] #same as before
+
+ x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
+ x_sample = x_sample.astype(np.uint8)
+ image = Image.fromarray(x_sample)
+ original_sample = x_sample
+ original_filename = filename
+
+ st.session_state["preview_image"].image(image)
+
+ if use_GFPGAN and st.session_state["GFPGAN"] is not None and not use_RealESRGAN:
+ st.session_state["progress_bar_text"].text("Running GFPGAN on image %d of %d..." % (i+1, len(x_samples_ddim)))
+ #skip_save = True # #287 >_>
+ torch_gc()
+ cropped_faces, restored_faces, restored_img = st.session_state["GFPGAN"].enhance(x_sample[:,:,::-1], has_aligned=False, only_center_face=False, paste_back=True)
+ gfpgan_sample = restored_img[:,:,::-1]
+ gfpgan_image = Image.fromarray(gfpgan_sample)
+ gfpgan_filename = original_filename + '-gfpgan'
+
+ save_sample(gfpgan_image, sample_path_i, gfpgan_filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale,
+ normalize_prompt_weights, use_GFPGAN, write_info_files, prompt_matrix, init_img, uses_loopback,
+ uses_random_seed_loopback, save_grid, sort_samples, sampler_name, ddim_eta,
+ n_iter, batch_size, i, denoising_strength, resize_mode, False, st.session_state["loaded_model"])
+
+ output_images.append(gfpgan_image) #287
+ run_images.append(gfpgan_image)
+
+ if simple_templating:
+ grid_captions.append( captions[i] + "\ngfpgan" )
+
+ elif use_RealESRGAN and st.session_state["RealESRGAN"] is not None and not use_GFPGAN:
+ st.session_state["progress_bar_text"].text("Running RealESRGAN on image %d of %d..." % (i+1, len(x_samples_ddim)))
+ #skip_save = True # #287 >_>
+ torch_gc()
+
+ if st.session_state["RealESRGAN"].model.name != realesrgan_model_name:
+ #try_loading_RealESRGAN(realesrgan_model_name)
+ load_models(use_GFPGAN=use_GFPGAN, use_RealESRGAN=use_RealESRGAN, RealESRGAN_model=realesrgan_model_name)
+
+ output, img_mode = st.session_state["RealESRGAN"].enhance(x_sample[:,:,::-1])
+ esrgan_filename = original_filename + '-esrgan4x'
+ esrgan_sample = output[:,:,::-1]
+ esrgan_image = Image.fromarray(esrgan_sample)
+
+ #save_sample(image, sample_path_i, original_filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale,
+ #normalize_prompt_weights, use_GFPGAN, write_info_files, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, skip_save,
+ #save_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode)
+
+ save_sample(esrgan_image, sample_path_i, esrgan_filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale,
+ normalize_prompt_weights, use_GFPGAN, write_info_files, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback,
+ save_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, False, st.session_state["loaded_model"])
+
+ output_images.append(esrgan_image) #287
+ run_images.append(esrgan_image)
+
+ if simple_templating:
+ grid_captions.append( captions[i] + "\nesrgan" )
+
+ elif use_RealESRGAN and st.session_state["RealESRGAN"] is not None and use_GFPGAN and st.session_state["GFPGAN"] is not None:
+ st.session_state["progress_bar_text"].text("Running GFPGAN+RealESRGAN on image %d of %d..." % (i+1, len(x_samples_ddim)))
+ #skip_save = True # #287 >_>
+ torch_gc()
+ cropped_faces, restored_faces, restored_img = st.session_state["GFPGAN"].enhance(x_sample[:,:,::-1], has_aligned=False, only_center_face=False, paste_back=True)
+ gfpgan_sample = restored_img[:,:,::-1]
+
+ if st.session_state["RealESRGAN"].model.name != realesrgan_model_name:
+ #try_loading_RealESRGAN(realesrgan_model_name)
+ load_models(use_GFPGAN=use_GFPGAN, use_RealESRGAN=use_RealESRGAN, RealESRGAN_model=realesrgan_model_name)
+
+ output, img_mode = st.session_state["RealESRGAN"].enhance(gfpgan_sample[:,:,::-1])
+ gfpgan_esrgan_filename = original_filename + '-gfpgan-esrgan4x'
+ gfpgan_esrgan_sample = output[:,:,::-1]
+ gfpgan_esrgan_image = Image.fromarray(gfpgan_esrgan_sample)
+
+ save_sample(gfpgan_esrgan_image, sample_path_i, gfpgan_esrgan_filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale,
+ normalize_prompt_weights, False, write_info_files, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback,
+ save_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, False, st.session_state["loaded_model"])
+
+ output_images.append(gfpgan_esrgan_image) #287
+ run_images.append(gfpgan_esrgan_image)
+
+ if simple_templating:
+ grid_captions.append( captions[i] + "\ngfpgan_esrgan" )
+ else:
+ output_images.append(image)
+ run_images.append(image)
+
+ if mask_restore and init_mask:
+ #init_mask = init_mask if keep_mask else ImageOps.invert(init_mask)
+ init_mask = init_mask.filter(ImageFilter.GaussianBlur(mask_blur_strength))
+ init_mask = init_mask.convert('L')
+ init_img = init_img.convert('RGB')
+ image = image.convert('RGB')
+
+ if use_RealESRGAN and st.session_state["RealESRGAN"] is not None:
+ if st.session_state["RealESRGAN"].model.name != realesrgan_model_name:
+ #try_loading_RealESRGAN(realesrgan_model_name)
+ load_models(use_GFPGAN=use_GFPGAN, use_RealESRGAN=use_RealESRGAN, RealESRGAN_model=realesrgan_model_name)
+
+ output, img_mode = st.session_state["RealESRGAN"].enhance(np.array(init_img, dtype=np.uint8))
+ init_img = Image.fromarray(output)
+ init_img = init_img.convert('RGB')
+
+ output, img_mode = st.session_state["RealESRGAN"].enhance(np.array(init_mask, dtype=np.uint8))
+ init_mask = Image.fromarray(output)
+ init_mask = init_mask.convert('L')
+
+ image = Image.composite(init_img, image, init_mask)
+
+ if save_individual_images:
+ save_sample(image, sample_path_i, filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale,
+ normalize_prompt_weights, use_GFPGAN, write_info_files, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback,
+ save_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, save_individual_images, st.session_state["loaded_model"])
+
+ #if add_original_image or not simple_templating:
+ #output_images.append(image)
+ #if simple_templating:
+ #grid_captions.append( captions[i] )
+
+ if st.session_state['defaults'].general.optimized:
+ mem = torch.cuda.memory_allocated()/1e6
+ st.session_state.modelFS.to("cpu")
+ while(torch.cuda.memory_allocated()/1e6 >= mem):
+ time.sleep(1)
+
+ if len(run_images) > 1:
+ preview_image = image_grid(run_images, n_iter)
+ else:
+ preview_image = run_images[0]
+
+ # Constrain the final preview image to 1440x900 so we're not sending huge amounts of data
+ # to the browser
+ preview_image = constrain_image(preview_image, 1440, 900)
+ st.session_state["progress_bar_text"].text("Finished!")
+ st.session_state["preview_image"].image(preview_image)
+
+ if prompt_matrix or save_grid:
+ if prompt_matrix:
+ if simple_templating:
+ grid = image_grid(output_images, n_iter, force_n_rows=frows, captions=grid_captions)
+ else:
+ grid = image_grid(output_images, n_iter, force_n_rows=1 << ((len(prompt_matrix_parts)-1)//2))
+ try:
+ grid = draw_prompt_matrix(grid, width, height, prompt_matrix_parts)
+ except:
+ import traceback
+ print("Error creating prompt_matrix text:", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+ else:
+ grid = image_grid(output_images, batch_size)
+
+ if grid and (batch_size > 1 or n_iter > 1):
+ output_images.insert(0, grid)
+
+ grid_count = get_next_sequence_number(outpath, 'grid-')
+ grid_file = f"grid-{grid_count:05}-{seed}_{slugify(prompts[i].replace(' ', '_')[:220-len(full_path)])}.{grid_ext}"
+ grid.save(os.path.join(outpath, grid_file), grid_format, quality=grid_quality, lossless=grid_lossless, optimize=True)
+
+ toc = time.time()
+
+ mem_max_used, mem_total = mem_mon.read_and_stop()
+ time_diff = time.time()-start_time
+
+ info = f"""
+ {prompt}
+ Steps: {steps}, Sampler: {sampler_name}, CFG scale: {cfg_scale}, Seed: {seed}{', Denoising strength: '+str(denoising_strength) if init_img is not None else ''}{', GFPGAN' if use_GFPGAN and st.session_state["GFPGAN"] is not None else ''}{', '+realesrgan_model_name if use_RealESRGAN and st.session_state["RealESRGAN"] is not None else ''}{', Prompt Matrix Mode.' if prompt_matrix else ''}""".strip()
+ stats = f'''
+ Took { round(time_diff, 2) }s total ({ round(time_diff/(len(all_prompts)),2) }s per image)
+ Peak memory usage: { -(mem_max_used // -1_048_576) } MiB / { -(mem_total // -1_048_576) } MiB / { round(mem_max_used/mem_total*100, 3) }%'''
+
+ for comment in comments:
+ info += "\n\n" + comment
+
+ #mem_mon.stop()
+ #del mem_mon
+ torch_gc()
+
+ return output_images, seed, info, stats
+
+
+def resize_image(resize_mode, im, width, height):
+ LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
+ if resize_mode == 0:
+ res = im.resize((width, height), resample=LANCZOS)
+ elif resize_mode == 1:
+ ratio = width / height
+ src_ratio = im.width / im.height
+
+ src_w = width if ratio > src_ratio else im.width * height // im.height
+ src_h = height if ratio <= src_ratio else im.height * width // im.width
+
+ resized = im.resize((src_w, src_h), resample=LANCZOS)
+ res = Image.new("RGBA", (width, height))
+ res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
+ else:
+ ratio = width / height
+ src_ratio = im.width / im.height
+
+ src_w = width if ratio < src_ratio else im.width * height // im.height
+ src_h = height if ratio >= src_ratio else im.height * width // im.width
+
+ resized = im.resize((src_w, src_h), resample=LANCZOS)
+ res = Image.new("RGBA", (width, height))
+ res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
+
+ if ratio < src_ratio:
+ fill_height = height // 2 - src_h // 2
+ res.paste(resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0))
+ res.paste(resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)), box=(0, fill_height + src_h))
+ elif ratio > src_ratio:
+ fill_width = width // 2 - src_w // 2
+ res.paste(resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0))
+ res.paste(resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)), box=(fill_width + src_w, 0))
+
+ return res
+
+def constrain_image(img, max_width, max_height):
+ ratio = max(img.width / max_width, img.height / max_height)
+ if ratio <= 1:
+ return img
+ resampler = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
+ resized = img.resize((int(img.width / ratio), int(img.height / ratio)), resample=resampler)
+ return resized
diff --git a/scripts/stable_diffusion_pipeline.py b/scripts/stable_diffusion_pipeline.py
new file mode 100644
index 0000000..6f4f794
--- /dev/null
+++ b/scripts/stable_diffusion_pipeline.py
@@ -0,0 +1,233 @@
+import inspect
+import warnings
+from tqdm.auto import tqdm
+from typing import List, Optional, Union
+
+import torch
+from diffusers import ModelMixin
+from diffusers.models import AutoencoderKL, UNet2DConditionModel
+from diffusers.pipeline_utils import DiffusionPipeline
+from diffusers.pipelines.stable_diffusion.safety_checker import \
+ StableDiffusionSafetyChecker
+from diffusers.schedulers import (DDIMScheduler, LMSDiscreteScheduler,
+ PNDMScheduler)
+from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
+
+
+class StableDiffusionPipeline(DiffusionPipeline):
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPFeatureExtractor,
+ ):
+ super().__init__()
+ scheduler = scheduler.set_format("pt")
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Optional[Union[str, List[str]]] = None,
+ height: Optional[int] = 512,
+ width: Optional[int] = 512,
+ num_inference_steps: Optional[int] = 50,
+ guidance_scale: Optional[float] = 7.5,
+ eta: Optional[float] = 0.0,
+ generator: Optional[torch.Generator] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ text_embeddings: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ **kwargs,
+ ):
+ if "torch_device" in kwargs:
+ device = kwargs.pop("torch_device")
+ warnings.warn(
+ "`torch_device` is deprecated as an input argument to `__call__` and"
+ " will be removed in v0.3.0. Consider using `pipe.to(torch_device)`"
+ " instead."
+ )
+
+ # Set device as before (to be removed in 0.3.0)
+ if device is None:
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ self.to(device)
+
+ if text_embeddings is None:
+ if isinstance(prompt, str):
+ batch_size = 1
+ elif isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ raise ValueError(
+ f"`prompt` has to be of type `str` or `list` but is {type(prompt)}"
+ )
+
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(
+ "`height` and `width` have to be divisible by 8 but are"
+ f" {height} and {width}."
+ )
+
+ # get prompt text embeddings
+ text_input = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]
+ else:
+ batch_size = text_embeddings.shape[0]
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance:
+ # max_length = text_input.input_ids.shape[-1]
+ max_length = 77 # self.tokenizer.model_max_length
+ uncond_input = self.tokenizer(
+ [""] * batch_size,
+ padding="max_length",
+ max_length=max_length,
+ return_tensors="pt",
+ )
+ uncond_embeddings = self.text_encoder(
+ uncond_input.input_ids.to(self.device)
+ )[0]
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
+
+ # get the initial random noise unless the user supplied it
+ latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8)
+ if latents is None:
+ latents = torch.randn(
+ latents_shape,
+ generator=generator,
+ device=self.device,
+ )
+ else:
+ if latents.shape != latents_shape:
+ raise ValueError(
+ f"Unexpected latents shape, got {latents.shape}, expected"
+ f" {latents_shape}"
+ )
+ latents = latents.to(self.device)
+
+ # set timesteps
+ accepts_offset = "offset" in set(
+ inspect.signature(self.scheduler.set_timesteps).parameters.keys()
+ )
+ extra_set_kwargs = {}
+ if accepts_offset:
+ extra_set_kwargs["offset"] = 1
+
+ self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
+
+ # if we use LMSDiscreteScheduler, let's make sure latents are mulitplied by sigmas
+ if isinstance(self.scheduler, LMSDiscreteScheduler):
+ latents = latents * self.scheduler.sigmas[0]
+
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+ accepts_eta = "eta" in set(
+ inspect.signature(self.scheduler.step).parameters.keys()
+ )
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = (
+ torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ )
+ if isinstance(self.scheduler, LMSDiscreteScheduler):
+ sigma = self.scheduler.sigmas[i]
+ # the model input needs to be scaled to match the continuous ODE formulation in K-LMS
+ latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
+
+ # predict the noise residual
+ noise_pred = self.unet(
+ latent_model_input, t, encoder_hidden_states=text_embeddings
+ )["sample"]
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (
+ noise_pred_text - noise_pred_uncond
+ )
+
+ # compute the previous noisy sample x_t -> x_t-1
+ if isinstance(self.scheduler, LMSDiscreteScheduler):
+ latents = self.scheduler.step(
+ noise_pred, i, latents, **extra_step_kwargs
+ )["prev_sample"]
+ else:
+ latents = self.scheduler.step(
+ noise_pred, t, latents, **extra_step_kwargs
+ )["prev_sample"]
+
+ # scale and decode the image latents with vae
+ latents = 1 / 0.18215 * latents
+ image = self.vae.decode(latents).sample
+
+ image = (image / 2 + 0.5).clamp(0, 1)
+ image = image.cpu().permute(0, 2, 3, 1).numpy()
+
+ safety_cheker_input = self.feature_extractor(
+ self.numpy_to_pil(image), return_tensors="pt"
+ ).to(self.device)
+ image, has_nsfw_concept = self.safety_checker(
+ images=image, clip_input=safety_cheker_input.pixel_values
+ )
+
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ return {"sample": image, "nsfw_content_detected": has_nsfw_concept}
+
+ def embed_text(self, text):
+ """Helper to embed some text"""
+ with torch.autocast("cuda"):
+ text_input = self.tokenizer(
+ text,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ with torch.no_grad():
+ embed = self.text_encoder(text_input.input_ids.to(self.device))[0]
+ return embed
+
+
+class NoCheck(ModelMixin):
+ """Can be used in place of safety checker. Use responsibly and at your own risk."""
+ def __init__(self):
+ super().__init__()
+ self.register_parameter(name='asdf', param=torch.nn.Parameter(torch.randn(3)))
+
+ def forward(self, images=None, **kwargs):
+ return images, [False]
diff --git a/scripts/stable_diffusion_walk.py b/scripts/stable_diffusion_walk.py
new file mode 100644
index 0000000..1ce175d
--- /dev/null
+++ b/scripts/stable_diffusion_walk.py
@@ -0,0 +1,218 @@
+import json
+import subprocess
+from pathlib import Path
+
+import numpy as np
+import torch
+from diffusers.schedulers import (DDIMScheduler, LMSDiscreteScheduler,
+ PNDMScheduler)
+from diffusers import ModelMixin
+
+from stable_diffusion_pipeline import StableDiffusionPipeline
+
+pipeline = StableDiffusionPipeline.from_pretrained(
+ "CompVis/stable-diffusion-v1-4",
+ use_auth_token=True,
+ torch_dtype=torch.float16,
+ revision="fp16",
+).to("cuda")
+
+default_scheduler = PNDMScheduler(
+ beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
+)
+ddim_scheduler = DDIMScheduler(
+ beta_start=0.00085,
+ beta_end=0.012,
+ beta_schedule="scaled_linear",
+ clip_sample=False,
+ set_alpha_to_one=False,
+)
+klms_scheduler = LMSDiscreteScheduler(
+ beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
+)
+SCHEDULERS = dict(default=default_scheduler, ddim=ddim_scheduler, klms=klms_scheduler)
+
+
+def slerp(t, v0, v1, DOT_THRESHOLD=0.9995):
+ """helper function to spherically interpolate two arrays v1 v2"""
+
+ if not isinstance(v0, np.ndarray):
+ inputs_are_torch = True
+ input_device = v0.device
+ v0 = v0.cpu().numpy()
+ v1 = v1.cpu().numpy()
+
+ dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1)))
+ if np.abs(dot) > DOT_THRESHOLD:
+ v2 = (1 - t) * v0 + t * v1
+ else:
+ theta_0 = np.arccos(dot)
+ sin_theta_0 = np.sin(theta_0)
+ theta_t = theta_0 * t
+ sin_theta_t = np.sin(theta_t)
+ s0 = np.sin(theta_0 - theta_t) / sin_theta_0
+ s1 = sin_theta_t / sin_theta_0
+ v2 = s0 * v0 + s1 * v1
+
+ if inputs_are_torch:
+ v2 = torch.from_numpy(v2).to(input_device)
+
+ return v2
+
+
+def make_video_ffmpeg(frame_dir, output_file_name='output.mp4', frame_filename="frame%06d.jpg", fps=30):
+ frame_ref_path = str(frame_dir / frame_filename)
+ video_path = str(frame_dir / output_file_name)
+ subprocess.call(
+ f"ffmpeg -r {fps} -i {frame_ref_path} -vcodec libx264 -crf 10 -pix_fmt yuv420p"
+ f" {video_path}".split()
+ )
+ return video_path
+
+
+def walk(
+ prompts=["blueberry spaghetti", "strawberry spaghetti"],
+ seeds=[42, 123],
+ num_steps=5,
+ output_dir="dreams",
+ name="berry_good_spaghetti",
+ height=512,
+ width=512,
+ guidance_scale=7.5,
+ eta=0.0,
+ num_inference_steps=50,
+ do_loop=False,
+ make_video=False,
+ use_lerp_for_text=False,
+ scheduler="klms", # choices: default, ddim, klms
+ disable_tqdm=False,
+ upsample=False,
+ fps=30,
+):
+ """Generate video frames/a video given a list of prompts and seeds.
+
+ Args:
+ prompts (List[str], optional): List of . Defaults to ["blueberry spaghetti", "strawberry spaghetti"].
+ seeds (List[int], optional): List of random seeds corresponding to given prompts.
+ num_steps (int, optional): Number of steps to walk. Increase this value to 60-200 for good results. Defaults to 5.
+ output_dir (str, optional): Root dir where images will be saved. Defaults to "dreams".
+ name (str, optional): Sub directory of output_dir to save this run's files. Defaults to "berry_good_spaghetti".
+ height (int, optional): Height of image to generate. Defaults to 512.
+ width (int, optional): Width of image to generate. Defaults to 512.
+ guidance_scale (float, optional): Higher = more adherance to prompt. Lower = let model take the wheel. Defaults to 7.5.
+ eta (float, optional): ETA. Defaults to 0.0.
+ num_inference_steps (int, optional): Number of diffusion steps. Defaults to 50.
+ do_loop (bool, optional): Whether to loop from last prompt back to first. Defaults to False.
+ make_video (bool, optional): Whether to make a video or just save the images. Defaults to False.
+ use_lerp_for_text (bool, optional): Use LERP instead of SLERP for text embeddings when walking. Defaults to False.
+ scheduler (str, optional): Which scheduler to use. Defaults to "klms". Choices are "default", "ddim", "klms".
+ disable_tqdm (bool, optional): Whether to turn off the tqdm progress bars. Defaults to False.
+ upsample (bool, optional): If True, uses Real-ESRGAN to upsample images 4x. Requires it to be installed
+ which you can do by running: `pip install git+https://github.com/xinntao/Real-ESRGAN.git`. Defaults to False.
+ fps (int, optional): The frames per second (fps) that you want the video to use. Does nothing if make_video is False. Defaults to 30.
+
+ Returns:
+ str: Path to video file saved if make_video=True, else None.
+ """
+ if upsample:
+ from .upsampling import PipelineRealESRGAN
+
+ upsampling_pipeline = PipelineRealESRGAN.from_pretrained('nateraw/real-esrgan')
+
+ pipeline.set_progress_bar_config(disable=disable_tqdm)
+
+ pipeline.scheduler = SCHEDULERS[scheduler]
+
+ output_path = Path(output_dir) / name
+ output_path.mkdir(exist_ok=True, parents=True)
+
+ # Write prompt info to file in output dir so we can keep track of what we did
+ prompt_config_path = output_path / 'prompt_config.json'
+ prompt_config_path.write_text(
+ json.dumps(
+ dict(
+ prompts=prompts,
+ seeds=seeds,
+ num_steps=num_steps,
+ name=name,
+ guidance_scale=guidance_scale,
+ eta=eta,
+ num_inference_steps=num_inference_steps,
+ do_loop=do_loop,
+ make_video=make_video,
+ use_lerp_for_text=use_lerp_for_text,
+ scheduler=scheduler
+ ),
+ indent=2,
+ sort_keys=False,
+ )
+ )
+
+ assert len(prompts) == len(seeds)
+
+ first_prompt, *prompts = prompts
+ embeds_a = pipeline.embed_text(first_prompt)
+
+ first_seed, *seeds = seeds
+ latents_a = torch.randn(
+ (1, pipeline.unet.in_channels, height // 8, width // 8),
+ device=pipeline.device,
+ generator=torch.Generator(device=pipeline.device).manual_seed(first_seed),
+ )
+
+ if do_loop:
+ prompts.append(first_prompt)
+ seeds.append(first_seed)
+
+ frame_index = 0
+ for prompt, seed in zip(prompts, seeds):
+ # Text
+ embeds_b = pipeline.embed_text(prompt)
+
+ # Latent Noise
+ latents_b = torch.randn(
+ (1, pipeline.unet.in_channels, height // 8, width // 8),
+ device=pipeline.device,
+ generator=torch.Generator(device=pipeline.device).manual_seed(seed),
+ )
+
+ for i, t in enumerate(np.linspace(0, 1, num_steps)):
+ do_print_progress = (i == 0) or ((frame_index + 1) % 20 == 0)
+ if do_print_progress:
+ print(f"COUNT: {frame_index+1}/{len(seeds)*num_steps}")
+
+ if use_lerp_for_text:
+ embeds = torch.lerp(embeds_a, embeds_b, float(t))
+ else:
+ embeds = slerp(float(t), embeds_a, embeds_b)
+ latents = slerp(float(t), latents_a, latents_b)
+
+ with torch.autocast("cuda"):
+ im = pipeline(
+ latents=latents,
+ text_embeddings=embeds,
+ height=height,
+ width=width,
+ guidance_scale=guidance_scale,
+ eta=eta,
+ num_inference_steps=num_inference_steps,
+ output_type='pil' if not upsample else 'numpy'
+ )["sample"][0]
+
+ if upsample:
+ im = upsampling_pipeline(im)
+
+ im.save(output_path / ("frame%06d.jpg" % frame_index))
+ frame_index += 1
+
+ embeds_a = embeds_b
+ latents_a = latents_b
+
+ if make_video:
+ return make_video_ffmpeg(output_path, f"{name}.mp4", fps=fps)
+
+
+if __name__ == "__main__":
+ import fire
+
+ fire.Fire(walk)
diff --git a/scripts/textual_inversion.py b/scripts/textual_inversion.py
new file mode 100644
index 0000000..3e5cc3e
--- /dev/null
+++ b/scripts/textual_inversion.py
@@ -0,0 +1,57 @@
+# base webui import and utils.
+from webui_streamlit import st
+from sd_utils import *
+
+# streamlit imports
+
+
+#other imports
+#from transformers import CLIPTextModel, CLIPTokenizer
+
+# Temp imports
+
+
+# end of imports
+#---------------------------------------------------------------------------------------------------------------
+
+#def load_learned_embed_in_clip(learned_embeds_path, text_encoder, tokenizer, token=None):
+
+ #loaded_learned_embeds = torch.load(learned_embeds_path, map_location="cpu")
+
+ ## separate token and the embeds
+ #print (loaded_learned_embeds)
+ #trained_token = list(loaded_learned_embeds.keys())[0]
+ #embeds = loaded_learned_embeds[trained_token]
+
+ ## cast to dtype of text_encoder
+ #dtype = text_encoder.get_input_embeddings().weight.dtype
+ #embeds.to(dtype)
+
+ ## add the token in tokenizer
+ #token = token if token is not None else trained_token
+ #num_added_tokens = tokenizer.add_tokens(token)
+ #i = 1
+ #while(num_added_tokens == 0):
+ #print(f"The tokenizer already contains the token {token}.")
+ #token = f"{token[:-1]}-{i}>"
+ #print(f"Attempting to add the token {token}.")
+ #num_added_tokens = tokenizer.add_tokens(token)
+ #i+=1
+
+ ## resize the token embeddings
+ #text_encoder.resize_token_embeddings(len(tokenizer))
+
+ ## get the id for the token and assign the embeds
+ #token_id = tokenizer.convert_tokens_to_ids(token)
+ #text_encoder.get_input_embeddings().weight.data[token_id] = embeds
+ #return token
+
+##def token_loader()
+#learned_token = load_learned_embed_in_clip(f"models/custom/embeddings/Custom Ami.pt", st.session_state.pipe.text_encoder, st.session_state.pipe.tokenizer, "*")
+#model_content["token"] = learned_token
+#models.append(model_content)
+
+model_id = "./models/custom/embeddings/"
+
+def layout():
+ st.write("Textual Inversion")
\ No newline at end of file
diff --git a/scripts/txt2img.py b/scripts/txt2img.py
new file mode 100644
index 0000000..6f74143
--- /dev/null
+++ b/scripts/txt2img.py
@@ -0,0 +1,368 @@
+# base webui import and utils.
+from webui_streamlit import st
+from sd_utils import *
+
+# streamlit imports
+from streamlit import StopException
+from streamlit.runtime.in_memory_file_manager import in_memory_file_manager
+from streamlit.elements import image as STImage
+
+#other imports
+import os
+from typing import Union
+from io import BytesIO
+from ldm.models.diffusion.ddim import DDIMSampler
+from ldm.models.diffusion.plms import PLMSSampler
+
+# Temp imports
+
+
+# end of imports
+#---------------------------------------------------------------------------------------------------------------
+
+
+try:
+ # this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
+ from transformers import logging
+
+ logging.set_verbosity_error()
+except:
+ pass
+
+class plugin_info():
+ plugname = "txt2img"
+ description = "Text to Image"
+ isTab = True
+ displayPriority = 1
+
+
+if os.path.exists(os.path.join(st.session_state['defaults'].general.GFPGAN_dir, "experiments", "pretrained_models", "GFPGANv1.3.pth")):
+ GFPGAN_available = True
+else:
+ GFPGAN_available = False
+
+if os.path.exists(os.path.join(st.session_state['defaults'].general.RealESRGAN_dir, "experiments","pretrained_models", f"{st.session_state['defaults'].general.RealESRGAN_model}.pth")):
+ RealESRGAN_available = True
+else:
+ RealESRGAN_available = False
+
+#
+def txt2img(prompt: str, ddim_steps: int, sampler_name: str, realesrgan_model_name: str,
+ n_iter: int, batch_size: int, cfg_scale: float, seed: Union[int, str, None],
+ height: int, width: int, separate_prompts:bool = False, normalize_prompt_weights:bool = True,
+ save_individual_images: bool = True, save_grid: bool = True, group_by_prompt: bool = True,
+ save_as_jpg: bool = True, use_GFPGAN: bool = True, use_RealESRGAN: bool = True,
+ RealESRGAN_model: str = "RealESRGAN_x4plus_anime_6B", fp = None, variant_amount: float = None,
+ variant_seed: int = None, ddim_eta:float = 0.0, write_info_files:bool = True):
+
+ outpath = st.session_state['defaults'].general.outdir_txt2img or st.session_state['defaults'].general.outdir or "outputs/txt2img-samples"
+
+ seed = seed_to_int(seed)
+
+ #prompt_matrix = 0 in toggles
+ #normalize_prompt_weights = 1 in toggles
+ #skip_save = 2 not in toggles
+ #save_grid = 3 not in toggles
+ #sort_samples = 4 in toggles
+ #write_info_files = 5 in toggles
+ #jpg_sample = 6 in toggles
+ #use_GFPGAN = 7 in toggles
+ #use_RealESRGAN = 8 in toggles
+
+ if sampler_name == 'PLMS':
+ sampler = PLMSSampler(st.session_state["model"])
+ elif sampler_name == 'DDIM':
+ sampler = DDIMSampler(st.session_state["model"])
+ elif sampler_name == 'k_dpm_2_a':
+ sampler = KDiffusionSampler(st.session_state["model"],'dpm_2_ancestral')
+ elif sampler_name == 'k_dpm_2':
+ sampler = KDiffusionSampler(st.session_state["model"],'dpm_2')
+ elif sampler_name == 'k_euler_a':
+ sampler = KDiffusionSampler(st.session_state["model"],'euler_ancestral')
+ elif sampler_name == 'k_euler':
+ sampler = KDiffusionSampler(st.session_state["model"],'euler')
+ elif sampler_name == 'k_heun':
+ sampler = KDiffusionSampler(st.session_state["model"],'heun')
+ elif sampler_name == 'k_lms':
+ sampler = KDiffusionSampler(st.session_state["model"],'lms')
+ else:
+ raise Exception("Unknown sampler: " + sampler_name)
+
+ def init():
+ pass
+
+ def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name):
+ samples_ddim, _ = sampler.sample(S=ddim_steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=cfg_scale,
+ unconditional_conditioning=unconditional_conditioning, eta=ddim_eta, x_T=x, img_callback=generation_callback,
+ log_every_t=int(st.session_state.update_preview_frequency))
+
+ return samples_ddim
+
+ #try:
+ output_images, seed, info, stats = process_images(
+ outpath=outpath,
+ func_init=init,
+ func_sample=sample,
+ prompt=prompt,
+ seed=seed,
+ sampler_name=sampler_name,
+ save_grid=save_grid,
+ batch_size=batch_size,
+ n_iter=n_iter,
+ steps=ddim_steps,
+ cfg_scale=cfg_scale,
+ width=width,
+ height=height,
+ prompt_matrix=separate_prompts,
+ use_GFPGAN=st.session_state["use_GFPGAN"],
+ use_RealESRGAN=st.session_state["use_RealESRGAN"],
+ realesrgan_model_name=realesrgan_model_name,
+ ddim_eta=ddim_eta,
+ normalize_prompt_weights=normalize_prompt_weights,
+ save_individual_images=save_individual_images,
+ sort_samples=group_by_prompt,
+ write_info_files=write_info_files,
+ jpg_sample=save_as_jpg,
+ variant_amount=variant_amount,
+ variant_seed=variant_seed,
+ )
+
+ del sampler
+
+ return output_images, seed, info, stats
+
+ #except RuntimeError as e:
+ #err = e
+ #err_msg = f'CRASHED:
Please wait while the program restarts.'
+ #stats = err_msg
+ #return [], seed, 'err', stats
+
+def layout():
+ with st.form("txt2img-inputs"):
+ st.session_state["generation_mode"] = "txt2img"
+
+ input_col1, generate_col1 = st.columns([10,1])
+
+ with input_col1:
+ #prompt = st.text_area("Input Text","")
+ prompt = st.text_input("Input Text","", placeholder="A corgi wearing a top hat as an oil painting.")
+
+ # Every form must have a submit button, the extra blank spaces is a temp way to align it with the input field. Needs to be done in CSS or some other way.
+ generate_col1.write("")
+ generate_col1.write("")
+ generate_button = generate_col1.form_submit_button("Generate")
+
+ # creating the page layout using columns
+ col1, col2, col3 = st.columns([1,2,1], gap="large")
+
+ with col1:
+ width = st.slider("Width:", min_value=64, max_value=4096, value=st.session_state['defaults'].txt2img.width, step=64)
+ height = st.slider("Height:", min_value=64, max_value=4096, value=st.session_state['defaults'].txt2img.height, step=64)
+ cfg_scale = st.slider("CFG (Classifier Free Guidance Scale):", min_value=1.0, max_value=30.0, value=st.session_state['defaults'].txt2img.cfg_scale, step=0.5, help="How strongly the image should follow the prompt.")
+ seed = st.text_input("Seed:", value=st.session_state['defaults'].txt2img.seed, help=" The seed to use, if left blank a random seed will be generated.")
+ batch_count = st.slider("Batch count.", min_value=1, max_value=100, value=st.session_state['defaults'].txt2img.batch_count, step=1, help="How many iterations or batches of images to generate in total.")
+
+ bs_slider_max_value = 5
+ if st.session_state.defaults.general.optimized:
+ bs_slider_max_value = 100
+
+ batch_size = st.slider(
+ "Batch size",
+ min_value=1,
+ max_value=bs_slider_max_value,
+ value=st.session_state.defaults.txt2img.batch_size,
+ step=1,
+ help="How many images are at once in a batch.\
+ It increases the VRAM usage a lot but if you have enough VRAM it can reduce the time it takes to finish generation as more images are generated at once.\
+ Default: 1")
+
+ with st.expander("Preview Settings"):
+ st.session_state["update_preview"] = st.checkbox("Update Image Preview", value=st.session_state['defaults'].txt2img.update_preview,
+ help="If enabled the image preview will be updated during the generation instead of at the end. \
+ You can use the Update Preview \Frequency option bellow to customize how frequent it's updated. \
+ By default this is enabled and the frequency is set to 1 step.")
+
+ st.session_state["update_preview_frequency"] = st.text_input("Update Image Preview Frequency", value=st.session_state['defaults'].txt2img.update_preview_frequency,
+ help="Frequency in steps at which the the preview image is updated. By default the frequency \
+ is set to 1 step.")
+
+ with col2:
+ preview_tab, gallery_tab = st.tabs(["Preview", "Gallery"])
+
+ with preview_tab:
+ #st.write("Image")
+ #Image for testing
+ #image = Image.open(requests.get("https://icon-library.com/images/image-placeholder-icon/image-placeholder-icon-13.jpg", stream=True).raw).convert('RGB')
+ #new_image = image.resize((175, 240))
+ #preview_image = st.image(image)
+
+ # create an empty container for the image, progress bar, etc so we can update it later and use session_state to hold them globally.
+ st.session_state["preview_image"] = st.empty()
+
+ st.session_state["loading"] = st.empty()
+
+ st.session_state["progress_bar_text"] = st.empty()
+ st.session_state["progress_bar"] = st.empty()
+
+ message = st.empty()
+
+ with col3:
+ # If we have custom models available on the "models/custom"
+ #folder then we show a menu to select which model we want to use, otherwise we use the main model for SD
+ if st.session_state.CustomModel_available:
+ st.session_state.custom_model = st.selectbox("Custom Model:", st.session_state.custom_models,
+ index=st.session_state["custom_models"].index(st.session_state['defaults'].general.default_model),
+ help="Select the model you want to use. This option is only available if you have custom models \
+ on your 'models/custom' folder. The model name that will be shown here is the same as the name\
+ the file for the model has on said folder, it is recommended to give the .ckpt file a name that \
+ will make it easier for you to distinguish it from other models. Default: Stable Diffusion v1.4")
+
+ st.session_state.sampling_steps = st.slider("Sampling Steps",
+ value=st.session_state['defaults'].txt2img.sampling_steps,
+ min_value=st.session_state['defaults'].txt2img.slider_bounds.sampling.lower,
+ max_value=st.session_state['defaults'].txt2img.slider_bounds.sampling.upper,
+ step=st.session_state['defaults'].txt2img.slider_steps.sampling)
+
+ sampler_name_list = ["k_lms", "k_euler", "k_euler_a", "k_dpm_2", "k_dpm_2_a", "k_heun", "PLMS", "DDIM"]
+ sampler_name = st.selectbox("Sampling method", sampler_name_list,
+ index=sampler_name_list.index(st.session_state['defaults'].txt2img.default_sampler), help="Sampling method to use. Default: k_euler")
+
+
+
+ #basic_tab, advanced_tab = st.tabs(["Basic", "Advanced"])
+
+ #with basic_tab:
+ #summit_on_enter = st.radio("Submit on enter?", ("Yes", "No"), horizontal=True,
+ #help="Press the Enter key to summit, when 'No' is selected you can use the Enter key to write multiple lines.")
+
+ with st.expander("Advanced"):
+ separate_prompts = st.checkbox("Create Prompt Matrix.", value=st.session_state['defaults'].txt2img.separate_prompts, help="Separate multiple prompts using the `|` character, and get all combinations of them.")
+ normalize_prompt_weights = st.checkbox("Normalize Prompt Weights.", value=st.session_state['defaults'].txt2img.normalize_prompt_weights, help="Ensure the sum of all weights add up to 1.0")
+ save_individual_images = st.checkbox("Save individual images.", value=st.session_state['defaults'].txt2img.save_individual_images, help="Save each image generated before any filter or enhancement is applied.")
+ save_grid = st.checkbox("Save grid",value=st.session_state['defaults'].txt2img.save_grid, help="Save a grid with all the images generated into a single image.")
+ group_by_prompt = st.checkbox("Group results by prompt", value=st.session_state['defaults'].txt2img.group_by_prompt,
+ help="Saves all the images with the same prompt into the same folder. When using a prompt matrix each prompt combination will have its own folder.")
+ write_info_files = st.checkbox("Write Info file", value=st.session_state['defaults'].txt2img.write_info_files, help="Save a file next to the image with informartion about the generation.")
+ save_as_jpg = st.checkbox("Save samples as jpg", value=st.session_state['defaults'].txt2img.save_as_jpg, help="Saves the images as jpg instead of png.")
+
+ if st.session_state["GFPGAN_available"]:
+ st.session_state["use_GFPGAN"] = st.checkbox("Use GFPGAN", value=st.session_state['defaults'].txt2img.use_GFPGAN, help="Uses the GFPGAN model to improve faces after the generation.\
+ This greatly improve the quality and consistency of faces but uses extra VRAM. Disable if you need the extra VRAM.")
+ else:
+ st.session_state["use_GFPGAN"] = False
+
+ if st.session_state["RealESRGAN_available"]:
+ st.session_state["use_RealESRGAN"] = st.checkbox("Use RealESRGAN", value=st.session_state['defaults'].txt2img.use_RealESRGAN,
+ help="Uses the RealESRGAN model to upscale the images after the generation.\
+ This greatly improve the quality and lets you have high resolution images but uses extra VRAM. Disable if you need the extra VRAM.")
+ st.session_state["RealESRGAN_model"] = st.selectbox("RealESRGAN model", ["RealESRGAN_x4plus", "RealESRGAN_x4plus_anime_6B"], index=0)
+ else:
+ st.session_state["use_RealESRGAN"] = False
+ st.session_state["RealESRGAN_model"] = "RealESRGAN_x4plus"
+
+ variant_amount = st.slider("Variant Amount:", value=st.session_state['defaults'].txt2img.variant_amount, min_value=0.0, max_value=1.0, step=0.01)
+ variant_seed = st.text_input("Variant Seed:", value=st.session_state['defaults'].txt2img.seed, help="The seed to use when generating a variant, if left blank a random seed will be generated.")
+ galleryCont = st.empty()
+
+ if generate_button:
+ #print("Loading models")
+ # load the models when we hit the generate button for the first time, it wont be loaded after that so dont worry.
+ load_models(False, st.session_state["use_GFPGAN"], st.session_state["use_RealESRGAN"], st.session_state["RealESRGAN_model"], st.session_state["CustomModel_available"],
+ st.session_state["custom_model"])
+
+
+ try:
+ #
+ output_images, seeds, info, stats = txt2img(prompt, st.session_state.sampling_steps, sampler_name, st.session_state["RealESRGAN_model"], batch_count, batch_size,
+ cfg_scale, seed, height, width, separate_prompts, normalize_prompt_weights, save_individual_images,
+ save_grid, group_by_prompt, save_as_jpg, st.session_state["use_GFPGAN"], st.session_state["use_RealESRGAN"], st.session_state["RealESRGAN_model"],
+ variant_amount=variant_amount, variant_seed=variant_seed, write_info_files=write_info_files)
+
+ message.success('Render Complete: ' + info + '; Stats: ' + stats, icon="✅")
+
+ #history_tab,col1,col2,col3,PlaceHolder,col1_cont,col2_cont,col3_cont = st.session_state['historyTab']
+
+ #if 'latestImages' in st.session_state:
+ #for i in output_images:
+ ##push the new image to the list of latest images and remove the oldest one
+ ##remove the last index from the list\
+ #st.session_state['latestImages'].pop()
+ ##add the new image to the start of the list
+ #st.session_state['latestImages'].insert(0, i)
+ #PlaceHolder.empty()
+ #with PlaceHolder.container():
+ #col1, col2, col3 = st.columns(3)
+ #col1_cont = st.container()
+ #col2_cont = st.container()
+ #col3_cont = st.container()
+ #images = st.session_state['latestImages']
+ #with col1_cont:
+ #with col1:
+ #[st.image(images[index]) for index in [0, 3, 6] if index < len(images)]
+ #with col2_cont:
+ #with col2:
+ #[st.image(images[index]) for index in [1, 4, 7] if index < len(images)]
+ #with col3_cont:
+ #with col3:
+ #[st.image(images[index]) for index in [2, 5, 8] if index < len(images)]
+ #historyGallery = st.empty()
+
+ ## check if output_images length is the same as seeds length
+ #with gallery_tab:
+ #st.markdown(createHTMLGallery(output_images,seeds), unsafe_allow_html=True)
+
+
+ #st.session_state['historyTab'] = [history_tab,col1,col2,col3,PlaceHolder,col1_cont,col2_cont,col3_cont]
+
+ except (StopException, KeyError):
+ print(f"Received Streamlit StopException")
+
+ # this will render all the images at the end of the generation but its better if its moved to a second tab inside col2 and shown as a gallery.
+ # use the current col2 first tab to show the preview_img and update it as its generated.
+ #preview_image.image(output_images)
+
+#on import run init
+def createHTMLGallery(images,info):
+ html3 = """
+
+ """
+ mkdwn_array = []
+ for i in images:
+ try:
+ seed = info[images.index(i)]
+ except:
+ seed = ' '
+ image_io = BytesIO()
+ i.save(image_io, 'PNG')
+ width, height = i.size
+ #get random number for the id
+ image_id = "%s" % (str(images.index(i)))
+ (data, mimetype) = STImage._normalize_to_bytes(image_io.getvalue(), width, 'auto')
+ this_file = in_memory_file_manager.add(data, mimetype, image_id)
+ img_str = this_file.url
+ #img_str = 'data:image/png;base64,' + b64encode(image_io.getvalue()).decode('ascii')
+ #get image size
+
+ #make sure the image is not bigger then 150px but keep the aspect ratio
+ if width > 150:
+ height = int(height * (150/width))
+ width = 150
+ if height > 150:
+ width = int(width * (150/height))
+ height = 150
+
+ #mkdwn = f""""""
+ mkdwn = f'''
'
+ return html3
+#
+def layout():
+ with st.form("txt2vid-inputs"):
+ st.session_state["generation_mode"] = "txt2vid"
+
+ input_col1, generate_col1 = st.columns([10,1])
+ with input_col1:
+ #prompt = st.text_area("Input Text","")
+ prompt = st.text_input("Input Text","", placeholder="A corgi wearing a top hat as an oil painting.")
+
+ # Every form must have a submit button, the extra blank spaces is a temp way to align it with the input field. Needs to be done in CSS or some other way.
+ generate_col1.write("")
+ generate_col1.write("")
+ generate_button = generate_col1.form_submit_button("Generate")
+
+ # creating the page layout using columns
+ col1, col2, col3 = st.columns([1,2,1], gap="large")
+
+ with col1:
+ width = st.slider("Width:", min_value=64, max_value=2048, value=st.session_state['defaults'].txt2vid.width, step=64)
+ height = st.slider("Height:", min_value=64, max_value=2048, value=st.session_state['defaults'].txt2vid.height, step=64)
+ cfg_scale = st.slider("CFG (Classifier Free Guidance Scale):", min_value=1.0, max_value=30.0, value=st.session_state['defaults'].txt2vid.cfg_scale, step=0.5, help="How strongly the image should follow the prompt.")
+
+ #uploaded_images = st.file_uploader("Upload Image", accept_multiple_files=False, type=["png", "jpg", "jpeg", "webp"],
+ #help="Upload an image which will be used for the image to image generation.")
+ seed = st.text_input("Seed:", value=st.session_state['defaults'].txt2vid.seed, help=" The seed to use, if left blank a random seed will be generated.")
+ #batch_count = st.slider("Batch count.", min_value=1, max_value=100, value=st.session_state['defaults'].txt2vid.batch_count, step=1, help="How many iterations or batches of images to generate in total.")
+ #batch_size = st.slider("Batch size", min_value=1, max_value=250, value=st.session_state['defaults'].txt2vid.batch_size, step=1,
+ #help="How many images are at once in a batch.\
+ #It increases the VRAM usage a lot but if you have enough VRAM it can reduce the time it takes to finish generation as more images are generated at once.\
+ #Default: 1")
+
+ st.session_state["max_frames"] = int(st.text_input("Max Frames:", value=st.session_state['defaults'].txt2vid.max_frames, help="Specify the max number of frames you want to generate."))
+
+ with st.expander("Preview Settings"):
+ st.session_state["update_preview"] = st.checkbox("Update Image Preview", value=st.session_state['defaults'].txt2vid.update_preview,
+ help="If enabled the image preview will be updated during the generation instead of at the end. \
+ You can use the Update Preview \Frequency option bellow to customize how frequent it's updated. \
+ By default this is enabled and the frequency is set to 1 step.")
+
+ st.session_state["update_preview_frequency"] = st.text_input("Update Image Preview Frequency", value=st.session_state['defaults'].txt2vid.update_preview_frequency,
+ help="Frequency in steps at which the the preview image is updated. By default the frequency \
+ is set to 1 step.")
+
+ #
+
+
+
+ with col2:
+ preview_tab, gallery_tab = st.tabs(["Preview", "Gallery"])
+
+ with preview_tab:
+ #st.write("Image")
+ #Image for testing
+ #image = Image.open(requests.get("https://icon-library.com/images/image-placeholder-icon/image-placeholder-icon-13.jpg", stream=True).raw).convert('RGB')
+ #new_image = image.resize((175, 240))
+ #preview_image = st.image(image)
+
+ # create an empty container for the image, progress bar, etc so we can update it later and use session_state to hold them globally.
+ st.session_state["preview_image"] = st.empty()
+
+ st.session_state["loading"] = st.empty()
+
+ st.session_state["progress_bar_text"] = st.empty()
+ st.session_state["progress_bar"] = st.empty()
+
+ #generate_video = st.empty()
+ st.session_state["preview_video"] = st.empty()
+
+ message = st.empty()
+
+ with gallery_tab:
+ st.write('Here should be the image gallery, if I could make a grid in streamlit.')
+
+ with col3:
+ # If we have custom models available on the "models/custom"
+ #folder then we show a menu to select which model we want to use, otherwise we use the main model for SD
+ if st.session_state["CustomModel_available"]:
+ custom_model = st.selectbox("Custom Model:", st.session_state["defaults"].txt2vid.custom_models_list,
+ index=st.session_state["defaults"].txt2vid.custom_models_list.index(st.session_state["defaults"].txt2vid.default_model),
+ help="Select the model you want to use. This option is only available if you have custom models \
+ on your 'models/custom' folder. The model name that will be shown here is the same as the name\
+ the file for the model has on said folder, it is recommended to give the .ckpt file a name that \
+ will make it easier for you to distinguish it from other models. Default: Stable Diffusion v1.4")
+ else:
+ custom_model = "CompVis/stable-diffusion-v1-4"
+
+ #st.session_state["weights_path"] = custom_model
+ #else:
+ #custom_model = "CompVis/stable-diffusion-v1-4"
+ #st.session_state["weights_path"] = f"CompVis/{slugify(custom_model.lower())}"
+
+ st.session_state.sampling_steps = st.slider("Sampling Steps",
+ value=st.session_state['defaults'].txt2vid.sampling_steps,
+ min_value=st.session_state['defaults'].txt2vid.slider_bounds.sampling.lower,
+ max_value=st.session_state['defaults'].txt2vid.slider_bounds.sampling.upper,
+ step=st.session_state['defaults'].txt2vid.slider_steps.sampling,
+ help="Number of steps between each pair of sampled points")
+ st.session_state.num_inference_steps = st.slider("Inference Steps:", value=st.session_state['defaults'].txt2vid.num_inference_steps, min_value=10,step=10, max_value=500,
+ help="Higher values (e.g. 100, 200 etc) can create better images.")
+
+ #sampler_name_list = ["k_lms", "k_euler", "k_euler_a", "k_dpm_2", "k_dpm_2_a", "k_heun", "PLMS", "DDIM"]
+ #sampler_name = st.selectbox("Sampling method", sampler_name_list,
+ #index=sampler_name_list.index(st.session_state['defaults'].txt2vid.default_sampler), help="Sampling method to use. Default: k_euler")
+ scheduler_name_list = ["klms", "ddim"]
+ scheduler_name = st.selectbox("Scheduler:", scheduler_name_list,
+ index=scheduler_name_list.index(st.session_state['defaults'].txt2vid.scheduler_name), help="Scheduler to use. Default: klms")
+
+ beta_scheduler_type_list = ["scaled_linear", "linear"]
+ beta_scheduler_type = st.selectbox("Beta Schedule Type:", beta_scheduler_type_list,
+ index=beta_scheduler_type_list.index(st.session_state['defaults'].txt2vid.beta_scheduler_type), help="Schedule Type to use. Default: linear")
+
+
+ #basic_tab, advanced_tab = st.tabs(["Basic", "Advanced"])
+
+ #with basic_tab:
+ #summit_on_enter = st.radio("Submit on enter?", ("Yes", "No"), horizontal=True,
+ #help="Press the Enter key to summit, when 'No' is selected you can use the Enter key to write multiple lines.")
+
+ with st.expander("Advanced"):
+ st.session_state["separate_prompts"] = st.checkbox("Create Prompt Matrix.", value=st.session_state['defaults'].txt2vid.separate_prompts,
+ help="Separate multiple prompts using the `|` character, and get all combinations of them.")
+ st.session_state["normalize_prompt_weights"] = st.checkbox("Normalize Prompt Weights.",
+ value=st.session_state['defaults'].txt2vid.normalize_prompt_weights, help="Ensure the sum of all weights add up to 1.0")
+ st.session_state["save_individual_images"] = st.checkbox("Save individual images.",
+ value=st.session_state['defaults'].txt2vid.save_individual_images, help="Save each image generated before any filter or enhancement is applied.")
+ st.session_state["save_video"] = st.checkbox("Save video",value=st.session_state['defaults'].txt2vid.save_video, help="Save a video with all the images generated as frames at the end of the generation.")
+ st.session_state["group_by_prompt"] = st.checkbox("Group results by prompt", value=st.session_state['defaults'].txt2vid.group_by_prompt,
+ help="Saves all the images with the same prompt into the same folder. When using a prompt matrix each prompt combination will have its own folder.")
+ st.session_state["write_info_files"] = st.checkbox("Write Info file", value=st.session_state['defaults'].txt2vid.write_info_files,
+ help="Save a file next to the image with informartion about the generation.")
+ st.session_state["dynamic_preview_frequency"] = st.checkbox("Dynamic Preview Frequency", value=st.session_state['defaults'].txt2vid.dynamic_preview_frequency,
+ help="This option tries to find the best value at which we can update \
+ the preview image during generation while minimizing the impact it has in performance. Default: True")
+ st.session_state["do_loop"] = st.checkbox("Do Loop", value=st.session_state['defaults'].txt2vid.do_loop,
+ help="Do loop")
+ st.session_state["save_as_jpg"] = st.checkbox("Save samples as jpg", value=st.session_state['defaults'].txt2vid.save_as_jpg, help="Saves the images as jpg instead of png.")
+
+ if GFPGAN_available:
+ st.session_state["use_GFPGAN"] = st.checkbox("Use GFPGAN", value=st.session_state['defaults'].txt2vid.use_GFPGAN, help="Uses the GFPGAN model to improve faces after the generation. This greatly improve the quality and consistency of faces but uses extra VRAM. Disable if you need the extra VRAM.")
+ else:
+ st.session_state["use_GFPGAN"] = False
+
+ if RealESRGAN_available:
+ st.session_state["use_RealESRGAN"] = st.checkbox("Use RealESRGAN", value=st.session_state['defaults'].txt2vid.use_RealESRGAN,
+ help="Uses the RealESRGAN model to upscale the images after the generation. This greatly improve the quality and lets you have high resolution images but uses extra VRAM. Disable if you need the extra VRAM.")
+ st.session_state["RealESRGAN_model"] = st.selectbox("RealESRGAN model", ["RealESRGAN_x4plus", "RealESRGAN_x4plus_anime_6B"], index=0)
+ else:
+ st.session_state["use_RealESRGAN"] = False
+ st.session_state["RealESRGAN_model"] = "RealESRGAN_x4plus"
+
+ st.session_state["variant_amount"] = st.slider("Variant Amount:", value=st.session_state['defaults'].txt2vid.variant_amount, min_value=0.0, max_value=1.0, step=0.01)
+ st.session_state["variant_seed"] = st.text_input("Variant Seed:", value=st.session_state['defaults'].txt2vid.seed, help="The seed to use when generating a variant, if left blank a random seed will be generated.")
+ st.session_state["beta_start"] = st.slider("Beta Start:", value=st.session_state['defaults'].txt2vid.beta_start, min_value=0.0001, max_value=0.03, step=0.0001, format="%.4f")
+ st.session_state["beta_end"] = st.slider("Beta End:", value=st.session_state['defaults'].txt2vid.beta_end, min_value=0.0001, max_value=0.03, step=0.0001, format="%.4f")
+
+ if generate_button:
+ #print("Loading models")
+ # load the models when we hit the generate button for the first time, it wont be loaded after that so dont worry.
+ #load_models(False, False, False, st.session_state["RealESRGAN_model"], CustomModel_available=st.session_state["CustomModel_available"], custom_model=custom_model)
+
+ try:
+ # run video generation
+ video, seed, info, stats = txt2vid(prompts=prompt, gpu=st.session_state["defaults"].general.gpu,
+ num_steps=st.session_state.sampling_steps, max_frames=int(st.session_state.max_frames),
+ num_inference_steps=st.session_state.num_inference_steps,
+ cfg_scale=cfg_scale,do_loop=st.session_state["do_loop"],
+ seeds=seed, quality=100, eta=0.0, width=width,
+ height=height, weights_path=custom_model, scheduler=scheduler_name,
+ disable_tqdm=False, beta_start=st.session_state["beta_start"], beta_end=st.session_state["beta_end"],
+ beta_schedule=beta_scheduler_type, starting_image=None)
+
+ #message.success('Done!', icon="✅")
+ message.success('Render Complete: ' + info + '; Stats: ' + stats, icon="✅")
+
+ #history_tab,col1,col2,col3,PlaceHolder,col1_cont,col2_cont,col3_cont = st.session_state['historyTab']
+
+ #if 'latestVideos' in st.session_state:
+ #for i in video:
+ ##push the new image to the list of latest images and remove the oldest one
+ ##remove the last index from the list\
+ #st.session_state['latestVideos'].pop()
+ ##add the new image to the start of the list
+ #st.session_state['latestVideos'].insert(0, i)
+ #PlaceHolder.empty()
+
+ #with PlaceHolder.container():
+ #col1, col2, col3 = st.columns(3)
+ #col1_cont = st.container()
+ #col2_cont = st.container()
+ #col3_cont = st.container()
+
+ #with col1_cont:
+ #with col1:
+ #st.image(st.session_state['latestVideos'][0])
+ #st.image(st.session_state['latestVideos'][3])
+ #st.image(st.session_state['latestVideos'][6])
+ #with col2_cont:
+ #with col2:
+ #st.image(st.session_state['latestVideos'][1])
+ #st.image(st.session_state['latestVideos'][4])
+ #st.image(st.session_state['latestVideos'][7])
+ #with col3_cont:
+ #with col3:
+ #st.image(st.session_state['latestVideos'][2])
+ #st.image(st.session_state['latestVideos'][5])
+ #st.image(st.session_state['latestVideos'][8])
+ #historyGallery = st.empty()
+
+ ## check if output_images length is the same as seeds length
+ #with gallery_tab:
+ #st.markdown(createHTMLGallery(video,seed), unsafe_allow_html=True)
+
+
+ #st.session_state['historyTab'] = [history_tab,col1,col2,col3,PlaceHolder,col1_cont,col2_cont,col3_cont]
+
+ except (StopException, KeyError):
+ print(f"Received Streamlit StopException")
+
+
diff --git a/scripts/webui.py b/scripts/webui.py
index dd64a4c..eb5d32f 100644
--- a/scripts/webui.py
+++ b/scripts/webui.py
@@ -2,8 +2,10 @@ import argparse, os, sys, glob, re
import cv2
+from perlin import perlinNoise
from frontend.frontend import draw_gradio_ui
from frontend.job_manager import JobManager, JobInfo
+from frontend.image_metadata import ImageMetadata
from frontend.ui_functions import resize_image
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--ckpt", type=str, default="models/ldm/stable-diffusion-v1/model.ckpt", help="path to checkpoint of model",)
@@ -13,7 +15,7 @@ parser.add_argument("--defaults", type=str, help="path to configuration file pro
parser.add_argument("--esrgan-cpu", action='store_true', help="run ESRGAN on cpu", default=False)
parser.add_argument("--esrgan-gpu", type=int, help="run ESRGAN on specific gpu (overrides --gpu)", default=0)
parser.add_argument("--extra-models-cpu", action='store_true', help="run extra models (GFGPAN/ESRGAN) on cpu", default=False)
-parser.add_argument("--extra-models-gpu", action='store_true', help="run extra models (GFGPAN/ESRGAN) on cpu", default=False)
+parser.add_argument("--extra-models-gpu", action='store_true', help="run extra models (GFGPAN/ESRGAN) on gpu", default=False)
parser.add_argument("--gfpgan-cpu", action='store_true', help="run GFPGAN on cpu", default=False)
parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN')) # i disagree with where you're putting it but since all guidefags are doing it this way, there you go
parser.add_argument("--gfpgan-gpu", type=int, help="run GFPGAN on specific gpu (overrides --gpu) ", default=0)
@@ -31,6 +33,7 @@ parser.add_argument("--outdir_img2img", type=str, nargs="?", help="dir to write
parser.add_argument("--outdir_imglab", type=str, nargs="?", help="dir to write imglab results to (overrides --outdir)", default=None)
parser.add_argument("--outdir_txt2img", type=str, nargs="?", help="dir to write txt2img results to (overrides --outdir)", default=None)
parser.add_argument("--outdir", type=str, nargs="?", help="dir to write results to", default=None)
+parser.add_argument("--filename_format", type=str, nargs="?", help="filenames format", default=None)
parser.add_argument("--port", type=int, help="choose the port for the gradio webserver to use", default=7860)
parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast")
parser.add_argument("--realesrgan-dir", type=str, help="RealESRGAN directory", default=('./src/realesrgan' if os.path.exists('./src/realesrgan') else './RealESRGAN'))
@@ -42,6 +45,7 @@ parser.add_argument("--skip-grid", action='store_true', help="do not save a grid
parser.add_argument("--skip-save", action='store_true', help="do not save indiviual samples. For speed measurements.", default=False)
parser.add_argument('--no-job-manager', action='store_true', help="Don't use the experimental job manager on top of gradio", default=False)
parser.add_argument("--max-jobs", type=int, help="Maximum number of concurrent 'generate' commands", default=1)
+parser.add_argument("--tiling", action='store_true', help="Generate tiling images", default=False)
opt = parser.parse_args()
#Should not be needed anymore
@@ -66,16 +70,27 @@ import torch
import torch.nn as nn
import yaml
import glob
-from typing import List, Union, Dict
+import copy
+from typing import List, Union, Dict, Callable, Any, Optional
from pathlib import Path
from collections import namedtuple
+from functools import partial
+
+# tell the user which GPU the code is actually using
+if os.getenv("SD_WEBUI_DEBUG", 'False').lower() in ('true', '1', 'y'):
+ gpu_in_use = opt.gpu
+ # prioritize --esrgan-gpu and --gfpgan-gpu over --gpu, as stated in the option info
+ if opt.esrgan_gpu != opt.gpu:
+ gpu_in_use = opt.esrgan_gpu
+ elif opt.gfpgan_gpu != opt.gpu:
+ gpu_in_use = opt.gfpgan_gpu
+ print("Starting on GPU {selected_gpu_name}".format(selected_gpu_name=torch.cuda.get_device_name(gpu_in_use)))
from contextlib import contextmanager, nullcontext
from einops import rearrange, repeat
from itertools import islice
from omegaconf import OmegaConf
-from PIL import Image, ImageFont, ImageDraw, ImageFilter, ImageOps
-from PIL.PngImagePlugin import PngInfo
+from PIL import Image, ImageFont, ImageDraw, ImageFilter, ImageOps, ImageChops
from io import BytesIO
import base64
import re
@@ -84,6 +99,18 @@ from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler
from ldm.util import instantiate_from_config
+# add global options to models
+def patch_conv(**patch):
+ cls = torch.nn.Conv2d
+ init = cls.__init__
+ def __init__(self, *args, **kwargs):
+ return init(self, *args, **kwargs, **patch)
+ cls.__init__ = __init__
+
+if opt.tiling:
+ patch_conv(padding_mode='circular')
+ print("patched for tiling")
+
try:
# this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
@@ -92,6 +119,14 @@ try:
except:
pass
+from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
+from transformers import AutoFeatureExtractor
+
+# load safety model
+safety_model_id = "CompVis/stable-diffusion-safety-checker"
+safety_feature_extractor = None
+safety_checker = None
+
# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the bowser will not show any UI
mimetypes.init()
mimetypes.add_type('application/javascript', '.js')
@@ -203,7 +238,16 @@ class MemUsageMonitor(threading.Thread):
print(f"[{self.name}] Unable to initialize NVIDIA management. No memory stats. \n")
return
print(f"[{self.name}] Recording max memory usage...\n")
- handle = pynvml.nvmlDeviceGetHandleByIndex(opt.gpu)
+ # check if we're using a scoped-down GPU environment (pynvml does not listen to CUDA_VISIBLE_DEVICES)
+ # so that we can measure memory on the correct GPU
+ try:
+ isinstance(int(os.environ["CUDA_VISIBLE_DEVICES"]), int)
+ handle = pynvml.nvmlDeviceGetHandleByIndex(int(os.environ["CUDA_VISIBLE_DEVICES"]))
+ except (KeyError, ValueError) as pynvmlHandleError:
+ if os.getenv("SD_WEBUI_DEBUG", 'False').lower() in ('true', '1', 'y'):
+ print("[MemMon][WARNING]", pynvmlHandleError)
+ print("[MemMon][INFO]", "defaulting to monitoring memory on the default gpu (set via --gpu flag)")
+ handle = pynvml.nvmlDeviceGetHandleByIndex(opt.gpu)
self.total = pynvml.nvmlDeviceGetMemoryInfo(handle).total
while not self.stop_flag:
m = pynvml.nvmlDeviceGetMemoryInfo(handle)
@@ -264,15 +308,21 @@ class KDiffusionSampler:
self.schedule = sampler
def get_sampler_name(self):
return self.schedule
- def sample(self, S, conditioning, batch_size, shape, verbose, unconditional_guidance_scale, unconditional_conditioning, eta, x_T):
+ def sample(self, S, conditioning, batch_size, shape, verbose, unconditional_guidance_scale, unconditional_conditioning, eta, x_T, img_callback: Callable = None ):
sigmas = self.model_wrap.get_sigmas(S)
x = x_T * sigmas[0]
model_wrap_cfg = CFGDenoiser(self.model_wrap)
- samples_ddim = K.sampling.__dict__[f'sample_{self.schedule}'](model_wrap_cfg, x, sigmas, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': unconditional_guidance_scale}, disable=False)
+ samples_ddim = K.sampling.__dict__[f'sample_{self.schedule}'](model_wrap_cfg, x, sigmas, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': unconditional_guidance_scale}, disable=False, callback=partial(KDiffusionSampler.img_callback_wrapper, img_callback))
return samples_ddim, None
+ @classmethod
+ def img_callback_wrapper(cls, callback: Callable, *args):
+ ''' Converts a KDiffusion callback to the standard img_callback '''
+ if callback:
+ arg_dict = args[0]
+ callback(image_sample=arg_dict['denoised'], iter_num=arg_dict['i'])
def create_random_tensors(shape, seeds):
xs = []
@@ -592,25 +642,18 @@ def check_prompt_length(prompt, comments):
comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
-def save_sample(image, sample_path_i, filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale,
-normalize_prompt_weights, use_GFPGAN, write_info_files, write_sample_info_to_log_file, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, skip_save,
-skip_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, skip_metadata=True):
+
+def save_sample(image, sample_path_i, filename, jpg_sample, write_info_files, write_sample_info_to_log_file, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, skip_save,
+skip_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, skip_metadata=False):
+ ''' saves the image according to selected parameters. Expects to find generation parameters on image, set by ImageMetadata.set_on_image() '''
+ metadata = ImageMetadata.get_from_image(image)
+ if not skip_metadata and metadata is None:
+ print("No metadata passed in to save. Set metadata on the image before calling save_sample using the ImageMetadata.set_on_image() function.")
+ skip_metadata = True
filename_i = os.path.join(sample_path_i, filename)
if not jpg_sample:
if opt.save_metadata and not skip_metadata:
- metadata = PngInfo()
- metadata.add_text("SD:prompt", prompts[i])
- metadata.add_text("SD:seed", str(seeds[i]))
- metadata.add_text("SD:width", str(width))
- metadata.add_text("SD:height", str(height))
- metadata.add_text("SD:sampler_name", str(sampler_name))
- metadata.add_text("SD:steps", str(steps))
- metadata.add_text("SD:cfg_scale", str(cfg_scale))
- metadata.add_text("SD:normalize_prompt_weights", str(normalize_prompt_weights))
- if init_img is not None:
- metadata.add_text("SD:denoising_strength", str(denoising_strength))
- metadata.add_text("SD:GFPGAN", str(use_GFPGAN and GFPGAN is not None))
- image.save(f"{filename_i}.png", pnginfo=metadata)
+ image.save(f"{filename_i}.png", pnginfo=metadata.as_png_info())
else:
image.save(f"{filename_i}.png")
else:
@@ -621,7 +664,7 @@ skip_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoisin
toggles = []
if prompt_matrix:
toggles.append(0)
- if normalize_prompt_weights:
+ if metadata.normalize_prompt_weights:
toggles.append(1)
if init_img is not None:
if uses_loopback:
@@ -638,14 +681,14 @@ skip_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoisin
toggles.append(5 + offset)
if write_sample_info_to_log_file:
toggles.append(6+offset)
- if use_GFPGAN:
+ if metadata.GFPGAN:
toggles.append(7 + offset)
info_dict = dict(
target="txt2img" if init_img is None else "img2img",
- prompt=prompts[i], ddim_steps=steps, toggles=toggles, sampler_name=sampler_name,
- ddim_eta=ddim_eta, n_iter=n_iter, batch_size=batch_size, cfg_scale=cfg_scale,
- seed=seeds[i], width=width, height=height
+ prompt=metadata.prompt, ddim_steps=metadata.steps, toggles=toggles, sampler_name=sampler_name,
+ ddim_eta=ddim_eta, n_iter=n_iter, batch_size=batch_size, cfg_scale=metadata.cfg_scale,
+ seed=metadata.seed, width=metadata.width, height=metadata.height
)
if init_img is not None:
# Not yet any use for these, but they bloat up the files:
@@ -775,16 +818,95 @@ def oxlamon_matrix(prompt, seed, n_iter, batch_size):
return all_seeds, n_iter, prompt_matrix_parts, all_prompts, needrows
+def perform_masked_image_restoration(image, init_img, init_mask, mask_blur_strength, mask_restore, use_RealESRGAN, RealESRGAN):
+ if not mask_restore:
+ return image
+ else:
+ init_mask = init_mask.filter(ImageFilter.GaussianBlur(mask_blur_strength))
+ init_mask = init_mask.convert('L')
+ init_img = init_img.convert('RGB')
+ image = image.convert('RGB')
+ if use_RealESRGAN and RealESRGAN is not None:
+ output, img_mode = RealESRGAN.enhance(np.array(init_mask, dtype=np.uint8))
+ init_mask = Image.fromarray(output)
+ init_mask = init_mask.convert('L')
+
+ output, img_mode = RealESRGAN.enhance(np.array(init_img, dtype=np.uint8))
+ init_img = Image.fromarray(output)
+ init_img = init_img.convert('RGB')
+
+ image = Image.composite(init_img, image, init_mask)
+
+ return image
+
+
+def perform_color_correction(img_rgb, correction_target_lab, do_color_correction):
+ try:
+ from skimage import exposure
+ except:
+ print("Install scikit-image to perform color correction")
+ return img_rgb
+
+ if not do_color_correction: return img_rgb
+ if correction_target_lab is None: return img_rgb
+
+ return (
+ Image.fromarray(cv2.cvtColor(exposure.match_histograms(
+ cv2.cvtColor(
+ np.asarray(img_rgb),
+ cv2.COLOR_RGB2LAB
+ ),
+ correction_target_lab,
+ channel_axis=2
+ ), cv2.COLOR_LAB2RGB).astype("uint8")
+ )
+ )
def process_images(
outpath, func_init, func_sample, prompt, seed, sampler_name, skip_grid, skip_save, batch_size,
- n_iter, steps, cfg_scale, width, height, prompt_matrix, use_GFPGAN, use_RealESRGAN, realesrgan_model_name,
+ n_iter, steps, cfg_scale, width, height, prompt_matrix, filter_nsfw, use_GFPGAN, use_RealESRGAN, realesrgan_model_name,
fp, ddim_eta=0.0, do_not_save_grid=False, normalize_prompt_weights=True, init_img=None, init_mask=None,
- keep_mask=False, mask_blur_strength=3, denoising_strength=0.75, resize_mode=None, uses_loopback=False,
+ keep_mask=False, mask_blur_strength=3, mask_restore=False, denoising_strength=0.75, resize_mode=None, uses_loopback=False,
uses_random_seed_loopback=False, sort_samples=True, write_info_files=True, write_sample_info_to_log_file=False, jpg_sample=False,
- variant_amount=0.0, variant_seed=None,imgProcessorTask=False, job_info: JobInfo = None):
+ variant_amount=0.0, variant_seed=None,imgProcessorTask=False, job_info: JobInfo = None, do_color_correction=False, correction_target=None):
"""this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
+
+ def numpy_to_pil(images):
+ """
+ Convert a numpy image or a batch of images to a PIL image.
+ """
+ if images.ndim == 3:
+ images = images[None, ...]
+ images = (images * 255).round().astype("uint8")
+ pil_images = [Image.fromarray(image) for image in images]
+
+ return pil_images
+
+ # load replacement of nsfw content
+ def load_replacement(x):
+ try:
+ hwc = x.shape
+ y = Image.open("images/nsfw.jpeg").convert("RGB").resize((hwc[1], hwc[0]))
+ y = (np.array(y)/255.0).astype(x.dtype)
+ assert y.shape == x.shape
+ return y
+ except Exception:
+ return x
+
+ # check and replace nsfw content
+ def check_safety(x_image):
+ global safety_feature_extractor, safety_checker
+ if safety_feature_extractor is None:
+ safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id)
+ safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id)
+ safety_checker_input = safety_feature_extractor(numpy_to_pil(x_image), return_tensors="pt")
+ x_checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values)
+ for i in range(len(has_nsfw_concept)):
+ if has_nsfw_concept[i]:
+ x_checked_image[i] = load_replacement(x_checked_image[i])
+ return x_checked_image, has_nsfw_concept
+
prompt = prompt or ''
torch_gc()
# start time after garbage collection (or before?)
@@ -804,6 +926,12 @@ def process_images(
if not ("|" in prompt) and prompt.startswith("@"):
prompt = prompt[1:]
+ negprompt = ''
+ if '###' in prompt:
+ prompt, negprompt = prompt.split('###', 1)
+ prompt = prompt.strip()
+ negprompt = negprompt.strip()
+
comments = []
prompt_matrix_parts = []
@@ -882,12 +1010,14 @@ def process_images(
if job_info:
job_info.job_status = f"Processing Iteration {n+1}/{n_iter}. Batch size {batch_size}"
+ job_info.rec_steps_imgs.clear()
for idx,(p,s) in enumerate(zip(prompts,seeds)):
job_info.job_status += f"\nItem {idx}: Seed {s}\nPrompt: {p}"
+ print(f"Current prompt: {p}")
if opt.optimized:
modelCS.to(device)
- uc = (model if not opt.optimized else modelCS).get_learned_conditioning(len(prompts) * [""])
+ uc = (model if not opt.optimized else modelCS).get_learned_conditioning(len(prompts) * [negprompt])
if isinstance(prompts, tuple):
prompts = list(prompts)
@@ -912,7 +1042,7 @@ def process_images(
while(torch.cuda.memory_allocated()/1e6 >= mem):
time.sleep(1)
- cur_variant_amount = variant_amount
+ cur_variant_amount = variant_amount
if variant_amount == 0.0:
# we manually generate all input noises because each one should have a specific seed
x = create_random_tensors(shape, seeds=seeds)
@@ -935,16 +1065,91 @@ def process_images(
# finally, slerp base_x noise to target_x noise for creating a variant
x = slerp(device, max(0.0, min(1.0, cur_variant_amount)), base_x, target_x)
- samples_ddim = func_sample(init_data=init_data, x=x, conditioning=c, unconditional_conditioning=uc, sampler_name=sampler_name)
+ # If optimized then use first stage for preview and store it on cpu until needed
+ if opt.optimized:
+ step_preview_model = modelFS
+ step_preview_model.cpu()
+ else:
+ step_preview_model = model
+
+ def sample_iteration_callback(image_sample: torch.Tensor, iter_num: int):
+ ''' Called from the sampler every iteration '''
+ if job_info:
+ job_info.active_iteration_cnt = iter_num
+ record_periodic_image = job_info.rec_steps_enabled and (0 == iter_num % job_info.rec_steps_intrvl)
+ if record_periodic_image or job_info.refresh_active_image_requested.is_set():
+ preview_start_time = time.time()
+ if opt.optimized:
+ step_preview_model.to(device)
+
+ decoded_batch: List[torch.Tensor] = []
+ # Break up batch to save VRAM
+ for sample in image_sample:
+ sample = sample[None, :] # expands the tensor as if it still had a batch dimension
+ decoded_sample = step_preview_model.decode_first_stage(sample)[0]
+ decoded_sample = torch.clamp((decoded_sample + 1.0) / 2.0, min=0.0, max=1.0)
+ decoded_sample = decoded_sample.cpu()
+ decoded_batch.append(decoded_sample)
+
+ batch_size = len(decoded_batch)
+
+ if opt.optimized:
+ step_preview_model.cpu()
+
+ images: List[Image.Image] = []
+ # Convert tensor to image (copied from code below)
+ for ddim in decoded_batch:
+ x_sample = 255. * rearrange(ddim.numpy(), 'c h w -> h w c')
+ x_sample = x_sample.astype(np.uint8)
+ image = Image.fromarray(x_sample)
+ images.append(image)
+
+ caption = f"Iter {iter_num}"
+ grid = image_grid(images, len(images), force_n_rows=1, captions=[caption]*len(images))
+
+ # Save the images if recording steps, and append existing saved steps
+ if job_info.rec_steps_enabled:
+ gallery_img_size = tuple(int(0.25*dim) for dim in images[0].size)
+ job_info.rec_steps_imgs.append(grid.resize(gallery_img_size))
+
+ # Notify the requester that the image is updated
+ if job_info.refresh_active_image_requested.is_set():
+ if job_info.rec_steps_enabled:
+ grid_rows = None if batch_size == 1 else len(job_info.rec_steps_imgs)
+ grid = image_grid(imgs=job_info.rec_steps_imgs[::-1], batch_size=1, force_n_rows=grid_rows)
+ job_info.active_image = grid
+ job_info.refresh_active_image_done.set()
+ job_info.refresh_active_image_requested.clear()
+
+ preview_elapsed_timed = time.time() - preview_start_time
+ if preview_elapsed_timed / job_info.rec_steps_intrvl > 1:
+ print(
+ f"Warning: Preview generation is slowing image generation. It took {preview_elapsed_timed:.2f}s to generate progress images for batch of {batch_size} images!")
+
+ # Interrupt current iteration?
+ if job_info.stop_cur_iter.is_set():
+ job_info.stop_cur_iter.clear()
+ raise StopIteration()
+
+ try:
+ samples_ddim = func_sample(init_data=init_data, x=x, conditioning=c, unconditional_conditioning=uc, sampler_name=sampler_name, img_callback=sample_iteration_callback)
+ except StopIteration:
+ print("Skipping iteration")
+ job_info.job_status = "Skipping iteration"
+ continue
if opt.optimized:
modelFS.to(device)
+ for i in range(len(samples_ddim)):
+ x_samples_ddim = (model if not opt.optimized else modelFS).decode_first_stage(samples_ddim[i].unsqueeze(0))
+ x_sample = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
+ if filter_nsfw:
+ x_samples_ddim_numpy = x_sample.cpu().permute(0, 2, 3, 1).numpy()
+ x_checked_image, has_nsfw_concept = check_safety(x_samples_ddim_numpy)
+ x_sample = torch.from_numpy(x_checked_image).permute(0, 3, 1, 2)
- x_samples_ddim = (model if not opt.optimized else modelFS).decode_first_stage(samples_ddim)
- x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
- for i, x_sample in enumerate(x_samples_ddim):
sanitized_prompt = prompts[i].replace(' ', '_').translate({ord(x): '' for x in invalid_filename_chars})
if variant_seed != None and variant_seed != '':
if variant_amount == 0.0:
@@ -958,16 +1163,33 @@ def process_images(
sample_path_i = os.path.join(sample_path, sanitized_prompt)
os.makedirs(sample_path_i, exist_ok=True)
base_count = get_next_sequence_number(sample_path_i)
- filename = f"{base_count:05}-{steps}_{sampler_name}_{seed_used}_{cur_variant_amount:.2f}"
+ filename = opt.filename_format or "[STEPS]_[SAMPLER]_[SEED]_[VARIANT_AMOUNT]"
else:
sample_path_i = sample_path
base_count = get_next_sequence_number(sample_path_i)
- sanitized_prompt = sanitized_prompt
- filename = f"{base_count:05}-{steps}_{sampler_name}_{seed_used}_{cur_variant_amount:.2f}_{sanitized_prompt}"[:128] #same as before
+ filename = opt.filename_format or "[STEPS]_[SAMPLER]_[SEED]_[VARIANT_AMOUNT]_[PROMPT]"
- x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
+ #Add new filenames tags here
+ filename = f"{base_count:05}-" + filename
+ filename = filename.replace("[STEPS]", str(steps))
+ filename = filename.replace("[CFG]", str(cfg_scale))
+ filename = filename.replace("[PROMPT]", sanitized_prompt[:128])
+ filename = filename.replace("[PROMPT_SPACES]", prompts[i].translate({ord(x): '' for x in invalid_filename_chars})[:128])
+ filename = filename.replace("[WIDTH]", str(width))
+ filename = filename.replace("[HEIGHT]", str(height))
+ filename = filename.replace("[SAMPLER]", sampler_name)
+ filename = filename.replace("[SEED]", seed_used)
+ filename = filename.replace("[VARIANT_AMOUNT]", f"{cur_variant_amount:.2f}")
+
+ x_sample = 255. * rearrange(x_sample[0].cpu().numpy(), 'c h w -> h w c')
x_sample = x_sample.astype(np.uint8)
+ metadata = ImageMetadata(prompt=prompts[i], seed=seeds[i], height=height, width=width, steps=steps,
+ cfg_scale=cfg_scale, normalize_prompt_weights=normalize_prompt_weights, denoising_strength=denoising_strength,
+ GFPGAN=use_GFPGAN )
image = Image.fromarray(x_sample)
+ image = perform_color_correction(image, correction_target, do_color_correction)
+ ImageMetadata.set_on_image(image, metadata)
+
original_sample = x_sample
original_filename = filename
if use_GFPGAN and GFPGAN is not None and not use_RealESRGAN:
@@ -976,10 +1198,18 @@ def process_images(
cropped_faces, restored_faces, restored_img = GFPGAN.enhance(original_sample[:,:,::-1], has_aligned=False, only_center_face=False, paste_back=True)
gfpgan_sample = restored_img[:,:,::-1]
gfpgan_image = Image.fromarray(gfpgan_sample)
+ gfpgan_image = perform_color_correction(gfpgan_image, correction_target, do_color_correction)
+ gfpgan_image = perform_masked_image_restoration(
+ gfpgan_image, init_img, init_mask,
+ mask_blur_strength, mask_restore,
+ use_RealESRGAN = False, RealESRGAN = None
+ )
+ gfpgan_metadata = copy.copy(metadata)
+ gfpgan_metadata.GFPGAN = True
+ ImageMetadata.set_on_image( gfpgan_image, gfpgan_metadata )
gfpgan_filename = original_filename + '-gfpgan'
- save_sample(gfpgan_image, sample_path_i, gfpgan_filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale,
-normalize_prompt_weights, use_GFPGAN, write_info_files, write_sample_info_to_log_file, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, skip_save,
-skip_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, skip_metadata=True)
+ save_sample(gfpgan_image, sample_path_i, gfpgan_filename, jpg_sample, write_info_files, write_sample_info_to_log_file, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, skip_save,
+skip_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, skip_metadata=False)
output_images.append(gfpgan_image) #287
#if simple_templating:
# grid_captions.append( captions[i] + "\ngfpgan" )
@@ -991,9 +1221,15 @@ skip_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoisin
esrgan_filename = original_filename + '-esrgan4x'
esrgan_sample = output[:,:,::-1]
esrgan_image = Image.fromarray(esrgan_sample)
- save_sample(esrgan_image, sample_path_i, esrgan_filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale,
-normalize_prompt_weights, use_GFPGAN,write_info_files, write_sample_info_to_log_file, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, skip_save,
-skip_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, skip_metadata=True)
+ esrgan_image = perform_color_correction(esrgan_image, correction_target, do_color_correction)
+ esrgan_image = perform_masked_image_restoration(
+ esrgan_image, init_img, init_mask,
+ mask_blur_strength, mask_restore,
+ use_RealESRGAN, RealESRGAN
+ )
+ ImageMetadata.set_on_image( esrgan_image, metadata )
+ save_sample(esrgan_image, sample_path_i, esrgan_filename, jpg_sample, write_info_files, write_sample_info_to_log_file, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, skip_save,
+skip_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, skip_metadata=False)
output_images.append(esrgan_image) #287
#if simple_templating:
# grid_captions.append( captions[i] + "\nesrgan" )
@@ -1007,9 +1243,15 @@ skip_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoisin
gfpgan_esrgan_filename = original_filename + '-gfpgan-esrgan4x'
gfpgan_esrgan_sample = output[:,:,::-1]
gfpgan_esrgan_image = Image.fromarray(gfpgan_esrgan_sample)
- save_sample(gfpgan_esrgan_image, sample_path_i, gfpgan_esrgan_filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale,
-normalize_prompt_weights, use_GFPGAN, write_info_files, write_sample_info_to_log_file, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, skip_save,
-skip_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, skip_metadata=True)
+ gfpgan_esrgan_image = perform_color_correction(gfpgan_esrgan_image, correction_target, do_color_correction)
+ gfpgan_esrgan_image = perform_masked_image_restoration(
+ gfpgan_esrgan_image, init_img, init_mask,
+ mask_blur_strength, mask_restore,
+ use_RealESRGAN, RealESRGAN
+ )
+ ImageMetadata.set_on_image(gfpgan_esrgan_image, metadata)
+ save_sample(gfpgan_esrgan_image, sample_path_i, gfpgan_esrgan_filename, jpg_sample, write_info_files, write_sample_info_to_log_file, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback,
+skip_save, skip_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, skip_metadata=False)
output_images.append(gfpgan_esrgan_image) #287
#if simple_templating:
# grid_captions.append( captions[i] + "\ngfpgan_esrgan" )
@@ -1018,15 +1260,34 @@ skip_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoisin
if imgProcessorTask == True:
output_images.append(image)
+ image = perform_masked_image_restoration(
+ image, init_img, init_mask,
+ mask_blur_strength, mask_restore,
+ # RealESRGAN image already processed in if-case above.
+ use_RealESRGAN = False, RealESRGAN = None
+ )
+
if not skip_save:
- save_sample(image, sample_path_i, filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale,
-normalize_prompt_weights, use_GFPGAN, write_info_files, write_sample_info_to_log_file, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, skip_save,
+ save_sample(image, sample_path_i, filename, jpg_sample, write_info_files, write_sample_info_to_log_file, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, skip_save,
skip_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, False)
if add_original_image or not simple_templating:
output_images.append(image)
if simple_templating:
grid_captions.append( captions[i] )
+ # Save the progress images?
+ if job_info:
+ if job_info.rec_steps_enabled and (job_info.rec_steps_to_file or job_info.rec_steps_to_gallery):
+ steps_grid = image_grid(job_info.rec_steps_imgs, 1)
+ if job_info.rec_steps_to_gallery:
+ gallery_img_size = tuple(2*dim for dim in image.size)
+ output_images.append( steps_grid.resize( gallery_img_size ) )
+ if job_info.rec_steps_to_file:
+ steps_grid_filename = f"{original_filename}_step_grid"
+ save_sample(steps_grid, sample_path_i, steps_grid_filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale,
+ normalize_prompt_weights, use_GFPGAN, write_info_files, write_sample_info_to_log_file, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, skip_save,
+ skip_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, False)
+
if opt.optimized:
mem = torch.cuda.memory_allocated()/1e6
modelFS.to("cpu")
@@ -1046,7 +1307,7 @@ skip_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoisin
import traceback
print("Error creating prompt_matrix text:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
- elif batch_size > 1 or n_iter > 1:
+ elif len(output_images) > 0 and (batch_size > 1 or n_iter > 1):
grid = image_grid(output_images, batch_size)
if grid is not None:
grid_count = get_next_sequence_number(outpath, 'grid-')
@@ -1101,8 +1362,13 @@ def txt2img(prompt: str, ddim_steps: int, sampler_name: str, toggles: List[int],
write_info_files = 5 in toggles
write_to_one_file = 6 in toggles
jpg_sample = 7 in toggles
- use_GFPGAN = 8 in toggles
- use_RealESRGAN = 9 in toggles
+ filter_nsfw = 8 in toggles
+ use_GFPGAN = 9 in toggles
+ use_RealESRGAN = 10 in toggles
+
+ do_color_correction = False
+ correction_target = None
+
ModelLoader(['model'],True,False)
if use_GFPGAN and not use_RealESRGAN:
ModelLoader(['GFPGAN'],True,False)
@@ -1134,8 +1400,8 @@ def txt2img(prompt: str, ddim_steps: int, sampler_name: str, toggles: List[int],
def init():
pass
- def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name):
- samples_ddim, _ = sampler.sample(S=ddim_steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=cfg_scale, unconditional_conditioning=unconditional_conditioning, eta=ddim_eta, x_T=x)
+ def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name, img_callback: Callable = None):
+ samples_ddim, _ = sampler.sample(S=ddim_steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=cfg_scale, unconditional_conditioning=unconditional_conditioning, eta=ddim_eta, x_T=x, img_callback=img_callback)
return samples_ddim
try:
@@ -1155,6 +1421,7 @@ def txt2img(prompt: str, ddim_steps: int, sampler_name: str, toggles: List[int],
width=width,
height=height,
prompt_matrix=prompt_matrix,
+ filter_nsfw=filter_nsfw,
use_GFPGAN=use_GFPGAN,
use_RealESRGAN=use_RealESRGAN,
realesrgan_model_name=realesrgan_model_name,
@@ -1168,6 +1435,8 @@ def txt2img(prompt: str, ddim_steps: int, sampler_name: str, toggles: List[int],
variant_amount=variant_amount,
variant_seed=variant_seed,
job_info=job_info,
+ do_color_correction=do_color_correction,
+ correction_target=correction_target
)
del sampler
@@ -1225,7 +1494,15 @@ class Flagging(gr.FlaggingCallback):
print("Logged:", filenames[0])
-def img2img(prompt: str, image_editor_mode: str, mask_mode: str, mask_blur_strength: int, ddim_steps: int, sampler_name: str,
+def blurArr(a,r=8):
+ im1=Image.fromarray((a*255).astype(np.int8),"L")
+ im2 = im1.filter(ImageFilter.GaussianBlur(radius = r))
+ out= np.array(im2)/255
+ return out
+
+
+
+def img2img(prompt: str, image_editor_mode: str, mask_mode: str, mask_blur_strength: int, mask_restore: bool, ddim_steps: int, sampler_name: str,
toggles: List[int], realesrgan_model_name: str, n_iter: int, cfg_scale: float, denoising_strength: float,
seed: int, height: int, width: int, resize_mode: int, init_info: any = None, init_info_mask: any = None, fp = None, job_info: JobInfo = None):
# print([prompt, image_editor_mode, init_info, init_info_mask, mask_mode,
@@ -1249,8 +1526,10 @@ def img2img(prompt: str, image_editor_mode: str, mask_mode: str, mask_blur_stren
write_info_files = 7 in toggles
write_sample_info_to_log_file = 8 in toggles
jpg_sample = 9 in toggles
- use_GFPGAN = 10 in toggles
- use_RealESRGAN = 11 in toggles
+ do_color_correction = 10 in toggles
+ filter_nsfw = 11 in toggles
+ use_GFPGAN = 12 in toggles
+ use_RealESRGAN = 13 in toggles
ModelLoader(['model'],True,False)
if use_GFPGAN and not use_RealESRGAN:
ModelLoader(['GFPGAN'],True,False)
@@ -1279,10 +1558,12 @@ def img2img(prompt: str, image_editor_mode: str, mask_mode: str, mask_blur_stren
if image_editor_mode == 'Mask':
init_img = init_info_mask["image"]
+ init_img_transparency = ImageOps.invert(init_img.split()[-1]).convert('L').point(lambda x: 255 if x > 0 else 0, mode='1')
init_img = init_img.convert("RGB")
init_img = resize_image(resize_mode, init_img, width, height)
init_img = init_img.convert("RGB")
init_mask = init_info_mask["mask"]
+ init_mask = ImageChops.lighter(init_img_transparency, init_mask.convert('L')).convert('RGBA')
init_mask = init_mask.convert("RGB")
init_mask = resize_image(resize_mode, init_mask, width, height)
init_mask = init_mask.convert("RGB")
@@ -1305,16 +1586,7 @@ def img2img(prompt: str, image_editor_mode: str, mask_mode: str, mask_blur_stren
image = torch.from_numpy(image)
mask_channel = None
- if image_editor_mode == "Uncrop":
- alpha = init_img.convert("RGBA")
- alpha = resize_image(resize_mode, alpha, width // 8, height // 8)
- mask_channel = alpha.split()[-1]
- mask_channel = mask_channel.filter(ImageFilter.GaussianBlur(4))
- mask_channel = np.array(mask_channel)
- mask_channel[mask_channel >= 255] = 255
- mask_channel[mask_channel < 255] = 0
- mask_channel = Image.fromarray(mask_channel).filter(ImageFilter.GaussianBlur(2))
- elif image_editor_mode == "Mask":
+ if image_editor_mode == "Mask":
alpha = init_mask.convert("RGBA")
alpha = resize_image(resize_mode, alpha, width // 8, height // 8)
mask_channel = alpha.split()[1]
@@ -1329,11 +1601,62 @@ def img2img(prompt: str, image_editor_mode: str, mask_mode: str, mask_blur_stren
if opt.optimized:
modelFS.to(device)
- init_image = 2. * image - 1.
+ #let's try and find where init_image is 0's
+ #shape is probably (3,width,height)?
+
+ if image_editor_mode == "Uncrop":
+ _image=image.numpy()[0]
+ _mask=np.ones((_image.shape[1],_image.shape[2]))
+
+ #compute bounding box
+ cmax=np.max(_image,axis=0)
+ rowmax=np.max(cmax,axis=0)
+ colmax=np.max(cmax,axis=1)
+ rowwhere=np.where(rowmax>0)[0]
+ colwhere=np.where(colmax>0)[0]
+ rowstart=rowwhere[0]
+ rowend=rowwhere[-1]+1
+ colstart=colwhere[0]
+ colend=colwhere[-1]+1
+ print('bounding box: ',rowstart,rowend,colstart,colend)
+
+ #this is where noise will get added
+ PAD_IMG=16
+ boundingbox=np.zeros(shape=(height,width))
+ boundingbox[colstart+PAD_IMG:colend-PAD_IMG,rowstart+PAD_IMG:rowend-PAD_IMG]=1
+ boundingbox=blurArr(boundingbox,4)
+
+ #this is the mask for outpainting
+ PAD_MASK=24
+ boundingbox2=np.zeros(shape=(height,width))
+ boundingbox2[colstart+PAD_MASK:colend-PAD_MASK,rowstart+PAD_MASK:rowend-PAD_MASK]=1
+ boundingbox2=blurArr(boundingbox2,4)
+
+ #noise=np.random.randn(*_image.shape)
+ noise=np.array([perlinNoise(height,width,height/64,width/64) for i in range(3)])
+ _mask*=1-boundingbox2
+
+ #convert 0,1 to -1,1
+ _image = 2. * _image - 1.
+
+ #add noise
+ boundingbox=np.tile(boundingbox,(3,1,1))
+ _image=_image*boundingbox+noise*(1-boundingbox)
+
+ #resize mask
+ _mask = np.array(resize_image(resize_mode, Image.fromarray(_mask*255), width // 8, height // 8))/255
+
+ #convert back to torch tensor
+ init_image=torch.from_numpy(np.expand_dims(_image,axis=0).astype(np.float32)).to(device)
+ mask=torch.from_numpy(_mask.astype(np.float32)).to(device)
+
+ else:
+ init_image = 2. * image - 1.
+
init_image = init_image.to(device)
init_image = repeat(init_image, '1 ... -> b ...', b=batch_size)
init_latent = (model if not opt.optimized else modelFS).get_first_stage_encoding((model if not opt.optimized else modelFS).encode_first_stage(init_image)) # move to latent space
-
+
if opt.optimized:
mem = torch.cuda.memory_allocated()/1e6
modelFS.to("cpu")
@@ -1342,7 +1665,7 @@ def img2img(prompt: str, image_editor_mode: str, mask_mode: str, mask_blur_stren
return init_latent, mask,
- def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name):
+ def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name, img_callback: Callable = None):
t_enc_steps = t_enc
obliterate = False
if ddim_steps == t_enc_steps:
@@ -1364,7 +1687,7 @@ def img2img(prompt: str, image_editor_mode: str, mask_mode: str, mask_blur_stren
sigma_sched = sigmas[ddim_steps - t_enc_steps - 1:]
model_wrap_cfg = CFGMaskedDenoiser(sampler.model_wrap)
- samples_ddim = K.sampling.__dict__[f'sample_{sampler.get_sampler_name()}'](model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': cfg_scale, 'mask': z_mask, 'x0': x0, 'xi': xi}, disable=False)
+ samples_ddim = K.sampling.__dict__[f'sample_{sampler.get_sampler_name()}'](model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': cfg_scale, 'mask': z_mask, 'x0': x0, 'xi': xi}, disable=False, callback=partial(KDiffusionSampler.img_callback_wrapper, img_callback))
else:
x0, z_mask = init_data
@@ -1385,18 +1708,14 @@ def img2img(prompt: str, image_editor_mode: str, mask_mode: str, mask_blur_stren
return samples_ddim
-
+ correction_target = None
if loopback:
output_images, info = None, None
history = []
initial_seed = None
- do_color_correction = False
- try:
- from skimage import exposure
- do_color_correction = True
- except:
- print("Install scikit-image to perform color correction on loopback")
+ # turn on color correction for loopback to prevent known issue of color drift
+ do_color_correction = True
for i in range(n_iter):
if do_color_correction and i == 0:
@@ -1418,6 +1737,7 @@ def img2img(prompt: str, image_editor_mode: str, mask_mode: str, mask_blur_stren
width=width,
height=height,
prompt_matrix=prompt_matrix,
+ filter_nsfw=filter_nsfw,
use_GFPGAN=use_GFPGAN,
use_RealESRGAN=False, # Forcefully disable upscaling when using loopback
realesrgan_model_name=realesrgan_model_name,
@@ -1428,6 +1748,7 @@ def img2img(prompt: str, image_editor_mode: str, mask_mode: str, mask_blur_stren
init_mask=init_mask,
keep_mask=keep_mask,
mask_blur_strength=mask_blur_strength,
+ mask_restore=mask_restore,
denoising_strength=denoising_strength,
resize_mode=resize_mode,
uses_loopback=loopback,
@@ -1436,7 +1757,9 @@ def img2img(prompt: str, image_editor_mode: str, mask_mode: str, mask_blur_stren
write_info_files=write_info_files,
write_sample_info_to_log_file=write_sample_info_to_log_file,
jpg_sample=jpg_sample,
- job_info=job_info
+ job_info=job_info,
+ do_color_correction=do_color_correction,
+ correction_target=correction_target
)
if initial_seed is None:
@@ -1444,16 +1767,6 @@ def img2img(prompt: str, image_editor_mode: str, mask_mode: str, mask_blur_stren
init_img = output_images[0]
- if do_color_correction and correction_target is not None:
- init_img = Image.fromarray(cv2.cvtColor(exposure.match_histograms(
- cv2.cvtColor(
- np.asarray(init_img),
- cv2.COLOR_RGB2LAB
- ),
- correction_target,
- channel_axis=2
- ), cv2.COLOR_LAB2RGB).astype("uint8"))
-
if not random_seed_loopback:
seed = seed + 1
else:
@@ -1472,6 +1785,9 @@ def img2img(prompt: str, image_editor_mode: str, mask_mode: str, mask_blur_stren
seed = initial_seed
else:
+ if do_color_correction:
+ correction_target = cv2.cvtColor(np.asarray(init_img.copy()), cv2.COLOR_RGB2LAB)
+
output_images, seed, info, stats = process_images(
outpath=outpath,
func_init=init,
@@ -1488,6 +1804,7 @@ def img2img(prompt: str, image_editor_mode: str, mask_mode: str, mask_blur_stren
width=width,
height=height,
prompt_matrix=prompt_matrix,
+ filter_nsfw=filter_nsfw,
use_GFPGAN=use_GFPGAN,
use_RealESRGAN=use_RealESRGAN,
realesrgan_model_name=realesrgan_model_name,
@@ -1498,13 +1815,16 @@ def img2img(prompt: str, image_editor_mode: str, mask_mode: str, mask_blur_stren
keep_mask=keep_mask,
mask_blur_strength=mask_blur_strength,
denoising_strength=denoising_strength,
+ mask_restore=mask_restore,
resize_mode=resize_mode,
uses_loopback=loopback,
sort_samples=sort_samples,
write_info_files=write_info_files,
write_sample_info_to_log_file=write_sample_info_to_log_file,
jpg_sample=jpg_sample,
- job_info=job_info
+ job_info=job_info,
+ do_color_correction=do_color_correction,
+ correction_target=correction_target
)
del sampler
@@ -1572,8 +1892,13 @@ def imgproc(image,image_batch,imgproc_prompt,imgproc_toggles, imgproc_upscale_to
images = []
def processGFPGAN(image,strength):
image = image.convert("RGB")
+ metadata = ImageMetadata.get_from_image(image)
cropped_faces, restored_faces, restored_img = GFPGAN.enhance(np.array(image, dtype=np.uint8), has_aligned=False, only_center_face=False, paste_back=True)
result = Image.fromarray(restored_img)
+ if metadata:
+ metadata.GFPGAN = True
+ ImageMetadata.set_on_image(image, metadata)
+
if strength < 1.0:
result = Image.blend(image, result, strength)
@@ -1585,15 +1910,18 @@ def imgproc(image,image_batch,imgproc_prompt,imgproc_toggles, imgproc_upscale_to
else:
modelMode = imgproc_realesrgan_model_name
image = image.convert("RGB")
+ metadata = ImageMetadata.get_from_image(image)
RealESRGAN = load_RealESRGAN(modelMode)
result, res = RealESRGAN.enhance(np.array(image, dtype=np.uint8))
result = Image.fromarray(result)
+ ImageMetadata.set_on_image(result, metadata)
if 'x2' in imgproc_realesrgan_model_name:
# downscale to 1/2 size
result = result.resize((result.width//2, result.height//2), LANCZOS)
return result
def processGoBig(image):
+ metadata = ImageMetadata.get_from_image(image)
result = processRealESRGAN(image,)
if 'x4' in imgproc_realesrgan_model_name:
#downscale to 1/2 size
@@ -1638,6 +1966,7 @@ def imgproc(image,image_batch,imgproc_prompt,imgproc_toggles, imgproc_upscale_to
init_img = result
init_mask = None
keep_mask = False
+ mask_restore = False
assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]'
def init():
@@ -1663,7 +1992,7 @@ def imgproc(image,image_batch,imgproc_prompt,imgproc_toggles, imgproc_upscale_to
return init_latent,
- def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name):
+ def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name, img_callback: Callable = None):
if sampler_name != 'DDIM':
x0, = init_data
@@ -1673,7 +2002,7 @@ def imgproc(image,image_batch,imgproc_prompt,imgproc_toggles, imgproc_upscale_to
xi = x0 + noise
sigma_sched = sigmas[ddim_steps - t_enc - 1:]
model_wrap_cfg = CFGDenoiser(sampler.model_wrap)
- samples_ddim = K.sampling.__dict__[f'sample_{sampler.get_sampler_name()}'](model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': cfg_scale}, disable=False)
+ samples_ddim = K.sampling.__dict__[f'sample_{sampler.get_sampler_name()}'](model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': cfg_scale}, disable=False, callback=partial(KDiffusionSampler.img_callback_wrapper, img_callback))
else:
x0, = init_data
sampler.make_schedule(ddim_num_steps=ddim_steps, ddim_eta=0.0, verbose=False)
@@ -1774,6 +2103,7 @@ def imgproc(image,image_batch,imgproc_prompt,imgproc_toggles, imgproc_upscale_to
width=width,
height=height,
prompt_matrix=None,
+ filter_nsfw=False,
use_GFPGAN=None,
use_RealESRGAN=None,
realesrgan_model_name=None,
@@ -1784,6 +2114,7 @@ def imgproc(image,image_batch,imgproc_prompt,imgproc_toggles, imgproc_upscale_to
keep_mask=False,
mask_blur_strength=None,
denoising_strength=denoising_strength,
+ mask_restore=mask_restore,
resize_mode=resize_mode,
uses_loopback=False,
sort_samples=True,
@@ -1808,11 +2139,14 @@ def imgproc(image,image_batch,imgproc_prompt,imgproc_toggles, imgproc_upscale_to
del sampler
torch.cuda.empty_cache()
+ ImageMetadata.set_on_image(combined_image, metadata)
return combined_image
def processLDSR(image):
+ metadata = ImageMetadata.get_from_image(image)
result = LDSR.superResolution(image,int(imgproc_ldsr_steps),str(imgproc_ldsr_pre_downSample),str(imgproc_ldsr_post_downSample))
- return result
-
+ ImageMetadata.set_on_image(result, metadata)
+ return result
+
if image_batch != None:
if image != None:
@@ -1839,7 +2173,7 @@ def imgproc(image,image_batch,imgproc_prompt,imgproc_toggles, imgproc_upscale_to
if 1 in imgproc_toggles:
if imgproc_upscale_toggles == 0:
ModelLoader(['GFPGAN','LDSR'],False,True) # Unload unused models
- ModelLoader(['RealESGAN'],True,False,imgproc_realesrgan_model_name) # Load used models
+ ModelLoader(['RealESGAN'],True,False,imgproc_realesrgan_model_name) # Load used models
elif imgproc_upscale_toggles == 1:
ModelLoader(['GFPGAN','LDSR'],False,True) # Unload unused models
ModelLoader(['RealESGAN','model'],True,False) # Load used models
@@ -1851,10 +2185,14 @@ def imgproc(image,image_batch,imgproc_prompt,imgproc_toggles, imgproc_upscale_to
ModelLoader(['GFPGAN','LDSR'],False,True) # Unload unused models
ModelLoader(['RealESGAN','model'],True,False,imgproc_realesrgan_model_name) # Load used models
for image in images:
+ metadata = ImageMetadata.get_from_image(image)
if 0 in imgproc_toggles:
#recheck if GFPGAN is loaded since it's the only model that can be loaded in the loop as well
ModelLoader(['GFPGAN'],True,False) # Load used models
image = processGFPGAN(image,imgproc_gfpgan_strength)
+ if metadata:
+ metadata.GFPGAN = True
+ ImageMetadata.set_on_image(image, metadata)
outpathDir = os.path.join(outpath,'GFPGAN')
os.makedirs(outpathDir, exist_ok=True)
batchNumber = get_next_sequence_number(outpathDir)
@@ -1862,47 +2200,51 @@ def imgproc(image,image_batch,imgproc_prompt,imgproc_toggles, imgproc_upscale_to
if 1 not in imgproc_toggles:
output.append(image)
- save_sample(image, outpathDir, outFilename, False, None, None, None, None, None, None, None, None, None, None, None, None, None, False, None, None, None, None, None, None, None, None, None, True)
+ save_sample(image, outpathDir, outFilename, False, None, None, None, None, None, False, None, None, None, None, None, None, None, None, None, False)
if 1 in imgproc_toggles:
if imgproc_upscale_toggles == 0:
image = processRealESRGAN(image)
+ ImageMetadata.set_on_image(image, metadata)
outpathDir = os.path.join(outpath,'RealESRGAN')
os.makedirs(outpathDir, exist_ok=True)
batchNumber = get_next_sequence_number(outpathDir)
outFilename = str(batchNumber)+'-'+'result'
output.append(image)
- save_sample(image, outpathDir, outFilename, False, None, None, None, None, None, None, None, None, None, None, None, None, None, False, None, None, None, None, None, None, None, None, None, True)
+ save_sample(image, outpathDir, outFilename, False, None, None, None, None, None, False, None, None, None, None, None, None, None, None, None, False)
elif imgproc_upscale_toggles == 1:
image = processGoBig(image)
+ ImageMetadata.set_on_image(image, metadata)
outpathDir = os.path.join(outpath,'GoBig')
os.makedirs(outpathDir, exist_ok=True)
batchNumber = get_next_sequence_number(outpathDir)
outFilename = str(batchNumber)+'-'+'result'
output.append(image)
- save_sample(image, outpathDir, outFilename, False, None, None, None, None, None, None, None, None, None, None, None, None, None, False, None, None, None, None, None, None, None, None, None, True)
+ save_sample(image, outpathDir, outFilename, False, None, None, None, None, None, False, None, None, None, None, None, None, None, None, None, False)
elif imgproc_upscale_toggles == 2:
image = processLDSR(image)
+ ImageMetadata.set_on_image(image, metadata)
outpathDir = os.path.join(outpath,'LDSR')
os.makedirs(outpathDir, exist_ok=True)
batchNumber = get_next_sequence_number(outpathDir)
outFilename = str(batchNumber)+'-'+'result'
output.append(image)
- save_sample(image, outpathDir, outFilename, False, None, None, None, None, None, None, None, None, None, None, None, None, None, False, None, None, None, None, None, None, None, None, None, True)
+ save_sample(image, outpathDir, outFilename, False, None, None, None, None, None, False, None, None, None, None, None, None, None, None, None, False)
elif imgproc_upscale_toggles == 3:
image = processGoBig(image)
ModelLoader(['model','GFPGAN','RealESGAN'],False,True) # Unload unused models
ModelLoader(['LDSR'],True,False) # Load used models
image = processLDSR(image)
+ ImageMetadata.set_on_image(image, metadata)
outpathDir = os.path.join(outpath,'GoLatent')
os.makedirs(outpathDir, exist_ok=True)
batchNumber = get_next_sequence_number(outpathDir)
outFilename = str(batchNumber)+'-'+'result'
output.append(image)
- save_sample(image, outpathDir, outFilename, None, None, None, None, None, None, None, None, None, None, None, None, None, None, False, None, None, None, None, None, None, None, None, None, True)
+ save_sample(image, outpathDir, outFilename, None, None, None, None, None, None, False, None, None, None, None, None, None, None, None, None, False)
#LDSR is always unloaded to avoid memory issues
#ModelLoader(['LDSR'],False,True)
@@ -1952,10 +2294,13 @@ def ModelLoader(models,load=False,unload=False,imgproc_realesrgan_model_name='Re
def run_GFPGAN(image, strength):
ModelLoader(['LDSR','RealESRGAN'],False,True)
ModelLoader(['GFPGAN'],True,False)
+ metadata = ImageMetadata.get_from_image(image)
image = image.convert("RGB")
cropped_faces, restored_faces, restored_img = GFPGAN.enhance(np.array(image, dtype=np.uint8), has_aligned=False, only_center_face=False, paste_back=True)
res = Image.fromarray(restored_img)
+ metadata.GFPGAN = True
+ ImageMetadata.set_on_image(res, metadata)
if strength < 1.0:
res = Image.blend(image, res, strength)
@@ -1968,10 +2313,12 @@ def run_RealESRGAN(image, model_name: str):
if RealESRGAN.model.name != model_name:
try_loading_RealESRGAN(model_name)
+ metadata = ImageMetadata.get_from_image(image)
image = image.convert("RGB")
output, img_mode = RealESRGAN.enhance(np.array(image, dtype=np.uint8))
res = Image.fromarray(output)
+ ImageMetadata.set_on_image(res, metadata)
return res
@@ -1997,6 +2344,7 @@ txt2img_toggles = [
'Write sample info files',
'write sample info to log file',
'jpg samples',
+ 'Filter NSFW content',
]
if GFPGAN is not None:
@@ -2057,6 +2405,8 @@ img2img_toggles = [
'Write sample info files',
'Write sample info to one file',
'jpg samples',
+ 'Color correction (always enabled on loopback mode)',
+ 'Filter NSFW content',
]
# removed for now becuase of Image Lab implementation
if GFPGAN is not None:
@@ -2086,6 +2436,7 @@ img2img_defaults = {
'cfg_scale': 5.0,
'denoising_strength': 0.75,
'mask_mode': 0,
+ 'mask_restore': False,
'resize_mode': 0,
'seed': '',
'height': 512,
@@ -2099,24 +2450,6 @@ if 'img2img' in user_defaults:
img2img_toggle_defaults = [img2img_toggles[i] for i in img2img_defaults['toggles']]
img2img_image_mode = 'sketch'
-def change_image_editor_mode(choice, cropped_image, resize_mode, width, height):
- if choice == "Mask":
- return [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), gr.update(visible=True)]
- return [gr.update(visible=True), gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=True), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)]
-
-def update_image_mask(cropped_image, resize_mode, width, height):
- resized_cropped_image = resize_image(resize_mode, cropped_image, width, height) if cropped_image else None
- return gr.update(value=resized_cropped_image)
-
-
-
-def copy_img_to_upscale_esrgan(img):
- update = gr.update(selected='realesrgan_tab')
- image_data = re.sub('^data:image/.+;base64,', '', img)
- processed_image = Image.open(BytesIO(base64.b64decode(image_data)))
- return {'realesrgan_source': processed_image, 'tabs': update}
-
-
help_text = """
## Mask/Crop
* The masking/cropping is very temperamental.
@@ -2178,7 +2511,7 @@ class ServerLauncher(threading.Thread):
'inbrowser': opt.inbrowser,
'server_name': '0.0.0.0',
'server_port': opt.port,
- 'share': opt.share,
+ 'share': opt.share,
'show_error': True
}
if not opt.share:
diff --git a/scripts/webui_streamlit.py b/scripts/webui_streamlit.py
index bd680da..15a2e2f 100644
--- a/scripts/webui_streamlit.py
+++ b/scripts/webui_streamlit.py
@@ -1,46 +1,31 @@
-import warnings
+# base webui import and utils.
import streamlit as st
-from streamlit import StopException, StreamlitAPIException
-import base64, cv2
-import argparse, os, sys, glob, re, random, datetime
-from PIL import Image, ImageFont, ImageDraw, ImageFilter, ImageOps
-from PIL.PngImagePlugin import PngInfo
-import requests
-from scipy import integrate
-import torch
-from torchdiffeq import odeint
-from tqdm.auto import trange, tqdm
+# streamlit imports
+import streamlit_nested_layout
+
+#streamlit components section
+from st_on_hover_tabs import on_hover_tabs
+
+#other imports
+
+import warnings
+import os
import k_diffusion as K
-import math
-import mimetypes
-import numpy as np
-import pynvml
-import threading, asyncio
-import time
-import torch
-from torch import autocast
-from torchvision import transforms
-import torch.nn as nn
-import yaml
-from typing import List, Union
-from pathlib import Path
-from tqdm import tqdm
-from contextlib import contextmanager, nullcontext
-from einops import rearrange, repeat
-from itertools import islice
from omegaconf import OmegaConf
-from io import BytesIO
-from ldm.models.diffusion.ddim import DDIMSampler
-from ldm.models.diffusion.plms import PLMSSampler
-from ldm.util import instantiate_from_config
-from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, \
- extract_into_tensor
-from retry import retry
-# we use python-slugify to make the filenames safe for windows and linux, its better than doing it manually
-# install it with 'pip install python-slugify'
-from slugify import slugify
+from sd_utils import *
+if not "defaults" in st.session_state:
+ st.session_state["defaults"] = {}
+
+st.session_state["defaults"] = OmegaConf.load("configs/webui/webui_streamlit.yaml")
+
+if (os.path.exists("configs/webui/userconfig_streamlit.yaml")):
+ user_defaults = OmegaConf.load("configs/webui/userconfig_streamlit.yaml")
+ st.session_state["defaults"] = OmegaConf.merge(st.session_state["defaults"], user_defaults)
+
+# end of imports
+#---------------------------------------------------------------------------------------------------------------
try:
# this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
@@ -53,1412 +38,18 @@ except:
# remove some annoying deprecation warnings that show every now and then.
warnings.filterwarnings("ignore", category=DeprecationWarning)
-defaults = OmegaConf.load("configs/webui/webui_streamlit.yaml")
-
-# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the bowser will not show any UI
-mimetypes.init()
-mimetypes.add_type('application/javascript', '.js')
-
-# some of those options should not be changed at all because they would break the model, so I removed them from options.
-opt_C = 4
-opt_f = 8
-
-# should and will be moved to a settings menu in the UI at some point
-grid_format = [s.lower() for s in defaults.general.grid_format.split(':')]
-grid_lossless = False
-grid_quality = 100
-if grid_format[0] == 'png':
- grid_ext = 'png'
- grid_format = 'png'
-elif grid_format[0] in ['jpg', 'jpeg']:
- grid_quality = int(grid_format[1]) if len(grid_format) > 1 else 100
- grid_ext = 'jpg'
- grid_format = 'jpeg'
-elif grid_format[0] == 'webp':
- grid_quality = int(grid_format[1]) if len(grid_format) > 1 else 100
- grid_ext = 'webp'
- grid_format = 'webp'
- if grid_quality < 0: # e.g. webp:-100 for lossless mode
- grid_lossless = True
- grid_quality = abs(grid_quality)
-
# this should force GFPGAN and RealESRGAN onto the selected gpu as well
-os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" # see issue #152
-os.environ["CUDA_VISIBLE_DEVICES"] = str(defaults.general.gpu)
-
-@retry(tries=5)
-def load_models(continue_prev_run = False, use_GFPGAN=False, use_RealESRGAN=False, RealESRGAN_model="RealESRGAN_x4plus"):
- """Load the different models. We also reuse the models that are already in memory to speed things up instead of loading them again. """
-
- print ("Loading models.")
-
- # Generate random run ID
- # Used to link runs linked w/ continue_prev_run which is not yet implemented
- # Use URL and filesystem safe version just in case.
- st.session_state["run_id"] = base64.urlsafe_b64encode(
- os.urandom(6)
- ).decode("ascii")
-
- # check what models we want to use and if the they are already loaded.
-
- if use_GFPGAN:
- if "GFPGAN" in st.session_state:
- print("GFPGAN already loaded")
- else:
- # Load GFPGAN
- if os.path.exists(defaults.general.GFPGAN_dir):
- try:
- st.session_state["GFPGAN"] = load_GFPGAN()
- print("Loaded GFPGAN")
- except Exception:
- import traceback
- print("Error loading GFPGAN:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
- else:
- if "GFPGAN" in st.session_state:
- del st.session_state["GFPGAN"]
-
- if use_RealESRGAN:
- if "RealESRGAN" in st.session_state and st.session_state["RealESRGAN"].model.name == RealESRGAN_model:
- print("RealESRGAN already loaded")
- else:
- #Load RealESRGAN
- try:
- # We first remove the variable in case it has something there,
- # some errors can load the model incorrectly and leave things in memory.
- del st.session_state["RealESRGAN"]
- except KeyError:
- pass
-
- if os.path.exists(defaults.general.RealESRGAN_dir):
- # st.session_state is used for keeping the models in memory across multiple pages or runs.
- st.session_state["RealESRGAN"] = load_RealESRGAN(RealESRGAN_model)
- print("Loaded RealESRGAN with model "+ st.session_state["RealESRGAN"].model.name)
-
- else:
- if "RealESRGAN" in st.session_state:
- del st.session_state["RealESRGAN"]
-
-
- if "model" in st.session_state:
- print("Model already loaded")
- else:
- config = OmegaConf.load("configs/stable-diffusion/v1-inference.yaml")
- model = load_model_from_config(config, defaults.general.ckpt)
-
- st.session_state["device"] = torch.device(f"cuda:{defaults.general.gpu}") if torch.cuda.is_available() else torch.device("cpu")
- st.session_state["model"] = (model if defaults.general.no_half else model.half()).to(st.session_state["device"] )
-
- print("Model loaded.")
-
-
-def load_model_from_config(config, ckpt, verbose=False):
-
- print(f"Loading model from {ckpt}")
-
- pl_sd = torch.load(ckpt, map_location="cpu")
- if "global_step" in pl_sd:
- print(f"Global Step: {pl_sd['global_step']}")
- sd = pl_sd["state_dict"]
- model = instantiate_from_config(config.model)
- m, u = model.load_state_dict(sd, strict=False)
- if len(m) > 0 and verbose:
- print("missing keys:")
- print(m)
- if len(u) > 0 and verbose:
- print("unexpected keys:")
- print(u)
-
- model.cuda()
- model.eval()
- return model
-
-def load_sd_from_config(ckpt, verbose=False):
- print(f"Loading model from {ckpt}")
- pl_sd = torch.load(ckpt, map_location="cpu")
- if "global_step" in pl_sd:
- print(f"Global Step: {pl_sd['global_step']}")
- sd = pl_sd["state_dict"]
- return sd
-#
-@retry(tries=5)
-def generation_callback(img, i=0):
-
- try:
- if i == 0:
- if img['i']: i = img['i']
- except TypeError:
- pass
-
-
- if i % int(defaults.general.update_preview_frequency) == 0 and defaults.general.update_preview:
- #print (img)
- #print (type(img))
- # The following lines will convert the tensor we got on img to an actual image we can render on the UI.
- # It can probably be done in a better way for someone who knows what they're doing. I don't.
- #print (img,isinstance(img, torch.Tensor))
- if isinstance(img, torch.Tensor):
- x_samples_ddim = (st.session_state["model"] if not defaults.general.optimized else modelFS).decode_first_stage(img)
- else:
- # When using the k Diffusion samplers they return a dict instead of a tensor that look like this:
- # {'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}
- x_samples_ddim = (st.session_state["model"] if not defaults.general.optimized else modelFS).decode_first_stage(img["denoised"])
-
- x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
-
- pil_image = transforms.ToPILImage()(x_samples_ddim.squeeze_(0))
-
- # update image on the UI so we can see the progress
- st.session_state["preview_image"].image(pil_image)
-
- # Show a progress bar so we can keep track of the progress even when the image progress is not been shown,
- # Dont worry, it doesnt affect the performance.
- if st.session_state["generation_mode"] == "txt2img":
- percent = int(100 * float(i+1 if i+1 < st.session_state.sampling_steps else st.session_state.sampling_steps)/float(st.session_state.sampling_steps))
- st.session_state["progress_bar_text"].text(
- f"Running step: {i+1 if i+1 < st.session_state.sampling_steps else st.session_state.sampling_steps}/{st.session_state.sampling_steps} {percent if percent < 100 else 100}%")
- else:
- round_sampling_steps = round(st.session_state.sampling_steps * st.session_state["denoising_strength"])
- percent = int(100 * float(i+1 if i+1 < round_sampling_steps else round_sampling_steps)/float(round_sampling_steps))
- st.session_state["progress_bar_text"].text(
- f"""Running step: {i+1 if i+1 < round_sampling_steps else round_sampling_steps}/{round_sampling_steps} {percent if percent < 100 else 100}%""")
-
- st.session_state["progress_bar"].progress(percent if percent < 100 else 100)
-
-
-
-class MemUsageMonitor(threading.Thread):
- stop_flag = False
- max_usage = 0
- total = -1
-
- def __init__(self, name):
- threading.Thread.__init__(self)
- self.name = name
-
- def run(self):
- try:
- pynvml.nvmlInit()
- except:
- print(f"[{self.name}] Unable to initialize NVIDIA management. No memory stats. \n")
- return
- print(f"[{self.name}] Recording max memory usage...\n")
- handle = pynvml.nvmlDeviceGetHandleByIndex(defaults.general.gpu)
- self.total = pynvml.nvmlDeviceGetMemoryInfo(handle).total
- while not self.stop_flag:
- m = pynvml.nvmlDeviceGetMemoryInfo(handle)
- self.max_usage = max(self.max_usage, m.used)
- # print(self.max_usage)
- time.sleep(0.1)
- print(f"[{self.name}] Stopped recording.\n")
- pynvml.nvmlShutdown()
-
- def read(self):
- return self.max_usage, self.total
-
- def stop(self):
- self.stop_flag = True
-
- def read_and_stop(self):
- self.stop_flag = True
- return self.max_usage, self.total
-
-class CFGMaskedDenoiser(nn.Module):
- def __init__(self, model):
- super().__init__()
- self.inner_model = model
-
- def forward(self, x, sigma, uncond, cond, cond_scale, mask, x0, xi):
- x_in = x
- x_in = torch.cat([x_in] * 2)
- sigma_in = torch.cat([sigma] * 2)
- cond_in = torch.cat([uncond, cond])
- uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
- denoised = uncond + (cond - uncond) * cond_scale
-
- if mask is not None:
- assert x0 is not None
- img_orig = x0
- mask_inv = 1. - mask
- denoised = (img_orig * mask_inv) + (mask * denoised)
-
- return denoised
-
-class CFGDenoiser(nn.Module):
- def __init__(self, model):
- super().__init__()
- self.inner_model = model
-
- def forward(self, x, sigma, uncond, cond, cond_scale):
- x_in = torch.cat([x] * 2)
- sigma_in = torch.cat([sigma] * 2)
- cond_in = torch.cat([uncond, cond])
- uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
- return uncond + (cond - uncond) * cond_scale
-def append_zero(x):
- return torch.cat([x, x.new_zeros([1])])
-def append_dims(x, target_dims):
- """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
- dims_to_append = target_dims - x.ndim
- if dims_to_append < 0:
- raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
- return x[(...,) + (None,) * dims_to_append]
-def get_sigmas_karras(n, sigma_min, sigma_max, rho=7., device='cpu'):
- """Constructs the noise schedule of Karras et al. (2022)."""
- ramp = torch.linspace(0, 1, n)
- min_inv_rho = sigma_min ** (1 / rho)
- max_inv_rho = sigma_max ** (1 / rho)
- sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
- return append_zero(sigmas).to(device)
-
-
-def get_sigmas_exponential(n, sigma_min, sigma_max, device='cpu'):
- """Constructs an exponential noise schedule."""
- sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), n, device=device).exp()
- return append_zero(sigmas)
-
-
-def get_sigmas_vp(n, beta_d=19.9, beta_min=0.1, eps_s=1e-3, device='cpu'):
- """Constructs a continuous VP noise schedule."""
- t = torch.linspace(1, eps_s, n, device=device)
- sigmas = torch.sqrt(torch.exp(beta_d * t ** 2 / 2 + beta_min * t) - 1)
- return append_zero(sigmas)
-
-
-def to_d(x, sigma, denoised):
- """Converts a denoiser output to a Karras ODE derivative."""
- return (x - denoised) / append_dims(sigma, x.ndim)
-def linear_multistep_coeff(order, t, i, j):
- if order - 1 > i:
- raise ValueError(f'Order {order} too high for step {i}')
- def fn(tau):
- prod = 1.
- for k in range(order):
- if j == k:
- continue
- prod *= (tau - t[i - k]) / (t[i - j] - t[i - k])
- return prod
- return integrate.quad(fn, t[i], t[i + 1], epsrel=1e-4)[0]
-
-class KDiffusionSampler:
- def __init__(self, m, sampler):
- self.model = m
- self.model_wrap = K.external.CompVisDenoiser(m)
- self.schedule = sampler
- def get_sampler_name(self):
- return self.schedule
- def sample(self, S, conditioning, batch_size, shape, verbose, unconditional_guidance_scale, unconditional_conditioning, eta, x_T, img_callback=None, log_every_t=None):
- sigmas = self.model_wrap.get_sigmas(S)
- x = x_T * sigmas[0]
- model_wrap_cfg = CFGDenoiser(self.model_wrap)
- samples_ddim = None
- samples_ddim = K.sampling.__dict__[f'sample_{self.schedule}'](model_wrap_cfg, x, sigmas,
- extra_args={'cond': conditioning, 'uncond': unconditional_conditioning,
- 'cond_scale': unconditional_guidance_scale}, disable=False, callback=generation_callback)
- #
- return samples_ddim, None
-
-
-@torch.no_grad()
-def log_likelihood(model, x, sigma_min, sigma_max, extra_args=None, atol=1e-4, rtol=1e-4):
- extra_args = {} if extra_args is None else extra_args
- s_in = x.new_ones([x.shape[0]])
- v = torch.randint_like(x, 2) * 2 - 1
- fevals = 0
- def ode_fn(sigma, x):
- nonlocal fevals
- with torch.enable_grad():
- x = x[0].detach().requires_grad_()
- denoised = model(x, sigma * s_in, **extra_args)
- d = to_d(x, sigma, denoised)
- fevals += 1
- grad = torch.autograd.grad((d * v).sum(), x)[0]
- d_ll = (v * grad).flatten(1).sum(1)
- return d.detach(), d_ll
- x_min = x, x.new_zeros([x.shape[0]])
- t = x.new_tensor([sigma_min, sigma_max])
- sol = odeint(ode_fn, x_min, t, atol=atol, rtol=rtol, method='dopri5')
- latent, delta_ll = sol[0][-1], sol[1][-1]
- ll_prior = torch.distributions.Normal(0, sigma_max).log_prob(latent).flatten(1).sum(1)
- return ll_prior + delta_ll, {'fevals': fevals}
-
-
-def create_random_tensors(shape, seeds):
- xs = []
- for seed in seeds:
- torch.manual_seed(seed)
-
- # randn results depend on device; gpu and cpu get different results for same seed;
- # the way I see it, it's better to do this on CPU, so that everyone gets same result;
- # but the original script had it like this so i do not dare change it for now because
- # it will break everyone's seeds.
- xs.append(torch.randn(shape, device=defaults.general.gpu))
- x = torch.stack(xs)
- return x
-
-def torch_gc():
- torch.cuda.empty_cache()
- torch.cuda.ipc_collect()
-
-def load_GFPGAN():
- model_name = 'GFPGANv1.3'
- model_path = os.path.join(defaults.general.GFPGAN_dir, 'experiments/pretrained_models', model_name + '.pth')
- if not os.path.isfile(model_path):
- raise Exception("GFPGAN model not found at path "+model_path)
-
- sys.path.append(os.path.abspath(defaults.general.GFPGAN_dir))
- from gfpgan import GFPGANer
-
- if defaults.general.gfpgan_cpu or defaults.general.extra_models_cpu:
- instance = GFPGANer(model_path=model_path, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=torch.device('cpu'))
- elif defaults.general.extra_models_gpu:
- instance = GFPGANer(model_path=model_path, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=torch.device(f'cuda:{defaults.general.gfpgan_gpu}'))
- else:
- instance = GFPGANer(model_path=model_path, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=torch.device(f'cuda:{defaults.general.gpu}'))
- return instance
-
-def load_RealESRGAN(model_name: str):
- from basicsr.archs.rrdbnet_arch import RRDBNet
- RealESRGAN_models = {
- 'RealESRGAN_x4plus': RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4),
- 'RealESRGAN_x4plus_anime_6B': RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
- }
-
- model_path = os.path.join(defaults.general.RealESRGAN_dir, 'experiments/pretrained_models', model_name + '.pth')
- if not os.path.exists(os.path.join(defaults.general.RealESRGAN_dir, "experiments","pretrained_models", f"{model_name}.pth")):
- raise Exception(model_name+".pth not found at path "+model_path)
-
- sys.path.append(os.path.abspath(defaults.general.RealESRGAN_dir))
- from realesrgan import RealESRGANer
-
- if defaults.general.esrgan_cpu or defaults.general.extra_models_cpu:
- instance = RealESRGANer(scale=2, model_path=model_path, model=RealESRGAN_models[model_name], pre_pad=0, half=False) # cpu does not support half
- instance.device = torch.device('cpu')
- instance.model.to('cpu')
- elif defaults.general.extra_models_gpu:
- instance = RealESRGANer(scale=2, model_path=model_path, model=RealESRGAN_models[model_name], pre_pad=0, half=not defaults.general.no_half, device=torch.device(f'cuda:{defaults.general.esrgan_gpu}'))
- else:
- instance = RealESRGANer(scale=2, model_path=model_path, model=RealESRGAN_models[model_name], pre_pad=0, half=not defaults.general.no_half, device=torch.device(f'cuda:{defaults.general.gpu}'))
- instance.model.name = model_name
-
- return instance
-
-prompt_parser = re.compile("""
- (?P # capture group for 'prompt'
- [^:]+ # match one or more non ':' characters
- ) # end 'prompt'
- (?: # non-capture group
- :+ # match one or more ':' characters
- (?P # capture group for 'weight'
- -?\\d+(?:\\.\\d+)? # match positive or negative decimal number
- )? # end weight capture group, make optional
- \\s* # strip spaces after weight
- | # OR
- $ # else, if no ':' then match end of line
- ) # end non-capture group
-""", re.VERBOSE)
-
-# grabs all text up to the first occurrence of ':' as sub-prompt
-# takes the value following ':' as weight
-# if ':' has no value defined, defaults to 1.0
-# repeats until no text remaining
-def split_weighted_subprompts(input_string, normalize=True):
- parsed_prompts = [(match.group("prompt"), float(match.group("weight") or 1)) for match in re.finditer(prompt_parser, input_string)]
- if not normalize:
- return parsed_prompts
- # this probably still doesn't handle negative weights very well
- weight_sum = sum(map(lambda x: x[1], parsed_prompts))
- return [(x[0], x[1] / weight_sum) for x in parsed_prompts]
-
-def slerp(device, t, v0:torch.Tensor, v1:torch.Tensor, DOT_THRESHOLD=0.9995):
- v0 = v0.detach().cpu().numpy()
- v1 = v1.detach().cpu().numpy()
-
- dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1)))
- if np.abs(dot) > DOT_THRESHOLD:
- v2 = (1 - t) * v0 + t * v1
- else:
- theta_0 = np.arccos(dot)
- sin_theta_0 = np.sin(theta_0)
- theta_t = theta_0 * t
- sin_theta_t = np.sin(theta_t)
- s0 = np.sin(theta_0 - theta_t) / sin_theta_0
- s1 = sin_theta_t / sin_theta_0
- v2 = s0 * v0 + s1 * v1
-
- v2 = torch.from_numpy(v2).to(device)
-
- return v2
-
-
-def ModelLoader(models,load=False,unload=False,imgproc_realesrgan_model_name='RealESRGAN_x4plus'):
- #get global variables
- global_vars = globals()
- #check if m is in globals
- if unload:
- for m in models:
- if m in global_vars:
- #if it is, delete it
- del global_vars[m]
- if defaults.general.optimized:
- if m == 'model':
- del global_vars[m+'FS']
- del global_vars[m+'CS']
- if m =='model':
- m='Stable Diffusion'
- print('Unloaded ' + m)
- if load:
- for m in models:
- if m not in global_vars or m in global_vars and type(global_vars[m]) == bool:
- #if it isn't, load it
- if m == 'GFPGAN':
- global_vars[m] = load_GFPGAN()
- elif m == 'model':
- sdLoader = load_sd_from_config()
- global_vars[m] = sdLoader[0]
- if defaults.general.optimized:
- global_vars[m+'CS'] = sdLoader[1]
- global_vars[m+'FS'] = sdLoader[2]
- elif m == 'RealESRGAN':
- global_vars[m] = load_RealESRGAN(imgproc_realesrgan_model_name)
- elif m == 'LDSR':
- global_vars[m] = load_LDSR()
- if m =='model':
- m='Stable Diffusion'
- print('Loaded ' + m)
- torch_gc()
-
-
-
-def get_font(fontsize):
- fonts = ["arial.ttf", "DejaVuSans.ttf"]
- for font_name in fonts:
- try:
- return ImageFont.truetype(font_name, fontsize)
- except OSError:
- pass
-
- # ImageFont.load_default() is practically unusable as it only supports
- # latin1, so raise an exception instead if no usable font was found
- raise Exception(f"No usable font found (tried {', '.join(fonts)})")
-
-def load_embeddings(fp):
- if fp is not None and hasattr(st.session_state["model"], "embedding_manager"):
- st.session_state["model"].embedding_manager.load(fp['name'])
-
-def image_grid(imgs, batch_size, force_n_rows=None, captions=None):
- #print (len(imgs))
- if force_n_rows is not None:
- rows = force_n_rows
- elif defaults.general.n_rows > 0:
- rows = defaults.general.n_rows
- elif defaults.general.n_rows == 0:
- rows = batch_size
- else:
- rows = math.sqrt(len(imgs))
- rows = round(rows)
-
- cols = math.ceil(len(imgs) / rows)
-
- w, h = imgs[0].size
- grid = Image.new('RGB', size=(cols * w, rows * h), color='black')
-
- fnt = get_font(30)
-
- for i, img in enumerate(imgs):
- grid.paste(img, box=(i % cols * w, i // cols * h))
- if captions and i= 2**32:
- n = n >> 32
- return n
-
-def check_prompt_length(prompt, comments):
- """this function tests if prompt is too long, and if so, adds a message to comments"""
-
- tokenizer = (st.session_state["model"] if not defaults.general.optimized else modelCS).cond_stage_model.tokenizer
- max_length = (st.session_state["model"] if not defaults.general.optimized else modelCS).cond_stage_model.max_length
-
- info = (st.session_state["model"] if not defaults.general.optimized else modelCS).cond_stage_model.tokenizer([prompt], truncation=True, max_length=max_length,
- return_overflowing_tokens=True, padding="max_length", return_tensors="pt")
- ovf = info['overflowing_tokens'][0]
- overflowing_count = ovf.shape[0]
- if overflowing_count == 0:
- return
-
- vocab = {v: k for k, v in tokenizer.get_vocab().items()}
- overflowing_words = [vocab.get(int(x), "") for x in ovf]
- overflowing_text = tokenizer.convert_tokens_to_string(''.join(overflowing_words))
-
- comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
-
-def save_sample(image, sample_path_i, filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale,
- normalize_prompt_weights, use_GFPGAN, write_info_files, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback,
- save_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, save_individual_images):
-
- filename_i = os.path.join(sample_path_i, filename)
-
- if not jpg_sample:
- if defaults.general.save_metadata:
- metadata = PngInfo()
- metadata.add_text("SD:prompt", prompts[i])
- metadata.add_text("SD:seed", str(seeds[i]))
- metadata.add_text("SD:width", str(width))
- metadata.add_text("SD:height", str(height))
- metadata.add_text("SD:steps", str(steps))
- metadata.add_text("SD:cfg_scale", str(cfg_scale))
- metadata.add_text("SD:normalize_prompt_weights", str(normalize_prompt_weights))
- if init_img is not None:
- metadata.add_text("SD:denoising_strength", str(denoising_strength))
- metadata.add_text("SD:GFPGAN", str(use_GFPGAN and st.session_state["GFPGAN"] is not None))
- image.save(f"{filename_i}.png", pnginfo=metadata)
- else:
- image.save(f"{filename_i}.png")
- else:
- image.save(f"{filename_i}.jpg", 'jpeg', quality=100, optimize=True)
-
- if write_info_files:
- # toggles differ for txt2img vs. img2img:
- offset = 0 if init_img is None else 2
- toggles = []
- if prompt_matrix:
- toggles.append(0)
- if normalize_prompt_weights:
- toggles.append(1)
- if init_img is not None:
- if uses_loopback:
- toggles.append(2)
- if uses_random_seed_loopback:
- toggles.append(3)
- if save_individual_images:
- toggles.append(2 + offset)
- if save_grid:
- toggles.append(3 + offset)
- if sort_samples:
- toggles.append(4 + offset)
- if write_info_files:
- toggles.append(5 + offset)
- if use_GFPGAN:
- toggles.append(6 + offset)
- info_dict = dict(
- target="txt2img" if init_img is None else "img2img",
- prompt=prompts[i], ddim_steps=steps, toggles=toggles, sampler_name=sampler_name,
- ddim_eta=ddim_eta, n_iter=n_iter, batch_size=batch_size, cfg_scale=cfg_scale,
- seed=seeds[i], width=width, height=height
- )
- if init_img is not None:
- # Not yet any use for these, but they bloat up the files:
- #info_dict["init_img"] = init_img
- #info_dict["init_mask"] = init_mask
- info_dict["denoising_strength"] = denoising_strength
- info_dict["resize_mode"] = resize_mode
- with open(f"{filename_i}.yaml", "w", encoding="utf8") as f:
- yaml.dump(info_dict, f, allow_unicode=True, width=10000)
-
- # render the image on the frontend
- st.session_state["preview_image"].image(image)
-
-def get_next_sequence_number(path, prefix=''):
- """
- Determines and returns the next sequence number to use when saving an
- image in the specified directory.
-
- If a prefix is given, only consider files whose names start with that
- prefix, and strip the prefix from filenames before extracting their
- sequence number.
-
- The sequence starts at 0.
- """
- result = -1
- for p in Path(path).iterdir():
- if p.name.endswith(('.png', '.jpg')) and p.name.startswith(prefix):
- tmp = p.name[len(prefix):]
- try:
- result = max(int(tmp.split('-')[0]), result)
- except ValueError:
- pass
- return result + 1
-
-
-def oxlamon_matrix(prompt, seed, n_iter, batch_size):
- pattern = re.compile(r'(,\s){2,}')
-
- class PromptItem:
- def __init__(self, text, parts, item):
- self.text = text
- self.parts = parts
- if item:
- self.parts.append( item )
-
- def clean(txt):
- return re.sub(pattern, ', ', txt)
-
- def getrowcount( txt ):
- for data in re.finditer( ".*?\\((.*?)\\).*", txt ):
- if data:
- return len(data.group(1).split("|"))
- break
- return None
-
- def repliter( txt ):
- for data in re.finditer( ".*?\\((.*?)\\).*", txt ):
- if data:
- r = data.span(1)
- for item in data.group(1).split("|"):
- yield (clean(txt[:r[0]-1] + item.strip() + txt[r[1]+1:]), item.strip())
- break
-
- def iterlist( items ):
- outitems = []
- for item in items:
- for newitem, newpart in repliter(item.text):
- outitems.append( PromptItem(newitem, item.parts.copy(), newpart) )
-
- return outitems
-
- def getmatrix( prompt ):
- dataitems = [ PromptItem( prompt[1:].strip(), [], None ) ]
- while True:
- newdataitems = iterlist( dataitems )
- if len( newdataitems ) == 0:
- return dataitems
- dataitems = newdataitems
-
- def classToArrays( items, seed, n_iter ):
- texts = []
- parts = []
- seeds = []
-
- for item in items:
- itemseed = seed
- for i in range(n_iter):
- texts.append( item.text )
- parts.append( f"Seed: {itemseed}\n" + "\n".join(item.parts) )
- seeds.append( itemseed )
- itemseed += 1
-
- return seeds, texts, parts
-
- all_seeds, all_prompts, prompt_matrix_parts = classToArrays(getmatrix( prompt ), seed, n_iter)
- n_iter = math.ceil(len(all_prompts) / batch_size)
-
- needrows = getrowcount(prompt)
- if needrows:
- xrows = math.sqrt(len(all_prompts))
- xrows = round(xrows)
- # if columns is to much
- cols = math.ceil(len(all_prompts) / xrows)
- if cols > needrows*4:
- needrows *= 2
-
- return all_seeds, n_iter, prompt_matrix_parts, all_prompts, needrows
-
-
-def process_images(
- outpath, func_init, func_sample, prompt, seed, sampler_name, save_grid, batch_size,
- n_iter, steps, cfg_scale, width, height, prompt_matrix, use_GFPGAN, use_RealESRGAN, realesrgan_model_name,
- fp=None, ddim_eta=0.0, normalize_prompt_weights=True, init_img=None, init_mask=None,
- keep_mask=False, mask_blur_strength=3, denoising_strength=0.75, resize_mode=None, uses_loopback=False,
- uses_random_seed_loopback=False, sort_samples=True, write_info_files=True, jpg_sample=False,
- variant_amount=0.0, variant_seed=None, save_individual_images: bool = True):
- """this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
- assert prompt is not None
- torch_gc()
- # start time after garbage collection (or before?)
- start_time = time.time()
-
- # We will use this date here later for the folder name, need to start_time if not need
- run_start_dt = datetime.datetime.now()
-
- mem_mon = MemUsageMonitor('MemMon')
- mem_mon.start()
-
- if hasattr(st.session_state["model"], "embedding_manager"):
- load_embeddings(fp)
-
- os.makedirs(outpath, exist_ok=True)
-
- sample_path = os.path.join(outpath, "samples")
- os.makedirs(sample_path, exist_ok=True)
-
- if not ("|" in prompt) and prompt.startswith("@"):
- prompt = prompt[1:]
-
- comments = []
-
- prompt_matrix_parts = []
- simple_templating = False
- add_original_image = not (use_RealESRGAN or use_GFPGAN)
-
- if prompt_matrix:
- if prompt.startswith("@"):
- simple_templating = True
- add_original_image = not (use_RealESRGAN or use_GFPGAN)
- all_seeds, n_iter, prompt_matrix_parts, all_prompts, frows = oxlamon_matrix(prompt, seed, n_iter, batch_size)
- else:
- all_prompts = []
- prompt_matrix_parts = prompt.split("|")
- combination_count = 2 ** (len(prompt_matrix_parts) - 1)
- for combination_num in range(combination_count):
- current = prompt_matrix_parts[0]
-
- for n, text in enumerate(prompt_matrix_parts[1:]):
- if combination_num & (2 ** n) > 0:
- current += ("" if text.strip().startswith(",") else ", ") + text
-
- all_prompts.append(current)
-
- n_iter = math.ceil(len(all_prompts) / batch_size)
- all_seeds = len(all_prompts) * [seed]
-
- print(f"Prompt matrix will create {len(all_prompts)} images using a total of {n_iter} batches.")
- else:
-
- if not defaults.general.no_verify_input:
- try:
- check_prompt_length(prompt, comments)
- except:
- import traceback
- print("Error verifying input:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
-
- all_prompts = batch_size * n_iter * [prompt]
- all_seeds = [seed + x for x in range(len(all_prompts))]
-
- precision_scope = autocast if defaults.general.precision == "autocast" else nullcontext
- output_images = []
- grid_captions = []
- stats = []
- with torch.no_grad(), precision_scope("cuda"), (st.session_state["model"].ema_scope() if not defaults.general.optimized else nullcontext()):
- init_data = func_init()
- tic = time.time()
-
-
- # if variant_amount > 0.0 create noise from base seed
- base_x = None
- if variant_amount > 0.0:
- target_seed_randomizer = seed_to_int('') # random seed
- torch.manual_seed(seed) # this has to be the single starting seed (not per-iteration)
- base_x = create_random_tensors([opt_C, height // opt_f, width // opt_f], seeds=[seed])
- # we don't want all_seeds to be sequential from starting seed with variants,
- # since that makes the same variants each time,
- # so we add target_seed_randomizer as a random offset
- for si in range(len(all_seeds)):
- all_seeds[si] += target_seed_randomizer
-
- for n in range(n_iter):
- print(f"Iteration: {n+1}/{n_iter}")
- prompts = all_prompts[n * batch_size:(n + 1) * batch_size]
- captions = prompt_matrix_parts[n * batch_size:(n + 1) * batch_size]
- seeds = all_seeds[n * batch_size:(n + 1) * batch_size]
-
- print(prompt)
-
- if defaults.general.optimized:
- modelCS.to(defaults.general.gpu)
-
- uc = (st.session_state["model"] if not defaults.general.optimized else modelCS).get_learned_conditioning(len(prompts) * [""])
-
- if isinstance(prompts, tuple):
- prompts = list(prompts)
-
- # split the prompt if it has : for weighting
- # TODO for speed it might help to have this occur when all_prompts filled??
- weighted_subprompts = split_weighted_subprompts(prompts[0], normalize_prompt_weights)
-
- # sub-prompt weighting used if more than 1
- if len(weighted_subprompts) > 1:
- c = torch.zeros_like(uc) # i dont know if this is correct.. but it works
- for i in range(0, len(weighted_subprompts)):
- # note if alpha negative, it functions same as torch.sub
- c = torch.add(c, (st.session_state["model"] if not defaults.general.optimized else modelCS).get_learned_conditioning(weighted_subprompts[i][0]), alpha=weighted_subprompts[i][1])
- else: # just behave like usual
- c = (st.session_state["model"] if not defaults.general.optimized else modelCS).get_learned_conditioning(prompts)
-
-
- shape = [opt_C, height // opt_f, width // opt_f]
-
- if defaults.general.optimized:
- mem = torch.cuda.memory_allocated()/1e6
- modelCS.to("cpu")
- while(torch.cuda.memory_allocated()/1e6 >= mem):
- time.sleep(1)
-
- if variant_amount == 0.0:
- # we manually generate all input noises because each one should have a specific seed
- x = create_random_tensors(shape, seeds=seeds)
-
- else: # we are making variants
- # using variant_seed as sneaky toggle,
- # when not None or '' use the variant_seed
- # otherwise use seeds
- if variant_seed != None and variant_seed != '':
- specified_variant_seed = seed_to_int(variant_seed)
- torch.manual_seed(specified_variant_seed)
- seeds = [specified_variant_seed]
- target_x = create_random_tensors(shape, seeds=seeds)
- # finally, slerp base_x noise to target_x noise for creating a variant
- x = slerp(defaults.general.gpu, max(0.0, min(1.0, variant_amount)), base_x, target_x)
-
- samples_ddim = func_sample(init_data=init_data, x=x, conditioning=c, unconditional_conditioning=uc, sampler_name=sampler_name)
-
- if defaults.general.optimized:
- modelFS.to(defaults.general.gpu)
-
- x_samples_ddim = (st.session_state["model"] if not defaults.general.optimized else modelFS).decode_first_stage(samples_ddim)
- x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
-
- for i, x_sample in enumerate(x_samples_ddim):
- sanitized_prompt = slugify(prompts[i])
-
- if sort_samples:
- full_path = os.path.join(os.getcwd(), sample_path, sanitized_prompt)
-
-
- sanitized_prompt = sanitized_prompt[:220-len(full_path)]
- sample_path_i = os.path.join(sample_path, sanitized_prompt)
-
- #print(f"output folder length: {len(os.path.join(os.getcwd(), sample_path_i))}")
- #print(os.path.join(os.getcwd(), sample_path_i))
-
- os.makedirs(sample_path_i, exist_ok=True)
- base_count = get_next_sequence_number(sample_path_i)
- filename = f"{base_count:05}-{steps}_{sampler_name}_{seeds[i]}"
- else:
- full_path = os.path.join(os.getcwd(), sample_path)
- sample_path_i = sample_path
- base_count = get_next_sequence_number(sample_path_i)
- filename = f"{base_count:05}-{steps}_{sampler_name}_{seeds[i]}_{sanitized_prompt}"[:220-len(full_path)] #same as before
-
- x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
- x_sample = x_sample.astype(np.uint8)
- image = Image.fromarray(x_sample)
- original_sample = x_sample
- original_filename = filename
-
- if use_GFPGAN and st.session_state["GFPGAN"] is not None and not use_RealESRGAN:
- #skip_save = True # #287 >_>
- torch_gc()
- cropped_faces, restored_faces, restored_img = st.session_state["GFPGAN"].enhance(x_sample[:,:,::-1], has_aligned=False, only_center_face=False, paste_back=True)
- gfpgan_sample = restored_img[:,:,::-1]
- gfpgan_image = Image.fromarray(gfpgan_sample)
- gfpgan_filename = original_filename + '-gfpgan'
-
- save_sample(gfpgan_image, sample_path_i, gfpgan_filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale,
- normalize_prompt_weights, use_GFPGAN, write_info_files, prompt_matrix, init_img, uses_loopback,
- uses_random_seed_loopback, save_grid, sort_samples, sampler_name, ddim_eta,
- n_iter, batch_size, i, denoising_strength, resize_mode, save_individual_images=False)
-
- output_images.append(gfpgan_image) #287
- if simple_templating:
- grid_captions.append( captions[i] + "\ngfpgan" )
-
- if use_RealESRGAN and st.session_state["RealESRGAN"] is not None and not use_GFPGAN:
- #skip_save = True # #287 >_>
- torch_gc()
-
- if st.session_state["RealESRGAN"].model.name != realesrgan_model_name:
- #try_loading_RealESRGAN(realesrgan_model_name)
- load_models(use_GFPGAN=use_GFPGAN, use_RealESRGAN=use_RealESRGAN, RealESRGAN_model=realesrgan_model_name)
-
- output, img_mode = st.session_state["RealESRGAN"].enhance(x_sample[:,:,::-1])
- esrgan_filename = original_filename + '-esrgan4x'
- esrgan_sample = output[:,:,::-1]
- esrgan_image = Image.fromarray(esrgan_sample)
-
- #save_sample(image, sample_path_i, original_filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale,
- #normalize_prompt_weights, use_GFPGAN, write_info_files, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, skip_save,
- #save_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode)
-
- save_sample(esrgan_image, sample_path_i, esrgan_filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale,
- normalize_prompt_weights, use_GFPGAN, write_info_files, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback,
- save_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, save_individual_images=False)
-
- output_images.append(esrgan_image) #287
- if simple_templating:
- grid_captions.append( captions[i] + "\nesrgan" )
-
- if use_RealESRGAN and st.session_state["RealESRGAN"] is not None and use_GFPGAN and st.session_state["GFPGAN"] is not None:
- #skip_save = True # #287 >_>
- torch_gc()
- cropped_faces, restored_faces, restored_img = st.session_state["GFPGAN"].enhance(x_sample[:,:,::-1], has_aligned=False, only_center_face=False, paste_back=True)
- gfpgan_sample = restored_img[:,:,::-1]
-
- if st.session_state["RealESRGAN"].model.name != realesrgan_model_name:
- #try_loading_RealESRGAN(realesrgan_model_name)
- load_models(use_GFPGAN=use_GFPGAN, use_RealESRGAN=use_RealESRGAN, RealESRGAN_model=realesrgan_model_name)
-
- output, img_mode = st.session_state["RealESRGAN"].enhance(gfpgan_sample[:,:,::-1])
- gfpgan_esrgan_filename = original_filename + '-gfpgan-esrgan4x'
- gfpgan_esrgan_sample = output[:,:,::-1]
- gfpgan_esrgan_image = Image.fromarray(gfpgan_esrgan_sample)
-
- save_sample(gfpgan_esrgan_image, sample_path_i, gfpgan_esrgan_filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale,
- normalize_prompt_weights, False, write_info_files, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback,
- save_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, save_individual_images=False)
-
- output_images.append(gfpgan_esrgan_image) #287
-
- if simple_templating:
- grid_captions.append( captions[i] + "\ngfpgan_esrgan" )
-
- if save_individual_images:
- save_sample(image, sample_path_i, filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale,
- normalize_prompt_weights, use_GFPGAN, write_info_files, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback,
- save_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, save_individual_images)
-
- if not use_GFPGAN or not use_RealESRGAN:
- output_images.append(image)
-
- #if add_original_image or not simple_templating:
- #output_images.append(image)
- #if simple_templating:
- #grid_captions.append( captions[i] )
-
- if defaults.general.optimized:
- mem = torch.cuda.memory_allocated()/1e6
- modelFS.to("cpu")
- while(torch.cuda.memory_allocated()/1e6 >= mem):
- time.sleep(1)
-
- if prompt_matrix or save_grid:
- if prompt_matrix:
- if simple_templating:
- grid = image_grid(output_images, n_iter, force_n_rows=frows, captions=grid_captions)
- else:
- grid = image_grid(output_images, n_iter, force_n_rows=1 << ((len(prompt_matrix_parts)-1)//2))
- try:
- grid = draw_prompt_matrix(grid, width, height, prompt_matrix_parts)
- except:
- import traceback
- print("Error creating prompt_matrix text:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
- else:
- grid = image_grid(output_images, batch_size)
-
- if grid and (batch_size > 1 or n_iter > 1):
- output_images.insert(0, grid)
-
- grid_count = get_next_sequence_number(outpath, 'grid-')
- grid_file = f"grid-{grid_count:05}-{seed}_{slugify(prompts[i].replace(' ', '_')[:220-len(full_path)])}.{grid_ext}"
- grid.save(os.path.join(outpath, grid_file), grid_format, quality=grid_quality, lossless=grid_lossless, optimize=True)
-
- toc = time.time()
-
- mem_max_used, mem_total = mem_mon.read_and_stop()
- time_diff = time.time()-start_time
-
- info = f"""
- {prompt}
- Steps: {steps}, Sampler: {sampler_name}, CFG scale: {cfg_scale}, Seed: {seed}{', Denoising strength: '+str(denoising_strength) if init_img is not None else ''}{', GFPGAN' if use_GFPGAN and st.session_state["GFPGAN"] is not None else ''}{', '+realesrgan_model_name if use_RealESRGAN and st.session_state["RealESRGAN"] is not None else ''}{', Prompt Matrix Mode.' if prompt_matrix else ''}""".strip()
- stats = f'''
- Took { round(time_diff, 2) }s total ({ round(time_diff/(len(all_prompts)),2) }s per image)
- Peak memory usage: { -(mem_max_used // -1_048_576) } MiB / { -(mem_total // -1_048_576) } MiB / { round(mem_max_used/mem_total*100, 3) }%'''
-
- for comment in comments:
- info += "\n\n" + comment
-
- #mem_mon.stop()
- #del mem_mon
- torch_gc()
-
- return output_images, seed, info, stats
-
-
-def resize_image(resize_mode, im, width, height):
- LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
- if resize_mode == 0:
- res = im.resize((width, height), resample=LANCZOS)
- elif resize_mode == 1:
- ratio = width / height
- src_ratio = im.width / im.height
-
- src_w = width if ratio > src_ratio else im.width * height // im.height
- src_h = height if ratio <= src_ratio else im.height * width // im.width
-
- resized = im.resize((src_w, src_h), resample=LANCZOS)
- res = Image.new("RGBA", (width, height))
- res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
- else:
- ratio = width / height
- src_ratio = im.width / im.height
-
- src_w = width if ratio < src_ratio else im.width * height // im.height
- src_h = height if ratio >= src_ratio else im.height * width // im.width
-
- resized = im.resize((src_w, src_h), resample=LANCZOS)
- res = Image.new("RGBA", (width, height))
- res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
-
- if ratio < src_ratio:
- fill_height = height // 2 - src_h // 2
- res.paste(resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0))
- res.paste(resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)), box=(0, fill_height + src_h))
- elif ratio > src_ratio:
- fill_width = width // 2 - src_w // 2
- res.paste(resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0))
- res.paste(resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)), box=(fill_width + src_w, 0))
-
- return res
-
-def img2img(prompt: str = '', init_info: any = None, init_info_mask: any = None, mask_mode: int = 0, mask_blur_strength: int = 3,
- ddim_steps: int = 50, sampler_name: str = 'DDIM',
- n_iter: int = 1, cfg_scale: float = 7.5, denoising_strength: float = 0.8,
- seed: int = -1, height: int = 512, width: int = 512, resize_mode: int = 0, fp = None,
- variant_amount: float = None, variant_seed: int = None, ddim_eta:float = 0.0,
- write_info_files:bool = True, RealESRGAN_model: str = "RealESRGAN_x4plus_anime_6B",
- separate_prompts:bool = False, normalize_prompt_weights:bool = True,
- save_individual_images: bool = True, save_grid: bool = True, group_by_prompt: bool = True,
- save_as_jpg: bool = True, use_GFPGAN: bool = True, use_RealESRGAN: bool = True, loopback: bool = False,
- random_seed_loopback: bool = False
- ):
-
- outpath = defaults.general.outdir_img2img or defaults.general.outdir or "outputs/img2img-samples"
- err = False
- #loopback = False
- #skip_save = False
- seed = seed_to_int(seed)
-
- batch_size = 1
-
- #prompt_matrix = 0
- #normalize_prompt_weights = 1 in toggles
- #loopback = 2 in toggles
- #random_seed_loopback = 3 in toggles
- #skip_save = 4 not in toggles
- #save_grid = 5 in toggles
- #sort_samples = 6 in toggles
- #write_info_files = 7 in toggles
- #write_sample_info_to_log_file = 8 in toggles
- #jpg_sample = 9 in toggles
- #use_GFPGAN = 10 in toggles
- #use_RealESRGAN = 11 in toggles
-
- if sampler_name == 'PLMS':
- sampler = PLMSSampler(st.session_state["model"])
- elif sampler_name == 'DDIM':
- sampler = DDIMSampler(st.session_state["model"])
- elif sampler_name == 'k_dpm_2_a':
- sampler = KDiffusionSampler(st.session_state["model"],'dpm_2_ancestral')
- elif sampler_name == 'k_dpm_2':
- sampler = KDiffusionSampler(st.session_state["model"],'dpm_2')
- elif sampler_name == 'k_euler_a':
- sampler = KDiffusionSampler(st.session_state["model"],'euler_ancestral')
- elif sampler_name == 'k_euler':
- sampler = KDiffusionSampler(st.session_state["model"],'euler')
- elif sampler_name == 'k_heun':
- sampler = KDiffusionSampler(st.session_state["model"],'heun')
- elif sampler_name == 'k_lms':
- sampler = KDiffusionSampler(st.session_state["model"],'lms')
- else:
- raise Exception("Unknown sampler: " + sampler_name)
-
- init_img = init_info
- init_mask = None
- keep_mask = False
-
- assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]'
- t_enc = int(denoising_strength * ddim_steps)
-
- def init():
-
- image = init_img
- image = np.array(image).astype(np.float32) / 255.0
- image = image[None].transpose(0, 3, 1, 2)
- image = torch.from_numpy(image)
-
- mask = None
- if defaults.general.optimized:
- modelFS.to(st.session_state["device"] )
-
- init_image = 2. * image - 1.
- init_image = init_image.to(st.session_state["device"])
- init_latent = (st.session_state["model"] if not defaults.general.optimized else modelFS).get_first_stage_encoding((st.session_state["model"] if not defaults.general.optimized else modelFS).encode_first_stage(init_image)) # move to latent space
-
- if defaults.general.optimized:
- mem = torch.cuda.memory_allocated()/1e6
- modelFS.to("cpu")
- while(torch.cuda.memory_allocated()/1e6 >= mem):
- time.sleep(1)
-
- return init_latent, mask,
-
- def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name):
- t_enc_steps = t_enc
- obliterate = False
- if ddim_steps == t_enc_steps:
- t_enc_steps = t_enc_steps - 1
- obliterate = True
-
- if sampler_name != 'DDIM':
- x0, z_mask = init_data
-
- sigmas = sampler.model_wrap.get_sigmas(ddim_steps)
- noise = x * sigmas[ddim_steps - t_enc_steps - 1]
-
- xi = x0 + noise
-
- # Obliterate masked image
- if z_mask is not None and obliterate:
- random = torch.randn(z_mask.shape, device=xi.device)
- xi = (z_mask * noise) + ((1-z_mask) * xi)
-
- sigma_sched = sigmas[ddim_steps - t_enc_steps - 1:]
- model_wrap_cfg = CFGMaskedDenoiser(sampler.model_wrap)
- samples_ddim = K.sampling.__dict__[f'sample_{sampler.get_sampler_name()}'](model_wrap_cfg, xi, sigma_sched,
- extra_args={'cond': conditioning, 'uncond': unconditional_conditioning,
- 'cond_scale': cfg_scale, 'mask': z_mask, 'x0': x0, 'xi': xi}, disable=False,
- callback=generation_callback)
- else:
-
- x0, z_mask = init_data
-
- sampler.make_schedule(ddim_num_steps=ddim_steps, ddim_eta=0.0, verbose=False)
- z_enc = sampler.stochastic_encode(x0, torch.tensor([t_enc_steps]*batch_size).to(st.session_state["device"] ))
-
- # Obliterate masked image
- if z_mask is not None and obliterate:
- random = torch.randn(z_mask.shape, device=z_enc.device)
- z_enc = (z_mask * random) + ((1-z_mask) * z_enc)
-
- # decode it
- samples_ddim = sampler.decode(z_enc, conditioning, t_enc_steps,
- unconditional_guidance_scale=cfg_scale,
- unconditional_conditioning=unconditional_conditioning,
- z_mask=z_mask, x0=x0)
- return samples_ddim
-
-
-
- if loopback:
- output_images, info = None, None
- history = []
- initial_seed = None
-
- do_color_correction = False
- try:
- from skimage import exposure
- do_color_correction = True
- except:
- print("Install scikit-image to perform color correction on loopback")
-
- for i in range(1):
- if do_color_correction and i == 0:
- correction_target = cv2.cvtColor(np.asarray(init_img.copy()), cv2.COLOR_RGB2LAB)
-
- output_images, seed, info, stats = process_images(
- outpath=outpath,
- func_init=init,
- func_sample=sample,
- prompt=prompt,
- seed=seed,
- sampler_name=sampler_name,
- save_grid=save_grid,
- batch_size=1,
- n_iter=n_iter,
- steps=ddim_steps,
- cfg_scale=cfg_scale,
- width=width,
- height=height,
- prompt_matrix=separate_prompts,
- use_GFPGAN=use_GFPGAN,
- use_RealESRGAN=use_RealESRGAN, # Forcefully disable upscaling when using loopback
- realesrgan_model_name=RealESRGAN_model,
- fp=fp,
- normalize_prompt_weights=normalize_prompt_weights,
- save_individual_images=save_individual_images,
- init_img=init_img,
- init_mask=init_mask,
- keep_mask=keep_mask,
- mask_blur_strength=mask_blur_strength,
- denoising_strength=denoising_strength,
- resize_mode=resize_mode,
- uses_loopback=loopback,
- uses_random_seed_loopback=random_seed_loopback,
- sort_samples=group_by_prompt,
- write_info_files=write_info_files,
- jpg_sample=save_as_jpg
- )
-
- if initial_seed is None:
- initial_seed = seed
-
- init_img = output_images[0]
-
- if do_color_correction and correction_target is not None:
- init_img = Image.fromarray(cv2.cvtColor(exposure.match_histograms(
- cv2.cvtColor(
- np.asarray(init_img),
- cv2.COLOR_RGB2LAB
- ),
- correction_target,
- channel_axis=2
- ), cv2.COLOR_LAB2RGB).astype("uint8"))
-
- if not random_seed_loopback:
- seed = seed + 1
- else:
- seed = seed_to_int(None)
-
- denoising_strength = max(denoising_strength * 0.95, 0.1)
- history.append(init_img)
-
- output_images = history
- seed = initial_seed
-
- else:
- output_images, seed, info, stats = process_images(
- outpath=outpath,
- func_init=init,
- func_sample=sample,
- prompt=prompt,
- seed=seed,
- sampler_name=sampler_name,
- save_grid=save_grid,
- batch_size=batch_size,
- n_iter=n_iter,
- steps=ddim_steps,
- cfg_scale=cfg_scale,
- width=width,
- height=height,
- prompt_matrix=separate_prompts,
- use_GFPGAN=use_GFPGAN,
- use_RealESRGAN=use_RealESRGAN,
- realesrgan_model_name=RealESRGAN_model,
- fp=fp,
- normalize_prompt_weights=normalize_prompt_weights,
- save_individual_images=save_individual_images,
- init_img=init_img,
- init_mask=init_mask,
- keep_mask=keep_mask,
- mask_blur_strength=2,
- denoising_strength=denoising_strength,
- resize_mode=resize_mode,
- uses_loopback=loopback,
- sort_samples=group_by_prompt,
- write_info_files=write_info_files,
- jpg_sample=save_as_jpg
- )
-
- del sampler
-
- return output_images, seed, info, stats
-
-#@retry(RuntimeError, tries=3)
-def txt2img(prompt: str, ddim_steps: int, sampler_name: str, realesrgan_model_name: str,
- n_iter: int, batch_size: int, cfg_scale: float, seed: Union[int, str, None],
- height: int, width: int, separate_prompts:bool = False, normalize_prompt_weights:bool = True,
- save_individual_images: bool = True, save_grid: bool = True, group_by_prompt: bool = True,
- save_as_jpg: bool = True, use_GFPGAN: bool = True, use_RealESRGAN: bool = True,
- RealESRGAN_model: str = "RealESRGAN_x4plus_anime_6B", fp = None, variant_amount: float = None,
- variant_seed: int = None, ddim_eta:float = 0.0, write_info_files:bool = True):
-
- outpath = defaults.general.outdir_txt2img or defaults.general.outdir or "outputs/txt2img-samples"
-
- err = False
- seed = seed_to_int(seed)
-
- #prompt_matrix = 0 in toggles
- #normalize_prompt_weights = 1 in toggles
- #skip_save = 2 not in toggles
- #save_grid = 3 not in toggles
- #sort_samples = 4 in toggles
- #write_info_files = 5 in toggles
- #jpg_sample = 6 in toggles
- #use_GFPGAN = 7 in toggles
- #use_RealESRGAN = 8 in toggles
-
- if sampler_name == 'PLMS':
- sampler = PLMSSampler(st.session_state["model"])
- elif sampler_name == 'DDIM':
- sampler = DDIMSampler(st.session_state["model"])
- elif sampler_name == 'k_dpm_2_a':
- sampler = KDiffusionSampler(st.session_state["model"],'dpm_2_ancestral')
- elif sampler_name == 'k_dpm_2':
- sampler = KDiffusionSampler(st.session_state["model"],'dpm_2')
- elif sampler_name == 'k_euler_a':
- sampler = KDiffusionSampler(st.session_state["model"],'euler_ancestral')
- elif sampler_name == 'k_euler':
- sampler = KDiffusionSampler(st.session_state["model"],'euler')
- elif sampler_name == 'k_heun':
- sampler = KDiffusionSampler(st.session_state["model"],'heun')
- elif sampler_name == 'k_lms':
- sampler = KDiffusionSampler(st.session_state["model"],'lms')
- else:
- raise Exception("Unknown sampler: " + sampler_name)
-
- def init():
- pass
-
- def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name):
- samples_ddim, _ = sampler.sample(S=ddim_steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=cfg_scale,
- unconditional_conditioning=unconditional_conditioning, eta=ddim_eta, x_T=x, img_callback=generation_callback,
- log_every_t=int(defaults.general.update_preview_frequency))
-
- return samples_ddim
-
- #try:
- output_images, seed, info, stats = process_images(
- outpath=outpath,
- func_init=init,
- func_sample=sample,
- prompt=prompt,
- seed=seed,
- sampler_name=sampler_name,
- save_grid=save_grid,
- batch_size=batch_size,
- n_iter=n_iter,
- steps=ddim_steps,
- cfg_scale=cfg_scale,
- width=width,
- height=height,
- prompt_matrix=separate_prompts,
- use_GFPGAN=use_GFPGAN,
- use_RealESRGAN=use_RealESRGAN,
- realesrgan_model_name=realesrgan_model_name,
- fp=fp,
- ddim_eta=ddim_eta,
- normalize_prompt_weights=normalize_prompt_weights,
- save_individual_images=save_individual_images,
- sort_samples=group_by_prompt,
- write_info_files=write_info_files,
- jpg_sample=save_as_jpg,
- variant_amount=variant_amount,
- variant_seed=variant_seed,
- )
-
- del sampler
-
- return output_images, seed, info, stats
-
- #except RuntimeError as e:
- #err = e
- #err_msg = f'CRASHED:
Please wait while the program restarts.'
- #stats = err_msg
- #return [], seed, 'err', stats
-
-
-
+#os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" # see issue #152
+#os.environ["CUDA_VISIBLE_DEVICES"] = str(st.session_state["defaults"].general.gpu)
# functions to load css locally OR remotely starts here. Options exist for future flexibility. Called as st.markdown with unsafe_allow_html as css injection
# TODO, maybe look into async loading the file especially for remote fetching
def local_css(file_name):
- with open(file_name) as f:
- st.markdown(f'', unsafe_allow_html=True)
+ with open(file_name) as f:
+ st.markdown(f'', unsafe_allow_html=True)
def remote_css(url):
- st.markdown(f'', unsafe_allow_html=True)
+ st.markdown(f'', unsafe_allow_html=True)
def load_css(isLocal, nameOrURL):
if(isLocal):
@@ -1466,289 +57,89 @@ def load_css(isLocal, nameOrURL):
else:
remote_css(nameOrURL)
-
-# main functions to define streamlit layout here
def layout():
-
- st.set_page_config(page_title="Stable Diffusion Playground", layout="wide", initial_sidebar_state="collapsed")
+ """Layout functions to define all the streamlit layout here."""
+ st.set_page_config(page_title="Stable Diffusion Playground", layout="wide")
with st.empty():
# load css as an external file, function has an option to local or remote url. Potential use when running from cloud infra that might not have access to local path.
load_css(True, 'frontend/css/streamlit.main.css')
-
+
# check if the models exist on their respective folders
- if os.path.exists(os.path.join(defaults.general.GFPGAN_dir, "experiments", "pretrained_models", "GFPGANv1.3.pth")):
- GFPGAN_available = True
+ if os.path.exists(os.path.join(st.session_state["defaults"].general.GFPGAN_dir, "experiments", "pretrained_models", "GFPGANv1.3.pth")):
+ st.session_state["GFPGAN_available"] = True
else:
- GFPGAN_available = False
+ st.session_state["GFPGAN_available"] = False
- if os.path.exists(os.path.join(defaults.general.RealESRGAN_dir, "experiments","pretrained_models", f"{defaults.general.RealESRGAN_model}.pth")):
- RealESRGAN_available = True
+ if os.path.exists(os.path.join(st.session_state["defaults"].general.RealESRGAN_dir, "experiments","pretrained_models", f"{st.session_state['defaults'].general.RealESRGAN_model}.pth")):
+ st.session_state["RealESRGAN_available"] = True
else:
- RealESRGAN_available = False
+ st.session_state["RealESRGAN_available"] = False
+
+ # Allow for custom models to be used instead of the default one,
+ # an example would be Waifu-Diffusion or any other fine tune of stable diffusion
+ st.session_state["custom_models"]:sorted = []
+ for root, dirs, files in os.walk(os.path.join("models", "custom")):
+ for file in files:
+ if os.path.splitext(file)[1] == '.ckpt':
+ #fullpath = os.path.join(root, file)
+ #print(fullpath)
+ st.session_state["custom_models"].append(os.path.splitext(file)[0])
+ #print (os.path.splitext(file)[0])
+
+ if len(st.session_state["custom_models"]) > 0:
+ st.session_state["CustomModel_available"] = True
+ st.session_state["custom_models"].append("Stable Diffusion v1.4")
+ else:
+ st.session_state["CustomModel_available"] = False
with st.sidebar:
- # we should use an expander and group things together when more options are added so the sidebar is not too messy.
+ # The global settings section will be moved to the Settings page.
#with st.expander("Global Settings:"):
- st.write("Global Settings:")
- defaults.general.update_preview = st.checkbox("Update Image Preview", value=defaults.general.update_preview,
- help="If enabled the image preview will be updated during the generation instead of at the end. You can use the Update Preview \
- Frequency option bellow to customize how frequent it's updated. By default this is enabled and the frequency is set to 1 step.")
- defaults.general.update_preview_frequency = st.text_input("Update Image Preview Frequency", value=defaults.general.update_preview_frequency,
- help="Frequency in steps at which the the preview image is updated. By default the frequency is set to 1 step.")
-
-
-
- txt2img_tab, img2img_tab, txt2video, postprocessing_tab = st.tabs(["Text-to-Image Unified", "Image-to-Image Unified", "Text-to-Video","Post-Processing"])
-
- with txt2img_tab:
- with st.form("txt2img-inputs"):
- st.session_state["generation_mode"] = "txt2img"
-
- input_col1, generate_col1 = st.columns([10,1])
- with input_col1:
- #prompt = st.text_area("Input Text","")
- prompt = st.text_input("Input Text","", placeholder="A corgi wearing a top hat as an oil painting.")
-
- # Every form must have a submit button, the extra blank spaces is a temp way to align it with the input field. Needs to be done in CSS or some other way.
- generate_col1.write("")
- generate_col1.write("")
- generate_button = generate_col1.form_submit_button("Generate")
-
- # creating the page layout using columns
- col1, col2, col3 = st.columns([1,2,1], gap="large")
-
- with col1:
- width = st.slider("Width:", min_value=64, max_value=1024, value=defaults.txt2img.width, step=64)
- height = st.slider("Height:", min_value=64, max_value=1024, value=defaults.txt2img.height, step=64)
- cfg_scale = st.slider("CFG (Classifier Free Guidance Scale):", min_value=1.0, max_value=30.0, value=defaults.txt2img.cfg_scale, step=0.5, help="How strongly the image should follow the prompt.")
- seed = st.text_input("Seed:", value=defaults.txt2img.seed, help=" The seed to use, if left blank a random seed will be generated.")
- batch_count = st.slider("Batch count.", min_value=1, max_value=100, value=defaults.txt2img.batch_count, step=1, help="How many iterations or batches of images to generate in total.")
- #batch_size = st.slider("Batch size", min_value=1, max_value=250, value=defaults.txt2img.batch_size, step=1,
- #help="How many images are at once in a batch.\
- #It increases the VRAM usage a lot but if you have enough VRAM it can reduce the time it takes to finish generation as more images are generated at once.\
- #Default: 1")
-
- with col2:
- preview_tab, gallery_tab = st.tabs(["Preview", "Gallery"])
-
- with preview_tab:
- #st.write("Image")
- #Image for testing
- #image = Image.open(requests.get("https://icon-library.com/images/image-placeholder-icon/image-placeholder-icon-13.jpg", stream=True).raw).convert('RGB')
- #new_image = image.resize((175, 240))
- #preview_image = st.image(image)
-
- # create an empty container for the image, progress bar, etc so we can update it later and use session_state to hold them globally.
- st.session_state["preview_image"] = st.empty()
-
- st.session_state["loading"] = st.empty()
-
- st.session_state["progress_bar_text"] = st.empty()
- st.session_state["progress_bar"] = st.empty()
-
- message = st.empty()
-
- with gallery_tab:
- st.write('Here should be the image gallery, if I could make a grid in streamlit.')
-
- with col3:
- st.session_state.sampling_steps = st.slider("Sampling Steps", value=defaults.txt2img.sampling_steps, min_value=1, max_value=250)
-
- sampler_name_list = ["k_lms", "k_euler", "k_euler_a", "k_dpm_2", "k_dpm_2_a", "k_heun", "PLMS", "DDIM"]
- sampler_name = st.selectbox("Sampling method", sampler_name_list,
- index=sampler_name_list.index(defaults.txt2img.default_sampler), help="Sampling method to use. Default: k_euler")
-
-
-
- #basic_tab, advanced_tab = st.tabs(["Basic", "Advanced"])
-
- #with basic_tab:
- #summit_on_enter = st.radio("Submit on enter?", ("Yes", "No"), horizontal=True,
- #help="Press the Enter key to summit, when 'No' is selected you can use the Enter key to write multiple lines.")
-
- with st.expander("Advanced"):
- separate_prompts = st.checkbox("Create Prompt Matrix.", value=False, help="Separate multiple prompts using the `|` character, and get all combinations of them.")
- normalize_prompt_weights = st.checkbox("Normalize Prompt Weights.", value=True, help="Ensure the sum of all weights add up to 1.0")
- save_individual_images = st.checkbox("Save individual images.", value=True, help="Save each image generated before any filter or enhancement is applied.")
- save_grid = st.checkbox("Save grid",value=True, help="Save a grid with all the images generated into a single image.")
- group_by_prompt = st.checkbox("Group results by prompt", value=True,
- help="Saves all the images with the same prompt into the same folder. When using a prompt matrix each prompt combination will have its own folder.")
- write_info_files = st.checkbox("Write Info file", value=True, help="Save a file next to the image with informartion about the generation.")
- save_as_jpg = st.checkbox("Save samples as jpg", value=False, help="Saves the images as jpg instead of png.")
-
- if GFPGAN_available:
- use_GFPGAN = st.checkbox("Use GFPGAN", value=defaults.txt2img.use_GFPGAN, help="Uses the GFPGAN model to improve faces after the generation. This greatly improve the quality and consistency of faces but uses extra VRAM. Disable if you need the extra VRAM.")
- else:
- use_GFPGAN = False
-
- if RealESRGAN_available:
- use_RealESRGAN = st.checkbox("Use RealESRGAN", value=defaults.txt2img.use_RealESRGAN, help="Uses the RealESRGAN model to upscale the images after the generation. This greatly improve the quality and lets you have high resolution images but uses extra VRAM. Disable if you need the extra VRAM.")
- RealESRGAN_model = st.selectbox("RealESRGAN model", ["RealESRGAN_x4plus", "RealESRGAN_x4plus_anime_6B"], index=0)
- else:
- use_RealESRGAN = False
- RealESRGAN_model = "RealESRGAN_x4plus"
-
- variant_amount = st.slider("Variant Amount:", value=defaults.txt2img.variant_amount, min_value=0.0, max_value=1.0, step=0.01)
- variant_seed = st.text_input("Variant Seed:", value=defaults.txt2img.seed, help="The seed to use when generating a variant, if left blank a random seed will be generated.")
-
-
- if generate_button:
- #print("Loading models")
- # load the models when we hit the generate button for the first time, it wont be loaded after that so dont worry.
- load_models(False, use_GFPGAN, use_RealESRGAN, RealESRGAN_model)
-
- try:
- output_images, seed, info, stats = txt2img(prompt, st.session_state.sampling_steps, sampler_name, RealESRGAN_model, batch_count, 1,
- cfg_scale, seed, height, width, separate_prompts, normalize_prompt_weights, save_individual_images,
- save_grid, group_by_prompt, save_as_jpg, use_GFPGAN, use_RealESRGAN, RealESRGAN_model, fp=defaults.general.fp,
- variant_amount=variant_amount, variant_seed=variant_seed, write_info_files=write_info_files)
+ #st.write("Global Settings:")
+ #defaults.general.update_preview = st.checkbox("Update Image Preview", value=defaults.general.update_preview,
+ #help="If enabled the image preview will be updated during the generation instead of at the end. You can use the Update Preview \
+ #Frequency option bellow to customize how frequent it's updated. By default this is enabled and the frequency is set to 1 step.")
+ #st.session_state.update_preview_frequency = st.text_input("Update Image Preview Frequency", value=defaults.general.update_preview_frequency,
+ #help="Frequency in steps at which the the preview image is updated. By default the frequency is set to 1 step.")
+
+ tabs = on_hover_tabs(tabName=['Stable Diffusion', "Textual Inversion","Model Manager","Settings"],
+ iconName=['dashboard','model_training' ,'cloud_download', 'settings'], default_choice=0)
+
+ if tabs =='Stable Diffusion':
+ # txt2img_tab, img2img_tab, txt2vid_tab, postprocessing_tab, concept_library_tab = st.tabs(["Text-to-Image Unified", "Image-to-Image Unified",
+ # "Text-to-Video","Post-Processing", "Concept Library"])
+ txt2img_tab, img2img_tab, txt2vid_tab = st.tabs(
+ ["Text-to-Image Unified", "Image-to-Image Unified", "Text-to-Video"]
+ )
+ #with home_tab:
+ #from home import layout
+ #layout()
+
+ with txt2img_tab:
+ from txt2img import layout
+ layout()
+
+ with img2img_tab:
+ from img2img import layout
+ layout()
+
+ with txt2vid_tab:
+ from txt2vid import layout
+ layout()
+
+ # with concept_library_tab:
+ # from sd_concept_library import layout
+ # layout()
+
+ #
+ elif tabs == 'Model Manager':
+ from ModelManager import layout
+ layout()
- message.success('Done!', icon="✅")
-
- except (StopException, KeyError):
- print(f"Received Streamlit StopException")
-
- # this will render all the images at the end of the generation but its better if its moved to a second tab inside col2 and shown as a gallery.
- # use the current col2 first tab to show the preview_img and update it as its generated.
- #preview_image.image(output_images)
-
- with img2img_tab:
- with st.form("img2img-inputs"):
- st.session_state["generation_mode"] = "img2img"
-
- img2img_input_col, img2img_generate_col = st.columns([10,1])
- with img2img_input_col:
- #prompt = st.text_area("Input Text","")
- prompt = st.text_input("Input Text","", placeholder="A corgi wearing a top hat as an oil painting.")
-
- # Every form must have a submit button, the extra blank spaces is a temp way to align it with the input field. Needs to be done in CSS or some other way.
- img2img_generate_col.write("")
- img2img_generate_col.write("")
- generate_button = img2img_generate_col.form_submit_button("Generate")
-
-
- # creating the page layout using columns
- col1_img2img_layout, col2_img2img_layout, col3_img2img_layout = st.columns([1,2,2], gap="small")
-
- with col1_img2img_layout:
- st.session_state["sampling_steps"] = st.slider("Sampling Steps", value=defaults.img2img.sampling_steps, min_value=1, max_value=250)
- st.session_state["sampler_name"] = st.selectbox("Sampling method", ["k_lms", "k_euler", "k_euler_a", "k_dpm_2", "k_dpm_2_a", "k_heun", "PLMS", "DDIM"],
- index=0, help="Sampling method to use. Default: k_lms")
-
- uploaded_images = st.file_uploader("Upload Image", accept_multiple_files=False, type=["png", "jpg", "jpeg"],
- help="Upload an image which will be used for the image to image generation."
- )
-
- width = st.slider("Width:", min_value=64, max_value=1024, value=defaults.img2img.width, step=64)
- height = st.slider("Height:", min_value=64, max_value=1024, value=defaults.img2img.height, step=64)
- seed = st.text_input("Seed:", value=defaults.img2img.seed, help=" The seed to use, if left blank a random seed will be generated.")
- batch_count = st.slider("Batch count.", min_value=1, max_value=100, value=defaults.img2img.batch_count, step=1, help="How many iterations or batches of images to generate in total.")
-
- #
- with st.expander("Advanced"):
- separate_prompts = st.checkbox("Create Prompt Matrix.", value=defaults.img2img.separate_prompts, help="Separate multiple prompts using the `|` character, and get all combinations of them.")
- normalize_prompt_weights = st.checkbox("Normalize Prompt Weights.", value=defaults.img2img.normalize_prompt_weights, help="Ensure the sum of all weights add up to 1.0")
- loopback = st.checkbox("Loopback.", value=defaults.img2img.loopback, help="Use images from previous batch when creating next batch.")
- random_seed_loopback = st.checkbox("Random loopback seed.", value=defaults.img2img.random_seed_loopback, help="Random loopback seed")
- save_individual_images = st.checkbox("Save individual images.", value=True, help="Save each image generated before any filter or enhancement is applied.")
- save_grid = st.checkbox("Save grid",value=defaults.img2img.save_grid, help="Save a grid with all the images generated into a single image.")
- group_by_prompt = st.checkbox("Group results by prompt", value=defaults.img2img.group_by_prompt,
- help="Saves all the images with the same prompt into the same folder. When using a prompt matrix each prompt combination will have its own folder.")
- write_info_files = st.checkbox("Write Info file", value=True, help="Save a file next to the image with informartion about the generation.")
- save_as_jpg = st.checkbox("Save samples as jpg", value=False, help="Saves the images as jpg instead of png.")
-
- if GFPGAN_available:
- use_GFPGAN = st.checkbox("Use GFPGAN", value=defaults.img2img.use_GFPGAN, help="Uses the GFPGAN model to improve faces after the generation.\
- This greatly improve the quality and consistency of faces but uses extra VRAM. Disable if you need the extra VRAM.")
- else:
- use_GFPGAN = False
-
- if RealESRGAN_available:
- use_RealESRGAN = st.checkbox("Use RealESRGAN", value=defaults.img2img.use_RealESRGAN, help="Uses the RealESRGAN model to upscale the images after the generation.\
- This greatly improve the quality and lets you have high resolution images but uses extra VRAM. Disable if you need the extra VRAM.")
- RealESRGAN_model = st.selectbox("RealESRGAN model", ["RealESRGAN_x4plus", "RealESRGAN_x4plus_anime_6B"], index=0)
- else:
- use_RealESRGAN = False
- RealESRGAN_model = "RealESRGAN_x4plus"
-
- variant_amount = st.slider("Variant Amount:", value=defaults.img2img.variant_amount, min_value=0.0, max_value=1.0, step=0.01)
- variant_seed = st.text_input("Variant Seed:", value=defaults.img2img.variant_seed, help="The seed to use when generating a variant, if left blank a random seed will be generated.")
- cfg_scale = st.slider("CFG (Classifier Free Guidance Scale):", min_value=1.0, max_value=30.0, value=defaults.img2img.cfg_scale, step=0.5, help="How strongly the image should follow the prompt.")
- batch_size = st.slider("Batch size", min_value=1, max_value=100, value=defaults.img2img.batch_size, step=1,
- help="How many images are at once in a batch.\
- It increases the VRAM usage a lot but if you have enough VRAM it can reduce the time it takes to finish generation as more images are generated at once.\
- Default: 1")
-
- st.session_state["denoising_strength"] = st.slider("Denoising Strength:", value=defaults.img2img.denoising_strength, min_value=0.01, max_value=1.0, step=0.01)
-
-
- with col2_img2img_layout:
- editor_tab = st.tabs(["Editor"])
-
- editor_image = st.empty()
- st.session_state["editor_image"] = editor_image
-
- if uploaded_images:
- image = Image.open(uploaded_images).convert('RGB')
- #img_array = np.array(image) # if you want to pass it to OpenCV
- new_img = image.resize((width, height))
- st.image(new_img)
-
-
- with col3_img2img_layout:
- result_tab = st.tabs(["Result"])
-
- # create an empty container for the image, progress bar, etc so we can update it later and use session_state to hold them globally.
- preview_image = st.empty()
- st.session_state["preview_image"] = preview_image
-
- #st.session_state["loading"] = st.empty()
-
- st.session_state["progress_bar_text"] = st.empty()
- st.session_state["progress_bar"] = st.empty()
-
-
- message = st.empty()
-
- #if uploaded_images:
- #image = Image.open(uploaded_images).convert('RGB')
- ##img_array = np.array(image) # if you want to pass it to OpenCV
- #new_img = image.resize((width, height))
- #st.image(new_img, use_column_width=True)
-
-
- if generate_button:
- #print("Loading models")
- # load the models when we hit the generate button for the first time, it wont be loaded after that so dont worry.
- load_models(False, use_GFPGAN, use_RealESRGAN, RealESRGAN_model)
- if uploaded_images:
- image = Image.open(uploaded_images).convert('RGB')
- new_img = image.resize((width, height))
- #img_array = np.array(image) # if you want to pass it to OpenCV
-
- try:
- output_images, seed, info, stats = img2img(prompt=prompt, init_info=new_img, ddim_steps=st.session_state["sampling_steps"],
- sampler_name=st.session_state["sampler_name"], n_iter=batch_count,
- cfg_scale=cfg_scale, denoising_strength=st.session_state["denoising_strength"], variant_seed=variant_seed,
- seed=seed, width=width, height=height, fp=defaults.general.fp, variant_amount=variant_amount,
- ddim_eta=0.0, write_info_files=write_info_files, RealESRGAN_model=RealESRGAN_model,
- separate_prompts=separate_prompts, normalize_prompt_weights=normalize_prompt_weights,
- save_individual_images=save_individual_images, save_grid=save_grid,
- group_by_prompt=group_by_prompt, save_as_jpg=save_as_jpg, use_GFPGAN=use_GFPGAN,
- use_RealESRGAN=use_RealESRGAN if not loopback else False, loopback=loopback
- )
+ # elif tabs == 'Textual Inversion':
+ # from textual_inversion import layout
+ # layout()
- #show a message when the generation is complete.
- message.success('Done!', icon="✅")
-
- except (StopException, KeyError):
- print(f"Received Streamlit StopException")
-
- # this will render all the images at the end of the generation but its better if its moved to a second tab inside col2 and shown as a gallery.
- # use the current col2 first tab to show the preview_img and update it as its generated.
- #preview_image.image(output_images, width=750)
-
-
if __name__ == '__main__':
layout()
\ No newline at end of file
diff --git a/scripts/webui_streamlit_old.py b/scripts/webui_streamlit_old.py
new file mode 100644
index 0000000..ad1b9da
--- /dev/null
+++ b/scripts/webui_streamlit_old.py
@@ -0,0 +1,2738 @@
+import warnings
+
+import piexif
+import piexif.helper
+import json
+
+import streamlit as st
+from streamlit import StopException
+
+#streamlit components section
+from st_on_hover_tabs import on_hover_tabs
+
+import base64, cv2
+import os, sys, re, random, datetime, timeit
+from PIL import Image, ImageFont, ImageDraw, ImageFilter, ImageOps
+from PIL.PngImagePlugin import PngInfo
+from scipy import integrate
+import pandas as pd
+import torch
+from torchdiffeq import odeint
+import k_diffusion as K
+import math
+import mimetypes
+import numpy as np
+import pynvml
+import threading
+import time, inspect
+import torch
+from torch import autocast
+from torchvision import transforms
+import torch.nn as nn
+import yaml
+from typing import Union
+from pathlib import Path
+#from tqdm import tqdm
+from contextlib import nullcontext
+from einops import rearrange
+from omegaconf import OmegaConf
+from io import StringIO
+from ldm.models.diffusion.ddim import DDIMSampler
+from ldm.models.diffusion.plms import PLMSSampler
+from ldm.util import instantiate_from_config
+
+from retry import retry
+
+# these are for testing txt2vid, should be removed and we should use things from our own code.
+from diffusers import StableDiffusionPipeline
+from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
+
+#will be used for saving and reading a video made by the txt2vid function
+import imageio, io
+
+# we use python-slugify to make the filenames safe for windows and linux, its better than doing it manually
+# install it with 'pip install python-slugify'
+from slugify import slugify
+
+try:
+ # this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
+ from transformers import logging
+
+ logging.set_verbosity_error()
+except:
+ pass
+
+# remove some annoying deprecation warnings that show every now and then.
+warnings.filterwarnings("ignore", category=DeprecationWarning)
+
+defaults = OmegaConf.load("configs/webui/webui_streamlit.yaml")
+if (os.path.exists("configs/webui/userconfig_streamlit.yaml")):
+ user_defaults = OmegaConf.load("configs/webui/userconfig_streamlit.yaml");
+ defaults = OmegaConf.merge(defaults, user_defaults)
+
+# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the bowser will not show any UI
+mimetypes.init()
+mimetypes.add_type('application/javascript', '.js')
+
+# some of those options should not be changed at all because they would break the model, so I removed them from options.
+opt_C = 4
+opt_f = 8
+
+# should and will be moved to a settings menu in the UI at some point
+grid_format = [s.lower() for s in defaults.general.grid_format.split(':')]
+grid_lossless = False
+grid_quality = 100
+if grid_format[0] == 'png':
+ grid_ext = 'png'
+ grid_format = 'png'
+elif grid_format[0] in ['jpg', 'jpeg']:
+ grid_quality = int(grid_format[1]) if len(grid_format) > 1 else 100
+ grid_ext = 'jpg'
+ grid_format = 'jpeg'
+elif grid_format[0] == 'webp':
+ grid_quality = int(grid_format[1]) if len(grid_format) > 1 else 100
+ grid_ext = 'webp'
+ grid_format = 'webp'
+ if grid_quality < 0: # e.g. webp:-100 for lossless mode
+ grid_lossless = True
+ grid_quality = abs(grid_quality)
+
+# should and will be moved to a settings menu in the UI at some point
+save_format = [s.lower() for s in defaults.general.save_format.split(':')]
+save_lossless = False
+save_quality = 100
+if save_format[0] == 'png':
+ save_ext = 'png'
+ save_format = 'png'
+elif save_format[0] in ['jpg', 'jpeg']:
+ save_quality = int(save_format[1]) if len(save_format) > 1 else 100
+ save_ext = 'jpg'
+ save_format = 'jpeg'
+elif save_format[0] == 'webp':
+ save_quality = int(save_format[1]) if len(save_format) > 1 else 100
+ save_ext = 'webp'
+ save_format = 'webp'
+ if save_quality < 0: # e.g. webp:-100 for lossless mode
+ save_lossless = True
+ save_quality = abs(save_quality)
+
+# this should force GFPGAN and RealESRGAN onto the selected gpu as well
+os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" # see issue #152
+os.environ["CUDA_VISIBLE_DEVICES"] = str(defaults.general.gpu)
+
+@retry(tries=5)
+def load_models(continue_prev_run = False, use_GFPGAN=False, use_RealESRGAN=False, RealESRGAN_model="RealESRGAN_x4plus",
+ CustomModel_available=False, custom_model="Stable Diffusion v1.4"):
+ """Load the different models. We also reuse the models that are already in memory to speed things up instead of loading them again. """
+
+ print ("Loading models.")
+
+ st.session_state["progress_bar_text"].text("Loading models...")
+
+ # Generate random run ID
+ # Used to link runs linked w/ continue_prev_run which is not yet implemented
+ # Use URL and filesystem safe version just in case.
+ st.session_state["run_id"] = base64.urlsafe_b64encode(
+ os.urandom(6)
+ ).decode("ascii")
+
+ # check what models we want to use and if the they are already loaded.
+
+ if use_GFPGAN:
+ if "GFPGAN" in st.session_state:
+ print("GFPGAN already loaded")
+ else:
+ # Load GFPGAN
+ if os.path.exists(defaults.general.GFPGAN_dir):
+ try:
+ st.session_state["GFPGAN"] = load_GFPGAN()
+ print("Loaded GFPGAN")
+ except Exception:
+ import traceback
+ print("Error loading GFPGAN:", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+ else:
+ if "GFPGAN" in st.session_state:
+ del st.session_state["GFPGAN"]
+
+ if use_RealESRGAN:
+ if "RealESRGAN" in st.session_state and st.session_state["RealESRGAN"].model.name == RealESRGAN_model:
+ print("RealESRGAN already loaded")
+ else:
+ #Load RealESRGAN
+ try:
+ # We first remove the variable in case it has something there,
+ # some errors can load the model incorrectly and leave things in memory.
+ del st.session_state["RealESRGAN"]
+ except KeyError:
+ pass
+
+ if os.path.exists(defaults.general.RealESRGAN_dir):
+ # st.session_state is used for keeping the models in memory across multiple pages or runs.
+ st.session_state["RealESRGAN"] = load_RealESRGAN(RealESRGAN_model)
+ print("Loaded RealESRGAN with model "+ st.session_state["RealESRGAN"].model.name)
+
+ else:
+ if "RealESRGAN" in st.session_state:
+ del st.session_state["RealESRGAN"]
+
+
+
+ if "model" in st.session_state:
+ if "model" in st.session_state and st.session_state["custom_model"] == custom_model:
+ print("Model already loaded")
+ else:
+ try:
+ del st.session_state["model"]
+ except KeyError:
+ pass
+
+ config = OmegaConf.load(defaults.general.default_model_config)
+
+ if custom_model == defaults.general.default_model:
+ model = load_model_from_config(config, defaults.general.default_model_path)
+ else:
+ model = load_model_from_config(config, os.path.join("models","custom", f"{custom_model}.ckpt"))
+
+ st.session_state["custom_model"] = custom_model
+ st.session_state["device"] = torch.device(f"cuda:{defaults.general.gpu}") if torch.cuda.is_available() else torch.device("cpu")
+ st.session_state["model"] = (model if defaults.general.no_half else model.half()).to(st.session_state["device"] )
+ else:
+ config = OmegaConf.load(defaults.general.default_model_config)
+
+ if custom_model == defaults.general.default_model:
+ model = load_model_from_config(config, defaults.general.default_model_path)
+ else:
+ model = load_model_from_config(config, os.path.join("models","custom", f"{custom_model}.ckpt"))
+
+ st.session_state["custom_model"] = custom_model
+ st.session_state["device"] = torch.device(f"cuda:{defaults.general.gpu}") if torch.cuda.is_available() else torch.device("cpu")
+ st.session_state["model"] = (model if defaults.general.no_half else model.half()).to(st.session_state["device"] )
+
+ print("Model loaded.")
+
+
+def load_model_from_config(config, ckpt, verbose=False):
+
+ print(f"Loading model from {ckpt}")
+
+ pl_sd = torch.load(ckpt, map_location="cpu")
+ if "global_step" in pl_sd:
+ print(f"Global Step: {pl_sd['global_step']}")
+ sd = pl_sd["state_dict"]
+ model = instantiate_from_config(config.model)
+ m, u = model.load_state_dict(sd, strict=False)
+ if len(m) > 0 and verbose:
+ print("missing keys:")
+ print(m)
+ if len(u) > 0 and verbose:
+ print("unexpected keys:")
+ print(u)
+
+ model.cuda()
+ model.eval()
+ return model
+
+def load_sd_from_config(ckpt, verbose=False):
+ print(f"Loading model from {ckpt}")
+ pl_sd = torch.load(ckpt, map_location="cpu")
+ if "global_step" in pl_sd:
+ print(f"Global Step: {pl_sd['global_step']}")
+ sd = pl_sd["state_dict"]
+ return sd
+#
+@retry(tries=5)
+def generation_callback(img, i=0):
+
+ try:
+ if i == 0:
+ if img['i']: i = img['i']
+ except TypeError:
+ pass
+
+
+ if i % int(defaults.general.update_preview_frequency) == 0 and defaults.general.update_preview:
+ #print (img)
+ #print (type(img))
+ # The following lines will convert the tensor we got on img to an actual image we can render on the UI.
+ # It can probably be done in a better way for someone who knows what they're doing. I don't.
+ #print (img,isinstance(img, torch.Tensor))
+ if isinstance(img, torch.Tensor):
+ x_samples_ddim = (st.session_state["model"] if not defaults.general.optimized else modelFS).decode_first_stage(img)
+ else:
+ # When using the k Diffusion samplers they return a dict instead of a tensor that look like this:
+ # {'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}
+ x_samples_ddim = (st.session_state["model"] if not defaults.general.optimized else modelFS).decode_first_stage(img["denoised"])
+
+ x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
+
+ pil_image = transforms.ToPILImage()(x_samples_ddim.squeeze_(0))
+
+ # update image on the UI so we can see the progress
+ st.session_state["preview_image"].image(pil_image)
+
+ # Show a progress bar so we can keep track of the progress even when the image progress is not been shown,
+ # Dont worry, it doesnt affect the performance.
+ if st.session_state["generation_mode"] == "txt2img":
+ percent = int(100 * float(i+1 if i+1 < st.session_state.sampling_steps else st.session_state.sampling_steps)/float(st.session_state.sampling_steps))
+ st.session_state["progress_bar_text"].text(
+ f"Running step: {i+1 if i+1 < st.session_state.sampling_steps else st.session_state.sampling_steps}/{st.session_state.sampling_steps} {percent if percent < 100 else 100}%")
+ else:
+ if st.session_state["generation_mode"] == "img2img":
+ round_sampling_steps = round(st.session_state.sampling_steps * st.session_state["denoising_strength"])
+ percent = int(100 * float(i+1 if i+1 < round_sampling_steps else round_sampling_steps)/float(round_sampling_steps))
+ st.session_state["progress_bar_text"].text(
+ f"""Running step: {i+1 if i+1 < round_sampling_steps else round_sampling_steps}/{round_sampling_steps} {percent if percent < 100 else 100}%""")
+ else:
+ if st.session_state["generation_mode"] == "txt2vid":
+ percent = int(100 * float(i+1 if i+1 < st.session_state.sampling_steps else st.session_state.sampling_steps)/float(st.session_state.sampling_steps))
+ st.session_state["progress_bar_text"].text(
+ f"Running step: {i+1 if i+1 < st.session_state.sampling_steps else st.session_state.sampling_steps}/{st.session_state.sampling_steps}"
+ f"{percent if percent < 100 else 100}%")
+
+ st.session_state["progress_bar"].progress(percent if percent < 100 else 100)
+
+
+
+class MemUsageMonitor(threading.Thread):
+ stop_flag = False
+ max_usage = 0
+ total = -1
+
+ def __init__(self, name):
+ threading.Thread.__init__(self)
+ self.name = name
+
+ def run(self):
+ try:
+ pynvml.nvmlInit()
+ except:
+ print(f"[{self.name}] Unable to initialize NVIDIA management. No memory stats. \n")
+ return
+ print(f"[{self.name}] Recording max memory usage...\n")
+ handle = pynvml.nvmlDeviceGetHandleByIndex(defaults.general.gpu)
+ self.total = pynvml.nvmlDeviceGetMemoryInfo(handle).total
+ while not self.stop_flag:
+ m = pynvml.nvmlDeviceGetMemoryInfo(handle)
+ self.max_usage = max(self.max_usage, m.used)
+ # print(self.max_usage)
+ time.sleep(0.1)
+ print(f"[{self.name}] Stopped recording.\n")
+ pynvml.nvmlShutdown()
+
+ def read(self):
+ return self.max_usage, self.total
+
+ def stop(self):
+ self.stop_flag = True
+
+ def read_and_stop(self):
+ self.stop_flag = True
+ return self.max_usage, self.total
+
+class CFGMaskedDenoiser(nn.Module):
+ def __init__(self, model):
+ super().__init__()
+ self.inner_model = model
+
+ def forward(self, x, sigma, uncond, cond, cond_scale, mask, x0, xi):
+ x_in = x
+ x_in = torch.cat([x_in] * 2)
+ sigma_in = torch.cat([sigma] * 2)
+ cond_in = torch.cat([uncond, cond])
+ uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
+ denoised = uncond + (cond - uncond) * cond_scale
+
+ if mask is not None:
+ assert x0 is not None
+ img_orig = x0
+ mask_inv = 1. - mask
+ denoised = (img_orig * mask_inv) + (mask * denoised)
+
+ return denoised
+
+class CFGDenoiser(nn.Module):
+ def __init__(self, model):
+ super().__init__()
+ self.inner_model = model
+
+ def forward(self, x, sigma, uncond, cond, cond_scale):
+ x_in = torch.cat([x] * 2)
+ sigma_in = torch.cat([sigma] * 2)
+ cond_in = torch.cat([uncond, cond])
+ uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
+ return uncond + (cond - uncond) * cond_scale
+def append_zero(x):
+ return torch.cat([x, x.new_zeros([1])])
+def append_dims(x, target_dims):
+ """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
+ dims_to_append = target_dims - x.ndim
+ if dims_to_append < 0:
+ raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
+ return x[(...,) + (None,) * dims_to_append]
+def get_sigmas_karras(n, sigma_min, sigma_max, rho=7., device='cpu'):
+ """Constructs the noise schedule of Karras et al. (2022)."""
+ ramp = torch.linspace(0, 1, n)
+ min_inv_rho = sigma_min ** (1 / rho)
+ max_inv_rho = sigma_max ** (1 / rho)
+ sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
+ return append_zero(sigmas).to(device)
+
+
+def get_sigmas_exponential(n, sigma_min, sigma_max, device='cpu'):
+ """Constructs an exponential noise schedule."""
+ sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), n, device=device).exp()
+ return append_zero(sigmas)
+
+
+def get_sigmas_vp(n, beta_d=19.9, beta_min=0.1, eps_s=1e-3, device='cpu'):
+ """Constructs a continuous VP noise schedule."""
+ t = torch.linspace(1, eps_s, n, device=device)
+ sigmas = torch.sqrt(torch.exp(beta_d * t ** 2 / 2 + beta_min * t) - 1)
+ return append_zero(sigmas)
+
+
+def to_d(x, sigma, denoised):
+ """Converts a denoiser output to a Karras ODE derivative."""
+ return (x - denoised) / append_dims(sigma, x.ndim)
+def linear_multistep_coeff(order, t, i, j):
+ if order - 1 > i:
+ raise ValueError(f'Order {order} too high for step {i}')
+ def fn(tau):
+ prod = 1.
+ for k in range(order):
+ if j == k:
+ continue
+ prod *= (tau - t[i - k]) / (t[i - j] - t[i - k])
+ return prod
+ return integrate.quad(fn, t[i], t[i + 1], epsrel=1e-4)[0]
+
+class KDiffusionSampler:
+ def __init__(self, m, sampler):
+ self.model = m
+ self.model_wrap = K.external.CompVisDenoiser(m)
+ self.schedule = sampler
+ def get_sampler_name(self):
+ return self.schedule
+ def sample(self, S, conditioning, batch_size, shape, verbose, unconditional_guidance_scale, unconditional_conditioning, eta, x_T, img_callback=None, log_every_t=None):
+ sigmas = self.model_wrap.get_sigmas(S)
+ x = x_T * sigmas[0]
+ model_wrap_cfg = CFGDenoiser(self.model_wrap)
+ samples_ddim = None
+ samples_ddim = K.sampling.__dict__[f'sample_{self.schedule}'](model_wrap_cfg, x, sigmas,
+ extra_args={'cond': conditioning, 'uncond': unconditional_conditioning,
+ 'cond_scale': unconditional_guidance_scale}, disable=False, callback=generation_callback)
+ #
+ return samples_ddim, None
+
+
+@torch.no_grad()
+def log_likelihood(model, x, sigma_min, sigma_max, extra_args=None, atol=1e-4, rtol=1e-4):
+ extra_args = {} if extra_args is None else extra_args
+ s_in = x.new_ones([x.shape[0]])
+ v = torch.randint_like(x, 2) * 2 - 1
+ fevals = 0
+ def ode_fn(sigma, x):
+ nonlocal fevals
+ with torch.enable_grad():
+ x = x[0].detach().requires_grad_()
+ denoised = model(x, sigma * s_in, **extra_args)
+ d = to_d(x, sigma, denoised)
+ fevals += 1
+ grad = torch.autograd.grad((d * v).sum(), x)[0]
+ d_ll = (v * grad).flatten(1).sum(1)
+ return d.detach(), d_ll
+ x_min = x, x.new_zeros([x.shape[0]])
+ t = x.new_tensor([sigma_min, sigma_max])
+ sol = odeint(ode_fn, x_min, t, atol=atol, rtol=rtol, method='dopri5')
+ latent, delta_ll = sol[0][-1], sol[1][-1]
+ ll_prior = torch.distributions.Normal(0, sigma_max).log_prob(latent).flatten(1).sum(1)
+ return ll_prior + delta_ll, {'fevals': fevals}
+
+
+def create_random_tensors(shape, seeds):
+ xs = []
+ for seed in seeds:
+ torch.manual_seed(seed)
+
+ # randn results depend on device; gpu and cpu get different results for same seed;
+ # the way I see it, it's better to do this on CPU, so that everyone gets same result;
+ # but the original script had it like this so i do not dare change it for now because
+ # it will break everyone's seeds.
+ xs.append(torch.randn(shape, device=defaults.general.gpu))
+ x = torch.stack(xs)
+ return x
+
+def torch_gc():
+ torch.cuda.empty_cache()
+ torch.cuda.ipc_collect()
+
+def load_GFPGAN():
+ model_name = 'GFPGANv1.3'
+ model_path = os.path.join(defaults.general.GFPGAN_dir, 'experiments/pretrained_models', model_name + '.pth')
+ if not os.path.isfile(model_path):
+ raise Exception("GFPGAN model not found at path "+model_path)
+
+ sys.path.append(os.path.abspath(defaults.general.GFPGAN_dir))
+ from gfpgan import GFPGANer
+
+ if defaults.general.gfpgan_cpu or defaults.general.extra_models_cpu:
+ instance = GFPGANer(model_path=model_path, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=torch.device('cpu'))
+ elif defaults.general.extra_models_gpu:
+ instance = GFPGANer(model_path=model_path, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=torch.device(f'cuda:{defaults.general.gfpgan_gpu}'))
+ else:
+ instance = GFPGANer(model_path=model_path, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=torch.device(f'cuda:{defaults.general.gpu}'))
+ return instance
+
+def load_RealESRGAN(model_name: str):
+ from basicsr.archs.rrdbnet_arch import RRDBNet
+ RealESRGAN_models = {
+ 'RealESRGAN_x4plus': RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4),
+ 'RealESRGAN_x4plus_anime_6B': RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
+ }
+
+ model_path = os.path.join(defaults.general.RealESRGAN_dir, 'experiments/pretrained_models', model_name + '.pth')
+ if not os.path.exists(os.path.join(defaults.general.RealESRGAN_dir, "experiments","pretrained_models", f"{model_name}.pth")):
+ raise Exception(model_name+".pth not found at path "+model_path)
+
+ sys.path.append(os.path.abspath(defaults.general.RealESRGAN_dir))
+ from realesrgan import RealESRGANer
+
+ if defaults.general.esrgan_cpu or defaults.general.extra_models_cpu:
+ instance = RealESRGANer(scale=2, model_path=model_path, model=RealESRGAN_models[model_name], pre_pad=0, half=False) # cpu does not support half
+ instance.device = torch.device('cpu')
+ instance.model.to('cpu')
+ elif defaults.general.extra_models_gpu:
+ instance = RealESRGANer(scale=2, model_path=model_path, model=RealESRGAN_models[model_name], pre_pad=0, half=not defaults.general.no_half, device=torch.device(f'cuda:{defaults.general.esrgan_gpu}'))
+ else:
+ instance = RealESRGANer(scale=2, model_path=model_path, model=RealESRGAN_models[model_name], pre_pad=0, half=not defaults.general.no_half, device=torch.device(f'cuda:{defaults.general.gpu}'))
+ instance.model.name = model_name
+
+ return instance
+
+prompt_parser = re.compile("""
+ (?P # capture group for 'prompt'
+ [^:]+ # match one or more non ':' characters
+ ) # end 'prompt'
+ (?: # non-capture group
+ :+ # match one or more ':' characters
+ (?P # capture group for 'weight'
+ -?\\d+(?:\\.\\d+)? # match positive or negative decimal number
+ )? # end weight capture group, make optional
+ \\s* # strip spaces after weight
+ | # OR
+ $ # else, if no ':' then match end of line
+ ) # end non-capture group
+""", re.VERBOSE)
+
+# grabs all text up to the first occurrence of ':' as sub-prompt
+# takes the value following ':' as weight
+# if ':' has no value defined, defaults to 1.0
+# repeats until no text remaining
+def split_weighted_subprompts(input_string, normalize=True):
+ parsed_prompts = [(match.group("prompt"), float(match.group("weight") or 1)) for match in re.finditer(prompt_parser, input_string)]
+ if not normalize:
+ return parsed_prompts
+ # this probably still doesn't handle negative weights very well
+ weight_sum = sum(map(lambda x: x[1], parsed_prompts))
+ return [(x[0], x[1] / weight_sum) for x in parsed_prompts]
+
+def slerp(device, t, v0:torch.Tensor, v1:torch.Tensor, DOT_THRESHOLD=0.9995):
+ v0 = v0.detach().cpu().numpy()
+ v1 = v1.detach().cpu().numpy()
+
+ dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1)))
+ if np.abs(dot) > DOT_THRESHOLD:
+ v2 = (1 - t) * v0 + t * v1
+ else:
+ theta_0 = np.arccos(dot)
+ sin_theta_0 = np.sin(theta_0)
+ theta_t = theta_0 * t
+ sin_theta_t = np.sin(theta_t)
+ s0 = np.sin(theta_0 - theta_t) / sin_theta_0
+ s1 = sin_theta_t / sin_theta_0
+ v2 = s0 * v0 + s1 * v1
+
+ v2 = torch.from_numpy(v2).to(device)
+
+ return v2
+
+
+def optimize_update_preview_frequency(current_chunk_speed, previous_chunk_speed, update_preview_frequency):
+ """Find the optimal update_preview_frequency value maximizing
+ performance while minimizing the time between updates."""
+ if current_chunk_speed >= previous_chunk_speed:
+ #print(f"{current_chunk_speed} >= {previous_chunk_speed}")
+ update_preview_frequency +=1
+ previous_chunk_speed = current_chunk_speed
+ else:
+ #print(f"{current_chunk_speed} <= {previous_chunk_speed}")
+ update_preview_frequency -=1
+ previous_chunk_speed = current_chunk_speed
+
+ return current_chunk_speed, previous_chunk_speed, update_preview_frequency
+
+# -----------------------------------------------------------------------------
+
+@torch.no_grad()
+def diffuse(
+ pipe,
+ cond_embeddings, # text conditioning, should be (1, 77, 768)
+ cond_latents, # image conditioning, should be (1, 4, 64, 64)
+ num_inference_steps,
+ cfg_scale,
+ eta,
+ ):
+
+ torch_device = cond_latents.get_device()
+
+ # classifier guidance: add the unconditional embedding
+ max_length = cond_embeddings.shape[1] # 77
+ uncond_input = pipe.tokenizer([""], padding="max_length", max_length=max_length, return_tensors="pt")
+ uncond_embeddings = pipe.text_encoder(uncond_input.input_ids.to(torch_device))[0]
+ text_embeddings = torch.cat([uncond_embeddings, cond_embeddings])
+
+ # if we use LMSDiscreteScheduler, let's make sure latents are mulitplied by sigmas
+ if isinstance(pipe.scheduler, LMSDiscreteScheduler):
+ cond_latents = cond_latents * pipe.scheduler.sigmas[0]
+
+ # init the scheduler
+ accepts_offset = "offset" in set(inspect.signature(pipe.scheduler.set_timesteps).parameters.keys())
+ extra_set_kwargs = {}
+ if accepts_offset:
+ extra_set_kwargs["offset"] = 1
+
+ pipe.scheduler.set_timesteps(num_inference_steps + st.session_state.sampling_steps, **extra_set_kwargs)
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+ accepts_eta = "eta" in set(inspect.signature(pipe.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+
+ step_counter = 0
+ inference_counter = 0
+ current_chunk_speed = 0
+ previous_chunk_speed = 0
+
+ # diffuse!
+ for i, t in enumerate(pipe.scheduler.timesteps):
+ start = timeit.default_timer()
+
+ #status_text.text(f"Running step: {step_counter}{total_number_steps} {percent} | {duration:.2f}{speed}")
+
+ # expand the latents for classifier free guidance
+ latent_model_input = torch.cat([cond_latents] * 2)
+ if isinstance(pipe.scheduler, LMSDiscreteScheduler):
+ sigma = pipe.scheduler.sigmas[i]
+ latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
+
+ # predict the noise residual
+ noise_pred = pipe.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"]
+
+ # cfg
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + cfg_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ if isinstance(pipe.scheduler, LMSDiscreteScheduler):
+ cond_latents = pipe.scheduler.step(noise_pred, i, cond_latents, **extra_step_kwargs)["prev_sample"]
+ else:
+ cond_latents = pipe.scheduler.step(noise_pred, t, cond_latents, **extra_step_kwargs)["prev_sample"]
+
+ #print (st.session_state["update_preview_frequency"])
+ #update the preview image if it is enabled and the frequency matches the step_counter
+ if defaults.general.update_preview:
+ step_counter += 1
+
+ if st.session_state.dynamic_preview_frequency:
+ current_chunk_speed, previous_chunk_speed, defaults.general.update_preview_frequency = optimize_update_preview_frequency(
+ current_chunk_speed, previous_chunk_speed, defaults.general.update_preview_frequency)
+
+ if defaults.general.update_preview_frequency == step_counter or step_counter == st.session_state.sampling_steps:
+ #scale and decode the image latents with vae
+ cond_latents_2 = 1 / 0.18215 * cond_latents
+ image_2 = pipe.vae.decode(cond_latents_2)
+
+ # generate output numpy image as uint8
+ image_2 = (image_2 / 2 + 0.5).clamp(0, 1)
+ image_2 = image_2.cpu().permute(0, 2, 3, 1).numpy()
+ image_2 = (image_2[0] * 255).astype(np.uint8)
+
+ st.session_state["preview_image"].image(image_2)
+
+ step_counter = 0
+
+ duration = timeit.default_timer() - start
+
+ current_chunk_speed = duration
+
+ if duration >= 1:
+ speed = "s/it"
+ else:
+ speed = "it/s"
+ duration = 1 / duration
+
+ if i > st.session_state.sampling_steps:
+ inference_counter += 1
+ inference_percent = int(100 * float(inference_counter if inference_counter < num_inference_steps else num_inference_steps)/float(num_inference_steps))
+ inference_progress = f"{inference_counter if inference_counter < num_inference_steps else num_inference_steps}/{num_inference_steps} {inference_percent}% "
+ else:
+ inference_progress = ""
+
+ percent = int(100 * float(i+1 if i+1 < st.session_state.sampling_steps else st.session_state.sampling_steps)/float(st.session_state.sampling_steps))
+ frames_percent = int(100 * float(st.session_state.current_frame if st.session_state.current_frame < st.session_state.max_frames else st.session_state.max_frames)/float(st.session_state.max_frames))
+
+ st.session_state["progress_bar_text"].text(
+ f"Running step: {i+1 if i+1 < st.session_state.sampling_steps else st.session_state.sampling_steps}/{st.session_state.sampling_steps} "
+ f"{percent if percent < 100 else 100}% {inference_progress}{duration:.2f}{speed} | "
+ f"Frame: {st.session_state.current_frame if st.session_state.current_frame < st.session_state.max_frames else st.session_state.max_frames}/{st.session_state.max_frames} "
+ f"{frames_percent if frames_percent < 100 else 100}% {st.session_state.frame_duration:.2f}{st.session_state.frame_speed}"
+ )
+ st.session_state["progress_bar"].progress(percent if percent < 100 else 100)
+
+ # scale and decode the image latents with vae
+ cond_latents = 1 / 0.18215 * cond_latents
+ image = pipe.vae.decode(cond_latents)
+
+ # generate output numpy image as uint8
+ image = (image / 2 + 0.5).clamp(0, 1)
+ image = image.cpu().permute(0, 2, 3, 1).numpy()
+ image = (image[0] * 255).astype(np.uint8)
+
+ return image
+
+
+def ModelLoader(models,load=False,unload=False,imgproc_realesrgan_model_name='RealESRGAN_x4plus'):
+ #get global variables
+ global_vars = globals()
+ #check if m is in globals
+ if unload:
+ for m in models:
+ if m in global_vars:
+ #if it is, delete it
+ del global_vars[m]
+ if defaults.general.optimized:
+ if m == 'model':
+ del global_vars[m+'FS']
+ del global_vars[m+'CS']
+ if m =='model':
+ m='Stable Diffusion'
+ print('Unloaded ' + m)
+ if load:
+ for m in models:
+ if m not in global_vars or m in global_vars and type(global_vars[m]) == bool:
+ #if it isn't, load it
+ if m == 'GFPGAN':
+ global_vars[m] = load_GFPGAN()
+ elif m == 'model':
+ sdLoader = load_sd_from_config()
+ global_vars[m] = sdLoader[0]
+ if defaults.general.optimized:
+ global_vars[m+'CS'] = sdLoader[1]
+ global_vars[m+'FS'] = sdLoader[2]
+ elif m == 'RealESRGAN':
+ global_vars[m] = load_RealESRGAN(imgproc_realesrgan_model_name)
+ elif m == 'LDSR':
+ global_vars[m] = load_LDSR()
+ if m =='model':
+ m='Stable Diffusion'
+ print('Loaded ' + m)
+ torch_gc()
+
+
+
+def get_font(fontsize):
+ fonts = ["arial.ttf", "DejaVuSans.ttf"]
+ for font_name in fonts:
+ try:
+ return ImageFont.truetype(font_name, fontsize)
+ except OSError:
+ pass
+
+ # ImageFont.load_default() is practically unusable as it only supports
+ # latin1, so raise an exception instead if no usable font was found
+ raise Exception(f"No usable font found (tried {', '.join(fonts)})")
+
+def load_embeddings(fp):
+ if fp is not None and hasattr(st.session_state["model"], "embedding_manager"):
+ st.session_state["model"].embedding_manager.load(fp['name'])
+
+def image_grid(imgs, batch_size, force_n_rows=None, captions=None):
+ #print (len(imgs))
+ if force_n_rows is not None:
+ rows = force_n_rows
+ elif defaults.general.n_rows > 0:
+ rows = defaults.general.n_rows
+ elif defaults.general.n_rows == 0:
+ rows = batch_size
+ else:
+ rows = math.sqrt(len(imgs))
+ rows = round(rows)
+
+ cols = math.ceil(len(imgs) / rows)
+
+ w, h = imgs[0].size
+ grid = Image.new('RGB', size=(cols * w, rows * h), color='black')
+
+ fnt = get_font(30)
+
+ for i, img in enumerate(imgs):
+ grid.paste(img, box=(i % cols * w, i // cols * h))
+ if captions and i= 2**32:
+ n = n >> 32
+ return n
+
+def check_prompt_length(prompt, comments):
+ """this function tests if prompt is too long, and if so, adds a message to comments"""
+
+ tokenizer = (st.session_state["model"] if not defaults.general.optimized else modelCS).cond_stage_model.tokenizer
+ max_length = (st.session_state["model"] if not defaults.general.optimized else modelCS).cond_stage_model.max_length
+
+ info = (st.session_state["model"] if not defaults.general.optimized else modelCS).cond_stage_model.tokenizer([prompt], truncation=True, max_length=max_length,
+ return_overflowing_tokens=True, padding="max_length", return_tensors="pt")
+ ovf = info['overflowing_tokens'][0]
+ overflowing_count = ovf.shape[0]
+ if overflowing_count == 0:
+ return
+
+ vocab = {v: k for k, v in tokenizer.get_vocab().items()}
+ overflowing_words = [vocab.get(int(x), "") for x in ovf]
+ overflowing_text = tokenizer.convert_tokens_to_string(''.join(overflowing_words))
+
+ comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
+
+def save_sample(image, sample_path_i, filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale,
+ normalize_prompt_weights, use_GFPGAN, write_info_files, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback,
+ save_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, save_individual_images):
+
+ filename_i = os.path.join(sample_path_i, filename)
+
+ if defaults.general.save_metadata or write_info_files:
+ # toggles differ for txt2img vs. img2img:
+ offset = 0 if init_img is None else 2
+ toggles = []
+ if prompt_matrix:
+ toggles.append(0)
+ if normalize_prompt_weights:
+ toggles.append(1)
+ if init_img is not None:
+ if uses_loopback:
+ toggles.append(2)
+ if uses_random_seed_loopback:
+ toggles.append(3)
+ if save_individual_images:
+ toggles.append(2 + offset)
+ if save_grid:
+ toggles.append(3 + offset)
+ if sort_samples:
+ toggles.append(4 + offset)
+ if write_info_files:
+ toggles.append(5 + offset)
+ if use_GFPGAN:
+ toggles.append(6 + offset)
+ metadata = \
+ dict(
+ target="txt2img" if init_img is None else "img2img",
+ prompt=prompts[i], ddim_steps=steps, toggles=toggles, sampler_name=sampler_name,
+ ddim_eta=ddim_eta, n_iter=n_iter, batch_size=batch_size, cfg_scale=cfg_scale,
+ seed=seeds[i], width=width, height=height, normalize_prompt_weights=normalize_prompt_weights)
+ # Not yet any use for these, but they bloat up the files:
+ # info_dict["init_img"] = init_img
+ # info_dict["init_mask"] = init_mask
+ if init_img is not None:
+ metadata["denoising_strength"] = str(denoising_strength)
+ metadata["resize_mode"] = resize_mode
+
+ if write_info_files:
+ with open(f"{filename_i}.yaml", "w", encoding="utf8") as f:
+ yaml.dump(metadata, f, allow_unicode=True, width=10000)
+
+ if defaults.general.save_metadata:
+ # metadata = {
+ # "SD:prompt": prompts[i],
+ # "SD:seed": str(seeds[i]),
+ # "SD:width": str(width),
+ # "SD:height": str(height),
+ # "SD:steps": str(steps),
+ # "SD:cfg_scale": str(cfg_scale),
+ # "SD:normalize_prompt_weights": str(normalize_prompt_weights),
+ # }
+ metadata = {"SD:" + k:v for (k,v) in metadata.items()}
+
+ if save_ext == "png":
+ mdata = PngInfo()
+ for key in metadata:
+ mdata.add_text(key, str(metadata[key]))
+ image.save(f"{filename_i}.png", pnginfo=mdata)
+ else:
+ if jpg_sample:
+ image.save(f"{filename_i}.jpg", quality=save_quality,
+ optimize=True)
+ elif save_ext == "webp":
+ image.save(f"{filename_i}.{save_ext}", f"webp", quality=save_quality,
+ lossless=save_lossless)
+ else:
+ # not sure what file format this is
+ image.save(f"{filename_i}.{save_ext}", f"{save_ext}")
+ try:
+ exif_dict = piexif.load(f"{filename_i}.{save_ext}")
+ except:
+ exif_dict = { "Exif": dict() }
+ exif_dict["Exif"][piexif.ExifIFD.UserComment] = piexif.helper.UserComment.dump(
+ json.dumps(metadata), encoding="unicode")
+ piexif.insert(piexif.dump(exif_dict), f"{filename_i}.{save_ext}")
+
+ # render the image on the frontend
+ st.session_state["preview_image"].image(image)
+
+def get_next_sequence_number(path, prefix=''):
+ """
+ Determines and returns the next sequence number to use when saving an
+ image in the specified directory.
+
+ If a prefix is given, only consider files whose names start with that
+ prefix, and strip the prefix from filenames before extracting their
+ sequence number.
+
+ The sequence starts at 0.
+ """
+ result = -1
+ for p in Path(path).iterdir():
+ if p.name.endswith(('.png', '.jpg')) and p.name.startswith(prefix):
+ tmp = p.name[len(prefix):]
+ try:
+ result = max(int(tmp.split('-')[0]), result)
+ except ValueError:
+ pass
+ return result + 1
+
+
+def oxlamon_matrix(prompt, seed, n_iter, batch_size):
+ pattern = re.compile(r'(,\s){2,}')
+
+ class PromptItem:
+ def __init__(self, text, parts, item):
+ self.text = text
+ self.parts = parts
+ if item:
+ self.parts.append( item )
+
+ def clean(txt):
+ return re.sub(pattern, ', ', txt)
+
+ def getrowcount( txt ):
+ for data in re.finditer( ".*?\\((.*?)\\).*", txt ):
+ if data:
+ return len(data.group(1).split("|"))
+ break
+ return None
+
+ def repliter( txt ):
+ for data in re.finditer( ".*?\\((.*?)\\).*", txt ):
+ if data:
+ r = data.span(1)
+ for item in data.group(1).split("|"):
+ yield (clean(txt[:r[0]-1] + item.strip() + txt[r[1]+1:]), item.strip())
+ break
+
+ def iterlist( items ):
+ outitems = []
+ for item in items:
+ for newitem, newpart in repliter(item.text):
+ outitems.append( PromptItem(newitem, item.parts.copy(), newpart) )
+
+ return outitems
+
+ def getmatrix( prompt ):
+ dataitems = [ PromptItem( prompt[1:].strip(), [], None ) ]
+ while True:
+ newdataitems = iterlist( dataitems )
+ if len( newdataitems ) == 0:
+ return dataitems
+ dataitems = newdataitems
+
+ def classToArrays( items, seed, n_iter ):
+ texts = []
+ parts = []
+ seeds = []
+
+ for item in items:
+ itemseed = seed
+ for i in range(n_iter):
+ texts.append( item.text )
+ parts.append( f"Seed: {itemseed}\n" + "\n".join(item.parts) )
+ seeds.append( itemseed )
+ itemseed += 1
+
+ return seeds, texts, parts
+
+ all_seeds, all_prompts, prompt_matrix_parts = classToArrays(getmatrix( prompt ), seed, n_iter)
+ n_iter = math.ceil(len(all_prompts) / batch_size)
+
+ needrows = getrowcount(prompt)
+ if needrows:
+ xrows = math.sqrt(len(all_prompts))
+ xrows = round(xrows)
+ # if columns is to much
+ cols = math.ceil(len(all_prompts) / xrows)
+ if cols > needrows*4:
+ needrows *= 2
+
+ return all_seeds, n_iter, prompt_matrix_parts, all_prompts, needrows
+
+
+import find_noise_for_image
+import matched_noise
+
+
+def process_images(
+ outpath, func_init, func_sample, prompt, seed, sampler_name, save_grid, batch_size,
+ n_iter, steps, cfg_scale, width, height, prompt_matrix, use_GFPGAN, use_RealESRGAN, realesrgan_model_name,
+ fp=None, ddim_eta=0.0, normalize_prompt_weights=True, init_img=None, init_mask=None,
+ mask_blur_strength=3, mask_restore=False, denoising_strength=0.75, noise_mode=0, find_noise_steps=1, resize_mode=None, uses_loopback=False,
+ uses_random_seed_loopback=False, sort_samples=True, write_info_files=True, jpg_sample=False,
+ variant_amount=0.0, variant_seed=None, save_individual_images: bool = True):
+ """this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
+ assert prompt is not None
+ torch_gc()
+ # start time after garbage collection (or before?)
+ start_time = time.time()
+
+ # We will use this date here later for the folder name, need to start_time if not need
+ run_start_dt = datetime.datetime.now()
+
+ mem_mon = MemUsageMonitor('MemMon')
+ mem_mon.start()
+
+ if hasattr(st.session_state["model"], "embedding_manager"):
+ load_embeddings(fp)
+
+ os.makedirs(outpath, exist_ok=True)
+
+ sample_path = os.path.join(outpath, "samples")
+ os.makedirs(sample_path, exist_ok=True)
+
+ if not ("|" in prompt) and prompt.startswith("@"):
+ prompt = prompt[1:]
+
+ comments = []
+
+ prompt_matrix_parts = []
+ simple_templating = False
+ add_original_image = not (use_RealESRGAN or use_GFPGAN)
+
+ if prompt_matrix:
+ if prompt.startswith("@"):
+ simple_templating = True
+ add_original_image = not (use_RealESRGAN or use_GFPGAN)
+ all_seeds, n_iter, prompt_matrix_parts, all_prompts, frows = oxlamon_matrix(prompt, seed, n_iter, batch_size)
+ else:
+ all_prompts = []
+ prompt_matrix_parts = prompt.split("|")
+ combination_count = 2 ** (len(prompt_matrix_parts) - 1)
+ for combination_num in range(combination_count):
+ current = prompt_matrix_parts[0]
+
+ for n, text in enumerate(prompt_matrix_parts[1:]):
+ if combination_num & (2 ** n) > 0:
+ current += ("" if text.strip().startswith(",") else ", ") + text
+
+ all_prompts.append(current)
+
+ n_iter = math.ceil(len(all_prompts) / batch_size)
+ all_seeds = len(all_prompts) * [seed]
+
+ print(f"Prompt matrix will create {len(all_prompts)} images using a total of {n_iter} batches.")
+ else:
+
+ if not defaults.general.no_verify_input:
+ try:
+ check_prompt_length(prompt, comments)
+ except:
+ import traceback
+ print("Error verifying input:", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+
+ all_prompts = batch_size * n_iter * [prompt]
+ all_seeds = [seed + x for x in range(len(all_prompts))]
+
+ precision_scope = autocast if defaults.general.precision == "autocast" else nullcontext
+ output_images = []
+ grid_captions = []
+ stats = []
+ with torch.no_grad(), precision_scope("cuda"), (st.session_state["model"].ema_scope() if not defaults.general.optimized else nullcontext()):
+ init_data = func_init()
+ tic = time.time()
+
+
+ # if variant_amount > 0.0 create noise from base seed
+ base_x = None
+ if variant_amount > 0.0:
+ target_seed_randomizer = seed_to_int('') # random seed
+ torch.manual_seed(seed) # this has to be the single starting seed (not per-iteration)
+ base_x = create_random_tensors([opt_C, height // opt_f, width // opt_f], seeds=[seed])
+ # we don't want all_seeds to be sequential from starting seed with variants,
+ # since that makes the same variants each time,
+ # so we add target_seed_randomizer as a random offset
+ for si in range(len(all_seeds)):
+ all_seeds[si] += target_seed_randomizer
+
+ for n in range(n_iter):
+ print(f"Iteration: {n+1}/{n_iter}")
+ prompts = all_prompts[n * batch_size:(n + 1) * batch_size]
+ captions = prompt_matrix_parts[n * batch_size:(n + 1) * batch_size]
+ seeds = all_seeds[n * batch_size:(n + 1) * batch_size]
+
+ print(prompt)
+
+ if defaults.general.optimized:
+ modelCS.to(defaults.general.gpu)
+
+ uc = (st.session_state["model"] if not defaults.general.optimized else modelCS).get_learned_conditioning(len(prompts) * [""])
+
+ if isinstance(prompts, tuple):
+ prompts = list(prompts)
+
+ # split the prompt if it has : for weighting
+ # TODO for speed it might help to have this occur when all_prompts filled??
+ weighted_subprompts = split_weighted_subprompts(prompts[0], normalize_prompt_weights)
+
+ # sub-prompt weighting used if more than 1
+ if len(weighted_subprompts) > 1:
+ c = torch.zeros_like(uc) # i dont know if this is correct.. but it works
+ for i in range(0, len(weighted_subprompts)):
+ # note if alpha negative, it functions same as torch.sub
+ c = torch.add(c, (st.session_state["model"] if not defaults.general.optimized else modelCS).get_learned_conditioning(weighted_subprompts[i][0]), alpha=weighted_subprompts[i][1])
+ else: # just behave like usual
+ c = (st.session_state["model"] if not defaults.general.optimized else modelCS).get_learned_conditioning(prompts)
+
+
+ shape = [opt_C, height // opt_f, width // opt_f]
+
+ if defaults.general.optimized:
+ mem = torch.cuda.memory_allocated()/1e6
+ modelCS.to("cpu")
+ while(torch.cuda.memory_allocated()/1e6 >= mem):
+ time.sleep(1)
+
+ if noise_mode == 1 or noise_mode == 3:
+ # TODO params for find_noise_to_image
+ x = torch.cat(batch_size * [find_noise_for_image.find_noise_for_image(
+ st.session_state["model"], st.session_state["device"],
+ init_img.convert('RGB'), '', find_noise_steps, 0.0, normalize=True,
+ generation_callback=generation_callback,
+ )], dim=0)
+ else:
+ # we manually generate all input noises because each one should have a specific seed
+ x = create_random_tensors(shape, seeds=seeds)
+
+ if variant_amount > 0.0: # we are making variants
+ # using variant_seed as sneaky toggle,
+ # when not None or '' use the variant_seed
+ # otherwise use seeds
+ if variant_seed != None and variant_seed != '':
+ specified_variant_seed = seed_to_int(variant_seed)
+ torch.manual_seed(specified_variant_seed)
+ seeds = [specified_variant_seed]
+ # finally, slerp base_x noise to target_x noise for creating a variant
+ x = slerp(defaults.general.gpu, max(0.0, min(1.0, variant_amount)), base_x, x)
+
+ samples_ddim = func_sample(init_data=init_data, x=x, conditioning=c, unconditional_conditioning=uc, sampler_name=sampler_name)
+
+ if defaults.general.optimized:
+ modelFS.to(defaults.general.gpu)
+
+ x_samples_ddim = (st.session_state["model"] if not defaults.general.optimized else modelFS).decode_first_stage(samples_ddim)
+ x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
+
+ for i, x_sample in enumerate(x_samples_ddim):
+ sanitized_prompt = slugify(prompts[i])
+
+ if sort_samples:
+ full_path = os.path.join(os.getcwd(), sample_path, sanitized_prompt)
+
+
+ sanitized_prompt = sanitized_prompt[:220-len(full_path)]
+ sample_path_i = os.path.join(sample_path, sanitized_prompt)
+
+ #print(f"output folder length: {len(os.path.join(os.getcwd(), sample_path_i))}")
+ #print(os.path.join(os.getcwd(), sample_path_i))
+
+ os.makedirs(sample_path_i, exist_ok=True)
+ base_count = get_next_sequence_number(sample_path_i)
+ filename = f"{base_count:05}-{steps}_{sampler_name}_{seeds[i]}"
+ else:
+ full_path = os.path.join(os.getcwd(), sample_path)
+ sample_path_i = sample_path
+ base_count = get_next_sequence_number(sample_path_i)
+ filename = f"{base_count:05}-{steps}_{sampler_name}_{seeds[i]}_{sanitized_prompt}"[:220-len(full_path)] #same as before
+
+ x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
+ x_sample = x_sample.astype(np.uint8)
+ image = Image.fromarray(x_sample)
+ original_sample = x_sample
+ original_filename = filename
+
+ if use_GFPGAN and st.session_state["GFPGAN"] is not None and not use_RealESRGAN:
+ #skip_save = True # #287 >_>
+ torch_gc()
+ cropped_faces, restored_faces, restored_img = st.session_state["GFPGAN"].enhance(x_sample[:,:,::-1], has_aligned=False, only_center_face=False, paste_back=True)
+ gfpgan_sample = restored_img[:,:,::-1]
+ gfpgan_image = Image.fromarray(gfpgan_sample)
+ gfpgan_filename = original_filename + '-gfpgan'
+
+ save_sample(gfpgan_image, sample_path_i, gfpgan_filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale,
+ normalize_prompt_weights, use_GFPGAN, write_info_files, prompt_matrix, init_img, uses_loopback,
+ uses_random_seed_loopback, save_grid, sort_samples, sampler_name, ddim_eta,
+ n_iter, batch_size, i, denoising_strength, resize_mode, save_individual_images=False)
+
+ output_images.append(gfpgan_image) #287
+ if simple_templating:
+ grid_captions.append( captions[i] + "\ngfpgan" )
+
+ if use_RealESRGAN and st.session_state["RealESRGAN"] is not None and not use_GFPGAN:
+ #skip_save = True # #287 >_>
+ torch_gc()
+
+ if st.session_state["RealESRGAN"].model.name != realesrgan_model_name:
+ #try_loading_RealESRGAN(realesrgan_model_name)
+ load_models(use_GFPGAN=use_GFPGAN, use_RealESRGAN=use_RealESRGAN, RealESRGAN_model=realesrgan_model_name)
+
+ output, img_mode = st.session_state["RealESRGAN"].enhance(x_sample[:,:,::-1])
+ esrgan_filename = original_filename + '-esrgan4x'
+ esrgan_sample = output[:,:,::-1]
+ esrgan_image = Image.fromarray(esrgan_sample)
+
+ #save_sample(image, sample_path_i, original_filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale,
+ #normalize_prompt_weights, use_GFPGAN, write_info_files, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback, skip_save,
+ #save_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode)
+
+ save_sample(esrgan_image, sample_path_i, esrgan_filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale,
+ normalize_prompt_weights, use_GFPGAN, write_info_files, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback,
+ save_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, save_individual_images=False)
+
+ output_images.append(esrgan_image) #287
+ if simple_templating:
+ grid_captions.append( captions[i] + "\nesrgan" )
+
+ if use_RealESRGAN and st.session_state["RealESRGAN"] is not None and use_GFPGAN and st.session_state["GFPGAN"] is not None:
+ #skip_save = True # #287 >_>
+ torch_gc()
+ cropped_faces, restored_faces, restored_img = st.session_state["GFPGAN"].enhance(x_sample[:,:,::-1], has_aligned=False, only_center_face=False, paste_back=True)
+ gfpgan_sample = restored_img[:,:,::-1]
+
+ if st.session_state["RealESRGAN"].model.name != realesrgan_model_name:
+ #try_loading_RealESRGAN(realesrgan_model_name)
+ load_models(use_GFPGAN=use_GFPGAN, use_RealESRGAN=use_RealESRGAN, RealESRGAN_model=realesrgan_model_name)
+
+ output, img_mode = st.session_state["RealESRGAN"].enhance(gfpgan_sample[:,:,::-1])
+ gfpgan_esrgan_filename = original_filename + '-gfpgan-esrgan4x'
+ gfpgan_esrgan_sample = output[:,:,::-1]
+ gfpgan_esrgan_image = Image.fromarray(gfpgan_esrgan_sample)
+
+ save_sample(gfpgan_esrgan_image, sample_path_i, gfpgan_esrgan_filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale,
+ normalize_prompt_weights, False, write_info_files, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback,
+ save_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, save_individual_images=False)
+
+ output_images.append(gfpgan_esrgan_image) #287
+
+ if simple_templating:
+ grid_captions.append( captions[i] + "\ngfpgan_esrgan" )
+
+ if mask_restore and init_mask:
+ #init_mask = init_mask if keep_mask else ImageOps.invert(init_mask)
+ init_mask = init_mask.filter(ImageFilter.GaussianBlur(mask_blur_strength))
+ init_mask = init_mask.convert('L')
+ init_img = init_img.convert('RGB')
+ image = image.convert('RGB')
+
+ if use_RealESRGAN and st.session_state["RealESRGAN"] is not None:
+ if st.session_state["RealESRGAN"].model.name != realesrgan_model_name:
+ #try_loading_RealESRGAN(realesrgan_model_name)
+ load_models(use_GFPGAN=use_GFPGAN, use_RealESRGAN=use_RealESRGAN, RealESRGAN_model=realesrgan_model_name)
+
+ output, img_mode = st.session_state["RealESRGAN"].enhance(np.array(init_img, dtype=np.uint8))
+ init_img = Image.fromarray(output)
+ init_img = init_img.convert('RGB')
+
+ output, img_mode = st.session_state["RealESRGAN"].enhance(np.array(init_mask, dtype=np.uint8))
+ init_mask = Image.fromarray(output)
+ init_mask = init_mask.convert('L')
+
+ image = Image.composite(init_img, image, init_mask)
+
+ if save_individual_images:
+ save_sample(image, sample_path_i, filename, jpg_sample, prompts, seeds, width, height, steps, cfg_scale,
+ normalize_prompt_weights, use_GFPGAN, write_info_files, prompt_matrix, init_img, uses_loopback, uses_random_seed_loopback,
+ save_grid, sort_samples, sampler_name, ddim_eta, n_iter, batch_size, i, denoising_strength, resize_mode, save_individual_images)
+
+ if not use_GFPGAN or not use_RealESRGAN:
+ output_images.append(image)
+
+ #if add_original_image or not simple_templating:
+ #output_images.append(image)
+ #if simple_templating:
+ #grid_captions.append( captions[i] )
+
+ if defaults.general.optimized:
+ mem = torch.cuda.memory_allocated()/1e6
+ modelFS.to("cpu")
+ while(torch.cuda.memory_allocated()/1e6 >= mem):
+ time.sleep(1)
+
+ if prompt_matrix or save_grid:
+ if prompt_matrix:
+ if simple_templating:
+ grid = image_grid(output_images, n_iter, force_n_rows=frows, captions=grid_captions)
+ else:
+ grid = image_grid(output_images, n_iter, force_n_rows=1 << ((len(prompt_matrix_parts)-1)//2))
+ try:
+ grid = draw_prompt_matrix(grid, width, height, prompt_matrix_parts)
+ except:
+ import traceback
+ print("Error creating prompt_matrix text:", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+ else:
+ grid = image_grid(output_images, batch_size)
+
+ if grid and (batch_size > 1 or n_iter > 1):
+ output_images.insert(0, grid)
+
+ grid_count = get_next_sequence_number(outpath, 'grid-')
+ grid_file = f"grid-{grid_count:05}-{seed}_{slugify(prompts[i].replace(' ', '_')[:220-len(full_path)])}.{grid_ext}"
+ grid.save(os.path.join(outpath, grid_file), grid_format, quality=grid_quality, lossless=grid_lossless, optimize=True)
+
+ toc = time.time()
+
+ mem_max_used, mem_total = mem_mon.read_and_stop()
+ time_diff = time.time()-start_time
+
+ info = f"""
+ {prompt}
+ Steps: {steps}, Sampler: {sampler_name}, CFG scale: {cfg_scale}, Seed: {seed}{', Denoising strength: '+str(denoising_strength) if init_img is not None else ''}{', GFPGAN' if use_GFPGAN and st.session_state["GFPGAN"] is not None else ''}{', '+realesrgan_model_name if use_RealESRGAN and st.session_state["RealESRGAN"] is not None else ''}{', Prompt Matrix Mode.' if prompt_matrix else ''}""".strip()
+ stats = f'''
+ Took { round(time_diff, 2) }s total ({ round(time_diff/(len(all_prompts)),2) }s per image)
+ Peak memory usage: { -(mem_max_used // -1_048_576) } MiB / { -(mem_total // -1_048_576) } MiB / { round(mem_max_used/mem_total*100, 3) }%'''
+
+ for comment in comments:
+ info += "\n\n" + comment
+
+ #mem_mon.stop()
+ #del mem_mon
+ torch_gc()
+
+ return output_images, seed, info, stats
+
+
+def resize_image(resize_mode, im, width, height):
+ LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
+ if resize_mode == 0:
+ res = im.resize((width, height), resample=LANCZOS)
+ elif resize_mode == 1:
+ ratio = width / height
+ src_ratio = im.width / im.height
+
+ src_w = width if ratio > src_ratio else im.width * height // im.height
+ src_h = height if ratio <= src_ratio else im.height * width // im.width
+
+ resized = im.resize((src_w, src_h), resample=LANCZOS)
+ res = Image.new("RGBA", (width, height))
+ res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
+ else:
+ ratio = width / height
+ src_ratio = im.width / im.height
+
+ src_w = width if ratio < src_ratio else im.width * height // im.height
+ src_h = height if ratio >= src_ratio else im.height * width // im.width
+
+ resized = im.resize((src_w, src_h), resample=LANCZOS)
+ res = Image.new("RGBA", (width, height))
+ res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
+
+ if ratio < src_ratio:
+ fill_height = height // 2 - src_h // 2
+ res.paste(resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0))
+ res.paste(resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)), box=(0, fill_height + src_h))
+ elif ratio > src_ratio:
+ fill_width = width // 2 - src_w // 2
+ res.paste(resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0))
+ res.paste(resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)), box=(fill_width + src_w, 0))
+
+ return res
+
+import skimage
+
+def img2img(prompt: str = '', init_info: any = None, init_info_mask: any = None, mask_mode: int = 0, mask_blur_strength: int = 3,
+ mask_restore: bool = False, ddim_steps: int = 50, sampler_name: str = 'DDIM',
+ n_iter: int = 1, cfg_scale: float = 7.5, denoising_strength: float = 0.8,
+ seed: int = -1, noise_mode: int = 0, find_noise_steps: str = "", height: int = 512, width: int = 512, resize_mode: int = 0, fp = None,
+ variant_amount: float = None, variant_seed: int = None, ddim_eta:float = 0.0,
+ write_info_files:bool = True, RealESRGAN_model: str = "RealESRGAN_x4plus_anime_6B",
+ separate_prompts:bool = False, normalize_prompt_weights:bool = True,
+ save_individual_images: bool = True, save_grid: bool = True, group_by_prompt: bool = True,
+ save_as_jpg: bool = True, use_GFPGAN: bool = True, use_RealESRGAN: bool = True, loopback: bool = False,
+ random_seed_loopback: bool = False
+ ):
+
+ outpath = defaults.general.outdir_img2img or defaults.general.outdir or "outputs/img2img-samples"
+ err = False
+ #loopback = False
+ #skip_save = False
+ seed = seed_to_int(seed)
+
+ batch_size = 1
+
+ #prompt_matrix = 0
+ #normalize_prompt_weights = 1 in toggles
+ #loopback = 2 in toggles
+ #random_seed_loopback = 3 in toggles
+ #skip_save = 4 not in toggles
+ #save_grid = 5 in toggles
+ #sort_samples = 6 in toggles
+ #write_info_files = 7 in toggles
+ #write_sample_info_to_log_file = 8 in toggles
+ #jpg_sample = 9 in toggles
+ #use_GFPGAN = 10 in toggles
+ #use_RealESRGAN = 11 in toggles
+
+ if sampler_name == 'PLMS':
+ sampler = PLMSSampler(st.session_state["model"])
+ elif sampler_name == 'DDIM':
+ sampler = DDIMSampler(st.session_state["model"])
+ elif sampler_name == 'k_dpm_2_a':
+ sampler = KDiffusionSampler(st.session_state["model"],'dpm_2_ancestral')
+ elif sampler_name == 'k_dpm_2':
+ sampler = KDiffusionSampler(st.session_state["model"],'dpm_2')
+ elif sampler_name == 'k_euler_a':
+ sampler = KDiffusionSampler(st.session_state["model"],'euler_ancestral')
+ elif sampler_name == 'k_euler':
+ sampler = KDiffusionSampler(st.session_state["model"],'euler')
+ elif sampler_name == 'k_heun':
+ sampler = KDiffusionSampler(st.session_state["model"],'heun')
+ elif sampler_name == 'k_lms':
+ sampler = KDiffusionSampler(st.session_state["model"],'lms')
+ else:
+ raise Exception("Unknown sampler: " + sampler_name)
+
+ def process_init_mask(init_mask: Image):
+ if init_mask.mode == "RGBA":
+ init_mask = init_mask.convert('RGBA')
+ background = Image.new('RGBA', init_mask.size, (0, 0, 0))
+ init_mask = Image.alpha_composite(background, init_mask)
+ init_mask = init_mask.convert('RGB')
+ return init_mask
+
+ init_img = init_info
+ init_mask = None
+ if mask_mode == 0:
+ if init_info_mask:
+ init_mask = process_init_mask(init_info_mask)
+ elif mask_mode == 1:
+ if init_info_mask:
+ init_mask = process_init_mask(init_info_mask)
+ init_mask = ImageOps.invert(init_mask)
+ elif mask_mode == 2:
+ init_img_transparency = init_img.split()[-1].convert('L')#.point(lambda x: 255 if x > 0 else 0, mode='1')
+ init_mask = init_img_transparency
+ init_mask = init_mask.convert("RGB")
+ init_mask = resize_image(resize_mode, init_mask, width, height)
+ init_mask = init_mask.convert("RGB")
+
+ assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]'
+ t_enc = int(denoising_strength * ddim_steps)
+
+ if init_mask is not None and (noise_mode == 2 or noise_mode == 3) and init_img is not None:
+ noise_q = 0.99
+ color_variation = 0.0
+ mask_blend_factor = 1.0
+
+ np_init = (np.asarray(init_img.convert("RGB"))/255.0).astype(np.float64) # annoyingly complex mask fixing
+ np_mask_rgb = 1. - (np.asarray(ImageOps.invert(init_mask).convert("RGB"))/255.0).astype(np.float64)
+ np_mask_rgb -= np.min(np_mask_rgb)
+ np_mask_rgb /= np.max(np_mask_rgb)
+ np_mask_rgb = 1. - np_mask_rgb
+ np_mask_rgb_hardened = 1. - (np_mask_rgb < 0.99).astype(np.float64)
+ blurred = skimage.filters.gaussian(np_mask_rgb_hardened[:], sigma=16., channel_axis=2, truncate=32.)
+ blurred2 = skimage.filters.gaussian(np_mask_rgb_hardened[:], sigma=16., channel_axis=2, truncate=32.)
+ #np_mask_rgb_dilated = np_mask_rgb + blurred # fixup mask todo: derive magic constants
+ #np_mask_rgb = np_mask_rgb + blurred
+ np_mask_rgb_dilated = np.clip((np_mask_rgb + blurred2) * 0.7071, 0., 1.)
+ np_mask_rgb = np.clip((np_mask_rgb + blurred) * 0.7071, 0., 1.)
+
+ noise_rgb = matched_noise.get_matched_noise(np_init, np_mask_rgb, noise_q, color_variation)
+ blend_mask_rgb = np.clip(np_mask_rgb_dilated,0.,1.) ** (mask_blend_factor)
+ noised = noise_rgb[:]
+ blend_mask_rgb **= (2.)
+ noised = np_init[:] * (1. - blend_mask_rgb) + noised * blend_mask_rgb
+
+ np_mask_grey = np.sum(np_mask_rgb, axis=2)/3.
+ ref_mask = np_mask_grey < 1e-3
+
+ all_mask = np.ones((height, width), dtype=bool)
+ noised[all_mask,:] = skimage.exposure.match_histograms(noised[all_mask,:]**1., noised[ref_mask,:], channel_axis=1)
+
+ init_img = Image.fromarray(np.clip(noised * 255., 0., 255.).astype(np.uint8), mode="RGB")
+ st.session_state["editor_image"].image(init_img) # debug
+
+ def init():
+ image = init_img.convert('RGB')
+ image = np.array(image).astype(np.float32) / 255.0
+ image = image[None].transpose(0, 3, 1, 2)
+ image = torch.from_numpy(image)
+
+ mask_channel = None
+ if init_mask:
+ alpha = resize_image(resize_mode, init_mask, width // 8, height // 8)
+ mask_channel = alpha.split()[-1]
+
+ mask = None
+ if mask_channel is not None:
+ mask = np.array(mask_channel).astype(np.float32) / 255.0
+ mask = (1 - mask)
+ mask = np.tile(mask, (4, 1, 1))
+ mask = mask[None].transpose(0, 1, 2, 3)
+ mask = torch.from_numpy(mask).to(st.session_state["device"])
+
+ if defaults.general.optimized:
+ modelFS.to(st.session_state["device"] )
+
+ init_image = 2. * image - 1.
+ init_image = init_image.to(st.session_state["device"])
+ init_latent = (st.session_state["model"] if not defaults.general.optimized else modelFS).get_first_stage_encoding((st.session_state["model"] if not defaults.general.optimized else modelFS).encode_first_stage(init_image)) # move to latent space
+
+ if defaults.general.optimized:
+ mem = torch.cuda.memory_allocated()/1e6
+ modelFS.to("cpu")
+ while(torch.cuda.memory_allocated()/1e6 >= mem):
+ time.sleep(1)
+
+ return init_latent, mask,
+
+ def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name):
+ t_enc_steps = t_enc
+ obliterate = False
+ if ddim_steps == t_enc_steps:
+ t_enc_steps = t_enc_steps - 1
+ obliterate = True
+
+ if sampler_name != 'DDIM':
+ x0, z_mask = init_data
+
+ sigmas = sampler.model_wrap.get_sigmas(ddim_steps)
+ noise = x * sigmas[ddim_steps - t_enc_steps - 1]
+
+ xi = x0 + noise
+
+ # Obliterate masked image
+ if z_mask is not None and obliterate:
+ random = torch.randn(z_mask.shape, device=xi.device)
+ xi = (z_mask * noise) + ((1-z_mask) * xi)
+
+ sigma_sched = sigmas[ddim_steps - t_enc_steps - 1:]
+ model_wrap_cfg = CFGMaskedDenoiser(sampler.model_wrap)
+ samples_ddim = K.sampling.__dict__[f'sample_{sampler.get_sampler_name()}'](model_wrap_cfg, xi, sigma_sched,
+ extra_args={'cond': conditioning, 'uncond': unconditional_conditioning,
+ 'cond_scale': cfg_scale, 'mask': z_mask, 'x0': x0, 'xi': xi}, disable=False,
+ callback=generation_callback)
+ else:
+
+ x0, z_mask = init_data
+
+ sampler.make_schedule(ddim_num_steps=ddim_steps, ddim_eta=0.0, verbose=False)
+ z_enc = sampler.stochastic_encode(x0, torch.tensor([t_enc_steps]*batch_size).to(st.session_state["device"] ))
+
+ # Obliterate masked image
+ if z_mask is not None and obliterate:
+ random = torch.randn(z_mask.shape, device=z_enc.device)
+ z_enc = (z_mask * random) + ((1-z_mask) * z_enc)
+
+ # decode it
+ samples_ddim = sampler.decode(z_enc, conditioning, t_enc_steps,
+ unconditional_guidance_scale=cfg_scale,
+ unconditional_conditioning=unconditional_conditioning,
+ z_mask=z_mask, x0=x0)
+ return samples_ddim
+
+
+
+ if loopback:
+ output_images, info = None, None
+ history = []
+ initial_seed = None
+
+ do_color_correction = False
+ try:
+ from skimage import exposure
+ do_color_correction = True
+ except:
+ print("Install scikit-image to perform color correction on loopback")
+
+ for i in range(n_iter):
+ if do_color_correction and i == 0:
+ correction_target = cv2.cvtColor(np.asarray(init_img.copy()), cv2.COLOR_RGB2LAB)
+
+ output_images, seed, info, stats = process_images(
+ outpath=outpath,
+ func_init=init,
+ func_sample=sample,
+ prompt=prompt,
+ seed=seed,
+ sampler_name=sampler_name,
+ save_grid=save_grid,
+ batch_size=1,
+ n_iter=1,
+ steps=ddim_steps,
+ cfg_scale=cfg_scale,
+ width=width,
+ height=height,
+ prompt_matrix=separate_prompts,
+ use_GFPGAN=use_GFPGAN,
+ use_RealESRGAN=use_RealESRGAN, # Forcefully disable upscaling when using loopback
+ realesrgan_model_name=RealESRGAN_model,
+ fp=fp,
+ normalize_prompt_weights=normalize_prompt_weights,
+ save_individual_images=save_individual_images,
+ init_img=init_img,
+ init_mask=init_mask,
+ mask_blur_strength=mask_blur_strength,
+ mask_restore=mask_restore,
+ denoising_strength=denoising_strength,
+ noise_mode=noise_mode,
+ find_noise_steps=find_noise_steps,
+ resize_mode=resize_mode,
+ uses_loopback=loopback,
+ uses_random_seed_loopback=random_seed_loopback,
+ sort_samples=group_by_prompt,
+ write_info_files=write_info_files,
+ jpg_sample=save_as_jpg
+ )
+
+ if initial_seed is None:
+ initial_seed = seed
+
+ init_img = output_images[0]
+
+ if do_color_correction and correction_target is not None:
+ init_img = Image.fromarray(cv2.cvtColor(exposure.match_histograms(
+ cv2.cvtColor(
+ np.asarray(init_img),
+ cv2.COLOR_RGB2LAB
+ ),
+ correction_target,
+ channel_axis=2
+ ), cv2.COLOR_LAB2RGB).astype("uint8"))
+
+ if not random_seed_loopback:
+ seed = seed + 1
+ else:
+ seed = seed_to_int(None)
+
+ denoising_strength = max(denoising_strength * 0.95, 0.1)
+ history.append(init_img)
+
+ output_images = history
+ seed = initial_seed
+
+ else:
+ output_images, seed, info, stats = process_images(
+ outpath=outpath,
+ func_init=init,
+ func_sample=sample,
+ prompt=prompt,
+ seed=seed,
+ sampler_name=sampler_name,
+ save_grid=save_grid,
+ batch_size=batch_size,
+ n_iter=n_iter,
+ steps=ddim_steps,
+ cfg_scale=cfg_scale,
+ width=width,
+ height=height,
+ prompt_matrix=separate_prompts,
+ use_GFPGAN=use_GFPGAN,
+ use_RealESRGAN=use_RealESRGAN,
+ realesrgan_model_name=RealESRGAN_model,
+ fp=fp,
+ normalize_prompt_weights=normalize_prompt_weights,
+ save_individual_images=save_individual_images,
+ init_img=init_img,
+ init_mask=init_mask,
+ mask_blur_strength=mask_blur_strength,
+ denoising_strength=denoising_strength,
+ noise_mode=noise_mode,
+ find_noise_steps=find_noise_steps,
+ mask_restore=mask_restore,
+ resize_mode=resize_mode,
+ uses_loopback=loopback,
+ sort_samples=group_by_prompt,
+ write_info_files=write_info_files,
+ jpg_sample=save_as_jpg
+ )
+
+ del sampler
+
+ return output_images, seed, info, stats
+
+@retry((RuntimeError, KeyError) , tries=3)
+def txt2img(prompt: str, ddim_steps: int, sampler_name: str, realesrgan_model_name: str,
+ n_iter: int, batch_size: int, cfg_scale: float, seed: Union[int, str, None],
+ height: int, width: int, separate_prompts:bool = False, normalize_prompt_weights:bool = True,
+ save_individual_images: bool = True, save_grid: bool = True, group_by_prompt: bool = True,
+ save_as_jpg: bool = True, use_GFPGAN: bool = True, use_RealESRGAN: bool = True,
+ RealESRGAN_model: str = "RealESRGAN_x4plus_anime_6B", fp = None, variant_amount: float = None,
+ variant_seed: int = None, ddim_eta:float = 0.0, write_info_files:bool = True):
+
+ outpath = defaults.general.outdir_txt2img or defaults.general.outdir or "outputs/txt2img-samples"
+
+ err = False
+ seed = seed_to_int(seed)
+
+ #prompt_matrix = 0 in toggles
+ #normalize_prompt_weights = 1 in toggles
+ #skip_save = 2 not in toggles
+ #save_grid = 3 not in toggles
+ #sort_samples = 4 in toggles
+ #write_info_files = 5 in toggles
+ #jpg_sample = 6 in toggles
+ #use_GFPGAN = 7 in toggles
+ #use_RealESRGAN = 8 in toggles
+
+ if sampler_name == 'PLMS':
+ sampler = PLMSSampler(st.session_state["model"])
+ elif sampler_name == 'DDIM':
+ sampler = DDIMSampler(st.session_state["model"])
+ elif sampler_name == 'k_dpm_2_a':
+ sampler = KDiffusionSampler(st.session_state["model"],'dpm_2_ancestral')
+ elif sampler_name == 'k_dpm_2':
+ sampler = KDiffusionSampler(st.session_state["model"],'dpm_2')
+ elif sampler_name == 'k_euler_a':
+ sampler = KDiffusionSampler(st.session_state["model"],'euler_ancestral')
+ elif sampler_name == 'k_euler':
+ sampler = KDiffusionSampler(st.session_state["model"],'euler')
+ elif sampler_name == 'k_heun':
+ sampler = KDiffusionSampler(st.session_state["model"],'heun')
+ elif sampler_name == 'k_lms':
+ sampler = KDiffusionSampler(st.session_state["model"],'lms')
+ else:
+ raise Exception("Unknown sampler: " + sampler_name)
+
+ def init():
+ pass
+
+ def sample(init_data, x, conditioning, unconditional_conditioning, sampler_name):
+ samples_ddim, _ = sampler.sample(S=ddim_steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=cfg_scale,
+ unconditional_conditioning=unconditional_conditioning, eta=ddim_eta, x_T=x, img_callback=generation_callback,
+ log_every_t=int(defaults.general.update_preview_frequency))
+
+ return samples_ddim
+
+ #try:
+ output_images, seed, info, stats = process_images(
+ outpath=outpath,
+ func_init=init,
+ func_sample=sample,
+ prompt=prompt,
+ seed=seed,
+ sampler_name=sampler_name,
+ save_grid=save_grid,
+ batch_size=batch_size,
+ n_iter=n_iter,
+ steps=ddim_steps,
+ cfg_scale=cfg_scale,
+ width=width,
+ height=height,
+ prompt_matrix=separate_prompts,
+ use_GFPGAN=use_GFPGAN,
+ use_RealESRGAN=use_RealESRGAN,
+ realesrgan_model_name=realesrgan_model_name,
+ fp=fp,
+ ddim_eta=ddim_eta,
+ normalize_prompt_weights=normalize_prompt_weights,
+ save_individual_images=save_individual_images,
+ sort_samples=group_by_prompt,
+ write_info_files=write_info_files,
+ jpg_sample=save_as_jpg,
+ variant_amount=variant_amount,
+ variant_seed=variant_seed,
+ )
+
+ del sampler
+
+ return output_images, seed, info, stats
+
+ #except RuntimeError as e:
+ #err = e
+ #err_msg = f'CRASHED:
Please wait while the program restarts.'
+ #stats = err_msg
+ #return [], seed, 'err', stats
+
+
+#
+def txt2vid(
+ # --------------------------------------
+ # args you probably want to change
+ prompts = ["blueberry spaghetti", "strawberry spaghetti"], # prompt to dream about
+ gpu:int = defaults.general.gpu, # id of the gpu to run on
+ #name:str = 'test', # name of this project, for the output directory
+ #rootdir:str = defaults.general.outdir,
+ num_steps:int = 200, # number of steps between each pair of sampled points
+ max_frames:int = 10000, # number of frames to write and then exit the script
+ num_inference_steps:int = 50, # more (e.g. 100, 200 etc) can create slightly better images
+ cfg_scale:float = 5.0, # can depend on the prompt. usually somewhere between 3-10 is good
+ do_loop = False,
+ use_lerp_for_text = False,
+ seeds = None,
+ quality:int = 100, # for jpeg compression of the output images
+ eta:float = 0.0,
+ width:int = 256,
+ height:int = 256,
+ weights_path = "CompVis/stable-diffusion-v1-4",
+ scheduler="klms", # choices: default, ddim, klms
+ disable_tqdm = False,
+ #-----------------------------------------------
+ beta_start = 0.0001,
+ beta_end = 0.00012,
+ beta_schedule = "scaled_linear"
+ ):
+ """
+ prompt = ["blueberry spaghetti", "strawberry spaghetti"], # prompt to dream about
+ gpu:int = defaults.general.gpu, # id of the gpu to run on
+ #name:str = 'test', # name of this project, for the output directory
+ #rootdir:str = defaults.general.outdir,
+ num_steps:int = 200, # number of steps between each pair of sampled points
+ max_frames:int = 10000, # number of frames to write and then exit the script
+ num_inference_steps:int = 50, # more (e.g. 100, 200 etc) can create slightly better images
+ cfg_scale:float = 5.0, # can depend on the prompt. usually somewhere between 3-10 is good
+ do_loop = False,
+ use_lerp_for_text = False,
+ seed = None,
+ quality:int = 100, # for jpeg compression of the output images
+ eta:float = 0.0,
+ width:int = 256,
+ height:int = 256,
+ weights_path = "CompVis/stable-diffusion-v1-4",
+ scheduler="klms", # choices: default, ddim, klms
+ disable_tqdm = False,
+ beta_start = 0.0001,
+ beta_end = 0.00012,
+ beta_schedule = "scaled_linear"
+ """
+ mem_mon = MemUsageMonitor('MemMon')
+ mem_mon.start()
+
+
+ seeds = seed_to_int(seeds)
+
+ # We add an extra frame because most
+ # of the time the first frame is just the noise.
+ max_frames +=1
+
+ assert torch.cuda.is_available()
+ assert height % 8 == 0 and width % 8 == 0
+ torch.manual_seed(seeds)
+ torch_device = f"cuda:{gpu}"
+
+ # init the output dir
+ sanitized_prompt = slugify(prompts)
+
+ full_path = os.path.join(os.getcwd(), defaults.general.outdir, "txt2vid-samples", "samples", sanitized_prompt)
+
+ if len(full_path) > 220:
+ sanitized_prompt = sanitized_prompt[:220-len(full_path)]
+ full_path = os.path.join(os.getcwd(), defaults.general.outdir, "txt2vid-samples", "samples", sanitized_prompt)
+
+ os.makedirs(full_path, exist_ok=True)
+
+ # Write prompt info to file in output dir so we can keep track of what we did
+ if st.session_state.write_info_files:
+ with open(os.path.join(full_path , f'{slugify(str(seeds))}_config.json' if len(prompts) > 1 else "prompts_config.json"), "w") as outfile:
+ outfile.write(json.dumps(
+ dict(
+ prompts = prompts,
+ gpu = gpu,
+ num_steps = num_steps,
+ max_frames = max_frames,
+ num_inference_steps = num_inference_steps,
+ cfg_scale = cfg_scale,
+ do_loop = do_loop,
+ use_lerp_for_text = use_lerp_for_text,
+ seeds = seeds,
+ quality = quality,
+ eta = eta,
+ width = width,
+ height = height,
+ weights_path = weights_path,
+ scheduler=scheduler,
+ disable_tqdm = disable_tqdm,
+ beta_start = beta_start,
+ beta_end = beta_end,
+ beta_schedule = beta_schedule
+ ),
+ indent=2,
+ sort_keys=False,
+ ))
+
+ #print(scheduler)
+ default_scheduler = PNDMScheduler(
+ beta_start=beta_start, beta_end=beta_end, beta_schedule=beta_schedule
+ )
+ # ------------------------------------------------------------------------------
+ #Schedulers
+ ddim_scheduler = DDIMScheduler(
+ beta_start=beta_start,
+ beta_end=beta_end,
+ beta_schedule=beta_schedule,
+ clip_sample=False,
+ set_alpha_to_one=False,
+ )
+
+ klms_scheduler = LMSDiscreteScheduler(
+ beta_start=beta_start, beta_end=beta_end, beta_schedule=beta_schedule
+ )
+
+ SCHEDULERS = dict(default=default_scheduler, ddim=ddim_scheduler, klms=klms_scheduler)
+
+ # ------------------------------------------------------------------------------
+
+ #if weights_path == "Stable Diffusion v1.4":
+ #weights_path = "CompVis/stable-diffusion-v1-4"
+ #else:
+ #weights_path = os.path.join("./models", "custom", f"{weights_path}.ckpt")
+
+ try:
+ if "model" in st.session_state:
+ del st.session_state["model"]
+ except:
+ pass
+
+ #print (st.session_state["weights_path"] != weights_path)
+
+ try:
+ if not st.session_state["pipe"] or st.session_state["weights_path"] != weights_path:
+ if st.session_state["weights_path"] != weights_path:
+ del st.session_state["weights_path"]
+
+ st.session_state["weights_path"] = weights_path
+ st.session_state["pipe"] = StableDiffusionPipeline.from_pretrained(
+ weights_path,
+ use_local_file=True,
+ use_auth_token=True,
+ #torch_dtype=torch.float16 if not defaults.general.no_half else None,
+ revision="fp16" if not defaults.general.no_half else None
+ )
+
+ st.session_state["pipe"].unet.to(torch_device)
+ st.session_state["pipe"].vae.to(torch_device)
+ st.session_state["pipe"].text_encoder.to(torch_device)
+ print("Tx2Vid Model Loaded")
+ else:
+ print("Tx2Vid Model already Loaded")
+
+ except:
+ #del st.session_state["weights_path"]
+ #del st.session_state["pipe"]
+
+ st.session_state["weights_path"] = weights_path
+ st.session_state["pipe"] = StableDiffusionPipeline.from_pretrained(
+ weights_path,
+ use_local_file=True,
+ use_auth_token=True,
+ #torch_dtype=torch.float16 if not defaults.general.no_half else None,
+ revision="fp16" if not defaults.general.no_half else None
+ )
+
+ st.session_state["pipe"].unet.to(torch_device)
+ st.session_state["pipe"].vae.to(torch_device)
+ st.session_state["pipe"].text_encoder.to(torch_device)
+ print("Tx2Vid Model Loaded")
+
+ st.session_state["pipe"].scheduler = SCHEDULERS[scheduler]
+
+ # get the conditional text embeddings based on the prompt
+ text_input = st.session_state["pipe"].tokenizer(prompts, padding="max_length", max_length=st.session_state["pipe"].tokenizer.model_max_length, truncation=True, return_tensors="pt")
+ cond_embeddings = st.session_state["pipe"].text_encoder(text_input.input_ids.to(torch_device))[0] # shape [1, 77, 768]
+
+ # sample a source
+ init1 = torch.randn((1, st.session_state["pipe"].unet.in_channels, height // 8, width // 8), device=torch_device)
+
+ if do_loop:
+ prompts = [prompts, prompts]
+ seeds = [seeds, seeds]
+ #first_seed, *seeds = seeds
+ #prompts.append(prompts)
+ #seeds.append(first_seed)
+
+
+ # iterate the loop
+ frames = []
+ frame_index = 0
+
+ st.session_state["frame_total_duration"] = 0
+ st.session_state["frame_total_speed"] = 0
+
+ try:
+ while frame_index < max_frames:
+ st.session_state["frame_duration"] = 0
+ st.session_state["frame_speed"] = 0
+ st.session_state["current_frame"] = frame_index
+
+ # sample the destination
+ init2 = torch.randn((1, st.session_state["pipe"].unet.in_channels, height // 8, width // 8), device=torch_device)
+
+ for i, t in enumerate(np.linspace(0, 1, num_steps)):
+ start = timeit.default_timer()
+ print(f"COUNT: {frame_index+1}/{num_steps}")
+
+ #if use_lerp_for_text:
+ #init = torch.lerp(init1, init2, float(t))
+ #else:
+ #init = slerp(gpu, float(t), init1, init2)
+
+ init = slerp(gpu, float(t), init1, init2)
+
+ with autocast("cuda"):
+ image = diffuse(st.session_state["pipe"], cond_embeddings, init, num_inference_steps, cfg_scale, eta)
+
+ im = Image.fromarray(image)
+ outpath = os.path.join(full_path, 'frame%06d.png' % frame_index)
+ im.save(outpath, quality=quality)
+
+ # send the image to the UI to update it
+ #st.session_state["preview_image"].image(im)
+
+ #append the frames to the frames list so we can use them later.
+ frames.append(np.asarray(im))
+
+ #increase frame_index counter.
+ frame_index += 1
+
+ st.session_state["current_frame"] = frame_index
+
+ duration = timeit.default_timer() - start
+
+ if duration >= 1:
+ speed = "s/it"
+ else:
+ speed = "it/s"
+ duration = 1 / duration
+
+ st.session_state["frame_duration"] = duration
+ st.session_state["frame_speed"] = speed
+
+ init1 = init2
+
+ except StopException:
+ pass
+
+
+ if st.session_state['save_video']:
+ # write video to memory
+ #output = io.BytesIO()
+ #writer = imageio.get_writer(os.path.join(os.getcwd(), defaults.general.outdir, "txt2vid-samples"), im, extension=".mp4", fps=30)
+ try:
+ video_path = os.path.join(os.getcwd(), defaults.general.outdir, "txt2vid-samples","temp.mp4")
+ writer = imageio.get_writer(video_path, fps=24)
+ for frame in frames:
+ writer.append_data(frame)
+ writer.close()
+ except:
+ print("Can't save video, skipping.")
+
+ # show video preview on the UI
+ st.session_state["preview_video"].video(open(video_path, 'rb').read())
+
+ mem_max_used, mem_total = mem_mon.read_and_stop()
+ time_diff = time.time()- start
+
+ info = f"""
+ {prompts}
+ Sampling Steps: {num_steps}, Sampler: {scheduler}, CFG scale: {cfg_scale}, Seed: {seeds}, Max Frames: {max_frames}""".strip()
+ stats = f'''
+ Took { round(time_diff, 2) }s total ({ round(time_diff/(max_frames),2) }s per image)
+ Peak memory usage: { -(mem_max_used // -1_048_576) } MiB / { -(mem_total // -1_048_576) } MiB / { round(mem_max_used/mem_total*100, 3) }%'''
+
+ return im, seeds, info, stats
+
+
+# functions to load css locally OR remotely starts here. Options exist for future flexibility. Called as st.markdown with unsafe_allow_html as css injection
+# TODO, maybe look into async loading the file especially for remote fetching
+def local_css(file_name):
+ with open(file_name) as f:
+ st.markdown(f'', unsafe_allow_html=True)
+
+def remote_css(url):
+ st.markdown(f'', unsafe_allow_html=True)
+
+def load_css(isLocal, nameOrURL):
+ if(isLocal):
+ local_css(nameOrURL)
+ else:
+ remote_css(nameOrURL)
+
+
+# main functions to define streamlit layout here
+def layout():
+
+ st.set_page_config(page_title="Stable Diffusion Playground", layout="wide")
+
+ with st.empty():
+ # load css as an external file, function has an option to local or remote url. Potential use when running from cloud infra that might not have access to local path.
+ load_css(True, 'frontend/css/streamlit.main.css')
+
+ # check if the models exist on their respective folders
+ if os.path.exists(os.path.join(defaults.general.GFPGAN_dir, "experiments", "pretrained_models", "GFPGANv1.3.pth")):
+ GFPGAN_available = True
+ else:
+ GFPGAN_available = False
+
+ if os.path.exists(os.path.join(defaults.general.RealESRGAN_dir, "experiments","pretrained_models", f"{defaults.general.RealESRGAN_model}.pth")):
+ RealESRGAN_available = True
+ else:
+ RealESRGAN_available = False
+
+ # Allow for custom models to be used instead of the default one,
+ # an example would be Waifu-Diffusion or any other fine tune of stable diffusion
+ custom_models:sorted = []
+ for root, dirs, files in os.walk(os.path.join("models", "custom")):
+ for file in files:
+ if os.path.splitext(file)[1] == '.ckpt':
+ fullpath = os.path.join(root, file)
+ #print(fullpath)
+ custom_models.append(os.path.splitext(file)[0])
+ #print (os.path.splitext(file)[0])
+
+ if len(custom_models) > 0:
+ CustomModel_available = True
+ custom_models.append("Stable Diffusion v1.4")
+ else:
+ CustomModel_available = False
+
+ with st.sidebar:
+ # The global settings section will be moved to the Settings page.
+ #with st.expander("Global Settings:"):
+ #st.write("Global Settings:")
+ #defaults.general.update_preview = st.checkbox("Update Image Preview", value=defaults.general.update_preview,
+ #help="If enabled the image preview will be updated during the generation instead of at the end. You can use the Update Preview \
+ #Frequency option bellow to customize how frequent it's updated. By default this is enabled and the frequency is set to 1 step.")
+ #st.session_state.update_preview_frequency = st.text_input("Update Image Preview Frequency", value=defaults.general.update_preview_frequency,
+ #help="Frequency in steps at which the the preview image is updated. By default the frequency is set to 1 step.")
+
+ tabs = on_hover_tabs(tabName=['Stable Diffusion', "Textual Inversion","Model Manager","Settings"],
+ iconName=['dashboard','model_training' ,'cloud_download', 'settings'], default_choice=0)
+
+
+ if tabs =='Stable Diffusion':
+ txt2img_tab, img2img_tab, txt2vid_tab, postprocessing_tab = st.tabs(["Text-to-Image Unified", "Image-to-Image Unified",
+ "Text-to-Video","Post-Processing"])
+ with txt2img_tab:
+ with st.form("txt2img-inputs"):
+ st.session_state["generation_mode"] = "txt2img"
+
+ input_col1, generate_col1 = st.columns([10,1])
+
+ with input_col1:
+ #prompt = st.text_area("Input Text","")
+ prompt = st.text_input("Input Text","", placeholder="A corgi wearing a top hat as an oil painting.")
+
+ # Every form must have a submit button, the extra blank spaces is a temp way to align it with the input field. Needs to be done in CSS or some other way.
+ generate_col1.write("")
+ generate_col1.write("")
+ generate_button = generate_col1.form_submit_button("Generate")
+
+ # creating the page layout using columns
+ col1, col2, col3 = st.columns([1,2,1], gap="large")
+
+ with col1:
+ width = st.slider("Width:", min_value=64, max_value=1024, value=defaults.txt2img.width, step=64)
+ height = st.slider("Height:", min_value=64, max_value=1024, value=defaults.txt2img.height, step=64)
+ cfg_scale = st.slider("CFG (Classifier Free Guidance Scale):", min_value=1.0, max_value=30.0, value=defaults.txt2img.cfg_scale, step=0.5, help="How strongly the image should follow the prompt.")
+ seed = st.text_input("Seed:", value=defaults.txt2img.seed, help=" The seed to use, if left blank a random seed will be generated.")
+ batch_count = st.slider("Batch count.", min_value=1, max_value=100, value=defaults.txt2img.batch_count, step=1, help="How many iterations or batches of images to generate in total.")
+ #batch_size = st.slider("Batch size", min_value=1, max_value=250, value=defaults.txt2img.batch_size, step=1,
+ #help="How many images are at once in a batch.\
+ #It increases the VRAM usage a lot but if you have enough VRAM it can reduce the time it takes to finish generation as more images are generated at once.\
+ #Default: 1")
+
+ with st.expander("Preview Settings"):
+ st.session_state["update_preview"] = st.checkbox("Update Image Preview", value=defaults.txt2img.update_preview,
+ help="If enabled the image preview will be updated during the generation instead of at the end. \
+ You can use the Update Preview \Frequency option bellow to customize how frequent it's updated. \
+ By default this is enabled and the frequency is set to 1 step.")
+
+ st.session_state["update_preview_frequency"] = st.text_input("Update Image Preview Frequency", value=defaults.txt2img.update_preview_frequency,
+ help="Frequency in steps at which the the preview image is updated. By default the frequency \
+ is set to 1 step.")
+
+ with col2:
+ preview_tab, gallery_tab = st.tabs(["Preview", "Gallery"])
+
+ with preview_tab:
+ #st.write("Image")
+ #Image for testing
+ #image = Image.open(requests.get("https://icon-library.com/images/image-placeholder-icon/image-placeholder-icon-13.jpg", stream=True).raw).convert('RGB')
+ #new_image = image.resize((175, 240))
+ #preview_image = st.image(image)
+
+ # create an empty container for the image, progress bar, etc so we can update it later and use session_state to hold them globally.
+ st.session_state["preview_image"] = st.empty()
+ st.session_state["preview_video"] = st.empty()
+
+ st.session_state["loading"] = st.empty()
+
+ st.session_state["progress_bar_text"] = st.empty()
+ st.session_state["progress_bar"] = st.empty()
+
+ message = st.empty()
+
+ with gallery_tab:
+ st.write('Here should be the image gallery, if I could make a grid in streamlit.')
+
+ with col3:
+ # If we have custom models available on the "models/custom"
+ #folder then we show a menu to select which model we want to use, otherwise we use the main model for SD
+ if CustomModel_available:
+ custom_model = st.selectbox("Custom Model:", custom_models,
+ index=custom_models.index(defaults.general.default_model),
+ help="Select the model you want to use. This option is only available if you have custom models \
+ on your 'models/custom' folder. The model name that will be shown here is the same as the name\
+ the file for the model has on said folder, it is recommended to give the .ckpt file a name that \
+ will make it easier for you to distinguish it from other models. Default: Stable Diffusion v1.4")
+ else:
+ custom_model = "Stable Diffusion v1.4"
+
+ st.session_state.sampling_steps = st.slider("Sampling Steps", value=defaults.txt2img.sampling_steps, min_value=1, max_value=250)
+
+ sampler_name_list = ["k_lms", "k_euler", "k_euler_a", "k_dpm_2", "k_dpm_2_a", "k_heun", "PLMS", "DDIM"]
+ sampler_name = st.selectbox("Sampling method", sampler_name_list,
+ index=sampler_name_list.index(defaults.txt2img.default_sampler), help="Sampling method to use. Default: k_euler")
+
+
+
+ #basic_tab, advanced_tab = st.tabs(["Basic", "Advanced"])
+
+ #with basic_tab:
+ #summit_on_enter = st.radio("Submit on enter?", ("Yes", "No"), horizontal=True,
+ #help="Press the Enter key to summit, when 'No' is selected you can use the Enter key to write multiple lines.")
+
+ with st.expander("Advanced"):
+ separate_prompts = st.checkbox("Create Prompt Matrix.", value=False,
+ help="Separate multiple prompts using the `|` character, and get all combinations of them.")
+ normalize_prompt_weights = st.checkbox("Normalize Prompt Weights.",
+ value=defaults.txt2img.normalize_prompt_weights, help="Ensure the sum of all weights add up to 1.0")
+ save_individual_images = st.checkbox("Save individual images.", value=defaults.txt2img.save_individual_images,
+ help="Save each image generated before any filter or enhancement is applied.")
+ save_grid = st.checkbox("Save grid",value=defaults.txt2img.save_grid, help="Save a grid with all the images generated into a single image.")
+ group_by_prompt = st.checkbox("Group results by prompt", value=defaults.txt2img.group_by_prompt,
+ help="Saves all the images with the same prompt into the same folder. \
+ When using a prompt matrix each prompt combination will have its own folder.")
+ write_info_files = st.checkbox("Write Info file", value=defaults.txt2img.write_info_files,
+ help="Save a file next to the image with informartion about the generation.")
+ save_as_jpg = st.checkbox("Save samples as jpg", value=defaults.txt2img.save_as_jpg, help="Saves the images as jpg instead of png.")
+
+ if GFPGAN_available:
+ use_GFPGAN = st.checkbox("Use GFPGAN", value=defaults.txt2img.use_GFPGAN,
+ help="Uses the GFPGAN model to improve faces after the generation. This greatly improve the quality and \
+ consistency of faces but uses extra VRAM. Disable if you need the extra VRAM.")
+ else:
+ use_GFPGAN = False
+
+ if RealESRGAN_available:
+ use_RealESRGAN = st.checkbox("Use RealESRGAN", value=defaults.txt2img.use_RealESRGAN,
+ help="Uses the RealESRGAN model to upscale the images after the generation. This greatly improve the \
+ quality and lets you have high resolution images but uses extra VRAM. Disable if you need the extra VRAM.")
+ RealESRGAN_model = st.selectbox("RealESRGAN model", ["RealESRGAN_x4plus", "RealESRGAN_x4plus_anime_6B"], index=0)
+ else:
+ use_RealESRGAN = False
+ RealESRGAN_model = "RealESRGAN_x4plus"
+
+ variant_amount = st.slider("Variant Amount:", value=defaults.txt2img.variant_amount, min_value=0.0, max_value=1.0, step=0.01)
+ variant_seed = st.text_input("Variant Seed:", value=defaults.txt2img.seed,
+ help="The seed to use when generating a variant, if left blank a random seed will be generated.")
+
+
+ if generate_button:
+ #print("Loading models")
+ # load the models when we hit the generate button for the first time, it wont be loaded after that so dont worry.
+ load_models(False, use_GFPGAN, use_RealESRGAN, RealESRGAN_model, CustomModel_available, custom_model)
+
+ try:
+ output_images, seed, info, stats = txt2img(prompt, st.session_state.sampling_steps, sampler_name, RealESRGAN_model, batch_count, 1,
+ cfg_scale, seed, height, width, separate_prompts, normalize_prompt_weights, save_individual_images,
+ save_grid, group_by_prompt, save_as_jpg, use_GFPGAN, use_RealESRGAN, RealESRGAN_model, fp=defaults.general.fp,
+ variant_amount=variant_amount, variant_seed=variant_seed, write_info_files=write_info_files)
+
+ message.success('Render Complete: ' + info + '; Stats: ' + stats, icon="✅")
+
+ except KeyError:
+ output_images, seed, info, stats = txt2img(prompt, st.session_state.sampling_steps, sampler_name, RealESRGAN_model, batch_count, 1,
+ cfg_scale, seed, height, width, separate_prompts, normalize_prompt_weights, save_individual_images,
+ save_grid, group_by_prompt, save_as_jpg, use_GFPGAN, use_RealESRGAN, RealESRGAN_model, fp=defaults.general.fp,
+ variant_amount=variant_amount, variant_seed=variant_seed, write_info_files=write_info_files)
+
+ message.success('Render Complete: ' + info + '; Stats: ' + stats, icon="✅")
+
+ except (StopException):
+ print(f"Received Streamlit StopException")
+
+ # this will render all the images at the end of the generation but its better if its moved to a second tab inside col2 and shown as a gallery.
+ # use the current col2 first tab to show the preview_img and update it as its generated.
+ #preview_image.image(output_images)
+
+ with img2img_tab:
+ with st.form("img2img-inputs"):
+ st.session_state["generation_mode"] = "img2img"
+
+ img2img_input_col, img2img_generate_col = st.columns([10,1])
+ with img2img_input_col:
+ #prompt = st.text_area("Input Text","")
+ prompt = st.text_input("Input Text","", placeholder="A corgi wearing a top hat as an oil painting.")
+
+ # Every form must have a submit button, the extra blank spaces is a temp way to align it with the input field. Needs to be done in CSS or some other way.
+ img2img_generate_col.write("")
+ img2img_generate_col.write("")
+ generate_button = img2img_generate_col.form_submit_button("Generate")
+
+
+ # creating the page layout using columns
+ col1_img2img_layout, col2_img2img_layout, col3_img2img_layout = st.columns([1,2,2], gap="small")
+
+ with col1_img2img_layout:
+ # If we have custom models available on the "models/custom"
+ #folder then we show a menu to select which model we want to use, otherwise we use the main model for SD
+ if CustomModel_available:
+ custom_model = st.selectbox("Custom Model:", custom_models,
+ index=custom_models.index(defaults.general.default_model),
+ help="Select the model you want to use. This option is only available if you have custom models \
+ on your 'models/custom' folder. The model name that will be shown here is the same as the name\
+ the file for the model has on said folder, it is recommended to give the .ckpt file a name that \
+ will make it easier for you to distinguish it from other models. Default: Stable Diffusion v1.4")
+ else:
+ custom_model = "Stable Diffusion v1.4"
+
+ st.session_state["sampling_steps"] = st.slider("Sampling Steps", value=defaults.img2img.sampling_steps, min_value=1, max_value=500)
+ st.session_state["sampler_name"] = st.selectbox("Sampling method",
+ ["k_lms", "k_euler", "k_euler_a", "k_dpm_2", "k_dpm_2_a", "k_heun", "PLMS", "DDIM"],
+ index=sampler_name_list.index(defaults.img2img.sampler_name),
+ help="Sampling method to use.")
+
+ mask_mode_list = ["Mask", "Inverted mask", "Image alpha"]
+ mask_mode = st.selectbox("Mask Mode", mask_mode_list,
+ help="Select how you want your image to be masked.\"Mask\" modifies the image where the mask is white.\n\
+ \"Inverted mask\" modifies the image where the mask is black. \"Image alpha\" modifies the image where the image is transparent."
+ )
+ mask_mode = mask_mode_list.index(mask_mode)
+
+ width = st.slider("Width:", min_value=64, max_value=1024, value=defaults.img2img.width, step=64)
+ height = st.slider("Height:", min_value=64, max_value=1024, value=defaults.img2img.height, step=64)
+ seed = st.text_input("Seed:", value=defaults.img2img.seed, help=" The seed to use, if left blank a random seed will be generated.")
+ noise_mode_list = ["Seed", "Find Noise", "Matched Noise", "Find+Matched Noise"]
+ noise_mode = st.selectbox(
+ "Noise Mode", noise_mode_list,
+ help=""
+ )
+ noise_mode = noise_mode_list.index(noise_mode)
+ find_noise_steps = st.slider("Find Noise Steps", value=100, min_value=1, max_value=500)
+ batch_count = st.slider("Batch count.", min_value=1, max_value=100, value=defaults.img2img.batch_count, step=1,
+ help="How many iterations or batches of images to generate in total.")
+
+ #
+ with st.expander("Advanced"):
+ separate_prompts = st.checkbox("Create Prompt Matrix.", value=defaults.img2img.separate_prompts,
+ help="Separate multiple prompts using the `|` character, and get all combinations of them.")
+ normalize_prompt_weights = st.checkbox("Normalize Prompt Weights.", value=defaults.img2img.normalize_prompt_weights,
+ help="Ensure the sum of all weights add up to 1.0")
+ loopback = st.checkbox("Loopback.", value=defaults.img2img.loopback, help="Use images from previous batch when creating next batch.")
+ random_seed_loopback = st.checkbox("Random loopback seed.", value=defaults.img2img.random_seed_loopback, help="Random loopback seed")
+ save_individual_images = st.checkbox("Save individual images.", value=defaults.img2img.save_individual_images,
+ help="Save each image generated before any filter or enhancement is applied.")
+ save_grid = st.checkbox("Save grid",value=defaults.img2img.save_grid, help="Save a grid with all the images generated into a single image.")
+ group_by_prompt = st.checkbox("Group results by prompt", value=defaults.img2img.group_by_prompt,
+ help="Saves all the images with the same prompt into the same folder. \
+ When using a prompt matrix each prompt combination will have its own folder.")
+ write_info_files = st.checkbox("Write Info file", value=defaults.img2img.write_info_files,
+ help="Save a file next to the image with informartion about the generation.")
+ save_as_jpg = st.checkbox("Save samples as jpg", value=defaults.img2img.save_as_jpg, help="Saves the images as jpg instead of png.")
+
+ if GFPGAN_available:
+ use_GFPGAN = st.checkbox("Use GFPGAN", value=defaults.img2img.use_GFPGAN, help="Uses the GFPGAN model to improve faces after the generation.\
+ This greatly improve the quality and consistency of faces but uses extra VRAM. Disable if you need the extra VRAM.")
+ else:
+ use_GFPGAN = False
+
+ if RealESRGAN_available:
+ use_RealESRGAN = st.checkbox("Use RealESRGAN", value=defaults.img2img.use_RealESRGAN,
+ help="Uses the RealESRGAN model to upscale the images after the generation.\
+ This greatly improve the quality and lets you have high resolution images but uses extra VRAM. Disable if you need the extra VRAM.")
+ RealESRGAN_model = st.selectbox("RealESRGAN model", ["RealESRGAN_x4plus", "RealESRGAN_x4plus_anime_6B"], index=0)
+ else:
+ use_RealESRGAN = False
+ RealESRGAN_model = "RealESRGAN_x4plus"
+
+ variant_amount = st.slider("Variant Amount:", value=defaults.img2img.variant_amount, min_value=0.0, max_value=1.0, step=0.01)
+ variant_seed = st.text_input("Variant Seed:", value=defaults.img2img.variant_seed,
+ help="The seed to use when generating a variant, if left blank a random seed will be generated.")
+ cfg_scale = st.slider("CFG (Classifier Free Guidance Scale):", min_value=1.0, max_value=30.0, value=defaults.img2img.cfg_scale, step=0.5,
+ help="How strongly the image should follow the prompt.")
+ batch_size = st.slider("Batch size", min_value=1, max_value=100, value=defaults.img2img.batch_size, step=1,
+ help="How many images are at once in a batch.\
+ It increases the VRAM usage a lot but if you have enough VRAM it can reduce the time it takes to finish \
+ generation as more images are generated at once.\
+ Default: 1")
+
+ st.session_state["denoising_strength"] = st.slider("Denoising Strength:", value=defaults.img2img.denoising_strength,
+ min_value=0.01, max_value=1.0, step=0.01)
+
+ with st.expander("Preview Settings"):
+ st.session_state["update_preview"] = st.checkbox("Update Image Preview", value=defaults.img2img.update_preview,
+ help="If enabled the image preview will be updated during the generation instead of at the end. \
+ You can use the Update Preview \Frequency option bellow to customize how frequent it's updated. \
+ By default this is enabled and the frequency is set to 1 step.")
+
+ st.session_state["update_preview_frequency"] = st.text_input("Update Image Preview Frequency", value=defaults.img2img.update_preview_frequency,
+ help="Frequency in steps at which the the preview image is updated. By default the frequency \
+ is set to 1 step.")
+
+ with col2_img2img_layout:
+ editor_tab = st.tabs(["Editor"])
+
+ editor_image = st.empty()
+ st.session_state["editor_image"] = editor_image
+
+ refresh_button = st.form_submit_button("Refresh")
+
+ masked_image_holder = st.empty()
+ image_holder = st.empty()
+
+ uploaded_images = st.file_uploader(
+ "Upload Image", accept_multiple_files=False, type=["png", "jpg", "jpeg", "webp"],
+ help="Upload an image which will be used for the image to image generation.",
+ )
+ if uploaded_images:
+ image = Image.open(uploaded_images).convert('RGBA')
+ new_img = image.resize((width, height))
+ image_holder.image(new_img)
+
+ mask_holder = st.empty()
+
+ uploaded_masks = st.file_uploader(
+ "Upload Mask", accept_multiple_files=False, type=["png", "jpg", "jpeg", "webp"],
+ help="Upload an mask image which will be used for masking the image to image generation.",
+ )
+ if uploaded_masks:
+ mask = Image.open(uploaded_masks)
+ if mask.mode == "RGBA":
+ mask = mask.convert('RGBA')
+ background = Image.new('RGBA', mask.size, (0, 0, 0))
+ mask = Image.alpha_composite(background, mask)
+ mask = mask.resize((width, height))
+ mask_holder.image(mask)
+
+ if uploaded_images and uploaded_masks:
+ if mask_mode != 2:
+ final_img = new_img.copy()
+ alpha_layer = mask.convert('L')
+ strength = st.session_state["denoising_strength"]
+ if mask_mode == 0:
+ alpha_layer = ImageOps.invert(alpha_layer)
+ alpha_layer = alpha_layer.point(lambda a: a * strength)
+ alpha_layer = ImageOps.invert(alpha_layer)
+ elif mask_mode == 1:
+ alpha_layer = alpha_layer.point(lambda a: a * strength)
+ alpha_layer = ImageOps.invert(alpha_layer)
+
+ final_img.putalpha(alpha_layer)
+
+ with masked_image_holder.container():
+ st.text("Masked Image Preview")
+ st.image(final_img)
+
+
+ with col3_img2img_layout:
+ result_tab = st.tabs(["Result"])
+
+ # create an empty container for the image, progress bar, etc so we can update it later and use session_state to hold them globally.
+ preview_image = st.empty()
+ st.session_state["preview_image"] = preview_image
+
+ #st.session_state["loading"] = st.empty()
+
+ st.session_state["progress_bar_text"] = st.empty()
+ st.session_state["progress_bar"] = st.empty()
+
+
+ message = st.empty()
+
+ #if uploaded_images:
+ #image = Image.open(uploaded_images).convert('RGB')
+ ##img_array = np.array(image) # if you want to pass it to OpenCV
+ #new_img = image.resize((width, height))
+ #st.image(new_img, use_column_width=True)
+
+
+ if generate_button:
+ #print("Loading models")
+ # load the models when we hit the generate button for the first time, it wont be loaded after that so dont worry.
+ load_models(False, use_GFPGAN, use_RealESRGAN, RealESRGAN_model, CustomModel_available, custom_model)
+ if uploaded_images:
+ image = Image.open(uploaded_images).convert('RGBA')
+ new_img = image.resize((width, height))
+ #img_array = np.array(image) # if you want to pass it to OpenCV
+ new_mask = None
+ if uploaded_masks:
+ mask = Image.open(uploaded_masks).convert('RGBA')
+ new_mask = mask.resize((width, height))
+
+ try:
+ output_images, seed, info, stats = img2img(prompt=prompt, init_info=new_img, init_info_mask=new_mask, mask_mode=mask_mode, ddim_steps=st.session_state["sampling_steps"],
+ sampler_name=st.session_state["sampler_name"], n_iter=batch_count,
+ cfg_scale=cfg_scale, denoising_strength=st.session_state["denoising_strength"], variant_seed=variant_seed,
+ seed=seed, noise_mode=noise_mode, find_noise_steps=find_noise_steps, width=width, height=height, fp=defaults.general.fp, variant_amount=variant_amount,
+ ddim_eta=0.0, write_info_files=write_info_files, RealESRGAN_model=RealESRGAN_model,
+ separate_prompts=separate_prompts, normalize_prompt_weights=normalize_prompt_weights,
+ save_individual_images=save_individual_images, save_grid=save_grid,
+ group_by_prompt=group_by_prompt, save_as_jpg=save_as_jpg, use_GFPGAN=use_GFPGAN,
+ use_RealESRGAN=use_RealESRGAN if not loopback else False, loopback=loopback
+ )
+
+ #show a message when the generation is complete.
+ message.success('Render Complete: ' + info + '; Stats: ' + stats, icon="✅")
+
+ except (StopException, KeyError):
+ print(f"Received Streamlit StopException")
+
+ # this will render all the images at the end of the generation but its better if its moved to a second tab inside col2 and shown as a gallery.
+ # use the current col2 first tab to show the preview_img and update it as its generated.
+ #preview_image.image(output_images, width=750)
+
+ with txt2vid_tab:
+ with st.form("txt2vid-inputs"):
+ st.session_state["generation_mode"] = "txt2vid"
+
+ input_col1, generate_col1 = st.columns([10,1])
+ with input_col1:
+ #prompt = st.text_area("Input Text","")
+ prompt = st.text_input("Input Text","", placeholder="A corgi wearing a top hat as an oil painting.")
+
+ # Every form must have a submit button, the extra blank spaces is a temp way to align it with the input field. Needs to be done in CSS or some other way.
+ generate_col1.write("")
+ generate_col1.write("")
+ generate_button = generate_col1.form_submit_button("Generate")
+
+ # creating the page layout using columns
+ col1, col2, col3 = st.columns([1,2,1], gap="large")
+
+ with col1:
+ width = st.slider("Width:", min_value=64, max_value=2048, value=defaults.txt2vid.width, step=64)
+ height = st.slider("Height:", min_value=64, max_value=2048, value=defaults.txt2vid.height, step=64)
+ cfg_scale = st.slider("CFG (Classifier Free Guidance Scale):", min_value=1.0, max_value=30.0, value=defaults.txt2vid.cfg_scale, step=0.5, help="How strongly the image should follow the prompt.")
+ seed = st.text_input("Seed:", value=defaults.txt2vid.seed, help=" The seed to use, if left blank a random seed will be generated.")
+ batch_count = st.slider("Batch count.", min_value=1, max_value=100, value=defaults.txt2vid.batch_count, step=1, help="How many iterations or batches of images to generate in total.")
+ #batch_size = st.slider("Batch size", min_value=1, max_value=250, value=defaults.txt2vid.batch_size, step=1,
+ #help="How many images are at once in a batch.\
+ #It increases the VRAM usage a lot but if you have enough VRAM it can reduce the time it takes to finish generation as more images are generated at once.\
+ #Default: 1")
+
+ st.session_state["max_frames"] = int(st.text_input("Max Frames:", value=defaults.txt2vid.max_frames, help="Specify the max number of frames you want to generate."))
+
+ with st.expander("Preview Settings"):
+ st.session_state["update_preview"] = st.checkbox("Update Image Preview", value=defaults.txt2vid.update_preview,
+ help="If enabled the image preview will be updated during the generation instead of at the end. \
+ You can use the Update Preview \Frequency option bellow to customize how frequent it's updated. \
+ By default this is enabled and the frequency is set to 1 step.")
+
+ st.session_state["update_preview_frequency"] = st.text_input("Update Image Preview Frequency", value=defaults.txt2vid.update_preview_frequency,
+ help="Frequency in steps at which the the preview image is updated. By default the frequency \
+ is set to 1 step.")
+ with col2:
+ preview_tab, gallery_tab = st.tabs(["Preview", "Gallery"])
+
+ with preview_tab:
+ #st.write("Image")
+ #Image for testing
+ #image = Image.open(requests.get("https://icon-library.com/images/image-placeholder-icon/image-placeholder-icon-13.jpg", stream=True).raw).convert('RGB')
+ #new_image = image.resize((175, 240))
+ #preview_image = st.image(image)
+
+ # create an empty container for the image, progress bar, etc so we can update it later and use session_state to hold them globally.
+ st.session_state["preview_image"] = st.empty()
+
+ st.session_state["loading"] = st.empty()
+
+ st.session_state["progress_bar_text"] = st.empty()
+ st.session_state["progress_bar"] = st.empty()
+
+ generate_video = st.empty()
+ st.session_state["preview_video"] = st.empty()
+
+ message = st.empty()
+
+ with gallery_tab:
+ st.write('Here should be the image gallery, if I could make a grid in streamlit.')
+
+ with col3:
+ # If we have custom models available on the "models/custom"
+ #folder then we show a menu to select which model we want to use, otherwise we use the main model for SD
+ #if CustomModel_available:
+ custom_model = st.selectbox("Custom Model:", defaults.txt2vid.custom_models_list,
+ index=defaults.txt2vid.custom_models_list.index(defaults.txt2vid.default_model),
+ help="Select the model you want to use. This option is only available if you have custom models \
+ on your 'models/custom' folder. The model name that will be shown here is the same as the name\
+ the file for the model has on said folder, it is recommended to give the .ckpt file a name that \
+ will make it easier for you to distinguish it from other models. Default: Stable Diffusion v1.4")
+
+ #st.session_state["weights_path"] = custom_model
+ #else:
+ #custom_model = "CompVis/stable-diffusion-v1-4"
+ #st.session_state["weights_path"] = f"CompVis/{slugify(custom_model.lower())}"
+
+ st.session_state.sampling_steps = st.slider("Sampling Steps", value=defaults.txt2vid.sampling_steps, min_value=10, step=10, max_value=500,
+ help="Number of steps between each pair of sampled points")
+ st.session_state.num_inference_steps = st.slider("Inference Steps:", value=defaults.txt2vid.num_inference_steps, min_value=10,step=10, max_value=500,
+ help="Higher values (e.g. 100, 200 etc) can create better images.")
+
+ #sampler_name_list = ["k_lms", "k_euler", "k_euler_a", "k_dpm_2", "k_dpm_2_a", "k_heun", "PLMS", "DDIM"]
+ #sampler_name = st.selectbox("Sampling method", sampler_name_list,
+ #index=sampler_name_list.index(defaults.txt2vid.default_sampler), help="Sampling method to use. Default: k_euler")
+ scheduler_name_list = ["klms", "ddim"]
+ scheduler_name = st.selectbox("Scheduler:", scheduler_name_list,
+ index=scheduler_name_list.index(defaults.txt2vid.scheduler_name), help="Scheduler to use. Default: klms")
+
+ beta_scheduler_type_list = ["scaled_linear", "linear"]
+ beta_scheduler_type = st.selectbox("Beta Schedule Type:", beta_scheduler_type_list,
+ index=beta_scheduler_type_list.index(defaults.txt2vid.beta_scheduler_type), help="Schedule Type to use. Default: linear")
+
+
+ #basic_tab, advanced_tab = st.tabs(["Basic", "Advanced"])
+
+ #with basic_tab:
+ #summit_on_enter = st.radio("Submit on enter?", ("Yes", "No"), horizontal=True,
+ #help="Press the Enter key to summit, when 'No' is selected you can use the Enter key to write multiple lines.")
+
+ with st.expander("Advanced"):
+ st.session_state["separate_prompts"] = st.checkbox("Create Prompt Matrix.", value=defaults.txt2vid.separate_prompts,
+ help="Separate multiple prompts using the `|` character, and get all combinations of them.")
+ st.session_state["normalize_prompt_weights"] = st.checkbox("Normalize Prompt Weights.",
+ value=defaults.txt2vid.normalize_prompt_weights, help="Ensure the sum of all weights add up to 1.0")
+ st.session_state["save_individual_images"] = st.checkbox("Save individual images.",
+ value=defaults.txt2vid.save_individual_images, help="Save each image generated before any filter or enhancement is applied.")
+ st.session_state["save_video"] = st.checkbox("Save video",value=defaults.txt2vid.save_video, help="Save a video with all the images generated as frames at the end of the generation.")
+ st.session_state["group_by_prompt"] = st.checkbox("Group results by prompt", value=defaults.txt2vid.group_by_prompt,
+ help="Saves all the images with the same prompt into the same folder. When using a prompt matrix each prompt combination will have its own folder.")
+ st.session_state["write_info_files"] = st.checkbox("Write Info file", value=defaults.txt2vid.write_info_files,
+ help="Save a file next to the image with informartion about the generation.")
+ st.session_state["dynamic_preview_frequency"] = st.checkbox("Dynamic Preview Frequency", value=defaults.txt2vid.dynamic_preview_frequency,
+ help="This option tries to find the best value at which we can update \
+ the preview image during generation while minimizing the impact it has in performance. Default: True")
+ st.session_state["do_loop"] = st.checkbox("Do Loop", value=defaults.txt2vid.do_loop,
+ help="Do loop")
+ st.session_state["save_as_jpg"] = st.checkbox("Save samples as jpg", value=defaults.txt2vid.save_as_jpg, help="Saves the images as jpg instead of png.")
+
+ if GFPGAN_available:
+ st.session_state["use_GFPGAN"] = st.checkbox("Use GFPGAN", value=defaults.txt2vid.use_GFPGAN, help="Uses the GFPGAN model to improve faces after the generation. This greatly improve the quality and consistency of faces but uses extra VRAM. Disable if you need the extra VRAM.")
+ else:
+ st.session_state["use_GFPGAN"] = False
+
+ if RealESRGAN_available:
+ st.session_state["use_RealESRGAN"] = st.checkbox("Use RealESRGAN", value=defaults.txt2vid.use_RealESRGAN,
+ help="Uses the RealESRGAN model to upscale the images after the generation. This greatly improve the quality and lets you have high resolution images but uses extra VRAM. Disable if you need the extra VRAM.")
+ st.session_state["RealESRGAN_model"] = st.selectbox("RealESRGAN model", ["RealESRGAN_x4plus", "RealESRGAN_x4plus_anime_6B"], index=0)
+ else:
+ st.session_state["use_RealESRGAN"] = False
+ st.session_state["RealESRGAN_model"] = "RealESRGAN_x4plus"
+
+ st.session_state["variant_amount"] = st.slider("Variant Amount:", value=defaults.txt2vid.variant_amount, min_value=0.0, max_value=1.0, step=0.01)
+ st.session_state["variant_seed"] = st.text_input("Variant Seed:", value=defaults.txt2vid.seed, help="The seed to use when generating a variant, if left blank a random seed will be generated.")
+ st.session_state["beta_start"] = st.slider("Beta Start:", value=defaults.txt2vid.beta_start, min_value=0.0001, max_value=0.03, step=0.0001, format="%.4f")
+ st.session_state["beta_end"] = st.slider("Beta End:", value=defaults.txt2vid.beta_end, min_value=0.0001, max_value=0.03, step=0.0001, format="%.4f")
+
+ if generate_button:
+ #print("Loading models")
+ # load the models when we hit the generate button for the first time, it wont be loaded after that so dont worry.
+ #load_models(False, False, False, RealESRGAN_model, CustomModel_available=CustomModel_available, custom_model=custom_model)
+
+ # run video generation
+ image, seed, info, stats = txt2vid(prompts=prompt, gpu=defaults.general.gpu,
+ num_steps=st.session_state.sampling_steps, max_frames=int(st.session_state.max_frames),
+ num_inference_steps=st.session_state.num_inference_steps,
+ cfg_scale=cfg_scale,do_loop=st.session_state["do_loop"],
+ seeds=seed, quality=100, eta=0.0, width=width,
+ height=height, weights_path=custom_model, scheduler=scheduler_name,
+ disable_tqdm=False, beta_start=st.session_state["beta_start"], beta_end=st.session_state["beta_end"],
+ beta_schedule=beta_scheduler_type)
+
+ #message.success('Done!', icon="✅")
+ message.success('Render Complete: ' + info + '; Stats: ' + stats, icon="✅")
+
+ #except (StopException, KeyError):
+ #print(f"Received Streamlit StopException")
+
+ # this will render all the images at the end of the generation but its better if its moved to a second tab inside col2 and shown as a gallery.
+ # use the current col2 first tab to show the preview_img and update it as its generated.
+ #preview_image.image(output_images)
+
+ #
+ elif tabs == 'Model Manager':
+ #search = st.text_input(label="Search", placeholder="Type the name of the model you want to search for.", help="")
+
+ csvString = f"""
+ ,Stable Diffusion v1.4 , ./models/ldm/stable-diffusion-v1 , https://www.googleapis.com/storage/v1/b/aai-blog-files/o/sd-v1-4.ckpt?alt=media
+ ,GFPGAN v1.3 , ./src/gfpgan/experiments/pretrained_models , https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth
+ ,RealESRGAN_x4plus , ./src/realesrgan/experiments/pretrained_models , https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth
+ ,RealESRGAN_x4plus_anime_6B , ./src/realesrgan/experiments/pretrained_models , https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth
+ ,Waifu Diffusion v1.2 , ./models/custom , http://wd.links.sd:8880/wd-v1-2-full-ema.ckpt
+ ,TrinArt Stable Diffusion v2 , ./models/custom , https://huggingface.co/naclbit/trinart_stable_diffusion_v2/resolve/main/trinart2_step115000.ckpt
+ """
+ colms = st.columns((1, 3, 5, 5))
+ columns = ["№",'Model Name','Save Location','Download Link']
+
+ # Convert String into StringIO
+ csvStringIO = StringIO(csvString)
+ df = pd.read_csv(csvStringIO, sep=",", header=None, names=columns)
+
+ for col, field_name in zip(colms, columns):
+ # table header
+ col.write(field_name)
+
+ for x, model_name in enumerate(df["Model Name"]):
+ col1, col2, col3, col4 = st.columns((1, 3, 4, 6))
+ col1.write(x) # index
+ col2.write(df['Model Name'][x])
+ col3.write(df['Save Location'][x])
+ col4.write(df['Download Link'][x])
+
+
+ elif tabs == 'Settings':
+ import Settings
+
+ st.write("Settings")
+
+if __name__ == '__main__':
+ layout()
diff --git a/setup.py b/setup.py
index a24d541..0e768e1 100644
--- a/setup.py
+++ b/setup.py
@@ -1,7 +1,7 @@
from setuptools import setup, find_packages
setup(
- name='latent-diffusion',
+ name='sd-webui',
version='0.0.1',
description='',
packages=find_packages(),
diff --git a/webui.sh b/webui.sh
index ea7028f..7b07a49 100755
--- a/webui.sh
+++ b/webui.sh
@@ -37,7 +37,7 @@ if ! conda env list | grep ".*${ENV_NAME}.*" >/dev/null 2>&1; then
ENV_UPDATED=1
elif [[ ! -z $CONDA_FORCE_UPDATE && $CONDA_FORCE_UPDATE == "true" ]] || (( $ENV_MODIFIED > $ENV_MODIFIED_CACHED )); then
echo "Updating conda env: ${ENV_NAME} ..."
- conda env update --file $ENV_FILE --prune
+ PIP_EXISTS_ACTION=w conda env update --file $ENV_FILE --prune
ENV_UPDATED=1
fi
@@ -56,4 +56,4 @@ if [ ! -e "models/ldm/stable-diffusion-v1/model.ckpt" ]; then
exit 1
fi
-python scripts/relauncher.py
\ No newline at end of file
+python scripts/relauncher.py